Always stream completions through the LLM service (#16113)
This PR removes the `llm-service` feature flag and makes it so all completions are done via the LLM service when using the Zed provider. Release Notes: - N/A
This commit is contained in:
parent
bab4da78b7
commit
6389c613a2
1 changed files with 264 additions and 509 deletions
|
@ -4,10 +4,10 @@ use crate::{
|
||||||
LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
|
LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
|
||||||
LanguageModelProviderState, LanguageModelRequest, RateLimiter, ZedModel,
|
LanguageModelProviderState, LanguageModelRequest, RateLimiter, ZedModel,
|
||||||
};
|
};
|
||||||
use anyhow::{anyhow, bail, Context as _, Result};
|
use anyhow::{anyhow, bail, Result};
|
||||||
use client::{Client, PerformCompletionParams, UserStore, EXPIRED_LLM_TOKEN_HEADER_NAME};
|
use client::{Client, PerformCompletionParams, UserStore, EXPIRED_LLM_TOKEN_HEADER_NAME};
|
||||||
use collections::BTreeMap;
|
use collections::BTreeMap;
|
||||||
use feature_flags::{FeatureFlag, FeatureFlagAppExt, LanguageModels};
|
use feature_flags::{FeatureFlagAppExt, LanguageModels};
|
||||||
use futures::{future::BoxFuture, stream::BoxStream, AsyncBufReadExt, FutureExt, StreamExt};
|
use futures::{future::BoxFuture, stream::BoxStream, AsyncBufReadExt, FutureExt, StreamExt};
|
||||||
use gpui::{AnyView, AppContext, AsyncAppContext, Model, ModelContext, Subscription, Task};
|
use gpui::{AnyView, AppContext, AsyncAppContext, Model, ModelContext, Subscription, Task};
|
||||||
use http_client::{AsyncBody, HttpClient, Method, Response};
|
use http_client::{AsyncBody, HttpClient, Method, Response};
|
||||||
|
@ -228,16 +228,6 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct LlmServiceFeatureFlag;
|
|
||||||
|
|
||||||
impl FeatureFlag for LlmServiceFeatureFlag {
|
|
||||||
const NAME: &'static str = "llm-service";
|
|
||||||
|
|
||||||
fn enabled_for_staff() -> bool {
|
|
||||||
false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct CloudLanguageModel {
|
pub struct CloudLanguageModel {
|
||||||
id: LanguageModelId,
|
id: LanguageModelId,
|
||||||
model: CloudModel,
|
model: CloudModel,
|
||||||
|
@ -354,17 +344,12 @@ impl LanguageModel for CloudLanguageModel {
|
||||||
fn stream_completion(
|
fn stream_completion(
|
||||||
&self,
|
&self,
|
||||||
request: LanguageModelRequest,
|
request: LanguageModelRequest,
|
||||||
cx: &AsyncAppContext,
|
_cx: &AsyncAppContext,
|
||||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
||||||
match &self.model {
|
match &self.model {
|
||||||
CloudModel::Anthropic(model) => {
|
CloudModel::Anthropic(model) => {
|
||||||
let request = request.into_anthropic(model.id().into());
|
let request = request.into_anthropic(model.id().into());
|
||||||
let client = self.client.clone();
|
let client = self.client.clone();
|
||||||
|
|
||||||
if cx
|
|
||||||
.update(|cx| cx.has_flag::<LlmServiceFeatureFlag>())
|
|
||||||
.unwrap_or(false)
|
|
||||||
{
|
|
||||||
let llm_api_token = self.llm_api_token.clone();
|
let llm_api_token = self.llm_api_token.clone();
|
||||||
let future = self.request_limiter.stream(async move {
|
let future = self.request_limiter.stream(async move {
|
||||||
let response = Self::perform_llm_completion(
|
let response = Self::perform_llm_completion(
|
||||||
|
@ -380,14 +365,12 @@ impl LanguageModel for CloudLanguageModel {
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
let body = BufReader::new(response.into_body());
|
let body = BufReader::new(response.into_body());
|
||||||
let stream =
|
let stream = futures::stream::try_unfold(body, move |mut body| async move {
|
||||||
futures::stream::try_unfold(body, move |mut body| async move {
|
|
||||||
let mut buffer = String::new();
|
let mut buffer = String::new();
|
||||||
match body.read_line(&mut buffer).await {
|
match body.read_line(&mut buffer).await {
|
||||||
Ok(0) => Ok(None),
|
Ok(0) => Ok(None),
|
||||||
Ok(_) => {
|
Ok(_) => {
|
||||||
let event: anthropic::Event =
|
let event: anthropic::Event = serde_json::from_str(&buffer)?;
|
||||||
serde_json::from_str(&buffer)?;
|
|
||||||
Ok(Some((event, body)))
|
Ok(Some((event, body)))
|
||||||
}
|
}
|
||||||
Err(e) => Err(e.into()),
|
Err(e) => Err(e.into()),
|
||||||
|
@ -397,29 +380,10 @@ impl LanguageModel for CloudLanguageModel {
|
||||||
Ok(anthropic::extract_text_from_events(stream))
|
Ok(anthropic::extract_text_from_events(stream))
|
||||||
});
|
});
|
||||||
async move { Ok(future.await?.boxed()) }.boxed()
|
async move { Ok(future.await?.boxed()) }.boxed()
|
||||||
} else {
|
|
||||||
let future = self.request_limiter.stream(async move {
|
|
||||||
let request = serde_json::to_string(&request)?;
|
|
||||||
let stream = client
|
|
||||||
.request_stream(proto::StreamCompleteWithLanguageModel {
|
|
||||||
provider: proto::LanguageModelProvider::Anthropic as i32,
|
|
||||||
request,
|
|
||||||
})
|
|
||||||
.await?
|
|
||||||
.map(|event| Ok(serde_json::from_str(&event?.event)?));
|
|
||||||
Ok(anthropic::extract_text_from_events(stream))
|
|
||||||
});
|
|
||||||
async move { Ok(future.await?.boxed()) }.boxed()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
CloudModel::OpenAi(model) => {
|
CloudModel::OpenAi(model) => {
|
||||||
let client = self.client.clone();
|
let client = self.client.clone();
|
||||||
let request = request.into_open_ai(model.id().into());
|
let request = request.into_open_ai(model.id().into());
|
||||||
|
|
||||||
if cx
|
|
||||||
.update(|cx| cx.has_flag::<LlmServiceFeatureFlag>())
|
|
||||||
.unwrap_or(false)
|
|
||||||
{
|
|
||||||
let llm_api_token = self.llm_api_token.clone();
|
let llm_api_token = self.llm_api_token.clone();
|
||||||
let future = self.request_limiter.stream(async move {
|
let future = self.request_limiter.stream(async move {
|
||||||
let response = Self::perform_llm_completion(
|
let response = Self::perform_llm_completion(
|
||||||
|
@ -435,8 +399,7 @@ impl LanguageModel for CloudLanguageModel {
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
let body = BufReader::new(response.into_body());
|
let body = BufReader::new(response.into_body());
|
||||||
let stream =
|
let stream = futures::stream::try_unfold(body, move |mut body| async move {
|
||||||
futures::stream::try_unfold(body, move |mut body| async move {
|
|
||||||
let mut buffer = String::new();
|
let mut buffer = String::new();
|
||||||
match body.read_line(&mut buffer).await {
|
match body.read_line(&mut buffer).await {
|
||||||
Ok(0) => Ok(None),
|
Ok(0) => Ok(None),
|
||||||
|
@ -452,30 +415,10 @@ impl LanguageModel for CloudLanguageModel {
|
||||||
Ok(open_ai::extract_text_from_events(stream))
|
Ok(open_ai::extract_text_from_events(stream))
|
||||||
});
|
});
|
||||||
async move { Ok(future.await?.boxed()) }.boxed()
|
async move { Ok(future.await?.boxed()) }.boxed()
|
||||||
} else {
|
|
||||||
let future = self.request_limiter.stream(async move {
|
|
||||||
let request = serde_json::to_string(&request)?;
|
|
||||||
let stream = client
|
|
||||||
.request_stream(proto::StreamCompleteWithLanguageModel {
|
|
||||||
provider: proto::LanguageModelProvider::OpenAi as i32,
|
|
||||||
request,
|
|
||||||
})
|
|
||||||
.await?;
|
|
||||||
Ok(open_ai::extract_text_from_events(
|
|
||||||
stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
|
|
||||||
))
|
|
||||||
});
|
|
||||||
async move { Ok(future.await?.boxed()) }.boxed()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
CloudModel::Google(model) => {
|
CloudModel::Google(model) => {
|
||||||
let client = self.client.clone();
|
let client = self.client.clone();
|
||||||
let request = request.into_google(model.id().into());
|
let request = request.into_google(model.id().into());
|
||||||
|
|
||||||
if cx
|
|
||||||
.update(|cx| cx.has_flag::<LlmServiceFeatureFlag>())
|
|
||||||
.unwrap_or(false)
|
|
||||||
{
|
|
||||||
let llm_api_token = self.llm_api_token.clone();
|
let llm_api_token = self.llm_api_token.clone();
|
||||||
let future = self.request_limiter.stream(async move {
|
let future = self.request_limiter.stream(async move {
|
||||||
let response = Self::perform_llm_completion(
|
let response = Self::perform_llm_completion(
|
||||||
|
@ -491,8 +434,7 @@ impl LanguageModel for CloudLanguageModel {
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
let body = BufReader::new(response.into_body());
|
let body = BufReader::new(response.into_body());
|
||||||
let stream =
|
let stream = futures::stream::try_unfold(body, move |mut body| async move {
|
||||||
futures::stream::try_unfold(body, move |mut body| async move {
|
|
||||||
let mut buffer = String::new();
|
let mut buffer = String::new();
|
||||||
match body.read_line(&mut buffer).await {
|
match body.read_line(&mut buffer).await {
|
||||||
Ok(0) => Ok(None),
|
Ok(0) => Ok(None),
|
||||||
|
@ -508,31 +450,11 @@ impl LanguageModel for CloudLanguageModel {
|
||||||
Ok(google_ai::extract_text_from_events(stream))
|
Ok(google_ai::extract_text_from_events(stream))
|
||||||
});
|
});
|
||||||
async move { Ok(future.await?.boxed()) }.boxed()
|
async move { Ok(future.await?.boxed()) }.boxed()
|
||||||
} else {
|
|
||||||
let future = self.request_limiter.stream(async move {
|
|
||||||
let request = serde_json::to_string(&request)?;
|
|
||||||
let stream = client
|
|
||||||
.request_stream(proto::StreamCompleteWithLanguageModel {
|
|
||||||
provider: proto::LanguageModelProvider::Google as i32,
|
|
||||||
request,
|
|
||||||
})
|
|
||||||
.await?;
|
|
||||||
Ok(google_ai::extract_text_from_events(
|
|
||||||
stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
|
|
||||||
))
|
|
||||||
});
|
|
||||||
async move { Ok(future.await?.boxed()) }.boxed()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
CloudModel::Zed(model) => {
|
CloudModel::Zed(model) => {
|
||||||
let client = self.client.clone();
|
let client = self.client.clone();
|
||||||
let mut request = request.into_open_ai(model.id().into());
|
let mut request = request.into_open_ai(model.id().into());
|
||||||
request.max_tokens = Some(4000);
|
request.max_tokens = Some(4000);
|
||||||
|
|
||||||
if cx
|
|
||||||
.update(|cx| cx.has_flag::<LlmServiceFeatureFlag>())
|
|
||||||
.unwrap_or(false)
|
|
||||||
{
|
|
||||||
let llm_api_token = self.llm_api_token.clone();
|
let llm_api_token = self.llm_api_token.clone();
|
||||||
let future = self.request_limiter.stream(async move {
|
let future = self.request_limiter.stream(async move {
|
||||||
let response = Self::perform_llm_completion(
|
let response = Self::perform_llm_completion(
|
||||||
|
@ -548,8 +470,7 @@ impl LanguageModel for CloudLanguageModel {
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
let body = BufReader::new(response.into_body());
|
let body = BufReader::new(response.into_body());
|
||||||
let stream =
|
let stream = futures::stream::try_unfold(body, move |mut body| async move {
|
||||||
futures::stream::try_unfold(body, move |mut body| async move {
|
|
||||||
let mut buffer = String::new();
|
let mut buffer = String::new();
|
||||||
match body.read_line(&mut buffer).await {
|
match body.read_line(&mut buffer).await {
|
||||||
Ok(0) => Ok(None),
|
Ok(0) => Ok(None),
|
||||||
|
@ -565,21 +486,6 @@ impl LanguageModel for CloudLanguageModel {
|
||||||
Ok(open_ai::extract_text_from_events(stream))
|
Ok(open_ai::extract_text_from_events(stream))
|
||||||
});
|
});
|
||||||
async move { Ok(future.await?.boxed()) }.boxed()
|
async move { Ok(future.await?.boxed()) }.boxed()
|
||||||
} else {
|
|
||||||
let future = self.request_limiter.stream(async move {
|
|
||||||
let request = serde_json::to_string(&request)?;
|
|
||||||
let stream = client
|
|
||||||
.request_stream(proto::StreamCompleteWithLanguageModel {
|
|
||||||
provider: proto::LanguageModelProvider::Zed as i32,
|
|
||||||
request,
|
|
||||||
})
|
|
||||||
.await?;
|
|
||||||
Ok(open_ai::extract_text_from_events(
|
|
||||||
stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
|
|
||||||
))
|
|
||||||
});
|
|
||||||
async move { Ok(future.await?.boxed()) }.boxed()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -590,7 +496,7 @@ impl LanguageModel for CloudLanguageModel {
|
||||||
tool_name: String,
|
tool_name: String,
|
||||||
tool_description: String,
|
tool_description: String,
|
||||||
input_schema: serde_json::Value,
|
input_schema: serde_json::Value,
|
||||||
cx: &AsyncAppContext,
|
_cx: &AsyncAppContext,
|
||||||
) -> BoxFuture<'static, Result<serde_json::Value>> {
|
) -> BoxFuture<'static, Result<serde_json::Value>> {
|
||||||
match &self.model {
|
match &self.model {
|
||||||
CloudModel::Anthropic(model) => {
|
CloudModel::Anthropic(model) => {
|
||||||
|
@ -605,10 +511,6 @@ impl LanguageModel for CloudLanguageModel {
|
||||||
input_schema,
|
input_schema,
|
||||||
}];
|
}];
|
||||||
|
|
||||||
if cx
|
|
||||||
.update(|cx| cx.has_flag::<LlmServiceFeatureFlag>())
|
|
||||||
.unwrap_or(false)
|
|
||||||
{
|
|
||||||
let llm_api_token = self.llm_api_token.clone();
|
let llm_api_token = self.llm_api_token.clone();
|
||||||
self.request_limiter
|
self.request_limiter
|
||||||
.run(async move {
|
.run(async move {
|
||||||
|
@ -618,9 +520,9 @@ impl LanguageModel for CloudLanguageModel {
|
||||||
PerformCompletionParams {
|
PerformCompletionParams {
|
||||||
provider: client::LanguageModelProvider::Anthropic,
|
provider: client::LanguageModelProvider::Anthropic,
|
||||||
model: request.model.clone(),
|
model: request.model.clone(),
|
||||||
provider_request: RawValue::from_string(
|
provider_request: RawValue::from_string(serde_json::to_string(
|
||||||
serde_json::to_string(&request)?,
|
&request,
|
||||||
)?,
|
)?)?,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
|
@ -638,26 +540,22 @@ impl LanguageModel for CloudLanguageModel {
|
||||||
content_block,
|
content_block,
|
||||||
index,
|
index,
|
||||||
} => {
|
} => {
|
||||||
if let anthropic::Content::ToolUse { name, .. } =
|
if let anthropic::Content::ToolUse { name, .. } = content_block
|
||||||
content_block
|
|
||||||
{
|
{
|
||||||
if name == tool_name {
|
if name == tool_name {
|
||||||
tool_use_index = Some(index);
|
tool_use_index = Some(index);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
anthropic::Event::ContentBlockDelta { index, delta } => {
|
anthropic::Event::ContentBlockDelta { index, delta } => match delta
|
||||||
match delta {
|
{
|
||||||
anthropic::ContentDelta::TextDelta { .. } => {}
|
anthropic::ContentDelta::TextDelta { .. } => {}
|
||||||
anthropic::ContentDelta::InputJsonDelta {
|
anthropic::ContentDelta::InputJsonDelta { partial_json } => {
|
||||||
partial_json,
|
|
||||||
} => {
|
|
||||||
if Some(index) == tool_use_index {
|
if Some(index) == tool_use_index {
|
||||||
tool_input.push_str(&partial_json);
|
tool_input.push_str(&partial_json);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
},
|
||||||
}
|
|
||||||
anthropic::Event::ContentBlockStop { index } => {
|
anthropic::Event::ContentBlockStop { index } => {
|
||||||
if Some(index) == tool_use_index {
|
if Some(index) == tool_use_index {
|
||||||
return Ok(serde_json::from_str(&tool_input)?);
|
return Ok(serde_json::from_str(&tool_input)?);
|
||||||
|
@ -674,37 +572,6 @@ impl LanguageModel for CloudLanguageModel {
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
.boxed()
|
.boxed()
|
||||||
} else {
|
|
||||||
self.request_limiter
|
|
||||||
.run(async move {
|
|
||||||
let request = serde_json::to_string(&request)?;
|
|
||||||
let response = client
|
|
||||||
.request(proto::CompleteWithLanguageModel {
|
|
||||||
provider: proto::LanguageModelProvider::Anthropic as i32,
|
|
||||||
request,
|
|
||||||
})
|
|
||||||
.await?;
|
|
||||||
let response: anthropic::Response =
|
|
||||||
serde_json::from_str(&response.completion)?;
|
|
||||||
response
|
|
||||||
.content
|
|
||||||
.into_iter()
|
|
||||||
.find_map(|content| {
|
|
||||||
if let anthropic::Content::ToolUse { name, input, .. } = content
|
|
||||||
{
|
|
||||||
if name == tool_name {
|
|
||||||
Some(input)
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.context("tool not used")
|
|
||||||
})
|
|
||||||
.boxed()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
CloudModel::OpenAi(model) => {
|
CloudModel::OpenAi(model) => {
|
||||||
let mut request = request.into_open_ai(model.id().into());
|
let mut request = request.into_open_ai(model.id().into());
|
||||||
|
@ -723,10 +590,6 @@ impl LanguageModel for CloudLanguageModel {
|
||||||
function.parameters = Some(input_schema);
|
function.parameters = Some(input_schema);
|
||||||
request.tools = vec![open_ai::ToolDefinition::Function { function }];
|
request.tools = vec![open_ai::ToolDefinition::Function { function }];
|
||||||
|
|
||||||
if cx
|
|
||||||
.update(|cx| cx.has_flag::<LlmServiceFeatureFlag>())
|
|
||||||
.unwrap_or(false)
|
|
||||||
{
|
|
||||||
let llm_api_token = self.llm_api_token.clone();
|
let llm_api_token = self.llm_api_token.clone();
|
||||||
self.request_limiter
|
self.request_limiter
|
||||||
.run(async move {
|
.run(async move {
|
||||||
|
@ -736,9 +599,9 @@ impl LanguageModel for CloudLanguageModel {
|
||||||
PerformCompletionParams {
|
PerformCompletionParams {
|
||||||
provider: client::LanguageModelProvider::OpenAi,
|
provider: client::LanguageModelProvider::OpenAi,
|
||||||
model: request.model.clone(),
|
model: request.model.clone(),
|
||||||
provider_request: RawValue::from_string(
|
provider_request: RawValue::from_string(serde_json::to_string(
|
||||||
serde_json::to_string(&request)?,
|
&request,
|
||||||
)?,
|
)?)?,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
|
@ -748,8 +611,7 @@ impl LanguageModel for CloudLanguageModel {
|
||||||
let mut load_state = None;
|
let mut load_state = None;
|
||||||
|
|
||||||
while body.read_line(&mut line).await? > 0 {
|
while body.read_line(&mut line).await? > 0 {
|
||||||
let part: open_ai::ResponseStreamEvent =
|
let part: open_ai::ResponseStreamEvent = serde_json::from_str(&line)?;
|
||||||
serde_json::from_str(&line)?;
|
|
||||||
line.clear();
|
line.clear();
|
||||||
|
|
||||||
for choice in part.choices {
|
for choice in part.choices {
|
||||||
|
@ -781,57 +643,6 @@ impl LanguageModel for CloudLanguageModel {
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
.boxed()
|
.boxed()
|
||||||
} else {
|
|
||||||
self.request_limiter
|
|
||||||
.run(async move {
|
|
||||||
let request = serde_json::to_string(&request)?;
|
|
||||||
let response = client
|
|
||||||
.request_stream(proto::StreamCompleteWithLanguageModel {
|
|
||||||
provider: proto::LanguageModelProvider::OpenAi as i32,
|
|
||||||
request,
|
|
||||||
})
|
|
||||||
.await?;
|
|
||||||
let mut load_state = None;
|
|
||||||
let mut response = response.map(
|
|
||||||
|item: Result<
|
|
||||||
proto::StreamCompleteWithLanguageModelResponse,
|
|
||||||
anyhow::Error,
|
|
||||||
>| {
|
|
||||||
Result::<open_ai::ResponseStreamEvent, anyhow::Error>::Ok(
|
|
||||||
serde_json::from_str(&item?.event)?,
|
|
||||||
)
|
|
||||||
},
|
|
||||||
);
|
|
||||||
while let Some(Ok(part)) = response.next().await {
|
|
||||||
for choice in part.choices {
|
|
||||||
let Some(tool_calls) = choice.delta.tool_calls else {
|
|
||||||
continue;
|
|
||||||
};
|
|
||||||
|
|
||||||
for call in tool_calls {
|
|
||||||
if let Some(func) = call.function {
|
|
||||||
if func.name.as_deref() == Some(tool_name.as_str()) {
|
|
||||||
load_state = Some((String::default(), call.index));
|
|
||||||
}
|
|
||||||
if let Some((arguments, (output, index))) =
|
|
||||||
func.arguments.zip(load_state.as_mut())
|
|
||||||
{
|
|
||||||
if call.index == *index {
|
|
||||||
output.push_str(&arguments);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if let Some((arguments, _)) = load_state {
|
|
||||||
return Ok(serde_json::from_str(&arguments)?);
|
|
||||||
} else {
|
|
||||||
bail!("tool not used");
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.boxed()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
CloudModel::Google(_) => {
|
CloudModel::Google(_) => {
|
||||||
future::ready(Err(anyhow!("tool use not implemented for Google AI"))).boxed()
|
future::ready(Err(anyhow!("tool use not implemented for Google AI"))).boxed()
|
||||||
|
@ -854,10 +665,6 @@ impl LanguageModel for CloudLanguageModel {
|
||||||
function.parameters = Some(input_schema);
|
function.parameters = Some(input_schema);
|
||||||
request.tools = vec![open_ai::ToolDefinition::Function { function }];
|
request.tools = vec![open_ai::ToolDefinition::Function { function }];
|
||||||
|
|
||||||
if cx
|
|
||||||
.update(|cx| cx.has_flag::<LlmServiceFeatureFlag>())
|
|
||||||
.unwrap_or(false)
|
|
||||||
{
|
|
||||||
let llm_api_token = self.llm_api_token.clone();
|
let llm_api_token = self.llm_api_token.clone();
|
||||||
self.request_limiter
|
self.request_limiter
|
||||||
.run(async move {
|
.run(async move {
|
||||||
|
@ -867,9 +674,9 @@ impl LanguageModel for CloudLanguageModel {
|
||||||
PerformCompletionParams {
|
PerformCompletionParams {
|
||||||
provider: client::LanguageModelProvider::Zed,
|
provider: client::LanguageModelProvider::Zed,
|
||||||
model: request.model.clone(),
|
model: request.model.clone(),
|
||||||
provider_request: RawValue::from_string(
|
provider_request: RawValue::from_string(serde_json::to_string(
|
||||||
serde_json::to_string(&request)?,
|
&request,
|
||||||
)?,
|
)?)?,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
|
@ -879,8 +686,7 @@ impl LanguageModel for CloudLanguageModel {
|
||||||
let mut load_state = None;
|
let mut load_state = None;
|
||||||
|
|
||||||
while body.read_line(&mut line).await? > 0 {
|
while body.read_line(&mut line).await? > 0 {
|
||||||
let part: open_ai::ResponseStreamEvent =
|
let part: open_ai::ResponseStreamEvent = serde_json::from_str(&line)?;
|
||||||
serde_json::from_str(&line)?;
|
|
||||||
line.clear();
|
line.clear();
|
||||||
|
|
||||||
for choice in part.choices {
|
for choice in part.choices {
|
||||||
|
@ -911,57 +717,6 @@ impl LanguageModel for CloudLanguageModel {
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
.boxed()
|
.boxed()
|
||||||
} else {
|
|
||||||
self.request_limiter
|
|
||||||
.run(async move {
|
|
||||||
let request = serde_json::to_string(&request)?;
|
|
||||||
let response = client
|
|
||||||
.request_stream(proto::StreamCompleteWithLanguageModel {
|
|
||||||
provider: proto::LanguageModelProvider::OpenAi as i32,
|
|
||||||
request,
|
|
||||||
})
|
|
||||||
.await?;
|
|
||||||
let mut load_state = None;
|
|
||||||
let mut response = response.map(
|
|
||||||
|item: Result<
|
|
||||||
proto::StreamCompleteWithLanguageModelResponse,
|
|
||||||
anyhow::Error,
|
|
||||||
>| {
|
|
||||||
Result::<open_ai::ResponseStreamEvent, anyhow::Error>::Ok(
|
|
||||||
serde_json::from_str(&item?.event)?,
|
|
||||||
)
|
|
||||||
},
|
|
||||||
);
|
|
||||||
while let Some(Ok(part)) = response.next().await {
|
|
||||||
for choice in part.choices {
|
|
||||||
let Some(tool_calls) = choice.delta.tool_calls else {
|
|
||||||
continue;
|
|
||||||
};
|
|
||||||
|
|
||||||
for call in tool_calls {
|
|
||||||
if let Some(func) = call.function {
|
|
||||||
if func.name.as_deref() == Some(tool_name.as_str()) {
|
|
||||||
load_state = Some((String::default(), call.index));
|
|
||||||
}
|
|
||||||
if let Some((arguments, (output, index))) =
|
|
||||||
func.arguments.zip(load_state.as_mut())
|
|
||||||
{
|
|
||||||
if call.index == *index {
|
|
||||||
output.push_str(&arguments);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if let Some((arguments, _)) = load_state {
|
|
||||||
return Ok(serde_json::from_str(&arguments)?);
|
|
||||||
} else {
|
|
||||||
bail!("tool not used");
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.boxed()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue