update semantic search to use keychain as fallback

This commit is contained in:
KCaverly 2023-10-21 10:19:50 -04:00
parent 9c49191031
commit 106115676d
4 changed files with 73 additions and 12 deletions

View file

@ -2,7 +2,7 @@ use anyhow::{anyhow, Result};
use async_trait::async_trait;
use futures::AsyncReadExt;
use gpui::executor::Background;
use gpui::serde_json;
use gpui::{serde_json, ViewContext};
use isahc::http::StatusCode;
use isahc::prelude::Configurable;
use isahc::{AsyncBody, Response};
@ -20,9 +20,11 @@ use std::sync::Arc;
use std::time::{Duration, Instant};
use tiktoken_rs::{cl100k_base, CoreBPE};
use util::http::{HttpClient, Request};
use util::ResultExt;
use crate::completion::OPENAI_API_URL;
lazy_static! {
static ref OPENAI_API_KEY: Option<String> = env::var("OPENAI_API_KEY").ok();
static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap();
}
@ -87,6 +89,7 @@ impl Embedding {
#[derive(Clone)]
pub struct OpenAIEmbeddings {
pub api_key: Option<String>,
pub client: Arc<dyn HttpClient>,
pub executor: Arc<Background>,
rate_limit_count_rx: watch::Receiver<Option<Instant>>,
@ -166,11 +169,36 @@ impl EmbeddingProvider for DummyEmbeddings {
const OPENAI_INPUT_LIMIT: usize = 8190;
impl OpenAIEmbeddings {
pub fn new(client: Arc<dyn HttpClient>, executor: Arc<Background>) -> Self {
pub fn authenticate(&mut self, cx: &mut ViewContext<Self>) {
if self.api_key.is_none() {
let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") {
Some(api_key)
} else if let Some((_, api_key)) = cx
.platform()
.read_credentials(OPENAI_API_URL)
.log_err()
.flatten()
{
String::from_utf8(api_key).log_err()
} else {
None
};
if let Some(api_key) = api_key {
self.api_key = Some(api_key);
}
}
}
pub fn new(
api_key: Option<String>,
client: Arc<dyn HttpClient>,
executor: Arc<Background>,
) -> Self {
let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None);
let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx));
OpenAIEmbeddings {
api_key,
client,
executor,
rate_limit_count_rx,
@ -237,8 +265,9 @@ impl OpenAIEmbeddings {
#[async_trait]
impl EmbeddingProvider for OpenAIEmbeddings {
fn is_authenticated(&self) -> bool {
OPENAI_API_KEY.as_ref().is_some()
self.api_key.is_some()
}
fn max_tokens_per_batch(&self) -> usize {
50000
}
@ -265,9 +294,9 @@ impl EmbeddingProvider for OpenAIEmbeddings {
const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
const MAX_RETRIES: usize = 4;
let api_key = OPENAI_API_KEY
.as_ref()
.ok_or_else(|| anyhow!("no api key"))?;
let Some(api_key) = self.api_key.clone() else {
return Err(anyhow!("no open ai key provided"));
};
let mut request_number = 0;
let mut rate_limiting = false;
@ -276,7 +305,7 @@ impl EmbeddingProvider for OpenAIEmbeddings {
while request_number < MAX_RETRIES {
response = self
.send_request(
api_key,
&api_key,
spans.iter().map(|x| &**x).collect(),
request_timeout,
)