Allow /workflow and step resolution prompts to be overridden (#15892)

This will help us as we hit issues with the /workflow and step
resolution. We can override the baked-in prompts and make tweaks, then
import our refinements back into the source tree when we're ready.

Release Notes:

- N/A
This commit is contained in:
Nathan Sobo 2024-08-06 21:47:42 -06:00 committed by GitHub
parent c8f1358629
commit 990774247e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 196 additions and 104 deletions

View file

@ -1,5 +1,5 @@
use crate::{
prompt_library::PromptStore, slash_command::SlashCommandLine, AssistantPanel, InitialInsertion,
prompts::PromptBuilder, slash_command::SlashCommandLine, AssistantPanel, InitialInsertion,
InlineAssistId, InlineAssistant, MessageId, MessageStatus,
};
use anyhow::{anyhow, Context as _, Result};
@ -611,6 +611,7 @@ pub struct Context {
language_registry: Arc<LanguageRegistry>,
workflow_steps: Vec<WorkflowStep>,
project: Option<Model<Project>>,
prompt_builder: Arc<PromptBuilder>,
}
impl EventEmitter<ContextEvent> for Context {}
@ -620,6 +621,7 @@ impl Context {
language_registry: Arc<LanguageRegistry>,
project: Option<Model<Project>>,
telemetry: Option<Arc<Telemetry>>,
prompt_builder: Arc<PromptBuilder>,
cx: &mut ModelContext<Self>,
) -> Self {
Self::new(
@ -627,17 +629,20 @@ impl Context {
ReplicaId::default(),
language::Capability::ReadWrite,
language_registry,
prompt_builder,
project,
telemetry,
cx,
)
}
#[allow(clippy::too_many_arguments)]
pub fn new(
id: ContextId,
replica_id: ReplicaId,
capability: language::Capability,
language_registry: Arc<LanguageRegistry>,
prompt_builder: Arc<PromptBuilder>,
project: Option<Model<Project>>,
telemetry: Option<Arc<Telemetry>>,
cx: &mut ModelContext<Self>,
@ -680,6 +685,7 @@ impl Context {
project,
language_registry,
workflow_steps: Vec::new(),
prompt_builder,
};
let first_message_id = MessageId(clock::Lamport {
@ -749,6 +755,7 @@ impl Context {
saved_context: SavedContext,
path: PathBuf,
language_registry: Arc<LanguageRegistry>,
prompt_builder: Arc<PromptBuilder>,
project: Option<Model<Project>>,
telemetry: Option<Arc<Telemetry>>,
cx: &mut ModelContext<Self>,
@ -759,6 +766,7 @@ impl Context {
ReplicaId::default(),
language::Capability::ReadWrite,
language_registry,
prompt_builder,
project,
telemetry,
cx,
@ -1246,9 +1254,9 @@ impl Context {
cx.spawn(|this, mut cx| {
async move {
let prompt_store = cx.update(|cx| PromptStore::global(cx))?.await?;
let mut prompt = prompt_store.step_resolution_prompt()?;
let mut prompt = this.update(&mut cx, |this, _| {
this.prompt_builder.generate_step_resolution_prompt()
})??;
prompt.push_str(&step_text);
request.messages.push(LanguageModelRequestMessage {
@ -2448,8 +2456,9 @@ mod tests {
cx.set_global(settings_store);
assistant_panel::init(cx);
let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
let context = cx.new_model(|cx| Context::local(registry, None, None, cx));
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
let context =
cx.new_model(|cx| Context::local(registry, None, None, prompt_builder.clone(), cx));
let buffer = context.read(cx).buffer.clone();
let message_1 = context.read(cx).message_anchors[0].clone();
@ -2580,7 +2589,9 @@ mod tests {
assistant_panel::init(cx);
let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
let context = cx.new_model(|cx| Context::local(registry, None, None, cx));
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
let context =
cx.new_model(|cx| Context::local(registry, None, None, prompt_builder.clone(), cx));
let buffer = context.read(cx).buffer.clone();
let message_1 = context.read(cx).message_anchors[0].clone();
@ -2673,7 +2684,9 @@ mod tests {
cx.set_global(settings_store);
assistant_panel::init(cx);
let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
let context = cx.new_model(|cx| Context::local(registry, None, None, cx));
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
let context =
cx.new_model(|cx| Context::local(registry, None, None, prompt_builder.clone(), cx));
let buffer = context.read(cx).buffer.clone();
let message_1 = context.read(cx).message_anchors[0].clone();
@ -2778,7 +2791,10 @@ mod tests {
slash_command_registry.register_command(active_command::ActiveSlashCommand, false);
let registry = Arc::new(LanguageRegistry::test(cx.executor()));
let context = cx.new_model(|cx| Context::local(registry.clone(), None, None, cx));
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
let context = cx.new_model(|cx| {
Context::local(registry.clone(), None, None, prompt_builder.clone(), cx)
});
let output_ranges = Rc::new(RefCell::new(HashSet::default()));
context.update(cx, |_, cx| {
@ -2905,7 +2921,16 @@ mod tests {
let registry = Arc::new(LanguageRegistry::test(cx.executor()));
// Create a new context
let context = cx.new_model(|cx| Context::local(registry.clone(), Some(project), None, cx));
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
let context = cx.new_model(|cx| {
Context::local(
registry.clone(),
Some(project),
None,
prompt_builder.clone(),
cx,
)
});
let buffer = context.read_with(cx, |context, _| context.buffer.clone());
// Simulate user input
@ -3070,7 +3095,10 @@ mod tests {
cx.update(LanguageModelRegistry::test);
cx.update(assistant_panel::init);
let registry = Arc::new(LanguageRegistry::test(cx.executor()));
let context = cx.new_model(|cx| Context::local(registry.clone(), None, None, cx));
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
let context = cx.new_model(|cx| {
Context::local(registry.clone(), None, None, prompt_builder.clone(), cx)
});
let buffer = context.read_with(cx, |context, _| context.buffer.clone());
let message_0 = context.read_with(cx, |context, _| context.message_anchors[0].id);
let message_1 = context.update(cx, |context, cx| {
@ -3109,6 +3137,7 @@ mod tests {
serialized_context,
Default::default(),
registry.clone(),
prompt_builder.clone(),
None,
None,
cx,
@ -3158,6 +3187,7 @@ mod tests {
let num_peers = rng.gen_range(min_peers..=max_peers);
let context_id = ContextId::new();
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
for i in 0..num_peers {
let context = cx.new_model(|cx| {
Context::new(
@ -3165,6 +3195,7 @@ mod tests {
i as ReplicaId,
language::Capability::ReadWrite,
registry.clone(),
prompt_builder.clone(),
None,
None,
cx,