assistant: Add support for displaying billing-related errors (#19082)
This PR adds support to the assistant for display billing-related errors. Pulling this out of #19081 to make it easier to cherry-pick. Release Notes: - N/A Co-authored-by: Antonio <antonio@zed.dev> Co-authored-by: Richard <richard@zed.dev>
This commit is contained in:
parent
5cf0217549
commit
84b61c8b1a
9 changed files with 330 additions and 83 deletions
|
@ -47,6 +47,7 @@ settings.workspace = true
|
|||
smol.workspace = true
|
||||
strum.workspace = true
|
||||
theme.workspace = true
|
||||
thiserror.workspace = true
|
||||
tiktoken-rs.workspace = true
|
||||
ui.workspace = true
|
||||
util.workspace = true
|
||||
|
|
|
@ -7,7 +7,10 @@ use crate::{
|
|||
};
|
||||
use anthropic::AnthropicError;
|
||||
use anyhow::{anyhow, Result};
|
||||
use client::{Client, PerformCompletionParams, UserStore, EXPIRED_LLM_TOKEN_HEADER_NAME};
|
||||
use client::{
|
||||
Client, PerformCompletionParams, UserStore, EXPIRED_LLM_TOKEN_HEADER_NAME,
|
||||
MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME,
|
||||
};
|
||||
use collections::BTreeMap;
|
||||
use feature_flags::{FeatureFlagAppExt, LlmClosedBeta, ZedPro};
|
||||
use futures::{
|
||||
|
@ -15,10 +18,11 @@ use futures::{
|
|||
TryStreamExt as _,
|
||||
};
|
||||
use gpui::{
|
||||
AnyElement, AnyView, AppContext, AsyncAppContext, FontWeight, Model, ModelContext,
|
||||
Subscription, Task,
|
||||
AnyElement, AnyView, AppContext, AsyncAppContext, EventEmitter, FontWeight, Global, Model,
|
||||
ModelContext, ReadGlobal, Subscription, Task,
|
||||
};
|
||||
use http_client::{AsyncBody, HttpClient, HttpRequestExt, Method, Response};
|
||||
use http_client::{AsyncBody, HttpClient, HttpRequestExt, Method, Response, StatusCode};
|
||||
use proto::TypedEnvelope;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{de::DeserializeOwned, Deserialize, Serialize};
|
||||
use serde_json::value::RawValue;
|
||||
|
@ -27,12 +31,14 @@ use smol::{
|
|||
io::{AsyncReadExt, BufReader},
|
||||
lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard},
|
||||
};
|
||||
use std::fmt;
|
||||
use std::time::Duration;
|
||||
use std::{
|
||||
future,
|
||||
sync::{Arc, LazyLock},
|
||||
};
|
||||
use strum::IntoEnumIterator;
|
||||
use thiserror::Error;
|
||||
use ui::{prelude::*, TintColor};
|
||||
|
||||
use crate::{LanguageModelAvailability, LanguageModelCompletionEvent, LanguageModelProvider};
|
||||
|
@ -90,22 +96,93 @@ pub struct AvailableModel {
|
|||
pub default_temperature: Option<f32>,
|
||||
}
|
||||
|
||||
struct GlobalRefreshLlmTokenListener(Model<RefreshLlmTokenListener>);
|
||||
|
||||
impl Global for GlobalRefreshLlmTokenListener {}
|
||||
|
||||
pub struct RefreshLlmTokenEvent;
|
||||
|
||||
pub struct RefreshLlmTokenListener {
|
||||
_llm_token_subscription: client::Subscription,
|
||||
}
|
||||
|
||||
impl EventEmitter<RefreshLlmTokenEvent> for RefreshLlmTokenListener {}
|
||||
|
||||
impl RefreshLlmTokenListener {
|
||||
pub fn register(client: Arc<Client>, cx: &mut AppContext) {
|
||||
let listener = cx.new_model(|cx| RefreshLlmTokenListener::new(client, cx));
|
||||
cx.set_global(GlobalRefreshLlmTokenListener(listener));
|
||||
}
|
||||
|
||||
pub fn global(cx: &AppContext) -> Model<Self> {
|
||||
GlobalRefreshLlmTokenListener::global(cx).0.clone()
|
||||
}
|
||||
|
||||
fn new(client: Arc<Client>, cx: &mut ModelContext<Self>) -> Self {
|
||||
Self {
|
||||
_llm_token_subscription: client
|
||||
.add_message_handler(cx.weak_model(), Self::handle_refresh_llm_token),
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_refresh_llm_token(
|
||||
this: Model<Self>,
|
||||
_: TypedEnvelope<proto::RefreshLlmToken>,
|
||||
mut cx: AsyncAppContext,
|
||||
) -> Result<()> {
|
||||
this.update(&mut cx, |_this, cx| cx.emit(RefreshLlmTokenEvent))
|
||||
}
|
||||
}
|
||||
|
||||
pub struct CloudLanguageModelProvider {
|
||||
client: Arc<Client>,
|
||||
llm_api_token: LlmApiToken,
|
||||
state: gpui::Model<State>,
|
||||
_maintain_client_status: Task<()>,
|
||||
}
|
||||
|
||||
pub struct State {
|
||||
client: Arc<Client>,
|
||||
llm_api_token: LlmApiToken,
|
||||
user_store: Model<UserStore>,
|
||||
status: client::Status,
|
||||
accept_terms: Option<Task<Result<()>>>,
|
||||
_subscription: Subscription,
|
||||
_settings_subscription: Subscription,
|
||||
_llm_token_subscription: Subscription,
|
||||
}
|
||||
|
||||
impl State {
|
||||
fn new(
|
||||
client: Arc<Client>,
|
||||
user_store: Model<UserStore>,
|
||||
status: client::Status,
|
||||
cx: &mut ModelContext<Self>,
|
||||
) -> Self {
|
||||
let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
|
||||
|
||||
Self {
|
||||
client: client.clone(),
|
||||
llm_api_token: LlmApiToken::default(),
|
||||
user_store,
|
||||
status,
|
||||
accept_terms: None,
|
||||
_settings_subscription: cx.observe_global::<SettingsStore>(|_, cx| {
|
||||
cx.notify();
|
||||
}),
|
||||
_llm_token_subscription: cx.subscribe(
|
||||
&refresh_llm_token_listener,
|
||||
|this, _listener, _event, cx| {
|
||||
let client = this.client.clone();
|
||||
let llm_api_token = this.llm_api_token.clone();
|
||||
cx.spawn(|_this, _cx| async move {
|
||||
llm_api_token.refresh(&client).await?;
|
||||
anyhow::Ok(())
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
},
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
fn is_signed_out(&self) -> bool {
|
||||
self.status.is_signed_out()
|
||||
}
|
||||
|
@ -144,15 +221,7 @@ impl CloudLanguageModelProvider {
|
|||
let mut status_rx = client.status();
|
||||
let status = *status_rx.borrow();
|
||||
|
||||
let state = cx.new_model(|cx| State {
|
||||
client: client.clone(),
|
||||
user_store,
|
||||
status,
|
||||
accept_terms: None,
|
||||
_subscription: cx.observe_global::<SettingsStore>(|_, cx| {
|
||||
cx.notify();
|
||||
}),
|
||||
});
|
||||
let state = cx.new_model(|cx| State::new(client.clone(), user_store.clone(), status, cx));
|
||||
|
||||
let state_ref = state.downgrade();
|
||||
let maintain_client_status = cx.spawn(|mut cx| async move {
|
||||
|
@ -172,8 +241,7 @@ impl CloudLanguageModelProvider {
|
|||
|
||||
Self {
|
||||
client,
|
||||
state,
|
||||
llm_api_token: LlmApiToken::default(),
|
||||
state: state.clone(),
|
||||
_maintain_client_status: maintain_client_status,
|
||||
}
|
||||
}
|
||||
|
@ -272,13 +340,14 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
|
|||
models.insert(model.id().to_string(), model.clone());
|
||||
}
|
||||
|
||||
let llm_api_token = self.state.read(cx).llm_api_token.clone();
|
||||
models
|
||||
.into_values()
|
||||
.map(|model| {
|
||||
Arc::new(CloudLanguageModel {
|
||||
id: LanguageModelId::from(model.id().to_string()),
|
||||
model,
|
||||
llm_api_token: self.llm_api_token.clone(),
|
||||
llm_api_token: llm_api_token.clone(),
|
||||
client: self.client.clone(),
|
||||
request_limiter: RateLimiter::new(4),
|
||||
}) as Arc<dyn LanguageModel>
|
||||
|
@ -377,6 +446,30 @@ pub struct CloudLanguageModel {
|
|||
#[derive(Clone, Default)]
|
||||
struct LlmApiToken(Arc<RwLock<Option<String>>>);
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub struct PaymentRequiredError;
|
||||
|
||||
impl fmt::Display for PaymentRequiredError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"Payment required to use this language model. Please upgrade your account."
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub struct MaxMonthlySpendReachedError;
|
||||
|
||||
impl fmt::Display for MaxMonthlySpendReachedError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"Maximum spending limit reached for this month. For more usage, increase your spending limit."
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl CloudLanguageModel {
|
||||
async fn perform_llm_completion(
|
||||
client: Arc<Client>,
|
||||
|
@ -411,6 +504,15 @@ impl CloudLanguageModel {
|
|||
{
|
||||
did_retry = true;
|
||||
token = llm_api_token.refresh(&client).await?;
|
||||
} else if response.status() == StatusCode::FORBIDDEN
|
||||
&& response
|
||||
.headers()
|
||||
.get(MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME)
|
||||
.is_some()
|
||||
{
|
||||
break Err(anyhow!(MaxMonthlySpendReachedError))?;
|
||||
} else if response.status() == StatusCode::PAYMENT_REQUIRED {
|
||||
break Err(anyhow!(PaymentRequiredError))?;
|
||||
} else {
|
||||
let mut body = String::new();
|
||||
response.body_mut().read_to_string(&mut body).await?;
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
use crate::provider::cloud::RefreshLlmTokenListener;
|
||||
use crate::{
|
||||
provider::{
|
||||
anthropic::AnthropicLanguageModelProvider, cloud::CloudLanguageModelProvider,
|
||||
|
@ -30,6 +31,8 @@ fn register_language_model_providers(
|
|||
) {
|
||||
use feature_flags::FeatureFlagAppExt;
|
||||
|
||||
RefreshLlmTokenListener::register(client.clone(), cx);
|
||||
|
||||
registry.register_provider(
|
||||
AnthropicLanguageModelProvider::new(client.http_client(), cx),
|
||||
cx,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue