Improve workflow step pruning and symbol similarity matching (#16036)

This PR improves workflow step management and symbol matching. We've
optimized step pruning to remove any step that intersects an edit and
switched to normalized Levenshtein distance for more accurate symbol
matching.

Release Notes:

- N/A
This commit is contained in:
Antonio Scandurra 2024-08-12 11:09:07 +02:00 committed by GitHub
parent 355aebd0e4
commit 48f6193628
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 124 additions and 20 deletions

View file

@ -33,7 +33,7 @@ use project::Project;
use schemars::JsonSchema; use schemars::JsonSchema;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::{ use std::{
cmp, cmp::{self, Ordering},
fmt::Debug, fmt::Debug,
iter, mem, iter, mem,
ops::Range, ops::Range,
@ -618,6 +618,7 @@ pub struct Context {
telemetry: Option<Arc<Telemetry>>, telemetry: Option<Arc<Telemetry>>,
language_registry: Arc<LanguageRegistry>, language_registry: Arc<LanguageRegistry>,
workflow_steps: Vec<WorkflowStep>, workflow_steps: Vec<WorkflowStep>,
edits_since_last_workflow_step_prune: language::Subscription,
project: Option<Model<Project>>, project: Option<Model<Project>>,
prompt_builder: Arc<PromptBuilder>, prompt_builder: Arc<PromptBuilder>,
} }
@ -667,6 +668,8 @@ impl Context {
}); });
let edits_since_last_slash_command_parse = let edits_since_last_slash_command_parse =
buffer.update(cx, |buffer, _| buffer.subscribe()); buffer.update(cx, |buffer, _| buffer.subscribe());
let edits_since_last_workflow_step_prune =
buffer.update(cx, |buffer, _| buffer.subscribe());
let mut this = Self { let mut this = Self {
id, id,
timestamp: clock::Lamport::new(replica_id), timestamp: clock::Lamport::new(replica_id),
@ -693,6 +696,7 @@ impl Context {
project, project,
language_registry, language_registry,
workflow_steps: Vec::new(), workflow_steps: Vec::new(),
edits_since_last_workflow_step_prune,
prompt_builder, prompt_builder,
}; };
@ -1058,7 +1062,9 @@ impl Context {
language::Event::Edited => { language::Event::Edited => {
self.count_remaining_tokens(cx); self.count_remaining_tokens(cx);
self.reparse_slash_commands(cx); self.reparse_slash_commands(cx);
self.prune_invalid_workflow_steps(cx); // Use `inclusive = true` to invalidate a step when an edit occurs
// at the start/end of a parsed step.
self.prune_invalid_workflow_steps(true, cx);
cx.emit(ContextEvent::MessagesEdited); cx.emit(ContextEvent::MessagesEdited);
} }
_ => {} _ => {}
@ -1165,24 +1171,62 @@ impl Context {
} }
} }
fn prune_invalid_workflow_steps(&mut self, cx: &mut ModelContext<Self>) { fn prune_invalid_workflow_steps(&mut self, inclusive: bool, cx: &mut ModelContext<Self>) {
let buffer = self.buffer.read(cx);
let prev_len = self.workflow_steps.len();
let mut removed = Vec::new(); let mut removed = Vec::new();
self.workflow_steps.retain(|step| {
if step.tagged_range.start.is_valid(buffer) && step.tagged_range.end.is_valid(buffer) { for edit_range in self.edits_since_last_workflow_step_prune.consume() {
true let intersecting_range = self.find_intersecting_steps(edit_range.new, inclusive, cx);
} else { removed.extend(
removed.push(step.tagged_range.clone()); self.workflow_steps
false .drain(intersecting_range)
.map(|step| step.tagged_range),
);
} }
});
if self.workflow_steps.len() != prev_len { if !removed.is_empty() {
cx.emit(ContextEvent::WorkflowStepsRemoved(removed)); cx.emit(ContextEvent::WorkflowStepsRemoved(removed));
cx.notify(); cx.notify();
} }
} }
fn find_intersecting_steps(
&self,
range: Range<usize>,
inclusive: bool,
cx: &AppContext,
) -> Range<usize> {
let buffer = self.buffer.read(cx);
let start_ix = match self.workflow_steps.binary_search_by(|probe| {
probe
.tagged_range
.end
.to_offset(buffer)
.cmp(&range.start)
.then(if inclusive {
Ordering::Greater
} else {
Ordering::Less
})
}) {
Ok(ix) | Err(ix) => ix,
};
let end_ix = match self.workflow_steps.binary_search_by(|probe| {
probe
.tagged_range
.start
.to_offset(buffer)
.cmp(&range.end)
.then(if inclusive {
Ordering::Less
} else {
Ordering::Greater
})
}) {
Ok(ix) | Err(ix) => ix,
};
start_ix..end_ix
}
fn parse_workflow_steps_in_range( fn parse_workflow_steps_in_range(
&mut self, &mut self,
range: Range<usize>, range: Range<usize>,
@ -1248,8 +1292,12 @@ impl Context {
self.workflow_steps.insert(index, step); self.workflow_steps.insert(index, step);
self.resolve_workflow_step(step_range, project.clone(), cx); self.resolve_workflow_step(step_range, project.clone(), cx);
} }
// Delete <step> tags, making sure we don't accidentally invalidate
// the step we just parsed.
self.buffer self.buffer
.update(cx, |buffer, cx| buffer.edit(edits, None, cx)); .update(cx, |buffer, cx| buffer.edit(edits, None, cx));
self.edits_since_last_workflow_step_prune.consume();
} }
pub fn resolve_workflow_step( pub fn resolve_workflow_step(
@ -1629,6 +1677,8 @@ impl Context {
message_start_offset..message_new_end_offset message_start_offset..message_new_end_offset
}); });
if let Some(project) = this.project.clone() { if let Some(project) = this.project.clone() {
// Use `inclusive = false` as edits might occur at the end of a parsed step.
this.prune_invalid_workflow_steps(false, cx);
this.parse_workflow_steps_in_range(message_range, project, cx); this.parse_workflow_steps_in_range(message_range, project, cx);
} }
cx.emit(ContextEvent::StreamedCompletion); cx.emit(ContextEvent::StreamedCompletion);

View file

@ -84,13 +84,24 @@ impl<T> Outline<T> {
} }
} }
/// Find the most similar symbol to the provided query according to the Jaro-Winkler distance measure. /// Find the most similar symbol to the provided query using normalized Levenshtein distance.
pub fn find_most_similar(&self, query: &str) -> Option<&OutlineItem<T>> { pub fn find_most_similar(&self, query: &str) -> Option<&OutlineItem<T>> {
let candidate = self.path_candidates.iter().max_by(|a, b| { const SIMILARITY_THRESHOLD: f64 = 0.6;
strsim::jaro_winkler(&a.string, query)
.total_cmp(&strsim::jaro_winkler(&b.string, query)) let (item, similarity) = self
})?; .items
Some(&self.items[candidate.id]) .iter()
.map(|item| {
let similarity = strsim::normalized_levenshtein(&item.text, query);
(item, similarity)
})
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())?;
if similarity >= SIMILARITY_THRESHOLD {
Some(item)
} else {
None
}
} }
/// Find all outline symbols according to a longest subsequence match with the query, ordered descending by match score. /// Find all outline symbols according to a longest subsequence match with the query, ordered descending by match score.
@ -208,3 +219,46 @@ pub fn render_item<T>(
StyledText::new(outline_item.text.clone()).with_highlights(&text_style, highlights) StyledText::new(outline_item.text.clone()).with_highlights(&text_style, highlights)
} }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_find_most_similar_with_low_similarity() {
let outline = Outline::new(vec![
OutlineItem {
depth: 0,
range: Point::new(0, 0)..Point::new(5, 0),
text: "fn process".to_string(),
highlight_ranges: vec![],
name_ranges: vec![3..10],
body_range: None,
annotation_range: None,
},
OutlineItem {
depth: 0,
range: Point::new(7, 0)..Point::new(12, 0),
text: "struct DataProcessor".to_string(),
highlight_ranges: vec![],
name_ranges: vec![7..20],
body_range: None,
annotation_range: None,
},
]);
assert_eq!(
outline.find_most_similar("pub fn process"),
Some(&outline.items[0])
);
assert_eq!(
outline.find_most_similar("async fn process"),
Some(&outline.items[0])
);
assert_eq!(
outline.find_most_similar("struct Processor"),
Some(&outline.items[1])
);
assert_eq!(outline.find_most_similar("struct User"), None);
assert_eq!(outline.find_most_similar("struct"), None);
}
}