agent2: Port Zed AI features (#36172)
Release Notes: - N/A --------- Co-authored-by: Antonio Scandurra <me@as-cii.com>
This commit is contained in:
parent
f8b0105258
commit
6f3cd42411
17 changed files with 994 additions and 358 deletions
|
@ -33,13 +33,23 @@ pub struct UserMessage {
|
|||
pub id: Option<UserMessageId>,
|
||||
pub content: ContentBlock,
|
||||
pub chunks: Vec<acp::ContentBlock>,
|
||||
pub checkpoint: Option<GitStoreCheckpoint>,
|
||||
pub checkpoint: Option<Checkpoint>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Checkpoint {
|
||||
git_checkpoint: GitStoreCheckpoint,
|
||||
pub show: bool,
|
||||
}
|
||||
|
||||
impl UserMessage {
|
||||
fn to_markdown(&self, cx: &App) -> String {
|
||||
let mut markdown = String::new();
|
||||
if let Some(_) = self.checkpoint {
|
||||
if self
|
||||
.checkpoint
|
||||
.as_ref()
|
||||
.map_or(false, |checkpoint| checkpoint.show)
|
||||
{
|
||||
writeln!(markdown, "## User (checkpoint)").unwrap();
|
||||
} else {
|
||||
writeln!(markdown, "## User").unwrap();
|
||||
|
@ -1145,9 +1155,12 @@ impl AcpThread {
|
|||
self.project.read(cx).languages().clone(),
|
||||
cx,
|
||||
);
|
||||
let request = acp::PromptRequest {
|
||||
prompt: message.clone(),
|
||||
session_id: self.session_id.clone(),
|
||||
};
|
||||
let git_store = self.project.read(cx).git_store().clone();
|
||||
|
||||
let old_checkpoint = git_store.update(cx, |git, cx| git.checkpoint(cx));
|
||||
let message_id = if self
|
||||
.connection
|
||||
.session_editor(&self.session_id, cx)
|
||||
|
@ -1161,68 +1174,63 @@ impl AcpThread {
|
|||
AgentThreadEntry::UserMessage(UserMessage {
|
||||
id: message_id.clone(),
|
||||
content: block,
|
||||
chunks: message.clone(),
|
||||
chunks: message,
|
||||
checkpoint: None,
|
||||
}),
|
||||
cx,
|
||||
);
|
||||
|
||||
self.run_turn(cx, async move |this, cx| {
|
||||
let old_checkpoint = git_store
|
||||
.update(cx, |git, cx| git.checkpoint(cx))?
|
||||
.await
|
||||
.context("failed to get old checkpoint")
|
||||
.log_err();
|
||||
this.update(cx, |this, cx| {
|
||||
if let Some((_ix, message)) = this.last_user_message() {
|
||||
message.checkpoint = old_checkpoint.map(|git_checkpoint| Checkpoint {
|
||||
git_checkpoint,
|
||||
show: false,
|
||||
});
|
||||
}
|
||||
this.connection.prompt(message_id, request, cx)
|
||||
})?
|
||||
.await
|
||||
})
|
||||
}
|
||||
|
||||
pub fn resume(&mut self, cx: &mut Context<Self>) -> BoxFuture<'static, Result<()>> {
|
||||
self.run_turn(cx, async move |this, cx| {
|
||||
this.update(cx, |this, cx| {
|
||||
this.connection
|
||||
.resume(&this.session_id, cx)
|
||||
.map(|resume| resume.run(cx))
|
||||
})?
|
||||
.context("resuming a session is not supported")?
|
||||
.await
|
||||
})
|
||||
}
|
||||
|
||||
fn run_turn(
|
||||
&mut self,
|
||||
cx: &mut Context<Self>,
|
||||
f: impl 'static + AsyncFnOnce(WeakEntity<Self>, &mut AsyncApp) -> Result<acp::PromptResponse>,
|
||||
) -> BoxFuture<'static, Result<()>> {
|
||||
self.clear_completed_plan_entries(cx);
|
||||
|
||||
let (old_checkpoint_tx, old_checkpoint_rx) = oneshot::channel();
|
||||
let (tx, rx) = oneshot::channel();
|
||||
let cancel_task = self.cancel(cx);
|
||||
let request = acp::PromptRequest {
|
||||
prompt: message,
|
||||
session_id: self.session_id.clone(),
|
||||
};
|
||||
|
||||
self.send_task = Some(cx.spawn({
|
||||
let message_id = message_id.clone();
|
||||
async move |this, cx| {
|
||||
cancel_task.await;
|
||||
|
||||
old_checkpoint_tx.send(old_checkpoint.await).ok();
|
||||
if let Ok(result) = this.update(cx, |this, cx| {
|
||||
this.connection.prompt(message_id, request, cx)
|
||||
}) {
|
||||
tx.send(result.await).log_err();
|
||||
}
|
||||
}
|
||||
self.send_task = Some(cx.spawn(async move |this, cx| {
|
||||
cancel_task.await;
|
||||
tx.send(f(this, cx).await).ok();
|
||||
}));
|
||||
|
||||
cx.spawn(async move |this, cx| {
|
||||
let old_checkpoint = old_checkpoint_rx
|
||||
.await
|
||||
.map_err(|_| anyhow!("send canceled"))
|
||||
.flatten()
|
||||
.context("failed to get old checkpoint")
|
||||
.log_err();
|
||||
|
||||
let response = rx.await;
|
||||
|
||||
if let Some((old_checkpoint, message_id)) = old_checkpoint.zip(message_id) {
|
||||
let new_checkpoint = git_store
|
||||
.update(cx, |git, cx| git.checkpoint(cx))?
|
||||
.await
|
||||
.context("failed to get new checkpoint")
|
||||
.log_err();
|
||||
if let Some(new_checkpoint) = new_checkpoint {
|
||||
let equal = git_store
|
||||
.update(cx, |git, cx| {
|
||||
git.compare_checkpoints(old_checkpoint.clone(), new_checkpoint, cx)
|
||||
})?
|
||||
.await
|
||||
.unwrap_or(true);
|
||||
if !equal {
|
||||
this.update(cx, |this, cx| {
|
||||
if let Some((ix, message)) = this.user_message_mut(&message_id) {
|
||||
message.checkpoint = Some(old_checkpoint);
|
||||
cx.emit(AcpThreadEvent::EntryUpdated(ix));
|
||||
}
|
||||
})?;
|
||||
}
|
||||
}
|
||||
}
|
||||
this.update(cx, |this, cx| this.update_last_checkpoint(cx))?
|
||||
.await?;
|
||||
|
||||
this.update(cx, |this, cx| {
|
||||
match response {
|
||||
|
@ -1294,7 +1302,10 @@ impl AcpThread {
|
|||
return Task::ready(Err(anyhow!("message not found")));
|
||||
};
|
||||
|
||||
let checkpoint = message.checkpoint.clone();
|
||||
let checkpoint = message
|
||||
.checkpoint
|
||||
.as_ref()
|
||||
.map(|c| c.git_checkpoint.clone());
|
||||
|
||||
let git_store = self.project.read(cx).git_store().clone();
|
||||
cx.spawn(async move |this, cx| {
|
||||
|
@ -1316,6 +1327,59 @@ impl AcpThread {
|
|||
})
|
||||
}
|
||||
|
||||
fn update_last_checkpoint(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
|
||||
let git_store = self.project.read(cx).git_store().clone();
|
||||
|
||||
let old_checkpoint = if let Some((_, message)) = self.last_user_message() {
|
||||
if let Some(checkpoint) = message.checkpoint.as_ref() {
|
||||
checkpoint.git_checkpoint.clone()
|
||||
} else {
|
||||
return Task::ready(Ok(()));
|
||||
}
|
||||
} else {
|
||||
return Task::ready(Ok(()));
|
||||
};
|
||||
|
||||
let new_checkpoint = git_store.update(cx, |git, cx| git.checkpoint(cx));
|
||||
cx.spawn(async move |this, cx| {
|
||||
let new_checkpoint = new_checkpoint
|
||||
.await
|
||||
.context("failed to get new checkpoint")
|
||||
.log_err();
|
||||
if let Some(new_checkpoint) = new_checkpoint {
|
||||
let equal = git_store
|
||||
.update(cx, |git, cx| {
|
||||
git.compare_checkpoints(old_checkpoint.clone(), new_checkpoint, cx)
|
||||
})?
|
||||
.await
|
||||
.unwrap_or(true);
|
||||
this.update(cx, |this, cx| {
|
||||
let (ix, message) = this.last_user_message().context("no user message")?;
|
||||
let checkpoint = message.checkpoint.as_mut().context("no checkpoint")?;
|
||||
checkpoint.show = !equal;
|
||||
cx.emit(AcpThreadEvent::EntryUpdated(ix));
|
||||
anyhow::Ok(())
|
||||
})??;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
fn last_user_message(&mut self) -> Option<(usize, &mut UserMessage)> {
|
||||
self.entries
|
||||
.iter_mut()
|
||||
.enumerate()
|
||||
.rev()
|
||||
.find_map(|(ix, entry)| {
|
||||
if let AgentThreadEntry::UserMessage(message) = entry {
|
||||
Some((ix, message))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn user_message(&self, id: &UserMessageId) -> Option<&UserMessage> {
|
||||
self.entries.iter().find_map(|entry| {
|
||||
if let AgentThreadEntry::UserMessage(message) = entry {
|
||||
|
@ -1552,6 +1616,7 @@ mod tests {
|
|||
use settings::SettingsStore;
|
||||
use smol::stream::StreamExt as _;
|
||||
use std::{
|
||||
any::Any,
|
||||
cell::RefCell,
|
||||
path::Path,
|
||||
rc::Rc,
|
||||
|
@ -2284,6 +2349,10 @@ mod tests {
|
|||
_session_id: session_id.clone(),
|
||||
}))
|
||||
}
|
||||
|
||||
fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
struct FakeAgentSessionEditor {
|
||||
|
|
|
@ -4,7 +4,7 @@ use anyhow::Result;
|
|||
use collections::IndexMap;
|
||||
use gpui::{Entity, SharedString, Task};
|
||||
use project::Project;
|
||||
use std::{error::Error, fmt, path::Path, rc::Rc, sync::Arc};
|
||||
use std::{any::Any, error::Error, fmt, path::Path, rc::Rc, sync::Arc};
|
||||
use ui::{App, IconName};
|
||||
use uuid::Uuid;
|
||||
|
||||
|
@ -36,6 +36,14 @@ pub trait AgentConnection {
|
|||
cx: &mut App,
|
||||
) -> Task<Result<acp::PromptResponse>>;
|
||||
|
||||
fn resume(
|
||||
&self,
|
||||
_session_id: &acp::SessionId,
|
||||
_cx: &mut App,
|
||||
) -> Option<Rc<dyn AgentSessionResume>> {
|
||||
None
|
||||
}
|
||||
|
||||
fn cancel(&self, session_id: &acp::SessionId, cx: &mut App);
|
||||
|
||||
fn session_editor(
|
||||
|
@ -53,12 +61,24 @@ pub trait AgentConnection {
|
|||
fn model_selector(&self) -> Option<Rc<dyn AgentModelSelector>> {
|
||||
None
|
||||
}
|
||||
|
||||
fn into_any(self: Rc<Self>) -> Rc<dyn Any>;
|
||||
}
|
||||
|
||||
impl dyn AgentConnection {
|
||||
pub fn downcast<T: 'static + AgentConnection + Sized>(self: Rc<Self>) -> Option<Rc<T>> {
|
||||
self.into_any().downcast().ok()
|
||||
}
|
||||
}
|
||||
|
||||
pub trait AgentSessionEditor {
|
||||
fn truncate(&self, message_id: UserMessageId, cx: &mut App) -> Task<Result<()>>;
|
||||
}
|
||||
|
||||
pub trait AgentSessionResume {
|
||||
fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>>;
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct AuthRequired;
|
||||
|
||||
|
@ -299,6 +319,10 @@ mod test_support {
|
|||
) -> Option<Rc<dyn AgentSessionEditor>> {
|
||||
Some(Rc::new(StubAgentSessionEditor))
|
||||
}
|
||||
|
||||
fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
struct StubAgentSessionEditor;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue