acp: Clean up entry views on rewind (#36197)
We were leaking diffs and terminals on rewind, we'll now clean them up. This PR also introduces a refactor of how we mantain the entry view state to use a `Vec` that's kept in sync with the thread entries. Release Notes: - N/A
This commit is contained in:
parent
2acfa5e948
commit
43ee604179
13 changed files with 758 additions and 346 deletions
|
@ -13,7 +13,7 @@ path = "src/acp_thread.rs"
|
|||
doctest = false
|
||||
|
||||
[features]
|
||||
test-support = ["gpui/test-support", "project/test-support"]
|
||||
test-support = ["gpui/test-support", "project/test-support", "dep:parking_lot"]
|
||||
|
||||
[dependencies]
|
||||
action_log.workspace = true
|
||||
|
@ -29,6 +29,7 @@ gpui.workspace = true
|
|||
itertools.workspace = true
|
||||
language.workspace = true
|
||||
markdown.workspace = true
|
||||
parking_lot = { workspace = true, optional = true }
|
||||
project.workspace = true
|
||||
prompt_store.workspace = true
|
||||
serde.workspace = true
|
||||
|
|
|
@ -1575,11 +1575,7 @@ mod tests {
|
|||
let project = Project::test(fs, [], cx).await;
|
||||
let connection = Rc::new(FakeAgentConnection::new());
|
||||
let thread = cx
|
||||
.spawn(async move |mut cx| {
|
||||
connection
|
||||
.new_thread(project, Path::new(path!("/test")), &mut cx)
|
||||
.await
|
||||
})
|
||||
.update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
|
@ -1699,11 +1695,7 @@ mod tests {
|
|||
));
|
||||
|
||||
let thread = cx
|
||||
.spawn(async move |mut cx| {
|
||||
connection
|
||||
.new_thread(project, Path::new(path!("/test")), &mut cx)
|
||||
.await
|
||||
})
|
||||
.update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
|
@ -1786,7 +1778,7 @@ mod tests {
|
|||
.unwrap();
|
||||
|
||||
let thread = cx
|
||||
.spawn(|mut cx| connection.new_thread(project, Path::new(path!("/tmp")), &mut cx))
|
||||
.update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
|
@ -1849,11 +1841,7 @@ mod tests {
|
|||
}));
|
||||
|
||||
let thread = cx
|
||||
.spawn(async move |mut cx| {
|
||||
connection
|
||||
.new_thread(project, Path::new(path!("/test")), &mut cx)
|
||||
.await
|
||||
})
|
||||
.update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
|
@ -1961,10 +1949,11 @@ mod tests {
|
|||
}
|
||||
}));
|
||||
|
||||
let thread = connection
|
||||
.new_thread(project, Path::new(path!("/test")), &mut cx.to_async())
|
||||
let thread = cx
|
||||
.update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Hi".into()], cx)))
|
||||
.await
|
||||
.unwrap();
|
||||
|
@ -2021,8 +2010,8 @@ mod tests {
|
|||
.boxed_local()
|
||||
}
|
||||
}));
|
||||
let thread = connection
|
||||
.new_thread(project, Path::new(path!("/test")), &mut cx.to_async())
|
||||
let thread = cx
|
||||
.update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
|
@ -2227,7 +2216,7 @@ mod tests {
|
|||
self: Rc<Self>,
|
||||
project: Entity<Project>,
|
||||
_cwd: &Path,
|
||||
cx: &mut gpui::AsyncApp,
|
||||
cx: &mut gpui::App,
|
||||
) -> Task<gpui::Result<Entity<AcpThread>>> {
|
||||
let session_id = acp::SessionId(
|
||||
rand::thread_rng()
|
||||
|
@ -2237,9 +2226,8 @@ mod tests {
|
|||
.collect::<String>()
|
||||
.into(),
|
||||
);
|
||||
let thread = cx
|
||||
.new(|cx| AcpThread::new("Test", self.clone(), project, session_id.clone(), cx))
|
||||
.unwrap();
|
||||
let thread =
|
||||
cx.new(|cx| AcpThread::new("Test", self.clone(), project, session_id.clone(), cx));
|
||||
self.sessions.lock().insert(session_id, thread.downgrade());
|
||||
Task::ready(Ok(thread))
|
||||
}
|
||||
|
|
|
@ -2,7 +2,7 @@ use crate::AcpThread;
|
|||
use agent_client_protocol::{self as acp};
|
||||
use anyhow::Result;
|
||||
use collections::IndexMap;
|
||||
use gpui::{AsyncApp, Entity, SharedString, Task};
|
||||
use gpui::{Entity, SharedString, Task};
|
||||
use project::Project;
|
||||
use std::{error::Error, fmt, path::Path, rc::Rc, sync::Arc};
|
||||
use ui::{App, IconName};
|
||||
|
@ -22,7 +22,7 @@ pub trait AgentConnection {
|
|||
self: Rc<Self>,
|
||||
project: Entity<Project>,
|
||||
cwd: &Path,
|
||||
cx: &mut AsyncApp,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<Entity<AcpThread>>>;
|
||||
|
||||
fn auth_methods(&self) -> &[acp::AuthMethod];
|
||||
|
@ -160,3 +160,155 @@ impl AgentModelList {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "test-support")]
|
||||
mod test_support {
|
||||
use std::sync::Arc;
|
||||
|
||||
use collections::HashMap;
|
||||
use futures::future::try_join_all;
|
||||
use gpui::{AppContext as _, WeakEntity};
|
||||
use parking_lot::Mutex;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[derive(Clone, Default)]
|
||||
pub struct StubAgentConnection {
|
||||
sessions: Arc<Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
|
||||
permission_requests: HashMap<acp::ToolCallId, Vec<acp::PermissionOption>>,
|
||||
next_prompt_updates: Arc<Mutex<Vec<acp::SessionUpdate>>>,
|
||||
}
|
||||
|
||||
impl StubAgentConnection {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
next_prompt_updates: Default::default(),
|
||||
permission_requests: HashMap::default(),
|
||||
sessions: Arc::default(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_next_prompt_updates(&self, updates: Vec<acp::SessionUpdate>) {
|
||||
*self.next_prompt_updates.lock() = updates;
|
||||
}
|
||||
|
||||
pub fn with_permission_requests(
|
||||
mut self,
|
||||
permission_requests: HashMap<acp::ToolCallId, Vec<acp::PermissionOption>>,
|
||||
) -> Self {
|
||||
self.permission_requests = permission_requests;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn send_update(
|
||||
&self,
|
||||
session_id: acp::SessionId,
|
||||
update: acp::SessionUpdate,
|
||||
cx: &mut App,
|
||||
) {
|
||||
self.sessions
|
||||
.lock()
|
||||
.get(&session_id)
|
||||
.unwrap()
|
||||
.update(cx, |thread, cx| {
|
||||
thread.handle_session_update(update.clone(), cx).unwrap();
|
||||
})
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
impl AgentConnection for StubAgentConnection {
|
||||
fn auth_methods(&self) -> &[acp::AuthMethod] {
|
||||
&[]
|
||||
}
|
||||
|
||||
fn new_thread(
|
||||
self: Rc<Self>,
|
||||
project: Entity<Project>,
|
||||
_cwd: &Path,
|
||||
cx: &mut gpui::App,
|
||||
) -> Task<gpui::Result<Entity<AcpThread>>> {
|
||||
let session_id = acp::SessionId(self.sessions.lock().len().to_string().into());
|
||||
let thread =
|
||||
cx.new(|cx| AcpThread::new("Test", self.clone(), project, session_id.clone(), cx));
|
||||
self.sessions.lock().insert(session_id, thread.downgrade());
|
||||
Task::ready(Ok(thread))
|
||||
}
|
||||
|
||||
fn authenticate(
|
||||
&self,
|
||||
_method_id: acp::AuthMethodId,
|
||||
_cx: &mut App,
|
||||
) -> Task<gpui::Result<()>> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn prompt(
|
||||
&self,
|
||||
_id: Option<UserMessageId>,
|
||||
params: acp::PromptRequest,
|
||||
cx: &mut App,
|
||||
) -> Task<gpui::Result<acp::PromptResponse>> {
|
||||
let sessions = self.sessions.lock();
|
||||
let thread = sessions.get(¶ms.session_id).unwrap();
|
||||
let mut tasks = vec![];
|
||||
for update in self.next_prompt_updates.lock().drain(..) {
|
||||
let thread = thread.clone();
|
||||
let update = update.clone();
|
||||
let permission_request = if let acp::SessionUpdate::ToolCall(tool_call) = &update
|
||||
&& let Some(options) = self.permission_requests.get(&tool_call.id)
|
||||
{
|
||||
Some((tool_call.clone(), options.clone()))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let task = cx.spawn(async move |cx| {
|
||||
if let Some((tool_call, options)) = permission_request {
|
||||
let permission = thread.update(cx, |thread, cx| {
|
||||
thread.request_tool_call_authorization(
|
||||
tool_call.clone(),
|
||||
options.clone(),
|
||||
cx,
|
||||
)
|
||||
})?;
|
||||
permission.await?;
|
||||
}
|
||||
thread.update(cx, |thread, cx| {
|
||||
thread.handle_session_update(update.clone(), cx).unwrap();
|
||||
})?;
|
||||
anyhow::Ok(())
|
||||
});
|
||||
tasks.push(task);
|
||||
}
|
||||
cx.spawn(async move |_| {
|
||||
try_join_all(tasks).await?;
|
||||
Ok(acp::PromptResponse {
|
||||
stop_reason: acp::StopReason::EndTurn,
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
fn cancel(&self, _session_id: &acp::SessionId, _cx: &mut App) {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn session_editor(
|
||||
&self,
|
||||
_session_id: &agent_client_protocol::SessionId,
|
||||
_cx: &mut App,
|
||||
) -> Option<Rc<dyn AgentSessionEditor>> {
|
||||
Some(Rc::new(StubAgentSessionEditor))
|
||||
}
|
||||
}
|
||||
|
||||
struct StubAgentSessionEditor;
|
||||
|
||||
impl AgentSessionEditor for StubAgentSessionEditor {
|
||||
fn truncate(&self, _: UserMessageId, _: &mut App) -> Task<Result<()>> {
|
||||
Task::ready(Ok(()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "test-support")]
|
||||
pub use test_support::*;
|
||||
|
|
|
@ -522,7 +522,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
|
|||
self: Rc<Self>,
|
||||
project: Entity<Project>,
|
||||
cwd: &Path,
|
||||
cx: &mut AsyncApp,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<Entity<acp_thread::AcpThread>>> {
|
||||
let agent = self.0.clone();
|
||||
log::info!("Creating new thread for project at: {:?}", cwd);
|
||||
|
@ -940,11 +940,7 @@ mod tests {
|
|||
// Create a thread/session
|
||||
let acp_thread = cx
|
||||
.update(|cx| {
|
||||
Rc::new(connection.clone()).new_thread(
|
||||
project.clone(),
|
||||
Path::new("/a"),
|
||||
&mut cx.to_async(),
|
||||
)
|
||||
Rc::new(connection.clone()).new_thread(project.clone(), Path::new("/a"), cx)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
|
|
@ -841,7 +841,7 @@ async fn test_agent_connection(cx: &mut TestAppContext) {
|
|||
// Create a thread using new_thread
|
||||
let connection_rc = Rc::new(connection.clone());
|
||||
let acp_thread = cx
|
||||
.update(|cx| connection_rc.new_thread(project, cwd, &mut cx.to_async()))
|
||||
.update(|cx| connection_rc.new_thread(project, cwd, cx))
|
||||
.await
|
||||
.expect("new_thread should succeed");
|
||||
|
||||
|
|
|
@ -423,7 +423,7 @@ impl AgentConnection for AcpConnection {
|
|||
self: Rc<Self>,
|
||||
project: Entity<Project>,
|
||||
_cwd: &Path,
|
||||
cx: &mut AsyncApp,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<Entity<AcpThread>>> {
|
||||
let task = self.connection.request_any(
|
||||
acp_old::InitializeParams {
|
||||
|
|
|
@ -111,7 +111,7 @@ impl AgentConnection for AcpConnection {
|
|||
self: Rc<Self>,
|
||||
project: Entity<Project>,
|
||||
cwd: &Path,
|
||||
cx: &mut AsyncApp,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<Entity<AcpThread>>> {
|
||||
let conn = self.connection.clone();
|
||||
let sessions = self.sessions.clone();
|
||||
|
|
|
@ -74,7 +74,7 @@ impl AgentConnection for ClaudeAgentConnection {
|
|||
self: Rc<Self>,
|
||||
project: Entity<Project>,
|
||||
cwd: &Path,
|
||||
cx: &mut AsyncApp,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<Entity<AcpThread>>> {
|
||||
let cwd = cwd.to_owned();
|
||||
cx.spawn(async move |cx| {
|
||||
|
|
|
@ -422,8 +422,8 @@ pub async fn new_test_thread(
|
|||
.await
|
||||
.unwrap();
|
||||
|
||||
let thread = connection
|
||||
.new_thread(project.clone(), current_dir.as_ref(), &mut cx.to_async())
|
||||
let thread = cx
|
||||
.update(|cx| connection.new_thread(project.clone(), current_dir.as_ref(), cx))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
|
|
|
@ -103,6 +103,7 @@ workspace.workspace = true
|
|||
zed_actions.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
acp_thread = { workspace = true, features = ["test-support"] }
|
||||
agent = { workspace = true, features = ["test-support"] }
|
||||
assistant_context = { workspace = true, features = ["test-support"] }
|
||||
assistant_tools.workspace = true
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
mod completion_provider;
|
||||
mod entry_view_state;
|
||||
mod message_editor;
|
||||
mod model_selector;
|
||||
mod model_selector_popover;
|
||||
|
|
351
crates/agent_ui/src/acp/entry_view_state.rs
Normal file
351
crates/agent_ui/src/acp/entry_view_state.rs
Normal file
|
@ -0,0 +1,351 @@
|
|||
use std::{collections::HashMap, ops::Range};
|
||||
|
||||
use acp_thread::AcpThread;
|
||||
use editor::{Editor, EditorMode, MinimapVisibility, MultiBuffer};
|
||||
use gpui::{
|
||||
AnyEntity, App, AppContext as _, Entity, EntityId, TextStyleRefinement, WeakEntity, Window,
|
||||
};
|
||||
use language::language_settings::SoftWrap;
|
||||
use settings::Settings as _;
|
||||
use terminal_view::TerminalView;
|
||||
use theme::ThemeSettings;
|
||||
use ui::TextSize;
|
||||
use workspace::Workspace;
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct EntryViewState {
|
||||
entries: Vec<Entry>,
|
||||
}
|
||||
|
||||
impl EntryViewState {
|
||||
pub fn entry(&self, index: usize) -> Option<&Entry> {
|
||||
self.entries.get(index)
|
||||
}
|
||||
|
||||
pub fn sync_entry(
|
||||
&mut self,
|
||||
workspace: WeakEntity<Workspace>,
|
||||
thread: Entity<AcpThread>,
|
||||
index: usize,
|
||||
window: &mut Window,
|
||||
cx: &mut App,
|
||||
) {
|
||||
debug_assert!(index <= self.entries.len());
|
||||
let entry = if let Some(entry) = self.entries.get_mut(index) {
|
||||
entry
|
||||
} else {
|
||||
self.entries.push(Entry::default());
|
||||
self.entries.last_mut().unwrap()
|
||||
};
|
||||
|
||||
entry.sync_diff_multibuffers(&thread, index, window, cx);
|
||||
entry.sync_terminals(&workspace, &thread, index, window, cx);
|
||||
}
|
||||
|
||||
pub fn remove(&mut self, range: Range<usize>) {
|
||||
self.entries.drain(range);
|
||||
}
|
||||
|
||||
pub fn settings_changed(&mut self, cx: &mut App) {
|
||||
for entry in self.entries.iter() {
|
||||
for view in entry.views.values() {
|
||||
if let Ok(diff_editor) = view.clone().downcast::<Editor>() {
|
||||
diff_editor.update(cx, |diff_editor, cx| {
|
||||
diff_editor
|
||||
.set_text_style_refinement(diff_editor_text_style_refinement(cx));
|
||||
cx.notify();
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Entry {
|
||||
views: HashMap<EntityId, AnyEntity>,
|
||||
}
|
||||
|
||||
impl Entry {
|
||||
pub fn editor_for_diff(&self, diff: &Entity<MultiBuffer>) -> Option<Entity<Editor>> {
|
||||
self.views
|
||||
.get(&diff.entity_id())
|
||||
.cloned()
|
||||
.map(|entity| entity.downcast::<Editor>().unwrap())
|
||||
}
|
||||
|
||||
pub fn terminal(
|
||||
&self,
|
||||
terminal: &Entity<acp_thread::Terminal>,
|
||||
) -> Option<Entity<TerminalView>> {
|
||||
self.views
|
||||
.get(&terminal.entity_id())
|
||||
.cloned()
|
||||
.map(|entity| entity.downcast::<TerminalView>().unwrap())
|
||||
}
|
||||
|
||||
fn sync_diff_multibuffers(
|
||||
&mut self,
|
||||
thread: &Entity<AcpThread>,
|
||||
index: usize,
|
||||
window: &mut Window,
|
||||
cx: &mut App,
|
||||
) {
|
||||
let Some(entry) = thread.read(cx).entries().get(index) else {
|
||||
return;
|
||||
};
|
||||
|
||||
let multibuffers = entry
|
||||
.diffs()
|
||||
.map(|diff| diff.read(cx).multibuffer().clone());
|
||||
|
||||
let multibuffers = multibuffers.collect::<Vec<_>>();
|
||||
|
||||
for multibuffer in multibuffers {
|
||||
if self.views.contains_key(&multibuffer.entity_id()) {
|
||||
return;
|
||||
}
|
||||
|
||||
let editor = cx.new(|cx| {
|
||||
let mut editor = Editor::new(
|
||||
EditorMode::Full {
|
||||
scale_ui_elements_with_buffer_font_size: false,
|
||||
show_active_line_background: false,
|
||||
sized_by_content: true,
|
||||
},
|
||||
multibuffer.clone(),
|
||||
None,
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
editor.set_show_gutter(false, cx);
|
||||
editor.disable_inline_diagnostics();
|
||||
editor.disable_expand_excerpt_buttons(cx);
|
||||
editor.set_show_vertical_scrollbar(false, cx);
|
||||
editor.set_minimap_visibility(MinimapVisibility::Disabled, window, cx);
|
||||
editor.set_soft_wrap_mode(SoftWrap::None, cx);
|
||||
editor.scroll_manager.set_forbid_vertical_scroll(true);
|
||||
editor.set_show_indent_guides(false, cx);
|
||||
editor.set_read_only(true);
|
||||
editor.set_show_breakpoints(false, cx);
|
||||
editor.set_show_code_actions(false, cx);
|
||||
editor.set_show_git_diff_gutter(false, cx);
|
||||
editor.set_expand_all_diff_hunks(cx);
|
||||
editor.set_text_style_refinement(diff_editor_text_style_refinement(cx));
|
||||
editor
|
||||
});
|
||||
|
||||
let entity_id = multibuffer.entity_id();
|
||||
self.views.insert(entity_id, editor.into_any());
|
||||
}
|
||||
}
|
||||
|
||||
fn sync_terminals(
|
||||
&mut self,
|
||||
workspace: &WeakEntity<Workspace>,
|
||||
thread: &Entity<AcpThread>,
|
||||
index: usize,
|
||||
window: &mut Window,
|
||||
cx: &mut App,
|
||||
) {
|
||||
let Some(entry) = thread.read(cx).entries().get(index) else {
|
||||
return;
|
||||
};
|
||||
|
||||
let terminals = entry
|
||||
.terminals()
|
||||
.map(|terminal| terminal.clone())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
for terminal in terminals {
|
||||
if self.views.contains_key(&terminal.entity_id()) {
|
||||
return;
|
||||
}
|
||||
|
||||
let Some(strong_workspace) = workspace.upgrade() else {
|
||||
return;
|
||||
};
|
||||
|
||||
let terminal_view = cx.new(|cx| {
|
||||
let mut view = TerminalView::new(
|
||||
terminal.read(cx).inner().clone(),
|
||||
workspace.clone(),
|
||||
None,
|
||||
strong_workspace.read(cx).project().downgrade(),
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
view.set_embedded_mode(Some(1000), cx);
|
||||
view
|
||||
});
|
||||
|
||||
let entity_id = terminal.entity_id();
|
||||
self.views.insert(entity_id, terminal_view.into_any());
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub fn len(&self) -> usize {
|
||||
self.views.len()
|
||||
}
|
||||
}
|
||||
|
||||
fn diff_editor_text_style_refinement(cx: &mut App) -> TextStyleRefinement {
|
||||
TextStyleRefinement {
|
||||
font_size: Some(
|
||||
TextSize::Small
|
||||
.rems(cx)
|
||||
.to_pixels(ThemeSettings::get_global(cx).agent_font_size(cx))
|
||||
.into(),
|
||||
),
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for Entry {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
// Avoid allocating in the heap by default
|
||||
views: HashMap::with_capacity(0),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::{path::Path, rc::Rc};
|
||||
|
||||
use acp_thread::{AgentConnection, StubAgentConnection};
|
||||
use agent_client_protocol as acp;
|
||||
use agent_settings::AgentSettings;
|
||||
use buffer_diff::{DiffHunkStatus, DiffHunkStatusKind};
|
||||
use editor::{EditorSettings, RowInfo};
|
||||
use fs::FakeFs;
|
||||
use gpui::{SemanticVersion, TestAppContext};
|
||||
use multi_buffer::MultiBufferRow;
|
||||
use pretty_assertions::assert_matches;
|
||||
use project::Project;
|
||||
use serde_json::json;
|
||||
use settings::{Settings as _, SettingsStore};
|
||||
use theme::ThemeSettings;
|
||||
use util::path;
|
||||
use workspace::Workspace;
|
||||
|
||||
use crate::acp::entry_view_state::EntryViewState;
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_diff_sync(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
fs.insert_tree(
|
||||
"/project",
|
||||
json!({
|
||||
"hello.txt": "hi world"
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
let project = Project::test(fs, [Path::new(path!("/project"))], cx).await;
|
||||
|
||||
let (workspace, cx) =
|
||||
cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
|
||||
|
||||
let tool_call = acp::ToolCall {
|
||||
id: acp::ToolCallId("tool".into()),
|
||||
title: "Tool call".into(),
|
||||
kind: acp::ToolKind::Other,
|
||||
status: acp::ToolCallStatus::InProgress,
|
||||
content: vec![acp::ToolCallContent::Diff {
|
||||
diff: acp::Diff {
|
||||
path: "/project/hello.txt".into(),
|
||||
old_text: Some("hi world".into()),
|
||||
new_text: "hello world".into(),
|
||||
},
|
||||
}],
|
||||
locations: vec![],
|
||||
raw_input: None,
|
||||
raw_output: None,
|
||||
};
|
||||
let connection = Rc::new(StubAgentConnection::new());
|
||||
let thread = cx
|
||||
.update(|_, cx| {
|
||||
connection
|
||||
.clone()
|
||||
.new_thread(project, Path::new(path!("/project")), cx)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
let session_id = thread.update(cx, |thread, _| thread.session_id().clone());
|
||||
|
||||
cx.update(|_, cx| {
|
||||
connection.send_update(session_id, acp::SessionUpdate::ToolCall(tool_call), cx)
|
||||
});
|
||||
|
||||
let mut view_state = EntryViewState::default();
|
||||
cx.update(|window, cx| {
|
||||
view_state.sync_entry(workspace.downgrade(), thread.clone(), 0, window, cx);
|
||||
});
|
||||
|
||||
let multibuffer = thread.read_with(cx, |thread, cx| {
|
||||
thread
|
||||
.entries()
|
||||
.get(0)
|
||||
.unwrap()
|
||||
.diffs()
|
||||
.next()
|
||||
.unwrap()
|
||||
.read(cx)
|
||||
.multibuffer()
|
||||
.clone()
|
||||
});
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
let entry = view_state.entry(0).unwrap();
|
||||
let diff_editor = entry.editor_for_diff(&multibuffer).unwrap();
|
||||
assert_eq!(
|
||||
diff_editor.read_with(cx, |editor, cx| editor.text(cx)),
|
||||
"hi world\nhello world"
|
||||
);
|
||||
let row_infos = diff_editor.read_with(cx, |editor, cx| {
|
||||
let multibuffer = editor.buffer().read(cx);
|
||||
multibuffer
|
||||
.snapshot(cx)
|
||||
.row_infos(MultiBufferRow(0))
|
||||
.collect::<Vec<_>>()
|
||||
});
|
||||
assert_matches!(
|
||||
row_infos.as_slice(),
|
||||
[
|
||||
RowInfo {
|
||||
multibuffer_row: Some(MultiBufferRow(0)),
|
||||
diff_status: Some(DiffHunkStatus {
|
||||
kind: DiffHunkStatusKind::Deleted,
|
||||
..
|
||||
}),
|
||||
..
|
||||
},
|
||||
RowInfo {
|
||||
multibuffer_row: Some(MultiBufferRow(1)),
|
||||
diff_status: Some(DiffHunkStatus {
|
||||
kind: DiffHunkStatusKind::Added,
|
||||
..
|
||||
}),
|
||||
..
|
||||
}
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
fn init_test(cx: &mut TestAppContext) {
|
||||
cx.update(|cx| {
|
||||
let settings_store = SettingsStore::test(cx);
|
||||
cx.set_global(settings_store);
|
||||
language::init(cx);
|
||||
Project::init_settings(cx);
|
||||
AgentSettings::register(cx);
|
||||
workspace::init_settings(cx);
|
||||
ThemeSettings::register(cx);
|
||||
release_channel::init(SemanticVersion::default(), cx);
|
||||
EditorSettings::register(cx);
|
||||
});
|
||||
}
|
||||
}
|
|
@ -12,24 +12,22 @@ use audio::{Audio, Sound};
|
|||
use buffer_diff::BufferDiff;
|
||||
use collections::{HashMap, HashSet};
|
||||
use editor::scroll::Autoscroll;
|
||||
use editor::{Editor, EditorMode, MinimapVisibility, MultiBuffer, PathKey, SelectionEffects};
|
||||
use editor::{Editor, EditorMode, MultiBuffer, PathKey, SelectionEffects};
|
||||
use file_icons::FileIcons;
|
||||
use gpui::{
|
||||
Action, Animation, AnimationExt, App, BorderStyle, ClickEvent, EdgesRefinement, Empty, Entity,
|
||||
EntityId, FocusHandle, Focusable, Hsla, Length, ListOffset, ListState, MouseButton,
|
||||
PlatformDisplay, SharedString, Stateful, StyleRefinement, Subscription, Task, TextStyle,
|
||||
TextStyleRefinement, Transformation, UnderlineStyle, WeakEntity, Window, WindowHandle, div,
|
||||
linear_color_stop, linear_gradient, list, percentage, point, prelude::*, pulsating_between,
|
||||
FocusHandle, Focusable, Hsla, Length, ListOffset, ListState, MouseButton, PlatformDisplay,
|
||||
SharedString, Stateful, StyleRefinement, Subscription, Task, TextStyle, TextStyleRefinement,
|
||||
Transformation, UnderlineStyle, WeakEntity, Window, WindowHandle, div, linear_color_stop,
|
||||
linear_gradient, list, percentage, point, prelude::*, pulsating_between,
|
||||
};
|
||||
use language::Buffer;
|
||||
use language::language_settings::SoftWrap;
|
||||
use markdown::{HeadingLevelStyles, Markdown, MarkdownElement, MarkdownStyle};
|
||||
use project::Project;
|
||||
use prompt_store::PromptId;
|
||||
use rope::Point;
|
||||
use settings::{Settings as _, SettingsStore};
|
||||
use std::{collections::BTreeMap, process::ExitStatus, rc::Rc, time::Duration};
|
||||
use terminal_view::TerminalView;
|
||||
use text::Anchor;
|
||||
use theme::ThemeSettings;
|
||||
use ui::{
|
||||
|
@ -41,6 +39,7 @@ use workspace::{CollaboratorId, Workspace};
|
|||
use zed_actions::agent::{Chat, ToggleModelSelector};
|
||||
use zed_actions::assistant::OpenRulesLibrary;
|
||||
|
||||
use super::entry_view_state::EntryViewState;
|
||||
use crate::acp::AcpModelSelectorPopover;
|
||||
use crate::acp::message_editor::{MessageEditor, MessageEditorEvent};
|
||||
use crate::agent_diff::AgentDiff;
|
||||
|
@ -61,8 +60,7 @@ pub struct AcpThreadView {
|
|||
thread_store: Entity<ThreadStore>,
|
||||
text_thread_store: Entity<TextThreadStore>,
|
||||
thread_state: ThreadState,
|
||||
diff_editors: HashMap<EntityId, Entity<Editor>>,
|
||||
terminal_views: HashMap<EntityId, Entity<TerminalView>>,
|
||||
entry_view_state: EntryViewState,
|
||||
message_editor: Entity<MessageEditor>,
|
||||
model_selector: Option<Entity<AcpModelSelectorPopover>>,
|
||||
notifications: Vec<WindowHandle<AgentNotification>>,
|
||||
|
@ -149,8 +147,7 @@ impl AcpThreadView {
|
|||
model_selector: None,
|
||||
notifications: Vec::new(),
|
||||
notification_subscriptions: HashMap::default(),
|
||||
diff_editors: Default::default(),
|
||||
terminal_views: Default::default(),
|
||||
entry_view_state: EntryViewState::default(),
|
||||
list_state: list_state.clone(),
|
||||
scrollbar_state: ScrollbarState::new(list_state).parent_entity(&cx.entity()),
|
||||
last_error: None,
|
||||
|
@ -209,11 +206,18 @@ impl AcpThreadView {
|
|||
// })
|
||||
// .ok();
|
||||
|
||||
let result = match connection
|
||||
.clone()
|
||||
.new_thread(project.clone(), &root_dir, cx)
|
||||
.await
|
||||
{
|
||||
let Some(result) = cx
|
||||
.update(|_, cx| {
|
||||
connection
|
||||
.clone()
|
||||
.new_thread(project.clone(), &root_dir, cx)
|
||||
})
|
||||
.log_err()
|
||||
else {
|
||||
return;
|
||||
};
|
||||
|
||||
let result = match result.await {
|
||||
Err(e) => {
|
||||
let mut cx = cx.clone();
|
||||
if e.is::<acp_thread::AuthRequired>() {
|
||||
|
@ -480,16 +484,29 @@ impl AcpThreadView {
|
|||
) {
|
||||
match event {
|
||||
AcpThreadEvent::NewEntry => {
|
||||
let index = thread.read(cx).entries().len() - 1;
|
||||
self.sync_thread_entry_view(index, window, cx);
|
||||
let len = thread.read(cx).entries().len();
|
||||
let index = len - 1;
|
||||
self.entry_view_state.sync_entry(
|
||||
self.workspace.clone(),
|
||||
thread.clone(),
|
||||
index,
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
self.list_state.splice(index..index, 1);
|
||||
}
|
||||
AcpThreadEvent::EntryUpdated(index) => {
|
||||
self.sync_thread_entry_view(*index, window, cx);
|
||||
self.entry_view_state.sync_entry(
|
||||
self.workspace.clone(),
|
||||
thread.clone(),
|
||||
*index,
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
self.list_state.splice(*index..index + 1, 1);
|
||||
}
|
||||
AcpThreadEvent::EntriesRemoved(range) => {
|
||||
// TODO: Clean up unused diff editors and terminal views
|
||||
self.entry_view_state.remove(range.clone());
|
||||
self.list_state.splice(range.clone(), 0);
|
||||
}
|
||||
AcpThreadEvent::ToolAuthorizationRequired => {
|
||||
|
@ -523,128 +540,6 @@ impl AcpThreadView {
|
|||
cx.notify();
|
||||
}
|
||||
|
||||
fn sync_thread_entry_view(
|
||||
&mut self,
|
||||
entry_ix: usize,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
self.sync_diff_multibuffers(entry_ix, window, cx);
|
||||
self.sync_terminals(entry_ix, window, cx);
|
||||
}
|
||||
|
||||
fn sync_diff_multibuffers(
|
||||
&mut self,
|
||||
entry_ix: usize,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
let Some(multibuffers) = self.entry_diff_multibuffers(entry_ix, cx) else {
|
||||
return;
|
||||
};
|
||||
|
||||
let multibuffers = multibuffers.collect::<Vec<_>>();
|
||||
|
||||
for multibuffer in multibuffers {
|
||||
if self.diff_editors.contains_key(&multibuffer.entity_id()) {
|
||||
return;
|
||||
}
|
||||
|
||||
let editor = cx.new(|cx| {
|
||||
let mut editor = Editor::new(
|
||||
EditorMode::Full {
|
||||
scale_ui_elements_with_buffer_font_size: false,
|
||||
show_active_line_background: false,
|
||||
sized_by_content: true,
|
||||
},
|
||||
multibuffer.clone(),
|
||||
None,
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
editor.set_show_gutter(false, cx);
|
||||
editor.disable_inline_diagnostics();
|
||||
editor.disable_expand_excerpt_buttons(cx);
|
||||
editor.set_show_vertical_scrollbar(false, cx);
|
||||
editor.set_minimap_visibility(MinimapVisibility::Disabled, window, cx);
|
||||
editor.set_soft_wrap_mode(SoftWrap::None, cx);
|
||||
editor.scroll_manager.set_forbid_vertical_scroll(true);
|
||||
editor.set_show_indent_guides(false, cx);
|
||||
editor.set_read_only(true);
|
||||
editor.set_show_breakpoints(false, cx);
|
||||
editor.set_show_code_actions(false, cx);
|
||||
editor.set_show_git_diff_gutter(false, cx);
|
||||
editor.set_expand_all_diff_hunks(cx);
|
||||
editor.set_text_style_refinement(diff_editor_text_style_refinement(cx));
|
||||
editor
|
||||
});
|
||||
let entity_id = multibuffer.entity_id();
|
||||
cx.observe_release(&multibuffer, move |this, _, _| {
|
||||
this.diff_editors.remove(&entity_id);
|
||||
})
|
||||
.detach();
|
||||
|
||||
self.diff_editors.insert(entity_id, editor);
|
||||
}
|
||||
}
|
||||
|
||||
fn entry_diff_multibuffers(
|
||||
&self,
|
||||
entry_ix: usize,
|
||||
cx: &App,
|
||||
) -> Option<impl Iterator<Item = Entity<MultiBuffer>>> {
|
||||
let entry = self.thread()?.read(cx).entries().get(entry_ix)?;
|
||||
Some(
|
||||
entry
|
||||
.diffs()
|
||||
.map(|diff| diff.read(cx).multibuffer().clone()),
|
||||
)
|
||||
}
|
||||
|
||||
fn sync_terminals(&mut self, entry_ix: usize, window: &mut Window, cx: &mut Context<Self>) {
|
||||
let Some(terminals) = self.entry_terminals(entry_ix, cx) else {
|
||||
return;
|
||||
};
|
||||
|
||||
let terminals = terminals.collect::<Vec<_>>();
|
||||
|
||||
for terminal in terminals {
|
||||
if self.terminal_views.contains_key(&terminal.entity_id()) {
|
||||
return;
|
||||
}
|
||||
|
||||
let terminal_view = cx.new(|cx| {
|
||||
let mut view = TerminalView::new(
|
||||
terminal.read(cx).inner().clone(),
|
||||
self.workspace.clone(),
|
||||
None,
|
||||
self.project.downgrade(),
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
view.set_embedded_mode(Some(1000), cx);
|
||||
view
|
||||
});
|
||||
|
||||
let entity_id = terminal.entity_id();
|
||||
cx.observe_release(&terminal, move |this, _, _| {
|
||||
this.terminal_views.remove(&entity_id);
|
||||
})
|
||||
.detach();
|
||||
|
||||
self.terminal_views.insert(entity_id, terminal_view);
|
||||
}
|
||||
}
|
||||
|
||||
fn entry_terminals(
|
||||
&self,
|
||||
entry_ix: usize,
|
||||
cx: &App,
|
||||
) -> Option<impl Iterator<Item = Entity<acp_thread::Terminal>>> {
|
||||
let entry = self.thread()?.read(cx).entries().get(entry_ix)?;
|
||||
Some(entry.terminals().map(|terminal| terminal.clone()))
|
||||
}
|
||||
|
||||
fn authenticate(
|
||||
&mut self,
|
||||
method: acp::AuthMethodId,
|
||||
|
@ -712,7 +607,7 @@ impl AcpThreadView {
|
|||
|
||||
fn render_entry(
|
||||
&self,
|
||||
index: usize,
|
||||
entry_ix: usize,
|
||||
total_entries: usize,
|
||||
entry: &AgentThreadEntry,
|
||||
window: &mut Window,
|
||||
|
@ -720,7 +615,7 @@ impl AcpThreadView {
|
|||
) -> AnyElement {
|
||||
let primary = match &entry {
|
||||
AgentThreadEntry::UserMessage(message) => div()
|
||||
.id(("user_message", index))
|
||||
.id(("user_message", entry_ix))
|
||||
.py_4()
|
||||
.px_2()
|
||||
.children(message.id.clone().and_then(|message_id| {
|
||||
|
@ -749,7 +644,9 @@ impl AcpThreadView {
|
|||
.text_xs()
|
||||
.id("message")
|
||||
.on_click(cx.listener({
|
||||
move |this, _, window, cx| this.start_editing_message(index, window, cx)
|
||||
move |this, _, window, cx| {
|
||||
this.start_editing_message(entry_ix, window, cx)
|
||||
}
|
||||
}))
|
||||
.children(
|
||||
if let Some(editing) = self.editing_message.as_ref()
|
||||
|
@ -787,7 +684,7 @@ impl AcpThreadView {
|
|||
AssistantMessageChunk::Thought { block } => {
|
||||
block.markdown().map(|md| {
|
||||
self.render_thinking_block(
|
||||
index,
|
||||
entry_ix,
|
||||
chunk_ix,
|
||||
md.clone(),
|
||||
window,
|
||||
|
@ -803,7 +700,7 @@ impl AcpThreadView {
|
|||
v_flex()
|
||||
.px_5()
|
||||
.py_1()
|
||||
.when(index + 1 == total_entries, |this| this.pb_4())
|
||||
.when(entry_ix + 1 == total_entries, |this| this.pb_4())
|
||||
.w_full()
|
||||
.text_ui(cx)
|
||||
.child(message_body)
|
||||
|
@ -815,10 +712,12 @@ impl AcpThreadView {
|
|||
div().w_full().py_1p5().px_5().map(|this| {
|
||||
if has_terminals {
|
||||
this.children(tool_call.terminals().map(|terminal| {
|
||||
self.render_terminal_tool_call(terminal, tool_call, window, cx)
|
||||
self.render_terminal_tool_call(
|
||||
entry_ix, terminal, tool_call, window, cx,
|
||||
)
|
||||
}))
|
||||
} else {
|
||||
this.child(self.render_tool_call(index, tool_call, window, cx))
|
||||
this.child(self.render_tool_call(entry_ix, tool_call, window, cx))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
@ -830,7 +729,7 @@ impl AcpThreadView {
|
|||
};
|
||||
|
||||
let is_generating = matches!(thread.read(cx).status(), ThreadStatus::Generating);
|
||||
let primary = if index == total_entries - 1 && !is_generating {
|
||||
let primary = if entry_ix == total_entries - 1 && !is_generating {
|
||||
v_flex()
|
||||
.w_full()
|
||||
.child(primary)
|
||||
|
@ -841,10 +740,10 @@ impl AcpThreadView {
|
|||
};
|
||||
|
||||
if let Some(editing) = self.editing_message.as_ref()
|
||||
&& editing.index < index
|
||||
&& editing.index < entry_ix
|
||||
{
|
||||
let backdrop = div()
|
||||
.id(("backdrop", index))
|
||||
.id(("backdrop", entry_ix))
|
||||
.size_full()
|
||||
.absolute()
|
||||
.inset_0()
|
||||
|
@ -1125,7 +1024,9 @@ impl AcpThreadView {
|
|||
.w_full()
|
||||
.children(tool_call.content.iter().map(|content| {
|
||||
div()
|
||||
.child(self.render_tool_call_content(content, tool_call, window, cx))
|
||||
.child(
|
||||
self.render_tool_call_content(entry_ix, content, tool_call, window, cx),
|
||||
)
|
||||
.into_any_element()
|
||||
}))
|
||||
.child(self.render_permission_buttons(
|
||||
|
@ -1139,7 +1040,9 @@ impl AcpThreadView {
|
|||
.w_full()
|
||||
.children(tool_call.content.iter().map(|content| {
|
||||
div()
|
||||
.child(self.render_tool_call_content(content, tool_call, window, cx))
|
||||
.child(
|
||||
self.render_tool_call_content(entry_ix, content, tool_call, window, cx),
|
||||
)
|
||||
.into_any_element()
|
||||
})),
|
||||
ToolCallStatus::Rejected => v_flex().size_0(),
|
||||
|
@ -1257,6 +1160,7 @@ impl AcpThreadView {
|
|||
|
||||
fn render_tool_call_content(
|
||||
&self,
|
||||
entry_ix: usize,
|
||||
content: &ToolCallContent,
|
||||
tool_call: &ToolCall,
|
||||
window: &Window,
|
||||
|
@ -1273,10 +1177,10 @@ impl AcpThreadView {
|
|||
}
|
||||
}
|
||||
ToolCallContent::Diff(diff) => {
|
||||
self.render_diff_editor(&diff.read(cx).multibuffer(), cx)
|
||||
self.render_diff_editor(entry_ix, &diff.read(cx).multibuffer(), cx)
|
||||
}
|
||||
ToolCallContent::Terminal(terminal) => {
|
||||
self.render_terminal_tool_call(terminal, tool_call, window, cx)
|
||||
self.render_terminal_tool_call(entry_ix, terminal, tool_call, window, cx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1420,6 +1324,7 @@ impl AcpThreadView {
|
|||
|
||||
fn render_diff_editor(
|
||||
&self,
|
||||
entry_ix: usize,
|
||||
multibuffer: &Entity<MultiBuffer>,
|
||||
cx: &Context<Self>,
|
||||
) -> AnyElement {
|
||||
|
@ -1428,7 +1333,9 @@ impl AcpThreadView {
|
|||
.border_t_1()
|
||||
.border_color(self.tool_card_border_color(cx))
|
||||
.child(
|
||||
if let Some(editor) = self.diff_editors.get(&multibuffer.entity_id()) {
|
||||
if let Some(entry) = self.entry_view_state.entry(entry_ix)
|
||||
&& let Some(editor) = entry.editor_for_diff(&multibuffer)
|
||||
{
|
||||
editor.clone().into_any_element()
|
||||
} else {
|
||||
Empty.into_any()
|
||||
|
@ -1439,6 +1346,7 @@ impl AcpThreadView {
|
|||
|
||||
fn render_terminal_tool_call(
|
||||
&self,
|
||||
entry_ix: usize,
|
||||
terminal: &Entity<acp_thread::Terminal>,
|
||||
tool_call: &ToolCall,
|
||||
window: &Window,
|
||||
|
@ -1627,8 +1535,11 @@ impl AcpThreadView {
|
|||
})),
|
||||
);
|
||||
|
||||
let show_output =
|
||||
self.terminal_expanded && self.terminal_views.contains_key(&terminal.entity_id());
|
||||
let terminal_view = self
|
||||
.entry_view_state
|
||||
.entry(entry_ix)
|
||||
.and_then(|entry| entry.terminal(&terminal));
|
||||
let show_output = self.terminal_expanded && terminal_view.is_some();
|
||||
|
||||
v_flex()
|
||||
.mb_2()
|
||||
|
@ -1661,8 +1572,6 @@ impl AcpThreadView {
|
|||
),
|
||||
)
|
||||
.when(show_output, |this| {
|
||||
let terminal_view = self.terminal_views.get(&terminal.entity_id()).unwrap();
|
||||
|
||||
this.child(
|
||||
div()
|
||||
.pt_2()
|
||||
|
@ -1672,7 +1581,7 @@ impl AcpThreadView {
|
|||
.bg(cx.theme().colors().editor_background)
|
||||
.rounded_b_md()
|
||||
.text_ui_sm(cx)
|
||||
.child(terminal_view.clone()),
|
||||
.children(terminal_view.clone()),
|
||||
)
|
||||
})
|
||||
.into_any()
|
||||
|
@ -3075,12 +2984,7 @@ impl AcpThreadView {
|
|||
}
|
||||
|
||||
fn settings_changed(&mut self, _window: &mut Window, cx: &mut Context<Self>) {
|
||||
for diff_editor in self.diff_editors.values() {
|
||||
diff_editor.update(cx, |diff_editor, cx| {
|
||||
diff_editor.set_text_style_refinement(diff_editor_text_style_refinement(cx));
|
||||
cx.notify();
|
||||
})
|
||||
}
|
||||
self.entry_view_state.settings_changed(cx);
|
||||
}
|
||||
|
||||
pub(crate) fn insert_dragged_files(
|
||||
|
@ -3379,18 +3283,6 @@ fn plan_label_markdown_style(
|
|||
}
|
||||
}
|
||||
|
||||
fn diff_editor_text_style_refinement(cx: &mut App) -> TextStyleRefinement {
|
||||
TextStyleRefinement {
|
||||
font_size: Some(
|
||||
TextSize::Small
|
||||
.rems(cx)
|
||||
.to_pixels(ThemeSettings::get_global(cx).agent_font_size(cx))
|
||||
.into(),
|
||||
),
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
fn terminal_command_markdown_style(window: &Window, cx: &App) -> MarkdownStyle {
|
||||
let default_md_style = default_markdown_style(true, window, cx);
|
||||
|
||||
|
@ -3405,16 +3297,16 @@ fn terminal_command_markdown_style(window: &Window, cx: &App) -> MarkdownStyle {
|
|||
|
||||
#[cfg(test)]
|
||||
pub(crate) mod tests {
|
||||
use std::{path::Path, sync::Arc};
|
||||
use std::path::Path;
|
||||
|
||||
use acp_thread::StubAgentConnection;
|
||||
use agent::{TextThreadStore, ThreadStore};
|
||||
use agent_client_protocol::SessionId;
|
||||
use editor::EditorSettings;
|
||||
use fs::FakeFs;
|
||||
use futures::future::try_join_all;
|
||||
use gpui::{SemanticVersion, TestAppContext, VisualTestContext};
|
||||
use parking_lot::Mutex;
|
||||
use rand::Rng;
|
||||
use project::Project;
|
||||
use serde_json::json;
|
||||
use settings::SettingsStore;
|
||||
|
||||
use super::*;
|
||||
|
@ -3497,8 +3389,8 @@ pub(crate) mod tests {
|
|||
raw_input: None,
|
||||
raw_output: None,
|
||||
};
|
||||
let connection = StubAgentConnection::new(vec![acp::SessionUpdate::ToolCall(tool_call)])
|
||||
.with_permission_requests(HashMap::from_iter([(
|
||||
let connection =
|
||||
StubAgentConnection::new().with_permission_requests(HashMap::from_iter([(
|
||||
tool_call_id,
|
||||
vec![acp::PermissionOption {
|
||||
id: acp::PermissionOptionId("1".into()),
|
||||
|
@ -3506,6 +3398,9 @@ pub(crate) mod tests {
|
|||
kind: acp::PermissionOptionKind::AllowOnce,
|
||||
}],
|
||||
)]));
|
||||
|
||||
connection.set_next_prompt_updates(vec![acp::SessionUpdate::ToolCall(tool_call)]);
|
||||
|
||||
let (thread_view, cx) = setup_thread_view(StubAgentServer::new(connection), cx).await;
|
||||
|
||||
let message_editor = cx.read(|cx| thread_view.read(cx).message_editor.clone());
|
||||
|
@ -3605,115 +3500,6 @@ pub(crate) mod tests {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Default)]
|
||||
struct StubAgentConnection {
|
||||
sessions: Arc<Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
|
||||
permission_requests: HashMap<acp::ToolCallId, Vec<acp::PermissionOption>>,
|
||||
updates: Vec<acp::SessionUpdate>,
|
||||
}
|
||||
|
||||
impl StubAgentConnection {
|
||||
fn new(updates: Vec<acp::SessionUpdate>) -> Self {
|
||||
Self {
|
||||
updates,
|
||||
permission_requests: HashMap::default(),
|
||||
sessions: Arc::default(),
|
||||
}
|
||||
}
|
||||
|
||||
fn with_permission_requests(
|
||||
mut self,
|
||||
permission_requests: HashMap<acp::ToolCallId, Vec<acp::PermissionOption>>,
|
||||
) -> Self {
|
||||
self.permission_requests = permission_requests;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl AgentConnection for StubAgentConnection {
|
||||
fn auth_methods(&self) -> &[acp::AuthMethod] {
|
||||
&[]
|
||||
}
|
||||
|
||||
fn new_thread(
|
||||
self: Rc<Self>,
|
||||
project: Entity<Project>,
|
||||
_cwd: &Path,
|
||||
cx: &mut gpui::AsyncApp,
|
||||
) -> Task<gpui::Result<Entity<AcpThread>>> {
|
||||
let session_id = SessionId(
|
||||
rand::thread_rng()
|
||||
.sample_iter(&rand::distributions::Alphanumeric)
|
||||
.take(7)
|
||||
.map(char::from)
|
||||
.collect::<String>()
|
||||
.into(),
|
||||
);
|
||||
let thread = cx
|
||||
.new(|cx| AcpThread::new("Test", self.clone(), project, session_id.clone(), cx))
|
||||
.unwrap();
|
||||
self.sessions.lock().insert(session_id, thread.downgrade());
|
||||
Task::ready(Ok(thread))
|
||||
}
|
||||
|
||||
fn authenticate(
|
||||
&self,
|
||||
_method_id: acp::AuthMethodId,
|
||||
_cx: &mut App,
|
||||
) -> Task<gpui::Result<()>> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn prompt(
|
||||
&self,
|
||||
_id: Option<acp_thread::UserMessageId>,
|
||||
params: acp::PromptRequest,
|
||||
cx: &mut App,
|
||||
) -> Task<gpui::Result<acp::PromptResponse>> {
|
||||
let sessions = self.sessions.lock();
|
||||
let thread = sessions.get(¶ms.session_id).unwrap();
|
||||
let mut tasks = vec![];
|
||||
for update in &self.updates {
|
||||
let thread = thread.clone();
|
||||
let update = update.clone();
|
||||
let permission_request = if let acp::SessionUpdate::ToolCall(tool_call) = &update
|
||||
&& let Some(options) = self.permission_requests.get(&tool_call.id)
|
||||
{
|
||||
Some((tool_call.clone(), options.clone()))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let task = cx.spawn(async move |cx| {
|
||||
if let Some((tool_call, options)) = permission_request {
|
||||
let permission = thread.update(cx, |thread, cx| {
|
||||
thread.request_tool_call_authorization(
|
||||
tool_call.clone(),
|
||||
options.clone(),
|
||||
cx,
|
||||
)
|
||||
})?;
|
||||
permission.await?;
|
||||
}
|
||||
thread.update(cx, |thread, cx| {
|
||||
thread.handle_session_update(update.clone(), cx).unwrap();
|
||||
})?;
|
||||
anyhow::Ok(())
|
||||
});
|
||||
tasks.push(task);
|
||||
}
|
||||
cx.spawn(async move |_| {
|
||||
try_join_all(tasks).await?;
|
||||
Ok(acp::PromptResponse {
|
||||
stop_reason: acp::StopReason::EndTurn,
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
fn cancel(&self, _session_id: &acp::SessionId, _cx: &mut App) {
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct SaboteurAgentConnection;
|
||||
|
||||
|
@ -3722,19 +3508,17 @@ pub(crate) mod tests {
|
|||
self: Rc<Self>,
|
||||
project: Entity<Project>,
|
||||
_cwd: &Path,
|
||||
cx: &mut gpui::AsyncApp,
|
||||
cx: &mut gpui::App,
|
||||
) -> Task<gpui::Result<Entity<AcpThread>>> {
|
||||
Task::ready(Ok(cx
|
||||
.new(|cx| {
|
||||
AcpThread::new(
|
||||
"SaboteurAgentConnection",
|
||||
self,
|
||||
project,
|
||||
SessionId("test".into()),
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.unwrap()))
|
||||
Task::ready(Ok(cx.new(|cx| {
|
||||
AcpThread::new(
|
||||
"SaboteurAgentConnection",
|
||||
self,
|
||||
project,
|
||||
SessionId("test".into()),
|
||||
cx,
|
||||
)
|
||||
})))
|
||||
}
|
||||
|
||||
fn auth_methods(&self) -> &[acp::AuthMethod] {
|
||||
|
@ -3776,4 +3560,142 @@ pub(crate) mod tests {
|
|||
EditorSettings::register(cx);
|
||||
});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_rewind_views(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
fs.insert_tree(
|
||||
"/project",
|
||||
json!({
|
||||
"test1.txt": "old content 1",
|
||||
"test2.txt": "old content 2"
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
let project = Project::test(fs, [Path::new("/project")], cx).await;
|
||||
let (workspace, cx) =
|
||||
cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
|
||||
|
||||
let thread_store =
|
||||
cx.update(|_window, cx| cx.new(|cx| ThreadStore::fake(project.clone(), cx)));
|
||||
let text_thread_store =
|
||||
cx.update(|_window, cx| cx.new(|cx| TextThreadStore::fake(project.clone(), cx)));
|
||||
|
||||
let connection = Rc::new(StubAgentConnection::new());
|
||||
let thread_view = cx.update(|window, cx| {
|
||||
cx.new(|cx| {
|
||||
AcpThreadView::new(
|
||||
Rc::new(StubAgentServer::new(connection.as_ref().clone())),
|
||||
workspace.downgrade(),
|
||||
project.clone(),
|
||||
thread_store.clone(),
|
||||
text_thread_store.clone(),
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
})
|
||||
});
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
let thread = thread_view
|
||||
.read_with(cx, |view, _| view.thread().cloned())
|
||||
.unwrap();
|
||||
|
||||
// First user message
|
||||
connection.set_next_prompt_updates(vec![acp::SessionUpdate::ToolCall(acp::ToolCall {
|
||||
id: acp::ToolCallId("tool1".into()),
|
||||
title: "Edit file 1".into(),
|
||||
kind: acp::ToolKind::Edit,
|
||||
status: acp::ToolCallStatus::Completed,
|
||||
content: vec![acp::ToolCallContent::Diff {
|
||||
diff: acp::Diff {
|
||||
path: "/project/test1.txt".into(),
|
||||
old_text: Some("old content 1".into()),
|
||||
new_text: "new content 1".into(),
|
||||
},
|
||||
}],
|
||||
locations: vec![],
|
||||
raw_input: None,
|
||||
raw_output: None,
|
||||
})]);
|
||||
|
||||
thread
|
||||
.update(cx, |thread, cx| thread.send_raw("Give me a diff", cx))
|
||||
.await
|
||||
.unwrap();
|
||||
cx.run_until_parked();
|
||||
|
||||
thread.read_with(cx, |thread, _| {
|
||||
assert_eq!(thread.entries().len(), 2);
|
||||
});
|
||||
|
||||
thread_view.read_with(cx, |view, _| {
|
||||
assert_eq!(view.entry_view_state.entry(0).unwrap().len(), 0);
|
||||
assert_eq!(view.entry_view_state.entry(1).unwrap().len(), 1);
|
||||
});
|
||||
|
||||
// Second user message
|
||||
connection.set_next_prompt_updates(vec![acp::SessionUpdate::ToolCall(acp::ToolCall {
|
||||
id: acp::ToolCallId("tool2".into()),
|
||||
title: "Edit file 2".into(),
|
||||
kind: acp::ToolKind::Edit,
|
||||
status: acp::ToolCallStatus::Completed,
|
||||
content: vec![acp::ToolCallContent::Diff {
|
||||
diff: acp::Diff {
|
||||
path: "/project/test2.txt".into(),
|
||||
old_text: Some("old content 2".into()),
|
||||
new_text: "new content 2".into(),
|
||||
},
|
||||
}],
|
||||
locations: vec![],
|
||||
raw_input: None,
|
||||
raw_output: None,
|
||||
})]);
|
||||
|
||||
thread
|
||||
.update(cx, |thread, cx| thread.send_raw("Another one", cx))
|
||||
.await
|
||||
.unwrap();
|
||||
cx.run_until_parked();
|
||||
|
||||
let second_user_message_id = thread.read_with(cx, |thread, _| {
|
||||
assert_eq!(thread.entries().len(), 4);
|
||||
let AgentThreadEntry::UserMessage(user_message) = thread.entries().get(2).unwrap()
|
||||
else {
|
||||
panic!();
|
||||
};
|
||||
user_message.id.clone().unwrap()
|
||||
});
|
||||
|
||||
thread_view.read_with(cx, |view, _| {
|
||||
assert_eq!(view.entry_view_state.entry(0).unwrap().len(), 0);
|
||||
assert_eq!(view.entry_view_state.entry(1).unwrap().len(), 1);
|
||||
assert_eq!(view.entry_view_state.entry(2).unwrap().len(), 0);
|
||||
assert_eq!(view.entry_view_state.entry(3).unwrap().len(), 1);
|
||||
});
|
||||
|
||||
// Rewind to first message
|
||||
thread
|
||||
.update(cx, |thread, cx| thread.rewind(second_user_message_id, cx))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
thread.read_with(cx, |thread, _| {
|
||||
assert_eq!(thread.entries().len(), 2);
|
||||
});
|
||||
|
||||
thread_view.read_with(cx, |view, _| {
|
||||
assert_eq!(view.entry_view_state.entry(0).unwrap().len(), 0);
|
||||
assert_eq!(view.entry_view_state.entry(1).unwrap().len(), 1);
|
||||
|
||||
// Old views should be dropped
|
||||
assert!(view.entry_view_state.entry(2).is_none());
|
||||
assert!(view.entry_view_state.entry(3).is_none());
|
||||
});
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue