Start porting tests for EditFileTool

Co-authored-by: Ben Brandt <benjamin.j.brandt@gmail.com>
This commit is contained in:
Antonio Scandurra 2025-08-08 11:27:08 +02:00
parent da5f2978fd
commit 8f390d9c6d
6 changed files with 1271 additions and 1297 deletions

View file

@ -417,7 +417,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
thread.add_tool(ThinkingTool);
thread.add_tool(FindPathTool::new(project.clone()));
thread.add_tool(ReadFileTool::new(project.clone(), action_log));
thread.add_tool(EditFileTool::new(project.clone(), cx.entity()));
thread.add_tool(EditFileTool::new(cx.entity()));
thread
});

View file

@ -9,5 +9,6 @@ mod tests;
pub use agent::*;
pub use native_agent_server::NativeAgentServer;
pub use templates::*;
pub use thread::*;
pub use tools::*;

View file

@ -1,5 +1,4 @@
use super::*;
use crate::templates::Templates;
use acp_thread::AgentConnection;
use agent_client_protocol::{self as acp};
use anyhow::Result;

View file

@ -1,4 +1,4 @@
use crate::templates::{SystemPromptTemplate, Template, Templates};
use crate::{SystemPromptTemplate, Template, Templates};
use acp_thread::Diff;
use agent_client_protocol as acp;
use anyhow::{anyhow, Context as _, Result};
@ -133,12 +133,13 @@ pub struct Thread {
project_context: Rc<RefCell<ProjectContext>>,
templates: Arc<Templates>,
pub selected_model: Arc<dyn LanguageModel>,
project: Entity<Project>,
action_log: Entity<ActionLog>,
}
impl Thread {
pub fn new(
_project: Entity<Project>,
project: Entity<Project>,
project_context: Rc<RefCell<ProjectContext>>,
action_log: Entity<ActionLog>,
templates: Arc<Templates>,
@ -153,10 +154,19 @@ impl Thread {
project_context,
templates,
selected_model: default_model,
project,
action_log,
}
}
pub fn project(&self) -> &Entity<Project> {
&self.project
}
pub fn action_log(&self) -> &Entity<ActionLog> {
&self.action_log
}
pub fn set_mode(&mut self, mode: CompletionMode) {
self.completion_mode = mode;
}
@ -323,10 +333,6 @@ impl Thread {
events_rx
}
pub fn action_log(&self) -> &Entity<ActionLog> {
&self.action_log
}
pub fn build_system_message(&self) -> AgentMessage {
log::debug!("Building system message");
let prompt = SystemPromptTemplate {
@ -901,6 +907,29 @@ pub struct ToolCallEventStream {
}
impl ToolCallEventStream {
#[cfg(test)]
pub fn test() -> (
Self,
mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
) {
let (events_tx, events_rx) =
mpsc::unbounded::<Result<AgentResponseEvent, LanguageModelCompletionError>>();
let stream = ToolCallEventStream::new(
&LanguageModelToolUse {
id: "test_id".into(),
name: "test_tool".into(),
raw_input: String::new(),
input: serde_json::Value::Null,
is_input_complete: true,
},
acp::ToolKind::Other,
AgentResponseEventStream(events_tx),
);
(stream, events_rx)
}
fn new(
tool_use: &LanguageModelToolUse,
kind: acp::ToolKind,
@ -934,38 +963,3 @@ impl ToolCallEventStream {
)
}
}
#[cfg(test)]
pub struct TestToolCallEventStream {
stream: ToolCallEventStream,
_events_rx: mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
}
#[cfg(test)]
impl TestToolCallEventStream {
pub fn new() -> Self {
let (events_tx, events_rx) =
mpsc::unbounded::<Result<AgentResponseEvent, LanguageModelCompletionError>>();
let stream = ToolCallEventStream::new(
&LanguageModelToolUse {
id: "test_id".into(),
name: "test_tool".into(),
raw_input: String::new(),
input: serde_json::Value::Null,
is_input_complete: true,
},
acp::ToolKind::Other,
AgentResponseEventStream(events_tx),
);
Self {
stream,
_events_rx: events_rx,
}
}
pub fn stream(&self) -> ToolCallEventStream {
self.stream.clone()
}
}

File diff suppressed because it is too large Load diff

View file

@ -285,8 +285,6 @@ impl AgentTool for ReadFileTool {
#[cfg(test)]
mod test {
use crate::TestToolCallEventStream;
use super::*;
use gpui::{AppContext, TestAppContext, UpdateGlobal as _};
use language::{tree_sitter_rust, Language, LanguageConfig, LanguageMatcher};
@ -304,7 +302,7 @@ mod test {
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let tool = Arc::new(ReadFileTool::new(project, action_log));
let event_stream = TestToolCallEventStream::new();
let (event_stream, _) = ToolCallEventStream::test();
let result = cx
.update(|cx| {
@ -313,7 +311,7 @@ mod test {
start_line: None,
end_line: None,
};
tool.run(input, event_stream.stream(), cx)
tool.run(input, event_stream, cx)
})
.await;
assert_eq!(
@ -321,6 +319,7 @@ mod test {
"root/nonexistent_file.txt not found"
);
}
#[gpui::test]
async fn test_read_small_file(cx: &mut TestAppContext) {
init_test(cx);
@ -336,7 +335,6 @@ mod test {
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let tool = Arc::new(ReadFileTool::new(project, action_log));
let event_stream = TestToolCallEventStream::new();
let result = cx
.update(|cx| {
let input = ReadFileToolInput {
@ -344,7 +342,7 @@ mod test {
start_line: None,
end_line: None,
};
tool.run(input, event_stream.stream(), cx)
tool.run(input, ToolCallEventStream::test().0, cx)
})
.await;
assert_eq!(result.unwrap(), "This is a small file content");
@ -367,7 +365,6 @@ mod test {
language_registry.add(Arc::new(rust_lang()));
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let tool = Arc::new(ReadFileTool::new(project, action_log));
let event_stream = TestToolCallEventStream::new();
let content = cx
.update(|cx| {
let input = ReadFileToolInput {
@ -375,7 +372,7 @@ mod test {
start_line: None,
end_line: None,
};
tool.clone().run(input, event_stream.stream(), cx)
tool.clone().run(input, ToolCallEventStream::test().0, cx)
})
.await
.unwrap();
@ -399,7 +396,7 @@ mod test {
start_line: None,
end_line: None,
};
tool.run(input, event_stream.stream(), cx)
tool.run(input, ToolCallEventStream::test().0, cx)
})
.await;
let content = result.unwrap();
@ -438,7 +435,6 @@ mod test {
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let tool = Arc::new(ReadFileTool::new(project, action_log));
let event_stream = TestToolCallEventStream::new();
let result = cx
.update(|cx| {
let input = ReadFileToolInput {
@ -446,7 +442,7 @@ mod test {
start_line: Some(2),
end_line: Some(4),
};
tool.run(input, event_stream.stream(), cx)
tool.run(input, ToolCallEventStream::test().0, cx)
})
.await;
assert_eq!(result.unwrap(), "Line 2\nLine 3\nLine 4");
@ -467,7 +463,6 @@ mod test {
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let tool = Arc::new(ReadFileTool::new(project, action_log));
let event_stream = TestToolCallEventStream::new();
// start_line of 0 should be treated as 1
let result = cx
@ -477,7 +472,7 @@ mod test {
start_line: Some(0),
end_line: Some(2),
};
tool.clone().run(input, event_stream.stream(), cx)
tool.clone().run(input, ToolCallEventStream::test().0, cx)
})
.await;
assert_eq!(result.unwrap(), "Line 1\nLine 2");
@ -490,7 +485,7 @@ mod test {
start_line: Some(1),
end_line: Some(0),
};
tool.clone().run(input, event_stream.stream(), cx)
tool.clone().run(input, ToolCallEventStream::test().0, cx)
})
.await;
assert_eq!(result.unwrap(), "Line 1");
@ -503,7 +498,7 @@ mod test {
start_line: Some(3),
end_line: Some(2),
};
tool.clone().run(input, event_stream.stream(), cx)
tool.clone().run(input, ToolCallEventStream::test().0, cx)
})
.await;
assert_eq!(result.unwrap(), "Line 3");
@ -612,7 +607,6 @@ mod test {
let project = Project::test(fs.clone(), [path!("/project_root").as_ref()], cx).await;
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let tool = Arc::new(ReadFileTool::new(project, action_log));
let event_stream = TestToolCallEventStream::new();
// Reading a file outside the project worktree should fail
let result = cx
@ -622,7 +616,7 @@ mod test {
start_line: None,
end_line: None,
};
tool.clone().run(input, event_stream.stream(), cx)
tool.clone().run(input, ToolCallEventStream::test().0, cx)
})
.await;
assert!(
@ -638,7 +632,7 @@ mod test {
start_line: None,
end_line: None,
};
tool.clone().run(input, event_stream.stream(), cx)
tool.clone().run(input, ToolCallEventStream::test().0, cx)
})
.await;
assert!(
@ -654,7 +648,7 @@ mod test {
start_line: None,
end_line: None,
};
tool.clone().run(input, event_stream.stream(), cx)
tool.clone().run(input, ToolCallEventStream::test().0, cx)
})
.await;
assert!(
@ -669,7 +663,7 @@ mod test {
start_line: None,
end_line: None,
};
tool.clone().run(input, event_stream.stream(), cx)
tool.clone().run(input, ToolCallEventStream::test().0, cx)
})
.await;
assert!(
@ -685,7 +679,7 @@ mod test {
start_line: None,
end_line: None,
};
tool.clone().run(input, event_stream.stream(), cx)
tool.clone().run(input, ToolCallEventStream::test().0, cx)
})
.await;
assert!(
@ -700,7 +694,7 @@ mod test {
start_line: None,
end_line: None,
};
tool.clone().run(input, event_stream.stream(), cx)
tool.clone().run(input, ToolCallEventStream::test().0, cx)
})
.await;
assert!(
@ -715,7 +709,7 @@ mod test {
start_line: None,
end_line: None,
};
tool.clone().run(input, event_stream.stream(), cx)
tool.clone().run(input, ToolCallEventStream::test().0, cx)
})
.await;
assert!(
@ -731,7 +725,7 @@ mod test {
start_line: None,
end_line: None,
};
tool.clone().run(input, event_stream.stream(), cx)
tool.clone().run(input, ToolCallEventStream::test().0, cx)
})
.await;
assert!(result.is_ok(), "Should be able to read normal files");
@ -745,7 +739,7 @@ mod test {
start_line: None,
end_line: None,
};
tool.run(input, event_stream.stream(), cx)
tool.run(input, ToolCallEventStream::test().0, cx)
})
.await;
assert!(
@ -826,7 +820,6 @@ mod test {
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let tool = Arc::new(ReadFileTool::new(project.clone(), action_log.clone()));
let event_stream = TestToolCallEventStream::new();
// Test reading allowed files in worktree1
let result = cx
@ -836,7 +829,7 @@ mod test {
start_line: None,
end_line: None,
};
tool.clone().run(input, event_stream.stream(), cx)
tool.clone().run(input, ToolCallEventStream::test().0, cx)
})
.await
.unwrap();
@ -851,7 +844,7 @@ mod test {
start_line: None,
end_line: None,
};
tool.clone().run(input, event_stream.stream(), cx)
tool.clone().run(input, ToolCallEventStream::test().0, cx)
})
.await;
@ -872,7 +865,7 @@ mod test {
start_line: None,
end_line: None,
};
tool.clone().run(input, event_stream.stream(), cx)
tool.clone().run(input, ToolCallEventStream::test().0, cx)
})
.await;
@ -893,7 +886,7 @@ mod test {
start_line: None,
end_line: None,
};
tool.clone().run(input, event_stream.stream(), cx)
tool.clone().run(input, ToolCallEventStream::test().0, cx)
})
.await
.unwrap();
@ -911,7 +904,7 @@ mod test {
start_line: None,
end_line: None,
};
tool.clone().run(input, event_stream.stream(), cx)
tool.clone().run(input, ToolCallEventStream::test().0, cx)
})
.await;
@ -932,7 +925,7 @@ mod test {
start_line: None,
end_line: None,
};
tool.clone().run(input, event_stream.stream(), cx)
tool.clone().run(input, ToolCallEventStream::test().0, cx)
})
.await;
@ -954,7 +947,7 @@ mod test {
start_line: None,
end_line: None,
};
tool.clone().run(input, event_stream.stream(), cx)
tool.clone().run(input, ToolCallEventStream::test().0, cx)
})
.await;