Track tool use counts (#28722)

Release Notes:

- N/A
This commit is contained in:
Michael Sloan 2025-04-14 15:45:36 -06:00 committed by GitHub
parent 26b9c32e96
commit c8ccc472b5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 16 additions and 3 deletions

1
Cargo.lock generated
View file

@ -4882,6 +4882,7 @@ dependencies = [
"chrono", "chrono",
"clap", "clap",
"client", "client",
"collections",
"context_server", "context_server",
"dap", "dap",
"env_logger 0.11.8", "env_logger 0.11.8",

View file

@ -14,6 +14,7 @@ assistant_tools.workspace = true
chrono.workspace = true chrono.workspace = true
clap.workspace = true clap.workspace = true
client.workspace = true client.workspace = true
collections.workspace = true
context_server.workspace = true context_server.workspace = true
dap.workspace = true dap.workspace = true
env_logger.workspace = true env_logger.workspace = true

View file

@ -2,6 +2,7 @@ use agent::{RequestKind, ThreadEvent, ThreadStore};
use anyhow::{Context as _, Result, anyhow}; use anyhow::{Context as _, Result, anyhow};
use assistant_tool::ToolWorkingSet; use assistant_tool::ToolWorkingSet;
use client::proto::LspWorkProgress; use client::proto::LspWorkProgress;
use collections::HashMap;
use dap::DapRegistry; use dap::DapRegistry;
use futures::channel::{mpsc, oneshot}; use futures::channel::{mpsc, oneshot};
use futures::{FutureExt, StreamExt as _}; use futures::{FutureExt, StreamExt as _};
@ -63,6 +64,7 @@ pub struct RunOutput {
pub diagnostics: String, pub diagnostics: String,
pub response_count: usize, pub response_count: usize,
pub token_usage: TokenUsage, pub token_usage: TokenUsage,
pub tool_use_counts: HashMap<Arc<str>, u32>,
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
@ -270,12 +272,16 @@ impl Example {
log_file.flush().log_err(); log_file.flush().log_err();
} }
let tool_use_counts: Arc<Mutex<HashMap<Arc<str>, u32>>> =
Mutex::new(HashMap::default()).into();
let (tx, rx) = oneshot::channel(); let (tx, rx) = oneshot::channel();
let mut tx = Some(tx); let mut tx = Some(tx);
let _subscription = cx.subscribe(&thread, { let subscription = cx.subscribe(&thread, {
let log_file = this.log_file.clone(); let log_file = this.log_file.clone();
let name = this.name.clone(); let name = this.name.clone();
let tool_use_counts = tool_use_counts.clone();
move |thread, event: &ThreadEvent, cx| { move |thread, event: &ThreadEvent, cx| {
let mut log_file = log_file.lock().unwrap(); let mut log_file = log_file.lock().unwrap();
@ -327,8 +333,11 @@ impl Example {
writeln!(&mut log_file, "\n{}", message).log_err(); writeln!(&mut log_file, "\n{}", message).log_err();
} }
if let Some(tool_result) = thread.read(cx).tool_result(tool_use_id) { if let Some(tool_result) = thread.read(cx).tool_result(tool_use_id) {
let message = format!("\n{}\n", tool_result.content); writeln!(&mut log_file, "\n{}\n", tool_result.content).log_err();
writeln!(&mut log_file, "{}", message).log_err(); let mut tool_use_counts = tool_use_counts.lock().unwrap();
*tool_use_counts
.entry(tool_result.tool_name.clone())
.or_insert(0) += 1;
} }
} }
_ => {} _ => {}
@ -357,6 +366,7 @@ impl Example {
})? })?
.await?; .await?;
drop(subscription);
drop(lsp_open_handle_and_store); drop(lsp_open_handle_and_store);
thread.update(cx, |thread, _cx| { thread.update(cx, |thread, _cx| {
@ -369,6 +379,7 @@ impl Example {
diagnostics, diagnostics,
response_count, response_count,
token_usage: thread.cumulative_token_usage(), token_usage: thread.cumulative_token_usage(),
tool_use_counts: tool_use_counts.lock().unwrap().clone(),
} }
}) })
}) })