diff --git a/assets/prompts/assistant_system_prompt.hbs b/assets/prompts/assistant_system_prompt.hbs index eb91c4f4a8..1aca0a938b 100644 --- a/assets/prompts/assistant_system_prompt.hbs +++ b/assets/prompts/assistant_system_prompt.hbs @@ -73,9 +73,9 @@ There are project rules that apply to these root directories: {{/each}} {{/if}} -{{#if has_default_user_rules}} +{{#if has_user_rules}} The user has specified the following rules that should be applied: -{{#each default_user_rules}} +{{#each user_rules}} {{#if title}} Rules title: {{title}} diff --git a/crates/agent/src/active_thread.rs b/crates/agent/src/active_thread.rs index 2a4a00cf23..a0dd54218d 100644 --- a/crates/agent/src/active_thread.rs +++ b/crates/agent/src/active_thread.rs @@ -1,4 +1,4 @@ -use crate::context::{AssistantContext, ContextId, format_context_as_string}; +use crate::context::{AssistantContext, ContextId, RULES_ICON, format_context_as_string}; use crate::context_picker::MentionLink; use crate::thread::{ LastRestoreCheckpoint, MessageId, MessageSegment, Thread, ThreadError, ThreadEvent, @@ -688,6 +688,12 @@ fn open_markdown_link( } }), Some(MentionLink::Fetch(url)) => cx.open_url(&url), + Some(MentionLink::Rules(prompt_id)) => window.dispatch_action( + Box::new(OpenPromptLibrary { + prompt_to_select: Some(prompt_id.0), + }), + cx, + ), None => cx.open_url(&text), } } @@ -2957,10 +2963,10 @@ impl ActiveThread { return div().into_any(); }; - let default_user_rules_text = if project_context.default_user_rules.is_empty() { + let user_rules_text = if project_context.user_rules.is_empty() { None - } else if project_context.default_user_rules.len() == 1 { - let user_rules = &project_context.default_user_rules[0]; + } else if project_context.user_rules.len() == 1 { + let user_rules = &project_context.user_rules[0]; match user_rules.title.as_ref() { Some(title) => Some(format!("Using \"{title}\" user rule")), @@ -2969,14 +2975,14 @@ impl ActiveThread { } else { Some(format!( "Using {} user rules", - project_context.default_user_rules.len() + project_context.user_rules.len() )) }; - let first_default_user_rules_id = project_context - .default_user_rules + let first_user_rules_id = project_context + .user_rules .first() - .map(|user_rules| user_rules.uuid); + .map(|user_rules| user_rules.uuid.0); let rules_files = project_context .worktrees @@ -2993,7 +2999,7 @@ impl ActiveThread { rules_files => Some(format!("Using {} project rules files", rules_files.len())), }; - if default_user_rules_text.is_none() && rules_file_text.is_none() { + if user_rules_text.is_none() && rules_file_text.is_none() { return div().into_any(); } @@ -3001,45 +3007,42 @@ impl ActiveThread { .pt_2() .px_2p5() .gap_1() - .when_some( - default_user_rules_text, - |parent, default_user_rules_text| { - parent.child( - h_flex() - .w_full() - .child( - Icon::new(IconName::File) - .size(IconSize::XSmall) - .color(Color::Disabled), - ) - .child( - Label::new(default_user_rules_text) - .size(LabelSize::XSmall) - .color(Color::Muted) - .truncate() - .buffer_font(cx) - .ml_1p5() - .mr_0p5(), - ) - .child( - IconButton::new("open-prompt-library", IconName::ArrowUpRightAlt) - .shape(ui::IconButtonShape::Square) - .icon_size(IconSize::XSmall) - .icon_color(Color::Ignored) - // TODO: Figure out a way to pass focus handle here so we can display the `OpenPromptLibrary` keybinding - .tooltip(Tooltip::text("View User Rules")) - .on_click(move |_event, window, cx| { - window.dispatch_action( - Box::new(OpenPromptLibrary { - prompt_to_focus: first_default_user_rules_id, - }), - cx, - ) - }), - ), - ) - }, - ) + .when_some(user_rules_text, |parent, user_rules_text| { + parent.child( + h_flex() + .w_full() + .child( + Icon::new(RULES_ICON) + .size(IconSize::XSmall) + .color(Color::Disabled), + ) + .child( + Label::new(user_rules_text) + .size(LabelSize::XSmall) + .color(Color::Muted) + .truncate() + .buffer_font(cx) + .ml_1p5() + .mr_0p5(), + ) + .child( + IconButton::new("open-prompt-library", IconName::ArrowUpRightAlt) + .shape(ui::IconButtonShape::Square) + .icon_size(IconSize::XSmall) + .icon_color(Color::Ignored) + // TODO: Figure out a way to pass focus handle here so we can display the `OpenPromptLibrary` keybinding + .tooltip(Tooltip::text("View User Rules")) + .on_click(move |_event, window, cx| { + window.dispatch_action( + Box::new(OpenPromptLibrary { + prompt_to_select: first_user_rules_id, + }), + cx, + ) + }), + ), + ) + }) .when_some(rules_file_text, |parent, rules_file_text| { parent.child( h_flex() @@ -3316,6 +3319,12 @@ pub(crate) fn open_context( } }) } + AssistantContext::Rules(rules_context) => window.dispatch_action( + Box::new(OpenPromptLibrary { + prompt_to_select: Some(rules_context.prompt_id.0), + }), + cx, + ), } } diff --git a/crates/agent/src/assistant_panel.rs b/crates/agent/src/assistant_panel.rs index cea3a781a2..82c7b2be9d 100644 --- a/crates/agent/src/assistant_panel.rs +++ b/crates/agent/src/assistant_panel.rs @@ -25,7 +25,7 @@ use language_model::{LanguageModelProviderTosView, LanguageModelRegistry}; use language_model_selector::ToggleModelSelector; use project::Project; use prompt_library::{PromptLibrary, open_prompt_library}; -use prompt_store::{PromptBuilder, PromptId}; +use prompt_store::{PromptBuilder, PromptId, UserPromptId}; use proto::Plan; use settings::{Settings, update_settings_file}; use time::UtcOffset; @@ -79,11 +79,11 @@ pub fn init(cx: &mut App) { panel.update(cx, |panel, cx| panel.new_prompt_editor(window, cx)); } }) - .register_action(|workspace, _: &OpenPromptLibrary, window, cx| { + .register_action(|workspace, action: &OpenPromptLibrary, window, cx| { if let Some(panel) = workspace.panel::(cx) { workspace.focus_panel::(window, cx); panel.update(cx, |panel, cx| { - panel.deploy_prompt_library(&OpenPromptLibrary::default(), window, cx) + panel.deploy_prompt_library(action, window, cx) }); } }) @@ -502,7 +502,9 @@ impl AssistantPanel { None, )) }), - action.prompt_to_focus.map(|uuid| PromptId::User { uuid }), + action.prompt_to_select.map(|uuid| PromptId::User { + uuid: UserPromptId(uuid), + }), cx, ) .detach_and_log_err(cx); diff --git a/crates/agent/src/context.rs b/crates/agent/src/context.rs index 0770c5b4d7..b6213cde54 100644 --- a/crates/agent/src/context.rs +++ b/crates/agent/src/context.rs @@ -4,6 +4,7 @@ use gpui::{App, Entity, SharedString}; use language::{Buffer, File}; use language_model::LanguageModelRequestMessage; use project::{ProjectPath, Worktree}; +use prompt_store::UserPromptId; use rope::Point; use serde::{Deserialize, Serialize}; use text::{Anchor, BufferId}; @@ -12,6 +13,8 @@ use util::post_inc; use crate::thread::Thread; +pub const RULES_ICON: IconName = IconName::Context; + #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)] pub struct ContextId(pub(crate) usize); @@ -20,6 +23,7 @@ impl ContextId { Self(post_inc(&mut self.0)) } } + pub enum ContextKind { File, Directory, @@ -27,6 +31,7 @@ pub enum ContextKind { Excerpt, FetchedUrl, Thread, + Rules, } impl ContextKind { @@ -38,6 +43,7 @@ impl ContextKind { ContextKind::Excerpt => IconName::Code, ContextKind::FetchedUrl => IconName::Globe, ContextKind::Thread => IconName::MessageBubbles, + ContextKind::Rules => RULES_ICON, } } } @@ -50,6 +56,7 @@ pub enum AssistantContext { FetchedUrl(FetchedUrlContext), Thread(ThreadContext), Excerpt(ExcerptContext), + Rules(RulesContext), } impl AssistantContext { @@ -61,6 +68,7 @@ impl AssistantContext { Self::FetchedUrl(url) => url.id, Self::Thread(thread) => thread.id, Self::Excerpt(excerpt) => excerpt.id, + Self::Rules(rules) => rules.id, } } } @@ -168,6 +176,14 @@ pub struct ExcerptContext { pub context_buffer: ContextBuffer, } +#[derive(Debug, Clone)] +pub struct RulesContext { + pub id: ContextId, + pub prompt_id: UserPromptId, + pub title: SharedString, + pub text: SharedString, +} + /// Formats a collection of contexts into a string representation pub fn format_context_as_string<'a>( contexts: impl Iterator, @@ -179,6 +195,7 @@ pub fn format_context_as_string<'a>( let mut excerpt_context = Vec::new(); let mut fetch_context = Vec::new(); let mut thread_context = Vec::new(); + let mut rules_context = Vec::new(); for context in contexts { match context { @@ -188,6 +205,7 @@ pub fn format_context_as_string<'a>( AssistantContext::Excerpt(context) => excerpt_context.push(context), AssistantContext::FetchedUrl(context) => fetch_context.push(context), AssistantContext::Thread(context) => thread_context.push(context), + AssistantContext::Rules(context) => rules_context.push(context), } } @@ -197,6 +215,7 @@ pub fn format_context_as_string<'a>( && excerpt_context.is_empty() && fetch_context.is_empty() && thread_context.is_empty() + && rules_context.is_empty() { return None; } @@ -263,6 +282,18 @@ pub fn format_context_as_string<'a>( result.push_str("\n"); } + if !rules_context.is_empty() { + result.push_str( + "\n\ + The user has specified the following rules that should be applied:\n\n", + ); + for context in &rules_context { + result.push_str(&context.text); + result.push('\n'); + } + result.push_str("\n"); + } + result.push_str("\n"); Some(result) } diff --git a/crates/agent/src/context_picker.rs b/crates/agent/src/context_picker.rs index 9e578c4fc0..ce08179fb9 100644 --- a/crates/agent/src/context_picker.rs +++ b/crates/agent/src/context_picker.rs @@ -1,6 +1,7 @@ mod completion_provider; mod fetch_context_picker; mod file_context_picker; +mod rules_context_picker; mod symbol_context_picker; mod thread_context_picker; @@ -18,17 +19,22 @@ use gpui::{ }; use multi_buffer::MultiBufferRow; use project::{Entry, ProjectPath}; +use prompt_store::UserPromptId; +use rules_context_picker::RulesContextEntry; use symbol_context_picker::SymbolContextPicker; use thread_context_picker::{ThreadContextEntry, render_thread_context_entry}; use ui::{ ButtonLike, ContextMenu, ContextMenuEntry, ContextMenuItem, Disclosure, TintColor, prelude::*, }; +use uuid::Uuid; use workspace::{Workspace, notifications::NotifyResultExt}; use crate::AssistantPanel; +use crate::context::RULES_ICON; pub use crate::context_picker::completion_provider::ContextPickerCompletionProvider; use crate::context_picker::fetch_context_picker::FetchContextPicker; use crate::context_picker::file_context_picker::FileContextPicker; +use crate::context_picker::rules_context_picker::RulesContextPicker; use crate::context_picker::thread_context_picker::ThreadContextPicker; use crate::context_store::ContextStore; use crate::thread::ThreadId; @@ -40,6 +46,7 @@ enum ContextPickerMode { Symbol, Fetch, Thread, + Rules, } impl TryFrom<&str> for ContextPickerMode { @@ -51,6 +58,7 @@ impl TryFrom<&str> for ContextPickerMode { "symbol" => Ok(Self::Symbol), "fetch" => Ok(Self::Fetch), "thread" => Ok(Self::Thread), + "rules" => Ok(Self::Rules), _ => Err(format!("Invalid context picker mode: {}", value)), } } @@ -63,6 +71,7 @@ impl ContextPickerMode { Self::Symbol => "symbol", Self::Fetch => "fetch", Self::Thread => "thread", + Self::Rules => "rules", } } @@ -72,6 +81,7 @@ impl ContextPickerMode { Self::Symbol => "Symbols", Self::Fetch => "Fetch", Self::Thread => "Threads", + Self::Rules => "Rules", } } @@ -81,6 +91,7 @@ impl ContextPickerMode { Self::Symbol => IconName::Code, Self::Fetch => IconName::Globe, Self::Thread => IconName::MessageBubbles, + Self::Rules => RULES_ICON, } } } @@ -92,6 +103,7 @@ enum ContextPickerState { Symbol(Entity), Fetch(Entity), Thread(Entity), + Rules(Entity), } pub(super) struct ContextPicker { @@ -253,6 +265,19 @@ impl ContextPicker { })); } } + ContextPickerMode::Rules => { + if let Some(thread_store) = self.thread_store.as_ref() { + self.mode = ContextPickerState::Rules(cx.new(|cx| { + RulesContextPicker::new( + thread_store.clone(), + context_picker.clone(), + self.context_store.clone(), + window, + cx, + ) + })); + } + } } cx.notify(); @@ -381,6 +406,7 @@ impl ContextPicker { ContextPickerState::Symbol(entity) => entity.update(cx, |_, cx| cx.notify()), ContextPickerState::Fetch(entity) => entity.update(cx, |_, cx| cx.notify()), ContextPickerState::Thread(entity) => entity.update(cx, |_, cx| cx.notify()), + ContextPickerState::Rules(entity) => entity.update(cx, |_, cx| cx.notify()), } } } @@ -395,6 +421,7 @@ impl Focusable for ContextPicker { ContextPickerState::Symbol(symbol_picker) => symbol_picker.focus_handle(cx), ContextPickerState::Fetch(fetch_picker) => fetch_picker.focus_handle(cx), ContextPickerState::Thread(thread_picker) => thread_picker.focus_handle(cx), + ContextPickerState::Rules(user_rules_picker) => user_rules_picker.focus_handle(cx), } } } @@ -410,6 +437,9 @@ impl Render for ContextPicker { ContextPickerState::Symbol(symbol_picker) => parent.child(symbol_picker.clone()), ContextPickerState::Fetch(fetch_picker) => parent.child(fetch_picker.clone()), ContextPickerState::Thread(thread_picker) => parent.child(thread_picker.clone()), + ContextPickerState::Rules(user_rules_picker) => { + parent.child(user_rules_picker.clone()) + } }) } } @@ -431,6 +461,7 @@ fn supported_context_picker_modes( ]; if thread_store.is_some() { modes.push(ContextPickerMode::Thread); + modes.push(ContextPickerMode::Rules); } modes } @@ -626,6 +657,7 @@ pub enum MentionLink { Symbol(ProjectPath, String), Fetch(String), Thread(ThreadId), + Rules(UserPromptId), } impl MentionLink { @@ -633,14 +665,16 @@ impl MentionLink { const SYMBOL: &str = "@symbol"; const THREAD: &str = "@thread"; const FETCH: &str = "@fetch"; + const RULES: &str = "@rules"; const SEPARATOR: &str = ":"; pub fn is_valid(url: &str) -> bool { url.starts_with(Self::FILE) || url.starts_with(Self::SYMBOL) - || url.starts_with(Self::FETCH) || url.starts_with(Self::THREAD) + || url.starts_with(Self::FETCH) + || url.starts_with(Self::RULES) } pub fn for_file(file_name: &str, full_path: &str) -> String { @@ -657,12 +691,16 @@ impl MentionLink { ) } + pub fn for_thread(thread: &ThreadContextEntry) -> String { + format!("[@{}]({}:{})", thread.summary, Self::THREAD, thread.id) + } + pub fn for_fetch(url: &str) -> String { format!("[@{}]({}:{})", url, Self::FETCH, url) } - pub fn for_thread(thread: &ThreadContextEntry) -> String { - format!("[@{}]({}:{})", thread.summary, Self::THREAD, thread.id) + pub fn for_rules(rules: &RulesContextEntry) -> String { + format!("[@{}]({}:{})", rules.title, Self::RULES, rules.prompt_id.0) } pub fn try_parse(link: &str, workspace: &Entity, cx: &App) -> Option { @@ -706,6 +744,10 @@ impl MentionLink { Some(MentionLink::Thread(thread_id)) } Self::FETCH => Some(MentionLink::Fetch(argument.to_string())), + Self::RULES => { + let prompt_id = UserPromptId(Uuid::try_parse(argument).ok()?); + Some(MentionLink::Rules(prompt_id)) + } _ => None, } } diff --git a/crates/agent/src/context_picker/completion_provider.rs b/crates/agent/src/context_picker/completion_provider.rs index efa1cf5431..6703e93d66 100644 --- a/crates/agent/src/context_picker/completion_provider.rs +++ b/crates/agent/src/context_picker/completion_provider.rs @@ -14,11 +14,13 @@ use http_client::HttpClientWithUrl; use language::{Buffer, CodeLabel, HighlightId}; use lsp::CompletionContext; use project::{Completion, CompletionIntent, ProjectPath, Symbol, WorktreeId}; +use prompt_store::PromptId; use rope::Point; use text::{Anchor, ToPoint}; use ui::prelude::*; use workspace::Workspace; +use crate::context::RULES_ICON; use crate::context_picker::file_context_picker::search_files; use crate::context_picker::symbol_context_picker::search_symbols; use crate::context_store::ContextStore; @@ -26,6 +28,7 @@ use crate::thread_store::ThreadStore; use super::fetch_context_picker::fetch_url_content; use super::file_context_picker::FileMatch; +use super::rules_context_picker::{RulesContextEntry, search_rules}; use super::symbol_context_picker::SymbolMatch; use super::thread_context_picker::{ThreadContextEntry, ThreadMatch, search_threads}; use super::{ @@ -38,6 +41,7 @@ pub(crate) enum Match { File(FileMatch), Thread(ThreadMatch), Fetch(SharedString), + Rules(RulesContextEntry), Mode(ModeMatch), } @@ -54,6 +58,7 @@ impl Match { Match::Thread(_) => 1., Match::Symbol(_) => 1., Match::Fetch(_) => 1., + Match::Rules(_) => 1., } } } @@ -112,6 +117,21 @@ fn search( Task::ready(Vec::new()) } } + Some(ContextPickerMode::Rules) => { + if let Some(thread_store) = thread_store.as_ref().and_then(|t| t.upgrade()) { + let search_rules_task = + search_rules(query.clone(), cancellation_flag.clone(), thread_store, cx); + cx.background_spawn(async move { + search_rules_task + .await + .into_iter() + .map(Match::Rules) + .collect::>() + }) + } else { + Task::ready(Vec::new()) + } + } None => { if query.is_empty() { let mut matches = recent_entries @@ -287,6 +307,60 @@ impl ContextPickerCompletionProvider { } } + fn completion_for_rules( + rules: RulesContextEntry, + excerpt_id: ExcerptId, + source_range: Range, + editor: Entity, + context_store: Entity, + thread_store: Entity, + ) -> Completion { + let new_text = MentionLink::for_rules(&rules); + let new_text_len = new_text.len(); + Completion { + replace_range: source_range.clone(), + new_text, + label: CodeLabel::plain(rules.title.to_string(), None), + documentation: None, + insert_text_mode: None, + source: project::CompletionSource::Custom, + icon_path: Some(RULES_ICON.path().into()), + confirm: Some(confirm_completion_callback( + RULES_ICON.path().into(), + rules.title.clone(), + excerpt_id, + source_range.start, + new_text_len, + editor.clone(), + move |cx| { + let prompt_uuid = rules.prompt_id; + let prompt_id = PromptId::User { uuid: prompt_uuid }; + let context_store = context_store.clone(); + let Some(prompt_store) = thread_store.read(cx).prompt_store() else { + log::error!("Can't add user rules as prompt store is missing."); + return; + }; + let prompt_store = prompt_store.read(cx); + let Some(metadata) = prompt_store.metadata(prompt_id) else { + return; + }; + let Some(title) = metadata.title else { + return; + }; + let text_task = prompt_store.load(prompt_id, cx); + + cx.spawn(async move |cx| { + let text = text_task.await?; + context_store.update(cx, |context_store, cx| { + context_store.add_rules(prompt_uuid, title, text, false, cx) + }) + }) + .detach_and_log_err(cx); + }, + )), + } + } + fn completion_for_fetch( source_range: Range, url_to_fetch: SharedString, @@ -593,6 +667,17 @@ impl CompletionProvider for ContextPickerCompletionProvider { thread_store, )) } + Match::Rules(user_rules) => { + let thread_store = thread_store.as_ref().and_then(|t| t.upgrade())?; + Some(Self::completion_for_rules( + user_rules, + excerpt_id, + source_range.clone(), + editor.clone(), + context_store.clone(), + thread_store, + )) + } Match::Fetch(url) => Some(Self::completion_for_fetch( source_range.clone(), url, diff --git a/crates/agent/src/context_picker/rules_context_picker.rs b/crates/agent/src/context_picker/rules_context_picker.rs new file mode 100644 index 0000000000..4c1fc65303 --- /dev/null +++ b/crates/agent/src/context_picker/rules_context_picker.rs @@ -0,0 +1,248 @@ +use std::sync::Arc; +use std::sync::atomic::AtomicBool; + +use anyhow::anyhow; +use gpui::{App, DismissEvent, Entity, FocusHandle, Focusable, Task, WeakEntity}; +use picker::{Picker, PickerDelegate}; +use prompt_store::{PromptId, UserPromptId}; +use ui::{ListItem, prelude::*}; + +use crate::context::RULES_ICON; +use crate::context_picker::ContextPicker; +use crate::context_store::{self, ContextStore}; +use crate::thread_store::ThreadStore; + +pub struct RulesContextPicker { + picker: Entity>, +} + +impl RulesContextPicker { + pub fn new( + thread_store: WeakEntity, + context_picker: WeakEntity, + context_store: WeakEntity, + window: &mut Window, + cx: &mut Context, + ) -> Self { + let delegate = RulesContextPickerDelegate::new(thread_store, context_picker, context_store); + let picker = cx.new(|cx| Picker::uniform_list(delegate, window, cx)); + + RulesContextPicker { picker } + } +} + +impl Focusable for RulesContextPicker { + fn focus_handle(&self, cx: &App) -> FocusHandle { + self.picker.focus_handle(cx) + } +} + +impl Render for RulesContextPicker { + fn render(&mut self, _window: &mut Window, _cx: &mut Context) -> impl IntoElement { + self.picker.clone() + } +} + +#[derive(Debug, Clone)] +pub struct RulesContextEntry { + pub prompt_id: UserPromptId, + pub title: SharedString, +} + +pub struct RulesContextPickerDelegate { + thread_store: WeakEntity, + context_picker: WeakEntity, + context_store: WeakEntity, + matches: Vec, + selected_index: usize, +} + +impl RulesContextPickerDelegate { + pub fn new( + thread_store: WeakEntity, + context_picker: WeakEntity, + context_store: WeakEntity, + ) -> Self { + RulesContextPickerDelegate { + thread_store, + context_picker, + context_store, + matches: Vec::new(), + selected_index: 0, + } + } +} + +impl PickerDelegate for RulesContextPickerDelegate { + type ListItem = ListItem; + + fn match_count(&self) -> usize { + self.matches.len() + } + + fn selected_index(&self) -> usize { + self.selected_index + } + + fn set_selected_index( + &mut self, + ix: usize, + _window: &mut Window, + _cx: &mut Context>, + ) { + self.selected_index = ix; + } + + fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc { + "Search available rules…".into() + } + + fn update_matches( + &mut self, + query: String, + window: &mut Window, + cx: &mut Context>, + ) -> Task<()> { + let Some(thread_store) = self.thread_store.upgrade() else { + return Task::ready(()); + }; + + let search_task = search_rules(query, Arc::new(AtomicBool::default()), thread_store, cx); + cx.spawn_in(window, async move |this, cx| { + let matches = search_task.await; + this.update(cx, |this, cx| { + this.delegate.matches = matches; + this.delegate.selected_index = 0; + cx.notify(); + }) + .ok(); + }) + } + + fn confirm(&mut self, _secondary: bool, _window: &mut Window, cx: &mut Context>) { + let Some(entry) = self.matches.get(self.selected_index) else { + return; + }; + + let Some(thread_store) = self.thread_store.upgrade() else { + return; + }; + + let prompt_id = entry.prompt_id; + + let load_rules_task = thread_store.update(cx, |thread_store, cx| { + thread_store.load_rules(prompt_id, cx) + }); + + cx.spawn(async move |this, cx| { + let (metadata, text) = load_rules_task.await?; + let Some(title) = metadata.title else { + return Err(anyhow!("Encountered user rule with no title when attempting to add it to agent context.")); + }; + this.update(cx, |this, cx| { + this.delegate + .context_store + .update(cx, |context_store, cx| { + context_store.add_rules(prompt_id, title, text, true, cx) + }) + .ok(); + }) + }) + .detach_and_log_err(cx); + } + + fn dismissed(&mut self, _window: &mut Window, cx: &mut Context>) { + self.context_picker + .update(cx, |_, cx| { + cx.emit(DismissEvent); + }) + .ok(); + } + + fn render_match( + &self, + ix: usize, + selected: bool, + _window: &mut Window, + cx: &mut Context>, + ) -> Option { + let thread = &self.matches[ix]; + + Some(ListItem::new(ix).inset(true).toggle_state(selected).child( + render_thread_context_entry(thread, self.context_store.clone(), cx), + )) + } +} + +pub fn render_thread_context_entry( + user_rules: &RulesContextEntry, + context_store: WeakEntity, + cx: &mut App, +) -> Div { + let added = context_store.upgrade().map_or(false, |ctx_store| { + ctx_store + .read(cx) + .includes_user_rules(&user_rules.prompt_id) + .is_some() + }); + + h_flex() + .gap_1p5() + .w_full() + .justify_between() + .child( + h_flex() + .gap_1p5() + .max_w_72() + .child( + Icon::new(RULES_ICON) + .size(IconSize::XSmall) + .color(Color::Muted), + ) + .child(Label::new(user_rules.title.clone()).truncate()), + ) + .when(added, |el| { + el.child( + h_flex() + .gap_1() + .child( + Icon::new(IconName::Check) + .size(IconSize::Small) + .color(Color::Success), + ) + .child(Label::new("Added").size(LabelSize::Small)), + ) + }) +} + +pub(crate) fn search_rules( + query: String, + cancellation_flag: Arc, + thread_store: Entity, + cx: &mut App, +) -> Task> { + let Some(prompt_store) = thread_store.read(cx).prompt_store() else { + return Task::ready(vec![]); + }; + let search_task = prompt_store.read(cx).search(query, cancellation_flag, cx); + cx.background_spawn(async move { + search_task + .await + .into_iter() + .flat_map(|metadata| { + // Default prompts are filtered out as they are automatically included. + if metadata.default { + None + } else { + match metadata.id { + PromptId::EditWorkflow => None, + PromptId::User { uuid } => Some(RulesContextEntry { + prompt_id: uuid, + title: metadata.title?, + }), + } + } + }) + .collect::>() + }) +} diff --git a/crates/agent/src/context_picker/thread_context_picker.rs b/crates/agent/src/context_picker/thread_context_picker.rs index 941926a898..030eaf06af 100644 --- a/crates/agent/src/context_picker/thread_context_picker.rs +++ b/crates/agent/src/context_picker/thread_context_picker.rs @@ -103,11 +103,11 @@ impl PickerDelegate for ThreadContextPickerDelegate { window: &mut Window, cx: &mut Context>, ) -> Task<()> { - let Some(threads) = self.thread_store.upgrade() else { + let Some(thread_store) = self.thread_store.upgrade() else { return Task::ready(()); }; - let search_task = search_threads(query, Arc::new(AtomicBool::default()), threads, cx); + let search_task = search_threads(query, Arc::new(AtomicBool::default()), thread_store, cx); cx.spawn_in(window, async move |this, cx| { let matches = search_task.await; this.update(cx, |this, cx| { @@ -217,15 +217,15 @@ pub(crate) fn search_threads( thread_store: Entity, cx: &mut App, ) -> Task> { - let threads = thread_store.update(cx, |this, _cx| { - this.threads() - .into_iter() - .map(|thread| ThreadContextEntry { - id: thread.id, - summary: thread.summary, - }) - .collect::>() - }); + let threads = thread_store + .read(cx) + .threads() + .into_iter() + .map(|thread| ThreadContextEntry { + id: thread.id, + summary: thread.summary, + }) + .collect::>(); let executor = cx.background_executor().clone(); cx.background_spawn(async move { diff --git a/crates/agent/src/context_store.rs b/crates/agent/src/context_store.rs index a44d0caedf..6045a48a27 100644 --- a/crates/agent/src/context_store.rs +++ b/crates/agent/src/context_store.rs @@ -9,6 +9,7 @@ use futures::{self, Future, FutureExt, future}; use gpui::{App, AppContext as _, Context, Entity, SharedString, Task, WeakEntity}; use language::{Buffer, File}; use project::{Project, ProjectItem, ProjectPath, Worktree}; +use prompt_store::UserPromptId; use rope::{Point, Rope}; use text::{Anchor, BufferId, OffsetRangeExt}; use util::{ResultExt as _, maybe}; @@ -16,7 +17,7 @@ use util::{ResultExt as _, maybe}; use crate::ThreadStore; use crate::context::{ AssistantContext, ContextBuffer, ContextId, ContextSymbol, ContextSymbolId, DirectoryContext, - ExcerptContext, FetchedUrlContext, FileContext, SymbolContext, ThreadContext, + ExcerptContext, FetchedUrlContext, FileContext, RulesContext, SymbolContext, ThreadContext, }; use crate::context_strip::SuggestedContext; use crate::thread::{Thread, ThreadId}; @@ -25,7 +26,6 @@ pub struct ContextStore { project: WeakEntity, context: Vec, thread_store: Option>, - // TODO: If an EntityId is used for all context types (like BufferId), can remove ContextId. next_context_id: ContextId, files: BTreeMap, directories: HashMap, @@ -35,6 +35,7 @@ pub struct ContextStore { threads: HashMap, thread_summary_tasks: Vec>, fetched_urls: HashMap, + user_rules: HashMap, } impl ContextStore { @@ -55,6 +56,7 @@ impl ContextStore { threads: HashMap::default(), thread_summary_tasks: Vec::new(), fetched_urls: HashMap::default(), + user_rules: HashMap::default(), } } @@ -72,6 +74,7 @@ impl ContextStore { self.directories.clear(); self.threads.clear(); self.fetched_urls.clear(); + self.user_rules.clear(); } pub fn add_file_from_path( @@ -390,6 +393,42 @@ impl ContextStore { cx.notify(); } + pub fn add_rules( + &mut self, + prompt_id: UserPromptId, + title: impl Into, + text: impl Into, + remove_if_exists: bool, + cx: &mut Context, + ) { + if let Some(context_id) = self.includes_user_rules(&prompt_id) { + if remove_if_exists { + self.remove_context(context_id, cx); + } + } else { + self.insert_user_rules(prompt_id, title, text, cx); + } + } + + pub fn insert_user_rules( + &mut self, + prompt_id: UserPromptId, + title: impl Into, + text: impl Into, + cx: &mut Context, + ) { + let id = self.next_context_id.post_inc(); + + self.user_rules.insert(prompt_id, id); + self.context.push(AssistantContext::Rules(RulesContext { + id, + prompt_id, + title: title.into(), + text: text.into(), + })); + cx.notify(); + } + pub fn add_fetched_url( &mut self, url: String, @@ -518,6 +557,9 @@ impl ContextStore { AssistantContext::Thread(_) => { self.threads.retain(|_, context_id| *context_id != id); } + AssistantContext::Rules(RulesContext { prompt_id, .. }) => { + self.user_rules.remove(&prompt_id); + } } cx.notify(); @@ -614,6 +656,10 @@ impl ContextStore { self.threads.get(thread_id).copied() } + pub fn includes_user_rules(&self, prompt_id: &UserPromptId) -> Option { + self.user_rules.get(prompt_id).copied() + } + pub fn includes_url(&self, url: &str) -> Option { self.fetched_urls.get(url).copied() } @@ -641,7 +687,8 @@ impl ContextStore { | AssistantContext::Symbol(_) | AssistantContext::Excerpt(_) | AssistantContext::FetchedUrl(_) - | AssistantContext::Thread(_) => None, + | AssistantContext::Thread(_) + | AssistantContext::Rules(_) => None, }) .collect() } @@ -876,6 +923,10 @@ pub fn refresh_context_store_text( // and doing the caching properly could be tricky (unless it's already handled by // the HttpClient?). AssistantContext::FetchedUrl(_) => {} + AssistantContext::Rules(user_rules_context) => { + let context_store = context_store.clone(); + return Some(refresh_user_rules(context_store, user_rules_context, cx)); + } } None @@ -1026,6 +1077,45 @@ fn refresh_thread_text( }) } +fn refresh_user_rules( + context_store: Entity, + user_rules_context: &RulesContext, + cx: &App, +) -> Task<()> { + let id = user_rules_context.id; + let prompt_id = user_rules_context.prompt_id; + let Some(thread_store) = context_store.read(cx).thread_store.as_ref() else { + return Task::ready(()); + }; + let Ok(load_task) = thread_store.read_with(cx, |thread_store, cx| { + thread_store.load_rules(prompt_id, cx) + }) else { + return Task::ready(()); + }; + cx.spawn(async move |cx| { + if let Ok((metadata, text)) = load_task.await { + if let Some(title) = metadata.title.clone() { + context_store + .update(cx, |context_store, _cx| { + context_store.replace_context(AssistantContext::Rules(RulesContext { + id, + prompt_id, + title, + text: text.into(), + })); + }) + .ok(); + return; + } + } + context_store + .update(cx, |context_store, cx| { + context_store.remove_context(id, cx); + }) + .ok(); + }) +} + fn refresh_context_buffer( context_buffer: &ContextBuffer, cx: &App, diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index 061be6f738..9769fd92ba 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -774,7 +774,9 @@ impl Thread { cx, ); } - AssistantContext::FetchedUrl(_) | AssistantContext::Thread(_) => {} + AssistantContext::FetchedUrl(_) + | AssistantContext::Thread(_) + | AssistantContext::Rules(_) => {} } } }); diff --git a/crates/agent/src/thread_store.rs b/crates/agent/src/thread_store.rs index 646a26d26d..1f57c08cc7 100644 --- a/crates/agent/src/thread_store.rs +++ b/crates/agent/src/thread_store.rs @@ -24,8 +24,8 @@ use heed::types::SerdeBincode; use language_model::{LanguageModelToolUseId, Role, TokenUsage}; use project::{Project, Worktree}; use prompt_store::{ - DefaultUserRulesContext, ProjectContext, PromptBuilder, PromptId, PromptStore, - PromptsUpdatedEvent, RulesFileContext, WorktreeContext, + ProjectContext, PromptBuilder, PromptId, PromptMetadata, PromptStore, PromptsUpdatedEvent, + RulesFileContext, UserPromptId, UserRulesContext, WorktreeContext, }; use serde::{Deserialize, Serialize}; use settings::{Settings as _, SettingsStore}; @@ -62,6 +62,7 @@ pub struct ThreadStore { project: Entity, tools: Entity, prompt_builder: Arc, + prompt_store: Option>, context_server_manager: Entity, context_server_tool_ids: HashMap, Vec>, threads: Vec, @@ -135,6 +136,7 @@ impl ThreadStore { let (ready_tx, ready_rx) = oneshot::channel(); let mut ready_tx = Some(ready_tx); let reload_system_prompt_task = cx.spawn({ + let prompt_store = prompt_store.clone(); async move |thread_store, cx| { loop { let Some(reload_task) = thread_store @@ -158,6 +160,7 @@ impl ThreadStore { project, tools, prompt_builder, + prompt_store, context_server_manager, context_server_tool_ids: HashMap::default(), threads: Vec::new(), @@ -245,7 +248,7 @@ impl ThreadStore { let default_user_rules = default_user_rules .into_iter() .flat_map(|(contents, prompt_metadata)| match contents { - Ok(contents) => Some(DefaultUserRulesContext { + Ok(contents) => Some(UserRulesContext { uuid: match prompt_metadata.id { PromptId::User { uuid } => uuid, PromptId::EditWorkflow => return None, @@ -346,6 +349,27 @@ impl ThreadStore { self.context_server_manager.clone() } + pub fn prompt_store(&self) -> Option> { + self.prompt_store.clone() + } + + pub fn load_rules( + &self, + prompt_id: UserPromptId, + cx: &App, + ) -> Task> { + let prompt_id = PromptId::User { uuid: prompt_id }; + let Some(prompt_store) = self.prompt_store.as_ref() else { + return Task::ready(Err(anyhow!("Prompt store unexpectedly missing."))); + }; + let prompt_store = prompt_store.read(cx); + let Some(metadata) = prompt_store.metadata(prompt_id) else { + return Task::ready(Err(anyhow!("User rules not found in library."))); + }; + let text_task = prompt_store.load(prompt_id, cx); + cx.background_spawn(async move { Ok((metadata, text_task.await?)) }) + } + pub fn tools(&self) -> Entity { self.tools.clone() } diff --git a/crates/agent/src/ui/context_pill.rs b/crates/agent/src/ui/context_pill.rs index da07522745..cd5e63bc4f 100644 --- a/crates/agent/src/ui/context_pill.rs +++ b/crates/agent/src/ui/context_pill.rs @@ -354,6 +354,16 @@ impl AddedContext { .read(cx) .is_generating_detailed_summary(), }, + + AssistantContext::Rules(user_rules_context) => AddedContext { + id: user_rules_context.id, + kind: ContextKind::Rules, + name: user_rules_context.title.clone(), + parent: None, + tooltip: None, + icon_path: None, + summarizing: false, + }, } } } diff --git a/crates/assistant/src/assistant_panel.rs b/crates/assistant/src/assistant_panel.rs index 4760d4454f..205db386e2 100644 --- a/crates/assistant/src/assistant_panel.rs +++ b/crates/assistant/src/assistant_panel.rs @@ -27,7 +27,7 @@ use language_model::{ }; use project::Project; use prompt_library::{PromptLibrary, open_prompt_library}; -use prompt_store::{PromptBuilder, PromptId}; +use prompt_store::{PromptBuilder, PromptId, UserPromptId}; use search::{BufferSearchBar, buffer_search::DivRegistrar}; use settings::{Settings, update_settings_file}; @@ -58,11 +58,11 @@ pub fn init(cx: &mut App) { .register_action(AssistantPanel::show_configuration) .register_action(AssistantPanel::create_new_context) .register_action(AssistantPanel::restart_context_servers) - .register_action(|workspace, _: &OpenPromptLibrary, window, cx| { + .register_action(|workspace, action: &OpenPromptLibrary, window, cx| { if let Some(panel) = workspace.panel::(cx) { workspace.focus_panel::(window, cx); panel.update(cx, |panel, cx| { - panel.deploy_prompt_library(&OpenPromptLibrary::default(), window, cx) + panel.deploy_prompt_library(action, window, cx) }); } }); @@ -1060,7 +1060,9 @@ impl AssistantPanel { None, )) }), - action.prompt_to_focus.map(|uuid| PromptId::User { uuid }), + action.prompt_to_select.map(|uuid| PromptId::User { + uuid: UserPromptId(uuid), + }), cx, ) .detach_and_log_err(cx); diff --git a/crates/assistant_slash_commands/src/prompt_command.rs b/crates/assistant_slash_commands/src/prompt_command.rs index a057023197..c177f9f359 100644 --- a/crates/assistant_slash_commands/src/prompt_command.rs +++ b/crates/assistant_slash_commands/src/prompt_command.rs @@ -44,9 +44,10 @@ impl SlashCommand for PromptSlashCommand { let store = PromptStore::global(cx); let query = arguments.to_owned().join(" "); cx.spawn(async move |cx| { + let cancellation_flag = Arc::new(AtomicBool::default()); let prompts: Vec = store .await? - .read_with(cx, |store, cx| store.search(query, cx))? + .read_with(cx, |store, cx| store.search(query, cancellation_flag, cx))? .await; Ok(prompts .into_iter() diff --git a/crates/prompt_library/src/prompt_library.rs b/crates/prompt_library/src/prompt_library.rs index a0e9fbc9a3..d630ae2035 100644 --- a/crates/prompt_library/src/prompt_library.rs +++ b/crates/prompt_library/src/prompt_library.rs @@ -16,6 +16,7 @@ use release_channel::ReleaseChannel; use rope::Rope; use settings::Settings; use std::sync::Arc; +use std::sync::atomic::AtomicBool; use std::time::Duration; use theme::ThemeSettings; use ui::{ @@ -75,7 +76,7 @@ pub fn open_prompt_library( language_registry: Arc, inline_assist_delegate: Box, make_completion_provider: Arc Box>, - prompt_to_focus: Option, + prompt_to_select: Option, cx: &mut App, ) -> Task>> { let store = PromptStore::global(cx); @@ -90,8 +91,8 @@ pub fn open_prompt_library( if let Some(existing_window) = existing_window { existing_window .update(cx, |prompt_library, window, cx| { - if let Some(prompt_to_focus) = prompt_to_focus { - prompt_library.load_prompt(prompt_to_focus, true, window, cx); + if let Some(prompt_to_select) = prompt_to_select { + prompt_library.load_prompt(prompt_to_select, true, window, cx); } window.activate_window() }) @@ -126,18 +127,15 @@ pub fn open_prompt_library( }, |window, cx| { cx.new(|cx| { - let mut prompt_library = PromptLibrary::new( + PromptLibrary::new( store, language_registry, inline_assist_delegate, make_completion_provider, + prompt_to_select, window, cx, - ); - if let Some(prompt_to_focus) = prompt_to_focus { - prompt_library.load_prompt(prompt_to_focus, true, window, cx); - } - prompt_library + ) }) }, ) @@ -221,7 +219,8 @@ impl PickerDelegate for PromptPickerDelegate { window: &mut Window, cx: &mut Context>, ) -> Task<()> { - let search = self.store.read(cx).search(query, cx); + let cancellation_flag = Arc::new(AtomicBool::default()); + let search = self.store.read(cx).search(query, cancellation_flag, cx); let prev_prompt_id = self.matches.get(self.selected_index).map(|mat| mat.id); cx.spawn_in(window, async move |this, cx| { let (matches, selected_index) = cx @@ -353,13 +352,26 @@ impl PromptLibrary { language_registry: Arc, inline_assist_delegate: Box, make_completion_provider: Arc Box>, + prompt_to_select: Option, window: &mut Window, cx: &mut Context, ) -> Self { + let (selected_index, matches) = if let Some(prompt_to_select) = prompt_to_select { + let matches = store.read(cx).all_prompt_metadata(); + let selected_index = matches + .iter() + .enumerate() + .find(|(_, metadata)| metadata.id == prompt_to_select) + .map_or(0, |(ix, _)| ix); + (selected_index, matches) + } else { + (0, vec![]) + }; + let delegate = PromptPickerDelegate { store: store.clone(), - selected_index: 0, - matches: Vec::new(), + selected_index, + matches, }; let picker = cx.new(|cx| { diff --git a/crates/prompt_store/src/prompt_store.rs b/crates/prompt_store/src/prompt_store.rs index 66e4b9072f..84aaa688cd 100644 --- a/crates/prompt_store/src/prompt_store.rs +++ b/crates/prompt_store/src/prompt_store.rs @@ -54,14 +54,14 @@ pub struct PromptMetadata { #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)] #[serde(tag = "kind")] pub enum PromptId { - User { uuid: Uuid }, + User { uuid: UserPromptId }, EditWorkflow, } impl PromptId { pub fn new() -> PromptId { PromptId::User { - uuid: Uuid::new_v4(), + uuid: UserPromptId::new(), } } @@ -70,6 +70,22 @@ impl PromptId { } } +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)] +#[serde(transparent)] +pub struct UserPromptId(pub Uuid); + +impl UserPromptId { + pub fn new() -> UserPromptId { + UserPromptId(Uuid::new_v4()) + } +} + +impl From for UserPromptId { + fn from(uuid: Uuid) -> Self { + UserPromptId(uuid) + } +} + pub struct PromptStore { env: heed::Env, metadata_cache: RwLock, @@ -212,7 +228,7 @@ impl PromptStore { for (prompt_id_v1, metadata_v1) in metadata_v1 { let prompt_id_v2 = PromptId::User { - uuid: prompt_id_v1.0, + uuid: UserPromptId(prompt_id_v1.0), }; let Some(body_v1) = bodies_v1.remove(&prompt_id_v1) else { continue; @@ -257,6 +273,10 @@ impl PromptStore { }) } + pub fn all_prompt_metadata(&self) -> Vec { + self.metadata_cache.read().metadata.clone() + } + pub fn default_prompt_metadata(&self) -> Vec { return self .metadata_cache @@ -314,7 +334,12 @@ impl PromptStore { Some(metadata.id) } - pub fn search(&self, query: String, cx: &App) -> Task> { + pub fn search( + &self, + query: String, + cancellation_flag: Arc, + cx: &App, + ) -> Task> { let cached_metadata = self.metadata_cache.read().metadata.clone(); let executor = cx.background_executor().clone(); cx.background_spawn(async move { @@ -333,7 +358,7 @@ impl PromptStore { &query, false, 100, - &AtomicBool::default(), + &cancellation_flag, executor, ) .await; diff --git a/crates/prompt_store/src/prompts.rs b/crates/prompt_store/src/prompts.rs index 9dd3df3523..35430e8f06 100644 --- a/crates/prompt_store/src/prompts.rs +++ b/crates/prompt_store/src/prompts.rs @@ -15,34 +15,32 @@ use std::{ }; use text::LineEnding; use util::{ResultExt, get_system_shell}; -use uuid::Uuid; + +use crate::UserPromptId; #[derive(Debug, Clone, Serialize)] pub struct ProjectContext { pub worktrees: Vec, /// Whether any worktree has a rules_file. Provided as a field because handlebars can't do this. pub has_rules: bool, - pub default_user_rules: Vec, - /// `!default_user_rules.is_empty()` - provided as a field because handlebars can't do this. - pub has_default_user_rules: bool, + pub user_rules: Vec, + /// `!user_rules.is_empty()` - provided as a field because handlebars can't do this. + pub has_user_rules: bool, pub os: String, pub arch: String, pub shell: String, } impl ProjectContext { - pub fn new( - worktrees: Vec, - default_user_rules: Vec, - ) -> Self { + pub fn new(worktrees: Vec, default_user_rules: Vec) -> Self { let has_rules = worktrees .iter() .any(|worktree| worktree.rules_file.is_some()); Self { worktrees, has_rules, - has_default_user_rules: !default_user_rules.is_empty(), - default_user_rules, + has_user_rules: !default_user_rules.is_empty(), + user_rules: default_user_rules, os: std::env::consts::OS.to_string(), arch: std::env::consts::ARCH.to_string(), shell: get_system_shell(), @@ -51,8 +49,8 @@ impl ProjectContext { } #[derive(Debug, Clone, Serialize)] -pub struct DefaultUserRulesContext { - pub uuid: Uuid, +pub struct UserRulesContext { + pub uuid: UserPromptId, pub title: Option, pub contents: String, } @@ -397,6 +395,7 @@ impl PromptBuilder { #[cfg(test)] mod test { use super::*; + use uuid::Uuid; #[test] fn test_assistant_system_prompt_renders() { @@ -408,8 +407,8 @@ mod test { text: "".into(), }), }]; - let default_user_rules = vec![DefaultUserRulesContext { - uuid: Uuid::nil(), + let default_user_rules = vec![UserRulesContext { + uuid: UserPromptId(Uuid::nil()), title: Some("Rules title".into()), contents: "Rules contents".into(), }]; diff --git a/crates/zed_actions/src/lib.rs b/crates/zed_actions/src/lib.rs index 6c1852f882..b420662ebb 100644 --- a/crates/zed_actions/src/lib.rs +++ b/crates/zed_actions/src/lib.rs @@ -201,7 +201,7 @@ pub mod assistant { #[serde(deny_unknown_fields)] pub struct OpenPromptLibrary { #[serde(skip)] - pub prompt_to_focus: Option, + pub prompt_to_select: Option, } impl_action_with_deprecated_aliases!(