assistant: Propagate LLM stop reason upwards (#17358)
This PR makes it so we propagate the `stop_reason` from Anthropic up to the Assistant so that we can take action based on it. The `extract_content_from_events` function was moved from `anthropic` to the `anthropic` module in `language_model` since it is more useful if it is able to name the `LanguageModelCompletionEvent` type, as otherwise we'd need an additional layer of plumbing. Release Notes: - N/A
This commit is contained in:
parent
7c8f62e943
commit
f38956943b
7 changed files with 143 additions and 144 deletions
|
@ -18,7 +18,6 @@ path = "src/anthropic.rs"
|
|||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
chrono.workspace = true
|
||||
collections.workspace = true
|
||||
futures.workspace = true
|
||||
http_client.workspace = true
|
||||
isahc.workspace = true
|
||||
|
|
|
@ -5,7 +5,6 @@ use std::{pin::Pin, str::FromStr};
|
|||
|
||||
use anyhow::{anyhow, Context, Result};
|
||||
use chrono::{DateTime, Utc};
|
||||
use collections::HashMap;
|
||||
use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
|
||||
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
|
||||
use isahc::config::Configurable;
|
||||
|
@ -13,7 +12,7 @@ use isahc::http::{HeaderMap, HeaderValue};
|
|||
use serde::{Deserialize, Serialize};
|
||||
use strum::{EnumIter, EnumString};
|
||||
use thiserror::Error;
|
||||
use util::{maybe, ResultExt as _};
|
||||
use util::ResultExt as _;
|
||||
|
||||
pub use supported_countries::*;
|
||||
|
||||
|
@ -332,94 +331,6 @@ pub async fn stream_completion_with_rate_limit_info(
|
|||
}
|
||||
}
|
||||
|
||||
pub fn extract_content_from_events(
|
||||
events: Pin<Box<dyn Send + Stream<Item = Result<Event, AnthropicError>>>>,
|
||||
) -> impl Stream<Item = Result<ResponseContent, AnthropicError>> {
|
||||
struct RawToolUse {
|
||||
id: String,
|
||||
name: String,
|
||||
input_json: String,
|
||||
}
|
||||
|
||||
struct State {
|
||||
events: Pin<Box<dyn Send + Stream<Item = Result<Event, AnthropicError>>>>,
|
||||
tool_uses_by_index: HashMap<usize, RawToolUse>,
|
||||
}
|
||||
|
||||
futures::stream::unfold(
|
||||
State {
|
||||
events,
|
||||
tool_uses_by_index: HashMap::default(),
|
||||
},
|
||||
|mut state| async move {
|
||||
while let Some(event) = state.events.next().await {
|
||||
match event {
|
||||
Ok(event) => match event {
|
||||
Event::ContentBlockStart {
|
||||
index,
|
||||
content_block,
|
||||
} => match content_block {
|
||||
ResponseContent::Text { text } => {
|
||||
return Some((Some(Ok(ResponseContent::Text { text })), state));
|
||||
}
|
||||
ResponseContent::ToolUse { id, name, .. } => {
|
||||
state.tool_uses_by_index.insert(
|
||||
index,
|
||||
RawToolUse {
|
||||
id,
|
||||
name,
|
||||
input_json: String::new(),
|
||||
},
|
||||
);
|
||||
|
||||
return Some((None, state));
|
||||
}
|
||||
},
|
||||
Event::ContentBlockDelta { index, delta } => match delta {
|
||||
ContentDelta::TextDelta { text } => {
|
||||
return Some((Some(Ok(ResponseContent::Text { text })), state));
|
||||
}
|
||||
ContentDelta::InputJsonDelta { partial_json } => {
|
||||
if let Some(tool_use) = state.tool_uses_by_index.get_mut(&index) {
|
||||
tool_use.input_json.push_str(&partial_json);
|
||||
return Some((None, state));
|
||||
}
|
||||
}
|
||||
},
|
||||
Event::ContentBlockStop { index } => {
|
||||
if let Some(tool_use) = state.tool_uses_by_index.remove(&index) {
|
||||
return Some((
|
||||
Some(maybe!({
|
||||
Ok(ResponseContent::ToolUse {
|
||||
id: tool_use.id,
|
||||
name: tool_use.name,
|
||||
input: serde_json::Value::from_str(
|
||||
&tool_use.input_json,
|
||||
)
|
||||
.map_err(|err| anyhow!(err))?,
|
||||
})
|
||||
})),
|
||||
state,
|
||||
));
|
||||
}
|
||||
}
|
||||
Event::Error { error } => {
|
||||
return Some((Some(Err(AnthropicError::ApiError(error))), state));
|
||||
}
|
||||
_ => {}
|
||||
},
|
||||
Err(err) => {
|
||||
return Some((Some(Err(err)), state));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
},
|
||||
)
|
||||
.filter_map(|event| async move { event })
|
||||
}
|
||||
|
||||
pub async fn extract_tool_args_from_events(
|
||||
tool_name: String,
|
||||
mut events: Pin<Box<dyn Send + Stream<Item = Result<Event>>>>,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue