diff --git a/crates/assistant/src/assistant_panel.rs b/crates/assistant/src/assistant_panel.rs index 5d9e2e84a3..5a3df0b6cf 100644 --- a/crates/assistant/src/assistant_panel.rs +++ b/crates/assistant/src/assistant_panel.rs @@ -393,8 +393,13 @@ impl AssistantPanel { cx.subscribe(&context_store, Self::handle_context_store_event), cx.subscribe( &LanguageModelRegistry::global(cx), - |this, _, _: &language_model::ActiveModelChanged, cx| { - this.completion_provider_changed(cx); + |this, _, event: &language_model::Event, cx| match event { + language_model::Event::ActiveModelChanged => { + this.completion_provider_changed(cx); + } + language_model::Event::ProviderStateChanged => { + this.ensure_authenticated(cx); + } }, ), ]; @@ -587,6 +592,16 @@ impl AssistantPanel { } fn ensure_authenticated(&mut self, cx: &mut ViewContext) { + if self.is_authenticated(cx) { + for context_editor in self.context_editors(cx) { + context_editor.update(cx, |editor, cx| { + editor.set_authentication_prompt(None, cx); + }); + } + cx.notify(); + return; + } + let Some(provider_id) = LanguageModelRegistry::read_global(cx) .active_provider() .map(|p| p.id()) @@ -595,15 +610,18 @@ impl AssistantPanel { }; let load_credentials = self.authenticate(cx); - let task = cx.spawn(|this, mut cx| async move { - let _ = load_credentials.await; - this.update(&mut cx, |this, cx| { - this.show_authentication_prompt(cx); - }) - .log_err(); - }); - self.authenticate_provider_task = Some((provider_id, task)); + self.authenticate_provider_task = Some(( + provider_id, + cx.spawn(|this, mut cx| async move { + let _ = load_credentials.await; + this.update(&mut cx, |this, cx| { + this.show_authentication_prompt(cx); + this.authenticate_provider_task = None; + }) + .log_err(); + }), + )); } fn show_authentication_prompt(&mut self, cx: &mut ViewContext) { diff --git a/crates/language_model/src/language_model.rs b/crates/language_model/src/language_model.rs index 6dcc874721..fb8f14ba4d 100644 --- a/crates/language_model/src/language_model.rs +++ b/crates/language_model/src/language_model.rs @@ -86,10 +86,25 @@ pub trait LanguageModelProvider: 'static { fn authenticate(&self, cx: &mut AppContext) -> Task>; fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView; fn reset_credentials(&self, cx: &mut AppContext) -> Task>; + + // fn observable_entity(&self) ; } pub trait LanguageModelProviderState: 'static { - fn subscribe(&self, cx: &mut gpui::ModelContext) -> Option; + type ObservableEntity; + + fn observable_entity(&self) -> Option>; + + fn subscribe( + &self, + cx: &mut gpui::ModelContext, + callback: impl Fn(&mut T, &mut gpui::ModelContext) + 'static, + ) -> Option { + let entity = self.observable_entity()?; + Some(cx.observe(&entity, move |this, _, cx| { + callback(this, cx); + })) + } } #[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)] diff --git a/crates/language_model/src/provider/anthropic.rs b/crates/language_model/src/provider/anthropic.rs index ddaad618c4..3999483da0 100644 --- a/crates/language_model/src/provider/anthropic.rs +++ b/crates/language_model/src/provider/anthropic.rs @@ -44,7 +44,7 @@ pub struct AnthropicLanguageModelProvider { state: gpui::Model, } -struct State { +pub struct State { api_key: Option, _subscription: Subscription, } @@ -61,11 +61,12 @@ impl AnthropicLanguageModelProvider { Self { http_client, state } } } + impl LanguageModelProviderState for AnthropicLanguageModelProvider { - fn subscribe(&self, cx: &mut gpui::ModelContext) -> Option { - Some(cx.observe(&self.state, |_, _, cx| { - cx.notify(); - })) + type ObservableEntity = State; + + fn observable_entity(&self) -> Option> { + Some(self.state.clone()) } } diff --git a/crates/language_model/src/provider/cloud.rs b/crates/language_model/src/provider/cloud.rs index 362539fd85..8132d41071 100644 --- a/crates/language_model/src/provider/cloud.rs +++ b/crates/language_model/src/provider/cloud.rs @@ -50,7 +50,7 @@ pub struct CloudLanguageModelProvider { _maintain_client_status: Task<()>, } -struct State { +pub struct State { client: Arc, status: client::Status, _subscription: Subscription, @@ -99,10 +99,10 @@ impl CloudLanguageModelProvider { } impl LanguageModelProviderState for CloudLanguageModelProvider { - fn subscribe(&self, cx: &mut gpui::ModelContext) -> Option { - Some(cx.observe(&self.state, |_, _, cx| { - cx.notify(); - })) + type ObservableEntity = State; + + fn observable_entity(&self) -> Option> { + Some(self.state.clone()) } } diff --git a/crates/language_model/src/provider/copilot_chat.rs b/crates/language_model/src/provider/copilot_chat.rs index 072c87b92e..f73ddb74bf 100644 --- a/crates/language_model/src/provider/copilot_chat.rs +++ b/crates/language_model/src/provider/copilot_chat.rs @@ -11,8 +11,8 @@ use futures::future::BoxFuture; use futures::stream::BoxStream; use futures::{FutureExt, StreamExt}; use gpui::{ - percentage, svg, Animation, AnimationExt, AnyView, AppContext, AsyncAppContext, Model, - ModelContext, Render, Subscription, Task, Transformation, + percentage, svg, Animation, AnimationExt, AnyView, AppContext, AsyncAppContext, Model, Render, + Subscription, Task, Transformation, }; use settings::{Settings, SettingsStore}; use std::time::Duration; @@ -67,10 +67,10 @@ impl CopilotChatLanguageModelProvider { } impl LanguageModelProviderState for CopilotChatLanguageModelProvider { - fn subscribe(&self, cx: &mut ModelContext) -> Option { - Some(cx.observe(&self.state, |_, _, cx| { - cx.notify(); - })) + type ObservableEntity = State; + + fn observable_entity(&self) -> Option> { + Some(self.state.clone()) } } diff --git a/crates/language_model/src/provider/fake.rs b/crates/language_model/src/provider/fake.rs index f92ecaf467..70f8402bcc 100644 --- a/crates/language_model/src/provider/fake.rs +++ b/crates/language_model/src/provider/fake.rs @@ -36,7 +36,9 @@ pub struct FakeLanguageModelProvider { } impl LanguageModelProviderState for FakeLanguageModelProvider { - fn subscribe(&self, _: &mut gpui::ModelContext) -> Option { + type ObservableEntity = (); + + fn observable_entity(&self) -> Option> { None } } diff --git a/crates/language_model/src/provider/google.rs b/crates/language_model/src/provider/google.rs index 2739623c6a..a1a6cbcceb 100644 --- a/crates/language_model/src/provider/google.rs +++ b/crates/language_model/src/provider/google.rs @@ -44,7 +44,7 @@ pub struct GoogleLanguageModelProvider { state: gpui::Model, } -struct State { +pub struct State { api_key: Option, _subscription: Subscription, } @@ -63,10 +63,10 @@ impl GoogleLanguageModelProvider { } impl LanguageModelProviderState for GoogleLanguageModelProvider { - fn subscribe(&self, cx: &mut gpui::ModelContext) -> Option { - Some(cx.observe(&self.state, |_, _, cx| { - cx.notify(); - })) + type ObservableEntity = State; + + fn observable_entity(&self) -> Option> { + Some(self.state.clone()) } } diff --git a/crates/language_model/src/provider/ollama.rs b/crates/language_model/src/provider/ollama.rs index 0364866ccd..9afa3825b0 100644 --- a/crates/language_model/src/provider/ollama.rs +++ b/crates/language_model/src/provider/ollama.rs @@ -32,7 +32,7 @@ pub struct OllamaLanguageModelProvider { state: gpui::Model, } -struct State { +pub struct State { http_client: Arc, available_models: Vec, _subscription: Subscription, @@ -87,10 +87,10 @@ impl OllamaLanguageModelProvider { } impl LanguageModelProviderState for OllamaLanguageModelProvider { - fn subscribe(&self, cx: &mut gpui::ModelContext) -> Option { - Some(cx.observe(&self.state, |_, _, cx| { - cx.notify(); - })) + type ObservableEntity = State; + + fn observable_entity(&self) -> Option> { + Some(self.state.clone()) } } diff --git a/crates/language_model/src/provider/open_ai.rs b/crates/language_model/src/provider/open_ai.rs index 9f24dabb09..e0239d959b 100644 --- a/crates/language_model/src/provider/open_ai.rs +++ b/crates/language_model/src/provider/open_ai.rs @@ -45,7 +45,7 @@ pub struct OpenAiLanguageModelProvider { state: gpui::Model, } -struct State { +pub struct State { api_key: Option, _subscription: Subscription, } @@ -64,10 +64,10 @@ impl OpenAiLanguageModelProvider { } impl LanguageModelProviderState for OpenAiLanguageModelProvider { - fn subscribe(&self, cx: &mut gpui::ModelContext) -> Option { - Some(cx.observe(&self.state, |_, _, cx| { - cx.notify(); - })) + type ObservableEntity = State; + + fn observable_entity(&self) -> Option> { + Some(self.state.clone()) } } diff --git a/crates/language_model/src/registry.rs b/crates/language_model/src/registry.rs index a3af7e6b18..c7491cb70b 100644 --- a/crates/language_model/src/registry.rs +++ b/crates/language_model/src/registry.rs @@ -54,9 +54,7 @@ fn register_language_model_providers( registry.register_provider(CloudLanguageModelProvider::new(client.clone(), cx), cx); } else { registry.unregister_provider( - &LanguageModelProviderId::from( - crate::provider::cloud::PROVIDER_NAME.to_string(), - ), + &LanguageModelProviderId::from(crate::provider::cloud::PROVIDER_ID.to_string()), cx, ); } @@ -80,9 +78,12 @@ pub struct ActiveModel { model: Option>, } -pub struct ActiveModelChanged; +pub enum Event { + ActiveModelChanged, + ProviderStateChanged, +} -impl EventEmitter for LanguageModelRegistry {} +impl EventEmitter for LanguageModelRegistry {} impl LanguageModelRegistry { pub fn global(cx: &AppContext) -> Model { @@ -114,7 +115,10 @@ impl LanguageModelRegistry { ) { let name = provider.id(); - if let Some(subscription) = provider.subscribe(cx) { + let subscription = provider.subscribe(cx, |_, cx| { + cx.emit(Event::ProviderStateChanged); + }); + if let Some(subscription) = subscription { subscription.detach(); } @@ -187,7 +191,7 @@ impl LanguageModelRegistry { provider, model: None, }); - cx.emit(ActiveModelChanged); + cx.emit(Event::ActiveModelChanged); } pub fn set_active_model( @@ -202,13 +206,13 @@ impl LanguageModelRegistry { provider, model: Some(model), }); - cx.emit(ActiveModelChanged); + cx.emit(Event::ActiveModelChanged); } else { log::warn!("Active model's provider not found in registry"); } } else { self.active_model = None; - cx.emit(ActiveModelChanged); + cx.emit(Event::ActiveModelChanged); } }