Get agent2 compiling

Co-authored-by: Conrad Irwin <conrad.irwin@gmail.com>
Co-authored-by: Antonio Scandurra <me@as-cii.com>
This commit is contained in:
Max Brunsfeld 2025-06-25 10:30:52 -07:00
parent f4e2d38c29
commit adbccb1ad0
2 changed files with 109 additions and 99 deletions

View file

@ -1,19 +1,15 @@
use std::{ use std::{io::Write as _, path::Path, sync::Arc};
io::{Cursor, Write as _},
path::Path,
sync::{Arc, Weak},
};
use crate::{ use crate::{
Agent, AgentThread, AgentThreadEntryContent, AgentThreadSummary, Message, MessageChunk, Agent, AgentThreadEntryContent, AgentThreadSummary, Message, MessageChunk, ResponseEvent, Role,
ResponseEvent, Role, Thread, ThreadEntry, ThreadId, Thread, ThreadEntryId, ThreadId,
}; };
use agentic_coding_protocol::{self as acp, TurnId}; use agentic_coding_protocol as acp;
use anyhow::{Context as _, Result}; use anyhow::{Context as _, Result, anyhow};
use async_trait::async_trait; use async_trait::async_trait;
use collections::HashMap; use collections::HashMap;
use futures::channel::mpsc::UnboundedReceiver; use futures::channel::mpsc::UnboundedReceiver;
use gpui::{AppContext, AsyncApp, Entity, Task, WeakEntity}; use gpui::{App, AppContext, AsyncApp, Context, Entity, Task, WeakEntity};
use parking_lot::Mutex; use parking_lot::Mutex;
use project::Project; use project::Project;
use smol::process::Child; use smol::process::Child;
@ -21,18 +17,43 @@ use util::ResultExt;
pub struct AcpAgent { pub struct AcpAgent {
connection: Arc<acp::AgentConnection>, connection: Arc<acp::AgentConnection>,
threads: Arc<Mutex<HashMap<acp::ThreadId, WeakEntity<Thread>>>>, threads: Arc<Mutex<HashMap<ThreadId, WeakEntity<Thread>>>>,
project: Entity<Project>,
_handler_task: Task<()>, _handler_task: Task<()>,
_io_task: Task<()>, _io_task: Task<()>,
} }
struct AcpClientDelegate { struct AcpClientDelegate {
project: Entity<Project>, project: Entity<Project>,
threads: Arc<Mutex<HashMap<acp::ThreadId, WeakEntity<Thread>>>>, threads: Arc<Mutex<HashMap<ThreadId, WeakEntity<Thread>>>>,
cx: AsyncApp, cx: AsyncApp,
// sent_buffer_versions: HashMap<Entity<Buffer>, HashMap<u64, BufferSnapshot>>, // sent_buffer_versions: HashMap<Entity<Buffer>, HashMap<u64, BufferSnapshot>>,
} }
impl AcpClientDelegate {
fn new(project: Entity<Project>, cx: AsyncApp) -> Self {
Self {
project,
threads: Default::default(),
cx: cx,
}
}
fn update_thread<R>(
&self,
thread_id: &ThreadId,
cx: &mut App,
callback: impl FnMut(&mut Thread, &mut Context<Thread>) -> R,
) -> Option<R> {
let thread = self.threads.lock().get(&thread_id)?.clone();
let Some(thread) = thread.upgrade() else {
self.threads.lock().remove(&thread_id);
return None;
};
Some(thread.update(cx, callback))
}
}
#[async_trait(?Send)] #[async_trait(?Send)]
impl acp::Client for AcpClientDelegate { impl acp::Client for AcpClientDelegate {
async fn stat(&self, params: acp::StatParams) -> Result<acp::StatResponse> { async fn stat(&self, params: acp::StatParams) -> Result<acp::StatResponse> {
@ -58,7 +79,7 @@ impl acp::Client for AcpClientDelegate {
async fn stream_message_chunk( async fn stream_message_chunk(
&self, &self,
request: acp::StreamMessageChunkParams, chunk: acp::StreamMessageChunkParams,
) -> Result<acp::StreamMessageChunkResponse> { ) -> Result<acp::StreamMessageChunkResponse> {
Ok(acp::StreamMessageChunkResponse) Ok(acp::StreamMessageChunkResponse)
} }
@ -78,25 +99,23 @@ impl acp::Client for AcpClientDelegate {
})?? })??
.await?; .await?;
buffer.update(cx, |buffer, _| { buffer.update(cx, |buffer, cx| {
let start = language::Point::new(request.line_offset.unwrap_or(0), 0); let start = language::Point::new(request.line_offset.unwrap_or(0), 0);
let end = match request.line_limit { let end = match request.line_limit {
None => buffer.max_point(), None => buffer.max_point(),
Some(limit) => start + language::Point::new(limit + 1, 0), Some(limit) => start + language::Point::new(limit + 1, 0),
}; };
let content = buffer.text_for_range(start..end).collect(); let content: String = buffer.text_for_range(start..end).collect();
self.update_thread(&request.thread_id.into(), cx, |thread, cx| {
if let Some(thread) = self.threads.lock().get(&request.thread_id) { thread.push_entry(
thread.update(cx, |thread, cx| { AgentThreadEntryContent::ReadFile {
thread.push_entry(ThreadEntry { path: request.path.clone(),
content: AgentThreadEntryContent::ReadFile { content: content.clone(),
path: request.path.clone(), },
content: content.clone(), cx,
}, );
}); });
})
}
acp::ReadTextFileResponse { acp::ReadTextFileResponse {
content, content,
@ -135,7 +154,7 @@ impl acp::Client for AcpClientDelegate {
let mut base64_content = Vec::new(); let mut base64_content = Vec::new();
let mut base64_encoder = base64::write::EncoderWriter::new( let mut base64_encoder = base64::write::EncoderWriter::new(
Cursor::new(&mut base64_content), std::io::Cursor::new(&mut base64_content),
&base64::engine::general_purpose::STANDARD, &base64::engine::general_purpose::STANDARD,
); );
base64_encoder.write_all(range_content)?; base64_encoder.write_all(range_content)?;
@ -168,10 +187,7 @@ impl AcpAgent {
let stdout = process.stdout.take().expect("process didn't have stdout"); let stdout = process.stdout.take().expect("process didn't have stdout");
let (connection, handler_fut, io_fut) = acp::AgentConnection::connect_to_agent( let (connection, handler_fut, io_fut) = acp::AgentConnection::connect_to_agent(
AcpClientDelegate { AcpClientDelegate::new(project.clone(), cx.clone()),
project,
cx: cx.clone(),
},
stdin, stdin,
stdout, stdout,
); );
@ -182,17 +198,18 @@ impl AcpAgent {
}); });
Self { Self {
project,
connection: Arc::new(connection), connection: Arc::new(connection),
threads: Mutex::default(), threads: Default::default(),
_handler_task: cx.foreground_executor().spawn(handler_fut), _handler_task: cx.foreground_executor().spawn(handler_fut),
_io_task: io_task, _io_task: io_task,
} }
} }
} }
#[async_trait] #[async_trait(?Send)]
impl Agent for AcpAgent { impl Agent for AcpAgent {
async fn threads(&self) -> Result<Vec<AgentThreadSummary>> { async fn threads(&self, cx: &mut AsyncApp) -> Result<Vec<AgentThreadSummary>> {
let response = self.connection.request(acp::GetThreadsParams).await?; let response = self.connection.request(acp::GetThreadsParams).await?;
response response
.threads .threads
@ -207,31 +224,34 @@ impl Agent for AcpAgent {
.collect() .collect()
} }
async fn create_thread(&self) -> Result<Arc<Self::Thread>> { async fn create_thread(self: Arc<Self>, cx: &mut AsyncApp) -> Result<Entity<Thread>> {
let response = self.connection.request(acp::CreateThreadParams).await?; let response = self.connection.request(acp::CreateThreadParams).await?;
let thread = Arc::new(AcpAgentThread { let thread_id: ThreadId = response.thread_id.into();
id: response.thread_id.clone(), let agent = self.clone();
connection: self.connection.clone(), let thread = cx.new(|_| Thread {
state: Mutex::new(AcpAgentThreadState { id: thread_id.clone(),
turn: None, next_entry_id: ThreadEntryId(0),
next_turn_id: TurnId::default(), entries: Vec::default(),
}), project: self.project.clone(),
}); agent,
self.threads })?;
.lock() self.threads.lock().insert(thread_id, thread.downgrade());
.insert(response.thread_id, Arc::downgrade(&thread));
Ok(thread) Ok(thread)
} }
async fn open_thread(&self, id: ThreadId) -> Result<Thread> { async fn open_thread(&self, id: ThreadId, cx: &mut AsyncApp) -> Result<Entity<Thread>> {
todo!() todo!()
} }
async fn thread_entries(&self, thread_id: ThreadId) -> Result<Vec<AgentThreadEntryContent>> { async fn thread_entries(
&self,
thread_id: ThreadId,
cx: &mut AsyncApp,
) -> Result<Vec<AgentThreadEntryContent>> {
let response = self let response = self
.connection .connection
.request(acp::GetThreadEntriesParams { .request(acp::GetThreadEntriesParams {
thread_id: self.id.clone(), thread_id: thread_id.clone().into(),
}) })
.await?; .await?;
@ -265,18 +285,18 @@ impl Agent for AcpAgent {
&self, &self,
thread_id: ThreadId, thread_id: ThreadId,
message: crate::Message, message: crate::Message,
cx: &mut AsyncApp,
) -> Result<UnboundedReceiver<Result<ResponseEvent>>> { ) -> Result<UnboundedReceiver<Result<ResponseEvent>>> {
let turn_id = { let thread = self
let mut state = self.state.lock(); .threads
let turn_id = state.next_turn_id.post_inc(); .lock()
state.turn = Some(AcpAgentThreadTurn { id: turn_id }); .get(&thread_id)
turn_id .cloned()
}; .ok_or_else(|| anyhow!("no such thread"))?;
let response = self let response = self
.connection .connection
.request(acp::SendMessageParams { .request(acp::SendMessageParams {
thread_id: self.id.clone(), thread_id: thread_id.clone().into(),
turn_id,
message: acp::Message { message: acp::Message {
role: match message.role { role: match message.role {
Role::User => acp::Role::User, Role::User => acp::Role::User,
@ -301,29 +321,14 @@ impl Agent for AcpAgent {
} }
} }
pub struct AcpAgentThread {
id: acp::ThreadId,
connection: Arc<acp::AgentConnection>,
state: Mutex<AcpAgentThreadState>,
}
struct AcpAgentThreadState {
next_turn_id: acp::TurnId,
turn: Option<AcpAgentThreadTurn>,
}
struct AcpAgentThreadTurn {
id: acp::TurnId,
}
impl From<acp::ThreadId> for ThreadId { impl From<acp::ThreadId> for ThreadId {
fn from(thread_id: acp::ThreadId) -> Self { fn from(thread_id: acp::ThreadId) -> Self {
Self(thread_id.0) Self(thread_id.0.into())
} }
} }
impl From<ThreadId> for acp::ThreadId { impl From<ThreadId> for acp::ThreadId {
fn from(thread_id: ThreadId) -> Self { fn from(thread_id: ThreadId) -> Self {
acp::ThreadId(thread_id.0) acp::ThreadId(thread_id.0.to_string())
} }
} }

View file

@ -13,16 +13,21 @@ use gpui::{AppContext, AsyncApp, Context, Entity, SharedString, Task};
use project::Project; use project::Project;
use std::{future, ops::Range, path::PathBuf, pin::pin, sync::Arc}; use std::{future, ops::Range, path::PathBuf, pin::pin, sync::Arc};
#[async_trait] #[async_trait(?Send)]
pub trait Agent: 'static { pub trait Agent: 'static {
async fn threads(&self) -> Result<Vec<AgentThreadSummary>>; async fn threads(&self, cx: &mut AsyncApp) -> Result<Vec<AgentThreadSummary>>;
async fn create_thread(&self) -> Result<Entity<Thread>>; async fn create_thread(self: Arc<Self>, cx: &mut AsyncApp) -> Result<Entity<Thread>>;
async fn open_thread(&self, id: ThreadId) -> Result<Entity<Thread>>; async fn open_thread(&self, id: ThreadId, cx: &mut AsyncApp) -> Result<Entity<Thread>>;
async fn thread_entries(&self, id: ThreadId) -> Result<Vec<AgentThreadEntryContent>>; async fn thread_entries(
&self,
id: ThreadId,
cx: &mut AsyncApp,
) -> Result<Vec<AgentThreadEntryContent>>;
async fn send_thread_message( async fn send_thread_message(
&self, &self,
thread_id: ThreadId, thread_id: ThreadId,
message: Message, message: Message,
cx: &mut AsyncApp,
) -> Result<mpsc::UnboundedReceiver<Result<ResponseEvent>>>; ) -> Result<mpsc::UnboundedReceiver<Result<ResponseEvent>>>;
} }
@ -53,7 +58,7 @@ impl ReadFileRequest {
} }
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct ThreadId(SharedString); pub struct ThreadId(SharedString);
#[derive(Copy, Clone, Debug, PartialEq, Eq)] #[derive(Copy, Clone, Debug, PartialEq, Eq)]
@ -145,20 +150,20 @@ pub struct ThreadEntry {
pub content: AgentThreadEntryContent, pub content: AgentThreadEntryContent,
} }
pub struct ThreadStore<T: Agent> { pub struct ThreadStore {
threads: Vec<AgentThreadSummary>, threads: Vec<AgentThreadSummary>,
agent: Arc<T>, agent: Arc<dyn Agent>,
project: Entity<Project>, project: Entity<Project>,
} }
impl<T: Agent> ThreadStore<T> { impl ThreadStore {
pub async fn load( pub async fn load(
agent: Arc<T>, agent: Arc<dyn Agent>,
project: Entity<Project>, project: Entity<Project>,
cx: &mut AsyncApp, cx: &mut AsyncApp,
) -> Result<Entity<Self>> { ) -> Result<Entity<Self>> {
let threads = agent.threads().await?; let threads = agent.threads(cx).await?;
cx.new(|cx| Self { cx.new(|_cx| Self {
threads, threads,
agent, agent,
project, project,
@ -177,21 +182,13 @@ impl<T: Agent> ThreadStore<T> {
cx: &mut Context<Self>, cx: &mut Context<Self>,
) -> Task<Result<Entity<Thread>>> { ) -> Task<Result<Entity<Thread>>> {
let agent = self.agent.clone(); let agent = self.agent.clone();
let project = self.project.clone(); cx.spawn(async move |_, cx| agent.open_thread(id, cx).await)
cx.spawn(async move |_, cx| {
let agent_thread = agent.open_thread(id).await?;
Thread::load(agent_thread, project, cx).await
})
} }
/// Creates a new thread. /// Creates a new thread.
pub fn create_thread(&self, cx: &mut Context<Self>) -> Task<Result<Entity<Thread>>> { pub fn create_thread(&self, cx: &mut Context<Self>) -> Task<Result<Entity<Thread>>> {
let agent = self.agent.clone(); let agent = self.agent.clone();
let project = self.project.clone(); cx.spawn(async move |_, cx| agent.create_thread(cx).await)
cx.spawn(async move |_, cx| {
let agent_thread = agent.create_thread().await?;
Thread::load(agent_thread, project, cx).await
})
} }
} }
@ -210,7 +207,7 @@ impl Thread {
project: Entity<Project>, project: Entity<Project>,
cx: &mut AsyncApp, cx: &mut AsyncApp,
) -> Result<Entity<Self>> { ) -> Result<Entity<Self>> {
let entries = agent.thread_entries(thread_id.clone()).await?; let entries = agent.thread_entries(thread_id.clone(), cx).await?;
cx.new(|cx| Self::new(agent, thread_id, entries, project, cx)) cx.new(|cx| Self::new(agent, thread_id, entries, project, cx))
} }
@ -241,11 +238,19 @@ impl Thread {
&self.entries &self.entries
} }
pub fn push_entry(&mut self, entry: AgentThreadEntryContent, cx: &mut Context<Self>) {
self.entries.push(ThreadEntry {
id: self.next_entry_id.post_inc(),
content: entry,
});
cx.notify();
}
pub fn send(&mut self, message: Message, cx: &mut Context<Self>) -> Task<Result<()>> { pub fn send(&mut self, message: Message, cx: &mut Context<Self>) -> Task<Result<()>> {
let agent = self.agent.clone(); let agent = self.agent.clone();
let id = self.id; let id = self.id.clone();
cx.spawn(async move |this, cx| { cx.spawn(async move |this, cx| {
let mut events = agent.send_thread_message(id, message).await?; let mut events = agent.send_thread_message(id, message, cx).await?;
let mut pending_event_handlers = FuturesUnordered::new(); let mut pending_event_handlers = FuturesUnordered::new();
loop { loop {