agent: Improve error handling and retry for zed-provided models (#33565)

* Updates to `zed_llm_client-0.8.5` which adds support for `retry_after`
when anthropic provides it.

* Distinguishes upstream provider errors and rate limits from errors
that originate from zed's servers

* Moves `LanguageModelCompletionError::BadInputJson` to
`LanguageModelCompletionEvent::ToolUseJsonParseError`. While arguably
this is an error case, the logic in thread is cleaner with this move.
There is also precedent for inclusion of errors in the event type -
`CompletionRequestStatus::Failed` is how cloud errors arrive.

* Updates `PROVIDER_ID` / `PROVIDER_NAME` constants to use proper types
instead of `&str`, since they can be constructed in a const fashion.

* Removes use of `CLIENT_SUPPORTS_EXA_WEB_SEARCH_PROVIDER_HEADER_NAME`
as the server no longer reads this header and just defaults to that
behavior.

Release notes for this is covered by #33275

Release Notes:

- N/A

---------

Co-authored-by: Richard Feldman <oss@rtfeldman.com>
Co-authored-by: Richard <richard@zed.dev>
This commit is contained in:
Michael Sloan 2025-06-30 21:01:32 -06:00 committed by GitHub
parent f022a13091
commit d497f52e17
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
25 changed files with 656 additions and 479 deletions

View file

@ -33,8 +33,8 @@ use theme::ThemeSettings;
use ui::{Icon, IconName, List, Tooltip, prelude::*};
use util::ResultExt;
const PROVIDER_ID: &str = language_model::ANTHROPIC_PROVIDER_ID;
const PROVIDER_NAME: &str = "Anthropic";
const PROVIDER_ID: LanguageModelProviderId = language_model::ANTHROPIC_PROVIDER_ID;
const PROVIDER_NAME: LanguageModelProviderName = language_model::ANTHROPIC_PROVIDER_NAME;
#[derive(Default, Clone, Debug, PartialEq)]
pub struct AnthropicSettings {
@ -218,11 +218,11 @@ impl LanguageModelProviderState for AnthropicLanguageModelProvider {
impl LanguageModelProvider for AnthropicLanguageModelProvider {
fn id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into())
PROVIDER_ID
}
fn name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into())
PROVIDER_NAME
}
fn icon(&self) -> IconName {
@ -403,7 +403,11 @@ impl AnthropicModel {
};
async move {
let api_key = api_key.context("Missing Anthropic API Key")?;
let Some(api_key) = api_key else {
return Err(LanguageModelCompletionError::NoApiKey {
provider: PROVIDER_NAME,
});
};
let request =
anthropic::stream_completion(http_client.as_ref(), &api_url, &api_key, request);
request.await.map_err(Into::into)
@ -422,11 +426,11 @@ impl LanguageModel for AnthropicModel {
}
fn provider_id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into())
PROVIDER_ID
}
fn provider_name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into())
PROVIDER_NAME
}
fn supports_tools(&self) -> bool {
@ -806,12 +810,14 @@ impl AnthropicEventMapper {
raw_input: tool_use.input_json.clone(),
},
)),
Err(json_parse_err) => Err(LanguageModelCompletionError::BadInputJson {
id: tool_use.id.into(),
tool_name: tool_use.name.into(),
raw_input: input_json.into(),
json_parse_error: json_parse_err.to_string(),
}),
Err(json_parse_err) => {
Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
id: tool_use.id.into(),
tool_name: tool_use.name.into(),
raw_input: input_json.into(),
json_parse_error: json_parse_err.to_string(),
})
}
};
vec![event_result]

View file

@ -52,8 +52,8 @@ use util::ResultExt;
use crate::AllLanguageModelSettings;
const PROVIDER_ID: &str = "amazon-bedrock";
const PROVIDER_NAME: &str = "Amazon Bedrock";
const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("amazon-bedrock");
const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("Amazon Bedrock");
#[derive(Default, Clone, Deserialize, Serialize, PartialEq, Debug)]
pub struct BedrockCredentials {
@ -285,11 +285,11 @@ impl BedrockLanguageModelProvider {
impl LanguageModelProvider for BedrockLanguageModelProvider {
fn id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into())
PROVIDER_ID
}
fn name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into())
PROVIDER_NAME
}
fn icon(&self) -> IconName {
@ -489,11 +489,11 @@ impl LanguageModel for BedrockModel {
}
fn provider_id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into())
PROVIDER_ID
}
fn provider_name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into())
PROVIDER_NAME
}
fn supports_tools(&self) -> bool {

View file

@ -1,4 +1,4 @@
use anthropic::{AnthropicModelMode, parse_prompt_too_long};
use anthropic::AnthropicModelMode;
use anyhow::{Context as _, Result, anyhow};
use client::{Client, ModelRequestUsage, UserStore, zed_urls};
use futures::{
@ -8,25 +8,21 @@ use google_ai::GoogleModelMode;
use gpui::{
AnyElement, AnyView, App, AsyncApp, Context, Entity, SemanticVersion, Subscription, Task,
};
use http_client::http::{HeaderMap, HeaderValue};
use http_client::{AsyncBody, HttpClient, Method, Response, StatusCode};
use language_model::{
AuthenticateError, LanguageModel, LanguageModelCacheConfiguration,
LanguageModelCompletionError, LanguageModelId, LanguageModelKnownError, LanguageModelName,
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
LanguageModelProviderTosView, LanguageModelRequest, LanguageModelToolChoice,
LanguageModelToolSchemaFormat, ModelRequestLimitReachedError, RateLimiter,
ZED_CLOUD_PROVIDER_ID,
};
use language_model::{
LanguageModelCompletionEvent, LanguageModelProvider, LlmApiToken, PaymentRequiredError,
RefreshLlmTokenListener,
LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName,
LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelProviderTosView, LanguageModelRequest,
LanguageModelToolChoice, LanguageModelToolSchemaFormat, LlmApiToken,
ModelRequestLimitReachedError, PaymentRequiredError, RateLimiter, RefreshLlmTokenListener,
};
use proto::Plan;
use release_channel::AppVersion;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize, de::DeserializeOwned};
use settings::SettingsStore;
use smol::Timer;
use smol::io::{AsyncReadExt, BufReader};
use std::pin::Pin;
use std::str::FromStr as _;
@ -47,7 +43,8 @@ use crate::provider::anthropic::{AnthropicEventMapper, count_anthropic_tokens, i
use crate::provider::google::{GoogleEventMapper, into_google};
use crate::provider::open_ai::{OpenAiEventMapper, count_open_ai_tokens, into_open_ai};
pub const PROVIDER_NAME: &str = "Zed";
const PROVIDER_ID: LanguageModelProviderId = language_model::ZED_CLOUD_PROVIDER_ID;
const PROVIDER_NAME: LanguageModelProviderName = language_model::ZED_CLOUD_PROVIDER_NAME;
#[derive(Default, Clone, Debug, PartialEq)]
pub struct ZedDotDevSettings {
@ -351,11 +348,11 @@ impl LanguageModelProviderState for CloudLanguageModelProvider {
impl LanguageModelProvider for CloudLanguageModelProvider {
fn id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(ZED_CLOUD_PROVIDER_ID.into())
PROVIDER_ID
}
fn name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into())
PROVIDER_NAME
}
fn icon(&self) -> IconName {
@ -536,8 +533,6 @@ struct PerformLlmCompletionResponse {
}
impl CloudLanguageModel {
const MAX_RETRIES: usize = 3;
async fn perform_llm_completion(
client: Arc<Client>,
llm_api_token: LlmApiToken,
@ -547,8 +542,7 @@ impl CloudLanguageModel {
let http_client = &client.http_client();
let mut token = llm_api_token.acquire(&client).await?;
let mut retries_remaining = Self::MAX_RETRIES;
let mut retry_delay = Duration::from_secs(1);
let mut refreshed_token = false;
loop {
let request_builder = http_client::Request::builder()
@ -590,14 +584,20 @@ impl CloudLanguageModel {
includes_status_messages,
tool_use_limit_reached,
});
} else if response
.headers()
.get(EXPIRED_LLM_TOKEN_HEADER_NAME)
.is_some()
}
if !refreshed_token
&& response
.headers()
.get(EXPIRED_LLM_TOKEN_HEADER_NAME)
.is_some()
{
retries_remaining -= 1;
token = llm_api_token.refresh(&client).await?;
} else if status == StatusCode::FORBIDDEN
refreshed_token = true;
continue;
}
if status == StatusCode::FORBIDDEN
&& response
.headers()
.get(SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME)
@ -622,35 +622,18 @@ impl CloudLanguageModel {
return Err(anyhow!(ModelRequestLimitReachedError { plan }));
}
}
anyhow::bail!("Forbidden");
} else if status.as_u16() >= 500 && status.as_u16() < 600 {
// If we encounter an error in the 500 range, retry after a delay.
// We've seen at least these in the wild from API providers:
// * 500 Internal Server Error
// * 502 Bad Gateway
// * 529 Service Overloaded
if retries_remaining == 0 {
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
anyhow::bail!(
"cloud language model completion failed after {} retries with status {status}: {body}",
Self::MAX_RETRIES
);
}
Timer::after(retry_delay).await;
retries_remaining -= 1;
retry_delay *= 2; // If it fails again, wait longer.
} else if status == StatusCode::PAYMENT_REQUIRED {
return Err(anyhow!(PaymentRequiredError));
} else {
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
return Err(anyhow!(ApiError { status, body }));
}
let mut body = String::new();
let headers = response.headers().clone();
response.body_mut().read_to_string(&mut body).await?;
return Err(anyhow!(ApiError {
status,
body,
headers
}));
}
}
}
@ -660,6 +643,19 @@ impl CloudLanguageModel {
struct ApiError {
status: StatusCode,
body: String,
headers: HeaderMap<HeaderValue>,
}
impl From<ApiError> for LanguageModelCompletionError {
fn from(error: ApiError) -> Self {
let retry_after = None;
LanguageModelCompletionError::from_http_status(
PROVIDER_NAME,
error.status,
error.body,
retry_after,
)
}
}
impl LanguageModel for CloudLanguageModel {
@ -672,11 +668,29 @@ impl LanguageModel for CloudLanguageModel {
}
fn provider_id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(ZED_CLOUD_PROVIDER_ID.into())
PROVIDER_ID
}
fn provider_name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into())
PROVIDER_NAME
}
fn upstream_provider_id(&self) -> LanguageModelProviderId {
use zed_llm_client::LanguageModelProvider::*;
match self.model.provider {
Anthropic => language_model::ANTHROPIC_PROVIDER_ID,
OpenAi => language_model::OPEN_AI_PROVIDER_ID,
Google => language_model::GOOGLE_PROVIDER_ID,
}
}
fn upstream_provider_name(&self) -> LanguageModelProviderName {
use zed_llm_client::LanguageModelProvider::*;
match self.model.provider {
Anthropic => language_model::ANTHROPIC_PROVIDER_NAME,
OpenAi => language_model::OPEN_AI_PROVIDER_NAME,
Google => language_model::GOOGLE_PROVIDER_NAME,
}
}
fn supports_tools(&self) -> bool {
@ -776,6 +790,7 @@ impl LanguageModel for CloudLanguageModel {
.body(serde_json::to_string(&request_body)?.into())?;
let mut response = http_client.send(request).await?;
let status = response.status();
let headers = response.headers().clone();
let mut response_body = String::new();
response
.body_mut()
@ -790,7 +805,8 @@ impl LanguageModel for CloudLanguageModel {
} else {
Err(anyhow!(ApiError {
status,
body: response_body
body: response_body,
headers
}))
}
}
@ -855,18 +871,7 @@ impl LanguageModel for CloudLanguageModel {
)
.await
.map_err(|err| match err.downcast::<ApiError>() {
Ok(api_err) => {
if api_err.status == StatusCode::BAD_REQUEST {
if let Some(tokens) = parse_prompt_too_long(&api_err.body) {
return anyhow!(
LanguageModelKnownError::ContextWindowLimitExceeded {
tokens
}
);
}
}
anyhow!(api_err)
}
Ok(api_err) => anyhow!(LanguageModelCompletionError::from(api_err)),
Err(err) => anyhow!(err),
})?;
@ -995,7 +1000,7 @@ where
.flat_map(move |event| {
futures::stream::iter(match event {
Err(error) => {
vec![Err(LanguageModelCompletionError::Other(error))]
vec![Err(LanguageModelCompletionError::from(error))]
}
Ok(CloudCompletionEvent::Status(event)) => {
vec![Ok(LanguageModelCompletionEvent::StatusUpdate(event))]

View file

@ -35,8 +35,9 @@ use super::anthropic::count_anthropic_tokens;
use super::google::count_google_tokens;
use super::open_ai::count_open_ai_tokens;
const PROVIDER_ID: &str = "copilot_chat";
const PROVIDER_NAME: &str = "GitHub Copilot Chat";
const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("copilot_chat");
const PROVIDER_NAME: LanguageModelProviderName =
LanguageModelProviderName::new("GitHub Copilot Chat");
pub struct CopilotChatLanguageModelProvider {
state: Entity<State>,
@ -102,11 +103,11 @@ impl LanguageModelProviderState for CopilotChatLanguageModelProvider {
impl LanguageModelProvider for CopilotChatLanguageModelProvider {
fn id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into())
PROVIDER_ID
}
fn name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into())
PROVIDER_NAME
}
fn icon(&self) -> IconName {
@ -201,11 +202,11 @@ impl LanguageModel for CopilotChatLanguageModel {
}
fn provider_id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into())
PROVIDER_ID
}
fn provider_name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into())
PROVIDER_NAME
}
fn supports_tools(&self) -> bool {
@ -391,24 +392,24 @@ pub fn map_to_language_model_completion_events(
serde_json::Value::from_str(&tool_call.arguments)
};
match arguments {
Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse(
LanguageModelToolUse {
id: tool_call.id.clone().into(),
name: tool_call.name.as_str().into(),
is_input_complete: true,
input,
raw_input: tool_call.arguments.clone(),
},
)),
Err(error) => {
Err(LanguageModelCompletionError::BadInputJson {
id: tool_call.id.into(),
tool_name: tool_call.name.as_str().into(),
raw_input: tool_call.arguments.into(),
json_parse_error: error.to_string(),
})
}
}
Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse(
LanguageModelToolUse {
id: tool_call.id.clone().into(),
name: tool_call.name.as_str().into(),
is_input_complete: true,
input,
raw_input: tool_call.arguments.clone(),
},
)),
Err(error) => Ok(
LanguageModelCompletionEvent::ToolUseJsonParseError {
id: tool_call.id.into(),
tool_name: tool_call.name.as_str().into(),
raw_input: tool_call.arguments.into(),
json_parse_error: error.to_string(),
},
),
}
},
));

View file

@ -28,8 +28,8 @@ use util::ResultExt;
use crate::{AllLanguageModelSettings, ui::InstructionListItem};
const PROVIDER_ID: &str = "deepseek";
const PROVIDER_NAME: &str = "DeepSeek";
const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("deepseek");
const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("DeepSeek");
const DEEPSEEK_API_KEY_VAR: &str = "DEEPSEEK_API_KEY";
#[derive(Default)]
@ -174,11 +174,11 @@ impl LanguageModelProviderState for DeepSeekLanguageModelProvider {
impl LanguageModelProvider for DeepSeekLanguageModelProvider {
fn id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into())
PROVIDER_ID
}
fn name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into())
PROVIDER_NAME
}
fn icon(&self) -> IconName {
@ -283,11 +283,11 @@ impl LanguageModel for DeepSeekLanguageModel {
}
fn provider_id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into())
PROVIDER_ID
}
fn provider_name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into())
PROVIDER_NAME
}
fn supports_tools(&self) -> bool {
@ -466,7 +466,7 @@ impl DeepSeekEventMapper {
events.flat_map(move |event| {
futures::stream::iter(match event {
Ok(event) => self.map_event(event),
Err(error) => vec![Err(LanguageModelCompletionError::Other(anyhow!(error)))],
Err(error) => vec![Err(LanguageModelCompletionError::from(error))],
})
})
}
@ -476,7 +476,7 @@ impl DeepSeekEventMapper {
event: deepseek::StreamResponse,
) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
let Some(choice) = event.choices.first() else {
return vec![Err(LanguageModelCompletionError::Other(anyhow!(
return vec![Err(LanguageModelCompletionError::from(anyhow!(
"Response contained no choices"
)))];
};
@ -538,8 +538,8 @@ impl DeepSeekEventMapper {
raw_input: tool_call.arguments.clone(),
},
)),
Err(error) => Err(LanguageModelCompletionError::BadInputJson {
id: tool_call.id.into(),
Err(error) => Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
id: tool_call.id.clone().into(),
tool_name: tool_call.name.as_str().into(),
raw_input: tool_call.arguments.into(),
json_parse_error: error.to_string(),

View file

@ -37,8 +37,8 @@ use util::ResultExt;
use crate::AllLanguageModelSettings;
use crate::ui::InstructionListItem;
const PROVIDER_ID: &str = "google";
const PROVIDER_NAME: &str = "Google AI";
const PROVIDER_ID: LanguageModelProviderId = language_model::GOOGLE_PROVIDER_ID;
const PROVIDER_NAME: LanguageModelProviderName = language_model::GOOGLE_PROVIDER_NAME;
#[derive(Default, Clone, Debug, PartialEq)]
pub struct GoogleSettings {
@ -207,11 +207,11 @@ impl LanguageModelProviderState for GoogleLanguageModelProvider {
impl LanguageModelProvider for GoogleLanguageModelProvider {
fn id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into())
PROVIDER_ID
}
fn name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into())
PROVIDER_NAME
}
fn icon(&self) -> IconName {
@ -334,11 +334,11 @@ impl LanguageModel for GoogleLanguageModel {
}
fn provider_id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into())
PROVIDER_ID
}
fn provider_name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into())
PROVIDER_NAME
}
fn supports_tools(&self) -> bool {
@ -423,9 +423,7 @@ impl LanguageModel for GoogleLanguageModel {
);
let request = self.stream_completion(request, cx);
let future = self.request_limiter.stream(async move {
let response = request
.await
.map_err(|err| LanguageModelCompletionError::Other(anyhow!(err)))?;
let response = request.await.map_err(LanguageModelCompletionError::from)?;
Ok(GoogleEventMapper::new().map_stream(response))
});
async move { Ok(future.await?.boxed()) }.boxed()
@ -622,7 +620,7 @@ impl GoogleEventMapper {
futures::stream::iter(match event {
Some(Ok(event)) => self.map_event(event),
Some(Err(error)) => {
vec![Err(LanguageModelCompletionError::Other(anyhow!(error)))]
vec![Err(LanguageModelCompletionError::from(error))]
}
None => vec![Ok(LanguageModelCompletionEvent::Stop(self.stop_reason))],
})

View file

@ -31,8 +31,8 @@ const LMSTUDIO_DOWNLOAD_URL: &str = "https://lmstudio.ai/download";
const LMSTUDIO_CATALOG_URL: &str = "https://lmstudio.ai/models";
const LMSTUDIO_SITE: &str = "https://lmstudio.ai/";
const PROVIDER_ID: &str = "lmstudio";
const PROVIDER_NAME: &str = "LM Studio";
const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("lmstudio");
const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("LM Studio");
#[derive(Default, Debug, Clone, PartialEq)]
pub struct LmStudioSettings {
@ -156,11 +156,11 @@ impl LanguageModelProviderState for LmStudioLanguageModelProvider {
impl LanguageModelProvider for LmStudioLanguageModelProvider {
fn id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into())
PROVIDER_ID
}
fn name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into())
PROVIDER_NAME
}
fn icon(&self) -> IconName {
@ -386,11 +386,11 @@ impl LanguageModel for LmStudioLanguageModel {
}
fn provider_id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into())
PROVIDER_ID
}
fn provider_name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into())
PROVIDER_NAME
}
fn supports_tools(&self) -> bool {
@ -474,7 +474,7 @@ impl LmStudioEventMapper {
events.flat_map(move |event| {
futures::stream::iter(match event {
Ok(event) => self.map_event(event),
Err(error) => vec![Err(LanguageModelCompletionError::Other(anyhow!(error)))],
Err(error) => vec![Err(LanguageModelCompletionError::from(error))],
})
})
}
@ -484,7 +484,7 @@ impl LmStudioEventMapper {
event: lmstudio::ResponseStreamEvent,
) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
let Some(choice) = event.choices.into_iter().next() else {
return vec![Err(LanguageModelCompletionError::Other(anyhow!(
return vec![Err(LanguageModelCompletionError::from(anyhow!(
"Response contained no choices"
)))];
};
@ -553,7 +553,7 @@ impl LmStudioEventMapper {
raw_input: tool_call.arguments,
},
)),
Err(error) => Err(LanguageModelCompletionError::BadInputJson {
Err(error) => Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
id: tool_call.id.into(),
tool_name: tool_call.name.into(),
raw_input: tool_call.arguments.into(),

View file

@ -2,8 +2,7 @@ use anyhow::{Context as _, Result, anyhow};
use collections::BTreeMap;
use credentials_provider::CredentialsProvider;
use editor::{Editor, EditorElement, EditorStyle};
use futures::stream::BoxStream;
use futures::{FutureExt, StreamExt, future::BoxFuture};
use futures::{FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream};
use gpui::{
AnyView, App, AsyncApp, Context, Entity, FontStyle, Subscription, Task, TextStyle, WhiteSpace,
};
@ -15,6 +14,7 @@ use language_model::{
LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent,
RateLimiter, Role, StopReason, TokenUsage,
};
use mistral::StreamResponse;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsStore};
@ -29,8 +29,8 @@ use util::ResultExt;
use crate::{AllLanguageModelSettings, ui::InstructionListItem};
const PROVIDER_ID: &str = "mistral";
const PROVIDER_NAME: &str = "Mistral";
const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("mistral");
const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("Mistral");
#[derive(Default, Clone, Debug, PartialEq)]
pub struct MistralSettings {
@ -171,11 +171,11 @@ impl LanguageModelProviderState for MistralLanguageModelProvider {
impl LanguageModelProvider for MistralLanguageModelProvider {
fn id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into())
PROVIDER_ID
}
fn name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into())
PROVIDER_NAME
}
fn icon(&self) -> IconName {
@ -298,11 +298,11 @@ impl LanguageModel for MistralLanguageModel {
}
fn provider_id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into())
PROVIDER_ID
}
fn provider_name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into())
PROVIDER_NAME
}
fn supports_tools(&self) -> bool {
@ -579,13 +579,13 @@ impl MistralEventMapper {
pub fn map_stream(
mut self,
events: Pin<Box<dyn Send + futures::Stream<Item = Result<mistral::StreamResponse>>>>,
) -> impl futures::Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
events: Pin<Box<dyn Send + Stream<Item = Result<StreamResponse>>>>,
) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
{
events.flat_map(move |event| {
futures::stream::iter(match event {
Ok(event) => self.map_event(event),
Err(error) => vec![Err(LanguageModelCompletionError::Other(anyhow!(error)))],
Err(error) => vec![Err(LanguageModelCompletionError::from(error))],
})
})
}
@ -595,7 +595,7 @@ impl MistralEventMapper {
event: mistral::StreamResponse,
) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
let Some(choice) = event.choices.first() else {
return vec![Err(LanguageModelCompletionError::Other(anyhow!(
return vec![Err(LanguageModelCompletionError::from(anyhow!(
"Response contained no choices"
)))];
};
@ -660,7 +660,7 @@ impl MistralEventMapper {
for (_, tool_call) in self.tool_calls_by_index.drain() {
if tool_call.id.is_empty() || tool_call.name.is_empty() {
results.push(Err(LanguageModelCompletionError::Other(anyhow!(
results.push(Err(LanguageModelCompletionError::from(anyhow!(
"Received incomplete tool call: missing id or name"
))));
continue;
@ -676,12 +676,14 @@ impl MistralEventMapper {
raw_input: tool_call.arguments,
},
))),
Err(error) => results.push(Err(LanguageModelCompletionError::BadInputJson {
id: tool_call.id.into(),
tool_name: tool_call.name.into(),
raw_input: tool_call.arguments.into(),
json_parse_error: error.to_string(),
})),
Err(error) => {
results.push(Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
id: tool_call.id.into(),
tool_name: tool_call.name.into(),
raw_input: tool_call.arguments.into(),
json_parse_error: error.to_string(),
}))
}
}
}

View file

@ -30,8 +30,8 @@ const OLLAMA_DOWNLOAD_URL: &str = "https://ollama.com/download";
const OLLAMA_LIBRARY_URL: &str = "https://ollama.com/library";
const OLLAMA_SITE: &str = "https://ollama.com/";
const PROVIDER_ID: &str = "ollama";
const PROVIDER_NAME: &str = "Ollama";
const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("ollama");
const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("Ollama");
#[derive(Default, Debug, Clone, PartialEq)]
pub struct OllamaSettings {
@ -181,11 +181,11 @@ impl LanguageModelProviderState for OllamaLanguageModelProvider {
impl LanguageModelProvider for OllamaLanguageModelProvider {
fn id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into())
PROVIDER_ID
}
fn name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into())
PROVIDER_NAME
}
fn icon(&self) -> IconName {
@ -350,11 +350,11 @@ impl LanguageModel for OllamaLanguageModel {
}
fn provider_id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into())
PROVIDER_ID
}
fn provider_name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into())
PROVIDER_NAME
}
fn supports_tools(&self) -> bool {
@ -453,7 +453,7 @@ fn map_to_language_model_completion_events(
let delta = match response {
Ok(delta) => delta,
Err(e) => {
let event = Err(LanguageModelCompletionError::Other(anyhow!(e)));
let event = Err(LanguageModelCompletionError::from(anyhow!(e)));
return Some((vec![event], state));
}
};

View file

@ -31,8 +31,8 @@ use util::ResultExt;
use crate::OpenAiSettingsContent;
use crate::{AllLanguageModelSettings, ui::InstructionListItem};
const PROVIDER_ID: &str = "openai";
const PROVIDER_NAME: &str = "OpenAI";
const PROVIDER_ID: LanguageModelProviderId = language_model::OPEN_AI_PROVIDER_ID;
const PROVIDER_NAME: LanguageModelProviderName = language_model::OPEN_AI_PROVIDER_NAME;
#[derive(Default, Clone, Debug, PartialEq)]
pub struct OpenAiSettings {
@ -173,11 +173,11 @@ impl LanguageModelProviderState for OpenAiLanguageModelProvider {
impl LanguageModelProvider for OpenAiLanguageModelProvider {
fn id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into())
PROVIDER_ID
}
fn name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into())
PROVIDER_NAME
}
fn icon(&self) -> IconName {
@ -267,7 +267,11 @@ impl OpenAiLanguageModel {
};
let future = self.request_limiter.stream(async move {
let api_key = api_key.context("Missing OpenAI API Key")?;
let Some(api_key) = api_key else {
return Err(LanguageModelCompletionError::NoApiKey {
provider: PROVIDER_NAME,
});
};
let request = stream_completion(http_client.as_ref(), &api_url, &api_key, request);
let response = request.await?;
Ok(response)
@ -287,11 +291,11 @@ impl LanguageModel for OpenAiLanguageModel {
}
fn provider_id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into())
PROVIDER_ID
}
fn provider_name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into())
PROVIDER_NAME
}
fn supports_tools(&self) -> bool {
@ -525,7 +529,7 @@ impl OpenAiEventMapper {
events.flat_map(move |event| {
futures::stream::iter(match event {
Ok(event) => self.map_event(event),
Err(error) => vec![Err(LanguageModelCompletionError::Other(anyhow!(error)))],
Err(error) => vec![Err(LanguageModelCompletionError::from(anyhow!(error)))],
})
})
}
@ -588,10 +592,10 @@ impl OpenAiEventMapper {
raw_input: tool_call.arguments.clone(),
},
)),
Err(error) => Err(LanguageModelCompletionError::BadInputJson {
Err(error) => Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
id: tool_call.id.into(),
tool_name: tool_call.name.as_str().into(),
raw_input: tool_call.arguments.into(),
tool_name: tool_call.name.into(),
raw_input: tool_call.arguments.clone().into(),
json_parse_error: error.to_string(),
}),
}

View file

@ -29,8 +29,8 @@ use util::ResultExt;
use crate::{AllLanguageModelSettings, ui::InstructionListItem};
const PROVIDER_ID: &str = "openrouter";
const PROVIDER_NAME: &str = "OpenRouter";
const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("openrouter");
const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("OpenRouter");
#[derive(Default, Clone, Debug, PartialEq)]
pub struct OpenRouterSettings {
@ -244,11 +244,11 @@ impl LanguageModelProviderState for OpenRouterLanguageModelProvider {
impl LanguageModelProvider for OpenRouterLanguageModelProvider {
fn id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into())
PROVIDER_ID
}
fn name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into())
PROVIDER_NAME
}
fn icon(&self) -> IconName {
@ -363,11 +363,11 @@ impl LanguageModel for OpenRouterLanguageModel {
}
fn provider_id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into())
PROVIDER_ID
}
fn provider_name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into())
PROVIDER_NAME
}
fn supports_tools(&self) -> bool {
@ -607,7 +607,7 @@ impl OpenRouterEventMapper {
events.flat_map(move |event| {
futures::stream::iter(match event {
Ok(event) => self.map_event(event),
Err(error) => vec![Err(LanguageModelCompletionError::Other(anyhow!(error)))],
Err(error) => vec![Err(LanguageModelCompletionError::from(anyhow!(error)))],
})
})
}
@ -617,7 +617,7 @@ impl OpenRouterEventMapper {
event: ResponseStreamEvent,
) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
let Some(choice) = event.choices.first() else {
return vec![Err(LanguageModelCompletionError::Other(anyhow!(
return vec![Err(LanguageModelCompletionError::from(anyhow!(
"Response contained no choices"
)))];
};
@ -683,10 +683,10 @@ impl OpenRouterEventMapper {
raw_input: tool_call.arguments.clone(),
},
)),
Err(error) => Err(LanguageModelCompletionError::BadInputJson {
id: tool_call.id.into(),
Err(error) => Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
id: tool_call.id.clone().into(),
tool_name: tool_call.name.as_str().into(),
raw_input: tool_call.arguments.into(),
raw_input: tool_call.arguments.clone().into(),
json_parse_error: error.to_string(),
}),
}

View file

@ -25,8 +25,8 @@ use util::ResultExt;
use crate::{AllLanguageModelSettings, ui::InstructionListItem};
const PROVIDER_ID: &str = "vercel";
const PROVIDER_NAME: &str = "Vercel";
const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("vercel");
const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("Vercel");
#[derive(Default, Clone, Debug, PartialEq)]
pub struct VercelSettings {
@ -172,11 +172,11 @@ impl LanguageModelProviderState for VercelLanguageModelProvider {
impl LanguageModelProvider for VercelLanguageModelProvider {
fn id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into())
PROVIDER_ID
}
fn name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into())
PROVIDER_NAME
}
fn icon(&self) -> IconName {
@ -269,7 +269,11 @@ impl VercelLanguageModel {
};
let future = self.request_limiter.stream(async move {
let api_key = api_key.context("Missing Vercel API Key")?;
let Some(api_key) = api_key else {
return Err(LanguageModelCompletionError::NoApiKey {
provider: PROVIDER_NAME,
});
};
let request =
open_ai::stream_completion(http_client.as_ref(), &api_url, &api_key, request);
let response = request.await?;
@ -290,11 +294,11 @@ impl LanguageModel for VercelLanguageModel {
}
fn provider_id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into())
PROVIDER_ID
}
fn provider_name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into())
PROVIDER_NAME
}
fn supports_tools(&self) -> bool {