mod prompts; use anyhow::{Context as _, Result, anyhow}; use chrono::{DateTime, Utc}; use collections::HashMap; use futures::FutureExt as _; use futures::future::Shared; use fuzzy::StringMatchCandidate; use gpui::{ App, AppContext, Context, Entity, EventEmitter, Global, ReadGlobal, SharedString, Task, }; use heed::{ Database, RoTxn, types::{SerdeBincode, SerdeJson, Str}, }; use parking_lot::RwLock; pub use prompts::*; use rope::Rope; use serde::{Deserialize, Serialize}; use std::{ cmp::Reverse, future::Future, path::PathBuf, sync::{Arc, atomic::AtomicBool}, }; use text::LineEnding; use util::ResultExt; use uuid::Uuid; /// Init starts loading the PromptStore in the background and assigns /// a shared future to a global. pub fn init(cx: &mut App) { let db_path = paths::prompts_dir().join("prompts-library-db.0.mdb"); let prompt_store_task = PromptStore::new(db_path, cx); let prompt_store_entity_task = cx .spawn(async move |cx| { prompt_store_task .await .and_then(|prompt_store| cx.new(|_cx| prompt_store)) .map_err(Arc::new) }) .shared(); cx.set_global(GlobalPromptStore(prompt_store_entity_task)) } #[derive(Clone, Debug, Serialize, Deserialize)] pub struct PromptMetadata { pub id: PromptId, pub title: Option, pub default: bool, pub saved_at: DateTime, } #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)] #[serde(tag = "kind")] pub enum PromptId { User { uuid: UserPromptId }, EditWorkflow, } impl PromptId { pub fn new() -> PromptId { UserPromptId::new().into() } pub fn is_built_in(&self) -> bool { !matches!(self, PromptId::User { .. }) } } impl From for PromptId { fn from(uuid: UserPromptId) -> Self { PromptId::User { uuid } } } #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)] #[serde(transparent)] pub struct UserPromptId(pub Uuid); impl UserPromptId { pub fn new() -> UserPromptId { UserPromptId(Uuid::new_v4()) } } impl From for UserPromptId { fn from(uuid: Uuid) -> Self { UserPromptId(uuid) } } impl std::fmt::Display for PromptId { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { PromptId::User { uuid } => write!(f, "{}", uuid.0), PromptId::EditWorkflow => write!(f, "Edit workflow"), } } } pub struct PromptStore { env: heed::Env, metadata_cache: RwLock, metadata: Database, SerdeJson>, bodies: Database, Str>, } pub struct PromptsUpdatedEvent; impl EventEmitter for PromptStore {} #[derive(Default)] struct MetadataCache { metadata: Vec, metadata_by_id: HashMap, } impl MetadataCache { fn from_db( db: Database, SerdeJson>, txn: &RoTxn, ) -> Result { let mut cache = MetadataCache::default(); for result in db.iter(txn)? { let (prompt_id, metadata) = result?; cache.metadata.push(metadata.clone()); cache.metadata_by_id.insert(prompt_id, metadata); } cache.sort(); Ok(cache) } fn insert(&mut self, metadata: PromptMetadata) { self.metadata_by_id.insert(metadata.id, metadata.clone()); if let Some(old_metadata) = self.metadata.iter_mut().find(|m| m.id == metadata.id) { *old_metadata = metadata; } else { self.metadata.push(metadata); } self.sort(); } fn remove(&mut self, id: PromptId) { self.metadata.retain(|metadata| metadata.id != id); self.metadata_by_id.remove(&id); } fn sort(&mut self) { self.metadata.sort_unstable_by(|a, b| { a.title .cmp(&b.title) .then_with(|| b.saved_at.cmp(&a.saved_at)) }); } } impl PromptStore { pub fn global(cx: &App) -> impl Future>> + use<> { let store = GlobalPromptStore::global(cx).0.clone(); async move { store.await.map_err(|err| anyhow!(err)) } } pub fn new(db_path: PathBuf, cx: &App) -> Task> { cx.background_spawn(async move { std::fs::create_dir_all(&db_path)?; let db_env = unsafe { heed::EnvOpenOptions::new() .map_size(1024 * 1024 * 1024) // 1GB .max_dbs(4) // Metadata and bodies (possibly v1 of both as well) .open(db_path)? }; let mut txn = db_env.write_txn()?; let metadata = db_env.create_database(&mut txn, Some("metadata.v2"))?; let bodies = db_env.create_database(&mut txn, Some("bodies.v2"))?; // Remove edit workflow prompt, as we decided to opt into it using // a slash command instead. metadata.delete(&mut txn, &PromptId::EditWorkflow).ok(); bodies.delete(&mut txn, &PromptId::EditWorkflow).ok(); txn.commit()?; Self::upgrade_dbs(&db_env, metadata, bodies).log_err(); let txn = db_env.read_txn()?; let metadata_cache = MetadataCache::from_db(metadata, &txn)?; txn.commit()?; Ok(PromptStore { env: db_env, metadata_cache: RwLock::new(metadata_cache), metadata, bodies, }) }) } fn upgrade_dbs( env: &heed::Env, metadata_db: heed::Database, SerdeJson>, bodies_db: heed::Database, Str>, ) -> Result<()> { #[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize, Hash)] pub struct PromptIdV1(Uuid); #[derive(Clone, Debug, Serialize, Deserialize)] pub struct PromptMetadataV1 { pub id: PromptIdV1, pub title: Option, pub default: bool, pub saved_at: DateTime, } let mut txn = env.write_txn()?; let Some(bodies_v1_db) = env .open_database::, SerdeBincode>( &txn, Some("bodies"), )? else { return Ok(()); }; let mut bodies_v1 = bodies_v1_db .iter(&txn)? .collect::>>()?; let Some(metadata_v1_db) = env .open_database::, SerdeBincode>( &txn, Some("metadata"), )? else { return Ok(()); }; let metadata_v1 = metadata_v1_db .iter(&txn)? .collect::>>()?; for (prompt_id_v1, metadata_v1) in metadata_v1 { let prompt_id_v2 = UserPromptId(prompt_id_v1.0).into(); let Some(body_v1) = bodies_v1.remove(&prompt_id_v1) else { continue; }; if metadata_db .get(&txn, &prompt_id_v2)? .is_none_or(|metadata_v2| metadata_v1.saved_at > metadata_v2.saved_at) { metadata_db.put( &mut txn, &prompt_id_v2, &PromptMetadata { id: prompt_id_v2, title: metadata_v1.title.clone(), default: metadata_v1.default, saved_at: metadata_v1.saved_at, }, )?; bodies_db.put(&mut txn, &prompt_id_v2, &body_v1)?; } } txn.commit()?; Ok(()) } pub fn load(&self, id: PromptId, cx: &App) -> Task> { let env = self.env.clone(); let bodies = self.bodies; cx.background_spawn(async move { let txn = env.read_txn()?; let mut prompt = bodies.get(&txn, &id)?.context("prompt not found")?.into(); LineEnding::normalize(&mut prompt); Ok(prompt) }) } pub fn all_prompt_metadata(&self) -> Vec { self.metadata_cache.read().metadata.clone() } pub fn default_prompt_metadata(&self) -> Vec { return self .metadata_cache .read() .metadata .iter() .filter(|metadata| metadata.default) .cloned() .collect::>(); } pub fn delete(&self, id: PromptId, cx: &Context) -> Task> { self.metadata_cache.write().remove(id); let db_connection = self.env.clone(); let bodies = self.bodies; let metadata = self.metadata; let task = cx.background_spawn(async move { let mut txn = db_connection.write_txn()?; metadata.delete(&mut txn, &id)?; bodies.delete(&mut txn, &id)?; txn.commit()?; anyhow::Ok(()) }); cx.spawn(async move |this, cx| { task.await?; this.update(cx, |_, cx| cx.emit(PromptsUpdatedEvent)).ok(); anyhow::Ok(()) }) } /// Returns the number of prompts in the store. pub fn prompt_count(&self) -> usize { self.metadata_cache.read().metadata.len() } pub fn metadata(&self, id: PromptId) -> Option { self.metadata_cache.read().metadata_by_id.get(&id).cloned() } pub fn first(&self) -> Option { self.metadata_cache.read().metadata.first().cloned() } pub fn id_for_title(&self, title: &str) -> Option { let metadata_cache = self.metadata_cache.read(); let metadata = metadata_cache .metadata .iter() .find(|metadata| metadata.title.as_ref().map(|title| &***title) == Some(title))?; Some(metadata.id) } pub fn search( &self, query: String, cancellation_flag: Arc, cx: &App, ) -> Task> { let cached_metadata = self.metadata_cache.read().metadata.clone(); let executor = cx.background_executor().clone(); cx.background_spawn(async move { let mut matches = if query.is_empty() { cached_metadata } else { let candidates = cached_metadata .iter() .enumerate() .filter_map(|(ix, metadata)| { Some(StringMatchCandidate::new(ix, metadata.title.as_ref()?)) }) .collect::>(); let matches = fuzzy::match_strings( &candidates, &query, false, true, 100, &cancellation_flag, executor, ) .await; matches .into_iter() .map(|mat| cached_metadata[mat.candidate_id].clone()) .collect() }; matches.sort_by_key(|metadata| Reverse(metadata.default)); matches }) } pub fn save( &self, id: PromptId, title: Option, default: bool, body: Rope, cx: &Context, ) -> Task> { if id.is_built_in() { return Task::ready(Err(anyhow!("built-in prompts cannot be saved"))); } let prompt_metadata = PromptMetadata { id, title, default, saved_at: Utc::now(), }; self.metadata_cache.write().insert(prompt_metadata.clone()); let db_connection = self.env.clone(); let bodies = self.bodies; let metadata = self.metadata; let task = cx.background_spawn(async move { let mut txn = db_connection.write_txn()?; metadata.put(&mut txn, &id, &prompt_metadata)?; bodies.put(&mut txn, &id, &body.to_string())?; txn.commit()?; anyhow::Ok(()) }); cx.spawn(async move |this, cx| { task.await?; this.update(cx, |_, cx| cx.emit(PromptsUpdatedEvent)).ok(); anyhow::Ok(()) }) } pub fn save_metadata( &self, id: PromptId, mut title: Option, default: bool, cx: &Context, ) -> Task> { let mut cache = self.metadata_cache.write(); if id.is_built_in() { title = cache .metadata_by_id .get(&id) .and_then(|metadata| metadata.title.clone()); } let prompt_metadata = PromptMetadata { id, title, default, saved_at: Utc::now(), }; cache.insert(prompt_metadata.clone()); let db_connection = self.env.clone(); let metadata = self.metadata; let task = cx.background_spawn(async move { let mut txn = db_connection.write_txn()?; metadata.put(&mut txn, &id, &prompt_metadata)?; txn.commit()?; anyhow::Ok(()) }); cx.spawn(async move |this, cx| { task.await?; this.update(cx, |_, cx| cx.emit(PromptsUpdatedEvent)).ok(); anyhow::Ok(()) }) } } /// Wraps a shared future to a prompt store so it can be assigned as a context global. pub struct GlobalPromptStore(Shared, Arc>>>); impl Global for GlobalPromptStore {}