diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index dfbb21a196..e3080fd0ad 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -1495,27 +1495,76 @@ impl Thread { thread.update(cx, |thread, cx| { let event = match event { Ok(event) => event, - Err(LanguageModelCompletionError::BadInputJson { - id, - tool_name, - raw_input: invalid_input_json, - json_parse_error, - }) => { - thread.receive_invalid_tool_json( - id, - tool_name, - invalid_input_json, - json_parse_error, - window, - cx, - ); - return Ok(()); - } - Err(LanguageModelCompletionError::Other(error)) => { - return Err(error); - } - Err(err @ LanguageModelCompletionError::RateLimit(..)) => { - return Err(err.into()); + Err(error) => { + match error { + LanguageModelCompletionError::RateLimitExceeded { retry_after } => { + anyhow::bail!(LanguageModelKnownError::RateLimitExceeded { retry_after }); + } + LanguageModelCompletionError::Overloaded => { + anyhow::bail!(LanguageModelKnownError::Overloaded); + } + LanguageModelCompletionError::ApiInternalServerError =>{ + anyhow::bail!(LanguageModelKnownError::ApiInternalServerError); + } + LanguageModelCompletionError::PromptTooLarge { tokens } => { + let tokens = tokens.unwrap_or_else(|| { + // We didn't get an exact token count from the API, so fall back on our estimate. + thread.total_token_usage() + .map(|usage| usage.total) + .unwrap_or(0) + // We know the context window was exceeded in practice, so if our estimate was + // lower than max tokens, the estimate was wrong; return that we exceeded by 1. + .max(model.max_token_count().saturating_add(1)) + }); + + anyhow::bail!(LanguageModelKnownError::ContextWindowLimitExceeded { tokens }) + } + LanguageModelCompletionError::ApiReadResponseError(io_error) => { + anyhow::bail!(LanguageModelKnownError::ReadResponseError(io_error)); + } + LanguageModelCompletionError::UnknownResponseFormat(error) => { + anyhow::bail!(LanguageModelKnownError::UnknownResponseFormat(error)); + } + LanguageModelCompletionError::HttpResponseError { status, ref body } => { + if let Some(known_error) = LanguageModelKnownError::from_http_response(status, body) { + anyhow::bail!(known_error); + } else { + return Err(error.into()); + } + } + LanguageModelCompletionError::DeserializeResponse(error) => { + anyhow::bail!(LanguageModelKnownError::DeserializeResponse(error)); + } + LanguageModelCompletionError::BadInputJson { + id, + tool_name, + raw_input: invalid_input_json, + json_parse_error, + } => { + thread.receive_invalid_tool_json( + id, + tool_name, + invalid_input_json, + json_parse_error, + window, + cx, + ); + return Ok(()); + } + // These are all errors we can't automatically attempt to recover from (e.g. by retrying) + err @ LanguageModelCompletionError::BadRequestFormat | + err @ LanguageModelCompletionError::AuthenticationError | + err @ LanguageModelCompletionError::PermissionError | + err @ LanguageModelCompletionError::ApiEndpointNotFound | + err @ LanguageModelCompletionError::SerializeRequest(_) | + err @ LanguageModelCompletionError::BuildRequestBody(_) | + err @ LanguageModelCompletionError::HttpSend(_) => { + anyhow::bail!(err); + } + LanguageModelCompletionError::Other(error) => { + return Err(error); + } + } } }; @@ -1751,6 +1800,18 @@ impl Thread { project.set_agent_location(None, cx); }); + fn emit_generic_error(error: &anyhow::Error, cx: &mut Context) { + let error_message = error + .chain() + .map(|err| err.to_string()) + .collect::>() + .join("\n"); + cx.emit(ThreadEvent::ShowError(ThreadError::Message { + header: "Error interacting with language model".into(), + message: SharedString::from(error_message.clone()), + })); + } + if error.is::() { cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired)); } else if let Some(error) = @@ -1763,26 +1824,34 @@ impl Thread { error.downcast_ref::() { match known_error { - LanguageModelKnownError::ContextWindowLimitExceeded { - tokens, - } => { + LanguageModelKnownError::ContextWindowLimitExceeded { tokens } => { thread.exceeded_window_error = Some(ExceededWindowError { model_id: model.id(), token_count: *tokens, }); cx.notify(); } + LanguageModelKnownError::RateLimitExceeded { .. } => { + // In the future we will report the error to the user, wait retry_after, and then retry. + emit_generic_error(error, cx); + } + LanguageModelKnownError::Overloaded => { + // In the future we will wait and then retry, up to N times. + emit_generic_error(error, cx); + } + LanguageModelKnownError::ApiInternalServerError => { + // In the future we will retry the request, but only once. + emit_generic_error(error, cx); + } + LanguageModelKnownError::ReadResponseError(_) | + LanguageModelKnownError::DeserializeResponse(_) | + LanguageModelKnownError::UnknownResponseFormat(_) => { + // In the future we will attempt to re-roll response, but only once + emit_generic_error(error, cx); + } } } else { - let error_message = error - .chain() - .map(|err| err.to_string()) - .collect::>() - .join("\n"); - cx.emit(ThreadEvent::ShowError(ThreadError::Message { - header: "Error interacting with language model".into(), - message: SharedString::from(error_message.clone()), - })); + emit_generic_error(error, cx); } thread.cancel_last_completion(window, cx); diff --git a/crates/anthropic/src/anthropic.rs b/crates/anthropic/src/anthropic.rs index 97ebec710a..7f0ab7550d 100644 --- a/crates/anthropic/src/anthropic.rs +++ b/crates/anthropic/src/anthropic.rs @@ -1,10 +1,11 @@ +use std::io; use std::str::FromStr; use std::time::Duration; use anyhow::{Context as _, Result, anyhow}; use chrono::{DateTime, Utc}; use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream}; -use http_client::http::{HeaderMap, HeaderValue}; +use http_client::http::{self, HeaderMap, HeaderValue}; use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest}; use serde::{Deserialize, Serialize}; use strum::{EnumIter, EnumString}; @@ -336,7 +337,7 @@ pub async fn complete( let uri = format!("{api_url}/v1/messages"); let beta_headers = Model::from_id(&request.model) .map(|model| model.beta_headers()) - .unwrap_or_else(|_err| Model::DEFAULT_BETA_HEADERS.join(",")); + .unwrap_or_else(|_| Model::DEFAULT_BETA_HEADERS.join(",")); let request_builder = HttpRequest::builder() .method(Method::POST) .uri(uri) @@ -346,39 +347,30 @@ pub async fn complete( .header("Content-Type", "application/json"); let serialized_request = - serde_json::to_string(&request).context("failed to serialize request")?; + serde_json::to_string(&request).map_err(AnthropicError::SerializeRequest)?; let request = request_builder .body(AsyncBody::from(serialized_request)) - .context("failed to construct request body")?; + .map_err(AnthropicError::BuildRequestBody)?; let mut response = client .send(request) .await - .context("failed to send request to Anthropic")?; - if response.status().is_success() { - let mut body = Vec::new(); - response - .body_mut() - .read_to_end(&mut body) - .await - .context("failed to read response body")?; - let response_message: Response = - serde_json::from_slice(&body).context("failed to deserialize response body")?; - Ok(response_message) + .map_err(AnthropicError::HttpSend)?; + let status = response.status(); + let mut body = String::new(); + response + .body_mut() + .read_to_string(&mut body) + .await + .map_err(AnthropicError::ReadResponse)?; + + if status.is_success() { + Ok(serde_json::from_str(&body).map_err(AnthropicError::DeserializeResponse)?) } else { - let mut body = Vec::new(); - response - .body_mut() - .read_to_end(&mut body) - .await - .context("failed to read response body")?; - let body_str = - std::str::from_utf8(&body).context("failed to parse response body as UTF-8")?; - Err(AnthropicError::Other(anyhow!( - "Failed to connect to API: {} {}", - response.status(), - body_str - ))) + Err(AnthropicError::HttpResponseError { + status: status.as_u16(), + body, + }) } } @@ -491,7 +483,7 @@ pub async fn stream_completion_with_rate_limit_info( let uri = format!("{api_url}/v1/messages"); let beta_headers = Model::from_id(&request.base.model) .map(|model| model.beta_headers()) - .unwrap_or_else(|_err| Model::DEFAULT_BETA_HEADERS.join(",")); + .unwrap_or_else(|_| Model::DEFAULT_BETA_HEADERS.join(",")); let request_builder = HttpRequest::builder() .method(Method::POST) .uri(uri) @@ -500,15 +492,15 @@ pub async fn stream_completion_with_rate_limit_info( .header("X-Api-Key", api_key) .header("Content-Type", "application/json"); let serialized_request = - serde_json::to_string(&request).context("failed to serialize request")?; + serde_json::to_string(&request).map_err(AnthropicError::SerializeRequest)?; let request = request_builder .body(AsyncBody::from(serialized_request)) - .context("failed to construct request body")?; + .map_err(AnthropicError::BuildRequestBody)?; let mut response = client .send(request) .await - .context("failed to send request to Anthropic")?; + .map_err(AnthropicError::HttpSend)?; let rate_limits = RateLimitInfo::from_headers(response.headers()); if response.status().is_success() { let reader = BufReader::new(response.into_body()); @@ -520,37 +512,31 @@ pub async fn stream_completion_with_rate_limit_info( let line = line.strip_prefix("data: ")?; match serde_json::from_str(line) { Ok(response) => Some(Ok(response)), - Err(error) => Some(Err(AnthropicError::Other(anyhow!(error)))), + Err(error) => Some(Err(AnthropicError::DeserializeResponse(error))), } } - Err(error) => Some(Err(AnthropicError::Other(anyhow!(error)))), + Err(error) => Some(Err(AnthropicError::ReadResponse(error))), } }) .boxed(); Ok((stream, Some(rate_limits))) } else if let Some(retry_after) = rate_limits.retry_after { - Err(AnthropicError::RateLimit(retry_after)) + Err(AnthropicError::RateLimit { retry_after }) } else { - let mut body = Vec::new(); + let mut body = String::new(); response .body_mut() - .read_to_end(&mut body) + .read_to_string(&mut body) .await - .context("failed to read response body")?; + .map_err(AnthropicError::ReadResponse)?; - let body_str = - std::str::from_utf8(&body).context("failed to parse response body as UTF-8")?; - - match serde_json::from_str::(body_str) { + match serde_json::from_str::(&body) { Ok(Event::Error { error }) => Err(AnthropicError::ApiError(error)), - Ok(_) => Err(AnthropicError::Other(anyhow!( - "Unexpected success response while expecting an error: '{body_str}'", - ))), - Err(_) => Err(AnthropicError::Other(anyhow!( - "Failed to connect to API: {} {}", - response.status(), - body_str, - ))), + Ok(_) => Err(AnthropicError::UnexpectedResponseFormat(body)), + Err(_) => Err(AnthropicError::HttpResponseError { + status: response.status().as_u16(), + body: body, + }), } } } @@ -797,17 +783,38 @@ pub struct MessageDelta { pub stop_sequence: Option, } -#[derive(Error, Debug)] +#[derive(Debug)] pub enum AnthropicError { - #[error("rate limit exceeded, retry after {0:?}")] - RateLimit(Duration), - #[error("an error occurred while interacting with the Anthropic API: {error_type}: {message}", error_type = .0.error_type, message = .0.message)] + /// Failed to serialize the HTTP request body to JSON + SerializeRequest(serde_json::Error), + + /// Failed to construct the HTTP request body + BuildRequestBody(http::Error), + + /// Failed to send the HTTP request + HttpSend(anyhow::Error), + + /// Failed to deserialize the response from JSON + DeserializeResponse(serde_json::Error), + + /// Failed to read from response stream + ReadResponse(io::Error), + + /// HTTP error response from the API + HttpResponseError { status: u16, body: String }, + + /// Rate limit exceeded + RateLimit { retry_after: Duration }, + + /// API returned an error response ApiError(ApiError), - #[error("{0}")] - Other(#[from] anyhow::Error), + + /// Unexpected response format + UnexpectedResponseFormat(String), } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Serialize, Deserialize, Error)] +#[error("Anthropic API Error: {error_type}: {message}")] pub struct ApiError { #[serde(rename = "type")] pub error_type: String, diff --git a/crates/assistant_tools/src/edit_agent/evals.rs b/crates/assistant_tools/src/edit_agent/evals.rs index b5744d455a..116654e382 100644 --- a/crates/assistant_tools/src/edit_agent/evals.rs +++ b/crates/assistant_tools/src/edit_agent/evals.rs @@ -1659,13 +1659,13 @@ async fn retry_on_rate_limit(mut request: impl AsyncFnMut() -> Result) -> Ok(result) => return Ok(result), Err(err) => match err.downcast::() { Ok(err) => match err { - LanguageModelCompletionError::RateLimit(duration) => { + LanguageModelCompletionError::RateLimitExceeded { retry_after } => { // Wait for the duration supplied, with some jitter to avoid all requests being made at the same time. - let jitter = duration.mul_f64(rand::thread_rng().gen_range(0.0..1.0)); + let jitter = retry_after.mul_f64(rand::thread_rng().gen_range(0.0..1.0)); eprintln!( - "Attempt #{attempt}: Rate limit exceeded. Retry after {duration:?} + jitter of {jitter:?}" + "Attempt #{attempt}: Rate limit exceeded. Retry after {retry_after:?} + jitter of {jitter:?}" ); - Timer::after(duration + jitter).await; + Timer::after(retry_after + jitter).await; continue; } _ => return Err(err.into()), diff --git a/crates/language_model/src/language_model.rs b/crates/language_model/src/language_model.rs index 900d7f6f39..9f165df301 100644 --- a/crates/language_model/src/language_model.rs +++ b/crates/language_model/src/language_model.rs @@ -8,19 +8,21 @@ mod telemetry; #[cfg(any(test, feature = "test-support"))] pub mod fake_provider; +use anthropic::{AnthropicError, parse_prompt_too_long}; use anyhow::Result; use client::Client; use futures::FutureExt; use futures::{StreamExt, future::BoxFuture, stream::BoxStream}; use gpui::{AnyElement, AnyView, App, AsyncApp, SharedString, Task, Window}; +use http_client::http; use icons::IconName; use parking_lot::Mutex; use schemars::JsonSchema; use serde::{Deserialize, Serialize, de::DeserializeOwned}; -use std::fmt; use std::ops::{Add, Sub}; use std::sync::Arc; use std::time::Duration; +use std::{fmt, io}; use thiserror::Error; use util::serde::is_default; use zed_llm_client::CompletionRequestStatus; @@ -34,6 +36,10 @@ pub use crate::telemetry::*; pub const ZED_CLOUD_PROVIDER_ID: &str = "zed.dev"; +/// If we get a rate limit error that doesn't tell us when we can retry, +/// default to waiting this long before retrying. +const DEFAULT_RATE_LIMIT_RETRY_AFTER: Duration = Duration::from_secs(4); + pub fn init(client: Arc, cx: &mut App) { init_settings(cx); RefreshLlmTokenListener::register(client.clone(), cx); @@ -70,8 +76,8 @@ pub enum LanguageModelCompletionEvent { #[derive(Error, Debug)] pub enum LanguageModelCompletionError { - #[error("rate limit exceeded, retry after {0:?}")] - RateLimit(Duration), + #[error("rate limit exceeded, retry after {retry_after:?}")] + RateLimitExceeded { retry_after: Duration }, #[error("received bad input JSON")] BadInputJson { id: LanguageModelToolUseId, @@ -79,8 +85,78 @@ pub enum LanguageModelCompletionError { raw_input: Arc, json_parse_error: String, }, + #[error("language model provider's API is overloaded")] + Overloaded, #[error(transparent)] Other(#[from] anyhow::Error), + #[error("invalid request format to language model provider's API")] + BadRequestFormat, + #[error("authentication error with language model provider's API")] + AuthenticationError, + #[error("permission error with language model provider's API")] + PermissionError, + #[error("language model provider API endpoint not found")] + ApiEndpointNotFound, + #[error("prompt too large for context window")] + PromptTooLarge { tokens: Option }, + #[error("internal server error in language model provider's API")] + ApiInternalServerError, + #[error("I/O error reading response from language model provider's API: {0:?}")] + ApiReadResponseError(io::Error), + #[error("HTTP response error from language model provider's API: status {status} - {body:?}")] + HttpResponseError { status: u16, body: String }, + #[error("error serializing request to language model provider API: {0}")] + SerializeRequest(serde_json::Error), + #[error("error building request body to language model provider API: {0}")] + BuildRequestBody(http::Error), + #[error("error sending HTTP request to language model provider API: {0}")] + HttpSend(anyhow::Error), + #[error("error deserializing language model provider API response: {0}")] + DeserializeResponse(serde_json::Error), + #[error("unexpected language model provider API response format: {0}")] + UnknownResponseFormat(String), +} + +impl From for LanguageModelCompletionError { + fn from(error: AnthropicError) -> Self { + match error { + AnthropicError::SerializeRequest(error) => Self::SerializeRequest(error), + AnthropicError::BuildRequestBody(error) => Self::BuildRequestBody(error), + AnthropicError::HttpSend(error) => Self::HttpSend(error), + AnthropicError::DeserializeResponse(error) => Self::DeserializeResponse(error), + AnthropicError::ReadResponse(error) => Self::ApiReadResponseError(error), + AnthropicError::HttpResponseError { status, body } => { + Self::HttpResponseError { status, body } + } + AnthropicError::RateLimit { retry_after } => Self::RateLimitExceeded { retry_after }, + AnthropicError::ApiError(api_error) => api_error.into(), + AnthropicError::UnexpectedResponseFormat(error) => Self::UnknownResponseFormat(error), + } + } +} + +impl From for LanguageModelCompletionError { + fn from(error: anthropic::ApiError) -> Self { + use anthropic::ApiErrorCode::*; + + match error.code() { + Some(code) => match code { + InvalidRequestError => LanguageModelCompletionError::BadRequestFormat, + AuthenticationError => LanguageModelCompletionError::AuthenticationError, + PermissionError => LanguageModelCompletionError::PermissionError, + NotFoundError => LanguageModelCompletionError::ApiEndpointNotFound, + RequestTooLarge => LanguageModelCompletionError::PromptTooLarge { + tokens: parse_prompt_too_long(&error.message), + }, + RateLimitError => LanguageModelCompletionError::RateLimitExceeded { + retry_after: DEFAULT_RATE_LIMIT_RETRY_AFTER, + }, + ApiError => LanguageModelCompletionError::ApiInternalServerError, + OverloadedError => LanguageModelCompletionError::Overloaded, + }, + None => LanguageModelCompletionError::Other(error.into()), + } + } } /// Indicates the format used to define the input schema for a language model tool. @@ -319,6 +395,33 @@ pub trait LanguageModel: Send + Sync { pub enum LanguageModelKnownError { #[error("Context window limit exceeded ({tokens})")] ContextWindowLimitExceeded { tokens: u64 }, + #[error("Language model provider's API is currently overloaded")] + Overloaded, + #[error("Language model provider's API encountered an internal server error")] + ApiInternalServerError, + #[error("I/O error while reading response from language model provider's API: {0:?}")] + ReadResponseError(io::Error), + #[error("Error deserializing response from language model provider's API: {0:?}")] + DeserializeResponse(serde_json::Error), + #[error("Language model provider's API returned a response in an unknown format")] + UnknownResponseFormat(String), + #[error("Rate limit exceeded for language model provider's API; retry in {retry_after:?}")] + RateLimitExceeded { retry_after: Duration }, +} + +impl LanguageModelKnownError { + /// Attempts to map an HTTP response status code to a known error type. + /// Returns None if the status code doesn't map to a specific known error. + pub fn from_http_response(status: u16, _body: &str) -> Option { + match status { + 429 => Some(Self::RateLimitExceeded { + retry_after: DEFAULT_RATE_LIMIT_RETRY_AFTER, + }), + 503 => Some(Self::Overloaded), + 500..=599 => Some(Self::ApiInternalServerError), + _ => None, + } + } } pub trait LanguageModelTool: 'static + DeserializeOwned + JsonSchema { diff --git a/crates/language_models/src/provider/anthropic.rs b/crates/language_models/src/provider/anthropic.rs index a8423fefa5..719975c1d5 100644 --- a/crates/language_models/src/provider/anthropic.rs +++ b/crates/language_models/src/provider/anthropic.rs @@ -16,10 +16,10 @@ use gpui::{ use http_client::HttpClient; use language_model::{ AuthenticateError, LanguageModel, LanguageModelCacheConfiguration, - LanguageModelCompletionError, LanguageModelId, LanguageModelKnownError, LanguageModelName, - LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, - LanguageModelProviderState, LanguageModelRequest, LanguageModelToolChoice, - LanguageModelToolResultContent, MessageContent, RateLimiter, Role, + LanguageModelCompletionError, LanguageModelId, LanguageModelName, LanguageModelProvider, + LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, + LanguageModelRequest, LanguageModelToolChoice, LanguageModelToolResultContent, MessageContent, + RateLimiter, Role, }; use language_model::{LanguageModelCompletionEvent, LanguageModelToolUse, StopReason}; use schemars::JsonSchema; @@ -407,14 +407,7 @@ impl AnthropicModel { let api_key = api_key.context("Missing Anthropic API Key")?; let request = anthropic::stream_completion(http_client.as_ref(), &api_url, &api_key, request); - request.await.map_err(|err| match err { - AnthropicError::RateLimit(duration) => { - LanguageModelCompletionError::RateLimit(duration) - } - err @ (AnthropicError::ApiError(..) | AnthropicError::Other(..)) => { - LanguageModelCompletionError::Other(anthropic_err_to_anyhow(err)) - } - }) + request.await.map_err(Into::into) } .boxed() } @@ -714,7 +707,7 @@ impl AnthropicEventMapper { events.flat_map(move |event| { futures::stream::iter(match event { Ok(event) => self.map_event(event), - Err(error) => vec![Err(LanguageModelCompletionError::Other(anyhow!(error)))], + Err(error) => vec![Err(error.into())], }) }) } @@ -859,9 +852,7 @@ impl AnthropicEventMapper { vec![Ok(LanguageModelCompletionEvent::Stop(self.stop_reason))] } Event::Error { error } => { - vec![Err(LanguageModelCompletionError::Other(anyhow!( - AnthropicError::ApiError(error) - )))] + vec![Err(error.into())] } _ => Vec::new(), } @@ -874,16 +865,6 @@ struct RawToolUse { input_json: String, } -pub fn anthropic_err_to_anyhow(err: AnthropicError) -> anyhow::Error { - if let AnthropicError::ApiError(api_err) = &err { - if let Some(tokens) = api_err.match_window_exceeded() { - return anyhow!(LanguageModelKnownError::ContextWindowLimitExceeded { tokens }); - } - } - - anyhow!(err) -} - /// Updates usage data by preferring counts from `new`. fn update_usage(usage: &mut Usage, new: &Usage) { if let Some(input_tokens) = new.input_tokens {