agent2: Port rules UI (#36429)

Release Notes:

- N/A
This commit is contained in:
Bennet Bo Fenner 2025-08-19 11:12:57 +02:00 committed by GitHub
parent ed14ab8c02
commit b8ddb0141c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 197 additions and 32 deletions

View file

@ -22,7 +22,6 @@ use prompt_store::{
};
use settings::update_settings_file;
use std::any::Any;
use std::cell::RefCell;
use std::collections::HashMap;
use std::path::Path;
use std::rc::Rc;
@ -156,7 +155,7 @@ pub struct NativeAgent {
/// Session ID -> Session mapping
sessions: HashMap<acp::SessionId, Session>,
/// Shared project context for all threads
project_context: Rc<RefCell<ProjectContext>>,
project_context: Entity<ProjectContext>,
project_context_needs_refresh: watch::Sender<()>,
_maintain_project_context: Task<Result<()>>,
context_server_registry: Entity<ContextServerRegistry>,
@ -200,7 +199,7 @@ impl NativeAgent {
watch::channel(());
Self {
sessions: HashMap::new(),
project_context: Rc::new(RefCell::new(project_context)),
project_context: cx.new(|_| project_context),
project_context_needs_refresh: project_context_needs_refresh_tx,
_maintain_project_context: cx.spawn(async move |this, cx| {
Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await
@ -233,7 +232,9 @@ impl NativeAgent {
Self::build_project_context(&this.project, this.prompt_store.as_ref(), cx)
})?
.await;
this.update(cx, |this, _| this.project_context.replace(project_context))?;
this.update(cx, |this, cx| {
this.project_context = cx.new(|_| project_context);
})?;
}
Ok(())
@ -872,8 +873,8 @@ mod tests {
)
.await
.unwrap();
agent.read_with(cx, |agent, _| {
assert_eq!(agent.project_context.borrow().worktrees, vec![])
agent.read_with(cx, |agent, cx| {
assert_eq!(agent.project_context.read(cx).worktrees, vec![])
});
let worktree = project
@ -881,9 +882,9 @@ mod tests {
.await
.unwrap();
cx.run_until_parked();
agent.read_with(cx, |agent, _| {
agent.read_with(cx, |agent, cx| {
assert_eq!(
agent.project_context.borrow().worktrees,
agent.project_context.read(cx).worktrees,
vec![WorktreeContext {
root_name: "a".into(),
abs_path: Path::new("/a").into(),
@ -898,7 +899,7 @@ mod tests {
agent.read_with(cx, |agent, cx| {
let rules_entry = worktree.read(cx).entry_for_path(".rules").unwrap();
assert_eq!(
agent.project_context.borrow().worktrees,
agent.project_context.read(cx).worktrees,
vec![WorktreeContext {
root_name: "a".into(),
abs_path: Path::new("/a").into(),

View file

@ -25,7 +25,7 @@ use serde::{Deserialize, Serialize};
use serde_json::json;
use settings::SettingsStore;
use smol::stream::StreamExt;
use std::{cell::RefCell, path::Path, rc::Rc, sync::Arc, time::Duration};
use std::{path::Path, rc::Rc, sync::Arc, time::Duration};
use util::path;
mod test_tools;
@ -101,7 +101,9 @@ async fn test_system_prompt(cx: &mut TestAppContext) {
} = setup(cx, TestModel::Fake).await;
let fake_model = model.as_fake();
project_context.borrow_mut().shell = "test-shell".into();
project_context.update(cx, |project_context, _cx| {
project_context.shell = "test-shell".into()
});
thread.update(cx, |thread, _| thread.add_tool(EchoTool));
thread
.update(cx, |thread, cx| {
@ -1447,7 +1449,7 @@ fn stop_events(result_events: Vec<Result<AgentResponseEvent>>) -> Vec<acp::StopR
struct ThreadTest {
model: Arc<dyn LanguageModel>,
thread: Entity<Thread>,
project_context: Rc<RefCell<ProjectContext>>,
project_context: Entity<ProjectContext>,
fs: Arc<FakeFs>,
}
@ -1543,7 +1545,7 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
})
.await;
let project_context = Rc::new(RefCell::new(ProjectContext::default()));
let project_context = cx.new(|_cx| ProjectContext::default());
let context_server_registry =
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
let action_log = cx.new(|_| ActionLog::new(project.clone()));

View file

@ -25,7 +25,7 @@ use schemars::{JsonSchema, Schema};
use serde::{Deserialize, Serialize};
use settings::{Settings, update_settings_file};
use smol::stream::StreamExt;
use std::{cell::RefCell, collections::BTreeMap, path::Path, rc::Rc, sync::Arc};
use std::{collections::BTreeMap, path::Path, sync::Arc};
use std::{fmt::Write, ops::Range};
use util::{ResultExt, markdown::MarkdownCodeBlock};
use uuid::Uuid;
@ -479,7 +479,7 @@ pub struct Thread {
tool_use_limit_reached: bool,
context_server_registry: Entity<ContextServerRegistry>,
profile_id: AgentProfileId,
project_context: Rc<RefCell<ProjectContext>>,
project_context: Entity<ProjectContext>,
templates: Arc<Templates>,
model: Option<Arc<dyn LanguageModel>>,
project: Entity<Project>,
@ -489,7 +489,7 @@ pub struct Thread {
impl Thread {
pub fn new(
project: Entity<Project>,
project_context: Rc<RefCell<ProjectContext>>,
project_context: Entity<ProjectContext>,
context_server_registry: Entity<ContextServerRegistry>,
action_log: Entity<ActionLog>,
templates: Arc<Templates>,
@ -520,6 +520,10 @@ impl Thread {
&self.project
}
pub fn project_context(&self) -> &Entity<ProjectContext> {
&self.project_context
}
pub fn action_log(&self) -> &Entity<ActionLog> {
&self.action_log
}
@ -750,10 +754,10 @@ impl Thread {
Ok(events_rx)
}
pub fn build_system_message(&self) -> LanguageModelRequestMessage {
pub fn build_system_message(&self, cx: &App) -> LanguageModelRequestMessage {
log::debug!("Building system message");
let prompt = SystemPromptTemplate {
project: &self.project_context.borrow(),
project: &self.project_context.read(cx),
available_tools: self.tools.keys().cloned().collect(),
}
.render(&self.templates)
@ -1030,7 +1034,7 @@ impl Thread {
log::debug!("Completion intent: {:?}", completion_intent);
log::debug!("Completion mode: {:?}", self.completion_mode);
let messages = self.build_request_messages();
let messages = self.build_request_messages(cx);
log::info!("Request will include {} messages", messages.len());
let tools = if let Some(tools) = self.tools(cx).log_err() {
@ -1101,12 +1105,12 @@ impl Thread {
)))
}
fn build_request_messages(&self) -> Vec<LanguageModelRequestMessage> {
fn build_request_messages(&self, cx: &App) -> Vec<LanguageModelRequestMessage> {
log::trace!(
"Building request messages from {} thread messages",
self.messages.len()
);
let mut messages = vec![self.build_system_message()];
let mut messages = vec![self.build_system_message(cx)];
for message in &self.messages {
match message {
Message::User(message) => messages.push(message.to_request()),

View file

@ -503,9 +503,9 @@ mod tests {
use fs::Fs;
use gpui::{TestAppContext, UpdateGlobal};
use language_model::fake_provider::FakeLanguageModel;
use prompt_store::ProjectContext;
use serde_json::json;
use settings::SettingsStore;
use std::rc::Rc;
use util::path;
#[gpui::test]
@ -522,7 +522,7 @@ mod tests {
let thread = cx.new(|cx| {
Thread::new(
project,
Rc::default(),
cx.new(|_cx| ProjectContext::default()),
context_server_registry,
action_log,
Templates::new(),
@ -719,7 +719,7 @@ mod tests {
let thread = cx.new(|cx| {
Thread::new(
project,
Rc::default(),
cx.new(|_cx| ProjectContext::default()),
context_server_registry,
action_log.clone(),
Templates::new(),
@ -855,7 +855,7 @@ mod tests {
let thread = cx.new(|cx| {
Thread::new(
project,
Rc::default(),
cx.new(|_cx| ProjectContext::default()),
context_server_registry,
action_log.clone(),
Templates::new(),
@ -981,7 +981,7 @@ mod tests {
let thread = cx.new(|cx| {
Thread::new(
project,
Rc::default(),
cx.new(|_cx| ProjectContext::default()),
context_server_registry,
action_log.clone(),
Templates::new(),
@ -1118,7 +1118,7 @@ mod tests {
let thread = cx.new(|cx| {
Thread::new(
project,
Rc::default(),
cx.new(|_cx| ProjectContext::default()),
context_server_registry,
action_log.clone(),
Templates::new(),
@ -1228,7 +1228,7 @@ mod tests {
let thread = cx.new(|cx| {
Thread::new(
project.clone(),
Rc::default(),
cx.new(|_cx| ProjectContext::default()),
context_server_registry.clone(),
action_log.clone(),
Templates::new(),
@ -1309,7 +1309,7 @@ mod tests {
let thread = cx.new(|cx| {
Thread::new(
project.clone(),
Rc::default(),
cx.new(|_cx| ProjectContext::default()),
context_server_registry.clone(),
action_log.clone(),
Templates::new(),
@ -1393,7 +1393,7 @@ mod tests {
let thread = cx.new(|cx| {
Thread::new(
project.clone(),
Rc::default(),
cx.new(|_cx| ProjectContext::default()),
context_server_registry.clone(),
action_log.clone(),
Templates::new(),
@ -1474,7 +1474,7 @@ mod tests {
let thread = cx.new(|cx| {
Thread::new(
project.clone(),
Rc::default(),
cx.new(|_cx| ProjectContext::default()),
context_server_registry,
action_log.clone(),
Templates::new(),