Rely on model to determine indentation level and always rewrite the full line (#16145)
This PR simplifies our approach to indentation in the inline assistant in hopes of improving our experience for Python. We tell the model to generate the correct indentation in the prompt, and always start generating at the start of the line. This may fall down for less capable models, but I want to get a solid experience on the best models and then figure the rest out later. Also: We now prefer `./assets/prompts` as an overrides directory when stdout is a PTY, so you can do `cargo run` and then iterate prompts for the current run inside the current working copy. cc @trishume @dsp-ant Release Notes: - Zed now allows the model to control indentation when performing inline transformation. We're hoping this improves the indentation experience in Python and other indentation-sensitive languages, but it does require more from the model. --------- Co-authored-by: Mikayla <mikayla@zed.dev>
This commit is contained in:
parent
e662bfc74f
commit
a515442a36
6 changed files with 235 additions and 565 deletions
|
@ -12,11 +12,13 @@ use util::ResultExt;
|
|||
pub struct ContentPromptContext {
|
||||
pub content_type: String,
|
||||
pub language_name: Option<String>,
|
||||
pub is_insert: bool,
|
||||
pub is_truncated: bool,
|
||||
pub document_content: String,
|
||||
pub user_prompt: String,
|
||||
pub rewrite_section: Option<String>,
|
||||
pub rewrite_section: String,
|
||||
pub rewrite_section_with_selections: String,
|
||||
pub has_insertion: bool,
|
||||
pub has_replacement: bool,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
|
@ -33,41 +35,54 @@ pub struct PromptBuilder {
|
|||
handlebars: Arc<Mutex<Handlebars<'static>>>,
|
||||
}
|
||||
|
||||
pub struct PromptOverrideContext<'a> {
|
||||
pub dev_mode: bool,
|
||||
pub fs: Arc<dyn Fs>,
|
||||
pub cx: &'a mut gpui::AppContext,
|
||||
}
|
||||
|
||||
impl PromptBuilder {
|
||||
pub fn new(
|
||||
fs_and_cx: Option<(Arc<dyn Fs>, &gpui::AppContext)>,
|
||||
) -> Result<Self, Box<TemplateError>> {
|
||||
pub fn new(override_cx: Option<PromptOverrideContext>) -> Result<Self, Box<TemplateError>> {
|
||||
let mut handlebars = Handlebars::new();
|
||||
Self::register_templates(&mut handlebars)?;
|
||||
|
||||
let handlebars = Arc::new(Mutex::new(handlebars));
|
||||
|
||||
if let Some((fs, cx)) = fs_and_cx {
|
||||
Self::watch_fs_for_template_overrides(fs, cx, handlebars.clone());
|
||||
if let Some(override_cx) = override_cx {
|
||||
Self::watch_fs_for_template_overrides(override_cx, handlebars.clone());
|
||||
}
|
||||
|
||||
Ok(Self { handlebars })
|
||||
}
|
||||
|
||||
fn watch_fs_for_template_overrides(
|
||||
fs: Arc<dyn Fs>,
|
||||
cx: &gpui::AppContext,
|
||||
PromptOverrideContext { dev_mode, fs, cx }: PromptOverrideContext,
|
||||
handlebars: Arc<Mutex<Handlebars<'static>>>,
|
||||
) {
|
||||
let templates_dir = paths::prompt_overrides_dir();
|
||||
|
||||
cx.background_executor()
|
||||
.spawn(async move {
|
||||
let templates_dir = if dev_mode {
|
||||
std::env::current_dir()
|
||||
.ok()
|
||||
.and_then(|pwd| {
|
||||
let pwd_assets_prompts = pwd.join("assets").join("prompts");
|
||||
pwd_assets_prompts.exists().then_some(pwd_assets_prompts)
|
||||
})
|
||||
.unwrap_or_else(|| paths::prompt_overrides_dir().clone())
|
||||
} else {
|
||||
paths::prompt_overrides_dir().clone()
|
||||
};
|
||||
|
||||
// Create the prompt templates directory if it doesn't exist
|
||||
if !fs.is_dir(templates_dir).await {
|
||||
if let Err(e) = fs.create_dir(templates_dir).await {
|
||||
if !fs.is_dir(&templates_dir).await {
|
||||
if let Err(e) = fs.create_dir(&templates_dir).await {
|
||||
log::error!("Failed to create prompt templates directory: {}", e);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// Initial scan of the prompts directory
|
||||
if let Ok(mut entries) = fs.read_dir(templates_dir).await {
|
||||
if let Ok(mut entries) = fs.read_dir(&templates_dir).await {
|
||||
while let Some(Ok(file_path)) = entries.next().await {
|
||||
if file_path.to_string_lossy().ends_with(".hbs") {
|
||||
if let Ok(content) = fs.load(&file_path).await {
|
||||
|
@ -95,7 +110,7 @@ impl PromptBuilder {
|
|||
}
|
||||
|
||||
// Watch for changes
|
||||
let (mut changes, watcher) = fs.watch(templates_dir, Duration::from_secs(1)).await;
|
||||
let (mut changes, watcher) = fs.watch(&templates_dir, Duration::from_secs(1)).await;
|
||||
while let Some(changed_paths) = changes.next().await {
|
||||
for changed_path in changed_paths {
|
||||
if changed_path.extension().map_or(false, |ext| ext == "hbs") {
|
||||
|
@ -147,7 +162,8 @@ impl PromptBuilder {
|
|||
user_prompt: String,
|
||||
language_name: Option<&str>,
|
||||
buffer: BufferSnapshot,
|
||||
range: Range<usize>,
|
||||
transform_range: Range<usize>,
|
||||
selected_ranges: Vec<Range<usize>>,
|
||||
) -> Result<String, RenderError> {
|
||||
let content_type = match language_name {
|
||||
None | Some("Markdown" | "Plain Text") => "text",
|
||||
|
@ -155,21 +171,20 @@ impl PromptBuilder {
|
|||
};
|
||||
|
||||
const MAX_CTX: usize = 50000;
|
||||
let is_insert = range.is_empty();
|
||||
let mut is_truncated = false;
|
||||
|
||||
let before_range = 0..range.start;
|
||||
let before_range = 0..transform_range.start;
|
||||
let truncated_before = if before_range.len() > MAX_CTX {
|
||||
is_truncated = true;
|
||||
range.start - MAX_CTX..range.start
|
||||
transform_range.start - MAX_CTX..transform_range.start
|
||||
} else {
|
||||
before_range
|
||||
};
|
||||
|
||||
let after_range = range.end..buffer.len();
|
||||
let after_range = transform_range.end..buffer.len();
|
||||
let truncated_after = if after_range.len() > MAX_CTX {
|
||||
is_truncated = true;
|
||||
range.end..range.end + MAX_CTX
|
||||
transform_range.end..transform_range.end + MAX_CTX
|
||||
} else {
|
||||
after_range
|
||||
};
|
||||
|
@ -178,37 +193,61 @@ impl PromptBuilder {
|
|||
for chunk in buffer.text_for_range(truncated_before) {
|
||||
document_content.push_str(chunk);
|
||||
}
|
||||
if is_insert {
|
||||
document_content.push_str("<insert_here></insert_here>");
|
||||
} else {
|
||||
document_content.push_str("<rewrite_this>\n");
|
||||
for chunk in buffer.text_for_range(range.clone()) {
|
||||
document_content.push_str(chunk);
|
||||
}
|
||||
document_content.push_str("\n</rewrite_this>");
|
||||
document_content.push_str("<rewrite_this>\n");
|
||||
for chunk in buffer.text_for_range(transform_range.clone()) {
|
||||
document_content.push_str(chunk);
|
||||
}
|
||||
document_content.push_str("\n</rewrite_this>");
|
||||
|
||||
for chunk in buffer.text_for_range(truncated_after) {
|
||||
document_content.push_str(chunk);
|
||||
}
|
||||
|
||||
let rewrite_section = if !is_insert {
|
||||
let mut section = String::new();
|
||||
for chunk in buffer.text_for_range(range.clone()) {
|
||||
section.push_str(chunk);
|
||||
let mut rewrite_section = String::new();
|
||||
for chunk in buffer.text_for_range(transform_range.clone()) {
|
||||
rewrite_section.push_str(chunk);
|
||||
}
|
||||
|
||||
let rewrite_section_with_selections = {
|
||||
let mut section_with_selections = String::new();
|
||||
let mut last_end = 0;
|
||||
for selected_range in &selected_ranges {
|
||||
if selected_range.start > last_end {
|
||||
section_with_selections.push_str(
|
||||
&rewrite_section[last_end..selected_range.start - transform_range.start],
|
||||
);
|
||||
}
|
||||
if selected_range.start == selected_range.end {
|
||||
section_with_selections.push_str("<insert_here></insert_here>");
|
||||
} else {
|
||||
section_with_selections.push_str("<edit_here>");
|
||||
section_with_selections.push_str(
|
||||
&rewrite_section[selected_range.start - transform_range.start
|
||||
..selected_range.end - transform_range.start],
|
||||
);
|
||||
section_with_selections.push_str("</edit_here>");
|
||||
}
|
||||
last_end = selected_range.end - transform_range.start;
|
||||
}
|
||||
Some(section)
|
||||
} else {
|
||||
None
|
||||
if last_end < rewrite_section.len() {
|
||||
section_with_selections.push_str(&rewrite_section[last_end..]);
|
||||
}
|
||||
section_with_selections
|
||||
};
|
||||
|
||||
let has_insertion = selected_ranges.iter().any(|range| range.start == range.end);
|
||||
let has_replacement = selected_ranges.iter().any(|range| range.start != range.end);
|
||||
|
||||
let context = ContentPromptContext {
|
||||
content_type: content_type.to_string(),
|
||||
language_name: language_name.map(|s| s.to_string()),
|
||||
is_insert,
|
||||
is_truncated,
|
||||
document_content,
|
||||
user_prompt,
|
||||
rewrite_section,
|
||||
rewrite_section_with_selections,
|
||||
has_insertion,
|
||||
has_replacement,
|
||||
};
|
||||
|
||||
self.handlebars.lock().render("content_prompt", &context)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue