agent: Fix max token count mismatch when not using burn mode (#34025)
Closes #31854 Release Notes: - agent: Fixed an issue where the maximum token count would be displayed incorrectly when burn mode was not being used.
This commit is contained in:
parent
a9107dfaeb
commit
66a1c356bf
7 changed files with 64 additions and 17 deletions
4
Cargo.lock
generated
4
Cargo.lock
generated
|
@ -20145,9 +20145,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "zed_llm_client"
|
||||
version = "0.8.5"
|
||||
version = "0.8.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c740e29260b8797ad252c202ea09a255b3cbc13f30faaf92fb6b2490336106e0"
|
||||
checksum = "6607f74dee2a18a9ce0f091844944a0e59881359ab62e0768fb0618f55d4c1dc"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"serde",
|
||||
|
|
|
@ -625,7 +625,7 @@ wasmtime = { version = "29", default-features = false, features = [
|
|||
wasmtime-wasi = "29"
|
||||
which = "6.0.0"
|
||||
workspace-hack = "0.1.0"
|
||||
zed_llm_client = "= 0.8.5"
|
||||
zed_llm_client = "= 0.8.6"
|
||||
zstd = "0.11"
|
||||
|
||||
[workspace.dependencies.async-stripe]
|
||||
|
|
|
@ -23,10 +23,11 @@ use gpui::{
|
|||
};
|
||||
use language_model::{
|
||||
ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
|
||||
LanguageModelId, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
|
||||
LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolResultContent,
|
||||
LanguageModelToolUse, LanguageModelToolUseId, MessageContent, ModelRequestLimitReachedError,
|
||||
PaymentRequiredError, Role, SelectedModel, StopReason, TokenUsage,
|
||||
LanguageModelExt as _, LanguageModelId, LanguageModelRegistry, LanguageModelRequest,
|
||||
LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
|
||||
LanguageModelToolResultContent, LanguageModelToolUse, LanguageModelToolUseId, MessageContent,
|
||||
ModelRequestLimitReachedError, PaymentRequiredError, Role, SelectedModel, StopReason,
|
||||
TokenUsage,
|
||||
};
|
||||
use postage::stream::Stream as _;
|
||||
use project::{
|
||||
|
@ -1582,6 +1583,7 @@ impl Thread {
|
|||
tool_name,
|
||||
tool_output,
|
||||
self.configured_model.as_ref(),
|
||||
self.completion_mode,
|
||||
);
|
||||
|
||||
pending_tool_use
|
||||
|
@ -1610,6 +1612,10 @@ impl Thread {
|
|||
prompt_id: prompt_id.clone(),
|
||||
};
|
||||
|
||||
let completion_mode = request
|
||||
.mode
|
||||
.unwrap_or(zed_llm_client::CompletionMode::Normal);
|
||||
|
||||
self.last_received_chunk_at = Some(Instant::now());
|
||||
|
||||
let task = cx.spawn(async move |thread, cx| {
|
||||
|
@ -1959,7 +1965,11 @@ impl Thread {
|
|||
.unwrap_or(0)
|
||||
// We know the context window was exceeded in practice, so if our estimate was
|
||||
// lower than max tokens, the estimate was wrong; return that we exceeded by 1.
|
||||
.max(model.max_token_count().saturating_add(1))
|
||||
.max(
|
||||
model
|
||||
.max_token_count_for_mode(completion_mode)
|
||||
.saturating_add(1),
|
||||
)
|
||||
});
|
||||
thread.exceeded_window_error = Some(ExceededWindowError {
|
||||
model_id: model.id(),
|
||||
|
@ -2507,6 +2517,7 @@ impl Thread {
|
|||
hallucinated_tool_name,
|
||||
Err(anyhow!("Missing tool call: {error_message}")),
|
||||
self.configured_model.as_ref(),
|
||||
self.completion_mode,
|
||||
);
|
||||
|
||||
cx.emit(ThreadEvent::MissingToolUse {
|
||||
|
@ -2533,6 +2544,7 @@ impl Thread {
|
|||
tool_name,
|
||||
Err(anyhow!("Error parsing input JSON: {error}")),
|
||||
self.configured_model.as_ref(),
|
||||
self.completion_mode,
|
||||
);
|
||||
let ui_text = if let Some(pending_tool_use) = &pending_tool_use {
|
||||
pending_tool_use.ui_text.clone()
|
||||
|
@ -2608,6 +2620,7 @@ impl Thread {
|
|||
tool_name,
|
||||
output,
|
||||
thread.configured_model.as_ref(),
|
||||
thread.completion_mode,
|
||||
);
|
||||
thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
|
||||
})
|
||||
|
@ -3084,7 +3097,9 @@ impl Thread {
|
|||
return TotalTokenUsage::default();
|
||||
};
|
||||
|
||||
let max = model.model.max_token_count();
|
||||
let max = model
|
||||
.model
|
||||
.max_token_count_for_mode(self.completion_mode().into());
|
||||
|
||||
let index = self
|
||||
.messages
|
||||
|
@ -3111,7 +3126,9 @@ impl Thread {
|
|||
pub fn total_token_usage(&self) -> Option<TotalTokenUsage> {
|
||||
let model = self.configured_model.as_ref()?;
|
||||
|
||||
let max = model.model.max_token_count();
|
||||
let max = model
|
||||
.model
|
||||
.max_token_count_for_mode(self.completion_mode().into());
|
||||
|
||||
if let Some(exceeded_error) = &self.exceeded_window_error {
|
||||
if model.model.id() == exceeded_error.model_id {
|
||||
|
@ -3177,6 +3194,7 @@ impl Thread {
|
|||
tool_name,
|
||||
err,
|
||||
self.configured_model.as_ref(),
|
||||
self.completion_mode,
|
||||
);
|
||||
self.tool_finished(tool_use_id.clone(), None, true, window, cx);
|
||||
}
|
||||
|
|
|
@ -2,6 +2,7 @@ use crate::{
|
|||
thread::{MessageId, PromptId, ThreadId},
|
||||
thread_store::SerializedMessage,
|
||||
};
|
||||
use agent_settings::CompletionMode;
|
||||
use anyhow::Result;
|
||||
use assistant_tool::{
|
||||
AnyToolCard, Tool, ToolResultContent, ToolResultOutput, ToolUseStatus, ToolWorkingSet,
|
||||
|
@ -11,8 +12,9 @@ use futures::{FutureExt as _, future::Shared};
|
|||
use gpui::{App, Entity, SharedString, Task, Window};
|
||||
use icons::IconName;
|
||||
use language_model::{
|
||||
ConfiguredModel, LanguageModel, LanguageModelRequest, LanguageModelToolResult,
|
||||
LanguageModelToolResultContent, LanguageModelToolUse, LanguageModelToolUseId, Role,
|
||||
ConfiguredModel, LanguageModel, LanguageModelExt, LanguageModelRequest,
|
||||
LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolUse,
|
||||
LanguageModelToolUseId, Role,
|
||||
};
|
||||
use project::Project;
|
||||
use std::sync::Arc;
|
||||
|
@ -400,6 +402,7 @@ impl ToolUseState {
|
|||
tool_name: Arc<str>,
|
||||
output: Result<ToolResultOutput>,
|
||||
configured_model: Option<&ConfiguredModel>,
|
||||
completion_mode: CompletionMode,
|
||||
) -> Option<PendingToolUse> {
|
||||
let metadata = self.tool_use_metadata_by_id.remove(&tool_use_id);
|
||||
|
||||
|
@ -426,7 +429,10 @@ impl ToolUseState {
|
|||
|
||||
// Protect from overly large output
|
||||
let tool_output_limit = configured_model
|
||||
.map(|model| model.model.max_token_count() as usize * BYTES_PER_TOKEN_ESTIMATE)
|
||||
.map(|model| {
|
||||
model.model.max_token_count_for_mode(completion_mode.into()) as usize
|
||||
* BYTES_PER_TOKEN_ESTIMATE
|
||||
})
|
||||
.unwrap_or(usize::MAX);
|
||||
|
||||
let content = match tool_result {
|
||||
|
|
|
@ -38,8 +38,8 @@ use language::{
|
|||
language_settings::{SoftWrap, all_language_settings},
|
||||
};
|
||||
use language_model::{
|
||||
ConfigurationError, LanguageModelImage, LanguageModelProviderTosView, LanguageModelRegistry,
|
||||
Role,
|
||||
ConfigurationError, LanguageModelExt, LanguageModelImage, LanguageModelProviderTosView,
|
||||
LanguageModelRegistry, Role,
|
||||
};
|
||||
use multi_buffer::MultiBufferRow;
|
||||
use picker::{Picker, popover_menu::PickerPopoverMenu};
|
||||
|
@ -3063,7 +3063,7 @@ fn token_state(context: &Entity<AssistantContext>, cx: &App) -> Option<TokenStat
|
|||
.default_model()?
|
||||
.model;
|
||||
let token_count = context.read(cx).token_count()?;
|
||||
let max_token_count = model.max_token_count();
|
||||
let max_token_count = model.max_token_count_for_mode(context.read(cx).completion_mode().into());
|
||||
let token_state = if max_token_count.saturating_sub(token_count) == 0 {
|
||||
TokenState::NoTokensLeft {
|
||||
max_token_count,
|
||||
|
|
|
@ -26,7 +26,7 @@ use std::time::Duration;
|
|||
use std::{fmt, io};
|
||||
use thiserror::Error;
|
||||
use util::serde::is_default;
|
||||
use zed_llm_client::CompletionRequestStatus;
|
||||
use zed_llm_client::{CompletionMode, CompletionRequestStatus};
|
||||
|
||||
pub use crate::model::*;
|
||||
pub use crate::rate_limiter::*;
|
||||
|
@ -462,6 +462,10 @@ pub trait LanguageModel: Send + Sync {
|
|||
}
|
||||
|
||||
fn max_token_count(&self) -> u64;
|
||||
/// Returns the maximum token count for this model in burn mode (If `supports_burn_mode` is `false` this returns `None`)
|
||||
fn max_token_count_in_burn_mode(&self) -> Option<u64> {
|
||||
None
|
||||
}
|
||||
fn max_output_tokens(&self) -> Option<u64> {
|
||||
None
|
||||
}
|
||||
|
@ -557,6 +561,18 @@ pub trait LanguageModel: Send + Sync {
|
|||
}
|
||||
}
|
||||
|
||||
pub trait LanguageModelExt: LanguageModel {
|
||||
fn max_token_count_for_mode(&self, mode: CompletionMode) -> u64 {
|
||||
match mode {
|
||||
CompletionMode::Normal => self.max_token_count(),
|
||||
CompletionMode::Max => self
|
||||
.max_token_count_in_burn_mode()
|
||||
.unwrap_or_else(|| self.max_token_count()),
|
||||
}
|
||||
}
|
||||
}
|
||||
impl LanguageModelExt for dyn LanguageModel {}
|
||||
|
||||
pub trait LanguageModelTool: 'static + DeserializeOwned + JsonSchema {
|
||||
fn name() -> String;
|
||||
fn description() -> String;
|
||||
|
|
|
@ -730,6 +730,13 @@ impl LanguageModel for CloudLanguageModel {
|
|||
self.model.max_token_count as u64
|
||||
}
|
||||
|
||||
fn max_token_count_in_burn_mode(&self) -> Option<u64> {
|
||||
self.model
|
||||
.max_token_count_in_max_mode
|
||||
.filter(|_| self.model.supports_max_mode)
|
||||
.map(|max_token_count| max_token_count as u64)
|
||||
}
|
||||
|
||||
fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
|
||||
match &self.model.provider {
|
||||
zed_llm_client::LanguageModelProvider::Anthropic => {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue