assistant: Add support for displaying billing-related errors (#19082)
This PR adds support to the assistant for display billing-related errors. Pulling this out of #19081 to make it easier to cherry-pick. Release Notes: - N/A Co-authored-by: Antonio <antonio@zed.dev> Co-authored-by: Richard <richard@zed.dev>
This commit is contained in:
parent
5cf0217549
commit
84b61c8b1a
9 changed files with 330 additions and 83 deletions
1
Cargo.lock
generated
1
Cargo.lock
generated
|
@ -6291,6 +6291,7 @@ dependencies = [
|
||||||
"strum 0.25.0",
|
"strum 0.25.0",
|
||||||
"text",
|
"text",
|
||||||
"theme",
|
"theme",
|
||||||
|
"thiserror",
|
||||||
"tiktoken-rs",
|
"tiktoken-rs",
|
||||||
"ui",
|
"ui",
|
||||||
"unindent",
|
"unindent",
|
||||||
|
|
|
@ -1503,6 +1503,13 @@ struct WorkflowAssist {
|
||||||
|
|
||||||
type MessageHeader = MessageMetadata;
|
type MessageHeader = MessageMetadata;
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
enum AssistError {
|
||||||
|
PaymentRequired,
|
||||||
|
MaxMonthlySpendReached,
|
||||||
|
Message(SharedString),
|
||||||
|
}
|
||||||
|
|
||||||
pub struct ContextEditor {
|
pub struct ContextEditor {
|
||||||
context: Model<Context>,
|
context: Model<Context>,
|
||||||
fs: Arc<dyn Fs>,
|
fs: Arc<dyn Fs>,
|
||||||
|
@ -1521,7 +1528,7 @@ pub struct ContextEditor {
|
||||||
workflow_steps: HashMap<Range<language::Anchor>, WorkflowStepViewState>,
|
workflow_steps: HashMap<Range<language::Anchor>, WorkflowStepViewState>,
|
||||||
active_workflow_step: Option<ActiveWorkflowStep>,
|
active_workflow_step: Option<ActiveWorkflowStep>,
|
||||||
assistant_panel: WeakView<AssistantPanel>,
|
assistant_panel: WeakView<AssistantPanel>,
|
||||||
error_message: Option<SharedString>,
|
last_error: Option<AssistError>,
|
||||||
show_accept_terms: bool,
|
show_accept_terms: bool,
|
||||||
pub(crate) slash_menu_handle:
|
pub(crate) slash_menu_handle:
|
||||||
PopoverMenuHandle<Picker<slash_command_picker::SlashCommandDelegate>>,
|
PopoverMenuHandle<Picker<slash_command_picker::SlashCommandDelegate>>,
|
||||||
|
@ -1592,7 +1599,7 @@ impl ContextEditor {
|
||||||
workflow_steps: HashMap::default(),
|
workflow_steps: HashMap::default(),
|
||||||
active_workflow_step: None,
|
active_workflow_step: None,
|
||||||
assistant_panel,
|
assistant_panel,
|
||||||
error_message: None,
|
last_error: None,
|
||||||
show_accept_terms: false,
|
show_accept_terms: false,
|
||||||
slash_menu_handle: Default::default(),
|
slash_menu_handle: Default::default(),
|
||||||
dragged_file_worktrees: Vec::new(),
|
dragged_file_worktrees: Vec::new(),
|
||||||
|
@ -1636,7 +1643,7 @@ impl ContextEditor {
|
||||||
}
|
}
|
||||||
|
|
||||||
if !self.apply_active_workflow_step(cx) {
|
if !self.apply_active_workflow_step(cx) {
|
||||||
self.error_message = None;
|
self.last_error = None;
|
||||||
self.send_to_model(cx);
|
self.send_to_model(cx);
|
||||||
cx.notify();
|
cx.notify();
|
||||||
}
|
}
|
||||||
|
@ -1786,7 +1793,7 @@ impl ContextEditor {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn cancel(&mut self, _: &editor::actions::Cancel, cx: &mut ViewContext<Self>) {
|
fn cancel(&mut self, _: &editor::actions::Cancel, cx: &mut ViewContext<Self>) {
|
||||||
self.error_message = None;
|
self.last_error = None;
|
||||||
|
|
||||||
if self
|
if self
|
||||||
.context
|
.context
|
||||||
|
@ -2291,7 +2298,13 @@ impl ContextEditor {
|
||||||
}
|
}
|
||||||
ContextEvent::Operation(_) => {}
|
ContextEvent::Operation(_) => {}
|
||||||
ContextEvent::ShowAssistError(error_message) => {
|
ContextEvent::ShowAssistError(error_message) => {
|
||||||
self.error_message = Some(error_message.clone());
|
self.last_error = Some(AssistError::Message(error_message.clone()));
|
||||||
|
}
|
||||||
|
ContextEvent::ShowPaymentRequiredError => {
|
||||||
|
self.last_error = Some(AssistError::PaymentRequired);
|
||||||
|
}
|
||||||
|
ContextEvent::ShowMaxMonthlySpendReachedError => {
|
||||||
|
self.last_error = Some(AssistError::MaxMonthlySpendReached);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -4305,6 +4318,154 @@ impl ContextEditor {
|
||||||
focus_handle.dispatch_action(&Assist, cx);
|
focus_handle.dispatch_action(&Assist, cx);
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn render_last_error(&self, cx: &mut ViewContext<Self>) -> Option<AnyElement> {
|
||||||
|
let last_error = self.last_error.as_ref()?;
|
||||||
|
|
||||||
|
Some(
|
||||||
|
div()
|
||||||
|
.absolute()
|
||||||
|
.right_3()
|
||||||
|
.bottom_12()
|
||||||
|
.max_w_96()
|
||||||
|
.py_2()
|
||||||
|
.px_3()
|
||||||
|
.elevation_2(cx)
|
||||||
|
.occlude()
|
||||||
|
.child(match last_error {
|
||||||
|
AssistError::PaymentRequired => self.render_payment_required_error(cx),
|
||||||
|
AssistError::MaxMonthlySpendReached => {
|
||||||
|
self.render_max_monthly_spend_reached_error(cx)
|
||||||
|
}
|
||||||
|
AssistError::Message(error_message) => {
|
||||||
|
self.render_assist_error(error_message, cx)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.into_any(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn render_payment_required_error(&self, cx: &mut ViewContext<Self>) -> AnyElement {
|
||||||
|
const ERROR_MESSAGE: &str = "Free tier exceeded. Subscribe and add payment to continue using Zed LLMs. You'll be billed at cost for tokens used.";
|
||||||
|
const SUBSCRIBE_URL: &str = "https://zed.dev/ai/subscribe";
|
||||||
|
|
||||||
|
v_flex()
|
||||||
|
.gap_0p5()
|
||||||
|
.child(
|
||||||
|
h_flex()
|
||||||
|
.gap_1p5()
|
||||||
|
.items_center()
|
||||||
|
.child(Icon::new(IconName::XCircle).color(Color::Error))
|
||||||
|
.child(Label::new("Free Usage Exceeded").weight(FontWeight::MEDIUM)),
|
||||||
|
)
|
||||||
|
.child(
|
||||||
|
div()
|
||||||
|
.id("error-message")
|
||||||
|
.max_h_24()
|
||||||
|
.overflow_y_scroll()
|
||||||
|
.child(Label::new(ERROR_MESSAGE)),
|
||||||
|
)
|
||||||
|
.child(
|
||||||
|
h_flex()
|
||||||
|
.justify_end()
|
||||||
|
.mt_1()
|
||||||
|
.child(Button::new("subscribe", "Subscribe").on_click(cx.listener(
|
||||||
|
|this, _, cx| {
|
||||||
|
this.last_error = None;
|
||||||
|
cx.open_url(SUBSCRIBE_URL);
|
||||||
|
cx.notify();
|
||||||
|
},
|
||||||
|
)))
|
||||||
|
.child(Button::new("dismiss", "Dismiss").on_click(cx.listener(
|
||||||
|
|this, _, cx| {
|
||||||
|
this.last_error = None;
|
||||||
|
cx.notify();
|
||||||
|
},
|
||||||
|
))),
|
||||||
|
)
|
||||||
|
.into_any()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn render_max_monthly_spend_reached_error(&self, cx: &mut ViewContext<Self>) -> AnyElement {
|
||||||
|
const ERROR_MESSAGE: &str = "You have reached your maximum monthly spend. Increase your spend limit to continue using Zed LLMs.";
|
||||||
|
const ACCOUNT_URL: &str = "https://zed.dev/account";
|
||||||
|
|
||||||
|
v_flex()
|
||||||
|
.gap_0p5()
|
||||||
|
.child(
|
||||||
|
h_flex()
|
||||||
|
.gap_1p5()
|
||||||
|
.items_center()
|
||||||
|
.child(Icon::new(IconName::XCircle).color(Color::Error))
|
||||||
|
.child(Label::new("Max Monthly Spend Reached").weight(FontWeight::MEDIUM)),
|
||||||
|
)
|
||||||
|
.child(
|
||||||
|
div()
|
||||||
|
.id("error-message")
|
||||||
|
.max_h_24()
|
||||||
|
.overflow_y_scroll()
|
||||||
|
.child(Label::new(ERROR_MESSAGE)),
|
||||||
|
)
|
||||||
|
.child(
|
||||||
|
h_flex()
|
||||||
|
.justify_end()
|
||||||
|
.mt_1()
|
||||||
|
.child(
|
||||||
|
Button::new("subscribe", "Update Monthly Spend Limit").on_click(
|
||||||
|
cx.listener(|this, _, cx| {
|
||||||
|
this.last_error = None;
|
||||||
|
cx.open_url(ACCOUNT_URL);
|
||||||
|
cx.notify();
|
||||||
|
}),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
.child(Button::new("dismiss", "Dismiss").on_click(cx.listener(
|
||||||
|
|this, _, cx| {
|
||||||
|
this.last_error = None;
|
||||||
|
cx.notify();
|
||||||
|
},
|
||||||
|
))),
|
||||||
|
)
|
||||||
|
.into_any()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn render_assist_error(
|
||||||
|
&self,
|
||||||
|
error_message: &SharedString,
|
||||||
|
cx: &mut ViewContext<Self>,
|
||||||
|
) -> AnyElement {
|
||||||
|
v_flex()
|
||||||
|
.gap_0p5()
|
||||||
|
.child(
|
||||||
|
h_flex()
|
||||||
|
.gap_1p5()
|
||||||
|
.items_center()
|
||||||
|
.child(Icon::new(IconName::XCircle).color(Color::Error))
|
||||||
|
.child(
|
||||||
|
Label::new("Error interacting with language model")
|
||||||
|
.weight(FontWeight::MEDIUM),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
.child(
|
||||||
|
div()
|
||||||
|
.id("error-message")
|
||||||
|
.max_h_24()
|
||||||
|
.overflow_y_scroll()
|
||||||
|
.child(Label::new(error_message.clone())),
|
||||||
|
)
|
||||||
|
.child(
|
||||||
|
h_flex()
|
||||||
|
.justify_end()
|
||||||
|
.mt_1()
|
||||||
|
.child(Button::new("dismiss", "Dismiss").on_click(cx.listener(
|
||||||
|
|this, _, cx| {
|
||||||
|
this.last_error = None;
|
||||||
|
cx.notify();
|
||||||
|
},
|
||||||
|
))),
|
||||||
|
)
|
||||||
|
.into_any()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns the contents of the *outermost* fenced code block that contains the given offset.
|
/// Returns the contents of the *outermost* fenced code block that contains the given offset.
|
||||||
|
@ -4441,48 +4602,7 @@ impl Render for ContextEditor {
|
||||||
.child(element),
|
.child(element),
|
||||||
)
|
)
|
||||||
})
|
})
|
||||||
.when_some(self.error_message.clone(), |this, error_message| {
|
.children(self.render_last_error(cx))
|
||||||
this.child(
|
|
||||||
div()
|
|
||||||
.absolute()
|
|
||||||
.right_3()
|
|
||||||
.bottom_12()
|
|
||||||
.max_w_96()
|
|
||||||
.py_2()
|
|
||||||
.px_3()
|
|
||||||
.elevation_2(cx)
|
|
||||||
.occlude()
|
|
||||||
.child(
|
|
||||||
v_flex()
|
|
||||||
.gap_0p5()
|
|
||||||
.child(
|
|
||||||
h_flex()
|
|
||||||
.gap_1p5()
|
|
||||||
.items_center()
|
|
||||||
.child(Icon::new(IconName::XCircle).color(Color::Error))
|
|
||||||
.child(
|
|
||||||
Label::new("Error interacting with language model")
|
|
||||||
.weight(FontWeight::MEDIUM),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
.child(
|
|
||||||
div()
|
|
||||||
.id("error-message")
|
|
||||||
.max_h_24()
|
|
||||||
.overflow_y_scroll()
|
|
||||||
.child(Label::new(error_message)),
|
|
||||||
)
|
|
||||||
.child(h_flex().justify_end().mt_1().child(
|
|
||||||
Button::new("dismiss", "Dismiss").on_click(cx.listener(
|
|
||||||
|this, _, cx| {
|
|
||||||
this.error_message = None;
|
|
||||||
cx.notify();
|
|
||||||
},
|
|
||||||
)),
|
|
||||||
)),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
})
|
|
||||||
.child(
|
.child(
|
||||||
h_flex().w_full().relative().child(
|
h_flex().w_full().relative().child(
|
||||||
h_flex()
|
h_flex()
|
||||||
|
|
|
@ -26,6 +26,7 @@ use gpui::{
|
||||||
|
|
||||||
use language::{AnchorRangeExt, Bias, Buffer, LanguageRegistry, OffsetRangeExt, Point, ToOffset};
|
use language::{AnchorRangeExt, Bias, Buffer, LanguageRegistry, OffsetRangeExt, Point, ToOffset};
|
||||||
use language_model::{
|
use language_model::{
|
||||||
|
provider::cloud::{MaxMonthlySpendReachedError, PaymentRequiredError},
|
||||||
LanguageModel, LanguageModelCacheConfiguration, LanguageModelCompletionEvent,
|
LanguageModel, LanguageModelCacheConfiguration, LanguageModelCompletionEvent,
|
||||||
LanguageModelImage, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
|
LanguageModelImage, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
|
||||||
LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolUse, MessageContent, Role,
|
LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolUse, MessageContent, Role,
|
||||||
|
@ -294,6 +295,8 @@ impl ContextOperation {
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub enum ContextEvent {
|
pub enum ContextEvent {
|
||||||
ShowAssistError(SharedString),
|
ShowAssistError(SharedString),
|
||||||
|
ShowPaymentRequiredError,
|
||||||
|
ShowMaxMonthlySpendReachedError,
|
||||||
MessagesEdited,
|
MessagesEdited,
|
||||||
SummaryChanged,
|
SummaryChanged,
|
||||||
StreamedCompletion,
|
StreamedCompletion,
|
||||||
|
@ -2112,25 +2115,36 @@ impl Context {
|
||||||
let result = stream_completion.await;
|
let result = stream_completion.await;
|
||||||
|
|
||||||
this.update(&mut cx, |this, cx| {
|
this.update(&mut cx, |this, cx| {
|
||||||
let error_message = result
|
let error_message = if let Some(error) = result.as_ref().err() {
|
||||||
.as_ref()
|
if error.is::<PaymentRequiredError>() {
|
||||||
.err()
|
cx.emit(ContextEvent::ShowPaymentRequiredError);
|
||||||
.map(|error| error.to_string().trim().to_string());
|
this.update_metadata(assistant_message_id, cx, |metadata| {
|
||||||
|
metadata.status = MessageStatus::Canceled;
|
||||||
if let Some(error_message) = error_message.as_ref() {
|
});
|
||||||
cx.emit(ContextEvent::ShowAssistError(SharedString::from(
|
Some(error.to_string())
|
||||||
error_message.clone(),
|
} else if error.is::<MaxMonthlySpendReachedError>() {
|
||||||
)));
|
cx.emit(ContextEvent::ShowMaxMonthlySpendReachedError);
|
||||||
}
|
this.update_metadata(assistant_message_id, cx, |metadata| {
|
||||||
|
metadata.status = MessageStatus::Canceled;
|
||||||
this.update_metadata(assistant_message_id, cx, |metadata| {
|
});
|
||||||
if let Some(error_message) = error_message.as_ref() {
|
Some(error.to_string())
|
||||||
metadata.status =
|
|
||||||
MessageStatus::Error(SharedString::from(error_message.clone()));
|
|
||||||
} else {
|
} else {
|
||||||
metadata.status = MessageStatus::Done;
|
let error_message = error.to_string().trim().to_string();
|
||||||
|
cx.emit(ContextEvent::ShowAssistError(SharedString::from(
|
||||||
|
error_message.clone(),
|
||||||
|
)));
|
||||||
|
this.update_metadata(assistant_message_id, cx, |metadata| {
|
||||||
|
metadata.status =
|
||||||
|
MessageStatus::Error(SharedString::from(error_message.clone()));
|
||||||
|
});
|
||||||
|
Some(error_message)
|
||||||
}
|
}
|
||||||
});
|
} else {
|
||||||
|
this.update_metadata(assistant_message_id, cx, |metadata| {
|
||||||
|
metadata.status = MessageStatus::Done;
|
||||||
|
});
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
if let Some(telemetry) = this.telemetry.as_ref() {
|
if let Some(telemetry) = this.telemetry.as_ref() {
|
||||||
let language_name = this
|
let language_name = this
|
||||||
|
|
|
@ -47,6 +47,7 @@ settings.workspace = true
|
||||||
smol.workspace = true
|
smol.workspace = true
|
||||||
strum.workspace = true
|
strum.workspace = true
|
||||||
theme.workspace = true
|
theme.workspace = true
|
||||||
|
thiserror.workspace = true
|
||||||
tiktoken-rs.workspace = true
|
tiktoken-rs.workspace = true
|
||||||
ui.workspace = true
|
ui.workspace = true
|
||||||
util.workspace = true
|
util.workspace = true
|
||||||
|
|
|
@ -7,7 +7,10 @@ use crate::{
|
||||||
};
|
};
|
||||||
use anthropic::AnthropicError;
|
use anthropic::AnthropicError;
|
||||||
use anyhow::{anyhow, Result};
|
use anyhow::{anyhow, Result};
|
||||||
use client::{Client, PerformCompletionParams, UserStore, EXPIRED_LLM_TOKEN_HEADER_NAME};
|
use client::{
|
||||||
|
Client, PerformCompletionParams, UserStore, EXPIRED_LLM_TOKEN_HEADER_NAME,
|
||||||
|
MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME,
|
||||||
|
};
|
||||||
use collections::BTreeMap;
|
use collections::BTreeMap;
|
||||||
use feature_flags::{FeatureFlagAppExt, LlmClosedBeta, ZedPro};
|
use feature_flags::{FeatureFlagAppExt, LlmClosedBeta, ZedPro};
|
||||||
use futures::{
|
use futures::{
|
||||||
|
@ -15,10 +18,11 @@ use futures::{
|
||||||
TryStreamExt as _,
|
TryStreamExt as _,
|
||||||
};
|
};
|
||||||
use gpui::{
|
use gpui::{
|
||||||
AnyElement, AnyView, AppContext, AsyncAppContext, FontWeight, Model, ModelContext,
|
AnyElement, AnyView, AppContext, AsyncAppContext, EventEmitter, FontWeight, Global, Model,
|
||||||
Subscription, Task,
|
ModelContext, ReadGlobal, Subscription, Task,
|
||||||
};
|
};
|
||||||
use http_client::{AsyncBody, HttpClient, HttpRequestExt, Method, Response};
|
use http_client::{AsyncBody, HttpClient, HttpRequestExt, Method, Response, StatusCode};
|
||||||
|
use proto::TypedEnvelope;
|
||||||
use schemars::JsonSchema;
|
use schemars::JsonSchema;
|
||||||
use serde::{de::DeserializeOwned, Deserialize, Serialize};
|
use serde::{de::DeserializeOwned, Deserialize, Serialize};
|
||||||
use serde_json::value::RawValue;
|
use serde_json::value::RawValue;
|
||||||
|
@ -27,12 +31,14 @@ use smol::{
|
||||||
io::{AsyncReadExt, BufReader},
|
io::{AsyncReadExt, BufReader},
|
||||||
lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard},
|
lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard},
|
||||||
};
|
};
|
||||||
|
use std::fmt;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
use std::{
|
use std::{
|
||||||
future,
|
future,
|
||||||
sync::{Arc, LazyLock},
|
sync::{Arc, LazyLock},
|
||||||
};
|
};
|
||||||
use strum::IntoEnumIterator;
|
use strum::IntoEnumIterator;
|
||||||
|
use thiserror::Error;
|
||||||
use ui::{prelude::*, TintColor};
|
use ui::{prelude::*, TintColor};
|
||||||
|
|
||||||
use crate::{LanguageModelAvailability, LanguageModelCompletionEvent, LanguageModelProvider};
|
use crate::{LanguageModelAvailability, LanguageModelCompletionEvent, LanguageModelProvider};
|
||||||
|
@ -90,22 +96,93 @@ pub struct AvailableModel {
|
||||||
pub default_temperature: Option<f32>,
|
pub default_temperature: Option<f32>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct GlobalRefreshLlmTokenListener(Model<RefreshLlmTokenListener>);
|
||||||
|
|
||||||
|
impl Global for GlobalRefreshLlmTokenListener {}
|
||||||
|
|
||||||
|
pub struct RefreshLlmTokenEvent;
|
||||||
|
|
||||||
|
pub struct RefreshLlmTokenListener {
|
||||||
|
_llm_token_subscription: client::Subscription,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl EventEmitter<RefreshLlmTokenEvent> for RefreshLlmTokenListener {}
|
||||||
|
|
||||||
|
impl RefreshLlmTokenListener {
|
||||||
|
pub fn register(client: Arc<Client>, cx: &mut AppContext) {
|
||||||
|
let listener = cx.new_model(|cx| RefreshLlmTokenListener::new(client, cx));
|
||||||
|
cx.set_global(GlobalRefreshLlmTokenListener(listener));
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn global(cx: &AppContext) -> Model<Self> {
|
||||||
|
GlobalRefreshLlmTokenListener::global(cx).0.clone()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn new(client: Arc<Client>, cx: &mut ModelContext<Self>) -> Self {
|
||||||
|
Self {
|
||||||
|
_llm_token_subscription: client
|
||||||
|
.add_message_handler(cx.weak_model(), Self::handle_refresh_llm_token),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn handle_refresh_llm_token(
|
||||||
|
this: Model<Self>,
|
||||||
|
_: TypedEnvelope<proto::RefreshLlmToken>,
|
||||||
|
mut cx: AsyncAppContext,
|
||||||
|
) -> Result<()> {
|
||||||
|
this.update(&mut cx, |_this, cx| cx.emit(RefreshLlmTokenEvent))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub struct CloudLanguageModelProvider {
|
pub struct CloudLanguageModelProvider {
|
||||||
client: Arc<Client>,
|
client: Arc<Client>,
|
||||||
llm_api_token: LlmApiToken,
|
|
||||||
state: gpui::Model<State>,
|
state: gpui::Model<State>,
|
||||||
_maintain_client_status: Task<()>,
|
_maintain_client_status: Task<()>,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct State {
|
pub struct State {
|
||||||
client: Arc<Client>,
|
client: Arc<Client>,
|
||||||
|
llm_api_token: LlmApiToken,
|
||||||
user_store: Model<UserStore>,
|
user_store: Model<UserStore>,
|
||||||
status: client::Status,
|
status: client::Status,
|
||||||
accept_terms: Option<Task<Result<()>>>,
|
accept_terms: Option<Task<Result<()>>>,
|
||||||
_subscription: Subscription,
|
_settings_subscription: Subscription,
|
||||||
|
_llm_token_subscription: Subscription,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl State {
|
impl State {
|
||||||
|
fn new(
|
||||||
|
client: Arc<Client>,
|
||||||
|
user_store: Model<UserStore>,
|
||||||
|
status: client::Status,
|
||||||
|
cx: &mut ModelContext<Self>,
|
||||||
|
) -> Self {
|
||||||
|
let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
|
||||||
|
|
||||||
|
Self {
|
||||||
|
client: client.clone(),
|
||||||
|
llm_api_token: LlmApiToken::default(),
|
||||||
|
user_store,
|
||||||
|
status,
|
||||||
|
accept_terms: None,
|
||||||
|
_settings_subscription: cx.observe_global::<SettingsStore>(|_, cx| {
|
||||||
|
cx.notify();
|
||||||
|
}),
|
||||||
|
_llm_token_subscription: cx.subscribe(
|
||||||
|
&refresh_llm_token_listener,
|
||||||
|
|this, _listener, _event, cx| {
|
||||||
|
let client = this.client.clone();
|
||||||
|
let llm_api_token = this.llm_api_token.clone();
|
||||||
|
cx.spawn(|_this, _cx| async move {
|
||||||
|
llm_api_token.refresh(&client).await?;
|
||||||
|
anyhow::Ok(())
|
||||||
|
})
|
||||||
|
.detach_and_log_err(cx);
|
||||||
|
},
|
||||||
|
),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn is_signed_out(&self) -> bool {
|
fn is_signed_out(&self) -> bool {
|
||||||
self.status.is_signed_out()
|
self.status.is_signed_out()
|
||||||
}
|
}
|
||||||
|
@ -144,15 +221,7 @@ impl CloudLanguageModelProvider {
|
||||||
let mut status_rx = client.status();
|
let mut status_rx = client.status();
|
||||||
let status = *status_rx.borrow();
|
let status = *status_rx.borrow();
|
||||||
|
|
||||||
let state = cx.new_model(|cx| State {
|
let state = cx.new_model(|cx| State::new(client.clone(), user_store.clone(), status, cx));
|
||||||
client: client.clone(),
|
|
||||||
user_store,
|
|
||||||
status,
|
|
||||||
accept_terms: None,
|
|
||||||
_subscription: cx.observe_global::<SettingsStore>(|_, cx| {
|
|
||||||
cx.notify();
|
|
||||||
}),
|
|
||||||
});
|
|
||||||
|
|
||||||
let state_ref = state.downgrade();
|
let state_ref = state.downgrade();
|
||||||
let maintain_client_status = cx.spawn(|mut cx| async move {
|
let maintain_client_status = cx.spawn(|mut cx| async move {
|
||||||
|
@ -172,8 +241,7 @@ impl CloudLanguageModelProvider {
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
client,
|
client,
|
||||||
state,
|
state: state.clone(),
|
||||||
llm_api_token: LlmApiToken::default(),
|
|
||||||
_maintain_client_status: maintain_client_status,
|
_maintain_client_status: maintain_client_status,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -272,13 +340,14 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
|
||||||
models.insert(model.id().to_string(), model.clone());
|
models.insert(model.id().to_string(), model.clone());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let llm_api_token = self.state.read(cx).llm_api_token.clone();
|
||||||
models
|
models
|
||||||
.into_values()
|
.into_values()
|
||||||
.map(|model| {
|
.map(|model| {
|
||||||
Arc::new(CloudLanguageModel {
|
Arc::new(CloudLanguageModel {
|
||||||
id: LanguageModelId::from(model.id().to_string()),
|
id: LanguageModelId::from(model.id().to_string()),
|
||||||
model,
|
model,
|
||||||
llm_api_token: self.llm_api_token.clone(),
|
llm_api_token: llm_api_token.clone(),
|
||||||
client: self.client.clone(),
|
client: self.client.clone(),
|
||||||
request_limiter: RateLimiter::new(4),
|
request_limiter: RateLimiter::new(4),
|
||||||
}) as Arc<dyn LanguageModel>
|
}) as Arc<dyn LanguageModel>
|
||||||
|
@ -377,6 +446,30 @@ pub struct CloudLanguageModel {
|
||||||
#[derive(Clone, Default)]
|
#[derive(Clone, Default)]
|
||||||
struct LlmApiToken(Arc<RwLock<Option<String>>>);
|
struct LlmApiToken(Arc<RwLock<Option<String>>>);
|
||||||
|
|
||||||
|
#[derive(Error, Debug)]
|
||||||
|
pub struct PaymentRequiredError;
|
||||||
|
|
||||||
|
impl fmt::Display for PaymentRequiredError {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||||
|
write!(
|
||||||
|
f,
|
||||||
|
"Payment required to use this language model. Please upgrade your account."
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Error, Debug)]
|
||||||
|
pub struct MaxMonthlySpendReachedError;
|
||||||
|
|
||||||
|
impl fmt::Display for MaxMonthlySpendReachedError {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||||
|
write!(
|
||||||
|
f,
|
||||||
|
"Maximum spending limit reached for this month. For more usage, increase your spending limit."
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl CloudLanguageModel {
|
impl CloudLanguageModel {
|
||||||
async fn perform_llm_completion(
|
async fn perform_llm_completion(
|
||||||
client: Arc<Client>,
|
client: Arc<Client>,
|
||||||
|
@ -411,6 +504,15 @@ impl CloudLanguageModel {
|
||||||
{
|
{
|
||||||
did_retry = true;
|
did_retry = true;
|
||||||
token = llm_api_token.refresh(&client).await?;
|
token = llm_api_token.refresh(&client).await?;
|
||||||
|
} else if response.status() == StatusCode::FORBIDDEN
|
||||||
|
&& response
|
||||||
|
.headers()
|
||||||
|
.get(MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME)
|
||||||
|
.is_some()
|
||||||
|
{
|
||||||
|
break Err(anyhow!(MaxMonthlySpendReachedError))?;
|
||||||
|
} else if response.status() == StatusCode::PAYMENT_REQUIRED {
|
||||||
|
break Err(anyhow!(PaymentRequiredError))?;
|
||||||
} else {
|
} else {
|
||||||
let mut body = String::new();
|
let mut body = String::new();
|
||||||
response.body_mut().read_to_string(&mut body).await?;
|
response.body_mut().read_to_string(&mut body).await?;
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
use crate::provider::cloud::RefreshLlmTokenListener;
|
||||||
use crate::{
|
use crate::{
|
||||||
provider::{
|
provider::{
|
||||||
anthropic::AnthropicLanguageModelProvider, cloud::CloudLanguageModelProvider,
|
anthropic::AnthropicLanguageModelProvider, cloud::CloudLanguageModelProvider,
|
||||||
|
@ -30,6 +31,8 @@ fn register_language_model_providers(
|
||||||
) {
|
) {
|
||||||
use feature_flags::FeatureFlagAppExt;
|
use feature_flags::FeatureFlagAppExt;
|
||||||
|
|
||||||
|
RefreshLlmTokenListener::register(client.clone(), cx);
|
||||||
|
|
||||||
registry.register_provider(
|
registry.register_provider(
|
||||||
AnthropicLanguageModelProvider::new(client.http_client(), cx),
|
AnthropicLanguageModelProvider::new(client.http_client(), cx),
|
||||||
cx,
|
cx,
|
||||||
|
|
|
@ -269,6 +269,7 @@ message Envelope {
|
||||||
|
|
||||||
GetLlmToken get_llm_token = 235;
|
GetLlmToken get_llm_token = 235;
|
||||||
GetLlmTokenResponse get_llm_token_response = 236;
|
GetLlmTokenResponse get_llm_token_response = 236;
|
||||||
|
RefreshLlmToken refresh_llm_token = 259; // current max
|
||||||
|
|
||||||
LspExtSwitchSourceHeader lsp_ext_switch_source_header = 241;
|
LspExtSwitchSourceHeader lsp_ext_switch_source_header = 241;
|
||||||
LspExtSwitchSourceHeaderResponse lsp_ext_switch_source_header_response = 242;
|
LspExtSwitchSourceHeaderResponse lsp_ext_switch_source_header_response = 242;
|
||||||
|
@ -284,7 +285,7 @@ message Envelope {
|
||||||
|
|
||||||
ShutdownRemoteServer shutdown_remote_server = 257;
|
ShutdownRemoteServer shutdown_remote_server = 257;
|
||||||
|
|
||||||
RemoveWorktree remove_worktree = 258; // current max
|
RemoveWorktree remove_worktree = 258;
|
||||||
}
|
}
|
||||||
|
|
||||||
reserved 87 to 88;
|
reserved 87 to 88;
|
||||||
|
@ -2429,6 +2430,8 @@ message GetLlmTokenResponse {
|
||||||
string token = 1;
|
string token = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
message RefreshLlmToken {}
|
||||||
|
|
||||||
// Remote FS
|
// Remote FS
|
||||||
|
|
||||||
message AddWorktree {
|
message AddWorktree {
|
||||||
|
|
|
@ -253,6 +253,7 @@ messages!(
|
||||||
(ProjectEntryResponse, Foreground),
|
(ProjectEntryResponse, Foreground),
|
||||||
(CountLanguageModelTokens, Background),
|
(CountLanguageModelTokens, Background),
|
||||||
(CountLanguageModelTokensResponse, Background),
|
(CountLanguageModelTokensResponse, Background),
|
||||||
|
(RefreshLlmToken, Background),
|
||||||
(RefreshInlayHints, Foreground),
|
(RefreshInlayHints, Foreground),
|
||||||
(RejoinChannelBuffers, Foreground),
|
(RejoinChannelBuffers, Foreground),
|
||||||
(RejoinChannelBuffersResponse, Foreground),
|
(RejoinChannelBuffersResponse, Foreground),
|
||||||
|
|
|
@ -3,6 +3,8 @@ use strum::{Display, EnumIter, EnumString};
|
||||||
|
|
||||||
pub const EXPIRED_LLM_TOKEN_HEADER_NAME: &str = "x-zed-expired-token";
|
pub const EXPIRED_LLM_TOKEN_HEADER_NAME: &str = "x-zed-expired-token";
|
||||||
|
|
||||||
|
pub const MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME: &str = "x-zed-llm-max-monthly-spend-reached";
|
||||||
|
|
||||||
#[derive(
|
#[derive(
|
||||||
Debug, PartialEq, Eq, Hash, Clone, Copy, Serialize, Deserialize, EnumString, EnumIter, Display,
|
Debug, PartialEq, Eq, Hash, Clone, Copy, Serialize, Deserialize, EnumString, EnumIter, Display,
|
||||||
)]
|
)]
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue