Make LanguageModel::use_any_tool return a stream of chunks (#16262)
This PR is a refactor to pave the way for allowing the user to view and edit workflow step resolutions. I've made tool calls work more like normal streaming completions for all providers. The `use_any_tool` method returns a stream of strings (which contain chunks of JSON). I've also done some minor cleanup of language model providers in general, removing the duplication around handling streaming responses. Release Notes: - N/A
This commit is contained in:
parent
1117d89057
commit
4c390b82fb
14 changed files with 253 additions and 400 deletions
|
@ -5,18 +5,21 @@ use crate::{
|
|||
LanguageModelProviderState, LanguageModelRequest, RateLimiter, ZedModel,
|
||||
};
|
||||
use anthropic::AnthropicError;
|
||||
use anyhow::{anyhow, bail, Context as _, Result};
|
||||
use anyhow::{anyhow, Result};
|
||||
use client::{Client, PerformCompletionParams, UserStore, EXPIRED_LLM_TOKEN_HEADER_NAME};
|
||||
use collections::BTreeMap;
|
||||
use feature_flags::{FeatureFlagAppExt, ZedPro};
|
||||
use futures::{future::BoxFuture, stream::BoxStream, AsyncBufReadExt, FutureExt, StreamExt};
|
||||
use futures::{
|
||||
future::BoxFuture, stream::BoxStream, AsyncBufReadExt, FutureExt, Stream, StreamExt,
|
||||
TryStreamExt as _,
|
||||
};
|
||||
use gpui::{
|
||||
AnyElement, AnyView, AppContext, AsyncAppContext, FontWeight, Model, ModelContext,
|
||||
Subscription, Task,
|
||||
};
|
||||
use http_client::{AsyncBody, HttpClient, Method, Response};
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde::{de::DeserializeOwned, Deserialize, Serialize};
|
||||
use serde_json::value::RawValue;
|
||||
use settings::{Settings, SettingsStore};
|
||||
use smol::{
|
||||
|
@ -451,21 +454,9 @@ impl LanguageModel for CloudLanguageModel {
|
|||
},
|
||||
)
|
||||
.await?;
|
||||
let body = BufReader::new(response.into_body());
|
||||
let stream = futures::stream::try_unfold(body, move |mut body| async move {
|
||||
let mut buffer = String::new();
|
||||
match body.read_line(&mut buffer).await {
|
||||
Ok(0) => Ok(None),
|
||||
Ok(_) => {
|
||||
let event: anthropic::Event = serde_json::from_str(&buffer)
|
||||
.context("failed to parse Anthropic event")?;
|
||||
Ok(Some((event, body)))
|
||||
}
|
||||
Err(err) => Err(AnthropicError::Other(err.into())),
|
||||
}
|
||||
});
|
||||
|
||||
Ok(anthropic::extract_text_from_events(stream))
|
||||
Ok(anthropic::extract_text_from_events(
|
||||
response_lines(response).map_err(AnthropicError::Other),
|
||||
))
|
||||
});
|
||||
async move {
|
||||
Ok(future
|
||||
|
@ -492,21 +483,7 @@ impl LanguageModel for CloudLanguageModel {
|
|||
},
|
||||
)
|
||||
.await?;
|
||||
let body = BufReader::new(response.into_body());
|
||||
let stream = futures::stream::try_unfold(body, move |mut body| async move {
|
||||
let mut buffer = String::new();
|
||||
match body.read_line(&mut buffer).await {
|
||||
Ok(0) => Ok(None),
|
||||
Ok(_) => {
|
||||
let event: open_ai::ResponseStreamEvent =
|
||||
serde_json::from_str(&buffer)?;
|
||||
Ok(Some((event, body)))
|
||||
}
|
||||
Err(e) => Err(e.into()),
|
||||
}
|
||||
});
|
||||
|
||||
Ok(open_ai::extract_text_from_events(stream))
|
||||
Ok(open_ai::extract_text_from_events(response_lines(response)))
|
||||
});
|
||||
async move { Ok(future.await?.boxed()) }.boxed()
|
||||
}
|
||||
|
@ -527,21 +504,9 @@ impl LanguageModel for CloudLanguageModel {
|
|||
},
|
||||
)
|
||||
.await?;
|
||||
let body = BufReader::new(response.into_body());
|
||||
let stream = futures::stream::try_unfold(body, move |mut body| async move {
|
||||
let mut buffer = String::new();
|
||||
match body.read_line(&mut buffer).await {
|
||||
Ok(0) => Ok(None),
|
||||
Ok(_) => {
|
||||
let event: google_ai::GenerateContentResponse =
|
||||
serde_json::from_str(&buffer)?;
|
||||
Ok(Some((event, body)))
|
||||
}
|
||||
Err(e) => Err(e.into()),
|
||||
}
|
||||
});
|
||||
|
||||
Ok(google_ai::extract_text_from_events(stream))
|
||||
Ok(google_ai::extract_text_from_events(response_lines(
|
||||
response,
|
||||
)))
|
||||
});
|
||||
async move { Ok(future.await?.boxed()) }.boxed()
|
||||
}
|
||||
|
@ -563,21 +528,7 @@ impl LanguageModel for CloudLanguageModel {
|
|||
},
|
||||
)
|
||||
.await?;
|
||||
let body = BufReader::new(response.into_body());
|
||||
let stream = futures::stream::try_unfold(body, move |mut body| async move {
|
||||
let mut buffer = String::new();
|
||||
match body.read_line(&mut buffer).await {
|
||||
Ok(0) => Ok(None),
|
||||
Ok(_) => {
|
||||
let event: open_ai::ResponseStreamEvent =
|
||||
serde_json::from_str(&buffer)?;
|
||||
Ok(Some((event, body)))
|
||||
}
|
||||
Err(e) => Err(e.into()),
|
||||
}
|
||||
});
|
||||
|
||||
Ok(open_ai::extract_text_from_events(stream))
|
||||
Ok(open_ai::extract_text_from_events(response_lines(response)))
|
||||
});
|
||||
async move { Ok(future.await?.boxed()) }.boxed()
|
||||
}
|
||||
|
@ -591,10 +542,12 @@ impl LanguageModel for CloudLanguageModel {
|
|||
tool_description: String,
|
||||
input_schema: serde_json::Value,
|
||||
_cx: &AsyncAppContext,
|
||||
) -> BoxFuture<'static, Result<serde_json::Value>> {
|
||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
||||
let client = self.client.clone();
|
||||
let llm_api_token = self.llm_api_token.clone();
|
||||
|
||||
match &self.model {
|
||||
CloudModel::Anthropic(model) => {
|
||||
let client = self.client.clone();
|
||||
let mut request = request.into_anthropic(model.tool_model_id().into());
|
||||
request.tool_choice = Some(anthropic::ToolChoice::Tool {
|
||||
name: tool_name.clone(),
|
||||
|
@ -605,7 +558,6 @@ impl LanguageModel for CloudLanguageModel {
|
|||
input_schema,
|
||||
}];
|
||||
|
||||
let llm_api_token = self.llm_api_token.clone();
|
||||
self.request_limiter
|
||||
.run(async move {
|
||||
let response = Self::perform_llm_completion(
|
||||
|
@ -621,70 +573,34 @@ impl LanguageModel for CloudLanguageModel {
|
|||
)
|
||||
.await?;
|
||||
|
||||
let mut tool_use_index = None;
|
||||
let mut tool_input = String::new();
|
||||
let mut body = BufReader::new(response.into_body());
|
||||
let mut line = String::new();
|
||||
while body.read_line(&mut line).await? > 0 {
|
||||
let event: anthropic::Event = serde_json::from_str(&line)?;
|
||||
line.clear();
|
||||
|
||||
match event {
|
||||
anthropic::Event::ContentBlockStart {
|
||||
content_block,
|
||||
index,
|
||||
} => {
|
||||
if let anthropic::Content::ToolUse { name, .. } = content_block
|
||||
{
|
||||
if name == tool_name {
|
||||
tool_use_index = Some(index);
|
||||
}
|
||||
}
|
||||
}
|
||||
anthropic::Event::ContentBlockDelta { index, delta } => match delta
|
||||
{
|
||||
anthropic::ContentDelta::TextDelta { .. } => {}
|
||||
anthropic::ContentDelta::InputJsonDelta { partial_json } => {
|
||||
if Some(index) == tool_use_index {
|
||||
tool_input.push_str(&partial_json);
|
||||
}
|
||||
}
|
||||
},
|
||||
anthropic::Event::ContentBlockStop { index } => {
|
||||
if Some(index) == tool_use_index {
|
||||
return Ok(serde_json::from_str(&tool_input)?);
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
if tool_use_index.is_some() {
|
||||
Err(anyhow!("tool content incomplete"))
|
||||
} else {
|
||||
Err(anyhow!("tool not used"))
|
||||
}
|
||||
Ok(anthropic::extract_tool_args_from_events(
|
||||
tool_name,
|
||||
Box::pin(response_lines(response)),
|
||||
)
|
||||
.await?
|
||||
.boxed())
|
||||
})
|
||||
.boxed()
|
||||
}
|
||||
CloudModel::OpenAi(model) => {
|
||||
let mut request = request.into_open_ai(model.id().into());
|
||||
let client = self.client.clone();
|
||||
let mut function = open_ai::FunctionDefinition {
|
||||
name: tool_name.clone(),
|
||||
description: None,
|
||||
parameters: None,
|
||||
};
|
||||
let func = open_ai::ToolDefinition::Function {
|
||||
function: function.clone(),
|
||||
};
|
||||
request.tool_choice = Some(open_ai::ToolChoice::Other(func.clone()));
|
||||
// Fill in description and params separately, as they're not needed for tool_choice field.
|
||||
function.description = Some(tool_description);
|
||||
function.parameters = Some(input_schema);
|
||||
request.tools = vec![open_ai::ToolDefinition::Function { function }];
|
||||
request.tool_choice = Some(open_ai::ToolChoice::Other(
|
||||
open_ai::ToolDefinition::Function {
|
||||
function: open_ai::FunctionDefinition {
|
||||
name: tool_name.clone(),
|
||||
description: None,
|
||||
parameters: None,
|
||||
},
|
||||
},
|
||||
));
|
||||
request.tools = vec![open_ai::ToolDefinition::Function {
|
||||
function: open_ai::FunctionDefinition {
|
||||
name: tool_name.clone(),
|
||||
description: Some(tool_description),
|
||||
parameters: Some(input_schema),
|
||||
},
|
||||
}];
|
||||
|
||||
let llm_api_token = self.llm_api_token.clone();
|
||||
self.request_limiter
|
||||
.run(async move {
|
||||
let response = Self::perform_llm_completion(
|
||||
|
@ -700,41 +616,12 @@ impl LanguageModel for CloudLanguageModel {
|
|||
)
|
||||
.await?;
|
||||
|
||||
let mut body = BufReader::new(response.into_body());
|
||||
let mut line = String::new();
|
||||
let mut load_state = None;
|
||||
|
||||
while body.read_line(&mut line).await? > 0 {
|
||||
let part: open_ai::ResponseStreamEvent = serde_json::from_str(&line)?;
|
||||
line.clear();
|
||||
|
||||
for choice in part.choices {
|
||||
let Some(tool_calls) = choice.delta.tool_calls else {
|
||||
continue;
|
||||
};
|
||||
|
||||
for call in tool_calls {
|
||||
if let Some(func) = call.function {
|
||||
if func.name.as_deref() == Some(tool_name.as_str()) {
|
||||
load_state = Some((String::default(), call.index));
|
||||
}
|
||||
if let Some((arguments, (output, index))) =
|
||||
func.arguments.zip(load_state.as_mut())
|
||||
{
|
||||
if call.index == *index {
|
||||
output.push_str(&arguments);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some((arguments, _)) = load_state {
|
||||
return Ok(serde_json::from_str(&arguments)?);
|
||||
} else {
|
||||
bail!("tool not used");
|
||||
}
|
||||
Ok(open_ai::extract_tool_args_from_events(
|
||||
tool_name,
|
||||
Box::pin(response_lines(response)),
|
||||
)
|
||||
.await?
|
||||
.boxed())
|
||||
})
|
||||
.boxed()
|
||||
}
|
||||
|
@ -744,22 +631,23 @@ impl LanguageModel for CloudLanguageModel {
|
|||
CloudModel::Zed(model) => {
|
||||
// All Zed models are OpenAI-based at the time of writing.
|
||||
let mut request = request.into_open_ai(model.id().into());
|
||||
let client = self.client.clone();
|
||||
let mut function = open_ai::FunctionDefinition {
|
||||
name: tool_name.clone(),
|
||||
description: None,
|
||||
parameters: None,
|
||||
};
|
||||
let func = open_ai::ToolDefinition::Function {
|
||||
function: function.clone(),
|
||||
};
|
||||
request.tool_choice = Some(open_ai::ToolChoice::Other(func.clone()));
|
||||
// Fill in description and params separately, as they're not needed for tool_choice field.
|
||||
function.description = Some(tool_description);
|
||||
function.parameters = Some(input_schema);
|
||||
request.tools = vec![open_ai::ToolDefinition::Function { function }];
|
||||
request.tool_choice = Some(open_ai::ToolChoice::Other(
|
||||
open_ai::ToolDefinition::Function {
|
||||
function: open_ai::FunctionDefinition {
|
||||
name: tool_name.clone(),
|
||||
description: None,
|
||||
parameters: None,
|
||||
},
|
||||
},
|
||||
));
|
||||
request.tools = vec![open_ai::ToolDefinition::Function {
|
||||
function: open_ai::FunctionDefinition {
|
||||
name: tool_name.clone(),
|
||||
description: Some(tool_description),
|
||||
parameters: Some(input_schema),
|
||||
},
|
||||
}];
|
||||
|
||||
let llm_api_token = self.llm_api_token.clone();
|
||||
self.request_limiter
|
||||
.run(async move {
|
||||
let response = Self::perform_llm_completion(
|
||||
|
@ -775,40 +663,12 @@ impl LanguageModel for CloudLanguageModel {
|
|||
)
|
||||
.await?;
|
||||
|
||||
let mut body = BufReader::new(response.into_body());
|
||||
let mut line = String::new();
|
||||
let mut load_state = None;
|
||||
|
||||
while body.read_line(&mut line).await? > 0 {
|
||||
let part: open_ai::ResponseStreamEvent = serde_json::from_str(&line)?;
|
||||
line.clear();
|
||||
|
||||
for choice in part.choices {
|
||||
let Some(tool_calls) = choice.delta.tool_calls else {
|
||||
continue;
|
||||
};
|
||||
|
||||
for call in tool_calls {
|
||||
if let Some(func) = call.function {
|
||||
if func.name.as_deref() == Some(tool_name.as_str()) {
|
||||
load_state = Some((String::default(), call.index));
|
||||
}
|
||||
if let Some((arguments, (output, index))) =
|
||||
func.arguments.zip(load_state.as_mut())
|
||||
{
|
||||
if call.index == *index {
|
||||
output.push_str(&arguments);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if let Some((arguments, _)) = load_state {
|
||||
return Ok(serde_json::from_str(&arguments)?);
|
||||
} else {
|
||||
bail!("tool not used");
|
||||
}
|
||||
Ok(open_ai::extract_tool_args_from_events(
|
||||
tool_name,
|
||||
Box::pin(response_lines(response)),
|
||||
)
|
||||
.await?
|
||||
.boxed())
|
||||
})
|
||||
.boxed()
|
||||
}
|
||||
|
@ -816,6 +676,25 @@ impl LanguageModel for CloudLanguageModel {
|
|||
}
|
||||
}
|
||||
|
||||
fn response_lines<T: DeserializeOwned>(
|
||||
response: Response<AsyncBody>,
|
||||
) -> impl Stream<Item = Result<T>> {
|
||||
futures::stream::try_unfold(
|
||||
(String::new(), BufReader::new(response.into_body())),
|
||||
move |(mut line, mut body)| async {
|
||||
match body.read_line(&mut line).await {
|
||||
Ok(0) => Ok(None),
|
||||
Ok(_) => {
|
||||
let event: T = serde_json::from_str(&line)?;
|
||||
line.clear();
|
||||
Ok(Some((event, (line, body))))
|
||||
}
|
||||
Err(e) => Err(e.into()),
|
||||
}
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
impl LlmApiToken {
|
||||
async fn acquire(&self, client: &Arc<Client>) -> Result<String> {
|
||||
let lock = self.0.upgradable_read().await;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue