diff --git a/assets/prompts/assistant_system_prompt.hbs b/assets/prompts/assistant_system_prompt.hbs index 654c1848a7..60b2cee74e 100644 --- a/assets/prompts/assistant_system_prompt.hbs +++ b/assets/prompts/assistant_system_prompt.hbs @@ -144,6 +144,19 @@ In Markdown, hash marks signify headings. For example: This example is unacceptable because the path is in the wrong place. The path must be directly after the opening backticks. +{{#if has_default_user_rules}} +The user has specified the following rules that should be applied: +{{#each default_user_rules}} + +{{#if title}} +Rules title: {{title}} +{{/if}} +`````` +{{contents}} +`````` +{{/each}} + +{{/if}} The user has opened a project that contains the following root directories/files. Whenever you specify a path in the project, it must be a relative path which begins with one of these root directories/files: {{#each worktrees}} @@ -151,7 +164,7 @@ The user has opened a project that contains the following root directories/files {{/each}} {{#if has_rules}} -There are rules that apply to these root directories: +There are project rules that apply to these root directories: {{#each worktrees}} {{#if rules_file}} diff --git a/crates/agent/src/active_thread.rs b/crates/agent/src/active_thread.rs index 991906a236..daec26a733 100644 --- a/crates/agent/src/active_thread.rs +++ b/crates/agent/src/active_thread.rs @@ -42,6 +42,7 @@ use ui::{ }; use util::ResultExt as _; use workspace::{OpenOptions, Workspace}; +use zed_actions::assistant::OpenPromptLibrary; use crate::context_store::ContextStore; @@ -2948,53 +2949,106 @@ impl ActiveThread { return div().into_any(); }; + let default_user_rules_text = if project_context.default_user_rules.is_empty() { + None + } else if project_context.default_user_rules.len() == 1 { + let user_rules = &project_context.default_user_rules[0]; + + match user_rules.title.as_ref() { + Some(title) => Some(format!("Using \"{title}\" user rule")), + None => Some("Using user rule".into()), + } + } else { + Some(format!( + "Using {} user rules", + project_context.default_user_rules.len() + )) + }; + let rules_files = project_context .worktrees .iter() .filter_map(|worktree| worktree.rules_file.as_ref()) .collect::>(); - let label_text = match rules_files.as_slice() { - &[] => return div().into_any(), - &[rules_file] => { - format!("Using {:?} file", rules_file.path_in_worktree) - } - rules_files => { - format!("Using {} rules files", rules_files.len()) - } + let rules_file_text = match rules_files.as_slice() { + &[] => None, + &[rules_file] => Some(format!( + "Using project {:?} file", + rules_file.path_in_worktree + )), + rules_files => Some(format!("Using {} project rules files", rules_files.len())), }; - div() + if default_user_rules_text.is_none() && rules_file_text.is_none() { + return div().into_any(); + } + + v_flex() .pt_2() .px_2p5() - .child( - h_flex() - .w_full() - .gap_0p5() - .child( + .gap_1() + .when_some( + default_user_rules_text, + |parent, default_user_rules_text| { + parent.child( h_flex() - .gap_1p5() + .w_full() .child( Icon::new(IconName::File) .size(IconSize::XSmall) .color(Color::Disabled), ) .child( - Label::new(label_text) + Label::new(default_user_rules_text) .size(LabelSize::XSmall) .color(Color::Muted) - .buffer_font(cx), + .truncate() + .buffer_font(cx) + .ml_1p5() + .mr_0p5(), + ) + .child( + IconButton::new("open-prompt-library", IconName::ArrowUpRightAlt) + .shape(ui::IconButtonShape::Square) + .icon_size(IconSize::XSmall) + .icon_color(Color::Ignored) + // TODO: Figure out a way to pass focus handle here so we can display the `OpenPromptLibrary` keybinding + .tooltip(Tooltip::text("View User Rules")) + .on_click(|_event, window, cx| { + window.dispatch_action(Box::new(OpenPromptLibrary), cx) + }), ), ) - .child( - IconButton::new("open-rule", IconName::ArrowUpRightAlt) - .shape(ui::IconButtonShape::Square) - .icon_size(IconSize::XSmall) - .icon_color(Color::Ignored) - .on_click(cx.listener(Self::handle_open_rules)) - .tooltip(Tooltip::text("View Rules")), - ), + }, ) + .when_some(rules_file_text, |parent, rules_file_text| { + parent.child( + h_flex() + .w_full() + .child( + Icon::new(IconName::File) + .size(IconSize::XSmall) + .color(Color::Disabled), + ) + .child( + Label::new(rules_file_text) + .size(LabelSize::XSmall) + .color(Color::Muted) + .buffer_font(cx) + .ml_1p5() + .mr_0p5(), + ) + .child( + IconButton::new("open-rule", IconName::ArrowUpRightAlt) + .shape(ui::IconButtonShape::Square) + .icon_size(IconSize::XSmall) + .icon_color(Color::Ignored) + .on_click(cx.listener(Self::handle_open_rules)) + .tooltip(Tooltip::text("View Rules")), + ), + ) + }) .into_any() } diff --git a/crates/agent/src/agent_diff.rs b/crates/agent/src/agent_diff.rs index bf2f42640d..05f003d03a 100644 --- a/crates/agent/src/agent_diff.rs +++ b/crates/agent/src/agent_diff.rs @@ -922,6 +922,7 @@ mod tests { language::init(cx); Project::init_settings(cx); AssistantSettings::register(cx); + prompt_store::init(cx); thread_store::init(cx); workspace::init_settings(cx); ThemeSettings::register(cx); @@ -951,7 +952,8 @@ mod tests { cx, ) }) - .await; + .await + .unwrap(); let thread = thread_store.update(cx, |store, cx| store.create_thread(cx)); let action_log = thread.read_with(cx, |thread, _| thread.action_log().clone()); diff --git a/crates/agent/src/assistant_panel.rs b/crates/agent/src/assistant_panel.rs index 6a72677eaa..1f450921f7 100644 --- a/crates/agent/src/assistant_panel.rs +++ b/crates/agent/src/assistant_panel.rs @@ -213,7 +213,7 @@ impl AssistantPanel { let project = workspace.project().clone(); ThreadStore::load(project, tools.clone(), prompt_builder.clone(), cx) })? - .await; + .await?; let slash_commands = Arc::new(SlashCommandWorkingSet::default()); let context_store = workspace diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index 7adc78591e..ea9db9d9f5 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -4,7 +4,7 @@ use std::ops::Range; use std::sync::Arc; use std::time::Instant; -use anyhow::{Context as _, Result, anyhow}; +use anyhow::{Result, anyhow}; use assistant_settings::AssistantSettings; use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet}; use chrono::{DateTime, Utc}; @@ -939,7 +939,7 @@ impl Thread { pub fn to_completion_request( &self, request_kind: RequestKind, - cx: &App, + cx: &mut Context, ) -> LanguageModelRequest { let mut request = LanguageModelRequest { messages: vec![], @@ -949,20 +949,33 @@ impl Thread { }; if let Some(project_context) = self.project_context.borrow().as_ref() { - if let Some(system_prompt) = self + match self .prompt_builder .generate_assistant_system_prompt(project_context) - .context("failed to generate assistant system prompt") - .log_err() { - request.messages.push(LanguageModelRequestMessage { - role: Role::System, - content: vec![MessageContent::Text(system_prompt)], - cache: true, - }); + Err(err) => { + let message = format!("{err:?}").into(); + log::error!("{message}"); + cx.emit(ThreadEvent::ShowError(ThreadError::Message { + header: "Error generating system prompt".into(), + message, + })); + } + Ok(system_prompt) => { + request.messages.push(LanguageModelRequestMessage { + role: Role::System, + content: vec![MessageContent::Text(system_prompt)], + cache: true, + }); + } } } else { - log::error!("project_context not set.") + let message = "Context for system prompt unexpectedly not ready.".into(); + log::error!("{message}"); + cx.emit(ThreadEvent::ShowError(ThreadError::Message { + header: "Error generating system prompt".into(), + message, + })); } for message in &self.messages { @@ -2163,7 +2176,7 @@ fn main() {{ assert_eq!(message.context, expected_context); // Check message in request - let request = thread.read_with(cx, |thread, cx| { + let request = thread.update(cx, |thread, cx| { thread.to_completion_request(RequestKind::Chat, cx) }); @@ -2255,7 +2268,7 @@ fn main() {{ assert!(message3.context.contains("file3.rs")); // Check entire request to make sure all contexts are properly included - let request = thread.read_with(cx, |thread, cx| { + let request = thread.update(cx, |thread, cx| { thread.to_completion_request(RequestKind::Chat, cx) }); @@ -2307,7 +2320,7 @@ fn main() {{ assert_eq!(message.context, ""); // Check message in request - let request = thread.read_with(cx, |thread, cx| { + let request = thread.update(cx, |thread, cx| { thread.to_completion_request(RequestKind::Chat, cx) }); @@ -2327,7 +2340,7 @@ fn main() {{ assert_eq!(message2.context, ""); // Check that both messages appear in the request - let request = thread.read_with(cx, |thread, cx| { + let request = thread.update(cx, |thread, cx| { thread.to_completion_request(RequestKind::Chat, cx) }); @@ -2369,7 +2382,7 @@ fn main() {{ }); // Create a request and check that it doesn't have a stale buffer warning yet - let initial_request = thread.read_with(cx, |thread, cx| { + let initial_request = thread.update(cx, |thread, cx| { thread.to_completion_request(RequestKind::Chat, cx) }); @@ -2399,7 +2412,7 @@ fn main() {{ }); // Create a new request and check for the stale buffer warning - let new_request = thread.read_with(cx, |thread, cx| { + let new_request = thread.update(cx, |thread, cx| { thread.to_completion_request(RequestKind::Chat, cx) }); @@ -2428,6 +2441,7 @@ fn main() {{ language::init(cx); Project::init_settings(cx); AssistantSettings::register(cx); + prompt_store::init(cx); thread_store::init(cx); workspace::init_settings(cx); ThemeSettings::register(cx); @@ -2467,7 +2481,8 @@ fn main() {{ cx, ) }) - .await; + .await + .unwrap(); let thread = thread_store.update(cx, |store, cx| store.create_thread(cx)); let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None)); diff --git a/crates/agent/src/thread_store.rs b/crates/agent/src/thread_store.rs index a72313061c..5de5d7404a 100644 --- a/crates/agent/src/thread_store.rs +++ b/crates/agent/src/thread_store.rs @@ -12,8 +12,9 @@ use collections::HashMap; use context_server::manager::ContextServerManager; use context_server::{ContextServerFactoryRegistry, ContextServerTool}; use fs::Fs; -use futures::FutureExt as _; +use futures::channel::{mpsc, oneshot}; use futures::future::{self, BoxFuture, Shared}; +use futures::{FutureExt as _, StreamExt as _}; use gpui::{ App, BackgroundExecutor, Context, Entity, EventEmitter, Global, ReadGlobal, SharedString, Subscription, Task, prelude::*, @@ -22,7 +23,10 @@ use heed::Database; use heed::types::SerdeBincode; use language_model::{LanguageModelToolUseId, Role, TokenUsage}; use project::{Project, Worktree}; -use prompt_store::{ProjectContext, PromptBuilder, RulesFileContext, WorktreeContext}; +use prompt_store::{ + DefaultUserRulesContext, ProjectContext, PromptBuilder, PromptStore, PromptsUpdatedEvent, + RulesFileContext, WorktreeContext, +}; use serde::{Deserialize, Serialize}; use settings::{Settings as _, SettingsStore}; use util::ResultExt as _; @@ -62,6 +66,8 @@ pub struct ThreadStore { context_server_tool_ids: HashMap, Vec>, threads: Vec, project_context: SharedProjectContext, + reload_system_prompt_tx: mpsc::Sender<()>, + _reload_system_prompt_task: Task<()>, _subscriptions: Vec, } @@ -77,12 +83,22 @@ impl ThreadStore { tools: Entity, prompt_builder: Arc, cx: &mut App, - ) -> Task> { - let thread_store = cx.new(|cx| Self::new(project, tools, prompt_builder, cx)); - let reload = thread_store.update(cx, |store, cx| store.reload_system_prompt(cx)); - cx.foreground_executor().spawn(async move { - reload.await; - thread_store + ) -> Task>> { + let prompt_store = PromptStore::global(cx); + cx.spawn(async move |cx| { + let prompt_store = prompt_store.await.ok(); + let (thread_store, ready_rx) = cx.update(|cx| { + let mut option_ready_rx = None; + let thread_store = cx.new(|cx| { + let (thread_store, ready_rx) = + Self::new(project, tools, prompt_builder, prompt_store, cx); + option_ready_rx = Some(ready_rx); + thread_store + }); + (thread_store, option_ready_rx.take().unwrap()) + })?; + ready_rx.await?; + Ok(thread_store) }) } @@ -90,17 +106,53 @@ impl ThreadStore { project: Entity, tools: Entity, prompt_builder: Arc, + prompt_store: Option>, cx: &mut Context, - ) -> Self { + ) -> (Self, oneshot::Receiver<()>) { let context_server_factory_registry = ContextServerFactoryRegistry::default_global(cx); let context_server_manager = cx.new(|cx| { ContextServerManager::new(context_server_factory_registry, project.clone(), cx) }); - let settings_subscription = + + let mut subscriptions = vec![ cx.observe_global::(move |this: &mut Self, cx| { this.load_default_profile(cx); - }); - let project_subscription = cx.subscribe(&project, Self::handle_project_event); + }), + cx.subscribe(&project, Self::handle_project_event), + ]; + + if let Some(prompt_store) = prompt_store.as_ref() { + subscriptions.push(cx.subscribe( + prompt_store, + |this, _prompt_store, PromptsUpdatedEvent, _cx| { + this.enqueue_system_prompt_reload(); + }, + )) + } + + // This channel and task prevent concurrent and redundant loading of the system prompt. + let (reload_system_prompt_tx, mut reload_system_prompt_rx) = mpsc::channel(1); + let (ready_tx, ready_rx) = oneshot::channel(); + let mut ready_tx = Some(ready_tx); + let reload_system_prompt_task = cx.spawn({ + async move |thread_store, cx| { + loop { + let Some(reload_task) = thread_store + .update(cx, |thread_store, cx| { + thread_store.reload_system_prompt(prompt_store.clone(), cx) + }) + .ok() + else { + return; + }; + reload_task.await; + if let Some(ready_tx) = ready_tx.take() { + ready_tx.send(()).ok(); + } + reload_system_prompt_rx.next().await; + } + } + }); let this = Self { project, @@ -110,23 +162,25 @@ impl ThreadStore { context_server_tool_ids: HashMap::default(), threads: Vec::new(), project_context: SharedProjectContext::default(), - _subscriptions: vec![settings_subscription, project_subscription], + reload_system_prompt_tx, + _reload_system_prompt_task: reload_system_prompt_task, + _subscriptions: subscriptions, }; this.load_default_profile(cx); this.register_context_server_handlers(cx); this.reload(cx).detach_and_log_err(cx); - this + (this, ready_rx) } fn handle_project_event( &mut self, _project: Entity, event: &project::Event, - cx: &mut Context, + _cx: &mut Context, ) { match event { project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => { - self.reload_system_prompt(cx).detach(); + self.enqueue_system_prompt_reload(); } project::Event::WorktreeUpdatedEntries(_, items) => { if items.iter().any(|(path, _, _)| { @@ -134,16 +188,25 @@ impl ThreadStore { .iter() .any(|name| path.as_ref() == Path::new(name)) }) { - self.reload_system_prompt(cx).detach(); + self.enqueue_system_prompt_reload(); } } _ => {} } } - pub fn reload_system_prompt(&self, cx: &mut Context) -> Task<()> { + fn enqueue_system_prompt_reload(&mut self) { + self.reload_system_prompt_tx.try_send(()).ok(); + } + + // Note that this should only be called from `reload_system_prompt_task`. + fn reload_system_prompt( + &self, + prompt_store: Option>, + cx: &mut Context, + ) -> Task<()> { let project = self.project.read(cx); - let tasks = project + let worktree_tasks = project .visible_worktrees(cx) .map(|worktree| { Self::load_worktree_info_for_system_prompt( @@ -153,10 +216,23 @@ impl ThreadStore { ) }) .collect::>(); + let default_user_rules_task = match prompt_store { + None => Task::ready(vec![]), + Some(prompt_store) => prompt_store.read_with(cx, |prompt_store, cx| { + let prompts = prompt_store.default_prompt_metadata(); + let load_tasks = prompts.into_iter().map(|prompt_metadata| { + let contents = prompt_store.load(prompt_metadata.id, cx); + async move { (contents.await, prompt_metadata) } + }); + cx.background_spawn(future::join_all(load_tasks)) + }), + }; cx.spawn(async move |this, cx| { - let results = futures::future::join_all(tasks).await; - let worktrees = results + let (worktrees, default_user_rules) = + future::join(future::join_all(worktree_tasks), default_user_rules_task).await; + + let worktrees = worktrees .into_iter() .map(|(worktree, rules_error)| { if let Some(rules_error) = rules_error { @@ -165,8 +241,29 @@ impl ThreadStore { worktree }) .collect::>(); + + let default_user_rules = default_user_rules + .into_iter() + .flat_map(|(contents, prompt_metadata)| match contents { + Ok(contents) => Some(DefaultUserRulesContext { + title: prompt_metadata.title.map(|title| title.to_string()), + contents, + }), + Err(err) => { + this.update(cx, |_, cx| { + cx.emit(RulesLoadingError { + message: format!("{err:?}").into(), + }); + }) + .ok(); + None + } + }) + .collect::>(); + this.update(cx, |this, _cx| { - *this.project_context.0.borrow_mut() = Some(ProjectContext::new(worktrees)); + *this.project_context.0.borrow_mut() = + Some(ProjectContext::new(worktrees, default_user_rules)); }) .ok(); }) diff --git a/crates/assistant_slash_commands/src/default_command.rs b/crates/assistant_slash_commands/src/default_command.rs index e5bda3c03b..6fce7f07a4 100644 --- a/crates/assistant_slash_commands/src/default_command.rs +++ b/crates/assistant_slash_commands/src/default_command.rs @@ -54,9 +54,9 @@ impl SlashCommand for DefaultSlashCommand { cx: &mut App, ) -> Task { let store = PromptStore::global(cx); - cx.background_spawn(async move { + cx.spawn(async move |cx| { let store = store.await?; - let prompts = store.default_prompt_metadata(); + let prompts = store.read_with(cx, |store, _cx| store.default_prompt_metadata())?; let mut text = String::new(); text.push('\n'); diff --git a/crates/assistant_slash_commands/src/prompt_command.rs b/crates/assistant_slash_commands/src/prompt_command.rs index 7d535f803a..a057023197 100644 --- a/crates/assistant_slash_commands/src/prompt_command.rs +++ b/crates/assistant_slash_commands/src/prompt_command.rs @@ -5,7 +5,7 @@ use assistant_slash_command::{ }; use gpui::{Task, WeakEntity}; use language::{BufferSnapshot, LspAdapterDelegate}; -use prompt_store::PromptStore; +use prompt_store::{PromptMetadata, PromptStore}; use std::sync::{Arc, atomic::AtomicBool}; use ui::prelude::*; use workspace::Workspace; @@ -43,8 +43,11 @@ impl SlashCommand for PromptSlashCommand { ) -> Task>> { let store = PromptStore::global(cx); let query = arguments.to_owned().join(" "); - cx.background_spawn(async move { - let prompts = store.await?.search(query).await; + cx.spawn(async move |cx| { + let prompts: Vec = store + .await? + .read_with(cx, |store, cx| store.search(query, cx))? + .await; Ok(prompts .into_iter() .filter_map(|prompt| { @@ -77,14 +80,18 @@ impl SlashCommand for PromptSlashCommand { let store = PromptStore::global(cx); let title = SharedString::from(title.clone()); - let prompt = cx.background_spawn({ + let prompt = cx.spawn({ let title = title.clone(); - async move { + async move |cx| { let store = store.await?; - let prompt_id = store - .id_for_title(&title) - .with_context(|| format!("no prompt found with title {:?}", title))?; - let body = store.load(prompt_id).await?; + let body = store + .read_with(cx, |store, cx| { + let prompt_id = store + .id_for_title(&title) + .with_context(|| format!("no prompt found with title {:?}", title))?; + anyhow::Ok(store.load(prompt_id, cx)) + })?? + .await?; anyhow::Ok(body) } }); diff --git a/crates/eval/src/example.rs b/crates/eval/src/example.rs index cbd8ee7a8d..7e48c11313 100644 --- a/crates/eval/src/example.rs +++ b/crates/eval/src/example.rs @@ -309,7 +309,7 @@ impl Example { return Err(anyhow!("Setup only mode")); } - let thread_store = thread_store.await; + let thread_store = thread_store.await?; let thread = thread_store.update(cx, |thread_store, cx| thread_store.create_thread(cx))?; diff --git a/crates/prompt_library/src/prompt_library.rs b/crates/prompt_library/src/prompt_library.rs index 7fff6d1258..3d6e97b724 100644 --- a/crates/prompt_library/src/prompt_library.rs +++ b/crates/prompt_library/src/prompt_library.rs @@ -136,7 +136,7 @@ pub fn open_prompt_library( } pub struct PromptLibrary { - store: Arc, + store: Entity, language_registry: Arc, prompt_editors: HashMap, active_prompt_id: Option, @@ -158,7 +158,7 @@ struct PromptEditor { } struct PromptPickerDelegate { - store: Arc, + store: Entity, selected_index: usize, matches: Vec, } @@ -179,8 +179,8 @@ impl PickerDelegate for PromptPickerDelegate { self.matches.len() } - fn no_matches_text(&self, _window: &mut Window, _cx: &mut App) -> Option { - let text = if self.store.prompt_count() == 0 { + fn no_matches_text(&self, _window: &mut Window, cx: &mut App) -> Option { + let text = if self.store.read(cx).prompt_count() == 0 { "No prompts.".into() } else { "No prompts found matching your search.".into() @@ -211,7 +211,7 @@ impl PickerDelegate for PromptPickerDelegate { window: &mut Window, cx: &mut Context>, ) -> Task<()> { - let search = self.store.search(query); + let search = self.store.read(cx).search(query, cx); let prev_prompt_id = self.matches.get(self.selected_index).map(|mat| mat.id); cx.spawn_in(window, async move |this, cx| { let (matches, selected_index) = cx @@ -339,7 +339,7 @@ impl PickerDelegate for PromptPickerDelegate { impl PromptLibrary { fn new( - store: Arc, + store: Entity, language_registry: Arc, inline_assist_delegate: Box, make_completion_provider: Arc Box>, @@ -398,7 +398,7 @@ impl PromptLibrary { pub fn new_prompt(&mut self, window: &mut Window, cx: &mut Context) { // If we already have an untitled prompt, use that instead // of creating a new one. - if let Some(metadata) = self.store.first() { + if let Some(metadata) = self.store.read(cx).first() { if metadata.title.is_none() { self.load_prompt(metadata.id, true, window, cx); return; @@ -406,7 +406,9 @@ impl PromptLibrary { } let prompt_id = PromptId::new(); - let save = self.store.save(prompt_id, None, false, "".into()); + let save = self.store.update(cx, |store, cx| { + store.save(prompt_id, None, false, "".into(), cx) + }); self.picker .update(cx, |picker, cx| picker.refresh(window, cx)); cx.spawn_in(window, async move |this, cx| { @@ -430,7 +432,7 @@ impl PromptLibrary { return; } - let prompt_metadata = self.store.metadata(prompt_id).unwrap(); + let prompt_metadata = self.store.read(cx).metadata(prompt_id).unwrap(); let prompt_editor = self.prompt_editors.get_mut(&prompt_id).unwrap(); let title = prompt_editor.title_editor.read(cx).text(cx); let body = prompt_editor.body_editor.update(cx, |editor, cx| { @@ -465,10 +467,13 @@ impl PromptLibrary { } else { Some(SharedString::from(title)) }; - store - .save(prompt_id, title, prompt_metadata.default, body) - .await - .log_err(); + cx.update(|_window, cx| { + store.update(cx, |store, cx| { + store.save(prompt_id, title, prompt_metadata.default, body, cx) + }) + })? + .await + .log_err(); this.update_in(cx, |this, window, cx| { this.picker .update(cx, |picker, cx| picker.refresh(window, cx)); @@ -521,14 +526,21 @@ impl PromptLibrary { window: &mut Window, cx: &mut Context, ) { - if let Some(prompt_metadata) = self.store.metadata(prompt_id) { - self.store - .save_metadata(prompt_id, prompt_metadata.title, !prompt_metadata.default) - .detach_and_log_err(cx); - self.picker - .update(cx, |picker, cx| picker.refresh(window, cx)); - cx.notify(); - } + self.store.update(cx, move |store, cx| { + if let Some(prompt_metadata) = store.metadata(prompt_id) { + store + .save_metadata( + prompt_id, + prompt_metadata.title, + !prompt_metadata.default, + cx, + ) + .detach_and_log_err(cx); + } + }); + self.picker + .update(cx, |picker, cx| picker.refresh(window, cx)); + cx.notify(); } pub fn load_prompt( @@ -545,9 +557,9 @@ impl PromptLibrary { .update(cx, |editor, cx| window.focus(&editor.focus_handle(cx))); } self.set_active_prompt(Some(prompt_id), window, cx); - } else if let Some(prompt_metadata) = self.store.metadata(prompt_id) { + } else if let Some(prompt_metadata) = self.store.read(cx).metadata(prompt_id) { let language_registry = self.language_registry.clone(); - let prompt = self.store.load(prompt_id); + let prompt = self.store.read(cx).load(prompt_id, cx); let make_completion_provider = self.make_completion_provider.clone(); self.pending_load = cx.spawn_in(window, async move |this, cx| { let prompt = prompt.await; @@ -673,7 +685,7 @@ impl PromptLibrary { window: &mut Window, cx: &mut Context, ) { - if let Some(metadata) = self.store.metadata(prompt_id) { + if let Some(metadata) = self.store.read(cx).metadata(prompt_id) { let confirmation = window.prompt( PromptLevel::Warning, &format!( @@ -692,7 +704,9 @@ impl PromptLibrary { this.set_active_prompt(None, window, cx); } this.prompt_editors.remove(&prompt_id); - this.store.delete(prompt_id).detach_and_log_err(cx); + this.store + .update(cx, |store, cx| store.delete(prompt_id, cx)) + .detach_and_log_err(cx); this.picker .update(cx, |picker, cx| picker.refresh(window, cx)); cx.notify(); @@ -736,9 +750,9 @@ impl PromptLibrary { let new_id = PromptId::new(); let body = prompt.body_editor.read(cx).text(cx); - let save = self - .store - .save(new_id, Some(title.into()), false, body.into()); + let save = self.store.update(cx, |store, cx| { + store.save(new_id, Some(title.into()), false, body.into(), cx) + }); self.picker .update(cx, |picker, cx| picker.refresh(window, cx)); cx.spawn_in(window, async move |this, cx| { @@ -968,7 +982,7 @@ impl PromptLibrary { .flex_none() .min_w_64() .children(self.active_prompt_id.and_then(|prompt_id| { - let prompt_metadata = self.store.metadata(prompt_id)?; + let prompt_metadata = self.store.read(cx).metadata(prompt_id)?; let prompt_editor = &self.prompt_editors[&prompt_id]; let focus_handle = prompt_editor.body_editor.focus_handle(cx); let model = LanguageModelRegistry::read_global(cx) @@ -1238,7 +1252,7 @@ impl Render for PromptLibrary { .text_color(theme.colors().text) .child(self.render_prompt_list(cx)) .map(|el| { - if self.store.prompt_count() == 0 { + if self.store.read(cx).prompt_count() == 0 { el.child( v_flex() .w_2_3() diff --git a/crates/prompt_store/src/prompt_store.rs b/crates/prompt_store/src/prompt_store.rs index 57e3e04e79..66e4b9072f 100644 --- a/crates/prompt_store/src/prompt_store.rs +++ b/crates/prompt_store/src/prompt_store.rs @@ -4,9 +4,11 @@ use anyhow::{Result, anyhow}; use chrono::{DateTime, Utc}; use collections::HashMap; use futures::FutureExt as _; -use futures::future::{self, BoxFuture, Shared}; +use futures::future::Shared; use fuzzy::StringMatchCandidate; -use gpui::{App, BackgroundExecutor, Global, ReadGlobal, SharedString, Task}; +use gpui::{ + App, AppContext, Context, Entity, EventEmitter, Global, ReadGlobal, SharedString, Task, +}; use heed::{ Database, RoTxn, types::{SerdeBincode, SerdeJson, Str}, @@ -29,11 +31,16 @@ use uuid::Uuid; /// a shared future to a global. pub fn init(cx: &mut App) { let db_path = paths::prompts_dir().join("prompts-library-db.0.mdb"); - let prompt_store_future = PromptStore::new(db_path, cx.background_executor().clone()) - .then(|result| future::ready(result.map(Arc::new).map_err(Arc::new))) - .boxed() + let prompt_store_task = PromptStore::new(db_path, cx); + let prompt_store_entity_task = cx + .spawn(async move |cx| { + prompt_store_task + .await + .and_then(|prompt_store| cx.new(|_cx| prompt_store)) + .map_err(Arc::new) + }) .shared(); - cx.set_global(GlobalPromptStore(prompt_store_future)) + cx.set_global(GlobalPromptStore(prompt_store_entity_task)) } #[derive(Clone, Debug, Serialize, Deserialize)] @@ -64,13 +71,16 @@ impl PromptId { } pub struct PromptStore { - executor: BackgroundExecutor, env: heed::Env, metadata_cache: RwLock, metadata: Database, SerdeJson>, bodies: Database, Str>, } +pub struct PromptsUpdatedEvent; + +impl EventEmitter for PromptStore {} + #[derive(Default)] struct MetadataCache { metadata: Vec, @@ -117,49 +127,45 @@ impl MetadataCache { } impl PromptStore { - pub fn global(cx: &App) -> impl Future>> + use<> { + pub fn global(cx: &App) -> impl Future>> + use<> { let store = GlobalPromptStore::global(cx).0.clone(); async move { store.await.map_err(|err| anyhow!(err)) } } - pub fn new(db_path: PathBuf, executor: BackgroundExecutor) -> Task> { - executor.spawn({ - let executor = executor.clone(); - async move { - std::fs::create_dir_all(&db_path)?; + pub fn new(db_path: PathBuf, cx: &App) -> Task> { + cx.background_spawn(async move { + std::fs::create_dir_all(&db_path)?; - let db_env = unsafe { - heed::EnvOpenOptions::new() - .map_size(1024 * 1024 * 1024) // 1GB - .max_dbs(4) // Metadata and bodies (possibly v1 of both as well) - .open(db_path)? - }; + let db_env = unsafe { + heed::EnvOpenOptions::new() + .map_size(1024 * 1024 * 1024) // 1GB + .max_dbs(4) // Metadata and bodies (possibly v1 of both as well) + .open(db_path)? + }; - let mut txn = db_env.write_txn()?; - let metadata = db_env.create_database(&mut txn, Some("metadata.v2"))?; - let bodies = db_env.create_database(&mut txn, Some("bodies.v2"))?; + let mut txn = db_env.write_txn()?; + let metadata = db_env.create_database(&mut txn, Some("metadata.v2"))?; + let bodies = db_env.create_database(&mut txn, Some("bodies.v2"))?; - // Remove edit workflow prompt, as we decided to opt into it using - // a slash command instead. - metadata.delete(&mut txn, &PromptId::EditWorkflow).ok(); - bodies.delete(&mut txn, &PromptId::EditWorkflow).ok(); + // Remove edit workflow prompt, as we decided to opt into it using + // a slash command instead. + metadata.delete(&mut txn, &PromptId::EditWorkflow).ok(); + bodies.delete(&mut txn, &PromptId::EditWorkflow).ok(); - txn.commit()?; + txn.commit()?; - Self::upgrade_dbs(&db_env, metadata, bodies).log_err(); + Self::upgrade_dbs(&db_env, metadata, bodies).log_err(); - let txn = db_env.read_txn()?; - let metadata_cache = MetadataCache::from_db(metadata, &txn)?; - txn.commit()?; + let txn = db_env.read_txn()?; + let metadata_cache = MetadataCache::from_db(metadata, &txn)?; + txn.commit()?; - Ok(PromptStore { - executor, - env: db_env, - metadata_cache: RwLock::new(metadata_cache), - metadata, - bodies, - }) - } + Ok(PromptStore { + env: db_env, + metadata_cache: RwLock::new(metadata_cache), + metadata, + bodies, + }) }) } @@ -237,10 +243,10 @@ impl PromptStore { Ok(()) } - pub fn load(&self, id: PromptId) -> Task> { + pub fn load(&self, id: PromptId, cx: &App) -> Task> { let env = self.env.clone(); let bodies = self.bodies; - self.executor.spawn(async move { + cx.background_spawn(async move { let txn = env.read_txn()?; let mut prompt = bodies .get(&txn, &id)? @@ -262,21 +268,27 @@ impl PromptStore { .collect::>(); } - pub fn delete(&self, id: PromptId) -> Task> { + pub fn delete(&self, id: PromptId, cx: &Context) -> Task> { self.metadata_cache.write().remove(id); let db_connection = self.env.clone(); let bodies = self.bodies; let metadata = self.metadata; - self.executor.spawn(async move { + let task = cx.background_spawn(async move { let mut txn = db_connection.write_txn()?; metadata.delete(&mut txn, &id)?; bodies.delete(&mut txn, &id)?; txn.commit()?; - Ok(()) + anyhow::Ok(()) + }); + + cx.spawn(async move |this, cx| { + task.await?; + this.update(cx, |_, cx| cx.emit(PromptsUpdatedEvent)).ok(); + anyhow::Ok(()) }) } @@ -302,10 +314,10 @@ impl PromptStore { Some(metadata.id) } - pub fn search(&self, query: String) -> Task> { + pub fn search(&self, query: String, cx: &App) -> Task> { let cached_metadata = self.metadata_cache.read().metadata.clone(); - let executor = self.executor.clone(); - self.executor.spawn(async move { + let executor = cx.background_executor().clone(); + cx.background_spawn(async move { let mut matches = if query.is_empty() { cached_metadata } else { @@ -341,6 +353,7 @@ impl PromptStore { title: Option, default: bool, body: Rope, + cx: &Context, ) -> Task> { if id.is_built_in() { return Task::ready(Err(anyhow!("built-in prompts cannot be saved"))); @@ -358,7 +371,7 @@ impl PromptStore { let bodies = self.bodies; let metadata = self.metadata; - self.executor.spawn(async move { + let task = cx.background_spawn(async move { let mut txn = db_connection.write_txn()?; metadata.put(&mut txn, &id, &prompt_metadata)?; @@ -366,7 +379,13 @@ impl PromptStore { txn.commit()?; - Ok(()) + anyhow::Ok(()) + }); + + cx.spawn(async move |this, cx| { + task.await?; + this.update(cx, |_, cx| cx.emit(PromptsUpdatedEvent)).ok(); + anyhow::Ok(()) }) } @@ -375,6 +394,7 @@ impl PromptStore { id: PromptId, mut title: Option, default: bool, + cx: &Context, ) -> Task> { let mut cache = self.metadata_cache.write(); @@ -397,19 +417,23 @@ impl PromptStore { let db_connection = self.env.clone(); let metadata = self.metadata; - self.executor.spawn(async move { + let task = cx.background_spawn(async move { let mut txn = db_connection.write_txn()?; metadata.put(&mut txn, &id, &prompt_metadata)?; txn.commit()?; - Ok(()) + anyhow::Ok(()) + }); + + cx.spawn(async move |this, cx| { + task.await?; + this.update(cx, |_, cx| cx.emit(PromptsUpdatedEvent)).ok(); + anyhow::Ok(()) }) } } /// Wraps a shared future to a prompt store so it can be assigned as a context global. -pub struct GlobalPromptStore( - Shared, Arc>>>, -); +pub struct GlobalPromptStore(Shared, Arc>>>); impl Global for GlobalPromptStore {} diff --git a/crates/prompt_store/src/prompts.rs b/crates/prompt_store/src/prompts.rs index 717af5fc55..54aa632c98 100644 --- a/crates/prompt_store/src/prompts.rs +++ b/crates/prompt_store/src/prompts.rs @@ -19,20 +19,29 @@ use util::{ResultExt, get_system_shell}; #[derive(Debug, Clone, Serialize)] pub struct ProjectContext { pub worktrees: Vec, + /// Whether any worktree has a rules_file. Provided as a field because handlebars can't do this. pub has_rules: bool, + pub default_user_rules: Vec, + /// `!default_user_rules.is_empty()` - provided as a field because handlebars can't do this. + pub has_default_user_rules: bool, pub os: String, pub arch: String, pub shell: String, } impl ProjectContext { - pub fn new(worktrees: Vec) -> Self { + pub fn new( + worktrees: Vec, + default_user_rules: Vec, + ) -> Self { let has_rules = worktrees .iter() .any(|worktree| worktree.rules_file.is_some()); Self { worktrees, has_rules, + has_default_user_rules: !default_user_rules.is_empty(), + default_user_rules, os: std::env::consts::OS.to_string(), arch: std::env::consts::ARCH.to_string(), shell: get_system_shell(), @@ -40,6 +49,12 @@ impl ProjectContext { } } +#[derive(Debug, Clone, Serialize)] +pub struct DefaultUserRulesContext { + pub title: Option, + pub contents: String, +} + #[derive(Debug, Clone, Serialize)] pub struct WorktreeContext { pub root_name: String, @@ -377,3 +392,30 @@ impl PromptBuilder { self.handlebars.lock().render("suggest_edits", &()) } } + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_assistant_system_prompt_renders() { + let worktrees = vec![WorktreeContext { + root_name: "path".into(), + abs_path: Path::new("/some/path").into(), + rules_file: Some(RulesFileContext { + path_in_worktree: Path::new(".rules").into(), + abs_path: Path::new("/some/path/.rules").into(), + text: "".into(), + }), + }]; + let default_user_rules = vec![DefaultUserRulesContext { + title: Some("Rules title".into()), + contents: "Rules contents".into(), + }]; + let project_context = ProjectContext::new(worktrees, default_user_rules); + PromptBuilder::new(None) + .unwrap() + .generate_assistant_system_prompt(&project_context) + .unwrap(); + } +}