agent: Handle context window exceeded errors from Anthropic (#28688)
 Release Notes: - agent: Handle context window exceeded errors from Anthropic
This commit is contained in:
parent
4a57664c7f
commit
b45230784d
9 changed files with 190 additions and 28 deletions
1
Cargo.lock
generated
1
Cargo.lock
generated
|
@ -7708,6 +7708,7 @@ dependencies = [
|
||||||
"smol",
|
"smol",
|
||||||
"strum",
|
"strum",
|
||||||
"theme",
|
"theme",
|
||||||
|
"thiserror 2.0.12",
|
||||||
"tiktoken-rs",
|
"tiktoken-rs",
|
||||||
"tokio",
|
"tokio",
|
||||||
"ui",
|
"ui",
|
||||||
|
|
|
@ -761,13 +761,29 @@ impl MessageEditor {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn render_reaching_token_limit(&self, line_height: Pixels, cx: &mut Context<Self>) -> Div {
|
fn render_token_limit_callout(
|
||||||
|
&self,
|
||||||
|
line_height: Pixels,
|
||||||
|
token_usage_ratio: TokenUsageRatio,
|
||||||
|
cx: &mut Context<Self>,
|
||||||
|
) -> Div {
|
||||||
|
let heading = if token_usage_ratio == TokenUsageRatio::Exceeded {
|
||||||
|
"Thread reached the token limit"
|
||||||
|
} else {
|
||||||
|
"Thread reaching the token limit soon"
|
||||||
|
};
|
||||||
|
|
||||||
h_flex()
|
h_flex()
|
||||||
.p_2()
|
.p_2()
|
||||||
.gap_2()
|
.gap_2()
|
||||||
.flex_wrap()
|
.flex_wrap()
|
||||||
.justify_between()
|
.justify_between()
|
||||||
.bg(cx.theme().status().warning_background.opacity(0.1))
|
.bg(
|
||||||
|
if token_usage_ratio == TokenUsageRatio::Exceeded {
|
||||||
|
cx.theme().status().error_background.opacity(0.1)
|
||||||
|
} else {
|
||||||
|
cx.theme().status().warning_background.opacity(0.1)
|
||||||
|
})
|
||||||
.border_t_1()
|
.border_t_1()
|
||||||
.border_color(cx.theme().colors().border)
|
.border_color(cx.theme().colors().border)
|
||||||
.child(
|
.child(
|
||||||
|
@ -779,15 +795,21 @@ impl MessageEditor {
|
||||||
.h(line_height)
|
.h(line_height)
|
||||||
.justify_center()
|
.justify_center()
|
||||||
.child(
|
.child(
|
||||||
Icon::new(IconName::Warning)
|
if token_usage_ratio == TokenUsageRatio::Exceeded {
|
||||||
.color(Color::Warning)
|
Icon::new(IconName::X)
|
||||||
.size(IconSize::XSmall),
|
.color(Color::Error)
|
||||||
|
.size(IconSize::XSmall)
|
||||||
|
} else {
|
||||||
|
Icon::new(IconName::Warning)
|
||||||
|
.color(Color::Warning)
|
||||||
|
.size(IconSize::XSmall)
|
||||||
|
}
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
.child(
|
.child(
|
||||||
v_flex()
|
v_flex()
|
||||||
.mr_auto()
|
.mr_auto()
|
||||||
.child(Label::new("Thread reaching the token limit soon").size(LabelSize::Small))
|
.child(Label::new(heading).size(LabelSize::Small))
|
||||||
.child(
|
.child(
|
||||||
Label::new(
|
Label::new(
|
||||||
"Start a new thread from a summary to continue the conversation.",
|
"Start a new thread from a summary to continue the conversation.",
|
||||||
|
@ -875,7 +897,13 @@ impl Render for MessageEditor {
|
||||||
.child(self.render_editor(font_size, line_height, window, cx))
|
.child(self.render_editor(font_size, line_height, window, cx))
|
||||||
.when(
|
.when(
|
||||||
total_token_usage.ratio != TokenUsageRatio::Normal,
|
total_token_usage.ratio != TokenUsageRatio::Normal,
|
||||||
|parent| parent.child(self.render_reaching_token_limit(line_height, cx)),
|
|parent| {
|
||||||
|
parent.child(self.render_token_limit_callout(
|
||||||
|
line_height,
|
||||||
|
total_token_usage.ratio,
|
||||||
|
cx,
|
||||||
|
))
|
||||||
|
},
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,10 +15,11 @@ use futures::{FutureExt, StreamExt as _};
|
||||||
use git::repository::DiffType;
|
use git::repository::DiffType;
|
||||||
use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
|
use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
|
||||||
use language_model::{
|
use language_model::{
|
||||||
ConfiguredModel, LanguageModel, LanguageModelCompletionEvent, LanguageModelRegistry,
|
ConfiguredModel, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
|
||||||
LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool,
|
LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
|
||||||
LanguageModelToolResult, LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
|
LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
|
||||||
PaymentRequiredError, Role, StopReason, TokenUsage,
|
LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent, PaymentRequiredError,
|
||||||
|
Role, StopReason, TokenUsage,
|
||||||
};
|
};
|
||||||
use project::Project;
|
use project::Project;
|
||||||
use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
|
use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
|
||||||
|
@ -228,7 +229,7 @@ pub struct TotalTokenUsage {
|
||||||
pub ratio: TokenUsageRatio,
|
pub ratio: TokenUsageRatio,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Default, PartialEq, Eq)]
|
#[derive(Debug, Default, PartialEq, Eq)]
|
||||||
pub enum TokenUsageRatio {
|
pub enum TokenUsageRatio {
|
||||||
#[default]
|
#[default]
|
||||||
Normal,
|
Normal,
|
||||||
|
@ -260,11 +261,20 @@ pub struct Thread {
|
||||||
pending_checkpoint: Option<ThreadCheckpoint>,
|
pending_checkpoint: Option<ThreadCheckpoint>,
|
||||||
initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
|
initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
|
||||||
cumulative_token_usage: TokenUsage,
|
cumulative_token_usage: TokenUsage,
|
||||||
|
exceeded_window_error: Option<ExceededWindowError>,
|
||||||
feedback: Option<ThreadFeedback>,
|
feedback: Option<ThreadFeedback>,
|
||||||
message_feedback: HashMap<MessageId, ThreadFeedback>,
|
message_feedback: HashMap<MessageId, ThreadFeedback>,
|
||||||
last_auto_capture_at: Option<Instant>,
|
last_auto_capture_at: Option<Instant>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct ExceededWindowError {
|
||||||
|
/// Model used when last message exceeded context window
|
||||||
|
model_id: LanguageModelId,
|
||||||
|
/// Token count including last message
|
||||||
|
token_count: usize,
|
||||||
|
}
|
||||||
|
|
||||||
impl Thread {
|
impl Thread {
|
||||||
pub fn new(
|
pub fn new(
|
||||||
project: Entity<Project>,
|
project: Entity<Project>,
|
||||||
|
@ -301,6 +311,7 @@ impl Thread {
|
||||||
.shared()
|
.shared()
|
||||||
},
|
},
|
||||||
cumulative_token_usage: TokenUsage::default(),
|
cumulative_token_usage: TokenUsage::default(),
|
||||||
|
exceeded_window_error: None,
|
||||||
feedback: None,
|
feedback: None,
|
||||||
message_feedback: HashMap::default(),
|
message_feedback: HashMap::default(),
|
||||||
last_auto_capture_at: None,
|
last_auto_capture_at: None,
|
||||||
|
@ -367,6 +378,7 @@ impl Thread {
|
||||||
action_log: cx.new(|_| ActionLog::new(project)),
|
action_log: cx.new(|_| ActionLog::new(project)),
|
||||||
initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
|
initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
|
||||||
cumulative_token_usage: serialized.cumulative_token_usage,
|
cumulative_token_usage: serialized.cumulative_token_usage,
|
||||||
|
exceeded_window_error: None,
|
||||||
feedback: None,
|
feedback: None,
|
||||||
message_feedback: HashMap::default(),
|
message_feedback: HashMap::default(),
|
||||||
last_auto_capture_at: None,
|
last_auto_capture_at: None,
|
||||||
|
@ -817,6 +829,7 @@ impl Thread {
|
||||||
initial_project_snapshot,
|
initial_project_snapshot,
|
||||||
cumulative_token_usage: this.cumulative_token_usage.clone(),
|
cumulative_token_usage: this.cumulative_token_usage.clone(),
|
||||||
detailed_summary_state: this.detailed_summary_state.clone(),
|
detailed_summary_state: this.detailed_summary_state.clone(),
|
||||||
|
exceeded_window_error: this.exceeded_window_error.clone(),
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -1129,6 +1142,20 @@ impl Thread {
|
||||||
cx.emit(ThreadEvent::ShowError(
|
cx.emit(ThreadEvent::ShowError(
|
||||||
ThreadError::MaxMonthlySpendReached,
|
ThreadError::MaxMonthlySpendReached,
|
||||||
));
|
));
|
||||||
|
} else if let Some(known_error) =
|
||||||
|
error.downcast_ref::<LanguageModelKnownError>()
|
||||||
|
{
|
||||||
|
match known_error {
|
||||||
|
LanguageModelKnownError::ContextWindowLimitExceeded {
|
||||||
|
tokens,
|
||||||
|
} => {
|
||||||
|
thread.exceeded_window_error = Some(ExceededWindowError {
|
||||||
|
model_id: model.id(),
|
||||||
|
token_count: *tokens,
|
||||||
|
});
|
||||||
|
cx.notify();
|
||||||
|
}
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
let error_message = error
|
let error_message = error
|
||||||
.chain()
|
.chain()
|
||||||
|
@ -1784,10 +1811,6 @@ impl Thread {
|
||||||
&self.project
|
&self.project
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn cumulative_token_usage(&self) -> TokenUsage {
|
|
||||||
self.cumulative_token_usage.clone()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
|
pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
|
||||||
if !cx.has_flag::<feature_flags::ThreadAutoCapture>() {
|
if !cx.has_flag::<feature_flags::ThreadAutoCapture>() {
|
||||||
return;
|
return;
|
||||||
|
@ -1840,6 +1863,16 @@ impl Thread {
|
||||||
|
|
||||||
let max = model.model.max_token_count();
|
let max = model.model.max_token_count();
|
||||||
|
|
||||||
|
if let Some(exceeded_error) = &self.exceeded_window_error {
|
||||||
|
if model.model.id() == exceeded_error.model_id {
|
||||||
|
return TotalTokenUsage {
|
||||||
|
total: exceeded_error.token_count,
|
||||||
|
max,
|
||||||
|
ratio: TokenUsageRatio::Exceeded,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
|
let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
|
||||||
.unwrap_or("0.8".to_string())
|
.unwrap_or("0.8".to_string())
|
||||||
|
|
|
@ -27,7 +27,9 @@ use serde::{Deserialize, Serialize};
|
||||||
use settings::{Settings as _, SettingsStore};
|
use settings::{Settings as _, SettingsStore};
|
||||||
use util::ResultExt as _;
|
use util::ResultExt as _;
|
||||||
|
|
||||||
use crate::thread::{DetailedSummaryState, MessageId, ProjectSnapshot, Thread, ThreadId};
|
use crate::thread::{
|
||||||
|
DetailedSummaryState, ExceededWindowError, MessageId, ProjectSnapshot, Thread, ThreadId,
|
||||||
|
};
|
||||||
|
|
||||||
const RULES_FILE_NAMES: [&'static str; 6] = [
|
const RULES_FILE_NAMES: [&'static str; 6] = [
|
||||||
".rules",
|
".rules",
|
||||||
|
@ -491,6 +493,8 @@ pub struct SerializedThread {
|
||||||
pub cumulative_token_usage: TokenUsage,
|
pub cumulative_token_usage: TokenUsage,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub detailed_summary_state: DetailedSummaryState,
|
pub detailed_summary_state: DetailedSummaryState,
|
||||||
|
#[serde(default)]
|
||||||
|
pub exceeded_window_error: Option<ExceededWindowError>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl SerializedThread {
|
impl SerializedThread {
|
||||||
|
@ -577,6 +581,7 @@ impl LegacySerializedThread {
|
||||||
initial_project_snapshot: self.initial_project_snapshot,
|
initial_project_snapshot: self.initial_project_snapshot,
|
||||||
cumulative_token_usage: TokenUsage::default(),
|
cumulative_token_usage: TokenUsage::default(),
|
||||||
detailed_summary_state: DetailedSummaryState::default(),
|
detailed_summary_state: DetailedSummaryState::default(),
|
||||||
|
exceeded_window_error: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -724,4 +724,54 @@ impl ApiError {
|
||||||
pub fn is_rate_limit_error(&self) -> bool {
|
pub fn is_rate_limit_error(&self) -> bool {
|
||||||
matches!(self.error_type.as_str(), "rate_limit_error")
|
matches!(self.error_type.as_str(), "rate_limit_error")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn match_window_exceeded(&self) -> Option<usize> {
|
||||||
|
let Some(ApiErrorCode::InvalidRequestError) = self.code() else {
|
||||||
|
return None;
|
||||||
|
};
|
||||||
|
|
||||||
|
parse_prompt_too_long(&self.message)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn parse_prompt_too_long(message: &str) -> Option<usize> {
|
||||||
|
message
|
||||||
|
.strip_prefix("prompt is too long: ")?
|
||||||
|
.split_once(" tokens")?
|
||||||
|
.0
|
||||||
|
.parse::<usize>()
|
||||||
|
.ok()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_match_window_exceeded() {
|
||||||
|
let error = ApiError {
|
||||||
|
error_type: "invalid_request_error".to_string(),
|
||||||
|
message: "prompt is too long: 220000 tokens > 200000".to_string(),
|
||||||
|
};
|
||||||
|
assert_eq!(error.match_window_exceeded(), Some(220_000));
|
||||||
|
|
||||||
|
let error = ApiError {
|
||||||
|
error_type: "invalid_request_error".to_string(),
|
||||||
|
message: "prompt is too long: 1234953 tokens".to_string(),
|
||||||
|
};
|
||||||
|
assert_eq!(error.match_window_exceeded(), Some(1234953));
|
||||||
|
|
||||||
|
let error = ApiError {
|
||||||
|
error_type: "invalid_request_error".to_string(),
|
||||||
|
message: "not a prompt length error".to_string(),
|
||||||
|
};
|
||||||
|
assert_eq!(error.match_window_exceeded(), None);
|
||||||
|
|
||||||
|
let error = ApiError {
|
||||||
|
error_type: "rate_limit_error".to_string(),
|
||||||
|
message: "prompt is too long: 12345 tokens".to_string(),
|
||||||
|
};
|
||||||
|
assert_eq!(error.match_window_exceeded(), None);
|
||||||
|
|
||||||
|
let error = ApiError {
|
||||||
|
error_type: "invalid_request_error".to_string(),
|
||||||
|
message: "prompt is too long: invalid tokens".to_string(),
|
||||||
|
};
|
||||||
|
assert_eq!(error.match_window_exceeded(), None);
|
||||||
}
|
}
|
||||||
|
|
|
@ -278,6 +278,12 @@ pub trait LanguageModel: Send + Sync {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Error)]
|
||||||
|
pub enum LanguageModelKnownError {
|
||||||
|
#[error("Context window limit exceeded ({tokens})")]
|
||||||
|
ContextWindowLimitExceeded { tokens: usize },
|
||||||
|
}
|
||||||
|
|
||||||
pub trait LanguageModelTool: 'static + DeserializeOwned + JsonSchema {
|
pub trait LanguageModelTool: 'static + DeserializeOwned + JsonSchema {
|
||||||
fn name() -> String;
|
fn name() -> String;
|
||||||
fn description() -> String;
|
fn description() -> String;
|
||||||
|
@ -347,7 +353,7 @@ pub trait LanguageModelProviderState: 'static {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
|
#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd, Serialize, Deserialize)]
|
||||||
pub struct LanguageModelId(pub SharedString);
|
pub struct LanguageModelId(pub SharedString);
|
||||||
|
|
||||||
#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
|
#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
|
||||||
|
|
|
@ -47,6 +47,7 @@ settings.workspace = true
|
||||||
smol.workspace = true
|
smol.workspace = true
|
||||||
strum.workspace = true
|
strum.workspace = true
|
||||||
theme.workspace = true
|
theme.workspace = true
|
||||||
|
thiserror.workspace = true
|
||||||
tiktoken-rs.workspace = true
|
tiktoken-rs.workspace = true
|
||||||
tokio = { workspace = true, features = ["rt", "rt-multi-thread"] }
|
tokio = { workspace = true, features = ["rt", "rt-multi-thread"] }
|
||||||
ui.workspace = true
|
ui.workspace = true
|
||||||
|
|
|
@ -13,8 +13,9 @@ use gpui::{
|
||||||
use http_client::HttpClient;
|
use http_client::HttpClient;
|
||||||
use language_model::{
|
use language_model::{
|
||||||
AuthenticateError, LanguageModel, LanguageModelCacheConfiguration, LanguageModelId,
|
AuthenticateError, LanguageModel, LanguageModelCacheConfiguration, LanguageModelId,
|
||||||
LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
|
LanguageModelKnownError, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
|
||||||
LanguageModelProviderState, LanguageModelRequest, MessageContent, RateLimiter, Role,
|
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, MessageContent,
|
||||||
|
RateLimiter, Role,
|
||||||
};
|
};
|
||||||
use language_model::{LanguageModelCompletionEvent, LanguageModelToolUse, StopReason};
|
use language_model::{LanguageModelCompletionEvent, LanguageModelToolUse, StopReason};
|
||||||
use schemars::JsonSchema;
|
use schemars::JsonSchema;
|
||||||
|
@ -454,7 +455,12 @@ impl LanguageModel for AnthropicModel {
|
||||||
);
|
);
|
||||||
let request = self.stream_completion(request, cx);
|
let request = self.stream_completion(request, cx);
|
||||||
let future = self.request_limiter.stream(async move {
|
let future = self.request_limiter.stream(async move {
|
||||||
let response = request.await.map_err(|err| anyhow!(err))?;
|
let response = request
|
||||||
|
.await
|
||||||
|
.map_err(|err| match err.downcast::<AnthropicError>() {
|
||||||
|
Ok(anthropic_err) => anthropic_err_to_anyhow(anthropic_err),
|
||||||
|
Err(err) => anyhow!(err),
|
||||||
|
})?;
|
||||||
Ok(map_to_language_model_completion_events(response))
|
Ok(map_to_language_model_completion_events(response))
|
||||||
});
|
});
|
||||||
async move { Ok(future.await?.boxed()) }.boxed()
|
async move { Ok(future.await?.boxed()) }.boxed()
|
||||||
|
@ -746,7 +752,7 @@ pub fn map_to_language_model_completion_events(
|
||||||
_ => {}
|
_ => {}
|
||||||
},
|
},
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
return Some((vec![Err(anyhow!(err))], state));
|
return Some((vec![Err(anthropic_err_to_anyhow(err))], state));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -757,6 +763,16 @@ pub fn map_to_language_model_completion_events(
|
||||||
.flat_map(futures::stream::iter)
|
.flat_map(futures::stream::iter)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn anthropic_err_to_anyhow(err: AnthropicError) -> anyhow::Error {
|
||||||
|
if let AnthropicError::ApiError(api_err) = &err {
|
||||||
|
if let Some(tokens) = api_err.match_window_exceeded() {
|
||||||
|
return anyhow!(LanguageModelKnownError::ContextWindowLimitExceeded { tokens });
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
anyhow!(err)
|
||||||
|
}
|
||||||
|
|
||||||
/// Updates usage data by preferring counts from `new`.
|
/// Updates usage data by preferring counts from `new`.
|
||||||
fn update_usage(usage: &mut Usage, new: &Usage) {
|
fn update_usage(usage: &mut Usage, new: &Usage) {
|
||||||
if let Some(input_tokens) = new.input_tokens {
|
if let Some(input_tokens) = new.input_tokens {
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
use anthropic::{AnthropicError, AnthropicModelMode};
|
use anthropic::{AnthropicError, AnthropicModelMode, parse_prompt_too_long};
|
||||||
use anyhow::{Result, anyhow};
|
use anyhow::{Result, anyhow};
|
||||||
use client::{
|
use client::{
|
||||||
Client, EXPIRED_LLM_TOKEN_HEADER_NAME, MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME,
|
Client, EXPIRED_LLM_TOKEN_HEADER_NAME, MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME,
|
||||||
|
@ -14,7 +14,7 @@ use gpui::{AnyElement, AnyView, App, AsyncApp, Context, Entity, Subscription, Ta
|
||||||
use http_client::{AsyncBody, HttpClient, Method, Response, StatusCode};
|
use http_client::{AsyncBody, HttpClient, Method, Response, StatusCode};
|
||||||
use language_model::{
|
use language_model::{
|
||||||
AuthenticateError, CloudModel, LanguageModel, LanguageModelCacheConfiguration, LanguageModelId,
|
AuthenticateError, CloudModel, LanguageModel, LanguageModelCacheConfiguration, LanguageModelId,
|
||||||
LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
|
LanguageModelKnownError, LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
|
||||||
LanguageModelProviderState, LanguageModelProviderTosView, LanguageModelRequest,
|
LanguageModelProviderState, LanguageModelProviderTosView, LanguageModelRequest,
|
||||||
LanguageModelToolSchemaFormat, RateLimiter, ZED_CLOUD_PROVIDER_ID,
|
LanguageModelToolSchemaFormat, RateLimiter, ZED_CLOUD_PROVIDER_ID,
|
||||||
};
|
};
|
||||||
|
@ -33,6 +33,7 @@ use std::{
|
||||||
time::Duration,
|
time::Duration,
|
||||||
};
|
};
|
||||||
use strum::IntoEnumIterator;
|
use strum::IntoEnumIterator;
|
||||||
|
use thiserror::Error;
|
||||||
use ui::{TintColor, prelude::*};
|
use ui::{TintColor, prelude::*};
|
||||||
|
|
||||||
use crate::AllLanguageModelSettings;
|
use crate::AllLanguageModelSettings;
|
||||||
|
@ -575,14 +576,19 @@ impl CloudLanguageModel {
|
||||||
} else {
|
} else {
|
||||||
let mut body = String::new();
|
let mut body = String::new();
|
||||||
response.body_mut().read_to_string(&mut body).await?;
|
response.body_mut().read_to_string(&mut body).await?;
|
||||||
return Err(anyhow!(
|
return Err(anyhow!(ApiError { status, body }));
|
||||||
"cloud language model completion failed with status {status}: {body}",
|
|
||||||
));
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Error)]
|
||||||
|
#[error("cloud language model completion failed with status {status}: {body}")]
|
||||||
|
struct ApiError {
|
||||||
|
status: StatusCode,
|
||||||
|
body: String,
|
||||||
|
}
|
||||||
|
|
||||||
impl LanguageModel for CloudLanguageModel {
|
impl LanguageModel for CloudLanguageModel {
|
||||||
fn id(&self) -> LanguageModelId {
|
fn id(&self) -> LanguageModelId {
|
||||||
self.id.clone()
|
self.id.clone()
|
||||||
|
@ -696,7 +702,23 @@ impl LanguageModel for CloudLanguageModel {
|
||||||
)?)?,
|
)?)?,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
.await?;
|
.await
|
||||||
|
.map_err(|err| match err.downcast::<ApiError>() {
|
||||||
|
Ok(api_err) => {
|
||||||
|
if api_err.status == StatusCode::BAD_REQUEST {
|
||||||
|
if let Some(tokens) = parse_prompt_too_long(&api_err.body) {
|
||||||
|
return anyhow!(
|
||||||
|
LanguageModelKnownError::ContextWindowLimitExceeded {
|
||||||
|
tokens
|
||||||
|
}
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
anyhow!(api_err)
|
||||||
|
}
|
||||||
|
Err(err) => anyhow!(err),
|
||||||
|
})?;
|
||||||
|
|
||||||
Ok(
|
Ok(
|
||||||
crate::provider::anthropic::map_to_language_model_completion_events(
|
crate::provider::anthropic::map_to_language_model_completion_events(
|
||||||
Box::pin(response_lines(response).map_err(AnthropicError::Other)),
|
Box::pin(response_lines(response).map_err(AnthropicError::Other)),
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue