diff --git a/Cargo.lock b/Cargo.lock index b7868f2c67..20ea6472b3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -20139,9 +20139,9 @@ dependencies = [ [[package]] name = "zed_llm_client" -version = "0.8.4" +version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "de7d9523255f4e00ee3d0918e5407bd252d798a4a8e71f6d37f23317a1588203" +checksum = "c740e29260b8797ad252c202ea09a255b3cbc13f30faaf92fb6b2490336106e0" dependencies = [ "anyhow", "serde", diff --git a/Cargo.toml b/Cargo.toml index 2fee4059c6..bc686419e5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -625,7 +625,7 @@ wasmtime = { version = "29", default-features = false, features = [ wasmtime-wasi = "29" which = "6.0.0" workspace-hack = "0.1.0" -zed_llm_client = "0.8.4" +zed_llm_client = "0.8.5" zstd = "0.11" [workspace.dependencies.async-stripe] diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index 2f965e232a..028dabbd91 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -23,11 +23,10 @@ use gpui::{ }; use language_model::{ ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, - LanguageModelId, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest, - LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult, - LanguageModelToolResultContent, LanguageModelToolUseId, MessageContent, - ModelRequestLimitReachedError, PaymentRequiredError, Role, SelectedModel, StopReason, - TokenUsage, + LanguageModelId, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, + LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolResultContent, + LanguageModelToolUseId, MessageContent, ModelRequestLimitReachedError, PaymentRequiredError, + Role, SelectedModel, StopReason, TokenUsage, }; use postage::stream::Stream as _; use project::{ @@ -1531,82 +1530,7 @@ impl Thread { } thread.update(cx, |thread, cx| { - let event = match event { - Ok(event) => event, - Err(error) => { - match error { - LanguageModelCompletionError::RateLimitExceeded { retry_after } => { - anyhow::bail!(LanguageModelKnownError::RateLimitExceeded { retry_after }); - } - LanguageModelCompletionError::Overloaded => { - anyhow::bail!(LanguageModelKnownError::Overloaded); - } - LanguageModelCompletionError::ApiInternalServerError =>{ - anyhow::bail!(LanguageModelKnownError::ApiInternalServerError); - } - LanguageModelCompletionError::PromptTooLarge { tokens } => { - let tokens = tokens.unwrap_or_else(|| { - // We didn't get an exact token count from the API, so fall back on our estimate. - thread.total_token_usage() - .map(|usage| usage.total) - .unwrap_or(0) - // We know the context window was exceeded in practice, so if our estimate was - // lower than max tokens, the estimate was wrong; return that we exceeded by 1. - .max(model.max_token_count().saturating_add(1)) - }); - - anyhow::bail!(LanguageModelKnownError::ContextWindowLimitExceeded { tokens }) - } - LanguageModelCompletionError::ApiReadResponseError(io_error) => { - anyhow::bail!(LanguageModelKnownError::ReadResponseError(io_error)); - } - LanguageModelCompletionError::UnknownResponseFormat(error) => { - anyhow::bail!(LanguageModelKnownError::UnknownResponseFormat(error)); - } - LanguageModelCompletionError::HttpResponseError { status, ref body } => { - if let Some(known_error) = LanguageModelKnownError::from_http_response(status, body) { - anyhow::bail!(known_error); - } else { - return Err(error.into()); - } - } - LanguageModelCompletionError::DeserializeResponse(error) => { - anyhow::bail!(LanguageModelKnownError::DeserializeResponse(error)); - } - LanguageModelCompletionError::BadInputJson { - id, - tool_name, - raw_input: invalid_input_json, - json_parse_error, - } => { - thread.receive_invalid_tool_json( - id, - tool_name, - invalid_input_json, - json_parse_error, - window, - cx, - ); - return Ok(()); - } - // These are all errors we can't automatically attempt to recover from (e.g. by retrying) - err @ LanguageModelCompletionError::BadRequestFormat | - err @ LanguageModelCompletionError::AuthenticationError | - err @ LanguageModelCompletionError::PermissionError | - err @ LanguageModelCompletionError::ApiEndpointNotFound | - err @ LanguageModelCompletionError::SerializeRequest(_) | - err @ LanguageModelCompletionError::BuildRequestBody(_) | - err @ LanguageModelCompletionError::HttpSend(_) => { - anyhow::bail!(err); - } - LanguageModelCompletionError::Other(error) => { - return Err(error); - } - } - } - }; - - match event { + match event? { LanguageModelCompletionEvent::StartMessage { .. } => { request_assistant_message_id = Some(thread.insert_assistant_message( @@ -1683,9 +1607,7 @@ impl Thread { }; } } - LanguageModelCompletionEvent::RedactedThinking { - data - } => { + LanguageModelCompletionEvent::RedactedThinking { data } => { thread.received_chunk(); if let Some(last_message) = thread.messages.last_mut() { @@ -1734,6 +1656,21 @@ impl Thread { }); } } + LanguageModelCompletionEvent::ToolUseJsonParseError { + id, + tool_name, + raw_input: invalid_input_json, + json_parse_error, + } => { + thread.receive_invalid_tool_json( + id, + tool_name, + invalid_input_json, + json_parse_error, + window, + cx, + ); + } LanguageModelCompletionEvent::StatusUpdate(status_update) => { if let Some(completion) = thread .pending_completions @@ -1741,23 +1678,34 @@ impl Thread { .find(|completion| completion.id == pending_completion_id) { match status_update { - CompletionRequestStatus::Queued { - position, - } => { - completion.queue_state = QueueState::Queued { position }; + CompletionRequestStatus::Queued { position } => { + completion.queue_state = + QueueState::Queued { position }; } CompletionRequestStatus::Started => { - completion.queue_state = QueueState::Started; + completion.queue_state = QueueState::Started; } CompletionRequestStatus::Failed { - code, message, request_id + code, + message, + request_id: _, + retry_after, } => { - anyhow::bail!("completion request failed. request_id: {request_id}, code: {code}, message: {message}"); + return Err( + LanguageModelCompletionError::from_cloud_failure( + model.upstream_provider_name(), + code, + message, + retry_after.map(Duration::from_secs_f64), + ), + ); } - CompletionRequestStatus::UsageUpdated { - amount, limit - } => { - thread.update_model_request_usage(amount as u32, limit, cx); + CompletionRequestStatus::UsageUpdated { amount, limit } => { + thread.update_model_request_usage( + amount as u32, + limit, + cx, + ); } CompletionRequestStatus::ToolUseLimitReached => { thread.tool_use_limit_reached = true; @@ -1808,10 +1756,11 @@ impl Thread { Ok(stop_reason) => { match stop_reason { StopReason::ToolUse => { - let tool_uses = thread.use_pending_tools(window, model.clone(), cx); + let tool_uses = + thread.use_pending_tools(window, model.clone(), cx); cx.emit(ThreadEvent::UsePendingTools { tool_uses }); } - StopReason::EndTurn | StopReason::MaxTokens => { + StopReason::EndTurn | StopReason::MaxTokens => { thread.project.update(cx, |project, cx| { project.set_agent_location(None, cx); }); @@ -1827,7 +1776,9 @@ impl Thread { { let mut messages_to_remove = Vec::new(); - for (ix, message) in thread.messages.iter().enumerate().rev() { + for (ix, message) in + thread.messages.iter().enumerate().rev() + { messages_to_remove.push(message.id); if message.role == Role::User { @@ -1835,7 +1786,9 @@ impl Thread { break; } - if let Some(prev_message) = thread.messages.get(ix - 1) { + if let Some(prev_message) = + thread.messages.get(ix - 1) + { if prev_message.role == Role::Assistant { break; } @@ -1850,14 +1803,16 @@ impl Thread { cx.emit(ThreadEvent::ShowError(ThreadError::Message { header: "Language model refusal".into(), - message: "Model refused to generate content for safety reasons.".into(), + message: + "Model refused to generate content for safety reasons." + .into(), })); } } // We successfully completed, so cancel any remaining retries. thread.retry_state = None; - }, + } Err(error) => { thread.project.update(cx, |project, cx| { project.set_agent_location(None, cx); @@ -1883,26 +1838,38 @@ impl Thread { cx.emit(ThreadEvent::ShowError( ThreadError::ModelRequestLimitReached { plan: error.plan }, )); - } else if let Some(known_error) = - error.downcast_ref::() + } else if let Some(completion_error) = + error.downcast_ref::() { - match known_error { - LanguageModelKnownError::ContextWindowLimitExceeded { tokens } => { + use LanguageModelCompletionError::*; + match &completion_error { + PromptTooLarge { tokens, .. } => { + let tokens = tokens.unwrap_or_else(|| { + // We didn't get an exact token count from the API, so fall back on our estimate. + thread + .total_token_usage() + .map(|usage| usage.total) + .unwrap_or(0) + // We know the context window was exceeded in practice, so if our estimate was + // lower than max tokens, the estimate was wrong; return that we exceeded by 1. + .max(model.max_token_count().saturating_add(1)) + }); thread.exceeded_window_error = Some(ExceededWindowError { model_id: model.id(), - token_count: *tokens, + token_count: tokens, }); cx.notify(); } - LanguageModelKnownError::RateLimitExceeded { retry_after } => { - let provider_name = model.provider_name(); - let error_message = format!( - "{}'s API rate limit exceeded", - provider_name.0.as_ref() - ); - + RateLimitExceeded { + retry_after: Some(retry_after), + .. + } + | ServerOverloaded { + retry_after: Some(retry_after), + .. + } => { thread.handle_rate_limit_error( - &error_message, + &completion_error, *retry_after, model.clone(), intent, @@ -1911,15 +1878,9 @@ impl Thread { ); retry_scheduled = true; } - LanguageModelKnownError::Overloaded => { - let provider_name = model.provider_name(); - let error_message = format!( - "{}'s API servers are overloaded right now", - provider_name.0.as_ref() - ); - + RateLimitExceeded { .. } | ServerOverloaded { .. } => { retry_scheduled = thread.handle_retryable_error( - &error_message, + &completion_error, model.clone(), intent, window, @@ -1929,15 +1890,11 @@ impl Thread { emit_generic_error(error, cx); } } - LanguageModelKnownError::ApiInternalServerError => { - let provider_name = model.provider_name(); - let error_message = format!( - "{}'s API server reported an internal server error", - provider_name.0.as_ref() - ); - + ApiInternalServerError { .. } + | ApiReadResponseError { .. } + | HttpSend { .. } => { retry_scheduled = thread.handle_retryable_error( - &error_message, + &completion_error, model.clone(), intent, window, @@ -1947,12 +1904,16 @@ impl Thread { emit_generic_error(error, cx); } } - LanguageModelKnownError::ReadResponseError(_) | - LanguageModelKnownError::DeserializeResponse(_) | - LanguageModelKnownError::UnknownResponseFormat(_) => { - // In the future we will attempt to re-roll response, but only once - emit_generic_error(error, cx); - } + NoApiKey { .. } + | HttpResponseError { .. } + | BadRequestFormat { .. } + | AuthenticationError { .. } + | PermissionError { .. } + | ApiEndpointNotFound { .. } + | SerializeRequest { .. } + | BuildRequestBody { .. } + | DeserializeResponse { .. } + | Other { .. } => emit_generic_error(error, cx), } } else { emit_generic_error(error, cx); @@ -2084,7 +2045,7 @@ impl Thread { fn handle_rate_limit_error( &mut self, - error_message: &str, + error: &LanguageModelCompletionError, retry_after: Duration, model: Arc, intent: CompletionIntent, @@ -2092,9 +2053,10 @@ impl Thread { cx: &mut Context, ) { // For rate limit errors, we only retry once with the specified duration - let retry_message = format!( - "{error_message}. Retrying in {} seconds…", - retry_after.as_secs() + let retry_message = format!("{error}. Retrying in {} seconds…", retry_after.as_secs()); + log::warn!( + "Retrying completion request in {} seconds: {error:?}", + retry_after.as_secs(), ); // Add a UI-only message instead of a regular message @@ -2127,18 +2089,18 @@ impl Thread { fn handle_retryable_error( &mut self, - error_message: &str, + error: &LanguageModelCompletionError, model: Arc, intent: CompletionIntent, window: Option, cx: &mut Context, ) -> bool { - self.handle_retryable_error_with_delay(error_message, None, model, intent, window, cx) + self.handle_retryable_error_with_delay(error, None, model, intent, window, cx) } fn handle_retryable_error_with_delay( &mut self, - error_message: &str, + error: &LanguageModelCompletionError, custom_delay: Option, model: Arc, intent: CompletionIntent, @@ -2168,8 +2130,12 @@ impl Thread { // Add a transient message to inform the user let delay_secs = delay.as_secs(); let retry_message = format!( - "{}. Retrying (attempt {} of {}) in {} seconds...", - error_message, attempt, max_attempts, delay_secs + "{error}. Retrying (attempt {attempt} of {max_attempts}) \ + in {delay_secs} seconds..." + ); + log::warn!( + "Retrying completion request (attempt {attempt} of {max_attempts}) \ + in {delay_secs} seconds: {error:?}", ); // Add a UI-only message instead of a regular message @@ -4139,9 +4105,15 @@ fn main() {{ >, > { let error = match self.error_type { - TestError::Overloaded => LanguageModelCompletionError::Overloaded, + TestError::Overloaded => LanguageModelCompletionError::ServerOverloaded { + provider: self.provider_name(), + retry_after: None, + }, TestError::InternalServerError => { - LanguageModelCompletionError::ApiInternalServerError + LanguageModelCompletionError::ApiInternalServerError { + provider: self.provider_name(), + message: "I'm a teapot orbiting the sun".to_string(), + } } }; async move { @@ -4649,9 +4621,13 @@ fn main() {{ > { if !*self.failed_once.lock() { *self.failed_once.lock() = true; + let provider = self.provider_name(); // Return error on first attempt let stream = futures::stream::once(async move { - Err(LanguageModelCompletionError::Overloaded) + Err(LanguageModelCompletionError::ServerOverloaded { + provider, + retry_after: None, + }) }); async move { Ok(stream.boxed()) }.boxed() } else { @@ -4814,9 +4790,13 @@ fn main() {{ > { if !*self.failed_once.lock() { *self.failed_once.lock() = true; + let provider = self.provider_name(); // Return error on first attempt let stream = futures::stream::once(async move { - Err(LanguageModelCompletionError::Overloaded) + Err(LanguageModelCompletionError::ServerOverloaded { + provider, + retry_after: None, + }) }); async move { Ok(stream.boxed()) }.boxed() } else { @@ -4969,10 +4949,12 @@ fn main() {{ LanguageModelCompletionError, >, > { + let provider = self.provider_name(); async move { let stream = futures::stream::once(async move { Err(LanguageModelCompletionError::RateLimitExceeded { - retry_after: Duration::from_secs(TEST_RATE_LIMIT_RETRY_SECS), + provider, + retry_after: Some(Duration::from_secs(TEST_RATE_LIMIT_RETRY_SECS)), }) }); Ok(stream.boxed()) diff --git a/crates/agent_ui/src/agent_panel.rs b/crates/agent_ui/src/agent_panel.rs index 560e87b1c2..40d8a29c59 100644 --- a/crates/agent_ui/src/agent_panel.rs +++ b/crates/agent_ui/src/agent_panel.rs @@ -2025,9 +2025,7 @@ impl AgentPanel { .thread() .read(cx) .configured_model() - .map_or(false, |model| { - model.provider.id().0 == ZED_CLOUD_PROVIDER_ID - }); + .map_or(false, |model| model.provider.id() == ZED_CLOUD_PROVIDER_ID); if !is_using_zed_provider { return false; diff --git a/crates/agent_ui/src/message_editor.rs b/crates/agent_ui/src/message_editor.rs index 015b50a801..70d2b6e066 100644 --- a/crates/agent_ui/src/message_editor.rs +++ b/crates/agent_ui/src/message_editor.rs @@ -1250,9 +1250,7 @@ impl MessageEditor { self.thread .read(cx) .configured_model() - .map_or(false, |model| { - model.provider.id().0 == ZED_CLOUD_PROVIDER_ID - }) + .map_or(false, |model| model.provider.id() == ZED_CLOUD_PROVIDER_ID) } fn render_usage_callout(&self, line_height: Pixels, cx: &mut Context) -> Option
{ diff --git a/crates/anthropic/src/anthropic.rs b/crates/anthropic/src/anthropic.rs index 7f0ab7550d..c73f606045 100644 --- a/crates/anthropic/src/anthropic.rs +++ b/crates/anthropic/src/anthropic.rs @@ -6,7 +6,7 @@ use anyhow::{Context as _, Result, anyhow}; use chrono::{DateTime, Utc}; use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream}; use http_client::http::{self, HeaderMap, HeaderValue}; -use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest}; +use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest, StatusCode}; use serde::{Deserialize, Serialize}; use strum::{EnumIter, EnumString}; use thiserror::Error; @@ -356,7 +356,7 @@ pub async fn complete( .send(request) .await .map_err(AnthropicError::HttpSend)?; - let status = response.status(); + let status_code = response.status(); let mut body = String::new(); response .body_mut() @@ -364,12 +364,12 @@ pub async fn complete( .await .map_err(AnthropicError::ReadResponse)?; - if status.is_success() { + if status_code.is_success() { Ok(serde_json::from_str(&body).map_err(AnthropicError::DeserializeResponse)?) } else { Err(AnthropicError::HttpResponseError { - status: status.as_u16(), - body, + status_code, + message: body, }) } } @@ -444,11 +444,7 @@ impl RateLimitInfo { } Self { - retry_after: headers - .get("retry-after") - .and_then(|v| v.to_str().ok()) - .and_then(|v| v.parse::().ok()) - .map(Duration::from_secs), + retry_after: parse_retry_after(headers), requests: RateLimit::from_headers("requests", headers).ok(), tokens: RateLimit::from_headers("tokens", headers).ok(), input_tokens: RateLimit::from_headers("input-tokens", headers).ok(), @@ -457,6 +453,17 @@ impl RateLimitInfo { } } +/// Parses the Retry-After header value as an integer number of seconds (anthropic always uses +/// seconds). Note that other services might specify an HTTP date or some other format for this +/// header. Returns `None` if the header is not present or cannot be parsed. +pub fn parse_retry_after(headers: &HeaderMap) -> Option { + headers + .get("retry-after") + .and_then(|v| v.to_str().ok()) + .and_then(|v| v.parse::().ok()) + .map(Duration::from_secs) +} + fn get_header<'a>(key: &str, headers: &'a HeaderMap) -> anyhow::Result<&'a str> { Ok(headers .get(key) @@ -520,6 +527,10 @@ pub async fn stream_completion_with_rate_limit_info( }) .boxed(); Ok((stream, Some(rate_limits))) + } else if response.status().as_u16() == 529 { + Err(AnthropicError::ServerOverloaded { + retry_after: rate_limits.retry_after, + }) } else if let Some(retry_after) = rate_limits.retry_after { Err(AnthropicError::RateLimit { retry_after }) } else { @@ -532,10 +543,9 @@ pub async fn stream_completion_with_rate_limit_info( match serde_json::from_str::(&body) { Ok(Event::Error { error }) => Err(AnthropicError::ApiError(error)), - Ok(_) => Err(AnthropicError::UnexpectedResponseFormat(body)), - Err(_) => Err(AnthropicError::HttpResponseError { - status: response.status().as_u16(), - body: body, + Ok(_) | Err(_) => Err(AnthropicError::HttpResponseError { + status_code: response.status(), + message: body, }), } } @@ -801,16 +811,19 @@ pub enum AnthropicError { ReadResponse(io::Error), /// HTTP error response from the API - HttpResponseError { status: u16, body: String }, + HttpResponseError { + status_code: StatusCode, + message: String, + }, /// Rate limit exceeded RateLimit { retry_after: Duration }, + /// Server overloaded + ServerOverloaded { retry_after: Option }, + /// API returned an error response ApiError(ApiError), - - /// Unexpected response format - UnexpectedResponseFormat(String), } #[derive(Debug, Serialize, Deserialize, Error)] diff --git a/crates/assistant_context/src/assistant_context.rs b/crates/assistant_context/src/assistant_context.rs index 0be8afcf69..aaaef15250 100644 --- a/crates/assistant_context/src/assistant_context.rs +++ b/crates/assistant_context/src/assistant_context.rs @@ -2140,7 +2140,8 @@ impl AssistantContext { ); } LanguageModelCompletionEvent::ToolUse(_) | - LanguageModelCompletionEvent::UsageUpdate(_) => {} + LanguageModelCompletionEvent::ToolUseJsonParseError { .. } | + LanguageModelCompletionEvent::UsageUpdate(_) => {} } }); diff --git a/crates/assistant_tools/src/edit_agent/evals.rs b/crates/assistant_tools/src/edit_agent/evals.rs index 7beb2ec919..8df8f677f2 100644 --- a/crates/assistant_tools/src/edit_agent/evals.rs +++ b/crates/assistant_tools/src/edit_agent/evals.rs @@ -29,6 +29,7 @@ use std::{ path::Path, str::FromStr, sync::mpsc, + time::Duration, }; use util::path; @@ -1658,12 +1659,14 @@ async fn retry_on_rate_limit(mut request: impl AsyncFnMut() -> Result) -> match request().await { Ok(result) => return Ok(result), Err(err) => match err.downcast::() { - Ok(err) => match err { - LanguageModelCompletionError::RateLimitExceeded { retry_after } => { + Ok(err) => match &err { + LanguageModelCompletionError::RateLimitExceeded { retry_after, .. } + | LanguageModelCompletionError::ServerOverloaded { retry_after, .. } => { + let retry_after = retry_after.unwrap_or(Duration::from_secs(5)); // Wait for the duration supplied, with some jitter to avoid all requests being made at the same time. let jitter = retry_after.mul_f64(rand::thread_rng().gen_range(0.0..1.0)); eprintln!( - "Attempt #{attempt}: Rate limit exceeded. Retry after {retry_after:?} + jitter of {jitter:?}" + "Attempt #{attempt}: {err}. Retry after {retry_after:?} + jitter of {jitter:?}" ); Timer::after(retry_after + jitter).await; continue; diff --git a/crates/eval/src/instance.rs b/crates/eval/src/instance.rs index bb66a04e1f..d17dc89d0b 100644 --- a/crates/eval/src/instance.rs +++ b/crates/eval/src/instance.rs @@ -1054,6 +1054,15 @@ pub fn response_events_to_markdown( | LanguageModelCompletionEvent::StartMessage { .. } | LanguageModelCompletionEvent::StatusUpdate { .. }, ) => {} + Ok(LanguageModelCompletionEvent::ToolUseJsonParseError { + json_parse_error, .. + }) => { + flush_buffers(&mut response, &mut text_buffer, &mut thinking_buffer); + response.push_str(&format!( + "**Error**: parse error in tool use JSON: {}\n\n", + json_parse_error + )); + } Err(error) => { flush_buffers(&mut response, &mut text_buffer, &mut thinking_buffer); response.push_str(&format!("**Error**: {}\n\n", error)); @@ -1132,6 +1141,17 @@ impl ThreadDialog { | Ok(LanguageModelCompletionEvent::StartMessage { .. }) | Ok(LanguageModelCompletionEvent::Stop(_)) => {} + Ok(LanguageModelCompletionEvent::ToolUseJsonParseError { + json_parse_error, + .. + }) => { + flush_text(&mut current_text, &mut content); + content.push(MessageContent::Text(format!( + "ERROR: parse error in tool use JSON: {}", + json_parse_error + ))); + } + Err(error) => { flush_text(&mut current_text, &mut content); content.push(MessageContent::Text(format!("ERROR: {}", error))); diff --git a/crates/language_model/src/language_model.rs b/crates/language_model/src/language_model.rs index ccde40c05f..c06ae426cc 100644 --- a/crates/language_model/src/language_model.rs +++ b/crates/language_model/src/language_model.rs @@ -9,17 +9,18 @@ mod telemetry; pub mod fake_provider; use anthropic::{AnthropicError, parse_prompt_too_long}; -use anyhow::Result; +use anyhow::{Result, anyhow}; use client::Client; use futures::FutureExt; use futures::{StreamExt, future::BoxFuture, stream::BoxStream}; use gpui::{AnyElement, AnyView, App, AsyncApp, SharedString, Task, Window}; -use http_client::http; +use http_client::{StatusCode, http}; use icons::IconName; use parking_lot::Mutex; use schemars::JsonSchema; use serde::{Deserialize, Serialize, de::DeserializeOwned}; use std::ops::{Add, Sub}; +use std::str::FromStr; use std::sync::Arc; use std::time::Duration; use std::{fmt, io}; @@ -34,11 +35,22 @@ pub use crate::request::*; pub use crate::role::*; pub use crate::telemetry::*; -pub const ZED_CLOUD_PROVIDER_ID: &str = "zed.dev"; +pub const ANTHROPIC_PROVIDER_ID: LanguageModelProviderId = + LanguageModelProviderId::new("anthropic"); +pub const ANTHROPIC_PROVIDER_NAME: LanguageModelProviderName = + LanguageModelProviderName::new("Anthropic"); -/// If we get a rate limit error that doesn't tell us when we can retry, -/// default to waiting this long before retrying. -const DEFAULT_RATE_LIMIT_RETRY_AFTER: Duration = Duration::from_secs(4); +pub const GOOGLE_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("google"); +pub const GOOGLE_PROVIDER_NAME: LanguageModelProviderName = + LanguageModelProviderName::new("Google AI"); + +pub const OPEN_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("openai"); +pub const OPEN_AI_PROVIDER_NAME: LanguageModelProviderName = + LanguageModelProviderName::new("OpenAI"); + +pub const ZED_CLOUD_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("zed.dev"); +pub const ZED_CLOUD_PROVIDER_NAME: LanguageModelProviderName = + LanguageModelProviderName::new("Zed"); pub fn init(client: Arc, cx: &mut App) { init_settings(cx); @@ -71,6 +83,12 @@ pub enum LanguageModelCompletionEvent { data: String, }, ToolUse(LanguageModelToolUse), + ToolUseJsonParseError { + id: LanguageModelToolUseId, + tool_name: Arc, + raw_input: Arc, + json_parse_error: String, + }, StartMessage { message_id: String, }, @@ -79,61 +97,179 @@ pub enum LanguageModelCompletionEvent { #[derive(Error, Debug)] pub enum LanguageModelCompletionError { - #[error("rate limit exceeded, retry after {retry_after:?}")] - RateLimitExceeded { retry_after: Duration }, - #[error("received bad input JSON")] - BadInputJson { - id: LanguageModelToolUseId, - tool_name: Arc, - raw_input: Arc, - json_parse_error: String, - }, - #[error("language model provider's API is overloaded")] - Overloaded, - #[error(transparent)] - Other(#[from] anyhow::Error), - #[error("invalid request format to language model provider's API")] - BadRequestFormat, - #[error("authentication error with language model provider's API")] - AuthenticationError, - #[error("permission error with language model provider's API")] - PermissionError, - #[error("language model provider API endpoint not found")] - ApiEndpointNotFound, #[error("prompt too large for context window")] PromptTooLarge { tokens: Option }, - #[error("internal server error in language model provider's API")] - ApiInternalServerError, - #[error("I/O error reading response from language model provider's API: {0:?}")] - ApiReadResponseError(io::Error), - #[error("HTTP response error from language model provider's API: status {status} - {body:?}")] - HttpResponseError { status: u16, body: String }, - #[error("error serializing request to language model provider API: {0}")] - SerializeRequest(serde_json::Error), - #[error("error building request body to language model provider API: {0}")] - BuildRequestBody(http::Error), - #[error("error sending HTTP request to language model provider API: {0}")] - HttpSend(anyhow::Error), - #[error("error deserializing language model provider API response: {0}")] - DeserializeResponse(serde_json::Error), - #[error("unexpected language model provider API response format: {0}")] - UnknownResponseFormat(String), + #[error("missing {provider} API key")] + NoApiKey { provider: LanguageModelProviderName }, + #[error("{provider}'s API rate limit exceeded")] + RateLimitExceeded { + provider: LanguageModelProviderName, + retry_after: Option, + }, + #[error("{provider}'s API servers are overloaded right now")] + ServerOverloaded { + provider: LanguageModelProviderName, + retry_after: Option, + }, + #[error("{provider}'s API server reported an internal server error: {message}")] + ApiInternalServerError { + provider: LanguageModelProviderName, + message: String, + }, + #[error("HTTP response error from {provider}'s API: status {status_code} - {message:?}")] + HttpResponseError { + provider: LanguageModelProviderName, + status_code: StatusCode, + message: String, + }, + + // Client errors + #[error("invalid request format to {provider}'s API: {message}")] + BadRequestFormat { + provider: LanguageModelProviderName, + message: String, + }, + #[error("authentication error with {provider}'s API: {message}")] + AuthenticationError { + provider: LanguageModelProviderName, + message: String, + }, + #[error("permission error with {provider}'s API: {message}")] + PermissionError { + provider: LanguageModelProviderName, + message: String, + }, + #[error("language model provider API endpoint not found")] + ApiEndpointNotFound { provider: LanguageModelProviderName }, + #[error("I/O error reading response from {provider}'s API")] + ApiReadResponseError { + provider: LanguageModelProviderName, + #[source] + error: io::Error, + }, + #[error("error serializing request to {provider} API")] + SerializeRequest { + provider: LanguageModelProviderName, + #[source] + error: serde_json::Error, + }, + #[error("error building request body to {provider} API")] + BuildRequestBody { + provider: LanguageModelProviderName, + #[source] + error: http::Error, + }, + #[error("error sending HTTP request to {provider} API")] + HttpSend { + provider: LanguageModelProviderName, + #[source] + error: anyhow::Error, + }, + #[error("error deserializing {provider} API response")] + DeserializeResponse { + provider: LanguageModelProviderName, + #[source] + error: serde_json::Error, + }, + + // TODO: Ideally this would be removed in favor of having a comprehensive list of errors. + #[error(transparent)] + Other(#[from] anyhow::Error), +} + +impl LanguageModelCompletionError { + pub fn from_cloud_failure( + upstream_provider: LanguageModelProviderName, + code: String, + message: String, + retry_after: Option, + ) -> Self { + if let Some(tokens) = parse_prompt_too_long(&message) { + // TODO: currently Anthropic PAYLOAD_TOO_LARGE response may cause INTERNAL_SERVER_ERROR + // to be reported. This is a temporary workaround to handle this in the case where the + // token limit has been exceeded. + Self::PromptTooLarge { + tokens: Some(tokens), + } + } else if let Some(status_code) = code + .strip_prefix("upstream_http_") + .and_then(|code| StatusCode::from_str(code).ok()) + { + Self::from_http_status(upstream_provider, status_code, message, retry_after) + } else if let Some(status_code) = code + .strip_prefix("http_") + .and_then(|code| StatusCode::from_str(code).ok()) + { + Self::from_http_status(ZED_CLOUD_PROVIDER_NAME, status_code, message, retry_after) + } else { + anyhow!("completion request failed, code: {code}, message: {message}").into() + } + } + + pub fn from_http_status( + provider: LanguageModelProviderName, + status_code: StatusCode, + message: String, + retry_after: Option, + ) -> Self { + match status_code { + StatusCode::BAD_REQUEST => Self::BadRequestFormat { provider, message }, + StatusCode::UNAUTHORIZED => Self::AuthenticationError { provider, message }, + StatusCode::FORBIDDEN => Self::PermissionError { provider, message }, + StatusCode::NOT_FOUND => Self::ApiEndpointNotFound { provider }, + StatusCode::PAYLOAD_TOO_LARGE => Self::PromptTooLarge { + tokens: parse_prompt_too_long(&message), + }, + StatusCode::TOO_MANY_REQUESTS => Self::RateLimitExceeded { + provider, + retry_after, + }, + StatusCode::INTERNAL_SERVER_ERROR => Self::ApiInternalServerError { provider, message }, + StatusCode::SERVICE_UNAVAILABLE => Self::ServerOverloaded { + provider, + retry_after, + }, + _ if status_code.as_u16() == 529 => Self::ServerOverloaded { + provider, + retry_after, + }, + _ => Self::HttpResponseError { + provider, + status_code, + message, + }, + } + } } impl From for LanguageModelCompletionError { fn from(error: AnthropicError) -> Self { + let provider = ANTHROPIC_PROVIDER_NAME; match error { - AnthropicError::SerializeRequest(error) => Self::SerializeRequest(error), - AnthropicError::BuildRequestBody(error) => Self::BuildRequestBody(error), - AnthropicError::HttpSend(error) => Self::HttpSend(error), - AnthropicError::DeserializeResponse(error) => Self::DeserializeResponse(error), - AnthropicError::ReadResponse(error) => Self::ApiReadResponseError(error), - AnthropicError::HttpResponseError { status, body } => { - Self::HttpResponseError { status, body } + AnthropicError::SerializeRequest(error) => Self::SerializeRequest { provider, error }, + AnthropicError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error }, + AnthropicError::HttpSend(error) => Self::HttpSend { provider, error }, + AnthropicError::DeserializeResponse(error) => { + Self::DeserializeResponse { provider, error } } - AnthropicError::RateLimit { retry_after } => Self::RateLimitExceeded { retry_after }, + AnthropicError::ReadResponse(error) => Self::ApiReadResponseError { provider, error }, + AnthropicError::HttpResponseError { + status_code, + message, + } => Self::HttpResponseError { + provider, + status_code, + message, + }, + AnthropicError::RateLimit { retry_after } => Self::RateLimitExceeded { + provider, + retry_after: Some(retry_after), + }, + AnthropicError::ServerOverloaded { retry_after } => Self::ServerOverloaded { + provider, + retry_after: retry_after, + }, AnthropicError::ApiError(api_error) => api_error.into(), - AnthropicError::UnexpectedResponseFormat(error) => Self::UnknownResponseFormat(error), } } } @@ -141,23 +277,39 @@ impl From for LanguageModelCompletionError { impl From for LanguageModelCompletionError { fn from(error: anthropic::ApiError) -> Self { use anthropic::ApiErrorCode::*; - + let provider = ANTHROPIC_PROVIDER_NAME; match error.code() { Some(code) => match code { - InvalidRequestError => LanguageModelCompletionError::BadRequestFormat, - AuthenticationError => LanguageModelCompletionError::AuthenticationError, - PermissionError => LanguageModelCompletionError::PermissionError, - NotFoundError => LanguageModelCompletionError::ApiEndpointNotFound, - RequestTooLarge => LanguageModelCompletionError::PromptTooLarge { + InvalidRequestError => Self::BadRequestFormat { + provider, + message: error.message, + }, + AuthenticationError => Self::AuthenticationError { + provider, + message: error.message, + }, + PermissionError => Self::PermissionError { + provider, + message: error.message, + }, + NotFoundError => Self::ApiEndpointNotFound { provider }, + RequestTooLarge => Self::PromptTooLarge { tokens: parse_prompt_too_long(&error.message), }, - RateLimitError => LanguageModelCompletionError::RateLimitExceeded { - retry_after: DEFAULT_RATE_LIMIT_RETRY_AFTER, + RateLimitError => Self::RateLimitExceeded { + provider, + retry_after: None, + }, + ApiError => Self::ApiInternalServerError { + provider, + message: error.message, + }, + OverloadedError => Self::ServerOverloaded { + provider, + retry_after: None, }, - ApiError => LanguageModelCompletionError::ApiInternalServerError, - OverloadedError => LanguageModelCompletionError::Overloaded, }, - None => LanguageModelCompletionError::Other(error.into()), + None => Self::Other(error.into()), } } } @@ -278,6 +430,13 @@ pub trait LanguageModel: Send + Sync { fn name(&self) -> LanguageModelName; fn provider_id(&self) -> LanguageModelProviderId; fn provider_name(&self) -> LanguageModelProviderName; + fn upstream_provider_id(&self) -> LanguageModelProviderId { + self.provider_id() + } + fn upstream_provider_name(&self) -> LanguageModelProviderName { + self.provider_name() + } + fn telemetry_id(&self) -> String; fn api_key(&self, _cx: &App) -> Option { @@ -365,6 +524,9 @@ pub trait LanguageModel: Send + Sync { Ok(LanguageModelCompletionEvent::RedactedThinking { .. }) => None, Ok(LanguageModelCompletionEvent::Stop(_)) => None, Ok(LanguageModelCompletionEvent::ToolUse(_)) => None, + Ok(LanguageModelCompletionEvent::ToolUseJsonParseError { + .. + }) => None, Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => { *last_token_usage.lock() = token_usage; None @@ -395,39 +557,6 @@ pub trait LanguageModel: Send + Sync { } } -#[derive(Debug, Error)] -pub enum LanguageModelKnownError { - #[error("Context window limit exceeded ({tokens})")] - ContextWindowLimitExceeded { tokens: u64 }, - #[error("Language model provider's API is currently overloaded")] - Overloaded, - #[error("Language model provider's API encountered an internal server error")] - ApiInternalServerError, - #[error("I/O error while reading response from language model provider's API: {0:?}")] - ReadResponseError(io::Error), - #[error("Error deserializing response from language model provider's API: {0:?}")] - DeserializeResponse(serde_json::Error), - #[error("Language model provider's API returned a response in an unknown format")] - UnknownResponseFormat(String), - #[error("Rate limit exceeded for language model provider's API; retry in {retry_after:?}")] - RateLimitExceeded { retry_after: Duration }, -} - -impl LanguageModelKnownError { - /// Attempts to map an HTTP response status code to a known error type. - /// Returns None if the status code doesn't map to a specific known error. - pub fn from_http_response(status: u16, _body: &str) -> Option { - match status { - 429 => Some(Self::RateLimitExceeded { - retry_after: DEFAULT_RATE_LIMIT_RETRY_AFTER, - }), - 503 => Some(Self::Overloaded), - 500..=599 => Some(Self::ApiInternalServerError), - _ => None, - } - } -} - pub trait LanguageModelTool: 'static + DeserializeOwned + JsonSchema { fn name() -> String; fn description() -> String; @@ -509,12 +638,30 @@ pub struct LanguageModelProviderId(pub SharedString); #[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)] pub struct LanguageModelProviderName(pub SharedString); +impl LanguageModelProviderId { + pub const fn new(id: &'static str) -> Self { + Self(SharedString::new_static(id)) + } +} + +impl LanguageModelProviderName { + pub const fn new(id: &'static str) -> Self { + Self(SharedString::new_static(id)) + } +} + impl fmt::Display for LanguageModelProviderId { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "{}", self.0) } } +impl fmt::Display for LanguageModelProviderName { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.0) + } +} + impl From for LanguageModelId { fn from(value: String) -> Self { Self(SharedString::from(value)) diff --git a/crates/language_model/src/registry.rs b/crates/language_model/src/registry.rs index e9f03cc1ff..840fda38de 100644 --- a/crates/language_model/src/registry.rs +++ b/crates/language_model/src/registry.rs @@ -98,7 +98,7 @@ impl ConfiguredModel { } pub fn is_provided_by_zed(&self) -> bool { - self.provider.id().0 == crate::ZED_CLOUD_PROVIDER_ID + self.provider.id() == crate::ZED_CLOUD_PROVIDER_ID } } diff --git a/crates/language_model/src/telemetry.rs b/crates/language_model/src/telemetry.rs index 9bd9b903c2..ccdcb0ad0c 100644 --- a/crates/language_model/src/telemetry.rs +++ b/crates/language_model/src/telemetry.rs @@ -1,3 +1,4 @@ +use crate::ANTHROPIC_PROVIDER_ID; use anthropic::ANTHROPIC_API_URL; use anyhow::{Context as _, anyhow}; use client::telemetry::Telemetry; @@ -8,8 +9,6 @@ use std::sync::Arc; use telemetry_events::{AssistantEventData, AssistantKind, AssistantPhase}; use util::ResultExt; -pub const ANTHROPIC_PROVIDER_ID: &str = "anthropic"; - pub fn report_assistant_event( event: AssistantEventData, telemetry: Option>, @@ -19,7 +18,7 @@ pub fn report_assistant_event( ) { if let Some(telemetry) = telemetry.as_ref() { telemetry.report_assistant_event(event.clone()); - if telemetry.metrics_enabled() && event.model_provider == ANTHROPIC_PROVIDER_ID { + if telemetry.metrics_enabled() && event.model_provider == ANTHROPIC_PROVIDER_ID.0 { if let Some(api_key) = model_api_key { executor .spawn(async move { diff --git a/crates/language_models/src/provider/anthropic.rs b/crates/language_models/src/provider/anthropic.rs index aa500f4b4d..6ddb1a4381 100644 --- a/crates/language_models/src/provider/anthropic.rs +++ b/crates/language_models/src/provider/anthropic.rs @@ -33,8 +33,8 @@ use theme::ThemeSettings; use ui::{Icon, IconName, List, Tooltip, prelude::*}; use util::ResultExt; -const PROVIDER_ID: &str = language_model::ANTHROPIC_PROVIDER_ID; -const PROVIDER_NAME: &str = "Anthropic"; +const PROVIDER_ID: LanguageModelProviderId = language_model::ANTHROPIC_PROVIDER_ID; +const PROVIDER_NAME: LanguageModelProviderName = language_model::ANTHROPIC_PROVIDER_NAME; #[derive(Default, Clone, Debug, PartialEq)] pub struct AnthropicSettings { @@ -218,11 +218,11 @@ impl LanguageModelProviderState for AnthropicLanguageModelProvider { impl LanguageModelProvider for AnthropicLanguageModelProvider { fn id(&self) -> LanguageModelProviderId { - LanguageModelProviderId(PROVIDER_ID.into()) + PROVIDER_ID } fn name(&self) -> LanguageModelProviderName { - LanguageModelProviderName(PROVIDER_NAME.into()) + PROVIDER_NAME } fn icon(&self) -> IconName { @@ -403,7 +403,11 @@ impl AnthropicModel { }; async move { - let api_key = api_key.context("Missing Anthropic API Key")?; + let Some(api_key) = api_key else { + return Err(LanguageModelCompletionError::NoApiKey { + provider: PROVIDER_NAME, + }); + }; let request = anthropic::stream_completion(http_client.as_ref(), &api_url, &api_key, request); request.await.map_err(Into::into) @@ -422,11 +426,11 @@ impl LanguageModel for AnthropicModel { } fn provider_id(&self) -> LanguageModelProviderId { - LanguageModelProviderId(PROVIDER_ID.into()) + PROVIDER_ID } fn provider_name(&self) -> LanguageModelProviderName { - LanguageModelProviderName(PROVIDER_NAME.into()) + PROVIDER_NAME } fn supports_tools(&self) -> bool { @@ -806,12 +810,14 @@ impl AnthropicEventMapper { raw_input: tool_use.input_json.clone(), }, )), - Err(json_parse_err) => Err(LanguageModelCompletionError::BadInputJson { - id: tool_use.id.into(), - tool_name: tool_use.name.into(), - raw_input: input_json.into(), - json_parse_error: json_parse_err.to_string(), - }), + Err(json_parse_err) => { + Ok(LanguageModelCompletionEvent::ToolUseJsonParseError { + id: tool_use.id.into(), + tool_name: tool_use.name.into(), + raw_input: input_json.into(), + json_parse_error: json_parse_err.to_string(), + }) + } }; vec![event_result] diff --git a/crates/language_models/src/provider/bedrock.rs b/crates/language_models/src/provider/bedrock.rs index a55fc5bc11..dd19915f93 100644 --- a/crates/language_models/src/provider/bedrock.rs +++ b/crates/language_models/src/provider/bedrock.rs @@ -52,8 +52,8 @@ use util::ResultExt; use crate::AllLanguageModelSettings; -const PROVIDER_ID: &str = "amazon-bedrock"; -const PROVIDER_NAME: &str = "Amazon Bedrock"; +const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("amazon-bedrock"); +const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("Amazon Bedrock"); #[derive(Default, Clone, Deserialize, Serialize, PartialEq, Debug)] pub struct BedrockCredentials { @@ -285,11 +285,11 @@ impl BedrockLanguageModelProvider { impl LanguageModelProvider for BedrockLanguageModelProvider { fn id(&self) -> LanguageModelProviderId { - LanguageModelProviderId(PROVIDER_ID.into()) + PROVIDER_ID } fn name(&self) -> LanguageModelProviderName { - LanguageModelProviderName(PROVIDER_NAME.into()) + PROVIDER_NAME } fn icon(&self) -> IconName { @@ -489,11 +489,11 @@ impl LanguageModel for BedrockModel { } fn provider_id(&self) -> LanguageModelProviderId { - LanguageModelProviderId(PROVIDER_ID.into()) + PROVIDER_ID } fn provider_name(&self) -> LanguageModelProviderName { - LanguageModelProviderName(PROVIDER_NAME.into()) + PROVIDER_NAME } fn supports_tools(&self) -> bool { diff --git a/crates/language_models/src/provider/cloud.rs b/crates/language_models/src/provider/cloud.rs index 62a24282dd..ecc07daa69 100644 --- a/crates/language_models/src/provider/cloud.rs +++ b/crates/language_models/src/provider/cloud.rs @@ -1,4 +1,4 @@ -use anthropic::{AnthropicModelMode, parse_prompt_too_long}; +use anthropic::AnthropicModelMode; use anyhow::{Context as _, Result, anyhow}; use client::{Client, ModelRequestUsage, UserStore, zed_urls}; use futures::{ @@ -8,25 +8,21 @@ use google_ai::GoogleModelMode; use gpui::{ AnyElement, AnyView, App, AsyncApp, Context, Entity, SemanticVersion, Subscription, Task, }; +use http_client::http::{HeaderMap, HeaderValue}; use http_client::{AsyncBody, HttpClient, Method, Response, StatusCode}; use language_model::{ AuthenticateError, LanguageModel, LanguageModelCacheConfiguration, - LanguageModelCompletionError, LanguageModelId, LanguageModelKnownError, LanguageModelName, - LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, - LanguageModelProviderTosView, LanguageModelRequest, LanguageModelToolChoice, - LanguageModelToolSchemaFormat, ModelRequestLimitReachedError, RateLimiter, - ZED_CLOUD_PROVIDER_ID, -}; -use language_model::{ - LanguageModelCompletionEvent, LanguageModelProvider, LlmApiToken, PaymentRequiredError, - RefreshLlmTokenListener, + LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, + LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, + LanguageModelProviderState, LanguageModelProviderTosView, LanguageModelRequest, + LanguageModelToolChoice, LanguageModelToolSchemaFormat, LlmApiToken, + ModelRequestLimitReachedError, PaymentRequiredError, RateLimiter, RefreshLlmTokenListener, }; use proto::Plan; use release_channel::AppVersion; use schemars::JsonSchema; use serde::{Deserialize, Serialize, de::DeserializeOwned}; use settings::SettingsStore; -use smol::Timer; use smol::io::{AsyncReadExt, BufReader}; use std::pin::Pin; use std::str::FromStr as _; @@ -47,7 +43,8 @@ use crate::provider::anthropic::{AnthropicEventMapper, count_anthropic_tokens, i use crate::provider::google::{GoogleEventMapper, into_google}; use crate::provider::open_ai::{OpenAiEventMapper, count_open_ai_tokens, into_open_ai}; -pub const PROVIDER_NAME: &str = "Zed"; +const PROVIDER_ID: LanguageModelProviderId = language_model::ZED_CLOUD_PROVIDER_ID; +const PROVIDER_NAME: LanguageModelProviderName = language_model::ZED_CLOUD_PROVIDER_NAME; #[derive(Default, Clone, Debug, PartialEq)] pub struct ZedDotDevSettings { @@ -351,11 +348,11 @@ impl LanguageModelProviderState for CloudLanguageModelProvider { impl LanguageModelProvider for CloudLanguageModelProvider { fn id(&self) -> LanguageModelProviderId { - LanguageModelProviderId(ZED_CLOUD_PROVIDER_ID.into()) + PROVIDER_ID } fn name(&self) -> LanguageModelProviderName { - LanguageModelProviderName(PROVIDER_NAME.into()) + PROVIDER_NAME } fn icon(&self) -> IconName { @@ -536,8 +533,6 @@ struct PerformLlmCompletionResponse { } impl CloudLanguageModel { - const MAX_RETRIES: usize = 3; - async fn perform_llm_completion( client: Arc, llm_api_token: LlmApiToken, @@ -547,8 +542,7 @@ impl CloudLanguageModel { let http_client = &client.http_client(); let mut token = llm_api_token.acquire(&client).await?; - let mut retries_remaining = Self::MAX_RETRIES; - let mut retry_delay = Duration::from_secs(1); + let mut refreshed_token = false; loop { let request_builder = http_client::Request::builder() @@ -590,14 +584,20 @@ impl CloudLanguageModel { includes_status_messages, tool_use_limit_reached, }); - } else if response - .headers() - .get(EXPIRED_LLM_TOKEN_HEADER_NAME) - .is_some() + } + + if !refreshed_token + && response + .headers() + .get(EXPIRED_LLM_TOKEN_HEADER_NAME) + .is_some() { - retries_remaining -= 1; token = llm_api_token.refresh(&client).await?; - } else if status == StatusCode::FORBIDDEN + refreshed_token = true; + continue; + } + + if status == StatusCode::FORBIDDEN && response .headers() .get(SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME) @@ -622,35 +622,18 @@ impl CloudLanguageModel { return Err(anyhow!(ModelRequestLimitReachedError { plan })); } } - - anyhow::bail!("Forbidden"); - } else if status.as_u16() >= 500 && status.as_u16() < 600 { - // If we encounter an error in the 500 range, retry after a delay. - // We've seen at least these in the wild from API providers: - // * 500 Internal Server Error - // * 502 Bad Gateway - // * 529 Service Overloaded - - if retries_remaining == 0 { - let mut body = String::new(); - response.body_mut().read_to_string(&mut body).await?; - anyhow::bail!( - "cloud language model completion failed after {} retries with status {status}: {body}", - Self::MAX_RETRIES - ); - } - - Timer::after(retry_delay).await; - - retries_remaining -= 1; - retry_delay *= 2; // If it fails again, wait longer. } else if status == StatusCode::PAYMENT_REQUIRED { return Err(anyhow!(PaymentRequiredError)); - } else { - let mut body = String::new(); - response.body_mut().read_to_string(&mut body).await?; - return Err(anyhow!(ApiError { status, body })); } + + let mut body = String::new(); + let headers = response.headers().clone(); + response.body_mut().read_to_string(&mut body).await?; + return Err(anyhow!(ApiError { + status, + body, + headers + })); } } } @@ -660,6 +643,19 @@ impl CloudLanguageModel { struct ApiError { status: StatusCode, body: String, + headers: HeaderMap, +} + +impl From for LanguageModelCompletionError { + fn from(error: ApiError) -> Self { + let retry_after = None; + LanguageModelCompletionError::from_http_status( + PROVIDER_NAME, + error.status, + error.body, + retry_after, + ) + } } impl LanguageModel for CloudLanguageModel { @@ -672,11 +668,29 @@ impl LanguageModel for CloudLanguageModel { } fn provider_id(&self) -> LanguageModelProviderId { - LanguageModelProviderId(ZED_CLOUD_PROVIDER_ID.into()) + PROVIDER_ID } fn provider_name(&self) -> LanguageModelProviderName { - LanguageModelProviderName(PROVIDER_NAME.into()) + PROVIDER_NAME + } + + fn upstream_provider_id(&self) -> LanguageModelProviderId { + use zed_llm_client::LanguageModelProvider::*; + match self.model.provider { + Anthropic => language_model::ANTHROPIC_PROVIDER_ID, + OpenAi => language_model::OPEN_AI_PROVIDER_ID, + Google => language_model::GOOGLE_PROVIDER_ID, + } + } + + fn upstream_provider_name(&self) -> LanguageModelProviderName { + use zed_llm_client::LanguageModelProvider::*; + match self.model.provider { + Anthropic => language_model::ANTHROPIC_PROVIDER_NAME, + OpenAi => language_model::OPEN_AI_PROVIDER_NAME, + Google => language_model::GOOGLE_PROVIDER_NAME, + } } fn supports_tools(&self) -> bool { @@ -776,6 +790,7 @@ impl LanguageModel for CloudLanguageModel { .body(serde_json::to_string(&request_body)?.into())?; let mut response = http_client.send(request).await?; let status = response.status(); + let headers = response.headers().clone(); let mut response_body = String::new(); response .body_mut() @@ -790,7 +805,8 @@ impl LanguageModel for CloudLanguageModel { } else { Err(anyhow!(ApiError { status, - body: response_body + body: response_body, + headers })) } } @@ -855,18 +871,7 @@ impl LanguageModel for CloudLanguageModel { ) .await .map_err(|err| match err.downcast::() { - Ok(api_err) => { - if api_err.status == StatusCode::BAD_REQUEST { - if let Some(tokens) = parse_prompt_too_long(&api_err.body) { - return anyhow!( - LanguageModelKnownError::ContextWindowLimitExceeded { - tokens - } - ); - } - } - anyhow!(api_err) - } + Ok(api_err) => anyhow!(LanguageModelCompletionError::from(api_err)), Err(err) => anyhow!(err), })?; @@ -995,7 +1000,7 @@ where .flat_map(move |event| { futures::stream::iter(match event { Err(error) => { - vec![Err(LanguageModelCompletionError::Other(error))] + vec![Err(LanguageModelCompletionError::from(error))] } Ok(CloudCompletionEvent::Status(event)) => { vec![Ok(LanguageModelCompletionEvent::StatusUpdate(event))] diff --git a/crates/language_models/src/provider/copilot_chat.rs b/crates/language_models/src/provider/copilot_chat.rs index b00ec7570c..5411fbc63c 100644 --- a/crates/language_models/src/provider/copilot_chat.rs +++ b/crates/language_models/src/provider/copilot_chat.rs @@ -35,8 +35,9 @@ use super::anthropic::count_anthropic_tokens; use super::google::count_google_tokens; use super::open_ai::count_open_ai_tokens; -const PROVIDER_ID: &str = "copilot_chat"; -const PROVIDER_NAME: &str = "GitHub Copilot Chat"; +const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("copilot_chat"); +const PROVIDER_NAME: LanguageModelProviderName = + LanguageModelProviderName::new("GitHub Copilot Chat"); pub struct CopilotChatLanguageModelProvider { state: Entity, @@ -102,11 +103,11 @@ impl LanguageModelProviderState for CopilotChatLanguageModelProvider { impl LanguageModelProvider for CopilotChatLanguageModelProvider { fn id(&self) -> LanguageModelProviderId { - LanguageModelProviderId(PROVIDER_ID.into()) + PROVIDER_ID } fn name(&self) -> LanguageModelProviderName { - LanguageModelProviderName(PROVIDER_NAME.into()) + PROVIDER_NAME } fn icon(&self) -> IconName { @@ -201,11 +202,11 @@ impl LanguageModel for CopilotChatLanguageModel { } fn provider_id(&self) -> LanguageModelProviderId { - LanguageModelProviderId(PROVIDER_ID.into()) + PROVIDER_ID } fn provider_name(&self) -> LanguageModelProviderName { - LanguageModelProviderName(PROVIDER_NAME.into()) + PROVIDER_NAME } fn supports_tools(&self) -> bool { @@ -391,24 +392,24 @@ pub fn map_to_language_model_completion_events( serde_json::Value::from_str(&tool_call.arguments) }; match arguments { - Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse( - LanguageModelToolUse { - id: tool_call.id.clone().into(), - name: tool_call.name.as_str().into(), - is_input_complete: true, - input, - raw_input: tool_call.arguments.clone(), - }, - )), - Err(error) => { - Err(LanguageModelCompletionError::BadInputJson { - id: tool_call.id.into(), - tool_name: tool_call.name.as_str().into(), - raw_input: tool_call.arguments.into(), - json_parse_error: error.to_string(), - }) - } - } + Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse( + LanguageModelToolUse { + id: tool_call.id.clone().into(), + name: tool_call.name.as_str().into(), + is_input_complete: true, + input, + raw_input: tool_call.arguments.clone(), + }, + )), + Err(error) => Ok( + LanguageModelCompletionEvent::ToolUseJsonParseError { + id: tool_call.id.into(), + tool_name: tool_call.name.as_str().into(), + raw_input: tool_call.arguments.into(), + json_parse_error: error.to_string(), + }, + ), + } }, )); diff --git a/crates/language_models/src/provider/deepseek.rs b/crates/language_models/src/provider/deepseek.rs index 99a1ca70c6..a568ef4034 100644 --- a/crates/language_models/src/provider/deepseek.rs +++ b/crates/language_models/src/provider/deepseek.rs @@ -28,8 +28,8 @@ use util::ResultExt; use crate::{AllLanguageModelSettings, ui::InstructionListItem}; -const PROVIDER_ID: &str = "deepseek"; -const PROVIDER_NAME: &str = "DeepSeek"; +const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("deepseek"); +const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("DeepSeek"); const DEEPSEEK_API_KEY_VAR: &str = "DEEPSEEK_API_KEY"; #[derive(Default)] @@ -174,11 +174,11 @@ impl LanguageModelProviderState for DeepSeekLanguageModelProvider { impl LanguageModelProvider for DeepSeekLanguageModelProvider { fn id(&self) -> LanguageModelProviderId { - LanguageModelProviderId(PROVIDER_ID.into()) + PROVIDER_ID } fn name(&self) -> LanguageModelProviderName { - LanguageModelProviderName(PROVIDER_NAME.into()) + PROVIDER_NAME } fn icon(&self) -> IconName { @@ -283,11 +283,11 @@ impl LanguageModel for DeepSeekLanguageModel { } fn provider_id(&self) -> LanguageModelProviderId { - LanguageModelProviderId(PROVIDER_ID.into()) + PROVIDER_ID } fn provider_name(&self) -> LanguageModelProviderName { - LanguageModelProviderName(PROVIDER_NAME.into()) + PROVIDER_NAME } fn supports_tools(&self) -> bool { @@ -466,7 +466,7 @@ impl DeepSeekEventMapper { events.flat_map(move |event| { futures::stream::iter(match event { Ok(event) => self.map_event(event), - Err(error) => vec![Err(LanguageModelCompletionError::Other(anyhow!(error)))], + Err(error) => vec![Err(LanguageModelCompletionError::from(error))], }) }) } @@ -476,7 +476,7 @@ impl DeepSeekEventMapper { event: deepseek::StreamResponse, ) -> Vec> { let Some(choice) = event.choices.first() else { - return vec![Err(LanguageModelCompletionError::Other(anyhow!( + return vec![Err(LanguageModelCompletionError::from(anyhow!( "Response contained no choices" )))]; }; @@ -538,8 +538,8 @@ impl DeepSeekEventMapper { raw_input: tool_call.arguments.clone(), }, )), - Err(error) => Err(LanguageModelCompletionError::BadInputJson { - id: tool_call.id.into(), + Err(error) => Ok(LanguageModelCompletionEvent::ToolUseJsonParseError { + id: tool_call.id.clone().into(), tool_name: tool_call.name.as_str().into(), raw_input: tool_call.arguments.into(), json_parse_error: error.to_string(), diff --git a/crates/language_models/src/provider/google.rs b/crates/language_models/src/provider/google.rs index 5972798523..bb19a3901a 100644 --- a/crates/language_models/src/provider/google.rs +++ b/crates/language_models/src/provider/google.rs @@ -37,8 +37,8 @@ use util::ResultExt; use crate::AllLanguageModelSettings; use crate::ui::InstructionListItem; -const PROVIDER_ID: &str = "google"; -const PROVIDER_NAME: &str = "Google AI"; +const PROVIDER_ID: LanguageModelProviderId = language_model::GOOGLE_PROVIDER_ID; +const PROVIDER_NAME: LanguageModelProviderName = language_model::GOOGLE_PROVIDER_NAME; #[derive(Default, Clone, Debug, PartialEq)] pub struct GoogleSettings { @@ -207,11 +207,11 @@ impl LanguageModelProviderState for GoogleLanguageModelProvider { impl LanguageModelProvider for GoogleLanguageModelProvider { fn id(&self) -> LanguageModelProviderId { - LanguageModelProviderId(PROVIDER_ID.into()) + PROVIDER_ID } fn name(&self) -> LanguageModelProviderName { - LanguageModelProviderName(PROVIDER_NAME.into()) + PROVIDER_NAME } fn icon(&self) -> IconName { @@ -334,11 +334,11 @@ impl LanguageModel for GoogleLanguageModel { } fn provider_id(&self) -> LanguageModelProviderId { - LanguageModelProviderId(PROVIDER_ID.into()) + PROVIDER_ID } fn provider_name(&self) -> LanguageModelProviderName { - LanguageModelProviderName(PROVIDER_NAME.into()) + PROVIDER_NAME } fn supports_tools(&self) -> bool { @@ -423,9 +423,7 @@ impl LanguageModel for GoogleLanguageModel { ); let request = self.stream_completion(request, cx); let future = self.request_limiter.stream(async move { - let response = request - .await - .map_err(|err| LanguageModelCompletionError::Other(anyhow!(err)))?; + let response = request.await.map_err(LanguageModelCompletionError::from)?; Ok(GoogleEventMapper::new().map_stream(response)) }); async move { Ok(future.await?.boxed()) }.boxed() @@ -622,7 +620,7 @@ impl GoogleEventMapper { futures::stream::iter(match event { Some(Ok(event)) => self.map_event(event), Some(Err(error)) => { - vec![Err(LanguageModelCompletionError::Other(anyhow!(error)))] + vec![Err(LanguageModelCompletionError::from(error))] } None => vec![Ok(LanguageModelCompletionEvent::Stop(self.stop_reason))], }) diff --git a/crates/language_models/src/provider/lmstudio.rs b/crates/language_models/src/provider/lmstudio.rs index 519647b3bc..01600f3646 100644 --- a/crates/language_models/src/provider/lmstudio.rs +++ b/crates/language_models/src/provider/lmstudio.rs @@ -31,8 +31,8 @@ const LMSTUDIO_DOWNLOAD_URL: &str = "https://lmstudio.ai/download"; const LMSTUDIO_CATALOG_URL: &str = "https://lmstudio.ai/models"; const LMSTUDIO_SITE: &str = "https://lmstudio.ai/"; -const PROVIDER_ID: &str = "lmstudio"; -const PROVIDER_NAME: &str = "LM Studio"; +const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("lmstudio"); +const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("LM Studio"); #[derive(Default, Debug, Clone, PartialEq)] pub struct LmStudioSettings { @@ -156,11 +156,11 @@ impl LanguageModelProviderState for LmStudioLanguageModelProvider { impl LanguageModelProvider for LmStudioLanguageModelProvider { fn id(&self) -> LanguageModelProviderId { - LanguageModelProviderId(PROVIDER_ID.into()) + PROVIDER_ID } fn name(&self) -> LanguageModelProviderName { - LanguageModelProviderName(PROVIDER_NAME.into()) + PROVIDER_NAME } fn icon(&self) -> IconName { @@ -386,11 +386,11 @@ impl LanguageModel for LmStudioLanguageModel { } fn provider_id(&self) -> LanguageModelProviderId { - LanguageModelProviderId(PROVIDER_ID.into()) + PROVIDER_ID } fn provider_name(&self) -> LanguageModelProviderName { - LanguageModelProviderName(PROVIDER_NAME.into()) + PROVIDER_NAME } fn supports_tools(&self) -> bool { @@ -474,7 +474,7 @@ impl LmStudioEventMapper { events.flat_map(move |event| { futures::stream::iter(match event { Ok(event) => self.map_event(event), - Err(error) => vec![Err(LanguageModelCompletionError::Other(anyhow!(error)))], + Err(error) => vec![Err(LanguageModelCompletionError::from(error))], }) }) } @@ -484,7 +484,7 @@ impl LmStudioEventMapper { event: lmstudio::ResponseStreamEvent, ) -> Vec> { let Some(choice) = event.choices.into_iter().next() else { - return vec![Err(LanguageModelCompletionError::Other(anyhow!( + return vec![Err(LanguageModelCompletionError::from(anyhow!( "Response contained no choices" )))]; }; @@ -553,7 +553,7 @@ impl LmStudioEventMapper { raw_input: tool_call.arguments, }, )), - Err(error) => Err(LanguageModelCompletionError::BadInputJson { + Err(error) => Ok(LanguageModelCompletionEvent::ToolUseJsonParseError { id: tool_call.id.into(), tool_name: tool_call.name.into(), raw_input: tool_call.arguments.into(), diff --git a/crates/language_models/src/provider/mistral.rs b/crates/language_models/src/provider/mistral.rs index 171ce05896..c58622d4e0 100644 --- a/crates/language_models/src/provider/mistral.rs +++ b/crates/language_models/src/provider/mistral.rs @@ -2,8 +2,7 @@ use anyhow::{Context as _, Result, anyhow}; use collections::BTreeMap; use credentials_provider::CredentialsProvider; use editor::{Editor, EditorElement, EditorStyle}; -use futures::stream::BoxStream; -use futures::{FutureExt, StreamExt, future::BoxFuture}; +use futures::{FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream}; use gpui::{ AnyView, App, AsyncApp, Context, Entity, FontStyle, Subscription, Task, TextStyle, WhiteSpace, }; @@ -15,6 +14,7 @@ use language_model::{ LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent, RateLimiter, Role, StopReason, TokenUsage, }; +use mistral::StreamResponse; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use settings::{Settings, SettingsStore}; @@ -29,8 +29,8 @@ use util::ResultExt; use crate::{AllLanguageModelSettings, ui::InstructionListItem}; -const PROVIDER_ID: &str = "mistral"; -const PROVIDER_NAME: &str = "Mistral"; +const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("mistral"); +const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("Mistral"); #[derive(Default, Clone, Debug, PartialEq)] pub struct MistralSettings { @@ -171,11 +171,11 @@ impl LanguageModelProviderState for MistralLanguageModelProvider { impl LanguageModelProvider for MistralLanguageModelProvider { fn id(&self) -> LanguageModelProviderId { - LanguageModelProviderId(PROVIDER_ID.into()) + PROVIDER_ID } fn name(&self) -> LanguageModelProviderName { - LanguageModelProviderName(PROVIDER_NAME.into()) + PROVIDER_NAME } fn icon(&self) -> IconName { @@ -298,11 +298,11 @@ impl LanguageModel for MistralLanguageModel { } fn provider_id(&self) -> LanguageModelProviderId { - LanguageModelProviderId(PROVIDER_ID.into()) + PROVIDER_ID } fn provider_name(&self) -> LanguageModelProviderName { - LanguageModelProviderName(PROVIDER_NAME.into()) + PROVIDER_NAME } fn supports_tools(&self) -> bool { @@ -579,13 +579,13 @@ impl MistralEventMapper { pub fn map_stream( mut self, - events: Pin>>>, - ) -> impl futures::Stream> + events: Pin>>>, + ) -> impl Stream> { events.flat_map(move |event| { futures::stream::iter(match event { Ok(event) => self.map_event(event), - Err(error) => vec![Err(LanguageModelCompletionError::Other(anyhow!(error)))], + Err(error) => vec![Err(LanguageModelCompletionError::from(error))], }) }) } @@ -595,7 +595,7 @@ impl MistralEventMapper { event: mistral::StreamResponse, ) -> Vec> { let Some(choice) = event.choices.first() else { - return vec![Err(LanguageModelCompletionError::Other(anyhow!( + return vec![Err(LanguageModelCompletionError::from(anyhow!( "Response contained no choices" )))]; }; @@ -660,7 +660,7 @@ impl MistralEventMapper { for (_, tool_call) in self.tool_calls_by_index.drain() { if tool_call.id.is_empty() || tool_call.name.is_empty() { - results.push(Err(LanguageModelCompletionError::Other(anyhow!( + results.push(Err(LanguageModelCompletionError::from(anyhow!( "Received incomplete tool call: missing id or name" )))); continue; @@ -676,12 +676,14 @@ impl MistralEventMapper { raw_input: tool_call.arguments, }, ))), - Err(error) => results.push(Err(LanguageModelCompletionError::BadInputJson { - id: tool_call.id.into(), - tool_name: tool_call.name.into(), - raw_input: tool_call.arguments.into(), - json_parse_error: error.to_string(), - })), + Err(error) => { + results.push(Ok(LanguageModelCompletionEvent::ToolUseJsonParseError { + id: tool_call.id.into(), + tool_name: tool_call.name.into(), + raw_input: tool_call.arguments.into(), + json_parse_error: error.to_string(), + })) + } } } diff --git a/crates/language_models/src/provider/ollama.rs b/crates/language_models/src/provider/ollama.rs index 205dab6c87..0866cfa4c8 100644 --- a/crates/language_models/src/provider/ollama.rs +++ b/crates/language_models/src/provider/ollama.rs @@ -30,8 +30,8 @@ const OLLAMA_DOWNLOAD_URL: &str = "https://ollama.com/download"; const OLLAMA_LIBRARY_URL: &str = "https://ollama.com/library"; const OLLAMA_SITE: &str = "https://ollama.com/"; -const PROVIDER_ID: &str = "ollama"; -const PROVIDER_NAME: &str = "Ollama"; +const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("ollama"); +const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("Ollama"); #[derive(Default, Debug, Clone, PartialEq)] pub struct OllamaSettings { @@ -181,11 +181,11 @@ impl LanguageModelProviderState for OllamaLanguageModelProvider { impl LanguageModelProvider for OllamaLanguageModelProvider { fn id(&self) -> LanguageModelProviderId { - LanguageModelProviderId(PROVIDER_ID.into()) + PROVIDER_ID } fn name(&self) -> LanguageModelProviderName { - LanguageModelProviderName(PROVIDER_NAME.into()) + PROVIDER_NAME } fn icon(&self) -> IconName { @@ -350,11 +350,11 @@ impl LanguageModel for OllamaLanguageModel { } fn provider_id(&self) -> LanguageModelProviderId { - LanguageModelProviderId(PROVIDER_ID.into()) + PROVIDER_ID } fn provider_name(&self) -> LanguageModelProviderName { - LanguageModelProviderName(PROVIDER_NAME.into()) + PROVIDER_NAME } fn supports_tools(&self) -> bool { @@ -453,7 +453,7 @@ fn map_to_language_model_completion_events( let delta = match response { Ok(delta) => delta, Err(e) => { - let event = Err(LanguageModelCompletionError::Other(anyhow!(e))); + let event = Err(LanguageModelCompletionError::from(anyhow!(e))); return Some((vec![event], state)); } }; diff --git a/crates/language_models/src/provider/open_ai.rs b/crates/language_models/src/provider/open_ai.rs index ad4203ff81..476c1715ae 100644 --- a/crates/language_models/src/provider/open_ai.rs +++ b/crates/language_models/src/provider/open_ai.rs @@ -31,8 +31,8 @@ use util::ResultExt; use crate::OpenAiSettingsContent; use crate::{AllLanguageModelSettings, ui::InstructionListItem}; -const PROVIDER_ID: &str = "openai"; -const PROVIDER_NAME: &str = "OpenAI"; +const PROVIDER_ID: LanguageModelProviderId = language_model::OPEN_AI_PROVIDER_ID; +const PROVIDER_NAME: LanguageModelProviderName = language_model::OPEN_AI_PROVIDER_NAME; #[derive(Default, Clone, Debug, PartialEq)] pub struct OpenAiSettings { @@ -173,11 +173,11 @@ impl LanguageModelProviderState for OpenAiLanguageModelProvider { impl LanguageModelProvider for OpenAiLanguageModelProvider { fn id(&self) -> LanguageModelProviderId { - LanguageModelProviderId(PROVIDER_ID.into()) + PROVIDER_ID } fn name(&self) -> LanguageModelProviderName { - LanguageModelProviderName(PROVIDER_NAME.into()) + PROVIDER_NAME } fn icon(&self) -> IconName { @@ -267,7 +267,11 @@ impl OpenAiLanguageModel { }; let future = self.request_limiter.stream(async move { - let api_key = api_key.context("Missing OpenAI API Key")?; + let Some(api_key) = api_key else { + return Err(LanguageModelCompletionError::NoApiKey { + provider: PROVIDER_NAME, + }); + }; let request = stream_completion(http_client.as_ref(), &api_url, &api_key, request); let response = request.await?; Ok(response) @@ -287,11 +291,11 @@ impl LanguageModel for OpenAiLanguageModel { } fn provider_id(&self) -> LanguageModelProviderId { - LanguageModelProviderId(PROVIDER_ID.into()) + PROVIDER_ID } fn provider_name(&self) -> LanguageModelProviderName { - LanguageModelProviderName(PROVIDER_NAME.into()) + PROVIDER_NAME } fn supports_tools(&self) -> bool { @@ -525,7 +529,7 @@ impl OpenAiEventMapper { events.flat_map(move |event| { futures::stream::iter(match event { Ok(event) => self.map_event(event), - Err(error) => vec![Err(LanguageModelCompletionError::Other(anyhow!(error)))], + Err(error) => vec![Err(LanguageModelCompletionError::from(anyhow!(error)))], }) }) } @@ -588,10 +592,10 @@ impl OpenAiEventMapper { raw_input: tool_call.arguments.clone(), }, )), - Err(error) => Err(LanguageModelCompletionError::BadInputJson { + Err(error) => Ok(LanguageModelCompletionEvent::ToolUseJsonParseError { id: tool_call.id.into(), - tool_name: tool_call.name.as_str().into(), - raw_input: tool_call.arguments.into(), + tool_name: tool_call.name.into(), + raw_input: tool_call.arguments.clone().into(), json_parse_error: error.to_string(), }), } diff --git a/crates/language_models/src/provider/open_router.rs b/crates/language_models/src/provider/open_router.rs index 3a8a450cf6..5883da1e2f 100644 --- a/crates/language_models/src/provider/open_router.rs +++ b/crates/language_models/src/provider/open_router.rs @@ -29,8 +29,8 @@ use util::ResultExt; use crate::{AllLanguageModelSettings, ui::InstructionListItem}; -const PROVIDER_ID: &str = "openrouter"; -const PROVIDER_NAME: &str = "OpenRouter"; +const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("openrouter"); +const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("OpenRouter"); #[derive(Default, Clone, Debug, PartialEq)] pub struct OpenRouterSettings { @@ -244,11 +244,11 @@ impl LanguageModelProviderState for OpenRouterLanguageModelProvider { impl LanguageModelProvider for OpenRouterLanguageModelProvider { fn id(&self) -> LanguageModelProviderId { - LanguageModelProviderId(PROVIDER_ID.into()) + PROVIDER_ID } fn name(&self) -> LanguageModelProviderName { - LanguageModelProviderName(PROVIDER_NAME.into()) + PROVIDER_NAME } fn icon(&self) -> IconName { @@ -363,11 +363,11 @@ impl LanguageModel for OpenRouterLanguageModel { } fn provider_id(&self) -> LanguageModelProviderId { - LanguageModelProviderId(PROVIDER_ID.into()) + PROVIDER_ID } fn provider_name(&self) -> LanguageModelProviderName { - LanguageModelProviderName(PROVIDER_NAME.into()) + PROVIDER_NAME } fn supports_tools(&self) -> bool { @@ -607,7 +607,7 @@ impl OpenRouterEventMapper { events.flat_map(move |event| { futures::stream::iter(match event { Ok(event) => self.map_event(event), - Err(error) => vec![Err(LanguageModelCompletionError::Other(anyhow!(error)))], + Err(error) => vec![Err(LanguageModelCompletionError::from(anyhow!(error)))], }) }) } @@ -617,7 +617,7 @@ impl OpenRouterEventMapper { event: ResponseStreamEvent, ) -> Vec> { let Some(choice) = event.choices.first() else { - return vec![Err(LanguageModelCompletionError::Other(anyhow!( + return vec![Err(LanguageModelCompletionError::from(anyhow!( "Response contained no choices" )))]; }; @@ -683,10 +683,10 @@ impl OpenRouterEventMapper { raw_input: tool_call.arguments.clone(), }, )), - Err(error) => Err(LanguageModelCompletionError::BadInputJson { - id: tool_call.id.into(), + Err(error) => Ok(LanguageModelCompletionEvent::ToolUseJsonParseError { + id: tool_call.id.clone().into(), tool_name: tool_call.name.as_str().into(), - raw_input: tool_call.arguments.into(), + raw_input: tool_call.arguments.clone().into(), json_parse_error: error.to_string(), }), } diff --git a/crates/language_models/src/provider/vercel.rs b/crates/language_models/src/provider/vercel.rs index 2f64115d20..037ce467d0 100644 --- a/crates/language_models/src/provider/vercel.rs +++ b/crates/language_models/src/provider/vercel.rs @@ -25,8 +25,8 @@ use util::ResultExt; use crate::{AllLanguageModelSettings, ui::InstructionListItem}; -const PROVIDER_ID: &str = "vercel"; -const PROVIDER_NAME: &str = "Vercel"; +const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("vercel"); +const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("Vercel"); #[derive(Default, Clone, Debug, PartialEq)] pub struct VercelSettings { @@ -172,11 +172,11 @@ impl LanguageModelProviderState for VercelLanguageModelProvider { impl LanguageModelProvider for VercelLanguageModelProvider { fn id(&self) -> LanguageModelProviderId { - LanguageModelProviderId(PROVIDER_ID.into()) + PROVIDER_ID } fn name(&self) -> LanguageModelProviderName { - LanguageModelProviderName(PROVIDER_NAME.into()) + PROVIDER_NAME } fn icon(&self) -> IconName { @@ -269,7 +269,11 @@ impl VercelLanguageModel { }; let future = self.request_limiter.stream(async move { - let api_key = api_key.context("Missing Vercel API Key")?; + let Some(api_key) = api_key else { + return Err(LanguageModelCompletionError::NoApiKey { + provider: PROVIDER_NAME, + }); + }; let request = open_ai::stream_completion(http_client.as_ref(), &api_url, &api_key, request); let response = request.await?; @@ -290,11 +294,11 @@ impl LanguageModel for VercelLanguageModel { } fn provider_id(&self) -> LanguageModelProviderId { - LanguageModelProviderId(PROVIDER_ID.into()) + PROVIDER_ID } fn provider_name(&self) -> LanguageModelProviderName { - LanguageModelProviderName(PROVIDER_NAME.into()) + PROVIDER_NAME } fn supports_tools(&self) -> bool { diff --git a/crates/web_search_providers/src/cloud.rs b/crates/web_search_providers/src/cloud.rs index de280cd5c0..adf79b0ff6 100644 --- a/crates/web_search_providers/src/cloud.rs +++ b/crates/web_search_providers/src/cloud.rs @@ -7,10 +7,7 @@ use gpui::{App, AppContext, Context, Entity, Subscription, Task}; use http_client::{HttpClient, Method}; use language_model::{LlmApiToken, RefreshLlmTokenListener}; use web_search::{WebSearchProvider, WebSearchProviderId}; -use zed_llm_client::{ - CLIENT_SUPPORTS_EXA_WEB_SEARCH_PROVIDER_HEADER_NAME, EXPIRED_LLM_TOKEN_HEADER_NAME, - WebSearchBody, WebSearchResponse, -}; +use zed_llm_client::{EXPIRED_LLM_TOKEN_HEADER_NAME, WebSearchBody, WebSearchResponse}; pub struct CloudWebSearchProvider { state: Entity, @@ -92,7 +89,6 @@ async fn perform_web_search( .uri(http_client.build_zed_llm_url("/web_search", &[])?.as_ref()) .header("Content-Type", "application/json") .header("Authorization", format!("Bearer {token}")) - .header(CLIENT_SUPPORTS_EXA_WEB_SEARCH_PROVIDER_HEADER_NAME, "true") .body(serde_json::to_string(&body)?.into())?; let mut response = http_client .send(request)