agent: Ensure that web search tool is always available (#29799)

Some changes in the LanguageModelRegistry caused the web search tool not
to show up, because the `DefaultModelChanged` event is not emitted at
startup anymore.

Release Notes:

- agent: Fixed an issue where the web search tool would not be available
after starting Zed (only when using zed.dev as a provider).
This commit is contained in:
Bennet Bo Fenner 2025-05-02 17:34:08 +02:00 committed by GitHub
parent c4556e9909
commit fde621f0e3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 44 additions and 26 deletions

View file

@ -34,7 +34,7 @@ use assistant_settings::AssistantSettings;
use assistant_tool::ToolRegistry; use assistant_tool::ToolRegistry;
use copy_path_tool::CopyPathTool; use copy_path_tool::CopyPathTool;
use feature_flags::{AgentStreamEditsFeatureFlag, FeatureFlagAppExt}; use feature_flags::{AgentStreamEditsFeatureFlag, FeatureFlagAppExt};
use gpui::App; use gpui::{App, Entity};
use http_client::HttpClientWithUrl; use http_client::HttpClientWithUrl;
use language_model::LanguageModelRegistry; use language_model::LanguageModelRegistry;
use move_path_tool::MovePathTool; use move_path_tool::MovePathTool;
@ -101,19 +101,12 @@ pub fn init(http_client: Arc<HttpClientWithUrl>, cx: &mut App) {
cx.observe_global::<SettingsStore>(register_edit_file_tool) cx.observe_global::<SettingsStore>(register_edit_file_tool)
.detach(); .detach();
register_web_search_tool(&LanguageModelRegistry::global(cx), cx);
cx.subscribe( cx.subscribe(
&LanguageModelRegistry::global(cx), &LanguageModelRegistry::global(cx),
move |registry, event, cx| match event { move |registry, event, cx| match event {
language_model::Event::DefaultModelChanged => { language_model::Event::DefaultModelChanged => {
let using_zed_provider = registry register_web_search_tool(&registry, cx);
.read(cx)
.default_model()
.map_or(false, |default| default.is_provided_by_zed());
if using_zed_provider {
ToolRegistry::global(cx).register_tool(WebSearchTool);
} else {
ToolRegistry::global(cx).unregister_tool(WebSearchTool);
}
} }
_ => {} _ => {}
}, },
@ -121,6 +114,18 @@ pub fn init(http_client: Arc<HttpClientWithUrl>, cx: &mut App) {
.detach(); .detach();
} }
fn register_web_search_tool(registry: &Entity<LanguageModelRegistry>, cx: &mut App) {
let using_zed_provider = registry
.read(cx)
.default_model()
.map_or(false, |default| default.is_provided_by_zed());
if using_zed_provider {
ToolRegistry::global(cx).register_tool(WebSearchTool);
} else {
ToolRegistry::global(cx).unregister_tool(WebSearchTool);
}
}
fn register_edit_file_tool(cx: &mut App) { fn register_edit_file_tool(cx: &mut App) {
let registry = ToolRegistry::global(cx); let registry = ToolRegistry::global(cx);

View file

@ -1,7 +1,7 @@
mod cloud; mod cloud;
use client::Client; use client::Client;
use gpui::{App, Context}; use gpui::{App, Context, Entity};
use language_model::LanguageModelRegistry; use language_model::LanguageModelRegistry;
use std::sync::Arc; use std::sync::Arc;
use web_search::{WebSearchProviderId, WebSearchRegistry}; use web_search::{WebSearchProviderId, WebSearchRegistry};
@ -14,31 +14,44 @@ pub fn init(client: Arc<Client>, cx: &mut App) {
} }
fn register_web_search_providers( fn register_web_search_providers(
_registry: &mut WebSearchRegistry, registry: &mut WebSearchRegistry,
client: Arc<Client>, client: Arc<Client>,
cx: &mut Context<WebSearchRegistry>, cx: &mut Context<WebSearchRegistry>,
) { ) {
register_zed_web_search_provider(
registry,
client.clone(),
&LanguageModelRegistry::global(cx),
cx,
);
cx.subscribe( cx.subscribe(
&LanguageModelRegistry::global(cx), &LanguageModelRegistry::global(cx),
move |this, registry, event, cx| match event { move |this, registry, event, cx| match event {
language_model::Event::DefaultModelChanged => { language_model::Event::DefaultModelChanged => {
let using_zed_provider = registry register_zed_web_search_provider(this, client.clone(), &registry, cx)
.read(cx)
.default_model()
.map_or(false, |default| default.is_provided_by_zed());
if using_zed_provider {
this.register_provider(
cloud::CloudWebSearchProvider::new(client.clone(), cx),
cx,
)
} else {
this.unregister_provider(WebSearchProviderId(
cloud::ZED_WEB_SEARCH_PROVIDER_ID.into(),
));
}
} }
_ => {} _ => {}
}, },
) )
.detach(); .detach();
} }
fn register_zed_web_search_provider(
registry: &mut WebSearchRegistry,
client: Arc<Client>,
language_model_registry: &Entity<LanguageModelRegistry>,
cx: &mut Context<WebSearchRegistry>,
) {
let using_zed_provider = language_model_registry
.read(cx)
.default_model()
.map_or(false, |default| default.is_provided_by_zed());
if using_zed_provider {
registry.register_provider(cloud::CloudWebSearchProvider::new(client, cx), cx)
} else {
registry.unregister_provider(WebSearchProviderId(
cloud::ZED_WEB_SEARCH_PROVIDER_ID.into(),
));
}
}