Thread Anthropic errors into LanguageModelKnownError (#33261)
This PR is in preparation for doing automatic retries for certain errors, e.g. Overloaded. It doesn't change behavior yet (aside from some granularity of error messages shown to the user), but rather mostly changes some error handling to be exhaustive enum matches instead of `anyhow` downcasts, and leaves some comments for where the behavior change will be in a future PR. Release Notes: - N/A
This commit is contained in:
parent
aabfea4c10
commit
c610ebfb03
5 changed files with 283 additions and 123 deletions
|
@ -1495,12 +1495,52 @@ impl Thread {
|
||||||
thread.update(cx, |thread, cx| {
|
thread.update(cx, |thread, cx| {
|
||||||
let event = match event {
|
let event = match event {
|
||||||
Ok(event) => event,
|
Ok(event) => event,
|
||||||
Err(LanguageModelCompletionError::BadInputJson {
|
Err(error) => {
|
||||||
|
match error {
|
||||||
|
LanguageModelCompletionError::RateLimitExceeded { retry_after } => {
|
||||||
|
anyhow::bail!(LanguageModelKnownError::RateLimitExceeded { retry_after });
|
||||||
|
}
|
||||||
|
LanguageModelCompletionError::Overloaded => {
|
||||||
|
anyhow::bail!(LanguageModelKnownError::Overloaded);
|
||||||
|
}
|
||||||
|
LanguageModelCompletionError::ApiInternalServerError =>{
|
||||||
|
anyhow::bail!(LanguageModelKnownError::ApiInternalServerError);
|
||||||
|
}
|
||||||
|
LanguageModelCompletionError::PromptTooLarge { tokens } => {
|
||||||
|
let tokens = tokens.unwrap_or_else(|| {
|
||||||
|
// We didn't get an exact token count from the API, so fall back on our estimate.
|
||||||
|
thread.total_token_usage()
|
||||||
|
.map(|usage| usage.total)
|
||||||
|
.unwrap_or(0)
|
||||||
|
// We know the context window was exceeded in practice, so if our estimate was
|
||||||
|
// lower than max tokens, the estimate was wrong; return that we exceeded by 1.
|
||||||
|
.max(model.max_token_count().saturating_add(1))
|
||||||
|
});
|
||||||
|
|
||||||
|
anyhow::bail!(LanguageModelKnownError::ContextWindowLimitExceeded { tokens })
|
||||||
|
}
|
||||||
|
LanguageModelCompletionError::ApiReadResponseError(io_error) => {
|
||||||
|
anyhow::bail!(LanguageModelKnownError::ReadResponseError(io_error));
|
||||||
|
}
|
||||||
|
LanguageModelCompletionError::UnknownResponseFormat(error) => {
|
||||||
|
anyhow::bail!(LanguageModelKnownError::UnknownResponseFormat(error));
|
||||||
|
}
|
||||||
|
LanguageModelCompletionError::HttpResponseError { status, ref body } => {
|
||||||
|
if let Some(known_error) = LanguageModelKnownError::from_http_response(status, body) {
|
||||||
|
anyhow::bail!(known_error);
|
||||||
|
} else {
|
||||||
|
return Err(error.into());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
LanguageModelCompletionError::DeserializeResponse(error) => {
|
||||||
|
anyhow::bail!(LanguageModelKnownError::DeserializeResponse(error));
|
||||||
|
}
|
||||||
|
LanguageModelCompletionError::BadInputJson {
|
||||||
id,
|
id,
|
||||||
tool_name,
|
tool_name,
|
||||||
raw_input: invalid_input_json,
|
raw_input: invalid_input_json,
|
||||||
json_parse_error,
|
json_parse_error,
|
||||||
}) => {
|
} => {
|
||||||
thread.receive_invalid_tool_json(
|
thread.receive_invalid_tool_json(
|
||||||
id,
|
id,
|
||||||
tool_name,
|
tool_name,
|
||||||
|
@ -1511,11 +1551,20 @@ impl Thread {
|
||||||
);
|
);
|
||||||
return Ok(());
|
return Ok(());
|
||||||
}
|
}
|
||||||
Err(LanguageModelCompletionError::Other(error)) => {
|
// These are all errors we can't automatically attempt to recover from (e.g. by retrying)
|
||||||
|
err @ LanguageModelCompletionError::BadRequestFormat |
|
||||||
|
err @ LanguageModelCompletionError::AuthenticationError |
|
||||||
|
err @ LanguageModelCompletionError::PermissionError |
|
||||||
|
err @ LanguageModelCompletionError::ApiEndpointNotFound |
|
||||||
|
err @ LanguageModelCompletionError::SerializeRequest(_) |
|
||||||
|
err @ LanguageModelCompletionError::BuildRequestBody(_) |
|
||||||
|
err @ LanguageModelCompletionError::HttpSend(_) => {
|
||||||
|
anyhow::bail!(err);
|
||||||
|
}
|
||||||
|
LanguageModelCompletionError::Other(error) => {
|
||||||
return Err(error);
|
return Err(error);
|
||||||
}
|
}
|
||||||
Err(err @ LanguageModelCompletionError::RateLimit(..)) => {
|
}
|
||||||
return Err(err.into());
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -1751,6 +1800,18 @@ impl Thread {
|
||||||
project.set_agent_location(None, cx);
|
project.set_agent_location(None, cx);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
fn emit_generic_error(error: &anyhow::Error, cx: &mut Context<Thread>) {
|
||||||
|
let error_message = error
|
||||||
|
.chain()
|
||||||
|
.map(|err| err.to_string())
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
.join("\n");
|
||||||
|
cx.emit(ThreadEvent::ShowError(ThreadError::Message {
|
||||||
|
header: "Error interacting with language model".into(),
|
||||||
|
message: SharedString::from(error_message.clone()),
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
if error.is::<PaymentRequiredError>() {
|
if error.is::<PaymentRequiredError>() {
|
||||||
cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
|
cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
|
||||||
} else if let Some(error) =
|
} else if let Some(error) =
|
||||||
|
@ -1763,26 +1824,34 @@ impl Thread {
|
||||||
error.downcast_ref::<LanguageModelKnownError>()
|
error.downcast_ref::<LanguageModelKnownError>()
|
||||||
{
|
{
|
||||||
match known_error {
|
match known_error {
|
||||||
LanguageModelKnownError::ContextWindowLimitExceeded {
|
LanguageModelKnownError::ContextWindowLimitExceeded { tokens } => {
|
||||||
tokens,
|
|
||||||
} => {
|
|
||||||
thread.exceeded_window_error = Some(ExceededWindowError {
|
thread.exceeded_window_error = Some(ExceededWindowError {
|
||||||
model_id: model.id(),
|
model_id: model.id(),
|
||||||
token_count: *tokens,
|
token_count: *tokens,
|
||||||
});
|
});
|
||||||
cx.notify();
|
cx.notify();
|
||||||
}
|
}
|
||||||
|
LanguageModelKnownError::RateLimitExceeded { .. } => {
|
||||||
|
// In the future we will report the error to the user, wait retry_after, and then retry.
|
||||||
|
emit_generic_error(error, cx);
|
||||||
|
}
|
||||||
|
LanguageModelKnownError::Overloaded => {
|
||||||
|
// In the future we will wait and then retry, up to N times.
|
||||||
|
emit_generic_error(error, cx);
|
||||||
|
}
|
||||||
|
LanguageModelKnownError::ApiInternalServerError => {
|
||||||
|
// In the future we will retry the request, but only once.
|
||||||
|
emit_generic_error(error, cx);
|
||||||
|
}
|
||||||
|
LanguageModelKnownError::ReadResponseError(_) |
|
||||||
|
LanguageModelKnownError::DeserializeResponse(_) |
|
||||||
|
LanguageModelKnownError::UnknownResponseFormat(_) => {
|
||||||
|
// In the future we will attempt to re-roll response, but only once
|
||||||
|
emit_generic_error(error, cx);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
let error_message = error
|
emit_generic_error(error, cx);
|
||||||
.chain()
|
|
||||||
.map(|err| err.to_string())
|
|
||||||
.collect::<Vec<_>>()
|
|
||||||
.join("\n");
|
|
||||||
cx.emit(ThreadEvent::ShowError(ThreadError::Message {
|
|
||||||
header: "Error interacting with language model".into(),
|
|
||||||
message: SharedString::from(error_message.clone()),
|
|
||||||
}));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
thread.cancel_last_completion(window, cx);
|
thread.cancel_last_completion(window, cx);
|
||||||
|
|
|
@ -1,10 +1,11 @@
|
||||||
|
use std::io;
|
||||||
use std::str::FromStr;
|
use std::str::FromStr;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
|
|
||||||
use anyhow::{Context as _, Result, anyhow};
|
use anyhow::{Context as _, Result, anyhow};
|
||||||
use chrono::{DateTime, Utc};
|
use chrono::{DateTime, Utc};
|
||||||
use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
|
use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
|
||||||
use http_client::http::{HeaderMap, HeaderValue};
|
use http_client::http::{self, HeaderMap, HeaderValue};
|
||||||
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
|
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use strum::{EnumIter, EnumString};
|
use strum::{EnumIter, EnumString};
|
||||||
|
@ -336,7 +337,7 @@ pub async fn complete(
|
||||||
let uri = format!("{api_url}/v1/messages");
|
let uri = format!("{api_url}/v1/messages");
|
||||||
let beta_headers = Model::from_id(&request.model)
|
let beta_headers = Model::from_id(&request.model)
|
||||||
.map(|model| model.beta_headers())
|
.map(|model| model.beta_headers())
|
||||||
.unwrap_or_else(|_err| Model::DEFAULT_BETA_HEADERS.join(","));
|
.unwrap_or_else(|_| Model::DEFAULT_BETA_HEADERS.join(","));
|
||||||
let request_builder = HttpRequest::builder()
|
let request_builder = HttpRequest::builder()
|
||||||
.method(Method::POST)
|
.method(Method::POST)
|
||||||
.uri(uri)
|
.uri(uri)
|
||||||
|
@ -346,39 +347,30 @@ pub async fn complete(
|
||||||
.header("Content-Type", "application/json");
|
.header("Content-Type", "application/json");
|
||||||
|
|
||||||
let serialized_request =
|
let serialized_request =
|
||||||
serde_json::to_string(&request).context("failed to serialize request")?;
|
serde_json::to_string(&request).map_err(AnthropicError::SerializeRequest)?;
|
||||||
let request = request_builder
|
let request = request_builder
|
||||||
.body(AsyncBody::from(serialized_request))
|
.body(AsyncBody::from(serialized_request))
|
||||||
.context("failed to construct request body")?;
|
.map_err(AnthropicError::BuildRequestBody)?;
|
||||||
|
|
||||||
let mut response = client
|
let mut response = client
|
||||||
.send(request)
|
.send(request)
|
||||||
.await
|
.await
|
||||||
.context("failed to send request to Anthropic")?;
|
.map_err(AnthropicError::HttpSend)?;
|
||||||
if response.status().is_success() {
|
let status = response.status();
|
||||||
let mut body = Vec::new();
|
let mut body = String::new();
|
||||||
response
|
response
|
||||||
.body_mut()
|
.body_mut()
|
||||||
.read_to_end(&mut body)
|
.read_to_string(&mut body)
|
||||||
.await
|
.await
|
||||||
.context("failed to read response body")?;
|
.map_err(AnthropicError::ReadResponse)?;
|
||||||
let response_message: Response =
|
|
||||||
serde_json::from_slice(&body).context("failed to deserialize response body")?;
|
if status.is_success() {
|
||||||
Ok(response_message)
|
Ok(serde_json::from_str(&body).map_err(AnthropicError::DeserializeResponse)?)
|
||||||
} else {
|
} else {
|
||||||
let mut body = Vec::new();
|
Err(AnthropicError::HttpResponseError {
|
||||||
response
|
status: status.as_u16(),
|
||||||
.body_mut()
|
body,
|
||||||
.read_to_end(&mut body)
|
})
|
||||||
.await
|
|
||||||
.context("failed to read response body")?;
|
|
||||||
let body_str =
|
|
||||||
std::str::from_utf8(&body).context("failed to parse response body as UTF-8")?;
|
|
||||||
Err(AnthropicError::Other(anyhow!(
|
|
||||||
"Failed to connect to API: {} {}",
|
|
||||||
response.status(),
|
|
||||||
body_str
|
|
||||||
)))
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -491,7 +483,7 @@ pub async fn stream_completion_with_rate_limit_info(
|
||||||
let uri = format!("{api_url}/v1/messages");
|
let uri = format!("{api_url}/v1/messages");
|
||||||
let beta_headers = Model::from_id(&request.base.model)
|
let beta_headers = Model::from_id(&request.base.model)
|
||||||
.map(|model| model.beta_headers())
|
.map(|model| model.beta_headers())
|
||||||
.unwrap_or_else(|_err| Model::DEFAULT_BETA_HEADERS.join(","));
|
.unwrap_or_else(|_| Model::DEFAULT_BETA_HEADERS.join(","));
|
||||||
let request_builder = HttpRequest::builder()
|
let request_builder = HttpRequest::builder()
|
||||||
.method(Method::POST)
|
.method(Method::POST)
|
||||||
.uri(uri)
|
.uri(uri)
|
||||||
|
@ -500,15 +492,15 @@ pub async fn stream_completion_with_rate_limit_info(
|
||||||
.header("X-Api-Key", api_key)
|
.header("X-Api-Key", api_key)
|
||||||
.header("Content-Type", "application/json");
|
.header("Content-Type", "application/json");
|
||||||
let serialized_request =
|
let serialized_request =
|
||||||
serde_json::to_string(&request).context("failed to serialize request")?;
|
serde_json::to_string(&request).map_err(AnthropicError::SerializeRequest)?;
|
||||||
let request = request_builder
|
let request = request_builder
|
||||||
.body(AsyncBody::from(serialized_request))
|
.body(AsyncBody::from(serialized_request))
|
||||||
.context("failed to construct request body")?;
|
.map_err(AnthropicError::BuildRequestBody)?;
|
||||||
|
|
||||||
let mut response = client
|
let mut response = client
|
||||||
.send(request)
|
.send(request)
|
||||||
.await
|
.await
|
||||||
.context("failed to send request to Anthropic")?;
|
.map_err(AnthropicError::HttpSend)?;
|
||||||
let rate_limits = RateLimitInfo::from_headers(response.headers());
|
let rate_limits = RateLimitInfo::from_headers(response.headers());
|
||||||
if response.status().is_success() {
|
if response.status().is_success() {
|
||||||
let reader = BufReader::new(response.into_body());
|
let reader = BufReader::new(response.into_body());
|
||||||
|
@ -520,37 +512,31 @@ pub async fn stream_completion_with_rate_limit_info(
|
||||||
let line = line.strip_prefix("data: ")?;
|
let line = line.strip_prefix("data: ")?;
|
||||||
match serde_json::from_str(line) {
|
match serde_json::from_str(line) {
|
||||||
Ok(response) => Some(Ok(response)),
|
Ok(response) => Some(Ok(response)),
|
||||||
Err(error) => Some(Err(AnthropicError::Other(anyhow!(error)))),
|
Err(error) => Some(Err(AnthropicError::DeserializeResponse(error))),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Err(error) => Some(Err(AnthropicError::Other(anyhow!(error)))),
|
Err(error) => Some(Err(AnthropicError::ReadResponse(error))),
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
.boxed();
|
.boxed();
|
||||||
Ok((stream, Some(rate_limits)))
|
Ok((stream, Some(rate_limits)))
|
||||||
} else if let Some(retry_after) = rate_limits.retry_after {
|
} else if let Some(retry_after) = rate_limits.retry_after {
|
||||||
Err(AnthropicError::RateLimit(retry_after))
|
Err(AnthropicError::RateLimit { retry_after })
|
||||||
} else {
|
} else {
|
||||||
let mut body = Vec::new();
|
let mut body = String::new();
|
||||||
response
|
response
|
||||||
.body_mut()
|
.body_mut()
|
||||||
.read_to_end(&mut body)
|
.read_to_string(&mut body)
|
||||||
.await
|
.await
|
||||||
.context("failed to read response body")?;
|
.map_err(AnthropicError::ReadResponse)?;
|
||||||
|
|
||||||
let body_str =
|
match serde_json::from_str::<Event>(&body) {
|
||||||
std::str::from_utf8(&body).context("failed to parse response body as UTF-8")?;
|
|
||||||
|
|
||||||
match serde_json::from_str::<Event>(body_str) {
|
|
||||||
Ok(Event::Error { error }) => Err(AnthropicError::ApiError(error)),
|
Ok(Event::Error { error }) => Err(AnthropicError::ApiError(error)),
|
||||||
Ok(_) => Err(AnthropicError::Other(anyhow!(
|
Ok(_) => Err(AnthropicError::UnexpectedResponseFormat(body)),
|
||||||
"Unexpected success response while expecting an error: '{body_str}'",
|
Err(_) => Err(AnthropicError::HttpResponseError {
|
||||||
))),
|
status: response.status().as_u16(),
|
||||||
Err(_) => Err(AnthropicError::Other(anyhow!(
|
body: body,
|
||||||
"Failed to connect to API: {} {}",
|
}),
|
||||||
response.status(),
|
|
||||||
body_str,
|
|
||||||
))),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -797,17 +783,38 @@ pub struct MessageDelta {
|
||||||
pub stop_sequence: Option<String>,
|
pub stop_sequence: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Error, Debug)]
|
#[derive(Debug)]
|
||||||
pub enum AnthropicError {
|
pub enum AnthropicError {
|
||||||
#[error("rate limit exceeded, retry after {0:?}")]
|
/// Failed to serialize the HTTP request body to JSON
|
||||||
RateLimit(Duration),
|
SerializeRequest(serde_json::Error),
|
||||||
#[error("an error occurred while interacting with the Anthropic API: {error_type}: {message}", error_type = .0.error_type, message = .0.message)]
|
|
||||||
|
/// Failed to construct the HTTP request body
|
||||||
|
BuildRequestBody(http::Error),
|
||||||
|
|
||||||
|
/// Failed to send the HTTP request
|
||||||
|
HttpSend(anyhow::Error),
|
||||||
|
|
||||||
|
/// Failed to deserialize the response from JSON
|
||||||
|
DeserializeResponse(serde_json::Error),
|
||||||
|
|
||||||
|
/// Failed to read from response stream
|
||||||
|
ReadResponse(io::Error),
|
||||||
|
|
||||||
|
/// HTTP error response from the API
|
||||||
|
HttpResponseError { status: u16, body: String },
|
||||||
|
|
||||||
|
/// Rate limit exceeded
|
||||||
|
RateLimit { retry_after: Duration },
|
||||||
|
|
||||||
|
/// API returned an error response
|
||||||
ApiError(ApiError),
|
ApiError(ApiError),
|
||||||
#[error("{0}")]
|
|
||||||
Other(#[from] anyhow::Error),
|
/// Unexpected response format
|
||||||
|
UnexpectedResponseFormat(String),
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
#[derive(Debug, Serialize, Deserialize, Error)]
|
||||||
|
#[error("Anthropic API Error: {error_type}: {message}")]
|
||||||
pub struct ApiError {
|
pub struct ApiError {
|
||||||
#[serde(rename = "type")]
|
#[serde(rename = "type")]
|
||||||
pub error_type: String,
|
pub error_type: String,
|
||||||
|
|
|
@ -1659,13 +1659,13 @@ async fn retry_on_rate_limit<R>(mut request: impl AsyncFnMut() -> Result<R>) ->
|
||||||
Ok(result) => return Ok(result),
|
Ok(result) => return Ok(result),
|
||||||
Err(err) => match err.downcast::<LanguageModelCompletionError>() {
|
Err(err) => match err.downcast::<LanguageModelCompletionError>() {
|
||||||
Ok(err) => match err {
|
Ok(err) => match err {
|
||||||
LanguageModelCompletionError::RateLimit(duration) => {
|
LanguageModelCompletionError::RateLimitExceeded { retry_after } => {
|
||||||
// Wait for the duration supplied, with some jitter to avoid all requests being made at the same time.
|
// Wait for the duration supplied, with some jitter to avoid all requests being made at the same time.
|
||||||
let jitter = duration.mul_f64(rand::thread_rng().gen_range(0.0..1.0));
|
let jitter = retry_after.mul_f64(rand::thread_rng().gen_range(0.0..1.0));
|
||||||
eprintln!(
|
eprintln!(
|
||||||
"Attempt #{attempt}: Rate limit exceeded. Retry after {duration:?} + jitter of {jitter:?}"
|
"Attempt #{attempt}: Rate limit exceeded. Retry after {retry_after:?} + jitter of {jitter:?}"
|
||||||
);
|
);
|
||||||
Timer::after(duration + jitter).await;
|
Timer::after(retry_after + jitter).await;
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
_ => return Err(err.into()),
|
_ => return Err(err.into()),
|
||||||
|
|
|
@ -8,19 +8,21 @@ mod telemetry;
|
||||||
#[cfg(any(test, feature = "test-support"))]
|
#[cfg(any(test, feature = "test-support"))]
|
||||||
pub mod fake_provider;
|
pub mod fake_provider;
|
||||||
|
|
||||||
|
use anthropic::{AnthropicError, parse_prompt_too_long};
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use client::Client;
|
use client::Client;
|
||||||
use futures::FutureExt;
|
use futures::FutureExt;
|
||||||
use futures::{StreamExt, future::BoxFuture, stream::BoxStream};
|
use futures::{StreamExt, future::BoxFuture, stream::BoxStream};
|
||||||
use gpui::{AnyElement, AnyView, App, AsyncApp, SharedString, Task, Window};
|
use gpui::{AnyElement, AnyView, App, AsyncApp, SharedString, Task, Window};
|
||||||
|
use http_client::http;
|
||||||
use icons::IconName;
|
use icons::IconName;
|
||||||
use parking_lot::Mutex;
|
use parking_lot::Mutex;
|
||||||
use schemars::JsonSchema;
|
use schemars::JsonSchema;
|
||||||
use serde::{Deserialize, Serialize, de::DeserializeOwned};
|
use serde::{Deserialize, Serialize, de::DeserializeOwned};
|
||||||
use std::fmt;
|
|
||||||
use std::ops::{Add, Sub};
|
use std::ops::{Add, Sub};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
|
use std::{fmt, io};
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use util::serde::is_default;
|
use util::serde::is_default;
|
||||||
use zed_llm_client::CompletionRequestStatus;
|
use zed_llm_client::CompletionRequestStatus;
|
||||||
|
@ -34,6 +36,10 @@ pub use crate::telemetry::*;
|
||||||
|
|
||||||
pub const ZED_CLOUD_PROVIDER_ID: &str = "zed.dev";
|
pub const ZED_CLOUD_PROVIDER_ID: &str = "zed.dev";
|
||||||
|
|
||||||
|
/// If we get a rate limit error that doesn't tell us when we can retry,
|
||||||
|
/// default to waiting this long before retrying.
|
||||||
|
const DEFAULT_RATE_LIMIT_RETRY_AFTER: Duration = Duration::from_secs(4);
|
||||||
|
|
||||||
pub fn init(client: Arc<Client>, cx: &mut App) {
|
pub fn init(client: Arc<Client>, cx: &mut App) {
|
||||||
init_settings(cx);
|
init_settings(cx);
|
||||||
RefreshLlmTokenListener::register(client.clone(), cx);
|
RefreshLlmTokenListener::register(client.clone(), cx);
|
||||||
|
@ -70,8 +76,8 @@ pub enum LanguageModelCompletionEvent {
|
||||||
|
|
||||||
#[derive(Error, Debug)]
|
#[derive(Error, Debug)]
|
||||||
pub enum LanguageModelCompletionError {
|
pub enum LanguageModelCompletionError {
|
||||||
#[error("rate limit exceeded, retry after {0:?}")]
|
#[error("rate limit exceeded, retry after {retry_after:?}")]
|
||||||
RateLimit(Duration),
|
RateLimitExceeded { retry_after: Duration },
|
||||||
#[error("received bad input JSON")]
|
#[error("received bad input JSON")]
|
||||||
BadInputJson {
|
BadInputJson {
|
||||||
id: LanguageModelToolUseId,
|
id: LanguageModelToolUseId,
|
||||||
|
@ -79,8 +85,78 @@ pub enum LanguageModelCompletionError {
|
||||||
raw_input: Arc<str>,
|
raw_input: Arc<str>,
|
||||||
json_parse_error: String,
|
json_parse_error: String,
|
||||||
},
|
},
|
||||||
|
#[error("language model provider's API is overloaded")]
|
||||||
|
Overloaded,
|
||||||
#[error(transparent)]
|
#[error(transparent)]
|
||||||
Other(#[from] anyhow::Error),
|
Other(#[from] anyhow::Error),
|
||||||
|
#[error("invalid request format to language model provider's API")]
|
||||||
|
BadRequestFormat,
|
||||||
|
#[error("authentication error with language model provider's API")]
|
||||||
|
AuthenticationError,
|
||||||
|
#[error("permission error with language model provider's API")]
|
||||||
|
PermissionError,
|
||||||
|
#[error("language model provider API endpoint not found")]
|
||||||
|
ApiEndpointNotFound,
|
||||||
|
#[error("prompt too large for context window")]
|
||||||
|
PromptTooLarge { tokens: Option<u64> },
|
||||||
|
#[error("internal server error in language model provider's API")]
|
||||||
|
ApiInternalServerError,
|
||||||
|
#[error("I/O error reading response from language model provider's API: {0:?}")]
|
||||||
|
ApiReadResponseError(io::Error),
|
||||||
|
#[error("HTTP response error from language model provider's API: status {status} - {body:?}")]
|
||||||
|
HttpResponseError { status: u16, body: String },
|
||||||
|
#[error("error serializing request to language model provider API: {0}")]
|
||||||
|
SerializeRequest(serde_json::Error),
|
||||||
|
#[error("error building request body to language model provider API: {0}")]
|
||||||
|
BuildRequestBody(http::Error),
|
||||||
|
#[error("error sending HTTP request to language model provider API: {0}")]
|
||||||
|
HttpSend(anyhow::Error),
|
||||||
|
#[error("error deserializing language model provider API response: {0}")]
|
||||||
|
DeserializeResponse(serde_json::Error),
|
||||||
|
#[error("unexpected language model provider API response format: {0}")]
|
||||||
|
UnknownResponseFormat(String),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<AnthropicError> for LanguageModelCompletionError {
|
||||||
|
fn from(error: AnthropicError) -> Self {
|
||||||
|
match error {
|
||||||
|
AnthropicError::SerializeRequest(error) => Self::SerializeRequest(error),
|
||||||
|
AnthropicError::BuildRequestBody(error) => Self::BuildRequestBody(error),
|
||||||
|
AnthropicError::HttpSend(error) => Self::HttpSend(error),
|
||||||
|
AnthropicError::DeserializeResponse(error) => Self::DeserializeResponse(error),
|
||||||
|
AnthropicError::ReadResponse(error) => Self::ApiReadResponseError(error),
|
||||||
|
AnthropicError::HttpResponseError { status, body } => {
|
||||||
|
Self::HttpResponseError { status, body }
|
||||||
|
}
|
||||||
|
AnthropicError::RateLimit { retry_after } => Self::RateLimitExceeded { retry_after },
|
||||||
|
AnthropicError::ApiError(api_error) => api_error.into(),
|
||||||
|
AnthropicError::UnexpectedResponseFormat(error) => Self::UnknownResponseFormat(error),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<anthropic::ApiError> for LanguageModelCompletionError {
|
||||||
|
fn from(error: anthropic::ApiError) -> Self {
|
||||||
|
use anthropic::ApiErrorCode::*;
|
||||||
|
|
||||||
|
match error.code() {
|
||||||
|
Some(code) => match code {
|
||||||
|
InvalidRequestError => LanguageModelCompletionError::BadRequestFormat,
|
||||||
|
AuthenticationError => LanguageModelCompletionError::AuthenticationError,
|
||||||
|
PermissionError => LanguageModelCompletionError::PermissionError,
|
||||||
|
NotFoundError => LanguageModelCompletionError::ApiEndpointNotFound,
|
||||||
|
RequestTooLarge => LanguageModelCompletionError::PromptTooLarge {
|
||||||
|
tokens: parse_prompt_too_long(&error.message),
|
||||||
|
},
|
||||||
|
RateLimitError => LanguageModelCompletionError::RateLimitExceeded {
|
||||||
|
retry_after: DEFAULT_RATE_LIMIT_RETRY_AFTER,
|
||||||
|
},
|
||||||
|
ApiError => LanguageModelCompletionError::ApiInternalServerError,
|
||||||
|
OverloadedError => LanguageModelCompletionError::Overloaded,
|
||||||
|
},
|
||||||
|
None => LanguageModelCompletionError::Other(error.into()),
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Indicates the format used to define the input schema for a language model tool.
|
/// Indicates the format used to define the input schema for a language model tool.
|
||||||
|
@ -319,6 +395,33 @@ pub trait LanguageModel: Send + Sync {
|
||||||
pub enum LanguageModelKnownError {
|
pub enum LanguageModelKnownError {
|
||||||
#[error("Context window limit exceeded ({tokens})")]
|
#[error("Context window limit exceeded ({tokens})")]
|
||||||
ContextWindowLimitExceeded { tokens: u64 },
|
ContextWindowLimitExceeded { tokens: u64 },
|
||||||
|
#[error("Language model provider's API is currently overloaded")]
|
||||||
|
Overloaded,
|
||||||
|
#[error("Language model provider's API encountered an internal server error")]
|
||||||
|
ApiInternalServerError,
|
||||||
|
#[error("I/O error while reading response from language model provider's API: {0:?}")]
|
||||||
|
ReadResponseError(io::Error),
|
||||||
|
#[error("Error deserializing response from language model provider's API: {0:?}")]
|
||||||
|
DeserializeResponse(serde_json::Error),
|
||||||
|
#[error("Language model provider's API returned a response in an unknown format")]
|
||||||
|
UnknownResponseFormat(String),
|
||||||
|
#[error("Rate limit exceeded for language model provider's API; retry in {retry_after:?}")]
|
||||||
|
RateLimitExceeded { retry_after: Duration },
|
||||||
|
}
|
||||||
|
|
||||||
|
impl LanguageModelKnownError {
|
||||||
|
/// Attempts to map an HTTP response status code to a known error type.
|
||||||
|
/// Returns None if the status code doesn't map to a specific known error.
|
||||||
|
pub fn from_http_response(status: u16, _body: &str) -> Option<Self> {
|
||||||
|
match status {
|
||||||
|
429 => Some(Self::RateLimitExceeded {
|
||||||
|
retry_after: DEFAULT_RATE_LIMIT_RETRY_AFTER,
|
||||||
|
}),
|
||||||
|
503 => Some(Self::Overloaded),
|
||||||
|
500..=599 => Some(Self::ApiInternalServerError),
|
||||||
|
_ => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub trait LanguageModelTool: 'static + DeserializeOwned + JsonSchema {
|
pub trait LanguageModelTool: 'static + DeserializeOwned + JsonSchema {
|
||||||
|
|
|
@ -16,10 +16,10 @@ use gpui::{
|
||||||
use http_client::HttpClient;
|
use http_client::HttpClient;
|
||||||
use language_model::{
|
use language_model::{
|
||||||
AuthenticateError, LanguageModel, LanguageModelCacheConfiguration,
|
AuthenticateError, LanguageModel, LanguageModelCacheConfiguration,
|
||||||
LanguageModelCompletionError, LanguageModelId, LanguageModelKnownError, LanguageModelName,
|
LanguageModelCompletionError, LanguageModelId, LanguageModelName, LanguageModelProvider,
|
||||||
LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
|
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
|
||||||
LanguageModelProviderState, LanguageModelRequest, LanguageModelToolChoice,
|
LanguageModelRequest, LanguageModelToolChoice, LanguageModelToolResultContent, MessageContent,
|
||||||
LanguageModelToolResultContent, MessageContent, RateLimiter, Role,
|
RateLimiter, Role,
|
||||||
};
|
};
|
||||||
use language_model::{LanguageModelCompletionEvent, LanguageModelToolUse, StopReason};
|
use language_model::{LanguageModelCompletionEvent, LanguageModelToolUse, StopReason};
|
||||||
use schemars::JsonSchema;
|
use schemars::JsonSchema;
|
||||||
|
@ -407,14 +407,7 @@ impl AnthropicModel {
|
||||||
let api_key = api_key.context("Missing Anthropic API Key")?;
|
let api_key = api_key.context("Missing Anthropic API Key")?;
|
||||||
let request =
|
let request =
|
||||||
anthropic::stream_completion(http_client.as_ref(), &api_url, &api_key, request);
|
anthropic::stream_completion(http_client.as_ref(), &api_url, &api_key, request);
|
||||||
request.await.map_err(|err| match err {
|
request.await.map_err(Into::into)
|
||||||
AnthropicError::RateLimit(duration) => {
|
|
||||||
LanguageModelCompletionError::RateLimit(duration)
|
|
||||||
}
|
|
||||||
err @ (AnthropicError::ApiError(..) | AnthropicError::Other(..)) => {
|
|
||||||
LanguageModelCompletionError::Other(anthropic_err_to_anyhow(err))
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
.boxed()
|
.boxed()
|
||||||
}
|
}
|
||||||
|
@ -714,7 +707,7 @@ impl AnthropicEventMapper {
|
||||||
events.flat_map(move |event| {
|
events.flat_map(move |event| {
|
||||||
futures::stream::iter(match event {
|
futures::stream::iter(match event {
|
||||||
Ok(event) => self.map_event(event),
|
Ok(event) => self.map_event(event),
|
||||||
Err(error) => vec![Err(LanguageModelCompletionError::Other(anyhow!(error)))],
|
Err(error) => vec![Err(error.into())],
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -859,9 +852,7 @@ impl AnthropicEventMapper {
|
||||||
vec![Ok(LanguageModelCompletionEvent::Stop(self.stop_reason))]
|
vec![Ok(LanguageModelCompletionEvent::Stop(self.stop_reason))]
|
||||||
}
|
}
|
||||||
Event::Error { error } => {
|
Event::Error { error } => {
|
||||||
vec![Err(LanguageModelCompletionError::Other(anyhow!(
|
vec![Err(error.into())]
|
||||||
AnthropicError::ApiError(error)
|
|
||||||
)))]
|
|
||||||
}
|
}
|
||||||
_ => Vec::new(),
|
_ => Vec::new(),
|
||||||
}
|
}
|
||||||
|
@ -874,16 +865,6 @@ struct RawToolUse {
|
||||||
input_json: String,
|
input_json: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn anthropic_err_to_anyhow(err: AnthropicError) -> anyhow::Error {
|
|
||||||
if let AnthropicError::ApiError(api_err) = &err {
|
|
||||||
if let Some(tokens) = api_err.match_window_exceeded() {
|
|
||||||
return anyhow!(LanguageModelKnownError::ContextWindowLimitExceeded { tokens });
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
anyhow!(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Updates usage data by preferring counts from `new`.
|
/// Updates usage data by preferring counts from `new`.
|
||||||
fn update_usage(usage: &mut Usage, new: &Usage) {
|
fn update_usage(usage: &mut Usage, new: &Usage) {
|
||||||
if let Some(input_tokens) = new.input_tokens {
|
if let Some(input_tokens) = new.input_tokens {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue