diff --git a/Cargo.lock b/Cargo.lock index 7b5e82a312..8a3e319a57 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -204,6 +204,8 @@ dependencies = [ "gpui", "gpui_tokio", "handlebars 4.5.0", + "html_to_markdown", + "http_client", "indoc", "itertools 0.14.0", "language", diff --git a/crates/action_log/src/action_log.rs b/crates/action_log/src/action_log.rs index 025aba060d..c4eaffc228 100644 --- a/crates/action_log/src/action_log.rs +++ b/crates/action_log/src/action_log.rs @@ -17,8 +17,6 @@ use util::{ pub struct ActionLog { /// Buffers that we want to notify the model about when they change. tracked_buffers: BTreeMap, TrackedBuffer>, - /// Has the model edited a file since it last checked diagnostics? - edited_since_project_diagnostics_check: bool, /// The project this action log is associated with project: Entity, } @@ -28,7 +26,6 @@ impl ActionLog { pub fn new(project: Entity) -> Self { Self { tracked_buffers: BTreeMap::default(), - edited_since_project_diagnostics_check: false, project, } } @@ -37,16 +34,6 @@ impl ActionLog { &self.project } - /// Notifies a diagnostics check - pub fn checked_project_diagnostics(&mut self) { - self.edited_since_project_diagnostics_check = false; - } - - /// Returns true if any files have been edited since the last project diagnostics check - pub fn has_edited_files_since_project_diagnostics_check(&self) -> bool { - self.edited_since_project_diagnostics_check - } - pub fn latest_snapshot(&self, buffer: &Entity) -> Option { Some(self.tracked_buffers.get(buffer)?.snapshot.clone()) } @@ -543,14 +530,11 @@ impl ActionLog { /// Mark a buffer as created by agent, so we can refresh it in the context pub fn buffer_created(&mut self, buffer: Entity, cx: &mut Context) { - self.edited_since_project_diagnostics_check = true; self.track_buffer_internal(buffer.clone(), true, cx); } /// Mark a buffer as edited by agent, so we can refresh it in the context pub fn buffer_edited(&mut self, buffer: Entity, cx: &mut Context) { - self.edited_since_project_diagnostics_check = true; - let tracked_buffer = self.track_buffer_internal(buffer.clone(), false, cx); if let TrackedBufferStatus::Deleted = tracked_buffer.status { tracked_buffer.status = TrackedBufferStatus::Modified; diff --git a/crates/agent2/Cargo.toml b/crates/agent2/Cargo.toml index 622b08016a..7ee48aca04 100644 --- a/crates/agent2/Cargo.toml +++ b/crates/agent2/Cargo.toml @@ -27,6 +27,8 @@ fs.workspace = true futures.workspace = true gpui.workspace = true handlebars = { workspace = true, features = ["rust-embed"] } +html_to_markdown.workspace = true +http_client.workspace = true indoc.workspace = true itertools.workspace = true language.workspace = true diff --git a/crates/agent2/src/agent.rs b/crates/agent2/src/agent.rs index b1cefd2864..66893f49f9 100644 --- a/crates/agent2/src/agent.rs +++ b/crates/agent2/src/agent.rs @@ -1,8 +1,8 @@ use crate::{AgentResponseEvent, Thread, templates::Templates}; use crate::{ - CopyPathTool, CreateDirectoryTool, EditFileTool, FindPathTool, GrepTool, ListDirectoryTool, - MovePathTool, NowTool, OpenTool, ReadFileTool, TerminalTool, ThinkingTool, - ToolCallAuthorization, WebSearchTool, + CopyPathTool, CreateDirectoryTool, DiagnosticsTool, EditFileTool, FetchTool, FindPathTool, + GrepTool, ListDirectoryTool, MovePathTool, NowTool, OpenTool, ReadFileTool, TerminalTool, + ThinkingTool, ToolCallAuthorization, WebSearchTool, }; use acp_thread::ModelSelector; use agent_client_protocol as acp; @@ -420,11 +420,13 @@ impl acp_thread::AgentConnection for NativeAgentConnection { let mut thread = Thread::new(project.clone(), agent.project_context.clone(), action_log.clone(), agent.templates.clone(), default_model); thread.add_tool(CreateDirectoryTool::new(project.clone())); thread.add_tool(CopyPathTool::new(project.clone())); + thread.add_tool(DiagnosticsTool::new(project.clone())); thread.add_tool(MovePathTool::new(project.clone())); thread.add_tool(ListDirectoryTool::new(project.clone())); thread.add_tool(OpenTool::new(project.clone())); thread.add_tool(ThinkingTool); thread.add_tool(FindPathTool::new(project.clone())); + thread.add_tool(FetchTool::new(project.read(cx).client().http_client())); thread.add_tool(GrepTool::new(project.clone())); thread.add_tool(ReadFileTool::new(project.clone(), action_log)); thread.add_tool(EditFileTool::new(cx.entity())); diff --git a/crates/agent2/src/tools.rs b/crates/agent2/src/tools.rs index 29ba6780b8..8896b14538 100644 --- a/crates/agent2/src/tools.rs +++ b/crates/agent2/src/tools.rs @@ -1,7 +1,9 @@ mod copy_path_tool; mod create_directory_tool; mod delete_path_tool; +mod diagnostics_tool; mod edit_file_tool; +mod fetch_tool; mod find_path_tool; mod grep_tool; mod list_directory_tool; @@ -16,7 +18,9 @@ mod web_search_tool; pub use copy_path_tool::*; pub use create_directory_tool::*; pub use delete_path_tool::*; +pub use diagnostics_tool::*; pub use edit_file_tool::*; +pub use fetch_tool::*; pub use find_path_tool::*; pub use grep_tool::*; pub use list_directory_tool::*; diff --git a/crates/agent2/src/tools/diagnostics_tool.rs b/crates/agent2/src/tools/diagnostics_tool.rs new file mode 100644 index 0000000000..bd0b20df5a --- /dev/null +++ b/crates/agent2/src/tools/diagnostics_tool.rs @@ -0,0 +1,177 @@ +use crate::{AgentTool, ToolCallEventStream}; +use agent_client_protocol as acp; +use anyhow::{Result, anyhow}; +use gpui::{App, Entity, Task}; +use language::{DiagnosticSeverity, OffsetRangeExt}; +use project::Project; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; +use std::{fmt::Write, path::Path, sync::Arc}; +use ui::SharedString; +use util::markdown::MarkdownInlineCode; + +/// Get errors and warnings for the project or a specific file. +/// +/// This tool can be invoked after a series of edits to determine if further edits are necessary, or if the user asks to fix errors or warnings in their codebase. +/// +/// When a path is provided, shows all diagnostics for that specific file. +/// When no path is provided, shows a summary of error and warning counts for all files in the project. +/// +/// +/// To get diagnostics for a specific file: +/// { +/// "path": "src/main.rs" +/// } +/// +/// To get a project-wide diagnostic summary: +/// {} +/// +/// +/// +/// - If you think you can fix a diagnostic, make 1-2 attempts and then give up. +/// - Don't remove code you've generated just because you can't fix an error. The user can help you fix it. +/// +#[derive(Debug, Serialize, Deserialize, JsonSchema)] +pub struct DiagnosticsToolInput { + /// The path to get diagnostics for. If not provided, returns a project-wide summary. + /// + /// This path should never be absolute, and the first component + /// of the path should always be a root directory in a project. + /// + /// + /// If the project has the following root directories: + /// + /// - lorem + /// - ipsum + /// + /// If you wanna access diagnostics for `dolor.txt` in `ipsum`, you should use the path `ipsum/dolor.txt`. + /// + pub path: Option, +} + +pub struct DiagnosticsTool { + project: Entity, +} + +impl DiagnosticsTool { + pub fn new(project: Entity) -> Self { + Self { project } + } +} + +impl AgentTool for DiagnosticsTool { + type Input = DiagnosticsToolInput; + type Output = String; + + fn name(&self) -> SharedString { + "diagnostics".into() + } + + fn kind(&self) -> acp::ToolKind { + acp::ToolKind::Read + } + + fn initial_title(&self, input: Result) -> SharedString { + if let Some(path) = input.ok().and_then(|input| match input.path { + Some(path) if !path.is_empty() => Some(path), + _ => None, + }) { + format!("Check diagnostics for {}", MarkdownInlineCode(&path)).into() + } else { + "Check project diagnostics".into() + } + } + + fn run( + self: Arc, + input: Self::Input, + event_stream: ToolCallEventStream, + cx: &mut App, + ) -> Task> { + match input.path { + Some(path) if !path.is_empty() => { + let Some(project_path) = self.project.read(cx).find_project_path(&path, cx) else { + return Task::ready(Err(anyhow!("Could not find path {path} in project",))); + }; + + let buffer = self + .project + .update(cx, |project, cx| project.open_buffer(project_path, cx)); + + cx.spawn(async move |cx| { + let mut output = String::new(); + let buffer = buffer.await?; + let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?; + + for (_, group) in snapshot.diagnostic_groups(None) { + let entry = &group.entries[group.primary_ix]; + let range = entry.range.to_point(&snapshot); + let severity = match entry.diagnostic.severity { + DiagnosticSeverity::ERROR => "error", + DiagnosticSeverity::WARNING => "warning", + _ => continue, + }; + + writeln!( + output, + "{} at line {}: {}", + severity, + range.start.row + 1, + entry.diagnostic.message + )?; + + event_stream.update_fields(acp::ToolCallUpdateFields { + content: Some(vec![output.clone().into()]), + ..Default::default() + }); + } + + if output.is_empty() { + Ok("File doesn't have errors or warnings!".to_string()) + } else { + Ok(output) + } + }) + } + _ => { + let project = self.project.read(cx); + let mut output = String::new(); + let mut has_diagnostics = false; + + for (project_path, _, summary) in project.diagnostic_summaries(true, cx) { + if summary.error_count > 0 || summary.warning_count > 0 { + let Some(worktree) = project.worktree_for_id(project_path.worktree_id, cx) + else { + continue; + }; + + has_diagnostics = true; + output.push_str(&format!( + "{}: {} error(s), {} warning(s)\n", + Path::new(worktree.read(cx).root_name()) + .join(project_path.path) + .display(), + summary.error_count, + summary.warning_count + )); + } + } + + if has_diagnostics { + event_stream.update_fields(acp::ToolCallUpdateFields { + content: Some(vec![output.clone().into()]), + ..Default::default() + }); + Task::ready(Ok(output)) + } else { + let text = "No errors or warnings found in the project."; + event_stream.update_fields(acp::ToolCallUpdateFields { + content: Some(vec![text.into()]), + ..Default::default() + }); + Task::ready(Ok(text.into())) + } + } + } + } +} diff --git a/crates/agent2/src/tools/fetch_tool.rs b/crates/agent2/src/tools/fetch_tool.rs new file mode 100644 index 0000000000..7f3752843c --- /dev/null +++ b/crates/agent2/src/tools/fetch_tool.rs @@ -0,0 +1,161 @@ +use std::rc::Rc; +use std::sync::Arc; +use std::{borrow::Cow, cell::RefCell}; + +use agent_client_protocol as acp; +use anyhow::{Context as _, Result, bail}; +use futures::AsyncReadExt as _; +use gpui::{App, AppContext as _, Task}; +use html_to_markdown::{TagHandler, convert_html_to_markdown, markdown}; +use http_client::{AsyncBody, HttpClientWithUrl}; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; +use ui::SharedString; +use util::markdown::MarkdownEscaped; + +use crate::{AgentTool, ToolCallEventStream}; + +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy)] +enum ContentType { + Html, + Plaintext, + Json, +} + +/// Fetches a URL and returns the content as Markdown. +#[derive(Debug, Serialize, Deserialize, JsonSchema)] +pub struct FetchToolInput { + /// The URL to fetch. + url: String, +} + +pub struct FetchTool { + http_client: Arc, +} + +impl FetchTool { + pub fn new(http_client: Arc) -> Self { + Self { http_client } + } + + async fn build_message(http_client: Arc, url: &str) -> Result { + let url = if !url.starts_with("https://") && !url.starts_with("http://") { + Cow::Owned(format!("https://{url}")) + } else { + Cow::Borrowed(url) + }; + + let mut response = http_client.get(&url, AsyncBody::default(), true).await?; + + let mut body = Vec::new(); + response + .body_mut() + .read_to_end(&mut body) + .await + .context("error reading response body")?; + + if response.status().is_client_error() { + let text = String::from_utf8_lossy(body.as_slice()); + bail!( + "status error {}, response: {text:?}", + response.status().as_u16() + ); + } + + let Some(content_type) = response.headers().get("content-type") else { + bail!("missing Content-Type header"); + }; + let content_type = content_type + .to_str() + .context("invalid Content-Type header")?; + + let content_type = if content_type.starts_with("text/plain") { + ContentType::Plaintext + } else if content_type.starts_with("application/json") { + ContentType::Json + } else { + ContentType::Html + }; + + match content_type { + ContentType::Html => { + let mut handlers: Vec = vec![ + Rc::new(RefCell::new(markdown::WebpageChromeRemover)), + Rc::new(RefCell::new(markdown::ParagraphHandler)), + Rc::new(RefCell::new(markdown::HeadingHandler)), + Rc::new(RefCell::new(markdown::ListHandler)), + Rc::new(RefCell::new(markdown::TableHandler::new())), + Rc::new(RefCell::new(markdown::StyledTextHandler)), + ]; + if url.contains("wikipedia.org") { + use html_to_markdown::structure::wikipedia; + + handlers.push(Rc::new(RefCell::new(wikipedia::WikipediaChromeRemover))); + handlers.push(Rc::new(RefCell::new(wikipedia::WikipediaInfoboxHandler))); + handlers.push(Rc::new( + RefCell::new(wikipedia::WikipediaCodeHandler::new()), + )); + } else { + handlers.push(Rc::new(RefCell::new(markdown::CodeHandler))); + } + + convert_html_to_markdown(&body[..], &mut handlers) + } + ContentType::Plaintext => Ok(std::str::from_utf8(&body)?.to_owned()), + ContentType::Json => { + let json: serde_json::Value = serde_json::from_slice(&body)?; + + Ok(format!( + "```json\n{}\n```", + serde_json::to_string_pretty(&json)? + )) + } + } + } +} + +impl AgentTool for FetchTool { + type Input = FetchToolInput; + type Output = String; + + fn name(&self) -> SharedString { + "fetch".into() + } + + fn kind(&self) -> acp::ToolKind { + acp::ToolKind::Fetch + } + + fn initial_title(&self, input: Result) -> SharedString { + match input { + Ok(input) => format!("Fetch {}", MarkdownEscaped(&input.url)).into(), + Err(_) => "Fetch URL".into(), + } + } + + fn run( + self: Arc, + input: Self::Input, + event_stream: ToolCallEventStream, + cx: &mut App, + ) -> Task> { + let text = cx.background_spawn({ + let http_client = self.http_client.clone(); + async move { Self::build_message(http_client, &input.url).await } + }); + + cx.foreground_executor().spawn(async move { + let text = text.await?; + if text.trim().is_empty() { + bail!("no textual content found"); + } + + event_stream.update_fields(acp::ToolCallUpdateFields { + content: Some(vec![text.clone().into()]), + ..Default::default() + }); + + Ok(text) + }) + } +} diff --git a/crates/assistant_tools/src/diagnostics_tool.rs b/crates/assistant_tools/src/diagnostics_tool.rs index bc479eb596..4ec794e127 100644 --- a/crates/assistant_tools/src/diagnostics_tool.rs +++ b/crates/assistant_tools/src/diagnostics_tool.rs @@ -86,7 +86,7 @@ impl Tool for DiagnosticsTool { input: serde_json::Value, _request: Arc, project: Entity, - action_log: Entity, + _action_log: Entity, _model: Arc, _window: Option, cx: &mut App, @@ -159,10 +159,6 @@ impl Tool for DiagnosticsTool { } } - action_log.update(cx, |action_log, _cx| { - action_log.checked_project_diagnostics(); - }); - if has_diagnostics { Task::ready(Ok(output.into())).into() } else {