Lay the groundwork to support history in agent2 (#36483)
This pull request introduces title generation and history replaying. We still need to wire up the rest of the history but this gets us very close. I extracted a lot of this code from `agent2-history` because that branch was starting to get long-lived and there were lots of changes since we started. Release Notes: - N/A
This commit is contained in:
parent
c4083b9b63
commit
6c255c1973
19 changed files with 929 additions and 328 deletions
3
Cargo.lock
generated
3
Cargo.lock
generated
|
@ -191,6 +191,7 @@ version = "0.1.0"
|
|||
dependencies = [
|
||||
"acp_thread",
|
||||
"action_log",
|
||||
"agent",
|
||||
"agent-client-protocol",
|
||||
"agent_servers",
|
||||
"agent_settings",
|
||||
|
@ -208,6 +209,7 @@ dependencies = [
|
|||
"env_logger 0.11.8",
|
||||
"fs",
|
||||
"futures 0.3.31",
|
||||
"git",
|
||||
"gpui",
|
||||
"gpui_tokio",
|
||||
"handlebars 4.5.0",
|
||||
|
@ -256,6 +258,7 @@ name = "agent_servers"
|
|||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"acp_thread",
|
||||
"action_log",
|
||||
"agent-client-protocol",
|
||||
"agent_settings",
|
||||
"agentic-coding-protocol",
|
||||
|
|
|
@ -537,9 +537,15 @@ impl ToolCallContent {
|
|||
acp::ToolCallContent::Content { content } => {
|
||||
Self::ContentBlock(ContentBlock::new(content, &language_registry, cx))
|
||||
}
|
||||
acp::ToolCallContent::Diff { diff } => {
|
||||
Self::Diff(cx.new(|cx| Diff::from_acp(diff, language_registry, cx)))
|
||||
}
|
||||
acp::ToolCallContent::Diff { diff } => Self::Diff(cx.new(|cx| {
|
||||
Diff::finalized(
|
||||
diff.path,
|
||||
diff.old_text,
|
||||
diff.new_text,
|
||||
language_registry,
|
||||
cx,
|
||||
)
|
||||
})),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -682,6 +688,7 @@ pub struct AcpThread {
|
|||
#[derive(Debug)]
|
||||
pub enum AcpThreadEvent {
|
||||
NewEntry,
|
||||
TitleUpdated,
|
||||
EntryUpdated(usize),
|
||||
EntriesRemoved(Range<usize>),
|
||||
ToolAuthorizationRequired,
|
||||
|
@ -728,11 +735,9 @@ impl AcpThread {
|
|||
title: impl Into<SharedString>,
|
||||
connection: Rc<dyn AgentConnection>,
|
||||
project: Entity<Project>,
|
||||
action_log: Entity<ActionLog>,
|
||||
session_id: acp::SessionId,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Self {
|
||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||
|
||||
Self {
|
||||
action_log,
|
||||
shared_buffers: Default::default(),
|
||||
|
@ -926,6 +931,12 @@ impl AcpThread {
|
|||
cx.emit(AcpThreadEvent::NewEntry);
|
||||
}
|
||||
|
||||
pub fn update_title(&mut self, title: SharedString, cx: &mut Context<Self>) -> Result<()> {
|
||||
self.title = title;
|
||||
cx.emit(AcpThreadEvent::TitleUpdated);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn update_retry_status(&mut self, status: RetryStatus, cx: &mut Context<Self>) {
|
||||
cx.emit(AcpThreadEvent::Retry(status));
|
||||
}
|
||||
|
@ -1657,7 +1668,7 @@ mod tests {
|
|||
use super::*;
|
||||
use anyhow::anyhow;
|
||||
use futures::{channel::mpsc, future::LocalBoxFuture, select};
|
||||
use gpui::{AsyncApp, TestAppContext, WeakEntity};
|
||||
use gpui::{App, AsyncApp, TestAppContext, WeakEntity};
|
||||
use indoc::indoc;
|
||||
use project::{FakeFs, Fs};
|
||||
use rand::Rng as _;
|
||||
|
@ -2327,7 +2338,7 @@ mod tests {
|
|||
self: Rc<Self>,
|
||||
project: Entity<Project>,
|
||||
_cwd: &Path,
|
||||
cx: &mut gpui::App,
|
||||
cx: &mut App,
|
||||
) -> Task<gpui::Result<Entity<AcpThread>>> {
|
||||
let session_id = acp::SessionId(
|
||||
rand::thread_rng()
|
||||
|
@ -2337,8 +2348,16 @@ mod tests {
|
|||
.collect::<String>()
|
||||
.into(),
|
||||
);
|
||||
let thread =
|
||||
cx.new(|cx| AcpThread::new("Test", self.clone(), project, session_id.clone(), cx));
|
||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||
let thread = cx.new(|_cx| {
|
||||
AcpThread::new(
|
||||
"Test",
|
||||
self.clone(),
|
||||
project,
|
||||
action_log,
|
||||
session_id.clone(),
|
||||
)
|
||||
});
|
||||
self.sessions.lock().insert(session_id, thread.downgrade());
|
||||
Task::ready(Ok(thread))
|
||||
}
|
||||
|
|
|
@ -5,11 +5,12 @@ use collections::IndexMap;
|
|||
use gpui::{Entity, SharedString, Task};
|
||||
use language_model::LanguageModelProviderId;
|
||||
use project::Project;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{any::Any, error::Error, fmt, path::Path, rc::Rc, sync::Arc};
|
||||
use ui::{App, IconName};
|
||||
use uuid::Uuid;
|
||||
|
||||
#[derive(Clone, Debug, Eq, PartialEq)]
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct UserMessageId(Arc<str>);
|
||||
|
||||
impl UserMessageId {
|
||||
|
@ -208,6 +209,7 @@ impl AgentModelList {
|
|||
mod test_support {
|
||||
use std::sync::Arc;
|
||||
|
||||
use action_log::ActionLog;
|
||||
use collections::HashMap;
|
||||
use futures::{channel::oneshot, future::try_join_all};
|
||||
use gpui::{AppContext as _, WeakEntity};
|
||||
|
@ -295,8 +297,16 @@ mod test_support {
|
|||
cx: &mut gpui::App,
|
||||
) -> Task<gpui::Result<Entity<AcpThread>>> {
|
||||
let session_id = acp::SessionId(self.sessions.lock().len().to_string().into());
|
||||
let thread =
|
||||
cx.new(|cx| AcpThread::new("Test", self.clone(), project, session_id.clone(), cx));
|
||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||
let thread = cx.new(|_cx| {
|
||||
AcpThread::new(
|
||||
"Test",
|
||||
self.clone(),
|
||||
project,
|
||||
action_log,
|
||||
session_id.clone(),
|
||||
)
|
||||
});
|
||||
self.sessions.lock().insert(
|
||||
session_id,
|
||||
Session {
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
use agent_client_protocol as acp;
|
||||
use anyhow::Result;
|
||||
use buffer_diff::{BufferDiff, BufferDiffSnapshot};
|
||||
use editor::{MultiBuffer, PathKey};
|
||||
|
@ -21,17 +20,13 @@ pub enum Diff {
|
|||
}
|
||||
|
||||
impl Diff {
|
||||
pub fn from_acp(
|
||||
diff: acp::Diff,
|
||||
pub fn finalized(
|
||||
path: PathBuf,
|
||||
old_text: Option<String>,
|
||||
new_text: String,
|
||||
language_registry: Arc<LanguageRegistry>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Self {
|
||||
let acp::Diff {
|
||||
path,
|
||||
old_text,
|
||||
new_text,
|
||||
} = diff;
|
||||
|
||||
let multibuffer = cx.new(|_cx| MultiBuffer::without_headers(Capability::ReadOnly));
|
||||
|
||||
let new_buffer = cx.new(|cx| Buffer::local(new_text, cx));
|
||||
|
|
|
@ -2,6 +2,7 @@ use agent::ThreadId;
|
|||
use anyhow::{Context as _, Result, bail};
|
||||
use file_icons::FileIcons;
|
||||
use prompt_store::{PromptId, UserPromptId};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{
|
||||
fmt,
|
||||
ops::Range,
|
||||
|
@ -11,7 +12,7 @@ use std::{
|
|||
use ui::{App, IconName, SharedString};
|
||||
use url::Url;
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum MentionUri {
|
||||
File {
|
||||
abs_path: PathBuf,
|
||||
|
|
|
@ -14,6 +14,7 @@ workspace = true
|
|||
[dependencies]
|
||||
acp_thread.workspace = true
|
||||
action_log.workspace = true
|
||||
agent.workspace = true
|
||||
agent-client-protocol.workspace = true
|
||||
agent_servers.workspace = true
|
||||
agent_settings.workspace = true
|
||||
|
@ -26,6 +27,7 @@ collections.workspace = true
|
|||
context_server.workspace = true
|
||||
fs.workspace = true
|
||||
futures.workspace = true
|
||||
git.workspace = true
|
||||
gpui.workspace = true
|
||||
handlebars = { workspace = true, features = ["rust-embed"] }
|
||||
html_to_markdown.workspace = true
|
||||
|
@ -59,6 +61,7 @@ which.workspace = true
|
|||
workspace-hack.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
agent = { workspace = true, "features" = ["test-support"] }
|
||||
ctor.workspace = true
|
||||
client = { workspace = true, "features" = ["test-support"] }
|
||||
clock = { workspace = true, "features" = ["test-support"] }
|
||||
|
@ -66,6 +69,7 @@ context_server = { workspace = true, "features" = ["test-support"] }
|
|||
editor = { workspace = true, "features" = ["test-support"] }
|
||||
env_logger.workspace = true
|
||||
fs = { workspace = true, "features" = ["test-support"] }
|
||||
git = { workspace = true, "features" = ["test-support"] }
|
||||
gpui = { workspace = true, "features" = ["test-support"] }
|
||||
gpui_tokio.workspace = true
|
||||
language = { workspace = true, "features" = ["test-support"] }
|
||||
|
|
|
@ -1,10 +1,11 @@
|
|||
use crate::{
|
||||
AgentResponseEvent, ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DeletePathTool,
|
||||
DiagnosticsTool, EditFileTool, FetchTool, FindPathTool, GrepTool, ListDirectoryTool,
|
||||
MovePathTool, NowTool, OpenTool, ReadFileTool, TerminalTool, ThinkingTool, Thread,
|
||||
ToolCallAuthorization, UserMessageContent, WebSearchTool, templates::Templates,
|
||||
ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DeletePathTool, DiagnosticsTool,
|
||||
EditFileTool, FetchTool, FindPathTool, GrepTool, ListDirectoryTool, MovePathTool, NowTool,
|
||||
OpenTool, ReadFileTool, TerminalTool, ThinkingTool, Thread, ThreadEvent, ToolCallAuthorization,
|
||||
UserMessageContent, WebSearchTool, templates::Templates,
|
||||
};
|
||||
use acp_thread::AgentModelSelector;
|
||||
use action_log::ActionLog;
|
||||
use agent_client_protocol as acp;
|
||||
use agent_settings::AgentSettings;
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
|
@ -427,18 +428,19 @@ impl NativeAgent {
|
|||
) {
|
||||
self.models.refresh_list(cx);
|
||||
|
||||
let default_model = LanguageModelRegistry::read_global(cx)
|
||||
.default_model()
|
||||
.map(|m| m.model.clone());
|
||||
let registry = LanguageModelRegistry::read_global(cx);
|
||||
let default_model = registry.default_model().map(|m| m.model.clone());
|
||||
let summarization_model = registry.thread_summary_model().map(|m| m.model.clone());
|
||||
|
||||
for session in self.sessions.values_mut() {
|
||||
session.thread.update(cx, |thread, cx| {
|
||||
if thread.model().is_none()
|
||||
&& let Some(model) = default_model.clone()
|
||||
{
|
||||
thread.set_model(model);
|
||||
thread.set_model(model, cx);
|
||||
cx.notify();
|
||||
}
|
||||
thread.set_summarization_model(summarization_model.clone(), cx);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
@ -462,10 +464,7 @@ impl NativeAgentConnection {
|
|||
session_id: acp::SessionId,
|
||||
cx: &mut App,
|
||||
f: impl 'static
|
||||
+ FnOnce(
|
||||
Entity<Thread>,
|
||||
&mut App,
|
||||
) -> Result<mpsc::UnboundedReceiver<Result<AgentResponseEvent>>>,
|
||||
+ FnOnce(Entity<Thread>, &mut App) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>>,
|
||||
) -> Task<Result<acp::PromptResponse>> {
|
||||
let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
|
||||
agent
|
||||
|
@ -489,7 +488,18 @@ impl NativeAgentConnection {
|
|||
log::trace!("Received completion event: {:?}", event);
|
||||
|
||||
match event {
|
||||
AgentResponseEvent::Text(text) => {
|
||||
ThreadEvent::UserMessage(message) => {
|
||||
acp_thread.update(cx, |thread, cx| {
|
||||
for content in message.content {
|
||||
thread.push_user_content_block(
|
||||
Some(message.id.clone()),
|
||||
content.into(),
|
||||
cx,
|
||||
);
|
||||
}
|
||||
})?;
|
||||
}
|
||||
ThreadEvent::AgentText(text) => {
|
||||
acp_thread.update(cx, |thread, cx| {
|
||||
thread.push_assistant_content_block(
|
||||
acp::ContentBlock::Text(acp::TextContent {
|
||||
|
@ -501,7 +511,7 @@ impl NativeAgentConnection {
|
|||
)
|
||||
})?;
|
||||
}
|
||||
AgentResponseEvent::Thinking(text) => {
|
||||
ThreadEvent::AgentThinking(text) => {
|
||||
acp_thread.update(cx, |thread, cx| {
|
||||
thread.push_assistant_content_block(
|
||||
acp::ContentBlock::Text(acp::TextContent {
|
||||
|
@ -513,7 +523,7 @@ impl NativeAgentConnection {
|
|||
)
|
||||
})?;
|
||||
}
|
||||
AgentResponseEvent::ToolCallAuthorization(ToolCallAuthorization {
|
||||
ThreadEvent::ToolCallAuthorization(ToolCallAuthorization {
|
||||
tool_call,
|
||||
options,
|
||||
response,
|
||||
|
@ -536,22 +546,26 @@ impl NativeAgentConnection {
|
|||
})
|
||||
.detach();
|
||||
}
|
||||
AgentResponseEvent::ToolCall(tool_call) => {
|
||||
ThreadEvent::ToolCall(tool_call) => {
|
||||
acp_thread.update(cx, |thread, cx| {
|
||||
thread.upsert_tool_call(tool_call, cx)
|
||||
})??;
|
||||
}
|
||||
AgentResponseEvent::ToolCallUpdate(update) => {
|
||||
ThreadEvent::ToolCallUpdate(update) => {
|
||||
acp_thread.update(cx, |thread, cx| {
|
||||
thread.update_tool_call(update, cx)
|
||||
})??;
|
||||
}
|
||||
AgentResponseEvent::Retry(status) => {
|
||||
ThreadEvent::TitleUpdate(title) => {
|
||||
acp_thread
|
||||
.update(cx, |thread, cx| thread.update_title(title, cx))??;
|
||||
}
|
||||
ThreadEvent::Retry(status) => {
|
||||
acp_thread.update(cx, |thread, cx| {
|
||||
thread.update_retry_status(status, cx)
|
||||
})?;
|
||||
}
|
||||
AgentResponseEvent::Stop(stop_reason) => {
|
||||
ThreadEvent::Stop(stop_reason) => {
|
||||
log::debug!("Assistant message complete: {:?}", stop_reason);
|
||||
return Ok(acp::PromptResponse { stop_reason });
|
||||
}
|
||||
|
@ -604,8 +618,8 @@ impl AgentModelSelector for NativeAgentConnection {
|
|||
return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
|
||||
};
|
||||
|
||||
thread.update(cx, |thread, _cx| {
|
||||
thread.set_model(model.clone());
|
||||
thread.update(cx, |thread, cx| {
|
||||
thread.set_model(model.clone(), cx);
|
||||
});
|
||||
|
||||
update_settings_file::<AgentSettings>(
|
||||
|
@ -665,30 +679,14 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
|
|||
cx.spawn(async move |cx| {
|
||||
log::debug!("Starting thread creation in async context");
|
||||
|
||||
// Generate session ID
|
||||
let session_id = acp::SessionId(uuid::Uuid::new_v4().to_string().into());
|
||||
log::info!("Created session with ID: {}", session_id);
|
||||
|
||||
// Create AcpThread
|
||||
let acp_thread = cx.update(|cx| {
|
||||
cx.new(|cx| {
|
||||
acp_thread::AcpThread::new(
|
||||
"agent2",
|
||||
self.clone(),
|
||||
project.clone(),
|
||||
session_id.clone(),
|
||||
cx,
|
||||
)
|
||||
})
|
||||
})?;
|
||||
let action_log = cx.update(|cx| acp_thread.read(cx).action_log().clone())?;
|
||||
|
||||
let action_log = cx.new(|_cx| ActionLog::new(project.clone()))?;
|
||||
// Create Thread
|
||||
let thread = agent.update(
|
||||
cx,
|
||||
|agent, cx: &mut gpui::Context<NativeAgent>| -> Result<_> {
|
||||
// Fetch default model from registry settings
|
||||
let registry = LanguageModelRegistry::read_global(cx);
|
||||
let language_registry = project.read(cx).languages().clone();
|
||||
|
||||
// Log available models for debugging
|
||||
let available_count = registry.available_models(cx).count();
|
||||
|
@ -699,6 +697,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
|
|||
.models
|
||||
.model_from_id(&LanguageModels::model_id(&default_model.model))
|
||||
});
|
||||
let summarization_model = registry.thread_summary_model().map(|c| c.model);
|
||||
|
||||
let thread = cx.new(|cx| {
|
||||
let mut thread = Thread::new(
|
||||
|
@ -708,13 +707,14 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
|
|||
action_log.clone(),
|
||||
agent.templates.clone(),
|
||||
default_model,
|
||||
summarization_model,
|
||||
cx,
|
||||
);
|
||||
thread.add_tool(CopyPathTool::new(project.clone()));
|
||||
thread.add_tool(CreateDirectoryTool::new(project.clone()));
|
||||
thread.add_tool(DeletePathTool::new(project.clone(), action_log.clone()));
|
||||
thread.add_tool(DiagnosticsTool::new(project.clone()));
|
||||
thread.add_tool(EditFileTool::new(cx.entity()));
|
||||
thread.add_tool(EditFileTool::new(cx.weak_entity(), language_registry));
|
||||
thread.add_tool(FetchTool::new(project.read(cx).client().http_client()));
|
||||
thread.add_tool(FindPathTool::new(project.clone()));
|
||||
thread.add_tool(GrepTool::new(project.clone()));
|
||||
|
@ -722,7 +722,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
|
|||
thread.add_tool(MovePathTool::new(project.clone()));
|
||||
thread.add_tool(NowTool);
|
||||
thread.add_tool(OpenTool::new(project.clone()));
|
||||
thread.add_tool(ReadFileTool::new(project.clone(), action_log));
|
||||
thread.add_tool(ReadFileTool::new(project.clone(), action_log.clone()));
|
||||
thread.add_tool(TerminalTool::new(project.clone(), cx));
|
||||
thread.add_tool(ThinkingTool);
|
||||
thread.add_tool(WebSearchTool); // TODO: Enable this only if it's a zed model.
|
||||
|
@ -733,6 +733,21 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
|
|||
},
|
||||
)??;
|
||||
|
||||
let session_id = thread.read_with(cx, |thread, _| thread.id().clone())?;
|
||||
log::info!("Created session with ID: {}", session_id);
|
||||
// Create AcpThread
|
||||
let acp_thread = cx.update(|cx| {
|
||||
cx.new(|_cx| {
|
||||
acp_thread::AcpThread::new(
|
||||
"agent2",
|
||||
self.clone(),
|
||||
project.clone(),
|
||||
action_log.clone(),
|
||||
session_id.clone(),
|
||||
)
|
||||
})
|
||||
})?;
|
||||
|
||||
// Store the session
|
||||
agent.update(cx, |agent, cx| {
|
||||
agent.sessions.insert(
|
||||
|
@ -803,7 +818,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
|
|||
log::info!("Cancelling on session: {}", session_id);
|
||||
self.0.update(cx, |agent, cx| {
|
||||
if let Some(agent) = agent.sessions.get(session_id) {
|
||||
agent.thread.update(cx, |thread, _cx| thread.cancel());
|
||||
agent.thread.update(cx, |thread, cx| thread.cancel(cx));
|
||||
}
|
||||
});
|
||||
}
|
||||
|
@ -830,7 +845,10 @@ struct NativeAgentSessionEditor(Entity<Thread>);
|
|||
|
||||
impl acp_thread::AgentSessionEditor for NativeAgentSessionEditor {
|
||||
fn truncate(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
|
||||
Task::ready(self.0.update(cx, |thread, _cx| thread.truncate(message_id)))
|
||||
Task::ready(
|
||||
self.0
|
||||
.update(cx, |thread, cx| thread.truncate(message_id, cx)),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -345,7 +345,7 @@ async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
|
|||
|
||||
let mut saw_partial_tool_use = false;
|
||||
while let Some(event) = events.next().await {
|
||||
if let Ok(AgentResponseEvent::ToolCall(tool_call)) = event {
|
||||
if let Ok(ThreadEvent::ToolCall(tool_call)) = event {
|
||||
thread.update(cx, |thread, _cx| {
|
||||
// Look for a tool use in the thread's last message
|
||||
let message = thread.last_message().unwrap();
|
||||
|
@ -735,16 +735,14 @@ async fn test_send_after_tool_use_limit(cx: &mut TestAppContext) {
|
|||
);
|
||||
}
|
||||
|
||||
async fn expect_tool_call(
|
||||
events: &mut UnboundedReceiver<Result<AgentResponseEvent>>,
|
||||
) -> acp::ToolCall {
|
||||
async fn expect_tool_call(events: &mut UnboundedReceiver<Result<ThreadEvent>>) -> acp::ToolCall {
|
||||
let event = events
|
||||
.next()
|
||||
.await
|
||||
.expect("no tool call authorization event received")
|
||||
.unwrap();
|
||||
match event {
|
||||
AgentResponseEvent::ToolCall(tool_call) => return tool_call,
|
||||
ThreadEvent::ToolCall(tool_call) => return tool_call,
|
||||
event => {
|
||||
panic!("Unexpected event {event:?}");
|
||||
}
|
||||
|
@ -752,7 +750,7 @@ async fn expect_tool_call(
|
|||
}
|
||||
|
||||
async fn expect_tool_call_update_fields(
|
||||
events: &mut UnboundedReceiver<Result<AgentResponseEvent>>,
|
||||
events: &mut UnboundedReceiver<Result<ThreadEvent>>,
|
||||
) -> acp::ToolCallUpdate {
|
||||
let event = events
|
||||
.next()
|
||||
|
@ -760,7 +758,7 @@ async fn expect_tool_call_update_fields(
|
|||
.expect("no tool call authorization event received")
|
||||
.unwrap();
|
||||
match event {
|
||||
AgentResponseEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => {
|
||||
ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => {
|
||||
return update;
|
||||
}
|
||||
event => {
|
||||
|
@ -770,7 +768,7 @@ async fn expect_tool_call_update_fields(
|
|||
}
|
||||
|
||||
async fn next_tool_call_authorization(
|
||||
events: &mut UnboundedReceiver<Result<AgentResponseEvent>>,
|
||||
events: &mut UnboundedReceiver<Result<ThreadEvent>>,
|
||||
) -> ToolCallAuthorization {
|
||||
loop {
|
||||
let event = events
|
||||
|
@ -778,7 +776,7 @@ async fn next_tool_call_authorization(
|
|||
.await
|
||||
.expect("no tool call authorization event received")
|
||||
.unwrap();
|
||||
if let AgentResponseEvent::ToolCallAuthorization(tool_call_authorization) = event {
|
||||
if let ThreadEvent::ToolCallAuthorization(tool_call_authorization) = event {
|
||||
let permission_kinds = tool_call_authorization
|
||||
.options
|
||||
.iter()
|
||||
|
@ -945,13 +943,13 @@ async fn test_cancellation(cx: &mut TestAppContext) {
|
|||
let mut echo_completed = false;
|
||||
while let Some(event) = events.next().await {
|
||||
match event.unwrap() {
|
||||
AgentResponseEvent::ToolCall(tool_call) => {
|
||||
ThreadEvent::ToolCall(tool_call) => {
|
||||
assert_eq!(tool_call.title, expected_tools.remove(0));
|
||||
if tool_call.title == "Echo" {
|
||||
echo_id = Some(tool_call.id);
|
||||
}
|
||||
}
|
||||
AgentResponseEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
|
||||
ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
|
||||
acp::ToolCallUpdate {
|
||||
id,
|
||||
fields:
|
||||
|
@ -973,13 +971,13 @@ async fn test_cancellation(cx: &mut TestAppContext) {
|
|||
|
||||
// Cancel the current send and ensure that the event stream is closed, even
|
||||
// if one of the tools is still running.
|
||||
thread.update(cx, |thread, _cx| thread.cancel());
|
||||
thread.update(cx, |thread, cx| thread.cancel(cx));
|
||||
let events = events.collect::<Vec<_>>().await;
|
||||
let last_event = events.last();
|
||||
assert!(
|
||||
matches!(
|
||||
last_event,
|
||||
Some(Ok(AgentResponseEvent::Stop(acp::StopReason::Canceled)))
|
||||
Some(Ok(ThreadEvent::Stop(acp::StopReason::Canceled)))
|
||||
),
|
||||
"unexpected event {last_event:?}"
|
||||
);
|
||||
|
@ -1161,7 +1159,7 @@ async fn test_truncate(cx: &mut TestAppContext) {
|
|||
});
|
||||
|
||||
thread
|
||||
.update(cx, |thread, _cx| thread.truncate(message_id))
|
||||
.update(cx, |thread, cx| thread.truncate(message_id, cx))
|
||||
.unwrap();
|
||||
cx.run_until_parked();
|
||||
thread.read_with(cx, |thread, _| {
|
||||
|
@ -1203,6 +1201,51 @@ async fn test_truncate(cx: &mut TestAppContext) {
|
|||
});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_title_generation(cx: &mut TestAppContext) {
|
||||
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
|
||||
let fake_model = model.as_fake();
|
||||
|
||||
let summary_model = Arc::new(FakeLanguageModel::default());
|
||||
thread.update(cx, |thread, cx| {
|
||||
thread.set_summarization_model(Some(summary_model.clone()), cx)
|
||||
});
|
||||
|
||||
let send = thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.send(UserMessageId::new(), ["Hello"], cx)
|
||||
})
|
||||
.unwrap();
|
||||
cx.run_until_parked();
|
||||
|
||||
fake_model.send_last_completion_stream_text_chunk("Hey!");
|
||||
fake_model.end_last_completion_stream();
|
||||
cx.run_until_parked();
|
||||
thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "New Thread"));
|
||||
|
||||
// Ensure the summary model has been invoked to generate a title.
|
||||
summary_model.send_last_completion_stream_text_chunk("Hello ");
|
||||
summary_model.send_last_completion_stream_text_chunk("world\nG");
|
||||
summary_model.send_last_completion_stream_text_chunk("oodnight Moon");
|
||||
summary_model.end_last_completion_stream();
|
||||
send.collect::<Vec<_>>().await;
|
||||
thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
|
||||
|
||||
// Send another message, ensuring no title is generated this time.
|
||||
let send = thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.send(UserMessageId::new(), ["Hello again"], cx)
|
||||
})
|
||||
.unwrap();
|
||||
cx.run_until_parked();
|
||||
fake_model.send_last_completion_stream_text_chunk("Hey again!");
|
||||
fake_model.end_last_completion_stream();
|
||||
cx.run_until_parked();
|
||||
assert_eq!(summary_model.pending_completions(), Vec::new());
|
||||
send.collect::<Vec<_>>().await;
|
||||
thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_agent_connection(cx: &mut TestAppContext) {
|
||||
cx.update(settings::init);
|
||||
|
@ -1442,7 +1485,7 @@ async fn test_send_no_retry_on_success(cx: &mut TestAppContext) {
|
|||
|
||||
let mut events = thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.set_completion_mode(agent_settings::CompletionMode::Burn);
|
||||
thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
|
||||
thread.send(UserMessageId::new(), ["Hello!"], cx)
|
||||
})
|
||||
.unwrap();
|
||||
|
@ -1454,10 +1497,10 @@ async fn test_send_no_retry_on_success(cx: &mut TestAppContext) {
|
|||
let mut retry_events = Vec::new();
|
||||
while let Some(Ok(event)) = events.next().await {
|
||||
match event {
|
||||
AgentResponseEvent::Retry(retry_status) => {
|
||||
ThreadEvent::Retry(retry_status) => {
|
||||
retry_events.push(retry_status);
|
||||
}
|
||||
AgentResponseEvent::Stop(..) => break,
|
||||
ThreadEvent::Stop(..) => break,
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
@ -1486,7 +1529,7 @@ async fn test_send_retry_on_error(cx: &mut TestAppContext) {
|
|||
|
||||
let mut events = thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.set_completion_mode(agent_settings::CompletionMode::Burn);
|
||||
thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
|
||||
thread.send(UserMessageId::new(), ["Hello!"], cx)
|
||||
})
|
||||
.unwrap();
|
||||
|
@ -1507,10 +1550,10 @@ async fn test_send_retry_on_error(cx: &mut TestAppContext) {
|
|||
let mut retry_events = Vec::new();
|
||||
while let Some(Ok(event)) = events.next().await {
|
||||
match event {
|
||||
AgentResponseEvent::Retry(retry_status) => {
|
||||
ThreadEvent::Retry(retry_status) => {
|
||||
retry_events.push(retry_status);
|
||||
}
|
||||
AgentResponseEvent::Stop(..) => break,
|
||||
ThreadEvent::Stop(..) => break,
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
@ -1543,7 +1586,7 @@ async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) {
|
|||
|
||||
let mut events = thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.set_completion_mode(agent_settings::CompletionMode::Burn);
|
||||
thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
|
||||
thread.send(UserMessageId::new(), ["Hello!"], cx)
|
||||
})
|
||||
.unwrap();
|
||||
|
@ -1565,10 +1608,10 @@ async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) {
|
|||
let mut retry_events = Vec::new();
|
||||
while let Some(event) = events.next().await {
|
||||
match event {
|
||||
Ok(AgentResponseEvent::Retry(retry_status)) => {
|
||||
Ok(ThreadEvent::Retry(retry_status)) => {
|
||||
retry_events.push(retry_status);
|
||||
}
|
||||
Ok(AgentResponseEvent::Stop(..)) => break,
|
||||
Ok(ThreadEvent::Stop(..)) => break,
|
||||
Err(error) => errors.push(error),
|
||||
_ => {}
|
||||
}
|
||||
|
@ -1592,11 +1635,11 @@ async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) {
|
|||
}
|
||||
|
||||
/// Filters out the stop events for asserting against in tests
|
||||
fn stop_events(result_events: Vec<Result<AgentResponseEvent>>) -> Vec<acp::StopReason> {
|
||||
fn stop_events(result_events: Vec<Result<ThreadEvent>>) -> Vec<acp::StopReason> {
|
||||
result_events
|
||||
.into_iter()
|
||||
.filter_map(|event| match event.unwrap() {
|
||||
AgentResponseEvent::Stop(stop_reason) => Some(stop_reason),
|
||||
ThreadEvent::Stop(stop_reason) => Some(stop_reason),
|
||||
_ => None,
|
||||
})
|
||||
.collect()
|
||||
|
@ -1713,6 +1756,7 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
|
|||
action_log,
|
||||
templates,
|
||||
Some(model.clone()),
|
||||
None,
|
||||
cx,
|
||||
)
|
||||
});
|
||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -228,4 +228,14 @@ impl AnyAgentTool for ContextServerTool {
|
|||
})
|
||||
})
|
||||
}
|
||||
|
||||
fn replay(
|
||||
&self,
|
||||
_input: serde_json::Value,
|
||||
_output: serde_json::Value,
|
||||
_event_stream: ToolCallEventStream,
|
||||
_cx: &mut App,
|
||||
) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,10 +5,10 @@ use anyhow::{Context as _, Result, anyhow};
|
|||
use assistant_tools::edit_agent::{EditAgent, EditAgentOutput, EditAgentOutputEvent, EditFormat};
|
||||
use cloud_llm_client::CompletionIntent;
|
||||
use collections::HashSet;
|
||||
use gpui::{App, AppContext, AsyncApp, Entity, Task};
|
||||
use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity};
|
||||
use indoc::formatdoc;
|
||||
use language::ToPoint;
|
||||
use language::language_settings::{self, FormatOnSave};
|
||||
use language::{LanguageRegistry, ToPoint};
|
||||
use language_model::LanguageModelToolResultContent;
|
||||
use paths;
|
||||
use project::lsp_store::{FormatTrigger, LspFormatTarget};
|
||||
|
@ -98,11 +98,13 @@ pub enum EditFileMode {
|
|||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct EditFileToolOutput {
|
||||
#[serde(alias = "original_path")]
|
||||
input_path: PathBuf,
|
||||
project_path: PathBuf,
|
||||
new_text: String,
|
||||
old_text: Arc<String>,
|
||||
#[serde(default)]
|
||||
diff: String,
|
||||
#[serde(alias = "raw_output")]
|
||||
edit_agent_output: EditAgentOutput,
|
||||
}
|
||||
|
||||
|
@ -122,12 +124,16 @@ impl From<EditFileToolOutput> for LanguageModelToolResultContent {
|
|||
}
|
||||
|
||||
pub struct EditFileTool {
|
||||
thread: Entity<Thread>,
|
||||
thread: WeakEntity<Thread>,
|
||||
language_registry: Arc<LanguageRegistry>,
|
||||
}
|
||||
|
||||
impl EditFileTool {
|
||||
pub fn new(thread: Entity<Thread>) -> Self {
|
||||
Self { thread }
|
||||
pub fn new(thread: WeakEntity<Thread>, language_registry: Arc<LanguageRegistry>) -> Self {
|
||||
Self {
|
||||
thread,
|
||||
language_registry,
|
||||
}
|
||||
}
|
||||
|
||||
fn authorize(
|
||||
|
@ -167,8 +173,11 @@ impl EditFileTool {
|
|||
|
||||
// Check if path is inside the global config directory
|
||||
// First check if it's already inside project - if not, try to canonicalize
|
||||
let thread = self.thread.read(cx);
|
||||
let project_path = thread.project().read(cx).find_project_path(&input.path, cx);
|
||||
let Ok(project_path) = self.thread.read_with(cx, |thread, cx| {
|
||||
thread.project().read(cx).find_project_path(&input.path, cx)
|
||||
}) else {
|
||||
return Task::ready(Err(anyhow!("thread was dropped")));
|
||||
};
|
||||
|
||||
// If the path is inside the project, and it's not one of the above edge cases,
|
||||
// then no confirmation is necessary. Otherwise, confirmation is necessary.
|
||||
|
@ -221,7 +230,12 @@ impl AgentTool for EditFileTool {
|
|||
event_stream: ToolCallEventStream,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<Self::Output>> {
|
||||
let project = self.thread.read(cx).project().clone();
|
||||
let Ok(project) = self
|
||||
.thread
|
||||
.read_with(cx, |thread, _cx| thread.project().clone())
|
||||
else {
|
||||
return Task::ready(Err(anyhow!("thread was dropped")));
|
||||
};
|
||||
let project_path = match resolve_path(&input, project.clone(), cx) {
|
||||
Ok(path) => path,
|
||||
Err(err) => return Task::ready(Err(anyhow!(err))),
|
||||
|
@ -237,23 +251,17 @@ impl AgentTool for EditFileTool {
|
|||
});
|
||||
}
|
||||
|
||||
let Some(request) = self.thread.update(cx, |thread, cx| {
|
||||
thread
|
||||
.build_completion_request(CompletionIntent::ToolResults, cx)
|
||||
.ok()
|
||||
}) else {
|
||||
return Task::ready(Err(anyhow!("Failed to build completion request")));
|
||||
};
|
||||
let thread = self.thread.read(cx);
|
||||
let Some(model) = thread.model().cloned() else {
|
||||
return Task::ready(Err(anyhow!("No language model configured")));
|
||||
};
|
||||
let action_log = thread.action_log().clone();
|
||||
|
||||
let authorize = self.authorize(&input, &event_stream, cx);
|
||||
cx.spawn(async move |cx: &mut AsyncApp| {
|
||||
authorize.await?;
|
||||
|
||||
let (request, model, action_log) = self.thread.update(cx, |thread, cx| {
|
||||
let request = thread.build_completion_request(CompletionIntent::ToolResults, cx);
|
||||
(request, thread.model().cloned(), thread.action_log().clone())
|
||||
})?;
|
||||
let request = request?;
|
||||
let model = model.context("No language model configured")?;
|
||||
|
||||
let edit_format = EditFormat::from_model(model.clone())?;
|
||||
let edit_agent = EditAgent::new(
|
||||
model,
|
||||
|
@ -419,7 +427,6 @@ impl AgentTool for EditFileTool {
|
|||
|
||||
Ok(EditFileToolOutput {
|
||||
input_path: input.path,
|
||||
project_path: project_path.path.to_path_buf(),
|
||||
new_text: new_text.clone(),
|
||||
old_text,
|
||||
diff: unified_diff,
|
||||
|
@ -427,6 +434,25 @@ impl AgentTool for EditFileTool {
|
|||
})
|
||||
})
|
||||
}
|
||||
|
||||
fn replay(
|
||||
&self,
|
||||
_input: Self::Input,
|
||||
output: Self::Output,
|
||||
event_stream: ToolCallEventStream,
|
||||
cx: &mut App,
|
||||
) -> Result<()> {
|
||||
event_stream.update_diff(cx.new(|cx| {
|
||||
Diff::finalized(
|
||||
output.input_path,
|
||||
Some(output.old_text.to_string()),
|
||||
output.new_text,
|
||||
self.language_registry.clone(),
|
||||
cx,
|
||||
)
|
||||
}));
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Validate that the file path is valid, meaning:
|
||||
|
@ -515,6 +541,7 @@ mod tests {
|
|||
let fs = project::FakeFs::new(cx.executor());
|
||||
fs.insert_tree("/root", json!({})).await;
|
||||
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
|
||||
let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
|
||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||
let context_server_registry =
|
||||
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
|
||||
|
@ -527,6 +554,7 @@ mod tests {
|
|||
action_log,
|
||||
Templates::new(),
|
||||
Some(model),
|
||||
None,
|
||||
cx,
|
||||
)
|
||||
});
|
||||
|
@ -537,7 +565,11 @@ mod tests {
|
|||
path: "root/nonexistent_file.txt".into(),
|
||||
mode: EditFileMode::Edit,
|
||||
};
|
||||
Arc::new(EditFileTool { thread }).run(input, ToolCallEventStream::test().0, cx)
|
||||
Arc::new(EditFileTool::new(thread.downgrade(), language_registry)).run(
|
||||
input,
|
||||
ToolCallEventStream::test().0,
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.await;
|
||||
assert_eq!(
|
||||
|
@ -724,6 +756,7 @@ mod tests {
|
|||
action_log.clone(),
|
||||
Templates::new(),
|
||||
Some(model.clone()),
|
||||
None,
|
||||
cx,
|
||||
)
|
||||
});
|
||||
|
@ -750,9 +783,10 @@ mod tests {
|
|||
path: "root/src/main.rs".into(),
|
||||
mode: EditFileMode::Overwrite,
|
||||
};
|
||||
Arc::new(EditFileTool {
|
||||
thread: thread.clone(),
|
||||
})
|
||||
Arc::new(EditFileTool::new(
|
||||
thread.downgrade(),
|
||||
language_registry.clone(),
|
||||
))
|
||||
.run(input, ToolCallEventStream::test().0, cx)
|
||||
});
|
||||
|
||||
|
@ -806,7 +840,11 @@ mod tests {
|
|||
path: "root/src/main.rs".into(),
|
||||
mode: EditFileMode::Overwrite,
|
||||
};
|
||||
Arc::new(EditFileTool { thread }).run(input, ToolCallEventStream::test().0, cx)
|
||||
Arc::new(EditFileTool::new(thread.downgrade(), language_registry)).run(
|
||||
input,
|
||||
ToolCallEventStream::test().0,
|
||||
cx,
|
||||
)
|
||||
});
|
||||
|
||||
// Stream the unformatted content
|
||||
|
@ -850,6 +888,7 @@ mod tests {
|
|||
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
|
||||
let context_server_registry =
|
||||
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
|
||||
let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
|
||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||
let model = Arc::new(FakeLanguageModel::default());
|
||||
let thread = cx.new(|cx| {
|
||||
|
@ -860,6 +899,7 @@ mod tests {
|
|||
action_log.clone(),
|
||||
Templates::new(),
|
||||
Some(model.clone()),
|
||||
None,
|
||||
cx,
|
||||
)
|
||||
});
|
||||
|
@ -887,9 +927,10 @@ mod tests {
|
|||
path: "root/src/main.rs".into(),
|
||||
mode: EditFileMode::Overwrite,
|
||||
};
|
||||
Arc::new(EditFileTool {
|
||||
thread: thread.clone(),
|
||||
})
|
||||
Arc::new(EditFileTool::new(
|
||||
thread.downgrade(),
|
||||
language_registry.clone(),
|
||||
))
|
||||
.run(input, ToolCallEventStream::test().0, cx)
|
||||
});
|
||||
|
||||
|
@ -938,10 +979,11 @@ mod tests {
|
|||
path: "root/src/main.rs".into(),
|
||||
mode: EditFileMode::Overwrite,
|
||||
};
|
||||
Arc::new(EditFileTool {
|
||||
thread: thread.clone(),
|
||||
})
|
||||
.run(input, ToolCallEventStream::test().0, cx)
|
||||
Arc::new(EditFileTool::new(thread.downgrade(), language_registry)).run(
|
||||
input,
|
||||
ToolCallEventStream::test().0,
|
||||
cx,
|
||||
)
|
||||
});
|
||||
|
||||
// Stream the content with trailing whitespace
|
||||
|
@ -976,6 +1018,7 @@ mod tests {
|
|||
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
|
||||
let context_server_registry =
|
||||
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
|
||||
let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
|
||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||
let model = Arc::new(FakeLanguageModel::default());
|
||||
let thread = cx.new(|cx| {
|
||||
|
@ -986,10 +1029,11 @@ mod tests {
|
|||
action_log.clone(),
|
||||
Templates::new(),
|
||||
Some(model.clone()),
|
||||
None,
|
||||
cx,
|
||||
)
|
||||
});
|
||||
let tool = Arc::new(EditFileTool { thread });
|
||||
let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
|
||||
fs.insert_tree("/root", json!({})).await;
|
||||
|
||||
// Test 1: Path with .zed component should require confirmation
|
||||
|
@ -1111,6 +1155,7 @@ mod tests {
|
|||
let fs = project::FakeFs::new(cx.executor());
|
||||
fs.insert_tree("/project", json!({})).await;
|
||||
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
|
||||
let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
|
||||
let context_server_registry =
|
||||
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
|
||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||
|
@ -1123,10 +1168,11 @@ mod tests {
|
|||
action_log.clone(),
|
||||
Templates::new(),
|
||||
Some(model.clone()),
|
||||
None,
|
||||
cx,
|
||||
)
|
||||
});
|
||||
let tool = Arc::new(EditFileTool { thread });
|
||||
let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
|
||||
|
||||
// Test global config paths - these should require confirmation if they exist and are outside the project
|
||||
let test_cases = vec![
|
||||
|
@ -1220,7 +1266,7 @@ mod tests {
|
|||
cx,
|
||||
)
|
||||
.await;
|
||||
|
||||
let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
|
||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||
let context_server_registry =
|
||||
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
|
||||
|
@ -1233,10 +1279,11 @@ mod tests {
|
|||
action_log.clone(),
|
||||
Templates::new(),
|
||||
Some(model.clone()),
|
||||
None,
|
||||
cx,
|
||||
)
|
||||
});
|
||||
let tool = Arc::new(EditFileTool { thread });
|
||||
let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
|
||||
|
||||
// Test files in different worktrees
|
||||
let test_cases = vec![
|
||||
|
@ -1302,6 +1349,7 @@ mod tests {
|
|||
)
|
||||
.await;
|
||||
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
|
||||
let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
|
||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||
let context_server_registry =
|
||||
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
|
||||
|
@ -1314,10 +1362,11 @@ mod tests {
|
|||
action_log.clone(),
|
||||
Templates::new(),
|
||||
Some(model.clone()),
|
||||
None,
|
||||
cx,
|
||||
)
|
||||
});
|
||||
let tool = Arc::new(EditFileTool { thread });
|
||||
let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
|
||||
|
||||
// Test edge cases
|
||||
let test_cases = vec![
|
||||
|
@ -1386,6 +1435,7 @@ mod tests {
|
|||
)
|
||||
.await;
|
||||
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
|
||||
let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
|
||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||
let context_server_registry =
|
||||
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
|
||||
|
@ -1398,10 +1448,11 @@ mod tests {
|
|||
action_log.clone(),
|
||||
Templates::new(),
|
||||
Some(model.clone()),
|
||||
None,
|
||||
cx,
|
||||
)
|
||||
});
|
||||
let tool = Arc::new(EditFileTool { thread });
|
||||
let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
|
||||
|
||||
// Test different EditFileMode values
|
||||
let modes = vec![
|
||||
|
@ -1467,6 +1518,7 @@ mod tests {
|
|||
init_test(cx);
|
||||
let fs = project::FakeFs::new(cx.executor());
|
||||
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
|
||||
let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
|
||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||
let context_server_registry =
|
||||
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
|
||||
|
@ -1479,10 +1531,11 @@ mod tests {
|
|||
action_log.clone(),
|
||||
Templates::new(),
|
||||
Some(model.clone()),
|
||||
None,
|
||||
cx,
|
||||
)
|
||||
});
|
||||
let tool = Arc::new(EditFileTool { thread });
|
||||
let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
|
||||
|
||||
assert_eq!(
|
||||
tool.initial_title(Err(json!({
|
||||
|
|
|
@ -319,7 +319,7 @@ mod tests {
|
|||
use theme::ThemeSettings;
|
||||
use util::test::TempTree;
|
||||
|
||||
use crate::AgentResponseEvent;
|
||||
use crate::ThreadEvent;
|
||||
|
||||
use super::*;
|
||||
|
||||
|
@ -396,7 +396,7 @@ mod tests {
|
|||
});
|
||||
cx.run_until_parked();
|
||||
let event = stream_rx.try_next();
|
||||
if let Ok(Some(Ok(AgentResponseEvent::ToolCallAuthorization(auth)))) = event {
|
||||
if let Ok(Some(Ok(ThreadEvent::ToolCallAuthorization(auth)))) = event {
|
||||
auth.response.send(auth.options[0].id.clone()).unwrap();
|
||||
}
|
||||
|
||||
|
|
|
@ -80,33 +80,48 @@ impl AgentTool for WebSearchTool {
|
|||
}
|
||||
};
|
||||
|
||||
let result_text = if response.results.len() == 1 {
|
||||
"1 result".to_string()
|
||||
} else {
|
||||
format!("{} results", response.results.len())
|
||||
};
|
||||
event_stream.update_fields(acp::ToolCallUpdateFields {
|
||||
title: Some(format!("Searched the web: {result_text}")),
|
||||
content: Some(
|
||||
response
|
||||
.results
|
||||
.iter()
|
||||
.map(|result| acp::ToolCallContent::Content {
|
||||
content: acp::ContentBlock::ResourceLink(acp::ResourceLink {
|
||||
name: result.title.clone(),
|
||||
uri: result.url.clone(),
|
||||
title: Some(result.title.clone()),
|
||||
description: Some(result.text.clone()),
|
||||
mime_type: None,
|
||||
annotations: None,
|
||||
size: None,
|
||||
}),
|
||||
})
|
||||
.collect(),
|
||||
),
|
||||
..Default::default()
|
||||
});
|
||||
emit_update(&response, &event_stream);
|
||||
Ok(WebSearchToolOutput(response))
|
||||
})
|
||||
}
|
||||
|
||||
fn replay(
|
||||
&self,
|
||||
_input: Self::Input,
|
||||
output: Self::Output,
|
||||
event_stream: ToolCallEventStream,
|
||||
_cx: &mut App,
|
||||
) -> Result<()> {
|
||||
emit_update(&output.0, &event_stream);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn emit_update(response: &WebSearchResponse, event_stream: &ToolCallEventStream) {
|
||||
let result_text = if response.results.len() == 1 {
|
||||
"1 result".to_string()
|
||||
} else {
|
||||
format!("{} results", response.results.len())
|
||||
};
|
||||
event_stream.update_fields(acp::ToolCallUpdateFields {
|
||||
title: Some(format!("Searched the web: {result_text}")),
|
||||
content: Some(
|
||||
response
|
||||
.results
|
||||
.iter()
|
||||
.map(|result| acp::ToolCallContent::Content {
|
||||
content: acp::ContentBlock::ResourceLink(acp::ResourceLink {
|
||||
name: result.title.clone(),
|
||||
uri: result.url.clone(),
|
||||
title: Some(result.title.clone()),
|
||||
description: Some(result.text.clone()),
|
||||
mime_type: None,
|
||||
annotations: None,
|
||||
size: None,
|
||||
}),
|
||||
})
|
||||
.collect(),
|
||||
),
|
||||
..Default::default()
|
||||
});
|
||||
}
|
||||
|
|
|
@ -18,6 +18,7 @@ doctest = false
|
|||
|
||||
[dependencies]
|
||||
acp_thread.workspace = true
|
||||
action_log.workspace = true
|
||||
agent-client-protocol.workspace = true
|
||||
agent_settings.workspace = true
|
||||
agentic-coding-protocol.workspace = true
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
// Translates old acp agents into the new schema
|
||||
use action_log::ActionLog;
|
||||
use agent_client_protocol as acp;
|
||||
use agentic_coding_protocol::{self as acp_old, AgentRequest as _};
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
|
@ -443,7 +444,8 @@ impl AgentConnection for AcpConnection {
|
|||
cx.update(|cx| {
|
||||
let thread = cx.new(|cx| {
|
||||
let session_id = acp::SessionId("acp-old-no-id".into());
|
||||
AcpThread::new(self.name, self.clone(), project, session_id, cx)
|
||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||
AcpThread::new(self.name, self.clone(), project, action_log, session_id)
|
||||
});
|
||||
current_thread.replace(thread.downgrade());
|
||||
thread
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
use action_log::ActionLog;
|
||||
use agent_client_protocol::{self as acp, Agent as _};
|
||||
use anyhow::anyhow;
|
||||
use collections::HashMap;
|
||||
|
@ -153,14 +154,14 @@ impl AgentConnection for AcpConnection {
|
|||
})?;
|
||||
|
||||
let session_id = response.session_id;
|
||||
|
||||
let thread = cx.new(|cx| {
|
||||
let action_log = cx.new(|_| ActionLog::new(project.clone()))?;
|
||||
let thread = cx.new(|_cx| {
|
||||
AcpThread::new(
|
||||
self.server_name,
|
||||
self.clone(),
|
||||
project,
|
||||
action_log,
|
||||
session_id.clone(),
|
||||
cx,
|
||||
)
|
||||
})?;
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
mod mcp_server;
|
||||
pub mod tools;
|
||||
|
||||
use action_log::ActionLog;
|
||||
use collections::HashMap;
|
||||
use context_server::listener::McpServerTool;
|
||||
use language_models::provider::anthropic::AnthropicLanguageModelProvider;
|
||||
|
@ -215,8 +216,15 @@ impl AgentConnection for ClaudeAgentConnection {
|
|||
}
|
||||
});
|
||||
|
||||
let thread = cx.new(|cx| {
|
||||
AcpThread::new("Claude Code", self.clone(), project, session_id.clone(), cx)
|
||||
let action_log = cx.new(|_| ActionLog::new(project.clone()))?;
|
||||
let thread = cx.new(|_cx| {
|
||||
AcpThread::new(
|
||||
"Claude Code",
|
||||
self.clone(),
|
||||
project,
|
||||
action_log,
|
||||
session_id.clone(),
|
||||
)
|
||||
})?;
|
||||
|
||||
thread_tx.send(thread.downgrade())?;
|
||||
|
|
|
@ -303,8 +303,13 @@ impl AcpThreadView {
|
|||
let action_log_subscription =
|
||||
cx.observe(&action_log, |_, _, cx| cx.notify());
|
||||
|
||||
this.list_state
|
||||
.splice(0..0, thread.read(cx).entries().len());
|
||||
let count = thread.read(cx).entries().len();
|
||||
this.list_state.splice(0..0, count);
|
||||
this.entry_view_state.update(cx, |view_state, cx| {
|
||||
for ix in 0..count {
|
||||
view_state.sync_entry(ix, &thread, window, cx);
|
||||
}
|
||||
});
|
||||
|
||||
AgentDiff::set_active_thread(&workspace, thread.clone(), window, cx);
|
||||
|
||||
|
@ -808,6 +813,7 @@ impl AcpThreadView {
|
|||
self.thread_retry_status.take();
|
||||
self.thread_state = ThreadState::ServerExited { status: *status };
|
||||
}
|
||||
AcpThreadEvent::TitleUpdated => {}
|
||||
}
|
||||
cx.notify();
|
||||
}
|
||||
|
@ -2816,12 +2822,15 @@ impl AcpThreadView {
|
|||
return;
|
||||
};
|
||||
|
||||
thread.update(cx, |thread, _cx| {
|
||||
thread.update(cx, |thread, cx| {
|
||||
let current_mode = thread.completion_mode();
|
||||
thread.set_completion_mode(match current_mode {
|
||||
CompletionMode::Burn => CompletionMode::Normal,
|
||||
CompletionMode::Normal => CompletionMode::Burn,
|
||||
});
|
||||
thread.set_completion_mode(
|
||||
match current_mode {
|
||||
CompletionMode::Burn => CompletionMode::Normal,
|
||||
CompletionMode::Normal => CompletionMode::Burn,
|
||||
},
|
||||
cx,
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
|
@ -3572,8 +3581,9 @@ impl AcpThreadView {
|
|||
))
|
||||
.on_click({
|
||||
cx.listener(move |this, _, _window, cx| {
|
||||
thread.update(cx, |thread, _cx| {
|
||||
thread.set_completion_mode(CompletionMode::Burn);
|
||||
thread.update(cx, |thread, cx| {
|
||||
thread
|
||||
.set_completion_mode(CompletionMode::Burn, cx);
|
||||
});
|
||||
this.resume_chat(cx);
|
||||
})
|
||||
|
@ -4156,12 +4166,13 @@ pub(crate) mod tests {
|
|||
cx: &mut gpui::App,
|
||||
) -> Task<gpui::Result<Entity<AcpThread>>> {
|
||||
Task::ready(Ok(cx.new(|cx| {
|
||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||
AcpThread::new(
|
||||
"SaboteurAgentConnection",
|
||||
self,
|
||||
project,
|
||||
action_log,
|
||||
SessionId("test".into()),
|
||||
cx,
|
||||
)
|
||||
})))
|
||||
}
|
||||
|
|
|
@ -199,24 +199,21 @@ impl AgentDiffPane {
|
|||
let action_log = thread.action_log(cx).clone();
|
||||
|
||||
let mut this = Self {
|
||||
_subscriptions: [
|
||||
Some(
|
||||
cx.observe_in(&action_log, window, |this, _action_log, window, cx| {
|
||||
this.update_excerpts(window, cx)
|
||||
}),
|
||||
),
|
||||
_subscriptions: vec![
|
||||
cx.observe_in(&action_log, window, |this, _action_log, window, cx| {
|
||||
this.update_excerpts(window, cx)
|
||||
}),
|
||||
match &thread {
|
||||
AgentDiffThread::Native(thread) => {
|
||||
Some(cx.subscribe(thread, |this, _thread, event, cx| {
|
||||
this.handle_thread_event(event, cx)
|
||||
}))
|
||||
}
|
||||
AgentDiffThread::AcpThread(_) => None,
|
||||
AgentDiffThread::Native(thread) => cx
|
||||
.subscribe(thread, |this, _thread, event, cx| {
|
||||
this.handle_native_thread_event(event, cx)
|
||||
}),
|
||||
AgentDiffThread::AcpThread(thread) => cx
|
||||
.subscribe(thread, |this, _thread, event, cx| {
|
||||
this.handle_acp_thread_event(event, cx)
|
||||
}),
|
||||
},
|
||||
]
|
||||
.into_iter()
|
||||
.flatten()
|
||||
.collect(),
|
||||
],
|
||||
title: SharedString::default(),
|
||||
multibuffer,
|
||||
editor,
|
||||
|
@ -324,13 +321,20 @@ impl AgentDiffPane {
|
|||
}
|
||||
}
|
||||
|
||||
fn handle_thread_event(&mut self, event: &ThreadEvent, cx: &mut Context<Self>) {
|
||||
fn handle_native_thread_event(&mut self, event: &ThreadEvent, cx: &mut Context<Self>) {
|
||||
match event {
|
||||
ThreadEvent::SummaryGenerated => self.update_title(cx),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_acp_thread_event(&mut self, event: &AcpThreadEvent, cx: &mut Context<Self>) {
|
||||
match event {
|
||||
AcpThreadEvent::TitleUpdated => self.update_title(cx),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn move_to_path(&self, path_key: PathKey, window: &mut Window, cx: &mut App) {
|
||||
if let Some(position) = self.multibuffer.read(cx).location_for_path(&path_key, cx) {
|
||||
self.editor.update(cx, |editor, cx| {
|
||||
|
@ -1523,7 +1527,8 @@ impl AgentDiff {
|
|||
AcpThreadEvent::Stopped | AcpThreadEvent::Error | AcpThreadEvent::ServerExited(_) => {
|
||||
self.update_reviewing_editors(workspace, window, cx);
|
||||
}
|
||||
AcpThreadEvent::EntriesRemoved(_)
|
||||
AcpThreadEvent::TitleUpdated
|
||||
| AcpThreadEvent::EntriesRemoved(_)
|
||||
| AcpThreadEvent::ToolAuthorizationRequired
|
||||
| AcpThreadEvent::Retry(_) => {}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue