Have edit_file_tool
respect format_on_save
(#31047)
Release Notes: - Agents now automatically format after edits if `format_on_save` is enabled.
This commit is contained in:
parent
f3c2e71ca7
commit
cb112a4012
3 changed files with 419 additions and 3 deletions
1
Cargo.lock
generated
1
Cargo.lock
generated
|
@ -683,6 +683,7 @@ dependencies = [
|
|||
"language_model",
|
||||
"language_models",
|
||||
"log",
|
||||
"lsp",
|
||||
"markdown",
|
||||
"open",
|
||||
"paths",
|
||||
|
|
|
@ -59,11 +59,12 @@ ui.workspace = true
|
|||
util.workspace = true
|
||||
web_search.workspace = true
|
||||
which.workspace = true
|
||||
workspace-hack.workspace = true
|
||||
workspace.workspace = true
|
||||
workspace-hack.workspace = true
|
||||
zed_llm_client.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
lsp = { workspace = true, features = ["test-support"] }
|
||||
client = { workspace = true, features = ["test-support"] }
|
||||
clock = { workspace = true, features = ["test-support"] }
|
||||
collections = { workspace = true, features = ["test-support"] }
|
||||
|
|
|
@ -8,6 +8,10 @@ use assistant_tool::{
|
|||
ActionLog, AnyToolCard, Tool, ToolCard, ToolResult, ToolResultContent, ToolResultOutput,
|
||||
ToolUseStatus,
|
||||
};
|
||||
use language::language_settings::{self, FormatOnSave};
|
||||
use project::lsp_store::{FormatTrigger, LspFormatTarget};
|
||||
use std::collections::HashSet;
|
||||
|
||||
use buffer_diff::{BufferDiff, BufferDiffSnapshot};
|
||||
use editor::{Editor, EditorMode, MultiBuffer, PathKey};
|
||||
use futures::StreamExt;
|
||||
|
@ -249,6 +253,40 @@ impl Tool for EditFileTool {
|
|||
}
|
||||
let agent_output = output.await?;
|
||||
|
||||
// Format buffer if format_on_save is enabled, before saving.
|
||||
// If any part of the formatting operation fails, log an error but
|
||||
// don't block the completion of the edit tool's work.
|
||||
let should_format = buffer
|
||||
.read_with(cx, |buffer, cx| {
|
||||
let settings = language_settings::language_settings(
|
||||
buffer.language().map(|l| l.name()),
|
||||
buffer.file(),
|
||||
cx,
|
||||
);
|
||||
!matches!(settings.format_on_save, FormatOnSave::Off)
|
||||
})
|
||||
.log_err()
|
||||
.unwrap_or(false);
|
||||
|
||||
if should_format {
|
||||
let buffers = HashSet::from_iter([buffer.clone()]);
|
||||
|
||||
if let Some(format_task) = project
|
||||
.update(cx, move |project, cx| {
|
||||
project.format(
|
||||
buffers,
|
||||
LspFormatTarget::Buffers,
|
||||
false, // Don't push to history since the tool did it.
|
||||
FormatTrigger::Save,
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.log_err()
|
||||
{
|
||||
format_task.await.log_err();
|
||||
}
|
||||
}
|
||||
|
||||
project
|
||||
.update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))?
|
||||
.await?;
|
||||
|
@ -918,11 +956,15 @@ async fn build_buffer_diff(
|
|||
mod tests {
|
||||
use super::*;
|
||||
use client::TelemetrySettings;
|
||||
use fs::FakeFs;
|
||||
use gpui::TestAppContext;
|
||||
use fs::{FakeFs, Fs};
|
||||
use gpui::{TestAppContext, UpdateGlobal};
|
||||
use language::{FakeLspAdapter, Language, LanguageConfig, LanguageMatcher};
|
||||
use language_model::fake_provider::FakeLanguageModel;
|
||||
use language_settings::{AllLanguageSettings, Formatter, FormatterList, SelectedFormatter};
|
||||
use lsp;
|
||||
use serde_json::json;
|
||||
use settings::SettingsStore;
|
||||
use std::sync::Arc;
|
||||
use util::path;
|
||||
|
||||
#[gpui::test]
|
||||
|
@ -1129,4 +1171,376 @@ mod tests {
|
|||
Project::init_settings(cx);
|
||||
});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_remove_trailing_whitespace(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
fs.insert_tree("/root", json!({"src": {}})).await;
|
||||
|
||||
// Create a simple file with trailing whitespace
|
||||
fs.save(
|
||||
path!("/root/src/main.rs").as_ref(),
|
||||
&"initial content".into(),
|
||||
LineEnding::Unix,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
|
||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||
let model = Arc::new(FakeLanguageModel::default());
|
||||
|
||||
// First, test with remove_trailing_whitespace_on_save enabled
|
||||
cx.update(|cx| {
|
||||
SettingsStore::update_global(cx, |store, cx| {
|
||||
store.update_user_settings::<AllLanguageSettings>(cx, |settings| {
|
||||
settings.defaults.remove_trailing_whitespace_on_save = Some(true);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
const CONTENT_WITH_TRAILING_WHITESPACE: &str =
|
||||
"fn main() { \n println!(\"Hello!\"); \n}\n";
|
||||
|
||||
// Have the model stream content that contains trailing whitespace
|
||||
let edit_result = {
|
||||
let edit_task = cx.update(|cx| {
|
||||
let input = serde_json::to_value(EditFileToolInput {
|
||||
display_description: "Create main function".into(),
|
||||
path: "root/src/main.rs".into(),
|
||||
mode: EditFileMode::Overwrite,
|
||||
})
|
||||
.unwrap();
|
||||
Arc::new(EditFileTool)
|
||||
.run(
|
||||
input,
|
||||
Arc::default(),
|
||||
project.clone(),
|
||||
action_log.clone(),
|
||||
model.clone(),
|
||||
None,
|
||||
cx,
|
||||
)
|
||||
.output
|
||||
});
|
||||
|
||||
// Stream the content with trailing whitespace
|
||||
cx.executor().run_until_parked();
|
||||
model.stream_last_completion_response(CONTENT_WITH_TRAILING_WHITESPACE.to_string());
|
||||
model.end_last_completion_stream();
|
||||
|
||||
edit_task.await
|
||||
};
|
||||
assert!(edit_result.is_ok());
|
||||
|
||||
// Wait for any async operations (e.g. formatting) to complete
|
||||
cx.executor().run_until_parked();
|
||||
|
||||
// Read the file to verify trailing whitespace was removed automatically
|
||||
assert_eq!(
|
||||
// Ignore carriage returns on Windows
|
||||
fs.load(path!("/root/src/main.rs").as_ref())
|
||||
.await
|
||||
.unwrap()
|
||||
.replace("\r\n", "\n"),
|
||||
"fn main() {\n println!(\"Hello!\");\n}\n",
|
||||
"Trailing whitespace should be removed when remove_trailing_whitespace_on_save is enabled"
|
||||
);
|
||||
|
||||
// Next, test with remove_trailing_whitespace_on_save disabled
|
||||
cx.update(|cx| {
|
||||
SettingsStore::update_global(cx, |store, cx| {
|
||||
store.update_user_settings::<AllLanguageSettings>(cx, |settings| {
|
||||
settings.defaults.remove_trailing_whitespace_on_save = Some(false);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
// Stream edits again with trailing whitespace
|
||||
let edit_result = {
|
||||
let edit_task = cx.update(|cx| {
|
||||
let input = serde_json::to_value(EditFileToolInput {
|
||||
display_description: "Update main function".into(),
|
||||
path: "root/src/main.rs".into(),
|
||||
mode: EditFileMode::Overwrite,
|
||||
})
|
||||
.unwrap();
|
||||
Arc::new(EditFileTool)
|
||||
.run(
|
||||
input,
|
||||
Arc::default(),
|
||||
project.clone(),
|
||||
action_log.clone(),
|
||||
model.clone(),
|
||||
None,
|
||||
cx,
|
||||
)
|
||||
.output
|
||||
});
|
||||
|
||||
// Stream the content with trailing whitespace
|
||||
cx.executor().run_until_parked();
|
||||
model.stream_last_completion_response(CONTENT_WITH_TRAILING_WHITESPACE.to_string());
|
||||
model.end_last_completion_stream();
|
||||
|
||||
edit_task.await
|
||||
};
|
||||
assert!(edit_result.is_ok());
|
||||
|
||||
// Wait for any async operations (e.g. formatting) to complete
|
||||
cx.executor().run_until_parked();
|
||||
|
||||
// Verify the file still has trailing whitespace
|
||||
// Read the file again - it should still have trailing whitespace
|
||||
let final_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap();
|
||||
assert_eq!(
|
||||
// Ignore carriage returns on Windows
|
||||
final_content.replace("\r\n", "\n"),
|
||||
CONTENT_WITH_TRAILING_WHITESPACE,
|
||||
"Trailing whitespace should remain when remove_trailing_whitespace_on_save is disabled"
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_format_on_save(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
fs.insert_tree("/root", json!({"src": {}})).await;
|
||||
|
||||
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
|
||||
|
||||
// Set up a Rust language with LSP formatting support
|
||||
let rust_language = Arc::new(Language::new(
|
||||
LanguageConfig {
|
||||
name: "Rust".into(),
|
||||
matcher: LanguageMatcher {
|
||||
path_suffixes: vec!["rs".to_string()],
|
||||
..Default::default()
|
||||
},
|
||||
..Default::default()
|
||||
},
|
||||
None,
|
||||
));
|
||||
|
||||
// Register the language and fake LSP
|
||||
let language_registry = project.read_with(cx, |project, _| project.languages().clone());
|
||||
language_registry.add(rust_language);
|
||||
|
||||
let mut fake_language_servers = language_registry.register_fake_lsp(
|
||||
"Rust",
|
||||
FakeLspAdapter {
|
||||
capabilities: lsp::ServerCapabilities {
|
||||
document_formatting_provider: Some(lsp::OneOf::Left(true)),
|
||||
..Default::default()
|
||||
},
|
||||
..Default::default()
|
||||
},
|
||||
);
|
||||
|
||||
// Create the file
|
||||
fs.save(
|
||||
path!("/root/src/main.rs").as_ref(),
|
||||
&"initial content".into(),
|
||||
LineEnding::Unix,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Open the buffer to trigger LSP initialization
|
||||
let buffer = project
|
||||
.update(cx, |project, cx| {
|
||||
project.open_local_buffer(path!("/root/src/main.rs"), cx)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Register the buffer with language servers
|
||||
let _handle = project.update(cx, |project, cx| {
|
||||
project.register_buffer_with_language_servers(&buffer, cx)
|
||||
});
|
||||
|
||||
const UNFORMATTED_CONTENT: &str = "fn main() {println!(\"Hello!\");}\n";
|
||||
const FORMATTED_CONTENT: &str =
|
||||
"This file was formatted by the fake formatter in the test.\n";
|
||||
|
||||
// Get the fake language server and set up formatting handler
|
||||
let fake_language_server = fake_language_servers.next().await.unwrap();
|
||||
fake_language_server.set_request_handler::<lsp::request::Formatting, _, _>({
|
||||
|_, _| async move {
|
||||
Ok(Some(vec![lsp::TextEdit {
|
||||
range: lsp::Range::new(lsp::Position::new(0, 0), lsp::Position::new(1, 0)),
|
||||
new_text: FORMATTED_CONTENT.to_string(),
|
||||
}]))
|
||||
}
|
||||
});
|
||||
|
||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||
let model = Arc::new(FakeLanguageModel::default());
|
||||
|
||||
// First, test with format_on_save enabled
|
||||
cx.update(|cx| {
|
||||
SettingsStore::update_global(cx, |store, cx| {
|
||||
store.update_user_settings::<AllLanguageSettings>(cx, |settings| {
|
||||
settings.defaults.format_on_save = Some(FormatOnSave::On);
|
||||
settings.defaults.formatter = Some(SelectedFormatter::Auto);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
// Have the model stream unformatted content
|
||||
let edit_result = {
|
||||
let edit_task = cx.update(|cx| {
|
||||
let input = serde_json::to_value(EditFileToolInput {
|
||||
display_description: "Create main function".into(),
|
||||
path: "root/src/main.rs".into(),
|
||||
mode: EditFileMode::Overwrite,
|
||||
})
|
||||
.unwrap();
|
||||
Arc::new(EditFileTool)
|
||||
.run(
|
||||
input,
|
||||
Arc::default(),
|
||||
project.clone(),
|
||||
action_log.clone(),
|
||||
model.clone(),
|
||||
None,
|
||||
cx,
|
||||
)
|
||||
.output
|
||||
});
|
||||
|
||||
// Stream the unformatted content
|
||||
cx.executor().run_until_parked();
|
||||
model.stream_last_completion_response(UNFORMATTED_CONTENT.to_string());
|
||||
model.end_last_completion_stream();
|
||||
|
||||
edit_task.await
|
||||
};
|
||||
assert!(edit_result.is_ok());
|
||||
|
||||
// Wait for any async operations (e.g. formatting) to complete
|
||||
cx.executor().run_until_parked();
|
||||
|
||||
// Read the file to verify it was formatted automatically
|
||||
let new_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap();
|
||||
assert_eq!(
|
||||
// Ignore carriage returns on Windows
|
||||
new_content.replace("\r\n", "\n"),
|
||||
FORMATTED_CONTENT,
|
||||
"Code should be formatted when format_on_save is enabled"
|
||||
);
|
||||
|
||||
// Next, test with format_on_save disabled
|
||||
cx.update(|cx| {
|
||||
SettingsStore::update_global(cx, |store, cx| {
|
||||
store.update_user_settings::<AllLanguageSettings>(cx, |settings| {
|
||||
settings.defaults.format_on_save = Some(FormatOnSave::Off);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
// Stream unformatted edits again
|
||||
let edit_result = {
|
||||
let edit_task = cx.update(|cx| {
|
||||
let input = serde_json::to_value(EditFileToolInput {
|
||||
display_description: "Update main function".into(),
|
||||
path: "root/src/main.rs".into(),
|
||||
mode: EditFileMode::Overwrite,
|
||||
})
|
||||
.unwrap();
|
||||
Arc::new(EditFileTool)
|
||||
.run(
|
||||
input,
|
||||
Arc::default(),
|
||||
project.clone(),
|
||||
action_log.clone(),
|
||||
model.clone(),
|
||||
None,
|
||||
cx,
|
||||
)
|
||||
.output
|
||||
});
|
||||
|
||||
// Stream the unformatted content
|
||||
cx.executor().run_until_parked();
|
||||
model.stream_last_completion_response(UNFORMATTED_CONTENT.to_string());
|
||||
model.end_last_completion_stream();
|
||||
|
||||
edit_task.await
|
||||
};
|
||||
assert!(edit_result.is_ok());
|
||||
|
||||
// Wait for any async operations (e.g. formatting) to complete
|
||||
cx.executor().run_until_parked();
|
||||
|
||||
// Verify the file is still unformatted
|
||||
assert_eq!(
|
||||
// Ignore carriage returns on Windows
|
||||
fs.load(path!("/root/src/main.rs").as_ref())
|
||||
.await
|
||||
.unwrap()
|
||||
.replace("\r\n", "\n"),
|
||||
UNFORMATTED_CONTENT,
|
||||
"Code should remain unformatted when format_on_save is disabled"
|
||||
);
|
||||
|
||||
// Finally, test with format_on_save set to a list
|
||||
cx.update(|cx| {
|
||||
SettingsStore::update_global(cx, |store, cx| {
|
||||
store.update_user_settings::<AllLanguageSettings>(cx, |settings| {
|
||||
settings.defaults.format_on_save = Some(FormatOnSave::List(FormatterList(
|
||||
vec![Formatter::LanguageServer { name: None }].into(),
|
||||
)));
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
// Stream unformatted edits again
|
||||
let edit_result = {
|
||||
let edit_task = cx.update(|cx| {
|
||||
let input = serde_json::to_value(EditFileToolInput {
|
||||
display_description: "Update main function with list formatter".into(),
|
||||
path: "root/src/main.rs".into(),
|
||||
mode: EditFileMode::Overwrite,
|
||||
})
|
||||
.unwrap();
|
||||
Arc::new(EditFileTool)
|
||||
.run(
|
||||
input,
|
||||
Arc::default(),
|
||||
project.clone(),
|
||||
action_log.clone(),
|
||||
model.clone(),
|
||||
None,
|
||||
cx,
|
||||
)
|
||||
.output
|
||||
});
|
||||
|
||||
// Stream the unformatted content
|
||||
cx.executor().run_until_parked();
|
||||
model.stream_last_completion_response(UNFORMATTED_CONTENT.to_string());
|
||||
model.end_last_completion_stream();
|
||||
|
||||
edit_task.await
|
||||
};
|
||||
assert!(edit_result.is_ok());
|
||||
|
||||
// Wait for any async operations (e.g. formatting) to complete
|
||||
cx.executor().run_until_parked();
|
||||
|
||||
// Read the file to verify it was formatted with the specified formatter
|
||||
assert_eq!(
|
||||
// Ignore carriage returns on Windows
|
||||
fs.load(path!("/root/src/main.rs").as_ref())
|
||||
.await
|
||||
.unwrap()
|
||||
.replace("\r\n", "\n"),
|
||||
FORMATTED_CONTENT,
|
||||
"Code should be formatted when format_on_save is set to a list"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue