agent: Handle context window exceeded errors from Anthropic (#28688)
 Release Notes: - agent: Handle context window exceeded errors from Anthropic
This commit is contained in:
parent
4a57664c7f
commit
b45230784d
9 changed files with 190 additions and 28 deletions
|
@ -15,10 +15,11 @@ use futures::{FutureExt, StreamExt as _};
|
|||
use git::repository::DiffType;
|
||||
use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
|
||||
use language_model::{
|
||||
ConfiguredModel, LanguageModel, LanguageModelCompletionEvent, LanguageModelRegistry,
|
||||
LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool,
|
||||
LanguageModelToolResult, LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
|
||||
PaymentRequiredError, Role, StopReason, TokenUsage,
|
||||
ConfiguredModel, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
|
||||
LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
|
||||
LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
|
||||
LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent, PaymentRequiredError,
|
||||
Role, StopReason, TokenUsage,
|
||||
};
|
||||
use project::Project;
|
||||
use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
|
||||
|
@ -228,7 +229,7 @@ pub struct TotalTokenUsage {
|
|||
pub ratio: TokenUsageRatio,
|
||||
}
|
||||
|
||||
#[derive(Default, PartialEq, Eq)]
|
||||
#[derive(Debug, Default, PartialEq, Eq)]
|
||||
pub enum TokenUsageRatio {
|
||||
#[default]
|
||||
Normal,
|
||||
|
@ -260,11 +261,20 @@ pub struct Thread {
|
|||
pending_checkpoint: Option<ThreadCheckpoint>,
|
||||
initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
|
||||
cumulative_token_usage: TokenUsage,
|
||||
exceeded_window_error: Option<ExceededWindowError>,
|
||||
feedback: Option<ThreadFeedback>,
|
||||
message_feedback: HashMap<MessageId, ThreadFeedback>,
|
||||
last_auto_capture_at: Option<Instant>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ExceededWindowError {
|
||||
/// Model used when last message exceeded context window
|
||||
model_id: LanguageModelId,
|
||||
/// Token count including last message
|
||||
token_count: usize,
|
||||
}
|
||||
|
||||
impl Thread {
|
||||
pub fn new(
|
||||
project: Entity<Project>,
|
||||
|
@ -301,6 +311,7 @@ impl Thread {
|
|||
.shared()
|
||||
},
|
||||
cumulative_token_usage: TokenUsage::default(),
|
||||
exceeded_window_error: None,
|
||||
feedback: None,
|
||||
message_feedback: HashMap::default(),
|
||||
last_auto_capture_at: None,
|
||||
|
@ -367,6 +378,7 @@ impl Thread {
|
|||
action_log: cx.new(|_| ActionLog::new(project)),
|
||||
initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
|
||||
cumulative_token_usage: serialized.cumulative_token_usage,
|
||||
exceeded_window_error: None,
|
||||
feedback: None,
|
||||
message_feedback: HashMap::default(),
|
||||
last_auto_capture_at: None,
|
||||
|
@ -817,6 +829,7 @@ impl Thread {
|
|||
initial_project_snapshot,
|
||||
cumulative_token_usage: this.cumulative_token_usage.clone(),
|
||||
detailed_summary_state: this.detailed_summary_state.clone(),
|
||||
exceeded_window_error: this.exceeded_window_error.clone(),
|
||||
})
|
||||
})
|
||||
}
|
||||
|
@ -1129,6 +1142,20 @@ impl Thread {
|
|||
cx.emit(ThreadEvent::ShowError(
|
||||
ThreadError::MaxMonthlySpendReached,
|
||||
));
|
||||
} else if let Some(known_error) =
|
||||
error.downcast_ref::<LanguageModelKnownError>()
|
||||
{
|
||||
match known_error {
|
||||
LanguageModelKnownError::ContextWindowLimitExceeded {
|
||||
tokens,
|
||||
} => {
|
||||
thread.exceeded_window_error = Some(ExceededWindowError {
|
||||
model_id: model.id(),
|
||||
token_count: *tokens,
|
||||
});
|
||||
cx.notify();
|
||||
}
|
||||
}
|
||||
} else {
|
||||
let error_message = error
|
||||
.chain()
|
||||
|
@ -1784,10 +1811,6 @@ impl Thread {
|
|||
&self.project
|
||||
}
|
||||
|
||||
pub fn cumulative_token_usage(&self) -> TokenUsage {
|
||||
self.cumulative_token_usage.clone()
|
||||
}
|
||||
|
||||
pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
|
||||
if !cx.has_flag::<feature_flags::ThreadAutoCapture>() {
|
||||
return;
|
||||
|
@ -1840,6 +1863,16 @@ impl Thread {
|
|||
|
||||
let max = model.model.max_token_count();
|
||||
|
||||
if let Some(exceeded_error) = &self.exceeded_window_error {
|
||||
if model.model.id() == exceeded_error.model_id {
|
||||
return TotalTokenUsage {
|
||||
total: exceeded_error.token_count,
|
||||
max,
|
||||
ratio: TokenUsageRatio::Exceeded,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
|
||||
.unwrap_or("0.8".to_string())
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue