agent: Handle context window exceeded errors from Anthropic (#28688)
 Release Notes: - agent: Handle context window exceeded errors from Anthropic
This commit is contained in:
parent
4a57664c7f
commit
b45230784d
9 changed files with 190 additions and 28 deletions
1
Cargo.lock
generated
1
Cargo.lock
generated
|
@ -7708,6 +7708,7 @@ dependencies = [
|
|||
"smol",
|
||||
"strum",
|
||||
"theme",
|
||||
"thiserror 2.0.12",
|
||||
"tiktoken-rs",
|
||||
"tokio",
|
||||
"ui",
|
||||
|
|
|
@ -761,13 +761,29 @@ impl MessageEditor {
|
|||
})
|
||||
}
|
||||
|
||||
fn render_reaching_token_limit(&self, line_height: Pixels, cx: &mut Context<Self>) -> Div {
|
||||
fn render_token_limit_callout(
|
||||
&self,
|
||||
line_height: Pixels,
|
||||
token_usage_ratio: TokenUsageRatio,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Div {
|
||||
let heading = if token_usage_ratio == TokenUsageRatio::Exceeded {
|
||||
"Thread reached the token limit"
|
||||
} else {
|
||||
"Thread reaching the token limit soon"
|
||||
};
|
||||
|
||||
h_flex()
|
||||
.p_2()
|
||||
.gap_2()
|
||||
.flex_wrap()
|
||||
.justify_between()
|
||||
.bg(cx.theme().status().warning_background.opacity(0.1))
|
||||
.bg(
|
||||
if token_usage_ratio == TokenUsageRatio::Exceeded {
|
||||
cx.theme().status().error_background.opacity(0.1)
|
||||
} else {
|
||||
cx.theme().status().warning_background.opacity(0.1)
|
||||
})
|
||||
.border_t_1()
|
||||
.border_color(cx.theme().colors().border)
|
||||
.child(
|
||||
|
@ -779,15 +795,21 @@ impl MessageEditor {
|
|||
.h(line_height)
|
||||
.justify_center()
|
||||
.child(
|
||||
Icon::new(IconName::Warning)
|
||||
.color(Color::Warning)
|
||||
.size(IconSize::XSmall),
|
||||
if token_usage_ratio == TokenUsageRatio::Exceeded {
|
||||
Icon::new(IconName::X)
|
||||
.color(Color::Error)
|
||||
.size(IconSize::XSmall)
|
||||
} else {
|
||||
Icon::new(IconName::Warning)
|
||||
.color(Color::Warning)
|
||||
.size(IconSize::XSmall)
|
||||
}
|
||||
),
|
||||
)
|
||||
.child(
|
||||
v_flex()
|
||||
.mr_auto()
|
||||
.child(Label::new("Thread reaching the token limit soon").size(LabelSize::Small))
|
||||
.child(Label::new(heading).size(LabelSize::Small))
|
||||
.child(
|
||||
Label::new(
|
||||
"Start a new thread from a summary to continue the conversation.",
|
||||
|
@ -875,7 +897,13 @@ impl Render for MessageEditor {
|
|||
.child(self.render_editor(font_size, line_height, window, cx))
|
||||
.when(
|
||||
total_token_usage.ratio != TokenUsageRatio::Normal,
|
||||
|parent| parent.child(self.render_reaching_token_limit(line_height, cx)),
|
||||
|parent| {
|
||||
parent.child(self.render_token_limit_callout(
|
||||
line_height,
|
||||
total_token_usage.ratio,
|
||||
cx,
|
||||
))
|
||||
},
|
||||
)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -15,10 +15,11 @@ use futures::{FutureExt, StreamExt as _};
|
|||
use git::repository::DiffType;
|
||||
use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
|
||||
use language_model::{
|
||||
ConfiguredModel, LanguageModel, LanguageModelCompletionEvent, LanguageModelRegistry,
|
||||
LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool,
|
||||
LanguageModelToolResult, LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
|
||||
PaymentRequiredError, Role, StopReason, TokenUsage,
|
||||
ConfiguredModel, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
|
||||
LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
|
||||
LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
|
||||
LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent, PaymentRequiredError,
|
||||
Role, StopReason, TokenUsage,
|
||||
};
|
||||
use project::Project;
|
||||
use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
|
||||
|
@ -228,7 +229,7 @@ pub struct TotalTokenUsage {
|
|||
pub ratio: TokenUsageRatio,
|
||||
}
|
||||
|
||||
#[derive(Default, PartialEq, Eq)]
|
||||
#[derive(Debug, Default, PartialEq, Eq)]
|
||||
pub enum TokenUsageRatio {
|
||||
#[default]
|
||||
Normal,
|
||||
|
@ -260,11 +261,20 @@ pub struct Thread {
|
|||
pending_checkpoint: Option<ThreadCheckpoint>,
|
||||
initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
|
||||
cumulative_token_usage: TokenUsage,
|
||||
exceeded_window_error: Option<ExceededWindowError>,
|
||||
feedback: Option<ThreadFeedback>,
|
||||
message_feedback: HashMap<MessageId, ThreadFeedback>,
|
||||
last_auto_capture_at: Option<Instant>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ExceededWindowError {
|
||||
/// Model used when last message exceeded context window
|
||||
model_id: LanguageModelId,
|
||||
/// Token count including last message
|
||||
token_count: usize,
|
||||
}
|
||||
|
||||
impl Thread {
|
||||
pub fn new(
|
||||
project: Entity<Project>,
|
||||
|
@ -301,6 +311,7 @@ impl Thread {
|
|||
.shared()
|
||||
},
|
||||
cumulative_token_usage: TokenUsage::default(),
|
||||
exceeded_window_error: None,
|
||||
feedback: None,
|
||||
message_feedback: HashMap::default(),
|
||||
last_auto_capture_at: None,
|
||||
|
@ -367,6 +378,7 @@ impl Thread {
|
|||
action_log: cx.new(|_| ActionLog::new(project)),
|
||||
initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
|
||||
cumulative_token_usage: serialized.cumulative_token_usage,
|
||||
exceeded_window_error: None,
|
||||
feedback: None,
|
||||
message_feedback: HashMap::default(),
|
||||
last_auto_capture_at: None,
|
||||
|
@ -817,6 +829,7 @@ impl Thread {
|
|||
initial_project_snapshot,
|
||||
cumulative_token_usage: this.cumulative_token_usage.clone(),
|
||||
detailed_summary_state: this.detailed_summary_state.clone(),
|
||||
exceeded_window_error: this.exceeded_window_error.clone(),
|
||||
})
|
||||
})
|
||||
}
|
||||
|
@ -1129,6 +1142,20 @@ impl Thread {
|
|||
cx.emit(ThreadEvent::ShowError(
|
||||
ThreadError::MaxMonthlySpendReached,
|
||||
));
|
||||
} else if let Some(known_error) =
|
||||
error.downcast_ref::<LanguageModelKnownError>()
|
||||
{
|
||||
match known_error {
|
||||
LanguageModelKnownError::ContextWindowLimitExceeded {
|
||||
tokens,
|
||||
} => {
|
||||
thread.exceeded_window_error = Some(ExceededWindowError {
|
||||
model_id: model.id(),
|
||||
token_count: *tokens,
|
||||
});
|
||||
cx.notify();
|
||||
}
|
||||
}
|
||||
} else {
|
||||
let error_message = error
|
||||
.chain()
|
||||
|
@ -1784,10 +1811,6 @@ impl Thread {
|
|||
&self.project
|
||||
}
|
||||
|
||||
pub fn cumulative_token_usage(&self) -> TokenUsage {
|
||||
self.cumulative_token_usage.clone()
|
||||
}
|
||||
|
||||
pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
|
||||
if !cx.has_flag::<feature_flags::ThreadAutoCapture>() {
|
||||
return;
|
||||
|
@ -1840,6 +1863,16 @@ impl Thread {
|
|||
|
||||
let max = model.model.max_token_count();
|
||||
|
||||
if let Some(exceeded_error) = &self.exceeded_window_error {
|
||||
if model.model.id() == exceeded_error.model_id {
|
||||
return TotalTokenUsage {
|
||||
total: exceeded_error.token_count,
|
||||
max,
|
||||
ratio: TokenUsageRatio::Exceeded,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
|
||||
.unwrap_or("0.8".to_string())
|
||||
|
|
|
@ -27,7 +27,9 @@ use serde::{Deserialize, Serialize};
|
|||
use settings::{Settings as _, SettingsStore};
|
||||
use util::ResultExt as _;
|
||||
|
||||
use crate::thread::{DetailedSummaryState, MessageId, ProjectSnapshot, Thread, ThreadId};
|
||||
use crate::thread::{
|
||||
DetailedSummaryState, ExceededWindowError, MessageId, ProjectSnapshot, Thread, ThreadId,
|
||||
};
|
||||
|
||||
const RULES_FILE_NAMES: [&'static str; 6] = [
|
||||
".rules",
|
||||
|
@ -491,6 +493,8 @@ pub struct SerializedThread {
|
|||
pub cumulative_token_usage: TokenUsage,
|
||||
#[serde(default)]
|
||||
pub detailed_summary_state: DetailedSummaryState,
|
||||
#[serde(default)]
|
||||
pub exceeded_window_error: Option<ExceededWindowError>,
|
||||
}
|
||||
|
||||
impl SerializedThread {
|
||||
|
@ -577,6 +581,7 @@ impl LegacySerializedThread {
|
|||
initial_project_snapshot: self.initial_project_snapshot,
|
||||
cumulative_token_usage: TokenUsage::default(),
|
||||
detailed_summary_state: DetailedSummaryState::default(),
|
||||
exceeded_window_error: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -724,4 +724,54 @@ impl ApiError {
|
|||
pub fn is_rate_limit_error(&self) -> bool {
|
||||
matches!(self.error_type.as_str(), "rate_limit_error")
|
||||
}
|
||||
|
||||
pub fn match_window_exceeded(&self) -> Option<usize> {
|
||||
let Some(ApiErrorCode::InvalidRequestError) = self.code() else {
|
||||
return None;
|
||||
};
|
||||
|
||||
parse_prompt_too_long(&self.message)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn parse_prompt_too_long(message: &str) -> Option<usize> {
|
||||
message
|
||||
.strip_prefix("prompt is too long: ")?
|
||||
.split_once(" tokens")?
|
||||
.0
|
||||
.parse::<usize>()
|
||||
.ok()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_match_window_exceeded() {
|
||||
let error = ApiError {
|
||||
error_type: "invalid_request_error".to_string(),
|
||||
message: "prompt is too long: 220000 tokens > 200000".to_string(),
|
||||
};
|
||||
assert_eq!(error.match_window_exceeded(), Some(220_000));
|
||||
|
||||
let error = ApiError {
|
||||
error_type: "invalid_request_error".to_string(),
|
||||
message: "prompt is too long: 1234953 tokens".to_string(),
|
||||
};
|
||||
assert_eq!(error.match_window_exceeded(), Some(1234953));
|
||||
|
||||
let error = ApiError {
|
||||
error_type: "invalid_request_error".to_string(),
|
||||
message: "not a prompt length error".to_string(),
|
||||
};
|
||||
assert_eq!(error.match_window_exceeded(), None);
|
||||
|
||||
let error = ApiError {
|
||||
error_type: "rate_limit_error".to_string(),
|
||||
message: "prompt is too long: 12345 tokens".to_string(),
|
||||
};
|
||||
assert_eq!(error.match_window_exceeded(), None);
|
||||
|
||||
let error = ApiError {
|
||||
error_type: "invalid_request_error".to_string(),
|
||||
message: "prompt is too long: invalid tokens".to_string(),
|
||||
};
|
||||
assert_eq!(error.match_window_exceeded(), None);
|
||||
}
|
||||
|
|
|
@ -278,6 +278,12 @@ pub trait LanguageModel: Send + Sync {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum LanguageModelKnownError {
|
||||
#[error("Context window limit exceeded ({tokens})")]
|
||||
ContextWindowLimitExceeded { tokens: usize },
|
||||
}
|
||||
|
||||
pub trait LanguageModelTool: 'static + DeserializeOwned + JsonSchema {
|
||||
fn name() -> String;
|
||||
fn description() -> String;
|
||||
|
@ -347,7 +353,7 @@ pub trait LanguageModelProviderState: 'static {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
|
||||
#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd, Serialize, Deserialize)]
|
||||
pub struct LanguageModelId(pub SharedString);
|
||||
|
||||
#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
|
||||
|
|
|
@ -47,6 +47,7 @@ settings.workspace = true
|
|||
smol.workspace = true
|
||||
strum.workspace = true
|
||||
theme.workspace = true
|
||||
thiserror.workspace = true
|
||||
tiktoken-rs.workspace = true
|
||||
tokio = { workspace = true, features = ["rt", "rt-multi-thread"] }
|
||||
ui.workspace = true
|
||||
|
|
|
@ -13,8 +13,9 @@ use gpui::{
|
|||
use http_client::HttpClient;
|
||||
use language_model::{
|
||||
AuthenticateError, LanguageModel, LanguageModelCacheConfiguration, LanguageModelId,
|
||||
LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
|
||||
LanguageModelProviderState, LanguageModelRequest, MessageContent, RateLimiter, Role,
|
||||
LanguageModelKnownError, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
|
||||
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, MessageContent,
|
||||
RateLimiter, Role,
|
||||
};
|
||||
use language_model::{LanguageModelCompletionEvent, LanguageModelToolUse, StopReason};
|
||||
use schemars::JsonSchema;
|
||||
|
@ -454,7 +455,12 @@ impl LanguageModel for AnthropicModel {
|
|||
);
|
||||
let request = self.stream_completion(request, cx);
|
||||
let future = self.request_limiter.stream(async move {
|
||||
let response = request.await.map_err(|err| anyhow!(err))?;
|
||||
let response = request
|
||||
.await
|
||||
.map_err(|err| match err.downcast::<AnthropicError>() {
|
||||
Ok(anthropic_err) => anthropic_err_to_anyhow(anthropic_err),
|
||||
Err(err) => anyhow!(err),
|
||||
})?;
|
||||
Ok(map_to_language_model_completion_events(response))
|
||||
});
|
||||
async move { Ok(future.await?.boxed()) }.boxed()
|
||||
|
@ -746,7 +752,7 @@ pub fn map_to_language_model_completion_events(
|
|||
_ => {}
|
||||
},
|
||||
Err(err) => {
|
||||
return Some((vec![Err(anyhow!(err))], state));
|
||||
return Some((vec![Err(anthropic_err_to_anyhow(err))], state));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -757,6 +763,16 @@ pub fn map_to_language_model_completion_events(
|
|||
.flat_map(futures::stream::iter)
|
||||
}
|
||||
|
||||
pub fn anthropic_err_to_anyhow(err: AnthropicError) -> anyhow::Error {
|
||||
if let AnthropicError::ApiError(api_err) = &err {
|
||||
if let Some(tokens) = api_err.match_window_exceeded() {
|
||||
return anyhow!(LanguageModelKnownError::ContextWindowLimitExceeded { tokens });
|
||||
}
|
||||
}
|
||||
|
||||
anyhow!(err)
|
||||
}
|
||||
|
||||
/// Updates usage data by preferring counts from `new`.
|
||||
fn update_usage(usage: &mut Usage, new: &Usage) {
|
||||
if let Some(input_tokens) = new.input_tokens {
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use anthropic::{AnthropicError, AnthropicModelMode};
|
||||
use anthropic::{AnthropicError, AnthropicModelMode, parse_prompt_too_long};
|
||||
use anyhow::{Result, anyhow};
|
||||
use client::{
|
||||
Client, EXPIRED_LLM_TOKEN_HEADER_NAME, MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME,
|
||||
|
@ -14,7 +14,7 @@ use gpui::{AnyElement, AnyView, App, AsyncApp, Context, Entity, Subscription, Ta
|
|||
use http_client::{AsyncBody, HttpClient, Method, Response, StatusCode};
|
||||
use language_model::{
|
||||
AuthenticateError, CloudModel, LanguageModel, LanguageModelCacheConfiguration, LanguageModelId,
|
||||
LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
|
||||
LanguageModelKnownError, LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
|
||||
LanguageModelProviderState, LanguageModelProviderTosView, LanguageModelRequest,
|
||||
LanguageModelToolSchemaFormat, RateLimiter, ZED_CLOUD_PROVIDER_ID,
|
||||
};
|
||||
|
@ -33,6 +33,7 @@ use std::{
|
|||
time::Duration,
|
||||
};
|
||||
use strum::IntoEnumIterator;
|
||||
use thiserror::Error;
|
||||
use ui::{TintColor, prelude::*};
|
||||
|
||||
use crate::AllLanguageModelSettings;
|
||||
|
@ -575,14 +576,19 @@ impl CloudLanguageModel {
|
|||
} else {
|
||||
let mut body = String::new();
|
||||
response.body_mut().read_to_string(&mut body).await?;
|
||||
return Err(anyhow!(
|
||||
"cloud language model completion failed with status {status}: {body}",
|
||||
));
|
||||
return Err(anyhow!(ApiError { status, body }));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
#[error("cloud language model completion failed with status {status}: {body}")]
|
||||
struct ApiError {
|
||||
status: StatusCode,
|
||||
body: String,
|
||||
}
|
||||
|
||||
impl LanguageModel for CloudLanguageModel {
|
||||
fn id(&self) -> LanguageModelId {
|
||||
self.id.clone()
|
||||
|
@ -696,7 +702,23 @@ impl LanguageModel for CloudLanguageModel {
|
|||
)?)?,
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
.await
|
||||
.map_err(|err| match err.downcast::<ApiError>() {
|
||||
Ok(api_err) => {
|
||||
if api_err.status == StatusCode::BAD_REQUEST {
|
||||
if let Some(tokens) = parse_prompt_too_long(&api_err.body) {
|
||||
return anyhow!(
|
||||
LanguageModelKnownError::ContextWindowLimitExceeded {
|
||||
tokens
|
||||
}
|
||||
);
|
||||
}
|
||||
}
|
||||
anyhow!(api_err)
|
||||
}
|
||||
Err(err) => anyhow!(err),
|
||||
})?;
|
||||
|
||||
Ok(
|
||||
crate::provider::anthropic::map_to_language_model_completion_events(
|
||||
Box::pin(response_lines(response).map_err(AnthropicError::Other)),
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue