From b8ddb0141c0625a47fdc7b68aa8a8a782c439f62 Mon Sep 17 00:00:00 2001 From: Bennet Bo Fenner Date: Tue, 19 Aug 2025 11:12:57 +0200 Subject: [PATCH] agent2: Port rules UI (#36429) Release Notes: - N/A --- crates/agent2/src/agent.rs | 19 +-- crates/agent2/src/tests/mod.rs | 10 +- crates/agent2/src/thread.rs | 20 +-- crates/agent2/src/tools/edit_file_tool.rs | 20 +-- crates/agent_ui/src/acp/thread_view.rs | 160 +++++++++++++++++++++- 5 files changed, 197 insertions(+), 32 deletions(-) diff --git a/crates/agent2/src/agent.rs b/crates/agent2/src/agent.rs index 985de4d123..6347f5f9a4 100644 --- a/crates/agent2/src/agent.rs +++ b/crates/agent2/src/agent.rs @@ -22,7 +22,6 @@ use prompt_store::{ }; use settings::update_settings_file; use std::any::Any; -use std::cell::RefCell; use std::collections::HashMap; use std::path::Path; use std::rc::Rc; @@ -156,7 +155,7 @@ pub struct NativeAgent { /// Session ID -> Session mapping sessions: HashMap, /// Shared project context for all threads - project_context: Rc>, + project_context: Entity, project_context_needs_refresh: watch::Sender<()>, _maintain_project_context: Task>, context_server_registry: Entity, @@ -200,7 +199,7 @@ impl NativeAgent { watch::channel(()); Self { sessions: HashMap::new(), - project_context: Rc::new(RefCell::new(project_context)), + project_context: cx.new(|_| project_context), project_context_needs_refresh: project_context_needs_refresh_tx, _maintain_project_context: cx.spawn(async move |this, cx| { Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await @@ -233,7 +232,9 @@ impl NativeAgent { Self::build_project_context(&this.project, this.prompt_store.as_ref(), cx) })? .await; - this.update(cx, |this, _| this.project_context.replace(project_context))?; + this.update(cx, |this, cx| { + this.project_context = cx.new(|_| project_context); + })?; } Ok(()) @@ -872,8 +873,8 @@ mod tests { ) .await .unwrap(); - agent.read_with(cx, |agent, _| { - assert_eq!(agent.project_context.borrow().worktrees, vec![]) + agent.read_with(cx, |agent, cx| { + assert_eq!(agent.project_context.read(cx).worktrees, vec![]) }); let worktree = project @@ -881,9 +882,9 @@ mod tests { .await .unwrap(); cx.run_until_parked(); - agent.read_with(cx, |agent, _| { + agent.read_with(cx, |agent, cx| { assert_eq!( - agent.project_context.borrow().worktrees, + agent.project_context.read(cx).worktrees, vec![WorktreeContext { root_name: "a".into(), abs_path: Path::new("/a").into(), @@ -898,7 +899,7 @@ mod tests { agent.read_with(cx, |agent, cx| { let rules_entry = worktree.read(cx).entry_for_path(".rules").unwrap(); assert_eq!( - agent.project_context.borrow().worktrees, + agent.project_context.read(cx).worktrees, vec![WorktreeContext { root_name: "a".into(), abs_path: Path::new("/a").into(), diff --git a/crates/agent2/src/tests/mod.rs b/crates/agent2/src/tests/mod.rs index e3e3050d49..13b37fbaa2 100644 --- a/crates/agent2/src/tests/mod.rs +++ b/crates/agent2/src/tests/mod.rs @@ -25,7 +25,7 @@ use serde::{Deserialize, Serialize}; use serde_json::json; use settings::SettingsStore; use smol::stream::StreamExt; -use std::{cell::RefCell, path::Path, rc::Rc, sync::Arc, time::Duration}; +use std::{path::Path, rc::Rc, sync::Arc, time::Duration}; use util::path; mod test_tools; @@ -101,7 +101,9 @@ async fn test_system_prompt(cx: &mut TestAppContext) { } = setup(cx, TestModel::Fake).await; let fake_model = model.as_fake(); - project_context.borrow_mut().shell = "test-shell".into(); + project_context.update(cx, |project_context, _cx| { + project_context.shell = "test-shell".into() + }); thread.update(cx, |thread, _| thread.add_tool(EchoTool)); thread .update(cx, |thread, cx| { @@ -1447,7 +1449,7 @@ fn stop_events(result_events: Vec>) -> Vec, thread: Entity, - project_context: Rc>, + project_context: Entity, fs: Arc, } @@ -1543,7 +1545,7 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest { }) .await; - let project_context = Rc::new(RefCell::new(ProjectContext::default())); + let project_context = cx.new(|_cx| ProjectContext::default()); let context_server_registry = cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); let action_log = cx.new(|_| ActionLog::new(project.clone())); diff --git a/crates/agent2/src/thread.rs b/crates/agent2/src/thread.rs index e0819abcc5..7f0465f5ce 100644 --- a/crates/agent2/src/thread.rs +++ b/crates/agent2/src/thread.rs @@ -25,7 +25,7 @@ use schemars::{JsonSchema, Schema}; use serde::{Deserialize, Serialize}; use settings::{Settings, update_settings_file}; use smol::stream::StreamExt; -use std::{cell::RefCell, collections::BTreeMap, path::Path, rc::Rc, sync::Arc}; +use std::{collections::BTreeMap, path::Path, sync::Arc}; use std::{fmt::Write, ops::Range}; use util::{ResultExt, markdown::MarkdownCodeBlock}; use uuid::Uuid; @@ -479,7 +479,7 @@ pub struct Thread { tool_use_limit_reached: bool, context_server_registry: Entity, profile_id: AgentProfileId, - project_context: Rc>, + project_context: Entity, templates: Arc, model: Option>, project: Entity, @@ -489,7 +489,7 @@ pub struct Thread { impl Thread { pub fn new( project: Entity, - project_context: Rc>, + project_context: Entity, context_server_registry: Entity, action_log: Entity, templates: Arc, @@ -520,6 +520,10 @@ impl Thread { &self.project } + pub fn project_context(&self) -> &Entity { + &self.project_context + } + pub fn action_log(&self) -> &Entity { &self.action_log } @@ -750,10 +754,10 @@ impl Thread { Ok(events_rx) } - pub fn build_system_message(&self) -> LanguageModelRequestMessage { + pub fn build_system_message(&self, cx: &App) -> LanguageModelRequestMessage { log::debug!("Building system message"); let prompt = SystemPromptTemplate { - project: &self.project_context.borrow(), + project: &self.project_context.read(cx), available_tools: self.tools.keys().cloned().collect(), } .render(&self.templates) @@ -1030,7 +1034,7 @@ impl Thread { log::debug!("Completion intent: {:?}", completion_intent); log::debug!("Completion mode: {:?}", self.completion_mode); - let messages = self.build_request_messages(); + let messages = self.build_request_messages(cx); log::info!("Request will include {} messages", messages.len()); let tools = if let Some(tools) = self.tools(cx).log_err() { @@ -1101,12 +1105,12 @@ impl Thread { ))) } - fn build_request_messages(&self) -> Vec { + fn build_request_messages(&self, cx: &App) -> Vec { log::trace!( "Building request messages from {} thread messages", self.messages.len() ); - let mut messages = vec![self.build_system_message()]; + let mut messages = vec![self.build_system_message(cx)]; for message in &self.messages { match message { Message::User(message) => messages.push(message.to_request()), diff --git a/crates/agent2/src/tools/edit_file_tool.rs b/crates/agent2/src/tools/edit_file_tool.rs index e70e5e8a14..8ebd2936a5 100644 --- a/crates/agent2/src/tools/edit_file_tool.rs +++ b/crates/agent2/src/tools/edit_file_tool.rs @@ -503,9 +503,9 @@ mod tests { use fs::Fs; use gpui::{TestAppContext, UpdateGlobal}; use language_model::fake_provider::FakeLanguageModel; + use prompt_store::ProjectContext; use serde_json::json; use settings::SettingsStore; - use std::rc::Rc; use util::path; #[gpui::test] @@ -522,7 +522,7 @@ mod tests { let thread = cx.new(|cx| { Thread::new( project, - Rc::default(), + cx.new(|_cx| ProjectContext::default()), context_server_registry, action_log, Templates::new(), @@ -719,7 +719,7 @@ mod tests { let thread = cx.new(|cx| { Thread::new( project, - Rc::default(), + cx.new(|_cx| ProjectContext::default()), context_server_registry, action_log.clone(), Templates::new(), @@ -855,7 +855,7 @@ mod tests { let thread = cx.new(|cx| { Thread::new( project, - Rc::default(), + cx.new(|_cx| ProjectContext::default()), context_server_registry, action_log.clone(), Templates::new(), @@ -981,7 +981,7 @@ mod tests { let thread = cx.new(|cx| { Thread::new( project, - Rc::default(), + cx.new(|_cx| ProjectContext::default()), context_server_registry, action_log.clone(), Templates::new(), @@ -1118,7 +1118,7 @@ mod tests { let thread = cx.new(|cx| { Thread::new( project, - Rc::default(), + cx.new(|_cx| ProjectContext::default()), context_server_registry, action_log.clone(), Templates::new(), @@ -1228,7 +1228,7 @@ mod tests { let thread = cx.new(|cx| { Thread::new( project.clone(), - Rc::default(), + cx.new(|_cx| ProjectContext::default()), context_server_registry.clone(), action_log.clone(), Templates::new(), @@ -1309,7 +1309,7 @@ mod tests { let thread = cx.new(|cx| { Thread::new( project.clone(), - Rc::default(), + cx.new(|_cx| ProjectContext::default()), context_server_registry.clone(), action_log.clone(), Templates::new(), @@ -1393,7 +1393,7 @@ mod tests { let thread = cx.new(|cx| { Thread::new( project.clone(), - Rc::default(), + cx.new(|_cx| ProjectContext::default()), context_server_registry.clone(), action_log.clone(), Templates::new(), @@ -1474,7 +1474,7 @@ mod tests { let thread = cx.new(|cx| { Thread::new( project.clone(), - Rc::default(), + cx.new(|_cx| ProjectContext::default()), context_server_registry, action_log.clone(), Templates::new(), diff --git a/crates/agent_ui/src/acp/thread_view.rs b/crates/agent_ui/src/acp/thread_view.rs index 2cfedfe840..2fffe1b179 100644 --- a/crates/agent_ui/src/acp/thread_view.rs +++ b/crates/agent_ui/src/acp/thread_view.rs @@ -30,7 +30,7 @@ use language::Buffer; use language_model::LanguageModelRegistry; use markdown::{HeadingLevelStyles, Markdown, MarkdownElement, MarkdownStyle}; -use project::Project; +use project::{Project, ProjectEntryId}; use prompt_store::PromptId; use rope::Point; use settings::{Settings as _, SettingsStore}; @@ -703,6 +703,38 @@ impl AcpThreadView { }) } + fn handle_open_rules(&mut self, _: &ClickEvent, window: &mut Window, cx: &mut Context) { + let Some(thread) = self.as_native_thread(cx) else { + return; + }; + let project_context = thread.read(cx).project_context().read(cx); + + let project_entry_ids = project_context + .worktrees + .iter() + .flat_map(|worktree| worktree.rules_file.as_ref()) + .map(|rules_file| ProjectEntryId::from_usize(rules_file.project_entry_id)) + .collect::>(); + + self.workspace + .update(cx, move |workspace, cx| { + // TODO: Open a multibuffer instead? In some cases this doesn't make the set of rules + // files clear. For example, if rules file 1 is already open but rules file 2 is not, + // this would open and focus rules file 2 in a tab that is not next to rules file 1. + let project = workspace.project().read(cx); + let project_paths = project_entry_ids + .into_iter() + .flat_map(|entry_id| project.path_for_entry(entry_id, cx)) + .collect::>(); + for project_path in project_paths { + workspace + .open_path(project_path, None, true, window, cx) + .detach_and_log_err(cx); + } + }) + .ok(); + } + fn handle_thread_error(&mut self, error: anyhow::Error, cx: &mut Context) { self.thread_error = Some(ThreadError::from_err(error)); cx.notify(); @@ -858,6 +890,12 @@ impl AcpThreadView { let editor_focus = editor.focus_handle(cx).is_focused(window); let focus_border = cx.theme().colors().border_focused; + let rules_item = if entry_ix == 0 { + self.render_rules_item(cx) + } else { + None + }; + div() .id(("user_message", entry_ix)) .py_4() @@ -874,6 +912,7 @@ impl AcpThreadView { })) }) })) + .children(rules_item) .child( div() .relative() @@ -1862,6 +1901,125 @@ impl AcpThreadView { .into_any_element() } + fn render_rules_item(&self, cx: &Context) -> Option { + let project_context = self + .as_native_thread(cx)? + .read(cx) + .project_context() + .read(cx); + + let user_rules_text = if project_context.user_rules.is_empty() { + None + } else if project_context.user_rules.len() == 1 { + let user_rules = &project_context.user_rules[0]; + + match user_rules.title.as_ref() { + Some(title) => Some(format!("Using \"{title}\" user rule")), + None => Some("Using user rule".into()), + } + } else { + Some(format!( + "Using {} user rules", + project_context.user_rules.len() + )) + }; + + let first_user_rules_id = project_context + .user_rules + .first() + .map(|user_rules| user_rules.uuid.0); + + let rules_files = project_context + .worktrees + .iter() + .filter_map(|worktree| worktree.rules_file.as_ref()) + .collect::>(); + + let rules_file_text = match rules_files.as_slice() { + &[] => None, + &[rules_file] => Some(format!( + "Using project {:?} file", + rules_file.path_in_worktree + )), + rules_files => Some(format!("Using {} project rules files", rules_files.len())), + }; + + if user_rules_text.is_none() && rules_file_text.is_none() { + return None; + } + + Some( + v_flex() + .pt_2() + .px_2p5() + .gap_1() + .when_some(user_rules_text, |parent, user_rules_text| { + parent.child( + h_flex() + .w_full() + .child( + Icon::new(IconName::Reader) + .size(IconSize::XSmall) + .color(Color::Disabled), + ) + .child( + Label::new(user_rules_text) + .size(LabelSize::XSmall) + .color(Color::Muted) + .truncate() + .buffer_font(cx) + .ml_1p5() + .mr_0p5(), + ) + .child( + IconButton::new("open-prompt-library", IconName::ArrowUpRight) + .shape(ui::IconButtonShape::Square) + .icon_size(IconSize::XSmall) + .icon_color(Color::Ignored) + // TODO: Figure out a way to pass focus handle here so we can display the `OpenRulesLibrary` keybinding + .tooltip(Tooltip::text("View User Rules")) + .on_click(move |_event, window, cx| { + window.dispatch_action( + Box::new(OpenRulesLibrary { + prompt_to_select: first_user_rules_id, + }), + cx, + ) + }), + ), + ) + }) + .when_some(rules_file_text, |parent, rules_file_text| { + parent.child( + h_flex() + .w_full() + .child( + Icon::new(IconName::File) + .size(IconSize::XSmall) + .color(Color::Disabled), + ) + .child( + Label::new(rules_file_text) + .size(LabelSize::XSmall) + .color(Color::Muted) + .buffer_font(cx) + .ml_1p5() + .mr_0p5(), + ) + .child( + IconButton::new("open-rule", IconName::ArrowUpRight) + .shape(ui::IconButtonShape::Square) + .icon_size(IconSize::XSmall) + .icon_color(Color::Ignored) + .on_click(cx.listener(Self::handle_open_rules)) + .tooltip(Tooltip::text("View Rules")), + ), + ) + }) + .into_any(), + ) + } + fn render_empty_state(&self, cx: &App) -> AnyElement { let loading = matches!(&self.thread_state, ThreadState::Loading { .. });