diff --git a/Cargo.lock b/Cargo.lock index e9e2fb8fdd..ef35a57f28 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -7708,6 +7708,7 @@ dependencies = [ "smol", "strum", "theme", + "thiserror 2.0.12", "tiktoken-rs", "tokio", "ui", diff --git a/crates/agent/src/message_editor.rs b/crates/agent/src/message_editor.rs index 3b8881f2d8..573e4b4d03 100644 --- a/crates/agent/src/message_editor.rs +++ b/crates/agent/src/message_editor.rs @@ -761,13 +761,29 @@ impl MessageEditor { }) } - fn render_reaching_token_limit(&self, line_height: Pixels, cx: &mut Context) -> Div { + fn render_token_limit_callout( + &self, + line_height: Pixels, + token_usage_ratio: TokenUsageRatio, + cx: &mut Context, + ) -> Div { + let heading = if token_usage_ratio == TokenUsageRatio::Exceeded { + "Thread reached the token limit" + } else { + "Thread reaching the token limit soon" + }; + h_flex() .p_2() .gap_2() .flex_wrap() .justify_between() - .bg(cx.theme().status().warning_background.opacity(0.1)) + .bg( + if token_usage_ratio == TokenUsageRatio::Exceeded { + cx.theme().status().error_background.opacity(0.1) + } else { + cx.theme().status().warning_background.opacity(0.1) + }) .border_t_1() .border_color(cx.theme().colors().border) .child( @@ -779,15 +795,21 @@ impl MessageEditor { .h(line_height) .justify_center() .child( - Icon::new(IconName::Warning) - .color(Color::Warning) - .size(IconSize::XSmall), + if token_usage_ratio == TokenUsageRatio::Exceeded { + Icon::new(IconName::X) + .color(Color::Error) + .size(IconSize::XSmall) + } else { + Icon::new(IconName::Warning) + .color(Color::Warning) + .size(IconSize::XSmall) + } ), ) .child( v_flex() .mr_auto() - .child(Label::new("Thread reaching the token limit soon").size(LabelSize::Small)) + .child(Label::new(heading).size(LabelSize::Small)) .child( Label::new( "Start a new thread from a summary to continue the conversation.", @@ -875,7 +897,13 @@ impl Render for MessageEditor { .child(self.render_editor(font_size, line_height, window, cx)) .when( total_token_usage.ratio != TokenUsageRatio::Normal, - |parent| parent.child(self.render_reaching_token_limit(line_height, cx)), + |parent| { + parent.child(self.render_token_limit_callout( + line_height, + total_token_usage.ratio, + cx, + )) + }, ) } } diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index abd449a4f0..190f2ace0e 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -15,10 +15,11 @@ use futures::{FutureExt, StreamExt as _}; use git::repository::DiffType; use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Task, WeakEntity}; use language_model::{ - ConfiguredModel, LanguageModel, LanguageModelCompletionEvent, LanguageModelRegistry, - LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool, - LanguageModelToolResult, LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent, - PaymentRequiredError, Role, StopReason, TokenUsage, + ConfiguredModel, LanguageModel, LanguageModelCompletionEvent, LanguageModelId, + LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest, + LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult, + LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent, PaymentRequiredError, + Role, StopReason, TokenUsage, }; use project::Project; use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState}; @@ -228,7 +229,7 @@ pub struct TotalTokenUsage { pub ratio: TokenUsageRatio, } -#[derive(Default, PartialEq, Eq)] +#[derive(Debug, Default, PartialEq, Eq)] pub enum TokenUsageRatio { #[default] Normal, @@ -260,11 +261,20 @@ pub struct Thread { pending_checkpoint: Option, initial_project_snapshot: Shared>>>, cumulative_token_usage: TokenUsage, + exceeded_window_error: Option, feedback: Option, message_feedback: HashMap, last_auto_capture_at: Option, } +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ExceededWindowError { + /// Model used when last message exceeded context window + model_id: LanguageModelId, + /// Token count including last message + token_count: usize, +} + impl Thread { pub fn new( project: Entity, @@ -301,6 +311,7 @@ impl Thread { .shared() }, cumulative_token_usage: TokenUsage::default(), + exceeded_window_error: None, feedback: None, message_feedback: HashMap::default(), last_auto_capture_at: None, @@ -367,6 +378,7 @@ impl Thread { action_log: cx.new(|_| ActionLog::new(project)), initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(), cumulative_token_usage: serialized.cumulative_token_usage, + exceeded_window_error: None, feedback: None, message_feedback: HashMap::default(), last_auto_capture_at: None, @@ -817,6 +829,7 @@ impl Thread { initial_project_snapshot, cumulative_token_usage: this.cumulative_token_usage.clone(), detailed_summary_state: this.detailed_summary_state.clone(), + exceeded_window_error: this.exceeded_window_error.clone(), }) }) } @@ -1129,6 +1142,20 @@ impl Thread { cx.emit(ThreadEvent::ShowError( ThreadError::MaxMonthlySpendReached, )); + } else if let Some(known_error) = + error.downcast_ref::() + { + match known_error { + LanguageModelKnownError::ContextWindowLimitExceeded { + tokens, + } => { + thread.exceeded_window_error = Some(ExceededWindowError { + model_id: model.id(), + token_count: *tokens, + }); + cx.notify(); + } + } } else { let error_message = error .chain() @@ -1784,10 +1811,6 @@ impl Thread { &self.project } - pub fn cumulative_token_usage(&self) -> TokenUsage { - self.cumulative_token_usage.clone() - } - pub fn auto_capture_telemetry(&mut self, cx: &mut Context) { if !cx.has_flag::() { return; @@ -1840,6 +1863,16 @@ impl Thread { let max = model.model.max_token_count(); + if let Some(exceeded_error) = &self.exceeded_window_error { + if model.model.id() == exceeded_error.model_id { + return TotalTokenUsage { + total: exceeded_error.token_count, + max, + ratio: TokenUsageRatio::Exceeded, + }; + } + } + #[cfg(debug_assertions)] let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD") .unwrap_or("0.8".to_string()) diff --git a/crates/agent/src/thread_store.rs b/crates/agent/src/thread_store.rs index c8f8d239a2..ebb673a86f 100644 --- a/crates/agent/src/thread_store.rs +++ b/crates/agent/src/thread_store.rs @@ -27,7 +27,9 @@ use serde::{Deserialize, Serialize}; use settings::{Settings as _, SettingsStore}; use util::ResultExt as _; -use crate::thread::{DetailedSummaryState, MessageId, ProjectSnapshot, Thread, ThreadId}; +use crate::thread::{ + DetailedSummaryState, ExceededWindowError, MessageId, ProjectSnapshot, Thread, ThreadId, +}; const RULES_FILE_NAMES: [&'static str; 6] = [ ".rules", @@ -491,6 +493,8 @@ pub struct SerializedThread { pub cumulative_token_usage: TokenUsage, #[serde(default)] pub detailed_summary_state: DetailedSummaryState, + #[serde(default)] + pub exceeded_window_error: Option, } impl SerializedThread { @@ -577,6 +581,7 @@ impl LegacySerializedThread { initial_project_snapshot: self.initial_project_snapshot, cumulative_token_usage: TokenUsage::default(), detailed_summary_state: DetailedSummaryState::default(), + exceeded_window_error: None, } } } diff --git a/crates/anthropic/src/anthropic.rs b/crates/anthropic/src/anthropic.rs index e0c215bc3a..2e403fd0fa 100644 --- a/crates/anthropic/src/anthropic.rs +++ b/crates/anthropic/src/anthropic.rs @@ -724,4 +724,54 @@ impl ApiError { pub fn is_rate_limit_error(&self) -> bool { matches!(self.error_type.as_str(), "rate_limit_error") } + + pub fn match_window_exceeded(&self) -> Option { + let Some(ApiErrorCode::InvalidRequestError) = self.code() else { + return None; + }; + + parse_prompt_too_long(&self.message) + } +} + +pub fn parse_prompt_too_long(message: &str) -> Option { + message + .strip_prefix("prompt is too long: ")? + .split_once(" tokens")? + .0 + .parse::() + .ok() +} + +#[test] +fn test_match_window_exceeded() { + let error = ApiError { + error_type: "invalid_request_error".to_string(), + message: "prompt is too long: 220000 tokens > 200000".to_string(), + }; + assert_eq!(error.match_window_exceeded(), Some(220_000)); + + let error = ApiError { + error_type: "invalid_request_error".to_string(), + message: "prompt is too long: 1234953 tokens".to_string(), + }; + assert_eq!(error.match_window_exceeded(), Some(1234953)); + + let error = ApiError { + error_type: "invalid_request_error".to_string(), + message: "not a prompt length error".to_string(), + }; + assert_eq!(error.match_window_exceeded(), None); + + let error = ApiError { + error_type: "rate_limit_error".to_string(), + message: "prompt is too long: 12345 tokens".to_string(), + }; + assert_eq!(error.match_window_exceeded(), None); + + let error = ApiError { + error_type: "invalid_request_error".to_string(), + message: "prompt is too long: invalid tokens".to_string(), + }; + assert_eq!(error.match_window_exceeded(), None); } diff --git a/crates/language_model/src/language_model.rs b/crates/language_model/src/language_model.rs index 98456e7db4..e1ec23410e 100644 --- a/crates/language_model/src/language_model.rs +++ b/crates/language_model/src/language_model.rs @@ -278,6 +278,12 @@ pub trait LanguageModel: Send + Sync { } } +#[derive(Debug, Error)] +pub enum LanguageModelKnownError { + #[error("Context window limit exceeded ({tokens})")] + ContextWindowLimitExceeded { tokens: usize }, +} + pub trait LanguageModelTool: 'static + DeserializeOwned + JsonSchema { fn name() -> String; fn description() -> String; @@ -347,7 +353,7 @@ pub trait LanguageModelProviderState: 'static { } } -#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)] +#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd, Serialize, Deserialize)] pub struct LanguageModelId(pub SharedString); #[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)] diff --git a/crates/language_models/Cargo.toml b/crates/language_models/Cargo.toml index c1bea29691..6f2e11f493 100644 --- a/crates/language_models/Cargo.toml +++ b/crates/language_models/Cargo.toml @@ -47,6 +47,7 @@ settings.workspace = true smol.workspace = true strum.workspace = true theme.workspace = true +thiserror.workspace = true tiktoken-rs.workspace = true tokio = { workspace = true, features = ["rt", "rt-multi-thread"] } ui.workspace = true diff --git a/crates/language_models/src/provider/anthropic.rs b/crates/language_models/src/provider/anthropic.rs index 4540a08268..7746d214b4 100644 --- a/crates/language_models/src/provider/anthropic.rs +++ b/crates/language_models/src/provider/anthropic.rs @@ -13,8 +13,9 @@ use gpui::{ use http_client::HttpClient; use language_model::{ AuthenticateError, LanguageModel, LanguageModelCacheConfiguration, LanguageModelId, - LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, - LanguageModelProviderState, LanguageModelRequest, MessageContent, RateLimiter, Role, + LanguageModelKnownError, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, + LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, MessageContent, + RateLimiter, Role, }; use language_model::{LanguageModelCompletionEvent, LanguageModelToolUse, StopReason}; use schemars::JsonSchema; @@ -454,7 +455,12 @@ impl LanguageModel for AnthropicModel { ); let request = self.stream_completion(request, cx); let future = self.request_limiter.stream(async move { - let response = request.await.map_err(|err| anyhow!(err))?; + let response = request + .await + .map_err(|err| match err.downcast::() { + Ok(anthropic_err) => anthropic_err_to_anyhow(anthropic_err), + Err(err) => anyhow!(err), + })?; Ok(map_to_language_model_completion_events(response)) }); async move { Ok(future.await?.boxed()) }.boxed() @@ -746,7 +752,7 @@ pub fn map_to_language_model_completion_events( _ => {} }, Err(err) => { - return Some((vec![Err(anyhow!(err))], state)); + return Some((vec![Err(anthropic_err_to_anyhow(err))], state)); } } } @@ -757,6 +763,16 @@ pub fn map_to_language_model_completion_events( .flat_map(futures::stream::iter) } +pub fn anthropic_err_to_anyhow(err: AnthropicError) -> anyhow::Error { + if let AnthropicError::ApiError(api_err) = &err { + if let Some(tokens) = api_err.match_window_exceeded() { + return anyhow!(LanguageModelKnownError::ContextWindowLimitExceeded { tokens }); + } + } + + anyhow!(err) +} + /// Updates usage data by preferring counts from `new`. fn update_usage(usage: &mut Usage, new: &Usage) { if let Some(input_tokens) = new.input_tokens { diff --git a/crates/language_models/src/provider/cloud.rs b/crates/language_models/src/provider/cloud.rs index 6a08f48522..38d8c79d35 100644 --- a/crates/language_models/src/provider/cloud.rs +++ b/crates/language_models/src/provider/cloud.rs @@ -1,4 +1,4 @@ -use anthropic::{AnthropicError, AnthropicModelMode}; +use anthropic::{AnthropicError, AnthropicModelMode, parse_prompt_too_long}; use anyhow::{Result, anyhow}; use client::{ Client, EXPIRED_LLM_TOKEN_HEADER_NAME, MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME, @@ -14,7 +14,7 @@ use gpui::{AnyElement, AnyView, App, AsyncApp, Context, Entity, Subscription, Ta use http_client::{AsyncBody, HttpClient, Method, Response, StatusCode}; use language_model::{ AuthenticateError, CloudModel, LanguageModel, LanguageModelCacheConfiguration, LanguageModelId, - LanguageModelName, LanguageModelProviderId, LanguageModelProviderName, + LanguageModelKnownError, LanguageModelName, LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, LanguageModelProviderTosView, LanguageModelRequest, LanguageModelToolSchemaFormat, RateLimiter, ZED_CLOUD_PROVIDER_ID, }; @@ -33,6 +33,7 @@ use std::{ time::Duration, }; use strum::IntoEnumIterator; +use thiserror::Error; use ui::{TintColor, prelude::*}; use crate::AllLanguageModelSettings; @@ -575,14 +576,19 @@ impl CloudLanguageModel { } else { let mut body = String::new(); response.body_mut().read_to_string(&mut body).await?; - return Err(anyhow!( - "cloud language model completion failed with status {status}: {body}", - )); + return Err(anyhow!(ApiError { status, body })); } } } } +#[derive(Debug, Error)] +#[error("cloud language model completion failed with status {status}: {body}")] +struct ApiError { + status: StatusCode, + body: String, +} + impl LanguageModel for CloudLanguageModel { fn id(&self) -> LanguageModelId { self.id.clone() @@ -696,7 +702,23 @@ impl LanguageModel for CloudLanguageModel { )?)?, }, ) - .await?; + .await + .map_err(|err| match err.downcast::() { + Ok(api_err) => { + if api_err.status == StatusCode::BAD_REQUEST { + if let Some(tokens) = parse_prompt_too_long(&api_err.body) { + return anyhow!( + LanguageModelKnownError::ContextWindowLimitExceeded { + tokens + } + ); + } + } + anyhow!(api_err) + } + Err(err) => anyhow!(err), + })?; + Ok( crate::provider::anthropic::map_to_language_model_completion_events( Box::pin(response_lines(response).map_err(AnthropicError::Other)),