parent
ed14ab8c02
commit
b8ddb0141c
5 changed files with 197 additions and 32 deletions
|
@ -22,7 +22,6 @@ use prompt_store::{
|
|||
};
|
||||
use settings::update_settings_file;
|
||||
use std::any::Any;
|
||||
use std::cell::RefCell;
|
||||
use std::collections::HashMap;
|
||||
use std::path::Path;
|
||||
use std::rc::Rc;
|
||||
|
@ -156,7 +155,7 @@ pub struct NativeAgent {
|
|||
/// Session ID -> Session mapping
|
||||
sessions: HashMap<acp::SessionId, Session>,
|
||||
/// Shared project context for all threads
|
||||
project_context: Rc<RefCell<ProjectContext>>,
|
||||
project_context: Entity<ProjectContext>,
|
||||
project_context_needs_refresh: watch::Sender<()>,
|
||||
_maintain_project_context: Task<Result<()>>,
|
||||
context_server_registry: Entity<ContextServerRegistry>,
|
||||
|
@ -200,7 +199,7 @@ impl NativeAgent {
|
|||
watch::channel(());
|
||||
Self {
|
||||
sessions: HashMap::new(),
|
||||
project_context: Rc::new(RefCell::new(project_context)),
|
||||
project_context: cx.new(|_| project_context),
|
||||
project_context_needs_refresh: project_context_needs_refresh_tx,
|
||||
_maintain_project_context: cx.spawn(async move |this, cx| {
|
||||
Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await
|
||||
|
@ -233,7 +232,9 @@ impl NativeAgent {
|
|||
Self::build_project_context(&this.project, this.prompt_store.as_ref(), cx)
|
||||
})?
|
||||
.await;
|
||||
this.update(cx, |this, _| this.project_context.replace(project_context))?;
|
||||
this.update(cx, |this, cx| {
|
||||
this.project_context = cx.new(|_| project_context);
|
||||
})?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
|
@ -872,8 +873,8 @@ mod tests {
|
|||
)
|
||||
.await
|
||||
.unwrap();
|
||||
agent.read_with(cx, |agent, _| {
|
||||
assert_eq!(agent.project_context.borrow().worktrees, vec![])
|
||||
agent.read_with(cx, |agent, cx| {
|
||||
assert_eq!(agent.project_context.read(cx).worktrees, vec![])
|
||||
});
|
||||
|
||||
let worktree = project
|
||||
|
@ -881,9 +882,9 @@ mod tests {
|
|||
.await
|
||||
.unwrap();
|
||||
cx.run_until_parked();
|
||||
agent.read_with(cx, |agent, _| {
|
||||
agent.read_with(cx, |agent, cx| {
|
||||
assert_eq!(
|
||||
agent.project_context.borrow().worktrees,
|
||||
agent.project_context.read(cx).worktrees,
|
||||
vec![WorktreeContext {
|
||||
root_name: "a".into(),
|
||||
abs_path: Path::new("/a").into(),
|
||||
|
@ -898,7 +899,7 @@ mod tests {
|
|||
agent.read_with(cx, |agent, cx| {
|
||||
let rules_entry = worktree.read(cx).entry_for_path(".rules").unwrap();
|
||||
assert_eq!(
|
||||
agent.project_context.borrow().worktrees,
|
||||
agent.project_context.read(cx).worktrees,
|
||||
vec![WorktreeContext {
|
||||
root_name: "a".into(),
|
||||
abs_path: Path::new("/a").into(),
|
||||
|
|
|
@ -25,7 +25,7 @@ use serde::{Deserialize, Serialize};
|
|||
use serde_json::json;
|
||||
use settings::SettingsStore;
|
||||
use smol::stream::StreamExt;
|
||||
use std::{cell::RefCell, path::Path, rc::Rc, sync::Arc, time::Duration};
|
||||
use std::{path::Path, rc::Rc, sync::Arc, time::Duration};
|
||||
use util::path;
|
||||
|
||||
mod test_tools;
|
||||
|
@ -101,7 +101,9 @@ async fn test_system_prompt(cx: &mut TestAppContext) {
|
|||
} = setup(cx, TestModel::Fake).await;
|
||||
let fake_model = model.as_fake();
|
||||
|
||||
project_context.borrow_mut().shell = "test-shell".into();
|
||||
project_context.update(cx, |project_context, _cx| {
|
||||
project_context.shell = "test-shell".into()
|
||||
});
|
||||
thread.update(cx, |thread, _| thread.add_tool(EchoTool));
|
||||
thread
|
||||
.update(cx, |thread, cx| {
|
||||
|
@ -1447,7 +1449,7 @@ fn stop_events(result_events: Vec<Result<AgentResponseEvent>>) -> Vec<acp::StopR
|
|||
struct ThreadTest {
|
||||
model: Arc<dyn LanguageModel>,
|
||||
thread: Entity<Thread>,
|
||||
project_context: Rc<RefCell<ProjectContext>>,
|
||||
project_context: Entity<ProjectContext>,
|
||||
fs: Arc<FakeFs>,
|
||||
}
|
||||
|
||||
|
@ -1543,7 +1545,7 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
|
|||
})
|
||||
.await;
|
||||
|
||||
let project_context = Rc::new(RefCell::new(ProjectContext::default()));
|
||||
let project_context = cx.new(|_cx| ProjectContext::default());
|
||||
let context_server_registry =
|
||||
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
|
||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||
|
|
|
@ -25,7 +25,7 @@ use schemars::{JsonSchema, Schema};
|
|||
use serde::{Deserialize, Serialize};
|
||||
use settings::{Settings, update_settings_file};
|
||||
use smol::stream::StreamExt;
|
||||
use std::{cell::RefCell, collections::BTreeMap, path::Path, rc::Rc, sync::Arc};
|
||||
use std::{collections::BTreeMap, path::Path, sync::Arc};
|
||||
use std::{fmt::Write, ops::Range};
|
||||
use util::{ResultExt, markdown::MarkdownCodeBlock};
|
||||
use uuid::Uuid;
|
||||
|
@ -479,7 +479,7 @@ pub struct Thread {
|
|||
tool_use_limit_reached: bool,
|
||||
context_server_registry: Entity<ContextServerRegistry>,
|
||||
profile_id: AgentProfileId,
|
||||
project_context: Rc<RefCell<ProjectContext>>,
|
||||
project_context: Entity<ProjectContext>,
|
||||
templates: Arc<Templates>,
|
||||
model: Option<Arc<dyn LanguageModel>>,
|
||||
project: Entity<Project>,
|
||||
|
@ -489,7 +489,7 @@ pub struct Thread {
|
|||
impl Thread {
|
||||
pub fn new(
|
||||
project: Entity<Project>,
|
||||
project_context: Rc<RefCell<ProjectContext>>,
|
||||
project_context: Entity<ProjectContext>,
|
||||
context_server_registry: Entity<ContextServerRegistry>,
|
||||
action_log: Entity<ActionLog>,
|
||||
templates: Arc<Templates>,
|
||||
|
@ -520,6 +520,10 @@ impl Thread {
|
|||
&self.project
|
||||
}
|
||||
|
||||
pub fn project_context(&self) -> &Entity<ProjectContext> {
|
||||
&self.project_context
|
||||
}
|
||||
|
||||
pub fn action_log(&self) -> &Entity<ActionLog> {
|
||||
&self.action_log
|
||||
}
|
||||
|
@ -750,10 +754,10 @@ impl Thread {
|
|||
Ok(events_rx)
|
||||
}
|
||||
|
||||
pub fn build_system_message(&self) -> LanguageModelRequestMessage {
|
||||
pub fn build_system_message(&self, cx: &App) -> LanguageModelRequestMessage {
|
||||
log::debug!("Building system message");
|
||||
let prompt = SystemPromptTemplate {
|
||||
project: &self.project_context.borrow(),
|
||||
project: &self.project_context.read(cx),
|
||||
available_tools: self.tools.keys().cloned().collect(),
|
||||
}
|
||||
.render(&self.templates)
|
||||
|
@ -1030,7 +1034,7 @@ impl Thread {
|
|||
log::debug!("Completion intent: {:?}", completion_intent);
|
||||
log::debug!("Completion mode: {:?}", self.completion_mode);
|
||||
|
||||
let messages = self.build_request_messages();
|
||||
let messages = self.build_request_messages(cx);
|
||||
log::info!("Request will include {} messages", messages.len());
|
||||
|
||||
let tools = if let Some(tools) = self.tools(cx).log_err() {
|
||||
|
@ -1101,12 +1105,12 @@ impl Thread {
|
|||
)))
|
||||
}
|
||||
|
||||
fn build_request_messages(&self) -> Vec<LanguageModelRequestMessage> {
|
||||
fn build_request_messages(&self, cx: &App) -> Vec<LanguageModelRequestMessage> {
|
||||
log::trace!(
|
||||
"Building request messages from {} thread messages",
|
||||
self.messages.len()
|
||||
);
|
||||
let mut messages = vec![self.build_system_message()];
|
||||
let mut messages = vec![self.build_system_message(cx)];
|
||||
for message in &self.messages {
|
||||
match message {
|
||||
Message::User(message) => messages.push(message.to_request()),
|
||||
|
|
|
@ -503,9 +503,9 @@ mod tests {
|
|||
use fs::Fs;
|
||||
use gpui::{TestAppContext, UpdateGlobal};
|
||||
use language_model::fake_provider::FakeLanguageModel;
|
||||
use prompt_store::ProjectContext;
|
||||
use serde_json::json;
|
||||
use settings::SettingsStore;
|
||||
use std::rc::Rc;
|
||||
use util::path;
|
||||
|
||||
#[gpui::test]
|
||||
|
@ -522,7 +522,7 @@ mod tests {
|
|||
let thread = cx.new(|cx| {
|
||||
Thread::new(
|
||||
project,
|
||||
Rc::default(),
|
||||
cx.new(|_cx| ProjectContext::default()),
|
||||
context_server_registry,
|
||||
action_log,
|
||||
Templates::new(),
|
||||
|
@ -719,7 +719,7 @@ mod tests {
|
|||
let thread = cx.new(|cx| {
|
||||
Thread::new(
|
||||
project,
|
||||
Rc::default(),
|
||||
cx.new(|_cx| ProjectContext::default()),
|
||||
context_server_registry,
|
||||
action_log.clone(),
|
||||
Templates::new(),
|
||||
|
@ -855,7 +855,7 @@ mod tests {
|
|||
let thread = cx.new(|cx| {
|
||||
Thread::new(
|
||||
project,
|
||||
Rc::default(),
|
||||
cx.new(|_cx| ProjectContext::default()),
|
||||
context_server_registry,
|
||||
action_log.clone(),
|
||||
Templates::new(),
|
||||
|
@ -981,7 +981,7 @@ mod tests {
|
|||
let thread = cx.new(|cx| {
|
||||
Thread::new(
|
||||
project,
|
||||
Rc::default(),
|
||||
cx.new(|_cx| ProjectContext::default()),
|
||||
context_server_registry,
|
||||
action_log.clone(),
|
||||
Templates::new(),
|
||||
|
@ -1118,7 +1118,7 @@ mod tests {
|
|||
let thread = cx.new(|cx| {
|
||||
Thread::new(
|
||||
project,
|
||||
Rc::default(),
|
||||
cx.new(|_cx| ProjectContext::default()),
|
||||
context_server_registry,
|
||||
action_log.clone(),
|
||||
Templates::new(),
|
||||
|
@ -1228,7 +1228,7 @@ mod tests {
|
|||
let thread = cx.new(|cx| {
|
||||
Thread::new(
|
||||
project.clone(),
|
||||
Rc::default(),
|
||||
cx.new(|_cx| ProjectContext::default()),
|
||||
context_server_registry.clone(),
|
||||
action_log.clone(),
|
||||
Templates::new(),
|
||||
|
@ -1309,7 +1309,7 @@ mod tests {
|
|||
let thread = cx.new(|cx| {
|
||||
Thread::new(
|
||||
project.clone(),
|
||||
Rc::default(),
|
||||
cx.new(|_cx| ProjectContext::default()),
|
||||
context_server_registry.clone(),
|
||||
action_log.clone(),
|
||||
Templates::new(),
|
||||
|
@ -1393,7 +1393,7 @@ mod tests {
|
|||
let thread = cx.new(|cx| {
|
||||
Thread::new(
|
||||
project.clone(),
|
||||
Rc::default(),
|
||||
cx.new(|_cx| ProjectContext::default()),
|
||||
context_server_registry.clone(),
|
||||
action_log.clone(),
|
||||
Templates::new(),
|
||||
|
@ -1474,7 +1474,7 @@ mod tests {
|
|||
let thread = cx.new(|cx| {
|
||||
Thread::new(
|
||||
project.clone(),
|
||||
Rc::default(),
|
||||
cx.new(|_cx| ProjectContext::default()),
|
||||
context_server_registry,
|
||||
action_log.clone(),
|
||||
Templates::new(),
|
||||
|
|
|
@ -30,7 +30,7 @@ use language::Buffer;
|
|||
|
||||
use language_model::LanguageModelRegistry;
|
||||
use markdown::{HeadingLevelStyles, Markdown, MarkdownElement, MarkdownStyle};
|
||||
use project::Project;
|
||||
use project::{Project, ProjectEntryId};
|
||||
use prompt_store::PromptId;
|
||||
use rope::Point;
|
||||
use settings::{Settings as _, SettingsStore};
|
||||
|
@ -703,6 +703,38 @@ impl AcpThreadView {
|
|||
})
|
||||
}
|
||||
|
||||
fn handle_open_rules(&mut self, _: &ClickEvent, window: &mut Window, cx: &mut Context<Self>) {
|
||||
let Some(thread) = self.as_native_thread(cx) else {
|
||||
return;
|
||||
};
|
||||
let project_context = thread.read(cx).project_context().read(cx);
|
||||
|
||||
let project_entry_ids = project_context
|
||||
.worktrees
|
||||
.iter()
|
||||
.flat_map(|worktree| worktree.rules_file.as_ref())
|
||||
.map(|rules_file| ProjectEntryId::from_usize(rules_file.project_entry_id))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
self.workspace
|
||||
.update(cx, move |workspace, cx| {
|
||||
// TODO: Open a multibuffer instead? In some cases this doesn't make the set of rules
|
||||
// files clear. For example, if rules file 1 is already open but rules file 2 is not,
|
||||
// this would open and focus rules file 2 in a tab that is not next to rules file 1.
|
||||
let project = workspace.project().read(cx);
|
||||
let project_paths = project_entry_ids
|
||||
.into_iter()
|
||||
.flat_map(|entry_id| project.path_for_entry(entry_id, cx))
|
||||
.collect::<Vec<_>>();
|
||||
for project_path in project_paths {
|
||||
workspace
|
||||
.open_path(project_path, None, true, window, cx)
|
||||
.detach_and_log_err(cx);
|
||||
}
|
||||
})
|
||||
.ok();
|
||||
}
|
||||
|
||||
fn handle_thread_error(&mut self, error: anyhow::Error, cx: &mut Context<Self>) {
|
||||
self.thread_error = Some(ThreadError::from_err(error));
|
||||
cx.notify();
|
||||
|
@ -858,6 +890,12 @@ impl AcpThreadView {
|
|||
let editor_focus = editor.focus_handle(cx).is_focused(window);
|
||||
let focus_border = cx.theme().colors().border_focused;
|
||||
|
||||
let rules_item = if entry_ix == 0 {
|
||||
self.render_rules_item(cx)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
div()
|
||||
.id(("user_message", entry_ix))
|
||||
.py_4()
|
||||
|
@ -874,6 +912,7 @@ impl AcpThreadView {
|
|||
}))
|
||||
})
|
||||
}))
|
||||
.children(rules_item)
|
||||
.child(
|
||||
div()
|
||||
.relative()
|
||||
|
@ -1862,6 +1901,125 @@ impl AcpThreadView {
|
|||
.into_any_element()
|
||||
}
|
||||
|
||||
fn render_rules_item(&self, cx: &Context<Self>) -> Option<AnyElement> {
|
||||
let project_context = self
|
||||
.as_native_thread(cx)?
|
||||
.read(cx)
|
||||
.project_context()
|
||||
.read(cx);
|
||||
|
||||
let user_rules_text = if project_context.user_rules.is_empty() {
|
||||
None
|
||||
} else if project_context.user_rules.len() == 1 {
|
||||
let user_rules = &project_context.user_rules[0];
|
||||
|
||||
match user_rules.title.as_ref() {
|
||||
Some(title) => Some(format!("Using \"{title}\" user rule")),
|
||||
None => Some("Using user rule".into()),
|
||||
}
|
||||
} else {
|
||||
Some(format!(
|
||||
"Using {} user rules",
|
||||
project_context.user_rules.len()
|
||||
))
|
||||
};
|
||||
|
||||
let first_user_rules_id = project_context
|
||||
.user_rules
|
||||
.first()
|
||||
.map(|user_rules| user_rules.uuid.0);
|
||||
|
||||
let rules_files = project_context
|
||||
.worktrees
|
||||
.iter()
|
||||
.filter_map(|worktree| worktree.rules_file.as_ref())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let rules_file_text = match rules_files.as_slice() {
|
||||
&[] => None,
|
||||
&[rules_file] => Some(format!(
|
||||
"Using project {:?} file",
|
||||
rules_file.path_in_worktree
|
||||
)),
|
||||
rules_files => Some(format!("Using {} project rules files", rules_files.len())),
|
||||
};
|
||||
|
||||
if user_rules_text.is_none() && rules_file_text.is_none() {
|
||||
return None;
|
||||
}
|
||||
|
||||
Some(
|
||||
v_flex()
|
||||
.pt_2()
|
||||
.px_2p5()
|
||||
.gap_1()
|
||||
.when_some(user_rules_text, |parent, user_rules_text| {
|
||||
parent.child(
|
||||
h_flex()
|
||||
.w_full()
|
||||
.child(
|
||||
Icon::new(IconName::Reader)
|
||||
.size(IconSize::XSmall)
|
||||
.color(Color::Disabled),
|
||||
)
|
||||
.child(
|
||||
Label::new(user_rules_text)
|
||||
.size(LabelSize::XSmall)
|
||||
.color(Color::Muted)
|
||||
.truncate()
|
||||
.buffer_font(cx)
|
||||
.ml_1p5()
|
||||
.mr_0p5(),
|
||||
)
|
||||
.child(
|
||||
IconButton::new("open-prompt-library", IconName::ArrowUpRight)
|
||||
.shape(ui::IconButtonShape::Square)
|
||||
.icon_size(IconSize::XSmall)
|
||||
.icon_color(Color::Ignored)
|
||||
// TODO: Figure out a way to pass focus handle here so we can display the `OpenRulesLibrary` keybinding
|
||||
.tooltip(Tooltip::text("View User Rules"))
|
||||
.on_click(move |_event, window, cx| {
|
||||
window.dispatch_action(
|
||||
Box::new(OpenRulesLibrary {
|
||||
prompt_to_select: first_user_rules_id,
|
||||
}),
|
||||
cx,
|
||||
)
|
||||
}),
|
||||
),
|
||||
)
|
||||
})
|
||||
.when_some(rules_file_text, |parent, rules_file_text| {
|
||||
parent.child(
|
||||
h_flex()
|
||||
.w_full()
|
||||
.child(
|
||||
Icon::new(IconName::File)
|
||||
.size(IconSize::XSmall)
|
||||
.color(Color::Disabled),
|
||||
)
|
||||
.child(
|
||||
Label::new(rules_file_text)
|
||||
.size(LabelSize::XSmall)
|
||||
.color(Color::Muted)
|
||||
.buffer_font(cx)
|
||||
.ml_1p5()
|
||||
.mr_0p5(),
|
||||
)
|
||||
.child(
|
||||
IconButton::new("open-rule", IconName::ArrowUpRight)
|
||||
.shape(ui::IconButtonShape::Square)
|
||||
.icon_size(IconSize::XSmall)
|
||||
.icon_color(Color::Ignored)
|
||||
.on_click(cx.listener(Self::handle_open_rules))
|
||||
.tooltip(Tooltip::text("View Rules")),
|
||||
),
|
||||
)
|
||||
})
|
||||
.into_any(),
|
||||
)
|
||||
}
|
||||
|
||||
fn render_empty_state(&self, cx: &App) -> AnyElement {
|
||||
let loading = matches!(&self.thread_state, ThreadState::Loading { .. });
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue