diff --git a/crates/semantic_index/src/embedding.rs b/crates/semantic_index/src/embedding.rs index a9cb0245c4..72621d3138 100644 --- a/crates/semantic_index/src/embedding.rs +++ b/crates/semantic_index/src/embedding.rs @@ -54,6 +54,8 @@ struct OpenAIEmbeddingUsage { #[async_trait] pub trait EmbeddingProvider: Sync + Send { async fn embed_batch(&self, spans: Vec<&str>) -> Result>>; + fn count_tokens(&self, span: &str) -> usize; + // fn truncate(&self, span: &str) -> Result<&str>; } pub struct DummyEmbeddings {} @@ -66,6 +68,12 @@ impl EmbeddingProvider for DummyEmbeddings { let dummy_vec = vec![0.32 as f32; 1536]; return Ok(vec![dummy_vec; spans.len()]); } + + fn count_tokens(&self, span: &str) -> usize { + // For Dummy Providers, we are going to use OpenAI tokenization for ease + let tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span); + tokens.len() + } } const OPENAI_INPUT_LIMIT: usize = 8190; @@ -111,6 +119,12 @@ impl OpenAIEmbeddings { #[async_trait] impl EmbeddingProvider for OpenAIEmbeddings { + fn count_tokens(&self, span: &str) -> usize { + // For Dummy Providers, we are going to use OpenAI tokenization for ease + let tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span); + tokens.len() + } + async fn embed_batch(&self, spans: Vec<&str>) -> Result>> { const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45]; const MAX_RETRIES: usize = 4; diff --git a/crates/semantic_index/src/parsing.rs b/crates/semantic_index/src/parsing.rs index 4aefb0b00d..b106e5055b 100644 --- a/crates/semantic_index/src/parsing.rs +++ b/crates/semantic_index/src/parsing.rs @@ -1,3 +1,4 @@ +use crate::embedding::EmbeddingProvider; use anyhow::{anyhow, Ok, Result}; use language::{Grammar, Language}; use sha1::{Digest, Sha1}; @@ -17,6 +18,7 @@ pub struct Document { pub content: String, pub embedding: Vec, pub sha1: [u8; 20], + pub token_count: usize, } const CODE_CONTEXT_TEMPLATE: &str = @@ -30,6 +32,7 @@ pub const PARSEABLE_ENTIRE_FILE_TYPES: &[&str] = pub struct CodeContextRetriever { pub parser: Parser, pub cursor: QueryCursor, + pub embedding_provider: Arc, } // Every match has an item, this represents the fundamental treesitter symbol and anchors the search @@ -47,10 +50,11 @@ pub struct CodeContextMatch { } impl CodeContextRetriever { - pub fn new() -> Self { + pub fn new(embedding_provider: Arc) -> Self { Self { parser: Parser::new(), cursor: QueryCursor::new(), + embedding_provider, } } @@ -68,12 +72,15 @@ impl CodeContextRetriever { let mut sha1 = Sha1::new(); sha1.update(&document_span); + let token_count = self.embedding_provider.count_tokens(&document_span); + Ok(vec![Document { range: 0..content.len(), content: document_span, embedding: Vec::new(), name: language_name.to_string(), sha1: sha1.finalize().into(), + token_count, }]) } @@ -85,12 +92,15 @@ impl CodeContextRetriever { let mut sha1 = Sha1::new(); sha1.update(&document_span); + let token_count = self.embedding_provider.count_tokens(&document_span); + Ok(vec![Document { range: 0..content.len(), content: document_span, embedding: Vec::new(), name: "Markdown".to_string(), sha1: sha1.finalize().into(), + token_count, }]) } @@ -166,10 +176,14 @@ impl CodeContextRetriever { let mut documents = self.parse_file(content, language)?; for document in &mut documents { - document.content = CODE_CONTEXT_TEMPLATE + let document_content = CODE_CONTEXT_TEMPLATE .replace("", relative_path.to_string_lossy().as_ref()) .replace("", language_name.as_ref()) .replace("item", &document.content); + + let token_count = self.embedding_provider.count_tokens(&document_content); + document.content = document_content; + document.token_count = token_count; } Ok(documents) } @@ -272,6 +286,7 @@ impl CodeContextRetriever { range: item_range.clone(), embedding: vec![], sha1: sha1.finalize().into(), + token_count: 0, }) } diff --git a/crates/semantic_index/src/semantic_index.rs b/crates/semantic_index/src/semantic_index.rs index 2da0d84baf..ab05ca7581 100644 --- a/crates/semantic_index/src/semantic_index.rs +++ b/crates/semantic_index/src/semantic_index.rs @@ -332,8 +332,9 @@ impl SemanticIndex { let parsing_files_rx = parsing_files_rx.clone(); let batch_files_tx = batch_files_tx.clone(); let db_update_tx = db_update_tx.clone(); + let embedding_provider = embedding_provider.clone(); _parsing_files_tasks.push(cx.background().spawn(async move { - let mut retriever = CodeContextRetriever::new(); + let mut retriever = CodeContextRetriever::new(embedding_provider.clone()); while let Ok(pending_file) = parsing_files_rx.recv().await { Self::parse_file( &fs, diff --git a/crates/semantic_index/src/semantic_index_tests.rs b/crates/semantic_index/src/semantic_index_tests.rs index 32d8bb0fb8..cb318a9fd6 100644 --- a/crates/semantic_index/src/semantic_index_tests.rs +++ b/crates/semantic_index/src/semantic_index_tests.rs @@ -1,6 +1,6 @@ use crate::{ db::dot, - embedding::EmbeddingProvider, + embedding::{DummyEmbeddings, EmbeddingProvider}, parsing::{subtract_ranges, CodeContextRetriever, Document}, semantic_index_settings::SemanticIndexSettings, SearchResult, SemanticIndex, @@ -227,7 +227,8 @@ fn assert_search_results( #[gpui::test] async fn test_code_context_retrieval_rust() { let language = rust_lang(); - let mut retriever = CodeContextRetriever::new(); + let embedding_provider = Arc::new(DummyEmbeddings {}); + let mut retriever = CodeContextRetriever::new(embedding_provider); let text = " /// A doc comment @@ -314,7 +315,8 @@ async fn test_code_context_retrieval_rust() { #[gpui::test] async fn test_code_context_retrieval_json() { let language = json_lang(); - let mut retriever = CodeContextRetriever::new(); + let embedding_provider = Arc::new(DummyEmbeddings {}); + let mut retriever = CodeContextRetriever::new(embedding_provider); let text = r#" { @@ -397,7 +399,8 @@ fn assert_documents_eq( #[gpui::test] async fn test_code_context_retrieval_javascript() { let language = js_lang(); - let mut retriever = CodeContextRetriever::new(); + let embedding_provider = Arc::new(DummyEmbeddings {}); + let mut retriever = CodeContextRetriever::new(embedding_provider); let text = " /* globals importScripts, backend */ @@ -495,7 +498,8 @@ async fn test_code_context_retrieval_javascript() { #[gpui::test] async fn test_code_context_retrieval_lua() { let language = lua_lang(); - let mut retriever = CodeContextRetriever::new(); + let embedding_provider = Arc::new(DummyEmbeddings {}); + let mut retriever = CodeContextRetriever::new(embedding_provider); let text = r#" -- Creates a new class @@ -568,7 +572,8 @@ async fn test_code_context_retrieval_lua() { #[gpui::test] async fn test_code_context_retrieval_elixir() { let language = elixir_lang(); - let mut retriever = CodeContextRetriever::new(); + let embedding_provider = Arc::new(DummyEmbeddings {}); + let mut retriever = CodeContextRetriever::new(embedding_provider); let text = r#" defmodule File.Stream do @@ -684,7 +689,8 @@ async fn test_code_context_retrieval_elixir() { #[gpui::test] async fn test_code_context_retrieval_cpp() { let language = cpp_lang(); - let mut retriever = CodeContextRetriever::new(); + let embedding_provider = Arc::new(DummyEmbeddings {}); + let mut retriever = CodeContextRetriever::new(embedding_provider); let text = " /** @@ -836,7 +842,8 @@ async fn test_code_context_retrieval_cpp() { #[gpui::test] async fn test_code_context_retrieval_ruby() { let language = ruby_lang(); - let mut retriever = CodeContextRetriever::new(); + let embedding_provider = Arc::new(DummyEmbeddings {}); + let mut retriever = CodeContextRetriever::new(embedding_provider); let text = r#" # This concern is inspired by "sudo mode" on GitHub. It @@ -1026,7 +1033,8 @@ async fn test_code_context_retrieval_ruby() { #[gpui::test] async fn test_code_context_retrieval_php() { let language = php_lang(); - let mut retriever = CodeContextRetriever::new(); + let embedding_provider = Arc::new(DummyEmbeddings {}); + let mut retriever = CodeContextRetriever::new(embedding_provider); let text = r#" usize { + span.len() + } + async fn embed_batch(&self, spans: Vec<&str>) -> Result>> { self.embedding_count .fetch_add(spans.len(), atomic::Ordering::SeqCst);