Rely on model to determine indentation level and always rewrite the full line (#16145)

This PR simplifies our approach to indentation in the inline assistant
in hopes of improving our experience for Python. We tell the model to
generate the correct indentation in the prompt, and always start
generating at the start of the line. This may fall down for less capable
models, but I want to get a solid experience on the best models and then
figure the rest out later.

Also: We now prefer `./assets/prompts` as an overrides directory when
stdout is a PTY, so you can do `cargo run` and then iterate prompts for
the current run inside the current working copy.

cc @trishume @dsp-ant 

Release Notes:

- Zed now allows the model to control indentation when performing inline
transformation. We're hoping this improves the indentation experience in
Python and other indentation-sensitive languages, but it does require
more from the model.

---------

Co-authored-by: Mikayla <mikayla@zed.dev>
This commit is contained in:
Nathan Sobo 2024-08-12 22:41:24 -06:00 committed by GitHub
parent e662bfc74f
commit a515442a36
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 235 additions and 565 deletions

View file

@ -12,11 +12,13 @@ use util::ResultExt;
pub struct ContentPromptContext {
pub content_type: String,
pub language_name: Option<String>,
pub is_insert: bool,
pub is_truncated: bool,
pub document_content: String,
pub user_prompt: String,
pub rewrite_section: Option<String>,
pub rewrite_section: String,
pub rewrite_section_with_selections: String,
pub has_insertion: bool,
pub has_replacement: bool,
}
#[derive(Serialize)]
@ -33,41 +35,54 @@ pub struct PromptBuilder {
handlebars: Arc<Mutex<Handlebars<'static>>>,
}
pub struct PromptOverrideContext<'a> {
pub dev_mode: bool,
pub fs: Arc<dyn Fs>,
pub cx: &'a mut gpui::AppContext,
}
impl PromptBuilder {
pub fn new(
fs_and_cx: Option<(Arc<dyn Fs>, &gpui::AppContext)>,
) -> Result<Self, Box<TemplateError>> {
pub fn new(override_cx: Option<PromptOverrideContext>) -> Result<Self, Box<TemplateError>> {
let mut handlebars = Handlebars::new();
Self::register_templates(&mut handlebars)?;
let handlebars = Arc::new(Mutex::new(handlebars));
if let Some((fs, cx)) = fs_and_cx {
Self::watch_fs_for_template_overrides(fs, cx, handlebars.clone());
if let Some(override_cx) = override_cx {
Self::watch_fs_for_template_overrides(override_cx, handlebars.clone());
}
Ok(Self { handlebars })
}
fn watch_fs_for_template_overrides(
fs: Arc<dyn Fs>,
cx: &gpui::AppContext,
PromptOverrideContext { dev_mode, fs, cx }: PromptOverrideContext,
handlebars: Arc<Mutex<Handlebars<'static>>>,
) {
let templates_dir = paths::prompt_overrides_dir();
cx.background_executor()
.spawn(async move {
let templates_dir = if dev_mode {
std::env::current_dir()
.ok()
.and_then(|pwd| {
let pwd_assets_prompts = pwd.join("assets").join("prompts");
pwd_assets_prompts.exists().then_some(pwd_assets_prompts)
})
.unwrap_or_else(|| paths::prompt_overrides_dir().clone())
} else {
paths::prompt_overrides_dir().clone()
};
// Create the prompt templates directory if it doesn't exist
if !fs.is_dir(templates_dir).await {
if let Err(e) = fs.create_dir(templates_dir).await {
if !fs.is_dir(&templates_dir).await {
if let Err(e) = fs.create_dir(&templates_dir).await {
log::error!("Failed to create prompt templates directory: {}", e);
return;
}
}
// Initial scan of the prompts directory
if let Ok(mut entries) = fs.read_dir(templates_dir).await {
if let Ok(mut entries) = fs.read_dir(&templates_dir).await {
while let Some(Ok(file_path)) = entries.next().await {
if file_path.to_string_lossy().ends_with(".hbs") {
if let Ok(content) = fs.load(&file_path).await {
@ -95,7 +110,7 @@ impl PromptBuilder {
}
// Watch for changes
let (mut changes, watcher) = fs.watch(templates_dir, Duration::from_secs(1)).await;
let (mut changes, watcher) = fs.watch(&templates_dir, Duration::from_secs(1)).await;
while let Some(changed_paths) = changes.next().await {
for changed_path in changed_paths {
if changed_path.extension().map_or(false, |ext| ext == "hbs") {
@ -147,7 +162,8 @@ impl PromptBuilder {
user_prompt: String,
language_name: Option<&str>,
buffer: BufferSnapshot,
range: Range<usize>,
transform_range: Range<usize>,
selected_ranges: Vec<Range<usize>>,
) -> Result<String, RenderError> {
let content_type = match language_name {
None | Some("Markdown" | "Plain Text") => "text",
@ -155,21 +171,20 @@ impl PromptBuilder {
};
const MAX_CTX: usize = 50000;
let is_insert = range.is_empty();
let mut is_truncated = false;
let before_range = 0..range.start;
let before_range = 0..transform_range.start;
let truncated_before = if before_range.len() > MAX_CTX {
is_truncated = true;
range.start - MAX_CTX..range.start
transform_range.start - MAX_CTX..transform_range.start
} else {
before_range
};
let after_range = range.end..buffer.len();
let after_range = transform_range.end..buffer.len();
let truncated_after = if after_range.len() > MAX_CTX {
is_truncated = true;
range.end..range.end + MAX_CTX
transform_range.end..transform_range.end + MAX_CTX
} else {
after_range
};
@ -178,37 +193,61 @@ impl PromptBuilder {
for chunk in buffer.text_for_range(truncated_before) {
document_content.push_str(chunk);
}
if is_insert {
document_content.push_str("<insert_here></insert_here>");
} else {
document_content.push_str("<rewrite_this>\n");
for chunk in buffer.text_for_range(range.clone()) {
document_content.push_str(chunk);
}
document_content.push_str("\n</rewrite_this>");
document_content.push_str("<rewrite_this>\n");
for chunk in buffer.text_for_range(transform_range.clone()) {
document_content.push_str(chunk);
}
document_content.push_str("\n</rewrite_this>");
for chunk in buffer.text_for_range(truncated_after) {
document_content.push_str(chunk);
}
let rewrite_section = if !is_insert {
let mut section = String::new();
for chunk in buffer.text_for_range(range.clone()) {
section.push_str(chunk);
let mut rewrite_section = String::new();
for chunk in buffer.text_for_range(transform_range.clone()) {
rewrite_section.push_str(chunk);
}
let rewrite_section_with_selections = {
let mut section_with_selections = String::new();
let mut last_end = 0;
for selected_range in &selected_ranges {
if selected_range.start > last_end {
section_with_selections.push_str(
&rewrite_section[last_end..selected_range.start - transform_range.start],
);
}
if selected_range.start == selected_range.end {
section_with_selections.push_str("<insert_here></insert_here>");
} else {
section_with_selections.push_str("<edit_here>");
section_with_selections.push_str(
&rewrite_section[selected_range.start - transform_range.start
..selected_range.end - transform_range.start],
);
section_with_selections.push_str("</edit_here>");
}
last_end = selected_range.end - transform_range.start;
}
Some(section)
} else {
None
if last_end < rewrite_section.len() {
section_with_selections.push_str(&rewrite_section[last_end..]);
}
section_with_selections
};
let has_insertion = selected_ranges.iter().any(|range| range.start == range.end);
let has_replacement = selected_ranges.iter().any(|range| range.start != range.end);
let context = ContentPromptContext {
content_type: content_type.to_string(),
language_name: language_name.map(|s| s.to_string()),
is_insert,
is_truncated,
document_content,
user_prompt,
rewrite_section,
rewrite_section_with_selections,
has_insertion,
has_replacement,
};
self.handlebars.lock().render("content_prompt", &context)