agent: Fix max token count mismatch when not using burn mode (#34025)

Closes #31854

Release Notes:

- agent: Fixed an issue where the maximum token count would be displayed
incorrectly when burn mode was not being used.
This commit is contained in:
Bennet Bo Fenner 2025-07-07 23:13:24 +02:00 committed by GitHub
parent a9107dfaeb
commit 66a1c356bf
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 64 additions and 17 deletions

View file

@ -23,10 +23,11 @@ use gpui::{
};
use language_model::{
ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
LanguageModelId, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolResultContent,
LanguageModelToolUse, LanguageModelToolUseId, MessageContent, ModelRequestLimitReachedError,
PaymentRequiredError, Role, SelectedModel, StopReason, TokenUsage,
LanguageModelExt as _, LanguageModelId, LanguageModelRegistry, LanguageModelRequest,
LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
LanguageModelToolResultContent, LanguageModelToolUse, LanguageModelToolUseId, MessageContent,
ModelRequestLimitReachedError, PaymentRequiredError, Role, SelectedModel, StopReason,
TokenUsage,
};
use postage::stream::Stream as _;
use project::{
@ -1582,6 +1583,7 @@ impl Thread {
tool_name,
tool_output,
self.configured_model.as_ref(),
self.completion_mode,
);
pending_tool_use
@ -1610,6 +1612,10 @@ impl Thread {
prompt_id: prompt_id.clone(),
};
let completion_mode = request
.mode
.unwrap_or(zed_llm_client::CompletionMode::Normal);
self.last_received_chunk_at = Some(Instant::now());
let task = cx.spawn(async move |thread, cx| {
@ -1959,7 +1965,11 @@ impl Thread {
.unwrap_or(0)
// We know the context window was exceeded in practice, so if our estimate was
// lower than max tokens, the estimate was wrong; return that we exceeded by 1.
.max(model.max_token_count().saturating_add(1))
.max(
model
.max_token_count_for_mode(completion_mode)
.saturating_add(1),
)
});
thread.exceeded_window_error = Some(ExceededWindowError {
model_id: model.id(),
@ -2507,6 +2517,7 @@ impl Thread {
hallucinated_tool_name,
Err(anyhow!("Missing tool call: {error_message}")),
self.configured_model.as_ref(),
self.completion_mode,
);
cx.emit(ThreadEvent::MissingToolUse {
@ -2533,6 +2544,7 @@ impl Thread {
tool_name,
Err(anyhow!("Error parsing input JSON: {error}")),
self.configured_model.as_ref(),
self.completion_mode,
);
let ui_text = if let Some(pending_tool_use) = &pending_tool_use {
pending_tool_use.ui_text.clone()
@ -2608,6 +2620,7 @@ impl Thread {
tool_name,
output,
thread.configured_model.as_ref(),
thread.completion_mode,
);
thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
})
@ -3084,7 +3097,9 @@ impl Thread {
return TotalTokenUsage::default();
};
let max = model.model.max_token_count();
let max = model
.model
.max_token_count_for_mode(self.completion_mode().into());
let index = self
.messages
@ -3111,7 +3126,9 @@ impl Thread {
pub fn total_token_usage(&self) -> Option<TotalTokenUsage> {
let model = self.configured_model.as_ref()?;
let max = model.model.max_token_count();
let max = model
.model
.max_token_count_for_mode(self.completion_mode().into());
if let Some(exceeded_error) = &self.exceeded_window_error {
if model.model.id() == exceeded_error.model_id {
@ -3177,6 +3194,7 @@ impl Thread {
tool_name,
err,
self.configured_model.as_ref(),
self.completion_mode,
);
self.tool_finished(tool_use_id.clone(), None, true, window, cx);
}