assistant: Propagate LLM stop reason upwards (#17358)

This PR makes it so we propagate the `stop_reason` from Anthropic up to
the Assistant so that we can take action based on it.

The `extract_content_from_events` function was moved from `anthropic` to
the `anthropic` module in `language_model` since it is more useful if it
is able to name the `LanguageModelCompletionEvent` type, as otherwise
we'd need an additional layer of plumbing.

Release Notes:

- N/A
This commit is contained in:
Marshall Bowers 2024-09-04 12:31:10 -04:00 committed by GitHub
parent 7c8f62e943
commit f38956943b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 143 additions and 144 deletions

View file

@ -18,7 +18,6 @@ path = "src/anthropic.rs"
[dependencies]
anyhow.workspace = true
chrono.workspace = true
collections.workspace = true
futures.workspace = true
http_client.workspace = true
isahc.workspace = true

View file

@ -5,7 +5,6 @@ use std::{pin::Pin, str::FromStr};
use anyhow::{anyhow, Context, Result};
use chrono::{DateTime, Utc};
use collections::HashMap;
use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
use isahc::config::Configurable;
@ -13,7 +12,7 @@ use isahc::http::{HeaderMap, HeaderValue};
use serde::{Deserialize, Serialize};
use strum::{EnumIter, EnumString};
use thiserror::Error;
use util::{maybe, ResultExt as _};
use util::ResultExt as _;
pub use supported_countries::*;
@ -332,94 +331,6 @@ pub async fn stream_completion_with_rate_limit_info(
}
}
pub fn extract_content_from_events(
events: Pin<Box<dyn Send + Stream<Item = Result<Event, AnthropicError>>>>,
) -> impl Stream<Item = Result<ResponseContent, AnthropicError>> {
struct RawToolUse {
id: String,
name: String,
input_json: String,
}
struct State {
events: Pin<Box<dyn Send + Stream<Item = Result<Event, AnthropicError>>>>,
tool_uses_by_index: HashMap<usize, RawToolUse>,
}
futures::stream::unfold(
State {
events,
tool_uses_by_index: HashMap::default(),
},
|mut state| async move {
while let Some(event) = state.events.next().await {
match event {
Ok(event) => match event {
Event::ContentBlockStart {
index,
content_block,
} => match content_block {
ResponseContent::Text { text } => {
return Some((Some(Ok(ResponseContent::Text { text })), state));
}
ResponseContent::ToolUse { id, name, .. } => {
state.tool_uses_by_index.insert(
index,
RawToolUse {
id,
name,
input_json: String::new(),
},
);
return Some((None, state));
}
},
Event::ContentBlockDelta { index, delta } => match delta {
ContentDelta::TextDelta { text } => {
return Some((Some(Ok(ResponseContent::Text { text })), state));
}
ContentDelta::InputJsonDelta { partial_json } => {
if let Some(tool_use) = state.tool_uses_by_index.get_mut(&index) {
tool_use.input_json.push_str(&partial_json);
return Some((None, state));
}
}
},
Event::ContentBlockStop { index } => {
if let Some(tool_use) = state.tool_uses_by_index.remove(&index) {
return Some((
Some(maybe!({
Ok(ResponseContent::ToolUse {
id: tool_use.id,
name: tool_use.name,
input: serde_json::Value::from_str(
&tool_use.input_json,
)
.map_err(|err| anyhow!(err))?,
})
})),
state,
));
}
}
Event::Error { error } => {
return Some((Some(Err(AnthropicError::ApiError(error))), state));
}
_ => {}
},
Err(err) => {
return Some((Some(Err(err)), state));
}
}
}
None
},
)
.filter_map(|event| async move { event })
}
pub async fn extract_tool_args_from_events(
tool_name: String,
mut events: Pin<Box<dyn Send + Stream<Item = Result<Event>>>>,