agent: Fix max token count mismatch when not using burn mode (#34025)

Closes #31854

Release Notes:

- agent: Fixed an issue where the maximum token count would be displayed
incorrectly when burn mode was not being used.
This commit is contained in:
Bennet Bo Fenner 2025-07-07 23:13:24 +02:00 committed by GitHub
parent a9107dfaeb
commit 66a1c356bf
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 64 additions and 17 deletions

4
Cargo.lock generated
View file

@ -20145,9 +20145,9 @@ dependencies = [
[[package]] [[package]]
name = "zed_llm_client" name = "zed_llm_client"
version = "0.8.5" version = "0.8.6"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c740e29260b8797ad252c202ea09a255b3cbc13f30faaf92fb6b2490336106e0" checksum = "6607f74dee2a18a9ce0f091844944a0e59881359ab62e0768fb0618f55d4c1dc"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"serde", "serde",

View file

@ -625,7 +625,7 @@ wasmtime = { version = "29", default-features = false, features = [
wasmtime-wasi = "29" wasmtime-wasi = "29"
which = "6.0.0" which = "6.0.0"
workspace-hack = "0.1.0" workspace-hack = "0.1.0"
zed_llm_client = "= 0.8.5" zed_llm_client = "= 0.8.6"
zstd = "0.11" zstd = "0.11"
[workspace.dependencies.async-stripe] [workspace.dependencies.async-stripe]

View file

@ -23,10 +23,11 @@ use gpui::{
}; };
use language_model::{ use language_model::{
ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
LanguageModelId, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, LanguageModelExt as _, LanguageModelId, LanguageModelRegistry, LanguageModelRequest,
LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
LanguageModelToolUse, LanguageModelToolUseId, MessageContent, ModelRequestLimitReachedError, LanguageModelToolResultContent, LanguageModelToolUse, LanguageModelToolUseId, MessageContent,
PaymentRequiredError, Role, SelectedModel, StopReason, TokenUsage, ModelRequestLimitReachedError, PaymentRequiredError, Role, SelectedModel, StopReason,
TokenUsage,
}; };
use postage::stream::Stream as _; use postage::stream::Stream as _;
use project::{ use project::{
@ -1582,6 +1583,7 @@ impl Thread {
tool_name, tool_name,
tool_output, tool_output,
self.configured_model.as_ref(), self.configured_model.as_ref(),
self.completion_mode,
); );
pending_tool_use pending_tool_use
@ -1610,6 +1612,10 @@ impl Thread {
prompt_id: prompt_id.clone(), prompt_id: prompt_id.clone(),
}; };
let completion_mode = request
.mode
.unwrap_or(zed_llm_client::CompletionMode::Normal);
self.last_received_chunk_at = Some(Instant::now()); self.last_received_chunk_at = Some(Instant::now());
let task = cx.spawn(async move |thread, cx| { let task = cx.spawn(async move |thread, cx| {
@ -1959,7 +1965,11 @@ impl Thread {
.unwrap_or(0) .unwrap_or(0)
// We know the context window was exceeded in practice, so if our estimate was // We know the context window was exceeded in practice, so if our estimate was
// lower than max tokens, the estimate was wrong; return that we exceeded by 1. // lower than max tokens, the estimate was wrong; return that we exceeded by 1.
.max(model.max_token_count().saturating_add(1)) .max(
model
.max_token_count_for_mode(completion_mode)
.saturating_add(1),
)
}); });
thread.exceeded_window_error = Some(ExceededWindowError { thread.exceeded_window_error = Some(ExceededWindowError {
model_id: model.id(), model_id: model.id(),
@ -2507,6 +2517,7 @@ impl Thread {
hallucinated_tool_name, hallucinated_tool_name,
Err(anyhow!("Missing tool call: {error_message}")), Err(anyhow!("Missing tool call: {error_message}")),
self.configured_model.as_ref(), self.configured_model.as_ref(),
self.completion_mode,
); );
cx.emit(ThreadEvent::MissingToolUse { cx.emit(ThreadEvent::MissingToolUse {
@ -2533,6 +2544,7 @@ impl Thread {
tool_name, tool_name,
Err(anyhow!("Error parsing input JSON: {error}")), Err(anyhow!("Error parsing input JSON: {error}")),
self.configured_model.as_ref(), self.configured_model.as_ref(),
self.completion_mode,
); );
let ui_text = if let Some(pending_tool_use) = &pending_tool_use { let ui_text = if let Some(pending_tool_use) = &pending_tool_use {
pending_tool_use.ui_text.clone() pending_tool_use.ui_text.clone()
@ -2608,6 +2620,7 @@ impl Thread {
tool_name, tool_name,
output, output,
thread.configured_model.as_ref(), thread.configured_model.as_ref(),
thread.completion_mode,
); );
thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx); thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
}) })
@ -3084,7 +3097,9 @@ impl Thread {
return TotalTokenUsage::default(); return TotalTokenUsage::default();
}; };
let max = model.model.max_token_count(); let max = model
.model
.max_token_count_for_mode(self.completion_mode().into());
let index = self let index = self
.messages .messages
@ -3111,7 +3126,9 @@ impl Thread {
pub fn total_token_usage(&self) -> Option<TotalTokenUsage> { pub fn total_token_usage(&self) -> Option<TotalTokenUsage> {
let model = self.configured_model.as_ref()?; let model = self.configured_model.as_ref()?;
let max = model.model.max_token_count(); let max = model
.model
.max_token_count_for_mode(self.completion_mode().into());
if let Some(exceeded_error) = &self.exceeded_window_error { if let Some(exceeded_error) = &self.exceeded_window_error {
if model.model.id() == exceeded_error.model_id { if model.model.id() == exceeded_error.model_id {
@ -3177,6 +3194,7 @@ impl Thread {
tool_name, tool_name,
err, err,
self.configured_model.as_ref(), self.configured_model.as_ref(),
self.completion_mode,
); );
self.tool_finished(tool_use_id.clone(), None, true, window, cx); self.tool_finished(tool_use_id.clone(), None, true, window, cx);
} }

View file

@ -2,6 +2,7 @@ use crate::{
thread::{MessageId, PromptId, ThreadId}, thread::{MessageId, PromptId, ThreadId},
thread_store::SerializedMessage, thread_store::SerializedMessage,
}; };
use agent_settings::CompletionMode;
use anyhow::Result; use anyhow::Result;
use assistant_tool::{ use assistant_tool::{
AnyToolCard, Tool, ToolResultContent, ToolResultOutput, ToolUseStatus, ToolWorkingSet, AnyToolCard, Tool, ToolResultContent, ToolResultOutput, ToolUseStatus, ToolWorkingSet,
@ -11,8 +12,9 @@ use futures::{FutureExt as _, future::Shared};
use gpui::{App, Entity, SharedString, Task, Window}; use gpui::{App, Entity, SharedString, Task, Window};
use icons::IconName; use icons::IconName;
use language_model::{ use language_model::{
ConfiguredModel, LanguageModel, LanguageModelRequest, LanguageModelToolResult, ConfiguredModel, LanguageModel, LanguageModelExt, LanguageModelRequest,
LanguageModelToolResultContent, LanguageModelToolUse, LanguageModelToolUseId, Role, LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolUse,
LanguageModelToolUseId, Role,
}; };
use project::Project; use project::Project;
use std::sync::Arc; use std::sync::Arc;
@ -400,6 +402,7 @@ impl ToolUseState {
tool_name: Arc<str>, tool_name: Arc<str>,
output: Result<ToolResultOutput>, output: Result<ToolResultOutput>,
configured_model: Option<&ConfiguredModel>, configured_model: Option<&ConfiguredModel>,
completion_mode: CompletionMode,
) -> Option<PendingToolUse> { ) -> Option<PendingToolUse> {
let metadata = self.tool_use_metadata_by_id.remove(&tool_use_id); let metadata = self.tool_use_metadata_by_id.remove(&tool_use_id);
@ -426,7 +429,10 @@ impl ToolUseState {
// Protect from overly large output // Protect from overly large output
let tool_output_limit = configured_model let tool_output_limit = configured_model
.map(|model| model.model.max_token_count() as usize * BYTES_PER_TOKEN_ESTIMATE) .map(|model| {
model.model.max_token_count_for_mode(completion_mode.into()) as usize
* BYTES_PER_TOKEN_ESTIMATE
})
.unwrap_or(usize::MAX); .unwrap_or(usize::MAX);
let content = match tool_result { let content = match tool_result {

View file

@ -38,8 +38,8 @@ use language::{
language_settings::{SoftWrap, all_language_settings}, language_settings::{SoftWrap, all_language_settings},
}; };
use language_model::{ use language_model::{
ConfigurationError, LanguageModelImage, LanguageModelProviderTosView, LanguageModelRegistry, ConfigurationError, LanguageModelExt, LanguageModelImage, LanguageModelProviderTosView,
Role, LanguageModelRegistry, Role,
}; };
use multi_buffer::MultiBufferRow; use multi_buffer::MultiBufferRow;
use picker::{Picker, popover_menu::PickerPopoverMenu}; use picker::{Picker, popover_menu::PickerPopoverMenu};
@ -3063,7 +3063,7 @@ fn token_state(context: &Entity<AssistantContext>, cx: &App) -> Option<TokenStat
.default_model()? .default_model()?
.model; .model;
let token_count = context.read(cx).token_count()?; let token_count = context.read(cx).token_count()?;
let max_token_count = model.max_token_count(); let max_token_count = model.max_token_count_for_mode(context.read(cx).completion_mode().into());
let token_state = if max_token_count.saturating_sub(token_count) == 0 { let token_state = if max_token_count.saturating_sub(token_count) == 0 {
TokenState::NoTokensLeft { TokenState::NoTokensLeft {
max_token_count, max_token_count,

View file

@ -26,7 +26,7 @@ use std::time::Duration;
use std::{fmt, io}; use std::{fmt, io};
use thiserror::Error; use thiserror::Error;
use util::serde::is_default; use util::serde::is_default;
use zed_llm_client::CompletionRequestStatus; use zed_llm_client::{CompletionMode, CompletionRequestStatus};
pub use crate::model::*; pub use crate::model::*;
pub use crate::rate_limiter::*; pub use crate::rate_limiter::*;
@ -462,6 +462,10 @@ pub trait LanguageModel: Send + Sync {
} }
fn max_token_count(&self) -> u64; fn max_token_count(&self) -> u64;
/// Returns the maximum token count for this model in burn mode (If `supports_burn_mode` is `false` this returns `None`)
fn max_token_count_in_burn_mode(&self) -> Option<u64> {
None
}
fn max_output_tokens(&self) -> Option<u64> { fn max_output_tokens(&self) -> Option<u64> {
None None
} }
@ -557,6 +561,18 @@ pub trait LanguageModel: Send + Sync {
} }
} }
pub trait LanguageModelExt: LanguageModel {
fn max_token_count_for_mode(&self, mode: CompletionMode) -> u64 {
match mode {
CompletionMode::Normal => self.max_token_count(),
CompletionMode::Max => self
.max_token_count_in_burn_mode()
.unwrap_or_else(|| self.max_token_count()),
}
}
}
impl LanguageModelExt for dyn LanguageModel {}
pub trait LanguageModelTool: 'static + DeserializeOwned + JsonSchema { pub trait LanguageModelTool: 'static + DeserializeOwned + JsonSchema {
fn name() -> String; fn name() -> String;
fn description() -> String; fn description() -> String;

View file

@ -730,6 +730,13 @@ impl LanguageModel for CloudLanguageModel {
self.model.max_token_count as u64 self.model.max_token_count as u64
} }
fn max_token_count_in_burn_mode(&self) -> Option<u64> {
self.model
.max_token_count_in_max_mode
.filter(|_| self.model.supports_max_mode)
.map(|max_token_count| max_token_count as u64)
}
fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> { fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
match &self.model.provider { match &self.model.provider {
zed_llm_client::LanguageModelProvider::Anthropic => { zed_llm_client::LanguageModelProvider::Anthropic => {