parellelize embedding api calls

This commit is contained in:
KCaverly 2023-07-18 16:09:44 -04:00
parent 342dbc6945
commit 0e071919a0
2 changed files with 42 additions and 18 deletions

View file

@ -106,7 +106,7 @@ impl OpenAIEmbeddings {
#[async_trait]
impl EmbeddingProvider for OpenAIEmbeddings {
async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
const BACKOFF_SECONDS: [usize; 3] = [65, 180, 360];
const BACKOFF_SECONDS: [usize; 3] = [45, 75, 125];
const MAX_RETRIES: usize = 3;
let api_key = OPENAI_API_KEY
@ -133,6 +133,10 @@ impl EmbeddingProvider for OpenAIEmbeddings {
match response.status() {
StatusCode::TOO_MANY_REQUESTS => {
let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64);
log::trace!(
"open ai rate limiting, delaying request by {:?} seconds",
delay.as_secs()
);
self.executor.timer(delay).await;
}
StatusCode::BAD_REQUEST => {

View file

@ -24,7 +24,7 @@ use std::{
ops::Range,
path::{Path, PathBuf},
sync::{Arc, Weak},
time::SystemTime,
time::{Instant, SystemTime},
};
use util::{
channel::{ReleaseChannel, RELEASE_CHANNEL, RELEASE_CHANNEL_NAME},
@ -34,7 +34,7 @@ use util::{
};
const SEMANTIC_INDEX_VERSION: usize = 4;
const EMBEDDINGS_BATCH_SIZE: usize = 150;
const EMBEDDINGS_BATCH_SIZE: usize = 80;
pub fn init(
fs: Arc<dyn Fs>,
@ -84,7 +84,7 @@ pub struct SemanticIndex {
db_update_tx: channel::Sender<DbOperation>,
parsing_files_tx: channel::Sender<PendingFile>,
_db_update_task: Task<()>,
_embed_batch_task: Task<()>,
_embed_batch_tasks: Vec<Task<()>>,
_batch_files_task: Task<()>,
_parsing_files_tasks: Vec<Task<()>>,
projects: HashMap<WeakModelHandle<Project>, ProjectState>,
@ -189,6 +189,7 @@ impl SemanticIndex {
language_registry: Arc<LanguageRegistry>,
mut cx: AsyncAppContext,
) -> Result<ModelHandle<Self>> {
let t0 = Instant::now();
let database_url = Arc::new(database_url);
let db = cx
@ -196,7 +197,13 @@ impl SemanticIndex {
.spawn(VectorDatabase::new(fs.clone(), database_url.clone()))
.await?;
log::trace!(
"db initialization took {:?} milliseconds",
t0.elapsed().as_millis()
);
Ok(cx.add_model(|cx| {
let t0 = Instant::now();
// Perform database operations
let (db_update_tx, db_update_rx) = channel::unbounded();
let _db_update_task = cx.background().spawn({
@ -210,20 +217,24 @@ impl SemanticIndex {
// Group documents into batches and send them to the embedding provider.
let (embed_batch_tx, embed_batch_rx) =
channel::unbounded::<Vec<(i64, Vec<Document>, PathBuf, SystemTime, JobHandle)>>();
let _embed_batch_task = cx.background().spawn({
let db_update_tx = db_update_tx.clone();
let embedding_provider = embedding_provider.clone();
async move {
while let Ok(embeddings_queue) = embed_batch_rx.recv().await {
Self::compute_embeddings_for_batch(
embeddings_queue,
&embedding_provider,
&db_update_tx,
)
.await;
let mut _embed_batch_tasks = Vec::new();
for _ in 0..cx.background().num_cpus() {
let embed_batch_rx = embed_batch_rx.clone();
_embed_batch_tasks.push(cx.background().spawn({
let db_update_tx = db_update_tx.clone();
let embedding_provider = embedding_provider.clone();
async move {
while let Ok(embeddings_queue) = embed_batch_rx.recv().await {
Self::compute_embeddings_for_batch(
embeddings_queue,
&embedding_provider,
&db_update_tx,
)
.await;
}
}
}
});
}));
}
// Group documents into batches and send them to the embedding provider.
let (batch_files_tx, batch_files_rx) = channel::unbounded::<EmbeddingJob>();
@ -264,6 +275,10 @@ impl SemanticIndex {
}));
}
log::trace!(
"semantic index task initialization took {:?} milliseconds",
t0.elapsed().as_millis()
);
Self {
fs,
database_url,
@ -272,7 +287,7 @@ impl SemanticIndex {
db_update_tx,
parsing_files_tx,
_db_update_task,
_embed_batch_task,
_embed_batch_tasks,
_batch_files_task,
_parsing_files_tasks,
projects: HashMap::new(),
@ -460,6 +475,7 @@ impl SemanticIndex {
project: ModelHandle<Project>,
cx: &mut ModelContext<Self>,
) -> Task<Result<(usize, watch::Receiver<usize>)>> {
let t0 = Instant::now();
let worktree_scans_complete = project
.read(cx)
.worktrees(cx)
@ -577,6 +593,10 @@ impl SemanticIndex {
}
}
log::trace!(
"walking worktree took {:?} milliseconds",
t0.elapsed().as_millis()
);
anyhow::Ok((count, job_count_rx))
})
.await