Always stream completions through the LLM service (#16113)
This PR removes the `llm-service` feature flag and makes it so all completions are done via the LLM service when using the Zed provider. Release Notes: - N/A
This commit is contained in:
parent
bab4da78b7
commit
6389c613a2
1 changed files with 264 additions and 509 deletions
|
@ -4,10 +4,10 @@ use crate::{
|
|||
LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
|
||||
LanguageModelProviderState, LanguageModelRequest, RateLimiter, ZedModel,
|
||||
};
|
||||
use anyhow::{anyhow, bail, Context as _, Result};
|
||||
use anyhow::{anyhow, bail, Result};
|
||||
use client::{Client, PerformCompletionParams, UserStore, EXPIRED_LLM_TOKEN_HEADER_NAME};
|
||||
use collections::BTreeMap;
|
||||
use feature_flags::{FeatureFlag, FeatureFlagAppExt, LanguageModels};
|
||||
use feature_flags::{FeatureFlagAppExt, LanguageModels};
|
||||
use futures::{future::BoxFuture, stream::BoxStream, AsyncBufReadExt, FutureExt, StreamExt};
|
||||
use gpui::{AnyView, AppContext, AsyncAppContext, Model, ModelContext, Subscription, Task};
|
||||
use http_client::{AsyncBody, HttpClient, Method, Response};
|
||||
|
@ -228,16 +228,6 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
|
|||
}
|
||||
}
|
||||
|
||||
struct LlmServiceFeatureFlag;
|
||||
|
||||
impl FeatureFlag for LlmServiceFeatureFlag {
|
||||
const NAME: &'static str = "llm-service";
|
||||
|
||||
fn enabled_for_staff() -> bool {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
pub struct CloudLanguageModel {
|
||||
id: LanguageModelId,
|
||||
model: CloudModel,
|
||||
|
@ -354,232 +344,148 @@ impl LanguageModel for CloudLanguageModel {
|
|||
fn stream_completion(
|
||||
&self,
|
||||
request: LanguageModelRequest,
|
||||
cx: &AsyncAppContext,
|
||||
_cx: &AsyncAppContext,
|
||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
||||
match &self.model {
|
||||
CloudModel::Anthropic(model) => {
|
||||
let request = request.into_anthropic(model.id().into());
|
||||
let client = self.client.clone();
|
||||
|
||||
if cx
|
||||
.update(|cx| cx.has_flag::<LlmServiceFeatureFlag>())
|
||||
.unwrap_or(false)
|
||||
{
|
||||
let llm_api_token = self.llm_api_token.clone();
|
||||
let future = self.request_limiter.stream(async move {
|
||||
let response = Self::perform_llm_completion(
|
||||
client.clone(),
|
||||
llm_api_token,
|
||||
PerformCompletionParams {
|
||||
provider: client::LanguageModelProvider::Anthropic,
|
||||
model: request.model.clone(),
|
||||
provider_request: RawValue::from_string(serde_json::to_string(
|
||||
&request,
|
||||
)?)?,
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
let body = BufReader::new(response.into_body());
|
||||
let stream =
|
||||
futures::stream::try_unfold(body, move |mut body| async move {
|
||||
let mut buffer = String::new();
|
||||
match body.read_line(&mut buffer).await {
|
||||
Ok(0) => Ok(None),
|
||||
Ok(_) => {
|
||||
let event: anthropic::Event =
|
||||
serde_json::from_str(&buffer)?;
|
||||
Ok(Some((event, body)))
|
||||
}
|
||||
Err(e) => Err(e.into()),
|
||||
}
|
||||
});
|
||||
|
||||
Ok(anthropic::extract_text_from_events(stream))
|
||||
let llm_api_token = self.llm_api_token.clone();
|
||||
let future = self.request_limiter.stream(async move {
|
||||
let response = Self::perform_llm_completion(
|
||||
client.clone(),
|
||||
llm_api_token,
|
||||
PerformCompletionParams {
|
||||
provider: client::LanguageModelProvider::Anthropic,
|
||||
model: request.model.clone(),
|
||||
provider_request: RawValue::from_string(serde_json::to_string(
|
||||
&request,
|
||||
)?)?,
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
let body = BufReader::new(response.into_body());
|
||||
let stream = futures::stream::try_unfold(body, move |mut body| async move {
|
||||
let mut buffer = String::new();
|
||||
match body.read_line(&mut buffer).await {
|
||||
Ok(0) => Ok(None),
|
||||
Ok(_) => {
|
||||
let event: anthropic::Event = serde_json::from_str(&buffer)?;
|
||||
Ok(Some((event, body)))
|
||||
}
|
||||
Err(e) => Err(e.into()),
|
||||
}
|
||||
});
|
||||
async move { Ok(future.await?.boxed()) }.boxed()
|
||||
} else {
|
||||
let future = self.request_limiter.stream(async move {
|
||||
let request = serde_json::to_string(&request)?;
|
||||
let stream = client
|
||||
.request_stream(proto::StreamCompleteWithLanguageModel {
|
||||
provider: proto::LanguageModelProvider::Anthropic as i32,
|
||||
request,
|
||||
})
|
||||
.await?
|
||||
.map(|event| Ok(serde_json::from_str(&event?.event)?));
|
||||
Ok(anthropic::extract_text_from_events(stream))
|
||||
});
|
||||
async move { Ok(future.await?.boxed()) }.boxed()
|
||||
}
|
||||
|
||||
Ok(anthropic::extract_text_from_events(stream))
|
||||
});
|
||||
async move { Ok(future.await?.boxed()) }.boxed()
|
||||
}
|
||||
CloudModel::OpenAi(model) => {
|
||||
let client = self.client.clone();
|
||||
let request = request.into_open_ai(model.id().into());
|
||||
|
||||
if cx
|
||||
.update(|cx| cx.has_flag::<LlmServiceFeatureFlag>())
|
||||
.unwrap_or(false)
|
||||
{
|
||||
let llm_api_token = self.llm_api_token.clone();
|
||||
let future = self.request_limiter.stream(async move {
|
||||
let response = Self::perform_llm_completion(
|
||||
client.clone(),
|
||||
llm_api_token,
|
||||
PerformCompletionParams {
|
||||
provider: client::LanguageModelProvider::OpenAi,
|
||||
model: request.model.clone(),
|
||||
provider_request: RawValue::from_string(serde_json::to_string(
|
||||
&request,
|
||||
)?)?,
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
let body = BufReader::new(response.into_body());
|
||||
let stream =
|
||||
futures::stream::try_unfold(body, move |mut body| async move {
|
||||
let mut buffer = String::new();
|
||||
match body.read_line(&mut buffer).await {
|
||||
Ok(0) => Ok(None),
|
||||
Ok(_) => {
|
||||
let event: open_ai::ResponseStreamEvent =
|
||||
serde_json::from_str(&buffer)?;
|
||||
Ok(Some((event, body)))
|
||||
}
|
||||
Err(e) => Err(e.into()),
|
||||
}
|
||||
});
|
||||
|
||||
Ok(open_ai::extract_text_from_events(stream))
|
||||
let llm_api_token = self.llm_api_token.clone();
|
||||
let future = self.request_limiter.stream(async move {
|
||||
let response = Self::perform_llm_completion(
|
||||
client.clone(),
|
||||
llm_api_token,
|
||||
PerformCompletionParams {
|
||||
provider: client::LanguageModelProvider::OpenAi,
|
||||
model: request.model.clone(),
|
||||
provider_request: RawValue::from_string(serde_json::to_string(
|
||||
&request,
|
||||
)?)?,
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
let body = BufReader::new(response.into_body());
|
||||
let stream = futures::stream::try_unfold(body, move |mut body| async move {
|
||||
let mut buffer = String::new();
|
||||
match body.read_line(&mut buffer).await {
|
||||
Ok(0) => Ok(None),
|
||||
Ok(_) => {
|
||||
let event: open_ai::ResponseStreamEvent =
|
||||
serde_json::from_str(&buffer)?;
|
||||
Ok(Some((event, body)))
|
||||
}
|
||||
Err(e) => Err(e.into()),
|
||||
}
|
||||
});
|
||||
async move { Ok(future.await?.boxed()) }.boxed()
|
||||
} else {
|
||||
let future = self.request_limiter.stream(async move {
|
||||
let request = serde_json::to_string(&request)?;
|
||||
let stream = client
|
||||
.request_stream(proto::StreamCompleteWithLanguageModel {
|
||||
provider: proto::LanguageModelProvider::OpenAi as i32,
|
||||
request,
|
||||
})
|
||||
.await?;
|
||||
Ok(open_ai::extract_text_from_events(
|
||||
stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
|
||||
))
|
||||
});
|
||||
async move { Ok(future.await?.boxed()) }.boxed()
|
||||
}
|
||||
|
||||
Ok(open_ai::extract_text_from_events(stream))
|
||||
});
|
||||
async move { Ok(future.await?.boxed()) }.boxed()
|
||||
}
|
||||
CloudModel::Google(model) => {
|
||||
let client = self.client.clone();
|
||||
let request = request.into_google(model.id().into());
|
||||
|
||||
if cx
|
||||
.update(|cx| cx.has_flag::<LlmServiceFeatureFlag>())
|
||||
.unwrap_or(false)
|
||||
{
|
||||
let llm_api_token = self.llm_api_token.clone();
|
||||
let future = self.request_limiter.stream(async move {
|
||||
let response = Self::perform_llm_completion(
|
||||
client.clone(),
|
||||
llm_api_token,
|
||||
PerformCompletionParams {
|
||||
provider: client::LanguageModelProvider::Google,
|
||||
model: request.model.clone(),
|
||||
provider_request: RawValue::from_string(serde_json::to_string(
|
||||
&request,
|
||||
)?)?,
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
let body = BufReader::new(response.into_body());
|
||||
let stream =
|
||||
futures::stream::try_unfold(body, move |mut body| async move {
|
||||
let mut buffer = String::new();
|
||||
match body.read_line(&mut buffer).await {
|
||||
Ok(0) => Ok(None),
|
||||
Ok(_) => {
|
||||
let event: google_ai::GenerateContentResponse =
|
||||
serde_json::from_str(&buffer)?;
|
||||
Ok(Some((event, body)))
|
||||
}
|
||||
Err(e) => Err(e.into()),
|
||||
}
|
||||
});
|
||||
|
||||
Ok(google_ai::extract_text_from_events(stream))
|
||||
let llm_api_token = self.llm_api_token.clone();
|
||||
let future = self.request_limiter.stream(async move {
|
||||
let response = Self::perform_llm_completion(
|
||||
client.clone(),
|
||||
llm_api_token,
|
||||
PerformCompletionParams {
|
||||
provider: client::LanguageModelProvider::Google,
|
||||
model: request.model.clone(),
|
||||
provider_request: RawValue::from_string(serde_json::to_string(
|
||||
&request,
|
||||
)?)?,
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
let body = BufReader::new(response.into_body());
|
||||
let stream = futures::stream::try_unfold(body, move |mut body| async move {
|
||||
let mut buffer = String::new();
|
||||
match body.read_line(&mut buffer).await {
|
||||
Ok(0) => Ok(None),
|
||||
Ok(_) => {
|
||||
let event: google_ai::GenerateContentResponse =
|
||||
serde_json::from_str(&buffer)?;
|
||||
Ok(Some((event, body)))
|
||||
}
|
||||
Err(e) => Err(e.into()),
|
||||
}
|
||||
});
|
||||
async move { Ok(future.await?.boxed()) }.boxed()
|
||||
} else {
|
||||
let future = self.request_limiter.stream(async move {
|
||||
let request = serde_json::to_string(&request)?;
|
||||
let stream = client
|
||||
.request_stream(proto::StreamCompleteWithLanguageModel {
|
||||
provider: proto::LanguageModelProvider::Google as i32,
|
||||
request,
|
||||
})
|
||||
.await?;
|
||||
Ok(google_ai::extract_text_from_events(
|
||||
stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
|
||||
))
|
||||
});
|
||||
async move { Ok(future.await?.boxed()) }.boxed()
|
||||
}
|
||||
|
||||
Ok(google_ai::extract_text_from_events(stream))
|
||||
});
|
||||
async move { Ok(future.await?.boxed()) }.boxed()
|
||||
}
|
||||
CloudModel::Zed(model) => {
|
||||
let client = self.client.clone();
|
||||
let mut request = request.into_open_ai(model.id().into());
|
||||
request.max_tokens = Some(4000);
|
||||
|
||||
if cx
|
||||
.update(|cx| cx.has_flag::<LlmServiceFeatureFlag>())
|
||||
.unwrap_or(false)
|
||||
{
|
||||
let llm_api_token = self.llm_api_token.clone();
|
||||
let future = self.request_limiter.stream(async move {
|
||||
let response = Self::perform_llm_completion(
|
||||
client.clone(),
|
||||
llm_api_token,
|
||||
PerformCompletionParams {
|
||||
provider: client::LanguageModelProvider::Zed,
|
||||
model: request.model.clone(),
|
||||
provider_request: RawValue::from_string(serde_json::to_string(
|
||||
&request,
|
||||
)?)?,
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
let body = BufReader::new(response.into_body());
|
||||
let stream =
|
||||
futures::stream::try_unfold(body, move |mut body| async move {
|
||||
let mut buffer = String::new();
|
||||
match body.read_line(&mut buffer).await {
|
||||
Ok(0) => Ok(None),
|
||||
Ok(_) => {
|
||||
let event: open_ai::ResponseStreamEvent =
|
||||
serde_json::from_str(&buffer)?;
|
||||
Ok(Some((event, body)))
|
||||
}
|
||||
Err(e) => Err(e.into()),
|
||||
}
|
||||
});
|
||||
|
||||
Ok(open_ai::extract_text_from_events(stream))
|
||||
let llm_api_token = self.llm_api_token.clone();
|
||||
let future = self.request_limiter.stream(async move {
|
||||
let response = Self::perform_llm_completion(
|
||||
client.clone(),
|
||||
llm_api_token,
|
||||
PerformCompletionParams {
|
||||
provider: client::LanguageModelProvider::Zed,
|
||||
model: request.model.clone(),
|
||||
provider_request: RawValue::from_string(serde_json::to_string(
|
||||
&request,
|
||||
)?)?,
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
let body = BufReader::new(response.into_body());
|
||||
let stream = futures::stream::try_unfold(body, move |mut body| async move {
|
||||
let mut buffer = String::new();
|
||||
match body.read_line(&mut buffer).await {
|
||||
Ok(0) => Ok(None),
|
||||
Ok(_) => {
|
||||
let event: open_ai::ResponseStreamEvent =
|
||||
serde_json::from_str(&buffer)?;
|
||||
Ok(Some((event, body)))
|
||||
}
|
||||
Err(e) => Err(e.into()),
|
||||
}
|
||||
});
|
||||
async move { Ok(future.await?.boxed()) }.boxed()
|
||||
} else {
|
||||
let future = self.request_limiter.stream(async move {
|
||||
let request = serde_json::to_string(&request)?;
|
||||
let stream = client
|
||||
.request_stream(proto::StreamCompleteWithLanguageModel {
|
||||
provider: proto::LanguageModelProvider::Zed as i32,
|
||||
request,
|
||||
})
|
||||
.await?;
|
||||
Ok(open_ai::extract_text_from_events(
|
||||
stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
|
||||
))
|
||||
});
|
||||
async move { Ok(future.await?.boxed()) }.boxed()
|
||||
}
|
||||
|
||||
Ok(open_ai::extract_text_from_events(stream))
|
||||
});
|
||||
async move { Ok(future.await?.boxed()) }.boxed()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -590,7 +496,7 @@ impl LanguageModel for CloudLanguageModel {
|
|||
tool_name: String,
|
||||
tool_description: String,
|
||||
input_schema: serde_json::Value,
|
||||
cx: &AsyncAppContext,
|
||||
_cx: &AsyncAppContext,
|
||||
) -> BoxFuture<'static, Result<serde_json::Value>> {
|
||||
match &self.model {
|
||||
CloudModel::Anthropic(model) => {
|
||||
|
@ -605,106 +511,67 @@ impl LanguageModel for CloudLanguageModel {
|
|||
input_schema,
|
||||
}];
|
||||
|
||||
if cx
|
||||
.update(|cx| cx.has_flag::<LlmServiceFeatureFlag>())
|
||||
.unwrap_or(false)
|
||||
{
|
||||
let llm_api_token = self.llm_api_token.clone();
|
||||
self.request_limiter
|
||||
.run(async move {
|
||||
let response = Self::perform_llm_completion(
|
||||
client.clone(),
|
||||
llm_api_token,
|
||||
PerformCompletionParams {
|
||||
provider: client::LanguageModelProvider::Anthropic,
|
||||
model: request.model.clone(),
|
||||
provider_request: RawValue::from_string(
|
||||
serde_json::to_string(&request)?,
|
||||
)?,
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
let llm_api_token = self.llm_api_token.clone();
|
||||
self.request_limiter
|
||||
.run(async move {
|
||||
let response = Self::perform_llm_completion(
|
||||
client.clone(),
|
||||
llm_api_token,
|
||||
PerformCompletionParams {
|
||||
provider: client::LanguageModelProvider::Anthropic,
|
||||
model: request.model.clone(),
|
||||
provider_request: RawValue::from_string(serde_json::to_string(
|
||||
&request,
|
||||
)?)?,
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
|
||||
let mut tool_use_index = None;
|
||||
let mut tool_input = String::new();
|
||||
let mut body = BufReader::new(response.into_body());
|
||||
let mut line = String::new();
|
||||
while body.read_line(&mut line).await? > 0 {
|
||||
let event: anthropic::Event = serde_json::from_str(&line)?;
|
||||
line.clear();
|
||||
let mut tool_use_index = None;
|
||||
let mut tool_input = String::new();
|
||||
let mut body = BufReader::new(response.into_body());
|
||||
let mut line = String::new();
|
||||
while body.read_line(&mut line).await? > 0 {
|
||||
let event: anthropic::Event = serde_json::from_str(&line)?;
|
||||
line.clear();
|
||||
|
||||
match event {
|
||||
anthropic::Event::ContentBlockStart {
|
||||
content_block,
|
||||
index,
|
||||
} => {
|
||||
if let anthropic::Content::ToolUse { name, .. } =
|
||||
content_block
|
||||
{
|
||||
if name == tool_name {
|
||||
tool_use_index = Some(index);
|
||||
}
|
||||
}
|
||||
}
|
||||
anthropic::Event::ContentBlockDelta { index, delta } => {
|
||||
match delta {
|
||||
anthropic::ContentDelta::TextDelta { .. } => {}
|
||||
anthropic::ContentDelta::InputJsonDelta {
|
||||
partial_json,
|
||||
} => {
|
||||
if Some(index) == tool_use_index {
|
||||
tool_input.push_str(&partial_json);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
anthropic::Event::ContentBlockStop { index } => {
|
||||
if Some(index) == tool_use_index {
|
||||
return Ok(serde_json::from_str(&tool_input)?);
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
if tool_use_index.is_some() {
|
||||
Err(anyhow!("tool content incomplete"))
|
||||
} else {
|
||||
Err(anyhow!("tool not used"))
|
||||
}
|
||||
})
|
||||
.boxed()
|
||||
} else {
|
||||
self.request_limiter
|
||||
.run(async move {
|
||||
let request = serde_json::to_string(&request)?;
|
||||
let response = client
|
||||
.request(proto::CompleteWithLanguageModel {
|
||||
provider: proto::LanguageModelProvider::Anthropic as i32,
|
||||
request,
|
||||
})
|
||||
.await?;
|
||||
let response: anthropic::Response =
|
||||
serde_json::from_str(&response.completion)?;
|
||||
response
|
||||
.content
|
||||
.into_iter()
|
||||
.find_map(|content| {
|
||||
if let anthropic::Content::ToolUse { name, input, .. } = content
|
||||
match event {
|
||||
anthropic::Event::ContentBlockStart {
|
||||
content_block,
|
||||
index,
|
||||
} => {
|
||||
if let anthropic::Content::ToolUse { name, .. } = content_block
|
||||
{
|
||||
if name == tool_name {
|
||||
Some(input)
|
||||
} else {
|
||||
None
|
||||
tool_use_index = Some(index);
|
||||
}
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.context("tool not used")
|
||||
})
|
||||
.boxed()
|
||||
}
|
||||
}
|
||||
anthropic::Event::ContentBlockDelta { index, delta } => match delta
|
||||
{
|
||||
anthropic::ContentDelta::TextDelta { .. } => {}
|
||||
anthropic::ContentDelta::InputJsonDelta { partial_json } => {
|
||||
if Some(index) == tool_use_index {
|
||||
tool_input.push_str(&partial_json);
|
||||
}
|
||||
}
|
||||
},
|
||||
anthropic::Event::ContentBlockStop { index } => {
|
||||
if Some(index) == tool_use_index {
|
||||
return Ok(serde_json::from_str(&tool_input)?);
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
if tool_use_index.is_some() {
|
||||
Err(anyhow!("tool content incomplete"))
|
||||
} else {
|
||||
Err(anyhow!("tool not used"))
|
||||
}
|
||||
})
|
||||
.boxed()
|
||||
}
|
||||
CloudModel::OpenAi(model) => {
|
||||
let mut request = request.into_open_ai(model.id().into());
|
||||
|
@ -723,115 +590,59 @@ impl LanguageModel for CloudLanguageModel {
|
|||
function.parameters = Some(input_schema);
|
||||
request.tools = vec![open_ai::ToolDefinition::Function { function }];
|
||||
|
||||
if cx
|
||||
.update(|cx| cx.has_flag::<LlmServiceFeatureFlag>())
|
||||
.unwrap_or(false)
|
||||
{
|
||||
let llm_api_token = self.llm_api_token.clone();
|
||||
self.request_limiter
|
||||
.run(async move {
|
||||
let response = Self::perform_llm_completion(
|
||||
client.clone(),
|
||||
llm_api_token,
|
||||
PerformCompletionParams {
|
||||
provider: client::LanguageModelProvider::OpenAi,
|
||||
model: request.model.clone(),
|
||||
provider_request: RawValue::from_string(
|
||||
serde_json::to_string(&request)?,
|
||||
)?,
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
let llm_api_token = self.llm_api_token.clone();
|
||||
self.request_limiter
|
||||
.run(async move {
|
||||
let response = Self::perform_llm_completion(
|
||||
client.clone(),
|
||||
llm_api_token,
|
||||
PerformCompletionParams {
|
||||
provider: client::LanguageModelProvider::OpenAi,
|
||||
model: request.model.clone(),
|
||||
provider_request: RawValue::from_string(serde_json::to_string(
|
||||
&request,
|
||||
)?)?,
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
|
||||
let mut body = BufReader::new(response.into_body());
|
||||
let mut line = String::new();
|
||||
let mut load_state = None;
|
||||
let mut body = BufReader::new(response.into_body());
|
||||
let mut line = String::new();
|
||||
let mut load_state = None;
|
||||
|
||||
while body.read_line(&mut line).await? > 0 {
|
||||
let part: open_ai::ResponseStreamEvent =
|
||||
serde_json::from_str(&line)?;
|
||||
line.clear();
|
||||
while body.read_line(&mut line).await? > 0 {
|
||||
let part: open_ai::ResponseStreamEvent = serde_json::from_str(&line)?;
|
||||
line.clear();
|
||||
|
||||
for choice in part.choices {
|
||||
let Some(tool_calls) = choice.delta.tool_calls else {
|
||||
continue;
|
||||
};
|
||||
for choice in part.choices {
|
||||
let Some(tool_calls) = choice.delta.tool_calls else {
|
||||
continue;
|
||||
};
|
||||
|
||||
for call in tool_calls {
|
||||
if let Some(func) = call.function {
|
||||
if func.name.as_deref() == Some(tool_name.as_str()) {
|
||||
load_state = Some((String::default(), call.index));
|
||||
}
|
||||
if let Some((arguments, (output, index))) =
|
||||
func.arguments.zip(load_state.as_mut())
|
||||
{
|
||||
if call.index == *index {
|
||||
output.push_str(&arguments);
|
||||
}
|
||||
for call in tool_calls {
|
||||
if let Some(func) = call.function {
|
||||
if func.name.as_deref() == Some(tool_name.as_str()) {
|
||||
load_state = Some((String::default(), call.index));
|
||||
}
|
||||
if let Some((arguments, (output, index))) =
|
||||
func.arguments.zip(load_state.as_mut())
|
||||
{
|
||||
if call.index == *index {
|
||||
output.push_str(&arguments);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some((arguments, _)) = load_state {
|
||||
return Ok(serde_json::from_str(&arguments)?);
|
||||
} else {
|
||||
bail!("tool not used");
|
||||
}
|
||||
})
|
||||
.boxed()
|
||||
} else {
|
||||
self.request_limiter
|
||||
.run(async move {
|
||||
let request = serde_json::to_string(&request)?;
|
||||
let response = client
|
||||
.request_stream(proto::StreamCompleteWithLanguageModel {
|
||||
provider: proto::LanguageModelProvider::OpenAi as i32,
|
||||
request,
|
||||
})
|
||||
.await?;
|
||||
let mut load_state = None;
|
||||
let mut response = response.map(
|
||||
|item: Result<
|
||||
proto::StreamCompleteWithLanguageModelResponse,
|
||||
anyhow::Error,
|
||||
>| {
|
||||
Result::<open_ai::ResponseStreamEvent, anyhow::Error>::Ok(
|
||||
serde_json::from_str(&item?.event)?,
|
||||
)
|
||||
},
|
||||
);
|
||||
while let Some(Ok(part)) = response.next().await {
|
||||
for choice in part.choices {
|
||||
let Some(tool_calls) = choice.delta.tool_calls else {
|
||||
continue;
|
||||
};
|
||||
|
||||
for call in tool_calls {
|
||||
if let Some(func) = call.function {
|
||||
if func.name.as_deref() == Some(tool_name.as_str()) {
|
||||
load_state = Some((String::default(), call.index));
|
||||
}
|
||||
if let Some((arguments, (output, index))) =
|
||||
func.arguments.zip(load_state.as_mut())
|
||||
{
|
||||
if call.index == *index {
|
||||
output.push_str(&arguments);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if let Some((arguments, _)) = load_state {
|
||||
return Ok(serde_json::from_str(&arguments)?);
|
||||
} else {
|
||||
bail!("tool not used");
|
||||
}
|
||||
})
|
||||
.boxed()
|
||||
}
|
||||
if let Some((arguments, _)) = load_state {
|
||||
return Ok(serde_json::from_str(&arguments)?);
|
||||
} else {
|
||||
bail!("tool not used");
|
||||
}
|
||||
})
|
||||
.boxed()
|
||||
}
|
||||
CloudModel::Google(_) => {
|
||||
future::ready(Err(anyhow!("tool use not implemented for Google AI"))).boxed()
|
||||
|
@ -854,114 +665,58 @@ impl LanguageModel for CloudLanguageModel {
|
|||
function.parameters = Some(input_schema);
|
||||
request.tools = vec![open_ai::ToolDefinition::Function { function }];
|
||||
|
||||
if cx
|
||||
.update(|cx| cx.has_flag::<LlmServiceFeatureFlag>())
|
||||
.unwrap_or(false)
|
||||
{
|
||||
let llm_api_token = self.llm_api_token.clone();
|
||||
self.request_limiter
|
||||
.run(async move {
|
||||
let response = Self::perform_llm_completion(
|
||||
client.clone(),
|
||||
llm_api_token,
|
||||
PerformCompletionParams {
|
||||
provider: client::LanguageModelProvider::Zed,
|
||||
model: request.model.clone(),
|
||||
provider_request: RawValue::from_string(
|
||||
serde_json::to_string(&request)?,
|
||||
)?,
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
let llm_api_token = self.llm_api_token.clone();
|
||||
self.request_limiter
|
||||
.run(async move {
|
||||
let response = Self::perform_llm_completion(
|
||||
client.clone(),
|
||||
llm_api_token,
|
||||
PerformCompletionParams {
|
||||
provider: client::LanguageModelProvider::Zed,
|
||||
model: request.model.clone(),
|
||||
provider_request: RawValue::from_string(serde_json::to_string(
|
||||
&request,
|
||||
)?)?,
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
|
||||
let mut body = BufReader::new(response.into_body());
|
||||
let mut line = String::new();
|
||||
let mut load_state = None;
|
||||
let mut body = BufReader::new(response.into_body());
|
||||
let mut line = String::new();
|
||||
let mut load_state = None;
|
||||
|
||||
while body.read_line(&mut line).await? > 0 {
|
||||
let part: open_ai::ResponseStreamEvent =
|
||||
serde_json::from_str(&line)?;
|
||||
line.clear();
|
||||
while body.read_line(&mut line).await? > 0 {
|
||||
let part: open_ai::ResponseStreamEvent = serde_json::from_str(&line)?;
|
||||
line.clear();
|
||||
|
||||
for choice in part.choices {
|
||||
let Some(tool_calls) = choice.delta.tool_calls else {
|
||||
continue;
|
||||
};
|
||||
for choice in part.choices {
|
||||
let Some(tool_calls) = choice.delta.tool_calls else {
|
||||
continue;
|
||||
};
|
||||
|
||||
for call in tool_calls {
|
||||
if let Some(func) = call.function {
|
||||
if func.name.as_deref() == Some(tool_name.as_str()) {
|
||||
load_state = Some((String::default(), call.index));
|
||||
}
|
||||
if let Some((arguments, (output, index))) =
|
||||
func.arguments.zip(load_state.as_mut())
|
||||
{
|
||||
if call.index == *index {
|
||||
output.push_str(&arguments);
|
||||
}
|
||||
for call in tool_calls {
|
||||
if let Some(func) = call.function {
|
||||
if func.name.as_deref() == Some(tool_name.as_str()) {
|
||||
load_state = Some((String::default(), call.index));
|
||||
}
|
||||
if let Some((arguments, (output, index))) =
|
||||
func.arguments.zip(load_state.as_mut())
|
||||
{
|
||||
if call.index == *index {
|
||||
output.push_str(&arguments);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if let Some((arguments, _)) = load_state {
|
||||
return Ok(serde_json::from_str(&arguments)?);
|
||||
} else {
|
||||
bail!("tool not used");
|
||||
}
|
||||
})
|
||||
.boxed()
|
||||
} else {
|
||||
self.request_limiter
|
||||
.run(async move {
|
||||
let request = serde_json::to_string(&request)?;
|
||||
let response = client
|
||||
.request_stream(proto::StreamCompleteWithLanguageModel {
|
||||
provider: proto::LanguageModelProvider::OpenAi as i32,
|
||||
request,
|
||||
})
|
||||
.await?;
|
||||
let mut load_state = None;
|
||||
let mut response = response.map(
|
||||
|item: Result<
|
||||
proto::StreamCompleteWithLanguageModelResponse,
|
||||
anyhow::Error,
|
||||
>| {
|
||||
Result::<open_ai::ResponseStreamEvent, anyhow::Error>::Ok(
|
||||
serde_json::from_str(&item?.event)?,
|
||||
)
|
||||
},
|
||||
);
|
||||
while let Some(Ok(part)) = response.next().await {
|
||||
for choice in part.choices {
|
||||
let Some(tool_calls) = choice.delta.tool_calls else {
|
||||
continue;
|
||||
};
|
||||
|
||||
for call in tool_calls {
|
||||
if let Some(func) = call.function {
|
||||
if func.name.as_deref() == Some(tool_name.as_str()) {
|
||||
load_state = Some((String::default(), call.index));
|
||||
}
|
||||
if let Some((arguments, (output, index))) =
|
||||
func.arguments.zip(load_state.as_mut())
|
||||
{
|
||||
if call.index == *index {
|
||||
output.push_str(&arguments);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if let Some((arguments, _)) = load_state {
|
||||
return Ok(serde_json::from_str(&arguments)?);
|
||||
} else {
|
||||
bail!("tool not used");
|
||||
}
|
||||
})
|
||||
.boxed()
|
||||
}
|
||||
}
|
||||
if let Some((arguments, _)) = load_state {
|
||||
return Ok(serde_json::from_str(&arguments)?);
|
||||
} else {
|
||||
bail!("tool not used");
|
||||
}
|
||||
})
|
||||
.boxed()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue