diff --git a/Cargo.lock b/Cargo.lock index dc28a1cb44..cfbd5b653f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -7,6 +7,7 @@ name = "acp_thread" version = "0.1.0" dependencies = [ "action_log", + "agent", "agent-client-protocol", "anyhow", "buffer_diff", @@ -21,6 +22,7 @@ dependencies = [ "markdown", "parking_lot", "project", + "prompt_store", "rand 0.8.5", "serde", "serde_json", @@ -29,7 +31,9 @@ dependencies = [ "tempfile", "terminal", "ui", + "url", "util", + "uuid", "workspace-hack", ] @@ -388,6 +392,7 @@ dependencies = [ "ui", "ui_input", "unindent", + "url", "urlencoding", "util", "uuid", diff --git a/crates/acp_thread/Cargo.toml b/crates/acp_thread/Cargo.toml index 33e88df761..66009e355b 100644 --- a/crates/acp_thread/Cargo.toml +++ b/crates/acp_thread/Cargo.toml @@ -18,6 +18,7 @@ test-support = ["gpui/test-support", "project/test-support"] [dependencies] action_log.workspace = true agent-client-protocol.workspace = true +agent.workspace = true anyhow.workspace = true buffer_diff.workspace = true editor.workspace = true @@ -28,13 +29,16 @@ language.workspace = true language_model.workspace = true markdown.workspace = true project.workspace = true +prompt_store.workspace = true serde.workspace = true serde_json.workspace = true settings.workspace = true smol.workspace = true terminal.workspace = true ui.workspace = true +url.workspace = true util.workspace = true +uuid.workspace = true workspace-hack.workspace = true [dev-dependencies] diff --git a/crates/acp_thread/src/acp_thread.rs b/crates/acp_thread/src/acp_thread.rs index 1c0a9479df..d853686020 100644 --- a/crates/acp_thread/src/acp_thread.rs +++ b/crates/acp_thread/src/acp_thread.rs @@ -1,13 +1,15 @@ mod connection; mod diff; +mod mention; mod terminal; pub use connection::*; pub use diff::*; +pub use mention::*; pub use terminal::*; use action_log::ActionLog; -use agent_client_protocol as acp; +use agent_client_protocol::{self as acp}; use anyhow::{Context as _, Result}; use editor::Bias; use futures::{FutureExt, channel::oneshot, future::BoxFuture}; @@ -21,12 +23,7 @@ use std::error::Error; use std::fmt::Formatter; use std::process::ExitStatus; use std::rc::Rc; -use std::{ - fmt::Display, - mem, - path::{Path, PathBuf}, - sync::Arc, -}; +use std::{fmt::Display, mem, path::PathBuf, sync::Arc}; use ui::App; use util::ResultExt; @@ -53,38 +50,6 @@ impl UserMessage { } } -#[derive(Debug)] -pub struct MentionPath<'a>(&'a Path); - -impl<'a> MentionPath<'a> { - const PREFIX: &'static str = "@file:"; - - pub fn new(path: &'a Path) -> Self { - MentionPath(path) - } - - pub fn try_parse(url: &'a str) -> Option { - let path = url.strip_prefix(Self::PREFIX)?; - Some(MentionPath(Path::new(path))) - } - - pub fn path(&self) -> &Path { - self.0 - } -} - -impl Display for MentionPath<'_> { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!( - f, - "[@{}]({}{})", - self.0.file_name().unwrap_or_default().display(), - Self::PREFIX, - self.0.display() - ) - } -} - #[derive(Debug, PartialEq)] pub struct AssistantMessage { pub chunks: Vec, @@ -367,16 +332,24 @@ impl ContentBlock { ) { let new_content = match block { acp::ContentBlock::Text(text_content) => text_content.text.clone(), - acp::ContentBlock::ResourceLink(resource_link) => { - if let Some(path) = resource_link.uri.strip_prefix("file://") { - format!("{}", MentionPath(path.as_ref())) + acp::ContentBlock::Resource(acp::EmbeddedResource { + resource: + acp::EmbeddedResourceResource::TextResourceContents(acp::TextResourceContents { + uri, + .. + }), + .. + }) => { + if let Some(uri) = MentionUri::parse(&uri).log_err() { + uri.as_link().to_string() } else { - resource_link.uri.clone() + uri.clone() } } acp::ContentBlock::Image(_) | acp::ContentBlock::Audio(_) - | acp::ContentBlock::Resource(_) => String::new(), + | acp::ContentBlock::Resource(acp::EmbeddedResource { .. }) + | acp::ContentBlock::ResourceLink(_) => String::new(), }; match self { @@ -1329,7 +1302,7 @@ mod tests { use serde_json::json; use settings::SettingsStore; use smol::stream::StreamExt as _; - use std::{cell::RefCell, rc::Rc, time::Duration}; + use std::{cell::RefCell, path::Path, rc::Rc, time::Duration}; use util::path; diff --git a/crates/acp_thread/src/mention.rs b/crates/acp_thread/src/mention.rs new file mode 100644 index 0000000000..3e9d93d633 --- /dev/null +++ b/crates/acp_thread/src/mention.rs @@ -0,0 +1,368 @@ +use agent::ThreadId; +use anyhow::{Context as _, Result, bail}; +use prompt_store::{PromptId, UserPromptId}; +use std::{ + fmt, + ops::Range, + path::{Path, PathBuf}, +}; +use url::Url; + +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum MentionUri { + File(PathBuf), + Symbol { + path: PathBuf, + name: String, + line_range: Range, + }, + Thread { + id: ThreadId, + name: String, + }, + TextThread { + path: PathBuf, + name: String, + }, + Rule { + id: PromptId, + name: String, + }, + Selection { + path: PathBuf, + line_range: Range, + }, + Fetch { + url: Url, + }, +} + +impl MentionUri { + pub fn parse(input: &str) -> Result { + let url = url::Url::parse(input)?; + let path = url.path(); + match url.scheme() { + "file" => { + if let Some(fragment) = url.fragment() { + let range = fragment + .strip_prefix("L") + .context("Line range must start with \"L\"")?; + let (start, end) = range + .split_once(":") + .context("Line range must use colon as separator")?; + let line_range = start + .parse::() + .context("Parsing line range start")? + .checked_sub(1) + .context("Line numbers should be 1-based")? + ..end + .parse::() + .context("Parsing line range end")? + .checked_sub(1) + .context("Line numbers should be 1-based")?; + if let Some(name) = single_query_param(&url, "symbol")? { + Ok(Self::Symbol { + name, + path: path.into(), + line_range, + }) + } else { + Ok(Self::Selection { + path: path.into(), + line_range, + }) + } + } else { + Ok(Self::File(path.into())) + } + } + "zed" => { + if let Some(thread_id) = path.strip_prefix("/agent/thread/") { + let name = single_query_param(&url, "name")?.context("Missing thread name")?; + Ok(Self::Thread { + id: thread_id.into(), + name, + }) + } else if let Some(path) = path.strip_prefix("/agent/text-thread/") { + let name = single_query_param(&url, "name")?.context("Missing thread name")?; + Ok(Self::TextThread { + path: path.into(), + name, + }) + } else if let Some(rule_id) = path.strip_prefix("/agent/rule/") { + let name = single_query_param(&url, "name")?.context("Missing rule name")?; + let rule_id = UserPromptId(rule_id.parse()?); + Ok(Self::Rule { + id: rule_id.into(), + name, + }) + } else { + bail!("invalid zed url: {:?}", input); + } + } + "http" | "https" => Ok(MentionUri::Fetch { url }), + other => bail!("unrecognized scheme {:?}", other), + } + } + + fn name(&self) -> String { + match self { + MentionUri::File(path) => path + .file_name() + .unwrap_or_default() + .to_string_lossy() + .into_owned(), + MentionUri::Symbol { name, .. } => name.clone(), + MentionUri::Thread { name, .. } => name.clone(), + MentionUri::TextThread { name, .. } => name.clone(), + MentionUri::Rule { name, .. } => name.clone(), + MentionUri::Selection { + path, line_range, .. + } => selection_name(path, line_range), + MentionUri::Fetch { url } => url.to_string(), + } + } + + pub fn as_link<'a>(&'a self) -> MentionLink<'a> { + MentionLink(self) + } + + pub fn to_uri(&self) -> Url { + match self { + MentionUri::File(path) => { + let mut url = Url::parse("file:///").unwrap(); + url.set_path(&path.to_string_lossy()); + url + } + MentionUri::Symbol { + path, + name, + line_range, + } => { + let mut url = Url::parse("file:///").unwrap(); + url.set_path(&path.to_string_lossy()); + url.query_pairs_mut().append_pair("symbol", name); + url.set_fragment(Some(&format!( + "L{}:{}", + line_range.start + 1, + line_range.end + 1 + ))); + url + } + MentionUri::Selection { path, line_range } => { + let mut url = Url::parse("file:///").unwrap(); + url.set_path(&path.to_string_lossy()); + url.set_fragment(Some(&format!( + "L{}:{}", + line_range.start + 1, + line_range.end + 1 + ))); + url + } + MentionUri::Thread { name, id } => { + let mut url = Url::parse("zed:///").unwrap(); + url.set_path(&format!("/agent/thread/{id}")); + url.query_pairs_mut().append_pair("name", name); + url + } + MentionUri::TextThread { path, name } => { + let mut url = Url::parse("zed:///").unwrap(); + url.set_path(&format!("/agent/text-thread/{}", path.to_string_lossy())); + url.query_pairs_mut().append_pair("name", name); + url + } + MentionUri::Rule { name, id } => { + let mut url = Url::parse("zed:///").unwrap(); + url.set_path(&format!("/agent/rule/{id}")); + url.query_pairs_mut().append_pair("name", name); + url + } + MentionUri::Fetch { url } => url.clone(), + } + } +} + +pub struct MentionLink<'a>(&'a MentionUri); + +impl fmt::Display for MentionLink<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "[@{}]({})", self.0.name(), self.0.to_uri()) + } +} + +fn single_query_param(url: &Url, name: &'static str) -> Result> { + let pairs = url.query_pairs().collect::>(); + match pairs.as_slice() { + [] => Ok(None), + [(k, v)] => { + if k != name { + bail!("invalid query parameter") + } + + Ok(Some(v.to_string())) + } + _ => bail!("too many query pairs"), + } +} + +pub fn selection_name(path: &Path, line_range: &Range) -> String { + format!( + "{} ({}:{})", + path.file_name().unwrap_or_default().display(), + line_range.start + 1, + line_range.end + 1 + ) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parse_file_uri() { + let file_uri = "file:///path/to/file.rs"; + let parsed = MentionUri::parse(file_uri).unwrap(); + match &parsed { + MentionUri::File(path) => assert_eq!(path.to_str().unwrap(), "/path/to/file.rs"), + _ => panic!("Expected File variant"), + } + assert_eq!(parsed.to_uri().to_string(), file_uri); + } + + #[test] + fn test_parse_symbol_uri() { + let symbol_uri = "file:///path/to/file.rs?symbol=MySymbol#L10:20"; + let parsed = MentionUri::parse(symbol_uri).unwrap(); + match &parsed { + MentionUri::Symbol { + path, + name, + line_range, + } => { + assert_eq!(path.to_str().unwrap(), "/path/to/file.rs"); + assert_eq!(name, "MySymbol"); + assert_eq!(line_range.start, 9); + assert_eq!(line_range.end, 19); + } + _ => panic!("Expected Symbol variant"), + } + assert_eq!(parsed.to_uri().to_string(), symbol_uri); + } + + #[test] + fn test_parse_selection_uri() { + let selection_uri = "file:///path/to/file.rs#L5:15"; + let parsed = MentionUri::parse(selection_uri).unwrap(); + match &parsed { + MentionUri::Selection { path, line_range } => { + assert_eq!(path.to_str().unwrap(), "/path/to/file.rs"); + assert_eq!(line_range.start, 4); + assert_eq!(line_range.end, 14); + } + _ => panic!("Expected Selection variant"), + } + assert_eq!(parsed.to_uri().to_string(), selection_uri); + } + + #[test] + fn test_parse_thread_uri() { + let thread_uri = "zed:///agent/thread/session123?name=Thread%20name"; + let parsed = MentionUri::parse(thread_uri).unwrap(); + match &parsed { + MentionUri::Thread { + id: thread_id, + name, + } => { + assert_eq!(thread_id.to_string(), "session123"); + assert_eq!(name, "Thread name"); + } + _ => panic!("Expected Thread variant"), + } + assert_eq!(parsed.to_uri().to_string(), thread_uri); + } + + #[test] + fn test_parse_rule_uri() { + let rule_uri = "zed:///agent/rule/d8694ff2-90d5-4b6f-be33-33c1763acd52?name=Some%20rule"; + let parsed = MentionUri::parse(rule_uri).unwrap(); + match &parsed { + MentionUri::Rule { id, name } => { + assert_eq!(id.to_string(), "d8694ff2-90d5-4b6f-be33-33c1763acd52"); + assert_eq!(name, "Some rule"); + } + _ => panic!("Expected Rule variant"), + } + assert_eq!(parsed.to_uri().to_string(), rule_uri); + } + + #[test] + fn test_parse_fetch_http_uri() { + let http_uri = "http://example.com/path?query=value#fragment"; + let parsed = MentionUri::parse(http_uri).unwrap(); + match &parsed { + MentionUri::Fetch { url } => { + assert_eq!(url.to_string(), http_uri); + } + _ => panic!("Expected Fetch variant"), + } + assert_eq!(parsed.to_uri().to_string(), http_uri); + } + + #[test] + fn test_parse_fetch_https_uri() { + let https_uri = "https://example.com/api/endpoint"; + let parsed = MentionUri::parse(https_uri).unwrap(); + match &parsed { + MentionUri::Fetch { url } => { + assert_eq!(url.to_string(), https_uri); + } + _ => panic!("Expected Fetch variant"), + } + assert_eq!(parsed.to_uri().to_string(), https_uri); + } + + #[test] + fn test_invalid_scheme() { + assert!(MentionUri::parse("ftp://example.com").is_err()); + assert!(MentionUri::parse("ssh://example.com").is_err()); + assert!(MentionUri::parse("unknown://example.com").is_err()); + } + + #[test] + fn test_invalid_zed_path() { + assert!(MentionUri::parse("zed:///invalid/path").is_err()); + assert!(MentionUri::parse("zed:///agent/unknown/test").is_err()); + } + + #[test] + fn test_invalid_line_range_format() { + // Missing L prefix + assert!(MentionUri::parse("file:///path/to/file.rs#10:20").is_err()); + + // Missing colon separator + assert!(MentionUri::parse("file:///path/to/file.rs#L1020").is_err()); + + // Invalid numbers + assert!(MentionUri::parse("file:///path/to/file.rs#L10:abc").is_err()); + assert!(MentionUri::parse("file:///path/to/file.rs#Labc:20").is_err()); + } + + #[test] + fn test_invalid_query_parameters() { + // Invalid query parameter name + assert!(MentionUri::parse("file:///path/to/file.rs#L10:20?invalid=test").is_err()); + + // Too many query parameters + assert!( + MentionUri::parse("file:///path/to/file.rs#L10:20?symbol=test&another=param").is_err() + ); + } + + #[test] + fn test_zero_based_line_numbers() { + // Test that 0-based line numbers are rejected (should be 1-based) + assert!(MentionUri::parse("file:///path/to/file.rs#L0:10").is_err()); + assert!(MentionUri::parse("file:///path/to/file.rs#L1:0").is_err()); + assert!(MentionUri::parse("file:///path/to/file.rs#L0:0").is_err()); + } +} diff --git a/crates/agent/src/thread_store.rs b/crates/agent/src/thread_store.rs index cc7cb50c91..12c94a522d 100644 --- a/crates/agent/src/thread_store.rs +++ b/crates/agent/src/thread_store.rs @@ -205,6 +205,22 @@ impl ThreadStore { (this, ready_rx) } + #[cfg(any(test, feature = "test-support"))] + pub fn fake(project: Entity, cx: &mut App) -> Self { + Self { + project, + tools: cx.new(|_| ToolWorkingSet::default()), + prompt_builder: Arc::new(PromptBuilder::new(None).unwrap()), + prompt_store: None, + context_server_tool_ids: HashMap::default(), + threads: Vec::new(), + project_context: SharedProjectContext::default(), + reload_system_prompt_tx: mpsc::channel(0).0, + _reload_system_prompt_task: Task::ready(()), + _subscriptions: vec![], + } + } + fn handle_project_event( &mut self, _project: Entity, diff --git a/crates/agent2/src/agent.rs b/crates/agent2/src/agent.rs index 18a830b978..7439b2a088 100644 --- a/crates/agent2/src/agent.rs +++ b/crates/agent2/src/agent.rs @@ -1,8 +1,8 @@ use crate::{AgentResponseEvent, Thread, templates::Templates}; use crate::{ ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DiagnosticsTool, EditFileTool, - FetchTool, FindPathTool, GrepTool, ListDirectoryTool, MovePathTool, NowTool, OpenTool, - ReadFileTool, TerminalTool, ThinkingTool, ToolCallAuthorization, WebSearchTool, + FetchTool, FindPathTool, GrepTool, ListDirectoryTool, MessageContent, MovePathTool, NowTool, + OpenTool, ReadFileTool, TerminalTool, ThinkingTool, ToolCallAuthorization, WebSearchTool, }; use acp_thread::ModelSelector; use agent_client_protocol as acp; @@ -516,10 +516,13 @@ impl acp_thread::AgentConnection for NativeAgentConnection { })?; log::debug!("Found session for: {}", session_id); - // Convert prompt to message - let message = convert_prompt_to_message(params.prompt); + let message: Vec = params + .prompt + .into_iter() + .map(Into::into) + .collect::>(); log::info!("Converted prompt to message: {} chars", message.len()); - log::debug!("Message content: {}", message); + log::debug!("Message content: {:?}", message); // Get model using the ModelSelector capability (always available for agent2) // Get the selected model from the thread directly @@ -623,39 +626,6 @@ impl acp_thread::AgentConnection for NativeAgentConnection { } } -/// Convert ACP content blocks to a message string -fn convert_prompt_to_message(blocks: Vec) -> String { - log::debug!("Converting {} content blocks to message", blocks.len()); - let mut message = String::new(); - - for block in blocks { - match block { - acp::ContentBlock::Text(text) => { - log::trace!("Processing text block: {} chars", text.text.len()); - message.push_str(&text.text); - } - acp::ContentBlock::ResourceLink(link) => { - log::trace!("Processing resource link: {}", link.uri); - message.push_str(&format!(" @{} ", link.uri)); - } - acp::ContentBlock::Image(_) => { - log::trace!("Processing image block"); - message.push_str(" [image] "); - } - acp::ContentBlock::Audio(_) => { - log::trace!("Processing audio block"); - message.push_str(" [audio] "); - } - acp::ContentBlock::Resource(resource) => { - log::trace!("Processing resource block: {:?}", resource.resource); - message.push_str(&format!(" [resource: {:?}] ", resource.resource)); - } - } - } - - message -} - #[cfg(test)] mod tests { use super::*; diff --git a/crates/agent2/src/tests/mod.rs b/crates/agent2/src/tests/mod.rs index 7f4b934c08..88cf92836b 100644 --- a/crates/agent2/src/tests/mod.rs +++ b/crates/agent2/src/tests/mod.rs @@ -1,4 +1,5 @@ use super::*; +use crate::MessageContent; use acp_thread::AgentConnection; use action_log::ActionLog; use agent_client_protocol::{self as acp}; @@ -13,8 +14,8 @@ use gpui::{ use indoc::indoc; use language_model::{ LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId, - LanguageModelRegistry, LanguageModelToolResult, LanguageModelToolUse, MessageContent, Role, - StopReason, fake_provider::FakeLanguageModel, + LanguageModelRegistry, LanguageModelToolResult, LanguageModelToolUse, Role, StopReason, + fake_provider::FakeLanguageModel, }; use project::Project; use prompt_store::ProjectContext; @@ -272,14 +273,14 @@ async fn test_tool_authorization(cx: &mut TestAppContext) { assert_eq!( message.content, vec![ - MessageContent::ToolResult(LanguageModelToolResult { + language_model::MessageContent::ToolResult(LanguageModelToolResult { tool_use_id: tool_call_auth_1.tool_call.id.0.to_string().into(), tool_name: ToolRequiringPermission.name().into(), is_error: false, content: "Allowed".into(), output: Some("Allowed".into()) }), - MessageContent::ToolResult(LanguageModelToolResult { + language_model::MessageContent::ToolResult(LanguageModelToolResult { tool_use_id: tool_call_auth_2.tool_call.id.0.to_string().into(), tool_name: ToolRequiringPermission.name().into(), is_error: true, @@ -312,13 +313,15 @@ async fn test_tool_authorization(cx: &mut TestAppContext) { let message = completion.messages.last().unwrap(); assert_eq!( message.content, - vec![MessageContent::ToolResult(LanguageModelToolResult { - tool_use_id: tool_call_auth_3.tool_call.id.0.to_string().into(), - tool_name: ToolRequiringPermission.name().into(), - is_error: false, - content: "Allowed".into(), - output: Some("Allowed".into()) - })] + vec![language_model::MessageContent::ToolResult( + LanguageModelToolResult { + tool_use_id: tool_call_auth_3.tool_call.id.0.to_string().into(), + tool_name: ToolRequiringPermission.name().into(), + is_error: false, + content: "Allowed".into(), + output: Some("Allowed".into()) + } + )] ); // Simulate a final tool call, ensuring we don't trigger authorization. @@ -337,13 +340,15 @@ async fn test_tool_authorization(cx: &mut TestAppContext) { let message = completion.messages.last().unwrap(); assert_eq!( message.content, - vec![MessageContent::ToolResult(LanguageModelToolResult { - tool_use_id: "tool_id_4".into(), - tool_name: ToolRequiringPermission.name().into(), - is_error: false, - content: "Allowed".into(), - output: Some("Allowed".into()) - })] + vec![language_model::MessageContent::ToolResult( + LanguageModelToolResult { + tool_use_id: "tool_id_4".into(), + tool_name: ToolRequiringPermission.name().into(), + is_error: false, + content: "Allowed".into(), + output: Some("Allowed".into()) + } + )] ); } diff --git a/crates/agent2/src/thread.rs b/crates/agent2/src/thread.rs index 231f83ce20..d33b1b8f9d 100644 --- a/crates/agent2/src/thread.rs +++ b/crates/agent2/src/thread.rs @@ -1,4 +1,5 @@ use crate::{ContextServerRegistry, SystemPromptTemplate, Template, Templates}; +use acp_thread::MentionUri; use action_log::ActionLog; use agent_client_protocol as acp; use agent_settings::{AgentProfileId, AgentSettings}; @@ -13,10 +14,10 @@ use futures::{ }; use gpui::{App, Context, Entity, SharedString, Task}; use language_model::{ - LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, + LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelImage, LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolSchemaFormat, - LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role, StopReason, + LanguageModelToolUse, LanguageModelToolUseId, Role, StopReason, }; use log; use project::Project; @@ -25,7 +26,8 @@ use schemars::{JsonSchema, Schema}; use serde::{Deserialize, Serialize}; use settings::{Settings, update_settings_file}; use smol::stream::StreamExt; -use std::{cell::RefCell, collections::BTreeMap, fmt::Write, rc::Rc, sync::Arc}; +use std::{cell::RefCell, collections::BTreeMap, path::Path, rc::Rc, sync::Arc}; +use std::{fmt::Write, ops::Range}; use util::{ResultExt, markdown::MarkdownCodeBlock}; #[derive(Debug, Clone)] @@ -34,6 +36,23 @@ pub struct AgentMessage { pub content: Vec, } +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum MessageContent { + Text(String), + Thinking { + text: String, + signature: Option, + }, + Mention { + uri: MentionUri, + content: String, + }, + RedactedThinking(String), + Image(LanguageModelImage), + ToolUse(LanguageModelToolUse), + ToolResult(LanguageModelToolResult), +} + impl AgentMessage { pub fn to_markdown(&self) -> String { let mut markdown = format!("## {}\n", self.role); @@ -93,6 +112,9 @@ impl AgentMessage { .unwrap(); } } + MessageContent::Mention { uri, .. } => { + write!(markdown, "{}", uri.as_link()).ok(); + } } } @@ -214,10 +236,11 @@ impl Thread { /// The returned channel will report all the occurrences in which the model stops before erroring or ending its turn. pub fn send( &mut self, - content: impl Into, + content: impl Into, cx: &mut Context, ) -> mpsc::UnboundedReceiver> { - let content = content.into(); + let content = content.into().0; + let model = self.selected_model.clone(); log::info!("Thread::send called with model: {:?}", model.name()); log::debug!("Thread::send content: {:?}", content); @@ -230,7 +253,7 @@ impl Thread { let user_message_ix = self.messages.len(); self.messages.push(AgentMessage { role: Role::User, - content: vec![content], + content, }); log::info!("Total messages in thread: {}", self.messages.len()); self.running_turn = Some(cx.spawn(async move |thread, cx| { @@ -353,7 +376,7 @@ impl Thread { log::debug!("System message built"); AgentMessage { role: Role::System, - content: vec![prompt.into()], + content: vec![prompt.as_str().into()], } } @@ -701,11 +724,7 @@ impl Thread { }, message.content.len() ); - LanguageModelRequestMessage { - role: message.role, - content: message.content.clone(), - cache: false, - } + message.to_request() }) .collect(); messages @@ -720,6 +739,20 @@ impl Thread { } } +pub struct UserMessage(Vec); + +impl From> for UserMessage { + fn from(content: Vec) -> Self { + UserMessage(content) + } +} + +impl> From for UserMessage { + fn from(content: T) -> Self { + UserMessage(vec![content.into()]) + } +} + pub trait AgentTool where Self: 'static + Sized, @@ -1102,3 +1135,246 @@ impl std::ops::DerefMut for ToolCallEventStreamReceiver { &mut self.0 } } + +impl AgentMessage { + fn to_request(&self) -> language_model::LanguageModelRequestMessage { + let mut message = LanguageModelRequestMessage { + role: self.role, + content: Vec::with_capacity(self.content.len()), + cache: false, + }; + + const OPEN_CONTEXT: &str = "\n\ + The following items were attached by the user. \ + They are up-to-date and don't need to be re-read.\n\n"; + + const OPEN_FILES_TAG: &str = ""; + const OPEN_SYMBOLS_TAG: &str = ""; + const OPEN_THREADS_TAG: &str = ""; + const OPEN_FETCH_TAG: &str = ""; + const OPEN_RULES_TAG: &str = + "\nThe user has specified the following rules that should be applied:\n"; + + let mut file_context = OPEN_FILES_TAG.to_string(); + let mut symbol_context = OPEN_SYMBOLS_TAG.to_string(); + let mut thread_context = OPEN_THREADS_TAG.to_string(); + let mut fetch_context = OPEN_FETCH_TAG.to_string(); + let mut rules_context = OPEN_RULES_TAG.to_string(); + + for chunk in &self.content { + let chunk = match chunk { + MessageContent::Text(text) => language_model::MessageContent::Text(text.clone()), + MessageContent::Thinking { text, signature } => { + language_model::MessageContent::Thinking { + text: text.clone(), + signature: signature.clone(), + } + } + MessageContent::RedactedThinking(value) => { + language_model::MessageContent::RedactedThinking(value.clone()) + } + MessageContent::ToolUse(value) => { + language_model::MessageContent::ToolUse(value.clone()) + } + MessageContent::ToolResult(value) => { + language_model::MessageContent::ToolResult(value.clone()) + } + MessageContent::Image(value) => { + language_model::MessageContent::Image(value.clone()) + } + MessageContent::Mention { uri, content } => { + match uri { + MentionUri::File(path) => { + write!( + &mut symbol_context, + "\n{}", + MarkdownCodeBlock { + tag: &codeblock_tag(&path, None), + text: &content.to_string(), + } + ) + .ok(); + } + MentionUri::Symbol { + path, line_range, .. + } + | MentionUri::Selection { + path, line_range, .. + } => { + write!( + &mut rules_context, + "\n{}", + MarkdownCodeBlock { + tag: &codeblock_tag(&path, Some(line_range)), + text: &content + } + ) + .ok(); + } + MentionUri::Thread { .. } => { + write!(&mut thread_context, "\n{}\n", content).ok(); + } + MentionUri::TextThread { .. } => { + write!(&mut thread_context, "\n{}\n", content).ok(); + } + MentionUri::Rule { .. } => { + write!( + &mut rules_context, + "\n{}", + MarkdownCodeBlock { + tag: "", + text: &content + } + ) + .ok(); + } + MentionUri::Fetch { url } => { + write!(&mut fetch_context, "\nFetch: {}\n\n{}", url, content).ok(); + } + } + + language_model::MessageContent::Text(uri.as_link().to_string()) + } + }; + + message.content.push(chunk); + } + + let len_before_context = message.content.len(); + + if file_context.len() > OPEN_FILES_TAG.len() { + file_context.push_str("\n"); + message + .content + .push(language_model::MessageContent::Text(file_context)); + } + + if symbol_context.len() > OPEN_SYMBOLS_TAG.len() { + symbol_context.push_str("\n"); + message + .content + .push(language_model::MessageContent::Text(symbol_context)); + } + + if thread_context.len() > OPEN_THREADS_TAG.len() { + thread_context.push_str("\n"); + message + .content + .push(language_model::MessageContent::Text(thread_context)); + } + + if fetch_context.len() > OPEN_FETCH_TAG.len() { + fetch_context.push_str("\n"); + message + .content + .push(language_model::MessageContent::Text(fetch_context)); + } + + if rules_context.len() > OPEN_RULES_TAG.len() { + rules_context.push_str("\n"); + message + .content + .push(language_model::MessageContent::Text(rules_context)); + } + + if message.content.len() > len_before_context { + message.content.insert( + len_before_context, + language_model::MessageContent::Text(OPEN_CONTEXT.into()), + ); + message + .content + .push(language_model::MessageContent::Text("".into())); + } + + message + } +} + +fn codeblock_tag(full_path: &Path, line_range: Option<&Range>) -> String { + let mut result = String::new(); + + if let Some(extension) = full_path.extension().and_then(|ext| ext.to_str()) { + let _ = write!(result, "{} ", extension); + } + + let _ = write!(result, "{}", full_path.display()); + + if let Some(range) = line_range { + if range.start == range.end { + let _ = write!(result, ":{}", range.start + 1); + } else { + let _ = write!(result, ":{}-{}", range.start + 1, range.end + 1); + } + } + + result +} + +impl From for MessageContent { + fn from(value: acp::ContentBlock) -> Self { + match value { + acp::ContentBlock::Text(text_content) => MessageContent::Text(text_content.text), + acp::ContentBlock::Image(image_content) => { + MessageContent::Image(convert_image(image_content)) + } + acp::ContentBlock::Audio(_) => { + // TODO + MessageContent::Text("[audio]".to_string()) + } + acp::ContentBlock::ResourceLink(resource_link) => { + match MentionUri::parse(&resource_link.uri) { + Ok(uri) => Self::Mention { + uri, + content: String::new(), + }, + Err(err) => { + log::error!("Failed to parse mention link: {}", err); + MessageContent::Text(format!( + "[{}]({})", + resource_link.name, resource_link.uri + )) + } + } + } + acp::ContentBlock::Resource(resource) => match resource.resource { + acp::EmbeddedResourceResource::TextResourceContents(resource) => { + match MentionUri::parse(&resource.uri) { + Ok(uri) => Self::Mention { + uri, + content: resource.text, + }, + Err(err) => { + log::error!("Failed to parse mention link: {}", err); + MessageContent::Text( + MarkdownCodeBlock { + tag: &resource.uri, + text: &resource.text, + } + .to_string(), + ) + } + } + } + acp::EmbeddedResourceResource::BlobResourceContents(_) => { + // TODO + MessageContent::Text("[blob]".to_string()) + } + }, + } + } +} + +fn convert_image(image_content: acp::ImageContent) -> LanguageModelImage { + LanguageModelImage { + source: image_content.data.into(), + // TODO: make this optional? + size: gpui::Size::new(0.into(), 0.into()), + } +} + +impl From<&str> for MessageContent { + fn from(text: &str) -> Self { + MessageContent::Text(text.into()) + } +} diff --git a/crates/agent_ui/Cargo.toml b/crates/agent_ui/Cargo.toml index de0a27c2cb..b6a5710aa4 100644 --- a/crates/agent_ui/Cargo.toml +++ b/crates/agent_ui/Cargo.toml @@ -93,6 +93,7 @@ time.workspace = true time_format.workspace = true ui.workspace = true ui_input.workspace = true +url.workspace = true urlencoding.workspace = true util.workspace = true uuid.workspace = true @@ -102,6 +103,8 @@ workspace.workspace = true zed_actions.workspace = true [dev-dependencies] +agent = { workspace = true, features = ["test-support"] } +assistant_context = { workspace = true, features = ["test-support"] } assistant_tools.workspace = true buffer_diff = { workspace = true, features = ["test-support"] } editor = { workspace = true, features = ["test-support"] } diff --git a/crates/agent_ui/src/acp/completion_provider.rs b/crates/agent_ui/src/acp/completion_provider.rs index d8f452afa5..b00473d1dc 100644 --- a/crates/agent_ui/src/acp/completion_provider.rs +++ b/crates/agent_ui/src/acp/completion_provider.rs @@ -1,64 +1,648 @@ use std::ops::Range; -use std::path::Path; +use std::path::{Path, PathBuf}; use std::sync::Arc; use std::sync::atomic::AtomicBool; -use anyhow::Result; -use collections::HashMap; +use acp_thread::{MentionUri, selection_name}; +use anyhow::{Context as _, Result, anyhow}; +use collections::{HashMap, HashSet}; use editor::display_map::CreaseId; -use editor::{CompletionProvider, Editor, ExcerptId}; +use editor::{CompletionProvider, Editor, ExcerptId, ToOffset as _}; use file_icons::FileIcons; +use futures::future::try_join_all; +use fuzzy::{StringMatch, StringMatchCandidate}; use gpui::{App, Entity, Task, WeakEntity}; +use http_client::HttpClientWithUrl; +use itertools::Itertools as _; use language::{Buffer, CodeLabel, HighlightId}; use lsp::CompletionContext; use parking_lot::Mutex; -use project::{Completion, CompletionIntent, CompletionResponse, ProjectPath, WorktreeId}; +use project::{ + Completion, CompletionIntent, CompletionResponse, Project, ProjectPath, Symbol, WorktreeId, +}; +use prompt_store::PromptStore; use rope::Point; -use text::{Anchor, ToPoint}; +use text::{Anchor, OffsetRangeExt as _, ToPoint as _}; use ui::prelude::*; +use url::Url; use workspace::Workspace; +use workspace::notifications::NotifyResultExt; -use crate::context_picker::MentionLink; -use crate::context_picker::file_context_picker::{extract_file_name_and_directory, search_files}; +use agent::{ + context::RULES_ICON, + thread_store::{TextThreadStore, ThreadStore}, +}; + +use crate::context_picker::fetch_context_picker::fetch_url_content; +use crate::context_picker::file_context_picker::{FileMatch, search_files}; +use crate::context_picker::rules_context_picker::{RulesContextEntry, search_rules}; +use crate::context_picker::symbol_context_picker::SymbolMatch; +use crate::context_picker::symbol_context_picker::search_symbols; +use crate::context_picker::thread_context_picker::{ + ThreadContextEntry, ThreadMatch, search_threads, +}; +use crate::context_picker::{ + ContextPickerAction, ContextPickerEntry, ContextPickerMode, RecentEntry, + available_context_picker_entries, recent_context_picker_entries, selection_ranges, +}; #[derive(Default)] pub struct MentionSet { - paths_by_crease_id: HashMap, + uri_by_crease_id: HashMap, + fetch_results: HashMap, } impl MentionSet { - pub fn insert(&mut self, crease_id: CreaseId, path: ProjectPath) { - self.paths_by_crease_id.insert(crease_id, path); + pub fn insert(&mut self, crease_id: CreaseId, uri: MentionUri) { + self.uri_by_crease_id.insert(crease_id, uri); } - pub fn path_for_crease_id(&self, crease_id: CreaseId) -> Option { - self.paths_by_crease_id.get(&crease_id).cloned() + pub fn add_fetch_result(&mut self, url: Url, content: String) { + self.fetch_results.insert(url, content); } pub fn drain(&mut self) -> impl Iterator { - self.paths_by_crease_id.drain().map(|(id, _)| id) + self.uri_by_crease_id.drain().map(|(id, _)| id) + } + + pub fn contents( + &self, + project: Entity, + thread_store: Entity, + text_thread_store: Entity, + window: &mut Window, + cx: &mut App, + ) -> Task>> { + let contents = self + .uri_by_crease_id + .iter() + .map(|(&crease_id, uri)| { + match uri { + MentionUri::File(path) => { + let uri = uri.clone(); + let path = path.to_path_buf(); + let buffer_task = project.update(cx, |project, cx| { + let path = project + .find_project_path(path, cx) + .context("Failed to find project path")?; + anyhow::Ok(project.open_buffer(path, cx)) + }); + + cx.spawn(async move |cx| { + let buffer = buffer_task?.await?; + let content = buffer.read_with(cx, |buffer, _cx| buffer.text())?; + + anyhow::Ok((crease_id, Mention { uri, content })) + }) + } + MentionUri::Symbol { + path, line_range, .. + } + | MentionUri::Selection { + path, line_range, .. + } => { + let uri = uri.clone(); + let path_buf = path.clone(); + let line_range = line_range.clone(); + + let buffer_task = project.update(cx, |project, cx| { + let path = project + .find_project_path(&path_buf, cx) + .context("Failed to find project path")?; + anyhow::Ok(project.open_buffer(path, cx)) + }); + + cx.spawn(async move |cx| { + let buffer = buffer_task?.await?; + let content = buffer.read_with(cx, |buffer, _cx| { + buffer + .text_for_range( + Point::new(line_range.start, 0) + ..Point::new( + line_range.end, + buffer.line_len(line_range.end), + ), + ) + .collect() + })?; + + anyhow::Ok((crease_id, Mention { uri, content })) + }) + } + MentionUri::Thread { id: thread_id, .. } => { + let open_task = thread_store.update(cx, |thread_store, cx| { + thread_store.open_thread(&thread_id, window, cx) + }); + + let uri = uri.clone(); + cx.spawn(async move |cx| { + let thread = open_task.await?; + let content = thread.read_with(cx, |thread, _cx| { + thread.latest_detailed_summary_or_text().to_string() + })?; + + anyhow::Ok((crease_id, Mention { uri, content })) + }) + } + MentionUri::TextThread { path, .. } => { + let context = text_thread_store.update(cx, |text_thread_store, cx| { + text_thread_store.open_local_context(path.as_path().into(), cx) + }); + let uri = uri.clone(); + cx.spawn(async move |cx| { + let context = context.await?; + let xml = context.update(cx, |context, cx| context.to_xml(cx))?; + anyhow::Ok((crease_id, Mention { uri, content: xml })) + }) + } + MentionUri::Rule { id: prompt_id, .. } => { + let Some(prompt_store) = thread_store.read(cx).prompt_store().clone() + else { + return Task::ready(Err(anyhow!("missing prompt store"))); + }; + let text_task = prompt_store.read(cx).load(prompt_id.clone(), cx); + let uri = uri.clone(); + cx.spawn(async move |_| { + // TODO: report load errors instead of just logging + let text = text_task.await?; + anyhow::Ok((crease_id, Mention { uri, content: text })) + }) + } + MentionUri::Fetch { url } => { + let Some(content) = self.fetch_results.get(&url) else { + return Task::ready(Err(anyhow!("missing fetch result"))); + }; + Task::ready(Ok(( + crease_id, + Mention { + uri: uri.clone(), + content: content.clone(), + }, + ))) + } + } + }) + .collect::>(); + + cx.spawn(async move |_cx| { + let contents = try_join_all(contents).await?.into_iter().collect(); + anyhow::Ok(contents) + }) + } +} + +#[derive(Debug)] +pub struct Mention { + pub uri: MentionUri, + pub content: String, +} + +pub(crate) enum Match { + File(FileMatch), + Symbol(SymbolMatch), + Thread(ThreadMatch), + Fetch(SharedString), + Rules(RulesContextEntry), + Entry(EntryMatch), +} + +pub struct EntryMatch { + mat: Option, + entry: ContextPickerEntry, +} + +impl Match { + pub fn score(&self) -> f64 { + match self { + Match::File(file) => file.mat.score, + Match::Entry(mode) => mode.mat.as_ref().map(|mat| mat.score).unwrap_or(1.), + Match::Thread(_) => 1., + Match::Symbol(_) => 1., + Match::Rules(_) => 1., + Match::Fetch(_) => 1., + } + } +} + +fn search( + mode: Option, + query: String, + cancellation_flag: Arc, + recent_entries: Vec, + prompt_store: Option>, + thread_store: WeakEntity, + text_thread_context_store: WeakEntity, + workspace: Entity, + cx: &mut App, +) -> Task> { + match mode { + Some(ContextPickerMode::File) => { + let search_files_task = + search_files(query.clone(), cancellation_flag.clone(), &workspace, cx); + cx.background_spawn(async move { + search_files_task + .await + .into_iter() + .map(Match::File) + .collect() + }) + } + + Some(ContextPickerMode::Symbol) => { + let search_symbols_task = + search_symbols(query.clone(), cancellation_flag.clone(), &workspace, cx); + cx.background_spawn(async move { + search_symbols_task + .await + .into_iter() + .map(|symbol| Match::Symbol(symbol)) + .collect() + }) + } + + Some(ContextPickerMode::Thread) => { + if let Some((thread_store, context_store)) = thread_store + .upgrade() + .zip(text_thread_context_store.upgrade()) + { + let search_threads_task = search_threads( + query.clone(), + cancellation_flag.clone(), + thread_store, + context_store, + cx, + ); + cx.background_spawn(async move { + search_threads_task + .await + .into_iter() + .map(Match::Thread) + .collect() + }) + } else { + Task::ready(Vec::new()) + } + } + + Some(ContextPickerMode::Fetch) => { + if !query.is_empty() { + Task::ready(vec![Match::Fetch(query.into())]) + } else { + Task::ready(Vec::new()) + } + } + + Some(ContextPickerMode::Rules) => { + if let Some(prompt_store) = prompt_store.as_ref() { + let search_rules_task = + search_rules(query.clone(), cancellation_flag.clone(), prompt_store, cx); + cx.background_spawn(async move { + search_rules_task + .await + .into_iter() + .map(Match::Rules) + .collect::>() + }) + } else { + Task::ready(Vec::new()) + } + } + + None => { + if query.is_empty() { + let mut matches = recent_entries + .into_iter() + .map(|entry| match entry { + RecentEntry::File { + project_path, + path_prefix, + } => Match::File(FileMatch { + mat: fuzzy::PathMatch { + score: 1., + positions: Vec::new(), + worktree_id: project_path.worktree_id.to_usize(), + path: project_path.path, + path_prefix, + is_dir: false, + distance_to_relative_ancestor: 0, + }, + is_recent: true, + }), + RecentEntry::Thread(thread_context_entry) => Match::Thread(ThreadMatch { + thread: thread_context_entry, + is_recent: true, + }), + }) + .collect::>(); + + matches.extend( + available_context_picker_entries( + &prompt_store, + &Some(thread_store.clone()), + &workspace, + cx, + ) + .into_iter() + .map(|mode| { + Match::Entry(EntryMatch { + entry: mode, + mat: None, + }) + }), + ); + + Task::ready(matches) + } else { + let executor = cx.background_executor().clone(); + + let search_files_task = + search_files(query.clone(), cancellation_flag.clone(), &workspace, cx); + + let entries = available_context_picker_entries( + &prompt_store, + &Some(thread_store.clone()), + &workspace, + cx, + ); + let entry_candidates = entries + .iter() + .enumerate() + .map(|(ix, entry)| StringMatchCandidate::new(ix, entry.keyword())) + .collect::>(); + + cx.background_spawn(async move { + let mut matches = search_files_task + .await + .into_iter() + .map(Match::File) + .collect::>(); + + let entry_matches = fuzzy::match_strings( + &entry_candidates, + &query, + false, + true, + 100, + &Arc::new(AtomicBool::default()), + executor, + ) + .await; + + matches.extend(entry_matches.into_iter().map(|mat| { + Match::Entry(EntryMatch { + entry: entries[mat.candidate_id], + mat: Some(mat), + }) + })); + + matches.sort_by(|a, b| { + b.score() + .partial_cmp(&a.score()) + .unwrap_or(std::cmp::Ordering::Equal) + }); + + matches + }) + } + } } } pub struct ContextPickerCompletionProvider { - workspace: WeakEntity, - editor: WeakEntity, mention_set: Arc>, + workspace: WeakEntity, + thread_store: WeakEntity, + text_thread_store: WeakEntity, + editor: WeakEntity, } impl ContextPickerCompletionProvider { pub fn new( mention_set: Arc>, workspace: WeakEntity, + thread_store: WeakEntity, + text_thread_store: WeakEntity, editor: WeakEntity, ) -> Self { Self { mention_set, workspace, + thread_store, + text_thread_store, editor, } } + fn completion_for_entry( + entry: ContextPickerEntry, + excerpt_id: ExcerptId, + source_range: Range, + editor: Entity, + mention_set: Arc>, + workspace: &Entity, + cx: &mut App, + ) -> Option { + match entry { + ContextPickerEntry::Mode(mode) => Some(Completion { + replace_range: source_range.clone(), + new_text: format!("@{} ", mode.keyword()), + label: CodeLabel::plain(mode.label().to_string(), None), + icon_path: Some(mode.icon().path().into()), + documentation: None, + source: project::CompletionSource::Custom, + insert_text_mode: None, + // This ensures that when a user accepts this completion, the + // completion menu will still be shown after "@category " is + // inserted + confirm: Some(Arc::new(|_, _, _| true)), + }), + ContextPickerEntry::Action(action) => { + let (new_text, on_action) = match action { + ContextPickerAction::AddSelections => { + let selections = selection_ranges(workspace, cx); + + const PLACEHOLDER: &str = "selection "; + + let new_text = std::iter::repeat(PLACEHOLDER) + .take(selections.len()) + .chain(std::iter::once("")) + .join(" "); + + let callback = Arc::new({ + let mention_set = mention_set.clone(); + let selections = selections.clone(); + move |_, window: &mut Window, cx: &mut App| { + let editor = editor.clone(); + let mention_set = mention_set.clone(); + let selections = selections.clone(); + window.defer(cx, move |window, cx| { + let mut current_offset = 0; + + for (buffer, selection_range) in selections { + let snapshot = + editor.read(cx).buffer().read(cx).snapshot(cx); + let Some(start) = snapshot + .anchor_in_excerpt(excerpt_id, source_range.start) + else { + return; + }; + + let offset = start.to_offset(&snapshot) + current_offset; + let text_len = PLACEHOLDER.len() - 1; + + let range = snapshot.anchor_after(offset) + ..snapshot.anchor_after(offset + text_len); + + let path = buffer + .read(cx) + .file() + .map_or(PathBuf::from("untitled"), |file| { + file.path().to_path_buf() + }); + + let point_range = snapshot + .as_singleton() + .map(|(_, _, snapshot)| { + selection_range.to_point(&snapshot) + }) + .unwrap_or_default(); + let line_range = point_range.start.row..point_range.end.row; + let crease = crate::context_picker::crease_for_mention( + selection_name(&path, &line_range).into(), + IconName::Reader.path().into(), + range, + editor.downgrade(), + ); + + let [crease_id]: [_; 1] = + editor.update(cx, |editor, cx| { + let crease_ids = + editor.insert_creases(vec![crease.clone()], cx); + editor.fold_creases( + vec![crease], + false, + window, + cx, + ); + crease_ids.try_into().unwrap() + }); + + mention_set.lock().insert( + crease_id, + MentionUri::Selection { path, line_range }, + ); + + current_offset += text_len + 1; + } + }); + + false + } + }); + + (new_text, callback) + } + }; + + Some(Completion { + replace_range: source_range.clone(), + new_text, + label: CodeLabel::plain(action.label().to_string(), None), + icon_path: Some(action.icon().path().into()), + documentation: None, + source: project::CompletionSource::Custom, + insert_text_mode: None, + // This ensures that when a user accepts this completion, the + // completion menu will still be shown after "@category " is + // inserted + confirm: Some(on_action), + }) + } + } + } + + fn completion_for_thread( + thread_entry: ThreadContextEntry, + excerpt_id: ExcerptId, + source_range: Range, + recent: bool, + editor: Entity, + mention_set: Arc>, + ) -> Completion { + let icon_for_completion = if recent { + IconName::HistoryRerun + } else { + IconName::Thread + }; + + let uri = match &thread_entry { + ThreadContextEntry::Thread { id, title } => MentionUri::Thread { + id: id.clone(), + name: title.to_string(), + }, + ThreadContextEntry::Context { path, title } => MentionUri::TextThread { + path: path.to_path_buf(), + name: title.to_string(), + }, + }; + let new_text = format!("{} ", uri.as_link()); + + let new_text_len = new_text.len(); + Completion { + replace_range: source_range.clone(), + new_text, + label: CodeLabel::plain(thread_entry.title().to_string(), None), + documentation: None, + insert_text_mode: None, + source: project::CompletionSource::Custom, + icon_path: Some(icon_for_completion.path().into()), + confirm: Some(confirm_completion_callback( + IconName::Thread.path().into(), + thread_entry.title().clone(), + excerpt_id, + source_range.start, + new_text_len - 1, + editor.clone(), + mention_set, + uri, + )), + } + } + + fn completion_for_rules( + rule: RulesContextEntry, + excerpt_id: ExcerptId, + source_range: Range, + editor: Entity, + mention_set: Arc>, + ) -> Completion { + let uri = MentionUri::Rule { + id: rule.prompt_id.into(), + name: rule.title.to_string(), + }; + let new_text = format!("{} ", uri.as_link()); + let new_text_len = new_text.len(); + Completion { + replace_range: source_range.clone(), + new_text, + label: CodeLabel::plain(rule.title.to_string(), None), + documentation: None, + insert_text_mode: None, + source: project::CompletionSource::Custom, + icon_path: Some(RULES_ICON.path().into()), + confirm: Some(confirm_completion_callback( + RULES_ICON.path().into(), + rule.title.clone(), + excerpt_id, + source_range.start, + new_text_len - 1, + editor.clone(), + mention_set, + uri, + )), + } + } + pub(crate) fn completion_for_path( project_path: ProjectPath, path_prefix: &str, @@ -68,10 +652,14 @@ impl ContextPickerCompletionProvider { source_range: Range, editor: Entity, mention_set: Arc>, + project: Entity, cx: &App, - ) -> Completion { + ) -> Option { let (file_name, directory) = - extract_file_name_and_directory(&project_path.path, path_prefix); + crate::context_picker::file_context_picker::extract_file_name_and_directory( + &project_path.path, + path_prefix, + ); let label = build_code_label_for_full_path(&file_name, directory.as_ref().map(|s| s.as_ref()), cx); @@ -93,9 +681,14 @@ impl ContextPickerCompletionProvider { crease_icon_path.clone() }; - let new_text = format!("{} ", MentionLink::for_file(&file_name, &full_path)); + let Some(abs_path) = project.read(cx).absolute_path(&project_path, cx) else { + return None; + }; + + let file_uri = MentionUri::File(abs_path.into()); + let new_text = format!("{} ", file_uri.as_link()); let new_text_len = new_text.len(); - Completion { + Some(Completion { replace_range: source_range.clone(), new_text, label, @@ -106,14 +699,153 @@ impl ContextPickerCompletionProvider { confirm: Some(confirm_completion_callback( crease_icon_path, file_name, - project_path, excerpt_id, source_range.start, new_text_len - 1, editor, - mention_set, + mention_set.clone(), + file_uri, )), - } + }) + } + + fn completion_for_symbol( + symbol: Symbol, + excerpt_id: ExcerptId, + source_range: Range, + editor: Entity, + mention_set: Arc>, + workspace: Entity, + cx: &mut App, + ) -> Option { + let project = workspace.read(cx).project().clone(); + + let label = CodeLabel::plain(symbol.name.clone(), None); + + let abs_path = project.read(cx).absolute_path(&symbol.path, cx)?; + let uri = MentionUri::Symbol { + path: abs_path, + name: symbol.name.clone(), + line_range: symbol.range.start.0.row..symbol.range.end.0.row, + }; + let new_text = format!("{} ", uri.as_link()); + let new_text_len = new_text.len(); + Some(Completion { + replace_range: source_range.clone(), + new_text, + label, + documentation: None, + source: project::CompletionSource::Custom, + icon_path: Some(IconName::Code.path().into()), + insert_text_mode: None, + confirm: Some(confirm_completion_callback( + IconName::Code.path().into(), + symbol.name.clone().into(), + excerpt_id, + source_range.start, + new_text_len - 1, + editor.clone(), + mention_set.clone(), + uri, + )), + }) + } + + fn completion_for_fetch( + source_range: Range, + url_to_fetch: SharedString, + excerpt_id: ExcerptId, + editor: Entity, + mention_set: Arc>, + http_client: Arc, + ) -> Option { + let new_text = format!("@fetch {} ", url_to_fetch.clone()); + let new_text_len = new_text.len(); + Some(Completion { + replace_range: source_range.clone(), + new_text, + label: CodeLabel::plain(url_to_fetch.to_string(), None), + documentation: None, + source: project::CompletionSource::Custom, + icon_path: Some(IconName::ToolWeb.path().into()), + insert_text_mode: None, + confirm: Some({ + let start = source_range.start; + let content_len = new_text_len - 1; + let editor = editor.clone(); + let url_to_fetch = url_to_fetch.clone(); + let source_range = source_range.clone(); + Arc::new(move |_, window, cx| { + let Some(url) = url::Url::parse(url_to_fetch.as_ref()) + .or_else(|_| url::Url::parse(&format!("https://{url_to_fetch}"))) + .notify_app_err(cx) + else { + return false; + }; + let mention_uri = MentionUri::Fetch { url: url.clone() }; + + let editor = editor.clone(); + let mention_set = mention_set.clone(); + let http_client = http_client.clone(); + let source_range = source_range.clone(); + window.defer(cx, move |window, cx| { + let url = url.clone(); + + let Some(crease_id) = crate::context_picker::insert_crease_for_mention( + excerpt_id, + start, + content_len, + url.to_string().into(), + IconName::ToolWeb.path().into(), + editor.clone(), + window, + cx, + ) else { + return; + }; + + let editor = editor.clone(); + let mention_set = mention_set.clone(); + let http_client = http_client.clone(); + let source_range = source_range.clone(); + window + .spawn(cx, async move |cx| { + if let Some(content) = + fetch_url_content(http_client, url.to_string()) + .await + .notify_async_err(cx) + { + mention_set.lock().add_fetch_result(url, content); + mention_set.lock().insert(crease_id, mention_uri.clone()); + } else { + // Remove crease if we failed to fetch + editor + .update(cx, |editor, cx| { + let snapshot = editor.buffer().read(cx).snapshot(cx); + let Some(anchor) = snapshot + .anchor_in_excerpt(excerpt_id, source_range.start) + else { + return; + }; + editor.display_map.update(cx, |display_map, cx| { + display_map.unfold_intersecting( + vec![anchor..anchor], + true, + cx, + ); + }); + editor.remove_creases([crease_id], cx); + }) + .ok(); + } + Some(()) + }) + .detach(); + }); + false + }) + }), + }) } } @@ -159,16 +891,67 @@ impl CompletionProvider for ContextPickerCompletionProvider { return Task::ready(Ok(Vec::new())); }; + let project = workspace.read(cx).project().clone(); + let http_client = workspace.read(cx).client().http_client(); let snapshot = buffer.read(cx).snapshot(); let source_range = snapshot.anchor_before(state.source_range.start) ..snapshot.anchor_after(state.source_range.end); + let thread_store = self.thread_store.clone(); + let text_thread_store = self.text_thread_store.clone(); let editor = self.editor.clone(); - let mention_set = self.mention_set.clone(); - let MentionCompletion { argument, .. } = state; + + let MentionCompletion { mode, argument, .. } = state; let query = argument.unwrap_or_else(|| "".to_string()); - let search_task = search_files(query.clone(), Arc::::default(), &workspace, cx); + let (exclude_paths, exclude_threads) = { + let mention_set = self.mention_set.lock(); + + let mut excluded_paths = HashSet::default(); + let mut excluded_threads = HashSet::default(); + + for uri in mention_set.uri_by_crease_id.values() { + match uri { + MentionUri::File(path) => { + excluded_paths.insert(path.clone()); + } + MentionUri::Thread { id, .. } => { + excluded_threads.insert(id.clone()); + } + _ => {} + } + } + + (excluded_paths, excluded_threads) + }; + + let recent_entries = recent_context_picker_entries( + Some(thread_store.clone()), + Some(text_thread_store.clone()), + workspace.clone(), + &exclude_paths, + &exclude_threads, + cx, + ); + + let prompt_store = thread_store + .read_with(cx, |thread_store, _cx| thread_store.prompt_store().clone()) + .ok() + .flatten(); + + let search_task = search( + mode, + query, + Arc::::default(), + recent_entries, + prompt_store, + thread_store.clone(), + text_thread_store.clone(), + workspace.clone(), + cx, + ); + + let mention_set = self.mention_set.clone(); cx.spawn(async move |_, cx| { let matches = search_task.await; @@ -179,24 +962,74 @@ impl CompletionProvider for ContextPickerCompletionProvider { let completions = cx.update(|cx| { matches .into_iter() - .map(|mat| { - let path_match = &mat.mat; - let project_path = ProjectPath { - worktree_id: WorktreeId::from_usize(path_match.worktree_id), - path: path_match.path.clone(), - }; + .filter_map(|mat| match mat { + Match::File(FileMatch { mat, is_recent }) => { + let project_path = ProjectPath { + worktree_id: WorktreeId::from_usize(mat.worktree_id), + path: mat.path.clone(), + }; - Self::completion_for_path( - project_path, - &path_match.path_prefix, - mat.is_recent, - path_match.is_dir, + Self::completion_for_path( + project_path, + &mat.path_prefix, + is_recent, + mat.is_dir, + excerpt_id, + source_range.clone(), + editor.clone(), + mention_set.clone(), + project.clone(), + cx, + ) + } + + Match::Symbol(SymbolMatch { symbol, .. }) => Self::completion_for_symbol( + symbol, excerpt_id, source_range.clone(), editor.clone(), mention_set.clone(), + workspace.clone(), cx, - ) + ), + + Match::Thread(ThreadMatch { + thread, is_recent, .. + }) => Some(Self::completion_for_thread( + thread, + excerpt_id, + source_range.clone(), + is_recent, + editor.clone(), + mention_set.clone(), + )), + + Match::Rules(user_rules) => Some(Self::completion_for_rules( + user_rules, + excerpt_id, + source_range.clone(), + editor.clone(), + mention_set.clone(), + )), + + Match::Fetch(url) => Self::completion_for_fetch( + source_range.clone(), + url, + excerpt_id, + editor.clone(), + mention_set.clone(), + http_client.clone(), + ), + + Match::Entry(EntryMatch { entry, .. }) => Self::completion_for_entry( + entry, + excerpt_id, + source_range.clone(), + editor.clone(), + mention_set.clone(), + &workspace, + cx, + ), }) .collect() })?; @@ -248,21 +1081,21 @@ impl CompletionProvider for ContextPickerCompletionProvider { fn confirm_completion_callback( crease_icon_path: SharedString, crease_text: SharedString, - project_path: ProjectPath, excerpt_id: ExcerptId, start: Anchor, content_len: usize, editor: Entity, mention_set: Arc>, + mention_uri: MentionUri, ) -> Arc bool + Send + Sync> { Arc::new(move |_, window, cx| { let crease_text = crease_text.clone(); let crease_icon_path = crease_icon_path.clone(); let editor = editor.clone(); - let project_path = project_path.clone(); let mention_set = mention_set.clone(); + let mention_uri = mention_uri.clone(); window.defer(cx, move |window, cx| { - let crease_id = crate::context_picker::insert_crease_for_mention( + if let Some(crease_id) = crate::context_picker::insert_crease_for_mention( excerpt_id, start, content_len, @@ -271,9 +1104,8 @@ fn confirm_completion_callback( editor.clone(), window, cx, - ); - if let Some(crease_id) = crease_id { - mention_set.lock().insert(crease_id, project_path); + ) { + mention_set.lock().insert(crease_id, mention_uri.clone()); } }); false @@ -283,6 +1115,7 @@ fn confirm_completion_callback( #[derive(Debug, Default, PartialEq)] struct MentionCompletion { source_range: Range, + mode: Option, argument: Option, } @@ -302,17 +1135,37 @@ impl MentionCompletion { } let rest_of_line = &line[last_mention_start + 1..]; + + let mut mode = None; let mut argument = None; let mut parts = rest_of_line.split_whitespace(); let mut end = last_mention_start + 1; - if let Some(argument_text) = parts.next() { - end += argument_text.len(); - argument = Some(argument_text.to_string()); + if let Some(mode_text) = parts.next() { + end += mode_text.len(); + + if let Some(parsed_mode) = ContextPickerMode::try_from(mode_text).ok() { + mode = Some(parsed_mode); + } else { + argument = Some(mode_text.to_string()); + } + match rest_of_line[mode_text.len()..].find(|c: char| !c.is_whitespace()) { + Some(whitespace_count) => { + if let Some(argument_text) = parts.next() { + argument = Some(argument_text.to_string()); + end += whitespace_count + argument_text.len(); + } + } + None => { + // Rest of line is entirely whitespace + end += rest_of_line.len() - mode_text.len(); + } + } } Some(Self { source_range: last_mention_start + offset_to_line..end + offset_to_line, + mode, argument, }) } @@ -321,10 +1174,12 @@ impl MentionCompletion { #[cfg(test)] mod tests { use super::*; + use editor::AnchorRangeExt; use gpui::{EventEmitter, FocusHandle, Focusable, TestAppContext, VisualTestContext}; use project::{Project, ProjectPath}; use serde_json::json; use settings::SettingsStore; + use smol::stream::StreamExt as _; use std::{ops::Deref, rc::Rc}; use util::path; use workspace::{AppState, Item}; @@ -337,14 +1192,61 @@ mod tests { MentionCompletion::try_parse("Lorem @", 0), Some(MentionCompletion { source_range: 6..7, + mode: None, argument: None, }) ); + assert_eq!( + MentionCompletion::try_parse("Lorem @file", 0), + Some(MentionCompletion { + source_range: 6..11, + mode: Some(ContextPickerMode::File), + argument: None, + }) + ); + + assert_eq!( + MentionCompletion::try_parse("Lorem @file ", 0), + Some(MentionCompletion { + source_range: 6..12, + mode: Some(ContextPickerMode::File), + argument: None, + }) + ); + + assert_eq!( + MentionCompletion::try_parse("Lorem @file main.rs", 0), + Some(MentionCompletion { + source_range: 6..19, + mode: Some(ContextPickerMode::File), + argument: Some("main.rs".to_string()), + }) + ); + + assert_eq!( + MentionCompletion::try_parse("Lorem @file main.rs ", 0), + Some(MentionCompletion { + source_range: 6..19, + mode: Some(ContextPickerMode::File), + argument: Some("main.rs".to_string()), + }) + ); + + assert_eq!( + MentionCompletion::try_parse("Lorem @file main.rs Ipsum", 0), + Some(MentionCompletion { + source_range: 6..19, + mode: Some(ContextPickerMode::File), + argument: Some("main.rs".to_string()), + }) + ); + assert_eq!( MentionCompletion::try_parse("Lorem @main", 0), Some(MentionCompletion { source_range: 6..11, + mode: None, argument: Some("main".to_string()), }) ); @@ -401,16 +1303,16 @@ mod tests { json!({ "editor": "", "a": { - "one.txt": "", - "two.txt": "", - "three.txt": "", - "four.txt": "" + "one.txt": "1", + "two.txt": "2", + "three.txt": "3", + "four.txt": "4" }, "b": { - "five.txt": "", - "six.txt": "", - "seven.txt": "", - "eight.txt": "", + "five.txt": "5", + "six.txt": "6", + "seven.txt": "7", + "eight.txt": "8", } }), ) @@ -485,12 +1387,17 @@ mod tests { let mention_set = Arc::new(Mutex::new(MentionSet::default())); + let thread_store = cx.new(|cx| ThreadStore::fake(project.clone(), cx)); + let text_thread_store = cx.new(|cx| TextThreadStore::fake(project.clone(), cx)); + let editor_entity = editor.downgrade(); editor.update_in(&mut cx, |editor, window, cx| { window.focus(&editor.focus_handle(cx)); editor.set_completion_provider(Some(Rc::new(ContextPickerCompletionProvider::new( mention_set.clone(), workspace.downgrade(), + thread_store.downgrade(), + text_thread_store.downgrade(), editor_entity, )))); }); @@ -514,22 +1421,10 @@ mod tests { "seven.txt dir/b/", "six.txt dir/b/", "five.txt dir/b/", - "four.txt dir/a/", - "three.txt dir/a/", - "two.txt dir/a/", - "one.txt dir/a/", - "dir ", - "a dir/", - "four.txt dir/a/", - "one.txt dir/a/", - "three.txt dir/a/", - "two.txt dir/a/", - "b dir/", - "eight.txt dir/b/", - "five.txt dir/b/", - "seven.txt dir/b/", - "six.txt dir/b/", - "editor dir/" + "Files & Directories", + "Symbols", + "Threads", + "Fetch" ] ); }); @@ -547,8 +1442,264 @@ mod tests { cx.run_until_parked(); editor.update(&mut cx, |editor, cx| { - assert_eq!(editor.text(cx), "Lorem [@four.txt](@file:dir/a/four.txt) "); + assert_eq!(editor.text(cx), "Lorem @file "); + assert!(editor.has_visible_completions_menu()); }); + + cx.simulate_input("one"); + + editor.update(&mut cx, |editor, cx| { + assert_eq!(editor.text(cx), "Lorem @file one"); + assert!(editor.has_visible_completions_menu()); + assert_eq!(current_completion_labels(editor), vec!["one.txt dir/a/"]); + }); + + editor.update_in(&mut cx, |editor, window, cx| { + assert!(editor.has_visible_completions_menu()); + editor.confirm_completion(&editor::actions::ConfirmCompletion::default(), window, cx); + }); + + editor.update(&mut cx, |editor, cx| { + assert_eq!(editor.text(cx), "Lorem [@one.txt](file:///dir/a/one.txt) "); + assert!(!editor.has_visible_completions_menu()); + assert_eq!( + fold_ranges(editor, cx), + vec![Point::new(0, 6)..Point::new(0, 39)] + ); + }); + + let contents = cx + .update(|window, cx| { + mention_set.lock().contents( + project.clone(), + thread_store.clone(), + text_thread_store.clone(), + window, + cx, + ) + }) + .await + .unwrap() + .into_values() + .collect::>(); + + assert_eq!(contents.len(), 1); + assert_eq!(contents[0].content, "1"); + assert_eq!( + contents[0].uri.to_uri().to_string(), + "file:///dir/a/one.txt" + ); + + cx.simulate_input(" "); + + editor.update(&mut cx, |editor, cx| { + assert_eq!(editor.text(cx), "Lorem [@one.txt](file:///dir/a/one.txt) "); + assert!(!editor.has_visible_completions_menu()); + assert_eq!( + fold_ranges(editor, cx), + vec![Point::new(0, 6)..Point::new(0, 39)] + ); + }); + + cx.simulate_input("Ipsum "); + + editor.update(&mut cx, |editor, cx| { + assert_eq!( + editor.text(cx), + "Lorem [@one.txt](file:///dir/a/one.txt) Ipsum ", + ); + assert!(!editor.has_visible_completions_menu()); + assert_eq!( + fold_ranges(editor, cx), + vec![Point::new(0, 6)..Point::new(0, 39)] + ); + }); + + cx.simulate_input("@file "); + + editor.update(&mut cx, |editor, cx| { + assert_eq!( + editor.text(cx), + "Lorem [@one.txt](file:///dir/a/one.txt) Ipsum @file ", + ); + assert!(editor.has_visible_completions_menu()); + assert_eq!( + fold_ranges(editor, cx), + vec![Point::new(0, 6)..Point::new(0, 39)] + ); + }); + + editor.update_in(&mut cx, |editor, window, cx| { + editor.confirm_completion(&editor::actions::ConfirmCompletion::default(), window, cx); + }); + + cx.run_until_parked(); + + let contents = cx + .update(|window, cx| { + mention_set.lock().contents( + project.clone(), + thread_store.clone(), + text_thread_store.clone(), + window, + cx, + ) + }) + .await + .unwrap() + .into_values() + .collect::>(); + + assert_eq!(contents.len(), 2); + let new_mention = contents + .iter() + .find(|mention| mention.uri.to_uri().to_string() == "file:///dir/b/eight.txt") + .unwrap(); + assert_eq!(new_mention.content, "8"); + + editor.update(&mut cx, |editor, cx| { + assert_eq!( + editor.text(cx), + "Lorem [@one.txt](file:///dir/a/one.txt) Ipsum [@eight.txt](file:///dir/b/eight.txt) " + ); + assert!(!editor.has_visible_completions_menu()); + assert_eq!( + fold_ranges(editor, cx), + vec![ + Point::new(0, 6)..Point::new(0, 39), + Point::new(0, 47)..Point::new(0, 84) + ] + ); + }); + + let plain_text_language = Arc::new(language::Language::new( + language::LanguageConfig { + name: "Plain Text".into(), + matcher: language::LanguageMatcher { + path_suffixes: vec!["txt".to_string()], + ..Default::default() + }, + ..Default::default() + }, + None, + )); + + // Register the language and fake LSP + let language_registry = project.read_with(&cx, |project, _| project.languages().clone()); + language_registry.add(plain_text_language); + + let mut fake_language_servers = language_registry.register_fake_lsp( + "Plain Text", + language::FakeLspAdapter { + capabilities: lsp::ServerCapabilities { + workspace_symbol_provider: Some(lsp::OneOf::Left(true)), + ..Default::default() + }, + ..Default::default() + }, + ); + + // Open the buffer to trigger LSP initialization + let buffer = project + .update(&mut cx, |project, cx| { + project.open_local_buffer(path!("/dir/a/one.txt"), cx) + }) + .await + .unwrap(); + + // Register the buffer with language servers + let _handle = project.update(&mut cx, |project, cx| { + project.register_buffer_with_language_servers(&buffer, cx) + }); + + cx.run_until_parked(); + + let fake_language_server = fake_language_servers.next().await.unwrap(); + fake_language_server.set_request_handler::( + |_, _| async move { + Ok(Some(lsp::WorkspaceSymbolResponse::Flat(vec![ + #[allow(deprecated)] + lsp::SymbolInformation { + name: "MySymbol".into(), + location: lsp::Location { + uri: lsp::Url::from_file_path(path!("/dir/a/one.txt")).unwrap(), + range: lsp::Range::new( + lsp::Position::new(0, 0), + lsp::Position::new(0, 1), + ), + }, + kind: lsp::SymbolKind::CONSTANT, + tags: None, + container_name: None, + deprecated: None, + }, + ]))) + }, + ); + + cx.simulate_input("@symbol "); + + editor.update(&mut cx, |editor, cx| { + assert_eq!( + editor.text(cx), + "Lorem [@one.txt](file:///dir/a/one.txt) Ipsum [@eight.txt](file:///dir/b/eight.txt) @symbol " + ); + assert!(editor.has_visible_completions_menu()); + assert_eq!( + current_completion_labels(editor), + &[ + "MySymbol", + ] + ); + }); + + editor.update_in(&mut cx, |editor, window, cx| { + editor.confirm_completion(&editor::actions::ConfirmCompletion::default(), window, cx); + }); + + let contents = cx + .update(|window, cx| { + mention_set.lock().contents( + project.clone(), + thread_store, + text_thread_store, + window, + cx, + ) + }) + .await + .unwrap() + .into_values() + .collect::>(); + + assert_eq!(contents.len(), 3); + let new_mention = contents + .iter() + .find(|mention| { + mention.uri.to_uri().to_string() == "file:///dir/a/one.txt?symbol=MySymbol#L1:1" + }) + .unwrap(); + assert_eq!(new_mention.content, "1"); + + cx.run_until_parked(); + + editor.read_with(&mut cx, |editor, cx| { + assert_eq!( + editor.text(cx), + "Lorem [@one.txt](file:///dir/a/one.txt) Ipsum [@eight.txt](file:///dir/b/eight.txt) [@MySymbol](file:///dir/a/one.txt?symbol=MySymbol#L1:1) " + ); + }); + } + + fn fold_ranges(editor: &Editor, cx: &mut App) -> Vec> { + let snapshot = editor.buffer().read(cx).snapshot(cx); + editor.display_map.update(cx, |display_map, cx| { + display_map + .snapshot(cx) + .folds_in_range(0..snapshot.len()) + .map(|fold| fold.range.to_point(&snapshot)) + .collect() + }) } fn current_completion_labels(editor: &Editor) -> Vec { diff --git a/crates/agent_ui/src/acp/thread_view.rs b/crates/agent_ui/src/acp/thread_view.rs index f37deac26e..940ac7135f 100644 --- a/crates/agent_ui/src/acp/thread_view.rs +++ b/crates/agent_ui/src/acp/thread_view.rs @@ -1,18 +1,20 @@ use acp_thread::{ AcpThread, AcpThreadEvent, AgentThreadEntry, AssistantMessage, AssistantMessageChunk, - LoadError, MentionPath, ThreadStatus, ToolCall, ToolCallContent, ToolCallStatus, + LoadError, MentionUri, ThreadStatus, ToolCall, ToolCallContent, ToolCallStatus, }; use acp_thread::{AgentConnection, Plan}; use action_log::ActionLog; +use agent::{TextThreadStore, ThreadStore}; use agent_client_protocol as acp; use agent_servers::AgentServer; use agent_settings::{AgentSettings, NotifyWhenAgentWaiting}; use audio::{Audio, Sound}; use buffer_diff::BufferDiff; use collections::{HashMap, HashSet}; +use editor::scroll::Autoscroll; use editor::{ AnchorRangeExt, ContextMenuOptions, ContextMenuPlacement, Editor, EditorElement, EditorMode, - EditorStyle, MinimapVisibility, MultiBuffer, PathKey, + EditorStyle, MinimapVisibility, MultiBuffer, PathKey, SelectionEffects, }; use file_icons::FileIcons; use gpui::{ @@ -27,7 +29,11 @@ use language::{Buffer, Language}; use markdown::{HeadingLevelStyles, Markdown, MarkdownElement, MarkdownStyle}; use parking_lot::Mutex; use project::{CompletionIntent, Project}; +use prompt_store::PromptId; +use rope::Point; use settings::{Settings as _, SettingsStore}; +use std::fmt::Write as _; +use std::path::PathBuf; use std::{ cell::RefCell, collections::BTreeMap, path::Path, process::ExitStatus, rc::Rc, sync::Arc, time::Duration, @@ -41,6 +47,7 @@ use ui::{ use util::{ResultExt, size::format_file_size, time::duration_alt_display}; use workspace::{CollaboratorId, Workspace}; use zed_actions::agent::{Chat, NextHistoryMessage, PreviousHistoryMessage}; +use zed_actions::assistant::OpenRulesLibrary; use crate::acp::completion_provider::{ContextPickerCompletionProvider, MentionSet}; use crate::acp::message_history::MessageHistory; @@ -57,6 +64,8 @@ pub struct AcpThreadView { agent: Rc, workspace: WeakEntity, project: Entity, + thread_store: WeakEntity, + text_thread_store: WeakEntity, thread_state: ThreadState, diff_editors: HashMap>, terminal_views: HashMap>, @@ -103,6 +112,8 @@ impl AcpThreadView { agent: Rc, workspace: WeakEntity, project: Entity, + thread_store: WeakEntity, + text_thread_store: WeakEntity, message_history: Rc>>>, min_lines: usize, max_lines: Option, @@ -140,6 +151,8 @@ impl AcpThreadView { editor.set_completion_provider(Some(Rc::new(ContextPickerCompletionProvider::new( mention_set.clone(), workspace.clone(), + thread_store.clone(), + text_thread_store.clone(), cx.weak_entity(), )))); editor.set_context_menu_options(ContextMenuOptions { @@ -183,6 +196,8 @@ impl AcpThreadView { agent: agent.clone(), workspace: workspace.clone(), project: project.clone(), + thread_store, + text_thread_store, thread_state: Self::initial_state(agent, workspace, project, window, cx), message_editor, message_set_from_history: None, @@ -376,81 +391,111 @@ impl AcpThreadView { let mut ix = 0; let mut chunks: Vec = Vec::new(); let project = self.project.clone(); - self.message_editor.update(cx, |editor, cx| { - let text = editor.text(cx); - editor.display_map.update(cx, |map, cx| { - let snapshot = map.snapshot(cx); - for (crease_id, crease) in snapshot.crease_snapshot.creases() { - // Skip creases that have been edited out of the message buffer. - if !crease.range().start.is_valid(&snapshot.buffer_snapshot) { - continue; - } - if let Some(project_path) = - self.mention_set.lock().path_for_crease_id(crease_id) - { - let crease_range = crease.range().to_offset(&snapshot.buffer_snapshot); - if crease_range.start > ix { - chunks.push(text[ix..crease_range.start].into()); - } - if let Some(abs_path) = project.read(cx).absolute_path(&project_path, cx) { - let path_str = abs_path.display().to_string(); - chunks.push(acp::ContentBlock::ResourceLink(acp::ResourceLink { - uri: path_str.clone(), - name: path_str, - annotations: None, - description: None, - mime_type: None, - size: None, - title: None, - })); - } - ix = crease_range.end; - } - } - - if ix < text.len() { - let last_chunk = text[ix..].trim_end(); - if !last_chunk.is_empty() { - chunks.push(last_chunk.into()); - } - } - }) - }); - - if chunks.is_empty() { - return; - } - - let Some(thread) = self.thread() else { + let Some(thread_store) = self.thread_store.upgrade() else { + return; + }; + let Some(text_thread_store) = self.text_thread_store.upgrade() else { return; }; - let task = thread.update(cx, |thread, cx| thread.send(chunks.clone(), cx)); - cx.spawn(async move |this, cx| { - let result = task.await; + let contents = + self.mention_set + .lock() + .contents(project, thread_store, text_thread_store, window, cx); - this.update(cx, |this, cx| { - if let Err(err) = result { - this.last_error = - Some(cx.new(|cx| Markdown::new(err.to_string().into(), None, None, cx))) + cx.spawn_in(window, async move |this, cx| { + let contents = match contents.await { + Ok(contents) => contents, + Err(e) => { + this.update(cx, |this, cx| { + this.last_error = + Some(cx.new(|cx| Markdown::new(e.to_string().into(), None, None, cx))); + }) + .ok(); + return; } + }; + + this.update_in(cx, |this, window, cx| { + this.message_editor.update(cx, |editor, cx| { + let text = editor.text(cx); + editor.display_map.update(cx, |map, cx| { + let snapshot = map.snapshot(cx); + for (crease_id, crease) in snapshot.crease_snapshot.creases() { + // Skip creases that have been edited out of the message buffer. + if !crease.range().start.is_valid(&snapshot.buffer_snapshot) { + continue; + } + + if let Some(mention) = contents.get(&crease_id) { + let crease_range = + crease.range().to_offset(&snapshot.buffer_snapshot); + if crease_range.start > ix { + chunks.push(text[ix..crease_range.start].into()); + } + chunks.push(acp::ContentBlock::Resource(acp::EmbeddedResource { + annotations: None, + resource: acp::EmbeddedResourceResource::TextResourceContents( + acp::TextResourceContents { + mime_type: None, + text: mention.content.clone(), + uri: mention.uri.to_uri().to_string(), + }, + ), + })); + ix = crease_range.end; + } + } + + if ix < text.len() { + let last_chunk = text[ix..].trim_end(); + if !last_chunk.is_empty() { + chunks.push(last_chunk.into()); + } + } + }) + }); + + if chunks.is_empty() { + return; + } + + let Some(thread) = this.thread() else { + return; + }; + let task = thread.update(cx, |thread, cx| thread.send(chunks.clone(), cx)); + + cx.spawn(async move |this, cx| { + let result = task.await; + + this.update(cx, |this, cx| { + if let Err(err) = result { + this.last_error = + Some(cx.new(|cx| { + Markdown::new(err.to_string().into(), None, None, cx) + })) + } + }) + }) + .detach(); + + let mention_set = this.mention_set.clone(); + + this.set_editor_is_expanded(false, cx); + + this.message_editor.update(cx, |editor, cx| { + editor.clear(window, cx); + editor.remove_creases(mention_set.lock().drain(), cx) + }); + + this.scroll_to_bottom(cx); + + this.message_history.borrow_mut().push(chunks); }) + .ok(); }) .detach(); - - let mention_set = self.mention_set.clone(); - - self.set_editor_is_expanded(false, cx); - - self.message_editor.update(cx, |editor, cx| { - editor.clear(window, cx); - editor.remove_creases(mention_set.lock().drain(), cx) - }); - - self.scroll_to_bottom(cx); - - self.message_history.borrow_mut().push(chunks); } fn previous_history_message( @@ -563,16 +608,18 @@ impl AcpThreadView { acp::ContentBlock::Text(text_content) => { text.push_str(&text_content.text); } - acp::ContentBlock::ResourceLink(resource_link) => { - let path = Path::new(&resource_link.uri); + acp::ContentBlock::Resource(acp::EmbeddedResource { + resource: acp::EmbeddedResourceResource::TextResourceContents(resource), + .. + }) => { + let path = PathBuf::from(&resource.uri); + let project_path = project.read(cx).project_path_for_absolute_path(&path, cx); let start = text.len(); - let content = MentionPath::new(&path).to_string(); - text.push_str(&content); + let _ = write!(&mut text, "{}", MentionUri::File(path).to_uri()); let end = text.len(); - if let Some(project_path) = - project.read(cx).project_path_for_absolute_path(&path, cx) - { - let filename: SharedString = path + if let Some(project_path) = project_path { + let filename: SharedString = project_path + .path .file_name() .unwrap_or_default() .to_string_lossy() @@ -583,7 +630,8 @@ impl AcpThreadView { } acp::ContentBlock::Image(_) | acp::ContentBlock::Audio(_) - | acp::ContentBlock::Resource(_) => {} + | acp::ContentBlock::Resource(_) + | acp::ContentBlock::ResourceLink(_) => {} } } @@ -602,18 +650,23 @@ impl AcpThreadView { }; let anchor = snapshot.anchor_before(range.start); - let crease_id = crate::context_picker::insert_crease_for_mention( - anchor.excerpt_id, - anchor.text_anchor, - range.end - range.start, - filename, - crease_icon_path, - message_editor.clone(), - window, - cx, - ); - if let Some(crease_id) = crease_id { - mention_set.lock().insert(crease_id, project_path); + if let Some(project_path) = project.read(cx).absolute_path(&project_path, cx) { + let crease_id = crate::context_picker::insert_crease_for_mention( + anchor.excerpt_id, + anchor.text_anchor, + range.end - range.start, + filename, + crease_icon_path, + message_editor.clone(), + window, + cx, + ); + + if let Some(crease_id) = crease_id { + mention_set + .lock() + .insert(crease_id, MentionUri::File(project_path)); + } } } @@ -2562,26 +2615,95 @@ impl AcpThreadView { return; }; - if let Some(mention_path) = MentionPath::try_parse(&url) { - workspace.update(cx, |workspace, cx| { - let project = workspace.project(); - let Some((path, entry)) = project.update(cx, |project, cx| { - let path = project.find_project_path(mention_path.path(), cx)?; - let entry = project.entry_for_path(&path, cx)?; - Some((path, entry)) - }) else { - return; - }; + if let Some(mention) = MentionUri::parse(&url).log_err() { + workspace.update(cx, |workspace, cx| match mention { + MentionUri::File(path) => { + let project = workspace.project(); + let Some((path, entry)) = project.update(cx, |project, cx| { + let path = project.find_project_path(path, cx)?; + let entry = project.entry_for_path(&path, cx)?; + Some((path, entry)) + }) else { + return; + }; - if entry.is_dir() { - project.update(cx, |_, cx| { - cx.emit(project::Event::RevealInProjectPanel(entry.id)); - }); - } else { - workspace - .open_path(path, None, true, window, cx) + if entry.is_dir() { + project.update(cx, |_, cx| { + cx.emit(project::Event::RevealInProjectPanel(entry.id)); + }); + } else { + workspace + .open_path(path, None, true, window, cx) + .detach_and_log_err(cx); + } + } + MentionUri::Symbol { + path, line_range, .. + } + | MentionUri::Selection { path, line_range } => { + let project = workspace.project(); + let Some((path, _)) = project.update(cx, |project, cx| { + let path = project.find_project_path(path, cx)?; + let entry = project.entry_for_path(&path, cx)?; + Some((path, entry)) + }) else { + return; + }; + + let item = workspace.open_path(path, None, true, window, cx); + window + .spawn(cx, async move |cx| { + let Some(editor) = item.await?.downcast::() else { + return Ok(()); + }; + let range = Point::new(line_range.start as u32, 0) + ..Point::new(line_range.start as u32, 0); + editor + .update_in(cx, |editor, window, cx| { + editor.change_selections( + SelectionEffects::scroll(Autoscroll::center()), + window, + cx, + |s| s.select_ranges(vec![range]), + ); + }) + .ok(); + anyhow::Ok(()) + }) .detach_and_log_err(cx); } + MentionUri::Thread { id, .. } => { + if let Some(panel) = workspace.panel::(cx) { + panel.update(cx, |panel, cx| { + panel + .open_thread_by_id(&id, window, cx) + .detach_and_log_err(cx) + }); + } + } + MentionUri::TextThread { path, .. } => { + if let Some(panel) = workspace.panel::(cx) { + panel.update(cx, |panel, cx| { + panel + .open_saved_prompt_editor(path.as_path().into(), window, cx) + .detach_and_log_err(cx); + }); + } + } + MentionUri::Rule { id, .. } => { + let PromptId::User { uuid } = id else { + return; + }; + window.dispatch_action( + Box::new(OpenRulesLibrary { + prompt_to_select: Some(uuid.0), + }), + cx, + ) + } + MentionUri::Fetch { url } => { + cx.open_url(url.as_str()); + } }) } else { cx.open_url(&url); @@ -2966,7 +3088,7 @@ impl AcpThreadView { .unwrap_or(path.path.as_os_str()) .display() .to_string(); - let completion = ContextPickerCompletionProvider::completion_for_path( + let Some(completion) = ContextPickerCompletionProvider::completion_for_path( path, &path_prefix, false, @@ -2975,8 +3097,11 @@ impl AcpThreadView { anchor..anchor, self.message_editor.clone(), self.mention_set.clone(), + self.project.clone(), cx, - ); + ) else { + continue; + }; self.message_editor.update(cx, |message_editor, cx| { message_editor.edit( @@ -3117,7 +3242,7 @@ fn user_message_markdown_style(window: &Window, cx: &App) -> MarkdownStyle { style.base_text_style = text_style; style.link_callback = Some(Rc::new(move |url, cx| { - if MentionPath::try_parse(url).is_some() { + if MentionUri::parse(url).is_ok() { let colors = cx.theme().colors(); Some(TextStyleRefinement { background_color: Some(colors.element_background), @@ -3434,6 +3559,8 @@ mod tests { Rc::new(agent), workspace.downgrade(), project, + WeakEntity::new_invalid(), + WeakEntity::new_invalid(), Rc::new(RefCell::new(MessageHistory::default())), 1, None, @@ -3536,6 +3663,8 @@ mod tests { Rc::new(agent), workspace.downgrade(), project, + WeakEntity::new_invalid(), + WeakEntity::new_invalid(), Rc::new(RefCell::new(MessageHistory::default())), 1, None, diff --git a/crates/agent_ui/src/agent_panel.rs b/crates/agent_ui/src/agent_panel.rs index 87e4dd822c..fbb75f28c0 100644 --- a/crates/agent_ui/src/agent_panel.rs +++ b/crates/agent_ui/src/agent_panel.rs @@ -924,6 +924,9 @@ impl AgentPanel { agent: crate::ExternalAgent, } + let thread_store = self.thread_store.clone(); + let text_thread_store = self.context_store.clone(); + cx.spawn_in(window, async move |this, cx| { let server: Rc = match agent_choice { Some(agent) => { @@ -962,6 +965,8 @@ impl AgentPanel { server, workspace.clone(), project, + thread_store.downgrade(), + text_thread_store.downgrade(), message_history, MIN_EDITOR_LINES, Some(MAX_EDITOR_LINES), diff --git a/crates/agent_ui/src/context_picker.rs b/crates/agent_ui/src/context_picker.rs index 58f11313e6..7dc00bfae2 100644 --- a/crates/agent_ui/src/context_picker.rs +++ b/crates/agent_ui/src/context_picker.rs @@ -1,15 +1,16 @@ mod completion_provider; -mod fetch_context_picker; +pub(crate) mod fetch_context_picker; pub(crate) mod file_context_picker; -mod rules_context_picker; -mod symbol_context_picker; -mod thread_context_picker; +pub(crate) mod rules_context_picker; +pub(crate) mod symbol_context_picker; +pub(crate) mod thread_context_picker; use std::ops::Range; use std::path::{Path, PathBuf}; use std::sync::Arc; use anyhow::{Result, anyhow}; +use collections::HashSet; pub use completion_provider::ContextPickerCompletionProvider; use editor::display_map::{Crease, CreaseId, CreaseMetadata, FoldId}; use editor::{Anchor, AnchorRangeExt as _, Editor, ExcerptId, FoldPlaceholder, ToOffset}; @@ -45,7 +46,7 @@ use agent::{ }; #[derive(Debug, Clone, Copy, PartialEq, Eq)] -enum ContextPickerEntry { +pub(crate) enum ContextPickerEntry { Mode(ContextPickerMode), Action(ContextPickerAction), } @@ -74,7 +75,7 @@ impl ContextPickerEntry { } #[derive(Debug, Clone, Copy, PartialEq, Eq)] -enum ContextPickerMode { +pub(crate) enum ContextPickerMode { File, Symbol, Fetch, @@ -83,7 +84,7 @@ enum ContextPickerMode { } #[derive(Debug, Clone, Copy, PartialEq, Eq)] -enum ContextPickerAction { +pub(crate) enum ContextPickerAction { AddSelections, } @@ -531,7 +532,7 @@ impl ContextPicker { return vec![]; }; - recent_context_picker_entries( + recent_context_picker_entries_with_store( context_store, self.thread_store.clone(), self.text_thread_store.clone(), @@ -585,7 +586,8 @@ impl Render for ContextPicker { }) } } -enum RecentEntry { + +pub(crate) enum RecentEntry { File { project_path: ProjectPath, path_prefix: Arc, @@ -593,7 +595,7 @@ enum RecentEntry { Thread(ThreadContextEntry), } -fn available_context_picker_entries( +pub(crate) fn available_context_picker_entries( prompt_store: &Option>, thread_store: &Option>, workspace: &Entity, @@ -630,24 +632,56 @@ fn available_context_picker_entries( entries } -fn recent_context_picker_entries( +fn recent_context_picker_entries_with_store( context_store: Entity, thread_store: Option>, text_thread_store: Option>, workspace: Entity, exclude_path: Option, cx: &App, +) -> Vec { + let project = workspace.read(cx).project(); + + let mut exclude_paths = context_store.read(cx).file_paths(cx); + exclude_paths.extend(exclude_path); + + let exclude_paths = exclude_paths + .into_iter() + .filter_map(|project_path| project.read(cx).absolute_path(&project_path, cx)) + .collect(); + + let exclude_threads = context_store.read(cx).thread_ids(); + + recent_context_picker_entries( + thread_store, + text_thread_store, + workspace, + &exclude_paths, + exclude_threads, + cx, + ) +} + +pub(crate) fn recent_context_picker_entries( + thread_store: Option>, + text_thread_store: Option>, + workspace: Entity, + exclude_paths: &HashSet, + exclude_threads: &HashSet, + cx: &App, ) -> Vec { let mut recent = Vec::with_capacity(6); - let mut current_files = context_store.read(cx).file_paths(cx); - current_files.extend(exclude_path); let workspace = workspace.read(cx); let project = workspace.project().read(cx); recent.extend( workspace .recent_navigation_history_iter(cx) - .filter(|(path, _)| !current_files.contains(path)) + .filter(|(_, abs_path)| { + abs_path + .as_ref() + .map_or(true, |path| !exclude_paths.contains(path.as_path())) + }) .take(4) .filter_map(|(project_path, _)| { project @@ -659,8 +693,6 @@ fn recent_context_picker_entries( }), ); - let current_threads = context_store.read(cx).thread_ids(); - let active_thread_id = workspace .panel::(cx) .and_then(|panel| Some(panel.read(cx).active_thread(cx)?.read(cx).id())); @@ -672,7 +704,7 @@ fn recent_context_picker_entries( let mut threads = unordered_thread_entries(thread_store, text_thread_store, cx) .filter(|(_, thread)| match thread { ThreadContextEntry::Thread { id, .. } => { - Some(id) != active_thread_id && !current_threads.contains(id) + Some(id) != active_thread_id && !exclude_threads.contains(id) } ThreadContextEntry::Context { .. } => true, }) @@ -710,7 +742,7 @@ fn add_selections_as_context( }) } -fn selection_ranges( +pub(crate) fn selection_ranges( workspace: &Entity, cx: &mut App, ) -> Vec<(Entity, Range)> { diff --git a/crates/agent_ui/src/context_picker/completion_provider.rs b/crates/agent_ui/src/context_picker/completion_provider.rs index 8123b3437d..962c0df03d 100644 --- a/crates/agent_ui/src/context_picker/completion_provider.rs +++ b/crates/agent_ui/src/context_picker/completion_provider.rs @@ -35,7 +35,7 @@ use super::symbol_context_picker::search_symbols; use super::thread_context_picker::{ThreadContextEntry, ThreadMatch, search_threads}; use super::{ ContextPickerAction, ContextPickerEntry, ContextPickerMode, MentionLink, RecentEntry, - available_context_picker_entries, recent_context_picker_entries, selection_ranges, + available_context_picker_entries, recent_context_picker_entries_with_store, selection_ranges, }; use crate::message_editor::ContextCreasesAddon; @@ -787,7 +787,7 @@ impl CompletionProvider for ContextPickerCompletionProvider { .and_then(|b| b.read(cx).file()) .map(|file| ProjectPath::from_file(file.as_ref(), cx)); - let recent_entries = recent_context_picker_entries( + let recent_entries = recent_context_picker_entries_with_store( context_store.clone(), thread_store.clone(), text_thread_store.clone(), diff --git a/crates/assistant_context/Cargo.toml b/crates/assistant_context/Cargo.toml index 8f5ff98790..45c0072418 100644 --- a/crates/assistant_context/Cargo.toml +++ b/crates/assistant_context/Cargo.toml @@ -11,6 +11,9 @@ workspace = true [lib] path = "src/assistant_context.rs" +[features] +test-support = [] + [dependencies] agent_settings.workspace = true anyhow.workspace = true diff --git a/crates/assistant_context/src/context_store.rs b/crates/assistant_context/src/context_store.rs index 3090a7b234..622d8867a7 100644 --- a/crates/assistant_context/src/context_store.rs +++ b/crates/assistant_context/src/context_store.rs @@ -138,6 +138,27 @@ impl ContextStore { }) } + #[cfg(any(test, feature = "test-support"))] + pub fn fake(project: Entity, cx: &mut Context) -> Self { + Self { + contexts: Default::default(), + contexts_metadata: Default::default(), + context_server_slash_command_ids: Default::default(), + host_contexts: Default::default(), + fs: project.read(cx).fs().clone(), + languages: project.read(cx).languages().clone(), + slash_commands: Arc::default(), + telemetry: project.read(cx).client().telemetry().clone(), + _watch_updates: Task::ready(None), + client: project.read(cx).client(), + project, + project_is_shared: false, + client_subscription: None, + _project_subscriptions: Default::default(), + prompt_builder: Arc::new(PromptBuilder::new(None).unwrap()), + } + } + async fn handle_advertise_contexts( this: Entity, envelope: TypedEnvelope, diff --git a/crates/editor/src/editor.rs b/crates/editor/src/editor.rs index d1bf95c794..c421e1fec1 100644 --- a/crates/editor/src/editor.rs +++ b/crates/editor/src/editor.rs @@ -12158,6 +12158,8 @@ impl Editor { let clipboard_text = Cow::Borrowed(text); self.transact(window, cx, |this, window, cx| { + let had_active_edit_prediction = this.has_active_edit_prediction(); + if let Some(mut clipboard_selections) = clipboard_selections { let old_selections = this.selections.all::(cx); let all_selections_were_entire_line = @@ -12230,6 +12232,11 @@ impl Editor { } else { this.insert(&clipboard_text, window, cx); } + + let trigger_in_words = + this.show_edit_predictions_in_menu() || !had_active_edit_prediction; + + this.trigger_completion_on_input(&text, trigger_in_words, window, cx); }); } diff --git a/crates/prompt_store/src/prompt_store.rs b/crates/prompt_store/src/prompt_store.rs index f9cb26ed9a..06a65b97cd 100644 --- a/crates/prompt_store/src/prompt_store.rs +++ b/crates/prompt_store/src/prompt_store.rs @@ -90,6 +90,15 @@ impl From for UserPromptId { } } +impl std::fmt::Display for PromptId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + PromptId::User { uuid } => write!(f, "{}", uuid.0), + PromptId::EditWorkflow => write!(f, "Edit workflow"), + } + } +} + pub struct PromptStore { env: heed::Env, metadata_cache: RwLock,