From 84b61c8b1aec84b5235b7798b83aad039d98033e Mon Sep 17 00:00:00 2001 From: Marshall Bowers Date: Fri, 11 Oct 2024 13:22:45 -0400 Subject: [PATCH] assistant: Add support for displaying billing-related errors (#19082) This PR adds support to the assistant for display billing-related errors. Pulling this out of #19081 to make it easier to cherry-pick. Release Notes: - N/A Co-authored-by: Antonio Co-authored-by: Richard --- Cargo.lock | 1 + crates/assistant/src/assistant_panel.rs | 214 +++++++++++++++----- crates/assistant/src/context.rs | 48 +++-- crates/language_model/Cargo.toml | 1 + crates/language_model/src/provider/cloud.rs | 138 +++++++++++-- crates/language_model/src/registry.rs | 3 + crates/proto/proto/zed.proto | 5 +- crates/proto/src/proto.rs | 1 + crates/rpc/src/llm.rs | 2 + 9 files changed, 330 insertions(+), 83 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 018e9219c5..752050c995 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6291,6 +6291,7 @@ dependencies = [ "strum 0.25.0", "text", "theme", + "thiserror", "tiktoken-rs", "ui", "unindent", diff --git a/crates/assistant/src/assistant_panel.rs b/crates/assistant/src/assistant_panel.rs index e4d02e5bb7..0439122cb4 100644 --- a/crates/assistant/src/assistant_panel.rs +++ b/crates/assistant/src/assistant_panel.rs @@ -1503,6 +1503,13 @@ struct WorkflowAssist { type MessageHeader = MessageMetadata; +#[derive(Clone)] +enum AssistError { + PaymentRequired, + MaxMonthlySpendReached, + Message(SharedString), +} + pub struct ContextEditor { context: Model, fs: Arc, @@ -1521,7 +1528,7 @@ pub struct ContextEditor { workflow_steps: HashMap, WorkflowStepViewState>, active_workflow_step: Option, assistant_panel: WeakView, - error_message: Option, + last_error: Option, show_accept_terms: bool, pub(crate) slash_menu_handle: PopoverMenuHandle>, @@ -1592,7 +1599,7 @@ impl ContextEditor { workflow_steps: HashMap::default(), active_workflow_step: None, assistant_panel, - error_message: None, + last_error: None, show_accept_terms: false, slash_menu_handle: Default::default(), dragged_file_worktrees: Vec::new(), @@ -1636,7 +1643,7 @@ impl ContextEditor { } if !self.apply_active_workflow_step(cx) { - self.error_message = None; + self.last_error = None; self.send_to_model(cx); cx.notify(); } @@ -1786,7 +1793,7 @@ impl ContextEditor { } fn cancel(&mut self, _: &editor::actions::Cancel, cx: &mut ViewContext) { - self.error_message = None; + self.last_error = None; if self .context @@ -2291,7 +2298,13 @@ impl ContextEditor { } ContextEvent::Operation(_) => {} ContextEvent::ShowAssistError(error_message) => { - self.error_message = Some(error_message.clone()); + self.last_error = Some(AssistError::Message(error_message.clone())); + } + ContextEvent::ShowPaymentRequiredError => { + self.last_error = Some(AssistError::PaymentRequired); + } + ContextEvent::ShowMaxMonthlySpendReachedError => { + self.last_error = Some(AssistError::MaxMonthlySpendReached); } } } @@ -4305,6 +4318,154 @@ impl ContextEditor { focus_handle.dispatch_action(&Assist, cx); }) } + + fn render_last_error(&self, cx: &mut ViewContext) -> Option { + let last_error = self.last_error.as_ref()?; + + Some( + div() + .absolute() + .right_3() + .bottom_12() + .max_w_96() + .py_2() + .px_3() + .elevation_2(cx) + .occlude() + .child(match last_error { + AssistError::PaymentRequired => self.render_payment_required_error(cx), + AssistError::MaxMonthlySpendReached => { + self.render_max_monthly_spend_reached_error(cx) + } + AssistError::Message(error_message) => { + self.render_assist_error(error_message, cx) + } + }) + .into_any(), + ) + } + + fn render_payment_required_error(&self, cx: &mut ViewContext) -> AnyElement { + const ERROR_MESSAGE: &str = "Free tier exceeded. Subscribe and add payment to continue using Zed LLMs. You'll be billed at cost for tokens used."; + const SUBSCRIBE_URL: &str = "https://zed.dev/ai/subscribe"; + + v_flex() + .gap_0p5() + .child( + h_flex() + .gap_1p5() + .items_center() + .child(Icon::new(IconName::XCircle).color(Color::Error)) + .child(Label::new("Free Usage Exceeded").weight(FontWeight::MEDIUM)), + ) + .child( + div() + .id("error-message") + .max_h_24() + .overflow_y_scroll() + .child(Label::new(ERROR_MESSAGE)), + ) + .child( + h_flex() + .justify_end() + .mt_1() + .child(Button::new("subscribe", "Subscribe").on_click(cx.listener( + |this, _, cx| { + this.last_error = None; + cx.open_url(SUBSCRIBE_URL); + cx.notify(); + }, + ))) + .child(Button::new("dismiss", "Dismiss").on_click(cx.listener( + |this, _, cx| { + this.last_error = None; + cx.notify(); + }, + ))), + ) + .into_any() + } + + fn render_max_monthly_spend_reached_error(&self, cx: &mut ViewContext) -> AnyElement { + const ERROR_MESSAGE: &str = "You have reached your maximum monthly spend. Increase your spend limit to continue using Zed LLMs."; + const ACCOUNT_URL: &str = "https://zed.dev/account"; + + v_flex() + .gap_0p5() + .child( + h_flex() + .gap_1p5() + .items_center() + .child(Icon::new(IconName::XCircle).color(Color::Error)) + .child(Label::new("Max Monthly Spend Reached").weight(FontWeight::MEDIUM)), + ) + .child( + div() + .id("error-message") + .max_h_24() + .overflow_y_scroll() + .child(Label::new(ERROR_MESSAGE)), + ) + .child( + h_flex() + .justify_end() + .mt_1() + .child( + Button::new("subscribe", "Update Monthly Spend Limit").on_click( + cx.listener(|this, _, cx| { + this.last_error = None; + cx.open_url(ACCOUNT_URL); + cx.notify(); + }), + ), + ) + .child(Button::new("dismiss", "Dismiss").on_click(cx.listener( + |this, _, cx| { + this.last_error = None; + cx.notify(); + }, + ))), + ) + .into_any() + } + + fn render_assist_error( + &self, + error_message: &SharedString, + cx: &mut ViewContext, + ) -> AnyElement { + v_flex() + .gap_0p5() + .child( + h_flex() + .gap_1p5() + .items_center() + .child(Icon::new(IconName::XCircle).color(Color::Error)) + .child( + Label::new("Error interacting with language model") + .weight(FontWeight::MEDIUM), + ), + ) + .child( + div() + .id("error-message") + .max_h_24() + .overflow_y_scroll() + .child(Label::new(error_message.clone())), + ) + .child( + h_flex() + .justify_end() + .mt_1() + .child(Button::new("dismiss", "Dismiss").on_click(cx.listener( + |this, _, cx| { + this.last_error = None; + cx.notify(); + }, + ))), + ) + .into_any() + } } /// Returns the contents of the *outermost* fenced code block that contains the given offset. @@ -4441,48 +4602,7 @@ impl Render for ContextEditor { .child(element), ) }) - .when_some(self.error_message.clone(), |this, error_message| { - this.child( - div() - .absolute() - .right_3() - .bottom_12() - .max_w_96() - .py_2() - .px_3() - .elevation_2(cx) - .occlude() - .child( - v_flex() - .gap_0p5() - .child( - h_flex() - .gap_1p5() - .items_center() - .child(Icon::new(IconName::XCircle).color(Color::Error)) - .child( - Label::new("Error interacting with language model") - .weight(FontWeight::MEDIUM), - ), - ) - .child( - div() - .id("error-message") - .max_h_24() - .overflow_y_scroll() - .child(Label::new(error_message)), - ) - .child(h_flex().justify_end().mt_1().child( - Button::new("dismiss", "Dismiss").on_click(cx.listener( - |this, _, cx| { - this.error_message = None; - cx.notify(); - }, - )), - )), - ), - ) - }) + .children(self.render_last_error(cx)) .child( h_flex().w_full().relative().child( h_flex() diff --git a/crates/assistant/src/context.rs b/crates/assistant/src/context.rs index c27d17f8c5..1a5a4d44f0 100644 --- a/crates/assistant/src/context.rs +++ b/crates/assistant/src/context.rs @@ -26,6 +26,7 @@ use gpui::{ use language::{AnchorRangeExt, Bias, Buffer, LanguageRegistry, OffsetRangeExt, Point, ToOffset}; use language_model::{ + provider::cloud::{MaxMonthlySpendReachedError, PaymentRequiredError}, LanguageModel, LanguageModelCacheConfiguration, LanguageModelCompletionEvent, LanguageModelImage, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolUse, MessageContent, Role, @@ -294,6 +295,8 @@ impl ContextOperation { #[derive(Debug, Clone)] pub enum ContextEvent { ShowAssistError(SharedString), + ShowPaymentRequiredError, + ShowMaxMonthlySpendReachedError, MessagesEdited, SummaryChanged, StreamedCompletion, @@ -2112,25 +2115,36 @@ impl Context { let result = stream_completion.await; this.update(&mut cx, |this, cx| { - let error_message = result - .as_ref() - .err() - .map(|error| error.to_string().trim().to_string()); - - if let Some(error_message) = error_message.as_ref() { - cx.emit(ContextEvent::ShowAssistError(SharedString::from( - error_message.clone(), - ))); - } - - this.update_metadata(assistant_message_id, cx, |metadata| { - if let Some(error_message) = error_message.as_ref() { - metadata.status = - MessageStatus::Error(SharedString::from(error_message.clone())); + let error_message = if let Some(error) = result.as_ref().err() { + if error.is::() { + cx.emit(ContextEvent::ShowPaymentRequiredError); + this.update_metadata(assistant_message_id, cx, |metadata| { + metadata.status = MessageStatus::Canceled; + }); + Some(error.to_string()) + } else if error.is::() { + cx.emit(ContextEvent::ShowMaxMonthlySpendReachedError); + this.update_metadata(assistant_message_id, cx, |metadata| { + metadata.status = MessageStatus::Canceled; + }); + Some(error.to_string()) } else { - metadata.status = MessageStatus::Done; + let error_message = error.to_string().trim().to_string(); + cx.emit(ContextEvent::ShowAssistError(SharedString::from( + error_message.clone(), + ))); + this.update_metadata(assistant_message_id, cx, |metadata| { + metadata.status = + MessageStatus::Error(SharedString::from(error_message.clone())); + }); + Some(error_message) } - }); + } else { + this.update_metadata(assistant_message_id, cx, |metadata| { + metadata.status = MessageStatus::Done; + }); + None + }; if let Some(telemetry) = this.telemetry.as_ref() { let language_name = this diff --git a/crates/language_model/Cargo.toml b/crates/language_model/Cargo.toml index ef273ac44f..74a2ed0ed0 100644 --- a/crates/language_model/Cargo.toml +++ b/crates/language_model/Cargo.toml @@ -47,6 +47,7 @@ settings.workspace = true smol.workspace = true strum.workspace = true theme.workspace = true +thiserror.workspace = true tiktoken-rs.workspace = true ui.workspace = true util.workspace = true diff --git a/crates/language_model/src/provider/cloud.rs b/crates/language_model/src/provider/cloud.rs index b81f6f9fba..eb66a22e59 100644 --- a/crates/language_model/src/provider/cloud.rs +++ b/crates/language_model/src/provider/cloud.rs @@ -7,7 +7,10 @@ use crate::{ }; use anthropic::AnthropicError; use anyhow::{anyhow, Result}; -use client::{Client, PerformCompletionParams, UserStore, EXPIRED_LLM_TOKEN_HEADER_NAME}; +use client::{ + Client, PerformCompletionParams, UserStore, EXPIRED_LLM_TOKEN_HEADER_NAME, + MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME, +}; use collections::BTreeMap; use feature_flags::{FeatureFlagAppExt, LlmClosedBeta, ZedPro}; use futures::{ @@ -15,10 +18,11 @@ use futures::{ TryStreamExt as _, }; use gpui::{ - AnyElement, AnyView, AppContext, AsyncAppContext, FontWeight, Model, ModelContext, - Subscription, Task, + AnyElement, AnyView, AppContext, AsyncAppContext, EventEmitter, FontWeight, Global, Model, + ModelContext, ReadGlobal, Subscription, Task, }; -use http_client::{AsyncBody, HttpClient, HttpRequestExt, Method, Response}; +use http_client::{AsyncBody, HttpClient, HttpRequestExt, Method, Response, StatusCode}; +use proto::TypedEnvelope; use schemars::JsonSchema; use serde::{de::DeserializeOwned, Deserialize, Serialize}; use serde_json::value::RawValue; @@ -27,12 +31,14 @@ use smol::{ io::{AsyncReadExt, BufReader}, lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard}, }; +use std::fmt; use std::time::Duration; use std::{ future, sync::{Arc, LazyLock}, }; use strum::IntoEnumIterator; +use thiserror::Error; use ui::{prelude::*, TintColor}; use crate::{LanguageModelAvailability, LanguageModelCompletionEvent, LanguageModelProvider}; @@ -90,22 +96,93 @@ pub struct AvailableModel { pub default_temperature: Option, } +struct GlobalRefreshLlmTokenListener(Model); + +impl Global for GlobalRefreshLlmTokenListener {} + +pub struct RefreshLlmTokenEvent; + +pub struct RefreshLlmTokenListener { + _llm_token_subscription: client::Subscription, +} + +impl EventEmitter for RefreshLlmTokenListener {} + +impl RefreshLlmTokenListener { + pub fn register(client: Arc, cx: &mut AppContext) { + let listener = cx.new_model(|cx| RefreshLlmTokenListener::new(client, cx)); + cx.set_global(GlobalRefreshLlmTokenListener(listener)); + } + + pub fn global(cx: &AppContext) -> Model { + GlobalRefreshLlmTokenListener::global(cx).0.clone() + } + + fn new(client: Arc, cx: &mut ModelContext) -> Self { + Self { + _llm_token_subscription: client + .add_message_handler(cx.weak_model(), Self::handle_refresh_llm_token), + } + } + + async fn handle_refresh_llm_token( + this: Model, + _: TypedEnvelope, + mut cx: AsyncAppContext, + ) -> Result<()> { + this.update(&mut cx, |_this, cx| cx.emit(RefreshLlmTokenEvent)) + } +} + pub struct CloudLanguageModelProvider { client: Arc, - llm_api_token: LlmApiToken, state: gpui::Model, _maintain_client_status: Task<()>, } pub struct State { client: Arc, + llm_api_token: LlmApiToken, user_store: Model, status: client::Status, accept_terms: Option>>, - _subscription: Subscription, + _settings_subscription: Subscription, + _llm_token_subscription: Subscription, } impl State { + fn new( + client: Arc, + user_store: Model, + status: client::Status, + cx: &mut ModelContext, + ) -> Self { + let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx); + + Self { + client: client.clone(), + llm_api_token: LlmApiToken::default(), + user_store, + status, + accept_terms: None, + _settings_subscription: cx.observe_global::(|_, cx| { + cx.notify(); + }), + _llm_token_subscription: cx.subscribe( + &refresh_llm_token_listener, + |this, _listener, _event, cx| { + let client = this.client.clone(); + let llm_api_token = this.llm_api_token.clone(); + cx.spawn(|_this, _cx| async move { + llm_api_token.refresh(&client).await?; + anyhow::Ok(()) + }) + .detach_and_log_err(cx); + }, + ), + } + } + fn is_signed_out(&self) -> bool { self.status.is_signed_out() } @@ -144,15 +221,7 @@ impl CloudLanguageModelProvider { let mut status_rx = client.status(); let status = *status_rx.borrow(); - let state = cx.new_model(|cx| State { - client: client.clone(), - user_store, - status, - accept_terms: None, - _subscription: cx.observe_global::(|_, cx| { - cx.notify(); - }), - }); + let state = cx.new_model(|cx| State::new(client.clone(), user_store.clone(), status, cx)); let state_ref = state.downgrade(); let maintain_client_status = cx.spawn(|mut cx| async move { @@ -172,8 +241,7 @@ impl CloudLanguageModelProvider { Self { client, - state, - llm_api_token: LlmApiToken::default(), + state: state.clone(), _maintain_client_status: maintain_client_status, } } @@ -272,13 +340,14 @@ impl LanguageModelProvider for CloudLanguageModelProvider { models.insert(model.id().to_string(), model.clone()); } + let llm_api_token = self.state.read(cx).llm_api_token.clone(); models .into_values() .map(|model| { Arc::new(CloudLanguageModel { id: LanguageModelId::from(model.id().to_string()), model, - llm_api_token: self.llm_api_token.clone(), + llm_api_token: llm_api_token.clone(), client: self.client.clone(), request_limiter: RateLimiter::new(4), }) as Arc @@ -377,6 +446,30 @@ pub struct CloudLanguageModel { #[derive(Clone, Default)] struct LlmApiToken(Arc>>); +#[derive(Error, Debug)] +pub struct PaymentRequiredError; + +impl fmt::Display for PaymentRequiredError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!( + f, + "Payment required to use this language model. Please upgrade your account." + ) + } +} + +#[derive(Error, Debug)] +pub struct MaxMonthlySpendReachedError; + +impl fmt::Display for MaxMonthlySpendReachedError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!( + f, + "Maximum spending limit reached for this month. For more usage, increase your spending limit." + ) + } +} + impl CloudLanguageModel { async fn perform_llm_completion( client: Arc, @@ -411,6 +504,15 @@ impl CloudLanguageModel { { did_retry = true; token = llm_api_token.refresh(&client).await?; + } else if response.status() == StatusCode::FORBIDDEN + && response + .headers() + .get(MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME) + .is_some() + { + break Err(anyhow!(MaxMonthlySpendReachedError))?; + } else if response.status() == StatusCode::PAYMENT_REQUIRED { + break Err(anyhow!(PaymentRequiredError))?; } else { let mut body = String::new(); response.body_mut().read_to_string(&mut body).await?; diff --git a/crates/language_model/src/registry.rs b/crates/language_model/src/registry.rs index e1ba1c5886..72dfd998d4 100644 --- a/crates/language_model/src/registry.rs +++ b/crates/language_model/src/registry.rs @@ -1,3 +1,4 @@ +use crate::provider::cloud::RefreshLlmTokenListener; use crate::{ provider::{ anthropic::AnthropicLanguageModelProvider, cloud::CloudLanguageModelProvider, @@ -30,6 +31,8 @@ fn register_language_model_providers( ) { use feature_flags::FeatureFlagAppExt; + RefreshLlmTokenListener::register(client.clone(), cx); + registry.register_provider( AnthropicLanguageModelProvider::new(client.http_client(), cx), cx, diff --git a/crates/proto/proto/zed.proto b/crates/proto/proto/zed.proto index e4e6ac4240..f6711f8db9 100644 --- a/crates/proto/proto/zed.proto +++ b/crates/proto/proto/zed.proto @@ -269,6 +269,7 @@ message Envelope { GetLlmToken get_llm_token = 235; GetLlmTokenResponse get_llm_token_response = 236; + RefreshLlmToken refresh_llm_token = 259; // current max LspExtSwitchSourceHeader lsp_ext_switch_source_header = 241; LspExtSwitchSourceHeaderResponse lsp_ext_switch_source_header_response = 242; @@ -284,7 +285,7 @@ message Envelope { ShutdownRemoteServer shutdown_remote_server = 257; - RemoveWorktree remove_worktree = 258; // current max + RemoveWorktree remove_worktree = 258; } reserved 87 to 88; @@ -2429,6 +2430,8 @@ message GetLlmTokenResponse { string token = 1; } +message RefreshLlmToken {} + // Remote FS message AddWorktree { diff --git a/crates/proto/src/proto.rs b/crates/proto/src/proto.rs index 93c92e9d47..2c038c2e1c 100644 --- a/crates/proto/src/proto.rs +++ b/crates/proto/src/proto.rs @@ -253,6 +253,7 @@ messages!( (ProjectEntryResponse, Foreground), (CountLanguageModelTokens, Background), (CountLanguageModelTokensResponse, Background), + (RefreshLlmToken, Background), (RefreshInlayHints, Foreground), (RejoinChannelBuffers, Foreground), (RejoinChannelBuffersResponse, Foreground), diff --git a/crates/rpc/src/llm.rs b/crates/rpc/src/llm.rs index 681f2d8db3..0a7510d891 100644 --- a/crates/rpc/src/llm.rs +++ b/crates/rpc/src/llm.rs @@ -3,6 +3,8 @@ use strum::{Display, EnumIter, EnumString}; pub const EXPIRED_LLM_TOKEN_HEADER_NAME: &str = "x-zed-expired-token"; +pub const MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME: &str = "x-zed-llm-max-monthly-spend-reached"; + #[derive( Debug, PartialEq, Eq, Hash, Clone, Copy, Serialize, Deserialize, EnumString, EnumIter, Display, )]