Merge branch 'acp' of github.com:zed-industries/zed into acp
This commit is contained in:
commit
348bc52a3f
2 changed files with 30 additions and 91 deletions
|
@ -123,7 +123,7 @@ pub enum AgentThreadEntryContent {
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct ToolCall {
|
pub struct ToolCall {
|
||||||
id: ToolCallId,
|
id: ToolCallId,
|
||||||
tool_name: Entity<Markdown>,
|
display_name: Entity<Markdown>,
|
||||||
status: ToolCallStatus,
|
status: ToolCallStatus,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -271,7 +271,7 @@ impl AcpThread {
|
||||||
|
|
||||||
pub fn request_tool_call(
|
pub fn request_tool_call(
|
||||||
&mut self,
|
&mut self,
|
||||||
title: String,
|
display_name: String,
|
||||||
confirmation: acp::ToolCallConfirmation,
|
confirmation: acp::ToolCallConfirmation,
|
||||||
cx: &mut Context<Self>,
|
cx: &mut Context<Self>,
|
||||||
) -> ToolCallRequest {
|
) -> ToolCallRequest {
|
||||||
|
@ -282,22 +282,22 @@ impl AcpThread {
|
||||||
respond_tx: tx,
|
respond_tx: tx,
|
||||||
};
|
};
|
||||||
|
|
||||||
let id = self.insert_tool_call(title, status, cx);
|
let id = self.insert_tool_call(display_name, status, cx);
|
||||||
ToolCallRequest { id, outcome: rx }
|
ToolCallRequest { id, outcome: rx }
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn push_tool_call(&mut self, title: String, cx: &mut Context<Self>) -> ToolCallId {
|
pub fn push_tool_call(&mut self, display_name: String, cx: &mut Context<Self>) -> ToolCallId {
|
||||||
let status = ToolCallStatus::Allowed {
|
let status = ToolCallStatus::Allowed {
|
||||||
status: acp::ToolCallStatus::Running,
|
status: acp::ToolCallStatus::Running,
|
||||||
content: None,
|
content: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
self.insert_tool_call(title, status, cx)
|
self.insert_tool_call(display_name, status, cx)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn insert_tool_call(
|
fn insert_tool_call(
|
||||||
&mut self,
|
&mut self,
|
||||||
title: String,
|
display_name: String,
|
||||||
status: ToolCallStatus,
|
status: ToolCallStatus,
|
||||||
cx: &mut Context<Self>,
|
cx: &mut Context<Self>,
|
||||||
) -> ToolCallId {
|
) -> ToolCallId {
|
||||||
|
@ -307,8 +307,13 @@ impl AcpThread {
|
||||||
AgentThreadEntryContent::ToolCall(ToolCall {
|
AgentThreadEntryContent::ToolCall(ToolCall {
|
||||||
// todo! clean up id creation
|
// todo! clean up id creation
|
||||||
id: ToolCallId(ThreadEntryId(self.entries.len() as u64)),
|
id: ToolCallId(ThreadEntryId(self.entries.len() as u64)),
|
||||||
tool_name: cx.new(|cx| {
|
display_name: cx.new(|cx| {
|
||||||
Markdown::new(title.into(), Some(language_registry.clone()), None, cx)
|
Markdown::new(
|
||||||
|
display_name.into(),
|
||||||
|
Some(language_registry.clone()),
|
||||||
|
None,
|
||||||
|
cx,
|
||||||
|
)
|
||||||
}),
|
}),
|
||||||
status,
|
status,
|
||||||
}),
|
}),
|
||||||
|
@ -441,13 +446,11 @@ pub struct ToolCallRequest {
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use futures::{FutureExt as _, channel::mpsc, select};
|
|
||||||
use gpui::{AsyncApp, TestAppContext};
|
use gpui::{AsyncApp, TestAppContext};
|
||||||
use project::FakeFs;
|
use project::FakeFs;
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
use settings::SettingsStore;
|
use settings::SettingsStore;
|
||||||
use smol::stream::StreamExt;
|
use std::{env, path::Path, process::Stdio};
|
||||||
use std::{env, path::Path, process::Stdio, time::Duration};
|
|
||||||
use util::path;
|
use util::path;
|
||||||
|
|
||||||
fn init_test(cx: &mut TestAppContext) {
|
fn init_test(cx: &mut TestAppContext) {
|
||||||
|
@ -509,27 +512,27 @@ mod tests {
|
||||||
let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
|
let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
|
||||||
let server = gemini_acp_server(project.clone(), cx.to_async()).unwrap();
|
let server = gemini_acp_server(project.clone(), cx.to_async()).unwrap();
|
||||||
let thread = server.create_thread(&mut cx.to_async()).await.unwrap();
|
let thread = server.create_thread(&mut cx.to_async()).await.unwrap();
|
||||||
let full_turn = thread.update(cx, |thread, cx| {
|
thread
|
||||||
|
.update(cx, |thread, cx| {
|
||||||
thread.send(
|
thread.send(
|
||||||
"Read the '/private/tmp/foo' file and tell me what you see.",
|
"Read the '/private/tmp/foo' file and tell me what you see.",
|
||||||
cx,
|
cx,
|
||||||
)
|
)
|
||||||
});
|
})
|
||||||
|
.await
|
||||||
run_until_tool_call(&thread, cx).await;
|
.unwrap();
|
||||||
|
thread.read_with(cx, |thread, cx| {
|
||||||
let tool_call_id = thread.read_with(cx, |thread, cx| {
|
|
||||||
let AgentThreadEntryContent::ToolCall(ToolCall {
|
let AgentThreadEntryContent::ToolCall(ToolCall {
|
||||||
id,
|
id,
|
||||||
tool_name,
|
display_name,
|
||||||
status: ToolCallStatus::Allowed { .. },
|
status: ToolCallStatus::Allowed { content, .. },
|
||||||
}) = &thread.entries().last().unwrap().content
|
}) = &thread.entries()[1].content
|
||||||
else {
|
else {
|
||||||
panic!();
|
panic!();
|
||||||
};
|
};
|
||||||
|
|
||||||
tool_name.read_with(cx, |md, _cx| {
|
display_name.read_with(cx, |md, _cx| {
|
||||||
assert_eq!(md.source(), "read_file");
|
assert_eq!(md.source(), "ReadFile");
|
||||||
});
|
});
|
||||||
|
|
||||||
// todo!
|
// todo!
|
||||||
|
@ -542,70 +545,6 @@ mod tests {
|
||||||
// });
|
// });
|
||||||
*id
|
*id
|
||||||
});
|
});
|
||||||
|
|
||||||
thread.update(cx, |thread, cx| {
|
|
||||||
thread.authorize_tool_call(tool_call_id, acp::ToolCallConfirmationOutcome::Allow, cx);
|
|
||||||
assert!(matches!(
|
|
||||||
thread.entries().last().unwrap().content,
|
|
||||||
AgentThreadEntryContent::ToolCall(ToolCall {
|
|
||||||
status: ToolCallStatus::Allowed { .. },
|
|
||||||
..
|
|
||||||
})
|
|
||||||
));
|
|
||||||
});
|
|
||||||
|
|
||||||
full_turn.await.unwrap();
|
|
||||||
|
|
||||||
thread.read_with(cx, |thread, _| {
|
|
||||||
assert!(thread.entries.len() >= 3, "{:?}", &thread.entries);
|
|
||||||
assert!(matches!(
|
|
||||||
thread.entries[0].content,
|
|
||||||
AgentThreadEntryContent::Message(Message {
|
|
||||||
role: Role::User,
|
|
||||||
..
|
|
||||||
})
|
|
||||||
));
|
|
||||||
assert!(matches!(
|
|
||||||
thread.entries[1].content,
|
|
||||||
AgentThreadEntryContent::ToolCall(ToolCall {
|
|
||||||
status: ToolCallStatus::Allowed { .. },
|
|
||||||
..
|
|
||||||
})
|
|
||||||
));
|
|
||||||
assert!(matches!(
|
|
||||||
thread.entries[2].content,
|
|
||||||
AgentThreadEntryContent::Message(Message {
|
|
||||||
role: Role::Assistant,
|
|
||||||
..
|
|
||||||
})
|
|
||||||
));
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn run_until_tool_call(thread: &Entity<AcpThread>, cx: &mut TestAppContext) {
|
|
||||||
let (mut tx, mut rx) = mpsc::channel(1);
|
|
||||||
|
|
||||||
let subscription = cx.update(|cx| {
|
|
||||||
cx.subscribe(thread, move |thread, _, cx| {
|
|
||||||
if thread
|
|
||||||
.read(cx)
|
|
||||||
.entries
|
|
||||||
.iter()
|
|
||||||
.any(|e| matches!(e.content, AgentThreadEntryContent::ToolCall(_)))
|
|
||||||
{
|
|
||||||
tx.try_send(()).unwrap();
|
|
||||||
}
|
|
||||||
})
|
|
||||||
});
|
|
||||||
|
|
||||||
select! {
|
|
||||||
_ = cx.executor().timer(Duration::from_secs(5)).fuse() => {
|
|
||||||
panic!("Timeout waiting for tool call")
|
|
||||||
}
|
|
||||||
_ = rx.next().fuse() => {
|
|
||||||
drop(subscription);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn gemini_acp_server(project: Entity<Project>, mut cx: AsyncApp) -> Result<Arc<AcpServer>> {
|
pub fn gemini_acp_server(project: Entity<Project>, mut cx: AsyncApp) -> Result<Arc<AcpServer>> {
|
||||||
|
|
|
@ -329,7 +329,7 @@ impl AcpThreadView {
|
||||||
.color(Color::Muted),
|
.color(Color::Muted),
|
||||||
)
|
)
|
||||||
.child(MarkdownElement::new(
|
.child(MarkdownElement::new(
|
||||||
tool_call.tool_name.clone(),
|
tool_call.display_name.clone(),
|
||||||
default_markdown_style(window, cx),
|
default_markdown_style(window, cx),
|
||||||
))
|
))
|
||||||
.child(div().w_full())
|
.child(div().w_full())
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue