From 990774247ecb61ddf69f0325d28eb3b203fc3699 Mon Sep 17 00:00:00 2001 From: Nathan Sobo Date: Tue, 6 Aug 2024 21:47:42 -0600 Subject: [PATCH] Allow /workflow and step resolution prompts to be overridden (#15892) This will help us as we hit issues with the /workflow and step resolution. We can override the baked-in prompts and make tweaks, then import our refinements back into the source tree when we're ready. Release Notes: - N/A --- .../{edit_workflow.md => edit_workflow.hbs} | 0 ...step_resolution.md => step_resolution.hbs} | 0 crates/assistant/src/assistant.rs | 42 +++++++++----- crates/assistant/src/assistant_panel.rs | 6 +- crates/assistant/src/context.rs | 53 ++++++++++++++---- crates/assistant/src/context_store.rs | 19 ++++++- crates/assistant/src/prompt_library.rs | 16 +----- crates/assistant/src/prompts.rs | 36 +++++++----- .../src/slash_command/workflow_command.rs | 55 ++++++++++--------- crates/collab/src/tests/integration_tests.rs | 7 ++- crates/zed/src/main.rs | 34 ++++++++---- crates/zed/src/zed.rs | 15 +++-- crates/zed/src/zed/open_listener.rs | 11 +++- docs/src/language-model-integration.md | 6 +- 14 files changed, 196 insertions(+), 104 deletions(-) rename assets/prompts/{edit_workflow.md => edit_workflow.hbs} (100%) rename assets/prompts/{step_resolution.md => step_resolution.hbs} (100%) diff --git a/assets/prompts/edit_workflow.md b/assets/prompts/edit_workflow.hbs similarity index 100% rename from assets/prompts/edit_workflow.md rename to assets/prompts/edit_workflow.hbs diff --git a/assets/prompts/step_resolution.md b/assets/prompts/step_resolution.hbs similarity index 100% rename from assets/prompts/step_resolution.md rename to assets/prompts/step_resolution.hbs diff --git a/crates/assistant/src/assistant.rs b/crates/assistant/src/assistant.rs index 4896aadeec..04e4f69505 100644 --- a/crates/assistant/src/assistant.rs +++ b/crates/assistant/src/assistant.rs @@ -25,6 +25,7 @@ use language_model::{ LanguageModelId, LanguageModelProviderId, LanguageModelRegistry, LanguageModelResponseMessage, }; pub(crate) use model_selector::*; +pub use prompts::PromptBuilder; use semantic_index::{CloudEmbeddingProvider, SemanticIndex}; use serde::{Deserialize, Serialize}; use settings::{update_settings_file, Settings, SettingsStore}; @@ -163,7 +164,7 @@ impl Assistant { } } -pub fn init(fs: Arc, client: Arc, cx: &mut AppContext) { +pub fn init(fs: Arc, client: Arc, cx: &mut AppContext) -> Arc { cx.set_global(Assistant::default()); AssistantSettings::register(cx); @@ -196,19 +197,25 @@ pub fn init(fs: Arc, client: Arc, cx: &mut AppContext) { prompt_library::init(cx); init_language_model_settings(cx); assistant_slash_command::init(cx); - register_slash_commands(cx); assistant_panel::init(cx); - if let Some(prompt_builder) = prompts::PromptBuilder::new(Some((fs.clone(), cx))).log_err() { - let prompt_builder = Arc::new(prompt_builder); - inline_assistant::init( - fs.clone(), - prompt_builder.clone(), - client.telemetry().clone(), - cx, - ); - terminal_inline_assistant::init(fs.clone(), prompt_builder, client.telemetry().clone(), cx); - } + let prompt_builder = prompts::PromptBuilder::new(Some((fs.clone(), cx))) + .log_err() + .map(Arc::new) + .unwrap_or_else(|| Arc::new(prompts::PromptBuilder::new(None).unwrap())); + register_slash_commands(Some(prompt_builder.clone()), cx); + inline_assistant::init( + fs.clone(), + prompt_builder.clone(), + client.telemetry().clone(), + cx, + ); + terminal_inline_assistant::init( + fs.clone(), + prompt_builder.clone(), + client.telemetry().clone(), + cx, + ); IndexedDocsRegistry::init_global(cx); CommandPaletteFilter::update_global(cx, |filter, _cx| { @@ -226,6 +233,8 @@ pub fn init(fs: Arc, client: Arc, cx: &mut AppContext) { }); }) .detach(); + + prompt_builder } fn init_language_model_settings(cx: &mut AppContext) { @@ -256,7 +265,7 @@ fn update_active_language_model_from_settings(cx: &mut AppContext) { }); } -fn register_slash_commands(cx: &mut AppContext) { +fn register_slash_commands(prompt_builder: Option>, cx: &mut AppContext) { let slash_command_registry = SlashCommandRegistry::global(cx); slash_command_registry.register_command(file_command::FileSlashCommand, true); slash_command_registry.register_command(active_command::ActiveSlashCommand, true); @@ -270,7 +279,12 @@ fn register_slash_commands(cx: &mut AppContext) { slash_command_registry.register_command(now_command::NowSlashCommand, true); slash_command_registry.register_command(diagnostics_command::DiagnosticsSlashCommand, true); slash_command_registry.register_command(docs_command::DocsSlashCommand, true); - slash_command_registry.register_command(workflow_command::WorkflowSlashCommand, true); + if let Some(prompt_builder) = prompt_builder { + slash_command_registry.register_command( + workflow_command::WorkflowSlashCommand::new(prompt_builder), + true, + ); + } slash_command_registry.register_command(fetch_command::FetchSlashCommand, false); } diff --git a/crates/assistant/src/assistant_panel.rs b/crates/assistant/src/assistant_panel.rs index 9ccd8356fc..c5c504c401 100644 --- a/crates/assistant/src/assistant_panel.rs +++ b/crates/assistant/src/assistant_panel.rs @@ -2,6 +2,7 @@ use crate::{ assistant_settings::{AssistantDockPosition, AssistantSettings}, humanize_token_count, prompt_library::open_prompt_library, + prompts::PromptBuilder, slash_command::{ default_command::DefaultSlashCommand, docs_command::{DocsSlashCommand, DocsSlashCommandArgs}, @@ -315,14 +316,17 @@ impl PickerDelegate for SavedContextPickerDelegate { impl AssistantPanel { pub fn load( workspace: WeakView, + prompt_builder: Arc, cx: AsyncWindowContext, ) -> Task>> { cx.spawn(|mut cx| async move { let context_store = workspace .update(&mut cx, |workspace, cx| { - ContextStore::new(workspace.project().clone(), cx) + let project = workspace.project().clone(); + ContextStore::new(project, prompt_builder.clone(), cx) })? .await?; + workspace.update(&mut cx, |workspace, cx| { // TODO: deserialize state. cx.new_view(|cx| Self::new(workspace, context_store, cx)) diff --git a/crates/assistant/src/context.rs b/crates/assistant/src/context.rs index 5b4e4a980c..b001794d9e 100644 --- a/crates/assistant/src/context.rs +++ b/crates/assistant/src/context.rs @@ -1,5 +1,5 @@ use crate::{ - prompt_library::PromptStore, slash_command::SlashCommandLine, AssistantPanel, InitialInsertion, + prompts::PromptBuilder, slash_command::SlashCommandLine, AssistantPanel, InitialInsertion, InlineAssistId, InlineAssistant, MessageId, MessageStatus, }; use anyhow::{anyhow, Context as _, Result}; @@ -611,6 +611,7 @@ pub struct Context { language_registry: Arc, workflow_steps: Vec, project: Option>, + prompt_builder: Arc, } impl EventEmitter for Context {} @@ -620,6 +621,7 @@ impl Context { language_registry: Arc, project: Option>, telemetry: Option>, + prompt_builder: Arc, cx: &mut ModelContext, ) -> Self { Self::new( @@ -627,17 +629,20 @@ impl Context { ReplicaId::default(), language::Capability::ReadWrite, language_registry, + prompt_builder, project, telemetry, cx, ) } + #[allow(clippy::too_many_arguments)] pub fn new( id: ContextId, replica_id: ReplicaId, capability: language::Capability, language_registry: Arc, + prompt_builder: Arc, project: Option>, telemetry: Option>, cx: &mut ModelContext, @@ -680,6 +685,7 @@ impl Context { project, language_registry, workflow_steps: Vec::new(), + prompt_builder, }; let first_message_id = MessageId(clock::Lamport { @@ -749,6 +755,7 @@ impl Context { saved_context: SavedContext, path: PathBuf, language_registry: Arc, + prompt_builder: Arc, project: Option>, telemetry: Option>, cx: &mut ModelContext, @@ -759,6 +766,7 @@ impl Context { ReplicaId::default(), language::Capability::ReadWrite, language_registry, + prompt_builder, project, telemetry, cx, @@ -1246,9 +1254,9 @@ impl Context { cx.spawn(|this, mut cx| { async move { - let prompt_store = cx.update(|cx| PromptStore::global(cx))?.await?; - - let mut prompt = prompt_store.step_resolution_prompt()?; + let mut prompt = this.update(&mut cx, |this, _| { + this.prompt_builder.generate_step_resolution_prompt() + })??; prompt.push_str(&step_text); request.messages.push(LanguageModelRequestMessage { @@ -2448,8 +2456,9 @@ mod tests { cx.set_global(settings_store); assistant_panel::init(cx); let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone())); - - let context = cx.new_model(|cx| Context::local(registry, None, None, cx)); + let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap()); + let context = + cx.new_model(|cx| Context::local(registry, None, None, prompt_builder.clone(), cx)); let buffer = context.read(cx).buffer.clone(); let message_1 = context.read(cx).message_anchors[0].clone(); @@ -2580,7 +2589,9 @@ mod tests { assistant_panel::init(cx); let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone())); - let context = cx.new_model(|cx| Context::local(registry, None, None, cx)); + let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap()); + let context = + cx.new_model(|cx| Context::local(registry, None, None, prompt_builder.clone(), cx)); let buffer = context.read(cx).buffer.clone(); let message_1 = context.read(cx).message_anchors[0].clone(); @@ -2673,7 +2684,9 @@ mod tests { cx.set_global(settings_store); assistant_panel::init(cx); let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone())); - let context = cx.new_model(|cx| Context::local(registry, None, None, cx)); + let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap()); + let context = + cx.new_model(|cx| Context::local(registry, None, None, prompt_builder.clone(), cx)); let buffer = context.read(cx).buffer.clone(); let message_1 = context.read(cx).message_anchors[0].clone(); @@ -2778,7 +2791,10 @@ mod tests { slash_command_registry.register_command(active_command::ActiveSlashCommand, false); let registry = Arc::new(LanguageRegistry::test(cx.executor())); - let context = cx.new_model(|cx| Context::local(registry.clone(), None, None, cx)); + let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap()); + let context = cx.new_model(|cx| { + Context::local(registry.clone(), None, None, prompt_builder.clone(), cx) + }); let output_ranges = Rc::new(RefCell::new(HashSet::default())); context.update(cx, |_, cx| { @@ -2905,7 +2921,16 @@ mod tests { let registry = Arc::new(LanguageRegistry::test(cx.executor())); // Create a new context - let context = cx.new_model(|cx| Context::local(registry.clone(), Some(project), None, cx)); + let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap()); + let context = cx.new_model(|cx| { + Context::local( + registry.clone(), + Some(project), + None, + prompt_builder.clone(), + cx, + ) + }); let buffer = context.read_with(cx, |context, _| context.buffer.clone()); // Simulate user input @@ -3070,7 +3095,10 @@ mod tests { cx.update(LanguageModelRegistry::test); cx.update(assistant_panel::init); let registry = Arc::new(LanguageRegistry::test(cx.executor())); - let context = cx.new_model(|cx| Context::local(registry.clone(), None, None, cx)); + let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap()); + let context = cx.new_model(|cx| { + Context::local(registry.clone(), None, None, prompt_builder.clone(), cx) + }); let buffer = context.read_with(cx, |context, _| context.buffer.clone()); let message_0 = context.read_with(cx, |context, _| context.message_anchors[0].id); let message_1 = context.update(cx, |context, cx| { @@ -3109,6 +3137,7 @@ mod tests { serialized_context, Default::default(), registry.clone(), + prompt_builder.clone(), None, None, cx, @@ -3158,6 +3187,7 @@ mod tests { let num_peers = rng.gen_range(min_peers..=max_peers); let context_id = ContextId::new(); + let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap()); for i in 0..num_peers { let context = cx.new_model(|cx| { Context::new( @@ -3165,6 +3195,7 @@ mod tests { i as ReplicaId, language::Capability::ReadWrite, registry.clone(), + prompt_builder.clone(), None, None, cx, diff --git a/crates/assistant/src/context_store.rs b/crates/assistant/src/context_store.rs index dab709bd20..ce82b5eca7 100644 --- a/crates/assistant/src/context_store.rs +++ b/crates/assistant/src/context_store.rs @@ -1,6 +1,6 @@ use crate::{ - Context, ContextEvent, ContextId, ContextOperation, ContextVersion, SavedContext, - SavedContextMetadata, + prompts::PromptBuilder, Context, ContextEvent, ContextId, ContextOperation, ContextVersion, + SavedContext, SavedContextMetadata, }; use anyhow::{anyhow, Context as _, Result}; use client::{proto, telemetry::Telemetry, Client, TypedEnvelope}; @@ -52,6 +52,7 @@ pub struct ContextStore { project_is_shared: bool, client_subscription: Option, _project_subscriptions: Vec, + prompt_builder: Arc, } pub enum ContextStoreEvent { @@ -82,7 +83,11 @@ impl ContextHandle { } impl ContextStore { - pub fn new(project: Model, cx: &mut AppContext) -> Task>> { + pub fn new( + project: Model, + prompt_builder: Arc, + cx: &mut AppContext, + ) -> Task>> { let fs = project.read(cx).fs().clone(); let languages = project.read(cx).languages().clone(); let telemetry = project.read(cx).client().telemetry().clone(); @@ -117,6 +122,7 @@ impl ContextStore { project_is_shared: false, client: project.read(cx).client(), project: project.clone(), + prompt_builder, }; this.handle_project_changed(project, cx); this.synchronize_contexts(cx); @@ -334,6 +340,7 @@ impl ContextStore { self.languages.clone(), Some(self.project.clone()), Some(self.telemetry.clone()), + self.prompt_builder.clone(), cx, ) }); @@ -358,6 +365,7 @@ impl ContextStore { let language_registry = self.languages.clone(); let project = self.project.clone(); let telemetry = self.telemetry.clone(); + let prompt_builder = self.prompt_builder.clone(); let request = self.client.request(proto::CreateContext { project_id }); cx.spawn(|this, mut cx| async move { let response = request.await?; @@ -369,6 +377,7 @@ impl ContextStore { replica_id, capability, language_registry, + prompt_builder, Some(project), Some(telemetry), cx, @@ -417,6 +426,7 @@ impl ContextStore { SavedContext::from_json(&saved_context) } }); + let prompt_builder = self.prompt_builder.clone(); cx.spawn(|this, mut cx| async move { let saved_context = load.await?; @@ -425,6 +435,7 @@ impl ContextStore { saved_context, path.clone(), languages, + prompt_builder, Some(project), Some(telemetry), cx, @@ -493,6 +504,7 @@ impl ContextStore { project_id, context_id: context_id.to_proto(), }); + let prompt_builder = self.prompt_builder.clone(); cx.spawn(|this, mut cx| async move { let response = request.await?; let context_proto = response.context.context("invalid context")?; @@ -502,6 +514,7 @@ impl ContextStore { replica_id, capability, language_registry, + prompt_builder, Some(project), Some(telemetry), cx, diff --git a/crates/assistant/src/prompt_library.rs b/crates/assistant/src/prompt_library.rs index 23f76177f7..a0b25bf679 100644 --- a/crates/assistant/src/prompt_library.rs +++ b/crates/assistant/src/prompt_library.rs @@ -2,7 +2,6 @@ use crate::{ slash_command::SlashCommandCompletionProvider, AssistantPanel, InlineAssist, InlineAssistant, }; use anyhow::{anyhow, Result}; -use assets::Assets; use chrono::{DateTime, Utc}; use collections::{HashMap, HashSet}; use editor::{actions::Tab, CurrentLineHighlight, Editor, EditorElement, EditorEvent, EditorStyle}; @@ -12,8 +11,8 @@ use futures::{ }; use fuzzy::StringMatchCandidate; use gpui::{ - actions, point, size, transparent_black, AppContext, AssetSource, BackgroundExecutor, Bounds, - EventEmitter, Global, HighlightStyle, PromptLevel, ReadGlobal, Subscription, Task, TextStyle, + actions, point, size, transparent_black, AppContext, BackgroundExecutor, Bounds, EventEmitter, + Global, HighlightStyle, PromptLevel, ReadGlobal, Subscription, Task, TextStyle, TitlebarOptions, UpdateGlobal, View, WindowBounds, WindowHandle, WindowOptions, }; use heed::{ @@ -1466,17 +1465,6 @@ impl PromptStore { fn first(&self) -> Option { self.metadata_cache.read().metadata.first().cloned() } - - pub fn step_resolution_prompt(&self) -> Result { - let path = "prompts/step_resolution.md"; - - Ok(String::from_utf8( - Assets - .load(path)? - .ok_or_else(|| anyhow!("{path} not found"))? - .to_vec(), - )?) - } } /// Wraps a shared future to a prompt store so it can be assigned as a context global. diff --git a/crates/assistant/src/prompts.rs b/crates/assistant/src/prompts.rs index d61d127880..f6fb203e96 100644 --- a/crates/assistant/src/prompts.rs +++ b/crates/assistant/src/prompts.rs @@ -125,22 +125,20 @@ impl PromptBuilder { } fn register_templates(handlebars: &mut Handlebars) -> Result<(), Box> { - let content_prompt = Assets::get("prompts/content_prompt.hbs") - .expect("Content prompt template not found") - .data; - let terminal_assistant_prompt = Assets::get("prompts/terminal_assistant_prompt.hbs") - .expect("Terminal assistant prompt template not found") - .data; + let mut register_template = |id: &str| { + let prompt = Assets::get(&format!("prompts/{}.hbs", id)) + .unwrap_or_else(|| panic!("{} prompt template not found", id)) + .data; + handlebars + .register_template_string(id, String::from_utf8_lossy(&prompt)) + .map_err(Box::new) + }; + + register_template("content_prompt")?; + register_template("terminal_assistant_prompt")?; + register_template("edit_workflow")?; + register_template("step_resolution")?; - handlebars - .register_template_string("content_prompt", String::from_utf8_lossy(&content_prompt)) - .map_err(Box::new)?; - handlebars - .register_template_string( - "terminal_assistant_prompt", - String::from_utf8_lossy(&terminal_assistant_prompt), - ) - .map_err(Box::new)?; Ok(()) } @@ -236,4 +234,12 @@ impl PromptBuilder { .lock() .render("terminal_assistant_prompt", &context) } + + pub fn generate_workflow_prompt(&self) -> Result { + self.handlebars.lock().render("edit_workflow", &()) + } + + pub fn generate_step_resolution_prompt(&self) -> Result { + self.handlebars.lock().render("step_resolution", &()) + } } diff --git a/crates/assistant/src/slash_command/workflow_command.rs b/crates/assistant/src/slash_command/workflow_command.rs index f55275f011..d2708c38d2 100644 --- a/crates/assistant/src/slash_command/workflow_command.rs +++ b/crates/assistant/src/slash_command/workflow_command.rs @@ -1,18 +1,27 @@ -use std::sync::atomic::AtomicBool; +use crate::prompts::PromptBuilder; use std::sync::Arc; -use anyhow::{Context as _, Result}; -use assets::Assets; +use std::sync::atomic::AtomicBool; + +use anyhow::Result; use assistant_slash_command::{ ArgumentCompletion, SlashCommand, SlashCommandOutput, SlashCommandOutputSection, }; -use gpui::{AppContext, AssetSource, Task, WeakView}; +use gpui::{AppContext, Task, WeakView}; use language::LspAdapterDelegate; -use text::LineEnding; use ui::prelude::*; + use workspace::Workspace; -pub(crate) struct WorkflowSlashCommand; +pub(crate) struct WorkflowSlashCommand { + prompt_builder: Arc, +} + +impl WorkflowSlashCommand { + pub fn new(prompt_builder: Arc) -> Self { + Self { prompt_builder } + } +} impl SlashCommand for WorkflowSlashCommand { fn name(&self) -> String { @@ -46,26 +55,22 @@ impl SlashCommand for WorkflowSlashCommand { _argument: Option<&str>, _workspace: WeakView, _delegate: Option>, - _cx: &mut WindowContext, + cx: &mut WindowContext, ) -> Task> { - let mut text = match Assets - .load("prompts/edit_workflow.md") - .and_then(|prompt| prompt.context("prompts/edit_workflow.md not found")) - { - Ok(prompt) => String::from_utf8_lossy(&prompt).into_owned(), - Err(error) => return Task::ready(Err(error)), - }; - LineEnding::normalize(&mut text); - let range = 0..text.len(); + let prompt_builder = self.prompt_builder.clone(); + cx.spawn(|_cx| async move { + let text = prompt_builder.generate_workflow_prompt()?; + let range = 0..text.len(); - Task::ready(Ok(SlashCommandOutput { - text, - sections: vec![SlashCommandOutputSection { - range, - icon: IconName::Route, - label: "Workflow".into(), - }], - run_commands_in_text: false, - })) + Ok(SlashCommandOutput { + text, + sections: vec![SlashCommandOutputSection { + range, + icon: IconName::Route, + label: "Workflow".into(), + }], + run_commands_in_text: false, + }) + }) } } diff --git a/crates/collab/src/tests/integration_tests.rs b/crates/collab/src/tests/integration_tests.rs index f5b7b8903b..3e95ca7659 100644 --- a/crates/collab/src/tests/integration_tests.rs +++ b/crates/collab/src/tests/integration_tests.rs @@ -6,7 +6,7 @@ use crate::{ }, }; use anyhow::{anyhow, Result}; -use assistant::ContextStore; +use assistant::{ContextStore, PromptBuilder}; use call::{room, ActiveCall, ParticipantLocation, Room}; use client::{User, RECEIVE_TIMEOUT}; use collections::{HashMap, HashSet}; @@ -6485,12 +6485,13 @@ async fn test_context_collaboration_with_reconnect( assert_eq!(project.collaborators().len(), 1); }); + let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap()); let context_store_a = cx_a - .update(|cx| ContextStore::new(project_a.clone(), cx)) + .update(|cx| ContextStore::new(project_a.clone(), prompt_builder.clone(), cx)) .await .unwrap(); let context_store_b = cx_b - .update(|cx| ContextStore::new(project_b.clone(), cx)) + .update(|cx| ContextStore::new(project_b.clone(), prompt_builder.clone(), cx)) .await .unwrap(); diff --git a/crates/zed/src/main.rs b/crates/zed/src/main.rs index 1570d75b40..e56bfe5b92 100644 --- a/crates/zed/src/main.rs +++ b/crates/zed/src/main.rs @@ -7,6 +7,7 @@ mod reliability; mod zed; use anyhow::{anyhow, Context as _, Result}; +use assistant::PromptBuilder; use clap::{command, Parser}; use cli::FORCE_CLI_MODE_ENV_VAR_NAME; use client::{parse_zed_link, Client, DevServerToken, UserStore}; @@ -161,7 +162,7 @@ fn init_headless( } // init_common is called for both headless and normal mode. -fn init_common(app_state: Arc, cx: &mut AppContext) { +fn init_common(app_state: Arc, cx: &mut AppContext) -> Arc { SystemAppearance::init(cx); theme::init(theme::LoadThemes::All(Box::new(Assets)), cx); command_palette::init(cx); @@ -182,7 +183,7 @@ fn init_common(app_state: Arc, cx: &mut AppContext) { ); snippet_provider::init(cx); inline_completion_registry::init(app_state.client.telemetry().clone(), cx); - assistant::init(app_state.fs.clone(), app_state.client.clone(), cx); + let prompt_builder = assistant::init(app_state.fs.clone(), app_state.client.clone(), cx); repl::init( app_state.fs.clone(), app_state.client.telemetry().clone(), @@ -196,9 +197,14 @@ fn init_common(app_state: Arc, cx: &mut AppContext) { ThemeRegistry::global(cx), cx, ); + prompt_builder } -fn init_ui(app_state: Arc, cx: &mut AppContext) -> Result<()> { +fn init_ui( + app_state: Arc, + prompt_builder: Arc, + cx: &mut AppContext, +) -> Result<()> { match cx.try_global::() { Some(AppMode::Headless(_)) => { return Err(anyhow!( @@ -289,7 +295,7 @@ fn init_ui(app_state: Arc, cx: &mut AppContext) -> Result<()> { watch_file_types(fs.clone(), cx); cx.set_menus(app_menus()); - initialize_workspace(app_state.clone(), cx); + initialize_workspace(app_state.clone(), prompt_builder, cx); cx.activate(true); @@ -467,7 +473,7 @@ fn main() { auto_update::init(client.http_client(), cx); reliability::init(client.http_client(), installation_id, cx); - init_common(app_state.clone(), cx); + let prompt_builder = init_common(app_state.clone(), cx); let args = Args::parse(); let urls: Vec<_> = args @@ -487,7 +493,7 @@ fn main() { .and_then(|urls| OpenRequest::parse(urls, cx).log_err()) { Some(request) => { - handle_open_request(request, app_state.clone(), cx); + handle_open_request(request, app_state.clone(), prompt_builder.clone(), cx); } None => { if let Some(dev_server_token) = args.dev_server_token { @@ -503,7 +509,7 @@ fn main() { }) .detach(); } else { - init_ui(app_state.clone(), cx).unwrap(); + init_ui(app_state.clone(), prompt_builder.clone(), cx).unwrap(); cx.spawn({ let app_state = app_state.clone(); |mut cx| async move { @@ -518,11 +524,12 @@ fn main() { } let app_state = app_state.clone(); + let prompt_builder = prompt_builder.clone(); cx.spawn(move |cx| async move { while let Some(urls) = open_rx.next().await { cx.update(|cx| { if let Some(request) = OpenRequest::parse(urls, cx).log_err() { - handle_open_request(request, app_state.clone(), cx); + handle_open_request(request, app_state.clone(), prompt_builder.clone(), cx); } }) .ok(); @@ -532,15 +539,20 @@ fn main() { }); } -fn handle_open_request(request: OpenRequest, app_state: Arc, cx: &mut AppContext) { +fn handle_open_request( + request: OpenRequest, + app_state: Arc, + prompt_builder: Arc, + cx: &mut AppContext, +) { if let Some(connection) = request.cli_connection { let app_state = app_state.clone(); - cx.spawn(move |cx| handle_cli_connection(connection, app_state, cx)) + cx.spawn(move |cx| handle_cli_connection(connection, app_state, prompt_builder, cx)) .detach(); return; } - if let Err(e) = init_ui(app_state.clone(), cx) { + if let Err(e) = init_ui(app_state.clone(), prompt_builder, cx) { fail_to_open_window(e, cx); return; }; diff --git a/crates/zed/src/zed.rs b/crates/zed/src/zed.rs index e727f0e170..667e8c66bc 100644 --- a/crates/zed/src/zed.rs +++ b/crates/zed/src/zed.rs @@ -7,6 +7,7 @@ pub(crate) mod only_instance; mod open_listener; pub use app_menus::*; +use assistant::PromptBuilder; use breadcrumbs::Breadcrumbs; use client::ZED_URL_SCHEME; use collections::VecDeque; @@ -119,7 +120,11 @@ pub fn build_window_options(display_uuid: Option, cx: &mut AppContext) -> } } -pub fn initialize_workspace(app_state: Arc, cx: &mut AppContext) { +pub fn initialize_workspace( + app_state: Arc, + prompt_builder: Arc, + cx: &mut AppContext, +) { cx.observe_new_views(move |workspace: &mut Workspace, cx| { let workspace_handle = cx.view().clone(); let center_pane = workspace.active_pane().clone(); @@ -238,9 +243,10 @@ pub fn initialize_workspace(app_state: Arc, cx: &mut AppContext) { }); } + let prompt_builder = prompt_builder.clone(); cx.spawn(|workspace_handle, mut cx| async move { let assistant_panel = - assistant::AssistantPanel::load(workspace_handle.clone(), cx.clone()); + assistant::AssistantPanel::load(workspace_handle.clone(), prompt_builder, cx.clone()); let project_panel = ProjectPanel::load(workspace_handle.clone(), cx.clone()); let outline_panel = OutlinePanel::load(workspace_handle.clone(), cx.clone()); @@ -3474,14 +3480,15 @@ mod tests { app_state.fs.clone(), cx, ); - assistant::init(app_state.fs.clone(), app_state.client.clone(), cx); + let prompt_builder = + assistant::init(app_state.fs.clone(), app_state.client.clone(), cx); repl::init( app_state.fs.clone(), app_state.client.telemetry().clone(), cx, ); tasks_ui::init(cx); - initialize_workspace(app_state.clone(), cx); + initialize_workspace(app_state.clone(), prompt_builder, cx); app_state }) } diff --git a/crates/zed/src/zed/open_listener.rs b/crates/zed/src/zed/open_listener.rs index fb94f2430e..a9c741f7a0 100644 --- a/crates/zed/src/zed/open_listener.rs +++ b/crates/zed/src/zed/open_listener.rs @@ -1,6 +1,7 @@ use crate::restorable_workspace_locations; use crate::{handle_open_request, init_headless, init_ui}; use anyhow::{anyhow, Context, Result}; +use assistant::PromptBuilder; use cli::{ipc, IpcHandshake}; use cli::{ipc::IpcSender, CliRequest, CliResponse}; use client::parse_zed_link; @@ -245,6 +246,7 @@ pub async fn open_paths_with_positions( pub async fn handle_cli_connection( (mut requests, responses): (mpsc::Receiver, IpcSender), app_state: Arc, + prompt_builder: Arc, mut cx: AsyncAppContext, ) { if let Some(request) = requests.next().await { @@ -289,7 +291,12 @@ pub async fn handle_cli_connection( cx.update(|cx| { match OpenRequest::parse(urls, cx) { Ok(open_request) => { - handle_open_request(open_request, app_state.clone(), cx); + handle_open_request( + open_request, + app_state.clone(), + prompt_builder.clone(), + cx, + ); responses.send(CliResponse::Exit { status: 0 }).log_err(); } Err(e) => { @@ -307,7 +314,7 @@ pub async fn handle_cli_connection( } if let Err(e) = cx - .update(|cx| init_ui(app_state.clone(), cx)) + .update(|cx| init_ui(app_state.clone(), prompt_builder.clone(), cx)) .and_then(|r| r) { responses diff --git a/docs/src/language-model-integration.md b/docs/src/language-model-integration.md index bb8d0f2f53..f2a2fb7b7c 100644 --- a/docs/src/language-model-integration.md +++ b/docs/src/language-model-integration.md @@ -221,6 +221,10 @@ Zed allows you to override the default prompts used for various assistant featur given system information and latest terminal output if relevant. ``` -You can customize these templates to better suit your needs while maintaining the core structure and variables used by Zed. Zed will automatically reload your prompt overrides when they change on disk. +3. `edit_workflow.hbs`: Used for generating the edit workflow prompt. + +4. `step_resolution.hbs`: Used for generating the step resolution prompt. + +You can customize these templates to better suit your needs while maintaining the core structure and variables used by Zed. Zed will automatically reload your prompt overrides when they change on disk. Consult Zed's assets/prompts directory for current versions you can play with. Be sure you want to override these, as you'll miss out on iteration on our built in features. This should be primarily used when developing Zed.