agent: Improve error handling and retry for zed-provided models (#33565)

* Updates to `zed_llm_client-0.8.5` which adds support for `retry_after`
when anthropic provides it.

* Distinguishes upstream provider errors and rate limits from errors
that originate from zed's servers

* Moves `LanguageModelCompletionError::BadInputJson` to
`LanguageModelCompletionEvent::ToolUseJsonParseError`. While arguably
this is an error case, the logic in thread is cleaner with this move.
There is also precedent for inclusion of errors in the event type -
`CompletionRequestStatus::Failed` is how cloud errors arrive.

* Updates `PROVIDER_ID` / `PROVIDER_NAME` constants to use proper types
instead of `&str`, since they can be constructed in a const fashion.

* Removes use of `CLIENT_SUPPORTS_EXA_WEB_SEARCH_PROVIDER_HEADER_NAME`
as the server no longer reads this header and just defaults to that
behavior.

Release notes for this is covered by #33275

Release Notes:

- N/A

---------

Co-authored-by: Richard Feldman <oss@rtfeldman.com>
Co-authored-by: Richard <richard@zed.dev>
This commit is contained in:
Michael Sloan 2025-06-30 21:01:32 -06:00 committed by GitHub
parent f022a13091
commit d497f52e17
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
25 changed files with 656 additions and 479 deletions

4
Cargo.lock generated
View file

@ -20139,9 +20139,9 @@ dependencies = [
[[package]] [[package]]
name = "zed_llm_client" name = "zed_llm_client"
version = "0.8.4" version = "0.8.5"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "de7d9523255f4e00ee3d0918e5407bd252d798a4a8e71f6d37f23317a1588203" checksum = "c740e29260b8797ad252c202ea09a255b3cbc13f30faaf92fb6b2490336106e0"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"serde", "serde",

View file

@ -625,7 +625,7 @@ wasmtime = { version = "29", default-features = false, features = [
wasmtime-wasi = "29" wasmtime-wasi = "29"
which = "6.0.0" which = "6.0.0"
workspace-hack = "0.1.0" workspace-hack = "0.1.0"
zed_llm_client = "0.8.4" zed_llm_client = "0.8.5"
zstd = "0.11" zstd = "0.11"
[workspace.dependencies.async-stripe] [workspace.dependencies.async-stripe]

View file

@ -23,11 +23,10 @@ use gpui::{
}; };
use language_model::{ use language_model::{
ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
LanguageModelId, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest, LanguageModelId, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult, LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolResultContent,
LanguageModelToolResultContent, LanguageModelToolUseId, MessageContent, LanguageModelToolUseId, MessageContent, ModelRequestLimitReachedError, PaymentRequiredError,
ModelRequestLimitReachedError, PaymentRequiredError, Role, SelectedModel, StopReason, Role, SelectedModel, StopReason, TokenUsage,
TokenUsage,
}; };
use postage::stream::Stream as _; use postage::stream::Stream as _;
use project::{ use project::{
@ -1531,82 +1530,7 @@ impl Thread {
} }
thread.update(cx, |thread, cx| { thread.update(cx, |thread, cx| {
let event = match event { match event? {
Ok(event) => event,
Err(error) => {
match error {
LanguageModelCompletionError::RateLimitExceeded { retry_after } => {
anyhow::bail!(LanguageModelKnownError::RateLimitExceeded { retry_after });
}
LanguageModelCompletionError::Overloaded => {
anyhow::bail!(LanguageModelKnownError::Overloaded);
}
LanguageModelCompletionError::ApiInternalServerError =>{
anyhow::bail!(LanguageModelKnownError::ApiInternalServerError);
}
LanguageModelCompletionError::PromptTooLarge { tokens } => {
let tokens = tokens.unwrap_or_else(|| {
// We didn't get an exact token count from the API, so fall back on our estimate.
thread.total_token_usage()
.map(|usage| usage.total)
.unwrap_or(0)
// We know the context window was exceeded in practice, so if our estimate was
// lower than max tokens, the estimate was wrong; return that we exceeded by 1.
.max(model.max_token_count().saturating_add(1))
});
anyhow::bail!(LanguageModelKnownError::ContextWindowLimitExceeded { tokens })
}
LanguageModelCompletionError::ApiReadResponseError(io_error) => {
anyhow::bail!(LanguageModelKnownError::ReadResponseError(io_error));
}
LanguageModelCompletionError::UnknownResponseFormat(error) => {
anyhow::bail!(LanguageModelKnownError::UnknownResponseFormat(error));
}
LanguageModelCompletionError::HttpResponseError { status, ref body } => {
if let Some(known_error) = LanguageModelKnownError::from_http_response(status, body) {
anyhow::bail!(known_error);
} else {
return Err(error.into());
}
}
LanguageModelCompletionError::DeserializeResponse(error) => {
anyhow::bail!(LanguageModelKnownError::DeserializeResponse(error));
}
LanguageModelCompletionError::BadInputJson {
id,
tool_name,
raw_input: invalid_input_json,
json_parse_error,
} => {
thread.receive_invalid_tool_json(
id,
tool_name,
invalid_input_json,
json_parse_error,
window,
cx,
);
return Ok(());
}
// These are all errors we can't automatically attempt to recover from (e.g. by retrying)
err @ LanguageModelCompletionError::BadRequestFormat |
err @ LanguageModelCompletionError::AuthenticationError |
err @ LanguageModelCompletionError::PermissionError |
err @ LanguageModelCompletionError::ApiEndpointNotFound |
err @ LanguageModelCompletionError::SerializeRequest(_) |
err @ LanguageModelCompletionError::BuildRequestBody(_) |
err @ LanguageModelCompletionError::HttpSend(_) => {
anyhow::bail!(err);
}
LanguageModelCompletionError::Other(error) => {
return Err(error);
}
}
}
};
match event {
LanguageModelCompletionEvent::StartMessage { .. } => { LanguageModelCompletionEvent::StartMessage { .. } => {
request_assistant_message_id = request_assistant_message_id =
Some(thread.insert_assistant_message( Some(thread.insert_assistant_message(
@ -1683,9 +1607,7 @@ impl Thread {
}; };
} }
} }
LanguageModelCompletionEvent::RedactedThinking { LanguageModelCompletionEvent::RedactedThinking { data } => {
data
} => {
thread.received_chunk(); thread.received_chunk();
if let Some(last_message) = thread.messages.last_mut() { if let Some(last_message) = thread.messages.last_mut() {
@ -1734,6 +1656,21 @@ impl Thread {
}); });
} }
} }
LanguageModelCompletionEvent::ToolUseJsonParseError {
id,
tool_name,
raw_input: invalid_input_json,
json_parse_error,
} => {
thread.receive_invalid_tool_json(
id,
tool_name,
invalid_input_json,
json_parse_error,
window,
cx,
);
}
LanguageModelCompletionEvent::StatusUpdate(status_update) => { LanguageModelCompletionEvent::StatusUpdate(status_update) => {
if let Some(completion) = thread if let Some(completion) = thread
.pending_completions .pending_completions
@ -1741,23 +1678,34 @@ impl Thread {
.find(|completion| completion.id == pending_completion_id) .find(|completion| completion.id == pending_completion_id)
{ {
match status_update { match status_update {
CompletionRequestStatus::Queued { CompletionRequestStatus::Queued { position } => {
position, completion.queue_state =
} => { QueueState::Queued { position };
completion.queue_state = QueueState::Queued { position };
} }
CompletionRequestStatus::Started => { CompletionRequestStatus::Started => {
completion.queue_state = QueueState::Started; completion.queue_state = QueueState::Started;
} }
CompletionRequestStatus::Failed { CompletionRequestStatus::Failed {
code, message, request_id code,
message,
request_id: _,
retry_after,
} => { } => {
anyhow::bail!("completion request failed. request_id: {request_id}, code: {code}, message: {message}"); return Err(
LanguageModelCompletionError::from_cloud_failure(
model.upstream_provider_name(),
code,
message,
retry_after.map(Duration::from_secs_f64),
),
);
} }
CompletionRequestStatus::UsageUpdated { CompletionRequestStatus::UsageUpdated { amount, limit } => {
amount, limit thread.update_model_request_usage(
} => { amount as u32,
thread.update_model_request_usage(amount as u32, limit, cx); limit,
cx,
);
} }
CompletionRequestStatus::ToolUseLimitReached => { CompletionRequestStatus::ToolUseLimitReached => {
thread.tool_use_limit_reached = true; thread.tool_use_limit_reached = true;
@ -1808,10 +1756,11 @@ impl Thread {
Ok(stop_reason) => { Ok(stop_reason) => {
match stop_reason { match stop_reason {
StopReason::ToolUse => { StopReason::ToolUse => {
let tool_uses = thread.use_pending_tools(window, model.clone(), cx); let tool_uses =
thread.use_pending_tools(window, model.clone(), cx);
cx.emit(ThreadEvent::UsePendingTools { tool_uses }); cx.emit(ThreadEvent::UsePendingTools { tool_uses });
} }
StopReason::EndTurn | StopReason::MaxTokens => { StopReason::EndTurn | StopReason::MaxTokens => {
thread.project.update(cx, |project, cx| { thread.project.update(cx, |project, cx| {
project.set_agent_location(None, cx); project.set_agent_location(None, cx);
}); });
@ -1827,7 +1776,9 @@ impl Thread {
{ {
let mut messages_to_remove = Vec::new(); let mut messages_to_remove = Vec::new();
for (ix, message) in thread.messages.iter().enumerate().rev() { for (ix, message) in
thread.messages.iter().enumerate().rev()
{
messages_to_remove.push(message.id); messages_to_remove.push(message.id);
if message.role == Role::User { if message.role == Role::User {
@ -1835,7 +1786,9 @@ impl Thread {
break; break;
} }
if let Some(prev_message) = thread.messages.get(ix - 1) { if let Some(prev_message) =
thread.messages.get(ix - 1)
{
if prev_message.role == Role::Assistant { if prev_message.role == Role::Assistant {
break; break;
} }
@ -1850,14 +1803,16 @@ impl Thread {
cx.emit(ThreadEvent::ShowError(ThreadError::Message { cx.emit(ThreadEvent::ShowError(ThreadError::Message {
header: "Language model refusal".into(), header: "Language model refusal".into(),
message: "Model refused to generate content for safety reasons.".into(), message:
"Model refused to generate content for safety reasons."
.into(),
})); }));
} }
} }
// We successfully completed, so cancel any remaining retries. // We successfully completed, so cancel any remaining retries.
thread.retry_state = None; thread.retry_state = None;
}, }
Err(error) => { Err(error) => {
thread.project.update(cx, |project, cx| { thread.project.update(cx, |project, cx| {
project.set_agent_location(None, cx); project.set_agent_location(None, cx);
@ -1883,26 +1838,38 @@ impl Thread {
cx.emit(ThreadEvent::ShowError( cx.emit(ThreadEvent::ShowError(
ThreadError::ModelRequestLimitReached { plan: error.plan }, ThreadError::ModelRequestLimitReached { plan: error.plan },
)); ));
} else if let Some(known_error) = } else if let Some(completion_error) =
error.downcast_ref::<LanguageModelKnownError>() error.downcast_ref::<LanguageModelCompletionError>()
{ {
match known_error { use LanguageModelCompletionError::*;
LanguageModelKnownError::ContextWindowLimitExceeded { tokens } => { match &completion_error {
PromptTooLarge { tokens, .. } => {
let tokens = tokens.unwrap_or_else(|| {
// We didn't get an exact token count from the API, so fall back on our estimate.
thread
.total_token_usage()
.map(|usage| usage.total)
.unwrap_or(0)
// We know the context window was exceeded in practice, so if our estimate was
// lower than max tokens, the estimate was wrong; return that we exceeded by 1.
.max(model.max_token_count().saturating_add(1))
});
thread.exceeded_window_error = Some(ExceededWindowError { thread.exceeded_window_error = Some(ExceededWindowError {
model_id: model.id(), model_id: model.id(),
token_count: *tokens, token_count: tokens,
}); });
cx.notify(); cx.notify();
} }
LanguageModelKnownError::RateLimitExceeded { retry_after } => { RateLimitExceeded {
let provider_name = model.provider_name(); retry_after: Some(retry_after),
let error_message = format!( ..
"{}'s API rate limit exceeded", }
provider_name.0.as_ref() | ServerOverloaded {
); retry_after: Some(retry_after),
..
} => {
thread.handle_rate_limit_error( thread.handle_rate_limit_error(
&error_message, &completion_error,
*retry_after, *retry_after,
model.clone(), model.clone(),
intent, intent,
@ -1911,15 +1878,9 @@ impl Thread {
); );
retry_scheduled = true; retry_scheduled = true;
} }
LanguageModelKnownError::Overloaded => { RateLimitExceeded { .. } | ServerOverloaded { .. } => {
let provider_name = model.provider_name();
let error_message = format!(
"{}'s API servers are overloaded right now",
provider_name.0.as_ref()
);
retry_scheduled = thread.handle_retryable_error( retry_scheduled = thread.handle_retryable_error(
&error_message, &completion_error,
model.clone(), model.clone(),
intent, intent,
window, window,
@ -1929,15 +1890,11 @@ impl Thread {
emit_generic_error(error, cx); emit_generic_error(error, cx);
} }
} }
LanguageModelKnownError::ApiInternalServerError => { ApiInternalServerError { .. }
let provider_name = model.provider_name(); | ApiReadResponseError { .. }
let error_message = format!( | HttpSend { .. } => {
"{}'s API server reported an internal server error",
provider_name.0.as_ref()
);
retry_scheduled = thread.handle_retryable_error( retry_scheduled = thread.handle_retryable_error(
&error_message, &completion_error,
model.clone(), model.clone(),
intent, intent,
window, window,
@ -1947,12 +1904,16 @@ impl Thread {
emit_generic_error(error, cx); emit_generic_error(error, cx);
} }
} }
LanguageModelKnownError::ReadResponseError(_) | NoApiKey { .. }
LanguageModelKnownError::DeserializeResponse(_) | | HttpResponseError { .. }
LanguageModelKnownError::UnknownResponseFormat(_) => { | BadRequestFormat { .. }
// In the future we will attempt to re-roll response, but only once | AuthenticationError { .. }
emit_generic_error(error, cx); | PermissionError { .. }
} | ApiEndpointNotFound { .. }
| SerializeRequest { .. }
| BuildRequestBody { .. }
| DeserializeResponse { .. }
| Other { .. } => emit_generic_error(error, cx),
} }
} else { } else {
emit_generic_error(error, cx); emit_generic_error(error, cx);
@ -2084,7 +2045,7 @@ impl Thread {
fn handle_rate_limit_error( fn handle_rate_limit_error(
&mut self, &mut self,
error_message: &str, error: &LanguageModelCompletionError,
retry_after: Duration, retry_after: Duration,
model: Arc<dyn LanguageModel>, model: Arc<dyn LanguageModel>,
intent: CompletionIntent, intent: CompletionIntent,
@ -2092,9 +2053,10 @@ impl Thread {
cx: &mut Context<Self>, cx: &mut Context<Self>,
) { ) {
// For rate limit errors, we only retry once with the specified duration // For rate limit errors, we only retry once with the specified duration
let retry_message = format!( let retry_message = format!("{error}. Retrying in {} seconds…", retry_after.as_secs());
"{error_message}. Retrying in {} seconds…", log::warn!(
retry_after.as_secs() "Retrying completion request in {} seconds: {error:?}",
retry_after.as_secs(),
); );
// Add a UI-only message instead of a regular message // Add a UI-only message instead of a regular message
@ -2127,18 +2089,18 @@ impl Thread {
fn handle_retryable_error( fn handle_retryable_error(
&mut self, &mut self,
error_message: &str, error: &LanguageModelCompletionError,
model: Arc<dyn LanguageModel>, model: Arc<dyn LanguageModel>,
intent: CompletionIntent, intent: CompletionIntent,
window: Option<AnyWindowHandle>, window: Option<AnyWindowHandle>,
cx: &mut Context<Self>, cx: &mut Context<Self>,
) -> bool { ) -> bool {
self.handle_retryable_error_with_delay(error_message, None, model, intent, window, cx) self.handle_retryable_error_with_delay(error, None, model, intent, window, cx)
} }
fn handle_retryable_error_with_delay( fn handle_retryable_error_with_delay(
&mut self, &mut self,
error_message: &str, error: &LanguageModelCompletionError,
custom_delay: Option<Duration>, custom_delay: Option<Duration>,
model: Arc<dyn LanguageModel>, model: Arc<dyn LanguageModel>,
intent: CompletionIntent, intent: CompletionIntent,
@ -2168,8 +2130,12 @@ impl Thread {
// Add a transient message to inform the user // Add a transient message to inform the user
let delay_secs = delay.as_secs(); let delay_secs = delay.as_secs();
let retry_message = format!( let retry_message = format!(
"{}. Retrying (attempt {} of {}) in {} seconds...", "{error}. Retrying (attempt {attempt} of {max_attempts}) \
error_message, attempt, max_attempts, delay_secs in {delay_secs} seconds..."
);
log::warn!(
"Retrying completion request (attempt {attempt} of {max_attempts}) \
in {delay_secs} seconds: {error:?}",
); );
// Add a UI-only message instead of a regular message // Add a UI-only message instead of a regular message
@ -4139,9 +4105,15 @@ fn main() {{
>, >,
> { > {
let error = match self.error_type { let error = match self.error_type {
TestError::Overloaded => LanguageModelCompletionError::Overloaded, TestError::Overloaded => LanguageModelCompletionError::ServerOverloaded {
provider: self.provider_name(),
retry_after: None,
},
TestError::InternalServerError => { TestError::InternalServerError => {
LanguageModelCompletionError::ApiInternalServerError LanguageModelCompletionError::ApiInternalServerError {
provider: self.provider_name(),
message: "I'm a teapot orbiting the sun".to_string(),
}
} }
}; };
async move { async move {
@ -4649,9 +4621,13 @@ fn main() {{
> { > {
if !*self.failed_once.lock() { if !*self.failed_once.lock() {
*self.failed_once.lock() = true; *self.failed_once.lock() = true;
let provider = self.provider_name();
// Return error on first attempt // Return error on first attempt
let stream = futures::stream::once(async move { let stream = futures::stream::once(async move {
Err(LanguageModelCompletionError::Overloaded) Err(LanguageModelCompletionError::ServerOverloaded {
provider,
retry_after: None,
})
}); });
async move { Ok(stream.boxed()) }.boxed() async move { Ok(stream.boxed()) }.boxed()
} else { } else {
@ -4814,9 +4790,13 @@ fn main() {{
> { > {
if !*self.failed_once.lock() { if !*self.failed_once.lock() {
*self.failed_once.lock() = true; *self.failed_once.lock() = true;
let provider = self.provider_name();
// Return error on first attempt // Return error on first attempt
let stream = futures::stream::once(async move { let stream = futures::stream::once(async move {
Err(LanguageModelCompletionError::Overloaded) Err(LanguageModelCompletionError::ServerOverloaded {
provider,
retry_after: None,
})
}); });
async move { Ok(stream.boxed()) }.boxed() async move { Ok(stream.boxed()) }.boxed()
} else { } else {
@ -4969,10 +4949,12 @@ fn main() {{
LanguageModelCompletionError, LanguageModelCompletionError,
>, >,
> { > {
let provider = self.provider_name();
async move { async move {
let stream = futures::stream::once(async move { let stream = futures::stream::once(async move {
Err(LanguageModelCompletionError::RateLimitExceeded { Err(LanguageModelCompletionError::RateLimitExceeded {
retry_after: Duration::from_secs(TEST_RATE_LIMIT_RETRY_SECS), provider,
retry_after: Some(Duration::from_secs(TEST_RATE_LIMIT_RETRY_SECS)),
}) })
}); });
Ok(stream.boxed()) Ok(stream.boxed())

View file

@ -2025,9 +2025,7 @@ impl AgentPanel {
.thread() .thread()
.read(cx) .read(cx)
.configured_model() .configured_model()
.map_or(false, |model| { .map_or(false, |model| model.provider.id() == ZED_CLOUD_PROVIDER_ID);
model.provider.id().0 == ZED_CLOUD_PROVIDER_ID
});
if !is_using_zed_provider { if !is_using_zed_provider {
return false; return false;

View file

@ -1250,9 +1250,7 @@ impl MessageEditor {
self.thread self.thread
.read(cx) .read(cx)
.configured_model() .configured_model()
.map_or(false, |model| { .map_or(false, |model| model.provider.id() == ZED_CLOUD_PROVIDER_ID)
model.provider.id().0 == ZED_CLOUD_PROVIDER_ID
})
} }
fn render_usage_callout(&self, line_height: Pixels, cx: &mut Context<Self>) -> Option<Div> { fn render_usage_callout(&self, line_height: Pixels, cx: &mut Context<Self>) -> Option<Div> {

View file

@ -6,7 +6,7 @@ use anyhow::{Context as _, Result, anyhow};
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream}; use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
use http_client::http::{self, HeaderMap, HeaderValue}; use http_client::http::{self, HeaderMap, HeaderValue};
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest}; use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest, StatusCode};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use strum::{EnumIter, EnumString}; use strum::{EnumIter, EnumString};
use thiserror::Error; use thiserror::Error;
@ -356,7 +356,7 @@ pub async fn complete(
.send(request) .send(request)
.await .await
.map_err(AnthropicError::HttpSend)?; .map_err(AnthropicError::HttpSend)?;
let status = response.status(); let status_code = response.status();
let mut body = String::new(); let mut body = String::new();
response response
.body_mut() .body_mut()
@ -364,12 +364,12 @@ pub async fn complete(
.await .await
.map_err(AnthropicError::ReadResponse)?; .map_err(AnthropicError::ReadResponse)?;
if status.is_success() { if status_code.is_success() {
Ok(serde_json::from_str(&body).map_err(AnthropicError::DeserializeResponse)?) Ok(serde_json::from_str(&body).map_err(AnthropicError::DeserializeResponse)?)
} else { } else {
Err(AnthropicError::HttpResponseError { Err(AnthropicError::HttpResponseError {
status: status.as_u16(), status_code,
body, message: body,
}) })
} }
} }
@ -444,11 +444,7 @@ impl RateLimitInfo {
} }
Self { Self {
retry_after: headers retry_after: parse_retry_after(headers),
.get("retry-after")
.and_then(|v| v.to_str().ok())
.and_then(|v| v.parse::<u64>().ok())
.map(Duration::from_secs),
requests: RateLimit::from_headers("requests", headers).ok(), requests: RateLimit::from_headers("requests", headers).ok(),
tokens: RateLimit::from_headers("tokens", headers).ok(), tokens: RateLimit::from_headers("tokens", headers).ok(),
input_tokens: RateLimit::from_headers("input-tokens", headers).ok(), input_tokens: RateLimit::from_headers("input-tokens", headers).ok(),
@ -457,6 +453,17 @@ impl RateLimitInfo {
} }
} }
/// Parses the Retry-After header value as an integer number of seconds (anthropic always uses
/// seconds). Note that other services might specify an HTTP date or some other format for this
/// header. Returns `None` if the header is not present or cannot be parsed.
pub fn parse_retry_after(headers: &HeaderMap<HeaderValue>) -> Option<Duration> {
headers
.get("retry-after")
.and_then(|v| v.to_str().ok())
.and_then(|v| v.parse::<u64>().ok())
.map(Duration::from_secs)
}
fn get_header<'a>(key: &str, headers: &'a HeaderMap) -> anyhow::Result<&'a str> { fn get_header<'a>(key: &str, headers: &'a HeaderMap) -> anyhow::Result<&'a str> {
Ok(headers Ok(headers
.get(key) .get(key)
@ -520,6 +527,10 @@ pub async fn stream_completion_with_rate_limit_info(
}) })
.boxed(); .boxed();
Ok((stream, Some(rate_limits))) Ok((stream, Some(rate_limits)))
} else if response.status().as_u16() == 529 {
Err(AnthropicError::ServerOverloaded {
retry_after: rate_limits.retry_after,
})
} else if let Some(retry_after) = rate_limits.retry_after { } else if let Some(retry_after) = rate_limits.retry_after {
Err(AnthropicError::RateLimit { retry_after }) Err(AnthropicError::RateLimit { retry_after })
} else { } else {
@ -532,10 +543,9 @@ pub async fn stream_completion_with_rate_limit_info(
match serde_json::from_str::<Event>(&body) { match serde_json::from_str::<Event>(&body) {
Ok(Event::Error { error }) => Err(AnthropicError::ApiError(error)), Ok(Event::Error { error }) => Err(AnthropicError::ApiError(error)),
Ok(_) => Err(AnthropicError::UnexpectedResponseFormat(body)), Ok(_) | Err(_) => Err(AnthropicError::HttpResponseError {
Err(_) => Err(AnthropicError::HttpResponseError { status_code: response.status(),
status: response.status().as_u16(), message: body,
body: body,
}), }),
} }
} }
@ -801,16 +811,19 @@ pub enum AnthropicError {
ReadResponse(io::Error), ReadResponse(io::Error),
/// HTTP error response from the API /// HTTP error response from the API
HttpResponseError { status: u16, body: String }, HttpResponseError {
status_code: StatusCode,
message: String,
},
/// Rate limit exceeded /// Rate limit exceeded
RateLimit { retry_after: Duration }, RateLimit { retry_after: Duration },
/// Server overloaded
ServerOverloaded { retry_after: Option<Duration> },
/// API returned an error response /// API returned an error response
ApiError(ApiError), ApiError(ApiError),
/// Unexpected response format
UnexpectedResponseFormat(String),
} }
#[derive(Debug, Serialize, Deserialize, Error)] #[derive(Debug, Serialize, Deserialize, Error)]

View file

@ -2140,7 +2140,8 @@ impl AssistantContext {
); );
} }
LanguageModelCompletionEvent::ToolUse(_) | LanguageModelCompletionEvent::ToolUse(_) |
LanguageModelCompletionEvent::UsageUpdate(_) => {} LanguageModelCompletionEvent::ToolUseJsonParseError { .. } |
LanguageModelCompletionEvent::UsageUpdate(_) => {}
} }
}); });

View file

@ -29,6 +29,7 @@ use std::{
path::Path, path::Path,
str::FromStr, str::FromStr,
sync::mpsc, sync::mpsc,
time::Duration,
}; };
use util::path; use util::path;
@ -1658,12 +1659,14 @@ async fn retry_on_rate_limit<R>(mut request: impl AsyncFnMut() -> Result<R>) ->
match request().await { match request().await {
Ok(result) => return Ok(result), Ok(result) => return Ok(result),
Err(err) => match err.downcast::<LanguageModelCompletionError>() { Err(err) => match err.downcast::<LanguageModelCompletionError>() {
Ok(err) => match err { Ok(err) => match &err {
LanguageModelCompletionError::RateLimitExceeded { retry_after } => { LanguageModelCompletionError::RateLimitExceeded { retry_after, .. }
| LanguageModelCompletionError::ServerOverloaded { retry_after, .. } => {
let retry_after = retry_after.unwrap_or(Duration::from_secs(5));
// Wait for the duration supplied, with some jitter to avoid all requests being made at the same time. // Wait for the duration supplied, with some jitter to avoid all requests being made at the same time.
let jitter = retry_after.mul_f64(rand::thread_rng().gen_range(0.0..1.0)); let jitter = retry_after.mul_f64(rand::thread_rng().gen_range(0.0..1.0));
eprintln!( eprintln!(
"Attempt #{attempt}: Rate limit exceeded. Retry after {retry_after:?} + jitter of {jitter:?}" "Attempt #{attempt}: {err}. Retry after {retry_after:?} + jitter of {jitter:?}"
); );
Timer::after(retry_after + jitter).await; Timer::after(retry_after + jitter).await;
continue; continue;

View file

@ -1054,6 +1054,15 @@ pub fn response_events_to_markdown(
| LanguageModelCompletionEvent::StartMessage { .. } | LanguageModelCompletionEvent::StartMessage { .. }
| LanguageModelCompletionEvent::StatusUpdate { .. }, | LanguageModelCompletionEvent::StatusUpdate { .. },
) => {} ) => {}
Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
json_parse_error, ..
}) => {
flush_buffers(&mut response, &mut text_buffer, &mut thinking_buffer);
response.push_str(&format!(
"**Error**: parse error in tool use JSON: {}\n\n",
json_parse_error
));
}
Err(error) => { Err(error) => {
flush_buffers(&mut response, &mut text_buffer, &mut thinking_buffer); flush_buffers(&mut response, &mut text_buffer, &mut thinking_buffer);
response.push_str(&format!("**Error**: {}\n\n", error)); response.push_str(&format!("**Error**: {}\n\n", error));
@ -1132,6 +1141,17 @@ impl ThreadDialog {
| Ok(LanguageModelCompletionEvent::StartMessage { .. }) | Ok(LanguageModelCompletionEvent::StartMessage { .. })
| Ok(LanguageModelCompletionEvent::Stop(_)) => {} | Ok(LanguageModelCompletionEvent::Stop(_)) => {}
Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
json_parse_error,
..
}) => {
flush_text(&mut current_text, &mut content);
content.push(MessageContent::Text(format!(
"ERROR: parse error in tool use JSON: {}",
json_parse_error
)));
}
Err(error) => { Err(error) => {
flush_text(&mut current_text, &mut content); flush_text(&mut current_text, &mut content);
content.push(MessageContent::Text(format!("ERROR: {}", error))); content.push(MessageContent::Text(format!("ERROR: {}", error)));

View file

@ -9,17 +9,18 @@ mod telemetry;
pub mod fake_provider; pub mod fake_provider;
use anthropic::{AnthropicError, parse_prompt_too_long}; use anthropic::{AnthropicError, parse_prompt_too_long};
use anyhow::Result; use anyhow::{Result, anyhow};
use client::Client; use client::Client;
use futures::FutureExt; use futures::FutureExt;
use futures::{StreamExt, future::BoxFuture, stream::BoxStream}; use futures::{StreamExt, future::BoxFuture, stream::BoxStream};
use gpui::{AnyElement, AnyView, App, AsyncApp, SharedString, Task, Window}; use gpui::{AnyElement, AnyView, App, AsyncApp, SharedString, Task, Window};
use http_client::http; use http_client::{StatusCode, http};
use icons::IconName; use icons::IconName;
use parking_lot::Mutex; use parking_lot::Mutex;
use schemars::JsonSchema; use schemars::JsonSchema;
use serde::{Deserialize, Serialize, de::DeserializeOwned}; use serde::{Deserialize, Serialize, de::DeserializeOwned};
use std::ops::{Add, Sub}; use std::ops::{Add, Sub};
use std::str::FromStr;
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use std::{fmt, io}; use std::{fmt, io};
@ -34,11 +35,22 @@ pub use crate::request::*;
pub use crate::role::*; pub use crate::role::*;
pub use crate::telemetry::*; pub use crate::telemetry::*;
pub const ZED_CLOUD_PROVIDER_ID: &str = "zed.dev"; pub const ANTHROPIC_PROVIDER_ID: LanguageModelProviderId =
LanguageModelProviderId::new("anthropic");
pub const ANTHROPIC_PROVIDER_NAME: LanguageModelProviderName =
LanguageModelProviderName::new("Anthropic");
/// If we get a rate limit error that doesn't tell us when we can retry, pub const GOOGLE_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("google");
/// default to waiting this long before retrying. pub const GOOGLE_PROVIDER_NAME: LanguageModelProviderName =
const DEFAULT_RATE_LIMIT_RETRY_AFTER: Duration = Duration::from_secs(4); LanguageModelProviderName::new("Google AI");
pub const OPEN_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("openai");
pub const OPEN_AI_PROVIDER_NAME: LanguageModelProviderName =
LanguageModelProviderName::new("OpenAI");
pub const ZED_CLOUD_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("zed.dev");
pub const ZED_CLOUD_PROVIDER_NAME: LanguageModelProviderName =
LanguageModelProviderName::new("Zed");
pub fn init(client: Arc<Client>, cx: &mut App) { pub fn init(client: Arc<Client>, cx: &mut App) {
init_settings(cx); init_settings(cx);
@ -71,6 +83,12 @@ pub enum LanguageModelCompletionEvent {
data: String, data: String,
}, },
ToolUse(LanguageModelToolUse), ToolUse(LanguageModelToolUse),
ToolUseJsonParseError {
id: LanguageModelToolUseId,
tool_name: Arc<str>,
raw_input: Arc<str>,
json_parse_error: String,
},
StartMessage { StartMessage {
message_id: String, message_id: String,
}, },
@ -79,61 +97,179 @@ pub enum LanguageModelCompletionEvent {
#[derive(Error, Debug)] #[derive(Error, Debug)]
pub enum LanguageModelCompletionError { pub enum LanguageModelCompletionError {
#[error("rate limit exceeded, retry after {retry_after:?}")]
RateLimitExceeded { retry_after: Duration },
#[error("received bad input JSON")]
BadInputJson {
id: LanguageModelToolUseId,
tool_name: Arc<str>,
raw_input: Arc<str>,
json_parse_error: String,
},
#[error("language model provider's API is overloaded")]
Overloaded,
#[error(transparent)]
Other(#[from] anyhow::Error),
#[error("invalid request format to language model provider's API")]
BadRequestFormat,
#[error("authentication error with language model provider's API")]
AuthenticationError,
#[error("permission error with language model provider's API")]
PermissionError,
#[error("language model provider API endpoint not found")]
ApiEndpointNotFound,
#[error("prompt too large for context window")] #[error("prompt too large for context window")]
PromptTooLarge { tokens: Option<u64> }, PromptTooLarge { tokens: Option<u64> },
#[error("internal server error in language model provider's API")] #[error("missing {provider} API key")]
ApiInternalServerError, NoApiKey { provider: LanguageModelProviderName },
#[error("I/O error reading response from language model provider's API: {0:?}")] #[error("{provider}'s API rate limit exceeded")]
ApiReadResponseError(io::Error), RateLimitExceeded {
#[error("HTTP response error from language model provider's API: status {status} - {body:?}")] provider: LanguageModelProviderName,
HttpResponseError { status: u16, body: String }, retry_after: Option<Duration>,
#[error("error serializing request to language model provider API: {0}")] },
SerializeRequest(serde_json::Error), #[error("{provider}'s API servers are overloaded right now")]
#[error("error building request body to language model provider API: {0}")] ServerOverloaded {
BuildRequestBody(http::Error), provider: LanguageModelProviderName,
#[error("error sending HTTP request to language model provider API: {0}")] retry_after: Option<Duration>,
HttpSend(anyhow::Error), },
#[error("error deserializing language model provider API response: {0}")] #[error("{provider}'s API server reported an internal server error: {message}")]
DeserializeResponse(serde_json::Error), ApiInternalServerError {
#[error("unexpected language model provider API response format: {0}")] provider: LanguageModelProviderName,
UnknownResponseFormat(String), message: String,
},
#[error("HTTP response error from {provider}'s API: status {status_code} - {message:?}")]
HttpResponseError {
provider: LanguageModelProviderName,
status_code: StatusCode,
message: String,
},
// Client errors
#[error("invalid request format to {provider}'s API: {message}")]
BadRequestFormat {
provider: LanguageModelProviderName,
message: String,
},
#[error("authentication error with {provider}'s API: {message}")]
AuthenticationError {
provider: LanguageModelProviderName,
message: String,
},
#[error("permission error with {provider}'s API: {message}")]
PermissionError {
provider: LanguageModelProviderName,
message: String,
},
#[error("language model provider API endpoint not found")]
ApiEndpointNotFound { provider: LanguageModelProviderName },
#[error("I/O error reading response from {provider}'s API")]
ApiReadResponseError {
provider: LanguageModelProviderName,
#[source]
error: io::Error,
},
#[error("error serializing request to {provider} API")]
SerializeRequest {
provider: LanguageModelProviderName,
#[source]
error: serde_json::Error,
},
#[error("error building request body to {provider} API")]
BuildRequestBody {
provider: LanguageModelProviderName,
#[source]
error: http::Error,
},
#[error("error sending HTTP request to {provider} API")]
HttpSend {
provider: LanguageModelProviderName,
#[source]
error: anyhow::Error,
},
#[error("error deserializing {provider} API response")]
DeserializeResponse {
provider: LanguageModelProviderName,
#[source]
error: serde_json::Error,
},
// TODO: Ideally this would be removed in favor of having a comprehensive list of errors.
#[error(transparent)]
Other(#[from] anyhow::Error),
}
impl LanguageModelCompletionError {
pub fn from_cloud_failure(
upstream_provider: LanguageModelProviderName,
code: String,
message: String,
retry_after: Option<Duration>,
) -> Self {
if let Some(tokens) = parse_prompt_too_long(&message) {
// TODO: currently Anthropic PAYLOAD_TOO_LARGE response may cause INTERNAL_SERVER_ERROR
// to be reported. This is a temporary workaround to handle this in the case where the
// token limit has been exceeded.
Self::PromptTooLarge {
tokens: Some(tokens),
}
} else if let Some(status_code) = code
.strip_prefix("upstream_http_")
.and_then(|code| StatusCode::from_str(code).ok())
{
Self::from_http_status(upstream_provider, status_code, message, retry_after)
} else if let Some(status_code) = code
.strip_prefix("http_")
.and_then(|code| StatusCode::from_str(code).ok())
{
Self::from_http_status(ZED_CLOUD_PROVIDER_NAME, status_code, message, retry_after)
} else {
anyhow!("completion request failed, code: {code}, message: {message}").into()
}
}
pub fn from_http_status(
provider: LanguageModelProviderName,
status_code: StatusCode,
message: String,
retry_after: Option<Duration>,
) -> Self {
match status_code {
StatusCode::BAD_REQUEST => Self::BadRequestFormat { provider, message },
StatusCode::UNAUTHORIZED => Self::AuthenticationError { provider, message },
StatusCode::FORBIDDEN => Self::PermissionError { provider, message },
StatusCode::NOT_FOUND => Self::ApiEndpointNotFound { provider },
StatusCode::PAYLOAD_TOO_LARGE => Self::PromptTooLarge {
tokens: parse_prompt_too_long(&message),
},
StatusCode::TOO_MANY_REQUESTS => Self::RateLimitExceeded {
provider,
retry_after,
},
StatusCode::INTERNAL_SERVER_ERROR => Self::ApiInternalServerError { provider, message },
StatusCode::SERVICE_UNAVAILABLE => Self::ServerOverloaded {
provider,
retry_after,
},
_ if status_code.as_u16() == 529 => Self::ServerOverloaded {
provider,
retry_after,
},
_ => Self::HttpResponseError {
provider,
status_code,
message,
},
}
}
} }
impl From<AnthropicError> for LanguageModelCompletionError { impl From<AnthropicError> for LanguageModelCompletionError {
fn from(error: AnthropicError) -> Self { fn from(error: AnthropicError) -> Self {
let provider = ANTHROPIC_PROVIDER_NAME;
match error { match error {
AnthropicError::SerializeRequest(error) => Self::SerializeRequest(error), AnthropicError::SerializeRequest(error) => Self::SerializeRequest { provider, error },
AnthropicError::BuildRequestBody(error) => Self::BuildRequestBody(error), AnthropicError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error },
AnthropicError::HttpSend(error) => Self::HttpSend(error), AnthropicError::HttpSend(error) => Self::HttpSend { provider, error },
AnthropicError::DeserializeResponse(error) => Self::DeserializeResponse(error), AnthropicError::DeserializeResponse(error) => {
AnthropicError::ReadResponse(error) => Self::ApiReadResponseError(error), Self::DeserializeResponse { provider, error }
AnthropicError::HttpResponseError { status, body } => {
Self::HttpResponseError { status, body }
} }
AnthropicError::RateLimit { retry_after } => Self::RateLimitExceeded { retry_after }, AnthropicError::ReadResponse(error) => Self::ApiReadResponseError { provider, error },
AnthropicError::HttpResponseError {
status_code,
message,
} => Self::HttpResponseError {
provider,
status_code,
message,
},
AnthropicError::RateLimit { retry_after } => Self::RateLimitExceeded {
provider,
retry_after: Some(retry_after),
},
AnthropicError::ServerOverloaded { retry_after } => Self::ServerOverloaded {
provider,
retry_after: retry_after,
},
AnthropicError::ApiError(api_error) => api_error.into(), AnthropicError::ApiError(api_error) => api_error.into(),
AnthropicError::UnexpectedResponseFormat(error) => Self::UnknownResponseFormat(error),
} }
} }
} }
@ -141,23 +277,39 @@ impl From<AnthropicError> for LanguageModelCompletionError {
impl From<anthropic::ApiError> for LanguageModelCompletionError { impl From<anthropic::ApiError> for LanguageModelCompletionError {
fn from(error: anthropic::ApiError) -> Self { fn from(error: anthropic::ApiError) -> Self {
use anthropic::ApiErrorCode::*; use anthropic::ApiErrorCode::*;
let provider = ANTHROPIC_PROVIDER_NAME;
match error.code() { match error.code() {
Some(code) => match code { Some(code) => match code {
InvalidRequestError => LanguageModelCompletionError::BadRequestFormat, InvalidRequestError => Self::BadRequestFormat {
AuthenticationError => LanguageModelCompletionError::AuthenticationError, provider,
PermissionError => LanguageModelCompletionError::PermissionError, message: error.message,
NotFoundError => LanguageModelCompletionError::ApiEndpointNotFound, },
RequestTooLarge => LanguageModelCompletionError::PromptTooLarge { AuthenticationError => Self::AuthenticationError {
provider,
message: error.message,
},
PermissionError => Self::PermissionError {
provider,
message: error.message,
},
NotFoundError => Self::ApiEndpointNotFound { provider },
RequestTooLarge => Self::PromptTooLarge {
tokens: parse_prompt_too_long(&error.message), tokens: parse_prompt_too_long(&error.message),
}, },
RateLimitError => LanguageModelCompletionError::RateLimitExceeded { RateLimitError => Self::RateLimitExceeded {
retry_after: DEFAULT_RATE_LIMIT_RETRY_AFTER, provider,
retry_after: None,
},
ApiError => Self::ApiInternalServerError {
provider,
message: error.message,
},
OverloadedError => Self::ServerOverloaded {
provider,
retry_after: None,
}, },
ApiError => LanguageModelCompletionError::ApiInternalServerError,
OverloadedError => LanguageModelCompletionError::Overloaded,
}, },
None => LanguageModelCompletionError::Other(error.into()), None => Self::Other(error.into()),
} }
} }
} }
@ -278,6 +430,13 @@ pub trait LanguageModel: Send + Sync {
fn name(&self) -> LanguageModelName; fn name(&self) -> LanguageModelName;
fn provider_id(&self) -> LanguageModelProviderId; fn provider_id(&self) -> LanguageModelProviderId;
fn provider_name(&self) -> LanguageModelProviderName; fn provider_name(&self) -> LanguageModelProviderName;
fn upstream_provider_id(&self) -> LanguageModelProviderId {
self.provider_id()
}
fn upstream_provider_name(&self) -> LanguageModelProviderName {
self.provider_name()
}
fn telemetry_id(&self) -> String; fn telemetry_id(&self) -> String;
fn api_key(&self, _cx: &App) -> Option<String> { fn api_key(&self, _cx: &App) -> Option<String> {
@ -365,6 +524,9 @@ pub trait LanguageModel: Send + Sync {
Ok(LanguageModelCompletionEvent::RedactedThinking { .. }) => None, Ok(LanguageModelCompletionEvent::RedactedThinking { .. }) => None,
Ok(LanguageModelCompletionEvent::Stop(_)) => None, Ok(LanguageModelCompletionEvent::Stop(_)) => None,
Ok(LanguageModelCompletionEvent::ToolUse(_)) => None, Ok(LanguageModelCompletionEvent::ToolUse(_)) => None,
Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
..
}) => None,
Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => { Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
*last_token_usage.lock() = token_usage; *last_token_usage.lock() = token_usage;
None None
@ -395,39 +557,6 @@ pub trait LanguageModel: Send + Sync {
} }
} }
#[derive(Debug, Error)]
pub enum LanguageModelKnownError {
#[error("Context window limit exceeded ({tokens})")]
ContextWindowLimitExceeded { tokens: u64 },
#[error("Language model provider's API is currently overloaded")]
Overloaded,
#[error("Language model provider's API encountered an internal server error")]
ApiInternalServerError,
#[error("I/O error while reading response from language model provider's API: {0:?}")]
ReadResponseError(io::Error),
#[error("Error deserializing response from language model provider's API: {0:?}")]
DeserializeResponse(serde_json::Error),
#[error("Language model provider's API returned a response in an unknown format")]
UnknownResponseFormat(String),
#[error("Rate limit exceeded for language model provider's API; retry in {retry_after:?}")]
RateLimitExceeded { retry_after: Duration },
}
impl LanguageModelKnownError {
/// Attempts to map an HTTP response status code to a known error type.
/// Returns None if the status code doesn't map to a specific known error.
pub fn from_http_response(status: u16, _body: &str) -> Option<Self> {
match status {
429 => Some(Self::RateLimitExceeded {
retry_after: DEFAULT_RATE_LIMIT_RETRY_AFTER,
}),
503 => Some(Self::Overloaded),
500..=599 => Some(Self::ApiInternalServerError),
_ => None,
}
}
}
pub trait LanguageModelTool: 'static + DeserializeOwned + JsonSchema { pub trait LanguageModelTool: 'static + DeserializeOwned + JsonSchema {
fn name() -> String; fn name() -> String;
fn description() -> String; fn description() -> String;
@ -509,12 +638,30 @@ pub struct LanguageModelProviderId(pub SharedString);
#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)] #[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
pub struct LanguageModelProviderName(pub SharedString); pub struct LanguageModelProviderName(pub SharedString);
impl LanguageModelProviderId {
pub const fn new(id: &'static str) -> Self {
Self(SharedString::new_static(id))
}
}
impl LanguageModelProviderName {
pub const fn new(id: &'static str) -> Self {
Self(SharedString::new_static(id))
}
}
impl fmt::Display for LanguageModelProviderId { impl fmt::Display for LanguageModelProviderId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.0) write!(f, "{}", self.0)
} }
} }
impl fmt::Display for LanguageModelProviderName {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.0)
}
}
impl From<String> for LanguageModelId { impl From<String> for LanguageModelId {
fn from(value: String) -> Self { fn from(value: String) -> Self {
Self(SharedString::from(value)) Self(SharedString::from(value))

View file

@ -98,7 +98,7 @@ impl ConfiguredModel {
} }
pub fn is_provided_by_zed(&self) -> bool { pub fn is_provided_by_zed(&self) -> bool {
self.provider.id().0 == crate::ZED_CLOUD_PROVIDER_ID self.provider.id() == crate::ZED_CLOUD_PROVIDER_ID
} }
} }

View file

@ -1,3 +1,4 @@
use crate::ANTHROPIC_PROVIDER_ID;
use anthropic::ANTHROPIC_API_URL; use anthropic::ANTHROPIC_API_URL;
use anyhow::{Context as _, anyhow}; use anyhow::{Context as _, anyhow};
use client::telemetry::Telemetry; use client::telemetry::Telemetry;
@ -8,8 +9,6 @@ use std::sync::Arc;
use telemetry_events::{AssistantEventData, AssistantKind, AssistantPhase}; use telemetry_events::{AssistantEventData, AssistantKind, AssistantPhase};
use util::ResultExt; use util::ResultExt;
pub const ANTHROPIC_PROVIDER_ID: &str = "anthropic";
pub fn report_assistant_event( pub fn report_assistant_event(
event: AssistantEventData, event: AssistantEventData,
telemetry: Option<Arc<Telemetry>>, telemetry: Option<Arc<Telemetry>>,
@ -19,7 +18,7 @@ pub fn report_assistant_event(
) { ) {
if let Some(telemetry) = telemetry.as_ref() { if let Some(telemetry) = telemetry.as_ref() {
telemetry.report_assistant_event(event.clone()); telemetry.report_assistant_event(event.clone());
if telemetry.metrics_enabled() && event.model_provider == ANTHROPIC_PROVIDER_ID { if telemetry.metrics_enabled() && event.model_provider == ANTHROPIC_PROVIDER_ID.0 {
if let Some(api_key) = model_api_key { if let Some(api_key) = model_api_key {
executor executor
.spawn(async move { .spawn(async move {

View file

@ -33,8 +33,8 @@ use theme::ThemeSettings;
use ui::{Icon, IconName, List, Tooltip, prelude::*}; use ui::{Icon, IconName, List, Tooltip, prelude::*};
use util::ResultExt; use util::ResultExt;
const PROVIDER_ID: &str = language_model::ANTHROPIC_PROVIDER_ID; const PROVIDER_ID: LanguageModelProviderId = language_model::ANTHROPIC_PROVIDER_ID;
const PROVIDER_NAME: &str = "Anthropic"; const PROVIDER_NAME: LanguageModelProviderName = language_model::ANTHROPIC_PROVIDER_NAME;
#[derive(Default, Clone, Debug, PartialEq)] #[derive(Default, Clone, Debug, PartialEq)]
pub struct AnthropicSettings { pub struct AnthropicSettings {
@ -218,11 +218,11 @@ impl LanguageModelProviderState for AnthropicLanguageModelProvider {
impl LanguageModelProvider for AnthropicLanguageModelProvider { impl LanguageModelProvider for AnthropicLanguageModelProvider {
fn id(&self) -> LanguageModelProviderId { fn id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into()) PROVIDER_ID
} }
fn name(&self) -> LanguageModelProviderName { fn name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into()) PROVIDER_NAME
} }
fn icon(&self) -> IconName { fn icon(&self) -> IconName {
@ -403,7 +403,11 @@ impl AnthropicModel {
}; };
async move { async move {
let api_key = api_key.context("Missing Anthropic API Key")?; let Some(api_key) = api_key else {
return Err(LanguageModelCompletionError::NoApiKey {
provider: PROVIDER_NAME,
});
};
let request = let request =
anthropic::stream_completion(http_client.as_ref(), &api_url, &api_key, request); anthropic::stream_completion(http_client.as_ref(), &api_url, &api_key, request);
request.await.map_err(Into::into) request.await.map_err(Into::into)
@ -422,11 +426,11 @@ impl LanguageModel for AnthropicModel {
} }
fn provider_id(&self) -> LanguageModelProviderId { fn provider_id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into()) PROVIDER_ID
} }
fn provider_name(&self) -> LanguageModelProviderName { fn provider_name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into()) PROVIDER_NAME
} }
fn supports_tools(&self) -> bool { fn supports_tools(&self) -> bool {
@ -806,12 +810,14 @@ impl AnthropicEventMapper {
raw_input: tool_use.input_json.clone(), raw_input: tool_use.input_json.clone(),
}, },
)), )),
Err(json_parse_err) => Err(LanguageModelCompletionError::BadInputJson { Err(json_parse_err) => {
id: tool_use.id.into(), Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
tool_name: tool_use.name.into(), id: tool_use.id.into(),
raw_input: input_json.into(), tool_name: tool_use.name.into(),
json_parse_error: json_parse_err.to_string(), raw_input: input_json.into(),
}), json_parse_error: json_parse_err.to_string(),
})
}
}; };
vec![event_result] vec![event_result]

View file

@ -52,8 +52,8 @@ use util::ResultExt;
use crate::AllLanguageModelSettings; use crate::AllLanguageModelSettings;
const PROVIDER_ID: &str = "amazon-bedrock"; const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("amazon-bedrock");
const PROVIDER_NAME: &str = "Amazon Bedrock"; const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("Amazon Bedrock");
#[derive(Default, Clone, Deserialize, Serialize, PartialEq, Debug)] #[derive(Default, Clone, Deserialize, Serialize, PartialEq, Debug)]
pub struct BedrockCredentials { pub struct BedrockCredentials {
@ -285,11 +285,11 @@ impl BedrockLanguageModelProvider {
impl LanguageModelProvider for BedrockLanguageModelProvider { impl LanguageModelProvider for BedrockLanguageModelProvider {
fn id(&self) -> LanguageModelProviderId { fn id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into()) PROVIDER_ID
} }
fn name(&self) -> LanguageModelProviderName { fn name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into()) PROVIDER_NAME
} }
fn icon(&self) -> IconName { fn icon(&self) -> IconName {
@ -489,11 +489,11 @@ impl LanguageModel for BedrockModel {
} }
fn provider_id(&self) -> LanguageModelProviderId { fn provider_id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into()) PROVIDER_ID
} }
fn provider_name(&self) -> LanguageModelProviderName { fn provider_name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into()) PROVIDER_NAME
} }
fn supports_tools(&self) -> bool { fn supports_tools(&self) -> bool {

View file

@ -1,4 +1,4 @@
use anthropic::{AnthropicModelMode, parse_prompt_too_long}; use anthropic::AnthropicModelMode;
use anyhow::{Context as _, Result, anyhow}; use anyhow::{Context as _, Result, anyhow};
use client::{Client, ModelRequestUsage, UserStore, zed_urls}; use client::{Client, ModelRequestUsage, UserStore, zed_urls};
use futures::{ use futures::{
@ -8,25 +8,21 @@ use google_ai::GoogleModelMode;
use gpui::{ use gpui::{
AnyElement, AnyView, App, AsyncApp, Context, Entity, SemanticVersion, Subscription, Task, AnyElement, AnyView, App, AsyncApp, Context, Entity, SemanticVersion, Subscription, Task,
}; };
use http_client::http::{HeaderMap, HeaderValue};
use http_client::{AsyncBody, HttpClient, Method, Response, StatusCode}; use http_client::{AsyncBody, HttpClient, Method, Response, StatusCode};
use language_model::{ use language_model::{
AuthenticateError, LanguageModel, LanguageModelCacheConfiguration, AuthenticateError, LanguageModel, LanguageModelCacheConfiguration,
LanguageModelCompletionError, LanguageModelId, LanguageModelKnownError, LanguageModelName, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName,
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderTosView, LanguageModelRequest, LanguageModelToolChoice, LanguageModelProviderState, LanguageModelProviderTosView, LanguageModelRequest,
LanguageModelToolSchemaFormat, ModelRequestLimitReachedError, RateLimiter, LanguageModelToolChoice, LanguageModelToolSchemaFormat, LlmApiToken,
ZED_CLOUD_PROVIDER_ID, ModelRequestLimitReachedError, PaymentRequiredError, RateLimiter, RefreshLlmTokenListener,
};
use language_model::{
LanguageModelCompletionEvent, LanguageModelProvider, LlmApiToken, PaymentRequiredError,
RefreshLlmTokenListener,
}; };
use proto::Plan; use proto::Plan;
use release_channel::AppVersion; use release_channel::AppVersion;
use schemars::JsonSchema; use schemars::JsonSchema;
use serde::{Deserialize, Serialize, de::DeserializeOwned}; use serde::{Deserialize, Serialize, de::DeserializeOwned};
use settings::SettingsStore; use settings::SettingsStore;
use smol::Timer;
use smol::io::{AsyncReadExt, BufReader}; use smol::io::{AsyncReadExt, BufReader};
use std::pin::Pin; use std::pin::Pin;
use std::str::FromStr as _; use std::str::FromStr as _;
@ -47,7 +43,8 @@ use crate::provider::anthropic::{AnthropicEventMapper, count_anthropic_tokens, i
use crate::provider::google::{GoogleEventMapper, into_google}; use crate::provider::google::{GoogleEventMapper, into_google};
use crate::provider::open_ai::{OpenAiEventMapper, count_open_ai_tokens, into_open_ai}; use crate::provider::open_ai::{OpenAiEventMapper, count_open_ai_tokens, into_open_ai};
pub const PROVIDER_NAME: &str = "Zed"; const PROVIDER_ID: LanguageModelProviderId = language_model::ZED_CLOUD_PROVIDER_ID;
const PROVIDER_NAME: LanguageModelProviderName = language_model::ZED_CLOUD_PROVIDER_NAME;
#[derive(Default, Clone, Debug, PartialEq)] #[derive(Default, Clone, Debug, PartialEq)]
pub struct ZedDotDevSettings { pub struct ZedDotDevSettings {
@ -351,11 +348,11 @@ impl LanguageModelProviderState for CloudLanguageModelProvider {
impl LanguageModelProvider for CloudLanguageModelProvider { impl LanguageModelProvider for CloudLanguageModelProvider {
fn id(&self) -> LanguageModelProviderId { fn id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(ZED_CLOUD_PROVIDER_ID.into()) PROVIDER_ID
} }
fn name(&self) -> LanguageModelProviderName { fn name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into()) PROVIDER_NAME
} }
fn icon(&self) -> IconName { fn icon(&self) -> IconName {
@ -536,8 +533,6 @@ struct PerformLlmCompletionResponse {
} }
impl CloudLanguageModel { impl CloudLanguageModel {
const MAX_RETRIES: usize = 3;
async fn perform_llm_completion( async fn perform_llm_completion(
client: Arc<Client>, client: Arc<Client>,
llm_api_token: LlmApiToken, llm_api_token: LlmApiToken,
@ -547,8 +542,7 @@ impl CloudLanguageModel {
let http_client = &client.http_client(); let http_client = &client.http_client();
let mut token = llm_api_token.acquire(&client).await?; let mut token = llm_api_token.acquire(&client).await?;
let mut retries_remaining = Self::MAX_RETRIES; let mut refreshed_token = false;
let mut retry_delay = Duration::from_secs(1);
loop { loop {
let request_builder = http_client::Request::builder() let request_builder = http_client::Request::builder()
@ -590,14 +584,20 @@ impl CloudLanguageModel {
includes_status_messages, includes_status_messages,
tool_use_limit_reached, tool_use_limit_reached,
}); });
} else if response }
.headers()
.get(EXPIRED_LLM_TOKEN_HEADER_NAME) if !refreshed_token
.is_some() && response
.headers()
.get(EXPIRED_LLM_TOKEN_HEADER_NAME)
.is_some()
{ {
retries_remaining -= 1;
token = llm_api_token.refresh(&client).await?; token = llm_api_token.refresh(&client).await?;
} else if status == StatusCode::FORBIDDEN refreshed_token = true;
continue;
}
if status == StatusCode::FORBIDDEN
&& response && response
.headers() .headers()
.get(SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME) .get(SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME)
@ -622,35 +622,18 @@ impl CloudLanguageModel {
return Err(anyhow!(ModelRequestLimitReachedError { plan })); return Err(anyhow!(ModelRequestLimitReachedError { plan }));
} }
} }
anyhow::bail!("Forbidden");
} else if status.as_u16() >= 500 && status.as_u16() < 600 {
// If we encounter an error in the 500 range, retry after a delay.
// We've seen at least these in the wild from API providers:
// * 500 Internal Server Error
// * 502 Bad Gateway
// * 529 Service Overloaded
if retries_remaining == 0 {
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
anyhow::bail!(
"cloud language model completion failed after {} retries with status {status}: {body}",
Self::MAX_RETRIES
);
}
Timer::after(retry_delay).await;
retries_remaining -= 1;
retry_delay *= 2; // If it fails again, wait longer.
} else if status == StatusCode::PAYMENT_REQUIRED { } else if status == StatusCode::PAYMENT_REQUIRED {
return Err(anyhow!(PaymentRequiredError)); return Err(anyhow!(PaymentRequiredError));
} else {
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
return Err(anyhow!(ApiError { status, body }));
} }
let mut body = String::new();
let headers = response.headers().clone();
response.body_mut().read_to_string(&mut body).await?;
return Err(anyhow!(ApiError {
status,
body,
headers
}));
} }
} }
} }
@ -660,6 +643,19 @@ impl CloudLanguageModel {
struct ApiError { struct ApiError {
status: StatusCode, status: StatusCode,
body: String, body: String,
headers: HeaderMap<HeaderValue>,
}
impl From<ApiError> for LanguageModelCompletionError {
fn from(error: ApiError) -> Self {
let retry_after = None;
LanguageModelCompletionError::from_http_status(
PROVIDER_NAME,
error.status,
error.body,
retry_after,
)
}
} }
impl LanguageModel for CloudLanguageModel { impl LanguageModel for CloudLanguageModel {
@ -672,11 +668,29 @@ impl LanguageModel for CloudLanguageModel {
} }
fn provider_id(&self) -> LanguageModelProviderId { fn provider_id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(ZED_CLOUD_PROVIDER_ID.into()) PROVIDER_ID
} }
fn provider_name(&self) -> LanguageModelProviderName { fn provider_name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into()) PROVIDER_NAME
}
fn upstream_provider_id(&self) -> LanguageModelProviderId {
use zed_llm_client::LanguageModelProvider::*;
match self.model.provider {
Anthropic => language_model::ANTHROPIC_PROVIDER_ID,
OpenAi => language_model::OPEN_AI_PROVIDER_ID,
Google => language_model::GOOGLE_PROVIDER_ID,
}
}
fn upstream_provider_name(&self) -> LanguageModelProviderName {
use zed_llm_client::LanguageModelProvider::*;
match self.model.provider {
Anthropic => language_model::ANTHROPIC_PROVIDER_NAME,
OpenAi => language_model::OPEN_AI_PROVIDER_NAME,
Google => language_model::GOOGLE_PROVIDER_NAME,
}
} }
fn supports_tools(&self) -> bool { fn supports_tools(&self) -> bool {
@ -776,6 +790,7 @@ impl LanguageModel for CloudLanguageModel {
.body(serde_json::to_string(&request_body)?.into())?; .body(serde_json::to_string(&request_body)?.into())?;
let mut response = http_client.send(request).await?; let mut response = http_client.send(request).await?;
let status = response.status(); let status = response.status();
let headers = response.headers().clone();
let mut response_body = String::new(); let mut response_body = String::new();
response response
.body_mut() .body_mut()
@ -790,7 +805,8 @@ impl LanguageModel for CloudLanguageModel {
} else { } else {
Err(anyhow!(ApiError { Err(anyhow!(ApiError {
status, status,
body: response_body body: response_body,
headers
})) }))
} }
} }
@ -855,18 +871,7 @@ impl LanguageModel for CloudLanguageModel {
) )
.await .await
.map_err(|err| match err.downcast::<ApiError>() { .map_err(|err| match err.downcast::<ApiError>() {
Ok(api_err) => { Ok(api_err) => anyhow!(LanguageModelCompletionError::from(api_err)),
if api_err.status == StatusCode::BAD_REQUEST {
if let Some(tokens) = parse_prompt_too_long(&api_err.body) {
return anyhow!(
LanguageModelKnownError::ContextWindowLimitExceeded {
tokens
}
);
}
}
anyhow!(api_err)
}
Err(err) => anyhow!(err), Err(err) => anyhow!(err),
})?; })?;
@ -995,7 +1000,7 @@ where
.flat_map(move |event| { .flat_map(move |event| {
futures::stream::iter(match event { futures::stream::iter(match event {
Err(error) => { Err(error) => {
vec![Err(LanguageModelCompletionError::Other(error))] vec![Err(LanguageModelCompletionError::from(error))]
} }
Ok(CloudCompletionEvent::Status(event)) => { Ok(CloudCompletionEvent::Status(event)) => {
vec![Ok(LanguageModelCompletionEvent::StatusUpdate(event))] vec![Ok(LanguageModelCompletionEvent::StatusUpdate(event))]

View file

@ -35,8 +35,9 @@ use super::anthropic::count_anthropic_tokens;
use super::google::count_google_tokens; use super::google::count_google_tokens;
use super::open_ai::count_open_ai_tokens; use super::open_ai::count_open_ai_tokens;
const PROVIDER_ID: &str = "copilot_chat"; const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("copilot_chat");
const PROVIDER_NAME: &str = "GitHub Copilot Chat"; const PROVIDER_NAME: LanguageModelProviderName =
LanguageModelProviderName::new("GitHub Copilot Chat");
pub struct CopilotChatLanguageModelProvider { pub struct CopilotChatLanguageModelProvider {
state: Entity<State>, state: Entity<State>,
@ -102,11 +103,11 @@ impl LanguageModelProviderState for CopilotChatLanguageModelProvider {
impl LanguageModelProvider for CopilotChatLanguageModelProvider { impl LanguageModelProvider for CopilotChatLanguageModelProvider {
fn id(&self) -> LanguageModelProviderId { fn id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into()) PROVIDER_ID
} }
fn name(&self) -> LanguageModelProviderName { fn name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into()) PROVIDER_NAME
} }
fn icon(&self) -> IconName { fn icon(&self) -> IconName {
@ -201,11 +202,11 @@ impl LanguageModel for CopilotChatLanguageModel {
} }
fn provider_id(&self) -> LanguageModelProviderId { fn provider_id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into()) PROVIDER_ID
} }
fn provider_name(&self) -> LanguageModelProviderName { fn provider_name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into()) PROVIDER_NAME
} }
fn supports_tools(&self) -> bool { fn supports_tools(&self) -> bool {
@ -391,24 +392,24 @@ pub fn map_to_language_model_completion_events(
serde_json::Value::from_str(&tool_call.arguments) serde_json::Value::from_str(&tool_call.arguments)
}; };
match arguments { match arguments {
Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse( Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse(
LanguageModelToolUse { LanguageModelToolUse {
id: tool_call.id.clone().into(), id: tool_call.id.clone().into(),
name: tool_call.name.as_str().into(), name: tool_call.name.as_str().into(),
is_input_complete: true, is_input_complete: true,
input, input,
raw_input: tool_call.arguments.clone(), raw_input: tool_call.arguments.clone(),
}, },
)), )),
Err(error) => { Err(error) => Ok(
Err(LanguageModelCompletionError::BadInputJson { LanguageModelCompletionEvent::ToolUseJsonParseError {
id: tool_call.id.into(), id: tool_call.id.into(),
tool_name: tool_call.name.as_str().into(), tool_name: tool_call.name.as_str().into(),
raw_input: tool_call.arguments.into(), raw_input: tool_call.arguments.into(),
json_parse_error: error.to_string(), json_parse_error: error.to_string(),
}) },
} ),
} }
}, },
)); ));

View file

@ -28,8 +28,8 @@ use util::ResultExt;
use crate::{AllLanguageModelSettings, ui::InstructionListItem}; use crate::{AllLanguageModelSettings, ui::InstructionListItem};
const PROVIDER_ID: &str = "deepseek"; const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("deepseek");
const PROVIDER_NAME: &str = "DeepSeek"; const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("DeepSeek");
const DEEPSEEK_API_KEY_VAR: &str = "DEEPSEEK_API_KEY"; const DEEPSEEK_API_KEY_VAR: &str = "DEEPSEEK_API_KEY";
#[derive(Default)] #[derive(Default)]
@ -174,11 +174,11 @@ impl LanguageModelProviderState for DeepSeekLanguageModelProvider {
impl LanguageModelProvider for DeepSeekLanguageModelProvider { impl LanguageModelProvider for DeepSeekLanguageModelProvider {
fn id(&self) -> LanguageModelProviderId { fn id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into()) PROVIDER_ID
} }
fn name(&self) -> LanguageModelProviderName { fn name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into()) PROVIDER_NAME
} }
fn icon(&self) -> IconName { fn icon(&self) -> IconName {
@ -283,11 +283,11 @@ impl LanguageModel for DeepSeekLanguageModel {
} }
fn provider_id(&self) -> LanguageModelProviderId { fn provider_id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into()) PROVIDER_ID
} }
fn provider_name(&self) -> LanguageModelProviderName { fn provider_name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into()) PROVIDER_NAME
} }
fn supports_tools(&self) -> bool { fn supports_tools(&self) -> bool {
@ -466,7 +466,7 @@ impl DeepSeekEventMapper {
events.flat_map(move |event| { events.flat_map(move |event| {
futures::stream::iter(match event { futures::stream::iter(match event {
Ok(event) => self.map_event(event), Ok(event) => self.map_event(event),
Err(error) => vec![Err(LanguageModelCompletionError::Other(anyhow!(error)))], Err(error) => vec![Err(LanguageModelCompletionError::from(error))],
}) })
}) })
} }
@ -476,7 +476,7 @@ impl DeepSeekEventMapper {
event: deepseek::StreamResponse, event: deepseek::StreamResponse,
) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> { ) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
let Some(choice) = event.choices.first() else { let Some(choice) = event.choices.first() else {
return vec![Err(LanguageModelCompletionError::Other(anyhow!( return vec![Err(LanguageModelCompletionError::from(anyhow!(
"Response contained no choices" "Response contained no choices"
)))]; )))];
}; };
@ -538,8 +538,8 @@ impl DeepSeekEventMapper {
raw_input: tool_call.arguments.clone(), raw_input: tool_call.arguments.clone(),
}, },
)), )),
Err(error) => Err(LanguageModelCompletionError::BadInputJson { Err(error) => Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
id: tool_call.id.into(), id: tool_call.id.clone().into(),
tool_name: tool_call.name.as_str().into(), tool_name: tool_call.name.as_str().into(),
raw_input: tool_call.arguments.into(), raw_input: tool_call.arguments.into(),
json_parse_error: error.to_string(), json_parse_error: error.to_string(),

View file

@ -37,8 +37,8 @@ use util::ResultExt;
use crate::AllLanguageModelSettings; use crate::AllLanguageModelSettings;
use crate::ui::InstructionListItem; use crate::ui::InstructionListItem;
const PROVIDER_ID: &str = "google"; const PROVIDER_ID: LanguageModelProviderId = language_model::GOOGLE_PROVIDER_ID;
const PROVIDER_NAME: &str = "Google AI"; const PROVIDER_NAME: LanguageModelProviderName = language_model::GOOGLE_PROVIDER_NAME;
#[derive(Default, Clone, Debug, PartialEq)] #[derive(Default, Clone, Debug, PartialEq)]
pub struct GoogleSettings { pub struct GoogleSettings {
@ -207,11 +207,11 @@ impl LanguageModelProviderState for GoogleLanguageModelProvider {
impl LanguageModelProvider for GoogleLanguageModelProvider { impl LanguageModelProvider for GoogleLanguageModelProvider {
fn id(&self) -> LanguageModelProviderId { fn id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into()) PROVIDER_ID
} }
fn name(&self) -> LanguageModelProviderName { fn name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into()) PROVIDER_NAME
} }
fn icon(&self) -> IconName { fn icon(&self) -> IconName {
@ -334,11 +334,11 @@ impl LanguageModel for GoogleLanguageModel {
} }
fn provider_id(&self) -> LanguageModelProviderId { fn provider_id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into()) PROVIDER_ID
} }
fn provider_name(&self) -> LanguageModelProviderName { fn provider_name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into()) PROVIDER_NAME
} }
fn supports_tools(&self) -> bool { fn supports_tools(&self) -> bool {
@ -423,9 +423,7 @@ impl LanguageModel for GoogleLanguageModel {
); );
let request = self.stream_completion(request, cx); let request = self.stream_completion(request, cx);
let future = self.request_limiter.stream(async move { let future = self.request_limiter.stream(async move {
let response = request let response = request.await.map_err(LanguageModelCompletionError::from)?;
.await
.map_err(|err| LanguageModelCompletionError::Other(anyhow!(err)))?;
Ok(GoogleEventMapper::new().map_stream(response)) Ok(GoogleEventMapper::new().map_stream(response))
}); });
async move { Ok(future.await?.boxed()) }.boxed() async move { Ok(future.await?.boxed()) }.boxed()
@ -622,7 +620,7 @@ impl GoogleEventMapper {
futures::stream::iter(match event { futures::stream::iter(match event {
Some(Ok(event)) => self.map_event(event), Some(Ok(event)) => self.map_event(event),
Some(Err(error)) => { Some(Err(error)) => {
vec![Err(LanguageModelCompletionError::Other(anyhow!(error)))] vec![Err(LanguageModelCompletionError::from(error))]
} }
None => vec![Ok(LanguageModelCompletionEvent::Stop(self.stop_reason))], None => vec![Ok(LanguageModelCompletionEvent::Stop(self.stop_reason))],
}) })

View file

@ -31,8 +31,8 @@ const LMSTUDIO_DOWNLOAD_URL: &str = "https://lmstudio.ai/download";
const LMSTUDIO_CATALOG_URL: &str = "https://lmstudio.ai/models"; const LMSTUDIO_CATALOG_URL: &str = "https://lmstudio.ai/models";
const LMSTUDIO_SITE: &str = "https://lmstudio.ai/"; const LMSTUDIO_SITE: &str = "https://lmstudio.ai/";
const PROVIDER_ID: &str = "lmstudio"; const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("lmstudio");
const PROVIDER_NAME: &str = "LM Studio"; const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("LM Studio");
#[derive(Default, Debug, Clone, PartialEq)] #[derive(Default, Debug, Clone, PartialEq)]
pub struct LmStudioSettings { pub struct LmStudioSettings {
@ -156,11 +156,11 @@ impl LanguageModelProviderState for LmStudioLanguageModelProvider {
impl LanguageModelProvider for LmStudioLanguageModelProvider { impl LanguageModelProvider for LmStudioLanguageModelProvider {
fn id(&self) -> LanguageModelProviderId { fn id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into()) PROVIDER_ID
} }
fn name(&self) -> LanguageModelProviderName { fn name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into()) PROVIDER_NAME
} }
fn icon(&self) -> IconName { fn icon(&self) -> IconName {
@ -386,11 +386,11 @@ impl LanguageModel for LmStudioLanguageModel {
} }
fn provider_id(&self) -> LanguageModelProviderId { fn provider_id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into()) PROVIDER_ID
} }
fn provider_name(&self) -> LanguageModelProviderName { fn provider_name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into()) PROVIDER_NAME
} }
fn supports_tools(&self) -> bool { fn supports_tools(&self) -> bool {
@ -474,7 +474,7 @@ impl LmStudioEventMapper {
events.flat_map(move |event| { events.flat_map(move |event| {
futures::stream::iter(match event { futures::stream::iter(match event {
Ok(event) => self.map_event(event), Ok(event) => self.map_event(event),
Err(error) => vec![Err(LanguageModelCompletionError::Other(anyhow!(error)))], Err(error) => vec![Err(LanguageModelCompletionError::from(error))],
}) })
}) })
} }
@ -484,7 +484,7 @@ impl LmStudioEventMapper {
event: lmstudio::ResponseStreamEvent, event: lmstudio::ResponseStreamEvent,
) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> { ) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
let Some(choice) = event.choices.into_iter().next() else { let Some(choice) = event.choices.into_iter().next() else {
return vec![Err(LanguageModelCompletionError::Other(anyhow!( return vec![Err(LanguageModelCompletionError::from(anyhow!(
"Response contained no choices" "Response contained no choices"
)))]; )))];
}; };
@ -553,7 +553,7 @@ impl LmStudioEventMapper {
raw_input: tool_call.arguments, raw_input: tool_call.arguments,
}, },
)), )),
Err(error) => Err(LanguageModelCompletionError::BadInputJson { Err(error) => Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
id: tool_call.id.into(), id: tool_call.id.into(),
tool_name: tool_call.name.into(), tool_name: tool_call.name.into(),
raw_input: tool_call.arguments.into(), raw_input: tool_call.arguments.into(),

View file

@ -2,8 +2,7 @@ use anyhow::{Context as _, Result, anyhow};
use collections::BTreeMap; use collections::BTreeMap;
use credentials_provider::CredentialsProvider; use credentials_provider::CredentialsProvider;
use editor::{Editor, EditorElement, EditorStyle}; use editor::{Editor, EditorElement, EditorStyle};
use futures::stream::BoxStream; use futures::{FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream};
use futures::{FutureExt, StreamExt, future::BoxFuture};
use gpui::{ use gpui::{
AnyView, App, AsyncApp, Context, Entity, FontStyle, Subscription, Task, TextStyle, WhiteSpace, AnyView, App, AsyncApp, Context, Entity, FontStyle, Subscription, Task, TextStyle, WhiteSpace,
}; };
@ -15,6 +14,7 @@ use language_model::{
LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent, LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent,
RateLimiter, Role, StopReason, TokenUsage, RateLimiter, Role, StopReason, TokenUsage,
}; };
use mistral::StreamResponse;
use schemars::JsonSchema; use schemars::JsonSchema;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsStore}; use settings::{Settings, SettingsStore};
@ -29,8 +29,8 @@ use util::ResultExt;
use crate::{AllLanguageModelSettings, ui::InstructionListItem}; use crate::{AllLanguageModelSettings, ui::InstructionListItem};
const PROVIDER_ID: &str = "mistral"; const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("mistral");
const PROVIDER_NAME: &str = "Mistral"; const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("Mistral");
#[derive(Default, Clone, Debug, PartialEq)] #[derive(Default, Clone, Debug, PartialEq)]
pub struct MistralSettings { pub struct MistralSettings {
@ -171,11 +171,11 @@ impl LanguageModelProviderState for MistralLanguageModelProvider {
impl LanguageModelProvider for MistralLanguageModelProvider { impl LanguageModelProvider for MistralLanguageModelProvider {
fn id(&self) -> LanguageModelProviderId { fn id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into()) PROVIDER_ID
} }
fn name(&self) -> LanguageModelProviderName { fn name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into()) PROVIDER_NAME
} }
fn icon(&self) -> IconName { fn icon(&self) -> IconName {
@ -298,11 +298,11 @@ impl LanguageModel for MistralLanguageModel {
} }
fn provider_id(&self) -> LanguageModelProviderId { fn provider_id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into()) PROVIDER_ID
} }
fn provider_name(&self) -> LanguageModelProviderName { fn provider_name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into()) PROVIDER_NAME
} }
fn supports_tools(&self) -> bool { fn supports_tools(&self) -> bool {
@ -579,13 +579,13 @@ impl MistralEventMapper {
pub fn map_stream( pub fn map_stream(
mut self, mut self,
events: Pin<Box<dyn Send + futures::Stream<Item = Result<mistral::StreamResponse>>>>, events: Pin<Box<dyn Send + Stream<Item = Result<StreamResponse>>>>,
) -> impl futures::Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> ) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
{ {
events.flat_map(move |event| { events.flat_map(move |event| {
futures::stream::iter(match event { futures::stream::iter(match event {
Ok(event) => self.map_event(event), Ok(event) => self.map_event(event),
Err(error) => vec![Err(LanguageModelCompletionError::Other(anyhow!(error)))], Err(error) => vec![Err(LanguageModelCompletionError::from(error))],
}) })
}) })
} }
@ -595,7 +595,7 @@ impl MistralEventMapper {
event: mistral::StreamResponse, event: mistral::StreamResponse,
) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> { ) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
let Some(choice) = event.choices.first() else { let Some(choice) = event.choices.first() else {
return vec![Err(LanguageModelCompletionError::Other(anyhow!( return vec![Err(LanguageModelCompletionError::from(anyhow!(
"Response contained no choices" "Response contained no choices"
)))]; )))];
}; };
@ -660,7 +660,7 @@ impl MistralEventMapper {
for (_, tool_call) in self.tool_calls_by_index.drain() { for (_, tool_call) in self.tool_calls_by_index.drain() {
if tool_call.id.is_empty() || tool_call.name.is_empty() { if tool_call.id.is_empty() || tool_call.name.is_empty() {
results.push(Err(LanguageModelCompletionError::Other(anyhow!( results.push(Err(LanguageModelCompletionError::from(anyhow!(
"Received incomplete tool call: missing id or name" "Received incomplete tool call: missing id or name"
)))); ))));
continue; continue;
@ -676,12 +676,14 @@ impl MistralEventMapper {
raw_input: tool_call.arguments, raw_input: tool_call.arguments,
}, },
))), ))),
Err(error) => results.push(Err(LanguageModelCompletionError::BadInputJson { Err(error) => {
id: tool_call.id.into(), results.push(Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
tool_name: tool_call.name.into(), id: tool_call.id.into(),
raw_input: tool_call.arguments.into(), tool_name: tool_call.name.into(),
json_parse_error: error.to_string(), raw_input: tool_call.arguments.into(),
})), json_parse_error: error.to_string(),
}))
}
} }
} }

View file

@ -30,8 +30,8 @@ const OLLAMA_DOWNLOAD_URL: &str = "https://ollama.com/download";
const OLLAMA_LIBRARY_URL: &str = "https://ollama.com/library"; const OLLAMA_LIBRARY_URL: &str = "https://ollama.com/library";
const OLLAMA_SITE: &str = "https://ollama.com/"; const OLLAMA_SITE: &str = "https://ollama.com/";
const PROVIDER_ID: &str = "ollama"; const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("ollama");
const PROVIDER_NAME: &str = "Ollama"; const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("Ollama");
#[derive(Default, Debug, Clone, PartialEq)] #[derive(Default, Debug, Clone, PartialEq)]
pub struct OllamaSettings { pub struct OllamaSettings {
@ -181,11 +181,11 @@ impl LanguageModelProviderState for OllamaLanguageModelProvider {
impl LanguageModelProvider for OllamaLanguageModelProvider { impl LanguageModelProvider for OllamaLanguageModelProvider {
fn id(&self) -> LanguageModelProviderId { fn id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into()) PROVIDER_ID
} }
fn name(&self) -> LanguageModelProviderName { fn name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into()) PROVIDER_NAME
} }
fn icon(&self) -> IconName { fn icon(&self) -> IconName {
@ -350,11 +350,11 @@ impl LanguageModel for OllamaLanguageModel {
} }
fn provider_id(&self) -> LanguageModelProviderId { fn provider_id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into()) PROVIDER_ID
} }
fn provider_name(&self) -> LanguageModelProviderName { fn provider_name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into()) PROVIDER_NAME
} }
fn supports_tools(&self) -> bool { fn supports_tools(&self) -> bool {
@ -453,7 +453,7 @@ fn map_to_language_model_completion_events(
let delta = match response { let delta = match response {
Ok(delta) => delta, Ok(delta) => delta,
Err(e) => { Err(e) => {
let event = Err(LanguageModelCompletionError::Other(anyhow!(e))); let event = Err(LanguageModelCompletionError::from(anyhow!(e)));
return Some((vec![event], state)); return Some((vec![event], state));
} }
}; };

View file

@ -31,8 +31,8 @@ use util::ResultExt;
use crate::OpenAiSettingsContent; use crate::OpenAiSettingsContent;
use crate::{AllLanguageModelSettings, ui::InstructionListItem}; use crate::{AllLanguageModelSettings, ui::InstructionListItem};
const PROVIDER_ID: &str = "openai"; const PROVIDER_ID: LanguageModelProviderId = language_model::OPEN_AI_PROVIDER_ID;
const PROVIDER_NAME: &str = "OpenAI"; const PROVIDER_NAME: LanguageModelProviderName = language_model::OPEN_AI_PROVIDER_NAME;
#[derive(Default, Clone, Debug, PartialEq)] #[derive(Default, Clone, Debug, PartialEq)]
pub struct OpenAiSettings { pub struct OpenAiSettings {
@ -173,11 +173,11 @@ impl LanguageModelProviderState for OpenAiLanguageModelProvider {
impl LanguageModelProvider for OpenAiLanguageModelProvider { impl LanguageModelProvider for OpenAiLanguageModelProvider {
fn id(&self) -> LanguageModelProviderId { fn id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into()) PROVIDER_ID
} }
fn name(&self) -> LanguageModelProviderName { fn name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into()) PROVIDER_NAME
} }
fn icon(&self) -> IconName { fn icon(&self) -> IconName {
@ -267,7 +267,11 @@ impl OpenAiLanguageModel {
}; };
let future = self.request_limiter.stream(async move { let future = self.request_limiter.stream(async move {
let api_key = api_key.context("Missing OpenAI API Key")?; let Some(api_key) = api_key else {
return Err(LanguageModelCompletionError::NoApiKey {
provider: PROVIDER_NAME,
});
};
let request = stream_completion(http_client.as_ref(), &api_url, &api_key, request); let request = stream_completion(http_client.as_ref(), &api_url, &api_key, request);
let response = request.await?; let response = request.await?;
Ok(response) Ok(response)
@ -287,11 +291,11 @@ impl LanguageModel for OpenAiLanguageModel {
} }
fn provider_id(&self) -> LanguageModelProviderId { fn provider_id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into()) PROVIDER_ID
} }
fn provider_name(&self) -> LanguageModelProviderName { fn provider_name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into()) PROVIDER_NAME
} }
fn supports_tools(&self) -> bool { fn supports_tools(&self) -> bool {
@ -525,7 +529,7 @@ impl OpenAiEventMapper {
events.flat_map(move |event| { events.flat_map(move |event| {
futures::stream::iter(match event { futures::stream::iter(match event {
Ok(event) => self.map_event(event), Ok(event) => self.map_event(event),
Err(error) => vec![Err(LanguageModelCompletionError::Other(anyhow!(error)))], Err(error) => vec![Err(LanguageModelCompletionError::from(anyhow!(error)))],
}) })
}) })
} }
@ -588,10 +592,10 @@ impl OpenAiEventMapper {
raw_input: tool_call.arguments.clone(), raw_input: tool_call.arguments.clone(),
}, },
)), )),
Err(error) => Err(LanguageModelCompletionError::BadInputJson { Err(error) => Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
id: tool_call.id.into(), id: tool_call.id.into(),
tool_name: tool_call.name.as_str().into(), tool_name: tool_call.name.into(),
raw_input: tool_call.arguments.into(), raw_input: tool_call.arguments.clone().into(),
json_parse_error: error.to_string(), json_parse_error: error.to_string(),
}), }),
} }

View file

@ -29,8 +29,8 @@ use util::ResultExt;
use crate::{AllLanguageModelSettings, ui::InstructionListItem}; use crate::{AllLanguageModelSettings, ui::InstructionListItem};
const PROVIDER_ID: &str = "openrouter"; const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("openrouter");
const PROVIDER_NAME: &str = "OpenRouter"; const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("OpenRouter");
#[derive(Default, Clone, Debug, PartialEq)] #[derive(Default, Clone, Debug, PartialEq)]
pub struct OpenRouterSettings { pub struct OpenRouterSettings {
@ -244,11 +244,11 @@ impl LanguageModelProviderState for OpenRouterLanguageModelProvider {
impl LanguageModelProvider for OpenRouterLanguageModelProvider { impl LanguageModelProvider for OpenRouterLanguageModelProvider {
fn id(&self) -> LanguageModelProviderId { fn id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into()) PROVIDER_ID
} }
fn name(&self) -> LanguageModelProviderName { fn name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into()) PROVIDER_NAME
} }
fn icon(&self) -> IconName { fn icon(&self) -> IconName {
@ -363,11 +363,11 @@ impl LanguageModel for OpenRouterLanguageModel {
} }
fn provider_id(&self) -> LanguageModelProviderId { fn provider_id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into()) PROVIDER_ID
} }
fn provider_name(&self) -> LanguageModelProviderName { fn provider_name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into()) PROVIDER_NAME
} }
fn supports_tools(&self) -> bool { fn supports_tools(&self) -> bool {
@ -607,7 +607,7 @@ impl OpenRouterEventMapper {
events.flat_map(move |event| { events.flat_map(move |event| {
futures::stream::iter(match event { futures::stream::iter(match event {
Ok(event) => self.map_event(event), Ok(event) => self.map_event(event),
Err(error) => vec![Err(LanguageModelCompletionError::Other(anyhow!(error)))], Err(error) => vec![Err(LanguageModelCompletionError::from(anyhow!(error)))],
}) })
}) })
} }
@ -617,7 +617,7 @@ impl OpenRouterEventMapper {
event: ResponseStreamEvent, event: ResponseStreamEvent,
) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> { ) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
let Some(choice) = event.choices.first() else { let Some(choice) = event.choices.first() else {
return vec![Err(LanguageModelCompletionError::Other(anyhow!( return vec![Err(LanguageModelCompletionError::from(anyhow!(
"Response contained no choices" "Response contained no choices"
)))]; )))];
}; };
@ -683,10 +683,10 @@ impl OpenRouterEventMapper {
raw_input: tool_call.arguments.clone(), raw_input: tool_call.arguments.clone(),
}, },
)), )),
Err(error) => Err(LanguageModelCompletionError::BadInputJson { Err(error) => Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
id: tool_call.id.into(), id: tool_call.id.clone().into(),
tool_name: tool_call.name.as_str().into(), tool_name: tool_call.name.as_str().into(),
raw_input: tool_call.arguments.into(), raw_input: tool_call.arguments.clone().into(),
json_parse_error: error.to_string(), json_parse_error: error.to_string(),
}), }),
} }

View file

@ -25,8 +25,8 @@ use util::ResultExt;
use crate::{AllLanguageModelSettings, ui::InstructionListItem}; use crate::{AllLanguageModelSettings, ui::InstructionListItem};
const PROVIDER_ID: &str = "vercel"; const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("vercel");
const PROVIDER_NAME: &str = "Vercel"; const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("Vercel");
#[derive(Default, Clone, Debug, PartialEq)] #[derive(Default, Clone, Debug, PartialEq)]
pub struct VercelSettings { pub struct VercelSettings {
@ -172,11 +172,11 @@ impl LanguageModelProviderState for VercelLanguageModelProvider {
impl LanguageModelProvider for VercelLanguageModelProvider { impl LanguageModelProvider for VercelLanguageModelProvider {
fn id(&self) -> LanguageModelProviderId { fn id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into()) PROVIDER_ID
} }
fn name(&self) -> LanguageModelProviderName { fn name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into()) PROVIDER_NAME
} }
fn icon(&self) -> IconName { fn icon(&self) -> IconName {
@ -269,7 +269,11 @@ impl VercelLanguageModel {
}; };
let future = self.request_limiter.stream(async move { let future = self.request_limiter.stream(async move {
let api_key = api_key.context("Missing Vercel API Key")?; let Some(api_key) = api_key else {
return Err(LanguageModelCompletionError::NoApiKey {
provider: PROVIDER_NAME,
});
};
let request = let request =
open_ai::stream_completion(http_client.as_ref(), &api_url, &api_key, request); open_ai::stream_completion(http_client.as_ref(), &api_url, &api_key, request);
let response = request.await?; let response = request.await?;
@ -290,11 +294,11 @@ impl LanguageModel for VercelLanguageModel {
} }
fn provider_id(&self) -> LanguageModelProviderId { fn provider_id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into()) PROVIDER_ID
} }
fn provider_name(&self) -> LanguageModelProviderName { fn provider_name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into()) PROVIDER_NAME
} }
fn supports_tools(&self) -> bool { fn supports_tools(&self) -> bool {

View file

@ -7,10 +7,7 @@ use gpui::{App, AppContext, Context, Entity, Subscription, Task};
use http_client::{HttpClient, Method}; use http_client::{HttpClient, Method};
use language_model::{LlmApiToken, RefreshLlmTokenListener}; use language_model::{LlmApiToken, RefreshLlmTokenListener};
use web_search::{WebSearchProvider, WebSearchProviderId}; use web_search::{WebSearchProvider, WebSearchProviderId};
use zed_llm_client::{ use zed_llm_client::{EXPIRED_LLM_TOKEN_HEADER_NAME, WebSearchBody, WebSearchResponse};
CLIENT_SUPPORTS_EXA_WEB_SEARCH_PROVIDER_HEADER_NAME, EXPIRED_LLM_TOKEN_HEADER_NAME,
WebSearchBody, WebSearchResponse,
};
pub struct CloudWebSearchProvider { pub struct CloudWebSearchProvider {
state: Entity<State>, state: Entity<State>,
@ -92,7 +89,6 @@ async fn perform_web_search(
.uri(http_client.build_zed_llm_url("/web_search", &[])?.as_ref()) .uri(http_client.build_zed_llm_url("/web_search", &[])?.as_ref())
.header("Content-Type", "application/json") .header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {token}")) .header("Authorization", format!("Bearer {token}"))
.header(CLIENT_SUPPORTS_EXA_WEB_SEARCH_PROVIDER_HEADER_NAME, "true")
.body(serde_json::to_string(&body)?.into())?; .body(serde_json::to_string(&body)?.into())?;
let mut response = http_client let mut response = http_client
.send(request) .send(request)