agent: Fix max token count mismatch when not using burn mode (#34025)
Closes #31854 Release Notes: - agent: Fixed an issue where the maximum token count would be displayed incorrectly when burn mode was not being used.
This commit is contained in:
parent
a9107dfaeb
commit
66a1c356bf
7 changed files with 64 additions and 17 deletions
4
Cargo.lock
generated
4
Cargo.lock
generated
|
@ -20145,9 +20145,9 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "zed_llm_client"
|
name = "zed_llm_client"
|
||||||
version = "0.8.5"
|
version = "0.8.6"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "c740e29260b8797ad252c202ea09a255b3cbc13f30faaf92fb6b2490336106e0"
|
checksum = "6607f74dee2a18a9ce0f091844944a0e59881359ab62e0768fb0618f55d4c1dc"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"anyhow",
|
"anyhow",
|
||||||
"serde",
|
"serde",
|
||||||
|
|
|
@ -625,7 +625,7 @@ wasmtime = { version = "29", default-features = false, features = [
|
||||||
wasmtime-wasi = "29"
|
wasmtime-wasi = "29"
|
||||||
which = "6.0.0"
|
which = "6.0.0"
|
||||||
workspace-hack = "0.1.0"
|
workspace-hack = "0.1.0"
|
||||||
zed_llm_client = "= 0.8.5"
|
zed_llm_client = "= 0.8.6"
|
||||||
zstd = "0.11"
|
zstd = "0.11"
|
||||||
|
|
||||||
[workspace.dependencies.async-stripe]
|
[workspace.dependencies.async-stripe]
|
||||||
|
|
|
@ -23,10 +23,11 @@ use gpui::{
|
||||||
};
|
};
|
||||||
use language_model::{
|
use language_model::{
|
||||||
ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
|
ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
|
||||||
LanguageModelId, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
|
LanguageModelExt as _, LanguageModelId, LanguageModelRegistry, LanguageModelRequest,
|
||||||
LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolResultContent,
|
LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
|
||||||
LanguageModelToolUse, LanguageModelToolUseId, MessageContent, ModelRequestLimitReachedError,
|
LanguageModelToolResultContent, LanguageModelToolUse, LanguageModelToolUseId, MessageContent,
|
||||||
PaymentRequiredError, Role, SelectedModel, StopReason, TokenUsage,
|
ModelRequestLimitReachedError, PaymentRequiredError, Role, SelectedModel, StopReason,
|
||||||
|
TokenUsage,
|
||||||
};
|
};
|
||||||
use postage::stream::Stream as _;
|
use postage::stream::Stream as _;
|
||||||
use project::{
|
use project::{
|
||||||
|
@ -1582,6 +1583,7 @@ impl Thread {
|
||||||
tool_name,
|
tool_name,
|
||||||
tool_output,
|
tool_output,
|
||||||
self.configured_model.as_ref(),
|
self.configured_model.as_ref(),
|
||||||
|
self.completion_mode,
|
||||||
);
|
);
|
||||||
|
|
||||||
pending_tool_use
|
pending_tool_use
|
||||||
|
@ -1610,6 +1612,10 @@ impl Thread {
|
||||||
prompt_id: prompt_id.clone(),
|
prompt_id: prompt_id.clone(),
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let completion_mode = request
|
||||||
|
.mode
|
||||||
|
.unwrap_or(zed_llm_client::CompletionMode::Normal);
|
||||||
|
|
||||||
self.last_received_chunk_at = Some(Instant::now());
|
self.last_received_chunk_at = Some(Instant::now());
|
||||||
|
|
||||||
let task = cx.spawn(async move |thread, cx| {
|
let task = cx.spawn(async move |thread, cx| {
|
||||||
|
@ -1959,7 +1965,11 @@ impl Thread {
|
||||||
.unwrap_or(0)
|
.unwrap_or(0)
|
||||||
// We know the context window was exceeded in practice, so if our estimate was
|
// We know the context window was exceeded in practice, so if our estimate was
|
||||||
// lower than max tokens, the estimate was wrong; return that we exceeded by 1.
|
// lower than max tokens, the estimate was wrong; return that we exceeded by 1.
|
||||||
.max(model.max_token_count().saturating_add(1))
|
.max(
|
||||||
|
model
|
||||||
|
.max_token_count_for_mode(completion_mode)
|
||||||
|
.saturating_add(1),
|
||||||
|
)
|
||||||
});
|
});
|
||||||
thread.exceeded_window_error = Some(ExceededWindowError {
|
thread.exceeded_window_error = Some(ExceededWindowError {
|
||||||
model_id: model.id(),
|
model_id: model.id(),
|
||||||
|
@ -2507,6 +2517,7 @@ impl Thread {
|
||||||
hallucinated_tool_name,
|
hallucinated_tool_name,
|
||||||
Err(anyhow!("Missing tool call: {error_message}")),
|
Err(anyhow!("Missing tool call: {error_message}")),
|
||||||
self.configured_model.as_ref(),
|
self.configured_model.as_ref(),
|
||||||
|
self.completion_mode,
|
||||||
);
|
);
|
||||||
|
|
||||||
cx.emit(ThreadEvent::MissingToolUse {
|
cx.emit(ThreadEvent::MissingToolUse {
|
||||||
|
@ -2533,6 +2544,7 @@ impl Thread {
|
||||||
tool_name,
|
tool_name,
|
||||||
Err(anyhow!("Error parsing input JSON: {error}")),
|
Err(anyhow!("Error parsing input JSON: {error}")),
|
||||||
self.configured_model.as_ref(),
|
self.configured_model.as_ref(),
|
||||||
|
self.completion_mode,
|
||||||
);
|
);
|
||||||
let ui_text = if let Some(pending_tool_use) = &pending_tool_use {
|
let ui_text = if let Some(pending_tool_use) = &pending_tool_use {
|
||||||
pending_tool_use.ui_text.clone()
|
pending_tool_use.ui_text.clone()
|
||||||
|
@ -2608,6 +2620,7 @@ impl Thread {
|
||||||
tool_name,
|
tool_name,
|
||||||
output,
|
output,
|
||||||
thread.configured_model.as_ref(),
|
thread.configured_model.as_ref(),
|
||||||
|
thread.completion_mode,
|
||||||
);
|
);
|
||||||
thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
|
thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
|
||||||
})
|
})
|
||||||
|
@ -3084,7 +3097,9 @@ impl Thread {
|
||||||
return TotalTokenUsage::default();
|
return TotalTokenUsage::default();
|
||||||
};
|
};
|
||||||
|
|
||||||
let max = model.model.max_token_count();
|
let max = model
|
||||||
|
.model
|
||||||
|
.max_token_count_for_mode(self.completion_mode().into());
|
||||||
|
|
||||||
let index = self
|
let index = self
|
||||||
.messages
|
.messages
|
||||||
|
@ -3111,7 +3126,9 @@ impl Thread {
|
||||||
pub fn total_token_usage(&self) -> Option<TotalTokenUsage> {
|
pub fn total_token_usage(&self) -> Option<TotalTokenUsage> {
|
||||||
let model = self.configured_model.as_ref()?;
|
let model = self.configured_model.as_ref()?;
|
||||||
|
|
||||||
let max = model.model.max_token_count();
|
let max = model
|
||||||
|
.model
|
||||||
|
.max_token_count_for_mode(self.completion_mode().into());
|
||||||
|
|
||||||
if let Some(exceeded_error) = &self.exceeded_window_error {
|
if let Some(exceeded_error) = &self.exceeded_window_error {
|
||||||
if model.model.id() == exceeded_error.model_id {
|
if model.model.id() == exceeded_error.model_id {
|
||||||
|
@ -3177,6 +3194,7 @@ impl Thread {
|
||||||
tool_name,
|
tool_name,
|
||||||
err,
|
err,
|
||||||
self.configured_model.as_ref(),
|
self.configured_model.as_ref(),
|
||||||
|
self.completion_mode,
|
||||||
);
|
);
|
||||||
self.tool_finished(tool_use_id.clone(), None, true, window, cx);
|
self.tool_finished(tool_use_id.clone(), None, true, window, cx);
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,6 +2,7 @@ use crate::{
|
||||||
thread::{MessageId, PromptId, ThreadId},
|
thread::{MessageId, PromptId, ThreadId},
|
||||||
thread_store::SerializedMessage,
|
thread_store::SerializedMessage,
|
||||||
};
|
};
|
||||||
|
use agent_settings::CompletionMode;
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use assistant_tool::{
|
use assistant_tool::{
|
||||||
AnyToolCard, Tool, ToolResultContent, ToolResultOutput, ToolUseStatus, ToolWorkingSet,
|
AnyToolCard, Tool, ToolResultContent, ToolResultOutput, ToolUseStatus, ToolWorkingSet,
|
||||||
|
@ -11,8 +12,9 @@ use futures::{FutureExt as _, future::Shared};
|
||||||
use gpui::{App, Entity, SharedString, Task, Window};
|
use gpui::{App, Entity, SharedString, Task, Window};
|
||||||
use icons::IconName;
|
use icons::IconName;
|
||||||
use language_model::{
|
use language_model::{
|
||||||
ConfiguredModel, LanguageModel, LanguageModelRequest, LanguageModelToolResult,
|
ConfiguredModel, LanguageModel, LanguageModelExt, LanguageModelRequest,
|
||||||
LanguageModelToolResultContent, LanguageModelToolUse, LanguageModelToolUseId, Role,
|
LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolUse,
|
||||||
|
LanguageModelToolUseId, Role,
|
||||||
};
|
};
|
||||||
use project::Project;
|
use project::Project;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
@ -400,6 +402,7 @@ impl ToolUseState {
|
||||||
tool_name: Arc<str>,
|
tool_name: Arc<str>,
|
||||||
output: Result<ToolResultOutput>,
|
output: Result<ToolResultOutput>,
|
||||||
configured_model: Option<&ConfiguredModel>,
|
configured_model: Option<&ConfiguredModel>,
|
||||||
|
completion_mode: CompletionMode,
|
||||||
) -> Option<PendingToolUse> {
|
) -> Option<PendingToolUse> {
|
||||||
let metadata = self.tool_use_metadata_by_id.remove(&tool_use_id);
|
let metadata = self.tool_use_metadata_by_id.remove(&tool_use_id);
|
||||||
|
|
||||||
|
@ -426,7 +429,10 @@ impl ToolUseState {
|
||||||
|
|
||||||
// Protect from overly large output
|
// Protect from overly large output
|
||||||
let tool_output_limit = configured_model
|
let tool_output_limit = configured_model
|
||||||
.map(|model| model.model.max_token_count() as usize * BYTES_PER_TOKEN_ESTIMATE)
|
.map(|model| {
|
||||||
|
model.model.max_token_count_for_mode(completion_mode.into()) as usize
|
||||||
|
* BYTES_PER_TOKEN_ESTIMATE
|
||||||
|
})
|
||||||
.unwrap_or(usize::MAX);
|
.unwrap_or(usize::MAX);
|
||||||
|
|
||||||
let content = match tool_result {
|
let content = match tool_result {
|
||||||
|
|
|
@ -38,8 +38,8 @@ use language::{
|
||||||
language_settings::{SoftWrap, all_language_settings},
|
language_settings::{SoftWrap, all_language_settings},
|
||||||
};
|
};
|
||||||
use language_model::{
|
use language_model::{
|
||||||
ConfigurationError, LanguageModelImage, LanguageModelProviderTosView, LanguageModelRegistry,
|
ConfigurationError, LanguageModelExt, LanguageModelImage, LanguageModelProviderTosView,
|
||||||
Role,
|
LanguageModelRegistry, Role,
|
||||||
};
|
};
|
||||||
use multi_buffer::MultiBufferRow;
|
use multi_buffer::MultiBufferRow;
|
||||||
use picker::{Picker, popover_menu::PickerPopoverMenu};
|
use picker::{Picker, popover_menu::PickerPopoverMenu};
|
||||||
|
@ -3063,7 +3063,7 @@ fn token_state(context: &Entity<AssistantContext>, cx: &App) -> Option<TokenStat
|
||||||
.default_model()?
|
.default_model()?
|
||||||
.model;
|
.model;
|
||||||
let token_count = context.read(cx).token_count()?;
|
let token_count = context.read(cx).token_count()?;
|
||||||
let max_token_count = model.max_token_count();
|
let max_token_count = model.max_token_count_for_mode(context.read(cx).completion_mode().into());
|
||||||
let token_state = if max_token_count.saturating_sub(token_count) == 0 {
|
let token_state = if max_token_count.saturating_sub(token_count) == 0 {
|
||||||
TokenState::NoTokensLeft {
|
TokenState::NoTokensLeft {
|
||||||
max_token_count,
|
max_token_count,
|
||||||
|
|
|
@ -26,7 +26,7 @@ use std::time::Duration;
|
||||||
use std::{fmt, io};
|
use std::{fmt, io};
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use util::serde::is_default;
|
use util::serde::is_default;
|
||||||
use zed_llm_client::CompletionRequestStatus;
|
use zed_llm_client::{CompletionMode, CompletionRequestStatus};
|
||||||
|
|
||||||
pub use crate::model::*;
|
pub use crate::model::*;
|
||||||
pub use crate::rate_limiter::*;
|
pub use crate::rate_limiter::*;
|
||||||
|
@ -462,6 +462,10 @@ pub trait LanguageModel: Send + Sync {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn max_token_count(&self) -> u64;
|
fn max_token_count(&self) -> u64;
|
||||||
|
/// Returns the maximum token count for this model in burn mode (If `supports_burn_mode` is `false` this returns `None`)
|
||||||
|
fn max_token_count_in_burn_mode(&self) -> Option<u64> {
|
||||||
|
None
|
||||||
|
}
|
||||||
fn max_output_tokens(&self) -> Option<u64> {
|
fn max_output_tokens(&self) -> Option<u64> {
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
|
@ -557,6 +561,18 @@ pub trait LanguageModel: Send + Sync {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub trait LanguageModelExt: LanguageModel {
|
||||||
|
fn max_token_count_for_mode(&self, mode: CompletionMode) -> u64 {
|
||||||
|
match mode {
|
||||||
|
CompletionMode::Normal => self.max_token_count(),
|
||||||
|
CompletionMode::Max => self
|
||||||
|
.max_token_count_in_burn_mode()
|
||||||
|
.unwrap_or_else(|| self.max_token_count()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
impl LanguageModelExt for dyn LanguageModel {}
|
||||||
|
|
||||||
pub trait LanguageModelTool: 'static + DeserializeOwned + JsonSchema {
|
pub trait LanguageModelTool: 'static + DeserializeOwned + JsonSchema {
|
||||||
fn name() -> String;
|
fn name() -> String;
|
||||||
fn description() -> String;
|
fn description() -> String;
|
||||||
|
|
|
@ -730,6 +730,13 @@ impl LanguageModel for CloudLanguageModel {
|
||||||
self.model.max_token_count as u64
|
self.model.max_token_count as u64
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn max_token_count_in_burn_mode(&self) -> Option<u64> {
|
||||||
|
self.model
|
||||||
|
.max_token_count_in_max_mode
|
||||||
|
.filter(|_| self.model.supports_max_mode)
|
||||||
|
.map(|max_token_count| max_token_count as u64)
|
||||||
|
}
|
||||||
|
|
||||||
fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
|
fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
|
||||||
match &self.model.provider {
|
match &self.model.provider {
|
||||||
zed_llm_client::LanguageModelProvider::Anthropic => {
|
zed_llm_client::LanguageModelProvider::Anthropic => {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue