parellelize embedding api calls
This commit is contained in:
parent
342dbc6945
commit
0e071919a0
2 changed files with 42 additions and 18 deletions
|
@ -106,7 +106,7 @@ impl OpenAIEmbeddings {
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl EmbeddingProvider for OpenAIEmbeddings {
|
impl EmbeddingProvider for OpenAIEmbeddings {
|
||||||
async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
|
async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
|
||||||
const BACKOFF_SECONDS: [usize; 3] = [65, 180, 360];
|
const BACKOFF_SECONDS: [usize; 3] = [45, 75, 125];
|
||||||
const MAX_RETRIES: usize = 3;
|
const MAX_RETRIES: usize = 3;
|
||||||
|
|
||||||
let api_key = OPENAI_API_KEY
|
let api_key = OPENAI_API_KEY
|
||||||
|
@ -133,6 +133,10 @@ impl EmbeddingProvider for OpenAIEmbeddings {
|
||||||
match response.status() {
|
match response.status() {
|
||||||
StatusCode::TOO_MANY_REQUESTS => {
|
StatusCode::TOO_MANY_REQUESTS => {
|
||||||
let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64);
|
let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64);
|
||||||
|
log::trace!(
|
||||||
|
"open ai rate limiting, delaying request by {:?} seconds",
|
||||||
|
delay.as_secs()
|
||||||
|
);
|
||||||
self.executor.timer(delay).await;
|
self.executor.timer(delay).await;
|
||||||
}
|
}
|
||||||
StatusCode::BAD_REQUEST => {
|
StatusCode::BAD_REQUEST => {
|
||||||
|
|
|
@ -24,7 +24,7 @@ use std::{
|
||||||
ops::Range,
|
ops::Range,
|
||||||
path::{Path, PathBuf},
|
path::{Path, PathBuf},
|
||||||
sync::{Arc, Weak},
|
sync::{Arc, Weak},
|
||||||
time::SystemTime,
|
time::{Instant, SystemTime},
|
||||||
};
|
};
|
||||||
use util::{
|
use util::{
|
||||||
channel::{ReleaseChannel, RELEASE_CHANNEL, RELEASE_CHANNEL_NAME},
|
channel::{ReleaseChannel, RELEASE_CHANNEL, RELEASE_CHANNEL_NAME},
|
||||||
|
@ -34,7 +34,7 @@ use util::{
|
||||||
};
|
};
|
||||||
|
|
||||||
const SEMANTIC_INDEX_VERSION: usize = 4;
|
const SEMANTIC_INDEX_VERSION: usize = 4;
|
||||||
const EMBEDDINGS_BATCH_SIZE: usize = 150;
|
const EMBEDDINGS_BATCH_SIZE: usize = 80;
|
||||||
|
|
||||||
pub fn init(
|
pub fn init(
|
||||||
fs: Arc<dyn Fs>,
|
fs: Arc<dyn Fs>,
|
||||||
|
@ -84,7 +84,7 @@ pub struct SemanticIndex {
|
||||||
db_update_tx: channel::Sender<DbOperation>,
|
db_update_tx: channel::Sender<DbOperation>,
|
||||||
parsing_files_tx: channel::Sender<PendingFile>,
|
parsing_files_tx: channel::Sender<PendingFile>,
|
||||||
_db_update_task: Task<()>,
|
_db_update_task: Task<()>,
|
||||||
_embed_batch_task: Task<()>,
|
_embed_batch_tasks: Vec<Task<()>>,
|
||||||
_batch_files_task: Task<()>,
|
_batch_files_task: Task<()>,
|
||||||
_parsing_files_tasks: Vec<Task<()>>,
|
_parsing_files_tasks: Vec<Task<()>>,
|
||||||
projects: HashMap<WeakModelHandle<Project>, ProjectState>,
|
projects: HashMap<WeakModelHandle<Project>, ProjectState>,
|
||||||
|
@ -189,6 +189,7 @@ impl SemanticIndex {
|
||||||
language_registry: Arc<LanguageRegistry>,
|
language_registry: Arc<LanguageRegistry>,
|
||||||
mut cx: AsyncAppContext,
|
mut cx: AsyncAppContext,
|
||||||
) -> Result<ModelHandle<Self>> {
|
) -> Result<ModelHandle<Self>> {
|
||||||
|
let t0 = Instant::now();
|
||||||
let database_url = Arc::new(database_url);
|
let database_url = Arc::new(database_url);
|
||||||
|
|
||||||
let db = cx
|
let db = cx
|
||||||
|
@ -196,7 +197,13 @@ impl SemanticIndex {
|
||||||
.spawn(VectorDatabase::new(fs.clone(), database_url.clone()))
|
.spawn(VectorDatabase::new(fs.clone(), database_url.clone()))
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
|
log::trace!(
|
||||||
|
"db initialization took {:?} milliseconds",
|
||||||
|
t0.elapsed().as_millis()
|
||||||
|
);
|
||||||
|
|
||||||
Ok(cx.add_model(|cx| {
|
Ok(cx.add_model(|cx| {
|
||||||
|
let t0 = Instant::now();
|
||||||
// Perform database operations
|
// Perform database operations
|
||||||
let (db_update_tx, db_update_rx) = channel::unbounded();
|
let (db_update_tx, db_update_rx) = channel::unbounded();
|
||||||
let _db_update_task = cx.background().spawn({
|
let _db_update_task = cx.background().spawn({
|
||||||
|
@ -210,20 +217,24 @@ impl SemanticIndex {
|
||||||
// Group documents into batches and send them to the embedding provider.
|
// Group documents into batches and send them to the embedding provider.
|
||||||
let (embed_batch_tx, embed_batch_rx) =
|
let (embed_batch_tx, embed_batch_rx) =
|
||||||
channel::unbounded::<Vec<(i64, Vec<Document>, PathBuf, SystemTime, JobHandle)>>();
|
channel::unbounded::<Vec<(i64, Vec<Document>, PathBuf, SystemTime, JobHandle)>>();
|
||||||
let _embed_batch_task = cx.background().spawn({
|
let mut _embed_batch_tasks = Vec::new();
|
||||||
let db_update_tx = db_update_tx.clone();
|
for _ in 0..cx.background().num_cpus() {
|
||||||
let embedding_provider = embedding_provider.clone();
|
let embed_batch_rx = embed_batch_rx.clone();
|
||||||
async move {
|
_embed_batch_tasks.push(cx.background().spawn({
|
||||||
while let Ok(embeddings_queue) = embed_batch_rx.recv().await {
|
let db_update_tx = db_update_tx.clone();
|
||||||
Self::compute_embeddings_for_batch(
|
let embedding_provider = embedding_provider.clone();
|
||||||
embeddings_queue,
|
async move {
|
||||||
&embedding_provider,
|
while let Ok(embeddings_queue) = embed_batch_rx.recv().await {
|
||||||
&db_update_tx,
|
Self::compute_embeddings_for_batch(
|
||||||
)
|
embeddings_queue,
|
||||||
.await;
|
&embedding_provider,
|
||||||
|
&db_update_tx,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}));
|
||||||
});
|
}
|
||||||
|
|
||||||
// Group documents into batches and send them to the embedding provider.
|
// Group documents into batches and send them to the embedding provider.
|
||||||
let (batch_files_tx, batch_files_rx) = channel::unbounded::<EmbeddingJob>();
|
let (batch_files_tx, batch_files_rx) = channel::unbounded::<EmbeddingJob>();
|
||||||
|
@ -264,6 +275,10 @@ impl SemanticIndex {
|
||||||
}));
|
}));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
log::trace!(
|
||||||
|
"semantic index task initialization took {:?} milliseconds",
|
||||||
|
t0.elapsed().as_millis()
|
||||||
|
);
|
||||||
Self {
|
Self {
|
||||||
fs,
|
fs,
|
||||||
database_url,
|
database_url,
|
||||||
|
@ -272,7 +287,7 @@ impl SemanticIndex {
|
||||||
db_update_tx,
|
db_update_tx,
|
||||||
parsing_files_tx,
|
parsing_files_tx,
|
||||||
_db_update_task,
|
_db_update_task,
|
||||||
_embed_batch_task,
|
_embed_batch_tasks,
|
||||||
_batch_files_task,
|
_batch_files_task,
|
||||||
_parsing_files_tasks,
|
_parsing_files_tasks,
|
||||||
projects: HashMap::new(),
|
projects: HashMap::new(),
|
||||||
|
@ -460,6 +475,7 @@ impl SemanticIndex {
|
||||||
project: ModelHandle<Project>,
|
project: ModelHandle<Project>,
|
||||||
cx: &mut ModelContext<Self>,
|
cx: &mut ModelContext<Self>,
|
||||||
) -> Task<Result<(usize, watch::Receiver<usize>)>> {
|
) -> Task<Result<(usize, watch::Receiver<usize>)>> {
|
||||||
|
let t0 = Instant::now();
|
||||||
let worktree_scans_complete = project
|
let worktree_scans_complete = project
|
||||||
.read(cx)
|
.read(cx)
|
||||||
.worktrees(cx)
|
.worktrees(cx)
|
||||||
|
@ -577,6 +593,10 @@ impl SemanticIndex {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
log::trace!(
|
||||||
|
"walking worktree took {:?} milliseconds",
|
||||||
|
t0.elapsed().as_millis()
|
||||||
|
);
|
||||||
anyhow::Ok((count, job_count_rx))
|
anyhow::Ok((count, job_count_rx))
|
||||||
})
|
})
|
||||||
.await
|
.await
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue