add embedding query for json with nested arrays and strings
Co-authored-by: maxbrunsfeld <max@zed.dev>
This commit is contained in:
parent
9809ec3d70
commit
efe973ebe2
7 changed files with 189 additions and 59 deletions
1
Cargo.lock
generated
1
Cargo.lock
generated
|
@ -6502,6 +6502,7 @@ dependencies = [
|
||||||
"tree-sitter",
|
"tree-sitter",
|
||||||
"tree-sitter-cpp",
|
"tree-sitter-cpp",
|
||||||
"tree-sitter-elixir 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
"tree-sitter-elixir 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
|
"tree-sitter-json 0.19.0",
|
||||||
"tree-sitter-rust",
|
"tree-sitter-rust",
|
||||||
"tree-sitter-toml 0.20.0",
|
"tree-sitter-toml 0.20.0",
|
||||||
"tree-sitter-typescript 0.20.2 (registry+https://github.com/rust-lang/crates.io-index)",
|
"tree-sitter-typescript 0.20.2 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
|
|
|
@ -526,7 +526,7 @@ pub struct OutlineConfig {
|
||||||
pub struct EmbeddingConfig {
|
pub struct EmbeddingConfig {
|
||||||
pub query: Query,
|
pub query: Query,
|
||||||
pub item_capture_ix: u32,
|
pub item_capture_ix: u32,
|
||||||
pub name_capture_ix: u32,
|
pub name_capture_ix: Option<u32>,
|
||||||
pub context_capture_ix: Option<u32>,
|
pub context_capture_ix: Option<u32>,
|
||||||
pub collapse_capture_ix: Option<u32>,
|
pub collapse_capture_ix: Option<u32>,
|
||||||
pub keep_capture_ix: Option<u32>,
|
pub keep_capture_ix: Option<u32>,
|
||||||
|
@ -1263,7 +1263,7 @@ impl Language {
|
||||||
("collapse", &mut collapse_capture_ix),
|
("collapse", &mut collapse_capture_ix),
|
||||||
],
|
],
|
||||||
);
|
);
|
||||||
if let Some((item_capture_ix, name_capture_ix)) = item_capture_ix.zip(name_capture_ix) {
|
if let Some(item_capture_ix) = item_capture_ix {
|
||||||
grammar.embedding_config = Some(EmbeddingConfig {
|
grammar.embedding_config = Some(EmbeddingConfig {
|
||||||
query,
|
query,
|
||||||
item_capture_ix,
|
item_capture_ix,
|
||||||
|
|
|
@ -54,6 +54,7 @@ ctor.workspace = true
|
||||||
env_logger.workspace = true
|
env_logger.workspace = true
|
||||||
|
|
||||||
tree-sitter-typescript = "*"
|
tree-sitter-typescript = "*"
|
||||||
|
tree-sitter-json = "*"
|
||||||
tree-sitter-rust = "*"
|
tree-sitter-rust = "*"
|
||||||
tree-sitter-toml = "*"
|
tree-sitter-toml = "*"
|
||||||
tree-sitter-cpp = "*"
|
tree-sitter-cpp = "*"
|
||||||
|
|
|
@ -1,6 +1,12 @@
|
||||||
use anyhow::{anyhow, Ok, Result};
|
use anyhow::{anyhow, Ok, Result};
|
||||||
use language::{Grammar, Language};
|
use language::{Grammar, Language};
|
||||||
use std::{cmp, collections::HashSet, ops::Range, path::Path, sync::Arc};
|
use std::{
|
||||||
|
cmp::{self, Reverse},
|
||||||
|
collections::HashSet,
|
||||||
|
ops::Range,
|
||||||
|
path::Path,
|
||||||
|
sync::Arc,
|
||||||
|
};
|
||||||
use tree_sitter::{Parser, QueryCursor};
|
use tree_sitter::{Parser, QueryCursor};
|
||||||
|
|
||||||
#[derive(Debug, PartialEq, Clone)]
|
#[derive(Debug, PartialEq, Clone)]
|
||||||
|
@ -15,7 +21,7 @@ const CODE_CONTEXT_TEMPLATE: &str =
|
||||||
"The below code snippet is from file '<path>'\n\n```<language>\n<item>\n```";
|
"The below code snippet is from file '<path>'\n\n```<language>\n<item>\n```";
|
||||||
const ENTIRE_FILE_TEMPLATE: &str =
|
const ENTIRE_FILE_TEMPLATE: &str =
|
||||||
"The below snippet is from file '<path>'\n\n```<language>\n<item>\n```";
|
"The below snippet is from file '<path>'\n\n```<language>\n<item>\n```";
|
||||||
pub const PARSEABLE_ENTIRE_FILE_TYPES: [&str; 4] = ["TOML", "YAML", "JSON", "CSS"];
|
pub const PARSEABLE_ENTIRE_FILE_TYPES: &[&str] = &["TOML", "YAML", "CSS"];
|
||||||
|
|
||||||
pub struct CodeContextRetriever {
|
pub struct CodeContextRetriever {
|
||||||
pub parser: Parser,
|
pub parser: Parser,
|
||||||
|
@ -30,8 +36,8 @@ pub struct CodeContextRetriever {
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct CodeContextMatch {
|
pub struct CodeContextMatch {
|
||||||
pub start_col: usize,
|
pub start_col: usize,
|
||||||
pub item_range: Range<usize>,
|
pub item_range: Option<Range<usize>>,
|
||||||
pub name_range: Range<usize>,
|
pub name_range: Option<Range<usize>>,
|
||||||
pub context_ranges: Vec<Range<usize>>,
|
pub context_ranges: Vec<Range<usize>>,
|
||||||
pub collapse_ranges: Vec<Range<usize>>,
|
pub collapse_ranges: Vec<Range<usize>>,
|
||||||
}
|
}
|
||||||
|
@ -44,7 +50,7 @@ impl CodeContextRetriever {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn _parse_entire_file(
|
fn parse_entire_file(
|
||||||
&self,
|
&self,
|
||||||
relative_path: &Path,
|
relative_path: &Path,
|
||||||
language_name: Arc<str>,
|
language_name: Arc<str>,
|
||||||
|
@ -97,7 +103,7 @@ impl CodeContextRetriever {
|
||||||
if capture.index == embedding_config.item_capture_ix {
|
if capture.index == embedding_config.item_capture_ix {
|
||||||
item_range = Some(capture.node.byte_range());
|
item_range = Some(capture.node.byte_range());
|
||||||
start_col = capture.node.start_position().column;
|
start_col = capture.node.start_position().column;
|
||||||
} else if capture.index == embedding_config.name_capture_ix {
|
} else if Some(capture.index) == embedding_config.name_capture_ix {
|
||||||
name_range = Some(capture.node.byte_range());
|
name_range = Some(capture.node.byte_range());
|
||||||
} else if Some(capture.index) == embedding_config.context_capture_ix {
|
} else if Some(capture.index) == embedding_config.context_capture_ix {
|
||||||
context_ranges.push(capture.node.byte_range());
|
context_ranges.push(capture.node.byte_range());
|
||||||
|
@ -108,16 +114,13 @@ impl CodeContextRetriever {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if item_range.is_some() && name_range.is_some() {
|
captures.push(CodeContextMatch {
|
||||||
let item_range = item_range.unwrap();
|
start_col,
|
||||||
captures.push(CodeContextMatch {
|
item_range,
|
||||||
start_col,
|
name_range,
|
||||||
item_range,
|
context_ranges,
|
||||||
name_range: name_range.unwrap(),
|
collapse_ranges: subtract_ranges(&collapse_ranges, &keep_ranges),
|
||||||
context_ranges,
|
});
|
||||||
collapse_ranges: subtract_ranges(&collapse_ranges, &keep_ranges),
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
Ok(captures)
|
Ok(captures)
|
||||||
}
|
}
|
||||||
|
@ -129,7 +132,12 @@ impl CodeContextRetriever {
|
||||||
language: Arc<Language>,
|
language: Arc<Language>,
|
||||||
) -> Result<Vec<Document>> {
|
) -> Result<Vec<Document>> {
|
||||||
let language_name = language.name();
|
let language_name = language.name();
|
||||||
let mut documents = self.parse_file(relative_path, content, language)?;
|
|
||||||
|
if PARSEABLE_ENTIRE_FILE_TYPES.contains(&language_name.as_ref()) {
|
||||||
|
return self.parse_entire_file(relative_path, language_name, &content);
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut documents = self.parse_file(content, language)?;
|
||||||
for document in &mut documents {
|
for document in &mut documents {
|
||||||
document.content = CODE_CONTEXT_TEMPLATE
|
document.content = CODE_CONTEXT_TEMPLATE
|
||||||
.replace("<path>", relative_path.to_string_lossy().as_ref())
|
.replace("<path>", relative_path.to_string_lossy().as_ref())
|
||||||
|
@ -139,16 +147,7 @@ impl CodeContextRetriever {
|
||||||
Ok(documents)
|
Ok(documents)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn parse_file(
|
pub fn parse_file(&mut self, content: &str, language: Arc<Language>) -> Result<Vec<Document>> {
|
||||||
&mut self,
|
|
||||||
relative_path: &Path,
|
|
||||||
content: &str,
|
|
||||||
language: Arc<Language>,
|
|
||||||
) -> Result<Vec<Document>> {
|
|
||||||
if PARSEABLE_ENTIRE_FILE_TYPES.contains(&language.name().as_ref()) {
|
|
||||||
return self._parse_entire_file(relative_path, language.name(), &content);
|
|
||||||
}
|
|
||||||
|
|
||||||
let grammar = language
|
let grammar = language
|
||||||
.grammar()
|
.grammar()
|
||||||
.ok_or_else(|| anyhow!("no grammar for language"))?;
|
.ok_or_else(|| anyhow!("no grammar for language"))?;
|
||||||
|
@ -163,32 +162,49 @@ impl CodeContextRetriever {
|
||||||
let mut collapsed_ranges_within = Vec::new();
|
let mut collapsed_ranges_within = Vec::new();
|
||||||
let mut parsed_name_ranges = HashSet::new();
|
let mut parsed_name_ranges = HashSet::new();
|
||||||
for (i, context_match) in matches.iter().enumerate() {
|
for (i, context_match) in matches.iter().enumerate() {
|
||||||
if parsed_name_ranges.contains(&context_match.name_range) {
|
// Items which are collapsible but not embeddable have no item range
|
||||||
|
let item_range = if let Some(item_range) = context_match.item_range.clone() {
|
||||||
|
item_range
|
||||||
|
} else {
|
||||||
continue;
|
continue;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Checks for deduplication
|
||||||
|
let name;
|
||||||
|
if let Some(name_range) = context_match.name_range.clone() {
|
||||||
|
name = content
|
||||||
|
.get(name_range.clone())
|
||||||
|
.map_or(String::new(), |s| s.to_string());
|
||||||
|
if parsed_name_ranges.contains(&name_range) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
parsed_name_ranges.insert(name_range);
|
||||||
|
} else {
|
||||||
|
name = String::new();
|
||||||
}
|
}
|
||||||
|
|
||||||
collapsed_ranges_within.clear();
|
collapsed_ranges_within.clear();
|
||||||
for remaining_match in &matches[(i + 1)..] {
|
'outer: for remaining_match in &matches[(i + 1)..] {
|
||||||
if context_match
|
for collapsed_range in &remaining_match.collapse_ranges {
|
||||||
.item_range
|
if item_range.start <= collapsed_range.start
|
||||||
.contains(&remaining_match.item_range.start)
|
&& item_range.end >= collapsed_range.end
|
||||||
&& context_match
|
{
|
||||||
.item_range
|
collapsed_ranges_within.push(collapsed_range.clone());
|
||||||
.contains(&remaining_match.item_range.end)
|
} else {
|
||||||
{
|
break 'outer;
|
||||||
collapsed_ranges_within.extend(remaining_match.collapse_ranges.iter().cloned());
|
}
|
||||||
} else {
|
|
||||||
break;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
collapsed_ranges_within.sort_by_key(|r| (r.start, Reverse(r.end)));
|
||||||
|
|
||||||
let mut document_content = String::new();
|
let mut document_content = String::new();
|
||||||
for context_range in &context_match.context_ranges {
|
for context_range in &context_match.context_ranges {
|
||||||
document_content.push_str(&content[context_range.clone()]);
|
document_content.push_str(&content[context_range.clone()]);
|
||||||
document_content.push_str("\n");
|
document_content.push_str("\n");
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut offset = context_match.item_range.start;
|
let mut offset = item_range.start;
|
||||||
for collapsed_range in &collapsed_ranges_within {
|
for collapsed_range in &collapsed_ranges_within {
|
||||||
if collapsed_range.start > offset {
|
if collapsed_range.start > offset {
|
||||||
add_content_from_range(
|
add_content_from_range(
|
||||||
|
@ -197,29 +213,30 @@ impl CodeContextRetriever {
|
||||||
offset..collapsed_range.start,
|
offset..collapsed_range.start,
|
||||||
context_match.start_col,
|
context_match.start_col,
|
||||||
);
|
);
|
||||||
|
offset = collapsed_range.start;
|
||||||
|
}
|
||||||
|
|
||||||
|
if collapsed_range.end > offset {
|
||||||
|
document_content.push_str(placeholder);
|
||||||
|
offset = collapsed_range.end;
|
||||||
}
|
}
|
||||||
document_content.push_str(placeholder);
|
|
||||||
offset = collapsed_range.end;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if offset < context_match.item_range.end {
|
if offset < item_range.end {
|
||||||
add_content_from_range(
|
add_content_from_range(
|
||||||
&mut document_content,
|
&mut document_content,
|
||||||
content,
|
content,
|
||||||
offset..context_match.item_range.end,
|
offset..item_range.end,
|
||||||
context_match.start_col,
|
context_match.start_col,
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(name) = content.get(context_match.name_range.clone()) {
|
documents.push(Document {
|
||||||
parsed_name_ranges.insert(context_match.name_range.clone());
|
name,
|
||||||
documents.push(Document {
|
content: document_content,
|
||||||
name: name.to_string(),
|
range: item_range.clone(),
|
||||||
content: document_content,
|
embedding: vec![],
|
||||||
range: context_match.item_range.clone(),
|
})
|
||||||
embedding: vec![],
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return Ok(documents);
|
return Ok(documents);
|
||||||
|
|
|
@ -33,7 +33,7 @@ use util::{
|
||||||
ResultExt,
|
ResultExt,
|
||||||
};
|
};
|
||||||
|
|
||||||
const SEMANTIC_INDEX_VERSION: usize = 4;
|
const SEMANTIC_INDEX_VERSION: usize = 5;
|
||||||
const EMBEDDINGS_BATCH_SIZE: usize = 80;
|
const EMBEDDINGS_BATCH_SIZE: usize = 80;
|
||||||
|
|
||||||
pub fn init(
|
pub fn init(
|
||||||
|
|
|
@ -170,9 +170,7 @@ async fn test_code_context_retrieval_rust() {
|
||||||
"
|
"
|
||||||
.unindent();
|
.unindent();
|
||||||
|
|
||||||
let documents = retriever
|
let documents = retriever.parse_file(&text, language).unwrap();
|
||||||
.parse_file(Path::new("foo.rs"), &text, language)
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
assert_documents_eq(
|
assert_documents_eq(
|
||||||
&documents,
|
&documents,
|
||||||
|
@ -229,6 +227,76 @@ async fn test_code_context_retrieval_rust() {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[gpui::test]
|
||||||
|
async fn test_code_context_retrieval_json() {
|
||||||
|
let language = json_lang();
|
||||||
|
let mut retriever = CodeContextRetriever::new();
|
||||||
|
|
||||||
|
let text = r#"
|
||||||
|
{
|
||||||
|
"array": [1, 2, 3, 4],
|
||||||
|
"string": "abcdefg",
|
||||||
|
"nested_object": {
|
||||||
|
"array_2": [5, 6, 7, 8],
|
||||||
|
"string_2": "hijklmnop",
|
||||||
|
"boolean": true,
|
||||||
|
"none": null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"#
|
||||||
|
.unindent();
|
||||||
|
|
||||||
|
let documents = retriever.parse_file(&text, language.clone()).unwrap();
|
||||||
|
|
||||||
|
assert_documents_eq(
|
||||||
|
&documents,
|
||||||
|
&[(
|
||||||
|
r#"
|
||||||
|
{
|
||||||
|
"array": [],
|
||||||
|
"string": "",
|
||||||
|
"nested_object": {
|
||||||
|
"array_2": [],
|
||||||
|
"string_2": "",
|
||||||
|
"boolean": true,
|
||||||
|
"none": null
|
||||||
|
}
|
||||||
|
}"#
|
||||||
|
.unindent(),
|
||||||
|
text.find("{").unwrap(),
|
||||||
|
)],
|
||||||
|
);
|
||||||
|
|
||||||
|
let text = r#"
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"name": "somebody",
|
||||||
|
"age": 42
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "somebody else",
|
||||||
|
"age": 43
|
||||||
|
}
|
||||||
|
]
|
||||||
|
"#
|
||||||
|
.unindent();
|
||||||
|
|
||||||
|
let documents = retriever.parse_file(&text, language.clone()).unwrap();
|
||||||
|
|
||||||
|
assert_documents_eq(
|
||||||
|
&documents,
|
||||||
|
&[(
|
||||||
|
r#"
|
||||||
|
[{
|
||||||
|
"name": "",
|
||||||
|
"age": 42
|
||||||
|
}]"#
|
||||||
|
.unindent(),
|
||||||
|
text.find("[").unwrap(),
|
||||||
|
)],
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
fn assert_documents_eq(
|
fn assert_documents_eq(
|
||||||
documents: &[Document],
|
documents: &[Document],
|
||||||
expected_contents_and_start_offsets: &[(String, usize)],
|
expected_contents_and_start_offsets: &[(String, usize)],
|
||||||
|
@ -913,6 +981,35 @@ fn rust_lang() -> Arc<Language> {
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn json_lang() -> Arc<Language> {
|
||||||
|
Arc::new(
|
||||||
|
Language::new(
|
||||||
|
LanguageConfig {
|
||||||
|
name: "JSON".into(),
|
||||||
|
path_suffixes: vec!["json".into()],
|
||||||
|
..Default::default()
|
||||||
|
},
|
||||||
|
Some(tree_sitter_json::language()),
|
||||||
|
)
|
||||||
|
.with_embedding_query(
|
||||||
|
r#"
|
||||||
|
(document) @item
|
||||||
|
|
||||||
|
(array
|
||||||
|
"[" @keep
|
||||||
|
.
|
||||||
|
(object)? @keep
|
||||||
|
"]" @keep) @collapse
|
||||||
|
|
||||||
|
(pair value: (string
|
||||||
|
"\"" @keep
|
||||||
|
"\"" @keep) @collapse)
|
||||||
|
"#,
|
||||||
|
)
|
||||||
|
.unwrap(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
fn toml_lang() -> Arc<Language> {
|
fn toml_lang() -> Arc<Language> {
|
||||||
Arc::new(Language::new(
|
Arc::new(Language::new(
|
||||||
LanguageConfig {
|
LanguageConfig {
|
||||||
|
|
14
crates/zed/src/languages/json/embedding.scm
Normal file
14
crates/zed/src/languages/json/embedding.scm
Normal file
|
@ -0,0 +1,14 @@
|
||||||
|
; Only produce one embedding for the entire file.
|
||||||
|
(document) @item
|
||||||
|
|
||||||
|
; Collapse arrays, except for the first object.
|
||||||
|
(array
|
||||||
|
"[" @keep
|
||||||
|
.
|
||||||
|
(object)? @keep
|
||||||
|
"]" @keep) @collapse
|
||||||
|
|
||||||
|
; Collapse string values (but not keys).
|
||||||
|
(pair value: (string
|
||||||
|
"\"" @keep
|
||||||
|
"\"" @keep) @collapse)
|
Loading…
Add table
Add a link
Reference in a new issue