Start porting tests for EditFileTool

Co-authored-by: Ben Brandt <benjamin.j.brandt@gmail.com>
This commit is contained in:
Antonio Scandurra 2025-08-08 11:27:08 +02:00
parent da5f2978fd
commit 8f390d9c6d
6 changed files with 1271 additions and 1297 deletions

View file

@ -417,7 +417,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
thread.add_tool(ThinkingTool); thread.add_tool(ThinkingTool);
thread.add_tool(FindPathTool::new(project.clone())); thread.add_tool(FindPathTool::new(project.clone()));
thread.add_tool(ReadFileTool::new(project.clone(), action_log)); thread.add_tool(ReadFileTool::new(project.clone(), action_log));
thread.add_tool(EditFileTool::new(project.clone(), cx.entity())); thread.add_tool(EditFileTool::new(cx.entity()));
thread thread
}); });

View file

@ -9,5 +9,6 @@ mod tests;
pub use agent::*; pub use agent::*;
pub use native_agent_server::NativeAgentServer; pub use native_agent_server::NativeAgentServer;
pub use templates::*;
pub use thread::*; pub use thread::*;
pub use tools::*; pub use tools::*;

View file

@ -1,5 +1,4 @@
use super::*; use super::*;
use crate::templates::Templates;
use acp_thread::AgentConnection; use acp_thread::AgentConnection;
use agent_client_protocol::{self as acp}; use agent_client_protocol::{self as acp};
use anyhow::Result; use anyhow::Result;

View file

@ -1,4 +1,4 @@
use crate::templates::{SystemPromptTemplate, Template, Templates}; use crate::{SystemPromptTemplate, Template, Templates};
use acp_thread::Diff; use acp_thread::Diff;
use agent_client_protocol as acp; use agent_client_protocol as acp;
use anyhow::{anyhow, Context as _, Result}; use anyhow::{anyhow, Context as _, Result};
@ -133,12 +133,13 @@ pub struct Thread {
project_context: Rc<RefCell<ProjectContext>>, project_context: Rc<RefCell<ProjectContext>>,
templates: Arc<Templates>, templates: Arc<Templates>,
pub selected_model: Arc<dyn LanguageModel>, pub selected_model: Arc<dyn LanguageModel>,
project: Entity<Project>,
action_log: Entity<ActionLog>, action_log: Entity<ActionLog>,
} }
impl Thread { impl Thread {
pub fn new( pub fn new(
_project: Entity<Project>, project: Entity<Project>,
project_context: Rc<RefCell<ProjectContext>>, project_context: Rc<RefCell<ProjectContext>>,
action_log: Entity<ActionLog>, action_log: Entity<ActionLog>,
templates: Arc<Templates>, templates: Arc<Templates>,
@ -153,10 +154,19 @@ impl Thread {
project_context, project_context,
templates, templates,
selected_model: default_model, selected_model: default_model,
project,
action_log, action_log,
} }
} }
pub fn project(&self) -> &Entity<Project> {
&self.project
}
pub fn action_log(&self) -> &Entity<ActionLog> {
&self.action_log
}
pub fn set_mode(&mut self, mode: CompletionMode) { pub fn set_mode(&mut self, mode: CompletionMode) {
self.completion_mode = mode; self.completion_mode = mode;
} }
@ -323,10 +333,6 @@ impl Thread {
events_rx events_rx
} }
pub fn action_log(&self) -> &Entity<ActionLog> {
&self.action_log
}
pub fn build_system_message(&self) -> AgentMessage { pub fn build_system_message(&self) -> AgentMessage {
log::debug!("Building system message"); log::debug!("Building system message");
let prompt = SystemPromptTemplate { let prompt = SystemPromptTemplate {
@ -901,6 +907,29 @@ pub struct ToolCallEventStream {
} }
impl ToolCallEventStream { impl ToolCallEventStream {
#[cfg(test)]
pub fn test() -> (
Self,
mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
) {
let (events_tx, events_rx) =
mpsc::unbounded::<Result<AgentResponseEvent, LanguageModelCompletionError>>();
let stream = ToolCallEventStream::new(
&LanguageModelToolUse {
id: "test_id".into(),
name: "test_tool".into(),
raw_input: String::new(),
input: serde_json::Value::Null,
is_input_complete: true,
},
acp::ToolKind::Other,
AgentResponseEventStream(events_tx),
);
(stream, events_rx)
}
fn new( fn new(
tool_use: &LanguageModelToolUse, tool_use: &LanguageModelToolUse,
kind: acp::ToolKind, kind: acp::ToolKind,
@ -934,38 +963,3 @@ impl ToolCallEventStream {
) )
} }
} }
#[cfg(test)]
pub struct TestToolCallEventStream {
stream: ToolCallEventStream,
_events_rx: mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
}
#[cfg(test)]
impl TestToolCallEventStream {
pub fn new() -> Self {
let (events_tx, events_rx) =
mpsc::unbounded::<Result<AgentResponseEvent, LanguageModelCompletionError>>();
let stream = ToolCallEventStream::new(
&LanguageModelToolUse {
id: "test_id".into(),
name: "test_tool".into(),
raw_input: String::new(),
input: serde_json::Value::Null,
is_input_complete: true,
},
acp::ToolKind::Other,
AgentResponseEventStream(events_tx),
);
Self {
stream,
_events_rx: events_rx,
}
}
pub fn stream(&self) -> ToolCallEventStream {
self.stream.clone()
}
}

View file

@ -75,11 +75,6 @@ pub struct EditFileToolInput {
/// When a file already exists or you just created it, prefer editing /// When a file already exists or you just created it, prefer editing
/// it as opposed to recreating it from scratch. /// it as opposed to recreating it from scratch.
pub mode: EditFileMode, pub mode: EditFileMode,
/// The new content for the file (required for create and overwrite modes)
/// For edit mode, this field is not used - edits happen through the edit agent
#[serde(default)]
pub content: Option<String>,
} }
#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] #[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
@ -91,13 +86,12 @@ pub enum EditFileMode {
} }
pub struct EditFileTool { pub struct EditFileTool {
project: Entity<Project>,
thread: Entity<Thread>, thread: Entity<Thread>,
} }
impl EditFileTool { impl EditFileTool {
pub fn new(project: Entity<Project>, thread: Entity<Thread>) -> Self { pub fn new(thread: Entity<Thread>) -> Self {
Self { project, thread } Self { thread }
} }
fn authorize( fn authorize(
@ -136,7 +130,8 @@ impl EditFileTool {
// Check if path is inside the global config directory // Check if path is inside the global config directory
// First check if it's already inside project - if not, try to canonicalize // First check if it's already inside project - if not, try to canonicalize
let project_path = self.project.read(cx).find_project_path(&input.path, cx); let thread = self.thread.read(cx);
let project_path = thread.project().read(cx).find_project_path(&input.path, cx);
// If the path is inside the project, and it's not one of the above edge cases, // If the path is inside the project, and it's not one of the above edge cases,
// then no confirmation is necessary. Otherwise, confirmation is necessary. // then no confirmation is necessary. Otherwise, confirmation is necessary.
@ -170,12 +165,12 @@ impl AgentTool for EditFileTool {
event_stream: ToolCallEventStream, event_stream: ToolCallEventStream,
cx: &mut App, cx: &mut App,
) -> Task<Result<String>> { ) -> Task<Result<String>> {
let project_path = match resolve_path(&input, self.project.clone(), cx) { let project = self.thread.read(cx).project().clone();
let project_path = match resolve_path(&input, project.clone(), cx) {
Ok(path) => path, Ok(path) => path,
Err(err) => return Task::ready(Err(anyhow!(err))).into(), Err(err) => return Task::ready(Err(anyhow!(err))).into(),
}; };
let project = self.project.clone();
let request = self.thread.update(cx, |thread, cx| { let request = self.thread.update(cx, |thread, cx| {
thread.build_completion_request(CompletionIntent::ToolResults, cx) thread.build_completion_request(CompletionIntent::ToolResults, cx)
}); });
@ -410,54 +405,46 @@ fn resolve_path(
} }
} }
// todo! restore tests #[cfg(test)]
// #[cfg(test)] mod tests {
// mod tests { use crate::Templates;
// use super::*;
// use ::fs::Fs;
// use client::TelemetrySettings;
// use gpui::{TestAppContext, UpdateGlobal};
// use language_model::fake_provider::FakeLanguageModel;
// use serde_json::json;
// use settings::SettingsStore;
// use std::fs;
// use util::path;
// #[gpui::test] use super::*;
// async fn test_edit_nonexistent_file(cx: &mut TestAppContext) { use assistant_tool::ActionLog;
// init_test(cx); use client::TelemetrySettings;
use gpui::TestAppContext;
use language_model::fake_provider::FakeLanguageModel;
use serde_json::json;
use settings::SettingsStore;
use std::rc::Rc;
use util::path;
// let fs = project::FakeFs::new(cx.executor()); #[gpui::test]
// fs.insert_tree("/root", json!({})).await; async fn test_edit_nonexistent_file(cx: &mut TestAppContext) {
// let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; init_test(cx);
// let action_log = cx.new(|_| ActionLog::new(project.clone()));
// let model = Arc::new(FakeLanguageModel::default()); let fs = project::FakeFs::new(cx.executor());
// let result = cx fs.insert_tree("/root", json!({})).await;
// .update(|cx| { let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
// let input = serde_json::to_value(EditFileToolInput { let action_log = cx.new(|_| ActionLog::new(project.clone()));
// display_description: "Some edit".into(), let model = Arc::new(FakeLanguageModel::default());
// path: "root/nonexistent_file.txt".into(), let thread =
// mode: EditFileMode::Edit, cx.new(|_| Thread::new(project, Rc::default(), action_log, Templates::new(), model));
// }) let result = cx
// .unwrap(); .update(|cx| {
// Arc::new(EditFileTool) let input = EditFileToolInput {
// .run( display_description: "Some edit".into(),
// input, path: "root/nonexistent_file.txt".into(),
// Arc::default(), mode: EditFileMode::Edit,
// project.clone(), };
// action_log, Arc::new(EditFileTool { thread }).run(input, ToolCallEventStream::test().0, cx)
// model, })
// None, .await;
// cx, assert_eq!(
// ) result.unwrap_err().to_string(),
// .output "Can't edit file: path not found"
// }) );
// .await; }
// assert_eq!(
// result.unwrap_err().to_string(),
// "Can't edit file: path not found"
// );
// }
// #[gpui::test] // #[gpui::test]
// async fn test_resolve_path_for_creating_file(cx: &mut TestAppContext) { // async fn test_resolve_path_for_creating_file(cx: &mut TestAppContext) {
@ -618,16 +605,16 @@ fn resolve_path(
// ); // );
// } // }
// fn init_test(cx: &mut TestAppContext) { fn init_test(cx: &mut TestAppContext) {
// cx.update(|cx| { cx.update(|cx| {
// let settings_store = SettingsStore::test(cx); let settings_store = SettingsStore::test(cx);
// cx.set_global(settings_store); cx.set_global(settings_store);
// language::init(cx); language::init(cx);
// TelemetrySettings::register(cx); TelemetrySettings::register(cx);
// agent_settings::AgentSettings::register(cx); agent_settings::AgentSettings::register(cx);
// Project::init_settings(cx); Project::init_settings(cx);
// }); });
// } }
// fn init_test_with_config(cx: &mut TestAppContext, data_dir: &Path) { // fn init_test_with_config(cx: &mut TestAppContext, data_dir: &Path) {
// cx.update(|cx| { // cx.update(|cx| {
@ -1619,4 +1606,4 @@ fn resolve_path(
// ); // );
// }); // });
// } // }
// } }

View file

@ -285,8 +285,6 @@ impl AgentTool for ReadFileTool {
#[cfg(test)] #[cfg(test)]
mod test { mod test {
use crate::TestToolCallEventStream;
use super::*; use super::*;
use gpui::{AppContext, TestAppContext, UpdateGlobal as _}; use gpui::{AppContext, TestAppContext, UpdateGlobal as _};
use language::{tree_sitter_rust, Language, LanguageConfig, LanguageMatcher}; use language::{tree_sitter_rust, Language, LanguageConfig, LanguageMatcher};
@ -304,7 +302,7 @@ mod test {
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
let action_log = cx.new(|_| ActionLog::new(project.clone())); let action_log = cx.new(|_| ActionLog::new(project.clone()));
let tool = Arc::new(ReadFileTool::new(project, action_log)); let tool = Arc::new(ReadFileTool::new(project, action_log));
let event_stream = TestToolCallEventStream::new(); let (event_stream, _) = ToolCallEventStream::test();
let result = cx let result = cx
.update(|cx| { .update(|cx| {
@ -313,7 +311,7 @@ mod test {
start_line: None, start_line: None,
end_line: None, end_line: None,
}; };
tool.run(input, event_stream.stream(), cx) tool.run(input, event_stream, cx)
}) })
.await; .await;
assert_eq!( assert_eq!(
@ -321,6 +319,7 @@ mod test {
"root/nonexistent_file.txt not found" "root/nonexistent_file.txt not found"
); );
} }
#[gpui::test] #[gpui::test]
async fn test_read_small_file(cx: &mut TestAppContext) { async fn test_read_small_file(cx: &mut TestAppContext) {
init_test(cx); init_test(cx);
@ -336,7 +335,6 @@ mod test {
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
let action_log = cx.new(|_| ActionLog::new(project.clone())); let action_log = cx.new(|_| ActionLog::new(project.clone()));
let tool = Arc::new(ReadFileTool::new(project, action_log)); let tool = Arc::new(ReadFileTool::new(project, action_log));
let event_stream = TestToolCallEventStream::new();
let result = cx let result = cx
.update(|cx| { .update(|cx| {
let input = ReadFileToolInput { let input = ReadFileToolInput {
@ -344,7 +342,7 @@ mod test {
start_line: None, start_line: None,
end_line: None, end_line: None,
}; };
tool.run(input, event_stream.stream(), cx) tool.run(input, ToolCallEventStream::test().0, cx)
}) })
.await; .await;
assert_eq!(result.unwrap(), "This is a small file content"); assert_eq!(result.unwrap(), "This is a small file content");
@ -367,7 +365,6 @@ mod test {
language_registry.add(Arc::new(rust_lang())); language_registry.add(Arc::new(rust_lang()));
let action_log = cx.new(|_| ActionLog::new(project.clone())); let action_log = cx.new(|_| ActionLog::new(project.clone()));
let tool = Arc::new(ReadFileTool::new(project, action_log)); let tool = Arc::new(ReadFileTool::new(project, action_log));
let event_stream = TestToolCallEventStream::new();
let content = cx let content = cx
.update(|cx| { .update(|cx| {
let input = ReadFileToolInput { let input = ReadFileToolInput {
@ -375,7 +372,7 @@ mod test {
start_line: None, start_line: None,
end_line: None, end_line: None,
}; };
tool.clone().run(input, event_stream.stream(), cx) tool.clone().run(input, ToolCallEventStream::test().0, cx)
}) })
.await .await
.unwrap(); .unwrap();
@ -399,7 +396,7 @@ mod test {
start_line: None, start_line: None,
end_line: None, end_line: None,
}; };
tool.run(input, event_stream.stream(), cx) tool.run(input, ToolCallEventStream::test().0, cx)
}) })
.await; .await;
let content = result.unwrap(); let content = result.unwrap();
@ -438,7 +435,6 @@ mod test {
let action_log = cx.new(|_| ActionLog::new(project.clone())); let action_log = cx.new(|_| ActionLog::new(project.clone()));
let tool = Arc::new(ReadFileTool::new(project, action_log)); let tool = Arc::new(ReadFileTool::new(project, action_log));
let event_stream = TestToolCallEventStream::new();
let result = cx let result = cx
.update(|cx| { .update(|cx| {
let input = ReadFileToolInput { let input = ReadFileToolInput {
@ -446,7 +442,7 @@ mod test {
start_line: Some(2), start_line: Some(2),
end_line: Some(4), end_line: Some(4),
}; };
tool.run(input, event_stream.stream(), cx) tool.run(input, ToolCallEventStream::test().0, cx)
}) })
.await; .await;
assert_eq!(result.unwrap(), "Line 2\nLine 3\nLine 4"); assert_eq!(result.unwrap(), "Line 2\nLine 3\nLine 4");
@ -467,7 +463,6 @@ mod test {
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
let action_log = cx.new(|_| ActionLog::new(project.clone())); let action_log = cx.new(|_| ActionLog::new(project.clone()));
let tool = Arc::new(ReadFileTool::new(project, action_log)); let tool = Arc::new(ReadFileTool::new(project, action_log));
let event_stream = TestToolCallEventStream::new();
// start_line of 0 should be treated as 1 // start_line of 0 should be treated as 1
let result = cx let result = cx
@ -477,7 +472,7 @@ mod test {
start_line: Some(0), start_line: Some(0),
end_line: Some(2), end_line: Some(2),
}; };
tool.clone().run(input, event_stream.stream(), cx) tool.clone().run(input, ToolCallEventStream::test().0, cx)
}) })
.await; .await;
assert_eq!(result.unwrap(), "Line 1\nLine 2"); assert_eq!(result.unwrap(), "Line 1\nLine 2");
@ -490,7 +485,7 @@ mod test {
start_line: Some(1), start_line: Some(1),
end_line: Some(0), end_line: Some(0),
}; };
tool.clone().run(input, event_stream.stream(), cx) tool.clone().run(input, ToolCallEventStream::test().0, cx)
}) })
.await; .await;
assert_eq!(result.unwrap(), "Line 1"); assert_eq!(result.unwrap(), "Line 1");
@ -503,7 +498,7 @@ mod test {
start_line: Some(3), start_line: Some(3),
end_line: Some(2), end_line: Some(2),
}; };
tool.clone().run(input, event_stream.stream(), cx) tool.clone().run(input, ToolCallEventStream::test().0, cx)
}) })
.await; .await;
assert_eq!(result.unwrap(), "Line 3"); assert_eq!(result.unwrap(), "Line 3");
@ -612,7 +607,6 @@ mod test {
let project = Project::test(fs.clone(), [path!("/project_root").as_ref()], cx).await; let project = Project::test(fs.clone(), [path!("/project_root").as_ref()], cx).await;
let action_log = cx.new(|_| ActionLog::new(project.clone())); let action_log = cx.new(|_| ActionLog::new(project.clone()));
let tool = Arc::new(ReadFileTool::new(project, action_log)); let tool = Arc::new(ReadFileTool::new(project, action_log));
let event_stream = TestToolCallEventStream::new();
// Reading a file outside the project worktree should fail // Reading a file outside the project worktree should fail
let result = cx let result = cx
@ -622,7 +616,7 @@ mod test {
start_line: None, start_line: None,
end_line: None, end_line: None,
}; };
tool.clone().run(input, event_stream.stream(), cx) tool.clone().run(input, ToolCallEventStream::test().0, cx)
}) })
.await; .await;
assert!( assert!(
@ -638,7 +632,7 @@ mod test {
start_line: None, start_line: None,
end_line: None, end_line: None,
}; };
tool.clone().run(input, event_stream.stream(), cx) tool.clone().run(input, ToolCallEventStream::test().0, cx)
}) })
.await; .await;
assert!( assert!(
@ -654,7 +648,7 @@ mod test {
start_line: None, start_line: None,
end_line: None, end_line: None,
}; };
tool.clone().run(input, event_stream.stream(), cx) tool.clone().run(input, ToolCallEventStream::test().0, cx)
}) })
.await; .await;
assert!( assert!(
@ -669,7 +663,7 @@ mod test {
start_line: None, start_line: None,
end_line: None, end_line: None,
}; };
tool.clone().run(input, event_stream.stream(), cx) tool.clone().run(input, ToolCallEventStream::test().0, cx)
}) })
.await; .await;
assert!( assert!(
@ -685,7 +679,7 @@ mod test {
start_line: None, start_line: None,
end_line: None, end_line: None,
}; };
tool.clone().run(input, event_stream.stream(), cx) tool.clone().run(input, ToolCallEventStream::test().0, cx)
}) })
.await; .await;
assert!( assert!(
@ -700,7 +694,7 @@ mod test {
start_line: None, start_line: None,
end_line: None, end_line: None,
}; };
tool.clone().run(input, event_stream.stream(), cx) tool.clone().run(input, ToolCallEventStream::test().0, cx)
}) })
.await; .await;
assert!( assert!(
@ -715,7 +709,7 @@ mod test {
start_line: None, start_line: None,
end_line: None, end_line: None,
}; };
tool.clone().run(input, event_stream.stream(), cx) tool.clone().run(input, ToolCallEventStream::test().0, cx)
}) })
.await; .await;
assert!( assert!(
@ -731,7 +725,7 @@ mod test {
start_line: None, start_line: None,
end_line: None, end_line: None,
}; };
tool.clone().run(input, event_stream.stream(), cx) tool.clone().run(input, ToolCallEventStream::test().0, cx)
}) })
.await; .await;
assert!(result.is_ok(), "Should be able to read normal files"); assert!(result.is_ok(), "Should be able to read normal files");
@ -745,7 +739,7 @@ mod test {
start_line: None, start_line: None,
end_line: None, end_line: None,
}; };
tool.run(input, event_stream.stream(), cx) tool.run(input, ToolCallEventStream::test().0, cx)
}) })
.await; .await;
assert!( assert!(
@ -826,7 +820,6 @@ mod test {
let action_log = cx.new(|_| ActionLog::new(project.clone())); let action_log = cx.new(|_| ActionLog::new(project.clone()));
let tool = Arc::new(ReadFileTool::new(project.clone(), action_log.clone())); let tool = Arc::new(ReadFileTool::new(project.clone(), action_log.clone()));
let event_stream = TestToolCallEventStream::new();
// Test reading allowed files in worktree1 // Test reading allowed files in worktree1
let result = cx let result = cx
@ -836,7 +829,7 @@ mod test {
start_line: None, start_line: None,
end_line: None, end_line: None,
}; };
tool.clone().run(input, event_stream.stream(), cx) tool.clone().run(input, ToolCallEventStream::test().0, cx)
}) })
.await .await
.unwrap(); .unwrap();
@ -851,7 +844,7 @@ mod test {
start_line: None, start_line: None,
end_line: None, end_line: None,
}; };
tool.clone().run(input, event_stream.stream(), cx) tool.clone().run(input, ToolCallEventStream::test().0, cx)
}) })
.await; .await;
@ -872,7 +865,7 @@ mod test {
start_line: None, start_line: None,
end_line: None, end_line: None,
}; };
tool.clone().run(input, event_stream.stream(), cx) tool.clone().run(input, ToolCallEventStream::test().0, cx)
}) })
.await; .await;
@ -893,7 +886,7 @@ mod test {
start_line: None, start_line: None,
end_line: None, end_line: None,
}; };
tool.clone().run(input, event_stream.stream(), cx) tool.clone().run(input, ToolCallEventStream::test().0, cx)
}) })
.await .await
.unwrap(); .unwrap();
@ -911,7 +904,7 @@ mod test {
start_line: None, start_line: None,
end_line: None, end_line: None,
}; };
tool.clone().run(input, event_stream.stream(), cx) tool.clone().run(input, ToolCallEventStream::test().0, cx)
}) })
.await; .await;
@ -932,7 +925,7 @@ mod test {
start_line: None, start_line: None,
end_line: None, end_line: None,
}; };
tool.clone().run(input, event_stream.stream(), cx) tool.clone().run(input, ToolCallEventStream::test().0, cx)
}) })
.await; .await;
@ -954,7 +947,7 @@ mod test {
start_line: None, start_line: None,
end_line: None, end_line: None,
}; };
tool.clone().run(input, event_stream.stream(), cx) tool.clone().run(input, ToolCallEventStream::test().0, cx)
}) })
.await; .await;