use crate::{ db::dot, embedding::EmbeddingProvider, parsing::{CodeContextRetriever, Document}, semantic_index_settings::SemanticIndexSettings, SemanticIndex, }; use anyhow::Result; use async_trait::async_trait; use gpui::{Task, TestAppContext}; use language::{Language, LanguageConfig, LanguageRegistry, ToOffset}; use project::{project_settings::ProjectSettings, FakeFs, Fs, Project}; use rand::{rngs::StdRng, Rng}; use serde_json::json; use settings::SettingsStore; use std::{ path::Path, sync::{ atomic::{self, AtomicUsize}, Arc, }, }; use unindent::Unindent; #[ctor::ctor] fn init_logger() { if std::env::var("RUST_LOG").is_ok() { env_logger::init(); } } #[gpui::test] async fn test_semantic_index(cx: &mut TestAppContext) { cx.update(|cx| { cx.set_global(SettingsStore::test(cx)); settings::register::(cx); settings::register::(cx); }); let fs = FakeFs::new(cx.background()); fs.insert_tree( "/the-root", json!({ "src": { "file1.rs": " fn aaa() { println!(\"aaaa!\"); } fn zzzzzzzzz() { println!(\"SLEEPING\"); } ".unindent(), "file2.rs": " fn bbb() { println!(\"bbbb!\"); } ".unindent(), "file3.toml": " ZZZZZZZ = 5 ".unindent(), } }), ) .await; let languages = Arc::new(LanguageRegistry::new(Task::ready(()))); let rust_language = rust_lang(); let toml_language = toml_lang(); languages.add(rust_language); languages.add(toml_language); let db_dir = tempdir::TempDir::new("vector-store").unwrap(); let db_path = db_dir.path().join("db.sqlite"); let embedding_provider = Arc::new(FakeEmbeddingProvider::default()); let store = SemanticIndex::new( fs.clone(), db_path, embedding_provider.clone(), languages, cx.to_async(), ) .await .unwrap(); let project = Project::test(fs.clone(), ["/the-root".as_ref()], cx).await; let (file_count, outstanding_file_count) = store .update(cx, |store, cx| store.index_project(project.clone(), cx)) .await .unwrap(); assert_eq!(file_count, 3); cx.foreground().run_until_parked(); assert_eq!(*outstanding_file_count.borrow(), 0); let search_results = store .update(cx, |store, cx| { store.search_project(project.clone(), "aaaa".to_string(), 5, cx) }) .await .unwrap(); search_results[0].buffer.read_with(cx, |buffer, _cx| { assert_eq!(search_results[0].range.start.to_offset(buffer), 0); assert_eq!( buffer.file().unwrap().path().as_ref(), Path::new("file1.rs") ); }); fs.save( "/the-root/src/file2.rs".as_ref(), &" fn dddd() { println!(\"ddddd!\"); } struct pqpqpqp {} " .unindent() .into(), Default::default(), ) .await .unwrap(); cx.foreground().run_until_parked(); let prev_embedding_count = embedding_provider.embedding_count(); let (file_count, outstanding_file_count) = store .update(cx, |store, cx| store.index_project(project.clone(), cx)) .await .unwrap(); assert_eq!(file_count, 1); cx.foreground().run_until_parked(); assert_eq!(*outstanding_file_count.borrow(), 0); assert_eq!( embedding_provider.embedding_count() - prev_embedding_count, 2 ); } #[gpui::test] async fn test_code_context_retrieval_rust() { let language = rust_lang(); let mut retriever = CodeContextRetriever::new(); let text = " /// A doc comment /// that spans multiple lines fn a() { b } impl C for D { } " .unindent(); let parsed_files = retriever .parse_file(Path::new("foo.rs"), &text, language) .unwrap(); assert_eq!( parsed_files, &[ Document { name: "a".into(), range: text.find("fn a").unwrap()..(text.find("}").unwrap() + 1), content: " The below code snippet is from file 'foo.rs' ```rust /// A doc comment /// that spans multiple lines fn a() { b } ```" .unindent(), embedding: vec![], }, Document { name: "C for D".into(), range: text.find("impl C").unwrap()..(text.rfind("}").unwrap() + 1), content: " The below code snippet is from file 'foo.rs' ```rust impl C for D { } ```" .unindent(), embedding: vec![], } ] ); } #[gpui::test] async fn test_code_context_retrieval_javascript() { let language = js_lang(); let mut retriever = CodeContextRetriever::new(); let text = " /* globals importScripts, backend */ function _authorize() {} /** * Sometimes the frontend build is way faster than backend. */ export async function authorizeBank() { _authorize(pushModal, upgradingAccountId, {}); } export class SettingsPage { /* This is a test setting */ constructor(page) { this.page = page; } } /* This is a test comment */ class TestClass {} /* Schema for editor_events in Clickhouse. */ export interface ClickhouseEditorEvent { installation_id: string operation: string } " .unindent(); let parsed_files = retriever .parse_file(Path::new("foo.js"), &text, language) .unwrap(); let test_documents = &[ Document { name: "function _authorize".into(), range: text.find("function _authorize").unwrap()..(text.find("}").unwrap() + 1), content: " The below code snippet is from file 'foo.js' ```javascript /* globals importScripts, backend */ function _authorize() {} ```" .unindent(), embedding: vec![], }, Document { name: "async function authorizeBank".into(), range: text.find("export async").unwrap()..223, content: " The below code snippet is from file 'foo.js' ```javascript /** * Sometimes the frontend build is way faster than backend. */ export async function authorizeBank() { _authorize(pushModal, upgradingAccountId, {}); } ```" .unindent(), embedding: vec![], }, Document { name: "class SettingsPage".into(), range: 225..343, content: " The below code snippet is from file 'foo.js' ```javascript export class SettingsPage { /* This is a test setting */ constructor(page) { this.page = page; } } ```" .unindent(), embedding: vec![], }, Document { name: "constructor".into(), range: 290..341, content: " The below code snippet is from file 'foo.js' ```javascript /* This is a test setting */ constructor(page) { this.page = page; } ```" .unindent(), embedding: vec![], }, Document { name: "class TestClass".into(), range: 374..392, content: " The below code snippet is from file 'foo.js' ```javascript /* This is a test comment */ class TestClass {} ```" .unindent(), embedding: vec![], }, Document { name: "interface ClickhouseEditorEvent".into(), range: 440..532, content: " The below code snippet is from file 'foo.js' ```javascript /* Schema for editor_events in Clickhouse. */ export interface ClickhouseEditorEvent { installation_id: string operation: string } ```" .unindent(), embedding: vec![], }, ]; for idx in 0..test_documents.len() { assert_eq!(test_documents[idx], parsed_files[idx]); } } #[gpui::test] async fn test_code_context_retrieval_elixir() { let language = elixir_lang(); let mut retriever = CodeContextRetriever::new(); let text = r#" defmodule File.Stream do @moduledoc """ Defines a `File.Stream` struct returned by `File.stream!/3`. The following fields are public: * `path` - the file path * `modes` - the file modes * `raw` - a boolean indicating if bin functions should be used * `line_or_bytes` - if reading should read lines or a given number of bytes * `node` - the node the file belongs to """ defstruct path: nil, modes: [], line_or_bytes: :line, raw: true, node: nil @type t :: %__MODULE__{} @doc false def __build__(path, modes, line_or_bytes) do raw = :lists.keyfind(:encoding, 1, modes) == false modes = case raw do true -> case :lists.keyfind(:read_ahead, 1, modes) do {:read_ahead, false} -> [:raw | :lists.keydelete(:read_ahead, 1, modes)] {:read_ahead, _} -> [:raw | modes] false -> [:raw, :read_ahead | modes] end false -> modes end %File.Stream{path: path, modes: modes, raw: raw, line_or_bytes: line_or_bytes, node: node()} end "# .unindent(); let parsed_files = retriever .parse_file(Path::new("foo.ex"), &text, language) .unwrap(); let test_documents = &[ Document{ name: "defmodule File.Stream".into(), range: 0..1132, content: r#" The below code snippet is from file 'foo.ex' ```elixir defmodule File.Stream do @moduledoc """ Defines a `File.Stream` struct returned by `File.stream!/3`. The following fields are public: * `path` - the file path * `modes` - the file modes * `raw` - a boolean indicating if bin functions should be used * `line_or_bytes` - if reading should read lines or a given number of bytes * `node` - the node the file belongs to """ defstruct path: nil, modes: [], line_or_bytes: :line, raw: true, node: nil @type t :: %__MODULE__{} @doc false def __build__(path, modes, line_or_bytes) do raw = :lists.keyfind(:encoding, 1, modes) == false modes = case raw do true -> case :lists.keyfind(:read_ahead, 1, modes) do {:read_ahead, false} -> [:raw | :lists.keydelete(:read_ahead, 1, modes)] {:read_ahead, _} -> [:raw | modes] false -> [:raw, :read_ahead | modes] end false -> modes end %File.Stream{path: path, modes: modes, raw: raw, line_or_bytes: line_or_bytes, node: node()} end ```"#.unindent(), embedding: vec![], }, Document { name: "def __build__".into(), range: 574..1132, content: r#" The below code snippet is from file 'foo.ex' ```elixir @doc false def __build__(path, modes, line_or_bytes) do raw = :lists.keyfind(:encoding, 1, modes) == false modes = case raw do true -> case :lists.keyfind(:read_ahead, 1, modes) do {:read_ahead, false} -> [:raw | :lists.keydelete(:read_ahead, 1, modes)] {:read_ahead, _} -> [:raw | modes] false -> [:raw, :read_ahead | modes] end false -> modes end %File.Stream{path: path, modes: modes, raw: raw, line_or_bytes: line_or_bytes, node: node()} end ```"# .unindent(), embedding: vec![], }]; for idx in 0..test_documents.len() { assert_eq!(test_documents[idx], parsed_files[idx]); } } #[gpui::test] async fn test_code_context_retrieval_cpp() { let language = cpp_lang(); let mut retriever = CodeContextRetriever::new(); let text = " /** * @brief Main function * @returns 0 on exit */ int main() { return 0; } /** * This is a test comment */ class MyClass { // The class public: // Access specifier int myNum; // Attribute (int variable) string myString; // Attribute (string variable) }; // This is a test comment enum Color { red, green, blue }; /** This is a preceeding block comment * This is the second line */ struct { // Structure declaration int myNum; // Member (int variable) string myString; // Member (string variable) } myStructure; /** * @brief Matrix class. */ template ::value || std::is_floating_point::value, bool>::type> class Matrix2 { std::vector> _mat; public: /** * @brief Constructor * @tparam Integer ensuring integers are being evaluated and not other * data types. * @param size denoting the size of Matrix as size x size */ template ::value, Integer>::type> explicit Matrix(const Integer size) { for (size_t i = 0; i < size; ++i) { _mat.emplace_back(std::vector(size, 0)); } } }" .unindent(); let parsed_files = retriever .parse_file(Path::new("foo.cpp"), &text, language) .unwrap(); let test_documents = &[ Document { name: "int main".into(), range: 54..78, content: " The below code snippet is from file 'foo.cpp' ```cpp /** * @brief Main function * @returns 0 on exit */ int main() { return 0; } ```" .unindent(), embedding: vec![], }, Document { name: "class MyClass".into(), range: 112..295, content: " The below code snippet is from file 'foo.cpp' ```cpp /** * This is a test comment */ class MyClass { // The class public: // Access specifier int myNum; // Attribute (int variable) string myString; // Attribute (string variable) } ```" .unindent(), embedding: vec![], }, Document { name: "enum Color".into(), range: 324..355, content: " The below code snippet is from file 'foo.cpp' ```cpp // This is a test comment enum Color { red, green, blue } ```" .unindent(), embedding: vec![], }, Document { name: "struct myStructure".into(), range: 428..581, content: " The below code snippet is from file 'foo.cpp' ```cpp /** This is a preceeding block comment * This is the second line */ struct { // Structure declaration int myNum; // Member (int variable) string myString; // Member (string variable) } myStructure; ```" .unindent(), embedding: vec![], }, Document { name: "class Matrix2".into(), range: 613..1342, content: " The below code snippet is from file 'foo.cpp' ```cpp /** * @brief Matrix class. */ template ::value || std::is_floating_point::value, bool>::type> class Matrix2 { std::vector> _mat; public: /** * @brief Constructor * @tparam Integer ensuring integers are being evaluated and not other * data types. * @param size denoting the size of Matrix as size x size */ template ::value, Integer>::type> explicit Matrix(const Integer size) { for (size_t i = 0; i < size; ++i) { _mat.emplace_back(std::vector(size, 0)); } } } ```" .unindent(), embedding: vec![], }, ]; for idx in 0..test_documents.len() { assert_eq!(test_documents[idx], parsed_files[idx]); } } #[gpui::test] fn test_dot_product(mut rng: StdRng) { assert_eq!(dot(&[1., 0., 0., 0., 0.], &[0., 1., 0., 0., 0.]), 0.); assert_eq!(dot(&[2., 0., 0., 0., 0.], &[3., 1., 0., 0., 0.]), 6.); for _ in 0..100 { let size = 1536; let mut a = vec![0.; size]; let mut b = vec![0.; size]; for (a, b) in a.iter_mut().zip(b.iter_mut()) { *a = rng.gen(); *b = rng.gen(); } assert_eq!( round_to_decimals(dot(&a, &b), 1), round_to_decimals(reference_dot(&a, &b), 1) ); } fn round_to_decimals(n: f32, decimal_places: i32) -> f32 { let factor = (10.0 as f32).powi(decimal_places); (n * factor).round() / factor } fn reference_dot(a: &[f32], b: &[f32]) -> f32 { a.iter().zip(b.iter()).map(|(a, b)| a * b).sum() } } #[derive(Default)] struct FakeEmbeddingProvider { embedding_count: AtomicUsize, } impl FakeEmbeddingProvider { fn embedding_count(&self) -> usize { self.embedding_count.load(atomic::Ordering::SeqCst) } } #[async_trait] impl EmbeddingProvider for FakeEmbeddingProvider { async fn embed_batch(&self, spans: Vec<&str>) -> Result>> { self.embedding_count .fetch_add(spans.len(), atomic::Ordering::SeqCst); Ok(spans .iter() .map(|span| { let mut result = vec![1.0; 26]; for letter in span.chars() { let letter = letter.to_ascii_lowercase(); if letter as u32 >= 'a' as u32 { let ix = (letter as u32) - ('a' as u32); if ix < 26 { result[ix as usize] += 1.0; } } } let norm = result.iter().map(|x| x * x).sum::().sqrt(); for x in &mut result { *x /= norm; } result }) .collect()) } } fn js_lang() -> Arc { Arc::new( Language::new( LanguageConfig { name: "Javascript".into(), path_suffixes: vec!["js".into()], ..Default::default() }, Some(tree_sitter_typescript::language_tsx()), ) .with_embedding_query( &r#" ( (comment)* @context . (export_statement (function_declaration "async"? @name "function" @name name: (_) @name)) @item ) ( (comment)* @context . (function_declaration "async"? @name "function" @name name: (_) @name) @item ) ( (comment)* @context . (export_statement (class_declaration "class" @name name: (_) @name)) @item ) ( (comment)* @context . (class_declaration "class" @name name: (_) @name) @item ) ( (comment)* @context . (method_definition [ "get" "set" "async" "*" "static" ]* @name name: (_) @name) @item ) ( (comment)* @context . (export_statement (interface_declaration "interface" @name name: (_) @name)) @item ) ( (comment)* @context . (interface_declaration "interface" @name name: (_) @name) @item ) ( (comment)* @context . (export_statement (enum_declaration "enum" @name name: (_) @name)) @item ) ( (comment)* @context . (enum_declaration "enum" @name name: (_) @name) @item ) "# .unindent(), ) .unwrap(), ) } fn rust_lang() -> Arc { Arc::new( Language::new( LanguageConfig { name: "Rust".into(), path_suffixes: vec!["rs".into()], ..Default::default() }, Some(tree_sitter_rust::language()), ) .with_embedding_query( r#" ( (line_comment)* @context . (enum_item name: (_) @name) @item ) ( (line_comment)* @context . (struct_item name: (_) @name) @item ) ( (line_comment)* @context . (impl_item trait: (_)? @name "for"? @name type: (_) @name) @item ) ( (line_comment)* @context . (trait_item name: (_) @name) @item ) ( (line_comment)* @context . (function_item name: (_) @name) @item ) ( (line_comment)* @context . (macro_definition name: (_) @name) @item ) ( (line_comment)* @context . (function_signature_item name: (_) @name) @item ) "#, ) .unwrap(), ) } fn toml_lang() -> Arc { Arc::new(Language::new( LanguageConfig { name: "TOML".into(), path_suffixes: vec!["toml".into()], ..Default::default() }, Some(tree_sitter_toml::language()), )) } fn cpp_lang() -> Arc { Arc::new( Language::new( LanguageConfig { name: "CPP".into(), path_suffixes: vec!["cpp".into()], ..Default::default() }, Some(tree_sitter_cpp::language()), ) .with_embedding_query( r#" ( (comment)* @context . (function_definition (type_qualifier)? @name type: (_)? @name declarator: [ (function_declarator declarator: (_) @name) (pointer_declarator "*" @name declarator: (function_declarator declarator: (_) @name)) (pointer_declarator "*" @name declarator: (pointer_declarator "*" @name declarator: (function_declarator declarator: (_) @name))) (reference_declarator ["&" "&&"] @name (function_declarator declarator: (_) @name)) ] (type_qualifier)? @name) @item ) ( (comment)* @context . (template_declaration (class_specifier "class" @name name: (_) @name) ) @item ) ( (comment)* @context . (class_specifier "class" @name name: (_) @name) @item ) ( (comment)* @context . (enum_specifier "enum" @name name: (_) @name) @item ) ( (comment)* @context . (declaration type: (struct_specifier "struct" @name) declarator: (_) @name) @item ) "#, ) .unwrap(), ) } fn elixir_lang() -> Arc { Arc::new( Language::new( LanguageConfig { name: "Elixir".into(), path_suffixes: vec!["rs".into()], ..Default::default() }, Some(tree_sitter_elixir::language()), ) .with_embedding_query( r#" ( (unary_operator operator: "@" operand: (call target: (identifier) @unary (#match? @unary "^(doc)$")) ) @context . (call target: (identifier) @name (arguments [ (identifier) @name (call target: (identifier) @name) (binary_operator left: (call target: (identifier) @name) operator: "when") ]) (#match? @name "^(def|defp|defdelegate|defguard|defguardp|defmacro|defmacrop|defn|defnp)$")) @item ) (call target: (identifier) @name (arguments (alias) @name) (#match? @name "^(defmodule|defprotocol)$")) @item "#, ) .unwrap(), ) }