Wire up find_path tool in agent2 (#35799)
Release Notes: - N/A --------- Co-authored-by: Antonio Scandurra <me@as-cii.com>
This commit is contained in:
parent
11efa32fa7
commit
90fa921756
18 changed files with 669 additions and 247 deletions
|
@ -1,16 +1,16 @@
|
|||
use crate::templates::{SystemPromptTemplate, Template, Templates};
|
||||
use agent_client_protocol as acp;
|
||||
use anyhow::{anyhow, Context as _, Result};
|
||||
use assistant_tool::ActionLog;
|
||||
use assistant_tool::{adapt_schema_to_format, ActionLog};
|
||||
use cloud_llm_client::{CompletionIntent, CompletionMode};
|
||||
use collections::HashMap;
|
||||
use futures::{
|
||||
channel::{mpsc, oneshot},
|
||||
stream::FuturesUnordered,
|
||||
};
|
||||
use gpui::{App, Context, Entity, ImageFormat, SharedString, Task};
|
||||
use gpui::{App, Context, Entity, SharedString, Task};
|
||||
use language_model::{
|
||||
LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelImage,
|
||||
LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
|
||||
LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool,
|
||||
LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolSchemaFormat,
|
||||
LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role, StopReason,
|
||||
|
@ -19,7 +19,7 @@ use log;
|
|||
use project::Project;
|
||||
use prompt_store::ProjectContext;
|
||||
use schemars::{JsonSchema, Schema};
|
||||
use serde::Deserialize;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use smol::stream::StreamExt;
|
||||
use std::{cell::RefCell, collections::BTreeMap, fmt::Write, future::Future, rc::Rc, sync::Arc};
|
||||
use util::{markdown::MarkdownCodeBlock, ResultExt};
|
||||
|
@ -276,7 +276,17 @@ impl Thread {
|
|||
while let Some(tool_result) = tool_uses.next().await {
|
||||
log::info!("Tool finished {:?}", tool_result);
|
||||
|
||||
event_stream.send_tool_call_result(&tool_result);
|
||||
event_stream.send_tool_call_update(
|
||||
&tool_result.tool_use_id,
|
||||
acp::ToolCallUpdateFields {
|
||||
status: Some(if tool_result.is_error {
|
||||
acp::ToolCallStatus::Failed
|
||||
} else {
|
||||
acp::ToolCallStatus::Completed
|
||||
}),
|
||||
..Default::default()
|
||||
},
|
||||
);
|
||||
thread
|
||||
.update(cx, |thread, _cx| {
|
||||
thread.pending_tool_uses.remove(&tool_result.tool_use_id);
|
||||
|
@ -426,6 +436,8 @@ impl Thread {
|
|||
) -> Option<Task<LanguageModelToolResult>> {
|
||||
cx.notify();
|
||||
|
||||
let tool = self.tools.get(tool_use.name.as_ref()).cloned();
|
||||
|
||||
self.pending_tool_uses
|
||||
.insert(tool_use.id.clone(), tool_use.clone());
|
||||
let last_message = self.last_assistant_message();
|
||||
|
@ -443,8 +455,9 @@ impl Thread {
|
|||
true
|
||||
}
|
||||
});
|
||||
|
||||
if push_new_tool_use {
|
||||
event_stream.send_tool_call(&tool_use);
|
||||
event_stream.send_tool_call(tool.as_ref(), &tool_use);
|
||||
last_message
|
||||
.content
|
||||
.push(MessageContent::ToolUse(tool_use.clone()));
|
||||
|
@ -462,37 +475,36 @@ impl Thread {
|
|||
return None;
|
||||
}
|
||||
|
||||
if let Some(tool) = self.tools.get(tool_use.name.as_ref()) {
|
||||
let tool_result =
|
||||
self.run_tool(tool.clone(), tool_use.clone(), event_stream.clone(), cx);
|
||||
Some(cx.foreground_executor().spawn(async move {
|
||||
match tool_result.await {
|
||||
Ok(tool_output) => LanguageModelToolResult {
|
||||
tool_use_id: tool_use.id,
|
||||
tool_name: tool_use.name,
|
||||
is_error: false,
|
||||
content: LanguageModelToolResultContent::Text(Arc::from(tool_output)),
|
||||
output: None,
|
||||
},
|
||||
Err(error) => LanguageModelToolResult {
|
||||
tool_use_id: tool_use.id,
|
||||
tool_name: tool_use.name,
|
||||
is_error: true,
|
||||
content: LanguageModelToolResultContent::Text(Arc::from(error.to_string())),
|
||||
output: None,
|
||||
},
|
||||
}
|
||||
}))
|
||||
} else {
|
||||
let Some(tool) = tool else {
|
||||
let content = format!("No tool named {} exists", tool_use.name);
|
||||
Some(Task::ready(LanguageModelToolResult {
|
||||
return Some(Task::ready(LanguageModelToolResult {
|
||||
content: LanguageModelToolResultContent::Text(Arc::from(content)),
|
||||
tool_use_id: tool_use.id,
|
||||
tool_name: tool_use.name,
|
||||
is_error: true,
|
||||
output: None,
|
||||
}))
|
||||
}
|
||||
}));
|
||||
};
|
||||
|
||||
let tool_result = self.run_tool(tool, tool_use.clone(), event_stream.clone(), cx);
|
||||
Some(cx.foreground_executor().spawn(async move {
|
||||
match tool_result.await {
|
||||
Ok(tool_output) => LanguageModelToolResult {
|
||||
tool_use_id: tool_use.id,
|
||||
tool_name: tool_use.name,
|
||||
is_error: false,
|
||||
content: LanguageModelToolResultContent::Text(Arc::from(tool_output)),
|
||||
output: None,
|
||||
},
|
||||
Err(error) => LanguageModelToolResult {
|
||||
tool_use_id: tool_use.id,
|
||||
tool_name: tool_use.name,
|
||||
is_error: true,
|
||||
content: LanguageModelToolResultContent::Text(Arc::from(error.to_string())),
|
||||
output: None,
|
||||
},
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
fn run_tool(
|
||||
|
@ -502,20 +514,14 @@ impl Thread {
|
|||
event_stream: AgentResponseEventStream,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Task<Result<String>> {
|
||||
let needs_authorization = tool.needs_authorization(tool_use.input.clone(), cx);
|
||||
cx.spawn(async move |_this, cx| {
|
||||
if needs_authorization? {
|
||||
event_stream.authorize_tool_call(&tool_use).await?;
|
||||
}
|
||||
|
||||
event_stream.send_tool_call_update(
|
||||
&tool_use.id,
|
||||
acp::ToolCallUpdateFields {
|
||||
status: Some(acp::ToolCallStatus::InProgress),
|
||||
..Default::default()
|
||||
},
|
||||
);
|
||||
cx.update(|cx| tool.run(tool_use.input, cx))?.await
|
||||
let tool_event_stream = ToolCallEventStream::new(tool_use.id, event_stream);
|
||||
tool_event_stream.send_update(acp::ToolCallUpdateFields {
|
||||
status: Some(acp::ToolCallStatus::InProgress),
|
||||
..Default::default()
|
||||
});
|
||||
cx.update(|cx| tool.run(tool_use.input, tool_event_stream, cx))?
|
||||
.await
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -584,7 +590,7 @@ impl Thread {
|
|||
name: tool_name,
|
||||
description: tool.description(cx).to_string(),
|
||||
input_schema: tool
|
||||
.input_schema(LanguageModelToolSchemaFormat::JsonSchema)
|
||||
.input_schema(self.selected_model.tool_input_format())
|
||||
.log_err()?,
|
||||
})
|
||||
})
|
||||
|
@ -651,9 +657,10 @@ pub trait AgentTool
|
|||
where
|
||||
Self: 'static + Sized,
|
||||
{
|
||||
type Input: for<'de> Deserialize<'de> + JsonSchema;
|
||||
type Input: for<'de> Deserialize<'de> + Serialize + JsonSchema;
|
||||
|
||||
fn name(&self) -> SharedString;
|
||||
|
||||
fn description(&self, _cx: &mut App) -> SharedString {
|
||||
let schema = schemars::schema_for!(Self::Input);
|
||||
SharedString::new(
|
||||
|
@ -664,17 +671,33 @@ where
|
|||
)
|
||||
}
|
||||
|
||||
fn kind(&self) -> acp::ToolKind;
|
||||
|
||||
/// The initial tool title to display. Can be updated during the tool run.
|
||||
fn initial_title(&self, input: Self::Input) -> SharedString;
|
||||
|
||||
/// Returns the JSON schema that describes the tool's input.
|
||||
fn input_schema(&self, _format: LanguageModelToolSchemaFormat) -> Schema {
|
||||
fn input_schema(&self) -> Schema {
|
||||
schemars::schema_for!(Self::Input)
|
||||
}
|
||||
|
||||
/// Returns true if the tool needs the users's authorization
|
||||
/// before running.
|
||||
fn needs_authorization(&self, input: Self::Input, cx: &App) -> bool;
|
||||
/// Allows the tool to authorize a given tool call with the user if necessary
|
||||
fn authorize(
|
||||
&self,
|
||||
input: Self::Input,
|
||||
event_stream: ToolCallEventStream,
|
||||
) -> impl use<Self> + Future<Output = Result<()>> {
|
||||
let json_input = serde_json::json!(&input);
|
||||
event_stream.authorize(self.initial_title(input).into(), self.kind(), json_input)
|
||||
}
|
||||
|
||||
/// Runs the tool with the provided input.
|
||||
fn run(self: Arc<Self>, input: Self::Input, cx: &mut App) -> Task<Result<String>>;
|
||||
fn run(
|
||||
self: Arc<Self>,
|
||||
input: Self::Input,
|
||||
event_stream: ToolCallEventStream,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<String>>;
|
||||
|
||||
fn erase(self) -> Arc<dyn AnyAgentTool> {
|
||||
Arc::new(Erased(Arc::new(self)))
|
||||
|
@ -686,9 +709,15 @@ pub struct Erased<T>(T);
|
|||
pub trait AnyAgentTool {
|
||||
fn name(&self) -> SharedString;
|
||||
fn description(&self, cx: &mut App) -> SharedString;
|
||||
fn kind(&self) -> acp::ToolKind;
|
||||
fn initial_title(&self, input: serde_json::Value) -> Result<SharedString>;
|
||||
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value>;
|
||||
fn needs_authorization(&self, input: serde_json::Value, cx: &mut App) -> Result<bool>;
|
||||
fn run(self: Arc<Self>, input: serde_json::Value, cx: &mut App) -> Task<Result<String>>;
|
||||
fn run(
|
||||
self: Arc<Self>,
|
||||
input: serde_json::Value,
|
||||
event_stream: ToolCallEventStream,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<String>>;
|
||||
}
|
||||
|
||||
impl<T> AnyAgentTool for Erased<Arc<T>>
|
||||
|
@ -703,22 +732,30 @@ where
|
|||
self.0.description(cx)
|
||||
}
|
||||
|
||||
fn kind(&self) -> agent_client_protocol::ToolKind {
|
||||
self.0.kind()
|
||||
}
|
||||
|
||||
fn initial_title(&self, input: serde_json::Value) -> Result<SharedString> {
|
||||
let parsed_input = serde_json::from_value(input)?;
|
||||
Ok(self.0.initial_title(parsed_input))
|
||||
}
|
||||
|
||||
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
|
||||
Ok(serde_json::to_value(self.0.input_schema(format))?)
|
||||
let mut json = serde_json::to_value(self.0.input_schema())?;
|
||||
adapt_schema_to_format(&mut json, format)?;
|
||||
Ok(json)
|
||||
}
|
||||
|
||||
fn needs_authorization(&self, input: serde_json::Value, cx: &mut App) -> Result<bool> {
|
||||
fn run(
|
||||
self: Arc<Self>,
|
||||
input: serde_json::Value,
|
||||
event_stream: ToolCallEventStream,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<String>> {
|
||||
let parsed_input: Result<T::Input> = serde_json::from_value(input).map_err(Into::into);
|
||||
match parsed_input {
|
||||
Ok(input) => Ok(self.0.needs_authorization(input, cx)),
|
||||
Err(error) => Err(anyhow!(error)),
|
||||
}
|
||||
}
|
||||
|
||||
fn run(self: Arc<Self>, input: serde_json::Value, cx: &mut App) -> Task<Result<String>> {
|
||||
let parsed_input: Result<T::Input> = serde_json::from_value(input).map_err(Into::into);
|
||||
match parsed_input {
|
||||
Ok(input) => self.0.clone().run(input, cx),
|
||||
Ok(input) => self.0.clone().run(input, event_stream, cx),
|
||||
Err(error) => Task::ready(Err(anyhow!(error))),
|
||||
}
|
||||
}
|
||||
|
@ -744,21 +781,16 @@ impl AgentResponseEventStream {
|
|||
|
||||
fn authorize_tool_call(
|
||||
&self,
|
||||
tool_use: &LanguageModelToolUse,
|
||||
id: &LanguageModelToolUseId,
|
||||
title: String,
|
||||
kind: acp::ToolKind,
|
||||
input: serde_json::Value,
|
||||
) -> impl use<> + Future<Output = Result<()>> {
|
||||
let (response_tx, response_rx) = oneshot::channel();
|
||||
self.0
|
||||
.unbounded_send(Ok(AgentResponseEvent::ToolCallAuthorization(
|
||||
ToolCallAuthorization {
|
||||
tool_call: acp::ToolCall {
|
||||
id: acp::ToolCallId(tool_use.id.to_string().into()),
|
||||
title: tool_use.name.to_string(),
|
||||
kind: acp::ToolKind::Other,
|
||||
status: acp::ToolCallStatus::Pending,
|
||||
content: vec![],
|
||||
locations: vec![],
|
||||
raw_input: Some(tool_use.input.clone()),
|
||||
},
|
||||
tool_call: Self::initial_tool_call(id, title, kind, input),
|
||||
options: vec![
|
||||
acp::PermissionOption {
|
||||
id: acp::PermissionOptionId("always_allow".into()),
|
||||
|
@ -788,20 +820,41 @@ impl AgentResponseEventStream {
|
|||
}
|
||||
}
|
||||
|
||||
fn send_tool_call(&self, tool_use: &LanguageModelToolUse) {
|
||||
fn send_tool_call(
|
||||
&self,
|
||||
tool: Option<&Arc<dyn AnyAgentTool>>,
|
||||
tool_use: &LanguageModelToolUse,
|
||||
) {
|
||||
self.0
|
||||
.unbounded_send(Ok(AgentResponseEvent::ToolCall(acp::ToolCall {
|
||||
id: acp::ToolCallId(tool_use.id.to_string().into()),
|
||||
title: tool_use.name.to_string(),
|
||||
kind: acp::ToolKind::Other,
|
||||
status: acp::ToolCallStatus::Pending,
|
||||
content: vec![],
|
||||
locations: vec![],
|
||||
raw_input: Some(tool_use.input.clone()),
|
||||
})))
|
||||
.unbounded_send(Ok(AgentResponseEvent::ToolCall(Self::initial_tool_call(
|
||||
&tool_use.id,
|
||||
tool.and_then(|t| t.initial_title(tool_use.input.clone()).ok())
|
||||
.map(|i| i.into())
|
||||
.unwrap_or_else(|| tool_use.name.to_string()),
|
||||
tool.map(|t| t.kind()).unwrap_or(acp::ToolKind::Other),
|
||||
tool_use.input.clone(),
|
||||
))))
|
||||
.ok();
|
||||
}
|
||||
|
||||
fn initial_tool_call(
|
||||
id: &LanguageModelToolUseId,
|
||||
title: String,
|
||||
kind: acp::ToolKind,
|
||||
input: serde_json::Value,
|
||||
) -> acp::ToolCall {
|
||||
acp::ToolCall {
|
||||
id: acp::ToolCallId(id.to_string().into()),
|
||||
title,
|
||||
kind,
|
||||
status: acp::ToolCallStatus::Pending,
|
||||
content: vec![],
|
||||
locations: vec![],
|
||||
raw_input: Some(input),
|
||||
raw_output: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn send_tool_call_update(
|
||||
&self,
|
||||
tool_use_id: &LanguageModelToolUseId,
|
||||
|
@ -817,38 +870,6 @@ impl AgentResponseEventStream {
|
|||
.ok();
|
||||
}
|
||||
|
||||
fn send_tool_call_result(&self, tool_result: &LanguageModelToolResult) {
|
||||
let status = if tool_result.is_error {
|
||||
acp::ToolCallStatus::Failed
|
||||
} else {
|
||||
acp::ToolCallStatus::Completed
|
||||
};
|
||||
let content = match &tool_result.content {
|
||||
LanguageModelToolResultContent::Text(text) => text.to_string().into(),
|
||||
LanguageModelToolResultContent::Image(LanguageModelImage { source, .. }) => {
|
||||
acp::ToolCallContent::Content {
|
||||
content: acp::ContentBlock::Image(acp::ImageContent {
|
||||
annotations: None,
|
||||
data: source.to_string(),
|
||||
mime_type: ImageFormat::Png.mime_type().to_string(),
|
||||
}),
|
||||
}
|
||||
}
|
||||
};
|
||||
self.0
|
||||
.unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate(
|
||||
acp::ToolCallUpdate {
|
||||
id: acp::ToolCallId(tool_result.tool_use_id.to_string().into()),
|
||||
fields: acp::ToolCallUpdateFields {
|
||||
status: Some(status),
|
||||
content: Some(vec![content]),
|
||||
..Default::default()
|
||||
},
|
||||
},
|
||||
)))
|
||||
.ok();
|
||||
}
|
||||
|
||||
fn send_stop(&self, reason: StopReason) {
|
||||
match reason {
|
||||
StopReason::EndTurn => {
|
||||
|
@ -874,3 +895,32 @@ impl AgentResponseEventStream {
|
|||
self.0.unbounded_send(Err(error)).ok();
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ToolCallEventStream {
|
||||
tool_use_id: LanguageModelToolUseId,
|
||||
stream: AgentResponseEventStream,
|
||||
}
|
||||
|
||||
impl ToolCallEventStream {
|
||||
fn new(tool_use_id: LanguageModelToolUseId, stream: AgentResponseEventStream) -> Self {
|
||||
Self {
|
||||
tool_use_id,
|
||||
stream,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn send_update(&self, fields: acp::ToolCallUpdateFields) {
|
||||
self.stream.send_tool_call_update(&self.tool_use_id, fields);
|
||||
}
|
||||
|
||||
pub fn authorize(
|
||||
&self,
|
||||
title: String,
|
||||
kind: acp::ToolKind,
|
||||
input: serde_json::Value,
|
||||
) -> impl use<> + Future<Output = Result<()>> {
|
||||
self.stream
|
||||
.authorize_tool_call(&self.tool_use_id, title, kind, input)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue