From 1b1c2e55f32d10d351dae213f1c9af26b17cc630 Mon Sep 17 00:00:00 2001 From: Marshall Bowers Date: Thu, 16 Jan 2025 15:06:16 -0500 Subject: [PATCH] Extract `PromptStore` and `PromptBuilder` to new `prompt_library` crate (#23254) This PR adds a new `prompt_library` crate and extracts the `PromptStore` and `PromptBuilder` to it. Eventually we'll want to house the `PromptLibrary` itself in this crate, but right now that involves untangling a few dependencies. Release Notes: - N/A --- Cargo.lock | 33 +- Cargo.toml | 2 + crates/assistant/Cargo.toml | 4 +- crates/assistant/src/assistant.rs | 11 +- crates/assistant/src/assistant_panel.rs | 2 +- crates/assistant/src/context.rs | 2 +- crates/assistant/src/context/context_tests.rs | 6 +- crates/assistant/src/context_store.rs | 5 +- crates/assistant/src/inline_assistant.rs | 7 +- crates/assistant/src/prompt_library.rs | 421 +----------------- .../src/slash_command/default_command.rs | 2 +- .../src/slash_command/prompt_command.rs | 2 +- .../src/terminal_inline_assistant.rs | 4 +- crates/assistant2/Cargo.toml | 4 +- crates/assistant2/src/assistant.rs | 7 +- crates/assistant2/src/buffer_codegen.rs | 6 +- crates/assistant2/src/inline_assistant.rs | 3 +- crates/assistant2/src/prompts.rs | 312 ------------- .../src/terminal_inline_assistant.rs | 2 +- crates/prompt_library/Cargo.toml | 33 ++ crates/prompt_library/LICENSE-GPL | 1 + crates/prompt_library/src/prompt_library.rs | 11 + crates/prompt_library/src/prompt_store.rs | 412 +++++++++++++++++ .../src/prompts.rs | 0 24 files changed, 524 insertions(+), 768 deletions(-) delete mode 100644 crates/assistant2/src/prompts.rs create mode 100644 crates/prompt_library/Cargo.toml create mode 120000 crates/prompt_library/LICENSE-GPL create mode 100644 crates/prompt_library/src/prompt_library.rs create mode 100644 crates/prompt_library/src/prompt_store.rs rename crates/{assistant => prompt_library}/src/prompts.rs (100%) diff --git a/Cargo.lock b/Cargo.lock index 4abb09fc7c..939e4b4453 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -373,7 +373,6 @@ name = "assistant" version = "0.1.0" dependencies = [ "anyhow", - "assets", "assistant_settings", "assistant_slash_command", "assistant_tool", @@ -395,8 +394,6 @@ dependencies = [ "fuzzy", "globset", "gpui", - "handlebars 4.5.0", - "heed", "html_to_markdown", "http_client", "indexed_docs", @@ -418,6 +415,7 @@ dependencies = [ "picker", "pretty_assertions", "project", + "prompt_library", "proto", "rand 0.8.5", "regex", @@ -456,7 +454,6 @@ name = "assistant2" version = "0.1.0" dependencies = [ "anyhow", - "assets", "assistant_settings", "assistant_tool", "async-watch", @@ -474,7 +471,6 @@ dependencies = [ "futures 0.3.31", "fuzzy", "gpui", - "handlebars 4.5.0", "html_to_markdown", "http_client", "indoc", @@ -490,9 +486,9 @@ dependencies = [ "multi_buffer", "ordered-float 2.10.1", "parking_lot", - "paths", "picker", "project", + "prompt_library", "proto", "rand 0.8.5", "rope", @@ -9820,6 +9816,31 @@ dependencies = [ "thiserror 1.0.69", ] +[[package]] +name = "prompt_library" +version = "0.1.0" +dependencies = [ + "anyhow", + "assets", + "chrono", + "collections", + "fs", + "futures 0.3.31", + "fuzzy", + "gpui", + "handlebars 4.5.0", + "heed", + "language", + "log", + "parking_lot", + "paths", + "rope", + "serde", + "text", + "util", + "uuid", +] + [[package]] name = "prost" version = "0.9.0" diff --git a/Cargo.toml b/Cargo.toml index 293bc3f7fa..1ce69abdba 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -89,6 +89,7 @@ members = [ "crates/project", "crates/project_panel", "crates/project_symbols", + "crates/prompt_library", "crates/proto", "crates/recent_projects", "crates/refineable", @@ -279,6 +280,7 @@ prettier = { path = "crates/prettier" } project = { path = "crates/project" } project_panel = { path = "crates/project_panel" } project_symbols = { path = "crates/project_symbols" } +prompt_library = { path = "crates/prompt_library" } proto = { path = "crates/proto" } recent_projects = { path = "crates/recent_projects" } refineable = { path = "crates/refineable" } diff --git a/crates/assistant/Cargo.toml b/crates/assistant/Cargo.toml index ed889b0bb9..2b92b7e223 100644 --- a/crates/assistant/Cargo.toml +++ b/crates/assistant/Cargo.toml @@ -22,7 +22,6 @@ test-support = [ [dependencies] anyhow.workspace = true -assets.workspace = true assistant_settings.workspace = true assistant_slash_command.workspace = true assistant_tool.workspace = true @@ -42,8 +41,6 @@ futures.workspace = true fuzzy.workspace = true globset.workspace = true gpui.workspace = true -handlebars.workspace = true -heed.workspace = true html_to_markdown.workspace = true http_client.workspace = true indexed_docs.workspace = true @@ -63,6 +60,7 @@ parking_lot.workspace = true paths.workspace = true picker.workspace = true project.workspace = true +prompt_library.workspace = true proto.workspace = true regex.workspace = true release_channel.workspace = true diff --git a/crates/assistant/src/assistant.rs b/crates/assistant/src/assistant.rs index 5e6a96369c..fd5628d68e 100644 --- a/crates/assistant/src/assistant.rs +++ b/crates/assistant/src/assistant.rs @@ -6,7 +6,6 @@ pub mod context_store; mod inline_assistant; mod patch; mod prompt_library; -mod prompts; mod slash_command; pub(crate) mod slash_command_picker; pub mod slash_command_settings; @@ -14,6 +13,8 @@ mod streaming_diff; mod terminal_inline_assistant; use crate::slash_command::project_command::ProjectSlashCommandFeatureFlag; +pub use ::prompt_library::PromptBuilder; +use ::prompt_library::PromptLoadingParams; pub use assistant_panel::{AssistantPanel, AssistantPanelEvent}; use assistant_settings::AssistantSettings; use assistant_slash_command::SlashCommandRegistry; @@ -31,8 +32,6 @@ use language_model::{ LanguageModelId, LanguageModelProviderId, LanguageModelRegistry, LanguageModelResponseMessage, }; pub use patch::*; -pub use prompts::PromptBuilder; -use prompts::PromptLoadingParams; use semantic_index::{CloudEmbeddingProvider, SemanticDb}; use serde::{Deserialize, Serialize}; use settings::{Settings, SettingsStore}; @@ -225,14 +224,14 @@ pub fn init( .detach(); context_store::init(&client.clone().into()); - prompt_library::init(cx); + ::prompt_library::init(cx); init_language_model_settings(cx); assistant_slash_command::init(cx); assistant_tool::init(cx); assistant_panel::init(cx); context_server::init(cx); - let prompt_builder = prompts::PromptBuilder::new(Some(PromptLoadingParams { + let prompt_builder = PromptBuilder::new(Some(PromptLoadingParams { fs: fs.clone(), repo_path: stdout_is_a_pty .then(|| std::env::current_dir().log_err()) @@ -241,7 +240,7 @@ pub fn init( })) .log_err() .map(Arc::new) - .unwrap_or_else(|| Arc::new(prompts::PromptBuilder::new(None).unwrap())); + .unwrap_or_else(|| Arc::new(PromptBuilder::new(None).unwrap())); register_slash_commands(Some(prompt_builder.clone()), cx); inline_assistant::init( fs.clone(), diff --git a/crates/assistant/src/assistant_panel.rs b/crates/assistant/src/assistant_panel.rs index 6b2916a6ff..fe3840fefc 100644 --- a/crates/assistant/src/assistant_panel.rs +++ b/crates/assistant/src/assistant_panel.rs @@ -2,7 +2,6 @@ use crate::slash_command::file_command::codeblock_fence_for_path; use crate::{ humanize_token_count, prompt_library::open_prompt_library, - prompts::PromptBuilder, slash_command::{ default_command::DefaultSlashCommand, docs_command::{DocsSlashCommand, DocsSlashCommandArgs}, @@ -59,6 +58,7 @@ use multi_buffer::MultiBufferRow; use picker::{Picker, PickerDelegate}; use project::lsp_store::LocalLspAdapterDelegate; use project::{Project, Worktree}; +use prompt_library::PromptBuilder; use rope::Point; use search::{buffer_search::DivRegistrar, BufferSearchBar}; use serde::{Deserialize, Serialize}; diff --git a/crates/assistant/src/context.rs b/crates/assistant/src/context.rs index 135a59157c..99822f10b9 100644 --- a/crates/assistant/src/context.rs +++ b/crates/assistant/src/context.rs @@ -2,7 +2,6 @@ mod context_tests; use crate::{ - prompts::PromptBuilder, slash_command::{file_command::FileCommandMetadata, SlashCommandLine}, AssistantEdit, AssistantPatch, AssistantPatchStatus, MessageId, MessageStatus, }; @@ -22,6 +21,7 @@ use gpui::{ AppContext, Context as _, EventEmitter, Model, ModelContext, RenderImage, SharedString, Subscription, Task, }; +use prompt_library::PromptBuilder; use language::{AnchorRangeExt, Bias, Buffer, LanguageRegistry, OffsetRangeExt, Point, ToOffset}; use language_model::{ diff --git a/crates/assistant/src/context/context_tests.rs b/crates/assistant/src/context/context_tests.rs index 9689a882c3..756d57e88e 100644 --- a/crates/assistant/src/context/context_tests.rs +++ b/crates/assistant/src/context/context_tests.rs @@ -1,8 +1,8 @@ use super::{AssistantEdit, MessageCacheMetadata}; use crate::{ - assistant_panel, prompt_library, slash_command::file_command, AssistantEditKind, CacheStatus, - Context, ContextEvent, ContextId, ContextOperation, InvokedSlashCommandId, MessageId, - MessageStatus, PromptBuilder, + assistant_panel, slash_command::file_command, AssistantEditKind, CacheStatus, Context, + ContextEvent, ContextId, ContextOperation, InvokedSlashCommandId, MessageId, MessageStatus, + PromptBuilder, }; use anyhow::Result; use assistant_slash_command::{ diff --git a/crates/assistant/src/context_store.rs b/crates/assistant/src/context_store.rs index f1ba0d5673..b1fb79afbd 100644 --- a/crates/assistant/src/context_store.rs +++ b/crates/assistant/src/context_store.rs @@ -1,8 +1,8 @@ use crate::slash_command::context_server_command; use crate::SlashCommandId; use crate::{ - prompts::PromptBuilder, Context, ContextEvent, ContextId, ContextOperation, ContextVersion, - SavedContext, SavedContextMetadata, + Context, ContextEvent, ContextId, ContextOperation, ContextVersion, SavedContext, + SavedContextMetadata, }; use anyhow::{anyhow, Context as _, Result}; use assistant_slash_command::SlashCommandWorkingSet; @@ -21,6 +21,7 @@ use gpui::{ use language::LanguageRegistry; use paths::contexts_dir; use project::Project; +use prompt_library::PromptBuilder; use regex::Regex; use rpc::AnyProtoClient; use std::sync::LazyLock; diff --git a/crates/assistant/src/inline_assistant.rs b/crates/assistant/src/inline_assistant.rs index 4ea19d1912..0a6511c044 100644 --- a/crates/assistant/src/inline_assistant.rs +++ b/crates/assistant/src/inline_assistant.rs @@ -1,7 +1,7 @@ use crate::{ - humanize_token_count, prompts::PromptBuilder, AssistantPanel, AssistantPanelEvent, - CharOperation, CycleNextInlineAssist, CyclePreviousInlineAssist, LineDiff, LineOperation, - RequestType, StreamingDiff, + humanize_token_count, AssistantPanel, AssistantPanelEvent, CharOperation, + CycleNextInlineAssist, CyclePreviousInlineAssist, LineDiff, LineOperation, RequestType, + StreamingDiff, }; use anyhow::{anyhow, Context as _, Result}; use assistant_settings::AssistantSettings; @@ -41,6 +41,7 @@ use language_models::report_assistant_event; use multi_buffer::MultiBufferRow; use parking_lot::Mutex; use project::{CodeAction, ProjectTransaction}; +use prompt_library::PromptBuilder; use rope::Rope; use settings::{update_settings_file, Settings, SettingsStore}; use smol::future::FutureExt; diff --git a/crates/assistant/src/prompt_library.rs b/crates/assistant/src/prompt_library.rs index 705f059331..03997d4ef5 100644 --- a/crates/assistant/src/prompt_library.rs +++ b/crates/assistant/src/prompt_library.rs @@ -1,48 +1,30 @@ use crate::SlashCommandWorkingSet; use crate::{slash_command::SlashCommandCompletionProvider, AssistantPanel, InlineAssistant}; -use anyhow::{anyhow, Result}; -use chrono::{DateTime, Utc}; +use anyhow::Result; use collections::{HashMap, HashSet}; use editor::{actions::Tab, CurrentLineHighlight, Editor, EditorElement, EditorEvent, EditorStyle}; -use futures::{ - future::{self, BoxFuture, Shared}, - FutureExt, -}; -use fuzzy::StringMatchCandidate; use gpui::{ - actions, point, size, transparent_black, Action, AppContext, BackgroundExecutor, Bounds, - EventEmitter, Global, PromptLevel, ReadGlobal, Subscription, Task, TextStyle, TitlebarOptions, - UpdateGlobal, View, WindowBounds, WindowHandle, WindowOptions, -}; -use heed::{ - types::{SerdeBincode, SerdeJson, Str}, - Database, RoTxn, + actions, point, size, transparent_black, Action, AppContext, Bounds, EventEmitter, PromptLevel, + Subscription, Task, TextStyle, TitlebarOptions, UpdateGlobal, View, WindowBounds, WindowHandle, + WindowOptions, }; use language::{language_settings::SoftWrap, Buffer, LanguageRegistry}; use language_model::{ LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role, }; -use parking_lot::RwLock; use picker::{Picker, PickerDelegate}; +use prompt_library::{PromptId, PromptMetadata, PromptStore}; use release_channel::ReleaseChannel; use rope::Rope; -use serde::{Deserialize, Serialize}; use settings::Settings; -use std::{ - cmp::Reverse, - future::Future, - path::PathBuf, - sync::{atomic::AtomicBool, Arc}, - time::Duration, -}; -use text::LineEnding; +use std::sync::Arc; +use std::time::Duration; use theme::ThemeSettings; use ui::{ div, prelude::*, IconButtonShape, KeyBinding, ListItem, ListItemSpacing, ParentElement, Render, SharedString, Styled, Tooltip, ViewContext, VisualContext, }; use util::{ResultExt, TryFutureExt}; -use uuid::Uuid; use workspace::Workspace; use zed_actions::InlineAssist; @@ -56,17 +38,6 @@ actions!( ] ); -/// Init starts loading the PromptStore in the background and assigns -/// a shared future to a global. -pub fn init(cx: &mut AppContext) { - let db_path = paths::prompts_dir().join("prompts-library-db.0.mdb"); - let prompt_store_future = PromptStore::new(db_path, cx.background_executor().clone()) - .then(|result| future::ready(result.map(Arc::new).map_err(Arc::new))) - .boxed() - .shared(); - cx.set_global(GlobalPromptStore(prompt_store_future)) -} - const BUILT_IN_TOOLTIP_TEXT: &'static str = concat!( "This prompt supports special functionality.\n", "It's read-only, but you can remove it from your default prompt." @@ -1165,381 +1136,3 @@ impl Render for PromptLibrary { }) } } - -#[derive(Clone, Debug, Serialize, Deserialize)] -pub struct PromptMetadata { - pub id: PromptId, - pub title: Option, - pub default: bool, - pub saved_at: DateTime, -} - -#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)] -#[serde(tag = "kind")] -pub enum PromptId { - User { uuid: Uuid }, - EditWorkflow, -} - -impl PromptId { - pub fn new() -> PromptId { - PromptId::User { - uuid: Uuid::new_v4(), - } - } - - pub fn is_built_in(&self) -> bool { - !matches!(self, PromptId::User { .. }) - } -} - -pub struct PromptStore { - executor: BackgroundExecutor, - env: heed::Env, - metadata_cache: RwLock, - metadata: Database, SerdeJson>, - bodies: Database, Str>, -} - -#[derive(Default)] -struct MetadataCache { - metadata: Vec, - metadata_by_id: HashMap, -} - -impl MetadataCache { - fn from_db( - db: Database, SerdeJson>, - txn: &RoTxn, - ) -> Result { - let mut cache = MetadataCache::default(); - for result in db.iter(txn)? { - let (prompt_id, metadata) = result?; - cache.metadata.push(metadata.clone()); - cache.metadata_by_id.insert(prompt_id, metadata); - } - cache.sort(); - Ok(cache) - } - - fn insert(&mut self, metadata: PromptMetadata) { - self.metadata_by_id.insert(metadata.id, metadata.clone()); - if let Some(old_metadata) = self.metadata.iter_mut().find(|m| m.id == metadata.id) { - *old_metadata = metadata; - } else { - self.metadata.push(metadata); - } - self.sort(); - } - - fn remove(&mut self, id: PromptId) { - self.metadata.retain(|metadata| metadata.id != id); - self.metadata_by_id.remove(&id); - } - - fn sort(&mut self) { - self.metadata.sort_unstable_by(|a, b| { - a.title - .cmp(&b.title) - .then_with(|| b.saved_at.cmp(&a.saved_at)) - }); - } -} - -impl PromptStore { - pub fn global(cx: &AppContext) -> impl Future>> { - let store = GlobalPromptStore::global(cx).0.clone(); - async move { store.await.map_err(|err| anyhow!(err)) } - } - - pub fn new(db_path: PathBuf, executor: BackgroundExecutor) -> Task> { - executor.spawn({ - let executor = executor.clone(); - async move { - std::fs::create_dir_all(&db_path)?; - - let db_env = unsafe { - heed::EnvOpenOptions::new() - .map_size(1024 * 1024 * 1024) // 1GB - .max_dbs(4) // Metadata and bodies (possibly v1 of both as well) - .open(db_path)? - }; - - let mut txn = db_env.write_txn()?; - let metadata = db_env.create_database(&mut txn, Some("metadata.v2"))?; - let bodies = db_env.create_database(&mut txn, Some("bodies.v2"))?; - - // Remove edit workflow prompt, as we decided to opt into it using - // a slash command instead. - metadata.delete(&mut txn, &PromptId::EditWorkflow).ok(); - bodies.delete(&mut txn, &PromptId::EditWorkflow).ok(); - - txn.commit()?; - - Self::upgrade_dbs(&db_env, metadata, bodies).log_err(); - - let txn = db_env.read_txn()?; - let metadata_cache = MetadataCache::from_db(metadata, &txn)?; - txn.commit()?; - - Ok(PromptStore { - executor, - env: db_env, - metadata_cache: RwLock::new(metadata_cache), - metadata, - bodies, - }) - } - }) - } - - fn upgrade_dbs( - env: &heed::Env, - metadata_db: heed::Database, SerdeJson>, - bodies_db: heed::Database, Str>, - ) -> Result<()> { - #[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize, Hash)] - pub struct PromptIdV1(Uuid); - - #[derive(Clone, Debug, Serialize, Deserialize)] - pub struct PromptMetadataV1 { - pub id: PromptIdV1, - pub title: Option, - pub default: bool, - pub saved_at: DateTime, - } - - let mut txn = env.write_txn()?; - let Some(bodies_v1_db) = env - .open_database::, SerdeBincode>( - &txn, - Some("bodies"), - )? - else { - return Ok(()); - }; - let mut bodies_v1 = bodies_v1_db - .iter(&txn)? - .collect::>>()?; - - let Some(metadata_v1_db) = env - .open_database::, SerdeBincode>( - &txn, - Some("metadata"), - )? - else { - return Ok(()); - }; - let metadata_v1 = metadata_v1_db - .iter(&txn)? - .collect::>>()?; - - for (prompt_id_v1, metadata_v1) in metadata_v1 { - let prompt_id_v2 = PromptId::User { - uuid: prompt_id_v1.0, - }; - let Some(body_v1) = bodies_v1.remove(&prompt_id_v1) else { - continue; - }; - - if metadata_db - .get(&txn, &prompt_id_v2)? - .map_or(true, |metadata_v2| { - metadata_v1.saved_at > metadata_v2.saved_at - }) - { - metadata_db.put( - &mut txn, - &prompt_id_v2, - &PromptMetadata { - id: prompt_id_v2, - title: metadata_v1.title.clone(), - default: metadata_v1.default, - saved_at: metadata_v1.saved_at, - }, - )?; - bodies_db.put(&mut txn, &prompt_id_v2, &body_v1)?; - } - } - - txn.commit()?; - - Ok(()) - } - - pub fn load(&self, id: PromptId) -> Task> { - let env = self.env.clone(); - let bodies = self.bodies; - self.executor.spawn(async move { - let txn = env.read_txn()?; - let mut prompt = bodies - .get(&txn, &id)? - .ok_or_else(|| anyhow!("prompt not found"))? - .into(); - LineEnding::normalize(&mut prompt); - Ok(prompt) - }) - } - - pub fn default_prompt_metadata(&self) -> Vec { - return self - .metadata_cache - .read() - .metadata - .iter() - .filter(|metadata| metadata.default) - .cloned() - .collect::>(); - } - - pub fn delete(&self, id: PromptId) -> Task> { - self.metadata_cache.write().remove(id); - - let db_connection = self.env.clone(); - let bodies = self.bodies; - let metadata = self.metadata; - - self.executor.spawn(async move { - let mut txn = db_connection.write_txn()?; - - metadata.delete(&mut txn, &id)?; - bodies.delete(&mut txn, &id)?; - - txn.commit()?; - Ok(()) - }) - } - - /// Returns the number of prompts in the store. - fn prompt_count(&self) -> usize { - self.metadata_cache.read().metadata.len() - } - - fn metadata(&self, id: PromptId) -> Option { - self.metadata_cache.read().metadata_by_id.get(&id).cloned() - } - - pub fn id_for_title(&self, title: &str) -> Option { - let metadata_cache = self.metadata_cache.read(); - let metadata = metadata_cache - .metadata - .iter() - .find(|metadata| metadata.title.as_ref().map(|title| &***title) == Some(title))?; - Some(metadata.id) - } - - pub fn search(&self, query: String) -> Task> { - let cached_metadata = self.metadata_cache.read().metadata.clone(); - let executor = self.executor.clone(); - self.executor.spawn(async move { - let mut matches = if query.is_empty() { - cached_metadata - } else { - let candidates = cached_metadata - .iter() - .enumerate() - .filter_map(|(ix, metadata)| { - Some(StringMatchCandidate::new(ix, metadata.title.as_ref()?)) - }) - .collect::>(); - let matches = fuzzy::match_strings( - &candidates, - &query, - false, - 100, - &AtomicBool::default(), - executor, - ) - .await; - matches - .into_iter() - .map(|mat| cached_metadata[mat.candidate_id].clone()) - .collect() - }; - matches.sort_by_key(|metadata| Reverse(metadata.default)); - matches - }) - } - - fn save( - &self, - id: PromptId, - title: Option, - default: bool, - body: Rope, - ) -> Task> { - if id.is_built_in() { - return Task::ready(Err(anyhow!("built-in prompts cannot be saved"))); - } - - let prompt_metadata = PromptMetadata { - id, - title, - default, - saved_at: Utc::now(), - }; - self.metadata_cache.write().insert(prompt_metadata.clone()); - - let db_connection = self.env.clone(); - let bodies = self.bodies; - let metadata = self.metadata; - - self.executor.spawn(async move { - let mut txn = db_connection.write_txn()?; - - metadata.put(&mut txn, &id, &prompt_metadata)?; - bodies.put(&mut txn, &id, &body.to_string())?; - - txn.commit()?; - - Ok(()) - }) - } - - fn save_metadata( - &self, - id: PromptId, - mut title: Option, - default: bool, - ) -> Task> { - let mut cache = self.metadata_cache.write(); - - if id.is_built_in() { - title = cache - .metadata_by_id - .get(&id) - .and_then(|metadata| metadata.title.clone()); - } - - let prompt_metadata = PromptMetadata { - id, - title, - default, - saved_at: Utc::now(), - }; - - cache.insert(prompt_metadata.clone()); - - let db_connection = self.env.clone(); - let metadata = self.metadata; - - self.executor.spawn(async move { - let mut txn = db_connection.write_txn()?; - metadata.put(&mut txn, &id, &prompt_metadata)?; - txn.commit()?; - - Ok(()) - }) - } - - fn first(&self) -> Option { - self.metadata_cache.read().metadata.first().cloned() - } -} - -/// Wraps a shared future to a prompt store so it can be assigned as a context global. -pub struct GlobalPromptStore( - Shared, Arc>>>, -); - -impl Global for GlobalPromptStore {} diff --git a/crates/assistant/src/slash_command/default_command.rs b/crates/assistant/src/slash_command/default_command.rs index 49a7b244e9..9577d2ac7f 100644 --- a/crates/assistant/src/slash_command/default_command.rs +++ b/crates/assistant/src/slash_command/default_command.rs @@ -1,4 +1,3 @@ -use crate::prompt_library::PromptStore; use anyhow::{anyhow, Result}; use assistant_slash_command::{ ArgumentCompletion, SlashCommand, SlashCommandOutput, SlashCommandOutputSection, @@ -6,6 +5,7 @@ use assistant_slash_command::{ }; use gpui::{Task, WeakView}; use language::{BufferSnapshot, LspAdapterDelegate}; +use prompt_library::PromptStore; use std::{ fmt::Write, sync::{atomic::AtomicBool, Arc}, diff --git a/crates/assistant/src/slash_command/prompt_command.rs b/crates/assistant/src/slash_command/prompt_command.rs index 9eb44d3418..2091201fab 100644 --- a/crates/assistant/src/slash_command/prompt_command.rs +++ b/crates/assistant/src/slash_command/prompt_command.rs @@ -1,4 +1,3 @@ -use crate::prompt_library::PromptStore; use anyhow::{anyhow, Context, Result}; use assistant_slash_command::{ ArgumentCompletion, SlashCommand, SlashCommandOutput, SlashCommandOutputSection, @@ -6,6 +5,7 @@ use assistant_slash_command::{ }; use gpui::{Task, WeakView}; use language::{BufferSnapshot, LspAdapterDelegate}; +use prompt_library::PromptStore; use std::sync::{atomic::AtomicBool, Arc}; use ui::prelude::*; use workspace::Workspace; diff --git a/crates/assistant/src/terminal_inline_assistant.rs b/crates/assistant/src/terminal_inline_assistant.rs index df0ff24f3f..b55a731c47 100644 --- a/crates/assistant/src/terminal_inline_assistant.rs +++ b/crates/assistant/src/terminal_inline_assistant.rs @@ -1,6 +1,5 @@ use crate::{ - humanize_token_count, prompts::PromptBuilder, AssistantPanel, AssistantPanelEvent, RequestType, - DEFAULT_CONTEXT_LINES, + humanize_token_count, AssistantPanel, AssistantPanelEvent, RequestType, DEFAULT_CONTEXT_LINES, }; use anyhow::{Context as _, Result}; use assistant_settings::AssistantSettings; @@ -22,6 +21,7 @@ use language_model::{ }; use language_model_selector::{LanguageModelSelector, LanguageModelSelectorPopoverMenu}; use language_models::report_assistant_event; +use prompt_library::PromptBuilder; use settings::{update_settings_file, Settings}; use std::{ cmp, diff --git a/crates/assistant2/Cargo.toml b/crates/assistant2/Cargo.toml index ca0e929525..2e5d5eff41 100644 --- a/crates/assistant2/Cargo.toml +++ b/crates/assistant2/Cargo.toml @@ -14,7 +14,6 @@ doctest = false [dependencies] anyhow.workspace = true -assets.workspace = true assistant_settings.workspace = true assistant_tool.workspace = true async-watch.workspace = true @@ -32,7 +31,6 @@ fs.workspace = true futures.workspace = true fuzzy.workspace = true gpui.workspace = true -handlebars.workspace = true html_to_markdown.workspace = true http_client.workspace = true itertools.workspace = true @@ -47,9 +45,9 @@ menu.workspace = true multi_buffer.workspace = true ordered-float.workspace = true parking_lot.workspace = true -paths.workspace = true picker.workspace = true project.workspace = true +prompt_library.workspace = true proto.workspace = true rope.workspace = true serde.workspace = true diff --git a/crates/assistant2/src/assistant.rs b/crates/assistant2/src/assistant.rs index c68ce8e377..2334748a40 100644 --- a/crates/assistant2/src/assistant.rs +++ b/crates/assistant2/src/assistant.rs @@ -9,7 +9,6 @@ mod context_strip; mod inline_assistant; mod inline_prompt_editor; mod message_editor; -mod prompts; mod streaming_diff; mod terminal_codegen; mod terminal_inline_assistant; @@ -26,7 +25,7 @@ use command_palette_hooks::CommandPaletteFilter; use feature_flags::{Assistant2FeatureFlag, FeatureFlagAppExt}; use fs::Fs; use gpui::{actions, AppContext}; -use prompts::PromptLoadingParams; +use prompt_library::{PromptBuilder, PromptLoadingParams}; use settings::Settings as _; use util::ResultExt; @@ -62,7 +61,7 @@ pub fn init(fs: Arc, client: Arc, stdout_is_a_pty: bool, cx: &mu AssistantSettings::register(cx); assistant_panel::init(cx); - let prompt_builder = prompts::PromptBuilder::new(Some(PromptLoadingParams { + let prompt_builder = PromptBuilder::new(Some(PromptLoadingParams { fs: fs.clone(), repo_path: stdout_is_a_pty .then(|| std::env::current_dir().log_err()) @@ -71,7 +70,7 @@ pub fn init(fs: Arc, client: Arc, stdout_is_a_pty: bool, cx: &mu })) .log_err() .map(Arc::new) - .unwrap_or_else(|| Arc::new(prompts::PromptBuilder::new(None).unwrap())); + .unwrap_or_else(|| Arc::new(PromptBuilder::new(None).unwrap())); inline_assistant::init( fs.clone(), prompt_builder.clone(), diff --git a/crates/assistant2/src/buffer_codegen.rs b/crates/assistant2/src/buffer_codegen.rs index 90d830dea3..7538f008db 100644 --- a/crates/assistant2/src/buffer_codegen.rs +++ b/crates/assistant2/src/buffer_codegen.rs @@ -1,10 +1,7 @@ use crate::context::attach_context_to_message; use crate::context_store::ContextStore; use crate::inline_prompt_editor::CodegenStatus; -use crate::{ - prompts::PromptBuilder, - streaming_diff::{CharOperation, LineDiff, LineOperation, StreamingDiff}, -}; +use crate::streaming_diff::{CharOperation, LineDiff, LineOperation, StreamingDiff}; use anyhow::{Context as _, Result}; use client::telemetry::Telemetry; use collections::HashSet; @@ -19,6 +16,7 @@ use language_model::{ use language_models::report_assistant_event; use multi_buffer::MultiBufferRow; use parking_lot::Mutex; +use prompt_library::PromptBuilder; use rope::Rope; use smol::future::FutureExt; use std::{ diff --git a/crates/assistant2/src/inline_assistant.rs b/crates/assistant2/src/inline_assistant.rs index 275b0a6646..15a0adeaa0 100644 --- a/crates/assistant2/src/inline_assistant.rs +++ b/crates/assistant2/src/inline_assistant.rs @@ -29,6 +29,7 @@ use language_models::report_assistant_event; use multi_buffer::MultiBufferRow; use parking_lot::Mutex; use project::{CodeAction, ProjectTransaction}; +use prompt_library::PromptBuilder; use settings::{Settings, SettingsStore}; use telemetry_events::{AssistantEvent, AssistantKind, AssistantPhase}; use terminal_view::{terminal_panel::TerminalPanel, TerminalView}; @@ -42,9 +43,9 @@ use workspace::{notifications::NotificationId, ItemHandle, Toast, Workspace}; use crate::buffer_codegen::{BufferCodegen, CodegenAlternative, CodegenEvent}; use crate::context_store::ContextStore; use crate::inline_prompt_editor::{CodegenStatus, InlineAssistId, PromptEditor, PromptEditorEvent}; +use crate::terminal_inline_assistant::TerminalInlineAssistant; use crate::thread_store::ThreadStore; use crate::AssistantPanel; -use crate::{prompts::PromptBuilder, terminal_inline_assistant::TerminalInlineAssistant}; pub fn init( fs: Arc, diff --git a/crates/assistant2/src/prompts.rs b/crates/assistant2/src/prompts.rs deleted file mode 100644 index 1edfd9cb46..0000000000 --- a/crates/assistant2/src/prompts.rs +++ /dev/null @@ -1,312 +0,0 @@ -use anyhow::Result; -use assets::Assets; -use fs::Fs; -use futures::StreamExt; -use gpui::AssetSource; -use handlebars::{Handlebars, RenderError}; -use language::{BufferSnapshot, LanguageName, Point}; -use parking_lot::Mutex; -use serde::Serialize; -use std::{ops::Range, path::PathBuf, sync::Arc, time::Duration}; -use text::LineEnding; -use util::ResultExt; - -#[derive(Serialize)] -pub struct ContentPromptDiagnosticContext { - pub line_number: usize, - pub error_message: String, - pub code_content: String, -} - -#[derive(Serialize)] -pub struct ContentPromptContext { - pub content_type: String, - pub language_name: Option, - pub is_insert: bool, - pub is_truncated: bool, - pub document_content: String, - pub user_prompt: String, - pub rewrite_section: Option, - pub diagnostic_errors: Vec, -} - -#[derive(Serialize)] -pub struct TerminalAssistantPromptContext { - pub os: String, - pub arch: String, - pub shell: Option, - pub working_directory: Option, - pub latest_output: Vec, - pub user_prompt: String, -} - -#[derive(Serialize)] -pub struct ProjectSlashCommandPromptContext { - pub context_buffer: String, -} - -pub struct PromptLoadingParams<'a> { - pub fs: Arc, - pub repo_path: Option, - pub cx: &'a gpui::AppContext, -} - -pub struct PromptBuilder { - handlebars: Arc>>, -} - -impl PromptBuilder { - pub fn new(loading_params: Option) -> Result { - let mut handlebars = Handlebars::new(); - Self::register_built_in_templates(&mut handlebars)?; - - let handlebars = Arc::new(Mutex::new(handlebars)); - - if let Some(params) = loading_params { - Self::watch_fs_for_template_overrides(params, handlebars.clone()); - } - - Ok(Self { handlebars }) - } - - /// Watches the filesystem for changes to prompt template overrides. - /// - /// This function sets up a file watcher on the prompt templates directory. It performs - /// an initial scan of the directory and registers any existing template overrides. - /// Then it continuously monitors for changes, reloading templates as they are - /// modified or added. - /// - /// If the templates directory doesn't exist initially, it waits for it to be created. - /// If the directory is removed, it restores the built-in templates and waits for the - /// directory to be recreated. - /// - /// # Arguments - /// - /// * `params` - A `PromptLoadingParams` struct containing the filesystem, repository path, - /// and application context. - /// * `handlebars` - An `Arc>` for registering and updating templates. - fn watch_fs_for_template_overrides( - params: PromptLoadingParams, - handlebars: Arc>>, - ) { - let templates_dir = paths::prompt_overrides_dir(params.repo_path.as_deref()); - params.cx.background_executor() - .spawn(async move { - let Some(parent_dir) = templates_dir.parent() else { - return; - }; - - let mut found_dir_once = false; - loop { - // Check if the templates directory exists and handle its status - // If it exists, log its presence and check if it's a symlink - // If it doesn't exist: - // - Log that we're using built-in prompts - // - Check if it's a broken symlink and log if so - // - Set up a watcher to detect when it's created - // After the first check, set the `found_dir_once` flag - // This allows us to avoid logging when looping back around after deleting the prompt overrides directory. - let dir_status = params.fs.is_dir(&templates_dir).await; - let symlink_status = params.fs.read_link(&templates_dir).await.ok(); - if dir_status { - let mut log_message = format!("Prompt template overrides directory found at {}", templates_dir.display()); - if let Some(target) = symlink_status { - log_message.push_str(" -> "); - log_message.push_str(&target.display().to_string()); - } - log::info!("{}.", log_message); - } else { - if !found_dir_once { - log::info!("No prompt template overrides directory found at {}. Using built-in prompts.", templates_dir.display()); - if let Some(target) = symlink_status { - log::info!("Symlink found pointing to {}, but target is invalid.", target.display()); - } - } - - if params.fs.is_dir(parent_dir).await { - let (mut changes, _watcher) = params.fs.watch(parent_dir, Duration::from_secs(1)).await; - while let Some(changed_paths) = changes.next().await { - if changed_paths.iter().any(|p| &p.path == &templates_dir) { - let mut log_message = format!("Prompt template overrides directory detected at {}", templates_dir.display()); - if let Ok(target) = params.fs.read_link(&templates_dir).await { - log_message.push_str(" -> "); - log_message.push_str(&target.display().to_string()); - } - log::info!("{}.", log_message); - break; - } - } - } else { - return; - } - } - - found_dir_once = true; - - // Initial scan of the prompt overrides directory - if let Ok(mut entries) = params.fs.read_dir(&templates_dir).await { - while let Some(Ok(file_path)) = entries.next().await { - if file_path.to_string_lossy().ends_with(".hbs") { - if let Ok(content) = params.fs.load(&file_path).await { - let file_name = file_path.file_stem().unwrap().to_string_lossy(); - log::debug!("Registering prompt template override: {}", file_name); - handlebars.lock().register_template_string(&file_name, content).log_err(); - } - } - } - } - - // Watch both the parent directory and the template overrides directory: - // - Monitor the parent directory to detect if the template overrides directory is deleted. - // - Monitor the template overrides directory to re-register templates when they change. - // Combine both watch streams into a single stream. - let (parent_changes, parent_watcher) = params.fs.watch(parent_dir, Duration::from_secs(1)).await; - let (changes, watcher) = params.fs.watch(&templates_dir, Duration::from_secs(1)).await; - let mut combined_changes = futures::stream::select(changes, parent_changes); - - while let Some(changed_paths) = combined_changes.next().await { - if changed_paths.iter().any(|p| &p.path == &templates_dir) { - if !params.fs.is_dir(&templates_dir).await { - log::info!("Prompt template overrides directory removed. Restoring built-in prompt templates."); - Self::register_built_in_templates(&mut handlebars.lock()).log_err(); - break; - } - } - for event in changed_paths { - if event.path.starts_with(&templates_dir) && event.path.extension().map_or(false, |ext| ext == "hbs") { - log::info!("Reloading prompt template override: {}", event.path.display()); - if let Some(content) = params.fs.load(&event.path).await.log_err() { - let file_name = event.path.file_stem().unwrap().to_string_lossy(); - handlebars.lock().register_template_string(&file_name, content).log_err(); - } - } - } - } - - drop(watcher); - drop(parent_watcher); - } - }) - .detach(); - } - - fn register_built_in_templates(handlebars: &mut Handlebars) -> Result<()> { - for path in Assets.list("prompts")? { - if let Some(id) = path.split('/').last().and_then(|s| s.strip_suffix(".hbs")) { - if let Some(prompt) = Assets.load(path.as_ref()).log_err().flatten() { - log::debug!("Registering built-in prompt template: {}", id); - let prompt = String::from_utf8_lossy(prompt.as_ref()); - handlebars.register_template_string(id, LineEnding::normalize_cow(prompt))? - } - } - } - - Ok(()) - } - - pub fn generate_inline_transformation_prompt( - &self, - user_prompt: String, - language_name: Option<&LanguageName>, - buffer: BufferSnapshot, - range: Range, - ) -> Result { - let content_type = match language_name.as_ref().map(|l| l.0.as_ref()) { - None | Some("Markdown" | "Plain Text") => "text", - Some(_) => "code", - }; - - const MAX_CTX: usize = 50000; - let is_insert = range.is_empty(); - let mut is_truncated = false; - - let before_range = 0..range.start; - let truncated_before = if before_range.len() > MAX_CTX { - is_truncated = true; - let start = buffer.clip_offset(range.start - MAX_CTX, text::Bias::Right); - start..range.start - } else { - before_range - }; - - let after_range = range.end..buffer.len(); - let truncated_after = if after_range.len() > MAX_CTX { - is_truncated = true; - let end = buffer.clip_offset(range.end + MAX_CTX, text::Bias::Left); - range.end..end - } else { - after_range - }; - - let mut document_content = String::new(); - for chunk in buffer.text_for_range(truncated_before) { - document_content.push_str(chunk); - } - if is_insert { - document_content.push_str(""); - } else { - document_content.push_str("\n"); - for chunk in buffer.text_for_range(range.clone()) { - document_content.push_str(chunk); - } - document_content.push_str("\n"); - } - for chunk in buffer.text_for_range(truncated_after) { - document_content.push_str(chunk); - } - - let rewrite_section = if !is_insert { - let mut section = String::new(); - for chunk in buffer.text_for_range(range.clone()) { - section.push_str(chunk); - } - Some(section) - } else { - None - }; - let diagnostics = buffer.diagnostics_in_range::<_, Point>(range, false); - let diagnostic_errors: Vec = diagnostics - .map(|entry| { - let start = entry.range.start; - ContentPromptDiagnosticContext { - line_number: (start.row + 1) as usize, - error_message: entry.diagnostic.message.clone(), - code_content: buffer.text_for_range(entry.range.clone()).collect(), - } - }) - .collect(); - - let context = ContentPromptContext { - content_type: content_type.to_string(), - language_name: language_name.map(|s| s.to_string()), - is_insert, - is_truncated, - document_content, - user_prompt, - rewrite_section, - diagnostic_errors, - }; - self.handlebars.lock().render("content_prompt", &context) - } - - pub fn generate_terminal_assistant_prompt( - &self, - user_prompt: &str, - shell: Option<&str>, - working_directory: Option<&str>, - latest_output: &[String], - ) -> Result { - let context = TerminalAssistantPromptContext { - os: std::env::consts::OS.to_string(), - arch: std::env::consts::ARCH.to_string(), - shell: shell.map(|s| s.to_string()), - working_directory: working_directory.map(|s| s.to_string()), - latest_output: latest_output.to_vec(), - user_prompt: user_prompt.to_string(), - }; - - self.handlebars - .lock() - .render("terminal_assistant_prompt", &context) - } -} diff --git a/crates/assistant2/src/terminal_inline_assistant.rs b/crates/assistant2/src/terminal_inline_assistant.rs index 1b75dc2c3e..65be6a6e98 100644 --- a/crates/assistant2/src/terminal_inline_assistant.rs +++ b/crates/assistant2/src/terminal_inline_assistant.rs @@ -3,7 +3,6 @@ use crate::context_store::ContextStore; use crate::inline_prompt_editor::{ CodegenStatus, PromptEditor, PromptEditorEvent, TerminalInlineAssistId, }; -use crate::prompts::PromptBuilder; use crate::terminal_codegen::{CodegenEvent, TerminalCodegen, CLEAR_INPUT}; use crate::thread_store::ThreadStore; use anyhow::{Context as _, Result}; @@ -20,6 +19,7 @@ use language_model::{ LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role, }; use language_models::report_assistant_event; +use prompt_library::PromptBuilder; use std::sync::Arc; use telemetry_events::{AssistantEvent, AssistantKind, AssistantPhase}; use terminal_view::TerminalView; diff --git a/crates/prompt_library/Cargo.toml b/crates/prompt_library/Cargo.toml new file mode 100644 index 0000000000..3ad8c72534 --- /dev/null +++ b/crates/prompt_library/Cargo.toml @@ -0,0 +1,33 @@ +[package] +name = "prompt_library" +version = "0.1.0" +edition = "2021" +publish = false +license = "GPL-3.0-or-later" + +[lints] +workspace = true + +[lib] +path = "src/prompt_library.rs" + +[dependencies] +anyhow.workspace = true +assets.workspace = true +chrono.workspace = true +collections.workspace = true +fs.workspace = true +futures.workspace = true +fuzzy.workspace = true +gpui.workspace = true +handlebars.workspace = true +heed.workspace = true +language.workspace = true +log.workspace = true +parking_lot.workspace = true +paths.workspace = true +rope.workspace = true +serde.workspace = true +text.workspace = true +util.workspace = true +uuid.workspace = true diff --git a/crates/prompt_library/LICENSE-GPL b/crates/prompt_library/LICENSE-GPL new file mode 120000 index 0000000000..89e542f750 --- /dev/null +++ b/crates/prompt_library/LICENSE-GPL @@ -0,0 +1 @@ +../../LICENSE-GPL \ No newline at end of file diff --git a/crates/prompt_library/src/prompt_library.rs b/crates/prompt_library/src/prompt_library.rs new file mode 100644 index 0000000000..4e607deaae --- /dev/null +++ b/crates/prompt_library/src/prompt_library.rs @@ -0,0 +1,11 @@ +mod prompt_store; +mod prompts; + +use gpui::AppContext; + +pub use crate::prompt_store::*; +pub use crate::prompts::*; + +pub fn init(cx: &mut AppContext) { + prompt_store::init(cx); +} diff --git a/crates/prompt_library/src/prompt_store.rs b/crates/prompt_library/src/prompt_store.rs new file mode 100644 index 0000000000..56ee87e489 --- /dev/null +++ b/crates/prompt_library/src/prompt_store.rs @@ -0,0 +1,412 @@ +use anyhow::{anyhow, Result}; +use chrono::{DateTime, Utc}; +use collections::HashMap; +use futures::future::{self, BoxFuture, Shared}; +use futures::FutureExt as _; +use fuzzy::StringMatchCandidate; +use gpui::{AppContext, BackgroundExecutor, Global, ReadGlobal, SharedString, Task}; +use heed::{ + types::{SerdeBincode, SerdeJson, Str}, + Database, RoTxn, +}; +use parking_lot::RwLock; +use rope::Rope; +use serde::{Deserialize, Serialize}; +use std::{ + cmp::Reverse, + future::Future, + path::PathBuf, + sync::{atomic::AtomicBool, Arc}, +}; +use text::LineEnding; +use util::ResultExt; +use uuid::Uuid; + +/// Init starts loading the PromptStore in the background and assigns +/// a shared future to a global. +pub fn init(cx: &mut AppContext) { + let db_path = paths::prompts_dir().join("prompts-library-db.0.mdb"); + let prompt_store_future = PromptStore::new(db_path, cx.background_executor().clone()) + .then(|result| future::ready(result.map(Arc::new).map_err(Arc::new))) + .boxed() + .shared(); + cx.set_global(GlobalPromptStore(prompt_store_future)) +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct PromptMetadata { + pub id: PromptId, + pub title: Option, + pub default: bool, + pub saved_at: DateTime, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)] +#[serde(tag = "kind")] +pub enum PromptId { + User { uuid: Uuid }, + EditWorkflow, +} + +impl PromptId { + pub fn new() -> PromptId { + PromptId::User { + uuid: Uuid::new_v4(), + } + } + + pub fn is_built_in(&self) -> bool { + !matches!(self, PromptId::User { .. }) + } +} + +pub struct PromptStore { + executor: BackgroundExecutor, + env: heed::Env, + metadata_cache: RwLock, + metadata: Database, SerdeJson>, + bodies: Database, Str>, +} + +#[derive(Default)] +struct MetadataCache { + metadata: Vec, + metadata_by_id: HashMap, +} + +impl MetadataCache { + fn from_db( + db: Database, SerdeJson>, + txn: &RoTxn, + ) -> Result { + let mut cache = MetadataCache::default(); + for result in db.iter(txn)? { + let (prompt_id, metadata) = result?; + cache.metadata.push(metadata.clone()); + cache.metadata_by_id.insert(prompt_id, metadata); + } + cache.sort(); + Ok(cache) + } + + fn insert(&mut self, metadata: PromptMetadata) { + self.metadata_by_id.insert(metadata.id, metadata.clone()); + if let Some(old_metadata) = self.metadata.iter_mut().find(|m| m.id == metadata.id) { + *old_metadata = metadata; + } else { + self.metadata.push(metadata); + } + self.sort(); + } + + fn remove(&mut self, id: PromptId) { + self.metadata.retain(|metadata| metadata.id != id); + self.metadata_by_id.remove(&id); + } + + fn sort(&mut self) { + self.metadata.sort_unstable_by(|a, b| { + a.title + .cmp(&b.title) + .then_with(|| b.saved_at.cmp(&a.saved_at)) + }); + } +} + +impl PromptStore { + pub fn global(cx: &AppContext) -> impl Future>> { + let store = GlobalPromptStore::global(cx).0.clone(); + async move { store.await.map_err(|err| anyhow!(err)) } + } + + pub fn new(db_path: PathBuf, executor: BackgroundExecutor) -> Task> { + executor.spawn({ + let executor = executor.clone(); + async move { + std::fs::create_dir_all(&db_path)?; + + let db_env = unsafe { + heed::EnvOpenOptions::new() + .map_size(1024 * 1024 * 1024) // 1GB + .max_dbs(4) // Metadata and bodies (possibly v1 of both as well) + .open(db_path)? + }; + + let mut txn = db_env.write_txn()?; + let metadata = db_env.create_database(&mut txn, Some("metadata.v2"))?; + let bodies = db_env.create_database(&mut txn, Some("bodies.v2"))?; + + // Remove edit workflow prompt, as we decided to opt into it using + // a slash command instead. + metadata.delete(&mut txn, &PromptId::EditWorkflow).ok(); + bodies.delete(&mut txn, &PromptId::EditWorkflow).ok(); + + txn.commit()?; + + Self::upgrade_dbs(&db_env, metadata, bodies).log_err(); + + let txn = db_env.read_txn()?; + let metadata_cache = MetadataCache::from_db(metadata, &txn)?; + txn.commit()?; + + Ok(PromptStore { + executor, + env: db_env, + metadata_cache: RwLock::new(metadata_cache), + metadata, + bodies, + }) + } + }) + } + + fn upgrade_dbs( + env: &heed::Env, + metadata_db: heed::Database, SerdeJson>, + bodies_db: heed::Database, Str>, + ) -> Result<()> { + #[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize, Hash)] + pub struct PromptIdV1(Uuid); + + #[derive(Clone, Debug, Serialize, Deserialize)] + pub struct PromptMetadataV1 { + pub id: PromptIdV1, + pub title: Option, + pub default: bool, + pub saved_at: DateTime, + } + + let mut txn = env.write_txn()?; + let Some(bodies_v1_db) = env + .open_database::, SerdeBincode>( + &txn, + Some("bodies"), + )? + else { + return Ok(()); + }; + let mut bodies_v1 = bodies_v1_db + .iter(&txn)? + .collect::>>()?; + + let Some(metadata_v1_db) = env + .open_database::, SerdeBincode>( + &txn, + Some("metadata"), + )? + else { + return Ok(()); + }; + let metadata_v1 = metadata_v1_db + .iter(&txn)? + .collect::>>()?; + + for (prompt_id_v1, metadata_v1) in metadata_v1 { + let prompt_id_v2 = PromptId::User { + uuid: prompt_id_v1.0, + }; + let Some(body_v1) = bodies_v1.remove(&prompt_id_v1) else { + continue; + }; + + if metadata_db + .get(&txn, &prompt_id_v2)? + .map_or(true, |metadata_v2| { + metadata_v1.saved_at > metadata_v2.saved_at + }) + { + metadata_db.put( + &mut txn, + &prompt_id_v2, + &PromptMetadata { + id: prompt_id_v2, + title: metadata_v1.title.clone(), + default: metadata_v1.default, + saved_at: metadata_v1.saved_at, + }, + )?; + bodies_db.put(&mut txn, &prompt_id_v2, &body_v1)?; + } + } + + txn.commit()?; + + Ok(()) + } + + pub fn load(&self, id: PromptId) -> Task> { + let env = self.env.clone(); + let bodies = self.bodies; + self.executor.spawn(async move { + let txn = env.read_txn()?; + let mut prompt = bodies + .get(&txn, &id)? + .ok_or_else(|| anyhow!("prompt not found"))? + .into(); + LineEnding::normalize(&mut prompt); + Ok(prompt) + }) + } + + pub fn default_prompt_metadata(&self) -> Vec { + return self + .metadata_cache + .read() + .metadata + .iter() + .filter(|metadata| metadata.default) + .cloned() + .collect::>(); + } + + pub fn delete(&self, id: PromptId) -> Task> { + self.metadata_cache.write().remove(id); + + let db_connection = self.env.clone(); + let bodies = self.bodies; + let metadata = self.metadata; + + self.executor.spawn(async move { + let mut txn = db_connection.write_txn()?; + + metadata.delete(&mut txn, &id)?; + bodies.delete(&mut txn, &id)?; + + txn.commit()?; + Ok(()) + }) + } + + /// Returns the number of prompts in the store. + pub fn prompt_count(&self) -> usize { + self.metadata_cache.read().metadata.len() + } + + pub fn metadata(&self, id: PromptId) -> Option { + self.metadata_cache.read().metadata_by_id.get(&id).cloned() + } + + pub fn first(&self) -> Option { + self.metadata_cache.read().metadata.first().cloned() + } + + pub fn id_for_title(&self, title: &str) -> Option { + let metadata_cache = self.metadata_cache.read(); + let metadata = metadata_cache + .metadata + .iter() + .find(|metadata| metadata.title.as_ref().map(|title| &***title) == Some(title))?; + Some(metadata.id) + } + + pub fn search(&self, query: String) -> Task> { + let cached_metadata = self.metadata_cache.read().metadata.clone(); + let executor = self.executor.clone(); + self.executor.spawn(async move { + let mut matches = if query.is_empty() { + cached_metadata + } else { + let candidates = cached_metadata + .iter() + .enumerate() + .filter_map(|(ix, metadata)| { + Some(StringMatchCandidate::new(ix, metadata.title.as_ref()?)) + }) + .collect::>(); + let matches = fuzzy::match_strings( + &candidates, + &query, + false, + 100, + &AtomicBool::default(), + executor, + ) + .await; + matches + .into_iter() + .map(|mat| cached_metadata[mat.candidate_id].clone()) + .collect() + }; + matches.sort_by_key(|metadata| Reverse(metadata.default)); + matches + }) + } + + pub fn save( + &self, + id: PromptId, + title: Option, + default: bool, + body: Rope, + ) -> Task> { + if id.is_built_in() { + return Task::ready(Err(anyhow!("built-in prompts cannot be saved"))); + } + + let prompt_metadata = PromptMetadata { + id, + title, + default, + saved_at: Utc::now(), + }; + self.metadata_cache.write().insert(prompt_metadata.clone()); + + let db_connection = self.env.clone(); + let bodies = self.bodies; + let metadata = self.metadata; + + self.executor.spawn(async move { + let mut txn = db_connection.write_txn()?; + + metadata.put(&mut txn, &id, &prompt_metadata)?; + bodies.put(&mut txn, &id, &body.to_string())?; + + txn.commit()?; + + Ok(()) + }) + } + + pub fn save_metadata( + &self, + id: PromptId, + mut title: Option, + default: bool, + ) -> Task> { + let mut cache = self.metadata_cache.write(); + + if id.is_built_in() { + title = cache + .metadata_by_id + .get(&id) + .and_then(|metadata| metadata.title.clone()); + } + + let prompt_metadata = PromptMetadata { + id, + title, + default, + saved_at: Utc::now(), + }; + + cache.insert(prompt_metadata.clone()); + + let db_connection = self.env.clone(); + let metadata = self.metadata; + + self.executor.spawn(async move { + let mut txn = db_connection.write_txn()?; + metadata.put(&mut txn, &id, &prompt_metadata)?; + txn.commit()?; + + Ok(()) + }) + } +} + +/// Wraps a shared future to a prompt store so it can be assigned as a context global. +pub struct GlobalPromptStore( + Shared, Arc>>>, +); + +impl Global for GlobalPromptStore {} diff --git a/crates/assistant/src/prompts.rs b/crates/prompt_library/src/prompts.rs similarity index 100% rename from crates/assistant/src/prompts.rs rename to crates/prompt_library/src/prompts.rs