agent: Add websearch tool (#28621)

Staff only for now. We'll work on making this usable for non zed.dev
users later

Release Notes:

- N/A

---------

Co-authored-by: Antonio Scandurra <me@as-cii.com>
Co-authored-by: Danilo Leal <daniloleal09@gmail.com>
Co-authored-by: Marshall Bowers <git@maxdeviant.com>
This commit is contained in:
Bennet Bo Fenner 2025-04-16 19:25:00 +02:00 committed by GitHub
parent 2b277123be
commit 456e54b87c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
22 changed files with 675 additions and 51 deletions

39
Cargo.lock generated
View file

@ -704,6 +704,7 @@ dependencies = [
"assistant_tool",
"chrono",
"collections",
"feature_flags",
"futures 0.3.31",
"gpui",
"html_to_markdown",
@ -721,9 +722,11 @@ dependencies = [
"ui",
"unindent",
"util",
"web_search",
"workspace",
"workspace-hack",
"worktree",
"zed_llm_client",
]
[[package]]
@ -16609,6 +16612,36 @@ dependencies = [
"wasm-bindgen",
]
[[package]]
name = "web_search"
version = "0.1.0"
dependencies = [
"anyhow",
"collections",
"gpui",
"serde",
"workspace-hack",
"zed_llm_client",
]
[[package]]
name = "web_search_providers"
version = "0.1.0"
dependencies = [
"anyhow",
"client",
"feature_flags",
"futures 0.3.31",
"gpui",
"http_client",
"language_model",
"serde",
"serde_json",
"web_search",
"workspace-hack",
"zed_llm_client",
]
[[package]]
name = "webpki-root-certs"
version = "0.26.8"
@ -18287,6 +18320,8 @@ dependencies = [
"uuid",
"vim",
"vim_mode_setting",
"web_search",
"web_search_providers",
"welcome",
"windows 0.61.1",
"winresource",
@ -18351,9 +18386,9 @@ dependencies = [
[[package]]
name = "zed_llm_client"
version = "0.4.2"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1d28a5d6bdb0f40acf5261c39cabbf65a13b55ba4b86d9beb5b8b1c484373f1a"
checksum = "57a5e1b5b3ace3fb55292a4c14036723bb8a01fac4aeaa3c2b63b51228412f94"
dependencies = [
"serde",
"serde_json",

View file

@ -165,6 +165,8 @@ members = [
"crates/util_macros",
"crates/vim",
"crates/vim_mode_setting",
"crates/web_search",
"crates/web_search_providers",
"crates/welcome",
"crates/workspace",
"crates/worktree",
@ -370,6 +372,8 @@ util = { path = "crates/util" }
util_macros = { path = "crates/util_macros" }
vim = { path = "crates/vim" }
vim_mode_setting = { path = "crates/vim_mode_setting" }
web_search = { path = "crates/web_search" }
web_search_providers = { path = "crates/web_search_providers" }
welcome = { path = "crates/welcome" }
workspace = { path = "crates/workspace" }
worktree = { path = "crates/worktree" }
@ -601,7 +605,7 @@ wasmtime-wasi = "29"
which = "6.0.0"
wit-component = "0.221"
workspace-hack = "0.1.0"
zed_llm_client = "0.4.2"
zed_llm_client = "0.5.0"
zstd = "0.11"
metal = "0.29"

View file

@ -652,7 +652,8 @@
"path_search": true,
"read_file": true,
"regex_search": true,
"thinking": true
"thinking": true,
"web_search": true
}
},
"write": {
@ -678,7 +679,8 @@
"regex_search": true,
"rename": true,
"symbol_info": true,
"thinking": true
"thinking": true,
"web_search": true
}
}
},

View file

@ -5,11 +5,12 @@ use crate::thread::{
ThreadEvent, ThreadFeedback,
};
use crate::thread_store::{RulesLoadingError, ThreadStore};
use crate::tool_use::{PendingToolUseStatus, ToolUse, ToolUseStatus};
use crate::tool_use::{PendingToolUseStatus, ToolUse};
use crate::ui::{AddedContext, AgentNotification, AgentNotificationEvent, ContextPill};
use crate::{AssistantPanel, OpenActiveThreadAsMarkdown};
use anyhow::Context as _;
use assistant_settings::{AssistantSettings, NotifyWhenAgentWaiting};
use assistant_tool::ToolUseStatus;
use collections::{HashMap, HashSet};
use editor::scroll::Autoscroll;
use editor::{Editor, EditorElement, EditorStyle, MultiBuffer};
@ -943,8 +944,8 @@ impl ActiveThread {
&tool_use.input,
self.thread
.read(cx)
.tool_result(&tool_use.id)
.map(|result| result.content.clone().into())
.output_for_tool(&tool_use.id)
.map(|output| output.clone().into())
.unwrap_or("".into()),
cx,
);
@ -2279,12 +2280,15 @@ impl ActiveThread {
window: &mut Window,
cx: &mut Context<Self>,
) -> impl IntoElement + use<> {
if let Some(card) = self.thread.read(cx).card_for_tool(&tool_use.id) {
return card.render(&tool_use.status, window, cx);
}
let is_open = self
.expanded_tool_uses
.get(&tool_use.id)
.copied()
.unwrap_or_default();
let is_status_finished = matches!(&tool_use.status, ToolUseStatus::Finished(_));
let fs = self
@ -2381,6 +2385,7 @@ impl ActiveThread {
open_markdown_link(text, workspace.clone(), window, cx);
}
})
.into_any_element()
}),
)),
),
@ -2437,6 +2442,7 @@ impl ActiveThread {
open_markdown_link(text, workspace.clone(), window, cx);
}
})
.into_any_element()
})),
),
),
@ -2767,7 +2773,7 @@ impl ActiveThread {
)
})
}
})
}).into_any_element()
}
fn render_rules_item(&self, cx: &Context<Self>) -> AnyElement {

View file

@ -6,7 +6,7 @@ use std::time::Instant;
use anyhow::{Context as _, Result, anyhow};
use assistant_settings::AssistantSettings;
use assistant_tool::{ActionLog, Tool, ToolWorkingSet};
use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
use chrono::{DateTime, Utc};
use collections::{BTreeMap, HashMap};
use feature_flags::{self, FeatureFlagAppExt};
@ -631,6 +631,14 @@ impl Thread {
self.tool_use.tool_result(id)
}
pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
Some(&self.tool_use.tool_result(id)?.content)
}
pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
self.tool_use.tool_result_card(id).cloned()
}
pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
self.tool_use.message_has_tool_results(message_id)
}
@ -1426,6 +1434,12 @@ impl Thread {
)
};
// Store the card separately if it exists
if let Some(card) = tool_result.card.clone() {
self.tool_use
.insert_tool_result_card(tool_use_id.clone(), card);
}
cx.spawn({
async move |thread: WeakEntity<Thread>, cx| {
let output = tool_result.output.await;

View file

@ -1,7 +1,7 @@
use std::sync::Arc;
use anyhow::Result;
use assistant_tool::{Tool, ToolWorkingSet};
use assistant_tool::{AnyToolCard, Tool, ToolUseStatus, ToolWorkingSet};
use collections::HashMap;
use futures::FutureExt as _;
use futures::future::Shared;
@ -27,26 +27,7 @@ pub struct ToolUse {
pub needs_confirmation: bool,
}
#[derive(Debug, Clone)]
pub enum ToolUseStatus {
NeedsConfirmation,
Pending,
Running,
Finished(SharedString),
Error(SharedString),
}
impl ToolUseStatus {
pub fn text(&self) -> SharedString {
match self {
ToolUseStatus::NeedsConfirmation => "".into(),
ToolUseStatus::Pending => "".into(),
ToolUseStatus::Running => "".into(),
ToolUseStatus::Finished(out) => out.clone(),
ToolUseStatus::Error(out) => out.clone(),
}
}
}
pub const USING_TOOL_MARKER: &str = "<using_tool>";
pub struct ToolUseState {
tools: Entity<ToolWorkingSet>,
@ -54,10 +35,9 @@ pub struct ToolUseState {
tool_uses_by_user_message: HashMap<MessageId, Vec<LanguageModelToolUseId>>,
tool_results: HashMap<LanguageModelToolUseId, LanguageModelToolResult>,
pending_tool_uses_by_id: HashMap<LanguageModelToolUseId, PendingToolUse>,
tool_result_cards: HashMap<LanguageModelToolUseId, AnyToolCard>,
}
pub const USING_TOOL_MARKER: &str = "<using_tool>";
impl ToolUseState {
pub fn new(tools: Entity<ToolWorkingSet>) -> Self {
Self {
@ -66,6 +46,7 @@ impl ToolUseState {
tool_uses_by_user_message: HashMap::default(),
tool_results: HashMap::default(),
pending_tool_uses_by_id: HashMap::default(),
tool_result_cards: HashMap::default(),
}
}
@ -257,6 +238,18 @@ impl ToolUseState {
self.tool_results.get(tool_use_id)
}
pub fn tool_result_card(&self, tool_use_id: &LanguageModelToolUseId) -> Option<&AnyToolCard> {
self.tool_result_cards.get(tool_use_id)
}
pub fn insert_tool_result_card(
&mut self,
tool_use_id: LanguageModelToolUseId,
card: AnyToolCard,
) {
self.tool_result_cards.insert(tool_use_id, card);
}
pub fn request_tool_use(
&mut self,
assistant_message_id: MessageId,

View file

@ -9,6 +9,10 @@ use std::fmt::Formatter;
use std::sync::Arc;
use anyhow::Result;
use gpui::AnyElement;
use gpui::Context;
use gpui::IntoElement;
use gpui::Window;
use gpui::{App, Entity, SharedString, Task};
use icons::IconName;
use language_model::LanguageModelRequestMessage;
@ -24,16 +28,87 @@ pub fn init(cx: &mut App) {
ToolRegistry::default_global(cx);
}
/// The result of running a tool
#[derive(Debug, Clone)]
pub enum ToolUseStatus {
NeedsConfirmation,
Pending,
Running,
Finished(SharedString),
Error(SharedString),
}
impl ToolUseStatus {
pub fn text(&self) -> SharedString {
match self {
ToolUseStatus::NeedsConfirmation => "".into(),
ToolUseStatus::Pending => "".into(),
ToolUseStatus::Running => "".into(),
ToolUseStatus::Finished(out) => out.clone(),
ToolUseStatus::Error(out) => out.clone(),
}
}
}
/// The result of running a tool, containing both the asynchronous output
/// and an optional card view that can be rendered immediately.
pub struct ToolResult {
/// The asynchronous task that will eventually resolve to the tool's output
pub output: Task<Result<String>>,
/// An optional view to present the output of the tool.
pub card: Option<AnyToolCard>,
}
pub trait ToolCard: 'static + Sized {
fn render(
&mut self,
status: &ToolUseStatus,
window: &mut Window,
cx: &mut Context<Self>,
) -> impl IntoElement;
}
#[derive(Clone)]
pub struct AnyToolCard {
entity: gpui::AnyEntity,
render: fn(
entity: gpui::AnyEntity,
status: &ToolUseStatus,
window: &mut Window,
cx: &mut App,
) -> AnyElement,
}
impl<T: ToolCard> From<Entity<T>> for AnyToolCard {
fn from(entity: Entity<T>) -> Self {
fn downcast_render<T: ToolCard>(
entity: gpui::AnyEntity,
status: &ToolUseStatus,
window: &mut Window,
cx: &mut App,
) -> AnyElement {
let entity = entity.downcast::<T>().unwrap();
entity.update(cx, |entity, cx| {
entity.render(status, window, cx).into_any_element()
})
}
Self {
entity: entity.into(),
render: downcast_render::<T>,
}
}
}
impl AnyToolCard {
pub fn render(&self, status: &ToolUseStatus, window: &mut Window, cx: &mut App) -> AnyElement {
(self.render)(self.entity.clone(), status, window, cx)
}
}
impl From<Task<Result<String>>> for ToolResult {
/// Convert from a task to a ToolResult
/// Convert from a task to a ToolResult with no card
fn from(output: Task<Result<String>>) -> Self {
Self { output }
Self { output, card: None }
}
}

View file

@ -16,6 +16,7 @@ anyhow.workspace = true
assistant_tool.workspace = true
chrono.workspace = true
collections.workspace = true
feature_flags.workspace = true
futures.workspace = true
gpui.workspace = true
html_to_markdown.workspace = true
@ -32,7 +33,9 @@ ui.workspace = true
util.workspace = true
worktree.workspace = true
open = { workspace = true }
web_search.workspace = true
workspace-hack.workspace = true
zed_llm_client.workspace = true
[dev-dependencies]
collections = { workspace = true, features = ["test-support"] }

View file

@ -22,14 +22,17 @@ mod schema;
mod symbol_info_tool;
mod terminal_tool;
mod thinking_tool;
mod web_search_tool;
use std::sync::Arc;
use assistant_tool::ToolRegistry;
use copy_path_tool::CopyPathTool;
use feature_flags::FeatureFlagAppExt;
use gpui::App;
use http_client::HttpClientWithUrl;
use move_path_tool::MovePathTool;
use web_search_tool::WebSearchTool;
use crate::batch_tool::BatchTool;
use crate::code_action_tool::CodeActionTool;
@ -56,28 +59,39 @@ pub fn init(http_client: Arc<HttpClientWithUrl>, cx: &mut App) {
assistant_tool::init(cx);
let registry = ToolRegistry::global(cx);
registry.register_tool(TerminalTool);
registry.register_tool(BatchTool);
registry.register_tool(CreateDirectoryTool);
registry.register_tool(CreateFileTool);
registry.register_tool(CopyPathTool);
registry.register_tool(DeletePathTool);
registry.register_tool(FindReplaceFileTool);
registry.register_tool(SymbolInfoTool);
registry.register_tool(CodeActionTool);
registry.register_tool(MovePathTool);
registry.register_tool(DiagnosticsTool);
registry.register_tool(ListDirectoryTool);
registry.register_tool(NowTool);
registry.register_tool(OpenTool);
registry.register_tool(CodeSymbolsTool);
registry.register_tool(ContentsTool);
registry.register_tool(CopyPathTool);
registry.register_tool(CreateDirectoryTool);
registry.register_tool(CreateFileTool);
registry.register_tool(DeletePathTool);
registry.register_tool(DiagnosticsTool);
registry.register_tool(FetchTool::new(http_client));
registry.register_tool(FindReplaceFileTool);
registry.register_tool(ListDirectoryTool);
registry.register_tool(MovePathTool);
registry.register_tool(NowTool);
registry.register_tool(OpenTool);
registry.register_tool(PathSearchTool);
registry.register_tool(ReadFileTool);
registry.register_tool(RegexSearchTool);
registry.register_tool(RenameTool);
registry.register_tool(SymbolInfoTool);
registry.register_tool(TerminalTool);
registry.register_tool(ThinkingTool);
registry.register_tool(FetchTool::new(http_client));
cx.observe_flag::<feature_flags::ZedProWebSearchTool, _>({
move |is_enabled, cx| {
if is_enabled {
ToolRegistry::global(cx).register_tool(WebSearchTool);
} else {
ToolRegistry::global(cx).unregister_tool(WebSearchTool);
}
}
})
.detach();
}
#[cfg(test)]

View file

@ -0,0 +1,213 @@
use std::{sync::Arc, time::Duration};
use crate::schema::json_schema_for;
use anyhow::{Context as _, Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolCard, ToolResult, ToolUseStatus};
use futures::{FutureExt, TryFutureExt};
use gpui::{
Animation, AnimationExt, App, AppContext, Context, Entity, IntoElement, Task, Window,
pulsating_between,
};
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::Project;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use ui::{IconName, Tooltip, prelude::*};
use web_search::WebSearchRegistry;
use zed_llm_client::WebSearchResponse;
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
pub struct WebSearchToolInput {
/// The search term or question to query on the web.
query: String,
}
pub struct WebSearchTool;
impl Tool for WebSearchTool {
fn name(&self) -> String {
"web_search".into()
}
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
false
}
fn description(&self) -> String {
"Search the web for information using your query. Use this when you need real-time information, facts, or data that might not be in your training. Results will include snippets and links from relevant web pages.".into()
}
fn icon(&self) -> IconName {
IconName::Globe
}
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
json_schema_for::<WebSearchToolInput>(format)
}
fn ui_text(&self, _input: &serde_json::Value) -> String {
"Web Search".to_string()
}
fn run(
self: Arc<Self>,
input: serde_json::Value,
_messages: &[LanguageModelRequestMessage],
_project: Entity<Project>,
_action_log: Entity<ActionLog>,
cx: &mut App,
) -> ToolResult {
let input = match serde_json::from_value::<WebSearchToolInput>(input) {
Ok(input) => input,
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
};
let Some(provider) = WebSearchRegistry::read_global(cx).active_provider() else {
return Task::ready(Err(anyhow!("Web search is not available."))).into();
};
let search_task = provider.search(input.query, cx).map_err(Arc::new).shared();
let output = cx.background_spawn({
let search_task = search_task.clone();
async move {
let response = search_task.await.map_err(|err| anyhow!(err))?;
serde_json::to_string(&response).context("Failed to serialize search results")
}
});
ToolResult {
output,
card: Some(cx.new(|cx| WebSearchToolCard::new(search_task, cx)).into()),
}
}
}
struct WebSearchToolCard {
response: Option<Result<WebSearchResponse>>,
_task: Task<()>,
}
impl WebSearchToolCard {
fn new(
search_task: impl 'static + Future<Output = Result<WebSearchResponse, Arc<anyhow::Error>>>,
cx: &mut Context<Self>,
) -> Self {
let _task = cx.spawn(async move |this, cx| {
let response = search_task.await.map_err(|err| anyhow!(err));
this.update(cx, |this, cx| {
this.response = Some(response);
cx.notify();
})
.ok();
});
Self {
response: None,
_task,
}
}
}
impl ToolCard for WebSearchToolCard {
fn render(
&mut self,
_status: &ToolUseStatus,
_window: &mut Window,
cx: &mut Context<Self>,
) -> impl IntoElement {
let header = h_flex()
.id("tool-label-container")
.gap_1p5()
.max_w_full()
.overflow_x_scroll()
.child(
Icon::new(IconName::Globe)
.size(IconSize::XSmall)
.color(Color::Muted),
)
.child(match self.response.as_ref() {
Some(Ok(response)) => {
let text: SharedString = if response.citations.len() == 1 {
"1 result".into()
} else {
format!("{} results", response.citations.len()).into()
};
h_flex()
.gap_1p5()
.child(Label::new("Searched the Web").size(LabelSize::Small))
.child(
div()
.size(px(3.))
.rounded_full()
.bg(cx.theme().colors().text),
)
.child(Label::new(text).size(LabelSize::Small))
.into_any_element()
}
Some(Err(error)) => div()
.id("web-search-error")
.child(Label::new("Web Search failed").size(LabelSize::Small))
.tooltip(Tooltip::text(error.to_string()))
.into_any_element(),
None => Label::new("Searching the Web…")
.size(LabelSize::Small)
.with_animation(
"web-search-label",
Animation::new(Duration::from_secs(2))
.repeat()
.with_easing(pulsating_between(0.6, 1.)),
|label, delta| label.alpha(delta),
)
.into_any_element(),
})
.into_any();
let content =
self.response.as_ref().and_then(|response| match response {
Ok(response) => {
Some(
v_flex()
.ml_1p5()
.pl_1p5()
.border_l_1()
.border_color(cx.theme().colors().border_variant)
.gap_1()
.children(response.citations.iter().enumerate().map(
|(index, citation)| {
let title = citation.title.clone();
let url = citation.url.clone();
Button::new(("citation", index), title)
.label_size(LabelSize::Small)
.color(Color::Muted)
.icon(IconName::ArrowUpRight)
.icon_size(IconSize::XSmall)
.icon_position(IconPosition::End)
.truncate(true)
.tooltip({
let url = url.clone();
move |window, cx| {
Tooltip::with_meta(
"Citation Link",
None,
url.clone(),
window,
cx,
)
}
})
.on_click({
let url = url.clone();
move |_, _, cx| cx.open_url(&url)
})
},
))
.into_any(),
)
}
Err(_) => None,
});
v_flex().my_2().gap_1().child(header).children(content)
}
}

View file

@ -84,6 +84,11 @@ impl FeatureFlag for ZedPro {
const NAME: &'static str = "zed-pro";
}
pub struct ZedProWebSearchTool {}
impl FeatureFlag for ZedProWebSearchTool {
const NAME: &'static str = "zed-pro-web-search-tool";
}
pub struct NotebookFeatureFlag;
impl FeatureFlag for NotebookFeatureFlag {

View file

@ -160,7 +160,11 @@ impl Render for Tooltip {
}),
)
.when_some(self.meta.clone(), |this, meta| {
this.child(Label::new(meta).size(LabelSize::Small).color(Color::Muted))
this.child(
div()
.max_w_72()
.child(Label::new(meta).size(LabelSize::Small).color(Color::Muted)),
)
})
})
}

View file

@ -0,0 +1,20 @@
[package]
name = "web_search"
version = "0.1.0"
edition.workspace = true
publish.workspace = true
license = "GPL-3.0-or-later"
[lints]
workspace = true
[lib]
path = "src/web_search.rs"
[dependencies]
anyhow.workspace = true
collections.workspace = true
gpui.workspace = true
serde.workspace = true
workspace-hack.workspace = true
zed_llm_client.workspace = true

View file

@ -0,0 +1 @@
../../LICENSE-GPL

View file

@ -0,0 +1,64 @@
use anyhow::Result;
use collections::HashMap;
use gpui::{App, AppContext as _, Context, Entity, Global, SharedString, Task};
use std::sync::Arc;
use zed_llm_client::WebSearchResponse;
pub fn init(cx: &mut App) {
let registry = cx.new(|_cx| WebSearchRegistry::default());
cx.set_global(GlobalWebSearchRegistry(registry));
}
#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
pub struct WebSearchProviderId(pub SharedString);
pub trait WebSearchProvider {
fn id(&self) -> WebSearchProviderId;
fn search(&self, query: String, cx: &mut App) -> Task<Result<WebSearchResponse>>;
}
struct GlobalWebSearchRegistry(Entity<WebSearchRegistry>);
impl Global for GlobalWebSearchRegistry {}
#[derive(Default)]
pub struct WebSearchRegistry {
providers: HashMap<WebSearchProviderId, Arc<dyn WebSearchProvider>>,
active_provider: Option<Arc<dyn WebSearchProvider>>,
}
impl WebSearchRegistry {
pub fn global(cx: &App) -> Entity<Self> {
cx.global::<GlobalWebSearchRegistry>().0.clone()
}
pub fn read_global(cx: &App) -> &Self {
cx.global::<GlobalWebSearchRegistry>().0.read(cx)
}
pub fn providers(&self) -> impl Iterator<Item = &Arc<dyn WebSearchProvider>> {
self.providers.values()
}
pub fn active_provider(&self) -> Option<Arc<dyn WebSearchProvider>> {
self.active_provider.clone()
}
pub fn set_active_provider(&mut self, provider: Arc<dyn WebSearchProvider>) {
self.active_provider = Some(provider.clone());
self.providers.insert(provider.id(), provider);
}
pub fn register_provider<T: WebSearchProvider + 'static>(
&mut self,
provider: T,
_cx: &mut Context<Self>,
) {
let id = provider.id();
let provider = Arc::new(provider);
self.providers.insert(id.clone(), provider.clone());
if self.active_provider.is_none() {
self.active_provider = Some(provider);
}
}
}

