diff --git a/Cargo.lock b/Cargo.lock index 761bfb557b..c49a71c5eb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -704,6 +704,7 @@ dependencies = [ "assistant_tool", "chrono", "collections", + "feature_flags", "futures 0.3.31", "gpui", "html_to_markdown", @@ -721,9 +722,11 @@ dependencies = [ "ui", "unindent", "util", + "web_search", "workspace", "workspace-hack", "worktree", + "zed_llm_client", ] [[package]] @@ -16609,6 +16612,36 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "web_search" +version = "0.1.0" +dependencies = [ + "anyhow", + "collections", + "gpui", + "serde", + "workspace-hack", + "zed_llm_client", +] + +[[package]] +name = "web_search_providers" +version = "0.1.0" +dependencies = [ + "anyhow", + "client", + "feature_flags", + "futures 0.3.31", + "gpui", + "http_client", + "language_model", + "serde", + "serde_json", + "web_search", + "workspace-hack", + "zed_llm_client", +] + [[package]] name = "webpki-root-certs" version = "0.26.8" @@ -18287,6 +18320,8 @@ dependencies = [ "uuid", "vim", "vim_mode_setting", + "web_search", + "web_search_providers", "welcome", "windows 0.61.1", "winresource", @@ -18351,9 +18386,9 @@ dependencies = [ [[package]] name = "zed_llm_client" -version = "0.4.2" +version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1d28a5d6bdb0f40acf5261c39cabbf65a13b55ba4b86d9beb5b8b1c484373f1a" +checksum = "57a5e1b5b3ace3fb55292a4c14036723bb8a01fac4aeaa3c2b63b51228412f94" dependencies = [ "serde", "serde_json", diff --git a/Cargo.toml b/Cargo.toml index 23299809e8..b002855337 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -165,6 +165,8 @@ members = [ "crates/util_macros", "crates/vim", "crates/vim_mode_setting", + "crates/web_search", + "crates/web_search_providers", "crates/welcome", "crates/workspace", "crates/worktree", @@ -370,6 +372,8 @@ util = { path = "crates/util" } util_macros = { path = "crates/util_macros" } vim = { path = "crates/vim" } vim_mode_setting = { path = "crates/vim_mode_setting" } +web_search = { path = "crates/web_search" } +web_search_providers = { path = "crates/web_search_providers" } welcome = { path = "crates/welcome" } workspace = { path = "crates/workspace" } worktree = { path = "crates/worktree" } @@ -601,7 +605,7 @@ wasmtime-wasi = "29" which = "6.0.0" wit-component = "0.221" workspace-hack = "0.1.0" -zed_llm_client = "0.4.2" +zed_llm_client = "0.5.0" zstd = "0.11" metal = "0.29" diff --git a/assets/settings/default.json b/assets/settings/default.json index 5c6335099b..98ee37f213 100644 --- a/assets/settings/default.json +++ b/assets/settings/default.json @@ -652,7 +652,8 @@ "path_search": true, "read_file": true, "regex_search": true, - "thinking": true + "thinking": true, + "web_search": true } }, "write": { @@ -678,7 +679,8 @@ "regex_search": true, "rename": true, "symbol_info": true, - "thinking": true + "thinking": true, + "web_search": true } } }, diff --git a/crates/agent/src/active_thread.rs b/crates/agent/src/active_thread.rs index 0e9399b294..0cd90ba796 100644 --- a/crates/agent/src/active_thread.rs +++ b/crates/agent/src/active_thread.rs @@ -5,11 +5,12 @@ use crate::thread::{ ThreadEvent, ThreadFeedback, }; use crate::thread_store::{RulesLoadingError, ThreadStore}; -use crate::tool_use::{PendingToolUseStatus, ToolUse, ToolUseStatus}; +use crate::tool_use::{PendingToolUseStatus, ToolUse}; use crate::ui::{AddedContext, AgentNotification, AgentNotificationEvent, ContextPill}; use crate::{AssistantPanel, OpenActiveThreadAsMarkdown}; use anyhow::Context as _; use assistant_settings::{AssistantSettings, NotifyWhenAgentWaiting}; +use assistant_tool::ToolUseStatus; use collections::{HashMap, HashSet}; use editor::scroll::Autoscroll; use editor::{Editor, EditorElement, EditorStyle, MultiBuffer}; @@ -943,8 +944,8 @@ impl ActiveThread { &tool_use.input, self.thread .read(cx) - .tool_result(&tool_use.id) - .map(|result| result.content.clone().into()) + .output_for_tool(&tool_use.id) + .map(|output| output.clone().into()) .unwrap_or("".into()), cx, ); @@ -2279,12 +2280,15 @@ impl ActiveThread { window: &mut Window, cx: &mut Context, ) -> impl IntoElement + use<> { + if let Some(card) = self.thread.read(cx).card_for_tool(&tool_use.id) { + return card.render(&tool_use.status, window, cx); + } + let is_open = self .expanded_tool_uses .get(&tool_use.id) .copied() .unwrap_or_default(); - let is_status_finished = matches!(&tool_use.status, ToolUseStatus::Finished(_)); let fs = self @@ -2381,6 +2385,7 @@ impl ActiveThread { open_markdown_link(text, workspace.clone(), window, cx); } }) + .into_any_element() }), )), ), @@ -2437,6 +2442,7 @@ impl ActiveThread { open_markdown_link(text, workspace.clone(), window, cx); } }) + .into_any_element() })), ), ), @@ -2767,7 +2773,7 @@ impl ActiveThread { ) }) } - }) + }).into_any_element() } fn render_rules_item(&self, cx: &Context) -> AnyElement { diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index ee85c9a27f..694b212e31 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -6,7 +6,7 @@ use std::time::Instant; use anyhow::{Context as _, Result, anyhow}; use assistant_settings::AssistantSettings; -use assistant_tool::{ActionLog, Tool, ToolWorkingSet}; +use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet}; use chrono::{DateTime, Utc}; use collections::{BTreeMap, HashMap}; use feature_flags::{self, FeatureFlagAppExt}; @@ -631,6 +631,14 @@ impl Thread { self.tool_use.tool_result(id) } + pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc> { + Some(&self.tool_use.tool_result(id)?.content) + } + + pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option { + self.tool_use.tool_result_card(id).cloned() + } + pub fn message_has_tool_results(&self, message_id: MessageId) -> bool { self.tool_use.message_has_tool_results(message_id) } @@ -1426,6 +1434,12 @@ impl Thread { ) }; + // Store the card separately if it exists + if let Some(card) = tool_result.card.clone() { + self.tool_use + .insert_tool_result_card(tool_use_id.clone(), card); + } + cx.spawn({ async move |thread: WeakEntity, cx| { let output = tool_result.output.await; diff --git a/crates/agent/src/tool_use.rs b/crates/agent/src/tool_use.rs index 93576d57a3..32876a100c 100644 --- a/crates/agent/src/tool_use.rs +++ b/crates/agent/src/tool_use.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use anyhow::Result; -use assistant_tool::{Tool, ToolWorkingSet}; +use assistant_tool::{AnyToolCard, Tool, ToolUseStatus, ToolWorkingSet}; use collections::HashMap; use futures::FutureExt as _; use futures::future::Shared; @@ -27,26 +27,7 @@ pub struct ToolUse { pub needs_confirmation: bool, } -#[derive(Debug, Clone)] -pub enum ToolUseStatus { - NeedsConfirmation, - Pending, - Running, - Finished(SharedString), - Error(SharedString), -} - -impl ToolUseStatus { - pub fn text(&self) -> SharedString { - match self { - ToolUseStatus::NeedsConfirmation => "".into(), - ToolUseStatus::Pending => "".into(), - ToolUseStatus::Running => "".into(), - ToolUseStatus::Finished(out) => out.clone(), - ToolUseStatus::Error(out) => out.clone(), - } - } -} +pub const USING_TOOL_MARKER: &str = ""; pub struct ToolUseState { tools: Entity, @@ -54,10 +35,9 @@ pub struct ToolUseState { tool_uses_by_user_message: HashMap>, tool_results: HashMap, pending_tool_uses_by_id: HashMap, + tool_result_cards: HashMap, } -pub const USING_TOOL_MARKER: &str = ""; - impl ToolUseState { pub fn new(tools: Entity) -> Self { Self { @@ -66,6 +46,7 @@ impl ToolUseState { tool_uses_by_user_message: HashMap::default(), tool_results: HashMap::default(), pending_tool_uses_by_id: HashMap::default(), + tool_result_cards: HashMap::default(), } } @@ -257,6 +238,18 @@ impl ToolUseState { self.tool_results.get(tool_use_id) } + pub fn tool_result_card(&self, tool_use_id: &LanguageModelToolUseId) -> Option<&AnyToolCard> { + self.tool_result_cards.get(tool_use_id) + } + + pub fn insert_tool_result_card( + &mut self, + tool_use_id: LanguageModelToolUseId, + card: AnyToolCard, + ) { + self.tool_result_cards.insert(tool_use_id, card); + } + pub fn request_tool_use( &mut self, assistant_message_id: MessageId, diff --git a/crates/assistant_tool/src/assistant_tool.rs b/crates/assistant_tool/src/assistant_tool.rs index 89450efb13..cb7f0ff518 100644 --- a/crates/assistant_tool/src/assistant_tool.rs +++ b/crates/assistant_tool/src/assistant_tool.rs @@ -9,6 +9,10 @@ use std::fmt::Formatter; use std::sync::Arc; use anyhow::Result; +use gpui::AnyElement; +use gpui::Context; +use gpui::IntoElement; +use gpui::Window; use gpui::{App, Entity, SharedString, Task}; use icons::IconName; use language_model::LanguageModelRequestMessage; @@ -24,16 +28,87 @@ pub fn init(cx: &mut App) { ToolRegistry::default_global(cx); } -/// The result of running a tool +#[derive(Debug, Clone)] +pub enum ToolUseStatus { + NeedsConfirmation, + Pending, + Running, + Finished(SharedString), + Error(SharedString), +} + +impl ToolUseStatus { + pub fn text(&self) -> SharedString { + match self { + ToolUseStatus::NeedsConfirmation => "".into(), + ToolUseStatus::Pending => "".into(), + ToolUseStatus::Running => "".into(), + ToolUseStatus::Finished(out) => out.clone(), + ToolUseStatus::Error(out) => out.clone(), + } + } +} + +/// The result of running a tool, containing both the asynchronous output +/// and an optional card view that can be rendered immediately. pub struct ToolResult { /// The asynchronous task that will eventually resolve to the tool's output pub output: Task>, + /// An optional view to present the output of the tool. + pub card: Option, +} + +pub trait ToolCard: 'static + Sized { + fn render( + &mut self, + status: &ToolUseStatus, + window: &mut Window, + cx: &mut Context, + ) -> impl IntoElement; +} + +#[derive(Clone)] +pub struct AnyToolCard { + entity: gpui::AnyEntity, + render: fn( + entity: gpui::AnyEntity, + status: &ToolUseStatus, + window: &mut Window, + cx: &mut App, + ) -> AnyElement, +} + +impl From> for AnyToolCard { + fn from(entity: Entity) -> Self { + fn downcast_render( + entity: gpui::AnyEntity, + status: &ToolUseStatus, + window: &mut Window, + cx: &mut App, + ) -> AnyElement { + let entity = entity.downcast::().unwrap(); + entity.update(cx, |entity, cx| { + entity.render(status, window, cx).into_any_element() + }) + } + + Self { + entity: entity.into(), + render: downcast_render::, + } + } +} + +impl AnyToolCard { + pub fn render(&self, status: &ToolUseStatus, window: &mut Window, cx: &mut App) -> AnyElement { + (self.render)(self.entity.clone(), status, window, cx) + } } impl From>> for ToolResult { - /// Convert from a task to a ToolResult + /// Convert from a task to a ToolResult with no card fn from(output: Task>) -> Self { - Self { output } + Self { output, card: None } } } diff --git a/crates/assistant_tools/Cargo.toml b/crates/assistant_tools/Cargo.toml index c5fa01e76b..245a37ef4b 100644 --- a/crates/assistant_tools/Cargo.toml +++ b/crates/assistant_tools/Cargo.toml @@ -16,6 +16,7 @@ anyhow.workspace = true assistant_tool.workspace = true chrono.workspace = true collections.workspace = true +feature_flags.workspace = true futures.workspace = true gpui.workspace = true html_to_markdown.workspace = true @@ -32,7 +33,9 @@ ui.workspace = true util.workspace = true worktree.workspace = true open = { workspace = true } +web_search.workspace = true workspace-hack.workspace = true +zed_llm_client.workspace = true [dev-dependencies] collections = { workspace = true, features = ["test-support"] } diff --git a/crates/assistant_tools/src/assistant_tools.rs b/crates/assistant_tools/src/assistant_tools.rs index 3016f5412f..33e06466e2 100644 --- a/crates/assistant_tools/src/assistant_tools.rs +++ b/crates/assistant_tools/src/assistant_tools.rs @@ -22,14 +22,17 @@ mod schema; mod symbol_info_tool; mod terminal_tool; mod thinking_tool; +mod web_search_tool; use std::sync::Arc; use assistant_tool::ToolRegistry; use copy_path_tool::CopyPathTool; +use feature_flags::FeatureFlagAppExt; use gpui::App; use http_client::HttpClientWithUrl; use move_path_tool::MovePathTool; +use web_search_tool::WebSearchTool; use crate::batch_tool::BatchTool; use crate::code_action_tool::CodeActionTool; @@ -56,28 +59,39 @@ pub fn init(http_client: Arc, cx: &mut App) { assistant_tool::init(cx); let registry = ToolRegistry::global(cx); - registry.register_tool(TerminalTool); registry.register_tool(BatchTool); - registry.register_tool(CreateDirectoryTool); - registry.register_tool(CreateFileTool); - registry.register_tool(CopyPathTool); - registry.register_tool(DeletePathTool); - registry.register_tool(FindReplaceFileTool); - registry.register_tool(SymbolInfoTool); registry.register_tool(CodeActionTool); - registry.register_tool(MovePathTool); - registry.register_tool(DiagnosticsTool); - registry.register_tool(ListDirectoryTool); - registry.register_tool(NowTool); - registry.register_tool(OpenTool); registry.register_tool(CodeSymbolsTool); registry.register_tool(ContentsTool); + registry.register_tool(CopyPathTool); + registry.register_tool(CreateDirectoryTool); + registry.register_tool(CreateFileTool); + registry.register_tool(DeletePathTool); + registry.register_tool(DiagnosticsTool); + registry.register_tool(FetchTool::new(http_client)); + registry.register_tool(FindReplaceFileTool); + registry.register_tool(ListDirectoryTool); + registry.register_tool(MovePathTool); + registry.register_tool(NowTool); + registry.register_tool(OpenTool); registry.register_tool(PathSearchTool); registry.register_tool(ReadFileTool); registry.register_tool(RegexSearchTool); registry.register_tool(RenameTool); + registry.register_tool(SymbolInfoTool); + registry.register_tool(TerminalTool); registry.register_tool(ThinkingTool); - registry.register_tool(FetchTool::new(http_client)); + + cx.observe_flag::({ + move |is_enabled, cx| { + if is_enabled { + ToolRegistry::global(cx).register_tool(WebSearchTool); + } else { + ToolRegistry::global(cx).unregister_tool(WebSearchTool); + } + } + }) + .detach(); } #[cfg(test)] diff --git a/crates/assistant_tools/src/web_search_tool.rs b/crates/assistant_tools/src/web_search_tool.rs new file mode 100644 index 0000000000..081aef27ed --- /dev/null +++ b/crates/assistant_tools/src/web_search_tool.rs @@ -0,0 +1,213 @@ +use std::{sync::Arc, time::Duration}; + +use crate::schema::json_schema_for; +use anyhow::{Context as _, Result, anyhow}; +use assistant_tool::{ActionLog, Tool, ToolCard, ToolResult, ToolUseStatus}; +use futures::{FutureExt, TryFutureExt}; +use gpui::{ + Animation, AnimationExt, App, AppContext, Context, Entity, IntoElement, Task, Window, + pulsating_between, +}; +use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat}; +use project::Project; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; +use ui::{IconName, Tooltip, prelude::*}; +use web_search::WebSearchRegistry; +use zed_llm_client::WebSearchResponse; + +#[derive(Debug, Serialize, Deserialize, JsonSchema)] +pub struct WebSearchToolInput { + /// The search term or question to query on the web. + query: String, +} + +pub struct WebSearchTool; + +impl Tool for WebSearchTool { + fn name(&self) -> String { + "web_search".into() + } + + fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool { + false + } + + fn description(&self) -> String { + "Search the web for information using your query. Use this when you need real-time information, facts, or data that might not be in your training. Results will include snippets and links from relevant web pages.".into() + } + + fn icon(&self) -> IconName { + IconName::Globe + } + + fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result { + json_schema_for::(format) + } + + fn ui_text(&self, _input: &serde_json::Value) -> String { + "Web Search".to_string() + } + + fn run( + self: Arc, + input: serde_json::Value, + _messages: &[LanguageModelRequestMessage], + _project: Entity, + _action_log: Entity, + cx: &mut App, + ) -> ToolResult { + let input = match serde_json::from_value::(input) { + Ok(input) => input, + Err(err) => return Task::ready(Err(anyhow!(err))).into(), + }; + let Some(provider) = WebSearchRegistry::read_global(cx).active_provider() else { + return Task::ready(Err(anyhow!("Web search is not available."))).into(); + }; + + let search_task = provider.search(input.query, cx).map_err(Arc::new).shared(); + let output = cx.background_spawn({ + let search_task = search_task.clone(); + async move { + let response = search_task.await.map_err(|err| anyhow!(err))?; + serde_json::to_string(&response).context("Failed to serialize search results") + } + }); + + ToolResult { + output, + card: Some(cx.new(|cx| WebSearchToolCard::new(search_task, cx)).into()), + } + } +} + +struct WebSearchToolCard { + response: Option>, + _task: Task<()>, +} + +impl WebSearchToolCard { + fn new( + search_task: impl 'static + Future>>, + cx: &mut Context, + ) -> Self { + let _task = cx.spawn(async move |this, cx| { + let response = search_task.await.map_err(|err| anyhow!(err)); + this.update(cx, |this, cx| { + this.response = Some(response); + cx.notify(); + }) + .ok(); + }); + + Self { + response: None, + _task, + } + } +} + +impl ToolCard for WebSearchToolCard { + fn render( + &mut self, + _status: &ToolUseStatus, + _window: &mut Window, + cx: &mut Context, + ) -> impl IntoElement { + let header = h_flex() + .id("tool-label-container") + .gap_1p5() + .max_w_full() + .overflow_x_scroll() + .child( + Icon::new(IconName::Globe) + .size(IconSize::XSmall) + .color(Color::Muted), + ) + .child(match self.response.as_ref() { + Some(Ok(response)) => { + let text: SharedString = if response.citations.len() == 1 { + "1 result".into() + } else { + format!("{} results", response.citations.len()).into() + }; + h_flex() + .gap_1p5() + .child(Label::new("Searched the Web").size(LabelSize::Small)) + .child( + div() + .size(px(3.)) + .rounded_full() + .bg(cx.theme().colors().text), + ) + .child(Label::new(text).size(LabelSize::Small)) + .into_any_element() + } + Some(Err(error)) => div() + .id("web-search-error") + .child(Label::new("Web Search failed").size(LabelSize::Small)) + .tooltip(Tooltip::text(error.to_string())) + .into_any_element(), + + None => Label::new("Searching the Web…") + .size(LabelSize::Small) + .with_animation( + "web-search-label", + Animation::new(Duration::from_secs(2)) + .repeat() + .with_easing(pulsating_between(0.6, 1.)), + |label, delta| label.alpha(delta), + ) + .into_any_element(), + }) + .into_any(); + + let content = + self.response.as_ref().and_then(|response| match response { + Ok(response) => { + Some( + v_flex() + .ml_1p5() + .pl_1p5() + .border_l_1() + .border_color(cx.theme().colors().border_variant) + .gap_1() + .children(response.citations.iter().enumerate().map( + |(index, citation)| { + let title = citation.title.clone(); + let url = citation.url.clone(); + + Button::new(("citation", index), title) + .label_size(LabelSize::Small) + .color(Color::Muted) + .icon(IconName::ArrowUpRight) + .icon_size(IconSize::XSmall) + .icon_position(IconPosition::End) + .truncate(true) + .tooltip({ + let url = url.clone(); + move |window, cx| { + Tooltip::with_meta( + "Citation Link", + None, + url.clone(), + window, + cx, + ) + } + }) + .on_click({ + let url = url.clone(); + move |_, _, cx| cx.open_url(&url) + }) + }, + )) + .into_any(), + ) + } + Err(_) => None, + }); + + v_flex().my_2().gap_1().child(header).children(content) + } +} diff --git a/crates/feature_flags/src/feature_flags.rs b/crates/feature_flags/src/feature_flags.rs index 772619a899..17a2d811f3 100644 --- a/crates/feature_flags/src/feature_flags.rs +++ b/crates/feature_flags/src/feature_flags.rs @@ -84,6 +84,11 @@ impl FeatureFlag for ZedPro { const NAME: &'static str = "zed-pro"; } +pub struct ZedProWebSearchTool {} +impl FeatureFlag for ZedProWebSearchTool { + const NAME: &'static str = "zed-pro-web-search-tool"; +} + pub struct NotebookFeatureFlag; impl FeatureFlag for NotebookFeatureFlag { diff --git a/crates/ui/src/components/tooltip.rs b/crates/ui/src/components/tooltip.rs index d692f45a33..647b700c37 100644 --- a/crates/ui/src/components/tooltip.rs +++ b/crates/ui/src/components/tooltip.rs @@ -160,7 +160,11 @@ impl Render for Tooltip { }), ) .when_some(self.meta.clone(), |this, meta| { - this.child(Label::new(meta).size(LabelSize::Small).color(Color::Muted)) + this.child( + div() + .max_w_72() + .child(Label::new(meta).size(LabelSize::Small).color(Color::Muted)), + ) }) }) } diff --git a/crates/web_search/Cargo.toml b/crates/web_search/Cargo.toml new file mode 100644 index 0000000000..e5b8ca63b2 --- /dev/null +++ b/crates/web_search/Cargo.toml @@ -0,0 +1,20 @@ +[package] +name = "web_search" +version = "0.1.0" +edition.workspace = true +publish.workspace = true +license = "GPL-3.0-or-later" + +[lints] +workspace = true + +[lib] +path = "src/web_search.rs" + +[dependencies] +anyhow.workspace = true +collections.workspace = true +gpui.workspace = true +serde.workspace = true +workspace-hack.workspace = true +zed_llm_client.workspace = true diff --git a/crates/web_search/LICENSE-GPL b/crates/web_search/LICENSE-GPL new file mode 120000 index 0000000000..89e542f750 --- /dev/null +++ b/crates/web_search/LICENSE-GPL @@ -0,0 +1 @@ +../../LICENSE-GPL \ No newline at end of file diff --git a/crates/web_search/src/web_search.rs b/crates/web_search/src/web_search.rs new file mode 100644 index 0000000000..73ff75b748 --- /dev/null +++ b/crates/web_search/src/web_search.rs @@ -0,0 +1,64 @@ +use anyhow::Result; +use collections::HashMap; +use gpui::{App, AppContext as _, Context, Entity, Global, SharedString, Task}; +use std::sync::Arc; +use zed_llm_client::WebSearchResponse; + +pub fn init(cx: &mut App) { + let registry = cx.new(|_cx| WebSearchRegistry::default()); + cx.set_global(GlobalWebSearchRegistry(registry)); +} + +#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)] +pub struct WebSearchProviderId(pub SharedString); + +pub trait WebSearchProvider { + fn id(&self) -> WebSearchProviderId; + fn search(&self, query: String, cx: &mut App) -> Task>; +} + +struct GlobalWebSearchRegistry(Entity); + +impl Global for GlobalWebSearchRegistry {} + +#[derive(Default)] +pub struct WebSearchRegistry { + providers: HashMap>, + active_provider: Option>, +} + +impl WebSearchRegistry { + pub fn global(cx: &App) -> Entity { + cx.global::().0.clone() + } + + pub fn read_global(cx: &App) -> &Self { + cx.global::().0.read(cx) + } + + pub fn providers(&self) -> impl Iterator> { + self.providers.values() + } + + pub fn active_provider(&self) -> Option> { + self.active_provider.clone() + } + + pub fn set_active_provider(&mut self, provider: Arc) { + self.active_provider = Some(provider.clone()); + self.providers.insert(provider.id(), provider); + } + + pub fn register_provider( + &mut self, + provider: T, + _cx: &mut Context, + ) { + let id = provider.id(); + let provider = Arc::new(provider); + self.providers.insert(id.clone(), provider.clone()); + if self.active_provider.is_none() { + self.active_provider = Some(provider); + } + } +} diff --git a/crates/web_search_providers/Cargo.toml b/crates/web_search_providers/Cargo.toml new file mode 100644 index 0000000000..208cb63593 --- /dev/null +++ b/crates/web_search_providers/Cargo.toml @@ -0,0 +1,26 @@ +[package] +name = "web_search_providers" +version = "0.1.0" +edition.workspace = true +publish.workspace = true +license = "GPL-3.0-or-later" + +[lints] +workspace = true + +[lib] +path = "src/web_search_providers.rs" + +[dependencies] +anyhow.workspace = true +client.workspace = true +feature_flags.workspace = true +futures.workspace = true +gpui.workspace = true +http_client.workspace = true +language_model.workspace = true +serde.workspace = true +serde_json.workspace = true +web_search.workspace = true +workspace-hack.workspace = true +zed_llm_client.workspace = true diff --git a/crates/web_search_providers/LICENSE-GPL b/crates/web_search_providers/LICENSE-GPL new file mode 120000 index 0000000000..89e542f750 --- /dev/null +++ b/crates/web_search_providers/LICENSE-GPL @@ -0,0 +1 @@ +../../LICENSE-GPL \ No newline at end of file diff --git a/crates/web_search_providers/src/cloud.rs b/crates/web_search_providers/src/cloud.rs new file mode 100644 index 0000000000..8a764b9671 --- /dev/null +++ b/crates/web_search_providers/src/cloud.rs @@ -0,0 +1,103 @@ +use std::sync::Arc; + +use anyhow::{Context as _, Result, anyhow}; +use client::Client; +use futures::AsyncReadExt as _; +use gpui::{App, AppContext, Context, Entity, Subscription, Task}; +use http_client::{HttpClient, Method}; +use language_model::{LlmApiToken, RefreshLlmTokenListener}; +use web_search::{WebSearchProvider, WebSearchProviderId}; +use zed_llm_client::{WebSearchBody, WebSearchResponse}; + +pub struct CloudWebSearchProvider { + state: Entity, +} + +impl CloudWebSearchProvider { + pub fn new(client: Arc, cx: &mut App) -> Self { + let state = cx.new(|cx| State::new(client, cx)); + + Self { state } + } +} + +pub struct State { + client: Arc, + llm_api_token: LlmApiToken, + _llm_token_subscription: Subscription, +} + +impl State { + pub fn new(client: Arc, cx: &mut Context) -> Self { + let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx); + + Self { + client, + llm_api_token: LlmApiToken::default(), + _llm_token_subscription: cx.subscribe( + &refresh_llm_token_listener, + |this, _, _event, cx| { + let client = this.client.clone(); + let llm_api_token = this.llm_api_token.clone(); + cx.spawn(async move |_this, _cx| { + llm_api_token.refresh(&client).await?; + anyhow::Ok(()) + }) + .detach_and_log_err(cx); + }, + ), + } + } +} + +impl WebSearchProvider for CloudWebSearchProvider { + fn id(&self) -> WebSearchProviderId { + WebSearchProviderId("zed.dev".into()) + } + + fn search(&self, query: String, cx: &mut App) -> Task> { + let state = self.state.read(cx); + let client = state.client.clone(); + let llm_api_token = state.llm_api_token.clone(); + let body = WebSearchBody { query }; + cx.background_spawn(async move { perform_web_search(client, llm_api_token, body).await }) + } +} + +async fn perform_web_search( + client: Arc, + llm_api_token: LlmApiToken, + body: WebSearchBody, +) -> Result { + let http_client = &client.http_client(); + + let token = llm_api_token.acquire(&client).await?; + + let request_builder = http_client::Request::builder().method(Method::POST); + let request_builder = if let Ok(web_search_url) = std::env::var("ZED_WEB_SEARCH_URL") { + request_builder.uri(web_search_url) + } else { + request_builder.uri(http_client.build_zed_llm_url("/web_search", &[])?.as_ref()) + }; + let request = request_builder + .header("Content-Type", "application/json") + .header("Authorization", format!("Bearer {token}")) + .body(serde_json::to_string(&body)?.into())?; + let mut response = http_client + .send(request) + .await + .context("failed to send web search request")?; + + if response.status().is_success() { + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + return Ok(serde_json::from_str(&body)?); + } else { + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + return Err(anyhow!( + "error performing web search.\nStatus: {:?}\nBody: {body}", + response.status(), + )); + } +} diff --git a/crates/web_search_providers/src/web_search_providers.rs b/crates/web_search_providers/src/web_search_providers.rs new file mode 100644 index 0000000000..d547ee7308 --- /dev/null +++ b/crates/web_search_providers/src/web_search_providers.rs @@ -0,0 +1,35 @@ +mod cloud; + +use client::Client; +use feature_flags::{FeatureFlagAppExt, ZedProWebSearchTool}; +use gpui::{App, Context}; +use std::sync::Arc; +use web_search::WebSearchRegistry; + +pub fn init(client: Arc, cx: &mut App) { + let registry = WebSearchRegistry::global(cx); + registry.update(cx, |registry, cx| { + register_web_search_providers(registry, client, cx); + }); +} + +fn register_web_search_providers( + _registry: &mut WebSearchRegistry, + client: Arc, + cx: &mut Context, +) { + cx.observe_flag::({ + let client = client.clone(); + move |is_enabled, cx| { + if is_enabled { + WebSearchRegistry::global(cx).update(cx, |registry, cx| { + registry.register_provider( + cloud::CloudWebSearchProvider::new(client.clone(), cx), + cx, + ); + }); + } + } + }) + .detach(); +} diff --git a/crates/zed/Cargo.toml b/crates/zed/Cargo.toml index a4f6f71075..6b61ed00ac 100644 --- a/crates/zed/Cargo.toml +++ b/crates/zed/Cargo.toml @@ -133,6 +133,8 @@ util.workspace = true uuid.workspace = true vim.workspace = true vim_mode_setting.workspace = true +web_search.workspace = true +web_search_providers.workspace = true welcome.workspace = true workspace.workspace = true zed_actions.workspace = true diff --git a/crates/zed/src/main.rs b/crates/zed/src/main.rs index 21fff01cd5..967f4aac14 100644 --- a/crates/zed/src/main.rs +++ b/crates/zed/src/main.rs @@ -490,6 +490,8 @@ fn main() { app_state.fs.clone(), cx, ); + web_search::init(cx); + web_search_providers::init(app_state.client.clone(), cx); snippet_provider::init(cx); inline_completion_registry::init( app_state.client.clone(), diff --git a/crates/zed/src/zed.rs b/crates/zed/src/zed.rs index 03a7ad149e..691de1edca 100644 --- a/crates/zed/src/zed.rs +++ b/crates/zed/src/zed.rs @@ -4258,6 +4258,8 @@ mod tests { app_state.fs.clone(), cx, ); + web_search::init(cx); + web_search_providers::init(app_state.client.clone(), cx); let prompt_builder = PromptBuilder::load(app_state.fs.clone(), false, cx); assistant::init( app_state.fs.clone(),