agent: Add websearch tool (#28621)
Staff only for now. We'll work on making this usable for non zed.dev users later Release Notes: - N/A --------- Co-authored-by: Antonio Scandurra <me@as-cii.com> Co-authored-by: Danilo Leal <daniloleal09@gmail.com> Co-authored-by: Marshall Bowers <git@maxdeviant.com>
This commit is contained in:
parent
2b277123be
commit
456e54b87c
22 changed files with 675 additions and 51 deletions
26
crates/web_search_providers/Cargo.toml
Normal file
26
crates/web_search_providers/Cargo.toml
Normal file
|
@ -0,0 +1,26 @@
|
|||
[package]
|
||||
name = "web_search_providers"
|
||||
version = "0.1.0"
|
||||
edition.workspace = true
|
||||
publish.workspace = true
|
||||
license = "GPL-3.0-or-later"
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[lib]
|
||||
path = "src/web_search_providers.rs"
|
||||
|
||||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
client.workspace = true
|
||||
feature_flags.workspace = true
|
||||
futures.workspace = true
|
||||
gpui.workspace = true
|
||||
http_client.workspace = true
|
||||
language_model.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
web_search.workspace = true
|
||||
workspace-hack.workspace = true
|
||||
zed_llm_client.workspace = true
|
1
crates/web_search_providers/LICENSE-GPL
Symbolic link
1
crates/web_search_providers/LICENSE-GPL
Symbolic link
|
@ -0,0 +1 @@
|
|||
../../LICENSE-GPL
|
103
crates/web_search_providers/src/cloud.rs
Normal file
103
crates/web_search_providers/src/cloud.rs
Normal file
|
@ -0,0 +1,103 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use client::Client;
|
||||
use futures::AsyncReadExt as _;
|
||||
use gpui::{App, AppContext, Context, Entity, Subscription, Task};
|
||||
use http_client::{HttpClient, Method};
|
||||
use language_model::{LlmApiToken, RefreshLlmTokenListener};
|
||||
use web_search::{WebSearchProvider, WebSearchProviderId};
|
||||
use zed_llm_client::{WebSearchBody, WebSearchResponse};
|
||||
|
||||
pub struct CloudWebSearchProvider {
|
||||
state: Entity<State>,
|
||||
}
|
||||
|
||||
impl CloudWebSearchProvider {
|
||||
pub fn new(client: Arc<Client>, cx: &mut App) -> Self {
|
||||
let state = cx.new(|cx| State::new(client, cx));
|
||||
|
||||
Self { state }
|
||||
}
|
||||
}
|
||||
|
||||
pub struct State {
|
||||
client: Arc<Client>,
|
||||
llm_api_token: LlmApiToken,
|
||||
_llm_token_subscription: Subscription,
|
||||
}
|
||||
|
||||
impl State {
|
||||
pub fn new(client: Arc<Client>, cx: &mut Context<Self>) -> Self {
|
||||
let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
|
||||
|
||||
Self {
|
||||
client,
|
||||
llm_api_token: LlmApiToken::default(),
|
||||
_llm_token_subscription: cx.subscribe(
|
||||
&refresh_llm_token_listener,
|
||||
|this, _, _event, cx| {
|
||||
let client = this.client.clone();
|
||||
let llm_api_token = this.llm_api_token.clone();
|
||||
cx.spawn(async move |_this, _cx| {
|
||||
llm_api_token.refresh(&client).await?;
|
||||
anyhow::Ok(())
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
},
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl WebSearchProvider for CloudWebSearchProvider {
|
||||
fn id(&self) -> WebSearchProviderId {
|
||||
WebSearchProviderId("zed.dev".into())
|
||||
}
|
||||
|
||||
fn search(&self, query: String, cx: &mut App) -> Task<Result<WebSearchResponse>> {
|
||||
let state = self.state.read(cx);
|
||||
let client = state.client.clone();
|
||||
let llm_api_token = state.llm_api_token.clone();
|
||||
let body = WebSearchBody { query };
|
||||
cx.background_spawn(async move { perform_web_search(client, llm_api_token, body).await })
|
||||
}
|
||||
}
|
||||
|
||||
async fn perform_web_search(
|
||||
client: Arc<Client>,
|
||||
llm_api_token: LlmApiToken,
|
||||
body: WebSearchBody,
|
||||
) -> Result<WebSearchResponse> {
|
||||
let http_client = &client.http_client();
|
||||
|
||||
let token = llm_api_token.acquire(&client).await?;
|
||||
|
||||
let request_builder = http_client::Request::builder().method(Method::POST);
|
||||
let request_builder = if let Ok(web_search_url) = std::env::var("ZED_WEB_SEARCH_URL") {
|
||||
request_builder.uri(web_search_url)
|
||||
} else {
|
||||
request_builder.uri(http_client.build_zed_llm_url("/web_search", &[])?.as_ref())
|
||||
};
|
||||
let request = request_builder
|
||||
.header("Content-Type", "application/json")
|
||||
.header("Authorization", format!("Bearer {token}"))
|
||||
.body(serde_json::to_string(&body)?.into())?;
|
||||
let mut response = http_client
|
||||
.send(request)
|
||||
.await
|
||||
.context("failed to send web search request")?;
|
||||
|
||||
if response.status().is_success() {
|
||||
let mut body = String::new();
|
||||
response.body_mut().read_to_string(&mut body).await?;
|
||||
return Ok(serde_json::from_str(&body)?);
|
||||
} else {
|
||||
let mut body = String::new();
|
||||
response.body_mut().read_to_string(&mut body).await?;
|
||||
return Err(anyhow!(
|
||||
"error performing web search.\nStatus: {:?}\nBody: {body}",
|
||||
response.status(),
|
||||
));
|
||||
}
|
||||
}
|
35
crates/web_search_providers/src/web_search_providers.rs
Normal file
35
crates/web_search_providers/src/web_search_providers.rs
Normal file
|
@ -0,0 +1,35 @@
|
|||
mod cloud;
|
||||
|
||||
use client::Client;
|
||||
use feature_flags::{FeatureFlagAppExt, ZedProWebSearchTool};
|
||||
use gpui::{App, Context};
|
||||
use std::sync::Arc;
|
||||
use web_search::WebSearchRegistry;
|
||||
|
||||
pub fn init(client: Arc<Client>, cx: &mut App) {
|
||||
let registry = WebSearchRegistry::global(cx);
|
||||
registry.update(cx, |registry, cx| {
|
||||
register_web_search_providers(registry, client, cx);
|
||||
});
|
||||
}
|
||||
|
||||
fn register_web_search_providers(
|
||||
_registry: &mut WebSearchRegistry,
|
||||
client: Arc<Client>,
|
||||
cx: &mut Context<WebSearchRegistry>,
|
||||
) {
|
||||
cx.observe_flag::<ZedProWebSearchTool, _>({
|
||||
let client = client.clone();
|
||||
move |is_enabled, cx| {
|
||||
if is_enabled {
|
||||
WebSearchRegistry::global(cx).update(cx, |registry, cx| {
|
||||
registry.register_provider(
|
||||
cloud::CloudWebSearchProvider::new(client.clone(), cx),
|
||||
cx,
|
||||
);
|
||||
});
|
||||
}
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue