agent2: Token count (#36496)
Release Notes: - N/A --------- Co-authored-by: Agus Zubiaga <agus@zed.dev>
This commit is contained in:
parent
6825715503
commit
5fb68cb8be
9 changed files with 321 additions and 26 deletions
|
@ -6,6 +6,7 @@ mod terminal;
|
||||||
pub use connection::*;
|
pub use connection::*;
|
||||||
pub use diff::*;
|
pub use diff::*;
|
||||||
pub use mention::*;
|
pub use mention::*;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
pub use terminal::*;
|
pub use terminal::*;
|
||||||
|
|
||||||
use action_log::ActionLog;
|
use action_log::ActionLog;
|
||||||
|
@ -664,6 +665,12 @@ impl PlanEntry {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||||
|
pub struct TokenUsage {
|
||||||
|
pub max_tokens: u64,
|
||||||
|
pub used_tokens: u64,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct RetryStatus {
|
pub struct RetryStatus {
|
||||||
pub last_error: SharedString,
|
pub last_error: SharedString,
|
||||||
|
@ -683,12 +690,14 @@ pub struct AcpThread {
|
||||||
send_task: Option<Task<()>>,
|
send_task: Option<Task<()>>,
|
||||||
connection: Rc<dyn AgentConnection>,
|
connection: Rc<dyn AgentConnection>,
|
||||||
session_id: acp::SessionId,
|
session_id: acp::SessionId,
|
||||||
|
token_usage: Option<TokenUsage>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub enum AcpThreadEvent {
|
pub enum AcpThreadEvent {
|
||||||
NewEntry,
|
NewEntry,
|
||||||
TitleUpdated,
|
TitleUpdated,
|
||||||
|
TokenUsageUpdated,
|
||||||
EntryUpdated(usize),
|
EntryUpdated(usize),
|
||||||
EntriesRemoved(Range<usize>),
|
EntriesRemoved(Range<usize>),
|
||||||
ToolAuthorizationRequired,
|
ToolAuthorizationRequired,
|
||||||
|
@ -748,6 +757,7 @@ impl AcpThread {
|
||||||
send_task: None,
|
send_task: None,
|
||||||
connection,
|
connection,
|
||||||
session_id,
|
session_id,
|
||||||
|
token_usage: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -787,6 +797,10 @@ impl AcpThread {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn token_usage(&self) -> Option<&TokenUsage> {
|
||||||
|
self.token_usage.as_ref()
|
||||||
|
}
|
||||||
|
|
||||||
pub fn has_pending_edit_tool_calls(&self) -> bool {
|
pub fn has_pending_edit_tool_calls(&self) -> bool {
|
||||||
for entry in self.entries.iter().rev() {
|
for entry in self.entries.iter().rev() {
|
||||||
match entry {
|
match entry {
|
||||||
|
@ -937,6 +951,11 @@ impl AcpThread {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn update_token_usage(&mut self, usage: Option<TokenUsage>, cx: &mut Context<Self>) {
|
||||||
|
self.token_usage = usage;
|
||||||
|
cx.emit(AcpThreadEvent::TokenUsageUpdated);
|
||||||
|
}
|
||||||
|
|
||||||
pub fn update_retry_status(&mut self, status: RetryStatus, cx: &mut Context<Self>) {
|
pub fn update_retry_status(&mut self, status: RetryStatus, cx: &mut Context<Self>) {
|
||||||
cx.emit(AcpThreadEvent::Retry(status));
|
cx.emit(AcpThreadEvent::Retry(status));
|
||||||
}
|
}
|
||||||
|
|
|
@ -10,7 +10,7 @@ use std::{any::Any, error::Error, fmt, path::Path, rc::Rc, sync::Arc};
|
||||||
use ui::{App, IconName};
|
use ui::{App, IconName};
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Hash)]
|
||||||
pub struct UserMessageId(Arc<str>);
|
pub struct UserMessageId(Arc<str>);
|
||||||
|
|
||||||
impl UserMessageId {
|
impl UserMessageId {
|
||||||
|
|
|
@ -66,6 +66,7 @@ zstd.workspace = true
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
agent = { workspace = true, "features" = ["test-support"] }
|
agent = { workspace = true, "features" = ["test-support"] }
|
||||||
|
assistant_context = { workspace = true, "features" = ["test-support"] }
|
||||||
ctor.workspace = true
|
ctor.workspace = true
|
||||||
client = { workspace = true, "features" = ["test-support"] }
|
client = { workspace = true, "features" = ["test-support"] }
|
||||||
clock = { workspace = true, "features" = ["test-support"] }
|
clock = { workspace = true, "features" = ["test-support"] }
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
|
use crate::HistoryStore;
|
||||||
use crate::{
|
use crate::{
|
||||||
ContextServerRegistry, Thread, ThreadEvent, ToolCallAuthorization, UserMessageContent,
|
ContextServerRegistry, Thread, ThreadEvent, ThreadsDatabase, ToolCallAuthorization,
|
||||||
templates::Templates,
|
UserMessageContent, templates::Templates,
|
||||||
};
|
};
|
||||||
use crate::{HistoryStore, ThreadsDatabase};
|
|
||||||
use acp_thread::{AcpThread, AgentModelSelector};
|
use acp_thread::{AcpThread, AgentModelSelector};
|
||||||
use action_log::ActionLog;
|
use action_log::ActionLog;
|
||||||
use agent_client_protocol as acp;
|
use agent_client_protocol as acp;
|
||||||
|
@ -673,6 +673,11 @@ impl NativeAgentConnection {
|
||||||
thread.update_tool_call(update, cx)
|
thread.update_tool_call(update, cx)
|
||||||
})??;
|
})??;
|
||||||
}
|
}
|
||||||
|
ThreadEvent::TokenUsageUpdate(usage) => {
|
||||||
|
acp_thread.update(cx, |thread, cx| {
|
||||||
|
thread.update_token_usage(Some(usage), cx)
|
||||||
|
})?;
|
||||||
|
}
|
||||||
ThreadEvent::TitleUpdate(title) => {
|
ThreadEvent::TitleUpdate(title) => {
|
||||||
acp_thread
|
acp_thread
|
||||||
.update(cx, |thread, cx| thread.update_title(title, cx))??;
|
.update(cx, |thread, cx| thread.update_title(title, cx))??;
|
||||||
|
@ -895,10 +900,12 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
|
||||||
cx: &mut App,
|
cx: &mut App,
|
||||||
) -> Option<Rc<dyn acp_thread::AgentSessionEditor>> {
|
) -> Option<Rc<dyn acp_thread::AgentSessionEditor>> {
|
||||||
self.0.update(cx, |agent, _cx| {
|
self.0.update(cx, |agent, _cx| {
|
||||||
agent
|
agent.sessions.get(session_id).map(|session| {
|
||||||
.sessions
|
Rc::new(NativeAgentSessionEditor {
|
||||||
.get(session_id)
|
thread: session.thread.clone(),
|
||||||
.map(|session| Rc::new(NativeAgentSessionEditor(session.thread.clone())) as _)
|
acp_thread: session.acp_thread.clone(),
|
||||||
|
}) as _
|
||||||
|
})
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -907,14 +914,27 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct NativeAgentSessionEditor(Entity<Thread>);
|
struct NativeAgentSessionEditor {
|
||||||
|
thread: Entity<Thread>,
|
||||||
|
acp_thread: WeakEntity<AcpThread>,
|
||||||
|
}
|
||||||
|
|
||||||
impl acp_thread::AgentSessionEditor for NativeAgentSessionEditor {
|
impl acp_thread::AgentSessionEditor for NativeAgentSessionEditor {
|
||||||
fn truncate(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
|
fn truncate(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
|
||||||
Task::ready(
|
match self.thread.update(cx, |thread, cx| {
|
||||||
self.0
|
thread.truncate(message_id.clone(), cx)?;
|
||||||
.update(cx, |thread, cx| thread.truncate(message_id, cx)),
|
Ok(thread.latest_token_usage())
|
||||||
)
|
}) {
|
||||||
|
Ok(usage) => {
|
||||||
|
self.acp_thread
|
||||||
|
.update(cx, |thread, cx| {
|
||||||
|
thread.update_token_usage(usage, cx);
|
||||||
|
})
|
||||||
|
.ok();
|
||||||
|
Task::ready(Ok(()))
|
||||||
|
}
|
||||||
|
Err(error) => Task::ready(Err(error)),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
use crate::{AgentMessage, AgentMessageContent, UserMessage, UserMessageContent};
|
use crate::{AgentMessage, AgentMessageContent, UserMessage, UserMessageContent};
|
||||||
|
use acp_thread::UserMessageId;
|
||||||
use agent::thread_store;
|
use agent::thread_store;
|
||||||
use agent_client_protocol as acp;
|
use agent_client_protocol as acp;
|
||||||
use agent_settings::{AgentProfileId, CompletionMode};
|
use agent_settings::{AgentProfileId, CompletionMode};
|
||||||
|
@ -42,7 +43,7 @@ pub struct DbThread {
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub cumulative_token_usage: language_model::TokenUsage,
|
pub cumulative_token_usage: language_model::TokenUsage,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub request_token_usage: Vec<language_model::TokenUsage>,
|
pub request_token_usage: HashMap<acp_thread::UserMessageId, language_model::TokenUsage>,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub model: Option<DbLanguageModel>,
|
pub model: Option<DbLanguageModel>,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
|
@ -67,7 +68,10 @@ impl DbThread {
|
||||||
|
|
||||||
fn upgrade_from_agent_1(thread: agent::SerializedThread) -> Result<Self> {
|
fn upgrade_from_agent_1(thread: agent::SerializedThread) -> Result<Self> {
|
||||||
let mut messages = Vec::new();
|
let mut messages = Vec::new();
|
||||||
for msg in thread.messages {
|
let mut request_token_usage = HashMap::default();
|
||||||
|
|
||||||
|
let mut last_user_message_id = None;
|
||||||
|
for (ix, msg) in thread.messages.into_iter().enumerate() {
|
||||||
let message = match msg.role {
|
let message = match msg.role {
|
||||||
language_model::Role::User => {
|
language_model::Role::User => {
|
||||||
let mut content = Vec::new();
|
let mut content = Vec::new();
|
||||||
|
@ -93,9 +97,12 @@ impl DbThread {
|
||||||
content.push(UserMessageContent::Text(msg.context));
|
content.push(UserMessageContent::Text(msg.context));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let id = UserMessageId::new();
|
||||||
|
last_user_message_id = Some(id.clone());
|
||||||
|
|
||||||
crate::Message::User(UserMessage {
|
crate::Message::User(UserMessage {
|
||||||
// MessageId from old format can't be meaningfully converted, so generate a new one
|
// MessageId from old format can't be meaningfully converted, so generate a new one
|
||||||
id: acp_thread::UserMessageId::new(),
|
id,
|
||||||
content,
|
content,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -154,6 +161,12 @@ impl DbThread {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if let Some(last_user_message_id) = &last_user_message_id
|
||||||
|
&& let Some(token_usage) = thread.request_token_usage.get(ix).copied()
|
||||||
|
{
|
||||||
|
request_token_usage.insert(last_user_message_id.clone(), token_usage);
|
||||||
|
}
|
||||||
|
|
||||||
crate::Message::Agent(AgentMessage {
|
crate::Message::Agent(AgentMessage {
|
||||||
content,
|
content,
|
||||||
tool_results,
|
tool_results,
|
||||||
|
@ -175,7 +188,7 @@ impl DbThread {
|
||||||
summary: thread.detailed_summary_state,
|
summary: thread.detailed_summary_state,
|
||||||
initial_project_snapshot: thread.initial_project_snapshot,
|
initial_project_snapshot: thread.initial_project_snapshot,
|
||||||
cumulative_token_usage: thread.cumulative_token_usage,
|
cumulative_token_usage: thread.cumulative_token_usage,
|
||||||
request_token_usage: thread.request_token_usage,
|
request_token_usage,
|
||||||
model: thread.model,
|
model: thread.model,
|
||||||
completion_mode: thread.completion_mode,
|
completion_mode: thread.completion_mode,
|
||||||
profile: thread.profile,
|
profile: thread.profile,
|
||||||
|
|
|
@ -1117,7 +1117,7 @@ async fn test_refusal(cx: &mut TestAppContext) {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[gpui::test]
|
#[gpui::test]
|
||||||
async fn test_truncate(cx: &mut TestAppContext) {
|
async fn test_truncate_first_message(cx: &mut TestAppContext) {
|
||||||
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
|
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
|
||||||
let fake_model = model.as_fake();
|
let fake_model = model.as_fake();
|
||||||
|
|
||||||
|
@ -1137,9 +1137,18 @@ async fn test_truncate(cx: &mut TestAppContext) {
|
||||||
Hello
|
Hello
|
||||||
"}
|
"}
|
||||||
);
|
);
|
||||||
|
assert_eq!(thread.latest_token_usage(), None);
|
||||||
});
|
});
|
||||||
|
|
||||||
fake_model.send_last_completion_stream_text_chunk("Hey!");
|
fake_model.send_last_completion_stream_text_chunk("Hey!");
|
||||||
|
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
|
||||||
|
language_model::TokenUsage {
|
||||||
|
input_tokens: 32_000,
|
||||||
|
output_tokens: 16_000,
|
||||||
|
cache_creation_input_tokens: 0,
|
||||||
|
cache_read_input_tokens: 0,
|
||||||
|
},
|
||||||
|
));
|
||||||
cx.run_until_parked();
|
cx.run_until_parked();
|
||||||
thread.read_with(cx, |thread, _| {
|
thread.read_with(cx, |thread, _| {
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
|
@ -1154,6 +1163,13 @@ async fn test_truncate(cx: &mut TestAppContext) {
|
||||||
Hey!
|
Hey!
|
||||||
"}
|
"}
|
||||||
);
|
);
|
||||||
|
assert_eq!(
|
||||||
|
thread.latest_token_usage(),
|
||||||
|
Some(acp_thread::TokenUsage {
|
||||||
|
used_tokens: 32_000 + 16_000,
|
||||||
|
max_tokens: 1_000_000,
|
||||||
|
})
|
||||||
|
);
|
||||||
});
|
});
|
||||||
|
|
||||||
thread
|
thread
|
||||||
|
@ -1162,6 +1178,7 @@ async fn test_truncate(cx: &mut TestAppContext) {
|
||||||
cx.run_until_parked();
|
cx.run_until_parked();
|
||||||
thread.read_with(cx, |thread, _| {
|
thread.read_with(cx, |thread, _| {
|
||||||
assert_eq!(thread.to_markdown(), "");
|
assert_eq!(thread.to_markdown(), "");
|
||||||
|
assert_eq!(thread.latest_token_usage(), None);
|
||||||
});
|
});
|
||||||
|
|
||||||
// Ensure we can still send a new message after truncation.
|
// Ensure we can still send a new message after truncation.
|
||||||
|
@ -1182,6 +1199,14 @@ async fn test_truncate(cx: &mut TestAppContext) {
|
||||||
});
|
});
|
||||||
cx.run_until_parked();
|
cx.run_until_parked();
|
||||||
fake_model.send_last_completion_stream_text_chunk("Ahoy!");
|
fake_model.send_last_completion_stream_text_chunk("Ahoy!");
|
||||||
|
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
|
||||||
|
language_model::TokenUsage {
|
||||||
|
input_tokens: 40_000,
|
||||||
|
output_tokens: 20_000,
|
||||||
|
cache_creation_input_tokens: 0,
|
||||||
|
cache_read_input_tokens: 0,
|
||||||
|
},
|
||||||
|
));
|
||||||
cx.run_until_parked();
|
cx.run_until_parked();
|
||||||
thread.read_with(cx, |thread, _| {
|
thread.read_with(cx, |thread, _| {
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
|
@ -1196,9 +1221,126 @@ async fn test_truncate(cx: &mut TestAppContext) {
|
||||||
Ahoy!
|
Ahoy!
|
||||||
"}
|
"}
|
||||||
);
|
);
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
thread.latest_token_usage(),
|
||||||
|
Some(acp_thread::TokenUsage {
|
||||||
|
used_tokens: 40_000 + 20_000,
|
||||||
|
max_tokens: 1_000_000,
|
||||||
|
})
|
||||||
|
);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[gpui::test]
|
||||||
|
async fn test_truncate_second_message(cx: &mut TestAppContext) {
|
||||||
|
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
|
||||||
|
let fake_model = model.as_fake();
|
||||||
|
|
||||||
|
thread
|
||||||
|
.update(cx, |thread, cx| {
|
||||||
|
thread.send(UserMessageId::new(), ["Message 1"], cx)
|
||||||
|
})
|
||||||
|
.unwrap();
|
||||||
|
cx.run_until_parked();
|
||||||
|
fake_model.send_last_completion_stream_text_chunk("Message 1 response");
|
||||||
|
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
|
||||||
|
language_model::TokenUsage {
|
||||||
|
input_tokens: 32_000,
|
||||||
|
output_tokens: 16_000,
|
||||||
|
cache_creation_input_tokens: 0,
|
||||||
|
cache_read_input_tokens: 0,
|
||||||
|
},
|
||||||
|
));
|
||||||
|
fake_model.end_last_completion_stream();
|
||||||
|
cx.run_until_parked();
|
||||||
|
|
||||||
|
let assert_first_message_state = |cx: &mut TestAppContext| {
|
||||||
|
thread.clone().read_with(cx, |thread, _| {
|
||||||
|
assert_eq!(
|
||||||
|
thread.to_markdown(),
|
||||||
|
indoc! {"
|
||||||
|
## User
|
||||||
|
|
||||||
|
Message 1
|
||||||
|
|
||||||
|
## Assistant
|
||||||
|
|
||||||
|
Message 1 response
|
||||||
|
"}
|
||||||
|
);
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
thread.latest_token_usage(),
|
||||||
|
Some(acp_thread::TokenUsage {
|
||||||
|
used_tokens: 32_000 + 16_000,
|
||||||
|
max_tokens: 1_000_000,
|
||||||
|
})
|
||||||
|
);
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
|
assert_first_message_state(cx);
|
||||||
|
|
||||||
|
let second_message_id = UserMessageId::new();
|
||||||
|
thread
|
||||||
|
.update(cx, |thread, cx| {
|
||||||
|
thread.send(second_message_id.clone(), ["Message 2"], cx)
|
||||||
|
})
|
||||||
|
.unwrap();
|
||||||
|
cx.run_until_parked();
|
||||||
|
|
||||||
|
fake_model.send_last_completion_stream_text_chunk("Message 2 response");
|
||||||
|
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
|
||||||
|
language_model::TokenUsage {
|
||||||
|
input_tokens: 40_000,
|
||||||
|
output_tokens: 20_000,
|
||||||
|
cache_creation_input_tokens: 0,
|
||||||
|
cache_read_input_tokens: 0,
|
||||||
|
},
|
||||||
|
));
|
||||||
|
fake_model.end_last_completion_stream();
|
||||||
|
cx.run_until_parked();
|
||||||
|
|
||||||
|
thread.read_with(cx, |thread, _| {
|
||||||
|
assert_eq!(
|
||||||
|
thread.to_markdown(),
|
||||||
|
indoc! {"
|
||||||
|
## User
|
||||||
|
|
||||||
|
Message 1
|
||||||
|
|
||||||
|
## Assistant
|
||||||
|
|
||||||
|
Message 1 response
|
||||||
|
|
||||||
|
## User
|
||||||
|
|
||||||
|
Message 2
|
||||||
|
|
||||||
|
## Assistant
|
||||||
|
|
||||||
|
Message 2 response
|
||||||
|
"}
|
||||||
|
);
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
thread.latest_token_usage(),
|
||||||
|
Some(acp_thread::TokenUsage {
|
||||||
|
used_tokens: 40_000 + 20_000,
|
||||||
|
max_tokens: 1_000_000,
|
||||||
|
})
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
thread
|
||||||
|
.update(cx, |thread, cx| thread.truncate(second_message_id, cx))
|
||||||
|
.unwrap();
|
||||||
|
cx.run_until_parked();
|
||||||
|
|
||||||
|
assert_first_message_state(cx);
|
||||||
|
}
|
||||||
|
|
||||||
#[gpui::test]
|
#[gpui::test]
|
||||||
async fn test_title_generation(cx: &mut TestAppContext) {
|
async fn test_title_generation(cx: &mut TestAppContext) {
|
||||||
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
|
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
|
||||||
|
|
|
@ -13,7 +13,7 @@ use anyhow::{Context as _, Result, anyhow};
|
||||||
use assistant_tool::adapt_schema_to_format;
|
use assistant_tool::adapt_schema_to_format;
|
||||||
use chrono::{DateTime, Utc};
|
use chrono::{DateTime, Utc};
|
||||||
use cloud_llm_client::{CompletionIntent, CompletionRequestStatus};
|
use cloud_llm_client::{CompletionIntent, CompletionRequestStatus};
|
||||||
use collections::IndexMap;
|
use collections::{HashMap, IndexMap};
|
||||||
use fs::Fs;
|
use fs::Fs;
|
||||||
use futures::{
|
use futures::{
|
||||||
FutureExt,
|
FutureExt,
|
||||||
|
@ -24,8 +24,8 @@ use futures::{
|
||||||
use git::repository::DiffType;
|
use git::repository::DiffType;
|
||||||
use gpui::{App, AppContext, AsyncApp, Context, Entity, SharedString, Task, WeakEntity};
|
use gpui::{App, AppContext, AsyncApp, Context, Entity, SharedString, Task, WeakEntity};
|
||||||
use language_model::{
|
use language_model::{
|
||||||
LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelImage,
|
LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelExt,
|
||||||
LanguageModelProviderId, LanguageModelRegistry, LanguageModelRequest,
|
LanguageModelImage, LanguageModelProviderId, LanguageModelRegistry, LanguageModelRequest,
|
||||||
LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
|
LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
|
||||||
LanguageModelToolResultContent, LanguageModelToolSchemaFormat, LanguageModelToolUse,
|
LanguageModelToolResultContent, LanguageModelToolSchemaFormat, LanguageModelToolUse,
|
||||||
LanguageModelToolUseId, Role, SelectedModel, StopReason, TokenUsage,
|
LanguageModelToolUseId, Role, SelectedModel, StopReason, TokenUsage,
|
||||||
|
@ -481,6 +481,7 @@ pub enum ThreadEvent {
|
||||||
ToolCall(acp::ToolCall),
|
ToolCall(acp::ToolCall),
|
||||||
ToolCallUpdate(acp_thread::ToolCallUpdate),
|
ToolCallUpdate(acp_thread::ToolCallUpdate),
|
||||||
ToolCallAuthorization(ToolCallAuthorization),
|
ToolCallAuthorization(ToolCallAuthorization),
|
||||||
|
TokenUsageUpdate(acp_thread::TokenUsage),
|
||||||
TitleUpdate(SharedString),
|
TitleUpdate(SharedString),
|
||||||
Retry(acp_thread::RetryStatus),
|
Retry(acp_thread::RetryStatus),
|
||||||
Stop(acp::StopReason),
|
Stop(acp::StopReason),
|
||||||
|
@ -509,8 +510,7 @@ pub struct Thread {
|
||||||
pending_message: Option<AgentMessage>,
|
pending_message: Option<AgentMessage>,
|
||||||
tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
|
tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
|
||||||
tool_use_limit_reached: bool,
|
tool_use_limit_reached: bool,
|
||||||
#[allow(unused)]
|
request_token_usage: HashMap<UserMessageId, language_model::TokenUsage>,
|
||||||
request_token_usage: Vec<TokenUsage>,
|
|
||||||
#[allow(unused)]
|
#[allow(unused)]
|
||||||
cumulative_token_usage: TokenUsage,
|
cumulative_token_usage: TokenUsage,
|
||||||
#[allow(unused)]
|
#[allow(unused)]
|
||||||
|
@ -548,7 +548,7 @@ impl Thread {
|
||||||
pending_message: None,
|
pending_message: None,
|
||||||
tools: BTreeMap::default(),
|
tools: BTreeMap::default(),
|
||||||
tool_use_limit_reached: false,
|
tool_use_limit_reached: false,
|
||||||
request_token_usage: Vec::new(),
|
request_token_usage: HashMap::default(),
|
||||||
cumulative_token_usage: TokenUsage::default(),
|
cumulative_token_usage: TokenUsage::default(),
|
||||||
initial_project_snapshot: {
|
initial_project_snapshot: {
|
||||||
let project_snapshot = Self::project_snapshot(project.clone(), cx);
|
let project_snapshot = Self::project_snapshot(project.clone(), cx);
|
||||||
|
@ -951,6 +951,15 @@ impl Thread {
|
||||||
self.flush_pending_message(cx);
|
self.flush_pending_message(cx);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn update_token_usage(&mut self, update: language_model::TokenUsage) {
|
||||||
|
let Some(last_user_message) = self.last_user_message() else {
|
||||||
|
return;
|
||||||
|
};
|
||||||
|
|
||||||
|
self.request_token_usage
|
||||||
|
.insert(last_user_message.id.clone(), update);
|
||||||
|
}
|
||||||
|
|
||||||
pub fn truncate(&mut self, message_id: UserMessageId, cx: &mut Context<Self>) -> Result<()> {
|
pub fn truncate(&mut self, message_id: UserMessageId, cx: &mut Context<Self>) -> Result<()> {
|
||||||
self.cancel(cx);
|
self.cancel(cx);
|
||||||
let Some(position) = self.messages.iter().position(
|
let Some(position) = self.messages.iter().position(
|
||||||
|
@ -958,11 +967,31 @@ impl Thread {
|
||||||
) else {
|
) else {
|
||||||
return Err(anyhow!("Message not found"));
|
return Err(anyhow!("Message not found"));
|
||||||
};
|
};
|
||||||
self.messages.truncate(position);
|
|
||||||
|
for message in self.messages.drain(position..) {
|
||||||
|
match message {
|
||||||
|
Message::User(message) => {
|
||||||
|
self.request_token_usage.remove(&message.id);
|
||||||
|
}
|
||||||
|
Message::Agent(_) | Message::Resume => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
cx.notify();
|
cx.notify();
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn latest_token_usage(&self) -> Option<acp_thread::TokenUsage> {
|
||||||
|
let last_user_message = self.last_user_message()?;
|
||||||
|
let tokens = self.request_token_usage.get(&last_user_message.id)?;
|
||||||
|
let model = self.model.clone()?;
|
||||||
|
|
||||||
|
Some(acp_thread::TokenUsage {
|
||||||
|
max_tokens: model.max_token_count_for_mode(self.completion_mode.into()),
|
||||||
|
used_tokens: tokens.total_tokens(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
pub fn resume(
|
pub fn resume(
|
||||||
&mut self,
|
&mut self,
|
||||||
cx: &mut Context<Self>,
|
cx: &mut Context<Self>,
|
||||||
|
@ -1148,6 +1177,21 @@ impl Thread {
|
||||||
)) => {
|
)) => {
|
||||||
*tool_use_limit_reached = true;
|
*tool_use_limit_reached = true;
|
||||||
}
|
}
|
||||||
|
Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
|
||||||
|
let usage = acp_thread::TokenUsage {
|
||||||
|
max_tokens: model.max_token_count_for_mode(
|
||||||
|
request
|
||||||
|
.mode
|
||||||
|
.unwrap_or(cloud_llm_client::CompletionMode::Normal),
|
||||||
|
),
|
||||||
|
used_tokens: token_usage.total_tokens(),
|
||||||
|
};
|
||||||
|
|
||||||
|
this.update(cx, |this, _cx| this.update_token_usage(token_usage))
|
||||||
|
.ok();
|
||||||
|
|
||||||
|
event_stream.send_token_usage_update(usage);
|
||||||
|
}
|
||||||
Ok(LanguageModelCompletionEvent::Stop(StopReason::Refusal)) => {
|
Ok(LanguageModelCompletionEvent::Stop(StopReason::Refusal)) => {
|
||||||
*refusal = true;
|
*refusal = true;
|
||||||
return Ok(FuturesUnordered::default());
|
return Ok(FuturesUnordered::default());
|
||||||
|
@ -1532,6 +1576,16 @@ impl Thread {
|
||||||
})
|
})
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
fn last_user_message(&self) -> Option<&UserMessage> {
|
||||||
|
self.messages
|
||||||
|
.iter()
|
||||||
|
.rev()
|
||||||
|
.find_map(|message| match message {
|
||||||
|
Message::User(user_message) => Some(user_message),
|
||||||
|
Message::Agent(_) => None,
|
||||||
|
Message::Resume => None,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
fn pending_message(&mut self) -> &mut AgentMessage {
|
fn pending_message(&mut self) -> &mut AgentMessage {
|
||||||
self.pending_message.get_or_insert_default()
|
self.pending_message.get_or_insert_default()
|
||||||
|
@ -2051,6 +2105,12 @@ impl ThreadEventStream {
|
||||||
.ok();
|
.ok();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn send_token_usage_update(&self, usage: acp_thread::TokenUsage) {
|
||||||
|
self.0
|
||||||
|
.unbounded_send(Ok(ThreadEvent::TokenUsageUpdate(usage)))
|
||||||
|
.ok();
|
||||||
|
}
|
||||||
|
|
||||||
fn send_retry(&self, status: acp_thread::RetryStatus) {
|
fn send_retry(&self, status: acp_thread::RetryStatus) {
|
||||||
self.0.unbounded_send(Ok(ThreadEvent::Retry(status))).ok();
|
self.0.unbounded_send(Ok(ThreadEvent::Retry(status))).ok();
|
||||||
}
|
}
|
||||||
|
|
|
@ -816,7 +816,7 @@ impl AcpThreadView {
|
||||||
self.thread_retry_status.take();
|
self.thread_retry_status.take();
|
||||||
self.thread_state = ThreadState::ServerExited { status: *status };
|
self.thread_state = ThreadState::ServerExited { status: *status };
|
||||||
}
|
}
|
||||||
AcpThreadEvent::TitleUpdated => {}
|
AcpThreadEvent::TitleUpdated | AcpThreadEvent::TokenUsageUpdated => {}
|
||||||
}
|
}
|
||||||
cx.notify();
|
cx.notify();
|
||||||
}
|
}
|
||||||
|
@ -2794,6 +2794,7 @@ impl AcpThreadView {
|
||||||
.child(
|
.child(
|
||||||
h_flex()
|
h_flex()
|
||||||
.gap_1()
|
.gap_1()
|
||||||
|
.children(self.render_token_usage(cx))
|
||||||
.children(self.profile_selector.clone())
|
.children(self.profile_selector.clone())
|
||||||
.children(self.model_selector.clone())
|
.children(self.model_selector.clone())
|
||||||
.child(self.render_send_button(cx)),
|
.child(self.render_send_button(cx)),
|
||||||
|
@ -2816,6 +2817,44 @@ impl AcpThreadView {
|
||||||
.thread(acp_thread.session_id(), cx)
|
.thread(acp_thread.session_id(), cx)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn render_token_usage(&self, cx: &mut Context<Self>) -> Option<Div> {
|
||||||
|
let thread = self.thread()?.read(cx);
|
||||||
|
let usage = thread.token_usage()?;
|
||||||
|
let is_generating = thread.status() != ThreadStatus::Idle;
|
||||||
|
|
||||||
|
let used = crate::text_thread_editor::humanize_token_count(usage.used_tokens);
|
||||||
|
let max = crate::text_thread_editor::humanize_token_count(usage.max_tokens);
|
||||||
|
|
||||||
|
Some(
|
||||||
|
h_flex()
|
||||||
|
.flex_shrink_0()
|
||||||
|
.gap_0p5()
|
||||||
|
.mr_1()
|
||||||
|
.child(
|
||||||
|
Label::new(used)
|
||||||
|
.size(LabelSize::Small)
|
||||||
|
.color(Color::Muted)
|
||||||
|
.map(|label| {
|
||||||
|
if is_generating {
|
||||||
|
label
|
||||||
|
.with_animation(
|
||||||
|
"used-tokens-label",
|
||||||
|
Animation::new(Duration::from_secs(2))
|
||||||
|
.repeat()
|
||||||
|
.with_easing(pulsating_between(0.6, 1.)),
|
||||||
|
|label, delta| label.alpha(delta),
|
||||||
|
)
|
||||||
|
.into_any()
|
||||||
|
} else {
|
||||||
|
label.into_any_element()
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
.child(Label::new("/").size(LabelSize::Small).color(Color::Muted))
|
||||||
|
.child(Label::new(max).size(LabelSize::Small).color(Color::Muted)),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
fn toggle_burn_mode(
|
fn toggle_burn_mode(
|
||||||
&mut self,
|
&mut self,
|
||||||
_: &ToggleBurnMode,
|
_: &ToggleBurnMode,
|
||||||
|
|
|
@ -1526,6 +1526,7 @@ impl AgentDiff {
|
||||||
self.update_reviewing_editors(workspace, window, cx);
|
self.update_reviewing_editors(workspace, window, cx);
|
||||||
}
|
}
|
||||||
AcpThreadEvent::TitleUpdated
|
AcpThreadEvent::TitleUpdated
|
||||||
|
| AcpThreadEvent::TokenUsageUpdated
|
||||||
| AcpThreadEvent::EntriesRemoved(_)
|
| AcpThreadEvent::EntriesRemoved(_)
|
||||||
| AcpThreadEvent::ToolAuthorizationRequired
|
| AcpThreadEvent::ToolAuthorizationRequired
|
||||||
| AcpThreadEvent::Retry(_) => {}
|
| AcpThreadEvent::Retry(_) => {}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue