Start porting tests for EditFileTool

Co-authored-by: Ben Brandt <benjamin.j.brandt@gmail.com>
This commit is contained in:
Antonio Scandurra 2025-08-08 11:27:08 +02:00
parent da5f2978fd
commit 8f390d9c6d
6 changed files with 1271 additions and 1297 deletions

View file

@ -417,7 +417,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
thread.add_tool(ThinkingTool); thread.add_tool(ThinkingTool);
thread.add_tool(FindPathTool::new(project.clone())); thread.add_tool(FindPathTool::new(project.clone()));
thread.add_tool(ReadFileTool::new(project.clone(), action_log)); thread.add_tool(ReadFileTool::new(project.clone(), action_log));
thread.add_tool(EditFileTool::new(project.clone(), cx.entity())); thread.add_tool(EditFileTool::new(cx.entity()));
thread thread
}); });

View file

@ -9,5 +9,6 @@ mod tests;
pub use agent::*; pub use agent::*;
pub use native_agent_server::NativeAgentServer; pub use native_agent_server::NativeAgentServer;
pub use templates::*;
pub use thread::*; pub use thread::*;
pub use tools::*; pub use tools::*;

View file

@ -1,5 +1,4 @@
use super::*; use super::*;
use crate::templates::Templates;
use acp_thread::AgentConnection; use acp_thread::AgentConnection;
use agent_client_protocol::{self as acp}; use agent_client_protocol::{self as acp};
use anyhow::Result; use anyhow::Result;

View file

