More resilient eval (#32257)

Bubbles up rate limit information so that we can retry after a certain
duration if needed higher up in the stack.

Also caps the number of concurrent evals running at once to also help.

Release Notes:

- N/A
This commit is contained in:
Ben Brandt 2025-06-09 20:07:22 +02:00 committed by GitHub
parent fa54fa80d0
commit e4bd115a63
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
22 changed files with 147 additions and 56 deletions

View file

@ -62,7 +62,7 @@ jobs:
- name: Run unit evals
shell: bash -euxo pipefail {0}
run: cargo nextest run --workspace --no-fail-fast --features eval --no-capture -E 'test(::eval_)' --test-threads 1
run: cargo nextest run --workspace --no-fail-fast --features eval --no-capture -E 'test(::eval_)'
env:
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}

1
Cargo.lock generated
View file

@ -705,6 +705,7 @@ dependencies = [
"serde_json",
"settings",
"smallvec",
"smol",
"streaming_diff",
"strsim",
"task",

View file

@ -386,8 +386,10 @@ impl CodegenAlternative {
async { Ok(LanguageModelTextStream::default()) }.boxed_local()
} else {
let request = self.build_request(&model, user_prompt, cx)?;
cx.spawn(async move |_, cx| model.stream_completion_text(request.await, &cx).await)
.boxed_local()
cx.spawn(async move |_, cx| {
Ok(model.stream_completion_text(request.await, &cx).await?)
})
.boxed_local()
};
self.handle_stream(telemetry_id, provider_id.to_string(), api_key, stream, cx);
Ok(())

View file

@ -1563,6 +1563,9 @@ impl Thread {
Err(LanguageModelCompletionError::Other(error)) => {
return Err(error);
}
Err(err @ LanguageModelCompletionError::RateLimit(..)) => {
return Err(err.into());
}
};
match event {

View file

@ -1,4 +1,5 @@
use std::str::FromStr;
use std::time::Duration;
use anyhow::{Context as _, Result, anyhow};
use chrono::{DateTime, Utc};
@ -406,6 +407,7 @@ impl RateLimit {
/// <https://docs.anthropic.com/en/api/rate-limits#response-headers>
#[derive(Debug)]
pub struct RateLimitInfo {
pub retry_after: Option<Duration>,
pub requests: Option<RateLimit>,
pub tokens: Option<RateLimit>,
pub input_tokens: Option<RateLimit>,
@ -417,10 +419,11 @@ impl RateLimitInfo {
// Check if any rate limit headers exist
let has_rate_limit_headers = headers
.keys()
.any(|k| k.as_str().starts_with("anthropic-ratelimit-"));
.any(|k| k == "retry-after" || k.as_str().starts_with("anthropic-ratelimit-"));
if !has_rate_limit_headers {
return Self {
retry_after: None,
requests: None,
tokens: None,
input_tokens: None,
@ -429,6 +432,11 @@ impl RateLimitInfo {
}
Self {
retry_after: headers
.get("retry-after")
.and_then(|v| v.to_str().ok())
.and_then(|v| v.parse::<u64>().ok())
.map(Duration::from_secs),
requests: RateLimit::from_headers("requests", headers).ok(),
tokens: RateLimit::from_headers("tokens", headers).ok(),
input_tokens: RateLimit::from_headers("input-tokens", headers).ok(),
@ -481,8 +489,8 @@ pub async fn stream_completion_with_rate_limit_info(
.send(request)
.await
.context("failed to send request to Anthropic")?;
let rate_limits = RateLimitInfo::from_headers(response.headers());
if response.status().is_success() {
let rate_limits = RateLimitInfo::from_headers(response.headers());
let reader = BufReader::new(response.into_body());
let stream = reader
.lines()
@ -500,6 +508,8 @@ pub async fn stream_completion_with_rate_limit_info(
})
.boxed();
Ok((stream, Some(rate_limits)))
} else if let Some(retry_after) = rate_limits.retry_after {
Err(AnthropicError::RateLimit(retry_after))
} else {
let mut body = Vec::new();
response
@ -769,6 +779,8 @@ pub struct MessageDelta {
#[derive(Error, Debug)]
pub enum AnthropicError {
#[error("rate limit exceeded, retry after {0:?}")]
RateLimit(Duration),
#[error("an error occurred while interacting with the Anthropic API: {error_type}: {message}", error_type = .0.error_type, message = .0.message)]
ApiError(ApiError),
#[error("{0}")]

View file

@ -682,11 +682,12 @@ mod tests {
_: &AsyncApp,
) -> BoxFuture<
'static,
http_client::Result<
Result<
BoxStream<
'static,
http_client::Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
>,
LanguageModelCompletionError,
>,
> {
unimplemented!()

View file

@ -80,6 +80,7 @@ rand.workspace = true
pretty_assertions.workspace = true
reqwest_client.workspace = true
settings = { workspace = true, features = ["test-support"] }
smol.workspace = true
task = { workspace = true, features = ["test-support"]}
tempfile.workspace = true
theme.workspace = true

View file

@ -11,7 +11,7 @@ use client::{Client, UserStore};
use collections::HashMap;
use fs::FakeFs;
use futures::{FutureExt, future::LocalBoxFuture};
use gpui::{AppContext, TestAppContext};
use gpui::{AppContext, TestAppContext, Timer};
use indoc::{formatdoc, indoc};
use language_model::{
LanguageModelRegistry, LanguageModelRequestTool, LanguageModelToolResult,
@ -1255,9 +1255,12 @@ impl EvalAssertion {
}],
..Default::default()
};
let mut response = judge
.stream_completion_text(request, &cx.to_async())
.await?;
let mut response = retry_on_rate_limit(async || {
Ok(judge
.stream_completion_text(request.clone(), &cx.to_async())
.await?)
})
.await?;
let mut output = String::new();
while let Some(chunk) = response.stream.next().await {
let chunk = chunk?;
@ -1308,10 +1311,17 @@ fn eval(
run_eval(eval.clone(), tx.clone());
let executor = gpui::background_executor();
let semaphore = Arc::new(smol::lock::Semaphore::new(32));
for _ in 1..iterations {
let eval = eval.clone();
let tx = tx.clone();
executor.spawn(async move { run_eval(eval, tx) }).detach();
let semaphore = semaphore.clone();
executor
.spawn(async move {
let _guard = semaphore.acquire().await;
run_eval(eval, tx)
})
.detach();
}
drop(tx);
@ -1577,21 +1587,31 @@ impl EditAgentTest {
if let Some(input_content) = eval.input_content.as_deref() {
buffer.update(cx, |buffer, cx| buffer.set_text(input_content, cx));
}
let (edit_output, _) = self.agent.edit(
buffer.clone(),
eval.edit_file_input.display_description,
&conversation,
&mut cx.to_async(),
);
edit_output.await?
retry_on_rate_limit(async || {
self.agent
.edit(
buffer.clone(),
eval.edit_file_input.display_description.clone(),
&conversation,
&mut cx.to_async(),
)
.0
.await
})
.await?
} else {
let (edit_output, _) = self.agent.overwrite(
buffer.clone(),
eval.edit_file_input.display_description,
&conversation,
&mut cx.to_async(),
);
edit_output.await?
retry_on_rate_limit(async || {
self.agent
.overwrite(
buffer.clone(),
eval.edit_file_input.display_description.clone(),
&conversation,
&mut cx.to_async(),
)
.0
.await
})
.await?
};
let buffer_text = buffer.read_with(cx, |buffer, _| buffer.text());
@ -1613,6 +1633,26 @@ impl EditAgentTest {
}
}
async fn retry_on_rate_limit<R>(mut request: impl AsyncFnMut() -> Result<R>) -> Result<R> {
loop {
match request().await {
Ok(result) => return Ok(result),
Err(err) => match err.downcast::<LanguageModelCompletionError>() {
Ok(err) => match err {
LanguageModelCompletionError::RateLimit(duration) => {
// Wait until after we are allowed to try again
eprintln!("Rate limit exceeded. Waiting for {duration:?}...",);
Timer::after(duration).await;
continue;
}
_ => return Err(err.into()),
},
Err(err) => return Err(err),
},
}
}
}
#[derive(Clone, Debug, Eq, PartialEq, Hash)]
struct EvalAssertionOutcome {
score: usize,

View file

@ -185,6 +185,7 @@ impl LanguageModel for FakeLanguageModel {
'static,
Result<
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
LanguageModelCompletionError,
>,
> {
let (tx, rx) = mpsc::unbounded();

View file

@ -22,6 +22,7 @@ use std::fmt;
use std::ops::{Add, Sub};
use std::str::FromStr as _;
use std::sync::Arc;
use std::time::Duration;
use thiserror::Error;
use util::serde::is_default;
use zed_llm_client::{
@ -74,6 +75,8 @@ pub enum LanguageModelCompletionEvent {
#[derive(Error, Debug)]
pub enum LanguageModelCompletionError {
#[error("rate limit exceeded, retry after {0:?}")]
RateLimit(Duration),
#[error("received bad input JSON")]
BadInputJson {
id: LanguageModelToolUseId,
@ -270,6 +273,7 @@ pub trait LanguageModel: Send + Sync {
'static,
Result<
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
LanguageModelCompletionError,
>,
>;
@ -277,7 +281,7 @@ pub trait LanguageModel: Send + Sync {
&self,
request: LanguageModelRequest,
cx: &AsyncApp,
) -> BoxFuture<'static, Result<LanguageModelTextStream>> {
) -> BoxFuture<'static, Result<LanguageModelTextStream, LanguageModelCompletionError>> {
let future = self.stream_completion(request, cx);
async move {

View file

@ -1,4 +1,3 @@
use anyhow::Result;
use futures::Stream;
use smol::lock::{Semaphore, SemaphoreGuardArc};
use std::{
@ -8,6 +7,8 @@ use std::{
task::{Context, Poll},
};
use crate::LanguageModelCompletionError;
#[derive(Clone)]
pub struct RateLimiter {
semaphore: Arc<Semaphore>,
@ -36,9 +37,12 @@ impl RateLimiter {
}
}
pub fn run<'a, Fut, T>(&self, future: Fut) -> impl 'a + Future<Output = Result<T>>
pub fn run<'a, Fut, T>(
&self,
future: Fut,
) -> impl 'a + Future<Output = Result<T, LanguageModelCompletionError>>
where
Fut: 'a + Future<Output = Result<T>>,
Fut: 'a + Future<Output = Result<T, LanguageModelCompletionError>>,
{
let guard = self.semaphore.acquire_arc();
async move {
@ -52,9 +56,12 @@ impl RateLimiter {
pub fn stream<'a, Fut, T>(
&self,
future: Fut,
) -> impl 'a + Future<Output = Result<impl Stream<Item = T::Item> + use<Fut, T>>>
) -> impl 'a
+ Future<
Output = Result<impl Stream<Item = T::Item> + use<Fut, T>, LanguageModelCompletionError>,
>
where
Fut: 'a + Future<Output = Result<T>>,
Fut: 'a + Future<Output = Result<T, LanguageModelCompletionError>>,
T: Stream,
{
let guard = self.semaphore.acquire_arc();

View file

@ -387,22 +387,34 @@ impl AnthropicModel {
&self,
request: anthropic::Request,
cx: &AsyncApp,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<anthropic::Event, AnthropicError>>>>
{
) -> BoxFuture<
'static,
Result<
BoxStream<'static, Result<anthropic::Event, AnthropicError>>,
LanguageModelCompletionError,
>,
> {
let http_client = self.http_client.clone();
let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, cx| {
let settings = &AllLanguageModelSettings::get_global(cx).anthropic;
(state.api_key.clone(), settings.api_url.clone())
}) else {
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
return futures::future::ready(Err(anyhow!("App state dropped").into())).boxed();
};
async move {
let api_key = api_key.context("Missing Anthropic API Key")?;
let request =
anthropic::stream_completion(http_client.as_ref(), &api_url, &api_key, request);
request.await.context("failed to stream completion")
request.await.map_err(|err| match err {
AnthropicError::RateLimit(duration) => {
LanguageModelCompletionError::RateLimit(duration)
}
err @ (AnthropicError::ApiError(..) | AnthropicError::Other(..)) => {
LanguageModelCompletionError::Other(anthropic_err_to_anyhow(err))
}
})
}
.boxed()
}
@ -473,6 +485,7 @@ impl LanguageModel for AnthropicModel {
'static,
Result<
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
LanguageModelCompletionError,
>,
> {
let request = into_anthropic(
@ -484,12 +497,7 @@ impl LanguageModel for AnthropicModel {
);
let request = self.stream_completion(request, cx);
let future = self.request_limiter.stream(async move {
let response = request
.await
.map_err(|err| match err.downcast::<AnthropicError>() {
Ok(anthropic_err) => anthropic_err_to_anyhow(anthropic_err),
Err(err) => anyhow!(err),
})?;
let response = request.await?;
Ok(AnthropicEventMapper::new().map_stream(response))
});
async move { Ok(future.await?.boxed()) }.boxed()

View file

@ -527,6 +527,7 @@ impl LanguageModel for BedrockModel {
'static,
Result<
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
LanguageModelCompletionError,
>,
> {
let Ok(region) = cx.read_entity(&self.state, |state, _cx| {
@ -539,16 +540,13 @@ impl LanguageModel for BedrockModel {
.or(settings_region)
.unwrap_or(String::from("us-east-1"))
}) else {
return async move {
anyhow::bail!("App State Dropped");
}
.boxed();
return async move { Err(anyhow::anyhow!("App State Dropped").into()) }.boxed();
};
let model_id = match self.model.cross_region_inference_id(&region) {
Ok(s) => s,
Err(e) => {
return async move { Err(e) }.boxed();
return async move { Err(e.into()) }.boxed();
}
};
@ -560,7 +558,7 @@ impl LanguageModel for BedrockModel {
self.model.mode(),
) {
Ok(request) => request,
Err(err) => return futures::future::ready(Err(err)).boxed(),
Err(err) => return futures::future::ready(Err(err.into())).boxed(),
};
let owned_handle = self.handler.clone();

View file

@ -807,6 +807,7 @@ impl LanguageModel for CloudLanguageModel {
'static,
Result<
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
LanguageModelCompletionError,
>,
> {
let thread_id = request.thread_id.clone();
@ -848,7 +849,8 @@ impl LanguageModel for CloudLanguageModel {
mode,
provider: zed_llm_client::LanguageModelProvider::Anthropic,
model: request.model.clone(),
provider_request: serde_json::to_value(&request)?,
provider_request: serde_json::to_value(&request)
.map_err(|e| anyhow!(e))?,
},
)
.await
@ -884,7 +886,7 @@ impl LanguageModel for CloudLanguageModel {
let client = self.client.clone();
let model = match open_ai::Model::from_id(&self.model.id.0) {
Ok(model) => model,
Err(err) => return async move { Err(anyhow!(err)) }.boxed(),
Err(err) => return async move { Err(anyhow!(err).into()) }.boxed(),
};
let request = into_open_ai(request, &model, None);
let llm_api_token = self.llm_api_token.clone();
@ -905,7 +907,8 @@ impl LanguageModel for CloudLanguageModel {
mode,
provider: zed_llm_client::LanguageModelProvider::OpenAi,
model: request.model.clone(),
provider_request: serde_json::to_value(&request)?,
provider_request: serde_json::to_value(&request)
.map_err(|e| anyhow!(e))?,
},
)
.await?;
@ -944,7 +947,8 @@ impl LanguageModel for CloudLanguageModel {
mode,
provider: zed_llm_client::LanguageModelProvider::Google,
model: request.model.model_id.clone(),
provider_request: serde_json::to_value(&request)?,
provider_request: serde_json::to_value(&request)
.map_err(|e| anyhow!(e))?,
},
)
.await?;

View file

@ -265,13 +265,15 @@ impl LanguageModel for CopilotChatLanguageModel {
'static,
Result<
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
LanguageModelCompletionError,
>,
> {
if let Some(message) = request.messages.last() {
if message.contents_empty() {
const EMPTY_PROMPT_MSG: &str =
"Empty prompts aren't allowed. Please provide a non-empty prompt.";
return futures::future::ready(Err(anyhow::anyhow!(EMPTY_PROMPT_MSG))).boxed();
return futures::future::ready(Err(anyhow::anyhow!(EMPTY_PROMPT_MSG).into()))
.boxed();
}
// Copilot Chat has a restriction that the final message must be from the user.
@ -279,13 +281,13 @@ impl LanguageModel for CopilotChatLanguageModel {
// and provide a more helpful error message.
if !matches!(message.role, Role::User) {
const USER_ROLE_MSG: &str = "The final message must be from the user. To provide a system prompt, you must provide the system prompt followed by a user prompt.";
return futures::future::ready(Err(anyhow::anyhow!(USER_ROLE_MSG))).boxed();
return futures::future::ready(Err(anyhow::anyhow!(USER_ROLE_MSG).into())).boxed();
}
}
let copilot_request = match into_copilot_chat(&self.model, request) {
Ok(request) => request,
Err(err) => return futures::future::ready(Err(err)).boxed(),
Err(err) => return futures::future::ready(Err(err.into())).boxed(),
};
let is_streaming = copilot_request.stream;

View file

@ -348,6 +348,7 @@ impl LanguageModel for DeepSeekLanguageModel {
'static,
Result<
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
LanguageModelCompletionError,
>,
> {
let request = into_deepseek(request, &self.model, self.max_output_tokens());

View file

@ -409,6 +409,7 @@ impl LanguageModel for GoogleLanguageModel {
'static,
Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
>,
LanguageModelCompletionError,
>,
> {
let request = into_google(

View file

@ -420,6 +420,7 @@ impl LanguageModel for LmStudioLanguageModel {
'static,
Result<
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
LanguageModelCompletionError,
>,
> {
let request = self.to_lmstudio_request(request);

View file

@ -364,6 +364,7 @@ impl LanguageModel for MistralLanguageModel {
'static,
Result<
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
LanguageModelCompletionError,
>,
> {
let request = into_mistral(

View file

@ -406,6 +406,7 @@ impl LanguageModel for OllamaLanguageModel {
'static,
Result<
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
LanguageModelCompletionError,
>,
> {
let request = self.to_ollama_request(request);
@ -415,7 +416,7 @@ impl LanguageModel for OllamaLanguageModel {
let settings = &AllLanguageModelSettings::get_global(cx).ollama;
settings.api_url.clone()
}) else {
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
return futures::future::ready(Err(anyhow!("App state dropped").into())).boxed();
};
let future = self.request_limiter.stream(async move {

View file

@ -339,6 +339,7 @@ impl LanguageModel for OpenAiLanguageModel {
'static,
Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
>,
LanguageModelCompletionError,
>,
> {
let request = into_open_ai(request, &self.model, self.max_output_tokens());

View file

@ -367,6 +367,7 @@ impl LanguageModel for OpenRouterLanguageModel {
'static,
Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
>,
LanguageModelCompletionError,
>,
> {
let request = into_open_router(request, &self.model, self.max_output_tokens());