diff --git a/Cargo.lock b/Cargo.lock index d84675186e..2ac6054797 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4901,6 +4901,37 @@ dependencies = [ "num-traits", ] +[[package]] +name = "eval" +version = "0.1.0" +dependencies = [ + "agent", + "anyhow", + "assistant_tool", + "assistant_tools", + "client", + "collections", + "context_server", + "dap", + "env_logger 0.11.8", + "fs", + "gpui", + "gpui_tokio", + "language", + "language_model", + "language_models", + "node_runtime", + "project", + "prompt_store", + "release_channel", + "reqwest_client", + "serde", + "settings", + "smol", + "toml 0.8.20", + "workspace-hack", +] + [[package]] name = "evals" version = "0.1.0" diff --git a/Cargo.toml b/Cargo.toml index c696c6ebe6..b7a0825915 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -47,6 +47,7 @@ members = [ "crates/diagnostics", "crates/docs_preprocessor", "crates/editor", + "crates/eval", "crates/evals", "crates/extension", "crates/extension_api", diff --git a/crates/agent/src/active_thread.rs b/crates/agent/src/active_thread.rs index 6c1b45f3e2..09267db8ba 100644 --- a/crates/agent/src/active_thread.rs +++ b/crates/agent/src/active_thread.rs @@ -21,7 +21,7 @@ use gpui::{ linear_color_stop, linear_gradient, list, percentage, pulsating_between, }; use language::{Buffer, LanguageRegistry}; -use language_model::{ConfiguredModel, LanguageModelRegistry, LanguageModelToolUseId, Role}; +use language_model::{LanguageModelRegistry, LanguageModelToolUseId, Role}; use markdown::parser::CodeBlockKind; use markdown::{Markdown, MarkdownElement, MarkdownStyle, ParsedMarkdown, without_fences}; use project::ProjectItem as _; @@ -897,11 +897,7 @@ impl ActiveThread { self.save_thread(cx); cx.notify(); } - ThreadEvent::UsePendingTools => { - let tool_uses = self - .thread - .update(cx, |thread, cx| thread.use_pending_tools(cx)); - + ThreadEvent::UsePendingTools { tool_uses } => { for tool_use in tool_uses { self.render_tool_use_markdown( tool_use.id.clone(), @@ -913,11 +909,8 @@ impl ActiveThread { } } ThreadEvent::ToolFinished { - pending_tool_use, - canceled, - .. + pending_tool_use, .. } => { - let canceled = *canceled; if let Some(tool_use) = pending_tool_use { self.render_tool_use_markdown( tool_use.id.clone(), @@ -931,18 +924,6 @@ impl ActiveThread { cx, ); } - - if self.thread.read(cx).all_tools_finished() { - let model_registry = LanguageModelRegistry::read_global(cx); - if let Some(ConfiguredModel { model, .. }) = model_registry.default_model() { - self.thread.update(cx, |thread, cx| { - thread.attach_tool_results(cx); - if !canceled { - thread.send_to_model(model, RequestKind::Chat, cx); - } - }); - } - } } ThreadEvent::CheckpointChanged => cx.notify(), } diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index f83324295a..bebae477de 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -1181,7 +1181,8 @@ impl Thread { match result.as_ref() { Ok(stop_reason) => match stop_reason { StopReason::ToolUse => { - cx.emit(ThreadEvent::UsePendingTools); + let tool_uses = thread.use_pending_tools(cx); + cx.emit(ThreadEvent::UsePendingTools { tool_uses }); } StopReason::EndTurn => {} StopReason::MaxTokens => {} @@ -1369,10 +1370,7 @@ impl Thread { ) } - pub fn use_pending_tools( - &mut self, - cx: &mut Context, - ) -> impl IntoIterator + use<> { + pub fn use_pending_tools(&mut self, cx: &mut Context) -> Vec { let request = self.to_completion_request(RequestKind::Chat, cx); let messages = Arc::new(request.messages); let pending_tool_uses = self @@ -1460,18 +1458,36 @@ impl Thread { output, cx, ); - - cx.emit(ThreadEvent::ToolFinished { - tool_use_id, - pending_tool_use, - canceled: false, - }); + thread.tool_finished(tool_use_id, pending_tool_use, false, cx); }) .ok(); } }) } + fn tool_finished( + &mut self, + tool_use_id: LanguageModelToolUseId, + pending_tool_use: Option, + canceled: bool, + cx: &mut Context, + ) { + if self.all_tools_finished() { + let model_registry = LanguageModelRegistry::read_global(cx); + if let Some(ConfiguredModel { model, .. }) = model_registry.default_model() { + self.attach_tool_results(cx); + if !canceled { + self.send_to_model(model, RequestKind::Chat, cx); + } + } + } + + cx.emit(ThreadEvent::ToolFinished { + tool_use_id, + pending_tool_use, + }); + } + pub fn attach_tool_results(&mut self, cx: &mut Context) { // Insert a user message to contain the tool results. self.insert_user_message( @@ -1495,11 +1511,12 @@ impl Thread { let mut canceled = false; for pending_tool_use in self.tool_use.cancel_pending() { canceled = true; - cx.emit(ThreadEvent::ToolFinished { - tool_use_id: pending_tool_use.id.clone(), - pending_tool_use: Some(pending_tool_use), - canceled: true, - }); + self.tool_finished( + pending_tool_use.id.clone(), + Some(pending_tool_use), + true, + cx, + ); } canceled }; @@ -1866,12 +1883,7 @@ impl Thread { self.tool_use .insert_tool_output(tool_use_id.clone(), tool_name, err, cx); - - cx.emit(ThreadEvent::ToolFinished { - tool_use_id, - pending_tool_use: None, - canceled: true, - }); + self.tool_finished(tool_use_id.clone(), None, true, cx); } } @@ -1897,14 +1909,14 @@ pub enum ThreadEvent { MessageDeleted(MessageId), SummaryGenerated, SummaryChanged, - UsePendingTools, + UsePendingTools { + tool_uses: Vec, + }, ToolFinished { #[allow(unused)] tool_use_id: LanguageModelToolUseId, /// The pending tool use that corresponds to this tool. pending_tool_use: Option, - /// Whether the tool was canceled by the user. - canceled: bool, }, CheckpointChanged, ToolConfirmationNeeded, diff --git a/crates/agent_eval/src/headless_assistant.rs b/crates/agent_eval/src/headless_assistant.rs index 41fc41b30a..dbaf11a150 100644 --- a/crates/agent_eval/src/headless_assistant.rs +++ b/crates/agent_eval/src/headless_assistant.rs @@ -95,11 +95,7 @@ impl HeadlessAssistant { self.done_tx.send_blocking(Ok(())).unwrap() } } - ThreadEvent::UsePendingTools => { - thread.update(cx, |thread, cx| { - thread.use_pending_tools(cx); - }); - } + ThreadEvent::UsePendingTools { .. } => {} ThreadEvent::ToolConfirmationNeeded => { // Automatically approve all tools that need confirmation in headless mode println!("Tool confirmation needed - automatically approving in headless mode"); @@ -152,19 +148,6 @@ impl HeadlessAssistant { if let Some(tool_result) = thread.read(cx).tool_result(tool_use_id) { println!("Tool result: {:?}", tool_result); } - if thread.read(cx).all_tools_finished() { - let model_registry = LanguageModelRegistry::read_global(cx); - if let Some(model) = model_registry.default_model() { - thread.update(cx, |thread, cx| { - thread.attach_tool_results(cx); - thread.send_to_model(model.model, RequestKind::Chat, cx); - }); - } else { - println!( - "Warning: No active language model available to continue conversation" - ); - } - } } _ => {} } diff --git a/crates/eval/Cargo.toml b/crates/eval/Cargo.toml new file mode 100644 index 0000000000..e3701c9a23 --- /dev/null +++ b/crates/eval/Cargo.toml @@ -0,0 +1,39 @@ +[package] +name = "eval" +version = "0.1.0" +publish.workspace = true +edition.workspace = true + +[dependencies] +agent.workspace = true +anyhow.workspace = true +assistant_tool.workspace = true +assistant_tools.workspace = true +client.workspace = true +collections.workspace = true +context_server.workspace = true +dap.workspace = true +env_logger.workspace = true +fs.workspace = true +gpui.workspace = true +gpui_tokio.workspace = true +language.workspace = true +language_model.workspace = true +language_models.workspace = true +node_runtime.workspace = true +project.workspace = true +prompt_store.workspace = true +release_channel.workspace = true +reqwest_client.workspace = true +serde.workspace = true +settings.workspace = true +smol.workspace = true +toml.workspace = true +workspace-hack.workspace = true + +[[bin]] +name = "eval" +path = "src/eval.rs" + +[lints] +workspace = true diff --git a/crates/eval/LICENSE-GPL b/crates/eval/LICENSE-GPL new file mode 120000 index 0000000000..89e542f750 --- /dev/null +++ b/crates/eval/LICENSE-GPL @@ -0,0 +1 @@ +../../LICENSE-GPL \ No newline at end of file diff --git a/crates/eval/README.md b/crates/eval/README.md new file mode 100644 index 0000000000..b28806bebb --- /dev/null +++ b/crates/eval/README.md @@ -0,0 +1,7 @@ +# Eval + +This eval assumes the working directory is the root of the repository. Run it with: + +```sh +cargo run -p eval +``` diff --git a/crates/eval/examples/find_and_replace_diff_card/base.toml b/crates/eval/examples/find_and_replace_diff_card/base.toml new file mode 100644 index 0000000000..2b14a64530 --- /dev/null +++ b/crates/eval/examples/find_and_replace_diff_card/base.toml @@ -0,0 +1,2 @@ +path = "../zed_worktree" +revision = "38fcadf9481d018543c65f36ac3bafeba190179b" diff --git a/crates/eval/examples/find_and_replace_diff_card/prompt.md b/crates/eval/examples/find_and_replace_diff_card/prompt.md new file mode 100644 index 0000000000..efd23cbba3 --- /dev/null +++ b/crates/eval/examples/find_and_replace_diff_card/prompt.md @@ -0,0 +1,3 @@ +Look at the `find_replace_file_tool.rs`. I want to implement a card for it. The card should be a brand new `Entity` with a `Render` implementation. + +The card should show a diff. It should be a beautifully presented diff. The card "box" should look like what we show for markdown codeblocks (look at `MarkdownElement`). I want to see a red background for lines that were deleted and a green background for lines that were added. We should have a div per diff line. diff --git a/crates/eval/examples/find_and_replace_diff_card/rubric.md b/crates/eval/examples/find_and_replace_diff_card/rubric.md new file mode 100644 index 0000000000..e69de29bb2 diff --git a/crates/eval/src/agent.rs b/crates/eval/src/agent.rs new file mode 100644 index 0000000000..636c8b5b3d --- /dev/null +++ b/crates/eval/src/agent.rs @@ -0,0 +1,229 @@ +use ::agent::{RequestKind, Thread, ThreadEvent, ThreadStore}; +use anyhow::anyhow; +use assistant_tool::ToolWorkingSet; +use client::{Client, UserStore}; +use collections::HashMap; +use dap::DapRegistry; +use gpui::{App, Entity, SemanticVersion, Subscription, Task, prelude::*}; +use language::LanguageRegistry; +use language_model::{ + AuthenticateError, LanguageModel, LanguageModelProviderId, LanguageModelRegistry, +}; +use node_runtime::NodeRuntime; +use project::{Project, RealFs}; +use prompt_store::PromptBuilder; +use settings::SettingsStore; +use smol::channel; +use std::sync::Arc; + +/// Subset of `workspace::AppState` needed by `HeadlessAssistant`, with additional fields. +pub struct AgentAppState { + pub languages: Arc, + pub client: Arc, + pub user_store: Entity, + pub fs: Arc, + pub node_runtime: NodeRuntime, + + // Additional fields not present in `workspace::AppState`. + pub prompt_builder: Arc, +} + +pub struct Agent { + // pub thread: Entity, + // pub project: Entity, + #[allow(dead_code)] + pub thread_store: Entity, + pub tool_use_counts: HashMap, u32>, + pub done_tx: channel::Sender>, + _subscription: Subscription, +} + +impl Agent { + pub fn new( + app_state: Arc, + cx: &mut App, + ) -> anyhow::Result<(Entity, channel::Receiver>)> { + let env = None; + let project = Project::local( + app_state.client.clone(), + app_state.node_runtime.clone(), + app_state.user_store.clone(), + app_state.languages.clone(), + Arc::new(DapRegistry::default()), + app_state.fs.clone(), + env, + cx, + ); + + let tools = Arc::new(ToolWorkingSet::default()); + let thread_store = + ThreadStore::new(project.clone(), tools, app_state.prompt_builder.clone(), cx)?; + + let thread = thread_store.update(cx, |thread_store, cx| thread_store.create_thread(cx)); + + let (done_tx, done_rx) = channel::unbounded::>(); + + let headless_thread = cx.new(move |cx| Self { + _subscription: cx.subscribe(&thread, Self::handle_thread_event), + // thread, + // project, + thread_store, + tool_use_counts: HashMap::default(), + done_tx, + }); + + Ok((headless_thread, done_rx)) + } + + fn handle_thread_event( + &mut self, + thread: Entity, + event: &ThreadEvent, + cx: &mut Context, + ) { + match event { + ThreadEvent::ShowError(err) => self + .done_tx + .send_blocking(Err(anyhow!("{:?}", err))) + .unwrap(), + ThreadEvent::DoneStreaming => { + let thread = thread.read(cx); + if let Some(message) = thread.messages().last() { + println!("Message: {}", message.to_string()); + } + if thread.all_tools_finished() { + self.done_tx.send_blocking(Ok(())).unwrap() + } + } + ThreadEvent::UsePendingTools { .. } => {} + ThreadEvent::ToolConfirmationNeeded => { + // Automatically approve all tools that need confirmation in headless mode + println!("Tool confirmation needed - automatically approving in headless mode"); + + // Get the tools needing confirmation + let tools_needing_confirmation: Vec<_> = thread + .read(cx) + .tools_needing_confirmation() + .cloned() + .collect(); + + // Run each tool that needs confirmation + for tool_use in tools_needing_confirmation { + if let Some(tool) = thread.read(cx).tools().tool(&tool_use.name, cx) { + thread.update(cx, |thread, cx| { + println!("Auto-approving tool: {}", tool_use.name); + + // Create a request to send to the tool + let request = thread.to_completion_request(RequestKind::Chat, cx); + let messages = Arc::new(request.messages); + + // Run the tool + thread.run_tool( + tool_use.id.clone(), + tool_use.ui_text.clone(), + tool_use.input.clone(), + &messages, + tool, + cx, + ); + }); + } + } + } + ThreadEvent::ToolFinished { + tool_use_id, + pending_tool_use, + .. + } => { + if let Some(pending_tool_use) = pending_tool_use { + println!( + "Used tool {} with input: {}", + pending_tool_use.name, pending_tool_use.input + ); + *self + .tool_use_counts + .entry(pending_tool_use.name.clone()) + .or_insert(0) += 1; + } + if let Some(tool_result) = thread.read(cx).tool_result(tool_use_id) { + println!("Tool result: {:?}", tool_result); + } + } + _ => {} + } + } +} + +pub fn init(cx: &mut App) -> Arc { + release_channel::init(SemanticVersion::default(), cx); + gpui_tokio::init(cx); + + let mut settings_store = SettingsStore::new(cx); + settings_store + .set_default_settings(settings::default_settings().as_ref(), cx) + .unwrap(); + cx.set_global(settings_store); + client::init_settings(cx); + Project::init_settings(cx); + + let client = Client::production(cx); + cx.set_http_client(client.http_client().clone()); + + let git_binary_path = None; + let fs = Arc::new(RealFs::new( + git_binary_path, + cx.background_executor().clone(), + )); + + let languages = Arc::new(LanguageRegistry::new(cx.background_executor().clone())); + + let user_store = cx.new(|cx| UserStore::new(client.clone(), cx)); + + language::init(cx); + language_model::init(client.clone(), cx); + language_models::init(user_store.clone(), client.clone(), fs.clone(), cx); + assistant_tools::init(client.http_client().clone(), cx); + context_server::init(cx); + let stdout_is_a_pty = false; + let prompt_builder = PromptBuilder::load(fs.clone(), stdout_is_a_pty, cx); + agent::init(fs.clone(), client.clone(), prompt_builder.clone(), cx); + + Arc::new(AgentAppState { + languages, + client, + user_store, + fs, + node_runtime: NodeRuntime::unavailable(), + prompt_builder, + }) +} + +pub fn find_model(model_name: &str, cx: &App) -> anyhow::Result> { + let model_registry = LanguageModelRegistry::read_global(cx); + let model = model_registry + .available_models(cx) + .find(|model| model.id().0 == model_name); + + let Some(model) = model else { + return Err(anyhow!( + "No language model named {} was available. Available models: {}", + model_name, + model_registry + .available_models(cx) + .map(|model| model.id().0.clone()) + .collect::>() + .join(", ") + )); + }; + + Ok(model) +} + +pub fn authenticate_model_provider( + provider_id: LanguageModelProviderId, + cx: &mut App, +) -> Task> { + let model_registry = LanguageModelRegistry::read_global(cx); + let model_provider = model_registry.provider(&provider_id).unwrap(); + model_provider.authenticate(cx) +} diff --git a/crates/eval/src/eval.rs b/crates/eval/src/eval.rs new file mode 100644 index 0000000000..ad3180512d --- /dev/null +++ b/crates/eval/src/eval.rs @@ -0,0 +1,101 @@ +use agent::Agent; +use anyhow::Result; +use gpui::Application; +use language_model::LanguageModelRegistry; +use reqwest_client::ReqwestClient; +use serde::Deserialize; +use std::{ + fs, + path::{Path, PathBuf}, + sync::Arc, +}; +mod agent; + +#[derive(Debug, Deserialize)] +pub struct ExampleBase { + pub path: PathBuf, + pub revision: String, +} + +#[derive(Debug)] +pub struct Example { + pub base: ExampleBase, + + /// Content of the prompt.md file + pub prompt: String, + + /// Content of the rubric.md file + pub rubric: String, +} + +impl Example { + /// Load an example from a directory containing base.toml, prompt.md, and rubric.md + pub fn load_from_directory>(dir_path: P) -> Result { + let base_path = dir_path.as_ref().join("base.toml"); + let prompt_path = dir_path.as_ref().join("prompt.md"); + let rubric_path = dir_path.as_ref().join("rubric.md"); + + let mut base: ExampleBase = toml::from_str(&fs::read_to_string(&base_path)?)?; + base.path = base.path.canonicalize()?; + + Ok(Example { + base, + prompt: fs::read_to_string(prompt_path)?, + rubric: fs::read_to_string(rubric_path)?, + }) + } + + /// Set up the example by checking out the specified Git revision + pub fn setup(&self) -> Result<()> { + use std::process::Command; + + // Check if the directory exists + let path = Path::new(&self.base.path); + anyhow::ensure!(path.exists(), "Path does not exist: {:?}", self.base.path); + + // Change to the project directory and checkout the specified revision + let output = Command::new("git") + .current_dir(&self.base.path) + .arg("checkout") + .arg(&self.base.revision) + .output()?; + anyhow::ensure!( + output.status.success(), + "Failed to checkout revision {}: {}", + self.base.revision, + String::from_utf8_lossy(&output.stderr), + ); + + Ok(()) + } +} + +fn main() { + env_logger::init(); + let http_client = Arc::new(ReqwestClient::new()); + let app = Application::headless().with_http_client(http_client.clone()); + + app.run(move |cx| { + let app_state = crate::agent::init(cx); + let _agent = Agent::new(app_state, cx); + + let model = agent::find_model("claude-3-7-sonnet-thinking-latest", cx).unwrap(); + + LanguageModelRegistry::global(cx).update(cx, |registry, cx| { + registry.set_default_model(Some(model.clone()), cx); + }); + + let model_provider_id = model.provider_id(); + + let authenticate = agent::authenticate_model_provider(model_provider_id.clone(), cx); + + cx.spawn(async move |_cx| { + authenticate.await.unwrap(); + }) + .detach(); + }); + + // let example = + // Example::load_from_directory("./crates/eval/examples/find_and_replace_diff_card")?; + // example.setup()?; +}