agent2: Port Zed AI features (#36172)

Release Notes:

- N/A

---------

Co-authored-by: Antonio Scandurra <me@as-cii.com>
This commit is contained in:
Bennet Bo Fenner 2025-08-15 13:17:17 +02:00 committed by GitHub
parent f8b0105258
commit 6f3cd42411
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
17 changed files with 994 additions and 358 deletions

View file

@ -33,13 +33,23 @@ pub struct UserMessage {
pub id: Option<UserMessageId>,
pub content: ContentBlock,
pub chunks: Vec<acp::ContentBlock>,
pub checkpoint: Option<GitStoreCheckpoint>,
pub checkpoint: Option<Checkpoint>,
}
#[derive(Debug)]
pub struct Checkpoint {
git_checkpoint: GitStoreCheckpoint,
pub show: bool,
}
impl UserMessage {
fn to_markdown(&self, cx: &App) -> String {
let mut markdown = String::new();
if let Some(_) = self.checkpoint {
if self
.checkpoint
.as_ref()
.map_or(false, |checkpoint| checkpoint.show)
{
writeln!(markdown, "## User (checkpoint)").unwrap();
} else {
writeln!(markdown, "## User").unwrap();
@ -1145,9 +1155,12 @@ impl AcpThread {
self.project.read(cx).languages().clone(),
cx,
);
let request = acp::PromptRequest {
prompt: message.clone(),
session_id: self.session_id.clone(),
};
let git_store = self.project.read(cx).git_store().clone();
let old_checkpoint = git_store.update(cx, |git, cx| git.checkpoint(cx));
let message_id = if self
.connection
.session_editor(&self.session_id, cx)
@ -1161,68 +1174,63 @@ impl AcpThread {
AgentThreadEntry::UserMessage(UserMessage {
id: message_id.clone(),
content: block,
chunks: message.clone(),
chunks: message,
checkpoint: None,
}),
cx,
);
self.run_turn(cx, async move |this, cx| {
let old_checkpoint = git_store
.update(cx, |git, cx| git.checkpoint(cx))?
.await
.context("failed to get old checkpoint")
.log_err();
this.update(cx, |this, cx| {
if let Some((_ix, message)) = this.last_user_message() {
message.checkpoint = old_checkpoint.map(|git_checkpoint| Checkpoint {
git_checkpoint,
show: false,
});
}
this.connection.prompt(message_id, request, cx)
})?
.await
})
}
pub fn resume(&mut self, cx: &mut Context<Self>) -> BoxFuture<'static, Result<()>> {
self.run_turn(cx, async move |this, cx| {
this.update(cx, |this, cx| {
this.connection
.resume(&this.session_id, cx)
.map(|resume| resume.run(cx))
})?
.context("resuming a session is not supported")?
.await
})
}
fn run_turn(
&mut self,
cx: &mut Context<Self>,
f: impl 'static + AsyncFnOnce(WeakEntity<Self>, &mut AsyncApp) -> Result<acp::PromptResponse>,
) -> BoxFuture<'static, Result<()>> {
self.clear_completed_plan_entries(cx);
let (old_checkpoint_tx, old_checkpoint_rx) = oneshot::channel();
let (tx, rx) = oneshot::channel();
let cancel_task = self.cancel(cx);
let request = acp::PromptRequest {
prompt: message,
session_id: self.session_id.clone(),
};
self.send_task = Some(cx.spawn({
let message_id = message_id.clone();
async move |this, cx| {
cancel_task.await;
old_checkpoint_tx.send(old_checkpoint.await).ok();
if let Ok(result) = this.update(cx, |this, cx| {
this.connection.prompt(message_id, request, cx)
}) {
tx.send(result.await).log_err();
}
}
self.send_task = Some(cx.spawn(async move |this, cx| {
cancel_task.await;
tx.send(f(this, cx).await).ok();
}));
cx.spawn(async move |this, cx| {
let old_checkpoint = old_checkpoint_rx
.await
.map_err(|_| anyhow!("send canceled"))
.flatten()
.context("failed to get old checkpoint")
.log_err();
let response = rx.await;
if let Some((old_checkpoint, message_id)) = old_checkpoint.zip(message_id) {
let new_checkpoint = git_store
.update(cx, |git, cx| git.checkpoint(cx))?
.await
.context("failed to get new checkpoint")
.log_err();
if let Some(new_checkpoint) = new_checkpoint {
let equal = git_store
.update(cx, |git, cx| {
git.compare_checkpoints(old_checkpoint.clone(), new_checkpoint, cx)
})?
.await
.unwrap_or(true);
if !equal {
this.update(cx, |this, cx| {
if let Some((ix, message)) = this.user_message_mut(&message_id) {
message.checkpoint = Some(old_checkpoint);
cx.emit(AcpThreadEvent::EntryUpdated(ix));
}
})?;
}
}
}
this.update(cx, |this, cx| this.update_last_checkpoint(cx))?
.await?;
this.update(cx, |this, cx| {
match response {
@ -1294,7 +1302,10 @@ impl AcpThread {
return Task::ready(Err(anyhow!("message not found")));
};
let checkpoint = message.checkpoint.clone();
let checkpoint = message
.checkpoint
.as_ref()
.map(|c| c.git_checkpoint.clone());
let git_store = self.project.read(cx).git_store().clone();
cx.spawn(async move |this, cx| {
@ -1316,6 +1327,59 @@ impl AcpThread {
})
}
fn update_last_checkpoint(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
let git_store = self.project.read(cx).git_store().clone();
let old_checkpoint = if let Some((_, message)) = self.last_user_message() {
if let Some(checkpoint) = message.checkpoint.as_ref() {
checkpoint.git_checkpoint.clone()
} else {
return Task::ready(Ok(()));
}
} else {
return Task::ready(Ok(()));
};
let new_checkpoint = git_store.update(cx, |git, cx| git.checkpoint(cx));
cx.spawn(async move |this, cx| {
let new_checkpoint = new_checkpoint
.await
.context("failed to get new checkpoint")
.log_err();
if let Some(new_checkpoint) = new_checkpoint {
let equal = git_store
.update(cx, |git, cx| {
git.compare_checkpoints(old_checkpoint.clone(), new_checkpoint, cx)
})?
.await
.unwrap_or(true);
this.update(cx, |this, cx| {
let (ix, message) = this.last_user_message().context("no user message")?;
let checkpoint = message.checkpoint.as_mut().context("no checkpoint")?;
checkpoint.show = !equal;
cx.emit(AcpThreadEvent::EntryUpdated(ix));
anyhow::Ok(())
})??;
}
Ok(())
})
}
fn last_user_message(&mut self) -> Option<(usize, &mut UserMessage)> {
self.entries
.iter_mut()
.enumerate()
.rev()
.find_map(|(ix, entry)| {
if let AgentThreadEntry::UserMessage(message) = entry {
Some((ix, message))
} else {
None
}
})
}
fn user_message(&self, id: &UserMessageId) -> Option<&UserMessage> {
self.entries.iter().find_map(|entry| {
if let AgentThreadEntry::UserMessage(message) = entry {
@ -1552,6 +1616,7 @@ mod tests {
use settings::SettingsStore;
use smol::stream::StreamExt as _;
use std::{
any::Any,
cell::RefCell,
path::Path,
rc::Rc,
@ -2284,6 +2349,10 @@ mod tests {
_session_id: session_id.clone(),
}))
}
fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
self
}
}
struct FakeAgentSessionEditor {

View file

@ -4,7 +4,7 @@ use anyhow::Result;
use collections::IndexMap;
use gpui::{Entity, SharedString, Task};
use project::Project;
use std::{error::Error, fmt, path::Path, rc::Rc, sync::Arc};
use std::{any::Any, error::Error, fmt, path::Path, rc::Rc, sync::Arc};
use ui::{App, IconName};
use uuid::Uuid;
@ -36,6 +36,14 @@ pub trait AgentConnection {
cx: &mut App,
) -> Task<Result<acp::PromptResponse>>;
fn resume(
&self,
_session_id: &acp::SessionId,
_cx: &mut App,
) -> Option<Rc<dyn AgentSessionResume>> {
None
}
fn cancel(&self, session_id: &acp::SessionId, cx: &mut App);
fn session_editor(
@ -53,12 +61,24 @@ pub trait AgentConnection {
fn model_selector(&self) -> Option<Rc<dyn AgentModelSelector>> {
None
}
fn into_any(self: Rc<Self>) -> Rc<dyn Any>;
}
impl dyn AgentConnection {
pub fn downcast<T: 'static + AgentConnection + Sized>(self: Rc<Self>) -> Option<Rc<T>> {
self.into_any().downcast().ok()
}
}
pub trait AgentSessionEditor {
fn truncate(&self, message_id: UserMessageId, cx: &mut App) -> Task<Result<()>>;
}
pub trait AgentSessionResume {
fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>>;
}
#[derive(Debug)]
pub struct AuthRequired;
@ -299,6 +319,10 @@ mod test_support {
) -> Option<Rc<dyn AgentSessionEditor>> {
Some(Rc::new(StubAgentSessionEditor))
}
fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
self
}
}
struct StubAgentSessionEditor;