diff --git a/Cargo.lock b/Cargo.lock index b2713cc1dc..7840702c68 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -683,6 +683,7 @@ dependencies = [ "language_model", "language_models", "log", + "lsp", "markdown", "open", "paths", diff --git a/crates/assistant_tools/Cargo.toml b/crates/assistant_tools/Cargo.toml index 6d6baf2d54..6d02d34770 100644 --- a/crates/assistant_tools/Cargo.toml +++ b/crates/assistant_tools/Cargo.toml @@ -59,11 +59,12 @@ ui.workspace = true util.workspace = true web_search.workspace = true which.workspace = true -workspace-hack.workspace = true workspace.workspace = true +workspace-hack.workspace = true zed_llm_client.workspace = true [dev-dependencies] +lsp = { workspace = true, features = ["test-support"] } client = { workspace = true, features = ["test-support"] } clock = { workspace = true, features = ["test-support"] } collections = { workspace = true, features = ["test-support"] } diff --git a/crates/assistant_tools/src/edit_file_tool.rs b/crates/assistant_tools/src/edit_file_tool.rs index 6c0d22704f..d97004568c 100644 --- a/crates/assistant_tools/src/edit_file_tool.rs +++ b/crates/assistant_tools/src/edit_file_tool.rs @@ -8,6 +8,10 @@ use assistant_tool::{ ActionLog, AnyToolCard, Tool, ToolCard, ToolResult, ToolResultContent, ToolResultOutput, ToolUseStatus, }; +use language::language_settings::{self, FormatOnSave}; +use project::lsp_store::{FormatTrigger, LspFormatTarget}; +use std::collections::HashSet; + use buffer_diff::{BufferDiff, BufferDiffSnapshot}; use editor::{Editor, EditorMode, MultiBuffer, PathKey}; use futures::StreamExt; @@ -249,6 +253,40 @@ impl Tool for EditFileTool { } let agent_output = output.await?; + // Format buffer if format_on_save is enabled, before saving. + // If any part of the formatting operation fails, log an error but + // don't block the completion of the edit tool's work. + let should_format = buffer + .read_with(cx, |buffer, cx| { + let settings = language_settings::language_settings( + buffer.language().map(|l| l.name()), + buffer.file(), + cx, + ); + !matches!(settings.format_on_save, FormatOnSave::Off) + }) + .log_err() + .unwrap_or(false); + + if should_format { + let buffers = HashSet::from_iter([buffer.clone()]); + + if let Some(format_task) = project + .update(cx, move |project, cx| { + project.format( + buffers, + LspFormatTarget::Buffers, + false, // Don't push to history since the tool did it. + FormatTrigger::Save, + cx, + ) + }) + .log_err() + { + format_task.await.log_err(); + } + } + project .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))? .await?; @@ -918,11 +956,15 @@ async fn build_buffer_diff( mod tests { use super::*; use client::TelemetrySettings; - use fs::FakeFs; - use gpui::TestAppContext; + use fs::{FakeFs, Fs}; + use gpui::{TestAppContext, UpdateGlobal}; + use language::{FakeLspAdapter, Language, LanguageConfig, LanguageMatcher}; use language_model::fake_provider::FakeLanguageModel; + use language_settings::{AllLanguageSettings, Formatter, FormatterList, SelectedFormatter}; + use lsp; use serde_json::json; use settings::SettingsStore; + use std::sync::Arc; use util::path; #[gpui::test] @@ -1129,4 +1171,376 @@ mod tests { Project::init_settings(cx); }); } + + #[gpui::test] + async fn test_remove_trailing_whitespace(cx: &mut TestAppContext) { + init_test(cx); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree("/root", json!({"src": {}})).await; + + // Create a simple file with trailing whitespace + fs.save( + path!("/root/src/main.rs").as_ref(), + &"initial content".into(), + LineEnding::Unix, + ) + .await + .unwrap(); + + let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; + let action_log = cx.new(|_| ActionLog::new(project.clone())); + let model = Arc::new(FakeLanguageModel::default()); + + // First, test with remove_trailing_whitespace_on_save enabled + cx.update(|cx| { + SettingsStore::update_global(cx, |store, cx| { + store.update_user_settings::(cx, |settings| { + settings.defaults.remove_trailing_whitespace_on_save = Some(true); + }); + }); + }); + + const CONTENT_WITH_TRAILING_WHITESPACE: &str = + "fn main() { \n println!(\"Hello!\"); \n}\n"; + + // Have the model stream content that contains trailing whitespace + let edit_result = { + let edit_task = cx.update(|cx| { + let input = serde_json::to_value(EditFileToolInput { + display_description: "Create main function".into(), + path: "root/src/main.rs".into(), + mode: EditFileMode::Overwrite, + }) + .unwrap(); + Arc::new(EditFileTool) + .run( + input, + Arc::default(), + project.clone(), + action_log.clone(), + model.clone(), + None, + cx, + ) + .output + }); + + // Stream the content with trailing whitespace + cx.executor().run_until_parked(); + model.stream_last_completion_response(CONTENT_WITH_TRAILING_WHITESPACE.to_string()); + model.end_last_completion_stream(); + + edit_task.await + }; + assert!(edit_result.is_ok()); + + // Wait for any async operations (e.g. formatting) to complete + cx.executor().run_until_parked(); + + // Read the file to verify trailing whitespace was removed automatically + assert_eq!( + // Ignore carriage returns on Windows + fs.load(path!("/root/src/main.rs").as_ref()) + .await + .unwrap() + .replace("\r\n", "\n"), + "fn main() {\n println!(\"Hello!\");\n}\n", + "Trailing whitespace should be removed when remove_trailing_whitespace_on_save is enabled" + ); + + // Next, test with remove_trailing_whitespace_on_save disabled + cx.update(|cx| { + SettingsStore::update_global(cx, |store, cx| { + store.update_user_settings::(cx, |settings| { + settings.defaults.remove_trailing_whitespace_on_save = Some(false); + }); + }); + }); + + // Stream edits again with trailing whitespace + let edit_result = { + let edit_task = cx.update(|cx| { + let input = serde_json::to_value(EditFileToolInput { + display_description: "Update main function".into(), + path: "root/src/main.rs".into(), + mode: EditFileMode::Overwrite, + }) + .unwrap(); + Arc::new(EditFileTool) + .run( + input, + Arc::default(), + project.clone(), + action_log.clone(), + model.clone(), + None, + cx, + ) + .output + }); + + // Stream the content with trailing whitespace + cx.executor().run_until_parked(); + model.stream_last_completion_response(CONTENT_WITH_TRAILING_WHITESPACE.to_string()); + model.end_last_completion_stream(); + + edit_task.await + }; + assert!(edit_result.is_ok()); + + // Wait for any async operations (e.g. formatting) to complete + cx.executor().run_until_parked(); + + // Verify the file still has trailing whitespace + // Read the file again - it should still have trailing whitespace + let final_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap(); + assert_eq!( + // Ignore carriage returns on Windows + final_content.replace("\r\n", "\n"), + CONTENT_WITH_TRAILING_WHITESPACE, + "Trailing whitespace should remain when remove_trailing_whitespace_on_save is disabled" + ); + } + + #[gpui::test] + async fn test_format_on_save(cx: &mut TestAppContext) { + init_test(cx); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree("/root", json!({"src": {}})).await; + + let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; + + // Set up a Rust language with LSP formatting support + let rust_language = Arc::new(Language::new( + LanguageConfig { + name: "Rust".into(), + matcher: LanguageMatcher { + path_suffixes: vec!["rs".to_string()], + ..Default::default() + }, + ..Default::default() + }, + None, + )); + + // Register the language and fake LSP + let language_registry = project.read_with(cx, |project, _| project.languages().clone()); + language_registry.add(rust_language); + + let mut fake_language_servers = language_registry.register_fake_lsp( + "Rust", + FakeLspAdapter { + capabilities: lsp::ServerCapabilities { + document_formatting_provider: Some(lsp::OneOf::Left(true)), + ..Default::default() + }, + ..Default::default() + }, + ); + + // Create the file + fs.save( + path!("/root/src/main.rs").as_ref(), + &"initial content".into(), + LineEnding::Unix, + ) + .await + .unwrap(); + + // Open the buffer to trigger LSP initialization + let buffer = project + .update(cx, |project, cx| { + project.open_local_buffer(path!("/root/src/main.rs"), cx) + }) + .await + .unwrap(); + + // Register the buffer with language servers + let _handle = project.update(cx, |project, cx| { + project.register_buffer_with_language_servers(&buffer, cx) + }); + + const UNFORMATTED_CONTENT: &str = "fn main() {println!(\"Hello!\");}\n"; + const FORMATTED_CONTENT: &str = + "This file was formatted by the fake formatter in the test.\n"; + + // Get the fake language server and set up formatting handler + let fake_language_server = fake_language_servers.next().await.unwrap(); + fake_language_server.set_request_handler::({ + |_, _| async move { + Ok(Some(vec![lsp::TextEdit { + range: lsp::Range::new(lsp::Position::new(0, 0), lsp::Position::new(1, 0)), + new_text: FORMATTED_CONTENT.to_string(), + }])) + } + }); + + let action_log = cx.new(|_| ActionLog::new(project.clone())); + let model = Arc::new(FakeLanguageModel::default()); + + // First, test with format_on_save enabled + cx.update(|cx| { + SettingsStore::update_global(cx, |store, cx| { + store.update_user_settings::(cx, |settings| { + settings.defaults.format_on_save = Some(FormatOnSave::On); + settings.defaults.formatter = Some(SelectedFormatter::Auto); + }); + }); + }); + + // Have the model stream unformatted content + let edit_result = { + let edit_task = cx.update(|cx| { + let input = serde_json::to_value(EditFileToolInput { + display_description: "Create main function".into(), + path: "root/src/main.rs".into(), + mode: EditFileMode::Overwrite, + }) + .unwrap(); + Arc::new(EditFileTool) + .run( + input, + Arc::default(), + project.clone(), + action_log.clone(), + model.clone(), + None, + cx, + ) + .output + }); + + // Stream the unformatted content + cx.executor().run_until_parked(); + model.stream_last_completion_response(UNFORMATTED_CONTENT.to_string()); + model.end_last_completion_stream(); + + edit_task.await + }; + assert!(edit_result.is_ok()); + + // Wait for any async operations (e.g. formatting) to complete + cx.executor().run_until_parked(); + + // Read the file to verify it was formatted automatically + let new_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap(); + assert_eq!( + // Ignore carriage returns on Windows + new_content.replace("\r\n", "\n"), + FORMATTED_CONTENT, + "Code should be formatted when format_on_save is enabled" + ); + + // Next, test with format_on_save disabled + cx.update(|cx| { + SettingsStore::update_global(cx, |store, cx| { + store.update_user_settings::(cx, |settings| { + settings.defaults.format_on_save = Some(FormatOnSave::Off); + }); + }); + }); + + // Stream unformatted edits again + let edit_result = { + let edit_task = cx.update(|cx| { + let input = serde_json::to_value(EditFileToolInput { + display_description: "Update main function".into(), + path: "root/src/main.rs".into(), + mode: EditFileMode::Overwrite, + }) + .unwrap(); + Arc::new(EditFileTool) + .run( + input, + Arc::default(), + project.clone(), + action_log.clone(), + model.clone(), + None, + cx, + ) + .output + }); + + // Stream the unformatted content + cx.executor().run_until_parked(); + model.stream_last_completion_response(UNFORMATTED_CONTENT.to_string()); + model.end_last_completion_stream(); + + edit_task.await + }; + assert!(edit_result.is_ok()); + + // Wait for any async operations (e.g. formatting) to complete + cx.executor().run_until_parked(); + + // Verify the file is still unformatted + assert_eq!( + // Ignore carriage returns on Windows + fs.load(path!("/root/src/main.rs").as_ref()) + .await + .unwrap() + .replace("\r\n", "\n"), + UNFORMATTED_CONTENT, + "Code should remain unformatted when format_on_save is disabled" + ); + + // Finally, test with format_on_save set to a list + cx.update(|cx| { + SettingsStore::update_global(cx, |store, cx| { + store.update_user_settings::(cx, |settings| { + settings.defaults.format_on_save = Some(FormatOnSave::List(FormatterList( + vec![Formatter::LanguageServer { name: None }].into(), + ))); + }); + }); + }); + + // Stream unformatted edits again + let edit_result = { + let edit_task = cx.update(|cx| { + let input = serde_json::to_value(EditFileToolInput { + display_description: "Update main function with list formatter".into(), + path: "root/src/main.rs".into(), + mode: EditFileMode::Overwrite, + }) + .unwrap(); + Arc::new(EditFileTool) + .run( + input, + Arc::default(), + project.clone(), + action_log.clone(), + model.clone(), + None, + cx, + ) + .output + }); + + // Stream the unformatted content + cx.executor().run_until_parked(); + model.stream_last_completion_response(UNFORMATTED_CONTENT.to_string()); + model.end_last_completion_stream(); + + edit_task.await + }; + assert!(edit_result.is_ok()); + + // Wait for any async operations (e.g. formatting) to complete + cx.executor().run_until_parked(); + + // Read the file to verify it was formatted with the specified formatter + assert_eq!( + // Ignore carriage returns on Windows + fs.load(path!("/root/src/main.rs").as_ref()) + .await + .unwrap() + .replace("\r\n", "\n"), + FORMATTED_CONTENT, + "Code should be formatted when format_on_save is set to a list" + ); + } }