Wire up find_path tool in agent2 (#35799)
Release Notes: - N/A --------- Co-authored-by: Antonio Scandurra <me@as-cii.com>
This commit is contained in:
parent
11efa32fa7
commit
90fa921756
18 changed files with 669 additions and 247 deletions
4
Cargo.lock
generated
4
Cargo.lock
generated
|
@ -138,9 +138,9 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "agent-client-protocol"
|
name = "agent-client-protocol"
|
||||||
version = "0.0.21"
|
version = "0.0.23"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "b7ae3c22c23b64a5c3b7fc8a86fcc7c494e989bd2cd66fdce14a58cfc8078381"
|
checksum = "3fad72b7b8ee4331b3a4c8d43c107e982a4725564b4ee658ae5c4e79d2b486e8"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"anyhow",
|
"anyhow",
|
||||||
"futures 0.3.31",
|
"futures 0.3.31",
|
||||||
|
|
|
@ -425,7 +425,7 @@ zlog_settings = { path = "crates/zlog_settings" }
|
||||||
#
|
#
|
||||||
|
|
||||||
agentic-coding-protocol = "0.0.10"
|
agentic-coding-protocol = "0.0.10"
|
||||||
agent-client-protocol = "0.0.21"
|
agent-client-protocol = { version = "0.0.23" }
|
||||||
aho-corasick = "1.1"
|
aho-corasick = "1.1"
|
||||||
alacritty_terminal = { git = "https://github.com/zed-industries/alacritty.git", branch = "add-hush-login-flag" }
|
alacritty_terminal = { git = "https://github.com/zed-industries/alacritty.git", branch = "add-hush-login-flag" }
|
||||||
any_vec = "0.14"
|
any_vec = "0.14"
|
||||||
|
|
|
@ -167,6 +167,7 @@ pub struct ToolCall {
|
||||||
pub status: ToolCallStatus,
|
pub status: ToolCallStatus,
|
||||||
pub locations: Vec<acp::ToolCallLocation>,
|
pub locations: Vec<acp::ToolCallLocation>,
|
||||||
pub raw_input: Option<serde_json::Value>,
|
pub raw_input: Option<serde_json::Value>,
|
||||||
|
pub raw_output: Option<serde_json::Value>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ToolCall {
|
impl ToolCall {
|
||||||
|
@ -195,6 +196,7 @@ impl ToolCall {
|
||||||
locations: tool_call.locations,
|
locations: tool_call.locations,
|
||||||
status,
|
status,
|
||||||
raw_input: tool_call.raw_input,
|
raw_input: tool_call.raw_input,
|
||||||
|
raw_output: tool_call.raw_output,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -211,6 +213,7 @@ impl ToolCall {
|
||||||
content,
|
content,
|
||||||
locations,
|
locations,
|
||||||
raw_input,
|
raw_input,
|
||||||
|
raw_output,
|
||||||
} = fields;
|
} = fields;
|
||||||
|
|
||||||
if let Some(kind) = kind {
|
if let Some(kind) = kind {
|
||||||
|
@ -241,6 +244,10 @@ impl ToolCall {
|
||||||
if let Some(raw_input) = raw_input {
|
if let Some(raw_input) = raw_input {
|
||||||
self.raw_input = Some(raw_input);
|
self.raw_input = Some(raw_input);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if let Some(raw_output) = raw_output {
|
||||||
|
self.raw_output = Some(raw_output);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn diffs(&self) -> impl Iterator<Item = &Diff> {
|
pub fn diffs(&self) -> impl Iterator<Item = &Diff> {
|
||||||
|
@ -1547,6 +1554,7 @@ mod tests {
|
||||||
content: vec![],
|
content: vec![],
|
||||||
locations: vec![],
|
locations: vec![],
|
||||||
raw_input: None,
|
raw_input: None,
|
||||||
|
raw_output: None,
|
||||||
}),
|
}),
|
||||||
cx,
|
cx,
|
||||||
)
|
)
|
||||||
|
@ -1659,6 +1667,7 @@ mod tests {
|
||||||
}],
|
}],
|
||||||
locations: vec![],
|
locations: vec![],
|
||||||
raw_input: None,
|
raw_input: None,
|
||||||
|
raw_output: None,
|
||||||
}),
|
}),
|
||||||
cx,
|
cx,
|
||||||
)
|
)
|
||||||
|
|
|
@ -39,7 +39,6 @@ ui.workspace = true
|
||||||
util.workspace = true
|
util.workspace = true
|
||||||
uuid.workspace = true
|
uuid.workspace = true
|
||||||
watch.workspace = true
|
watch.workspace = true
|
||||||
worktree.workspace = true
|
|
||||||
workspace-hack.workspace = true
|
workspace-hack.workspace = true
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
use crate::ToolCallAuthorization;
|
|
||||||
use crate::{templates::Templates, AgentResponseEvent, Thread};
|
use crate::{templates::Templates, AgentResponseEvent, Thread};
|
||||||
|
use crate::{FindPathTool, ThinkingTool, ToolCallAuthorization};
|
||||||
use acp_thread::ModelSelector;
|
use acp_thread::ModelSelector;
|
||||||
use agent_client_protocol as acp;
|
use agent_client_protocol as acp;
|
||||||
use anyhow::{anyhow, Context as _, Result};
|
use anyhow::{anyhow, Context as _, Result};
|
||||||
|
@ -412,7 +412,13 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
|
||||||
anyhow!("No default model configured. Please configure a default model in settings.")
|
anyhow!("No default model configured. Please configure a default model in settings.")
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
let thread = cx.new(|_| Thread::new(project, agent.project_context.clone(), action_log, agent.templates.clone(), default_model));
|
let thread = cx.new(|_| {
|
||||||
|
let mut thread = Thread::new(project.clone(), agent.project_context.clone(), action_log, agent.templates.clone(), default_model);
|
||||||
|
thread.add_tool(ThinkingTool);
|
||||||
|
thread.add_tool(FindPathTool::new(project.clone()));
|
||||||
|
thread
|
||||||
|
});
|
||||||
|
|
||||||
Ok(thread)
|
Ok(thread)
|
||||||
},
|
},
|
||||||
)??;
|
)??;
|
||||||
|
|
|
@ -10,3 +10,4 @@ mod tests;
|
||||||
pub use agent::*;
|
pub use agent::*;
|
||||||
pub use native_agent_server::NativeAgentServer;
|
pub use native_agent_server::NativeAgentServer;
|
||||||
pub use thread::*;
|
pub use thread::*;
|
||||||
|
pub use tools::*;
|
||||||
|
|
|
@ -33,19 +33,6 @@ pub trait Template: Sized {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[expect(
|
|
||||||
dead_code,
|
|
||||||
reason = "Marked as unused by Rust 1.89 and left as is as of 07 Aug 2025 to let AI team address it."
|
|
||||||
)]
|
|
||||||
#[derive(Serialize)]
|
|
||||||
pub struct GlobTemplate {
|
|
||||||
pub project_roots: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Template for GlobTemplate {
|
|
||||||
const TEMPLATE_NAME: &'static str = "glob.hbs";
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize)]
|
#[derive(Serialize)]
|
||||||
pub struct SystemPromptTemplate<'a> {
|
pub struct SystemPromptTemplate<'a> {
|
||||||
#[serde(flatten)]
|
#[serde(flatten)]
|
||||||
|
|
|
@ -1,8 +0,0 @@
|
||||||
Find paths on disk with glob patterns.
|
|
||||||
|
|
||||||
Assume that all glob patterns are matched in a project directory with the following entries.
|
|
||||||
|
|
||||||
{{project_roots}}
|
|
||||||
|
|
||||||
When searching with patterns that begin with literal path components, e.g. `foo/bar/**/*.rs`, be
|
|
||||||
sure to anchor them with one of the directories listed above.
|
|
|
@ -270,14 +270,14 @@ async fn test_tool_authorization(cx: &mut TestAppContext) {
|
||||||
vec![
|
vec![
|
||||||
MessageContent::ToolResult(LanguageModelToolResult {
|
MessageContent::ToolResult(LanguageModelToolResult {
|
||||||
tool_use_id: tool_call_auth_1.tool_call.id.0.to_string().into(),
|
tool_use_id: tool_call_auth_1.tool_call.id.0.to_string().into(),
|
||||||
tool_name: tool_call_auth_1.tool_call.title.into(),
|
tool_name: ToolRequiringPermission.name().into(),
|
||||||
is_error: false,
|
is_error: false,
|
||||||
content: "Allowed".into(),
|
content: "Allowed".into(),
|
||||||
output: None
|
output: None
|
||||||
}),
|
}),
|
||||||
MessageContent::ToolResult(LanguageModelToolResult {
|
MessageContent::ToolResult(LanguageModelToolResult {
|
||||||
tool_use_id: tool_call_auth_2.tool_call.id.0.to_string().into(),
|
tool_use_id: tool_call_auth_2.tool_call.id.0.to_string().into(),
|
||||||
tool_name: tool_call_auth_2.tool_call.title.into(),
|
tool_name: ToolRequiringPermission.name().into(),
|
||||||
is_error: true,
|
is_error: true,
|
||||||
content: "Permission to run tool denied by user".into(),
|
content: "Permission to run tool denied by user".into(),
|
||||||
output: None
|
output: None
|
||||||
|
@ -286,6 +286,63 @@ async fn test_tool_authorization(cx: &mut TestAppContext) {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[gpui::test]
|
||||||
|
async fn test_tool_hallucination(cx: &mut TestAppContext) {
|
||||||
|
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
|
||||||
|
let fake_model = model.as_fake();
|
||||||
|
|
||||||
|
let mut events = thread.update(cx, |thread, cx| thread.send(model.clone(), "abc", cx));
|
||||||
|
cx.run_until_parked();
|
||||||
|
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
|
||||||
|
LanguageModelToolUse {
|
||||||
|
id: "tool_id_1".into(),
|
||||||
|
name: "nonexistent_tool".into(),
|
||||||
|
raw_input: "{}".into(),
|
||||||
|
input: json!({}),
|
||||||
|
is_input_complete: true,
|
||||||
|
},
|
||||||
|
));
|
||||||
|
fake_model.end_last_completion_stream();
|
||||||
|
|
||||||
|
let tool_call = expect_tool_call(&mut events).await;
|
||||||
|
assert_eq!(tool_call.title, "nonexistent_tool");
|
||||||
|
assert_eq!(tool_call.status, acp::ToolCallStatus::Pending);
|
||||||
|
let update = expect_tool_call_update(&mut events).await;
|
||||||
|
assert_eq!(update.fields.status, Some(acp::ToolCallStatus::Failed));
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn expect_tool_call(
|
||||||
|
events: &mut UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
|
||||||
|
) -> acp::ToolCall {
|
||||||
|
let event = events
|
||||||
|
.next()
|
||||||
|
.await
|
||||||
|
.expect("no tool call authorization event received")
|
||||||
|
.unwrap();
|
||||||
|
match event {
|
||||||
|
AgentResponseEvent::ToolCall(tool_call) => return tool_call,
|
||||||
|
event => {
|
||||||
|
panic!("Unexpected event {event:?}");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn expect_tool_call_update(
|
||||||
|
events: &mut UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
|
||||||
|
) -> acp::ToolCallUpdate {
|
||||||
|
let event = events
|
||||||
|
.next()
|
||||||
|
.await
|
||||||
|
.expect("no tool call authorization event received")
|
||||||
|
.unwrap();
|
||||||
|
match event {
|
||||||
|
AgentResponseEvent::ToolCallUpdate(tool_call_update) => return tool_call_update,
|
||||||
|
event => {
|
||||||
|
panic!("Unexpected event {event:?}");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
async fn next_tool_call_authorization(
|
async fn next_tool_call_authorization(
|
||||||
events: &mut UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
|
events: &mut UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
|
||||||
) -> ToolCallAuthorization {
|
) -> ToolCallAuthorization {
|
||||||
|
@ -582,6 +639,77 @@ async fn test_agent_connection(cx: &mut TestAppContext) {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[gpui::test]
|
||||||
|
async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
|
||||||
|
let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
|
||||||
|
thread.update(cx, |thread, _cx| thread.add_tool(ThinkingTool));
|
||||||
|
let fake_model = model.as_fake();
|
||||||
|
|
||||||
|
let mut events = thread.update(cx, |thread, cx| thread.send(model.clone(), "Think", cx));
|
||||||
|
cx.run_until_parked();
|
||||||
|
|
||||||
|
let input = json!({ "content": "Thinking hard!" });
|
||||||
|
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
|
||||||
|
LanguageModelToolUse {
|
||||||
|
id: "1".into(),
|
||||||
|
name: ThinkingTool.name().into(),
|
||||||
|
raw_input: input.to_string(),
|
||||||
|
input,
|
||||||
|
is_input_complete: true,
|
||||||
|
},
|
||||||
|
));
|
||||||
|
fake_model.end_last_completion_stream();
|
||||||
|
cx.run_until_parked();
|
||||||
|
|
||||||
|
let tool_call = expect_tool_call(&mut events).await;
|
||||||
|
assert_eq!(
|
||||||
|
tool_call,
|
||||||
|
acp::ToolCall {
|
||||||
|
id: acp::ToolCallId("1".into()),
|
||||||
|
title: "Thinking".into(),
|
||||||
|
kind: acp::ToolKind::Think,
|
||||||
|
status: acp::ToolCallStatus::Pending,
|
||||||
|
content: vec![],
|
||||||
|
locations: vec![],
|
||||||
|
raw_input: Some(json!({ "content": "Thinking hard!" })),
|
||||||
|
raw_output: None,
|
||||||
|
}
|
||||||
|
);
|
||||||
|
let update = expect_tool_call_update(&mut events).await;
|
||||||
|
assert_eq!(
|
||||||
|
update,
|
||||||
|
acp::ToolCallUpdate {
|
||||||
|
id: acp::ToolCallId("1".into()),
|
||||||
|
fields: acp::ToolCallUpdateFields {
|
||||||
|
status: Some(acp::ToolCallStatus::InProgress,),
|
||||||
|
..Default::default()
|
||||||
|
},
|
||||||
|
}
|
||||||
|
);
|
||||||
|
let update = expect_tool_call_update(&mut events).await;
|
||||||
|
assert_eq!(
|
||||||
|
update,
|
||||||
|
acp::ToolCallUpdate {
|
||||||
|
id: acp::ToolCallId("1".into()),
|
||||||
|
fields: acp::ToolCallUpdateFields {
|
||||||
|
content: Some(vec!["Thinking hard!".into()]),
|
||||||
|
..Default::default()
|
||||||
|
},
|
||||||
|
}
|
||||||
|
);
|
||||||
|
let update = expect_tool_call_update(&mut events).await;
|
||||||
|
assert_eq!(
|
||||||
|
update,
|
||||||
|
acp::ToolCallUpdate {
|
||||||
|
id: acp::ToolCallId("1".into()),
|
||||||
|
fields: acp::ToolCallUpdateFields {
|
||||||
|
status: Some(acp::ToolCallStatus::Completed),
|
||||||
|
..Default::default()
|
||||||
|
},
|
||||||
|
}
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
/// Filters out the stop events for asserting against in tests
|
/// Filters out the stop events for asserting against in tests
|
||||||
fn stop_events(
|
fn stop_events(
|
||||||
result_events: Vec<Result<AgentResponseEvent, LanguageModelCompletionError>>,
|
result_events: Vec<Result<AgentResponseEvent, LanguageModelCompletionError>>,
|
||||||
|
|
|
@ -19,11 +19,20 @@ impl AgentTool for EchoTool {
|
||||||
"echo".into()
|
"echo".into()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn needs_authorization(&self, _input: Self::Input, _cx: &App) -> bool {
|
fn kind(&self) -> acp::ToolKind {
|
||||||
false
|
acp::ToolKind::Other
|
||||||
}
|
}
|
||||||
|
|
||||||
fn run(self: Arc<Self>, input: Self::Input, _cx: &mut App) -> Task<Result<String>> {
|
fn initial_title(&self, _: Self::Input) -> SharedString {
|
||||||
|
"Echo".into()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run(
|
||||||
|
self: Arc<Self>,
|
||||||
|
input: Self::Input,
|
||||||
|
_event_stream: ToolCallEventStream,
|
||||||
|
_cx: &mut App,
|
||||||
|
) -> Task<Result<String>> {
|
||||||
Task::ready(Ok(input.text))
|
Task::ready(Ok(input.text))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -44,11 +53,20 @@ impl AgentTool for DelayTool {
|
||||||
"delay".into()
|
"delay".into()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn needs_authorization(&self, _input: Self::Input, _cx: &App) -> bool {
|
fn initial_title(&self, input: Self::Input) -> SharedString {
|
||||||
false
|
format!("Delay {}ms", input.ms).into()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn run(self: Arc<Self>, input: Self::Input, cx: &mut App) -> Task<Result<String>>
|
fn kind(&self) -> acp::ToolKind {
|
||||||
|
acp::ToolKind::Other
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run(
|
||||||
|
self: Arc<Self>,
|
||||||
|
input: Self::Input,
|
||||||
|
_event_stream: ToolCallEventStream,
|
||||||
|
cx: &mut App,
|
||||||
|
) -> Task<Result<String>>
|
||||||
where
|
where
|
||||||
Self: Sized,
|
Self: Sized,
|
||||||
{
|
{
|
||||||
|
@ -71,16 +89,28 @@ impl AgentTool for ToolRequiringPermission {
|
||||||
"tool_requiring_permission".into()
|
"tool_requiring_permission".into()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn needs_authorization(&self, _input: Self::Input, _cx: &App) -> bool {
|
fn kind(&self) -> acp::ToolKind {
|
||||||
true
|
acp::ToolKind::Other
|
||||||
}
|
}
|
||||||
|
|
||||||
fn run(self: Arc<Self>, _input: Self::Input, cx: &mut App) -> Task<Result<String>>
|
fn initial_title(&self, _input: Self::Input) -> SharedString {
|
||||||
|
"This tool requires permission".into()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run(
|
||||||
|
self: Arc<Self>,
|
||||||
|
input: Self::Input,
|
||||||
|
event_stream: ToolCallEventStream,
|
||||||
|
cx: &mut App,
|
||||||
|
) -> Task<Result<String>>
|
||||||
where
|
where
|
||||||
Self: Sized,
|
Self: Sized,
|
||||||
{
|
{
|
||||||
cx.foreground_executor()
|
let auth_check = self.authorize(input, event_stream);
|
||||||
.spawn(async move { Ok("Allowed".to_string()) })
|
cx.foreground_executor().spawn(async move {
|
||||||
|
auth_check.await?;
|
||||||
|
Ok("Allowed".to_string())
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -96,11 +126,20 @@ impl AgentTool for InfiniteTool {
|
||||||
"infinite".into()
|
"infinite".into()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn needs_authorization(&self, _input: Self::Input, _cx: &App) -> bool {
|
fn kind(&self) -> acp::ToolKind {
|
||||||
false
|
acp::ToolKind::Other
|
||||||
}
|
}
|
||||||
|
|
||||||
fn run(self: Arc<Self>, _input: Self::Input, cx: &mut App) -> Task<Result<String>> {
|
fn initial_title(&self, _input: Self::Input) -> SharedString {
|
||||||
|
"This is the tool that never ends... it just goes on and on my friends!".into()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run(
|
||||||
|
self: Arc<Self>,
|
||||||
|
_input: Self::Input,
|
||||||
|
_event_stream: ToolCallEventStream,
|
||||||
|
cx: &mut App,
|
||||||
|
) -> Task<Result<String>> {
|
||||||
cx.foreground_executor().spawn(async move {
|
cx.foreground_executor().spawn(async move {
|
||||||
future::pending::<()>().await;
|
future::pending::<()>().await;
|
||||||
unreachable!()
|
unreachable!()
|
||||||
|
@ -137,11 +176,20 @@ impl AgentTool for WordListTool {
|
||||||
"word_list".into()
|
"word_list".into()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn needs_authorization(&self, _input: Self::Input, _cx: &App) -> bool {
|
fn initial_title(&self, _input: Self::Input) -> SharedString {
|
||||||
false
|
"List of random words".into()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn run(self: Arc<Self>, _input: Self::Input, _cx: &mut App) -> Task<Result<String>> {
|
fn kind(&self) -> acp::ToolKind {
|
||||||
|
acp::ToolKind::Other
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run(
|
||||||
|
self: Arc<Self>,
|
||||||
|
_input: Self::Input,
|
||||||
|
_event_stream: ToolCallEventStream,
|
||||||
|
_cx: &mut App,
|
||||||
|
) -> Task<Result<String>> {
|
||||||
Task::ready(Ok("ok".to_string()))
|
Task::ready(Ok("ok".to_string()))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,16 +1,16 @@
|
||||||
use crate::templates::{SystemPromptTemplate, Template, Templates};
|
use crate::templates::{SystemPromptTemplate, Template, Templates};
|
||||||
use agent_client_protocol as acp;
|
use agent_client_protocol as acp;
|
||||||
use anyhow::{anyhow, Context as _, Result};
|
use anyhow::{anyhow, Context as _, Result};
|
||||||
use assistant_tool::ActionLog;
|
use assistant_tool::{adapt_schema_to_format, ActionLog};
|
||||||
use cloud_llm_client::{CompletionIntent, CompletionMode};
|
use cloud_llm_client::{CompletionIntent, CompletionMode};
|
||||||
use collections::HashMap;
|
use collections::HashMap;
|
||||||
use futures::{
|
use futures::{
|
||||||
channel::{mpsc, oneshot},
|
channel::{mpsc, oneshot},
|
||||||
stream::FuturesUnordered,
|
stream::FuturesUnordered,
|
||||||
};
|
};
|
||||||
use gpui::{App, Context, Entity, ImageFormat, SharedString, Task};
|
use gpui::{App, Context, Entity, SharedString, Task};
|
||||||
use language_model::{
|
use language_model::{
|
||||||
LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelImage,
|
LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
|
||||||
LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool,
|
LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool,
|
||||||
LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolSchemaFormat,
|
LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolSchemaFormat,
|
||||||
LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role, StopReason,
|
LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role, StopReason,
|
||||||
|
@ -19,7 +19,7 @@ use log;
|
||||||
use project::Project;
|
use project::Project;
|
||||||
use prompt_store::ProjectContext;
|
use prompt_store::ProjectContext;
|
||||||
use schemars::{JsonSchema, Schema};
|
use schemars::{JsonSchema, Schema};
|
||||||
use serde::Deserialize;
|
use serde::{Deserialize, Serialize};
|
||||||
use smol::stream::StreamExt;
|
use smol::stream::StreamExt;
|
||||||
use std::{cell::RefCell, collections::BTreeMap, fmt::Write, future::Future, rc::Rc, sync::Arc};
|
use std::{cell::RefCell, collections::BTreeMap, fmt::Write, future::Future, rc::Rc, sync::Arc};
|
||||||
use util::{markdown::MarkdownCodeBlock, ResultExt};
|
use util::{markdown::MarkdownCodeBlock, ResultExt};
|
||||||
|
@ -276,7 +276,17 @@ impl Thread {
|
||||||
while let Some(tool_result) = tool_uses.next().await {
|
while let Some(tool_result) = tool_uses.next().await {
|
||||||
log::info!("Tool finished {:?}", tool_result);
|
log::info!("Tool finished {:?}", tool_result);
|
||||||
|
|
||||||
event_stream.send_tool_call_result(&tool_result);
|
event_stream.send_tool_call_update(
|
||||||
|
&tool_result.tool_use_id,
|
||||||
|
acp::ToolCallUpdateFields {
|
||||||
|
status: Some(if tool_result.is_error {
|
||||||
|
acp::ToolCallStatus::Failed
|
||||||
|
} else {
|
||||||
|
acp::ToolCallStatus::Completed
|
||||||
|
}),
|
||||||
|
..Default::default()
|
||||||
|
},
|
||||||
|
);
|
||||||
thread
|
thread
|
||||||
.update(cx, |thread, _cx| {
|
.update(cx, |thread, _cx| {
|
||||||
thread.pending_tool_uses.remove(&tool_result.tool_use_id);
|
thread.pending_tool_uses.remove(&tool_result.tool_use_id);
|
||||||
|
@ -426,6 +436,8 @@ impl Thread {
|
||||||
) -> Option<Task<LanguageModelToolResult>> {
|
) -> Option<Task<LanguageModelToolResult>> {
|
||||||
cx.notify();
|
cx.notify();
|
||||||
|
|
||||||
|
let tool = self.tools.get(tool_use.name.as_ref()).cloned();
|
||||||
|
|
||||||
self.pending_tool_uses
|
self.pending_tool_uses
|
||||||
.insert(tool_use.id.clone(), tool_use.clone());
|
.insert(tool_use.id.clone(), tool_use.clone());
|
||||||
let last_message = self.last_assistant_message();
|
let last_message = self.last_assistant_message();
|
||||||
|
@ -443,8 +455,9 @@ impl Thread {
|
||||||
true
|
true
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
if push_new_tool_use {
|
if push_new_tool_use {
|
||||||
event_stream.send_tool_call(&tool_use);
|
event_stream.send_tool_call(tool.as_ref(), &tool_use);
|
||||||
last_message
|
last_message
|
||||||
.content
|
.content
|
||||||
.push(MessageContent::ToolUse(tool_use.clone()));
|
.push(MessageContent::ToolUse(tool_use.clone()));
|
||||||
|
@ -462,37 +475,36 @@ impl Thread {
|
||||||
return None;
|
return None;
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(tool) = self.tools.get(tool_use.name.as_ref()) {
|
let Some(tool) = tool else {
|
||||||
let tool_result =
|
|
||||||
self.run_tool(tool.clone(), tool_use.clone(), event_stream.clone(), cx);
|
|
||||||
Some(cx.foreground_executor().spawn(async move {
|
|
||||||
match tool_result.await {
|
|
||||||
Ok(tool_output) => LanguageModelToolResult {
|
|
||||||
tool_use_id: tool_use.id,
|
|
||||||
tool_name: tool_use.name,
|
|
||||||
is_error: false,
|
|
||||||
content: LanguageModelToolResultContent::Text(Arc::from(tool_output)),
|
|
||||||
output: None,
|
|
||||||
},
|
|
||||||
Err(error) => LanguageModelToolResult {
|
|
||||||
tool_use_id: tool_use.id,
|
|
||||||
tool_name: tool_use.name,
|
|
||||||
is_error: true,
|
|
||||||
content: LanguageModelToolResultContent::Text(Arc::from(error.to_string())),
|
|
||||||
output: None,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}))
|
|
||||||
} else {
|
|
||||||
let content = format!("No tool named {} exists", tool_use.name);
|
let content = format!("No tool named {} exists", tool_use.name);
|
||||||
Some(Task::ready(LanguageModelToolResult {
|
return Some(Task::ready(LanguageModelToolResult {
|
||||||
content: LanguageModelToolResultContent::Text(Arc::from(content)),
|
content: LanguageModelToolResultContent::Text(Arc::from(content)),
|
||||||
tool_use_id: tool_use.id,
|
tool_use_id: tool_use.id,
|
||||||
tool_name: tool_use.name,
|
tool_name: tool_use.name,
|
||||||
is_error: true,
|
is_error: true,
|
||||||
output: None,
|
output: None,
|
||||||
}))
|
}));
|
||||||
}
|
};
|
||||||
|
|
||||||
|
let tool_result = self.run_tool(tool, tool_use.clone(), event_stream.clone(), cx);
|
||||||
|
Some(cx.foreground_executor().spawn(async move {
|
||||||
|
match tool_result.await {
|
||||||
|
Ok(tool_output) => LanguageModelToolResult {
|
||||||
|
tool_use_id: tool_use.id,
|
||||||
|
tool_name: tool_use.name,
|
||||||
|
is_error: false,
|
||||||
|
content: LanguageModelToolResultContent::Text(Arc::from(tool_output)),
|
||||||
|
output: None,
|
||||||
|
},
|
||||||
|
Err(error) => LanguageModelToolResult {
|
||||||
|
tool_use_id: tool_use.id,
|
||||||
|
tool_name: tool_use.name,
|
||||||
|
is_error: true,
|
||||||
|
content: LanguageModelToolResultContent::Text(Arc::from(error.to_string())),
|
||||||
|
output: None,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn run_tool(
|
fn run_tool(
|
||||||
|
@ -502,20 +514,14 @@ impl Thread {
|
||||||
event_stream: AgentResponseEventStream,
|
event_stream: AgentResponseEventStream,
|
||||||
cx: &mut Context<Self>,
|
cx: &mut Context<Self>,
|
||||||
) -> Task<Result<String>> {
|
) -> Task<Result<String>> {
|
||||||
let needs_authorization = tool.needs_authorization(tool_use.input.clone(), cx);
|
|
||||||
cx.spawn(async move |_this, cx| {
|
cx.spawn(async move |_this, cx| {
|
||||||
if needs_authorization? {
|
let tool_event_stream = ToolCallEventStream::new(tool_use.id, event_stream);
|
||||||
event_stream.authorize_tool_call(&tool_use).await?;
|
tool_event_stream.send_update(acp::ToolCallUpdateFields {
|
||||||
}
|
status: Some(acp::ToolCallStatus::InProgress),
|
||||||
|
..Default::default()
|
||||||
event_stream.send_tool_call_update(
|
});
|
||||||
&tool_use.id,
|
cx.update(|cx| tool.run(tool_use.input, tool_event_stream, cx))?
|
||||||
acp::ToolCallUpdateFields {
|
.await
|
||||||
status: Some(acp::ToolCallStatus::InProgress),
|
|
||||||
..Default::default()
|
|
||||||
},
|
|
||||||
);
|
|
||||||
cx.update(|cx| tool.run(tool_use.input, cx))?.await
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -584,7 +590,7 @@ impl Thread {
|
||||||
name: tool_name,
|
name: tool_name,
|
||||||
description: tool.description(cx).to_string(),
|
description: tool.description(cx).to_string(),
|
||||||
input_schema: tool
|
input_schema: tool
|
||||||
.input_schema(LanguageModelToolSchemaFormat::JsonSchema)
|
.input_schema(self.selected_model.tool_input_format())
|
||||||
.log_err()?,
|
.log_err()?,
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
@ -651,9 +657,10 @@ pub trait AgentTool
|
||||||
where
|
where
|
||||||
Self: 'static + Sized,
|
Self: 'static + Sized,
|
||||||
{
|
{
|
||||||
type Input: for<'de> Deserialize<'de> + JsonSchema;
|
type Input: for<'de> Deserialize<'de> + Serialize + JsonSchema;
|
||||||
|
|
||||||
fn name(&self) -> SharedString;
|
fn name(&self) -> SharedString;
|
||||||
|
|
||||||
fn description(&self, _cx: &mut App) -> SharedString {
|
fn description(&self, _cx: &mut App) -> SharedString {
|
||||||
let schema = schemars::schema_for!(Self::Input);
|
let schema = schemars::schema_for!(Self::Input);
|
||||||
SharedString::new(
|
SharedString::new(
|
||||||
|
@ -664,17 +671,33 @@ where
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn kind(&self) -> acp::ToolKind;
|
||||||
|
|
||||||
|
/// The initial tool title to display. Can be updated during the tool run.
|
||||||
|
fn initial_title(&self, input: Self::Input) -> SharedString;
|
||||||
|
|
||||||
/// Returns the JSON schema that describes the tool's input.
|
/// Returns the JSON schema that describes the tool's input.
|
||||||
fn input_schema(&self, _format: LanguageModelToolSchemaFormat) -> Schema {
|
fn input_schema(&self) -> Schema {
|
||||||
schemars::schema_for!(Self::Input)
|
schemars::schema_for!(Self::Input)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns true if the tool needs the users's authorization
|
/// Allows the tool to authorize a given tool call with the user if necessary
|
||||||
/// before running.
|
fn authorize(
|
||||||
fn needs_authorization(&self, input: Self::Input, cx: &App) -> bool;
|
&self,
|
||||||
|
input: Self::Input,
|
||||||
|
event_stream: ToolCallEventStream,
|
||||||
|
) -> impl use<Self> + Future<Output = Result<()>> {
|
||||||
|
let json_input = serde_json::json!(&input);
|
||||||
|
event_stream.authorize(self.initial_title(input).into(), self.kind(), json_input)
|
||||||
|
}
|
||||||
|
|
||||||
/// Runs the tool with the provided input.
|
/// Runs the tool with the provided input.
|
||||||
fn run(self: Arc<Self>, input: Self::Input, cx: &mut App) -> Task<Result<String>>;
|
fn run(
|
||||||
|
self: Arc<Self>,
|
||||||
|
input: Self::Input,
|
||||||
|
event_stream: ToolCallEventStream,
|
||||||
|
cx: &mut App,
|
||||||
|
) -> Task<Result<String>>;
|
||||||
|
|
||||||
fn erase(self) -> Arc<dyn AnyAgentTool> {
|
fn erase(self) -> Arc<dyn AnyAgentTool> {
|
||||||
Arc::new(Erased(Arc::new(self)))
|
Arc::new(Erased(Arc::new(self)))
|
||||||
|
@ -686,9 +709,15 @@ pub struct Erased<T>(T);
|
||||||
pub trait AnyAgentTool {
|
pub trait AnyAgentTool {
|
||||||
fn name(&self) -> SharedString;
|
fn name(&self) -> SharedString;
|
||||||
fn description(&self, cx: &mut App) -> SharedString;
|
fn description(&self, cx: &mut App) -> SharedString;
|
||||||
|
fn kind(&self) -> acp::ToolKind;
|
||||||
|
fn initial_title(&self, input: serde_json::Value) -> Result<SharedString>;
|
||||||
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value>;
|
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value>;
|
||||||
fn needs_authorization(&self, input: serde_json::Value, cx: &mut App) -> Result<bool>;
|
fn run(
|
||||||
fn run(self: Arc<Self>, input: serde_json::Value, cx: &mut App) -> Task<Result<String>>;
|
self: Arc<Self>,
|
||||||
|
input: serde_json::Value,
|
||||||
|
event_stream: ToolCallEventStream,
|
||||||
|
cx: &mut App,
|
||||||
|
) -> Task<Result<String>>;
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T> AnyAgentTool for Erased<Arc<T>>
|
impl<T> AnyAgentTool for Erased<Arc<T>>
|
||||||
|
@ -703,22 +732,30 @@ where
|
||||||
self.0.description(cx)
|
self.0.description(cx)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn kind(&self) -> agent_client_protocol::ToolKind {
|
||||||
|
self.0.kind()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn initial_title(&self, input: serde_json::Value) -> Result<SharedString> {
|
||||||
|
let parsed_input = serde_json::from_value(input)?;
|
||||||
|
Ok(self.0.initial_title(parsed_input))
|
||||||
|
}
|
||||||
|
|
||||||
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
|
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
|
||||||
Ok(serde_json::to_value(self.0.input_schema(format))?)
|
let mut json = serde_json::to_value(self.0.input_schema())?;
|
||||||
|
adapt_schema_to_format(&mut json, format)?;
|
||||||
|
Ok(json)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn needs_authorization(&self, input: serde_json::Value, cx: &mut App) -> Result<bool> {
|
fn run(
|
||||||
|
self: Arc<Self>,
|
||||||
|
input: serde_json::Value,
|
||||||
|
event_stream: ToolCallEventStream,
|
||||||
|
cx: &mut App,
|
||||||
|
) -> Task<Result<String>> {
|
||||||
let parsed_input: Result<T::Input> = serde_json::from_value(input).map_err(Into::into);
|
let parsed_input: Result<T::Input> = serde_json::from_value(input).map_err(Into::into);
|
||||||
match parsed_input {
|
match parsed_input {
|
||||||
Ok(input) => Ok(self.0.needs_authorization(input, cx)),
|
Ok(input) => self.0.clone().run(input, event_stream, cx),
|
||||||
Err(error) => Err(anyhow!(error)),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn run(self: Arc<Self>, input: serde_json::Value, cx: &mut App) -> Task<Result<String>> {
|
|
||||||
let parsed_input: Result<T::Input> = serde_json::from_value(input).map_err(Into::into);
|
|
||||||
match parsed_input {
|
|
||||||
Ok(input) => self.0.clone().run(input, cx),
|
|
||||||
Err(error) => Task::ready(Err(anyhow!(error))),
|
Err(error) => Task::ready(Err(anyhow!(error))),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -744,21 +781,16 @@ impl AgentResponseEventStream {
|
||||||
|
|
||||||
fn authorize_tool_call(
|
fn authorize_tool_call(
|
||||||
&self,
|
&self,
|
||||||
tool_use: &LanguageModelToolUse,
|
id: &LanguageModelToolUseId,
|
||||||
|
title: String,
|
||||||
|
kind: acp::ToolKind,
|
||||||
|
input: serde_json::Value,
|
||||||
) -> impl use<> + Future<Output = Result<()>> {
|
) -> impl use<> + Future<Output = Result<()>> {
|
||||||
let (response_tx, response_rx) = oneshot::channel();
|
let (response_tx, response_rx) = oneshot::channel();
|
||||||
self.0
|
self.0
|
||||||
.unbounded_send(Ok(AgentResponseEvent::ToolCallAuthorization(
|
.unbounded_send(Ok(AgentResponseEvent::ToolCallAuthorization(
|
||||||
ToolCallAuthorization {
|
ToolCallAuthorization {
|
||||||
tool_call: acp::ToolCall {
|
tool_call: Self::initial_tool_call(id, title, kind, input),
|
||||||
id: acp::ToolCallId(tool_use.id.to_string().into()),
|
|
||||||
title: tool_use.name.to_string(),
|
|
||||||
kind: acp::ToolKind::Other,
|
|
||||||
status: acp::ToolCallStatus::Pending,
|
|
||||||
content: vec![],
|
|
||||||
locations: vec![],
|
|
||||||
raw_input: Some(tool_use.input.clone()),
|
|
||||||
},
|
|
||||||
options: vec![
|
options: vec![
|
||||||
acp::PermissionOption {
|
acp::PermissionOption {
|
||||||
id: acp::PermissionOptionId("always_allow".into()),
|
id: acp::PermissionOptionId("always_allow".into()),
|
||||||
|
@ -788,20 +820,41 @@ impl AgentResponseEventStream {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn send_tool_call(&self, tool_use: &LanguageModelToolUse) {
|
fn send_tool_call(
|
||||||
|
&self,
|
||||||
|
tool: Option<&Arc<dyn AnyAgentTool>>,
|
||||||
|
tool_use: &LanguageModelToolUse,
|
||||||
|
) {
|
||||||
self.0
|
self.0
|
||||||
.unbounded_send(Ok(AgentResponseEvent::ToolCall(acp::ToolCall {
|
.unbounded_send(Ok(AgentResponseEvent::ToolCall(Self::initial_tool_call(
|
||||||
id: acp::ToolCallId(tool_use.id.to_string().into()),
|
&tool_use.id,
|
||||||
title: tool_use.name.to_string(),
|
tool.and_then(|t| t.initial_title(tool_use.input.clone()).ok())
|
||||||
kind: acp::ToolKind::Other,
|
.map(|i| i.into())
|
||||||
status: acp::ToolCallStatus::Pending,
|
.unwrap_or_else(|| tool_use.name.to_string()),
|
||||||
content: vec![],
|
tool.map(|t| t.kind()).unwrap_or(acp::ToolKind::Other),
|
||||||
locations: vec![],
|
tool_use.input.clone(),
|
||||||
raw_input: Some(tool_use.input.clone()),
|
))))
|
||||||
})))
|
|
||||||
.ok();
|
.ok();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn initial_tool_call(
|
||||||
|
id: &LanguageModelToolUseId,
|
||||||
|
title: String,
|
||||||
|
kind: acp::ToolKind,
|
||||||
|
input: serde_json::Value,
|
||||||
|
) -> acp::ToolCall {
|
||||||
|
acp::ToolCall {
|
||||||
|
id: acp::ToolCallId(id.to_string().into()),
|
||||||
|
title,
|
||||||
|
kind,
|
||||||
|
status: acp::ToolCallStatus::Pending,
|
||||||
|
content: vec![],
|
||||||
|
locations: vec![],
|
||||||
|
raw_input: Some(input),
|
||||||
|
raw_output: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn send_tool_call_update(
|
fn send_tool_call_update(
|
||||||
&self,
|
&self,
|
||||||
tool_use_id: &LanguageModelToolUseId,
|
tool_use_id: &LanguageModelToolUseId,
|
||||||
|
@ -817,38 +870,6 @@ impl AgentResponseEventStream {
|
||||||
.ok();
|
.ok();
|
||||||
}
|
}
|
||||||
|
|
||||||
fn send_tool_call_result(&self, tool_result: &LanguageModelToolResult) {
|
|
||||||
let status = if tool_result.is_error {
|
|
||||||
acp::ToolCallStatus::Failed
|
|
||||||
} else {
|
|
||||||
acp::ToolCallStatus::Completed
|
|
||||||
};
|
|
||||||
let content = match &tool_result.content {
|
|
||||||
LanguageModelToolResultContent::Text(text) => text.to_string().into(),
|
|
||||||
LanguageModelToolResultContent::Image(LanguageModelImage { source, .. }) => {
|
|
||||||
acp::ToolCallContent::Content {
|
|
||||||
content: acp::ContentBlock::Image(acp::ImageContent {
|
|
||||||
annotations: None,
|
|
||||||
data: source.to_string(),
|
|
||||||
mime_type: ImageFormat::Png.mime_type().to_string(),
|
|
||||||
}),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
self.0
|
|
||||||
.unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate(
|
|
||||||
acp::ToolCallUpdate {
|
|
||||||
id: acp::ToolCallId(tool_result.tool_use_id.to_string().into()),
|
|
||||||
fields: acp::ToolCallUpdateFields {
|
|
||||||
status: Some(status),
|
|
||||||
content: Some(vec![content]),
|
|
||||||
..Default::default()
|
|
||||||
},
|
|
||||||
},
|
|
||||||
)))
|
|
||||||
.ok();
|
|
||||||
}
|
|
||||||
|
|
||||||
fn send_stop(&self, reason: StopReason) {
|
fn send_stop(&self, reason: StopReason) {
|
||||||
match reason {
|
match reason {
|
||||||
StopReason::EndTurn => {
|
StopReason::EndTurn => {
|
||||||
|
@ -874,3 +895,32 @@ impl AgentResponseEventStream {
|
||||||
self.0.unbounded_send(Err(error)).ok();
|
self.0.unbounded_send(Err(error)).ok();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct ToolCallEventStream {
|
||||||
|
tool_use_id: LanguageModelToolUseId,
|
||||||
|
stream: AgentResponseEventStream,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ToolCallEventStream {
|
||||||
|
fn new(tool_use_id: LanguageModelToolUseId, stream: AgentResponseEventStream) -> Self {
|
||||||
|
Self {
|
||||||
|
tool_use_id,
|
||||||
|
stream,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn send_update(&self, fields: acp::ToolCallUpdateFields) {
|
||||||
|
self.stream.send_tool_call_update(&self.tool_use_id, fields);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn authorize(
|
||||||
|
&self,
|
||||||
|
title: String,
|
||||||
|
kind: acp::ToolKind,
|
||||||
|
input: serde_json::Value,
|
||||||
|
) -> impl use<> + Future<Output = Result<()>> {
|
||||||
|
self.stream
|
||||||
|
.authorize_tool_call(&self.tool_use_id, title, kind, input)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -1 +1,5 @@
|
||||||
mod glob;
|
mod find_path_tool;
|
||||||
|
mod thinking_tool;
|
||||||
|
|
||||||
|
pub use find_path_tool::*;
|
||||||
|
pub use thinking_tool::*;
|
||||||
|
|
231
crates/agent2/src/tools/find_path_tool.rs
Normal file
231
crates/agent2/src/tools/find_path_tool.rs
Normal file
|
@ -0,0 +1,231 @@
|
||||||
|
use agent_client_protocol as acp;
|
||||||
|
use anyhow::{anyhow, Result};
|
||||||
|
use gpui::{App, AppContext, Entity, SharedString, Task};
|
||||||
|
use project::Project;
|
||||||
|
use schemars::JsonSchema;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::fmt::Write;
|
||||||
|
use std::{cmp, path::PathBuf, sync::Arc};
|
||||||
|
use util::paths::PathMatcher;
|
||||||
|
|
||||||
|
use crate::{AgentTool, ToolCallEventStream};
|
||||||
|
|
||||||
|
/// Fast file path pattern matching tool that works with any codebase size
|
||||||
|
///
|
||||||
|
/// - Supports glob patterns like "**/*.js" or "src/**/*.ts"
|
||||||
|
/// - Returns matching file paths sorted alphabetically
|
||||||
|
/// - Prefer the `grep` tool to this tool when searching for symbols unless you have specific information about paths.
|
||||||
|
/// - Use this tool when you need to find files by name patterns
|
||||||
|
/// - Results are paginated with 50 matches per page. Use the optional 'offset' parameter to request subsequent pages.
|
||||||
|
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
|
||||||
|
pub struct FindPathToolInput {
|
||||||
|
/// The glob to match against every path in the project.
|
||||||
|
///
|
||||||
|
/// <example>
|
||||||
|
/// If the project has the following root directories:
|
||||||
|
///
|
||||||
|
/// - directory1/a/something.txt
|
||||||
|
/// - directory2/a/things.txt
|
||||||
|
/// - directory3/a/other.txt
|
||||||
|
///
|
||||||
|
/// You can get back the first two paths by providing a glob of "*thing*.txt"
|
||||||
|
/// </example>
|
||||||
|
pub glob: String,
|
||||||
|
|
||||||
|
/// Optional starting position for paginated results (0-based).
|
||||||
|
/// When not provided, starts from the beginning.
|
||||||
|
#[serde(default)]
|
||||||
|
pub offset: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
struct FindPathToolOutput {
|
||||||
|
paths: Vec<PathBuf>,
|
||||||
|
}
|
||||||
|
|
||||||
|
const RESULTS_PER_PAGE: usize = 50;
|
||||||
|
|
||||||
|
pub struct FindPathTool {
|
||||||
|
project: Entity<Project>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl FindPathTool {
|
||||||
|
pub fn new(project: Entity<Project>) -> Self {
|
||||||
|
Self { project }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AgentTool for FindPathTool {
|
||||||
|
type Input = FindPathToolInput;
|
||||||
|
|
||||||
|
fn name(&self) -> SharedString {
|
||||||
|
"find_path".into()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn kind(&self) -> acp::ToolKind {
|
||||||
|
acp::ToolKind::Search
|
||||||
|
}
|
||||||
|
|
||||||
|
fn initial_title(&self, input: Self::Input) -> SharedString {
|
||||||
|
format!("Find paths matching “`{}`”", input.glob).into()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run(
|
||||||
|
self: Arc<Self>,
|
||||||
|
input: Self::Input,
|
||||||
|
event_stream: ToolCallEventStream,
|
||||||
|
cx: &mut App,
|
||||||
|
) -> Task<Result<String>> {
|
||||||
|
let search_paths_task = search_paths(&input.glob, self.project.clone(), cx);
|
||||||
|
|
||||||
|
cx.background_spawn(async move {
|
||||||
|
let matches = search_paths_task.await?;
|
||||||
|
let paginated_matches: &[PathBuf] = &matches[cmp::min(input.offset, matches.len())
|
||||||
|
..cmp::min(input.offset + RESULTS_PER_PAGE, matches.len())];
|
||||||
|
|
||||||
|
event_stream.send_update(acp::ToolCallUpdateFields {
|
||||||
|
title: Some(if paginated_matches.len() == 0 {
|
||||||
|
"No matches".into()
|
||||||
|
} else if paginated_matches.len() == 1 {
|
||||||
|
"1 match".into()
|
||||||
|
} else {
|
||||||
|
format!("{} matches", paginated_matches.len())
|
||||||
|
}),
|
||||||
|
content: Some(
|
||||||
|
paginated_matches
|
||||||
|
.iter()
|
||||||
|
.map(|path| acp::ToolCallContent::Content {
|
||||||
|
content: acp::ContentBlock::ResourceLink(acp::ResourceLink {
|
||||||
|
uri: format!("file://{}", path.display()),
|
||||||
|
name: path.to_string_lossy().into(),
|
||||||
|
annotations: None,
|
||||||
|
description: None,
|
||||||
|
mime_type: None,
|
||||||
|
size: None,
|
||||||
|
title: None,
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
.collect(),
|
||||||
|
),
|
||||||
|
raw_output: Some(serde_json::json!({
|
||||||
|
"paths": &matches,
|
||||||
|
})),
|
||||||
|
..Default::default()
|
||||||
|
});
|
||||||
|
|
||||||
|
if matches.is_empty() {
|
||||||
|
Ok("No matches found".into())
|
||||||
|
} else {
|
||||||
|
let mut message = format!("Found {} total matches.", matches.len());
|
||||||
|
if matches.len() > RESULTS_PER_PAGE {
|
||||||
|
write!(
|
||||||
|
&mut message,
|
||||||
|
"\nShowing results {}-{} (provide 'offset' parameter for more results):",
|
||||||
|
input.offset + 1,
|
||||||
|
input.offset + paginated_matches.len()
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
for mat in matches.iter().skip(input.offset).take(RESULTS_PER_PAGE) {
|
||||||
|
write!(&mut message, "\n{}", mat.display()).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(message)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn search_paths(glob: &str, project: Entity<Project>, cx: &mut App) -> Task<Result<Vec<PathBuf>>> {
|
||||||
|
let path_matcher = match PathMatcher::new([
|
||||||
|
// Sometimes models try to search for "". In this case, return all paths in the project.
|
||||||
|
if glob.is_empty() { "*" } else { glob },
|
||||||
|
]) {
|
||||||
|
Ok(matcher) => matcher,
|
||||||
|
Err(err) => return Task::ready(Err(anyhow!("Invalid glob: {err}"))),
|
||||||
|
};
|
||||||
|
let snapshots: Vec<_> = project
|
||||||
|
.read(cx)
|
||||||
|
.worktrees(cx)
|
||||||
|
.map(|worktree| worktree.read(cx).snapshot())
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
cx.background_spawn(async move {
|
||||||
|
Ok(snapshots
|
||||||
|
.iter()
|
||||||
|
.flat_map(|snapshot| {
|
||||||
|
let root_name = PathBuf::from(snapshot.root_name());
|
||||||
|
snapshot
|
||||||
|
.entries(false, 0)
|
||||||
|
.map(move |entry| root_name.join(&entry.path))
|
||||||
|
.filter(|path| path_matcher.is_match(&path))
|
||||||
|
})
|
||||||
|
.collect())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod test {
|
||||||
|
use super::*;
|
||||||
|
use gpui::TestAppContext;
|
||||||
|
use project::{FakeFs, Project};
|
||||||
|
use settings::SettingsStore;
|
||||||
|
use util::path;
|
||||||
|
|
||||||
|
#[gpui::test]
|
||||||
|
async fn test_find_path_tool(cx: &mut TestAppContext) {
|
||||||
|
init_test(cx);
|
||||||
|
|
||||||
|
let fs = FakeFs::new(cx.executor());
|
||||||
|
fs.insert_tree(
|
||||||
|
"/root",
|
||||||
|
serde_json::json!({
|
||||||
|
"apple": {
|
||||||
|
"banana": {
|
||||||
|
"carrot": "1",
|
||||||
|
},
|
||||||
|
"bandana": {
|
||||||
|
"carbonara": "2",
|
||||||
|
},
|
||||||
|
"endive": "3"
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
|
||||||
|
|
||||||
|
let matches = cx
|
||||||
|
.update(|cx| search_paths("root/**/car*", project.clone(), cx))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(
|
||||||
|
matches,
|
||||||
|
&[
|
||||||
|
PathBuf::from("root/apple/banana/carrot"),
|
||||||
|
PathBuf::from("root/apple/bandana/carbonara")
|
||||||
|
]
|
||||||
|
);
|
||||||
|
|
||||||
|
let matches = cx
|
||||||
|
.update(|cx| search_paths("**/car*", project.clone(), cx))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(
|
||||||
|
matches,
|
||||||
|
&[
|
||||||
|
PathBuf::from("root/apple/banana/carrot"),
|
||||||
|
PathBuf::from("root/apple/bandana/carbonara")
|
||||||
|
]
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn init_test(cx: &mut TestAppContext) {
|
||||||
|
cx.update(|cx| {
|
||||||
|
let settings_store = SettingsStore::test(cx);
|
||||||
|
cx.set_global(settings_store);
|
||||||
|
language::init(cx);
|
||||||
|
Project::init_settings(cx);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,84 +0,0 @@
|
||||||
use anyhow::{anyhow, Result};
|
|
||||||
use gpui::{App, AppContext, Entity, SharedString, Task};
|
|
||||||
use project::Project;
|
|
||||||
use schemars::JsonSchema;
|
|
||||||
use serde::Deserialize;
|
|
||||||
use std::{path::PathBuf, sync::Arc};
|
|
||||||
use util::paths::PathMatcher;
|
|
||||||
use worktree::Snapshot as WorktreeSnapshot;
|
|
||||||
|
|
||||||
use crate::{
|
|
||||||
templates::{GlobTemplate, Template, Templates},
|
|
||||||
thread::AgentTool,
|
|
||||||
};
|
|
||||||
|
|
||||||
// Description is dynamic, see `fn description` below
|
|
||||||
#[derive(Deserialize, JsonSchema)]
|
|
||||||
struct GlobInput {
|
|
||||||
/// A POSIX glob pattern
|
|
||||||
glob: SharedString,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[expect(
|
|
||||||
dead_code,
|
|
||||||
reason = "Marked as unused by Rust 1.89 and left as is as of 07 Aug 2025 to let AI team address it."
|
|
||||||
)]
|
|
||||||
struct GlobTool {
|
|
||||||
project: Entity<Project>,
|
|
||||||
templates: Arc<Templates>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl AgentTool for GlobTool {
|
|
||||||
type Input = GlobInput;
|
|
||||||
|
|
||||||
fn name(&self) -> SharedString {
|
|
||||||
"glob".into()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn description(&self, cx: &mut App) -> SharedString {
|
|
||||||
let project_roots = self
|
|
||||||
.project
|
|
||||||
.read(cx)
|
|
||||||
.worktrees(cx)
|
|
||||||
.map(|worktree| worktree.read(cx).root_name().into())
|
|
||||||
.collect::<Vec<String>>()
|
|
||||||
.join("\n");
|
|
||||||
|
|
||||||
GlobTemplate { project_roots }
|
|
||||||
.render(&self.templates)
|
|
||||||
.expect("template failed to render")
|
|
||||||
.into()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn needs_authorization(&self, _input: Self::Input, _cx: &App) -> bool {
|
|
||||||
false
|
|
||||||
}
|
|
||||||
|
|
||||||
fn run(self: Arc<Self>, input: Self::Input, cx: &mut App) -> Task<Result<String>> {
|
|
||||||
let path_matcher = match PathMatcher::new([&input.glob]) {
|
|
||||||
Ok(matcher) => matcher,
|
|
||||||
Err(error) => return Task::ready(Err(anyhow!(error))),
|
|
||||||
};
|
|
||||||
|
|
||||||
let snapshots: Vec<WorktreeSnapshot> = self
|
|
||||||
.project
|
|
||||||
.read(cx)
|
|
||||||
.worktrees(cx)
|
|
||||||
.map(|worktree| worktree.read(cx).snapshot())
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
cx.background_spawn(async move {
|
|
||||||
let paths = snapshots.iter().flat_map(|snapshot| {
|
|
||||||
let root_name = PathBuf::from(snapshot.root_name());
|
|
||||||
snapshot
|
|
||||||
.entries(false, 0)
|
|
||||||
.map(move |entry| root_name.join(&entry.path))
|
|
||||||
.filter(|path| path_matcher.is_match(&path))
|
|
||||||
});
|
|
||||||
let output = paths
|
|
||||||
.map(|path| format!("{}\n", path.display()))
|
|
||||||
.collect::<String>();
|
|
||||||
Ok(output)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
48
crates/agent2/src/tools/thinking_tool.rs
Normal file
48
crates/agent2/src/tools/thinking_tool.rs
Normal file
|
@ -0,0 +1,48 @@
|
||||||
|
use agent_client_protocol as acp;
|
||||||
|
use anyhow::Result;
|
||||||
|
use gpui::{App, SharedString, Task};
|
||||||
|
use schemars::JsonSchema;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use crate::{AgentTool, ToolCallEventStream};
|
||||||
|
|
||||||
|
/// A tool for thinking through problems, brainstorming ideas, or planning without executing any actions.
|
||||||
|
/// Use this tool when you need to work through complex problems, develop strategies, or outline approaches before taking action.
|
||||||
|
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
|
||||||
|
pub struct ThinkingToolInput {
|
||||||
|
/// Content to think about. This should be a description of what to think about or
|
||||||
|
/// a problem to solve.
|
||||||
|
content: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct ThinkingTool;
|
||||||
|
|
||||||
|
impl AgentTool for ThinkingTool {
|
||||||
|
type Input = ThinkingToolInput;
|
||||||
|
|
||||||
|
fn name(&self) -> SharedString {
|
||||||
|
"thinking".into()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn kind(&self) -> acp::ToolKind {
|
||||||
|
acp::ToolKind::Think
|
||||||
|
}
|
||||||
|
|
||||||
|
fn initial_title(&self, _input: Self::Input) -> SharedString {
|
||||||
|
"Thinking".into()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run(
|
||||||
|
self: Arc<Self>,
|
||||||
|
input: Self::Input,
|
||||||
|
event_stream: ToolCallEventStream,
|
||||||
|
_cx: &mut App,
|
||||||
|
) -> Task<Result<String>> {
|
||||||
|
event_stream.send_update(acp::ToolCallUpdateFields {
|
||||||
|
content: Some(vec![input.content.into()]),
|
||||||
|
..Default::default()
|
||||||
|
});
|
||||||
|
Task::ready(Ok("Finished thinking.".to_string()))
|
||||||
|
}
|
||||||
|
}
|
|
@ -280,6 +280,7 @@ fn into_new_tool_call(id: acp::ToolCallId, request: acp_old::PushToolCallParams)
|
||||||
.map(into_new_tool_call_location)
|
.map(into_new_tool_call_location)
|
||||||
.collect(),
|
.collect(),
|
||||||
raw_input: None,
|
raw_input: None,
|
||||||
|
raw_output: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -297,6 +297,7 @@ impl ClaudeTool {
|
||||||
content: self.content(),
|
content: self.content(),
|
||||||
locations: self.locations(),
|
locations: self.locations(),
|
||||||
raw_input: None,
|
raw_input: None,
|
||||||
|
raw_output: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -2988,6 +2988,7 @@ mod tests {
|
||||||
content: vec!["hi".into()],
|
content: vec!["hi".into()],
|
||||||
locations: vec![],
|
locations: vec![],
|
||||||
raw_input: None,
|
raw_input: None,
|
||||||
|
raw_output: None,
|
||||||
};
|
};
|
||||||
let connection = StubAgentConnection::new(vec![acp::SessionUpdate::ToolCall(tool_call)])
|
let connection = StubAgentConnection::new(vec![acp::SessionUpdate::ToolCall(tool_call)])
|
||||||
.with_permission_requests(HashMap::from_iter([(
|
.with_permission_requests(HashMap::from_iter([(
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue