More resilient eval (#32257)
Bubbles up rate limit information so that we can retry after a certain duration if needed higher up in the stack. Also caps the number of concurrent evals running at once to also help. Release Notes: - N/A
This commit is contained in:
parent
fa54fa80d0
commit
e4bd115a63
22 changed files with 147 additions and 56 deletions
|
@ -185,6 +185,7 @@ impl LanguageModel for FakeLanguageModel {
|
|||
'static,
|
||||
Result<
|
||||
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
|
||||
LanguageModelCompletionError,
|
||||
>,
|
||||
> {
|
||||
let (tx, rx) = mpsc::unbounded();
|
||||
|
|
|
@ -22,6 +22,7 @@ use std::fmt;
|
|||
use std::ops::{Add, Sub};
|
||||
use std::str::FromStr as _;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use thiserror::Error;
|
||||
use util::serde::is_default;
|
||||
use zed_llm_client::{
|
||||
|
@ -74,6 +75,8 @@ pub enum LanguageModelCompletionEvent {
|
|||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum LanguageModelCompletionError {
|
||||
#[error("rate limit exceeded, retry after {0:?}")]
|
||||
RateLimit(Duration),
|
||||
#[error("received bad input JSON")]
|
||||
BadInputJson {
|
||||
id: LanguageModelToolUseId,
|
||||
|
@ -270,6 +273,7 @@ pub trait LanguageModel: Send + Sync {
|
|||
'static,
|
||||
Result<
|
||||
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
|
||||
LanguageModelCompletionError,
|
||||
>,
|
||||
>;
|
||||
|
||||
|
@ -277,7 +281,7 @@ pub trait LanguageModel: Send + Sync {
|
|||
&self,
|
||||
request: LanguageModelRequest,
|
||||
cx: &AsyncApp,
|
||||
) -> BoxFuture<'static, Result<LanguageModelTextStream>> {
|
||||
) -> BoxFuture<'static, Result<LanguageModelTextStream, LanguageModelCompletionError>> {
|
||||
let future = self.stream_completion(request, cx);
|
||||
|
||||
async move {
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
use anyhow::Result;
|
||||
use futures::Stream;
|
||||
use smol::lock::{Semaphore, SemaphoreGuardArc};
|
||||
use std::{
|
||||
|
@ -8,6 +7,8 @@ use std::{
|
|||
task::{Context, Poll},
|
||||
};
|
||||
|
||||
use crate::LanguageModelCompletionError;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct RateLimiter {
|
||||
semaphore: Arc<Semaphore>,
|
||||
|
@ -36,9 +37,12 @@ impl RateLimiter {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn run<'a, Fut, T>(&self, future: Fut) -> impl 'a + Future<Output = Result<T>>
|
||||
pub fn run<'a, Fut, T>(
|
||||
&self,
|
||||
future: Fut,
|
||||
) -> impl 'a + Future<Output = Result<T, LanguageModelCompletionError>>
|
||||
where
|
||||
Fut: 'a + Future<Output = Result<T>>,
|
||||
Fut: 'a + Future<Output = Result<T, LanguageModelCompletionError>>,
|
||||
{
|
||||
let guard = self.semaphore.acquire_arc();
|
||||
async move {
|
||||
|
@ -52,9 +56,12 @@ impl RateLimiter {
|
|||
pub fn stream<'a, Fut, T>(
|
||||
&self,
|
||||
future: Fut,
|
||||
) -> impl 'a + Future<Output = Result<impl Stream<Item = T::Item> + use<Fut, T>>>
|
||||
) -> impl 'a
|
||||
+ Future<
|
||||
Output = Result<impl Stream<Item = T::Item> + use<Fut, T>, LanguageModelCompletionError>,
|
||||
>
|
||||
where
|
||||
Fut: 'a + Future<Output = Result<T>>,
|
||||
Fut: 'a + Future<Output = Result<T, LanguageModelCompletionError>>,
|
||||
T: Stream,
|
||||
{
|
||||
let guard = self.semaphore.acquire_arc();
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue