move embedding truncation to base model

This commit is contained in:
KCaverly 2023-10-23 14:07:45 +02:00
parent 2b780ee7b2
commit 4e90e45999
5 changed files with 48 additions and 34 deletions

View file

@ -72,7 +72,6 @@ pub trait EmbeddingProvider: Sync + Send {
fn is_authenticated(&self) -> bool; fn is_authenticated(&self) -> bool;
async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>>; async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>>;
fn max_tokens_per_batch(&self) -> usize; fn max_tokens_per_batch(&self) -> usize;
fn truncate(&self, span: &str) -> (String, usize);
fn rate_limit_expiration(&self) -> Option<Instant>; fn rate_limit_expiration(&self) -> Option<Instant>;
} }

View file

@ -23,6 +23,10 @@ impl LanguageModel for DummyLanguageModel {
length: usize, length: usize,
direction: crate::models::TruncationDirection, direction: crate::models::TruncationDirection,
) -> anyhow::Result<String> { ) -> anyhow::Result<String> {
if content.len() < length {
return anyhow::Ok(content.to_string());
}
let truncated = match direction { let truncated = match direction {
TruncationDirection::End => content.chars().collect::<Vec<char>>()[..length] TruncationDirection::End => content.chars().collect::<Vec<char>>()[..length]
.iter() .iter()
@ -73,11 +77,4 @@ impl EmbeddingProvider for DummyEmbeddingProvider {
fn max_tokens_per_batch(&self) -> usize { fn max_tokens_per_batch(&self) -> usize {
8190 8190
} }
fn truncate(&self, span: &str) -> (String, usize) {
let truncated = span.chars().collect::<Vec<char>>()[..8190]
.iter()
.collect::<String>();
(truncated, 8190)
}
} }

View file

@ -61,8 +61,6 @@ struct OpenAIEmbeddingUsage {
total_tokens: usize, total_tokens: usize,
} }
const OPENAI_INPUT_LIMIT: usize = 8190;
impl OpenAIEmbeddingProvider { impl OpenAIEmbeddingProvider {
pub fn new(client: Arc<dyn HttpClient>, executor: Arc<Background>) -> Self { pub fn new(client: Arc<dyn HttpClient>, executor: Arc<Background>) -> Self {
let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None); let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None);
@ -151,20 +149,20 @@ impl EmbeddingProvider for OpenAIEmbeddingProvider {
fn rate_limit_expiration(&self) -> Option<Instant> { fn rate_limit_expiration(&self) -> Option<Instant> {
*self.rate_limit_count_rx.borrow() *self.rate_limit_count_rx.borrow()
} }
fn truncate(&self, span: &str) -> (String, usize) { // fn truncate(&self, span: &str) -> (String, usize) {
let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span); // let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
let output = if tokens.len() > OPENAI_INPUT_LIMIT { // let output = if tokens.len() > OPENAI_INPUT_LIMIT {
tokens.truncate(OPENAI_INPUT_LIMIT); // tokens.truncate(OPENAI_INPUT_LIMIT);
OPENAI_BPE_TOKENIZER // OPENAI_BPE_TOKENIZER
.decode(tokens.clone()) // .decode(tokens.clone())
.ok() // .ok()
.unwrap_or_else(|| span.to_string()) // .unwrap_or_else(|| span.to_string())
} else { // } else {
span.to_string() // span.to_string()
}; // };
(output, tokens.len()) // (output, tokens.len())
} // }
async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> { async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45]; const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];

View file

@ -1,4 +1,7 @@
use ai::embedding::{Embedding, EmbeddingProvider}; use ai::{
embedding::{Embedding, EmbeddingProvider},
models::TruncationDirection,
};
use anyhow::{anyhow, Result}; use anyhow::{anyhow, Result};
use language::{Grammar, Language}; use language::{Grammar, Language};
use rusqlite::{ use rusqlite::{
@ -108,7 +111,14 @@ impl CodeContextRetriever {
.replace("<language>", language_name.as_ref()) .replace("<language>", language_name.as_ref())
.replace("<item>", &content); .replace("<item>", &content);
let digest = SpanDigest::from(document_span.as_str()); let digest = SpanDigest::from(document_span.as_str());
let (document_span, token_count) = self.embedding_provider.truncate(&document_span); let model = self.embedding_provider.base_model();
let document_span = model.truncate(
&document_span,
model.capacity()?,
ai::models::TruncationDirection::End,
)?;
let token_count = model.count_tokens(&document_span)?;
Ok(vec![Span { Ok(vec![Span {
range: 0..content.len(), range: 0..content.len(),
content: document_span, content: document_span,
@ -131,7 +141,15 @@ impl CodeContextRetriever {
) )
.replace("<item>", &content); .replace("<item>", &content);
let digest = SpanDigest::from(document_span.as_str()); let digest = SpanDigest::from(document_span.as_str());
let (document_span, token_count) = self.embedding_provider.truncate(&document_span);
let model = self.embedding_provider.base_model();
let document_span = model.truncate(
&document_span,
model.capacity()?,
ai::models::TruncationDirection::End,
)?;
let token_count = model.count_tokens(&document_span)?;
Ok(vec![Span { Ok(vec![Span {
range: 0..content.len(), range: 0..content.len(),
content: document_span, content: document_span,
@ -222,8 +240,13 @@ impl CodeContextRetriever {
.replace("<language>", language_name.as_ref()) .replace("<language>", language_name.as_ref())
.replace("item", &span.content); .replace("item", &span.content);
let (document_content, token_count) = let model = self.embedding_provider.base_model();
self.embedding_provider.truncate(&document_content); let document_content = model.truncate(
&document_content,
model.capacity()?,
TruncationDirection::End,
)?;
let token_count = model.count_tokens(&document_content)?;
span.content = document_content; span.content = document_content;
span.token_count = token_count; span.token_count = token_count;

View file

@ -1291,12 +1291,8 @@ impl EmbeddingProvider for FakeEmbeddingProvider {
fn is_authenticated(&self) -> bool { fn is_authenticated(&self) -> bool {
true true
} }
fn truncate(&self, span: &str) -> (String, usize) {
(span.to_string(), 1)
}
fn max_tokens_per_batch(&self) -> usize { fn max_tokens_per_batch(&self) -> usize {
200 1000
} }
fn rate_limit_expiration(&self) -> Option<Instant> { fn rate_limit_expiration(&self) -> Option<Instant> {
@ -1306,7 +1302,8 @@ impl EmbeddingProvider for FakeEmbeddingProvider {
async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> { async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
self.embedding_count self.embedding_count
.fetch_add(spans.len(), atomic::Ordering::SeqCst); .fetch_add(spans.len(), atomic::Ordering::SeqCst);
Ok(spans.iter().map(|span| self.embed_sync(span)).collect())
anyhow::Ok(spans.iter().map(|span| self.embed_sync(span)).collect())
} }
} }