Start porting tests for EditFileTool

Co-authored-by: Ben Brandt <benjamin.j.brandt@gmail.com>
This commit is contained in:
Antonio Scandurra 2025-08-08 11:27:08 +02:00
parent da5f2978fd
commit 8f390d9c6d
6 changed files with 1271 additions and 1297 deletions

View file

@ -417,7 +417,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
thread.add_tool(ThinkingTool);
thread.add_tool(FindPathTool::new(project.clone()));
thread.add_tool(ReadFileTool::new(project.clone(), action_log));
thread.add_tool(EditFileTool::new(project.clone(), cx.entity()));
thread.add_tool(EditFileTool::new(cx.entity()));
thread
});

View file

@ -9,5 +9,6 @@ mod tests;
pub use agent::*;
pub use native_agent_server::NativeAgentServer;
pub use templates::*;
pub use thread::*;
pub use tools::*;

View file

@ -1,5 +1,4 @@
use super::*;
use crate::templates::Templates;
use acp_thread::AgentConnection;
use agent_client_protocol::{self as acp};
use anyhow::Result;

View file

@ -1,4 +1,4 @@
use crate::templates::{SystemPromptTemplate, Template, Templates};
use crate::{SystemPromptTemplate, Template, Templates};
use acp_thread::Diff;
use agent_client_protocol as acp;
use anyhow::{anyhow, Context as _, Result};
@ -133,12 +133,13 @@ pub struct Thread {
project_context: Rc<RefCell<ProjectContext>>,
templates: Arc<Templates>,
pub selected_model: Arc<dyn LanguageModel>,
project: Entity<Project>,
action_log: Entity<ActionLog>,
}
impl Thread {
pub fn new(
_project: Entity<Project>,
project: Entity<Project>,
project_context: Rc<RefCell<ProjectContext>>,
action_log: Entity<ActionLog>,
templates: Arc<Templates>,
@ -153,10 +154,19 @@ impl Thread {
project_context,
templates,
selected_model: default_model,
project,
action_log,
}
}
pub fn project(&self) -> &Entity<Project> {
&self.project
}
pub fn action_log(&self) -> &Entity<ActionLog> {
&self.action_log
}
pub fn set_mode(&mut self, mode: CompletionMode) {
self.completion_mode = mode;
}
@ -323,10 +333,6 @@ impl Thread {
events_rx
}
pub fn action_log(&self) -> &Entity<ActionLog> {
&self.action_log
}
pub fn build_system_message(&self) -> AgentMessage {
log::debug!("Building system message");
let prompt = SystemPromptTemplate {
@ -901,6 +907,29 @@ pub struct ToolCallEventStream {
}
impl ToolCallEventStream {
#[cfg(test)]
pub fn test() -> (
Self,
mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
) {
let (events_tx, events_rx) =
mpsc::unbounded::<Result<AgentResponseEvent, LanguageModelCompletionError>>();
let stream = ToolCallEventStream::new(
&LanguageModelToolUse {
id: "test_id".into(),
name: "test_tool".into(),
raw_input: String::new(),
input: serde_json::Value::Null,
is_input_complete: true,
},
acp::ToolKind::Other,
AgentResponseEventStream(events_tx),
);
(stream, events_rx)
}
fn new(
tool_use: &LanguageModelToolUse,
kind: acp::ToolKind,
@ -934,38 +963,3 @@ impl ToolCallEventStream {
)
}
}
#[cfg(test)]
pub struct TestToolCallEventStream {
stream: ToolCallEventStream,
_events_rx: mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
}
#[cfg(test)]
impl TestToolCallEventStream {
pub fn new() -> Self {
let (events_tx, events_rx) =
mpsc::unbounded::<Result<AgentResponseEvent, LanguageModelCompletionError>>();
let stream = ToolCallEventStream::new(
&LanguageModelToolUse {
id: "test_id".into(),
name: "test_tool".into(),
raw_input: String::new(),
input: serde_json::Value::Null,
is_input_complete: true,
},
acp::ToolKind::Other,
AgentResponseEventStream(events_tx),
);
Self {
stream,
_events_rx: events_rx,
}
}
pub fn stream(&self) -> ToolCallEventStream {
self.stream.clone()
}
}

View file

@ -75,11 +75,6 @@ pub struct EditFileToolInput {
/// When a file already exists or you just created it, prefer editing
/// it as opposed to recreating it from scratch.
pub mode: EditFileMode,
/// The new content for the file (required for create and overwrite modes)
/// For edit mode, this field is not used - edits happen through the edit agent
#[serde(default)]
pub content: Option<String>,
}
#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
@ -91,13 +86,12 @@ pub enum EditFileMode {
}
pub struct EditFileTool {
project: Entity<Project>,
thread: Entity<Thread>,
}
impl EditFileTool {
pub fn new(project: Entity<Project>, thread: Entity<Thread>) -> Self {
Self { project, thread }
pub fn new(thread: Entity<Thread>) -> Self {
Self { thread }
}
fn authorize(
@ -136,7 +130,8 @@ impl EditFileTool {
// Check if path is inside the global config directory
// First check if it's already inside project - if not, try to canonicalize
let project_path = self.project.read(cx).find_project_path(&input.path, cx);
let thread = self.thread.read(cx);
let project_path = thread.project().read(cx).find_project_path(&input.path, cx);
// If the path is inside the project, and it's not one of the above edge cases,
// then no confirmation is necessary. Otherwise, confirmation is necessary.
@ -170,12 +165,12 @@ impl AgentTool for EditFileTool {
event_stream: ToolCallEventStream,
cx: &mut App,
) -> Task<Result<String>> {
let project_path = match resolve_path(&input, self.project.clone(), cx) {
let project = self.thread.read(cx).project().clone();
let project_path = match resolve_path(&input, project.clone(), cx) {
Ok(path) => path,
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
};
let project = self.project.clone();
let request = self.thread.update(cx, |thread, cx| {
thread.build_completion_request(CompletionIntent::ToolResults, cx)
});
@ -410,54 +405,46 @@ fn resolve_path(
}
}
// todo! restore tests
// #[cfg(test)]
// mod tests {
// use super::*;
// use ::fs::Fs;
// use client::TelemetrySettings;
// use gpui::{TestAppContext, UpdateGlobal};
// use language_model::fake_provider::FakeLanguageModel;
// use serde_json::json;
// use settings::SettingsStore;
// use std::fs;
// use util::path;
#[cfg(test)]
mod tests {
use crate::Templates;
// #[gpui::test]
// async fn test_edit_nonexistent_file(cx: &mut TestAppContext) {
// init_test(cx);
use super::*;
use assistant_tool::ActionLog;
use client::TelemetrySettings;
use gpui::TestAppContext;
use language_model::fake_provider::FakeLanguageModel;
use serde_json::json;
use settings::SettingsStore;
use std::rc::Rc;
use util::path;
// let fs = project::FakeFs::new(cx.executor());
// fs.insert_tree("/root", json!({})).await;
// let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
// let action_log = cx.new(|_| ActionLog::new(project.clone()));
// let model = Arc::new(FakeLanguageModel::default());
// let result = cx
// .update(|cx| {
// let input = serde_json::to_value(EditFileToolInput {
// display_description: "Some edit".into(),
// path: "root/nonexistent_file.txt".into(),
// mode: EditFileMode::Edit,
// })
// .unwrap();
// Arc::new(EditFileTool)
// .run(
// input,
// Arc::default(),
// project.clone(),
// action_log,
// model,
// None,
// cx,
// )
// .output
// })
// .await;
// assert_eq!(
// result.unwrap_err().to_string(),
// "Can't edit file: path not found"
// );
// }
#[gpui::test]
async fn test_edit_nonexistent_file(cx: &mut TestAppContext) {
init_test(cx);
let fs = project::FakeFs::new(cx.executor());
fs.insert_tree("/root", json!({})).await;
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let model = Arc::new(FakeLanguageModel::default());
let thread =
cx.new(|_| Thread::new(project, Rc::default(), action_log, Templates::new(), model));
let result = cx
.update(|cx| {
let input = EditFileToolInput {
display_description: "Some edit".into(),
path: "root/nonexistent_file.txt".into(),
mode: EditFileMode::Edit,
};
Arc::new(EditFileTool { thread }).run(input, ToolCallEventStream::test().0, cx)
})
.await;
assert_eq!(
result.unwrap_err().to_string(),
"Can't edit file: path not found"
);
}
// #[gpui::test]
// async fn test_resolve_path_for_creating_file(cx: &mut TestAppContext) {
@ -618,16 +605,16 @@ fn resolve_path(
// );
// }
// fn init_test(cx: &mut TestAppContext) {
// cx.update(|cx| {
// let settings_store = SettingsStore::test(cx);
// cx.set_global(settings_store);
// language::init(cx);
// TelemetrySettings::register(cx);
// agent_settings::AgentSettings::register(cx);
// Project::init_settings(cx);
// });
// }
fn init_test(cx: &mut TestAppContext) {
cx.update(|cx| {
let settings_store = SettingsStore::test(cx);
cx.set_global(settings_store);
language::init(cx);
TelemetrySettings::register(cx);
agent_settings::AgentSettings::register(cx);
Project::init_settings(cx);
});
}
// fn init_test_with_config(cx: &mut TestAppContext, data_dir: &Path) {
// cx.update(|cx| {
@ -1619,4 +1606,4 @@ fn resolve_path(
// );
// });
// }
// }
}

View file

@ -285,8 +285,6 @@ impl AgentTool for ReadFileTool {
#[cfg(test)]
mod test {
use crate::TestToolCallEventStream;
use super::*;
use gpui::{AppContext, TestAppContext, UpdateGlobal as _};
use language::{tree_sitter_rust, Language, LanguageConfig, LanguageMatcher};
@ -304,7 +302,7 @@ mod test {
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let tool = Arc::new(ReadFileTool::new(project, action_log));
let event_stream = TestToolCallEventStream::new();
let (event_stream, _) = ToolCallEventStream::test();
let result = cx
.update(|cx| {
@ -313,7 +311,7 @@ mod test {
start_line: None,
end_line: None,
};
tool.run(input, event_stream.stream(), cx)
tool.run(input, event_stream, cx)
})
.await;
assert_eq!(
@ -321,6 +319,7 @@ mod test {
"root/nonexistent_file.txt not found"
);
}
#[gpui::test]
async fn test_read_small_file(cx: &mut TestAppContext) {
init_test(cx);
@ -336,7 +335,6 @@ mod test {
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let tool = Arc::new(ReadFileTool::new(project, action_log));
let event_stream = TestToolCallEventStream::new();
let result = cx
.update(|cx| {
let input = ReadFileToolInput {
@ -344,7 +342,7 @@ mod test {
start_line: None,
end_line: None,
};
tool.run(input, event_stream.stream(), cx)
tool.run(input, ToolCallEventStream::test().0, cx)
})
.await;
assert_eq!(result.unwrap(), "This is a small file content");
@ -367,7 +365,6 @@ mod test {
language_registry.add(Arc::new(rust_lang()));
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let tool = Arc::new(ReadFileTool::new(project, action_log));
let event_stream = TestToolCallEventStream::new();
let content = cx
.update(|cx| {
let input = ReadFileToolInput {
@ -375,7 +372,7 @@ mod test {
start_line: None,
end_line: None,
};
tool.clone().run(input, event_stream.stream(), cx)
tool.clone().run(input, ToolCallEventStream::test().0, cx)
})
.await
.unwrap();
@ -399,7 +396,7 @@ mod test {
start_line: None,
end_line: None,
};
tool.run(input, event_stream.stream(), cx)
tool.run(input, ToolCallEventStream::test().0, cx)
})
.await;
let content = result.unwrap();
@ -438,7 +435,6 @@ mod test {
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let tool = Arc::new(ReadFileTool::new(project, action_log));
let event_stream = TestToolCallEventStream::new();
let result = cx
.update(|cx| {
let input = ReadFileToolInput {
@ -446,7 +442,7 @@ mod test {
start_line: Some(2),
end_line: Some(4),
};
tool.run(input, event_stream.stream(), cx)
tool.run(input, ToolCallEventStream::test().0, cx)
})
.await;
assert_eq!(result.unwrap(), "Line 2\nLine 3\nLine 4");
@ -467,7 +463,6 @@ mod test {
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let tool = Arc::new(ReadFileTool::new(project, action_log));
let event_stream = TestToolCallEventStream::new();
// start_line of 0 should be treated as 1
let result = cx
@ -477,7 +472,7 @@ mod test {
start_line: Some(0),
end_line: Some(2),
};
tool.clone().run(input, event_stream.stream(), cx)
tool.clone().run(input, ToolCallEventStream::test().0, cx)
})
.await;
assert_eq!(result.unwrap(), "Line 1\nLine 2");
@ -490,7 +485,7 @@ mod test {
start_line: Some(1),
end_line: Some(0),
};
tool.clone().run(input, event_stream.stream(), cx)
tool.clone().run(input, ToolCallEventStream::test().0, cx)
})
.await;
assert_eq!(result.unwrap(), "Line 1");
@ -503,7 +498,7 @@ mod test {
start_line: Some(3),
end_line: Some(2),
};
tool.clone().run(input, event_stream.stream(), cx)
tool.clone().run(input, ToolCallEventStream::test().0, cx)
})
.await;
assert_eq!(result.unwrap(), "Line 3");
@ -612,7 +607,6 @@ mod test {
let project = Project::test(fs.clone(), [path!("/project_root").as_ref()], cx).await;
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let tool = Arc::new(ReadFileTool::new(project, action_log));
let event_stream = TestToolCallEventStream::new();
// Reading a file outside the project worktree should fail
let result = cx
@ -622,7 +616,7 @@ mod test {
start_line: None,
end_line: None,
};
tool.clone().run(input, event_stream.stream(), cx)
tool.clone().run(input, ToolCallEventStream::test().0, cx)
})
.await;
assert!(
@ -638,7 +632,7 @@ mod test {
start_line: None,
end_line: None,
};
tool.clone().run(input, event_stream.stream(), cx)
tool.clone().run(input, ToolCallEventStream::test().0, cx)
})
.await;
assert!(
@ -654,7 +648,7 @@ mod test {
start_line: None,
end_line: None,
};
tool.clone().run(input, event_stream.stream(), cx)
tool.clone().run(input, ToolCallEventStream::test().0, cx)
})
.await;
assert!(
@ -669,7 +663,7 @@ mod test {
start_line: None,
end_line: None,
};
tool.clone().run(input, event_stream.stream(), cx)
tool.clone().run(input, ToolCallEventStream::test().0, cx)
})
.await;
assert!(
@ -685,7 +679,7 @@ mod test {
start_line: None,
end_line: None,
};
tool.clone().run(input, event_stream.stream(), cx)
tool.clone().run(input, ToolCallEventStream::test().0, cx)
})
.await;
assert!(
@ -700,7 +694,7 @@ mod test {
start_line: None,
end_line: None,
};
tool.clone().run(input, event_stream.stream(), cx)
tool.clone().run(input, ToolCallEventStream::test().0, cx)
})
.await;
assert!(
@ -715,7 +709,7 @@ mod test {
start_line: None,
end_line: None,
};
tool.clone().run(input, event_stream.stream(), cx)
tool.clone().run(input, ToolCallEventStream::test().0, cx)
})
.await;
assert!(
@ -731,7 +725,7 @@ mod test {
start_line: None,
end_line: None,
};
tool.clone().run(input, event_stream.stream(), cx)
tool.clone().run(input, ToolCallEventStream::test().0, cx)
})
.await;
assert!(result.is_ok(), "Should be able to read normal files");
@ -745,7 +739,7 @@ mod test {
start_line: None,
end_line: None,
};
tool.run(input, event_stream.stream(), cx)
tool.run(input, ToolCallEventStream::test().0, cx)
})
.await;
assert!(
@ -826,7 +820,6 @@ mod test {
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let tool = Arc::new(ReadFileTool::new(project.clone(), action_log.clone()));
let event_stream = TestToolCallEventStream::new();
// Test reading allowed files in worktree1
let result = cx
@ -836,7 +829,7 @@ mod test {
start_line: None,
end_line: None,
};
tool.clone().run(input, event_stream.stream(), cx)
tool.clone().run(input, ToolCallEventStream::test().0, cx)
})
.await
.unwrap();
@ -851,7 +844,7 @@ mod test {
start_line: None,
end_line: None,
};
tool.clone().run(input, event_stream.stream(), cx)
tool.clone().run(input, ToolCallEventStream::test().0, cx)
})
.await;
@ -872,7 +865,7 @@ mod test {
start_line: None,
end_line: None,
};
tool.clone().run(input, event_stream.stream(), cx)
tool.clone().run(input, ToolCallEventStream::test().0, cx)
})
.await;
@ -893,7 +886,7 @@ mod test {
start_line: None,
end_line: None,
};
tool.clone().run(input, event_stream.stream(), cx)
tool.clone().run(input, ToolCallEventStream::test().0, cx)
})
.await
.unwrap();
@ -911,7 +904,7 @@ mod test {
start_line: None,
end_line: None,
};
tool.clone().run(input, event_stream.stream(), cx)
tool.clone().run(input, ToolCallEventStream::test().0, cx)
})
.await;
@ -932,7 +925,7 @@ mod test {
start_line: None,
end_line: None,
};
tool.clone().run(input, event_stream.stream(), cx)
tool.clone().run(input, ToolCallEventStream::test().0, cx)
})
.await;
@ -954,7 +947,7 @@ mod test {
start_line: None,
end_line: None,
};
tool.clone().run(input, event_stream.stream(), cx)
tool.clone().run(input, ToolCallEventStream::test().0, cx)
})
.await;