use anyhow::{anyhow, Result}; use async_trait::async_trait; use futures::AsyncReadExt; use gpui::BackgroundExecutor; use gpui::{serde_json, AppContext}; use isahc::http::StatusCode; use isahc::prelude::Configurable; use isahc::{AsyncBody, Response}; use lazy_static::lazy_static; use parking_lot::{Mutex, RwLock}; use parse_duration::parse; use postage::watch; use serde::{Deserialize, Serialize}; use std::env; use std::ops::Add; use std::sync::Arc; use std::time::{Duration, Instant}; use tiktoken_rs::{cl100k_base, CoreBPE}; use util::http::{HttpClient, Request}; use util::ResultExt; use crate::auth::{CredentialProvider, ProviderCredential}; use crate::embedding::{Embedding, EmbeddingProvider}; use crate::models::LanguageModel; use crate::providers::open_ai::OpenAILanguageModel; use crate::providers::open_ai::OPENAI_API_URL; lazy_static! { static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap(); } #[derive(Clone)] pub struct OpenAIEmbeddingProvider { model: OpenAILanguageModel, credential: Arc>, pub client: Arc, pub executor: BackgroundExecutor, rate_limit_count_rx: watch::Receiver>, rate_limit_count_tx: Arc>>>, } #[derive(Serialize)] struct OpenAIEmbeddingRequest<'a> { model: &'static str, input: Vec<&'a str>, } #[derive(Deserialize)] struct OpenAIEmbeddingResponse { data: Vec, usage: OpenAIEmbeddingUsage, } #[derive(Debug, Deserialize)] struct OpenAIEmbedding { embedding: Vec, index: usize, object: String, } #[derive(Deserialize)] struct OpenAIEmbeddingUsage { prompt_tokens: usize, total_tokens: usize, } impl OpenAIEmbeddingProvider { pub fn new(client: Arc, executor: BackgroundExecutor) -> Self { let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None); let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx)); let model = OpenAILanguageModel::load("text-embedding-ada-002"); let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials)); OpenAIEmbeddingProvider { model, credential, client, executor, rate_limit_count_rx, rate_limit_count_tx, } } fn get_api_key(&self) -> Result { match self.credential.read().clone() { ProviderCredential::Credentials { api_key } => Ok(api_key), _ => Err(anyhow!("api credentials not provided")), } } fn resolve_rate_limit(&self) { let reset_time = *self.rate_limit_count_tx.lock().borrow(); if let Some(reset_time) = reset_time { if Instant::now() >= reset_time { *self.rate_limit_count_tx.lock().borrow_mut() = None } } log::trace!( "resolving reset time: {:?}", *self.rate_limit_count_tx.lock().borrow() ); } fn update_reset_time(&self, reset_time: Instant) { let original_time = *self.rate_limit_count_tx.lock().borrow(); let updated_time = if let Some(original_time) = original_time { if reset_time < original_time { Some(reset_time) } else { Some(original_time) } } else { Some(reset_time) }; log::trace!("updating rate limit time: {:?}", updated_time); *self.rate_limit_count_tx.lock().borrow_mut() = updated_time; } async fn send_request( &self, api_key: &str, spans: Vec<&str>, request_timeout: u64, ) -> Result> { let request = Request::post("https://api.openai.com/v1/embeddings") .redirect_policy(isahc::config::RedirectPolicy::Follow) .timeout(Duration::from_secs(request_timeout)) .header("Content-Type", "application/json") .header("Authorization", format!("Bearer {}", api_key)) .body( serde_json::to_string(&OpenAIEmbeddingRequest { input: spans.clone(), model: "text-embedding-ada-002", }) .unwrap() .into(), )?; Ok(self.client.send(request).await?) } } impl CredentialProvider for OpenAIEmbeddingProvider { fn has_credentials(&self) -> bool { match *self.credential.read() { ProviderCredential::Credentials { .. } => true, _ => false, } } fn retrieve_credentials(&self, cx: &mut AppContext) -> ProviderCredential { let existing_credential = self.credential.read().clone(); let retrieved_credential = match existing_credential { ProviderCredential::Credentials { .. } => existing_credential.clone(), _ => { if let Some(api_key) = env::var("OPENAI_API_KEY").log_err() { ProviderCredential::Credentials { api_key } } else if let Some(Some((_, api_key))) = cx.read_credentials(OPENAI_API_URL).log_err() { if let Some(api_key) = String::from_utf8(api_key).log_err() { ProviderCredential::Credentials { api_key } } else { ProviderCredential::NoCredentials } } else { ProviderCredential::NoCredentials } } }; *self.credential.write() = retrieved_credential.clone(); retrieved_credential } fn save_credentials(&self, cx: &mut AppContext, credential: ProviderCredential) { *self.credential.write() = credential.clone(); match credential { ProviderCredential::Credentials { api_key } => { cx.write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes()) .log_err(); } _ => {} } } fn delete_credentials(&self, cx: &mut AppContext) { cx.delete_credentials(OPENAI_API_URL).log_err(); *self.credential.write() = ProviderCredential::NoCredentials; } } #[async_trait] impl EmbeddingProvider for OpenAIEmbeddingProvider { fn base_model(&self) -> Box { let model: Box = Box::new(self.model.clone()); model } fn max_tokens_per_batch(&self) -> usize { 50000 } fn rate_limit_expiration(&self) -> Option { *self.rate_limit_count_rx.borrow() } async fn embed_batch(&self, spans: Vec) -> Result> { const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45]; const MAX_RETRIES: usize = 4; let api_key = self.get_api_key()?; let mut request_number = 0; let mut rate_limiting = false; let mut request_timeout: u64 = 15; let mut response: Response; while request_number < MAX_RETRIES { response = self .send_request( &api_key, spans.iter().map(|x| &**x).collect(), request_timeout, ) .await?; request_number += 1; match response.status() { StatusCode::REQUEST_TIMEOUT => { request_timeout += 5; } StatusCode::OK => { let mut body = String::new(); response.body_mut().read_to_string(&mut body).await?; let response: OpenAIEmbeddingResponse = serde_json::from_str(&body)?; log::trace!( "openai embedding completed. tokens: {:?}", response.usage.total_tokens ); // If we complete a request successfully that was previously rate_limited // resolve the rate limit if rate_limiting { self.resolve_rate_limit() } return Ok(response .data .into_iter() .map(|embedding| Embedding::from(embedding.embedding)) .collect()); } StatusCode::TOO_MANY_REQUESTS => { rate_limiting = true; let mut body = String::new(); response.body_mut().read_to_string(&mut body).await?; let delay_duration = { let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64); if let Some(time_to_reset) = response.headers().get("x-ratelimit-reset-tokens") { if let Ok(time_str) = time_to_reset.to_str() { parse(time_str).unwrap_or(delay) } else { delay } } else { delay } }; // If we've previously rate limited, increment the duration but not the count let reset_time = Instant::now().add(delay_duration); self.update_reset_time(reset_time); log::trace!( "openai rate limiting: waiting {:?} until lifted", &delay_duration ); self.executor.timer(delay_duration).await; } _ => { let mut body = String::new(); response.body_mut().read_to_string(&mut body).await?; return Err(anyhow!( "open ai bad request: {:?} {:?}", &response.status(), body )); } } } Err(anyhow!("openai max retries")) } }