diff --git a/.github/workflows/unit_evals.yml b/.github/workflows/unit_evals.yml index e033ba40ce..6ffae74e4e 100644 --- a/.github/workflows/unit_evals.yml +++ b/.github/workflows/unit_evals.yml @@ -62,7 +62,7 @@ jobs: - name: Run unit evals shell: bash -euxo pipefail {0} - run: cargo nextest run --workspace --no-fail-fast --features eval --no-capture -E 'test(::eval_)' --test-threads 1 + run: cargo nextest run --workspace --no-fail-fast --features eval --no-capture -E 'test(::eval_)' env: ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }} diff --git a/Cargo.lock b/Cargo.lock index c9f91959dc..3272bba810 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -705,6 +705,7 @@ dependencies = [ "serde_json", "settings", "smallvec", + "smol", "streaming_diff", "strsim", "task", diff --git a/crates/agent/src/buffer_codegen.rs b/crates/agent/src/buffer_codegen.rs index 166a002be2..e566ea9d86 100644 --- a/crates/agent/src/buffer_codegen.rs +++ b/crates/agent/src/buffer_codegen.rs @@ -386,8 +386,10 @@ impl CodegenAlternative { async { Ok(LanguageModelTextStream::default()) }.boxed_local() } else { let request = self.build_request(&model, user_prompt, cx)?; - cx.spawn(async move |_, cx| model.stream_completion_text(request.await, &cx).await) - .boxed_local() + cx.spawn(async move |_, cx| { + Ok(model.stream_completion_text(request.await, &cx).await?) + }) + .boxed_local() }; self.handle_stream(telemetry_id, provider_id.to_string(), api_key, stream, cx); Ok(()) diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index ad0e9260dc..eac99eefbe 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -1563,6 +1563,9 @@ impl Thread { Err(LanguageModelCompletionError::Other(error)) => { return Err(error); } + Err(err @ LanguageModelCompletionError::RateLimit(..)) => { + return Err(err.into()); + } }; match event { diff --git a/crates/anthropic/src/anthropic.rs b/crates/anthropic/src/anthropic.rs index be111235ec..58369ea252 100644 --- a/crates/anthropic/src/anthropic.rs +++ b/crates/anthropic/src/anthropic.rs @@ -1,4 +1,5 @@ use std::str::FromStr; +use std::time::Duration; use anyhow::{Context as _, Result, anyhow}; use chrono::{DateTime, Utc}; @@ -406,6 +407,7 @@ impl RateLimit { /// #[derive(Debug)] pub struct RateLimitInfo { + pub retry_after: Option, pub requests: Option, pub tokens: Option, pub input_tokens: Option, @@ -417,10 +419,11 @@ impl RateLimitInfo { // Check if any rate limit headers exist let has_rate_limit_headers = headers .keys() - .any(|k| k.as_str().starts_with("anthropic-ratelimit-")); + .any(|k| k == "retry-after" || k.as_str().starts_with("anthropic-ratelimit-")); if !has_rate_limit_headers { return Self { + retry_after: None, requests: None, tokens: None, input_tokens: None, @@ -429,6 +432,11 @@ impl RateLimitInfo { } Self { + retry_after: headers + .get("retry-after") + .and_then(|v| v.to_str().ok()) + .and_then(|v| v.parse::().ok()) + .map(Duration::from_secs), requests: RateLimit::from_headers("requests", headers).ok(), tokens: RateLimit::from_headers("tokens", headers).ok(), input_tokens: RateLimit::from_headers("input-tokens", headers).ok(), @@ -481,8 +489,8 @@ pub async fn stream_completion_with_rate_limit_info( .send(request) .await .context("failed to send request to Anthropic")?; + let rate_limits = RateLimitInfo::from_headers(response.headers()); if response.status().is_success() { - let rate_limits = RateLimitInfo::from_headers(response.headers()); let reader = BufReader::new(response.into_body()); let stream = reader .lines() @@ -500,6 +508,8 @@ pub async fn stream_completion_with_rate_limit_info( }) .boxed(); Ok((stream, Some(rate_limits))) + } else if let Some(retry_after) = rate_limits.retry_after { + Err(AnthropicError::RateLimit(retry_after)) } else { let mut body = Vec::new(); response @@ -769,6 +779,8 @@ pub struct MessageDelta { #[derive(Error, Debug)] pub enum AnthropicError { + #[error("rate limit exceeded, retry after {0:?}")] + RateLimit(Duration), #[error("an error occurred while interacting with the Anthropic API: {error_type}: {message}", error_type = .0.error_type, message = .0.message)] ApiError(ApiError), #[error("{0}")] diff --git a/crates/assistant_context_editor/src/language_model_selector.rs b/crates/assistant_context_editor/src/language_model_selector.rs index 049c4d24bd..732d8a326e 100644 --- a/crates/assistant_context_editor/src/language_model_selector.rs +++ b/crates/assistant_context_editor/src/language_model_selector.rs @@ -682,11 +682,12 @@ mod tests { _: &AsyncApp, ) -> BoxFuture< 'static, - http_client::Result< + Result< BoxStream< 'static, - http_client::Result, + Result, >, + LanguageModelCompletionError, >, > { unimplemented!() diff --git a/crates/assistant_tools/Cargo.toml b/crates/assistant_tools/Cargo.toml index ded54460d7..2b8958feb1 100644 --- a/crates/assistant_tools/Cargo.toml +++ b/crates/assistant_tools/Cargo.toml @@ -80,6 +80,7 @@ rand.workspace = true pretty_assertions.workspace = true reqwest_client.workspace = true settings = { workspace = true, features = ["test-support"] } +smol.workspace = true task = { workspace = true, features = ["test-support"]} tempfile.workspace = true theme.workspace = true diff --git a/crates/assistant_tools/src/edit_agent/evals.rs b/crates/assistant_tools/src/edit_agent/evals.rs index f07edff09e..023febbdae 100644 --- a/crates/assistant_tools/src/edit_agent/evals.rs +++ b/crates/assistant_tools/src/edit_agent/evals.rs @@ -11,7 +11,7 @@ use client::{Client, UserStore}; use collections::HashMap; use fs::FakeFs; use futures::{FutureExt, future::LocalBoxFuture}; -use gpui::{AppContext, TestAppContext}; +use gpui::{AppContext, TestAppContext, Timer}; use indoc::{formatdoc, indoc}; use language_model::{ LanguageModelRegistry, LanguageModelRequestTool, LanguageModelToolResult, @@ -1255,9 +1255,12 @@ impl EvalAssertion { }], ..Default::default() }; - let mut response = judge - .stream_completion_text(request, &cx.to_async()) - .await?; + let mut response = retry_on_rate_limit(async || { + Ok(judge + .stream_completion_text(request.clone(), &cx.to_async()) + .await?) + }) + .await?; let mut output = String::new(); while let Some(chunk) = response.stream.next().await { let chunk = chunk?; @@ -1308,10 +1311,17 @@ fn eval( run_eval(eval.clone(), tx.clone()); let executor = gpui::background_executor(); + let semaphore = Arc::new(smol::lock::Semaphore::new(32)); for _ in 1..iterations { let eval = eval.clone(); let tx = tx.clone(); - executor.spawn(async move { run_eval(eval, tx) }).detach(); + let semaphore = semaphore.clone(); + executor + .spawn(async move { + let _guard = semaphore.acquire().await; + run_eval(eval, tx) + }) + .detach(); } drop(tx); @@ -1577,21 +1587,31 @@ impl EditAgentTest { if let Some(input_content) = eval.input_content.as_deref() { buffer.update(cx, |buffer, cx| buffer.set_text(input_content, cx)); } - let (edit_output, _) = self.agent.edit( - buffer.clone(), - eval.edit_file_input.display_description, - &conversation, - &mut cx.to_async(), - ); - edit_output.await? + retry_on_rate_limit(async || { + self.agent + .edit( + buffer.clone(), + eval.edit_file_input.display_description.clone(), + &conversation, + &mut cx.to_async(), + ) + .0 + .await + }) + .await? } else { - let (edit_output, _) = self.agent.overwrite( - buffer.clone(), - eval.edit_file_input.display_description, - &conversation, - &mut cx.to_async(), - ); - edit_output.await? + retry_on_rate_limit(async || { + self.agent + .overwrite( + buffer.clone(), + eval.edit_file_input.display_description.clone(), + &conversation, + &mut cx.to_async(), + ) + .0 + .await + }) + .await? }; let buffer_text = buffer.read_with(cx, |buffer, _| buffer.text()); @@ -1613,6 +1633,26 @@ impl EditAgentTest { } } +async fn retry_on_rate_limit(mut request: impl AsyncFnMut() -> Result) -> Result { + loop { + match request().await { + Ok(result) => return Ok(result), + Err(err) => match err.downcast::() { + Ok(err) => match err { + LanguageModelCompletionError::RateLimit(duration) => { + // Wait until after we are allowed to try again + eprintln!("Rate limit exceeded. Waiting for {duration:?}...",); + Timer::after(duration).await; + continue; + } + _ => return Err(err.into()), + }, + Err(err) => return Err(err), + }, + } + } +} + #[derive(Clone, Debug, Eq, PartialEq, Hash)] struct EvalAssertionOutcome { score: usize, diff --git a/crates/language_model/src/fake_provider.rs b/crates/language_model/src/fake_provider.rs index b4ba0c057f..f04f568b72 100644 --- a/crates/language_model/src/fake_provider.rs +++ b/crates/language_model/src/fake_provider.rs @@ -185,6 +185,7 @@ impl LanguageModel for FakeLanguageModel { 'static, Result< BoxStream<'static, Result>, + LanguageModelCompletionError, >, > { let (tx, rx) = mpsc::unbounded(); diff --git a/crates/language_model/src/language_model.rs b/crates/language_model/src/language_model.rs index 7fb2f57585..01f005d73c 100644 --- a/crates/language_model/src/language_model.rs +++ b/crates/language_model/src/language_model.rs @@ -22,6 +22,7 @@ use std::fmt; use std::ops::{Add, Sub}; use std::str::FromStr as _; use std::sync::Arc; +use std::time::Duration; use thiserror::Error; use util::serde::is_default; use zed_llm_client::{ @@ -74,6 +75,8 @@ pub enum LanguageModelCompletionEvent { #[derive(Error, Debug)] pub enum LanguageModelCompletionError { + #[error("rate limit exceeded, retry after {0:?}")] + RateLimit(Duration), #[error("received bad input JSON")] BadInputJson { id: LanguageModelToolUseId, @@ -270,6 +273,7 @@ pub trait LanguageModel: Send + Sync { 'static, Result< BoxStream<'static, Result>, + LanguageModelCompletionError, >, >; @@ -277,7 +281,7 @@ pub trait LanguageModel: Send + Sync { &self, request: LanguageModelRequest, cx: &AsyncApp, - ) -> BoxFuture<'static, Result> { + ) -> BoxFuture<'static, Result> { let future = self.stream_completion(request, cx); async move { diff --git a/crates/language_model/src/rate_limiter.rs b/crates/language_model/src/rate_limiter.rs index a48d34488b..790be05ac0 100644 --- a/crates/language_model/src/rate_limiter.rs +++ b/crates/language_model/src/rate_limiter.rs @@ -1,4 +1,3 @@ -use anyhow::Result; use futures::Stream; use smol::lock::{Semaphore, SemaphoreGuardArc}; use std::{ @@ -8,6 +7,8 @@ use std::{ task::{Context, Poll}, }; +use crate::LanguageModelCompletionError; + #[derive(Clone)] pub struct RateLimiter { semaphore: Arc, @@ -36,9 +37,12 @@ impl RateLimiter { } } - pub fn run<'a, Fut, T>(&self, future: Fut) -> impl 'a + Future> + pub fn run<'a, Fut, T>( + &self, + future: Fut, + ) -> impl 'a + Future> where - Fut: 'a + Future>, + Fut: 'a + Future>, { let guard = self.semaphore.acquire_arc(); async move { @@ -52,9 +56,12 @@ impl RateLimiter { pub fn stream<'a, Fut, T>( &self, future: Fut, - ) -> impl 'a + Future + use>> + ) -> impl 'a + + Future< + Output = Result + use, LanguageModelCompletionError>, + > where - Fut: 'a + Future>, + Fut: 'a + Future>, T: Stream, { let guard = self.semaphore.acquire_arc(); diff --git a/crates/language_models/src/provider/anthropic.rs b/crates/language_models/src/provider/anthropic.rs index 4a524eb452..c581f01f4c 100644 --- a/crates/language_models/src/provider/anthropic.rs +++ b/crates/language_models/src/provider/anthropic.rs @@ -387,22 +387,34 @@ impl AnthropicModel { &self, request: anthropic::Request, cx: &AsyncApp, - ) -> BoxFuture<'static, Result>>> - { + ) -> BoxFuture< + 'static, + Result< + BoxStream<'static, Result>, + LanguageModelCompletionError, + >, + > { let http_client = self.http_client.clone(); let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, cx| { let settings = &AllLanguageModelSettings::get_global(cx).anthropic; (state.api_key.clone(), settings.api_url.clone()) }) else { - return futures::future::ready(Err(anyhow!("App state dropped"))).boxed(); + return futures::future::ready(Err(anyhow!("App state dropped").into())).boxed(); }; async move { let api_key = api_key.context("Missing Anthropic API Key")?; let request = anthropic::stream_completion(http_client.as_ref(), &api_url, &api_key, request); - request.await.context("failed to stream completion") + request.await.map_err(|err| match err { + AnthropicError::RateLimit(duration) => { + LanguageModelCompletionError::RateLimit(duration) + } + err @ (AnthropicError::ApiError(..) | AnthropicError::Other(..)) => { + LanguageModelCompletionError::Other(anthropic_err_to_anyhow(err)) + } + }) } .boxed() } @@ -473,6 +485,7 @@ impl LanguageModel for AnthropicModel { 'static, Result< BoxStream<'static, Result>, + LanguageModelCompletionError, >, > { let request = into_anthropic( @@ -484,12 +497,7 @@ impl LanguageModel for AnthropicModel { ); let request = self.stream_completion(request, cx); let future = self.request_limiter.stream(async move { - let response = request - .await - .map_err(|err| match err.downcast::() { - Ok(anthropic_err) => anthropic_err_to_anyhow(anthropic_err), - Err(err) => anyhow!(err), - })?; + let response = request.await?; Ok(AnthropicEventMapper::new().map_stream(response)) }); async move { Ok(future.await?.boxed()) }.boxed() diff --git a/crates/language_models/src/provider/bedrock.rs b/crates/language_models/src/provider/bedrock.rs index 2ee786fbb6..8ec659c97a 100644 --- a/crates/language_models/src/provider/bedrock.rs +++ b/crates/language_models/src/provider/bedrock.rs @@ -527,6 +527,7 @@ impl LanguageModel for BedrockModel { 'static, Result< BoxStream<'static, Result>, + LanguageModelCompletionError, >, > { let Ok(region) = cx.read_entity(&self.state, |state, _cx| { @@ -539,16 +540,13 @@ impl LanguageModel for BedrockModel { .or(settings_region) .unwrap_or(String::from("us-east-1")) }) else { - return async move { - anyhow::bail!("App State Dropped"); - } - .boxed(); + return async move { Err(anyhow::anyhow!("App State Dropped").into()) }.boxed(); }; let model_id = match self.model.cross_region_inference_id(®ion) { Ok(s) => s, Err(e) => { - return async move { Err(e) }.boxed(); + return async move { Err(e.into()) }.boxed(); } }; @@ -560,7 +558,7 @@ impl LanguageModel for BedrockModel { self.model.mode(), ) { Ok(request) => request, - Err(err) => return futures::future::ready(Err(err)).boxed(), + Err(err) => return futures::future::ready(Err(err.into())).boxed(), }; let owned_handle = self.handler.clone(); diff --git a/crates/language_models/src/provider/cloud.rs b/crates/language_models/src/provider/cloud.rs index ee6fe8d484..b04971c5ed 100644 --- a/crates/language_models/src/provider/cloud.rs +++ b/crates/language_models/src/provider/cloud.rs @@ -807,6 +807,7 @@ impl LanguageModel for CloudLanguageModel { 'static, Result< BoxStream<'static, Result>, + LanguageModelCompletionError, >, > { let thread_id = request.thread_id.clone(); @@ -848,7 +849,8 @@ impl LanguageModel for CloudLanguageModel { mode, provider: zed_llm_client::LanguageModelProvider::Anthropic, model: request.model.clone(), - provider_request: serde_json::to_value(&request)?, + provider_request: serde_json::to_value(&request) + .map_err(|e| anyhow!(e))?, }, ) .await @@ -884,7 +886,7 @@ impl LanguageModel for CloudLanguageModel { let client = self.client.clone(); let model = match open_ai::Model::from_id(&self.model.id.0) { Ok(model) => model, - Err(err) => return async move { Err(anyhow!(err)) }.boxed(), + Err(err) => return async move { Err(anyhow!(err).into()) }.boxed(), }; let request = into_open_ai(request, &model, None); let llm_api_token = self.llm_api_token.clone(); @@ -905,7 +907,8 @@ impl LanguageModel for CloudLanguageModel { mode, provider: zed_llm_client::LanguageModelProvider::OpenAi, model: request.model.clone(), - provider_request: serde_json::to_value(&request)?, + provider_request: serde_json::to_value(&request) + .map_err(|e| anyhow!(e))?, }, ) .await?; @@ -944,7 +947,8 @@ impl LanguageModel for CloudLanguageModel { mode, provider: zed_llm_client::LanguageModelProvider::Google, model: request.model.model_id.clone(), - provider_request: serde_json::to_value(&request)?, + provider_request: serde_json::to_value(&request) + .map_err(|e| anyhow!(e))?, }, ) .await?; diff --git a/crates/language_models/src/provider/copilot_chat.rs b/crates/language_models/src/provider/copilot_chat.rs index fc655e0c6f..11015aa455 100644 --- a/crates/language_models/src/provider/copilot_chat.rs +++ b/crates/language_models/src/provider/copilot_chat.rs @@ -265,13 +265,15 @@ impl LanguageModel for CopilotChatLanguageModel { 'static, Result< BoxStream<'static, Result>, + LanguageModelCompletionError, >, > { if let Some(message) = request.messages.last() { if message.contents_empty() { const EMPTY_PROMPT_MSG: &str = "Empty prompts aren't allowed. Please provide a non-empty prompt."; - return futures::future::ready(Err(anyhow::anyhow!(EMPTY_PROMPT_MSG))).boxed(); + return futures::future::ready(Err(anyhow::anyhow!(EMPTY_PROMPT_MSG).into())) + .boxed(); } // Copilot Chat has a restriction that the final message must be from the user. @@ -279,13 +281,13 @@ impl LanguageModel for CopilotChatLanguageModel { // and provide a more helpful error message. if !matches!(message.role, Role::User) { const USER_ROLE_MSG: &str = "The final message must be from the user. To provide a system prompt, you must provide the system prompt followed by a user prompt."; - return futures::future::ready(Err(anyhow::anyhow!(USER_ROLE_MSG))).boxed(); + return futures::future::ready(Err(anyhow::anyhow!(USER_ROLE_MSG).into())).boxed(); } } let copilot_request = match into_copilot_chat(&self.model, request) { Ok(request) => request, - Err(err) => return futures::future::ready(Err(err)).boxed(), + Err(err) => return futures::future::ready(Err(err.into())).boxed(), }; let is_streaming = copilot_request.stream; diff --git a/crates/language_models/src/provider/deepseek.rs b/crates/language_models/src/provider/deepseek.rs index 6a16ec019f..4a4faa13a7 100644 --- a/crates/language_models/src/provider/deepseek.rs +++ b/crates/language_models/src/provider/deepseek.rs @@ -348,6 +348,7 @@ impl LanguageModel for DeepSeekLanguageModel { 'static, Result< BoxStream<'static, Result>, + LanguageModelCompletionError, >, > { let request = into_deepseek(request, &self.model, self.max_output_tokens()); diff --git a/crates/language_models/src/provider/google.rs b/crates/language_models/src/provider/google.rs index 8608666c46..aeee2ca73e 100644 --- a/crates/language_models/src/provider/google.rs +++ b/crates/language_models/src/provider/google.rs @@ -409,6 +409,7 @@ impl LanguageModel for GoogleLanguageModel { 'static, Result, >, + LanguageModelCompletionError, >, > { let request = into_google( diff --git a/crates/language_models/src/provider/lmstudio.rs b/crates/language_models/src/provider/lmstudio.rs index 792d39bfed..8f0704c5bc 100644 --- a/crates/language_models/src/provider/lmstudio.rs +++ b/crates/language_models/src/provider/lmstudio.rs @@ -420,6 +420,7 @@ impl LanguageModel for LmStudioLanguageModel { 'static, Result< BoxStream<'static, Result>, + LanguageModelCompletionError, >, > { let request = self.to_lmstudio_request(request); diff --git a/crates/language_models/src/provider/mistral.rs b/crates/language_models/src/provider/mistral.rs index 6debead977..d00af8ecd6 100644 --- a/crates/language_models/src/provider/mistral.rs +++ b/crates/language_models/src/provider/mistral.rs @@ -364,6 +364,7 @@ impl LanguageModel for MistralLanguageModel { 'static, Result< BoxStream<'static, Result>, + LanguageModelCompletionError, >, > { let request = into_mistral( diff --git a/crates/language_models/src/provider/ollama.rs b/crates/language_models/src/provider/ollama.rs index bf2a2e1597..fca78e4791 100644 --- a/crates/language_models/src/provider/ollama.rs +++ b/crates/language_models/src/provider/ollama.rs @@ -406,6 +406,7 @@ impl LanguageModel for OllamaLanguageModel { 'static, Result< BoxStream<'static, Result>, + LanguageModelCompletionError, >, > { let request = self.to_ollama_request(request); @@ -415,7 +416,7 @@ impl LanguageModel for OllamaLanguageModel { let settings = &AllLanguageModelSettings::get_global(cx).ollama; settings.api_url.clone() }) else { - return futures::future::ready(Err(anyhow!("App state dropped"))).boxed(); + return futures::future::ready(Err(anyhow!("App state dropped").into())).boxed(); }; let future = self.request_limiter.stream(async move { diff --git a/crates/language_models/src/provider/open_ai.rs b/crates/language_models/src/provider/open_ai.rs index 48812edcc8..146aac9cfa 100644 --- a/crates/language_models/src/provider/open_ai.rs +++ b/crates/language_models/src/provider/open_ai.rs @@ -339,6 +339,7 @@ impl LanguageModel for OpenAiLanguageModel { 'static, Result, >, + LanguageModelCompletionError, >, > { let request = into_open_ai(request, &self.model, self.max_output_tokens()); diff --git a/crates/language_models/src/provider/open_router.rs b/crates/language_models/src/provider/open_router.rs index 7af265544a..623316916b 100644 --- a/crates/language_models/src/provider/open_router.rs +++ b/crates/language_models/src/provider/open_router.rs @@ -367,6 +367,7 @@ impl LanguageModel for OpenRouterLanguageModel { 'static, Result, >, + LanguageModelCompletionError, >, > { let request = into_open_router(request, &self.model, self.max_output_tokens());