agent: Improve error handling and retry for zed-provided models (#33565)
* Updates to `zed_llm_client-0.8.5` which adds support for `retry_after` when anthropic provides it. * Distinguishes upstream provider errors and rate limits from errors that originate from zed's servers * Moves `LanguageModelCompletionError::BadInputJson` to `LanguageModelCompletionEvent::ToolUseJsonParseError`. While arguably this is an error case, the logic in thread is cleaner with this move. There is also precedent for inclusion of errors in the event type - `CompletionRequestStatus::Failed` is how cloud errors arrive. * Updates `PROVIDER_ID` / `PROVIDER_NAME` constants to use proper types instead of `&str`, since they can be constructed in a const fashion. * Removes use of `CLIENT_SUPPORTS_EXA_WEB_SEARCH_PROVIDER_HEADER_NAME` as the server no longer reads this header and just defaults to that behavior. Release notes for this is covered by #33275 Release Notes: - N/A --------- Co-authored-by: Richard Feldman <oss@rtfeldman.com> Co-authored-by: Richard <richard@zed.dev>
This commit is contained in:
parent
f022a13091
commit
d497f52e17
25 changed files with 656 additions and 479 deletions
|
@ -1,4 +1,4 @@
|
|||
use anthropic::{AnthropicModelMode, parse_prompt_too_long};
|
||||
use anthropic::AnthropicModelMode;
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use client::{Client, ModelRequestUsage, UserStore, zed_urls};
|
||||
use futures::{
|
||||
|
@ -8,25 +8,21 @@ use google_ai::GoogleModelMode;
|
|||
use gpui::{
|
||||
AnyElement, AnyView, App, AsyncApp, Context, Entity, SemanticVersion, Subscription, Task,
|
||||
};
|
||||
use http_client::http::{HeaderMap, HeaderValue};
|
||||
use http_client::{AsyncBody, HttpClient, Method, Response, StatusCode};
|
||||
use language_model::{
|
||||
AuthenticateError, LanguageModel, LanguageModelCacheConfiguration,
|
||||
LanguageModelCompletionError, LanguageModelId, LanguageModelKnownError, LanguageModelName,
|
||||
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
|
||||
LanguageModelProviderTosView, LanguageModelRequest, LanguageModelToolChoice,
|
||||
LanguageModelToolSchemaFormat, ModelRequestLimitReachedError, RateLimiter,
|
||||
ZED_CLOUD_PROVIDER_ID,
|
||||
};
|
||||
use language_model::{
|
||||
LanguageModelCompletionEvent, LanguageModelProvider, LlmApiToken, PaymentRequiredError,
|
||||
RefreshLlmTokenListener,
|
||||
LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName,
|
||||
LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
|
||||
LanguageModelProviderState, LanguageModelProviderTosView, LanguageModelRequest,
|
||||
LanguageModelToolChoice, LanguageModelToolSchemaFormat, LlmApiToken,
|
||||
ModelRequestLimitReachedError, PaymentRequiredError, RateLimiter, RefreshLlmTokenListener,
|
||||
};
|
||||
use proto::Plan;
|
||||
use release_channel::AppVersion;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize, de::DeserializeOwned};
|
||||
use settings::SettingsStore;
|
||||
use smol::Timer;
|
||||
use smol::io::{AsyncReadExt, BufReader};
|
||||
use std::pin::Pin;
|
||||
use std::str::FromStr as _;
|
||||
|
@ -47,7 +43,8 @@ use crate::provider::anthropic::{AnthropicEventMapper, count_anthropic_tokens, i
|
|||
use crate::provider::google::{GoogleEventMapper, into_google};
|
||||
use crate::provider::open_ai::{OpenAiEventMapper, count_open_ai_tokens, into_open_ai};
|
||||
|
||||
pub const PROVIDER_NAME: &str = "Zed";
|
||||
const PROVIDER_ID: LanguageModelProviderId = language_model::ZED_CLOUD_PROVIDER_ID;
|
||||
const PROVIDER_NAME: LanguageModelProviderName = language_model::ZED_CLOUD_PROVIDER_NAME;
|
||||
|
||||
#[derive(Default, Clone, Debug, PartialEq)]
|
||||
pub struct ZedDotDevSettings {
|
||||
|
@ -351,11 +348,11 @@ impl LanguageModelProviderState for CloudLanguageModelProvider {
|
|||
|
||||
impl LanguageModelProvider for CloudLanguageModelProvider {
|
||||
fn id(&self) -> LanguageModelProviderId {
|
||||
LanguageModelProviderId(ZED_CLOUD_PROVIDER_ID.into())
|
||||
PROVIDER_ID
|
||||
}
|
||||
|
||||
fn name(&self) -> LanguageModelProviderName {
|
||||
LanguageModelProviderName(PROVIDER_NAME.into())
|
||||
PROVIDER_NAME
|
||||
}
|
||||
|
||||
fn icon(&self) -> IconName {
|
||||
|
@ -536,8 +533,6 @@ struct PerformLlmCompletionResponse {
|
|||
}
|
||||
|
||||
impl CloudLanguageModel {
|
||||
const MAX_RETRIES: usize = 3;
|
||||
|
||||
async fn perform_llm_completion(
|
||||
client: Arc<Client>,
|
||||
llm_api_token: LlmApiToken,
|
||||
|
@ -547,8 +542,7 @@ impl CloudLanguageModel {
|
|||
let http_client = &client.http_client();
|
||||
|
||||
let mut token = llm_api_token.acquire(&client).await?;
|
||||
let mut retries_remaining = Self::MAX_RETRIES;
|
||||
let mut retry_delay = Duration::from_secs(1);
|
||||
let mut refreshed_token = false;
|
||||
|
||||
loop {
|
||||
let request_builder = http_client::Request::builder()
|
||||
|
@ -590,14 +584,20 @@ impl CloudLanguageModel {
|
|||
includes_status_messages,
|
||||
tool_use_limit_reached,
|
||||
});
|
||||
} else if response
|
||||
.headers()
|
||||
.get(EXPIRED_LLM_TOKEN_HEADER_NAME)
|
||||
.is_some()
|
||||
}
|
||||
|
||||
if !refreshed_token
|
||||
&& response
|
||||
.headers()
|
||||
.get(EXPIRED_LLM_TOKEN_HEADER_NAME)
|
||||
.is_some()
|
||||
{
|
||||
retries_remaining -= 1;
|
||||
token = llm_api_token.refresh(&client).await?;
|
||||
} else if status == StatusCode::FORBIDDEN
|
||||
refreshed_token = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
if status == StatusCode::FORBIDDEN
|
||||
&& response
|
||||
.headers()
|
||||
.get(SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME)
|
||||
|
@ -622,35 +622,18 @@ impl CloudLanguageModel {
|
|||
return Err(anyhow!(ModelRequestLimitReachedError { plan }));
|
||||
}
|
||||
}
|
||||
|
||||
anyhow::bail!("Forbidden");
|
||||
} else if status.as_u16() >= 500 && status.as_u16() < 600 {
|
||||
// If we encounter an error in the 500 range, retry after a delay.
|
||||
// We've seen at least these in the wild from API providers:
|
||||
// * 500 Internal Server Error
|
||||
// * 502 Bad Gateway
|
||||
// * 529 Service Overloaded
|
||||
|
||||
if retries_remaining == 0 {
|
||||
let mut body = String::new();
|
||||
response.body_mut().read_to_string(&mut body).await?;
|
||||
anyhow::bail!(
|
||||
"cloud language model completion failed after {} retries with status {status}: {body}",
|
||||
Self::MAX_RETRIES
|
||||
);
|
||||
}
|
||||
|
||||
Timer::after(retry_delay).await;
|
||||
|
||||
retries_remaining -= 1;
|
||||
retry_delay *= 2; // If it fails again, wait longer.
|
||||
} else if status == StatusCode::PAYMENT_REQUIRED {
|
||||
return Err(anyhow!(PaymentRequiredError));
|
||||
} else {
|
||||
let mut body = String::new();
|
||||
response.body_mut().read_to_string(&mut body).await?;
|
||||
return Err(anyhow!(ApiError { status, body }));
|
||||
}
|
||||
|
||||
let mut body = String::new();
|
||||
let headers = response.headers().clone();
|
||||
response.body_mut().read_to_string(&mut body).await?;
|
||||
return Err(anyhow!(ApiError {
|
||||
status,
|
||||
body,
|
||||
headers
|
||||
}));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -660,6 +643,19 @@ impl CloudLanguageModel {
|
|||
struct ApiError {
|
||||
status: StatusCode,
|
||||
body: String,
|
||||
headers: HeaderMap<HeaderValue>,
|
||||
}
|
||||
|
||||
impl From<ApiError> for LanguageModelCompletionError {
|
||||
fn from(error: ApiError) -> Self {
|
||||
let retry_after = None;
|
||||
LanguageModelCompletionError::from_http_status(
|
||||
PROVIDER_NAME,
|
||||
error.status,
|
||||
error.body,
|
||||
retry_after,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl LanguageModel for CloudLanguageModel {
|
||||
|
@ -672,11 +668,29 @@ impl LanguageModel for CloudLanguageModel {
|
|||
}
|
||||
|
||||
fn provider_id(&self) -> LanguageModelProviderId {
|
||||
LanguageModelProviderId(ZED_CLOUD_PROVIDER_ID.into())
|
||||
PROVIDER_ID
|
||||
}
|
||||
|
||||
fn provider_name(&self) -> LanguageModelProviderName {
|
||||
LanguageModelProviderName(PROVIDER_NAME.into())
|
||||
PROVIDER_NAME
|
||||
}
|
||||
|
||||
fn upstream_provider_id(&self) -> LanguageModelProviderId {
|
||||
use zed_llm_client::LanguageModelProvider::*;
|
||||
match self.model.provider {
|
||||
Anthropic => language_model::ANTHROPIC_PROVIDER_ID,
|
||||
OpenAi => language_model::OPEN_AI_PROVIDER_ID,
|
||||
Google => language_model::GOOGLE_PROVIDER_ID,
|
||||
}
|
||||
}
|
||||
|
||||
fn upstream_provider_name(&self) -> LanguageModelProviderName {
|
||||
use zed_llm_client::LanguageModelProvider::*;
|
||||
match self.model.provider {
|
||||
Anthropic => language_model::ANTHROPIC_PROVIDER_NAME,
|
||||
OpenAi => language_model::OPEN_AI_PROVIDER_NAME,
|
||||
Google => language_model::GOOGLE_PROVIDER_NAME,
|
||||
}
|
||||
}
|
||||
|
||||
fn supports_tools(&self) -> bool {
|
||||
|
@ -776,6 +790,7 @@ impl LanguageModel for CloudLanguageModel {
|
|||
.body(serde_json::to_string(&request_body)?.into())?;
|
||||
let mut response = http_client.send(request).await?;
|
||||
let status = response.status();
|
||||
let headers = response.headers().clone();
|
||||
let mut response_body = String::new();
|
||||
response
|
||||
.body_mut()
|
||||
|
@ -790,7 +805,8 @@ impl LanguageModel for CloudLanguageModel {
|
|||
} else {
|
||||
Err(anyhow!(ApiError {
|
||||
status,
|
||||
body: response_body
|
||||
body: response_body,
|
||||
headers
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
@ -855,18 +871,7 @@ impl LanguageModel for CloudLanguageModel {
|
|||
)
|
||||
.await
|
||||
.map_err(|err| match err.downcast::<ApiError>() {
|
||||
Ok(api_err) => {
|
||||
if api_err.status == StatusCode::BAD_REQUEST {
|
||||
if let Some(tokens) = parse_prompt_too_long(&api_err.body) {
|
||||
return anyhow!(
|
||||
LanguageModelKnownError::ContextWindowLimitExceeded {
|
||||
tokens
|
||||
}
|
||||
);
|
||||
}
|
||||
}
|
||||
anyhow!(api_err)
|
||||
}
|
||||
Ok(api_err) => anyhow!(LanguageModelCompletionError::from(api_err)),
|
||||
Err(err) => anyhow!(err),
|
||||
})?;
|
||||
|
||||
|
@ -995,7 +1000,7 @@ where
|
|||
.flat_map(move |event| {
|
||||
futures::stream::iter(match event {
|
||||
Err(error) => {
|
||||
vec![Err(LanguageModelCompletionError::Other(error))]
|
||||
vec![Err(LanguageModelCompletionError::from(error))]
|
||||
}
|
||||
Ok(CloudCompletionEvent::Status(event)) => {
|
||||
vec![Ok(LanguageModelCompletionEvent::StatusUpdate(event))]
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue