updated authentication for embedding provider

This commit is contained in:
KCaverly 2023-10-26 11:18:16 +02:00
parent 71bc35d241
commit 3447a9478c
16 changed files with 277 additions and 271 deletions

View file

@ -1,5 +1,5 @@
use crate::{parsing::Span, JobHandle};
use ai::embedding::EmbeddingProvider;
use ai::{auth::ProviderCredential, embedding::EmbeddingProvider};
use gpui::executor::Background;
use parking_lot::Mutex;
use smol::channel;
@ -41,7 +41,7 @@ pub struct EmbeddingQueue {
pending_batch_token_count: usize,
finished_files_tx: channel::Sender<FileToEmbed>,
finished_files_rx: channel::Receiver<FileToEmbed>,
api_key: Option<String>,
provider_credential: ProviderCredential,
}
#[derive(Clone)]
@ -54,7 +54,7 @@ impl EmbeddingQueue {
pub fn new(
embedding_provider: Arc<dyn EmbeddingProvider>,
executor: Arc<Background>,
api_key: Option<String>,
provider_credential: ProviderCredential,
) -> Self {
let (finished_files_tx, finished_files_rx) = channel::unbounded();
Self {
@ -64,12 +64,12 @@ impl EmbeddingQueue {
pending_batch_token_count: 0,
finished_files_tx,
finished_files_rx,
api_key,
provider_credential,
}
}
pub fn set_api_key(&mut self, api_key: Option<String>) {
self.api_key = api_key
pub fn set_credential(&mut self, credential: ProviderCredential) {
self.provider_credential = credential
}
pub fn push(&mut self, file: FileToEmbed) {
@ -118,7 +118,7 @@ impl EmbeddingQueue {
let finished_files_tx = self.finished_files_tx.clone();
let embedding_provider = self.embedding_provider.clone();
let api_key = self.api_key.clone();
let credential = self.provider_credential.clone();
self.executor
.spawn(async move {
@ -143,7 +143,7 @@ impl EmbeddingQueue {
return;
};
match embedding_provider.embed_batch(spans, api_key).await {
match embedding_provider.embed_batch(spans, credential).await {
Ok(embeddings) => {
let mut embeddings = embeddings.into_iter();
for fragment in batch {