Allow /workflow and step resolution prompts to be overridden (#15892)

This will help us as we hit issues with the /workflow and step
resolution. We can override the baked-in prompts and make tweaks, then
import our refinements back into the source tree when we're ready.

Release Notes:

- N/A
This commit is contained in:
Nathan Sobo 2024-08-06 21:47:42 -06:00 committed by GitHub
parent c8f1358629
commit 990774247e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 196 additions and 104 deletions

View file

@ -25,6 +25,7 @@ use language_model::{
LanguageModelId, LanguageModelProviderId, LanguageModelRegistry, LanguageModelResponseMessage, LanguageModelId, LanguageModelProviderId, LanguageModelRegistry, LanguageModelResponseMessage,
}; };
pub(crate) use model_selector::*; pub(crate) use model_selector::*;
pub use prompts::PromptBuilder;
use semantic_index::{CloudEmbeddingProvider, SemanticIndex}; use semantic_index::{CloudEmbeddingProvider, SemanticIndex};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use settings::{update_settings_file, Settings, SettingsStore}; use settings::{update_settings_file, Settings, SettingsStore};
@ -163,7 +164,7 @@ impl Assistant {
} }
} }
pub fn init(fs: Arc<dyn Fs>, client: Arc<Client>, cx: &mut AppContext) { pub fn init(fs: Arc<dyn Fs>, client: Arc<Client>, cx: &mut AppContext) -> Arc<PromptBuilder> {
cx.set_global(Assistant::default()); cx.set_global(Assistant::default());
AssistantSettings::register(cx); AssistantSettings::register(cx);
@ -196,19 +197,25 @@ pub fn init(fs: Arc<dyn Fs>, client: Arc<Client>, cx: &mut AppContext) {
prompt_library::init(cx); prompt_library::init(cx);
init_language_model_settings(cx); init_language_model_settings(cx);
assistant_slash_command::init(cx); assistant_slash_command::init(cx);
register_slash_commands(cx);
assistant_panel::init(cx); assistant_panel::init(cx);
if let Some(prompt_builder) = prompts::PromptBuilder::new(Some((fs.clone(), cx))).log_err() { let prompt_builder = prompts::PromptBuilder::new(Some((fs.clone(), cx)))
let prompt_builder = Arc::new(prompt_builder); .log_err()
inline_assistant::init( .map(Arc::new)
fs.clone(), .unwrap_or_else(|| Arc::new(prompts::PromptBuilder::new(None).unwrap()));
prompt_builder.clone(), register_slash_commands(Some(prompt_builder.clone()), cx);
client.telemetry().clone(), inline_assistant::init(
cx, fs.clone(),
); prompt_builder.clone(),
terminal_inline_assistant::init(fs.clone(), prompt_builder, client.telemetry().clone(), cx); client.telemetry().clone(),
} cx,
);
terminal_inline_assistant::init(
fs.clone(),
prompt_builder.clone(),
client.telemetry().clone(),
cx,
);
IndexedDocsRegistry::init_global(cx); IndexedDocsRegistry::init_global(cx);
CommandPaletteFilter::update_global(cx, |filter, _cx| { CommandPaletteFilter::update_global(cx, |filter, _cx| {
@ -226,6 +233,8 @@ pub fn init(fs: Arc<dyn Fs>, client: Arc<Client>, cx: &mut AppContext) {
}); });
}) })
.detach(); .detach();
prompt_builder
} }
fn init_language_model_settings(cx: &mut AppContext) { fn init_language_model_settings(cx: &mut AppContext) {
@ -256,7 +265,7 @@ fn update_active_language_model_from_settings(cx: &mut AppContext) {
}); });
} }
fn register_slash_commands(cx: &mut AppContext) { fn register_slash_commands(prompt_builder: Option<Arc<PromptBuilder>>, cx: &mut AppContext) {
let slash_command_registry = SlashCommandRegistry::global(cx); let slash_command_registry = SlashCommandRegistry::global(cx);
slash_command_registry.register_command(file_command::FileSlashCommand, true); slash_command_registry.register_command(file_command::FileSlashCommand, true);
slash_command_registry.register_command(active_command::ActiveSlashCommand, true); slash_command_registry.register_command(active_command::ActiveSlashCommand, true);
@ -270,7 +279,12 @@ fn register_slash_commands(cx: &mut AppContext) {
slash_command_registry.register_command(now_command::NowSlashCommand, true); slash_command_registry.register_command(now_command::NowSlashCommand, true);
slash_command_registry.register_command(diagnostics_command::DiagnosticsSlashCommand, true); slash_command_registry.register_command(diagnostics_command::DiagnosticsSlashCommand, true);
slash_command_registry.register_command(docs_command::DocsSlashCommand, true); slash_command_registry.register_command(docs_command::DocsSlashCommand, true);
slash_command_registry.register_command(workflow_command::WorkflowSlashCommand, true); if let Some(prompt_builder) = prompt_builder {
slash_command_registry.register_command(
workflow_command::WorkflowSlashCommand::new(prompt_builder),
true,
);
}
slash_command_registry.register_command(fetch_command::FetchSlashCommand, false); slash_command_registry.register_command(fetch_command::FetchSlashCommand, false);
} }

View file

@ -2,6 +2,7 @@ use crate::{
assistant_settings::{AssistantDockPosition, AssistantSettings}, assistant_settings::{AssistantDockPosition, AssistantSettings},
humanize_token_count, humanize_token_count,
prompt_library::open_prompt_library, prompt_library::open_prompt_library,
prompts::PromptBuilder,
slash_command::{ slash_command::{
default_command::DefaultSlashCommand, default_command::DefaultSlashCommand,
docs_command::{DocsSlashCommand, DocsSlashCommandArgs}, docs_command::{DocsSlashCommand, DocsSlashCommandArgs},
@ -315,14 +316,17 @@ impl PickerDelegate for SavedContextPickerDelegate {
impl AssistantPanel { impl AssistantPanel {
pub fn load( pub fn load(
workspace: WeakView<Workspace>, workspace: WeakView<Workspace>,
prompt_builder: Arc<PromptBuilder>,
cx: AsyncWindowContext, cx: AsyncWindowContext,
) -> Task<Result<View<Self>>> { ) -> Task<Result<View<Self>>> {
cx.spawn(|mut cx| async move { cx.spawn(|mut cx| async move {
let context_store = workspace let context_store = workspace
.update(&mut cx, |workspace, cx| { .update(&mut cx, |workspace, cx| {
ContextStore::new(workspace.project().clone(), cx) let project = workspace.project().clone();
ContextStore::new(project, prompt_builder.clone(), cx)
})? })?
.await?; .await?;
workspace.update(&mut cx, |workspace, cx| { workspace.update(&mut cx, |workspace, cx| {
// TODO: deserialize state. // TODO: deserialize state.
cx.new_view(|cx| Self::new(workspace, context_store, cx)) cx.new_view(|cx| Self::new(workspace, context_store, cx))

View file

@ -1,5 +1,5 @@
use crate::{ use crate::{
prompt_library::PromptStore, slash_command::SlashCommandLine, AssistantPanel, InitialInsertion, prompts::PromptBuilder, slash_command::SlashCommandLine, AssistantPanel, InitialInsertion,
InlineAssistId, InlineAssistant, MessageId, MessageStatus, InlineAssistId, InlineAssistant, MessageId, MessageStatus,
}; };
use anyhow::{anyhow, Context as _, Result}; use anyhow::{anyhow, Context as _, Result};
@ -611,6 +611,7 @@ pub struct Context {
language_registry: Arc<LanguageRegistry>, language_registry: Arc<LanguageRegistry>,
workflow_steps: Vec<WorkflowStep>, workflow_steps: Vec<WorkflowStep>,
project: Option<Model<Project>>, project: Option<Model<Project>>,
prompt_builder: Arc<PromptBuilder>,
} }
impl EventEmitter<ContextEvent> for Context {} impl EventEmitter<ContextEvent> for Context {}
@ -620,6 +621,7 @@ impl Context {
language_registry: Arc<LanguageRegistry>, language_registry: Arc<LanguageRegistry>,
project: Option<Model<Project>>, project: Option<Model<Project>>,
telemetry: Option<Arc<Telemetry>>, telemetry: Option<Arc<Telemetry>>,
prompt_builder: Arc<PromptBuilder>,
cx: &mut ModelContext<Self>, cx: &mut ModelContext<Self>,
) -> Self { ) -> Self {
Self::new( Self::new(
@ -627,17 +629,20 @@ impl Context {
ReplicaId::default(), ReplicaId::default(),
language::Capability::ReadWrite, language::Capability::ReadWrite,
language_registry, language_registry,
prompt_builder,
project, project,
telemetry, telemetry,
cx, cx,
) )
} }
#[allow(clippy::too_many_arguments)]
pub fn new( pub fn new(
id: ContextId, id: ContextId,
replica_id: ReplicaId, replica_id: ReplicaId,
capability: language::Capability, capability: language::Capability,
language_registry: Arc<LanguageRegistry>, language_registry: Arc<LanguageRegistry>,
prompt_builder: Arc<PromptBuilder>,
project: Option<Model<Project>>, project: Option<Model<Project>>,
telemetry: Option<Arc<Telemetry>>, telemetry: Option<Arc<Telemetry>>,
cx: &mut ModelContext<Self>, cx: &mut ModelContext<Self>,
@ -680,6 +685,7 @@ impl Context {
project, project,
language_registry, language_registry,
workflow_steps: Vec::new(), workflow_steps: Vec::new(),
prompt_builder,
}; };
let first_message_id = MessageId(clock::Lamport { let first_message_id = MessageId(clock::Lamport {
@ -749,6 +755,7 @@ impl Context {
saved_context: SavedContext, saved_context: SavedContext,
path: PathBuf, path: PathBuf,
language_registry: Arc<LanguageRegistry>, language_registry: Arc<LanguageRegistry>,
prompt_builder: Arc<PromptBuilder>,
project: Option<Model<Project>>, project: Option<Model<Project>>,
telemetry: Option<Arc<Telemetry>>, telemetry: Option<Arc<Telemetry>>,
cx: &mut ModelContext<Self>, cx: &mut ModelContext<Self>,
@ -759,6 +766,7 @@ impl Context {
ReplicaId::default(), ReplicaId::default(),
language::Capability::ReadWrite, language::Capability::ReadWrite,
language_registry, language_registry,
prompt_builder,
project, project,
telemetry, telemetry,
cx, cx,
@ -1246,9 +1254,9 @@ impl Context {
cx.spawn(|this, mut cx| { cx.spawn(|this, mut cx| {
async move { async move {
let prompt_store = cx.update(|cx| PromptStore::global(cx))?.await?; let mut prompt = this.update(&mut cx, |this, _| {
this.prompt_builder.generate_step_resolution_prompt()
let mut prompt = prompt_store.step_resolution_prompt()?; })??;
prompt.push_str(&step_text); prompt.push_str(&step_text);
request.messages.push(LanguageModelRequestMessage { request.messages.push(LanguageModelRequestMessage {
@ -2448,8 +2456,9 @@ mod tests {
cx.set_global(settings_store); cx.set_global(settings_store);
assistant_panel::init(cx); assistant_panel::init(cx);
let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone())); let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
let context = cx.new_model(|cx| Context::local(registry, None, None, cx)); let context =
cx.new_model(|cx| Context::local(registry, None, None, prompt_builder.clone(), cx));
let buffer = context.read(cx).buffer.clone(); let buffer = context.read(cx).buffer.clone();
let message_1 = context.read(cx).message_anchors[0].clone(); let message_1 = context.read(cx).message_anchors[0].clone();
@ -2580,7 +2589,9 @@ mod tests {
assistant_panel::init(cx); assistant_panel::init(cx);
let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone())); let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
let context = cx.new_model(|cx| Context::local(registry, None, None, cx)); let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
let context =
cx.new_model(|cx| Context::local(registry, None, None, prompt_builder.clone(), cx));
let buffer = context.read(cx).buffer.clone(); let buffer = context.read(cx).buffer.clone();
let message_1 = context.read(cx).message_anchors[0].clone(); let message_1 = context.read(cx).message_anchors[0].clone();
@ -2673,7 +2684,9 @@ mod tests {
cx.set_global(settings_store); cx.set_global(settings_store);
assistant_panel::init(cx); assistant_panel::init(cx);
let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone())); let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
let context = cx.new_model(|cx| Context::local(registry, None, None, cx)); let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
let context =
cx.new_model(|cx| Context::local(registry, None, None, prompt_builder.clone(), cx));
let buffer = context.read(cx).buffer.clone(); let buffer = context.read(cx).buffer.clone();
let message_1 = context.read(cx).message_anchors[0].clone(); let message_1 = context.read(cx).message_anchors[0].clone();
@ -2778,7 +2791,10 @@ mod tests {
slash_command_registry.register_command(active_command::ActiveSlashCommand, false); slash_command_registry.register_command(active_command::ActiveSlashCommand, false);
let registry = Arc::new(LanguageRegistry::test(cx.executor())); let registry = Arc::new(LanguageRegistry::test(cx.executor()));
let context = cx.new_model(|cx| Context::local(registry.clone(), None, None, cx)); let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
let context = cx.new_model(|cx| {
Context::local(registry.clone(), None, None, prompt_builder.clone(), cx)
});
let output_ranges = Rc::new(RefCell::new(HashSet::default())); let output_ranges = Rc::new(RefCell::new(HashSet::default()));
context.update(cx, |_, cx| { context.update(cx, |_, cx| {
@ -2905,7 +2921,16 @@ mod tests {
let registry = Arc::new(LanguageRegistry::test(cx.executor())); let registry = Arc::new(LanguageRegistry::test(cx.executor()));
// Create a new context // Create a new context
let context = cx.new_model(|cx| Context::local(registry.clone(), Some(project), None, cx)); let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
let context = cx.new_model(|cx| {
Context::local(
registry.clone(),
Some(project),
None,
prompt_builder.clone(),
cx,
)
});
let buffer = context.read_with(cx, |context, _| context.buffer.clone()); let buffer = context.read_with(cx, |context, _| context.buffer.clone());
// Simulate user input // Simulate user input
@ -3070,7 +3095,10 @@ mod tests {
cx.update(LanguageModelRegistry::test); cx.update(LanguageModelRegistry::test);
cx.update(assistant_panel::init); cx.update(assistant_panel::init);
let registry = Arc::new(LanguageRegistry::test(cx.executor())); let registry = Arc::new(LanguageRegistry::test(cx.executor()));
let context = cx.new_model(|cx| Context::local(registry.clone(), None, None, cx)); let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
let context = cx.new_model(|cx| {
Context::local(registry.clone(), None, None, prompt_builder.clone(), cx)
});
let buffer = context.read_with(cx, |context, _| context.buffer.clone()); let buffer = context.read_with(cx, |context, _| context.buffer.clone());
let message_0 = context.read_with(cx, |context, _| context.message_anchors[0].id); let message_0 = context.read_with(cx, |context, _| context.message_anchors[0].id);
let message_1 = context.update(cx, |context, cx| { let message_1 = context.update(cx, |context, cx| {
@ -3109,6 +3137,7 @@ mod tests {
serialized_context, serialized_context,
Default::default(), Default::default(),
registry.clone(), registry.clone(),
prompt_builder.clone(),
None, None,
None, None,
cx, cx,
@ -3158,6 +3187,7 @@ mod tests {
let num_peers = rng.gen_range(min_peers..=max_peers); let num_peers = rng.gen_range(min_peers..=max_peers);
let context_id = ContextId::new(); let context_id = ContextId::new();
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
for i in 0..num_peers { for i in 0..num_peers {
let context = cx.new_model(|cx| { let context = cx.new_model(|cx| {
Context::new( Context::new(
@ -3165,6 +3195,7 @@ mod tests {
i as ReplicaId, i as ReplicaId,
language::Capability::ReadWrite, language::Capability::ReadWrite,
registry.clone(), registry.clone(),
prompt_builder.clone(),
None, None,
None, None,
cx, cx,

View file

@ -1,6 +1,6 @@
use crate::{ use crate::{
Context, ContextEvent, ContextId, ContextOperation, ContextVersion, SavedContext, prompts::PromptBuilder, Context, ContextEvent, ContextId, ContextOperation, ContextVersion,
SavedContextMetadata, SavedContext, SavedContextMetadata,
}; };
use anyhow::{anyhow, Context as _, Result}; use anyhow::{anyhow, Context as _, Result};
use client::{proto, telemetry::Telemetry, Client, TypedEnvelope}; use client::{proto, telemetry::Telemetry, Client, TypedEnvelope};
@ -52,6 +52,7 @@ pub struct ContextStore {
project_is_shared: bool, project_is_shared: bool,
client_subscription: Option<client::Subscription>, client_subscription: Option<client::Subscription>,
_project_subscriptions: Vec<gpui::Subscription>, _project_subscriptions: Vec<gpui::Subscription>,
prompt_builder: Arc<PromptBuilder>,
} }
pub enum ContextStoreEvent { pub enum ContextStoreEvent {
@ -82,7 +83,11 @@ impl ContextHandle {
} }
impl ContextStore { impl ContextStore {
pub fn new(project: Model<Project>, cx: &mut AppContext) -> Task<Result<Model<Self>>> { pub fn new(
project: Model<Project>,
prompt_builder: Arc<PromptBuilder>,
cx: &mut AppContext,
) -> Task<Result<Model<Self>>> {
let fs = project.read(cx).fs().clone(); let fs = project.read(cx).fs().clone();
let languages = project.read(cx).languages().clone(); let languages = project.read(cx).languages().clone();
let telemetry = project.read(cx).client().telemetry().clone(); let telemetry = project.read(cx).client().telemetry().clone();
@ -117,6 +122,7 @@ impl ContextStore {
project_is_shared: false, project_is_shared: false,
client: project.read(cx).client(), client: project.read(cx).client(),
project: project.clone(), project: project.clone(),
prompt_builder,
}; };
this.handle_project_changed(project, cx); this.handle_project_changed(project, cx);
this.synchronize_contexts(cx); this.synchronize_contexts(cx);
@ -334,6 +340,7 @@ impl ContextStore {
self.languages.clone(), self.languages.clone(),
Some(self.project.clone()), Some(self.project.clone()),
Some(self.telemetry.clone()), Some(self.telemetry.clone()),
self.prompt_builder.clone(),
cx, cx,
) )
}); });
@ -358,6 +365,7 @@ impl ContextStore {
let language_registry = self.languages.clone(); let language_registry = self.languages.clone();
let project = self.project.clone(); let project = self.project.clone();
let telemetry = self.telemetry.clone(); let telemetry = self.telemetry.clone();
let prompt_builder = self.prompt_builder.clone();
let request = self.client.request(proto::CreateContext { project_id }); let request = self.client.request(proto::CreateContext { project_id });
cx.spawn(|this, mut cx| async move { cx.spawn(|this, mut cx| async move {
let response = request.await?; let response = request.await?;
@ -369,6 +377,7 @@ impl ContextStore {
replica_id, replica_id,
capability, capability,
language_registry, language_registry,
prompt_builder,
Some(project), Some(project),
Some(telemetry), Some(telemetry),
cx, cx,
@ -417,6 +426,7 @@ impl ContextStore {
SavedContext::from_json(&saved_context) SavedContext::from_json(&saved_context)
} }
}); });
let prompt_builder = self.prompt_builder.clone();
cx.spawn(|this, mut cx| async move { cx.spawn(|this, mut cx| async move {
let saved_context = load.await?; let saved_context = load.await?;
@ -425,6 +435,7 @@ impl ContextStore {
saved_context, saved_context,
path.clone(), path.clone(),
languages, languages,
prompt_builder,
Some(project), Some(project),
Some(telemetry), Some(telemetry),
cx, cx,
@ -493,6 +504,7 @@ impl ContextStore {
project_id, project_id,
context_id: context_id.to_proto(), context_id: context_id.to_proto(),
}); });
let prompt_builder = self.prompt_builder.clone();
cx.spawn(|this, mut cx| async move { cx.spawn(|this, mut cx| async move {
let response = request.await?; let response = request.await?;
let context_proto = response.context.context("invalid context")?; let context_proto = response.context.context("invalid context")?;
@ -502,6 +514,7 @@ impl ContextStore {
replica_id, replica_id,
capability, capability,
language_registry, language_registry,
prompt_builder,
Some(project), Some(project),
Some(telemetry), Some(telemetry),
cx, cx,

View file

@ -2,7 +2,6 @@ use crate::{
slash_command::SlashCommandCompletionProvider, AssistantPanel, InlineAssist, InlineAssistant, slash_command::SlashCommandCompletionProvider, AssistantPanel, InlineAssist, InlineAssistant,
}; };
use anyhow::{anyhow, Result}; use anyhow::{anyhow, Result};
use assets::Assets;
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use collections::{HashMap, HashSet}; use collections::{HashMap, HashSet};
use editor::{actions::Tab, CurrentLineHighlight, Editor, EditorElement, EditorEvent, EditorStyle}; use editor::{actions::Tab, CurrentLineHighlight, Editor, EditorElement, EditorEvent, EditorStyle};
@ -12,8 +11,8 @@ use futures::{
}; };
use fuzzy::StringMatchCandidate; use fuzzy::StringMatchCandidate;
use gpui::{ use gpui::{
actions, point, size, transparent_black, AppContext, AssetSource, BackgroundExecutor, Bounds, actions, point, size, transparent_black, AppContext, BackgroundExecutor, Bounds, EventEmitter,
EventEmitter, Global, HighlightStyle, PromptLevel, ReadGlobal, Subscription, Task, TextStyle, Global, HighlightStyle, PromptLevel, ReadGlobal, Subscription, Task, TextStyle,
TitlebarOptions, UpdateGlobal, View, WindowBounds, WindowHandle, WindowOptions, TitlebarOptions, UpdateGlobal, View, WindowBounds, WindowHandle, WindowOptions,
}; };
use heed::{ use heed::{
@ -1466,17 +1465,6 @@ impl PromptStore {
fn first(&self) -> Option<PromptMetadata> { fn first(&self) -> Option<PromptMetadata> {
self.metadata_cache.read().metadata.first().cloned() self.metadata_cache.read().metadata.first().cloned()
} }
pub fn step_resolution_prompt(&self) -> Result<String> {
let path = "prompts/step_resolution.md";
Ok(String::from_utf8(
Assets
.load(path)?
.ok_or_else(|| anyhow!("{path} not found"))?
.to_vec(),
)?)
}
} }
/// Wraps a shared future to a prompt store so it can be assigned as a context global. /// Wraps a shared future to a prompt store so it can be assigned as a context global.

View file

@ -125,22 +125,20 @@ impl PromptBuilder {
} }
fn register_templates(handlebars: &mut Handlebars) -> Result<(), Box<TemplateError>> { fn register_templates(handlebars: &mut Handlebars) -> Result<(), Box<TemplateError>> {
let content_prompt = Assets::get("prompts/content_prompt.hbs") let mut register_template = |id: &str| {
.expect("Content prompt template not found") let prompt = Assets::get(&format!("prompts/{}.hbs", id))
.data; .unwrap_or_else(|| panic!("{} prompt template not found", id))
let terminal_assistant_prompt = Assets::get("prompts/terminal_assistant_prompt.hbs") .data;
.expect("Terminal assistant prompt template not found") handlebars
.data; .register_template_string(id, String::from_utf8_lossy(&prompt))
.map_err(Box::new)
};
register_template("content_prompt")?;
register_template("terminal_assistant_prompt")?;
register_template("edit_workflow")?;
register_template("step_resolution")?;
handlebars
.register_template_string("content_prompt", String::from_utf8_lossy(&content_prompt))
.map_err(Box::new)?;
handlebars
.register_template_string(
"terminal_assistant_prompt",
String::from_utf8_lossy(&terminal_assistant_prompt),
)
.map_err(Box::new)?;
Ok(()) Ok(())
} }
@ -236,4 +234,12 @@ impl PromptBuilder {
.lock() .lock()
.render("terminal_assistant_prompt", &context) .render("terminal_assistant_prompt", &context)
} }
pub fn generate_workflow_prompt(&self) -> Result<String, RenderError> {
self.handlebars.lock().render("edit_workflow", &())
}
pub fn generate_step_resolution_prompt(&self) -> Result<String, RenderError> {
self.handlebars.lock().render("step_resolution", &())
}
} }

View file

@ -1,18 +1,27 @@
use std::sync::atomic::AtomicBool; use crate::prompts::PromptBuilder;
use std::sync::Arc; use std::sync::Arc;
use anyhow::{Context as _, Result}; use std::sync::atomic::AtomicBool;
use assets::Assets;
use anyhow::Result;
use assistant_slash_command::{ use assistant_slash_command::{
ArgumentCompletion, SlashCommand, SlashCommandOutput, SlashCommandOutputSection, ArgumentCompletion, SlashCommand, SlashCommandOutput, SlashCommandOutputSection,
}; };
use gpui::{AppContext, AssetSource, Task, WeakView}; use gpui::{AppContext, Task, WeakView};
use language::LspAdapterDelegate; use language::LspAdapterDelegate;
use text::LineEnding;
use ui::prelude::*; use ui::prelude::*;
use workspace::Workspace; use workspace::Workspace;
pub(crate) struct WorkflowSlashCommand; pub(crate) struct WorkflowSlashCommand {
prompt_builder: Arc<PromptBuilder>,
}
impl WorkflowSlashCommand {
pub fn new(prompt_builder: Arc<PromptBuilder>) -> Self {
Self { prompt_builder }
}
}
impl SlashCommand for WorkflowSlashCommand { impl SlashCommand for WorkflowSlashCommand {
fn name(&self) -> String { fn name(&self) -> String {
@ -46,26 +55,22 @@ impl SlashCommand for WorkflowSlashCommand {
_argument: Option<&str>, _argument: Option<&str>,
_workspace: WeakView<Workspace>, _workspace: WeakView<Workspace>,
_delegate: Option<Arc<dyn LspAdapterDelegate>>, _delegate: Option<Arc<dyn LspAdapterDelegate>>,
_cx: &mut WindowContext, cx: &mut WindowContext,
) -> Task<Result<SlashCommandOutput>> { ) -> Task<Result<SlashCommandOutput>> {
let mut text = match Assets let prompt_builder = self.prompt_builder.clone();
.load("prompts/edit_workflow.md") cx.spawn(|_cx| async move {
.and_then(|prompt| prompt.context("prompts/edit_workflow.md not found")) let text = prompt_builder.generate_workflow_prompt()?;
{ let range = 0..text.len();
Ok(prompt) => String::from_utf8_lossy(&prompt).into_owned(),
Err(error) => return Task::ready(Err(error)),
};
LineEnding::normalize(&mut text);
let range = 0..text.len();
Task::ready(Ok(SlashCommandOutput { Ok(SlashCommandOutput {
text, text,
sections: vec![SlashCommandOutputSection { sections: vec![SlashCommandOutputSection {
range, range,
icon: IconName::Route, icon: IconName::Route,
label: "Workflow".into(), label: "Workflow".into(),
}], }],
run_commands_in_text: false, run_commands_in_text: false,
})) })
})
} }
} }

View file

@ -6,7 +6,7 @@ use crate::{
}, },
}; };
use anyhow::{anyhow, Result}; use anyhow::{anyhow, Result};
use assistant::ContextStore; use assistant::{ContextStore, PromptBuilder};
use call::{room, ActiveCall, ParticipantLocation, Room}; use call::{room, ActiveCall, ParticipantLocation, Room};
use client::{User, RECEIVE_TIMEOUT}; use client::{User, RECEIVE_TIMEOUT};
use collections::{HashMap, HashSet}; use collections::{HashMap, HashSet};
@ -6485,12 +6485,13 @@ async fn test_context_collaboration_with_reconnect(
assert_eq!(project.collaborators().len(), 1); assert_eq!(project.collaborators().len(), 1);
}); });
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
let context_store_a = cx_a let context_store_a = cx_a
.update(|cx| ContextStore::new(project_a.clone(), cx)) .update(|cx| ContextStore::new(project_a.clone(), prompt_builder.clone(), cx))
.await .await
.unwrap(); .unwrap();
let context_store_b = cx_b let context_store_b = cx_b
.update(|cx| ContextStore::new(project_b.clone(), cx)) .update(|cx| ContextStore::new(project_b.clone(), prompt_builder.clone(), cx))
.await .await
.unwrap(); .unwrap();

View file

@ -7,6 +7,7 @@ mod reliability;
mod zed; mod zed;
use anyhow::{anyhow, Context as _, Result}; use anyhow::{anyhow, Context as _, Result};
use assistant::PromptBuilder;
use clap::{command, Parser}; use clap::{command, Parser};
use cli::FORCE_CLI_MODE_ENV_VAR_NAME; use cli::FORCE_CLI_MODE_ENV_VAR_NAME;
use client::{parse_zed_link, Client, DevServerToken, UserStore}; use client::{parse_zed_link, Client, DevServerToken, UserStore};
@ -161,7 +162,7 @@ fn init_headless(
} }
// init_common is called for both headless and normal mode. // init_common is called for both headless and normal mode.
fn init_common(app_state: Arc<AppState>, cx: &mut AppContext) { fn init_common(app_state: Arc<AppState>, cx: &mut AppContext) -> Arc<PromptBuilder> {
SystemAppearance::init(cx); SystemAppearance::init(cx);
theme::init(theme::LoadThemes::All(Box::new(Assets)), cx); theme::init(theme::LoadThemes::All(Box::new(Assets)), cx);
command_palette::init(cx); command_palette::init(cx);
@ -182,7 +183,7 @@ fn init_common(app_state: Arc<AppState>, cx: &mut AppContext) {
); );
snippet_provider::init(cx); snippet_provider::init(cx);
inline_completion_registry::init(app_state.client.telemetry().clone(), cx); inline_completion_registry::init(app_state.client.telemetry().clone(), cx);
assistant::init(app_state.fs.clone(), app_state.client.clone(), cx); let prompt_builder = assistant::init(app_state.fs.clone(), app_state.client.clone(), cx);
repl::init( repl::init(
app_state.fs.clone(), app_state.fs.clone(),
app_state.client.telemetry().clone(), app_state.client.telemetry().clone(),
@ -196,9 +197,14 @@ fn init_common(app_state: Arc<AppState>, cx: &mut AppContext) {
ThemeRegistry::global(cx), ThemeRegistry::global(cx),
cx, cx,
); );
prompt_builder
} }
fn init_ui(app_state: Arc<AppState>, cx: &mut AppContext) -> Result<()> { fn init_ui(
app_state: Arc<AppState>,
prompt_builder: Arc<PromptBuilder>,
cx: &mut AppContext,
) -> Result<()> {
match cx.try_global::<AppMode>() { match cx.try_global::<AppMode>() {
Some(AppMode::Headless(_)) => { Some(AppMode::Headless(_)) => {
return Err(anyhow!( return Err(anyhow!(
@ -289,7 +295,7 @@ fn init_ui(app_state: Arc<AppState>, cx: &mut AppContext) -> Result<()> {
watch_file_types(fs.clone(), cx); watch_file_types(fs.clone(), cx);
cx.set_menus(app_menus()); cx.set_menus(app_menus());
initialize_workspace(app_state.clone(), cx); initialize_workspace(app_state.clone(), prompt_builder, cx);
cx.activate(true); cx.activate(true);
@ -467,7 +473,7 @@ fn main() {
auto_update::init(client.http_client(), cx); auto_update::init(client.http_client(), cx);
reliability::init(client.http_client(), installation_id, cx); reliability::init(client.http_client(), installation_id, cx);
init_common(app_state.clone(), cx); let prompt_builder = init_common(app_state.clone(), cx);
let args = Args::parse(); let args = Args::parse();
let urls: Vec<_> = args let urls: Vec<_> = args
@ -487,7 +493,7 @@ fn main() {
.and_then(|urls| OpenRequest::parse(urls, cx).log_err()) .and_then(|urls| OpenRequest::parse(urls, cx).log_err())
{ {
Some(request) => { Some(request) => {
handle_open_request(request, app_state.clone(), cx); handle_open_request(request, app_state.clone(), prompt_builder.clone(), cx);
} }
None => { None => {
if let Some(dev_server_token) = args.dev_server_token { if let Some(dev_server_token) = args.dev_server_token {
@ -503,7 +509,7 @@ fn main() {
}) })
.detach(); .detach();
} else { } else {
init_ui(app_state.clone(), cx).unwrap(); init_ui(app_state.clone(), prompt_builder.clone(), cx).unwrap();
cx.spawn({ cx.spawn({
let app_state = app_state.clone(); let app_state = app_state.clone();
|mut cx| async move { |mut cx| async move {
@ -518,11 +524,12 @@ fn main() {
} }
let app_state = app_state.clone(); let app_state = app_state.clone();
let prompt_builder = prompt_builder.clone();
cx.spawn(move |cx| async move { cx.spawn(move |cx| async move {
while let Some(urls) = open_rx.next().await { while let Some(urls) = open_rx.next().await {
cx.update(|cx| { cx.update(|cx| {
if let Some(request) = OpenRequest::parse(urls, cx).log_err() { if let Some(request) = OpenRequest::parse(urls, cx).log_err() {
handle_open_request(request, app_state.clone(), cx); handle_open_request(request, app_state.clone(), prompt_builder.clone(), cx);
} }
}) })
.ok(); .ok();
@ -532,15 +539,20 @@ fn main() {
}); });
} }
fn handle_open_request(request: OpenRequest, app_state: Arc<AppState>, cx: &mut AppContext) { fn handle_open_request(
request: OpenRequest,
app_state: Arc<AppState>,
prompt_builder: Arc<PromptBuilder>,
cx: &mut AppContext,
) {
if let Some(connection) = request.cli_connection { if let Some(connection) = request.cli_connection {
let app_state = app_state.clone(); let app_state = app_state.clone();
cx.spawn(move |cx| handle_cli_connection(connection, app_state, cx)) cx.spawn(move |cx| handle_cli_connection(connection, app_state, prompt_builder, cx))
.detach(); .detach();
return; return;
} }
if let Err(e) = init_ui(app_state.clone(), cx) { if let Err(e) = init_ui(app_state.clone(), prompt_builder, cx) {
fail_to_open_window(e, cx); fail_to_open_window(e, cx);
return; return;
}; };

View file

@ -7,6 +7,7 @@ pub(crate) mod only_instance;
mod open_listener; mod open_listener;
pub use app_menus::*; pub use app_menus::*;
use assistant::PromptBuilder;
use breadcrumbs::Breadcrumbs; use breadcrumbs::Breadcrumbs;
use client::ZED_URL_SCHEME; use client::ZED_URL_SCHEME;
use collections::VecDeque; use collections::VecDeque;
@ -119,7 +120,11 @@ pub fn build_window_options(display_uuid: Option<Uuid>, cx: &mut AppContext) ->
} }
} }
pub fn initialize_workspace(app_state: Arc<AppState>, cx: &mut AppContext) { pub fn initialize_workspace(
app_state: Arc<AppState>,
prompt_builder: Arc<PromptBuilder>,
cx: &mut AppContext,
) {
cx.observe_new_views(move |workspace: &mut Workspace, cx| { cx.observe_new_views(move |workspace: &mut Workspace, cx| {
let workspace_handle = cx.view().clone(); let workspace_handle = cx.view().clone();
let center_pane = workspace.active_pane().clone(); let center_pane = workspace.active_pane().clone();
@ -238,9 +243,10 @@ pub fn initialize_workspace(app_state: Arc<AppState>, cx: &mut AppContext) {
}); });
} }
let prompt_builder = prompt_builder.clone();
cx.spawn(|workspace_handle, mut cx| async move { cx.spawn(|workspace_handle, mut cx| async move {
let assistant_panel = let assistant_panel =
assistant::AssistantPanel::load(workspace_handle.clone(), cx.clone()); assistant::AssistantPanel::load(workspace_handle.clone(), prompt_builder, cx.clone());
let project_panel = ProjectPanel::load(workspace_handle.clone(), cx.clone()); let project_panel = ProjectPanel::load(workspace_handle.clone(), cx.clone());
let outline_panel = OutlinePanel::load(workspace_handle.clone(), cx.clone()); let outline_panel = OutlinePanel::load(workspace_handle.clone(), cx.clone());
@ -3474,14 +3480,15 @@ mod tests {
app_state.fs.clone(), app_state.fs.clone(),
cx, cx,
); );
assistant::init(app_state.fs.clone(), app_state.client.clone(), cx); let prompt_builder =
assistant::init(app_state.fs.clone(), app_state.client.clone(), cx);
repl::init( repl::init(
app_state.fs.clone(), app_state.fs.clone(),
app_state.client.telemetry().clone(), app_state.client.telemetry().clone(),
cx, cx,
); );
tasks_ui::init(cx); tasks_ui::init(cx);
initialize_workspace(app_state.clone(), cx); initialize_workspace(app_state.clone(), prompt_builder, cx);
app_state app_state
}) })
} }

View file

@ -1,6 +1,7 @@
use crate::restorable_workspace_locations; use crate::restorable_workspace_locations;
use crate::{handle_open_request, init_headless, init_ui}; use crate::{handle_open_request, init_headless, init_ui};
use anyhow::{anyhow, Context, Result}; use anyhow::{anyhow, Context, Result};
use assistant::PromptBuilder;
use cli::{ipc, IpcHandshake}; use cli::{ipc, IpcHandshake};
use cli::{ipc::IpcSender, CliRequest, CliResponse}; use cli::{ipc::IpcSender, CliRequest, CliResponse};
use client::parse_zed_link; use client::parse_zed_link;
@ -245,6 +246,7 @@ pub async fn open_paths_with_positions(
pub async fn handle_cli_connection( pub async fn handle_cli_connection(
(mut requests, responses): (mpsc::Receiver<CliRequest>, IpcSender<CliResponse>), (mut requests, responses): (mpsc::Receiver<CliRequest>, IpcSender<CliResponse>),
app_state: Arc<AppState>, app_state: Arc<AppState>,
prompt_builder: Arc<PromptBuilder>,
mut cx: AsyncAppContext, mut cx: AsyncAppContext,
) { ) {
if let Some(request) = requests.next().await { if let Some(request) = requests.next().await {
@ -289,7 +291,12 @@ pub async fn handle_cli_connection(
cx.update(|cx| { cx.update(|cx| {
match OpenRequest::parse(urls, cx) { match OpenRequest::parse(urls, cx) {
Ok(open_request) => { Ok(open_request) => {
handle_open_request(open_request, app_state.clone(), cx); handle_open_request(
open_request,
app_state.clone(),
prompt_builder.clone(),
cx,
);
responses.send(CliResponse::Exit { status: 0 }).log_err(); responses.send(CliResponse::Exit { status: 0 }).log_err();
} }
Err(e) => { Err(e) => {
@ -307,7 +314,7 @@ pub async fn handle_cli_connection(
} }
if let Err(e) = cx if let Err(e) = cx
.update(|cx| init_ui(app_state.clone(), cx)) .update(|cx| init_ui(app_state.clone(), prompt_builder.clone(), cx))
.and_then(|r| r) .and_then(|r| r)
{ {
responses responses

View file

@ -221,6 +221,10 @@ Zed allows you to override the default prompts used for various assistant featur
given system information and latest terminal output if relevant. given system information and latest terminal output if relevant.
``` ```
You can customize these templates to better suit your needs while maintaining the core structure and variables used by Zed. Zed will automatically reload your prompt overrides when they change on disk. 3. `edit_workflow.hbs`: Used for generating the edit workflow prompt.
4. `step_resolution.hbs`: Used for generating the step resolution prompt.
You can customize these templates to better suit your needs while maintaining the core structure and variables used by Zed. Zed will automatically reload your prompt overrides when they change on disk. Consult Zed's assets/prompts directory for current versions you can play with.
Be sure you want to override these, as you'll miss out on iteration on our built in features. This should be primarily used when developing Zed. Be sure you want to override these, as you'll miss out on iteration on our built in features. This should be primarily used when developing Zed.