From cc396a5e364f659b11c8bfa1b1ae3474b99d4f64 Mon Sep 17 00:00:00 2001 From: Agus Zubiaga Date: Mon, 11 Aug 2025 20:41:47 -0300 Subject: [PATCH] Include mention content in agent2 requests Co-authored-by: Cole Miller --- Cargo.lock | 1 + crates/acp_thread/Cargo.toml | 1 + crates/acp_thread/src/acp_thread.rs | 63 ++-- crates/acp_thread/src/mention.rs | 122 ++++++++ crates/agent2/src/agent.rs | 43 +-- crates/agent2/src/tests/mod.rs | 9 +- crates/agent2/src/thread.rs | 284 +++++++++++++++--- .../agent_ui/src/acp/completion_provider.rs | 77 ++++- crates/agent_ui/src/acp/thread_view.rs | 249 ++++++++------- 9 files changed, 597 insertions(+), 252 deletions(-) create mode 100644 crates/acp_thread/src/mention.rs diff --git a/Cargo.lock b/Cargo.lock index 7b5e82a312..4367bd3f07 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -29,6 +29,7 @@ dependencies = [ "tempfile", "terminal", "ui", + "url", "util", "workspace-hack", ] diff --git a/crates/acp_thread/Cargo.toml b/crates/acp_thread/Cargo.toml index 33e88df761..1fef342c01 100644 --- a/crates/acp_thread/Cargo.toml +++ b/crates/acp_thread/Cargo.toml @@ -34,6 +34,7 @@ settings.workspace = true smol.workspace = true terminal.workspace = true ui.workspace = true +url.workspace = true util.workspace = true workspace-hack.workspace = true diff --git a/crates/acp_thread/src/acp_thread.rs b/crates/acp_thread/src/acp_thread.rs index d632e6e570..4f8773b416 100644 --- a/crates/acp_thread/src/acp_thread.rs +++ b/crates/acp_thread/src/acp_thread.rs @@ -1,13 +1,15 @@ mod connection; mod diff; +mod mention; mod terminal; pub use connection::*; pub use diff::*; +pub use mention::*; pub use terminal::*; use action_log::ActionLog; -use agent_client_protocol as acp; +use agent_client_protocol::{self as acp}; use anyhow::{Context as _, Result}; use editor::Bias; use futures::{FutureExt, channel::oneshot, future::BoxFuture}; @@ -21,12 +23,7 @@ use std::error::Error; use std::fmt::Formatter; use std::process::ExitStatus; use std::rc::Rc; -use std::{ - fmt::Display, - mem, - path::{Path, PathBuf}, - sync::Arc, -}; +use std::{fmt::Display, mem, path::PathBuf, sync::Arc}; use ui::App; use util::ResultExt; @@ -53,38 +50,6 @@ impl UserMessage { } } -#[derive(Debug)] -pub struct MentionPath<'a>(&'a Path); - -impl<'a> MentionPath<'a> { - const PREFIX: &'static str = "@file:"; - - pub fn new(path: &'a Path) -> Self { - MentionPath(path) - } - - pub fn try_parse(url: &'a str) -> Option { - let path = url.strip_prefix(Self::PREFIX)?; - Some(MentionPath(Path::new(path))) - } - - pub fn path(&self) -> &Path { - self.0 - } -} - -impl Display for MentionPath<'_> { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!( - f, - "[@{}]({}{})", - self.0.file_name().unwrap_or_default().display(), - Self::PREFIX, - self.0.display() - ) - } -} - #[derive(Debug, PartialEq)] pub struct AssistantMessage { pub chunks: Vec, @@ -358,16 +323,24 @@ impl ContentBlock { ) { let new_content = match block { acp::ContentBlock::Text(text_content) => text_content.text.clone(), - acp::ContentBlock::ResourceLink(resource_link) => { - if let Some(path) = resource_link.uri.strip_prefix("file://") { - format!("{}", MentionPath(path.as_ref())) + acp::ContentBlock::Resource(acp::EmbeddedResource { + resource: + acp::EmbeddedResourceResource::TextResourceContents(acp::TextResourceContents { + uri, + .. + }), + .. + }) => { + if let Some(uri) = MentionUri::parse(&uri).log_err() { + uri.to_link() } else { - resource_link.uri.clone() + uri.clone() } } acp::ContentBlock::Image(_) | acp::ContentBlock::Audio(_) - | acp::ContentBlock::Resource(_) => String::new(), + | acp::ContentBlock::Resource(acp::EmbeddedResource { .. }) + | acp::ContentBlock::ResourceLink(_) => String::new(), }; match self { @@ -1278,7 +1251,7 @@ mod tests { use serde_json::json; use settings::SettingsStore; use smol::stream::StreamExt as _; - use std::{cell::RefCell, rc::Rc, time::Duration}; + use std::{cell::RefCell, path::Path, rc::Rc, time::Duration}; use util::path; diff --git a/crates/acp_thread/src/mention.rs b/crates/acp_thread/src/mention.rs new file mode 100644 index 0000000000..35f1d42c64 --- /dev/null +++ b/crates/acp_thread/src/mention.rs @@ -0,0 +1,122 @@ +use agent_client_protocol as acp; +use anyhow::{Result, bail}; +use std::path::PathBuf; + +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum MentionUri { + File(PathBuf), + Symbol(PathBuf, String), + Thread(acp::SessionId), + Rule(String), +} + +impl MentionUri { + pub fn parse(input: &str) -> Result { + let url = url::Url::parse(input)?; + let path = url.path(); + match url.scheme() { + "file" => { + if let Some(fragment) = url.fragment() { + Ok(Self::Symbol(path.into(), fragment.into())) + } else { + Ok(Self::File(path.into())) + } + } + "zed" => { + if let Some(thread) = path.strip_prefix("/agent/thread/") { + Ok(Self::Thread(acp::SessionId(thread.into()))) + } else if let Some(rule) = path.strip_prefix("/agent/rule/") { + Ok(Self::Rule(rule.into())) + } else { + bail!("invalid zed url: {:?}", input); + } + } + other => bail!("unrecognized scheme {:?}", other), + } + } + + pub fn name(&self) -> String { + match self { + MentionUri::File(path) => path.file_name().unwrap().to_string_lossy().into_owned(), + MentionUri::Symbol(_path, name) => name.clone(), + MentionUri::Thread(thread) => thread.to_string(), + MentionUri::Rule(rule) => rule.clone(), + } + } + + pub fn to_link(&self) -> String { + let name = self.name(); + let uri = self.to_uri(); + format!("[{name}]({uri})") + } + + pub fn to_uri(&self) -> String { + match self { + MentionUri::File(path) => { + format!("file://{}", path.display()) + } + MentionUri::Symbol(path, name) => { + format!("file://{}#{}", path.display(), name) + } + MentionUri::Thread(thread) => { + format!("zed://agent/thread/{}", thread.0) + } + MentionUri::Rule(rule) => { + format!("zed://agent/rule/{}", rule) + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_mention_uri_parse_and_display() { + // Test file URI + let file_uri = "file:///path/to/file.rs"; + let parsed = MentionUri::parse(file_uri).unwrap(); + match &parsed { + MentionUri::File(path) => assert_eq!(path.to_str().unwrap(), "/path/to/file.rs"), + _ => panic!("Expected File variant"), + } + assert_eq!(parsed.to_uri(), file_uri); + + // Test symbol URI + let symbol_uri = "file:///path/to/file.rs#MySymbol"; + let parsed = MentionUri::parse(symbol_uri).unwrap(); + match &parsed { + MentionUri::Symbol(path, symbol) => { + assert_eq!(path.to_str().unwrap(), "/path/to/file.rs"); + assert_eq!(symbol, "MySymbol"); + } + _ => panic!("Expected Symbol variant"), + } + assert_eq!(parsed.to_uri(), symbol_uri); + + // Test thread URI + let thread_uri = "zed://agent/thread/session123"; + let parsed = MentionUri::parse(thread_uri).unwrap(); + match &parsed { + MentionUri::Thread(session_id) => assert_eq!(session_id.0.as_ref(), "session123"), + _ => panic!("Expected Thread variant"), + } + assert_eq!(parsed.to_uri(), thread_uri); + + // Test rule URI + let rule_uri = "zed://agent/rule/my_rule"; + let parsed = MentionUri::parse(rule_uri).unwrap(); + match &parsed { + MentionUri::Rule(rule) => assert_eq!(rule, "my_rule"), + _ => panic!("Expected Rule variant"), + } + assert_eq!(parsed.to_uri(), rule_uri); + + // Test invalid scheme + assert!(MentionUri::parse("http://example.com").is_err()); + + // Test invalid zed path + assert!(MentionUri::parse("zed://invalid/path").is_err()); + } +} diff --git a/crates/agent2/src/agent.rs b/crates/agent2/src/agent.rs index b1cefd2864..4aa9c62f4c 100644 --- a/crates/agent2/src/agent.rs +++ b/crates/agent2/src/agent.rs @@ -1,7 +1,7 @@ use crate::{AgentResponseEvent, Thread, templates::Templates}; use crate::{ CopyPathTool, CreateDirectoryTool, EditFileTool, FindPathTool, GrepTool, ListDirectoryTool, - MovePathTool, NowTool, OpenTool, ReadFileTool, TerminalTool, ThinkingTool, + MessageContent, MovePathTool, NowTool, OpenTool, ReadFileTool, TerminalTool, ThinkingTool, ToolCallAuthorization, WebSearchTool, }; use acp_thread::ModelSelector; @@ -495,9 +495,13 @@ impl acp_thread::AgentConnection for NativeAgentConnection { log::debug!("Found session for: {}", session_id); // Convert prompt to message - let message = convert_prompt_to_message(params.prompt); + let message: Vec = params + .prompt + .into_iter() + .map(Into::into) + .collect::>(); log::info!("Converted prompt to message: {} chars", message.len()); - log::debug!("Message content: {}", message); + // log::debug!("Message content: {}", message); // Get model using the ModelSelector capability (always available for agent2) // Get the selected model from the thread directly @@ -601,39 +605,6 @@ impl acp_thread::AgentConnection for NativeAgentConnection { } } -/// Convert ACP content blocks to a message string -fn convert_prompt_to_message(blocks: Vec) -> String { - log::debug!("Converting {} content blocks to message", blocks.len()); - let mut message = String::new(); - - for block in blocks { - match block { - acp::ContentBlock::Text(text) => { - log::trace!("Processing text block: {} chars", text.text.len()); - message.push_str(&text.text); - } - acp::ContentBlock::ResourceLink(link) => { - log::trace!("Processing resource link: {}", link.uri); - message.push_str(&format!(" @{} ", link.uri)); - } - acp::ContentBlock::Image(_) => { - log::trace!("Processing image block"); - message.push_str(" [image] "); - } - acp::ContentBlock::Audio(_) => { - log::trace!("Processing audio block"); - message.push_str(" [audio] "); - } - acp::ContentBlock::Resource(resource) => { - log::trace!("Processing resource block: {:?}", resource.resource); - message.push_str(&format!(" [resource: {:?}] ", resource.resource)); - } - } - } - - message -} - #[cfg(test)] mod tests { use super::*; diff --git a/crates/agent2/src/tests/mod.rs b/crates/agent2/src/tests/mod.rs index b47816f35c..0d9959f27b 100644 --- a/crates/agent2/src/tests/mod.rs +++ b/crates/agent2/src/tests/mod.rs @@ -1,4 +1,5 @@ use super::*; +use crate::MessageContent; use acp_thread::AgentConnection; use action_log::ActionLog; use agent_client_protocol::{self as acp}; @@ -10,8 +11,8 @@ use gpui::{AppContext, Entity, Task, TestAppContext, http_client::FakeHttpClient use indoc::indoc; use language_model::{ LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId, - LanguageModelRegistry, LanguageModelToolResult, LanguageModelToolUse, MessageContent, Role, - StopReason, fake_provider::FakeLanguageModel, + LanguageModelRegistry, LanguageModelToolResult, LanguageModelToolUse, Role, StopReason, + fake_provider::FakeLanguageModel, }; use project::Project; use prompt_store::ProjectContext; @@ -266,14 +267,14 @@ async fn test_tool_authorization(cx: &mut TestAppContext) { assert_eq!( message.content, vec![ - MessageContent::ToolResult(LanguageModelToolResult { + language_model::MessageContent::ToolResult(LanguageModelToolResult { tool_use_id: tool_call_auth_1.tool_call.id.0.to_string().into(), tool_name: ToolRequiringPermission.name().into(), is_error: false, content: "Allowed".into(), output: Some("Allowed".into()) }), - MessageContent::ToolResult(LanguageModelToolResult { + language_model::MessageContent::ToolResult(LanguageModelToolResult { tool_use_id: tool_call_auth_2.tool_call.id.0.to_string().into(), tool_name: ToolRequiringPermission.name().into(), is_error: true, diff --git a/crates/agent2/src/thread.rs b/crates/agent2/src/thread.rs index 55ac548fd8..8ac38b6ed1 100644 --- a/crates/agent2/src/thread.rs +++ b/crates/agent2/src/thread.rs @@ -1,4 +1,5 @@ use crate::{SystemPromptTemplate, Template, Templates}; +use acp_thread::MentionUri; use action_log::ActionLog; use agent_client_protocol as acp; use anyhow::{Context as _, Result, anyhow}; @@ -10,12 +11,11 @@ use futures::{ stream::FuturesUnordered, }; use gpui::{App, Context, Entity, SharedString, Task}; -use language::Rope; use language_model::{ LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelImage, LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolSchemaFormat, - LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role, StopReason, + LanguageModelToolUse, LanguageModelToolUseId, Role, StopReason, }; use log; use project::Project; @@ -23,7 +23,8 @@ use prompt_store::ProjectContext; use schemars::{JsonSchema, Schema}; use serde::{Deserialize, Serialize}; use smol::stream::StreamExt; -use std::{cell::RefCell, collections::BTreeMap, fmt::Write, future::Future, rc::Rc, sync::Arc}; +use std::fmt::Write; +use std::{cell::RefCell, collections::BTreeMap, future::Future, path::Path, rc::Rc, sync::Arc}; use util::{ResultExt, markdown::MarkdownCodeBlock}; #[derive(Debug, Clone)] @@ -32,27 +33,15 @@ pub struct AgentMessage { pub content: Vec, } -enum MessageContent { +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum MessageContent { Text(String), Thinking { text: String, signature: Option, }, - File { - path: PathBuf, - content: Rope, - }, - Symbol { - name: String, - path: PathBuf, - content: Rope, - }, - Thread { - name: String, - content: String, - }, - Rule { - name: String, + Mention { + uri: MentionUri, content: String, }, RedactedThinking(String), @@ -61,27 +50,6 @@ enum MessageContent { ToolResult(LanguageModelToolResult), } -impl Into for MessageContent { - fn into(self) -> language_model::MessageContent { - match self { - MessageContent::Text(text) => language_model::MessageContent::Text(text), - MessageContent::Thinking { data, signature } => todo!(), - MessageContent::File { path, content } => todo!(), - MessageContent::Symbol { name } => todo!(), - MessageContent::Thread { name, content } => todo!(), - MessageContent::Rule { name, content } => todo!(), - MessageContent::RedactedThinking(text) => { - language_model::MessageContent::RedactedThinking(text) - } - MessageContent::Image(image) => language_model::MessageContent::Image(image), - MessageContent::ToolUse(tool_use) => language_model::MessageContent::ToolUse(tool_use), - MessageContent::ToolResult(tool_result) => { - language_model::MessageContent::ToolResult(tool_result) - } - } - } -} - impl AgentMessage { pub fn to_markdown(&self) -> String { let mut markdown = format!("## {}\n", self.role); @@ -141,6 +109,9 @@ impl AgentMessage { .unwrap(); } } + MessageContent::Mention { uri, .. } => { + write!(markdown, "{}", uri.to_link()).ok(); + } } } @@ -251,10 +222,11 @@ impl Thread { /// The returned channel will report all the occurrences in which the model stops before erroring or ending its turn. pub fn send( &mut self, - content: impl Into, + content: impl Into, cx: &mut Context, ) -> mpsc::UnboundedReceiver> { - let content = content.into(); + let content = content.into().0; + let model = self.selected_model.clone(); log::info!("Thread::send called with model: {:?}", model.name()); log::debug!("Thread::send content: {:?}", content); @@ -267,7 +239,7 @@ impl Thread { let user_message_ix = self.messages.len(); self.messages.push(AgentMessage { role: Role::User, - content: vec![content], + content, }); log::info!("Total messages in thread: {}", self.messages.len()); self.running_turn = Some(cx.spawn(async move |thread, cx| { @@ -389,7 +361,7 @@ impl Thread { log::debug!("System message built"); AgentMessage { role: Role::System, - content: vec![prompt.into()], + content: vec![prompt.as_str().into()], } } @@ -705,11 +677,7 @@ impl Thread { }, message.content.len() ); - LanguageModelRequestMessage { - role: message.role, - content: message.content.clone(), - cache: false, - } + message.to_request() }) .collect(); messages @@ -724,6 +692,20 @@ impl Thread { } } +pub struct UserMessage(Vec); + +impl From> for UserMessage { + fn from(content: Vec) -> Self { + UserMessage(content) + } +} + +impl> From for UserMessage { + fn from(content: T) -> Self { + UserMessage(vec![content.into()]) + } +} + pub trait AgentTool where Self: 'static + Sized, @@ -1088,3 +1070,207 @@ impl std::ops::DerefMut for ToolCallEventStreamReceiver { &mut self.0 } } + +impl AgentMessage { + fn to_request(&self) -> language_model::LanguageModelRequestMessage { + let mut message = LanguageModelRequestMessage { + role: self.role, + content: Vec::with_capacity(self.content.len()), + cache: false, + }; + + const OPEN_CONTEXT: &str = "\n\ + The following items were attached by the user. \ + They are up-to-date and don't need to be re-read.\n\n"; + + const OPEN_FILES_TAG: &str = ""; + const OPEN_SYMBOLS_TAG: &str = ""; + const OPEN_THREADS_TAG: &str = ""; + const OPEN_RULES_TAG: &str = + "\nThe user has specified the following rules that should be applied:\n"; + + let mut file_context = OPEN_FILES_TAG.to_string(); + let mut symbol_context = OPEN_SYMBOLS_TAG.to_string(); + let mut thread_context = OPEN_THREADS_TAG.to_string(); + let mut rules_context = OPEN_RULES_TAG.to_string(); + + for chunk in &self.content { + let chunk = match chunk { + MessageContent::Text(text) => language_model::MessageContent::Text(text.clone()), + MessageContent::Thinking { text, signature } => { + language_model::MessageContent::Thinking { + text: text.clone(), + signature: signature.clone(), + } + } + MessageContent::RedactedThinking(value) => { + language_model::MessageContent::RedactedThinking(value.clone()) + } + MessageContent::ToolUse(value) => { + language_model::MessageContent::ToolUse(value.clone()) + } + MessageContent::ToolResult(value) => { + language_model::MessageContent::ToolResult(value.clone()) + } + MessageContent::Image(value) => { + language_model::MessageContent::Image(value.clone()) + } + MessageContent::Mention { uri, content } => { + match uri { + MentionUri::File(path) | MentionUri::Symbol(path, _) => { + write!( + &mut symbol_context, + "\n{}", + MarkdownCodeBlock { + tag: &codeblock_tag(&path), + text: &content.to_string(), + } + ) + .ok(); + } + MentionUri::Thread(_session_id) => { + write!(&mut thread_context, "\n{}\n", content).ok(); + } + MentionUri::Rule(_user_prompt_id) => { + write!( + &mut rules_context, + "\n{}", + MarkdownCodeBlock { + tag: "", + text: &content + } + ) + .ok(); + } + } + + language_model::MessageContent::Text(uri.to_link()) + } + }; + + message.content.push(chunk); + } + + let len_before_context = message.content.len(); + + if file_context.len() > OPEN_FILES_TAG.len() { + file_context.push_str("\n"); + message + .content + .push(language_model::MessageContent::Text(file_context)); + } + + if symbol_context.len() > OPEN_SYMBOLS_TAG.len() { + symbol_context.push_str("\n"); + message + .content + .push(language_model::MessageContent::Text(symbol_context)); + } + + if thread_context.len() > OPEN_THREADS_TAG.len() { + thread_context.push_str("\n"); + message + .content + .push(language_model::MessageContent::Text(thread_context)); + } + + if rules_context.len() > OPEN_RULES_TAG.len() { + rules_context.push_str("\n"); + message + .content + .push(language_model::MessageContent::Text(rules_context)); + } + + if message.content.len() > len_before_context { + message.content.insert( + len_before_context, + language_model::MessageContent::Text(OPEN_CONTEXT.into()), + ); + message + .content + .push(language_model::MessageContent::Text("".into())); + } + + message + } +} + +fn codeblock_tag(full_path: &Path) -> String { + let mut result = String::new(); + + if let Some(extension) = full_path.extension().and_then(|ext| ext.to_str()) { + let _ = write!(result, "{} ", extension); + } + + let _ = write!(result, "{}", full_path.display()); + + result +} + +impl From for MessageContent { + fn from(value: acp::ContentBlock) -> Self { + match value { + acp::ContentBlock::Text(text_content) => MessageContent::Text(text_content.text), + acp::ContentBlock::Image(image_content) => { + MessageContent::Image(convert_image(image_content)) + } + acp::ContentBlock::Audio(_) => { + // TODO + MessageContent::Text("[audio]".to_string()) + } + acp::ContentBlock::ResourceLink(resource_link) => { + match MentionUri::parse(&resource_link.uri) { + Ok(uri) => Self::Mention { + uri, + content: String::new(), + }, + Err(err) => { + log::error!("Failed to parse mention link: {}", err); + MessageContent::Text(format!( + "[{}]({})", + resource_link.name, resource_link.uri + )) + } + } + } + acp::ContentBlock::Resource(resource) => match resource.resource { + acp::EmbeddedResourceResource::TextResourceContents(resource) => { + match MentionUri::parse(&resource.uri) { + Ok(uri) => Self::Mention { + uri, + content: resource.text, + }, + Err(err) => { + log::error!("Failed to parse mention link: {}", err); + MessageContent::Text( + MarkdownCodeBlock { + tag: &resource.uri, + text: &resource.text, + } + .to_string(), + ) + } + } + } + acp::EmbeddedResourceResource::BlobResourceContents(_) => { + // TODO + MessageContent::Text("[blob]".to_string()) + } + }, + } + } +} + +fn convert_image(image_content: acp::ImageContent) -> LanguageModelImage { + LanguageModelImage { + source: image_content.data.into(), + // TODO: make this optional? + size: gpui::Size::new(0.into(), 0.into()), + } +} + +impl From<&str> for MessageContent { + fn from(text: &str) -> Self { + MessageContent::Text(text.into()) + } +} diff --git a/crates/agent_ui/src/acp/completion_provider.rs b/crates/agent_ui/src/acp/completion_provider.rs index d8f452afa5..060bbf02b3 100644 --- a/crates/agent_ui/src/acp/completion_provider.rs +++ b/crates/agent_ui/src/acp/completion_provider.rs @@ -1,18 +1,20 @@ use std::ops::Range; -use std::path::Path; +use std::path::{Path, PathBuf}; use std::sync::Arc; use std::sync::atomic::AtomicBool; -use anyhow::Result; +use acp_thread::MentionUri; +use anyhow::{Context as _, Result}; use collections::HashMap; use editor::display_map::CreaseId; use editor::{CompletionProvider, Editor, ExcerptId}; use file_icons::FileIcons; +use futures::future::try_join_all; use gpui::{App, Entity, Task, WeakEntity}; use language::{Buffer, CodeLabel, HighlightId}; use lsp::CompletionContext; use parking_lot::Mutex; -use project::{Completion, CompletionIntent, CompletionResponse, ProjectPath, WorktreeId}; +use project::{Completion, CompletionIntent, CompletionResponse, Project, ProjectPath, WorktreeId}; use rope::Point; use text::{Anchor, ToPoint}; use ui::prelude::*; @@ -23,21 +25,63 @@ use crate::context_picker::file_context_picker::{extract_file_name_and_directory #[derive(Default)] pub struct MentionSet { - paths_by_crease_id: HashMap, + paths_by_crease_id: HashMap, } impl MentionSet { - pub fn insert(&mut self, crease_id: CreaseId, path: ProjectPath) { - self.paths_by_crease_id.insert(crease_id, path); - } - - pub fn path_for_crease_id(&self, crease_id: CreaseId) -> Option { - self.paths_by_crease_id.get(&crease_id).cloned() + pub fn insert(&mut self, crease_id: CreaseId, path: PathBuf) { + self.paths_by_crease_id + .insert(crease_id, MentionUri::File(path)); } pub fn drain(&mut self) -> impl Iterator { self.paths_by_crease_id.drain().map(|(id, _)| id) } + + pub fn contents( + &self, + project: Entity, + cx: &mut App, + ) -> Task>> { + let contents = self + .paths_by_crease_id + .iter() + .map(|(crease_id, uri)| match uri { + MentionUri::File(path) => { + let crease_id = crease_id.clone(); + let uri = uri.clone(); + let path = path.to_path_buf(); + let buffer_task = project.update(cx, |project, cx| { + let path = project + .find_project_path(path, cx) + .context("Failed to find project path")?; + anyhow::Ok(project.open_buffer(path, cx)) + }); + + cx.spawn(async move |cx| { + let buffer = buffer_task?.await?; + let content = buffer.read_with(cx, |buffer, _cx| buffer.text())?; + + anyhow::Ok((crease_id, Mention { uri, content })) + }) + } + _ => { + // TODO + unimplemented!() + } + }) + .collect::>(); + + cx.spawn(async move |_cx| { + let contents = try_join_all(contents).await?.into_iter().collect(); + anyhow::Ok(contents) + }) + } +} + +pub struct Mention { + pub uri: MentionUri, + pub content: String, } pub struct ContextPickerCompletionProvider { @@ -68,6 +112,7 @@ impl ContextPickerCompletionProvider { source_range: Range, editor: Entity, mention_set: Arc>, + project: Entity, cx: &App, ) -> Completion { let (file_name, directory) = @@ -112,6 +157,7 @@ impl ContextPickerCompletionProvider { new_text_len - 1, editor, mention_set, + project, )), } } @@ -159,6 +205,7 @@ impl CompletionProvider for ContextPickerCompletionProvider { return Task::ready(Ok(Vec::new())); }; + let project = workspace.read(cx).project().clone(); let snapshot = buffer.read(cx).snapshot(); let source_range = snapshot.anchor_before(state.source_range.start) ..snapshot.anchor_after(state.source_range.end); @@ -195,6 +242,7 @@ impl CompletionProvider for ContextPickerCompletionProvider { source_range.clone(), editor.clone(), mention_set.clone(), + project.clone(), cx, ) }) @@ -254,6 +302,7 @@ fn confirm_completion_callback( content_len: usize, editor: Entity, mention_set: Arc>, + project: Entity, ) -> Arc bool + Send + Sync> { Arc::new(move |_, window, cx| { let crease_text = crease_text.clone(); @@ -261,6 +310,7 @@ fn confirm_completion_callback( let editor = editor.clone(); let project_path = project_path.clone(); let mention_set = mention_set.clone(); + let project = project.clone(); window.defer(cx, move |window, cx| { let crease_id = crate::context_picker::insert_crease_for_mention( excerpt_id, @@ -272,8 +322,13 @@ fn confirm_completion_callback( window, cx, ); + + let Some(path) = project.read(cx).absolute_path(&project_path, cx) else { + return; + }; + if let Some(crease_id) = crease_id { - mention_set.lock().insert(crease_id, project_path); + mention_set.lock().insert(crease_id, path); } }); false diff --git a/crates/agent_ui/src/acp/thread_view.rs b/crates/agent_ui/src/acp/thread_view.rs index 2536612ece..583dcc777b 100644 --- a/crates/agent_ui/src/acp/thread_view.rs +++ b/crates/agent_ui/src/acp/thread_view.rs @@ -1,6 +1,6 @@ use acp_thread::{ AcpThread, AcpThreadEvent, AgentThreadEntry, AssistantMessage, AssistantMessageChunk, - LoadError, MentionPath, ThreadStatus, ToolCall, ToolCallContent, ToolCallStatus, + LoadError, MentionUri, ThreadStatus, ToolCall, ToolCallContent, ToolCallStatus, }; use acp_thread::{AgentConnection, Plan}; use action_log::ActionLog; @@ -28,6 +28,7 @@ use markdown::{HeadingLevelStyles, Markdown, MarkdownElement, MarkdownStyle}; use parking_lot::Mutex; use project::{CompletionIntent, Project}; use settings::{Settings as _, SettingsStore}; +use std::path::PathBuf; use std::{ cell::RefCell, collections::BTreeMap, path::Path, process::ExitStatus, rc::Rc, sync::Arc, time::Duration, @@ -374,81 +375,101 @@ impl AcpThreadView { let mut ix = 0; let mut chunks: Vec = Vec::new(); let project = self.project.clone(); - self.message_editor.update(cx, |editor, cx| { - let text = editor.text(cx); - editor.display_map.update(cx, |map, cx| { - let snapshot = map.snapshot(cx); - for (crease_id, crease) in snapshot.crease_snapshot.creases() { - // Skip creases that have been edited out of the message buffer. - if !crease.range().start.is_valid(&snapshot.buffer_snapshot) { - continue; - } - if let Some(project_path) = - self.mention_set.lock().path_for_crease_id(crease_id) - { - let crease_range = crease.range().to_offset(&snapshot.buffer_snapshot); - if crease_range.start > ix { - chunks.push(text[ix..crease_range.start].into()); + let contents = self.mention_set.lock().contents(project, cx); + + cx.spawn_in(window, async move |this, cx| { + let contents = match contents.await { + Ok(contents) => contents, + Err(e) => { + this.update(cx, |this, cx| { + this.last_error = + Some(cx.new(|cx| Markdown::new(e.to_string().into(), None, None, cx))); + }) + .ok(); + return; + } + }; + + this.update_in(cx, |this, window, cx| { + this.message_editor.update(cx, |editor, cx| { + let text = editor.text(cx); + editor.display_map.update(cx, |map, cx| { + let snapshot = map.snapshot(cx); + for (crease_id, crease) in snapshot.crease_snapshot.creases() { + // Skip creases that have been edited out of the message buffer. + if !crease.range().start.is_valid(&snapshot.buffer_snapshot) { + continue; + } + + if let Some(mention) = contents.get(&crease_id) { + let crease_range = + crease.range().to_offset(&snapshot.buffer_snapshot); + if crease_range.start > ix { + chunks.push(text[ix..crease_range.start].into()); + } + chunks.push(acp::ContentBlock::Resource(acp::EmbeddedResource { + annotations: None, + resource: acp::EmbeddedResourceResource::TextResourceContents( + acp::TextResourceContents { + mime_type: None, + text: mention.content.clone(), + uri: mention.uri.to_uri(), + }, + ), + })); + ix = crease_range.end; + } } - if let Some(abs_path) = project.read(cx).absolute_path(&project_path, cx) { - let path_str = abs_path.display().to_string(); - chunks.push(acp::ContentBlock::ResourceLink(acp::ResourceLink { - uri: path_str.clone(), - name: path_str, - annotations: None, - description: None, - mime_type: None, - size: None, - title: None, - })); + + if ix < text.len() { + let last_chunk = text[ix..].trim_end(); + if !last_chunk.is_empty() { + chunks.push(last_chunk.into()); + } } - ix = crease_range.end; - } + }) + }); + + if chunks.is_empty() { + return; } - if ix < text.len() { - let last_chunk = text[ix..].trim_end(); - if !last_chunk.is_empty() { - chunks.push(last_chunk.into()); - } - } - }) - }); - - if chunks.is_empty() { - return; - } - - let Some(thread) = self.thread() else { - return; - }; - let task = thread.update(cx, |thread, cx| thread.send(chunks.clone(), cx)); - - cx.spawn(async move |this, cx| { - let result = task.await; - - this.update(cx, |this, cx| { - if let Err(err) = result { - this.last_error = - Some(cx.new(|cx| Markdown::new(err.to_string().into(), None, None, cx))) - } + let Some(thread) = this.thread() else { + return; + }; + let task = thread.update(cx, |thread, cx| thread.send(chunks.clone(), cx)); + + cx.spawn(async move |this, cx| { + let result = task.await; + + this.update(cx, |this, cx| { + if let Err(err) = result { + this.last_error = + Some(cx.new(|cx| { + Markdown::new(err.to_string().into(), None, None, cx) + })) + } + }) + }) + .detach(); + + let mention_set = this.mention_set.clone(); + + this.set_editor_is_expanded(false, cx); + + this.message_editor.update(cx, |editor, cx| { + editor.clear(window, cx); + editor.remove_creases(mention_set.lock().drain(), cx) + }); + + this.scroll_to_bottom(cx); + + this.message_history.borrow_mut().push(chunks); }) + .ok(); }) .detach(); - - let mention_set = self.mention_set.clone(); - - self.set_editor_is_expanded(false, cx); - - self.message_editor.update(cx, |editor, cx| { - editor.clear(window, cx); - editor.remove_creases(mention_set.lock().drain(), cx) - }); - - self.scroll_to_bottom(cx); - - self.message_history.borrow_mut().push(chunks); } fn previous_history_message( @@ -561,16 +582,19 @@ impl AcpThreadView { acp::ContentBlock::Text(text_content) => { text.push_str(&text_content.text); } - acp::ContentBlock::ResourceLink(resource_link) => { - let path = Path::new(&resource_link.uri); + acp::ContentBlock::Resource(acp::EmbeddedResource { + resource: acp::EmbeddedResourceResource::TextResourceContents(resource), + .. + }) => { + let path = PathBuf::from(&resource.uri); + let project_path = project.read(cx).project_path_for_absolute_path(&path, cx); let start = text.len(); - let content = MentionPath::new(&path).to_string(); + let content = MentionUri::File(path).to_uri(); text.push_str(&content); let end = text.len(); - if let Some(project_path) = - project.read(cx).project_path_for_absolute_path(&path, cx) - { - let filename: SharedString = path + if let Some(project_path) = project_path { + let filename: SharedString = project_path + .path .file_name() .unwrap_or_default() .to_string_lossy() @@ -581,7 +605,8 @@ impl AcpThreadView { } acp::ContentBlock::Image(_) | acp::ContentBlock::Audio(_) - | acp::ContentBlock::Resource(_) => {} + | acp::ContentBlock::Resource(_) + | acp::ContentBlock::ResourceLink(_) => {} } } @@ -600,18 +625,21 @@ impl AcpThreadView { }; let anchor = snapshot.anchor_before(range.start); - let crease_id = crate::context_picker::insert_crease_for_mention( - anchor.excerpt_id, - anchor.text_anchor, - range.end - range.start, - filename, - crease_icon_path, - message_editor.clone(), - window, - cx, - ); - if let Some(crease_id) = crease_id { - mention_set.lock().insert(crease_id, project_path); + if let Some(project_path) = project.read(cx).absolute_path(&project_path, cx) { + let crease_id = crate::context_picker::insert_crease_for_mention( + anchor.excerpt_id, + anchor.text_anchor, + range.end - range.start, + filename, + crease_icon_path, + message_editor.clone(), + window, + cx, + ); + + if let Some(crease_id) = crease_id { + mention_set.lock().insert(crease_id, project_path); + } } } @@ -2302,25 +2330,31 @@ impl AcpThreadView { return; }; - if let Some(mention_path) = MentionPath::try_parse(&url) { - workspace.update(cx, |workspace, cx| { - let project = workspace.project(); - let Some((path, entry)) = project.update(cx, |project, cx| { - let path = project.find_project_path(mention_path.path(), cx)?; - let entry = project.entry_for_path(&path, cx)?; - Some((path, entry)) - }) else { - return; - }; + if let Some(mention) = MentionUri::parse(&url).log_err() { + workspace.update(cx, |workspace, cx| match mention { + MentionUri::File(path) => { + let project = workspace.project(); + let Some((path, entry)) = project.update(cx, |project, cx| { + let path = project.find_project_path(path, cx)?; + let entry = project.entry_for_path(&path, cx)?; + Some((path, entry)) + }) else { + return; + }; - if entry.is_dir() { - project.update(cx, |_, cx| { - cx.emit(project::Event::RevealInProjectPanel(entry.id)); - }); - } else { - workspace - .open_path(path, None, true, window, cx) - .detach_and_log_err(cx); + if entry.is_dir() { + project.update(cx, |_, cx| { + cx.emit(project::Event::RevealInProjectPanel(entry.id)); + }); + } else { + workspace + .open_path(path, None, true, window, cx) + .detach_and_log_err(cx); + } + } + _ => { + // TODO + unimplemented!() } }) } else { @@ -2715,6 +2749,7 @@ impl AcpThreadView { anchor..anchor, self.message_editor.clone(), self.mention_set.clone(), + self.project.clone(), cx, ); @@ -2857,7 +2892,7 @@ fn user_message_markdown_style(window: &Window, cx: &App) -> MarkdownStyle { style.base_text_style = text_style; style.link_callback = Some(Rc::new(move |url, cx| { - if MentionPath::try_parse(url).is_some() { + if MentionUri::parse(url).is_ok() { let colors = cx.theme().colors(); Some(TextStyleRefinement { background_color: Some(colors.element_background),