ZIm/crates/language_model/src/rate_limiter.rs
Marshall Bowers d93141bded
agent: Extract usage information from response headers (#29002)
This PR updates the Agent to extract the usage information from the
response headers, if they are present.

For now we just log the information, but we'll be using this soon to
populate some UI.

Release Notes:

- N/A
2025-04-17 20:11:07 +00:00

100 lines
2.3 KiB
Rust

use anyhow::Result;
use futures::Stream;
use smol::lock::{Semaphore, SemaphoreGuardArc};
use std::{
future::Future,
pin::Pin,
sync::Arc,
task::{Context, Poll},
};
use crate::RequestUsage;
#[derive(Clone)]
pub struct RateLimiter {
semaphore: Arc<Semaphore>,
}
pub struct RateLimitGuard<T> {
inner: T,
_guard: SemaphoreGuardArc,
}
impl<T> Stream for RateLimitGuard<T>
where
T: Stream,
{
type Item = T::Item;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
unsafe { Pin::map_unchecked_mut(self, |this| &mut this.inner).poll_next(cx) }
}
}
impl RateLimiter {
pub fn new(limit: usize) -> Self {
Self {
semaphore: Arc::new(Semaphore::new(limit)),
}
}
pub fn run<'a, Fut, T>(&self, future: Fut) -> impl 'a + Future<Output = Result<T>>
where
Fut: 'a + Future<Output = Result<T>>,
{
let guard = self.semaphore.acquire_arc();
async move {
let guard = guard.await;
let result = future.await?;
drop(guard);
Ok(result)
}
}
pub fn stream<'a, Fut, T>(
&self,
future: Fut,
) -> impl 'a + Future<Output = Result<impl Stream<Item = T::Item> + use<Fut, T>>>
where
Fut: 'a + Future<Output = Result<T>>,
T: Stream,
{
let guard = self.semaphore.acquire_arc();
async move {
let guard = guard.await;
let inner = future.await?;
Ok(RateLimitGuard {
inner,
_guard: guard,
})
}
}
pub fn stream_with_usage<'a, Fut, T>(
&self,
future: Fut,
) -> impl 'a
+ Future<
Output = Result<(
impl Stream<Item = T::Item> + use<Fut, T>,
Option<RequestUsage>,
)>,
>
where
Fut: 'a + Future<Output = Result<(T, Option<RequestUsage>)>>,
T: Stream,
{
let guard = self.semaphore.acquire_arc();
async move {
let guard = guard.await;
let (inner, usage) = future.await?;
Ok((
RateLimitGuard {
inner,
_guard: guard,
},
usage,
))
}
}
}