diff --git a/Cargo.lock b/Cargo.lock index 04dde837b5..dfe0878c8d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -703,9 +703,10 @@ dependencies = [ "anyhow", "assistant_tool", "chrono", + "client", + "clock", "collections", "component", - "feature_flags", "futures 0.3.31", "gpui", "html_to_markdown", @@ -16631,7 +16632,6 @@ version = "0.1.0" dependencies = [ "anyhow", "client", - "feature_flags", "futures 0.3.31", "gpui", "http_client", diff --git a/crates/assistant/src/assistant_panel.rs b/crates/assistant/src/assistant_panel.rs index 205db386e2..ea979ec703 100644 --- a/crates/assistant/src/assistant_panel.rs +++ b/crates/assistant/src/assistant_panel.rs @@ -23,7 +23,6 @@ use gpui::{ use language::LanguageRegistry; use language_model::{ AuthenticateError, ConfiguredModel, LanguageModelProviderId, LanguageModelRegistry, - ZED_CLOUD_PROVIDER_ID, }; use project::Project; use prompt_library::{PromptLibrary, open_prompt_library}; @@ -489,8 +488,8 @@ impl AssistantPanel { // If we're signed out and don't have a provider configured, or we're signed-out AND Zed.dev is // the provider, we want to show a nudge to sign in. - let show_zed_ai_notice = client_status.is_signed_out() - && model.map_or(true, |model| model.provider.id().0 == ZED_CLOUD_PROVIDER_ID); + let show_zed_ai_notice = + client_status.is_signed_out() && model.map_or(true, |model| model.is_provided_by_zed()); self.show_zed_ai_notice = show_zed_ai_notice; cx.notify(); diff --git a/crates/assistant_tools/Cargo.toml b/crates/assistant_tools/Cargo.toml index eaaeff1e47..26ddb53d8f 100644 --- a/crates/assistant_tools/Cargo.toml +++ b/crates/assistant_tools/Cargo.toml @@ -17,7 +17,6 @@ assistant_tool.workspace = true chrono.workspace = true collections.workspace = true component.workspace = true -feature_flags.workspace = true futures.workspace = true gpui.workspace = true html_to_markdown.workspace = true @@ -41,6 +40,8 @@ worktree.workspace = true zed_llm_client.workspace = true [dev-dependencies] +client = { workspace = true, features = ["test-support"] } +clock = { workspace = true, features = ["test-support"] } collections = { workspace = true, features = ["test-support"] } gpui = { workspace = true, features = ["test-support"] } language = { workspace = true, features = ["test-support"] } diff --git a/crates/assistant_tools/src/assistant_tools.rs b/crates/assistant_tools/src/assistant_tools.rs index 86e000e3b2..250c1490e5 100644 --- a/crates/assistant_tools/src/assistant_tools.rs +++ b/crates/assistant_tools/src/assistant_tools.rs @@ -29,9 +29,9 @@ use std::sync::Arc; use assistant_tool::ToolRegistry; use copy_path_tool::CopyPathTool; -use feature_flags::FeatureFlagAppExt; use gpui::App; use http_client::HttpClientWithUrl; +use language_model::LanguageModelRegistry; use move_path_tool::MovePathTool; use web_search_tool::WebSearchTool; @@ -85,34 +85,45 @@ pub fn init(http_client: Arc, cx: &mut App) { registry.register_tool(ThinkingTool); registry.register_tool(FetchTool::new(http_client)); - cx.observe_flag::({ - move |is_enabled, cx| { - if is_enabled { - ToolRegistry::global(cx).register_tool(WebSearchTool); - } else { - ToolRegistry::global(cx).unregister_tool(WebSearchTool); + cx.subscribe( + &LanguageModelRegistry::global(cx), + move |registry, event, cx| match event { + language_model::Event::DefaultModelChanged => { + let using_zed_provider = registry + .read(cx) + .default_model() + .map_or(false, |default| default.is_provided_by_zed()); + if using_zed_provider { + ToolRegistry::global(cx).register_tool(WebSearchTool); + } else { + ToolRegistry::global(cx).unregister_tool(WebSearchTool); + } } - } - }) + _ => {} + }, + ) .detach(); } #[cfg(test)] mod tests { + use client::Client; + use clock::FakeSystemClock; use http_client::FakeHttpClient; use super::*; #[gpui::test] fn test_builtin_tool_schema_compatibility(cx: &mut App) { - crate::init( - Arc::new(http_client::HttpClientWithUrl::new( - FakeHttpClient::with_200_response(), - "https://zed.dev", - None, - )), + settings::init(cx); + + let client = Client::new( + Arc::new(FakeSystemClock::new()), + FakeHttpClient::with_200_response(), cx, ); + language_model::init(client.clone(), cx); + crate::init(client.http_client(), cx); for tool in ToolRegistry::global(cx).tools() { let actual_schema = tool diff --git a/crates/feature_flags/src/feature_flags.rs b/crates/feature_flags/src/feature_flags.rs index 17a2d811f3..772619a899 100644 --- a/crates/feature_flags/src/feature_flags.rs +++ b/crates/feature_flags/src/feature_flags.rs @@ -84,11 +84,6 @@ impl FeatureFlag for ZedPro { const NAME: &'static str = "zed-pro"; } -pub struct ZedProWebSearchTool {} -impl FeatureFlag for ZedProWebSearchTool { - const NAME: &'static str = "zed-pro-web-search-tool"; -} - pub struct NotebookFeatureFlag; impl FeatureFlag for NotebookFeatureFlag { diff --git a/crates/language_model/src/registry.rs b/crates/language_model/src/registry.rs index 62f216094b..1f17a6e822 100644 --- a/crates/language_model/src/registry.rs +++ b/crates/language_model/src/registry.rs @@ -42,6 +42,10 @@ impl ConfiguredModel { pub fn is_same_as(&self, other: &ConfiguredModel) -> bool { self.model.id() == other.model.id() && self.provider.id() == other.provider.id() } + + pub fn is_provided_by_zed(&self) -> bool { + self.provider.id().0 == crate::ZED_CLOUD_PROVIDER_ID + } } pub enum Event { diff --git a/crates/web_search/src/web_search.rs b/crates/web_search/src/web_search.rs index 73ff75b748..a131b0de71 100644 --- a/crates/web_search/src/web_search.rs +++ b/crates/web_search/src/web_search.rs @@ -61,4 +61,11 @@ impl WebSearchRegistry { self.active_provider = Some(provider); } } + + pub fn unregister_provider(&mut self, id: WebSearchProviderId) { + self.providers.remove(&id); + if self.active_provider.as_ref().map(|provider| provider.id()) == Some(id) { + self.active_provider = None; + } + } } diff --git a/crates/web_search_providers/Cargo.toml b/crates/web_search_providers/Cargo.toml index 208cb63593..2e052796c4 100644 --- a/crates/web_search_providers/Cargo.toml +++ b/crates/web_search_providers/Cargo.toml @@ -14,7 +14,6 @@ path = "src/web_search_providers.rs" [dependencies] anyhow.workspace = true client.workspace = true -feature_flags.workspace = true futures.workspace = true gpui.workspace = true http_client.workspace = true diff --git a/crates/web_search_providers/src/cloud.rs b/crates/web_search_providers/src/cloud.rs index 8a764b9671..ec1469e0d2 100644 --- a/crates/web_search_providers/src/cloud.rs +++ b/crates/web_search_providers/src/cloud.rs @@ -50,9 +50,11 @@ impl State { } } +pub const ZED_WEB_SEARCH_PROVIDER_ID: &'static str = "zed.dev"; + impl WebSearchProvider for CloudWebSearchProvider { fn id(&self) -> WebSearchProviderId { - WebSearchProviderId("zed.dev".into()) + WebSearchProviderId(ZED_WEB_SEARCH_PROVIDER_ID.into()) } fn search(&self, query: String, cx: &mut App) -> Task> { diff --git a/crates/web_search_providers/src/web_search_providers.rs b/crates/web_search_providers/src/web_search_providers.rs index d547ee7308..c2b563e7eb 100644 --- a/crates/web_search_providers/src/web_search_providers.rs +++ b/crates/web_search_providers/src/web_search_providers.rs @@ -1,10 +1,10 @@ mod cloud; use client::Client; -use feature_flags::{FeatureFlagAppExt, ZedProWebSearchTool}; use gpui::{App, Context}; +use language_model::LanguageModelRegistry; use std::sync::Arc; -use web_search::WebSearchRegistry; +use web_search::{WebSearchProviderId, WebSearchRegistry}; pub fn init(client: Arc, cx: &mut App) { let registry = WebSearchRegistry::global(cx); @@ -18,18 +18,27 @@ fn register_web_search_providers( client: Arc, cx: &mut Context, ) { - cx.observe_flag::({ - let client = client.clone(); - move |is_enabled, cx| { - if is_enabled { - WebSearchRegistry::global(cx).update(cx, |registry, cx| { - registry.register_provider( + cx.subscribe( + &LanguageModelRegistry::global(cx), + move |this, registry, event, cx| match event { + language_model::Event::DefaultModelChanged => { + let using_zed_provider = registry + .read(cx) + .default_model() + .map_or(false, |default| default.is_provided_by_zed()); + if using_zed_provider { + this.register_provider( cloud::CloudWebSearchProvider::new(client.clone(), cx), cx, - ); - }); + ) + } else { + this.unregister_provider(WebSearchProviderId( + cloud::ZED_WEB_SEARCH_PROVIDER_ID.into(), + )); + } } - } - }) + _ => {} + }, + ) .detach(); }