assistant: Add support for displaying billing-related errors (#19082)

This PR adds support to the assistant for display billing-related
errors.

Pulling this out of #19081 to make it easier to cherry-pick.

Release Notes:

- N/A

Co-authored-by: Antonio <antonio@zed.dev>
Co-authored-by: Richard <richard@zed.dev>
This commit is contained in:
Marshall Bowers 2024-10-11 13:22:45 -04:00 committed by GitHub
parent 5cf0217549
commit 84b61c8b1a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 330 additions and 83 deletions

1
Cargo.lock generated
View file

@ -6291,6 +6291,7 @@ dependencies = [
"strum 0.25.0",
"text",
"theme",
"thiserror",
"tiktoken-rs",
"ui",
"unindent",

View file

@ -1503,6 +1503,13 @@ struct WorkflowAssist {
type MessageHeader = MessageMetadata;
#[derive(Clone)]
enum AssistError {
PaymentRequired,
MaxMonthlySpendReached,
Message(SharedString),
}
pub struct ContextEditor {
context: Model<Context>,
fs: Arc<dyn Fs>,
@ -1521,7 +1528,7 @@ pub struct ContextEditor {
workflow_steps: HashMap<Range<language::Anchor>, WorkflowStepViewState>,
active_workflow_step: Option<ActiveWorkflowStep>,
assistant_panel: WeakView<AssistantPanel>,
error_message: Option<SharedString>,
last_error: Option<AssistError>,
show_accept_terms: bool,
pub(crate) slash_menu_handle:
PopoverMenuHandle<Picker<slash_command_picker::SlashCommandDelegate>>,
@ -1592,7 +1599,7 @@ impl ContextEditor {
workflow_steps: HashMap::default(),
active_workflow_step: None,
assistant_panel,
error_message: None,
last_error: None,
show_accept_terms: false,
slash_menu_handle: Default::default(),
dragged_file_worktrees: Vec::new(),
@ -1636,7 +1643,7 @@ impl ContextEditor {
}
if !self.apply_active_workflow_step(cx) {
self.error_message = None;
self.last_error = None;
self.send_to_model(cx);
cx.notify();
}
@ -1786,7 +1793,7 @@ impl ContextEditor {
}
fn cancel(&mut self, _: &editor::actions::Cancel, cx: &mut ViewContext<Self>) {
self.error_message = None;
self.last_error = None;
if self
.context
@ -2291,7 +2298,13 @@ impl ContextEditor {
}
ContextEvent::Operation(_) => {}
ContextEvent::ShowAssistError(error_message) => {
self.error_message = Some(error_message.clone());
self.last_error = Some(AssistError::Message(error_message.clone()));
}
ContextEvent::ShowPaymentRequiredError => {
self.last_error = Some(AssistError::PaymentRequired);
}
ContextEvent::ShowMaxMonthlySpendReachedError => {
self.last_error = Some(AssistError::MaxMonthlySpendReached);
}
}
}
@ -4305,6 +4318,154 @@ impl ContextEditor {
focus_handle.dispatch_action(&Assist, cx);
})
}
fn render_last_error(&self, cx: &mut ViewContext<Self>) -> Option<AnyElement> {
let last_error = self.last_error.as_ref()?;
Some(
div()
.absolute()
.right_3()
.bottom_12()
.max_w_96()
.py_2()
.px_3()
.elevation_2(cx)
.occlude()
.child(match last_error {
AssistError::PaymentRequired => self.render_payment_required_error(cx),
AssistError::MaxMonthlySpendReached => {
self.render_max_monthly_spend_reached_error(cx)
}
AssistError::Message(error_message) => {
self.render_assist_error(error_message, cx)
}
})
.into_any(),
)
}
fn render_payment_required_error(&self, cx: &mut ViewContext<Self>) -> AnyElement {
const ERROR_MESSAGE: &str = "Free tier exceeded. Subscribe and add payment to continue using Zed LLMs. You'll be billed at cost for tokens used.";
const SUBSCRIBE_URL: &str = "https://zed.dev/ai/subscribe";
v_flex()
.gap_0p5()
.child(
h_flex()
.gap_1p5()
.items_center()
.child(Icon::new(IconName::XCircle).color(Color::Error))
.child(Label::new("Free Usage Exceeded").weight(FontWeight::MEDIUM)),
)
.child(
div()
.id("error-message")
.max_h_24()
.overflow_y_scroll()
.child(Label::new(ERROR_MESSAGE)),
)
.child(
h_flex()
.justify_end()
.mt_1()
.child(Button::new("subscribe", "Subscribe").on_click(cx.listener(
|this, _, cx| {
this.last_error = None;
cx.open_url(SUBSCRIBE_URL);
cx.notify();
},
)))
.child(Button::new("dismiss", "Dismiss").on_click(cx.listener(
|this, _, cx| {
this.last_error = None;
cx.notify();
},
))),
)
.into_any()
}
fn render_max_monthly_spend_reached_error(&self, cx: &mut ViewContext<Self>) -> AnyElement {
const ERROR_MESSAGE: &str = "You have reached your maximum monthly spend. Increase your spend limit to continue using Zed LLMs.";
const ACCOUNT_URL: &str = "https://zed.dev/account";
v_flex()
.gap_0p5()
.child(
h_flex()
.gap_1p5()
.items_center()
.child(Icon::new(IconName::XCircle).color(Color::Error))
.child(Label::new("Max Monthly Spend Reached").weight(FontWeight::MEDIUM)),
)
.child(
div()
.id("error-message")
.max_h_24()
.overflow_y_scroll()
.child(Label::new(ERROR_MESSAGE)),
)
.child(
h_flex()
.justify_end()
.mt_1()
.child(
Button::new("subscribe", "Update Monthly Spend Limit").on_click(
cx.listener(|this, _, cx| {
this.last_error = None;
cx.open_url(ACCOUNT_URL);
cx.notify();
}),
),
)
.child(Button::new("dismiss", "Dismiss").on_click(cx.listener(
|this, _, cx| {
this.last_error = None;
cx.notify();
},
))),
)
.into_any()
}
fn render_assist_error(
&self,
error_message: &SharedString,
cx: &mut ViewContext<Self>,
) -> AnyElement {
v_flex()
.gap_0p5()
.child(
h_flex()
.gap_1p5()
.items_center()
.child(Icon::new(IconName::XCircle).color(Color::Error))
.child(
Label::new("Error interacting with language model")
.weight(FontWeight::MEDIUM),
),
)
.child(
div()
.id("error-message")
.max_h_24()
.overflow_y_scroll()
.child(Label::new(error_message.clone())),
)
.child(
h_flex()
.justify_end()
.mt_1()
.child(Button::new("dismiss", "Dismiss").on_click(cx.listener(
|this, _, cx| {
this.last_error = None;
cx.notify();
},
))),
)
.into_any()
}
}
/// Returns the contents of the *outermost* fenced code block that contains the given offset.
@ -4441,48 +4602,7 @@ impl Render for ContextEditor {
.child(element),
)
})
.when_some(self.error_message.clone(), |this, error_message| {
this.child(
div()
.absolute()
.right_3()
.bottom_12()
.max_w_96()
.py_2()
.px_3()
.elevation_2(cx)
.occlude()
.child(
v_flex()
.gap_0p5()
.child(
h_flex()
.gap_1p5()
.items_center()
.child(Icon::new(IconName::XCircle).color(Color::Error))
.child(
Label::new("Error interacting with language model")
.weight(FontWeight::MEDIUM),
),
)
.child(
div()
.id("error-message")
.max_h_24()
.overflow_y_scroll()
.child(Label::new(error_message)),
)
.child(h_flex().justify_end().mt_1().child(
Button::new("dismiss", "Dismiss").on_click(cx.listener(
|this, _, cx| {
this.error_message = None;
cx.notify();
},
)),
)),
),
)
})
.children(self.render_last_error(cx))
.child(
h_flex().w_full().relative().child(
h_flex()

View file

@ -26,6 +26,7 @@ use gpui::{
use language::{AnchorRangeExt, Bias, Buffer, LanguageRegistry, OffsetRangeExt, Point, ToOffset};
use language_model::{
provider::cloud::{MaxMonthlySpendReachedError, PaymentRequiredError},
LanguageModel, LanguageModelCacheConfiguration, LanguageModelCompletionEvent,
LanguageModelImage, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolUse, MessageContent, Role,
@ -294,6 +295,8 @@ impl ContextOperation {
#[derive(Debug, Clone)]
pub enum ContextEvent {
ShowAssistError(SharedString),
ShowPaymentRequiredError,
ShowMaxMonthlySpendReachedError,
MessagesEdited,
SummaryChanged,
StreamedCompletion,
@ -2112,25 +2115,36 @@ impl Context {
let result = stream_completion.await;
this.update(&mut cx, |this, cx| {
let error_message = result
.as_ref()
.err()
.map(|error| error.to_string().trim().to_string());
if let Some(error_message) = error_message.as_ref() {
cx.emit(ContextEvent::ShowAssistError(SharedString::from(
error_message.clone(),
)));
}
this.update_metadata(assistant_message_id, cx, |metadata| {
if let Some(error_message) = error_message.as_ref() {
metadata.status =
MessageStatus::Error(SharedString::from(error_message.clone()));
let error_message = if let Some(error) = result.as_ref().err() {
if error.is::<PaymentRequiredError>() {
cx.emit(ContextEvent::ShowPaymentRequiredError);
this.update_metadata(assistant_message_id, cx, |metadata| {
metadata.status = MessageStatus::Canceled;
});
Some(error.to_string())
} else if error.is::<MaxMonthlySpendReachedError>() {
cx.emit(ContextEvent::ShowMaxMonthlySpendReachedError);
this.update_metadata(assistant_message_id, cx, |metadata| {
metadata.status = MessageStatus::Canceled;
});
Some(error.to_string())
} else {
metadata.status = MessageStatus::Done;
let error_message = error.to_string().trim().to_string();
cx.emit(ContextEvent::ShowAssistError(SharedString::from(
error_message.clone(),
)));
this.update_metadata(assistant_message_id, cx, |metadata| {
metadata.status =
MessageStatus::Error(SharedString::from(error_message.clone()));
});
Some(error_message)
}
});
} else {
this.update_metadata(assistant_message_id, cx, |metadata| {
metadata.status = MessageStatus::Done;
});
None
};
if let Some(telemetry) = this.telemetry.as_ref() {
let language_name = this

View file

@ -47,6 +47,7 @@ settings.workspace = true
smol.workspace = true
strum.workspace = true
theme.workspace = true
thiserror.workspace = true
tiktoken-rs.workspace = true
ui.workspace = true
util.workspace = true

View file

@ -7,7 +7,10 @@ use crate::{
};
use anthropic::AnthropicError;
use anyhow::{anyhow, Result};
use client::{Client, PerformCompletionParams, UserStore, EXPIRED_LLM_TOKEN_HEADER_NAME};
use client::{
Client, PerformCompletionParams, UserStore, EXPIRED_LLM_TOKEN_HEADER_NAME,
MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME,
};
use collections::BTreeMap;
use feature_flags::{FeatureFlagAppExt, LlmClosedBeta, ZedPro};
use futures::{
@ -15,10 +18,11 @@ use futures::{
TryStreamExt as _,
};
use gpui::{
AnyElement, AnyView, AppContext, AsyncAppContext, FontWeight, Model, ModelContext,
Subscription, Task,
AnyElement, AnyView, AppContext, AsyncAppContext, EventEmitter, FontWeight, Global, Model,
ModelContext, ReadGlobal, Subscription, Task,
};
use http_client::{AsyncBody, HttpClient, HttpRequestExt, Method, Response};
use http_client::{AsyncBody, HttpClient, HttpRequestExt, Method, Response, StatusCode};
use proto::TypedEnvelope;
use schemars::JsonSchema;
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use serde_json::value::RawValue;
@ -27,12 +31,14 @@ use smol::{
io::{AsyncReadExt, BufReader},
lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard},
};
use std::fmt;
use std::time::Duration;
use std::{
future,
sync::{Arc, LazyLock},
};
use strum::IntoEnumIterator;
use thiserror::Error;
use ui::{prelude::*, TintColor};
use crate::{LanguageModelAvailability, LanguageModelCompletionEvent, LanguageModelProvider};
@ -90,22 +96,93 @@ pub struct AvailableModel {
pub default_temperature: Option<f32>,
}
struct GlobalRefreshLlmTokenListener(Model<RefreshLlmTokenListener>);
impl Global for GlobalRefreshLlmTokenListener {}
pub struct RefreshLlmTokenEvent;
pub struct RefreshLlmTokenListener {
_llm_token_subscription: client::Subscription,
}
impl EventEmitter<RefreshLlmTokenEvent> for RefreshLlmTokenListener {}
impl RefreshLlmTokenListener {
pub fn register(client: Arc<Client>, cx: &mut AppContext) {
let listener = cx.new_model(|cx| RefreshLlmTokenListener::new(client, cx));
cx.set_global(GlobalRefreshLlmTokenListener(listener));
}
pub fn global(cx: &AppContext) -> Model<Self> {
GlobalRefreshLlmTokenListener::global(cx).0.clone()
}
fn new(client: Arc<Client>, cx: &mut ModelContext<Self>) -> Self {
Self {
_llm_token_subscription: client
.add_message_handler(cx.weak_model(), Self::handle_refresh_llm_token),
}
}
async fn handle_refresh_llm_token(
this: Model<Self>,
_: TypedEnvelope<proto::RefreshLlmToken>,
mut cx: AsyncAppContext,
) -> Result<()> {
this.update(&mut cx, |_this, cx| cx.emit(RefreshLlmTokenEvent))
}
}
pub struct CloudLanguageModelProvider {
client: Arc<Client>,
llm_api_token: LlmApiToken,
state: gpui::Model<State>,
_maintain_client_status: Task<()>,
}
pub struct State {
client: Arc<Client>,
llm_api_token: LlmApiToken,
user_store: Model<UserStore>,
status: client::Status,
accept_terms: Option<Task<Result<()>>>,
_subscription: Subscription,
_settings_subscription: Subscription,
_llm_token_subscription: Subscription,
}
impl State {
fn new(
client: Arc<Client>,
user_store: Model<UserStore>,
status: client::Status,
cx: &mut ModelContext<Self>,
) -> Self {
let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
Self {
client: client.clone(),
llm_api_token: LlmApiToken::default(),
user_store,
status,
accept_terms: None,
_settings_subscription: cx.observe_global::<SettingsStore>(|_, cx| {
cx.notify();
}),
_llm_token_subscription: cx.subscribe(
&refresh_llm_token_listener,
|this, _listener, _event, cx| {
let client = this.client.clone();
let llm_api_token = this.llm_api_token.clone();
cx.spawn(|_this, _cx| async move {
llm_api_token.refresh(&client).await?;
anyhow::Ok(())
})
.detach_and_log_err(cx);
},
),
}
}
fn is_signed_out(&self) -> bool {
self.status.is_signed_out()
}
@ -144,15 +221,7 @@ impl CloudLanguageModelProvider {
let mut status_rx = client.status();
let status = *status_rx.borrow();
let state = cx.new_model(|cx| State {
client: client.clone(),
user_store,
status,
accept_terms: None,
_subscription: cx.observe_global::<SettingsStore>(|_, cx| {
cx.notify();
}),
});
let state = cx.new_model(|cx| State::new(client.clone(), user_store.clone(), status, cx));
let state_ref = state.downgrade();
let maintain_client_status = cx.spawn(|mut cx| async move {
@ -172,8 +241,7 @@ impl CloudLanguageModelProvider {
Self {
client,
state,
llm_api_token: LlmApiToken::default(),
state: state.clone(),
_maintain_client_status: maintain_client_status,
}
}
@ -272,13 +340,14 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
models.insert(model.id().to_string(), model.clone());
}
let llm_api_token = self.state.read(cx).llm_api_token.clone();
models
.into_values()
.map(|model| {
Arc::new(CloudLanguageModel {
id: LanguageModelId::from(model.id().to_string()),
model,
llm_api_token: self.llm_api_token.clone(),
llm_api_token: llm_api_token.clone(),
client: self.client.clone(),
request_limiter: RateLimiter::new(4),
}) as Arc<dyn LanguageModel>
@ -377,6 +446,30 @@ pub struct CloudLanguageModel {
#[derive(Clone, Default)]
struct LlmApiToken(Arc<RwLock<Option<String>>>);
#[derive(Error, Debug)]
pub struct PaymentRequiredError;
impl fmt::Display for PaymentRequiredError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
"Payment required to use this language model. Please upgrade your account."
)
}
}
#[derive(Error, Debug)]
pub struct MaxMonthlySpendReachedError;
impl fmt::Display for MaxMonthlySpendReachedError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
"Maximum spending limit reached for this month. For more usage, increase your spending limit."
)
}
}
impl CloudLanguageModel {
async fn perform_llm_completion(
client: Arc<Client>,
@ -411,6 +504,15 @@ impl CloudLanguageModel {
{
did_retry = true;
token = llm_api_token.refresh(&client).await?;
} else if response.status() == StatusCode::FORBIDDEN
&& response
.headers()
.get(MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME)
.is_some()
{
break Err(anyhow!(MaxMonthlySpendReachedError))?;
} else if response.status() == StatusCode::PAYMENT_REQUIRED {
break Err(anyhow!(PaymentRequiredError))?;
} else {
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;

View file

@ -1,3 +1,4 @@
use crate::provider::cloud::RefreshLlmTokenListener;
use crate::{
provider::{
anthropic::AnthropicLanguageModelProvider, cloud::CloudLanguageModelProvider,
@ -30,6 +31,8 @@ fn register_language_model_providers(
) {
use feature_flags::FeatureFlagAppExt;
RefreshLlmTokenListener::register(client.clone(), cx);
registry.register_provider(
AnthropicLanguageModelProvider::new(client.http_client(), cx),
cx,

View file

@ -269,6 +269,7 @@ message Envelope {
GetLlmToken get_llm_token = 235;
GetLlmTokenResponse get_llm_token_response = 236;
RefreshLlmToken refresh_llm_token = 259; // current max
LspExtSwitchSourceHeader lsp_ext_switch_source_header = 241;
LspExtSwitchSourceHeaderResponse lsp_ext_switch_source_header_response = 242;
@ -284,7 +285,7 @@ message Envelope {
ShutdownRemoteServer shutdown_remote_server = 257;
RemoveWorktree remove_worktree = 258; // current max
RemoveWorktree remove_worktree = 258;
}
reserved 87 to 88;
@ -2429,6 +2430,8 @@ message GetLlmTokenResponse {
string token = 1;
}
message RefreshLlmToken {}
// Remote FS
message AddWorktree {

View file

@ -253,6 +253,7 @@ messages!(
(ProjectEntryResponse, Foreground),
(CountLanguageModelTokens, Background),
(CountLanguageModelTokensResponse, Background),
(RefreshLlmToken, Background),
(RefreshInlayHints, Foreground),
(RejoinChannelBuffers, Foreground),
(RejoinChannelBuffersResponse, Foreground),

View file

@ -3,6 +3,8 @@ use strum::{Display, EnumIter, EnumString};
pub const EXPIRED_LLM_TOKEN_HEADER_NAME: &str = "x-zed-expired-token";
pub const MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME: &str = "x-zed-llm-max-monthly-spend-reached";
#[derive(
Debug, PartialEq, Eq, Hash, Clone, Copy, Serialize, Deserialize, EnumString, EnumIter, Display,
)]