Always stream completions through the LLM service (#16113)
This PR removes the `llm-service` feature flag and makes it so all completions are done via the LLM service when using the Zed provider. Release Notes: - N/A
This commit is contained in:
parent
bab4da78b7
commit
6389c613a2
1 changed files with 264 additions and 509 deletions
|
@ -4,10 +4,10 @@ use crate::{
|
||||||
LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
|
LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
|
||||||
LanguageModelProviderState, LanguageModelRequest, RateLimiter, ZedModel,
|
LanguageModelProviderState, LanguageModelRequest, RateLimiter, ZedModel,
|
||||||
};
|
};
|
||||||
use anyhow::{anyhow, bail, Context as _, Result};
|
use anyhow::{anyhow, bail, Result};
|
||||||
use client::{Client, PerformCompletionParams, UserStore, EXPIRED_LLM_TOKEN_HEADER_NAME};
|
use client::{Client, PerformCompletionParams, UserStore, EXPIRED_LLM_TOKEN_HEADER_NAME};
|
||||||
use collections::BTreeMap;
|
use collections::BTreeMap;
|
||||||
use feature_flags::{FeatureFlag, FeatureFlagAppExt, LanguageModels};
|
use feature_flags::{FeatureFlagAppExt, LanguageModels};
|
||||||
use futures::{future::BoxFuture, stream::BoxStream, AsyncBufReadExt, FutureExt, StreamExt};
|
use futures::{future::BoxFuture, stream::BoxStream, AsyncBufReadExt, FutureExt, StreamExt};
|
||||||
use gpui::{AnyView, AppContext, AsyncAppContext, Model, ModelContext, Subscription, Task};
|
use gpui::{AnyView, AppContext, AsyncAppContext, Model, ModelContext, Subscription, Task};
|
||||||
use http_client::{AsyncBody, HttpClient, Method, Response};
|
use http_client::{AsyncBody, HttpClient, Method, Response};
|
||||||
|
@ -228,16 +228,6 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct LlmServiceFeatureFlag;
|
|
||||||
|
|
||||||
impl FeatureFlag for LlmServiceFeatureFlag {
|
|
||||||
const NAME: &'static str = "llm-service";
|
|
||||||
|
|
||||||
fn enabled_for_staff() -> bool {
|
|
||||||
false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct CloudLanguageModel {
|
pub struct CloudLanguageModel {
|
||||||
id: LanguageModelId,
|
id: LanguageModelId,
|
||||||
model: CloudModel,
|
model: CloudModel,
|
||||||
|
@ -354,232 +344,148 @@ impl LanguageModel for CloudLanguageModel {
|
||||||
fn stream_completion(
|
fn stream_completion(
|
||||||
&self,
|
&self,
|
||||||
request: LanguageModelRequest,
|
request: LanguageModelRequest,
|
||||||
cx: &AsyncAppContext,
|
_cx: &AsyncAppContext,
|
||||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
||||||
match &self.model {
|
match &self.model {
|
||||||
CloudModel::Anthropic(model) => {
|
CloudModel::Anthropic(model) => {
|
||||||
let request = request.into_anthropic(model.id().into());
|
let request = request.into_anthropic(model.id().into());
|
||||||
let client = self.client.clone();
|
let client = self.client.clone();
|
||||||
|
let llm_api_token = self.llm_api_token.clone();
|
||||||
if cx
|
let future = self.request_limiter.stream(async move {
|
||||||
.update(|cx| cx.has_flag::<LlmServiceFeatureFlag>())
|
let response = Self::perform_llm_completion(
|
||||||
.unwrap_or(false)
|
client.clone(),
|
||||||
{
|
llm_api_token,
|
||||||
let llm_api_token = self.llm_api_token.clone();
|
PerformCompletionParams {
|
||||||
let future = self.request_limiter.stream(async move {
|
provider: client::LanguageModelProvider::Anthropic,
|
||||||
let response = Self::perform_llm_completion(
|
model: request.model.clone(),
|
||||||
client.clone(),
|
provider_request: RawValue::from_string(serde_json::to_string(
|
||||||
llm_api_token,
|
&request,
|
||||||
PerformCompletionParams {
|
)?)?,
|
||||||
provider: client::LanguageModelProvider::Anthropic,
|
},
|
||||||
model: request.model.clone(),
|
)
|
||||||
provider_request: RawValue::from_string(serde_json::to_string(
|
.await?;
|
||||||
&request,
|
let body = BufReader::new(response.into_body());
|
||||||
)?)?,
|
let stream = futures::stream::try_unfold(body, move |mut body| async move {
|
||||||
},
|
let mut buffer = String::new();
|
||||||
)
|
match body.read_line(&mut buffer).await {
|
||||||
.await?;
|
Ok(0) => Ok(None),
|
||||||
let body = BufReader::new(response.into_body());
|
Ok(_) => {
|
||||||
let stream =
|
let event: anthropic::Event = serde_json::from_str(&buffer)?;
|
||||||
futures::stream::try_unfold(body, move |mut body| async move {
|
Ok(Some((event, body)))
|
||||||
let mut buffer = String::new();
|
}
|
||||||
match body.read_line(&mut buffer).await {
|
Err(e) => Err(e.into()),
|
||||||
Ok(0) => Ok(None),
|
}
|
||||||
Ok(_) => {
|
|
||||||
let event: anthropic::Event =
|
|
||||||
serde_json::from_str(&buffer)?;
|
|
||||||
Ok(Some((event, body)))
|
|
||||||
}
|
|
||||||
Err(e) => Err(e.into()),
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
Ok(anthropic::extract_text_from_events(stream))
|
|
||||||
});
|
});
|
||||||
async move { Ok(future.await?.boxed()) }.boxed()
|
|
||||||
} else {
|
Ok(anthropic::extract_text_from_events(stream))
|
||||||
let future = self.request_limiter.stream(async move {
|
});
|
||||||
let request = serde_json::to_string(&request)?;
|
async move { Ok(future.await?.boxed()) }.boxed()
|
||||||
let stream = client
|
|
||||||
.request_stream(proto::StreamCompleteWithLanguageModel {
|
|
||||||
provider: proto::LanguageModelProvider::Anthropic as i32,
|
|
||||||
request,
|
|
||||||
})
|
|
||||||
.await?
|
|
||||||
.map(|event| Ok(serde_json::from_str(&event?.event)?));
|
|
||||||
Ok(anthropic::extract_text_from_events(stream))
|
|
||||||
});
|
|
||||||
async move { Ok(future.await?.boxed()) }.boxed()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
CloudModel::OpenAi(model) => {
|
CloudModel::OpenAi(model) => {
|
||||||
let client = self.client.clone();
|
let client = self.client.clone();
|
||||||
let request = request.into_open_ai(model.id().into());
|
let request = request.into_open_ai(model.id().into());
|
||||||
|
let llm_api_token = self.llm_api_token.clone();
|
||||||
if cx
|
let future = self.request_limiter.stream(async move {
|
||||||
.update(|cx| cx.has_flag::<LlmServiceFeatureFlag>())
|
let response = Self::perform_llm_completion(
|
||||||
.unwrap_or(false)
|
client.clone(),
|
||||||
{
|
llm_api_token,
|
||||||
let llm_api_token = self.llm_api_token.clone();
|
PerformCompletionParams {
|
||||||
let future = self.request_limiter.stream(async move {
|
provider: client::LanguageModelProvider::OpenAi,
|
||||||
let response = Self::perform_llm_completion(
|
model: request.model.clone(),
|
||||||
client.clone(),
|
provider_request: RawValue::from_string(serde_json::to_string(
|
||||||
llm_api_token,
|
&request,
|
||||||
PerformCompletionParams {
|
)?)?,
|
||||||
provider: client::LanguageModelProvider::OpenAi,
|
},
|
||||||
model: request.model.clone(),
|
)
|
||||||
provider_request: RawValue::from_string(serde_json::to_string(
|
.await?;
|
||||||
&request,
|
let body = BufReader::new(response.into_body());
|
||||||
)?)?,
|
let stream = futures::stream::try_unfold(body, move |mut body| async move {
|
||||||
},
|
let mut buffer = String::new();
|
||||||
)
|
match body.read_line(&mut buffer).await {
|
||||||
.await?;
|
Ok(0) => Ok(None),
|
||||||
let body = BufReader::new(response.into_body());
|
Ok(_) => {
|
||||||
let stream =
|
let event: open_ai::ResponseStreamEvent =
|
||||||
futures::stream::try_unfold(body, move |mut body| async move {
|
serde_json::from_str(&buffer)?;
|
||||||
let mut buffer = String::new();
|
Ok(Some((event, body)))
|
||||||
match body.read_line(&mut buffer).await {
|
}
|
||||||
Ok(0) => Ok(None),
|
Err(e) => Err(e.into()),
|
||||||
Ok(_) => {
|
}
|
||||||
let event: open_ai::ResponseStreamEvent =
|
|
||||||
serde_json::from_str(&buffer)?;
|
|
||||||
Ok(Some((event, body)))
|
|
||||||
}
|
|
||||||
Err(e) => Err(e.into()),
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
Ok(open_ai::extract_text_from_events(stream))
|
|
||||||
});
|
});
|
||||||
async move { Ok(future.await?.boxed()) }.boxed()
|
|
||||||
} else {
|
Ok(open_ai::extract_text_from_events(stream))
|
||||||
let future = self.request_limiter.stream(async move {
|
});
|
||||||
let request = serde_json::to_string(&request)?;
|
async move { Ok(future.await?.boxed()) }.boxed()
|
||||||
let stream = client
|
|
||||||
.request_stream(proto::StreamCompleteWithLanguageModel {
|
|
||||||
provider: proto::LanguageModelProvider::OpenAi as i32,
|
|
||||||
request,
|
|
||||||
})
|
|
||||||
.await?;
|
|
||||||
Ok(open_ai::extract_text_from_events(
|
|
||||||
stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
|
|
||||||
))
|
|
||||||
});
|
|
||||||
async move { Ok(future.await?.boxed()) }.boxed()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
CloudModel::Google(model) => {
|
CloudModel::Google(model) => {
|
||||||
let client = self.client.clone();
|
let client = self.client.clone();
|
||||||
let request = request.into_google(model.id().into());
|
let request = request.into_google(model.id().into());
|
||||||
|
let llm_api_token = self.llm_api_token.clone();
|
||||||
if cx
|
let future = self.request_limiter.stream(async move {
|
||||||
.update(|cx| cx.has_flag::<LlmServiceFeatureFlag>())
|
let response = Self::perform_llm_completion(
|
||||||
.unwrap_or(false)
|
client.clone(),
|
||||||
{
|
llm_api_token,
|
||||||
let llm_api_token = self.llm_api_token.clone();
|
PerformCompletionParams {
|
||||||
let future = self.request_limiter.stream(async move {
|
provider: client::LanguageModelProvider::Google,
|
||||||
let response = Self::perform_llm_completion(
|
model: request.model.clone(),
|
||||||
client.clone(),
|
provider_request: RawValue::from_string(serde_json::to_string(
|
||||||
llm_api_token,
|
&request,
|
||||||
PerformCompletionParams {
|
)?)?,
|
||||||
provider: client::LanguageModelProvider::Google,
|
},
|
||||||
model: request.model.clone(),
|
)
|
||||||
provider_request: RawValue::from_string(serde_json::to_string(
|
.await?;
|
||||||
&request,
|
let body = BufReader::new(response.into_body());
|
||||||
)?)?,
|
let stream = futures::stream::try_unfold(body, move |mut body| async move {
|
||||||
},
|
let mut buffer = String::new();
|
||||||
)
|
match body.read_line(&mut buffer).await {
|
||||||
.await?;
|
Ok(0) => Ok(None),
|
||||||
let body = BufReader::new(response.into_body());
|
Ok(_) => {
|
||||||
let stream =
|
let event: google_ai::GenerateContentResponse =
|
||||||
futures::stream::try_unfold(body, move |mut body| async move {
|
serde_json::from_str(&buffer)?;
|
||||||
let mut buffer = String::new();
|
Ok(Some((event, body)))
|
||||||
match body.read_line(&mut buffer).await {
|
}
|
||||||
Ok(0) => Ok(None),
|
Err(e) => Err(e.into()),
|
||||||
Ok(_) => {
|
}
|
||||||
let event: google_ai::GenerateContentResponse =
|
|
||||||
serde_json::from_str(&buffer)?;
|
|
||||||
Ok(Some((event, body)))
|
|
||||||
}
|
|
||||||
Err(e) => Err(e.into()),
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
Ok(google_ai::extract_text_from_events(stream))
|
|
||||||
});
|
});
|
||||||
async move { Ok(future.await?.boxed()) }.boxed()
|
|
||||||
} else {
|
Ok(google_ai::extract_text_from_events(stream))
|
||||||
let future = self.request_limiter.stream(async move {
|
});
|
||||||
let request = serde_json::to_string(&request)?;
|
async move { Ok(future.await?.boxed()) }.boxed()
|
||||||
let stream = client
|
|
||||||
.request_stream(proto::StreamCompleteWithLanguageModel {
|
|
||||||
provider: proto::LanguageModelProvider::Google as i32,
|
|
||||||
request,
|
|
||||||
})
|
|
||||||
.await?;
|
|
||||||
Ok(google_ai::extract_text_from_events(
|
|
||||||
stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
|
|
||||||
))
|
|
||||||
});
|
|
||||||
async move { Ok(future.await?.boxed()) }.boxed()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
CloudModel::Zed(model) => {
|
CloudModel::Zed(model) => {
|
||||||
let client = self.client.clone();
|
let client = self.client.clone();
|
||||||
let mut request = request.into_open_ai(model.id().into());
|
let mut request = request.into_open_ai(model.id().into());
|
||||||
request.max_tokens = Some(4000);
|
request.max_tokens = Some(4000);
|
||||||
|
let llm_api_token = self.llm_api_token.clone();
|
||||||
if cx
|
let future = self.request_limiter.stream(async move {
|
||||||
.update(|cx| cx.has_flag::<LlmServiceFeatureFlag>())
|
let response = Self::perform_llm_completion(
|
||||||
.unwrap_or(false)
|
client.clone(),
|
||||||
{
|
llm_api_token,
|
||||||
let llm_api_token = self.llm_api_token.clone();
|
PerformCompletionParams {
|
||||||
let future = self.request_limiter.stream(async move {
|
provider: client::LanguageModelProvider::Zed,
|
||||||
let response = Self::perform_llm_completion(
|
model: request.model.clone(),
|
||||||
client.clone(),
|
provider_request: RawValue::from_string(serde_json::to_string(
|
||||||
llm_api_token,
|
&request,
|
||||||
PerformCompletionParams {
|
)?)?,
|
||||||
provider: client::LanguageModelProvider::Zed,
|
},
|
||||||
model: request.model.clone(),
|
)
|
||||||
provider_request: RawValue::from_string(serde_json::to_string(
|
.await?;
|
||||||
&request,
|
let body = BufReader::new(response.into_body());
|
||||||
)?)?,
|
let stream = futures::stream::try_unfold(body, move |mut body| async move {
|
||||||
},
|
let mut buffer = String::new();
|
||||||
)
|
match body.read_line(&mut buffer).await {
|
||||||
.await?;
|
Ok(0) => Ok(None),
|
||||||
let body = BufReader::new(response.into_body());
|
Ok(_) => {
|
||||||
let stream =
|
let event: open_ai::ResponseStreamEvent =
|
||||||
futures::stream::try_unfold(body, move |mut body| async move {
|
serde_json::from_str(&buffer)?;
|
||||||
let mut buffer = String::new();
|
Ok(Some((event, body)))
|
||||||
match body.read_line(&mut buffer).await {
|
}
|
||||||
Ok(0) => Ok(None),
|
Err(e) => Err(e.into()),
|
||||||
Ok(_) => {
|
}
|
||||||
let event: open_ai::ResponseStreamEvent =
|
|
||||||
serde_json::from_str(&buffer)?;
|
|
||||||
Ok(Some((event, body)))
|
|
||||||
}
|
|
||||||
Err(e) => Err(e.into()),
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
Ok(open_ai::extract_text_from_events(stream))
|
|
||||||
});
|
});
|
||||||
async move { Ok(future.await?.boxed()) }.boxed()
|
|
||||||
} else {
|
Ok(open_ai::extract_text_from_events(stream))
|
||||||
let future = self.request_limiter.stream(async move {
|
});
|
||||||
let request = serde_json::to_string(&request)?;
|
async move { Ok(future.await?.boxed()) }.boxed()
|
||||||
let stream = client
|
|
||||||
.request_stream(proto::StreamCompleteWithLanguageModel {
|
|
||||||
provider: proto::LanguageModelProvider::Zed as i32,
|
|
||||||
request,
|
|
||||||
})
|
|
||||||
.await?;
|
|
||||||
Ok(open_ai::extract_text_from_events(
|
|
||||||
stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
|
|
||||||
))
|
|
||||||
});
|
|
||||||
async move { Ok(future.await?.boxed()) }.boxed()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -590,7 +496,7 @@ impl LanguageModel for CloudLanguageModel {
|
||||||
tool_name: String,
|
tool_name: String,
|
||||||
tool_description: String,
|
tool_description: String,
|
||||||
input_schema: serde_json::Value,
|
input_schema: serde_json::Value,
|
||||||
cx: &AsyncAppContext,
|
_cx: &AsyncAppContext,
|
||||||
) -> BoxFuture<'static, Result<serde_json::Value>> {
|
) -> BoxFuture<'static, Result<serde_json::Value>> {
|
||||||
match &self.model {
|
match &self.model {
|
||||||
CloudModel::Anthropic(model) => {
|
CloudModel::Anthropic(model) => {
|
||||||
|
@ -605,106 +511,67 @@ impl LanguageModel for CloudLanguageModel {
|
||||||
input_schema,
|
input_schema,
|
||||||
}];
|
}];
|
||||||
|
|
||||||
if cx
|
let llm_api_token = self.llm_api_token.clone();
|
||||||
.update(|cx| cx.has_flag::<LlmServiceFeatureFlag>())
|
self.request_limiter
|
||||||
.unwrap_or(false)
|
.run(async move {
|
||||||
{
|
let response = Self::perform_llm_completion(
|
||||||
let llm_api_token = self.llm_api_token.clone();
|
client.clone(),
|
||||||
self.request_limiter
|
llm_api_token,
|
||||||
.run(async move {
|
PerformCompletionParams {
|
||||||
let response = Self::perform_llm_completion(
|
provider: client::LanguageModelProvider::Anthropic,
|
||||||
client.clone(),
|
model: request.model.clone(),
|
||||||
llm_api_token,
|
provider_request: RawValue::from_string(serde_json::to_string(
|
||||||
PerformCompletionParams {
|
&request,
|
||||||
provider: client::LanguageModelProvider::Anthropic,
|
)?)?,
|
||||||
model: request.model.clone(),
|
},
|
||||||
provider_request: RawValue::from_string(
|
)
|
||||||
serde_json::to_string(&request)?,
|
.await?;
|
||||||
)?,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
let mut tool_use_index = None;
|
let mut tool_use_index = None;
|
||||||
let mut tool_input = String::new();
|
let mut tool_input = String::new();
|
||||||
let mut body = BufReader::new(response.into_body());
|
let mut body = BufReader::new(response.into_body());
|
||||||
let mut line = String::new();
|
let mut line = String::new();
|
||||||
while body.read_line(&mut line).await? > 0 {
|
while body.read_line(&mut line).await? > 0 {
|
||||||
let event: anthropic::Event = serde_json::from_str(&line)?;
|
let event: anthropic::Event = serde_json::from_str(&line)?;
|
||||||
line.clear();
|
line.clear();
|
||||||
|
|
||||||
match event {
|
match event {
|
||||||
anthropic::Event::ContentBlockStart {
|
anthropic::Event::ContentBlockStart {
|
||||||
content_block,
|
content_block,
|
||||||
index,
|
index,
|
||||||
} => {
|
} => {
|
||||||
if let anthropic::Content::ToolUse { name, .. } =
|
if let anthropic::Content::ToolUse { name, .. } = content_block
|
||||||
content_block
|
|
||||||
{
|
|
||||||
if name == tool_name {
|
|
||||||
tool_use_index = Some(index);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
anthropic::Event::ContentBlockDelta { index, delta } => {
|
|
||||||
match delta {
|
|
||||||
anthropic::ContentDelta::TextDelta { .. } => {}
|
|
||||||
anthropic::ContentDelta::InputJsonDelta {
|
|
||||||
partial_json,
|
|
||||||
} => {
|
|
||||||
if Some(index) == tool_use_index {
|
|
||||||
tool_input.push_str(&partial_json);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
anthropic::Event::ContentBlockStop { index } => {
|
|
||||||
if Some(index) == tool_use_index {
|
|
||||||
return Ok(serde_json::from_str(&tool_input)?);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
_ => {}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if tool_use_index.is_some() {
|
|
||||||
Err(anyhow!("tool content incomplete"))
|
|
||||||
} else {
|
|
||||||
Err(anyhow!("tool not used"))
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.boxed()
|
|
||||||
} else {
|
|
||||||
self.request_limiter
|
|
||||||
.run(async move {
|
|
||||||
let request = serde_json::to_string(&request)?;
|
|
||||||
let response = client
|
|
||||||
.request(proto::CompleteWithLanguageModel {
|
|
||||||
provider: proto::LanguageModelProvider::Anthropic as i32,
|
|
||||||
request,
|
|
||||||
})
|
|
||||||
.await?;
|
|
||||||
let response: anthropic::Response =
|
|
||||||
serde_json::from_str(&response.completion)?;
|
|
||||||
response
|
|
||||||
.content
|
|
||||||
.into_iter()
|
|
||||||
.find_map(|content| {
|
|
||||||
if let anthropic::Content::ToolUse { name, input, .. } = content
|
|
||||||
{
|
{
|
||||||
if name == tool_name {
|
if name == tool_name {
|
||||||
Some(input)
|
tool_use_index = Some(index);
|
||||||
} else {
|
|
||||||
None
|
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
None
|
|
||||||
}
|
}
|
||||||
})
|
}
|
||||||
.context("tool not used")
|
anthropic::Event::ContentBlockDelta { index, delta } => match delta
|
||||||
})
|
{
|
||||||
.boxed()
|
anthropic::ContentDelta::TextDelta { .. } => {}
|
||||||
}
|
anthropic::ContentDelta::InputJsonDelta { partial_json } => {
|
||||||
|
if Some(index) == tool_use_index {
|
||||||
|
tool_input.push_str(&partial_json);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
anthropic::Event::ContentBlockStop { index } => {
|
||||||
|
if Some(index) == tool_use_index {
|
||||||
|
return Ok(serde_json::from_str(&tool_input)?);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if tool_use_index.is_some() {
|
||||||
|
Err(anyhow!("tool content incomplete"))
|
||||||
|
} else {
|
||||||
|
Err(anyhow!("tool not used"))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.boxed()
|
||||||
}
|
}
|
||||||
CloudModel::OpenAi(model) => {
|
CloudModel::OpenAi(model) => {
|
||||||
let mut request = request.into_open_ai(model.id().into());
|
let mut request = request.into_open_ai(model.id().into());
|
||||||
|
@ -723,115 +590,59 @@ impl LanguageModel for CloudLanguageModel {
|
||||||
function.parameters = Some(input_schema);
|
function.parameters = Some(input_schema);
|
||||||
request.tools = vec![open_ai::ToolDefinition::Function { function }];
|
request.tools = vec![open_ai::ToolDefinition::Function { function }];
|
||||||
|
|
||||||
if cx
|
let llm_api_token = self.llm_api_token.clone();
|
||||||
.update(|cx| cx.has_flag::<LlmServiceFeatureFlag>())
|
self.request_limiter
|
||||||
.unwrap_or(false)
|
.run(async move {
|
||||||
{
|
let response = Self::perform_llm_completion(
|
||||||
let llm_api_token = self.llm_api_token.clone();
|
client.clone(),
|
||||||
self.request_limiter
|
llm_api_token,
|
||||||
.run(async move {
|
PerformCompletionParams {
|
||||||
let response = Self::perform_llm_completion(
|
provider: client::LanguageModelProvider::OpenAi,
|
||||||
client.clone(),
|
model: request.model.clone(),
|
||||||
llm_api_token,
|
provider_request: RawValue::from_string(serde_json::to_string(
|
||||||
PerformCompletionParams {
|
&request,
|
||||||
provider: client::LanguageModelProvider::OpenAi,
|
)?)?,
|
||||||
model: request.model.clone(),
|
},
|
||||||
provider_request: RawValue::from_string(
|
)
|
||||||
serde_json::to_string(&request)?,
|
.await?;
|
||||||
)?,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
let mut body = BufReader::new(response.into_body());
|
let mut body = BufReader::new(response.into_body());
|
||||||
let mut line = String::new();
|
let mut line = String::new();
|
||||||
let mut load_state = None;
|
let mut load_state = None;
|
||||||
|
|
||||||
while body.read_line(&mut line).await? > 0 {
|
while body.read_line(&mut line).await? > 0 {
|
||||||
let part: open_ai::ResponseStreamEvent =
|
let part: open_ai::ResponseStreamEvent = serde_json::from_str(&line)?;
|
||||||
serde_json::from_str(&line)?;
|
line.clear();
|
||||||
line.clear();
|
|
||||||
|
|
||||||
for choice in part.choices {
|
for choice in part.choices {
|
||||||
let Some(tool_calls) = choice.delta.tool_calls else {
|
let Some(tool_calls) = choice.delta.tool_calls else {
|
||||||
continue;
|
continue;
|
||||||
};
|
};
|
||||||
|
|
||||||
for call in tool_calls {
|
for call in tool_calls {
|
||||||
if let Some(func) = call.function {
|
if let Some(func) = call.function {
|
||||||
if func.name.as_deref() == Some(tool_name.as_str()) {
|
if func.name.as_deref() == Some(tool_name.as_str()) {
|
||||||
load_state = Some((String::default(), call.index));
|
load_state = Some((String::default(), call.index));
|
||||||
}
|
}
|
||||||
if let Some((arguments, (output, index))) =
|
if let Some((arguments, (output, index))) =
|
||||||
func.arguments.zip(load_state.as_mut())
|
func.arguments.zip(load_state.as_mut())
|
||||||
{
|
{
|
||||||
if call.index == *index {
|
if call.index == *index {
|
||||||
output.push_str(&arguments);
|
output.push_str(&arguments);
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if let Some((arguments, _)) = load_state {
|
if let Some((arguments, _)) = load_state {
|
||||||
return Ok(serde_json::from_str(&arguments)?);
|
return Ok(serde_json::from_str(&arguments)?);
|
||||||
} else {
|
} else {
|
||||||
bail!("tool not used");
|
bail!("tool not used");
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
.boxed()
|
.boxed()
|
||||||
} else {
|
|
||||||
self.request_limiter
|
|
||||||
.run(async move {
|
|
||||||
let request = serde_json::to_string(&request)?;
|
|
||||||
let response = client
|
|
||||||
.request_stream(proto::StreamCompleteWithLanguageModel {
|
|
||||||
provider: proto::LanguageModelProvider::OpenAi as i32,
|
|
||||||
request,
|
|
||||||
})
|
|
||||||
.await?;
|
|
||||||
let mut load_state = None;
|
|
||||||
let mut response = response.map(
|
|
||||||
|item: Result<
|
|
||||||
proto::StreamCompleteWithLanguageModelResponse,
|
|
||||||
anyhow::Error,
|
|
||||||
>| {
|
|
||||||
Result::<open_ai::ResponseStreamEvent, anyhow::Error>::Ok(
|
|
||||||
serde_json::from_str(&item?.event)?,
|
|
||||||
)
|
|
||||||
},
|
|
||||||
);
|
|
||||||
while let Some(Ok(part)) = response.next().await {
|
|
||||||
for choice in part.choices {
|
|
||||||
let Some(tool_calls) = choice.delta.tool_calls else {
|
|
||||||
continue;
|
|
||||||
};
|
|
||||||
|
|
||||||
for call in tool_calls {
|
|
||||||
if let Some(func) = call.function {
|
|
||||||
if func.name.as_deref() == Some(tool_name.as_str()) {
|
|
||||||
load_state = Some((String::default(), call.index));
|
|
||||||
}
|
|
||||||
if let Some((arguments, (output, index))) =
|
|
||||||
func.arguments.zip(load_state.as_mut())
|
|
||||||
{
|
|
||||||
if call.index == *index {
|
|
||||||
output.push_str(&arguments);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if let Some((arguments, _)) = load_state {
|
|
||||||
return Ok(serde_json::from_str(&arguments)?);
|
|
||||||
} else {
|
|
||||||
bail!("tool not used");
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.boxed()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
CloudModel::Google(_) => {
|
CloudModel::Google(_) => {
|
||||||
future::ready(Err(anyhow!("tool use not implemented for Google AI"))).boxed()
|
future::ready(Err(anyhow!("tool use not implemented for Google AI"))).boxed()
|
||||||
|
@ -854,114 +665,58 @@ impl LanguageModel for CloudLanguageModel {
|
||||||
function.parameters = Some(input_schema);
|
function.parameters = Some(input_schema);
|
||||||
request.tools = vec![open_ai::ToolDefinition::Function { function }];
|
request.tools = vec![open_ai::ToolDefinition::Function { function }];
|
||||||
|
|
||||||
if cx
|
let llm_api_token = self.llm_api_token.clone();
|
||||||
.update(|cx| cx.has_flag::<LlmServiceFeatureFlag>())
|
self.request_limiter
|
||||||
.unwrap_or(false)
|
.run(async move {
|
||||||
{
|
let response = Self::perform_llm_completion(
|
||||||
let llm_api_token = self.llm_api_token.clone();
|
client.clone(),
|
||||||
self.request_limiter
|
llm_api_token,
|
||||||
.run(async move {
|
PerformCompletionParams {
|
||||||
let response = Self::perform_llm_completion(
|
provider: client::LanguageModelProvider::Zed,
|
||||||
client.clone(),
|
model: request.model.clone(),
|
||||||
llm_api_token,
|
provider_request: RawValue::from_string(serde_json::to_string(
|
||||||
PerformCompletionParams {
|
&request,
|
||||||
provider: client::LanguageModelProvider::Zed,
|
)?)?,
|
||||||
model: request.model.clone(),
|
},
|
||||||
provider_request: RawValue::from_string(
|
)
|
||||||
serde_json::to_string(&request)?,
|
.await?;
|
||||||
)?,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
let mut body = BufReader::new(response.into_body());
|
let mut body = BufReader::new(response.into_body());
|
||||||
let mut line = String::new();
|
let mut line = String::new();
|
||||||
let mut load_state = None;
|
let mut load_state = None;
|
||||||
|
|
||||||
while body.read_line(&mut line).await? > 0 {
|
while body.read_line(&mut line).await? > 0 {
|
||||||
let part: open_ai::ResponseStreamEvent =
|
let part: open_ai::ResponseStreamEvent = serde_json::from_str(&line)?;
|
||||||
serde_json::from_str(&line)?;
|
line.clear();
|
||||||
line.clear();
|
|
||||||
|
|
||||||
for choice in part.choices {
|
for choice in part.choices {
|
||||||
let Some(tool_calls) = choice.delta.tool_calls else {
|
let Some(tool_calls) = choice.delta.tool_calls else {
|
||||||
continue;
|
continue;
|
||||||
};
|
};
|
||||||
|
|
||||||
for call in tool_calls {
|
for call in tool_calls {
|
||||||
if let Some(func) = call.function {
|
if let Some(func) = call.function {
|
||||||
if func.name.as_deref() == Some(tool_name.as_str()) {
|
if func.name.as_deref() == Some(tool_name.as_str()) {
|
||||||
load_state = Some((String::default(), call.index));
|
load_state = Some((String::default(), call.index));
|
||||||
}
|
}
|
||||||
if let Some((arguments, (output, index))) =
|
if let Some((arguments, (output, index))) =
|
||||||
func.arguments.zip(load_state.as_mut())
|
func.arguments.zip(load_state.as_mut())
|
||||||
{
|
{
|
||||||
if call.index == *index {
|
if call.index == *index {
|
||||||
output.push_str(&arguments);
|
output.push_str(&arguments);
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if let Some((arguments, _)) = load_state {
|
}
|
||||||
return Ok(serde_json::from_str(&arguments)?);
|
if let Some((arguments, _)) = load_state {
|
||||||
} else {
|
return Ok(serde_json::from_str(&arguments)?);
|
||||||
bail!("tool not used");
|
} else {
|
||||||
}
|
bail!("tool not used");
|
||||||
})
|
}
|
||||||
.boxed()
|
})
|
||||||
} else {
|
.boxed()
|
||||||
self.request_limiter
|
|
||||||
.run(async move {
|
|
||||||
let request = serde_json::to_string(&request)?;
|
|
||||||
let response = client
|
|
||||||
.request_stream(proto::StreamCompleteWithLanguageModel {
|
|
||||||
provider: proto::LanguageModelProvider::OpenAi as i32,
|
|
||||||
request,
|
|
||||||
})
|
|
||||||
.await?;
|
|
||||||
let mut load_state = None;
|
|
||||||
let mut response = response.map(
|
|
||||||
|item: Result<
|
|
||||||
proto::StreamCompleteWithLanguageModelResponse,
|
|
||||||
anyhow::Error,
|
|
||||||
>| {
|
|
||||||
Result::<open_ai::ResponseStreamEvent, anyhow::Error>::Ok(
|
|
||||||
serde_json::from_str(&item?.event)?,
|
|
||||||
)
|
|
||||||
},
|
|
||||||
);
|
|
||||||
while let Some(Ok(part)) = response.next().await {
|
|
||||||
for choice in part.choices {
|
|
||||||
let Some(tool_calls) = choice.delta.tool_calls else {
|
|
||||||
continue;
|
|
||||||
};
|
|
||||||
|
|
||||||
for call in tool_calls {
|
|
||||||
if let Some(func) = call.function {
|
|
||||||
if func.name.as_deref() == Some(tool_name.as_str()) {
|
|
||||||
load_state = Some((String::default(), call.index));
|
|
||||||
}
|
|
||||||
if let Some((arguments, (output, index))) =
|
|
||||||
func.arguments.zip(load_state.as_mut())
|
|
||||||
{
|
|
||||||
if call.index == *index {
|
|
||||||
output.push_str(&arguments);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if let Some((arguments, _)) = load_state {
|
|
||||||
return Ok(serde_json::from_str(&arguments)?);
|
|
||||||
} else {
|
|
||||||
bail!("tool not used");
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.boxed()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue