Always stream completions through the LLM service (#16113)

This PR removes the `llm-service` feature flag and makes it so all
completions are done via the LLM service when using the Zed provider.

Release Notes:

- N/A
This commit is contained in:
Marshall Bowers 2024-08-12 09:33:24 -04:00 committed by GitHub
parent bab4da78b7
commit 6389c613a2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -4,10 +4,10 @@ use crate::{
LanguageModelName, LanguageModelProviderId, LanguageModelProviderName, LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelRequest, RateLimiter, ZedModel, LanguageModelProviderState, LanguageModelRequest, RateLimiter, ZedModel,
}; };
use anyhow::{anyhow, bail, Context as _, Result}; use anyhow::{anyhow, bail, Result};
use client::{Client, PerformCompletionParams, UserStore, EXPIRED_LLM_TOKEN_HEADER_NAME}; use client::{Client, PerformCompletionParams, UserStore, EXPIRED_LLM_TOKEN_HEADER_NAME};
use collections::BTreeMap; use collections::BTreeMap;
use feature_flags::{FeatureFlag, FeatureFlagAppExt, LanguageModels}; use feature_flags::{FeatureFlagAppExt, LanguageModels};
use futures::{future::BoxFuture, stream::BoxStream, AsyncBufReadExt, FutureExt, StreamExt}; use futures::{future::BoxFuture, stream::BoxStream, AsyncBufReadExt, FutureExt, StreamExt};
use gpui::{AnyView, AppContext, AsyncAppContext, Model, ModelContext, Subscription, Task}; use gpui::{AnyView, AppContext, AsyncAppContext, Model, ModelContext, Subscription, Task};
use http_client::{AsyncBody, HttpClient, Method, Response}; use http_client::{AsyncBody, HttpClient, Method, Response};
@ -228,16 +228,6 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
} }
} }
struct LlmServiceFeatureFlag;
impl FeatureFlag for LlmServiceFeatureFlag {
const NAME: &'static str = "llm-service";
fn enabled_for_staff() -> bool {
false
}
}
pub struct CloudLanguageModel { pub struct CloudLanguageModel {
id: LanguageModelId, id: LanguageModelId,
model: CloudModel, model: CloudModel,
@ -354,232 +344,148 @@ impl LanguageModel for CloudLanguageModel {
fn stream_completion( fn stream_completion(
&self, &self,
request: LanguageModelRequest, request: LanguageModelRequest,
cx: &AsyncAppContext, _cx: &AsyncAppContext,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> { ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
match &self.model { match &self.model {
CloudModel::Anthropic(model) => { CloudModel::Anthropic(model) => {
let request = request.into_anthropic(model.id().into()); let request = request.into_anthropic(model.id().into());
let client = self.client.clone(); let client = self.client.clone();
let llm_api_token = self.llm_api_token.clone();
if cx let future = self.request_limiter.stream(async move {
.update(|cx| cx.has_flag::<LlmServiceFeatureFlag>()) let response = Self::perform_llm_completion(
.unwrap_or(false) client.clone(),
{ llm_api_token,
let llm_api_token = self.llm_api_token.clone(); PerformCompletionParams {
let future = self.request_limiter.stream(async move { provider: client::LanguageModelProvider::Anthropic,
let response = Self::perform_llm_completion( model: request.model.clone(),
client.clone(), provider_request: RawValue::from_string(serde_json::to_string(
llm_api_token, &request,
PerformCompletionParams { )?)?,
provider: client::LanguageModelProvider::Anthropic, },
model: request.model.clone(), )
provider_request: RawValue::from_string(serde_json::to_string( .await?;
&request, let body = BufReader::new(response.into_body());
)?)?, let stream = futures::stream::try_unfold(body, move |mut body| async move {
}, let mut buffer = String::new();
) match body.read_line(&mut buffer).await {
.await?; Ok(0) => Ok(None),
let body = BufReader::new(response.into_body()); Ok(_) => {
let stream = let event: anthropic::Event = serde_json::from_str(&buffer)?;
futures::stream::try_unfold(body, move |mut body| async move { Ok(Some((event, body)))
let mut buffer = String::new(); }
match body.read_line(&mut buffer).await { Err(e) => Err(e.into()),
Ok(0) => Ok(None), }
Ok(_) => {
let event: anthropic::Event =
serde_json::from_str(&buffer)?;
Ok(Some((event, body)))
}
Err(e) => Err(e.into()),
}
});
Ok(anthropic::extract_text_from_events(stream))
}); });
async move { Ok(future.await?.boxed()) }.boxed()
} else { Ok(anthropic::extract_text_from_events(stream))
let future = self.request_limiter.stream(async move { });
let request = serde_json::to_string(&request)?; async move { Ok(future.await?.boxed()) }.boxed()
let stream = client
.request_stream(proto::StreamCompleteWithLanguageModel {
provider: proto::LanguageModelProvider::Anthropic as i32,
request,
})
.await?
.map(|event| Ok(serde_json::from_str(&event?.event)?));
Ok(anthropic::extract_text_from_events(stream))
});
async move { Ok(future.await?.boxed()) }.boxed()
}
} }
CloudModel::OpenAi(model) => { CloudModel::OpenAi(model) => {
let client = self.client.clone(); let client = self.client.clone();
let request = request.into_open_ai(model.id().into()); let request = request.into_open_ai(model.id().into());
let llm_api_token = self.llm_api_token.clone();
if cx let future = self.request_limiter.stream(async move {
.update(|cx| cx.has_flag::<LlmServiceFeatureFlag>()) let response = Self::perform_llm_completion(
.unwrap_or(false) client.clone(),
{ llm_api_token,
let llm_api_token = self.llm_api_token.clone(); PerformCompletionParams {
let future = self.request_limiter.stream(async move { provider: client::LanguageModelProvider::OpenAi,
let response = Self::perform_llm_completion( model: request.model.clone(),
client.clone(), provider_request: RawValue::from_string(serde_json::to_string(
llm_api_token, &request,
PerformCompletionParams { )?)?,
provider: client::LanguageModelProvider::OpenAi, },
model: request.model.clone(), )
provider_request: RawValue::from_string(serde_json::to_string( .await?;
&request, let body = BufReader::new(response.into_body());
)?)?, let stream = futures::stream::try_unfold(body, move |mut body| async move {
}, let mut buffer = String::new();
) match body.read_line(&mut buffer).await {
.await?; Ok(0) => Ok(None),
let body = BufReader::new(response.into_body()); Ok(_) => {
let stream = let event: open_ai::ResponseStreamEvent =
futures::stream::try_unfold(body, move |mut body| async move { serde_json::from_str(&buffer)?;
let mut buffer = String::new(); Ok(Some((event, body)))
match body.read_line(&mut buffer).await { }
Ok(0) => Ok(None), Err(e) => Err(e.into()),
Ok(_) => { }
let event: open_ai::ResponseStreamEvent =
serde_json::from_str(&buffer)?;
Ok(Some((event, body)))
}
Err(e) => Err(e.into()),
}
});
Ok(open_ai::extract_text_from_events(stream))
}); });
async move { Ok(future.await?.boxed()) }.boxed()
} else { Ok(open_ai::extract_text_from_events(stream))
let future = self.request_limiter.stream(async move { });
let request = serde_json::to_string(&request)?; async move { Ok(future.await?.boxed()) }.boxed()
let stream = client
.request_stream(proto::StreamCompleteWithLanguageModel {
provider: proto::LanguageModelProvider::OpenAi as i32,
request,
})
.await?;
Ok(open_ai::extract_text_from_events(
stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
))
});
async move { Ok(future.await?.boxed()) }.boxed()
}
} }
CloudModel::Google(model) => { CloudModel::Google(model) => {
let client = self.client.clone(); let client = self.client.clone();
let request = request.into_google(model.id().into()); let request = request.into_google(model.id().into());
let llm_api_token = self.llm_api_token.clone();
if cx let future = self.request_limiter.stream(async move {
.update(|cx| cx.has_flag::<LlmServiceFeatureFlag>()) let response = Self::perform_llm_completion(
.unwrap_or(false) client.clone(),
{ llm_api_token,
let llm_api_token = self.llm_api_token.clone(); PerformCompletionParams {
let future = self.request_limiter.stream(async move { provider: client::LanguageModelProvider::Google,
let response = Self::perform_llm_completion( model: request.model.clone(),
client.clone(), provider_request: RawValue::from_string(serde_json::to_string(
llm_api_token, &request,
PerformCompletionParams { )?)?,
provider: client::LanguageModelProvider::Google, },
model: request.model.clone(), )
provider_request: RawValue::from_string(serde_json::to_string( .await?;
&request, let body = BufReader::new(response.into_body());
)?)?, let stream = futures::stream::try_unfold(body, move |mut body| async move {
}, let mut buffer = String::new();
) match body.read_line(&mut buffer).await {
.await?; Ok(0) => Ok(None),
let body = BufReader::new(response.into_body()); Ok(_) => {
let stream = let event: google_ai::GenerateContentResponse =
futures::stream::try_unfold(body, move |mut body| async move { serde_json::from_str(&buffer)?;
let mut buffer = String::new(); Ok(Some((event, body)))
match body.read_line(&mut buffer).await { }
Ok(0) => Ok(None), Err(e) => Err(e.into()),
Ok(_) => { }
let event: google_ai::GenerateContentResponse =
serde_json::from_str(&buffer)?;
Ok(Some((event, body)))
}
Err(e) => Err(e.into()),
}
});
Ok(google_ai::extract_text_from_events(stream))
}); });
async move { Ok(future.await?.boxed()) }.boxed()
} else { Ok(google_ai::extract_text_from_events(stream))
let future = self.request_limiter.stream(async move { });
let request = serde_json::to_string(&request)?; async move { Ok(future.await?.boxed()) }.boxed()
let stream = client
.request_stream(proto::StreamCompleteWithLanguageModel {
provider: proto::LanguageModelProvider::Google as i32,
request,
})
.await?;
Ok(google_ai::extract_text_from_events(
stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
))
});
async move { Ok(future.await?.boxed()) }.boxed()
}
} }
CloudModel::Zed(model) => { CloudModel::Zed(model) => {
let client = self.client.clone(); let client = self.client.clone();
let mut request = request.into_open_ai(model.id().into()); let mut request = request.into_open_ai(model.id().into());
request.max_tokens = Some(4000); request.max_tokens = Some(4000);
let llm_api_token = self.llm_api_token.clone();
if cx let future = self.request_limiter.stream(async move {
.update(|cx| cx.has_flag::<LlmServiceFeatureFlag>()) let response = Self::perform_llm_completion(
.unwrap_or(false) client.clone(),
{ llm_api_token,
let llm_api_token = self.llm_api_token.clone(); PerformCompletionParams {
let future = self.request_limiter.stream(async move { provider: client::LanguageModelProvider::Zed,
let response = Self::perform_llm_completion( model: request.model.clone(),
client.clone(), provider_request: RawValue::from_string(serde_json::to_string(
llm_api_token, &request,
PerformCompletionParams { )?)?,
provider: client::LanguageModelProvider::Zed, },
model: request.model.clone(), )
provider_request: RawValue::from_string(serde_json::to_string( .await?;
&request, let body = BufReader::new(response.into_body());
)?)?, let stream = futures::stream::try_unfold(body, move |mut body| async move {
}, let mut buffer = String::new();
) match body.read_line(&mut buffer).await {
.await?; Ok(0) => Ok(None),
let body = BufReader::new(response.into_body()); Ok(_) => {
let stream = let event: open_ai::ResponseStreamEvent =
futures::stream::try_unfold(body, move |mut body| async move { serde_json::from_str(&buffer)?;
let mut buffer = String::new(); Ok(Some((event, body)))
match body.read_line(&mut buffer).await { }
Ok(0) => Ok(None), Err(e) => Err(e.into()),
Ok(_) => { }
let event: open_ai::ResponseStreamEvent =
serde_json::from_str(&buffer)?;
Ok(Some((event, body)))
}
Err(e) => Err(e.into()),
}
});
Ok(open_ai::extract_text_from_events(stream))
}); });
async move { Ok(future.await?.boxed()) }.boxed()
} else { Ok(open_ai::extract_text_from_events(stream))
let future = self.request_limiter.stream(async move { });
let request = serde_json::to_string(&request)?; async move { Ok(future.await?.boxed()) }.boxed()
let stream = client
.request_stream(proto::StreamCompleteWithLanguageModel {
provider: proto::LanguageModelProvider::Zed as i32,
request,
})
.await?;
Ok(open_ai::extract_text_from_events(
stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
))
});
async move { Ok(future.await?.boxed()) }.boxed()
}
} }
} }
} }
@ -590,7 +496,7 @@ impl LanguageModel for CloudLanguageModel {
tool_name: String, tool_name: String,
tool_description: String, tool_description: String,
input_schema: serde_json::Value, input_schema: serde_json::Value,
cx: &AsyncAppContext, _cx: &AsyncAppContext,
) -> BoxFuture<'static, Result<serde_json::Value>> { ) -> BoxFuture<'static, Result<serde_json::Value>> {
match &self.model { match &self.model {
CloudModel::Anthropic(model) => { CloudModel::Anthropic(model) => {
@ -605,106 +511,67 @@ impl LanguageModel for CloudLanguageModel {
input_schema, input_schema,
}]; }];
if cx let llm_api_token = self.llm_api_token.clone();
.update(|cx| cx.has_flag::<LlmServiceFeatureFlag>()) self.request_limiter
.unwrap_or(false) .run(async move {
{ let response = Self::perform_llm_completion(
let llm_api_token = self.llm_api_token.clone(); client.clone(),
self.request_limiter llm_api_token,
.run(async move { PerformCompletionParams {
let response = Self::perform_llm_completion( provider: client::LanguageModelProvider::Anthropic,
client.clone(), model: request.model.clone(),
llm_api_token, provider_request: RawValue::from_string(serde_json::to_string(
PerformCompletionParams { &request,
provider: client::LanguageModelProvider::Anthropic, )?)?,
model: request.model.clone(), },
provider_request: RawValue::from_string( )
serde_json::to_string(&request)?, .await?;
)?,
},
)
.await?;
let mut tool_use_index = None; let mut tool_use_index = None;
let mut tool_input = String::new(); let mut tool_input = String::new();
let mut body = BufReader::new(response.into_body()); let mut body = BufReader::new(response.into_body());
let mut line = String::new(); let mut line = String::new();
while body.read_line(&mut line).await? > 0 { while body.read_line(&mut line).await? > 0 {
let event: anthropic::Event = serde_json::from_str(&line)?; let event: anthropic::Event = serde_json::from_str(&line)?;
line.clear(); line.clear();
match event { match event {
anthropic::Event::ContentBlockStart { anthropic::Event::ContentBlockStart {
content_block, content_block,
index, index,
} => { } => {
if let anthropic::Content::ToolUse { name, .. } = if let anthropic::Content::ToolUse { name, .. } = content_block
content_block
{
if name == tool_name {
tool_use_index = Some(index);
}
}
}
anthropic::Event::ContentBlockDelta { index, delta } => {
match delta {
anthropic::ContentDelta::TextDelta { .. } => {}
anthropic::ContentDelta::InputJsonDelta {
partial_json,
} => {
if Some(index) == tool_use_index {
tool_input.push_str(&partial_json);
}
}
}
}
anthropic::Event::ContentBlockStop { index } => {
if Some(index) == tool_use_index {
return Ok(serde_json::from_str(&tool_input)?);
}
}
_ => {}
}
}
if tool_use_index.is_some() {
Err(anyhow!("tool content incomplete"))
} else {
Err(anyhow!("tool not used"))
}
})
.boxed()
} else {
self.request_limiter
.run(async move {
let request = serde_json::to_string(&request)?;
let response = client
.request(proto::CompleteWithLanguageModel {
provider: proto::LanguageModelProvider::Anthropic as i32,
request,
})
.await?;
let response: anthropic::Response =
serde_json::from_str(&response.completion)?;
response
.content
.into_iter()
.find_map(|content| {
if let anthropic::Content::ToolUse { name, input, .. } = content
{ {
if name == tool_name { if name == tool_name {
Some(input) tool_use_index = Some(index);
} else {
None
} }
} else {
None
} }
}) }
.context("tool not used") anthropic::Event::ContentBlockDelta { index, delta } => match delta
}) {
.boxed() anthropic::ContentDelta::TextDelta { .. } => {}
} anthropic::ContentDelta::InputJsonDelta { partial_json } => {
if Some(index) == tool_use_index {
tool_input.push_str(&partial_json);
}
}
},
anthropic::Event::ContentBlockStop { index } => {
if Some(index) == tool_use_index {
return Ok(serde_json::from_str(&tool_input)?);
}
}
_ => {}
}
}
if tool_use_index.is_some() {
Err(anyhow!("tool content incomplete"))
} else {
Err(anyhow!("tool not used"))
}
})
.boxed()
} }
CloudModel::OpenAi(model) => { CloudModel::OpenAi(model) => {
let mut request = request.into_open_ai(model.id().into()); let mut request = request.into_open_ai(model.id().into());
@ -723,115 +590,59 @@ impl LanguageModel for CloudLanguageModel {
function.parameters = Some(input_schema); function.parameters = Some(input_schema);
request.tools = vec![open_ai::ToolDefinition::Function { function }]; request.tools = vec![open_ai::ToolDefinition::Function { function }];
if cx let llm_api_token = self.llm_api_token.clone();
.update(|cx| cx.has_flag::<LlmServiceFeatureFlag>()) self.request_limiter
.unwrap_or(false) .run(async move {
{ let response = Self::perform_llm_completion(
let llm_api_token = self.llm_api_token.clone(); client.clone(),
self.request_limiter llm_api_token,
.run(async move { PerformCompletionParams {
let response = Self::perform_llm_completion( provider: client::LanguageModelProvider::OpenAi,
client.clone(), model: request.model.clone(),
llm_api_token, provider_request: RawValue::from_string(serde_json::to_string(
PerformCompletionParams { &request,
provider: client::LanguageModelProvider::OpenAi, )?)?,
model: request.model.clone(), },
provider_request: RawValue::from_string( )
serde_json::to_string(&request)?, .await?;
)?,
},
)
.await?;
let mut body = BufReader::new(response.into_body()); let mut body = BufReader::new(response.into_body());
let mut line = String::new(); let mut line = String::new();
let mut load_state = None; let mut load_state = None;
while body.read_line(&mut line).await? > 0 { while body.read_line(&mut line).await? > 0 {
let part: open_ai::ResponseStreamEvent = let part: open_ai::ResponseStreamEvent = serde_json::from_str(&line)?;
serde_json::from_str(&line)?; line.clear();
line.clear();
for choice in part.choices { for choice in part.choices {
let Some(tool_calls) = choice.delta.tool_calls else { let Some(tool_calls) = choice.delta.tool_calls else {
continue; continue;
}; };
for call in tool_calls { for call in tool_calls {
if let Some(func) = call.function { if let Some(func) = call.function {
if func.name.as_deref() == Some(tool_name.as_str()) { if func.name.as_deref() == Some(tool_name.as_str()) {
load_state = Some((String::default(), call.index)); load_state = Some((String::default(), call.index));
} }
if let Some((arguments, (output, index))) = if let Some((arguments, (output, index))) =
func.arguments.zip(load_state.as_mut()) func.arguments.zip(load_state.as_mut())
{ {
if call.index == *index { if call.index == *index {
output.push_str(&arguments); output.push_str(&arguments);
}
} }
} }
} }
} }
} }
}
if let Some((arguments, _)) = load_state { if let Some((arguments, _)) = load_state {
return Ok(serde_json::from_str(&arguments)?); return Ok(serde_json::from_str(&arguments)?);
} else { } else {
bail!("tool not used"); bail!("tool not used");
} }
}) })
.boxed() .boxed()
} else {
self.request_limiter
.run(async move {
let request = serde_json::to_string(&request)?;
let response = client
.request_stream(proto::StreamCompleteWithLanguageModel {
provider: proto::LanguageModelProvider::OpenAi as i32,
request,
})
.await?;
let mut load_state = None;
let mut response = response.map(
|item: Result<
proto::StreamCompleteWithLanguageModelResponse,
anyhow::Error,
>| {
Result::<open_ai::ResponseStreamEvent, anyhow::Error>::Ok(
serde_json::from_str(&item?.event)?,
)
},
);
while let Some(Ok(part)) = response.next().await {
for choice in part.choices {
let Some(tool_calls) = choice.delta.tool_calls else {
continue;
};
for call in tool_calls {
if let Some(func) = call.function {
if func.name.as_deref() == Some(tool_name.as_str()) {
load_state = Some((String::default(), call.index));
}
if let Some((arguments, (output, index))) =
func.arguments.zip(load_state.as_mut())
{
if call.index == *index {
output.push_str(&arguments);
}
}
}
}
}
}
if let Some((arguments, _)) = load_state {
return Ok(serde_json::from_str(&arguments)?);
} else {
bail!("tool not used");
}
})
.boxed()
}
} }
CloudModel::Google(_) => { CloudModel::Google(_) => {
future::ready(Err(anyhow!("tool use not implemented for Google AI"))).boxed() future::ready(Err(anyhow!("tool use not implemented for Google AI"))).boxed()
@ -854,114 +665,58 @@ impl LanguageModel for CloudLanguageModel {
function.parameters = Some(input_schema); function.parameters = Some(input_schema);
request.tools = vec![open_ai::ToolDefinition::Function { function }]; request.tools = vec![open_ai::ToolDefinition::Function { function }];
if cx let llm_api_token = self.llm_api_token.clone();
.update(|cx| cx.has_flag::<LlmServiceFeatureFlag>()) self.request_limiter
.unwrap_or(false) .run(async move {
{ let response = Self::perform_llm_completion(
let llm_api_token = self.llm_api_token.clone(); client.clone(),
self.request_limiter llm_api_token,
.run(async move { PerformCompletionParams {
let response = Self::perform_llm_completion( provider: client::LanguageModelProvider::Zed,
client.clone(), model: request.model.clone(),
llm_api_token, provider_request: RawValue::from_string(serde_json::to_string(
PerformCompletionParams { &request,
provider: client::LanguageModelProvider::Zed, )?)?,
model: request.model.clone(), },
provider_request: RawValue::from_string( )
serde_json::to_string(&request)?, .await?;
)?,
},
)
.await?;
let mut body = BufReader::new(response.into_body()); let mut body = BufReader::new(response.into_body());
let mut line = String::new(); let mut line = String::new();
let mut load_state = None; let mut load_state = None;
while body.read_line(&mut line).await? > 0 { while body.read_line(&mut line).await? > 0 {
let part: open_ai::ResponseStreamEvent = let part: open_ai::ResponseStreamEvent = serde_json::from_str(&line)?;
serde_json::from_str(&line)?; line.clear();
line.clear();
for choice in part.choices { for choice in part.choices {
let Some(tool_calls) = choice.delta.tool_calls else { let Some(tool_calls) = choice.delta.tool_calls else {
continue; continue;
}; };
for call in tool_calls { for call in tool_calls {
if let Some(func) = call.function { if let Some(func) = call.function {
if func.name.as_deref() == Some(tool_name.as_str()) { if func.name.as_deref() == Some(tool_name.as_str()) {
load_state = Some((String::default(), call.index)); load_state = Some((String::default(), call.index));
} }
if let Some((arguments, (output, index))) = if let Some((arguments, (output, index))) =
func.arguments.zip(load_state.as_mut()) func.arguments.zip(load_state.as_mut())
{ {
if call.index == *index { if call.index == *index {
output.push_str(&arguments); output.push_str(&arguments);
}
} }
} }
} }
} }
} }
if let Some((arguments, _)) = load_state { }
return Ok(serde_json::from_str(&arguments)?); if let Some((arguments, _)) = load_state {
} else { return Ok(serde_json::from_str(&arguments)?);
bail!("tool not used"); } else {
} bail!("tool not used");
}) }
.boxed() })
} else { .boxed()
self.request_limiter
.run(async move {
let request = serde_json::to_string(&request)?;
let response = client
.request_stream(proto::StreamCompleteWithLanguageModel {
provider: proto::LanguageModelProvider::OpenAi as i32,
request,
})
.await?;
let mut load_state = None;
let mut response = response.map(
|item: Result<
proto::StreamCompleteWithLanguageModelResponse,
anyhow::Error,
>| {
Result::<open_ai::ResponseStreamEvent, anyhow::Error>::Ok(
serde_json::from_str(&item?.event)?,
)
},
);
while let Some(Ok(part)) = response.next().await {
for choice in part.choices {
let Some(tool_calls) = choice.delta.tool_calls else {
continue;
};
for call in tool_calls {
if let Some(func) = call.function {
if func.name.as_deref() == Some(tool_name.as_str()) {
load_state = Some((String::default(), call.index));
}
if let Some((arguments, (output, index))) =
func.arguments.zip(load_state.as_mut())
{
if call.index == *index {
output.push_str(&arguments);
}
}
}
}
}
}
if let Some((arguments, _)) = load_state {
return Ok(serde_json::from_str(&arguments)?);
} else {
bail!("tool not used");
}
})
.boxed()
}
} }
} }
} }