diff --git a/Cargo.lock b/Cargo.lock index 37dedc85da..e90baba517 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -658,6 +658,7 @@ dependencies = [ "collections", "derive_more", "gpui", + "language_model", "parking_lot", "project", "serde", @@ -671,11 +672,16 @@ dependencies = [ "anyhow", "assistant_tool", "chrono", + "collections", + "futures 0.3.31", "gpui", + "language_model", "project", + "rand 0.8.5", "schemars", "serde", "serde_json", + "util", ] [[package]] @@ -3128,6 +3134,7 @@ dependencies = [ "extension", "futures 0.3.31", "gpui", + "language_model", "log", "parking_lot", "postage", diff --git a/assets/settings/default.json b/assets/settings/default.json index edcf18a6c7..534d564441 100644 --- a/assets/settings/default.json +++ b/assets/settings/default.json @@ -600,6 +600,13 @@ "provider": "zed.dev", // The model to use. "model": "claude-3-5-sonnet-latest" + }, + // The model to use when applying edits from the assistant. + "editor_model": { + // The provider to use. + "provider": "zed.dev", + // The model to use. + "model": "claude-3-5-sonnet-latest" } }, // The settings for slash commands. diff --git a/crates/assistant/src/assistant.rs b/crates/assistant/src/assistant.rs index e50a4cea6f..63d715d8a7 100644 --- a/crates/assistant/src/assistant.rs +++ b/crates/assistant/src/assistant.rs @@ -186,8 +186,12 @@ fn init_language_model_settings(cx: &mut App) { fn update_active_language_model_from_settings(cx: &mut App) { let settings = AssistantSettings::get_global(cx); - let provider_name = LanguageModelProviderId::from(settings.default_model.provider.clone()); - let model_id = LanguageModelId::from(settings.default_model.model.clone()); + let active_model_provider_name = + LanguageModelProviderId::from(settings.default_model.provider.clone()); + let active_model_id = LanguageModelId::from(settings.default_model.model.clone()); + let editor_provider_name = + LanguageModelProviderId::from(settings.editor_model.provider.clone()); + let editor_model_id = LanguageModelId::from(settings.editor_model.model.clone()); let inline_alternatives = settings .inline_alternatives .iter() @@ -199,7 +203,8 @@ fn update_active_language_model_from_settings(cx: &mut App) { }) .collect::>(); LanguageModelRegistry::global(cx).update(cx, |registry, cx| { - registry.select_active_model(&provider_name, &model_id, cx); + registry.select_active_model(&active_model_provider_name, &active_model_id, cx); + registry.select_editor_model(&editor_provider_name, &editor_model_id, cx); registry.select_inline_alternative_models(inline_alternatives, cx); }); } diff --git a/crates/assistant/src/assistant_panel.rs b/crates/assistant/src/assistant_panel.rs index eab5add89d..8d5d703a13 100644 --- a/crates/assistant/src/assistant_panel.rs +++ b/crates/assistant/src/assistant_panel.rs @@ -297,7 +297,8 @@ impl AssistantPanel { &LanguageModelRegistry::global(cx), window, |this, _, event: &language_model::Event, window, cx| match event { - language_model::Event::ActiveModelChanged => { + language_model::Event::ActiveModelChanged + | language_model::Event::EditorModelChanged => { this.completion_provider_changed(window, cx); } language_model::Event::ProviderStateChanged => { diff --git a/crates/assistant2/src/active_thread.rs b/crates/assistant2/src/active_thread.rs index f9ef4d0679..05e75d2c76 100644 --- a/crates/assistant2/src/active_thread.rs +++ b/crates/assistant2/src/active_thread.rs @@ -652,7 +652,7 @@ impl ActiveThread { ) .child(message_content), ), - Role::Assistant => div() + Role::Assistant => v_flex() .id(("message-container", ix)) .child(message_content) .map(|parent| { diff --git a/crates/assistant2/src/thread.rs b/crates/assistant2/src/thread.rs index feac911a38..6bfbda90dc 100644 --- a/crates/assistant2/src/thread.rs +++ b/crates/assistant2/src/thread.rs @@ -623,6 +623,7 @@ impl Thread { } pub fn use_pending_tools(&mut self, cx: &mut Context) { + let request = self.to_completion_request(RequestKind::Chat, cx); let pending_tool_uses = self .tool_use .pending_tool_uses() @@ -633,7 +634,7 @@ impl Thread { for tool_use in pending_tool_uses { if let Some(tool) = self.tools.tool(&tool_use.name, cx) { - let task = tool.run(tool_use.input, self.project.clone(), cx); + let task = tool.run(tool_use.input, &request.messages, self.project.clone(), cx); self.insert_tool_output(tool_use.id.clone(), task, cx); } diff --git a/crates/assistant_settings/src/assistant_settings.rs b/crates/assistant_settings/src/assistant_settings.rs index d12f4a23f0..125c2fbff3 100644 --- a/crates/assistant_settings/src/assistant_settings.rs +++ b/crates/assistant_settings/src/assistant_settings.rs @@ -62,6 +62,7 @@ pub struct AssistantSettings { pub default_width: Pixels, pub default_height: Pixels, pub default_model: LanguageModelSelection, + pub editor_model: LanguageModelSelection, pub inline_alternatives: Vec, pub using_outdated_settings_version: bool, pub enable_experimental_live_diffs: bool, @@ -162,6 +163,7 @@ impl AssistantSettingsContent { }) } }), + editor_model: None, inline_alternatives: None, enable_experimental_live_diffs: None, }, @@ -182,6 +184,7 @@ impl AssistantSettingsContent { .id() .to_string(), }), + editor_model: None, inline_alternatives: None, enable_experimental_live_diffs: None, }, @@ -310,6 +313,7 @@ impl Default for VersionedAssistantSettingsContent { default_width: None, default_height: None, default_model: None, + editor_model: None, inline_alternatives: None, enable_experimental_live_diffs: None, }) @@ -340,6 +344,8 @@ pub struct AssistantSettingsContentV2 { default_height: Option, /// The default model to use when creating new chats. default_model: Option, + /// The model to use when applying edits from the assistant. + editor_model: Option, /// Additional models with which to generate alternatives when performing inline assists. inline_alternatives: Option>, /// Enable experimental live diffs in the assistant panel. @@ -470,6 +476,7 @@ impl Settings for AssistantSettings { value.default_height.map(Into::into), ); merge(&mut settings.default_model, value.default_model); + merge(&mut settings.editor_model, value.editor_model); merge(&mut settings.inline_alternatives, value.inline_alternatives); merge( &mut settings.enable_experimental_live_diffs, @@ -528,6 +535,10 @@ mod tests { provider: "test-provider".into(), model: "gpt-99".into(), }), + editor_model: Some(LanguageModelSelection { + provider: "test-provider".into(), + model: "gpt-99".into(), + }), inline_alternatives: None, enabled: None, button: None, diff --git a/crates/assistant_tool/Cargo.toml b/crates/assistant_tool/Cargo.toml index 85eb36e306..f30c6d3134 100644 --- a/crates/assistant_tool/Cargo.toml +++ b/crates/assistant_tool/Cargo.toml @@ -15,6 +15,7 @@ path = "src/assistant_tool.rs" anyhow.workspace = true collections.workspace = true derive_more.workspace = true +language_model.workspace = true gpui.workspace = true parking_lot.workspace = true project.workspace = true diff --git a/crates/assistant_tool/src/assistant_tool.rs b/crates/assistant_tool/src/assistant_tool.rs index 9980d9a47a..5002866287 100644 --- a/crates/assistant_tool/src/assistant_tool.rs +++ b/crates/assistant_tool/src/assistant_tool.rs @@ -5,6 +5,7 @@ use std::sync::Arc; use anyhow::Result; use gpui::{App, Entity, SharedString, Task}; +use language_model::LanguageModelRequestMessage; use project::Project; pub use crate::tool_registry::*; @@ -44,6 +45,7 @@ pub trait Tool: 'static + Send + Sync { fn run( self: Arc, input: serde_json::Value, + messages: &[LanguageModelRequestMessage], project: Entity, cx: &mut App, ) -> Task>; diff --git a/crates/assistant_tools/Cargo.toml b/crates/assistant_tools/Cargo.toml index 394cfe7e8e..22000f3921 100644 --- a/crates/assistant_tools/Cargo.toml +++ b/crates/assistant_tools/Cargo.toml @@ -15,8 +15,18 @@ path = "src/assistant_tools.rs" anyhow.workspace = true assistant_tool.workspace = true chrono.workspace = true +collections.workspace = true +futures.workspace = true gpui.workspace = true +language_model.workspace = true project.workspace = true schemars.workspace = true serde.workspace = true serde_json.workspace = true +util.workspace = true + +[dev-dependencies] +rand.workspace = true +collections = { workspace = true, features = ["test-support"] } +gpui = { workspace = true, features = ["test-support"] } +project = { workspace = true, features = ["test-support"] } diff --git a/crates/assistant_tools/src/assistant_tools.rs b/crates/assistant_tools/src/assistant_tools.rs index d1e9081c23..1b0343430c 100644 --- a/crates/assistant_tools/src/assistant_tools.rs +++ b/crates/assistant_tools/src/assistant_tools.rs @@ -1,3 +1,4 @@ +mod edit_files_tool; mod list_worktrees_tool; mod now_tool; mod read_file_tool; @@ -5,6 +6,7 @@ mod read_file_tool; use assistant_tool::ToolRegistry; use gpui::App; +use crate::edit_files_tool::EditFilesTool; use crate::list_worktrees_tool::ListWorktreesTool; use crate::now_tool::NowTool; use crate::read_file_tool::ReadFileTool; @@ -16,4 +18,5 @@ pub fn init(cx: &mut App) { registry.register_tool(NowTool); registry.register_tool(ListWorktreesTool); registry.register_tool(ReadFileTool); + registry.register_tool(EditFilesTool); } diff --git a/crates/assistant_tools/src/edit_files_tool.rs b/crates/assistant_tools/src/edit_files_tool.rs new file mode 100644 index 0000000000..87255febc4 --- /dev/null +++ b/crates/assistant_tools/src/edit_files_tool.rs @@ -0,0 +1,155 @@ +mod edit_action; + +use collections::HashSet; +use std::{path::Path, sync::Arc}; + +use anyhow::{anyhow, Result}; +use assistant_tool::Tool; +use edit_action::{EditAction, EditActionParser}; +use futures::StreamExt; +use gpui::{App, Entity, Task}; +use language_model::{ + LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role, +}; +use project::{Project, ProjectPath, WorktreeId}; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Serialize, Deserialize, JsonSchema)] +pub struct EditFilesToolInput { + /// The ID of the worktree in which the files reside. + pub worktree_id: usize, + /// Instruct how to modify the files. + pub edit_instructions: String, +} + +pub struct EditFilesTool; + +impl Tool for EditFilesTool { + fn name(&self) -> String { + "edit-files".into() + } + + fn description(&self) -> String { + include_str!("./edit_files_tool/description.md").into() + } + + fn input_schema(&self) -> serde_json::Value { + let schema = schemars::schema_for!(EditFilesToolInput); + serde_json::to_value(&schema).unwrap() + } + + fn run( + self: Arc, + input: serde_json::Value, + messages: &[LanguageModelRequestMessage], + project: Entity, + cx: &mut App, + ) -> Task> { + let input = match serde_json::from_value::(input) { + Ok(input) => input, + Err(err) => return Task::ready(Err(anyhow!(err))), + }; + + let model_registry = LanguageModelRegistry::read_global(cx); + let Some(model) = model_registry.editor_model() else { + return Task::ready(Err(anyhow!("No editor model configured"))); + }; + + let mut messages = messages.to_vec(); + if let Some(last_message) = messages.last_mut() { + // Strip out tool use from the last message because we're in the middle of executing a tool call. + last_message + .content + .retain(|content| !matches!(content, language_model::MessageContent::ToolUse(_))) + } + messages.push(LanguageModelRequestMessage { + role: Role::User, + content: vec![ + include_str!("./edit_files_tool/edit_prompt.md").into(), + input.edit_instructions.into(), + ], + cache: false, + }); + + cx.spawn(|mut cx| async move { + let request = LanguageModelRequest { + messages, + tools: vec![], + stop: vec![], + temperature: None, + }; + + let mut parser = EditActionParser::new(); + + let stream = model.stream_completion_text(request, &cx); + let mut chunks = stream.await?; + + let mut changed_buffers = HashSet::default(); + let mut applied_edits = 0; + + while let Some(chunk) = chunks.stream.next().await { + for action in parser.parse_chunk(&chunk?) { + let project_path = ProjectPath { + worktree_id: WorktreeId::from_usize(input.worktree_id), + path: Path::new(action.file_path()).into(), + }; + + let buffer = project + .update(&mut cx, |project, cx| project.open_buffer(project_path, cx))? + .await?; + + let diff = buffer + .read_with(&cx, |buffer, cx| { + let new_text = match action { + EditAction::Replace { old, new, .. } => { + // TODO: Replace in background? + buffer.text().replace(&old, &new) + } + EditAction::Write { content, .. } => content, + }; + + buffer.diff(new_text, cx) + })? + .await; + + let _clock = + buffer.update(&mut cx, |buffer, cx| buffer.apply_diff(diff, cx))?; + + changed_buffers.insert(buffer); + + applied_edits += 1; + } + } + + // Save each buffer once at the end + for buffer in changed_buffers { + project + .update(&mut cx, |project, cx| project.save_buffer(buffer, cx))? + .await?; + } + + let errors = parser.errors(); + + if errors.is_empty() { + Ok("Successfully applied all edits".into()) + } else { + let error_message = errors + .iter() + .map(|e| e.to_string()) + .collect::>() + .join("\n"); + + if applied_edits > 0 { + Err(anyhow!( + "Applied {} edit(s), but some blocks failed to parse:\n{}", + applied_edits, + error_message + )) + } else { + Err(anyhow!(error_message)) + } + } + }) + } +} diff --git a/crates/assistant_tools/src/edit_files_tool/description.md b/crates/assistant_tools/src/edit_files_tool/description.md new file mode 100644 index 0000000000..4f61ecd3cc --- /dev/null +++ b/crates/assistant_tools/src/edit_files_tool/description.md @@ -0,0 +1,3 @@ +Edit files in a worktree by providing its id and a description of how to modify the code to complete the request. + +Make instructions unambiguous and complete. Explain all needed code changes clearly and completely, but concisely. Just show the changes needed. DO NOT show the entire updated function/file/etc! diff --git a/crates/assistant_tools/src/edit_files_tool/edit_action.rs b/crates/assistant_tools/src/edit_files_tool/edit_action.rs new file mode 100644 index 0000000000..2749418806 --- /dev/null +++ b/crates/assistant_tools/src/edit_files_tool/edit_action.rs @@ -0,0 +1,807 @@ +use util::ResultExt; + +/// Represents an edit action to be performed on a file. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum EditAction { + /// Replace specific content in a file with new content + Replace { + file_path: String, + old: String, + new: String, + }, + /// Write content to a file (create or overwrite) + Write { file_path: String, content: String }, +} + +impl EditAction { + pub fn file_path(&self) -> &str { + match self { + EditAction::Replace { file_path, .. } => file_path, + EditAction::Write { file_path, .. } => file_path, + } + } +} + +/// Parses edit actions from an LLM response. +/// See system.md for more details on the format. +#[derive(Debug)] +pub struct EditActionParser { + state: State, + pre_fence_line: Vec, + marker_ix: usize, + line: usize, + column: usize, + old_bytes: Vec, + new_bytes: Vec, + errors: Vec, +} + +#[derive(Debug, PartialEq, Eq)] +enum State { + /// Anywhere outside an action + Default, + /// After opening ```, in optional language tag + OpenFence, + /// In SEARCH marker + SearchMarker, + /// In search block or divider + SearchBlock, + /// In replace block or REPLACE marker + ReplaceBlock, + /// In closing ``` + CloseFence, +} + +impl EditActionParser { + /// Creates a new `EditActionParser` + pub fn new() -> Self { + Self { + state: State::Default, + pre_fence_line: Vec::new(), + marker_ix: 0, + line: 1, + column: 0, + old_bytes: Vec::new(), + new_bytes: Vec::new(), + errors: Vec::new(), + } + } + + /// Processes a chunk of input text and returns any completed edit actions. + /// + /// This method can be called repeatedly with fragments of input. The parser + /// maintains its state between calls, allowing you to process streaming input + /// as it becomes available. Actions are only inserted once they are fully parsed. + /// + /// If a block fails to parse, it will simply be skipped and an error will be recorded. + /// All errors can be accessed through the `EditActionsParser::errors` method. + pub fn parse_chunk(&mut self, input: &str) -> Vec { + use State::*; + + const FENCE: &[u8] = b"\n```"; + const SEARCH_MARKER: &[u8] = b"<<<<<<< SEARCH\n"; + const DIVIDER: &[u8] = b"=======\n"; + const NL_DIVIDER: &[u8] = b"\n=======\n"; + const REPLACE_MARKER: &[u8] = b">>>>>>> REPLACE"; + const NL_REPLACE_MARKER: &[u8] = b"\n>>>>>>> REPLACE"; + + let mut actions = Vec::new(); + + for byte in input.bytes() { + // Update line and column tracking + if byte == b'\n' { + self.line += 1; + self.column = 0; + } else { + self.column += 1; + } + + match self.state { + Default => match match_marker(byte, FENCE, &mut self.marker_ix) { + MarkerMatch::Complete => { + self.to_state(OpenFence); + } + MarkerMatch::Partial => {} + MarkerMatch::None => { + if self.marker_ix > 0 { + self.marker_ix = 0; + self.pre_fence_line.clear(); + } + + if byte != b'\n' { + self.pre_fence_line.push(byte); + } + } + }, + OpenFence => { + // skip language tag + if byte == b'\n' { + self.to_state(SearchMarker); + } + } + SearchMarker => { + if self.expect_marker(byte, SEARCH_MARKER) { + self.to_state(SearchBlock); + } + } + SearchBlock => { + if collect_until_marker( + byte, + DIVIDER, + NL_DIVIDER, + &mut self.marker_ix, + &mut self.old_bytes, + ) { + self.to_state(ReplaceBlock); + } + } + ReplaceBlock => { + if collect_until_marker( + byte, + REPLACE_MARKER, + NL_REPLACE_MARKER, + &mut self.marker_ix, + &mut self.new_bytes, + ) { + self.to_state(CloseFence); + } + } + CloseFence => { + if self.expect_marker(byte, FENCE) { + if let Some(action) = self.action() { + actions.push(action); + } + self.reset(); + } + } + }; + } + + actions + } + + /// Returns a reference to the errors encountered during parsing. + pub fn errors(&self) -> &[ParseError] { + &self.errors + } + + fn action(&mut self) -> Option { + if self.old_bytes.is_empty() && self.new_bytes.is_empty() { + self.push_error(ParseErrorKind::NoOp); + return None; + } + + let file_path = String::from_utf8(std::mem::take(&mut self.pre_fence_line)).log_err()?; + let content = String::from_utf8(std::mem::take(&mut self.new_bytes)).log_err()?; + + if self.old_bytes.is_empty() { + Some(EditAction::Write { file_path, content }) + } else { + let old = String::from_utf8(std::mem::take(&mut self.old_bytes)).log_err()?; + + Some(EditAction::Replace { + file_path, + old, + new: content, + }) + } + } + + fn expect_marker(&mut self, byte: u8, marker: &'static [u8]) -> bool { + match match_marker(byte, marker, &mut self.marker_ix) { + MarkerMatch::Complete => true, + MarkerMatch::Partial => false, + MarkerMatch::None => { + self.push_error(ParseErrorKind::ExpectedMarker { + expected: marker, + found: byte, + }); + self.reset(); + false + } + } + } + + fn to_state(&mut self, state: State) { + self.state = state; + self.marker_ix = 0; + } + + fn reset(&mut self) { + self.pre_fence_line.clear(); + self.old_bytes.clear(); + self.new_bytes.clear(); + self.to_state(State::Default); + } + + fn push_error(&mut self, kind: ParseErrorKind) { + self.errors.push(ParseError { + line: self.line, + column: self.column, + kind, + }); + } +} + +#[derive(Debug)] +enum MarkerMatch { + None, + Partial, + Complete, +} + +fn match_marker(byte: u8, marker: &[u8], marker_ix: &mut usize) -> MarkerMatch { + if byte == marker[*marker_ix] { + *marker_ix += 1; + + if *marker_ix >= marker.len() { + MarkerMatch::Complete + } else { + MarkerMatch::Partial + } + } else { + MarkerMatch::None + } +} + +fn collect_until_marker( + byte: u8, + marker: &[u8], + nl_marker: &[u8], + marker_ix: &mut usize, + buf: &mut Vec, +) -> bool { + let marker = if buf.is_empty() { + // do not require another newline if block is empty + marker + } else { + nl_marker + }; + + match match_marker(byte, marker, marker_ix) { + MarkerMatch::Complete => true, + MarkerMatch::Partial => false, + MarkerMatch::None => { + if *marker_ix > 0 { + buf.extend_from_slice(&marker[..*marker_ix]); + *marker_ix = 0; + + // The beginning of marker might match current byte + match match_marker(byte, marker, marker_ix) { + MarkerMatch::Complete => return true, + MarkerMatch::Partial => return false, + MarkerMatch::None => { /* no match, keep collecting */ } + } + } + + buf.push(byte); + + false + } + } +} + +#[derive(Debug, PartialEq, Eq)] +pub struct ParseError { + line: usize, + column: usize, + kind: ParseErrorKind, +} + +#[derive(Debug, PartialEq, Eq)] +pub enum ParseErrorKind { + ExpectedMarker { expected: &'static [u8], found: u8 }, + NoOp, +} + +impl std::fmt::Display for ParseErrorKind { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + ParseErrorKind::ExpectedMarker { expected, found } => { + write!( + f, + "Expected marker {:?}, found {:?}", + String::from_utf8_lossy(expected), + *found as char + ) + } + ParseErrorKind::NoOp => { + write!(f, "No search or replace") + } + } + } +} + +impl std::fmt::Display for ParseError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "input:{}:{}: {}", self.line, self.column, self.kind) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use rand::prelude::*; + + #[test] + fn test_simple_edit_action() { + let input = r#"src/main.rs +``` +<<<<<<< SEARCH +fn original() {} +======= +fn replacement() {} +>>>>>>> REPLACE +``` +"#; + + let mut parser = EditActionParser::new(); + let actions = parser.parse_chunk(input); + + assert_eq!(actions.len(), 1); + assert_eq!( + actions[0], + EditAction::Replace { + file_path: "src/main.rs".to_string(), + old: "fn original() {}".to_string(), + new: "fn replacement() {}".to_string(), + } + ); + } + + #[test] + fn test_with_language_tag() { + let input = r#"src/main.rs +```rust +<<<<<<< SEARCH +fn original() {} +======= +fn replacement() {} +>>>>>>> REPLACE +``` +"#; + + let mut parser = EditActionParser::new(); + let actions = parser.parse_chunk(input); + + assert_eq!(actions.len(), 1); + assert_eq!( + actions[0], + EditAction::Replace { + file_path: "src/main.rs".to_string(), + old: "fn original() {}".to_string(), + new: "fn replacement() {}".to_string(), + } + ); + } + + #[test] + fn test_with_surrounding_text() { + let input = r#"Here's a modification I'd like to make to the file: + +src/main.rs +```rust +<<<<<<< SEARCH +fn original() {} +======= +fn replacement() {} +>>>>>>> REPLACE +``` + +This change makes the function better. +"#; + + let mut parser = EditActionParser::new(); + let actions = parser.parse_chunk(input); + + assert_eq!(actions.len(), 1); + assert_eq!( + actions[0], + EditAction::Replace { + file_path: "src/main.rs".to_string(), + old: "fn original() {}".to_string(), + new: "fn replacement() {}".to_string(), + } + ); + } + + #[test] + fn test_multiple_edit_actions() { + let input = r#"First change: +src/main.rs +``` +<<<<<<< SEARCH +fn original() {} +======= +fn replacement() {} +>>>>>>> REPLACE +``` + +Second change: +src/utils.rs +```rust +<<<<<<< SEARCH +fn old_util() -> bool { false } +======= +fn new_util() -> bool { true } +>>>>>>> REPLACE +``` +"#; + + let mut parser = EditActionParser::new(); + let actions = parser.parse_chunk(input); + + assert_eq!(actions.len(), 2); + assert_eq!( + actions[0], + EditAction::Replace { + file_path: "src/main.rs".to_string(), + old: "fn original() {}".to_string(), + new: "fn replacement() {}".to_string(), + } + ); + assert_eq!( + actions[1], + EditAction::Replace { + file_path: "src/utils.rs".to_string(), + old: "fn old_util() -> bool { false }".to_string(), + new: "fn new_util() -> bool { true }".to_string(), + } + ); + } + + #[test] + fn test_multiline() { + let input = r#"src/main.rs +```rust +<<<<<<< SEARCH +fn original() { + println!("This is the original function"); + let x = 42; + if x > 0 { + println!("Positive number"); + } +} +======= +fn replacement() { + println!("This is the replacement function"); + let x = 100; + if x > 50 { + println!("Large number"); + } else { + println!("Small number"); + } +} +>>>>>>> REPLACE +``` +"#; + + let mut parser = EditActionParser::new(); + let actions = parser.parse_chunk(input); + + assert_eq!(actions.len(), 1); + assert_eq!( + actions[0], + EditAction::Replace { + file_path: "src/main.rs".to_string(), + old: "fn original() {\n println!(\"This is the original function\");\n let x = 42;\n if x > 0 {\n println!(\"Positive number\");\n }\n}".to_string(), + new: "fn replacement() {\n println!(\"This is the replacement function\");\n let x = 100;\n if x > 50 {\n println!(\"Large number\");\n } else {\n println!(\"Small number\");\n }\n}".to_string(), + } + ); + } + + #[test] + fn test_write_action() { + let input = r#"Create a new main.rs file: + +src/main.rs +```rust +<<<<<<< SEARCH +======= +fn new_function() { + println!("This function is being added"); +} +>>>>>>> REPLACE +``` +"#; + + let mut parser = EditActionParser::new(); + let actions = parser.parse_chunk(input); + + assert_eq!(actions.len(), 1); + assert_eq!( + actions[0], + EditAction::Write { + file_path: "src/main.rs".to_string(), + content: "fn new_function() {\n println!(\"This function is being added\");\n}" + .to_string(), + } + ); + } + + #[test] + fn test_empty_replace() { + let input = r#"src/main.rs +```rust +<<<<<<< SEARCH +fn this_will_be_deleted() { + println!("Deleting this function"); +} +======= +>>>>>>> REPLACE +``` +"#; + + let mut parser = EditActionParser::new(); + let actions = parser.parse_chunk(input); + + assert_eq!(actions.len(), 1); + assert_eq!( + actions[0], + EditAction::Replace { + file_path: "src/main.rs".to_string(), + old: "fn this_will_be_deleted() {\n println!(\"Deleting this function\");\n}" + .to_string(), + new: "".to_string(), + } + ); + } + + #[test] + fn test_empty_both() { + let input = r#"src/main.rs +```rust +<<<<<<< SEARCH +======= +>>>>>>> REPLACE +``` +"#; + + let mut parser = EditActionParser::new(); + let actions = parser.parse_chunk(input); + + // Should not create an action when both sections are empty + assert_eq!(actions.len(), 0); + + // Check that the NoOp error was added + assert_eq!(parser.errors().len(), 1); + match parser.errors()[0].kind { + ParseErrorKind::NoOp => {} + _ => panic!("Expected NoOp error"), + } + } + + #[test] + fn test_resumability() { + let input_part1 = r#"src/main.rs +```rust +<<<<<<< SEARCH +fn ori"#; + + let input_part2 = r#"ginal() {} +======= +fn replacement() {}"#; + + let input_part3 = r#" +>>>>>>> REPLACE +``` +"#; + + let mut parser = EditActionParser::new(); + let actions1 = parser.parse_chunk(input_part1); + assert_eq!(actions1.len(), 0); + + let actions2 = parser.parse_chunk(input_part2); + // No actions should be complete yet + assert_eq!(actions2.len(), 0); + + let actions3 = parser.parse_chunk(input_part3); + // The third chunk should complete the action + assert_eq!(actions3.len(), 1); + assert_eq!( + actions3[0], + EditAction::Replace { + file_path: "src/main.rs".to_string(), + old: "fn original() {}".to_string(), + new: "fn replacement() {}".to_string(), + } + ); + } + + #[test] + fn test_parser_state_preservation() { + let mut parser = EditActionParser::new(); + let actions1 = parser.parse_chunk("src/main.rs\n```rust\n<<<<<<< SEARCH\n"); + + // Check parser is in the correct state + assert_eq!(parser.state, State::SearchBlock); + assert_eq!(parser.pre_fence_line, b"src/main.rs"); + + // Continue parsing + let actions2 = parser.parse_chunk("original code\n=======\n"); + assert_eq!(parser.state, State::ReplaceBlock); + assert_eq!(parser.old_bytes, b"original code"); + + let actions3 = parser.parse_chunk("replacement code\n>>>>>>> REPLACE\n```\n"); + + // After complete parsing, state should reset + assert_eq!(parser.state, State::Default); + assert!(parser.pre_fence_line.is_empty()); + assert!(parser.old_bytes.is_empty()); + assert!(parser.new_bytes.is_empty()); + + assert_eq!(actions1.len(), 0); + assert_eq!(actions2.len(), 0); + assert_eq!(actions3.len(), 1); + } + + #[test] + fn test_invalid_search_marker() { + let input = r#"src/main.rs +```rust +<<<<<<< WRONG_MARKER +fn original() {} +======= +fn replacement() {} +>>>>>>> REPLACE +``` +"#; + + let mut parser = EditActionParser::new(); + let actions = parser.parse_chunk(input); + assert_eq!(actions.len(), 0); + + assert_eq!(parser.errors().len(), 1); + let error = &parser.errors()[0]; + + assert_eq!(error.line, 3); + assert_eq!(error.column, 9); + assert_eq!( + error.kind, + ParseErrorKind::ExpectedMarker { + expected: b"<<<<<<< SEARCH\n", + found: b'W' + } + ); + } + + #[test] + fn test_missing_closing_fence() { + let input = r#"src/main.rs +```rust +<<<<<<< SEARCH +fn original() {} +======= +fn replacement() {} +>>>>>>> REPLACE + + +src/utils.rs +```rust +<<<<<<< SEARCH +fn utils_func() {} +======= +fn new_utils_func() {} +>>>>>>> REPLACE +``` +"#; + + let mut parser = EditActionParser::new(); + let actions = parser.parse_chunk(input); + + // Only the second block should be parsed + assert_eq!(actions.len(), 1); + assert_eq!( + actions[0], + EditAction::Replace { + file_path: "src/utils.rs".to_string(), + old: "fn utils_func() {}".to_string(), + new: "fn new_utils_func() {}".to_string(), + } + ); + + // The parser should continue after an error + assert_eq!(parser.state, State::Default); + } + + const SYSTEM_PROMPT: &str = include_str!("./edit_prompt.md"); + + #[test] + fn test_parse_examples_in_system_prompt() { + let mut parser = EditActionParser::new(); + let actions = parser.parse_chunk(SYSTEM_PROMPT); + assert_examples_in_system_prompt(&actions, parser.errors()); + } + + #[gpui::test(iterations = 10)] + fn test_random_chunking_of_system_prompt(mut rng: StdRng) { + let mut parser = EditActionParser::new(); + let mut remaining = SYSTEM_PROMPT; + let mut actions = Vec::with_capacity(5); + + while !remaining.is_empty() { + let chunk_size = rng.gen_range(1..=std::cmp::min(remaining.len(), 100)); + + let (chunk, rest) = remaining.split_at(chunk_size); + + actions.extend(parser.parse_chunk(chunk)); + remaining = rest; + } + + assert_examples_in_system_prompt(&actions, parser.errors()); + } + + fn assert_examples_in_system_prompt(actions: &[EditAction], errors: &[ParseError]) { + assert_eq!(actions.len(), 5); + + assert_eq!( + actions[0], + EditAction::Replace { + file_path: "mathweb/flask/app.py".to_string(), + old: "from flask import Flask".to_string(), + new: "import math\nfrom flask import Flask".to_string(), + } + ); + + assert_eq!( + actions[1], + EditAction::Replace { + file_path: "mathweb/flask/app.py".to_string(), + old: "def factorial(n):\n \"compute factorial\"\n\n if n == 0:\n return 1\n else:\n return n * factorial(n-1)\n".to_string(), + new: "".to_string(), + } + ); + + assert_eq!( + actions[2], + EditAction::Replace { + file_path: "mathweb/flask/app.py".to_string(), + old: " return str(factorial(n))".to_string(), + new: " return str(math.factorial(n))".to_string(), + } + ); + + assert_eq!( + actions[3], + EditAction::Write { + file_path: "hello.py".to_string(), + content: "def hello():\n \"print a greeting\"\n\n print(\"hello\")" + .to_string(), + } + ); + + assert_eq!( + actions[4], + EditAction::Replace { + file_path: "main.py".to_string(), + old: "def hello():\n \"print a greeting\"\n\n print(\"hello\")".to_string(), + new: "from hello import hello".to_string(), + } + ); + + // Ensure we have no parsing errors + assert!(errors.is_empty(), "Parsing errors found: {:?}", errors); + } + + #[test] + fn test_print_error() { + let input = r#"src/main.rs +```rust +<<<<<<< WRONG_MARKER +fn original() {} +======= +fn replacement() {} +>>>>>>> REPLACE +``` +"#; + + let mut parser = EditActionParser::new(); + parser.parse_chunk(input); + + assert_eq!(parser.errors().len(), 1); + let error = &parser.errors()[0]; + let expected_error = r#"input:3:9: Expected marker "<<<<<<< SEARCH\n", found 'W'"#; + + assert_eq!(format!("{}", error), expected_error); + } +} diff --git a/crates/assistant_tools/src/edit_files_tool/edit_prompt.md b/crates/assistant_tools/src/edit_files_tool/edit_prompt.md new file mode 100644 index 0000000000..834b811981 --- /dev/null +++ b/crates/assistant_tools/src/edit_files_tool/edit_prompt.md @@ -0,0 +1,134 @@ +Act as an expert software developer. +Always use best practices when coding. +Respect and use existing conventions, libraries, etc that are already present in the code base. + +Take requests for changes to the supplied code. +Always reply to the user in the same language they are using. + +Once you understand the request, you MUST: + +1. Decide if you need to propose *SEARCH/REPLACE* edits to any files that haven't been added to the chat. You can create new files without asking! + +2. Think step-by-step and explain the needed changes in a few short sentences. + +3. Describe each change with a *SEARCH/REPLACE block* per the examples below. + +All changes to files must use this *SEARCH/REPLACE block* format. +ONLY EVER RETURN CODE IN A *SEARCH/REPLACE BLOCK*! + +# Example conversations: + +## USER: Change get_factorial() to use math.factorial + +## ASSISTANT: To make this change we need to modify `mathweb/flask/app.py` to: + +1. Import the math package. +2. Remove the existing factorial() function. +3. Update get_factorial() to call math.factorial instead. + +Here are the *SEARCH/REPLACE* blocks: + +mathweb/flask/app.py +```python +<<<<<<< SEARCH +from flask import Flask +======= +import math +from flask import Flask +>>>>>>> REPLACE +``` + +mathweb/flask/app.py +```python +<<<<<<< SEARCH +def factorial(n): + "compute factorial" + + if n == 0: + return 1 + else: + return n * factorial(n-1) + +======= +>>>>>>> REPLACE +``` + +mathweb/flask/app.py +```python +<<<<<<< SEARCH + return str(factorial(n)) +======= + return str(math.factorial(n)) +>>>>>>> REPLACE +``` + + +## USER: Refactor hello() into its own file. + +## ASSISTANT: To make this change we need to modify `main.py` and make a new file `hello.py`: + +1. Make a new hello.py file with hello() in it. +2. Remove hello() from main.py and replace it with an import. + +Here are the *SEARCH/REPLACE* blocks: + +hello.py +```python +<<<<<<< SEARCH +======= +def hello(): + "print a greeting" + + print("hello") +>>>>>>> REPLACE +``` + +main.py +```python +<<<<<<< SEARCH +def hello(): + "print a greeting" + + print("hello") +======= +from hello import hello +>>>>>>> REPLACE +``` +# *SEARCH/REPLACE block* Rules: + +Every *SEARCH/REPLACE block* must use this format: +1. The *FULL* file path alone on a line, verbatim. No bold asterisks, no quotes around it, no escaping of characters, etc. +2. The opening fence and code language, eg: ```python +3. The start of search block: <<<<<<< SEARCH +4. A contiguous chunk of lines to search for in the existing source code +5. The dividing line: ======= +6. The lines to replace into the source code +7. The end of the replace block: >>>>>>> REPLACE +8. The closing fence: ``` + +Use the *FULL* file path, as shown to you by the user. + +Every *SEARCH* section must *EXACTLY MATCH* the existing file content, character for character, including all comments, docstrings, etc. +If the file contains code or other data wrapped/escaped in json/xml/quotes or other containers, you need to propose edits to the literal contents of the file, including the container markup. + +*SEARCH/REPLACE* blocks will *only* replace the first match occurrence. +Including multiple unique *SEARCH/REPLACE* blocks if needed. +Include enough lines in each SEARCH section to uniquely match each set of lines that need to change. + +Keep *SEARCH/REPLACE* blocks concise. +Break large *SEARCH/REPLACE* blocks into a series of smaller blocks that each change a small portion of the file. +Include just the changing lines, and a few surrounding lines if needed for uniqueness. +Do not include long runs of unchanging lines in *SEARCH/REPLACE* blocks. + +Only create *SEARCH/REPLACE* blocks for files that the user has added to the chat! + +To move code within a file, use 2 *SEARCH/REPLACE* blocks: 1 to delete it from its current location, 1 to insert it in the new location. + +Pay attention to which filenames the user wants you to edit, especially if they are asking you to create a new file. + +If you want to put code in a new file, use a *SEARCH/REPLACE block* with: +- A new file path, including dir name if needed +- An empty `SEARCH` section +- The new file's contents in the `REPLACE` section + +ONLY EVER RETURN CODE IN A *SEARCH/REPLACE BLOCK*! diff --git a/crates/assistant_tools/src/list_worktrees_tool.rs b/crates/assistant_tools/src/list_worktrees_tool.rs index eace3b6665..d30f987424 100644 --- a/crates/assistant_tools/src/list_worktrees_tool.rs +++ b/crates/assistant_tools/src/list_worktrees_tool.rs @@ -3,6 +3,7 @@ use std::sync::Arc; use anyhow::Result; use assistant_tool::Tool; use gpui::{App, Entity, Task}; +use language_model::LanguageModelRequestMessage; use project::Project; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; @@ -34,6 +35,7 @@ impl Tool for ListWorktreesTool { fn run( self: Arc, _input: serde_json::Value, + _messages: &[LanguageModelRequestMessage], project: Entity, cx: &mut App, ) -> Task> { diff --git a/crates/assistant_tools/src/now_tool.rs b/crates/assistant_tools/src/now_tool.rs index 23fe6aa43d..80332a184e 100644 --- a/crates/assistant_tools/src/now_tool.rs +++ b/crates/assistant_tools/src/now_tool.rs @@ -4,6 +4,7 @@ use anyhow::{anyhow, Result}; use assistant_tool::Tool; use chrono::{Local, Utc}; use gpui::{App, Entity, Task}; +use language_model::LanguageModelRequestMessage; use project::Project; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; @@ -42,6 +43,7 @@ impl Tool for NowTool { fn run( self: Arc, input: serde_json::Value, + _messages: &[LanguageModelRequestMessage], _project: Entity, _cx: &mut App, ) -> Task> { diff --git a/crates/assistant_tools/src/read_file_tool.rs b/crates/assistant_tools/src/read_file_tool.rs index 88775a3024..82df2d499d 100644 --- a/crates/assistant_tools/src/read_file_tool.rs +++ b/crates/assistant_tools/src/read_file_tool.rs @@ -4,6 +4,7 @@ use std::sync::Arc; use anyhow::{anyhow, Result}; use assistant_tool::Tool; use gpui::{App, Entity, Task}; +use language_model::LanguageModelRequestMessage; use project::{Project, ProjectPath, WorktreeId}; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; @@ -37,6 +38,7 @@ impl Tool for ReadFileTool { fn run( self: Arc, input: serde_json::Value, + _messages: &[LanguageModelRequestMessage], project: Entity, cx: &mut App, ) -> Task> { @@ -56,7 +58,16 @@ impl Tool for ReadFileTool { })? .await?; - cx.update(|cx| buffer.read(cx).text()) + buffer.read_with(&cx, |buffer, _cx| { + if buffer + .file() + .map_or(false, |file| file.disk_state().exists()) + { + Ok(buffer.text()) + } else { + Err(anyhow!("File does not exist")) + } + })? }) } } diff --git a/crates/context_server/Cargo.toml b/crates/context_server/Cargo.toml index 5306b314d6..6e3feaf2ee 100644 --- a/crates/context_server/Cargo.toml +++ b/crates/context_server/Cargo.toml @@ -21,6 +21,7 @@ context_server_settings.workspace = true extension.workspace = true futures.workspace = true gpui.workspace = true +language_model.workspace = true log.workspace = true parking_lot.workspace = true postage.workspace = true diff --git a/crates/context_server/src/context_server_tool.rs b/crates/context_server/src/context_server_tool.rs index 5601a0d245..899db58c7d 100644 --- a/crates/context_server/src/context_server_tool.rs +++ b/crates/context_server/src/context_server_tool.rs @@ -3,6 +3,7 @@ use std::sync::Arc; use anyhow::{anyhow, bail, Result}; use assistant_tool::{Tool, ToolSource}; use gpui::{App, Entity, Task}; +use language_model::LanguageModelRequestMessage; use project::Project; use crate::manager::ContextServerManager; @@ -58,6 +59,7 @@ impl Tool for ContextServerTool { fn run( self: Arc, input: serde_json::Value, + _messages: &[LanguageModelRequestMessage], _project: Entity, cx: &mut App, ) -> Task> { diff --git a/crates/language_model/src/registry.rs b/crates/language_model/src/registry.rs index 792fc8b41b..c3c6b3bb03 100644 --- a/crates/language_model/src/registry.rs +++ b/crates/language_model/src/registry.rs @@ -18,6 +18,7 @@ impl Global for GlobalLanguageModelRegistry {} #[derive(Default)] pub struct LanguageModelRegistry { active_model: Option, + editor_model: Option, providers: BTreeMap>, inline_alternatives: Vec>, } @@ -29,6 +30,7 @@ pub struct ActiveModel { pub enum Event { ActiveModelChanged, + EditorModelChanged, ProviderStateChanged, AddedProvider(LanguageModelProviderId), RemovedProvider(LanguageModelProviderId), @@ -128,6 +130,22 @@ impl LanguageModelRegistry { } } + pub fn select_editor_model( + &mut self, + provider: &LanguageModelProviderId, + model_id: &LanguageModelId, + cx: &mut Context, + ) { + let Some(provider) = self.provider(provider) else { + return; + }; + + let models = provider.provided_models(cx); + if let Some(model) = models.iter().find(|model| &model.id() == model_id).cloned() { + self.set_editor_model(Some(model), cx); + } + } + pub fn set_active_provider( &mut self, provider: Option>, @@ -162,6 +180,28 @@ impl LanguageModelRegistry { } } + pub fn set_editor_model( + &mut self, + model: Option>, + cx: &mut Context, + ) { + if let Some(model) = model { + let provider_id = model.provider_id(); + if let Some(provider) = self.providers.get(&provider_id).cloned() { + self.editor_model = Some(ActiveModel { + provider, + model: Some(model), + }); + cx.emit(Event::EditorModelChanged); + } else { + log::warn!("Active model's provider not found in registry"); + } + } else { + self.editor_model = None; + cx.emit(Event::EditorModelChanged); + } + } + pub fn active_provider(&self) -> Option> { Some(self.active_model.as_ref()?.provider.clone()) } @@ -170,6 +210,10 @@ impl LanguageModelRegistry { self.active_model.as_ref()?.model.clone() } + pub fn editor_model(&self) -> Option> { + self.editor_model.as_ref()?.model.clone() + } + /// Selects and sets the inline alternatives for language models based on /// provider name and id. pub fn select_inline_alternative_models(