agent: Improve error handling and retry for zed-provided models (#33565)

* Updates to `zed_llm_client-0.8.5` which adds support for `retry_after`
when anthropic provides it.

* Distinguishes upstream provider errors and rate limits from errors
that originate from zed's servers

* Moves `LanguageModelCompletionError::BadInputJson` to
`LanguageModelCompletionEvent::ToolUseJsonParseError`. While arguably
this is an error case, the logic in thread is cleaner with this move.
There is also precedent for inclusion of errors in the event type -
`CompletionRequestStatus::Failed` is how cloud errors arrive.

* Updates `PROVIDER_ID` / `PROVIDER_NAME` constants to use proper types
instead of `&str`, since they can be constructed in a const fashion.

* Removes use of `CLIENT_SUPPORTS_EXA_WEB_SEARCH_PROVIDER_HEADER_NAME`
as the server no longer reads this header and just defaults to that
behavior.

Release notes for this is covered by #33275

Release Notes:

- N/A

---------

Co-authored-by: Richard Feldman <oss@rtfeldman.com>
Co-authored-by: Richard <richard@zed.dev>
This commit is contained in:
Michael Sloan 2025-06-30 21:01:32 -06:00 committed by GitHub
parent f022a13091
commit d497f52e17
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
25 changed files with 656 additions and 479 deletions

View file

@ -23,11 +23,10 @@ use gpui::{
};
use language_model::{
ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
LanguageModelId, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
LanguageModelToolResultContent, LanguageModelToolUseId, MessageContent,
ModelRequestLimitReachedError, PaymentRequiredError, Role, SelectedModel, StopReason,
TokenUsage,
LanguageModelId, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolResultContent,
LanguageModelToolUseId, MessageContent, ModelRequestLimitReachedError, PaymentRequiredError,
Role, SelectedModel, StopReason, TokenUsage,
};
use postage::stream::Stream as _;
use project::{
@ -1531,82 +1530,7 @@ impl Thread {
}
thread.update(cx, |thread, cx| {
let event = match event {
Ok(event) => event,
Err(error) => {
match error {
LanguageModelCompletionError::RateLimitExceeded { retry_after } => {
anyhow::bail!(LanguageModelKnownError::RateLimitExceeded { retry_after });
}
LanguageModelCompletionError::Overloaded => {
anyhow::bail!(LanguageModelKnownError::Overloaded);
}
LanguageModelCompletionError::ApiInternalServerError =>{
anyhow::bail!(LanguageModelKnownError::ApiInternalServerError);
}
LanguageModelCompletionError::PromptTooLarge { tokens } => {
let tokens = tokens.unwrap_or_else(|| {
// We didn't get an exact token count from the API, so fall back on our estimate.
thread.total_token_usage()
.map(|usage| usage.total)
.unwrap_or(0)
// We know the context window was exceeded in practice, so if our estimate was
// lower than max tokens, the estimate was wrong; return that we exceeded by 1.
.max(model.max_token_count().saturating_add(1))
});
anyhow::bail!(LanguageModelKnownError::ContextWindowLimitExceeded { tokens })
}
LanguageModelCompletionError::ApiReadResponseError(io_error) => {
anyhow::bail!(LanguageModelKnownError::ReadResponseError(io_error));
}
LanguageModelCompletionError::UnknownResponseFormat(error) => {
anyhow::bail!(LanguageModelKnownError::UnknownResponseFormat(error));
}
LanguageModelCompletionError::HttpResponseError { status, ref body } => {
if let Some(known_error) = LanguageModelKnownError::from_http_response(status, body) {
anyhow::bail!(known_error);
} else {
return Err(error.into());
}
}
LanguageModelCompletionError::DeserializeResponse(error) => {
anyhow::bail!(LanguageModelKnownError::DeserializeResponse(error));
}
LanguageModelCompletionError::BadInputJson {
id,
tool_name,
raw_input: invalid_input_json,
json_parse_error,
} => {
thread.receive_invalid_tool_json(
id,
tool_name,
invalid_input_json,
json_parse_error,
window,
cx,
);
return Ok(());
}
// These are all errors we can't automatically attempt to recover from (e.g. by retrying)
err @ LanguageModelCompletionError::BadRequestFormat |
err @ LanguageModelCompletionError::AuthenticationError |
err @ LanguageModelCompletionError::PermissionError |
err @ LanguageModelCompletionError::ApiEndpointNotFound |
err @ LanguageModelCompletionError::SerializeRequest(_) |
err @ LanguageModelCompletionError::BuildRequestBody(_) |
err @ LanguageModelCompletionError::HttpSend(_) => {
anyhow::bail!(err);
}
LanguageModelCompletionError::Other(error) => {
return Err(error);
}
}
}
};
match event {
match event? {
LanguageModelCompletionEvent::StartMessage { .. } => {
request_assistant_message_id =
Some(thread.insert_assistant_message(
@ -1683,9 +1607,7 @@ impl Thread {
};
}
}
LanguageModelCompletionEvent::RedactedThinking {
data
} => {
LanguageModelCompletionEvent::RedactedThinking { data } => {
thread.received_chunk();
if let Some(last_message) = thread.messages.last_mut() {
@ -1734,6 +1656,21 @@ impl Thread {
});
}
}
LanguageModelCompletionEvent::ToolUseJsonParseError {
id,
tool_name,
raw_input: invalid_input_json,
json_parse_error,
} => {
thread.receive_invalid_tool_json(
id,
tool_name,
invalid_input_json,
json_parse_error,
window,
cx,
);
}
LanguageModelCompletionEvent::StatusUpdate(status_update) => {
if let Some(completion) = thread
.pending_completions
@ -1741,23 +1678,34 @@ impl Thread {
.find(|completion| completion.id == pending_completion_id)
{
match status_update {
CompletionRequestStatus::Queued {
position,
} => {
completion.queue_state = QueueState::Queued { position };
CompletionRequestStatus::Queued { position } => {
completion.queue_state =
QueueState::Queued { position };
}
CompletionRequestStatus::Started => {
completion.queue_state = QueueState::Started;
completion.queue_state = QueueState::Started;
}
CompletionRequestStatus::Failed {
code, message, request_id
code,
message,
request_id: _,
retry_after,
} => {
anyhow::bail!("completion request failed. request_id: {request_id}, code: {code}, message: {message}");
return Err(
LanguageModelCompletionError::from_cloud_failure(
model.upstream_provider_name(),
code,
message,
retry_after.map(Duration::from_secs_f64),
),
);
}
CompletionRequestStatus::UsageUpdated {
amount, limit
} => {
thread.update_model_request_usage(amount as u32, limit, cx);
CompletionRequestStatus::UsageUpdated { amount, limit } => {
thread.update_model_request_usage(
amount as u32,
limit,
cx,
);
}
CompletionRequestStatus::ToolUseLimitReached => {
thread.tool_use_limit_reached = true;
@ -1808,10 +1756,11 @@ impl Thread {
Ok(stop_reason) => {
match stop_reason {
StopReason::ToolUse => {
let tool_uses = thread.use_pending_tools(window, model.clone(), cx);
let tool_uses =
thread.use_pending_tools(window, model.clone(), cx);
cx.emit(ThreadEvent::UsePendingTools { tool_uses });
}
StopReason::EndTurn | StopReason::MaxTokens => {
StopReason::EndTurn | StopReason::MaxTokens => {
thread.project.update(cx, |project, cx| {
project.set_agent_location(None, cx);
});
@ -1827,7 +1776,9 @@ impl Thread {
{
let mut messages_to_remove = Vec::new();
for (ix, message) in thread.messages.iter().enumerate().rev() {
for (ix, message) in
thread.messages.iter().enumerate().rev()
{
messages_to_remove.push(message.id);
if message.role == Role::User {
@ -1835,7 +1786,9 @@ impl Thread {
break;
}
if let Some(prev_message) = thread.messages.get(ix - 1) {
if let Some(prev_message) =
thread.messages.get(ix - 1)
{
if prev_message.role == Role::Assistant {
break;
}
@ -1850,14 +1803,16 @@ impl Thread {
cx.emit(ThreadEvent::ShowError(ThreadError::Message {
header: "Language model refusal".into(),
message: "Model refused to generate content for safety reasons.".into(),
message:
"Model refused to generate content for safety reasons."
.into(),
}));
}
}
// We successfully completed, so cancel any remaining retries.
thread.retry_state = None;
},
}
Err(error) => {
thread.project.update(cx, |project, cx| {
project.set_agent_location(None, cx);
@ -1883,26 +1838,38 @@ impl Thread {
cx.emit(ThreadEvent::ShowError(
ThreadError::ModelRequestLimitReached { plan: error.plan },
));
} else if let Some(known_error) =
error.downcast_ref::<LanguageModelKnownError>()
} else if let Some(completion_error) =
error.downcast_ref::<LanguageModelCompletionError>()
{
match known_error {
LanguageModelKnownError::ContextWindowLimitExceeded { tokens } => {
use LanguageModelCompletionError::*;
match &completion_error {
PromptTooLarge { tokens, .. } => {
let tokens = tokens.unwrap_or_else(|| {
// We didn't get an exact token count from the API, so fall back on our estimate.
thread
.total_token_usage()
.map(|usage| usage.total)
.unwrap_or(0)
// We know the context window was exceeded in practice, so if our estimate was
// lower than max tokens, the estimate was wrong; return that we exceeded by 1.
.max(model.max_token_count().saturating_add(1))
});
thread.exceeded_window_error = Some(ExceededWindowError {
model_id: model.id(),
token_count: *tokens,
token_count: tokens,
});
cx.notify();
}
LanguageModelKnownError::RateLimitExceeded { retry_after } => {
let provider_name = model.provider_name();
let error_message = format!(
"{}'s API rate limit exceeded",
provider_name.0.as_ref()
);
RateLimitExceeded {
retry_after: Some(retry_after),
..
}
| ServerOverloaded {
retry_after: Some(retry_after),
..
} => {
thread.handle_rate_limit_error(
&error_message,
&completion_error,
*retry_after,
model.clone(),
intent,
@ -1911,15 +1878,9 @@ impl Thread {
);
retry_scheduled = true;
}
LanguageModelKnownError::Overloaded => {
let provider_name = model.provider_name();
let error_message = format!(
"{}'s API servers are overloaded right now",
provider_name.0.as_ref()
);
RateLimitExceeded { .. } | ServerOverloaded { .. } => {
retry_scheduled = thread.handle_retryable_error(
&error_message,
&completion_error,
model.clone(),
intent,
window,
@ -1929,15 +1890,11 @@ impl Thread {
emit_generic_error(error, cx);
}
}
LanguageModelKnownError::ApiInternalServerError => {
let provider_name = model.provider_name();
let error_message = format!(
"{}'s API server reported an internal server error",
provider_name.0.as_ref()
);
ApiInternalServerError { .. }
| ApiReadResponseError { .. }
| HttpSend { .. } => {
retry_scheduled = thread.handle_retryable_error(
&error_message,
&completion_error,
model.clone(),
intent,
window,
@ -1947,12 +1904,16 @@ impl Thread {
emit_generic_error(error, cx);
}
}
LanguageModelKnownError::ReadResponseError(_) |
LanguageModelKnownError::DeserializeResponse(_) |
LanguageModelKnownError::UnknownResponseFormat(_) => {
// In the future we will attempt to re-roll response, but only once
emit_generic_error(error, cx);
}
NoApiKey { .. }
| HttpResponseError { .. }
| BadRequestFormat { .. }
| AuthenticationError { .. }
| PermissionError { .. }
| ApiEndpointNotFound { .. }
| SerializeRequest { .. }
| BuildRequestBody { .. }
| DeserializeResponse { .. }
| Other { .. } => emit_generic_error(error, cx),
}
} else {
emit_generic_error(error, cx);
@ -2084,7 +2045,7 @@ impl Thread {
fn handle_rate_limit_error(
&mut self,
error_message: &str,
error: &LanguageModelCompletionError,
retry_after: Duration,
model: Arc<dyn LanguageModel>,
intent: CompletionIntent,
@ -2092,9 +2053,10 @@ impl Thread {
cx: &mut Context<Self>,
) {
// For rate limit errors, we only retry once with the specified duration
let retry_message = format!(
"{error_message}. Retrying in {} seconds…",
retry_after.as_secs()
let retry_message = format!("{error}. Retrying in {} seconds…", retry_after.as_secs());
log::warn!(
"Retrying completion request in {} seconds: {error:?}",
retry_after.as_secs(),
);
// Add a UI-only message instead of a regular message
@ -2127,18 +2089,18 @@ impl Thread {
fn handle_retryable_error(
&mut self,
error_message: &str,
error: &LanguageModelCompletionError,
model: Arc<dyn LanguageModel>,
intent: CompletionIntent,
window: Option<AnyWindowHandle>,
cx: &mut Context<Self>,
) -> bool {
self.handle_retryable_error_with_delay(error_message, None, model, intent, window, cx)
self.handle_retryable_error_with_delay(error, None, model, intent, window, cx)
}
fn handle_retryable_error_with_delay(
&mut self,
error_message: &str,
error: &LanguageModelCompletionError,
custom_delay: Option<Duration>,
model: Arc<dyn LanguageModel>,
intent: CompletionIntent,
@ -2168,8 +2130,12 @@ impl Thread {
// Add a transient message to inform the user
let delay_secs = delay.as_secs();
let retry_message = format!(
"{}. Retrying (attempt {} of {}) in {} seconds...",
error_message, attempt, max_attempts, delay_secs
"{error}. Retrying (attempt {attempt} of {max_attempts}) \
in {delay_secs} seconds..."
);
log::warn!(
"Retrying completion request (attempt {attempt} of {max_attempts}) \
in {delay_secs} seconds: {error:?}",
);
// Add a UI-only message instead of a regular message
@ -4139,9 +4105,15 @@ fn main() {{
>,
> {
let error = match self.error_type {
TestError::Overloaded => LanguageModelCompletionError::Overloaded,
TestError::Overloaded => LanguageModelCompletionError::ServerOverloaded {
provider: self.provider_name(),
retry_after: None,
},
TestError::InternalServerError => {
LanguageModelCompletionError::ApiInternalServerError
LanguageModelCompletionError::ApiInternalServerError {
provider: self.provider_name(),
message: "I'm a teapot orbiting the sun".to_string(),
}
}
};
async move {
@ -4649,9 +4621,13 @@ fn main() {{
> {
if !*self.failed_once.lock() {
*self.failed_once.lock() = true;
let provider = self.provider_name();
// Return error on first attempt
let stream = futures::stream::once(async move {
Err(LanguageModelCompletionError::Overloaded)
Err(LanguageModelCompletionError::ServerOverloaded {
provider,
retry_after: None,
})
});
async move { Ok(stream.boxed()) }.boxed()
} else {
@ -4814,9 +4790,13 @@ fn main() {{
> {
if !*self.failed_once.lock() {
*self.failed_once.lock() = true;
let provider = self.provider_name();
// Return error on first attempt
let stream = futures::stream::once(async move {
Err(LanguageModelCompletionError::Overloaded)
Err(LanguageModelCompletionError::ServerOverloaded {
provider,
retry_after: None,
})
});
async move { Ok(stream.boxed()) }.boxed()
} else {
@ -4969,10 +4949,12 @@ fn main() {{
LanguageModelCompletionError,
>,
> {
let provider = self.provider_name();
async move {
let stream = futures::stream::once(async move {
Err(LanguageModelCompletionError::RateLimitExceeded {
retry_after: Duration::from_secs(TEST_RATE_LIMIT_RETRY_SECS),
provider,
retry_after: Some(Duration::from_secs(TEST_RATE_LIMIT_RETRY_SECS)),
})
});
Ok(stream.boxed())