From bab28560ef56092bd99e17d595523644f18f6764 Mon Sep 17 00:00:00 2001 From: Nathan Sobo Date: Fri, 18 Apr 2025 20:47:59 -0600 Subject: [PATCH] Systematically optimize agentic editing performance (#28961) Now that we've established a proper eval in tree, this PR is reboots of our agent loop back to a set of minimal tools and simpler prompts. We should aim to get this branch feeling subjectively competitive with what's on main and then merge it, and build from there. Let's invest in our eval and use it to drive better performance of the agent loop. How you can help: Pick an example, and then make the outcome faster or better. It's fine to even use your own subjective judgment, as our evaluation criteria likely need tuning as well at this point. Focus on making the agent work better in your own subjective experience first. Let's focus on simple/practical improvements to make this thing work better, then determine how we can craft our judgment criteria to lock those improvements in. Release Notes: - N/A --------- Co-authored-by: Max Co-authored-by: Antonio Co-authored-by: Agus Co-authored-by: Richard Co-authored-by: Max Brunsfeld Co-authored-by: Antonio Scandurra Co-authored-by: Michael Sloan --- Cargo.lock | 5 + assets/prompts/assistant_system_prompt.hbs | 203 ++---- assets/settings/default.json | 173 +++-- crates/agent/src/thread.rs | 53 +- crates/agent/src/thread_store.rs | 3 - crates/assistant_tools/Cargo.toml | 4 + crates/assistant_tools/src/assistant_tools.rs | 24 +- .../assistant_tools/src/code_symbols_tool.rs | 20 +- crates/assistant_tools/src/contents_tool.rs | 2 +- .../src/diagnostics_tool/description.md | 7 +- crates/assistant_tools/src/edit_file_tool.rs | 183 ++++++ .../src/edit_file_tool/description.md | 45 ++ .../src/list_directory_tool.rs | 2 +- .../src/list_directory_tool/description.md | 2 +- .../assistant_tools/src/path_search_tool.rs | 171 +++-- .../src/path_search_tool/description.md | 8 +- crates/assistant_tools/src/read_file_tool.rs | 243 ++++++- .../src/read_file_tool/description.md | 5 +- .../src/regex_search_tool/description.md | 9 +- crates/eval/Cargo.toml | 3 +- .../{criteria.md => diff_criteria.md} | 0 .../{criteria.md => diff_criteria.md} | 0 .../{criteria.md => diff_criteria.md} | 0 .../{criteria.md => diff_criteria.md} | 0 .../{criteria.md => diff_criteria.md} | 0 .../{criteria.md => diff_criteria.md} | 0 .../{criteria.md => diff_criteria.md} | 0 .../email_verification_refactor/base.toml | 1 + .../{criteria.md => diff_criteria.md} | 0 .../{criteria.md => diff_criteria.md} | 0 .../{criteria.md => diff_criteria.md} | 0 .../{criteria.md => diff_criteria.md} | 0 .../find_and_replace_diff_card/prompt.md | 2 +- .../thread_criteria.md | 3 + .../{criteria.md => diff_criteria.md} | 0 .../{criteria.md => diff_criteria.md} | 0 .../{criteria.md => diff_criteria.md} | 0 .../{criteria.md => diff_criteria.md} | 0 .../{criteria.md => diff_criteria.md} | 0 .../eval/examples/metal_i64_support/base.toml | 1 + .../{criteria.md => diff_criteria.md} | 0 .../{criteria.md => diff_criteria.md} | 0 .../{criteria.md => diff_criteria.md} | 0 .../{criteria.md => diff_criteria.md} | 0 .../{criteria.md => diff_criteria.md} | 0 .../{criteria.md => diff_criteria.md} | 0 .../{criteria.md => diff_criteria.md} | 0 .../{criteria.md => diff_criteria.md} | 0 .../{criteria.md => diff_criteria.md} | 0 .../{criteria.md => diff_criteria.md} | 0 .../{criteria.md => diff_criteria.md} | 0 .../{criteria.md => diff_criteria.md} | 0 .../{criteria.md => diff_criteria.md} | 0 .../{criteria.md => diff_criteria.md} | 0 .../{criteria.md => diff_criteria.md} | 0 .../{criteria.md => diff_criteria.md} | 0 .../{criteria.md => diff_criteria.md} | 0 .../virtio_block_request_refactor/base.toml | 1 + .../{criteria.md => diff_criteria.md} | 0 .../{criteria.md => diff_criteria.md} | 0 .../{criteria.md => diff_criteria.md} | 0 crates/eval/src/eval.rs | 193 ++++-- crates/eval/src/example.rs | 620 +++++++++++++++--- ...judge_prompt.hbs => judge_diff_prompt.hbs} | 22 + crates/eval/src/judge_thread_prompt.hbs | 22 + crates/prompt_store/src/prompts.rs | 2 - crates/worktree/src/worktree.rs | 17 + typos.toml | 4 +- 68 files changed, 1575 insertions(+), 478 deletions(-) create mode 100644 crates/assistant_tools/src/edit_file_tool.rs create mode 100644 crates/assistant_tools/src/edit_file_tool/description.md rename crates/eval/examples/add_arp_protocol_support/{criteria.md => diff_criteria.md} (100%) rename crates/eval/examples/auth_session_management/{criteria.md => diff_criteria.md} (100%) rename crates/eval/examples/buffer_string_input_support/{criteria.md => diff_criteria.md} (100%) rename crates/eval/examples/checkpoint_stability/{criteria.md => diff_criteria.md} (100%) rename crates/eval/examples/dd_iaptic_mcp_server_integration/{criteria.md => diff_criteria.md} (100%) rename crates/eval/examples/debian_image_builder/{criteria.md => diff_criteria.md} (100%) rename crates/eval/examples/docs_restructure/{criteria.md => diff_criteria.md} (100%) rename crates/eval/examples/email_verification_refactor/{criteria.md => diff_criteria.md} (100%) rename crates/eval/examples/exif_rotation_support/{criteria.md => diff_criteria.md} (100%) rename crates/eval/examples/expand_laravel_php_support/{criteria.md => diff_criteria.md} (100%) rename crates/eval/examples/find_and_replace_diff_card/{criteria.md => diff_criteria.md} (100%) create mode 100644 crates/eval/examples/find_and_replace_diff_card/thread_criteria.md rename crates/eval/examples/finnish_translation/{criteria.md => diff_criteria.md} (100%) rename crates/eval/examples/language_model_file_support/{criteria.md => diff_criteria.md} (100%) rename crates/eval/examples/lhs_join_update_callbacks/{criteria.md => diff_criteria.md} (100%) rename crates/eval/examples/libdevice_symbol_reexport/{criteria.md => diff_criteria.md} (100%) rename crates/eval/examples/license_management/{criteria.md => diff_criteria.md} (100%) rename crates/eval/examples/metal_i64_support/{criteria.md => diff_criteria.md} (100%) rename crates/eval/examples/metrics_data_size_updates/{criteria.md => diff_criteria.md} (100%) rename crates/eval/examples/nan_diff_handling/{criteria.md => diff_criteria.md} (100%) rename crates/eval/examples/never_type_workaround/{criteria.md => diff_criteria.md} (100%) rename crates/eval/examples/optimizer_schema_refactor/{criteria.md => diff_criteria.md} (100%) rename crates/eval/examples/rate_limit_endpoints/{criteria.md => diff_criteria.md} (100%) rename crates/eval/examples/replace_hold_with_drain_on_exit/{criteria.md => diff_criteria.md} (100%) rename crates/eval/examples/request_to_axios_migration/{criteria.md => diff_criteria.md} (100%) rename crates/eval/examples/restore_version_api_support/{criteria.md => diff_criteria.md} (100%) rename crates/eval/examples/runtime_script_refactor/{criteria.md => diff_criteria.md} (100%) rename crates/eval/examples/standardized_docker_dependency_checks/{criteria.md => diff_criteria.md} (100%) rename crates/eval/examples/table_metrics_sorting/{criteria.md => diff_criteria.md} (100%) rename crates/eval/examples/tax_id_validation/{criteria.md => diff_criteria.md} (100%) rename crates/eval/examples/test_infrastructure/{criteria.md => diff_criteria.md} (100%) rename crates/eval/examples/time_detail_merge_update/{criteria.md => diff_criteria.md} (100%) rename crates/eval/examples/tool_response_handling/{criteria.md => diff_criteria.md} (100%) rename crates/eval/examples/toolbar_endpoints/{criteria.md => diff_criteria.md} (100%) rename crates/eval/examples/virtio_block_request_refactor/{criteria.md => diff_criteria.md} (100%) rename crates/eval/examples/war_and_uri_corrections/{criteria.md => diff_criteria.md} (100%) rename crates/eval/examples/window_title_support/{criteria.md => diff_criteria.md} (100%) rename crates/eval/src/{judge_prompt.hbs => judge_diff_prompt.hbs} (64%) create mode 100644 crates/eval/src/judge_thread_prompt.hbs diff --git a/Cargo.lock b/Cargo.lock index 1b872d02fa..27f9183913 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -710,17 +710,21 @@ dependencies = [ "gpui", "html_to_markdown", "http_client", + "indoc", "itertools 0.14.0", "language", "language_model", "linkme", "open", + "pretty_assertions", "project", "rand 0.8.5", "regex", "schemars", "serde", "serde_json", + "settings", + "tree-sitter-rust", "ui", "unindent", "util", @@ -4914,6 +4918,7 @@ dependencies = [ "release_channel", "reqwest_client", "serde", + "serde_json", "settings", "shellexpand 2.1.2", "telemetry", diff --git a/assets/prompts/assistant_system_prompt.hbs b/assets/prompts/assistant_system_prompt.hbs index 60b2cee74e..be186f1738 100644 --- a/assets/prompts/assistant_system_prompt.hbs +++ b/assets/prompts/assistant_system_prompt.hbs @@ -1,148 +1,77 @@ -You are an AI assistant integrated into a code editor. You have the programming ability of an expert programmer who takes pride in writing high-quality code and is driven to the point of obsession about solving problems effectively. Your goal is to do one of the following two things: +You are a highly skilled software engineer with extensive knowledge in many programming languages, frameworks, design patterns, and best practices. -1. Help users answer questions and perform tasks related to their codebase. -2. Answer general-purpose questions unrelated to their particular codebase. +## Communication -It will be up to you to decide which of these you are doing based on what the user has told you. When unclear, ask clarifying questions to understand the user's intent before proceeding. +1. Be conversational but professional. +2. Refer to the USER in the second person and yourself in the first person. +3. Format your responses in markdown. Use backticks to format file, directory, function, and class names. +4. NEVER lie or make things up. +5. Refrain from apologizing all the time when results are unexpected. Instead, just try your best to proceed or explain the circumstances to the user without apologizing. -You should only perform actions that modify the user's system if explicitly requested by the user: -- If the user asks a question about how to accomplish a task, provide guidance or information, and use read-only tools (e.g., search) to assist. You may suggest potential actions, but do not directly modify the user's system without explicit instruction. -- If the user clearly requests that you perform an action, carry out the action directly without explaining why you are doing so. +## Searching and Reading -When answering questions, it's okay to give incomplete examples containing comments about what would go there in a real version. When being asked to directly perform tasks on the code base, you must ALWAYS make fully working code. You may never "simplify" the code by omitting or deleting functionality you know the user has requested, and you must NEVER write comments like "in a full version, this would..." - instead, you must actually implement the real version. Don't be lazy! +If you are unsure about the answer to the user's request or how to satiate their request, you should gather more information. +This can be done with additional tool calls, asking clarifying questions, etc. -Note that project files are automatically backed up. The user can always get them back later if anything goes wrong, so there's -no need to create backup files (e.g. `.bak` files) because these files will just take up unnecessary space on the user's disk. +For example, if you've performed a semantic search, and the results may not fully answer the user's request, or merit gathering more information, feel free to call more tools. Similarly, if you've performed an edit that may partially +satiate the user's query, but you're not confident, gather more information or use more tools before ending your turn. -When attempting to resolve issues around failing tests, never simply remove the failing tests. Unless the user explicitly asks you to remove tests, ALWAYS attempt to fix the code causing the tests to fail. +Bias towards not asking the user for help if you can find the answer yourself. -Ignore "TODO"-type comments unless they're relevant to the user's explicit request or the user specifically asks you to address them. It is, however, okay to include them in codebase summaries. +## Tool Use - +{{#if has_rules}} +There are project rules that apply to these root directories: +{{#each worktrees}} +{{#if rules_file}} +`{{root_name}}/{{rules_file.path_in_worktree}}`: +`````` +{{{rules_file.text}}} +`````` +{{/if}} +{{/each}} +{{/if}} {{#if has_default_user_rules}} The user has specified the following rules that should be applied: @@ -152,32 +81,8 @@ The user has specified the following rules that should be applied: Rules title: {{title}} {{/if}} `````` -{{contents}} +{{contents}}} `````` {{/each}} - {{/if}} -The user has opened a project that contains the following root directories/files. Whenever you specify a path in the project, it must be a relative path which begins with one of these root directories/files: - -{{#each worktrees}} -- `{{root_name}}` (absolute path: `{{abs_path}}`) -{{/each}} -{{#if has_rules}} - -There are project rules that apply to these root directories: -{{#each worktrees}} -{{#if rules_file}} - -`{{root_name}}/{{rules_file.path_in_worktree}}`: - -`````` -{{{rules_file.text}}} -`````` {{/if}} -{{/each}} -{{/if}} - - -Operating System: {{os}} ({{arch}}) -Shell: {{shell}} - diff --git a/assets/settings/default.json b/assets/settings/default.json index 98ee37f213..f31feb7356 100644 --- a/assets/settings/default.json +++ b/assets/settings/default.json @@ -214,7 +214,14 @@ // The default number of lines to expand excerpts in the multibuffer by. "expand_excerpt_lines": 3, // Globs to match against file paths to determine if a file is private. - "private_files": ["**/.env*", "**/*.pem", "**/*.key", "**/*.cert", "**/*.crt", "**/secrets.yml"], + "private_files": [ + "**/.env*", + "**/*.pem", + "**/*.key", + "**/*.cert", + "**/*.crt", + "**/secrets.yml" + ], // Whether to use additional LSP queries to format (and amend) the code after // every "trigger" symbol input, defined by LSP server capabilities. "use_on_type_format": true, @@ -587,7 +594,6 @@ // // Default: main "fallback_branch_name": "main", - "scrollbar": { // When to show the scrollbar in the git panel. // @@ -660,25 +666,25 @@ "name": "Write", "enable_all_context_servers": true, "tools": { - "terminal": true, - "batch_tool": true, - "code_actions": true, - "code_symbols": true, - "contents": true, + "batch_tool": false, + "code_actions": false, + "code_symbols": false, + "contents": false, "copy_path": false, "create_file": true, "delete_path": false, "diagnostics": true, - "find_replace_file": true, + "edit_file": true, "fetch": true, - "list_directory": false, + "list_directory": true, "move_path": false, - "now": true, + "now": false, "path_search": true, "read_file": true, "regex_search": true, - "rename": true, - "symbol_info": true, + "rename": false, + "symbol_info": false, + "terminal": true, "thinking": true, "web_search": true } @@ -715,7 +721,9 @@ // The list of language servers to use (or disable) for all languages. // // This is typically customized on a per-language basis. - "language_servers": ["..."], + "language_servers": [ + "..." + ], // When to automatically save edited buffers. This setting can // take four values. // @@ -911,7 +919,9 @@ // for files that are not tracked by git, but are still important to your project. Note that globs // that are overly broad can slow down Zed's file scanning. `file_scan_exclusions` takes // precedence over these inclusions. - "file_scan_inclusions": [".env*"], + "file_scan_inclusions": [ + ".env*" + ], // Git gutter behavior configuration. "git": { // Control whether the git gutter is shown. May take 2 values: @@ -963,7 +973,15 @@ // Any addition to this list will be merged with the default list. // Globs are matched relative to the worktree root, // except when starting with a slash (/) or equivalent in Windows. - "disabled_globs": ["**/.env*", "**/*.pem", "**/*.key", "**/*.cert", "**/*.crt", "**/.dev.vars", "**/secrets.yml"], + "disabled_globs": [ + "**/.env*", + "**/*.pem", + "**/*.key", + "**/*.cert", + "**/*.crt", + "**/.dev.vars", + "**/secrets.yml" + ], // When to show edit predictions previews in buffer. // This setting takes two possible values: // 1. Display predictions inline when there are no language server completions available. @@ -1096,7 +1114,12 @@ // Default directories to search for virtual environments, relative // to the current working directory. We recommend overriding this // in your project's settings, rather than globally. - "directories": [".env", "env", ".venv", "venv"], + "directories": [ + ".env", + "env", + ".venv", + "venv" + ], // Can also be `csh`, `fish`, `nushell` and `power_shell` "activate_script": "default" } @@ -1160,8 +1183,15 @@ // } // "file_types": { - "JSONC": ["**/.zed/**/*.json", "**/zed/**/*.json", "**/Zed/**/*.json", "**/.vscode/**/*.json"], - "Shell Script": [".env.*"] + "JSONC": [ + "**/.zed/**/*.json", + "**/zed/**/*.json", + "**/Zed/**/*.json", + "**/.vscode/**/*.json" + ], + "Shell Script": [ + ".env.*" + ] }, // By default use a recent system version of node, or install our own. // You can override this to use a version of node that is not in $PATH with: @@ -1234,10 +1264,15 @@ // Different settings for specific languages. "languages": { "Astro": { - "language_servers": ["astro-language-server", "..."], + "language_servers": [ + "astro-language-server", + "..." + ], "prettier": { "allowed": true, - "plugins": ["prettier-plugin-astro"] + "plugins": [ + "prettier-plugin-astro" + ] } }, "Blade": { @@ -1273,10 +1308,19 @@ "ensure_final_newline_on_save": false }, "Elixir": { - "language_servers": ["elixir-ls", "!next-ls", "!lexical", "..."] + "language_servers": [ + "elixir-ls", + "!next-ls", + "!lexical", + "..." + ] }, "Erlang": { - "language_servers": ["erlang-ls", "!elp", "..."] + "language_servers": [ + "erlang-ls", + "!elp", + "..." + ] }, "Git Commit": { "allow_rewrap": "anywhere" @@ -1292,7 +1336,12 @@ } }, "HEEX": { - "language_servers": ["elixir-ls", "!next-ls", "!lexical", "..."] + "language_servers": [ + "elixir-ls", + "!next-ls", + "!lexical", + "..." + ] }, "HTML": { "prettier": { @@ -1302,11 +1351,17 @@ "Java": { "prettier": { "allowed": true, - "plugins": ["prettier-plugin-java"] + "plugins": [ + "prettier-plugin-java" + ] } }, "JavaScript": { - "language_servers": ["!typescript-language-server", "vtsls", "..."], + "language_servers": [ + "!typescript-language-server", + "vtsls", + "..." + ], "prettier": { "allowed": true } @@ -1324,7 +1379,10 @@ "LaTeX": { "format_on_save": "on", "formatter": "language_server", - "language_servers": ["texlab", "..."], + "language_servers": [ + "texlab", + "..." + ], "prettier": { "allowed": false } @@ -1339,10 +1397,16 @@ } }, "PHP": { - "language_servers": ["phpactor", "!intelephense", "..."], + "language_servers": [ + "phpactor", + "!intelephense", + "..." + ], "prettier": { "allowed": true, - "plugins": ["@prettier/plugin-php"], + "plugins": [ + "@prettier/plugin-php" + ], "parser": "php" } }, @@ -1350,7 +1414,12 @@ "allow_rewrap": "anywhere" }, "Ruby": { - "language_servers": ["solargraph", "!ruby-lsp", "!rubocop", "..."] + "language_servers": [ + "solargraph", + "!ruby-lsp", + "!rubocop", + "..." + ] }, "SCSS": { "prettier": { @@ -1360,21 +1429,36 @@ "SQL": { "prettier": { "allowed": true, - "plugins": ["prettier-plugin-sql"] + "plugins": [ + "prettier-plugin-sql" + ] } }, "Starlark": { - "language_servers": ["starpls", "!buck2-lsp", "..."] + "language_servers": [ + "starpls", + "!buck2-lsp", + "..." + ] }, "Svelte": { - "language_servers": ["svelte-language-server", "..."], + "language_servers": [ + "svelte-language-server", + "..." + ], "prettier": { "allowed": true, - "plugins": ["prettier-plugin-svelte"] + "plugins": [ + "prettier-plugin-svelte" + ] } }, "TSX": { - "language_servers": ["!typescript-language-server", "vtsls", "..."], + "language_servers": [ + "!typescript-language-server", + "vtsls", + "..." + ], "prettier": { "allowed": true } @@ -1385,13 +1469,20 @@ } }, "TypeScript": { - "language_servers": ["!typescript-language-server", "vtsls", "..."], + "language_servers": [ + "!typescript-language-server", + "vtsls", + "..." + ], "prettier": { "allowed": true } }, "Vue.js": { - "language_servers": ["vue-language-server", "..."], + "language_servers": [ + "vue-language-server", + "..." + ], "prettier": { "allowed": true } @@ -1399,7 +1490,9 @@ "XML": { "prettier": { "allowed": true, - "plugins": ["@prettier/plugin-xml"] + "plugins": [ + "@prettier/plugin-xml" + ] } }, "YAML": { @@ -1408,7 +1501,10 @@ } }, "Zig": { - "language_servers": ["zls", "..."] + "language_servers": [ + "zls", + "..." + ] } }, // Different settings for specific language models. @@ -1562,7 +1658,6 @@ // } // ] "ssh_connections": [], - // Configures context servers for use in the Assistant. "context_servers": {}, "debugger": { diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index a7044b0100..c50c001cd2 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -313,6 +313,9 @@ pub struct Thread { feedback: Option, message_feedback: HashMap, last_auto_capture_at: Option, + request_callback: Option< + Box])>, + >, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -365,6 +368,7 @@ impl Thread { feedback: None, message_feedback: HashMap::default(), last_auto_capture_at: None, + request_callback: None, } } @@ -434,9 +438,18 @@ impl Thread { feedback: None, message_feedback: HashMap::default(), last_auto_capture_at: None, + request_callback: None, } } + pub fn set_request_callback( + &mut self, + callback: impl 'static + + FnMut(&LanguageModelRequest, &[Result]), + ) { + self.request_callback = Some(Box::new(callback)); + } + pub fn id(&self) -> &ThreadId { &self.id } @@ -1083,15 +1096,6 @@ impl Thread { content.push(stale_message.into()); } - if action_log.has_edited_files_since_project_diagnostics_check() { - content.push( - "\n\nWhen you're done making changes, make sure to check project diagnostics \ - and fix all errors AND warnings you introduced! \ - DO NOT mention you're going to do this until you're done." - .into(), - ); - } - if !content.is_empty() { let context_message = LanguageModelRequestMessage { role: Role::User, @@ -1110,6 +1114,11 @@ impl Thread { cx: &mut Context, ) { let pending_completion_id = post_inc(&mut self.completion_count); + let mut request_callback_parameters = if self.request_callback.is_some() { + Some((request.clone(), Vec::new())) + } else { + None + }; let prompt_id = self.last_prompt_id.clone(); let task = cx.spawn(async move |thread, cx| { let stream_completion_future = model.stream_completion_with_usage(request, &cx); @@ -1117,6 +1126,7 @@ impl Thread { thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage); let stream_completion = async { let (mut events, usage) = stream_completion_future.await?; + let mut stop_reason = StopReason::EndTurn; let mut current_token_usage = TokenUsage::default(); @@ -1129,6 +1139,11 @@ impl Thread { } while let Some(event) = events.next().await { + if let Some((_, response_events)) = request_callback_parameters.as_mut() { + response_events + .push(event.as_ref().map_err(|error| error.to_string()).cloned()); + } + let event = event?; thread.update(cx, |thread, cx| { @@ -1293,6 +1308,14 @@ impl Thread { } cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new))); + if let Some((request_callback, (request, response_events))) = thread + .request_callback + .as_mut() + .zip(request_callback_parameters.as_ref()) + { + request_callback(request, response_events); + } + thread.auto_capture_telemetry(cx); if let Ok(initial_usage) = initial_token_usage { @@ -1587,17 +1610,11 @@ impl Thread { }); } + /// Insert an empty message to be populated with tool results upon send. pub fn attach_tool_results(&mut self, cx: &mut Context) { + // TODO: Don't insert a dummy user message here. Ensure this works with the thinking model. // Insert a user message to contain the tool results. - self.insert_user_message( - // TODO: Sending up a user message without any content results in the model sending back - // responses that also don't have any content. We currently don't handle this case well, - // so for now we provide some text to keep the model on track. - "Here are the tool results.", - Vec::new(), - None, - cx, - ); + self.insert_user_message("Here are the tool results.", Vec::new(), None, cx); } /// Cancels the last pending completion, if there are any pending. diff --git a/crates/agent/src/thread_store.rs b/crates/agent/src/thread_store.rs index 74787016fd..e2f9e3d2de 100644 --- a/crates/agent/src/thread_store.rs +++ b/crates/agent/src/thread_store.rs @@ -279,14 +279,12 @@ impl ThreadStore { cx: &App, ) -> Task<(WorktreeContext, Option)> { let root_name = worktree.root_name().into(); - let abs_path = worktree.abs_path(); let rules_task = Self::load_worktree_rules_file(fs, worktree, cx); let Some(rules_task) = rules_task else { return Task::ready(( WorktreeContext { root_name, - abs_path, rules_file: None, }, None, @@ -305,7 +303,6 @@ impl ThreadStore { }; let worktree_info = WorktreeContext { root_name, - abs_path, rules_file, }; (worktree_info, rules_file_error) diff --git a/crates/assistant_tools/Cargo.toml b/crates/assistant_tools/Cargo.toml index be116a6534..eaaeff1e47 100644 --- a/crates/assistant_tools/Cargo.toml +++ b/crates/assistant_tools/Cargo.toml @@ -22,6 +22,7 @@ futures.workspace = true gpui.workspace = true html_to_markdown.workspace = true http_client.workspace = true +indoc.workspace = true itertools.workspace = true language.workspace = true language_model.workspace = true @@ -45,5 +46,8 @@ gpui = { workspace = true, features = ["test-support"] } language = { workspace = true, features = ["test-support"] } project = { workspace = true, features = ["test-support"] } rand.workspace = true +pretty_assertions.workspace = true +settings = { workspace = true, features = ["test-support"] } +tree-sitter-rust.workspace = true workspace = { workspace = true, features = ["test-support"] } unindent.workspace = true diff --git a/crates/assistant_tools/src/assistant_tools.rs b/crates/assistant_tools/src/assistant_tools.rs index 33e06466e2..b68a273b4e 100644 --- a/crates/assistant_tools/src/assistant_tools.rs +++ b/crates/assistant_tools/src/assistant_tools.rs @@ -7,8 +7,8 @@ mod create_directory_tool; mod create_file_tool; mod delete_path_tool; mod diagnostics_tool; +mod edit_file_tool; mod fetch_tool; -mod find_replace_file_tool; mod list_directory_tool; mod move_path_tool; mod now_tool; @@ -42,8 +42,8 @@ use crate::create_directory_tool::CreateDirectoryTool; use crate::create_file_tool::CreateFileTool; use crate::delete_path_tool::DeletePathTool; use crate::diagnostics_tool::DiagnosticsTool; +use crate::edit_file_tool::EditFileTool; use crate::fetch_tool::FetchTool; -use crate::find_replace_file_tool::FindReplaceFileTool; use crate::list_directory_tool::ListDirectoryTool; use crate::now_tool::NowTool; use crate::open_tool::OpenTool; @@ -59,28 +59,28 @@ pub fn init(http_client: Arc, cx: &mut App) { assistant_tool::init(cx); let registry = ToolRegistry::global(cx); + registry.register_tool(TerminalTool); registry.register_tool(BatchTool); - registry.register_tool(CodeActionTool); - registry.register_tool(CodeSymbolsTool); - registry.register_tool(ContentsTool); - registry.register_tool(CopyPathTool); registry.register_tool(CreateDirectoryTool); registry.register_tool(CreateFileTool); + registry.register_tool(CopyPathTool); registry.register_tool(DeletePathTool); - registry.register_tool(DiagnosticsTool); - registry.register_tool(FetchTool::new(http_client)); - registry.register_tool(FindReplaceFileTool); - registry.register_tool(ListDirectoryTool); + registry.register_tool(EditFileTool); + registry.register_tool(SymbolInfoTool); + registry.register_tool(CodeActionTool); registry.register_tool(MovePathTool); + registry.register_tool(DiagnosticsTool); + registry.register_tool(ListDirectoryTool); registry.register_tool(NowTool); registry.register_tool(OpenTool); + registry.register_tool(CodeSymbolsTool); + registry.register_tool(ContentsTool); registry.register_tool(PathSearchTool); registry.register_tool(ReadFileTool); registry.register_tool(RegexSearchTool); registry.register_tool(RenameTool); - registry.register_tool(SymbolInfoTool); - registry.register_tool(TerminalTool); registry.register_tool(ThinkingTool); + registry.register_tool(FetchTool::new(http_client)); cx.observe_flag::({ move |is_enabled, cx| { diff --git a/crates/assistant_tools/src/code_symbols_tool.rs b/crates/assistant_tools/src/code_symbols_tool.rs index 4743c88720..9cd63ab02c 100644 --- a/crates/assistant_tools/src/code_symbols_tool.rs +++ b/crates/assistant_tools/src/code_symbols_tool.rs @@ -147,7 +147,7 @@ impl Tool for CodeSymbolsTool { }; cx.spawn(async move |cx| match input.path { - Some(path) => file_outline(project, path, action_log, regex, input.offset, cx).await, + Some(path) => file_outline(project, path, action_log, regex, cx).await, None => project_symbols(project, regex, input.offset, cx).await, }) .into() @@ -159,7 +159,6 @@ pub async fn file_outline( path: String, action_log: Entity, regex: Option, - offset: u32, cx: &mut AsyncApp, ) -> anyhow::Result { let buffer = { @@ -195,7 +194,8 @@ pub async fn file_outline( .into_iter() .map(|item| item.to_point(&snapshot)), regex, - offset, + 0, + usize::MAX, ) .await } @@ -294,11 +294,10 @@ async fn project_symbols( async fn render_outline( items: impl IntoIterator>, regex: Option, - offset: u32, + offset: usize, + results_per_page: usize, ) -> Result { - const RESULTS_PER_PAGE_USIZE: usize = RESULTS_PER_PAGE as usize; - - let mut items = items.into_iter().skip(offset as usize); + let mut items = items.into_iter().skip(offset); let entries = items .by_ref() @@ -307,7 +306,7 @@ async fn render_outline( .as_ref() .is_none_or(|regex| regex.is_match(&item.text)) }) - .take(RESULTS_PER_PAGE_USIZE) + .take(results_per_page) .collect::>(); let has_more = items.next().is_some(); @@ -338,7 +337,10 @@ async fn render_outline( Ok(output) } -fn render_entries(output: &mut String, items: impl IntoIterator>) -> u32 { +fn render_entries( + output: &mut String, + items: impl IntoIterator>, +) -> usize { let mut entries_rendered = 0; for item in items { diff --git a/crates/assistant_tools/src/contents_tool.rs b/crates/assistant_tools/src/contents_tool.rs index 5281cfa7c7..183fa29e8c 100644 --- a/crates/assistant_tools/src/contents_tool.rs +++ b/crates/assistant_tools/src/contents_tool.rs @@ -228,7 +228,7 @@ impl Tool for ContentsTool { } else { // File is too big, so return its outline and a suggestion to // read again with a line number range specified. - let outline = file_outline(project, file_path, action_log, None, 0, cx).await?; + let outline = file_outline(project, file_path, action_log, None, cx).await?; Ok(format!("This file was too big to read all at once. Here is an outline of its symbols:\n\n{outline}\n\nUsing the line numbers in this outline, you can call this tool again while specifying the start and end fields to see the implementations of symbols in the outline.")) } diff --git a/crates/assistant_tools/src/diagnostics_tool/description.md b/crates/assistant_tools/src/diagnostics_tool/description.md index ab250e09ad..90dc00f1e4 100644 --- a/crates/assistant_tools/src/diagnostics_tool/description.md +++ b/crates/assistant_tools/src/diagnostics_tool/description.md @@ -15,6 +15,7 @@ To get a project-wide diagnostic summary: {} -IMPORTANT: When you're done making changes, you **MUST** get the **project** diagnostics (input: `{}`) at the end of your edits so you can fix any problems you might have introduced. **DO NOT** tell the user you're done before doing this! - -You may only attempt to fix these up to 3 times. If you have tried 3 times to fix them, and there are still problems remaining, you must not continue trying to fix them, and must instead tell the user that there are problems remaining - and ask if the user would like you to attempt to solve them further. + +- If you think you can fix a diagnostic, make 1-2 attempts and then give up. +- Don't remove code you've generated just because you can't fix an error. The user can help you fix it. + diff --git a/crates/assistant_tools/src/edit_file_tool.rs b/crates/assistant_tools/src/edit_file_tool.rs new file mode 100644 index 0000000000..136dd60bfe --- /dev/null +++ b/crates/assistant_tools/src/edit_file_tool.rs @@ -0,0 +1,183 @@ +use crate::{replace::replace_with_flexible_indent, schema::json_schema_for}; +use anyhow::{Context as _, Result, anyhow}; +use assistant_tool::{ActionLog, Tool, ToolResult}; +use gpui::{App, AppContext, AsyncApp, Entity, Task}; +use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat}; +use project::Project; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; +use std::{path::PathBuf, sync::Arc}; +use ui::IconName; + +use crate::replace::replace_exact; + +#[derive(Debug, Serialize, Deserialize, JsonSchema)] +pub struct EditFileToolInput { + /// The full path of the file to modify in the project. + /// + /// WARNING: When specifying which file path need changing, you MUST + /// start each path with one of the project's root directories. + /// + /// The following examples assume we have two root directories in the project: + /// - backend + /// - frontend + /// + /// + /// `backend/src/main.rs` + /// + /// Notice how the file path starts with root-1. Without that, the path + /// would be ambiguous and the call would fail! + /// + /// + /// + /// `frontend/db.js` + /// + pub path: PathBuf, + + /// A user-friendly markdown description of what's being replaced. This will be shown in the UI. + /// + /// Fix API endpoint URLs + /// Update copyright year in `page_footer` + pub display_description: String, + + /// The text to replace. + pub old_string: String, + + /// The text to replace it with. + pub new_string: String, +} + +pub struct EditFileTool; + +impl Tool for EditFileTool { + fn name(&self) -> String { + "edit_file".into() + } + + fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool { + false + } + + fn description(&self) -> String { + include_str!("edit_file_tool/description.md").to_string() + } + + fn icon(&self) -> IconName { + IconName::Pencil + } + + fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result { + json_schema_for::(format) + } + + fn ui_text(&self, input: &serde_json::Value) -> String { + match serde_json::from_value::(input.clone()) { + Ok(input) => input.display_description, + Err(_) => "Edit file".to_string(), + } + } + + fn run( + self: Arc, + input: serde_json::Value, + _messages: &[LanguageModelRequestMessage], + project: Entity, + action_log: Entity, + cx: &mut App, + ) -> ToolResult { + let input = match serde_json::from_value::(input) { + Ok(input) => input, + Err(err) => return Task::ready(Err(anyhow!(err))).into(), + }; + + cx.spawn(async move |cx: &mut AsyncApp| { + let project_path = project.read_with(cx, |project, cx| { + project + .find_project_path(&input.path, cx) + .context("Path not found in project") + })??; + + let buffer = project + .update(cx, |project, cx| project.open_buffer(project_path, cx))? + .await?; + + let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?; + + if input.old_string.is_empty() { + return Err(anyhow!("`old_string` cannot be empty. Use a different tool if you want to create a file.")); + } + + if input.old_string == input.new_string { + return Err(anyhow!("The `old_string` and `new_string` are identical, so no changes would be made.")); + } + + let result = cx + .background_spawn(async move { + // Try to match exactly + let diff = replace_exact(&input.old_string, &input.new_string, &snapshot) + .await + // If that fails, try being flexible about indentation + .or_else(|| replace_with_flexible_indent(&input.old_string, &input.new_string, &snapshot))?; + + if diff.edits.is_empty() { + return None; + } + + let old_text = snapshot.text(); + + Some((old_text, diff)) + }) + .await; + + let Some((old_text, diff)) = result else { + let err = buffer.read_with(cx, |buffer, _cx| { + let file_exists = buffer + .file() + .map_or(false, |file| file.disk_state().exists()); + + if !file_exists { + anyhow!("{} does not exist", input.path.display()) + } else if buffer.is_empty() { + anyhow!( + "{} is empty, so the provided `old_string` wasn't found.", + input.path.display() + ) + } else { + anyhow!("Failed to match the provided `old_string`") + } + })?; + + return Err(err) + }; + + let snapshot = cx.update(|cx| { + action_log.update(cx, |log, cx| { + log.buffer_read(buffer.clone(), cx) + }); + let snapshot = buffer.update(cx, |buffer, cx| { + buffer.finalize_last_transaction(); + buffer.apply_diff(diff, cx); + buffer.finalize_last_transaction(); + buffer.snapshot() + }); + action_log.update(cx, |log, cx| { + log.buffer_edited(buffer.clone(), cx) + }); + snapshot + })?; + + project.update( cx, |project, cx| { + project.save_buffer(buffer, cx) + })?.await?; + + let diff_str = cx.background_spawn(async move { + let new_text = snapshot.text(); + language::unified_diff(&old_text, &new_text) + }).await; + + + Ok(format!("Edited {}:\n\n```diff\n{}\n```", input.path.display(), diff_str)) + + }).into() + } +} diff --git a/crates/assistant_tools/src/edit_file_tool/description.md b/crates/assistant_tools/src/edit_file_tool/description.md new file mode 100644 index 0000000000..51f2db7808 --- /dev/null +++ b/crates/assistant_tools/src/edit_file_tool/description.md @@ -0,0 +1,45 @@ +This is a tool for editing files. For moving or renaming files, you should generally use the `terminal` tool with the 'mv' command instead. For larger edits, use the `create_file` tool to overwrite files. + +Before using this tool: + +1. Use the `read_file` tool to understand the file's contents and context + +2. Verify the directory path is correct (only applicable when creating new files): + - Use the `list_directory` tool to verify the parent directory exists and is the correct location + +To make a file edit, provide the following: +1. path: The full path to the file you wish to modify in the project. This path must include the root directory in the project. +2. old_string: The text to replace (must be unique within the file, and must match the file contents exactly, including all whitespace and indentation) +3. new_string: The edited text, which will replace the old_string in the file. + +The tool will replace ONE occurrence of old_string with new_string in the specified file. + +CRITICAL REQUIREMENTS FOR USING THIS TOOL: + +1. UNIQUENESS: The old_string MUST uniquely identify the specific instance you want to change. This means: + - Include AT LEAST 3-5 lines of context BEFORE the change point + - Include AT LEAST 3-5 lines of context AFTER the change point + - Include all whitespace, indentation, and surrounding code exactly as it appears in the file + +2. SINGLE INSTANCE: This tool can only change ONE instance at a time. If you need to change multiple instances: + - Make separate calls to this tool for each instance + - Each call must uniquely identify its specific instance using extensive context + +3. VERIFICATION: Before using this tool: + - Check how many instances of the target text exist in the file + - If multiple instances exist, gather enough context to uniquely identify each one + - Plan separate tool calls for each instance + +WARNING: If you do not follow these requirements: + - The tool will fail if old_string matches multiple locations + - The tool will fail if old_string doesn't match exactly (including whitespace) + - You may change the wrong instance if you don't include enough context + +When making edits: + - Ensure the edit results in idiomatic, correct code + - Do not leave the code in a broken state + - Always use fully-qualified project paths (starting with the name of one of the project's root directories) + +If you want to create a new file, use the `create_file` tool instead of this tool. Don't pass an empty `old_string`. + +Remember: when making multiple file edits in a row to the same file, you should prefer to send all edits in a single message with multiple calls to this tool, rather than multiple messages with a single call each. diff --git a/crates/assistant_tools/src/list_directory_tool.rs b/crates/assistant_tools/src/list_directory_tool.rs index 9db00d765d..ef0c2838e8 100644 --- a/crates/assistant_tools/src/list_directory_tool.rs +++ b/crates/assistant_tools/src/list_directory_tool.rs @@ -12,7 +12,7 @@ use util::markdown::MarkdownString; #[derive(Debug, Serialize, Deserialize, JsonSchema)] pub struct ListDirectoryToolInput { - /// The relative path of the directory to list. + /// The fully-qualified path of the directory to list in the project. /// /// This path should never be absolute, and the first component /// of the path should always be a root directory in a project. diff --git a/crates/assistant_tools/src/list_directory_tool/description.md b/crates/assistant_tools/src/list_directory_tool/description.md index a7d364ae63..1daf3e3a9f 100644 --- a/crates/assistant_tools/src/list_directory_tool/description.md +++ b/crates/assistant_tools/src/list_directory_tool/description.md @@ -1 +1 @@ -Lists files and directories in a given path. +Lists files and directories in a given path. Prefer the `regex_search` or `path_search` tools when searching the codebase. diff --git a/crates/assistant_tools/src/path_search_tool.rs b/crates/assistant_tools/src/path_search_tool.rs index 17b85f8278..ea19cb1dee 100644 --- a/crates/assistant_tools/src/path_search_tool.rs +++ b/crates/assistant_tools/src/path_search_tool.rs @@ -6,14 +6,14 @@ use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat} use project::Project; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; -use std::{path::PathBuf, sync::Arc}; +use std::{cmp, fmt::Write as _, path::PathBuf, sync::Arc}; use ui::IconName; use util::paths::PathMatcher; use worktree::Snapshot; #[derive(Debug, Serialize, Deserialize, JsonSchema)] pub struct PathSearchToolInput { - /// The glob to search all project paths for. + /// The glob to match against every path in the project. /// /// /// If the project has the following root directories: @@ -76,66 +76,125 @@ impl Tool for PathSearchTool { Ok(input) => (input.offset, input.glob), Err(err) => return Task::ready(Err(anyhow!(err))).into(), }; - - let path_matcher = match PathMatcher::new([ - // Sometimes models try to search for "". In this case, return all paths in the project. - if glob.is_empty() { "*" } else { &glob }, - ]) { - Ok(matcher) => matcher, - Err(err) => return Task::ready(Err(anyhow!("Invalid glob: {err}"))).into(), - }; - let snapshots: Vec = project - .read(cx) - .worktrees(cx) - .map(|worktree| worktree.read(cx).snapshot()) - .collect(); - + let offset = offset as usize; + let task = search_paths(&glob, project, cx); cx.background_spawn(async move { - let mut matches = Vec::new(); - - for worktree in snapshots { - let root_name = worktree.root_name(); - - // Don't consider ignored entries. - for entry in worktree.entries(false, 0) { - if path_matcher.is_match(&entry.path) { - matches.push( - PathBuf::from(root_name) - .join(&entry.path) - .to_string_lossy() - .to_string(), - ); - } - } - } + let matches = task.await?; + let paginated_matches = &matches[cmp::min(offset, matches.len()) + ..cmp::min(offset + RESULTS_PER_PAGE, matches.len())]; if matches.is_empty() { - Ok(format!("No paths in the project matched the glob {glob:?}")) + Ok("No matches found".to_string()) } else { - // Sort to group entries in the same directory together. - matches.sort(); - - let total_matches = matches.len(); - let response = if total_matches > RESULTS_PER_PAGE + offset as usize { - let paginated_matches: Vec<_> = matches - .into_iter() - .skip(offset as usize) - .take(RESULTS_PER_PAGE) - .collect(); - - format!( - "Found {} total matches. Showing results {}-{} (provide 'offset' parameter for more results):\n\n{}", - total_matches, + let mut message = format!("Found {} total matches.", matches.len()); + if matches.len() > RESULTS_PER_PAGE { + write!( + &mut message, + "\nShowing results {}-{} (provide 'offset' parameter for more results):", offset + 1, - offset as usize + paginated_matches.len(), - paginated_matches.join("\n") + offset + paginated_matches.len() ) - } else { - matches.join("\n") - }; - - Ok(response) + .unwrap(); + } + for mat in matches.into_iter().skip(offset).take(RESULTS_PER_PAGE) { + write!(&mut message, "\n{}", mat.display()).unwrap(); + } + Ok(message) } - }).into() + }) + .into() + } +} + +fn search_paths(glob: &str, project: Entity, cx: &mut App) -> Task>> { + let path_matcher = match PathMatcher::new([ + // Sometimes models try to search for "". In this case, return all paths in the project. + if glob.is_empty() { "*" } else { glob }, + ]) { + Ok(matcher) => matcher, + Err(err) => return Task::ready(Err(anyhow!("Invalid glob: {err}"))), + }; + let snapshots: Vec = project + .read(cx) + .worktrees(cx) + .map(|worktree| worktree.read(cx).snapshot()) + .collect(); + + cx.background_spawn(async move { + Ok(snapshots + .iter() + .flat_map(|snapshot| { + let root_name = PathBuf::from(snapshot.root_name()); + snapshot + .entries(false, 0) + .map(move |entry| root_name.join(&entry.path)) + .filter(|path| path_matcher.is_match(&path)) + }) + .collect()) + }) +} + +#[cfg(test)] +mod test { + use super::*; + use gpui::TestAppContext; + use project::{FakeFs, Project}; + use settings::SettingsStore; + use util::path; + + #[gpui::test] + async fn test_path_search_tool(cx: &mut TestAppContext) { + init_test(cx); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + "/root", + serde_json::json!({ + "apple": { + "banana": { + "carrot": "1", + }, + "bandana": { + "carbonara": "2", + }, + "endive": "3" + } + }), + ) + .await; + let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; + + let matches = cx + .update(|cx| search_paths("root/**/car*", project.clone(), cx)) + .await + .unwrap(); + assert_eq!( + matches, + &[ + PathBuf::from("root/apple/banana/carrot"), + PathBuf::from("root/apple/bandana/carbonara") + ] + ); + + let matches = cx + .update(|cx| search_paths("**/car*", project.clone(), cx)) + .await + .unwrap(); + assert_eq!( + matches, + &[ + PathBuf::from("root/apple/banana/carrot"), + PathBuf::from("root/apple/bandana/carbonara") + ] + ); + } + + fn init_test(cx: &mut TestAppContext) { + cx.update(|cx| { + let settings_store = SettingsStore::test(cx); + cx.set_global(settings_store); + language::init(cx); + Project::init_settings(cx); + }); } } diff --git a/crates/assistant_tools/src/path_search_tool/description.md b/crates/assistant_tools/src/path_search_tool/description.md index 129aaa7c8e..73345bff8e 100644 --- a/crates/assistant_tools/src/path_search_tool/description.md +++ b/crates/assistant_tools/src/path_search_tool/description.md @@ -1,3 +1,7 @@ -Returns paths in the project which match the given glob. +Fast file pattern matching tool that works with any codebase size -Results are paginated with 50 matches per page. Use the optional 'offset' parameter to request subsequent pages. +- Supports glob patterns like "**/*.js" or "src/**/*.ts" +- Returns matching file paths sorted alphabetically +- Prefer the `regex_search` tool to this tool when searching for symbols unless you have specific information about paths. +- Use this tool when you need to find files by name patterns +- Results are paginated with 50 matches per page. Use the optional 'offset' parameter to request subsequent pages. diff --git a/crates/assistant_tools/src/read_file_tool.rs b/crates/assistant_tools/src/read_file_tool.rs index 5fe5cf9e97..87e6fd96a7 100644 --- a/crates/assistant_tools/src/read_file_tool.rs +++ b/crates/assistant_tools/src/read_file_tool.rs @@ -1,14 +1,14 @@ -use std::sync::Arc; - use crate::{code_symbols_tool::file_outline, schema::json_schema_for}; use anyhow::{Result, anyhow}; use assistant_tool::{ActionLog, Tool, ToolResult}; use gpui::{App, Entity, Task}; +use indoc::formatdoc; use itertools::Itertools; use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat}; use project::Project; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; +use std::sync::Arc; use ui::IconName; use util::markdown::MarkdownString; @@ -95,11 +95,24 @@ impl Tool for ReadFileTool { }; let Some(project_path) = project.read(cx).find_project_path(&input.path, cx) else { - return Task::ready(Err(anyhow!("Path {} not found in project", &input.path,))).into(); + return Task::ready(Err(anyhow!("Path {} not found in project", &input.path))).into(); }; + let Some(worktree) = project + .read(cx) + .worktree_for_id(project_path.worktree_id, cx) + else { + return Task::ready(Err(anyhow!("Worktree not found for project path"))).into(); + }; + let exists = worktree.update(cx, |worktree, cx| { + worktree.file_exists(&project_path.path, cx) + }); let file_path = input.path.clone(); cx.spawn(async move |cx| { + if !exists.await? { + return Err(anyhow!("{} not found", file_path)) + } + let buffer = cx .update(|cx| { project.update(cx, |project, cx| project.open_buffer(project_path, cx)) @@ -141,11 +154,231 @@ impl Tool for ReadFileTool { } else { // File is too big, so return an error with the outline // and a suggestion to read again with line numbers. - let outline = file_outline(project, file_path, action_log, None, 0, cx).await?; + let outline = file_outline(project, file_path, action_log, None, cx).await?; + Ok(formatdoc! {" + This file was too big to read all at once. Here is an outline of its symbols: - Ok(format!("This file was too big to read all at once. Here is an outline of its symbols:\n\n{outline}\n\nUsing the line numbers in this outline, you can call this tool again while specifying the start_line and end_line fields to see the implementations of symbols in the outline.")) + {outline} + + Using the line numbers in this outline, you can call this tool again while specifying + the start_line and end_line fields to see the implementations of symbols in the outline." + }) } } }).into() } } + +#[cfg(test)] +mod test { + use super::*; + use gpui::{AppContext, TestAppContext}; + use language::{Language, LanguageConfig, LanguageMatcher}; + use project::{FakeFs, Project}; + use serde_json::json; + use settings::SettingsStore; + use util::path; + + #[gpui::test] + async fn test_read_nonexistent_file(cx: &mut TestAppContext) { + init_test(cx); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree("/root", json!({})).await; + let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; + let action_log = cx.new(|_| ActionLog::new(project.clone())); + let result = cx + .update(|cx| { + let input = json!({ + "path": "root/nonexistent_file.txt" + }); + Arc::new(ReadFileTool) + .run(input, &[], project.clone(), action_log, cx) + .output + }) + .await; + assert_eq!( + result.unwrap_err().to_string(), + "root/nonexistent_file.txt not found" + ); + } + + #[gpui::test] + async fn test_read_small_file(cx: &mut TestAppContext) { + init_test(cx); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + "/root", + json!({ + "small_file.txt": "This is a small file content" + }), + ) + .await; + let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; + let action_log = cx.new(|_| ActionLog::new(project.clone())); + let result = cx + .update(|cx| { + let input = json!({ + "path": "root/small_file.txt" + }); + Arc::new(ReadFileTool) + .run(input, &[], project.clone(), action_log, cx) + .output + }) + .await; + assert_eq!(result.unwrap(), "This is a small file content"); + } + + #[gpui::test] + async fn test_read_large_file(cx: &mut TestAppContext) { + init_test(cx); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + "/root", + json!({ + "large_file.rs": (0..1000).map(|i| format!("struct Test{} {{\n a: u32,\n b: usize,\n}}", i)).collect::>().join("\n") + }), + ) + .await; + let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; + let language_registry = project.read_with(cx, |project, _| project.languages().clone()); + language_registry.add(Arc::new(rust_lang())); + let action_log = cx.new(|_| ActionLog::new(project.clone())); + + let result = cx + .update(|cx| { + let input = json!({ + "path": "root/large_file.rs" + }); + Arc::new(ReadFileTool) + .run(input, &[], project.clone(), action_log.clone(), cx) + .output + }) + .await; + let content = result.unwrap(); + assert_eq!( + content.lines().skip(2).take(6).collect::>(), + vec![ + "struct Test0 [L1-4]", + " a [L2]", + " b [L3]", + "struct Test1 [L5-8]", + " a [L6]", + " b [L7]", + ] + ); + + let result = cx + .update(|cx| { + let input = json!({ + "path": "root/large_file.rs", + "offset": 1 + }); + Arc::new(ReadFileTool) + .run(input, &[], project.clone(), action_log, cx) + .output + }) + .await; + let content = result.unwrap(); + let expected_content = (0..1000) + .flat_map(|i| { + vec![ + format!("struct Test{} [L{}-{}]", i, i * 4 + 1, i * 4 + 4), + format!(" a [L{}]", i * 4 + 2), + format!(" b [L{}]", i * 4 + 3), + ] + }) + .collect::>(); + pretty_assertions::assert_eq!( + content + .lines() + .skip(2) + .take(expected_content.len()) + .collect::>(), + expected_content + ); + } + + #[gpui::test] + async fn test_read_file_with_line_range(cx: &mut TestAppContext) { + init_test(cx); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + "/root", + json!({ + "multiline.txt": "Line 1\nLine 2\nLine 3\nLine 4\nLine 5" + }), + ) + .await; + let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; + let action_log = cx.new(|_| ActionLog::new(project.clone())); + let result = cx + .update(|cx| { + let input = json!({ + "path": "root/multiline.txt", + "start_line": 2, + "end_line": 4 + }); + Arc::new(ReadFileTool) + .run(input, &[], project.clone(), action_log, cx) + .output + }) + .await; + assert_eq!(result.unwrap(), "Line 2\nLine 3"); + } + + fn init_test(cx: &mut TestAppContext) { + cx.update(|cx| { + let settings_store = SettingsStore::test(cx); + cx.set_global(settings_store); + language::init(cx); + Project::init_settings(cx); + }); + } + + fn rust_lang() -> Language { + Language::new( + LanguageConfig { + name: "Rust".into(), + matcher: LanguageMatcher { + path_suffixes: vec!["rs".to_string()], + ..Default::default() + }, + ..Default::default() + }, + Some(tree_sitter_rust::LANGUAGE.into()), + ) + .with_outline_query( + r#" + (line_comment) @annotation + + (struct_item + "struct" @context + name: (_) @name) @item + (enum_item + "enum" @context + name: (_) @name) @item + (enum_variant + name: (_) @name) @item + (field_declaration + name: (_) @name) @item + (impl_item + "impl" @context + trait: (_)? @name + "for"? @context + type: (_) @name + body: (_ "{" (_)* "}")) @item + (function_item + "fn" @context + name: (_) @name) @item + (mod_item + "mod" @context + name: (_) @name) @item + "#, + ) + .unwrap() + } +} diff --git a/crates/assistant_tools/src/read_file_tool/description.md b/crates/assistant_tools/src/read_file_tool/description.md index b14898a2c3..7bcebc0334 100644 --- a/crates/assistant_tools/src/read_file_tool/description.md +++ b/crates/assistant_tools/src/read_file_tool/description.md @@ -1,6 +1,3 @@ Reads the content of the given file in the project. -If the file is too big to read all at once, and neither a start line -nor an end line was specified, then this returns an outline of the -file's symbols (with line numbers) instead of the file's contents, -so that it can be called again with line ranges. +- Never attempt to read a path that hasn't been previously mentioned. diff --git a/crates/assistant_tools/src/regex_search_tool/description.md b/crates/assistant_tools/src/regex_search_tool/description.md index 160ecedebb..674dd6043b 100644 --- a/crates/assistant_tools/src/regex_search_tool/description.md +++ b/crates/assistant_tools/src/regex_search_tool/description.md @@ -1,7 +1,6 @@ Searches the entire project for the given regular expression. -Returns a list of paths that matched the query. For each path, it returns some excerpts of the matched text. - -Results are paginated with 20 matches per page. Use the optional 'offset' parameter to request subsequent pages. - -This tool is not aware of semantics and does not use any information from language servers, so it should only be used when no available semantic tool (e.g. one that uses language servers) could fit a particular use case instead. +- Prefer this tool when searching for files containing symbols in the project. +- Supports full regex syntax (eg. "log.*Error", "function\\s+\\w+", etc.) +- Use this tool when you need to find files containing specific patterns +- Results are paginated with 20 matches per page. Use the optional 'offset' parameter to request subsequent pages. diff --git a/crates/eval/Cargo.toml b/crates/eval/Cargo.toml index d8c97f9a11..e494ce296d 100644 --- a/crates/eval/Cargo.toml +++ b/crates/eval/Cargo.toml @@ -28,7 +28,7 @@ language.workspace = true language_extension.workspace = true language_model.workspace = true language_models.workspace = true -languages.workspace = true +languages = { workspace = true, features = ["load-grammars"] } node_runtime.workspace = true paths.workspace = true project.workspace = true @@ -36,6 +36,7 @@ prompt_store.workspace = true release_channel.workspace = true reqwest_client.workspace = true serde.workspace = true +serde_json.workspace = true settings.workspace = true shellexpand.workspace = true telemetry.workspace = true diff --git a/crates/eval/examples/add_arp_protocol_support/criteria.md b/crates/eval/examples/add_arp_protocol_support/diff_criteria.md similarity index 100% rename from crates/eval/examples/add_arp_protocol_support/criteria.md rename to crates/eval/examples/add_arp_protocol_support/diff_criteria.md diff --git a/crates/eval/examples/auth_session_management/criteria.md b/crates/eval/examples/auth_session_management/diff_criteria.md similarity index 100% rename from crates/eval/examples/auth_session_management/criteria.md rename to crates/eval/examples/auth_session_management/diff_criteria.md diff --git a/crates/eval/examples/buffer_string_input_support/criteria.md b/crates/eval/examples/buffer_string_input_support/diff_criteria.md similarity index 100% rename from crates/eval/examples/buffer_string_input_support/criteria.md rename to crates/eval/examples/buffer_string_input_support/diff_criteria.md diff --git a/crates/eval/examples/checkpoint_stability/criteria.md b/crates/eval/examples/checkpoint_stability/diff_criteria.md similarity index 100% rename from crates/eval/examples/checkpoint_stability/criteria.md rename to crates/eval/examples/checkpoint_stability/diff_criteria.md diff --git a/crates/eval/examples/dd_iaptic_mcp_server_integration/criteria.md b/crates/eval/examples/dd_iaptic_mcp_server_integration/diff_criteria.md similarity index 100% rename from crates/eval/examples/dd_iaptic_mcp_server_integration/criteria.md rename to crates/eval/examples/dd_iaptic_mcp_server_integration/diff_criteria.md diff --git a/crates/eval/examples/debian_image_builder/criteria.md b/crates/eval/examples/debian_image_builder/diff_criteria.md similarity index 100% rename from crates/eval/examples/debian_image_builder/criteria.md rename to crates/eval/examples/debian_image_builder/diff_criteria.md diff --git a/crates/eval/examples/docs_restructure/criteria.md b/crates/eval/examples/docs_restructure/diff_criteria.md similarity index 100% rename from crates/eval/examples/docs_restructure/criteria.md rename to crates/eval/examples/docs_restructure/diff_criteria.md diff --git a/crates/eval/examples/email_verification_refactor/base.toml b/crates/eval/examples/email_verification_refactor/base.toml index a8851fddda..04c26ca6b9 100644 --- a/crates/eval/examples/email_verification_refactor/base.toml +++ b/crates/eval/examples/email_verification_refactor/base.toml @@ -1,3 +1,4 @@ url = "https://github.com/dani-garcia/vaultwarden.git" revision = "3a1f1bae002bebf26ce3a38b879c1ba26529af1e" language_extension = "rs" +allow_preexisting_diagnostics = true diff --git a/crates/eval/examples/email_verification_refactor/criteria.md b/crates/eval/examples/email_verification_refactor/diff_criteria.md similarity index 100% rename from crates/eval/examples/email_verification_refactor/criteria.md rename to crates/eval/examples/email_verification_refactor/diff_criteria.md diff --git a/crates/eval/examples/exif_rotation_support/criteria.md b/crates/eval/examples/exif_rotation_support/diff_criteria.md similarity index 100% rename from crates/eval/examples/exif_rotation_support/criteria.md rename to crates/eval/examples/exif_rotation_support/diff_criteria.md diff --git a/crates/eval/examples/expand_laravel_php_support/criteria.md b/crates/eval/examples/expand_laravel_php_support/diff_criteria.md similarity index 100% rename from crates/eval/examples/expand_laravel_php_support/criteria.md rename to crates/eval/examples/expand_laravel_php_support/diff_criteria.md diff --git a/crates/eval/examples/find_and_replace_diff_card/criteria.md b/crates/eval/examples/find_and_replace_diff_card/diff_criteria.md similarity index 100% rename from crates/eval/examples/find_and_replace_diff_card/criteria.md rename to crates/eval/examples/find_and_replace_diff_card/diff_criteria.md diff --git a/crates/eval/examples/find_and_replace_diff_card/prompt.md b/crates/eval/examples/find_and_replace_diff_card/prompt.md index efd23cbba3..a4c2cfdb0c 100644 --- a/crates/eval/examples/find_and_replace_diff_card/prompt.md +++ b/crates/eval/examples/find_and_replace_diff_card/prompt.md @@ -1,3 +1,3 @@ -Look at the `find_replace_file_tool.rs`. I want to implement a card for it. The card should be a brand new `Entity` with a `Render` implementation. +Look at the `find_replace_file_tool.rs`. I want to implement a card for it. The card should implement the `Render` trait. The card should show a diff. It should be a beautifully presented diff. The card "box" should look like what we show for markdown codeblocks (look at `MarkdownElement`). I want to see a red background for lines that were deleted and a green background for lines that were added. We should have a div per diff line. diff --git a/crates/eval/examples/find_and_replace_diff_card/thread_criteria.md b/crates/eval/examples/find_and_replace_diff_card/thread_criteria.md new file mode 100644 index 0000000000..ee5c640a44 --- /dev/null +++ b/crates/eval/examples/find_and_replace_diff_card/thread_criteria.md @@ -0,0 +1,3 @@ +1. The first tool call should be to path search including "find_replace_file_tool.rs" in the string. (*Not* regex_search, for example, or reading the file based on a guess at the path.) This is because we gave the model a filename and it needs to turn that into a real path. +2. After obtaining the correct path of "zed/crates/assistant_tools/src/find_replace_file_tool.rs", it should read the contents of that path. +3. When trying to find information about the Render trait, it should *not* begin with a path search, because it doesn't yet have any information on what path the Render trait might be in. diff --git a/crates/eval/examples/finnish_translation/criteria.md b/crates/eval/examples/finnish_translation/diff_criteria.md similarity index 100% rename from crates/eval/examples/finnish_translation/criteria.md rename to crates/eval/examples/finnish_translation/diff_criteria.md diff --git a/crates/eval/examples/language_model_file_support/criteria.md b/crates/eval/examples/language_model_file_support/diff_criteria.md similarity index 100% rename from crates/eval/examples/language_model_file_support/criteria.md rename to crates/eval/examples/language_model_file_support/diff_criteria.md diff --git a/crates/eval/examples/lhs_join_update_callbacks/criteria.md b/crates/eval/examples/lhs_join_update_callbacks/diff_criteria.md similarity index 100% rename from crates/eval/examples/lhs_join_update_callbacks/criteria.md rename to crates/eval/examples/lhs_join_update_callbacks/diff_criteria.md diff --git a/crates/eval/examples/libdevice_symbol_reexport/criteria.md b/crates/eval/examples/libdevice_symbol_reexport/diff_criteria.md similarity index 100% rename from crates/eval/examples/libdevice_symbol_reexport/criteria.md rename to crates/eval/examples/libdevice_symbol_reexport/diff_criteria.md diff --git a/crates/eval/examples/license_management/criteria.md b/crates/eval/examples/license_management/diff_criteria.md similarity index 100% rename from crates/eval/examples/license_management/criteria.md rename to crates/eval/examples/license_management/diff_criteria.md diff --git a/crates/eval/examples/metal_i64_support/base.toml b/crates/eval/examples/metal_i64_support/base.toml index 01b0703231..4648f148b8 100644 --- a/crates/eval/examples/metal_i64_support/base.toml +++ b/crates/eval/examples/metal_i64_support/base.toml @@ -1,3 +1,4 @@ url = "https://github.com/huggingface/candle.git" revision = "3164a19a5dc18f5e0f7a063ae85a0cfd289e98f1" language_extension = "rs" +allow_preexisting_diagnostics = true diff --git a/crates/eval/examples/metal_i64_support/criteria.md b/crates/eval/examples/metal_i64_support/diff_criteria.md similarity index 100% rename from crates/eval/examples/metal_i64_support/criteria.md rename to crates/eval/examples/metal_i64_support/diff_criteria.md diff --git a/crates/eval/examples/metrics_data_size_updates/criteria.md b/crates/eval/examples/metrics_data_size_updates/diff_criteria.md similarity index 100% rename from crates/eval/examples/metrics_data_size_updates/criteria.md rename to crates/eval/examples/metrics_data_size_updates/diff_criteria.md diff --git a/crates/eval/examples/nan_diff_handling/criteria.md b/crates/eval/examples/nan_diff_handling/diff_criteria.md similarity index 100% rename from crates/eval/examples/nan_diff_handling/criteria.md rename to crates/eval/examples/nan_diff_handling/diff_criteria.md diff --git a/crates/eval/examples/never_type_workaround/criteria.md b/crates/eval/examples/never_type_workaround/diff_criteria.md similarity index 100% rename from crates/eval/examples/never_type_workaround/criteria.md rename to crates/eval/examples/never_type_workaround/diff_criteria.md diff --git a/crates/eval/examples/optimizer_schema_refactor/criteria.md b/crates/eval/examples/optimizer_schema_refactor/diff_criteria.md similarity index 100% rename from crates/eval/examples/optimizer_schema_refactor/criteria.md rename to crates/eval/examples/optimizer_schema_refactor/diff_criteria.md diff --git a/crates/eval/examples/rate_limit_endpoints/criteria.md b/crates/eval/examples/rate_limit_endpoints/diff_criteria.md similarity index 100% rename from crates/eval/examples/rate_limit_endpoints/criteria.md rename to crates/eval/examples/rate_limit_endpoints/diff_criteria.md diff --git a/crates/eval/examples/replace_hold_with_drain_on_exit/criteria.md b/crates/eval/examples/replace_hold_with_drain_on_exit/diff_criteria.md similarity index 100% rename from crates/eval/examples/replace_hold_with_drain_on_exit/criteria.md rename to crates/eval/examples/replace_hold_with_drain_on_exit/diff_criteria.md diff --git a/crates/eval/examples/request_to_axios_migration/criteria.md b/crates/eval/examples/request_to_axios_migration/diff_criteria.md similarity index 100% rename from crates/eval/examples/request_to_axios_migration/criteria.md rename to crates/eval/examples/request_to_axios_migration/diff_criteria.md diff --git a/crates/eval/examples/restore_version_api_support/criteria.md b/crates/eval/examples/restore_version_api_support/diff_criteria.md similarity index 100% rename from crates/eval/examples/restore_version_api_support/criteria.md rename to crates/eval/examples/restore_version_api_support/diff_criteria.md diff --git a/crates/eval/examples/runtime_script_refactor/criteria.md b/crates/eval/examples/runtime_script_refactor/diff_criteria.md similarity index 100% rename from crates/eval/examples/runtime_script_refactor/criteria.md rename to crates/eval/examples/runtime_script_refactor/diff_criteria.md diff --git a/crates/eval/examples/standardized_docker_dependency_checks/criteria.md b/crates/eval/examples/standardized_docker_dependency_checks/diff_criteria.md similarity index 100% rename from crates/eval/examples/standardized_docker_dependency_checks/criteria.md rename to crates/eval/examples/standardized_docker_dependency_checks/diff_criteria.md diff --git a/crates/eval/examples/table_metrics_sorting/criteria.md b/crates/eval/examples/table_metrics_sorting/diff_criteria.md similarity index 100% rename from crates/eval/examples/table_metrics_sorting/criteria.md rename to crates/eval/examples/table_metrics_sorting/diff_criteria.md diff --git a/crates/eval/examples/tax_id_validation/criteria.md b/crates/eval/examples/tax_id_validation/diff_criteria.md similarity index 100% rename from crates/eval/examples/tax_id_validation/criteria.md rename to crates/eval/examples/tax_id_validation/diff_criteria.md diff --git a/crates/eval/examples/test_infrastructure/criteria.md b/crates/eval/examples/test_infrastructure/diff_criteria.md similarity index 100% rename from crates/eval/examples/test_infrastructure/criteria.md rename to crates/eval/examples/test_infrastructure/diff_criteria.md diff --git a/crates/eval/examples/time_detail_merge_update/criteria.md b/crates/eval/examples/time_detail_merge_update/diff_criteria.md similarity index 100% rename from crates/eval/examples/time_detail_merge_update/criteria.md rename to crates/eval/examples/time_detail_merge_update/diff_criteria.md diff --git a/crates/eval/examples/tool_response_handling/criteria.md b/crates/eval/examples/tool_response_handling/diff_criteria.md similarity index 100% rename from crates/eval/examples/tool_response_handling/criteria.md rename to crates/eval/examples/tool_response_handling/diff_criteria.md diff --git a/crates/eval/examples/toolbar_endpoints/criteria.md b/crates/eval/examples/toolbar_endpoints/diff_criteria.md similarity index 100% rename from crates/eval/examples/toolbar_endpoints/criteria.md rename to crates/eval/examples/toolbar_endpoints/diff_criteria.md diff --git a/crates/eval/examples/virtio_block_request_refactor/base.toml b/crates/eval/examples/virtio_block_request_refactor/base.toml index a207fb3a10..58fdc0a963 100644 --- a/crates/eval/examples/virtio_block_request_refactor/base.toml +++ b/crates/eval/examples/virtio_block_request_refactor/base.toml @@ -1,3 +1,4 @@ url = "https://github.com/firecracker-microvm/firecracker.git" revision = "5eaa6e08e350cd38c8102848913a096312e59097" language_extension = "rs" +allow_preexisting_diagnostics = true diff --git a/crates/eval/examples/virtio_block_request_refactor/criteria.md b/crates/eval/examples/virtio_block_request_refactor/diff_criteria.md similarity index 100% rename from crates/eval/examples/virtio_block_request_refactor/criteria.md rename to crates/eval/examples/virtio_block_request_refactor/diff_criteria.md diff --git a/crates/eval/examples/war_and_uri_corrections/criteria.md b/crates/eval/examples/war_and_uri_corrections/diff_criteria.md similarity index 100% rename from crates/eval/examples/war_and_uri_corrections/criteria.md rename to crates/eval/examples/war_and_uri_corrections/diff_criteria.md diff --git a/crates/eval/examples/window_title_support/criteria.md b/crates/eval/examples/window_title_support/diff_criteria.md similarity index 100% rename from crates/eval/examples/window_title_support/criteria.md rename to crates/eval/examples/window_title_support/diff_criteria.md diff --git a/crates/eval/src/eval.rs b/crates/eval/src/eval.rs index 49c08201a8..5b465a5e74 100644 --- a/crates/eval/src/eval.rs +++ b/crates/eval/src/eval.rs @@ -9,8 +9,7 @@ use ::fs::RealFs; use anyhow::{Result, anyhow}; use clap::Parser; use extension::ExtensionHostProxy; -use futures::future; -use futures::stream::StreamExt; +use futures::{StreamExt, future}; use gpui::http_client::{Uri, read_proxy_from_env}; use gpui::{App, AppContext, Application, AsyncApp, Entity, SemanticVersion, Task, UpdateGlobal}; use gpui_tokio::Tokio; @@ -183,7 +182,7 @@ fn main() { println!( "{}Logging to: {}", example.log_prefix, - example.output_file_path.display() + example.example_output_directory().display() ); let repo_url = example.base.url.clone(); @@ -192,7 +191,7 @@ fn main() { if !repo_path.join(".git").is_dir() { println!( - "{:>(); + }); let results = futures::stream::iter(tasks) .buffer_unordered(concurrency) - .collect::>>, Example)>>() + .collect::>() .await; println!("\n\n"); @@ -259,26 +256,41 @@ fn main() { println!("========================================"); println!(""); - let mut judge_scores = Vec::new(); + let mut diff_scores = Vec::new(); + let mut thread_scores = Vec::new(); + let mut error_count = 0; for (result, example) in results { match result { Err(err) => { println!("💥 {}{:?}", example.log_prefix, err); + error_count += 1; } Ok(judge_results) => { for judge_result in judge_results { match judge_result { Ok(judge_output) => { const SCORES: [&str; 6] = ["💀", "😭", "😔", "😐", "🙂", "🤩"]; - let score: u32 = judge_output.score; - let score_index = (score.min(5)) as usize; + let diff_score: u32 = judge_output.diff.score; + let score_index = (diff_score.min(5)) as usize; println!( - "{} {}{}", - SCORES[score_index], example.log_prefix, judge_output.score, + "{} {}{} (Diff)", + SCORES[score_index], + example.log_prefix, + judge_output.diff.score, ); - judge_scores.push(judge_output.score); + diff_scores.push(judge_output.diff.score); + + if let Some(thread) = judge_output.thread { + let process_score: u32 = thread.score; + let score_index = (process_score.min(5)) as usize; + println!( + "{} {}{} (Thread)", + SCORES[score_index], example.log_prefix, thread.score, + ); + thread_scores.push(thread.score); + } } Err(err) => { println!("💥 {}{:?}", example.log_prefix, err); @@ -290,17 +302,39 @@ fn main() { println!( "{} > {}", " ".repeat(max_name_width), - example.output_file_path.display() + example.example_output_directory().display() ); } - let score_count = judge_scores.len(); - let average_score = judge_scores + let diff_score_count = diff_scores.len(); + let average_diff_score = diff_scores .into_iter() .map(|score| score as f32) .sum::() - / (score_count as f32); - println!("\nAverage score: {average_score}"); + / (diff_score_count as f32); + + if error_count > 0 { + println!("\n{error_count} examples failed to run!"); + } + + if diff_score_count > 0 { + println!("\nAverage code diff score: {average_diff_score}"); + } + + let thread_score_count = thread_scores.len(); + + // We might have gotten no thread scores if we weren't asked to judge the thread. + if thread_score_count > 0 { + let average_thread_score = thread_scores + .into_iter() + .map(|score| score as f32) + .sum::() + / (thread_score_count as f32); + + if diff_score_count > 0 { + println!("\nAverage thread score: {average_thread_score}"); + } + } std::thread::sleep(std::time::Duration::from_secs(2)); @@ -322,45 +356,11 @@ async fn run_example( let run_output = cx .update(|cx| example.run(model.clone(), app_state.clone(), cx))? .await?; - let diff = example.repository_diff().await?; - // Run judge for each repetition - let mut results = Vec::new(); - for round in 0..judge_repetitions { - let judge_result = example.judge(model.clone(), diff.clone(), round, cx).await; + let judge_tasks = (0..judge_repetitions) + .map(|round| run_judge_repetition(example.clone(), model.clone(), &run_output, round, cx)); - if let Ok(judge_output) = &judge_result { - let cohort_id = example - .output_file_path - .parent() - .and_then(|p| p.file_name()) - .map(|name| name.to_string_lossy().to_string()) - .unwrap_or(chrono::Local::now().format("%Y-%m-%d_%H-%M-%S").to_string()); - - let path = std::path::Path::new("."); - let commit_id = get_current_commit_id(path).await.unwrap_or_default(); - - telemetry::event!( - "Agent Eval Completed", - cohort_id = cohort_id, - example_name = example.name.clone(), - round = round, - score = judge_output.score, - analysis = judge_output.analysis, - tool_use_counts = run_output.tool_use_counts, - response_count = run_output.response_count, - token_usage = run_output.token_usage, - model = model.telemetry_id(), - model_provider = model.provider_id().to_string(), - repository_url = example.base.url.clone(), - repository_revision = example.base.revision.clone(), - diagnostics_summary = run_output.diagnostics, - commit_id = commit_id - ); - } - - results.push(judge_result); - } + let results = future::join_all(judge_tasks).await; app_state.client.telemetry().flush_events(); @@ -537,3 +537,68 @@ pub fn get_current_commit_id_sync(repo_path: &Path) -> String { get_current_commit_id(repo_path).await.unwrap_or_default() }) } + +async fn run_judge_repetition( + example: Example, + model: Arc, + run_output: &RunOutput, + round: u32, + cx: &AsyncApp, +) -> Result { + let judge_result = example.judge(model.clone(), &run_output, round, cx).await; + + if let Ok(judge_output) = &judge_result { + let cohort_id = example + .run_directory_path + .file_name() + .map(|name| name.to_string_lossy().to_string()) + .unwrap_or(chrono::Local::now().format("%Y-%m-%d_%H-%M-%S").to_string()); + + let path = std::path::Path::new("."); + let commit_id = get_current_commit_id(path).await.unwrap_or_default(); + + if let Some(thread) = &judge_output.thread { + telemetry::event!( + "Agent Eval Completed", + cohort_id = cohort_id, + example_name = example.name.clone(), + round = round, + diff_score = judge_output.diff.score, + diff_analysis = judge_output.diff.analysis, + thread_score = thread.score, + thread_analysis = thread.analysis, + tool_use_counts = run_output.tool_use_counts, + response_count = run_output.response_count, + token_usage = run_output.token_usage, + model = model.telemetry_id(), + model_provider = model.provider_id().to_string(), + repository_url = example.base.url.clone(), + repository_revision = example.base.revision.clone(), + diagnostics_before = run_output.diagnostics_before, + diagnostics_after = run_output.diagnostics_after, + commit_id = commit_id + ); + } else { + telemetry::event!( + "Agent Eval Completed", + cohort_id = cohort_id, + example_name = example.name.clone(), + round = round, + diff_score = judge_output.diff.score, + diff_analysis = judge_output.diff.analysis, + tool_use_counts = run_output.tool_use_counts, + response_count = run_output.response_count, + token_usage = run_output.token_usage, + model = model.telemetry_id(), + model_provider = model.provider_id().to_string(), + repository_url = example.base.url.clone(), + repository_revision = example.base.revision.clone(), + diagnostics_before = run_output.diagnostics_before, + diagnostics_after = run_output.diagnostics_after, + commit_id = commit_id + ); + } + } + + judge_result +} diff --git a/crates/eval/src/example.rs b/crates/eval/src/example.rs index 36f0fe7fdf..6b6fcb7290 100644 --- a/crates/eval/src/example.rs +++ b/crates/eval/src/example.rs @@ -10,14 +10,16 @@ use gpui::{App, AppContext as _, AsyncApp, Entity, Task}; use handlebars::Handlebars; use language::{DiagnosticSeverity, OffsetRangeExt}; use language_model::{ - LanguageModel, LanguageModelRequest, LanguageModelRequestMessage, MessageContent, Role, - StopReason, TokenUsage, + LanguageModel, LanguageModelCompletionEvent, LanguageModelRequest, LanguageModelRequestMessage, + MessageContent, Role, StopReason, TokenUsage, }; use project::{LspStore, Project, ProjectPath}; use serde::{Deserialize, Serialize}; +use std::cell::RefCell; use std::fmt::Write as _; use std::fs::File; use std::io::Write as _; +use std::rc::Rc; use std::sync::{Arc, Mutex}; use std::time::Duration; use std::{ @@ -45,6 +47,19 @@ pub struct ExampleBase { pub insert_id: Option, #[serde(default = "default_true")] pub require_lsp: bool, + #[serde(default)] + pub allow_preexisting_diagnostics: bool, +} + +impl ExampleBase { + pub fn repo_name(&self) -> String { + self.url + .split('/') + .next_back() + .unwrap_or(&"") + .trim_end_matches(".git") + .into() + } } #[derive(Clone, Debug)] @@ -54,14 +69,12 @@ pub struct Example { pub base: ExampleBase, /// Content of `prompt.md` pub prompt: String, - /// Content of `criteria.md` - pub criteria: String, - /// Markdown output file to append to - pub output_file: Option>>, - /// Path to the output run directory. - pub run_dir: PathBuf, - /// Path to markdown output file - pub output_file_path: PathBuf, + /// Content of `diff_criteria.md` + pub diff_criteria: String, + /// Content of `thread_criteria.md`, if that file exists (it's optional) + pub thread_criteria: Option, + /// Path to the directory containing the requests and responses for the agentic loop + pub run_directory_path: PathBuf, /// Prefix used for logging that identifies this example pub log_prefix: String, } @@ -69,41 +82,65 @@ pub struct Example { #[derive(Debug, Serialize, Deserialize, Clone)] pub struct RunOutput { pub repository_diff: String, - pub diagnostics: String, + pub ran_diagnostics_check: bool, + pub diagnostics_before: Option, + pub diagnostics_after: Option, pub response_count: usize, pub token_usage: TokenUsage, pub tool_use_counts: HashMap, u32>, + pub last_request: LanguageModelRequest, } #[derive(Debug, Clone, Serialize, Deserialize)] -pub struct JudgeInput { +pub struct JudgeDiffInput { pub repository_diff: String, + pub ran_diagnostics_check: bool, + #[serde(skip_serializing_if = "Option::is_none")] + pub diagnostics_before: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub diagnostics_after: Option, pub criteria: String, } #[derive(Debug, Clone, Serialize, Deserialize)] -pub struct JudgeOutput { +pub struct JudgeThreadInput { + pub messages: String, + pub criteria: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct JudgeResponse { pub analysis: String, pub score: u32, } +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct JudgeOutput { + pub thread: Option, + pub diff: JudgeResponse, +} + impl Example { /// Load an example from a directory containing base.toml, prompt.md, and criteria.md pub fn load_from_directory(dir_path: &Path, run_dir: &Path) -> Result { let name = Self::name_from_path(dir_path); let base_path = dir_path.join("base.toml"); let prompt_path = dir_path.join("prompt.md"); - let criteria_path = dir_path.join("criteria.md"); - let output_file_path = run_dir.join(format!("{}.md", name)); + let diff_criteria_path = dir_path.join("diff_criteria.md"); + let thread_criteria_path = dir_path.join("thread_criteria.md"); + let thread_criteria = if thread_criteria_path.exists() { + Some(fs::read_to_string(thread_criteria_path.clone())?) + } else { + None + }; Ok(Example { name: name.clone(), base: toml::from_str(&fs::read_to_string(&base_path)?)?, prompt: fs::read_to_string(prompt_path.clone())?, - criteria: fs::read_to_string(criteria_path.clone())?, - run_dir: run_dir.to_path_buf(), - output_file: None, - output_file_path, + thread_criteria, + diff_criteria: fs::read_to_string(diff_criteria_path.clone())?, + run_directory_path: run_dir.to_path_buf(), log_prefix: name, }) } @@ -111,10 +148,13 @@ impl Example { pub fn set_repetition_number(&mut self, repetition_number: u32) { if repetition_number > 0 { self.name = format!("{}-{}", self.name, repetition_number); - self.output_file_path = self.run_dir.join(format!("{}.md", self.name)); } } + pub fn example_output_directory(&self) -> PathBuf { + self.run_directory_path.join(&self.name) + } + pub fn set_log_prefix_style(&mut self, color: &str, name_width: usize) { self.log_prefix = format!( "{}{: Arc> { - self.output_file - .clone() - .expect("Output file not created. Call setup() first.") - } - pub fn run( &self, model: Arc, @@ -305,6 +337,11 @@ impl Example { None }; + let diagnostics_before = query_lsp_diagnostics(project.clone(), cx).await?; + if diagnostics_before.is_some() && !this.base.allow_preexisting_diagnostics { + return Err(anyhow!("Example has pre-existing diagnostics. If you want to run this example regardless, set `allow_preexisting_diagnostics` to `true` in `base.toml`")); + } + if std::env::var("ZED_EVAL_SETUP_ONLY").is_ok() { return Err(anyhow!("Setup only mode")); } @@ -312,15 +349,32 @@ impl Example { let thread_store = thread_store.await?; let thread = thread_store.update(cx, |thread_store, cx| thread_store.create_thread(cx))?; + let last_request = Rc::new(RefCell::new(None)); - { - let output_file_ref = this.output_file(); - let mut output_file = output_file_ref.lock().unwrap(); - writeln!(&mut output_file, "👤 USER:").log_err(); - writeln!(&mut output_file, "{}", this.prompt).log_err(); - writeln!(&mut output_file, "🤖 ASSISTANT:").log_err(); - output_file.flush().log_err(); - } + thread.update(cx, |thread, _cx| { + let mut request_count = 0; + let example_dir_path = this.example_output_directory(); + + let last_request = Rc::clone(&last_request); + thread.set_request_callback(move |request, response_events| { + *last_request.borrow_mut() = Some(request.clone()); + + request_count += 1; + let messages_file_path = example_dir_path.join(format!("{request_count}.messages.md")); + let last_messages_file_path = example_dir_path.join("last.messages.md"); + let request_markdown = RequestMarkdown::new(request); + let response_events_markdown = response_events_to_markdown(response_events); + + let messages = format!("{}\n\n{}", request_markdown.messages, response_events_markdown); + fs::write(messages_file_path, messages.clone()).expect("failed to write messages file"); + fs::write(last_messages_file_path, messages).expect("failed to write last messages file"); + + if request_count == 1 { + let tools_file_path = example_dir_path.join("tools.md"); + fs::write(tools_file_path, request_markdown.tools).expect("failed to write tools file"); + } + }); + })?; let tool_use_counts: Arc, u32>>> = Mutex::new(HashMap::default()).into(); @@ -332,8 +386,6 @@ impl Example { }); let event_handler_task = cx.spawn({ - // Need to clone the Arc here because the reference from output_file() won't live long enough - let output_file = this.output_file.clone().unwrap(); let log_prefix = this.log_prefix.clone(); let tool_use_counts = tool_use_counts.clone(); let thread = thread.downgrade(); @@ -349,8 +401,6 @@ impl Example { return Err(anyhow!("ThreadEvent channel ended early")); }; - let mut output_file = output_file.lock().unwrap(); - match event { ThreadEvent::Stopped(reason) => match reason { Ok(StopReason::EndTurn) => { @@ -371,18 +421,7 @@ impl Example { ThreadEvent::ShowError(thread_error) => { break Err(anyhow!(thread_error.clone())); } - ThreadEvent::StreamedAssistantText(_, chunk) => { - write!(&mut output_file, "{}", chunk).log_err(); - } - ThreadEvent::StreamedAssistantThinking(_, chunk) => { - write!(&mut output_file, "{}", chunk).log_err(); - } - ThreadEvent::UsePendingTools { tool_uses } => { - writeln!(&mut output_file, "\n\nUSING TOOLS:").log_err(); - for tool_use in tool_uses { - writeln!(&mut output_file, "{}: {}", tool_use.name, tool_use.input) - .log_err(); - } + ThreadEvent::StreamedAssistantText(_, _)| ThreadEvent::StreamedAssistantThinking(_, _) | ThreadEvent::UsePendingTools { .. } => { } ThreadEvent::ToolFinished { tool_use_id, @@ -398,8 +437,6 @@ impl Example { format!("TOOL FINISHED: {}", tool_use.name) }; println!("{log_prefix}{message}"); - writeln!(&mut output_file, "\n{}", message).log_err(); - writeln!(&mut output_file, "\n{}\n", tool_result.content).log_err(); let mut tool_use_counts = tool_use_counts.lock().unwrap(); *tool_use_counts .entry(tool_result.tool_name.clone()) @@ -407,7 +444,6 @@ impl Example { } else { let message = format!("TOOL FINISHED WITHOUT RESULT: {}", tool_use.name); println!("{log_prefix}{message}"); - writeln!(&mut output_file, "\n{}", message).log_err(); } } })?; @@ -428,8 +464,6 @@ impl Example { } } } - - output_file.flush().log_err(); } } }); @@ -451,21 +485,35 @@ impl Example { println!("{}Getting repository diff", this.log_prefix); let repository_diff = this.repository_diff().await?; - let repository_diff_path = this.run_dir.join(format!("{}.diff", this.name)); + let example_output_dir = this.example_output_directory(); + let repository_diff_path = example_output_dir.join("patch.diff"); let mut repository_diff_output_file = File::create(&repository_diff_path)?; writeln!(&mut repository_diff_output_file, "{}", &repository_diff).log_err(); println!("{}Getting diagnostics", this.log_prefix); - let diagnostics = cx + let diagnostics_after = cx .update(move |cx| { cx.spawn(async move |cx| query_lsp_diagnostics(project, cx).await) })? .await?; println!("{}Got diagnostics", this.log_prefix); + let Some(last_request) = last_request.borrow_mut().take() else { + return Err(anyhow!("No requests ran.")); + }; + drop(subscription); drop(lsp_open_handle_and_store); + if let Some(diagnostics_before) = &diagnostics_before { + fs::write(example_output_dir.join("diagnostics_before.txt"), diagnostics_before)?; + } + + if let Some(diagnostics_after) = &diagnostics_after { + fs::write(example_output_dir.join("diagnostics_after.txt"), diagnostics_after)?; + } + + thread.update(cx, |thread, _cx| { let response_count = thread .messages() @@ -473,31 +521,38 @@ impl Example { .count(); RunOutput { repository_diff, - diagnostics, + ran_diagnostics_check: this.base.require_lsp, + diagnostics_before, + diagnostics_after, response_count, token_usage: thread.cumulative_token_usage(), tool_use_counts: tool_use_counts.lock().unwrap().clone(), + last_request, } }) }) } - pub async fn judge( + async fn judge_diff( &self, model: Arc, - repository_diff: String, - judge_repetitions: u32, + run_output: &RunOutput, + judge_number: u32, cx: &AsyncApp, - ) -> Result { - let judge_prompt = include_str!("judge_prompt.hbs"); - let judge_prompt_name = "judge_prompt"; - let mut handlebars = Handlebars::new(); - handlebars.register_template_string(judge_prompt_name, judge_prompt)?; - let prompt = handlebars.render( - judge_prompt_name, - &JudgeInput { - repository_diff, - criteria: self.criteria.clone(), + ) -> Result<(String, JudgeResponse)> { + let judge_diff_prompt = include_str!("judge_diff_prompt.hbs"); + let judge_diff_prompt_name = "judge_diff_prompt"; + let mut hbs = Handlebars::new(); + hbs.register_template_string(judge_diff_prompt_name, judge_diff_prompt)?; + + let diff_prompt = hbs.render( + judge_diff_prompt_name, + &JudgeDiffInput { + repository_diff: run_output.repository_diff.clone(), + ran_diagnostics_check: run_output.ran_diagnostics_check, + diagnostics_before: run_output.diagnostics_before.clone(), + diagnostics_after: run_output.diagnostics_after.clone(), + criteria: self.diff_criteria.clone(), }, )?; @@ -506,7 +561,7 @@ impl Example { prompt_id: None, messages: vec![LanguageModelRequestMessage { role: Role::User, - content: vec![MessageContent::Text(prompt)], + content: vec![MessageContent::Text(diff_prompt)], cache: false, }], temperature: None, @@ -514,24 +569,106 @@ impl Example { stop: Vec::new(), }; - let response = send_language_model_request(model, request, cx).await?; + let diff_response = send_language_model_request(model, request, cx).await?; + let diff_output = JudgeResponse::parse(&diff_response)?; - let judge_file_path = self.run_dir.join(format!( - "{}_judge_{}.md", - self.name, // This is the eval_name - judge_repetitions - )); + println!( + "{}Judge #{judge_number} - Diff score: {}", + self.log_prefix, diff_output.score + ); - let mut judge_output_file = File::create(&judge_file_path)?; - writeln!(&mut judge_output_file, "{}", &response).log_err(); - - parse_judge_output(&response) + Ok((diff_response, diff_output)) } - pub async fn repository_diff(&self) -> Result { + async fn judge_thread( + &self, + model: Arc, + run_output: &RunOutput, + judge_number: u32, + cx: &AsyncApp, + ) -> Result<(String, Option)> { + if let Some(criteria) = self.thread_criteria.clone() { + let judge_thread_prompt = include_str!("judge_thread_prompt.hbs"); + let judge_thread_prompt_name = "judge_thread_prompt"; + let mut hbs = Handlebars::new(); + hbs.register_template_string(judge_thread_prompt_name, judge_thread_prompt)?; + + let request_markdown = RequestMarkdown::new(&run_output.last_request); + let thread_prompt = hbs.render( + judge_thread_prompt_name, + &JudgeThreadInput { + messages: request_markdown.messages, + criteria, + }, + )?; + + let request = LanguageModelRequest { + thread_id: None, + prompt_id: None, + messages: vec![LanguageModelRequestMessage { + role: Role::User, + content: vec![MessageContent::Text(thread_prompt)], + cache: false, + }], + temperature: None, + tools: Vec::new(), + stop: Vec::new(), + }; + + let thread_response = send_language_model_request(model, request, cx).await?; + let thread_output = JudgeResponse::parse(&thread_response)?; + + println!( + "{}Judge #{judge_number} - Thread score: {}", + self.log_prefix, thread_output.score + ); + + Ok((thread_response, Some(thread_output))) + } else { + let msg = "There were no criteria specified for this thread, so this example was not judged on its thread.".to_string(); + Ok((msg, None)) + } + } + + pub async fn judge( + &self, + model: Arc, + run_output: &RunOutput, + judge_number: u32, + cx: &AsyncApp, + ) -> Result { + let mut output_file = File::create( + self.example_output_directory() + .join(format!("judge_{}.md", judge_number)), + ) + .expect("failed to create judge.md"); + + println!("{}Running judge #{judge_number}", self.log_prefix); + + let diff_task = self.judge_diff(model.clone(), &run_output, judge_number, cx); + let thread_task = self.judge_thread(model.clone(), &run_output, judge_number, cx); + + let (diff_result, thread_result) = futures::join!(diff_task, thread_task); + + let (diff_response, diff_output) = diff_result?; + let (thread_response, thread_output) = thread_result?; + + writeln!( + &mut output_file, + "# Judgment\n\n## Thread\n\n{thread_response}\n\n## Diff\n\n{diff_response}", + ) + .log_err(); + + Ok(JudgeOutput { + thread: thread_output, + diff: diff_output, + }) + } + + async fn repository_diff(&self) -> Result { let worktree_path = self.worktree_path(); - run_git(&worktree_path, &["add", "-N"]).await?; - run_git(&worktree_path, &["diff"]).await + run_git(&worktree_path, &["add", "."]).await?; + run_git(&worktree_path, &["diff", "--staged"]).await } } @@ -599,7 +736,10 @@ fn has_pending_lang_server_work(lsp_store: &Entity, cx: &App) -> bool .any(|(_, status)| !status.pending_work.is_empty()) } -async fn query_lsp_diagnostics(project: Entity, cx: &mut AsyncApp) -> Result { +async fn query_lsp_diagnostics( + project: Entity, + cx: &mut AsyncApp, +) -> Result> { let paths_with_diagnostics = project.update(cx, |project, cx| { project .diagnostic_summaries(true, cx) @@ -608,6 +748,10 @@ async fn query_lsp_diagnostics(project: Entity, cx: &mut AsyncApp) -> R .collect::>() })?; + if paths_with_diagnostics.is_empty() { + return Ok(None); + } + let mut output = String::new(); for project_path in paths_with_diagnostics { let buffer = project @@ -633,16 +777,18 @@ async fn query_lsp_diagnostics(project: Entity, cx: &mut AsyncApp) -> R )?; } } - anyhow::Ok(output) + anyhow::Ok(Some(output)) } -fn parse_judge_output(response: &str) -> Result { - let analysis = get_tag("analysis", response)?.to_string(); - let score = get_tag("score", response)? - .parse() - .context("error parsing score")?; +impl JudgeResponse { + fn parse(response: &str) -> Result { + let analysis = get_tag("analysis", response)?.to_string(); + let score = get_tag("score", response)? + .parse() + .context("error parsing score")?; - Ok(JudgeOutput { analysis, score }) + Ok(Self { analysis, score }) + } } fn get_tag(name: &'static str, response: &str) -> Result { @@ -724,9 +870,135 @@ pub async fn send_language_model_request( } } +struct RequestMarkdown { + tools: String, + messages: String, +} + +impl RequestMarkdown { + fn new(request: &LanguageModelRequest) -> Self { + let mut tools = String::new(); + let mut messages = String::new(); + + // Print the tools + if !request.tools.is_empty() { + for tool in &request.tools { + write!(&mut tools, "# {}\n\n", tool.name).unwrap(); + write!(&mut tools, "{}\n\n", tool.description).unwrap(); + write!( + &mut tools, + "```json\n{}\n```\n\n", + serde_json::to_string_pretty(&tool.input_schema).unwrap_or_default() + ) + .unwrap(); + } + } + + // Print the messages + for message in &request.messages { + let role_str = match message.role { + Role::User => "👤 USER", + Role::Assistant => "🤖 ASSISTANT", + Role::System => "⚙️ SYSTEM", + }; + + messages.push_str(&format!("# {}\n\n", role_str)); + + for content in &message.content { + match content { + MessageContent::Text(text) => { + messages.push_str(text); + messages.push_str("\n\n"); + } + MessageContent::Image(_) => { + messages.push_str("[IMAGE DATA]\n\n"); + } + MessageContent::ToolUse(tool_use) => { + messages.push_str(&format!( + "**Tool Use**: {} (ID: {})\n", + tool_use.name, tool_use.id + )); + messages.push_str(&format!("```json\n{}\n```\n\n", tool_use.input)); + } + MessageContent::ToolResult(tool_result) => { + messages.push_str(&format!( + "**Tool Result**: {} (ID: {})\n\n", + tool_result.tool_name, tool_result.tool_use_id + )); + if tool_result.is_error { + messages.push_str("**ERROR:**\n"); + } + messages.push_str(&format!("{}\n", tool_result.content)); + } + } + } + } + + Self { tools, messages } + } +} + +fn response_events_to_markdown( + response_events: &[std::result::Result], +) -> String { + let mut response = String::new(); + // Print the response events if any + response.push_str("# Response\n\n"); + let mut text_buffer = String::new(); + let mut thinking_buffer = String::new(); + + let flush_buffers = + |output: &mut String, text_buffer: &mut String, thinking_buffer: &mut String| { + if !text_buffer.is_empty() { + output.push_str(&format!("**Text**:\n{}\n\n", text_buffer)); + text_buffer.clear(); + } + if !thinking_buffer.is_empty() { + output.push_str(&format!("**Thinking**:\n{}\n\n", thinking_buffer)); + thinking_buffer.clear(); + } + }; + + for event in response_events { + match event { + Ok(LanguageModelCompletionEvent::Text(text)) => { + text_buffer.push_str(text); + } + Ok(LanguageModelCompletionEvent::Thinking(text)) => { + thinking_buffer.push_str(text); + } + Ok(LanguageModelCompletionEvent::Stop(reason)) => { + flush_buffers(&mut response, &mut text_buffer, &mut thinking_buffer); + response.push_str(&format!("**Stop**: {:?}\n\n", reason)); + } + Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) => { + flush_buffers(&mut response, &mut text_buffer, &mut thinking_buffer); + response.push_str(&format!( + "**Tool Use**: {} (ID: {})\n", + tool_use.name, tool_use.id + )); + response.push_str(&format!("```json\n{}\n```\n\n", tool_use.input)); + } + Ok( + LanguageModelCompletionEvent::UsageUpdate(_) + | LanguageModelCompletionEvent::StartMessage { .. }, + ) => {} + Err(error) => { + flush_buffers(&mut response, &mut text_buffer, &mut thinking_buffer); + response.push_str(&format!("**Error**: {}\n\n", error)); + } + } + } + + flush_buffers(&mut response, &mut text_buffer, &mut thinking_buffer); + + response +} + #[cfg(test)] mod test { use super::*; + use handlebars::Handlebars; #[test] fn test_parse_judge_output() { @@ -736,7 +1008,7 @@ mod test { "# .unindent(); - let output = parse_judge_output(&response).unwrap(); + let output = JudgeResponse::parse(&response).unwrap(); assert_eq!( output.analysis, "The model did a good job but there were still compilations errors." @@ -756,8 +1028,158 @@ mod test { "# .unindent(); - let output = parse_judge_output(&response).unwrap(); + let output = JudgeResponse::parse(&response).unwrap(); assert_eq!(output.analysis, "Failed to compile:\n- Error 1\n- Error 2"); assert_eq!(output.score, 1); } + + #[test] + fn test_judge_prompt_with_diagnostics() { + // Case 1: Both diagnostics before and after are present + let input = JudgeDiffInput { + repository_diff: "diff content goes here".to_string(), + ran_diagnostics_check: true, + diagnostics_before: Some("Error at line 10: variable not found".to_string()), + diagnostics_after: Some("Error at line 15: missing semicolon".to_string()), + criteria: "Fix all bugs".to_string(), + }; + + let rendered = templates().render(JUDGE_PROMPT_NAME, &input).unwrap(); + + let expected_diagnostics_section = r#" + Take into account the diagnostics before and after applying the change: + + + Error at line 10: variable not found + + + + Error at line 15: missing semicolon + + "# + .unindent(); + + assert!(rendered.contains(&expected_diagnostics_section)); + } + + #[test] + fn test_judge_prompt_with_empty_diagnostics() { + // Case 2: Diagnostics check run but no diagnostics found + let input = JudgeDiffInput { + repository_diff: "diff content goes here".to_string(), + ran_diagnostics_check: true, + diagnostics_before: None, + diagnostics_after: None, + criteria: "Fix all bugs".to_string(), + }; + + let rendered = templates().render(JUDGE_PROMPT_NAME, &input).unwrap(); + + let expected_diagnostics_section = r#" + Take into account the diagnostics before and after applying the change: + + + No diagnostics before applying the edits. + + + + No diagnostics after applying the edits. + + "# + .unindent(); + + assert!(rendered.contains(&expected_diagnostics_section)); + } + + #[test] + fn test_judge_prompt_with_mixed_diagnostics() { + let templates = templates(); + + // Case 3: Before diagnostics present, after diagnostics absent + let input = JudgeDiffInput { + repository_diff: "diff content goes here".to_string(), + ran_diagnostics_check: true, + diagnostics_before: Some("Error at line 10: variable not found".to_string()), + diagnostics_after: None, + criteria: "Fix all bugs".to_string(), + }; + + let rendered = templates.render(JUDGE_PROMPT_NAME, &input).unwrap(); + + let expected_diagnostics_section = r#" + Take into account the diagnostics before and after applying the change: + + + Error at line 10: variable not found + + + + No diagnostics after applying the edits. + + "# + .unindent(); + + assert!(rendered.contains(&expected_diagnostics_section)); + + // Case 4: Before diagnostics absent, after diagnostics present + let input = JudgeDiffInput { + repository_diff: "diff content goes here".to_string(), + ran_diagnostics_check: true, + diagnostics_before: None, + diagnostics_after: Some("Error at line 15: missing semicolon".to_string()), + criteria: "Fix all bugs".to_string(), + }; + + let rendered = templates.render(JUDGE_PROMPT_NAME, &input).unwrap(); + + let expected_diagnostics_section = r#" + Take into account the diagnostics before and after applying the change: + + + No diagnostics before applying the edits. + + + + Error at line 15: missing semicolon + + "# + .unindent(); + + assert!(rendered.contains(&expected_diagnostics_section)); + } + + #[test] + fn test_judge_prompt_without_diagnostics() { + let templates = templates(); + + // Case 5: No diagnostics check run + let input = JudgeDiffInput { + repository_diff: "diff content goes here".to_string(), + ran_diagnostics_check: false, + diagnostics_before: None, + diagnostics_after: None, + criteria: "Fix all bugs".to_string(), + }; + + let rendered = templates.render(JUDGE_PROMPT_NAME, &input).unwrap(); + + // Check for the message when no diagnostics were performed + let diagnostics_message = "No diagnostic checks were performed."; + + assert!(rendered.contains(diagnostics_message)); + assert!(!rendered.contains("")); + assert!(!rendered.contains("")); + } + + const JUDGE_PROMPT_NAME: &str = "judge_prompt"; + + fn templates() -> Handlebars<'static> { + let mut judge_prompt = include_str!("judge_diff_prompt.hbs").to_string(); + language::LineEnding::normalize(&mut judge_prompt); + let mut handlebars = Handlebars::new(); + handlebars + .register_template_string(JUDGE_PROMPT_NAME, judge_prompt) + .unwrap(); + handlebars + } } diff --git a/crates/eval/src/judge_prompt.hbs b/crates/eval/src/judge_diff_prompt.hbs similarity index 64% rename from crates/eval/src/judge_prompt.hbs rename to crates/eval/src/judge_diff_prompt.hbs index 862cc0985c..4e9aaf68c0 100644 --- a/crates/eval/src/judge_prompt.hbs +++ b/crates/eval/src/judge_diff_prompt.hbs @@ -10,6 +10,28 @@ Use the following criteria to score the above changes. {{criteria}} +{{#if ran_diagnostics_check}} +Take into account the diagnostics before and after applying the change: + + +{{#if diagnostics_before}} +{{{diagnostics_before}}} +{{else}} +No diagnostics before applying the edits. +{{/if}} + + + +{{#if diagnostics_after}} +{{{diagnostics_after}}} +{{else}} +No diagnostics after applying the edits. +{{/if}} + +{{else}} +No diagnostic checks were performed. +{{/if}} + Based on these criteria, give the test output a score between 0 and 5. The output score should ONLY INCLUDE whole numbers. DO NOT return decimals or floats. diff --git a/crates/eval/src/judge_thread_prompt.hbs b/crates/eval/src/judge_thread_prompt.hbs new file mode 100644 index 0000000000..a84ce8e698 --- /dev/null +++ b/crates/eval/src/judge_thread_prompt.hbs @@ -0,0 +1,22 @@ +You are an expert software developer tasked with evaluating an AI agent's messages and tool calls in this conversation: + + +{{{messages}}} + + +Use the following criteria to score the above messages. + + +{{criteria}} + + +Based on these criteria, give the messages a score between 0 and 5. +The output score should ONLY INCLUDE whole numbers. DO NOT return decimals or floats. + +- 5 means: messages meet all criteria +- 0 means: messages don't meet any criteria + +``` +{YOUR ANALYSIS HERE} +{YOUR SCORE HERE} +``` diff --git a/crates/prompt_store/src/prompts.rs b/crates/prompt_store/src/prompts.rs index 6b143440c2..9dd3df3523 100644 --- a/crates/prompt_store/src/prompts.rs +++ b/crates/prompt_store/src/prompts.rs @@ -60,7 +60,6 @@ pub struct DefaultUserRulesContext { #[derive(Debug, Clone, Serialize)] pub struct WorktreeContext { pub root_name: String, - pub abs_path: Arc, pub rules_file: Option, } @@ -403,7 +402,6 @@ mod test { fn test_assistant_system_prompt_renders() { let worktrees = vec![WorktreeContext { root_name: "path".into(), - abs_path: Path::new("/some/path").into(), rules_file: Some(RulesFileContext { path_in_worktree: Path::new(".rules").into(), abs_path: Path::new("/some/path/.rules").into(), diff --git a/crates/worktree/src/worktree.rs b/crates/worktree/src/worktree.rs index 38e39fd774..b00b1d6b63 100644 --- a/crates/worktree/src/worktree.rs +++ b/crates/worktree/src/worktree.rs @@ -806,6 +806,23 @@ impl Worktree { } } + pub fn file_exists(&self, path: &Path, cx: &Context) -> Task> { + match self { + Worktree::Local(this) => { + let fs = this.fs.clone(); + let path = this.absolutize(path); + cx.background_spawn(async move { + let path = path?; + let metadata = fs.metadata(&path).await?; + Ok(metadata.map_or(false, |metadata| !metadata.is_dir)) + }) + } + Worktree::Remote(_) => Task::ready(Err(anyhow!( + "remote worktrees can't yet check file existence" + ))), + } + } + pub fn load_file(&self, path: &Path, cx: &Context) -> Task> { match self { Worktree::Local(this) => this.load_file(path, cx), diff --git a/typos.toml b/typos.toml index 72bc3e8ccf..4952c61c6a 100644 --- a/typos.toml +++ b/typos.toml @@ -45,9 +45,7 @@ extend-exclude = [ # Spellcheck triggers on `|Fixe[sd]|` regex part. "script/danger/dangerfile.ts", # Eval examples for prompts and criteria - "crates/eval/examples/checkpoint_stability/criteria.md", - "crates/eval/examples/tax_id_validation/prompt.md", - "crates/eval/examples/tax_id_validation/criteria.md" + "crates/eval/examples/", ] [default]