assistant: Stream tool uses as structured data (#17322)

This PR adjusts the approach we use to encoding tool uses in the
completion response to use a structured format rather than simply
injecting it into the response stream as text.

In #17170 we would encode the tool uses as XML and insert them as text.
This would require then re-parsing the tool uses out of the buffer in
order to use them.

The approach taken in this PR is to make `stream_completion` return a
stream of `LanguageModelCompletionEvent`s. Each of these events can be
either text, or a tool use.

A new `stream_completion_text` method has been added to `LanguageModel`
for scenarios where we only care about textual content (currently,
everywhere that isn't the Assistant context editor).

Release Notes:

- N/A
This commit is contained in:
Marshall Bowers 2024-09-03 15:04:51 -04:00 committed by GitHub
parent 132e8e8064
commit 452272e5df
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 235 additions and 83 deletions

View file

@ -18,6 +18,7 @@ path = "src/anthropic.rs"
[dependencies]
anyhow.workspace = true
chrono.workspace = true
collections.workspace = true
futures.workspace = true
http_client.workspace = true
isahc.workspace = true

View file

@ -1,17 +1,19 @@
mod supported_countries;
use std::time::Duration;
use std::{pin::Pin, str::FromStr};
use anyhow::{anyhow, Context, Result};
use chrono::{DateTime, Utc};
use collections::HashMap;
use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
use isahc::config::Configurable;
use isahc::http::{HeaderMap, HeaderValue};
use serde::{Deserialize, Serialize};
use std::time::Duration;
use std::{pin::Pin, str::FromStr};
use strum::{EnumIter, EnumString};
use thiserror::Error;
use util::ResultExt as _;
use util::{maybe, ResultExt as _};
pub use supported_countries::*;
@ -332,19 +334,22 @@ pub async fn stream_completion_with_rate_limit_info(
pub fn extract_content_from_events(
events: Pin<Box<dyn Send + Stream<Item = Result<Event, AnthropicError>>>>,
) -> impl Stream<Item = Result<String, AnthropicError>> {
struct State {
events: Pin<Box<dyn Send + Stream<Item = Result<Event, AnthropicError>>>>,
current_tool_use_index: Option<usize>,
) -> impl Stream<Item = Result<ResponseContent, AnthropicError>> {
struct RawToolUse {
id: String,
name: String,
input_json: String,
}
const INDENT: &str = " ";
const NEWLINE: char = '\n';
struct State {
events: Pin<Box<dyn Send + Stream<Item = Result<Event, AnthropicError>>>>,
tool_uses_by_index: HashMap<usize, RawToolUse>,
}
futures::stream::unfold(
State {
events,
current_tool_use_index: None,
tool_uses_by_index: HashMap::default(),
},
|mut state| async move {
while let Some(event) = state.events.next().await {
@ -355,62 +360,56 @@ pub fn extract_content_from_events(
content_block,
} => match content_block {
ResponseContent::Text { text } => {
return Some((Ok(text), state));
return Some((Some(Ok(ResponseContent::Text { text })), state));
}
ResponseContent::ToolUse { id, name, .. } => {
state.current_tool_use_index = Some(index);
state.tool_uses_by_index.insert(
index,
RawToolUse {
id,
name,
input_json: String::new(),
},
);
let mut text = String::new();
text.push(NEWLINE);
text.push_str("<tool_use>");
text.push(NEWLINE);
text.push_str(INDENT);
text.push_str("<id>");
text.push_str(&id);
text.push_str("</id>");
text.push(NEWLINE);
text.push_str(INDENT);
text.push_str("<name>");
text.push_str(&name);
text.push_str("</name>");
text.push(NEWLINE);
text.push_str(INDENT);
text.push_str("<input>");
return Some((Ok(text), state));
return Some((None, state));
}
},
Event::ContentBlockDelta { index, delta } => match delta {
ContentDelta::TextDelta { text } => {
return Some((Ok(text), state));
return Some((Some(Ok(ResponseContent::Text { text })), state));
}
ContentDelta::InputJsonDelta { partial_json } => {
if Some(index) == state.current_tool_use_index {
return Some((Ok(partial_json), state));
if let Some(tool_use) = state.tool_uses_by_index.get_mut(&index) {
tool_use.input_json.push_str(&partial_json);
return Some((None, state));
}
}
},
Event::ContentBlockStop { index } => {
if Some(index) == state.current_tool_use_index.take() {
let mut text = String::new();
text.push_str("</input>");
text.push(NEWLINE);
text.push_str("</tool_use>");
return Some((Ok(text), state));
if let Some(tool_use) = state.tool_uses_by_index.remove(&index) {
return Some((
Some(maybe!({
Ok(ResponseContent::ToolUse {
id: tool_use.id,
name: tool_use.name,
input: serde_json::Value::from_str(
&tool_use.input_json,
)
.map_err(|err| anyhow!(err))?,
})
})),
state,
));
}
}
Event::Error { error } => {
return Some((Err(AnthropicError::ApiError(error)), state));
return Some((Some(Err(AnthropicError::ApiError(error))), state));
}
_ => {}
},
Err(err) => {
return Some((Err(err), state));
return Some((Some(Err(err)), state));
}
}
}
@ -418,6 +417,7 @@ pub fn extract_content_from_events(
None
},
)
.filter_map(|event| async move { event })
}
pub async fn extract_tool_args_from_events(