diff --git a/crates/assistant2/src/system_prompt.md b/assets/prompts/assistant_system_prompt.hbs similarity index 87% rename from crates/assistant2/src/system_prompt.md rename to assets/prompts/assistant_system_prompt.hbs index 7b7871922b..6ac1b785c1 100644 --- a/crates/assistant2/src/system_prompt.md +++ b/assets/prompts/assistant_system_prompt.hbs @@ -10,3 +10,9 @@ You should only perform actions that modify the user’s system if explicitly re - If the user clearly requests that you perform an action, carry out the action directly without explaining why you are doing so. Be concise and direct in your responses. + +The user has opened a project that contains the following top-level directories/files: + +{{#each worktree_root_names}} +- {{this}} +{{/each}} diff --git a/crates/assistant2/src/assistant_panel.rs b/crates/assistant2/src/assistant_panel.rs index 70ad76fadc..3b9c45bec0 100644 --- a/crates/assistant2/src/assistant_panel.rs +++ b/crates/assistant2/src/assistant_panel.rs @@ -112,7 +112,7 @@ impl AssistantPanel { log::info!("[assistant2-debug] initializing ThreadStore"); let thread_store = workspace.update(&mut cx, |workspace, cx| { let project = workspace.project().clone(); - ThreadStore::new(project, tools.clone(), cx) + ThreadStore::new(project, tools.clone(), prompt_builder.clone(), cx) })??; log::info!("[assistant2-debug] finished initializing ThreadStore"); diff --git a/crates/assistant2/src/thread.rs b/crates/assistant2/src/thread.rs index 3c8b833ab3..acca0b65e7 100644 --- a/crates/assistant2/src/thread.rs +++ b/crates/assistant2/src/thread.rs @@ -1,6 +1,6 @@ use std::sync::Arc; -use anyhow::Result; +use anyhow::{Context as _, Result}; use assistant_tool::ToolWorkingSet; use chrono::{DateTime, Utc}; use collections::{BTreeMap, HashMap, HashSet}; @@ -13,9 +13,10 @@ use language_model::{ Role, StopReason, }; use project::Project; +use prompt_store::PromptBuilder; use scripting_tool::{ScriptingSession, ScriptingTool}; use serde::{Deserialize, Serialize}; -use util::{post_inc, TryFutureExt as _}; +use util::{post_inc, ResultExt, TryFutureExt as _}; use uuid::Uuid; use crate::context::{attach_context_to_message, ContextId, ContextSnapshot}; @@ -74,6 +75,7 @@ pub struct Thread { completion_count: usize, pending_completions: Vec, project: Entity, + prompt_builder: Arc, tools: Arc, tool_use: ToolUseState, scripting_session: Entity, @@ -84,6 +86,7 @@ impl Thread { pub fn new( project: Entity, tools: Arc, + prompt_builder: Arc, cx: &mut Context, ) -> Self { let scripting_session = cx.new(|cx| ScriptingSession::new(project.clone(), cx)); @@ -100,6 +103,7 @@ impl Thread { completion_count: 0, pending_completions: Vec::new(), project, + prompt_builder, tools, tool_use: ToolUseState::new(), scripting_session, @@ -112,6 +116,7 @@ impl Thread { saved: SavedThread, project: Entity, tools: Arc, + prompt_builder: Arc, cx: &mut Context, ) -> Self { let next_message_id = MessageId( @@ -147,6 +152,7 @@ impl Thread { completion_count: 0, pending_completions: Vec::new(), project, + prompt_builder, tools, tool_use, scripting_session, @@ -373,14 +379,25 @@ impl Thread { pub fn to_completion_request( &self, request_kind: RequestKind, - _cx: &App, + cx: &App, ) -> LanguageModelRequest { + let worktree_root_names = self + .project + .read(cx) + .worktree_root_names(cx) + .map(ToString::to_string) + .collect::>(); + let system_prompt = self + .prompt_builder + .generate_assistant_system_prompt(worktree_root_names) + .context("failed to generate assistant system prompt") + .log_err() + .unwrap_or_default(); + let mut request = LanguageModelRequest { messages: vec![LanguageModelRequestMessage { role: Role::System, - content: vec![MessageContent::Text( - include_str!("./system_prompt.md").to_string(), - )], + content: vec![MessageContent::Text(system_prompt)], cache: true, }], tools: Vec::new(), diff --git a/crates/assistant2/src/thread_store.rs b/crates/assistant2/src/thread_store.rs index 2200d914f3..7657ef9624 100644 --- a/crates/assistant2/src/thread_store.rs +++ b/crates/assistant2/src/thread_store.rs @@ -16,6 +16,7 @@ use heed::types::{SerdeBincode, SerdeJson}; use heed::Database; use language_model::{LanguageModelToolUseId, Role}; use project::Project; +use prompt_store::PromptBuilder; use serde::{Deserialize, Serialize}; use util::ResultExt as _; @@ -28,6 +29,7 @@ pub fn init(cx: &mut App) { pub struct ThreadStore { project: Entity, tools: Arc, + prompt_builder: Arc, context_server_manager: Entity, context_server_tool_ids: HashMap, Vec>, threads: Vec, @@ -37,6 +39,7 @@ impl ThreadStore { pub fn new( project: Entity, tools: Arc, + prompt_builder: Arc, cx: &mut App, ) -> Result> { let this = cx.new(|cx| { @@ -48,6 +51,7 @@ impl ThreadStore { let this = Self { project, tools, + prompt_builder, context_server_manager, context_server_tool_ids: HashMap::default(), threads: Vec::new(), @@ -77,7 +81,14 @@ impl ThreadStore { } pub fn create_thread(&mut self, cx: &mut Context) -> Entity { - cx.new(|cx| Thread::new(self.project.clone(), self.tools.clone(), cx)) + cx.new(|cx| { + Thread::new( + self.project.clone(), + self.tools.clone(), + self.prompt_builder.clone(), + cx, + ) + }) } pub fn open_thread( @@ -101,6 +112,7 @@ impl ThreadStore { thread, this.project.clone(), this.tools.clone(), + this.prompt_builder.clone(), cx, ) }) diff --git a/crates/assistant_tools/src/assistant_tools.rs b/crates/assistant_tools/src/assistant_tools.rs index 1b0343430c..0d3453cd2a 100644 --- a/crates/assistant_tools/src/assistant_tools.rs +++ b/crates/assistant_tools/src/assistant_tools.rs @@ -1,5 +1,4 @@ mod edit_files_tool; -mod list_worktrees_tool; mod now_tool; mod read_file_tool; @@ -7,7 +6,6 @@ use assistant_tool::ToolRegistry; use gpui::App; use crate::edit_files_tool::EditFilesTool; -use crate::list_worktrees_tool::ListWorktreesTool; use crate::now_tool::NowTool; use crate::read_file_tool::ReadFileTool; @@ -16,7 +14,6 @@ pub fn init(cx: &mut App) { let registry = ToolRegistry::global(cx); registry.register_tool(NowTool); - registry.register_tool(ListWorktreesTool); registry.register_tool(ReadFileTool); registry.register_tool(EditFilesTool); } diff --git a/crates/assistant_tools/src/edit_files_tool.rs b/crates/assistant_tools/src/edit_files_tool.rs index 87255febc4..95875f5199 100644 --- a/crates/assistant_tools/src/edit_files_tool.rs +++ b/crates/assistant_tools/src/edit_files_tool.rs @@ -1,25 +1,33 @@ mod edit_action; -use collections::HashSet; -use std::{path::Path, sync::Arc}; - -use anyhow::{anyhow, Result}; +use anyhow::{anyhow, Context, Result}; use assistant_tool::Tool; +use collections::HashSet; use edit_action::{EditAction, EditActionParser}; use futures::StreamExt; use gpui::{App, Entity, Task}; use language_model::{ LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role, }; -use project::{Project, ProjectPath, WorktreeId}; +use project::{Project, ProjectPath}; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; +use std::sync::Arc; #[derive(Debug, Serialize, Deserialize, JsonSchema)] pub struct EditFilesToolInput { - /// The ID of the worktree in which the files reside. - pub worktree_id: usize, - /// Instruct how to modify the files. + /// High-level edit instructions. These will be interpreted by a smaller model, + /// so explain the edits you want that model to make and to which files need changing. + /// The description should be concise and clear. We will show this description to the user + /// as well. + /// + /// + /// If you want to rename a function you can say "Rename the function 'foo' to 'bar'". + /// + /// + /// + /// If you want to add a new function you can say "Add a new method to the `User` struct that prints the age". + /// pub edit_instructions: String, } @@ -90,10 +98,25 @@ impl Tool for EditFilesTool { while let Some(chunk) = chunks.stream.next().await { for action in parser.parse_chunk(&chunk?) { - let project_path = ProjectPath { - worktree_id: WorktreeId::from_usize(input.worktree_id), - path: Path::new(action.file_path()).into(), - }; + let project_path = project.read_with(&cx, |project, cx| { + let worktree_root_name = action + .file_path() + .components() + .next() + .context("Invalid path")?; + let worktree = project + .worktree_for_root_name( + &worktree_root_name.as_os_str().to_string_lossy(), + cx, + ) + .context("Directory not found in project")?; + anyhow::Ok(ProjectPath { + worktree_id: worktree.read(cx).id(), + path: Arc::from( + action.file_path().strip_prefix(worktree_root_name).unwrap(), + ), + }) + })??; let buffer = project .update(&mut cx, |project, cx| project.open_buffer(project_path, cx))? diff --git a/crates/assistant_tools/src/edit_files_tool/description.md b/crates/assistant_tools/src/edit_files_tool/description.md index 4f61ecd3cc..a049076bd1 100644 --- a/crates/assistant_tools/src/edit_files_tool/description.md +++ b/crates/assistant_tools/src/edit_files_tool/description.md @@ -1,3 +1,5 @@ -Edit files in a worktree by providing its id and a description of how to modify the code to complete the request. +Edit files in the current project. -Make instructions unambiguous and complete. Explain all needed code changes clearly and completely, but concisely. Just show the changes needed. DO NOT show the entire updated function/file/etc! +When using this tool, you should suggest one coherent edit that can be made to the codebase. + +When the set of edits you want to make is large or complex, feel free to invoke this tool multiple times, each time focusing on a specific change you wanna make. diff --git a/crates/assistant_tools/src/edit_files_tool/edit_action.rs b/crates/assistant_tools/src/edit_files_tool/edit_action.rs index 34e1f11d29..d118ca1a7f 100644 --- a/crates/assistant_tools/src/edit_files_tool/edit_action.rs +++ b/crates/assistant_tools/src/edit_files_tool/edit_action.rs @@ -1,3 +1,4 @@ +use std::path::{Path, PathBuf}; use util::ResultExt; /// Represents an edit action to be performed on a file. @@ -5,16 +6,16 @@ use util::ResultExt; pub enum EditAction { /// Replace specific content in a file with new content Replace { - file_path: String, + file_path: PathBuf, old: String, new: String, }, /// Write content to a file (create or overwrite) - Write { file_path: String, content: String }, + Write { file_path: PathBuf, content: String }, } impl EditAction { - pub fn file_path(&self) -> &str { + pub fn file_path(&self) -> &Path { match self { EditAction::Replace { file_path, .. } => file_path, EditAction::Write { file_path, .. } => file_path, @@ -180,7 +181,7 @@ impl EditActionParser { pop_carriage_return(&mut pre_fence_line); } - let file_path = String::from_utf8(pre_fence_line).log_err()?; + let file_path = PathBuf::from(String::from_utf8(pre_fence_line).log_err()?); let content = String::from_utf8(std::mem::take(&mut self.new_bytes)).log_err()?; if self.old_bytes.is_empty() { @@ -374,7 +375,7 @@ fn replacement() {} assert_eq!( actions[0], EditAction::Replace { - file_path: "src/main.rs".to_string(), + file_path: PathBuf::from("src/main.rs"), old: "fn original() {}".to_string(), new: "fn replacement() {}".to_string(), } @@ -401,7 +402,7 @@ fn replacement() {} assert_eq!( actions[0], EditAction::Replace { - file_path: "src/main.rs".to_string(), + file_path: PathBuf::from("src/main.rs"), old: "fn original() {}".to_string(), new: "fn replacement() {}".to_string(), } @@ -432,7 +433,7 @@ This change makes the function better. assert_eq!( actions[0], EditAction::Replace { - file_path: "src/main.rs".to_string(), + file_path: PathBuf::from("src/main.rs"), old: "fn original() {}".to_string(), new: "fn replacement() {}".to_string(), } @@ -470,7 +471,7 @@ fn new_util() -> bool { true } assert_eq!( actions[0], EditAction::Replace { - file_path: "src/main.rs".to_string(), + file_path: PathBuf::from("src/main.rs"), old: "fn original() {}".to_string(), new: "fn replacement() {}".to_string(), } @@ -478,7 +479,7 @@ fn new_util() -> bool { true } assert_eq!( actions[1], EditAction::Replace { - file_path: "src/utils.rs".to_string(), + file_path: PathBuf::from("src/utils.rs"), old: "fn old_util() -> bool { false }".to_string(), new: "fn new_util() -> bool { true }".to_string(), } @@ -519,7 +520,7 @@ fn replacement() { assert_eq!( actions[0], EditAction::Replace { - file_path: "src/main.rs".to_string(), + file_path: PathBuf::from("src/main.rs"), old: "fn original() {\n println!(\"This is the original function\");\n let x = 42;\n if x > 0 {\n println!(\"Positive number\");\n }\n}".to_string(), new: "fn replacement() {\n println!(\"This is the replacement function\");\n let x = 100;\n if x > 50 {\n println!(\"Large number\");\n } else {\n println!(\"Small number\");\n }\n}".to_string(), } @@ -549,7 +550,7 @@ fn new_function() { assert_eq!( actions[0], EditAction::Write { - file_path: "src/main.rs".to_string(), + file_path: PathBuf::from("src/main.rs"), content: "fn new_function() {\n println!(\"This function is being added\");\n}" .to_string(), } @@ -576,7 +577,7 @@ fn this_will_be_deleted() { assert_eq!( actions[0], EditAction::Replace { - file_path: "src/main.rs".to_string(), + file_path: PathBuf::from("src/main.rs"), old: "fn this_will_be_deleted() {\n println!(\"Deleting this function\");\n}" .to_string(), new: "".to_string(), @@ -589,7 +590,7 @@ fn this_will_be_deleted() { assert_eq!( actions[0], EditAction::Replace { - file_path: "src/main.rs".to_string(), + file_path: PathBuf::from("src/main.rs"), old: "fn this_will_be_deleted() {\r\n println!(\"Deleting this function\");\r\n}" .to_string(), @@ -655,7 +656,7 @@ fn replacement() {}"#; assert_eq!( actions3[0], EditAction::Replace { - file_path: "src/main.rs".to_string(), + file_path: PathBuf::from("src/main.rs"), old: "fn original() {}".to_string(), new: "fn replacement() {}".to_string(), } @@ -747,7 +748,7 @@ fn new_utils_func() {} assert_eq!( actions[0], EditAction::Replace { - file_path: "src/utils.rs".to_string(), + file_path: PathBuf::from("src/utils.rs"), old: "fn utils_func() {}".to_string(), new: "fn new_utils_func() {}".to_string(), } @@ -795,7 +796,7 @@ fn new_utils_func() {} assert_eq!( actions[0], EditAction::Replace { - file_path: "mathweb/flask/app.py".to_string(), + file_path: PathBuf::from("mathweb/flask/app.py"), old: "from flask import Flask".to_string(), new: "import math\nfrom flask import Flask".to_string(), } @@ -804,7 +805,7 @@ fn new_utils_func() {} assert_eq!( actions[1], EditAction::Replace { - file_path: "mathweb/flask/app.py".to_string(), + file_path: PathBuf::from("mathweb/flask/app.py"), old: "def factorial(n):\n \"compute factorial\"\n\n if n == 0:\n return 1\n else:\n return n * factorial(n-1)\n".to_string(), new: "".to_string(), } @@ -813,7 +814,7 @@ fn new_utils_func() {} assert_eq!( actions[2], EditAction::Replace { - file_path: "mathweb/flask/app.py".to_string(), + file_path: PathBuf::from("mathweb/flask/app.py"), old: " return str(factorial(n))".to_string(), new: " return str(math.factorial(n))".to_string(), } @@ -822,7 +823,7 @@ fn new_utils_func() {} assert_eq!( actions[3], EditAction::Write { - file_path: "hello.py".to_string(), + file_path: PathBuf::from("hello.py"), content: "def hello():\n \"print a greeting\"\n\n print(\"hello\")" .to_string(), } @@ -831,7 +832,7 @@ fn new_utils_func() {} assert_eq!( actions[4], EditAction::Replace { - file_path: "main.py".to_string(), + file_path: PathBuf::from("main.py"), old: "def hello():\n \"print a greeting\"\n\n print(\"hello\")".to_string(), new: "from hello import hello".to_string(), } diff --git a/crates/assistant_tools/src/list_worktrees_tool.rs b/crates/assistant_tools/src/list_worktrees_tool.rs deleted file mode 100644 index d30f987424..0000000000 --- a/crates/assistant_tools/src/list_worktrees_tool.rs +++ /dev/null @@ -1,79 +0,0 @@ -use std::sync::Arc; - -use anyhow::Result; -use assistant_tool::Tool; -use gpui::{App, Entity, Task}; -use language_model::LanguageModelRequestMessage; -use project::Project; -use schemars::JsonSchema; -use serde::{Deserialize, Serialize}; - -#[derive(Debug, Serialize, Deserialize, JsonSchema)] -pub struct ListWorktreesToolInput {} - -pub struct ListWorktreesTool; - -impl Tool for ListWorktreesTool { - fn name(&self) -> String { - "list-worktrees".into() - } - - fn description(&self) -> String { - "Lists all worktrees in the current project. Use this tool when you need to find available worktrees and their IDs.".into() - } - - fn input_schema(&self) -> serde_json::Value { - serde_json::json!( - { - "type": "object", - "properties": {}, - "required": [] - } - ) - } - - fn run( - self: Arc, - _input: serde_json::Value, - _messages: &[LanguageModelRequestMessage], - project: Entity, - cx: &mut App, - ) -> Task> { - cx.spawn(|cx| async move { - cx.update(|cx| { - #[derive(Debug, Serialize)] - struct WorktreeInfo { - id: usize, - root_name: String, - root_dir: Option, - } - - let worktrees = project.update(cx, |project, cx| { - project - .visible_worktrees(cx) - .map(|worktree| { - worktree.read_with(cx, |worktree, _cx| WorktreeInfo { - id: worktree.id().to_usize(), - root_dir: worktree - .root_dir() - .map(|root_dir| root_dir.to_string_lossy().to_string()), - root_name: worktree.root_name().to_string(), - }) - }) - .collect::>() - }); - - if worktrees.is_empty() { - return Ok("No worktrees found in the current project.".to_string()); - } - - let mut result = String::from("Worktrees in the current project:\n\n"); - for worktree in worktrees { - result.push_str(&serde_json::to_string(&worktree)?); - } - - Ok(result) - })? - }) - } -} diff --git a/crates/assistant_tools/src/read_file_tool.rs b/crates/assistant_tools/src/read_file_tool.rs index 82df2d499d..4c08bbe4e7 100644 --- a/crates/assistant_tools/src/read_file_tool.rs +++ b/crates/assistant_tools/src/read_file_tool.rs @@ -5,17 +5,24 @@ use anyhow::{anyhow, Result}; use assistant_tool::Tool; use gpui::{App, Entity, Task}; use language_model::LanguageModelRequestMessage; -use project::{Project, ProjectPath, WorktreeId}; +use project::{Project, ProjectPath}; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; #[derive(Debug, Serialize, Deserialize, JsonSchema)] pub struct ReadFileToolInput { - /// The ID of the worktree in which the file resides. - pub worktree_id: usize, - /// The path to the file to read. + /// The relative path of the file to read. /// - /// This path is relative to the worktree root, it must not be an absolute path. + /// This path should never be absolute, and the first component + /// of the path should always be a top-level directory in a project. + /// + /// For example, if the project has the following top-level directories: + /// + /// - directory1 + /// - directory2 + /// + /// If you wanna access `file.txt` in `directory1`, you should use the path `directory1/file.txt`. + /// If you wanna access `file.txt` in `directory2`, you should use the path `directory2/file.txt`. pub path: Arc, } @@ -27,7 +34,7 @@ impl Tool for ReadFileTool { } fn description(&self) -> String { - "Reads the content of a file specified by a worktree ID and path. Use this tool when you need to access the contents of a file in the project.".into() + include_str!("./read_file_tool/description.md").into() } fn input_schema(&self) -> serde_json::Value { @@ -47,9 +54,18 @@ impl Tool for ReadFileTool { Err(err) => return Task::ready(Err(anyhow!(err))), }; + let Some(worktree_root_name) = input.path.components().next() else { + return Task::ready(Err(anyhow!("Invalid path"))); + }; + let Some(worktree) = project + .read(cx) + .worktree_for_root_name(&worktree_root_name.as_os_str().to_string_lossy(), cx) + else { + return Task::ready(Err(anyhow!("Directory not found in the project"))); + }; let project_path = ProjectPath { - worktree_id: WorktreeId::from_usize(input.worktree_id), - path: input.path, + worktree_id: worktree.read(cx).id(), + path: Arc::from(input.path.strip_prefix(worktree_root_name).unwrap()), }; cx.spawn(|cx| async move { let buffer = cx diff --git a/crates/assistant_tools/src/read_file_tool/description.md b/crates/assistant_tools/src/read_file_tool/description.md new file mode 100644 index 0000000000..ff023a6bb7 --- /dev/null +++ b/crates/assistant_tools/src/read_file_tool/description.md @@ -0,0 +1 @@ +Reads the content of the given file in the project. diff --git a/crates/project/src/project.rs b/crates/project/src/project.rs index 5ffadd023c..c5e1a27ddb 100644 --- a/crates/project/src/project.rs +++ b/crates/project/src/project.rs @@ -1589,6 +1589,11 @@ impl Project { self.worktree_store.read(cx).visible_worktrees(cx) } + pub fn worktree_for_root_name(&self, root_name: &str, cx: &App) -> Option> { + self.visible_worktrees(cx) + .find(|tree| tree.read(cx).root_name() == root_name) + } + pub fn worktree_root_names<'a>(&'a self, cx: &'a App) -> impl Iterator { self.visible_worktrees(cx) .map(|tree| tree.read(cx).root_name()) diff --git a/crates/prompt_store/src/prompts.rs b/crates/prompt_store/src/prompts.rs index 4fafbce2a3..deb9465563 100644 --- a/crates/prompt_store/src/prompts.rs +++ b/crates/prompt_store/src/prompts.rs @@ -11,6 +11,11 @@ use std::{ops::Range, path::PathBuf, sync::Arc, time::Duration}; use text::LineEnding; use util::ResultExt; +#[derive(Serialize)] +pub struct AssistantSystemPromptContext { + pub worktree_root_names: Vec, +} + #[derive(Serialize)] pub struct ContentPromptDiagnosticContext { pub line_number: usize, @@ -216,6 +221,18 @@ impl PromptBuilder { Ok(()) } + pub fn generate_assistant_system_prompt( + &self, + worktree_root_names: Vec, + ) -> Result { + let prompt = AssistantSystemPromptContext { + worktree_root_names, + }; + self.handlebars + .lock() + .render("assistant_system_prompt", &prompt) + } + pub fn generate_inline_transformation_prompt( &self, user_prompt: String,