add embedding query for json with nested arrays and strings

Co-authored-by: maxbrunsfeld <max@zed.dev>
This commit is contained in:
KCaverly 2023-07-19 16:52:44 -04:00
parent 9809ec3d70
commit efe973ebe2
7 changed files with 189 additions and 59 deletions

1
Cargo.lock generated
View file

@ -6502,6 +6502,7 @@ dependencies = [
"tree-sitter", "tree-sitter",
"tree-sitter-cpp", "tree-sitter-cpp",
"tree-sitter-elixir 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)", "tree-sitter-elixir 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)",
"tree-sitter-json 0.19.0",
"tree-sitter-rust", "tree-sitter-rust",
"tree-sitter-toml 0.20.0", "tree-sitter-toml 0.20.0",
"tree-sitter-typescript 0.20.2 (registry+https://github.com/rust-lang/crates.io-index)", "tree-sitter-typescript 0.20.2 (registry+https://github.com/rust-lang/crates.io-index)",

View file

@ -526,7 +526,7 @@ pub struct OutlineConfig {
pub struct EmbeddingConfig { pub struct EmbeddingConfig {
pub query: Query, pub query: Query,
pub item_capture_ix: u32, pub item_capture_ix: u32,
pub name_capture_ix: u32, pub name_capture_ix: Option<u32>,
pub context_capture_ix: Option<u32>, pub context_capture_ix: Option<u32>,
pub collapse_capture_ix: Option<u32>, pub collapse_capture_ix: Option<u32>,
pub keep_capture_ix: Option<u32>, pub keep_capture_ix: Option<u32>,
@ -1263,7 +1263,7 @@ impl Language {
("collapse", &mut collapse_capture_ix), ("collapse", &mut collapse_capture_ix),
], ],
); );
if let Some((item_capture_ix, name_capture_ix)) = item_capture_ix.zip(name_capture_ix) { if let Some(item_capture_ix) = item_capture_ix {
grammar.embedding_config = Some(EmbeddingConfig { grammar.embedding_config = Some(EmbeddingConfig {
query, query,
item_capture_ix, item_capture_ix,

View file

@ -54,6 +54,7 @@ ctor.workspace = true
env_logger.workspace = true env_logger.workspace = true
tree-sitter-typescript = "*" tree-sitter-typescript = "*"
tree-sitter-json = "*"
tree-sitter-rust = "*" tree-sitter-rust = "*"
tree-sitter-toml = "*" tree-sitter-toml = "*"
tree-sitter-cpp = "*" tree-sitter-cpp = "*"

View file

@ -1,6 +1,12 @@
use anyhow::{anyhow, Ok, Result}; use anyhow::{anyhow, Ok, Result};
use language::{Grammar, Language}; use language::{Grammar, Language};
use std::{cmp, collections::HashSet, ops::Range, path::Path, sync::Arc}; use std::{
cmp::{self, Reverse},
collections::HashSet,
ops::Range,
path::Path,
sync::Arc,
};
use tree_sitter::{Parser, QueryCursor}; use tree_sitter::{Parser, QueryCursor};
#[derive(Debug, PartialEq, Clone)] #[derive(Debug, PartialEq, Clone)]
@ -15,7 +21,7 @@ const CODE_CONTEXT_TEMPLATE: &str =
"The below code snippet is from file '<path>'\n\n```<language>\n<item>\n```"; "The below code snippet is from file '<path>'\n\n```<language>\n<item>\n```";
const ENTIRE_FILE_TEMPLATE: &str = const ENTIRE_FILE_TEMPLATE: &str =
"The below snippet is from file '<path>'\n\n```<language>\n<item>\n```"; "The below snippet is from file '<path>'\n\n```<language>\n<item>\n```";
pub const PARSEABLE_ENTIRE_FILE_TYPES: [&str; 4] = ["TOML", "YAML", "JSON", "CSS"]; pub const PARSEABLE_ENTIRE_FILE_TYPES: &[&str] = &["TOML", "YAML", "CSS"];
pub struct CodeContextRetriever { pub struct CodeContextRetriever {
pub parser: Parser, pub parser: Parser,
@ -30,8 +36,8 @@ pub struct CodeContextRetriever {
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct CodeContextMatch { pub struct CodeContextMatch {
pub start_col: usize, pub start_col: usize,
pub item_range: Range<usize>, pub item_range: Option<Range<usize>>,
pub name_range: Range<usize>, pub name_range: Option<Range<usize>>,
pub context_ranges: Vec<Range<usize>>, pub context_ranges: Vec<Range<usize>>,
pub collapse_ranges: Vec<Range<usize>>, pub collapse_ranges: Vec<Range<usize>>,
} }
@ -44,7 +50,7 @@ impl CodeContextRetriever {
} }
} }
fn _parse_entire_file( fn parse_entire_file(
&self, &self,
relative_path: &Path, relative_path: &Path,
language_name: Arc<str>, language_name: Arc<str>,
@ -97,7 +103,7 @@ impl CodeContextRetriever {
if capture.index == embedding_config.item_capture_ix { if capture.index == embedding_config.item_capture_ix {
item_range = Some(capture.node.byte_range()); item_range = Some(capture.node.byte_range());
start_col = capture.node.start_position().column; start_col = capture.node.start_position().column;
} else if capture.index == embedding_config.name_capture_ix { } else if Some(capture.index) == embedding_config.name_capture_ix {
name_range = Some(capture.node.byte_range()); name_range = Some(capture.node.byte_range());
} else if Some(capture.index) == embedding_config.context_capture_ix { } else if Some(capture.index) == embedding_config.context_capture_ix {
context_ranges.push(capture.node.byte_range()); context_ranges.push(capture.node.byte_range());
@ -108,16 +114,13 @@ impl CodeContextRetriever {
} }
} }
if item_range.is_some() && name_range.is_some() { captures.push(CodeContextMatch {
let item_range = item_range.unwrap(); start_col,
captures.push(CodeContextMatch { item_range,
start_col, name_range,
item_range, context_ranges,
name_range: name_range.unwrap(), collapse_ranges: subtract_ranges(&collapse_ranges, &keep_ranges),
context_ranges, });
collapse_ranges: subtract_ranges(&collapse_ranges, &keep_ranges),
});
}
} }
Ok(captures) Ok(captures)
} }
@ -129,7 +132,12 @@ impl CodeContextRetriever {
language: Arc<Language>, language: Arc<Language>,
) -> Result<Vec<Document>> { ) -> Result<Vec<Document>> {
let language_name = language.name(); let language_name = language.name();
let mut documents = self.parse_file(relative_path, content, language)?;
if PARSEABLE_ENTIRE_FILE_TYPES.contains(&language_name.as_ref()) {
return self.parse_entire_file(relative_path, language_name, &content);
}
let mut documents = self.parse_file(content, language)?;
for document in &mut documents { for document in &mut documents {
document.content = CODE_CONTEXT_TEMPLATE document.content = CODE_CONTEXT_TEMPLATE
.replace("<path>", relative_path.to_string_lossy().as_ref()) .replace("<path>", relative_path.to_string_lossy().as_ref())
@ -139,16 +147,7 @@ impl CodeContextRetriever {
Ok(documents) Ok(documents)
} }
pub fn parse_file( pub fn parse_file(&mut self, content: &str, language: Arc<Language>) -> Result<Vec<Document>> {
&mut self,
relative_path: &Path,
content: &str,
language: Arc<Language>,
) -> Result<Vec<Document>> {
if PARSEABLE_ENTIRE_FILE_TYPES.contains(&language.name().as_ref()) {
return self._parse_entire_file(relative_path, language.name(), &content);
}
let grammar = language let grammar = language
.grammar() .grammar()
.ok_or_else(|| anyhow!("no grammar for language"))?; .ok_or_else(|| anyhow!("no grammar for language"))?;
@ -163,32 +162,49 @@ impl CodeContextRetriever {
let mut collapsed_ranges_within = Vec::new(); let mut collapsed_ranges_within = Vec::new();
let mut parsed_name_ranges = HashSet::new(); let mut parsed_name_ranges = HashSet::new();
for (i, context_match) in matches.iter().enumerate() { for (i, context_match) in matches.iter().enumerate() {
if parsed_name_ranges.contains(&context_match.name_range) { // Items which are collapsible but not embeddable have no item range
let item_range = if let Some(item_range) = context_match.item_range.clone() {
item_range
} else {
continue; continue;
};
// Checks for deduplication
let name;
if let Some(name_range) = context_match.name_range.clone() {
name = content
.get(name_range.clone())
.map_or(String::new(), |s| s.to_string());
if parsed_name_ranges.contains(&name_range) {
continue;
}
parsed_name_ranges.insert(name_range);
} else {
name = String::new();
} }
collapsed_ranges_within.clear(); collapsed_ranges_within.clear();
for remaining_match in &matches[(i + 1)..] { 'outer: for remaining_match in &matches[(i + 1)..] {
if context_match for collapsed_range in &remaining_match.collapse_ranges {
.item_range if item_range.start <= collapsed_range.start
.contains(&remaining_match.item_range.start) && item_range.end >= collapsed_range.end
&& context_match {
.item_range collapsed_ranges_within.push(collapsed_range.clone());
.contains(&remaining_match.item_range.end) } else {
{ break 'outer;
collapsed_ranges_within.extend(remaining_match.collapse_ranges.iter().cloned()); }
} else {
break;
} }
} }
collapsed_ranges_within.sort_by_key(|r| (r.start, Reverse(r.end)));
let mut document_content = String::new(); let mut document_content = String::new();
for context_range in &context_match.context_ranges { for context_range in &context_match.context_ranges {
document_content.push_str(&content[context_range.clone()]); document_content.push_str(&content[context_range.clone()]);
document_content.push_str("\n"); document_content.push_str("\n");
} }
let mut offset = context_match.item_range.start; let mut offset = item_range.start;
for collapsed_range in &collapsed_ranges_within { for collapsed_range in &collapsed_ranges_within {
if collapsed_range.start > offset { if collapsed_range.start > offset {
add_content_from_range( add_content_from_range(
@ -197,29 +213,30 @@ impl CodeContextRetriever {
offset..collapsed_range.start, offset..collapsed_range.start,
context_match.start_col, context_match.start_col,
); );
offset = collapsed_range.start;
}
if collapsed_range.end > offset {
document_content.push_str(placeholder);
offset = collapsed_range.end;
} }
document_content.push_str(placeholder);
offset = collapsed_range.end;
} }
if offset < context_match.item_range.end { if offset < item_range.end {
add_content_from_range( add_content_from_range(
&mut document_content, &mut document_content,
content, content,
offset..context_match.item_range.end, offset..item_range.end,
context_match.start_col, context_match.start_col,
); );
} }
if let Some(name) = content.get(context_match.name_range.clone()) { documents.push(Document {
parsed_name_ranges.insert(context_match.name_range.clone()); name,
documents.push(Document { content: document_content,
name: name.to_string(), range: item_range.clone(),
content: document_content, embedding: vec![],
range: context_match.item_range.clone(), })
embedding: vec![],
})
}
} }
return Ok(documents); return Ok(documents);

View file

@ -33,7 +33,7 @@ use util::{
ResultExt, ResultExt,
}; };
const SEMANTIC_INDEX_VERSION: usize = 4; const SEMANTIC_INDEX_VERSION: usize = 5;
const EMBEDDINGS_BATCH_SIZE: usize = 80; const EMBEDDINGS_BATCH_SIZE: usize = 80;
pub fn init( pub fn init(

View file

@ -170,9 +170,7 @@ async fn test_code_context_retrieval_rust() {
" "
.unindent(); .unindent();
let documents = retriever let documents = retriever.parse_file(&text, language).unwrap();
.parse_file(Path::new("foo.rs"), &text, language)
.unwrap();
assert_documents_eq( assert_documents_eq(
&documents, &documents,
@ -229,6 +227,76 @@ async fn test_code_context_retrieval_rust() {
); );
} }
#[gpui::test]
async fn test_code_context_retrieval_json() {
let language = json_lang();
let mut retriever = CodeContextRetriever::new();
let text = r#"
{
"array": [1, 2, 3, 4],
"string": "abcdefg",
"nested_object": {
"array_2": [5, 6, 7, 8],
"string_2": "hijklmnop",
"boolean": true,
"none": null
}
}
"#
.unindent();
let documents = retriever.parse_file(&text, language.clone()).unwrap();
assert_documents_eq(
&documents,
&[(
r#"
{
"array": [],
"string": "",
"nested_object": {
"array_2": [],
"string_2": "",
"boolean": true,
"none": null
}
}"#
.unindent(),
text.find("{").unwrap(),
)],
);
let text = r#"
[
{
"name": "somebody",
"age": 42
},
{
"name": "somebody else",
"age": 43
}
]
"#
.unindent();
let documents = retriever.parse_file(&text, language.clone()).unwrap();
assert_documents_eq(
&documents,
&[(
r#"
[{
"name": "",
"age": 42
}]"#
.unindent(),
text.find("[").unwrap(),
)],
);
}
fn assert_documents_eq( fn assert_documents_eq(
documents: &[Document], documents: &[Document],
expected_contents_and_start_offsets: &[(String, usize)], expected_contents_and_start_offsets: &[(String, usize)],
@ -913,6 +981,35 @@ fn rust_lang() -> Arc<Language> {
) )
} }
fn json_lang() -> Arc<Language> {
Arc::new(
Language::new(
LanguageConfig {
name: "JSON".into(),
path_suffixes: vec!["json".into()],
..Default::default()
},
Some(tree_sitter_json::language()),
)
.with_embedding_query(
r#"
(document) @item
(array
"[" @keep
.
(object)? @keep
"]" @keep) @collapse
(pair value: (string
"\"" @keep
"\"" @keep) @collapse)
"#,
)
.unwrap(),
)
}
fn toml_lang() -> Arc<Language> { fn toml_lang() -> Arc<Language> {
Arc::new(Language::new( Arc::new(Language::new(
LanguageConfig { LanguageConfig {

View file

@ -0,0 +1,14 @@
; Only produce one embedding for the entire file.
(document) @item
; Collapse arrays, except for the first object.
(array
"[" @keep
.
(object)? @keep
"]" @keep) @collapse
; Collapse string values (but not keys).
(pair value: (string
"\"" @keep
"\"" @keep) @collapse)