Add ability to attach rules as context (#29109)

Release Notes:

- agent: Added support for adding rules as context.
This commit is contained in:
Michael Sloan 2025-04-21 14:16:51 -06:00 committed by GitHub
parent 3b31860d52
commit 7aa0fa1543
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
18 changed files with 694 additions and 112 deletions

View file

@ -73,9 +73,9 @@ There are project rules that apply to these root directories:
{{/each}} {{/each}}
{{/if}} {{/if}}
{{#if has_default_user_rules}} {{#if has_user_rules}}
The user has specified the following rules that should be applied: The user has specified the following rules that should be applied:
{{#each default_user_rules}} {{#each user_rules}}
{{#if title}} {{#if title}}
Rules title: {{title}} Rules title: {{title}}

View file

@ -1,4 +1,4 @@
use crate::context::{AssistantContext, ContextId, format_context_as_string}; use crate::context::{AssistantContext, ContextId, RULES_ICON, format_context_as_string};
use crate::context_picker::MentionLink; use crate::context_picker::MentionLink;
use crate::thread::{ use crate::thread::{
LastRestoreCheckpoint, MessageId, MessageSegment, Thread, ThreadError, ThreadEvent, LastRestoreCheckpoint, MessageId, MessageSegment, Thread, ThreadError, ThreadEvent,
@ -688,6 +688,12 @@ fn open_markdown_link(
} }
}), }),
Some(MentionLink::Fetch(url)) => cx.open_url(&url), Some(MentionLink::Fetch(url)) => cx.open_url(&url),
Some(MentionLink::Rules(prompt_id)) => window.dispatch_action(
Box::new(OpenPromptLibrary {
prompt_to_select: Some(prompt_id.0),
}),
cx,
),
None => cx.open_url(&text), None => cx.open_url(&text),
} }
} }
@ -2957,10 +2963,10 @@ impl ActiveThread {
return div().into_any(); return div().into_any();
}; };
let default_user_rules_text = if project_context.default_user_rules.is_empty() { let user_rules_text = if project_context.user_rules.is_empty() {
None None
} else if project_context.default_user_rules.len() == 1 { } else if project_context.user_rules.len() == 1 {
let user_rules = &project_context.default_user_rules[0]; let user_rules = &project_context.user_rules[0];
match user_rules.title.as_ref() { match user_rules.title.as_ref() {
Some(title) => Some(format!("Using \"{title}\" user rule")), Some(title) => Some(format!("Using \"{title}\" user rule")),
@ -2969,14 +2975,14 @@ impl ActiveThread {
} else { } else {
Some(format!( Some(format!(
"Using {} user rules", "Using {} user rules",
project_context.default_user_rules.len() project_context.user_rules.len()
)) ))
}; };
let first_default_user_rules_id = project_context let first_user_rules_id = project_context
.default_user_rules .user_rules
.first() .first()
.map(|user_rules| user_rules.uuid); .map(|user_rules| user_rules.uuid.0);
let rules_files = project_context let rules_files = project_context
.worktrees .worktrees
@ -2993,7 +2999,7 @@ impl ActiveThread {
rules_files => Some(format!("Using {} project rules files", rules_files.len())), rules_files => Some(format!("Using {} project rules files", rules_files.len())),
}; };
if default_user_rules_text.is_none() && rules_file_text.is_none() { if user_rules_text.is_none() && rules_file_text.is_none() {
return div().into_any(); return div().into_any();
} }
@ -3001,45 +3007,42 @@ impl ActiveThread {
.pt_2() .pt_2()
.px_2p5() .px_2p5()
.gap_1() .gap_1()
.when_some( .when_some(user_rules_text, |parent, user_rules_text| {
default_user_rules_text, parent.child(
|parent, default_user_rules_text| { h_flex()
parent.child( .w_full()
h_flex() .child(
.w_full() Icon::new(RULES_ICON)
.child( .size(IconSize::XSmall)
Icon::new(IconName::File) .color(Color::Disabled),
.size(IconSize::XSmall) )
.color(Color::Disabled), .child(
) Label::new(user_rules_text)
.child( .size(LabelSize::XSmall)
Label::new(default_user_rules_text) .color(Color::Muted)
.size(LabelSize::XSmall) .truncate()
.color(Color::Muted) .buffer_font(cx)
.truncate() .ml_1p5()
.buffer_font(cx) .mr_0p5(),
.ml_1p5() )
.mr_0p5(), .child(
) IconButton::new("open-prompt-library", IconName::ArrowUpRightAlt)
.child( .shape(ui::IconButtonShape::Square)
IconButton::new("open-prompt-library", IconName::ArrowUpRightAlt) .icon_size(IconSize::XSmall)
.shape(ui::IconButtonShape::Square) .icon_color(Color::Ignored)
.icon_size(IconSize::XSmall) // TODO: Figure out a way to pass focus handle here so we can display the `OpenPromptLibrary` keybinding
.icon_color(Color::Ignored) .tooltip(Tooltip::text("View User Rules"))
// TODO: Figure out a way to pass focus handle here so we can display the `OpenPromptLibrary` keybinding .on_click(move |_event, window, cx| {
.tooltip(Tooltip::text("View User Rules")) window.dispatch_action(
.on_click(move |_event, window, cx| { Box::new(OpenPromptLibrary {
window.dispatch_action( prompt_to_select: first_user_rules_id,
Box::new(OpenPromptLibrary { }),
prompt_to_focus: first_default_user_rules_id, cx,
}), )
cx, }),
) ),
}), )
), })
)
},
)
.when_some(rules_file_text, |parent, rules_file_text| { .when_some(rules_file_text, |parent, rules_file_text| {
parent.child( parent.child(
h_flex() h_flex()
@ -3316,6 +3319,12 @@ pub(crate) fn open_context(
} }
}) })
} }
AssistantContext::Rules(rules_context) => window.dispatch_action(
Box::new(OpenPromptLibrary {
prompt_to_select: Some(rules_context.prompt_id.0),
}),
cx,
),
} }
} }

View file

@ -25,7 +25,7 @@ use language_model::{LanguageModelProviderTosView, LanguageModelRegistry};
use language_model_selector::ToggleModelSelector; use language_model_selector::ToggleModelSelector;
use project::Project; use project::Project;
use prompt_library::{PromptLibrary, open_prompt_library}; use prompt_library::{PromptLibrary, open_prompt_library};
use prompt_store::{PromptBuilder, PromptId}; use prompt_store::{PromptBuilder, PromptId, UserPromptId};
use proto::Plan; use proto::Plan;
use settings::{Settings, update_settings_file}; use settings::{Settings, update_settings_file};
use time::UtcOffset; use time::UtcOffset;
@ -79,11 +79,11 @@ pub fn init(cx: &mut App) {
panel.update(cx, |panel, cx| panel.new_prompt_editor(window, cx)); panel.update(cx, |panel, cx| panel.new_prompt_editor(window, cx));
} }
}) })
.register_action(|workspace, _: &OpenPromptLibrary, window, cx| { .register_action(|workspace, action: &OpenPromptLibrary, window, cx| {
if let Some(panel) = workspace.panel::<AssistantPanel>(cx) { if let Some(panel) = workspace.panel::<AssistantPanel>(cx) {
workspace.focus_panel::<AssistantPanel>(window, cx); workspace.focus_panel::<AssistantPanel>(window, cx);
panel.update(cx, |panel, cx| { panel.update(cx, |panel, cx| {
panel.deploy_prompt_library(&OpenPromptLibrary::default(), window, cx) panel.deploy_prompt_library(action, window, cx)
}); });
} }
}) })
@ -502,7 +502,9 @@ impl AssistantPanel {
None, None,
)) ))
}), }),
action.prompt_to_focus.map(|uuid| PromptId::User { uuid }), action.prompt_to_select.map(|uuid| PromptId::User {
uuid: UserPromptId(uuid),
}),
cx, cx,
) )
.detach_and_log_err(cx); .detach_and_log_err(cx);

View file

@ -4,6 +4,7 @@ use gpui::{App, Entity, SharedString};
use language::{Buffer, File}; use language::{Buffer, File};
use language_model::LanguageModelRequestMessage; use language_model::LanguageModelRequestMessage;
use project::{ProjectPath, Worktree}; use project::{ProjectPath, Worktree};
use prompt_store::UserPromptId;
use rope::Point; use rope::Point;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use text::{Anchor, BufferId}; use text::{Anchor, BufferId};
@ -12,6 +13,8 @@ use util::post_inc;
use crate::thread::Thread; use crate::thread::Thread;
pub const RULES_ICON: IconName = IconName::Context;
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)] #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
pub struct ContextId(pub(crate) usize); pub struct ContextId(pub(crate) usize);
@ -20,6 +23,7 @@ impl ContextId {
Self(post_inc(&mut self.0)) Self(post_inc(&mut self.0))
} }
} }
pub enum ContextKind { pub enum ContextKind {
File, File,
Directory, Directory,
@ -27,6 +31,7 @@ pub enum ContextKind {
Excerpt, Excerpt,
FetchedUrl, FetchedUrl,
Thread, Thread,
Rules,
} }
impl ContextKind { impl ContextKind {
@ -38,6 +43,7 @@ impl ContextKind {
ContextKind::Excerpt => IconName::Code, ContextKind::Excerpt => IconName::Code,
ContextKind::FetchedUrl => IconName::Globe, ContextKind::FetchedUrl => IconName::Globe,
ContextKind::Thread => IconName::MessageBubbles, ContextKind::Thread => IconName::MessageBubbles,
ContextKind::Rules => RULES_ICON,
} }
} }
} }
@ -50,6 +56,7 @@ pub enum AssistantContext {
FetchedUrl(FetchedUrlContext), FetchedUrl(FetchedUrlContext),
Thread(ThreadContext), Thread(ThreadContext),
Excerpt(ExcerptContext), Excerpt(ExcerptContext),
Rules(RulesContext),
} }
impl AssistantContext { impl AssistantContext {
@ -61,6 +68,7 @@ impl AssistantContext {
Self::FetchedUrl(url) => url.id, Self::FetchedUrl(url) => url.id,
Self::Thread(thread) => thread.id, Self::Thread(thread) => thread.id,
Self::Excerpt(excerpt) => excerpt.id, Self::Excerpt(excerpt) => excerpt.id,
Self::Rules(rules) => rules.id,
} }
} }
} }
@ -168,6 +176,14 @@ pub struct ExcerptContext {
pub context_buffer: ContextBuffer, pub context_buffer: ContextBuffer,
} }
#[derive(Debug, Clone)]
pub struct RulesContext {
pub id: ContextId,
pub prompt_id: UserPromptId,
pub title: SharedString,
pub text: SharedString,
}
/// Formats a collection of contexts into a string representation /// Formats a collection of contexts into a string representation
pub fn format_context_as_string<'a>( pub fn format_context_as_string<'a>(
contexts: impl Iterator<Item = &'a AssistantContext>, contexts: impl Iterator<Item = &'a AssistantContext>,
@ -179,6 +195,7 @@ pub fn format_context_as_string<'a>(
let mut excerpt_context = Vec::new(); let mut excerpt_context = Vec::new();
let mut fetch_context = Vec::new(); let mut fetch_context = Vec::new();
let mut thread_context = Vec::new(); let mut thread_context = Vec::new();
let mut rules_context = Vec::new();
for context in contexts { for context in contexts {
match context { match context {
@ -188,6 +205,7 @@ pub fn format_context_as_string<'a>(
AssistantContext::Excerpt(context) => excerpt_context.push(context), AssistantContext::Excerpt(context) => excerpt_context.push(context),
AssistantContext::FetchedUrl(context) => fetch_context.push(context), AssistantContext::FetchedUrl(context) => fetch_context.push(context),
AssistantContext::Thread(context) => thread_context.push(context), AssistantContext::Thread(context) => thread_context.push(context),
AssistantContext::Rules(context) => rules_context.push(context),
} }
} }
@ -197,6 +215,7 @@ pub fn format_context_as_string<'a>(
&& excerpt_context.is_empty() && excerpt_context.is_empty()
&& fetch_context.is_empty() && fetch_context.is_empty()
&& thread_context.is_empty() && thread_context.is_empty()
&& rules_context.is_empty()
{ {
return None; return None;
} }
@ -263,6 +282,18 @@ pub fn format_context_as_string<'a>(
result.push_str("</conversation_threads>\n"); result.push_str("</conversation_threads>\n");
} }
if !rules_context.is_empty() {
result.push_str(
"<user_rules>\n\
The user has specified the following rules that should be applied:\n\n",
);
for context in &rules_context {
result.push_str(&context.text);
result.push('\n');
}
result.push_str("</user_rules>\n");
}
result.push_str("</context>\n"); result.push_str("</context>\n");
Some(result) Some(result)
} }

View file

@ -1,6 +1,7 @@
mod completion_provider; mod completion_provider;
mod fetch_context_picker; mod fetch_context_picker;
mod file_context_picker; mod file_context_picker;
mod rules_context_picker;
mod symbol_context_picker; mod symbol_context_picker;
mod thread_context_picker; mod thread_context_picker;
@ -18,17 +19,22 @@ use gpui::{
}; };
use multi_buffer::MultiBufferRow; use multi_buffer::MultiBufferRow;
use project::{Entry, ProjectPath}; use project::{Entry, ProjectPath};
use prompt_store::UserPromptId;
use rules_context_picker::RulesContextEntry;
use symbol_context_picker::SymbolContextPicker; use symbol_context_picker::SymbolContextPicker;
use thread_context_picker::{ThreadContextEntry, render_thread_context_entry}; use thread_context_picker::{ThreadContextEntry, render_thread_context_entry};
use ui::{ use ui::{
ButtonLike, ContextMenu, ContextMenuEntry, ContextMenuItem, Disclosure, TintColor, prelude::*, ButtonLike, ContextMenu, ContextMenuEntry, ContextMenuItem, Disclosure, TintColor, prelude::*,
}; };
use uuid::Uuid;
use workspace::{Workspace, notifications::NotifyResultExt}; use workspace::{Workspace, notifications::NotifyResultExt};
use crate::AssistantPanel; use crate::AssistantPanel;
use crate::context::RULES_ICON;
pub use crate::context_picker::completion_provider::ContextPickerCompletionProvider; pub use crate::context_picker::completion_provider::ContextPickerCompletionProvider;
use crate::context_picker::fetch_context_picker::FetchContextPicker; use crate::context_picker::fetch_context_picker::FetchContextPicker;
use crate::context_picker::file_context_picker::FileContextPicker; use crate::context_picker::file_context_picker::FileContextPicker;
use crate::context_picker::rules_context_picker::RulesContextPicker;
use crate::context_picker::thread_context_picker::ThreadContextPicker; use crate::context_picker::thread_context_picker::ThreadContextPicker;
use crate::context_store::ContextStore; use crate::context_store::ContextStore;
use crate::thread::ThreadId; use crate::thread::ThreadId;
@ -40,6 +46,7 @@ enum ContextPickerMode {
Symbol, Symbol,
Fetch, Fetch,
Thread, Thread,
Rules,
} }
impl TryFrom<&str> for ContextPickerMode { impl TryFrom<&str> for ContextPickerMode {
@ -51,6 +58,7 @@ impl TryFrom<&str> for ContextPickerMode {
"symbol" => Ok(Self::Symbol), "symbol" => Ok(Self::Symbol),
"fetch" => Ok(Self::Fetch), "fetch" => Ok(Self::Fetch),
"thread" => Ok(Self::Thread), "thread" => Ok(Self::Thread),
"rules" => Ok(Self::Rules),
_ => Err(format!("Invalid context picker mode: {}", value)), _ => Err(format!("Invalid context picker mode: {}", value)),
} }
} }
@ -63,6 +71,7 @@ impl ContextPickerMode {
Self::Symbol => "symbol", Self::Symbol => "symbol",
Self::Fetch => "fetch", Self::Fetch => "fetch",
Self::Thread => "thread", Self::Thread => "thread",
Self::Rules => "rules",
} }
} }
@ -72,6 +81,7 @@ impl ContextPickerMode {
Self::Symbol => "Symbols", Self::Symbol => "Symbols",
Self::Fetch => "Fetch", Self::Fetch => "Fetch",
Self::Thread => "Threads", Self::Thread => "Threads",
Self::Rules => "Rules",
} }
} }
@ -81,6 +91,7 @@ impl ContextPickerMode {
Self::Symbol => IconName::Code, Self::Symbol => IconName::Code,
Self::Fetch => IconName::Globe, Self::Fetch => IconName::Globe,
Self::Thread => IconName::MessageBubbles, Self::Thread => IconName::MessageBubbles,
Self::Rules => RULES_ICON,
} }
} }
} }
@ -92,6 +103,7 @@ enum ContextPickerState {
Symbol(Entity<SymbolContextPicker>), Symbol(Entity<SymbolContextPicker>),
Fetch(Entity<FetchContextPicker>), Fetch(Entity<FetchContextPicker>),
Thread(Entity<ThreadContextPicker>), Thread(Entity<ThreadContextPicker>),
Rules(Entity<RulesContextPicker>),
} }
pub(super) struct ContextPicker { pub(super) struct ContextPicker {
@ -253,6 +265,19 @@ impl ContextPicker {
})); }));
} }
} }
ContextPickerMode::Rules => {
if let Some(thread_store) = self.thread_store.as_ref() {
self.mode = ContextPickerState::Rules(cx.new(|cx| {
RulesContextPicker::new(
thread_store.clone(),
context_picker.clone(),
self.context_store.clone(),
window,
cx,
)
}));
}
}
} }
cx.notify(); cx.notify();
@ -381,6 +406,7 @@ impl ContextPicker {
ContextPickerState::Symbol(entity) => entity.update(cx, |_, cx| cx.notify()), ContextPickerState::Symbol(entity) => entity.update(cx, |_, cx| cx.notify()),
ContextPickerState::Fetch(entity) => entity.update(cx, |_, cx| cx.notify()), ContextPickerState::Fetch(entity) => entity.update(cx, |_, cx| cx.notify()),
ContextPickerState::Thread(entity) => entity.update(cx, |_, cx| cx.notify()), ContextPickerState::Thread(entity) => entity.update(cx, |_, cx| cx.notify()),
ContextPickerState::Rules(entity) => entity.update(cx, |_, cx| cx.notify()),
} }
} }
} }
@ -395,6 +421,7 @@ impl Focusable for ContextPicker {
ContextPickerState::Symbol(symbol_picker) => symbol_picker.focus_handle(cx), ContextPickerState::Symbol(symbol_picker) => symbol_picker.focus_handle(cx),
ContextPickerState::Fetch(fetch_picker) => fetch_picker.focus_handle(cx), ContextPickerState::Fetch(fetch_picker) => fetch_picker.focus_handle(cx),
ContextPickerState::Thread(thread_picker) => thread_picker.focus_handle(cx), ContextPickerState::Thread(thread_picker) => thread_picker.focus_handle(cx),
ContextPickerState::Rules(user_rules_picker) => user_rules_picker.focus_handle(cx),
} }
} }
} }
@ -410,6 +437,9 @@ impl Render for ContextPicker {
ContextPickerState::Symbol(symbol_picker) => parent.child(symbol_picker.clone()), ContextPickerState::Symbol(symbol_picker) => parent.child(symbol_picker.clone()),
ContextPickerState::Fetch(fetch_picker) => parent.child(fetch_picker.clone()), ContextPickerState::Fetch(fetch_picker) => parent.child(fetch_picker.clone()),
ContextPickerState::Thread(thread_picker) => parent.child(thread_picker.clone()), ContextPickerState::Thread(thread_picker) => parent.child(thread_picker.clone()),
ContextPickerState::Rules(user_rules_picker) => {
parent.child(user_rules_picker.clone())
}
}) })
} }
} }
@ -431,6 +461,7 @@ fn supported_context_picker_modes(
]; ];
if thread_store.is_some() { if thread_store.is_some() {
modes.push(ContextPickerMode::Thread); modes.push(ContextPickerMode::Thread);
modes.push(ContextPickerMode::Rules);
} }
modes modes
} }
@ -626,6 +657,7 @@ pub enum MentionLink {
Symbol(ProjectPath, String), Symbol(ProjectPath, String),
Fetch(String), Fetch(String),
Thread(ThreadId), Thread(ThreadId),
Rules(UserPromptId),
} }
impl MentionLink { impl MentionLink {
@ -633,14 +665,16 @@ impl MentionLink {
const SYMBOL: &str = "@symbol"; const SYMBOL: &str = "@symbol";
const THREAD: &str = "@thread"; const THREAD: &str = "@thread";
const FETCH: &str = "@fetch"; const FETCH: &str = "@fetch";
const RULES: &str = "@rules";
const SEPARATOR: &str = ":"; const SEPARATOR: &str = ":";
pub fn is_valid(url: &str) -> bool { pub fn is_valid(url: &str) -> bool {
url.starts_with(Self::FILE) url.starts_with(Self::FILE)
|| url.starts_with(Self::SYMBOL) || url.starts_with(Self::SYMBOL)
|| url.starts_with(Self::FETCH)
|| url.starts_with(Self::THREAD) || url.starts_with(Self::THREAD)
|| url.starts_with(Self::FETCH)
|| url.starts_with(Self::RULES)
} }
pub fn for_file(file_name: &str, full_path: &str) -> String { pub fn for_file(file_name: &str, full_path: &str) -> String {
@ -657,12 +691,16 @@ impl MentionLink {
) )
} }
pub fn for_thread(thread: &ThreadContextEntry) -> String {
format!("[@{}]({}:{})", thread.summary, Self::THREAD, thread.id)
}
pub fn for_fetch(url: &str) -> String { pub fn for_fetch(url: &str) -> String {
format!("[@{}]({}:{})", url, Self::FETCH, url) format!("[@{}]({}:{})", url, Self::FETCH, url)
} }
pub fn for_thread(thread: &ThreadContextEntry) -> String { pub fn for_rules(rules: &RulesContextEntry) -> String {
format!("[@{}]({}:{})", thread.summary, Self::THREAD, thread.id) format!("[@{}]({}:{})", rules.title, Self::RULES, rules.prompt_id.0)
} }
pub fn try_parse(link: &str, workspace: &Entity<Workspace>, cx: &App) -> Option<Self> { pub fn try_parse(link: &str, workspace: &Entity<Workspace>, cx: &App) -> Option<Self> {
@ -706,6 +744,10 @@ impl MentionLink {
Some(MentionLink::Thread(thread_id)) Some(MentionLink::Thread(thread_id))
} }
Self::FETCH => Some(MentionLink::Fetch(argument.to_string())), Self::FETCH => Some(MentionLink::Fetch(argument.to_string())),
Self::RULES => {
let prompt_id = UserPromptId(Uuid::try_parse(argument).ok()?);
Some(MentionLink::Rules(prompt_id))
}
_ => None, _ => None,
} }
} }

View file

