use anyhow::{anyhow, Result}; use async_trait::async_trait; use futures::AsyncReadExt; use gpui::executor::Background; use gpui::serde_json; use isahc::http::StatusCode; use isahc::prelude::Configurable; use isahc::{AsyncBody, Response}; use lazy_static::lazy_static; use serde::{Deserialize, Serialize}; use std::env; use std::sync::Arc; use std::time::Duration; use tiktoken_rs::{cl100k_base, CoreBPE}; use util::http::{HttpClient, Request}; lazy_static! { static ref OPENAI_API_KEY: Option = env::var("OPENAI_API_KEY").ok(); static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap(); } #[derive(Clone)] pub struct OpenAIEmbeddings { pub client: Arc, pub executor: Arc, } #[derive(Serialize)] struct OpenAIEmbeddingRequest<'a> { model: &'static str, input: Vec<&'a str>, } #[derive(Deserialize)] struct OpenAIEmbeddingResponse { data: Vec, usage: OpenAIEmbeddingUsage, } #[derive(Debug, Deserialize)] struct OpenAIEmbedding { embedding: Vec, index: usize, object: String, } #[derive(Deserialize)] struct OpenAIEmbeddingUsage { prompt_tokens: usize, total_tokens: usize, } #[async_trait] pub trait EmbeddingProvider: Sync + Send { async fn embed_batch(&self, spans: Vec<&str>) -> Result>>; } pub struct DummyEmbeddings {} #[async_trait] impl EmbeddingProvider for DummyEmbeddings { async fn embed_batch(&self, spans: Vec<&str>) -> Result>> { // 1024 is the OpenAI Embeddings size for ada models. // the model we will likely be starting with. let dummy_vec = vec![0.32 as f32; 1536]; return Ok(vec![dummy_vec; spans.len()]); } } impl OpenAIEmbeddings { async fn truncate(span: String) -> String { let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span.as_ref()); if tokens.len() > 8190 { tokens.truncate(8190); let result = OPENAI_BPE_TOKENIZER.decode(tokens.clone()); if result.is_ok() { let transformed = result.unwrap(); // assert_ne!(transformed, span); return transformed; } } return span.to_string(); } async fn send_request(&self, api_key: &str, spans: Vec<&str>) -> Result> { let request = Request::post("https://api.openai.com/v1/embeddings") .redirect_policy(isahc::config::RedirectPolicy::Follow) .header("Content-Type", "application/json") .header("Authorization", format!("Bearer {}", api_key)) .body( serde_json::to_string(&OpenAIEmbeddingRequest { input: spans.clone(), model: "text-embedding-ada-002", }) .unwrap() .into(), )?; Ok(self.client.send(request).await?) } } #[async_trait] impl EmbeddingProvider for OpenAIEmbeddings { async fn embed_batch(&self, spans: Vec<&str>) -> Result>> { const BACKOFF_SECONDS: [usize; 3] = [65, 180, 360]; const MAX_RETRIES: usize = 3; let api_key = OPENAI_API_KEY .as_ref() .ok_or_else(|| anyhow!("no api key"))?; let mut request_number = 0; let mut response: Response; let mut spans: Vec = spans.iter().map(|x| x.to_string()).collect(); while request_number < MAX_RETRIES { response = self .send_request(api_key, spans.iter().map(|x| &**x).collect()) .await?; request_number += 1; if request_number + 1 == MAX_RETRIES && response.status() != StatusCode::OK { return Err(anyhow!( "openai max retries, error: {:?}", &response.status() )); } match response.status() { StatusCode::TOO_MANY_REQUESTS => { let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64); self.executor.timer(delay).await; } StatusCode::BAD_REQUEST => { log::info!("BAD REQUEST: {:?}", &response.status()); // Don't worry about delaying bad request, as we can assume // we haven't been rate limited yet. for span in spans.iter_mut() { *span = Self::truncate(span.to_string()).await; } } StatusCode::OK => { let mut body = String::new(); response.body_mut().read_to_string(&mut body).await?; let response: OpenAIEmbeddingResponse = serde_json::from_str(&body)?; log::info!( "openai embedding completed. tokens: {:?}", response.usage.total_tokens ); return Ok(response .data .into_iter() .map(|embedding| embedding.embedding) .collect()); } _ => { return Err(anyhow!("openai embedding failed {}", response.status())); } } } Err(anyhow!("openai embedding failed")) } }