add tests for rust context parsing, and update rust embedding query

Co-authored-by: maxbrunsfeld <max@zed.dev>
This commit is contained in:
KCaverly 2023-07-13 16:58:42 -04:00
parent 0a0e40fb24
commit 623cb9833c
3 changed files with 178 additions and 46 deletions

View file

@ -81,7 +81,11 @@ impl CodeContextRetriever {
if let Some((item, byte_range)) = item.zip(byte_range) { if let Some((item, byte_range)) = item.zip(byte_range) {
if !name.is_empty() { if !name.is_empty() {
let item = format!("{}\n{}", context_spans.join("\n"), item); let item = if context_spans.is_empty() {
item.to_string()
} else {
format!("{}\n{}", context_spans.join("\n"), item)
};
let document_text = CODE_CONTEXT_TEMPLATE let document_text = CODE_CONTEXT_TEMPLATE
.replace("<path>", relative_path.to_str().unwrap()) .replace("<path>", relative_path.to_str().unwrap())

View file

@ -1,5 +1,9 @@
use crate::{ use crate::{
db::dot, embedding::EmbeddingProvider, vector_store_settings::VectorStoreSettings, VectorStore, db::dot,
embedding::EmbeddingProvider,
parsing::{CodeContextRetriever, Document},
vector_store_settings::VectorStoreSettings,
VectorStore,
}; };
use anyhow::Result; use anyhow::Result;
use async_trait::async_trait; use async_trait::async_trait;
@ -9,7 +13,7 @@ use project::{project_settings::ProjectSettings, FakeFs, Project};
use rand::{rngs::StdRng, Rng}; use rand::{rngs::StdRng, Rng};
use serde_json::json; use serde_json::json;
use settings::SettingsStore; use settings::SettingsStore;
use std::sync::Arc; use std::{path::Path, sync::Arc};
use unindent::Unindent; use unindent::Unindent;
#[ctor::ctor] #[ctor::ctor]
@ -52,24 +56,7 @@ async fn test_vector_store(cx: &mut TestAppContext) {
.await; .await;
let languages = Arc::new(LanguageRegistry::new(Task::ready(()))); let languages = Arc::new(LanguageRegistry::new(Task::ready(())));
let rust_language = Arc::new( let rust_language = rust_lang();
Language::new(
LanguageConfig {
name: "Rust".into(),
path_suffixes: vec!["rs".into()],
..Default::default()
},
Some(tree_sitter_rust::language()),
)
.with_embedding_query(
r#"
(function_item
name: (identifier) @name
body: (block)) @item
"#,
)
.unwrap(),
);
languages.add(rust_language); languages.add(rust_language);
let db_dir = tempdir::TempDir::new("vector-store").unwrap(); let db_dir = tempdir::TempDir::new("vector-store").unwrap();
@ -109,14 +96,59 @@ async fn test_vector_store(cx: &mut TestAppContext) {
#[gpui::test] #[gpui::test]
async fn test_code_context_retrieval(cx: &mut TestAppContext) { async fn test_code_context_retrieval(cx: &mut TestAppContext) {
// let mut retriever = CodeContextRetriever::new(fs); let language = rust_lang();
let mut retriever = CodeContextRetriever::new();
// retriever::parse_file( let text = "
// " /// A doc comment
// // /// that spans multiple lines
// ", fn a() {
// ); b
// }
impl C for D {
}
"
.unindent();
let parsed_files = retriever
.parse_file(Path::new("foo.rs"), &text, language)
.unwrap();
assert_eq!(
parsed_files,
&[
Document {
name: "a".into(),
range: text.find("fn a").unwrap()..(text.find("}").unwrap() + 1),
content: "
The below code snippet is from file 'foo.rs'
```rust
/// A doc comment
/// that spans multiple lines
fn a() {
b
}
```"
.unindent(),
embedding: vec![],
},
Document {
name: "C for D".into(),
range: text.find("impl C").unwrap()..(text.rfind("}").unwrap() + 1),
content: "
The below code snippet is from file 'foo.rs'
```rust
impl C for D {
}
```"
.unindent(),
embedding: vec![],
}
]
);
} }
#[gpui::test] #[gpui::test]
@ -178,3 +210,71 @@ impl EmbeddingProvider for FakeEmbeddingProvider {
.collect()) .collect())
} }
} }
fn rust_lang() -> Arc<Language> {
Arc::new(
Language::new(
LanguageConfig {
name: "Rust".into(),
path_suffixes: vec!["rs".into()],
..Default::default()
},
Some(tree_sitter_rust::language()),
)
.with_embedding_query(
r#"
(
(line_comment)* @context
.
(enum_item
name: (_) @name) @item
)
(
(line_comment)* @context
.
(struct_item
name: (_) @name) @item
)
(
(line_comment)* @context
.
(impl_item
trait: (_)? @name
"for"? @name
type: (_) @name) @item
)
(
(line_comment)* @context
.
(trait_item
name: (_) @name) @item
)
(
(line_comment)* @context
.
(function_item
name: (_) @name) @item
)
(
(line_comment)* @context
.
(macro_definition
name: (_) @name) @item
)
(
(line_comment)* @context
.
(function_signature_item
name: (_) @name) @item
)
"#,
)
.unwrap(),
)
}

View file

@ -1,22 +1,50 @@
( (
(line_comment)* @context (line_comment)* @context
. .
[ (enum_item
(enum_item name: (_) @name) @item
name: (_) @name) @item )
(struct_item
name: (_) @name) @item (
(impl_item (line_comment)* @context
trait: (_)? @name .
"for"? @name (struct_item
type: (_) @name) @item name: (_) @name) @item
(trait_item )
name: (_) @name) @item
(function_item (
name: (_) @name) @item (line_comment)* @context
(macro_definition .
name: (_) @name) @item (impl_item
(function_signature_item trait: (_)? @name
name: (_) @name) @item "for"? @name
] type: (_) @name) @item
)
(
(line_comment)* @context
.
(trait_item
name: (_) @name) @item
)
(
(line_comment)* @context
.
(function_item
name: (_) @name) @item
)
(
(line_comment)* @context
.
(macro_definition
name: (_) @name) @item
)
(
(line_comment)* @context
.
(function_signature_item
name: (_) @name) @item
) )