agent2: Port edit_file
tool (#35844)
TODO: - [x] Authorization - [x] Restore tests Release Notes: - N/A --------- Co-authored-by: Antonio Scandurra <me@as-cii.com> Co-authored-by: Ben Brandt <benjamin.j.brandt@gmail.com>
This commit is contained in:
parent
d705585a2e
commit
2526dcb5a5
20 changed files with 2075 additions and 414 deletions
|
@ -1,4 +1,5 @@
|
|||
use crate::templates::{SystemPromptTemplate, Template, Templates};
|
||||
use crate::{SystemPromptTemplate, Template, Templates};
|
||||
use acp_thread::Diff;
|
||||
use agent_client_protocol as acp;
|
||||
use anyhow::{anyhow, Context as _, Result};
|
||||
use assistant_tool::{adapt_schema_to_format, ActionLog};
|
||||
|
@ -103,6 +104,7 @@ pub enum AgentResponseEvent {
|
|||
ToolCall(acp::ToolCall),
|
||||
ToolCallUpdate(acp::ToolCallUpdate),
|
||||
ToolCallAuthorization(ToolCallAuthorization),
|
||||
ToolCallDiff(ToolCallDiff),
|
||||
Stop(acp::StopReason),
|
||||
}
|
||||
|
||||
|
@ -113,6 +115,12 @@ pub struct ToolCallAuthorization {
|
|||
pub response: oneshot::Sender<acp::PermissionOptionId>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ToolCallDiff {
|
||||
pub tool_call_id: acp::ToolCallId,
|
||||
pub diff: Entity<acp_thread::Diff>,
|
||||
}
|
||||
|
||||
pub struct Thread {
|
||||
messages: Vec<AgentMessage>,
|
||||
completion_mode: CompletionMode,
|
||||
|
@ -125,12 +133,13 @@ pub struct Thread {
|
|||
project_context: Rc<RefCell<ProjectContext>>,
|
||||
templates: Arc<Templates>,
|
||||
pub selected_model: Arc<dyn LanguageModel>,
|
||||
project: Entity<Project>,
|
||||
action_log: Entity<ActionLog>,
|
||||
}
|
||||
|
||||
impl Thread {
|
||||
pub fn new(
|
||||
_project: Entity<Project>,
|
||||
project: Entity<Project>,
|
||||
project_context: Rc<RefCell<ProjectContext>>,
|
||||
action_log: Entity<ActionLog>,
|
||||
templates: Arc<Templates>,
|
||||
|
@ -145,10 +154,19 @@ impl Thread {
|
|||
project_context,
|
||||
templates,
|
||||
selected_model: default_model,
|
||||
project,
|
||||
action_log,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn project(&self) -> &Entity<Project> {
|
||||
&self.project
|
||||
}
|
||||
|
||||
pub fn action_log(&self) -> &Entity<ActionLog> {
|
||||
&self.action_log
|
||||
}
|
||||
|
||||
pub fn set_mode(&mut self, mode: CompletionMode) {
|
||||
self.completion_mode = mode;
|
||||
}
|
||||
|
@ -315,10 +333,6 @@ impl Thread {
|
|||
events_rx
|
||||
}
|
||||
|
||||
pub fn action_log(&self) -> &Entity<ActionLog> {
|
||||
&self.action_log
|
||||
}
|
||||
|
||||
pub fn build_system_message(&self) -> AgentMessage {
|
||||
log::debug!("Building system message");
|
||||
let prompt = SystemPromptTemplate {
|
||||
|
@ -490,15 +504,33 @@ impl Thread {
|
|||
}));
|
||||
};
|
||||
|
||||
let tool_result = self.run_tool(tool, tool_use.clone(), event_stream.clone(), cx);
|
||||
let tool_event_stream =
|
||||
ToolCallEventStream::new(&tool_use, tool.kind(), event_stream.clone());
|
||||
tool_event_stream.send_update(acp::ToolCallUpdateFields {
|
||||
status: Some(acp::ToolCallStatus::InProgress),
|
||||
..Default::default()
|
||||
});
|
||||
let supports_images = self.selected_model.supports_images();
|
||||
let tool_result = tool.run(tool_use.input, tool_event_stream, cx);
|
||||
Some(cx.foreground_executor().spawn(async move {
|
||||
match tool_result.await {
|
||||
Ok(tool_output) => LanguageModelToolResult {
|
||||
let tool_result = tool_result.await.and_then(|output| {
|
||||
if let LanguageModelToolResultContent::Image(_) = &output.llm_output {
|
||||
if !supports_images {
|
||||
return Err(anyhow!(
|
||||
"Attempted to read an image, but this model doesn't support it.",
|
||||
));
|
||||
}
|
||||
}
|
||||
Ok(output)
|
||||
});
|
||||
|
||||
match tool_result {
|
||||
Ok(output) => LanguageModelToolResult {
|
||||
tool_use_id: tool_use.id,
|
||||
tool_name: tool_use.name,
|
||||
is_error: false,
|
||||
content: LanguageModelToolResultContent::Text(Arc::from(tool_output)),
|
||||
output: None,
|
||||
content: output.llm_output,
|
||||
output: Some(output.raw_output),
|
||||
},
|
||||
Err(error) => LanguageModelToolResult {
|
||||
tool_use_id: tool_use.id,
|
||||
|
@ -511,24 +543,6 @@ impl Thread {
|
|||
}))
|
||||
}
|
||||
|
||||
fn run_tool(
|
||||
&self,
|
||||
tool: Arc<dyn AnyAgentTool>,
|
||||
tool_use: LanguageModelToolUse,
|
||||
event_stream: AgentResponseEventStream,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Task<Result<String>> {
|
||||
cx.spawn(async move |_this, cx| {
|
||||
let tool_event_stream = ToolCallEventStream::new(tool_use.id, event_stream);
|
||||
tool_event_stream.send_update(acp::ToolCallUpdateFields {
|
||||
status: Some(acp::ToolCallStatus::InProgress),
|
||||
..Default::default()
|
||||
});
|
||||
cx.update(|cx| tool.run(tool_use.input, tool_event_stream, cx))?
|
||||
.await
|
||||
})
|
||||
}
|
||||
|
||||
fn handle_tool_use_json_parse_error_event(
|
||||
&mut self,
|
||||
tool_use_id: LanguageModelToolUseId,
|
||||
|
@ -572,7 +586,7 @@ impl Thread {
|
|||
self.messages.last_mut().unwrap()
|
||||
}
|
||||
|
||||
fn build_completion_request(
|
||||
pub(crate) fn build_completion_request(
|
||||
&self,
|
||||
completion_intent: CompletionIntent,
|
||||
cx: &mut App,
|
||||
|
@ -662,6 +676,7 @@ where
|
|||
Self: 'static + Sized,
|
||||
{
|
||||
type Input: for<'de> Deserialize<'de> + Serialize + JsonSchema;
|
||||
type Output: for<'de> Deserialize<'de> + Serialize + Into<LanguageModelToolResultContent>;
|
||||
|
||||
fn name(&self) -> SharedString;
|
||||
|
||||
|
@ -685,23 +700,13 @@ where
|
|||
schemars::schema_for!(Self::Input)
|
||||
}
|
||||
|
||||
/// Allows the tool to authorize a given tool call with the user if necessary
|
||||
fn authorize(
|
||||
&self,
|
||||
input: Self::Input,
|
||||
event_stream: ToolCallEventStream,
|
||||
) -> impl use<Self> + Future<Output = Result<()>> {
|
||||
let json_input = serde_json::json!(&input);
|
||||
event_stream.authorize(self.initial_title(input).into(), self.kind(), json_input)
|
||||
}
|
||||
|
||||
/// Runs the tool with the provided input.
|
||||
fn run(
|
||||
self: Arc<Self>,
|
||||
input: Self::Input,
|
||||
event_stream: ToolCallEventStream,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<String>>;
|
||||
) -> Task<Result<Self::Output>>;
|
||||
|
||||
fn erase(self) -> Arc<dyn AnyAgentTool> {
|
||||
Arc::new(Erased(Arc::new(self)))
|
||||
|
@ -710,6 +715,11 @@ where
|
|||
|
||||
pub struct Erased<T>(T);
|
||||
|
||||
pub struct AgentToolOutput {
|
||||
llm_output: LanguageModelToolResultContent,
|
||||
raw_output: serde_json::Value,
|
||||
}
|
||||
|
||||
pub trait AnyAgentTool {
|
||||
fn name(&self) -> SharedString;
|
||||
fn description(&self, cx: &mut App) -> SharedString;
|
||||
|
@ -721,7 +731,7 @@ pub trait AnyAgentTool {
|
|||
input: serde_json::Value,
|
||||
event_stream: ToolCallEventStream,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<String>>;
|
||||
) -> Task<Result<AgentToolOutput>>;
|
||||
}
|
||||
|
||||
impl<T> AnyAgentTool for Erased<Arc<T>>
|
||||
|
@ -756,12 +766,18 @@ where
|
|||
input: serde_json::Value,
|
||||
event_stream: ToolCallEventStream,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<String>> {
|
||||
let parsed_input: Result<T::Input> = serde_json::from_value(input).map_err(Into::into);
|
||||
match parsed_input {
|
||||
Ok(input) => self.0.clone().run(input, event_stream, cx),
|
||||
Err(error) => Task::ready(Err(anyhow!(error))),
|
||||
}
|
||||
) -> Task<Result<AgentToolOutput>> {
|
||||
cx.spawn(async move |cx| {
|
||||
let input = serde_json::from_value(input)?;
|
||||
let output = cx
|
||||
.update(|cx| self.0.clone().run(input, event_stream, cx))?
|
||||
.await?;
|
||||
let raw_output = serde_json::to_value(&output)?;
|
||||
Ok(AgentToolOutput {
|
||||
llm_output: output.into(),
|
||||
raw_output,
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -874,6 +890,12 @@ impl AgentResponseEventStream {
|
|||
.ok();
|
||||
}
|
||||
|
||||
fn send_tool_call_diff(&self, tool_call_diff: ToolCallDiff) {
|
||||
self.0
|
||||
.unbounded_send(Ok(AgentResponseEvent::ToolCallDiff(tool_call_diff)))
|
||||
.ok();
|
||||
}
|
||||
|
||||
fn send_stop(&self, reason: StopReason) {
|
||||
match reason {
|
||||
StopReason::EndTurn => {
|
||||
|
@ -903,13 +925,41 @@ impl AgentResponseEventStream {
|
|||
#[derive(Clone)]
|
||||
pub struct ToolCallEventStream {
|
||||
tool_use_id: LanguageModelToolUseId,
|
||||
kind: acp::ToolKind,
|
||||
input: serde_json::Value,
|
||||
stream: AgentResponseEventStream,
|
||||
}
|
||||
|
||||
impl ToolCallEventStream {
|
||||
fn new(tool_use_id: LanguageModelToolUseId, stream: AgentResponseEventStream) -> Self {
|
||||
#[cfg(test)]
|
||||
pub fn test() -> (Self, ToolCallEventStreamReceiver) {
|
||||
let (events_tx, events_rx) =
|
||||
mpsc::unbounded::<Result<AgentResponseEvent, LanguageModelCompletionError>>();
|
||||
|
||||
let stream = ToolCallEventStream::new(
|
||||
&LanguageModelToolUse {
|
||||
id: "test_id".into(),
|
||||
name: "test_tool".into(),
|
||||
raw_input: String::new(),
|
||||
input: serde_json::Value::Null,
|
||||
is_input_complete: true,
|
||||
},
|
||||
acp::ToolKind::Other,
|
||||
AgentResponseEventStream(events_tx),
|
||||
);
|
||||
|
||||
(stream, ToolCallEventStreamReceiver(events_rx))
|
||||
}
|
||||
|
||||
fn new(
|
||||
tool_use: &LanguageModelToolUse,
|
||||
kind: acp::ToolKind,
|
||||
stream: AgentResponseEventStream,
|
||||
) -> Self {
|
||||
Self {
|
||||
tool_use_id,
|
||||
tool_use_id: tool_use.id.clone(),
|
||||
kind,
|
||||
input: tool_use.input.clone(),
|
||||
stream,
|
||||
}
|
||||
}
|
||||
|
@ -918,38 +968,52 @@ impl ToolCallEventStream {
|
|||
self.stream.send_tool_call_update(&self.tool_use_id, fields);
|
||||
}
|
||||
|
||||
pub fn authorize(
|
||||
&self,
|
||||
title: String,
|
||||
kind: acp::ToolKind,
|
||||
input: serde_json::Value,
|
||||
) -> impl use<> + Future<Output = Result<()>> {
|
||||
self.stream
|
||||
.authorize_tool_call(&self.tool_use_id, title, kind, input)
|
||||
pub fn send_diff(&self, diff: Entity<Diff>) {
|
||||
self.stream.send_tool_call_diff(ToolCallDiff {
|
||||
tool_call_id: acp::ToolCallId(self.tool_use_id.to_string().into()),
|
||||
diff,
|
||||
});
|
||||
}
|
||||
|
||||
pub fn authorize(&self, title: String) -> impl use<> + Future<Output = Result<()>> {
|
||||
self.stream.authorize_tool_call(
|
||||
&self.tool_use_id,
|
||||
title,
|
||||
self.kind.clone(),
|
||||
self.input.clone(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub struct TestToolCallEventStream {
|
||||
stream: ToolCallEventStream,
|
||||
_events_rx: mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
|
||||
}
|
||||
pub struct ToolCallEventStreamReceiver(
|
||||
mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
|
||||
);
|
||||
|
||||
#[cfg(test)]
|
||||
impl TestToolCallEventStream {
|
||||
pub fn new() -> Self {
|
||||
let (events_tx, events_rx) =
|
||||
mpsc::unbounded::<Result<AgentResponseEvent, LanguageModelCompletionError>>();
|
||||
|
||||
let stream = ToolCallEventStream::new("test".into(), AgentResponseEventStream(events_tx));
|
||||
|
||||
Self {
|
||||
stream,
|
||||
_events_rx: events_rx,
|
||||
impl ToolCallEventStreamReceiver {
|
||||
pub async fn expect_tool_authorization(&mut self) -> ToolCallAuthorization {
|
||||
let event = self.0.next().await;
|
||||
if let Some(Ok(AgentResponseEvent::ToolCallAuthorization(auth))) = event {
|
||||
auth
|
||||
} else {
|
||||
panic!("Expected ToolCallAuthorization but got: {:?}", event);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn stream(&self) -> ToolCallEventStream {
|
||||
self.stream.clone()
|
||||
#[cfg(test)]
|
||||
impl std::ops::Deref for ToolCallEventStreamReceiver {
|
||||
type Target = mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
impl std::ops::DerefMut for ToolCallEventStreamReceiver {
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
&mut self.0
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue