Lay the groundwork to support history in agent2 (#36483)
This pull request introduces title generation and history replaying. We still need to wire up the rest of the history but this gets us very close. I extracted a lot of this code from `agent2-history` because that branch was starting to get long-lived and there were lots of changes since we started. Release Notes: - N/A
This commit is contained in:
parent
c4083b9b63
commit
6c255c1973
19 changed files with 929 additions and 328 deletions
3
Cargo.lock
generated
3
Cargo.lock
generated
|
@ -191,6 +191,7 @@ version = "0.1.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"acp_thread",
|
"acp_thread",
|
||||||
"action_log",
|
"action_log",
|
||||||
|
"agent",
|
||||||
"agent-client-protocol",
|
"agent-client-protocol",
|
||||||
"agent_servers",
|
"agent_servers",
|
||||||
"agent_settings",
|
"agent_settings",
|
||||||
|
@ -208,6 +209,7 @@ dependencies = [
|
||||||
"env_logger 0.11.8",
|
"env_logger 0.11.8",
|
||||||
"fs",
|
"fs",
|
||||||
"futures 0.3.31",
|
"futures 0.3.31",
|
||||||
|
"git",
|
||||||
"gpui",
|
"gpui",
|
||||||
"gpui_tokio",
|
"gpui_tokio",
|
||||||
"handlebars 4.5.0",
|
"handlebars 4.5.0",
|
||||||
|
@ -256,6 +258,7 @@ name = "agent_servers"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"acp_thread",
|
"acp_thread",
|
||||||
|
"action_log",
|
||||||
"agent-client-protocol",
|
"agent-client-protocol",
|
||||||
"agent_settings",
|
"agent_settings",
|
||||||
"agentic-coding-protocol",
|
"agentic-coding-protocol",
|
||||||
|
|
|
@ -537,9 +537,15 @@ impl ToolCallContent {
|
||||||
acp::ToolCallContent::Content { content } => {
|
acp::ToolCallContent::Content { content } => {
|
||||||
Self::ContentBlock(ContentBlock::new(content, &language_registry, cx))
|
Self::ContentBlock(ContentBlock::new(content, &language_registry, cx))
|
||||||
}
|
}
|
||||||
acp::ToolCallContent::Diff { diff } => {
|
acp::ToolCallContent::Diff { diff } => Self::Diff(cx.new(|cx| {
|
||||||
Self::Diff(cx.new(|cx| Diff::from_acp(diff, language_registry, cx)))
|
Diff::finalized(
|
||||||
}
|
diff.path,
|
||||||
|
diff.old_text,
|
||||||
|
diff.new_text,
|
||||||
|
language_registry,
|
||||||
|
cx,
|
||||||
|
)
|
||||||
|
})),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -682,6 +688,7 @@ pub struct AcpThread {
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub enum AcpThreadEvent {
|
pub enum AcpThreadEvent {
|
||||||
NewEntry,
|
NewEntry,
|
||||||
|
TitleUpdated,
|
||||||
EntryUpdated(usize),
|
EntryUpdated(usize),
|
||||||
EntriesRemoved(Range<usize>),
|
EntriesRemoved(Range<usize>),
|
||||||
ToolAuthorizationRequired,
|
ToolAuthorizationRequired,
|
||||||
|
@ -728,11 +735,9 @@ impl AcpThread {
|
||||||
title: impl Into<SharedString>,
|
title: impl Into<SharedString>,
|
||||||
connection: Rc<dyn AgentConnection>,
|
connection: Rc<dyn AgentConnection>,
|
||||||
project: Entity<Project>,
|
project: Entity<Project>,
|
||||||
|
action_log: Entity<ActionLog>,
|
||||||
session_id: acp::SessionId,
|
session_id: acp::SessionId,
|
||||||
cx: &mut Context<Self>,
|
|
||||||
) -> Self {
|
) -> Self {
|
||||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
action_log,
|
action_log,
|
||||||
shared_buffers: Default::default(),
|
shared_buffers: Default::default(),
|
||||||
|
@ -926,6 +931,12 @@ impl AcpThread {
|
||||||
cx.emit(AcpThreadEvent::NewEntry);
|
cx.emit(AcpThreadEvent::NewEntry);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn update_title(&mut self, title: SharedString, cx: &mut Context<Self>) -> Result<()> {
|
||||||
|
self.title = title;
|
||||||
|
cx.emit(AcpThreadEvent::TitleUpdated);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
pub fn update_retry_status(&mut self, status: RetryStatus, cx: &mut Context<Self>) {
|
pub fn update_retry_status(&mut self, status: RetryStatus, cx: &mut Context<Self>) {
|
||||||
cx.emit(AcpThreadEvent::Retry(status));
|
cx.emit(AcpThreadEvent::Retry(status));
|
||||||
}
|
}
|
||||||
|
@ -1657,7 +1668,7 @@ mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use anyhow::anyhow;
|
use anyhow::anyhow;
|
||||||
use futures::{channel::mpsc, future::LocalBoxFuture, select};
|
use futures::{channel::mpsc, future::LocalBoxFuture, select};
|
||||||
use gpui::{AsyncApp, TestAppContext, WeakEntity};
|
use gpui::{App, AsyncApp, TestAppContext, WeakEntity};
|
||||||
use indoc::indoc;
|
use indoc::indoc;
|
||||||
use project::{FakeFs, Fs};
|
use project::{FakeFs, Fs};
|
||||||
use rand::Rng as _;
|
use rand::Rng as _;
|
||||||
|
@ -2327,7 +2338,7 @@ mod tests {
|
||||||
self: Rc<Self>,
|
self: Rc<Self>,
|
||||||
project: Entity<Project>,
|
project: Entity<Project>,
|
||||||
_cwd: &Path,
|
_cwd: &Path,
|
||||||
cx: &mut gpui::App,
|
cx: &mut App,
|
||||||
) -> Task<gpui::Result<Entity<AcpThread>>> {
|
) -> Task<gpui::Result<Entity<AcpThread>>> {
|
||||||
let session_id = acp::SessionId(
|
let session_id = acp::SessionId(
|
||||||
rand::thread_rng()
|
rand::thread_rng()
|
||||||
|
@ -2337,8 +2348,16 @@ mod tests {
|
||||||
.collect::<String>()
|
.collect::<String>()
|
||||||
.into(),
|
.into(),
|
||||||
);
|
);
|
||||||
let thread =
|
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||||
cx.new(|cx| AcpThread::new("Test", self.clone(), project, session_id.clone(), cx));
|
let thread = cx.new(|_cx| {
|
||||||
|
AcpThread::new(
|
||||||
|
"Test",
|
||||||
|
self.clone(),
|
||||||
|
project,
|
||||||
|
action_log,
|
||||||
|
session_id.clone(),
|
||||||
|
)
|
||||||
|
});
|
||||||
self.sessions.lock().insert(session_id, thread.downgrade());
|
self.sessions.lock().insert(session_id, thread.downgrade());
|
||||||
Task::ready(Ok(thread))
|
Task::ready(Ok(thread))
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,11 +5,12 @@ use collections::IndexMap;
|
||||||
use gpui::{Entity, SharedString, Task};
|
use gpui::{Entity, SharedString, Task};
|
||||||
use language_model::LanguageModelProviderId;
|
use language_model::LanguageModelProviderId;
|
||||||
use project::Project;
|
use project::Project;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
use std::{any::Any, error::Error, fmt, path::Path, rc::Rc, sync::Arc};
|
use std::{any::Any, error::Error, fmt, path::Path, rc::Rc, sync::Arc};
|
||||||
use ui::{App, IconName};
|
use ui::{App, IconName};
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
#[derive(Clone, Debug, Eq, PartialEq)]
|
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||||
pub struct UserMessageId(Arc<str>);
|
pub struct UserMessageId(Arc<str>);
|
||||||
|
|
||||||
impl UserMessageId {
|
impl UserMessageId {
|
||||||
|
@ -208,6 +209,7 @@ impl AgentModelList {
|
||||||
mod test_support {
|
mod test_support {
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use action_log::ActionLog;
|
||||||
use collections::HashMap;
|
use collections::HashMap;
|
||||||
use futures::{channel::oneshot, future::try_join_all};
|
use futures::{channel::oneshot, future::try_join_all};
|
||||||
use gpui::{AppContext as _, WeakEntity};
|
use gpui::{AppContext as _, WeakEntity};
|
||||||
|
@ -295,8 +297,16 @@ mod test_support {
|
||||||
cx: &mut gpui::App,
|
cx: &mut gpui::App,
|
||||||
) -> Task<gpui::Result<Entity<AcpThread>>> {
|
) -> Task<gpui::Result<Entity<AcpThread>>> {
|
||||||
let session_id = acp::SessionId(self.sessions.lock().len().to_string().into());
|
let session_id = acp::SessionId(self.sessions.lock().len().to_string().into());
|
||||||
let thread =
|
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||||
cx.new(|cx| AcpThread::new("Test", self.clone(), project, session_id.clone(), cx));
|
let thread = cx.new(|_cx| {
|
||||||
|
AcpThread::new(
|
||||||
|
"Test",
|
||||||
|
self.clone(),
|
||||||
|
project,
|
||||||
|
action_log,
|
||||||
|
session_id.clone(),
|
||||||
|
)
|
||||||
|
});
|
||||||
self.sessions.lock().insert(
|
self.sessions.lock().insert(
|
||||||
session_id,
|
session_id,
|
||||||
Session {
|
Session {
|
||||||
|
|
|
@ -1,4 +1,3 @@
|
||||||
use agent_client_protocol as acp;
|
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use buffer_diff::{BufferDiff, BufferDiffSnapshot};
|
use buffer_diff::{BufferDiff, BufferDiffSnapshot};
|
||||||
use editor::{MultiBuffer, PathKey};
|
use editor::{MultiBuffer, PathKey};
|
||||||
|
@ -21,17 +20,13 @@ pub enum Diff {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Diff {
|
impl Diff {
|
||||||
pub fn from_acp(
|
pub fn finalized(
|
||||||
diff: acp::Diff,
|
path: PathBuf,
|
||||||
|
old_text: Option<String>,
|
||||||
|
new_text: String,
|
||||||
language_registry: Arc<LanguageRegistry>,
|
language_registry: Arc<LanguageRegistry>,
|
||||||
cx: &mut Context<Self>,
|
cx: &mut Context<Self>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
let acp::Diff {
|
|
||||||
path,
|
|
||||||
old_text,
|
|
||||||
new_text,
|
|
||||||
} = diff;
|
|
||||||
|
|
||||||
let multibuffer = cx.new(|_cx| MultiBuffer::without_headers(Capability::ReadOnly));
|
let multibuffer = cx.new(|_cx| MultiBuffer::without_headers(Capability::ReadOnly));
|
||||||
|
|
||||||
let new_buffer = cx.new(|cx| Buffer::local(new_text, cx));
|
let new_buffer = cx.new(|cx| Buffer::local(new_text, cx));
|
||||||
|
|
|
@ -2,6 +2,7 @@ use agent::ThreadId;
|
||||||
use anyhow::{Context as _, Result, bail};
|
use anyhow::{Context as _, Result, bail};
|
||||||
use file_icons::FileIcons;
|
use file_icons::FileIcons;
|
||||||
use prompt_store::{PromptId, UserPromptId};
|
use prompt_store::{PromptId, UserPromptId};
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
use std::{
|
use std::{
|
||||||
fmt,
|
fmt,
|
||||||
ops::Range,
|
ops::Range,
|
||||||
|
@ -11,7 +12,7 @@ use std::{
|
||||||
use ui::{App, IconName, SharedString};
|
use ui::{App, IconName, SharedString};
|
||||||
use url::Url;
|
use url::Url;
|
||||||
|
|
||||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
|
||||||
pub enum MentionUri {
|
pub enum MentionUri {
|
||||||
File {
|
File {
|
||||||
abs_path: PathBuf,
|
abs_path: PathBuf,
|
||||||
|
|
|
@ -14,6 +14,7 @@ workspace = true
|
||||||
[dependencies]
|
[dependencies]
|
||||||
acp_thread.workspace = true
|
acp_thread.workspace = true
|
||||||
action_log.workspace = true
|
action_log.workspace = true
|
||||||
|
agent.workspace = true
|
||||||
agent-client-protocol.workspace = true
|
agent-client-protocol.workspace = true
|
||||||
agent_servers.workspace = true
|
agent_servers.workspace = true
|
||||||
agent_settings.workspace = true
|
agent_settings.workspace = true
|
||||||
|
@ -26,6 +27,7 @@ collections.workspace = true
|
||||||
context_server.workspace = true
|
context_server.workspace = true
|
||||||
fs.workspace = true
|
fs.workspace = true
|
||||||
futures.workspace = true
|
futures.workspace = true
|
||||||
|
git.workspace = true
|
||||||
gpui.workspace = true
|
gpui.workspace = true
|
||||||
handlebars = { workspace = true, features = ["rust-embed"] }
|
handlebars = { workspace = true, features = ["rust-embed"] }
|
||||||
html_to_markdown.workspace = true
|
html_to_markdown.workspace = true
|
||||||
|
@ -59,6 +61,7 @@ which.workspace = true
|
||||||
workspace-hack.workspace = true
|
workspace-hack.workspace = true
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
|
agent = { workspace = true, "features" = ["test-support"] }
|
||||||
ctor.workspace = true
|
ctor.workspace = true
|
||||||
client = { workspace = true, "features" = ["test-support"] }
|
client = { workspace = true, "features" = ["test-support"] }
|
||||||
clock = { workspace = true, "features" = ["test-support"] }
|
clock = { workspace = true, "features" = ["test-support"] }
|
||||||
|
@ -66,6 +69,7 @@ context_server = { workspace = true, "features" = ["test-support"] }
|
||||||
editor = { workspace = true, "features" = ["test-support"] }
|
editor = { workspace = true, "features" = ["test-support"] }
|
||||||
env_logger.workspace = true
|
env_logger.workspace = true
|
||||||
fs = { workspace = true, "features" = ["test-support"] }
|
fs = { workspace = true, "features" = ["test-support"] }
|
||||||
|
git = { workspace = true, "features" = ["test-support"] }
|
||||||
gpui = { workspace = true, "features" = ["test-support"] }
|
gpui = { workspace = true, "features" = ["test-support"] }
|
||||||
gpui_tokio.workspace = true
|
gpui_tokio.workspace = true
|
||||||
language = { workspace = true, "features" = ["test-support"] }
|
language = { workspace = true, "features" = ["test-support"] }
|
||||||
|
|
|
@ -1,10 +1,11 @@
|
||||||
use crate::{
|
use crate::{
|
||||||
AgentResponseEvent, ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DeletePathTool,
|
ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DeletePathTool, DiagnosticsTool,
|
||||||
DiagnosticsTool, EditFileTool, FetchTool, FindPathTool, GrepTool, ListDirectoryTool,
|
EditFileTool, FetchTool, FindPathTool, GrepTool, ListDirectoryTool, MovePathTool, NowTool,
|
||||||
MovePathTool, NowTool, OpenTool, ReadFileTool, TerminalTool, ThinkingTool, Thread,
|
OpenTool, ReadFileTool, TerminalTool, ThinkingTool, Thread, ThreadEvent, ToolCallAuthorization,
|
||||||
ToolCallAuthorization, UserMessageContent, WebSearchTool, templates::Templates,
|
UserMessageContent, WebSearchTool, templates::Templates,
|
||||||
};
|
};
|
||||||
use acp_thread::AgentModelSelector;
|
use acp_thread::AgentModelSelector;
|
||||||
|
use action_log::ActionLog;
|
||||||
use agent_client_protocol as acp;
|
use agent_client_protocol as acp;
|
||||||
use agent_settings::AgentSettings;
|
use agent_settings::AgentSettings;
|
||||||
use anyhow::{Context as _, Result, anyhow};
|
use anyhow::{Context as _, Result, anyhow};
|
||||||
|
@ -427,18 +428,19 @@ impl NativeAgent {
|
||||||
) {
|
) {
|
||||||
self.models.refresh_list(cx);
|
self.models.refresh_list(cx);
|
||||||
|
|
||||||
let default_model = LanguageModelRegistry::read_global(cx)
|
let registry = LanguageModelRegistry::read_global(cx);
|
||||||
.default_model()
|
let default_model = registry.default_model().map(|m| m.model.clone());
|
||||||
.map(|m| m.model.clone());
|
let summarization_model = registry.thread_summary_model().map(|m| m.model.clone());
|
||||||
|
|
||||||
for session in self.sessions.values_mut() {
|
for session in self.sessions.values_mut() {
|
||||||
session.thread.update(cx, |thread, cx| {
|
session.thread.update(cx, |thread, cx| {
|
||||||
if thread.model().is_none()
|
if thread.model().is_none()
|
||||||
&& let Some(model) = default_model.clone()
|
&& let Some(model) = default_model.clone()
|
||||||
{
|
{
|
||||||
thread.set_model(model);
|
thread.set_model(model, cx);
|
||||||
cx.notify();
|
cx.notify();
|
||||||
}
|
}
|
||||||
|
thread.set_summarization_model(summarization_model.clone(), cx);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -462,10 +464,7 @@ impl NativeAgentConnection {
|
||||||
session_id: acp::SessionId,
|
session_id: acp::SessionId,
|
||||||
cx: &mut App,
|
cx: &mut App,
|
||||||
f: impl 'static
|
f: impl 'static
|
||||||
+ FnOnce(
|
+ FnOnce(Entity<Thread>, &mut App) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>>,
|
||||||
Entity<Thread>,
|
|
||||||
&mut App,
|
|
||||||
) -> Result<mpsc::UnboundedReceiver<Result<AgentResponseEvent>>>,
|
|
||||||
) -> Task<Result<acp::PromptResponse>> {
|
) -> Task<Result<acp::PromptResponse>> {
|
||||||
let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
|
let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
|
||||||
agent
|
agent
|
||||||
|
@ -489,7 +488,18 @@ impl NativeAgentConnection {
|
||||||
log::trace!("Received completion event: {:?}", event);
|
log::trace!("Received completion event: {:?}", event);
|
||||||
|
|
||||||
match event {
|
match event {
|
||||||
AgentResponseEvent::Text(text) => {
|
ThreadEvent::UserMessage(message) => {
|
||||||
|
acp_thread.update(cx, |thread, cx| {
|
||||||
|
for content in message.content {
|
||||||
|
thread.push_user_content_block(
|
||||||
|
Some(message.id.clone()),
|
||||||
|
content.into(),
|
||||||
|
cx,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
})?;
|
||||||
|
}
|
||||||
|
ThreadEvent::AgentText(text) => {
|
||||||
acp_thread.update(cx, |thread, cx| {
|
acp_thread.update(cx, |thread, cx| {
|
||||||
thread.push_assistant_content_block(
|
thread.push_assistant_content_block(
|
||||||
acp::ContentBlock::Text(acp::TextContent {
|
acp::ContentBlock::Text(acp::TextContent {
|
||||||
|
@ -501,7 +511,7 @@ impl NativeAgentConnection {
|
||||||
)
|
)
|
||||||
})?;
|
})?;
|
||||||
}
|
}
|
||||||
AgentResponseEvent::Thinking(text) => {
|
ThreadEvent::AgentThinking(text) => {
|
||||||
acp_thread.update(cx, |thread, cx| {
|
acp_thread.update(cx, |thread, cx| {
|
||||||
thread.push_assistant_content_block(
|
thread.push_assistant_content_block(
|
||||||
acp::ContentBlock::Text(acp::TextContent {
|
acp::ContentBlock::Text(acp::TextContent {
|
||||||
|
@ -513,7 +523,7 @@ impl NativeAgentConnection {
|
||||||
)
|
)
|
||||||
})?;
|
})?;
|
||||||
}
|
}
|
||||||
AgentResponseEvent::ToolCallAuthorization(ToolCallAuthorization {
|
ThreadEvent::ToolCallAuthorization(ToolCallAuthorization {
|
||||||
tool_call,
|
tool_call,
|
||||||
options,
|
options,
|
||||||
response,
|
response,
|
||||||
|
@ -536,22 +546,26 @@ impl NativeAgentConnection {
|
||||||
})
|
})
|
||||||
.detach();
|
.detach();
|
||||||
}
|
}
|
||||||
AgentResponseEvent::ToolCall(tool_call) => {
|
ThreadEvent::ToolCall(tool_call) => {
|
||||||
acp_thread.update(cx, |thread, cx| {
|
acp_thread.update(cx, |thread, cx| {
|
||||||
thread.upsert_tool_call(tool_call, cx)
|
thread.upsert_tool_call(tool_call, cx)
|
||||||
})??;
|
})??;
|
||||||
}
|
}
|
||||||
AgentResponseEvent::ToolCallUpdate(update) => {
|
ThreadEvent::ToolCallUpdate(update) => {
|
||||||
acp_thread.update(cx, |thread, cx| {
|
acp_thread.update(cx, |thread, cx| {
|
||||||
thread.update_tool_call(update, cx)
|
thread.update_tool_call(update, cx)
|
||||||
})??;
|
})??;
|
||||||
}
|
}
|
||||||
AgentResponseEvent::Retry(status) => {
|
ThreadEvent::TitleUpdate(title) => {
|
||||||
|
acp_thread
|
||||||
|
.update(cx, |thread, cx| thread.update_title(title, cx))??;
|
||||||
|
}
|
||||||
|
ThreadEvent::Retry(status) => {
|
||||||
acp_thread.update(cx, |thread, cx| {
|
acp_thread.update(cx, |thread, cx| {
|
||||||
thread.update_retry_status(status, cx)
|
thread.update_retry_status(status, cx)
|
||||||
})?;
|
})?;
|
||||||
}
|
}
|
||||||
AgentResponseEvent::Stop(stop_reason) => {
|
ThreadEvent::Stop(stop_reason) => {
|
||||||
log::debug!("Assistant message complete: {:?}", stop_reason);
|
log::debug!("Assistant message complete: {:?}", stop_reason);
|
||||||
return Ok(acp::PromptResponse { stop_reason });
|
return Ok(acp::PromptResponse { stop_reason });
|
||||||
}
|
}
|
||||||
|
@ -604,8 +618,8 @@ impl AgentModelSelector for NativeAgentConnection {
|
||||||
return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
|
return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
|
||||||
};
|
};
|
||||||
|
|
||||||
thread.update(cx, |thread, _cx| {
|
thread.update(cx, |thread, cx| {
|
||||||
thread.set_model(model.clone());
|
thread.set_model(model.clone(), cx);
|
||||||
});
|
});
|
||||||
|
|
||||||
update_settings_file::<AgentSettings>(
|
update_settings_file::<AgentSettings>(
|
||||||
|
@ -665,30 +679,14 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
|
||||||
cx.spawn(async move |cx| {
|
cx.spawn(async move |cx| {
|
||||||
log::debug!("Starting thread creation in async context");
|
log::debug!("Starting thread creation in async context");
|
||||||
|
|
||||||
// Generate session ID
|
let action_log = cx.new(|_cx| ActionLog::new(project.clone()))?;
|
||||||
let session_id = acp::SessionId(uuid::Uuid::new_v4().to_string().into());
|
|
||||||
log::info!("Created session with ID: {}", session_id);
|
|
||||||
|
|
||||||
// Create AcpThread
|
|
||||||
let acp_thread = cx.update(|cx| {
|
|
||||||
cx.new(|cx| {
|
|
||||||
acp_thread::AcpThread::new(
|
|
||||||
"agent2",
|
|
||||||
self.clone(),
|
|
||||||
project.clone(),
|
|
||||||
session_id.clone(),
|
|
||||||
cx,
|
|
||||||
)
|
|
||||||
})
|
|
||||||
})?;
|
|
||||||
let action_log = cx.update(|cx| acp_thread.read(cx).action_log().clone())?;
|
|
||||||
|
|
||||||
// Create Thread
|
// Create Thread
|
||||||
let thread = agent.update(
|
let thread = agent.update(
|
||||||
cx,
|
cx,
|
||||||
|agent, cx: &mut gpui::Context<NativeAgent>| -> Result<_> {
|
|agent, cx: &mut gpui::Context<NativeAgent>| -> Result<_> {
|
||||||
// Fetch default model from registry settings
|
// Fetch default model from registry settings
|
||||||
let registry = LanguageModelRegistry::read_global(cx);
|
let registry = LanguageModelRegistry::read_global(cx);
|
||||||
|
let language_registry = project.read(cx).languages().clone();
|
||||||
|
|
||||||
// Log available models for debugging
|
// Log available models for debugging
|
||||||
let available_count = registry.available_models(cx).count();
|
let available_count = registry.available_models(cx).count();
|
||||||
|
@ -699,6 +697,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
|
||||||
.models
|
.models
|
||||||
.model_from_id(&LanguageModels::model_id(&default_model.model))
|
.model_from_id(&LanguageModels::model_id(&default_model.model))
|
||||||
});
|
});
|
||||||
|
let summarization_model = registry.thread_summary_model().map(|c| c.model);
|
||||||
|
|
||||||
let thread = cx.new(|cx| {
|
let thread = cx.new(|cx| {
|
||||||
let mut thread = Thread::new(
|
let mut thread = Thread::new(
|
||||||
|
@ -708,13 +707,14 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
|
||||||
action_log.clone(),
|
action_log.clone(),
|
||||||
agent.templates.clone(),
|
agent.templates.clone(),
|
||||||
default_model,
|
default_model,
|
||||||
|
summarization_model,
|
||||||
cx,
|
cx,
|
||||||
);
|
);
|
||||||
thread.add_tool(CopyPathTool::new(project.clone()));
|
thread.add_tool(CopyPathTool::new(project.clone()));
|
||||||
thread.add_tool(CreateDirectoryTool::new(project.clone()));
|
thread.add_tool(CreateDirectoryTool::new(project.clone()));
|
||||||
thread.add_tool(DeletePathTool::new(project.clone(), action_log.clone()));
|
thread.add_tool(DeletePathTool::new(project.clone(), action_log.clone()));
|
||||||
thread.add_tool(DiagnosticsTool::new(project.clone()));
|
thread.add_tool(DiagnosticsTool::new(project.clone()));
|
||||||
thread.add_tool(EditFileTool::new(cx.entity()));
|
thread.add_tool(EditFileTool::new(cx.weak_entity(), language_registry));
|
||||||
thread.add_tool(FetchTool::new(project.read(cx).client().http_client()));
|
thread.add_tool(FetchTool::new(project.read(cx).client().http_client()));
|
||||||
thread.add_tool(FindPathTool::new(project.clone()));
|
thread.add_tool(FindPathTool::new(project.clone()));
|
||||||
thread.add_tool(GrepTool::new(project.clone()));
|
thread.add_tool(GrepTool::new(project.clone()));
|
||||||
|
@ -722,7 +722,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
|
||||||
thread.add_tool(MovePathTool::new(project.clone()));
|
thread.add_tool(MovePathTool::new(project.clone()));
|
||||||
thread.add_tool(NowTool);
|
thread.add_tool(NowTool);
|
||||||
thread.add_tool(OpenTool::new(project.clone()));
|
thread.add_tool(OpenTool::new(project.clone()));
|
||||||
thread.add_tool(ReadFileTool::new(project.clone(), action_log));
|
thread.add_tool(ReadFileTool::new(project.clone(), action_log.clone()));
|
||||||
thread.add_tool(TerminalTool::new(project.clone(), cx));
|
thread.add_tool(TerminalTool::new(project.clone(), cx));
|
||||||
thread.add_tool(ThinkingTool);
|
thread.add_tool(ThinkingTool);
|
||||||
thread.add_tool(WebSearchTool); // TODO: Enable this only if it's a zed model.
|
thread.add_tool(WebSearchTool); // TODO: Enable this only if it's a zed model.
|
||||||
|
@ -733,6 +733,21 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
|
||||||
},
|
},
|
||||||
)??;
|
)??;
|
||||||
|
|
||||||
|
let session_id = thread.read_with(cx, |thread, _| thread.id().clone())?;
|
||||||
|
log::info!("Created session with ID: {}", session_id);
|
||||||
|
// Create AcpThread
|
||||||
|
let acp_thread = cx.update(|cx| {
|
||||||
|
cx.new(|_cx| {
|
||||||
|
acp_thread::AcpThread::new(
|
||||||
|
"agent2",
|
||||||
|
self.clone(),
|
||||||
|
project.clone(),
|
||||||
|
action_log.clone(),
|
||||||
|
session_id.clone(),
|
||||||
|
)
|
||||||
|
})
|
||||||
|
})?;
|
||||||
|
|
||||||
// Store the session
|
// Store the session
|
||||||
agent.update(cx, |agent, cx| {
|
agent.update(cx, |agent, cx| {
|
||||||
agent.sessions.insert(
|
agent.sessions.insert(
|
||||||
|
@ -803,7 +818,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
|
||||||
log::info!("Cancelling on session: {}", session_id);
|
log::info!("Cancelling on session: {}", session_id);
|
||||||
self.0.update(cx, |agent, cx| {
|
self.0.update(cx, |agent, cx| {
|
||||||
if let Some(agent) = agent.sessions.get(session_id) {
|
if let Some(agent) = agent.sessions.get(session_id) {
|
||||||
agent.thread.update(cx, |thread, _cx| thread.cancel());
|
agent.thread.update(cx, |thread, cx| thread.cancel(cx));
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
@ -830,7 +845,10 @@ struct NativeAgentSessionEditor(Entity<Thread>);
|
||||||
|
|
||||||
impl acp_thread::AgentSessionEditor for NativeAgentSessionEditor {
|
impl acp_thread::AgentSessionEditor for NativeAgentSessionEditor {
|
||||||
fn truncate(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
|
fn truncate(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
|
||||||
Task::ready(self.0.update(cx, |thread, _cx| thread.truncate(message_id)))
|
Task::ready(
|
||||||
|
self.0
|
||||||
|
.update(cx, |thread, cx| thread.truncate(message_id, cx)),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -345,7 +345,7 @@ async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
|
||||||
|
|
||||||
let mut saw_partial_tool_use = false;
|
let mut saw_partial_tool_use = false;
|
||||||
while let Some(event) = events.next().await {
|
while let Some(event) = events.next().await {
|
||||||
if let Ok(AgentResponseEvent::ToolCall(tool_call)) = event {
|
if let Ok(ThreadEvent::ToolCall(tool_call)) = event {
|
||||||
thread.update(cx, |thread, _cx| {
|
thread.update(cx, |thread, _cx| {
|
||||||
// Look for a tool use in the thread's last message
|
// Look for a tool use in the thread's last message
|
||||||
let message = thread.last_message().unwrap();
|
let message = thread.last_message().unwrap();
|
||||||
|
@ -735,16 +735,14 @@ async fn test_send_after_tool_use_limit(cx: &mut TestAppContext) {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn expect_tool_call(
|
async fn expect_tool_call(events: &mut UnboundedReceiver<Result<ThreadEvent>>) -> acp::ToolCall {
|
||||||
events: &mut UnboundedReceiver<Result<AgentResponseEvent>>,
|
|
||||||
) -> acp::ToolCall {
|
|
||||||
let event = events
|
let event = events
|
||||||
.next()
|
.next()
|
||||||
.await
|
.await
|
||||||
.expect("no tool call authorization event received")
|
.expect("no tool call authorization event received")
|
||||||
.unwrap();
|
.unwrap();
|
||||||
match event {
|
match event {
|
||||||
AgentResponseEvent::ToolCall(tool_call) => return tool_call,
|
ThreadEvent::ToolCall(tool_call) => return tool_call,
|
||||||
event => {
|
event => {
|
||||||
panic!("Unexpected event {event:?}");
|
panic!("Unexpected event {event:?}");
|
||||||
}
|
}
|
||||||
|
@ -752,7 +750,7 @@ async fn expect_tool_call(
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn expect_tool_call_update_fields(
|
async fn expect_tool_call_update_fields(
|
||||||
events: &mut UnboundedReceiver<Result<AgentResponseEvent>>,
|
events: &mut UnboundedReceiver<Result<ThreadEvent>>,
|
||||||
) -> acp::ToolCallUpdate {
|
) -> acp::ToolCallUpdate {
|
||||||
let event = events
|
let event = events
|
||||||
.next()
|
.next()
|
||||||
|
@ -760,7 +758,7 @@ async fn expect_tool_call_update_fields(
|
||||||
.expect("no tool call authorization event received")
|
.expect("no tool call authorization event received")
|
||||||
.unwrap();
|
.unwrap();
|
||||||
match event {
|
match event {
|
||||||
AgentResponseEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => {
|
ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => {
|
||||||
return update;
|
return update;
|
||||||
}
|
}
|
||||||
event => {
|
event => {
|
||||||
|
@ -770,7 +768,7 @@ async fn expect_tool_call_update_fields(
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn next_tool_call_authorization(
|
async fn next_tool_call_authorization(
|
||||||
events: &mut UnboundedReceiver<Result<AgentResponseEvent>>,
|
events: &mut UnboundedReceiver<Result<ThreadEvent>>,
|
||||||
) -> ToolCallAuthorization {
|
) -> ToolCallAuthorization {
|
||||||
loop {
|
loop {
|
||||||
let event = events
|
let event = events
|
||||||
|
@ -778,7 +776,7 @@ async fn next_tool_call_authorization(
|
||||||
.await
|
.await
|
||||||
.expect("no tool call authorization event received")
|
.expect("no tool call authorization event received")
|
||||||
.unwrap();
|
.unwrap();
|
||||||
if let AgentResponseEvent::ToolCallAuthorization(tool_call_authorization) = event {
|
if let ThreadEvent::ToolCallAuthorization(tool_call_authorization) = event {
|
||||||
let permission_kinds = tool_call_authorization
|
let permission_kinds = tool_call_authorization
|
||||||
.options
|
.options
|
||||||
.iter()
|
.iter()
|
||||||
|
@ -945,13 +943,13 @@ async fn test_cancellation(cx: &mut TestAppContext) {
|
||||||
let mut echo_completed = false;
|
let mut echo_completed = false;
|
||||||
while let Some(event) = events.next().await {
|
while let Some(event) = events.next().await {
|
||||||
match event.unwrap() {
|
match event.unwrap() {
|
||||||
AgentResponseEvent::ToolCall(tool_call) => {
|
ThreadEvent::ToolCall(tool_call) => {
|
||||||
assert_eq!(tool_call.title, expected_tools.remove(0));
|
assert_eq!(tool_call.title, expected_tools.remove(0));
|
||||||
if tool_call.title == "Echo" {
|
if tool_call.title == "Echo" {
|
||||||
echo_id = Some(tool_call.id);
|
echo_id = Some(tool_call.id);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
AgentResponseEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
|
ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
|
||||||
acp::ToolCallUpdate {
|
acp::ToolCallUpdate {
|
||||||
id,
|
id,
|
||||||
fields:
|
fields:
|
||||||
|
@ -973,13 +971,13 @@ async fn test_cancellation(cx: &mut TestAppContext) {
|
||||||
|
|
||||||
// Cancel the current send and ensure that the event stream is closed, even
|
// Cancel the current send and ensure that the event stream is closed, even
|
||||||
// if one of the tools is still running.
|
// if one of the tools is still running.
|
||||||
thread.update(cx, |thread, _cx| thread.cancel());
|
thread.update(cx, |thread, cx| thread.cancel(cx));
|
||||||
let events = events.collect::<Vec<_>>().await;
|
let events = events.collect::<Vec<_>>().await;
|
||||||
let last_event = events.last();
|
let last_event = events.last();
|
||||||
assert!(
|
assert!(
|
||||||
matches!(
|
matches!(
|
||||||
last_event,
|
last_event,
|
||||||
Some(Ok(AgentResponseEvent::Stop(acp::StopReason::Canceled)))
|
Some(Ok(ThreadEvent::Stop(acp::StopReason::Canceled)))
|
||||||
),
|
),
|
||||||
"unexpected event {last_event:?}"
|
"unexpected event {last_event:?}"
|
||||||
);
|
);
|
||||||
|
@ -1161,7 +1159,7 @@ async fn test_truncate(cx: &mut TestAppContext) {
|
||||||
});
|
});
|
||||||
|
|
||||||
thread
|
thread
|
||||||
.update(cx, |thread, _cx| thread.truncate(message_id))
|
.update(cx, |thread, cx| thread.truncate(message_id, cx))
|
||||||
.unwrap();
|
.unwrap();
|
||||||
cx.run_until_parked();
|
cx.run_until_parked();
|
||||||
thread.read_with(cx, |thread, _| {
|
thread.read_with(cx, |thread, _| {
|
||||||
|
@ -1203,6 +1201,51 @@ async fn test_truncate(cx: &mut TestAppContext) {
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[gpui::test]
|
||||||
|
async fn test_title_generation(cx: &mut TestAppContext) {
|
||||||
|
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
|
||||||
|
let fake_model = model.as_fake();
|
||||||
|
|
||||||
|
let summary_model = Arc::new(FakeLanguageModel::default());
|
||||||
|
thread.update(cx, |thread, cx| {
|
||||||
|
thread.set_summarization_model(Some(summary_model.clone()), cx)
|
||||||
|
});
|
||||||
|
|
||||||
|
let send = thread
|
||||||
|
.update(cx, |thread, cx| {
|
||||||
|
thread.send(UserMessageId::new(), ["Hello"], cx)
|
||||||
|
})
|
||||||
|
.unwrap();
|
||||||
|
cx.run_until_parked();
|
||||||
|
|
||||||
|
fake_model.send_last_completion_stream_text_chunk("Hey!");
|
||||||
|
fake_model.end_last_completion_stream();
|
||||||
|
cx.run_until_parked();
|
||||||
|
thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "New Thread"));
|
||||||
|
|
||||||
|
// Ensure the summary model has been invoked to generate a title.
|
||||||
|
summary_model.send_last_completion_stream_text_chunk("Hello ");
|
||||||
|
summary_model.send_last_completion_stream_text_chunk("world\nG");
|
||||||
|
summary_model.send_last_completion_stream_text_chunk("oodnight Moon");
|
||||||
|
summary_model.end_last_completion_stream();
|
||||||
|
send.collect::<Vec<_>>().await;
|
||||||
|
thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
|
||||||
|
|
||||||
|
// Send another message, ensuring no title is generated this time.
|
||||||
|
let send = thread
|
||||||
|
.update(cx, |thread, cx| {
|
||||||
|
thread.send(UserMessageId::new(), ["Hello again"], cx)
|
||||||
|
})
|
||||||
|
.unwrap();
|
||||||
|
cx.run_until_parked();
|
||||||
|
fake_model.send_last_completion_stream_text_chunk("Hey again!");
|
||||||
|
fake_model.end_last_completion_stream();
|
||||||
|
cx.run_until_parked();
|
||||||
|
assert_eq!(summary_model.pending_completions(), Vec::new());
|
||||||
|
send.collect::<Vec<_>>().await;
|
||||||
|
thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
|
||||||
|
}
|
||||||
|
|
||||||
#[gpui::test]
|
#[gpui::test]
|
||||||
async fn test_agent_connection(cx: &mut TestAppContext) {
|
async fn test_agent_connection(cx: &mut TestAppContext) {
|
||||||
cx.update(settings::init);
|
cx.update(settings::init);
|
||||||
|
@ -1442,7 +1485,7 @@ async fn test_send_no_retry_on_success(cx: &mut TestAppContext) {
|
||||||
|
|
||||||
let mut events = thread
|
let mut events = thread
|
||||||
.update(cx, |thread, cx| {
|
.update(cx, |thread, cx| {
|
||||||
thread.set_completion_mode(agent_settings::CompletionMode::Burn);
|
thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
|
||||||
thread.send(UserMessageId::new(), ["Hello!"], cx)
|
thread.send(UserMessageId::new(), ["Hello!"], cx)
|
||||||
})
|
})
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
@ -1454,10 +1497,10 @@ async fn test_send_no_retry_on_success(cx: &mut TestAppContext) {
|
||||||
let mut retry_events = Vec::new();
|
let mut retry_events = Vec::new();
|
||||||
while let Some(Ok(event)) = events.next().await {
|
while let Some(Ok(event)) = events.next().await {
|
||||||
match event {
|
match event {
|
||||||
AgentResponseEvent::Retry(retry_status) => {
|
ThreadEvent::Retry(retry_status) => {
|
||||||
retry_events.push(retry_status);
|
retry_events.push(retry_status);
|
||||||
}
|
}
|
||||||
AgentResponseEvent::Stop(..) => break,
|
ThreadEvent::Stop(..) => break,
|
||||||
_ => {}
|
_ => {}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1486,7 +1529,7 @@ async fn test_send_retry_on_error(cx: &mut TestAppContext) {
|
||||||
|
|
||||||
let mut events = thread
|
let mut events = thread
|
||||||
.update(cx, |thread, cx| {
|
.update(cx, |thread, cx| {
|
||||||
thread.set_completion_mode(agent_settings::CompletionMode::Burn);
|
thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
|
||||||
thread.send(UserMessageId::new(), ["Hello!"], cx)
|
thread.send(UserMessageId::new(), ["Hello!"], cx)
|
||||||
})
|
})
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
@ -1507,10 +1550,10 @@ async fn test_send_retry_on_error(cx: &mut TestAppContext) {
|
||||||
let mut retry_events = Vec::new();
|
let mut retry_events = Vec::new();
|
||||||
while let Some(Ok(event)) = events.next().await {
|
while let Some(Ok(event)) = events.next().await {
|
||||||
match event {
|
match event {
|
||||||
AgentResponseEvent::Retry(retry_status) => {
|
ThreadEvent::Retry(retry_status) => {
|
||||||
retry_events.push(retry_status);
|
retry_events.push(retry_status);
|
||||||
}
|
}
|
||||||
AgentResponseEvent::Stop(..) => break,
|
ThreadEvent::Stop(..) => break,
|
||||||
_ => {}
|
_ => {}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1543,7 +1586,7 @@ async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) {
|
||||||
|
|
||||||
let mut events = thread
|
let mut events = thread
|
||||||
.update(cx, |thread, cx| {
|
.update(cx, |thread, cx| {
|
||||||
thread.set_completion_mode(agent_settings::CompletionMode::Burn);
|
thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
|
||||||
thread.send(UserMessageId::new(), ["Hello!"], cx)
|
thread.send(UserMessageId::new(), ["Hello!"], cx)
|
||||||
})
|
})
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
@ -1565,10 +1608,10 @@ async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) {
|
||||||
let mut retry_events = Vec::new();
|
let mut retry_events = Vec::new();
|
||||||
while let Some(event) = events.next().await {
|
while let Some(event) = events.next().await {
|
||||||
match event {
|
match event {
|
||||||
Ok(AgentResponseEvent::Retry(retry_status)) => {
|
Ok(ThreadEvent::Retry(retry_status)) => {
|
||||||
retry_events.push(retry_status);
|
retry_events.push(retry_status);
|
||||||
}
|
}
|
||||||
Ok(AgentResponseEvent::Stop(..)) => break,
|
Ok(ThreadEvent::Stop(..)) => break,
|
||||||
Err(error) => errors.push(error),
|
Err(error) => errors.push(error),
|
||||||
_ => {}
|
_ => {}
|
||||||
}
|
}
|
||||||
|
@ -1592,11 +1635,11 @@ async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Filters out the stop events for asserting against in tests
|
/// Filters out the stop events for asserting against in tests
|
||||||
fn stop_events(result_events: Vec<Result<AgentResponseEvent>>) -> Vec<acp::StopReason> {
|
fn stop_events(result_events: Vec<Result<ThreadEvent>>) -> Vec<acp::StopReason> {
|
||||||
result_events
|
result_events
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.filter_map(|event| match event.unwrap() {
|
.filter_map(|event| match event.unwrap() {
|
||||||
AgentResponseEvent::Stop(stop_reason) => Some(stop_reason),
|
ThreadEvent::Stop(stop_reason) => Some(stop_reason),
|
||||||
_ => None,
|
_ => None,
|
||||||
})
|
})
|
||||||
.collect()
|
.collect()
|
||||||
|
@ -1713,6 +1756,7 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
|
||||||
action_log,
|
action_log,
|
||||||
templates,
|
templates,
|
||||||
Some(model.clone()),
|
Some(model.clone()),
|
||||||
|
None,
|
||||||
cx,
|
cx,
|
||||||
)
|
)
|
||||||
});
|
});
|
||||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -228,4 +228,14 @@ impl AnyAgentTool for ContextServerTool {
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn replay(
|
||||||
|
&self,
|
||||||
|
_input: serde_json::Value,
|
||||||
|
_output: serde_json::Value,
|
||||||
|
_event_stream: ToolCallEventStream,
|
||||||
|
_cx: &mut App,
|
||||||
|
) -> Result<()> {
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,10 +5,10 @@ use anyhow::{Context as _, Result, anyhow};
|
||||||
use assistant_tools::edit_agent::{EditAgent, EditAgentOutput, EditAgentOutputEvent, EditFormat};
|
use assistant_tools::edit_agent::{EditAgent, EditAgentOutput, EditAgentOutputEvent, EditFormat};
|
||||||
use cloud_llm_client::CompletionIntent;
|
use cloud_llm_client::CompletionIntent;
|
||||||
use collections::HashSet;
|
use collections::HashSet;
|
||||||
use gpui::{App, AppContext, AsyncApp, Entity, Task};
|
use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity};
|
||||||
use indoc::formatdoc;
|
use indoc::formatdoc;
|
||||||
use language::ToPoint;
|
|
||||||
use language::language_settings::{self, FormatOnSave};
|
use language::language_settings::{self, FormatOnSave};
|
||||||
|
use language::{LanguageRegistry, ToPoint};
|
||||||
use language_model::LanguageModelToolResultContent;
|
use language_model::LanguageModelToolResultContent;
|
||||||
use paths;
|
use paths;
|
||||||
use project::lsp_store::{FormatTrigger, LspFormatTarget};
|
use project::lsp_store::{FormatTrigger, LspFormatTarget};
|
||||||
|
@ -98,11 +98,13 @@ pub enum EditFileMode {
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
pub struct EditFileToolOutput {
|
pub struct EditFileToolOutput {
|
||||||
|
#[serde(alias = "original_path")]
|
||||||
input_path: PathBuf,
|
input_path: PathBuf,
|
||||||
project_path: PathBuf,
|
|
||||||
new_text: String,
|
new_text: String,
|
||||||
old_text: Arc<String>,
|
old_text: Arc<String>,
|
||||||
|
#[serde(default)]
|
||||||
diff: String,
|
diff: String,
|
||||||
|
#[serde(alias = "raw_output")]
|
||||||
edit_agent_output: EditAgentOutput,
|
edit_agent_output: EditAgentOutput,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -122,12 +124,16 @@ impl From<EditFileToolOutput> for LanguageModelToolResultContent {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct EditFileTool {
|
pub struct EditFileTool {
|
||||||
thread: Entity<Thread>,
|
thread: WeakEntity<Thread>,
|
||||||
|
language_registry: Arc<LanguageRegistry>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl EditFileTool {
|
impl EditFileTool {
|
||||||
pub fn new(thread: Entity<Thread>) -> Self {
|
pub fn new(thread: WeakEntity<Thread>, language_registry: Arc<LanguageRegistry>) -> Self {
|
||||||
Self { thread }
|
Self {
|
||||||
|
thread,
|
||||||
|
language_registry,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn authorize(
|
fn authorize(
|
||||||
|
@ -167,8 +173,11 @@ impl EditFileTool {
|
||||||
|
|
||||||
// Check if path is inside the global config directory
|
// Check if path is inside the global config directory
|
||||||
// First check if it's already inside project - if not, try to canonicalize
|
// First check if it's already inside project - if not, try to canonicalize
|
||||||
let thread = self.thread.read(cx);
|
let Ok(project_path) = self.thread.read_with(cx, |thread, cx| {
|
||||||
let project_path = thread.project().read(cx).find_project_path(&input.path, cx);
|
thread.project().read(cx).find_project_path(&input.path, cx)
|
||||||
|
}) else {
|
||||||
|
return Task::ready(Err(anyhow!("thread was dropped")));
|
||||||
|
};
|
||||||
|
|
||||||
// If the path is inside the project, and it's not one of the above edge cases,
|
// If the path is inside the project, and it's not one of the above edge cases,
|
||||||
// then no confirmation is necessary. Otherwise, confirmation is necessary.
|
// then no confirmation is necessary. Otherwise, confirmation is necessary.
|
||||||
|
@ -221,7 +230,12 @@ impl AgentTool for EditFileTool {
|
||||||
event_stream: ToolCallEventStream,
|
event_stream: ToolCallEventStream,
|
||||||
cx: &mut App,
|
cx: &mut App,
|
||||||
) -> Task<Result<Self::Output>> {
|
) -> Task<Result<Self::Output>> {
|
||||||
let project = self.thread.read(cx).project().clone();
|
let Ok(project) = self
|
||||||
|
.thread
|
||||||
|
.read_with(cx, |thread, _cx| thread.project().clone())
|
||||||
|
else {
|
||||||
|
return Task::ready(Err(anyhow!("thread was dropped")));
|
||||||
|
};
|
||||||
let project_path = match resolve_path(&input, project.clone(), cx) {
|
let project_path = match resolve_path(&input, project.clone(), cx) {
|
||||||
Ok(path) => path,
|
Ok(path) => path,
|
||||||
Err(err) => return Task::ready(Err(anyhow!(err))),
|
Err(err) => return Task::ready(Err(anyhow!(err))),
|
||||||
|
@ -237,23 +251,17 @@ impl AgentTool for EditFileTool {
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
let Some(request) = self.thread.update(cx, |thread, cx| {
|
|
||||||
thread
|
|
||||||
.build_completion_request(CompletionIntent::ToolResults, cx)
|
|
||||||
.ok()
|
|
||||||
}) else {
|
|
||||||
return Task::ready(Err(anyhow!("Failed to build completion request")));
|
|
||||||
};
|
|
||||||
let thread = self.thread.read(cx);
|
|
||||||
let Some(model) = thread.model().cloned() else {
|
|
||||||
return Task::ready(Err(anyhow!("No language model configured")));
|
|
||||||
};
|
|
||||||
let action_log = thread.action_log().clone();
|
|
||||||
|
|
||||||
let authorize = self.authorize(&input, &event_stream, cx);
|
let authorize = self.authorize(&input, &event_stream, cx);
|
||||||
cx.spawn(async move |cx: &mut AsyncApp| {
|
cx.spawn(async move |cx: &mut AsyncApp| {
|
||||||
authorize.await?;
|
authorize.await?;
|
||||||
|
|
||||||
|
let (request, model, action_log) = self.thread.update(cx, |thread, cx| {
|
||||||
|
let request = thread.build_completion_request(CompletionIntent::ToolResults, cx);
|
||||||
|
(request, thread.model().cloned(), thread.action_log().clone())
|
||||||
|
})?;
|
||||||
|
let request = request?;
|
||||||
|
let model = model.context("No language model configured")?;
|
||||||
|
|
||||||
let edit_format = EditFormat::from_model(model.clone())?;
|
let edit_format = EditFormat::from_model(model.clone())?;
|
||||||
let edit_agent = EditAgent::new(
|
let edit_agent = EditAgent::new(
|
||||||
model,
|
model,
|
||||||
|
@ -419,7 +427,6 @@ impl AgentTool for EditFileTool {
|
||||||
|
|
||||||
Ok(EditFileToolOutput {
|
Ok(EditFileToolOutput {
|
||||||
input_path: input.path,
|
input_path: input.path,
|
||||||
project_path: project_path.path.to_path_buf(),
|
|
||||||
new_text: new_text.clone(),
|
new_text: new_text.clone(),
|
||||||
old_text,
|
old_text,
|
||||||
diff: unified_diff,
|
diff: unified_diff,
|
||||||
|
@ -427,6 +434,25 @@ impl AgentTool for EditFileTool {
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn replay(
|
||||||
|
&self,
|
||||||
|
_input: Self::Input,
|
||||||
|
output: Self::Output,
|
||||||
|
event_stream: ToolCallEventStream,
|
||||||
|
cx: &mut App,
|
||||||
|
) -> Result<()> {
|
||||||
|
event_stream.update_diff(cx.new(|cx| {
|
||||||
|
Diff::finalized(
|
||||||
|
output.input_path,
|
||||||
|
Some(output.old_text.to_string()),
|
||||||
|
output.new_text,
|
||||||
|
self.language_registry.clone(),
|
||||||
|
cx,
|
||||||
|
)
|
||||||
|
}));
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Validate that the file path is valid, meaning:
|
/// Validate that the file path is valid, meaning:
|
||||||
|
@ -515,6 +541,7 @@ mod tests {
|
||||||
let fs = project::FakeFs::new(cx.executor());
|
let fs = project::FakeFs::new(cx.executor());
|
||||||
fs.insert_tree("/root", json!({})).await;
|
fs.insert_tree("/root", json!({})).await;
|
||||||
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
|
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
|
||||||
|
let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
|
||||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||||
let context_server_registry =
|
let context_server_registry =
|
||||||
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
|
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
|
||||||
|
@ -527,6 +554,7 @@ mod tests {
|
||||||
action_log,
|
action_log,
|
||||||
Templates::new(),
|
Templates::new(),
|
||||||
Some(model),
|
Some(model),
|
||||||
|
None,
|
||||||
cx,
|
cx,
|
||||||
)
|
)
|
||||||
});
|
});
|
||||||
|
@ -537,7 +565,11 @@ mod tests {
|
||||||
path: "root/nonexistent_file.txt".into(),
|
path: "root/nonexistent_file.txt".into(),
|
||||||
mode: EditFileMode::Edit,
|
mode: EditFileMode::Edit,
|
||||||
};
|
};
|
||||||
Arc::new(EditFileTool { thread }).run(input, ToolCallEventStream::test().0, cx)
|
Arc::new(EditFileTool::new(thread.downgrade(), language_registry)).run(
|
||||||
|
input,
|
||||||
|
ToolCallEventStream::test().0,
|
||||||
|
cx,
|
||||||
|
)
|
||||||
})
|
})
|
||||||
.await;
|
.await;
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
|
@ -724,6 +756,7 @@ mod tests {
|
||||||
action_log.clone(),
|
action_log.clone(),
|
||||||
Templates::new(),
|
Templates::new(),
|
||||||
Some(model.clone()),
|
Some(model.clone()),
|
||||||
|
None,
|
||||||
cx,
|
cx,
|
||||||
)
|
)
|
||||||
});
|
});
|
||||||
|
@ -750,9 +783,10 @@ mod tests {
|
||||||
path: "root/src/main.rs".into(),
|
path: "root/src/main.rs".into(),
|
||||||
mode: EditFileMode::Overwrite,
|
mode: EditFileMode::Overwrite,
|
||||||
};
|
};
|
||||||
Arc::new(EditFileTool {
|
Arc::new(EditFileTool::new(
|
||||||
thread: thread.clone(),
|
thread.downgrade(),
|
||||||
})
|
language_registry.clone(),
|
||||||
|
))
|
||||||
.run(input, ToolCallEventStream::test().0, cx)
|
.run(input, ToolCallEventStream::test().0, cx)
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -806,7 +840,11 @@ mod tests {
|
||||||
path: "root/src/main.rs".into(),
|
path: "root/src/main.rs".into(),
|
||||||
mode: EditFileMode::Overwrite,
|
mode: EditFileMode::Overwrite,
|
||||||
};
|
};
|
||||||
Arc::new(EditFileTool { thread }).run(input, ToolCallEventStream::test().0, cx)
|
Arc::new(EditFileTool::new(thread.downgrade(), language_registry)).run(
|
||||||
|
input,
|
||||||
|
ToolCallEventStream::test().0,
|
||||||
|
cx,
|
||||||
|
)
|
||||||
});
|
});
|
||||||
|
|
||||||
// Stream the unformatted content
|
// Stream the unformatted content
|
||||||
|
@ -850,6 +888,7 @@ mod tests {
|
||||||
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
|
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
|
||||||
let context_server_registry =
|
let context_server_registry =
|
||||||
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
|
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
|
||||||
|
let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
|
||||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||||
let model = Arc::new(FakeLanguageModel::default());
|
let model = Arc::new(FakeLanguageModel::default());
|
||||||
let thread = cx.new(|cx| {
|
let thread = cx.new(|cx| {
|
||||||
|
@ -860,6 +899,7 @@ mod tests {
|
||||||
action_log.clone(),
|
action_log.clone(),
|
||||||
Templates::new(),
|
Templates::new(),
|
||||||
Some(model.clone()),
|
Some(model.clone()),
|
||||||
|
None,
|
||||||
cx,
|
cx,
|
||||||
)
|
)
|
||||||
});
|
});
|
||||||
|
@ -887,9 +927,10 @@ mod tests {
|
||||||
path: "root/src/main.rs".into(),
|
path: "root/src/main.rs".into(),
|
||||||
mode: EditFileMode::Overwrite,
|
mode: EditFileMode::Overwrite,
|
||||||
};
|
};
|
||||||
Arc::new(EditFileTool {
|
Arc::new(EditFileTool::new(
|
||||||
thread: thread.clone(),
|
thread.downgrade(),
|
||||||
})
|
language_registry.clone(),
|
||||||
|
))
|
||||||
.run(input, ToolCallEventStream::test().0, cx)
|
.run(input, ToolCallEventStream::test().0, cx)
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -938,10 +979,11 @@ mod tests {
|
||||||
path: "root/src/main.rs".into(),
|
path: "root/src/main.rs".into(),
|
||||||
mode: EditFileMode::Overwrite,
|
mode: EditFileMode::Overwrite,
|
||||||
};
|
};
|
||||||
Arc::new(EditFileTool {
|
Arc::new(EditFileTool::new(thread.downgrade(), language_registry)).run(
|
||||||
thread: thread.clone(),
|
input,
|
||||||
})
|
ToolCallEventStream::test().0,
|
||||||
.run(input, ToolCallEventStream::test().0, cx)
|
cx,
|
||||||
|
)
|
||||||
});
|
});
|
||||||
|
|
||||||
// Stream the content with trailing whitespace
|
// Stream the content with trailing whitespace
|
||||||
|
@ -976,6 +1018,7 @@ mod tests {
|
||||||
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
|
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
|
||||||
let context_server_registry =
|
let context_server_registry =
|
||||||
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
|
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
|
||||||
|
let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
|
||||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||||
let model = Arc::new(FakeLanguageModel::default());
|
let model = Arc::new(FakeLanguageModel::default());
|
||||||
let thread = cx.new(|cx| {
|
let thread = cx.new(|cx| {
|
||||||
|
@ -986,10 +1029,11 @@ mod tests {
|
||||||
action_log.clone(),
|
action_log.clone(),
|
||||||
Templates::new(),
|
Templates::new(),
|
||||||
Some(model.clone()),
|
Some(model.clone()),
|
||||||
|
None,
|
||||||
cx,
|
cx,
|
||||||
)
|
)
|
||||||
});
|
});
|
||||||
let tool = Arc::new(EditFileTool { thread });
|
let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
|
||||||
fs.insert_tree("/root", json!({})).await;
|
fs.insert_tree("/root", json!({})).await;
|
||||||
|
|
||||||
// Test 1: Path with .zed component should require confirmation
|
// Test 1: Path with .zed component should require confirmation
|
||||||
|
@ -1111,6 +1155,7 @@ mod tests {
|
||||||
let fs = project::FakeFs::new(cx.executor());
|
let fs = project::FakeFs::new(cx.executor());
|
||||||
fs.insert_tree("/project", json!({})).await;
|
fs.insert_tree("/project", json!({})).await;
|
||||||
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
|
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
|
||||||
|
let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
|
||||||
let context_server_registry =
|
let context_server_registry =
|
||||||
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
|
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
|
||||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||||
|
@ -1123,10 +1168,11 @@ mod tests {
|
||||||
action_log.clone(),
|
action_log.clone(),
|
||||||
Templates::new(),
|
Templates::new(),
|
||||||
Some(model.clone()),
|
Some(model.clone()),
|
||||||
|
None,
|
||||||
cx,
|
cx,
|
||||||
)
|
)
|
||||||
});
|
});
|
||||||
let tool = Arc::new(EditFileTool { thread });
|
let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
|
||||||
|
|
||||||
// Test global config paths - these should require confirmation if they exist and are outside the project
|
// Test global config paths - these should require confirmation if they exist and are outside the project
|
||||||
let test_cases = vec![
|
let test_cases = vec![
|
||||||
|
@ -1220,7 +1266,7 @@ mod tests {
|
||||||
cx,
|
cx,
|
||||||
)
|
)
|
||||||
.await;
|
.await;
|
||||||
|
let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
|
||||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||||
let context_server_registry =
|
let context_server_registry =
|
||||||
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
|
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
|
||||||
|
@ -1233,10 +1279,11 @@ mod tests {
|
||||||
action_log.clone(),
|
action_log.clone(),
|
||||||
Templates::new(),
|
Templates::new(),
|
||||||
Some(model.clone()),
|
Some(model.clone()),
|
||||||
|
None,
|
||||||
cx,
|
cx,
|
||||||
)
|
)
|
||||||
});
|
});
|
||||||
let tool = Arc::new(EditFileTool { thread });
|
let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
|
||||||
|
|
||||||
// Test files in different worktrees
|
// Test files in different worktrees
|
||||||
let test_cases = vec![
|
let test_cases = vec![
|
||||||
|
@ -1302,6 +1349,7 @@ mod tests {
|
||||||
)
|
)
|
||||||
.await;
|
.await;
|
||||||
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
|
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
|
||||||
|
let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
|
||||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||||
let context_server_registry =
|
let context_server_registry =
|
||||||
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
|
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
|
||||||
|
@ -1314,10 +1362,11 @@ mod tests {
|
||||||
action_log.clone(),
|
action_log.clone(),
|
||||||
Templates::new(),
|
Templates::new(),
|
||||||
Some(model.clone()),
|
Some(model.clone()),
|
||||||
|
None,
|
||||||
cx,
|
cx,
|
||||||
)
|
)
|
||||||
});
|
});
|
||||||
let tool = Arc::new(EditFileTool { thread });
|
let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
|
||||||
|
|
||||||
// Test edge cases
|
// Test edge cases
|
||||||
let test_cases = vec![
|
let test_cases = vec![
|
||||||
|
@ -1386,6 +1435,7 @@ mod tests {
|
||||||
)
|
)
|
||||||
.await;
|
.await;
|
||||||
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
|
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
|
||||||
|
let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
|
||||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||||
let context_server_registry =
|
let context_server_registry =
|
||||||
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
|
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
|
||||||
|
@ -1398,10 +1448,11 @@ mod tests {
|
||||||
action_log.clone(),
|
action_log.clone(),
|
||||||
Templates::new(),
|
Templates::new(),
|
||||||
Some(model.clone()),
|
Some(model.clone()),
|
||||||
|
None,
|
||||||
cx,
|
cx,
|
||||||
)
|
)
|
||||||
});
|
});
|
||||||
let tool = Arc::new(EditFileTool { thread });
|
let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
|
||||||
|
|
||||||
// Test different EditFileMode values
|
// Test different EditFileMode values
|
||||||
let modes = vec![
|
let modes = vec![
|
||||||
|
@ -1467,6 +1518,7 @@ mod tests {
|
||||||
init_test(cx);
|
init_test(cx);
|
||||||
let fs = project::FakeFs::new(cx.executor());
|
let fs = project::FakeFs::new(cx.executor());
|
||||||
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
|
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
|
||||||
|
let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
|
||||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||||
let context_server_registry =
|
let context_server_registry =
|
||||||
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
|
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
|
||||||
|
@ -1479,10 +1531,11 @@ mod tests {
|
||||||
action_log.clone(),
|
action_log.clone(),
|
||||||
Templates::new(),
|
Templates::new(),
|
||||||
Some(model.clone()),
|
Some(model.clone()),
|
||||||
|
None,
|
||||||
cx,
|
cx,
|
||||||
)
|
)
|
||||||
});
|
});
|
||||||
let tool = Arc::new(EditFileTool { thread });
|
let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
|
||||||
|
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
tool.initial_title(Err(json!({
|
tool.initial_title(Err(json!({
|
||||||
|
|
|
@ -319,7 +319,7 @@ mod tests {
|
||||||
use theme::ThemeSettings;
|
use theme::ThemeSettings;
|
||||||
use util::test::TempTree;
|
use util::test::TempTree;
|
||||||
|
|
||||||
use crate::AgentResponseEvent;
|
use crate::ThreadEvent;
|
||||||
|
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
|
@ -396,7 +396,7 @@ mod tests {
|
||||||
});
|
});
|
||||||
cx.run_until_parked();
|
cx.run_until_parked();
|
||||||
let event = stream_rx.try_next();
|
let event = stream_rx.try_next();
|
||||||
if let Ok(Some(Ok(AgentResponseEvent::ToolCallAuthorization(auth)))) = event {
|
if let Ok(Some(Ok(ThreadEvent::ToolCallAuthorization(auth)))) = event {
|
||||||
auth.response.send(auth.options[0].id.clone()).unwrap();
|
auth.response.send(auth.options[0].id.clone()).unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -80,33 +80,48 @@ impl AgentTool for WebSearchTool {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let result_text = if response.results.len() == 1 {
|
emit_update(&response, &event_stream);
|
||||||
"1 result".to_string()
|
|
||||||
} else {
|
|
||||||
format!("{} results", response.results.len())
|
|
||||||
};
|
|
||||||
event_stream.update_fields(acp::ToolCallUpdateFields {
|
|
||||||
title: Some(format!("Searched the web: {result_text}")),
|
|
||||||
content: Some(
|
|
||||||
response
|
|
||||||
.results
|
|
||||||
.iter()
|
|
||||||
.map(|result| acp::ToolCallContent::Content {
|
|
||||||
content: acp::ContentBlock::ResourceLink(acp::ResourceLink {
|
|
||||||
name: result.title.clone(),
|
|
||||||
uri: result.url.clone(),
|
|
||||||
title: Some(result.title.clone()),
|
|
||||||
description: Some(result.text.clone()),
|
|
||||||
mime_type: None,
|
|
||||||
annotations: None,
|
|
||||||
size: None,
|
|
||||||
}),
|
|
||||||
})
|
|
||||||
.collect(),
|
|
||||||
),
|
|
||||||
..Default::default()
|
|
||||||
});
|
|
||||||
Ok(WebSearchToolOutput(response))
|
Ok(WebSearchToolOutput(response))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn replay(
|
||||||
|
&self,
|
||||||
|
_input: Self::Input,
|
||||||
|
output: Self::Output,
|
||||||
|
event_stream: ToolCallEventStream,
|
||||||
|
_cx: &mut App,
|
||||||
|
) -> Result<()> {
|
||||||
|
emit_update(&output.0, &event_stream);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn emit_update(response: &WebSearchResponse, event_stream: &ToolCallEventStream) {
|
||||||
|
let result_text = if response.results.len() == 1 {
|
||||||
|
"1 result".to_string()
|
||||||
|
} else {
|
||||||
|
format!("{} results", response.results.len())
|
||||||
|
};
|
||||||
|
event_stream.update_fields(acp::ToolCallUpdateFields {
|
||||||
|
title: Some(format!("Searched the web: {result_text}")),
|
||||||
|
content: Some(
|
||||||
|
response
|
||||||
|
.results
|
||||||
|
.iter()
|
||||||
|
.map(|result| acp::ToolCallContent::Content {
|
||||||
|
content: acp::ContentBlock::ResourceLink(acp::ResourceLink {
|
||||||
|
name: result.title.clone(),
|
||||||
|
uri: result.url.clone(),
|
||||||
|
title: Some(result.title.clone()),
|
||||||
|
description: Some(result.text.clone()),
|
||||||
|
mime_type: None,
|
||||||
|
annotations: None,
|
||||||
|
size: None,
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
.collect(),
|
||||||
|
),
|
||||||
|
..Default::default()
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,6 +18,7 @@ doctest = false
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
acp_thread.workspace = true
|
acp_thread.workspace = true
|
||||||
|
action_log.workspace = true
|
||||||
agent-client-protocol.workspace = true
|
agent-client-protocol.workspace = true
|
||||||
agent_settings.workspace = true
|
agent_settings.workspace = true
|
||||||
agentic-coding-protocol.workspace = true
|
agentic-coding-protocol.workspace = true
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
// Translates old acp agents into the new schema
|
// Translates old acp agents into the new schema
|
||||||
|
use action_log::ActionLog;
|
||||||
use agent_client_protocol as acp;
|
use agent_client_protocol as acp;
|
||||||
use agentic_coding_protocol::{self as acp_old, AgentRequest as _};
|
use agentic_coding_protocol::{self as acp_old, AgentRequest as _};
|
||||||
use anyhow::{Context as _, Result, anyhow};
|
use anyhow::{Context as _, Result, anyhow};
|
||||||
|
@ -443,7 +444,8 @@ impl AgentConnection for AcpConnection {
|
||||||
cx.update(|cx| {
|
cx.update(|cx| {
|
||||||
let thread = cx.new(|cx| {
|
let thread = cx.new(|cx| {
|
||||||
let session_id = acp::SessionId("acp-old-no-id".into());
|
let session_id = acp::SessionId("acp-old-no-id".into());
|
||||||
AcpThread::new(self.name, self.clone(), project, session_id, cx)
|
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||||
|
AcpThread::new(self.name, self.clone(), project, action_log, session_id)
|
||||||
});
|
});
|
||||||
current_thread.replace(thread.downgrade());
|
current_thread.replace(thread.downgrade());
|
||||||
thread
|
thread
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
use action_log::ActionLog;
|
||||||
use agent_client_protocol::{self as acp, Agent as _};
|
use agent_client_protocol::{self as acp, Agent as _};
|
||||||
use anyhow::anyhow;
|
use anyhow::anyhow;
|
||||||
use collections::HashMap;
|
use collections::HashMap;
|
||||||
|
@ -153,14 +154,14 @@ impl AgentConnection for AcpConnection {
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
let session_id = response.session_id;
|
let session_id = response.session_id;
|
||||||
|
let action_log = cx.new(|_| ActionLog::new(project.clone()))?;
|
||||||
let thread = cx.new(|cx| {
|
let thread = cx.new(|_cx| {
|
||||||
AcpThread::new(
|
AcpThread::new(
|
||||||
self.server_name,
|
self.server_name,
|
||||||
self.clone(),
|
self.clone(),
|
||||||
project,
|
project,
|
||||||
|
action_log,
|
||||||
session_id.clone(),
|
session_id.clone(),
|
||||||
cx,
|
|
||||||
)
|
)
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
mod mcp_server;
|
mod mcp_server;
|
||||||
pub mod tools;
|
pub mod tools;
|
||||||
|
|
||||||
|
use action_log::ActionLog;
|
||||||
use collections::HashMap;
|
use collections::HashMap;
|
||||||
use context_server::listener::McpServerTool;
|
use context_server::listener::McpServerTool;
|
||||||
use language_models::provider::anthropic::AnthropicLanguageModelProvider;
|
use language_models::provider::anthropic::AnthropicLanguageModelProvider;
|
||||||
|
@ -215,8 +216,15 @@ impl AgentConnection for ClaudeAgentConnection {
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
let thread = cx.new(|cx| {
|
let action_log = cx.new(|_| ActionLog::new(project.clone()))?;
|
||||||
AcpThread::new("Claude Code", self.clone(), project, session_id.clone(), cx)
|
let thread = cx.new(|_cx| {
|
||||||
|
AcpThread::new(
|
||||||
|
"Claude Code",
|
||||||
|
self.clone(),
|
||||||
|
project,
|
||||||
|
action_log,
|
||||||
|
session_id.clone(),
|
||||||
|
)
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
thread_tx.send(thread.downgrade())?;
|
thread_tx.send(thread.downgrade())?;
|
||||||
|
|
|
@ -303,8 +303,13 @@ impl AcpThreadView {
|
||||||
let action_log_subscription =
|
let action_log_subscription =
|
||||||
cx.observe(&action_log, |_, _, cx| cx.notify());
|
cx.observe(&action_log, |_, _, cx| cx.notify());
|
||||||
|
|
||||||
this.list_state
|
let count = thread.read(cx).entries().len();
|
||||||
.splice(0..0, thread.read(cx).entries().len());
|
this.list_state.splice(0..0, count);
|
||||||
|
this.entry_view_state.update(cx, |view_state, cx| {
|
||||||
|
for ix in 0..count {
|
||||||
|
view_state.sync_entry(ix, &thread, window, cx);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
AgentDiff::set_active_thread(&workspace, thread.clone(), window, cx);
|
AgentDiff::set_active_thread(&workspace, thread.clone(), window, cx);
|
||||||
|
|
||||||
|
@ -808,6 +813,7 @@ impl AcpThreadView {
|
||||||
self.thread_retry_status.take();
|
self.thread_retry_status.take();
|
||||||
self.thread_state = ThreadState::ServerExited { status: *status };
|
self.thread_state = ThreadState::ServerExited { status: *status };
|
||||||
}
|
}
|
||||||
|
AcpThreadEvent::TitleUpdated => {}
|
||||||
}
|
}
|
||||||
cx.notify();
|
cx.notify();
|
||||||
}
|
}
|
||||||
|
@ -2816,12 +2822,15 @@ impl AcpThreadView {
|
||||||
return;
|
return;
|
||||||
};
|
};
|
||||||
|
|
||||||
thread.update(cx, |thread, _cx| {
|
thread.update(cx, |thread, cx| {
|
||||||
let current_mode = thread.completion_mode();
|
let current_mode = thread.completion_mode();
|
||||||
thread.set_completion_mode(match current_mode {
|
thread.set_completion_mode(
|
||||||
CompletionMode::Burn => CompletionMode::Normal,
|
match current_mode {
|
||||||
CompletionMode::Normal => CompletionMode::Burn,
|
CompletionMode::Burn => CompletionMode::Normal,
|
||||||
});
|
CompletionMode::Normal => CompletionMode::Burn,
|
||||||
|
},
|
||||||
|
cx,
|
||||||
|
);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3572,8 +3581,9 @@ impl AcpThreadView {
|
||||||
))
|
))
|
||||||
.on_click({
|
.on_click({
|
||||||
cx.listener(move |this, _, _window, cx| {
|
cx.listener(move |this, _, _window, cx| {
|
||||||
thread.update(cx, |thread, _cx| {
|
thread.update(cx, |thread, cx| {
|
||||||
thread.set_completion_mode(CompletionMode::Burn);
|
thread
|
||||||
|
.set_completion_mode(CompletionMode::Burn, cx);
|
||||||
});
|
});
|
||||||
this.resume_chat(cx);
|
this.resume_chat(cx);
|
||||||
})
|
})
|
||||||
|
@ -4156,12 +4166,13 @@ pub(crate) mod tests {
|
||||||
cx: &mut gpui::App,
|
cx: &mut gpui::App,
|
||||||
) -> Task<gpui::Result<Entity<AcpThread>>> {
|
) -> Task<gpui::Result<Entity<AcpThread>>> {
|
||||||
Task::ready(Ok(cx.new(|cx| {
|
Task::ready(Ok(cx.new(|cx| {
|
||||||
|
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||||
AcpThread::new(
|
AcpThread::new(
|
||||||
"SaboteurAgentConnection",
|
"SaboteurAgentConnection",
|
||||||
self,
|
self,
|
||||||
project,
|
project,
|
||||||
|
action_log,
|
||||||
SessionId("test".into()),
|
SessionId("test".into()),
|
||||||
cx,
|
|
||||||
)
|
)
|
||||||
})))
|
})))
|
||||||
}
|
}
|
||||||
|
|
|
@ -199,24 +199,21 @@ impl AgentDiffPane {
|
||||||
let action_log = thread.action_log(cx).clone();
|
let action_log = thread.action_log(cx).clone();
|
||||||
|
|
||||||
let mut this = Self {
|
let mut this = Self {
|
||||||
_subscriptions: [
|
_subscriptions: vec![
|
||||||
Some(
|
cx.observe_in(&action_log, window, |this, _action_log, window, cx| {
|
||||||
cx.observe_in(&action_log, window, |this, _action_log, window, cx| {
|
this.update_excerpts(window, cx)
|
||||||
this.update_excerpts(window, cx)
|
}),
|
||||||
}),
|
|
||||||
),
|
|
||||||
match &thread {
|
match &thread {
|
||||||
AgentDiffThread::Native(thread) => {
|
AgentDiffThread::Native(thread) => cx
|
||||||
Some(cx.subscribe(thread, |this, _thread, event, cx| {
|
.subscribe(thread, |this, _thread, event, cx| {
|
||||||
this.handle_thread_event(event, cx)
|
this.handle_native_thread_event(event, cx)
|
||||||
}))
|
}),
|
||||||
}
|
AgentDiffThread::AcpThread(thread) => cx
|
||||||
AgentDiffThread::AcpThread(_) => None,
|
.subscribe(thread, |this, _thread, event, cx| {
|
||||||
|
this.handle_acp_thread_event(event, cx)
|
||||||
|
}),
|
||||||
},
|
},
|
||||||
]
|
],
|
||||||
.into_iter()
|
|
||||||
.flatten()
|
|
||||||
.collect(),
|
|
||||||
title: SharedString::default(),
|
title: SharedString::default(),
|
||||||
multibuffer,
|
multibuffer,
|
||||||
editor,
|
editor,
|
||||||
|
@ -324,13 +321,20 @@ impl AgentDiffPane {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn handle_thread_event(&mut self, event: &ThreadEvent, cx: &mut Context<Self>) {
|
fn handle_native_thread_event(&mut self, event: &ThreadEvent, cx: &mut Context<Self>) {
|
||||||
match event {
|
match event {
|
||||||
ThreadEvent::SummaryGenerated => self.update_title(cx),
|
ThreadEvent::SummaryGenerated => self.update_title(cx),
|
||||||
_ => {}
|
_ => {}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn handle_acp_thread_event(&mut self, event: &AcpThreadEvent, cx: &mut Context<Self>) {
|
||||||
|
match event {
|
||||||
|
AcpThreadEvent::TitleUpdated => self.update_title(cx),
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn move_to_path(&self, path_key: PathKey, window: &mut Window, cx: &mut App) {
|
pub fn move_to_path(&self, path_key: PathKey, window: &mut Window, cx: &mut App) {
|
||||||
if let Some(position) = self.multibuffer.read(cx).location_for_path(&path_key, cx) {
|
if let Some(position) = self.multibuffer.read(cx).location_for_path(&path_key, cx) {
|
||||||
self.editor.update(cx, |editor, cx| {
|
self.editor.update(cx, |editor, cx| {
|
||||||
|
@ -1523,7 +1527,8 @@ impl AgentDiff {
|
||||||
AcpThreadEvent::Stopped | AcpThreadEvent::Error | AcpThreadEvent::ServerExited(_) => {
|
AcpThreadEvent::Stopped | AcpThreadEvent::Error | AcpThreadEvent::ServerExited(_) => {
|
||||||
self.update_reviewing_editors(workspace, window, cx);
|
self.update_reviewing_editors(workspace, window, cx);
|
||||||
}
|
}
|
||||||
AcpThreadEvent::EntriesRemoved(_)
|
AcpThreadEvent::TitleUpdated
|
||||||
|
| AcpThreadEvent::EntriesRemoved(_)
|
||||||
| AcpThreadEvent::ToolAuthorizationRequired
|
| AcpThreadEvent::ToolAuthorizationRequired
|
||||||
| AcpThreadEvent::Retry(_) => {}
|
| AcpThreadEvent::Retry(_) => {}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue