From adbccb1ad09cd6d33a8d67b39adc578fb4b62468 Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Wed, 25 Jun 2025 10:30:52 -0700 Subject: [PATCH] Get agent2 compiling Co-authored-by: Conrad Irwin Co-authored-by: Antonio Scandurra --- crates/agent2/src/acp.rs | 153 +++++++++++++++++++----------------- crates/agent2/src/agent2.rs | 55 +++++++------ 2 files changed, 109 insertions(+), 99 deletions(-) diff --git a/crates/agent2/src/acp.rs b/crates/agent2/src/acp.rs index e2b2e1bd1a..529b33e828 100644 --- a/crates/agent2/src/acp.rs +++ b/crates/agent2/src/acp.rs @@ -1,19 +1,15 @@ -use std::{ - io::{Cursor, Write as _}, - path::Path, - sync::{Arc, Weak}, -}; +use std::{io::Write as _, path::Path, sync::Arc}; use crate::{ - Agent, AgentThread, AgentThreadEntryContent, AgentThreadSummary, Message, MessageChunk, - ResponseEvent, Role, Thread, ThreadEntry, ThreadId, + Agent, AgentThreadEntryContent, AgentThreadSummary, Message, MessageChunk, ResponseEvent, Role, + Thread, ThreadEntryId, ThreadId, }; -use agentic_coding_protocol::{self as acp, TurnId}; -use anyhow::{Context as _, Result}; +use agentic_coding_protocol as acp; +use anyhow::{Context as _, Result, anyhow}; use async_trait::async_trait; use collections::HashMap; use futures::channel::mpsc::UnboundedReceiver; -use gpui::{AppContext, AsyncApp, Entity, Task, WeakEntity}; +use gpui::{App, AppContext, AsyncApp, Context, Entity, Task, WeakEntity}; use parking_lot::Mutex; use project::Project; use smol::process::Child; @@ -21,18 +17,43 @@ use util::ResultExt; pub struct AcpAgent { connection: Arc, - threads: Arc>>>, + threads: Arc>>>, + project: Entity, _handler_task: Task<()>, _io_task: Task<()>, } struct AcpClientDelegate { project: Entity, - threads: Arc>>>, + threads: Arc>>>, cx: AsyncApp, // sent_buffer_versions: HashMap, HashMap>, } +impl AcpClientDelegate { + fn new(project: Entity, cx: AsyncApp) -> Self { + Self { + project, + threads: Default::default(), + cx: cx, + } + } + + fn update_thread( + &self, + thread_id: &ThreadId, + cx: &mut App, + callback: impl FnMut(&mut Thread, &mut Context) -> R, + ) -> Option { + let thread = self.threads.lock().get(&thread_id)?.clone(); + let Some(thread) = thread.upgrade() else { + self.threads.lock().remove(&thread_id); + return None; + }; + Some(thread.update(cx, callback)) + } +} + #[async_trait(?Send)] impl acp::Client for AcpClientDelegate { async fn stat(&self, params: acp::StatParams) -> Result { @@ -58,7 +79,7 @@ impl acp::Client for AcpClientDelegate { async fn stream_message_chunk( &self, - request: acp::StreamMessageChunkParams, + chunk: acp::StreamMessageChunkParams, ) -> Result { Ok(acp::StreamMessageChunkResponse) } @@ -78,25 +99,23 @@ impl acp::Client for AcpClientDelegate { })?? .await?; - buffer.update(cx, |buffer, _| { + buffer.update(cx, |buffer, cx| { let start = language::Point::new(request.line_offset.unwrap_or(0), 0); let end = match request.line_limit { None => buffer.max_point(), Some(limit) => start + language::Point::new(limit + 1, 0), }; - let content = buffer.text_for_range(start..end).collect(); - - if let Some(thread) = self.threads.lock().get(&request.thread_id) { - thread.update(cx, |thread, cx| { - thread.push_entry(ThreadEntry { - content: AgentThreadEntryContent::ReadFile { - path: request.path.clone(), - content: content.clone(), - }, - }); - }) - } + let content: String = buffer.text_for_range(start..end).collect(); + self.update_thread(&request.thread_id.into(), cx, |thread, cx| { + thread.push_entry( + AgentThreadEntryContent::ReadFile { + path: request.path.clone(), + content: content.clone(), + }, + cx, + ); + }); acp::ReadTextFileResponse { content, @@ -135,7 +154,7 @@ impl acp::Client for AcpClientDelegate { let mut base64_content = Vec::new(); let mut base64_encoder = base64::write::EncoderWriter::new( - Cursor::new(&mut base64_content), + std::io::Cursor::new(&mut base64_content), &base64::engine::general_purpose::STANDARD, ); base64_encoder.write_all(range_content)?; @@ -168,10 +187,7 @@ impl AcpAgent { let stdout = process.stdout.take().expect("process didn't have stdout"); let (connection, handler_fut, io_fut) = acp::AgentConnection::connect_to_agent( - AcpClientDelegate { - project, - cx: cx.clone(), - }, + AcpClientDelegate::new(project.clone(), cx.clone()), stdin, stdout, ); @@ -182,17 +198,18 @@ impl AcpAgent { }); Self { + project, connection: Arc::new(connection), - threads: Mutex::default(), + threads: Default::default(), _handler_task: cx.foreground_executor().spawn(handler_fut), _io_task: io_task, } } } -#[async_trait] +#[async_trait(?Send)] impl Agent for AcpAgent { - async fn threads(&self) -> Result> { + async fn threads(&self, cx: &mut AsyncApp) -> Result> { let response = self.connection.request(acp::GetThreadsParams).await?; response .threads @@ -207,31 +224,34 @@ impl Agent for AcpAgent { .collect() } - async fn create_thread(&self) -> Result> { + async fn create_thread(self: Arc, cx: &mut AsyncApp) -> Result> { let response = self.connection.request(acp::CreateThreadParams).await?; - let thread = Arc::new(AcpAgentThread { - id: response.thread_id.clone(), - connection: self.connection.clone(), - state: Mutex::new(AcpAgentThreadState { - turn: None, - next_turn_id: TurnId::default(), - }), - }); - self.threads - .lock() - .insert(response.thread_id, Arc::downgrade(&thread)); + let thread_id: ThreadId = response.thread_id.into(); + let agent = self.clone(); + let thread = cx.new(|_| Thread { + id: thread_id.clone(), + next_entry_id: ThreadEntryId(0), + entries: Vec::default(), + project: self.project.clone(), + agent, + })?; + self.threads.lock().insert(thread_id, thread.downgrade()); Ok(thread) } - async fn open_thread(&self, id: ThreadId) -> Result { + async fn open_thread(&self, id: ThreadId, cx: &mut AsyncApp) -> Result> { todo!() } - async fn thread_entries(&self, thread_id: ThreadId) -> Result> { + async fn thread_entries( + &self, + thread_id: ThreadId, + cx: &mut AsyncApp, + ) -> Result> { let response = self .connection .request(acp::GetThreadEntriesParams { - thread_id: self.id.clone(), + thread_id: thread_id.clone().into(), }) .await?; @@ -265,18 +285,18 @@ impl Agent for AcpAgent { &self, thread_id: ThreadId, message: crate::Message, + cx: &mut AsyncApp, ) -> Result>> { - let turn_id = { - let mut state = self.state.lock(); - let turn_id = state.next_turn_id.post_inc(); - state.turn = Some(AcpAgentThreadTurn { id: turn_id }); - turn_id - }; + let thread = self + .threads + .lock() + .get(&thread_id) + .cloned() + .ok_or_else(|| anyhow!("no such thread"))?; let response = self .connection .request(acp::SendMessageParams { - thread_id: self.id.clone(), - turn_id, + thread_id: thread_id.clone().into(), message: acp::Message { role: match message.role { Role::User => acp::Role::User, @@ -301,29 +321,14 @@ impl Agent for AcpAgent { } } -pub struct AcpAgentThread { - id: acp::ThreadId, - connection: Arc, - state: Mutex, -} - -struct AcpAgentThreadState { - next_turn_id: acp::TurnId, - turn: Option, -} - -struct AcpAgentThreadTurn { - id: acp::TurnId, -} - impl From for ThreadId { fn from(thread_id: acp::ThreadId) -> Self { - Self(thread_id.0) + Self(thread_id.0.into()) } } impl From for acp::ThreadId { fn from(thread_id: ThreadId) -> Self { - acp::ThreadId(thread_id.0) + acp::ThreadId(thread_id.0.to_string()) } } diff --git a/crates/agent2/src/agent2.rs b/crates/agent2/src/agent2.rs index 9c77a441cd..309fcc2728 100644 --- a/crates/agent2/src/agent2.rs +++ b/crates/agent2/src/agent2.rs @@ -13,16 +13,21 @@ use gpui::{AppContext, AsyncApp, Context, Entity, SharedString, Task}; use project::Project; use std::{future, ops::Range, path::PathBuf, pin::pin, sync::Arc}; -#[async_trait] +#[async_trait(?Send)] pub trait Agent: 'static { - async fn threads(&self) -> Result>; - async fn create_thread(&self) -> Result>; - async fn open_thread(&self, id: ThreadId) -> Result>; - async fn thread_entries(&self, id: ThreadId) -> Result>; + async fn threads(&self, cx: &mut AsyncApp) -> Result>; + async fn create_thread(self: Arc, cx: &mut AsyncApp) -> Result>; + async fn open_thread(&self, id: ThreadId, cx: &mut AsyncApp) -> Result>; + async fn thread_entries( + &self, + id: ThreadId, + cx: &mut AsyncApp, + ) -> Result>; async fn send_thread_message( &self, thread_id: ThreadId, message: Message, + cx: &mut AsyncApp, ) -> Result>>; } @@ -53,7 +58,7 @@ impl ReadFileRequest { } } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct ThreadId(SharedString); #[derive(Copy, Clone, Debug, PartialEq, Eq)] @@ -145,20 +150,20 @@ pub struct ThreadEntry { pub content: AgentThreadEntryContent, } -pub struct ThreadStore { +pub struct ThreadStore { threads: Vec, - agent: Arc, + agent: Arc, project: Entity, } -impl ThreadStore { +impl ThreadStore { pub async fn load( - agent: Arc, + agent: Arc, project: Entity, cx: &mut AsyncApp, ) -> Result> { - let threads = agent.threads().await?; - cx.new(|cx| Self { + let threads = agent.threads(cx).await?; + cx.new(|_cx| Self { threads, agent, project, @@ -177,21 +182,13 @@ impl ThreadStore { cx: &mut Context, ) -> Task>> { let agent = self.agent.clone(); - let project = self.project.clone(); - cx.spawn(async move |_, cx| { - let agent_thread = agent.open_thread(id).await?; - Thread::load(agent_thread, project, cx).await - }) + cx.spawn(async move |_, cx| agent.open_thread(id, cx).await) } /// Creates a new thread. pub fn create_thread(&self, cx: &mut Context) -> Task>> { let agent = self.agent.clone(); - let project = self.project.clone(); - cx.spawn(async move |_, cx| { - let agent_thread = agent.create_thread().await?; - Thread::load(agent_thread, project, cx).await - }) + cx.spawn(async move |_, cx| agent.create_thread(cx).await) } } @@ -210,7 +207,7 @@ impl Thread { project: Entity, cx: &mut AsyncApp, ) -> Result> { - let entries = agent.thread_entries(thread_id.clone()).await?; + let entries = agent.thread_entries(thread_id.clone(), cx).await?; cx.new(|cx| Self::new(agent, thread_id, entries, project, cx)) } @@ -241,11 +238,19 @@ impl Thread { &self.entries } + pub fn push_entry(&mut self, entry: AgentThreadEntryContent, cx: &mut Context) { + self.entries.push(ThreadEntry { + id: self.next_entry_id.post_inc(), + content: entry, + }); + cx.notify(); + } + pub fn send(&mut self, message: Message, cx: &mut Context) -> Task> { let agent = self.agent.clone(); - let id = self.id; + let id = self.id.clone(); cx.spawn(async move |this, cx| { - let mut events = agent.send_thread_message(id, message).await?; + let mut events = agent.send_thread_message(id, message, cx).await?; let mut pending_event_handlers = FuturesUnordered::new(); loop {