View file

@ -0,0 +1,26 @@
[package]
name = "web_search_providers"
version = "0.1.0"
edition.workspace = true
publish.workspace = true
license = "GPL-3.0-or-later"
[lints]
workspace = true
[lib]
path = "src/web_search_providers.rs"
[dependencies]
anyhow.workspace = true
client.workspace = true
feature_flags.workspace = true
futures.workspace = true
gpui.workspace = true
http_client.workspace = true
language_model.workspace = true
serde.workspace = true
serde_json.workspace = true
web_search.workspace = true
workspace-hack.workspace = true
zed_llm_client.workspace = true

View file

@ -0,0 +1 @@
../../LICENSE-GPL

View file

@ -0,0 +1,103 @@
use std::sync::Arc;
use anyhow::{Context as _, Result, anyhow};
use client::Client;
use futures::AsyncReadExt as _;
use gpui::{App, AppContext, Context, Entity, Subscription, Task};
use http_client::{HttpClient, Method};
use language_model::{LlmApiToken, RefreshLlmTokenListener};
use web_search::{WebSearchProvider, WebSearchProviderId};
use zed_llm_client::{WebSearchBody, WebSearchResponse};
pub struct CloudWebSearchProvider {
state: Entity<State>,
}
impl CloudWebSearchProvider {
pub fn new(client: Arc<Client>, cx: &mut App) -> Self {
let state = cx.new(|cx| State::new(client, cx));
Self { state }
}
}
pub struct State {
client: Arc<Client>,
llm_api_token: LlmApiToken,
_llm_token_subscription: Subscription,
}
impl State {
pub fn new(client: Arc<Client>, cx: &mut Context<Self>) -> Self {
let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
Self {
client,
llm_api_token: LlmApiToken::default(),
_llm_token_subscription: cx.subscribe(
&refresh_llm_token_listener,
|this, _, _event, cx| {
let client = this.client.clone();
let llm_api_token = this.llm_api_token.clone();
cx.spawn(async move |_this, _cx| {
llm_api_token.refresh(&client).await?;
anyhow::Ok(())
})
.detach_and_log_err(cx);
},
),
}
}
}
impl WebSearchProvider for CloudWebSearchProvider {
fn id(&self) -> WebSearchProviderId {
WebSearchProviderId("zed.dev".into())
}
fn search(&self, query: String, cx: &mut App) -> Task<Result<WebSearchResponse>> {
let state = self.state.read(cx);
let client = state.client.clone();
let llm_api_token = state.llm_api_token.clone();
let body = WebSearchBody { query };
cx.background_spawn(async move { perform_web_search(client, llm_api_token, body).await })
}
}
async fn perform_web_search(
client: Arc<Client>,
llm_api_token: LlmApiToken,
body: WebSearchBody,
) -> Result<WebSearchResponse> {
let http_client = &client.http_client();
let token = llm_api_token.acquire(&client).await?;
let request_builder = http_client::Request::builder().method(Method::POST);
let request_builder = if let Ok(web_search_url) = std::env::var("ZED_WEB_SEARCH_URL") {
request_builder.uri(web_search_url)
} else {
request_builder.uri(http_client.build_zed_llm_url("/web_search", &[])?.as_ref())
};
let request = request_builder
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {token}"))
.body(serde_json::to_string(&body)?.into())?;
let mut response = http_client
.send(request)
.await
.context("failed to send web search request")?;
if response.status().is_success() {
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
return Ok(serde_json::from_str(&body)?);
} else {
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
return Err(anyhow!(
"error performing web search.\nStatus: {:?}\nBody: {body}",
response.status(),
));
}
}

