diff --git a/crates/agent/src/agent_panel.rs b/crates/agent/src/agent_panel.rs index 464305f73d..d00c979194 100644 --- a/crates/agent/src/agent_panel.rs +++ b/crates/agent/src/agent_panel.rs @@ -10,9 +10,9 @@ use serde::{Deserialize, Serialize}; use agent_settings::{AgentDockPosition, AgentSettings, CompletionMode, DefaultView}; use anyhow::{Result, anyhow}; use assistant_context_editor::{ - AgentPanelDelegate, AssistantContext, ConfigurationError, ContextEditor, ContextEvent, - ContextSummary, SlashCommandCompletionProvider, humanize_token_count, - make_lsp_adapter_delegate, render_remaining_tokens, + AgentPanelDelegate, AssistantContext, ContextEditor, ContextEvent, ContextSummary, + SlashCommandCompletionProvider, humanize_token_count, make_lsp_adapter_delegate, + render_remaining_tokens, }; use assistant_slash_command::SlashCommandWorkingSet; use assistant_tool::ToolWorkingSet; @@ -29,7 +29,8 @@ use gpui::{ }; use language::LanguageRegistry; use language_model::{ - LanguageModelProviderTosView, LanguageModelRegistry, RequestUsage, ZED_CLOUD_PROVIDER_ID, + ConfigurationError, LanguageModelProviderTosView, LanguageModelRegistry, RequestUsage, + ZED_CLOUD_PROVIDER_ID, }; use project::{Project, ProjectPath, Worktree}; use prompt_store::{PromptBuilder, PromptStore, UserPromptId}; @@ -2353,24 +2354,6 @@ impl AgentPanel { self.thread.clone().into_any_element() } - fn configuration_error(&self, cx: &App) -> Option { - let Some(model) = LanguageModelRegistry::read_global(cx).default_model() else { - return Some(ConfigurationError::NoProvider); - }; - - if !model.provider.is_authenticated(cx) { - return Some(ConfigurationError::ProviderNotAuthenticated); - } - - if model.provider.must_accept_terms(cx) { - return Some(ConfigurationError::ProviderPendingTermsAcceptance( - model.provider, - )); - } - - None - } - fn render_thread_empty_state( &self, window: &mut Window, @@ -2380,7 +2363,9 @@ impl AgentPanel { .history_store .update(cx, |this, cx| this.recent_entries(6, cx)); - let configuration_error = self.configuration_error(cx); + let model_registry = LanguageModelRegistry::read_global(cx); + let configuration_error = + model_registry.configuration_error(model_registry.default_model(), cx); let no_error = configuration_error.is_none(); let focus_handle = self.focus_handle(cx); @@ -2397,11 +2382,7 @@ impl AgentPanel { .justify_center() .items_center() .gap_1() - .child( - h_flex().child( - Headline::new("Welcome to the Agent Panel") - ), - ) + .child(h_flex().child(Headline::new("Welcome to the Agent Panel"))) .when(no_error, |parent| { parent .child( @@ -2425,7 +2406,10 @@ impl AgentPanel { cx, )) .on_click(|_event, window, cx| { - window.dispatch_action(NewThread::default().boxed_clone(), cx) + window.dispatch_action( + NewThread::default().boxed_clone(), + cx, + ) }), ) .child( @@ -2442,7 +2426,10 @@ impl AgentPanel { cx, )) .on_click(|_event, window, cx| { - window.dispatch_action(ToggleContextPicker.boxed_clone(), cx) + window.dispatch_action( + ToggleContextPicker.boxed_clone(), + cx, + ) }), ) .child( @@ -2459,7 +2446,10 @@ impl AgentPanel { cx, )) .on_click(|_event, window, cx| { - window.dispatch_action(ToggleModelSelector.boxed_clone(), cx) + window.dispatch_action( + ToggleModelSelector.boxed_clone(), + cx, + ) }), ) .child( @@ -2476,51 +2466,50 @@ impl AgentPanel { cx, )) .on_click(|_event, window, cx| { - window.dispatch_action(OpenConfiguration.boxed_clone(), cx) + window.dispatch_action( + OpenConfiguration.boxed_clone(), + cx, + ) }), ) }) - .map(|parent| { - match configuration_error_ref { - Some(ConfigurationError::ProviderNotAuthenticated) - | Some(ConfigurationError::NoProvider) => { - parent - .child( - h_flex().child( - Label::new("To start using the agent, configure at least one LLM provider.") - .color(Color::Muted) - .mb_2p5() - ) - ) - .child( - Button::new("settings", "Configure a Provider") - .icon(IconName::Settings) - .icon_position(IconPosition::Start) - .icon_size(IconSize::Small) - .icon_color(Color::Muted) - .full_width() - .key_binding(KeyBinding::for_action_in( - &OpenConfiguration, - &focus_handle, - window, - cx, - )) - .on_click(|_event, window, cx| { - window.dispatch_action(OpenConfiguration.boxed_clone(), cx) - }), - ) - } - Some(ConfigurationError::ProviderPendingTermsAcceptance(provider)) => { - parent.children( - provider.render_accept_terms( - LanguageModelProviderTosView::ThreadFreshStart, + .map(|parent| match configuration_error_ref { + Some( + err @ (ConfigurationError::ModelNotFound + | ConfigurationError::ProviderNotAuthenticated(_) + | ConfigurationError::NoProvider), + ) => parent + .child(h_flex().child( + Label::new(err.to_string()).color(Color::Muted).mb_2p5(), + )) + .child( + Button::new("settings", "Configure a Provider") + .icon(IconName::Settings) + .icon_position(IconPosition::Start) + .icon_size(IconSize::Small) + .icon_color(Color::Muted) + .full_width() + .key_binding(KeyBinding::for_action_in( + &OpenConfiguration, + &focus_handle, + window, cx, - ), - ) - } - None => parent, + )) + .on_click(|_event, window, cx| { + window.dispatch_action( + OpenConfiguration.boxed_clone(), + cx, + ) + }), + ), + Some(ConfigurationError::ProviderPendingTermsAcceptance(provider)) => { + parent.children(provider.render_accept_terms( + LanguageModelProviderTosView::ThreadFreshStart, + cx, + )) } - }) + None => parent, + }), ) }) .when(!recent_history.is_empty(), |parent| { @@ -2555,7 +2544,8 @@ impl AgentPanel { &self.focus_handle(cx), window, cx, - ).map(|kb| kb.size(rems_from_px(12.))), + ) + .map(|kb| kb.size(rems_from_px(12.))), ) .on_click(move |_event, window, cx| { window.dispatch_action(OpenHistory.boxed_clone(), cx); @@ -2565,79 +2555,68 @@ impl AgentPanel { .child( v_flex() .gap_1() - .children( - recent_history.into_iter().enumerate().map(|(index, entry)| { + .children(recent_history.into_iter().enumerate().map( + |(index, entry)| { // TODO: Add keyboard navigation. - let is_hovered = self.hovered_recent_history_item == Some(index); + let is_hovered = + self.hovered_recent_history_item == Some(index); HistoryEntryElement::new(entry.clone(), cx.entity().downgrade()) .hovered(is_hovered) - .on_hover(cx.listener(move |this, is_hovered, _window, cx| { - if *is_hovered { - this.hovered_recent_history_item = Some(index); - } else if this.hovered_recent_history_item == Some(index) { - this.hovered_recent_history_item = None; - } - cx.notify(); - })) + .on_hover(cx.listener( + move |this, is_hovered, _window, cx| { + if *is_hovered { + this.hovered_recent_history_item = Some(index); + } else if this.hovered_recent_history_item + == Some(index) + { + this.hovered_recent_history_item = None; + } + cx.notify(); + }, + )) .into_any_element() - }), - ) + }, + )), ) - .map(|parent| { - match configuration_error_ref { - Some(ConfigurationError::ProviderNotAuthenticated) - | Some(ConfigurationError::NoProvider) => { - parent - .child( - Banner::new() - .severity(ui::Severity::Warning) - .child( - Label::new( - "Configure at least one LLM provider to start using the panel.", - ) - .size(LabelSize::Small), + .map(|parent| match configuration_error_ref { + Some( + err @ (ConfigurationError::ModelNotFound + | ConfigurationError::ProviderNotAuthenticated(_) + | ConfigurationError::NoProvider), + ) => parent.child( + Banner::new() + .severity(ui::Severity::Warning) + .child(Label::new(err.to_string()).size(LabelSize::Small)) + .action_slot( + Button::new("settings", "Configure Provider") + .style(ButtonStyle::Tinted(ui::TintColor::Warning)) + .label_size(LabelSize::Small) + .key_binding( + KeyBinding::for_action_in( + &OpenConfiguration, + &focus_handle, + window, + cx, ) - .action_slot( - Button::new("settings", "Configure Provider") - .style(ButtonStyle::Tinted(ui::TintColor::Warning)) - .label_size(LabelSize::Small) - .key_binding( - KeyBinding::for_action_in( - &OpenConfiguration, - &focus_handle, - window, - cx, - ) - .map(|kb| kb.size(rems_from_px(12.))), - ) - .on_click(|_event, window, cx| { - window.dispatch_action( - OpenConfiguration.boxed_clone(), - cx, - ) - }), - ), - ) - } - Some(ConfigurationError::ProviderPendingTermsAcceptance(provider)) => { - parent - .child( - Banner::new() - .severity(ui::Severity::Warning) - .child( - h_flex() - .w_full() - .children( - provider.render_accept_terms( - LanguageModelProviderTosView::ThreadtEmptyState, - cx, - ), - ), - ), - ) - } - None => parent, + .map(|kb| kb.size(rems_from_px(12.))), + ) + .on_click(|_event, window, cx| { + window.dispatch_action( + OpenConfiguration.boxed_clone(), + cx, + ) + }), + ), + ), + Some(ConfigurationError::ProviderPendingTermsAcceptance(provider)) => { + parent.child(Banner::new().severity(ui::Severity::Warning).child( + h_flex().w_full().children(provider.render_accept_terms( + LanguageModelProviderTosView::ThreadtEmptyState, + cx, + )), + )) } + None => parent, }) }) } diff --git a/crates/agent/src/inline_assistant.rs b/crates/agent/src/inline_assistant.rs index 7c85381cca..28841e21ad 100644 --- a/crates/agent/src/inline_assistant.rs +++ b/crates/agent/src/inline_assistant.rs @@ -24,6 +24,7 @@ use gpui::{ WeakEntity, Window, point, }; use language::{Buffer, Point, Selection, TransactionId}; +use language_model::ConfigurationError; use language_model::ConfiguredModel; use language_model::{LanguageModelRegistry, report_assistant_event}; use multi_buffer::MultiBufferRow; @@ -232,10 +233,9 @@ impl InlineAssistant { return; }; - let is_authenticated = || { - LanguageModelRegistry::read_global(cx) - .inline_assistant_model() - .map_or(false, |model| model.provider.is_authenticated(cx)) + let configuration_error = || { + let model_registry = LanguageModelRegistry::read_global(cx); + model_registry.configuration_error(model_registry.inline_assistant_model(), cx) }; let Some(agent_panel) = workspace.panel::(cx) else { @@ -283,20 +283,23 @@ impl InlineAssistant { } }; - if is_authenticated() { - handle_assist(window, cx); - } else { - cx.spawn_in(window, async move |_workspace, cx| { - let Some(task) = cx.update(|_, cx| { - LanguageModelRegistry::read_global(cx) - .inline_assistant_model() - .map_or(None, |model| Some(model.provider.authenticate(cx))) - })? - else { + if let Some(error) = configuration_error() { + if let ConfigurationError::ProviderNotAuthenticated(provider) = error { + cx.spawn(async move |_, cx| { + cx.update(|cx| provider.authenticate(cx))?.await?; + anyhow::Ok(()) + }) + .detach_and_log_err(cx); + + if configuration_error().is_none() { + handle_assist(window, cx); + } + } else { + cx.spawn_in(window, async move |_, cx| { let answer = cx .prompt( gpui::PromptLevel::Warning, - "No language model provider configured", + &error.to_string(), None, &["Configure", "Cancel"], ) @@ -310,17 +313,12 @@ impl InlineAssistant { .ok(); } } - return Ok(()); - }; - task.await?; - - anyhow::Ok(()) - }) - .detach_and_log_err(cx); - - if is_authenticated() { - handle_assist(window, cx); + anyhow::Ok(()) + }) + .detach_and_log_err(cx); } + } else { + handle_assist(window, cx); } } diff --git a/crates/assistant_context_editor/src/context_editor.rs b/crates/assistant_context_editor/src/context_editor.rs index bf590df964..24e59e449c 100644 --- a/crates/assistant_context_editor/src/context_editor.rs +++ b/crates/assistant_context_editor/src/context_editor.rs @@ -39,7 +39,7 @@ use language::{ language_settings::{SoftWrap, all_language_settings}, }; use language_model::{ - LanguageModelImage, LanguageModelProvider, LanguageModelProviderTosView, LanguageModelRegistry, + ConfigurationError, LanguageModelImage, LanguageModelProviderTosView, LanguageModelRegistry, Role, }; use multi_buffer::MultiBufferRow; @@ -1887,6 +1887,8 @@ impl ContextEditor { // value to not show the nudge. let nudge = Some(false); + let model_registry = LanguageModelRegistry::read_global(cx); + if nudge.map_or(false, |value| value) { Some( h_flex() @@ -1935,14 +1937,9 @@ impl ContextEditor { ) .into_any_element(), ) - } else if let Some(configuration_error) = configuration_error(cx) { - let label = match configuration_error { - ConfigurationError::NoProvider => "No LLM provider selected.", - ConfigurationError::ProviderNotAuthenticated => "LLM provider is not configured.", - ConfigurationError::ProviderPendingTermsAcceptance(_) => { - "LLM provider requires accepting the Terms of Service." - } - }; + } else if let Some(configuration_error) = + model_registry.configuration_error(model_registry.default_model(), cx) + { Some( h_flex() .px_3() @@ -1959,7 +1956,7 @@ impl ContextEditor { .size(IconSize::Small) .color(Color::Warning), ) - .child(Label::new(label)), + .child(Label::new(configuration_error.to_string())), ) .child( Button::new("open-configuration", "Configure Providers") @@ -2034,14 +2031,19 @@ impl ContextEditor { /// Will return false if the selected provided has a configuration error or /// if the user has not accepted the terms of service for this provider. fn sending_disabled(&self, cx: &mut Context<'_, ContextEditor>) -> bool { - let model = LanguageModelRegistry::read_global(cx).default_model(); + let model_registry = LanguageModelRegistry::read_global(cx); + let Some(configuration_error) = + model_registry.configuration_error(model_registry.default_model(), cx) + else { + return false; + }; - let has_configuration_error = configuration_error(cx).is_some(); - let needs_to_accept_terms = self.show_accept_terms - && model - .as_ref() - .map_or(false, |model| model.provider.must_accept_terms(cx)); - has_configuration_error || needs_to_accept_terms + match configuration_error { + ConfigurationError::NoProvider + | ConfigurationError::ModelNotFound + | ConfigurationError::ProviderNotAuthenticated(_) => true, + ConfigurationError::ProviderPendingTermsAcceptance(_) => self.show_accept_terms, + } } fn render_inject_context_menu(&self, cx: &mut Context) -> impl IntoElement { @@ -3180,33 +3182,6 @@ fn size_for_image(data: &RenderImage, max_size: Size) -> Size { } } -pub enum ConfigurationError { - NoProvider, - ProviderNotAuthenticated, - ProviderPendingTermsAcceptance(Arc), -} - -fn configuration_error(cx: &App) -> Option { - let model = LanguageModelRegistry::read_global(cx).default_model(); - let is_authenticated = model - .as_ref() - .map_or(false, |model| model.provider.is_authenticated(cx)); - - if model.is_some() && is_authenticated { - return None; - } - - if model.is_none() { - return Some(ConfigurationError::NoProvider); - } - - if !is_authenticated { - return Some(ConfigurationError::ProviderNotAuthenticated); - } - - None -} - pub fn humanize_token_count(count: usize) -> String { match count { 0..=999 => count.to_string(), diff --git a/crates/language_model/src/registry.rs b/crates/language_model/src/registry.rs index e094f61b08..e9f03cc1ff 100644 --- a/crates/language_model/src/registry.rs +++ b/crates/language_model/src/registry.rs @@ -5,6 +5,7 @@ use crate::{ use collections::BTreeMap; use gpui::{App, Context, Entity, EventEmitter, Global, prelude::*}; use std::{str::FromStr, sync::Arc}; +use thiserror::Error; use util::maybe; pub fn init(cx: &mut App) { @@ -16,6 +17,34 @@ struct GlobalLanguageModelRegistry(Entity); impl Global for GlobalLanguageModelRegistry {} +#[derive(Error)] +pub enum ConfigurationError { + #[error("Configure at least one LLM provider to start using the panel.")] + NoProvider, + #[error("LLM Provider is not configured or does not support the configured model.")] + ModelNotFound, + #[error("{} LLM provider is not configured.", .0.name().0)] + ProviderNotAuthenticated(Arc), + #[error("Using the {} LLM provider requires accepting the Terms of Service.", + .0.name().0)] + ProviderPendingTermsAcceptance(Arc), +} + +impl std::fmt::Debug for ConfigurationError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::NoProvider => write!(f, "NoProvider"), + Self::ModelNotFound => write!(f, "ModelNotFound"), + Self::ProviderNotAuthenticated(provider) => { + write!(f, "ProviderNotAuthenticated({})", provider.id()) + } + Self::ProviderPendingTermsAcceptance(provider) => { + write!(f, "ProviderPendingTermsAcceptance({})", provider.id()) + } + } + } +} + #[derive(Default)] pub struct LanguageModelRegistry { default_model: Option, @@ -152,6 +181,36 @@ impl LanguageModelRegistry { providers } + pub fn configuration_error( + &self, + model: Option, + cx: &App, + ) -> Option { + let Some(model) = model else { + if !self.has_authenticated_provider(cx) { + return Some(ConfigurationError::NoProvider); + } + return Some(ConfigurationError::ModelNotFound); + }; + + if !model.provider.is_authenticated(cx) { + return Some(ConfigurationError::ProviderNotAuthenticated(model.provider)); + } + + if model.provider.must_accept_terms(cx) { + return Some(ConfigurationError::ProviderPendingTermsAcceptance( + model.provider, + )); + } + + None + } + + /// Check that we have at least one provider that is authenticated. + fn has_authenticated_provider(&self, cx: &App) -> bool { + self.providers.values().any(|p| p.is_authenticated(cx)) + } + pub fn available_models<'a>( &'a self, cx: &'a App,