move embedding truncation to base model
This commit is contained in:
parent
2b780ee7b2
commit
4e90e45999
5 changed files with 48 additions and 34 deletions
|
@ -1,4 +1,7 @@
|
|||
use ai::embedding::{Embedding, EmbeddingProvider};
|
||||
use ai::{
|
||||
embedding::{Embedding, EmbeddingProvider},
|
||||
models::TruncationDirection,
|
||||
};
|
||||
use anyhow::{anyhow, Result};
|
||||
use language::{Grammar, Language};
|
||||
use rusqlite::{
|
||||
|
@ -108,7 +111,14 @@ impl CodeContextRetriever {
|
|||
.replace("<language>", language_name.as_ref())
|
||||
.replace("<item>", &content);
|
||||
let digest = SpanDigest::from(document_span.as_str());
|
||||
let (document_span, token_count) = self.embedding_provider.truncate(&document_span);
|
||||
let model = self.embedding_provider.base_model();
|
||||
let document_span = model.truncate(
|
||||
&document_span,
|
||||
model.capacity()?,
|
||||
ai::models::TruncationDirection::End,
|
||||
)?;
|
||||
let token_count = model.count_tokens(&document_span)?;
|
||||
|
||||
Ok(vec![Span {
|
||||
range: 0..content.len(),
|
||||
content: document_span,
|
||||
|
@ -131,7 +141,15 @@ impl CodeContextRetriever {
|
|||
)
|
||||
.replace("<item>", &content);
|
||||
let digest = SpanDigest::from(document_span.as_str());
|
||||
let (document_span, token_count) = self.embedding_provider.truncate(&document_span);
|
||||
|
||||
let model = self.embedding_provider.base_model();
|
||||
let document_span = model.truncate(
|
||||
&document_span,
|
||||
model.capacity()?,
|
||||
ai::models::TruncationDirection::End,
|
||||
)?;
|
||||
let token_count = model.count_tokens(&document_span)?;
|
||||
|
||||
Ok(vec![Span {
|
||||
range: 0..content.len(),
|
||||
content: document_span,
|
||||
|
@ -222,8 +240,13 @@ impl CodeContextRetriever {
|
|||
.replace("<language>", language_name.as_ref())
|
||||
.replace("item", &span.content);
|
||||
|
||||
let (document_content, token_count) =
|
||||
self.embedding_provider.truncate(&document_content);
|
||||
let model = self.embedding_provider.base_model();
|
||||
let document_content = model.truncate(
|
||||
&document_content,
|
||||
model.capacity()?,
|
||||
TruncationDirection::End,
|
||||
)?;
|
||||
let token_count = model.count_tokens(&document_content)?;
|
||||
|
||||
span.content = document_content;
|
||||
span.token_count = token_count;
|
||||
|
|
|
@ -1291,12 +1291,8 @@ impl EmbeddingProvider for FakeEmbeddingProvider {
|
|||
fn is_authenticated(&self) -> bool {
|
||||
true
|
||||
}
|
||||
fn truncate(&self, span: &str) -> (String, usize) {
|
||||
(span.to_string(), 1)
|
||||
}
|
||||
|
||||
fn max_tokens_per_batch(&self) -> usize {
|
||||
200
|
||||
1000
|
||||
}
|
||||
|
||||
fn rate_limit_expiration(&self) -> Option<Instant> {
|
||||
|
@ -1306,7 +1302,8 @@ impl EmbeddingProvider for FakeEmbeddingProvider {
|
|||
async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
|
||||
self.embedding_count
|
||||
.fetch_add(spans.len(), atomic::Ordering::SeqCst);
|
||||
Ok(spans.iter().map(|span| self.embed_sync(span)).collect())
|
||||
|
||||
anyhow::Ok(spans.iter().map(|span| self.embed_sync(span)).collect())
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue