agent2: Token count (#36496)
Release Notes: - N/A --------- Co-authored-by: Agus Zubiaga <agus@zed.dev>
This commit is contained in:
parent
6825715503
commit
5fb68cb8be
9 changed files with 321 additions and 26 deletions
|
@ -13,7 +13,7 @@ use anyhow::{Context as _, Result, anyhow};
|
|||
use assistant_tool::adapt_schema_to_format;
|
||||
use chrono::{DateTime, Utc};
|
||||
use cloud_llm_client::{CompletionIntent, CompletionRequestStatus};
|
||||
use collections::IndexMap;
|
||||
use collections::{HashMap, IndexMap};
|
||||
use fs::Fs;
|
||||
use futures::{
|
||||
FutureExt,
|
||||
|
@ -24,8 +24,8 @@ use futures::{
|
|||
use git::repository::DiffType;
|
||||
use gpui::{App, AppContext, AsyncApp, Context, Entity, SharedString, Task, WeakEntity};
|
||||
use language_model::{
|
||||
LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelImage,
|
||||
LanguageModelProviderId, LanguageModelRegistry, LanguageModelRequest,
|
||||
LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelExt,
|
||||
LanguageModelImage, LanguageModelProviderId, LanguageModelRegistry, LanguageModelRequest,
|
||||
LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
|
||||
LanguageModelToolResultContent, LanguageModelToolSchemaFormat, LanguageModelToolUse,
|
||||
LanguageModelToolUseId, Role, SelectedModel, StopReason, TokenUsage,
|
||||
|
@ -481,6 +481,7 @@ pub enum ThreadEvent {
|
|||
ToolCall(acp::ToolCall),
|
||||
ToolCallUpdate(acp_thread::ToolCallUpdate),
|
||||
ToolCallAuthorization(ToolCallAuthorization),
|
||||
TokenUsageUpdate(acp_thread::TokenUsage),
|
||||
TitleUpdate(SharedString),
|
||||
Retry(acp_thread::RetryStatus),
|
||||
Stop(acp::StopReason),
|
||||
|
@ -509,8 +510,7 @@ pub struct Thread {
|
|||
pending_message: Option<AgentMessage>,
|
||||
tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
|
||||
tool_use_limit_reached: bool,
|
||||
#[allow(unused)]
|
||||
request_token_usage: Vec<TokenUsage>,
|
||||
request_token_usage: HashMap<UserMessageId, language_model::TokenUsage>,
|
||||
#[allow(unused)]
|
||||
cumulative_token_usage: TokenUsage,
|
||||
#[allow(unused)]
|
||||
|
@ -548,7 +548,7 @@ impl Thread {
|
|||
pending_message: None,
|
||||
tools: BTreeMap::default(),
|
||||
tool_use_limit_reached: false,
|
||||
request_token_usage: Vec::new(),
|
||||
request_token_usage: HashMap::default(),
|
||||
cumulative_token_usage: TokenUsage::default(),
|
||||
initial_project_snapshot: {
|
||||
let project_snapshot = Self::project_snapshot(project.clone(), cx);
|
||||
|
@ -951,6 +951,15 @@ impl Thread {
|
|||
self.flush_pending_message(cx);
|
||||
}
|
||||
|
||||
pub fn update_token_usage(&mut self, update: language_model::TokenUsage) {
|
||||
let Some(last_user_message) = self.last_user_message() else {
|
||||
return;
|
||||
};
|
||||
|
||||
self.request_token_usage
|
||||
.insert(last_user_message.id.clone(), update);
|
||||
}
|
||||
|
||||
pub fn truncate(&mut self, message_id: UserMessageId, cx: &mut Context<Self>) -> Result<()> {
|
||||
self.cancel(cx);
|
||||
let Some(position) = self.messages.iter().position(
|
||||
|
@ -958,11 +967,31 @@ impl Thread {
|
|||
) else {
|
||||
return Err(anyhow!("Message not found"));
|
||||
};
|
||||
self.messages.truncate(position);
|
||||
|
||||
for message in self.messages.drain(position..) {
|
||||
match message {
|
||||
Message::User(message) => {
|
||||
self.request_token_usage.remove(&message.id);
|
||||
}
|
||||
Message::Agent(_) | Message::Resume => {}
|
||||
}
|
||||
}
|
||||
|
||||
cx.notify();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn latest_token_usage(&self) -> Option<acp_thread::TokenUsage> {
|
||||
let last_user_message = self.last_user_message()?;
|
||||
let tokens = self.request_token_usage.get(&last_user_message.id)?;
|
||||
let model = self.model.clone()?;
|
||||
|
||||
Some(acp_thread::TokenUsage {
|
||||
max_tokens: model.max_token_count_for_mode(self.completion_mode.into()),
|
||||
used_tokens: tokens.total_tokens(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn resume(
|
||||
&mut self,
|
||||
cx: &mut Context<Self>,
|
||||
|
@ -1148,6 +1177,21 @@ impl Thread {
|
|||
)) => {
|
||||
*tool_use_limit_reached = true;
|
||||
}
|
||||
Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
|
||||
let usage = acp_thread::TokenUsage {
|
||||
max_tokens: model.max_token_count_for_mode(
|
||||
request
|
||||
.mode
|
||||
.unwrap_or(cloud_llm_client::CompletionMode::Normal),
|
||||
),
|
||||
used_tokens: token_usage.total_tokens(),
|
||||
};
|
||||
|
||||
this.update(cx, |this, _cx| this.update_token_usage(token_usage))
|
||||
.ok();
|
||||
|
||||
event_stream.send_token_usage_update(usage);
|
||||
}
|
||||
Ok(LanguageModelCompletionEvent::Stop(StopReason::Refusal)) => {
|
||||
*refusal = true;
|
||||
return Ok(FuturesUnordered::default());
|
||||
|
@ -1532,6 +1576,16 @@ impl Thread {
|
|||
})
|
||||
}))
|
||||
}
|
||||
fn last_user_message(&self) -> Option<&UserMessage> {
|
||||
self.messages
|
||||
.iter()
|
||||
.rev()
|
||||
.find_map(|message| match message {
|
||||
Message::User(user_message) => Some(user_message),
|
||||
Message::Agent(_) => None,
|
||||
Message::Resume => None,
|
||||
})
|
||||
}
|
||||
|
||||
fn pending_message(&mut self) -> &mut AgentMessage {
|
||||
self.pending_message.get_or_insert_default()
|
||||
|
@ -2051,6 +2105,12 @@ impl ThreadEventStream {
|
|||
.ok();
|
||||
}
|
||||
|
||||
fn send_token_usage_update(&self, usage: acp_thread::TokenUsage) {
|
||||
self.0
|
||||
.unbounded_send(Ok(ThreadEvent::TokenUsageUpdate(usage)))
|
||||
.ok();
|
||||
}
|
||||
|
||||
fn send_retry(&self, status: acp_thread::RetryStatus) {
|
||||
self.0.unbounded_send(Ok(ThreadEvent::Retry(status))).ok();
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue