add embedding query for json with nested arrays and strings
Co-authored-by: maxbrunsfeld <max@zed.dev>
This commit is contained in:
parent
9809ec3d70
commit
efe973ebe2
7 changed files with 189 additions and 59 deletions
|
@ -1,6 +1,12 @@
|
|||
use anyhow::{anyhow, Ok, Result};
|
||||
use language::{Grammar, Language};
|
||||
use std::{cmp, collections::HashSet, ops::Range, path::Path, sync::Arc};
|
||||
use std::{
|
||||
cmp::{self, Reverse},
|
||||
collections::HashSet,
|
||||
ops::Range,
|
||||
path::Path,
|
||||
sync::Arc,
|
||||
};
|
||||
use tree_sitter::{Parser, QueryCursor};
|
||||
|
||||
#[derive(Debug, PartialEq, Clone)]
|
||||
|
@ -15,7 +21,7 @@ const CODE_CONTEXT_TEMPLATE: &str =
|
|||
"The below code snippet is from file '<path>'\n\n```<language>\n<item>\n```";
|
||||
const ENTIRE_FILE_TEMPLATE: &str =
|
||||
"The below snippet is from file '<path>'\n\n```<language>\n<item>\n```";
|
||||
pub const PARSEABLE_ENTIRE_FILE_TYPES: [&str; 4] = ["TOML", "YAML", "JSON", "CSS"];
|
||||
pub const PARSEABLE_ENTIRE_FILE_TYPES: &[&str] = &["TOML", "YAML", "CSS"];
|
||||
|
||||
pub struct CodeContextRetriever {
|
||||
pub parser: Parser,
|
||||
|
@ -30,8 +36,8 @@ pub struct CodeContextRetriever {
|
|||
#[derive(Debug, Clone)]
|
||||
pub struct CodeContextMatch {
|
||||
pub start_col: usize,
|
||||
pub item_range: Range<usize>,
|
||||
pub name_range: Range<usize>,
|
||||
pub item_range: Option<Range<usize>>,
|
||||
pub name_range: Option<Range<usize>>,
|
||||
pub context_ranges: Vec<Range<usize>>,
|
||||
pub collapse_ranges: Vec<Range<usize>>,
|
||||
}
|
||||
|
@ -44,7 +50,7 @@ impl CodeContextRetriever {
|
|||
}
|
||||
}
|
||||
|
||||
fn _parse_entire_file(
|
||||
fn parse_entire_file(
|
||||
&self,
|
||||
relative_path: &Path,
|
||||
language_name: Arc<str>,
|
||||
|
@ -97,7 +103,7 @@ impl CodeContextRetriever {
|
|||
if capture.index == embedding_config.item_capture_ix {
|
||||
item_range = Some(capture.node.byte_range());
|
||||
start_col = capture.node.start_position().column;
|
||||
} else if capture.index == embedding_config.name_capture_ix {
|
||||
} else if Some(capture.index) == embedding_config.name_capture_ix {
|
||||
name_range = Some(capture.node.byte_range());
|
||||
} else if Some(capture.index) == embedding_config.context_capture_ix {
|
||||
context_ranges.push(capture.node.byte_range());
|
||||
|
@ -108,16 +114,13 @@ impl CodeContextRetriever {
|
|||
}
|
||||
}
|
||||
|
||||
if item_range.is_some() && name_range.is_some() {
|
||||
let item_range = item_range.unwrap();
|
||||
captures.push(CodeContextMatch {
|
||||
start_col,
|
||||
item_range,
|
||||
name_range: name_range.unwrap(),
|
||||
context_ranges,
|
||||
collapse_ranges: subtract_ranges(&collapse_ranges, &keep_ranges),
|
||||
});
|
||||
}
|
||||
captures.push(CodeContextMatch {
|
||||
start_col,
|
||||
item_range,
|
||||
name_range,
|
||||
context_ranges,
|
||||
collapse_ranges: subtract_ranges(&collapse_ranges, &keep_ranges),
|
||||
});
|
||||
}
|
||||
Ok(captures)
|
||||
}
|
||||
|
@ -129,7 +132,12 @@ impl CodeContextRetriever {
|
|||
language: Arc<Language>,
|
||||
) -> Result<Vec<Document>> {
|
||||
let language_name = language.name();
|
||||
let mut documents = self.parse_file(relative_path, content, language)?;
|
||||
|
||||
if PARSEABLE_ENTIRE_FILE_TYPES.contains(&language_name.as_ref()) {
|
||||
return self.parse_entire_file(relative_path, language_name, &content);
|
||||
}
|
||||
|
||||
let mut documents = self.parse_file(content, language)?;
|
||||
for document in &mut documents {
|
||||
document.content = CODE_CONTEXT_TEMPLATE
|
||||
.replace("<path>", relative_path.to_string_lossy().as_ref())
|
||||
|
@ -139,16 +147,7 @@ impl CodeContextRetriever {
|
|||
Ok(documents)
|
||||
}
|
||||
|
||||
pub fn parse_file(
|
||||
&mut self,
|
||||
relative_path: &Path,
|
||||
content: &str,
|
||||
language: Arc<Language>,
|
||||
) -> Result<Vec<Document>> {
|
||||
if PARSEABLE_ENTIRE_FILE_TYPES.contains(&language.name().as_ref()) {
|
||||
return self._parse_entire_file(relative_path, language.name(), &content);
|
||||
}
|
||||
|
||||
pub fn parse_file(&mut self, content: &str, language: Arc<Language>) -> Result<Vec<Document>> {
|
||||
let grammar = language
|
||||
.grammar()
|
||||
.ok_or_else(|| anyhow!("no grammar for language"))?;
|
||||
|
@ -163,32 +162,49 @@ impl CodeContextRetriever {
|
|||
let mut collapsed_ranges_within = Vec::new();
|
||||
let mut parsed_name_ranges = HashSet::new();
|
||||
for (i, context_match) in matches.iter().enumerate() {
|
||||
if parsed_name_ranges.contains(&context_match.name_range) {
|
||||
// Items which are collapsible but not embeddable have no item range
|
||||
let item_range = if let Some(item_range) = context_match.item_range.clone() {
|
||||
item_range
|
||||
} else {
|
||||
continue;
|
||||
};
|
||||
|
||||
// Checks for deduplication
|
||||
let name;
|
||||
if let Some(name_range) = context_match.name_range.clone() {
|
||||
name = content
|
||||
.get(name_range.clone())
|
||||
.map_or(String::new(), |s| s.to_string());
|
||||
if parsed_name_ranges.contains(&name_range) {
|
||||
continue;
|
||||
}
|
||||
parsed_name_ranges.insert(name_range);
|
||||
} else {
|
||||
name = String::new();
|
||||
}
|
||||
|
||||
collapsed_ranges_within.clear();
|
||||
for remaining_match in &matches[(i + 1)..] {
|
||||
if context_match
|
||||
.item_range
|
||||
.contains(&remaining_match.item_range.start)
|
||||
&& context_match
|
||||
.item_range
|
||||
.contains(&remaining_match.item_range.end)
|
||||
{
|
||||
collapsed_ranges_within.extend(remaining_match.collapse_ranges.iter().cloned());
|
||||
} else {
|
||||
break;
|
||||
'outer: for remaining_match in &matches[(i + 1)..] {
|
||||
for collapsed_range in &remaining_match.collapse_ranges {
|
||||
if item_range.start <= collapsed_range.start
|
||||
&& item_range.end >= collapsed_range.end
|
||||
{
|
||||
collapsed_ranges_within.push(collapsed_range.clone());
|
||||
} else {
|
||||
break 'outer;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
collapsed_ranges_within.sort_by_key(|r| (r.start, Reverse(r.end)));
|
||||
|
||||
let mut document_content = String::new();
|
||||
for context_range in &context_match.context_ranges {
|
||||
document_content.push_str(&content[context_range.clone()]);
|
||||
document_content.push_str("\n");
|
||||
}
|
||||
|
||||
let mut offset = context_match.item_range.start;
|
||||
let mut offset = item_range.start;
|
||||
for collapsed_range in &collapsed_ranges_within {
|
||||
if collapsed_range.start > offset {
|
||||
add_content_from_range(
|
||||
|
@ -197,29 +213,30 @@ impl CodeContextRetriever {
|
|||
offset..collapsed_range.start,
|
||||
context_match.start_col,
|
||||
);
|
||||
offset = collapsed_range.start;
|
||||
}
|
||||
|
||||
if collapsed_range.end > offset {
|
||||
document_content.push_str(placeholder);
|
||||
offset = collapsed_range.end;
|
||||
}
|
||||
document_content.push_str(placeholder);
|
||||
offset = collapsed_range.end;
|
||||
}
|
||||
|
||||
if offset < context_match.item_range.end {
|
||||
if offset < item_range.end {
|
||||
add_content_from_range(
|
||||
&mut document_content,
|
||||
content,
|
||||
offset..context_match.item_range.end,
|
||||
offset..item_range.end,
|
||||
context_match.start_col,
|
||||
);
|
||||
}
|
||||
|
||||
if let Some(name) = content.get(context_match.name_range.clone()) {
|
||||
parsed_name_ranges.insert(context_match.name_range.clone());
|
||||
documents.push(Document {
|
||||
name: name.to_string(),
|
||||
content: document_content,
|
||||
range: context_match.item_range.clone(),
|
||||
embedding: vec![],
|
||||
})
|
||||
}
|
||||
documents.push(Document {
|
||||
name,
|
||||
content: document_content,
|
||||
range: item_range.clone(),
|
||||
embedding: vec![],
|
||||
})
|
||||
}
|
||||
|
||||
return Ok(documents);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue