Lay the groundwork to create terminals in AcpThread
(#35872)
This just prepares the types so that it will be easy later to update a tool call with a terminal entity. We paused because we realized we want to simplify how terminals are created in zed, and so that warrants a dedicated pull request that can be reviewed in isolation. Release Notes: - N/A --------- Co-authored-by: Ben Brandt <benjamin.j.brandt@gmail.com>
This commit is contained in:
parent
51298b6912
commit
db901278f2
9 changed files with 292 additions and 143 deletions
|
@ -198,7 +198,7 @@ impl ToolCall {
|
|||
}
|
||||
}
|
||||
|
||||
fn update(
|
||||
fn update_fields(
|
||||
&mut self,
|
||||
fields: acp::ToolCallUpdateFields,
|
||||
language_registry: Arc<LanguageRegistry>,
|
||||
|
@ -415,6 +415,39 @@ impl ToolCallContent {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq)]
|
||||
pub enum ToolCallUpdate {
|
||||
UpdateFields(acp::ToolCallUpdate),
|
||||
UpdateDiff(ToolCallUpdateDiff),
|
||||
}
|
||||
|
||||
impl ToolCallUpdate {
|
||||
fn id(&self) -> &acp::ToolCallId {
|
||||
match self {
|
||||
Self::UpdateFields(update) => &update.id,
|
||||
Self::UpdateDiff(diff) => &diff.id,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<acp::ToolCallUpdate> for ToolCallUpdate {
|
||||
fn from(update: acp::ToolCallUpdate) -> Self {
|
||||
Self::UpdateFields(update)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<ToolCallUpdateDiff> for ToolCallUpdate {
|
||||
fn from(diff: ToolCallUpdateDiff) -> Self {
|
||||
Self::UpdateDiff(diff)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq)]
|
||||
pub struct ToolCallUpdateDiff {
|
||||
pub id: acp::ToolCallId,
|
||||
pub diff: Entity<Diff>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
pub struct Plan {
|
||||
pub entries: Vec<PlanEntry>,
|
||||
|
@ -710,36 +743,32 @@ impl AcpThread {
|
|||
|
||||
pub fn update_tool_call(
|
||||
&mut self,
|
||||
update: acp::ToolCallUpdate,
|
||||
update: impl Into<ToolCallUpdate>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Result<()> {
|
||||
let update = update.into();
|
||||
let languages = self.project.read(cx).languages().clone();
|
||||
|
||||
let (ix, current_call) = self
|
||||
.tool_call_mut(&update.id)
|
||||
.tool_call_mut(update.id())
|
||||
.context("Tool call not found")?;
|
||||
current_call.update(update.fields, languages, cx);
|
||||
match update {
|
||||
ToolCallUpdate::UpdateFields(update) => {
|
||||
current_call.update_fields(update.fields, languages, cx);
|
||||
}
|
||||
ToolCallUpdate::UpdateDiff(update) => {
|
||||
current_call.content.clear();
|
||||
current_call
|
||||
.content
|
||||
.push(ToolCallContent::Diff { diff: update.diff });
|
||||
}
|
||||
}
|
||||
|
||||
cx.emit(AcpThreadEvent::EntryUpdated(ix));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn set_tool_call_diff(
|
||||
&mut self,
|
||||
tool_call_id: &acp::ToolCallId,
|
||||
diff: Entity<Diff>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Result<()> {
|
||||
let (ix, current_call) = self
|
||||
.tool_call_mut(tool_call_id)
|
||||
.context("Tool call not found")?;
|
||||
current_call.content.clear();
|
||||
current_call.content.push(ToolCallContent::Diff { diff });
|
||||
cx.emit(AcpThreadEvent::EntryUpdated(ix));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Updates a tool call if id matches an existing entry, otherwise inserts a new one.
|
||||
pub fn upsert_tool_call(&mut self, tool_call: acp::ToolCall, cx: &mut Context<Self>) {
|
||||
let status = ToolCallStatus::Allowed {
|
||||
|
|
|
@ -503,29 +503,27 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
|
|||
match event {
|
||||
AgentResponseEvent::Text(text) => {
|
||||
acp_thread.update(cx, |thread, cx| {
|
||||
thread.handle_session_update(
|
||||
acp::SessionUpdate::AgentMessageChunk {
|
||||
content: acp::ContentBlock::Text(acp::TextContent {
|
||||
text,
|
||||
annotations: None,
|
||||
}),
|
||||
},
|
||||
thread.push_assistant_content_block(
|
||||
acp::ContentBlock::Text(acp::TextContent {
|
||||
text,
|
||||
annotations: None,
|
||||
}),
|
||||
false,
|
||||
cx,
|
||||
)
|
||||
})??;
|
||||
})?;
|
||||
}
|
||||
AgentResponseEvent::Thinking(text) => {
|
||||
acp_thread.update(cx, |thread, cx| {
|
||||
thread.handle_session_update(
|
||||
acp::SessionUpdate::AgentThoughtChunk {
|
||||
content: acp::ContentBlock::Text(acp::TextContent {
|
||||
text,
|
||||
annotations: None,
|
||||
}),
|
||||
},
|
||||
thread.push_assistant_content_block(
|
||||
acp::ContentBlock::Text(acp::TextContent {
|
||||
text,
|
||||
annotations: None,
|
||||
}),
|
||||
true,
|
||||
cx,
|
||||
)
|
||||
})??;
|
||||
})?;
|
||||
}
|
||||
AgentResponseEvent::ToolCallAuthorization(ToolCallAuthorization {
|
||||
tool_call,
|
||||
|
@ -551,27 +549,12 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
|
|||
}
|
||||
AgentResponseEvent::ToolCall(tool_call) => {
|
||||
acp_thread.update(cx, |thread, cx| {
|
||||
thread.handle_session_update(
|
||||
acp::SessionUpdate::ToolCall(tool_call),
|
||||
cx,
|
||||
)
|
||||
})??;
|
||||
thread.upsert_tool_call(tool_call, cx)
|
||||
})?;
|
||||
}
|
||||
AgentResponseEvent::ToolCallUpdate(tool_call_update) => {
|
||||
AgentResponseEvent::ToolCallUpdate(update) => {
|
||||
acp_thread.update(cx, |thread, cx| {
|
||||
thread.handle_session_update(
|
||||
acp::SessionUpdate::ToolCallUpdate(tool_call_update),
|
||||
cx,
|
||||
)
|
||||
})??;
|
||||
}
|
||||
AgentResponseEvent::ToolCallDiff(tool_call_diff) => {
|
||||
acp_thread.update(cx, |thread, cx| {
|
||||
thread.set_tool_call_diff(
|
||||
&tool_call_diff.tool_call_id,
|
||||
tool_call_diff.diff,
|
||||
cx,
|
||||
)
|
||||
thread.update_tool_call(update, cx)
|
||||
})??;
|
||||
}
|
||||
AgentResponseEvent::Stop(stop_reason) => {
|
||||
|
|
|
@ -306,7 +306,7 @@ async fn test_tool_hallucination(cx: &mut TestAppContext) {
|
|||
let tool_call = expect_tool_call(&mut events).await;
|
||||
assert_eq!(tool_call.title, "nonexistent_tool");
|
||||
assert_eq!(tool_call.status, acp::ToolCallStatus::Pending);
|
||||
let update = expect_tool_call_update(&mut events).await;
|
||||
let update = expect_tool_call_update_fields(&mut events).await;
|
||||
assert_eq!(update.fields.status, Some(acp::ToolCallStatus::Failed));
|
||||
}
|
||||
|
||||
|
@ -326,7 +326,7 @@ async fn expect_tool_call(
|
|||
}
|
||||
}
|
||||
|
||||
async fn expect_tool_call_update(
|
||||
async fn expect_tool_call_update_fields(
|
||||
events: &mut UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
|
||||
) -> acp::ToolCallUpdate {
|
||||
let event = events
|
||||
|
@ -335,7 +335,9 @@ async fn expect_tool_call_update(
|
|||
.expect("no tool call authorization event received")
|
||||
.unwrap();
|
||||
match event {
|
||||
AgentResponseEvent::ToolCallUpdate(tool_call_update) => return tool_call_update,
|
||||
AgentResponseEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => {
|
||||
return update
|
||||
}
|
||||
event => {
|
||||
panic!("Unexpected event {event:?}");
|
||||
}
|
||||
|
@ -425,31 +427,33 @@ async fn test_cancellation(cx: &mut TestAppContext) {
|
|||
});
|
||||
|
||||
// Wait until both tools are called.
|
||||
let mut expected_tool_calls = vec!["echo", "infinite"];
|
||||
let mut expected_tools = vec!["Echo", "Infinite Tool"];
|
||||
let mut echo_id = None;
|
||||
let mut echo_completed = false;
|
||||
while let Some(event) = events.next().await {
|
||||
match event.unwrap() {
|
||||
AgentResponseEvent::ToolCall(tool_call) => {
|
||||
assert_eq!(tool_call.title, expected_tool_calls.remove(0));
|
||||
if tool_call.title == "echo" {
|
||||
assert_eq!(tool_call.title, expected_tools.remove(0));
|
||||
if tool_call.title == "Echo" {
|
||||
echo_id = Some(tool_call.id);
|
||||
}
|
||||
}
|
||||
AgentResponseEvent::ToolCallUpdate(acp::ToolCallUpdate {
|
||||
id,
|
||||
fields:
|
||||
acp::ToolCallUpdateFields {
|
||||
status: Some(acp::ToolCallStatus::Completed),
|
||||
..
|
||||
},
|
||||
}) if Some(&id) == echo_id.as_ref() => {
|
||||
AgentResponseEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
|
||||
acp::ToolCallUpdate {
|
||||
id,
|
||||
fields:
|
||||
acp::ToolCallUpdateFields {
|
||||
status: Some(acp::ToolCallStatus::Completed),
|
||||
..
|
||||
},
|
||||
},
|
||||
)) if Some(&id) == echo_id.as_ref() => {
|
||||
echo_completed = true;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
if expected_tool_calls.is_empty() && echo_completed {
|
||||
if expected_tools.is_empty() && echo_completed {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
@ -647,13 +651,26 @@ async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
|
|||
let mut events = thread.update(cx, |thread, cx| thread.send(model.clone(), "Think", cx));
|
||||
cx.run_until_parked();
|
||||
|
||||
let input = json!({ "content": "Thinking hard!" });
|
||||
// Simulate streaming partial input.
|
||||
let input = json!({});
|
||||
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
|
||||
LanguageModelToolUse {
|
||||
id: "1".into(),
|
||||
name: ThinkingTool.name().into(),
|
||||
raw_input: input.to_string(),
|
||||
input,
|
||||
is_input_complete: false,
|
||||
},
|
||||
));
|
||||
|
||||
// Input streaming completed
|
||||
let input = json!({ "content": "Thinking hard!" });
|
||||
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
|
||||
LanguageModelToolUse {
|
||||
id: "1".into(),
|
||||
name: "thinking".into(),
|
||||
raw_input: input.to_string(),
|
||||
input,
|
||||
is_input_complete: true,
|
||||
},
|
||||
));
|
||||
|
@ -670,22 +687,35 @@ async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
|
|||
status: acp::ToolCallStatus::Pending,
|
||||
content: vec![],
|
||||
locations: vec![],
|
||||
raw_input: Some(json!({ "content": "Thinking hard!" })),
|
||||
raw_input: Some(json!({})),
|
||||
raw_output: None,
|
||||
}
|
||||
);
|
||||
let update = expect_tool_call_update(&mut events).await;
|
||||
let update = expect_tool_call_update_fields(&mut events).await;
|
||||
assert_eq!(
|
||||
update,
|
||||
acp::ToolCallUpdate {
|
||||
id: acp::ToolCallId("1".into()),
|
||||
fields: acp::ToolCallUpdateFields {
|
||||
status: Some(acp::ToolCallStatus::InProgress,),
|
||||
title: Some("Thinking".into()),
|
||||
kind: Some(acp::ToolKind::Think),
|
||||
raw_input: Some(json!({ "content": "Thinking hard!" })),
|
||||
..Default::default()
|
||||
},
|
||||
}
|
||||
);
|
||||
let update = expect_tool_call_update(&mut events).await;
|
||||
let update = expect_tool_call_update_fields(&mut events).await;
|
||||
assert_eq!(
|
||||
update,
|
||||
acp::ToolCallUpdate {
|
||||
id: acp::ToolCallId("1".into()),
|
||||
fields: acp::ToolCallUpdateFields {
|
||||
status: Some(acp::ToolCallStatus::InProgress),
|
||||
..Default::default()
|
||||
},
|
||||
}
|
||||
);
|
||||
let update = expect_tool_call_update_fields(&mut events).await;
|
||||
assert_eq!(
|
||||
update,
|
||||
acp::ToolCallUpdate {
|
||||
|
@ -696,7 +726,7 @@ async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
|
|||
},
|
||||
}
|
||||
);
|
||||
let update = expect_tool_call_update(&mut events).await;
|
||||
let update = expect_tool_call_update_fields(&mut events).await;
|
||||
assert_eq!(
|
||||
update,
|
||||
acp::ToolCallUpdate {
|
||||
|
|
|
@ -24,7 +24,7 @@ impl AgentTool for EchoTool {
|
|||
acp::ToolKind::Other
|
||||
}
|
||||
|
||||
fn initial_title(&self, _: Self::Input) -> SharedString {
|
||||
fn initial_title(&self, _input: Result<Self::Input, serde_json::Value>) -> SharedString {
|
||||
"Echo".into()
|
||||
}
|
||||
|
||||
|
@ -55,8 +55,12 @@ impl AgentTool for DelayTool {
|
|||
"delay".into()
|
||||
}
|
||||
|
||||
fn initial_title(&self, input: Self::Input) -> SharedString {
|
||||
format!("Delay {}ms", input.ms).into()
|
||||
fn initial_title(&self, input: Result<Self::Input, serde_json::Value>) -> SharedString {
|
||||
if let Ok(input) = input {
|
||||
format!("Delay {}ms", input.ms).into()
|
||||
} else {
|
||||
"Delay".into()
|
||||
}
|
||||
}
|
||||
|
||||
fn kind(&self) -> acp::ToolKind {
|
||||
|
@ -96,7 +100,7 @@ impl AgentTool for ToolRequiringPermission {
|
|||
acp::ToolKind::Other
|
||||
}
|
||||
|
||||
fn initial_title(&self, _input: Self::Input) -> SharedString {
|
||||
fn initial_title(&self, _input: Result<Self::Input, serde_json::Value>) -> SharedString {
|
||||
"This tool requires permission".into()
|
||||
}
|
||||
|
||||
|
@ -131,8 +135,8 @@ impl AgentTool for InfiniteTool {
|
|||
acp::ToolKind::Other
|
||||
}
|
||||
|
||||
fn initial_title(&self, _input: Self::Input) -> SharedString {
|
||||
"This is the tool that never ends... it just goes on and on my friends!".into()
|
||||
fn initial_title(&self, _input: Result<Self::Input, serde_json::Value>) -> SharedString {
|
||||
"Infinite Tool".into()
|
||||
}
|
||||
|
||||
fn run(
|
||||
|
@ -182,7 +186,7 @@ impl AgentTool for WordListTool {
|
|||
acp::ToolKind::Other
|
||||
}
|
||||
|
||||
fn initial_title(&self, _input: Self::Input) -> SharedString {
|
||||
fn initial_title(&self, _input: Result<Self::Input, serde_json::Value>) -> SharedString {
|
||||
"List of random words".into()
|
||||
}
|
||||
|
||||
|
|
|
@ -102,9 +102,8 @@ pub enum AgentResponseEvent {
|
|||
Text(String),
|
||||
Thinking(String),
|
||||
ToolCall(acp::ToolCall),
|
||||
ToolCallUpdate(acp::ToolCallUpdate),
|
||||
ToolCallUpdate(acp_thread::ToolCallUpdate),
|
||||
ToolCallAuthorization(ToolCallAuthorization),
|
||||
ToolCallDiff(ToolCallDiff),
|
||||
Stop(acp::StopReason),
|
||||
}
|
||||
|
||||
|
@ -115,12 +114,6 @@ pub struct ToolCallAuthorization {
|
|||
pub response: oneshot::Sender<acp::PermissionOptionId>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ToolCallDiff {
|
||||
pub tool_call_id: acp::ToolCallId,
|
||||
pub diff: Entity<acp_thread::Diff>,
|
||||
}
|
||||
|
||||
pub struct Thread {
|
||||
messages: Vec<AgentMessage>,
|
||||
completion_mode: CompletionMode,
|
||||
|
@ -294,7 +287,7 @@ impl Thread {
|
|||
while let Some(tool_result) = tool_uses.next().await {
|
||||
log::info!("Tool finished {:?}", tool_result);
|
||||
|
||||
event_stream.send_tool_call_update(
|
||||
event_stream.update_tool_call_fields(
|
||||
&tool_result.tool_use_id,
|
||||
acp::ToolCallUpdateFields {
|
||||
status: Some(if tool_result.is_error {
|
||||
|
@ -474,15 +467,24 @@ impl Thread {
|
|||
}
|
||||
});
|
||||
|
||||
let mut title = SharedString::from(&tool_use.name);
|
||||
let mut kind = acp::ToolKind::Other;
|
||||
if let Some(tool) = tool.as_ref() {
|
||||
title = tool.initial_title(tool_use.input.clone());
|
||||
kind = tool.kind();
|
||||
}
|
||||
|
||||
if push_new_tool_use {
|
||||
event_stream.send_tool_call(tool.as_ref(), &tool_use);
|
||||
event_stream.send_tool_call(&tool_use.id, title, kind, tool_use.input.clone());
|
||||
last_message
|
||||
.content
|
||||
.push(MessageContent::ToolUse(tool_use.clone()));
|
||||
} else {
|
||||
event_stream.send_tool_call_update(
|
||||
event_stream.update_tool_call_fields(
|
||||
&tool_use.id,
|
||||
acp::ToolCallUpdateFields {
|
||||
title: Some(title.into()),
|
||||
kind: Some(kind),
|
||||
raw_input: Some(tool_use.input.clone()),
|
||||
..Default::default()
|
||||
},
|
||||
|
@ -506,7 +508,7 @@ impl Thread {
|
|||
|
||||
let tool_event_stream =
|
||||
ToolCallEventStream::new(&tool_use, tool.kind(), event_stream.clone());
|
||||
tool_event_stream.send_update(acp::ToolCallUpdateFields {
|
||||
tool_event_stream.update_fields(acp::ToolCallUpdateFields {
|
||||
status: Some(acp::ToolCallStatus::InProgress),
|
||||
..Default::default()
|
||||
});
|
||||
|
@ -693,7 +695,7 @@ where
|
|||
fn kind(&self) -> acp::ToolKind;
|
||||
|
||||
/// The initial tool title to display. Can be updated during the tool run.
|
||||
fn initial_title(&self, input: Self::Input) -> SharedString;
|
||||
fn initial_title(&self, input: Result<Self::Input, serde_json::Value>) -> SharedString;
|
||||
|
||||
/// Returns the JSON schema that describes the tool's input.
|
||||
fn input_schema(&self) -> Schema {
|
||||
|
@ -724,7 +726,7 @@ pub trait AnyAgentTool {
|
|||
fn name(&self) -> SharedString;
|
||||
fn description(&self, cx: &mut App) -> SharedString;
|
||||
fn kind(&self) -> acp::ToolKind;
|
||||
fn initial_title(&self, input: serde_json::Value) -> Result<SharedString>;
|
||||
fn initial_title(&self, input: serde_json::Value) -> SharedString;
|
||||
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value>;
|
||||
fn run(
|
||||
self: Arc<Self>,
|
||||
|
@ -750,9 +752,9 @@ where
|
|||
self.0.kind()
|
||||
}
|
||||
|
||||
fn initial_title(&self, input: serde_json::Value) -> Result<SharedString> {
|
||||
let parsed_input = serde_json::from_value(input)?;
|
||||
Ok(self.0.initial_title(parsed_input))
|
||||
fn initial_title(&self, input: serde_json::Value) -> SharedString {
|
||||
let parsed_input = serde_json::from_value(input.clone()).map_err(|_| input);
|
||||
self.0.initial_title(parsed_input)
|
||||
}
|
||||
|
||||
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
|
||||
|
@ -842,17 +844,17 @@ impl AgentResponseEventStream {
|
|||
|
||||
fn send_tool_call(
|
||||
&self,
|
||||
tool: Option<&Arc<dyn AnyAgentTool>>,
|
||||
tool_use: &LanguageModelToolUse,
|
||||
id: &LanguageModelToolUseId,
|
||||
title: SharedString,
|
||||
kind: acp::ToolKind,
|
||||
input: serde_json::Value,
|
||||
) {
|
||||
self.0
|
||||
.unbounded_send(Ok(AgentResponseEvent::ToolCall(Self::initial_tool_call(
|
||||
&tool_use.id,
|
||||
tool.and_then(|t| t.initial_title(tool_use.input.clone()).ok())
|
||||
.map(|i| i.into())
|
||||
.unwrap_or_else(|| tool_use.name.to_string()),
|
||||
tool.map(|t| t.kind()).unwrap_or(acp::ToolKind::Other),
|
||||
tool_use.input.clone(),
|
||||
id,
|
||||
title.to_string(),
|
||||
kind,
|
||||
input,
|
||||
))))
|
||||
.ok();
|
||||
}
|
||||
|
@ -875,7 +877,7 @@ impl AgentResponseEventStream {
|
|||
}
|
||||
}
|
||||
|
||||
fn send_tool_call_update(
|
||||
fn update_tool_call_fields(
|
||||
&self,
|
||||
tool_use_id: &LanguageModelToolUseId,
|
||||
fields: acp::ToolCallUpdateFields,
|
||||
|
@ -885,14 +887,21 @@ impl AgentResponseEventStream {
|
|||
acp::ToolCallUpdate {
|
||||
id: acp::ToolCallId(tool_use_id.to_string().into()),
|
||||
fields,
|
||||
},
|
||||
}
|
||||
.into(),
|
||||
)))
|
||||
.ok();
|
||||
}
|
||||
|
||||
fn send_tool_call_diff(&self, tool_call_diff: ToolCallDiff) {
|
||||
fn update_tool_call_diff(&self, tool_use_id: &LanguageModelToolUseId, diff: Entity<Diff>) {
|
||||
self.0
|
||||
.unbounded_send(Ok(AgentResponseEvent::ToolCallDiff(tool_call_diff)))
|
||||
.unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate(
|
||||
acp_thread::ToolCallUpdateDiff {
|
||||
id: acp::ToolCallId(tool_use_id.to_string().into()),
|
||||
diff,
|
||||
}
|
||||
.into(),
|
||||
)))
|
||||
.ok();
|
||||
}
|
||||
|
||||
|
@ -964,15 +973,13 @@ impl ToolCallEventStream {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn send_update(&self, fields: acp::ToolCallUpdateFields) {
|
||||
self.stream.send_tool_call_update(&self.tool_use_id, fields);
|
||||
pub fn update_fields(&self, fields: acp::ToolCallUpdateFields) {
|
||||
self.stream
|
||||
.update_tool_call_fields(&self.tool_use_id, fields);
|
||||
}
|
||||
|
||||
pub fn send_diff(&self, diff: Entity<Diff>) {
|
||||
self.stream.send_tool_call_diff(ToolCallDiff {
|
||||
tool_call_id: acp::ToolCallId(self.tool_use_id.to_string().into()),
|
||||
diff,
|
||||
});
|
||||
pub fn update_diff(&self, diff: Entity<Diff>) {
|
||||
self.stream.update_tool_call_diff(&self.tool_use_id, diff);
|
||||
}
|
||||
|
||||
pub fn authorize(&self, title: String) -> impl use<> + Future<Output = Result<()>> {
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
use crate::{AgentTool, Thread, ToolCallEventStream};
|
||||
use acp_thread::Diff;
|
||||
use agent_client_protocol as acp;
|
||||
use anyhow::{anyhow, Context as _, Result};
|
||||
|
@ -20,7 +21,7 @@ use std::sync::Arc;
|
|||
use ui::SharedString;
|
||||
use util::ResultExt;
|
||||
|
||||
use crate::{AgentTool, Thread, ToolCallEventStream};
|
||||
const DEFAULT_UI_TEXT: &str = "Editing file";
|
||||
|
||||
/// This is a tool for creating a new file or editing an existing file. For moving or renaming files, you should generally use the `terminal` tool with the 'mv' command instead.
|
||||
///
|
||||
|
@ -78,6 +79,14 @@ pub struct EditFileToolInput {
|
|||
pub mode: EditFileMode,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
|
||||
struct EditFileToolPartialInput {
|
||||
#[serde(default)]
|
||||
path: String,
|
||||
#[serde(default)]
|
||||
display_description: String,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum EditFileMode {
|
||||
|
@ -182,8 +191,27 @@ impl AgentTool for EditFileTool {
|
|||
acp::ToolKind::Edit
|
||||
}
|
||||
|
||||
fn initial_title(&self, input: Self::Input) -> SharedString {
|
||||
input.display_description.into()
|
||||
fn initial_title(&self, input: Result<Self::Input, serde_json::Value>) -> SharedString {
|
||||
match input {
|
||||
Ok(input) => input.display_description.into(),
|
||||
Err(raw_input) => {
|
||||
if let Some(input) =
|
||||
serde_json::from_value::<EditFileToolPartialInput>(raw_input).ok()
|
||||
{
|
||||
let description = input.display_description.trim();
|
||||
if !description.is_empty() {
|
||||
return description.to_string().into();
|
||||
}
|
||||
|
||||
let path = input.path.trim().to_string();
|
||||
if !path.is_empty() {
|
||||
return path.into();
|
||||
}
|
||||
}
|
||||
|
||||
DEFAULT_UI_TEXT.into()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn run(
|
||||
|
@ -226,7 +254,7 @@ impl AgentTool for EditFileTool {
|
|||
.await?;
|
||||
|
||||
let diff = cx.new(|cx| Diff::new(buffer.clone(), cx))?;
|
||||
event_stream.send_diff(diff.clone());
|
||||
event_stream.update_diff(diff.clone());
|
||||
|
||||
let old_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
|
||||
let old_text = cx
|
||||
|
@ -1348,6 +1376,66 @@ mod tests {
|
|||
}
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_initial_title_with_partial_input(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
let fs = project::FakeFs::new(cx.executor());
|
||||
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
|
||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||
let model = Arc::new(FakeLanguageModel::default());
|
||||
let thread = cx.new(|_| {
|
||||
Thread::new(
|
||||
project.clone(),
|
||||
Rc::default(),
|
||||
action_log.clone(),
|
||||
Templates::new(),
|
||||
model.clone(),
|
||||
)
|
||||
});
|
||||
let tool = Arc::new(EditFileTool { thread });
|
||||
|
||||
assert_eq!(
|
||||
tool.initial_title(Err(json!({
|
||||
"path": "src/main.rs",
|
||||
"display_description": "",
|
||||
"old_string": "old code",
|
||||
"new_string": "new code"
|
||||
}))),
|
||||
"src/main.rs"
|
||||
);
|
||||
assert_eq!(
|
||||
tool.initial_title(Err(json!({
|
||||
"path": "",
|
||||
"display_description": "Fix error handling",
|
||||
"old_string": "old code",
|
||||
"new_string": "new code"
|
||||
}))),
|
||||
"Fix error handling"
|
||||
);
|
||||
assert_eq!(
|
||||
tool.initial_title(Err(json!({
|
||||
"path": "src/main.rs",
|
||||
"display_description": "Fix error handling",
|
||||
"old_string": "old code",
|
||||
"new_string": "new code"
|
||||
}))),
|
||||
"Fix error handling"
|
||||
);
|
||||
assert_eq!(
|
||||
tool.initial_title(Err(json!({
|
||||
"path": "",
|
||||
"display_description": "",
|
||||
"old_string": "old code",
|
||||
"new_string": "new code"
|
||||
}))),
|
||||
DEFAULT_UI_TEXT
|
||||
);
|
||||
assert_eq!(
|
||||
tool.initial_title(Err(serde_json::Value::Null)),
|
||||
DEFAULT_UI_TEXT
|
||||
);
|
||||
}
|
||||
|
||||
fn init_test(cx: &mut TestAppContext) {
|
||||
cx.update(|cx| {
|
||||
let settings_store = SettingsStore::test(cx);
|
||||
|
|
|
@ -94,8 +94,12 @@ impl AgentTool for FindPathTool {
|
|||
acp::ToolKind::Search
|
||||
}
|
||||
|
||||
fn initial_title(&self, input: Self::Input) -> SharedString {
|
||||
format!("Find paths matching “`{}`”", input.glob).into()
|
||||
fn initial_title(&self, input: Result<Self::Input, serde_json::Value>) -> SharedString {
|
||||
let mut title = "Find paths".to_string();
|
||||
if let Ok(input) = input {
|
||||
title.push_str(&format!(" matching “`{}`”", input.glob));
|
||||
}
|
||||
title.into()
|
||||
}
|
||||
|
||||
fn run(
|
||||
|
@ -111,7 +115,7 @@ impl AgentTool for FindPathTool {
|
|||
let paginated_matches: &[PathBuf] = &matches[cmp::min(input.offset, matches.len())
|
||||
..cmp::min(input.offset + RESULTS_PER_PAGE, matches.len())];
|
||||
|
||||
event_stream.send_update(acp::ToolCallUpdateFields {
|
||||
event_stream.update_fields(acp::ToolCallUpdateFields {
|
||||
title: Some(if paginated_matches.len() == 0 {
|
||||
"No matches".into()
|
||||
} else if paginated_matches.len() == 1 {
|
||||
|
|
|
@ -70,24 +70,28 @@ impl AgentTool for ReadFileTool {
|
|||
acp::ToolKind::Read
|
||||
}
|
||||
|
||||
fn initial_title(&self, input: Self::Input) -> SharedString {
|
||||
let path = &input.path;
|
||||
match (input.start_line, input.end_line) {
|
||||
(Some(start), Some(end)) => {
|
||||
format!(
|
||||
"[Read file `{}` (lines {}-{})](@selection:{}:({}-{}))",
|
||||
path, start, end, path, start, end
|
||||
)
|
||||
fn initial_title(&self, input: Result<Self::Input, serde_json::Value>) -> SharedString {
|
||||
if let Ok(input) = input {
|
||||
let path = &input.path;
|
||||
match (input.start_line, input.end_line) {
|
||||
(Some(start), Some(end)) => {
|
||||
format!(
|
||||
"[Read file `{}` (lines {}-{})](@selection:{}:({}-{}))",
|
||||
path, start, end, path, start, end
|
||||
)
|
||||
}
|
||||
(Some(start), None) => {
|
||||
format!(
|
||||
"[Read file `{}` (from line {})](@selection:{}:({}-{}))",
|
||||
path, start, path, start, start
|
||||
)
|
||||
}
|
||||
_ => format!("[Read file `{}`](@file:{})", path, path),
|
||||
}
|
||||
(Some(start), None) => {
|
||||
format!(
|
||||
"[Read file `{}` (from line {})](@selection:{}:({}-{}))",
|
||||
path, start, path, start, start
|
||||
)
|
||||
}
|
||||
_ => format!("[Read file `{}`](@file:{})", path, path),
|
||||
.into()
|
||||
} else {
|
||||
"Read file".into()
|
||||
}
|
||||
.into()
|
||||
}
|
||||
|
||||
fn run(
|
||||
|
|
|
@ -30,7 +30,7 @@ impl AgentTool for ThinkingTool {
|
|||
acp::ToolKind::Think
|
||||
}
|
||||
|
||||
fn initial_title(&self, _input: Self::Input) -> SharedString {
|
||||
fn initial_title(&self, _input: Result<Self::Input, serde_json::Value>) -> SharedString {
|
||||
"Thinking".into()
|
||||
}
|
||||
|
||||
|
@ -40,7 +40,7 @@ impl AgentTool for ThinkingTool {
|
|||
event_stream: ToolCallEventStream,
|
||||
_cx: &mut App,
|
||||
) -> Task<Result<String>> {
|
||||
event_stream.send_update(acp::ToolCallUpdateFields {
|
||||
event_stream.update_fields(acp::ToolCallUpdateFields {
|
||||
content: Some(vec![input.content.into()]),
|
||||
..Default::default()
|
||||
});
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue