Allow prompt templates to be overridden in the zed configuration directory (#15887)

I need this to refine our prompts on the fly as I work.

Release Notes:

- Templates for prompts driving inline transformation in editors and the
terminal can now be overridden in the `~/.config/zed/prompts/templates`
directory. This is an advanced feature, and prevents you from getting
upstream changes. It's intended for use by Zed developers.
This commit is contained in:
Nathan Sobo 2024-08-06 19:30:48 -06:00 committed by GitHub
parent 6065db174a
commit c8f1358629
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 569 additions and 168 deletions

View file

@ -1,5 +1,5 @@
use crate::{
humanize_token_count, prompts::generate_content_prompt, AssistantPanel, AssistantPanelEvent,
humanize_token_count, prompts::PromptBuilder, AssistantPanel, AssistantPanelEvent,
CharOperation, LineDiff, LineOperation, ModelSelector, StreamingDiff,
};
use anyhow::{anyhow, Context as _, Result};
@ -51,8 +51,13 @@ use ui::{prelude::*, CheckboxWithLabel, IconButtonShape, Popover, Tooltip};
use util::{RangeExt, ResultExt};
use workspace::{notifications::NotificationId, Toast, Workspace};
pub fn init(fs: Arc<dyn Fs>, telemetry: Arc<Telemetry>, cx: &mut AppContext) {
cx.set_global(InlineAssistant::new(fs, telemetry));
pub fn init(
fs: Arc<dyn Fs>,
prompt_builder: Arc<PromptBuilder>,
telemetry: Arc<Telemetry>,
cx: &mut AppContext,
) {
cx.set_global(InlineAssistant::new(fs, prompt_builder, telemetry));
}
const PROMPT_HISTORY_MAX_LEN: usize = 20;
@ -64,6 +69,7 @@ pub struct InlineAssistant {
assists_by_editor: HashMap<WeakView<Editor>, EditorInlineAssists>,
assist_groups: HashMap<InlineAssistGroupId, InlineAssistGroup>,
prompt_history: VecDeque<String>,
prompt_builder: Arc<PromptBuilder>,
telemetry: Option<Arc<Telemetry>>,
fs: Arc<dyn Fs>,
}
@ -71,7 +77,11 @@ pub struct InlineAssistant {
impl Global for InlineAssistant {}
impl InlineAssistant {
pub fn new(fs: Arc<dyn Fs>, telemetry: Arc<Telemetry>) -> Self {
pub fn new(
fs: Arc<dyn Fs>,
prompt_builder: Arc<PromptBuilder>,
telemetry: Arc<Telemetry>,
) -> Self {
Self {
next_assist_id: InlineAssistId::default(),
next_assist_group_id: InlineAssistGroupId::default(),
@ -79,6 +89,7 @@ impl InlineAssistant {
assists_by_editor: HashMap::default(),
assist_groups: HashMap::default(),
prompt_history: VecDeque::default(),
prompt_builder,
telemetry: Some(telemetry),
fs,
}
@ -155,6 +166,7 @@ impl InlineAssistant {
range.clone(),
None,
self.telemetry.clone(),
self.prompt_builder.clone(),
cx,
)
});
@ -260,6 +272,7 @@ impl InlineAssistant {
range.clone(),
initial_transaction_id,
self.telemetry.clone(),
self.prompt_builder.clone(),
cx,
)
});
@ -2021,6 +2034,7 @@ pub struct Codegen {
diff: Diff,
telemetry: Option<Arc<Telemetry>>,
_subscription: gpui::Subscription,
builder: Arc<PromptBuilder>,
}
pub enum CodegenStatus {
@ -2050,6 +2064,7 @@ impl Codegen {
range: Range<Anchor>,
initial_transaction_id: Option<TransactionId>,
telemetry: Option<Arc<Telemetry>>,
builder: Arc<PromptBuilder>,
cx: &mut ModelContext<Self>,
) -> Self {
let snapshot = buffer.read(cx).snapshot(cx);
@ -2087,6 +2102,7 @@ impl Codegen {
telemetry,
_subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
initial_transaction_id,
builder,
}
}
@ -2118,7 +2134,10 @@ impl Codegen {
) -> BoxFuture<'static, Result<usize>> {
if let Some(model) = LanguageModelRegistry::read_global(cx).active_model() {
let request = self.build_request(user_prompt, assistant_panel_context, edit_range, cx);
model.count_tokens(request, cx)
match request {
Ok(request) => model.count_tokens(request, cx),
Err(error) => futures::future::ready(Err(error)).boxed(),
}
} else {
future::ready(Err(anyhow!("no active model"))).boxed()
}
@ -2152,7 +2171,8 @@ impl Codegen {
async { Ok(stream::empty().boxed()) }.boxed_local()
} else {
let request =
self.build_request(user_prompt, assistant_panel_context, edit_range.clone(), cx);
self.build_request(user_prompt, assistant_panel_context, edit_range.clone(), cx)?;
let chunks =
cx.spawn(|_, cx| async move { model.stream_completion(request, &cx).await });
async move { Ok(chunks.await?.boxed()) }.boxed_local()
@ -2167,7 +2187,7 @@ impl Codegen {
assistant_panel_context: Option<LanguageModelRequest>,
edit_range: Range<Anchor>,
cx: &AppContext,
) -> LanguageModelRequest {
) -> Result<LanguageModelRequest> {
let buffer = self.buffer.read(cx).snapshot(cx);
let language = buffer.language_at(edit_range.start);
let language_name = if let Some(language) = language.as_ref() {
@ -2202,12 +2222,15 @@ impl Codegen {
if start_buffer.remote_id() == end_buffer.remote_id() {
(start_buffer.clone(), start_buffer_offset..end_buffer_offset)
} else {
panic!("invalid transformation range");
return Err(anyhow::anyhow!("invalid transformation range"));
}
} else {
panic!("invalid transformation range");
return Err(anyhow::anyhow!("invalid transformation range"));
};
let prompt = generate_content_prompt(user_prompt, language_name, buffer, range);
let prompt = self
.builder
.generate_content_prompt(user_prompt, language_name, buffer, range)
.map_err(|e| anyhow::anyhow!("Failed to generate content prompt: {}", e))?;
let mut messages = Vec::new();
if let Some(context_request) = assistant_panel_context {
@ -2219,11 +2242,11 @@ impl Codegen {
content: prompt,
});
LanguageModelRequest {
Ok(LanguageModelRequest {
messages,
stop: vec!["|END|>".to_string()],
temperature,
}
})
}
pub fn handle_stream(
@ -2752,8 +2775,17 @@ mod tests {
let snapshot = buffer.snapshot(cx);
snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
});
let codegen =
cx.new_model(|cx| Codegen::new(buffer.clone(), range.clone(), None, None, cx));
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
let codegen = cx.new_model(|cx| {
Codegen::new(
buffer.clone(),
range.clone(),
None,
None,
prompt_builder,
cx,
)
});
let (chunks_tx, chunks_rx) = mpsc::unbounded();
codegen.update(cx, |codegen, cx| {
@ -2815,8 +2847,17 @@ mod tests {
let snapshot = buffer.snapshot(cx);
snapshot.anchor_before(Point::new(1, 6))..snapshot.anchor_after(Point::new(1, 6))
});
let codegen =
cx.new_model(|cx| Codegen::new(buffer.clone(), range.clone(), None, None, cx));
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
let codegen = cx.new_model(|cx| {
Codegen::new(
buffer.clone(),
range.clone(),
None,
None,
prompt_builder,
cx,
)
});
let (chunks_tx, chunks_rx) = mpsc::unbounded();
codegen.update(cx, |codegen, cx| {
@ -2881,8 +2922,17 @@ mod tests {
let snapshot = buffer.snapshot(cx);
snapshot.anchor_before(Point::new(1, 2))..snapshot.anchor_after(Point::new(1, 2))
});
let codegen =
cx.new_model(|cx| Codegen::new(buffer.clone(), range.clone(), None, None, cx));
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
let codegen = cx.new_model(|cx| {
Codegen::new(
buffer.clone(),
range.clone(),
None,
None,
prompt_builder,
cx,
)
});
let (chunks_tx, chunks_rx) = mpsc::unbounded();
codegen.update(cx, |codegen, cx| {
@ -2946,8 +2996,17 @@ mod tests {
let snapshot = buffer.snapshot(cx);
snapshot.anchor_before(Point::new(0, 0))..snapshot.anchor_after(Point::new(4, 2))
});
let codegen =
cx.new_model(|cx| Codegen::new(buffer.clone(), range.clone(), None, None, cx));
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
let codegen = cx.new_model(|cx| {
Codegen::new(
buffer.clone(),
range.clone(),
None,
None,
prompt_builder,
cx,
)
});
let (chunks_tx, chunks_rx) = mpsc::unbounded();
codegen.update(cx, |codegen, cx| {