@ -1,4 +1,4 @@
use crate::templates::{SystemPromptTemplate, Template, Templates}; use crate::{SystemPromptTemplate, Template, Templates};
use acp_thread::Diff; use acp_thread::Diff;
use agent_client_protocol as acp; use agent_client_protocol as acp;
use anyhow::{anyhow, Context as _, Result}; use anyhow::{anyhow, Context as _, Result};
@ -133,12 +133,13 @@ pub struct Thread {
project_context: Rc<RefCell<ProjectContext>>, project_context: Rc<RefCell<ProjectContext>>,
templates: Arc<Templates>, templates: Arc<Templates>,
pub selected_model: Arc<dyn LanguageModel>, pub selected_model: Arc<dyn LanguageModel>,
project: Entity<Project>,
action_log: Entity<ActionLog>, action_log: Entity<ActionLog>,
} }
impl Thread { impl Thread {
pub fn new( pub fn new(
_project: Entity<Project>, project: Entity<Project>,
project_context: Rc<RefCell<ProjectContext>>, project_context: Rc<RefCell<ProjectContext>>,
action_log: Entity<ActionLog>, action_log: Entity<ActionLog>,
templates: Arc<Templates>, templates: Arc<Templates>,
@ -153,10 +154,19 @@ impl Thread {
project_context, project_context,
templates, templates,
selected_model: default_model, selected_model: default_model,
project,
action_log, action_log,
} }
} }
pub fn project(&self) -> &Entity<Project> {
&self.project
}
pub fn action_log(&self) -> &Entity<ActionLog> {
&self.action_log
}
pub fn set_mode(&mut self, mode: CompletionMode) { pub fn set_mode(&mut self, mode: CompletionMode) {
self.completion_mode = mode; self.completion_mode = mode;
} }
@ -323,10 +333,6 @@ impl Thread {
events_rx events_rx
} }
pub fn action_log(&self) -> &Entity<ActionLog> {
&self.action_log
}
pub fn build_system_message(&self) -> AgentMessage { pub fn build_system_message(&self) -> AgentMessage {
log::debug!("Building system message"); log::debug!("Building system message");
let prompt = SystemPromptTemplate { let prompt = SystemPromptTemplate {
@ -901,6 +907,29 @@ pub struct ToolCallEventStream {
} }
impl ToolCallEventStream { impl ToolCallEventStream {
#[cfg(test)]
pub fn test() -> (
Self,
mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
) {
let (events_tx, events_rx) =
mpsc::unbounded::<Result<AgentResponseEvent, LanguageModelCompletionError>>();
let stream = ToolCallEventStream::new(
&LanguageModelToolUse {
id: "test_id".into(),
name: "test_tool".into(),
raw_input: String::new(),
input: serde_json::Value::Null,
is_input_complete: true,
},
acp::ToolKind::Other,
AgentResponseEventStream(events_tx),
);
(stream, events_rx)
}
fn new( fn new(
tool_use: &LanguageModelToolUse, tool_use: &LanguageModelToolUse,
kind: acp::ToolKind, kind: acp::ToolKind,
@ -934,38 +963,3 @@ impl ToolCallEventStream {
) )
} }
} }
#[cfg(test)]
pub struct TestToolCallEventStream {
stream: ToolCallEventStream,
_events_rx: mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
}
#[cfg(test)]
impl TestToolCallEventStream {
pub fn new() -> Self {
let (events_tx, events_rx) =
mpsc::unbounded::<Result<AgentResponseEvent, LanguageModelCompletionError>>();
let stream = ToolCallEventStream::new(
&LanguageModelToolUse {
id: "test_id".into(),
name: "test_tool".into(),
raw_input: String::new(),
input: serde_json::Value::Null,
is_input_complete: true,
},
acp::ToolKind::Other,
AgentResponseEventStream(events_tx),
);
Self {
stream,
_events_rx: events_rx,
}
}
pub fn stream(&self) -> ToolCallEventStream {
self.stream.clone()
}
}

File diff suppressed because it is too large Load diff

View file

@ -285,8 +285,6 @@ impl AgentTool for ReadFileTool {
#[cfg(test)] #[cfg(test)]
mod test { mod test {
use crate::TestToolCallEventStream;
use super::*; use super::*;
use gpui::{AppContext, TestAppContext, UpdateGlobal as _}; use gpui::{AppContext, TestAppContext, UpdateGlobal as _};
use language::{tree_sitter_rust, Language, LanguageConfig, LanguageMatcher}; use language::{tree_sitter_rust, Language, LanguageConfig, LanguageMatcher};
@ -304,7 +302,7 @@ mod test {
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
let action_log = cx.new(|_| ActionLog::new(project.clone())); let action_log = cx.new(|_| ActionLog::new(project.clone()));
let tool = Arc::new(ReadFileTool::new(project, action_log)); let tool = Arc::new(ReadFileTool::new(project, action_log));
let event_stream = TestToolCallEventStream::new(); let (event_stream, _) = ToolCallEventStream::test();
let result = cx let result = cx
.update(|cx| { .update(|cx| {
@ -313,7 +311,7 @@ mod test {
start_line: None, start_line: None,
end_line: None, end_line: None,
}; };
tool.run(input, event_stream.stream(), cx) tool.run(input, event_stream, cx)
}) })
.await; .await;
assert_eq!( assert_eq!(
@ -321,6 +319,7 @@ mod test {
"root/nonexistent_file.txt not found" "root/nonexistent_file.txt not found"
); );
} }
#[gpui::test] #[gpui::test]
async fn test_read_small_file(cx: &mut TestAppContext) { async fn test_read_small_file(cx: &mut TestAppContext) {
init_test(cx); init_test(cx);
@ -336,7 +335,6 @@ mod test {
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
let action_log = cx.new(|_| ActionLog::new(project.clone())); let action_log = cx.new(|_| ActionLog::new(project.clone()));
let tool = Arc::new(ReadFileTool::new(project, action_log)); let tool = Arc::new(ReadFileTool::new(project, action_log));
let event_stream = TestToolCallEventStream::new();
let result = cx let result = cx
.update(|cx| { .update(|cx| {
let input = ReadFileToolInput { let input = ReadFileToolInput {
@ -344,7 +342,7 @@ mod test {
start_line: None, start_line: None,
end_line: None, end_line: None,
}; };
tool.run(input, event_stream.stream(), cx) tool.run(input, ToolCallEventStream::test().0, cx)
}) })
.await; .await;
assert_eq!(result.unwrap(), "This is a small file content"); assert_eq!(result.unwrap(), "This is a small file content");
@ -367,7 +365,6 @@ mod test {
language_registry.add(Arc::new(rust_lang())); language_registry.add(Arc::new(rust_lang()));
let action_log = cx.new(|_| ActionLog::new(project.clone())); let action_log = cx.new(|_| ActionLog::new(project.clone()));
let tool = Arc::new(ReadFileTool::new(project, action_log)); let tool = Arc::new(ReadFileTool::new(project, action_log));
let event_stream = TestToolCallEventStream::new();
let content = cx let content = cx
.update(|cx| { .update(|cx| {
let input = ReadFileToolInput { let input = ReadFileToolInput {
@ -375,7 +372,7 @@ mod test {
start_line: None, start_line: None,
end_line: None, end_line: None,
}; };
tool.clone().run(input, event_stream.stream(), cx) tool.clone().run(input, ToolCallEventStream::test().0, cx)
}) })
.await .await
.unwrap(); .unwrap();
@ -399,7 +396,7 @@ mod test {
start_line: None, start_line: None,
end_line: None, end_line: None,
}; };
tool.run(input, event_stream.stream(), cx) tool.run(input, ToolCallEventStream::test().0, cx)
}) })
.await; .await;
let content = result.unwrap(); let content = result.unwrap();
@ -438,7 +435,6 @@ mod test {
let action_log = cx.new(|_| ActionLog::new(project.clone())); let action_log = cx.new(|_| ActionLog::new(project.clone()));
let tool = Arc::new(ReadFileTool::new(project, action_log)); let tool = Arc::new(ReadFileTool::new(project, action_log));
let event_stream = TestToolCallEventStream::new();
let result = cx let result = cx
.update(|cx| { .update(|cx| {
let input = ReadFileToolInput { let input = ReadFileToolInput {
@ -446,7 +442,7 @@ mod test {
start_line: Some(2), start_line: Some(2),
end_line: Some(4), end_line: Some(4),
}; };
tool.run(input, event_stream.stream(), cx) tool.run(input, ToolCallEventStream::test().0, cx)
}) })
.await; .await;
assert_eq!(result.unwrap(), "Line 2\nLine 3\nLine 4"); assert_eq!(result.unwrap(), "Line 2\nLine 3\nLine 4");
@ -467,7 +463,6 @@ mod test {
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
let action_log = cx.new(|_| ActionLog::new(project.clone())); let action_log = cx.new(|_| ActionLog::new(project.clone()));
let tool = Arc::new(ReadFileTool::new(project, action_log)); let tool = Arc::new(ReadFileTool::new(project, action_log));
let event_stream = TestToolCallEventStream::new();
// start_line of 0 should be treated as 1 // start_line of 0 should be treated as 1
let result = cx let result = cx
@ -477,7 +472,7 @@ mod test {
start_line: Some(0), start_line: Some(0),
end_line: Some(2), end_line: Some(2),
}; };
tool.clone().run(input, event_stream.stream(), cx) tool.clone().run(input, ToolCallEventStream::test().0, cx)
}) })
.await; .await;
assert_eq!(result.unwrap(), "Line 1\nLine 2"); assert_eq!(result.unwrap(), "Line 1\nLine 2");
@ -490,7 +485,7 @@ mod test {
start_line: Some(1), start_line: Some(1),
end_line: Some(0), end_line: Some(0),
}; };
tool.clone().run(input, event_stream.stream(), cx) tool.clone().run(input, ToolCallEventStream::test().0, cx)
}) })
.await; .await;
assert_eq!(result.unwrap(), "Line 1"); assert_eq!(result.unwrap(), "Line 1");
@ -503,7 +498,7 @@ mod test {
start_line: Some(3), start_line: Some(3),
end_line: Some(2), end_line: Some(2),
}; };
tool.clone().run(input, event_stream.stream(), cx) tool.clone().run(input, ToolCallEventStream::test().0, cx)
}) })
.await; .await;
assert_eq!(result.unwrap(), "Line 3"); assert_eq!(result.unwrap(), "Line 3");
@ -612,7 +607,6 @@ mod test {
let project = Project::test(fs.clone(), [path!("/project_root").as_ref()], cx).await; let project = Project::test(fs.clone(), [path!("/project_root").as_ref()], cx).await;
let action_log = cx.new(|_| ActionLog::new(project.clone())); let action_log = cx.new(|_| ActionLog::new(project.clone()));
let tool = Arc::new(ReadFileTool::new(project, action_log)); let tool = Arc::new(ReadFileTool::new(project, action_log));
let event_stream = TestToolCallEventStream::new();
// Reading a file outside the project worktree should fail // Reading a file outside the project worktree should fail
let result = cx let result = cx
@ -622,7 +616,7 @@ mod test {
start_line: None, start_line: None,
end_line: None, end_line: None,
}; };
tool.clone().run(input, event_stream.stream(), cx) tool.clone().run(input, ToolCallEventStream::test().0, cx)
}) })
.await; .await;
assert!( assert!(
@ -638,7 +632,7 @@ mod test {
start_line: None, start_line: None,
end_line: None, end_line: None,
}; };
tool.clone().run(input, event_stream.stream(), cx) tool.clone().run(input, ToolCallEventStream::test().0, cx)
}) })
.await; .await;
assert!( assert!(
@ -654,7 +648,7 @@ mod test {
start_line: None, start_line: None,
end_line: None, end_line: None,
}; };
tool.clone().run(input, event_stream.stream(), cx) tool.clone().run(input, ToolCallEventStream::test().0, cx)
}) })
.await; .await;
assert!( assert!(
@ -669,7 +663,7 @@ mod test {
start_line: None, start_line: None,
end_line: None, end_line: None,
}; };
tool.clone().run(input, event_stream.stream(), cx) tool.clone().run(input, ToolCallEventStream::test().0, cx)
}) })
.await; .await;
assert!( assert!(
@ -685,7 +679,7 @@ mod test {
start_line: None, start_line: None,
end_line: None, end_line: None,
}; };
tool.clone().run(input, event_stream.stream(), cx) tool.clone().run(input, ToolCallEventStream::test().0, cx)
}) })
.await; .await;
assert!( assert!(
@ -700,7 +694,7 @@ mod test {
start_line: None, start_line: None,
end_line: None, end_line: None,
}; };
tool.clone().run(input, event_stream.stream(), cx) tool.clone().run(input, ToolCallEventStream::test().0, cx)
}) })
.await; .await;
assert!( assert!(
@ -715,7 +709,7 @@ mod test {
start_line: None, start_line: None,
end_line: None, end_line: None,
}; };
tool.clone().run(input, event_stream.stream(), cx) tool.clone().run(input, ToolCallEventStream::test().0, cx)
}) })
.await; .await;
assert!( assert!(
@ -731,7 +725,7 @@ mod test {
start_line: None, start_line: None,
end_line: None, end_line: None,
}; };
tool.clone().run(input, event_stream.stream(), cx) tool.clone().run(input, ToolCallEventStream::test().0, cx)
}) })
.await; .await;
assert!(result.is_ok(), "Should be able to read normal files"); assert!(result.is_ok(), "Should be able to read normal files");
@ -745,7 +739,7 @@ mod test {
start_line: None, start_line: None,
end_line: None, end_line: None,
}; };
tool.run(input, event_stream.stream(), cx) tool.run(input, ToolCallEventStream::test().0, cx)
}) })
.await; .await;
assert!( assert!(
@ -826,7 +820,6 @@ mod test {
let action_log = cx.new(|_| ActionLog::new(project.clone())); let action_log = cx.new(|_| ActionLog::new(project.clone()));
let tool = Arc::new(ReadFileTool::new(project.clone(), action_log.clone())); let tool = Arc::new(ReadFileTool::new(project.clone(), action_log.clone()));
let event_stream = TestToolCallEventStream::new();
// Test reading allowed files in worktree1 // Test reading allowed files in worktree1
let result = cx let result = cx
@ -836,7 +829,7 @@ mod test {
start_line: None, start_line: None,
end_line: None, end_line: None,
}; };
tool.clone().run(input, event_stream.stream(), cx) tool.clone().run(input, ToolCallEventStream::test().0, cx)
}) })
.await .await
.unwrap(); .unwrap();
@ -851,7 +844,7 @@ mod test {
start_line: None, start_line: None,
end_line: None, end_line: None,
}; };
tool.clone().run(input, event_stream.stream(), cx) tool.clone().run(input, ToolCallEventStream::test().0, cx)
}) })
.await; .await;
@ -872,7 +865,7 @@ mod test {
start_line: None, start_line: None,
end_line: None, end_line: None,
}; };
tool.clone().run(input, event_stream.stream(), cx) tool.clone().run(input, ToolCallEventStream::test().0, cx)
}) })
.await; .await;
@ -893,7 +886,7 @@ mod test {
start_line: None, start_line: None,
end_line: None, end_line: None,
}; };
tool.clone().run(input, event_stream.stream(), cx) tool.clone().run(input, ToolCallEventStream::test().0, cx)
}) })
.await .await
.unwrap(); .unwrap();
@ -911,7 +904,7 @@ mod test {
start_line: None, start_line: None,
end_line: None, end_line: None,
}; };
tool.clone().run(input, event_stream.stream(), cx) tool.clone().run(input, ToolCallEventStream::test().0, cx)
}) })
.await; .await;
@ -932,7 +925,7 @@ mod test {
start_line: None, start_line: None,
end_line: None, end_line: None,
}; };
tool.clone().run(input, event_stream.stream(), cx) tool.clone().run(input, ToolCallEventStream::test().0, cx)
}) })
.await; .await;
@ -954,7 +947,7 @@ mod test {
start_line: None, start_line: None,
end_line: None, end_line: None,
}; };
tool.clone().run(input, event_stream.stream(), cx) tool.clone().run(input, ToolCallEventStream::test().0, cx)
}) })
.await; .await;