added token count to documents during parsing

This commit is contained in:
KCaverly 2023-08-30 11:05:46 -04:00
parent a7e6a65deb
commit e377ada1a9
4 changed files with 54 additions and 12 deletions

View file

@ -1,3 +1,4 @@
use crate::embedding::EmbeddingProvider;
use anyhow::{anyhow, Ok, Result};
use language::{Grammar, Language};
use sha1::{Digest, Sha1};
@ -17,6 +18,7 @@ pub struct Document {
pub content: String,
pub embedding: Vec<f32>,
pub sha1: [u8; 20],
pub token_count: usize,
}
const CODE_CONTEXT_TEMPLATE: &str =
@ -30,6 +32,7 @@ pub const PARSEABLE_ENTIRE_FILE_TYPES: &[&str] =
pub struct CodeContextRetriever {
pub parser: Parser,
pub cursor: QueryCursor,
pub embedding_provider: Arc<dyn EmbeddingProvider>,
}
// Every match has an item, this represents the fundamental treesitter symbol and anchors the search
@ -47,10 +50,11 @@ pub struct CodeContextMatch {
}
impl CodeContextRetriever {
pub fn new() -> Self {
pub fn new(embedding_provider: Arc<dyn EmbeddingProvider>) -> Self {
Self {
parser: Parser::new(),
cursor: QueryCursor::new(),
embedding_provider,
}
}
@ -68,12 +72,15 @@ impl CodeContextRetriever {
let mut sha1 = Sha1::new();
sha1.update(&document_span);
let token_count = self.embedding_provider.count_tokens(&document_span);
Ok(vec![Document {
range: 0..content.len(),
content: document_span,
embedding: Vec::new(),
name: language_name.to_string(),
sha1: sha1.finalize().into(),
token_count,
}])
}
@ -85,12 +92,15 @@ impl CodeContextRetriever {
let mut sha1 = Sha1::new();
sha1.update(&document_span);
let token_count = self.embedding_provider.count_tokens(&document_span);
Ok(vec![Document {
range: 0..content.len(),
content: document_span,
embedding: Vec::new(),
name: "Markdown".to_string(),
sha1: sha1.finalize().into(),
token_count,
}])
}
@ -166,10 +176,14 @@ impl CodeContextRetriever {
let mut documents = self.parse_file(content, language)?;
for document in &mut documents {
document.content = CODE_CONTEXT_TEMPLATE
let document_content = CODE_CONTEXT_TEMPLATE
.replace("<path>", relative_path.to_string_lossy().as_ref())
.replace("<language>", language_name.as_ref())
.replace("item", &document.content);
let token_count = self.embedding_provider.count_tokens(&document_content);
document.content = document_content;
document.token_count = token_count;
}
Ok(documents)
}
@ -272,6 +286,7 @@ impl CodeContextRetriever {
range: item_range.clone(),
embedding: vec![],
sha1: sha1.finalize().into(),
token_count: 0,
})
}