diff --git a/crates/assistant2/src/assistant_panel.rs b/crates/assistant2/src/assistant_panel.rs index ee5ae4c575..ad648bd54e 100644 --- a/crates/assistant2/src/assistant_panel.rs +++ b/crates/assistant2/src/assistant_panel.rs @@ -9,6 +9,7 @@ use assistant_context_editor::{ use assistant_settings::{AssistantDockPosition, AssistantSettings}; use assistant_slash_command::SlashCommandWorkingSet; use assistant_tool::ToolWorkingSet; + use client::zed_urls; use editor::Editor; use fs::Fs; @@ -18,7 +19,7 @@ use gpui::{ ViewContext, WeakView, WindowContext, }; use language::LanguageRegistry; -use language_model::LanguageModelRegistry; +use language_model::{LanguageModelProviderTosView, LanguageModelRegistry}; use project::Project; use prompt_library::{open_prompt_library, PromptBuilder, PromptLibrary}; use settings::{update_settings_file, Settings}; @@ -663,17 +664,16 @@ impl AssistantPanel { } fn configuration_error(&self, cx: &AppContext) -> Option { - let provider = LanguageModelRegistry::read_global(cx).active_provider(); - let is_authenticated = provider - .as_ref() - .map_or(false, |provider| provider.is_authenticated(cx)); + let Some(provider) = LanguageModelRegistry::read_global(cx).active_provider() else { + return Some(ConfigurationError::NoProvider); + }; - if provider.is_some() && is_authenticated { - return None; + if !provider.is_authenticated(cx) { + return Some(ConfigurationError::ProviderNotAuthenticated); } - if !is_authenticated { - return Some(ConfigurationError::ProviderNotAuthenticated); + if provider.must_accept_terms(cx) { + return Some(ConfigurationError::ProviderPendingTermsAcceptance(provider)); } None @@ -691,6 +691,9 @@ impl AssistantPanel { .child(Headline::new("Welcome to the Assistant Panel").size(HeadlineSize::Small)) }; + let configuration_error = self.configuration_error(cx); + let no_error = configuration_error.is_none(); + v_flex() .gap_2() .child( @@ -704,41 +707,51 @@ impl AssistantPanel { .mb_4(), ), ) - .when( - matches!( - self.configuration_error(cx), - Some(ConfigurationError::ProviderNotAuthenticated) - ), - |parent| { - parent.child( - v_flex() - .gap_0p5() - .child(create_welcome_heading()) - .child( - h_flex().mb_2().w_full().justify_center().child( - Label::new( - "To start using the assistant, configure at least one LLM provider.", - ) - .color(Color::Muted), + .map(|parent| { + match configuration_error { + Some(ConfigurationError::ProviderNotAuthenticated) | Some(ConfigurationError::NoProvider) => { + parent.child( + v_flex() + .gap_0p5() + .child(create_welcome_heading()) + .child( + h_flex().mb_2().w_full().justify_center().child( + Label::new( + "To start using the assistant, configure at least one LLM provider.", + ) + .color(Color::Muted), + ), + ) + .child( + h_flex().w_full().justify_center().child( + Button::new("open-configuration", "Configure a Provider") + .size(ButtonSize::Compact) + .icon(Some(IconName::Sliders)) + .icon_size(IconSize::Small) + .icon_position(IconPosition::Start) + .on_click(cx.listener(|this, _, cx| { + this.open_configuration(cx); + })), + ), ), - ) - .child( - h_flex().w_full().justify_center().child( - Button::new("open-configuration", "Configure a Provider") - .size(ButtonSize::Compact) - .icon(Some(IconName::Sliders)) - .icon_size(IconSize::Small) - .icon_position(IconPosition::Start) - .on_click(cx.listener(|this, _, cx| { - this.open_configuration(cx); - })), - ), - ), - ) - }, - ) + ) + } + Some(ConfigurationError::ProviderPendingTermsAcceptance(provider)) => { + parent.child( + v_flex() + .gap_0p5() + .child(create_welcome_heading()) + .children(provider.render_accept_terms( + LanguageModelProviderTosView::ThreadEmptyState, + cx, + )), + ) + } + None => parent, + } + }) .when( - recent_threads.is_empty() && self.configuration_error(cx).is_none(), + recent_threads.is_empty() && no_error, |parent| { parent.child( v_flex().gap_0p5().child(create_welcome_heading()).child( diff --git a/crates/assistant_context_editor/src/context_editor.rs b/crates/assistant_context_editor/src/context_editor.rs index dfe9c6fc74..c22ebeec7d 100644 --- a/crates/assistant_context_editor/src/context_editor.rs +++ b/crates/assistant_context_editor/src/context_editor.rs @@ -31,7 +31,10 @@ use gpui::{ }; use indexed_docs::IndexedDocsStore; use language::{language_settings::SoftWrap, BufferSnapshot, LspAdapterDelegate, ToOffset}; -use language_model::{LanguageModelImage, LanguageModelRegistry, LanguageModelToolUse, Role}; +use language_model::{ + LanguageModelImage, LanguageModelProvider, LanguageModelProviderTosView, LanguageModelRegistry, + LanguageModelToolUse, Role, +}; use language_model_selector::{LanguageModelSelector, LanguageModelSelectorPopoverMenu}; use multi_buffer::MultiBufferRow; use picker::Picker; @@ -2260,6 +2263,9 @@ impl ContextEditor { let label = match configuration_error { ConfigurationError::NoProvider => "No LLM provider selected.", ConfigurationError::ProviderNotAuthenticated => "LLM provider is not configured.", + ConfigurationError::ProviderPendingTermsAcceptance(_) => { + "LLM provider requires accepting the Terms of Service." + } }; Some( h_flex() @@ -2855,9 +2861,9 @@ impl Render for ContextEditor { fn render(&mut self, cx: &mut ViewContext) -> impl IntoElement { let provider = LanguageModelRegistry::read_global(cx).active_provider(); let accept_terms = if self.show_accept_terms { - provider - .as_ref() - .and_then(|provider| provider.render_accept_terms(cx)) + provider.as_ref().and_then(|provider| { + provider.render_accept_terms(LanguageModelProviderTosView::PromptEditorPopup, cx) + }) } else { None }; @@ -3502,6 +3508,7 @@ fn size_for_image(data: &RenderImage, max_size: Size) -> Size { pub enum ConfigurationError { NoProvider, ProviderNotAuthenticated, + ProviderPendingTermsAcceptance(Arc), } fn configuration_error(cx: &AppContext) -> Option { diff --git a/crates/language_model/src/language_model.rs b/crates/language_model/src/language_model.rs index a10c743b35..203e4a025d 100644 --- a/crates/language_model/src/language_model.rs +++ b/crates/language_model/src/language_model.rs @@ -245,12 +245,23 @@ pub trait LanguageModelProvider: 'static { fn must_accept_terms(&self, _cx: &AppContext) -> bool { false } - fn render_accept_terms(&self, _cx: &mut WindowContext) -> Option { + fn render_accept_terms( + &self, + _view: LanguageModelProviderTosView, + _cx: &mut WindowContext, + ) -> Option { None } fn reset_credentials(&self, cx: &mut AppContext) -> Task>; } +#[derive(PartialEq, Eq)] +pub enum LanguageModelProviderTosView { + ThreadEmptyState, + PromptEditorPopup, + Configuration, +} + pub trait LanguageModelProviderState: 'static { type ObservableEntity; diff --git a/crates/language_models/src/provider/cloud.rs b/crates/language_models/src/provider/cloud.rs index adb7960d9f..cccdc90229 100644 --- a/crates/language_models/src/provider/cloud.rs +++ b/crates/language_models/src/provider/cloud.rs @@ -12,14 +12,14 @@ use futures::{ TryStreamExt as _, }; use gpui::{ - AnyElement, AnyView, AppContext, AsyncAppContext, EventEmitter, FontWeight, Global, Model, - ModelContext, ReadGlobal, Subscription, Task, + AnyElement, AnyView, AppContext, AsyncAppContext, EventEmitter, Global, Model, ModelContext, + ReadGlobal, Subscription, Task, }; use http_client::{AsyncBody, HttpClient, Method, Response, StatusCode}; use language_model::{ CloudModel, LanguageModel, LanguageModelCacheConfiguration, LanguageModelId, LanguageModelName, LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, - LanguageModelRequest, RateLimiter, ZED_CLOUD_PROVIDER_ID, + LanguageModelProviderTosView, LanguageModelRequest, RateLimiter, ZED_CLOUD_PROVIDER_ID, }; use language_model::{ LanguageModelAvailability, LanguageModelCompletionEvent, LanguageModelProvider, @@ -378,60 +378,12 @@ impl LanguageModelProvider for CloudLanguageModelProvider { !self.state.read(cx).has_accepted_terms_of_service(cx) } - fn render_accept_terms(&self, cx: &mut WindowContext) -> Option { - let state = self.state.read(cx); - - let terms = [( - "terms_of_service", - "Terms of Service", - "https://zed.dev/terms-of-service", - )] - .map(|(id, label, url)| { - Button::new(id, label) - .style(ButtonStyle::Subtle) - .icon(IconName::ExternalLink) - .icon_size(IconSize::XSmall) - .icon_color(Color::Muted) - .on_click(move |_, cx| cx.open_url(url)) - }); - - if state.has_accepted_terms_of_service(cx) { - None - } else { - let disabled = state.accept_terms.is_some(); - Some( - v_flex() - .gap_2() - .child( - v_flex() - .child(Label::new("Terms and Conditions").weight(FontWeight::MEDIUM)) - .child( - Label::new( - "Please read and accept our terms and conditions to continue.", - ) - .size(LabelSize::Small), - ), - ) - .child(v_flex().gap_1().children(terms)) - .child( - h_flex().justify_end().child( - Button::new("accept_terms", "I've read it and accept it") - .disabled(disabled) - .on_click({ - let state = self.state.downgrade(); - move |_, cx| { - state - .update(cx, |state, cx| { - state.accept_terms_of_service(cx) - }) - .ok(); - } - }), - ), - ) - .into_any(), - ) - } + fn render_accept_terms( + &self, + view: LanguageModelProviderTosView, + cx: &mut WindowContext, + ) -> Option { + render_accept_terms(self.state.clone(), view, cx) } fn reset_credentials(&self, _cx: &mut AppContext) -> Task> { @@ -439,6 +391,68 @@ impl LanguageModelProvider for CloudLanguageModelProvider { } } +fn render_accept_terms( + state: Model, + view_kind: LanguageModelProviderTosView, + cx: &mut WindowContext, +) -> Option { + if state.read(cx).has_accepted_terms_of_service(cx) { + return None; + } + + let accept_terms_disabled = state.read(cx).accept_terms.is_some(); + + let terms_button = Button::new("terms_of_service", "Terms of Service") + .style(ButtonStyle::Subtle) + .icon(IconName::ArrowUpRight) + .icon_color(Color::Muted) + .icon_size(IconSize::XSmall) + .on_click(move |_, cx| cx.open_url("https://zed.dev/terms-of-service")); + + let text = "To start using Zed AI, please read and accept the"; + + let form = v_flex() + .w_full() + .gap_2() + .when( + view_kind == LanguageModelProviderTosView::ThreadEmptyState, + |form| form.items_center(), + ) + .child( + h_flex() + .flex_wrap() + .when( + view_kind == LanguageModelProviderTosView::ThreadEmptyState, + |form| form.justify_center(), + ) + .child(Label::new(text)) + .child(terms_button), + ) + .child({ + let button_container = h_flex().w_full().child( + Button::new("accept_terms", "I accept the Terms of Service") + .style(ButtonStyle::Tinted(TintColor::Accent)) + .disabled(accept_terms_disabled) + .on_click({ + let state = state.downgrade(); + move |_, cx| { + state + .update(cx, |state, cx| state.accept_terms_of_service(cx)) + .ok(); + } + }), + ); + + match view_kind { + LanguageModelProviderTosView::ThreadEmptyState => button_container.justify_center(), + LanguageModelProviderTosView::PromptEditorPopup => button_container.justify_end(), + LanguageModelProviderTosView::Configuration => button_container.justify_start(), + } + }); + + Some(form.into_any()) +} + pub struct CloudLanguageModel { id: LanguageModelId, model: CloudModel, @@ -852,44 +866,6 @@ impl ConfigurationView { }); cx.notify(); } - - fn render_accept_terms(&mut self, cx: &mut ViewContext) -> Option { - if self.state.read(cx).has_accepted_terms_of_service(cx) { - return None; - } - - let accept_terms_disabled = self.state.read(cx).accept_terms.is_some(); - - let terms_button = Button::new("terms_of_service", "Terms of Service") - .style(ButtonStyle::Subtle) - .icon(IconName::ArrowUpRight) - .icon_color(Color::Muted) - .icon_size(IconSize::XSmall) - .on_click(move |_, cx| cx.open_url("https://zed.dev/terms-of-service")); - - let text = "To start using Zed AI, please read and accept the"; - - let form = v_flex() - .gap_1() - .child(h_flex().child(Label::new(text)).child(terms_button)) - .child( - h_flex().child( - Button::new("accept_terms", "I've read and accept the Terms of Service") - .style(ButtonStyle::Tinted(TintColor::Accent)) - .disabled(accept_terms_disabled) - .on_click({ - let state = self.state.downgrade(); - move |_, cx| { - state - .update(cx, |state, cx| state.accept_terms_of_service(cx)) - .ok(); - } - }), - ), - ); - - Some(form.into_any()) - } } impl Render for ConfigurationView { @@ -939,8 +915,12 @@ impl Render for ConfigurationView { if is_connected { v_flex() .gap_3() - .max_w_4_5() - .children(self.render_accept_terms(cx)) + .w_full() + .children(render_accept_terms( + self.state.clone(), + LanguageModelProviderTosView::Configuration, + cx, + )) .when(has_accepted_terms, |this| { this.child(subscription_text) .children(manage_subscription_button)