assistant: Add support for displaying billing-related errors (#19082)

This PR adds support to the assistant for display billing-related
errors.

Pulling this out of #19081 to make it easier to cherry-pick.

Release Notes:

- N/A

Co-authored-by: Antonio <antonio@zed.dev>
Co-authored-by: Richard <richard@zed.dev>
This commit is contained in:
Marshall Bowers 2024-10-11 13:22:45 -04:00 committed by GitHub
parent 5cf0217549
commit 84b61c8b1a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 330 additions and 83 deletions

View file

@ -47,6 +47,7 @@ settings.workspace = true
smol.workspace = true
strum.workspace = true
theme.workspace = true
thiserror.workspace = true
tiktoken-rs.workspace = true
ui.workspace = true
util.workspace = true

View file

@ -7,7 +7,10 @@ use crate::{
};
use anthropic::AnthropicError;
use anyhow::{anyhow, Result};
use client::{Client, PerformCompletionParams, UserStore, EXPIRED_LLM_TOKEN_HEADER_NAME};
use client::{
Client, PerformCompletionParams, UserStore, EXPIRED_LLM_TOKEN_HEADER_NAME,
MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME,
};
use collections::BTreeMap;
use feature_flags::{FeatureFlagAppExt, LlmClosedBeta, ZedPro};
use futures::{
@ -15,10 +18,11 @@ use futures::{
TryStreamExt as _,
};
use gpui::{
AnyElement, AnyView, AppContext, AsyncAppContext, FontWeight, Model, ModelContext,
Subscription, Task,
AnyElement, AnyView, AppContext, AsyncAppContext, EventEmitter, FontWeight, Global, Model,
ModelContext, ReadGlobal, Subscription, Task,
};
use http_client::{AsyncBody, HttpClient, HttpRequestExt, Method, Response};
use http_client::{AsyncBody, HttpClient, HttpRequestExt, Method, Response, StatusCode};
use proto::TypedEnvelope;
use schemars::JsonSchema;
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use serde_json::value::RawValue;
@ -27,12 +31,14 @@ use smol::{
io::{AsyncReadExt, BufReader},
lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard},
};
use std::fmt;
use std::time::Duration;
use std::{
future,
sync::{Arc, LazyLock},
};
use strum::IntoEnumIterator;
use thiserror::Error;
use ui::{prelude::*, TintColor};
use crate::{LanguageModelAvailability, LanguageModelCompletionEvent, LanguageModelProvider};
@ -90,22 +96,93 @@ pub struct AvailableModel {
pub default_temperature: Option<f32>,
}
struct GlobalRefreshLlmTokenListener(Model<RefreshLlmTokenListener>);
impl Global for GlobalRefreshLlmTokenListener {}
pub struct RefreshLlmTokenEvent;
pub struct RefreshLlmTokenListener {
_llm_token_subscription: client::Subscription,
}
impl EventEmitter<RefreshLlmTokenEvent> for RefreshLlmTokenListener {}
impl RefreshLlmTokenListener {
pub fn register(client: Arc<Client>, cx: &mut AppContext) {
let listener = cx.new_model(|cx| RefreshLlmTokenListener::new(client, cx));
cx.set_global(GlobalRefreshLlmTokenListener(listener));
}
pub fn global(cx: &AppContext) -> Model<Self> {
GlobalRefreshLlmTokenListener::global(cx).0.clone()
}
fn new(client: Arc<Client>, cx: &mut ModelContext<Self>) -> Self {
Self {
_llm_token_subscription: client
.add_message_handler(cx.weak_model(), Self::handle_refresh_llm_token),
}
}
async fn handle_refresh_llm_token(
this: Model<Self>,
_: TypedEnvelope<proto::RefreshLlmToken>,
mut cx: AsyncAppContext,
) -> Result<()> {
this.update(&mut cx, |_this, cx| cx.emit(RefreshLlmTokenEvent))
}
}
pub struct CloudLanguageModelProvider {
client: Arc<Client>,
llm_api_token: LlmApiToken,
state: gpui::Model<State>,
_maintain_client_status: Task<()>,
}
pub struct State {
client: Arc<Client>,
llm_api_token: LlmApiToken,
user_store: Model<UserStore>,
status: client::Status,
accept_terms: Option<Task<Result<()>>>,
_subscription: Subscription,
_settings_subscription: Subscription,
_llm_token_subscription: Subscription,
}
impl State {
fn new(
client: Arc<Client>,
user_store: Model<UserStore>,
status: client::Status,
cx: &mut ModelContext<Self>,
) -> Self {
let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
Self {
client: client.clone(),
llm_api_token: LlmApiToken::default(),
user_store,
status,
accept_terms: None,
_settings_subscription: cx.observe_global::<SettingsStore>(|_, cx| {
cx.notify();
}),
_llm_token_subscription: cx.subscribe(
&refresh_llm_token_listener,
|this, _listener, _event, cx| {
let client = this.client.clone();
let llm_api_token = this.llm_api_token.clone();
cx.spawn(|_this, _cx| async move {
llm_api_token.refresh(&client).await?;
anyhow::Ok(())
})
.detach_and_log_err(cx);
},
),
}
}
fn is_signed_out(&self) -> bool {
self.status.is_signed_out()
}
@ -144,15 +221,7 @@ impl CloudLanguageModelProvider {
let mut status_rx = client.status();
let status = *status_rx.borrow();
let state = cx.new_model(|cx| State {
client: client.clone(),
user_store,
status,
accept_terms: None,
_subscription: cx.observe_global::<SettingsStore>(|_, cx| {
cx.notify();
}),
});
let state = cx.new_model(|cx| State::new(client.clone(), user_store.clone(), status, cx));
let state_ref = state.downgrade();
let maintain_client_status = cx.spawn(|mut cx| async move {
@ -172,8 +241,7 @@ impl CloudLanguageModelProvider {
Self {
client,
state,
llm_api_token: LlmApiToken::default(),
state: state.clone(),
_maintain_client_status: maintain_client_status,
}
}
@ -272,13 +340,14 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
models.insert(model.id().to_string(), model.clone());
}
let llm_api_token = self.state.read(cx).llm_api_token.clone();
models
.into_values()
.map(|model| {
Arc::new(CloudLanguageModel {
id: LanguageModelId::from(model.id().to_string()),
model,
llm_api_token: self.llm_api_token.clone(),
llm_api_token: llm_api_token.clone(),
client: self.client.clone(),
request_limiter: RateLimiter::new(4),
}) as Arc<dyn LanguageModel>
@ -377,6 +446,30 @@ pub struct CloudLanguageModel {
#[derive(Clone, Default)]
struct LlmApiToken(Arc<RwLock<Option<String>>>);
#[derive(Error, Debug)]
pub struct PaymentRequiredError;
impl fmt::Display for PaymentRequiredError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
"Payment required to use this language model. Please upgrade your account."
)
}
}
#[derive(Error, Debug)]
pub struct MaxMonthlySpendReachedError;
impl fmt::Display for MaxMonthlySpendReachedError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
"Maximum spending limit reached for this month. For more usage, increase your spending limit."
)
}
}
impl CloudLanguageModel {
async fn perform_llm_completion(
client: Arc<Client>,
@ -411,6 +504,15 @@ impl CloudLanguageModel {
{
did_retry = true;
token = llm_api_token.refresh(&client).await?;
} else if response.status() == StatusCode::FORBIDDEN
&& response
.headers()
.get(MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME)
.is_some()
{
break Err(anyhow!(MaxMonthlySpendReachedError))?;
} else if response.status() == StatusCode::PAYMENT_REQUIRED {
break Err(anyhow!(PaymentRequiredError))?;
} else {
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;

View file

@ -1,3 +1,4 @@
use crate::provider::cloud::RefreshLlmTokenListener;
use crate::{
provider::{
anthropic::AnthropicLanguageModelProvider, cloud::CloudLanguageModelProvider,
@ -30,6 +31,8 @@ fn register_language_model_providers(
) {
use feature_flags::FeatureFlagAppExt;
RefreshLlmTokenListener::register(client.clone(), cx);
registry.register_provider(
AnthropicLanguageModelProvider::new(client.http_client(), cx),
cx,