From 76caea80f7543cf86eaf0f4e899f06ea478f3d8a Mon Sep 17 00:00:00 2001 From: KCaverly Date: Wed, 30 Aug 2023 11:58:45 -0400 Subject: [PATCH] add should_truncate to embedding providers --- crates/semantic_index/src/embedding.rs | 19 +++++++++++++++++++ .../src/semantic_index_tests.rs | 4 ++++ 2 files changed, 23 insertions(+) diff --git a/crates/semantic_index/src/embedding.rs b/crates/semantic_index/src/embedding.rs index 72621d3138..3dd979f01b 100644 --- a/crates/semantic_index/src/embedding.rs +++ b/crates/semantic_index/src/embedding.rs @@ -55,6 +55,7 @@ struct OpenAIEmbeddingUsage { pub trait EmbeddingProvider: Sync + Send { async fn embed_batch(&self, spans: Vec<&str>) -> Result>>; fn count_tokens(&self, span: &str) -> usize; + fn should_truncate(&self, span: &str) -> bool; // fn truncate(&self, span: &str) -> Result<&str>; } @@ -74,6 +75,20 @@ impl EmbeddingProvider for DummyEmbeddings { let tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span); tokens.len() } + + fn should_truncate(&self, span: &str) -> bool { + self.count_tokens(span) > OPENAI_INPUT_LIMIT + + // let tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span); + // let Ok(output) = { + // if tokens.len() > OPENAI_INPUT_LIMIT { + // tokens.truncate(OPENAI_INPUT_LIMIT); + // OPENAI_BPE_TOKENIZER.decode(tokens) + // } else { + // Ok(span) + // } + // }; + } } const OPENAI_INPUT_LIMIT: usize = 8190; @@ -125,6 +140,10 @@ impl EmbeddingProvider for OpenAIEmbeddings { tokens.len() } + fn should_truncate(&self, span: &str) -> bool { + self.count_tokens(span) > OPENAI_INPUT_LIMIT + } + async fn embed_batch(&self, spans: Vec<&str>) -> Result>> { const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45]; const MAX_RETRIES: usize = 4; diff --git a/crates/semantic_index/src/semantic_index_tests.rs b/crates/semantic_index/src/semantic_index_tests.rs index cb318a9fd6..48cefd93b1 100644 --- a/crates/semantic_index/src/semantic_index_tests.rs +++ b/crates/semantic_index/src/semantic_index_tests.rs @@ -1228,6 +1228,10 @@ impl EmbeddingProvider for FakeEmbeddingProvider { span.len() } + fn should_truncate(&self, span: &str) -> bool { + false + } + async fn embed_batch(&self, spans: Vec<&str>) -> Result>> { self.embedding_count .fetch_add(spans.len(), atomic::Ordering::SeqCst);