agent: Improve error handling and retry for zed-provided models (#33565)
* Updates to `zed_llm_client-0.8.5` which adds support for `retry_after` when anthropic provides it. * Distinguishes upstream provider errors and rate limits from errors that originate from zed's servers * Moves `LanguageModelCompletionError::BadInputJson` to `LanguageModelCompletionEvent::ToolUseJsonParseError`. While arguably this is an error case, the logic in thread is cleaner with this move. There is also precedent for inclusion of errors in the event type - `CompletionRequestStatus::Failed` is how cloud errors arrive. * Updates `PROVIDER_ID` / `PROVIDER_NAME` constants to use proper types instead of `&str`, since they can be constructed in a const fashion. * Removes use of `CLIENT_SUPPORTS_EXA_WEB_SEARCH_PROVIDER_HEADER_NAME` as the server no longer reads this header and just defaults to that behavior. Release notes for this is covered by #33275 Release Notes: - N/A --------- Co-authored-by: Richard Feldman <oss@rtfeldman.com> Co-authored-by: Richard <richard@zed.dev>
This commit is contained in:
parent
f022a13091
commit
d497f52e17
25 changed files with 656 additions and 479 deletions
4
Cargo.lock
generated
4
Cargo.lock
generated
|
@ -20139,9 +20139,9 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "zed_llm_client"
|
name = "zed_llm_client"
|
||||||
version = "0.8.4"
|
version = "0.8.5"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "de7d9523255f4e00ee3d0918e5407bd252d798a4a8e71f6d37f23317a1588203"
|
checksum = "c740e29260b8797ad252c202ea09a255b3cbc13f30faaf92fb6b2490336106e0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"anyhow",
|
"anyhow",
|
||||||
"serde",
|
"serde",
|
||||||
|
|
|
@ -625,7 +625,7 @@ wasmtime = { version = "29", default-features = false, features = [
|
||||||
wasmtime-wasi = "29"
|
wasmtime-wasi = "29"
|
||||||
which = "6.0.0"
|
which = "6.0.0"
|
||||||
workspace-hack = "0.1.0"
|
workspace-hack = "0.1.0"
|
||||||
zed_llm_client = "0.8.4"
|
zed_llm_client = "0.8.5"
|
||||||
zstd = "0.11"
|
zstd = "0.11"
|
||||||
|
|
||||||
[workspace.dependencies.async-stripe]
|
[workspace.dependencies.async-stripe]
|
||||||
|
|
|
@ -23,11 +23,10 @@ use gpui::{
|
||||||
};
|
};
|
||||||
use language_model::{
|
use language_model::{
|
||||||
ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
|
ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
|
||||||
LanguageModelId, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
|
LanguageModelId, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
|
||||||
LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
|
LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolResultContent,
|
||||||
LanguageModelToolResultContent, LanguageModelToolUseId, MessageContent,
|
LanguageModelToolUseId, MessageContent, ModelRequestLimitReachedError, PaymentRequiredError,
|
||||||
ModelRequestLimitReachedError, PaymentRequiredError, Role, SelectedModel, StopReason,
|
Role, SelectedModel, StopReason, TokenUsage,
|
||||||
TokenUsage,
|
|
||||||
};
|
};
|
||||||
use postage::stream::Stream as _;
|
use postage::stream::Stream as _;
|
||||||
use project::{
|
use project::{
|
||||||
|
@ -1531,82 +1530,7 @@ impl Thread {
|
||||||
}
|
}
|
||||||
|
|
||||||
thread.update(cx, |thread, cx| {
|
thread.update(cx, |thread, cx| {
|
||||||
let event = match event {
|
match event? {
|
||||||
Ok(event) => event,
|
|
||||||
Err(error) => {
|
|
||||||
match error {
|
|
||||||
LanguageModelCompletionError::RateLimitExceeded { retry_after } => {
|
|
||||||
anyhow::bail!(LanguageModelKnownError::RateLimitExceeded { retry_after });
|
|
||||||
}
|
|
||||||
LanguageModelCompletionError::Overloaded => {
|
|
||||||
anyhow::bail!(LanguageModelKnownError::Overloaded);
|
|
||||||
}
|
|
||||||
LanguageModelCompletionError::ApiInternalServerError =>{
|
|
||||||
anyhow::bail!(LanguageModelKnownError::ApiInternalServerError);
|
|
||||||
}
|
|
||||||
LanguageModelCompletionError::PromptTooLarge { tokens } => {
|
|
||||||
let tokens = tokens.unwrap_or_else(|| {
|
|
||||||
// We didn't get an exact token count from the API, so fall back on our estimate.
|
|
||||||
thread.total_token_usage()
|
|
||||||
.map(|usage| usage.total)
|
|
||||||
.unwrap_or(0)
|
|
||||||
// We know the context window was exceeded in practice, so if our estimate was
|
|
||||||
// lower than max tokens, the estimate was wrong; return that we exceeded by 1.
|
|
||||||
.max(model.max_token_count().saturating_add(1))
|
|
||||||
});
|
|
||||||
|
|
||||||
anyhow::bail!(LanguageModelKnownError::ContextWindowLimitExceeded { tokens })
|
|
||||||
}
|
|
||||||
LanguageModelCompletionError::ApiReadResponseError(io_error) => {
|
|
||||||
anyhow::bail!(LanguageModelKnownError::ReadResponseError(io_error));
|
|
||||||
}
|
|
||||||
LanguageModelCompletionError::UnknownResponseFormat(error) => {
|
|
||||||
anyhow::bail!(LanguageModelKnownError::UnknownResponseFormat(error));
|
|
||||||
}
|
|
||||||
LanguageModelCompletionError::HttpResponseError { status, ref body } => {
|
|
||||||
if let Some(known_error) = LanguageModelKnownError::from_http_response(status, body) {
|
|
||||||
anyhow::bail!(known_error);
|
|
||||||
} else {
|
|
||||||
return Err(error.into());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
LanguageModelCompletionError::DeserializeResponse(error) => {
|
|
||||||
anyhow::bail!(LanguageModelKnownError::DeserializeResponse(error));
|
|
||||||
}
|
|
||||||
LanguageModelCompletionError::BadInputJson {
|
|
||||||
id,
|
|
||||||
tool_name,
|
|
||||||
raw_input: invalid_input_json,
|
|
||||||
json_parse_error,
|
|
||||||
} => {
|
|
||||||
thread.receive_invalid_tool_json(
|
|
||||||
id,
|
|
||||||
tool_name,
|
|
||||||
invalid_input_json,
|
|
||||||
json_parse_error,
|
|
||||||
window,
|
|
||||||
cx,
|
|
||||||
);
|
|
||||||
return Ok(());
|
|
||||||
}
|
|
||||||
// These are all errors we can't automatically attempt to recover from (e.g. by retrying)
|
|
||||||
err @ LanguageModelCompletionError::BadRequestFormat |
|
|
||||||
err @ LanguageModelCompletionError::AuthenticationError |
|
|
||||||
err @ LanguageModelCompletionError::PermissionError |
|
|
||||||
err @ LanguageModelCompletionError::ApiEndpointNotFound |
|
|
||||||
err @ LanguageModelCompletionError::SerializeRequest(_) |
|
|
||||||
err @ LanguageModelCompletionError::BuildRequestBody(_) |
|
|
||||||
err @ LanguageModelCompletionError::HttpSend(_) => {
|
|
||||||
anyhow::bail!(err);
|
|
||||||
}
|
|
||||||
LanguageModelCompletionError::Other(error) => {
|
|
||||||
return Err(error);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
match event {
|
|
||||||
LanguageModelCompletionEvent::StartMessage { .. } => {
|
LanguageModelCompletionEvent::StartMessage { .. } => {
|
||||||
request_assistant_message_id =
|
request_assistant_message_id =
|
||||||
Some(thread.insert_assistant_message(
|
Some(thread.insert_assistant_message(
|
||||||
|
@ -1683,9 +1607,7 @@ impl Thread {
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
LanguageModelCompletionEvent::RedactedThinking {
|
LanguageModelCompletionEvent::RedactedThinking { data } => {
|
||||||
data
|
|
||||||
} => {
|
|
||||||
thread.received_chunk();
|
thread.received_chunk();
|
||||||
|
|
||||||
if let Some(last_message) = thread.messages.last_mut() {
|
if let Some(last_message) = thread.messages.last_mut() {
|
||||||
|
@ -1734,6 +1656,21 @@ impl Thread {
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
LanguageModelCompletionEvent::ToolUseJsonParseError {
|
||||||
|
id,
|
||||||
|
tool_name,
|
||||||
|
raw_input: invalid_input_json,
|
||||||
|
json_parse_error,
|
||||||
|
} => {
|
||||||
|
thread.receive_invalid_tool_json(
|
||||||
|
id,
|
||||||
|
tool_name,
|
||||||
|
invalid_input_json,
|
||||||
|
json_parse_error,
|
||||||
|
window,
|
||||||
|
cx,
|
||||||
|
);
|
||||||
|
}
|
||||||
LanguageModelCompletionEvent::StatusUpdate(status_update) => {
|
LanguageModelCompletionEvent::StatusUpdate(status_update) => {
|
||||||
if let Some(completion) = thread
|
if let Some(completion) = thread
|
||||||
.pending_completions
|
.pending_completions
|
||||||
|
@ -1741,23 +1678,34 @@ impl Thread {
|
||||||
.find(|completion| completion.id == pending_completion_id)
|
.find(|completion| completion.id == pending_completion_id)
|
||||||
{
|
{
|
||||||
match status_update {
|
match status_update {
|
||||||
CompletionRequestStatus::Queued {
|
CompletionRequestStatus::Queued { position } => {
|
||||||
position,
|
completion.queue_state =
|
||||||
} => {
|
QueueState::Queued { position };
|
||||||
completion.queue_state = QueueState::Queued { position };
|
|
||||||
}
|
}
|
||||||
CompletionRequestStatus::Started => {
|
CompletionRequestStatus::Started => {
|
||||||
completion.queue_state = QueueState::Started;
|
completion.queue_state = QueueState::Started;
|
||||||
}
|
}
|
||||||
CompletionRequestStatus::Failed {
|
CompletionRequestStatus::Failed {
|
||||||
code, message, request_id
|
code,
|
||||||
|
message,
|
||||||
|
request_id: _,
|
||||||
|
retry_after,
|
||||||
} => {
|
} => {
|
||||||
anyhow::bail!("completion request failed. request_id: {request_id}, code: {code}, message: {message}");
|
return Err(
|
||||||
|
LanguageModelCompletionError::from_cloud_failure(
|
||||||
|
model.upstream_provider_name(),
|
||||||
|
code,
|
||||||
|
message,
|
||||||
|
retry_after.map(Duration::from_secs_f64),
|
||||||
|
),
|
||||||
|
);
|
||||||
}
|
}
|
||||||
CompletionRequestStatus::UsageUpdated {
|
CompletionRequestStatus::UsageUpdated { amount, limit } => {
|
||||||
amount, limit
|
thread.update_model_request_usage(
|
||||||
} => {
|
amount as u32,
|
||||||
thread.update_model_request_usage(amount as u32, limit, cx);
|
limit,
|
||||||
|
cx,
|
||||||
|
);
|
||||||
}
|
}
|
||||||
CompletionRequestStatus::ToolUseLimitReached => {
|
CompletionRequestStatus::ToolUseLimitReached => {
|
||||||
thread.tool_use_limit_reached = true;
|
thread.tool_use_limit_reached = true;
|
||||||
|
@ -1808,10 +1756,11 @@ impl Thread {
|
||||||
Ok(stop_reason) => {
|
Ok(stop_reason) => {
|
||||||
match stop_reason {
|
match stop_reason {
|
||||||
StopReason::ToolUse => {
|
StopReason::ToolUse => {
|
||||||
let tool_uses = thread.use_pending_tools(window, model.clone(), cx);
|
let tool_uses =
|
||||||
|
thread.use_pending_tools(window, model.clone(), cx);
|
||||||
cx.emit(ThreadEvent::UsePendingTools { tool_uses });
|
cx.emit(ThreadEvent::UsePendingTools { tool_uses });
|
||||||
}
|
}
|
||||||
StopReason::EndTurn | StopReason::MaxTokens => {
|
StopReason::EndTurn | StopReason::MaxTokens => {
|
||||||
thread.project.update(cx, |project, cx| {
|
thread.project.update(cx, |project, cx| {
|
||||||
project.set_agent_location(None, cx);
|
project.set_agent_location(None, cx);
|
||||||
});
|
});
|
||||||
|
@ -1827,7 +1776,9 @@ impl Thread {
|
||||||
{
|
{
|
||||||
let mut messages_to_remove = Vec::new();
|
let mut messages_to_remove = Vec::new();
|
||||||
|
|
||||||
for (ix, message) in thread.messages.iter().enumerate().rev() {
|
for (ix, message) in
|
||||||
|
thread.messages.iter().enumerate().rev()
|
||||||
|
{
|
||||||
messages_to_remove.push(message.id);
|
messages_to_remove.push(message.id);
|
||||||
|
|
||||||
if message.role == Role::User {
|
if message.role == Role::User {
|
||||||
|
@ -1835,7 +1786,9 @@ impl Thread {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(prev_message) = thread.messages.get(ix - 1) {
|
if let Some(prev_message) =
|
||||||
|
thread.messages.get(ix - 1)
|
||||||
|
{
|
||||||
if prev_message.role == Role::Assistant {
|
if prev_message.role == Role::Assistant {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
@ -1850,14 +1803,16 @@ impl Thread {
|
||||||
|
|
||||||
cx.emit(ThreadEvent::ShowError(ThreadError::Message {
|
cx.emit(ThreadEvent::ShowError(ThreadError::Message {
|
||||||
header: "Language model refusal".into(),
|
header: "Language model refusal".into(),
|
||||||
message: "Model refused to generate content for safety reasons.".into(),
|
message:
|
||||||
|
"Model refused to generate content for safety reasons."
|
||||||
|
.into(),
|
||||||
}));
|
}));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// We successfully completed, so cancel any remaining retries.
|
// We successfully completed, so cancel any remaining retries.
|
||||||
thread.retry_state = None;
|
thread.retry_state = None;
|
||||||
},
|
}
|
||||||
Err(error) => {
|
Err(error) => {
|
||||||
thread.project.update(cx, |project, cx| {
|
thread.project.update(cx, |project, cx| {
|
||||||
project.set_agent_location(None, cx);
|
project.set_agent_location(None, cx);
|
||||||
|
@ -1883,26 +1838,38 @@ impl Thread {
|
||||||
cx.emit(ThreadEvent::ShowError(
|
cx.emit(ThreadEvent::ShowError(
|
||||||
ThreadError::ModelRequestLimitReached { plan: error.plan },
|
ThreadError::ModelRequestLimitReached { plan: error.plan },
|
||||||
));
|
));
|
||||||
} else if let Some(known_error) =
|
} else if let Some(completion_error) =
|
||||||
error.downcast_ref::<LanguageModelKnownError>()
|
error.downcast_ref::<LanguageModelCompletionError>()
|
||||||
{
|
{
|
||||||
match known_error {
|
use LanguageModelCompletionError::*;
|
||||||
LanguageModelKnownError::ContextWindowLimitExceeded { tokens } => {
|
match &completion_error {
|
||||||
|
PromptTooLarge { tokens, .. } => {
|
||||||
|
let tokens = tokens.unwrap_or_else(|| {
|
||||||
|
// We didn't get an exact token count from the API, so fall back on our estimate.
|
||||||
|
thread
|
||||||
|
.total_token_usage()
|
||||||
|
.map(|usage| usage.total)
|
||||||
|
.unwrap_or(0)
|
||||||
|
// We know the context window was exceeded in practice, so if our estimate was
|
||||||
|
// lower than max tokens, the estimate was wrong; return that we exceeded by 1.
|
||||||
|
.max(model.max_token_count().saturating_add(1))
|
||||||
|
});
|
||||||
thread.exceeded_window_error = Some(ExceededWindowError {
|
thread.exceeded_window_error = Some(ExceededWindowError {
|
||||||
model_id: model.id(),
|
model_id: model.id(),
|
||||||
token_count: *tokens,
|
token_count: tokens,
|
||||||
});
|
});
|
||||||
cx.notify();
|
cx.notify();
|
||||||
}
|
}
|
||||||
LanguageModelKnownError::RateLimitExceeded { retry_after } => {
|
RateLimitExceeded {
|
||||||
let provider_name = model.provider_name();
|
retry_after: Some(retry_after),
|
||||||
let error_message = format!(
|
..
|
||||||
"{}'s API rate limit exceeded",
|
}
|
||||||
provider_name.0.as_ref()
|
| ServerOverloaded {
|
||||||
);
|
retry_after: Some(retry_after),
|
||||||
|
..
|
||||||
|
} => {
|
||||||
thread.handle_rate_limit_error(
|
thread.handle_rate_limit_error(
|
||||||
&error_message,
|
&completion_error,
|
||||||
*retry_after,
|
*retry_after,
|
||||||
model.clone(),
|
model.clone(),
|
||||||
intent,
|
intent,
|
||||||
|
@ -1911,15 +1878,9 @@ impl Thread {
|
||||||
);
|
);
|
||||||
retry_scheduled = true;
|
retry_scheduled = true;
|
||||||
}
|
}
|
||||||
LanguageModelKnownError::Overloaded => {
|
RateLimitExceeded { .. } | ServerOverloaded { .. } => {
|
||||||
let provider_name = model.provider_name();
|
|
||||||
let error_message = format!(
|
|
||||||
"{}'s API servers are overloaded right now",
|
|
||||||
provider_name.0.as_ref()
|
|
||||||
);
|
|
||||||
|
|
||||||
retry_scheduled = thread.handle_retryable_error(
|
retry_scheduled = thread.handle_retryable_error(
|
||||||
&error_message,
|
&completion_error,
|
||||||
model.clone(),
|
model.clone(),
|
||||||
intent,
|
intent,
|
||||||
window,
|
window,
|
||||||
|
@ -1929,15 +1890,11 @@ impl Thread {
|
||||||
emit_generic_error(error, cx);
|
emit_generic_error(error, cx);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
LanguageModelKnownError::ApiInternalServerError => {
|
ApiInternalServerError { .. }
|
||||||
let provider_name = model.provider_name();
|
| ApiReadResponseError { .. }
|
||||||
let error_message = format!(
|
| HttpSend { .. } => {
|
||||||
"{}'s API server reported an internal server error",
|
|
||||||
provider_name.0.as_ref()
|
|
||||||
);
|
|
||||||
|
|
||||||
retry_scheduled = thread.handle_retryable_error(
|
retry_scheduled = thread.handle_retryable_error(
|
||||||
&error_message,
|
&completion_error,
|
||||||
model.clone(),
|
model.clone(),
|
||||||
intent,
|
intent,
|
||||||
window,
|
window,
|
||||||
|
@ -1947,12 +1904,16 @@ impl Thread {
|
||||||
emit_generic_error(error, cx);
|
emit_generic_error(error, cx);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
LanguageModelKnownError::ReadResponseError(_) |
|
NoApiKey { .. }
|
||||||
LanguageModelKnownError::DeserializeResponse(_) |
|
| HttpResponseError { .. }
|
||||||
LanguageModelKnownError::UnknownResponseFormat(_) => {
|
| BadRequestFormat { .. }
|
||||||
// In the future we will attempt to re-roll response, but only once
|
| AuthenticationError { .. }
|
||||||
emit_generic_error(error, cx);
|
| PermissionError { .. }
|
||||||
}
|
| ApiEndpointNotFound { .. }
|
||||||
|
| SerializeRequest { .. }
|
||||||
|
| BuildRequestBody { .. }
|
||||||
|
| DeserializeResponse { .. }
|
||||||
|
| Other { .. } => emit_generic_error(error, cx),
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
emit_generic_error(error, cx);
|
emit_generic_error(error, cx);
|
||||||
|
@ -2084,7 +2045,7 @@ impl Thread {
|
||||||
|
|
||||||
fn handle_rate_limit_error(
|
fn handle_rate_limit_error(
|
||||||
&mut self,
|
&mut self,
|
||||||
error_message: &str,
|
error: &LanguageModelCompletionError,
|
||||||
retry_after: Duration,
|
retry_after: Duration,
|
||||||
model: Arc<dyn LanguageModel>,
|
model: Arc<dyn LanguageModel>,
|
||||||
intent: CompletionIntent,
|
intent: CompletionIntent,
|
||||||
|
@ -2092,9 +2053,10 @@ impl Thread {
|
||||||
cx: &mut Context<Self>,
|
cx: &mut Context<Self>,
|
||||||
) {
|
) {
|
||||||
// For rate limit errors, we only retry once with the specified duration
|
// For rate limit errors, we only retry once with the specified duration
|
||||||
let retry_message = format!(
|
let retry_message = format!("{error}. Retrying in {} seconds…", retry_after.as_secs());
|
||||||
"{error_message}. Retrying in {} seconds…",
|
log::warn!(
|
||||||
retry_after.as_secs()
|
"Retrying completion request in {} seconds: {error:?}",
|
||||||
|
retry_after.as_secs(),
|
||||||
);
|
);
|
||||||
|
|
||||||
// Add a UI-only message instead of a regular message
|
// Add a UI-only message instead of a regular message
|
||||||
|
@ -2127,18 +2089,18 @@ impl Thread {
|
||||||
|
|
||||||
fn handle_retryable_error(
|
fn handle_retryable_error(
|
||||||
&mut self,
|
&mut self,
|
||||||
error_message: &str,
|
error: &LanguageModelCompletionError,
|
||||||
model: Arc<dyn LanguageModel>,
|
model: Arc<dyn LanguageModel>,
|
||||||
intent: CompletionIntent,
|
intent: CompletionIntent,
|
||||||
window: Option<AnyWindowHandle>,
|
window: Option<AnyWindowHandle>,
|
||||||
cx: &mut Context<Self>,
|
cx: &mut Context<Self>,
|
||||||
) -> bool {
|
) -> bool {
|
||||||
self.handle_retryable_error_with_delay(error_message, None, model, intent, window, cx)
|
self.handle_retryable_error_with_delay(error, None, model, intent, window, cx)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn handle_retryable_error_with_delay(
|
fn handle_retryable_error_with_delay(
|
||||||
&mut self,
|
&mut self,
|
||||||
error_message: &str,
|
error: &LanguageModelCompletionError,
|
||||||
custom_delay: Option<Duration>,
|
custom_delay: Option<Duration>,
|
||||||
model: Arc<dyn LanguageModel>,
|
model: Arc<dyn LanguageModel>,
|
||||||
intent: CompletionIntent,
|
intent: CompletionIntent,
|
||||||
|
@ -2168,8 +2130,12 @@ impl Thread {
|
||||||
// Add a transient message to inform the user
|
// Add a transient message to inform the user
|
||||||
let delay_secs = delay.as_secs();
|
let delay_secs = delay.as_secs();
|
||||||
let retry_message = format!(
|
let retry_message = format!(
|
||||||
"{}. Retrying (attempt {} of {}) in {} seconds...",
|
"{error}. Retrying (attempt {attempt} of {max_attempts}) \
|
||||||
error_message, attempt, max_attempts, delay_secs
|
in {delay_secs} seconds..."
|
||||||
|
);
|
||||||
|
log::warn!(
|
||||||
|
"Retrying completion request (attempt {attempt} of {max_attempts}) \
|
||||||
|
in {delay_secs} seconds: {error:?}",
|
||||||
);
|
);
|
||||||
|
|
||||||
// Add a UI-only message instead of a regular message
|
// Add a UI-only message instead of a regular message
|
||||||
|
@ -4139,9 +4105,15 @@ fn main() {{
|
||||||
>,
|
>,
|
||||||
> {
|
> {
|
||||||
let error = match self.error_type {
|
let error = match self.error_type {
|
||||||
TestError::Overloaded => LanguageModelCompletionError::Overloaded,
|
TestError::Overloaded => LanguageModelCompletionError::ServerOverloaded {
|
||||||
|
provider: self.provider_name(),
|
||||||
|
retry_after: None,
|
||||||
|
},
|
||||||
TestError::InternalServerError => {
|
TestError::InternalServerError => {
|
||||||
LanguageModelCompletionError::ApiInternalServerError
|
LanguageModelCompletionError::ApiInternalServerError {
|
||||||
|
provider: self.provider_name(),
|
||||||
|
message: "I'm a teapot orbiting the sun".to_string(),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
async move {
|
async move {
|
||||||
|
@ -4649,9 +4621,13 @@ fn main() {{
|
||||||
> {
|
> {
|
||||||
if !*self.failed_once.lock() {
|
if !*self.failed_once.lock() {
|
||||||
*self.failed_once.lock() = true;
|
*self.failed_once.lock() = true;
|
||||||
|
let provider = self.provider_name();
|
||||||
// Return error on first attempt
|
// Return error on first attempt
|
||||||
let stream = futures::stream::once(async move {
|
let stream = futures::stream::once(async move {
|
||||||
Err(LanguageModelCompletionError::Overloaded)
|
Err(LanguageModelCompletionError::ServerOverloaded {
|
||||||
|
provider,
|
||||||
|
retry_after: None,
|
||||||
|
})
|
||||||
});
|
});
|
||||||
async move { Ok(stream.boxed()) }.boxed()
|
async move { Ok(stream.boxed()) }.boxed()
|
||||||
} else {
|
} else {
|
||||||
|
@ -4814,9 +4790,13 @@ fn main() {{
|
||||||
> {
|
> {
|
||||||
if !*self.failed_once.lock() {
|
if !*self.failed_once.lock() {
|
||||||
*self.failed_once.lock() = true;
|
*self.failed_once.lock() = true;
|
||||||
|
let provider = self.provider_name();
|
||||||
// Return error on first attempt
|
// Return error on first attempt
|
||||||
let stream = futures::stream::once(async move {
|
let stream = futures::stream::once(async move {
|
||||||
Err(LanguageModelCompletionError::Overloaded)
|
Err(LanguageModelCompletionError::ServerOverloaded {
|
||||||
|
provider,
|
||||||
|
retry_after: None,
|
||||||
|
})
|
||||||
});
|
});
|
||||||
async move { Ok(stream.boxed()) }.boxed()
|
async move { Ok(stream.boxed()) }.boxed()
|
||||||
} else {
|
} else {
|
||||||
|
@ -4969,10 +4949,12 @@ fn main() {{
|
||||||
LanguageModelCompletionError,
|
LanguageModelCompletionError,
|
||||||
>,
|
>,
|
||||||
> {
|
> {
|
||||||
|
let provider = self.provider_name();
|
||||||
async move {
|
async move {
|
||||||
let stream = futures::stream::once(async move {
|
let stream = futures::stream::once(async move {
|
||||||
Err(LanguageModelCompletionError::RateLimitExceeded {
|
Err(LanguageModelCompletionError::RateLimitExceeded {
|
||||||
retry_after: Duration::from_secs(TEST_RATE_LIMIT_RETRY_SECS),
|
provider,
|
||||||
|
retry_after: Some(Duration::from_secs(TEST_RATE_LIMIT_RETRY_SECS)),
|
||||||
})
|
})
|
||||||
});
|
});
|
||||||
Ok(stream.boxed())
|
Ok(stream.boxed())
|
||||||
|
|
|
@ -2025,9 +2025,7 @@ impl AgentPanel {
|
||||||
.thread()
|
.thread()
|
||||||
.read(cx)
|
.read(cx)
|
||||||
.configured_model()
|
.configured_model()
|
||||||
.map_or(false, |model| {
|
.map_or(false, |model| model.provider.id() == ZED_CLOUD_PROVIDER_ID);
|
||||||
model.provider.id().0 == ZED_CLOUD_PROVIDER_ID
|
|
||||||
});
|
|
||||||
|
|
||||||
if !is_using_zed_provider {
|
if !is_using_zed_provider {
|
||||||
return false;
|
return false;
|
||||||
|
|
|
@ -1250,9 +1250,7 @@ impl MessageEditor {
|
||||||
self.thread
|
self.thread
|
||||||
.read(cx)
|
.read(cx)
|
||||||
.configured_model()
|
.configured_model()
|
||||||
.map_or(false, |model| {
|
.map_or(false, |model| model.provider.id() == ZED_CLOUD_PROVIDER_ID)
|
||||||
model.provider.id().0 == ZED_CLOUD_PROVIDER_ID
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn render_usage_callout(&self, line_height: Pixels, cx: &mut Context<Self>) -> Option<Div> {
|
fn render_usage_callout(&self, line_height: Pixels, cx: &mut Context<Self>) -> Option<Div> {
|
||||||
|
|
|
@ -6,7 +6,7 @@ use anyhow::{Context as _, Result, anyhow};
|
||||||
use chrono::{DateTime, Utc};
|
use chrono::{DateTime, Utc};
|
||||||
use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
|
use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
|
||||||
use http_client::http::{self, HeaderMap, HeaderValue};
|
use http_client::http::{self, HeaderMap, HeaderValue};
|
||||||
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
|
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest, StatusCode};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use strum::{EnumIter, EnumString};
|
use strum::{EnumIter, EnumString};
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
|
@ -356,7 +356,7 @@ pub async fn complete(
|
||||||
.send(request)
|
.send(request)
|
||||||
.await
|
.await
|
||||||
.map_err(AnthropicError::HttpSend)?;
|
.map_err(AnthropicError::HttpSend)?;
|
||||||
let status = response.status();
|
let status_code = response.status();
|
||||||
let mut body = String::new();
|
let mut body = String::new();
|
||||||
response
|
response
|
||||||
.body_mut()
|
.body_mut()
|
||||||
|
@ -364,12 +364,12 @@ pub async fn complete(
|
||||||
.await
|
.await
|
||||||
.map_err(AnthropicError::ReadResponse)?;
|
.map_err(AnthropicError::ReadResponse)?;
|
||||||
|
|
||||||
if status.is_success() {
|
if status_code.is_success() {
|
||||||
Ok(serde_json::from_str(&body).map_err(AnthropicError::DeserializeResponse)?)
|
Ok(serde_json::from_str(&body).map_err(AnthropicError::DeserializeResponse)?)
|
||||||
} else {
|
} else {
|
||||||
Err(AnthropicError::HttpResponseError {
|
Err(AnthropicError::HttpResponseError {
|
||||||
status: status.as_u16(),
|
status_code,
|
||||||
body,
|
message: body,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -444,11 +444,7 @@ impl RateLimitInfo {
|
||||||
}
|
}
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
retry_after: headers
|
retry_after: parse_retry_after(headers),
|
||||||
.get("retry-after")
|
|
||||||
.and_then(|v| v.to_str().ok())
|
|
||||||
.and_then(|v| v.parse::<u64>().ok())
|
|
||||||
.map(Duration::from_secs),
|
|
||||||
requests: RateLimit::from_headers("requests", headers).ok(),
|
requests: RateLimit::from_headers("requests", headers).ok(),
|
||||||
tokens: RateLimit::from_headers("tokens", headers).ok(),
|
tokens: RateLimit::from_headers("tokens", headers).ok(),
|
||||||
input_tokens: RateLimit::from_headers("input-tokens", headers).ok(),
|
input_tokens: RateLimit::from_headers("input-tokens", headers).ok(),
|
||||||
|
@ -457,6 +453,17 @@ impl RateLimitInfo {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Parses the Retry-After header value as an integer number of seconds (anthropic always uses
|
||||||
|
/// seconds). Note that other services might specify an HTTP date or some other format for this
|
||||||
|
/// header. Returns `None` if the header is not present or cannot be parsed.
|
||||||
|
pub fn parse_retry_after(headers: &HeaderMap<HeaderValue>) -> Option<Duration> {
|
||||||
|
headers
|
||||||
|
.get("retry-after")
|
||||||
|
.and_then(|v| v.to_str().ok())
|
||||||
|
.and_then(|v| v.parse::<u64>().ok())
|
||||||
|
.map(Duration::from_secs)
|
||||||
|
}
|
||||||
|
|
||||||
fn get_header<'a>(key: &str, headers: &'a HeaderMap) -> anyhow::Result<&'a str> {
|
fn get_header<'a>(key: &str, headers: &'a HeaderMap) -> anyhow::Result<&'a str> {
|
||||||
Ok(headers
|
Ok(headers
|
||||||
.get(key)
|
.get(key)
|
||||||
|
@ -520,6 +527,10 @@ pub async fn stream_completion_with_rate_limit_info(
|
||||||
})
|
})
|
||||||
.boxed();
|
.boxed();
|
||||||
Ok((stream, Some(rate_limits)))
|
Ok((stream, Some(rate_limits)))
|
||||||
|
} else if response.status().as_u16() == 529 {
|
||||||
|
Err(AnthropicError::ServerOverloaded {
|
||||||
|
retry_after: rate_limits.retry_after,
|
||||||
|
})
|
||||||
} else if let Some(retry_after) = rate_limits.retry_after {
|
} else if let Some(retry_after) = rate_limits.retry_after {
|
||||||
Err(AnthropicError::RateLimit { retry_after })
|
Err(AnthropicError::RateLimit { retry_after })
|
||||||
} else {
|
} else {
|
||||||
|
@ -532,10 +543,9 @@ pub async fn stream_completion_with_rate_limit_info(
|
||||||
|
|
||||||
match serde_json::from_str::<Event>(&body) {
|
match serde_json::from_str::<Event>(&body) {
|
||||||
Ok(Event::Error { error }) => Err(AnthropicError::ApiError(error)),
|
Ok(Event::Error { error }) => Err(AnthropicError::ApiError(error)),
|
||||||
Ok(_) => Err(AnthropicError::UnexpectedResponseFormat(body)),
|
Ok(_) | Err(_) => Err(AnthropicError::HttpResponseError {
|
||||||
Err(_) => Err(AnthropicError::HttpResponseError {
|
status_code: response.status(),
|
||||||
status: response.status().as_u16(),
|
message: body,
|
||||||
body: body,
|
|
||||||
}),
|
}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -801,16 +811,19 @@ pub enum AnthropicError {
|
||||||
ReadResponse(io::Error),
|
ReadResponse(io::Error),
|
||||||
|
|
||||||
/// HTTP error response from the API
|
/// HTTP error response from the API
|
||||||
HttpResponseError { status: u16, body: String },
|
HttpResponseError {
|
||||||
|
status_code: StatusCode,
|
||||||
|
message: String,
|
||||||
|
},
|
||||||
|
|
||||||
/// Rate limit exceeded
|
/// Rate limit exceeded
|
||||||
RateLimit { retry_after: Duration },
|
RateLimit { retry_after: Duration },
|
||||||
|
|
||||||
|
/// Server overloaded
|
||||||
|
ServerOverloaded { retry_after: Option<Duration> },
|
||||||
|
|
||||||
/// API returned an error response
|
/// API returned an error response
|
||||||
ApiError(ApiError),
|
ApiError(ApiError),
|
||||||
|
|
||||||
/// Unexpected response format
|
|
||||||
UnexpectedResponseFormat(String),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize, Error)]
|
#[derive(Debug, Serialize, Deserialize, Error)]
|
||||||
|
|
|
@ -2140,7 +2140,8 @@ impl AssistantContext {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
LanguageModelCompletionEvent::ToolUse(_) |
|
LanguageModelCompletionEvent::ToolUse(_) |
|
||||||
LanguageModelCompletionEvent::UsageUpdate(_) => {}
|
LanguageModelCompletionEvent::ToolUseJsonParseError { .. } |
|
||||||
|
LanguageModelCompletionEvent::UsageUpdate(_) => {}
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
|
@ -29,6 +29,7 @@ use std::{
|
||||||
path::Path,
|
path::Path,
|
||||||
str::FromStr,
|
str::FromStr,
|
||||||
sync::mpsc,
|
sync::mpsc,
|
||||||
|
time::Duration,
|
||||||
};
|
};
|
||||||
use util::path;
|
use util::path;
|
||||||
|
|
||||||
|
@ -1658,12 +1659,14 @@ async fn retry_on_rate_limit<R>(mut request: impl AsyncFnMut() -> Result<R>) ->
|
||||||
match request().await {
|
match request().await {
|
||||||
Ok(result) => return Ok(result),
|
Ok(result) => return Ok(result),
|
||||||
Err(err) => match err.downcast::<LanguageModelCompletionError>() {
|
Err(err) => match err.downcast::<LanguageModelCompletionError>() {
|
||||||
Ok(err) => match err {
|
Ok(err) => match &err {
|
||||||
LanguageModelCompletionError::RateLimitExceeded { retry_after } => {
|
LanguageModelCompletionError::RateLimitExceeded { retry_after, .. }
|
||||||
|
| LanguageModelCompletionError::ServerOverloaded { retry_after, .. } => {
|
||||||
|
let retry_after = retry_after.unwrap_or(Duration::from_secs(5));
|
||||||
// Wait for the duration supplied, with some jitter to avoid all requests being made at the same time.
|
// Wait for the duration supplied, with some jitter to avoid all requests being made at the same time.
|
||||||
let jitter = retry_after.mul_f64(rand::thread_rng().gen_range(0.0..1.0));
|
let jitter = retry_after.mul_f64(rand::thread_rng().gen_range(0.0..1.0));
|
||||||
eprintln!(
|
eprintln!(
|
||||||
"Attempt #{attempt}: Rate limit exceeded. Retry after {retry_after:?} + jitter of {jitter:?}"
|
"Attempt #{attempt}: {err}. Retry after {retry_after:?} + jitter of {jitter:?}"
|
||||||
);
|
);
|
||||||
Timer::after(retry_after + jitter).await;
|
Timer::after(retry_after + jitter).await;
|
||||||
continue;
|
continue;
|
||||||
|
|
|
@ -1054,6 +1054,15 @@ pub fn response_events_to_markdown(
|
||||||
| LanguageModelCompletionEvent::StartMessage { .. }
|
| LanguageModelCompletionEvent::StartMessage { .. }
|
||||||
| LanguageModelCompletionEvent::StatusUpdate { .. },
|
| LanguageModelCompletionEvent::StatusUpdate { .. },
|
||||||
) => {}
|
) => {}
|
||||||
|
Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
|
||||||
|
json_parse_error, ..
|
||||||
|
}) => {
|
||||||
|
flush_buffers(&mut response, &mut text_buffer, &mut thinking_buffer);
|
||||||
|
response.push_str(&format!(
|
||||||
|
"**Error**: parse error in tool use JSON: {}\n\n",
|
||||||
|
json_parse_error
|
||||||
|
));
|
||||||
|
}
|
||||||
Err(error) => {
|
Err(error) => {
|
||||||
flush_buffers(&mut response, &mut text_buffer, &mut thinking_buffer);
|
flush_buffers(&mut response, &mut text_buffer, &mut thinking_buffer);
|
||||||
response.push_str(&format!("**Error**: {}\n\n", error));
|
response.push_str(&format!("**Error**: {}\n\n", error));
|
||||||
|
@ -1132,6 +1141,17 @@ impl ThreadDialog {
|
||||||
| Ok(LanguageModelCompletionEvent::StartMessage { .. })
|
| Ok(LanguageModelCompletionEvent::StartMessage { .. })
|
||||||
| Ok(LanguageModelCompletionEvent::Stop(_)) => {}
|
| Ok(LanguageModelCompletionEvent::Stop(_)) => {}
|
||||||
|
|
||||||
|
Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
|
||||||
|
json_parse_error,
|
||||||
|
..
|
||||||
|
}) => {
|
||||||
|
flush_text(&mut current_text, &mut content);
|
||||||
|
content.push(MessageContent::Text(format!(
|
||||||
|
"ERROR: parse error in tool use JSON: {}",
|
||||||
|
json_parse_error
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
|
||||||
Err(error) => {
|
Err(error) => {
|
||||||
flush_text(&mut current_text, &mut content);
|
flush_text(&mut current_text, &mut content);
|
||||||
content.push(MessageContent::Text(format!("ERROR: {}", error)));
|
content.push(MessageContent::Text(format!("ERROR: {}", error)));
|
||||||
|
|
|
@ -9,17 +9,18 @@ mod telemetry;
|
||||||
pub mod fake_provider;
|
pub mod fake_provider;
|
||||||
|
|
||||||
use anthropic::{AnthropicError, parse_prompt_too_long};
|
use anthropic::{AnthropicError, parse_prompt_too_long};
|
||||||
use anyhow::Result;
|
use anyhow::{Result, anyhow};
|
||||||
use client::Client;
|
use client::Client;
|
||||||
use futures::FutureExt;
|
use futures::FutureExt;
|
||||||
use futures::{StreamExt, future::BoxFuture, stream::BoxStream};
|
use futures::{StreamExt, future::BoxFuture, stream::BoxStream};
|
||||||
use gpui::{AnyElement, AnyView, App, AsyncApp, SharedString, Task, Window};
|
use gpui::{AnyElement, AnyView, App, AsyncApp, SharedString, Task, Window};
|
||||||
use http_client::http;
|
use http_client::{StatusCode, http};
|
||||||
use icons::IconName;
|
use icons::IconName;
|
||||||
use parking_lot::Mutex;
|
use parking_lot::Mutex;
|
||||||
use schemars::JsonSchema;
|
use schemars::JsonSchema;
|
||||||
use serde::{Deserialize, Serialize, de::DeserializeOwned};
|
use serde::{Deserialize, Serialize, de::DeserializeOwned};
|
||||||
use std::ops::{Add, Sub};
|
use std::ops::{Add, Sub};
|
||||||
|
use std::str::FromStr;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
use std::{fmt, io};
|
use std::{fmt, io};
|
||||||
|
@ -34,11 +35,22 @@ pub use crate::request::*;
|
||||||
pub use crate::role::*;
|
pub use crate::role::*;
|
||||||
pub use crate::telemetry::*;
|
pub use crate::telemetry::*;
|
||||||
|
|
||||||
pub const ZED_CLOUD_PROVIDER_ID: &str = "zed.dev";
|
pub const ANTHROPIC_PROVIDER_ID: LanguageModelProviderId =
|
||||||
|
LanguageModelProviderId::new("anthropic");
|
||||||
|
pub const ANTHROPIC_PROVIDER_NAME: LanguageModelProviderName =
|
||||||
|
LanguageModelProviderName::new("Anthropic");
|
||||||
|
|
||||||
/// If we get a rate limit error that doesn't tell us when we can retry,
|
pub const GOOGLE_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("google");
|
||||||
/// default to waiting this long before retrying.
|
pub const GOOGLE_PROVIDER_NAME: LanguageModelProviderName =
|
||||||
const DEFAULT_RATE_LIMIT_RETRY_AFTER: Duration = Duration::from_secs(4);
|
LanguageModelProviderName::new("Google AI");
|
||||||
|
|
||||||
|
pub const OPEN_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("openai");
|
||||||
|
pub const OPEN_AI_PROVIDER_NAME: LanguageModelProviderName =
|
||||||
|
LanguageModelProviderName::new("OpenAI");
|
||||||
|
|
||||||
|
pub const ZED_CLOUD_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("zed.dev");
|
||||||
|
pub const ZED_CLOUD_PROVIDER_NAME: LanguageModelProviderName =
|
||||||
|
LanguageModelProviderName::new("Zed");
|
||||||
|
|
||||||
pub fn init(client: Arc<Client>, cx: &mut App) {
|
pub fn init(client: Arc<Client>, cx: &mut App) {
|
||||||
init_settings(cx);
|
init_settings(cx);
|
||||||
|
@ -71,6 +83,12 @@ pub enum LanguageModelCompletionEvent {
|
||||||
data: String,
|
data: String,
|
||||||
},
|
},
|
||||||
ToolUse(LanguageModelToolUse),
|
ToolUse(LanguageModelToolUse),
|
||||||
|
ToolUseJsonParseError {
|
||||||
|
id: LanguageModelToolUseId,
|
||||||
|
tool_name: Arc<str>,
|
||||||
|
raw_input: Arc<str>,
|
||||||
|
json_parse_error: String,
|
||||||
|
},
|
||||||
StartMessage {
|
StartMessage {
|
||||||
message_id: String,
|
message_id: String,
|
||||||
},
|
},
|
||||||
|
@ -79,61 +97,179 @@ pub enum LanguageModelCompletionEvent {
|
||||||
|
|
||||||
#[derive(Error, Debug)]
|
#[derive(Error, Debug)]
|
||||||
pub enum LanguageModelCompletionError {
|
pub enum LanguageModelCompletionError {
|
||||||
#[error("rate limit exceeded, retry after {retry_after:?}")]
|
|
||||||
RateLimitExceeded { retry_after: Duration },
|
|
||||||
#[error("received bad input JSON")]
|
|
||||||
BadInputJson {
|
|
||||||
id: LanguageModelToolUseId,
|
|
||||||
tool_name: Arc<str>,
|
|
||||||
raw_input: Arc<str>,
|
|
||||||
json_parse_error: String,
|
|
||||||
},
|
|
||||||
#[error("language model provider's API is overloaded")]
|
|
||||||
Overloaded,
|
|
||||||
#[error(transparent)]
|
|
||||||
Other(#[from] anyhow::Error),
|
|
||||||
#[error("invalid request format to language model provider's API")]
|
|
||||||
BadRequestFormat,
|
|
||||||
#[error("authentication error with language model provider's API")]
|
|
||||||
AuthenticationError,
|
|
||||||
#[error("permission error with language model provider's API")]
|
|
||||||
PermissionError,
|
|
||||||
#[error("language model provider API endpoint not found")]
|
|
||||||
ApiEndpointNotFound,
|
|
||||||
#[error("prompt too large for context window")]
|
#[error("prompt too large for context window")]
|
||||||
PromptTooLarge { tokens: Option<u64> },
|
PromptTooLarge { tokens: Option<u64> },
|
||||||
#[error("internal server error in language model provider's API")]
|
#[error("missing {provider} API key")]
|
||||||
ApiInternalServerError,
|
NoApiKey { provider: LanguageModelProviderName },
|
||||||
#[error("I/O error reading response from language model provider's API: {0:?}")]
|
#[error("{provider}'s API rate limit exceeded")]
|
||||||
ApiReadResponseError(io::Error),
|
RateLimitExceeded {
|
||||||
#[error("HTTP response error from language model provider's API: status {status} - {body:?}")]
|
provider: LanguageModelProviderName,
|
||||||
HttpResponseError { status: u16, body: String },
|
retry_after: Option<Duration>,
|
||||||
#[error("error serializing request to language model provider API: {0}")]
|
},
|
||||||
SerializeRequest(serde_json::Error),
|
#[error("{provider}'s API servers are overloaded right now")]
|
||||||
#[error("error building request body to language model provider API: {0}")]
|
ServerOverloaded {
|
||||||
BuildRequestBody(http::Error),
|
provider: LanguageModelProviderName,
|
||||||
#[error("error sending HTTP request to language model provider API: {0}")]
|
retry_after: Option<Duration>,
|
||||||
HttpSend(anyhow::Error),
|
},
|
||||||
#[error("error deserializing language model provider API response: {0}")]
|
#[error("{provider}'s API server reported an internal server error: {message}")]
|
||||||
DeserializeResponse(serde_json::Error),
|
ApiInternalServerError {
|
||||||
#[error("unexpected language model provider API response format: {0}")]
|
provider: LanguageModelProviderName,
|
||||||
UnknownResponseFormat(String),
|
message: String,
|
||||||
|
},
|
||||||
|
#[error("HTTP response error from {provider}'s API: status {status_code} - {message:?}")]
|
||||||
|
HttpResponseError {
|
||||||
|
provider: LanguageModelProviderName,
|
||||||
|
status_code: StatusCode,
|
||||||
|
message: String,
|
||||||
|
},
|
||||||
|
|
||||||
|
// Client errors
|
||||||
|
#[error("invalid request format to {provider}'s API: {message}")]
|
||||||
|
BadRequestFormat {
|
||||||
|
provider: LanguageModelProviderName,
|
||||||
|
message: String,
|
||||||
|
},
|
||||||
|
#[error("authentication error with {provider}'s API: {message}")]
|
||||||
|
AuthenticationError {
|
||||||
|
provider: LanguageModelProviderName,
|
||||||
|
message: String,
|
||||||
|
},
|
||||||
|
#[error("permission error with {provider}'s API: {message}")]
|
||||||
|
PermissionError {
|
||||||
|
provider: LanguageModelProviderName,
|
||||||
|
message: String,
|
||||||
|
},
|
||||||
|
#[error("language model provider API endpoint not found")]
|
||||||
|
ApiEndpointNotFound { provider: LanguageModelProviderName },
|
||||||
|
#[error("I/O error reading response from {provider}'s API")]
|
||||||
|
ApiReadResponseError {
|
||||||
|
provider: LanguageModelProviderName,
|
||||||
|
#[source]
|
||||||
|
error: io::Error,
|
||||||
|
},
|
||||||
|
#[error("error serializing request to {provider} API")]
|
||||||
|
SerializeRequest {
|
||||||
|
provider: LanguageModelProviderName,
|
||||||
|
#[source]
|
||||||
|
error: serde_json::Error,
|
||||||
|
},
|
||||||
|
#[error("error building request body to {provider} API")]
|
||||||
|
BuildRequestBody {
|
||||||
|
provider: LanguageModelProviderName,
|
||||||
|
#[source]
|
||||||
|
error: http::Error,
|
||||||
|
},
|
||||||
|
#[error("error sending HTTP request to {provider} API")]
|
||||||
|
HttpSend {
|
||||||
|
provider: LanguageModelProviderName,
|
||||||
|
#[source]
|
||||||
|
error: anyhow::Error,
|
||||||
|
},
|
||||||
|
#[error("error deserializing {provider} API response")]
|
||||||
|
DeserializeResponse {
|
||||||
|
provider: LanguageModelProviderName,
|
||||||
|
#[source]
|
||||||
|
error: serde_json::Error,
|
||||||
|
},
|
||||||
|
|
||||||
|
// TODO: Ideally this would be removed in favor of having a comprehensive list of errors.
|
||||||
|
#[error(transparent)]
|
||||||
|
Other(#[from] anyhow::Error),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl LanguageModelCompletionError {
|
||||||
|
pub fn from_cloud_failure(
|
||||||
|
upstream_provider: LanguageModelProviderName,
|
||||||
|
code: String,
|
||||||
|
message: String,
|
||||||
|
retry_after: Option<Duration>,
|
||||||
|
) -> Self {
|
||||||
|
if let Some(tokens) = parse_prompt_too_long(&message) {
|
||||||
|
// TODO: currently Anthropic PAYLOAD_TOO_LARGE response may cause INTERNAL_SERVER_ERROR
|
||||||
|
// to be reported. This is a temporary workaround to handle this in the case where the
|
||||||
|
// token limit has been exceeded.
|
||||||
|
Self::PromptTooLarge {
|
||||||
|
tokens: Some(tokens),
|
||||||
|
}
|
||||||
|
} else if let Some(status_code) = code
|
||||||
|
.strip_prefix("upstream_http_")
|
||||||
|
.and_then(|code| StatusCode::from_str(code).ok())
|
||||||
|
{
|
||||||
|
Self::from_http_status(upstream_provider, status_code, message, retry_after)
|
||||||
|
} else if let Some(status_code) = code
|
||||||
|
.strip_prefix("http_")
|
||||||
|
.and_then(|code| StatusCode::from_str(code).ok())
|
||||||
|
{
|
||||||
|
Self::from_http_status(ZED_CLOUD_PROVIDER_NAME, status_code, message, retry_after)
|
||||||
|
} else {
|
||||||
|
anyhow!("completion request failed, code: {code}, message: {message}").into()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn from_http_status(
|
||||||
|
provider: LanguageModelProviderName,
|
||||||
|
status_code: StatusCode,
|
||||||
|
message: String,
|
||||||
|
retry_after: Option<Duration>,
|
||||||
|
) -> Self {
|
||||||
|
match status_code {
|
||||||
|
StatusCode::BAD_REQUEST => Self::BadRequestFormat { provider, message },
|
||||||
|
StatusCode::UNAUTHORIZED => Self::AuthenticationError { provider, message },
|
||||||
|
StatusCode::FORBIDDEN => Self::PermissionError { provider, message },
|
||||||
|
StatusCode::NOT_FOUND => Self::ApiEndpointNotFound { provider },
|
||||||
|
StatusCode::PAYLOAD_TOO_LARGE => Self::PromptTooLarge {
|
||||||
|
tokens: parse_prompt_too_long(&message),
|
||||||
|
},
|
||||||
|
StatusCode::TOO_MANY_REQUESTS => Self::RateLimitExceeded {
|
||||||
|
provider,
|
||||||
|
retry_after,
|
||||||
|
},
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR => Self::ApiInternalServerError { provider, message },
|
||||||
|
StatusCode::SERVICE_UNAVAILABLE => Self::ServerOverloaded {
|
||||||
|
provider,
|
||||||
|
retry_after,
|
||||||
|
},
|
||||||
|
_ if status_code.as_u16() == 529 => Self::ServerOverloaded {
|
||||||
|
provider,
|
||||||
|
retry_after,
|
||||||
|
},
|
||||||
|
_ => Self::HttpResponseError {
|
||||||
|
provider,
|
||||||
|
status_code,
|
||||||
|
message,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<AnthropicError> for LanguageModelCompletionError {
|
impl From<AnthropicError> for LanguageModelCompletionError {
|
||||||
fn from(error: AnthropicError) -> Self {
|
fn from(error: AnthropicError) -> Self {
|
||||||
|
let provider = ANTHROPIC_PROVIDER_NAME;
|
||||||
match error {
|
match error {
|
||||||
AnthropicError::SerializeRequest(error) => Self::SerializeRequest(error),
|
AnthropicError::SerializeRequest(error) => Self::SerializeRequest { provider, error },
|
||||||
AnthropicError::BuildRequestBody(error) => Self::BuildRequestBody(error),
|
AnthropicError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error },
|
||||||
AnthropicError::HttpSend(error) => Self::HttpSend(error),
|
AnthropicError::HttpSend(error) => Self::HttpSend { provider, error },
|
||||||
AnthropicError::DeserializeResponse(error) => Self::DeserializeResponse(error),
|
AnthropicError::DeserializeResponse(error) => {
|
||||||
AnthropicError::ReadResponse(error) => Self::ApiReadResponseError(error),
|
Self::DeserializeResponse { provider, error }
|
||||||
AnthropicError::HttpResponseError { status, body } => {
|
|
||||||
Self::HttpResponseError { status, body }
|
|
||||||
}
|
}
|
||||||
AnthropicError::RateLimit { retry_after } => Self::RateLimitExceeded { retry_after },
|
AnthropicError::ReadResponse(error) => Self::ApiReadResponseError { provider, error },
|
||||||
|
AnthropicError::HttpResponseError {
|
||||||
|
status_code,
|
||||||
|
message,
|
||||||
|
} => Self::HttpResponseError {
|
||||||
|
provider,
|
||||||
|
status_code,
|
||||||
|
message,
|
||||||
|
},
|
||||||
|
AnthropicError::RateLimit { retry_after } => Self::RateLimitExceeded {
|
||||||
|
provider,
|
||||||
|
retry_after: Some(retry_after),
|
||||||
|
},
|
||||||
|
AnthropicError::ServerOverloaded { retry_after } => Self::ServerOverloaded {
|
||||||
|
provider,
|
||||||
|
retry_after: retry_after,
|
||||||
|
},
|
||||||
AnthropicError::ApiError(api_error) => api_error.into(),
|
AnthropicError::ApiError(api_error) => api_error.into(),
|
||||||
AnthropicError::UnexpectedResponseFormat(error) => Self::UnknownResponseFormat(error),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -141,23 +277,39 @@ impl From<AnthropicError> for LanguageModelCompletionError {
|
||||||
impl From<anthropic::ApiError> for LanguageModelCompletionError {
|
impl From<anthropic::ApiError> for LanguageModelCompletionError {
|
||||||
fn from(error: anthropic::ApiError) -> Self {
|
fn from(error: anthropic::ApiError) -> Self {
|
||||||
use anthropic::ApiErrorCode::*;
|
use anthropic::ApiErrorCode::*;
|
||||||
|
let provider = ANTHROPIC_PROVIDER_NAME;
|
||||||
match error.code() {
|
match error.code() {
|
||||||
Some(code) => match code {
|
Some(code) => match code {
|
||||||
InvalidRequestError => LanguageModelCompletionError::BadRequestFormat,
|
InvalidRequestError => Self::BadRequestFormat {
|
||||||
AuthenticationError => LanguageModelCompletionError::AuthenticationError,
|
provider,
|
||||||
PermissionError => LanguageModelCompletionError::PermissionError,
|
message: error.message,
|
||||||
NotFoundError => LanguageModelCompletionError::ApiEndpointNotFound,
|
},
|
||||||
RequestTooLarge => LanguageModelCompletionError::PromptTooLarge {
|
AuthenticationError => Self::AuthenticationError {
|
||||||
|
provider,
|
||||||
|
message: error.message,
|
||||||
|
},
|
||||||
|
PermissionError => Self::PermissionError {
|
||||||
|
provider,
|
||||||
|
message: error.message,
|
||||||
|
},
|
||||||
|
NotFoundError => Self::ApiEndpointNotFound { provider },
|
||||||
|
RequestTooLarge => Self::PromptTooLarge {
|
||||||
tokens: parse_prompt_too_long(&error.message),
|
tokens: parse_prompt_too_long(&error.message),
|
||||||
},
|
},
|
||||||
RateLimitError => LanguageModelCompletionError::RateLimitExceeded {
|
RateLimitError => Self::RateLimitExceeded {
|
||||||
retry_after: DEFAULT_RATE_LIMIT_RETRY_AFTER,
|
provider,
|
||||||
|
retry_after: None,
|
||||||
|
},
|
||||||
|
ApiError => Self::ApiInternalServerError {
|
||||||
|
provider,
|
||||||
|
message: error.message,
|
||||||
|
},
|
||||||
|
OverloadedError => Self::ServerOverloaded {
|
||||||
|
provider,
|
||||||
|
retry_after: None,
|
||||||
},
|
},
|
||||||
ApiError => LanguageModelCompletionError::ApiInternalServerError,
|
|
||||||
OverloadedError => LanguageModelCompletionError::Overloaded,
|
|
||||||
},
|
},
|
||||||
None => LanguageModelCompletionError::Other(error.into()),
|
None => Self::Other(error.into()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -278,6 +430,13 @@ pub trait LanguageModel: Send + Sync {
|
||||||
fn name(&self) -> LanguageModelName;
|
fn name(&self) -> LanguageModelName;
|
||||||
fn provider_id(&self) -> LanguageModelProviderId;
|
fn provider_id(&self) -> LanguageModelProviderId;
|
||||||
fn provider_name(&self) -> LanguageModelProviderName;
|
fn provider_name(&self) -> LanguageModelProviderName;
|
||||||
|
fn upstream_provider_id(&self) -> LanguageModelProviderId {
|
||||||
|
self.provider_id()
|
||||||
|
}
|
||||||
|
fn upstream_provider_name(&self) -> LanguageModelProviderName {
|
||||||
|
self.provider_name()
|
||||||
|
}
|
||||||
|
|
||||||
fn telemetry_id(&self) -> String;
|
fn telemetry_id(&self) -> String;
|
||||||
|
|
||||||
fn api_key(&self, _cx: &App) -> Option<String> {
|
fn api_key(&self, _cx: &App) -> Option<String> {
|
||||||
|
@ -365,6 +524,9 @@ pub trait LanguageModel: Send + Sync {
|
||||||
Ok(LanguageModelCompletionEvent::RedactedThinking { .. }) => None,
|
Ok(LanguageModelCompletionEvent::RedactedThinking { .. }) => None,
|
||||||
Ok(LanguageModelCompletionEvent::Stop(_)) => None,
|
Ok(LanguageModelCompletionEvent::Stop(_)) => None,
|
||||||
Ok(LanguageModelCompletionEvent::ToolUse(_)) => None,
|
Ok(LanguageModelCompletionEvent::ToolUse(_)) => None,
|
||||||
|
Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
|
||||||
|
..
|
||||||
|
}) => None,
|
||||||
Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
|
Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
|
||||||
*last_token_usage.lock() = token_usage;
|
*last_token_usage.lock() = token_usage;
|
||||||
None
|
None
|
||||||
|
@ -395,39 +557,6 @@ pub trait LanguageModel: Send + Sync {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Error)]
|
|
||||||
pub enum LanguageModelKnownError {
|
|
||||||
#[error("Context window limit exceeded ({tokens})")]
|
|
||||||
ContextWindowLimitExceeded { tokens: u64 },
|
|
||||||
#[error("Language model provider's API is currently overloaded")]
|
|
||||||
Overloaded,
|
|
||||||
#[error("Language model provider's API encountered an internal server error")]
|
|
||||||
ApiInternalServerError,
|
|
||||||
#[error("I/O error while reading response from language model provider's API: {0:?}")]
|
|
||||||
ReadResponseError(io::Error),
|
|
||||||
#[error("Error deserializing response from language model provider's API: {0:?}")]
|
|
||||||
DeserializeResponse(serde_json::Error),
|
|
||||||
#[error("Language model provider's API returned a response in an unknown format")]
|
|
||||||
UnknownResponseFormat(String),
|
|
||||||
#[error("Rate limit exceeded for language model provider's API; retry in {retry_after:?}")]
|
|
||||||
RateLimitExceeded { retry_after: Duration },
|
|
||||||
}
|
|
||||||
|
|
||||||
impl LanguageModelKnownError {
|
|
||||||
/// Attempts to map an HTTP response status code to a known error type.
|
|
||||||
/// Returns None if the status code doesn't map to a specific known error.
|
|
||||||
pub fn from_http_response(status: u16, _body: &str) -> Option<Self> {
|
|
||||||
match status {
|
|
||||||
429 => Some(Self::RateLimitExceeded {
|
|
||||||
retry_after: DEFAULT_RATE_LIMIT_RETRY_AFTER,
|
|
||||||
}),
|
|
||||||
503 => Some(Self::Overloaded),
|
|
||||||
500..=599 => Some(Self::ApiInternalServerError),
|
|
||||||
_ => None,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub trait LanguageModelTool: 'static + DeserializeOwned + JsonSchema {
|
pub trait LanguageModelTool: 'static + DeserializeOwned + JsonSchema {
|
||||||
fn name() -> String;
|
fn name() -> String;
|
||||||
fn description() -> String;
|
fn description() -> String;
|
||||||
|
@ -509,12 +638,30 @@ pub struct LanguageModelProviderId(pub SharedString);
|
||||||
#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
|
#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
|
||||||
pub struct LanguageModelProviderName(pub SharedString);
|
pub struct LanguageModelProviderName(pub SharedString);
|
||||||
|
|
||||||
|
impl LanguageModelProviderId {
|
||||||
|
pub const fn new(id: &'static str) -> Self {
|
||||||
|
Self(SharedString::new_static(id))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl LanguageModelProviderName {
|
||||||
|
pub const fn new(id: &'static str) -> Self {
|
||||||
|
Self(SharedString::new_static(id))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl fmt::Display for LanguageModelProviderId {
|
impl fmt::Display for LanguageModelProviderId {
|
||||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
write!(f, "{}", self.0)
|
write!(f, "{}", self.0)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl fmt::Display for LanguageModelProviderName {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
|
write!(f, "{}", self.0)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl From<String> for LanguageModelId {
|
impl From<String> for LanguageModelId {
|
||||||
fn from(value: String) -> Self {
|
fn from(value: String) -> Self {
|
||||||
Self(SharedString::from(value))
|
Self(SharedString::from(value))
|
||||||
|
|
|
@ -98,7 +98,7 @@ impl ConfiguredModel {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn is_provided_by_zed(&self) -> bool {
|
pub fn is_provided_by_zed(&self) -> bool {
|
||||||
self.provider.id().0 == crate::ZED_CLOUD_PROVIDER_ID
|
self.provider.id() == crate::ZED_CLOUD_PROVIDER_ID
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
use crate::ANTHROPIC_PROVIDER_ID;
|
||||||
use anthropic::ANTHROPIC_API_URL;
|
use anthropic::ANTHROPIC_API_URL;
|
||||||
use anyhow::{Context as _, anyhow};
|
use anyhow::{Context as _, anyhow};
|
||||||
use client::telemetry::Telemetry;
|
use client::telemetry::Telemetry;
|
||||||
|
@ -8,8 +9,6 @@ use std::sync::Arc;
|
||||||
use telemetry_events::{AssistantEventData, AssistantKind, AssistantPhase};
|
use telemetry_events::{AssistantEventData, AssistantKind, AssistantPhase};
|
||||||
use util::ResultExt;
|
use util::ResultExt;
|
||||||
|
|
||||||
pub const ANTHROPIC_PROVIDER_ID: &str = "anthropic";
|
|
||||||
|
|
||||||
pub fn report_assistant_event(
|
pub fn report_assistant_event(
|
||||||
event: AssistantEventData,
|
event: AssistantEventData,
|
||||||
telemetry: Option<Arc<Telemetry>>,
|
telemetry: Option<Arc<Telemetry>>,
|
||||||
|
@ -19,7 +18,7 @@ pub fn report_assistant_event(
|
||||||
) {
|
) {
|
||||||
if let Some(telemetry) = telemetry.as_ref() {
|
if let Some(telemetry) = telemetry.as_ref() {
|
||||||
telemetry.report_assistant_event(event.clone());
|
telemetry.report_assistant_event(event.clone());
|
||||||
if telemetry.metrics_enabled() && event.model_provider == ANTHROPIC_PROVIDER_ID {
|
if telemetry.metrics_enabled() && event.model_provider == ANTHROPIC_PROVIDER_ID.0 {
|
||||||
if let Some(api_key) = model_api_key {
|
if let Some(api_key) = model_api_key {
|
||||||
executor
|
executor
|
||||||
.spawn(async move {
|
.spawn(async move {
|
||||||
|
|
|
@ -33,8 +33,8 @@ use theme::ThemeSettings;
|
||||||
use ui::{Icon, IconName, List, Tooltip, prelude::*};
|
use ui::{Icon, IconName, List, Tooltip, prelude::*};
|
||||||
use util::ResultExt;
|
use util::ResultExt;
|
||||||
|
|
||||||
const PROVIDER_ID: &str = language_model::ANTHROPIC_PROVIDER_ID;
|
const PROVIDER_ID: LanguageModelProviderId = language_model::ANTHROPIC_PROVIDER_ID;
|
||||||
const PROVIDER_NAME: &str = "Anthropic";
|
const PROVIDER_NAME: LanguageModelProviderName = language_model::ANTHROPIC_PROVIDER_NAME;
|
||||||
|
|
||||||
#[derive(Default, Clone, Debug, PartialEq)]
|
#[derive(Default, Clone, Debug, PartialEq)]
|
||||||
pub struct AnthropicSettings {
|
pub struct AnthropicSettings {
|
||||||
|
@ -218,11 +218,11 @@ impl LanguageModelProviderState for AnthropicLanguageModelProvider {
|
||||||
|
|
||||||
impl LanguageModelProvider for AnthropicLanguageModelProvider {
|
impl LanguageModelProvider for AnthropicLanguageModelProvider {
|
||||||
fn id(&self) -> LanguageModelProviderId {
|
fn id(&self) -> LanguageModelProviderId {
|
||||||
LanguageModelProviderId(PROVIDER_ID.into())
|
PROVIDER_ID
|
||||||
}
|
}
|
||||||
|
|
||||||
fn name(&self) -> LanguageModelProviderName {
|
fn name(&self) -> LanguageModelProviderName {
|
||||||
LanguageModelProviderName(PROVIDER_NAME.into())
|
PROVIDER_NAME
|
||||||
}
|
}
|
||||||
|
|
||||||
fn icon(&self) -> IconName {
|
fn icon(&self) -> IconName {
|
||||||
|
@ -403,7 +403,11 @@ impl AnthropicModel {
|
||||||
};
|
};
|
||||||
|
|
||||||
async move {
|
async move {
|
||||||
let api_key = api_key.context("Missing Anthropic API Key")?;
|
let Some(api_key) = api_key else {
|
||||||
|
return Err(LanguageModelCompletionError::NoApiKey {
|
||||||
|
provider: PROVIDER_NAME,
|
||||||
|
});
|
||||||
|
};
|
||||||
let request =
|
let request =
|
||||||
anthropic::stream_completion(http_client.as_ref(), &api_url, &api_key, request);
|
anthropic::stream_completion(http_client.as_ref(), &api_url, &api_key, request);
|
||||||
request.await.map_err(Into::into)
|
request.await.map_err(Into::into)
|
||||||
|
@ -422,11 +426,11 @@ impl LanguageModel for AnthropicModel {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn provider_id(&self) -> LanguageModelProviderId {
|
fn provider_id(&self) -> LanguageModelProviderId {
|
||||||
LanguageModelProviderId(PROVIDER_ID.into())
|
PROVIDER_ID
|
||||||
}
|
}
|
||||||
|
|
||||||
fn provider_name(&self) -> LanguageModelProviderName {
|
fn provider_name(&self) -> LanguageModelProviderName {
|
||||||
LanguageModelProviderName(PROVIDER_NAME.into())
|
PROVIDER_NAME
|
||||||
}
|
}
|
||||||
|
|
||||||
fn supports_tools(&self) -> bool {
|
fn supports_tools(&self) -> bool {
|
||||||
|
@ -806,12 +810,14 @@ impl AnthropicEventMapper {
|
||||||
raw_input: tool_use.input_json.clone(),
|
raw_input: tool_use.input_json.clone(),
|
||||||
},
|
},
|
||||||
)),
|
)),
|
||||||
Err(json_parse_err) => Err(LanguageModelCompletionError::BadInputJson {
|
Err(json_parse_err) => {
|
||||||
id: tool_use.id.into(),
|
Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
|
||||||
tool_name: tool_use.name.into(),
|
id: tool_use.id.into(),
|
||||||
raw_input: input_json.into(),
|
tool_name: tool_use.name.into(),
|
||||||
json_parse_error: json_parse_err.to_string(),
|
raw_input: input_json.into(),
|
||||||
}),
|
json_parse_error: json_parse_err.to_string(),
|
||||||
|
})
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
vec![event_result]
|
vec![event_result]
|
||||||
|
|
|
@ -52,8 +52,8 @@ use util::ResultExt;
|
||||||
|
|
||||||
use crate::AllLanguageModelSettings;
|
use crate::AllLanguageModelSettings;
|
||||||
|
|
||||||
const PROVIDER_ID: &str = "amazon-bedrock";
|
const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("amazon-bedrock");
|
||||||
const PROVIDER_NAME: &str = "Amazon Bedrock";
|
const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("Amazon Bedrock");
|
||||||
|
|
||||||
#[derive(Default, Clone, Deserialize, Serialize, PartialEq, Debug)]
|
#[derive(Default, Clone, Deserialize, Serialize, PartialEq, Debug)]
|
||||||
pub struct BedrockCredentials {
|
pub struct BedrockCredentials {
|
||||||
|
@ -285,11 +285,11 @@ impl BedrockLanguageModelProvider {
|
||||||
|
|
||||||
impl LanguageModelProvider for BedrockLanguageModelProvider {
|
impl LanguageModelProvider for BedrockLanguageModelProvider {
|
||||||
fn id(&self) -> LanguageModelProviderId {
|
fn id(&self) -> LanguageModelProviderId {
|
||||||
LanguageModelProviderId(PROVIDER_ID.into())
|
PROVIDER_ID
|
||||||
}
|
}
|
||||||
|
|
||||||
fn name(&self) -> LanguageModelProviderName {
|
fn name(&self) -> LanguageModelProviderName {
|
||||||
LanguageModelProviderName(PROVIDER_NAME.into())
|
PROVIDER_NAME
|
||||||
}
|
}
|
||||||
|
|
||||||
fn icon(&self) -> IconName {
|
fn icon(&self) -> IconName {
|
||||||
|
@ -489,11 +489,11 @@ impl LanguageModel for BedrockModel {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn provider_id(&self) -> LanguageModelProviderId {
|
fn provider_id(&self) -> LanguageModelProviderId {
|
||||||
LanguageModelProviderId(PROVIDER_ID.into())
|
PROVIDER_ID
|
||||||
}
|
}
|
||||||
|
|
||||||
fn provider_name(&self) -> LanguageModelProviderName {
|
fn provider_name(&self) -> LanguageModelProviderName {
|
||||||
LanguageModelProviderName(PROVIDER_NAME.into())
|
PROVIDER_NAME
|
||||||
}
|
}
|
||||||
|
|
||||||
fn supports_tools(&self) -> bool {
|
fn supports_tools(&self) -> bool {
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
use anthropic::{AnthropicModelMode, parse_prompt_too_long};
|
use anthropic::AnthropicModelMode;
|
||||||
use anyhow::{Context as _, Result, anyhow};
|
use anyhow::{Context as _, Result, anyhow};
|
||||||
use client::{Client, ModelRequestUsage, UserStore, zed_urls};
|
use client::{Client, ModelRequestUsage, UserStore, zed_urls};
|
||||||
use futures::{
|
use futures::{
|
||||||
|
@ -8,25 +8,21 @@ use google_ai::GoogleModelMode;
|
||||||
use gpui::{
|
use gpui::{
|
||||||
AnyElement, AnyView, App, AsyncApp, Context, Entity, SemanticVersion, Subscription, Task,
|
AnyElement, AnyView, App, AsyncApp, Context, Entity, SemanticVersion, Subscription, Task,
|
||||||
};
|
};
|
||||||
|
use http_client::http::{HeaderMap, HeaderValue};
|
||||||
use http_client::{AsyncBody, HttpClient, Method, Response, StatusCode};
|
use http_client::{AsyncBody, HttpClient, Method, Response, StatusCode};
|
||||||
use language_model::{
|
use language_model::{
|
||||||
AuthenticateError, LanguageModel, LanguageModelCacheConfiguration,
|
AuthenticateError, LanguageModel, LanguageModelCacheConfiguration,
|
||||||
LanguageModelCompletionError, LanguageModelId, LanguageModelKnownError, LanguageModelName,
|
LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName,
|
||||||
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
|
LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
|
||||||
LanguageModelProviderTosView, LanguageModelRequest, LanguageModelToolChoice,
|
LanguageModelProviderState, LanguageModelProviderTosView, LanguageModelRequest,
|
||||||
LanguageModelToolSchemaFormat, ModelRequestLimitReachedError, RateLimiter,
|
LanguageModelToolChoice, LanguageModelToolSchemaFormat, LlmApiToken,
|
||||||
ZED_CLOUD_PROVIDER_ID,
|
ModelRequestLimitReachedError, PaymentRequiredError, RateLimiter, RefreshLlmTokenListener,
|
||||||
};
|
|
||||||
use language_model::{
|
|
||||||
LanguageModelCompletionEvent, LanguageModelProvider, LlmApiToken, PaymentRequiredError,
|
|
||||||
RefreshLlmTokenListener,
|
|
||||||
};
|
};
|
||||||
use proto::Plan;
|
use proto::Plan;
|
||||||
use release_channel::AppVersion;
|
use release_channel::AppVersion;
|
||||||
use schemars::JsonSchema;
|
use schemars::JsonSchema;
|
||||||
use serde::{Deserialize, Serialize, de::DeserializeOwned};
|
use serde::{Deserialize, Serialize, de::DeserializeOwned};
|
||||||
use settings::SettingsStore;
|
use settings::SettingsStore;
|
||||||
use smol::Timer;
|
|
||||||
use smol::io::{AsyncReadExt, BufReader};
|
use smol::io::{AsyncReadExt, BufReader};
|
||||||
use std::pin::Pin;
|
use std::pin::Pin;
|
||||||
use std::str::FromStr as _;
|
use std::str::FromStr as _;
|
||||||
|
@ -47,7 +43,8 @@ use crate::provider::anthropic::{AnthropicEventMapper, count_anthropic_tokens, i
|
||||||
use crate::provider::google::{GoogleEventMapper, into_google};
|
use crate::provider::google::{GoogleEventMapper, into_google};
|
||||||
use crate::provider::open_ai::{OpenAiEventMapper, count_open_ai_tokens, into_open_ai};
|
use crate::provider::open_ai::{OpenAiEventMapper, count_open_ai_tokens, into_open_ai};
|
||||||
|
|
||||||
pub const PROVIDER_NAME: &str = "Zed";
|
const PROVIDER_ID: LanguageModelProviderId = language_model::ZED_CLOUD_PROVIDER_ID;
|
||||||
|
const PROVIDER_NAME: LanguageModelProviderName = language_model::ZED_CLOUD_PROVIDER_NAME;
|
||||||
|
|
||||||
#[derive(Default, Clone, Debug, PartialEq)]
|
#[derive(Default, Clone, Debug, PartialEq)]
|
||||||
pub struct ZedDotDevSettings {
|
pub struct ZedDotDevSettings {
|
||||||
|
@ -351,11 +348,11 @@ impl LanguageModelProviderState for CloudLanguageModelProvider {
|
||||||
|
|
||||||
impl LanguageModelProvider for CloudLanguageModelProvider {
|
impl LanguageModelProvider for CloudLanguageModelProvider {
|
||||||
fn id(&self) -> LanguageModelProviderId {
|
fn id(&self) -> LanguageModelProviderId {
|
||||||
LanguageModelProviderId(ZED_CLOUD_PROVIDER_ID.into())
|
PROVIDER_ID
|
||||||
}
|
}
|
||||||
|
|
||||||
fn name(&self) -> LanguageModelProviderName {
|
fn name(&self) -> LanguageModelProviderName {
|
||||||
LanguageModelProviderName(PROVIDER_NAME.into())
|
PROVIDER_NAME
|
||||||
}
|
}
|
||||||
|
|
||||||
fn icon(&self) -> IconName {
|
fn icon(&self) -> IconName {
|
||||||
|
@ -536,8 +533,6 @@ struct PerformLlmCompletionResponse {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl CloudLanguageModel {
|
impl CloudLanguageModel {
|
||||||
const MAX_RETRIES: usize = 3;
|
|
||||||
|
|
||||||
async fn perform_llm_completion(
|
async fn perform_llm_completion(
|
||||||
client: Arc<Client>,
|
client: Arc<Client>,
|
||||||
llm_api_token: LlmApiToken,
|
llm_api_token: LlmApiToken,
|
||||||
|
@ -547,8 +542,7 @@ impl CloudLanguageModel {
|
||||||
let http_client = &client.http_client();
|
let http_client = &client.http_client();
|
||||||
|
|
||||||
let mut token = llm_api_token.acquire(&client).await?;
|
let mut token = llm_api_token.acquire(&client).await?;
|
||||||
let mut retries_remaining = Self::MAX_RETRIES;
|
let mut refreshed_token = false;
|
||||||
let mut retry_delay = Duration::from_secs(1);
|
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
let request_builder = http_client::Request::builder()
|
let request_builder = http_client::Request::builder()
|
||||||
|
@ -590,14 +584,20 @@ impl CloudLanguageModel {
|
||||||
includes_status_messages,
|
includes_status_messages,
|
||||||
tool_use_limit_reached,
|
tool_use_limit_reached,
|
||||||
});
|
});
|
||||||
} else if response
|
}
|
||||||
.headers()
|
|
||||||
.get(EXPIRED_LLM_TOKEN_HEADER_NAME)
|
if !refreshed_token
|
||||||
.is_some()
|
&& response
|
||||||
|
.headers()
|
||||||
|
.get(EXPIRED_LLM_TOKEN_HEADER_NAME)
|
||||||
|
.is_some()
|
||||||
{
|
{
|
||||||
retries_remaining -= 1;
|
|
||||||
token = llm_api_token.refresh(&client).await?;
|
token = llm_api_token.refresh(&client).await?;
|
||||||
} else if status == StatusCode::FORBIDDEN
|
refreshed_token = true;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if status == StatusCode::FORBIDDEN
|
||||||
&& response
|
&& response
|
||||||
.headers()
|
.headers()
|
||||||
.get(SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME)
|
.get(SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME)
|
||||||
|
@ -622,35 +622,18 @@ impl CloudLanguageModel {
|
||||||
return Err(anyhow!(ModelRequestLimitReachedError { plan }));
|
return Err(anyhow!(ModelRequestLimitReachedError { plan }));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
anyhow::bail!("Forbidden");
|
|
||||||
} else if status.as_u16() >= 500 && status.as_u16() < 600 {
|
|
||||||
// If we encounter an error in the 500 range, retry after a delay.
|
|
||||||
// We've seen at least these in the wild from API providers:
|
|
||||||
// * 500 Internal Server Error
|
|
||||||
// * 502 Bad Gateway
|
|
||||||
// * 529 Service Overloaded
|
|
||||||
|
|
||||||
if retries_remaining == 0 {
|
|
||||||
let mut body = String::new();
|
|
||||||
response.body_mut().read_to_string(&mut body).await?;
|
|
||||||
anyhow::bail!(
|
|
||||||
"cloud language model completion failed after {} retries with status {status}: {body}",
|
|
||||||
Self::MAX_RETRIES
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
Timer::after(retry_delay).await;
|
|
||||||
|
|
||||||
retries_remaining -= 1;
|
|
||||||
retry_delay *= 2; // If it fails again, wait longer.
|
|
||||||
} else if status == StatusCode::PAYMENT_REQUIRED {
|
} else if status == StatusCode::PAYMENT_REQUIRED {
|
||||||
return Err(anyhow!(PaymentRequiredError));
|
return Err(anyhow!(PaymentRequiredError));
|
||||||
} else {
|
|
||||||
let mut body = String::new();
|
|
||||||
response.body_mut().read_to_string(&mut body).await?;
|
|
||||||
return Err(anyhow!(ApiError { status, body }));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let mut body = String::new();
|
||||||
|
let headers = response.headers().clone();
|
||||||
|
response.body_mut().read_to_string(&mut body).await?;
|
||||||
|
return Err(anyhow!(ApiError {
|
||||||
|
status,
|
||||||
|
body,
|
||||||
|
headers
|
||||||
|
}));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -660,6 +643,19 @@ impl CloudLanguageModel {
|
||||||
struct ApiError {
|
struct ApiError {
|
||||||
status: StatusCode,
|
status: StatusCode,
|
||||||
body: String,
|
body: String,
|
||||||
|
headers: HeaderMap<HeaderValue>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<ApiError> for LanguageModelCompletionError {
|
||||||
|
fn from(error: ApiError) -> Self {
|
||||||
|
let retry_after = None;
|
||||||
|
LanguageModelCompletionError::from_http_status(
|
||||||
|
PROVIDER_NAME,
|
||||||
|
error.status,
|
||||||
|
error.body,
|
||||||
|
retry_after,
|
||||||
|
)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl LanguageModel for CloudLanguageModel {
|
impl LanguageModel for CloudLanguageModel {
|
||||||
|
@ -672,11 +668,29 @@ impl LanguageModel for CloudLanguageModel {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn provider_id(&self) -> LanguageModelProviderId {
|
fn provider_id(&self) -> LanguageModelProviderId {
|
||||||
LanguageModelProviderId(ZED_CLOUD_PROVIDER_ID.into())
|
PROVIDER_ID
|
||||||
}
|
}
|
||||||
|
|
||||||
fn provider_name(&self) -> LanguageModelProviderName {
|
fn provider_name(&self) -> LanguageModelProviderName {
|
||||||
LanguageModelProviderName(PROVIDER_NAME.into())
|
PROVIDER_NAME
|
||||||
|
}
|
||||||
|
|
||||||
|
fn upstream_provider_id(&self) -> LanguageModelProviderId {
|
||||||
|
use zed_llm_client::LanguageModelProvider::*;
|
||||||
|
match self.model.provider {
|
||||||
|
Anthropic => language_model::ANTHROPIC_PROVIDER_ID,
|
||||||
|
OpenAi => language_model::OPEN_AI_PROVIDER_ID,
|
||||||
|
Google => language_model::GOOGLE_PROVIDER_ID,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn upstream_provider_name(&self) -> LanguageModelProviderName {
|
||||||
|
use zed_llm_client::LanguageModelProvider::*;
|
||||||
|
match self.model.provider {
|
||||||
|
Anthropic => language_model::ANTHROPIC_PROVIDER_NAME,
|
||||||
|
OpenAi => language_model::OPEN_AI_PROVIDER_NAME,
|
||||||
|
Google => language_model::GOOGLE_PROVIDER_NAME,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn supports_tools(&self) -> bool {
|
fn supports_tools(&self) -> bool {
|
||||||
|
@ -776,6 +790,7 @@ impl LanguageModel for CloudLanguageModel {
|
||||||
.body(serde_json::to_string(&request_body)?.into())?;
|
.body(serde_json::to_string(&request_body)?.into())?;
|
||||||
let mut response = http_client.send(request).await?;
|
let mut response = http_client.send(request).await?;
|
||||||
let status = response.status();
|
let status = response.status();
|
||||||
|
let headers = response.headers().clone();
|
||||||
let mut response_body = String::new();
|
let mut response_body = String::new();
|
||||||
response
|
response
|
||||||
.body_mut()
|
.body_mut()
|
||||||
|
@ -790,7 +805,8 @@ impl LanguageModel for CloudLanguageModel {
|
||||||
} else {
|
} else {
|
||||||
Err(anyhow!(ApiError {
|
Err(anyhow!(ApiError {
|
||||||
status,
|
status,
|
||||||
body: response_body
|
body: response_body,
|
||||||
|
headers
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -855,18 +871,7 @@ impl LanguageModel for CloudLanguageModel {
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
.map_err(|err| match err.downcast::<ApiError>() {
|
.map_err(|err| match err.downcast::<ApiError>() {
|
||||||
Ok(api_err) => {
|
Ok(api_err) => anyhow!(LanguageModelCompletionError::from(api_err)),
|
||||||
if api_err.status == StatusCode::BAD_REQUEST {
|
|
||||||
if let Some(tokens) = parse_prompt_too_long(&api_err.body) {
|
|
||||||
return anyhow!(
|
|
||||||
LanguageModelKnownError::ContextWindowLimitExceeded {
|
|
||||||
tokens
|
|
||||||
}
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
anyhow!(api_err)
|
|
||||||
}
|
|
||||||
Err(err) => anyhow!(err),
|
Err(err) => anyhow!(err),
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
|
@ -995,7 +1000,7 @@ where
|
||||||
.flat_map(move |event| {
|
.flat_map(move |event| {
|
||||||
futures::stream::iter(match event {
|
futures::stream::iter(match event {
|
||||||
Err(error) => {
|
Err(error) => {
|
||||||
vec![Err(LanguageModelCompletionError::Other(error))]
|
vec![Err(LanguageModelCompletionError::from(error))]
|
||||||
}
|
}
|
||||||
Ok(CloudCompletionEvent::Status(event)) => {
|
Ok(CloudCompletionEvent::Status(event)) => {
|
||||||
vec![Ok(LanguageModelCompletionEvent::StatusUpdate(event))]
|
vec![Ok(LanguageModelCompletionEvent::StatusUpdate(event))]
|
||||||
|
|
|
@ -35,8 +35,9 @@ use super::anthropic::count_anthropic_tokens;
|
||||||
use super::google::count_google_tokens;
|
use super::google::count_google_tokens;
|
||||||
use super::open_ai::count_open_ai_tokens;
|
use super::open_ai::count_open_ai_tokens;
|
||||||
|
|
||||||
const PROVIDER_ID: &str = "copilot_chat";
|
const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("copilot_chat");
|
||||||
const PROVIDER_NAME: &str = "GitHub Copilot Chat";
|
const PROVIDER_NAME: LanguageModelProviderName =
|
||||||
|
LanguageModelProviderName::new("GitHub Copilot Chat");
|
||||||
|
|
||||||
pub struct CopilotChatLanguageModelProvider {
|
pub struct CopilotChatLanguageModelProvider {
|
||||||
state: Entity<State>,
|
state: Entity<State>,
|
||||||
|
@ -102,11 +103,11 @@ impl LanguageModelProviderState for CopilotChatLanguageModelProvider {
|
||||||
|
|
||||||
impl LanguageModelProvider for CopilotChatLanguageModelProvider {
|
impl LanguageModelProvider for CopilotChatLanguageModelProvider {
|
||||||
fn id(&self) -> LanguageModelProviderId {
|
fn id(&self) -> LanguageModelProviderId {
|
||||||
LanguageModelProviderId(PROVIDER_ID.into())
|
PROVIDER_ID
|
||||||
}
|
}
|
||||||
|
|
||||||
fn name(&self) -> LanguageModelProviderName {
|
fn name(&self) -> LanguageModelProviderName {
|
||||||
LanguageModelProviderName(PROVIDER_NAME.into())
|
PROVIDER_NAME
|
||||||
}
|
}
|
||||||
|
|
||||||
fn icon(&self) -> IconName {
|
fn icon(&self) -> IconName {
|
||||||
|
@ -201,11 +202,11 @@ impl LanguageModel for CopilotChatLanguageModel {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn provider_id(&self) -> LanguageModelProviderId {
|
fn provider_id(&self) -> LanguageModelProviderId {
|
||||||
LanguageModelProviderId(PROVIDER_ID.into())
|
PROVIDER_ID
|
||||||
}
|
}
|
||||||
|
|
||||||
fn provider_name(&self) -> LanguageModelProviderName {
|
fn provider_name(&self) -> LanguageModelProviderName {
|
||||||
LanguageModelProviderName(PROVIDER_NAME.into())
|
PROVIDER_NAME
|
||||||
}
|
}
|
||||||
|
|
||||||
fn supports_tools(&self) -> bool {
|
fn supports_tools(&self) -> bool {
|
||||||
|
@ -391,24 +392,24 @@ pub fn map_to_language_model_completion_events(
|
||||||
serde_json::Value::from_str(&tool_call.arguments)
|
serde_json::Value::from_str(&tool_call.arguments)
|
||||||
};
|
};
|
||||||
match arguments {
|
match arguments {
|
||||||
Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse(
|
Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse(
|
||||||
LanguageModelToolUse {
|
LanguageModelToolUse {
|
||||||
id: tool_call.id.clone().into(),
|
id: tool_call.id.clone().into(),
|
||||||
name: tool_call.name.as_str().into(),
|
name: tool_call.name.as_str().into(),
|
||||||
is_input_complete: true,
|
is_input_complete: true,
|
||||||
input,
|
input,
|
||||||
raw_input: tool_call.arguments.clone(),
|
raw_input: tool_call.arguments.clone(),
|
||||||
},
|
},
|
||||||
)),
|
)),
|
||||||
Err(error) => {
|
Err(error) => Ok(
|
||||||
Err(LanguageModelCompletionError::BadInputJson {
|
LanguageModelCompletionEvent::ToolUseJsonParseError {
|
||||||
id: tool_call.id.into(),
|
id: tool_call.id.into(),
|
||||||
tool_name: tool_call.name.as_str().into(),
|
tool_name: tool_call.name.as_str().into(),
|
||||||
raw_input: tool_call.arguments.into(),
|
raw_input: tool_call.arguments.into(),
|
||||||
json_parse_error: error.to_string(),
|
json_parse_error: error.to_string(),
|
||||||
})
|
},
|
||||||
}
|
),
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
));
|
));
|
||||||
|
|
||||||
|
|
|
@ -28,8 +28,8 @@ use util::ResultExt;
|
||||||
|
|
||||||
use crate::{AllLanguageModelSettings, ui::InstructionListItem};
|
use crate::{AllLanguageModelSettings, ui::InstructionListItem};
|
||||||
|
|
||||||
const PROVIDER_ID: &str = "deepseek";
|
const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("deepseek");
|
||||||
const PROVIDER_NAME: &str = "DeepSeek";
|
const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("DeepSeek");
|
||||||
const DEEPSEEK_API_KEY_VAR: &str = "DEEPSEEK_API_KEY";
|
const DEEPSEEK_API_KEY_VAR: &str = "DEEPSEEK_API_KEY";
|
||||||
|
|
||||||
#[derive(Default)]
|
#[derive(Default)]
|
||||||
|
@ -174,11 +174,11 @@ impl LanguageModelProviderState for DeepSeekLanguageModelProvider {
|
||||||
|
|
||||||
impl LanguageModelProvider for DeepSeekLanguageModelProvider {
|
impl LanguageModelProvider for DeepSeekLanguageModelProvider {
|
||||||
fn id(&self) -> LanguageModelProviderId {
|
fn id(&self) -> LanguageModelProviderId {
|
||||||
LanguageModelProviderId(PROVIDER_ID.into())
|
PROVIDER_ID
|
||||||
}
|
}
|
||||||
|
|
||||||
fn name(&self) -> LanguageModelProviderName {
|
fn name(&self) -> LanguageModelProviderName {
|
||||||
LanguageModelProviderName(PROVIDER_NAME.into())
|
PROVIDER_NAME
|
||||||
}
|
}
|
||||||
|
|
||||||
fn icon(&self) -> IconName {
|
fn icon(&self) -> IconName {
|
||||||
|
@ -283,11 +283,11 @@ impl LanguageModel for DeepSeekLanguageModel {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn provider_id(&self) -> LanguageModelProviderId {
|
fn provider_id(&self) -> LanguageModelProviderId {
|
||||||
LanguageModelProviderId(PROVIDER_ID.into())
|
PROVIDER_ID
|
||||||
}
|
}
|
||||||
|
|
||||||
fn provider_name(&self) -> LanguageModelProviderName {
|
fn provider_name(&self) -> LanguageModelProviderName {
|
||||||
LanguageModelProviderName(PROVIDER_NAME.into())
|
PROVIDER_NAME
|
||||||
}
|
}
|
||||||
|
|
||||||
fn supports_tools(&self) -> bool {
|
fn supports_tools(&self) -> bool {
|
||||||
|
@ -466,7 +466,7 @@ impl DeepSeekEventMapper {
|
||||||
events.flat_map(move |event| {
|
events.flat_map(move |event| {
|
||||||
futures::stream::iter(match event {
|
futures::stream::iter(match event {
|
||||||
Ok(event) => self.map_event(event),
|
Ok(event) => self.map_event(event),
|
||||||
Err(error) => vec![Err(LanguageModelCompletionError::Other(anyhow!(error)))],
|
Err(error) => vec![Err(LanguageModelCompletionError::from(error))],
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -476,7 +476,7 @@ impl DeepSeekEventMapper {
|
||||||
event: deepseek::StreamResponse,
|
event: deepseek::StreamResponse,
|
||||||
) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
|
) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
|
||||||
let Some(choice) = event.choices.first() else {
|
let Some(choice) = event.choices.first() else {
|
||||||
return vec![Err(LanguageModelCompletionError::Other(anyhow!(
|
return vec![Err(LanguageModelCompletionError::from(anyhow!(
|
||||||
"Response contained no choices"
|
"Response contained no choices"
|
||||||
)))];
|
)))];
|
||||||
};
|
};
|
||||||
|
@ -538,8 +538,8 @@ impl DeepSeekEventMapper {
|
||||||
raw_input: tool_call.arguments.clone(),
|
raw_input: tool_call.arguments.clone(),
|
||||||
},
|
},
|
||||||
)),
|
)),
|
||||||
Err(error) => Err(LanguageModelCompletionError::BadInputJson {
|
Err(error) => Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
|
||||||
id: tool_call.id.into(),
|
id: tool_call.id.clone().into(),
|
||||||
tool_name: tool_call.name.as_str().into(),
|
tool_name: tool_call.name.as_str().into(),
|
||||||
raw_input: tool_call.arguments.into(),
|
raw_input: tool_call.arguments.into(),
|
||||||
json_parse_error: error.to_string(),
|
json_parse_error: error.to_string(),
|
||||||
|
|
|
@ -37,8 +37,8 @@ use util::ResultExt;
|
||||||
use crate::AllLanguageModelSettings;
|
use crate::AllLanguageModelSettings;
|
||||||
use crate::ui::InstructionListItem;
|
use crate::ui::InstructionListItem;
|
||||||
|
|
||||||
const PROVIDER_ID: &str = "google";
|
const PROVIDER_ID: LanguageModelProviderId = language_model::GOOGLE_PROVIDER_ID;
|
||||||
const PROVIDER_NAME: &str = "Google AI";
|
const PROVIDER_NAME: LanguageModelProviderName = language_model::GOOGLE_PROVIDER_NAME;
|
||||||
|
|
||||||
#[derive(Default, Clone, Debug, PartialEq)]
|
#[derive(Default, Clone, Debug, PartialEq)]
|
||||||
pub struct GoogleSettings {
|
pub struct GoogleSettings {
|
||||||
|
@ -207,11 +207,11 @@ impl LanguageModelProviderState for GoogleLanguageModelProvider {
|
||||||
|
|
||||||
impl LanguageModelProvider for GoogleLanguageModelProvider {
|
impl LanguageModelProvider for GoogleLanguageModelProvider {
|
||||||
fn id(&self) -> LanguageModelProviderId {
|
fn id(&self) -> LanguageModelProviderId {
|
||||||
LanguageModelProviderId(PROVIDER_ID.into())
|
PROVIDER_ID
|
||||||
}
|
}
|
||||||
|
|
||||||
fn name(&self) -> LanguageModelProviderName {
|
fn name(&self) -> LanguageModelProviderName {
|
||||||
LanguageModelProviderName(PROVIDER_NAME.into())
|
PROVIDER_NAME
|
||||||
}
|
}
|
||||||
|
|
||||||
fn icon(&self) -> IconName {
|
fn icon(&self) -> IconName {
|
||||||
|
@ -334,11 +334,11 @@ impl LanguageModel for GoogleLanguageModel {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn provider_id(&self) -> LanguageModelProviderId {
|
fn provider_id(&self) -> LanguageModelProviderId {
|
||||||
LanguageModelProviderId(PROVIDER_ID.into())
|
PROVIDER_ID
|
||||||
}
|
}
|
||||||
|
|
||||||
fn provider_name(&self) -> LanguageModelProviderName {
|
fn provider_name(&self) -> LanguageModelProviderName {
|
||||||
LanguageModelProviderName(PROVIDER_NAME.into())
|
PROVIDER_NAME
|
||||||
}
|
}
|
||||||
|
|
||||||
fn supports_tools(&self) -> bool {
|
fn supports_tools(&self) -> bool {
|
||||||
|
@ -423,9 +423,7 @@ impl LanguageModel for GoogleLanguageModel {
|
||||||
);
|
);
|
||||||
let request = self.stream_completion(request, cx);
|
let request = self.stream_completion(request, cx);
|
||||||
let future = self.request_limiter.stream(async move {
|
let future = self.request_limiter.stream(async move {
|
||||||
let response = request
|
let response = request.await.map_err(LanguageModelCompletionError::from)?;
|
||||||
.await
|
|
||||||
.map_err(|err| LanguageModelCompletionError::Other(anyhow!(err)))?;
|
|
||||||
Ok(GoogleEventMapper::new().map_stream(response))
|
Ok(GoogleEventMapper::new().map_stream(response))
|
||||||
});
|
});
|
||||||
async move { Ok(future.await?.boxed()) }.boxed()
|
async move { Ok(future.await?.boxed()) }.boxed()
|
||||||
|
@ -622,7 +620,7 @@ impl GoogleEventMapper {
|
||||||
futures::stream::iter(match event {
|
futures::stream::iter(match event {
|
||||||
Some(Ok(event)) => self.map_event(event),
|
Some(Ok(event)) => self.map_event(event),
|
||||||
Some(Err(error)) => {
|
Some(Err(error)) => {
|
||||||
vec![Err(LanguageModelCompletionError::Other(anyhow!(error)))]
|
vec![Err(LanguageModelCompletionError::from(error))]
|
||||||
}
|
}
|
||||||
None => vec![Ok(LanguageModelCompletionEvent::Stop(self.stop_reason))],
|
None => vec![Ok(LanguageModelCompletionEvent::Stop(self.stop_reason))],
|
||||||
})
|
})
|
||||||
|
|
|
@ -31,8 +31,8 @@ const LMSTUDIO_DOWNLOAD_URL: &str = "https://lmstudio.ai/download";
|
||||||
const LMSTUDIO_CATALOG_URL: &str = "https://lmstudio.ai/models";
|
const LMSTUDIO_CATALOG_URL: &str = "https://lmstudio.ai/models";
|
||||||
const LMSTUDIO_SITE: &str = "https://lmstudio.ai/";
|
const LMSTUDIO_SITE: &str = "https://lmstudio.ai/";
|
||||||
|
|
||||||
const PROVIDER_ID: &str = "lmstudio";
|
const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("lmstudio");
|
||||||
const PROVIDER_NAME: &str = "LM Studio";
|
const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("LM Studio");
|
||||||
|
|
||||||
#[derive(Default, Debug, Clone, PartialEq)]
|
#[derive(Default, Debug, Clone, PartialEq)]
|
||||||
pub struct LmStudioSettings {
|
pub struct LmStudioSettings {
|
||||||
|
@ -156,11 +156,11 @@ impl LanguageModelProviderState for LmStudioLanguageModelProvider {
|
||||||
|
|
||||||
impl LanguageModelProvider for LmStudioLanguageModelProvider {
|
impl LanguageModelProvider for LmStudioLanguageModelProvider {
|
||||||
fn id(&self) -> LanguageModelProviderId {
|
fn id(&self) -> LanguageModelProviderId {
|
||||||
LanguageModelProviderId(PROVIDER_ID.into())
|
PROVIDER_ID
|
||||||
}
|
}
|
||||||
|
|
||||||
fn name(&self) -> LanguageModelProviderName {
|
fn name(&self) -> LanguageModelProviderName {
|
||||||
LanguageModelProviderName(PROVIDER_NAME.into())
|
PROVIDER_NAME
|
||||||
}
|
}
|
||||||
|
|
||||||
fn icon(&self) -> IconName {
|
fn icon(&self) -> IconName {
|
||||||
|
@ -386,11 +386,11 @@ impl LanguageModel for LmStudioLanguageModel {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn provider_id(&self) -> LanguageModelProviderId {
|
fn provider_id(&self) -> LanguageModelProviderId {
|
||||||
LanguageModelProviderId(PROVIDER_ID.into())
|
PROVIDER_ID
|
||||||
}
|
}
|
||||||
|
|
||||||
fn provider_name(&self) -> LanguageModelProviderName {
|
fn provider_name(&self) -> LanguageModelProviderName {
|
||||||
LanguageModelProviderName(PROVIDER_NAME.into())
|
PROVIDER_NAME
|
||||||
}
|
}
|
||||||
|
|
||||||
fn supports_tools(&self) -> bool {
|
fn supports_tools(&self) -> bool {
|
||||||
|
@ -474,7 +474,7 @@ impl LmStudioEventMapper {
|
||||||
events.flat_map(move |event| {
|
events.flat_map(move |event| {
|
||||||
futures::stream::iter(match event {
|
futures::stream::iter(match event {
|
||||||
Ok(event) => self.map_event(event),
|
Ok(event) => self.map_event(event),
|
||||||
Err(error) => vec![Err(LanguageModelCompletionError::Other(anyhow!(error)))],
|
Err(error) => vec![Err(LanguageModelCompletionError::from(error))],
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -484,7 +484,7 @@ impl LmStudioEventMapper {
|
||||||
event: lmstudio::ResponseStreamEvent,
|
event: lmstudio::ResponseStreamEvent,
|
||||||
) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
|
) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
|
||||||
let Some(choice) = event.choices.into_iter().next() else {
|
let Some(choice) = event.choices.into_iter().next() else {
|
||||||
return vec![Err(LanguageModelCompletionError::Other(anyhow!(
|
return vec![Err(LanguageModelCompletionError::from(anyhow!(
|
||||||
"Response contained no choices"
|
"Response contained no choices"
|
||||||
)))];
|
)))];
|
||||||
};
|
};
|
||||||
|
@ -553,7 +553,7 @@ impl LmStudioEventMapper {
|
||||||
raw_input: tool_call.arguments,
|
raw_input: tool_call.arguments,
|
||||||
},
|
},
|
||||||
)),
|
)),
|
||||||
Err(error) => Err(LanguageModelCompletionError::BadInputJson {
|
Err(error) => Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
|
||||||
id: tool_call.id.into(),
|
id: tool_call.id.into(),
|
||||||
tool_name: tool_call.name.into(),
|
tool_name: tool_call.name.into(),
|
||||||
raw_input: tool_call.arguments.into(),
|
raw_input: tool_call.arguments.into(),
|
||||||
|
|
|
@ -2,8 +2,7 @@ use anyhow::{Context as _, Result, anyhow};
|
||||||
use collections::BTreeMap;
|
use collections::BTreeMap;
|
||||||
use credentials_provider::CredentialsProvider;
|
use credentials_provider::CredentialsProvider;
|
||||||
use editor::{Editor, EditorElement, EditorStyle};
|
use editor::{Editor, EditorElement, EditorStyle};
|
||||||
use futures::stream::BoxStream;
|
use futures::{FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream};
|
||||||
use futures::{FutureExt, StreamExt, future::BoxFuture};
|
|
||||||
use gpui::{
|
use gpui::{
|
||||||
AnyView, App, AsyncApp, Context, Entity, FontStyle, Subscription, Task, TextStyle, WhiteSpace,
|
AnyView, App, AsyncApp, Context, Entity, FontStyle, Subscription, Task, TextStyle, WhiteSpace,
|
||||||
};
|
};
|
||||||
|
@ -15,6 +14,7 @@ use language_model::{
|
||||||
LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent,
|
LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent,
|
||||||
RateLimiter, Role, StopReason, TokenUsage,
|
RateLimiter, Role, StopReason, TokenUsage,
|
||||||
};
|
};
|
||||||
|
use mistral::StreamResponse;
|
||||||
use schemars::JsonSchema;
|
use schemars::JsonSchema;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use settings::{Settings, SettingsStore};
|
use settings::{Settings, SettingsStore};
|
||||||
|
@ -29,8 +29,8 @@ use util::ResultExt;
|
||||||
|
|
||||||
use crate::{AllLanguageModelSettings, ui::InstructionListItem};
|
use crate::{AllLanguageModelSettings, ui::InstructionListItem};
|
||||||
|
|
||||||
const PROVIDER_ID: &str = "mistral";
|
const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("mistral");
|
||||||
const PROVIDER_NAME: &str = "Mistral";
|
const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("Mistral");
|
||||||
|
|
||||||
#[derive(Default, Clone, Debug, PartialEq)]
|
#[derive(Default, Clone, Debug, PartialEq)]
|
||||||
pub struct MistralSettings {
|
pub struct MistralSettings {
|
||||||
|
@ -171,11 +171,11 @@ impl LanguageModelProviderState for MistralLanguageModelProvider {
|
||||||
|
|
||||||
impl LanguageModelProvider for MistralLanguageModelProvider {
|
impl LanguageModelProvider for MistralLanguageModelProvider {
|
||||||
fn id(&self) -> LanguageModelProviderId {
|
fn id(&self) -> LanguageModelProviderId {
|
||||||
LanguageModelProviderId(PROVIDER_ID.into())
|
PROVIDER_ID
|
||||||
}
|
}
|
||||||
|
|
||||||
fn name(&self) -> LanguageModelProviderName {
|
fn name(&self) -> LanguageModelProviderName {
|
||||||
LanguageModelProviderName(PROVIDER_NAME.into())
|
PROVIDER_NAME
|
||||||
}
|
}
|
||||||
|
|
||||||
fn icon(&self) -> IconName {
|
fn icon(&self) -> IconName {
|
||||||
|
@ -298,11 +298,11 @@ impl LanguageModel for MistralLanguageModel {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn provider_id(&self) -> LanguageModelProviderId {
|
fn provider_id(&self) -> LanguageModelProviderId {
|
||||||
LanguageModelProviderId(PROVIDER_ID.into())
|
PROVIDER_ID
|
||||||
}
|
}
|
||||||
|
|
||||||
fn provider_name(&self) -> LanguageModelProviderName {
|
fn provider_name(&self) -> LanguageModelProviderName {
|
||||||
LanguageModelProviderName(PROVIDER_NAME.into())
|
PROVIDER_NAME
|
||||||
}
|
}
|
||||||
|
|
||||||
fn supports_tools(&self) -> bool {
|
fn supports_tools(&self) -> bool {
|
||||||
|
@ -579,13 +579,13 @@ impl MistralEventMapper {
|
||||||
|
|
||||||
pub fn map_stream(
|
pub fn map_stream(
|
||||||
mut self,
|
mut self,
|
||||||
events: Pin<Box<dyn Send + futures::Stream<Item = Result<mistral::StreamResponse>>>>,
|
events: Pin<Box<dyn Send + Stream<Item = Result<StreamResponse>>>>,
|
||||||
) -> impl futures::Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
|
) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
|
||||||
{
|
{
|
||||||
events.flat_map(move |event| {
|
events.flat_map(move |event| {
|
||||||
futures::stream::iter(match event {
|
futures::stream::iter(match event {
|
||||||
Ok(event) => self.map_event(event),
|
Ok(event) => self.map_event(event),
|
||||||
Err(error) => vec![Err(LanguageModelCompletionError::Other(anyhow!(error)))],
|
Err(error) => vec![Err(LanguageModelCompletionError::from(error))],
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -595,7 +595,7 @@ impl MistralEventMapper {
|
||||||
event: mistral::StreamResponse,
|
event: mistral::StreamResponse,
|
||||||
) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
|
) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
|
||||||
let Some(choice) = event.choices.first() else {
|
let Some(choice) = event.choices.first() else {
|
||||||
return vec![Err(LanguageModelCompletionError::Other(anyhow!(
|
return vec![Err(LanguageModelCompletionError::from(anyhow!(
|
||||||
"Response contained no choices"
|
"Response contained no choices"
|
||||||
)))];
|
)))];
|
||||||
};
|
};
|
||||||
|
@ -660,7 +660,7 @@ impl MistralEventMapper {
|
||||||
|
|
||||||
for (_, tool_call) in self.tool_calls_by_index.drain() {
|
for (_, tool_call) in self.tool_calls_by_index.drain() {
|
||||||
if tool_call.id.is_empty() || tool_call.name.is_empty() {
|
if tool_call.id.is_empty() || tool_call.name.is_empty() {
|
||||||
results.push(Err(LanguageModelCompletionError::Other(anyhow!(
|
results.push(Err(LanguageModelCompletionError::from(anyhow!(
|
||||||
"Received incomplete tool call: missing id or name"
|
"Received incomplete tool call: missing id or name"
|
||||||
))));
|
))));
|
||||||
continue;
|
continue;
|
||||||
|
@ -676,12 +676,14 @@ impl MistralEventMapper {
|
||||||
raw_input: tool_call.arguments,
|
raw_input: tool_call.arguments,
|
||||||
},
|
},
|
||||||
))),
|
))),
|
||||||
Err(error) => results.push(Err(LanguageModelCompletionError::BadInputJson {
|
Err(error) => {
|
||||||
id: tool_call.id.into(),
|
results.push(Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
|
||||||
tool_name: tool_call.name.into(),
|
id: tool_call.id.into(),
|
||||||
raw_input: tool_call.arguments.into(),
|
tool_name: tool_call.name.into(),
|
||||||
json_parse_error: error.to_string(),
|
raw_input: tool_call.arguments.into(),
|
||||||
})),
|
json_parse_error: error.to_string(),
|
||||||
|
}))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -30,8 +30,8 @@ const OLLAMA_DOWNLOAD_URL: &str = "https://ollama.com/download";
|
||||||
const OLLAMA_LIBRARY_URL: &str = "https://ollama.com/library";
|
const OLLAMA_LIBRARY_URL: &str = "https://ollama.com/library";
|
||||||
const OLLAMA_SITE: &str = "https://ollama.com/";
|
const OLLAMA_SITE: &str = "https://ollama.com/";
|
||||||
|
|
||||||
const PROVIDER_ID: &str = "ollama";
|
const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("ollama");
|
||||||
const PROVIDER_NAME: &str = "Ollama";
|
const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("Ollama");
|
||||||
|
|
||||||
#[derive(Default, Debug, Clone, PartialEq)]
|
#[derive(Default, Debug, Clone, PartialEq)]
|
||||||
pub struct OllamaSettings {
|
pub struct OllamaSettings {
|
||||||
|
@ -181,11 +181,11 @@ impl LanguageModelProviderState for OllamaLanguageModelProvider {
|
||||||
|
|
||||||
impl LanguageModelProvider for OllamaLanguageModelProvider {
|
impl LanguageModelProvider for OllamaLanguageModelProvider {
|
||||||
fn id(&self) -> LanguageModelProviderId {
|
fn id(&self) -> LanguageModelProviderId {
|
||||||
LanguageModelProviderId(PROVIDER_ID.into())
|
PROVIDER_ID
|
||||||
}
|
}
|
||||||
|
|
||||||
fn name(&self) -> LanguageModelProviderName {
|
fn name(&self) -> LanguageModelProviderName {
|
||||||
LanguageModelProviderName(PROVIDER_NAME.into())
|
PROVIDER_NAME
|
||||||
}
|
}
|
||||||
|
|
||||||
fn icon(&self) -> IconName {
|
fn icon(&self) -> IconName {
|
||||||
|
@ -350,11 +350,11 @@ impl LanguageModel for OllamaLanguageModel {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn provider_id(&self) -> LanguageModelProviderId {
|
fn provider_id(&self) -> LanguageModelProviderId {
|
||||||
LanguageModelProviderId(PROVIDER_ID.into())
|
PROVIDER_ID
|
||||||
}
|
}
|
||||||
|
|
||||||
fn provider_name(&self) -> LanguageModelProviderName {
|
fn provider_name(&self) -> LanguageModelProviderName {
|
||||||
LanguageModelProviderName(PROVIDER_NAME.into())
|
PROVIDER_NAME
|
||||||
}
|
}
|
||||||
|
|
||||||
fn supports_tools(&self) -> bool {
|
fn supports_tools(&self) -> bool {
|
||||||
|
@ -453,7 +453,7 @@ fn map_to_language_model_completion_events(
|
||||||
let delta = match response {
|
let delta = match response {
|
||||||
Ok(delta) => delta,
|
Ok(delta) => delta,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
let event = Err(LanguageModelCompletionError::Other(anyhow!(e)));
|
let event = Err(LanguageModelCompletionError::from(anyhow!(e)));
|
||||||
return Some((vec![event], state));
|
return Some((vec![event], state));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
@ -31,8 +31,8 @@ use util::ResultExt;
|
||||||
use crate::OpenAiSettingsContent;
|
use crate::OpenAiSettingsContent;
|
||||||
use crate::{AllLanguageModelSettings, ui::InstructionListItem};
|
use crate::{AllLanguageModelSettings, ui::InstructionListItem};
|
||||||
|
|
||||||
const PROVIDER_ID: &str = "openai";
|
const PROVIDER_ID: LanguageModelProviderId = language_model::OPEN_AI_PROVIDER_ID;
|
||||||
const PROVIDER_NAME: &str = "OpenAI";
|
const PROVIDER_NAME: LanguageModelProviderName = language_model::OPEN_AI_PROVIDER_NAME;
|
||||||
|
|
||||||
#[derive(Default, Clone, Debug, PartialEq)]
|
#[derive(Default, Clone, Debug, PartialEq)]
|
||||||
pub struct OpenAiSettings {
|
pub struct OpenAiSettings {
|
||||||
|
@ -173,11 +173,11 @@ impl LanguageModelProviderState for OpenAiLanguageModelProvider {
|
||||||
|
|
||||||
impl LanguageModelProvider for OpenAiLanguageModelProvider {
|
impl LanguageModelProvider for OpenAiLanguageModelProvider {
|
||||||
fn id(&self) -> LanguageModelProviderId {
|
fn id(&self) -> LanguageModelProviderId {
|
||||||
LanguageModelProviderId(PROVIDER_ID.into())
|
PROVIDER_ID
|
||||||
}
|
}
|
||||||
|
|
||||||
fn name(&self) -> LanguageModelProviderName {
|
fn name(&self) -> LanguageModelProviderName {
|
||||||
LanguageModelProviderName(PROVIDER_NAME.into())
|
PROVIDER_NAME
|
||||||
}
|
}
|
||||||
|
|
||||||
fn icon(&self) -> IconName {
|
fn icon(&self) -> IconName {
|
||||||
|
@ -267,7 +267,11 @@ impl OpenAiLanguageModel {
|
||||||
};
|
};
|
||||||
|
|
||||||
let future = self.request_limiter.stream(async move {
|
let future = self.request_limiter.stream(async move {
|
||||||
let api_key = api_key.context("Missing OpenAI API Key")?;
|
let Some(api_key) = api_key else {
|
||||||
|
return Err(LanguageModelCompletionError::NoApiKey {
|
||||||
|
provider: PROVIDER_NAME,
|
||||||
|
});
|
||||||
|
};
|
||||||
let request = stream_completion(http_client.as_ref(), &api_url, &api_key, request);
|
let request = stream_completion(http_client.as_ref(), &api_url, &api_key, request);
|
||||||
let response = request.await?;
|
let response = request.await?;
|
||||||
Ok(response)
|
Ok(response)
|
||||||
|
@ -287,11 +291,11 @@ impl LanguageModel for OpenAiLanguageModel {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn provider_id(&self) -> LanguageModelProviderId {
|
fn provider_id(&self) -> LanguageModelProviderId {
|
||||||
LanguageModelProviderId(PROVIDER_ID.into())
|
PROVIDER_ID
|
||||||
}
|
}
|
||||||
|
|
||||||
fn provider_name(&self) -> LanguageModelProviderName {
|
fn provider_name(&self) -> LanguageModelProviderName {
|
||||||
LanguageModelProviderName(PROVIDER_NAME.into())
|
PROVIDER_NAME
|
||||||
}
|
}
|
||||||
|
|
||||||
fn supports_tools(&self) -> bool {
|
fn supports_tools(&self) -> bool {
|
||||||
|
@ -525,7 +529,7 @@ impl OpenAiEventMapper {
|
||||||
events.flat_map(move |event| {
|
events.flat_map(move |event| {
|
||||||
futures::stream::iter(match event {
|
futures::stream::iter(match event {
|
||||||
Ok(event) => self.map_event(event),
|
Ok(event) => self.map_event(event),
|
||||||
Err(error) => vec![Err(LanguageModelCompletionError::Other(anyhow!(error)))],
|
Err(error) => vec![Err(LanguageModelCompletionError::from(anyhow!(error)))],
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -588,10 +592,10 @@ impl OpenAiEventMapper {
|
||||||
raw_input: tool_call.arguments.clone(),
|
raw_input: tool_call.arguments.clone(),
|
||||||
},
|
},
|
||||||
)),
|
)),
|
||||||
Err(error) => Err(LanguageModelCompletionError::BadInputJson {
|
Err(error) => Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
|
||||||
id: tool_call.id.into(),
|
id: tool_call.id.into(),
|
||||||
tool_name: tool_call.name.as_str().into(),
|
tool_name: tool_call.name.into(),
|
||||||
raw_input: tool_call.arguments.into(),
|
raw_input: tool_call.arguments.clone().into(),
|
||||||
json_parse_error: error.to_string(),
|
json_parse_error: error.to_string(),
|
||||||
}),
|
}),
|
||||||
}
|
}
|
||||||
|
|
|
@ -29,8 +29,8 @@ use util::ResultExt;
|
||||||
|
|
||||||
use crate::{AllLanguageModelSettings, ui::InstructionListItem};
|
use crate::{AllLanguageModelSettings, ui::InstructionListItem};
|
||||||
|
|
||||||
const PROVIDER_ID: &str = "openrouter";
|
const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("openrouter");
|
||||||
const PROVIDER_NAME: &str = "OpenRouter";
|
const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("OpenRouter");
|
||||||
|
|
||||||
#[derive(Default, Clone, Debug, PartialEq)]
|
#[derive(Default, Clone, Debug, PartialEq)]
|
||||||
pub struct OpenRouterSettings {
|
pub struct OpenRouterSettings {
|
||||||
|
@ -244,11 +244,11 @@ impl LanguageModelProviderState for OpenRouterLanguageModelProvider {
|
||||||
|
|
||||||
impl LanguageModelProvider for OpenRouterLanguageModelProvider {
|
impl LanguageModelProvider for OpenRouterLanguageModelProvider {
|
||||||
fn id(&self) -> LanguageModelProviderId {
|
fn id(&self) -> LanguageModelProviderId {
|
||||||
LanguageModelProviderId(PROVIDER_ID.into())
|
PROVIDER_ID
|
||||||
}
|
}
|
||||||
|
|
||||||
fn name(&self) -> LanguageModelProviderName {
|
fn name(&self) -> LanguageModelProviderName {
|
||||||
LanguageModelProviderName(PROVIDER_NAME.into())
|
PROVIDER_NAME
|
||||||
}
|
}
|
||||||
|
|
||||||
fn icon(&self) -> IconName {
|
fn icon(&self) -> IconName {
|
||||||
|
@ -363,11 +363,11 @@ impl LanguageModel for OpenRouterLanguageModel {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn provider_id(&self) -> LanguageModelProviderId {
|
fn provider_id(&self) -> LanguageModelProviderId {
|
||||||
LanguageModelProviderId(PROVIDER_ID.into())
|
PROVIDER_ID
|
||||||
}
|
}
|
||||||
|
|
||||||
fn provider_name(&self) -> LanguageModelProviderName {
|
fn provider_name(&self) -> LanguageModelProviderName {
|
||||||
LanguageModelProviderName(PROVIDER_NAME.into())
|
PROVIDER_NAME
|
||||||
}
|
}
|
||||||
|
|
||||||
fn supports_tools(&self) -> bool {
|
fn supports_tools(&self) -> bool {
|
||||||
|
@ -607,7 +607,7 @@ impl OpenRouterEventMapper {
|
||||||
events.flat_map(move |event| {
|
events.flat_map(move |event| {
|
||||||
futures::stream::iter(match event {
|
futures::stream::iter(match event {
|
||||||
Ok(event) => self.map_event(event),
|
Ok(event) => self.map_event(event),
|
||||||
Err(error) => vec![Err(LanguageModelCompletionError::Other(anyhow!(error)))],
|
Err(error) => vec![Err(LanguageModelCompletionError::from(anyhow!(error)))],
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -617,7 +617,7 @@ impl OpenRouterEventMapper {
|
||||||
event: ResponseStreamEvent,
|
event: ResponseStreamEvent,
|
||||||
) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
|
) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
|
||||||
let Some(choice) = event.choices.first() else {
|
let Some(choice) = event.choices.first() else {
|
||||||
return vec![Err(LanguageModelCompletionError::Other(anyhow!(
|
return vec![Err(LanguageModelCompletionError::from(anyhow!(
|
||||||
"Response contained no choices"
|
"Response contained no choices"
|
||||||
)))];
|
)))];
|
||||||
};
|
};
|
||||||
|
@ -683,10 +683,10 @@ impl OpenRouterEventMapper {
|
||||||
raw_input: tool_call.arguments.clone(),
|
raw_input: tool_call.arguments.clone(),
|
||||||
},
|
},
|
||||||
)),
|
)),
|
||||||
Err(error) => Err(LanguageModelCompletionError::BadInputJson {
|
Err(error) => Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
|
||||||
id: tool_call.id.into(),
|
id: tool_call.id.clone().into(),
|
||||||
tool_name: tool_call.name.as_str().into(),
|
tool_name: tool_call.name.as_str().into(),
|
||||||
raw_input: tool_call.arguments.into(),
|
raw_input: tool_call.arguments.clone().into(),
|
||||||
json_parse_error: error.to_string(),
|
json_parse_error: error.to_string(),
|
||||||
}),
|
}),
|
||||||
}
|
}
|
||||||
|
|
|
@ -25,8 +25,8 @@ use util::ResultExt;
|
||||||
|
|
||||||
use crate::{AllLanguageModelSettings, ui::InstructionListItem};
|
use crate::{AllLanguageModelSettings, ui::InstructionListItem};
|
||||||
|
|
||||||
const PROVIDER_ID: &str = "vercel";
|
const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("vercel");
|
||||||
const PROVIDER_NAME: &str = "Vercel";
|
const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("Vercel");
|
||||||
|
|
||||||
#[derive(Default, Clone, Debug, PartialEq)]
|
#[derive(Default, Clone, Debug, PartialEq)]
|
||||||
pub struct VercelSettings {
|
pub struct VercelSettings {
|
||||||
|
@ -172,11 +172,11 @@ impl LanguageModelProviderState for VercelLanguageModelProvider {
|
||||||
|
|
||||||
impl LanguageModelProvider for VercelLanguageModelProvider {
|
impl LanguageModelProvider for VercelLanguageModelProvider {
|
||||||
fn id(&self) -> LanguageModelProviderId {
|
fn id(&self) -> LanguageModelProviderId {
|
||||||
LanguageModelProviderId(PROVIDER_ID.into())
|
PROVIDER_ID
|
||||||
}
|
}
|
||||||
|
|
||||||
fn name(&self) -> LanguageModelProviderName {
|
fn name(&self) -> LanguageModelProviderName {
|
||||||
LanguageModelProviderName(PROVIDER_NAME.into())
|
PROVIDER_NAME
|
||||||
}
|
}
|
||||||
|
|
||||||
fn icon(&self) -> IconName {
|
fn icon(&self) -> IconName {
|
||||||
|
@ -269,7 +269,11 @@ impl VercelLanguageModel {
|
||||||
};
|
};
|
||||||
|
|
||||||
let future = self.request_limiter.stream(async move {
|
let future = self.request_limiter.stream(async move {
|
||||||
let api_key = api_key.context("Missing Vercel API Key")?;
|
let Some(api_key) = api_key else {
|
||||||
|
return Err(LanguageModelCompletionError::NoApiKey {
|
||||||
|
provider: PROVIDER_NAME,
|
||||||
|
});
|
||||||
|
};
|
||||||
let request =
|
let request =
|
||||||
open_ai::stream_completion(http_client.as_ref(), &api_url, &api_key, request);
|
open_ai::stream_completion(http_client.as_ref(), &api_url, &api_key, request);
|
||||||
let response = request.await?;
|
let response = request.await?;
|
||||||
|
@ -290,11 +294,11 @@ impl LanguageModel for VercelLanguageModel {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn provider_id(&self) -> LanguageModelProviderId {
|
fn provider_id(&self) -> LanguageModelProviderId {
|
||||||
LanguageModelProviderId(PROVIDER_ID.into())
|
PROVIDER_ID
|
||||||
}
|
}
|
||||||
|
|
||||||
fn provider_name(&self) -> LanguageModelProviderName {
|
fn provider_name(&self) -> LanguageModelProviderName {
|
||||||
LanguageModelProviderName(PROVIDER_NAME.into())
|
PROVIDER_NAME
|
||||||
}
|
}
|
||||||
|
|
||||||
fn supports_tools(&self) -> bool {
|
fn supports_tools(&self) -> bool {
|
||||||
|
|
|
@ -7,10 +7,7 @@ use gpui::{App, AppContext, Context, Entity, Subscription, Task};
|
||||||
use http_client::{HttpClient, Method};
|
use http_client::{HttpClient, Method};
|
||||||
use language_model::{LlmApiToken, RefreshLlmTokenListener};
|
use language_model::{LlmApiToken, RefreshLlmTokenListener};
|
||||||
use web_search::{WebSearchProvider, WebSearchProviderId};
|
use web_search::{WebSearchProvider, WebSearchProviderId};
|
||||||
use zed_llm_client::{
|
use zed_llm_client::{EXPIRED_LLM_TOKEN_HEADER_NAME, WebSearchBody, WebSearchResponse};
|
||||||
CLIENT_SUPPORTS_EXA_WEB_SEARCH_PROVIDER_HEADER_NAME, EXPIRED_LLM_TOKEN_HEADER_NAME,
|
|
||||||
WebSearchBody, WebSearchResponse,
|
|
||||||
};
|
|
||||||
|
|
||||||
pub struct CloudWebSearchProvider {
|
pub struct CloudWebSearchProvider {
|
||||||
state: Entity<State>,
|
state: Entity<State>,
|
||||||
|
@ -92,7 +89,6 @@ async fn perform_web_search(
|
||||||
.uri(http_client.build_zed_llm_url("/web_search", &[])?.as_ref())
|
.uri(http_client.build_zed_llm_url("/web_search", &[])?.as_ref())
|
||||||
.header("Content-Type", "application/json")
|
.header("Content-Type", "application/json")
|
||||||
.header("Authorization", format!("Bearer {token}"))
|
.header("Authorization", format!("Bearer {token}"))
|
||||||
.header(CLIENT_SUPPORTS_EXA_WEB_SEARCH_PROVIDER_HEADER_NAME, "true")
|
|
||||||
.body(serde_json::to_string(&body)?.into())?;
|
.body(serde_json::to_string(&body)?.into())?;
|
||||||
let mut response = http_client
|
let mut response = http_client
|
||||||
.send(request)
|
.send(request)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue