agent: Improve error handling and retry for zed-provided models (#33565)

* Updates to `zed_llm_client-0.8.5` which adds support for `retry_after`
when anthropic provides it.

* Distinguishes upstream provider errors and rate limits from errors
that originate from zed's servers

* Moves `LanguageModelCompletionError::BadInputJson` to
`LanguageModelCompletionEvent::ToolUseJsonParseError`. While arguably
this is an error case, the logic in thread is cleaner with this move.
There is also precedent for inclusion of errors in the event type -
`CompletionRequestStatus::Failed` is how cloud errors arrive.

* Updates `PROVIDER_ID` / `PROVIDER_NAME` constants to use proper types
instead of `&str`, since they can be constructed in a const fashion.

* Removes use of `CLIENT_SUPPORTS_EXA_WEB_SEARCH_PROVIDER_HEADER_NAME`
as the server no longer reads this header and just defaults to that
behavior.

Release notes for this is covered by #33275

Release Notes:

- N/A

---------

Co-authored-by: Richard Feldman <oss@rtfeldman.com>
Co-authored-by: Richard <richard@zed.dev>
This commit is contained in:
Michael Sloan 2025-06-30 21:01:32 -06:00 committed by GitHub
parent f022a13091
commit d497f52e17
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
25 changed files with 656 additions and 479 deletions

4
Cargo.lock generated
View file

@ -20139,9 +20139,9 @@ dependencies = [
[[package]]
name = "zed_llm_client"
version = "0.8.4"
version = "0.8.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "de7d9523255f4e00ee3d0918e5407bd252d798a4a8e71f6d37f23317a1588203"
checksum = "c740e29260b8797ad252c202ea09a255b3cbc13f30faaf92fb6b2490336106e0"
dependencies = [
"anyhow",
"serde",

View file

@ -625,7 +625,7 @@ wasmtime = { version = "29", default-features = false, features = [
wasmtime-wasi = "29"
which = "6.0.0"
workspace-hack = "0.1.0"
zed_llm_client = "0.8.4"
zed_llm_client = "0.8.5"
zstd = "0.11"
[workspace.dependencies.async-stripe]

View file

@ -23,11 +23,10 @@ use gpui::{
};
use language_model::{
ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
LanguageModelId, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
LanguageModelToolResultContent, LanguageModelToolUseId, MessageContent,
ModelRequestLimitReachedError, PaymentRequiredError, Role, SelectedModel, StopReason,
TokenUsage,
LanguageModelId, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolResultContent,
LanguageModelToolUseId, MessageContent, ModelRequestLimitReachedError, PaymentRequiredError,
Role, SelectedModel, StopReason, TokenUsage,
};
use postage::stream::Stream as _;
use project::{
@ -1531,82 +1530,7 @@ impl Thread {
}
thread.update(cx, |thread, cx| {
let event = match event {
Ok(event) => event,
Err(error) => {
match error {
LanguageModelCompletionError::RateLimitExceeded { retry_after } => {
anyhow::bail!(LanguageModelKnownError::RateLimitExceeded { retry_after });
}
LanguageModelCompletionError::Overloaded => {
anyhow::bail!(LanguageModelKnownError::Overloaded);
}
LanguageModelCompletionError::ApiInternalServerError =>{
anyhow::bail!(LanguageModelKnownError::ApiInternalServerError);
}
LanguageModelCompletionError::PromptTooLarge { tokens } => {
let tokens = tokens.unwrap_or_else(|| {
// We didn't get an exact token count from the API, so fall back on our estimate.
thread.total_token_usage()
.map(|usage| usage.total)
.unwrap_or(0)
// We know the context window was exceeded in practice, so if our estimate was
// lower than max tokens, the estimate was wrong; return that we exceeded by 1.
.max(model.max_token_count().saturating_add(1))
});
anyhow::bail!(LanguageModelKnownError::ContextWindowLimitExceeded { tokens })
}
LanguageModelCompletionError::ApiReadResponseError(io_error) => {
anyhow::bail!(LanguageModelKnownError::ReadResponseError(io_error));
}
LanguageModelCompletionError::UnknownResponseFormat(error) => {
anyhow::bail!(LanguageModelKnownError::UnknownResponseFormat(error));
}
LanguageModelCompletionError::HttpResponseError { status, ref body } => {
if let Some(known_error) = LanguageModelKnownError::from_http_response(status, body) {
anyhow::bail!(known_error);
} else {
return Err(error.into());
}
}
LanguageModelCompletionError::DeserializeResponse(error) => {
anyhow::bail!(LanguageModelKnownError::DeserializeResponse(error));
}
LanguageModelCompletionError::BadInputJson {
id,
tool_name,
raw_input: invalid_input_json,
json_parse_error,
} => {
thread.receive_invalid_tool_json(
id,
tool_name,
invalid_input_json,
json_parse_error,
window,
cx,
);
return Ok(());
}
// These are all errors we can't automatically attempt to recover from (e.g. by retrying)
err @ LanguageModelCompletionError::BadRequestFormat |
err @ LanguageModelCompletionError::AuthenticationError |
err @ LanguageModelCompletionError::PermissionError |
err @ LanguageModelCompletionError::ApiEndpointNotFound |
err @ LanguageModelCompletionError::SerializeRequest(_) |
err @ LanguageModelCompletionError::BuildRequestBody(_) |
err @ LanguageModelCompletionError::HttpSend(_) => {
anyhow::bail!(err);
}
LanguageModelCompletionError::Other(error) => {
return Err(error);
}
}
}
};
match event {
match event? {
LanguageModelCompletionEvent::StartMessage { .. } => {
request_assistant_message_id =
Some(thread.insert_assistant_message(
@ -1683,9 +1607,7 @@ impl Thread {
};
}
}
LanguageModelCompletionEvent::RedactedThinking {
data
} => {
LanguageModelCompletionEvent::RedactedThinking { data } => {
thread.received_chunk();
if let Some(last_message) = thread.messages.last_mut() {
@ -1734,6 +1656,21 @@ impl Thread {
});
}
}
LanguageModelCompletionEvent::ToolUseJsonParseError {
id,
tool_name,
raw_input: invalid_input_json,
json_parse_error,
} => {
thread.receive_invalid_tool_json(
id,
tool_name,
invalid_input_json,
json_parse_error,
window,
cx,
);
}
LanguageModelCompletionEvent::StatusUpdate(status_update) => {
if let Some(completion) = thread
.pending_completions
@ -1741,23 +1678,34 @@ impl Thread {
.find(|completion| completion.id == pending_completion_id)
{
match status_update {
CompletionRequestStatus::Queued {
position,
} => {
completion.queue_state = QueueState::Queued { position };
CompletionRequestStatus::Queued { position } => {
completion.queue_state =
QueueState::Queued { position };
}
CompletionRequestStatus::Started => {
completion.queue_state = QueueState::Started;
}
CompletionRequestStatus::Failed {
code, message, request_id
code,
message,
request_id: _,
retry_after,
} => {
anyhow::bail!("completion request failed. request_id: {request_id}, code: {code}, message: {message}");
return Err(
LanguageModelCompletionError::from_cloud_failure(
model.upstream_provider_name(),
code,
message,
retry_after.map(Duration::from_secs_f64),
),
);
}
CompletionRequestStatus::UsageUpdated {
amount, limit
} => {
thread.update_model_request_usage(amount as u32, limit, cx);
CompletionRequestStatus::UsageUpdated { amount, limit } => {
thread.update_model_request_usage(
amount as u32,
limit,
cx,
);
}
CompletionRequestStatus::ToolUseLimitReached => {
thread.tool_use_limit_reached = true;
@ -1808,7 +1756,8 @@ impl Thread {
Ok(stop_reason) => {
match stop_reason {
StopReason::ToolUse => {
let tool_uses = thread.use_pending_tools(window, model.clone(), cx);
let tool_uses =
thread.use_pending_tools(window, model.clone(), cx);
cx.emit(ThreadEvent::UsePendingTools { tool_uses });
}
StopReason::EndTurn | StopReason::MaxTokens => {
@ -1827,7 +1776,9 @@ impl Thread {
{
let mut messages_to_remove = Vec::new();
for (ix, message) in thread.messages.iter().enumerate().rev() {
for (ix, message) in
thread.messages.iter().enumerate().rev()
{
messages_to_remove.push(message.id);
if message.role == Role::User {
@ -1835,7 +1786,9 @@ impl Thread {
break;
}
if let Some(prev_message) = thread.messages.get(ix - 1) {
if let Some(prev_message) =
thread.messages.get(ix - 1)
{
if prev_message.role == Role::Assistant {
break;
}
@ -1850,14 +1803,16 @@ impl Thread {
cx.emit(ThreadEvent::ShowError(ThreadError::Message {
header: "Language model refusal".into(),
message: "Model refused to generate content for safety reasons.".into(),
message:
"Model refused to generate content for safety reasons."
.into(),
}));
}
}
// We successfully completed, so cancel any remaining retries.
thread.retry_state = None;
},
}
Err(error) => {
thread.project.update(cx, |project, cx| {
project.set_agent_location(None, cx);
@ -1883,26 +1838,38 @@ impl Thread {
cx.emit(ThreadEvent::ShowError(
ThreadError::ModelRequestLimitReached { plan: error.plan },
));
} else if let Some(known_error) =
error.downcast_ref::<LanguageModelKnownError>()
} else if let Some(completion_error) =
error.downcast_ref::<LanguageModelCompletionError>()
{
match known_error {
LanguageModelKnownError::ContextWindowLimitExceeded { tokens } => {
use LanguageModelCompletionError::*;
match &completion_error {
PromptTooLarge { tokens, .. } => {
let tokens = tokens.unwrap_or_else(|| {
// We didn't get an exact token count from the API, so fall back on our estimate.
thread
.total_token_usage()
.map(|usage| usage.total)
.unwrap_or(0)
// We know the context window was exceeded in practice, so if our estimate was
// lower than max tokens, the estimate was wrong; return that we exceeded by 1.
.max(model.max_token_count().saturating_add(1))
});
thread.exceeded_window_error = Some(ExceededWindowError {
model_id: model.id(),
token_count: *tokens,
token_count: tokens,
});
cx.notify();
}
LanguageModelKnownError::RateLimitExceeded { retry_after } => {
let provider_name = model.provider_name();
let error_message = format!(
"{}'s API rate limit exceeded",
provider_name.0.as_ref()
);
RateLimitExceeded {
retry_after: Some(retry_after),
..
}
| ServerOverloaded {
retry_after: Some(retry_after),
..
} => {
thread.handle_rate_limit_error(
&error_message,
&completion_error,
*retry_after,
model.clone(),
intent,
@ -1911,15 +1878,9 @@ impl Thread {
);
retry_scheduled = true;
}
LanguageModelKnownError::Overloaded => {
let provider_name = model.provider_name();
let error_message = format!(
"{}'s API servers are overloaded right now",
provider_name.0.as_ref()
);
RateLimitExceeded { .. } | ServerOverloaded { .. } => {
retry_scheduled = thread.handle_retryable_error(
&error_message,
&completion_error,
model.clone(),
intent,
window,
@ -1929,15 +1890,11 @@ impl Thread {
emit_generic_error(error, cx);
}
}
LanguageModelKnownError::ApiInternalServerError => {
let provider_name = model.provider_name();
let error_message = format!(
"{}'s API server reported an internal server error",
provider_name.0.as_ref()
);
ApiInternalServerError { .. }
| ApiReadResponseError { .. }
| HttpSend { .. } => {
retry_scheduled = thread.handle_retryable_error(
&error_message,
&completion_error,
model.clone(),
intent,
window,
@ -1947,12 +1904,16 @@ impl Thread {
emit_generic_error(error, cx);
}
}
LanguageModelKnownError::ReadResponseError(_) |
LanguageModelKnownError::DeserializeResponse(_) |
LanguageModelKnownError::UnknownResponseFormat(_) => {
// In the future we will attempt to re-roll response, but only once
emit_generic_error(error, cx);
}
NoApiKey { .. }
| HttpResponseError { .. }
| BadRequestFormat { .. }
| AuthenticationError { .. }
| PermissionError { .. }
| ApiEndpointNotFound { .. }
| SerializeRequest { .. }
| BuildRequestBody { .. }
| DeserializeResponse { .. }
| Other { .. } => emit_generic_error(error, cx),
}
} else {
emit_generic_error(error, cx);
@ -2084,7 +2045,7 @@ impl Thread {
fn handle_rate_limit_error(
&mut self,
error_message: &str,
error: &LanguageModelCompletionError,
retry_after: Duration,
model: Arc<dyn LanguageModel>,
intent: CompletionIntent,
@ -2092,9 +2053,10 @@ impl Thread {
cx: &mut Context<Self>,
) {
// For rate limit errors, we only retry once with the specified duration
let retry_message = format!(
"{error_message}. Retrying in {} seconds…",
retry_after.as_secs()
let retry_message = format!("{error}. Retrying in {} seconds…", retry_after.as_secs());
log::warn!(
"Retrying completion request in {} seconds: {error:?}",
retry_after.as_secs(),
);
// Add a UI-only message instead of a regular message
@ -2127,18 +2089,18 @@ impl Thread {
fn handle_retryable_error(
&mut self,
error_message: &str,
error: &LanguageModelCompletionError,
model: Arc<dyn LanguageModel>,
intent: CompletionIntent,
window: Option<AnyWindowHandle>,
cx: &mut Context<Self>,
) -> bool {
self.handle_retryable_error_with_delay(error_message, None, model, intent, window, cx)
self.handle_retryable_error_with_delay(error, None, model, intent, window, cx)
}
fn handle_retryable_error_with_delay(
&mut self,
error_message: &str,
error: &LanguageModelCompletionError,
custom_delay: Option<Duration>,
model: Arc<dyn LanguageModel>,
intent: CompletionIntent,
@ -2168,8 +2130,12 @@ impl Thread {
// Add a transient message to inform the user
let delay_secs = delay.as_secs();
let retry_message = format!(
"{}. Retrying (attempt {} of {}) in {} seconds...",
error_message, attempt, max_attempts, delay_secs
"{error}. Retrying (attempt {attempt} of {max_attempts}) \
in {delay_secs} seconds..."
);
log::warn!(
"Retrying completion request (attempt {attempt} of {max_attempts}) \
in {delay_secs} seconds: {error:?}",
);
// Add a UI-only message instead of a regular message
@ -4139,9 +4105,15 @@ fn main() {{
>,
> {
let error = match self.error_type {
TestError::Overloaded => LanguageModelCompletionError::Overloaded,
TestError::Overloaded => LanguageModelCompletionError::ServerOverloaded {
provider: self.provider_name(),
retry_after: None,
},
TestError::InternalServerError => {
LanguageModelCompletionError::ApiInternalServerError
LanguageModelCompletionError::ApiInternalServerError {
provider: self.provider_name(),
message: "I'm a teapot orbiting the sun".to_string(),
}
}
};
async move {
@ -4649,9 +4621,13 @@ fn main() {{
> {
if !*self.failed_once.lock() {
*self.failed_once.lock() = true;
let provider = self.provider_name();
// Return error on first attempt
let stream = futures::stream::once(async move {
Err(LanguageModelCompletionError::Overloaded)
Err(LanguageModelCompletionError::ServerOverloaded {
provider,
retry_after: None,
})
});
async move { Ok(stream.boxed()) }.boxed()
} else {
@ -4814,9 +4790,13 @@ fn main() {{
> {
if !*self.failed_once.lock() {
*self.failed_once.lock() = true;
let provider = self.provider_name();
// Return error on first attempt
let stream = futures::stream::once(async move {
Err(LanguageModelCompletionError::Overloaded)
Err(LanguageModelCompletionError::ServerOverloaded {
provider,
retry_after: None,
})
});
async move { Ok(stream.boxed()) }.boxed()
} else {
@ -4969,10 +4949,12 @@ fn main() {{
LanguageModelCompletionError,
>,
> {
let provider = self.provider_name();
async move {
let stream = futures::stream::once(async move {
Err(LanguageModelCompletionError::RateLimitExceeded {
retry_after: Duration::from_secs(TEST_RATE_LIMIT_RETRY_SECS),
provider,
retry_after: Some(Duration::from_secs(TEST_RATE_LIMIT_RETRY_SECS)),
})
});
Ok(stream.boxed())

View file

@ -2025,9 +2025,7 @@ impl AgentPanel {
.thread()
.read(cx)
.configured_model()
.map_or(false, |model| {
model.provider.id().0 == ZED_CLOUD_PROVIDER_ID
});
.map_or(false, |model| model.provider.id() == ZED_CLOUD_PROVIDER_ID);
if !is_using_zed_provider {
return false;

View file

@ -1250,9 +1250,7 @@ impl MessageEditor {
self.thread
.read(cx)
.configured_model()
.map_or(false, |model| {
model.provider.id().0 == ZED_CLOUD_PROVIDER_ID
})
.map_or(false, |model| model.provider.id() == ZED_CLOUD_PROVIDER_ID)
}
fn render_usage_callout(&self, line_height: Pixels, cx: &mut Context<Self>) -> Option<Div> {

View file

@ -6,7 +6,7 @@ use anyhow::{Context as _, Result, anyhow};
use chrono::{DateTime, Utc};
use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
use http_client::http::{self, HeaderMap, HeaderValue};
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest, StatusCode};
use serde::{Deserialize, Serialize};
use strum::{EnumIter, EnumString};
use thiserror::Error;
@ -356,7 +356,7 @@ pub async fn complete(
.send(request)
.await
.map_err(AnthropicError::HttpSend)?;
let status = response.status();
let status_code = response.status();
let mut body = String::new();
response
.body_mut()
@ -364,12 +364,12 @@ pub async fn complete(
.await
.map_err(AnthropicError::ReadResponse)?;
if status.is_success() {
if status_code.is_success() {
Ok(serde_json::from_str(&body).map_err(AnthropicError::DeserializeResponse)?)
} else {
Err(AnthropicError::HttpResponseError {
status: status.as_u16(),
body,
status_code,
message: body,
})
}
}
@ -444,11 +444,7 @@ impl RateLimitInfo {
}
Self {
retry_after: headers
.get("retry-after")
.and_then(|v| v.to_str().ok())
.and_then(|v| v.parse::<u64>().ok())
.map(Duration::from_secs),
retry_after: parse_retry_after(headers),
requests: RateLimit::from_headers("requests", headers).ok(),
tokens: RateLimit::from_headers("tokens", headers).ok(),
input_tokens: RateLimit::from_headers("input-tokens", headers).ok(),
@ -457,6 +453,17 @@ impl RateLimitInfo {
}
}
/// Parses the Retry-After header value as an integer number of seconds (anthropic always uses
/// seconds). Note that other services might specify an HTTP date or some other format for this
/// header. Returns `None` if the header is not present or cannot be parsed.
pub fn parse_retry_after(headers: &HeaderMap<HeaderValue>) -> Option<Duration> {
headers
.get("retry-after")
.and_then(|v| v.to_str().ok())
.and_then(|v| v.parse::<u64>().ok())
.map(Duration::from_secs)
}
fn get_header<'a>(key: &str, headers: &'a HeaderMap) -> anyhow::Result<&'a str> {
Ok(headers
.get(key)
@ -520,6 +527,10 @@ pub async fn stream_completion_with_rate_limit_info(
})
.boxed();
Ok((stream, Some(rate_limits)))
} else if response.status().as_u16() == 529 {
Err(AnthropicError::ServerOverloaded {
retry_after: rate_limits.retry_after,
})
} else if let Some(retry_after) = rate_limits.retry_after {
Err(AnthropicError::RateLimit { retry_after })
} else {
@ -532,10 +543,9 @@ pub async fn stream_completion_with_rate_limit_info(
match serde_json::from_str::<Event>(&body) {
Ok(Event::Error { error }) => Err(AnthropicError::ApiError(error)),
Ok(_) => Err(AnthropicError::UnexpectedResponseFormat(body)),
Err(_) => Err(AnthropicError::HttpResponseError {
status: response.status().as_u16(),
body: body,
Ok(_) | Err(_) => Err(AnthropicError::HttpResponseError {
status_code: response.status(),
message: body,
}),
}
}
@ -801,16 +811,19 @@ pub enum AnthropicError {
ReadResponse(io::Error),
/// HTTP error response from the API
HttpResponseError { status: u16, body: String },
HttpResponseError {
status_code: StatusCode,
message: String,
},
/// Rate limit exceeded
RateLimit { retry_after: Duration },
/// Server overloaded
ServerOverloaded { retry_after: Option<Duration> },
/// API returned an error response
ApiError(ApiError),
/// Unexpected response format
UnexpectedResponseFormat(String),
}
#[derive(Debug, Serialize, Deserialize, Error)]

View file

@ -2140,6 +2140,7 @@ impl AssistantContext {
);
}
LanguageModelCompletionEvent::ToolUse(_) |
LanguageModelCompletionEvent::ToolUseJsonParseError { .. } |
LanguageModelCompletionEvent::UsageUpdate(_) => {}
}
});

View file

@ -29,6 +29,7 @@ use std::{
path::Path,
str::FromStr,
sync::mpsc,
time::Duration,
};
use util::path;
@ -1658,12 +1659,14 @@ async fn retry_on_rate_limit<R>(mut request: impl AsyncFnMut() -> Result<R>) ->
match request().await {
Ok(result) => return Ok(result),
Err(err) => match err.downcast::<LanguageModelCompletionError>() {
Ok(err) => match err {
LanguageModelCompletionError::RateLimitExceeded { retry_after } => {
Ok(err) => match &err {
LanguageModelCompletionError::RateLimitExceeded { retry_after, .. }
| LanguageModelCompletionError::ServerOverloaded { retry_after, .. } => {
let retry_after = retry_after.unwrap_or(Duration::from_secs(5));
// Wait for the duration supplied, with some jitter to avoid all requests being made at the same time.
let jitter = retry_after.mul_f64(rand::thread_rng().gen_range(0.0..1.0));
eprintln!(
"Attempt #{attempt}: Rate limit exceeded. Retry after {retry_after:?} + jitter of {jitter:?}"
"Attempt #{attempt}: {err}. Retry after {retry_after:?} + jitter of {jitter:?}"
);
Timer::after(retry_after + jitter).await;
continue;

View file

@ -1054,6 +1054,15 @@ pub fn response_events_to_markdown(
| LanguageModelCompletionEvent::StartMessage { .. }
| LanguageModelCompletionEvent::StatusUpdate { .. },
) => {}
Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
json_parse_error, ..
}) => {
flush_buffers(&mut response, &mut text_buffer, &mut thinking_buffer);
response.push_str(&format!(
"**Error**: parse error in tool use JSON: {}\n\n",
json_parse_error
));
}
Err(error) => {
flush_buffers(&mut response, &mut text_buffer, &mut thinking_buffer);
response.push_str(&format!("**Error**: {}\n\n", error));
@ -1132,6 +1141,17 @@ impl ThreadDialog {
| Ok(LanguageModelCompletionEvent::StartMessage { .. })
| Ok(LanguageModelCompletionEvent::Stop(_)) => {}
Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
json_parse_error,
..
}) => {
flush_text(&mut current_text, &mut content);
content.push(MessageContent::Text(format!(
"ERROR: parse error in tool use JSON: {}",
json_parse_error
)));
}
Err(error) => {
flush_text(&mut current_text, &mut content);
content.push(MessageContent::Text(format!("ERROR: {}", error)));

View file

@ -9,17 +9,18 @@ mod telemetry;
pub mod fake_provider;
use anthropic::{AnthropicError, parse_prompt_too_long};
use anyhow::Result;
use anyhow::{Result, anyhow};
use client::Client;
use futures::FutureExt;
use futures::{StreamExt, future::BoxFuture, stream::BoxStream};
use gpui::{AnyElement, AnyView, App, AsyncApp, SharedString, Task, Window};
use http_client::http;
use http_client::{StatusCode, http};
use icons::IconName;
use parking_lot::Mutex;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize, de::DeserializeOwned};
use std::ops::{Add, Sub};
use std::str::FromStr;
use std::sync::Arc;
use std::time::Duration;
use std::{fmt, io};
@ -34,11 +35,22 @@ pub use crate::request::*;
pub use crate::role::*;
pub use crate::telemetry::*;
pub const ZED_CLOUD_PROVIDER_ID: &str = "zed.dev";
pub const ANTHROPIC_PROVIDER_ID: LanguageModelProviderId =
LanguageModelProviderId::new("anthropic");
pub const ANTHROPIC_PROVIDER_NAME: LanguageModelProviderName =
LanguageModelProviderName::new("Anthropic");
/// If we get a rate limit error that doesn't tell us when we can retry,
/// default to waiting this long before retrying.
const DEFAULT_RATE_LIMIT_RETRY_AFTER: Duration = Duration::from_secs(4);
pub const GOOGLE_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("google");
pub const GOOGLE_PROVIDER_NAME: LanguageModelProviderName =
LanguageModelProviderName::new("Google AI");
pub const OPEN_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("openai");
pub const OPEN_AI_PROVIDER_NAME: LanguageModelProviderName =
LanguageModelProviderName::new("OpenAI");
pub const ZED_CLOUD_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("zed.dev");
pub const ZED_CLOUD_PROVIDER_NAME: LanguageModelProviderName =
LanguageModelProviderName::new("Zed");
pub fn init(client: Arc<Client>, cx: &mut App) {
init_settings(cx);
@ -71,6 +83,12 @@ pub enum LanguageModelCompletionEvent {
data: String,
},
ToolUse(LanguageModelToolUse),
ToolUseJsonParseError {
id: LanguageModelToolUseId,
tool_name: Arc<str>,
raw_input: Arc<str>,
json_parse_error: String,
},
StartMessage {
message_id: String,
},
@ -79,61 +97,179 @@ pub enum LanguageModelCompletionEvent {
#[derive(Error, Debug)]
pub enum LanguageModelCompletionError {
#[error("rate limit exceeded, retry after {retry_after:?}")]
RateLimitExceeded { retry_after: Duration },
#[error("received bad input JSON")]
BadInputJson {
id: LanguageModelToolUseId,
tool_name: Arc<str>,
raw_input: Arc<str>,
json_parse_error: String,
},
#[error("language model provider's API is overloaded")]
Overloaded,
#[error(transparent)]
Other(#[from] anyhow::Error),
#[error("invalid request format to language model provider's API")]
BadRequestFormat,
#[error("authentication error with language model provider's API")]
AuthenticationError,
#[error("permission error with language model provider's API")]
PermissionError,
#[error("language model provider API endpoint not found")]
ApiEndpointNotFound,
#[error("prompt too large for context window")]
PromptTooLarge { tokens: Option<u64> },
#[error("internal server error in language model provider's API")]
ApiInternalServerError,
#[error("I/O error reading response from language model provider's API: {0:?}")]
ApiReadResponseError(io::Error),
#[error("HTTP response error from language model provider's API: status {status} - {body:?}")]
HttpResponseError { status: u16, body: String },
#[error("error serializing request to language model provider API: {0}")]
SerializeRequest(serde_json::Error),
#[error("error building request body to language model provider API: {0}")]
BuildRequestBody(http::Error),
#[error("error sending HTTP request to language model provider API: {0}")]
HttpSend(anyhow::Error),
#[error("error deserializing language model provider API response: {0}")]
DeserializeResponse(serde_json::Error),
#[error("unexpected language model provider API response format: {0}")]
UnknownResponseFormat(String),
#[error("missing {provider} API key")]
NoApiKey { provider: LanguageModelProviderName },
#[error("{provider}'s API rate limit exceeded")]
RateLimitExceeded {
provider: LanguageModelProviderName,
retry_after: Option<Duration>,
},
#[error("{provider}'s API servers are overloaded right now")]
ServerOverloaded {
provider: LanguageModelProviderName,
retry_after: Option<Duration>,
},
#[error("{provider}'s API server reported an internal server error: {message}")]
ApiInternalServerError {
provider: LanguageModelProviderName,
message: String,
},
#[error("HTTP response error from {provider}'s API: status {status_code} - {message:?}")]
HttpResponseError {
provider: LanguageModelProviderName,
status_code: StatusCode,
message: String,
},
// Client errors
#[error("invalid request format to {provider}'s API: {message}")]
BadRequestFormat {
provider: LanguageModelProviderName,
message: String,
},
#[error("authentication error with {provider}'s API: {message}")]
AuthenticationError {
provider: LanguageModelProviderName,
message: String,
},
#[error("permission error with {provider}'s API: {message}")]
PermissionError {
provider: LanguageModelProviderName,
message: String,
},
#[error("language model provider API endpoint not found")]
ApiEndpointNotFound { provider: LanguageModelProviderName },
#[error("I/O error reading response from {provider}'s API")]
ApiReadResponseError {
provider: LanguageModelProviderName,
#[source]
error: io::Error,
},
#[error("error serializing request to {provider} API")]
SerializeRequest {
provider: LanguageModelProviderName,
#[source]
error: serde_json::Error,
},
#[error("error building request body to {provider} API")]
BuildRequestBody {
provider: LanguageModelProviderName,
#[source]
error: http::Error,
},
#[error("error sending HTTP request to {provider} API")]
HttpSend {
provider: LanguageModelProviderName,
#[source]
error: anyhow::Error,
},
#[error("error deserializing {provider} API response")]
DeserializeResponse {
provider: LanguageModelProviderName,
#[source]
error: serde_json::Error,
},
// TODO: Ideally this would be removed in favor of having a comprehensive list of errors.
#[error(transparent)]
Other(#[from] anyhow::Error),
}
impl LanguageModelCompletionError {
pub fn from_cloud_failure(
upstream_provider: LanguageModelProviderName,
code: String,
message: String,
retry_after: Option<Duration>,
) -> Self {
if let Some(tokens) = parse_prompt_too_long(&message) {
// TODO: currently Anthropic PAYLOAD_TOO_LARGE response may cause INTERNAL_SERVER_ERROR
// to be reported. This is a temporary workaround to handle this in the case where the
// token limit has been exceeded.
Self::PromptTooLarge {
tokens: Some(tokens),
}
} else if let Some(status_code) = code
.strip_prefix("upstream_http_")
.and_then(|code| StatusCode::from_str(code).ok())
{
Self::from_http_status(upstream_provider, status_code, message, retry_after)
} else if let Some(status_code) = code
.strip_prefix("http_")
.and_then(|code| StatusCode::from_str(code).ok())
{
Self::from_http_status(ZED_CLOUD_PROVIDER_NAME, status_code, message, retry_after)
} else {
anyhow!("completion request failed, code: {code}, message: {message}").into()
}
}
pub fn from_http_status(
provider: LanguageModelProviderName,
status_code: StatusCode,
message: String,
retry_after: Option<Duration>,
) -> Self {
match status_code {
StatusCode::BAD_REQUEST => Self::BadRequestFormat { provider, message },
StatusCode::UNAUTHORIZED => Self::AuthenticationError { provider, message },
StatusCode::FORBIDDEN => Self::PermissionError { provider, message },
StatusCode::NOT_FOUND => Self::ApiEndpointNotFound { provider },
StatusCode::PAYLOAD_TOO_LARGE => Self::PromptTooLarge {
tokens: parse_prompt_too_long(&message),
},
StatusCode::TOO_MANY_REQUESTS => Self::RateLimitExceeded {
provider,
retry_after,
},
StatusCode::INTERNAL_SERVER_ERROR => Self::ApiInternalServerError { provider, message },
StatusCode::SERVICE_UNAVAILABLE => Self::ServerOverloaded {
provider,
retry_after,
},
_ if status_code.as_u16() == 529 => Self::ServerOverloaded {
provider,
retry_after,
},
_ => Self::HttpResponseError {
provider,
status_code,
message,
},
}
}
}
impl From<AnthropicError> for LanguageModelCompletionError {
fn from(error: AnthropicError) -> Self {
let provider = ANTHROPIC_PROVIDER_NAME;
match error {
AnthropicError::SerializeRequest(error) => Self::SerializeRequest(error),
AnthropicError::BuildRequestBody(error) => Self::BuildRequestBody(error),
AnthropicError::HttpSend(error) => Self::HttpSend(error),
AnthropicError::DeserializeResponse(error) => Self::DeserializeResponse(error),
AnthropicError::ReadResponse(error) => Self::ApiReadResponseError(error),
AnthropicError::HttpResponseError { status, body } => {
Self::HttpResponseError { status, body }
AnthropicError::SerializeRequest(error) => Self::SerializeRequest { provider, error },
AnthropicError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error },
AnthropicError::HttpSend(error) => Self::HttpSend { provider, error },
AnthropicError::DeserializeResponse(error) => {
Self::DeserializeResponse { provider, error }
}
AnthropicError::RateLimit { retry_after } => Self::RateLimitExceeded { retry_after },
AnthropicError::ReadResponse(error) => Self::ApiReadResponseError { provider, error },
AnthropicError::HttpResponseError {
status_code,
message,
} => Self::HttpResponseError {
provider,
status_code,
message,
},
AnthropicError::RateLimit { retry_after } => Self::RateLimitExceeded {
provider,
retry_after: Some(retry_after),
},
AnthropicError::ServerOverloaded { retry_after } => Self::ServerOverloaded {
provider,
retry_after: retry_after,
},
AnthropicError::ApiError(api_error) => api_error.into(),
AnthropicError::UnexpectedResponseFormat(error) => Self::UnknownResponseFormat(error),
}
}
}
@ -141,23 +277,39 @@ impl From<AnthropicError> for LanguageModelCompletionError {
impl From<anthropic::ApiError> for LanguageModelCompletionError {
fn from(error: anthropic::ApiError) -> Self {
use anthropic::ApiErrorCode::*;
let provider = ANTHROPIC_PROVIDER_NAME;
match error.code() {
Some(code) => match code {
InvalidRequestError => LanguageModelCompletionError::BadRequestFormat,
AuthenticationError => LanguageModelCompletionError::AuthenticationError,
PermissionError => LanguageModelCompletionError::PermissionError,
NotFoundError => LanguageModelCompletionError::ApiEndpointNotFound,
RequestTooLarge => LanguageModelCompletionError::PromptTooLarge {
InvalidRequestError => Self::BadRequestFormat {
provider,
message: error.message,
},
AuthenticationError => Self::AuthenticationError {
provider,
message: error.message,
},
PermissionError => Self::PermissionError {
provider,
message: error.message,
},
NotFoundError => Self::ApiEndpointNotFound { provider },
RequestTooLarge => Self::PromptTooLarge {
tokens: parse_prompt_too_long(&error.message),
},
RateLimitError => LanguageModelCompletionError::RateLimitExceeded {
retry_after: DEFAULT_RATE_LIMIT_RETRY_AFTER,
RateLimitError => Self::RateLimitExceeded {
provider,
retry_after: None,
},
ApiError => LanguageModelCompletionError::ApiInternalServerError,
OverloadedError => LanguageModelCompletionError::Overloaded,
ApiError => Self::ApiInternalServerError {
provider,
message: error.message,
},
None => LanguageModelCompletionError::Other(error.into()),
OverloadedError => Self::ServerOverloaded {
provider,
retry_after: None,
},
},
None => Self::Other(error.into()),
}
}
}
@ -278,6 +430,13 @@ pub trait LanguageModel: Send + Sync {
fn name(&self) -> LanguageModelName;
fn provider_id(&self) -> LanguageModelProviderId;
fn provider_name(&self) -> LanguageModelProviderName;
fn upstream_provider_id(&self) -> LanguageModelProviderId {
self.provider_id()
}
fn upstream_provider_name(&self) -> LanguageModelProviderName {
self.provider_name()
}
fn telemetry_id(&self) -> String;
fn api_key(&self, _cx: &App) -> Option<String> {
@ -365,6 +524,9 @@ pub trait LanguageModel: Send + Sync {
Ok(LanguageModelCompletionEvent::RedactedThinking { .. }) => None,
Ok(LanguageModelCompletionEvent::Stop(_)) => None,
Ok(LanguageModelCompletionEvent::ToolUse(_)) => None,
Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
..
}) => None,
Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
*last_token_usage.lock() = token_usage;
None
@ -395,39 +557,6 @@ pub trait LanguageModel: Send + Sync {
}
}
#[derive(Debug, Error)]
pub enum LanguageModelKnownError {
#[error("Context window limit exceeded ({tokens})")]
ContextWindowLimitExceeded { tokens: u64 },
#[error("Language model provider's API is currently overloaded")]
Overloaded,
#[error("Language model provider's API encountered an internal server error")]
ApiInternalServerError,
#[error("I/O error while reading response from language model provider's API: {0:?}")]
ReadResponseError(io::Error),
#[error("Error deserializing response from language model provider's API: {0:?}")]
DeserializeResponse(serde_json::Error),
#[error("Language model provider's API returned a response in an unknown format")]
UnknownResponseFormat(String),
#[error("Rate limit exceeded for language model provider's API; retry in {retry_after:?}")]
RateLimitExceeded { retry_after: Duration },
}
impl LanguageModelKnownError {
/// Attempts to map an HTTP response status code to a known error type.
/// Returns None if the status code doesn't map to a specific known error.
pub fn from_http_response(status: u16, _body: &str) -> Option<Self> {
match status {
429 => Some(Self::RateLimitExceeded {
retry_after: DEFAULT_RATE_LIMIT_RETRY_AFTER,
}),
503 => Some(Self::Overloaded),
500..=599 => Some(Self::ApiInternalServerError),
_ => None,
}
}
}
pub trait LanguageModelTool: 'static + DeserializeOwned + JsonSchema {
fn name() -> String;
fn description() -> String;
@ -509,12 +638,30 @@ pub struct LanguageModelProviderId(pub SharedString);
#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
pub struct LanguageModelProviderName(pub SharedString);
impl LanguageModelProviderId {
pub const fn new(id: &'static str) -> Self {
Self(SharedString::new_static(id))
}
}
impl LanguageModelProviderName {
pub const fn new(id: &'static str) -> Self {
Self(SharedString::new_static(id))
}
}
impl fmt::Display for LanguageModelProviderId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.0)
}
}
impl fmt::Display for LanguageModelProviderName {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.0)
}
}
impl From<String> for LanguageModelId {
fn from(value: String) -> Self {
Self(SharedString::from(value))

View file

@ -98,7 +98,7 @@ impl ConfiguredModel {
}
pub fn is_provided_by_zed(&self) -> bool {
self.provider.id().0 == crate::ZED_CLOUD_PROVIDER_ID
self.provider.id() == crate::ZED_CLOUD_PROVIDER_ID
}
}

View file

@ -1,3 +1,4 @@
use crate::ANTHROPIC_PROVIDER_ID;
use anthropic::ANTHROPIC_API_URL;
use anyhow::{Context as _, anyhow};
use client::telemetry::Telemetry;
@ -8,8 +9,6 @@ use std::sync::Arc;
use telemetry_events::{AssistantEventData, AssistantKind, AssistantPhase};
use util::ResultExt;
pub const ANTHROPIC_PROVIDER_ID: &str = "anthropic";
pub fn report_assistant_event(
event: AssistantEventData,
telemetry: Option<Arc<Telemetry>>,
@ -19,7 +18,7 @@ pub fn report_assistant_event(
) {
if let Some(telemetry) = telemetry.as_ref() {
telemetry.report_assistant_event(event.clone());
if telemetry.metrics_enabled() && event.model_provider == ANTHROPIC_PROVIDER_ID {
if telemetry.metrics_enabled() && event.model_provider == ANTHROPIC_PROVIDER_ID.0 {
if let Some(api_key) = model_api_key {
executor
.spawn(async move {

View file

@ -33,8 +33,8 @@ use theme::ThemeSettings;
use ui::{Icon, IconName, List, Tooltip, prelude::*};
use util::ResultExt;
const PROVIDER_ID: &str = language_model::ANTHROPIC_PROVIDER_ID;
const PROVIDER_NAME: &str = "Anthropic";
const PROVIDER_ID: LanguageModelProviderId = language_model::ANTHROPIC_PROVIDER_ID;
const PROVIDER_NAME: LanguageModelProviderName = language_model::ANTHROPIC_PROVIDER_NAME;
#[derive(Default, Clone, Debug, PartialEq)]
pub struct AnthropicSettings {
@ -218,11 +218,11 @@ impl LanguageModelProviderState for AnthropicLanguageModelProvider {
impl LanguageModelProvider for AnthropicLanguageModelProvider {
fn id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into())
PROVIDER_ID
}
fn name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into())
PROVIDER_NAME
}
fn icon(&self) -> IconName {
@ -403,7 +403,11 @@ impl AnthropicModel {
};
async move {
let api_key = api_key.context("Missing Anthropic API Key")?;
let Some(api_key) = api_key else {
return Err(LanguageModelCompletionError::NoApiKey {
provider: PROVIDER_NAME,
});
};
let request =
anthropic::stream_completion(http_client.as_ref(), &api_url, &api_key, request);
request.await.map_err(Into::into)
@ -422,11 +426,11 @@ impl LanguageModel for AnthropicModel {
}
fn provider_id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into())
PROVIDER_ID
}
fn provider_name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into())
PROVIDER_NAME
}
fn supports_tools(&self) -> bool {
@ -806,12 +810,14 @@ impl AnthropicEventMapper {
raw_input: tool_use.input_json.clone(),
},
)),
Err(json_parse_err) => Err(LanguageModelCompletionError::BadInputJson {
Err(json_parse_err) => {
Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
id: tool_use.id.into(),
tool_name: tool_use.name.into(),
raw_input: input_json.into(),
json_parse_error: json_parse_err.to_string(),
}),
})
}
};
vec![event_result]

View file

@ -52,8 +52,8 @@ use util::ResultExt;
use crate::AllLanguageModelSettings;
const PROVIDER_ID: &str = "amazon-bedrock";
const PROVIDER_NAME: &str = "Amazon Bedrock";
const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("amazon-bedrock");
const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("Amazon Bedrock");
#[derive(Default, Clone, Deserialize, Serialize, PartialEq, Debug)]
pub struct BedrockCredentials {
@ -285,11 +285,11 @@ impl BedrockLanguageModelProvider {
impl LanguageModelProvider for BedrockLanguageModelProvider {
fn id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into())
PROVIDER_ID
}
fn name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into())
PROVIDER_NAME
}
fn icon(&self) -> IconName {
@ -489,11 +489,11 @@ impl LanguageModel for BedrockModel {
}
fn provider_id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into())
PROVIDER_ID
}
fn provider_name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into())
PROVIDER_NAME
}
fn supports_tools(&self) -> bool {

View file

@ -1,4 +1,4 @@
use anthropic::{AnthropicModelMode, parse_prompt_too_long};
use anthropic::AnthropicModelMode;
use anyhow::{Context as _, Result, anyhow};
use client::{Client, ModelRequestUsage, UserStore, zed_urls};
use futures::{
@ -8,25 +8,21 @@ use google_ai::GoogleModelMode;
use gpui::{
AnyElement, AnyView, App, AsyncApp, Context, Entity, SemanticVersion, Subscription, Task,
};
use http_client::http::{HeaderMap, HeaderValue};
use http_client::{AsyncBody, HttpClient, Method, Response, StatusCode};
use language_model::{
AuthenticateError, LanguageModel, LanguageModelCacheConfiguration,
LanguageModelCompletionError, LanguageModelId, LanguageModelKnownError, LanguageModelName,
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
LanguageModelProviderTosView, LanguageModelRequest, LanguageModelToolChoice,
LanguageModelToolSchemaFormat, ModelRequestLimitReachedError, RateLimiter,
ZED_CLOUD_PROVIDER_ID,
};
use language_model::{
LanguageModelCompletionEvent, LanguageModelProvider, LlmApiToken, PaymentRequiredError,
RefreshLlmTokenListener,
LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName,
LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelProviderTosView, LanguageModelRequest,
LanguageModelToolChoice, LanguageModelToolSchemaFormat, LlmApiToken,
ModelRequestLimitReachedError, PaymentRequiredError, RateLimiter, RefreshLlmTokenListener,
};
use proto::Plan;
use release_channel::AppVersion;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize, de::DeserializeOwned};
use settings::SettingsStore;
use smol::Timer;
use smol::io::{AsyncReadExt, BufReader};
use std::pin::Pin;
use std::str::FromStr as _;
@ -47,7 +43,8 @@ use crate::provider::anthropic::{AnthropicEventMapper, count_anthropic_tokens, i
use crate::provider::google::{GoogleEventMapper, into_google};
use crate::provider::open_ai::{OpenAiEventMapper, count_open_ai_tokens, into_open_ai};
pub const PROVIDER_NAME: &str = "Zed";
const PROVIDER_ID: LanguageModelProviderId = language_model::ZED_CLOUD_PROVIDER_ID;
const PROVIDER_NAME: LanguageModelProviderName = language_model::ZED_CLOUD_PROVIDER_NAME;
#[derive(Default, Clone, Debug, PartialEq)]
pub struct ZedDotDevSettings {
@ -351,11 +348,11 @@ impl LanguageModelProviderState for CloudLanguageModelProvider {
impl LanguageModelProvider for CloudLanguageModelProvider {
fn id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(ZED_CLOUD_PROVIDER_ID.into())
PROVIDER_ID
}
fn name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into())
PROVIDER_NAME
}
fn icon(&self) -> IconName {
@ -536,8 +533,6 @@ struct PerformLlmCompletionResponse {
}
impl CloudLanguageModel {
const MAX_RETRIES: usize = 3;
async fn perform_llm_completion(
client: Arc<Client>,
llm_api_token: LlmApiToken,
@ -547,8 +542,7 @@ impl CloudLanguageModel {
let http_client = &client.http_client();
let mut token = llm_api_token.acquire(&client).await?;
let mut retries_remaining = Self::MAX_RETRIES;
let mut retry_delay = Duration::from_secs(1);
let mut refreshed_token = false;
loop {
let request_builder = http_client::Request::builder()
@ -590,14 +584,20 @@ impl CloudLanguageModel {
includes_status_messages,
tool_use_limit_reached,
});
} else if response
}
if !refreshed_token
&& response
.headers()
.get(EXPIRED_LLM_TOKEN_HEADER_NAME)
.is_some()
{
retries_remaining -= 1;
token = llm_api_token.refresh(&client).await?;
} else if status == StatusCode::FORBIDDEN
refreshed_token = true;
continue;
}
if status == StatusCode::FORBIDDEN
&& response
.headers()
.get(SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME)
@ -622,35 +622,18 @@ impl CloudLanguageModel {
return Err(anyhow!(ModelRequestLimitReachedError { plan }));
}
}
anyhow::bail!("Forbidden");
} else if status.as_u16() >= 500 && status.as_u16() < 600 {
// If we encounter an error in the 500 range, retry after a delay.
// We've seen at least these in the wild from API providers:
// * 500 Internal Server Error
// * 502 Bad Gateway
// * 529 Service Overloaded
if retries_remaining == 0 {
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
anyhow::bail!(
"cloud language model completion failed after {} retries with status {status}: {body}",
Self::MAX_RETRIES
);
}
Timer::after(retry_delay).await;
retries_remaining -= 1;
retry_delay *= 2; // If it fails again, wait longer.
} else if status == StatusCode::PAYMENT_REQUIRED {
return Err(anyhow!(PaymentRequiredError));
} else {
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
return Err(anyhow!(ApiError { status, body }));
}
let mut body = String::new();
let headers = response.headers().clone();
response.body_mut().read_to_string(&mut body).await?;
return Err(anyhow!(ApiError {
status,
body,
headers
}));
}
}
}
@ -660,6 +643,19 @@ impl CloudLanguageModel {
struct ApiError {
status: StatusCode,
body: String,
headers: HeaderMap<HeaderValue>,
}
impl From<ApiError> for LanguageModelCompletionError {
fn from(error: ApiError) -> Self {
let retry_after = None;
LanguageModelCompletionError::from_http_status(
PROVIDER_NAME,
error.status,
error.body,
retry_after,
)
}
}
impl LanguageModel for CloudLanguageModel {
@ -672,11 +668,29 @@ impl LanguageModel for CloudLanguageModel {
}
fn provider_id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(ZED_CLOUD_PROVIDER_ID.into())
PROVIDER_ID
}
fn provider_name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into())
PROVIDER_NAME
}
fn upstream_provider_id(&self) -> LanguageModelProviderId {
use zed_llm_client::LanguageModelProvider::*;
match self.model.provider {
Anthropic => language_model::ANTHROPIC_PROVIDER_ID,
OpenAi => language_model::OPEN_AI_PROVIDER_ID,
Google => language_model::GOOGLE_PROVIDER_ID,
}
}
fn upstream_provider_name(&self) -> LanguageModelProviderName {
use zed_llm_client::LanguageModelProvider::*;
match self.model.provider {
Anthropic => language_model::ANTHROPIC_PROVIDER_NAME,
OpenAi => language_model::OPEN_AI_PROVIDER_NAME,
Google => language_model::GOOGLE_PROVIDER_NAME,
}
}
fn supports_tools(&self) -> bool {
@ -776,6 +790,7 @@ impl LanguageModel for CloudLanguageModel {
.body(serde_json::to_string(&request_body)?.into())?;
let mut response = http_client.send(request).await?;
let status = response.status();
let headers = response.headers().clone();
let mut response_body = String::new();
response
.body_mut()
@ -790,7 +805,8 @@ impl LanguageModel for CloudLanguageModel {
} else {
Err(anyhow!(ApiError {
status,
body: response_body
body: response_body,
headers
}))
}
}
@ -855,18 +871,7 @@ impl LanguageModel for CloudLanguageModel {
)
.await
.map_err(|err| match err.downcast::<ApiError>() {
Ok(api_err) => {
if api_err.status == StatusCode::BAD_REQUEST {
if let Some(tokens) = parse_prompt_too_long(&api_err.body) {
return anyhow!(
LanguageModelKnownError::ContextWindowLimitExceeded {
tokens
}
);
}
}
anyhow!(api_err)
}
Ok(api_err) => anyhow!(LanguageModelCompletionError::from(api_err)),
Err(err) => anyhow!(err),
})?;
@ -995,7 +1000,7 @@ where
.flat_map(move |event| {
futures::stream::iter(match event {
Err(error) => {
vec![Err(LanguageModelCompletionError::Other(error))]
vec![Err(LanguageModelCompletionError::from(error))]
}
Ok(CloudCompletionEvent::Status(event)) => {
vec![Ok(LanguageModelCompletionEvent::StatusUpdate(event))]

View file

@ -35,8 +35,9 @@ use super::anthropic::count_anthropic_tokens;
use super::google::count_google_tokens;
use super::open_ai::count_open_ai_tokens;
const PROVIDER_ID: &str = "copilot_chat";
const PROVIDER_NAME: &str = "GitHub Copilot Chat";
const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("copilot_chat");
const PROVIDER_NAME: LanguageModelProviderName =
LanguageModelProviderName::new("GitHub Copilot Chat");
pub struct CopilotChatLanguageModelProvider {
state: Entity<State>,
@ -102,11 +103,11 @@ impl LanguageModelProviderState for CopilotChatLanguageModelProvider {
impl LanguageModelProvider for CopilotChatLanguageModelProvider {
fn id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into())
PROVIDER_ID
}
fn name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into())
PROVIDER_NAME
}
fn icon(&self) -> IconName {
@ -201,11 +202,11 @@ impl LanguageModel for CopilotChatLanguageModel {
}
fn provider_id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into())
PROVIDER_ID
}
fn provider_name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into())
PROVIDER_NAME
}
fn supports_tools(&self) -> bool {
@ -400,14 +401,14 @@ pub fn map_to_language_model_completion_events(
raw_input: tool_call.arguments.clone(),
},
)),
Err(error) => {
Err(LanguageModelCompletionError::BadInputJson {
Err(error) => Ok(
LanguageModelCompletionEvent::ToolUseJsonParseError {
id: tool_call.id.into(),
tool_name: tool_call.name.as_str().into(),
raw_input: tool_call.arguments.into(),
json_parse_error: error.to_string(),
})
}
},
),
}
},
));

View file

@ -28,8 +28,8 @@ use util::ResultExt;
use crate::{AllLanguageModelSettings, ui::InstructionListItem};
const PROVIDER_ID: &str = "deepseek";
const PROVIDER_NAME: &str = "DeepSeek";
const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("deepseek");
const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("DeepSeek");
const DEEPSEEK_API_KEY_VAR: &str = "DEEPSEEK_API_KEY";
#[derive(Default)]
@ -174,11 +174,11 @@ impl LanguageModelProviderState for DeepSeekLanguageModelProvider {
impl LanguageModelProvider for DeepSeekLanguageModelProvider {
fn id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into())
PROVIDER_ID
}
fn name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into())
PROVIDER_NAME
}
fn icon(&self) -> IconName {
@ -283,11 +283,11 @@ impl LanguageModel for DeepSeekLanguageModel {
}
fn provider_id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into())
PROVIDER_ID
}
fn provider_name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into())
PROVIDER_NAME
}
fn supports_tools(&self) -> bool {
@ -466,7 +466,7 @@ impl DeepSeekEventMapper {
events.flat_map(move |event| {
futures::stream::iter(match event {
Ok(event) => self.map_event(event),
Err(error) => vec![Err(LanguageModelCompletionError::Other(anyhow!(error)))],
Err(error) => vec![Err(LanguageModelCompletionError::from(error))],
})
})
}
@ -476,7 +476,7 @@ impl DeepSeekEventMapper {
event: deepseek::StreamResponse,
) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
let Some(choice) = event.choices.first() else {
return vec![Err(LanguageModelCompletionError::Other(anyhow!(
return vec![Err(LanguageModelCompletionError::from(anyhow!(
"Response contained no choices"
)))];
};
@ -538,8 +538,8 @@ impl DeepSeekEventMapper {
raw_input: tool_call.arguments.clone(),
},
)),
Err(error) => Err(LanguageModelCompletionError::BadInputJson {
id: tool_call.id.into(),
Err(error) => Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
id: tool_call.id.clone().into(),
tool_name: tool_call.name.as_str().into(),
raw_input: tool_call.arguments.into(),
json_parse_error: error.to_string(),

View file

@ -37,8 +37,8 @@ use util::ResultExt;
use crate::AllLanguageModelSettings;
use crate::ui::InstructionListItem;
const PROVIDER_ID: &str = "google";
const PROVIDER_NAME: &str = "Google AI";
const PROVIDER_ID: LanguageModelProviderId = language_model::GOOGLE_PROVIDER_ID;
const PROVIDER_NAME: LanguageModelProviderName = language_model::GOOGLE_PROVIDER_NAME;
#[derive(Default, Clone, Debug, PartialEq)]
pub struct GoogleSettings {
@ -207,11 +207,11 @@ impl LanguageModelProviderState for GoogleLanguageModelProvider {
impl LanguageModelProvider for GoogleLanguageModelProvider {
fn id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into())
PROVIDER_ID
}
fn name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into())
PROVIDER_NAME
}
fn icon(&self) -> IconName {
@ -334,11 +334,11 @@ impl LanguageModel for GoogleLanguageModel {
}
fn provider_id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into())
PROVIDER_ID
}
fn provider_name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into())
PROVIDER_NAME
}
fn supports_tools(&self) -> bool {
@ -423,9 +423,7 @@ impl LanguageModel for GoogleLanguageModel {
);
let request = self.stream_completion(request, cx);
let future = self.request_limiter.stream(async move {
let response = request
.await
.map_err(|err| LanguageModelCompletionError::Other(anyhow!(err)))?;
let response = request.await.map_err(LanguageModelCompletionError::from)?;
Ok(GoogleEventMapper::new().map_stream(response))
});
async move { Ok(future.await?.boxed()) }.boxed()
@ -622,7 +620,7 @@ impl GoogleEventMapper {
futures::stream::iter(match event {
Some(Ok(event)) => self.map_event(event),
Some(Err(error)) => {
vec![Err(LanguageModelCompletionError::Other(anyhow!(error)))]
vec![Err(LanguageModelCompletionError::from(error))]
}
None => vec![Ok(LanguageModelCompletionEvent::Stop(self.stop_reason))],
})

View file

@ -31,8 +31,8 @@ const LMSTUDIO_DOWNLOAD_URL: &str = "https://lmstudio.ai/download";
const LMSTUDIO_CATALOG_URL: &str = "https://lmstudio.ai/models";
const LMSTUDIO_SITE: &str = "https://lmstudio.ai/";
const PROVIDER_ID: &str = "lmstudio";
const PROVIDER_NAME: &str = "LM Studio";
const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("lmstudio");
const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("LM Studio");
#[derive(Default, Debug, Clone, PartialEq)]
pub struct LmStudioSettings {
@ -156,11 +156,11 @@ impl LanguageModelProviderState for LmStudioLanguageModelProvider {
impl LanguageModelProvider for LmStudioLanguageModelProvider {
fn id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into())
PROVIDER_ID
}
fn name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into())
PROVIDER_NAME
}
fn icon(&self) -> IconName {
@ -386,11 +386,11 @@ impl LanguageModel for LmStudioLanguageModel {
}
fn provider_id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into())
PROVIDER_ID
}
fn provider_name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into())
PROVIDER_NAME
}
fn supports_tools(&self) -> bool {
@ -474,7 +474,7 @@ impl LmStudioEventMapper {
events.flat_map(move |event| {
futures::stream::iter(match event {
Ok(event) => self.map_event(event),
Err(error) => vec![Err(LanguageModelCompletionError::Other(anyhow!(error)))],
Err(error) => vec![Err(LanguageModelCompletionError::from(error))],
})
})
}
@ -484,7 +484,7 @@ impl LmStudioEventMapper {
event: lmstudio::ResponseStreamEvent,
) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
let Some(choice) = event.choices.into_iter().next() else {
return vec![Err(LanguageModelCompletionError::Other(anyhow!(
return vec![Err(LanguageModelCompletionError::from(anyhow!(
"Response contained no choices"
)))];
};
@ -553,7 +553,7 @@ impl LmStudioEventMapper {
raw_input: tool_call.arguments,
},
)),
Err(error) => Err(LanguageModelCompletionError::BadInputJson {
Err(error) => Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
id: tool_call.id.into(),
tool_name: tool_call.name.into(),
raw_input: tool_call.arguments.into(),

View file

@ -2,8 +2,7 @@ use anyhow::{Context as _, Result, anyhow};
use collections::BTreeMap;
use credentials_provider::CredentialsProvider;
use editor::{Editor, EditorElement, EditorStyle};
use futures::stream::BoxStream;
use futures::{FutureExt, StreamExt, future::BoxFuture};
use futures::{FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream};
use gpui::{
AnyView, App, AsyncApp, Context, Entity, FontStyle, Subscription, Task, TextStyle, WhiteSpace,
};
@ -15,6 +14,7 @@ use language_model::{
LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent,
RateLimiter, Role, StopReason, TokenUsage,
};
use mistral::StreamResponse;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsStore};
@ -29,8 +29,8 @@ use util::ResultExt;
use crate::{AllLanguageModelSettings, ui::InstructionListItem};
const PROVIDER_ID: &str = "mistral";
const PROVIDER_NAME: &str = "Mistral";
const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("mistral");
const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("Mistral");
#[derive(Default, Clone, Debug, PartialEq)]
pub struct MistralSettings {
@ -171,11 +171,11 @@ impl LanguageModelProviderState for MistralLanguageModelProvider {
impl LanguageModelProvider for MistralLanguageModelProvider {
fn id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into())
PROVIDER_ID
}
fn name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into())
PROVIDER_NAME
}
fn icon(&self) -> IconName {
@ -298,11 +298,11 @@ impl LanguageModel for MistralLanguageModel {
}
fn provider_id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into())
PROVIDER_ID
}
fn provider_name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into())
PROVIDER_NAME
}
fn supports_tools(&self) -> bool {
@ -579,13 +579,13 @@ impl MistralEventMapper {
pub fn map_stream(
mut self,
events: Pin<Box<dyn Send + futures::Stream<Item = Result<mistral::StreamResponse>>>>,
) -> impl futures::Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
events: Pin<Box<dyn Send + Stream<Item = Result<StreamResponse>>>>,
) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
{
events.flat_map(move |event| {
futures::stream::iter(match event {
Ok(event) => self.map_event(event),
Err(error) => vec![Err(LanguageModelCompletionError::Other(anyhow!(error)))],
Err(error) => vec![Err(LanguageModelCompletionError::from(error))],
})
})
}
@ -595,7 +595,7 @@ impl MistralEventMapper {
event: mistral::StreamResponse,
) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
let Some(choice) = event.choices.first() else {
return vec![Err(LanguageModelCompletionError::Other(anyhow!(
return vec![Err(LanguageModelCompletionError::from(anyhow!(
"Response contained no choices"
)))];
};
@ -660,7 +660,7 @@ impl MistralEventMapper {
for (_, tool_call) in self.tool_calls_by_index.drain() {
if tool_call.id.is_empty() || tool_call.name.is_empty() {
results.push(Err(LanguageModelCompletionError::Other(anyhow!(
results.push(Err(LanguageModelCompletionError::from(anyhow!(
"Received incomplete tool call: missing id or name"
))));
continue;
@ -676,12 +676,14 @@ impl MistralEventMapper {
raw_input: tool_call.arguments,
},
))),
Err(error) => results.push(Err(LanguageModelCompletionError::BadInputJson {
Err(error) => {
results.push(Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
id: tool_call.id.into(),
tool_name: tool_call.name.into(),
raw_input: tool_call.arguments.into(),
json_parse_error: error.to_string(),
})),
}))
}
}
}

View file

@ -30,8 +30,8 @@ const OLLAMA_DOWNLOAD_URL: &str = "https://ollama.com/download";
const OLLAMA_LIBRARY_URL: &str = "https://ollama.com/library";
const OLLAMA_SITE: &str = "https://ollama.com/";
const PROVIDER_ID: &str = "ollama";
const PROVIDER_NAME: &str = "Ollama";
const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("ollama");
const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("Ollama");
#[derive(Default, Debug, Clone, PartialEq)]
pub struct OllamaSettings {
@ -181,11 +181,11 @@ impl LanguageModelProviderState for OllamaLanguageModelProvider {
impl LanguageModelProvider for OllamaLanguageModelProvider {
fn id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into())
PROVIDER_ID
}
fn name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into())
PROVIDER_NAME
}
fn icon(&self) -> IconName {
@ -350,11 +350,11 @@ impl LanguageModel for OllamaLanguageModel {
}
fn provider_id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into())
PROVIDER_ID
}
fn provider_name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into())
PROVIDER_NAME
}
fn supports_tools(&self) -> bool {
@ -453,7 +453,7 @@ fn map_to_language_model_completion_events(
let delta = match response {
Ok(delta) => delta,
Err(e) => {
let event = Err(LanguageModelCompletionError::Other(anyhow!(e)));
let event = Err(LanguageModelCompletionError::from(anyhow!(e)));
return Some((vec![event], state));
}
};

View file

@ -31,8 +31,8 @@ use util::ResultExt;
use crate::OpenAiSettingsContent;
use crate::{AllLanguageModelSettings, ui::InstructionListItem};
const PROVIDER_ID: &str = "openai";
const PROVIDER_NAME: &str = "OpenAI";
const PROVIDER_ID: LanguageModelProviderId = language_model::OPEN_AI_PROVIDER_ID;
const PROVIDER_NAME: LanguageModelProviderName = language_model::OPEN_AI_PROVIDER_NAME;
#[derive(Default, Clone, Debug, PartialEq)]
pub struct OpenAiSettings {
@ -173,11 +173,11 @@ impl LanguageModelProviderState for OpenAiLanguageModelProvider {
impl LanguageModelProvider for OpenAiLanguageModelProvider {
fn id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into())
PROVIDER_ID
}
fn name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into())
PROVIDER_NAME
}
fn icon(&self) -> IconName {
@ -267,7 +267,11 @@ impl OpenAiLanguageModel {
};
let future = self.request_limiter.stream(async move {
let api_key = api_key.context("Missing OpenAI API Key")?;
let Some(api_key) = api_key else {
return Err(LanguageModelCompletionError::NoApiKey {
provider: PROVIDER_NAME,
});
};
let request = stream_completion(http_client.as_ref(), &api_url, &api_key, request);
let response = request.await?;
Ok(response)
@ -287,11 +291,11 @@ impl LanguageModel for OpenAiLanguageModel {
}
fn provider_id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into())
PROVIDER_ID
}
fn provider_name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into())
PROVIDER_NAME
}
fn supports_tools(&self) -> bool {
@ -525,7 +529,7 @@ impl OpenAiEventMapper {
events.flat_map(move |event| {
futures::stream::iter(match event {
Ok(event) => self.map_event(event),
Err(error) => vec![Err(LanguageModelCompletionError::Other(anyhow!(error)))],
Err(error) => vec![Err(LanguageModelCompletionError::from(anyhow!(error)))],
})
})
}
@ -588,10 +592,10 @@ impl OpenAiEventMapper {
raw_input: tool_call.arguments.clone(),
},
)),
Err(error) => Err(LanguageModelCompletionError::BadInputJson {
Err(error) => Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
id: tool_call.id.into(),
tool_name: tool_call.name.as_str().into(),
raw_input: tool_call.arguments.into(),
tool_name: tool_call.name.into(),
raw_input: tool_call.arguments.clone().into(),
json_parse_error: error.to_string(),
}),
}

View file

@ -29,8 +29,8 @@ use util::ResultExt;
use crate::{AllLanguageModelSettings, ui::InstructionListItem};
const PROVIDER_ID: &str = "openrouter";
const PROVIDER_NAME: &str = "OpenRouter";
const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("openrouter");
const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("OpenRouter");
#[derive(Default, Clone, Debug, PartialEq)]
pub struct OpenRouterSettings {
@ -244,11 +244,11 @@ impl LanguageModelProviderState for OpenRouterLanguageModelProvider {
impl LanguageModelProvider for OpenRouterLanguageModelProvider {
fn id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into())
PROVIDER_ID
}
fn name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into())
PROVIDER_NAME
}
fn icon(&self) -> IconName {
@ -363,11 +363,11 @@ impl LanguageModel for OpenRouterLanguageModel {
}
fn provider_id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into())
PROVIDER_ID
}
fn provider_name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into())
PROVIDER_NAME
}
fn supports_tools(&self) -> bool {
@ -607,7 +607,7 @@ impl OpenRouterEventMapper {
events.flat_map(move |event| {
futures::stream::iter(match event {
Ok(event) => self.map_event(event),
Err(error) => vec![Err(LanguageModelCompletionError::Other(anyhow!(error)))],
Err(error) => vec![Err(LanguageModelCompletionError::from(anyhow!(error)))],
})
})
}
@ -617,7 +617,7 @@ impl OpenRouterEventMapper {
event: ResponseStreamEvent,
) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
let Some(choice) = event.choices.first() else {
return vec![Err(LanguageModelCompletionError::Other(anyhow!(
return vec![Err(LanguageModelCompletionError::from(anyhow!(
"Response contained no choices"
)))];
};
@ -683,10 +683,10 @@ impl OpenRouterEventMapper {
raw_input: tool_call.arguments.clone(),
},
)),
Err(error) => Err(LanguageModelCompletionError::BadInputJson {
id: tool_call.id.into(),
Err(error) => Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
id: tool_call.id.clone().into(),
tool_name: tool_call.name.as_str().into(),
raw_input: tool_call.arguments.into(),
raw_input: tool_call.arguments.clone().into(),
json_parse_error: error.to_string(),
}),
}

View file

@ -25,8 +25,8 @@ use util::ResultExt;
use crate::{AllLanguageModelSettings, ui::InstructionListItem};
const PROVIDER_ID: &str = "vercel";
const PROVIDER_NAME: &str = "Vercel";
const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("vercel");
const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("Vercel");
#[derive(Default, Clone, Debug, PartialEq)]
pub struct VercelSettings {
@ -172,11 +172,11 @@ impl LanguageModelProviderState for VercelLanguageModelProvider {
impl LanguageModelProvider for VercelLanguageModelProvider {
fn id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into())
PROVIDER_ID
}
fn name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into())
PROVIDER_NAME
}
fn icon(&self) -> IconName {
@ -269,7 +269,11 @@ impl VercelLanguageModel {
};
let future = self.request_limiter.stream(async move {
let api_key = api_key.context("Missing Vercel API Key")?;
let Some(api_key) = api_key else {
return Err(LanguageModelCompletionError::NoApiKey {
provider: PROVIDER_NAME,
});
};
let request =
open_ai::stream_completion(http_client.as_ref(), &api_url, &api_key, request);
let response = request.await?;
@ -290,11 +294,11 @@ impl LanguageModel for VercelLanguageModel {
}
fn provider_id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into())
PROVIDER_ID
}
fn provider_name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into())
PROVIDER_NAME
}
fn supports_tools(&self) -> bool {

View file

@ -7,10 +7,7 @@ use gpui::{App, AppContext, Context, Entity, Subscription, Task};
use http_client::{HttpClient, Method};
use language_model::{LlmApiToken, RefreshLlmTokenListener};
use web_search::{WebSearchProvider, WebSearchProviderId};
use zed_llm_client::{
CLIENT_SUPPORTS_EXA_WEB_SEARCH_PROVIDER_HEADER_NAME, EXPIRED_LLM_TOKEN_HEADER_NAME,
WebSearchBody, WebSearchResponse,
};
use zed_llm_client::{EXPIRED_LLM_TOKEN_HEADER_NAME, WebSearchBody, WebSearchResponse};
pub struct CloudWebSearchProvider {
state: Entity<State>,
@ -92,7 +89,6 @@ async fn perform_web_search(
.uri(http_client.build_zed_llm_url("/web_search", &[])?.as_ref())
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {token}"))
.header(CLIENT_SUPPORTS_EXA_WEB_SEARCH_PROVIDER_HEADER_NAME, "true")
.body(serde_json::to_string(&body)?.into())?;
let mut response = http_client
.send(request)