diff --git a/Cargo.lock b/Cargo.lock index 4881cb020a..5566faf25e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4961,6 +4961,7 @@ dependencies = [ "assistant_tools", "async-trait", "async-watch", + "buffer_diff", "chrono", "clap", "client", diff --git a/crates/assistant_tools/src/assistant_tools.rs b/crates/assistant_tools/src/assistant_tools.rs index 250c1490e5..2063c7d2a0 100644 --- a/crates/assistant_tools/src/assistant_tools.rs +++ b/crates/assistant_tools/src/assistant_tools.rs @@ -56,7 +56,10 @@ use crate::symbol_info_tool::SymbolInfoTool; use crate::terminal_tool::TerminalTool; use crate::thinking_tool::ThinkingTool; +pub use create_file_tool::CreateFileToolInput; +pub use edit_file_tool::EditFileToolInput; pub use path_search_tool::PathSearchToolInput; +pub use read_file_tool::ReadFileToolInput; pub fn init(http_client: Arc, cx: &mut App) { assistant_tool::init(cx); diff --git a/crates/eval/Cargo.toml b/crates/eval/Cargo.toml index b0f989b6ad..0046ca50a9 100644 --- a/crates/eval/Cargo.toml +++ b/crates/eval/Cargo.toml @@ -11,6 +11,7 @@ assistant_tool.workspace = true assistant_tools.workspace = true async-trait.workspace = true async-watch.workspace = true +buffer_diff.workspace = true chrono.workspace = true clap.workspace = true client.workspace = true diff --git a/crates/eval/src/example.rs b/crates/eval/src/example.rs index 84515bf45d..39b2d7d57c 100644 --- a/crates/eval/src/example.rs +++ b/crates/eval/src/example.rs @@ -1,6 +1,7 @@ use std::{ error::Error, fmt::{self, Debug}, + path::Path, sync::{Arc, Mutex}, time::Duration, }; @@ -12,6 +13,8 @@ use crate::{ use agent::ThreadEvent; use anyhow::{Result, anyhow}; use async_trait::async_trait; +use buffer_diff::DiffHunkStatus; +use collections::HashMap; use futures::{FutureExt as _, StreamExt, channel::mpsc, select_biased}; use gpui::{AppContext, AsyncApp, Entity}; use language_model::{LanguageModel, Role, StopReason}; @@ -234,9 +237,9 @@ impl ExampleContext { let mut tool_metrics = tool_metrics.lock().unwrap(); if let Some(tool_result) = thread.tool_result(&tool_use_id) { let message = if tool_result.is_error { - format!("TOOL FAILED: {}", tool_use.name) + format!("✖︎ {}", tool_use.name) } else { - format!("TOOL FINISHED: {}", tool_use.name) + format!("✔︎ {}", tool_use.name) }; println!("{log_prefix}{message}"); tool_metrics @@ -320,6 +323,36 @@ impl ExampleContext { Ok(response) } + + pub fn edits(&self) -> HashMap, FileEdits> { + self.app + .read_entity(&self.agent_thread, |thread, cx| { + let action_log = thread.action_log().read(cx); + HashMap::from_iter(action_log.changed_buffers(cx).into_iter().map( + |(buffer, diff)| { + let snapshot = buffer.read(cx).snapshot(); + + let file = snapshot.file().unwrap(); + let diff = diff.read(cx); + let base_text = diff.base_text().text(); + + let hunks = diff + .hunks(&snapshot, cx) + .map(|hunk| FileEditHunk { + base_text: base_text[hunk.diff_base_byte_range.clone()].to_string(), + text: snapshot + .text_for_range(hunk.range.clone()) + .collect::(), + status: hunk.status(), + }) + .collect(); + + (file.path().clone(), FileEdits { hunks }) + }, + )) + }) + .unwrap() + } } #[derive(Debug)] @@ -344,6 +377,10 @@ impl Response { }); cx.assert_some(result, format!("called `{}`", tool_name)) } + + pub fn tool_uses(&self) -> impl Iterator { + self.messages.iter().flat_map(|msg| &msg.tool_use) + } } #[derive(Debug)] @@ -355,17 +392,37 @@ pub struct Message { #[derive(Debug)] pub struct ToolUse { - name: String, + pub name: String, value: serde_json::Value, } impl ToolUse { - pub fn expect_input(&self, cx: &mut ExampleContext) -> Result + pub fn parse_input(&self) -> Result where Input: for<'de> serde::Deserialize<'de>, { - let result = - serde_json::from_value::(self.value.clone()).map_err(|err| anyhow!(err)); - cx.log_assertion(result, format!("valid `{}` input", &self.name)) + serde_json::from_value::(self.value.clone()).map_err(|err| anyhow!(err)) + } +} + +#[derive(Debug)] +pub struct FileEdits { + hunks: Vec, +} + +#[derive(Debug)] +struct FileEditHunk { + base_text: String, + text: String, + status: DiffHunkStatus, +} + +impl FileEdits { + pub fn has_added_line(&self, line: &str) -> bool { + self.hunks.iter().any(|hunk| { + hunk.status == DiffHunkStatus::added_none() + && hunk.base_text.is_empty() + && hunk.text.contains(line) + }) } } diff --git a/crates/eval/src/examples/add_arg_to_trait_method.rs b/crates/eval/src/examples/add_arg_to_trait_method.rs new file mode 100644 index 0000000000..5c3fb788f0 --- /dev/null +++ b/crates/eval/src/examples/add_arg_to_trait_method.rs @@ -0,0 +1,147 @@ +use std::{collections::HashSet, path::Path}; + +use anyhow::Result; +use assistant_tools::{CreateFileToolInput, EditFileToolInput, ReadFileToolInput}; +use async_trait::async_trait; + +use crate::example::{Example, ExampleContext, ExampleMetadata, JudgeAssertion, LanguageServer}; + +pub struct AddArgToTraitMethod; + +#[async_trait(?Send)] +impl Example for AddArgToTraitMethod { + fn meta(&self) -> ExampleMetadata { + ExampleMetadata { + name: "add_arg_to_trait_method".to_string(), + url: "https://github.com/zed-industries/zed.git".to_string(), + revision: "f69aeb6311dde3c0b8979c293d019d66498d54f2".to_string(), + language_server: Some(LanguageServer { + file_extension: "rs".to_string(), + allow_preexisting_diagnostics: false, + }), + max_assertions: None, + } + } + + async fn conversation(&self, cx: &mut ExampleContext) -> Result<()> { + const FILENAME: &str = "assistant_tool.rs"; + cx.push_user_message(format!( + r#" + Add a `window: Option` argument to the `Tool::run` trait method in {FILENAME}, + and update all the implementations of the trait and call sites accordingly. + "# + )); + + let response = cx.run_to_end().await?; + + // Reads files before it edits them + + let mut read_files = HashSet::new(); + + for tool_use in response.tool_uses() { + match tool_use.name.as_str() { + "read_file" => { + if let Ok(input) = tool_use.parse_input::() { + read_files.insert(input.path); + } + } + "create_file" => { + if let Ok(input) = tool_use.parse_input::() { + read_files.insert(input.path); + } + } + "edit_file" => { + if let Ok(input) = tool_use.parse_input::() { + cx.assert( + read_files.contains(input.path.to_str().unwrap()), + format!( + "Read before edit: {}", + &input.path.file_stem().unwrap().to_str().unwrap() + ), + ) + .ok(); + } + } + _ => {} + } + } + + // Adds ignored argument to all but `batch_tool` + + let add_ignored_window_paths = &[ + "code_action_tool", + "code_symbols_tool", + "contents_tool", + "copy_path_tool", + "create_directory_tool", + "create_file_tool", + "delete_path_tool", + "diagnostics_tool", + "edit_file_tool", + "fetch_tool", + "grep_tool", + "list_directory_tool", + "move_path_tool", + "now_tool", + "open_tool", + "path_search_tool", + "read_file_tool", + "rename_tool", + "symbol_info_tool", + "terminal_tool", + "thinking_tool", + "web_search_tool", + ]; + + let edits = cx.edits(); + + for tool_name in add_ignored_window_paths { + let path_str = format!("crates/assistant_tools/src/{}.rs", tool_name); + let edits = edits.get(Path::new(&path_str)); + + let ignored = edits.map_or(false, |edits| { + edits.has_added_line(" _window: Option,\n") + }); + let uningored = edits.map_or(false, |edits| { + edits.has_added_line(" window: Option,\n") + }); + + cx.assert(ignored || uningored, format!("Argument: {}", tool_name)) + .ok(); + + cx.assert(ignored, format!("`_` prefix: {}", tool_name)) + .ok(); + } + + // Adds unignored argument to `batch_tool` + + let batch_tool_edits = edits.get(Path::new("crates/assistant_tools/src/batch_tool.rs")); + + cx.assert( + batch_tool_edits.map_or(false, |edits| { + edits.has_added_line(" window: Option,\n") + }), + "Argument: batch_tool", + ) + .ok(); + + Ok(()) + } + + fn diff_assertions(&self) -> Vec { + vec![ + JudgeAssertion { + id: "batch tool passes window to each".to_string(), + description: + "batch_tool is modified to pass a clone of the window to each tool it calls." + .to_string(), + }, + JudgeAssertion { + id: "tool tests updated".to_string(), + description: + "tool tests are updated to pass the new `window` argument (`None` is ok)." + .to_string(), + }, + ] + } +} diff --git a/crates/eval/src/examples/file_search.rs b/crates/eval/src/examples/file_search.rs index 2649c87506..bbee8f008c 100644 --- a/crates/eval/src/examples/file_search.rs +++ b/crates/eval/src/examples/file_search.rs @@ -33,7 +33,7 @@ impl Example for FileSearchExample { let response = cx.run_turn().await?; let tool_use = response.expect_tool("path_search", cx)?; - let input = tool_use.expect_input::(cx)?; + let input = tool_use.parse_input::()?; let glob = input.glob; cx.assert( diff --git a/crates/eval/src/examples/mod.rs b/crates/eval/src/examples/mod.rs index 5e30f8ec80..83e44b6bfd 100644 --- a/crates/eval/src/examples/mod.rs +++ b/crates/eval/src/examples/mod.rs @@ -11,10 +11,14 @@ use util::serde::default_true; use crate::example::{Example, ExampleContext, ExampleMetadata, JudgeAssertion}; +mod add_arg_to_trait_method; mod file_search; pub fn all(examples_dir: &Path) -> Vec> { - let mut threads: Vec> = vec![Rc::new(file_search::FileSearchExample)]; + let mut threads: Vec> = vec![ + Rc::new(file_search::FileSearchExample), + Rc::new(add_arg_to_trait_method::AddArgToTraitMethod), + ]; for example_path in list_declarative_examples(examples_dir).unwrap() { threads.push(Rc::new(DeclarativeExample::load(&example_path).unwrap()));