diff --git a/crates/assistant/src/assistant_panel.rs b/crates/assistant/src/assistant_panel.rs index 5a3df0b6cf..5d9e2e84a3 100644 --- a/crates/assistant/src/assistant_panel.rs +++ b/crates/assistant/src/assistant_panel.rs @@ -393,13 +393,8 @@ impl AssistantPanel { cx.subscribe(&context_store, Self::handle_context_store_event), cx.subscribe( &LanguageModelRegistry::global(cx), - |this, _, event: &language_model::Event, cx| match event { - language_model::Event::ActiveModelChanged => { - this.completion_provider_changed(cx); - } - language_model::Event::ProviderStateChanged => { - this.ensure_authenticated(cx); - } + |this, _, _: &language_model::ActiveModelChanged, cx| { + this.completion_provider_changed(cx); }, ), ]; @@ -592,16 +587,6 @@ impl AssistantPanel { } fn ensure_authenticated(&mut self, cx: &mut ViewContext) { - if self.is_authenticated(cx) { - for context_editor in self.context_editors(cx) { - context_editor.update(cx, |editor, cx| { - editor.set_authentication_prompt(None, cx); - }); - } - cx.notify(); - return; - } - let Some(provider_id) = LanguageModelRegistry::read_global(cx) .active_provider() .map(|p| p.id()) @@ -610,18 +595,15 @@ impl AssistantPanel { }; let load_credentials = self.authenticate(cx); + let task = cx.spawn(|this, mut cx| async move { + let _ = load_credentials.await; + this.update(&mut cx, |this, cx| { + this.show_authentication_prompt(cx); + }) + .log_err(); + }); - self.authenticate_provider_task = Some(( - provider_id, - cx.spawn(|this, mut cx| async move { - let _ = load_credentials.await; - this.update(&mut cx, |this, cx| { - this.show_authentication_prompt(cx); - this.authenticate_provider_task = None; - }) - .log_err(); - }), - )); + self.authenticate_provider_task = Some((provider_id, task)); } fn show_authentication_prompt(&mut self, cx: &mut ViewContext) { diff --git a/crates/language_model/src/language_model.rs b/crates/language_model/src/language_model.rs index fb8f14ba4d..6dcc874721 100644 --- a/crates/language_model/src/language_model.rs +++ b/crates/language_model/src/language_model.rs @@ -86,25 +86,10 @@ pub trait LanguageModelProvider: 'static { fn authenticate(&self, cx: &mut AppContext) -> Task>; fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView; fn reset_credentials(&self, cx: &mut AppContext) -> Task>; - - // fn observable_entity(&self) ; } pub trait LanguageModelProviderState: 'static { - type ObservableEntity; - - fn observable_entity(&self) -> Option>; - - fn subscribe( - &self, - cx: &mut gpui::ModelContext, - callback: impl Fn(&mut T, &mut gpui::ModelContext) + 'static, - ) -> Option { - let entity = self.observable_entity()?; - Some(cx.observe(&entity, move |this, _, cx| { - callback(this, cx); - })) - } + fn subscribe(&self, cx: &mut gpui::ModelContext) -> Option; } #[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)] diff --git a/crates/language_model/src/provider/anthropic.rs b/crates/language_model/src/provider/anthropic.rs index 3999483da0..ddaad618c4 100644 --- a/crates/language_model/src/provider/anthropic.rs +++ b/crates/language_model/src/provider/anthropic.rs @@ -44,7 +44,7 @@ pub struct AnthropicLanguageModelProvider { state: gpui::Model, } -pub struct State { +struct State { api_key: Option, _subscription: Subscription, } @@ -61,12 +61,11 @@ impl AnthropicLanguageModelProvider { Self { http_client, state } } } - impl LanguageModelProviderState for AnthropicLanguageModelProvider { - type ObservableEntity = State; - - fn observable_entity(&self) -> Option> { - Some(self.state.clone()) + fn subscribe(&self, cx: &mut gpui::ModelContext) -> Option { + Some(cx.observe(&self.state, |_, _, cx| { + cx.notify(); + })) } } diff --git a/crates/language_model/src/provider/cloud.rs b/crates/language_model/src/provider/cloud.rs index 8132d41071..362539fd85 100644 --- a/crates/language_model/src/provider/cloud.rs +++ b/crates/language_model/src/provider/cloud.rs @@ -50,7 +50,7 @@ pub struct CloudLanguageModelProvider { _maintain_client_status: Task<()>, } -pub struct State { +struct State { client: Arc, status: client::Status, _subscription: Subscription, @@ -99,10 +99,10 @@ impl CloudLanguageModelProvider { } impl LanguageModelProviderState for CloudLanguageModelProvider { - type ObservableEntity = State; - - fn observable_entity(&self) -> Option> { - Some(self.state.clone()) + fn subscribe(&self, cx: &mut gpui::ModelContext) -> Option { + Some(cx.observe(&self.state, |_, _, cx| { + cx.notify(); + })) } } diff --git a/crates/language_model/src/provider/copilot_chat.rs b/crates/language_model/src/provider/copilot_chat.rs index f73ddb74bf..072c87b92e 100644 --- a/crates/language_model/src/provider/copilot_chat.rs +++ b/crates/language_model/src/provider/copilot_chat.rs @@ -11,8 +11,8 @@ use futures::future::BoxFuture; use futures::stream::BoxStream; use futures::{FutureExt, StreamExt}; use gpui::{ - percentage, svg, Animation, AnimationExt, AnyView, AppContext, AsyncAppContext, Model, Render, - Subscription, Task, Transformation, + percentage, svg, Animation, AnimationExt, AnyView, AppContext, AsyncAppContext, Model, + ModelContext, Render, Subscription, Task, Transformation, }; use settings::{Settings, SettingsStore}; use std::time::Duration; @@ -67,10 +67,10 @@ impl CopilotChatLanguageModelProvider { } impl LanguageModelProviderState for CopilotChatLanguageModelProvider { - type ObservableEntity = State; - - fn observable_entity(&self) -> Option> { - Some(self.state.clone()) + fn subscribe(&self, cx: &mut ModelContext) -> Option { + Some(cx.observe(&self.state, |_, _, cx| { + cx.notify(); + })) } } diff --git a/crates/language_model/src/provider/fake.rs b/crates/language_model/src/provider/fake.rs index 70f8402bcc..f92ecaf467 100644 --- a/crates/language_model/src/provider/fake.rs +++ b/crates/language_model/src/provider/fake.rs @@ -36,9 +36,7 @@ pub struct FakeLanguageModelProvider { } impl LanguageModelProviderState for FakeLanguageModelProvider { - type ObservableEntity = (); - - fn observable_entity(&self) -> Option> { + fn subscribe(&self, _: &mut gpui::ModelContext) -> Option { None } } diff --git a/crates/language_model/src/provider/google.rs b/crates/language_model/src/provider/google.rs index a1a6cbcceb..2739623c6a 100644 --- a/crates/language_model/src/provider/google.rs +++ b/crates/language_model/src/provider/google.rs @@ -44,7 +44,7 @@ pub struct GoogleLanguageModelProvider { state: gpui::Model, } -pub struct State { +struct State { api_key: Option, _subscription: Subscription, } @@ -63,10 +63,10 @@ impl GoogleLanguageModelProvider { } impl LanguageModelProviderState for GoogleLanguageModelProvider { - type ObservableEntity = State; - - fn observable_entity(&self) -> Option> { - Some(self.state.clone()) + fn subscribe(&self, cx: &mut gpui::ModelContext) -> Option { + Some(cx.observe(&self.state, |_, _, cx| { + cx.notify(); + })) } } diff --git a/crates/language_model/src/provider/ollama.rs b/crates/language_model/src/provider/ollama.rs index 9afa3825b0..0364866ccd 100644 --- a/crates/language_model/src/provider/ollama.rs +++ b/crates/language_model/src/provider/ollama.rs @@ -32,7 +32,7 @@ pub struct OllamaLanguageModelProvider { state: gpui::Model, } -pub struct State { +struct State { http_client: Arc, available_models: Vec, _subscription: Subscription, @@ -87,10 +87,10 @@ impl OllamaLanguageModelProvider { } impl LanguageModelProviderState for OllamaLanguageModelProvider { - type ObservableEntity = State; - - fn observable_entity(&self) -> Option> { - Some(self.state.clone()) + fn subscribe(&self, cx: &mut gpui::ModelContext) -> Option { + Some(cx.observe(&self.state, |_, _, cx| { + cx.notify(); + })) } } diff --git a/crates/language_model/src/provider/open_ai.rs b/crates/language_model/src/provider/open_ai.rs index e0239d959b..9f24dabb09 100644 --- a/crates/language_model/src/provider/open_ai.rs +++ b/crates/language_model/src/provider/open_ai.rs @@ -45,7 +45,7 @@ pub struct OpenAiLanguageModelProvider { state: gpui::Model, } -pub struct State { +struct State { api_key: Option, _subscription: Subscription, } @@ -64,10 +64,10 @@ impl OpenAiLanguageModelProvider { } impl LanguageModelProviderState for OpenAiLanguageModelProvider { - type ObservableEntity = State; - - fn observable_entity(&self) -> Option> { - Some(self.state.clone()) + fn subscribe(&self, cx: &mut gpui::ModelContext) -> Option { + Some(cx.observe(&self.state, |_, _, cx| { + cx.notify(); + })) } } diff --git a/crates/language_model/src/registry.rs b/crates/language_model/src/registry.rs index c7491cb70b..a3af7e6b18 100644 --- a/crates/language_model/src/registry.rs +++ b/crates/language_model/src/registry.rs @@ -54,7 +54,9 @@ fn register_language_model_providers( registry.register_provider(CloudLanguageModelProvider::new(client.clone(), cx), cx); } else { registry.unregister_provider( - &LanguageModelProviderId::from(crate::provider::cloud::PROVIDER_ID.to_string()), + &LanguageModelProviderId::from( + crate::provider::cloud::PROVIDER_NAME.to_string(), + ), cx, ); } @@ -78,12 +80,9 @@ pub struct ActiveModel { model: Option>, } -pub enum Event { - ActiveModelChanged, - ProviderStateChanged, -} +pub struct ActiveModelChanged; -impl EventEmitter for LanguageModelRegistry {} +impl EventEmitter for LanguageModelRegistry {} impl LanguageModelRegistry { pub fn global(cx: &AppContext) -> Model { @@ -115,10 +114,7 @@ impl LanguageModelRegistry { ) { let name = provider.id(); - let subscription = provider.subscribe(cx, |_, cx| { - cx.emit(Event::ProviderStateChanged); - }); - if let Some(subscription) = subscription { + if let Some(subscription) = provider.subscribe(cx) { subscription.detach(); } @@ -191,7 +187,7 @@ impl LanguageModelRegistry { provider, model: None, }); - cx.emit(Event::ActiveModelChanged); + cx.emit(ActiveModelChanged); } pub fn set_active_model( @@ -206,13 +202,13 @@ impl LanguageModelRegistry { provider, model: Some(model), }); - cx.emit(Event::ActiveModelChanged); + cx.emit(ActiveModelChanged); } else { log::warn!("Active model's provider not found in registry"); } } else { self.active_model = None; - cx.emit(Event::ActiveModelChanged); + cx.emit(ActiveModelChanged); } }