Make LanguageModel::use_any_tool return a stream of chunks (#16262)

This PR is a refactor to pave the way for allowing the user to view and
edit workflow step resolutions. I've made tool calls work more like
normal streaming completions for all providers. The `use_any_tool`
method returns a stream of strings (which contain chunks of JSON). I've
also done some minor cleanup of language model providers in general,
removing the duplication around handling streaming responses.

Release Notes:

- N/A
This commit is contained in:
Max Brunsfeld 2024-08-14 18:02:46 -07:00 committed by GitHub
parent 1117d89057
commit 4c390b82fb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 253 additions and 400 deletions

View file

@ -6,7 +6,7 @@ use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
use isahc::config::Configurable;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::{convert::TryFrom, future::Future, time::Duration};
use std::{convert::TryFrom, future::Future, pin::Pin, time::Duration};
use strum::EnumIter;
pub use supported_countries::*;
@ -384,6 +384,57 @@ pub fn embed<'a>(
}
}
pub async fn extract_tool_args_from_events(
tool_name: String,
mut events: Pin<Box<dyn Send + Stream<Item = Result<ResponseStreamEvent>>>>,
) -> Result<impl Send + Stream<Item = Result<String>>> {
let mut tool_use_index = None;
let mut first_chunk = None;
while let Some(event) = events.next().await {
let call = event?.choices.into_iter().find_map(|choice| {
choice.delta.tool_calls?.into_iter().find_map(|call| {
if call.function.as_ref()?.name.as_deref()? == tool_name {
Some(call)
} else {
None
}
})
});
if let Some(call) = call {
tool_use_index = Some(call.index);
first_chunk = call.function.and_then(|func| func.arguments);
break;
}
}
let Some(tool_use_index) = tool_use_index else {
return Err(anyhow!("tool not used"));
};
Ok(events.filter_map(move |event| {
let result = match event {
Err(error) => Some(Err(error)),
Ok(ResponseStreamEvent { choices, .. }) => choices.into_iter().find_map(|choice| {
choice.delta.tool_calls?.into_iter().find_map(|call| {
if call.index == tool_use_index {
let func = call.function?;
let mut arguments = func.arguments?;
if let Some(mut first_chunk) = first_chunk.take() {
first_chunk.push_str(&arguments);
arguments = first_chunk
}
Some(Ok(arguments))
} else {
None
}
})
}),
};
async move { result }
}))
}
pub fn extract_text_from_events(
response: impl Stream<Item = Result<ResponseStreamEvent>>,
) -> impl Stream<Item = Result<String>> {