move OpenAIEmbeddings to OpenAIEmbeddingProvider in providers folder
This commit is contained in:
parent
d813ae8845
commit
d1dec8314a
7 changed files with 308 additions and 299 deletions
|
@ -1,30 +1,9 @@
|
||||||
use anyhow::{anyhow, Result};
|
use anyhow::Result;
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use futures::AsyncReadExt;
|
|
||||||
use gpui::executor::Background;
|
|
||||||
use gpui::serde_json;
|
|
||||||
use isahc::http::StatusCode;
|
|
||||||
use isahc::prelude::Configurable;
|
|
||||||
use isahc::{AsyncBody, Response};
|
|
||||||
use lazy_static::lazy_static;
|
|
||||||
use ordered_float::OrderedFloat;
|
use ordered_float::OrderedFloat;
|
||||||
use parking_lot::Mutex;
|
|
||||||
use parse_duration::parse;
|
|
||||||
use postage::watch;
|
|
||||||
use rusqlite::types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef};
|
use rusqlite::types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef};
|
||||||
use rusqlite::ToSql;
|
use rusqlite::ToSql;
|
||||||
use serde::{Deserialize, Serialize};
|
use std::time::Instant;
|
||||||
use std::env;
|
|
||||||
use std::ops::Add;
|
|
||||||
use std::sync::Arc;
|
|
||||||
use std::time::{Duration, Instant};
|
|
||||||
use tiktoken_rs::{cl100k_base, CoreBPE};
|
|
||||||
use util::http::{HttpClient, Request};
|
|
||||||
|
|
||||||
lazy_static! {
|
|
||||||
static ref OPENAI_API_KEY: Option<String> = env::var("OPENAI_API_KEY").ok();
|
|
||||||
static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap();
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, PartialEq, Clone)]
|
#[derive(Debug, PartialEq, Clone)]
|
||||||
pub struct Embedding(pub Vec<f32>);
|
pub struct Embedding(pub Vec<f32>);
|
||||||
|
@ -85,39 +64,6 @@ impl Embedding {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone)]
|
|
||||||
pub struct OpenAIEmbeddings {
|
|
||||||
pub client: Arc<dyn HttpClient>,
|
|
||||||
pub executor: Arc<Background>,
|
|
||||||
rate_limit_count_rx: watch::Receiver<Option<Instant>>,
|
|
||||||
rate_limit_count_tx: Arc<Mutex<watch::Sender<Option<Instant>>>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize)]
|
|
||||||
struct OpenAIEmbeddingRequest<'a> {
|
|
||||||
model: &'static str,
|
|
||||||
input: Vec<&'a str>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
|
||||||
struct OpenAIEmbeddingResponse {
|
|
||||||
data: Vec<OpenAIEmbedding>,
|
|
||||||
usage: OpenAIEmbeddingUsage,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
|
||||||
struct OpenAIEmbedding {
|
|
||||||
embedding: Vec<f32>,
|
|
||||||
index: usize,
|
|
||||||
object: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
|
||||||
struct OpenAIEmbeddingUsage {
|
|
||||||
prompt_tokens: usize,
|
|
||||||
total_tokens: usize,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
pub trait EmbeddingProvider: Sync + Send {
|
pub trait EmbeddingProvider: Sync + Send {
|
||||||
fn is_authenticated(&self) -> bool;
|
fn is_authenticated(&self) -> bool;
|
||||||
|
@ -127,235 +73,6 @@ pub trait EmbeddingProvider: Sync + Send {
|
||||||
fn rate_limit_expiration(&self) -> Option<Instant>;
|
fn rate_limit_expiration(&self) -> Option<Instant>;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct DummyEmbeddings {}
|
|
||||||
|
|
||||||
#[async_trait]
|
|
||||||
impl EmbeddingProvider for DummyEmbeddings {
|
|
||||||
fn is_authenticated(&self) -> bool {
|
|
||||||
true
|
|
||||||
}
|
|
||||||
fn rate_limit_expiration(&self) -> Option<Instant> {
|
|
||||||
None
|
|
||||||
}
|
|
||||||
async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
|
|
||||||
// 1024 is the OpenAI Embeddings size for ada models.
|
|
||||||
// the model we will likely be starting with.
|
|
||||||
let dummy_vec = Embedding::from(vec![0.32 as f32; 1536]);
|
|
||||||
return Ok(vec![dummy_vec; spans.len()]);
|
|
||||||
}
|
|
||||||
|
|
||||||
fn max_tokens_per_batch(&self) -> usize {
|
|
||||||
OPENAI_INPUT_LIMIT
|
|
||||||
}
|
|
||||||
|
|
||||||
fn truncate(&self, span: &str) -> (String, usize) {
|
|
||||||
let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
|
|
||||||
let token_count = tokens.len();
|
|
||||||
let output = if token_count > OPENAI_INPUT_LIMIT {
|
|
||||||
tokens.truncate(OPENAI_INPUT_LIMIT);
|
|
||||||
let new_input = OPENAI_BPE_TOKENIZER.decode(tokens.clone());
|
|
||||||
new_input.ok().unwrap_or_else(|| span.to_string())
|
|
||||||
} else {
|
|
||||||
span.to_string()
|
|
||||||
};
|
|
||||||
|
|
||||||
(output, tokens.len())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const OPENAI_INPUT_LIMIT: usize = 8190;
|
|
||||||
|
|
||||||
impl OpenAIEmbeddings {
|
|
||||||
pub fn new(client: Arc<dyn HttpClient>, executor: Arc<Background>) -> Self {
|
|
||||||
let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None);
|
|
||||||
let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx));
|
|
||||||
|
|
||||||
OpenAIEmbeddings {
|
|
||||||
client,
|
|
||||||
executor,
|
|
||||||
rate_limit_count_rx,
|
|
||||||
rate_limit_count_tx,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn resolve_rate_limit(&self) {
|
|
||||||
let reset_time = *self.rate_limit_count_tx.lock().borrow();
|
|
||||||
|
|
||||||
if let Some(reset_time) = reset_time {
|
|
||||||
if Instant::now() >= reset_time {
|
|
||||||
*self.rate_limit_count_tx.lock().borrow_mut() = None
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
log::trace!(
|
|
||||||
"resolving reset time: {:?}",
|
|
||||||
*self.rate_limit_count_tx.lock().borrow()
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
fn update_reset_time(&self, reset_time: Instant) {
|
|
||||||
let original_time = *self.rate_limit_count_tx.lock().borrow();
|
|
||||||
|
|
||||||
let updated_time = if let Some(original_time) = original_time {
|
|
||||||
if reset_time < original_time {
|
|
||||||
Some(reset_time)
|
|
||||||
} else {
|
|
||||||
Some(original_time)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
Some(reset_time)
|
|
||||||
};
|
|
||||||
|
|
||||||
log::trace!("updating rate limit time: {:?}", updated_time);
|
|
||||||
|
|
||||||
*self.rate_limit_count_tx.lock().borrow_mut() = updated_time;
|
|
||||||
}
|
|
||||||
async fn send_request(
|
|
||||||
&self,
|
|
||||||
api_key: &str,
|
|
||||||
spans: Vec<&str>,
|
|
||||||
request_timeout: u64,
|
|
||||||
) -> Result<Response<AsyncBody>> {
|
|
||||||
let request = Request::post("https://api.openai.com/v1/embeddings")
|
|
||||||
.redirect_policy(isahc::config::RedirectPolicy::Follow)
|
|
||||||
.timeout(Duration::from_secs(request_timeout))
|
|
||||||
.header("Content-Type", "application/json")
|
|
||||||
.header("Authorization", format!("Bearer {}", api_key))
|
|
||||||
.body(
|
|
||||||
serde_json::to_string(&OpenAIEmbeddingRequest {
|
|
||||||
input: spans.clone(),
|
|
||||||
model: "text-embedding-ada-002",
|
|
||||||
})
|
|
||||||
.unwrap()
|
|
||||||
.into(),
|
|
||||||
)?;
|
|
||||||
|
|
||||||
Ok(self.client.send(request).await?)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[async_trait]
|
|
||||||
impl EmbeddingProvider for OpenAIEmbeddings {
|
|
||||||
fn is_authenticated(&self) -> bool {
|
|
||||||
OPENAI_API_KEY.as_ref().is_some()
|
|
||||||
}
|
|
||||||
fn max_tokens_per_batch(&self) -> usize {
|
|
||||||
50000
|
|
||||||
}
|
|
||||||
|
|
||||||
fn rate_limit_expiration(&self) -> Option<Instant> {
|
|
||||||
*self.rate_limit_count_rx.borrow()
|
|
||||||
}
|
|
||||||
fn truncate(&self, span: &str) -> (String, usize) {
|
|
||||||
let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
|
|
||||||
let output = if tokens.len() > OPENAI_INPUT_LIMIT {
|
|
||||||
tokens.truncate(OPENAI_INPUT_LIMIT);
|
|
||||||
OPENAI_BPE_TOKENIZER
|
|
||||||
.decode(tokens.clone())
|
|
||||||
.ok()
|
|
||||||
.unwrap_or_else(|| span.to_string())
|
|
||||||
} else {
|
|
||||||
span.to_string()
|
|
||||||
};
|
|
||||||
|
|
||||||
(output, tokens.len())
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
|
|
||||||
const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
|
|
||||||
const MAX_RETRIES: usize = 4;
|
|
||||||
|
|
||||||
let api_key = OPENAI_API_KEY
|
|
||||||
.as_ref()
|
|
||||||
.ok_or_else(|| anyhow!("no api key"))?;
|
|
||||||
|
|
||||||
let mut request_number = 0;
|
|
||||||
let mut rate_limiting = false;
|
|
||||||
let mut request_timeout: u64 = 15;
|
|
||||||
let mut response: Response<AsyncBody>;
|
|
||||||
while request_number < MAX_RETRIES {
|
|
||||||
response = self
|
|
||||||
.send_request(
|
|
||||||
api_key,
|
|
||||||
spans.iter().map(|x| &**x).collect(),
|
|
||||||
request_timeout,
|
|
||||||
)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
request_number += 1;
|
|
||||||
|
|
||||||
match response.status() {
|
|
||||||
StatusCode::REQUEST_TIMEOUT => {
|
|
||||||
request_timeout += 5;
|
|
||||||
}
|
|
||||||
StatusCode::OK => {
|
|
||||||
let mut body = String::new();
|
|
||||||
response.body_mut().read_to_string(&mut body).await?;
|
|
||||||
let response: OpenAIEmbeddingResponse = serde_json::from_str(&body)?;
|
|
||||||
|
|
||||||
log::trace!(
|
|
||||||
"openai embedding completed. tokens: {:?}",
|
|
||||||
response.usage.total_tokens
|
|
||||||
);
|
|
||||||
|
|
||||||
// If we complete a request successfully that was previously rate_limited
|
|
||||||
// resolve the rate limit
|
|
||||||
if rate_limiting {
|
|
||||||
self.resolve_rate_limit()
|
|
||||||
}
|
|
||||||
|
|
||||||
return Ok(response
|
|
||||||
.data
|
|
||||||
.into_iter()
|
|
||||||
.map(|embedding| Embedding::from(embedding.embedding))
|
|
||||||
.collect());
|
|
||||||
}
|
|
||||||
StatusCode::TOO_MANY_REQUESTS => {
|
|
||||||
rate_limiting = true;
|
|
||||||
let mut body = String::new();
|
|
||||||
response.body_mut().read_to_string(&mut body).await?;
|
|
||||||
|
|
||||||
let delay_duration = {
|
|
||||||
let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64);
|
|
||||||
if let Some(time_to_reset) =
|
|
||||||
response.headers().get("x-ratelimit-reset-tokens")
|
|
||||||
{
|
|
||||||
if let Ok(time_str) = time_to_reset.to_str() {
|
|
||||||
parse(time_str).unwrap_or(delay)
|
|
||||||
} else {
|
|
||||||
delay
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
delay
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// If we've previously rate limited, increment the duration but not the count
|
|
||||||
let reset_time = Instant::now().add(delay_duration);
|
|
||||||
self.update_reset_time(reset_time);
|
|
||||||
|
|
||||||
log::trace!(
|
|
||||||
"openai rate limiting: waiting {:?} until lifted",
|
|
||||||
&delay_duration
|
|
||||||
);
|
|
||||||
|
|
||||||
self.executor.timer(delay_duration).await;
|
|
||||||
}
|
|
||||||
_ => {
|
|
||||||
let mut body = String::new();
|
|
||||||
response.body_mut().read_to_string(&mut body).await?;
|
|
||||||
return Err(anyhow!(
|
|
||||||
"open ai bad request: {:?} {:?}",
|
|
||||||
&response.status(),
|
|
||||||
body
|
|
||||||
));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Err(anyhow!("openai max retries"))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
|
@ -1,4 +1,10 @@
|
||||||
use crate::completion::CompletionRequest;
|
use std::time::Instant;
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
completion::CompletionRequest,
|
||||||
|
embedding::{Embedding, EmbeddingProvider},
|
||||||
|
};
|
||||||
|
use async_trait::async_trait;
|
||||||
use serde::Serialize;
|
use serde::Serialize;
|
||||||
|
|
||||||
#[derive(Serialize)]
|
#[derive(Serialize)]
|
||||||
|
@ -11,3 +17,32 @@ impl CompletionRequest for DummyCompletionRequest {
|
||||||
serde_json::to_string(self)
|
serde_json::to_string(self)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub struct DummyEmbeddingProvider {}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl EmbeddingProvider for DummyEmbeddingProvider {
|
||||||
|
fn is_authenticated(&self) -> bool {
|
||||||
|
true
|
||||||
|
}
|
||||||
|
fn rate_limit_expiration(&self) -> Option<Instant> {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
async fn embed_batch(&self, spans: Vec<String>) -> anyhow::Result<Vec<Embedding>> {
|
||||||
|
// 1024 is the OpenAI Embeddings size for ada models.
|
||||||
|
// the model we will likely be starting with.
|
||||||
|
let dummy_vec = Embedding::from(vec![0.32 as f32; 1536]);
|
||||||
|
return Ok(vec![dummy_vec; spans.len()]);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn max_tokens_per_batch(&self) -> usize {
|
||||||
|
8190
|
||||||
|
}
|
||||||
|
|
||||||
|
fn truncate(&self, span: &str) -> (String, usize) {
|
||||||
|
let truncated = span.chars().collect::<Vec<char>>()[..8190]
|
||||||
|
.iter()
|
||||||
|
.collect::<String>();
|
||||||
|
(truncated, 8190)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
252
crates/ai/src/providers/open_ai/embedding.rs
Normal file
252
crates/ai/src/providers/open_ai/embedding.rs
Normal file
|
@ -0,0 +1,252 @@
|
||||||
|
use anyhow::{anyhow, Result};
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use futures::AsyncReadExt;
|
||||||
|
use gpui::executor::Background;
|
||||||
|
use gpui::serde_json;
|
||||||
|
use isahc::http::StatusCode;
|
||||||
|
use isahc::prelude::Configurable;
|
||||||
|
use isahc::{AsyncBody, Response};
|
||||||
|
use lazy_static::lazy_static;
|
||||||
|
use parking_lot::Mutex;
|
||||||
|
use parse_duration::parse;
|
||||||
|
use postage::watch;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::env;
|
||||||
|
use std::ops::Add;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::time::{Duration, Instant};
|
||||||
|
use tiktoken_rs::{cl100k_base, CoreBPE};
|
||||||
|
use util::http::{HttpClient, Request};
|
||||||
|
|
||||||
|
use crate::embedding::{Embedding, EmbeddingProvider};
|
||||||
|
|
||||||
|
lazy_static! {
|
||||||
|
static ref OPENAI_API_KEY: Option<String> = env::var("OPENAI_API_KEY").ok();
|
||||||
|
static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct OpenAIEmbeddingProvider {
|
||||||
|
pub client: Arc<dyn HttpClient>,
|
||||||
|
pub executor: Arc<Background>,
|
||||||
|
rate_limit_count_rx: watch::Receiver<Option<Instant>>,
|
||||||
|
rate_limit_count_tx: Arc<Mutex<watch::Sender<Option<Instant>>>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize)]
|
||||||
|
struct OpenAIEmbeddingRequest<'a> {
|
||||||
|
model: &'static str,
|
||||||
|
input: Vec<&'a str>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
struct OpenAIEmbeddingResponse {
|
||||||
|
data: Vec<OpenAIEmbedding>,
|
||||||
|
usage: OpenAIEmbeddingUsage,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct OpenAIEmbedding {
|
||||||
|
embedding: Vec<f32>,
|
||||||
|
index: usize,
|
||||||
|
object: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
struct OpenAIEmbeddingUsage {
|
||||||
|
prompt_tokens: usize,
|
||||||
|
total_tokens: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
const OPENAI_INPUT_LIMIT: usize = 8190;
|
||||||
|
|
||||||
|
impl OpenAIEmbeddingProvider {
|
||||||
|
pub fn new(client: Arc<dyn HttpClient>, executor: Arc<Background>) -> Self {
|
||||||
|
let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None);
|
||||||
|
let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx));
|
||||||
|
|
||||||
|
OpenAIEmbeddingProvider {
|
||||||
|
client,
|
||||||
|
executor,
|
||||||
|
rate_limit_count_rx,
|
||||||
|
rate_limit_count_tx,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn resolve_rate_limit(&self) {
|
||||||
|
let reset_time = *self.rate_limit_count_tx.lock().borrow();
|
||||||
|
|
||||||
|
if let Some(reset_time) = reset_time {
|
||||||
|
if Instant::now() >= reset_time {
|
||||||
|
*self.rate_limit_count_tx.lock().borrow_mut() = None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
log::trace!(
|
||||||
|
"resolving reset time: {:?}",
|
||||||
|
*self.rate_limit_count_tx.lock().borrow()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn update_reset_time(&self, reset_time: Instant) {
|
||||||
|
let original_time = *self.rate_limit_count_tx.lock().borrow();
|
||||||
|
|
||||||
|
let updated_time = if let Some(original_time) = original_time {
|
||||||
|
if reset_time < original_time {
|
||||||
|
Some(reset_time)
|
||||||
|
} else {
|
||||||
|
Some(original_time)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
Some(reset_time)
|
||||||
|
};
|
||||||
|
|
||||||
|
log::trace!("updating rate limit time: {:?}", updated_time);
|
||||||
|
|
||||||
|
*self.rate_limit_count_tx.lock().borrow_mut() = updated_time;
|
||||||
|
}
|
||||||
|
async fn send_request(
|
||||||
|
&self,
|
||||||
|
api_key: &str,
|
||||||
|
spans: Vec<&str>,
|
||||||
|
request_timeout: u64,
|
||||||
|
) -> Result<Response<AsyncBody>> {
|
||||||
|
let request = Request::post("https://api.openai.com/v1/embeddings")
|
||||||
|
.redirect_policy(isahc::config::RedirectPolicy::Follow)
|
||||||
|
.timeout(Duration::from_secs(request_timeout))
|
||||||
|
.header("Content-Type", "application/json")
|
||||||
|
.header("Authorization", format!("Bearer {}", api_key))
|
||||||
|
.body(
|
||||||
|
serde_json::to_string(&OpenAIEmbeddingRequest {
|
||||||
|
input: spans.clone(),
|
||||||
|
model: "text-embedding-ada-002",
|
||||||
|
})
|
||||||
|
.unwrap()
|
||||||
|
.into(),
|
||||||
|
)?;
|
||||||
|
|
||||||
|
Ok(self.client.send(request).await?)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl EmbeddingProvider for OpenAIEmbeddingProvider {
|
||||||
|
fn is_authenticated(&self) -> bool {
|
||||||
|
OPENAI_API_KEY.as_ref().is_some()
|
||||||
|
}
|
||||||
|
fn max_tokens_per_batch(&self) -> usize {
|
||||||
|
50000
|
||||||
|
}
|
||||||
|
|
||||||
|
fn rate_limit_expiration(&self) -> Option<Instant> {
|
||||||
|
*self.rate_limit_count_rx.borrow()
|
||||||
|
}
|
||||||
|
fn truncate(&self, span: &str) -> (String, usize) {
|
||||||
|
let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
|
||||||
|
let output = if tokens.len() > OPENAI_INPUT_LIMIT {
|
||||||
|
tokens.truncate(OPENAI_INPUT_LIMIT);
|
||||||
|
OPENAI_BPE_TOKENIZER
|
||||||
|
.decode(tokens.clone())
|
||||||
|
.ok()
|
||||||
|
.unwrap_or_else(|| span.to_string())
|
||||||
|
} else {
|
||||||
|
span.to_string()
|
||||||
|
};
|
||||||
|
|
||||||
|
(output, tokens.len())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
|
||||||
|
const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
|
||||||
|
const MAX_RETRIES: usize = 4;
|
||||||
|
|
||||||
|
let api_key = OPENAI_API_KEY
|
||||||
|
.as_ref()
|
||||||
|
.ok_or_else(|| anyhow!("no api key"))?;
|
||||||
|
|
||||||
|
let mut request_number = 0;
|
||||||
|
let mut rate_limiting = false;
|
||||||
|
let mut request_timeout: u64 = 15;
|
||||||
|
let mut response: Response<AsyncBody>;
|
||||||
|
while request_number < MAX_RETRIES {
|
||||||
|
response = self
|
||||||
|
.send_request(
|
||||||
|
api_key,
|
||||||
|
spans.iter().map(|x| &**x).collect(),
|
||||||
|
request_timeout,
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
request_number += 1;
|
||||||
|
|
||||||
|
match response.status() {
|
||||||
|
StatusCode::REQUEST_TIMEOUT => {
|
||||||
|
request_timeout += 5;
|
||||||
|
}
|
||||||
|
StatusCode::OK => {
|
||||||
|
let mut body = String::new();
|
||||||
|
response.body_mut().read_to_string(&mut body).await?;
|
||||||
|
let response: OpenAIEmbeddingResponse = serde_json::from_str(&body)?;
|
||||||
|
|
||||||
|
log::trace!(
|
||||||
|
"openai embedding completed. tokens: {:?}",
|
||||||
|
response.usage.total_tokens
|
||||||
|
);
|
||||||
|
|
||||||
|
// If we complete a request successfully that was previously rate_limited
|
||||||
|
// resolve the rate limit
|
||||||
|
if rate_limiting {
|
||||||
|
self.resolve_rate_limit()
|
||||||
|
}
|
||||||
|
|
||||||
|
return Ok(response
|
||||||
|
.data
|
||||||
|
.into_iter()
|
||||||
|
.map(|embedding| Embedding::from(embedding.embedding))
|
||||||
|
.collect());
|
||||||
|
}
|
||||||
|
StatusCode::TOO_MANY_REQUESTS => {
|
||||||
|
rate_limiting = true;
|
||||||
|
let mut body = String::new();
|
||||||
|
response.body_mut().read_to_string(&mut body).await?;
|
||||||
|
|
||||||
|
let delay_duration = {
|
||||||
|
let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64);
|
||||||
|
if let Some(time_to_reset) =
|
||||||
|
response.headers().get("x-ratelimit-reset-tokens")
|
||||||
|
{
|
||||||
|
if let Ok(time_str) = time_to_reset.to_str() {
|
||||||
|
parse(time_str).unwrap_or(delay)
|
||||||
|
} else {
|
||||||
|
delay
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
delay
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// If we've previously rate limited, increment the duration but not the count
|
||||||
|
let reset_time = Instant::now().add(delay_duration);
|
||||||
|
self.update_reset_time(reset_time);
|
||||||
|
|
||||||
|
log::trace!(
|
||||||
|
"openai rate limiting: waiting {:?} until lifted",
|
||||||
|
&delay_duration
|
||||||
|
);
|
||||||
|
|
||||||
|
self.executor.timer(delay_duration).await;
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
let mut body = String::new();
|
||||||
|
response.body_mut().read_to_string(&mut body).await?;
|
||||||
|
return Err(anyhow!(
|
||||||
|
"open ai bad request: {:?} {:?}",
|
||||||
|
&response.status(),
|
||||||
|
body
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(anyhow!("openai max retries"))
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,4 +1,7 @@
|
||||||
pub mod completion;
|
pub mod completion;
|
||||||
|
pub mod embedding;
|
||||||
pub mod model;
|
pub mod model;
|
||||||
|
|
||||||
pub use completion::*;
|
pub use completion::*;
|
||||||
|
pub use embedding::*;
|
||||||
pub use model::OpenAILanguageModel;
|
pub use model::OpenAILanguageModel;
|
||||||
|
|
|
@ -7,7 +7,8 @@ pub mod semantic_index_settings;
|
||||||
mod semantic_index_tests;
|
mod semantic_index_tests;
|
||||||
|
|
||||||
use crate::semantic_index_settings::SemanticIndexSettings;
|
use crate::semantic_index_settings::SemanticIndexSettings;
|
||||||
use ai::embedding::{Embedding, EmbeddingProvider, OpenAIEmbeddings};
|
use ai::embedding::{Embedding, EmbeddingProvider};
|
||||||
|
use ai::providers::open_ai::OpenAIEmbeddingProvider;
|
||||||
use anyhow::{anyhow, Result};
|
use anyhow::{anyhow, Result};
|
||||||
use collections::{BTreeMap, HashMap, HashSet};
|
use collections::{BTreeMap, HashMap, HashSet};
|
||||||
use db::VectorDatabase;
|
use db::VectorDatabase;
|
||||||
|
@ -88,7 +89,7 @@ pub fn init(
|
||||||
let semantic_index = SemanticIndex::new(
|
let semantic_index = SemanticIndex::new(
|
||||||
fs,
|
fs,
|
||||||
db_file_path,
|
db_file_path,
|
||||||
Arc::new(OpenAIEmbeddings::new(http_client, cx.background())),
|
Arc::new(OpenAIEmbeddingProvider::new(http_client, cx.background())),
|
||||||
language_registry,
|
language_registry,
|
||||||
cx.clone(),
|
cx.clone(),
|
||||||
)
|
)
|
||||||
|
|
|
@ -4,7 +4,8 @@ use crate::{
|
||||||
semantic_index_settings::SemanticIndexSettings,
|
semantic_index_settings::SemanticIndexSettings,
|
||||||
FileToEmbed, JobHandle, SearchResult, SemanticIndex, EMBEDDING_QUEUE_FLUSH_TIMEOUT,
|
FileToEmbed, JobHandle, SearchResult, SemanticIndex, EMBEDDING_QUEUE_FLUSH_TIMEOUT,
|
||||||
};
|
};
|
||||||
use ai::embedding::{DummyEmbeddings, Embedding, EmbeddingProvider};
|
use ai::embedding::{Embedding, EmbeddingProvider};
|
||||||
|
use ai::providers::dummy::DummyEmbeddingProvider;
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use gpui::{executor::Deterministic, Task, TestAppContext};
|
use gpui::{executor::Deterministic, Task, TestAppContext};
|
||||||
|
@ -280,7 +281,7 @@ fn assert_search_results(
|
||||||
#[gpui::test]
|
#[gpui::test]
|
||||||
async fn test_code_context_retrieval_rust() {
|
async fn test_code_context_retrieval_rust() {
|
||||||
let language = rust_lang();
|
let language = rust_lang();
|
||||||
let embedding_provider = Arc::new(DummyEmbeddings {});
|
let embedding_provider = Arc::new(DummyEmbeddingProvider {});
|
||||||
let mut retriever = CodeContextRetriever::new(embedding_provider);
|
let mut retriever = CodeContextRetriever::new(embedding_provider);
|
||||||
|
|
||||||
let text = "
|
let text = "
|
||||||
|
@ -382,7 +383,7 @@ async fn test_code_context_retrieval_rust() {
|
||||||
#[gpui::test]
|
#[gpui::test]
|
||||||
async fn test_code_context_retrieval_json() {
|
async fn test_code_context_retrieval_json() {
|
||||||
let language = json_lang();
|
let language = json_lang();
|
||||||
let embedding_provider = Arc::new(DummyEmbeddings {});
|
let embedding_provider = Arc::new(DummyEmbeddingProvider {});
|
||||||
let mut retriever = CodeContextRetriever::new(embedding_provider);
|
let mut retriever = CodeContextRetriever::new(embedding_provider);
|
||||||
|
|
||||||
let text = r#"
|
let text = r#"
|
||||||
|
@ -466,7 +467,7 @@ fn assert_documents_eq(
|
||||||
#[gpui::test]
|
#[gpui::test]
|
||||||
async fn test_code_context_retrieval_javascript() {
|
async fn test_code_context_retrieval_javascript() {
|
||||||
let language = js_lang();
|
let language = js_lang();
|
||||||
let embedding_provider = Arc::new(DummyEmbeddings {});
|
let embedding_provider = Arc::new(DummyEmbeddingProvider {});
|
||||||
let mut retriever = CodeContextRetriever::new(embedding_provider);
|
let mut retriever = CodeContextRetriever::new(embedding_provider);
|
||||||
|
|
||||||
let text = "
|
let text = "
|
||||||
|
@ -565,7 +566,7 @@ async fn test_code_context_retrieval_javascript() {
|
||||||
#[gpui::test]
|
#[gpui::test]
|
||||||
async fn test_code_context_retrieval_lua() {
|
async fn test_code_context_retrieval_lua() {
|
||||||
let language = lua_lang();
|
let language = lua_lang();
|
||||||
let embedding_provider = Arc::new(DummyEmbeddings {});
|
let embedding_provider = Arc::new(DummyEmbeddingProvider {});
|
||||||
let mut retriever = CodeContextRetriever::new(embedding_provider);
|
let mut retriever = CodeContextRetriever::new(embedding_provider);
|
||||||
|
|
||||||
let text = r#"
|
let text = r#"
|
||||||
|
@ -639,7 +640,7 @@ async fn test_code_context_retrieval_lua() {
|
||||||
#[gpui::test]
|
#[gpui::test]
|
||||||
async fn test_code_context_retrieval_elixir() {
|
async fn test_code_context_retrieval_elixir() {
|
||||||
let language = elixir_lang();
|
let language = elixir_lang();
|
||||||
let embedding_provider = Arc::new(DummyEmbeddings {});
|
let embedding_provider = Arc::new(DummyEmbeddingProvider {});
|
||||||
let mut retriever = CodeContextRetriever::new(embedding_provider);
|
let mut retriever = CodeContextRetriever::new(embedding_provider);
|
||||||
|
|
||||||
let text = r#"
|
let text = r#"
|
||||||
|
@ -756,7 +757,7 @@ async fn test_code_context_retrieval_elixir() {
|
||||||
#[gpui::test]
|
#[gpui::test]
|
||||||
async fn test_code_context_retrieval_cpp() {
|
async fn test_code_context_retrieval_cpp() {
|
||||||
let language = cpp_lang();
|
let language = cpp_lang();
|
||||||
let embedding_provider = Arc::new(DummyEmbeddings {});
|
let embedding_provider = Arc::new(DummyEmbeddingProvider {});
|
||||||
let mut retriever = CodeContextRetriever::new(embedding_provider);
|
let mut retriever = CodeContextRetriever::new(embedding_provider);
|
||||||
|
|
||||||
let text = "
|
let text = "
|
||||||
|
@ -909,7 +910,7 @@ async fn test_code_context_retrieval_cpp() {
|
||||||
#[gpui::test]
|
#[gpui::test]
|
||||||
async fn test_code_context_retrieval_ruby() {
|
async fn test_code_context_retrieval_ruby() {
|
||||||
let language = ruby_lang();
|
let language = ruby_lang();
|
||||||
let embedding_provider = Arc::new(DummyEmbeddings {});
|
let embedding_provider = Arc::new(DummyEmbeddingProvider {});
|
||||||
let mut retriever = CodeContextRetriever::new(embedding_provider);
|
let mut retriever = CodeContextRetriever::new(embedding_provider);
|
||||||
|
|
||||||
let text = r#"
|
let text = r#"
|
||||||
|
@ -1100,7 +1101,7 @@ async fn test_code_context_retrieval_ruby() {
|
||||||
#[gpui::test]
|
#[gpui::test]
|
||||||
async fn test_code_context_retrieval_php() {
|
async fn test_code_context_retrieval_php() {
|
||||||
let language = php_lang();
|
let language = php_lang();
|
||||||
let embedding_provider = Arc::new(DummyEmbeddings {});
|
let embedding_provider = Arc::new(DummyEmbeddingProvider {});
|
||||||
let mut retriever = CodeContextRetriever::new(embedding_provider);
|
let mut retriever = CodeContextRetriever::new(embedding_provider);
|
||||||
|
|
||||||
let text = r#"
|
let text = r#"
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
use ai::embedding::OpenAIEmbeddings;
|
use ai::providers::open_ai::OpenAIEmbeddingProvider;
|
||||||
use anyhow::{anyhow, Result};
|
use anyhow::{anyhow, Result};
|
||||||
use client::{self, UserStore};
|
use client::{self, UserStore};
|
||||||
use gpui::{AsyncAppContext, ModelHandle, Task};
|
use gpui::{AsyncAppContext, ModelHandle, Task};
|
||||||
|
@ -474,7 +474,7 @@ fn main() {
|
||||||
let semantic_index = SemanticIndex::new(
|
let semantic_index = SemanticIndex::new(
|
||||||
fs.clone(),
|
fs.clone(),
|
||||||
db_file_path,
|
db_file_path,
|
||||||
Arc::new(OpenAIEmbeddings::new(http_client, cx.background())),
|
Arc::new(OpenAIEmbeddingProvider::new(http_client, cx.background())),
|
||||||
languages.clone(),
|
languages.clone(),
|
||||||
cx.clone(),
|
cx.clone(),
|
||||||
)
|
)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue