agent: Add websearch tool (#28621)
Staff only for now. We'll work on making this usable for non zed.dev users later Release Notes: - N/A --------- Co-authored-by: Antonio Scandurra <me@as-cii.com> Co-authored-by: Danilo Leal <daniloleal09@gmail.com> Co-authored-by: Marshall Bowers <git@maxdeviant.com>
This commit is contained in:
parent
2b277123be
commit
456e54b87c
22 changed files with 675 additions and 51 deletions
39
Cargo.lock
generated
39
Cargo.lock
generated
|
@ -704,6 +704,7 @@ dependencies = [
|
||||||
"assistant_tool",
|
"assistant_tool",
|
||||||
"chrono",
|
"chrono",
|
||||||
"collections",
|
"collections",
|
||||||
|
"feature_flags",
|
||||||
"futures 0.3.31",
|
"futures 0.3.31",
|
||||||
"gpui",
|
"gpui",
|
||||||
"html_to_markdown",
|
"html_to_markdown",
|
||||||
|
@ -721,9 +722,11 @@ dependencies = [
|
||||||
"ui",
|
"ui",
|
||||||
"unindent",
|
"unindent",
|
||||||
"util",
|
"util",
|
||||||
|
"web_search",
|
||||||
"workspace",
|
"workspace",
|
||||||
"workspace-hack",
|
"workspace-hack",
|
||||||
"worktree",
|
"worktree",
|
||||||
|
"zed_llm_client",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -16609,6 +16612,36 @@ dependencies = [
|
||||||
"wasm-bindgen",
|
"wasm-bindgen",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "web_search"
|
||||||
|
version = "0.1.0"
|
||||||
|
dependencies = [
|
||||||
|
"anyhow",
|
||||||
|
"collections",
|
||||||
|
"gpui",
|
||||||
|
"serde",
|
||||||
|
"workspace-hack",
|
||||||
|
"zed_llm_client",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "web_search_providers"
|
||||||
|
version = "0.1.0"
|
||||||
|
dependencies = [
|
||||||
|
"anyhow",
|
||||||
|
"client",
|
||||||
|
"feature_flags",
|
||||||
|
"futures 0.3.31",
|
||||||
|
"gpui",
|
||||||
|
"http_client",
|
||||||
|
"language_model",
|
||||||
|
"serde",
|
||||||
|
"serde_json",
|
||||||
|
"web_search",
|
||||||
|
"workspace-hack",
|
||||||
|
"zed_llm_client",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "webpki-root-certs"
|
name = "webpki-root-certs"
|
||||||
version = "0.26.8"
|
version = "0.26.8"
|
||||||
|
@ -18287,6 +18320,8 @@ dependencies = [
|
||||||
"uuid",
|
"uuid",
|
||||||
"vim",
|
"vim",
|
||||||
"vim_mode_setting",
|
"vim_mode_setting",
|
||||||
|
"web_search",
|
||||||
|
"web_search_providers",
|
||||||
"welcome",
|
"welcome",
|
||||||
"windows 0.61.1",
|
"windows 0.61.1",
|
||||||
"winresource",
|
"winresource",
|
||||||
|
@ -18351,9 +18386,9 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "zed_llm_client"
|
name = "zed_llm_client"
|
||||||
version = "0.4.2"
|
version = "0.5.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "1d28a5d6bdb0f40acf5261c39cabbf65a13b55ba4b86d9beb5b8b1c484373f1a"
|
checksum = "57a5e1b5b3ace3fb55292a4c14036723bb8a01fac4aeaa3c2b63b51228412f94"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
|
|
|
@ -165,6 +165,8 @@ members = [
|
||||||
"crates/util_macros",
|
"crates/util_macros",
|
||||||
"crates/vim",
|
"crates/vim",
|
||||||
"crates/vim_mode_setting",
|
"crates/vim_mode_setting",
|
||||||
|
"crates/web_search",
|
||||||
|
"crates/web_search_providers",
|
||||||
"crates/welcome",
|
"crates/welcome",
|
||||||
"crates/workspace",
|
"crates/workspace",
|
||||||
"crates/worktree",
|
"crates/worktree",
|
||||||
|
@ -370,6 +372,8 @@ util = { path = "crates/util" }
|
||||||
util_macros = { path = "crates/util_macros" }
|
util_macros = { path = "crates/util_macros" }
|
||||||
vim = { path = "crates/vim" }
|
vim = { path = "crates/vim" }
|
||||||
vim_mode_setting = { path = "crates/vim_mode_setting" }
|
vim_mode_setting = { path = "crates/vim_mode_setting" }
|
||||||
|
web_search = { path = "crates/web_search" }
|
||||||
|
web_search_providers = { path = "crates/web_search_providers" }
|
||||||
welcome = { path = "crates/welcome" }
|
welcome = { path = "crates/welcome" }
|
||||||
workspace = { path = "crates/workspace" }
|
workspace = { path = "crates/workspace" }
|
||||||
worktree = { path = "crates/worktree" }
|
worktree = { path = "crates/worktree" }
|
||||||
|
@ -601,7 +605,7 @@ wasmtime-wasi = "29"
|
||||||
which = "6.0.0"
|
which = "6.0.0"
|
||||||
wit-component = "0.221"
|
wit-component = "0.221"
|
||||||
workspace-hack = "0.1.0"
|
workspace-hack = "0.1.0"
|
||||||
zed_llm_client = "0.4.2"
|
zed_llm_client = "0.5.0"
|
||||||
zstd = "0.11"
|
zstd = "0.11"
|
||||||
metal = "0.29"
|
metal = "0.29"
|
||||||
|
|
||||||
|
|
|
@ -652,7 +652,8 @@
|
||||||
"path_search": true,
|
"path_search": true,
|
||||||
"read_file": true,
|
"read_file": true,
|
||||||
"regex_search": true,
|
"regex_search": true,
|
||||||
"thinking": true
|
"thinking": true,
|
||||||
|
"web_search": true
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"write": {
|
"write": {
|
||||||
|
@ -678,7 +679,8 @@
|
||||||
"regex_search": true,
|
"regex_search": true,
|
||||||
"rename": true,
|
"rename": true,
|
||||||
"symbol_info": true,
|
"symbol_info": true,
|
||||||
"thinking": true
|
"thinking": true,
|
||||||
|
"web_search": true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|
|
@ -5,11 +5,12 @@ use crate::thread::{
|
||||||
ThreadEvent, ThreadFeedback,
|
ThreadEvent, ThreadFeedback,
|
||||||
};
|
};
|
||||||
use crate::thread_store::{RulesLoadingError, ThreadStore};
|
use crate::thread_store::{RulesLoadingError, ThreadStore};
|
||||||
use crate::tool_use::{PendingToolUseStatus, ToolUse, ToolUseStatus};
|
use crate::tool_use::{PendingToolUseStatus, ToolUse};
|
||||||
use crate::ui::{AddedContext, AgentNotification, AgentNotificationEvent, ContextPill};
|
use crate::ui::{AddedContext, AgentNotification, AgentNotificationEvent, ContextPill};
|
||||||
use crate::{AssistantPanel, OpenActiveThreadAsMarkdown};
|
use crate::{AssistantPanel, OpenActiveThreadAsMarkdown};
|
||||||
use anyhow::Context as _;
|
use anyhow::Context as _;
|
||||||
use assistant_settings::{AssistantSettings, NotifyWhenAgentWaiting};
|
use assistant_settings::{AssistantSettings, NotifyWhenAgentWaiting};
|
||||||
|
use assistant_tool::ToolUseStatus;
|
||||||
use collections::{HashMap, HashSet};
|
use collections::{HashMap, HashSet};
|
||||||
use editor::scroll::Autoscroll;
|
use editor::scroll::Autoscroll;
|
||||||
use editor::{Editor, EditorElement, EditorStyle, MultiBuffer};
|
use editor::{Editor, EditorElement, EditorStyle, MultiBuffer};
|
||||||
|
@ -943,8 +944,8 @@ impl ActiveThread {
|
||||||
&tool_use.input,
|
&tool_use.input,
|
||||||
self.thread
|
self.thread
|
||||||
.read(cx)
|
.read(cx)
|
||||||
.tool_result(&tool_use.id)
|
.output_for_tool(&tool_use.id)
|
||||||
.map(|result| result.content.clone().into())
|
.map(|output| output.clone().into())
|
||||||
.unwrap_or("".into()),
|
.unwrap_or("".into()),
|
||||||
cx,
|
cx,
|
||||||
);
|
);
|
||||||
|
@ -2279,12 +2280,15 @@ impl ActiveThread {
|
||||||
window: &mut Window,
|
window: &mut Window,
|
||||||
cx: &mut Context<Self>,
|
cx: &mut Context<Self>,
|
||||||
) -> impl IntoElement + use<> {
|
) -> impl IntoElement + use<> {
|
||||||
|
if let Some(card) = self.thread.read(cx).card_for_tool(&tool_use.id) {
|
||||||
|
return card.render(&tool_use.status, window, cx);
|
||||||
|
}
|
||||||
|
|
||||||
let is_open = self
|
let is_open = self
|
||||||
.expanded_tool_uses
|
.expanded_tool_uses
|
||||||
.get(&tool_use.id)
|
.get(&tool_use.id)
|
||||||
.copied()
|
.copied()
|
||||||
.unwrap_or_default();
|
.unwrap_or_default();
|
||||||
|
|
||||||
let is_status_finished = matches!(&tool_use.status, ToolUseStatus::Finished(_));
|
let is_status_finished = matches!(&tool_use.status, ToolUseStatus::Finished(_));
|
||||||
|
|
||||||
let fs = self
|
let fs = self
|
||||||
|
@ -2381,6 +2385,7 @@ impl ActiveThread {
|
||||||
open_markdown_link(text, workspace.clone(), window, cx);
|
open_markdown_link(text, workspace.clone(), window, cx);
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
.into_any_element()
|
||||||
}),
|
}),
|
||||||
)),
|
)),
|
||||||
),
|
),
|
||||||
|
@ -2437,6 +2442,7 @@ impl ActiveThread {
|
||||||
open_markdown_link(text, workspace.clone(), window, cx);
|
open_markdown_link(text, workspace.clone(), window, cx);
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
.into_any_element()
|
||||||
})),
|
})),
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
|
@ -2767,7 +2773,7 @@ impl ActiveThread {
|
||||||
)
|
)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
})
|
}).into_any_element()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn render_rules_item(&self, cx: &Context<Self>) -> AnyElement {
|
fn render_rules_item(&self, cx: &Context<Self>) -> AnyElement {
|
||||||
|
|
|
@ -6,7 +6,7 @@ use std::time::Instant;
|
||||||
|
|
||||||
use anyhow::{Context as _, Result, anyhow};
|
use anyhow::{Context as _, Result, anyhow};
|
||||||
use assistant_settings::AssistantSettings;
|
use assistant_settings::AssistantSettings;
|
||||||
use assistant_tool::{ActionLog, Tool, ToolWorkingSet};
|
use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
|
||||||
use chrono::{DateTime, Utc};
|
use chrono::{DateTime, Utc};
|
||||||
use collections::{BTreeMap, HashMap};
|
use collections::{BTreeMap, HashMap};
|
||||||
use feature_flags::{self, FeatureFlagAppExt};
|
use feature_flags::{self, FeatureFlagAppExt};
|
||||||
|
@ -631,6 +631,14 @@ impl Thread {
|
||||||
self.tool_use.tool_result(id)
|
self.tool_use.tool_result(id)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
|
||||||
|
Some(&self.tool_use.tool_result(id)?.content)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
|
||||||
|
self.tool_use.tool_result_card(id).cloned()
|
||||||
|
}
|
||||||
|
|
||||||
pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
|
pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
|
||||||
self.tool_use.message_has_tool_results(message_id)
|
self.tool_use.message_has_tool_results(message_id)
|
||||||
}
|
}
|
||||||
|
@ -1426,6 +1434,12 @@ impl Thread {
|
||||||
)
|
)
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Store the card separately if it exists
|
||||||
|
if let Some(card) = tool_result.card.clone() {
|
||||||
|
self.tool_use
|
||||||
|
.insert_tool_result_card(tool_use_id.clone(), card);
|
||||||
|
}
|
||||||
|
|
||||||
cx.spawn({
|
cx.spawn({
|
||||||
async move |thread: WeakEntity<Thread>, cx| {
|
async move |thread: WeakEntity<Thread>, cx| {
|
||||||
let output = tool_result.output.await;
|
let output = tool_result.output.await;
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use assistant_tool::{Tool, ToolWorkingSet};
|
use assistant_tool::{AnyToolCard, Tool, ToolUseStatus, ToolWorkingSet};
|
||||||
use collections::HashMap;
|
use collections::HashMap;
|
||||||
use futures::FutureExt as _;
|
use futures::FutureExt as _;
|
||||||
use futures::future::Shared;
|
use futures::future::Shared;
|
||||||
|
@ -27,26 +27,7 @@ pub struct ToolUse {
|
||||||
pub needs_confirmation: bool,
|
pub needs_confirmation: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
pub const USING_TOOL_MARKER: &str = "<using_tool>";
|
||||||
pub enum ToolUseStatus {
|
|
||||||
NeedsConfirmation,
|
|
||||||
Pending,
|
|
||||||
Running,
|
|
||||||
Finished(SharedString),
|
|
||||||
Error(SharedString),
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ToolUseStatus {
|
|
||||||
pub fn text(&self) -> SharedString {
|
|
||||||
match self {
|
|
||||||
ToolUseStatus::NeedsConfirmation => "".into(),
|
|
||||||
ToolUseStatus::Pending => "".into(),
|
|
||||||
ToolUseStatus::Running => "".into(),
|
|
||||||
ToolUseStatus::Finished(out) => out.clone(),
|
|
||||||
ToolUseStatus::Error(out) => out.clone(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct ToolUseState {
|
pub struct ToolUseState {
|
||||||
tools: Entity<ToolWorkingSet>,
|
tools: Entity<ToolWorkingSet>,
|
||||||
|
@ -54,10 +35,9 @@ pub struct ToolUseState {
|
||||||
tool_uses_by_user_message: HashMap<MessageId, Vec<LanguageModelToolUseId>>,
|
tool_uses_by_user_message: HashMap<MessageId, Vec<LanguageModelToolUseId>>,
|
||||||
tool_results: HashMap<LanguageModelToolUseId, LanguageModelToolResult>,
|
tool_results: HashMap<LanguageModelToolUseId, LanguageModelToolResult>,
|
||||||
pending_tool_uses_by_id: HashMap<LanguageModelToolUseId, PendingToolUse>,
|
pending_tool_uses_by_id: HashMap<LanguageModelToolUseId, PendingToolUse>,
|
||||||
|
tool_result_cards: HashMap<LanguageModelToolUseId, AnyToolCard>,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub const USING_TOOL_MARKER: &str = "<using_tool>";
|
|
||||||
|
|
||||||
impl ToolUseState {
|
impl ToolUseState {
|
||||||
pub fn new(tools: Entity<ToolWorkingSet>) -> Self {
|
pub fn new(tools: Entity<ToolWorkingSet>) -> Self {
|
||||||
Self {
|
Self {
|
||||||
|
@ -66,6 +46,7 @@ impl ToolUseState {
|
||||||
tool_uses_by_user_message: HashMap::default(),
|
tool_uses_by_user_message: HashMap::default(),
|
||||||
tool_results: HashMap::default(),
|
tool_results: HashMap::default(),
|
||||||
pending_tool_uses_by_id: HashMap::default(),
|
pending_tool_uses_by_id: HashMap::default(),
|
||||||
|
tool_result_cards: HashMap::default(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -257,6 +238,18 @@ impl ToolUseState {
|
||||||
self.tool_results.get(tool_use_id)
|
self.tool_results.get(tool_use_id)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn tool_result_card(&self, tool_use_id: &LanguageModelToolUseId) -> Option<&AnyToolCard> {
|
||||||
|
self.tool_result_cards.get(tool_use_id)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn insert_tool_result_card(
|
||||||
|
&mut self,
|
||||||
|
tool_use_id: LanguageModelToolUseId,
|
||||||
|
card: AnyToolCard,
|
||||||
|
) {
|
||||||
|
self.tool_result_cards.insert(tool_use_id, card);
|
||||||
|
}
|
||||||
|
|
||||||
pub fn request_tool_use(
|
pub fn request_tool_use(
|
||||||
&mut self,
|
&mut self,
|
||||||
assistant_message_id: MessageId,
|
assistant_message_id: MessageId,
|
||||||
|
|
|
@ -9,6 +9,10 @@ use std::fmt::Formatter;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
|
use gpui::AnyElement;
|
||||||
|
use gpui::Context;
|
||||||
|
use gpui::IntoElement;
|
||||||
|
use gpui::Window;
|
||||||
use gpui::{App, Entity, SharedString, Task};
|
use gpui::{App, Entity, SharedString, Task};
|
||||||
use icons::IconName;
|
use icons::IconName;
|
||||||
use language_model::LanguageModelRequestMessage;
|
use language_model::LanguageModelRequestMessage;
|
||||||
|
@ -24,16 +28,87 @@ pub fn init(cx: &mut App) {
|
||||||
ToolRegistry::default_global(cx);
|
ToolRegistry::default_global(cx);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// The result of running a tool
|
#[derive(Debug, Clone)]
|
||||||
|
pub enum ToolUseStatus {
|
||||||
|
NeedsConfirmation,
|
||||||
|
Pending,
|
||||||
|
Running,
|
||||||
|
Finished(SharedString),
|
||||||
|
Error(SharedString),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ToolUseStatus {
|
||||||
|
pub fn text(&self) -> SharedString {
|
||||||
|
match self {
|
||||||
|
ToolUseStatus::NeedsConfirmation => "".into(),
|
||||||
|
ToolUseStatus::Pending => "".into(),
|
||||||
|
ToolUseStatus::Running => "".into(),
|
||||||
|
ToolUseStatus::Finished(out) => out.clone(),
|
||||||
|
ToolUseStatus::Error(out) => out.clone(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The result of running a tool, containing both the asynchronous output
|
||||||
|
/// and an optional card view that can be rendered immediately.
|
||||||
pub struct ToolResult {
|
pub struct ToolResult {
|
||||||
/// The asynchronous task that will eventually resolve to the tool's output
|
/// The asynchronous task that will eventually resolve to the tool's output
|
||||||
pub output: Task<Result<String>>,
|
pub output: Task<Result<String>>,
|
||||||
|
/// An optional view to present the output of the tool.
|
||||||
|
pub card: Option<AnyToolCard>,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait ToolCard: 'static + Sized {
|
||||||
|
fn render(
|
||||||
|
&mut self,
|
||||||
|
status: &ToolUseStatus,
|
||||||
|
window: &mut Window,
|
||||||
|
cx: &mut Context<Self>,
|
||||||
|
) -> impl IntoElement;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct AnyToolCard {
|
||||||
|
entity: gpui::AnyEntity,
|
||||||
|
render: fn(
|
||||||
|
entity: gpui::AnyEntity,
|
||||||
|
status: &ToolUseStatus,
|
||||||
|
window: &mut Window,
|
||||||
|
cx: &mut App,
|
||||||
|
) -> AnyElement,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: ToolCard> From<Entity<T>> for AnyToolCard {
|
||||||
|
fn from(entity: Entity<T>) -> Self {
|
||||||
|
fn downcast_render<T: ToolCard>(
|
||||||
|
entity: gpui::AnyEntity,
|
||||||
|
status: &ToolUseStatus,
|
||||||
|
window: &mut Window,
|
||||||
|
cx: &mut App,
|
||||||
|
) -> AnyElement {
|
||||||
|
let entity = entity.downcast::<T>().unwrap();
|
||||||
|
entity.update(cx, |entity, cx| {
|
||||||
|
entity.render(status, window, cx).into_any_element()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
Self {
|
||||||
|
entity: entity.into(),
|
||||||
|
render: downcast_render::<T>,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AnyToolCard {
|
||||||
|
pub fn render(&self, status: &ToolUseStatus, window: &mut Window, cx: &mut App) -> AnyElement {
|
||||||
|
(self.render)(self.entity.clone(), status, window, cx)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<Task<Result<String>>> for ToolResult {
|
impl From<Task<Result<String>>> for ToolResult {
|
||||||
/// Convert from a task to a ToolResult
|
/// Convert from a task to a ToolResult with no card
|
||||||
fn from(output: Task<Result<String>>) -> Self {
|
fn from(output: Task<Result<String>>) -> Self {
|
||||||
Self { output }
|
Self { output, card: None }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -16,6 +16,7 @@ anyhow.workspace = true
|
||||||
assistant_tool.workspace = true
|
assistant_tool.workspace = true
|
||||||
chrono.workspace = true
|
chrono.workspace = true
|
||||||
collections.workspace = true
|
collections.workspace = true
|
||||||
|
feature_flags.workspace = true
|
||||||
futures.workspace = true
|
futures.workspace = true
|
||||||
gpui.workspace = true
|
gpui.workspace = true
|
||||||
html_to_markdown.workspace = true
|
html_to_markdown.workspace = true
|
||||||
|
@ -32,7 +33,9 @@ ui.workspace = true
|
||||||
util.workspace = true
|
util.workspace = true
|
||||||
worktree.workspace = true
|
worktree.workspace = true
|
||||||
open = { workspace = true }
|
open = { workspace = true }
|
||||||
|
web_search.workspace = true
|
||||||
workspace-hack.workspace = true
|
workspace-hack.workspace = true
|
||||||
|
zed_llm_client.workspace = true
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
collections = { workspace = true, features = ["test-support"] }
|
collections = { workspace = true, features = ["test-support"] }
|
||||||
|
|
|
@ -22,14 +22,17 @@ mod schema;
|
||||||
mod symbol_info_tool;
|
mod symbol_info_tool;
|
||||||
mod terminal_tool;
|
mod terminal_tool;
|
||||||
mod thinking_tool;
|
mod thinking_tool;
|
||||||
|
mod web_search_tool;
|
||||||
|
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
use assistant_tool::ToolRegistry;
|
use assistant_tool::ToolRegistry;
|
||||||
use copy_path_tool::CopyPathTool;
|
use copy_path_tool::CopyPathTool;
|
||||||
|
use feature_flags::FeatureFlagAppExt;
|
||||||
use gpui::App;
|
use gpui::App;
|
||||||
use http_client::HttpClientWithUrl;
|
use http_client::HttpClientWithUrl;
|
||||||
use move_path_tool::MovePathTool;
|
use move_path_tool::MovePathTool;
|
||||||
|
use web_search_tool::WebSearchTool;
|
||||||
|
|
||||||
use crate::batch_tool::BatchTool;
|
use crate::batch_tool::BatchTool;
|
||||||
use crate::code_action_tool::CodeActionTool;
|
use crate::code_action_tool::CodeActionTool;
|
||||||
|
@ -56,28 +59,39 @@ pub fn init(http_client: Arc<HttpClientWithUrl>, cx: &mut App) {
|
||||||
assistant_tool::init(cx);
|
assistant_tool::init(cx);
|
||||||
|
|
||||||
let registry = ToolRegistry::global(cx);
|
let registry = ToolRegistry::global(cx);
|
||||||
registry.register_tool(TerminalTool);
|
|
||||||
registry.register_tool(BatchTool);
|
registry.register_tool(BatchTool);
|
||||||
registry.register_tool(CreateDirectoryTool);
|
|
||||||
registry.register_tool(CreateFileTool);
|
|
||||||
registry.register_tool(CopyPathTool);
|
|
||||||
registry.register_tool(DeletePathTool);
|
|
||||||
registry.register_tool(FindReplaceFileTool);
|
|
||||||
registry.register_tool(SymbolInfoTool);
|
|
||||||
registry.register_tool(CodeActionTool);
|
registry.register_tool(CodeActionTool);
|
||||||
registry.register_tool(MovePathTool);
|
|
||||||
registry.register_tool(DiagnosticsTool);
|
|
||||||
registry.register_tool(ListDirectoryTool);
|
|
||||||
registry.register_tool(NowTool);
|
|
||||||
registry.register_tool(OpenTool);
|
|
||||||
registry.register_tool(CodeSymbolsTool);
|
registry.register_tool(CodeSymbolsTool);
|
||||||
registry.register_tool(ContentsTool);
|
registry.register_tool(ContentsTool);
|
||||||
|
registry.register_tool(CopyPathTool);
|
||||||
|
registry.register_tool(CreateDirectoryTool);
|
||||||
|
registry.register_tool(CreateFileTool);
|
||||||
|
registry.register_tool(DeletePathTool);
|
||||||
|
registry.register_tool(DiagnosticsTool);
|
||||||
|
registry.register_tool(FetchTool::new(http_client));
|
||||||
|
registry.register_tool(FindReplaceFileTool);
|
||||||
|
registry.register_tool(ListDirectoryTool);
|
||||||
|
registry.register_tool(MovePathTool);
|
||||||
|
registry.register_tool(NowTool);
|
||||||
|
registry.register_tool(OpenTool);
|
||||||
registry.register_tool(PathSearchTool);
|
registry.register_tool(PathSearchTool);
|
||||||
registry.register_tool(ReadFileTool);
|
registry.register_tool(ReadFileTool);
|
||||||
registry.register_tool(RegexSearchTool);
|
registry.register_tool(RegexSearchTool);
|
||||||
registry.register_tool(RenameTool);
|
registry.register_tool(RenameTool);
|
||||||
|
registry.register_tool(SymbolInfoTool);
|
||||||
|
registry.register_tool(TerminalTool);
|
||||||
registry.register_tool(ThinkingTool);
|
registry.register_tool(ThinkingTool);
|
||||||
registry.register_tool(FetchTool::new(http_client));
|
|
||||||
|
cx.observe_flag::<feature_flags::ZedProWebSearchTool, _>({
|
||||||
|
move |is_enabled, cx| {
|
||||||
|
if is_enabled {
|
||||||
|
ToolRegistry::global(cx).register_tool(WebSearchTool);
|
||||||
|
} else {
|
||||||
|
ToolRegistry::global(cx).unregister_tool(WebSearchTool);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.detach();
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
|
213
crates/assistant_tools/src/web_search_tool.rs
Normal file
213
crates/assistant_tools/src/web_search_tool.rs
Normal file
|
@ -0,0 +1,213 @@
|
||||||
|
use std::{sync::Arc, time::Duration};
|
||||||
|
|
||||||
|
use crate::schema::json_schema_for;
|
||||||
|
use anyhow::{Context as _, Result, anyhow};
|
||||||
|
use assistant_tool::{ActionLog, Tool, ToolCard, ToolResult, ToolUseStatus};
|
||||||
|
use futures::{FutureExt, TryFutureExt};
|
||||||
|
use gpui::{
|
||||||
|
Animation, AnimationExt, App, AppContext, Context, Entity, IntoElement, Task, Window,
|
||||||
|
pulsating_between,
|
||||||
|
};
|
||||||
|
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
|
||||||
|
use project::Project;
|
||||||
|
use schemars::JsonSchema;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use ui::{IconName, Tooltip, prelude::*};
|
||||||
|
use web_search::WebSearchRegistry;
|
||||||
|
use zed_llm_client::WebSearchResponse;
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
|
||||||
|
pub struct WebSearchToolInput {
|
||||||
|
/// The search term or question to query on the web.
|
||||||
|
query: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct WebSearchTool;
|
||||||
|
|
||||||
|
impl Tool for WebSearchTool {
|
||||||
|
fn name(&self) -> String {
|
||||||
|
"web_search".into()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
|
||||||
|
fn description(&self) -> String {
|
||||||
|
"Search the web for information using your query. Use this when you need real-time information, facts, or data that might not be in your training. Results will include snippets and links from relevant web pages.".into()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn icon(&self) -> IconName {
|
||||||
|
IconName::Globe
|
||||||
|
}
|
||||||
|
|
||||||
|
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
|
||||||
|
json_schema_for::<WebSearchToolInput>(format)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn ui_text(&self, _input: &serde_json::Value) -> String {
|
||||||
|
"Web Search".to_string()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run(
|
||||||
|
self: Arc<Self>,
|
||||||
|
input: serde_json::Value,
|
||||||
|
_messages: &[LanguageModelRequestMessage],
|
||||||
|
_project: Entity<Project>,
|
||||||
|
_action_log: Entity<ActionLog>,
|
||||||
|
cx: &mut App,
|
||||||
|
) -> ToolResult {
|
||||||
|
let input = match serde_json::from_value::<WebSearchToolInput>(input) {
|
||||||
|
Ok(input) => input,
|
||||||
|
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
|
||||||
|
};
|
||||||
|
let Some(provider) = WebSearchRegistry::read_global(cx).active_provider() else {
|
||||||
|
return Task::ready(Err(anyhow!("Web search is not available."))).into();
|
||||||
|
};
|
||||||
|
|
||||||
|
let search_task = provider.search(input.query, cx).map_err(Arc::new).shared();
|
||||||
|
let output = cx.background_spawn({
|
||||||
|
let search_task = search_task.clone();
|
||||||
|
async move {
|
||||||
|
let response = search_task.await.map_err(|err| anyhow!(err))?;
|
||||||
|
serde_json::to_string(&response).context("Failed to serialize search results")
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
ToolResult {
|
||||||
|
output,
|
||||||
|
card: Some(cx.new(|cx| WebSearchToolCard::new(search_task, cx)).into()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct WebSearchToolCard {
|
||||||
|
response: Option<Result<WebSearchResponse>>,
|
||||||
|
_task: Task<()>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl WebSearchToolCard {
|
||||||
|
fn new(
|
||||||
|
search_task: impl 'static + Future<Output = Result<WebSearchResponse, Arc<anyhow::Error>>>,
|
||||||
|
cx: &mut Context<Self>,
|
||||||
|
) -> Self {
|
||||||
|
let _task = cx.spawn(async move |this, cx| {
|
||||||
|
let response = search_task.await.map_err(|err| anyhow!(err));
|
||||||
|
this.update(cx, |this, cx| {
|
||||||
|
this.response = Some(response);
|
||||||
|
cx.notify();
|
||||||
|
})
|
||||||
|
.ok();
|
||||||
|
});
|
||||||
|
|
||||||
|
Self {
|
||||||
|
response: None,
|
||||||
|
_task,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ToolCard for WebSearchToolCard {
|
||||||
|
fn render(
|
||||||
|
&mut self,
|
||||||
|
_status: &ToolUseStatus,
|
||||||
|
_window: &mut Window,
|
||||||
|
cx: &mut Context<Self>,
|
||||||
|
) -> impl IntoElement {
|
||||||
|
let header = h_flex()
|
||||||
|
.id("tool-label-container")
|
||||||
|
.gap_1p5()
|
||||||
|
.max_w_full()
|
||||||
|
.overflow_x_scroll()
|
||||||
|
.child(
|
||||||
|
Icon::new(IconName::Globe)
|
||||||
|
.size(IconSize::XSmall)
|
||||||
|
.color(Color::Muted),
|
||||||
|
)
|
||||||
|
.child(match self.response.as_ref() {
|
||||||
|
Some(Ok(response)) => {
|
||||||
|
let text: SharedString = if response.citations.len() == 1 {
|
||||||
|
"1 result".into()
|
||||||
|
} else {
|
||||||
|
format!("{} results", response.citations.len()).into()
|
||||||
|
};
|
||||||
|
h_flex()
|
||||||
|
.gap_1p5()
|
||||||
|
.child(Label::new("Searched the Web").size(LabelSize::Small))
|
||||||
|
.child(
|
||||||
|
div()
|
||||||
|
.size(px(3.))
|
||||||
|
.rounded_full()
|
||||||
|
.bg(cx.theme().colors().text),
|
||||||
|
)
|
||||||
|
.child(Label::new(text).size(LabelSize::Small))
|
||||||
|
.into_any_element()
|
||||||
|
}
|
||||||
|
Some(Err(error)) => div()
|
||||||
|
.id("web-search-error")
|
||||||
|
.child(Label::new("Web Search failed").size(LabelSize::Small))
|
||||||
|
.tooltip(Tooltip::text(error.to_string()))
|
||||||
|
.into_any_element(),
|
||||||
|
|
||||||
|
None => Label::new("Searching the Web…")
|
||||||
|
.size(LabelSize::Small)
|
||||||
|
.with_animation(
|
||||||
|
"web-search-label",
|
||||||
|
Animation::new(Duration::from_secs(2))
|
||||||
|
.repeat()
|
||||||
|
.with_easing(pulsating_between(0.6, 1.)),
|
||||||
|
|label, delta| label.alpha(delta),
|
||||||
|
)
|
||||||
|
.into_any_element(),
|
||||||
|
})
|
||||||
|
.into_any();
|
||||||
|
|
||||||
|
let content =
|
||||||
|
self.response.as_ref().and_then(|response| match response {
|
||||||
|
Ok(response) => {
|
||||||
|
Some(
|
||||||
|
v_flex()
|
||||||
|
.ml_1p5()
|
||||||
|
.pl_1p5()
|
||||||
|
.border_l_1()
|
||||||
|
.border_color(cx.theme().colors().border_variant)
|
||||||
|
.gap_1()
|
||||||
|
.children(response.citations.iter().enumerate().map(
|
||||||
|
|(index, citation)| {
|
||||||
|
let title = citation.title.clone();
|
||||||
|
let url = citation.url.clone();
|
||||||
|
|
||||||
|
Button::new(("citation", index), title)
|
||||||
|
.label_size(LabelSize::Small)
|
||||||
|
.color(Color::Muted)
|
||||||
|
.icon(IconName::ArrowUpRight)
|
||||||
|
.icon_size(IconSize::XSmall)
|
||||||
|
.icon_position(IconPosition::End)
|
||||||
|
.truncate(true)
|
||||||
|
.tooltip({
|
||||||
|
let url = url.clone();
|
||||||
|
move |window, cx| {
|
||||||
|
Tooltip::with_meta(
|
||||||
|
"Citation Link",
|
||||||
|
None,
|
||||||
|
url.clone(),
|
||||||
|
window,
|
||||||
|
cx,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.on_click({
|
||||||
|
let url = url.clone();
|
||||||
|
move |_, _, cx| cx.open_url(&url)
|
||||||
|
})
|
||||||
|
},
|
||||||
|
))
|
||||||
|
.into_any(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
Err(_) => None,
|
||||||
|
});
|
||||||
|
|
||||||
|
v_flex().my_2().gap_1().child(header).children(content)
|
||||||
|
}
|
||||||
|
}
|
|
@ -84,6 +84,11 @@ impl FeatureFlag for ZedPro {
|
||||||
const NAME: &'static str = "zed-pro";
|
const NAME: &'static str = "zed-pro";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub struct ZedProWebSearchTool {}
|
||||||
|
impl FeatureFlag for ZedProWebSearchTool {
|
||||||
|
const NAME: &'static str = "zed-pro-web-search-tool";
|
||||||
|
}
|
||||||
|
|
||||||
pub struct NotebookFeatureFlag;
|
pub struct NotebookFeatureFlag;
|
||||||
|
|
||||||
impl FeatureFlag for NotebookFeatureFlag {
|
impl FeatureFlag for NotebookFeatureFlag {
|
||||||
|
|
|
@ -160,7 +160,11 @@ impl Render for Tooltip {
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
.when_some(self.meta.clone(), |this, meta| {
|
.when_some(self.meta.clone(), |this, meta| {
|
||||||
this.child(Label::new(meta).size(LabelSize::Small).color(Color::Muted))
|
this.child(
|
||||||
|
div()
|
||||||
|
.max_w_72()
|
||||||
|
.child(Label::new(meta).size(LabelSize::Small).color(Color::Muted)),
|
||||||
|
)
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
20
crates/web_search/Cargo.toml
Normal file
20
crates/web_search/Cargo.toml
Normal file
|
@ -0,0 +1,20 @@
|
||||||
|
[package]
|
||||||
|
name = "web_search"
|
||||||
|
version = "0.1.0"
|
||||||
|
edition.workspace = true
|
||||||
|
publish.workspace = true
|
||||||
|
license = "GPL-3.0-or-later"
|
||||||
|
|
||||||
|
[lints]
|
||||||
|
workspace = true
|
||||||
|
|
||||||
|
[lib]
|
||||||
|
path = "src/web_search.rs"
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
anyhow.workspace = true
|
||||||
|
collections.workspace = true
|
||||||
|
gpui.workspace = true
|
||||||
|
serde.workspace = true
|
||||||
|
workspace-hack.workspace = true
|
||||||
|
zed_llm_client.workspace = true
|
1
crates/web_search/LICENSE-GPL
Symbolic link
1
crates/web_search/LICENSE-GPL
Symbolic link
|
@ -0,0 +1 @@
|
||||||
|
../../LICENSE-GPL
|
64
crates/web_search/src/web_search.rs
Normal file
64
crates/web_search/src/web_search.rs
Normal file
|
@ -0,0 +1,64 @@
|
||||||
|
use anyhow::Result;
|
||||||
|
use collections::HashMap;
|
||||||
|
use gpui::{App, AppContext as _, Context, Entity, Global, SharedString, Task};
|
||||||
|
use std::sync::Arc;
|
||||||
|
use zed_llm_client::WebSearchResponse;
|
||||||
|
|
||||||
|
pub fn init(cx: &mut App) {
|
||||||
|
let registry = cx.new(|_cx| WebSearchRegistry::default());
|
||||||
|
cx.set_global(GlobalWebSearchRegistry(registry));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
|
||||||
|
pub struct WebSearchProviderId(pub SharedString);
|
||||||
|
|
||||||
|
pub trait WebSearchProvider {
|
||||||
|
fn id(&self) -> WebSearchProviderId;
|
||||||
|
fn search(&self, query: String, cx: &mut App) -> Task<Result<WebSearchResponse>>;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct GlobalWebSearchRegistry(Entity<WebSearchRegistry>);
|
||||||
|
|
||||||
|
impl Global for GlobalWebSearchRegistry {}
|
||||||
|
|
||||||
|
#[derive(Default)]
|
||||||
|
pub struct WebSearchRegistry {
|
||||||
|
providers: HashMap<WebSearchProviderId, Arc<dyn WebSearchProvider>>,
|
||||||
|
active_provider: Option<Arc<dyn WebSearchProvider>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl WebSearchRegistry {
|
||||||
|
pub fn global(cx: &App) -> Entity<Self> {
|
||||||
|
cx.global::<GlobalWebSearchRegistry>().0.clone()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn read_global(cx: &App) -> &Self {
|
||||||
|
cx.global::<GlobalWebSearchRegistry>().0.read(cx)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn providers(&self) -> impl Iterator<Item = &Arc<dyn WebSearchProvider>> {
|
||||||
|
self.providers.values()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn active_provider(&self) -> Option<Arc<dyn WebSearchProvider>> {
|
||||||
|
self.active_provider.clone()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn set_active_provider(&mut self, provider: Arc<dyn WebSearchProvider>) {
|
||||||
|
self.active_provider = Some(provider.clone());
|
||||||
|
self.providers.insert(provider.id(), provider);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn register_provider<T: WebSearchProvider + 'static>(
|
||||||
|
&mut self,
|
||||||
|
provider: T,
|
||||||
|
_cx: &mut Context<Self>,
|
||||||
|
) {
|
||||||
|
let id = provider.id();
|
||||||
|
let provider = Arc::new(provider);
|
||||||
|
self.providers.insert(id.clone(), provider.clone());
|
||||||
|
if self.active_provider.is_none() {
|
||||||
|
self.active_provider = Some(provider);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
26
crates/web_search_providers/Cargo.toml
Normal file
26
crates/web_search_providers/Cargo.toml
Normal file
|
@ -0,0 +1,26 @@
|
||||||
|
[package]
|
||||||
|
name = "web_search_providers"
|
||||||
|
version = "0.1.0"
|
||||||
|
edition.workspace = true
|
||||||
|
publish.workspace = true
|
||||||
|
license = "GPL-3.0-or-later"
|
||||||
|
|
||||||
|
[lints]
|
||||||
|
workspace = true
|
||||||
|
|
||||||
|
[lib]
|
||||||
|
path = "src/web_search_providers.rs"
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
anyhow.workspace = true
|
||||||
|
client.workspace = true
|
||||||
|
feature_flags.workspace = true
|
||||||
|
futures.workspace = true
|
||||||
|
gpui.workspace = true
|
||||||
|
http_client.workspace = true
|
||||||
|
language_model.workspace = true
|
||||||
|
serde.workspace = true
|
||||||
|
serde_json.workspace = true
|
||||||
|
web_search.workspace = true
|
||||||
|
workspace-hack.workspace = true
|
||||||
|
zed_llm_client.workspace = true
|
1
crates/web_search_providers/LICENSE-GPL
Symbolic link
1
crates/web_search_providers/LICENSE-GPL
Symbolic link
|
@ -0,0 +1 @@
|
||||||
|
../../LICENSE-GPL
|
103
crates/web_search_providers/src/cloud.rs
Normal file
103
crates/web_search_providers/src/cloud.rs
Normal file
|
@ -0,0 +1,103 @@
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use anyhow::{Context as _, Result, anyhow};
|
||||||
|
use client::Client;
|
||||||
|
use futures::AsyncReadExt as _;
|
||||||
|
use gpui::{App, AppContext, Context, Entity, Subscription, Task};
|
||||||
|
use http_client::{HttpClient, Method};
|
||||||
|
use language_model::{LlmApiToken, RefreshLlmTokenListener};
|
||||||
|
use web_search::{WebSearchProvider, WebSearchProviderId};
|
||||||
|
use zed_llm_client::{WebSearchBody, WebSearchResponse};
|
||||||
|
|
||||||
|
pub struct CloudWebSearchProvider {
|
||||||
|
state: Entity<State>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CloudWebSearchProvider {
|
||||||
|
pub fn new(client: Arc<Client>, cx: &mut App) -> Self {
|
||||||
|
let state = cx.new(|cx| State::new(client, cx));
|
||||||
|
|
||||||
|
Self { state }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct State {
|
||||||
|
client: Arc<Client>,
|
||||||
|
llm_api_token: LlmApiToken,
|
||||||
|
_llm_token_subscription: Subscription,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl State {
|
||||||
|
pub fn new(client: Arc<Client>, cx: &mut Context<Self>) -> Self {
|
||||||
|
let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
|
||||||
|
|
||||||
|
Self {
|
||||||
|
client,
|
||||||
|
llm_api_token: LlmApiToken::default(),
|
||||||
|
_llm_token_subscription: cx.subscribe(
|
||||||
|
&refresh_llm_token_listener,
|
||||||
|
|this, _, _event, cx| {
|
||||||
|
let client = this.client.clone();
|
||||||
|
let llm_api_token = this.llm_api_token.clone();
|
||||||
|
cx.spawn(async move |_this, _cx| {
|
||||||
|
llm_api_token.refresh(&client).await?;
|
||||||
|
anyhow::Ok(())
|
||||||
|
})
|
||||||
|
.detach_and_log_err(cx);
|
||||||
|
},
|
||||||
|
),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl WebSearchProvider for CloudWebSearchProvider {
|
||||||
|
fn id(&self) -> WebSearchProviderId {
|
||||||
|
WebSearchProviderId("zed.dev".into())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn search(&self, query: String, cx: &mut App) -> Task<Result<WebSearchResponse>> {
|
||||||
|
let state = self.state.read(cx);
|
||||||
|
let client = state.client.clone();
|
||||||
|
let llm_api_token = state.llm_api_token.clone();
|
||||||
|
let body = WebSearchBody { query };
|
||||||
|
cx.background_spawn(async move { perform_web_search(client, llm_api_token, body).await })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn perform_web_search(
|
||||||
|
client: Arc<Client>,
|
||||||
|
llm_api_token: LlmApiToken,
|
||||||
|
body: WebSearchBody,
|
||||||
|
) -> Result<WebSearchResponse> {
|
||||||
|
let http_client = &client.http_client();
|
||||||
|
|
||||||
|
let token = llm_api_token.acquire(&client).await?;
|
||||||
|
|
||||||
|
let request_builder = http_client::Request::builder().method(Method::POST);
|
||||||
|
let request_builder = if let Ok(web_search_url) = std::env::var("ZED_WEB_SEARCH_URL") {
|
||||||
|
request_builder.uri(web_search_url)
|
||||||
|
} else {
|
||||||
|
request_builder.uri(http_client.build_zed_llm_url("/web_search", &[])?.as_ref())
|
||||||
|
};
|
||||||
|
let request = request_builder
|
||||||
|
.header("Content-Type", "application/json")
|
||||||
|
.header("Authorization", format!("Bearer {token}"))
|
||||||
|
.body(serde_json::to_string(&body)?.into())?;
|
||||||
|
let mut response = http_client
|
||||||
|
.send(request)
|
||||||
|
.await
|
||||||
|
.context("failed to send web search request")?;
|
||||||
|
|
||||||
|
if response.status().is_success() {
|
||||||
|
let mut body = String::new();
|
||||||
|
response.body_mut().read_to_string(&mut body).await?;
|
||||||
|
return Ok(serde_json::from_str(&body)?);
|
||||||
|
} else {
|
||||||
|
let mut body = String::new();
|
||||||
|
response.body_mut().read_to_string(&mut body).await?;
|
||||||
|
return Err(anyhow!(
|
||||||
|
"error performing web search.\nStatus: {:?}\nBody: {body}",
|
||||||
|
response.status(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
35
crates/web_search_providers/src/web_search_providers.rs
Normal file
35
crates/web_search_providers/src/web_search_providers.rs
Normal file
|
@ -0,0 +1,35 @@
|
||||||
|
mod cloud;
|
||||||
|
|
||||||
|
use client::Client;
|
||||||
|
use feature_flags::{FeatureFlagAppExt, ZedProWebSearchTool};
|
||||||
|
use gpui::{App, Context};
|
||||||
|
use std::sync::Arc;
|
||||||
|
use web_search::WebSearchRegistry;
|
||||||
|
|
||||||
|
pub fn init(client: Arc<Client>, cx: &mut App) {
|
||||||
|
let registry = WebSearchRegistry::global(cx);
|
||||||
|
registry.update(cx, |registry, cx| {
|
||||||
|
register_web_search_providers(registry, client, cx);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
fn register_web_search_providers(
|
||||||
|
_registry: &mut WebSearchRegistry,
|
||||||
|
client: Arc<Client>,
|
||||||
|
cx: &mut Context<WebSearchRegistry>,
|
||||||
|
) {
|
||||||
|
cx.observe_flag::<ZedProWebSearchTool, _>({
|
||||||
|
let client = client.clone();
|
||||||
|
move |is_enabled, cx| {
|
||||||
|
if is_enabled {
|
||||||
|
WebSearchRegistry::global(cx).update(cx, |registry, cx| {
|
||||||
|
registry.register_provider(
|
||||||
|
cloud::CloudWebSearchProvider::new(client.clone(), cx),
|
||||||
|
cx,
|
||||||
|
);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.detach();
|
||||||
|
}
|
|
@ -133,6 +133,8 @@ util.workspace = true
|
||||||
uuid.workspace = true
|
uuid.workspace = true
|
||||||
vim.workspace = true
|
vim.workspace = true
|
||||||
vim_mode_setting.workspace = true
|
vim_mode_setting.workspace = true
|
||||||
|
web_search.workspace = true
|
||||||
|
web_search_providers.workspace = true
|
||||||
welcome.workspace = true
|
welcome.workspace = true
|
||||||
workspace.workspace = true
|
workspace.workspace = true
|
||||||
zed_actions.workspace = true
|
zed_actions.workspace = true
|
||||||
|
|
|
@ -490,6 +490,8 @@ fn main() {
|
||||||
app_state.fs.clone(),
|
app_state.fs.clone(),
|
||||||
cx,
|
cx,
|
||||||
);
|
);
|
||||||
|
web_search::init(cx);
|
||||||
|
web_search_providers::init(app_state.client.clone(), cx);
|
||||||
snippet_provider::init(cx);
|
snippet_provider::init(cx);
|
||||||
inline_completion_registry::init(
|
inline_completion_registry::init(
|
||||||
app_state.client.clone(),
|
app_state.client.clone(),
|
||||||
|
|
|
@ -4258,6 +4258,8 @@ mod tests {
|
||||||
app_state.fs.clone(),
|
app_state.fs.clone(),
|
||||||
cx,
|
cx,
|
||||||
);
|
);
|
||||||
|
web_search::init(cx);
|
||||||
|
web_search_providers::init(app_state.client.clone(), cx);
|
||||||
let prompt_builder = PromptBuilder::load(app_state.fs.clone(), false, cx);
|
let prompt_builder = PromptBuilder::load(app_state.fs.clone(), false, cx);
|
||||||
assistant::init(
|
assistant::init(
|
||||||
app_state.fs.clone(),
|
app_state.fs.clone(),
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue