diff --git a/crates/vector_store/src/parsing.rs b/crates/vector_store/src/parsing.rs index 23dcf505c9..8d6e03d6eb 100644 --- a/crates/vector_store/src/parsing.rs +++ b/crates/vector_store/src/parsing.rs @@ -81,7 +81,11 @@ impl CodeContextRetriever { if let Some((item, byte_range)) = item.zip(byte_range) { if !name.is_empty() { - let item = format!("{}\n{}", context_spans.join("\n"), item); + let item = if context_spans.is_empty() { + item.to_string() + } else { + format!("{}\n{}", context_spans.join("\n"), item) + }; let document_text = CODE_CONTEXT_TEMPLATE .replace("", relative_path.to_str().unwrap()) diff --git a/crates/vector_store/src/vector_store_tests.rs b/crates/vector_store/src/vector_store_tests.rs index c4349c7280..ccdd9fdaf0 100644 --- a/crates/vector_store/src/vector_store_tests.rs +++ b/crates/vector_store/src/vector_store_tests.rs @@ -1,5 +1,9 @@ use crate::{ - db::dot, embedding::EmbeddingProvider, vector_store_settings::VectorStoreSettings, VectorStore, + db::dot, + embedding::EmbeddingProvider, + parsing::{CodeContextRetriever, Document}, + vector_store_settings::VectorStoreSettings, + VectorStore, }; use anyhow::Result; use async_trait::async_trait; @@ -9,7 +13,7 @@ use project::{project_settings::ProjectSettings, FakeFs, Project}; use rand::{rngs::StdRng, Rng}; use serde_json::json; use settings::SettingsStore; -use std::sync::Arc; +use std::{path::Path, sync::Arc}; use unindent::Unindent; #[ctor::ctor] @@ -52,24 +56,7 @@ async fn test_vector_store(cx: &mut TestAppContext) { .await; let languages = Arc::new(LanguageRegistry::new(Task::ready(()))); - let rust_language = Arc::new( - Language::new( - LanguageConfig { - name: "Rust".into(), - path_suffixes: vec!["rs".into()], - ..Default::default() - }, - Some(tree_sitter_rust::language()), - ) - .with_embedding_query( - r#" - (function_item - name: (identifier) @name - body: (block)) @item - "#, - ) - .unwrap(), - ); + let rust_language = rust_lang(); languages.add(rust_language); let db_dir = tempdir::TempDir::new("vector-store").unwrap(); @@ -109,14 +96,59 @@ async fn test_vector_store(cx: &mut TestAppContext) { #[gpui::test] async fn test_code_context_retrieval(cx: &mut TestAppContext) { - // let mut retriever = CodeContextRetriever::new(fs); + let language = rust_lang(); + let mut retriever = CodeContextRetriever::new(); - // retriever::parse_file( - // " - // // - // ", - // ); - // + let text = " + /// A doc comment + /// that spans multiple lines + fn a() { + b + } + + impl C for D { + } + " + .unindent(); + + let parsed_files = retriever + .parse_file(Path::new("foo.rs"), &text, language) + .unwrap(); + + assert_eq!( + parsed_files, + &[ + Document { + name: "a".into(), + range: text.find("fn a").unwrap()..(text.find("}").unwrap() + 1), + content: " + The below code snippet is from file 'foo.rs' + + ```rust + /// A doc comment + /// that spans multiple lines + fn a() { + b + } + ```" + .unindent(), + embedding: vec![], + }, + Document { + name: "C for D".into(), + range: text.find("impl C").unwrap()..(text.rfind("}").unwrap() + 1), + content: " + The below code snippet is from file 'foo.rs' + + ```rust + impl C for D { + } + ```" + .unindent(), + embedding: vec![], + } + ] + ); } #[gpui::test] @@ -178,3 +210,71 @@ impl EmbeddingProvider for FakeEmbeddingProvider { .collect()) } } + +fn rust_lang() -> Arc { + Arc::new( + Language::new( + LanguageConfig { + name: "Rust".into(), + path_suffixes: vec!["rs".into()], + ..Default::default() + }, + Some(tree_sitter_rust::language()), + ) + .with_embedding_query( + r#" + ( + (line_comment)* @context + . + (enum_item + name: (_) @name) @item + ) + + ( + (line_comment)* @context + . + (struct_item + name: (_) @name) @item + ) + + ( + (line_comment)* @context + . + (impl_item + trait: (_)? @name + "for"? @name + type: (_) @name) @item + ) + + ( + (line_comment)* @context + . + (trait_item + name: (_) @name) @item + ) + + ( + (line_comment)* @context + . + (function_item + name: (_) @name) @item + ) + + ( + (line_comment)* @context + . + (macro_definition + name: (_) @name) @item + ) + + ( + (line_comment)* @context + . + (function_signature_item + name: (_) @name) @item + ) + "#, + ) + .unwrap(), + ) +} diff --git a/crates/zed/src/languages/rust/embedding.scm b/crates/zed/src/languages/rust/embedding.scm index 3aec101e9f..66e4083de5 100644 --- a/crates/zed/src/languages/rust/embedding.scm +++ b/crates/zed/src/languages/rust/embedding.scm @@ -1,22 +1,50 @@ ( (line_comment)* @context . - [ - (enum_item - name: (_) @name) @item - (struct_item - name: (_) @name) @item - (impl_item - trait: (_)? @name - "for"? @name - type: (_) @name) @item - (trait_item - name: (_) @name) @item - (function_item - name: (_) @name) @item - (macro_definition - name: (_) @name) @item - (function_signature_item - name: (_) @name) @item - ] + (enum_item + name: (_) @name) @item +) + +( + (line_comment)* @context + . + (struct_item + name: (_) @name) @item +) + +( + (line_comment)* @context + . + (impl_item + trait: (_)? @name + "for"? @name + type: (_) @name) @item +) + +( + (line_comment)* @context + . + (trait_item + name: (_) @name) @item +) + +( + (line_comment)* @context + . + (function_item + name: (_) @name) @item +) + +( + (line_comment)* @context + . + (macro_definition + name: (_) @name) @item +) + +( + (line_comment)* @context + . + (function_signature_item + name: (_) @name) @item )