diff --git a/Cargo.lock b/Cargo.lock index a0df4de529..65c6dfc732 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -683,6 +683,7 @@ dependencies = [ "language_model", "language_models", "log", + "lsp", "markdown", "open", "paths", diff --git a/crates/assistant_tools/Cargo.toml b/crates/assistant_tools/Cargo.toml index 8f81f1a695..8abe78a98f 100644 --- a/crates/assistant_tools/Cargo.toml +++ b/crates/assistant_tools/Cargo.toml @@ -36,6 +36,7 @@ itertools.workspace = true language.workspace = true language_model.workspace = true log.workspace = true +lsp.workspace = true markdown.workspace = true open.workspace = true paths.workspace = true @@ -64,6 +65,7 @@ workspace.workspace = true zed_llm_client.workspace = true [dev-dependencies] +lsp = { workspace = true, features = ["test-support"] } client = { workspace = true, features = ["test-support"] } clock = { workspace = true, features = ["test-support"] } collections = { workspace = true, features = ["test-support"] } diff --git a/crates/assistant_tools/src/edit_file_tool.rs b/crates/assistant_tools/src/edit_file_tool.rs index 11ae95396a..150fdc5eca 100644 --- a/crates/assistant_tools/src/edit_file_tool.rs +++ b/crates/assistant_tools/src/edit_file_tool.rs @@ -18,16 +18,21 @@ use gpui::{ use indoc::formatdoc; use language::{ Anchor, Buffer, Capability, LanguageRegistry, LineEnding, OffsetRangeExt, Point, Rope, - TextBuffer, language_settings::SoftWrap, + TextBuffer, + language_settings::{self, FormatOnSave, SoftWrap}, }; use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat}; use markdown::{Markdown, MarkdownElement, MarkdownStyle}; -use project::{Project, ProjectPath}; +use project::{ + Project, ProjectPath, + lsp_store::{FormatTrigger, LspFormatTarget}, +}; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use settings::Settings; use std::{ cmp::Reverse, + collections::HashSet, ops::Range, path::{Path, PathBuf}, sync::Arc, @@ -189,8 +194,10 @@ impl Tool for EditFileTool { }); let card_clone = card.clone(); + let action_log_clone = action_log.clone(); let task = cx.spawn(async move |cx: &mut AsyncApp| { - let edit_agent = EditAgent::new(model, project.clone(), action_log, Templates::new()); + let edit_agent = + EditAgent::new(model, project.clone(), action_log_clone, Templates::new()); let buffer = project .update(cx, |project, cx| { @@ -244,19 +251,53 @@ impl Tool for EditFileTool { } let agent_output = output.await?; + // If format_on_save is enabled, format the buffer + let format_on_save_enabled = buffer + .read_with(cx, |buffer, cx| { + let settings = language_settings::language_settings( + buffer.language().map(|l| l.name()), + buffer.file(), + cx, + ); + !matches!(settings.format_on_save, FormatOnSave::Off) + }) + .unwrap_or(false); + + if format_on_save_enabled { + let format_task = project.update(cx, |project, cx| { + project.format( + HashSet::from_iter([buffer.clone()]), + LspFormatTarget::Buffers, + false, // Don't push to history since the tool did it. + FormatTrigger::Save, + cx, + ) + })?; + format_task.await.log_err(); + } + project .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))? .await?; + // Notify the action log that we've edited the buffer (*after* formatting has completed). + action_log.update(cx, |log, cx| { + log.buffer_edited(buffer.clone(), cx); + })?; + let new_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?; - let new_text = cx.background_spawn({ - let new_snapshot = new_snapshot.clone(); - async move { new_snapshot.text() } - }); - let diff = cx.background_spawn(async move { - language::unified_diff(&old_snapshot.text(), &new_snapshot.text()) - }); - let (new_text, diff) = futures::join!(new_text, diff); + let (new_text, diff) = cx + .background_spawn({ + let new_snapshot = new_snapshot.clone(); + let old_text = old_text.clone(); + async move { + let new_text = new_snapshot.text(); + let diff = language::unified_diff(&old_text, &new_text); + + (new_text, diff) + } + }) + .await; let output = EditFileToolOutput { original_path: project_path.path.to_path_buf(), @@ -1099,8 +1140,8 @@ async fn build_buffer_diff( mod tests { use super::*; use client::TelemetrySettings; - use fs::FakeFs; - use gpui::TestAppContext; + use fs::{FakeFs, Fs}; + use gpui::{TestAppContext, UpdateGlobal}; use language_model::fake_provider::FakeLanguageModel; use serde_json::json; use settings::SettingsStore; @@ -1310,4 +1351,340 @@ mod tests { Project::init_settings(cx); }); } + + #[gpui::test] + async fn test_format_on_save(cx: &mut TestAppContext) { + init_test(cx); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree("/root", json!({"src": {}})).await; + + let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; + + // Set up a Rust language with LSP formatting support + let rust_language = Arc::new(language::Language::new( + language::LanguageConfig { + name: "Rust".into(), + matcher: language::LanguageMatcher { + path_suffixes: vec!["rs".to_string()], + ..Default::default() + }, + ..Default::default() + }, + None, + )); + + // Register the language and fake LSP + let language_registry = project.read_with(cx, |project, _| project.languages().clone()); + language_registry.add(rust_language); + + let mut fake_language_servers = language_registry.register_fake_lsp( + "Rust", + language::FakeLspAdapter { + capabilities: lsp::ServerCapabilities { + document_formatting_provider: Some(lsp::OneOf::Left(true)), + ..Default::default() + }, + ..Default::default() + }, + ); + + // Create the file + fs.save( + path!("/root/src/main.rs").as_ref(), + &"initial content".into(), + language::LineEnding::Unix, + ) + .await + .unwrap(); + + // Open the buffer to trigger LSP initialization + let buffer = project + .update(cx, |project, cx| { + project.open_local_buffer(path!("/root/src/main.rs"), cx) + }) + .await + .unwrap(); + + // Register the buffer with language servers + let _handle = project.update(cx, |project, cx| { + project.register_buffer_with_language_servers(&buffer, cx) + }); + + const UNFORMATTED_CONTENT: &str = "fn main() {println!(\"Hello!\");}\n"; + const FORMATTED_CONTENT: &str = + "This file was formatted by the fake formatter in the test.\n"; + + // Get the fake language server and set up formatting handler + let fake_language_server = fake_language_servers.next().await.unwrap(); + fake_language_server.set_request_handler::({ + |_, _| async move { + Ok(Some(vec![lsp::TextEdit { + range: lsp::Range::new(lsp::Position::new(0, 0), lsp::Position::new(1, 0)), + new_text: FORMATTED_CONTENT.to_string(), + }])) + } + }); + + let action_log = cx.new(|_| ActionLog::new(project.clone())); + let model = Arc::new(FakeLanguageModel::default()); + + // First, test with format_on_save enabled + cx.update(|cx| { + SettingsStore::update_global(cx, |store, cx| { + store.update_user_settings::( + cx, + |settings| { + settings.defaults.format_on_save = Some(FormatOnSave::On); + settings.defaults.formatter = + Some(language::language_settings::SelectedFormatter::Auto); + }, + ); + }); + }); + + // Have the model stream unformatted content + let edit_result = { + let edit_task = cx.update(|cx| { + let input = serde_json::to_value(EditFileToolInput { + display_description: "Create main function".into(), + path: "root/src/main.rs".into(), + mode: EditFileMode::Overwrite, + }) + .unwrap(); + Arc::new(EditFileTool) + .run( + input, + Arc::default(), + project.clone(), + action_log.clone(), + model.clone(), + None, + cx, + ) + .output + }); + + // Stream the unformatted content + cx.executor().run_until_parked(); + model.stream_last_completion_response(UNFORMATTED_CONTENT.to_string()); + model.end_last_completion_stream(); + + edit_task.await + }; + assert!(edit_result.is_ok()); + + // Wait for any async operations (e.g. formatting) to complete + cx.executor().run_until_parked(); + + // Read the file to verify it was formatted automatically + let new_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap(); + assert_eq!( + // Ignore carriage returns on Windows + new_content.replace("\r\n", "\n"), + FORMATTED_CONTENT, + "Code should be formatted when format_on_save is enabled" + ); + + let stale_buffer_count = action_log.read_with(cx, |log, cx| log.stale_buffers(cx).count()); + + assert_eq!( + stale_buffer_count, 0, + "BUG: Buffer is incorrectly marked as stale after format-on-save. Found {} stale buffers. \ + This causes the agent to think the file was modified externally when it was just formatted.", + stale_buffer_count + ); + + // Next, test with format_on_save disabled + cx.update(|cx| { + SettingsStore::update_global(cx, |store, cx| { + store.update_user_settings::( + cx, + |settings| { + settings.defaults.format_on_save = Some(FormatOnSave::Off); + }, + ); + }); + }); + + // Stream unformatted edits again + let edit_result = { + let edit_task = cx.update(|cx| { + let input = serde_json::to_value(EditFileToolInput { + display_description: "Update main function".into(), + path: "root/src/main.rs".into(), + mode: EditFileMode::Overwrite, + }) + .unwrap(); + Arc::new(EditFileTool) + .run( + input, + Arc::default(), + project.clone(), + action_log.clone(), + model.clone(), + None, + cx, + ) + .output + }); + + // Stream the unformatted content + cx.executor().run_until_parked(); + model.stream_last_completion_response(UNFORMATTED_CONTENT.to_string()); + model.end_last_completion_stream(); + + edit_task.await + }; + assert!(edit_result.is_ok()); + + // Wait for any async operations (e.g. formatting) to complete + cx.executor().run_until_parked(); + + // Verify the file was not formatted + let new_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap(); + assert_eq!( + // Ignore carriage returns on Windows + new_content.replace("\r\n", "\n"), + UNFORMATTED_CONTENT, + "Code should not be formatted when format_on_save is disabled" + ); + } + + #[gpui::test] + async fn test_remove_trailing_whitespace(cx: &mut TestAppContext) { + init_test(cx); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree("/root", json!({"src": {}})).await; + + // Create a simple file with trailing whitespace + fs.save( + path!("/root/src/main.rs").as_ref(), + &"initial content".into(), + language::LineEnding::Unix, + ) + .await + .unwrap(); + + let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; + let action_log = cx.new(|_| ActionLog::new(project.clone())); + let model = Arc::new(FakeLanguageModel::default()); + + // First, test with remove_trailing_whitespace_on_save enabled + cx.update(|cx| { + SettingsStore::update_global(cx, |store, cx| { + store.update_user_settings::( + cx, + |settings| { + settings.defaults.remove_trailing_whitespace_on_save = Some(true); + }, + ); + }); + }); + + const CONTENT_WITH_TRAILING_WHITESPACE: &str = + "fn main() { \n println!(\"Hello!\"); \n}\n"; + + // Have the model stream content that contains trailing whitespace + let edit_result = { + let edit_task = cx.update(|cx| { + let input = serde_json::to_value(EditFileToolInput { + display_description: "Create main function".into(), + path: "root/src/main.rs".into(), + mode: EditFileMode::Overwrite, + }) + .unwrap(); + Arc::new(EditFileTool) + .run( + input, + Arc::default(), + project.clone(), + action_log.clone(), + model.clone(), + None, + cx, + ) + .output + }); + + // Stream the content with trailing whitespace + cx.executor().run_until_parked(); + model.stream_last_completion_response(CONTENT_WITH_TRAILING_WHITESPACE.to_string()); + model.end_last_completion_stream(); + + edit_task.await + }; + assert!(edit_result.is_ok()); + + // Wait for any async operations (e.g. formatting) to complete + cx.executor().run_until_parked(); + + // Read the file to verify trailing whitespace was removed automatically + assert_eq!( + // Ignore carriage returns on Windows + fs.load(path!("/root/src/main.rs").as_ref()) + .await + .unwrap() + .replace("\r\n", "\n"), + "fn main() {\n println!(\"Hello!\");\n}\n", + "Trailing whitespace should be removed when remove_trailing_whitespace_on_save is enabled" + ); + + // Next, test with remove_trailing_whitespace_on_save disabled + cx.update(|cx| { + SettingsStore::update_global(cx, |store, cx| { + store.update_user_settings::( + cx, + |settings| { + settings.defaults.remove_trailing_whitespace_on_save = Some(false); + }, + ); + }); + }); + + // Stream edits again with trailing whitespace + let edit_result = { + let edit_task = cx.update(|cx| { + let input = serde_json::to_value(EditFileToolInput { + display_description: "Update main function".into(), + path: "root/src/main.rs".into(), + mode: EditFileMode::Overwrite, + }) + .unwrap(); + Arc::new(EditFileTool) + .run( + input, + Arc::default(), + project.clone(), + action_log.clone(), + model.clone(), + None, + cx, + ) + .output + }); + + // Stream the content with trailing whitespace + cx.executor().run_until_parked(); + model.stream_last_completion_response(CONTENT_WITH_TRAILING_WHITESPACE.to_string()); + model.end_last_completion_stream(); + + edit_task.await + }; + assert!(edit_result.is_ok()); + + // Wait for any async operations (e.g. formatting) to complete + cx.executor().run_until_parked(); + + // Verify the file still has trailing whitespace + // Read the file again - it should still have trailing whitespace + let final_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap(); + assert_eq!( + // Ignore carriage returns on Windows + final_content.replace("\r\n", "\n"), + CONTENT_WITH_TRAILING_WHITESPACE, + "Trailing whitespace should remain when remove_trailing_whitespace_on_save is disabled" + ); + } }