move queuing to embedding_queue functionality and update embedding provider to include trait items for max tokens per batch"
Co-authored-by: Max <max@zed.dev>
This commit is contained in:
parent
9781047156
commit
76ce52df4e
5 changed files with 295 additions and 91 deletions
|
@ -53,36 +53,30 @@ struct OpenAIEmbeddingUsage {
|
|||
|
||||
#[async_trait]
|
||||
pub trait EmbeddingProvider: Sync + Send {
|
||||
async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>>;
|
||||
fn count_tokens(&self, span: &str) -> usize;
|
||||
fn should_truncate(&self, span: &str) -> bool;
|
||||
fn truncate(&self, span: &str) -> String;
|
||||
async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Vec<f32>>>;
|
||||
fn max_tokens_per_batch(&self) -> usize;
|
||||
fn truncate(&self, span: &str) -> (String, usize);
|
||||
}
|
||||
|
||||
pub struct DummyEmbeddings {}
|
||||
|
||||
#[async_trait]
|
||||
impl EmbeddingProvider for DummyEmbeddings {
|
||||
async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
|
||||
async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Vec<f32>>> {
|
||||
// 1024 is the OpenAI Embeddings size for ada models.
|
||||
// the model we will likely be starting with.
|
||||
let dummy_vec = vec![0.32 as f32; 1536];
|
||||
return Ok(vec![dummy_vec; spans.len()]);
|
||||
}
|
||||
|
||||
fn count_tokens(&self, span: &str) -> usize {
|
||||
// For Dummy Providers, we are going to use OpenAI tokenization for ease
|
||||
let tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
|
||||
tokens.len()
|
||||
fn max_tokens_per_batch(&self) -> usize {
|
||||
OPENAI_INPUT_LIMIT
|
||||
}
|
||||
|
||||
fn should_truncate(&self, span: &str) -> bool {
|
||||
self.count_tokens(span) > OPENAI_INPUT_LIMIT
|
||||
}
|
||||
|
||||
fn truncate(&self, span: &str) -> String {
|
||||
fn truncate(&self, span: &str) -> (String, usize) {
|
||||
let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
|
||||
let output = if tokens.len() > OPENAI_INPUT_LIMIT {
|
||||
let token_count = tokens.len();
|
||||
let output = if token_count > OPENAI_INPUT_LIMIT {
|
||||
tokens.truncate(OPENAI_INPUT_LIMIT);
|
||||
OPENAI_BPE_TOKENIZER
|
||||
.decode(tokens)
|
||||
|
@ -92,7 +86,7 @@ impl EmbeddingProvider for DummyEmbeddings {
|
|||
span.to_string()
|
||||
};
|
||||
|
||||
output
|
||||
(output, token_count)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -125,19 +119,14 @@ impl OpenAIEmbeddings {
|
|||
|
||||
#[async_trait]
|
||||
impl EmbeddingProvider for OpenAIEmbeddings {
|
||||
fn count_tokens(&self, span: &str) -> usize {
|
||||
// For Dummy Providers, we are going to use OpenAI tokenization for ease
|
||||
let tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
|
||||
tokens.len()
|
||||
fn max_tokens_per_batch(&self) -> usize {
|
||||
OPENAI_INPUT_LIMIT
|
||||
}
|
||||
|
||||
fn should_truncate(&self, span: &str) -> bool {
|
||||
self.count_tokens(span) > OPENAI_INPUT_LIMIT
|
||||
}
|
||||
|
||||
fn truncate(&self, span: &str) -> String {
|
||||
fn truncate(&self, span: &str) -> (String, usize) {
|
||||
let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
|
||||
let output = if tokens.len() > OPENAI_INPUT_LIMIT {
|
||||
let token_count = tokens.len();
|
||||
let output = if token_count > OPENAI_INPUT_LIMIT {
|
||||
tokens.truncate(OPENAI_INPUT_LIMIT);
|
||||
OPENAI_BPE_TOKENIZER
|
||||
.decode(tokens)
|
||||
|
@ -147,10 +136,10 @@ impl EmbeddingProvider for OpenAIEmbeddings {
|
|||
span.to_string()
|
||||
};
|
||||
|
||||
output
|
||||
(output, token_count)
|
||||
}
|
||||
|
||||
async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
|
||||
async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Vec<f32>>> {
|
||||
const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
|
||||
const MAX_RETRIES: usize = 4;
|
||||
|
||||
|
@ -160,9 +149,7 @@ impl EmbeddingProvider for OpenAIEmbeddings {
|
|||
|
||||
let mut request_number = 0;
|
||||
let mut request_timeout: u64 = 10;
|
||||
let mut truncated = false;
|
||||
let mut response: Response<AsyncBody>;
|
||||
let mut spans: Vec<String> = spans.iter().map(|x| x.to_string()).collect();
|
||||
while request_number < MAX_RETRIES {
|
||||
response = self
|
||||
.send_request(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue