agent: Use default prompts from prompt library in system prompt (#28915)

Related to #28490.

- Default prompts from the prompt library are now included as "user
rules" in the system prompt.
- Presence of these user rules is shown at the beginning of the thread
in the UI.
_ Now uses an `Entity<PromptStore>` instead of an `Arc<PromptStore>`.
Motivation for this is emitting a `PromptsUpdatedEvent`.
- Now disallows concurrent reloading of the system prompt. Before this
change it was possible for reloads to race.

Release Notes:

- agent: Added support for including default prompts from the Prompt
Library as "user rules" in the system prompt.

---------

Co-authored-by: Danilo Leal <daniloleal09@gmail.com>
This commit is contained in:
Michael Sloan 2025-04-18 09:32:35 -06:00 committed by GitHub
parent eea6cfb383
commit 502a0f6535
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 433 additions and 165 deletions

View file

@ -144,6 +144,19 @@ In Markdown, hash marks signify headings. For example:
This example is unacceptable because the path is in the wrong place. The path must be directly after the opening backticks. This example is unacceptable because the path is in the wrong place. The path must be directly after the opening backticks.
</style> </style>
{{#if has_default_user_rules}}
The user has specified the following rules that should be applied:
{{#each default_user_rules}}
{{#if title}}
Rules title: {{title}}
{{/if}}
``````
{{contents}}
``````
{{/each}}
{{/if}}
The user has opened a project that contains the following root directories/files. Whenever you specify a path in the project, it must be a relative path which begins with one of these root directories/files: The user has opened a project that contains the following root directories/files. Whenever you specify a path in the project, it must be a relative path which begins with one of these root directories/files:
{{#each worktrees}} {{#each worktrees}}
@ -151,7 +164,7 @@ The user has opened a project that contains the following root directories/files
{{/each}} {{/each}}
{{#if has_rules}} {{#if has_rules}}
There are rules that apply to these root directories: There are project rules that apply to these root directories:
{{#each worktrees}} {{#each worktrees}}
{{#if rules_file}} {{#if rules_file}}

View file

@ -42,6 +42,7 @@ use ui::{
}; };
use util::ResultExt as _; use util::ResultExt as _;
use workspace::{OpenOptions, Workspace}; use workspace::{OpenOptions, Workspace};
use zed_actions::assistant::OpenPromptLibrary;
use crate::context_store::ContextStore; use crate::context_store::ContextStore;
@ -2948,53 +2949,106 @@ impl ActiveThread {
return div().into_any(); return div().into_any();
}; };
let default_user_rules_text = if project_context.default_user_rules.is_empty() {
None
} else if project_context.default_user_rules.len() == 1 {
let user_rules = &project_context.default_user_rules[0];
match user_rules.title.as_ref() {
Some(title) => Some(format!("Using \"{title}\" user rule")),
None => Some("Using user rule".into()),
}
} else {
Some(format!(
"Using {} user rules",
project_context.default_user_rules.len()
))
};
let rules_files = project_context let rules_files = project_context
.worktrees .worktrees
.iter() .iter()
.filter_map(|worktree| worktree.rules_file.as_ref()) .filter_map(|worktree| worktree.rules_file.as_ref())
.collect::<Vec<_>>(); .collect::<Vec<_>>();
let label_text = match rules_files.as_slice() { let rules_file_text = match rules_files.as_slice() {
&[] => return div().into_any(), &[] => None,
&[rules_file] => { &[rules_file] => Some(format!(
format!("Using {:?} file", rules_file.path_in_worktree) "Using project {:?} file",
} rules_file.path_in_worktree
rules_files => { )),
format!("Using {} rules files", rules_files.len()) rules_files => Some(format!("Using {} project rules files", rules_files.len())),
}
}; };
div() if default_user_rules_text.is_none() && rules_file_text.is_none() {
return div().into_any();
}
v_flex()
.pt_2() .pt_2()
.px_2p5() .px_2p5()
.child( .gap_1()
h_flex() .when_some(
.w_full() default_user_rules_text,
.gap_0p5() |parent, default_user_rules_text| {
.child( parent.child(
h_flex() h_flex()
.gap_1p5() .w_full()
.child( .child(
Icon::new(IconName::File) Icon::new(IconName::File)
.size(IconSize::XSmall) .size(IconSize::XSmall)
.color(Color::Disabled), .color(Color::Disabled),
) )
.child( .child(
Label::new(label_text) Label::new(default_user_rules_text)
.size(LabelSize::XSmall) .size(LabelSize::XSmall)
.color(Color::Muted) .color(Color::Muted)
.buffer_font(cx), .truncate()
.buffer_font(cx)
.ml_1p5()
.mr_0p5(),
)
.child(
IconButton::new("open-prompt-library", IconName::ArrowUpRightAlt)
.shape(ui::IconButtonShape::Square)
.icon_size(IconSize::XSmall)
.icon_color(Color::Ignored)
// TODO: Figure out a way to pass focus handle here so we can display the `OpenPromptLibrary` keybinding
.tooltip(Tooltip::text("View User Rules"))
.on_click(|_event, window, cx| {
window.dispatch_action(Box::new(OpenPromptLibrary), cx)
}),
), ),
) )
.child( },
IconButton::new("open-rule", IconName::ArrowUpRightAlt)
.shape(ui::IconButtonShape::Square)
.icon_size(IconSize::XSmall)
.icon_color(Color::Ignored)
.on_click(cx.listener(Self::handle_open_rules))
.tooltip(Tooltip::text("View Rules")),
),
) )
.when_some(rules_file_text, |parent, rules_file_text| {
parent.child(
h_flex()
.w_full()
.child(
Icon::new(IconName::File)
.size(IconSize::XSmall)
.color(Color::Disabled),
)
.child(
Label::new(rules_file_text)
.size(LabelSize::XSmall)
.color(Color::Muted)
.buffer_font(cx)
.ml_1p5()
.mr_0p5(),
)
.child(
IconButton::new("open-rule", IconName::ArrowUpRightAlt)
.shape(ui::IconButtonShape::Square)
.icon_size(IconSize::XSmall)
.icon_color(Color::Ignored)
.on_click(cx.listener(Self::handle_open_rules))
.tooltip(Tooltip::text("View Rules")),
),
)
})
.into_any() .into_any()
} }

View file

@ -922,6 +922,7 @@ mod tests {
language::init(cx); language::init(cx);
Project::init_settings(cx); Project::init_settings(cx);
AssistantSettings::register(cx); AssistantSettings::register(cx);
prompt_store::init(cx);
thread_store::init(cx); thread_store::init(cx);
workspace::init_settings(cx); workspace::init_settings(cx);
ThemeSettings::register(cx); ThemeSettings::register(cx);
@ -951,7 +952,8 @@ mod tests {
cx, cx,
) )
}) })
.await; .await
.unwrap();
let thread = thread_store.update(cx, |store, cx| store.create_thread(cx)); let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
let action_log = thread.read_with(cx, |thread, _| thread.action_log().clone()); let action_log = thread.read_with(cx, |thread, _| thread.action_log().clone());

View file

@ -213,7 +213,7 @@ impl AssistantPanel {
let project = workspace.project().clone(); let project = workspace.project().clone();
ThreadStore::load(project, tools.clone(), prompt_builder.clone(), cx) ThreadStore::load(project, tools.clone(), prompt_builder.clone(), cx)
})? })?
.await; .await?;
let slash_commands = Arc::new(SlashCommandWorkingSet::default()); let slash_commands = Arc::new(SlashCommandWorkingSet::default());
let context_store = workspace let context_store = workspace

View file

@ -4,7 +4,7 @@ use std::ops::Range;
use std::sync::Arc; use std::sync::Arc;
use std::time::Instant; use std::time::Instant;
use anyhow::{Context as _, Result, anyhow}; use anyhow::{Result, anyhow};
use assistant_settings::AssistantSettings; use assistant_settings::AssistantSettings;
use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet}; use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
@ -939,7 +939,7 @@ impl Thread {
pub fn to_completion_request( pub fn to_completion_request(
&self, &self,
request_kind: RequestKind, request_kind: RequestKind,
cx: &App, cx: &mut Context<Self>,
) -> LanguageModelRequest { ) -> LanguageModelRequest {
let mut request = LanguageModelRequest { let mut request = LanguageModelRequest {
messages: vec![], messages: vec![],
@ -949,20 +949,33 @@ impl Thread {
}; };
if let Some(project_context) = self.project_context.borrow().as_ref() { if let Some(project_context) = self.project_context.borrow().as_ref() {
if let Some(system_prompt) = self match self
.prompt_builder .prompt_builder
.generate_assistant_system_prompt(project_context) .generate_assistant_system_prompt(project_context)
.context("failed to generate assistant system prompt")
.log_err()
{ {
request.messages.push(LanguageModelRequestMessage { Err(err) => {
role: Role::System, let message = format!("{err:?}").into();
content: vec![MessageContent::Text(system_prompt)], log::error!("{message}");
cache: true, cx.emit(ThreadEvent::ShowError(ThreadError::Message {
}); header: "Error generating system prompt".into(),
message,
}));
}
Ok(system_prompt) => {
request.messages.push(LanguageModelRequestMessage {
role: Role::System,
content: vec![MessageContent::Text(system_prompt)],
cache: true,
});
}
} }
} else { } else {
log::error!("project_context not set.") let message = "Context for system prompt unexpectedly not ready.".into();
log::error!("{message}");
cx.emit(ThreadEvent::ShowError(ThreadError::Message {
header: "Error generating system prompt".into(),
message,
}));
} }
for message in &self.messages { for message in &self.messages {
@ -2163,7 +2176,7 @@ fn main() {{
assert_eq!(message.context, expected_context); assert_eq!(message.context, expected_context);
// Check message in request // Check message in request
let request = thread.read_with(cx, |thread, cx| { let request = thread.update(cx, |thread, cx| {
thread.to_completion_request(RequestKind::Chat, cx) thread.to_completion_request(RequestKind::Chat, cx)
}); });
@ -2255,7 +2268,7 @@ fn main() {{
assert!(message3.context.contains("file3.rs")); assert!(message3.context.contains("file3.rs"));
// Check entire request to make sure all contexts are properly included // Check entire request to make sure all contexts are properly included
let request = thread.read_with(cx, |thread, cx| { let request = thread.update(cx, |thread, cx| {
thread.to_completion_request(RequestKind::Chat, cx) thread.to_completion_request(RequestKind::Chat, cx)
}); });
@ -2307,7 +2320,7 @@ fn main() {{
assert_eq!(message.context, ""); assert_eq!(message.context, "");
// Check message in request // Check message in request
let request = thread.read_with(cx, |thread, cx| { let request = thread.update(cx, |thread, cx| {
thread.to_completion_request(RequestKind::Chat, cx) thread.to_completion_request(RequestKind::Chat, cx)
}); });
@ -2327,7 +2340,7 @@ fn main() {{
assert_eq!(message2.context, ""); assert_eq!(message2.context, "");
// Check that both messages appear in the request // Check that both messages appear in the request
let request = thread.read_with(cx, |thread, cx| { let request = thread.update(cx, |thread, cx| {
thread.to_completion_request(RequestKind::Chat, cx) thread.to_completion_request(RequestKind::Chat, cx)
}); });
@ -2369,7 +2382,7 @@ fn main() {{
}); });
// Create a request and check that it doesn't have a stale buffer warning yet // Create a request and check that it doesn't have a stale buffer warning yet
let initial_request = thread.read_with(cx, |thread, cx| { let initial_request = thread.update(cx, |thread, cx| {
thread.to_completion_request(RequestKind::Chat, cx) thread.to_completion_request(RequestKind::Chat, cx)
}); });
@ -2399,7 +2412,7 @@ fn main() {{
}); });
// Create a new request and check for the stale buffer warning // Create a new request and check for the stale buffer warning
let new_request = thread.read_with(cx, |thread, cx| { let new_request = thread.update(cx, |thread, cx| {
thread.to_completion_request(RequestKind::Chat, cx) thread.to_completion_request(RequestKind::Chat, cx)
}); });
@ -2428,6 +2441,7 @@ fn main() {{
language::init(cx); language::init(cx);
Project::init_settings(cx); Project::init_settings(cx);
AssistantSettings::register(cx); AssistantSettings::register(cx);
prompt_store::init(cx);
thread_store::init(cx); thread_store::init(cx);
workspace::init_settings(cx); workspace::init_settings(cx);
ThemeSettings::register(cx); ThemeSettings::register(cx);
@ -2467,7 +2481,8 @@ fn main() {{
cx, cx,
) )
}) })
.await; .await
.unwrap();
let thread = thread_store.update(cx, |store, cx| store.create_thread(cx)); let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None)); let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));

View file

@ -12,8 +12,9 @@ use collections::HashMap;
use context_server::manager::ContextServerManager; use context_server::manager::ContextServerManager;
use context_server::{ContextServerFactoryRegistry, ContextServerTool}; use context_server::{ContextServerFactoryRegistry, ContextServerTool};
use fs::Fs; use fs::Fs;
use futures::FutureExt as _; use futures::channel::{mpsc, oneshot};
use futures::future::{self, BoxFuture, Shared}; use futures::future::{self, BoxFuture, Shared};
use futures::{FutureExt as _, StreamExt as _};
use gpui::{ use gpui::{
App, BackgroundExecutor, Context, Entity, EventEmitter, Global, ReadGlobal, SharedString, App, BackgroundExecutor, Context, Entity, EventEmitter, Global, ReadGlobal, SharedString,
Subscription, Task, prelude::*, Subscription, Task, prelude::*,
@ -22,7 +23,10 @@ use heed::Database;
use heed::types::SerdeBincode; use heed::types::SerdeBincode;
use language_model::{LanguageModelToolUseId, Role, TokenUsage}; use language_model::{LanguageModelToolUseId, Role, TokenUsage};
use project::{Project, Worktree}; use project::{Project, Worktree};
use prompt_store::{ProjectContext, PromptBuilder, RulesFileContext, WorktreeContext}; use prompt_store::{
DefaultUserRulesContext, ProjectContext, PromptBuilder, PromptStore, PromptsUpdatedEvent,
RulesFileContext, WorktreeContext,
};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use settings::{Settings as _, SettingsStore}; use settings::{Settings as _, SettingsStore};
use util::ResultExt as _; use util::ResultExt as _;
@ -62,6 +66,8 @@ pub struct ThreadStore {
context_server_tool_ids: HashMap<Arc<str>, Vec<ToolId>>, context_server_tool_ids: HashMap<Arc<str>, Vec<ToolId>>,
threads: Vec<SerializedThreadMetadata>, threads: Vec<SerializedThreadMetadata>,
project_context: SharedProjectContext, project_context: SharedProjectContext,
reload_system_prompt_tx: mpsc::Sender<()>,
_reload_system_prompt_task: Task<()>,
_subscriptions: Vec<Subscription>, _subscriptions: Vec<Subscription>,
} }
@ -77,12 +83,22 @@ impl ThreadStore {
tools: Entity<ToolWorkingSet>, tools: Entity<ToolWorkingSet>,
prompt_builder: Arc<PromptBuilder>, prompt_builder: Arc<PromptBuilder>,
cx: &mut App, cx: &mut App,
) -> Task<Entity<Self>> { ) -> Task<Result<Entity<Self>>> {
let thread_store = cx.new(|cx| Self::new(project, tools, prompt_builder, cx)); let prompt_store = PromptStore::global(cx);
let reload = thread_store.update(cx, |store, cx| store.reload_system_prompt(cx)); cx.spawn(async move |cx| {
cx.foreground_executor().spawn(async move { let prompt_store = prompt_store.await.ok();
reload.await; let (thread_store, ready_rx) = cx.update(|cx| {
thread_store let mut option_ready_rx = None;
let thread_store = cx.new(|cx| {
let (thread_store, ready_rx) =
Self::new(project, tools, prompt_builder, prompt_store, cx);
option_ready_rx = Some(ready_rx);
thread_store
});
(thread_store, option_ready_rx.take().unwrap())
})?;
ready_rx.await?;
Ok(thread_store)
}) })
} }
@ -90,17 +106,53 @@ impl ThreadStore {
project: Entity<Project>, project: Entity<Project>,
tools: Entity<ToolWorkingSet>, tools: Entity<ToolWorkingSet>,
prompt_builder: Arc<PromptBuilder>, prompt_builder: Arc<PromptBuilder>,
prompt_store: Option<Entity<PromptStore>>,
cx: &mut Context<Self>, cx: &mut Context<Self>,
) -> Self { ) -> (Self, oneshot::Receiver<()>) {
let context_server_factory_registry = ContextServerFactoryRegistry::default_global(cx); let context_server_factory_registry = ContextServerFactoryRegistry::default_global(cx);
let context_server_manager = cx.new(|cx| { let context_server_manager = cx.new(|cx| {
ContextServerManager::new(context_server_factory_registry, project.clone(), cx) ContextServerManager::new(context_server_factory_registry, project.clone(), cx)
}); });
let settings_subscription =
let mut subscriptions = vec![
cx.observe_global::<SettingsStore>(move |this: &mut Self, cx| { cx.observe_global::<SettingsStore>(move |this: &mut Self, cx| {
this.load_default_profile(cx); this.load_default_profile(cx);
}); }),
let project_subscription = cx.subscribe(&project, Self::handle_project_event); cx.subscribe(&project, Self::handle_project_event),
];
if let Some(prompt_store) = prompt_store.as_ref() {
subscriptions.push(cx.subscribe(
prompt_store,
|this, _prompt_store, PromptsUpdatedEvent, _cx| {
this.enqueue_system_prompt_reload();
},
))
}
// This channel and task prevent concurrent and redundant loading of the system prompt.
let (reload_system_prompt_tx, mut reload_system_prompt_rx) = mpsc::channel(1);
let (ready_tx, ready_rx) = oneshot::channel();
let mut ready_tx = Some(ready_tx);
let reload_system_prompt_task = cx.spawn({
async move |thread_store, cx| {
loop {
let Some(reload_task) = thread_store
.update(cx, |thread_store, cx| {
thread_store.reload_system_prompt(prompt_store.clone(), cx)
})
.ok()
else {
return;
};
reload_task.await;
if let Some(ready_tx) = ready_tx.take() {
ready_tx.send(()).ok();
}
reload_system_prompt_rx.next().await;
}
}
});
let this = Self { let this = Self {
project, project,
@ -110,23 +162,25 @@ impl ThreadStore {
context_server_tool_ids: HashMap::default(), context_server_tool_ids: HashMap::default(),
threads: Vec::new(), threads: Vec::new(),
project_context: SharedProjectContext::default(), project_context: SharedProjectContext::default(),
_subscriptions: vec![settings_subscription, project_subscription], reload_system_prompt_tx,
_reload_system_prompt_task: reload_system_prompt_task,
_subscriptions: subscriptions,
}; };
this.load_default_profile(cx); this.load_default_profile(cx);
this.register_context_server_handlers(cx); this.register_context_server_handlers(cx);
this.reload(cx).detach_and_log_err(cx); this.reload(cx).detach_and_log_err(cx);
this (this, ready_rx)
} }
fn handle_project_event( fn handle_project_event(
&mut self, &mut self,
_project: Entity<Project>, _project: Entity<Project>,
event: &project::Event, event: &project::Event,
cx: &mut Context<Self>, _cx: &mut Context<Self>,
) { ) {
match event { match event {
project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => { project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
self.reload_system_prompt(cx).detach(); self.enqueue_system_prompt_reload();
} }
project::Event::WorktreeUpdatedEntries(_, items) => { project::Event::WorktreeUpdatedEntries(_, items) => {
if items.iter().any(|(path, _, _)| { if items.iter().any(|(path, _, _)| {
@ -134,16 +188,25 @@ impl ThreadStore {
.iter() .iter()
.any(|name| path.as_ref() == Path::new(name)) .any(|name| path.as_ref() == Path::new(name))
}) { }) {
self.reload_system_prompt(cx).detach(); self.enqueue_system_prompt_reload();
} }
} }
_ => {} _ => {}
} }
} }
pub fn reload_system_prompt(&self, cx: &mut Context<Self>) -> Task<()> { fn enqueue_system_prompt_reload(&mut self) {
self.reload_system_prompt_tx.try_send(()).ok();
}
// Note that this should only be called from `reload_system_prompt_task`.
fn reload_system_prompt(
&self,
prompt_store: Option<Entity<PromptStore>>,
cx: &mut Context<Self>,
) -> Task<()> {
let project = self.project.read(cx); let project = self.project.read(cx);
let tasks = project let worktree_tasks = project
.visible_worktrees(cx) .visible_worktrees(cx)
.map(|worktree| { .map(|worktree| {
Self::load_worktree_info_for_system_prompt( Self::load_worktree_info_for_system_prompt(
@ -153,10 +216,23 @@ impl ThreadStore {
) )
}) })
.collect::<Vec<_>>(); .collect::<Vec<_>>();
let default_user_rules_task = match prompt_store {
None => Task::ready(vec![]),
Some(prompt_store) => prompt_store.read_with(cx, |prompt_store, cx| {
let prompts = prompt_store.default_prompt_metadata();
let load_tasks = prompts.into_iter().map(|prompt_metadata| {
let contents = prompt_store.load(prompt_metadata.id, cx);
async move { (contents.await, prompt_metadata) }
});
cx.background_spawn(future::join_all(load_tasks))
}),
};
cx.spawn(async move |this, cx| { cx.spawn(async move |this, cx| {
let results = futures::future::join_all(tasks).await; let (worktrees, default_user_rules) =
let worktrees = results future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
let worktrees = worktrees
.into_iter() .into_iter()
.map(|(worktree, rules_error)| { .map(|(worktree, rules_error)| {
if let Some(rules_error) = rules_error { if let Some(rules_error) = rules_error {
@ -165,8 +241,29 @@ impl ThreadStore {
worktree worktree
}) })
.collect::<Vec<_>>(); .collect::<Vec<_>>();
let default_user_rules = default_user_rules
.into_iter()
.flat_map(|(contents, prompt_metadata)| match contents {
Ok(contents) => Some(DefaultUserRulesContext {
title: prompt_metadata.title.map(|title| title.to_string()),
contents,
}),
Err(err) => {
this.update(cx, |_, cx| {
cx.emit(RulesLoadingError {
message: format!("{err:?}").into(),
});
})
.ok();
None
}
})
.collect::<Vec<_>>();
this.update(cx, |this, _cx| { this.update(cx, |this, _cx| {
*this.project_context.0.borrow_mut() = Some(ProjectContext::new(worktrees)); *this.project_context.0.borrow_mut() =
Some(ProjectContext::new(worktrees, default_user_rules));
}) })
.ok(); .ok();
}) })

View file

@ -54,9 +54,9 @@ impl SlashCommand for DefaultSlashCommand {
cx: &mut App, cx: &mut App,
) -> Task<SlashCommandResult> { ) -> Task<SlashCommandResult> {
let store = PromptStore::global(cx); let store = PromptStore::global(cx);
cx.background_spawn(async move { cx.spawn(async move |cx| {
let store = store.await?; let store = store.await?;
let prompts = store.default_prompt_metadata(); let prompts = store.read_with(cx, |store, _cx| store.default_prompt_metadata())?;
let mut text = String::new(); let mut text = String::new();
text.push('\n'); text.push('\n');

View file

@ -5,7 +5,7 @@ use assistant_slash_command::{
}; };
use gpui::{Task, WeakEntity}; use gpui::{Task, WeakEntity};
use language::{BufferSnapshot, LspAdapterDelegate}; use language::{BufferSnapshot, LspAdapterDelegate};
use prompt_store::PromptStore; use prompt_store::{PromptMetadata, PromptStore};
use std::sync::{Arc, atomic::AtomicBool}; use std::sync::{Arc, atomic::AtomicBool};
use ui::prelude::*; use ui::prelude::*;
use workspace::Workspace; use workspace::Workspace;
@ -43,8 +43,11 @@ impl SlashCommand for PromptSlashCommand {
) -> Task<Result<Vec<ArgumentCompletion>>> { ) -> Task<Result<Vec<ArgumentCompletion>>> {
let store = PromptStore::global(cx); let store = PromptStore::global(cx);
let query = arguments.to_owned().join(" "); let query = arguments.to_owned().join(" ");
cx.background_spawn(async move { cx.spawn(async move |cx| {
let prompts = store.await?.search(query).await; let prompts: Vec<PromptMetadata> = store
.await?
.read_with(cx, |store, cx| store.search(query, cx))?
.await;
Ok(prompts Ok(prompts
.into_iter() .into_iter()
.filter_map(|prompt| { .filter_map(|prompt| {
@ -77,14 +80,18 @@ impl SlashCommand for PromptSlashCommand {
let store = PromptStore::global(cx); let store = PromptStore::global(cx);
let title = SharedString::from(title.clone()); let title = SharedString::from(title.clone());
let prompt = cx.background_spawn({ let prompt = cx.spawn({
let title = title.clone(); let title = title.clone();
async move { async move |cx| {
let store = store.await?; let store = store.await?;
let prompt_id = store let body = store
.id_for_title(&title) .read_with(cx, |store, cx| {
.with_context(|| format!("no prompt found with title {:?}", title))?; let prompt_id = store
let body = store.load(prompt_id).await?; .id_for_title(&title)
.with_context(|| format!("no prompt found with title {:?}", title))?;
anyhow::Ok(store.load(prompt_id, cx))
})??
.await?;
anyhow::Ok(body) anyhow::Ok(body)
} }
}); });

View file

@ -309,7 +309,7 @@ impl Example {
return Err(anyhow!("Setup only mode")); return Err(anyhow!("Setup only mode"));
} }
let thread_store = thread_store.await; let thread_store = thread_store.await?;
let thread = let thread =
thread_store.update(cx, |thread_store, cx| thread_store.create_thread(cx))?; thread_store.update(cx, |thread_store, cx| thread_store.create_thread(cx))?;

View file

@ -136,7 +136,7 @@ pub fn open_prompt_library(
} }
pub struct PromptLibrary { pub struct PromptLibrary {
store: Arc<PromptStore>, store: Entity<PromptStore>,
language_registry: Arc<LanguageRegistry>, language_registry: Arc<LanguageRegistry>,
prompt_editors: HashMap<PromptId, PromptEditor>, prompt_editors: HashMap<PromptId, PromptEditor>,
active_prompt_id: Option<PromptId>, active_prompt_id: Option<PromptId>,
@ -158,7 +158,7 @@ struct PromptEditor {
} }
struct PromptPickerDelegate { struct PromptPickerDelegate {
store: Arc<PromptStore>, store: Entity<PromptStore>,
selected_index: usize, selected_index: usize,
matches: Vec<PromptMetadata>, matches: Vec<PromptMetadata>,
} }
@ -179,8 +179,8 @@ impl PickerDelegate for PromptPickerDelegate {
self.matches.len() self.matches.len()
} }
fn no_matches_text(&self, _window: &mut Window, _cx: &mut App) -> Option<SharedString> { fn no_matches_text(&self, _window: &mut Window, cx: &mut App) -> Option<SharedString> {
let text = if self.store.prompt_count() == 0 { let text = if self.store.read(cx).prompt_count() == 0 {
"No prompts.".into() "No prompts.".into()
} else { } else {
"No prompts found matching your search.".into() "No prompts found matching your search.".into()
@ -211,7 +211,7 @@ impl PickerDelegate for PromptPickerDelegate {
window: &mut Window, window: &mut Window,
cx: &mut Context<Picker<Self>>, cx: &mut Context<Picker<Self>>,
) -> Task<()> { ) -> Task<()> {
let search = self.store.search(query); let search = self.store.read(cx).search(query, cx);
let prev_prompt_id = self.matches.get(self.selected_index).map(|mat| mat.id); let prev_prompt_id = self.matches.get(self.selected_index).map(|mat| mat.id);
cx.spawn_in(window, async move |this, cx| { cx.spawn_in(window, async move |this, cx| {
let (matches, selected_index) = cx let (matches, selected_index) = cx
@ -339,7 +339,7 @@ impl PickerDelegate for PromptPickerDelegate {
impl PromptLibrary { impl PromptLibrary {
fn new( fn new(
store: Arc<PromptStore>, store: Entity<PromptStore>,
language_registry: Arc<LanguageRegistry>, language_registry: Arc<LanguageRegistry>,
inline_assist_delegate: Box<dyn InlineAssistDelegate>, inline_assist_delegate: Box<dyn InlineAssistDelegate>,
make_completion_provider: Arc<dyn Fn() -> Box<dyn CompletionProvider>>, make_completion_provider: Arc<dyn Fn() -> Box<dyn CompletionProvider>>,
@ -398,7 +398,7 @@ impl PromptLibrary {
pub fn new_prompt(&mut self, window: &mut Window, cx: &mut Context<Self>) { pub fn new_prompt(&mut self, window: &mut Window, cx: &mut Context<Self>) {
// If we already have an untitled prompt, use that instead // If we already have an untitled prompt, use that instead
// of creating a new one. // of creating a new one.
if let Some(metadata) = self.store.first() { if let Some(metadata) = self.store.read(cx).first() {
if metadata.title.is_none() { if metadata.title.is_none() {
self.load_prompt(metadata.id, true, window, cx); self.load_prompt(metadata.id, true, window, cx);
return; return;
@ -406,7 +406,9 @@ impl PromptLibrary {
} }
let prompt_id = PromptId::new(); let prompt_id = PromptId::new();
let save = self.store.save(prompt_id, None, false, "".into()); let save = self.store.update(cx, |store, cx| {
store.save(prompt_id, None, false, "".into(), cx)
});
self.picker self.picker
.update(cx, |picker, cx| picker.refresh(window, cx)); .update(cx, |picker, cx| picker.refresh(window, cx));
cx.spawn_in(window, async move |this, cx| { cx.spawn_in(window, async move |this, cx| {
@ -430,7 +432,7 @@ impl PromptLibrary {
return; return;
} }
let prompt_metadata = self.store.metadata(prompt_id).unwrap(); let prompt_metadata = self.store.read(cx).metadata(prompt_id).unwrap();
let prompt_editor = self.prompt_editors.get_mut(&prompt_id).unwrap(); let prompt_editor = self.prompt_editors.get_mut(&prompt_id).unwrap();
let title = prompt_editor.title_editor.read(cx).text(cx); let title = prompt_editor.title_editor.read(cx).text(cx);
let body = prompt_editor.body_editor.update(cx, |editor, cx| { let body = prompt_editor.body_editor.update(cx, |editor, cx| {
@ -465,10 +467,13 @@ impl PromptLibrary {
} else { } else {
Some(SharedString::from(title)) Some(SharedString::from(title))
}; };
store cx.update(|_window, cx| {
.save(prompt_id, title, prompt_metadata.default, body) store.update(cx, |store, cx| {
.await store.save(prompt_id, title, prompt_metadata.default, body, cx)
.log_err(); })
})?
.await
.log_err();
this.update_in(cx, |this, window, cx| { this.update_in(cx, |this, window, cx| {
this.picker this.picker
.update(cx, |picker, cx| picker.refresh(window, cx)); .update(cx, |picker, cx| picker.refresh(window, cx));
@ -521,14 +526,21 @@ impl PromptLibrary {
window: &mut Window, window: &mut Window,
cx: &mut Context<Self>, cx: &mut Context<Self>,
) { ) {
if let Some(prompt_metadata) = self.store.metadata(prompt_id) { self.store.update(cx, move |store, cx| {
self.store if let Some(prompt_metadata) = store.metadata(prompt_id) {
.save_metadata(prompt_id, prompt_metadata.title, !prompt_metadata.default) store
.detach_and_log_err(cx); .save_metadata(
self.picker prompt_id,
.update(cx, |picker, cx| picker.refresh(window, cx)); prompt_metadata.title,
cx.notify(); !prompt_metadata.default,
} cx,
)
.detach_and_log_err(cx);
}
});
self.picker
.update(cx, |picker, cx| picker.refresh(window, cx));
cx.notify();
} }
pub fn load_prompt( pub fn load_prompt(
@ -545,9 +557,9 @@ impl PromptLibrary {
.update(cx, |editor, cx| window.focus(&editor.focus_handle(cx))); .update(cx, |editor, cx| window.focus(&editor.focus_handle(cx)));
} }
self.set_active_prompt(Some(prompt_id), window, cx); self.set_active_prompt(Some(prompt_id), window, cx);
} else if let Some(prompt_metadata) = self.store.metadata(prompt_id) { } else if let Some(prompt_metadata) = self.store.read(cx).metadata(prompt_id) {
let language_registry = self.language_registry.clone(); let language_registry = self.language_registry.clone();
let prompt = self.store.load(prompt_id); let prompt = self.store.read(cx).load(prompt_id, cx);
let make_completion_provider = self.make_completion_provider.clone(); let make_completion_provider = self.make_completion_provider.clone();
self.pending_load = cx.spawn_in(window, async move |this, cx| { self.pending_load = cx.spawn_in(window, async move |this, cx| {
let prompt = prompt.await; let prompt = prompt.await;
@ -673,7 +685,7 @@ impl PromptLibrary {
window: &mut Window, window: &mut Window,
cx: &mut Context<Self>, cx: &mut Context<Self>,
) { ) {
if let Some(metadata) = self.store.metadata(prompt_id) { if let Some(metadata) = self.store.read(cx).metadata(prompt_id) {
let confirmation = window.prompt( let confirmation = window.prompt(
PromptLevel::Warning, PromptLevel::Warning,
&format!( &format!(
@ -692,7 +704,9 @@ impl PromptLibrary {
this.set_active_prompt(None, window, cx); this.set_active_prompt(None, window, cx);
} }
this.prompt_editors.remove(&prompt_id); this.prompt_editors.remove(&prompt_id);
this.store.delete(prompt_id).detach_and_log_err(cx); this.store
.update(cx, |store, cx| store.delete(prompt_id, cx))
.detach_and_log_err(cx);
this.picker this.picker
.update(cx, |picker, cx| picker.refresh(window, cx)); .update(cx, |picker, cx| picker.refresh(window, cx));
cx.notify(); cx.notify();
@ -736,9 +750,9 @@ impl PromptLibrary {
let new_id = PromptId::new(); let new_id = PromptId::new();
let body = prompt.body_editor.read(cx).text(cx); let body = prompt.body_editor.read(cx).text(cx);
let save = self let save = self.store.update(cx, |store, cx| {
.store store.save(new_id, Some(title.into()), false, body.into(), cx)
.save(new_id, Some(title.into()), false, body.into()); });
self.picker self.picker
.update(cx, |picker, cx| picker.refresh(window, cx)); .update(cx, |picker, cx| picker.refresh(window, cx));
cx.spawn_in(window, async move |this, cx| { cx.spawn_in(window, async move |this, cx| {
@ -968,7 +982,7 @@ impl PromptLibrary {
.flex_none() .flex_none()
.min_w_64() .min_w_64()
.children(self.active_prompt_id.and_then(|prompt_id| { .children(self.active_prompt_id.and_then(|prompt_id| {
let prompt_metadata = self.store.metadata(prompt_id)?; let prompt_metadata = self.store.read(cx).metadata(prompt_id)?;
let prompt_editor = &self.prompt_editors[&prompt_id]; let prompt_editor = &self.prompt_editors[&prompt_id];
let focus_handle = prompt_editor.body_editor.focus_handle(cx); let focus_handle = prompt_editor.body_editor.focus_handle(cx);
let model = LanguageModelRegistry::read_global(cx) let model = LanguageModelRegistry::read_global(cx)
@ -1238,7 +1252,7 @@ impl Render for PromptLibrary {
.text_color(theme.colors().text) .text_color(theme.colors().text)
.child(self.render_prompt_list(cx)) .child(self.render_prompt_list(cx))
.map(|el| { .map(|el| {
if self.store.prompt_count() == 0 { if self.store.read(cx).prompt_count() == 0 {
el.child( el.child(
v_flex() v_flex()
.w_2_3() .w_2_3()

View file

@ -4,9 +4,11 @@ use anyhow::{Result, anyhow};
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use collections::HashMap; use collections::HashMap;
use futures::FutureExt as _; use futures::FutureExt as _;
use futures::future::{self, BoxFuture, Shared}; use futures::future::Shared;
use fuzzy::StringMatchCandidate; use fuzzy::StringMatchCandidate;
use gpui::{App, BackgroundExecutor, Global, ReadGlobal, SharedString, Task}; use gpui::{
App, AppContext, Context, Entity, EventEmitter, Global, ReadGlobal, SharedString, Task,
};
use heed::{ use heed::{
Database, RoTxn, Database, RoTxn,
types::{SerdeBincode, SerdeJson, Str}, types::{SerdeBincode, SerdeJson, Str},
@ -29,11 +31,16 @@ use uuid::Uuid;
/// a shared future to a global. /// a shared future to a global.
pub fn init(cx: &mut App) { pub fn init(cx: &mut App) {
let db_path = paths::prompts_dir().join("prompts-library-db.0.mdb"); let db_path = paths::prompts_dir().join("prompts-library-db.0.mdb");
let prompt_store_future = PromptStore::new(db_path, cx.background_executor().clone()) let prompt_store_task = PromptStore::new(db_path, cx);
.then(|result| future::ready(result.map(Arc::new).map_err(Arc::new))) let prompt_store_entity_task = cx
.boxed() .spawn(async move |cx| {
prompt_store_task
.await
.and_then(|prompt_store| cx.new(|_cx| prompt_store))
.map_err(Arc::new)
})
.shared(); .shared();
cx.set_global(GlobalPromptStore(prompt_store_future)) cx.set_global(GlobalPromptStore(prompt_store_entity_task))
} }
#[derive(Clone, Debug, Serialize, Deserialize)] #[derive(Clone, Debug, Serialize, Deserialize)]
@ -64,13 +71,16 @@ impl PromptId {
} }
pub struct PromptStore { pub struct PromptStore {
executor: BackgroundExecutor,
env: heed::Env, env: heed::Env,
metadata_cache: RwLock<MetadataCache>, metadata_cache: RwLock<MetadataCache>,
metadata: Database<SerdeJson<PromptId>, SerdeJson<PromptMetadata>>, metadata: Database<SerdeJson<PromptId>, SerdeJson<PromptMetadata>>,
bodies: Database<SerdeJson<PromptId>, Str>, bodies: Database<SerdeJson<PromptId>, Str>,
} }
pub struct PromptsUpdatedEvent;
impl EventEmitter<PromptsUpdatedEvent> for PromptStore {}
#[derive(Default)] #[derive(Default)]
struct MetadataCache { struct MetadataCache {
metadata: Vec<PromptMetadata>, metadata: Vec<PromptMetadata>,
@ -117,49 +127,45 @@ impl MetadataCache {
} }
impl PromptStore { impl PromptStore {
pub fn global(cx: &App) -> impl Future<Output = Result<Arc<Self>>> + use<> { pub fn global(cx: &App) -> impl Future<Output = Result<Entity<Self>>> + use<> {
let store = GlobalPromptStore::global(cx).0.clone(); let store = GlobalPromptStore::global(cx).0.clone();
async move { store.await.map_err(|err| anyhow!(err)) } async move { store.await.map_err(|err| anyhow!(err)) }
} }
pub fn new(db_path: PathBuf, executor: BackgroundExecutor) -> Task<Result<Self>> { pub fn new(db_path: PathBuf, cx: &App) -> Task<Result<Self>> {
executor.spawn({ cx.background_spawn(async move {
let executor = executor.clone(); std::fs::create_dir_all(&db_path)?;
async move {
std::fs::create_dir_all(&db_path)?;
let db_env = unsafe { let db_env = unsafe {
heed::EnvOpenOptions::new() heed::EnvOpenOptions::new()
.map_size(1024 * 1024 * 1024) // 1GB .map_size(1024 * 1024 * 1024) // 1GB
.max_dbs(4) // Metadata and bodies (possibly v1 of both as well) .max_dbs(4) // Metadata and bodies (possibly v1 of both as well)
.open(db_path)? .open(db_path)?
}; };
let mut txn = db_env.write_txn()?; let mut txn = db_env.write_txn()?;
let metadata = db_env.create_database(&mut txn, Some("metadata.v2"))?; let metadata = db_env.create_database(&mut txn, Some("metadata.v2"))?;
let bodies = db_env.create_database(&mut txn, Some("bodies.v2"))?; let bodies = db_env.create_database(&mut txn, Some("bodies.v2"))?;
// Remove edit workflow prompt, as we decided to opt into it using // Remove edit workflow prompt, as we decided to opt into it using
// a slash command instead. // a slash command instead.
metadata.delete(&mut txn, &PromptId::EditWorkflow).ok(); metadata.delete(&mut txn, &PromptId::EditWorkflow).ok();
bodies.delete(&mut txn, &PromptId::EditWorkflow).ok(); bodies.delete(&mut txn, &PromptId::EditWorkflow).ok();
txn.commit()?; txn.commit()?;
Self::upgrade_dbs(&db_env, metadata, bodies).log_err(); Self::upgrade_dbs(&db_env, metadata, bodies).log_err();
let txn = db_env.read_txn()?; let txn = db_env.read_txn()?;
let metadata_cache = MetadataCache::from_db(metadata, &txn)?; let metadata_cache = MetadataCache::from_db(metadata, &txn)?;
txn.commit()?; txn.commit()?;
Ok(PromptStore { Ok(PromptStore {
executor, env: db_env,
env: db_env, metadata_cache: RwLock::new(metadata_cache),
metadata_cache: RwLock::new(metadata_cache), metadata,
metadata, bodies,
bodies, })
})
}
}) })
} }
@ -237,10 +243,10 @@ impl PromptStore {
Ok(()) Ok(())
} }
pub fn load(&self, id: PromptId) -> Task<Result<String>> { pub fn load(&self, id: PromptId, cx: &App) -> Task<Result<String>> {
let env = self.env.clone(); let env = self.env.clone();
let bodies = self.bodies; let bodies = self.bodies;
self.executor.spawn(async move { cx.background_spawn(async move {
let txn = env.read_txn()?; let txn = env.read_txn()?;
let mut prompt = bodies let mut prompt = bodies
.get(&txn, &id)? .get(&txn, &id)?
@ -262,21 +268,27 @@ impl PromptStore {
.collect::<Vec<_>>(); .collect::<Vec<_>>();
} }
pub fn delete(&self, id: PromptId) -> Task<Result<()>> { pub fn delete(&self, id: PromptId, cx: &Context<Self>) -> Task<Result<()>> {
self.metadata_cache.write().remove(id); self.metadata_cache.write().remove(id);
let db_connection = self.env.clone(); let db_connection = self.env.clone();
let bodies = self.bodies; let bodies = self.bodies;
let metadata = self.metadata; let metadata = self.metadata;
self.executor.spawn(async move { let task = cx.background_spawn(async move {
let mut txn = db_connection.write_txn()?; let mut txn = db_connection.write_txn()?;
metadata.delete(&mut txn, &id)?; metadata.delete(&mut txn, &id)?;
bodies.delete(&mut txn, &id)?; bodies.delete(&mut txn, &id)?;
txn.commit()?; txn.commit()?;
Ok(()) anyhow::Ok(())
});
cx.spawn(async move |this, cx| {
task.await?;
this.update(cx, |_, cx| cx.emit(PromptsUpdatedEvent)).ok();
anyhow::Ok(())
}) })
} }
@ -302,10 +314,10 @@ impl PromptStore {
Some(metadata.id) Some(metadata.id)
} }
pub fn search(&self, query: String) -> Task<Vec<PromptMetadata>> { pub fn search(&self, query: String, cx: &App) -> Task<Vec<PromptMetadata>> {
let cached_metadata = self.metadata_cache.read().metadata.clone(); let cached_metadata = self.metadata_cache.read().metadata.clone();
let executor = self.executor.clone(); let executor = cx.background_executor().clone();
self.executor.spawn(async move { cx.background_spawn(async move {
let mut matches = if query.is_empty() { let mut matches = if query.is_empty() {
cached_metadata cached_metadata
} else { } else {
@ -341,6 +353,7 @@ impl PromptStore {
title: Option<SharedString>, title: Option<SharedString>,
default: bool, default: bool,
body: Rope, body: Rope,
cx: &Context<Self>,
) -> Task<Result<()>> { ) -> Task<Result<()>> {
if id.is_built_in() { if id.is_built_in() {
return Task::ready(Err(anyhow!("built-in prompts cannot be saved"))); return Task::ready(Err(anyhow!("built-in prompts cannot be saved")));
@ -358,7 +371,7 @@ impl PromptStore {
let bodies = self.bodies; let bodies = self.bodies;
let metadata = self.metadata; let metadata = self.metadata;
self.executor.spawn(async move { let task = cx.background_spawn(async move {
let mut txn = db_connection.write_txn()?; let mut txn = db_connection.write_txn()?;
metadata.put(&mut txn, &id, &prompt_metadata)?; metadata.put(&mut txn, &id, &prompt_metadata)?;
@ -366,7 +379,13 @@ impl PromptStore {
txn.commit()?; txn.commit()?;
Ok(()) anyhow::Ok(())
});
cx.spawn(async move |this, cx| {
task.await?;
this.update(cx, |_, cx| cx.emit(PromptsUpdatedEvent)).ok();
anyhow::Ok(())
}) })
} }
@ -375,6 +394,7 @@ impl PromptStore {
id: PromptId, id: PromptId,
mut title: Option<SharedString>, mut title: Option<SharedString>,
default: bool, default: bool,
cx: &Context<Self>,
) -> Task<Result<()>> { ) -> Task<Result<()>> {
let mut cache = self.metadata_cache.write(); let mut cache = self.metadata_cache.write();
@ -397,19 +417,23 @@ impl PromptStore {
let db_connection = self.env.clone(); let db_connection = self.env.clone();
let metadata = self.metadata; let metadata = self.metadata;
self.executor.spawn(async move { let task = cx.background_spawn(async move {
let mut txn = db_connection.write_txn()?; let mut txn = db_connection.write_txn()?;
metadata.put(&mut txn, &id, &prompt_metadata)?; metadata.put(&mut txn, &id, &prompt_metadata)?;
txn.commit()?; txn.commit()?;
Ok(()) anyhow::Ok(())
});
cx.spawn(async move |this, cx| {
task.await?;
this.update(cx, |_, cx| cx.emit(PromptsUpdatedEvent)).ok();
anyhow::Ok(())
}) })
} }
} }
/// Wraps a shared future to a prompt store so it can be assigned as a context global. /// Wraps a shared future to a prompt store so it can be assigned as a context global.
pub struct GlobalPromptStore( pub struct GlobalPromptStore(Shared<Task<Result<Entity<PromptStore>, Arc<anyhow::Error>>>>);
Shared<BoxFuture<'static, Result<Arc<PromptStore>, Arc<anyhow::Error>>>>,
);
impl Global for GlobalPromptStore {} impl Global for GlobalPromptStore {}

View file

@ -19,20 +19,29 @@ use util::{ResultExt, get_system_shell};
#[derive(Debug, Clone, Serialize)] #[derive(Debug, Clone, Serialize)]
pub struct ProjectContext { pub struct ProjectContext {
pub worktrees: Vec<WorktreeContext>, pub worktrees: Vec<WorktreeContext>,
/// Whether any worktree has a rules_file. Provided as a field because handlebars can't do this.
pub has_rules: bool, pub has_rules: bool,
pub default_user_rules: Vec<DefaultUserRulesContext>,
/// `!default_user_rules.is_empty()` - provided as a field because handlebars can't do this.
pub has_default_user_rules: bool,
pub os: String, pub os: String,
pub arch: String, pub arch: String,
pub shell: String, pub shell: String,
} }
impl ProjectContext { impl ProjectContext {
pub fn new(worktrees: Vec<WorktreeContext>) -> Self { pub fn new(
worktrees: Vec<WorktreeContext>,
default_user_rules: Vec<DefaultUserRulesContext>,
) -> Self {
let has_rules = worktrees let has_rules = worktrees
.iter() .iter()
.any(|worktree| worktree.rules_file.is_some()); .any(|worktree| worktree.rules_file.is_some());
Self { Self {
worktrees, worktrees,
has_rules, has_rules,
has_default_user_rules: !default_user_rules.is_empty(),
default_user_rules,
os: std::env::consts::OS.to_string(), os: std::env::consts::OS.to_string(),
arch: std::env::consts::ARCH.to_string(), arch: std::env::consts::ARCH.to_string(),
shell: get_system_shell(), shell: get_system_shell(),
@ -40,6 +49,12 @@ impl ProjectContext {
} }
} }
#[derive(Debug, Clone, Serialize)]
pub struct DefaultUserRulesContext {
pub title: Option<String>,
pub contents: String,
}
#[derive(Debug, Clone, Serialize)] #[derive(Debug, Clone, Serialize)]
pub struct WorktreeContext { pub struct WorktreeContext {
pub root_name: String, pub root_name: String,
@ -377,3 +392,30 @@ impl PromptBuilder {
self.handlebars.lock().render("suggest_edits", &()) self.handlebars.lock().render("suggest_edits", &())
} }
} }
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_assistant_system_prompt_renders() {
let worktrees = vec![WorktreeContext {
root_name: "path".into(),
abs_path: Path::new("/some/path").into(),
rules_file: Some(RulesFileContext {
path_in_worktree: Path::new(".rules").into(),
abs_path: Path::new("/some/path/.rules").into(),
text: "".into(),
}),
}];
let default_user_rules = vec![DefaultUserRulesContext {
title: Some("Rules title".into()),
contents: "Rules contents".into(),
}];
let project_context = ProjectContext::new(worktrees, default_user_rules);
PromptBuilder::new(None)
.unwrap()
.generate_assistant_system_prompt(&project_context)
.unwrap();
}
}