agent: Add websearch tool (#28621)
Staff only for now. We'll work on making this usable for non zed.dev users later Release Notes: - N/A --------- Co-authored-by: Antonio Scandurra <me@as-cii.com> Co-authored-by: Danilo Leal <daniloleal09@gmail.com> Co-authored-by: Marshall Bowers <git@maxdeviant.com>
This commit is contained in:
parent
2b277123be
commit
456e54b87c
22 changed files with 675 additions and 51 deletions
39
Cargo.lock
generated
39
Cargo.lock
generated
|
@ -704,6 +704,7 @@ dependencies = [
|
|||
"assistant_tool",
|
||||
"chrono",
|
||||
"collections",
|
||||
"feature_flags",
|
||||
"futures 0.3.31",
|
||||
"gpui",
|
||||
"html_to_markdown",
|
||||
|
@ -721,9 +722,11 @@ dependencies = [
|
|||
"ui",
|
||||
"unindent",
|
||||
"util",
|
||||
"web_search",
|
||||
"workspace",
|
||||
"workspace-hack",
|
||||
"worktree",
|
||||
"zed_llm_client",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -16609,6 +16612,36 @@ dependencies = [
|
|||
"wasm-bindgen",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "web_search"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"collections",
|
||||
"gpui",
|
||||
"serde",
|
||||
"workspace-hack",
|
||||
"zed_llm_client",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "web_search_providers"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"client",
|
||||
"feature_flags",
|
||||
"futures 0.3.31",
|
||||
"gpui",
|
||||
"http_client",
|
||||
"language_model",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"web_search",
|
||||
"workspace-hack",
|
||||
"zed_llm_client",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "webpki-root-certs"
|
||||
version = "0.26.8"
|
||||
|
@ -18287,6 +18320,8 @@ dependencies = [
|
|||
"uuid",
|
||||
"vim",
|
||||
"vim_mode_setting",
|
||||
"web_search",
|
||||
"web_search_providers",
|
||||
"welcome",
|
||||
"windows 0.61.1",
|
||||
"winresource",
|
||||
|
@ -18351,9 +18386,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "zed_llm_client"
|
||||
version = "0.4.2"
|
||||
version = "0.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1d28a5d6bdb0f40acf5261c39cabbf65a13b55ba4b86d9beb5b8b1c484373f1a"
|
||||
checksum = "57a5e1b5b3ace3fb55292a4c14036723bb8a01fac4aeaa3c2b63b51228412f94"
|
||||
dependencies = [
|
||||
"serde",
|
||||
"serde_json",
|
||||
|
|
|
@ -165,6 +165,8 @@ members = [
|
|||
"crates/util_macros",
|
||||
"crates/vim",
|
||||
"crates/vim_mode_setting",
|
||||
"crates/web_search",
|
||||
"crates/web_search_providers",
|
||||
"crates/welcome",
|
||||
"crates/workspace",
|
||||
"crates/worktree",
|
||||
|
@ -370,6 +372,8 @@ util = { path = "crates/util" }
|
|||
util_macros = { path = "crates/util_macros" }
|
||||
vim = { path = "crates/vim" }
|
||||
vim_mode_setting = { path = "crates/vim_mode_setting" }
|
||||
web_search = { path = "crates/web_search" }
|
||||
web_search_providers = { path = "crates/web_search_providers" }
|
||||
welcome = { path = "crates/welcome" }
|
||||
workspace = { path = "crates/workspace" }
|
||||
worktree = { path = "crates/worktree" }
|
||||
|
@ -601,7 +605,7 @@ wasmtime-wasi = "29"
|
|||
which = "6.0.0"
|
||||
wit-component = "0.221"
|
||||
workspace-hack = "0.1.0"
|
||||
zed_llm_client = "0.4.2"
|
||||
zed_llm_client = "0.5.0"
|
||||
zstd = "0.11"
|
||||
metal = "0.29"
|
||||
|
||||
|
|
|
@ -652,7 +652,8 @@
|
|||
"path_search": true,
|
||||
"read_file": true,
|
||||
"regex_search": true,
|
||||
"thinking": true
|
||||
"thinking": true,
|
||||
"web_search": true
|
||||
}
|
||||
},
|
||||
"write": {
|
||||
|
@ -678,7 +679,8 @@
|
|||
"regex_search": true,
|
||||
"rename": true,
|
||||
"symbol_info": true,
|
||||
"thinking": true
|
||||
"thinking": true,
|
||||
"web_search": true
|
||||
}
|
||||
}
|
||||
},
|
||||
|
|
|
@ -5,11 +5,12 @@ use crate::thread::{
|
|||
ThreadEvent, ThreadFeedback,
|
||||
};
|
||||
use crate::thread_store::{RulesLoadingError, ThreadStore};
|
||||
use crate::tool_use::{PendingToolUseStatus, ToolUse, ToolUseStatus};
|
||||
use crate::tool_use::{PendingToolUseStatus, ToolUse};
|
||||
use crate::ui::{AddedContext, AgentNotification, AgentNotificationEvent, ContextPill};
|
||||
use crate::{AssistantPanel, OpenActiveThreadAsMarkdown};
|
||||
use anyhow::Context as _;
|
||||
use assistant_settings::{AssistantSettings, NotifyWhenAgentWaiting};
|
||||
use assistant_tool::ToolUseStatus;
|
||||
use collections::{HashMap, HashSet};
|
||||
use editor::scroll::Autoscroll;
|
||||
use editor::{Editor, EditorElement, EditorStyle, MultiBuffer};
|
||||
|
@ -943,8 +944,8 @@ impl ActiveThread {
|
|||
&tool_use.input,
|
||||
self.thread
|
||||
.read(cx)
|
||||
.tool_result(&tool_use.id)
|
||||
.map(|result| result.content.clone().into())
|
||||
.output_for_tool(&tool_use.id)
|
||||
.map(|output| output.clone().into())
|
||||
.unwrap_or("".into()),
|
||||
cx,
|
||||
);
|
||||
|
@ -2279,12 +2280,15 @@ impl ActiveThread {
|
|||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) -> impl IntoElement + use<> {
|
||||
if let Some(card) = self.thread.read(cx).card_for_tool(&tool_use.id) {
|
||||
return card.render(&tool_use.status, window, cx);
|
||||
}
|
||||
|
||||
let is_open = self
|
||||
.expanded_tool_uses
|
||||
.get(&tool_use.id)
|
||||
.copied()
|
||||
.unwrap_or_default();
|
||||
|
||||
let is_status_finished = matches!(&tool_use.status, ToolUseStatus::Finished(_));
|
||||
|
||||
let fs = self
|
||||
|
@ -2381,6 +2385,7 @@ impl ActiveThread {
|
|||
open_markdown_link(text, workspace.clone(), window, cx);
|
||||
}
|
||||
})
|
||||
.into_any_element()
|
||||
}),
|
||||
)),
|
||||
),
|
||||
|
@ -2437,6 +2442,7 @@ impl ActiveThread {
|
|||
open_markdown_link(text, workspace.clone(), window, cx);
|
||||
}
|
||||
})
|
||||
.into_any_element()
|
||||
})),
|
||||
),
|
||||
),
|
||||
|
@ -2767,7 +2773,7 @@ impl ActiveThread {
|
|||
)
|
||||
})
|
||||
}
|
||||
})
|
||||
}).into_any_element()
|
||||
}
|
||||
|
||||
fn render_rules_item(&self, cx: &Context<Self>) -> AnyElement {
|
||||
|
|
|
@ -6,7 +6,7 @@ use std::time::Instant;
|
|||
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use assistant_settings::AssistantSettings;
|
||||
use assistant_tool::{ActionLog, Tool, ToolWorkingSet};
|
||||
use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
|
||||
use chrono::{DateTime, Utc};
|
||||
use collections::{BTreeMap, HashMap};
|
||||
use feature_flags::{self, FeatureFlagAppExt};
|
||||
|
@ -631,6 +631,14 @@ impl Thread {
|
|||
self.tool_use.tool_result(id)
|
||||
}
|
||||
|
||||
pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
|
||||
Some(&self.tool_use.tool_result(id)?.content)
|
||||
}
|
||||
|
||||
pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
|
||||
self.tool_use.tool_result_card(id).cloned()
|
||||
}
|
||||
|
||||
pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
|
||||
self.tool_use.message_has_tool_results(message_id)
|
||||
}
|
||||
|
@ -1426,6 +1434,12 @@ impl Thread {
|
|||
)
|
||||
};
|
||||
|
||||
// Store the card separately if it exists
|
||||
if let Some(card) = tool_result.card.clone() {
|
||||
self.tool_use
|
||||
.insert_tool_result_card(tool_use_id.clone(), card);
|
||||
}
|
||||
|
||||
cx.spawn({
|
||||
async move |thread: WeakEntity<Thread>, cx| {
|
||||
let output = tool_result.output.await;
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use anyhow::Result;
|
||||
use assistant_tool::{Tool, ToolWorkingSet};
|
||||
use assistant_tool::{AnyToolCard, Tool, ToolUseStatus, ToolWorkingSet};
|
||||
use collections::HashMap;
|
||||
use futures::FutureExt as _;
|
||||
use futures::future::Shared;
|
||||
|
@ -27,26 +27,7 @@ pub struct ToolUse {
|
|||
pub needs_confirmation: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum ToolUseStatus {
|
||||
NeedsConfirmation,
|
||||
Pending,
|
||||
Running,
|
||||
Finished(SharedString),
|
||||
Error(SharedString),
|
||||
}
|
||||
|
||||
impl ToolUseStatus {
|
||||
pub fn text(&self) -> SharedString {
|
||||
match self {
|
||||
ToolUseStatus::NeedsConfirmation => "".into(),
|
||||
ToolUseStatus::Pending => "".into(),
|
||||
ToolUseStatus::Running => "".into(),
|
||||
ToolUseStatus::Finished(out) => out.clone(),
|
||||
ToolUseStatus::Error(out) => out.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
pub const USING_TOOL_MARKER: &str = "<using_tool>";
|
||||
|
||||
pub struct ToolUseState {
|
||||
tools: Entity<ToolWorkingSet>,
|
||||
|
@ -54,10 +35,9 @@ pub struct ToolUseState {
|
|||
tool_uses_by_user_message: HashMap<MessageId, Vec<LanguageModelToolUseId>>,
|
||||
tool_results: HashMap<LanguageModelToolUseId, LanguageModelToolResult>,
|
||||
pending_tool_uses_by_id: HashMap<LanguageModelToolUseId, PendingToolUse>,
|
||||
tool_result_cards: HashMap<LanguageModelToolUseId, AnyToolCard>,
|
||||
}
|
||||
|
||||
pub const USING_TOOL_MARKER: &str = "<using_tool>";
|
||||
|
||||
impl ToolUseState {
|
||||
pub fn new(tools: Entity<ToolWorkingSet>) -> Self {
|
||||
Self {
|
||||
|
@ -66,6 +46,7 @@ impl ToolUseState {
|
|||
tool_uses_by_user_message: HashMap::default(),
|
||||
tool_results: HashMap::default(),
|
||||
pending_tool_uses_by_id: HashMap::default(),
|
||||
tool_result_cards: HashMap::default(),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -257,6 +238,18 @@ impl ToolUseState {
|
|||
self.tool_results.get(tool_use_id)
|
||||
}
|
||||
|
||||
pub fn tool_result_card(&self, tool_use_id: &LanguageModelToolUseId) -> Option<&AnyToolCard> {
|
||||
self.tool_result_cards.get(tool_use_id)
|
||||
}
|
||||
|
||||
pub fn insert_tool_result_card(
|
||||
&mut self,
|
||||
tool_use_id: LanguageModelToolUseId,
|
||||
card: AnyToolCard,
|
||||
) {
|
||||
self.tool_result_cards.insert(tool_use_id, card);
|
||||
}
|
||||
|
||||
pub fn request_tool_use(
|
||||
&mut self,
|
||||
assistant_message_id: MessageId,
|
||||
|
|
|
@ -9,6 +9,10 @@ use std::fmt::Formatter;
|
|||
use std::sync::Arc;
|
||||
|
||||
use anyhow::Result;
|
||||
use gpui::AnyElement;
|
||||
use gpui::Context;
|
||||
use gpui::IntoElement;
|
||||
use gpui::Window;
|
||||
use gpui::{App, Entity, SharedString, Task};
|
||||
use icons::IconName;
|
||||
use language_model::LanguageModelRequestMessage;
|
||||
|
@ -24,16 +28,87 @@ pub fn init(cx: &mut App) {
|
|||
ToolRegistry::default_global(cx);
|
||||
}
|
||||
|
||||
/// The result of running a tool
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum ToolUseStatus {
|
||||
NeedsConfirmation,
|
||||
Pending,
|
||||
Running,
|
||||
Finished(SharedString),
|
||||
Error(SharedString),
|
||||
}
|
||||
|
||||
impl ToolUseStatus {
|
||||
pub fn text(&self) -> SharedString {
|
||||
match self {
|
||||
ToolUseStatus::NeedsConfirmation => "".into(),
|
||||
ToolUseStatus::Pending => "".into(),
|
||||
ToolUseStatus::Running => "".into(),
|
||||
ToolUseStatus::Finished(out) => out.clone(),
|
||||
ToolUseStatus::Error(out) => out.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// The result of running a tool, containing both the asynchronous output
|
||||
/// and an optional card view that can be rendered immediately.
|
||||
pub struct ToolResult {
|
||||
/// The asynchronous task that will eventually resolve to the tool's output
|
||||
pub output: Task<Result<String>>,
|
||||
/// An optional view to present the output of the tool.
|
||||
pub card: Option<AnyToolCard>,
|
||||
}
|
||||
|
||||
pub trait ToolCard: 'static + Sized {
|
||||
fn render(
|
||||
&mut self,
|
||||
status: &ToolUseStatus,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) -> impl IntoElement;
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct AnyToolCard {
|
||||
entity: gpui::AnyEntity,
|
||||
render: fn(
|
||||
entity: gpui::AnyEntity,
|
||||
status: &ToolUseStatus,
|
||||
window: &mut Window,
|
||||
cx: &mut App,
|
||||
) -> AnyElement,
|
||||
}
|
||||
|
||||
impl<T: ToolCard> From<Entity<T>> for AnyToolCard {
|
||||
fn from(entity: Entity<T>) -> Self {
|
||||
fn downcast_render<T: ToolCard>(
|
||||
entity: gpui::AnyEntity,
|
||||
status: &ToolUseStatus,
|
||||
window: &mut Window,
|
||||
cx: &mut App,
|
||||
) -> AnyElement {
|
||||
let entity = entity.downcast::<T>().unwrap();
|
||||
entity.update(cx, |entity, cx| {
|
||||
entity.render(status, window, cx).into_any_element()
|
||||
})
|
||||
}
|
||||
|
||||
Self {
|
||||
entity: entity.into(),
|
||||
render: downcast_render::<T>,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AnyToolCard {
|
||||
pub fn render(&self, status: &ToolUseStatus, window: &mut Window, cx: &mut App) -> AnyElement {
|
||||
(self.render)(self.entity.clone(), status, window, cx)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Task<Result<String>>> for ToolResult {
|
||||
/// Convert from a task to a ToolResult
|
||||
/// Convert from a task to a ToolResult with no card
|
||||
fn from(output: Task<Result<String>>) -> Self {
|
||||
Self { output }
|
||||
Self { output, card: None }
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -16,6 +16,7 @@ anyhow.workspace = true
|
|||
assistant_tool.workspace = true
|
||||
chrono.workspace = true
|
||||
collections.workspace = true
|
||||
feature_flags.workspace = true
|
||||
futures.workspace = true
|
||||
gpui.workspace = true
|
||||
html_to_markdown.workspace = true
|
||||
|
@ -32,7 +33,9 @@ ui.workspace = true
|
|||
util.workspace = true
|
||||
worktree.workspace = true
|
||||
open = { workspace = true }
|
||||
web_search.workspace = true
|
||||
workspace-hack.workspace = true
|
||||
zed_llm_client.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
collections = { workspace = true, features = ["test-support"] }
|
||||
|
|
|
@ -22,14 +22,17 @@ mod schema;
|
|||
mod symbol_info_tool;
|
||||
mod terminal_tool;
|
||||
mod thinking_tool;
|
||||
mod web_search_tool;
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use assistant_tool::ToolRegistry;
|
||||
use copy_path_tool::CopyPathTool;
|
||||
use feature_flags::FeatureFlagAppExt;
|
||||
use gpui::App;
|
||||
use http_client::HttpClientWithUrl;
|
||||
use move_path_tool::MovePathTool;
|
||||
use web_search_tool::WebSearchTool;
|
||||
|
||||
use crate::batch_tool::BatchTool;
|
||||
use crate::code_action_tool::CodeActionTool;
|
||||
|
@ -56,28 +59,39 @@ pub fn init(http_client: Arc<HttpClientWithUrl>, cx: &mut App) {
|
|||
assistant_tool::init(cx);
|
||||
|
||||
let registry = ToolRegistry::global(cx);
|
||||
registry.register_tool(TerminalTool);
|
||||
registry.register_tool(BatchTool);
|
||||
registry.register_tool(CreateDirectoryTool);
|
||||
registry.register_tool(CreateFileTool);
|
||||
registry.register_tool(CopyPathTool);
|
||||
registry.register_tool(DeletePathTool);
|
||||
registry.register_tool(FindReplaceFileTool);
|
||||
registry.register_tool(SymbolInfoTool);
|
||||
registry.register_tool(CodeActionTool);
|
||||
registry.register_tool(MovePathTool);
|
||||
registry.register_tool(DiagnosticsTool);
|
||||
registry.register_tool(ListDirectoryTool);
|
||||
registry.register_tool(NowTool);
|
||||
registry.register_tool(OpenTool);
|
||||
registry.register_tool(CodeSymbolsTool);
|
||||
registry.register_tool(ContentsTool);
|
||||
registry.register_tool(CopyPathTool);
|
||||
registry.register_tool(CreateDirectoryTool);
|
||||
registry.register_tool(CreateFileTool);
|
||||
registry.register_tool(DeletePathTool);
|
||||
registry.register_tool(DiagnosticsTool);
|
||||
registry.register_tool(FetchTool::new(http_client));
|
||||
registry.register_tool(FindReplaceFileTool);
|
||||
registry.register_tool(ListDirectoryTool);
|
||||
registry.register_tool(MovePathTool);
|
||||
registry.register_tool(NowTool);
|
||||
registry.register_tool(OpenTool);
|
||||
registry.register_tool(PathSearchTool);
|
||||
registry.register_tool(ReadFileTool);
|
||||
registry.register_tool(RegexSearchTool);
|
||||
registry.register_tool(RenameTool);
|
||||
registry.register_tool(SymbolInfoTool);
|
||||
registry.register_tool(TerminalTool);
|
||||
registry.register_tool(ThinkingTool);
|
||||
registry.register_tool(FetchTool::new(http_client));
|
||||
|
||||
cx.observe_flag::<feature_flags::ZedProWebSearchTool, _>({
|
||||
move |is_enabled, cx| {
|
||||
if is_enabled {
|
||||
ToolRegistry::global(cx).register_tool(WebSearchTool);
|
||||
} else {
|
||||
ToolRegistry::global(cx).unregister_tool(WebSearchTool);
|
||||
}
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
|
213
crates/assistant_tools/src/web_search_tool.rs
Normal file
213
crates/assistant_tools/src/web_search_tool.rs
Normal file
|
@ -0,0 +1,213 @@
|
|||
use std::{sync::Arc, time::Duration};
|
||||
|
||||
use crate::schema::json_schema_for;
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use assistant_tool::{ActionLog, Tool, ToolCard, ToolResult, ToolUseStatus};
|
||||
use futures::{FutureExt, TryFutureExt};
|
||||
use gpui::{
|
||||
Animation, AnimationExt, App, AppContext, Context, Entity, IntoElement, Task, Window,
|
||||
pulsating_between,
|
||||
};
|
||||
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
|
||||
use project::Project;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use ui::{IconName, Tooltip, prelude::*};
|
||||
use web_search::WebSearchRegistry;
|
||||
use zed_llm_client::WebSearchResponse;
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct WebSearchToolInput {
|
||||
/// The search term or question to query on the web.
|
||||
query: String,
|
||||
}
|
||||
|
||||
pub struct WebSearchTool;
|
||||
|
||||
impl Tool for WebSearchTool {
|
||||
fn name(&self) -> String {
|
||||
"web_search".into()
|
||||
}
|
||||
|
||||
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn description(&self) -> String {
|
||||
"Search the web for information using your query. Use this when you need real-time information, facts, or data that might not be in your training. Results will include snippets and links from relevant web pages.".into()
|
||||
}
|
||||
|
||||
fn icon(&self) -> IconName {
|
||||
IconName::Globe
|
||||
}
|
||||
|
||||
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
|
||||
json_schema_for::<WebSearchToolInput>(format)
|
||||
}
|
||||
|
||||
fn ui_text(&self, _input: &serde_json::Value) -> String {
|
||||
"Web Search".to_string()
|
||||
}
|
||||
|
||||
fn run(
|
||||
self: Arc<Self>,
|
||||
input: serde_json::Value,
|
||||
_messages: &[LanguageModelRequestMessage],
|
||||
_project: Entity<Project>,
|
||||
_action_log: Entity<ActionLog>,
|
||||
cx: &mut App,
|
||||
) -> ToolResult {
|
||||
let input = match serde_json::from_value::<WebSearchToolInput>(input) {
|
||||
Ok(input) => input,
|
||||
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
|
||||
};
|
||||
let Some(provider) = WebSearchRegistry::read_global(cx).active_provider() else {
|
||||
return Task::ready(Err(anyhow!("Web search is not available."))).into();
|
||||
};
|
||||
|
||||
let search_task = provider.search(input.query, cx).map_err(Arc::new).shared();
|
||||
let output = cx.background_spawn({
|
||||
let search_task = search_task.clone();
|
||||
async move {
|
||||
let response = search_task.await.map_err(|err| anyhow!(err))?;
|
||||
serde_json::to_string(&response).context("Failed to serialize search results")
|
||||
}
|
||||
});
|
||||
|
||||
ToolResult {
|
||||
output,
|
||||
card: Some(cx.new(|cx| WebSearchToolCard::new(search_task, cx)).into()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct WebSearchToolCard {
|
||||
response: Option<Result<WebSearchResponse>>,
|
||||
_task: Task<()>,
|
||||
}
|
||||
|
||||
impl WebSearchToolCard {
|
||||
fn new(
|
||||
search_task: impl 'static + Future<Output = Result<WebSearchResponse, Arc<anyhow::Error>>>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Self {
|
||||
let _task = cx.spawn(async move |this, cx| {
|
||||
let response = search_task.await.map_err(|err| anyhow!(err));
|
||||
this.update(cx, |this, cx| {
|
||||
this.response = Some(response);
|
||||
cx.notify();
|
||||
})
|
||||
.ok();
|
||||
});
|
||||
|
||||
Self {
|
||||
response: None,
|
||||
_task,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ToolCard for WebSearchToolCard {
|
||||
fn render(
|
||||
&mut self,
|
||||
_status: &ToolUseStatus,
|
||||
_window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) -> impl IntoElement {
|
||||
let header = h_flex()
|
||||
.id("tool-label-container")
|
||||
.gap_1p5()
|
||||
.max_w_full()
|
||||
.overflow_x_scroll()
|
||||
.child(
|
||||
Icon::new(IconName::Globe)
|
||||
.size(IconSize::XSmall)
|
||||
.color(Color::Muted),
|
||||
)
|
||||
.child(match self.response.as_ref() {
|
||||
Some(Ok(response)) => {
|
||||
let text: SharedString = if response.citations.len() == 1 {
|
||||
"1 result".into()
|
||||
} else {
|
||||
format!("{} results", response.citations.len()).into()
|
||||
};
|
||||
h_flex()
|
||||
.gap_1p5()
|
||||
.child(Label::new("Searched the Web").size(LabelSize::Small))
|
||||
.child(
|
||||
div()
|
||||
.size(px(3.))
|
||||
.rounded_full()
|
||||
.bg(cx.theme().colors().text),
|
||||
)
|
||||
.child(Label::new(text).size(LabelSize::Small))
|
||||
.into_any_element()
|
||||
}
|
||||
Some(Err(error)) => div()
|
||||
.id("web-search-error")
|
||||
.child(Label::new("Web Search failed").size(LabelSize::Small))
|
||||
.tooltip(Tooltip::text(error.to_string()))
|
||||
.into_any_element(),
|
||||
|
||||
None => Label::new("Searching the Web…")
|
||||
.size(LabelSize::Small)
|
||||
.with_animation(
|
||||
"web-search-label",
|
||||
Animation::new(Duration::from_secs(2))
|
||||
.repeat()
|
||||
.with_easing(pulsating_between(0.6, 1.)),
|
||||
|label, delta| label.alpha(delta),
|
||||
)
|
||||
.into_any_element(),
|
||||
})
|
||||
.into_any();
|
||||
|
||||
let content =
|
||||
self.response.as_ref().and_then(|response| match response {
|
||||
Ok(response) => {
|
||||
Some(
|
||||
v_flex()
|
||||
.ml_1p5()
|
||||
.pl_1p5()
|
||||
.border_l_1()
|
||||
.border_color(cx.theme().colors().border_variant)
|
||||
.gap_1()
|
||||
.children(response.citations.iter().enumerate().map(
|
||||
|(index, citation)| {
|
||||
let title = citation.title.clone();
|
||||
let url = citation.url.clone();
|
||||
|
||||
Button::new(("citation", index), title)
|
||||
.label_size(LabelSize::Small)
|
||||
.color(Color::Muted)
|
||||
.icon(IconName::ArrowUpRight)
|
||||
.icon_size(IconSize::XSmall)
|
||||
.icon_position(IconPosition::End)
|
||||
.truncate(true)
|
||||
.tooltip({
|
||||
let url = url.clone();
|
||||
move |window, cx| {
|
||||
Tooltip::with_meta(
|
||||
"Citation Link",
|
||||
None,
|
||||
url.clone(),
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
}
|
||||
})
|
||||
.on_click({
|
||||
let url = url.clone();
|
||||
move |_, _, cx| cx.open_url(&url)
|
||||
})
|
||||
},
|
||||
))
|
||||
.into_any(),
|
||||
)
|
||||
}
|
||||
Err(_) => None,
|
||||
});
|
||||
|
||||
v_flex().my_2().gap_1().child(header).children(content)
|
||||
}
|
||||
}
|
|
@ -84,6 +84,11 @@ impl FeatureFlag for ZedPro {
|
|||
const NAME: &'static str = "zed-pro";
|
||||
}
|
||||
|
||||
pub struct ZedProWebSearchTool {}
|
||||
impl FeatureFlag for ZedProWebSearchTool {
|
||||
const NAME: &'static str = "zed-pro-web-search-tool";
|
||||
}
|
||||
|
||||
pub struct NotebookFeatureFlag;
|
||||
|
||||
impl FeatureFlag for NotebookFeatureFlag {
|
||||
|
|
|
@ -160,7 +160,11 @@ impl Render for Tooltip {
|
|||
}),
|
||||
)
|
||||
.when_some(self.meta.clone(), |this, meta| {
|
||||
this.child(Label::new(meta).size(LabelSize::Small).color(Color::Muted))
|
||||
this.child(
|
||||
div()
|
||||
.max_w_72()
|
||||
.child(Label::new(meta).size(LabelSize::Small).color(Color::Muted)),
|
||||
)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
|
20
crates/web_search/Cargo.toml
Normal file
20
crates/web_search/Cargo.toml
Normal file
|
@ -0,0 +1,20 @@
|
|||
[package]
|
||||
name = "web_search"
|
||||
version = "0.1.0"
|
||||
edition.workspace = true
|
||||
publish.workspace = true
|
||||
license = "GPL-3.0-or-later"
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[lib]
|
||||
path = "src/web_search.rs"
|
||||
|
||||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
collections.workspace = true
|
||||
gpui.workspace = true
|
||||
serde.workspace = true
|
||||
workspace-hack.workspace = true
|
||||
zed_llm_client.workspace = true
|
1
crates/web_search/LICENSE-GPL
Symbolic link
1
crates/web_search/LICENSE-GPL
Symbolic link
|
@ -0,0 +1 @@
|
|||
../../LICENSE-GPL
|
64
crates/web_search/src/web_search.rs
Normal file
64
crates/web_search/src/web_search.rs
Normal file
|
@ -0,0 +1,64 @@
|
|||
use anyhow::Result;
|
||||
use collections::HashMap;
|
||||
use gpui::{App, AppContext as _, Context, Entity, Global, SharedString, Task};
|
||||
use std::sync::Arc;
|
||||
use zed_llm_client::WebSearchResponse;
|
||||
|
||||
pub fn init(cx: &mut App) {
|
||||
let registry = cx.new(|_cx| WebSearchRegistry::default());
|
||||
cx.set_global(GlobalWebSearchRegistry(registry));
|
||||
}
|
||||
|
||||
#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
|
||||
pub struct WebSearchProviderId(pub SharedString);
|
||||
|
||||
pub trait WebSearchProvider {
|
||||
fn id(&self) -> WebSearchProviderId;
|
||||
fn search(&self, query: String, cx: &mut App) -> Task<Result<WebSearchResponse>>;
|
||||
}
|
||||
|
||||
struct GlobalWebSearchRegistry(Entity<WebSearchRegistry>);
|
||||
|
||||
impl Global for GlobalWebSearchRegistry {}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct WebSearchRegistry {
|
||||
providers: HashMap<WebSearchProviderId, Arc<dyn WebSearchProvider>>,
|
||||
active_provider: Option<Arc<dyn WebSearchProvider>>,
|
||||
}
|
||||
|
||||
impl WebSearchRegistry {
|
||||
pub fn global(cx: &App) -> Entity<Self> {
|
||||
cx.global::<GlobalWebSearchRegistry>().0.clone()
|
||||
}
|
||||
|
||||
pub fn read_global(cx: &App) -> &Self {
|
||||
cx.global::<GlobalWebSearchRegistry>().0.read(cx)
|
||||
}
|
||||
|
||||
pub fn providers(&self) -> impl Iterator<Item = &Arc<dyn WebSearchProvider>> {
|
||||
self.providers.values()
|
||||
}
|
||||
|
||||
pub fn active_provider(&self) -> Option<Arc<dyn WebSearchProvider>> {
|
||||
self.active_provider.clone()
|
||||
}
|
||||
|
||||
pub fn set_active_provider(&mut self, provider: Arc<dyn WebSearchProvider>) {
|
||||
self.active_provider = Some(provider.clone());
|
||||
self.providers.insert(provider.id(), provider);
|
||||
}
|
||||
|
||||
pub fn register_provider<T: WebSearchProvider + 'static>(
|
||||
&mut self,
|
||||
provider: T,
|
||||
_cx: &mut Context<Self>,
|
||||
) {
|
||||
let id = provider.id();
|
||||
let provider = Arc::new(provider);
|
||||
self.providers.insert(id.clone(), provider.clone());
|
||||
if self.active_provider.is_none() {
|
||||
self.active_provider = Some(provider);
|
||||
}
|
||||
}
|
||||
}
|
26
crates/web_search_providers/Cargo.toml
Normal file
26
crates/web_search_providers/Cargo.toml
Normal file
|
@ -0,0 +1,26 @@
|
|||
[package]
|
||||
name = "web_search_providers"
|
||||
version = "0.1.0"
|
||||
edition.workspace = true
|
||||
publish.workspace = true
|
||||
license = "GPL-3.0-or-later"
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[lib]
|
||||
path = "src/web_search_providers.rs"
|
||||
|
||||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
client.workspace = true
|
||||
feature_flags.workspace = true
|
||||
futures.workspace = true
|
||||
gpui.workspace = true
|
||||
http_client.workspace = true
|
||||
language_model.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
web_search.workspace = true
|
||||
workspace-hack.workspace = true
|
||||
zed_llm_client.workspace = true
|
1
crates/web_search_providers/LICENSE-GPL
Symbolic link
1
crates/web_search_providers/LICENSE-GPL
Symbolic link
|
@ -0,0 +1 @@
|
|||
../../LICENSE-GPL
|
103
crates/web_search_providers/src/cloud.rs
Normal file
103
crates/web_search_providers/src/cloud.rs
Normal file
|
@ -0,0 +1,103 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use client::Client;
|
||||
use futures::AsyncReadExt as _;
|
||||
use gpui::{App, AppContext, Context, Entity, Subscription, Task};
|
||||
use http_client::{HttpClient, Method};
|
||||
use language_model::{LlmApiToken, RefreshLlmTokenListener};
|
||||
use web_search::{WebSearchProvider, WebSearchProviderId};
|
||||
use zed_llm_client::{WebSearchBody, WebSearchResponse};
|
||||
|
||||
pub struct CloudWebSearchProvider {
|
||||
state: Entity<State>,
|
||||
}
|
||||
|
||||
impl CloudWebSearchProvider {
|
||||
pub fn new(client: Arc<Client>, cx: &mut App) -> Self {
|
||||
let state = cx.new(|cx| State::new(client, cx));
|
||||
|
||||
Self { state }
|
||||
}
|
||||
}
|
||||
|
||||
pub struct State {
|
||||
client: Arc<Client>,
|
||||
llm_api_token: LlmApiToken,
|
||||
_llm_token_subscription: Subscription,
|
||||
}
|
||||
|
||||
impl State {
|
||||
pub fn new(client: Arc<Client>, cx: &mut Context<Self>) -> Self {
|
||||
let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
|
||||
|
||||
Self {
|
||||
client,
|
||||
llm_api_token: LlmApiToken::default(),
|
||||
_llm_token_subscription: cx.subscribe(
|
||||
&refresh_llm_token_listener,
|
||||
|this, _, _event, cx| {
|
||||
let client = this.client.clone();
|
||||
let llm_api_token = this.llm_api_token.clone();
|
||||
cx.spawn(async move |_this, _cx| {
|
||||
llm_api_token.refresh(&client).await?;
|
||||
anyhow::Ok(())
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
},
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl WebSearchProvider for CloudWebSearchProvider {
|
||||
fn id(&self) -> WebSearchProviderId {
|
||||
WebSearchProviderId("zed.dev".into())
|
||||
}
|
||||
|
||||
fn search(&self, query: String, cx: &mut App) -> Task<Result<WebSearchResponse>> {
|
||||
let state = self.state.read(cx);
|
||||
let client = state.client.clone();
|
||||
let llm_api_token = state.llm_api_token.clone();
|
||||
let body = WebSearchBody { query };
|
||||
cx.background_spawn(async move { perform_web_search(client, llm_api_token, body).await })
|
||||
}
|
||||
}
|
||||
|
||||
async fn perform_web_search(
|
||||
client: Arc<Client>,
|
||||
llm_api_token: LlmApiToken,
|
||||
body: WebSearchBody,
|
||||
) -> Result<WebSearchResponse> {
|
||||
let http_client = &client.http_client();
|
||||
|
||||
let token = llm_api_token.acquire(&client).await?;
|
||||
|
||||
let request_builder = http_client::Request::builder().method(Method::POST);
|
||||
let request_builder = if let Ok(web_search_url) = std::env::var("ZED_WEB_SEARCH_URL") {
|
||||
request_builder.uri(web_search_url)
|
||||
} else {
|
||||
request_builder.uri(http_client.build_zed_llm_url("/web_search", &[])?.as_ref())
|
||||
};
|
||||
let request = request_builder
|
||||
.header("Content-Type", "application/json")
|
||||
.header("Authorization", format!("Bearer {token}"))
|
||||
.body(serde_json::to_string(&body)?.into())?;
|
||||
let mut response = http_client
|
||||
.send(request)
|
||||
.await
|
||||
.context("failed to send web search request")?;
|
||||
|
||||
if response.status().is_success() {
|
||||
let mut body = String::new();
|
||||
response.body_mut().read_to_string(&mut body).await?;
|
||||
return Ok(serde_json::from_str(&body)?);
|
||||
} else {
|
||||
let mut body = String::new();
|
||||
response.body_mut().read_to_string(&mut body).await?;
|
||||
return Err(anyhow!(
|
||||
"error performing web search.\nStatus: {:?}\nBody: {body}",
|
||||
response.status(),
|
||||
));
|
||||
}
|
||||
}
|
35
crates/web_search_providers/src/web_search_providers.rs
Normal file
35
crates/web_search_providers/src/web_search_providers.rs
Normal file
|
@ -0,0 +1,35 @@
|
|||
mod cloud;
|
||||
|
||||
use client::Client;
|
||||
use feature_flags::{FeatureFlagAppExt, ZedProWebSearchTool};
|
||||
use gpui::{App, Context};
|
||||
use std::sync::Arc;
|
||||
use web_search::WebSearchRegistry;
|
||||
|
||||
pub fn init(client: Arc<Client>, cx: &mut App) {
|
||||
let registry = WebSearchRegistry::global(cx);
|
||||
registry.update(cx, |registry, cx| {
|
||||
register_web_search_providers(registry, client, cx);
|
||||
});
|
||||
}
|
||||
|
||||
fn register_web_search_providers(
|
||||
_registry: &mut WebSearchRegistry,
|
||||
client: Arc<Client>,
|
||||
cx: &mut Context<WebSearchRegistry>,
|
||||
) {
|
||||
cx.observe_flag::<ZedProWebSearchTool, _>({
|
||||
let client = client.clone();
|
||||
move |is_enabled, cx| {
|
||||
if is_enabled {
|
||||
WebSearchRegistry::global(cx).update(cx, |registry, cx| {
|
||||
registry.register_provider(
|
||||
cloud::CloudWebSearchProvider::new(client.clone(), cx),
|
||||
cx,
|
||||
);
|
||||
});
|
||||
}
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
}
|
|
@ -133,6 +133,8 @@ util.workspace = true
|
|||
uuid.workspace = true
|
||||
vim.workspace = true
|
||||
vim_mode_setting.workspace = true
|
||||
web_search.workspace = true
|
||||
web_search_providers.workspace = true
|
||||
welcome.workspace = true
|
||||
workspace.workspace = true
|
||||
zed_actions.workspace = true
|
||||
|
|
|
@ -490,6 +490,8 @@ fn main() {
|
|||
app_state.fs.clone(),
|
||||
cx,
|
||||
);
|
||||
web_search::init(cx);
|
||||
web_search_providers::init(app_state.client.clone(), cx);
|
||||
snippet_provider::init(cx);
|
||||
inline_completion_registry::init(
|
||||
app_state.client.clone(),
|
||||
|
|
|
@ -4258,6 +4258,8 @@ mod tests {
|
|||
app_state.fs.clone(),
|
||||
cx,
|
||||
);
|
||||
web_search::init(cx);
|
||||
web_search_providers::init(app_state.client.clone(), cx);
|
||||
let prompt_builder = PromptBuilder::load(app_state.fs.clone(), false, cx);
|
||||
assistant::init(
|
||||
app_state.fs.clone(),
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue