WIP
This commit is contained in:
parent
6769b650d0
commit
160e6d5747
5 changed files with 211 additions and 114 deletions
|
@ -1,19 +1,18 @@
|
||||||
mod connection;
|
mod connection;
|
||||||
|
mod diff;
|
||||||
|
|
||||||
pub use connection::*;
|
pub use connection::*;
|
||||||
|
pub use diff::*;
|
||||||
|
|
||||||
use agent_client_protocol as acp;
|
use agent_client_protocol as acp;
|
||||||
use anyhow::{Context as _, Result};
|
use anyhow::{Context as _, Result};
|
||||||
use assistant_tool::ActionLog;
|
use assistant_tool::ActionLog;
|
||||||
use buffer_diff::BufferDiff;
|
use editor::Bias;
|
||||||
use editor::{Bias, MultiBuffer, PathKey};
|
|
||||||
use futures::future::{Fuse, FusedFuture};
|
use futures::future::{Fuse, FusedFuture};
|
||||||
use futures::{FutureExt, channel::oneshot, future::BoxFuture};
|
use futures::{FutureExt, channel::oneshot, future::BoxFuture};
|
||||||
use gpui::{AppContext, Context, Entity, EventEmitter, SharedString, Task};
|
use gpui::{AppContext, Context, Entity, EventEmitter, SharedString, Task};
|
||||||
use itertools::Itertools;
|
use itertools::Itertools;
|
||||||
use language::{
|
use language::{Anchor, Buffer, BufferSnapshot, LanguageRegistry, Point, text_diff};
|
||||||
Anchor, Buffer, BufferSnapshot, Capability, LanguageRegistry, OffsetRangeExt as _, Point,
|
|
||||||
text_diff,
|
|
||||||
};
|
|
||||||
use markdown::Markdown;
|
use markdown::Markdown;
|
||||||
use project::{AgentLocation, Project};
|
use project::{AgentLocation, Project};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
@ -141,7 +140,7 @@ impl AgentThreadEntry {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn diffs(&self) -> impl Iterator<Item = &Diff> {
|
pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
|
||||||
if let AgentThreadEntry::ToolCall(call) = self {
|
if let AgentThreadEntry::ToolCall(call) = self {
|
||||||
itertools::Either::Left(call.diffs())
|
itertools::Either::Left(call.diffs())
|
||||||
} else {
|
} else {
|
||||||
|
@ -250,7 +249,7 @@ impl ToolCall {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn diffs(&self) -> impl Iterator<Item = &Diff> {
|
pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
|
||||||
self.content.iter().filter_map(|content| match content {
|
self.content.iter().filter_map(|content| match content {
|
||||||
ToolCallContent::ContentBlock { .. } => None,
|
ToolCallContent::ContentBlock { .. } => None,
|
||||||
ToolCallContent::Diff { diff } => Some(diff),
|
ToolCallContent::Diff { diff } => Some(diff),
|
||||||
|
@ -390,7 +389,7 @@ impl ContentBlock {
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub enum ToolCallContent {
|
pub enum ToolCallContent {
|
||||||
ContentBlock { content: ContentBlock },
|
ContentBlock { content: ContentBlock },
|
||||||
Diff { diff: Diff },
|
Diff { diff: Entity<Diff> },
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ToolCallContent {
|
impl ToolCallContent {
|
||||||
|
@ -404,7 +403,7 @@ impl ToolCallContent {
|
||||||
content: ContentBlock::new(content, &language_registry, cx),
|
content: ContentBlock::new(content, &language_registry, cx),
|
||||||
},
|
},
|
||||||
acp::ToolCallContent::Diff { diff } => Self::Diff {
|
acp::ToolCallContent::Diff { diff } => Self::Diff {
|
||||||
diff: Diff::from_acp(diff, language_registry, cx),
|
diff: cx.new(|cx| Diff::from_acp(diff, language_registry, cx)),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -412,108 +411,11 @@ impl ToolCallContent {
|
||||||
pub fn to_markdown(&self, cx: &App) -> String {
|
pub fn to_markdown(&self, cx: &App) -> String {
|
||||||
match self {
|
match self {
|
||||||
Self::ContentBlock { content } => content.to_markdown(cx).to_string(),
|
Self::ContentBlock { content } => content.to_markdown(cx).to_string(),
|
||||||
Self::Diff { diff } => diff.to_markdown(cx),
|
Self::Diff { diff } => diff.read(cx).to_markdown(cx),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub struct Diff {
|
|
||||||
pub multibuffer: Entity<MultiBuffer>,
|
|
||||||
pub path: PathBuf,
|
|
||||||
_task: Task<Result<()>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Diff {
|
|
||||||
pub fn from_acp(
|
|
||||||
diff: acp::Diff,
|
|
||||||
language_registry: Arc<LanguageRegistry>,
|
|
||||||
cx: &mut App,
|
|
||||||
) -> Self {
|
|
||||||
let acp::Diff {
|
|
||||||
path,
|
|
||||||
old_text,
|
|
||||||
new_text,
|
|
||||||
} = diff;
|
|
||||||
|
|
||||||
let multibuffer = cx.new(|_cx| MultiBuffer::without_headers(Capability::ReadOnly));
|
|
||||||
|
|
||||||
let new_buffer = cx.new(|cx| Buffer::local(new_text, cx));
|
|
||||||
let old_buffer = cx.new(|cx| Buffer::local(old_text.unwrap_or("".into()), cx));
|
|
||||||
let new_buffer_snapshot = new_buffer.read(cx).text_snapshot();
|
|
||||||
let buffer_diff = cx.new(|cx| BufferDiff::new(&new_buffer_snapshot, cx));
|
|
||||||
|
|
||||||
let task = cx.spawn({
|
|
||||||
let multibuffer = multibuffer.clone();
|
|
||||||
let path = path.clone();
|
|
||||||
async move |cx| {
|
|
||||||
let language = language_registry
|
|
||||||
.language_for_file_path(&path)
|
|
||||||
.await
|
|
||||||
.log_err();
|
|
||||||
|
|
||||||
new_buffer.update(cx, |buffer, cx| buffer.set_language(language.clone(), cx))?;
|
|
||||||
|
|
||||||
let old_buffer_snapshot = old_buffer.update(cx, |buffer, cx| {
|
|
||||||
buffer.set_language(language, cx);
|
|
||||||
buffer.snapshot()
|
|
||||||
})?;
|
|
||||||
|
|
||||||
buffer_diff
|
|
||||||
.update(cx, |diff, cx| {
|
|
||||||
diff.set_base_text(
|
|
||||||
old_buffer_snapshot,
|
|
||||||
Some(language_registry),
|
|
||||||
new_buffer_snapshot,
|
|
||||||
cx,
|
|
||||||
)
|
|
||||||
})?
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
multibuffer
|
|
||||||
.update(cx, |multibuffer, cx| {
|
|
||||||
let hunk_ranges = {
|
|
||||||
let buffer = new_buffer.read(cx);
|
|
||||||
let diff = buffer_diff.read(cx);
|
|
||||||
diff.hunks_intersecting_range(Anchor::MIN..Anchor::MAX, &buffer, cx)
|
|
||||||
.map(|diff_hunk| diff_hunk.buffer_range.to_point(&buffer))
|
|
||||||
.collect::<Vec<_>>()
|
|
||||||
};
|
|
||||||
|
|
||||||
multibuffer.set_excerpts_for_path(
|
|
||||||
PathKey::for_buffer(&new_buffer, cx),
|
|
||||||
new_buffer.clone(),
|
|
||||||
hunk_ranges,
|
|
||||||
editor::DEFAULT_MULTIBUFFER_CONTEXT,
|
|
||||||
cx,
|
|
||||||
);
|
|
||||||
multibuffer.add_diff(buffer_diff, cx);
|
|
||||||
})
|
|
||||||
.log_err();
|
|
||||||
|
|
||||||
anyhow::Ok(())
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
Self {
|
|
||||||
multibuffer,
|
|
||||||
path,
|
|
||||||
_task: task,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn to_markdown(&self, cx: &App) -> String {
|
|
||||||
let buffer_text = self
|
|
||||||
.multibuffer
|
|
||||||
.read(cx)
|
|
||||||
.all_buffers()
|
|
||||||
.iter()
|
|
||||||
.map(|buffer| buffer.read(cx).text())
|
|
||||||
.join("\n");
|
|
||||||
format!("Diff: {}\n```\n{}\n```\n", self.path.display(), buffer_text)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Default)]
|
#[derive(Debug, Default)]
|
||||||
pub struct Plan {
|
pub struct Plan {
|
||||||
pub entries: Vec<PlanEntry>,
|
pub entries: Vec<PlanEntry>,
|
||||||
|
@ -828,6 +730,21 @@ impl AcpThread {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn set_tool_call_diff(
|
||||||
|
&mut self,
|
||||||
|
tool_call_id: &acp::ToolCallId,
|
||||||
|
diff: Entity<Diff>,
|
||||||
|
cx: &mut Context<Self>,
|
||||||
|
) -> Result<()> {
|
||||||
|
let (ix, current_call) = self
|
||||||
|
.tool_call_mut(tool_call_id)
|
||||||
|
.context("Tool call not found")?;
|
||||||
|
current_call.content.clear();
|
||||||
|
current_call.content.push(ToolCallContent::Diff { diff });
|
||||||
|
cx.emit(AcpThreadEvent::EntryUpdated(ix));
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
/// Updates a tool call if id matches an existing entry, otherwise inserts a new one.
|
/// Updates a tool call if id matches an existing entry, otherwise inserts a new one.
|
||||||
pub fn upsert_tool_call(&mut self, tool_call: acp::ToolCall, cx: &mut Context<Self>) {
|
pub fn upsert_tool_call(&mut self, tool_call: acp::ToolCall, cx: &mut Context<Self>) {
|
||||||
let status = ToolCallStatus::Allowed {
|
let status = ToolCallStatus::Allowed {
|
||||||
|
|
161
crates/acp_thread/src/diff.rs
Normal file
161
crates/acp_thread/src/diff.rs
Normal file
|
@ -0,0 +1,161 @@
|
||||||
|
use agent_client_protocol as acp;
|
||||||
|
use anyhow::Result;
|
||||||
|
use buffer_diff::BufferDiff;
|
||||||
|
use editor::{MultiBuffer, PathKey};
|
||||||
|
use gpui::{App, AppContext, Context, Entity, Task};
|
||||||
|
use itertools::Itertools;
|
||||||
|
use language::{Anchor, Buffer, Capability, LanguageRegistry, OffsetRangeExt as _};
|
||||||
|
use std::{
|
||||||
|
path::{Path, PathBuf},
|
||||||
|
sync::Arc,
|
||||||
|
};
|
||||||
|
use util::ResultExt;
|
||||||
|
|
||||||
|
pub enum Diff {
|
||||||
|
Pending {
|
||||||
|
multibuffer: Entity<MultiBuffer>,
|
||||||
|
base_text: Arc<String>,
|
||||||
|
buffer: Entity<Buffer>,
|
||||||
|
buffer_diff: Entity<BufferDiff>,
|
||||||
|
},
|
||||||
|
Ready {
|
||||||
|
path: PathBuf,
|
||||||
|
multibuffer: Entity<MultiBuffer>,
|
||||||
|
_task: Task<Result<()>>,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Diff {
|
||||||
|
pub fn from_acp(
|
||||||
|
diff: acp::Diff,
|
||||||
|
language_registry: Arc<LanguageRegistry>,
|
||||||
|
cx: &mut Context<Self>,
|
||||||
|
) -> Self {
|
||||||
|
let acp::Diff {
|
||||||
|
path,
|
||||||
|
old_text,
|
||||||
|
new_text,
|
||||||
|
} = diff;
|
||||||
|
|
||||||
|
let multibuffer = cx.new(|_cx| MultiBuffer::without_headers(Capability::ReadOnly));
|
||||||
|
|
||||||
|
let new_buffer = cx.new(|cx| Buffer::local(new_text, cx));
|
||||||
|
let old_buffer = cx.new(|cx| Buffer::local(old_text.unwrap_or("".into()), cx));
|
||||||
|
let new_buffer_snapshot = new_buffer.read(cx).text_snapshot();
|
||||||
|
let buffer_diff = cx.new(|cx| BufferDiff::new(&new_buffer_snapshot, cx));
|
||||||
|
|
||||||
|
let task = cx.spawn({
|
||||||
|
let multibuffer = multibuffer.clone();
|
||||||
|
let path = path.clone();
|
||||||
|
async move |_, cx| {
|
||||||
|
let language = language_registry
|
||||||
|
.language_for_file_path(&path)
|
||||||
|
.await
|
||||||
|
.log_err();
|
||||||
|
|
||||||
|
new_buffer.update(cx, |buffer, cx| buffer.set_language(language.clone(), cx))?;
|
||||||
|
|
||||||
|
let old_buffer_snapshot = old_buffer.update(cx, |buffer, cx| {
|
||||||
|
buffer.set_language(language, cx);
|
||||||
|
buffer.snapshot()
|
||||||
|
})?;
|
||||||
|
|
||||||
|
buffer_diff
|
||||||
|
.update(cx, |diff, cx| {
|
||||||
|
diff.set_base_text(
|
||||||
|
old_buffer_snapshot,
|
||||||
|
Some(language_registry),
|
||||||
|
new_buffer_snapshot,
|
||||||
|
cx,
|
||||||
|
)
|
||||||
|
})?
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
multibuffer
|
||||||
|
.update(cx, |multibuffer, cx| {
|
||||||
|
let hunk_ranges = {
|
||||||
|
let buffer = new_buffer.read(cx);
|
||||||
|
let diff = buffer_diff.read(cx);
|
||||||
|
diff.hunks_intersecting_range(Anchor::MIN..Anchor::MAX, &buffer, cx)
|
||||||
|
.map(|diff_hunk| diff_hunk.buffer_range.to_point(&buffer))
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
};
|
||||||
|
|
||||||
|
multibuffer.set_excerpts_for_path(
|
||||||
|
PathKey::for_buffer(&new_buffer, cx),
|
||||||
|
new_buffer.clone(),
|
||||||
|
hunk_ranges,
|
||||||
|
editor::DEFAULT_MULTIBUFFER_CONTEXT,
|
||||||
|
cx,
|
||||||
|
);
|
||||||
|
multibuffer.add_diff(buffer_diff, cx);
|
||||||
|
})
|
||||||
|
.log_err();
|
||||||
|
|
||||||
|
anyhow::Ok(())
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
Self::Ready {
|
||||||
|
multibuffer,
|
||||||
|
path,
|
||||||
|
_task: task,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn new(buffer: Entity<Buffer>, cx: &mut Context<Self>) -> Self {
|
||||||
|
let buffer_snapshot = buffer.read(cx).snapshot();
|
||||||
|
let base_text = buffer_snapshot.text();
|
||||||
|
let language_registry = buffer.read(cx).language_registry();
|
||||||
|
let text_snapshot = buffer.read(cx).text_snapshot();
|
||||||
|
let buffer_diff = cx.new(|cx| {
|
||||||
|
let mut diff = BufferDiff::new(&text_snapshot, cx);
|
||||||
|
let _ = diff.set_base_text(
|
||||||
|
buffer_snapshot.clone(),
|
||||||
|
language_registry,
|
||||||
|
text_snapshot,
|
||||||
|
cx,
|
||||||
|
);
|
||||||
|
diff
|
||||||
|
});
|
||||||
|
|
||||||
|
let multibuffer = cx.new(|cx| {
|
||||||
|
let mut multibuffer = MultiBuffer::without_headers(Capability::ReadOnly);
|
||||||
|
multibuffer.add_diff(buffer_diff.clone(), cx);
|
||||||
|
multibuffer
|
||||||
|
});
|
||||||
|
|
||||||
|
Self::Pending {
|
||||||
|
multibuffer,
|
||||||
|
base_text: Arc::new(base_text),
|
||||||
|
buffer,
|
||||||
|
buffer_diff,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn multibuffer(&self) -> &Entity<MultiBuffer> {
|
||||||
|
match self {
|
||||||
|
Self::Pending { multibuffer, .. } => multibuffer,
|
||||||
|
Self::Ready { multibuffer, .. } => multibuffer,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn to_markdown(&self, cx: &App) -> String {
|
||||||
|
let buffer_text = self
|
||||||
|
.multibuffer()
|
||||||
|
.read(cx)
|
||||||
|
.all_buffers()
|
||||||
|
.iter()
|
||||||
|
.map(|buffer| buffer.read(cx).text())
|
||||||
|
.join("\n");
|
||||||
|
let path = match self {
|
||||||
|
Diff::Pending { buffer, .. } => buffer.read(cx).file().map(|file| file.path().as_ref()),
|
||||||
|
Diff::Ready { path, .. } => Some(path.as_path()),
|
||||||
|
};
|
||||||
|
format!(
|
||||||
|
"Diff: {}\n```\n{}\n```\n",
|
||||||
|
path.unwrap_or(Path::new("untitled")).display(),
|
||||||
|
buffer_text
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
|
@ -562,6 +562,15 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
|
||||||
)
|
)
|
||||||
})??;
|
})??;
|
||||||
}
|
}
|
||||||
|
AgentResponseEvent::ToolCallDiff(tool_call_diff) => {
|
||||||
|
acp_thread.update(cx, |thread, cx| {
|
||||||
|
thread.set_tool_call_diff(
|
||||||
|
&tool_call_diff.tool_call_id,
|
||||||
|
tool_call_diff.diff,
|
||||||
|
cx,
|
||||||
|
)
|
||||||
|
})??;
|
||||||
|
}
|
||||||
AgentResponseEvent::Stop(stop_reason) => {
|
AgentResponseEvent::Stop(stop_reason) => {
|
||||||
log::debug!("Assistant message complete: {:?}", stop_reason);
|
log::debug!("Assistant message complete: {:?}", stop_reason);
|
||||||
return Ok(acp::PromptResponse { stop_reason });
|
return Ok(acp::PromptResponse { stop_reason });
|
||||||
|
|
|
@ -103,6 +103,7 @@ pub enum AgentResponseEvent {
|
||||||
ToolCall(acp::ToolCall),
|
ToolCall(acp::ToolCall),
|
||||||
ToolCallUpdate(acp::ToolCallUpdate),
|
ToolCallUpdate(acp::ToolCallUpdate),
|
||||||
ToolCallAuthorization(ToolCallAuthorization),
|
ToolCallAuthorization(ToolCallAuthorization),
|
||||||
|
ToolCallDiff(ToolCallDiff),
|
||||||
Stop(acp::StopReason),
|
Stop(acp::StopReason),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -113,6 +114,12 @@ pub struct ToolCallAuthorization {
|
||||||
pub response: oneshot::Sender<acp::PermissionOptionId>,
|
pub response: oneshot::Sender<acp::PermissionOptionId>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct ToolCallDiff {
|
||||||
|
pub tool_call_id: acp::ToolCallId,
|
||||||
|
pub diff: Entity<acp_thread::Diff>,
|
||||||
|
}
|
||||||
|
|
||||||
pub struct Thread {
|
pub struct Thread {
|
||||||
messages: Vec<AgentMessage>,
|
messages: Vec<AgentMessage>,
|
||||||
completion_mode: CompletionMode,
|
completion_mode: CompletionMode,
|
||||||
|
|
|
@ -42,7 +42,7 @@ use workspace::{CollaboratorId, Workspace};
|
||||||
use zed_actions::agent::{Chat, NextHistoryMessage, PreviousHistoryMessage};
|
use zed_actions::agent::{Chat, NextHistoryMessage, PreviousHistoryMessage};
|
||||||
|
|
||||||
use ::acp_thread::{
|
use ::acp_thread::{
|
||||||
AcpThread, AcpThreadEvent, AgentThreadEntry, AssistantMessage, AssistantMessageChunk, Diff,
|
AcpThread, AcpThreadEvent, AgentThreadEntry, AssistantMessage, AssistantMessageChunk,
|
||||||
LoadError, MentionPath, ThreadStatus, ToolCall, ToolCallContent, ToolCallStatus,
|
LoadError, MentionPath, ThreadStatus, ToolCall, ToolCallContent, ToolCallStatus,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -731,7 +731,11 @@ impl AcpThreadView {
|
||||||
cx: &App,
|
cx: &App,
|
||||||
) -> Option<impl Iterator<Item = Entity<MultiBuffer>>> {
|
) -> Option<impl Iterator<Item = Entity<MultiBuffer>>> {
|
||||||
let entry = self.thread()?.read(cx).entries().get(entry_ix)?;
|
let entry = self.thread()?.read(cx).entries().get(entry_ix)?;
|
||||||
Some(entry.diffs().map(|diff| diff.multibuffer.clone()))
|
Some(
|
||||||
|
entry
|
||||||
|
.diffs()
|
||||||
|
.map(|diff| diff.read(cx).multibuffer().clone()),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn authenticate(
|
fn authenticate(
|
||||||
|
@ -1313,10 +1317,9 @@ impl AcpThreadView {
|
||||||
Empty.into_any_element()
|
Empty.into_any_element()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ToolCallContent::Diff {
|
ToolCallContent::Diff { diff, .. } => {
|
||||||
diff: Diff { multibuffer, .. },
|
self.render_diff_editor(&diff.read(cx).multibuffer())
|
||||||
..
|
}
|
||||||
} => self.render_diff_editor(multibuffer),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue