From 2440faf4b22085faa5deebc3fc7f24efc02b89f1 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Thu, 10 Apr 2025 18:01:33 -0600 Subject: [PATCH] Actually run the eval and fix a hang when retrieving outline (#28547) Release Notes: - Fixed a regression that caused the agent to hang sometimes. --------- Co-authored-by: Thomas Mickley-Doyle Co-authored-by: Nathan Sobo Co-authored-by: Michael Sloan --- Cargo.lock | 57 +-- Cargo.toml | 4 - crates/agent/Cargo.toml | 2 +- crates/agent/src/active_thread.rs | 42 +- crates/agent/src/agent_diff.rs | 19 +- crates/agent/src/assistant_panel.rs | 10 +- crates/agent/src/message_editor.rs | 16 +- crates/agent/src/thread.rs | 167 +++----- crates/agent/src/thread_store.rs | 246 +++++++++-- crates/agent_eval/Cargo.toml | 46 --- crates/agent_eval/LICENSE-GPL | 1 - crates/agent_eval/build.rs | 52 --- crates/agent_eval/src/eval.rs | 384 ------------------ crates/agent_eval/src/get_exercise.rs | 149 ------- crates/agent_eval/src/git_commands.rs | 125 ------ crates/agent_eval/src/headless_assistant.rs | 229 ----------- crates/agent_eval/src/main.rs | 205 ---------- crates/agent_rules/Cargo.toml | 25 -- crates/agent_rules/LICENSE-GPL | 1 - crates/agent_rules/src/agent_rules.rs | 51 --- .../src/assistant_settings.rs | 2 +- .../assistant_tools/src/code_symbols_tool.rs | 8 +- crates/eval/Cargo.toml | 4 +- crates/eval/src/agent.rs | 229 ----------- crates/eval/src/eval.rs | 206 ++++++---- crates/eval/src/example.rs | 178 ++++++++ crates/gpui/src/app/async_context.rs | 24 +- crates/prompt_store/src/prompts.rs | 22 +- 28 files changed, 642 insertions(+), 1862 deletions(-) delete mode 100644 crates/agent_eval/Cargo.toml delete mode 120000 crates/agent_eval/LICENSE-GPL delete mode 100644 crates/agent_eval/build.rs delete mode 100644 crates/agent_eval/src/eval.rs delete mode 100644 crates/agent_eval/src/get_exercise.rs delete mode 100644 crates/agent_eval/src/git_commands.rs delete mode 100644 crates/agent_eval/src/headless_assistant.rs delete mode 100644 crates/agent_eval/src/main.rs delete mode 100644 crates/agent_rules/Cargo.toml delete mode 120000 crates/agent_rules/LICENSE-GPL delete mode 100644 crates/agent_rules/src/agent_rules.rs delete mode 100644 crates/eval/src/agent.rs create mode 100644 crates/eval/src/example.rs diff --git a/Cargo.lock b/Cargo.lock index ab3b46fd32..feafacda17 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -52,7 +52,6 @@ dependencies = [ name = "agent" version = "0.1.0" dependencies = [ - "agent_rules", "anyhow", "assistant_context_editor", "assistant_settings", @@ -116,6 +115,7 @@ dependencies = [ "terminal_view", "text", "theme", + "thiserror 2.0.12", "time", "time_format", "ui", @@ -127,57 +127,6 @@ dependencies = [ "zed_actions", ] -[[package]] -name = "agent_eval" -version = "0.1.0" -dependencies = [ - "agent", - "anyhow", - "assistant_tool", - "assistant_tools", - "clap", - "client", - "collections", - "context_server", - "dap", - "env_logger 0.11.8", - "fs", - "futures 0.3.31", - "gpui", - "gpui_tokio", - "language", - "language_model", - "language_models", - "node_runtime", - "project", - "prompt_store", - "release_channel", - "reqwest_client", - "serde", - "serde_json", - "serde_json_lenient", - "settings", - "smol", - "tempfile", - "util", - "walkdir", - "workspace-hack", -] - -[[package]] -name = "agent_rules" -version = "0.1.0" -dependencies = [ - "anyhow", - "fs", - "gpui", - "indoc", - "prompt_store", - "util", - "workspace-hack", - "worktree", -] - [[package]] name = "ahash" version = "0.7.8" @@ -4910,14 +4859,15 @@ version = "0.1.0" dependencies = [ "agent", "anyhow", + "assistant_settings", "assistant_tool", "assistant_tools", "client", - "collections", "context_server", "dap", "env_logger 0.11.8", "fs", + "futures 0.3.31", "gpui", "gpui_tokio", "language", @@ -4930,7 +4880,6 @@ dependencies = [ "reqwest_client", "serde", "settings", - "smol", "toml 0.8.20", "workspace-hack", ] diff --git a/Cargo.toml b/Cargo.toml index 9966140754..e1299b7451 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,13 +3,11 @@ resolver = "2" members = [ "crates/activity_indicator", "crates/agent", - "crates/agent_rules", "crates/anthropic", "crates/askpass", "crates/assets", "crates/assistant", "crates/assistant_context_editor", - "crates/agent_eval", "crates/assistant_settings", "crates/assistant_slash_command", "crates/assistant_slash_commands", @@ -211,14 +209,12 @@ edition = "2024" activity_indicator = { path = "crates/activity_indicator" } agent = { path = "crates/agent" } -agent_rules = { path = "crates/agent_rules" } ai = { path = "crates/ai" } anthropic = { path = "crates/anthropic" } askpass = { path = "crates/askpass" } assets = { path = "crates/assets" } assistant = { path = "crates/assistant" } assistant_context_editor = { path = "crates/assistant_context_editor" } -assistant_eval = { path = "crates/agent_eval" } assistant_settings = { path = "crates/assistant_settings" } assistant_slash_command = { path = "crates/assistant_slash_command" } assistant_slash_commands = { path = "crates/assistant_slash_commands" } diff --git a/crates/agent/Cargo.toml b/crates/agent/Cargo.toml index 9031d2db1a..ae184a1f38 100644 --- a/crates/agent/Cargo.toml +++ b/crates/agent/Cargo.toml @@ -19,7 +19,6 @@ test-support = [ ] [dependencies] -agent_rules.workspace = true anyhow.workspace = true assistant_context_editor.workspace = true assistant_settings.workspace = true @@ -81,6 +80,7 @@ terminal.workspace = true terminal_view.workspace = true text.workspace = true theme.workspace = true +thiserror.workspace = true time.workspace = true time_format.workspace = true ui.workspace = true diff --git a/crates/agent/src/active_thread.rs b/crates/agent/src/active_thread.rs index 55ea8a2a4e..57e2a78e95 100644 --- a/crates/agent/src/active_thread.rs +++ b/crates/agent/src/active_thread.rs @@ -4,7 +4,7 @@ use crate::thread::{ LastRestoreCheckpoint, MessageId, MessageSegment, RequestKind, Thread, ThreadError, ThreadEvent, ThreadFeedback, }; -use crate::thread_store::ThreadStore; +use crate::thread_store::{RulesLoadingError, ThreadStore}; use crate::tool_use::{PendingToolUseStatus, ToolUse, ToolUseStatus}; use crate::ui::{AddedContext, AgentNotification, AgentNotificationEvent, ContextPill}; use crate::{AssistantPanel, OpenActiveThreadAsMarkdown}; @@ -21,7 +21,7 @@ use gpui::{ linear_color_stop, linear_gradient, list, percentage, pulsating_between, }; use language::{Buffer, LanguageRegistry}; -use language_model::{LanguageModelRegistry, LanguageModelToolUseId, Role}; +use language_model::{LanguageModelRegistry, LanguageModelToolUseId, Role, StopReason}; use markdown::parser::{CodeBlockKind, CodeBlockMetadata}; use markdown::{Markdown, MarkdownElement, MarkdownStyle, ParsedMarkdown}; use project::ProjectItem as _; @@ -668,6 +668,7 @@ impl ActiveThread { let subscriptions = vec![ cx.observe(&thread, |_, _, cx| cx.notify()), cx.subscribe_in(&thread, window, Self::handle_thread_event), + cx.subscribe(&thread_store, Self::handle_rules_loading_error), ]; let list_state = ListState::new(0, ListAlignment::Bottom, px(2048.), { @@ -833,10 +834,9 @@ impl ActiveThread { | ThreadEvent::SummaryChanged => { self.save_thread(cx); } - ThreadEvent::DoneStreaming => { - let thread = self.thread.read(cx); - - if !thread.is_generating() { + ThreadEvent::Stopped(reason) => match reason { + Ok(StopReason::EndTurn | StopReason::MaxTokens) => { + let thread = self.thread.read(cx); self.show_notification( if thread.used_tools_since_last_user_message() { "Finished running tools" @@ -848,7 +848,8 @@ impl ActiveThread { cx, ); } - } + _ => {} + }, ThreadEvent::ToolConfirmationNeeded => { self.show_notification("Waiting for tool confirmation", IconName::Info, window, cx); } @@ -925,6 +926,19 @@ impl ActiveThread { } } + fn handle_rules_loading_error( + &mut self, + _thread_store: Entity, + error: &RulesLoadingError, + cx: &mut Context, + ) { + self.last_error = Some(ThreadError::Message { + header: "Error loading rules file".into(), + message: error.message.clone(), + }); + cx.notify(); + } + fn show_notification( &mut self, caption: impl Into, @@ -2701,12 +2715,13 @@ impl ActiveThread { } fn render_rules_item(&self, cx: &Context) -> AnyElement { - let Some(system_prompt_context) = self.thread.read(cx).system_prompt_context().as_ref() - else { + let project_context = self.thread.read(cx).project_context(); + let project_context = project_context.borrow(); + let Some(project_context) = project_context.as_ref() else { return div().into_any(); }; - let rules_files = system_prompt_context + let rules_files = project_context .worktrees .iter() .filter_map(|worktree| worktree.rules_file.as_ref()) @@ -2796,12 +2811,13 @@ impl ActiveThread { } fn handle_open_rules(&mut self, _: &ClickEvent, window: &mut Window, cx: &mut Context) { - let Some(system_prompt_context) = self.thread.read(cx).system_prompt_context().as_ref() - else { + let project_context = self.thread.read(cx).project_context(); + let project_context = project_context.borrow(); + let Some(project_context) = project_context.as_ref() else { return; }; - let abs_paths = system_prompt_context + let abs_paths = project_context .worktrees .iter() .flat_map(|worktree| worktree.rules_file.as_ref()) diff --git a/crates/agent/src/agent_diff.rs b/crates/agent/src/agent_diff.rs index 3b696a3e19..c3bc120ead 100644 --- a/crates/agent/src/agent_diff.rs +++ b/crates/agent/src/agent_diff.rs @@ -921,15 +921,16 @@ mod tests { }) .unwrap(); - let thread_store = cx.update(|cx| { - ThreadStore::new( - project.clone(), - Arc::default(), - Arc::new(PromptBuilder::new(None).unwrap()), - cx, - ) - .unwrap() - }); + let thread_store = cx + .update(|cx| { + ThreadStore::load( + project.clone(), + Arc::default(), + Arc::new(PromptBuilder::new(None).unwrap()), + cx, + ) + }) + .await; let thread = thread_store.update(cx, |store, cx| store.create_thread(cx)); let action_log = thread.read_with(cx, |thread, _| thread.action_log().clone()); diff --git a/crates/agent/src/assistant_panel.rs b/crates/agent/src/assistant_panel.rs index 6faa131df8..e257990d7d 100644 --- a/crates/agent/src/assistant_panel.rs +++ b/crates/agent/src/assistant_panel.rs @@ -194,10 +194,12 @@ impl AssistantPanel { ) -> Task>> { cx.spawn(async move |cx| { let tools = Arc::new(ToolWorkingSet::default()); - let thread_store = workspace.update(cx, |workspace, cx| { - let project = workspace.project().clone(); - ThreadStore::new(project, tools.clone(), prompt_builder.clone(), cx) - })??; + let thread_store = workspace + .update(cx, |workspace, cx| { + let project = workspace.project().clone(); + ThreadStore::load(project, tools.clone(), prompt_builder.clone(), cx) + })? + .await; let slash_commands = Arc::new(SlashCommandWorkingSet::default()); let context_store = workspace diff --git a/crates/agent/src/message_editor.rs b/crates/agent/src/message_editor.rs index 6daa5b6bbf..953c1f0f68 100644 --- a/crates/agent/src/message_editor.rs +++ b/crates/agent/src/message_editor.rs @@ -32,8 +32,8 @@ use crate::profile_selector::ProfileSelector; use crate::thread::{RequestKind, Thread, TokenUsageRatio}; use crate::thread_store::ThreadStore; use crate::{ - AgentDiff, Chat, ChatMode, NewThread, OpenAgentDiff, RemoveAllContext, ThreadEvent, - ToggleContextPicker, ToggleProfileSelector, + AgentDiff, Chat, ChatMode, NewThread, OpenAgentDiff, RemoveAllContext, ToggleContextPicker, + ToggleProfileSelector, }; pub struct MessageEditor { @@ -235,8 +235,6 @@ impl MessageEditor { let refresh_task = refresh_context_store_text(self.context_store.clone(), &HashSet::default(), cx); - let system_prompt_context_task = self.thread.read(cx).load_system_prompt_context(cx); - let thread = self.thread.clone(); let context_store = self.context_store.clone(); let git_store = self.project.read(cx).git_store().clone(); @@ -245,16 +243,6 @@ impl MessageEditor { cx.spawn(async move |this, cx| { let checkpoint = checkpoint.await.ok(); refresh_task.await; - let (system_prompt_context, load_error) = system_prompt_context_task.await; - - thread - .update(cx, |thread, cx| { - thread.set_system_prompt_context(system_prompt_context); - if let Some(load_error) = load_error { - cx.emit(ThreadEvent::ShowError(load_error)); - } - }) - .log_err(); thread .update(cx, |thread, cx| { diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index a66a8ada33..cba6970cc5 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -3,14 +3,12 @@ use std::io::Write; use std::ops::Range; use std::sync::Arc; -use agent_rules::load_worktree_rules_file; use anyhow::{Context as _, Result, anyhow}; use assistant_settings::AssistantSettings; use assistant_tool::{ActionLog, Tool, ToolWorkingSet}; use chrono::{DateTime, Utc}; use collections::{BTreeMap, HashMap}; use feature_flags::{self, FeatureFlagAppExt}; -use fs::Fs; use futures::future::Shared; use futures::{FutureExt, StreamExt as _}; use git::repository::DiffType; @@ -21,19 +19,20 @@ use language_model::{ LanguageModelToolResult, LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent, PaymentRequiredError, Role, StopReason, TokenUsage, }; +use project::Project; use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState}; -use project::{Project, Worktree}; -use prompt_store::{AssistantSystemPromptContext, PromptBuilder, WorktreeInfoForSystemPrompt}; +use prompt_store::PromptBuilder; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use settings::Settings; +use thiserror::Error; use util::{ResultExt as _, TryFutureExt as _, post_inc}; use uuid::Uuid; use crate::context::{AssistantContext, ContextId, format_context_as_string}; use crate::thread_store::{ SerializedMessage, SerializedMessageSegment, SerializedThread, SerializedToolResult, - SerializedToolUse, + SerializedToolUse, SharedProjectContext, }; use crate::tool_use::{PendingToolUse, ToolUse, ToolUseState, USING_TOOL_MARKER}; @@ -247,7 +246,7 @@ pub struct Thread { next_message_id: MessageId, context: BTreeMap, context_by_message: HashMap>, - system_prompt_context: Option, + project_context: SharedProjectContext, checkpoints_by_message: HashMap, completion_count: usize, pending_completions: Vec, @@ -269,6 +268,7 @@ impl Thread { project: Entity, tools: Arc, prompt_builder: Arc, + system_prompt: SharedProjectContext, cx: &mut Context, ) -> Self { Self { @@ -281,7 +281,7 @@ impl Thread { next_message_id: MessageId(0), context: BTreeMap::default(), context_by_message: HashMap::default(), - system_prompt_context: None, + project_context: system_prompt, checkpoints_by_message: HashMap::default(), completion_count: 0, pending_completions: Vec::new(), @@ -310,6 +310,7 @@ impl Thread { project: Entity, tools: Arc, prompt_builder: Arc, + project_context: SharedProjectContext, cx: &mut Context, ) -> Self { let next_message_id = MessageId( @@ -350,7 +351,7 @@ impl Thread { next_message_id, context: BTreeMap::default(), context_by_message: HashMap::default(), - system_prompt_context: None, + project_context, checkpoints_by_message: HashMap::default(), completion_count: 0, pending_completions: Vec::new(), @@ -388,6 +389,10 @@ impl Thread { self.summary.clone() } + pub fn project_context(&self) -> SharedProjectContext { + self.project_context.clone() + } + pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread"); pub fn summary_or_default(&self) -> SharedString { @@ -812,86 +817,6 @@ impl Thread { }) } - pub fn set_system_prompt_context(&mut self, context: AssistantSystemPromptContext) { - self.system_prompt_context = Some(context); - } - - pub fn system_prompt_context(&self) -> &Option { - &self.system_prompt_context - } - - pub fn load_system_prompt_context( - &self, - cx: &App, - ) -> Task<(AssistantSystemPromptContext, Option)> { - let project = self.project.read(cx); - let tasks = project - .visible_worktrees(cx) - .map(|worktree| { - Self::load_worktree_info_for_system_prompt( - project.fs().clone(), - worktree.read(cx), - cx, - ) - }) - .collect::>(); - - cx.spawn(async |_cx| { - let results = futures::future::join_all(tasks).await; - let mut first_err = None; - let worktrees = results - .into_iter() - .map(|(worktree, err)| { - if first_err.is_none() && err.is_some() { - first_err = err; - } - worktree - }) - .collect::>(); - (AssistantSystemPromptContext::new(worktrees), first_err) - }) - } - - fn load_worktree_info_for_system_prompt( - fs: Arc, - worktree: &Worktree, - cx: &App, - ) -> Task<(WorktreeInfoForSystemPrompt, Option)> { - let root_name = worktree.root_name().into(); - let abs_path = worktree.abs_path(); - - let rules_task = load_worktree_rules_file(fs, worktree, cx); - let Some(rules_task) = rules_task else { - return Task::ready(( - WorktreeInfoForSystemPrompt { - root_name, - abs_path, - rules_file: None, - }, - None, - )); - }; - - cx.spawn(async move |_| { - let (rules_file, rules_file_error) = match rules_task.await { - Ok(rules_file) => (Some(rules_file), None), - Err(err) => ( - None, - Some(ThreadError::Message { - header: "Error loading rules file".into(), - message: format!("{err}").into(), - }), - ), - }; - let worktree_info = WorktreeInfoForSystemPrompt { - root_name, - abs_path, - rules_file, - }; - (worktree_info, rules_file_error) - }) - } - pub fn send_to_model( &mut self, model: Arc, @@ -941,10 +866,10 @@ impl Thread { temperature: None, }; - if let Some(system_prompt_context) = self.system_prompt_context.as_ref() { + if let Some(project_context) = self.project_context.borrow().as_ref() { if let Some(system_prompt) = self .prompt_builder - .generate_assistant_system_prompt(system_prompt_context) + .generate_assistant_system_prompt(project_context) .context("failed to generate assistant system prompt") .log_err() { @@ -955,7 +880,7 @@ impl Thread { }); } } else { - log::error!("system_prompt_context not set.") + log::error!("project_context not set.") } for message in &self.messages { @@ -1215,7 +1140,7 @@ impl Thread { thread.cancel_last_completion(cx); } } - cx.emit(ThreadEvent::DoneStreaming); + cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new))); thread.auto_capture_telemetry(cx); @@ -1963,10 +1888,13 @@ impl Thread { } } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Error)] pub enum ThreadError { + #[error("Payment required")] PaymentRequired, + #[error("Max monthly spend reached")] MaxMonthlySpendReached, + #[error("Message {header}: {message}")] Message { header: SharedString, message: SharedString, @@ -1979,7 +1907,7 @@ pub enum ThreadEvent { StreamedCompletion, StreamedAssistantText(MessageId, String), StreamedAssistantThinking(MessageId, String), - DoneStreaming, + Stopped(Result>), MessageAdded(MessageId), MessageEdited(MessageId), MessageDeleted(MessageId), @@ -2085,9 +2013,9 @@ fn main() {{ thread.to_completion_request(RequestKind::Chat, cx) }); - assert_eq!(request.messages.len(), 1); + assert_eq!(request.messages.len(), 2); let expected_full_message = format!("{}Please explain this code", expected_context); - assert_eq!(request.messages[0].string_contents(), expected_full_message); + assert_eq!(request.messages[1].string_contents(), expected_full_message); } #[gpui::test] @@ -2178,20 +2106,20 @@ fn main() {{ }); // The request should contain all 3 messages - assert_eq!(request.messages.len(), 3); + assert_eq!(request.messages.len(), 4); // Check that the contexts are properly formatted in each message - assert!(request.messages[0].string_contents().contains("file1.rs")); - assert!(!request.messages[0].string_contents().contains("file2.rs")); - assert!(!request.messages[0].string_contents().contains("file3.rs")); - - assert!(!request.messages[1].string_contents().contains("file1.rs")); - assert!(request.messages[1].string_contents().contains("file2.rs")); + assert!(request.messages[1].string_contents().contains("file1.rs")); + assert!(!request.messages[1].string_contents().contains("file2.rs")); assert!(!request.messages[1].string_contents().contains("file3.rs")); assert!(!request.messages[2].string_contents().contains("file1.rs")); - assert!(!request.messages[2].string_contents().contains("file2.rs")); - assert!(request.messages[2].string_contents().contains("file3.rs")); + assert!(request.messages[2].string_contents().contains("file2.rs")); + assert!(!request.messages[2].string_contents().contains("file3.rs")); + + assert!(!request.messages[3].string_contents().contains("file1.rs")); + assert!(!request.messages[3].string_contents().contains("file2.rs")); + assert!(request.messages[3].string_contents().contains("file3.rs")); } #[gpui::test] @@ -2229,9 +2157,9 @@ fn main() {{ thread.to_completion_request(RequestKind::Chat, cx) }); - assert_eq!(request.messages.len(), 1); + assert_eq!(request.messages.len(), 2); assert_eq!( - request.messages[0].string_contents(), + request.messages[1].string_contents(), "What is the best way to learn Rust?" ); @@ -2249,13 +2177,13 @@ fn main() {{ thread.to_completion_request(RequestKind::Chat, cx) }); - assert_eq!(request.messages.len(), 2); + assert_eq!(request.messages.len(), 3); assert_eq!( - request.messages[0].string_contents(), + request.messages[1].string_contents(), "What is the best way to learn Rust?" ); assert_eq!( - request.messages[1].string_contents(), + request.messages[2].string_contents(), "Are there any good books?" ); } @@ -2376,15 +2304,16 @@ fn main() {{ let (workspace, cx) = cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx)); - let thread_store = cx.update(|_, cx| { - ThreadStore::new( - project.clone(), - Arc::default(), - Arc::new(PromptBuilder::new(None).unwrap()), - cx, - ) - .unwrap() - }); + let thread_store = cx + .update(|_, cx| { + ThreadStore::load( + project.clone(), + Arc::default(), + Arc::new(PromptBuilder::new(None).unwrap()), + cx, + ) + }) + .await; let thread = thread_store.update(cx, |store, cx| store.create_thread(cx)); let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None)); diff --git a/crates/agent/src/thread_store.rs b/crates/agent/src/thread_store.rs index aa7d515e68..c8f8d239a2 100644 --- a/crates/agent/src/thread_store.rs +++ b/crates/agent/src/thread_store.rs @@ -1,37 +1,57 @@ use std::borrow::Cow; -use std::path::PathBuf; +use std::cell::{Ref, RefCell}; +use std::path::{Path, PathBuf}; +use std::rc::Rc; use std::sync::Arc; -use anyhow::{Result, anyhow}; +use anyhow::{Context as _, Result, anyhow}; use assistant_settings::{AgentProfile, AgentProfileId, AssistantSettings}; use assistant_tool::{ToolId, ToolSource, ToolWorkingSet}; use chrono::{DateTime, Utc}; use collections::HashMap; use context_server::manager::ContextServerManager; use context_server::{ContextServerFactoryRegistry, ContextServerTool}; +use fs::Fs; use futures::FutureExt as _; use futures::future::{self, BoxFuture, Shared}; use gpui::{ - App, BackgroundExecutor, Context, Entity, Global, ReadGlobal, SharedString, Subscription, Task, - prelude::*, + App, BackgroundExecutor, Context, Entity, EventEmitter, Global, ReadGlobal, SharedString, + Subscription, Task, prelude::*, }; use heed::Database; use heed::types::SerdeBincode; use language_model::{LanguageModelToolUseId, Role, TokenUsage}; -use project::Project; -use prompt_store::PromptBuilder; +use project::{Project, Worktree}; +use prompt_store::{ProjectContext, PromptBuilder, RulesFileContext, WorktreeContext}; use serde::{Deserialize, Serialize}; use settings::{Settings as _, SettingsStore}; use util::ResultExt as _; -use crate::thread::{ - DetailedSummaryState, MessageId, ProjectSnapshot, Thread, ThreadEvent, ThreadId, -}; +use crate::thread::{DetailedSummaryState, MessageId, ProjectSnapshot, Thread, ThreadId}; + +const RULES_FILE_NAMES: [&'static str; 6] = [ + ".rules", + ".cursorrules", + ".windsurfrules", + ".clinerules", + ".github/copilot-instructions.md", + "CLAUDE.md", +]; pub fn init(cx: &mut App) { ThreadsDatabase::init(cx); } +/// A system prompt shared by all threads created by this ThreadStore +#[derive(Clone, Default)] +pub struct SharedProjectContext(Rc>>); + +impl SharedProjectContext { + pub fn borrow(&self) -> Ref> { + self.0.borrow() + } +} + pub struct ThreadStore { project: Entity, tools: Arc, @@ -39,43 +59,187 @@ pub struct ThreadStore { context_server_manager: Entity, context_server_tool_ids: HashMap, Vec>, threads: Vec, + project_context: SharedProjectContext, _subscriptions: Vec, } +pub struct RulesLoadingError { + pub message: SharedString, +} + +impl EventEmitter for ThreadStore {} + impl ThreadStore { - pub fn new( + pub fn load( project: Entity, tools: Arc, prompt_builder: Arc, cx: &mut App, - ) -> Result> { - let this = cx.new(|cx| { - let context_server_factory_registry = ContextServerFactoryRegistry::default_global(cx); - let context_server_manager = cx.new(|cx| { - ContextServerManager::new(context_server_factory_registry, project.clone(), cx) - }); - let settings_subscription = - cx.observe_global::(move |this: &mut Self, cx| { - this.load_default_profile(cx); - }); + ) -> Task> { + let thread_store = cx.new(|cx| Self::new(project, tools, prompt_builder, cx)); + let reload = thread_store.update(cx, |store, cx| store.reload_system_prompt(cx)); + cx.foreground_executor().spawn(async move { + reload.await; + thread_store + }) + } - let this = Self { - project, - tools, - prompt_builder, - context_server_manager, - context_server_tool_ids: HashMap::default(), - threads: Vec::new(), - _subscriptions: vec![settings_subscription], - }; - this.load_default_profile(cx); - this.register_context_server_handlers(cx); - this.reload(cx).detach_and_log_err(cx); - - this + fn new( + project: Entity, + tools: Arc, + prompt_builder: Arc, + cx: &mut Context, + ) -> Self { + let context_server_factory_registry = ContextServerFactoryRegistry::default_global(cx); + let context_server_manager = cx.new(|cx| { + ContextServerManager::new(context_server_factory_registry, project.clone(), cx) }); + let settings_subscription = + cx.observe_global::(move |this: &mut Self, cx| { + this.load_default_profile(cx); + }); + let project_subscription = cx.subscribe(&project, Self::handle_project_event); - Ok(this) + let this = Self { + project, + tools, + prompt_builder, + context_server_manager, + context_server_tool_ids: HashMap::default(), + threads: Vec::new(), + project_context: SharedProjectContext::default(), + _subscriptions: vec![settings_subscription, project_subscription], + }; + this.load_default_profile(cx); + this.register_context_server_handlers(cx); + this.reload(cx).detach_and_log_err(cx); + this + } + + fn handle_project_event( + &mut self, + _project: Entity, + event: &project::Event, + cx: &mut Context, + ) { + match event { + project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => { + self.reload_system_prompt(cx).detach(); + } + project::Event::WorktreeUpdatedEntries(_, items) => { + if items.iter().any(|(path, _, _)| { + RULES_FILE_NAMES + .iter() + .any(|name| path.as_ref() == Path::new(name)) + }) { + self.reload_system_prompt(cx).detach(); + } + } + _ => {} + } + } + + pub fn reload_system_prompt(&self, cx: &mut Context) -> Task<()> { + let project = self.project.read(cx); + let tasks = project + .visible_worktrees(cx) + .map(|worktree| { + Self::load_worktree_info_for_system_prompt( + project.fs().clone(), + worktree.read(cx), + cx, + ) + }) + .collect::>(); + + cx.spawn(async move |this, cx| { + let results = futures::future::join_all(tasks).await; + let worktrees = results + .into_iter() + .map(|(worktree, rules_error)| { + if let Some(rules_error) = rules_error { + this.update(cx, |_, cx| cx.emit(rules_error)).ok(); + } + worktree + }) + .collect::>(); + this.update(cx, |this, _cx| { + *this.project_context.0.borrow_mut() = Some(ProjectContext::new(worktrees)); + }) + .ok(); + }) + } + + fn load_worktree_info_for_system_prompt( + fs: Arc, + worktree: &Worktree, + cx: &App, + ) -> Task<(WorktreeContext, Option)> { + let root_name = worktree.root_name().into(); + let abs_path = worktree.abs_path(); + + let rules_task = Self::load_worktree_rules_file(fs, worktree, cx); + let Some(rules_task) = rules_task else { + return Task::ready(( + WorktreeContext { + root_name, + abs_path, + rules_file: None, + }, + None, + )); + }; + + cx.spawn(async move |_| { + let (rules_file, rules_file_error) = match rules_task.await { + Ok(rules_file) => (Some(rules_file), None), + Err(err) => ( + None, + Some(RulesLoadingError { + message: format!("{err}").into(), + }), + ), + }; + let worktree_info = WorktreeContext { + root_name, + abs_path, + rules_file, + }; + (worktree_info, rules_file_error) + }) + } + + fn load_worktree_rules_file( + fs: Arc, + worktree: &Worktree, + cx: &App, + ) -> Option>> { + let selected_rules_file = RULES_FILE_NAMES + .into_iter() + .filter_map(|name| { + worktree + .entry_for_path(name) + .filter(|entry| entry.is_file()) + .map(|entry| (entry.path.clone(), worktree.absolutize(&entry.path))) + }) + .next(); + + // Note that Cline supports `.clinerules` being a directory, but that is not currently + // supported. This doesn't seem to occur often in GitHub repositories. + selected_rules_file.map(|(path_in_worktree, abs_path)| { + let fs = fs.clone(); + cx.background_spawn(async move { + let abs_path = abs_path?; + let text = fs.load(&abs_path).await.with_context(|| { + format!("Failed to load assistant rules file {:?}", abs_path) + })?; + anyhow::Ok(RulesFileContext { + path_in_worktree, + abs_path: abs_path.into(), + text: text.trim().to_string(), + }) + }) + }) } pub fn context_server_manager(&self) -> Entity { @@ -107,6 +271,7 @@ impl ThreadStore { self.project.clone(), self.tools.clone(), self.prompt_builder.clone(), + self.project_context.clone(), cx, ) }) @@ -134,21 +299,12 @@ impl ThreadStore { this.project.clone(), this.tools.clone(), this.prompt_builder.clone(), + this.project_context.clone(), cx, ) }) })?; - let (system_prompt_context, load_error) = thread - .update(cx, |thread, cx| thread.load_system_prompt_context(cx))? - .await; - thread.update(cx, |thread, cx| { - thread.set_system_prompt_context(system_prompt_context); - if let Some(load_error) = load_error { - cx.emit(ThreadEvent::ShowError(load_error)); - } - })?; - Ok(thread) }) } diff --git a/crates/agent_eval/Cargo.toml b/crates/agent_eval/Cargo.toml deleted file mode 100644 index 8d17710c02..0000000000 --- a/crates/agent_eval/Cargo.toml +++ /dev/null @@ -1,46 +0,0 @@ -[package] -name = "agent_eval" -version = "0.1.0" -edition.workspace = true -publish.workspace = true -license = "GPL-3.0-or-later" - -[lints] -workspace = true - -[[bin]] -name = "agent_eval" -path = "src/main.rs" - -[dependencies] -agent.workspace = true -anyhow.workspace = true -assistant_tool.workspace = true -assistant_tools.workspace = true -clap.workspace = true -client.workspace = true -collections.workspace = true -context_server.workspace = true -dap.workspace = true -env_logger.workspace = true -fs.workspace = true -futures.workspace = true -gpui.workspace = true -gpui_tokio.workspace = true -language.workspace = true -language_model.workspace = true -language_models.workspace = true -node_runtime.workspace = true -project.workspace = true -prompt_store.workspace = true -release_channel.workspace = true -reqwest_client.workspace = true -serde.workspace = true -serde_json.workspace = true -serde_json_lenient.workspace = true -settings.workspace = true -smol.workspace = true -tempfile.workspace = true -util.workspace = true -walkdir.workspace = true -workspace-hack.workspace = true diff --git a/crates/agent_eval/LICENSE-GPL b/crates/agent_eval/LICENSE-GPL deleted file mode 120000 index 89e542f750..0000000000 --- a/crates/agent_eval/LICENSE-GPL +++ /dev/null @@ -1 +0,0 @@ -../../LICENSE-GPL \ No newline at end of file diff --git a/crates/agent_eval/build.rs b/crates/agent_eval/build.rs deleted file mode 100644 index 5b955c222a..0000000000 --- a/crates/agent_eval/build.rs +++ /dev/null @@ -1,52 +0,0 @@ -// Copied from `crates/zed/build.rs`, with removal of code for including the zed icon on windows. - -use std::process::Command; - -fn main() { - if cfg!(target_os = "macos") { - println!("cargo:rustc-env=MACOSX_DEPLOYMENT_TARGET=10.15.7"); - - // Weakly link ReplayKit to ensure Zed can be used on macOS 10.15+. - println!("cargo:rustc-link-arg=-Wl,-weak_framework,ReplayKit"); - - // Seems to be required to enable Swift concurrency - println!("cargo:rustc-link-arg=-Wl,-rpath,/usr/lib/swift"); - - // Register exported Objective-C selectors, protocols, etc - println!("cargo:rustc-link-arg=-Wl,-ObjC"); - } - - // Populate git sha environment variable if git is available - println!("cargo:rerun-if-changed=../../.git/logs/HEAD"); - println!( - "cargo:rustc-env=TARGET={}", - std::env::var("TARGET").unwrap() - ); - if let Ok(output) = Command::new("git").args(["rev-parse", "HEAD"]).output() { - if output.status.success() { - let git_sha = String::from_utf8_lossy(&output.stdout); - let git_sha = git_sha.trim(); - - println!("cargo:rustc-env=ZED_COMMIT_SHA={git_sha}"); - - if let Ok(build_profile) = std::env::var("PROFILE") { - if build_profile == "release" { - // This is currently the best way to make `cargo build ...`'s build script - // to print something to stdout without extra verbosity. - println!( - "cargo:warning=Info: using '{git_sha}' hash for ZED_COMMIT_SHA env var" - ); - } - } - } - } - - #[cfg(target_os = "windows")] - { - #[cfg(target_env = "msvc")] - { - // todo(windows): This is to avoid stack overflow. Remove it when solved. - println!("cargo:rustc-link-arg=/stack:{}", 8 * 1024 * 1024); - } - } -} diff --git a/crates/agent_eval/src/eval.rs b/crates/agent_eval/src/eval.rs deleted file mode 100644 index 8723701a1c..0000000000 --- a/crates/agent_eval/src/eval.rs +++ /dev/null @@ -1,384 +0,0 @@ -use crate::git_commands::{run_git, setup_temp_repo}; -use crate::headless_assistant::{HeadlessAppState, HeadlessAssistant}; -use crate::{get_exercise_language, get_exercise_name}; -use agent::RequestKind; -use anyhow::{Result, anyhow}; -use collections::HashMap; -use gpui::{App, Task}; -use language_model::{LanguageModel, TokenUsage}; -use serde::{Deserialize, Serialize}; -use std::{ - fs, - io::Write, - path::{Path, PathBuf}, - sync::Arc, - time::{Duration, SystemTime}, -}; - -#[derive(Debug, Serialize, Deserialize, Clone)] -pub struct EvalResult { - pub exercise_name: String, - pub diff: String, - pub assistant_response: String, - pub elapsed_time_ms: u128, - pub timestamp: u128, - // Token usage fields - pub input_tokens: usize, - pub output_tokens: usize, - pub total_tokens: usize, - pub tool_use_counts: usize, -} - -pub struct EvalOutput { - pub diff: String, - pub last_message: String, - pub elapsed_time: Duration, - pub assistant_response_count: usize, - pub tool_use_counts: HashMap, u32>, - pub token_usage: TokenUsage, -} - -#[derive(Deserialize)] -pub struct EvalSetup { - pub url: String, - pub base_sha: String, -} - -pub struct Eval { - pub repo_path: PathBuf, - pub eval_setup: EvalSetup, - pub user_prompt: String, -} - -impl Eval { - // Keep this method for potential future use, but mark it as intentionally unused - #[allow(dead_code)] - pub async fn load(_name: String, path: PathBuf, repos_dir: &Path) -> Result { - let prompt_path = path.join("prompt.txt"); - let user_prompt = smol::unblock(|| std::fs::read_to_string(prompt_path)).await?; - let setup_path = path.join("setup.json"); - let setup_contents = smol::unblock(|| std::fs::read_to_string(setup_path)).await?; - let eval_setup = serde_json_lenient::from_str_lenient::(&setup_contents)?; - - // Move this internal function inside the load method since it's only used here - fn repo_dir_name(url: &str) -> String { - url.trim_start_matches("https://") - .replace(|c: char| !c.is_alphanumeric(), "_") - } - - let repo_path = repos_dir.join(repo_dir_name(&eval_setup.url)); - - Ok(Eval { - repo_path, - eval_setup, - user_prompt, - }) - } - - pub fn run( - self, - app_state: Arc, - model: Arc, - cx: &mut App, - ) -> Task> { - cx.spawn(async move |cx| { - run_git(&self.repo_path, &["checkout", &self.eval_setup.base_sha]).await?; - - let (assistant, done_rx) = - cx.update(|cx| HeadlessAssistant::new(app_state.clone(), cx))??; - - let _worktree = assistant - .update(cx, |assistant, cx| { - assistant.project.update(cx, |project, cx| { - project.create_worktree(&self.repo_path, true, cx) - }) - })? - .await?; - - let start_time = std::time::SystemTime::now(); - - let (system_prompt_context, load_error) = cx - .update(|cx| { - assistant - .read(cx) - .thread - .read(cx) - .load_system_prompt_context(cx) - })? - .await; - - if let Some(load_error) = load_error { - return Err(anyhow!("{:?}", load_error)); - }; - - assistant.update(cx, |assistant, cx| { - assistant.thread.update(cx, |thread, cx| { - let context = vec![]; - thread.insert_user_message(self.user_prompt.clone(), context, None, cx); - thread.set_system_prompt_context(system_prompt_context); - thread.send_to_model(model, RequestKind::Chat, cx); - }); - })?; - - done_rx.recv().await??; - - // Add this section to check untracked files - println!("Checking for untracked files:"); - let untracked = run_git( - &self.repo_path, - &["ls-files", "--others", "--exclude-standard"], - ) - .await?; - if untracked.is_empty() { - println!("No untracked files found"); - } else { - // Add all files to git so they appear in the diff - println!("Adding untracked files to git"); - run_git(&self.repo_path, &["add", "."]).await?; - } - - // get git status - let _status = run_git(&self.repo_path, &["status", "--short"]).await?; - - let elapsed_time = start_time.elapsed()?; - - // Get diff of staged changes (the files we just added) - let staged_diff = run_git(&self.repo_path, &["diff", "--staged"]).await?; - - // Get diff of unstaged changes - let unstaged_diff = run_git(&self.repo_path, &["diff"]).await?; - - // Combine both diffs - let diff = if unstaged_diff.is_empty() { - staged_diff - } else if staged_diff.is_empty() { - unstaged_diff - } else { - format!( - "# Staged changes\n{}\n\n# Unstaged changes\n{}", - staged_diff, unstaged_diff - ) - }; - - assistant.update(cx, |assistant, cx| { - let thread = assistant.thread.read(cx); - let last_message = thread.messages().last().unwrap(); - if last_message.role != language_model::Role::Assistant { - return Err(anyhow!("Last message is not from assistant")); - } - let assistant_response_count = thread - .messages() - .filter(|message| message.role == language_model::Role::Assistant) - .count(); - Ok(EvalOutput { - diff, - last_message: last_message.to_string(), - elapsed_time, - assistant_response_count, - tool_use_counts: assistant.tool_use_counts.clone(), - token_usage: thread.cumulative_token_usage(), - }) - })? - }) - } -} - -impl EvalOutput { - // Keep this method for potential future use, but mark it as intentionally unused - #[allow(dead_code)] - pub fn save_to_directory(&self, output_dir: &Path, eval_output_value: String) -> Result<()> { - // Create the output directory if it doesn't exist - fs::create_dir_all(&output_dir)?; - - // Save the diff to a file - let diff_path = output_dir.join("diff.patch"); - let mut diff_file = fs::File::create(&diff_path)?; - diff_file.write_all(self.diff.as_bytes())?; - - // Save the last message to a file - let message_path = output_dir.join("assistant_response.txt"); - let mut message_file = fs::File::create(&message_path)?; - message_file.write_all(self.last_message.as_bytes())?; - - // Current metrics for this run - let current_metrics = serde_json::json!({ - "elapsed_time_ms": self.elapsed_time.as_millis(), - "assistant_response_count": self.assistant_response_count, - "tool_use_counts": self.tool_use_counts, - "token_usage": self.token_usage, - "eval_output_value": eval_output_value, - }); - - // Get current timestamp in milliseconds - let timestamp = std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH)? - .as_millis() - .to_string(); - - // Path to metrics file - let metrics_path = output_dir.join("metrics.json"); - - // Load existing metrics if the file exists, or create a new object - let mut historical_metrics = if metrics_path.exists() { - let metrics_content = fs::read_to_string(&metrics_path)?; - serde_json::from_str::(&metrics_content) - .unwrap_or_else(|_| serde_json::json!({})) - } else { - serde_json::json!({}) - }; - - // Add new run with timestamp as key - if let serde_json::Value::Object(ref mut map) = historical_metrics { - map.insert(timestamp, current_metrics); - } - - // Write updated metrics back to file - let metrics_json = serde_json::to_string_pretty(&historical_metrics)?; - let mut metrics_file = fs::File::create(&metrics_path)?; - metrics_file.write_all(metrics_json.as_bytes())?; - - Ok(()) - } -} - -pub async fn read_instructions(exercise_path: &Path) -> Result { - let instructions_path = exercise_path.join(".docs").join("instructions.md"); - println!("Reading instructions from: {}", instructions_path.display()); - let instructions = smol::unblock(move || std::fs::read_to_string(&instructions_path)).await?; - Ok(instructions) -} - -pub async fn save_eval_results(exercise_path: &Path, results: Vec) -> Result<()> { - let eval_dir = exercise_path.join("evaluation"); - fs::create_dir_all(&eval_dir)?; - - let eval_file = eval_dir.join("evals.json"); - - println!("Saving evaluation results to: {}", eval_file.display()); - println!( - "Results to save: {} evaluations for exercise path: {}", - results.len(), - exercise_path.display() - ); - - // Check file existence before reading/writing - if eval_file.exists() { - println!("Existing evals.json file found, will update it"); - } else { - println!("No existing evals.json file found, will create new one"); - } - - // Structure to organize evaluations by test name and timestamp - let mut eval_data: serde_json::Value = if eval_file.exists() { - let content = fs::read_to_string(&eval_file)?; - serde_json::from_str(&content).unwrap_or_else(|_| serde_json::json!({})) - } else { - serde_json::json!({}) - }; - - // Get current timestamp for this batch of results - let timestamp = SystemTime::now() - .duration_since(SystemTime::UNIX_EPOCH)? - .as_millis() - .to_string(); - - // Group the new results by test name (exercise name) - for result in results { - let exercise_name = &result.exercise_name; - - println!("Adding result: exercise={}", exercise_name); - - // Ensure the exercise entry exists - if eval_data.get(exercise_name).is_none() { - eval_data[exercise_name] = serde_json::json!({}); - } - - // Ensure the timestamp entry exists as an object - if eval_data[exercise_name].get(×tamp).is_none() { - eval_data[exercise_name][×tamp] = serde_json::json!({}); - } - - // Add this result under the timestamp with template name as key - eval_data[exercise_name][×tamp] = serde_json::to_value(&result)?; - } - - // Write back to file with pretty formatting - let json_content = serde_json::to_string_pretty(&eval_data)?; - match fs::write(&eval_file, json_content) { - Ok(_) => println!("✓ Successfully saved results to {}", eval_file.display()), - Err(e) => println!("✗ Failed to write results file: {}", e), - } - - Ok(()) -} - -pub async fn run_exercise_eval( - exercise_path: PathBuf, - model: Arc, - app_state: Arc, - base_sha: String, - _framework_path: PathBuf, - cx: gpui::AsyncApp, -) -> Result { - let exercise_name = get_exercise_name(&exercise_path); - let language = get_exercise_language(&exercise_path)?; - let mut instructions = read_instructions(&exercise_path).await?; - instructions.push_str(&format!( - "\n\nWhen writing the code for this prompt, use {} to achieve the goal.", - language - )); - - println!("Running evaluation for exercise: {}", exercise_name); - - // Create temporary directory with exercise files - let temp_dir = setup_temp_repo(&exercise_path, &base_sha).await?; - let temp_path = temp_dir.path().to_path_buf(); - - let local_commit_sha = run_git(&temp_path, &["rev-parse", "HEAD"]).await?; - - let start_time = SystemTime::now(); - - // Create a basic eval struct to work with the existing system - let eval = Eval { - repo_path: temp_path.clone(), - eval_setup: EvalSetup { - url: format!("file://{}", temp_path.display()), - base_sha: local_commit_sha, // Use the local commit SHA instead of the framework base SHA - }, - user_prompt: instructions.clone(), - }; - - // Run the evaluation - let eval_output = cx - .update(|cx| eval.run(app_state.clone(), model.clone(), cx))? - .await?; - - // Get diff from git - let diff = eval_output.diff.clone(); - - let elapsed_time = start_time.elapsed()?; - - // Calculate total tokens as the sum of input and output tokens - let input_tokens = eval_output.token_usage.input_tokens; - let output_tokens = eval_output.token_usage.output_tokens; - let tool_use_counts = eval_output.tool_use_counts.values().sum::(); - let total_tokens = input_tokens + output_tokens; - - // Save results to evaluation directory - let result = EvalResult { - exercise_name: exercise_name.clone(), - diff, - assistant_response: eval_output.last_message.clone(), - elapsed_time_ms: elapsed_time.as_millis(), - timestamp: SystemTime::now() - .duration_since(SystemTime::UNIX_EPOCH)? - .as_millis(), - // Convert u32 token counts to usize - input_tokens: input_tokens.try_into().unwrap(), - output_tokens: output_tokens.try_into().unwrap(), - total_tokens: total_tokens.try_into().unwrap(), - tool_use_counts: tool_use_counts.try_into().unwrap(), - }; - - Ok(result) -} diff --git a/crates/agent_eval/src/get_exercise.rs b/crates/agent_eval/src/get_exercise.rs deleted file mode 100644 index 56d93939da..0000000000 --- a/crates/agent_eval/src/get_exercise.rs +++ /dev/null @@ -1,149 +0,0 @@ -use anyhow::{Result, anyhow}; -use std::{ - fs, - path::{Path, PathBuf}, -}; - -pub fn get_exercise_name(exercise_path: &Path) -> String { - exercise_path - .file_name() - .unwrap_or_default() - .to_string_lossy() - .to_string() -} - -pub fn get_exercise_language(exercise_path: &Path) -> Result { - // Extract the language from path (data/python/exercises/... => python) - let parts: Vec<_> = exercise_path.components().collect(); - - for (i, part) in parts.iter().enumerate() { - if i > 0 && part.as_os_str() == "eval_code" { - if i + 1 < parts.len() { - let language = parts[i + 1].as_os_str().to_string_lossy().to_string(); - return Ok(language); - } - } - } - - Err(anyhow!( - "Could not determine language from path: {:?}", - exercise_path - )) -} - -pub fn find_exercises( - framework_path: &Path, - languages: &[&str], - max_per_language: Option, -) -> Result> { - let mut all_exercises = Vec::new(); - - println!("Searching for exercises in languages: {:?}", languages); - - for language in languages { - let language_dir = framework_path - .join("eval_code") - .join(language) - .join("exercises") - .join("practice"); - - println!("Checking language directory: {:?}", language_dir); - if !language_dir.exists() { - println!("Warning: Language directory not found: {:?}", language_dir); - continue; - } - - let mut exercises = Vec::new(); - match fs::read_dir(&language_dir) { - Ok(entries) => { - for entry_result in entries { - match entry_result { - Ok(entry) => { - let path = entry.path(); - - if path.is_dir() { - // Special handling for "internal" directory - if *language == "internal" { - // Check for repo_info.json to validate it's an internal exercise - let repo_info_path = path.join(".meta").join("repo_info.json"); - let instructions_path = - path.join(".docs").join("instructions.md"); - - if repo_info_path.exists() && instructions_path.exists() { - exercises.push(path); - } - } else { - // Map the language to the file extension - original code - let language_extension = match *language { - "python" => "py", - "go" => "go", - "rust" => "rs", - "typescript" => "ts", - "javascript" => "js", - "ruby" => "rb", - "php" => "php", - "bash" => "sh", - "multi" => "diff", - _ => continue, // Skip unsupported languages - }; - - // Check if this is a valid exercise with instructions and example - let instructions_path = - path.join(".docs").join("instructions.md"); - let has_instructions = instructions_path.exists(); - let example_path = path - .join(".meta") - .join(format!("example.{}", language_extension)); - let has_example = example_path.exists(); - - if has_instructions && has_example { - exercises.push(path); - } - } - } - } - Err(err) => println!("Error reading directory entry: {}", err), - } - } - } - Err(err) => println!( - "Error reading directory {}: {}", - language_dir.display(), - err - ), - } - - // Sort exercises by name for consistent selection - exercises.sort_by(|a, b| { - let a_name = a.file_name().unwrap_or_default().to_string_lossy(); - let b_name = b.file_name().unwrap_or_default().to_string_lossy(); - a_name.cmp(&b_name) - }); - - // Apply the limit if specified - if let Some(limit) = max_per_language { - if exercises.len() > limit { - println!( - "Limiting {} exercises to {} for language {}", - exercises.len(), - limit, - language - ); - exercises.truncate(limit); - } - } - - println!( - "Found {} exercises for language {}: {:?}", - exercises.len(), - language, - exercises - .iter() - .map(|p| p.file_name().unwrap_or_default().to_string_lossy()) - .collect::>() - ); - all_exercises.extend(exercises); - } - - Ok(all_exercises) -} diff --git a/crates/agent_eval/src/git_commands.rs b/crates/agent_eval/src/git_commands.rs deleted file mode 100644 index 89b54c3360..0000000000 --- a/crates/agent_eval/src/git_commands.rs +++ /dev/null @@ -1,125 +0,0 @@ -use anyhow::{Result, anyhow}; -use serde::Deserialize; -use std::{fs, path::Path}; -use tempfile::TempDir; -use util::command::new_smol_command; -use walkdir::WalkDir; - -#[derive(Debug, Deserialize)] -pub struct SetupConfig { - #[serde(rename = "base.sha")] - pub base_sha: String, -} - -#[derive(Debug, Deserialize)] -pub struct RepoInfo { - pub remote_url: String, - pub head_sha: String, -} - -pub async fn run_git(repo_path: &Path, args: &[&str]) -> Result { - let output = new_smol_command("git") - .current_dir(repo_path) - .args(args) - .output() - .await?; - - if output.status.success() { - Ok(String::from_utf8(output.stdout)?.trim().to_string()) - } else { - Err(anyhow!( - "Git command failed: {} with status: {}", - args.join(" "), - output.status - )) - } -} - -pub async fn read_base_sha(framework_path: &Path) -> Result { - let setup_path = framework_path.join("setup.json"); - let setup_content = smol::unblock(move || std::fs::read_to_string(&setup_path)).await?; - let setup_config: SetupConfig = serde_json_lenient::from_str_lenient(&setup_content)?; - Ok(setup_config.base_sha) -} - -pub async fn read_repo_info(exercise_path: &Path) -> Result { - let repo_info_path = exercise_path.join(".meta").join("repo_info.json"); - println!("Reading repo info from: {}", repo_info_path.display()); - let repo_info_content = smol::unblock(move || std::fs::read_to_string(&repo_info_path)).await?; - let repo_info: RepoInfo = serde_json_lenient::from_str_lenient(&repo_info_content)?; - - // Remove any quotes from the strings - let remote_url = repo_info.remote_url.trim_matches('"').to_string(); - let head_sha = repo_info.head_sha.trim_matches('"').to_string(); - - Ok(RepoInfo { - remote_url, - head_sha, - }) -} - -pub async fn setup_temp_repo(exercise_path: &Path, _base_sha: &str) -> Result { - let temp_dir = TempDir::new()?; - - // Check if this is an internal exercise by looking for repo_info.json - let repo_info_path = exercise_path.join(".meta").join("repo_info.json"); - if repo_info_path.exists() { - // This is an internal exercise, handle it differently - let repo_info = read_repo_info(exercise_path).await?; - - // Clone the repository to the temp directory - let url = repo_info.remote_url; - let clone_path = temp_dir.path(); - println!( - "Cloning repository from {} to {}", - url, - clone_path.display() - ); - run_git( - &std::env::current_dir()?, - &["clone", &url, &clone_path.to_string_lossy()], - ) - .await?; - - // Checkout the specified commit - println!("Checking out commit: {}", repo_info.head_sha); - run_git(temp_dir.path(), &["checkout", &repo_info.head_sha]).await?; - - println!("Successfully set up internal repository"); - } else { - // Original code for regular exercises - // Copy the exercise files to the temp directory, excluding .docs and .meta - for entry in WalkDir::new(exercise_path).min_depth(0).max_depth(10) { - let entry = entry?; - let source_path = entry.path(); - - // Skip .docs and .meta directories completely - if source_path.starts_with(exercise_path.join(".docs")) - || source_path.starts_with(exercise_path.join(".meta")) - { - continue; - } - - if source_path.is_file() { - let relative_path = source_path.strip_prefix(exercise_path)?; - let dest_path = temp_dir.path().join(relative_path); - - // Make sure parent directories exist - if let Some(parent) = dest_path.parent() { - fs::create_dir_all(parent)?; - } - - fs::copy(source_path, dest_path)?; - } - } - - // Initialize git repo in the temp directory - run_git(temp_dir.path(), &["init"]).await?; - run_git(temp_dir.path(), &["add", "."]).await?; - run_git(temp_dir.path(), &["commit", "-m", "Initial commit"]).await?; - - println!("Created temp repo without .docs and .meta directories"); - } - - Ok(temp_dir) -} diff --git a/crates/agent_eval/src/headless_assistant.rs b/crates/agent_eval/src/headless_assistant.rs deleted file mode 100644 index dbaf11a150..0000000000 --- a/crates/agent_eval/src/headless_assistant.rs +++ /dev/null @@ -1,229 +0,0 @@ -use agent::{RequestKind, Thread, ThreadEvent, ThreadStore}; -use anyhow::anyhow; -use assistant_tool::ToolWorkingSet; -use client::{Client, UserStore}; -use collections::HashMap; -use dap::DapRegistry; -use gpui::{App, Entity, SemanticVersion, Subscription, Task, prelude::*}; -use language::LanguageRegistry; -use language_model::{ - AuthenticateError, LanguageModel, LanguageModelProviderId, LanguageModelRegistry, -}; -use node_runtime::NodeRuntime; -use project::{Project, RealFs}; -use prompt_store::PromptBuilder; -use settings::SettingsStore; -use smol::channel; -use std::sync::Arc; - -/// Subset of `workspace::AppState` needed by `HeadlessAssistant`, with additional fields. -pub struct HeadlessAppState { - pub languages: Arc, - pub client: Arc, - pub user_store: Entity, - pub fs: Arc, - pub node_runtime: NodeRuntime, - - // Additional fields not present in `workspace::AppState`. - pub prompt_builder: Arc, -} - -pub struct HeadlessAssistant { - pub thread: Entity, - pub project: Entity, - #[allow(dead_code)] - pub thread_store: Entity, - pub tool_use_counts: HashMap, u32>, - pub done_tx: channel::Sender>, - _subscription: Subscription, -} - -impl HeadlessAssistant { - pub fn new( - app_state: Arc, - cx: &mut App, - ) -> anyhow::Result<(Entity, channel::Receiver>)> { - let env = None; - let project = Project::local( - app_state.client.clone(), - app_state.node_runtime.clone(), - app_state.user_store.clone(), - app_state.languages.clone(), - Arc::new(DapRegistry::default()), - app_state.fs.clone(), - env, - cx, - ); - - let tools = Arc::new(ToolWorkingSet::default()); - let thread_store = - ThreadStore::new(project.clone(), tools, app_state.prompt_builder.clone(), cx)?; - - let thread = thread_store.update(cx, |thread_store, cx| thread_store.create_thread(cx)); - - let (done_tx, done_rx) = channel::unbounded::>(); - - let headless_thread = cx.new(move |cx| Self { - _subscription: cx.subscribe(&thread, Self::handle_thread_event), - thread, - project, - thread_store, - tool_use_counts: HashMap::default(), - done_tx, - }); - - Ok((headless_thread, done_rx)) - } - - fn handle_thread_event( - &mut self, - thread: Entity, - event: &ThreadEvent, - cx: &mut Context, - ) { - match event { - ThreadEvent::ShowError(err) => self - .done_tx - .send_blocking(Err(anyhow!("{:?}", err))) - .unwrap(), - ThreadEvent::DoneStreaming => { - let thread = thread.read(cx); - if let Some(message) = thread.messages().last() { - println!("Message: {}", message.to_string()); - } - if thread.all_tools_finished() { - self.done_tx.send_blocking(Ok(())).unwrap() - } - } - ThreadEvent::UsePendingTools { .. } => {} - ThreadEvent::ToolConfirmationNeeded => { - // Automatically approve all tools that need confirmation in headless mode - println!("Tool confirmation needed - automatically approving in headless mode"); - - // Get the tools needing confirmation - let tools_needing_confirmation: Vec<_> = thread - .read(cx) - .tools_needing_confirmation() - .cloned() - .collect(); - - // Run each tool that needs confirmation - for tool_use in tools_needing_confirmation { - if let Some(tool) = thread.read(cx).tools().tool(&tool_use.name, cx) { - thread.update(cx, |thread, cx| { - println!("Auto-approving tool: {}", tool_use.name); - - // Create a request to send to the tool - let request = thread.to_completion_request(RequestKind::Chat, cx); - let messages = Arc::new(request.messages); - - // Run the tool - thread.run_tool( - tool_use.id.clone(), - tool_use.ui_text.clone(), - tool_use.input.clone(), - &messages, - tool, - cx, - ); - }); - } - } - } - ThreadEvent::ToolFinished { - tool_use_id, - pending_tool_use, - .. - } => { - if let Some(pending_tool_use) = pending_tool_use { - println!( - "Used tool {} with input: {}", - pending_tool_use.name, pending_tool_use.input - ); - *self - .tool_use_counts - .entry(pending_tool_use.name.clone()) - .or_insert(0) += 1; - } - if let Some(tool_result) = thread.read(cx).tool_result(tool_use_id) { - println!("Tool result: {:?}", tool_result); - } - } - _ => {} - } - } -} - -pub fn init(cx: &mut App) -> Arc { - release_channel::init(SemanticVersion::default(), cx); - gpui_tokio::init(cx); - - let mut settings_store = SettingsStore::new(cx); - settings_store - .set_default_settings(settings::default_settings().as_ref(), cx) - .unwrap(); - cx.set_global(settings_store); - client::init_settings(cx); - Project::init_settings(cx); - - let client = Client::production(cx); - cx.set_http_client(client.http_client().clone()); - - let git_binary_path = None; - let fs = Arc::new(RealFs::new( - git_binary_path, - cx.background_executor().clone(), - )); - - let languages = Arc::new(LanguageRegistry::new(cx.background_executor().clone())); - - let user_store = cx.new(|cx| UserStore::new(client.clone(), cx)); - - language::init(cx); - language_model::init(client.clone(), cx); - language_models::init(user_store.clone(), client.clone(), fs.clone(), cx); - assistant_tools::init(client.http_client().clone(), cx); - context_server::init(cx); - let stdout_is_a_pty = false; - let prompt_builder = PromptBuilder::load(fs.clone(), stdout_is_a_pty, cx); - agent::init(fs.clone(), client.clone(), prompt_builder.clone(), cx); - - Arc::new(HeadlessAppState { - languages, - client, - user_store, - fs, - node_runtime: NodeRuntime::unavailable(), - prompt_builder, - }) -} - -pub fn find_model(model_name: &str, cx: &App) -> anyhow::Result> { - let model_registry = LanguageModelRegistry::read_global(cx); - let model = model_registry - .available_models(cx) - .find(|model| model.id().0 == model_name); - - let Some(model) = model else { - return Err(anyhow!( - "No language model named {} was available. Available models: {}", - model_name, - model_registry - .available_models(cx) - .map(|model| model.id().0.clone()) - .collect::>() - .join(", ") - )); - }; - - Ok(model) -} - -pub fn authenticate_model_provider( - provider_id: LanguageModelProviderId, - cx: &mut App, -) -> Task> { - let model_registry = LanguageModelRegistry::read_global(cx); - let model_provider = model_registry.provider(&provider_id).unwrap(); - model_provider.authenticate(cx) -} diff --git a/crates/agent_eval/src/main.rs b/crates/agent_eval/src/main.rs deleted file mode 100644 index ef4cb51a33..0000000000 --- a/crates/agent_eval/src/main.rs +++ /dev/null @@ -1,205 +0,0 @@ -mod eval; -mod get_exercise; -mod git_commands; -mod headless_assistant; - -use clap::Parser; -use eval::{run_exercise_eval, save_eval_results}; -use futures::stream::{self, StreamExt}; -use get_exercise::{find_exercises, get_exercise_language, get_exercise_name}; -use git_commands::read_base_sha; -use gpui::Application; -use headless_assistant::{authenticate_model_provider, find_model}; -use language_model::LanguageModelRegistry; -use reqwest_client::ReqwestClient; -use std::{path::PathBuf, sync::Arc}; - -#[derive(Parser, Debug)] -#[command( - name = "agent_eval", - disable_version_flag = true, - before_help = "Tool eval runner" -)] -struct Args { - /// Match the names of evals to run. - #[arg(long)] - exercise_names: Vec, - /// Runs all exercises, causes the exercise_names to be ignored. - #[arg(long)] - all: bool, - /// Supported language types to evaluate (default: internal). - /// Internal is data generated from the agent panel - #[arg(long, default_value = "internal")] - languages: String, - /// Name of the model (default: "claude-3-7-sonnet-latest") - #[arg(long, default_value = "claude-3-7-sonnet-latest")] - model_name: String, - /// Name of the editor model (default: value of `--model_name`). - #[arg(long)] - editor_model_name: Option, - /// Number of evaluations to run concurrently (default: 3) - #[arg(short, long, default_value = "5")] - concurrency: usize, - /// Maximum number of exercises to evaluate per language - #[arg(long)] - max_exercises_per_language: Option, -} - -fn main() { - env_logger::init(); - let args = Args::parse(); - let http_client = Arc::new(ReqwestClient::new()); - let app = Application::headless().with_http_client(http_client.clone()); - - // Path to the zed-ace-framework repo - let framework_path = PathBuf::from("../zed-ace-framework") - .canonicalize() - .unwrap(); - - // Fix the 'languages' lifetime issue by creating owned Strings instead of slices - let languages: Vec = args.languages.split(',').map(|s| s.to_string()).collect(); - - println!("Using zed-ace-framework at: {:?}", framework_path); - println!("Evaluating languages: {:?}", languages); - - app.run(move |cx| { - let app_state = headless_assistant::init(cx); - - let model = find_model(&args.model_name, cx).unwrap(); - let editor_model = if let Some(model_name) = &args.editor_model_name { - find_model(model_name, cx).unwrap() - } else { - model.clone() - }; - - LanguageModelRegistry::global(cx).update(cx, |registry, cx| { - registry.set_default_model(Some(model.clone()), cx); - }); - - let model_provider_id = model.provider_id(); - let editor_model_provider_id = editor_model.provider_id(); - - let framework_path_clone = framework_path.clone(); - let languages_clone = languages.clone(); - let exercise_names = args.exercise_names.clone(); - let all_flag = args.all; - - cx.spawn(async move |cx| { - // Authenticate all model providers first - cx.update(|cx| authenticate_model_provider(model_provider_id.clone(), cx)) - .unwrap() - .await - .unwrap(); - cx.update(|cx| authenticate_model_provider(editor_model_provider_id.clone(), cx)) - .unwrap() - .await - .unwrap(); - - println!("framework path: {}", framework_path_clone.display()); - - let base_sha = read_base_sha(&framework_path_clone).await.unwrap(); - - println!("base sha: {}", base_sha); - - let all_exercises = find_exercises( - &framework_path_clone, - &languages_clone - .iter() - .map(|s| s.as_str()) - .collect::>(), - args.max_exercises_per_language, - ) - .unwrap(); - println!("Found {} exercises total", all_exercises.len()); - - // Filter exercises if specific ones were requested - let exercises_to_run = if !exercise_names.is_empty() { - // If exercise names are specified, filter by them regardless of --all flag - all_exercises - .into_iter() - .filter(|path| { - let name = get_exercise_name(path); - exercise_names.iter().any(|filter| name.contains(filter)) - }) - .collect() - } else if all_flag { - // Only use all_flag if no exercise names are specified - all_exercises - } else { - // Default behavior (no filters) - all_exercises - }; - - println!("Will run {} exercises", exercises_to_run.len()); - - // Create exercise eval tasks - each exercise is a single task that will run templates sequentially - let exercise_tasks: Vec<_> = exercises_to_run - .into_iter() - .map(|exercise_path| { - let exercise_name = get_exercise_name(&exercise_path); - let model_clone = model.clone(); - let app_state_clone = app_state.clone(); - let base_sha_clone = base_sha.clone(); - let framework_path_clone = framework_path_clone.clone(); - let cx_clone = cx.clone(); - - async move { - println!("Processing exercise: {}", exercise_name); - let mut exercise_results = Vec::new(); - - match run_exercise_eval( - exercise_path.clone(), - model_clone.clone(), - app_state_clone.clone(), - base_sha_clone.clone(), - framework_path_clone.clone(), - cx_clone.clone(), - ) - .await - { - Ok(result) => { - println!("Completed {}", exercise_name); - exercise_results.push(result); - } - Err(err) => { - println!("Error running {}: {}", exercise_name, err); - } - } - - // Save results for this exercise - if !exercise_results.is_empty() { - if let Err(err) = - save_eval_results(&exercise_path, exercise_results.clone()).await - { - println!("Error saving results for {}: {}", exercise_name, err); - } else { - println!("Saved results for {}", exercise_name); - } - } - - exercise_results - } - }) - .collect(); - - println!( - "Running {} exercises with concurrency: {}", - exercise_tasks.len(), - args.concurrency - ); - - // Run exercises concurrently, with each exercise running its templates sequentially - let all_results = stream::iter(exercise_tasks) - .buffer_unordered(args.concurrency) - .flat_map(stream::iter) - .collect::>() - .await; - - println!("Completed {} evaluation runs", all_results.len()); - cx.update(|cx| cx.quit()).unwrap(); - }) - .detach(); - }); - - println!("Done running evals"); -} diff --git a/crates/agent_rules/Cargo.toml b/crates/agent_rules/Cargo.toml deleted file mode 100644 index 3ec2b53bec..0000000000 --- a/crates/agent_rules/Cargo.toml +++ /dev/null @@ -1,25 +0,0 @@ -[package] -name = "agent_rules" -version = "0.1.0" -edition.workspace = true -publish.workspace = true -license = "GPL-3.0-or-later" - -[lints] -workspace = true - -[lib] -path = "src/agent_rules.rs" -doctest = false - -[dependencies] -anyhow.workspace = true -fs.workspace = true -gpui.workspace = true -prompt_store.workspace = true -util.workspace = true -worktree.workspace = true -workspace-hack = { version = "0.1", path = "../../tooling/workspace-hack" } - -[dev-dependencies] -indoc.workspace = true diff --git a/crates/agent_rules/LICENSE-GPL b/crates/agent_rules/LICENSE-GPL deleted file mode 120000 index 89e542f750..0000000000 --- a/crates/agent_rules/LICENSE-GPL +++ /dev/null @@ -1 +0,0 @@ -../../LICENSE-GPL \ No newline at end of file diff --git a/crates/agent_rules/src/agent_rules.rs b/crates/agent_rules/src/agent_rules.rs deleted file mode 100644 index faae6a086a..0000000000 --- a/crates/agent_rules/src/agent_rules.rs +++ /dev/null @@ -1,51 +0,0 @@ -use std::sync::Arc; - -use anyhow::{Context as _, Result}; -use fs::Fs; -use gpui::{App, AppContext, Task}; -use prompt_store::SystemPromptRulesFile; -use util::maybe; -use worktree::Worktree; - -const RULES_FILE_NAMES: [&'static str; 6] = [ - ".rules", - ".cursorrules", - ".windsurfrules", - ".clinerules", - ".github/copilot-instructions.md", - "CLAUDE.md", -]; - -pub fn load_worktree_rules_file( - fs: Arc, - worktree: &Worktree, - cx: &App, -) -> Option>> { - let selected_rules_file = RULES_FILE_NAMES - .into_iter() - .filter_map(|name| { - worktree - .entry_for_path(name) - .filter(|entry| entry.is_file()) - .map(|entry| (entry.path.clone(), worktree.absolutize(&entry.path))) - }) - .next(); - - // Note that Cline supports `.clinerules` being a directory, but that is not currently - // supported. This doesn't seem to occur often in GitHub repositories. - selected_rules_file.map(|(path_in_worktree, abs_path)| { - let fs = fs.clone(); - cx.background_spawn(maybe!(async move { - let abs_path = abs_path?; - let text = fs - .load(&abs_path) - .await - .with_context(|| format!("Failed to load assistant rules file {:?}", abs_path))?; - anyhow::Ok(SystemPromptRulesFile { - path_in_worktree, - abs_path: abs_path.into(), - text: text.trim().to_string(), - }) - })) - }) -} diff --git a/crates/assistant_settings/src/assistant_settings.rs b/crates/assistant_settings/src/assistant_settings.rs index e09239b4ec..283c1e569d 100644 --- a/crates/assistant_settings/src/assistant_settings.rs +++ b/crates/assistant_settings/src/assistant_settings.rs @@ -69,7 +69,7 @@ pub enum AssistantProviderContentV1 { }, } -#[derive(Debug, Default)] +#[derive(Clone, Debug, Default)] pub struct AssistantSettings { pub enabled: bool, pub button: bool, diff --git a/crates/assistant_tools/src/code_symbols_tool.rs b/crates/assistant_tools/src/code_symbols_tool.rs index dccff43fbf..9f0219b281 100644 --- a/crates/assistant_tools/src/code_symbols_tool.rs +++ b/crates/assistant_tools/src/code_symbols_tool.rs @@ -179,11 +179,9 @@ pub async fn file_outline( // Wait until the buffer has been fully parsed, so that we can read its outline. let mut parse_status = buffer.read_with(cx, |buffer, _| buffer.parse_status())?; - while parse_status - .recv() - .await - .map_or(false, |status| status != ParseStatus::Idle) - {} + while *parse_status.borrow() != ParseStatus::Idle { + parse_status.changed().await?; + } let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot())?; let Some(outline) = snapshot.outline(None) else { diff --git a/crates/eval/Cargo.toml b/crates/eval/Cargo.toml index e3701c9a23..0249c24dcf 100644 --- a/crates/eval/Cargo.toml +++ b/crates/eval/Cargo.toml @@ -9,12 +9,13 @@ agent.workspace = true anyhow.workspace = true assistant_tool.workspace = true assistant_tools.workspace = true +assistant_settings.workspace = true client.workspace = true -collections.workspace = true context_server.workspace = true dap.workspace = true env_logger.workspace = true fs.workspace = true +futures.workspace = true gpui.workspace = true gpui_tokio.workspace = true language.workspace = true @@ -27,7 +28,6 @@ release_channel.workspace = true reqwest_client.workspace = true serde.workspace = true settings.workspace = true -smol.workspace = true toml.workspace = true workspace-hack.workspace = true diff --git a/crates/eval/src/agent.rs b/crates/eval/src/agent.rs deleted file mode 100644 index 636c8b5b3d..0000000000 --- a/crates/eval/src/agent.rs +++ /dev/null @@ -1,229 +0,0 @@ -use ::agent::{RequestKind, Thread, ThreadEvent, ThreadStore}; -use anyhow::anyhow; -use assistant_tool::ToolWorkingSet; -use client::{Client, UserStore}; -use collections::HashMap; -use dap::DapRegistry; -use gpui::{App, Entity, SemanticVersion, Subscription, Task, prelude::*}; -use language::LanguageRegistry; -use language_model::{ - AuthenticateError, LanguageModel, LanguageModelProviderId, LanguageModelRegistry, -}; -use node_runtime::NodeRuntime; -use project::{Project, RealFs}; -use prompt_store::PromptBuilder; -use settings::SettingsStore; -use smol::channel; -use std::sync::Arc; - -/// Subset of `workspace::AppState` needed by `HeadlessAssistant`, with additional fields. -pub struct AgentAppState { - pub languages: Arc, - pub client: Arc, - pub user_store: Entity, - pub fs: Arc, - pub node_runtime: NodeRuntime, - - // Additional fields not present in `workspace::AppState`. - pub prompt_builder: Arc, -} - -pub struct Agent { - // pub thread: Entity, - // pub project: Entity, - #[allow(dead_code)] - pub thread_store: Entity, - pub tool_use_counts: HashMap, u32>, - pub done_tx: channel::Sender>, - _subscription: Subscription, -} - -impl Agent { - pub fn new( - app_state: Arc, - cx: &mut App, - ) -> anyhow::Result<(Entity, channel::Receiver>)> { - let env = None; - let project = Project::local( - app_state.client.clone(), - app_state.node_runtime.clone(), - app_state.user_store.clone(), - app_state.languages.clone(), - Arc::new(DapRegistry::default()), - app_state.fs.clone(), - env, - cx, - ); - - let tools = Arc::new(ToolWorkingSet::default()); - let thread_store = - ThreadStore::new(project.clone(), tools, app_state.prompt_builder.clone(), cx)?; - - let thread = thread_store.update(cx, |thread_store, cx| thread_store.create_thread(cx)); - - let (done_tx, done_rx) = channel::unbounded::>(); - - let headless_thread = cx.new(move |cx| Self { - _subscription: cx.subscribe(&thread, Self::handle_thread_event), - // thread, - // project, - thread_store, - tool_use_counts: HashMap::default(), - done_tx, - }); - - Ok((headless_thread, done_rx)) - } - - fn handle_thread_event( - &mut self, - thread: Entity, - event: &ThreadEvent, - cx: &mut Context, - ) { - match event { - ThreadEvent::ShowError(err) => self - .done_tx - .send_blocking(Err(anyhow!("{:?}", err))) - .unwrap(), - ThreadEvent::DoneStreaming => { - let thread = thread.read(cx); - if let Some(message) = thread.messages().last() { - println!("Message: {}", message.to_string()); - } - if thread.all_tools_finished() { - self.done_tx.send_blocking(Ok(())).unwrap() - } - } - ThreadEvent::UsePendingTools { .. } => {} - ThreadEvent::ToolConfirmationNeeded => { - // Automatically approve all tools that need confirmation in headless mode - println!("Tool confirmation needed - automatically approving in headless mode"); - - // Get the tools needing confirmation - let tools_needing_confirmation: Vec<_> = thread - .read(cx) - .tools_needing_confirmation() - .cloned() - .collect(); - - // Run each tool that needs confirmation - for tool_use in tools_needing_confirmation { - if let Some(tool) = thread.read(cx).tools().tool(&tool_use.name, cx) { - thread.update(cx, |thread, cx| { - println!("Auto-approving tool: {}", tool_use.name); - - // Create a request to send to the tool - let request = thread.to_completion_request(RequestKind::Chat, cx); - let messages = Arc::new(request.messages); - - // Run the tool - thread.run_tool( - tool_use.id.clone(), - tool_use.ui_text.clone(), - tool_use.input.clone(), - &messages, - tool, - cx, - ); - }); - } - } - } - ThreadEvent::ToolFinished { - tool_use_id, - pending_tool_use, - .. - } => { - if let Some(pending_tool_use) = pending_tool_use { - println!( - "Used tool {} with input: {}", - pending_tool_use.name, pending_tool_use.input - ); - *self - .tool_use_counts - .entry(pending_tool_use.name.clone()) - .or_insert(0) += 1; - } - if let Some(tool_result) = thread.read(cx).tool_result(tool_use_id) { - println!("Tool result: {:?}", tool_result); - } - } - _ => {} - } - } -} - -pub fn init(cx: &mut App) -> Arc { - release_channel::init(SemanticVersion::default(), cx); - gpui_tokio::init(cx); - - let mut settings_store = SettingsStore::new(cx); - settings_store - .set_default_settings(settings::default_settings().as_ref(), cx) - .unwrap(); - cx.set_global(settings_store); - client::init_settings(cx); - Project::init_settings(cx); - - let client = Client::production(cx); - cx.set_http_client(client.http_client().clone()); - - let git_binary_path = None; - let fs = Arc::new(RealFs::new( - git_binary_path, - cx.background_executor().clone(), - )); - - let languages = Arc::new(LanguageRegistry::new(cx.background_executor().clone())); - - let user_store = cx.new(|cx| UserStore::new(client.clone(), cx)); - - language::init(cx); - language_model::init(client.clone(), cx); - language_models::init(user_store.clone(), client.clone(), fs.clone(), cx); - assistant_tools::init(client.http_client().clone(), cx); - context_server::init(cx); - let stdout_is_a_pty = false; - let prompt_builder = PromptBuilder::load(fs.clone(), stdout_is_a_pty, cx); - agent::init(fs.clone(), client.clone(), prompt_builder.clone(), cx); - - Arc::new(AgentAppState { - languages, - client, - user_store, - fs, - node_runtime: NodeRuntime::unavailable(), - prompt_builder, - }) -} - -pub fn find_model(model_name: &str, cx: &App) -> anyhow::Result> { - let model_registry = LanguageModelRegistry::read_global(cx); - let model = model_registry - .available_models(cx) - .find(|model| model.id().0 == model_name); - - let Some(model) = model else { - return Err(anyhow!( - "No language model named {} was available. Available models: {}", - model_name, - model_registry - .available_models(cx) - .map(|model| model.id().0.clone()) - .collect::>() - .join(", ") - )); - }; - - Ok(model) -} - -pub fn authenticate_model_provider( - provider_id: LanguageModelProviderId, - cx: &mut App, -) -> Task> { - let model_registry = LanguageModelRegistry::read_global(cx); - let model_provider = model_registry.provider(&provider_id).unwrap(); - model_provider.authenticate(cx) -} diff --git a/crates/eval/src/eval.rs b/crates/eval/src/eval.rs index ad3180512d..88cca63852 100644 --- a/crates/eval/src/eval.rs +++ b/crates/eval/src/eval.rs @@ -1,74 +1,22 @@ -use agent::Agent; -use anyhow::Result; -use gpui::Application; -use language_model::LanguageModelRegistry; -use reqwest_client::ReqwestClient; -use serde::Deserialize; -use std::{ - fs, - path::{Path, PathBuf}, - sync::Arc, +mod example; + +use assistant_settings::AssistantSettings; +use client::{Client, UserStore}; +pub(crate) use example::*; + +use ::fs::RealFs; +use anyhow::anyhow; +use gpui::{App, AppContext, Application, Entity, SemanticVersion, Task}; +use language::LanguageRegistry; +use language_model::{ + AuthenticateError, LanguageModel, LanguageModelProviderId, LanguageModelRegistry, }; -mod agent; - -#[derive(Debug, Deserialize)] -pub struct ExampleBase { - pub path: PathBuf, - pub revision: String, -} - -#[derive(Debug)] -pub struct Example { - pub base: ExampleBase, - - /// Content of the prompt.md file - pub prompt: String, - - /// Content of the rubric.md file - pub rubric: String, -} - -impl Example { - /// Load an example from a directory containing base.toml, prompt.md, and rubric.md - pub fn load_from_directory>(dir_path: P) -> Result { - let base_path = dir_path.as_ref().join("base.toml"); - let prompt_path = dir_path.as_ref().join("prompt.md"); - let rubric_path = dir_path.as_ref().join("rubric.md"); - - let mut base: ExampleBase = toml::from_str(&fs::read_to_string(&base_path)?)?; - base.path = base.path.canonicalize()?; - - Ok(Example { - base, - prompt: fs::read_to_string(prompt_path)?, - rubric: fs::read_to_string(rubric_path)?, - }) - } - - /// Set up the example by checking out the specified Git revision - pub fn setup(&self) -> Result<()> { - use std::process::Command; - - // Check if the directory exists - let path = Path::new(&self.base.path); - anyhow::ensure!(path.exists(), "Path does not exist: {:?}", self.base.path); - - // Change to the project directory and checkout the specified revision - let output = Command::new("git") - .current_dir(&self.base.path) - .arg("checkout") - .arg(&self.base.revision) - .output()?; - anyhow::ensure!( - output.status.success(), - "Failed to checkout revision {}: {}", - self.base.revision, - String::from_utf8_lossy(&output.stderr), - ); - - Ok(()) - } -} +use node_runtime::NodeRuntime; +use project::Project; +use prompt_store::PromptBuilder; +use reqwest_client::ReqwestClient; +use settings::{Settings, SettingsStore}; +use std::sync::Arc; fn main() { env_logger::init(); @@ -76,10 +24,9 @@ fn main() { let app = Application::headless().with_http_client(http_client.clone()); app.run(move |cx| { - let app_state = crate::agent::init(cx); - let _agent = Agent::new(app_state, cx); + let app_state = init(cx); - let model = agent::find_model("claude-3-7-sonnet-thinking-latest", cx).unwrap(); + let model = find_model("claude-3-7-sonnet-thinking-latest", cx).unwrap(); LanguageModelRegistry::global(cx).update(cx, |registry, cx| { registry.set_default_model(Some(model.clone()), cx); @@ -87,15 +34,112 @@ fn main() { let model_provider_id = model.provider_id(); - let authenticate = agent::authenticate_model_provider(model_provider_id.clone(), cx); + let authenticate = authenticate_model_provider(model_provider_id.clone(), cx); - cx.spawn(async move |_cx| { + cx.spawn(async move |cx| { authenticate.await.unwrap(); - }) - .detach(); - }); - // let example = - // Example::load_from_directory("./crates/eval/examples/find_and_replace_diff_card")?; - // example.setup()?; + let example = + Example::load_from_directory("./crates/eval/examples/find_and_replace_diff_card")?; + example.setup()?; + cx.update(|cx| example.run(model, app_state, cx))?.await?; + + anyhow::Ok(()) + }) + .detach_and_log_err(cx); + }); +} + +/// Subset of `workspace::AppState` needed by `HeadlessAssistant`, with additional fields. +pub struct AgentAppState { + pub languages: Arc, + pub client: Arc, + pub user_store: Entity, + pub fs: Arc, + pub node_runtime: NodeRuntime, + + // Additional fields not present in `workspace::AppState`. + pub prompt_builder: Arc, +} + +pub fn init(cx: &mut App) -> Arc { + release_channel::init(SemanticVersion::default(), cx); + gpui_tokio::init(cx); + + let mut settings_store = SettingsStore::new(cx); + settings_store + .set_default_settings(settings::default_settings().as_ref(), cx) + .unwrap(); + cx.set_global(settings_store); + client::init_settings(cx); + Project::init_settings(cx); + + let client = Client::production(cx); + cx.set_http_client(client.http_client().clone()); + + let git_binary_path = None; + let fs = Arc::new(RealFs::new( + git_binary_path, + cx.background_executor().clone(), + )); + + let languages = Arc::new(LanguageRegistry::new(cx.background_executor().clone())); + + let user_store = cx.new(|cx| UserStore::new(client.clone(), cx)); + + language::init(cx); + language_model::init(client.clone(), cx); + language_models::init(user_store.clone(), client.clone(), fs.clone(), cx); + assistant_tools::init(client.http_client().clone(), cx); + context_server::init(cx); + let stdout_is_a_pty = false; + let prompt_builder = PromptBuilder::load(fs.clone(), stdout_is_a_pty, cx); + agent::init(fs.clone(), client.clone(), prompt_builder.clone(), cx); + + AssistantSettings::override_global( + AssistantSettings { + always_allow_tool_actions: true, + ..AssistantSettings::get_global(cx).clone() + }, + cx, + ); + + Arc::new(AgentAppState { + languages, + client, + user_store, + fs, + node_runtime: NodeRuntime::unavailable(), + prompt_builder, + }) +} + +pub fn find_model(model_name: &str, cx: &App) -> anyhow::Result> { + let model_registry = LanguageModelRegistry::read_global(cx); + let model = model_registry + .available_models(cx) + .find(|model| model.id().0 == model_name); + + let Some(model) = model else { + return Err(anyhow!( + "No language model named {} was available. Available models: {}", + model_name, + model_registry + .available_models(cx) + .map(|model| model.id().0.clone()) + .collect::>() + .join(", ") + )); + }; + + Ok(model) +} + +pub fn authenticate_model_provider( + provider_id: LanguageModelProviderId, + cx: &mut App, +) -> Task> { + let model_registry = LanguageModelRegistry::read_global(cx); + let model_provider = model_registry.provider(&provider_id).unwrap(); + model_provider.authenticate(cx) } diff --git a/crates/eval/src/example.rs b/crates/eval/src/example.rs new file mode 100644 index 0000000000..6b45eddb60 --- /dev/null +++ b/crates/eval/src/example.rs @@ -0,0 +1,178 @@ +use agent::{RequestKind, ThreadEvent, ThreadStore}; +use anyhow::{Result, anyhow}; +use assistant_tool::ToolWorkingSet; +use dap::DapRegistry; +use futures::channel::oneshot; +use gpui::{App, Task}; +use language_model::{LanguageModel, StopReason}; +use project::Project; +use serde::Deserialize; +use std::process::Command; +use std::sync::Arc; +use std::{ + fs, + path::{Path, PathBuf}, +}; + +use crate::AgentAppState; + +#[derive(Debug, Deserialize)] +pub struct ExampleBase { + pub path: PathBuf, + pub revision: String, +} + +#[derive(Debug)] +pub struct Example { + pub base: ExampleBase, + + /// Content of the prompt.md file + pub prompt: String, + + /// Content of the rubric.md file + pub _rubric: String, +} + +impl Example { + /// Load an example from a directory containing base.toml, prompt.md, and rubric.md + pub fn load_from_directory>(dir_path: P) -> Result { + let base_path = dir_path.as_ref().join("base.toml"); + let prompt_path = dir_path.as_ref().join("prompt.md"); + let rubric_path = dir_path.as_ref().join("rubric.md"); + + let mut base: ExampleBase = toml::from_str(&fs::read_to_string(&base_path)?)?; + base.path = base.path.canonicalize()?; + + Ok(Example { + base, + prompt: fs::read_to_string(prompt_path)?, + _rubric: fs::read_to_string(rubric_path)?, + }) + } + + /// Set up the example by checking out the specified Git revision + pub fn setup(&self) -> Result<()> { + // Check if the directory exists + let path = Path::new(&self.base.path); + anyhow::ensure!(path.exists(), "Path does not exist: {:?}", self.base.path); + + // Change to the project directory and checkout the specified revision + let output = Command::new("git") + .current_dir(&self.base.path) + .arg("checkout") + .arg(&self.base.revision) + .output()?; + anyhow::ensure!( + output.status.success(), + "Failed to checkout revision {}: {}", + self.base.revision, + String::from_utf8_lossy(&output.stderr), + ); + + Ok(()) + } + + pub fn run( + self, + model: Arc, + app_state: Arc, + cx: &mut App, + ) -> Task> { + let project = Project::local( + app_state.client.clone(), + app_state.node_runtime.clone(), + app_state.user_store.clone(), + app_state.languages.clone(), + Arc::new(DapRegistry::default()), + app_state.fs.clone(), + None, + cx, + ); + + let worktree = project.update(cx, |project, cx| { + project.create_worktree(self.base.path, true, cx) + }); + + let tools = Arc::new(ToolWorkingSet::default()); + let thread_store = + ThreadStore::load(project.clone(), tools, app_state.prompt_builder.clone(), cx); + + println!("USER:"); + println!("{}", self.prompt); + println!("ASSISTANT:"); + cx.spawn(async move |cx| { + worktree.await?; + let thread_store = thread_store.await; + let thread = + thread_store.update(cx, |thread_store, cx| thread_store.create_thread(cx))?; + + let (tx, rx) = oneshot::channel(); + let mut tx = Some(tx); + + let _subscription = + cx.subscribe( + &thread, + move |thread, event: &ThreadEvent, cx| match event { + ThreadEvent::Stopped(reason) => match reason { + Ok(StopReason::EndTurn) => { + if let Some(tx) = tx.take() { + tx.send(Ok(())).ok(); + } + } + Ok(StopReason::MaxTokens) => { + if let Some(tx) = tx.take() { + tx.send(Err(anyhow!("Exceeded maximum tokens"))).ok(); + } + } + Ok(StopReason::ToolUse) => {} + Err(error) => { + if let Some(tx) = tx.take() { + tx.send(Err(anyhow!(error.clone()))).ok(); + } + } + }, + ThreadEvent::ShowError(thread_error) => { + if let Some(tx) = tx.take() { + tx.send(Err(anyhow!(thread_error.clone()))).ok(); + } + } + ThreadEvent::StreamedAssistantText(_, chunk) => { + print!("{}", chunk); + } + ThreadEvent::StreamedAssistantThinking(_, chunk) => { + print!("{}", chunk); + } + ThreadEvent::UsePendingTools { tool_uses } => { + println!("\n\nUSING TOOLS:"); + for tool_use in tool_uses { + println!("{}: {}", tool_use.name, tool_use.input); + } + } + ThreadEvent::ToolFinished { + tool_use_id, + pending_tool_use, + .. + } => { + if let Some(tool_use) = pending_tool_use { + println!("\nTOOL FINISHED: {}", tool_use.name); + } + if let Some(tool_result) = thread.read(cx).tool_result(tool_use_id) { + println!("\n{}\n", tool_result.content); + } + } + _ => {} + }, + )?; + + thread.update(cx, |thread, cx| { + let context = vec![]; + thread.insert_user_message(self.prompt.clone(), context, None, cx); + thread.send_to_model(model, RequestKind::Chat, cx); + })?; + + rx.await??; + + Ok(()) + }) + } +} diff --git a/crates/gpui/src/app/async_context.rs b/crates/gpui/src/app/async_context.rs index c20a5b9066..02cc8f33b8 100644 --- a/crates/gpui/src/app/async_context.rs +++ b/crates/gpui/src/app/async_context.rs @@ -1,7 +1,7 @@ use crate::{ AnyView, AnyWindowHandle, App, AppCell, AppContext, BackgroundExecutor, BorrowAppContext, - Entity, Focusable, ForegroundExecutor, Global, PromptLevel, Render, Reservation, Result, Task, - VisualContext, Window, WindowHandle, + Entity, EventEmitter, Focusable, ForegroundExecutor, Global, PromptLevel, Render, Reservation, + Result, Subscription, Task, VisualContext, Window, WindowHandle, }; use anyhow::{Context as _, anyhow}; use derive_more::{Deref, DerefMut}; @@ -154,6 +154,26 @@ impl AsyncApp { Ok(lock.update(f)) } + /// Arrange for the given callback to be invoked whenever the given entity emits an event of a given type. + /// The callback is provided a handle to the emitting entity and a reference to the emitted event. + pub fn subscribe( + &mut self, + entity: &Entity, + mut on_event: impl FnMut(Entity, &Event, &mut App) + 'static, + ) -> Result + where + T: 'static + EventEmitter, + Event: 'static, + { + let app = self + .app + .upgrade() + .ok_or_else(|| anyhow!("app was released"))?; + let mut lock = app.borrow_mut(); + let subscription = lock.subscribe(entity, on_event); + Ok(subscription) + } + /// Open a window with the given options based on the root view returned by the given function. pub fn open_window( &self, diff --git a/crates/prompt_store/src/prompts.rs b/crates/prompt_store/src/prompts.rs index a06fed5714..717af5fc55 100644 --- a/crates/prompt_store/src/prompts.rs +++ b/crates/prompt_store/src/prompts.rs @@ -16,17 +16,17 @@ use std::{ use text::LineEnding; use util::{ResultExt, get_system_shell}; -#[derive(Serialize)] -pub struct AssistantSystemPromptContext { - pub worktrees: Vec, +#[derive(Debug, Clone, Serialize)] +pub struct ProjectContext { + pub worktrees: Vec, pub has_rules: bool, pub os: String, pub arch: String, pub shell: String, } -impl AssistantSystemPromptContext { - pub fn new(worktrees: Vec) -> Self { +impl ProjectContext { + pub fn new(worktrees: Vec) -> Self { let has_rules = worktrees .iter() .any(|worktree| worktree.rules_file.is_some()); @@ -40,15 +40,15 @@ impl AssistantSystemPromptContext { } } -#[derive(Serialize)] -pub struct WorktreeInfoForSystemPrompt { +#[derive(Debug, Clone, Serialize)] +pub struct WorktreeContext { pub root_name: String, pub abs_path: Arc, - pub rules_file: Option, + pub rules_file: Option, } -#[derive(Serialize)] -pub struct SystemPromptRulesFile { +#[derive(Debug, Clone, Serialize)] +pub struct RulesFileContext { pub path_in_worktree: Arc, pub abs_path: Arc, pub text: String, @@ -260,7 +260,7 @@ impl PromptBuilder { pub fn generate_assistant_system_prompt( &self, - context: &AssistantSystemPromptContext, + context: &ProjectContext, ) -> Result { self.handlebars .lock()