initial outline for rate limiting status updates
This commit is contained in:
parent
e9747d0fea
commit
a5ee8fc805
4 changed files with 106 additions and 8 deletions
|
@ -7,7 +7,9 @@ use isahc::http::StatusCode;
|
|||
use isahc::prelude::Configurable;
|
||||
use isahc::{AsyncBody, Response};
|
||||
use lazy_static::lazy_static;
|
||||
use parking_lot::Mutex;
|
||||
use parse_duration::parse;
|
||||
use postage::watch;
|
||||
use rusqlite::types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef};
|
||||
use rusqlite::ToSql;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
@ -82,6 +84,8 @@ impl ToSql for Embedding {
|
|||
pub struct OpenAIEmbeddings {
|
||||
pub client: Arc<dyn HttpClient>,
|
||||
pub executor: Arc<Background>,
|
||||
rate_limit_count_rx: watch::Receiver<(Duration, usize)>,
|
||||
rate_limit_count_tx: Arc<Mutex<watch::Sender<(Duration, usize)>>>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
|
@ -114,12 +118,16 @@ pub trait EmbeddingProvider: Sync + Send {
|
|||
async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>>;
|
||||
fn max_tokens_per_batch(&self) -> usize;
|
||||
fn truncate(&self, span: &str) -> (String, usize);
|
||||
fn rate_limit_expiration(&self) -> Duration;
|
||||
}
|
||||
|
||||
pub struct DummyEmbeddings {}
|
||||
|
||||
#[async_trait]
|
||||
impl EmbeddingProvider for DummyEmbeddings {
|
||||
fn rate_limit_expiration(&self) -> Duration {
|
||||
Duration::ZERO
|
||||
}
|
||||
async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
|
||||
// 1024 is the OpenAI Embeddings size for ada models.
|
||||
// the model we will likely be starting with.
|
||||
|
@ -149,6 +157,53 @@ impl EmbeddingProvider for DummyEmbeddings {
|
|||
const OPENAI_INPUT_LIMIT: usize = 8190;
|
||||
|
||||
impl OpenAIEmbeddings {
|
||||
pub fn new(client: Arc<dyn HttpClient>, executor: Arc<Background>) -> Self {
|
||||
let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with((Duration::ZERO, 0));
|
||||
let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx));
|
||||
|
||||
OpenAIEmbeddings {
|
||||
client,
|
||||
executor,
|
||||
rate_limit_count_rx,
|
||||
rate_limit_count_tx,
|
||||
}
|
||||
}
|
||||
|
||||
fn resolve_rate_limit(&self) {
|
||||
let (current_delay, delay_count) = *self.rate_limit_count_tx.lock().borrow();
|
||||
let updated_count = delay_count - 1;
|
||||
let updated_duration = if updated_count == 0 {
|
||||
Duration::ZERO
|
||||
} else {
|
||||
current_delay
|
||||
};
|
||||
|
||||
log::trace!(
|
||||
"resolving rate limit: Count: {:?} Duration: {:?}",
|
||||
updated_count,
|
||||
updated_duration
|
||||
);
|
||||
|
||||
*self.rate_limit_count_tx.lock().borrow_mut() = (updated_duration, updated_count);
|
||||
}
|
||||
|
||||
fn update_rate_limit(&self, delay_duration: Duration, count_increase: usize) {
|
||||
let (current_delay, delay_count) = *self.rate_limit_count_tx.lock().borrow();
|
||||
let updated_count = delay_count + count_increase;
|
||||
let updated_duration = if current_delay < delay_duration {
|
||||
delay_duration
|
||||
} else {
|
||||
current_delay
|
||||
};
|
||||
|
||||
log::trace!(
|
||||
"updating rate limit: Count: {:?} Duration: {:?}",
|
||||
updated_count,
|
||||
updated_duration
|
||||
);
|
||||
|
||||
*self.rate_limit_count_tx.lock().borrow_mut() = (updated_duration, updated_count);
|
||||
}
|
||||
async fn send_request(
|
||||
&self,
|
||||
api_key: &str,
|
||||
|
@ -179,6 +234,10 @@ impl EmbeddingProvider for OpenAIEmbeddings {
|
|||
50000
|
||||
}
|
||||
|
||||
fn rate_limit_expiration(&self) -> Duration {
|
||||
let (duration, _) = *self.rate_limit_count_rx.borrow();
|
||||
duration
|
||||
}
|
||||
fn truncate(&self, span: &str) -> (String, usize) {
|
||||
let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
|
||||
let output = if tokens.len() > OPENAI_INPUT_LIMIT {
|
||||
|
@ -203,6 +262,7 @@ impl EmbeddingProvider for OpenAIEmbeddings {
|
|||
.ok_or_else(|| anyhow!("no api key"))?;
|
||||
|
||||
let mut request_number = 0;
|
||||
let mut rate_limiting = false;
|
||||
let mut request_timeout: u64 = 15;
|
||||
let mut response: Response<AsyncBody>;
|
||||
while request_number < MAX_RETRIES {
|
||||
|
@ -229,6 +289,12 @@ impl EmbeddingProvider for OpenAIEmbeddings {
|
|||
response.usage.total_tokens
|
||||
);
|
||||
|
||||
// If we complete a request successfully that was previously rate_limited
|
||||
// resolve the rate limit
|
||||
if rate_limiting {
|
||||
self.resolve_rate_limit()
|
||||
}
|
||||
|
||||
return Ok(response
|
||||
.data
|
||||
.into_iter()
|
||||
|
@ -254,6 +320,15 @@ impl EmbeddingProvider for OpenAIEmbeddings {
|
|||
}
|
||||
};
|
||||
|
||||
// If we've previously rate limited, increment the duration but not the count
|
||||
if rate_limiting {
|
||||
self.update_rate_limit(delay_duration, 0);
|
||||
} else {
|
||||
self.update_rate_limit(delay_duration, 1);
|
||||
}
|
||||
|
||||
rate_limiting = true;
|
||||
|
||||
log::trace!(
|
||||
"openai rate limiting: waiting {:?} until lifted",
|
||||
&delay_duration
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue