move embedding truncation to base model

This commit is contained in:
KCaverly 2023-10-23 14:07:45 +02:00
parent 2b780ee7b2
commit 4e90e45999
5 changed files with 48 additions and 34 deletions

View file

@ -1,4 +1,7 @@
use ai::embedding::{Embedding, EmbeddingProvider};
use ai::{
embedding::{Embedding, EmbeddingProvider},
models::TruncationDirection,
};
use anyhow::{anyhow, Result};
use language::{Grammar, Language};
use rusqlite::{
@ -108,7 +111,14 @@ impl CodeContextRetriever {
.replace("<language>", language_name.as_ref())
.replace("<item>", &content);
let digest = SpanDigest::from(document_span.as_str());
let (document_span, token_count) = self.embedding_provider.truncate(&document_span);
let model = self.embedding_provider.base_model();
let document_span = model.truncate(
&document_span,
model.capacity()?,
ai::models::TruncationDirection::End,
)?;
let token_count = model.count_tokens(&document_span)?;
Ok(vec![Span {
range: 0..content.len(),
content: document_span,
@ -131,7 +141,15 @@ impl CodeContextRetriever {
)
.replace("<item>", &content);
let digest = SpanDigest::from(document_span.as_str());
let (document_span, token_count) = self.embedding_provider.truncate(&document_span);
let model = self.embedding_provider.base_model();
let document_span = model.truncate(
&document_span,
model.capacity()?,
ai::models::TruncationDirection::End,
)?;
let token_count = model.count_tokens(&document_span)?;
Ok(vec![Span {
range: 0..content.len(),
content: document_span,
@ -222,8 +240,13 @@ impl CodeContextRetriever {
.replace("<language>", language_name.as_ref())
.replace("item", &span.content);
let (document_content, token_count) =
self.embedding_provider.truncate(&document_content);
let model = self.embedding_provider.base_model();
let document_content = model.truncate(
&document_content,
model.capacity()?,
TruncationDirection::End,
)?;
let token_count = model.count_tokens(&document_content)?;
span.content = document_content;
span.token_count = token_count;

View file

@ -1291,12 +1291,8 @@ impl EmbeddingProvider for FakeEmbeddingProvider {
fn is_authenticated(&self) -> bool {
true
}
fn truncate(&self, span: &str) -> (String, usize) {
(span.to_string(), 1)
}
fn max_tokens_per_batch(&self) -> usize {
200
1000
}
fn rate_limit_expiration(&self) -> Option<Instant> {
@ -1306,7 +1302,8 @@ impl EmbeddingProvider for FakeEmbeddingProvider {
async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
self.embedding_count
.fetch_add(spans.len(), atomic::Ordering::SeqCst);
Ok(spans.iter().map(|span| self.embed_sync(span)).collect())
anyhow::Ok(spans.iter().map(|span| self.embed_sync(span)).collect())
}
}