From 1362c5a3d9753702820bc615dcfd4a4b261f0a3f Mon Sep 17 00:00:00 2001 From: KCaverly Date: Mon, 17 Jul 2023 14:43:29 -0400 Subject: [PATCH] add embedding treesitter query for cpp --- Cargo.lock | 1 + crates/vector_store/Cargo.toml | 1 + crates/vector_store/src/vector_store_tests.rs | 310 ++++++++++++++++-- crates/zed/src/languages/cpp/embedding.scm | 61 ++++ 4 files changed, 346 insertions(+), 27 deletions(-) create mode 100644 crates/zed/src/languages/cpp/embedding.scm diff --git a/Cargo.lock b/Cargo.lock index afd40fd308..28a0e76d14 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8518,6 +8518,7 @@ dependencies = [ "theme", "tiktoken-rs 0.5.0", "tree-sitter", + "tree-sitter-cpp", "tree-sitter-rust", "tree-sitter-toml 0.20.0", "tree-sitter-typescript 0.20.2 (registry+https://github.com/rust-lang/crates.io-index)", diff --git a/crates/vector_store/Cargo.toml b/crates/vector_store/Cargo.toml index 31119a1ba6..0009665e26 100644 --- a/crates/vector_store/Cargo.toml +++ b/crates/vector_store/Cargo.toml @@ -54,3 +54,4 @@ env_logger.workspace = true tree-sitter-typescript = "*" tree-sitter-rust = "*" tree-sitter-toml = "*" +tree-sitter-cpp = "*" diff --git a/crates/vector_store/src/vector_store_tests.rs b/crates/vector_store/src/vector_store_tests.rs index 84c9962493..3a9e1748c5 100644 --- a/crates/vector_store/src/vector_store_tests.rs +++ b/crates/vector_store/src/vector_store_tests.rs @@ -211,32 +211,33 @@ async fn test_code_context_retrieval_javascript() { let mut retriever = CodeContextRetriever::new(); let text = " -/* globals importScripts, backend */ -function _authorize() {} + /* globals importScripts, backend */ + function _authorize() {} -/** - * Sometimes the frontend build is way faster than backend. - */ -export async function authorizeBank() { - _authorize(pushModal, upgradingAccountId, {}); -} + /** + * Sometimes the frontend build is way faster than backend. + */ + export async function authorizeBank() { + _authorize(pushModal, upgradingAccountId, {}); + } -export class SettingsPage { - /* This is a test setting */ - constructor(page) { - this.page = page; - } -} + export class SettingsPage { + /* This is a test setting */ + constructor(page) { + this.page = page; + } + } -/* This is a test comment */ -class TestClass {} + /* This is a test comment */ + class TestClass {} -/* Schema for editor_events in Clickhouse. */ -export interface ClickhouseEditorEvent { - installation_id: string - operation: string -} -"; + /* Schema for editor_events in Clickhouse. */ + export interface ClickhouseEditorEvent { + installation_id: string + operation: string + } + " + .unindent(); let parsed_files = retriever .parse_file(Path::new("foo.js"), &text, language) @@ -258,7 +259,7 @@ export interface ClickhouseEditorEvent { }, Document { name: "async function authorizeBank".into(), - range: text.find("export async").unwrap()..224, + range: text.find("export async").unwrap()..223, content: " The below code snippet is from file 'foo.js' @@ -275,7 +276,7 @@ export interface ClickhouseEditorEvent { }, Document { name: "class SettingsPage".into(), - range: 226..344, + range: 225..343, content: " The below code snippet is from file 'foo.js' @@ -292,7 +293,7 @@ export interface ClickhouseEditorEvent { }, Document { name: "constructor".into(), - range: 291..342, + range: 290..341, content: " The below code snippet is from file 'foo.js' @@ -307,7 +308,7 @@ export interface ClickhouseEditorEvent { }, Document { name: "class TestClass".into(), - range: 375..393, + range: 374..392, content: " The below code snippet is from file 'foo.js' @@ -320,7 +321,7 @@ export interface ClickhouseEditorEvent { }, Document { name: "interface ClickhouseEditorEvent".into(), - range: 441..533, + range: 440..532, content: " The below code snippet is from file 'foo.js' @@ -341,6 +342,181 @@ export interface ClickhouseEditorEvent { } } +#[gpui::test] +async fn test_code_context_retrieval_cpp() { + let language = cpp_lang(); + let mut retriever = CodeContextRetriever::new(); + + let text = " + /** + * @brief Main function + * @returns 0 on exit + */ + int main() { return 0; } + + /** + * This is a test comment + */ + class MyClass { // The class + public: // Access specifier + int myNum; // Attribute (int variable) + string myString; // Attribute (string variable) + }; + + // This is a test comment + enum Color { red, green, blue }; + + /** This is a preceeding block comment + * This is the second line + */ + struct { // Structure declaration + int myNum; // Member (int variable) + string myString; // Member (string variable) + } myStructure; + + /** + * @brief Matrix class. + */ + template ::value || std::is_floating_point::value, + bool>::type> + class Matrix2 { + std::vector> _mat; + + public: + /** + * @brief Constructor + * @tparam Integer ensuring integers are being evaluated and not other + * data types. + * @param size denoting the size of Matrix as size x size + */ + template ::value, + Integer>::type> + explicit Matrix(const Integer size) { + for (size_t i = 0; i < size; ++i) { + _mat.emplace_back(std::vector(size, 0)); + } + } + }" + .unindent(); + + let parsed_files = retriever + .parse_file(Path::new("foo.cpp"), &text, language) + .unwrap(); + + let test_documents = &[ + Document { + name: "int main".into(), + range: 54..78, + content: " + The below code snippet is from file 'foo.cpp' + + ```cpp + /** + * @brief Main function + * @returns 0 on exit + */ + int main() { return 0; } + ```" + .unindent(), + embedding: vec![], + }, + Document { + name: "class MyClass".into(), + range: 112..295, + content: " + The below code snippet is from file 'foo.cpp' + + ```cpp + /** + * This is a test comment + */ + class MyClass { // The class + public: // Access specifier + int myNum; // Attribute (int variable) + string myString; // Attribute (string variable) + } + ```" + .unindent(), + embedding: vec![], + }, + Document { + name: "enum Color".into(), + range: 324..355, + content: " + The below code snippet is from file 'foo.cpp' + + ```cpp + // This is a test comment + enum Color { red, green, blue } + ```" + .unindent(), + embedding: vec![], + }, + Document { + name: "struct myStructure".into(), + range: 428..581, + content: " + The below code snippet is from file 'foo.cpp' + + ```cpp + /** This is a preceeding block comment + * This is the second line + */ + struct { // Structure declaration + int myNum; // Member (int variable) + string myString; // Member (string variable) + } myStructure; + ```" + .unindent(), + embedding: vec![], + }, + Document { + name: "class Matrix2".into(), + range: 613..1342, + content: " + The below code snippet is from file 'foo.cpp' + + ```cpp + /** + * @brief Matrix class. + */ + template ::value || std::is_floating_point::value, + bool>::type> + class Matrix2 { + std::vector> _mat; + + public: + /** + * @brief Constructor + * @tparam Integer ensuring integers are being evaluated and not other + * data types. + * @param size denoting the size of Matrix as size x size + */ + template ::value, + Integer>::type> + explicit Matrix(const Integer size) { + for (size_t i = 0; i < size; ++i) { + _mat.emplace_back(std::vector(size, 0)); + } + } + } + ```" + .unindent(), + embedding: vec![], + }, + ]; + + for idx in 0..test_documents.len() { + assert_eq!(test_documents[idx], parsed_files[idx]); + } +} + #[gpui::test] fn test_dot_product(mut rng: StdRng) { assert_eq!(dot(&[1., 0., 0., 0., 0.], &[0., 1., 0., 0., 0.]), 0.); @@ -594,3 +770,83 @@ fn toml_lang() -> Arc { Some(tree_sitter_toml::language()), )) } + +fn cpp_lang() -> Arc { + Arc::new( + Language::new( + LanguageConfig { + name: "CPP".into(), + path_suffixes: vec!["cpp".into()], + ..Default::default() + }, + Some(tree_sitter_cpp::language()), + ) + .with_embedding_query( + r#" + ( + (comment)* @context + . + (function_definition + (type_qualifier)? @name + type: (_)? @name + declarator: [ + (function_declarator + declarator: (_) @name) + (pointer_declarator + "*" @name + declarator: (function_declarator + declarator: (_) @name)) + (pointer_declarator + "*" @name + declarator: (pointer_declarator + "*" @name + declarator: (function_declarator + declarator: (_) @name))) + (reference_declarator + ["&" "&&"] @name + (function_declarator + declarator: (_) @name)) + ] + (type_qualifier)? @name) @item + ) + + ( + (comment)* @context + . + (template_declaration + (class_specifier + "class" @name + name: (_) @name) + ) @item + ) + + ( + (comment)* @context + . + (class_specifier + "class" @name + name: (_) @name) @item + ) + + ( + (comment)* @context + . + (enum_specifier + "enum" @name + name: (_) @name) @item + ) + + ( + (comment)* @context + . + (declaration + type: (struct_specifier + "struct" @name) + declarator: (_) @name) @item + ) + + "#, + ) + .unwrap(), + ) +} diff --git a/crates/zed/src/languages/cpp/embedding.scm b/crates/zed/src/languages/cpp/embedding.scm new file mode 100644 index 0000000000..bbd93f20db --- /dev/null +++ b/crates/zed/src/languages/cpp/embedding.scm @@ -0,0 +1,61 @@ +( + (comment)* @context + . + (function_definition + (type_qualifier)? @name + type: (_)? @name + declarator: [ + (function_declarator + declarator: (_) @name) + (pointer_declarator + "*" @name + declarator: (function_declarator + declarator: (_) @name)) + (pointer_declarator + "*" @name + declarator: (pointer_declarator + "*" @name + declarator: (function_declarator + declarator: (_) @name))) + (reference_declarator + ["&" "&&"] @name + (function_declarator + declarator: (_) @name)) + ] + (type_qualifier)? @name) @item + ) + +( + (comment)* @context + . + (template_declaration + (class_specifier + "class" @name + name: (_) @name) + ) @item +) + +( + (comment)* @context + . + (class_specifier + "class" @name + name: (_) @name) @item + ) + +( + (comment)* @context + . + (enum_specifier + "enum" @name + name: (_) @name) @item + ) + +( + (comment)* @context + . + (declaration + type: (struct_specifier + "struct" @name) + declarator: (_) @name) @item +)