agent2: Port edit_file tool (#35844)

TODO:
- [x] Authorization
- [x] Restore tests

Release Notes:

- N/A

---------

Co-authored-by: Antonio Scandurra <me@as-cii.com>
Co-authored-by: Ben Brandt <benjamin.j.brandt@gmail.com>
This commit is contained in:
Agus Zubiaga 2025-08-08 09:43:53 -03:00 committed by GitHub
parent d705585a2e
commit 2526dcb5a5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
20 changed files with 2075 additions and 414 deletions

View file

@ -1,4 +1,5 @@
use crate::templates::{SystemPromptTemplate, Template, Templates};
use crate::{SystemPromptTemplate, Template, Templates};
use acp_thread::Diff;
use agent_client_protocol as acp;
use anyhow::{anyhow, Context as _, Result};
use assistant_tool::{adapt_schema_to_format, ActionLog};
@ -103,6 +104,7 @@ pub enum AgentResponseEvent {
ToolCall(acp::ToolCall),
ToolCallUpdate(acp::ToolCallUpdate),
ToolCallAuthorization(ToolCallAuthorization),
ToolCallDiff(ToolCallDiff),
Stop(acp::StopReason),
}
@ -113,6 +115,12 @@ pub struct ToolCallAuthorization {
pub response: oneshot::Sender<acp::PermissionOptionId>,
}
#[derive(Debug)]
pub struct ToolCallDiff {
pub tool_call_id: acp::ToolCallId,
pub diff: Entity<acp_thread::Diff>,
}
pub struct Thread {
messages: Vec<AgentMessage>,
completion_mode: CompletionMode,
@ -125,12 +133,13 @@ pub struct Thread {
project_context: Rc<RefCell<ProjectContext>>,
templates: Arc<Templates>,
pub selected_model: Arc<dyn LanguageModel>,
project: Entity<Project>,
action_log: Entity<ActionLog>,
}
impl Thread {
pub fn new(
_project: Entity<Project>,
project: Entity<Project>,
project_context: Rc<RefCell<ProjectContext>>,
action_log: Entity<ActionLog>,
templates: Arc<Templates>,
@ -145,10 +154,19 @@ impl Thread {
project_context,
templates,
selected_model: default_model,
project,
action_log,
}
}
pub fn project(&self) -> &Entity<Project> {
&self.project
}
pub fn action_log(&self) -> &Entity<ActionLog> {
&self.action_log
}
pub fn set_mode(&mut self, mode: CompletionMode) {
self.completion_mode = mode;
}
@ -315,10 +333,6 @@ impl Thread {
events_rx
}
pub fn action_log(&self) -> &Entity<ActionLog> {
&self.action_log
}
pub fn build_system_message(&self) -> AgentMessage {
log::debug!("Building system message");
let prompt = SystemPromptTemplate {
@ -490,15 +504,33 @@ impl Thread {
}));
};
let tool_result = self.run_tool(tool, tool_use.clone(), event_stream.clone(), cx);
let tool_event_stream =
ToolCallEventStream::new(&tool_use, tool.kind(), event_stream.clone());
tool_event_stream.send_update(acp::ToolCallUpdateFields {
status: Some(acp::ToolCallStatus::InProgress),
..Default::default()
});
let supports_images = self.selected_model.supports_images();
let tool_result = tool.run(tool_use.input, tool_event_stream, cx);
Some(cx.foreground_executor().spawn(async move {
match tool_result.await {
Ok(tool_output) => LanguageModelToolResult {
let tool_result = tool_result.await.and_then(|output| {
if let LanguageModelToolResultContent::Image(_) = &output.llm_output {
if !supports_images {
return Err(anyhow!(
"Attempted to read an image, but this model doesn't support it.",
));
}
}
Ok(output)
});
match tool_result {
Ok(output) => LanguageModelToolResult {
tool_use_id: tool_use.id,
tool_name: tool_use.name,
is_error: false,
content: LanguageModelToolResultContent::Text(Arc::from(tool_output)),
output: None,
content: output.llm_output,
output: Some(output.raw_output),
},
Err(error) => LanguageModelToolResult {
tool_use_id: tool_use.id,
@ -511,24 +543,6 @@ impl Thread {
}))
}
fn run_tool(
&self,
tool: Arc<dyn AnyAgentTool>,
tool_use: LanguageModelToolUse,
event_stream: AgentResponseEventStream,
cx: &mut Context<Self>,
) -> Task<Result<String>> {
cx.spawn(async move |_this, cx| {
let tool_event_stream = ToolCallEventStream::new(tool_use.id, event_stream);
tool_event_stream.send_update(acp::ToolCallUpdateFields {
status: Some(acp::ToolCallStatus::InProgress),
..Default::default()
});
cx.update(|cx| tool.run(tool_use.input, tool_event_stream, cx))?
.await
})
}
fn handle_tool_use_json_parse_error_event(
&mut self,
tool_use_id: LanguageModelToolUseId,
@ -572,7 +586,7 @@ impl Thread {
self.messages.last_mut().unwrap()
}
fn build_completion_request(
pub(crate) fn build_completion_request(
&self,
completion_intent: CompletionIntent,
cx: &mut App,
@ -662,6 +676,7 @@ where
Self: 'static + Sized,
{
type Input: for<'de> Deserialize<'de> + Serialize + JsonSchema;
type Output: for<'de> Deserialize<'de> + Serialize + Into<LanguageModelToolResultContent>;
fn name(&self) -> SharedString;
@ -685,23 +700,13 @@ where
schemars::schema_for!(Self::Input)
}
/// Allows the tool to authorize a given tool call with the user if necessary
fn authorize(
&self,
input: Self::Input,
event_stream: ToolCallEventStream,
) -> impl use<Self> + Future<Output = Result<()>> {
let json_input = serde_json::json!(&input);
event_stream.authorize(self.initial_title(input).into(), self.kind(), json_input)
}
/// Runs the tool with the provided input.
fn run(
self: Arc<Self>,
input: Self::Input,
event_stream: ToolCallEventStream,
cx: &mut App,
) -> Task<Result<String>>;
) -> Task<Result<Self::Output>>;
fn erase(self) -> Arc<dyn AnyAgentTool> {
Arc::new(Erased(Arc::new(self)))
@ -710,6 +715,11 @@ where
pub struct Erased<T>(T);
pub struct AgentToolOutput {
llm_output: LanguageModelToolResultContent,
raw_output: serde_json::Value,
}
pub trait AnyAgentTool {
fn name(&self) -> SharedString;
fn description(&self, cx: &mut App) -> SharedString;
@ -721,7 +731,7 @@ pub trait AnyAgentTool {
input: serde_json::Value,
event_stream: ToolCallEventStream,
cx: &mut App,
) -> Task<Result<String>>;
) -> Task<Result<AgentToolOutput>>;
}
impl<T> AnyAgentTool for Erased<Arc<T>>
@ -756,12 +766,18 @@ where
input: serde_json::Value,
event_stream: ToolCallEventStream,
cx: &mut App,
) -> Task<Result<String>> {
let parsed_input: Result<T::Input> = serde_json::from_value(input).map_err(Into::into);
match parsed_input {
Ok(input) => self.0.clone().run(input, event_stream, cx),
Err(error) => Task::ready(Err(anyhow!(error))),
}
) -> Task<Result<AgentToolOutput>> {
cx.spawn(async move |cx| {
let input = serde_json::from_value(input)?;
let output = cx
.update(|cx| self.0.clone().run(input, event_stream, cx))?
.await?;
let raw_output = serde_json::to_value(&output)?;
Ok(AgentToolOutput {
llm_output: output.into(),
raw_output,
})
})
}
}
@ -874,6 +890,12 @@ impl AgentResponseEventStream {
.ok();
}
fn send_tool_call_diff(&self, tool_call_diff: ToolCallDiff) {
self.0
.unbounded_send(Ok(AgentResponseEvent::ToolCallDiff(tool_call_diff)))
.ok();
}
fn send_stop(&self, reason: StopReason) {
match reason {
StopReason::EndTurn => {
@ -903,13 +925,41 @@ impl AgentResponseEventStream {
#[derive(Clone)]
pub struct ToolCallEventStream {
tool_use_id: LanguageModelToolUseId,
kind: acp::ToolKind,
input: serde_json::Value,
stream: AgentResponseEventStream,
}
impl ToolCallEventStream {
fn new(tool_use_id: LanguageModelToolUseId, stream: AgentResponseEventStream) -> Self {
#[cfg(test)]
pub fn test() -> (Self, ToolCallEventStreamReceiver) {
let (events_tx, events_rx) =
mpsc::unbounded::<Result<AgentResponseEvent, LanguageModelCompletionError>>();
let stream = ToolCallEventStream::new(
&LanguageModelToolUse {
id: "test_id".into(),
name: "test_tool".into(),
raw_input: String::new(),
input: serde_json::Value::Null,
is_input_complete: true,
},
acp::ToolKind::Other,
AgentResponseEventStream(events_tx),
);
(stream, ToolCallEventStreamReceiver(events_rx))
}
fn new(
tool_use: &LanguageModelToolUse,
kind: acp::ToolKind,
stream: AgentResponseEventStream,
) -> Self {
Self {
tool_use_id,
tool_use_id: tool_use.id.clone(),
kind,
input: tool_use.input.clone(),
stream,
}
}
@ -918,38 +968,52 @@ impl ToolCallEventStream {
self.stream.send_tool_call_update(&self.tool_use_id, fields);
}
pub fn authorize(
&self,
title: String,
kind: acp::ToolKind,
input: serde_json::Value,
) -> impl use<> + Future<Output = Result<()>> {
self.stream
.authorize_tool_call(&self.tool_use_id, title, kind, input)
pub fn send_diff(&self, diff: Entity<Diff>) {
self.stream.send_tool_call_diff(ToolCallDiff {
tool_call_id: acp::ToolCallId(self.tool_use_id.to_string().into()),
diff,
});
}
pub fn authorize(&self, title: String) -> impl use<> + Future<Output = Result<()>> {
self.stream.authorize_tool_call(
&self.tool_use_id,
title,
self.kind.clone(),
self.input.clone(),
)
}
}
#[cfg(test)]
pub struct TestToolCallEventStream {
stream: ToolCallEventStream,
_events_rx: mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
}
pub struct ToolCallEventStreamReceiver(
mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
);
#[cfg(test)]
impl TestToolCallEventStream {
pub fn new() -> Self {
let (events_tx, events_rx) =
mpsc::unbounded::<Result<AgentResponseEvent, LanguageModelCompletionError>>();
let stream = ToolCallEventStream::new("test".into(), AgentResponseEventStream(events_tx));
Self {
stream,
_events_rx: events_rx,
impl ToolCallEventStreamReceiver {
pub async fn expect_tool_authorization(&mut self) -> ToolCallAuthorization {
let event = self.0.next().await;
if let Some(Ok(AgentResponseEvent::ToolCallAuthorization(auth))) = event {
auth
} else {
panic!("Expected ToolCallAuthorization but got: {:?}", event);
}
}
}
pub fn stream(&self) -> ToolCallEventStream {
self.stream.clone()
#[cfg(test)]
impl std::ops::Deref for ToolCallEventStreamReceiver {
type Target = mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
#[cfg(test)]
impl std::ops::DerefMut for ToolCallEventStreamReceiver {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}