Checkpoint

This commit is contained in:
Antonio Scandurra 2025-08-18 14:56:10 +02:00
parent 199256e43e
commit e6e23d04f8
5 changed files with 113 additions and 64 deletions

View file

@ -539,9 +539,15 @@ impl ToolCallContent {
acp::ToolCallContent::Content { content } => {
Self::ContentBlock(ContentBlock::new(content, &language_registry, cx))
}
acp::ToolCallContent::Diff { diff } => {
Self::Diff(cx.new(|cx| Diff::from_acp(diff, language_registry, cx)))
}
acp::ToolCallContent::Diff { diff } => Self::Diff(cx.new(|cx| {
Diff::finalized(
diff.path,
diff.old_text,
diff.new_text,
language_registry,
cx,
)
})),
}
}

View file

@ -1,4 +1,3 @@
use agent_client_protocol as acp;
use anyhow::Result;
use buffer_diff::{BufferDiff, BufferDiffSnapshot};
use editor::{MultiBuffer, PathKey};
@ -21,17 +20,13 @@ pub enum Diff {
}
impl Diff {
pub fn from_acp(
diff: acp::Diff,
pub fn finalized(
path: PathBuf,
old_text: Option<String>,
new_text: String,
language_registry: Arc<LanguageRegistry>,
cx: &mut Context<Self>,
) -> Self {
let acp::Diff {
path,
old_text,
new_text,
} = diff;
let multibuffer = cx.new(|_cx| MultiBuffer::without_headers(Capability::ReadOnly));
let new_buffer = cx.new(|cx| Buffer::local(new_text, cx));

View file

@ -587,11 +587,12 @@ impl NativeAgentConnection {
action_log: Entity<action_log::ActionLog>,
cx: &mut Context<Thread>,
) {
let language_registry = project.read(cx).languages().clone();
thread.add_tool(CopyPathTool::new(project.clone()));
thread.add_tool(CreateDirectoryTool::new(project.clone()));
thread.add_tool(DeletePathTool::new(project.clone(), action_log.clone()));
thread.add_tool(DiagnosticsTool::new(project.clone()));
thread.add_tool(EditFileTool::new(cx.weak_entity()));
thread.add_tool(EditFileTool::new(cx.weak_entity(), language_registry));
thread.add_tool(FetchTool::new(project.read(cx).client().http_client()));
thread.add_tool(FindPathTool::new(project.clone()));
thread.add_tool(GrepTool::new(project.clone()));

View file

@ -7,8 +7,8 @@ use cloud_llm_client::CompletionIntent;
use collections::HashSet;
use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity};
use indoc::formatdoc;
use language::ToPoint;
use language::language_settings::{self, FormatOnSave};
use language::{LanguageRegistry, ToPoint};
use language_model::LanguageModelToolResultContent;
use paths;
use project::lsp_store::{FormatTrigger, LspFormatTarget};
@ -98,11 +98,13 @@ pub enum EditFileMode {
#[derive(Debug, Serialize, Deserialize)]
pub struct EditFileToolOutput {
#[serde(alias = "original_path")]
input_path: PathBuf,
project_path: PathBuf,
new_text: String,
old_text: Arc<String>,
#[serde(default)]
diff: String,
#[serde(alias = "raw_output")]
edit_agent_output: EditAgentOutput,
}
@ -123,11 +125,15 @@ impl From<EditFileToolOutput> for LanguageModelToolResultContent {
pub struct EditFileTool {
thread: WeakEntity<Thread>,
language_registry: Arc<LanguageRegistry>,
}
impl EditFileTool {
pub fn new(thread: WeakEntity<Thread>) -> Self {
Self { thread }
pub fn new(thread: WeakEntity<Thread>, language_registry: Arc<LanguageRegistry>) -> Self {
Self {
thread,
language_registry,
}
}
fn authorize(
@ -419,7 +425,6 @@ impl AgentTool for EditFileTool {
Ok(EditFileToolOutput {
input_path: input.path,
project_path: project_path.path.to_path_buf(),
new_text: new_text.clone(),
old_text,
diff: unified_diff,
@ -427,6 +432,26 @@ impl AgentTool for EditFileTool {
})
})
}
fn replay(
&self,
_input: Self::Input,
output: Self::Output,
event_stream: ToolCallEventStream,
cx: &mut App,
) -> Result<()> {
dbg!(&output);
event_stream.update_diff(cx.new(|cx| {
Diff::finalized(
output.input_path,
Some(output.old_text.to_string()),
output.new_text,
self.language_registry.clone(),
cx,
)
}));
Ok(())
}
}
/// Validate that the file path is valid, meaning:
@ -515,6 +540,7 @@ mod tests {
let fs = project::FakeFs::new(cx.executor());
fs.insert_tree("/root", json!({})).await;
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let context_server_registry =
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
@ -537,7 +563,7 @@ mod tests {
path: "root/nonexistent_file.txt".into(),
mode: EditFileMode::Edit,
};
Arc::new(EditFileTool::new(thread.downgrade())).run(
Arc::new(EditFileTool::new(thread.downgrade(), language_registry)).run(
input,
ToolCallEventStream::test().0,
cx,
@ -754,11 +780,11 @@ mod tests {
path: "root/src/main.rs".into(),
mode: EditFileMode::Overwrite,
};
Arc::new(EditFileTool::new(thread.downgrade())).run(
input,
ToolCallEventStream::test().0,
cx,
)
Arc::new(EditFileTool::new(
thread.downgrade(),
language_registry.clone(),
))
.run(input, ToolCallEventStream::test().0, cx)
});
// Stream the unformatted content
@ -811,7 +837,7 @@ mod tests {
path: "root/src/main.rs".into(),
mode: EditFileMode::Overwrite,
};
Arc::new(EditFileTool::new(thread.downgrade())).run(
Arc::new(EditFileTool::new(thread.downgrade(), language_registry)).run(
input,
ToolCallEventStream::test().0,
cx,
@ -857,6 +883,7 @@ mod tests {
.unwrap();
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
let context_server_registry =
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
let action_log = cx.new(|_| ActionLog::new(project.clone()));
@ -896,11 +923,11 @@ mod tests {
path: "root/src/main.rs".into(),
mode: EditFileMode::Overwrite,
};
Arc::new(EditFileTool::new(thread.downgrade())).run(
input,
ToolCallEventStream::test().0,
cx,
)
Arc::new(EditFileTool::new(
thread.downgrade(),
language_registry.clone(),
))
.run(input, ToolCallEventStream::test().0, cx)
});
// Stream the content with trailing whitespace
@ -948,7 +975,7 @@ mod tests {
path: "root/src/main.rs".into(),
mode: EditFileMode::Overwrite,
};
Arc::new(EditFileTool::new(thread.downgrade())).run(
Arc::new(EditFileTool::new(thread.downgrade(), language_registry)).run(
input,
ToolCallEventStream::test().0,
cx,
@ -985,6 +1012,7 @@ mod tests {
init_test(cx);
let fs = project::FakeFs::new(cx.executor());
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
let context_server_registry =
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
let action_log = cx.new(|_| ActionLog::new(project.clone()));
@ -1000,7 +1028,7 @@ mod tests {
cx,
)
});
let tool = Arc::new(EditFileTool::new(thread.downgrade()));
let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
fs.insert_tree("/root", json!({})).await;
// Test 1: Path with .zed component should require confirmation
@ -1122,6 +1150,7 @@ mod tests {
let fs = project::FakeFs::new(cx.executor());
fs.insert_tree("/project", json!({})).await;
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
let context_server_registry =
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
let action_log = cx.new(|_| ActionLog::new(project.clone()));
@ -1137,7 +1166,7 @@ mod tests {
cx,
)
});
let tool = Arc::new(EditFileTool::new(thread.downgrade()));
let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
// Test global config paths - these should require confirmation if they exist and are outside the project
let test_cases = vec![
@ -1231,7 +1260,7 @@ mod tests {
cx,
)
.await;
let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let context_server_registry =
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
@ -1247,7 +1276,7 @@ mod tests {
cx,
)
});
let tool = Arc::new(EditFileTool::new(thread.downgrade()));
let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
// Test files in different worktrees
let test_cases = vec![
@ -1313,6 +1342,7 @@ mod tests {
)
.await;
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let context_server_registry =
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
@ -1328,7 +1358,7 @@ mod tests {
cx,
)
});
let tool = Arc::new(EditFileTool::new(thread.downgrade()));
let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
// Test edge cases
let test_cases = vec![
@ -1397,6 +1427,7 @@ mod tests {
)
.await;
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let context_server_registry =
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
@ -1412,7 +1443,7 @@ mod tests {
cx,
)
});
let tool = Arc::new(EditFileTool::new(thread.downgrade()));
let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
// Test different EditFileMode values
let modes = vec![
@ -1478,6 +1509,7 @@ mod tests {
init_test(cx);
let fs = project::FakeFs::new(cx.executor());
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let context_server_registry =
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
@ -1493,7 +1525,7 @@ mod tests {
cx,
)
});
let tool = Arc::new(EditFileTool::new(thread.downgrade()));
let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
assert_eq!(
tool.initial_title(Err(json!({

View file

@ -80,33 +80,48 @@ impl AgentTool for WebSearchTool {
}
};
let result_text = if response.results.len() == 1 {
"1 result".to_string()
} else {
format!("{} results", response.results.len())
};
event_stream.update_fields(acp::ToolCallUpdateFields {
title: Some(format!("Searched the web: {result_text}")),
content: Some(
response
.results
.iter()
.map(|result| acp::ToolCallContent::Content {
content: acp::ContentBlock::ResourceLink(acp::ResourceLink {
name: result.title.clone(),
uri: result.url.clone(),
title: Some(result.title.clone()),
description: Some(result.text.clone()),
mime_type: None,
annotations: None,
size: None,
}),
})
.collect(),
),
..Default::default()
});
emit_update(&response, &event_stream);
Ok(WebSearchToolOutput(response))
})
}
fn replay(
&self,
_input: Self::Input,
output: Self::Output,
event_stream: ToolCallEventStream,
_cx: &mut App,
) -> Result<()> {
emit_update(&output.0, &event_stream);
Ok(())
}
}
fn emit_update(response: &WebSearchResponse, event_stream: &ToolCallEventStream) {
let result_text = if response.results.len() == 1 {
"1 result".to_string()
} else {
format!("{} results", response.results.len())
};
event_stream.update_fields(acp::ToolCallUpdateFields {
title: Some(format!("Searched the web: {result_text}")),
content: Some(
response
.results
.iter()
.map(|result| acp::ToolCallContent::Content {
content: acp::ContentBlock::ResourceLink(acp::ResourceLink {
name: result.title.clone(),
uri: result.url.clone(),
title: Some(result.title.clone()),
description: Some(result.text.clone()),
mime_type: None,
annotations: None,
size: None,
}),
})
.collect(),
),
..Default::default()
});
}