Add ability to attach rules as context (#29109)

Release Notes:

- agent: Added support for adding rules as context.
This commit is contained in:
Michael Sloan 2025-04-21 14:16:51 -06:00 committed by GitHub
parent 3b31860d52
commit 7aa0fa1543
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
18 changed files with 694 additions and 112 deletions

View file

@ -73,9 +73,9 @@ There are project rules that apply to these root directories:
{{/each}}
{{/if}}
{{#if has_default_user_rules}}
{{#if has_user_rules}}
The user has specified the following rules that should be applied:
{{#each default_user_rules}}
{{#each user_rules}}
{{#if title}}
Rules title: {{title}}

View file

@ -1,4 +1,4 @@
use crate::context::{AssistantContext, ContextId, format_context_as_string};
use crate::context::{AssistantContext, ContextId, RULES_ICON, format_context_as_string};
use crate::context_picker::MentionLink;
use crate::thread::{
LastRestoreCheckpoint, MessageId, MessageSegment, Thread, ThreadError, ThreadEvent,
@ -688,6 +688,12 @@ fn open_markdown_link(
}
}),
Some(MentionLink::Fetch(url)) => cx.open_url(&url),
Some(MentionLink::Rules(prompt_id)) => window.dispatch_action(
Box::new(OpenPromptLibrary {
prompt_to_select: Some(prompt_id.0),
}),
cx,
),
None => cx.open_url(&text),
}
}
@ -2957,10 +2963,10 @@ impl ActiveThread {
return div().into_any();
};
let default_user_rules_text = if project_context.default_user_rules.is_empty() {
let user_rules_text = if project_context.user_rules.is_empty() {
None
} else if project_context.default_user_rules.len() == 1 {
let user_rules = &project_context.default_user_rules[0];
} else if project_context.user_rules.len() == 1 {
let user_rules = &project_context.user_rules[0];
match user_rules.title.as_ref() {
Some(title) => Some(format!("Using \"{title}\" user rule")),
@ -2969,14 +2975,14 @@ impl ActiveThread {
} else {
Some(format!(
"Using {} user rules",
project_context.default_user_rules.len()
project_context.user_rules.len()
))
};
let first_default_user_rules_id = project_context
.default_user_rules
let first_user_rules_id = project_context
.user_rules
.first()
.map(|user_rules| user_rules.uuid);
.map(|user_rules| user_rules.uuid.0);
let rules_files = project_context
.worktrees
@ -2993,7 +2999,7 @@ impl ActiveThread {
rules_files => Some(format!("Using {} project rules files", rules_files.len())),
};
if default_user_rules_text.is_none() && rules_file_text.is_none() {
if user_rules_text.is_none() && rules_file_text.is_none() {
return div().into_any();
}
@ -3001,45 +3007,42 @@ impl ActiveThread {
.pt_2()
.px_2p5()
.gap_1()
.when_some(
default_user_rules_text,
|parent, default_user_rules_text| {
parent.child(
h_flex()
.w_full()
.child(
Icon::new(IconName::File)
.size(IconSize::XSmall)
.color(Color::Disabled),
)
.child(
Label::new(default_user_rules_text)
.size(LabelSize::XSmall)
.color(Color::Muted)
.truncate()
.buffer_font(cx)
.ml_1p5()
.mr_0p5(),
)
.child(
IconButton::new("open-prompt-library", IconName::ArrowUpRightAlt)
.shape(ui::IconButtonShape::Square)
.icon_size(IconSize::XSmall)
.icon_color(Color::Ignored)
// TODO: Figure out a way to pass focus handle here so we can display the `OpenPromptLibrary` keybinding
.tooltip(Tooltip::text("View User Rules"))
.on_click(move |_event, window, cx| {
window.dispatch_action(
Box::new(OpenPromptLibrary {
prompt_to_focus: first_default_user_rules_id,
}),
cx,
)
}),
),
)
},
)
.when_some(user_rules_text, |parent, user_rules_text| {
parent.child(
h_flex()
.w_full()
.child(
Icon::new(RULES_ICON)
.size(IconSize::XSmall)
.color(Color::Disabled),
)
.child(
Label::new(user_rules_text)
.size(LabelSize::XSmall)
.color(Color::Muted)
.truncate()
.buffer_font(cx)
.ml_1p5()
.mr_0p5(),
)
.child(
IconButton::new("open-prompt-library", IconName::ArrowUpRightAlt)
.shape(ui::IconButtonShape::Square)
.icon_size(IconSize::XSmall)
.icon_color(Color::Ignored)
// TODO: Figure out a way to pass focus handle here so we can display the `OpenPromptLibrary` keybinding
.tooltip(Tooltip::text("View User Rules"))
.on_click(move |_event, window, cx| {
window.dispatch_action(
Box::new(OpenPromptLibrary {
prompt_to_select: first_user_rules_id,
}),
cx,
)
}),
),
)
})
.when_some(rules_file_text, |parent, rules_file_text| {
parent.child(
h_flex()
@ -3316,6 +3319,12 @@ pub(crate) fn open_context(
}
})
}
AssistantContext::Rules(rules_context) => window.dispatch_action(
Box::new(OpenPromptLibrary {
prompt_to_select: Some(rules_context.prompt_id.0),
}),
cx,
),
}
}

View file

@ -25,7 +25,7 @@ use language_model::{LanguageModelProviderTosView, LanguageModelRegistry};
use language_model_selector::ToggleModelSelector;
use project::Project;
use prompt_library::{PromptLibrary, open_prompt_library};
use prompt_store::{PromptBuilder, PromptId};
use prompt_store::{PromptBuilder, PromptId, UserPromptId};
use proto::Plan;
use settings::{Settings, update_settings_file};
use time::UtcOffset;
@ -79,11 +79,11 @@ pub fn init(cx: &mut App) {
panel.update(cx, |panel, cx| panel.new_prompt_editor(window, cx));
}
})
.register_action(|workspace, _: &OpenPromptLibrary, window, cx| {
.register_action(|workspace, action: &OpenPromptLibrary, window, cx| {
if let Some(panel) = workspace.panel::<AssistantPanel>(cx) {
workspace.focus_panel::<AssistantPanel>(window, cx);
panel.update(cx, |panel, cx| {
panel.deploy_prompt_library(&OpenPromptLibrary::default(), window, cx)
panel.deploy_prompt_library(action, window, cx)
});
}
})
@ -502,7 +502,9 @@ impl AssistantPanel {
None,
))
}),
action.prompt_to_focus.map(|uuid| PromptId::User { uuid }),
action.prompt_to_select.map(|uuid| PromptId::User {
uuid: UserPromptId(uuid),
}),
cx,
)
.detach_and_log_err(cx);

View file

@ -4,6 +4,7 @@ use gpui::{App, Entity, SharedString};
use language::{Buffer, File};
use language_model::LanguageModelRequestMessage;
use project::{ProjectPath, Worktree};
use prompt_store::UserPromptId;
use rope::Point;
use serde::{Deserialize, Serialize};
use text::{Anchor, BufferId};
@ -12,6 +13,8 @@ use util::post_inc;
use crate::thread::Thread;
pub const RULES_ICON: IconName = IconName::Context;
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
pub struct ContextId(pub(crate) usize);
@ -20,6 +23,7 @@ impl ContextId {
Self(post_inc(&mut self.0))
}
}
pub enum ContextKind {
File,
Directory,
@ -27,6 +31,7 @@ pub enum ContextKind {
Excerpt,
FetchedUrl,
Thread,
Rules,
}
impl ContextKind {
@ -38,6 +43,7 @@ impl ContextKind {
ContextKind::Excerpt => IconName::Code,
ContextKind::FetchedUrl => IconName::Globe,
ContextKind::Thread => IconName::MessageBubbles,
ContextKind::Rules => RULES_ICON,
}
}
}
@ -50,6 +56,7 @@ pub enum AssistantContext {
FetchedUrl(FetchedUrlContext),
Thread(ThreadContext),
Excerpt(ExcerptContext),
Rules(RulesContext),
}
impl AssistantContext {
@ -61,6 +68,7 @@ impl AssistantContext {
Self::FetchedUrl(url) => url.id,
Self::Thread(thread) => thread.id,
Self::Excerpt(excerpt) => excerpt.id,
Self::Rules(rules) => rules.id,
}
}
}
@ -168,6 +176,14 @@ pub struct ExcerptContext {
pub context_buffer: ContextBuffer,
}
#[derive(Debug, Clone)]
pub struct RulesContext {
pub id: ContextId,
pub prompt_id: UserPromptId,
pub title: SharedString,
pub text: SharedString,
}
/// Formats a collection of contexts into a string representation
pub fn format_context_as_string<'a>(
contexts: impl Iterator<Item = &'a AssistantContext>,
@ -179,6 +195,7 @@ pub fn format_context_as_string<'a>(
let mut excerpt_context = Vec::new();
let mut fetch_context = Vec::new();
let mut thread_context = Vec::new();
let mut rules_context = Vec::new();
for context in contexts {
match context {
@ -188,6 +205,7 @@ pub fn format_context_as_string<'a>(
AssistantContext::Excerpt(context) => excerpt_context.push(context),
AssistantContext::FetchedUrl(context) => fetch_context.push(context),
AssistantContext::Thread(context) => thread_context.push(context),
AssistantContext::Rules(context) => rules_context.push(context),
}
}
@ -197,6 +215,7 @@ pub fn format_context_as_string<'a>(
&& excerpt_context.is_empty()
&& fetch_context.is_empty()
&& thread_context.is_empty()
&& rules_context.is_empty()
{
return None;
}
@ -263,6 +282,18 @@ pub fn format_context_as_string<'a>(
result.push_str("</conversation_threads>\n");
}
if !rules_context.is_empty() {
result.push_str(
"<user_rules>\n\
The user has specified the following rules that should be applied:\n\n",
);
for context in &rules_context {
result.push_str(&context.text);
result.push('\n');
}
result.push_str("</user_rules>\n");
}
result.push_str("</context>\n");
Some(result)
}

View file

@ -1,6 +1,7 @@
mod completion_provider;
mod fetch_context_picker;
mod file_context_picker;
mod rules_context_picker;
mod symbol_context_picker;
mod thread_context_picker;
@ -18,17 +19,22 @@ use gpui::{
};
use multi_buffer::MultiBufferRow;
use project::{Entry, ProjectPath};
use prompt_store::UserPromptId;
use rules_context_picker::RulesContextEntry;
use symbol_context_picker::SymbolContextPicker;
use thread_context_picker::{ThreadContextEntry, render_thread_context_entry};
use ui::{
ButtonLike, ContextMenu, ContextMenuEntry, ContextMenuItem, Disclosure, TintColor, prelude::*,
};
use uuid::Uuid;
use workspace::{Workspace, notifications::NotifyResultExt};
use crate::AssistantPanel;
use crate::context::RULES_ICON;
pub use crate::context_picker::completion_provider::ContextPickerCompletionProvider;
use crate::context_picker::fetch_context_picker::FetchContextPicker;
use crate::context_picker::file_context_picker::FileContextPicker;
use crate::context_picker::rules_context_picker::RulesContextPicker;
use crate::context_picker::thread_context_picker::ThreadContextPicker;
use crate::context_store::ContextStore;
use crate::thread::ThreadId;
@ -40,6 +46,7 @@ enum ContextPickerMode {
Symbol,
Fetch,
Thread,
Rules,
}
impl TryFrom<&str> for ContextPickerMode {
@ -51,6 +58,7 @@ impl TryFrom<&str> for ContextPickerMode {
"symbol" => Ok(Self::Symbol),
"fetch" => Ok(Self::Fetch),
"thread" => Ok(Self::Thread),
"rules" => Ok(Self::Rules),
_ => Err(format!("Invalid context picker mode: {}", value)),
}
}
@ -63,6 +71,7 @@ impl ContextPickerMode {
Self::Symbol => "symbol",
Self::Fetch => "fetch",
Self::Thread => "thread",
Self::Rules => "rules",
}
}
@ -72,6 +81,7 @@ impl ContextPickerMode {
Self::Symbol => "Symbols",
Self::Fetch => "Fetch",
Self::Thread => "Threads",
Self::Rules => "Rules",
}
}
@ -81,6 +91,7 @@ impl ContextPickerMode {
Self::Symbol => IconName::Code,
Self::Fetch => IconName::Globe,
Self::Thread => IconName::MessageBubbles,
Self::Rules => RULES_ICON,
}
}
}
@ -92,6 +103,7 @@ enum ContextPickerState {
Symbol(Entity<SymbolContextPicker>),
Fetch(Entity<FetchContextPicker>),
Thread(Entity<ThreadContextPicker>),
Rules(Entity<RulesContextPicker>),
}
pub(super) struct ContextPicker {
@ -253,6 +265,19 @@ impl ContextPicker {
}));
}
}
ContextPickerMode::Rules => {
if let Some(thread_store) = self.thread_store.as_ref() {
self.mode = ContextPickerState::Rules(cx.new(|cx| {
RulesContextPicker::new(
thread_store.clone(),
context_picker.clone(),
self.context_store.clone(),
window,
cx,
)
}));
}
}
}
cx.notify();
@ -381,6 +406,7 @@ impl ContextPicker {
ContextPickerState::Symbol(entity) => entity.update(cx, |_, cx| cx.notify()),
ContextPickerState::Fetch(entity) => entity.update(cx, |_, cx| cx.notify()),
ContextPickerState::Thread(entity) => entity.update(cx, |_, cx| cx.notify()),
ContextPickerState::Rules(entity) => entity.update(cx, |_, cx| cx.notify()),
}
}
}
@ -395,6 +421,7 @@ impl Focusable for ContextPicker {
ContextPickerState::Symbol(symbol_picker) => symbol_picker.focus_handle(cx),
ContextPickerState::Fetch(fetch_picker) => fetch_picker.focus_handle(cx),
ContextPickerState::Thread(thread_picker) => thread_picker.focus_handle(cx),
ContextPickerState::Rules(user_rules_picker) => user_rules_picker.focus_handle(cx),
}
}
}
@ -410,6 +437,9 @@ impl Render for ContextPicker {
ContextPickerState::Symbol(symbol_picker) => parent.child(symbol_picker.clone()),
ContextPickerState::Fetch(fetch_picker) => parent.child(fetch_picker.clone()),
ContextPickerState::Thread(thread_picker) => parent.child(thread_picker.clone()),
ContextPickerState::Rules(user_rules_picker) => {
parent.child(user_rules_picker.clone())
}
})
}
}
@ -431,6 +461,7 @@ fn supported_context_picker_modes(
];
if thread_store.is_some() {
modes.push(ContextPickerMode::Thread);
modes.push(ContextPickerMode::Rules);
}
modes
}
@ -626,6 +657,7 @@ pub enum MentionLink {
Symbol(ProjectPath, String),
Fetch(String),
Thread(ThreadId),
Rules(UserPromptId),
}
impl MentionLink {
@ -633,14 +665,16 @@ impl MentionLink {
const SYMBOL: &str = "@symbol";
const THREAD: &str = "@thread";
const FETCH: &str = "@fetch";
const RULES: &str = "@rules";
const SEPARATOR: &str = ":";
pub fn is_valid(url: &str) -> bool {
url.starts_with(Self::FILE)
|| url.starts_with(Self::SYMBOL)
|| url.starts_with(Self::FETCH)
|| url.starts_with(Self::THREAD)
|| url.starts_with(Self::FETCH)
|| url.starts_with(Self::RULES)
}
pub fn for_file(file_name: &str, full_path: &str) -> String {
@ -657,12 +691,16 @@ impl MentionLink {
)
}
pub fn for_thread(thread: &ThreadContextEntry) -> String {
format!("[@{}]({}:{})", thread.summary, Self::THREAD, thread.id)
}
pub fn for_fetch(url: &str) -> String {
format!("[@{}]({}:{})", url, Self::FETCH, url)
}
pub fn for_thread(thread: &ThreadContextEntry) -> String {
format!("[@{}]({}:{})", thread.summary, Self::THREAD, thread.id)
pub fn for_rules(rules: &RulesContextEntry) -> String {
format!("[@{}]({}:{})", rules.title, Self::RULES, rules.prompt_id.0)
}
pub fn try_parse(link: &str, workspace: &Entity<Workspace>, cx: &App) -> Option<Self> {
@ -706,6 +744,10 @@ impl MentionLink {
Some(MentionLink::Thread(thread_id))
}
Self::FETCH => Some(MentionLink::Fetch(argument.to_string())),
Self::RULES => {
let prompt_id = UserPromptId(Uuid::try_parse(argument).ok()?);
Some(MentionLink::Rules(prompt_id))
}
_ => None,
}
}

View file

@ -14,11 +14,13 @@ use http_client::HttpClientWithUrl;
use language::{Buffer, CodeLabel, HighlightId};
use lsp::CompletionContext;
use project::{Completion, CompletionIntent, ProjectPath, Symbol, WorktreeId};
use prompt_store::PromptId;
use rope::Point;
use text::{Anchor, ToPoint};
use ui::prelude::*;
use workspace::Workspace;
use crate::context::RULES_ICON;
use crate::context_picker::file_context_picker::search_files;
use crate::context_picker::symbol_context_picker::search_symbols;
use crate::context_store::ContextStore;
@ -26,6 +28,7 @@ use crate::thread_store::ThreadStore;
use super::fetch_context_picker::fetch_url_content;
use super::file_context_picker::FileMatch;
use super::rules_context_picker::{RulesContextEntry, search_rules};
use super::symbol_context_picker::SymbolMatch;
use super::thread_context_picker::{ThreadContextEntry, ThreadMatch, search_threads};
use super::{
@ -38,6 +41,7 @@ pub(crate) enum Match {
File(FileMatch),
Thread(ThreadMatch),
Fetch(SharedString),
Rules(RulesContextEntry),
Mode(ModeMatch),
}
@ -54,6 +58,7 @@ impl Match {
Match::Thread(_) => 1.,
Match::Symbol(_) => 1.,
Match::Fetch(_) => 1.,
Match::Rules(_) => 1.,
}
}
}
@ -112,6 +117,21 @@ fn search(
Task::ready(Vec::new())
}
}
Some(ContextPickerMode::Rules) => {
if let Some(thread_store) = thread_store.as_ref().and_then(|t| t.upgrade()) {
let search_rules_task =
search_rules(query.clone(), cancellation_flag.clone(), thread_store, cx);
cx.background_spawn(async move {
search_rules_task
.await
.into_iter()
.map(Match::Rules)
.collect::<Vec<_>>()
})
} else {
Task::ready(Vec::new())
}
}
None => {
if query.is_empty() {
let mut matches = recent_entries
@ -287,6 +307,60 @@ impl ContextPickerCompletionProvider {
}
}
fn completion_for_rules(
rules: RulesContextEntry,
excerpt_id: ExcerptId,
source_range: Range<Anchor>,
editor: Entity<Editor>,
context_store: Entity<ContextStore>,
thread_store: Entity<ThreadStore>,
) -> Completion {
let new_text = MentionLink::for_rules(&rules);
let new_text_len = new_text.len();
Completion {
replace_range: source_range.clone(),
new_text,
label: CodeLabel::plain(rules.title.to_string(), None),
documentation: None,
insert_text_mode: None,
source: project::CompletionSource::Custom,
icon_path: Some(RULES_ICON.path().into()),
confirm: Some(confirm_completion_callback(
RULES_ICON.path().into(),
rules.title.clone(),
excerpt_id,
source_range.start,
new_text_len,
editor.clone(),
move |cx| {
let prompt_uuid = rules.prompt_id;
let prompt_id = PromptId::User { uuid: prompt_uuid };
let context_store = context_store.clone();
let Some(prompt_store) = thread_store.read(cx).prompt_store() else {
log::error!("Can't add user rules as prompt store is missing.");
return;
};
let prompt_store = prompt_store.read(cx);
let Some(metadata) = prompt_store.metadata(prompt_id) else {
return;
};
let Some(title) = metadata.title else {
return;
};
let text_task = prompt_store.load(prompt_id, cx);
cx.spawn(async move |cx| {
let text = text_task.await?;
context_store.update(cx, |context_store, cx| {
context_store.add_rules(prompt_uuid, title, text, false, cx)
})
})
.detach_and_log_err(cx);
},
)),
}
}
fn completion_for_fetch(
source_range: Range<Anchor>,
url_to_fetch: SharedString,
@ -593,6 +667,17 @@ impl CompletionProvider for ContextPickerCompletionProvider {
thread_store,
))
}
Match::Rules(user_rules) => {
let thread_store = thread_store.as_ref().and_then(|t| t.upgrade())?;
Some(Self::completion_for_rules(
user_rules,
excerpt_id,
source_range.clone(),
editor.clone(),
context_store.clone(),
thread_store,
))
}
Match::Fetch(url) => Some(Self::completion_for_fetch(
source_range.clone(),
url,

View file

@ -0,0 +1,248 @@
use std::sync::Arc;
use std::sync::atomic::AtomicBool;
use anyhow::anyhow;
use gpui::{App, DismissEvent, Entity, FocusHandle, Focusable, Task, WeakEntity};
use picker::{Picker, PickerDelegate};
use prompt_store::{PromptId, UserPromptId};
use ui::{ListItem, prelude::*};
use crate::context::RULES_ICON;
use crate::context_picker::ContextPicker;
use crate::context_store::{self, ContextStore};
use crate::thread_store::ThreadStore;
pub struct RulesContextPicker {
picker: Entity<Picker<RulesContextPickerDelegate>>,
}
impl RulesContextPicker {
pub fn new(
thread_store: WeakEntity<ThreadStore>,
context_picker: WeakEntity<ContextPicker>,
context_store: WeakEntity<context_store::ContextStore>,
window: &mut Window,
cx: &mut Context<Self>,
) -> Self {
let delegate = RulesContextPickerDelegate::new(thread_store, context_picker, context_store);
let picker = cx.new(|cx| Picker::uniform_list(delegate, window, cx));
RulesContextPicker { picker }
}
}
impl Focusable for RulesContextPicker {
fn focus_handle(&self, cx: &App) -> FocusHandle {
self.picker.focus_handle(cx)
}
}
impl Render for RulesContextPicker {
fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement {
self.picker.clone()
}
}
#[derive(Debug, Clone)]
pub struct RulesContextEntry {
pub prompt_id: UserPromptId,
pub title: SharedString,
}
pub struct RulesContextPickerDelegate {
thread_store: WeakEntity<ThreadStore>,
context_picker: WeakEntity<ContextPicker>,
context_store: WeakEntity<context_store::ContextStore>,
matches: Vec<RulesContextEntry>,
selected_index: usize,
}
impl RulesContextPickerDelegate {
pub fn new(
thread_store: WeakEntity<ThreadStore>,
context_picker: WeakEntity<ContextPicker>,
context_store: WeakEntity<context_store::ContextStore>,
) -> Self {
RulesContextPickerDelegate {
thread_store,
context_picker,
context_store,
matches: Vec::new(),
selected_index: 0,
}
}
}
impl PickerDelegate for RulesContextPickerDelegate {
type ListItem = ListItem;
fn match_count(&self) -> usize {
self.matches.len()
}
fn selected_index(&self) -> usize {
self.selected_index
}
fn set_selected_index(
&mut self,
ix: usize,
_window: &mut Window,
_cx: &mut Context<Picker<Self>>,
) {
self.selected_index = ix;
}
fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
"Search available rules…".into()
}
fn update_matches(
&mut self,
query: String,
window: &mut Window,
cx: &mut Context<Picker<Self>>,
) -> Task<()> {
let Some(thread_store) = self.thread_store.upgrade() else {
return Task::ready(());
};
let search_task = search_rules(query, Arc::new(AtomicBool::default()), thread_store, cx);
cx.spawn_in(window, async move |this, cx| {
let matches = search_task.await;
this.update(cx, |this, cx| {
this.delegate.matches = matches;
this.delegate.selected_index = 0;
cx.notify();
})
.ok();
})
}
fn confirm(&mut self, _secondary: bool, _window: &mut Window, cx: &mut Context<Picker<Self>>) {
let Some(entry) = self.matches.get(self.selected_index) else {
return;
};
let Some(thread_store) = self.thread_store.upgrade() else {
return;
};
let prompt_id = entry.prompt_id;
let load_rules_task = thread_store.update(cx, |thread_store, cx| {
thread_store.load_rules(prompt_id, cx)
});
cx.spawn(async move |this, cx| {
let (metadata, text) = load_rules_task.await?;
let Some(title) = metadata.title else {
return Err(anyhow!("Encountered user rule with no title when attempting to add it to agent context."));
};
this.update(cx, |this, cx| {
this.delegate
.context_store
.update(cx, |context_store, cx| {
context_store.add_rules(prompt_id, title, text, true, cx)
})
.ok();
})
})
.detach_and_log_err(cx);
}
fn dismissed(&mut self, _window: &mut Window, cx: &mut Context<Picker<Self>>) {
self.context_picker
.update(cx, |_, cx| {
cx.emit(DismissEvent);
})
.ok();
}
fn render_match(
&self,
ix: usize,
selected: bool,
_window: &mut Window,
cx: &mut Context<Picker<Self>>,
) -> Option<Self::ListItem> {
let thread = &self.matches[ix];
Some(ListItem::new(ix).inset(true).toggle_state(selected).child(
render_thread_context_entry(thread, self.context_store.clone(), cx),
))
}
}
pub fn render_thread_context_entry(
user_rules: &RulesContextEntry,
context_store: WeakEntity<ContextStore>,
cx: &mut App,
) -> Div {
let added = context_store.upgrade().map_or(false, |ctx_store| {
ctx_store
.read(cx)
.includes_user_rules(&user_rules.prompt_id)
.is_some()
});
h_flex()
.gap_1p5()
.w_full()
.justify_between()
.child(
h_flex()
.gap_1p5()
.max_w_72()
.child(
Icon::new(RULES_ICON)
.size(IconSize::XSmall)
.color(Color::Muted),
)
.child(Label::new(user_rules.title.clone()).truncate()),
)
.when(added, |el| {
el.child(
h_flex()
.gap_1()
.child(
Icon::new(IconName::Check)
.size(IconSize::Small)
.color(Color::Success),
)
.child(Label::new("Added").size(LabelSize::Small)),
)
})
}
pub(crate) fn search_rules(
query: String,
cancellation_flag: Arc<AtomicBool>,
thread_store: Entity<ThreadStore>,
cx: &mut App,
) -> Task<Vec<RulesContextEntry>> {
let Some(prompt_store) = thread_store.read(cx).prompt_store() else {
return Task::ready(vec![]);
};
let search_task = prompt_store.read(cx).search(query, cancellation_flag, cx);
cx.background_spawn(async move {
search_task
.await
.into_iter()
.flat_map(|metadata| {
// Default prompts are filtered out as they are automatically included.
if metadata.default {
None
} else {
match metadata.id {
PromptId::EditWorkflow => None,
PromptId::User { uuid } => Some(RulesContextEntry {
prompt_id: uuid,
title: metadata.title?,
}),
}
}
})
.collect::<Vec<_>>()
})
}

View file

@ -103,11 +103,11 @@ impl PickerDelegate for ThreadContextPickerDelegate {
window: &mut Window,
cx: &mut Context<Picker<Self>>,
) -> Task<()> {
let Some(threads) = self.thread_store.upgrade() else {
let Some(thread_store) = self.thread_store.upgrade() else {
return Task::ready(());
};
let search_task = search_threads(query, Arc::new(AtomicBool::default()), threads, cx);
let search_task = search_threads(query, Arc::new(AtomicBool::default()), thread_store, cx);
cx.spawn_in(window, async move |this, cx| {
let matches = search_task.await;
this.update(cx, |this, cx| {
@ -217,15 +217,15 @@ pub(crate) fn search_threads(
thread_store: Entity<ThreadStore>,
cx: &mut App,
) -> Task<Vec<ThreadMatch>> {
let threads = thread_store.update(cx, |this, _cx| {
this.threads()
.into_iter()
.map(|thread| ThreadContextEntry {
id: thread.id,
summary: thread.summary,
})
.collect::<Vec<_>>()
});
let threads = thread_store
.read(cx)
.threads()
.into_iter()
.map(|thread| ThreadContextEntry {
id: thread.id,
summary: thread.summary,
})
.collect::<Vec<_>>();
let executor = cx.background_executor().clone();
cx.background_spawn(async move {

View file

@ -9,6 +9,7 @@ use futures::{self, Future, FutureExt, future};
use gpui::{App, AppContext as _, Context, Entity, SharedString, Task, WeakEntity};
use language::{Buffer, File};
use project::{Project, ProjectItem, ProjectPath, Worktree};
use prompt_store::UserPromptId;
use rope::{Point, Rope};
use text::{Anchor, BufferId, OffsetRangeExt};
use util::{ResultExt as _, maybe};
@ -16,7 +17,7 @@ use util::{ResultExt as _, maybe};
use crate::ThreadStore;
use crate::context::{
AssistantContext, ContextBuffer, ContextId, ContextSymbol, ContextSymbolId, DirectoryContext,
ExcerptContext, FetchedUrlContext, FileContext, SymbolContext, ThreadContext,
ExcerptContext, FetchedUrlContext, FileContext, RulesContext, SymbolContext, ThreadContext,
};
use crate::context_strip::SuggestedContext;
use crate::thread::{Thread, ThreadId};
@ -25,7 +26,6 @@ pub struct ContextStore {
project: WeakEntity<Project>,
context: Vec<AssistantContext>,
thread_store: Option<WeakEntity<ThreadStore>>,
// TODO: If an EntityId is used for all context types (like BufferId), can remove ContextId.
next_context_id: ContextId,
files: BTreeMap<BufferId, ContextId>,
directories: HashMap<ProjectPath, ContextId>,
@ -35,6 +35,7 @@ pub struct ContextStore {
threads: HashMap<ThreadId, ContextId>,
thread_summary_tasks: Vec<Task<()>>,
fetched_urls: HashMap<String, ContextId>,
user_rules: HashMap<UserPromptId, ContextId>,
}
impl ContextStore {
@ -55,6 +56,7 @@ impl ContextStore {
threads: HashMap::default(),
thread_summary_tasks: Vec::new(),
fetched_urls: HashMap::default(),
user_rules: HashMap::default(),
}
}
@ -72,6 +74,7 @@ impl ContextStore {
self.directories.clear();
self.threads.clear();
self.fetched_urls.clear();
self.user_rules.clear();
}
pub fn add_file_from_path(
@ -390,6 +393,42 @@ impl ContextStore {
cx.notify();
}
pub fn add_rules(
&mut self,
prompt_id: UserPromptId,
title: impl Into<SharedString>,
text: impl Into<SharedString>,
remove_if_exists: bool,
cx: &mut Context<ContextStore>,
) {
if let Some(context_id) = self.includes_user_rules(&prompt_id) {
if remove_if_exists {
self.remove_context(context_id, cx);
}
} else {
self.insert_user_rules(prompt_id, title, text, cx);
}
}
pub fn insert_user_rules(
&mut self,
prompt_id: UserPromptId,
title: impl Into<SharedString>,
text: impl Into<SharedString>,
cx: &mut Context<ContextStore>,
) {
let id = self.next_context_id.post_inc();
self.user_rules.insert(prompt_id, id);
self.context.push(AssistantContext::Rules(RulesContext {
id,
prompt_id,
title: title.into(),
text: text.into(),
}));
cx.notify();
}
pub fn add_fetched_url(
&mut self,
url: String,
@ -518,6 +557,9 @@ impl ContextStore {
AssistantContext::Thread(_) => {
self.threads.retain(|_, context_id| *context_id != id);
}
AssistantContext::Rules(RulesContext { prompt_id, .. }) => {
self.user_rules.remove(&prompt_id);
}
}
cx.notify();
@ -614,6 +656,10 @@ impl ContextStore {
self.threads.get(thread_id).copied()
}
pub fn includes_user_rules(&self, prompt_id: &UserPromptId) -> Option<ContextId> {
self.user_rules.get(prompt_id).copied()
}
pub fn includes_url(&self, url: &str) -> Option<ContextId> {
self.fetched_urls.get(url).copied()
}
@ -641,7 +687,8 @@ impl ContextStore {
| AssistantContext::Symbol(_)
| AssistantContext::Excerpt(_)
| AssistantContext::FetchedUrl(_)
| AssistantContext::Thread(_) => None,
| AssistantContext::Thread(_)
| AssistantContext::Rules(_) => None,
})
.collect()
}
@ -876,6 +923,10 @@ pub fn refresh_context_store_text(
// and doing the caching properly could be tricky (unless it's already handled by
// the HttpClient?).
AssistantContext::FetchedUrl(_) => {}
AssistantContext::Rules(user_rules_context) => {
let context_store = context_store.clone();
return Some(refresh_user_rules(context_store, user_rules_context, cx));
}
}
None
@ -1026,6 +1077,45 @@ fn refresh_thread_text(
})
}
fn refresh_user_rules(
context_store: Entity<ContextStore>,
user_rules_context: &RulesContext,
cx: &App,
) -> Task<()> {
let id = user_rules_context.id;
let prompt_id = user_rules_context.prompt_id;
let Some(thread_store) = context_store.read(cx).thread_store.as_ref() else {
return Task::ready(());
};
let Ok(load_task) = thread_store.read_with(cx, |thread_store, cx| {
thread_store.load_rules(prompt_id, cx)
}) else {
return Task::ready(());
};
cx.spawn(async move |cx| {
if let Ok((metadata, text)) = load_task.await {
if let Some(title) = metadata.title.clone() {
context_store
.update(cx, |context_store, _cx| {
context_store.replace_context(AssistantContext::Rules(RulesContext {
id,
prompt_id,
title,
text: text.into(),
}));
})
.ok();
return;
}
}
context_store
.update(cx, |context_store, cx| {
context_store.remove_context(id, cx);
})
.ok();
})
}
fn refresh_context_buffer(
context_buffer: &ContextBuffer,
cx: &App,

View file

@ -774,7 +774,9 @@ impl Thread {
cx,
);
}
AssistantContext::FetchedUrl(_) | AssistantContext::Thread(_) => {}
AssistantContext::FetchedUrl(_)
| AssistantContext::Thread(_)
| AssistantContext::Rules(_) => {}
}
}
});

View file

@ -24,8 +24,8 @@ use heed::types::SerdeBincode;
use language_model::{LanguageModelToolUseId, Role, TokenUsage};
use project::{Project, Worktree};
use prompt_store::{
DefaultUserRulesContext, ProjectContext, PromptBuilder, PromptId, PromptStore,
PromptsUpdatedEvent, RulesFileContext, WorktreeContext,
ProjectContext, PromptBuilder, PromptId, PromptMetadata, PromptStore, PromptsUpdatedEvent,
RulesFileContext, UserPromptId, UserRulesContext, WorktreeContext,
};
use serde::{Deserialize, Serialize};
use settings::{Settings as _, SettingsStore};
@ -62,6 +62,7 @@ pub struct ThreadStore {
project: Entity<Project>,
tools: Entity<ToolWorkingSet>,
prompt_builder: Arc<PromptBuilder>,
prompt_store: Option<Entity<PromptStore>>,
context_server_manager: Entity<ContextServerManager>,
context_server_tool_ids: HashMap<Arc<str>, Vec<ToolId>>,
threads: Vec<SerializedThreadMetadata>,
@ -135,6 +136,7 @@ impl ThreadStore {
let (ready_tx, ready_rx) = oneshot::channel();
let mut ready_tx = Some(ready_tx);
let reload_system_prompt_task = cx.spawn({
let prompt_store = prompt_store.clone();
async move |thread_store, cx| {
loop {
let Some(reload_task) = thread_store
@ -158,6 +160,7 @@ impl ThreadStore {
project,
tools,
prompt_builder,
prompt_store,
context_server_manager,
context_server_tool_ids: HashMap::default(),
threads: Vec::new(),
@ -245,7 +248,7 @@ impl ThreadStore {
let default_user_rules = default_user_rules
.into_iter()
.flat_map(|(contents, prompt_metadata)| match contents {
Ok(contents) => Some(DefaultUserRulesContext {
Ok(contents) => Some(UserRulesContext {
uuid: match prompt_metadata.id {
PromptId::User { uuid } => uuid,
PromptId::EditWorkflow => return None,
@ -346,6 +349,27 @@ impl ThreadStore {
self.context_server_manager.clone()
}
pub fn prompt_store(&self) -> Option<Entity<PromptStore>> {
self.prompt_store.clone()
}
pub fn load_rules(
&self,
prompt_id: UserPromptId,
cx: &App,
) -> Task<Result<(PromptMetadata, String)>> {
let prompt_id = PromptId::User { uuid: prompt_id };
let Some(prompt_store) = self.prompt_store.as_ref() else {
return Task::ready(Err(anyhow!("Prompt store unexpectedly missing.")));
};
let prompt_store = prompt_store.read(cx);
let Some(metadata) = prompt_store.metadata(prompt_id) else {
return Task::ready(Err(anyhow!("User rules not found in library.")));
};
let text_task = prompt_store.load(prompt_id, cx);
cx.background_spawn(async move { Ok((metadata, text_task.await?)) })
}
pub fn tools(&self) -> Entity<ToolWorkingSet> {
self.tools.clone()
}

View file

@ -354,6 +354,16 @@ impl AddedContext {
.read(cx)
.is_generating_detailed_summary(),
},
AssistantContext::Rules(user_rules_context) => AddedContext {
id: user_rules_context.id,
kind: ContextKind::Rules,
name: user_rules_context.title.clone(),
parent: None,
tooltip: None,
icon_path: None,
summarizing: false,
},
}
}
}

View file

@ -27,7 +27,7 @@ use language_model::{
};
use project::Project;
use prompt_library::{PromptLibrary, open_prompt_library};
use prompt_store::{PromptBuilder, PromptId};
use prompt_store::{PromptBuilder, PromptId, UserPromptId};
use search::{BufferSearchBar, buffer_search::DivRegistrar};
use settings::{Settings, update_settings_file};
@ -58,11 +58,11 @@ pub fn init(cx: &mut App) {
.register_action(AssistantPanel::show_configuration)
.register_action(AssistantPanel::create_new_context)
.register_action(AssistantPanel::restart_context_servers)
.register_action(|workspace, _: &OpenPromptLibrary, window, cx| {
.register_action(|workspace, action: &OpenPromptLibrary, window, cx| {
if let Some(panel) = workspace.panel::<AssistantPanel>(cx) {
workspace.focus_panel::<AssistantPanel>(window, cx);
panel.update(cx, |panel, cx| {
panel.deploy_prompt_library(&OpenPromptLibrary::default(), window, cx)
panel.deploy_prompt_library(action, window, cx)
});
}
});
@ -1060,7 +1060,9 @@ impl AssistantPanel {
None,
))
}),
action.prompt_to_focus.map(|uuid| PromptId::User { uuid }),
action.prompt_to_select.map(|uuid| PromptId::User {
uuid: UserPromptId(uuid),
}),
cx,
)
.detach_and_log_err(cx);

View file

@ -44,9 +44,10 @@ impl SlashCommand for PromptSlashCommand {
let store = PromptStore::global(cx);
let query = arguments.to_owned().join(" ");
cx.spawn(async move |cx| {
let cancellation_flag = Arc::new(AtomicBool::default());
let prompts: Vec<PromptMetadata> = store
.await?
.read_with(cx, |store, cx| store.search(query, cx))?
.read_with(cx, |store, cx| store.search(query, cancellation_flag, cx))?
.await;
Ok(prompts
.into_iter()

View file

@ -16,6 +16,7 @@ use release_channel::ReleaseChannel;
use rope::Rope;
use settings::Settings;
use std::sync::Arc;
use std::sync::atomic::AtomicBool;
use std::time::Duration;
use theme::ThemeSettings;
use ui::{
@ -75,7 +76,7 @@ pub fn open_prompt_library(
language_registry: Arc<LanguageRegistry>,
inline_assist_delegate: Box<dyn InlineAssistDelegate>,
make_completion_provider: Arc<dyn Fn() -> Box<dyn CompletionProvider>>,
prompt_to_focus: Option<PromptId>,
prompt_to_select: Option<PromptId>,
cx: &mut App,
) -> Task<Result<WindowHandle<PromptLibrary>>> {
let store = PromptStore::global(cx);
@ -90,8 +91,8 @@ pub fn open_prompt_library(
if let Some(existing_window) = existing_window {
existing_window
.update(cx, |prompt_library, window, cx| {
if let Some(prompt_to_focus) = prompt_to_focus {
prompt_library.load_prompt(prompt_to_focus, true, window, cx);
if let Some(prompt_to_select) = prompt_to_select {
prompt_library.load_prompt(prompt_to_select, true, window, cx);
}
window.activate_window()
})
@ -126,18 +127,15 @@ pub fn open_prompt_library(
},
|window, cx| {
cx.new(|cx| {
let mut prompt_library = PromptLibrary::new(
PromptLibrary::new(
store,
language_registry,
inline_assist_delegate,
make_completion_provider,
prompt_to_select,
window,
cx,
);
if let Some(prompt_to_focus) = prompt_to_focus {
prompt_library.load_prompt(prompt_to_focus, true, window, cx);
}
prompt_library
)
})
},
)
@ -221,7 +219,8 @@ impl PickerDelegate for PromptPickerDelegate {
window: &mut Window,
cx: &mut Context<Picker<Self>>,
) -> Task<()> {
let search = self.store.read(cx).search(query, cx);
let cancellation_flag = Arc::new(AtomicBool::default());
let search = self.store.read(cx).search(query, cancellation_flag, cx);
let prev_prompt_id = self.matches.get(self.selected_index).map(|mat| mat.id);
cx.spawn_in(window, async move |this, cx| {
let (matches, selected_index) = cx
@ -353,13 +352,26 @@ impl PromptLibrary {
language_registry: Arc<LanguageRegistry>,
inline_assist_delegate: Box<dyn InlineAssistDelegate>,
make_completion_provider: Arc<dyn Fn() -> Box<dyn CompletionProvider>>,
prompt_to_select: Option<PromptId>,
window: &mut Window,
cx: &mut Context<Self>,
) -> Self {
let (selected_index, matches) = if let Some(prompt_to_select) = prompt_to_select {
let matches = store.read(cx).all_prompt_metadata();
let selected_index = matches
.iter()
.enumerate()
.find(|(_, metadata)| metadata.id == prompt_to_select)
.map_or(0, |(ix, _)| ix);
(selected_index, matches)
} else {
(0, vec![])
};
let delegate = PromptPickerDelegate {
store: store.clone(),
selected_index: 0,
matches: Vec::new(),
selected_index,
matches,
};
let picker = cx.new(|cx| {

View file

@ -54,14 +54,14 @@ pub struct PromptMetadata {
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(tag = "kind")]
pub enum PromptId {
User { uuid: Uuid },
User { uuid: UserPromptId },
EditWorkflow,
}
impl PromptId {
pub fn new() -> PromptId {
PromptId::User {
uuid: Uuid::new_v4(),
uuid: UserPromptId::new(),
}
}
@ -70,6 +70,22 @@ impl PromptId {
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(transparent)]
pub struct UserPromptId(pub Uuid);
impl UserPromptId {
pub fn new() -> UserPromptId {
UserPromptId(Uuid::new_v4())
}
}
impl From<Uuid> for UserPromptId {
fn from(uuid: Uuid) -> Self {
UserPromptId(uuid)
}
}
pub struct PromptStore {
env: heed::Env,
metadata_cache: RwLock<MetadataCache>,
@ -212,7 +228,7 @@ impl PromptStore {
for (prompt_id_v1, metadata_v1) in metadata_v1 {
let prompt_id_v2 = PromptId::User {
uuid: prompt_id_v1.0,
uuid: UserPromptId(prompt_id_v1.0),
};
let Some(body_v1) = bodies_v1.remove(&prompt_id_v1) else {
continue;
@ -257,6 +273,10 @@ impl PromptStore {
})
}
pub fn all_prompt_metadata(&self) -> Vec<PromptMetadata> {
self.metadata_cache.read().metadata.clone()
}
pub fn default_prompt_metadata(&self) -> Vec<PromptMetadata> {
return self
.metadata_cache
@ -314,7 +334,12 @@ impl PromptStore {
Some(metadata.id)
}
pub fn search(&self, query: String, cx: &App) -> Task<Vec<PromptMetadata>> {
pub fn search(
&self,
query: String,
cancellation_flag: Arc<AtomicBool>,
cx: &App,
) -> Task<Vec<PromptMetadata>> {
let cached_metadata = self.metadata_cache.read().metadata.clone();
let executor = cx.background_executor().clone();
cx.background_spawn(async move {
@ -333,7 +358,7 @@ impl PromptStore {
&query,
false,
100,
&AtomicBool::default(),
&cancellation_flag,
executor,
)
.await;

View file

@ -15,34 +15,32 @@ use std::{
};
use text::LineEnding;
use util::{ResultExt, get_system_shell};
use uuid::Uuid;
use crate::UserPromptId;
#[derive(Debug, Clone, Serialize)]
pub struct ProjectContext {
pub worktrees: Vec<WorktreeContext>,
/// Whether any worktree has a rules_file. Provided as a field because handlebars can't do this.
pub has_rules: bool,
pub default_user_rules: Vec<DefaultUserRulesContext>,
/// `!default_user_rules.is_empty()` - provided as a field because handlebars can't do this.
pub has_default_user_rules: bool,
pub user_rules: Vec<UserRulesContext>,
/// `!user_rules.is_empty()` - provided as a field because handlebars can't do this.
pub has_user_rules: bool,
pub os: String,
pub arch: String,
pub shell: String,
}
impl ProjectContext {
pub fn new(
worktrees: Vec<WorktreeContext>,
default_user_rules: Vec<DefaultUserRulesContext>,
) -> Self {
pub fn new(worktrees: Vec<WorktreeContext>, default_user_rules: Vec<UserRulesContext>) -> Self {
let has_rules = worktrees
.iter()
.any(|worktree| worktree.rules_file.is_some());
Self {
worktrees,
has_rules,
has_default_user_rules: !default_user_rules.is_empty(),
default_user_rules,
has_user_rules: !default_user_rules.is_empty(),
user_rules: default_user_rules,
os: std::env::consts::OS.to_string(),
arch: std::env::consts::ARCH.to_string(),
shell: get_system_shell(),
@ -51,8 +49,8 @@ impl ProjectContext {
}
#[derive(Debug, Clone, Serialize)]
pub struct DefaultUserRulesContext {
pub uuid: Uuid,
pub struct UserRulesContext {
pub uuid: UserPromptId,
pub title: Option<String>,
pub contents: String,
}
@ -397,6 +395,7 @@ impl PromptBuilder {
#[cfg(test)]
mod test {
use super::*;
use uuid::Uuid;
#[test]
fn test_assistant_system_prompt_renders() {
@ -408,8 +407,8 @@ mod test {
text: "".into(),
}),
}];
let default_user_rules = vec![DefaultUserRulesContext {
uuid: Uuid::nil(),
let default_user_rules = vec![UserRulesContext {
uuid: UserPromptId(Uuid::nil()),
title: Some("Rules title".into()),
contents: "Rules contents".into(),
}];

View file

@ -201,7 +201,7 @@ pub mod assistant {
#[serde(deny_unknown_fields)]
pub struct OpenPromptLibrary {
#[serde(skip)]
pub prompt_to_focus: Option<Uuid>,
pub prompt_to_select: Option<Uuid>,
}
impl_action_with_deprecated_aliases!(