From 66a1c356bfc8510a2528d34c9df825d5db540e7a Mon Sep 17 00:00:00 2001 From: Bennet Bo Fenner Date: Mon, 7 Jul 2025 23:13:24 +0200 Subject: [PATCH] agent: Fix max token count mismatch when not using burn mode (#34025) Closes #31854 Release Notes: - agent: Fixed an issue where the maximum token count would be displayed incorrectly when burn mode was not being used. --- Cargo.lock | 4 +-- Cargo.toml | 2 +- crates/agent/src/thread.rs | 32 +++++++++++++++----- crates/agent/src/tool_use.rs | 12 ++++++-- crates/agent_ui/src/text_thread_editor.rs | 6 ++-- crates/language_model/src/language_model.rs | 18 ++++++++++- crates/language_models/src/provider/cloud.rs | 7 +++++ 7 files changed, 64 insertions(+), 17 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 57b97cb853..a19397bdf9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -20145,9 +20145,9 @@ dependencies = [ [[package]] name = "zed_llm_client" -version = "0.8.5" +version = "0.8.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c740e29260b8797ad252c202ea09a255b3cbc13f30faaf92fb6b2490336106e0" +checksum = "6607f74dee2a18a9ce0f091844944a0e59881359ab62e0768fb0618f55d4c1dc" dependencies = [ "anyhow", "serde", diff --git a/Cargo.toml b/Cargo.toml index 82cbb53397..8dd7892329 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -625,7 +625,7 @@ wasmtime = { version = "29", default-features = false, features = [ wasmtime-wasi = "29" which = "6.0.0" workspace-hack = "0.1.0" -zed_llm_client = "= 0.8.5" +zed_llm_client = "= 0.8.6" zstd = "0.11" [workspace.dependencies.async-stripe] diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index 72417cfe99..1f2654dac5 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -23,10 +23,11 @@ use gpui::{ }; use language_model::{ ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, - LanguageModelId, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, - LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolResultContent, - LanguageModelToolUse, LanguageModelToolUseId, MessageContent, ModelRequestLimitReachedError, - PaymentRequiredError, Role, SelectedModel, StopReason, TokenUsage, + LanguageModelExt as _, LanguageModelId, LanguageModelRegistry, LanguageModelRequest, + LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult, + LanguageModelToolResultContent, LanguageModelToolUse, LanguageModelToolUseId, MessageContent, + ModelRequestLimitReachedError, PaymentRequiredError, Role, SelectedModel, StopReason, + TokenUsage, }; use postage::stream::Stream as _; use project::{ @@ -1582,6 +1583,7 @@ impl Thread { tool_name, tool_output, self.configured_model.as_ref(), + self.completion_mode, ); pending_tool_use @@ -1610,6 +1612,10 @@ impl Thread { prompt_id: prompt_id.clone(), }; + let completion_mode = request + .mode + .unwrap_or(zed_llm_client::CompletionMode::Normal); + self.last_received_chunk_at = Some(Instant::now()); let task = cx.spawn(async move |thread, cx| { @@ -1959,7 +1965,11 @@ impl Thread { .unwrap_or(0) // We know the context window was exceeded in practice, so if our estimate was // lower than max tokens, the estimate was wrong; return that we exceeded by 1. - .max(model.max_token_count().saturating_add(1)) + .max( + model + .max_token_count_for_mode(completion_mode) + .saturating_add(1), + ) }); thread.exceeded_window_error = Some(ExceededWindowError { model_id: model.id(), @@ -2507,6 +2517,7 @@ impl Thread { hallucinated_tool_name, Err(anyhow!("Missing tool call: {error_message}")), self.configured_model.as_ref(), + self.completion_mode, ); cx.emit(ThreadEvent::MissingToolUse { @@ -2533,6 +2544,7 @@ impl Thread { tool_name, Err(anyhow!("Error parsing input JSON: {error}")), self.configured_model.as_ref(), + self.completion_mode, ); let ui_text = if let Some(pending_tool_use) = &pending_tool_use { pending_tool_use.ui_text.clone() @@ -2608,6 +2620,7 @@ impl Thread { tool_name, output, thread.configured_model.as_ref(), + thread.completion_mode, ); thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx); }) @@ -3084,7 +3097,9 @@ impl Thread { return TotalTokenUsage::default(); }; - let max = model.model.max_token_count(); + let max = model + .model + .max_token_count_for_mode(self.completion_mode().into()); let index = self .messages @@ -3111,7 +3126,9 @@ impl Thread { pub fn total_token_usage(&self) -> Option { let model = self.configured_model.as_ref()?; - let max = model.model.max_token_count(); + let max = model + .model + .max_token_count_for_mode(self.completion_mode().into()); if let Some(exceeded_error) = &self.exceeded_window_error { if model.model.id() == exceeded_error.model_id { @@ -3177,6 +3194,7 @@ impl Thread { tool_name, err, self.configured_model.as_ref(), + self.completion_mode, ); self.tool_finished(tool_use_id.clone(), None, true, window, cx); } diff --git a/crates/agent/src/tool_use.rs b/crates/agent/src/tool_use.rs index 76de3d2022..74c719b4e6 100644 --- a/crates/agent/src/tool_use.rs +++ b/crates/agent/src/tool_use.rs @@ -2,6 +2,7 @@ use crate::{ thread::{MessageId, PromptId, ThreadId}, thread_store::SerializedMessage, }; +use agent_settings::CompletionMode; use anyhow::Result; use assistant_tool::{ AnyToolCard, Tool, ToolResultContent, ToolResultOutput, ToolUseStatus, ToolWorkingSet, @@ -11,8 +12,9 @@ use futures::{FutureExt as _, future::Shared}; use gpui::{App, Entity, SharedString, Task, Window}; use icons::IconName; use language_model::{ - ConfiguredModel, LanguageModel, LanguageModelRequest, LanguageModelToolResult, - LanguageModelToolResultContent, LanguageModelToolUse, LanguageModelToolUseId, Role, + ConfiguredModel, LanguageModel, LanguageModelExt, LanguageModelRequest, + LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolUse, + LanguageModelToolUseId, Role, }; use project::Project; use std::sync::Arc; @@ -400,6 +402,7 @@ impl ToolUseState { tool_name: Arc, output: Result, configured_model: Option<&ConfiguredModel>, + completion_mode: CompletionMode, ) -> Option { let metadata = self.tool_use_metadata_by_id.remove(&tool_use_id); @@ -426,7 +429,10 @@ impl ToolUseState { // Protect from overly large output let tool_output_limit = configured_model - .map(|model| model.model.max_token_count() as usize * BYTES_PER_TOKEN_ESTIMATE) + .map(|model| { + model.model.max_token_count_for_mode(completion_mode.into()) as usize + * BYTES_PER_TOKEN_ESTIMATE + }) .unwrap_or(usize::MAX); let content = match tool_result { diff --git a/crates/agent_ui/src/text_thread_editor.rs b/crates/agent_ui/src/text_thread_editor.rs index 465b3b4e58..de7606dbfb 100644 --- a/crates/agent_ui/src/text_thread_editor.rs +++ b/crates/agent_ui/src/text_thread_editor.rs @@ -38,8 +38,8 @@ use language::{ language_settings::{SoftWrap, all_language_settings}, }; use language_model::{ - ConfigurationError, LanguageModelImage, LanguageModelProviderTosView, LanguageModelRegistry, - Role, + ConfigurationError, LanguageModelExt, LanguageModelImage, LanguageModelProviderTosView, + LanguageModelRegistry, Role, }; use multi_buffer::MultiBufferRow; use picker::{Picker, popover_menu::PickerPopoverMenu}; @@ -3063,7 +3063,7 @@ fn token_state(context: &Entity, cx: &App) -> Option u64; + /// Returns the maximum token count for this model in burn mode (If `supports_burn_mode` is `false` this returns `None`) + fn max_token_count_in_burn_mode(&self) -> Option { + None + } fn max_output_tokens(&self) -> Option { None } @@ -557,6 +561,18 @@ pub trait LanguageModel: Send + Sync { } } +pub trait LanguageModelExt: LanguageModel { + fn max_token_count_for_mode(&self, mode: CompletionMode) -> u64 { + match mode { + CompletionMode::Normal => self.max_token_count(), + CompletionMode::Max => self + .max_token_count_in_burn_mode() + .unwrap_or_else(|| self.max_token_count()), + } + } +} +impl LanguageModelExt for dyn LanguageModel {} + pub trait LanguageModelTool: 'static + DeserializeOwned + JsonSchema { fn name() -> String; fn description() -> String; diff --git a/crates/language_models/src/provider/cloud.rs b/crates/language_models/src/provider/cloud.rs index 505caa2e42..1cd673710c 100644 --- a/crates/language_models/src/provider/cloud.rs +++ b/crates/language_models/src/provider/cloud.rs @@ -730,6 +730,13 @@ impl LanguageModel for CloudLanguageModel { self.model.max_token_count as u64 } + fn max_token_count_in_burn_mode(&self) -> Option { + self.model + .max_token_count_in_max_mode + .filter(|_| self.model.supports_max_mode) + .map(|max_token_count| max_token_count as u64) + } + fn cache_configuration(&self) -> Option { match &self.model.provider { zed_llm_client::LanguageModelProvider::Anthropic => {