ensure OpenAIEmbeddingProvider is using the provider credentials

This commit is contained in:
KCaverly 2023-10-27 08:33:35 +02:00
parent ca82ec8e8e
commit 6c8bb4b05e
3 changed files with 13 additions and 17 deletions

View file

@ -162,14 +162,15 @@ impl EmbeddingProvider for OpenAIEmbeddingProvider {
async fn embed_batch( async fn embed_batch(
&self, &self,
spans: Vec<String>, spans: Vec<String>,
_credential: ProviderCredential, credential: ProviderCredential,
) -> Result<Vec<Embedding>> { ) -> Result<Vec<Embedding>> {
const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45]; const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
const MAX_RETRIES: usize = 4; const MAX_RETRIES: usize = 4;
let api_key = OPENAI_API_KEY let api_key = match credential {
.as_ref() ProviderCredential::Credentials { api_key } => anyhow::Ok(api_key),
.ok_or_else(|| anyhow!("no api key"))?; _ => Err(anyhow!("no api key provided")),
}?;
let mut request_number = 0; let mut request_number = 0;
let mut rate_limiting = false; let mut rate_limiting = false;
@ -178,7 +179,7 @@ impl EmbeddingProvider for OpenAIEmbeddingProvider {
while request_number < MAX_RETRIES { while request_number < MAX_RETRIES {
response = self response = self
.send_request( .send_request(
api_key, &api_key,
spans.iter().map(|x| &**x).collect(), spans.iter().map(|x| &**x).collect(),
request_timeout, request_timeout,
) )

View file

@ -41,7 +41,7 @@ pub struct EmbeddingQueue {
pending_batch_token_count: usize, pending_batch_token_count: usize,
finished_files_tx: channel::Sender<FileToEmbed>, finished_files_tx: channel::Sender<FileToEmbed>,
finished_files_rx: channel::Receiver<FileToEmbed>, finished_files_rx: channel::Receiver<FileToEmbed>,
provider_credential: ProviderCredential, pub provider_credential: ProviderCredential,
} }
#[derive(Clone)] #[derive(Clone)]

View file

@ -281,15 +281,13 @@ impl SemanticIndex {
} }
pub fn authenticate(&mut self, cx: &AppContext) -> bool { pub fn authenticate(&mut self, cx: &AppContext) -> bool {
let credential = self.provider_credential.clone(); let existing_credential = self.provider_credential.clone();
match credential { let credential = match existing_credential {
ProviderCredential::NoCredentials => { ProviderCredential::NoCredentials => self.embedding_provider.retrieve_credentials(cx),
let credential = self.embedding_provider.retrieve_credentials(cx); _ => existing_credential,
self.provider_credential = credential; };
}
_ => {}
}
self.provider_credential = credential.clone();
self.embedding_queue.lock().set_credential(credential); self.embedding_queue.lock().set_credential(credential);
self.is_authenticated() self.is_authenticated()
} }
@ -1020,14 +1018,11 @@ impl SemanticIndex {
cx: &mut ModelContext<Self>, cx: &mut ModelContext<Self>,
) -> Task<Result<()>> { ) -> Task<Result<()>> {
if !self.is_authenticated() { if !self.is_authenticated() {
println!("Authenticating");
if !self.authenticate(cx) { if !self.authenticate(cx) {
return Task::ready(Err(anyhow!("user is not authenticated"))); return Task::ready(Err(anyhow!("user is not authenticated")));
} }
} }
println!("SHOULD NOW BE AUTHENTICATED");
if !self.projects.contains_key(&project.downgrade()) { if !self.projects.contains_key(&project.downgrade()) {
let subscription = cx.subscribe(&project, |this, project, event, cx| match event { let subscription = cx.subscribe(&project, |this, project, event, cx| match event {
project::Event::WorktreeAdded | project::Event::WorktreeRemoved(_) => { project::Event::WorktreeAdded | project::Event::WorktreeRemoved(_) => {