added token count to documents during parsing

This commit is contained in:
KCaverly 2023-08-30 11:05:46 -04:00
parent a7e6a65deb
commit e377ada1a9
4 changed files with 54 additions and 12 deletions

View file

@ -54,6 +54,8 @@ struct OpenAIEmbeddingUsage {
#[async_trait] #[async_trait]
pub trait EmbeddingProvider: Sync + Send { pub trait EmbeddingProvider: Sync + Send {
async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>>; async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>>;
fn count_tokens(&self, span: &str) -> usize;
// fn truncate(&self, span: &str) -> Result<&str>;
} }
pub struct DummyEmbeddings {} pub struct DummyEmbeddings {}
@ -66,6 +68,12 @@ impl EmbeddingProvider for DummyEmbeddings {
let dummy_vec = vec![0.32 as f32; 1536]; let dummy_vec = vec![0.32 as f32; 1536];
return Ok(vec![dummy_vec; spans.len()]); return Ok(vec![dummy_vec; spans.len()]);
} }
fn count_tokens(&self, span: &str) -> usize {
// For Dummy Providers, we are going to use OpenAI tokenization for ease
let tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
tokens.len()
}
} }
const OPENAI_INPUT_LIMIT: usize = 8190; const OPENAI_INPUT_LIMIT: usize = 8190;
@ -111,6 +119,12 @@ impl OpenAIEmbeddings {
#[async_trait] #[async_trait]
impl EmbeddingProvider for OpenAIEmbeddings { impl EmbeddingProvider for OpenAIEmbeddings {
fn count_tokens(&self, span: &str) -> usize {
// For Dummy Providers, we are going to use OpenAI tokenization for ease
let tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
tokens.len()
}
async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> { async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45]; const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
const MAX_RETRIES: usize = 4; const MAX_RETRIES: usize = 4;

View file

@ -1,3 +1,4 @@
use crate::embedding::EmbeddingProvider;
use anyhow::{anyhow, Ok, Result}; use anyhow::{anyhow, Ok, Result};
use language::{Grammar, Language}; use language::{Grammar, Language};
use sha1::{Digest, Sha1}; use sha1::{Digest, Sha1};
@ -17,6 +18,7 @@ pub struct Document {
pub content: String, pub content: String,
pub embedding: Vec<f32>, pub embedding: Vec<f32>,
pub sha1: [u8; 20], pub sha1: [u8; 20],
pub token_count: usize,
} }
const CODE_CONTEXT_TEMPLATE: &str = const CODE_CONTEXT_TEMPLATE: &str =
@ -30,6 +32,7 @@ pub const PARSEABLE_ENTIRE_FILE_TYPES: &[&str] =
pub struct CodeContextRetriever { pub struct CodeContextRetriever {
pub parser: Parser, pub parser: Parser,
pub cursor: QueryCursor, pub cursor: QueryCursor,
pub embedding_provider: Arc<dyn EmbeddingProvider>,
} }
// Every match has an item, this represents the fundamental treesitter symbol and anchors the search // Every match has an item, this represents the fundamental treesitter symbol and anchors the search
@ -47,10 +50,11 @@ pub struct CodeContextMatch {
} }
impl CodeContextRetriever { impl CodeContextRetriever {
pub fn new() -> Self { pub fn new(embedding_provider: Arc<dyn EmbeddingProvider>) -> Self {
Self { Self {
parser: Parser::new(), parser: Parser::new(),
cursor: QueryCursor::new(), cursor: QueryCursor::new(),
embedding_provider,
} }
} }
@ -68,12 +72,15 @@ impl CodeContextRetriever {
let mut sha1 = Sha1::new(); let mut sha1 = Sha1::new();
sha1.update(&document_span); sha1.update(&document_span);
let token_count = self.embedding_provider.count_tokens(&document_span);
Ok(vec![Document { Ok(vec![Document {
range: 0..content.len(), range: 0..content.len(),
content: document_span, content: document_span,
embedding: Vec::new(), embedding: Vec::new(),
name: language_name.to_string(), name: language_name.to_string(),
sha1: sha1.finalize().into(), sha1: sha1.finalize().into(),
token_count,
}]) }])
} }
@ -85,12 +92,15 @@ impl CodeContextRetriever {
let mut sha1 = Sha1::new(); let mut sha1 = Sha1::new();
sha1.update(&document_span); sha1.update(&document_span);
let token_count = self.embedding_provider.count_tokens(&document_span);
Ok(vec![Document { Ok(vec![Document {
range: 0..content.len(), range: 0..content.len(),
content: document_span, content: document_span,
embedding: Vec::new(), embedding: Vec::new(),
name: "Markdown".to_string(), name: "Markdown".to_string(),
sha1: sha1.finalize().into(), sha1: sha1.finalize().into(),
token_count,
}]) }])
} }
@ -166,10 +176,14 @@ impl CodeContextRetriever {
let mut documents = self.parse_file(content, language)?; let mut documents = self.parse_file(content, language)?;
for document in &mut documents { for document in &mut documents {
document.content = CODE_CONTEXT_TEMPLATE let document_content = CODE_CONTEXT_TEMPLATE
.replace("<path>", relative_path.to_string_lossy().as_ref()) .replace("<path>", relative_path.to_string_lossy().as_ref())
.replace("<language>", language_name.as_ref()) .replace("<language>", language_name.as_ref())
.replace("item", &document.content); .replace("item", &document.content);
let token_count = self.embedding_provider.count_tokens(&document_content);
document.content = document_content;
document.token_count = token_count;
} }
Ok(documents) Ok(documents)
} }
@ -272,6 +286,7 @@ impl CodeContextRetriever {
range: item_range.clone(), range: item_range.clone(),
embedding: vec![], embedding: vec![],
sha1: sha1.finalize().into(), sha1: sha1.finalize().into(),
token_count: 0,
}) })
} }

