Allow prompt templates to be overridden in the zed configuration directory (#15887)
I need this to refine our prompts on the fly as I work. Release Notes: - Templates for prompts driving inline transformation in editors and the terminal can now be overridden in the `~/.config/zed/prompts/templates` directory. This is an advanced feature, and prevents you from getting upstream changes. It's intended for use by Zed developers.
This commit is contained in:
parent
6065db174a
commit
c8f1358629
12 changed files with 569 additions and 168 deletions
|
@ -1,5 +1,5 @@
|
|||
use crate::{
|
||||
humanize_token_count, prompts::generate_content_prompt, AssistantPanel, AssistantPanelEvent,
|
||||
humanize_token_count, prompts::PromptBuilder, AssistantPanel, AssistantPanelEvent,
|
||||
CharOperation, LineDiff, LineOperation, ModelSelector, StreamingDiff,
|
||||
};
|
||||
use anyhow::{anyhow, Context as _, Result};
|
||||
|
@ -51,8 +51,13 @@ use ui::{prelude::*, CheckboxWithLabel, IconButtonShape, Popover, Tooltip};
|
|||
use util::{RangeExt, ResultExt};
|
||||
use workspace::{notifications::NotificationId, Toast, Workspace};
|
||||
|
||||
pub fn init(fs: Arc<dyn Fs>, telemetry: Arc<Telemetry>, cx: &mut AppContext) {
|
||||
cx.set_global(InlineAssistant::new(fs, telemetry));
|
||||
pub fn init(
|
||||
fs: Arc<dyn Fs>,
|
||||
prompt_builder: Arc<PromptBuilder>,
|
||||
telemetry: Arc<Telemetry>,
|
||||
cx: &mut AppContext,
|
||||
) {
|
||||
cx.set_global(InlineAssistant::new(fs, prompt_builder, telemetry));
|
||||
}
|
||||
|
||||
const PROMPT_HISTORY_MAX_LEN: usize = 20;
|
||||
|
@ -64,6 +69,7 @@ pub struct InlineAssistant {
|
|||
assists_by_editor: HashMap<WeakView<Editor>, EditorInlineAssists>,
|
||||
assist_groups: HashMap<InlineAssistGroupId, InlineAssistGroup>,
|
||||
prompt_history: VecDeque<String>,
|
||||
prompt_builder: Arc<PromptBuilder>,
|
||||
telemetry: Option<Arc<Telemetry>>,
|
||||
fs: Arc<dyn Fs>,
|
||||
}
|
||||
|
@ -71,7 +77,11 @@ pub struct InlineAssistant {
|
|||
impl Global for InlineAssistant {}
|
||||
|
||||
impl InlineAssistant {
|
||||
pub fn new(fs: Arc<dyn Fs>, telemetry: Arc<Telemetry>) -> Self {
|
||||
pub fn new(
|
||||
fs: Arc<dyn Fs>,
|
||||
prompt_builder: Arc<PromptBuilder>,
|
||||
telemetry: Arc<Telemetry>,
|
||||
) -> Self {
|
||||
Self {
|
||||
next_assist_id: InlineAssistId::default(),
|
||||
next_assist_group_id: InlineAssistGroupId::default(),
|
||||
|
@ -79,6 +89,7 @@ impl InlineAssistant {
|
|||
assists_by_editor: HashMap::default(),
|
||||
assist_groups: HashMap::default(),
|
||||
prompt_history: VecDeque::default(),
|
||||
prompt_builder,
|
||||
telemetry: Some(telemetry),
|
||||
fs,
|
||||
}
|
||||
|
@ -155,6 +166,7 @@ impl InlineAssistant {
|
|||
range.clone(),
|
||||
None,
|
||||
self.telemetry.clone(),
|
||||
self.prompt_builder.clone(),
|
||||
cx,
|
||||
)
|
||||
});
|
||||
|
@ -260,6 +272,7 @@ impl InlineAssistant {
|
|||
range.clone(),
|
||||
initial_transaction_id,
|
||||
self.telemetry.clone(),
|
||||
self.prompt_builder.clone(),
|
||||
cx,
|
||||
)
|
||||
});
|
||||
|
@ -2021,6 +2034,7 @@ pub struct Codegen {
|
|||
diff: Diff,
|
||||
telemetry: Option<Arc<Telemetry>>,
|
||||
_subscription: gpui::Subscription,
|
||||
builder: Arc<PromptBuilder>,
|
||||
}
|
||||
|
||||
pub enum CodegenStatus {
|
||||
|
@ -2050,6 +2064,7 @@ impl Codegen {
|
|||
range: Range<Anchor>,
|
||||
initial_transaction_id: Option<TransactionId>,
|
||||
telemetry: Option<Arc<Telemetry>>,
|
||||
builder: Arc<PromptBuilder>,
|
||||
cx: &mut ModelContext<Self>,
|
||||
) -> Self {
|
||||
let snapshot = buffer.read(cx).snapshot(cx);
|
||||
|
@ -2087,6 +2102,7 @@ impl Codegen {
|
|||
telemetry,
|
||||
_subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
|
||||
initial_transaction_id,
|
||||
builder,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -2118,7 +2134,10 @@ impl Codegen {
|
|||
) -> BoxFuture<'static, Result<usize>> {
|
||||
if let Some(model) = LanguageModelRegistry::read_global(cx).active_model() {
|
||||
let request = self.build_request(user_prompt, assistant_panel_context, edit_range, cx);
|
||||
model.count_tokens(request, cx)
|
||||
match request {
|
||||
Ok(request) => model.count_tokens(request, cx),
|
||||
Err(error) => futures::future::ready(Err(error)).boxed(),
|
||||
}
|
||||
} else {
|
||||
future::ready(Err(anyhow!("no active model"))).boxed()
|
||||
}
|
||||
|
@ -2152,7 +2171,8 @@ impl Codegen {
|
|||
async { Ok(stream::empty().boxed()) }.boxed_local()
|
||||
} else {
|
||||
let request =
|
||||
self.build_request(user_prompt, assistant_panel_context, edit_range.clone(), cx);
|
||||
self.build_request(user_prompt, assistant_panel_context, edit_range.clone(), cx)?;
|
||||
|
||||
let chunks =
|
||||
cx.spawn(|_, cx| async move { model.stream_completion(request, &cx).await });
|
||||
async move { Ok(chunks.await?.boxed()) }.boxed_local()
|
||||
|
@ -2167,7 +2187,7 @@ impl Codegen {
|
|||
assistant_panel_context: Option<LanguageModelRequest>,
|
||||
edit_range: Range<Anchor>,
|
||||
cx: &AppContext,
|
||||
) -> LanguageModelRequest {
|
||||
) -> Result<LanguageModelRequest> {
|
||||
let buffer = self.buffer.read(cx).snapshot(cx);
|
||||
let language = buffer.language_at(edit_range.start);
|
||||
let language_name = if let Some(language) = language.as_ref() {
|
||||
|
@ -2202,12 +2222,15 @@ impl Codegen {
|
|||
if start_buffer.remote_id() == end_buffer.remote_id() {
|
||||
(start_buffer.clone(), start_buffer_offset..end_buffer_offset)
|
||||
} else {
|
||||
panic!("invalid transformation range");
|
||||
return Err(anyhow::anyhow!("invalid transformation range"));
|
||||
}
|
||||
} else {
|
||||
panic!("invalid transformation range");
|
||||
return Err(anyhow::anyhow!("invalid transformation range"));
|
||||
};
|
||||
let prompt = generate_content_prompt(user_prompt, language_name, buffer, range);
|
||||
let prompt = self
|
||||
.builder
|
||||
.generate_content_prompt(user_prompt, language_name, buffer, range)
|
||||
.map_err(|e| anyhow::anyhow!("Failed to generate content prompt: {}", e))?;
|
||||
|
||||
let mut messages = Vec::new();
|
||||
if let Some(context_request) = assistant_panel_context {
|
||||
|
@ -2219,11 +2242,11 @@ impl Codegen {
|
|||
content: prompt,
|
||||
});
|
||||
|
||||
LanguageModelRequest {
|
||||
Ok(LanguageModelRequest {
|
||||
messages,
|
||||
stop: vec!["|END|>".to_string()],
|
||||
temperature,
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
pub fn handle_stream(
|
||||
|
@ -2752,8 +2775,17 @@ mod tests {
|
|||
let snapshot = buffer.snapshot(cx);
|
||||
snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
|
||||
});
|
||||
let codegen =
|
||||
cx.new_model(|cx| Codegen::new(buffer.clone(), range.clone(), None, None, cx));
|
||||
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
|
||||
let codegen = cx.new_model(|cx| {
|
||||
Codegen::new(
|
||||
buffer.clone(),
|
||||
range.clone(),
|
||||
None,
|
||||
None,
|
||||
prompt_builder,
|
||||
cx,
|
||||
)
|
||||
});
|
||||
|
||||
let (chunks_tx, chunks_rx) = mpsc::unbounded();
|
||||
codegen.update(cx, |codegen, cx| {
|
||||
|
@ -2815,8 +2847,17 @@ mod tests {
|
|||
let snapshot = buffer.snapshot(cx);
|
||||
snapshot.anchor_before(Point::new(1, 6))..snapshot.anchor_after(Point::new(1, 6))
|
||||
});
|
||||
let codegen =
|
||||
cx.new_model(|cx| Codegen::new(buffer.clone(), range.clone(), None, None, cx));
|
||||
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
|
||||
let codegen = cx.new_model(|cx| {
|
||||
Codegen::new(
|
||||
buffer.clone(),
|
||||
range.clone(),
|
||||
None,
|
||||
None,
|
||||
prompt_builder,
|
||||
cx,
|
||||
)
|
||||
});
|
||||
|
||||
let (chunks_tx, chunks_rx) = mpsc::unbounded();
|
||||
codegen.update(cx, |codegen, cx| {
|
||||
|
@ -2881,8 +2922,17 @@ mod tests {
|
|||
let snapshot = buffer.snapshot(cx);
|
||||
snapshot.anchor_before(Point::new(1, 2))..snapshot.anchor_after(Point::new(1, 2))
|
||||
});
|
||||
let codegen =
|
||||
cx.new_model(|cx| Codegen::new(buffer.clone(), range.clone(), None, None, cx));
|
||||
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
|
||||
let codegen = cx.new_model(|cx| {
|
||||
Codegen::new(
|
||||
buffer.clone(),
|
||||
range.clone(),
|
||||
None,
|
||||
None,
|
||||
prompt_builder,
|
||||
cx,
|
||||
)
|
||||
});
|
||||
|
||||
let (chunks_tx, chunks_rx) = mpsc::unbounded();
|
||||
codegen.update(cx, |codegen, cx| {
|
||||
|
@ -2946,8 +2996,17 @@ mod tests {
|
|||
let snapshot = buffer.snapshot(cx);
|
||||
snapshot.anchor_before(Point::new(0, 0))..snapshot.anchor_after(Point::new(4, 2))
|
||||
});
|
||||
let codegen =
|
||||
cx.new_model(|cx| Codegen::new(buffer.clone(), range.clone(), None, None, cx));
|
||||
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
|
||||
let codegen = cx.new_model(|cx| {
|
||||
Codegen::new(
|
||||
buffer.clone(),
|
||||
range.clone(),
|
||||
None,
|
||||
None,
|
||||
prompt_builder,
|
||||
cx,
|
||||
)
|
||||
});
|
||||
|
||||
let (chunks_tx, chunks_rx) = mpsc::unbounded();
|
||||
codegen.update(cx, |codegen, cx| {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue