ZIm/crates/vector_store/src/embedding.rs

128 lines
3.7 KiB
Rust

use anyhow::{anyhow, Result};
use async_trait::async_trait;
use futures::AsyncReadExt;
use gpui::serde_json;
use isahc::prelude::Configurable;
use lazy_static::lazy_static;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use std::{env, time::Instant};
use util::http::{HttpClient, Request};
lazy_static! {
static ref OPENAI_API_KEY: Option<String> = env::var("OPENAI_API_KEY").ok();
}
#[derive(Clone)]
pub struct OpenAIEmbeddings {
pub client: Arc<dyn HttpClient>,
}
#[derive(Serialize)]
struct OpenAIEmbeddingRequest<'a> {
model: &'static str,
input: Vec<&'a str>,
}
#[derive(Deserialize)]
struct OpenAIEmbeddingResponse {
data: Vec<OpenAIEmbedding>,
usage: OpenAIEmbeddingUsage,
}
#[derive(Debug, Deserialize)]
struct OpenAIEmbedding {
embedding: Vec<f32>,
index: usize,
object: String,
}
#[derive(Deserialize)]
struct OpenAIEmbeddingUsage {
prompt_tokens: usize,
total_tokens: usize,
}
#[async_trait]
pub trait EmbeddingProvider: Sync + Send {
async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>>;
}
pub struct DummyEmbeddings {}
#[async_trait]
impl EmbeddingProvider for DummyEmbeddings {
async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
// 1024 is the OpenAI Embeddings size for ada models.
// the model we will likely be starting with.
let dummy_vec = vec![0.32 as f32; 1536];
return Ok(vec![dummy_vec; spans.len()]);
}
}
// impl OpenAIEmbeddings {
// async fn truncate(span: &str) -> String {
// let bpe = cl100k_base().unwrap();
// let mut tokens = bpe.encode_with_special_tokens(span);
// if tokens.len() > 8192 {
// tokens.truncate(8192);
// let result = bpe.decode(tokens);
// if result.is_ok() {
// return result.unwrap();
// }
// }
// return span.to_string();
// }
// }
#[async_trait]
impl EmbeddingProvider for OpenAIEmbeddings {
async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
// Truncate spans to 8192 if needed
// let t0 = Instant::now();
// let mut truncated_spans = vec![];
// for span in spans {
// truncated_spans.push(Self::truncate(span));
// }
// let spans = futures::future::join_all(truncated_spans).await;
// log::info!("Truncated Spans in {:?}", t0.elapsed().as_secs());
let api_key = OPENAI_API_KEY
.as_ref()
.ok_or_else(|| anyhow!("no api key"))?;
let request = Request::post("https://api.openai.com/v1/embeddings")
.redirect_policy(isahc::config::RedirectPolicy::Follow)
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", api_key))
.body(
serde_json::to_string(&OpenAIEmbeddingRequest {
input: spans,
model: "text-embedding-ada-002",
})
.unwrap()
.into(),
)?;
let mut response = self.client.send(request).await?;
if !response.status().is_success() {
return Err(anyhow!("openai embedding failed {}", response.status()));
}
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
let response: OpenAIEmbeddingResponse = serde_json::from_str(&body)?;
log::info!(
"openai embedding completed. tokens: {:?}",
response.usage.total_tokens
);
Ok(response
.data
.into_iter()
.map(|embedding| embedding.embedding)
.collect())
}
}