View file

@ -332,8 +332,9 @@ impl SemanticIndex {
let parsing_files_rx = parsing_files_rx.clone(); let parsing_files_rx = parsing_files_rx.clone();
let batch_files_tx = batch_files_tx.clone(); let batch_files_tx = batch_files_tx.clone();
let db_update_tx = db_update_tx.clone(); let db_update_tx = db_update_tx.clone();
let embedding_provider = embedding_provider.clone();
_parsing_files_tasks.push(cx.background().spawn(async move { _parsing_files_tasks.push(cx.background().spawn(async move {
let mut retriever = CodeContextRetriever::new(); let mut retriever = CodeContextRetriever::new(embedding_provider.clone());
while let Ok(pending_file) = parsing_files_rx.recv().await { while let Ok(pending_file) = parsing_files_rx.recv().await {
Self::parse_file( Self::parse_file(
&fs, &fs,

View file

@ -1,6 +1,6 @@
use crate::{ use crate::{
db::dot, db::dot,
embedding::EmbeddingProvider, embedding::{DummyEmbeddings, EmbeddingProvider},
parsing::{subtract_ranges, CodeContextRetriever, Document}, parsing::{subtract_ranges, CodeContextRetriever, Document},
semantic_index_settings::SemanticIndexSettings, semantic_index_settings::SemanticIndexSettings,
SearchResult, SemanticIndex, SearchResult, SemanticIndex,
@ -227,7 +227,8 @@ fn assert_search_results(
#[gpui::test] #[gpui::test]
async fn test_code_context_retrieval_rust() { async fn test_code_context_retrieval_rust() {
let language = rust_lang(); let language = rust_lang();
let mut retriever = CodeContextRetriever::new(); let embedding_provider = Arc::new(DummyEmbeddings {});
let mut retriever = CodeContextRetriever::new(embedding_provider);
let text = " let text = "
/// A doc comment /// A doc comment
@ -314,7 +315,8 @@ async fn test_code_context_retrieval_rust() {
#[gpui::test] #[gpui::test]
async fn test_code_context_retrieval_json() { async fn test_code_context_retrieval_json() {
let language = json_lang(); let language = json_lang();
let mut retriever = CodeContextRetriever::new(); let embedding_provider = Arc::new(DummyEmbeddings {});
let mut retriever = CodeContextRetriever::new(embedding_provider);
let text = r#" let text = r#"
{ {
@ -397,7 +399,8 @@ fn assert_documents_eq(
#[gpui::test] #[gpui::test]
async fn test_code_context_retrieval_javascript() { async fn test_code_context_retrieval_javascript() {
let language = js_lang(); let language = js_lang();
let mut retriever = CodeContextRetriever::new(); let embedding_provider = Arc::new(DummyEmbeddings {});
let mut retriever = CodeContextRetriever::new(embedding_provider);
let text = " let text = "
/* globals importScripts, backend */ /* globals importScripts, backend */
@ -495,7 +498,8 @@ async fn test_code_context_retrieval_javascript() {
#[gpui::test] #[gpui::test]
async fn test_code_context_retrieval_lua() { async fn test_code_context_retrieval_lua() {
let language = lua_lang(); let language = lua_lang();
let mut retriever = CodeContextRetriever::new(); let embedding_provider = Arc::new(DummyEmbeddings {});
let mut retriever = CodeContextRetriever::new(embedding_provider);
let text = r#" let text = r#"
-- Creates a new class -- Creates a new class
@ -568,7 +572,8 @@ async fn test_code_context_retrieval_lua() {
#[gpui::test] #[gpui::test]
async fn test_code_context_retrieval_elixir() { async fn test_code_context_retrieval_elixir() {
let language = elixir_lang(); let language = elixir_lang();
let mut retriever = CodeContextRetriever::new(); let embedding_provider = Arc::new(DummyEmbeddings {});
let mut retriever = CodeContextRetriever::new(embedding_provider);
let text = r#" let text = r#"
defmodule File.Stream do defmodule File.Stream do
@ -684,7 +689,8 @@ async fn test_code_context_retrieval_elixir() {
#[gpui::test] #[gpui::test]
async fn test_code_context_retrieval_cpp() { async fn test_code_context_retrieval_cpp() {
let language = cpp_lang(); let language = cpp_lang();
let mut retriever = CodeContextRetriever::new(); let embedding_provider = Arc::new(DummyEmbeddings {});
let mut retriever = CodeContextRetriever::new(embedding_provider);
let text = " let text = "
/** /**
@ -836,7 +842,8 @@ async fn test_code_context_retrieval_cpp() {
#[gpui::test] #[gpui::test]
async fn test_code_context_retrieval_ruby() { async fn test_code_context_retrieval_ruby() {
let language = ruby_lang(); let language = ruby_lang();
let mut retriever = CodeContextRetriever::new(); let embedding_provider = Arc::new(DummyEmbeddings {});
let mut retriever = CodeContextRetriever::new(embedding_provider);
let text = r#" let text = r#"
# This concern is inspired by "sudo mode" on GitHub. It # This concern is inspired by "sudo mode" on GitHub. It
@ -1026,7 +1033,8 @@ async fn test_code_context_retrieval_ruby() {
#[gpui::test] #[gpui::test]
async fn test_code_context_retrieval_php() { async fn test_code_context_retrieval_php() {
let language = php_lang(); let language = php_lang();
let mut retriever = CodeContextRetriever::new(); let embedding_provider = Arc::new(DummyEmbeddings {});
let mut retriever = CodeContextRetriever::new(embedding_provider);
let text = r#" let text = r#"
<?php <?php
@ -1216,6 +1224,10 @@ impl FakeEmbeddingProvider {
#[async_trait] #[async_trait]
impl EmbeddingProvider for FakeEmbeddingProvider { impl EmbeddingProvider for FakeEmbeddingProvider {
fn count_tokens(&self, span: &str) -> usize {
span.len()
}
async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> { async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
self.embedding_count self.embedding_count
.fetch_add(spans.len(), atomic::Ordering::SeqCst); .fetch_add(spans.len(), atomic::Ordering::SeqCst);