@ -14,11 +14,13 @@ use http_client::HttpClientWithUrl;
use language::{Buffer, CodeLabel, HighlightId}; use language::{Buffer, CodeLabel, HighlightId};
use lsp::CompletionContext; use lsp::CompletionContext;
use project::{Completion, CompletionIntent, ProjectPath, Symbol, WorktreeId}; use project::{Completion, CompletionIntent, ProjectPath, Symbol, WorktreeId};
use prompt_store::PromptId;
use rope::Point; use rope::Point;
use text::{Anchor, ToPoint}; use text::{Anchor, ToPoint};
use ui::prelude::*; use ui::prelude::*;
use workspace::Workspace; use workspace::Workspace;
use crate::context::RULES_ICON;
use crate::context_picker::file_context_picker::search_files; use crate::context_picker::file_context_picker::search_files;
use crate::context_picker::symbol_context_picker::search_symbols; use crate::context_picker::symbol_context_picker::search_symbols;
use crate::context_store::ContextStore; use crate::context_store::ContextStore;
@ -26,6 +28,7 @@ use crate::thread_store::ThreadStore;
use super::fetch_context_picker::fetch_url_content; use super::fetch_context_picker::fetch_url_content;
use super::file_context_picker::FileMatch; use super::file_context_picker::FileMatch;
use super::rules_context_picker::{RulesContextEntry, search_rules};
use super::symbol_context_picker::SymbolMatch; use super::symbol_context_picker::SymbolMatch;
use super::thread_context_picker::{ThreadContextEntry, ThreadMatch, search_threads}; use super::thread_context_picker::{ThreadContextEntry, ThreadMatch, search_threads};
use super::{ use super::{
@ -38,6 +41,7 @@ pub(crate) enum Match {
File(FileMatch), File(FileMatch),
Thread(ThreadMatch), Thread(ThreadMatch),
Fetch(SharedString), Fetch(SharedString),
Rules(RulesContextEntry),
Mode(ModeMatch), Mode(ModeMatch),
} }
@ -54,6 +58,7 @@ impl Match {
Match::Thread(_) => 1., Match::Thread(_) => 1.,
Match::Symbol(_) => 1., Match::Symbol(_) => 1.,
Match::Fetch(_) => 1., Match::Fetch(_) => 1.,
Match::Rules(_) => 1.,
} }
} }
} }
@ -112,6 +117,21 @@ fn search(
Task::ready(Vec::new()) Task::ready(Vec::new())
} }
} }
Some(ContextPickerMode::Rules) => {
if let Some(thread_store) = thread_store.as_ref().and_then(|t| t.upgrade()) {
let search_rules_task =
search_rules(query.clone(), cancellation_flag.clone(), thread_store, cx);
cx.background_spawn(async move {
search_rules_task
.await
.into_iter()
.map(Match::Rules)
.collect::<Vec<_>>()
})
} else {
Task::ready(Vec::new())
}
}
None => { None => {
if query.is_empty() { if query.is_empty() {
let mut matches = recent_entries let mut matches = recent_entries
@ -287,6 +307,60 @@ impl ContextPickerCompletionProvider {
} }
} }
fn completion_for_rules(
rules: RulesContextEntry,
excerpt_id: ExcerptId,
source_range: Range<Anchor>,
editor: Entity<Editor>,
context_store: Entity<ContextStore>,
thread_store: Entity<ThreadStore>,
) -> Completion {
let new_text = MentionLink::for_rules(&rules);
let new_text_len = new_text.len();
Completion {
replace_range: source_range.clone(),
new_text,
label: CodeLabel::plain(rules.title.to_string(), None),
documentation: None,
insert_text_mode: None,
source: project::CompletionSource::Custom,
icon_path: Some(RULES_ICON.path().into()),
confirm: Some(confirm_completion_callback(
RULES_ICON.path().into(),
rules.title.clone(),
excerpt_id,
source_range.start,
new_text_len,
editor.clone(),
move |cx| {
let prompt_uuid = rules.prompt_id;
let prompt_id = PromptId::User { uuid: prompt_uuid };
let context_store = context_store.clone();
let Some(prompt_store) = thread_store.read(cx).prompt_store() else {
log::error!("Can't add user rules as prompt store is missing.");
return;
};
let prompt_store = prompt_store.read(cx);
let Some(metadata) = prompt_store.metadata(prompt_id) else {
return;
};
let Some(title) = metadata.title else {
return;
};
let text_task = prompt_store.load(prompt_id, cx);
cx.spawn(async move |cx| {
let text = text_task.await?;
context_store.update(cx, |context_store, cx| {
context_store.add_rules(prompt_uuid, title, text, false, cx)
})
})
.detach_and_log_err(cx);
},
)),
}
}
fn completion_for_fetch( fn completion_for_fetch(
source_range: Range<Anchor>, source_range: Range<Anchor>,
url_to_fetch: SharedString, url_to_fetch: SharedString,
@ -593,6 +667,17 @@ impl CompletionProvider for ContextPickerCompletionProvider {
thread_store, thread_store,
)) ))
} }
Match::Rules(user_rules) => {
let thread_store = thread_store.as_ref().and_then(|t| t.upgrade())?;
Some(Self::completion_for_rules(
user_rules,
excerpt_id,
source_range.clone(),
editor.clone(),
context_store.clone(),
thread_store,
))
}
Match::Fetch(url) => Some(Self::completion_for_fetch( Match::Fetch(url) => Some(Self::completion_for_fetch(
source_range.clone(), source_range.clone(),
url, url,

View file

@ -0,0 +1,248 @@
use std::sync::Arc;
use std::sync::atomic::AtomicBool;
use anyhow::anyhow;
use gpui::{App, DismissEvent, Entity, FocusHandle, Focusable, Task, WeakEntity};
use picker::{Picker, PickerDelegate};
use prompt_store::{PromptId, UserPromptId};
use ui::{ListItem, prelude::*};
use crate::context::RULES_ICON;
use crate::context_picker::ContextPicker;
use crate::context_store::{self, ContextStore};
use crate::thread_store::ThreadStore;
pub struct RulesContextPicker {
picker: Entity<Picker<RulesContextPickerDelegate>>,
}
impl RulesContextPicker {
pub fn new(
thread_store: WeakEntity<ThreadStore>,
context_picker: WeakEntity<ContextPicker>,
context_store: WeakEntity<context_store::ContextStore>,
window: &mut Window,
cx: &mut Context<Self>,
) -> Self {
let delegate = RulesContextPickerDelegate::new(thread_store, context_picker, context_store);
let picker = cx.new(|cx| Picker::uniform_list(delegate, window, cx));
RulesContextPicker { picker }
}
}
impl Focusable for RulesContextPicker {
fn focus_handle(&self, cx: &App) -> FocusHandle {
self.picker.focus_handle(cx)
}
}
impl Render for RulesContextPicker {
fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement {
self.picker.clone()
}
}
#[derive(Debug, Clone)]
pub struct RulesContextEntry {
pub prompt_id: UserPromptId,
pub title: SharedString,
}
pub struct RulesContextPickerDelegate {
thread_store: WeakEntity<ThreadStore>,
context_picker: WeakEntity<ContextPicker>,
context_store: WeakEntity<context_store::ContextStore>,
matches: Vec<RulesContextEntry>,
selected_index: usize,
}
impl RulesContextPickerDelegate {
pub fn new(
thread_store: WeakEntity<ThreadStore>,
context_picker: WeakEntity<ContextPicker>,
context_store: WeakEntity<context_store::ContextStore>,
) -> Self {
RulesContextPickerDelegate {
thread_store,
context_picker,
context_store,
matches: Vec::new(),
selected_index: 0,
}
}
}
impl PickerDelegate for RulesContextPickerDelegate {
type ListItem = ListItem;
fn match_count(&self) -> usize {
self.matches.len()
}
fn selected_index(&self) -> usize {
self.selected_index
}
fn set_selected_index(
&mut self,
ix: usize,
_window: &mut Window,
_cx: &mut Context<Picker<Self>>,
) {
self.selected_index = ix;
}
fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
"Search available rules…".into()
}
fn update_matches(
&mut self,
query: String,
window: &mut Window,
cx: &mut Context<Picker<Self>>,
) -> Task<()> {
let Some(thread_store) = self.thread_store.upgrade() else {
return Task::ready(());
};
let search_task = search_rules(query, Arc::new(AtomicBool::default()), thread_store, cx);
cx.spawn_in(window, async move |this, cx| {
let matches = search_task.await;
this.update(cx, |this, cx| {
this.delegate.matches = matches;
this.delegate.selected_index = 0;
cx.notify();
})
.ok();
})
}
fn confirm(&mut self, _secondary: bool, _window: &mut Window, cx: &mut Context<Picker<Self>>) {
let Some(entry) = self.matches.get(self.selected_index) else {
return;
};
let Some(thread_store) = self.thread_store.upgrade() else {
return;
};
let prompt_id = entry.prompt_id;
let load_rules_task = thread_store.update(cx, |thread_store, cx| {
thread_store.load_rules(prompt_id, cx)
});
cx.spawn(async move |this, cx| {
let (metadata, text) = load_rules_task.await?;
let Some(title) = metadata.title else {
return Err(anyhow!("Encountered user rule with no title when attempting to add it to agent context."));
};
this.update(cx, |this, cx| {
this.delegate
.context_store
.update(cx, |context_store, cx| {
context_store.add_rules(prompt_id, title, text, true, cx)
})
.ok();
})
})
.detach_and_log_err(cx);
}
fn dismissed(&mut self, _window: &mut Window, cx: &mut Context<Picker<Self>>) {
self.context_picker
.update(cx, |_, cx| {
cx.emit(DismissEvent);
})
.ok();
}
fn render_match(
&self,
ix: usize,
selected: bool,
_window: &mut Window,
cx: &mut Context<Picker<Self>>,
) -> Option<Self::ListItem> {
let thread = &self.matches[ix];
Some(ListItem::new(ix).inset(true).toggle_state(selected).child(
render_thread_context_entry(thread, self.context_store.clone(), cx),
))
}
}
pub fn render_thread_context_entry(
user_rules: &RulesContextEntry,
context_store: WeakEntity<ContextStore>,
cx: &mut App,
) -> Div {
let added = context_store.upgrade().map_or(false, |ctx_store| {
ctx_store
.read(cx)
.includes_user_rules(&user_rules.prompt_id)
.is_some()
});
h_flex()
.gap_1p5()
.w_full()
.justify_between()
.child(
h_flex()
.gap_1p5()
.max_w_72()
.child(
Icon::new(RULES_ICON)
.size(IconSize::XSmall)
.color(Color::Muted),
)
.child(Label::new(user_rules.title.clone()).truncate()),
)
.when(added, |el| {
el.child(
h_flex()
.gap_1()
.child(
Icon::new(IconName::Check)
.size(IconSize::Small)
.color(Color::Success),
)
.child(Label::new("Added").size(LabelSize::Small)),
)
})
}
pub(crate) fn search_rules(
query: String,
cancellation_flag: Arc<AtomicBool>,
thread_store: Entity<ThreadStore>,
cx: &mut App,
) -> Task<Vec<RulesContextEntry>> {
let Some(prompt_store) = thread_store.read(cx).prompt_store() else {
return Task::ready(vec![]);
};
let search_task = prompt_store.read(cx).search(query, cancellation_flag, cx);
cx.background_spawn(async move {
search_task
.await
.into_iter()
.flat_map(|metadata| {
// Default prompts are filtered out as they are automatically included.
if metadata.default {
None
} else {
match metadata.id {
PromptId::EditWorkflow => None,
PromptId::User { uuid } => Some(RulesContextEntry {
prompt_id: uuid,
title: metadata.title?,
}),
}
}
})
.collect::<Vec<_>>()
})
}

View file

@ -103,11 +103,11 @@ impl PickerDelegate for ThreadContextPickerDelegate {
window: &mut Window, window: &mut Window,
cx: &mut Context<Picker<Self>>, cx: &mut Context<Picker<Self>>,
) -> Task<()> { ) -> Task<()> {
let Some(threads) = self.thread_store.upgrade() else { let Some(thread_store) = self.thread_store.upgrade() else {
return Task::ready(()); return Task::ready(());
}; };
let search_task = search_threads(query, Arc::new(AtomicBool::default()), threads, cx); let search_task = search_threads(query, Arc::new(AtomicBool::default()), thread_store, cx);
cx.spawn_in(window, async move |this, cx| { cx.spawn_in(window, async move |this, cx| {
let matches = search_task.await; let matches = search_task.await;
this.update(cx, |this, cx| { this.update(cx, |this, cx| {
@ -217,15 +217,15 @@ pub(crate) fn search_threads(
thread_store: Entity<ThreadStore>, thread_store: Entity<ThreadStore>,
cx: &mut App, cx: &mut App,
) -> Task<Vec<ThreadMatch>> { ) -> Task<Vec<ThreadMatch>> {
let threads = thread_store.update(cx, |this, _cx| { let threads = thread_store
this.threads() .read(cx)
.into_iter() .threads()
.map(|thread| ThreadContextEntry { .into_iter()
id: thread.id, .map(|thread| ThreadContextEntry {
summary: thread.summary, id: thread.id,
}) summary: thread.summary,
.collect::<Vec<_>>() })
}); .collect::<Vec<_>>();
let executor = cx.background_executor().clone(); let executor = cx.background_executor().clone();
cx.background_spawn(async move { cx.background_spawn(async move {

View file

@ -9,6 +9,7 @@ use futures::{self, Future, FutureExt, future};
use gpui::{App, AppContext as _, Context, Entity, SharedString, Task, WeakEntity}; use gpui::{App, AppContext as _, Context, Entity, SharedString, Task, WeakEntity};
use language::{Buffer, File}; use language::{Buffer, File};
use project::{Project, ProjectItem, ProjectPath, Worktree}; use project::{Project, ProjectItem, ProjectPath, Worktree};
use prompt_store::UserPromptId;
use rope::{Point, Rope}; use rope::{Point, Rope};
use text::{Anchor, BufferId, OffsetRangeExt}; use text::{Anchor, BufferId, OffsetRangeExt};
use util::{ResultExt as _, maybe}; use util::{ResultExt as _, maybe};
@ -16,7 +17,7 @@ use util::{ResultExt as _, maybe};
use crate::ThreadStore; use crate::ThreadStore;
use crate::context::{ use crate::context::{
AssistantContext, ContextBuffer, ContextId, ContextSymbol, ContextSymbolId, DirectoryContext, AssistantContext, ContextBuffer, ContextId, ContextSymbol, ContextSymbolId, DirectoryContext,
ExcerptContext, FetchedUrlContext, FileContext, SymbolContext, ThreadContext, ExcerptContext, FetchedUrlContext, FileContext, RulesContext, SymbolContext, ThreadContext,
}; };
use crate::context_strip::SuggestedContext; use crate::context_strip::SuggestedContext;
use crate::thread::{Thread, ThreadId}; use crate::thread::{Thread, ThreadId};
@ -25,7 +26,6 @@ pub struct ContextStore {
project: WeakEntity<Project>, project: WeakEntity<Project>,
context: Vec<AssistantContext>, context: Vec<AssistantContext>,
thread_store: Option<WeakEntity<ThreadStore>>, thread_store: Option<WeakEntity<ThreadStore>>,
// TODO: If an EntityId is used for all context types (like BufferId), can remove ContextId.
next_context_id: ContextId, next_context_id: ContextId,
files: BTreeMap<BufferId, ContextId>, files: BTreeMap<BufferId, ContextId>,
directories: HashMap<ProjectPath, ContextId>, directories: HashMap<ProjectPath, ContextId>,
@ -35,6 +35,7 @@ pub struct ContextStore {
threads: HashMap<ThreadId, ContextId>, threads: HashMap<ThreadId, ContextId>,
thread_summary_tasks: Vec<Task<()>>, thread_summary_tasks: Vec<Task<()>>,
fetched_urls: HashMap<String, ContextId>, fetched_urls: HashMap<String, ContextId>,
user_rules: HashMap<UserPromptId, ContextId>,
} }
impl ContextStore { impl ContextStore {
@ -55,6 +56,7 @@ impl ContextStore {
threads: HashMap::default(), threads: HashMap::default(),
thread_summary_tasks: Vec::new(), thread_summary_tasks: Vec::new(),
fetched_urls: HashMap::default(), fetched_urls: HashMap::default(),
user_rules: HashMap::default(),
} }
} }
@ -72,6 +74,7 @@ impl ContextStore {
self.directories.clear(); self.directories.clear();
self.threads.clear(); self.threads.clear();
self.fetched_urls.clear(); self.fetched_urls.clear();
self.user_rules.clear();
} }
pub fn add_file_from_path( pub fn add_file_from_path(
@ -390,6 +393,42 @@ impl ContextStore {
cx.notify(); cx.notify();
} }
pub fn add_rules(
&mut self,
prompt_id: UserPromptId,
title: impl Into<SharedString>,
text: impl Into<SharedString>,
remove_if_exists: bool,
cx: &mut Context<ContextStore>,
) {
if let Some(context_id) = self.includes_user_rules(&prompt_id) {
if remove_if_exists {
self.remove_context(context_id, cx);
}
} else {
self.insert_user_rules(prompt_id, title, text, cx);
}
}
pub fn insert_user_rules(
&mut self,
prompt_id: UserPromptId,
title: impl Into<SharedString>,
text: impl Into<SharedString>,
cx: &mut Context<ContextStore>,
) {
let id = self.next_context_id.post_inc();
self.user_rules.insert(prompt_id, id);
self.context.push(AssistantContext::Rules(RulesContext {
id,
prompt_id,
title: title.into(),
text: text.into(),
}));
cx.notify();
}
pub fn add_fetched_url( pub fn add_fetched_url(
&mut self, &mut self,
url: String, url: String,
@ -518,6 +557,9 @@ impl ContextStore {
AssistantContext::Thread(_) => { AssistantContext::Thread(_) => {
self.threads.retain(|_, context_id| *context_id != id); self.threads.retain(|_, context_id| *context_id != id);
} }
AssistantContext::Rules(RulesContext { prompt_id, .. }) => {
self.user_rules.remove(&prompt_id);
}
} }
cx.notify(); cx.notify();
@ -614,6 +656,10 @@ impl ContextStore {
self.threads.get(thread_id).copied() self.threads.get(thread_id).copied()
} }
pub fn includes_user_rules(&self, prompt_id: &UserPromptId) -> Option<ContextId> {
self.user_rules.get(prompt_id).copied()
}
pub fn includes_url(&self, url: &str) -> Option<ContextId> { pub fn includes_url(&self, url: &str) -> Option<ContextId> {
self.fetched_urls.get(url).copied() self.fetched_urls.get(url).copied()
} }
@ -641,7 +687,8 @@ impl ContextStore {
| AssistantContext::Symbol(_) | AssistantContext::Symbol(_)
| AssistantContext::Excerpt(_) | AssistantContext::Excerpt(_)
| AssistantContext::FetchedUrl(_) | AssistantContext::FetchedUrl(_)
| AssistantContext::Thread(_) => None, | AssistantContext::Thread(_)
| AssistantContext::Rules(_) => None,
}) })
.collect() .collect()
} }
@ -876,6 +923,10 @@ pub fn refresh_context_store_text(
// and doing the caching properly could be tricky (unless it's already handled by // and doing the caching properly could be tricky (unless it's already handled by
// the HttpClient?). // the HttpClient?).
AssistantContext::FetchedUrl(_) => {} AssistantContext::FetchedUrl(_) => {}
AssistantContext::Rules(user_rules_context) => {
let context_store = context_store.clone();
return Some(refresh_user_rules(context_store, user_rules_context, cx));
}
} }
None None
@ -1026,6 +1077,45 @@ fn refresh_thread_text(
}) })
} }
fn refresh_user_rules(
context_store: Entity<ContextStore>,
user_rules_context: &RulesContext,
cx: &App,
) -> Task<()> {
let id = user_rules_context.id;
let prompt_id = user_rules_context.prompt_id;
let Some(thread_store) = context_store.read(cx).thread_store.as_ref() else {
return Task::ready(());
};
let Ok(load_task) = thread_store.read_with(cx, |thread_store, cx| {
thread_store.load_rules(prompt_id, cx)
}) else {
return Task::ready(());
};
cx.spawn(async move |cx| {
if let Ok((metadata, text)) = load_task.await {
if let Some(title) = metadata.title.clone() {
context_store
.update(cx, |context_store, _cx| {
context_store.replace_context(AssistantContext::Rules(RulesContext {
id,
prompt_id,
title,
text: text.into(),
}));
})
.ok();
return;
}
}
context_store
.update(cx, |context_store, cx| {
context_store.remove_context(id, cx);
})
.ok();
})
}
fn refresh_context_buffer( fn refresh_context_buffer(
context_buffer: &ContextBuffer, context_buffer: &ContextBuffer,
cx: &App, cx: &App,

View file

@ -774,7 +774,9 @@ impl Thread {
cx, cx,
); );
} }
AssistantContext::FetchedUrl(_) | AssistantContext::Thread(_) => {} AssistantContext::FetchedUrl(_)
| AssistantContext::Thread(_)
| AssistantContext::Rules(_) => {}
} }
} }
}); });

View file

@ -24,8 +24,8 @@ use heed::types::SerdeBincode;
use language_model::{LanguageModelToolUseId, Role, TokenUsage}; use language_model::{LanguageModelToolUseId, Role, TokenUsage};
use project::{Project, Worktree}; use project::{Project, Worktree};
use prompt_store::{ use prompt_store::{
DefaultUserRulesContext, ProjectContext, PromptBuilder, PromptId, PromptStore, ProjectContext, PromptBuilder, PromptId, PromptMetadata, PromptStore, PromptsUpdatedEvent,
PromptsUpdatedEvent, RulesFileContext, WorktreeContext, RulesFileContext, UserPromptId, UserRulesContext, WorktreeContext,
}; };
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use settings::{Settings as _, SettingsStore}; use settings::{Settings as _, SettingsStore};
@ -62,6 +62,7 @@ pub struct ThreadStore {
project: Entity<Project>, project: Entity<Project>,
tools: Entity<ToolWorkingSet>, tools: Entity<ToolWorkingSet>,
prompt_builder: Arc<PromptBuilder>, prompt_builder: Arc<PromptBuilder>,
prompt_store: Option<Entity<PromptStore>>,
context_server_manager: Entity<ContextServerManager>, context_server_manager: Entity<ContextServerManager>,
context_server_tool_ids: HashMap<Arc<str>, Vec<ToolId>>, context_server_tool_ids: HashMap<Arc<str>, Vec<ToolId>>,
threads: Vec<SerializedThreadMetadata>, threads: Vec<SerializedThreadMetadata>,
@ -135,6 +136,7 @@ impl ThreadStore {
let (ready_tx, ready_rx) = oneshot::channel(); let (ready_tx, ready_rx) = oneshot::channel();
let mut ready_tx = Some(ready_tx); let mut ready_tx = Some(ready_tx);
let reload_system_prompt_task = cx.spawn({ let reload_system_prompt_task = cx.spawn({
let prompt_store = prompt_store.clone();
async move |thread_store, cx| { async move |thread_store, cx| {
loop { loop {
let Some(reload_task) = thread_store let Some(reload_task) = thread_store
@ -158,6 +160,7 @@ impl ThreadStore {
project, project,
tools, tools,
prompt_builder, prompt_builder,
prompt_store,
context_server_manager, context_server_manager,
context_server_tool_ids: HashMap::default(), context_server_tool_ids: HashMap::default(),
threads: Vec::new(), threads: Vec::new(),
@ -245,7 +248,7 @@ impl ThreadStore {
let default_user_rules = default_user_rules let default_user_rules = default_user_rules
.into_iter() .into_iter()
.flat_map(|(contents, prompt_metadata)| match contents { .flat_map(|(contents, prompt_metadata)| match contents {
Ok(contents) => Some(DefaultUserRulesContext { Ok(contents) => Some(UserRulesContext {
uuid: match prompt_metadata.id { uuid: match prompt_metadata.id {
PromptId::User { uuid } => uuid, PromptId::User { uuid } => uuid,
PromptId::EditWorkflow => return None, PromptId::EditWorkflow => return None,
@ -346,6 +349,27 @@ impl ThreadStore {
self.context_server_manager.clone() self.context_server_manager.clone()
} }
pub fn prompt_store(&self) -> Option<Entity<PromptStore>> {
self.prompt_store.clone()
}
pub fn load_rules(
&self,
prompt_id: UserPromptId,
cx: &App,
) -> Task<Result<(PromptMetadata, String)>> {
let prompt_id = PromptId::User { uuid: prompt_id };
let Some(prompt_store) = self.prompt_store.as_ref() else {
return Task::ready(Err(anyhow!("Prompt store unexpectedly missing.")));
};
let prompt_store = prompt_store.read(cx);
let Some(metadata) = prompt_store.metadata(prompt_id) else {
return Task::ready(Err(anyhow!("User rules not found in library.")));
};
let text_task = prompt_store.load(prompt_id, cx);
cx.background_spawn(async move { Ok((metadata, text_task.await?)) })
}
pub fn tools(&self) -> Entity<ToolWorkingSet> { pub fn tools(&self) -> Entity<ToolWorkingSet> {
self.tools.clone() self.tools.clone()
} }

View file

@ -354,6 +354,16 @@ impl AddedContext {
.read(cx) .read(cx)
.is_generating_detailed_summary(), .is_generating_detailed_summary(),
}, },
AssistantContext::Rules(user_rules_context) => AddedContext {
id: user_rules_context.id,
kind: ContextKind::Rules,
name: user_rules_context.title.clone(),
parent: None,
tooltip: None,
icon_path: None,
summarizing: false,
},
} }
} }
} }

View file

@ -27,7 +27,7 @@ use language_model::{
}; };
use project::Project; use project::Project;
use prompt_library::{PromptLibrary, open_prompt_library}; use prompt_library::{PromptLibrary, open_prompt_library};
use prompt_store::{PromptBuilder, PromptId}; use prompt_store::{PromptBuilder, PromptId, UserPromptId};
use search::{BufferSearchBar, buffer_search::DivRegistrar}; use search::{BufferSearchBar, buffer_search::DivRegistrar};
use settings::{Settings, update_settings_file}; use settings::{Settings, update_settings_file};
@ -58,11 +58,11 @@ pub fn init(cx: &mut App) {
.register_action(AssistantPanel::show_configuration) .register_action(AssistantPanel::show_configuration)
.register_action(AssistantPanel::create_new_context) .register_action(AssistantPanel::create_new_context)
.register_action(AssistantPanel::restart_context_servers) .register_action(AssistantPanel::restart_context_servers)
.register_action(|workspace, _: &OpenPromptLibrary, window, cx| { .register_action(|workspace, action: &OpenPromptLibrary, window, cx| {
if let Some(panel) = workspace.panel::<AssistantPanel>(cx) { if let Some(panel) = workspace.panel::<AssistantPanel>(cx) {
workspace.focus_panel::<AssistantPanel>(window, cx); workspace.focus_panel::<AssistantPanel>(window, cx);
panel.update(cx, |panel, cx| { panel.update(cx, |panel, cx| {
panel.deploy_prompt_library(&OpenPromptLibrary::default(), window, cx) panel.deploy_prompt_library(action, window, cx)
}); });
} }
}); });
@ -1060,7 +1060,9 @@ impl AssistantPanel {
None, None,
)) ))
}), }),
action.prompt_to_focus.map(|uuid| PromptId::User { uuid }), action.prompt_to_select.map(|uuid| PromptId::User {
uuid: UserPromptId(uuid),
}),
cx, cx,
) )
.detach_and_log_err(cx); .detach_and_log_err(cx);

View file

@ -44,9 +44,10 @@ impl SlashCommand for PromptSlashCommand {
let store = PromptStore::global(cx); let store = PromptStore::global(cx);
let query = arguments.to_owned().join(" "); let query = arguments.to_owned().join(" ");
cx.spawn(async move |cx| { cx.spawn(async move |cx| {
let cancellation_flag = Arc::new(AtomicBool::default());
let prompts: Vec<PromptMetadata> = store let prompts: Vec<PromptMetadata> = store
.await? .await?
.read_with(cx, |store, cx| store.search(query, cx))? .read_with(cx, |store, cx| store.search(query, cancellation_flag, cx))?
.await; .await;
Ok(prompts Ok(prompts
.into_iter() .into_iter()

View file

@ -16,6 +16,7 @@ use release_channel::ReleaseChannel;
use rope::Rope; use rope::Rope;
use settings::Settings; use settings::Settings;
use std::sync::Arc; use std::sync::Arc;
use std::sync::atomic::AtomicBool;
use std::time::Duration; use std::time::Duration;
use theme::ThemeSettings; use theme::ThemeSettings;
use ui::{ use ui::{
@ -75,7 +76,7 @@ pub fn open_prompt_library(
language_registry: Arc<LanguageRegistry>, language_registry: Arc<LanguageRegistry>,
inline_assist_delegate: Box<dyn InlineAssistDelegate>, inline_assist_delegate: Box<dyn InlineAssistDelegate>,
make_completion_provider: Arc<dyn Fn() -> Box<dyn CompletionProvider>>, make_completion_provider: Arc<dyn Fn() -> Box<dyn CompletionProvider>>,
prompt_to_focus: Option<PromptId>, prompt_to_select: Option<PromptId>,
cx: &mut App, cx: &mut App,
) -> Task<Result<WindowHandle<PromptLibrary>>> { ) -> Task<Result<WindowHandle<PromptLibrary>>> {
let store = PromptStore::global(cx); let store = PromptStore::global(cx);
@ -90,8 +91,8 @@ pub fn open_prompt_library(
if let Some(existing_window) = existing_window { if let Some(existing_window) = existing_window {
existing_window existing_window
.update(cx, |prompt_library, window, cx| { .update(cx, |prompt_library, window, cx| {
if let Some(prompt_to_focus) = prompt_to_focus { if let Some(prompt_to_select) = prompt_to_select {
prompt_library.load_prompt(prompt_to_focus, true, window, cx); prompt_library.load_prompt(prompt_to_select, true, window, cx);
} }
window.activate_window() window.activate_window()
}) })
@ -126,18 +127,15 @@ pub fn open_prompt_library(
}, },
|window, cx| { |window, cx| {
cx.new(|cx| { cx.new(|cx| {
let mut prompt_library = PromptLibrary::new( PromptLibrary::new(
store, store,
language_registry, language_registry,
inline_assist_delegate, inline_assist_delegate,
make_completion_provider, make_completion_provider,
prompt_to_select,
window, window,
cx, cx,
); )
if let Some(prompt_to_focus) = prompt_to_focus {
prompt_library.load_prompt(prompt_to_focus, true, window, cx);
}
prompt_library
}) })
}, },
) )
@ -221,7 +219,8 @@ impl PickerDelegate for PromptPickerDelegate {
window: &mut Window, window: &mut Window,
cx: &mut Context<Picker<Self>>, cx: &mut Context<Picker<Self>>,
) -> Task<()> { ) -> Task<()> {
let search = self.store.read(cx).search(query, cx); let cancellation_flag = Arc::new(AtomicBool::default());
let search = self.store.read(cx).search(query, cancellation_flag, cx);
let prev_prompt_id = self.matches.get(self.selected_index).map(|mat| mat.id); let prev_prompt_id = self.matches.get(self.selected_index).map(|mat| mat.id);
cx.spawn_in(window, async move |this, cx| { cx.spawn_in(window, async move |this, cx| {
let (matches, selected_index) = cx let (matches, selected_index) = cx
@ -353,13 +352,26 @@ impl PromptLibrary {
language_registry: Arc<LanguageRegistry>, language_registry: Arc<LanguageRegistry>,
inline_assist_delegate: Box<dyn InlineAssistDelegate>, inline_assist_delegate: Box<dyn InlineAssistDelegate>,
make_completion_provider: Arc<dyn Fn() -> Box<dyn CompletionProvider>>, make_completion_provider: Arc<dyn Fn() -> Box<dyn CompletionProvider>>,
prompt_to_select: Option<PromptId>,
window: &mut Window, window: &mut Window,
cx: &mut Context<Self>, cx: &mut Context<Self>,
) -> Self { ) -> Self {
let (selected_index, matches) = if let Some(prompt_to_select) = prompt_to_select {
let matches = store.read(cx).all_prompt_metadata();
let selected_index = matches
.iter()
.enumerate()
.find(|(_, metadata)| metadata.id == prompt_to_select)
.map_or(0, |(ix, _)| ix);
(selected_index, matches)
} else {
(0, vec![])
};
let delegate = PromptPickerDelegate { let delegate = PromptPickerDelegate {
store: store.clone(), store: store.clone(),
selected_index: 0, selected_index,
matches: Vec::new(), matches,
}; };
let picker = cx.new(|cx| { let picker = cx.new(|cx| {

View file

@ -54,14 +54,14 @@ pub struct PromptMetadata {
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)] #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(tag = "kind")] #[serde(tag = "kind")]
pub enum PromptId { pub enum PromptId {
User { uuid: Uuid }, User { uuid: UserPromptId },
EditWorkflow, EditWorkflow,
} }
impl PromptId { impl PromptId {
pub fn new() -> PromptId { pub fn new() -> PromptId {
PromptId::User { PromptId::User {
uuid: Uuid::new_v4(), uuid: UserPromptId::new(),
} }
} }
@ -70,6 +70,22 @@ impl PromptId {
} }
} }
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(transparent)]
pub struct UserPromptId(pub Uuid);
impl UserPromptId {
pub fn new() -> UserPromptId {
UserPromptId(Uuid::new_v4())
}
}
impl From<Uuid> for UserPromptId {
fn from(uuid: Uuid) -> Self {
UserPromptId(uuid)
}
}
pub struct PromptStore { pub struct PromptStore {
env: heed::Env, env: heed::Env,
metadata_cache: RwLock<MetadataCache>, metadata_cache: RwLock<MetadataCache>,
@ -212,7 +228,7 @@ impl PromptStore {
for (prompt_id_v1, metadata_v1) in metadata_v1 { for (prompt_id_v1, metadata_v1) in metadata_v1 {
let prompt_id_v2 = PromptId::User { let prompt_id_v2 = PromptId::User {
uuid: prompt_id_v1.0, uuid: UserPromptId(prompt_id_v1.0),
}; };
let Some(body_v1) = bodies_v1.remove(&prompt_id_v1) else { let Some(body_v1) = bodies_v1.remove(&prompt_id_v1) else {
continue; continue;
@ -257,6 +273,10 @@ impl PromptStore {
}) })
} }
pub fn all_prompt_metadata(&self) -> Vec<PromptMetadata> {
self.metadata_cache.read().metadata.clone()
}
pub fn default_prompt_metadata(&self) -> Vec<PromptMetadata> { pub fn default_prompt_metadata(&self) -> Vec<PromptMetadata> {
return self return self
.metadata_cache .metadata_cache
@ -314,7 +334,12 @@ impl PromptStore {
Some(metadata.id) Some(metadata.id)
} }
pub fn search(&self, query: String, cx: &App) -> Task<Vec<PromptMetadata>> { pub fn search(
&self,
query: String,
cancellation_flag: Arc<AtomicBool>,
cx: &App,
) -> Task<Vec<PromptMetadata>> {
let cached_metadata = self.metadata_cache.read().metadata.clone(); let cached_metadata = self.metadata_cache.read().metadata.clone();
let executor = cx.background_executor().clone(); let executor = cx.background_executor().clone();
cx.background_spawn(async move { cx.background_spawn(async move {
@ -333,7 +358,7 @@ impl PromptStore {
&query, &query,
false, false,
100, 100,
&AtomicBool::default(), &cancellation_flag,
executor, executor,
) )
.await; .await;

View file

@ -15,34 +15,32 @@ use std::{
}; };
use text::LineEnding; use text::LineEnding;
use util::{ResultExt, get_system_shell}; use util::{ResultExt, get_system_shell};
use uuid::Uuid;
use crate::UserPromptId;
#[derive(Debug, Clone, Serialize)] #[derive(Debug, Clone, Serialize)]
pub struct ProjectContext { pub struct ProjectContext {
pub worktrees: Vec<WorktreeContext>, pub worktrees: Vec<WorktreeContext>,
/// Whether any worktree has a rules_file. Provided as a field because handlebars can't do this. /// Whether any worktree has a rules_file. Provided as a field because handlebars can't do this.
pub has_rules: bool, pub has_rules: bool,
pub default_user_rules: Vec<DefaultUserRulesContext>, pub user_rules: Vec<UserRulesContext>,
/// `!default_user_rules.is_empty()` - provided as a field because handlebars can't do this. /// `!user_rules.is_empty()` - provided as a field because handlebars can't do this.
pub has_default_user_rules: bool, pub has_user_rules: bool,
pub os: String, pub os: String,
pub arch: String, pub arch: String,
pub shell: String, pub shell: String,
} }
impl ProjectContext { impl ProjectContext {
pub fn new( pub fn new(worktrees: Vec<WorktreeContext>, default_user_rules: Vec<UserRulesContext>) -> Self {
worktrees: Vec<WorktreeContext>,
default_user_rules: Vec<DefaultUserRulesContext>,
) -> Self {
let has_rules = worktrees let has_rules = worktrees
.iter() .iter()
.any(|worktree| worktree.rules_file.is_some()); .any(|worktree| worktree.rules_file.is_some());
Self { Self {
worktrees, worktrees,
has_rules, has_rules,
has_default_user_rules: !default_user_rules.is_empty(), has_user_rules: !default_user_rules.is_empty(),
default_user_rules, user_rules: default_user_rules,
os: std::env::consts::OS.to_string(), os: std::env::consts::OS.to_string(),
arch: std::env::consts::ARCH.to_string(), arch: std::env::consts::ARCH.to_string(),
shell: get_system_shell(), shell: get_system_shell(),
@ -51,8 +49,8 @@ impl ProjectContext {
} }
#[derive(Debug, Clone, Serialize)] #[derive(Debug, Clone, Serialize)]
pub struct DefaultUserRulesContext { pub struct UserRulesContext {
pub uuid: Uuid, pub uuid: UserPromptId,
pub title: Option<String>, pub title: Option<String>,
pub contents: String, pub contents: String,
} }
@ -397,6 +395,7 @@ impl PromptBuilder {
#[cfg(test)] #[cfg(test)]
mod test { mod test {
use super::*; use super::*;
use uuid::Uuid;
#[test] #[test]
fn test_assistant_system_prompt_renders() { fn test_assistant_system_prompt_renders() {
@ -408,8 +407,8 @@ mod test {
text: "".into(), text: "".into(),
}), }),
}]; }];
let default_user_rules = vec![DefaultUserRulesContext { let default_user_rules = vec![UserRulesContext {
uuid: Uuid::nil(), uuid: UserPromptId(Uuid::nil()),
title: Some("Rules title".into()), title: Some("Rules title".into()),
contents: "Rules contents".into(), contents: "Rules contents".into(),
}]; }];

View file

@ -201,7 +201,7 @@ pub mod assistant {
#[serde(deny_unknown_fields)] #[serde(deny_unknown_fields)]
pub struct OpenPromptLibrary { pub struct OpenPromptLibrary {
#[serde(skip)] #[serde(skip)]
pub prompt_to_focus: Option<Uuid>, pub prompt_to_select: Option<Uuid>,
} }
impl_action_with_deprecated_aliases!( impl_action_with_deprecated_aliases!(