View file

@ -0,0 +1,35 @@
mod cloud;
use client::Client;
use feature_flags::{FeatureFlagAppExt, ZedProWebSearchTool};
use gpui::{App, Context};
use std::sync::Arc;
use web_search::WebSearchRegistry;
pub fn init(client: Arc<Client>, cx: &mut App) {
let registry = WebSearchRegistry::global(cx);
registry.update(cx, |registry, cx| {
register_web_search_providers(registry, client, cx);
});
}
fn register_web_search_providers(
_registry: &mut WebSearchRegistry,
client: Arc<Client>,
cx: &mut Context<WebSearchRegistry>,
) {
cx.observe_flag::<ZedProWebSearchTool, _>({
let client = client.clone();
move |is_enabled, cx| {
if is_enabled {
WebSearchRegistry::global(cx).update(cx, |registry, cx| {
registry.register_provider(
cloud::CloudWebSearchProvider::new(client.clone(), cx),
cx,
);
});
}
}
})
.detach();
}

View file

@ -133,6 +133,8 @@ util.workspace = true
uuid.workspace = true
vim.workspace = true
vim_mode_setting.workspace = true
web_search.workspace = true
web_search_providers.workspace = true
welcome.workspace = true
workspace.workspace = true
zed_actions.workspace = true

View file

@ -490,6 +490,8 @@ fn main() {
app_state.fs.clone(),
cx,
);
web_search::init(cx);
web_search_providers::init(app_state.client.clone(), cx);
snippet_provider::init(cx);
inline_completion_registry::init(
app_state.client.clone(),

View file

@ -4258,6 +4258,8 @@ mod tests {
app_state.fs.clone(),
cx,
);
web_search::init(cx);
web_search_providers::init(app_state.client.clone(), cx);
let prompt_builder = PromptBuilder::load(app_state.fs.clone(), false, cx);
assistant::init(
app_state.fs.clone(),