Checkpoint
This commit is contained in:
parent
199256e43e
commit
e6e23d04f8
5 changed files with 113 additions and 64 deletions
|
@ -539,9 +539,15 @@ impl ToolCallContent {
|
|||
acp::ToolCallContent::Content { content } => {
|
||||
Self::ContentBlock(ContentBlock::new(content, &language_registry, cx))
|
||||
}
|
||||
acp::ToolCallContent::Diff { diff } => {
|
||||
Self::Diff(cx.new(|cx| Diff::from_acp(diff, language_registry, cx)))
|
||||
}
|
||||
acp::ToolCallContent::Diff { diff } => Self::Diff(cx.new(|cx| {
|
||||
Diff::finalized(
|
||||
diff.path,
|
||||
diff.old_text,
|
||||
diff.new_text,
|
||||
language_registry,
|
||||
cx,
|
||||
)
|
||||
})),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
use agent_client_protocol as acp;
|
||||
use anyhow::Result;
|
||||
use buffer_diff::{BufferDiff, BufferDiffSnapshot};
|
||||
use editor::{MultiBuffer, PathKey};
|
||||
|
@ -21,17 +20,13 @@ pub enum Diff {
|
|||
}
|
||||
|
||||
impl Diff {
|
||||
pub fn from_acp(
|
||||
diff: acp::Diff,
|
||||
pub fn finalized(
|
||||
path: PathBuf,
|
||||
old_text: Option<String>,
|
||||
new_text: String,
|
||||
language_registry: Arc<LanguageRegistry>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Self {
|
||||
let acp::Diff {
|
||||
path,
|
||||
old_text,
|
||||
new_text,
|
||||
} = diff;
|
||||
|
||||
let multibuffer = cx.new(|_cx| MultiBuffer::without_headers(Capability::ReadOnly));
|
||||
|
||||
let new_buffer = cx.new(|cx| Buffer::local(new_text, cx));
|
||||
|
|
|
@ -587,11 +587,12 @@ impl NativeAgentConnection {
|
|||
action_log: Entity<action_log::ActionLog>,
|
||||
cx: &mut Context<Thread>,
|
||||
) {
|
||||
let language_registry = project.read(cx).languages().clone();
|
||||
thread.add_tool(CopyPathTool::new(project.clone()));
|
||||
thread.add_tool(CreateDirectoryTool::new(project.clone()));
|
||||
thread.add_tool(DeletePathTool::new(project.clone(), action_log.clone()));
|
||||
thread.add_tool(DiagnosticsTool::new(project.clone()));
|
||||
thread.add_tool(EditFileTool::new(cx.weak_entity()));
|
||||
thread.add_tool(EditFileTool::new(cx.weak_entity(), language_registry));
|
||||
thread.add_tool(FetchTool::new(project.read(cx).client().http_client()));
|
||||
thread.add_tool(FindPathTool::new(project.clone()));
|
||||
thread.add_tool(GrepTool::new(project.clone()));
|
||||
|
|
|
@ -7,8 +7,8 @@ use cloud_llm_client::CompletionIntent;
|
|||
use collections::HashSet;
|
||||
use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity};
|
||||
use indoc::formatdoc;
|
||||
use language::ToPoint;
|
||||
use language::language_settings::{self, FormatOnSave};
|
||||
use language::{LanguageRegistry, ToPoint};
|
||||
use language_model::LanguageModelToolResultContent;
|
||||
use paths;
|
||||
use project::lsp_store::{FormatTrigger, LspFormatTarget};
|
||||
|
@ -98,11 +98,13 @@ pub enum EditFileMode {
|
|||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct EditFileToolOutput {
|
||||
#[serde(alias = "original_path")]
|
||||
input_path: PathBuf,
|
||||
project_path: PathBuf,
|
||||
new_text: String,
|
||||
old_text: Arc<String>,
|
||||
#[serde(default)]
|
||||
diff: String,
|
||||
#[serde(alias = "raw_output")]
|
||||
edit_agent_output: EditAgentOutput,
|
||||
}
|
||||
|
||||
|
@ -123,11 +125,15 @@ impl From<EditFileToolOutput> for LanguageModelToolResultContent {
|
|||
|
||||
pub struct EditFileTool {
|
||||
thread: WeakEntity<Thread>,
|
||||
language_registry: Arc<LanguageRegistry>,
|
||||
}
|
||||
|
||||
impl EditFileTool {
|
||||
pub fn new(thread: WeakEntity<Thread>) -> Self {
|
||||
Self { thread }
|
||||
pub fn new(thread: WeakEntity<Thread>, language_registry: Arc<LanguageRegistry>) -> Self {
|
||||
Self {
|
||||
thread,
|
||||
language_registry,
|
||||
}
|
||||
}
|
||||
|
||||
fn authorize(
|
||||
|
@ -419,7 +425,6 @@ impl AgentTool for EditFileTool {
|
|||
|
||||
Ok(EditFileToolOutput {
|
||||
input_path: input.path,
|
||||
project_path: project_path.path.to_path_buf(),
|
||||
new_text: new_text.clone(),
|
||||
old_text,
|
||||
diff: unified_diff,
|
||||
|
@ -427,6 +432,26 @@ impl AgentTool for EditFileTool {
|
|||
})
|
||||
})
|
||||
}
|
||||
|
||||
fn replay(
|
||||
&self,
|
||||
_input: Self::Input,
|
||||
output: Self::Output,
|
||||
event_stream: ToolCallEventStream,
|
||||
cx: &mut App,
|
||||
) -> Result<()> {
|
||||
dbg!(&output);
|
||||
event_stream.update_diff(cx.new(|cx| {
|
||||
Diff::finalized(
|
||||
output.input_path,
|
||||
Some(output.old_text.to_string()),
|
||||
output.new_text,
|
||||
self.language_registry.clone(),
|
||||
cx,
|
||||
)
|
||||
}));
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Validate that the file path is valid, meaning:
|
||||
|
@ -515,6 +540,7 @@ mod tests {
|
|||
let fs = project::FakeFs::new(cx.executor());
|
||||
fs.insert_tree("/root", json!({})).await;
|
||||
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
|
||||
let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
|
||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||
let context_server_registry =
|
||||
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
|
||||
|
@ -537,7 +563,7 @@ mod tests {
|
|||
path: "root/nonexistent_file.txt".into(),
|
||||
mode: EditFileMode::Edit,
|
||||
};
|
||||
Arc::new(EditFileTool::new(thread.downgrade())).run(
|
||||
Arc::new(EditFileTool::new(thread.downgrade(), language_registry)).run(
|
||||
input,
|
||||
ToolCallEventStream::test().0,
|
||||
cx,
|
||||
|
@ -754,11 +780,11 @@ mod tests {
|
|||
path: "root/src/main.rs".into(),
|
||||
mode: EditFileMode::Overwrite,
|
||||
};
|
||||
Arc::new(EditFileTool::new(thread.downgrade())).run(
|
||||
input,
|
||||
ToolCallEventStream::test().0,
|
||||
cx,
|
||||
)
|
||||
Arc::new(EditFileTool::new(
|
||||
thread.downgrade(),
|
||||
language_registry.clone(),
|
||||
))
|
||||
.run(input, ToolCallEventStream::test().0, cx)
|
||||
});
|
||||
|
||||
// Stream the unformatted content
|
||||
|
@ -811,7 +837,7 @@ mod tests {
|
|||
path: "root/src/main.rs".into(),
|
||||
mode: EditFileMode::Overwrite,
|
||||
};
|
||||
Arc::new(EditFileTool::new(thread.downgrade())).run(
|
||||
Arc::new(EditFileTool::new(thread.downgrade(), language_registry)).run(
|
||||
input,
|
||||
ToolCallEventStream::test().0,
|
||||
cx,
|
||||
|
@ -857,6 +883,7 @@ mod tests {
|
|||
.unwrap();
|
||||
|
||||
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
|
||||
let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
|
||||
let context_server_registry =
|
||||
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
|
||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||
|
@ -896,11 +923,11 @@ mod tests {
|
|||
path: "root/src/main.rs".into(),
|
||||
mode: EditFileMode::Overwrite,
|
||||
};
|
||||
Arc::new(EditFileTool::new(thread.downgrade())).run(
|
||||
input,
|
||||
ToolCallEventStream::test().0,
|
||||
cx,
|
||||
)
|
||||
Arc::new(EditFileTool::new(
|
||||
thread.downgrade(),
|
||||
language_registry.clone(),
|
||||
))
|
||||
.run(input, ToolCallEventStream::test().0, cx)
|
||||
});
|
||||
|
||||
// Stream the content with trailing whitespace
|
||||
|
@ -948,7 +975,7 @@ mod tests {
|
|||
path: "root/src/main.rs".into(),
|
||||
mode: EditFileMode::Overwrite,
|
||||
};
|
||||
Arc::new(EditFileTool::new(thread.downgrade())).run(
|
||||
Arc::new(EditFileTool::new(thread.downgrade(), language_registry)).run(
|
||||
input,
|
||||
ToolCallEventStream::test().0,
|
||||
cx,
|
||||
|
@ -985,6 +1012,7 @@ mod tests {
|
|||
init_test(cx);
|
||||
let fs = project::FakeFs::new(cx.executor());
|
||||
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
|
||||
let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
|
||||
let context_server_registry =
|
||||
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
|
||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||
|
@ -1000,7 +1028,7 @@ mod tests {
|
|||
cx,
|
||||
)
|
||||
});
|
||||
let tool = Arc::new(EditFileTool::new(thread.downgrade()));
|
||||
let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
|
||||
fs.insert_tree("/root", json!({})).await;
|
||||
|
||||
// Test 1: Path with .zed component should require confirmation
|
||||
|
@ -1122,6 +1150,7 @@ mod tests {
|
|||
let fs = project::FakeFs::new(cx.executor());
|
||||
fs.insert_tree("/project", json!({})).await;
|
||||
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
|
||||
let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
|
||||
let context_server_registry =
|
||||
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
|
||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||
|
@ -1137,7 +1166,7 @@ mod tests {
|
|||
cx,
|
||||
)
|
||||
});
|
||||
let tool = Arc::new(EditFileTool::new(thread.downgrade()));
|
||||
let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
|
||||
|
||||
// Test global config paths - these should require confirmation if they exist and are outside the project
|
||||
let test_cases = vec![
|
||||
|
@ -1231,7 +1260,7 @@ mod tests {
|
|||
cx,
|
||||
)
|
||||
.await;
|
||||
|
||||
let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
|
||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||
let context_server_registry =
|
||||
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
|
||||
|
@ -1247,7 +1276,7 @@ mod tests {
|
|||
cx,
|
||||
)
|
||||
});
|
||||
let tool = Arc::new(EditFileTool::new(thread.downgrade()));
|
||||
let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
|
||||
|
||||
// Test files in different worktrees
|
||||
let test_cases = vec![
|
||||
|
@ -1313,6 +1342,7 @@ mod tests {
|
|||
)
|
||||
.await;
|
||||
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
|
||||
let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
|
||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||
let context_server_registry =
|
||||
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
|
||||
|
@ -1328,7 +1358,7 @@ mod tests {
|
|||
cx,
|
||||
)
|
||||
});
|
||||
let tool = Arc::new(EditFileTool::new(thread.downgrade()));
|
||||
let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
|
||||
|
||||
// Test edge cases
|
||||
let test_cases = vec![
|
||||
|
@ -1397,6 +1427,7 @@ mod tests {
|
|||
)
|
||||
.await;
|
||||
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
|
||||
let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
|
||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||
let context_server_registry =
|
||||
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
|
||||
|
@ -1412,7 +1443,7 @@ mod tests {
|
|||
cx,
|
||||
)
|
||||
});
|
||||
let tool = Arc::new(EditFileTool::new(thread.downgrade()));
|
||||
let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
|
||||
|
||||
// Test different EditFileMode values
|
||||
let modes = vec![
|
||||
|
@ -1478,6 +1509,7 @@ mod tests {
|
|||
init_test(cx);
|
||||
let fs = project::FakeFs::new(cx.executor());
|
||||
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
|
||||
let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
|
||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||
let context_server_registry =
|
||||
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
|
||||
|
@ -1493,7 +1525,7 @@ mod tests {
|
|||
cx,
|
||||
)
|
||||
});
|
||||
let tool = Arc::new(EditFileTool::new(thread.downgrade()));
|
||||
let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
|
||||
|
||||
assert_eq!(
|
||||
tool.initial_title(Err(json!({
|
||||
|
|
|
@ -80,33 +80,48 @@ impl AgentTool for WebSearchTool {
|
|||
}
|
||||
};
|
||||
|
||||
let result_text = if response.results.len() == 1 {
|
||||
"1 result".to_string()
|
||||
} else {
|
||||
format!("{} results", response.results.len())
|
||||
};
|
||||
event_stream.update_fields(acp::ToolCallUpdateFields {
|
||||
title: Some(format!("Searched the web: {result_text}")),
|
||||
content: Some(
|
||||
response
|
||||
.results
|
||||
.iter()
|
||||
.map(|result| acp::ToolCallContent::Content {
|
||||
content: acp::ContentBlock::ResourceLink(acp::ResourceLink {
|
||||
name: result.title.clone(),
|
||||
uri: result.url.clone(),
|
||||
title: Some(result.title.clone()),
|
||||
description: Some(result.text.clone()),
|
||||
mime_type: None,
|
||||
annotations: None,
|
||||
size: None,
|
||||
}),
|
||||
})
|
||||
.collect(),
|
||||
),
|
||||
..Default::default()
|
||||
});
|
||||
emit_update(&response, &event_stream);
|
||||
Ok(WebSearchToolOutput(response))
|
||||
})
|
||||
}
|
||||
|
||||
fn replay(
|
||||
&self,
|
||||
_input: Self::Input,
|
||||
output: Self::Output,
|
||||
event_stream: ToolCallEventStream,
|
||||
_cx: &mut App,
|
||||
) -> Result<()> {
|
||||
emit_update(&output.0, &event_stream);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn emit_update(response: &WebSearchResponse, event_stream: &ToolCallEventStream) {
|
||||
let result_text = if response.results.len() == 1 {
|
||||
"1 result".to_string()
|
||||
} else {
|
||||
format!("{} results", response.results.len())
|
||||
};
|
||||
event_stream.update_fields(acp::ToolCallUpdateFields {
|
||||
title: Some(format!("Searched the web: {result_text}")),
|
||||
content: Some(
|
||||
response
|
||||
.results
|
||||
.iter()
|
||||
.map(|result| acp::ToolCallContent::Content {
|
||||
content: acp::ContentBlock::ResourceLink(acp::ResourceLink {
|
||||
name: result.title.clone(),
|
||||
uri: result.url.clone(),
|
||||
title: Some(result.title.clone()),
|
||||
description: Some(result.text.clone()),
|
||||
mime_type: None,
|
||||
annotations: None,
|
||||
size: None,
|
||||
}),
|
||||
})
|
||||
.collect(),
|
||||
),
|
||||
..Default::default()
|
||||
});
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue