diff --git a/Cargo.lock b/Cargo.lock index d30e29f83c..91538b0184 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -715,6 +715,8 @@ dependencies = [ "feature_flags", "futures 0.3.31", "gpui", + "html_to_markdown", + "http_client", "itertools 0.14.0", "language", "language_model", diff --git a/crates/assistant_eval/src/headless_assistant.rs b/crates/assistant_eval/src/headless_assistant.rs index 827f3e6a9e..ea0eaf33c8 100644 --- a/crates/assistant_eval/src/headless_assistant.rs +++ b/crates/assistant_eval/src/headless_assistant.rs @@ -163,7 +163,7 @@ pub fn init(cx: &mut App) -> Arc { language::init(cx); language_model::init(client.clone(), cx); language_models::init(user_store.clone(), client.clone(), fs.clone(), cx); - assistant_tools::init(cx); + assistant_tools::init(client.http_client().clone(), cx); context_server::init(cx); let stdout_is_a_pty = false; let prompt_builder = PromptBuilder::load(fs.clone(), stdout_is_a_pty, cx); diff --git a/crates/assistant_tools/Cargo.toml b/crates/assistant_tools/Cargo.toml index 0c765e8074..9b388cc0ff 100644 --- a/crates/assistant_tools/Cargo.toml +++ b/crates/assistant_tools/Cargo.toml @@ -18,8 +18,10 @@ chrono.workspace = true collections.workspace = true feature_flags.workspace = true futures.workspace = true -itertools.workspace = true gpui.workspace = true +html_to_markdown.workspace = true +http_client.workspace = true +itertools.workspace = true language.workspace = true language_model.workspace = true project.workspace = true @@ -27,17 +29,17 @@ release_channel.workspace = true schemars.workspace = true serde.workspace = true serde_json.workspace = true +settings.workspace = true theme.workspace = true ui.workspace = true util.workspace = true workspace.workspace = true worktree.workspace = true -settings.workspace = true [dev-dependencies] -rand.workspace = true collections = { workspace = true, features = ["test-support"] } gpui = { workspace = true, features = ["test-support"] } language = { workspace = true, features = ["test-support"] } project = { workspace = true, features = ["test-support"] } +rand.workspace = true workspace = { workspace = true, features = ["test-support"] } diff --git a/crates/assistant_tools/src/assistant_tools.rs b/crates/assistant_tools/src/assistant_tools.rs index 21eab55598..daa471e6e5 100644 --- a/crates/assistant_tools/src/assistant_tools.rs +++ b/crates/assistant_tools/src/assistant_tools.rs @@ -2,6 +2,7 @@ mod bash_tool; mod delete_path_tool; mod diagnostics_tool; mod edit_files_tool; +mod fetch_tool; mod list_directory_tool; mod now_tool; mod path_search_tool; @@ -9,13 +10,17 @@ mod read_file_tool; mod regex_search_tool; mod thinking_tool; +use std::sync::Arc; + use assistant_tool::ToolRegistry; use gpui::App; +use http_client::HttpClientWithUrl; use crate::bash_tool::BashTool; use crate::delete_path_tool::DeletePathTool; use crate::diagnostics_tool::DiagnosticsTool; use crate::edit_files_tool::EditFilesTool; +use crate::fetch_tool::FetchTool; use crate::list_directory_tool::ListDirectoryTool; use crate::now_tool::NowTool; use crate::path_search_tool::PathSearchTool; @@ -23,7 +28,7 @@ use crate::read_file_tool::ReadFileTool; use crate::regex_search_tool::RegexSearchTool; use crate::thinking_tool::ThinkingTool; -pub fn init(cx: &mut App) { +pub fn init(http_client: Arc, cx: &mut App) { assistant_tool::init(cx); crate::edit_files_tool::log::init(cx); @@ -38,4 +43,5 @@ pub fn init(cx: &mut App) { registry.register_tool(ReadFileTool); registry.register_tool(RegexSearchTool); registry.register_tool(ThinkingTool); + registry.register_tool(FetchTool::new(http_client)); } diff --git a/crates/assistant_tools/src/fetch_tool.rs b/crates/assistant_tools/src/fetch_tool.rs new file mode 100644 index 0000000000..ecdc9dddd5 --- /dev/null +++ b/crates/assistant_tools/src/fetch_tool.rs @@ -0,0 +1,153 @@ +use std::cell::RefCell; +use std::rc::Rc; +use std::sync::Arc; + +use anyhow::{anyhow, bail, Context as _, Result}; +use assistant_tool::{ActionLog, Tool}; +use futures::AsyncReadExt as _; +use gpui::{App, AppContext as _, Entity, Task}; +use html_to_markdown::{convert_html_to_markdown, markdown, TagHandler}; +use http_client::{AsyncBody, HttpClientWithUrl}; +use language_model::LanguageModelRequestMessage; +use project::Project; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy)] +enum ContentType { + Html, + Plaintext, + Json, +} + +#[derive(Debug, Serialize, Deserialize, JsonSchema)] +pub struct FetchToolInput { + /// The URL to fetch. + url: String, +} + +pub struct FetchTool { + http_client: Arc, +} + +impl FetchTool { + pub fn new(http_client: Arc) -> Self { + Self { http_client } + } + + async fn build_message(http_client: Arc, url: &str) -> Result { + let mut url = url.to_owned(); + if !url.starts_with("https://") && !url.starts_with("http://") { + url = format!("https://{url}"); + } + + let mut response = http_client.get(&url, AsyncBody::default(), true).await?; + + let mut body = Vec::new(); + response + .body_mut() + .read_to_end(&mut body) + .await + .context("error reading response body")?; + + if response.status().is_client_error() { + let text = String::from_utf8_lossy(body.as_slice()); + bail!( + "status error {}, response: {text:?}", + response.status().as_u16() + ); + } + + let Some(content_type) = response.headers().get("content-type") else { + bail!("missing Content-Type header"); + }; + let content_type = content_type + .to_str() + .context("invalid Content-Type header")?; + let content_type = match content_type { + "text/html" => ContentType::Html, + "text/plain" => ContentType::Plaintext, + "application/json" => ContentType::Json, + _ => ContentType::Html, + }; + + match content_type { + ContentType::Html => { + let mut handlers: Vec = vec![ + Rc::new(RefCell::new(markdown::WebpageChromeRemover)), + Rc::new(RefCell::new(markdown::ParagraphHandler)), + Rc::new(RefCell::new(markdown::HeadingHandler)), + Rc::new(RefCell::new(markdown::ListHandler)), + Rc::new(RefCell::new(markdown::TableHandler::new())), + Rc::new(RefCell::new(markdown::StyledTextHandler)), + ]; + if url.contains("wikipedia.org") { + use html_to_markdown::structure::wikipedia; + + handlers.push(Rc::new(RefCell::new(wikipedia::WikipediaChromeRemover))); + handlers.push(Rc::new(RefCell::new(wikipedia::WikipediaInfoboxHandler))); + handlers.push(Rc::new( + RefCell::new(wikipedia::WikipediaCodeHandler::new()), + )); + } else { + handlers.push(Rc::new(RefCell::new(markdown::CodeHandler))); + } + + convert_html_to_markdown(&body[..], &mut handlers) + } + ContentType::Plaintext => Ok(std::str::from_utf8(&body)?.to_owned()), + ContentType::Json => { + let json: serde_json::Value = serde_json::from_slice(&body)?; + + Ok(format!( + "```json\n{}\n```", + serde_json::to_string_pretty(&json)? + )) + } + } + } +} + +impl Tool for FetchTool { + fn name(&self) -> String { + "fetch".to_string() + } + + fn description(&self) -> String { + include_str!("./fetch_tool/description.md").to_string() + } + + fn input_schema(&self) -> serde_json::Value { + let schema = schemars::schema_for!(FetchToolInput); + serde_json::to_value(&schema).unwrap() + } + + fn run( + self: Arc, + input: serde_json::Value, + _messages: &[LanguageModelRequestMessage], + _project: Entity, + _action_log: Entity, + cx: &mut App, + ) -> Task> { + let input = match serde_json::from_value::(input) { + Ok(input) => input, + Err(err) => return Task::ready(Err(anyhow!(err))), + }; + + let text = cx.background_spawn({ + let http_client = self.http_client.clone(); + let url = input.url.clone(); + async move { Self::build_message(http_client, &url).await } + }); + + cx.foreground_executor().spawn(async move { + let text = text.await?; + if text.trim().is_empty() { + bail!("no textual content found"); + } + + Ok(text) + }) + } +} diff --git a/crates/assistant_tools/src/fetch_tool/description.md b/crates/assistant_tools/src/fetch_tool/description.md new file mode 100644 index 0000000000..007ba6c608 --- /dev/null +++ b/crates/assistant_tools/src/fetch_tool/description.md @@ -0,0 +1 @@ +Fetches a URL and returns the content as Markdown. diff --git a/crates/zed/src/main.rs b/crates/zed/src/main.rs index f832fe7029..fae38a14ac 100644 --- a/crates/zed/src/main.rs +++ b/crates/zed/src/main.rs @@ -472,7 +472,7 @@ fn main() { prompt_builder.clone(), cx, ); - assistant_tools::init(cx); + assistant_tools::init(app_state.client.http_client(), cx); repl::init(app_state.fs.clone(), cx); extension_host::init( extension_host_proxy,