diff --git a/Cargo.lock b/Cargo.lock index d24a399c1c..ac1a56d53f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -10,6 +10,7 @@ dependencies = [ "agent-client-protocol", "anyhow", "buffer_diff", + "collections", "editor", "env_logger 0.11.8", "futures 0.3.31", @@ -17,7 +18,6 @@ dependencies = [ "indoc", "itertools 0.14.0", "language", - "language_model", "markdown", "parking_lot", "project", @@ -31,6 +31,8 @@ dependencies = [ "ui", "url", "util", + "uuid", + "watch", "workspace-hack", ] @@ -231,6 +233,7 @@ dependencies = [ "task", "tempfile", "terminal", + "text", "theme", "tree-sitter-rust", "ui", @@ -6444,6 +6447,7 @@ dependencies = [ "log", "parking_lot", "pretty_assertions", + "rand 0.8.5", "regex", "rope", "schemars", @@ -11148,14 +11152,13 @@ dependencies = [ "ai_onboarding", "anyhow", "client", - "command_palette_hooks", "component", "db", "documented", "editor", - "feature_flags", "fs", "fuzzy", + "git", "gpui", "itertools 0.14.0", "language", @@ -11167,6 +11170,7 @@ dependencies = [ "schemars", "serde", "settings", + "telemetry", "theme", "ui", "util", @@ -18882,33 +18886,6 @@ version = "0.1.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "53a85b86a771b1c87058196170769dd264f66c0782acf1ae6cc51bfd64b39082" -[[package]] -name = "welcome" -version = "0.1.0" -dependencies = [ - "anyhow", - "client", - "component", - "db", - "documented", - "editor", - "fuzzy", - "gpui", - "install_cli", - "language", - "picker", - "project", - "serde", - "settings", - "telemetry", - "ui", - "util", - "vim_mode_setting", - "workspace", - "workspace-hack", - "zed_actions", -] - [[package]] name = "which" version = "4.4.2" @@ -20663,7 +20640,6 @@ dependencies = [ "watch", "web_search", "web_search_providers", - "welcome", "windows 0.61.1", "winresource", "workspace", @@ -20687,7 +20663,7 @@ dependencies = [ [[package]] name = "zed_emmet" -version = "0.0.5" +version = "0.0.6" dependencies = [ "zed_extension_api 0.1.0", ] diff --git a/Cargo.toml b/Cargo.toml index dd14078dd2..1baa6d3d74 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -185,7 +185,6 @@ members = [ "crates/watch", "crates/web_search", "crates/web_search_providers", - "crates/welcome", "crates/workspace", "crates/worktree", "crates/x_ai", @@ -412,7 +411,6 @@ vim_mode_setting = { path = "crates/vim_mode_setting" } watch = { path = "crates/watch" } web_search = { path = "crates/web_search" } web_search_providers = { path = "crates/web_search_providers" } -welcome = { path = "crates/welcome" } workspace = { path = "crates/workspace" } worktree = { path = "crates/worktree" } x_ai = { path = "crates/x_ai" } @@ -566,6 +564,7 @@ reqwest = { git = "https://github.com/zed-industries/reqwest.git", rev = "951c77 "socks", "stream", ] } +rodio = { version = "0.21.1", default-features = false } rsa = "0.9.6" runtimelib = { git = "https://github.com/ConradIrwin/runtimed", rev = "7130c804216b6914355d15d0b91ea91f6babd734", default-features = false, features = [ "async-dispatcher-runtime", diff --git a/assets/settings/default.json b/assets/settings/default.json index 28cf591ee7..0f1818ac7f 100644 --- a/assets/settings/default.json +++ b/assets/settings/default.json @@ -82,10 +82,10 @@ // Layout mode of the bottom dock. Defaults to "contained" // choices: contained, full, left_aligned, right_aligned "bottom_dock_layout": "contained", - // The direction that you want to split panes horizontally. Defaults to "up" - "pane_split_direction_horizontal": "up", - // The direction that you want to split panes vertically. Defaults to "left" - "pane_split_direction_vertical": "left", + // The direction that you want to split panes horizontally. Defaults to "down" + "pane_split_direction_horizontal": "down", + // The direction that you want to split panes vertically. Defaults to "right" + "pane_split_direction_vertical": "right", // Centered layout related settings. "centered_layout": { // The relative width of the left padding of the central pane from the diff --git a/crates/acp_thread/Cargo.toml b/crates/acp_thread/Cargo.toml index 1fef342c01..b3ec217bad 100644 --- a/crates/acp_thread/Cargo.toml +++ b/crates/acp_thread/Cargo.toml @@ -20,12 +20,12 @@ action_log.workspace = true agent-client-protocol.workspace = true anyhow.workspace = true buffer_diff.workspace = true +collections.workspace = true editor.workspace = true futures.workspace = true gpui.workspace = true itertools.workspace = true language.workspace = true -language_model.workspace = true markdown.workspace = true project.workspace = true serde.workspace = true @@ -36,6 +36,8 @@ terminal.workspace = true ui.workspace = true url.workspace = true util.workspace = true +uuid.workspace = true +watch.workspace = true workspace-hack.workspace = true [dev-dependencies] diff --git a/crates/acp_thread/src/acp_thread.rs b/crates/acp_thread/src/acp_thread.rs index cadab3d62c..f8a5bf8032 100644 --- a/crates/acp_thread/src/acp_thread.rs +++ b/crates/acp_thread/src/acp_thread.rs @@ -9,18 +9,19 @@ pub use mention::*; pub use terminal::*; use action_log::ActionLog; -use agent_client_protocol::{self as acp}; -use anyhow::{Context as _, Result}; +use agent_client_protocol as acp; +use anyhow::{Context as _, Result, anyhow}; use editor::Bias; use futures::{FutureExt, channel::oneshot, future::BoxFuture}; -use gpui::{AppContext, Context, Entity, EventEmitter, SharedString, Task}; +use gpui::{AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity}; use itertools::Itertools; -use language::{Anchor, Buffer, BufferSnapshot, LanguageRegistry, Point, text_diff}; +use language::{Anchor, Buffer, BufferSnapshot, LanguageRegistry, Point, ToPoint, text_diff}; use markdown::Markdown; -use project::{AgentLocation, Project}; +use project::{AgentLocation, Project, git_store::GitStoreCheckpoint}; use std::collections::HashMap; use std::error::Error; -use std::fmt::Formatter; +use std::fmt::{Formatter, Write}; +use std::ops::Range; use std::process::ExitStatus; use std::rc::Rc; use std::{fmt::Display, mem, path::PathBuf, sync::Arc}; @@ -29,24 +30,23 @@ use util::ResultExt; #[derive(Debug)] pub struct UserMessage { + pub id: Option, pub content: ContentBlock, + pub checkpoint: Option, } impl UserMessage { - pub fn from_acp( - message: impl IntoIterator, - language_registry: Arc, - cx: &mut App, - ) -> Self { - let mut content = ContentBlock::Empty; - for chunk in message { - content.append(chunk, &language_registry, cx) - } - Self { content: content } - } - fn to_markdown(&self, cx: &App) -> String { - format!("## User\n\n{}\n\n", self.content.to_markdown(cx)) + let mut markdown = String::new(); + if let Some(_) = self.checkpoint { + writeln!(markdown, "## User (checkpoint)").unwrap(); + } else { + writeln!(markdown, "## User").unwrap(); + } + writeln!(markdown).unwrap(); + writeln!(markdown, "{}", self.content.to_markdown(cx)).unwrap(); + writeln!(markdown).unwrap(); + markdown } } @@ -122,9 +122,17 @@ impl AgentThreadEntry { } } - pub fn locations(&self) -> Option<&[acp::ToolCallLocation]> { - if let AgentThreadEntry::ToolCall(ToolCall { locations, .. }) = self { - Some(locations) + pub fn location(&self, ix: usize) -> Option<(acp::ToolCallLocation, AgentLocation)> { + if let AgentThreadEntry::ToolCall(ToolCall { + locations, + resolved_locations, + .. + }) = self + { + Some(( + locations.get(ix)?.clone(), + resolved_locations.get(ix)?.clone()?, + )) } else { None } @@ -139,6 +147,7 @@ pub struct ToolCall { pub content: Vec, pub status: ToolCallStatus, pub locations: Vec, + pub resolved_locations: Vec>, pub raw_input: Option, pub raw_output: Option, } @@ -167,6 +176,7 @@ impl ToolCall { .map(|content| ToolCallContent::from_acp(content, language_registry.clone(), cx)) .collect(), locations: tool_call.locations, + resolved_locations: Vec::default(), status, raw_input: tool_call.raw_input, raw_output: tool_call.raw_output, @@ -260,6 +270,57 @@ impl ToolCall { } markdown } + + async fn resolve_location( + location: acp::ToolCallLocation, + project: WeakEntity, + cx: &mut AsyncApp, + ) -> Option { + let buffer = project + .update(cx, |project, cx| { + if let Some(path) = project.project_path_for_absolute_path(&location.path, cx) { + Some(project.open_buffer(path, cx)) + } else { + None + } + }) + .ok()??; + let buffer = buffer.await.log_err()?; + let position = buffer + .update(cx, |buffer, _| { + if let Some(row) = location.line { + let snapshot = buffer.snapshot(); + let column = snapshot.indent_size_for_line(row).len; + let point = snapshot.clip_point(Point::new(row, column), Bias::Left); + snapshot.anchor_before(point) + } else { + Anchor::MIN + } + }) + .ok()?; + + Some(AgentLocation { + buffer: buffer.downgrade(), + position, + }) + } + + fn resolve_locations( + &self, + project: Entity, + cx: &mut App, + ) -> Task>> { + let locations = self.locations.clone(); + project.update(cx, |_, cx| { + cx.spawn(async move |project, cx| { + let mut new_locations = Vec::new(); + for location in locations { + new_locations.push(Self::resolve_location(location, project.clone(), cx).await); + } + new_locations + }) + }) + } } #[derive(Debug)] @@ -572,6 +633,7 @@ pub struct AcpThread { pub enum AcpThreadEvent { NewEntry, EntryUpdated(usize), + EntriesRemoved(Range), ToolAuthorizationRequired, Stopped, Error, @@ -633,6 +695,10 @@ impl AcpThread { } } + pub fn connection(&self) -> &Rc { + &self.connection + } + pub fn action_log(&self) -> &Entity { &self.action_log } @@ -707,7 +773,7 @@ impl AcpThread { ) -> Result<()> { match update { acp::SessionUpdate::UserMessageChunk { content } => { - self.push_user_content_block(content, cx); + self.push_user_content_block(None, content, cx); } acp::SessionUpdate::AgentMessageChunk { content } => { self.push_assistant_content_block(content, false, cx); @@ -728,18 +794,32 @@ impl AcpThread { Ok(()) } - pub fn push_user_content_block(&mut self, chunk: acp::ContentBlock, cx: &mut Context) { + pub fn push_user_content_block( + &mut self, + message_id: Option, + chunk: acp::ContentBlock, + cx: &mut Context, + ) { let language_registry = self.project.read(cx).languages().clone(); let entries_len = self.entries.len(); if let Some(last_entry) = self.entries.last_mut() - && let AgentThreadEntry::UserMessage(UserMessage { content }) = last_entry + && let AgentThreadEntry::UserMessage(UserMessage { id, content, .. }) = last_entry { + *id = message_id.or(id.take()); content.append(chunk, &language_registry, cx); - cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1)); + let idx = entries_len - 1; + cx.emit(AcpThreadEvent::EntryUpdated(idx)); } else { let content = ContentBlock::new(chunk, &language_registry, cx); - self.push_entry(AgentThreadEntry::UserMessage(UserMessage { content }), cx); + self.push_entry( + AgentThreadEntry::UserMessage(UserMessage { + id: message_id, + content, + checkpoint: None, + }), + cx, + ); } } @@ -754,7 +834,8 @@ impl AcpThread { if let Some(last_entry) = self.entries.last_mut() && let AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) = last_entry { - cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1)); + let idx = entries_len - 1; + cx.emit(AcpThreadEvent::EntryUpdated(idx)); match (chunks.last_mut(), is_thought) { (Some(AssistantMessageChunk::Message { block }), false) | (Some(AssistantMessageChunk::Thought { block }), true) => { @@ -804,7 +885,11 @@ impl AcpThread { .context("Tool call not found")?; match update { ToolCallUpdate::UpdateFields(update) => { + let location_updated = update.fields.locations.is_some(); current_call.update_fields(update.fields, languages, cx); + if location_updated { + self.resolve_locations(update.id.clone(), cx); + } } ToolCallUpdate::UpdateDiff(update) => { current_call.content.clear(); @@ -841,8 +926,7 @@ impl AcpThread { ) { let language_registry = self.project.read(cx).languages().clone(); let call = ToolCall::from_acp(tool_call, status, language_registry, cx); - - let location = call.locations.last().cloned(); + let id = call.id.clone(); if let Some((ix, current_call)) = self.tool_call_mut(&call.id) { *current_call = call; @@ -850,11 +934,9 @@ impl AcpThread { cx.emit(AcpThreadEvent::EntryUpdated(ix)); } else { self.push_entry(AgentThreadEntry::ToolCall(call), cx); - } + }; - if let Some(location) = location { - self.set_project_location(location, cx) - } + self.resolve_locations(id, cx); } fn tool_call_mut(&mut self, id: &acp::ToolCallId) -> Option<(usize, &mut ToolCall)> { @@ -875,35 +957,50 @@ impl AcpThread { }) } - pub fn set_project_location(&self, location: acp::ToolCallLocation, cx: &mut Context) { - self.project.update(cx, |project, cx| { - let Some(path) = project.project_path_for_absolute_path(&location.path, cx) else { - return; - }; - let buffer = project.open_buffer(path, cx); - cx.spawn(async move |project, cx| { - let buffer = buffer.await?; - - project.update(cx, |project, cx| { - let position = if let Some(line) = location.line { - let snapshot = buffer.read(cx).snapshot(); - let point = snapshot.clip_point(Point::new(line, 0), Bias::Left); - snapshot.anchor_before(point) - } else { - Anchor::MIN - }; - - project.set_agent_location( - Some(AgentLocation { - buffer: buffer.downgrade(), - position, - }), - cx, - ); - }) + pub fn resolve_locations(&mut self, id: acp::ToolCallId, cx: &mut Context) { + let project = self.project.clone(); + let Some((_, tool_call)) = self.tool_call_mut(&id) else { + return; + }; + let task = tool_call.resolve_locations(project, cx); + cx.spawn(async move |this, cx| { + let resolved_locations = task.await; + this.update(cx, |this, cx| { + let project = this.project.clone(); + let Some((ix, tool_call)) = this.tool_call_mut(&id) else { + return; + }; + if let Some(Some(location)) = resolved_locations.last() { + project.update(cx, |project, cx| { + if let Some(agent_location) = project.agent_location() { + let should_ignore = agent_location.buffer == location.buffer + && location + .buffer + .update(cx, |buffer, _| { + let snapshot = buffer.snapshot(); + let old_position = + agent_location.position.to_point(&snapshot); + let new_position = location.position.to_point(&snapshot); + // ignore this so that when we get updates from the edit tool + // the position doesn't reset to the startof line + old_position.row == new_position.row + && old_position.column > new_position.column + }) + .ok() + .unwrap_or_default(); + if !should_ignore { + project.set_agent_location(Some(location.clone()), cx); + } + } + }); + } + if tool_call.resolved_locations != resolved_locations { + tool_call.resolved_locations = resolved_locations; + cx.emit(AcpThreadEvent::EntryUpdated(ix)); + } }) - .detach_and_log_err(cx); - }); + }) + .detach(); } pub fn request_tool_call_authorization( @@ -1037,66 +1134,113 @@ impl AcpThread { self.project.read(cx).languages().clone(), cx, ); + let git_store = self.project.read(cx).git_store().clone(); + + let old_checkpoint = git_store.update(cx, |git, cx| git.checkpoint(cx)); + let message_id = if self + .connection + .session_editor(&self.session_id, cx) + .is_some() + { + Some(UserMessageId::new()) + } else { + None + }; self.push_entry( - AgentThreadEntry::UserMessage(UserMessage { content: block }), + AgentThreadEntry::UserMessage(UserMessage { + id: message_id.clone(), + content: block, + checkpoint: None, + }), cx, ); self.clear_completed_plan_entries(cx); + let (old_checkpoint_tx, old_checkpoint_rx) = oneshot::channel(); let (tx, rx) = oneshot::channel(); let cancel_task = self.cancel(cx); + let request = acp::PromptRequest { + prompt: message, + session_id: self.session_id.clone(), + }; - self.send_task = Some(cx.spawn(async move |this, cx| { - async { + self.send_task = Some(cx.spawn({ + let message_id = message_id.clone(); + async move |this, cx| { cancel_task.await; - let result = this - .update(cx, |this, cx| { - this.connection.prompt( - acp::PromptRequest { - prompt: message, - session_id: this.session_id.clone(), - }, - cx, - ) - })? - .await; - - tx.send(result).log_err(); - - anyhow::Ok(()) + old_checkpoint_tx.send(old_checkpoint.await).ok(); + if let Ok(result) = this.update(cx, |this, cx| { + this.connection.prompt(message_id, request, cx) + }) { + tx.send(result.await).log_err(); + } } - .await - .log_err(); })); - cx.spawn(async move |this, cx| match rx.await { - Ok(Err(e)) => { - this.update(cx, |_, cx| cx.emit(AcpThreadEvent::Error)) - .log_err(); - Err(e)? - } - result => { - let cancelled = matches!( - result, - Ok(Ok(acp::PromptResponse { - stop_reason: acp::StopReason::Cancelled - })) - ); + cx.spawn(async move |this, cx| { + let old_checkpoint = old_checkpoint_rx + .await + .map_err(|_| anyhow!("send canceled")) + .flatten() + .context("failed to get old checkpoint") + .log_err(); - // We only take the task if the current prompt wasn't cancelled. - // - // This prompt may have been cancelled because another one was sent - // while it was still generating. In these cases, dropping `send_task` - // would cause the next generation to be cancelled. - if !cancelled { - this.update(cx, |this, _cx| this.send_task.take()).ok(); + let response = rx.await; + + if let Some((old_checkpoint, message_id)) = old_checkpoint.zip(message_id) { + let new_checkpoint = git_store + .update(cx, |git, cx| git.checkpoint(cx))? + .await + .context("failed to get new checkpoint") + .log_err(); + if let Some(new_checkpoint) = new_checkpoint { + let equal = git_store + .update(cx, |git, cx| { + git.compare_checkpoints(old_checkpoint.clone(), new_checkpoint, cx) + })? + .await + .unwrap_or(true); + if !equal { + this.update(cx, |this, cx| { + if let Some((ix, message)) = this.user_message_mut(&message_id) { + message.checkpoint = Some(old_checkpoint); + cx.emit(AcpThreadEvent::EntryUpdated(ix)); + } + })?; + } } - - this.update(cx, |_, cx| cx.emit(AcpThreadEvent::Stopped)) - .log_err(); - Ok(()) } + + this.update(cx, |this, cx| { + match response { + Ok(Err(e)) => { + this.send_task.take(); + cx.emit(AcpThreadEvent::Error); + Err(e) + } + result => { + let cancelled = matches!( + result, + Ok(Ok(acp::PromptResponse { + stop_reason: acp::StopReason::Cancelled + })) + ); + + // We only take the task if the current prompt wasn't cancelled. + // + // This prompt may have been cancelled because another one was sent + // while it was still generating. In these cases, dropping `send_task` + // would cause the next generation to be cancelled. + if !cancelled { + this.send_task.take(); + } + + cx.emit(AcpThreadEvent::Stopped); + Ok(()) + } + } + })? }) .boxed() } @@ -1128,6 +1272,66 @@ impl AcpThread { cx.foreground_executor().spawn(send_task) } + /// Rewinds this thread to before the entry at `index`, removing it and all + /// subsequent entries while reverting any changes made from that point. + pub fn rewind(&mut self, id: UserMessageId, cx: &mut Context) -> Task> { + let Some(session_editor) = self.connection.session_editor(&self.session_id, cx) else { + return Task::ready(Err(anyhow!("not supported"))); + }; + let Some(message) = self.user_message(&id) else { + return Task::ready(Err(anyhow!("message not found"))); + }; + + let checkpoint = message.checkpoint.clone(); + + let git_store = self.project.read(cx).git_store().clone(); + cx.spawn(async move |this, cx| { + if let Some(checkpoint) = checkpoint { + git_store + .update(cx, |git, cx| git.restore_checkpoint(checkpoint, cx))? + .await?; + } + + cx.update(|cx| session_editor.truncate(id.clone(), cx))? + .await?; + this.update(cx, |this, cx| { + if let Some((ix, _)) = this.user_message_mut(&id) { + let range = ix..this.entries.len(); + this.entries.truncate(ix); + cx.emit(AcpThreadEvent::EntriesRemoved(range)); + } + }) + }) + } + + fn user_message(&self, id: &UserMessageId) -> Option<&UserMessage> { + self.entries.iter().find_map(|entry| { + if let AgentThreadEntry::UserMessage(message) = entry { + if message.id.as_ref() == Some(&id) { + Some(message) + } else { + None + } + } else { + None + } + }) + } + + fn user_message_mut(&mut self, id: &UserMessageId) -> Option<(usize, &mut UserMessage)> { + self.entries.iter_mut().enumerate().find_map(|(ix, entry)| { + if let AgentThreadEntry::UserMessage(message) = entry { + if message.id.as_ref() == Some(&id) { + Some((ix, message)) + } else { + None + } + } else { + None + } + }) + } + pub fn read_text_file( &self, path: PathBuf, @@ -1330,13 +1534,18 @@ mod tests { use futures::{channel::mpsc, future::LocalBoxFuture, select}; use gpui::{AsyncApp, TestAppContext, WeakEntity}; use indoc::indoc; - use project::FakeFs; + use project::{FakeFs, Fs}; use rand::Rng as _; use serde_json::json; use settings::SettingsStore; use smol::stream::StreamExt as _; - use std::{cell::RefCell, path::Path, rc::Rc, time::Duration}; - + use std::{ + cell::RefCell, + path::Path, + rc::Rc, + sync::atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst}, + time::Duration, + }; use util::path; fn init_test(cx: &mut TestAppContext) { @@ -1368,6 +1577,7 @@ mod tests { // Test creating a new user message thread.update(cx, |thread, cx| { thread.push_user_content_block( + None, acp::ContentBlock::Text(acp::TextContent { annotations: None, text: "Hello, ".to_string(), @@ -1379,6 +1589,7 @@ mod tests { thread.update(cx, |thread, cx| { assert_eq!(thread.entries.len(), 1); if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] { + assert_eq!(user_msg.id, None); assert_eq!(user_msg.content.to_markdown(cx), "Hello, "); } else { panic!("Expected UserMessage"); @@ -1386,8 +1597,10 @@ mod tests { }); // Test appending to existing user message + let message_1_id = UserMessageId::new(); thread.update(cx, |thread, cx| { thread.push_user_content_block( + Some(message_1_id.clone()), acp::ContentBlock::Text(acp::TextContent { annotations: None, text: "world!".to_string(), @@ -1399,6 +1612,7 @@ mod tests { thread.update(cx, |thread, cx| { assert_eq!(thread.entries.len(), 1); if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] { + assert_eq!(user_msg.id, Some(message_1_id)); assert_eq!(user_msg.content.to_markdown(cx), "Hello, world!"); } else { panic!("Expected UserMessage"); @@ -1417,8 +1631,10 @@ mod tests { ); }); + let message_2_id = UserMessageId::new(); thread.update(cx, |thread, cx| { thread.push_user_content_block( + Some(message_2_id.clone()), acp::ContentBlock::Text(acp::TextContent { annotations: None, text: "New user message".to_string(), @@ -1430,6 +1646,7 @@ mod tests { thread.update(cx, |thread, cx| { assert_eq!(thread.entries.len(), 3); if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[2] { + assert_eq!(user_msg.id, Some(message_2_id)); assert_eq!(user_msg.content.to_markdown(cx), "New user message"); } else { panic!("Expected UserMessage at index 2"); @@ -1746,6 +1963,180 @@ mod tests { assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls())); } + #[gpui::test(iterations = 10)] + async fn test_checkpoints(cx: &mut TestAppContext) { + init_test(cx); + let fs = FakeFs::new(cx.background_executor.clone()); + fs.insert_tree( + path!("/test"), + json!({ + ".git": {} + }), + ) + .await; + let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await; + + let simulate_changes = Arc::new(AtomicBool::new(true)); + let next_filename = Arc::new(AtomicUsize::new(0)); + let connection = Rc::new(FakeAgentConnection::new().on_user_message({ + let simulate_changes = simulate_changes.clone(); + let next_filename = next_filename.clone(); + let fs = fs.clone(); + move |request, thread, mut cx| { + let fs = fs.clone(); + let simulate_changes = simulate_changes.clone(); + let next_filename = next_filename.clone(); + async move { + if simulate_changes.load(SeqCst) { + let filename = format!("/test/file-{}", next_filename.fetch_add(1, SeqCst)); + fs.write(Path::new(&filename), b"").await?; + } + + let acp::ContentBlock::Text(content) = &request.prompt[0] else { + panic!("expected text content block"); + }; + thread.update(&mut cx, |thread, cx| { + thread + .handle_session_update( + acp::SessionUpdate::AgentMessageChunk { + content: content.text.to_uppercase().into(), + }, + cx, + ) + .unwrap(); + })?; + Ok(acp::PromptResponse { + stop_reason: acp::StopReason::EndTurn, + }) + } + .boxed_local() + } + })); + let thread = connection + .new_thread(project, Path::new(path!("/test")), &mut cx.to_async()) + .await + .unwrap(); + + cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Lorem".into()], cx))) + .await + .unwrap(); + thread.read_with(cx, |thread, cx| { + assert_eq!( + thread.to_markdown(cx), + indoc! {" + ## User (checkpoint) + + Lorem + + ## Assistant + + LOREM + + "} + ); + }); + assert_eq!(fs.files(), vec![Path::new("/test/file-0")]); + + cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["ipsum".into()], cx))) + .await + .unwrap(); + thread.read_with(cx, |thread, cx| { + assert_eq!( + thread.to_markdown(cx), + indoc! {" + ## User (checkpoint) + + Lorem + + ## Assistant + + LOREM + + ## User (checkpoint) + + ipsum + + ## Assistant + + IPSUM + + "} + ); + }); + assert_eq!( + fs.files(), + vec![Path::new("/test/file-0"), Path::new("/test/file-1")] + ); + + // Checkpoint isn't stored when there are no changes. + simulate_changes.store(false, SeqCst); + cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["dolor".into()], cx))) + .await + .unwrap(); + thread.read_with(cx, |thread, cx| { + assert_eq!( + thread.to_markdown(cx), + indoc! {" + ## User (checkpoint) + + Lorem + + ## Assistant + + LOREM + + ## User (checkpoint) + + ipsum + + ## Assistant + + IPSUM + + ## User + + dolor + + ## Assistant + + DOLOR + + "} + ); + }); + assert_eq!( + fs.files(), + vec![Path::new("/test/file-0"), Path::new("/test/file-1")] + ); + + // Rewinding the conversation truncates the history and restores the checkpoint. + thread + .update(cx, |thread, cx| { + let AgentThreadEntry::UserMessage(message) = &thread.entries[2] else { + panic!("unexpected entries {:?}", thread.entries) + }; + thread.rewind(message.id.clone().unwrap(), cx) + }) + .await + .unwrap(); + thread.read_with(cx, |thread, cx| { + assert_eq!( + thread.to_markdown(cx), + indoc! {" + ## User (checkpoint) + + Lorem + + ## Assistant + + LOREM + + "} + ); + }); + assert_eq!(fs.files(), vec![Path::new("/test/file-0")]); + } + async fn run_until_first_tool_call( thread: &Entity, cx: &mut TestAppContext, @@ -1854,6 +2245,7 @@ mod tests { fn prompt( &self, + _id: Option, params: acp::PromptRequest, cx: &mut App, ) -> Task> { @@ -1882,5 +2274,25 @@ mod tests { }) .detach(); } + + fn session_editor( + &self, + session_id: &acp::SessionId, + _cx: &mut App, + ) -> Option> { + Some(Rc::new(FakeAgentSessionEditor { + _session_id: session_id.clone(), + })) + } + } + + struct FakeAgentSessionEditor { + _session_id: acp::SessionId, + } + + impl AgentSessionEditor for FakeAgentSessionEditor { + fn truncate(&self, _message_id: UserMessageId, _cx: &mut App) -> Task> { + Task::ready(Ok(())) + } } } diff --git a/crates/acp_thread/src/connection.rs b/crates/acp_thread/src/connection.rs index cf06563bee..c3167eb2d4 100644 --- a/crates/acp_thread/src/connection.rs +++ b/crates/acp_thread/src/connection.rs @@ -1,18 +1,78 @@ -use std::{error::Error, fmt, path::Path, rc::Rc, sync::Arc}; - +use crate::AcpThread; use agent_client_protocol::{self as acp}; use anyhow::Result; -use gpui::{AsyncApp, Entity, Task}; -use language_model::LanguageModel; +use collections::IndexMap; +use gpui::{AsyncApp, Entity, SharedString, Task}; use project::Project; -use ui::App; +use std::{error::Error, fmt, path::Path, rc::Rc, sync::Arc}; +use ui::{App, IconName}; +use uuid::Uuid; -use crate::AcpThread; +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct UserMessageId(Arc); + +impl UserMessageId { + pub fn new() -> Self { + Self(Uuid::new_v4().to_string().into()) + } +} + +pub trait AgentConnection { + fn new_thread( + self: Rc, + project: Entity, + cwd: &Path, + cx: &mut AsyncApp, + ) -> Task>>; + + fn auth_methods(&self) -> &[acp::AuthMethod]; + + fn authenticate(&self, method: acp::AuthMethodId, cx: &mut App) -> Task>; + + fn prompt( + &self, + user_message_id: Option, + params: acp::PromptRequest, + cx: &mut App, + ) -> Task>; + + fn cancel(&self, session_id: &acp::SessionId, cx: &mut App); + + fn session_editor( + &self, + _session_id: &acp::SessionId, + _cx: &mut App, + ) -> Option> { + None + } + + /// Returns this agent as an [Rc] if the model selection capability is supported. + /// + /// If the agent does not support model selection, returns [None]. + /// This allows sharing the selector in UI components. + fn model_selector(&self) -> Option> { + None + } +} + +pub trait AgentSessionEditor { + fn truncate(&self, message_id: UserMessageId, cx: &mut App) -> Task>; +} + +#[derive(Debug)] +pub struct AuthRequired; + +impl Error for AuthRequired {} +impl fmt::Display for AuthRequired { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "AuthRequired") + } +} /// Trait for agents that support listing, selecting, and querying language models. /// /// This is an optional capability; agents indicate support via [AgentConnection::model_selector]. -pub trait ModelSelector: 'static { +pub trait AgentModelSelector: 'static { /// Lists all available language models for this agent. /// /// # Parameters @@ -20,7 +80,7 @@ pub trait ModelSelector: 'static { /// /// # Returns /// A task resolving to the list of models or an error (e.g., if no models are configured). - fn list_models(&self, cx: &mut AsyncApp) -> Task>>>; + fn list_models(&self, cx: &mut App) -> Task>; /// Selects a model for a specific session (thread). /// @@ -37,8 +97,8 @@ pub trait ModelSelector: 'static { fn select_model( &self, session_id: acp::SessionId, - model: Arc, - cx: &mut AsyncApp, + model_id: AgentModelId, + cx: &mut App, ) -> Task>; /// Retrieves the currently selected model for a specific session (thread). @@ -52,42 +112,51 @@ pub trait ModelSelector: 'static { fn selected_model( &self, session_id: &acp::SessionId, - cx: &mut AsyncApp, - ) -> Task>>; + cx: &mut App, + ) -> Task>; + + /// Whenever the model list is updated the receiver will be notified. + fn watch(&self, cx: &mut App) -> watch::Receiver<()>; } -pub trait AgentConnection { - fn new_thread( - self: Rc, - project: Entity, - cwd: &Path, - cx: &mut AsyncApp, - ) -> Task>>; +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct AgentModelId(pub SharedString); - fn auth_methods(&self) -> &[acp::AuthMethod]; +impl std::ops::Deref for AgentModelId { + type Target = SharedString; - fn authenticate(&self, method: acp::AuthMethodId, cx: &mut App) -> Task>; - - fn prompt(&self, params: acp::PromptRequest, cx: &mut App) - -> Task>; - - fn cancel(&self, session_id: &acp::SessionId, cx: &mut App); - - /// Returns this agent as an [Rc] if the model selection capability is supported. - /// - /// If the agent does not support model selection, returns [None]. - /// This allows sharing the selector in UI components. - fn model_selector(&self) -> Option> { - None // Default impl for agents that don't support it + fn deref(&self) -> &Self::Target { + &self.0 } } -#[derive(Debug)] -pub struct AuthRequired; - -impl Error for AuthRequired {} -impl fmt::Display for AuthRequired { +impl fmt::Display for AgentModelId { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "AuthRequired") + self.0.fmt(f) + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct AgentModelInfo { + pub id: AgentModelId, + pub name: SharedString, + pub icon: Option, +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct AgentModelGroupName(pub SharedString); + +#[derive(Debug, Clone)] +pub enum AgentModelList { + Flat(Vec), + Grouped(IndexMap>), +} + +impl AgentModelList { + pub fn is_empty(&self) -> bool { + match self { + AgentModelList::Flat(models) => models.is_empty(), + AgentModelList::Grouped(groups) => groups.is_empty(), + } } } diff --git a/crates/activity_indicator/src/activity_indicator.rs b/crates/activity_indicator/src/activity_indicator.rs index f8ea7173d8..7c562aaba4 100644 --- a/crates/activity_indicator/src/activity_indicator.rs +++ b/crates/activity_indicator/src/activity_indicator.rs @@ -716,18 +716,10 @@ impl ActivityIndicator { })), tooltip_message: Some(Self::version_tooltip_message(&version)), }), - AutoUpdateStatus::Updated { - binary_path, - version, - } => Some(Content { + AutoUpdateStatus::Updated { version } => Some(Content { icon: None, message: "Click to restart and update Zed".to_string(), - on_click: Some(Arc::new({ - let reload = workspace::Reload { - binary_path: Some(binary_path.clone()), - }; - move |_, _, cx| workspace::reload(&reload, cx) - })), + on_click: Some(Arc::new(move |_, _, cx| workspace::reload(cx))), tooltip_message: Some(Self::version_tooltip_message(&version)), }), AutoUpdateStatus::Errored => Some(Content { diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index 20d482f60d..1d417efbba 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -2268,6 +2268,15 @@ impl Thread { max_attempts: 3, }) } + Other(err) + if err.is::() + || err.is::() => + { + // Retrying won't help for Payment Required or Model Request Limit errors (where + // the user must upgrade to usage-based billing to get more requests, or else wait + // for a significant amount of time for the request limit to reset). + None + } // Conservatively assume that any other errors are non-retryable HttpResponseError { .. } | Other(..) => Some(RetryStrategy::Fixed { delay: BASE_RETRY_DELAY, diff --git a/crates/agent2/Cargo.toml b/crates/agent2/Cargo.toml index 1030380dc0..ac1840e5e5 100644 --- a/crates/agent2/Cargo.toml +++ b/crates/agent2/Cargo.toml @@ -49,6 +49,7 @@ settings.workspace = true smol.workspace = true task.workspace = true terminal.workspace = true +text.workspace = true ui.workspace = true util.workspace = true uuid.workspace = true diff --git a/crates/agent2/src/agent.rs b/crates/agent2/src/agent.rs index 7439b2a088..ced8c5e401 100644 --- a/crates/agent2/src/agent.rs +++ b/crates/agent2/src/agent.rs @@ -1,21 +1,26 @@ use crate::{AgentResponseEvent, Thread, templates::Templates}; use crate::{ ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DiagnosticsTool, EditFileTool, - FetchTool, FindPathTool, GrepTool, ListDirectoryTool, MessageContent, MovePathTool, NowTool, - OpenTool, ReadFileTool, TerminalTool, ThinkingTool, ToolCallAuthorization, WebSearchTool, + FetchTool, FindPathTool, GrepTool, ListDirectoryTool, MovePathTool, NowTool, OpenTool, + ReadFileTool, TerminalTool, ThinkingTool, ToolCallAuthorization, UserMessageContent, + WebSearchTool, }; -use acp_thread::ModelSelector; +use acp_thread::AgentModelSelector; use agent_client_protocol as acp; +use agent_settings::AgentSettings; use anyhow::{Context as _, Result, anyhow}; +use collections::{HashSet, IndexMap}; +use fs::Fs; use futures::{StreamExt, future}; use gpui::{ App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity, }; -use language_model::{LanguageModel, LanguageModelRegistry}; +use language_model::{LanguageModel, LanguageModelProvider, LanguageModelRegistry}; use project::{Project, ProjectItem, ProjectPath, Worktree}; use prompt_store::{ ProjectContext, PromptId, PromptStore, RulesFileContext, UserRulesContext, WorktreeContext, }; +use settings::update_settings_file; use std::cell::RefCell; use std::collections::HashMap; use std::path::Path; @@ -48,6 +53,104 @@ struct Session { _subscription: Subscription, } +pub struct LanguageModels { + /// Access language model by ID + models: HashMap>, + /// Cached list for returning language model information + model_list: acp_thread::AgentModelList, + refresh_models_rx: watch::Receiver<()>, + refresh_models_tx: watch::Sender<()>, +} + +impl LanguageModels { + fn new(cx: &App) -> Self { + let (refresh_models_tx, refresh_models_rx) = watch::channel(()); + let mut this = Self { + models: HashMap::default(), + model_list: acp_thread::AgentModelList::Grouped(IndexMap::default()), + refresh_models_rx, + refresh_models_tx, + }; + this.refresh_list(cx); + this + } + + fn refresh_list(&mut self, cx: &App) { + let providers = LanguageModelRegistry::global(cx) + .read(cx) + .providers() + .into_iter() + .filter(|provider| provider.is_authenticated(cx)) + .collect::>(); + + let mut language_model_list = IndexMap::default(); + let mut recommended_models = HashSet::default(); + + let mut recommended = Vec::new(); + for provider in &providers { + for model in provider.recommended_models(cx) { + recommended_models.insert(model.id()); + recommended.push(Self::map_language_model_to_info(&model, &provider)); + } + } + if !recommended.is_empty() { + language_model_list.insert( + acp_thread::AgentModelGroupName("Recommended".into()), + recommended, + ); + } + + let mut models = HashMap::default(); + for provider in providers { + let mut provider_models = Vec::new(); + for model in provider.provided_models(cx) { + let model_info = Self::map_language_model_to_info(&model, &provider); + let model_id = model_info.id.clone(); + if !recommended_models.contains(&model.id()) { + provider_models.push(model_info); + } + models.insert(model_id, model); + } + if !provider_models.is_empty() { + language_model_list.insert( + acp_thread::AgentModelGroupName(provider.name().0.clone()), + provider_models, + ); + } + } + + self.models = models; + self.model_list = acp_thread::AgentModelList::Grouped(language_model_list); + self.refresh_models_tx.send(()).ok(); + } + + fn watch(&self) -> watch::Receiver<()> { + self.refresh_models_rx.clone() + } + + pub fn model_from_id( + &self, + model_id: &acp_thread::AgentModelId, + ) -> Option> { + self.models.get(model_id).cloned() + } + + fn map_language_model_to_info( + model: &Arc, + provider: &Arc, + ) -> acp_thread::AgentModelInfo { + acp_thread::AgentModelInfo { + id: Self::model_id(model), + name: model.name().0, + icon: Some(provider.icon()), + } + } + + fn model_id(model: &Arc) -> acp_thread::AgentModelId { + acp_thread::AgentModelId(format!("{}/{}", model.provider_id().0, model.id().0).into()) + } +} + pub struct NativeAgent { /// Session ID -> Session mapping sessions: HashMap, @@ -58,8 +161,11 @@ pub struct NativeAgent { context_server_registry: Entity, /// Shared templates for all threads templates: Arc, + /// Cached model information + models: LanguageModels, project: Entity, prompt_store: Option>, + fs: Arc, _subscriptions: Vec, } @@ -68,6 +174,7 @@ impl NativeAgent { project: Entity, templates: Arc, prompt_store: Option>, + fs: Arc, cx: &mut AsyncApp, ) -> Result> { log::info!("Creating new NativeAgent"); @@ -77,7 +184,13 @@ impl NativeAgent { .await; cx.new(|cx| { - let mut subscriptions = vec![cx.subscribe(&project, Self::handle_project_event)]; + let mut subscriptions = vec![ + cx.subscribe(&project, Self::handle_project_event), + cx.subscribe( + &LanguageModelRegistry::global(cx), + Self::handle_models_updated_event, + ), + ]; if let Some(prompt_store) = prompt_store.as_ref() { subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event)) } @@ -95,13 +208,19 @@ impl NativeAgent { ContextServerRegistry::new(project.read(cx).context_server_store(), cx) }), templates, + models: LanguageModels::new(cx), project, prompt_store, + fs, _subscriptions: subscriptions, } }) } + pub fn models(&self) -> &LanguageModels { + &self.models + } + async fn maintain_project_context( this: WeakEntity, mut needs_refresh: watch::Receiver<()>, @@ -297,75 +416,104 @@ impl NativeAgent { ) { self.project_context_needs_refresh.send(()).ok(); } + + fn handle_models_updated_event( + &mut self, + _registry: Entity, + _event: &language_model::Event, + cx: &mut Context, + ) { + self.models.refresh_list(cx); + for session in self.sessions.values_mut() { + session.thread.update(cx, |thread, _| { + let model_id = LanguageModels::model_id(&thread.selected_model); + if let Some(model) = self.models.model_from_id(&model_id) { + thread.selected_model = model.clone(); + } + }); + } + } } /// Wrapper struct that implements the AgentConnection trait #[derive(Clone)] pub struct NativeAgentConnection(pub Entity); -impl ModelSelector for NativeAgentConnection { - fn list_models(&self, cx: &mut AsyncApp) -> Task>>> { +impl AgentModelSelector for NativeAgentConnection { + fn list_models(&self, cx: &mut App) -> Task> { log::debug!("NativeAgentConnection::list_models called"); - cx.spawn(async move |cx| { - cx.update(|cx| { - let registry = LanguageModelRegistry::read_global(cx); - let models = registry.available_models(cx).collect::>(); - log::info!("Found {} available models", models.len()); - if models.is_empty() { - Err(anyhow::anyhow!("No models available")) - } else { - Ok(models) - } - })? + let list = self.0.read(cx).models.model_list.clone(); + Task::ready(if list.is_empty() { + Err(anyhow::anyhow!("No models available")) + } else { + Ok(list) }) } fn select_model( &self, session_id: acp::SessionId, - model: Arc, - cx: &mut AsyncApp, + model_id: acp_thread::AgentModelId, + cx: &mut App, ) -> Task> { - log::info!( - "Setting model for session {}: {:?}", - session_id, - model.name() - ); - let agent = self.0.clone(); + log::info!("Setting model for session {}: {}", session_id, model_id); + let Some(thread) = self + .0 + .read(cx) + .sessions + .get(&session_id) + .map(|session| session.thread.clone()) + else { + return Task::ready(Err(anyhow!("Session not found"))); + }; - cx.spawn(async move |cx| { - agent.update(cx, |agent, cx| { - if let Some(session) = agent.sessions.get(&session_id) { - session.thread.update(cx, |thread, _cx| { - thread.selected_model = model; - }); - Ok(()) - } else { - Err(anyhow!("Session not found")) - } - })? - }) + let Some(model) = self.0.read(cx).models.model_from_id(&model_id) else { + return Task::ready(Err(anyhow!("Invalid model ID {}", model_id))); + }; + + thread.update(cx, |thread, _cx| { + thread.selected_model = model.clone(); + }); + + update_settings_file::( + self.0.read(cx).fs.clone(), + cx, + move |settings, _cx| { + settings.set_model(model); + }, + ); + + Task::ready(Ok(())) } fn selected_model( &self, session_id: &acp::SessionId, - cx: &mut AsyncApp, - ) -> Task>> { - let agent = self.0.clone(); + cx: &mut App, + ) -> Task> { let session_id = session_id.clone(); - cx.spawn(async move |cx| { - let thread = agent - .read_with(cx, |agent, _| { - agent - .sessions - .get(&session_id) - .map(|session| session.thread.clone()) - })? - .ok_or_else(|| anyhow::anyhow!("Session not found"))?; - let selected = thread.read_with(cx, |thread, _| thread.selected_model.clone())?; - Ok(selected) - }) + + let Some(thread) = self + .0 + .read(cx) + .sessions + .get(&session_id) + .map(|session| session.thread.clone()) + else { + return Task::ready(Err(anyhow!("Session not found"))); + }; + let model = thread.read(cx).selected_model.clone(); + let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id()) + else { + return Task::ready(Err(anyhow!("Provider not found"))); + }; + Task::ready(Ok(LanguageModels::map_language_model_to_info( + &model, &provider, + ))) + } + + fn watch(&self, cx: &mut App) -> watch::Receiver<()> { + self.0.read(cx).models.watch() } } @@ -413,13 +561,10 @@ impl acp_thread::AgentConnection for NativeAgentConnection { let default_model = registry .default_model() - .map(|configured| { - log::info!( - "Using configured default model: {:?} from provider: {:?}", - configured.model.name(), - configured.provider.name() - ); - configured.model + .and_then(|default_model| { + agent + .models + .model_from_id(&LanguageModels::model_id(&default_model.model)) }) .ok_or_else(|| { log::warn!("No default model configured in settings"); @@ -487,15 +632,17 @@ impl acp_thread::AgentConnection for NativeAgentConnection { Task::ready(Ok(())) } - fn model_selector(&self) -> Option> { - Some(Rc::new(self.clone()) as Rc) + fn model_selector(&self) -> Option> { + Some(Rc::new(self.clone()) as Rc) } fn prompt( &self, + id: Option, params: acp::PromptRequest, cx: &mut App, ) -> Task> { + let id = id.expect("UserMessageId is required"); let session_id = params.session_id.clone(); let agent = self.0.clone(); log::info!("Received prompt request for session: {}", session_id); @@ -516,13 +663,14 @@ impl acp_thread::AgentConnection for NativeAgentConnection { })?; log::debug!("Found session for: {}", session_id); - let message: Vec = params + let content: Vec = params .prompt .into_iter() .map(Into::into) .collect::>(); - log::info!("Converted prompt to message: {} chars", message.len()); - log::debug!("Message content: {:?}", message); + log::info!("Converted prompt to message: {} chars", content.len()); + log::debug!("Message id: {:?}", id); + log::debug!("Message content: {:?}", content); // Get model using the ModelSelector capability (always available for agent2) // Get the selected model from the thread directly @@ -530,7 +678,8 @@ impl acp_thread::AgentConnection for NativeAgentConnection { // Send to thread log::info!("Sending message to thread with model: {:?}", model.name()); - let mut response_stream = thread.update(cx, |thread, cx| thread.send(message, cx))?; + let mut response_stream = + thread.update(cx, |thread, cx| thread.send(id, content, cx))?; // Handle response stream and forward to session.acp_thread while let Some(result) = response_stream.next().await { @@ -624,11 +773,33 @@ impl acp_thread::AgentConnection for NativeAgentConnection { } }); } + + fn session_editor( + &self, + session_id: &agent_client_protocol::SessionId, + cx: &mut App, + ) -> Option> { + self.0.update(cx, |agent, _cx| { + agent + .sessions + .get(session_id) + .map(|session| Rc::new(NativeAgentSessionEditor(session.thread.clone())) as _) + }) + } +} + +struct NativeAgentSessionEditor(Entity); + +impl acp_thread::AgentSessionEditor for NativeAgentSessionEditor { + fn truncate(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task> { + Task::ready(self.0.update(cx, |thread, _cx| thread.truncate(message_id))) + } } #[cfg(test)] mod tests { use super::*; + use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelId, AgentModelInfo}; use fs::FakeFs; use gpui::TestAppContext; use serde_json::json; @@ -646,9 +817,15 @@ mod tests { ) .await; let project = Project::test(fs.clone(), [], cx).await; - let agent = NativeAgent::new(project.clone(), Templates::new(), None, &mut cx.to_async()) - .await - .unwrap(); + let agent = NativeAgent::new( + project.clone(), + Templates::new(), + None, + fs.clone(), + &mut cx.to_async(), + ) + .await + .unwrap(); agent.read_with(cx, |agent, _| { assert_eq!(agent.project_context.borrow().worktrees, vec![]) }); @@ -689,13 +866,131 @@ mod tests { }); } + #[gpui::test] + async fn test_listing_models(cx: &mut TestAppContext) { + init_test(cx); + let fs = FakeFs::new(cx.executor()); + fs.insert_tree("/", json!({ "a": {} })).await; + let project = Project::test(fs.clone(), [], cx).await; + let connection = NativeAgentConnection( + NativeAgent::new( + project.clone(), + Templates::new(), + None, + fs.clone(), + &mut cx.to_async(), + ) + .await + .unwrap(), + ); + + let models = cx.update(|cx| connection.list_models(cx)).await.unwrap(); + + let acp_thread::AgentModelList::Grouped(models) = models else { + panic!("Unexpected model group"); + }; + assert_eq!( + models, + IndexMap::from_iter([( + AgentModelGroupName("Fake".into()), + vec![AgentModelInfo { + id: AgentModelId("fake/fake".into()), + name: "Fake".into(), + icon: Some(ui::IconName::ZedAssistant), + }] + )]) + ); + } + + #[gpui::test] + async fn test_model_selection_persists_to_settings(cx: &mut TestAppContext) { + init_test(cx); + let fs = FakeFs::new(cx.executor()); + fs.create_dir(paths::settings_file().parent().unwrap()) + .await + .unwrap(); + fs.insert_file( + paths::settings_file(), + json!({ + "agent": { + "default_model": { + "provider": "foo", + "model": "bar" + } + } + }) + .to_string() + .into_bytes(), + ) + .await; + let project = Project::test(fs.clone(), [], cx).await; + + // Create the agent and connection + let agent = NativeAgent::new( + project.clone(), + Templates::new(), + None, + fs.clone(), + &mut cx.to_async(), + ) + .await + .unwrap(); + let connection = NativeAgentConnection(agent.clone()); + + // Create a thread/session + let acp_thread = cx + .update(|cx| { + Rc::new(connection.clone()).new_thread( + project.clone(), + Path::new("/a"), + &mut cx.to_async(), + ) + }) + .await + .unwrap(); + + let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone()); + + // Select a model + let model_id = AgentModelId("fake/fake".into()); + cx.update(|cx| connection.select_model(session_id.clone(), model_id.clone(), cx)) + .await + .unwrap(); + + // Verify the thread has the selected model + agent.read_with(cx, |agent, _| { + let session = agent.sessions.get(&session_id).unwrap(); + session.thread.read_with(cx, |thread, _| { + assert_eq!(thread.selected_model.id().0, "fake"); + }); + }); + + cx.run_until_parked(); + + // Verify settings file was updated + let settings_content = fs.load(paths::settings_file()).await.unwrap(); + let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap(); + + // Check that the agent settings contain the selected model + assert_eq!( + settings_json["agent"]["default_model"]["model"], + json!("fake") + ); + assert_eq!( + settings_json["agent"]["default_model"]["provider"], + json!("fake") + ); + } + fn init_test(cx: &mut TestAppContext) { env_logger::try_init().ok(); cx.update(|cx| { let settings_store = SettingsStore::test(cx); cx.set_global(settings_store); Project::init_settings(cx); + agent_settings::init(cx); language::init(cx); + LanguageModelRegistry::test(cx); }); } } diff --git a/crates/agent2/src/native_agent_server.rs b/crates/agent2/src/native_agent_server.rs index 58f6d37c54..cadd88a846 100644 --- a/crates/agent2/src/native_agent_server.rs +++ b/crates/agent2/src/native_agent_server.rs @@ -1,8 +1,8 @@ -use std::path::Path; -use std::rc::Rc; +use std::{path::Path, rc::Rc, sync::Arc}; use agent_servers::AgentServer; use anyhow::Result; +use fs::Fs; use gpui::{App, Entity, Task}; use project::Project; use prompt_store::PromptStore; @@ -10,7 +10,15 @@ use prompt_store::PromptStore; use crate::{NativeAgent, NativeAgentConnection, templates::Templates}; #[derive(Clone)] -pub struct NativeAgentServer; +pub struct NativeAgentServer { + fs: Arc, +} + +impl NativeAgentServer { + pub fn new(fs: Arc) -> Self { + Self { fs } + } +} impl AgentServer for NativeAgentServer { fn name(&self) -> &'static str { @@ -41,6 +49,7 @@ impl AgentServer for NativeAgentServer { _root_dir ); let project = project.clone(); + let fs = self.fs.clone(); let prompt_store = PromptStore::global(cx); cx.spawn(async move |cx| { log::debug!("Creating templates for native agent"); @@ -48,7 +57,7 @@ impl AgentServer for NativeAgentServer { let prompt_store = prompt_store.await?; log::debug!("Creating native agent entity"); - let agent = NativeAgent::new(project, templates, Some(prompt_store), cx).await?; + let agent = NativeAgent::new(project, templates, Some(prompt_store), fs, cx).await?; // Create the connection wrapper let connection = NativeAgentConnection(agent); diff --git a/crates/agent2/src/tests/mod.rs b/crates/agent2/src/tests/mod.rs index 88cf92836b..637af73d1a 100644 --- a/crates/agent2/src/tests/mod.rs +++ b/crates/agent2/src/tests/mod.rs @@ -1,6 +1,5 @@ use super::*; -use crate::MessageContent; -use acp_thread::AgentConnection; +use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelList, UserMessageId}; use action_log::ActionLog; use agent_client_protocol::{self as acp}; use agent_settings::AgentProfileId; @@ -38,15 +37,19 @@ async fn test_echo(cx: &mut TestAppContext) { let events = thread .update(cx, |thread, cx| { - thread.send("Testing: Reply with 'Hello'", cx) + thread.send(UserMessageId::new(), ["Testing: Reply with 'Hello'"], cx) }) .collect() .await; thread.update(cx, |thread, _cx| { assert_eq!( - thread.messages().last().unwrap().content, - vec![MessageContent::Text("Hello".to_string())] - ); + thread.last_message().unwrap().to_markdown(), + indoc! {" + ## Assistant + + Hello + "} + ) }); assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]); } @@ -59,12 +62,13 @@ async fn test_thinking(cx: &mut TestAppContext) { let events = thread .update(cx, |thread, cx| { thread.send( - indoc! {" + UserMessageId::new(), + [indoc! {" Testing: Generate a thinking step where you just think the word 'Think', and have your final answer be 'Hello' - "}, + "}], cx, ) }) @@ -72,9 +76,10 @@ async fn test_thinking(cx: &mut TestAppContext) { .await; thread.update(cx, |thread, _cx| { assert_eq!( - thread.messages().last().unwrap().to_markdown(), + thread.last_message().unwrap().to_markdown(), indoc! {" - ## assistant + ## Assistant + Think Hello "} @@ -95,7 +100,9 @@ async fn test_system_prompt(cx: &mut TestAppContext) { project_context.borrow_mut().shell = "test-shell".into(); thread.update(cx, |thread, _| thread.add_tool(EchoTool)); - thread.update(cx, |thread, cx| thread.send("abc", cx)); + thread.update(cx, |thread, cx| { + thread.send(UserMessageId::new(), ["abc"], cx) + }); cx.run_until_parked(); let mut pending_completions = fake_model.pending_completions(); assert_eq!( @@ -132,7 +139,8 @@ async fn test_basic_tool_calls(cx: &mut TestAppContext) { .update(cx, |thread, cx| { thread.add_tool(EchoTool); thread.send( - "Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'.", + UserMessageId::new(), + ["Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'."], cx, ) }) @@ -146,7 +154,11 @@ async fn test_basic_tool_calls(cx: &mut TestAppContext) { thread.remove_tool(&AgentTool::name(&EchoTool)); thread.add_tool(DelayTool); thread.send( - "Now call the delay tool with 200ms. When the timer goes off, then you echo the output of the tool.", + UserMessageId::new(), + [ + "Now call the delay tool with 200ms.", + "When the timer goes off, then you echo the output of the tool.", + ], cx, ) }) @@ -156,13 +168,14 @@ async fn test_basic_tool_calls(cx: &mut TestAppContext) { thread.update(cx, |thread, _cx| { assert!( thread - .messages() - .last() + .last_message() + .unwrap() + .as_agent_message() .unwrap() .content .iter() .any(|content| { - if let MessageContent::Text(text) = content { + if let AgentMessageContent::Text(text) = content { text.contains("Ding") } else { false @@ -182,7 +195,7 @@ async fn test_streaming_tool_calls(cx: &mut TestAppContext) { // Test a tool call that's likely to complete *before* streaming stops. let mut events = thread.update(cx, |thread, cx| { thread.add_tool(WordListTool); - thread.send("Test the word_list tool.", cx) + thread.send(UserMessageId::new(), ["Test the word_list tool."], cx) }); let mut saw_partial_tool_use = false; @@ -190,8 +203,10 @@ async fn test_streaming_tool_calls(cx: &mut TestAppContext) { if let Ok(AgentResponseEvent::ToolCall(tool_call)) = event { thread.update(cx, |thread, _cx| { // Look for a tool use in the thread's last message - let last_content = thread.messages().last().unwrap().content.last().unwrap(); - if let MessageContent::ToolUse(last_tool_use) = last_content { + let message = thread.last_message().unwrap(); + let agent_message = message.as_agent_message().unwrap(); + let last_content = agent_message.content.last().unwrap(); + if let AgentMessageContent::ToolUse(last_tool_use) = last_content { assert_eq!(last_tool_use.name.as_ref(), "word_list"); if tool_call.status == acp::ToolCallStatus::Pending { if !last_tool_use.is_input_complete @@ -229,7 +244,7 @@ async fn test_tool_authorization(cx: &mut TestAppContext) { let mut events = thread.update(cx, |thread, cx| { thread.add_tool(ToolRequiringPermission); - thread.send("abc", cx) + thread.send(UserMessageId::new(), ["abc"], cx) }); cx.run_until_parked(); fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse( @@ -357,7 +372,9 @@ async fn test_tool_hallucination(cx: &mut TestAppContext) { let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await; let fake_model = model.as_fake(); - let mut events = thread.update(cx, |thread, cx| thread.send("abc", cx)); + let mut events = thread.update(cx, |thread, cx| { + thread.send(UserMessageId::new(), ["abc"], cx) + }); cx.run_until_parked(); fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse( LanguageModelToolUse { @@ -449,7 +466,12 @@ async fn test_concurrent_tool_calls(cx: &mut TestAppContext) { .update(cx, |thread, cx| { thread.add_tool(DelayTool); thread.send( - "Call the delay tool twice in the same message. Once with 100ms. Once with 300ms. When both timers are complete, describe the outputs.", + UserMessageId::new(), + [ + "Call the delay tool twice in the same message.", + "Once with 100ms. Once with 300ms.", + "When both timers are complete, describe the outputs.", + ], cx, ) }) @@ -460,12 +482,13 @@ async fn test_concurrent_tool_calls(cx: &mut TestAppContext) { assert_eq!(stop_reasons, vec![acp::StopReason::EndTurn]); thread.update(cx, |thread, _cx| { - let last_message = thread.messages().last().unwrap(); - let text = last_message + let last_message = thread.last_message().unwrap(); + let agent_message = last_message.as_agent_message().unwrap(); + let text = agent_message .content .iter() .filter_map(|content| { - if let MessageContent::Text(text) = content { + if let AgentMessageContent::Text(text) = content { Some(text.as_str()) } else { None @@ -521,7 +544,7 @@ async fn test_profiles(cx: &mut TestAppContext) { // Test that test-1 profile (default) has echo and delay tools thread.update(cx, |thread, cx| { thread.set_profile(AgentProfileId("test-1".into())); - thread.send("test", cx); + thread.send(UserMessageId::new(), ["test"], cx); }); cx.run_until_parked(); @@ -539,7 +562,7 @@ async fn test_profiles(cx: &mut TestAppContext) { // Switch to test-2 profile, and verify that it has only the infinite tool. thread.update(cx, |thread, cx| { thread.set_profile(AgentProfileId("test-2".into())); - thread.send("test2", cx) + thread.send(UserMessageId::new(), ["test2"], cx) }); cx.run_until_parked(); let mut pending_completions = fake_model.pending_completions(); @@ -562,7 +585,8 @@ async fn test_cancellation(cx: &mut TestAppContext) { thread.add_tool(InfiniteTool); thread.add_tool(EchoTool); thread.send( - "Call the echo tool and then call the infinite tool, then explain their output", + UserMessageId::new(), + ["Call the echo tool, then call the infinite tool, then explain their output"], cx, ) }); @@ -607,14 +631,20 @@ async fn test_cancellation(cx: &mut TestAppContext) { // Ensure we can still send a new message after cancellation. let events = thread .update(cx, |thread, cx| { - thread.send("Testing: reply with 'Hello' then stop.", cx) + thread.send( + UserMessageId::new(), + ["Testing: reply with 'Hello' then stop."], + cx, + ) }) .collect::>() .await; thread.update(cx, |thread, _cx| { + let message = thread.last_message().unwrap(); + let agent_message = message.as_agent_message().unwrap(); assert_eq!( - thread.messages().last().unwrap().content, - vec![MessageContent::Text("Hello".to_string())] + agent_message.content, + vec![AgentMessageContent::Text("Hello".to_string())] ); }); assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]); @@ -625,13 +655,16 @@ async fn test_refusal(cx: &mut TestAppContext) { let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await; let fake_model = model.as_fake(); - let events = thread.update(cx, |thread, cx| thread.send("Hello", cx)); + let events = thread.update(cx, |thread, cx| { + thread.send(UserMessageId::new(), ["Hello"], cx) + }); cx.run_until_parked(); thread.read_with(cx, |thread, _| { assert_eq!( thread.to_markdown(), indoc! {" - ## user + ## User + Hello "} ); @@ -643,9 +676,12 @@ async fn test_refusal(cx: &mut TestAppContext) { assert_eq!( thread.to_markdown(), indoc! {" - ## user + ## User + Hello - ## assistant + + ## Assistant + Hey! "} ); @@ -661,6 +697,85 @@ async fn test_refusal(cx: &mut TestAppContext) { }); } +#[gpui::test] +async fn test_truncate(cx: &mut TestAppContext) { + let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await; + let fake_model = model.as_fake(); + + let message_id = UserMessageId::new(); + thread.update(cx, |thread, cx| { + thread.send(message_id.clone(), ["Hello"], cx) + }); + cx.run_until_parked(); + thread.read_with(cx, |thread, _| { + assert_eq!( + thread.to_markdown(), + indoc! {" + ## User + + Hello + "} + ); + }); + + fake_model.send_last_completion_stream_text_chunk("Hey!"); + cx.run_until_parked(); + thread.read_with(cx, |thread, _| { + assert_eq!( + thread.to_markdown(), + indoc! {" + ## User + + Hello + + ## Assistant + + Hey! + "} + ); + }); + + thread + .update(cx, |thread, _cx| thread.truncate(message_id)) + .unwrap(); + cx.run_until_parked(); + thread.read_with(cx, |thread, _| { + assert_eq!(thread.to_markdown(), ""); + }); + + // Ensure we can still send a new message after truncation. + thread.update(cx, |thread, cx| { + thread.send(UserMessageId::new(), ["Hi"], cx) + }); + thread.update(cx, |thread, _cx| { + assert_eq!( + thread.to_markdown(), + indoc! {" + ## User + + Hi + "} + ); + }); + cx.run_until_parked(); + fake_model.send_last_completion_stream_text_chunk("Ahoy!"); + cx.run_until_parked(); + thread.read_with(cx, |thread, _| { + assert_eq!( + thread.to_markdown(), + indoc! {" + ## User + + Hi + + ## Assistant + + Ahoy! + "} + ); + }); +} + #[gpui::test] async fn test_agent_connection(cx: &mut TestAppContext) { cx.update(settings::init); @@ -686,13 +801,19 @@ async fn test_agent_connection(cx: &mut TestAppContext) { // Create a project for new_thread let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone())); fake_fs.insert_tree(path!("/test"), json!({})).await; - let project = Project::test(fake_fs, [Path::new("/test")], cx).await; + let project = Project::test(fake_fs.clone(), [Path::new("/test")], cx).await; let cwd = Path::new("/test"); // Create agent and connection - let agent = NativeAgent::new(project.clone(), templates.clone(), None, &mut cx.to_async()) - .await - .unwrap(); + let agent = NativeAgent::new( + project.clone(), + templates.clone(), + None, + fake_fs.clone(), + &mut cx.to_async(), + ) + .await + .unwrap(); let connection = NativeAgentConnection(agent.clone()); // Test model_selector returns Some @@ -705,22 +826,22 @@ async fn test_agent_connection(cx: &mut TestAppContext) { // Test list_models let listed_models = cx - .update(|cx| { - let mut async_cx = cx.to_async(); - selector.list_models(&mut async_cx) - }) + .update(|cx| selector.list_models(cx)) .await .expect("list_models should succeed"); + let AgentModelList::Grouped(listed_models) = listed_models else { + panic!("Unexpected model list type"); + }; assert!(!listed_models.is_empty(), "should have at least one model"); - assert_eq!(listed_models[0].id().0, "fake"); + assert_eq!( + listed_models[&AgentModelGroupName("Fake".into())][0].id.0, + "fake/fake" + ); // Create a thread using new_thread let connection_rc = Rc::new(connection.clone()); let acp_thread = cx - .update(|cx| { - let mut async_cx = cx.to_async(); - connection_rc.new_thread(project, cwd, &mut async_cx) - }) + .update(|cx| connection_rc.new_thread(project, cwd, &mut cx.to_async())) .await .expect("new_thread should succeed"); @@ -729,12 +850,12 @@ async fn test_agent_connection(cx: &mut TestAppContext) { // Test selected_model returns the default let model = cx - .update(|cx| { - let mut async_cx = cx.to_async(); - selector.selected_model(&session_id, &mut async_cx) - }) + .update(|cx| selector.selected_model(&session_id, cx)) .await .expect("selected_model should succeed"); + let model = cx + .update(|cx| agent.read(cx).models().model_from_id(&model.id)) + .unwrap(); let model = model.as_fake(); assert_eq!(model.id().0, "fake", "should return default model"); @@ -768,6 +889,7 @@ async fn test_agent_connection(cx: &mut TestAppContext) { let result = cx .update(|cx| { connection.prompt( + Some(acp_thread::UserMessageId::new()), acp::PromptRequest { session_id: session_id.clone(), prompt: vec!["ghi".into()], @@ -790,7 +912,9 @@ async fn test_tool_updates_to_completion(cx: &mut TestAppContext) { thread.update(cx, |thread, _cx| thread.add_tool(ThinkingTool)); let fake_model = model.as_fake(); - let mut events = thread.update(cx, |thread, cx| thread.send("Think", cx)); + let mut events = thread.update(cx, |thread, cx| { + thread.send(UserMessageId::new(), ["Think"], cx) + }); cx.run_until_parked(); // Simulate streaming partial input. diff --git a/crates/agent2/src/thread.rs b/crates/agent2/src/thread.rs index 678e4cb5d2..204b489124 100644 --- a/crates/agent2/src/thread.rs +++ b/crates/agent2/src/thread.rs @@ -1,12 +1,12 @@ use crate::{ContextServerRegistry, SystemPromptTemplate, Template, Templates}; -use acp_thread::MentionUri; +use acp_thread::{MentionUri, UserMessageId}; use action_log::ActionLog; use agent_client_protocol as acp; use agent_settings::{AgentProfileId, AgentSettings}; use anyhow::{Context as _, Result, anyhow}; use assistant_tool::adapt_schema_to_format; use cloud_llm_client::{CompletionIntent, CompletionMode}; -use collections::HashMap; +use collections::IndexMap; use fs::Fs; use futures::{ channel::{mpsc, oneshot}, @@ -15,11 +15,10 @@ use futures::{ use gpui::{App, Context, Entity, SharedString, Task}; use language_model::{ LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelImage, - LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool, - LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolSchemaFormat, - LanguageModelToolUse, LanguageModelToolUseId, Role, StopReason, + LanguageModelProviderId, LanguageModelRequest, LanguageModelRequestMessage, + LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolResultContent, + LanguageModelToolSchemaFormat, LanguageModelToolUse, LanguageModelToolUseId, Role, StopReason, }; -use log; use project::Project; use prompt_store::ProjectContext; use schemars::{JsonSchema, Schema}; @@ -30,49 +29,199 @@ use std::fmt::Write; use std::{cell::RefCell, collections::BTreeMap, path::Path, rc::Rc, sync::Arc}; use util::{ResultExt, markdown::MarkdownCodeBlock}; -#[derive(Debug, Clone)] -pub struct AgentMessage { - pub role: Role, - pub content: Vec, +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum Message { + User(UserMessage), + Agent(AgentMessage), +} + +impl Message { + pub fn as_agent_message(&self) -> Option<&AgentMessage> { + match self { + Message::Agent(agent_message) => Some(agent_message), + _ => None, + } + } + + pub fn to_markdown(&self) -> String { + match self { + Message::User(message) => message.to_markdown(), + Message::Agent(message) => message.to_markdown(), + } + } } #[derive(Debug, Clone, PartialEq, Eq)] -pub enum MessageContent { +pub struct UserMessage { + pub id: UserMessageId, + pub content: Vec, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum UserMessageContent { Text(String), - Thinking { - text: String, - signature: Option, - }, - Mention { - uri: MentionUri, - content: String, - }, - RedactedThinking(String), + Mention { uri: MentionUri, content: String }, Image(LanguageModelImage), - ToolUse(LanguageModelToolUse), - ToolResult(LanguageModelToolResult), +} + +impl UserMessage { + pub fn to_markdown(&self) -> String { + let mut markdown = String::from("## User\n\n"); + + for content in &self.content { + match content { + UserMessageContent::Text(text) => { + markdown.push_str(text); + markdown.push('\n'); + } + UserMessageContent::Image(_) => { + markdown.push_str("\n"); + } + UserMessageContent::Mention { uri, content } => { + if !content.is_empty() { + markdown.push_str(&format!("{}\n\n{}\n", uri.to_link(), content)); + } else { + markdown.push_str(&format!("{}\n", uri.to_link())); + } + } + } + } + + markdown + } + + fn to_request(&self) -> LanguageModelRequestMessage { + let mut message = LanguageModelRequestMessage { + role: Role::User, + content: Vec::with_capacity(self.content.len()), + cache: false, + }; + + const OPEN_CONTEXT: &str = "\n\ + The following items were attached by the user. \ + They are up-to-date and don't need to be re-read.\n\n"; + + const OPEN_FILES_TAG: &str = ""; + const OPEN_SYMBOLS_TAG: &str = ""; + const OPEN_THREADS_TAG: &str = ""; + const OPEN_RULES_TAG: &str = + "\nThe user has specified the following rules that should be applied:\n"; + + let mut file_context = OPEN_FILES_TAG.to_string(); + let mut symbol_context = OPEN_SYMBOLS_TAG.to_string(); + let mut thread_context = OPEN_THREADS_TAG.to_string(); + let mut rules_context = OPEN_RULES_TAG.to_string(); + + for chunk in &self.content { + let chunk = match chunk { + UserMessageContent::Text(text) => { + language_model::MessageContent::Text(text.clone()) + } + UserMessageContent::Image(value) => { + language_model::MessageContent::Image(value.clone()) + } + UserMessageContent::Mention { uri, content } => { + match uri { + MentionUri::File(path) | MentionUri::Symbol(path, _) => { + write!( + &mut symbol_context, + "\n{}", + MarkdownCodeBlock { + tag: &codeblock_tag(&path), + text: &content.to_string(), + } + ) + .ok(); + } + MentionUri::Thread(_session_id) => { + write!(&mut thread_context, "\n{}\n", content).ok(); + } + MentionUri::Rule(_user_prompt_id) => { + write!( + &mut rules_context, + "\n{}", + MarkdownCodeBlock { + tag: "", + text: &content + } + ) + .ok(); + } + } + + language_model::MessageContent::Text(uri.to_link()) + } + }; + + message.content.push(chunk); + } + + let len_before_context = message.content.len(); + + if file_context.len() > OPEN_FILES_TAG.len() { + file_context.push_str("\n"); + message + .content + .push(language_model::MessageContent::Text(file_context)); + } + + if symbol_context.len() > OPEN_SYMBOLS_TAG.len() { + symbol_context.push_str("\n"); + message + .content + .push(language_model::MessageContent::Text(symbol_context)); + } + + if thread_context.len() > OPEN_THREADS_TAG.len() { + thread_context.push_str("\n"); + message + .content + .push(language_model::MessageContent::Text(thread_context)); + } + + if rules_context.len() > OPEN_RULES_TAG.len() { + rules_context.push_str("\n"); + message + .content + .push(language_model::MessageContent::Text(rules_context)); + } + + if message.content.len() > len_before_context { + message.content.insert( + len_before_context, + language_model::MessageContent::Text(OPEN_CONTEXT.into()), + ); + message + .content + .push(language_model::MessageContent::Text("".into())); + } + + message + } } impl AgentMessage { pub fn to_markdown(&self) -> String { - let mut markdown = format!("## {}\n", self.role); + let mut markdown = String::from("## Assistant\n\n"); for content in &self.content { match content { - MessageContent::Text(text) => { + AgentMessageContent::Text(text) => { markdown.push_str(text); markdown.push('\n'); } - MessageContent::Thinking { text, .. } => { + AgentMessageContent::Thinking { text, .. } => { markdown.push_str(""); markdown.push_str(text); markdown.push_str("\n"); } - MessageContent::RedactedThinking(_) => markdown.push_str("\n"), - MessageContent::Image(_) => { + AgentMessageContent::RedactedThinking(_) => { + markdown.push_str("\n") + } + AgentMessageContent::Image(_) => { markdown.push_str("\n"); } - MessageContent::ToolUse(tool_use) => { + AgentMessageContent::ToolUse(tool_use) => { markdown.push_str(&format!( "**Tool Use**: {} (ID: {})\n", tool_use.name, tool_use.id @@ -85,41 +234,106 @@ impl AgentMessage { } )); } - MessageContent::ToolResult(tool_result) => { - markdown.push_str(&format!( - "**Tool Result**: {} (ID: {})\n\n", - tool_result.tool_name, tool_result.tool_use_id - )); - if tool_result.is_error { - markdown.push_str("**ERROR:**\n"); - } + } + } - match &tool_result.content { - LanguageModelToolResultContent::Text(text) => { - writeln!(markdown, "{text}\n").ok(); - } - LanguageModelToolResultContent::Image(_) => { - writeln!(markdown, "\n").ok(); - } - } + for tool_result in self.tool_results.values() { + markdown.push_str(&format!( + "**Tool Result**: {} (ID: {})\n\n", + tool_result.tool_name, tool_result.tool_use_id + )); + if tool_result.is_error { + markdown.push_str("**ERROR:**\n"); + } - if let Some(output) = tool_result.output.as_ref() { - writeln!( - markdown, - "**Debug Output**:\n\n```json\n{}\n```\n", - serde_json::to_string_pretty(output).unwrap() - ) - .unwrap(); - } + match &tool_result.content { + LanguageModelToolResultContent::Text(text) => { + writeln!(markdown, "{text}\n").ok(); } - MessageContent::Mention { uri, .. } => { - write!(markdown, "{}", uri.to_link()).ok(); + LanguageModelToolResultContent::Image(_) => { + writeln!(markdown, "\n").ok(); } } + + if let Some(output) = tool_result.output.as_ref() { + writeln!( + markdown, + "**Debug Output**:\n\n```json\n{}\n```\n", + serde_json::to_string_pretty(output).unwrap() + ) + .unwrap(); + } } markdown } + + pub fn to_request(&self) -> Vec { + let mut content = Vec::with_capacity(self.content.len()); + for chunk in &self.content { + let chunk = match chunk { + AgentMessageContent::Text(text) => { + language_model::MessageContent::Text(text.clone()) + } + AgentMessageContent::Thinking { text, signature } => { + language_model::MessageContent::Thinking { + text: text.clone(), + signature: signature.clone(), + } + } + AgentMessageContent::RedactedThinking(value) => { + language_model::MessageContent::RedactedThinking(value.clone()) + } + AgentMessageContent::ToolUse(value) => { + language_model::MessageContent::ToolUse(value.clone()) + } + AgentMessageContent::Image(value) => { + language_model::MessageContent::Image(value.clone()) + } + }; + content.push(chunk); + } + + let mut messages = vec![LanguageModelRequestMessage { + role: Role::Assistant, + content, + cache: false, + }]; + + if !self.tool_results.is_empty() { + let mut tool_results = Vec::with_capacity(self.tool_results.len()); + for tool_result in self.tool_results.values() { + tool_results.push(language_model::MessageContent::ToolResult( + tool_result.clone(), + )); + } + messages.push(LanguageModelRequestMessage { + role: Role::User, + content: tool_results, + cache: false, + }); + } + + messages + } +} + +#[derive(Default, Debug, Clone, PartialEq, Eq)] +pub struct AgentMessage { + pub content: Vec, + pub tool_results: IndexMap, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum AgentMessageContent { + Text(String), + Thinking { + text: String, + signature: Option, + }, + RedactedThinking(String), + Image(LanguageModelImage), + ToolUse(LanguageModelToolUse), } #[derive(Debug)] @@ -140,13 +354,13 @@ pub struct ToolCallAuthorization { } pub struct Thread { - messages: Vec, + messages: Vec, completion_mode: CompletionMode, /// Holds the task that handles agent interaction until the end of the turn. /// Survives across multiple requests as the model performs tool calls and /// we run tools, report their results. running_turn: Option>, - pending_tool_uses: HashMap, + pending_agent_message: Option, tools: BTreeMap>, context_server_registry: Entity, profile_id: AgentProfileId, @@ -172,7 +386,7 @@ impl Thread { messages: Vec::new(), completion_mode: CompletionMode::Normal, running_turn: None, - pending_tool_uses: HashMap::default(), + pending_agent_message: None, tools: BTreeMap::default(), context_server_registry, profile_id, @@ -196,8 +410,13 @@ impl Thread { self.completion_mode = mode; } - pub fn messages(&self) -> &[AgentMessage] { - &self.messages + #[cfg(any(test, feature = "test-support"))] + pub fn last_message(&self) -> Option { + if let Some(message) = self.pending_agent_message.clone() { + Some(Message::Agent(message)) + } else { + self.messages.last().cloned() + } } pub fn add_tool(&mut self, tool: impl AgentTool) { @@ -213,35 +432,36 @@ impl Thread { } pub fn cancel(&mut self) { + // TODO: do we need to emit a stop::cancel for ACP? self.running_turn.take(); + self.flush_pending_agent_message(); + } - let tool_results = self - .pending_tool_uses - .drain() - .map(|(tool_use_id, tool_use)| { - MessageContent::ToolResult(LanguageModelToolResult { - tool_use_id, - tool_name: tool_use.name.clone(), - is_error: true, - content: LanguageModelToolResultContent::Text("Tool canceled by user".into()), - output: None, - }) - }) - .collect::>(); - self.last_user_message().content.extend(tool_results); + pub fn truncate(&mut self, message_id: UserMessageId) -> Result<()> { + self.cancel(); + let Some(position) = self.messages.iter().position( + |msg| matches!(msg, Message::User(UserMessage { id, .. }) if id == &message_id), + ) else { + return Err(anyhow!("Message not found")); + }; + self.messages.truncate(position); + Ok(()) } /// Sending a message results in the model streaming a response, which could include tool calls. /// After calling tools, the model will stops and waits for any outstanding tool calls to be completed and their results sent. /// The returned channel will report all the occurrences in which the model stops before erroring or ending its turn. - pub fn send( + pub fn send( &mut self, - content: impl Into, + message_id: UserMessageId, + content: impl IntoIterator, cx: &mut Context, - ) -> mpsc::UnboundedReceiver> { - let content = content.into().0; - + ) -> mpsc::UnboundedReceiver> + where + T: Into, + { let model = self.selected_model.clone(); + let content = content.into_iter().map(Into::into).collect::>(); log::info!("Thread::send called with model: {:?}", model.name()); log::debug!("Thread::send content: {:?}", content); @@ -251,10 +471,10 @@ impl Thread { let event_stream = AgentResponseEventStream(events_tx); let user_message_ix = self.messages.len(); - self.messages.push(AgentMessage { - role: Role::User, + self.messages.push(Message::User(UserMessage { + id: message_id, content, - }); + })); log::info!("Total messages in thread: {}", self.messages.len()); self.running_turn = Some(cx.spawn(async move |thread, cx| { log::info!("Starting agent turn execution"); @@ -270,15 +490,11 @@ impl Thread { thread.build_completion_request(completion_intent, cx) })?; - // println!( - // "request: {}", - // serde_json::to_string_pretty(&request).unwrap() - // ); - // Stream events, appending to messages and collecting up tool uses. log::info!("Calling model.stream_completion"); let mut events = model.stream_completion(request, cx).await?; log::debug!("Stream completion started successfully"); + let mut tool_uses = FuturesUnordered::new(); while let Some(event) = events.next().await { match event { @@ -286,6 +502,7 @@ impl Thread { event_stream.send_stop(reason); if reason == StopReason::Refusal { thread.update(cx, |thread, _cx| { + thread.pending_agent_message = None; thread.messages.truncate(user_message_ix); })?; break 'outer; @@ -338,15 +555,16 @@ impl Thread { ); thread .update(cx, |thread, _cx| { - thread.pending_tool_uses.remove(&tool_result.tool_use_id); thread - .last_user_message() - .content - .push(MessageContent::ToolResult(tool_result)); + .pending_agent_message() + .tool_results + .insert(tool_result.tool_use_id.clone(), tool_result); }) .ok(); } + thread.update(cx, |thread, _cx| thread.flush_pending_agent_message())?; + completion_intent = CompletionIntent::ToolResults; } @@ -354,6 +572,10 @@ impl Thread { } .await; + thread + .update(cx, |thread, _cx| thread.flush_pending_agent_message()) + .ok(); + if let Err(error) = turn_result { log::error!("Turn execution failed: {:?}", error); event_stream.send_error(error); @@ -364,7 +586,7 @@ impl Thread { events_rx } - pub fn build_system_message(&self) -> AgentMessage { + pub fn build_system_message(&self) -> LanguageModelRequestMessage { log::debug!("Building system message"); let prompt = SystemPromptTemplate { project: &self.project_context.borrow(), @@ -374,9 +596,10 @@ impl Thread { .context("failed to build system prompt") .expect("Invalid template"); log::debug!("System message built"); - AgentMessage { + LanguageModelRequestMessage { role: Role::System, - content: vec![prompt.as_str().into()], + content: vec![prompt.into()], + cache: true, } } @@ -394,10 +617,7 @@ impl Thread { match event { StartMessage { .. } => { - self.messages.push(AgentMessage { - role: Role::Assistant, - content: Vec::new(), - }); + self.messages.push(Message::Agent(AgentMessage::default())); } Text(new_text) => self.handle_text_event(new_text, event_stream, cx), Thinking { text, signature } => { @@ -435,11 +655,13 @@ impl Thread { ) { events_stream.send_text(&new_text); - let last_message = self.last_assistant_message(); - if let Some(MessageContent::Text(text)) = last_message.content.last_mut() { + let last_message = self.pending_agent_message(); + if let Some(AgentMessageContent::Text(text)) = last_message.content.last_mut() { text.push_str(&new_text); } else { - last_message.content.push(MessageContent::Text(new_text)); + last_message + .content + .push(AgentMessageContent::Text(new_text)); } cx.notify(); @@ -454,13 +676,14 @@ impl Thread { ) { event_stream.send_thinking(&new_text); - let last_message = self.last_assistant_message(); - if let Some(MessageContent::Thinking { text, signature }) = last_message.content.last_mut() + let last_message = self.pending_agent_message(); + if let Some(AgentMessageContent::Thinking { text, signature }) = + last_message.content.last_mut() { text.push_str(&new_text); *signature = new_signature.or(signature.take()); } else { - last_message.content.push(MessageContent::Thinking { + last_message.content.push(AgentMessageContent::Thinking { text: new_text, signature: new_signature, }); @@ -470,10 +693,10 @@ impl Thread { } fn handle_redacted_thinking_event(&mut self, data: String, cx: &mut Context) { - let last_message = self.last_assistant_message(); + let last_message = self.pending_agent_message(); last_message .content - .push(MessageContent::RedactedThinking(data)); + .push(AgentMessageContent::RedactedThinking(data)); cx.notify(); } @@ -486,14 +709,17 @@ impl Thread { cx.notify(); let tool = self.tools.get(tool_use.name.as_ref()).cloned(); - - self.pending_tool_uses - .insert(tool_use.id.clone(), tool_use.clone()); - let last_message = self.last_assistant_message(); + let mut title = SharedString::from(&tool_use.name); + let mut kind = acp::ToolKind::Other; + if let Some(tool) = tool.as_ref() { + title = tool.initial_title(tool_use.input.clone()); + kind = tool.kind(); + } // Ensure the last message ends in the current tool use + let last_message = self.pending_agent_message(); let push_new_tool_use = last_message.content.last_mut().map_or(true, |content| { - if let MessageContent::ToolUse(last_tool_use) = content { + if let AgentMessageContent::ToolUse(last_tool_use) = content { if last_tool_use.id == tool_use.id { *last_tool_use = tool_use.clone(); false @@ -505,18 +731,11 @@ impl Thread { } }); - let mut title = SharedString::from(&tool_use.name); - let mut kind = acp::ToolKind::Other; - if let Some(tool) = tool.as_ref() { - title = tool.initial_title(tool_use.input.clone()); - kind = tool.kind(); - } - if push_new_tool_use { event_stream.send_tool_call(&tool_use.id, title, kind, tool_use.input.clone()); last_message .content - .push(MessageContent::ToolUse(tool_use.clone())); + .push(AgentMessageContent::ToolUse(tool_use.clone())); } else { event_stream.update_tool_call_fields( &tool_use.id, @@ -601,30 +820,37 @@ impl Thread { } } - /// Guarantees the last message is from the assistant and returns a mutable reference. - fn last_assistant_message(&mut self) -> &mut AgentMessage { - if self - .messages - .last() - .map_or(true, |m| m.role != Role::Assistant) - { - self.messages.push(AgentMessage { - role: Role::Assistant, - content: Vec::new(), - }); - } - self.messages.last_mut().unwrap() + fn pending_agent_message(&mut self) -> &mut AgentMessage { + self.pending_agent_message.get_or_insert_default() } - /// Guarantees the last message is from the user and returns a mutable reference. - fn last_user_message(&mut self) -> &mut AgentMessage { - if self.messages.last().map_or(true, |m| m.role != Role::User) { - self.messages.push(AgentMessage { - role: Role::User, - content: Vec::new(), - }); + fn flush_pending_agent_message(&mut self) { + let Some(mut message) = self.pending_agent_message.take() else { + return; + }; + + for content in &message.content { + let AgentMessageContent::ToolUse(tool_use) = content else { + continue; + }; + + if !message.tool_results.contains_key(&tool_use.id) { + message.tool_results.insert( + tool_use.id.clone(), + LanguageModelToolResult { + tool_use_id: tool_use.id.clone(), + tool_name: tool_use.name.clone(), + is_error: true, + content: LanguageModelToolResultContent::Text( + "Tool canceled by user".into(), + ), + output: None, + }, + ); + } } - self.messages.last_mut().unwrap() + + self.messages.push(Message::Agent(message)); } pub(crate) fn build_completion_request( @@ -681,10 +907,12 @@ impl Thread { .profiles .get(&self.profile_id) .context("profile not found")?; + let provider_id = self.selected_model.provider_id(); Ok(self .tools .iter() + .filter(move |(_, tool)| tool.supported_provider(&provider_id)) .filter_map(|(tool_name, tool)| { if profile.is_tool_enabled(tool_name) { Some(tool) @@ -710,49 +938,39 @@ impl Thread { "Building request messages from {} thread messages", self.messages.len() ); + let mut messages = vec![self.build_system_message()]; + for message in &self.messages { + match message { + Message::User(message) => messages.push(message.to_request()), + Message::Agent(message) => messages.extend(message.to_request()), + } + } + + if let Some(message) = self.pending_agent_message.as_ref() { + messages.extend(message.to_request()); + } - let messages = Some(self.build_system_message()) - .iter() - .chain(self.messages.iter()) - .map(|message| { - log::trace!( - " - {} message with {} content items", - match message.role { - Role::System => "System", - Role::User => "User", - Role::Assistant => "Assistant", - }, - message.content.len() - ); - message.to_request() - }) - .collect(); messages } pub fn to_markdown(&self) -> String { let mut markdown = String::new(); - for message in &self.messages { + for (ix, message) in self.messages.iter().enumerate() { + if ix > 0 { + markdown.push('\n'); + } markdown.push_str(&message.to_markdown()); } + + if let Some(message) = self.pending_agent_message.as_ref() { + markdown.push('\n'); + markdown.push_str(&message.to_markdown()); + } + markdown } } -pub struct UserMessage(Vec); - -impl From> for UserMessage { - fn from(content: Vec) -> Self { - UserMessage(content) - } -} - -impl> From for UserMessage { - fn from(content: T) -> Self { - UserMessage(vec![content.into()]) - } -} - pub trait AgentTool where Self: 'static + Sized, @@ -782,6 +1000,12 @@ where schemars::schema_for!(Self::Input) } + /// Some tools rely on a provider for the underlying billing or other reasons. + /// Allow the tool to check if they are compatible, or should be filtered out. + fn supported_provider(&self, _provider: &LanguageModelProviderId) -> bool { + true + } + /// Runs the tool with the provided input. fn run( self: Arc, @@ -808,6 +1032,9 @@ pub trait AnyAgentTool { fn kind(&self) -> acp::ToolKind; fn initial_title(&self, input: serde_json::Value) -> SharedString; fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result; + fn supported_provider(&self, _provider: &LanguageModelProviderId) -> bool { + true + } fn run( self: Arc, input: serde_json::Value, @@ -843,6 +1070,10 @@ where Ok(json) } + fn supported_provider(&self, provider: &LanguageModelProviderId) -> bool { + self.0.supported_provider(provider) + } + fn run( self: Arc, input: serde_json::Value, @@ -1136,130 +1367,6 @@ impl std::ops::DerefMut for ToolCallEventStreamReceiver { } } -impl AgentMessage { - fn to_request(&self) -> language_model::LanguageModelRequestMessage { - let mut message = LanguageModelRequestMessage { - role: self.role, - content: Vec::with_capacity(self.content.len()), - cache: false, - }; - - const OPEN_CONTEXT: &str = "\n\ - The following items were attached by the user. \ - They are up-to-date and don't need to be re-read.\n\n"; - - const OPEN_FILES_TAG: &str = ""; - const OPEN_SYMBOLS_TAG: &str = ""; - const OPEN_THREADS_TAG: &str = ""; - const OPEN_RULES_TAG: &str = - "\nThe user has specified the following rules that should be applied:\n"; - - let mut file_context = OPEN_FILES_TAG.to_string(); - let mut symbol_context = OPEN_SYMBOLS_TAG.to_string(); - let mut thread_context = OPEN_THREADS_TAG.to_string(); - let mut rules_context = OPEN_RULES_TAG.to_string(); - - for chunk in &self.content { - let chunk = match chunk { - MessageContent::Text(text) => language_model::MessageContent::Text(text.clone()), - MessageContent::Thinking { text, signature } => { - language_model::MessageContent::Thinking { - text: text.clone(), - signature: signature.clone(), - } - } - MessageContent::RedactedThinking(value) => { - language_model::MessageContent::RedactedThinking(value.clone()) - } - MessageContent::ToolUse(value) => { - language_model::MessageContent::ToolUse(value.clone()) - } - MessageContent::ToolResult(value) => { - language_model::MessageContent::ToolResult(value.clone()) - } - MessageContent::Image(value) => { - language_model::MessageContent::Image(value.clone()) - } - MessageContent::Mention { uri, content } => { - match uri { - MentionUri::File(path) | MentionUri::Symbol(path, _) => { - write!( - &mut symbol_context, - "\n{}", - MarkdownCodeBlock { - tag: &codeblock_tag(&path), - text: &content.to_string(), - } - ) - .ok(); - } - MentionUri::Thread(_session_id) => { - write!(&mut thread_context, "\n{}\n", content).ok(); - } - MentionUri::Rule(_user_prompt_id) => { - write!( - &mut rules_context, - "\n{}", - MarkdownCodeBlock { - tag: "", - text: &content - } - ) - .ok(); - } - } - - language_model::MessageContent::Text(uri.to_link()) - } - }; - - message.content.push(chunk); - } - - let len_before_context = message.content.len(); - - if file_context.len() > OPEN_FILES_TAG.len() { - file_context.push_str("\n"); - message - .content - .push(language_model::MessageContent::Text(file_context)); - } - - if symbol_context.len() > OPEN_SYMBOLS_TAG.len() { - symbol_context.push_str("\n"); - message - .content - .push(language_model::MessageContent::Text(symbol_context)); - } - - if thread_context.len() > OPEN_THREADS_TAG.len() { - thread_context.push_str("\n"); - message - .content - .push(language_model::MessageContent::Text(thread_context)); - } - - if rules_context.len() > OPEN_RULES_TAG.len() { - rules_context.push_str("\n"); - message - .content - .push(language_model::MessageContent::Text(rules_context)); - } - - if message.content.len() > len_before_context { - message.content.insert( - len_before_context, - language_model::MessageContent::Text(OPEN_CONTEXT.into()), - ); - message - .content - .push(language_model::MessageContent::Text("".into())); - } - - message - } -} - fn codeblock_tag(full_path: &Path) -> String { let mut result = String::new(); @@ -1272,16 +1379,20 @@ fn codeblock_tag(full_path: &Path) -> String { result } -impl From for MessageContent { +impl From<&str> for UserMessageContent { + fn from(text: &str) -> Self { + Self::Text(text.into()) + } +} + +impl From for UserMessageContent { fn from(value: acp::ContentBlock) -> Self { match value { - acp::ContentBlock::Text(text_content) => MessageContent::Text(text_content.text), - acp::ContentBlock::Image(image_content) => { - MessageContent::Image(convert_image(image_content)) - } + acp::ContentBlock::Text(text_content) => Self::Text(text_content.text), + acp::ContentBlock::Image(image_content) => Self::Image(convert_image(image_content)), acp::ContentBlock::Audio(_) => { // TODO - MessageContent::Text("[audio]".to_string()) + Self::Text("[audio]".to_string()) } acp::ContentBlock::ResourceLink(resource_link) => { match MentionUri::parse(&resource_link.uri) { @@ -1291,10 +1402,7 @@ impl From for MessageContent { }, Err(err) => { log::error!("Failed to parse mention link: {}", err); - MessageContent::Text(format!( - "[{}]({})", - resource_link.name, resource_link.uri - )) + Self::Text(format!("[{}]({})", resource_link.name, resource_link.uri)) } } } @@ -1307,7 +1415,7 @@ impl From for MessageContent { }, Err(err) => { log::error!("Failed to parse mention link: {}", err); - MessageContent::Text( + Self::Text( MarkdownCodeBlock { tag: &resource.uri, text: &resource.text, @@ -1319,7 +1427,7 @@ impl From for MessageContent { } acp::EmbeddedResourceResource::BlobResourceContents(_) => { // TODO - MessageContent::Text("[blob]".to_string()) + Self::Text("[blob]".to_string()) } }, } @@ -1333,9 +1441,3 @@ fn convert_image(image_content: acp::ImageContent) -> LanguageModelImage { size: gpui::Size::new(0.into(), 0.into()), } } - -impl From<&str> for MessageContent { - fn from(text: &str) -> Self { - MessageContent::Text(text.into()) - } -} diff --git a/crates/agent2/src/tools/edit_file_tool.rs b/crates/agent2/src/tools/edit_file_tool.rs index 134bc5e5e4..405afb585f 100644 --- a/crates/agent2/src/tools/edit_file_tool.rs +++ b/crates/agent2/src/tools/edit_file_tool.rs @@ -1,12 +1,13 @@ use crate::{AgentTool, Thread, ToolCallEventStream}; use acp_thread::Diff; -use agent_client_protocol as acp; +use agent_client_protocol::{self as acp, ToolCallLocation, ToolCallUpdateFields}; use anyhow::{Context as _, Result, anyhow}; use assistant_tools::edit_agent::{EditAgent, EditAgentOutput, EditAgentOutputEvent, EditFormat}; use cloud_llm_client::CompletionIntent; use collections::HashSet; use gpui::{App, AppContext, AsyncApp, Entity, Task}; use indoc::formatdoc; +use language::ToPoint; use language::language_settings::{self, FormatOnSave}; use language_model::LanguageModelToolResultContent; use paths; @@ -225,6 +226,16 @@ impl AgentTool for EditFileTool { Ok(path) => path, Err(err) => return Task::ready(Err(anyhow!(err))), }; + let abs_path = project.read(cx).absolute_path(&project_path, cx); + if let Some(abs_path) = abs_path.clone() { + event_stream.update_fields(ToolCallUpdateFields { + locations: Some(vec![acp::ToolCallLocation { + path: abs_path, + line: None, + }]), + ..Default::default() + }); + } let request = self.thread.update(cx, |thread, cx| { thread.build_completion_request(CompletionIntent::ToolResults, cx) @@ -283,13 +294,38 @@ impl AgentTool for EditFileTool { let mut hallucinated_old_text = false; let mut ambiguous_ranges = Vec::new(); + let mut emitted_location = false; while let Some(event) = events.next().await { match event { - EditAgentOutputEvent::Edited => {}, + EditAgentOutputEvent::Edited(range) => { + if !emitted_location { + let line = buffer.update(cx, |buffer, _cx| { + range.start.to_point(&buffer.snapshot()).row + }).ok(); + if let Some(abs_path) = abs_path.clone() { + event_stream.update_fields(ToolCallUpdateFields { + locations: Some(vec![ToolCallLocation { path: abs_path, line }]), + ..Default::default() + }); + } + emitted_location = true; + } + }, EditAgentOutputEvent::UnresolvedEditRange => hallucinated_old_text = true, EditAgentOutputEvent::AmbiguousEditRange(ranges) => ambiguous_ranges = ranges, EditAgentOutputEvent::ResolvingEditRange(range) => { - diff.update(cx, |card, cx| card.reveal_range(range, cx))?; + diff.update(cx, |card, cx| card.reveal_range(range.clone(), cx))?; + // if !emitted_location { + // let line = buffer.update(cx, |buffer, _cx| { + // range.start.to_point(&buffer.snapshot()).row + // }).ok(); + // if let Some(abs_path) = abs_path.clone() { + // event_stream.update_fields(ToolCallUpdateFields { + // locations: Some(vec![ToolCallLocation { path: abs_path, line }]), + // ..Default::default() + // }); + // } + // } } } } diff --git a/crates/agent2/src/tools/read_file_tool.rs b/crates/agent2/src/tools/read_file_tool.rs index fac637d838..f21643cbbb 100644 --- a/crates/agent2/src/tools/read_file_tool.rs +++ b/crates/agent2/src/tools/read_file_tool.rs @@ -1,10 +1,10 @@ use action_log::ActionLog; -use agent_client_protocol::{self as acp}; +use agent_client_protocol::{self as acp, ToolCallUpdateFields}; use anyhow::{Context as _, Result, anyhow}; use assistant_tool::outline; use gpui::{App, Entity, SharedString, Task}; use indoc::formatdoc; -use language::{Anchor, Point}; +use language::Point; use language_model::{LanguageModelImage, LanguageModelToolResultContent}; use project::{AgentLocation, ImageItem, Project, WorktreeSettings, image_store}; use schemars::JsonSchema; @@ -97,7 +97,7 @@ impl AgentTool for ReadFileTool { fn run( self: Arc, input: Self::Input, - _event_stream: ToolCallEventStream, + event_stream: ToolCallEventStream, cx: &mut App, ) -> Task> { let Some(project_path) = self.project.read(cx).find_project_path(&input.path, cx) else { @@ -166,7 +166,9 @@ impl AgentTool for ReadFileTool { cx.spawn(async move |cx| { let buffer = cx .update(|cx| { - project.update(cx, |project, cx| project.open_buffer(project_path, cx)) + project.update(cx, |project, cx| { + project.open_buffer(project_path.clone(), cx) + }) })? .await?; if buffer.read_with(cx, |buffer, _| { @@ -178,19 +180,10 @@ impl AgentTool for ReadFileTool { anyhow::bail!("{file_path} not found"); } - project.update(cx, |project, cx| { - project.set_agent_location( - Some(AgentLocation { - buffer: buffer.downgrade(), - position: Anchor::MIN, - }), - cx, - ); - })?; + let mut anchor = None; // Check if specific line ranges are provided - if input.start_line.is_some() || input.end_line.is_some() { - let mut anchor = None; + let result = if input.start_line.is_some() || input.end_line.is_some() { let result = buffer.read_with(cx, |buffer, _cx| { let text = buffer.text(); // .max(1) because despite instructions to be 1-indexed, sometimes the model passes 0. @@ -214,18 +207,6 @@ impl AgentTool for ReadFileTool { log.buffer_read(buffer.clone(), cx); })?; - if let Some(anchor) = anchor { - project.update(cx, |project, cx| { - project.set_agent_location( - Some(AgentLocation { - buffer: buffer.downgrade(), - position: anchor, - }), - cx, - ); - })?; - } - Ok(result.into()) } else { // No line ranges specified, so check file size to see if it's too big. @@ -236,7 +217,7 @@ impl AgentTool for ReadFileTool { let result = buffer.read_with(cx, |buffer, _cx| buffer.text())?; action_log.update(cx, |log, cx| { - log.buffer_read(buffer, cx); + log.buffer_read(buffer.clone(), cx); })?; Ok(result.into()) @@ -244,7 +225,8 @@ impl AgentTool for ReadFileTool { // File is too big, so return the outline // and a suggestion to read again with line numbers. let outline = - outline::file_outline(project, file_path, action_log, None, cx).await?; + outline::file_outline(project.clone(), file_path, action_log, None, cx) + .await?; Ok(formatdoc! {" This file was too big to read all at once. @@ -261,7 +243,28 @@ impl AgentTool for ReadFileTool { } .into()) } - } + }; + + project.update(cx, |project, cx| { + if let Some(abs_path) = project.absolute_path(&project_path, cx) { + project.set_agent_location( + Some(AgentLocation { + buffer: buffer.downgrade(), + position: anchor.unwrap_or(text::Anchor::MIN), + }), + cx, + ); + event_stream.update_fields(ToolCallUpdateFields { + locations: Some(vec![acp::ToolCallLocation { + path: abs_path, + line: input.start_line.map(|line| line.saturating_sub(1)), + }]), + ..Default::default() + }); + } + })?; + + result }) } } diff --git a/crates/agent2/src/tools/web_search_tool.rs b/crates/agent2/src/tools/web_search_tool.rs index 12587c2f67..c1c0970742 100644 --- a/crates/agent2/src/tools/web_search_tool.rs +++ b/crates/agent2/src/tools/web_search_tool.rs @@ -5,7 +5,9 @@ use agent_client_protocol as acp; use anyhow::{Result, anyhow}; use cloud_llm_client::WebSearchResponse; use gpui::{App, AppContext, Task}; -use language_model::LanguageModelToolResultContent; +use language_model::{ + LanguageModelProviderId, LanguageModelToolResultContent, ZED_CLOUD_PROVIDER_ID, +}; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use ui::prelude::*; @@ -50,6 +52,11 @@ impl AgentTool for WebSearchTool { "Searching the Web".into() } + /// We currently only support Zed Cloud as a provider. + fn supported_provider(&self, provider: &LanguageModelProviderId) -> bool { + provider == &ZED_CLOUD_PROVIDER_ID + } + fn run( self: Arc, input: Self::Input, diff --git a/crates/agent_servers/src/acp/v0.rs b/crates/agent_servers/src/acp/v0.rs index 8d85435f92..327613de67 100644 --- a/crates/agent_servers/src/acp/v0.rs +++ b/crates/agent_servers/src/acp/v0.rs @@ -467,6 +467,7 @@ impl AgentConnection for AcpConnection { fn prompt( &self, + _id: Option, params: acp::PromptRequest, cx: &mut App, ) -> Task> { diff --git a/crates/agent_servers/src/acp/v1.rs b/crates/agent_servers/src/acp/v1.rs index ff71783b48..de397fddf0 100644 --- a/crates/agent_servers/src/acp/v1.rs +++ b/crates/agent_servers/src/acp/v1.rs @@ -171,6 +171,7 @@ impl AgentConnection for AcpConnection { fn prompt( &self, + _id: Option, params: acp::PromptRequest, cx: &mut App, ) -> Task> { diff --git a/crates/agent_servers/src/claude.rs b/crates/agent_servers/src/claude.rs index c65508f152..c394ec4a9c 100644 --- a/crates/agent_servers/src/claude.rs +++ b/crates/agent_servers/src/claude.rs @@ -210,6 +210,7 @@ impl AgentConnection for ClaudeAgentConnection { fn prompt( &self, + _id: Option, params: acp::PromptRequest, cx: &mut App, ) -> Task> { @@ -423,7 +424,7 @@ impl ClaudeAgentSession { if !turn_state.borrow().is_cancelled() { thread .update(cx, |thread, cx| { - thread.push_user_content_block(text.into(), cx) + thread.push_user_content_block(None, text.into(), cx) }) .log_err(); } diff --git a/crates/agent_ui/src/acp.rs b/crates/agent_ui/src/acp.rs index cc476b1a86..b9814adb2d 100644 --- a/crates/agent_ui/src/acp.rs +++ b/crates/agent_ui/src/acp.rs @@ -1,6 +1,10 @@ mod completion_provider; mod message_history; +mod model_selector; +mod model_selector_popover; mod thread_view; pub use message_history::MessageHistory; +pub use model_selector::AcpModelSelector; +pub use model_selector_popover::AcpModelSelectorPopover; pub use thread_view::AcpThreadView; diff --git a/crates/agent_ui/src/acp/model_selector.rs b/crates/agent_ui/src/acp/model_selector.rs new file mode 100644 index 0000000000..563afee65f --- /dev/null +++ b/crates/agent_ui/src/acp/model_selector.rs @@ -0,0 +1,472 @@ +use std::{cmp::Reverse, rc::Rc, sync::Arc}; + +use acp_thread::{AgentModelInfo, AgentModelList, AgentModelSelector}; +use agent_client_protocol as acp; +use anyhow::Result; +use collections::IndexMap; +use futures::FutureExt; +use fuzzy::{StringMatchCandidate, match_strings}; +use gpui::{Action, AsyncWindowContext, BackgroundExecutor, DismissEvent, Task, WeakEntity}; +use ordered_float::OrderedFloat; +use picker::{Picker, PickerDelegate}; +use ui::{ + AnyElement, App, Context, IntoElement, ListItem, ListItemSpacing, SharedString, Window, + prelude::*, rems, +}; +use util::ResultExt; + +pub type AcpModelSelector = Picker; + +pub fn acp_model_selector( + session_id: acp::SessionId, + selector: Rc, + window: &mut Window, + cx: &mut Context, +) -> AcpModelSelector { + let delegate = AcpModelPickerDelegate::new(session_id, selector, window, cx); + Picker::list(delegate, window, cx) + .show_scrollbar(true) + .width(rems(20.)) + .max_height(Some(rems(20.).into())) +} + +enum AcpModelPickerEntry { + Separator(SharedString), + Model(AgentModelInfo), +} + +pub struct AcpModelPickerDelegate { + session_id: acp::SessionId, + selector: Rc, + filtered_entries: Vec, + models: Option, + selected_index: usize, + selected_model: Option, + _refresh_models_task: Task<()>, +} + +impl AcpModelPickerDelegate { + fn new( + session_id: acp::SessionId, + selector: Rc, + window: &mut Window, + cx: &mut Context, + ) -> Self { + let mut rx = selector.watch(cx); + let refresh_models_task = cx.spawn_in(window, { + let session_id = session_id.clone(); + async move |this, cx| { + async fn refresh( + this: &WeakEntity>, + session_id: &acp::SessionId, + cx: &mut AsyncWindowContext, + ) -> Result<()> { + let (models_task, selected_model_task) = this.update(cx, |this, cx| { + ( + this.delegate.selector.list_models(cx), + this.delegate.selector.selected_model(session_id, cx), + ) + })?; + + let (models, selected_model) = futures::join!(models_task, selected_model_task); + + this.update_in(cx, |this, window, cx| { + this.delegate.models = models.ok(); + this.delegate.selected_model = selected_model.ok(); + this.delegate.update_matches(this.query(cx), window, cx) + })? + .await; + + Ok(()) + } + + refresh(&this, &session_id, cx).await.log_err(); + while let Ok(()) = rx.recv().await { + refresh(&this, &session_id, cx).await.log_err(); + } + } + }); + + Self { + session_id, + selector, + filtered_entries: Vec::new(), + models: None, + selected_model: None, + selected_index: 0, + _refresh_models_task: refresh_models_task, + } + } + + pub fn active_model(&self) -> Option<&AgentModelInfo> { + self.selected_model.as_ref() + } +} + +impl PickerDelegate for AcpModelPickerDelegate { + type ListItem = AnyElement; + + fn match_count(&self) -> usize { + self.filtered_entries.len() + } + + fn selected_index(&self) -> usize { + self.selected_index + } + + fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context>) { + self.selected_index = ix.min(self.filtered_entries.len().saturating_sub(1)); + cx.notify(); + } + + fn can_select( + &mut self, + ix: usize, + _window: &mut Window, + _cx: &mut Context>, + ) -> bool { + match self.filtered_entries.get(ix) { + Some(AcpModelPickerEntry::Model(_)) => true, + Some(AcpModelPickerEntry::Separator(_)) | None => false, + } + } + + fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc { + "Select a model…".into() + } + + fn update_matches( + &mut self, + query: String, + window: &mut Window, + cx: &mut Context>, + ) -> Task<()> { + cx.spawn_in(window, async move |this, cx| { + let filtered_models = match this + .read_with(cx, |this, cx| { + this.delegate.models.clone().map(move |models| { + fuzzy_search(models, query, cx.background_executor().clone()) + }) + }) + .ok() + .flatten() + { + Some(task) => task.await, + None => AgentModelList::Flat(vec![]), + }; + + this.update_in(cx, |this, window, cx| { + this.delegate.filtered_entries = + info_list_to_picker_entries(filtered_models).collect(); + // Finds the currently selected model in the list + let new_index = this + .delegate + .selected_model + .as_ref() + .and_then(|selected| { + this.delegate.filtered_entries.iter().position(|entry| { + if let AcpModelPickerEntry::Model(model_info) = entry { + model_info.id == selected.id + } else { + false + } + }) + }) + .unwrap_or(0); + this.set_selected_index(new_index, Some(picker::Direction::Down), true, window, cx); + cx.notify(); + }) + .ok(); + }) + } + + fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context>) { + if let Some(AcpModelPickerEntry::Model(model_info)) = + self.filtered_entries.get(self.selected_index) + { + self.selector + .select_model(self.session_id.clone(), model_info.id.clone(), cx) + .detach_and_log_err(cx); + self.selected_model = Some(model_info.clone()); + let current_index = self.selected_index; + self.set_selected_index(current_index, window, cx); + + cx.emit(DismissEvent); + } + } + + fn dismissed(&mut self, _: &mut Window, cx: &mut Context>) { + cx.emit(DismissEvent); + } + + fn render_match( + &self, + ix: usize, + selected: bool, + _: &mut Window, + cx: &mut Context>, + ) -> Option { + match self.filtered_entries.get(ix)? { + AcpModelPickerEntry::Separator(title) => Some( + div() + .px_2() + .pb_1() + .when(ix > 1, |this| { + this.mt_1() + .pt_2() + .border_t_1() + .border_color(cx.theme().colors().border_variant) + }) + .child( + Label::new(title) + .size(LabelSize::XSmall) + .color(Color::Muted), + ) + .into_any_element(), + ), + AcpModelPickerEntry::Model(model_info) => { + let is_selected = Some(model_info) == self.selected_model.as_ref(); + + let model_icon_color = if is_selected { + Color::Accent + } else { + Color::Muted + }; + + Some( + ListItem::new(ix) + .inset(true) + .spacing(ListItemSpacing::Sparse) + .toggle_state(selected) + .start_slot::(model_info.icon.map(|icon| { + Icon::new(icon) + .color(model_icon_color) + .size(IconSize::Small) + })) + .child( + h_flex() + .w_full() + .pl_0p5() + .gap_1p5() + .w(px(240.)) + .child(Label::new(model_info.name.clone()).truncate()), + ) + .end_slot(div().pr_3().when(is_selected, |this| { + this.child( + Icon::new(IconName::Check) + .color(Color::Accent) + .size(IconSize::Small), + ) + })) + .into_any_element(), + ) + } + } + } + + fn render_footer( + &self, + _: &mut Window, + cx: &mut Context>, + ) -> Option { + Some( + h_flex() + .w_full() + .border_t_1() + .border_color(cx.theme().colors().border_variant) + .p_1() + .gap_4() + .justify_between() + .child( + Button::new("configure", "Configure") + .icon(IconName::Settings) + .icon_size(IconSize::Small) + .icon_color(Color::Muted) + .icon_position(IconPosition::Start) + .on_click(|_, window, cx| { + window.dispatch_action( + zed_actions::agent::OpenSettings.boxed_clone(), + cx, + ); + }), + ) + .into_any(), + ) + } +} + +fn info_list_to_picker_entries( + model_list: AgentModelList, +) -> impl Iterator { + match model_list { + AgentModelList::Flat(list) => { + itertools::Either::Left(list.into_iter().map(AcpModelPickerEntry::Model)) + } + AgentModelList::Grouped(index_map) => { + itertools::Either::Right(index_map.into_iter().flat_map(|(group_name, models)| { + std::iter::once(AcpModelPickerEntry::Separator(group_name.0)) + .chain(models.into_iter().map(AcpModelPickerEntry::Model)) + })) + } + } +} + +async fn fuzzy_search( + model_list: AgentModelList, + query: String, + executor: BackgroundExecutor, +) -> AgentModelList { + async fn fuzzy_search_list( + model_list: Vec, + query: &str, + executor: BackgroundExecutor, + ) -> Vec { + let candidates = model_list + .iter() + .enumerate() + .map(|(ix, model)| { + StringMatchCandidate::new(ix, &format!("{}/{}", model.id, model.name)) + }) + .collect::>(); + let mut matches = match_strings( + &candidates, + &query, + false, + true, + 100, + &Default::default(), + executor, + ) + .await; + + matches.sort_unstable_by_key(|mat| { + let candidate = &candidates[mat.candidate_id]; + (Reverse(OrderedFloat(mat.score)), candidate.id) + }); + + matches + .into_iter() + .map(|mat| model_list[mat.candidate_id].clone()) + .collect() + } + + match model_list { + AgentModelList::Flat(model_list) => { + AgentModelList::Flat(fuzzy_search_list(model_list, &query, executor).await) + } + AgentModelList::Grouped(index_map) => { + let groups = + futures::future::join_all(index_map.into_iter().map(|(group_name, models)| { + fuzzy_search_list(models, &query, executor.clone()) + .map(|results| (group_name, results)) + })) + .await; + AgentModelList::Grouped(IndexMap::from_iter( + groups + .into_iter() + .filter(|(_, results)| !results.is_empty()), + )) + } + } +} + +#[cfg(test)] +mod tests { + use gpui::TestAppContext; + + use super::*; + + fn create_model_list(grouped_models: Vec<(&str, Vec<&str>)>) -> AgentModelList { + AgentModelList::Grouped(IndexMap::from_iter(grouped_models.into_iter().map( + |(group, models)| { + ( + acp_thread::AgentModelGroupName(group.to_string().into()), + models + .into_iter() + .map(|model| acp_thread::AgentModelInfo { + id: acp_thread::AgentModelId(model.to_string().into()), + name: model.to_string().into(), + icon: None, + }) + .collect::>(), + ) + }, + ))) + } + + fn assert_models_eq(result: AgentModelList, expected: Vec<(&str, Vec<&str>)>) { + let AgentModelList::Grouped(groups) = result else { + panic!("Expected LanguageModelInfoList::Grouped, got {:?}", result); + }; + + assert_eq!( + groups.len(), + expected.len(), + "Number of groups doesn't match" + ); + + for (i, (expected_group, expected_models)) in expected.iter().enumerate() { + let (actual_group, actual_models) = groups.get_index(i).unwrap(); + assert_eq!( + actual_group.0.as_ref(), + *expected_group, + "Group at position {} doesn't match expected group", + i + ); + assert_eq!( + actual_models.len(), + expected_models.len(), + "Number of models in group {} doesn't match", + expected_group + ); + + for (j, expected_model_name) in expected_models.iter().enumerate() { + assert_eq!( + actual_models[j].name, *expected_model_name, + "Model at position {} in group {} doesn't match expected model", + j, expected_group + ); + } + } + } + + #[gpui::test] + async fn test_fuzzy_match(cx: &mut TestAppContext) { + let models = create_model_list(vec![ + ( + "zed", + vec![ + "Claude 3.7 Sonnet", + "Claude 3.7 Sonnet Thinking", + "gpt-4.1", + "gpt-4.1-nano", + ], + ), + ("openai", vec!["gpt-3.5-turbo", "gpt-4.1", "gpt-4.1-nano"]), + ("ollama", vec!["mistral", "deepseek"]), + ]); + + // Results should preserve models order whenever possible. + // In the case below, `zed/gpt-4.1` and `openai/gpt-4.1` have identical + // similarity scores, but `zed/gpt-4.1` was higher in the models list, + // so it should appear first in the results. + let results = fuzzy_search(models.clone(), "41".into(), cx.executor()).await; + assert_models_eq( + results, + vec![ + ("zed", vec!["gpt-4.1", "gpt-4.1-nano"]), + ("openai", vec!["gpt-4.1", "gpt-4.1-nano"]), + ], + ); + + // Fuzzy search + let results = fuzzy_search(models.clone(), "4n".into(), cx.executor()).await; + assert_models_eq( + results, + vec![ + ("zed", vec!["gpt-4.1-nano"]), + ("openai", vec!["gpt-4.1-nano"]), + ], + ); + } +} diff --git a/crates/agent_ui/src/acp/model_selector_popover.rs b/crates/agent_ui/src/acp/model_selector_popover.rs new file mode 100644 index 0000000000..e52101113a --- /dev/null +++ b/crates/agent_ui/src/acp/model_selector_popover.rs @@ -0,0 +1,85 @@ +use std::rc::Rc; + +use acp_thread::AgentModelSelector; +use agent_client_protocol as acp; +use gpui::{Entity, FocusHandle}; +use picker::popover_menu::PickerPopoverMenu; +use ui::{ + ButtonLike, Context, IntoElement, PopoverMenuHandle, SharedString, Tooltip, Window, prelude::*, +}; +use zed_actions::agent::ToggleModelSelector; + +use crate::acp::{AcpModelSelector, model_selector::acp_model_selector}; + +pub struct AcpModelSelectorPopover { + selector: Entity, + menu_handle: PopoverMenuHandle, + focus_handle: FocusHandle, +} + +impl AcpModelSelectorPopover { + pub(crate) fn new( + session_id: acp::SessionId, + selector: Rc, + menu_handle: PopoverMenuHandle, + focus_handle: FocusHandle, + window: &mut Window, + cx: &mut Context, + ) -> Self { + Self { + selector: cx.new(move |cx| acp_model_selector(session_id, selector, window, cx)), + menu_handle, + focus_handle, + } + } + + pub fn toggle(&self, window: &mut Window, cx: &mut Context) { + self.menu_handle.toggle(window, cx); + } +} + +impl Render for AcpModelSelectorPopover { + fn render(&mut self, window: &mut Window, cx: &mut Context) -> impl IntoElement { + let model = self.selector.read(cx).delegate.active_model(); + let model_name = model + .as_ref() + .map(|model| model.name.clone()) + .unwrap_or_else(|| SharedString::from("Select a Model")); + + let model_icon = model.as_ref().and_then(|model| model.icon); + + let focus_handle = self.focus_handle.clone(); + + PickerPopoverMenu::new( + self.selector.clone(), + ButtonLike::new("active-model") + .when_some(model_icon, |this, icon| { + this.child(Icon::new(icon).color(Color::Muted).size(IconSize::XSmall)) + }) + .child( + Label::new(model_name) + .color(Color::Muted) + .size(LabelSize::Small) + .ml_0p5(), + ) + .child( + Icon::new(IconName::ChevronDown) + .color(Color::Muted) + .size(IconSize::XSmall), + ), + move |window, cx| { + Tooltip::for_action_in( + "Change Model", + &ToggleModelSelector, + &focus_handle, + window, + cx, + ) + }, + gpui::Corner::BottomRight, + cx, + ) + .with_handle(self.menu_handle.clone()) + .render(window, cx) + } +} diff --git a/crates/agent_ui/src/acp/thread_view.rs b/crates/agent_ui/src/acp/thread_view.rs index f47c7a0bc5..0b3ace1baf 100644 --- a/crates/agent_ui/src/acp/thread_view.rs +++ b/crates/agent_ui/src/acp/thread_view.rs @@ -27,6 +27,7 @@ use language::{Buffer, Language}; use markdown::{HeadingLevelStyles, Markdown, MarkdownElement, MarkdownStyle}; use parking_lot::Mutex; use project::{CompletionIntent, Project}; +use rope::Point; use settings::{Settings as _, SettingsStore}; use std::path::PathBuf; use std::{ @@ -37,12 +38,14 @@ use terminal_view::TerminalView; use text::{Anchor, BufferSnapshot}; use theme::ThemeSettings; use ui::{ - Disclosure, Divider, DividerColor, KeyBinding, Scrollbar, ScrollbarState, Tooltip, prelude::*, + Disclosure, Divider, DividerColor, KeyBinding, PopoverMenuHandle, Scrollbar, ScrollbarState, + Tooltip, prelude::*, }; use util::{ResultExt, size::format_file_size, time::duration_alt_display}; use workspace::{CollaboratorId, Workspace}; -use zed_actions::agent::{Chat, NextHistoryMessage, PreviousHistoryMessage}; +use zed_actions::agent::{Chat, NextHistoryMessage, PreviousHistoryMessage, ToggleModelSelector}; +use crate::acp::AcpModelSelectorPopover; use crate::acp::completion_provider::{ContextPickerCompletionProvider, MentionSet}; use crate::acp::message_history::MessageHistory; use crate::agent_diff::AgentDiff; @@ -62,6 +65,7 @@ pub struct AcpThreadView { diff_editors: HashMap>, terminal_views: HashMap>, message_editor: Entity, + model_selector: Option>, message_set_from_history: Option, _message_editor_subscription: Subscription, mention_set: Arc>, @@ -186,6 +190,7 @@ impl AcpThreadView { project: project.clone(), thread_state: Self::initial_state(agent, workspace, project, window, cx), message_editor, + model_selector: None, message_set_from_history: None, _message_editor_subscription: message_editor_subscription, mention_set, @@ -269,7 +274,7 @@ impl AcpThreadView { Err(e) } } - Ok(session_id) => Ok(session_id), + Ok(thread) => Ok(thread), }; this.update_in(cx, |this, window, cx| { @@ -287,6 +292,24 @@ impl AcpThreadView { AgentDiff::set_active_thread(&workspace, thread.clone(), window, cx); + this.model_selector = + thread + .read(cx) + .connection() + .model_selector() + .map(|selector| { + cx.new(|cx| { + AcpModelSelectorPopover::new( + thread.read(cx).session_id().clone(), + selector, + PopoverMenuHandle::default(), + this.focus_handle(cx), + window, + cx, + ) + }) + }); + this.thread_state = ThreadState::Ready { thread, _subscription: [thread_subscription, action_log_subscription], @@ -656,17 +679,19 @@ impl AcpThreadView { window: &mut Window, cx: &mut Context, ) { - let count = self.list_state.item_count(); match event { AcpThreadEvent::NewEntry => { let index = thread.read(cx).entries().len() - 1; self.sync_thread_entry_view(index, window, cx); - self.list_state.splice(count..count, 1); + self.list_state.splice(index..index, 1); } AcpThreadEvent::EntryUpdated(index) => { - let index = *index; - self.sync_thread_entry_view(index, window, cx); - self.list_state.splice(index..index + 1, 1); + self.sync_thread_entry_view(*index, window, cx); + self.list_state.splice(*index..index + 1, 1); + } + AcpThreadEvent::EntriesRemoved(range) => { + // TODO: Clean up unused diff editors and terminal views + self.list_state.splice(range.clone(), 0); } AcpThreadEvent::ToolAuthorizationRequired => { self.notify_with_sound("Waiting for tool confirmation", IconName::Info, window, cx); @@ -2471,6 +2496,12 @@ impl AcpThreadView { v_flex() .on_action(cx.listener(Self::expand_message_editor)) + .on_action(cx.listener(|this, _: &ToggleModelSelector, window, cx| { + if let Some(model_selector) = this.model_selector.as_ref() { + model_selector + .update(cx, |model_selector, cx| model_selector.toggle(window, cx)); + } + })) .p_2() .gap_2() .border_t_1() @@ -2547,7 +2578,12 @@ impl AcpThreadView { .flex_none() .justify_between() .child(self.render_follow_toggle(cx)) - .child(self.render_send_button(cx)), + .child( + h_flex() + .gap_1() + .children(self.model_selector.clone()) + .child(self.render_send_button(cx)), + ), ) .into_any() } @@ -2679,26 +2715,24 @@ impl AcpThreadView { window: &mut Window, cx: &mut Context, ) -> Option<()> { - let location = self + let (tool_call_location, agent_location) = self .thread()? .read(cx) .entries() .get(entry_ix)? - .locations()? - .get(location_ix)?; + .location(location_ix)?; let project_path = self .project .read(cx) - .find_project_path(&location.path, cx)?; + .find_project_path(&tool_call_location.path, cx)?; let open_task = self .workspace - .update(cx, |worskpace, cx| { - worskpace.open_path(project_path, None, true, window, cx) + .update(cx, |workspace, cx| { + workspace.open_path(project_path, None, true, window, cx) }) .log_err()?; - window .spawn(cx, async move |cx| { let item = open_task.await?; @@ -2708,17 +2742,22 @@ impl AcpThreadView { }; active_editor.update_in(cx, |editor, window, cx| { - let snapshot = editor.buffer().read(cx).snapshot(cx); - let first_hunk = editor - .diff_hunks_in_ranges( - &[editor::Anchor::min()..editor::Anchor::max()], - &snapshot, - ) - .next(); - if let Some(first_hunk) = first_hunk { - let first_hunk_start = first_hunk.multi_buffer_range().start; + let multibuffer = editor.buffer().read(cx); + let buffer = multibuffer.as_singleton(); + if agent_location.buffer.upgrade() == buffer { + let excerpt_id = multibuffer.excerpt_ids().first().cloned(); + let anchor = editor::Anchor::in_buffer( + excerpt_id.unwrap(), + buffer.unwrap().read(cx).remote_id(), + agent_location.position, + ); editor.change_selections(Default::default(), window, cx, |selections| { - selections.select_anchor_ranges([first_hunk_start..first_hunk_start]); + selections.select_anchor_ranges([anchor..anchor]); + }) + } else { + let row = tool_call_location.line.unwrap_or_default(); + editor.change_selections(Default::default(), window, cx, |selections| { + selections.select_ranges([Point::new(row, 0)..Point::new(row, 0)]); }) } })?; @@ -3752,6 +3791,7 @@ mod tests { fn prompt( &self, + _id: Option, params: acp::PromptRequest, cx: &mut App, ) -> Task> { @@ -3836,6 +3876,7 @@ mod tests { fn prompt( &self, + _id: Option, _params: acp::PromptRequest, _cx: &mut App, ) -> Task> { diff --git a/crates/agent_ui/src/agent_diff.rs b/crates/agent_ui/src/agent_diff.rs index 0abc5280f4..b9e1ea5d0a 100644 --- a/crates/agent_ui/src/agent_diff.rs +++ b/crates/agent_ui/src/agent_diff.rs @@ -1521,7 +1521,8 @@ impl AgentDiff { self.update_reviewing_editors(workspace, window, cx); } } - AcpThreadEvent::Stopped + AcpThreadEvent::EntriesRemoved(_) + | AcpThreadEvent::Stopped | AcpThreadEvent::ToolAuthorizationRequired | AcpThreadEvent::Error | AcpThreadEvent::ServerExited(_) => {} diff --git a/crates/agent_ui/src/agent_panel.rs b/crates/agent_ui/src/agent_panel.rs index 87e4dd822c..d07581da93 100644 --- a/crates/agent_ui/src/agent_panel.rs +++ b/crates/agent_ui/src/agent_panel.rs @@ -916,6 +916,7 @@ impl AgentPanel { let workspace = self.workspace.clone(); let project = self.project.clone(); let message_history = self.acp_message_history.clone(); + let fs = self.fs.clone(); const LAST_USED_EXTERNAL_AGENT_KEY: &str = "agent_panel__last_used_external_agent"; @@ -939,7 +940,7 @@ impl AgentPanel { }) .detach(); - agent.server() + agent.server(fs) } None => cx .background_spawn(async move { @@ -953,7 +954,7 @@ impl AgentPanel { }) .unwrap_or_default() .agent - .server(), + .server(fs), }; this.update_in(cx, |this, window, cx| { diff --git a/crates/agent_ui/src/agent_ui.rs b/crates/agent_ui/src/agent_ui.rs index fceb8f4c45..b776c0830b 100644 --- a/crates/agent_ui/src/agent_ui.rs +++ b/crates/agent_ui/src/agent_ui.rs @@ -155,11 +155,11 @@ enum ExternalAgent { } impl ExternalAgent { - pub fn server(&self) -> Rc { + pub fn server(&self, fs: Arc) -> Rc { match self { ExternalAgent::Gemini => Rc::new(agent_servers::Gemini), ExternalAgent::ClaudeCode => Rc::new(agent_servers::ClaudeCode), - ExternalAgent::NativeAgent => Rc::new(agent2::NativeAgentServer), + ExternalAgent::NativeAgent => Rc::new(agent2::NativeAgentServer::new(fs)), } } } diff --git a/crates/assistant_tools/src/edit_agent.rs b/crates/assistant_tools/src/edit_agent.rs index 9305f584cb..aa321aa8f3 100644 --- a/crates/assistant_tools/src/edit_agent.rs +++ b/crates/assistant_tools/src/edit_agent.rs @@ -65,7 +65,7 @@ pub enum EditAgentOutputEvent { ResolvingEditRange(Range), UnresolvedEditRange, AmbiguousEditRange(Vec>), - Edited, + Edited(Range), } #[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] @@ -178,7 +178,9 @@ impl EditAgent { ) }); output_events_tx - .unbounded_send(EditAgentOutputEvent::Edited) + .unbounded_send(EditAgentOutputEvent::Edited( + language::Anchor::MIN..language::Anchor::MAX, + )) .ok(); })?; @@ -200,7 +202,9 @@ impl EditAgent { }); })?; output_events_tx - .unbounded_send(EditAgentOutputEvent::Edited) + .unbounded_send(EditAgentOutputEvent::Edited( + language::Anchor::MIN..language::Anchor::MAX, + )) .ok(); } } @@ -336,8 +340,8 @@ impl EditAgent { // Edit the buffer and report edits to the action log as part of the // same effect cycle, otherwise the edit will be reported as if the // user made it. - cx.update(|cx| { - let max_edit_end = buffer.update(cx, |buffer, cx| { + let (min_edit_start, max_edit_end) = cx.update(|cx| { + let (min_edit_start, max_edit_end) = buffer.update(cx, |buffer, cx| { buffer.edit(edits.iter().cloned(), None, cx); let max_edit_end = buffer .summaries_for_anchors::( @@ -345,7 +349,16 @@ impl EditAgent { ) .max() .unwrap(); - buffer.anchor_before(max_edit_end) + let min_edit_start = buffer + .summaries_for_anchors::( + edits.iter().map(|(range, _)| &range.start), + ) + .min() + .unwrap(); + ( + buffer.anchor_after(min_edit_start), + buffer.anchor_before(max_edit_end), + ) }); self.action_log .update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx)); @@ -358,9 +371,10 @@ impl EditAgent { cx, ); }); + (min_edit_start, max_edit_end) })?; output_events - .unbounded_send(EditAgentOutputEvent::Edited) + .unbounded_send(EditAgentOutputEvent::Edited(min_edit_start..max_edit_end)) .ok(); } @@ -755,6 +769,7 @@ mod tests { use gpui::{AppContext, TestAppContext}; use indoc::indoc; use language_model::fake_provider::FakeLanguageModel; + use pretty_assertions::assert_matches; use project::{AgentLocation, Project}; use rand::prelude::*; use rand::rngs::StdRng; @@ -992,7 +1007,10 @@ mod tests { model.send_last_completion_stream_text_chunk("abX"); cx.run_until_parked(); - assert_eq!(drain_events(&mut events), [EditAgentOutputEvent::Edited]); + assert_matches!( + drain_events(&mut events).as_slice(), + [EditAgentOutputEvent::Edited(_)] + ); assert_eq!( buffer.read_with(cx, |buffer, _| buffer.snapshot().text()), "abXc\ndef\nghi\njkl" @@ -1007,7 +1025,10 @@ mod tests { model.send_last_completion_stream_text_chunk("cY"); cx.run_until_parked(); - assert_eq!(drain_events(&mut events), [EditAgentOutputEvent::Edited]); + assert_matches!( + drain_events(&mut events).as_slice(), + [EditAgentOutputEvent::Edited { .. }] + ); assert_eq!( buffer.read_with(cx, |buffer, _| buffer.snapshot().text()), "abXcY\ndef\nghi\njkl" @@ -1118,9 +1139,9 @@ mod tests { model.send_last_completion_stream_text_chunk("GHI"); cx.run_until_parked(); - assert_eq!( - drain_events(&mut events), - vec![EditAgentOutputEvent::Edited] + assert_matches!( + drain_events(&mut events).as_slice(), + [EditAgentOutputEvent::Edited { .. }] ); assert_eq!( buffer.read_with(cx, |buffer, _| buffer.snapshot().text()), @@ -1165,9 +1186,9 @@ mod tests { ); cx.run_until_parked(); - assert_eq!( - drain_events(&mut events), - vec![EditAgentOutputEvent::Edited] + assert_matches!( + drain_events(&mut events).as_slice(), + [EditAgentOutputEvent::Edited(_)] ); assert_eq!( buffer.read_with(cx, |buffer, _| buffer.snapshot().text()), @@ -1183,9 +1204,9 @@ mod tests { chunks_tx.unbounded_send("```\njkl\n").unwrap(); cx.run_until_parked(); - assert_eq!( - drain_events(&mut events), - vec![EditAgentOutputEvent::Edited] + assert_matches!( + drain_events(&mut events).as_slice(), + [EditAgentOutputEvent::Edited { .. }] ); assert_eq!( buffer.read_with(cx, |buffer, _| buffer.snapshot().text()), @@ -1201,9 +1222,9 @@ mod tests { chunks_tx.unbounded_send("mno\n").unwrap(); cx.run_until_parked(); - assert_eq!( - drain_events(&mut events), - vec![EditAgentOutputEvent::Edited] + assert_matches!( + drain_events(&mut events).as_slice(), + [EditAgentOutputEvent::Edited { .. }] ); assert_eq!( buffer.read_with(cx, |buffer, _| buffer.snapshot().text()), @@ -1219,9 +1240,9 @@ mod tests { chunks_tx.unbounded_send("pqr\n```").unwrap(); cx.run_until_parked(); - assert_eq!( - drain_events(&mut events), - vec![EditAgentOutputEvent::Edited] + assert_matches!( + drain_events(&mut events).as_slice(), + [EditAgentOutputEvent::Edited(_)], ); assert_eq!( buffer.read_with(cx, |buffer, _| buffer.snapshot().text()), diff --git a/crates/assistant_tools/src/edit_file_tool.rs b/crates/assistant_tools/src/edit_file_tool.rs index b5712415ec..e819c51e1e 100644 --- a/crates/assistant_tools/src/edit_file_tool.rs +++ b/crates/assistant_tools/src/edit_file_tool.rs @@ -307,7 +307,7 @@ impl Tool for EditFileTool { let mut ambiguous_ranges = Vec::new(); while let Some(event) = events.next().await { match event { - EditAgentOutputEvent::Edited => { + EditAgentOutputEvent::Edited { .. } => { if let Some(card) = card_clone.as_ref() { card.update(cx, |card, cx| card.update_diff(cx))?; } diff --git a/crates/audio/Cargo.toml b/crates/audio/Cargo.toml index d857a3eb2f..f1f40ad654 100644 --- a/crates/audio/Cargo.toml +++ b/crates/audio/Cargo.toml @@ -18,6 +18,6 @@ collections.workspace = true derive_more.workspace = true gpui.workspace = true parking_lot.workspace = true -rodio = { version = "0.21.1", default-features = false, features = ["wav", "playback", "tracing"] } +rodio = { workspace = true, features = ["wav", "playback", "tracing"] } util.workspace = true workspace-hack.workspace = true diff --git a/crates/auto_update/src/auto_update.rs b/crates/auto_update/src/auto_update.rs index 074aaa6fea..4d0d2d5984 100644 --- a/crates/auto_update/src/auto_update.rs +++ b/crates/auto_update/src/auto_update.rs @@ -59,16 +59,9 @@ pub enum VersionCheckType { pub enum AutoUpdateStatus { Idle, Checking, - Downloading { - version: VersionCheckType, - }, - Installing { - version: VersionCheckType, - }, - Updated { - binary_path: PathBuf, - version: VersionCheckType, - }, + Downloading { version: VersionCheckType }, + Installing { version: VersionCheckType }, + Updated { version: VersionCheckType }, Errored, } @@ -83,6 +76,7 @@ pub struct AutoUpdater { current_version: SemanticVersion, http_client: Arc, pending_poll: Option>>, + quit_subscription: Option, } #[derive(Deserialize, Clone, Debug)] @@ -164,7 +158,7 @@ pub fn init(http_client: Arc, cx: &mut App) { AutoUpdateSetting::register(cx); cx.observe_new(|workspace: &mut Workspace, _window, _cx| { - workspace.register_action(|_, action: &Check, window, cx| check(action, window, cx)); + workspace.register_action(|_, action, window, cx| check(action, window, cx)); workspace.register_action(|_, action, _, cx| { view_release_notes(action, cx); @@ -174,7 +168,7 @@ pub fn init(http_client: Arc, cx: &mut App) { let version = release_channel::AppVersion::global(cx); let auto_updater = cx.new(|cx| { - let updater = AutoUpdater::new(version, http_client); + let updater = AutoUpdater::new(version, http_client, cx); let poll_for_updates = ReleaseChannel::try_global(cx) .map(|channel| channel.poll_for_updates()) @@ -321,12 +315,34 @@ impl AutoUpdater { cx.default_global::().0.clone() } - fn new(current_version: SemanticVersion, http_client: Arc) -> Self { + fn new( + current_version: SemanticVersion, + http_client: Arc, + cx: &mut Context, + ) -> Self { + // On windows, executable files cannot be overwritten while they are + // running, so we must wait to overwrite the application until quitting + // or restarting. When quitting the app, we spawn the auto update helper + // to finish the auto update process after Zed exits. When restarting + // the app after an update, we use `set_restart_path` to run the auto + // update helper instead of the app, so that it can overwrite the app + // and then spawn the new binary. + let quit_subscription = Some(cx.on_app_quit(|_, _| async move { + #[cfg(target_os = "windows")] + finalize_auto_update_on_quit(); + })); + + cx.on_app_restart(|this, _| { + this.quit_subscription.take(); + }) + .detach(); + Self { status: AutoUpdateStatus::Idle, current_version, http_client, pending_poll: None, + quit_subscription, } } @@ -536,6 +552,8 @@ impl AutoUpdater { ) })?; + Self::check_dependencies()?; + this.update(&mut cx, |this, cx| { this.status = AutoUpdateStatus::Checking; cx.notify(); @@ -582,13 +600,15 @@ impl AutoUpdater { cx.notify(); })?; - let binary_path = Self::binary_path(installer_dir, target_path, &cx).await?; + let new_binary_path = Self::install_release(installer_dir, target_path, &cx).await?; + if let Some(new_binary_path) = new_binary_path { + cx.update(|cx| cx.set_restart_path(new_binary_path))?; + } this.update(&mut cx, |this, cx| { this.set_should_show_update_notification(true, cx) .detach_and_log_err(cx); this.status = AutoUpdateStatus::Updated { - binary_path, version: newer_version, }; cx.notify(); @@ -639,6 +659,15 @@ impl AutoUpdater { } } + fn check_dependencies() -> Result<()> { + #[cfg(not(target_os = "windows"))] + anyhow::ensure!( + which::which("rsync").is_ok(), + "Aborting. Could not find rsync which is required for auto-updates." + ); + Ok(()) + } + async fn target_path(installer_dir: &InstallerDir) -> Result { let filename = match OS { "macos" => anyhow::Ok("Zed.dmg"), @@ -647,20 +676,14 @@ impl AutoUpdater { unsupported_os => anyhow::bail!("not supported: {unsupported_os}"), }?; - #[cfg(not(target_os = "windows"))] - anyhow::ensure!( - which::which("rsync").is_ok(), - "Aborting. Could not find rsync which is required for auto-updates." - ); - Ok(installer_dir.path().join(filename)) } - async fn binary_path( + async fn install_release( installer_dir: InstallerDir, target_path: PathBuf, cx: &AsyncApp, - ) -> Result { + ) -> Result> { match OS { "macos" => install_release_macos(&installer_dir, target_path, cx).await, "linux" => install_release_linux(&installer_dir, target_path, cx).await, @@ -801,7 +824,7 @@ async fn install_release_linux( temp_dir: &InstallerDir, downloaded_tar_gz: PathBuf, cx: &AsyncApp, -) -> Result { +) -> Result> { let channel = cx.update(|cx| ReleaseChannel::global(cx).dev_name())?; let home_dir = PathBuf::from(env::var("HOME").context("no HOME env var set")?); let running_app_path = cx.update(|cx| cx.app_path())??; @@ -861,14 +884,14 @@ async fn install_release_linux( String::from_utf8_lossy(&output.stderr) ); - Ok(to.join(expected_suffix)) + Ok(Some(to.join(expected_suffix))) } async fn install_release_macos( temp_dir: &InstallerDir, downloaded_dmg: PathBuf, cx: &AsyncApp, -) -> Result { +) -> Result> { let running_app_path = cx.update(|cx| cx.app_path())??; let running_app_filename = running_app_path .file_name() @@ -910,10 +933,10 @@ async fn install_release_macos( String::from_utf8_lossy(&output.stderr) ); - Ok(running_app_path) + Ok(None) } -async fn install_release_windows(downloaded_installer: PathBuf) -> Result { +async fn install_release_windows(downloaded_installer: PathBuf) -> Result> { let output = Command::new(downloaded_installer) .arg("/verysilent") .arg("/update=true") @@ -926,29 +949,36 @@ async fn install_release_windows(downloaded_installer: PathBuf) -> Result bool { +pub fn finalize_auto_update_on_quit() { let Some(installer_path) = std::env::current_exe() .ok() .and_then(|p| p.parent().map(|p| p.join("updates"))) else { - return false; + return; }; // The installer will create a flag file after it finishes updating let flag_file = installer_path.join("versions.txt"); - if flag_file.exists() { - if let Some(helper) = installer_path + if flag_file.exists() + && let Some(helper) = installer_path .parent() .map(|p| p.join("tools\\auto_update_helper.exe")) - { - let _ = std::process::Command::new(helper).spawn(); - return true; - } + { + let mut command = std::process::Command::new(helper); + command.arg("--launch"); + command.arg("false"); + let _ = command.spawn(); } - false } #[cfg(test)] @@ -1002,7 +1032,6 @@ mod tests { let app_commit_sha = Ok(Some("a".to_string())); let installed_version = SemanticVersion::new(1, 0, 0); let status = AutoUpdateStatus::Updated { - binary_path: PathBuf::new(), version: VersionCheckType::Semantic(SemanticVersion::new(1, 0, 1)), }; let fetched_version = SemanticVersion::new(1, 0, 1); @@ -1024,7 +1053,6 @@ mod tests { let app_commit_sha = Ok(Some("a".to_string())); let installed_version = SemanticVersion::new(1, 0, 0); let status = AutoUpdateStatus::Updated { - binary_path: PathBuf::new(), version: VersionCheckType::Semantic(SemanticVersion::new(1, 0, 1)), }; let fetched_version = SemanticVersion::new(1, 0, 2); @@ -1090,7 +1118,6 @@ mod tests { let app_commit_sha = Ok(Some("a".to_string())); let installed_version = SemanticVersion::new(1, 0, 0); let status = AutoUpdateStatus::Updated { - binary_path: PathBuf::new(), version: VersionCheckType::Sha(AppCommitSha::new("b".to_string())), }; let fetched_sha = "b".to_string(); @@ -1112,7 +1139,6 @@ mod tests { let app_commit_sha = Ok(Some("a".to_string())); let installed_version = SemanticVersion::new(1, 0, 0); let status = AutoUpdateStatus::Updated { - binary_path: PathBuf::new(), version: VersionCheckType::Sha(AppCommitSha::new("b".to_string())), }; let fetched_sha = "c".to_string(); @@ -1160,7 +1186,6 @@ mod tests { let app_commit_sha = Ok(None); let installed_version = SemanticVersion::new(1, 0, 0); let status = AutoUpdateStatus::Updated { - binary_path: PathBuf::new(), version: VersionCheckType::Sha(AppCommitSha::new("b".to_string())), }; let fetched_sha = "b".to_string(); @@ -1183,7 +1208,6 @@ mod tests { let app_commit_sha = Ok(None); let installed_version = SemanticVersion::new(1, 0, 0); let status = AutoUpdateStatus::Updated { - binary_path: PathBuf::new(), version: VersionCheckType::Sha(AppCommitSha::new("b".to_string())), }; let fetched_sha = "c".to_string(); diff --git a/crates/auto_update_helper/src/auto_update_helper.rs b/crates/auto_update_helper/src/auto_update_helper.rs index 7c810d8724..2781176028 100644 --- a/crates/auto_update_helper/src/auto_update_helper.rs +++ b/crates/auto_update_helper/src/auto_update_helper.rs @@ -37,6 +37,11 @@ mod windows_impl { pub(crate) const WM_JOB_UPDATED: u32 = WM_USER + 1; pub(crate) const WM_TERMINATE: u32 = WM_USER + 2; + #[derive(Debug)] + struct Args { + launch: Option, + } + pub(crate) fn run() -> Result<()> { let helper_dir = std::env::current_exe()? .parent() @@ -51,8 +56,9 @@ mod windows_impl { log::info!("======= Starting Zed update ======="); let (tx, rx) = std::sync::mpsc::channel(); let hwnd = create_dialog_window(rx)?.0 as isize; + let args = parse_args(); std::thread::spawn(move || { - let result = perform_update(app_dir.as_path(), Some(hwnd)); + let result = perform_update(app_dir.as_path(), Some(hwnd), args.launch.unwrap_or(true)); tx.send(result).ok(); unsafe { PostMessageW(Some(HWND(hwnd as _)), WM_TERMINATE, WPARAM(0), LPARAM(0)) }.ok(); }); @@ -77,6 +83,41 @@ mod windows_impl { Ok(()) } + fn parse_args() -> Args { + let mut result = Args { launch: None }; + if let Some(candidate) = std::env::args().nth(1) { + parse_single_arg(&candidate, &mut result); + } + + result + } + + fn parse_single_arg(arg: &str, result: &mut Args) { + let Some((key, value)) = arg.strip_prefix("--").and_then(|arg| arg.split_once('=')) else { + log::error!( + "Invalid argument format: '{}'. Expected format: --key=value", + arg + ); + return; + }; + + match key { + "launch" => parse_launch_arg(value, &mut result.launch), + _ => log::error!("Unknown argument: --{}", key), + } + } + + fn parse_launch_arg(value: &str, arg: &mut Option) { + match value { + "true" => *arg = Some(true), + "false" => *arg = Some(false), + _ => log::error!( + "Invalid value for --launch: '{}'. Expected 'true' or 'false'", + value + ), + } + } + pub(crate) fn show_error(mut content: String) { if content.len() > 600 { content.truncate(600); @@ -91,4 +132,47 @@ mod windows_impl { ) }; } + + #[cfg(test)] + mod tests { + use crate::windows_impl::{Args, parse_launch_arg, parse_single_arg}; + + #[test] + fn test_parse_launch_arg() { + let mut arg = None; + parse_launch_arg("true", &mut arg); + assert_eq!(arg, Some(true)); + + let mut arg = None; + parse_launch_arg("false", &mut arg); + assert_eq!(arg, Some(false)); + + let mut arg = None; + parse_launch_arg("invalid", &mut arg); + assert_eq!(arg, None); + } + + #[test] + fn test_parse_single_arg() { + let mut args = Args { launch: None }; + parse_single_arg("--launch=true", &mut args); + assert_eq!(args.launch, Some(true)); + + let mut args = Args { launch: None }; + parse_single_arg("--launch=false", &mut args); + assert_eq!(args.launch, Some(false)); + + let mut args = Args { launch: None }; + parse_single_arg("--launch=invalid", &mut args); + assert_eq!(args.launch, None); + + let mut args = Args { launch: None }; + parse_single_arg("--launch", &mut args); + assert_eq!(args.launch, None); + + let mut args = Args { launch: None }; + parse_single_arg("--unknown", &mut args); + assert_eq!(args.launch, None); + } + } } diff --git a/crates/auto_update_helper/src/dialog.rs b/crates/auto_update_helper/src/dialog.rs index 010ebb4875..757819df51 100644 --- a/crates/auto_update_helper/src/dialog.rs +++ b/crates/auto_update_helper/src/dialog.rs @@ -72,7 +72,7 @@ pub(crate) fn create_dialog_window(receiver: Receiver>) -> Result) -> Result<()> { +pub(crate) fn perform_update(app_dir: &Path, hwnd: Option, launch: bool) -> Result<()> { let hwnd = hwnd.map(|ptr| HWND(ptr as _)); for job in JOBS.iter() { @@ -145,9 +145,11 @@ pub(crate) fn perform_update(app_dir: &Path, hwnd: Option) -> Result<()> } } } - let _ = std::process::Command::new(app_dir.join("Zed.exe")) - .creation_flags(CREATE_NEW_PROCESS_GROUP.0) - .spawn(); + if launch { + let _ = std::process::Command::new(app_dir.join("Zed.exe")) + .creation_flags(CREATE_NEW_PROCESS_GROUP.0) + .spawn(); + } log::info!("Update completed successfully"); Ok(()) } @@ -159,11 +161,11 @@ mod test { #[test] fn test_perform_update() { let app_dir = std::path::Path::new("C:/"); - assert!(perform_update(app_dir, None).is_ok()); + assert!(perform_update(app_dir, None, false).is_ok()); // Simulate a timeout unsafe { std::env::set_var("ZED_AUTO_UPDATE", "err") }; - let ret = perform_update(app_dir, None); + let ret = perform_update(app_dir, None, false); assert!(ret.is_err_and(|e| e.to_string().as_str() == "Timed out")); } } diff --git a/crates/cli/src/main.rs b/crates/cli/src/main.rs index 8d6cd2544a..67591167df 100644 --- a/crates/cli/src/main.rs +++ b/crates/cli/src/main.rs @@ -957,17 +957,14 @@ mod mac_os { ) -> Result<()> { use anyhow::bail; - let app_id_prompt = format!("id of app \"{}\"", channel.display_name()); - let app_id_output = Command::new("osascript") + let app_path_prompt = format!( + "POSIX path of (path to application \"{}\")", + channel.display_name() + ); + let app_path_output = Command::new("osascript") .arg("-e") - .arg(&app_id_prompt) + .arg(&app_path_prompt) .output()?; - if !app_id_output.status.success() { - bail!("Could not determine app id for {}", channel.display_name()); - } - let app_name = String::from_utf8(app_id_output.stdout)?.trim().to_owned(); - let app_path_prompt = format!("kMDItemCFBundleIdentifier == '{app_name}'"); - let app_path_output = Command::new("mdfind").arg(app_path_prompt).output()?; if !app_path_output.status.success() { bail!( "Could not determine app path for {}", diff --git a/crates/client/src/telemetry.rs b/crates/client/src/telemetry.rs index 43a1a0b7a4..54b3d3f801 100644 --- a/crates/client/src/telemetry.rs +++ b/crates/client/src/telemetry.rs @@ -340,22 +340,35 @@ impl Telemetry { } pub fn log_edit_event(self: &Arc, environment: &'static str, is_via_ssh: bool) { + static LAST_EVENT_TIME: Mutex> = Mutex::new(None); + let mut state = self.state.lock(); let period_data = state.event_coalescer.log_event(environment); drop(state); - if let Some((start, end, environment)) = period_data { - let duration = end - .saturating_duration_since(start) - .min(Duration::from_secs(60 * 60 * 24)) - .as_millis() as i64; + if let Some(mut last_event) = LAST_EVENT_TIME.try_lock() { + let current_time = std::time::Instant::now(); + let last_time = last_event.get_or_insert(current_time); - telemetry::event!( - "Editor Edited", - duration = duration, - environment = environment, - is_via_ssh = is_via_ssh - ); + if current_time.duration_since(*last_time) > Duration::from_secs(60 * 10) { + *last_time = current_time; + } else { + return; + } + + if let Some((start, end, environment)) = period_data { + let duration = end + .saturating_duration_since(start) + .min(Duration::from_secs(60 * 60 * 24)) + .as_millis() as i64; + + telemetry::event!( + "Editor Edited", + duration = duration, + environment = environment, + is_via_ssh = is_via_ssh + ); + } } } diff --git a/crates/editor/src/editor.rs b/crates/editor/src/editor.rs index d1bf95c794..8a9398e71f 100644 --- a/crates/editor/src/editor.rs +++ b/crates/editor/src/editor.rs @@ -250,6 +250,24 @@ pub type RenderDiffHunkControlsFn = Arc< ) -> AnyElement, >; +enum ReportEditorEvent { + Saved { auto_saved: bool }, + EditorOpened, + ZetaTosClicked, + Closed, +} + +impl ReportEditorEvent { + pub fn event_type(&self) -> &'static str { + match self { + Self::Saved { .. } => "Editor Saved", + Self::EditorOpened => "Editor Opened", + Self::ZetaTosClicked => "Edit Prediction Provider ToS Clicked", + Self::Closed => "Editor Closed", + } + } +} + struct InlineValueCache { enabled: bool, inlays: Vec, @@ -2325,7 +2343,7 @@ impl Editor { } if editor.mode.is_full() { - editor.report_editor_event("Editor Opened", None, cx); + editor.report_editor_event(ReportEditorEvent::EditorOpened, None, cx); } editor @@ -9124,7 +9142,7 @@ impl Editor { .on_mouse_down(MouseButton::Left, |_, window, _| window.prevent_default()) .on_click(cx.listener(|this, _event, window, cx| { cx.stop_propagation(); - this.report_editor_event("Edit Prediction Provider ToS Clicked", None, cx); + this.report_editor_event(ReportEditorEvent::ZetaTosClicked, None, cx); window.dispatch_action( zed_actions::OpenZedPredictOnboarding.boxed_clone(), cx, @@ -20547,7 +20565,7 @@ impl Editor { fn report_editor_event( &self, - event_type: &'static str, + reported_event: ReportEditorEvent, file_extension: Option, cx: &App, ) { @@ -20581,15 +20599,30 @@ impl Editor { .show_edit_predictions; let project = project.read(cx); - telemetry::event!( - event_type, - file_extension, - vim_mode, - copilot_enabled, - copilot_enabled_for_language, - edit_predictions_provider, - is_via_ssh = project.is_via_ssh(), - ); + let event_type = reported_event.event_type(); + + if let ReportEditorEvent::Saved { auto_saved } = reported_event { + telemetry::event!( + event_type, + type = if auto_saved {"autosave"} else {"manual"}, + file_extension, + vim_mode, + copilot_enabled, + copilot_enabled_for_language, + edit_predictions_provider, + is_via_ssh = project.is_via_ssh(), + ); + } else { + telemetry::event!( + event_type, + file_extension, + vim_mode, + copilot_enabled, + copilot_enabled_for_language, + edit_predictions_provider, + is_via_ssh = project.is_via_ssh(), + ); + }; } /// Copy the highlighted chunks to the clipboard as JSON. The format is an array of lines, diff --git a/crates/editor/src/editor_tests.rs b/crates/editor/src/editor_tests.rs index b31963c9c8..0d2ecec8f2 100644 --- a/crates/editor/src/editor_tests.rs +++ b/crates/editor/src/editor_tests.rs @@ -22456,7 +22456,7 @@ async fn test_invisible_worktree_servers(cx: &mut TestAppContext) { ); cx.update(|_, cx| { - workspace::reload(&workspace::Reload::default(), cx); + workspace::reload(cx); }); assert_language_servers_count( 1, diff --git a/crates/editor/src/element.rs b/crates/editor/src/element.rs index 8fb6c53171..e86d93d024 100644 --- a/crates/editor/src/element.rs +++ b/crates/editor/src/element.rs @@ -3012,7 +3012,7 @@ impl EditorElement { .icon_color(Color::Custom(cx.theme().colors().editor_line_number)) .icon_size(IconSize::Custom(rems(editor_font_size / window.rem_size()))) .shape(ui::IconButtonShape::Wide) - .width(width.into()) + .width(width) .on_click(move |_, window, cx| { editor.update(cx, |editor, cx| { editor.expand_excerpt(excerpt_id, direction, window, cx); @@ -3628,7 +3628,7 @@ impl EditorElement { ButtonLike::new("toggle-buffer-fold") .style(ui::ButtonStyle::Transparent) .height(px(28.).into()) - .width(px(28.).into()) + .width(px(28.)) .children(toggle_chevron_icon) .tooltip({ let focus_handle = focus_handle.clone(); diff --git a/crates/editor/src/items.rs b/crates/editor/src/items.rs index 231aaa1d00..1da82c605d 100644 --- a/crates/editor/src/items.rs +++ b/crates/editor/src/items.rs @@ -1,7 +1,7 @@ use crate::{ Anchor, Autoscroll, Editor, EditorEvent, EditorSettings, ExcerptId, ExcerptRange, FormatTarget, - MultiBuffer, MultiBufferSnapshot, NavigationData, SearchWithinRange, SelectionEffects, - ToPoint as _, + MultiBuffer, MultiBufferSnapshot, NavigationData, ReportEditorEvent, SearchWithinRange, + SelectionEffects, ToPoint as _, display_map::HighlightKey, editor_settings::SeedQuerySetting, persistence::{DB, SerializedEditor}, @@ -776,6 +776,10 @@ impl Item for Editor { } } + fn on_removed(&self, cx: &App) { + self.report_editor_event(ReportEditorEvent::Closed, None, cx); + } + fn deactivated(&mut self, _: &mut Window, cx: &mut Context) { let selection = self.selections.newest_anchor(); self.push_to_nav_history(selection.head(), None, true, false, cx); @@ -815,9 +819,9 @@ impl Item for Editor { ) -> Task> { // Add meta data tracking # of auto saves if options.autosave { - self.report_editor_event("Editor Autosaved", None, cx); + self.report_editor_event(ReportEditorEvent::Saved { auto_saved: true }, None, cx); } else { - self.report_editor_event("Editor Saved", None, cx); + self.report_editor_event(ReportEditorEvent::Saved { auto_saved: false }, None, cx); } let buffers = self.buffer().clone().read(cx).all_buffers(); @@ -896,7 +900,11 @@ impl Item for Editor { .path .extension() .map(|a| a.to_string_lossy().to_string()); - self.report_editor_event("Editor Saved", file_extension, cx); + self.report_editor_event( + ReportEditorEvent::Saved { auto_saved: false }, + file_extension, + cx, + ); project.update(cx, |project, cx| project.save_buffer_as(buffer, path, cx)) } @@ -997,12 +1005,16 @@ impl Item for Editor { ) { self.workspace = Some((workspace.weak_handle(), workspace.database_id())); if let Some(workspace) = &workspace.weak_handle().upgrade() { - cx.subscribe(&workspace, |editor, _, event: &workspace::Event, _cx| { - if matches!(event, workspace::Event::ModalOpened) { - editor.mouse_context_menu.take(); - editor.inline_blame_popover.take(); - } - }) + cx.subscribe( + &workspace, + |editor, _, event: &workspace::Event, _cx| match event { + workspace::Event::ModalOpened => { + editor.mouse_context_menu.take(); + editor.inline_blame_popover.take(); + } + _ => {} + }, + ) .detach(); } } diff --git a/crates/extension_host/src/extension_host.rs b/crates/extension_host/src/extension_host.rs index dc38c244f1..67baf4e692 100644 --- a/crates/extension_host/src/extension_host.rs +++ b/crates/extension_host/src/extension_host.rs @@ -1118,15 +1118,17 @@ impl ExtensionStore { extensions_to_unload.len() - reload_count ); - for extension_id in &extensions_to_load { - if let Some(extension) = new_index.extensions.get(extension_id) { - telemetry::event!( - "Extension Loaded", - extension_id, - version = extension.manifest.version - ); - } - } + let extension_ids = extensions_to_load + .iter() + .filter_map(|id| { + Some(( + id.clone(), + new_index.extensions.get(id)?.manifest.version.clone(), + )) + }) + .collect::>(); + + telemetry::event!("Extensions Loaded", id_and_versions = extension_ids); let themes_to_remove = old_index .themes diff --git a/crates/file_icons/src/file_icons.rs b/crates/file_icons/src/file_icons.rs index 2f159771b1..82a8e05d85 100644 --- a/crates/file_icons/src/file_icons.rs +++ b/crates/file_icons/src/file_icons.rs @@ -33,13 +33,23 @@ impl FileIcons { // TODO: Associate a type with the languages and have the file's language // override these associations - // check if file name is in suffixes - // e.g. catch file named `eslint.config.js` instead of `.eslint.config.js` - if let Some(typ) = path.file_name().and_then(|typ| typ.to_str()) { + if let Some(mut typ) = path.file_name().and_then(|typ| typ.to_str()) { + // check if file name is in suffixes + // e.g. catch file named `eslint.config.js` instead of `.eslint.config.js` let maybe_path = get_icon_from_suffix(typ); if maybe_path.is_some() { return maybe_path; } + + // check if suffix based on first dot is in suffixes + // e.g. consider `module.js` as suffix to angular's module file named `auth.module.js` + while let Some((_, suffix)) = typ.split_once('.') { + let maybe_path = get_icon_from_suffix(suffix); + if maybe_path.is_some() { + return maybe_path; + } + typ = suffix; + } } // primary case: check if the files extension or the hidden file name diff --git a/crates/fs/Cargo.toml b/crates/fs/Cargo.toml index 633fc1fc99..1d4161134e 100644 --- a/crates/fs/Cargo.toml +++ b/crates/fs/Cargo.toml @@ -51,6 +51,7 @@ ashpd.workspace = true [dev-dependencies] gpui = { workspace = true, features = ["test-support"] } +git = { workspace = true, features = ["test-support"] } [features] test-support = ["gpui/test-support", "git/test-support"] diff --git a/crates/fs/src/fake_git_repo.rs b/crates/fs/src/fake_git_repo.rs index 73da63fd47..f0936d400a 100644 --- a/crates/fs/src/fake_git_repo.rs +++ b/crates/fs/src/fake_git_repo.rs @@ -1,8 +1,9 @@ -use crate::{FakeFs, Fs}; +use crate::{FakeFs, FakeFsEntry, Fs}; use anyhow::{Context as _, Result}; use collections::{HashMap, HashSet}; use futures::future::{self, BoxFuture, join_all}; use git::{ + Oid, blame::Blame, repository::{ AskPassDelegate, Branch, CommitDetails, CommitOptions, FetchOptions, GitRepository, @@ -10,8 +11,9 @@ use git::{ }, status::{FileStatus, GitStatus, StatusCode, TrackedStatus, UnmergedStatus}, }; -use gpui::{AsyncApp, BackgroundExecutor, SharedString}; +use gpui::{AsyncApp, BackgroundExecutor, SharedString, Task}; use ignore::gitignore::GitignoreBuilder; +use parking_lot::Mutex; use rope::Rope; use smol::future::FutureExt as _; use std::{path::PathBuf, sync::Arc}; @@ -19,6 +21,7 @@ use std::{path::PathBuf, sync::Arc}; #[derive(Clone)] pub struct FakeGitRepository { pub(crate) fs: Arc, + pub(crate) checkpoints: Arc>>, pub(crate) executor: BackgroundExecutor, pub(crate) dot_git_path: PathBuf, pub(crate) repository_dir_path: PathBuf, @@ -183,7 +186,7 @@ impl GitRepository for FakeGitRepository { async move { None }.boxed() } - fn status(&self, path_prefixes: &[RepoPath]) -> BoxFuture<'_, Result> { + fn status(&self, path_prefixes: &[RepoPath]) -> Task> { let workdir_path = self.dot_git_path.parent().unwrap(); // Load gitignores @@ -311,7 +314,10 @@ impl GitRepository for FakeGitRepository { entries: entries.into(), }) }); - async move { result? }.boxed() + Task::ready(match result { + Ok(result) => result, + Err(e) => Err(e), + }) } fn branches(&self) -> BoxFuture<'_, Result>> { @@ -466,22 +472,57 @@ impl GitRepository for FakeGitRepository { } fn checkpoint(&self) -> BoxFuture<'static, Result> { - unimplemented!() + let executor = self.executor.clone(); + let fs = self.fs.clone(); + let checkpoints = self.checkpoints.clone(); + let repository_dir_path = self.repository_dir_path.parent().unwrap().to_path_buf(); + async move { + executor.simulate_random_delay().await; + let oid = Oid::random(&mut executor.rng()); + let entry = fs.entry(&repository_dir_path)?; + checkpoints.lock().insert(oid, entry); + Ok(GitRepositoryCheckpoint { commit_sha: oid }) + } + .boxed() } - fn restore_checkpoint( - &self, - _checkpoint: GitRepositoryCheckpoint, - ) -> BoxFuture<'_, Result<()>> { - unimplemented!() + fn restore_checkpoint(&self, checkpoint: GitRepositoryCheckpoint) -> BoxFuture<'_, Result<()>> { + let executor = self.executor.clone(); + let fs = self.fs.clone(); + let checkpoints = self.checkpoints.clone(); + let repository_dir_path = self.repository_dir_path.parent().unwrap().to_path_buf(); + async move { + executor.simulate_random_delay().await; + let checkpoints = checkpoints.lock(); + let entry = checkpoints + .get(&checkpoint.commit_sha) + .context(format!("invalid checkpoint: {}", checkpoint.commit_sha))?; + fs.insert_entry(&repository_dir_path, entry.clone())?; + Ok(()) + } + .boxed() } fn compare_checkpoints( &self, - _left: GitRepositoryCheckpoint, - _right: GitRepositoryCheckpoint, + left: GitRepositoryCheckpoint, + right: GitRepositoryCheckpoint, ) -> BoxFuture<'_, Result> { - unimplemented!() + let executor = self.executor.clone(); + let checkpoints = self.checkpoints.clone(); + async move { + executor.simulate_random_delay().await; + let checkpoints = checkpoints.lock(); + let left = checkpoints + .get(&left.commit_sha) + .context(format!("invalid left checkpoint: {}", left.commit_sha))?; + let right = checkpoints + .get(&right.commit_sha) + .context(format!("invalid right checkpoint: {}", right.commit_sha))?; + + Ok(left == right) + } + .boxed() } fn diff_checkpoints( @@ -496,3 +537,63 @@ impl GitRepository for FakeGitRepository { unimplemented!() } } + +#[cfg(test)] +mod tests { + use crate::{FakeFs, Fs}; + use gpui::BackgroundExecutor; + use serde_json::json; + use std::path::Path; + use util::path; + + #[gpui::test] + async fn test_checkpoints(executor: BackgroundExecutor) { + let fs = FakeFs::new(executor); + fs.insert_tree( + path!("/"), + json!({ + "bar": { + "baz": "qux" + }, + "foo": { + ".git": {}, + "a": "lorem", + "b": "ipsum", + }, + }), + ) + .await; + fs.with_git_state(Path::new("/foo/.git"), true, |_git| {}) + .unwrap(); + let repository = fs.open_repo(Path::new("/foo/.git")).unwrap(); + + let checkpoint_1 = repository.checkpoint().await.unwrap(); + fs.write(Path::new("/foo/b"), b"IPSUM").await.unwrap(); + fs.write(Path::new("/foo/c"), b"dolor").await.unwrap(); + let checkpoint_2 = repository.checkpoint().await.unwrap(); + let checkpoint_3 = repository.checkpoint().await.unwrap(); + + assert!( + repository + .compare_checkpoints(checkpoint_2.clone(), checkpoint_3.clone()) + .await + .unwrap() + ); + assert!( + !repository + .compare_checkpoints(checkpoint_1.clone(), checkpoint_2.clone()) + .await + .unwrap() + ); + + repository.restore_checkpoint(checkpoint_1).await.unwrap(); + assert_eq!( + fs.files_with_contents(Path::new("")), + [ + (Path::new("/bar/baz").into(), b"qux".into()), + (Path::new("/foo/a").into(), b"lorem".into()), + (Path::new("/foo/b").into(), b"ipsum".into()) + ] + ); + } +} diff --git a/crates/fs/src/fs.rs b/crates/fs/src/fs.rs index a2b75ac6a7..22bfdbcd66 100644 --- a/crates/fs/src/fs.rs +++ b/crates/fs/src/fs.rs @@ -924,7 +924,7 @@ pub struct FakeFs { #[cfg(any(test, feature = "test-support"))] struct FakeFsState { - root: Arc>, + root: FakeFsEntry, next_inode: u64, next_mtime: SystemTime, git_event_tx: smol::channel::Sender, @@ -939,7 +939,7 @@ struct FakeFsState { } #[cfg(any(test, feature = "test-support"))] -#[derive(Debug)] +#[derive(Clone, Debug)] enum FakeFsEntry { File { inode: u64, @@ -953,7 +953,7 @@ enum FakeFsEntry { inode: u64, mtime: MTime, len: u64, - entries: BTreeMap>>, + entries: BTreeMap, git_repo_state: Option>>, }, Symlink { @@ -961,6 +961,67 @@ enum FakeFsEntry { }, } +#[cfg(any(test, feature = "test-support"))] +impl PartialEq for FakeFsEntry { + fn eq(&self, other: &Self) -> bool { + match (self, other) { + ( + Self::File { + inode: l_inode, + mtime: l_mtime, + len: l_len, + content: l_content, + git_dir_path: l_git_dir_path, + }, + Self::File { + inode: r_inode, + mtime: r_mtime, + len: r_len, + content: r_content, + git_dir_path: r_git_dir_path, + }, + ) => { + l_inode == r_inode + && l_mtime == r_mtime + && l_len == r_len + && l_content == r_content + && l_git_dir_path == r_git_dir_path + } + ( + Self::Dir { + inode: l_inode, + mtime: l_mtime, + len: l_len, + entries: l_entries, + git_repo_state: l_git_repo_state, + }, + Self::Dir { + inode: r_inode, + mtime: r_mtime, + len: r_len, + entries: r_entries, + git_repo_state: r_git_repo_state, + }, + ) => { + let same_repo_state = match (l_git_repo_state.as_ref(), r_git_repo_state.as_ref()) { + (Some(l), Some(r)) => Arc::ptr_eq(l, r), + (None, None) => true, + _ => false, + }; + l_inode == r_inode + && l_mtime == r_mtime + && l_len == r_len + && l_entries == r_entries + && same_repo_state + } + (Self::Symlink { target: l_target }, Self::Symlink { target: r_target }) => { + l_target == r_target + } + _ => false, + } + } +} + #[cfg(any(test, feature = "test-support"))] impl FakeFsState { fn get_and_increment_mtime(&mut self) -> MTime { @@ -975,25 +1036,9 @@ impl FakeFsState { inode } - fn read_path(&self, target: &Path) -> Result>> { - Ok(self - .try_read_path(target, true) - .ok_or_else(|| { - anyhow!(io::Error::new( - io::ErrorKind::NotFound, - format!("not found: {target:?}") - )) - })? - .0) - } - - fn try_read_path( - &self, - target: &Path, - follow_symlink: bool, - ) -> Option<(Arc>, PathBuf)> { - let mut path = target.to_path_buf(); + fn canonicalize(&self, target: &Path, follow_symlink: bool) -> Option { let mut canonical_path = PathBuf::new(); + let mut path = target.to_path_buf(); let mut entry_stack = Vec::new(); 'outer: loop { let mut path_components = path.components().peekable(); @@ -1003,7 +1048,7 @@ impl FakeFsState { Component::Prefix(prefix_component) => prefix = Some(prefix_component), Component::RootDir => { entry_stack.clear(); - entry_stack.push(self.root.clone()); + entry_stack.push(&self.root); canonical_path.clear(); match prefix { Some(prefix_component) => { @@ -1020,20 +1065,18 @@ impl FakeFsState { canonical_path.pop(); } Component::Normal(name) => { - let current_entry = entry_stack.last().cloned()?; - let current_entry = current_entry.lock(); - if let FakeFsEntry::Dir { entries, .. } = &*current_entry { - let entry = entries.get(name.to_str().unwrap()).cloned()?; + let current_entry = *entry_stack.last()?; + if let FakeFsEntry::Dir { entries, .. } = current_entry { + let entry = entries.get(name.to_str().unwrap())?; if path_components.peek().is_some() || follow_symlink { - let entry = entry.lock(); - if let FakeFsEntry::Symlink { target, .. } = &*entry { + if let FakeFsEntry::Symlink { target, .. } = entry { let mut target = target.clone(); target.extend(path_components); path = target; continue 'outer; } } - entry_stack.push(entry.clone()); + entry_stack.push(entry); canonical_path = canonical_path.join(name); } else { return None; @@ -1043,19 +1086,72 @@ impl FakeFsState { } break; } - Some((entry_stack.pop()?, canonical_path)) + + if entry_stack.is_empty() { + None + } else { + Some(canonical_path) + } } - fn write_path(&self, path: &Path, callback: Fn) -> Result + fn try_entry( + &mut self, + target: &Path, + follow_symlink: bool, + ) -> Option<(&mut FakeFsEntry, PathBuf)> { + let canonical_path = self.canonicalize(target, follow_symlink)?; + + let mut components = canonical_path.components(); + let Some(Component::RootDir) = components.next() else { + panic!( + "the path {:?} was not canonicalized properly {:?}", + target, canonical_path + ) + }; + + let mut entry = &mut self.root; + for component in components { + match component { + Component::Normal(name) => { + if let FakeFsEntry::Dir { entries, .. } = entry { + entry = entries.get_mut(name.to_str().unwrap())?; + } else { + return None; + } + } + _ => { + panic!( + "the path {:?} was not canonicalized properly {:?}", + target, canonical_path + ) + } + } + } + + Some((entry, canonical_path)) + } + + fn entry(&mut self, target: &Path) -> Result<&mut FakeFsEntry> { + Ok(self + .try_entry(target, true) + .ok_or_else(|| { + anyhow!(io::Error::new( + io::ErrorKind::NotFound, + format!("not found: {target:?}") + )) + })? + .0) + } + + fn write_path(&mut self, path: &Path, callback: Fn) -> Result where - Fn: FnOnce(btree_map::Entry>>) -> Result, + Fn: FnOnce(btree_map::Entry) -> Result, { let path = normalize_path(path); let filename = path.file_name().context("cannot overwrite the root")?; let parent_path = path.parent().unwrap(); - let parent = self.read_path(parent_path)?; - let mut parent = parent.lock(); + let parent = self.entry(parent_path)?; let new_entry = parent .dir_entries(parent_path)? .entry(filename.to_str().unwrap().into()); @@ -1105,13 +1201,13 @@ impl FakeFs { this: this.clone(), executor: executor.clone(), state: Arc::new(Mutex::new(FakeFsState { - root: Arc::new(Mutex::new(FakeFsEntry::Dir { + root: FakeFsEntry::Dir { inode: 0, mtime: MTime(UNIX_EPOCH), len: 0, entries: Default::default(), git_repo_state: None, - })), + }, git_event_tx: tx, next_mtime: UNIX_EPOCH + Self::SYSTEMTIME_INTERVAL, next_inode: 1, @@ -1161,15 +1257,15 @@ impl FakeFs { .write_path(path, move |entry| { match entry { btree_map::Entry::Vacant(e) => { - e.insert(Arc::new(Mutex::new(FakeFsEntry::File { + e.insert(FakeFsEntry::File { inode: new_inode, mtime: new_mtime, content: Vec::new(), len: 0, git_dir_path: None, - }))); + }); } - btree_map::Entry::Occupied(mut e) => match &mut *e.get_mut().lock() { + btree_map::Entry::Occupied(mut e) => match &mut *e.get_mut() { FakeFsEntry::File { mtime, .. } => *mtime = new_mtime, FakeFsEntry::Dir { mtime, .. } => *mtime = new_mtime, FakeFsEntry::Symlink { .. } => {} @@ -1188,7 +1284,7 @@ impl FakeFs { pub async fn insert_symlink(&self, path: impl AsRef, target: PathBuf) { let mut state = self.state.lock(); let path = path.as_ref(); - let file = Arc::new(Mutex::new(FakeFsEntry::Symlink { target })); + let file = FakeFsEntry::Symlink { target }; state .write_path(path.as_ref(), move |e| match e { btree_map::Entry::Vacant(e) => { @@ -1221,13 +1317,13 @@ impl FakeFs { match entry { btree_map::Entry::Vacant(e) => { kind = Some(PathEventKind::Created); - e.insert(Arc::new(Mutex::new(FakeFsEntry::File { + e.insert(FakeFsEntry::File { inode: new_inode, mtime: new_mtime, len: new_len, content: new_content, git_dir_path: None, - }))); + }); } btree_map::Entry::Occupied(mut e) => { kind = Some(PathEventKind::Changed); @@ -1237,7 +1333,7 @@ impl FakeFs { len, content, .. - } = &mut *e.get_mut().lock() + } = e.get_mut() { *mtime = new_mtime; *content = new_content; @@ -1259,9 +1355,8 @@ impl FakeFs { pub fn read_file_sync(&self, path: impl AsRef) -> Result> { let path = path.as_ref(); let path = normalize_path(path); - let state = self.state.lock(); - let entry = state.read_path(&path)?; - let entry = entry.lock(); + let mut state = self.state.lock(); + let entry = state.entry(&path)?; entry.file_content(&path).cloned() } @@ -1269,9 +1364,8 @@ impl FakeFs { let path = path.as_ref(); let path = normalize_path(path); self.simulate_random_delay().await; - let state = self.state.lock(); - let entry = state.read_path(&path)?; - let entry = entry.lock(); + let mut state = self.state.lock(); + let entry = state.entry(&path)?; entry.file_content(&path).cloned() } @@ -1292,6 +1386,25 @@ impl FakeFs { self.state.lock().flush_events(count); } + pub(crate) fn entry(&self, target: &Path) -> Result { + self.state.lock().entry(target).cloned() + } + + pub(crate) fn insert_entry(&self, target: &Path, new_entry: FakeFsEntry) -> Result<()> { + let mut state = self.state.lock(); + state.write_path(target, |entry| { + match entry { + btree_map::Entry::Vacant(vacant_entry) => { + vacant_entry.insert(new_entry); + } + btree_map::Entry::Occupied(mut occupied_entry) => { + occupied_entry.insert(new_entry); + } + } + Ok(()) + }) + } + #[must_use] pub fn insert_tree<'a>( &'a self, @@ -1361,20 +1474,19 @@ impl FakeFs { F: FnOnce(&mut FakeGitRepositoryState, &Path, &Path) -> T, { let mut state = self.state.lock(); - let entry = state.read_path(dot_git).context("open .git")?; - let mut entry = entry.lock(); + let git_event_tx = state.git_event_tx.clone(); + let entry = state.entry(dot_git).context("open .git")?; - if let FakeFsEntry::Dir { git_repo_state, .. } = &mut *entry { + if let FakeFsEntry::Dir { git_repo_state, .. } = entry { let repo_state = git_repo_state.get_or_insert_with(|| { log::debug!("insert git state for {dot_git:?}"); - Arc::new(Mutex::new(FakeGitRepositoryState::new( - state.git_event_tx.clone(), - ))) + Arc::new(Mutex::new(FakeGitRepositoryState::new(git_event_tx))) }); let mut repo_state = repo_state.lock(); let result = f(&mut repo_state, dot_git, dot_git); + drop(repo_state); if emit_git_event { state.emit_event([(dot_git, None)]); } @@ -1398,21 +1510,20 @@ impl FakeFs { } } .clone(); - drop(entry); - let Some((git_dir_entry, canonical_path)) = state.try_read_path(&path, true) else { + let Some((git_dir_entry, canonical_path)) = state.try_entry(&path, true) else { anyhow::bail!("pointed-to git dir {path:?} not found") }; let FakeFsEntry::Dir { git_repo_state, entries, .. - } = &mut *git_dir_entry.lock() + } = git_dir_entry else { anyhow::bail!("gitfile points to a non-directory") }; let common_dir = if let Some(child) = entries.get("commondir") { Path::new( - std::str::from_utf8(child.lock().file_content("commondir".as_ref())?) + std::str::from_utf8(child.file_content("commondir".as_ref())?) .context("commondir content")?, ) .to_owned() @@ -1420,15 +1531,14 @@ impl FakeFs { canonical_path.clone() }; let repo_state = git_repo_state.get_or_insert_with(|| { - Arc::new(Mutex::new(FakeGitRepositoryState::new( - state.git_event_tx.clone(), - ))) + Arc::new(Mutex::new(FakeGitRepositoryState::new(git_event_tx))) }); let mut repo_state = repo_state.lock(); let result = f(&mut repo_state, &canonical_path, &common_dir); if emit_git_event { + drop(repo_state); state.emit_event([(canonical_path, None)]); } @@ -1655,14 +1765,12 @@ impl FakeFs { pub fn paths(&self, include_dot_git: bool) -> Vec { let mut result = Vec::new(); let mut queue = collections::VecDeque::new(); - queue.push_back(( - PathBuf::from(util::path!("/")), - self.state.lock().root.clone(), - )); + let state = &*self.state.lock(); + queue.push_back((PathBuf::from(util::path!("/")), &state.root)); while let Some((path, entry)) = queue.pop_front() { - if let FakeFsEntry::Dir { entries, .. } = &*entry.lock() { + if let FakeFsEntry::Dir { entries, .. } = entry { for (name, entry) in entries { - queue.push_back((path.join(name), entry.clone())); + queue.push_back((path.join(name), entry)); } } if include_dot_git @@ -1679,14 +1787,12 @@ impl FakeFs { pub fn directories(&self, include_dot_git: bool) -> Vec { let mut result = Vec::new(); let mut queue = collections::VecDeque::new(); - queue.push_back(( - PathBuf::from(util::path!("/")), - self.state.lock().root.clone(), - )); + let state = &*self.state.lock(); + queue.push_back((PathBuf::from(util::path!("/")), &state.root)); while let Some((path, entry)) = queue.pop_front() { - if let FakeFsEntry::Dir { entries, .. } = &*entry.lock() { + if let FakeFsEntry::Dir { entries, .. } = entry { for (name, entry) in entries { - queue.push_back((path.join(name), entry.clone())); + queue.push_back((path.join(name), entry)); } if include_dot_git || !path @@ -1703,17 +1809,14 @@ impl FakeFs { pub fn files(&self) -> Vec { let mut result = Vec::new(); let mut queue = collections::VecDeque::new(); - queue.push_back(( - PathBuf::from(util::path!("/")), - self.state.lock().root.clone(), - )); + let state = &*self.state.lock(); + queue.push_back((PathBuf::from(util::path!("/")), &state.root)); while let Some((path, entry)) = queue.pop_front() { - let e = entry.lock(); - match &*e { + match entry { FakeFsEntry::File { .. } => result.push(path), FakeFsEntry::Dir { entries, .. } => { for (name, entry) in entries { - queue.push_back((path.join(name), entry.clone())); + queue.push_back((path.join(name), entry)); } } FakeFsEntry::Symlink { .. } => {} @@ -1725,13 +1828,10 @@ impl FakeFs { pub fn files_with_contents(&self, prefix: &Path) -> Vec<(PathBuf, Vec)> { let mut result = Vec::new(); let mut queue = collections::VecDeque::new(); - queue.push_back(( - PathBuf::from(util::path!("/")), - self.state.lock().root.clone(), - )); + let state = &*self.state.lock(); + queue.push_back((PathBuf::from(util::path!("/")), &state.root)); while let Some((path, entry)) = queue.pop_front() { - let e = entry.lock(); - match &*e { + match entry { FakeFsEntry::File { content, .. } => { if path.starts_with(prefix) { result.push((path, content.clone())); @@ -1739,7 +1839,7 @@ impl FakeFs { } FakeFsEntry::Dir { entries, .. } => { for (name, entry) in entries { - queue.push_back((path.join(name), entry.clone())); + queue.push_back((path.join(name), entry)); } } FakeFsEntry::Symlink { .. } => {} @@ -1805,10 +1905,7 @@ impl FakeFsEntry { } } - fn dir_entries( - &mut self, - path: &Path, - ) -> Result<&mut BTreeMap>>> { + fn dir_entries(&mut self, path: &Path) -> Result<&mut BTreeMap> { if let Self::Dir { entries, .. } = self { Ok(entries) } else { @@ -1855,12 +1952,12 @@ struct FakeHandle { impl FileHandle for FakeHandle { fn current_path(&self, fs: &Arc) -> Result { let fs = fs.as_fake(); - let state = fs.state.lock(); - let Some(target) = state.moves.get(&self.inode) else { + let mut state = fs.state.lock(); + let Some(target) = state.moves.get(&self.inode).cloned() else { anyhow::bail!("fake fd not moved") }; - if state.try_read_path(&target, false).is_some() { + if state.try_entry(&target, false).is_some() { return Ok(target.clone()); } anyhow::bail!("fake fd target not found") @@ -1888,13 +1985,13 @@ impl Fs for FakeFs { state.write_path(&cur_path, |entry| { entry.or_insert_with(|| { created_dirs.push((cur_path.clone(), Some(PathEventKind::Created))); - Arc::new(Mutex::new(FakeFsEntry::Dir { + FakeFsEntry::Dir { inode, mtime, len: 0, entries: Default::default(), git_repo_state: None, - })) + } }); Ok(()) })? @@ -1909,13 +2006,13 @@ impl Fs for FakeFs { let mut state = self.state.lock(); let inode = state.get_and_increment_inode(); let mtime = state.get_and_increment_mtime(); - let file = Arc::new(Mutex::new(FakeFsEntry::File { + let file = FakeFsEntry::File { inode, mtime, len: 0, content: Vec::new(), git_dir_path: None, - })); + }; let mut kind = Some(PathEventKind::Created); state.write_path(path, |entry| { match entry { @@ -1939,7 +2036,7 @@ impl Fs for FakeFs { async fn create_symlink(&self, path: &Path, target: PathBuf) -> Result<()> { let mut state = self.state.lock(); - let file = Arc::new(Mutex::new(FakeFsEntry::Symlink { target })); + let file = FakeFsEntry::Symlink { target }; state .write_path(path.as_ref(), move |e| match e { btree_map::Entry::Vacant(e) => { @@ -2002,7 +2099,7 @@ impl Fs for FakeFs { } })?; - let inode = match *moved_entry.lock() { + let inode = match moved_entry { FakeFsEntry::File { inode, .. } => inode, FakeFsEntry::Dir { inode, .. } => inode, _ => 0, @@ -2051,8 +2148,8 @@ impl Fs for FakeFs { let mut state = self.state.lock(); let mtime = state.get_and_increment_mtime(); let inode = state.get_and_increment_inode(); - let source_entry = state.read_path(&source)?; - let content = source_entry.lock().file_content(&source)?.clone(); + let source_entry = state.entry(&source)?; + let content = source_entry.file_content(&source)?.clone(); let mut kind = Some(PathEventKind::Created); state.write_path(&target, |e| match e { btree_map::Entry::Occupied(e) => { @@ -2066,13 +2163,13 @@ impl Fs for FakeFs { } } btree_map::Entry::Vacant(e) => Ok(Some( - e.insert(Arc::new(Mutex::new(FakeFsEntry::File { + e.insert(FakeFsEntry::File { inode, mtime, len: content.len() as u64, content, git_dir_path: None, - }))) + }) .clone(), )), })?; @@ -2088,8 +2185,7 @@ impl Fs for FakeFs { let base_name = path.file_name().context("cannot remove the root")?; let mut state = self.state.lock(); - let parent_entry = state.read_path(parent_path)?; - let mut parent_entry = parent_entry.lock(); + let parent_entry = state.entry(parent_path)?; let entry = parent_entry .dir_entries(parent_path)? .entry(base_name.to_str().unwrap().into()); @@ -2100,15 +2196,14 @@ impl Fs for FakeFs { anyhow::bail!("{path:?} does not exist"); } } - btree_map::Entry::Occupied(e) => { + btree_map::Entry::Occupied(mut entry) => { { - let mut entry = e.get().lock(); - let children = entry.dir_entries(&path)?; + let children = entry.get_mut().dir_entries(&path)?; if !options.recursive && !children.is_empty() { anyhow::bail!("{path:?} is not empty"); } } - e.remove(); + entry.remove(); } } state.emit_event([(path, Some(PathEventKind::Removed))]); @@ -2122,8 +2217,7 @@ impl Fs for FakeFs { let parent_path = path.parent().context("cannot remove the root")?; let base_name = path.file_name().unwrap(); let mut state = self.state.lock(); - let parent_entry = state.read_path(parent_path)?; - let mut parent_entry = parent_entry.lock(); + let parent_entry = state.entry(parent_path)?; let entry = parent_entry .dir_entries(parent_path)? .entry(base_name.to_str().unwrap().into()); @@ -2133,9 +2227,9 @@ impl Fs for FakeFs { anyhow::bail!("{path:?} does not exist"); } } - btree_map::Entry::Occupied(e) => { - e.get().lock().file_content(&path)?; - e.remove(); + btree_map::Entry::Occupied(mut entry) => { + entry.get_mut().file_content(&path)?; + entry.remove(); } } state.emit_event([(path, Some(PathEventKind::Removed))]); @@ -2149,12 +2243,10 @@ impl Fs for FakeFs { async fn open_handle(&self, path: &Path) -> Result> { self.simulate_random_delay().await; - let state = self.state.lock(); - let entry = state.read_path(&path)?; - let entry = entry.lock(); - let inode = match *entry { - FakeFsEntry::File { inode, .. } => inode, - FakeFsEntry::Dir { inode, .. } => inode, + let mut state = self.state.lock(); + let inode = match state.entry(&path)? { + FakeFsEntry::File { inode, .. } => *inode, + FakeFsEntry::Dir { inode, .. } => *inode, _ => unreachable!(), }; Ok(Arc::new(FakeHandle { inode })) @@ -2204,8 +2296,8 @@ impl Fs for FakeFs { let path = normalize_path(path); self.simulate_random_delay().await; let state = self.state.lock(); - let (_, canonical_path) = state - .try_read_path(&path, true) + let canonical_path = state + .canonicalize(&path, true) .with_context(|| format!("path does not exist: {path:?}"))?; Ok(canonical_path) } @@ -2213,9 +2305,9 @@ impl Fs for FakeFs { async fn is_file(&self, path: &Path) -> bool { let path = normalize_path(path); self.simulate_random_delay().await; - let state = self.state.lock(); - if let Some((entry, _)) = state.try_read_path(&path, true) { - entry.lock().is_file() + let mut state = self.state.lock(); + if let Some((entry, _)) = state.try_entry(&path, true) { + entry.is_file() } else { false } @@ -2232,17 +2324,16 @@ impl Fs for FakeFs { let path = normalize_path(path); let mut state = self.state.lock(); state.metadata_call_count += 1; - if let Some((mut entry, _)) = state.try_read_path(&path, false) { - let is_symlink = entry.lock().is_symlink(); + if let Some((mut entry, _)) = state.try_entry(&path, false) { + let is_symlink = entry.is_symlink(); if is_symlink { - if let Some(e) = state.try_read_path(&path, true).map(|e| e.0) { + if let Some(e) = state.try_entry(&path, true).map(|e| e.0) { entry = e; } else { return Ok(None); } } - let entry = entry.lock(); Ok(Some(match &*entry { FakeFsEntry::File { inode, mtime, len, .. @@ -2274,12 +2365,11 @@ impl Fs for FakeFs { async fn read_link(&self, path: &Path) -> Result { self.simulate_random_delay().await; let path = normalize_path(path); - let state = self.state.lock(); + let mut state = self.state.lock(); let (entry, _) = state - .try_read_path(&path, false) + .try_entry(&path, false) .with_context(|| format!("path does not exist: {path:?}"))?; - let entry = entry.lock(); - if let FakeFsEntry::Symlink { target } = &*entry { + if let FakeFsEntry::Symlink { target } = entry { Ok(target.clone()) } else { anyhow::bail!("not a symlink: {path:?}") @@ -2294,8 +2384,7 @@ impl Fs for FakeFs { let path = normalize_path(path); let mut state = self.state.lock(); state.read_dir_call_count += 1; - let entry = state.read_path(&path)?; - let mut entry = entry.lock(); + let entry = state.entry(&path)?; let children = entry.dir_entries(&path)?; let paths = children .keys() @@ -2359,6 +2448,7 @@ impl Fs for FakeFs { dot_git_path: abs_dot_git.to_path_buf(), repository_dir_path: repository_dir_path.to_owned(), common_dir_path: common_dir_path.to_owned(), + checkpoints: Arc::default(), }) as _ }, ) diff --git a/crates/git/Cargo.toml b/crates/git/Cargo.toml index ab2210094d..74656f1d4c 100644 --- a/crates/git/Cargo.toml +++ b/crates/git/Cargo.toml @@ -12,7 +12,7 @@ workspace = true path = "src/git.rs" [features] -test-support = [] +test-support = ["rand"] [dependencies] anyhow.workspace = true @@ -26,6 +26,7 @@ http_client.workspace = true log.workspace = true parking_lot.workspace = true regex.workspace = true +rand = { workspace = true, optional = true } rope.workspace = true schemars.workspace = true serde.workspace = true @@ -47,3 +48,4 @@ text = { workspace = true, features = ["test-support"] } unindent.workspace = true gpui = { workspace = true, features = ["test-support"] } tempfile.workspace = true +rand.workspace = true diff --git a/crates/git/src/git.rs b/crates/git/src/git.rs index e6336eb656..e84014129c 100644 --- a/crates/git/src/git.rs +++ b/crates/git/src/git.rs @@ -119,6 +119,13 @@ impl Oid { Ok(Self(oid)) } + #[cfg(any(test, feature = "test-support"))] + pub fn random(rng: &mut impl rand::Rng) -> Self { + let mut bytes = [0; 20]; + rng.fill(&mut bytes); + Self::from_bytes(&bytes).unwrap() + } + pub fn as_bytes(&self) -> &[u8] { self.0.as_bytes() } diff --git a/crates/git/src/repository.rs b/crates/git/src/repository.rs index 518b6c4f46..49eee84840 100644 --- a/crates/git/src/repository.rs +++ b/crates/git/src/repository.rs @@ -6,7 +6,7 @@ use collections::HashMap; use futures::future::BoxFuture; use futures::{AsyncWriteExt, FutureExt as _, select_biased}; use git2::BranchType; -use gpui::{AppContext as _, AsyncApp, BackgroundExecutor, SharedString}; +use gpui::{AppContext as _, AsyncApp, BackgroundExecutor, SharedString, Task}; use parking_lot::Mutex; use rope::Rope; use schemars::JsonSchema; @@ -338,7 +338,7 @@ pub trait GitRepository: Send + Sync { fn merge_message(&self) -> BoxFuture<'_, Option>; - fn status(&self, path_prefixes: &[RepoPath]) -> BoxFuture<'_, Result>; + fn status(&self, path_prefixes: &[RepoPath]) -> Task>; fn branches(&self) -> BoxFuture<'_, Result>>; @@ -953,25 +953,27 @@ impl GitRepository for RealGitRepository { .boxed() } - fn status(&self, path_prefixes: &[RepoPath]) -> BoxFuture<'_, Result> { + fn status(&self, path_prefixes: &[RepoPath]) -> Task> { let git_binary_path = self.git_binary_path.clone(); - let working_directory = self.working_directory(); - let path_prefixes = path_prefixes.to_owned(); - self.executor - .spawn(async move { - let output = new_std_command(&git_binary_path) - .current_dir(working_directory?) - .args(git_status_args(&path_prefixes)) - .output()?; - if output.status.success() { - let stdout = String::from_utf8_lossy(&output.stdout); - stdout.parse() - } else { - let stderr = String::from_utf8_lossy(&output.stderr); - anyhow::bail!("git status failed: {stderr}"); - } - }) - .boxed() + let working_directory = match self.working_directory() { + Ok(working_directory) => working_directory, + Err(e) => return Task::ready(Err(e)), + }; + let args = git_status_args(&path_prefixes); + log::debug!("Checking for git status in {path_prefixes:?}"); + self.executor.spawn(async move { + let output = new_std_command(&git_binary_path) + .current_dir(working_directory) + .args(args) + .output()?; + if output.status.success() { + let stdout = String::from_utf8_lossy(&output.stdout); + stdout.parse() + } else { + let stderr = String::from_utf8_lossy(&output.stderr); + anyhow::bail!("git status failed: {stderr}"); + } + }) } fn branches(&self) -> BoxFuture<'_, Result>> { diff --git a/crates/git_ui/src/git_panel.rs b/crates/git_ui/src/git_panel.rs index 75fac114d2..de308b9dde 100644 --- a/crates/git_ui/src/git_panel.rs +++ b/crates/git_ui/src/git_panel.rs @@ -2105,7 +2105,7 @@ impl GitPanel { Ok(_) => cx.update(|window, cx| { window.prompt( PromptLevel::Info, - "Git Clone", + &format!("Git Clone: {}", repo_name), None, &["Add repo to project", "Open repo in new project"], cx, diff --git a/crates/git_ui/src/git_ui.rs b/crates/git_ui/src/git_ui.rs index 7d5207dfb6..79aa4a6bd0 100644 --- a/crates/git_ui/src/git_ui.rs +++ b/crates/git_ui/src/git_ui.rs @@ -181,10 +181,6 @@ pub fn init(cx: &mut App) { workspace.toggle_modal(window, cx, |window, cx| { GitCloneModal::show(panel, window, cx) }); - - // panel.update(cx, |panel, cx| { - // panel.git_clone(window, cx); - // }); }); workspace.register_action(|workspace, _: &git::OpenModifiedFiles, window, cx| { open_modified_files(workspace, window, cx); diff --git a/crates/gpui/src/app.rs b/crates/gpui/src/app.rs index ded7bae316..5f6d252503 100644 --- a/crates/gpui/src/app.rs +++ b/crates/gpui/src/app.rs @@ -277,6 +277,8 @@ pub struct App { pub(crate) release_listeners: SubscriberSet, pub(crate) global_observers: SubscriberSet, pub(crate) quit_observers: SubscriberSet<(), QuitHandler>, + pub(crate) restart_observers: SubscriberSet<(), Handler>, + pub(crate) restart_path: Option, pub(crate) window_closed_observers: SubscriberSet<(), WindowClosedHandler>, pub(crate) layout_id_buffer: Vec, // We recycle this memory across layout requests. pub(crate) propagate_event: bool, @@ -349,6 +351,8 @@ impl App { keyboard_layout_observers: SubscriberSet::new(), global_observers: SubscriberSet::new(), quit_observers: SubscriberSet::new(), + restart_observers: SubscriberSet::new(), + restart_path: None, window_closed_observers: SubscriberSet::new(), layout_id_buffer: Default::default(), propagate_event: true, @@ -832,8 +836,16 @@ impl App { } /// Restarts the application. - pub fn restart(&self, binary_path: Option) { - self.platform.restart(binary_path) + pub fn restart(&mut self) { + self.restart_observers + .clone() + .retain(&(), |observer| observer(self)); + self.platform.restart(self.restart_path.take()) + } + + /// Sets the path to use when restarting the application. + pub fn set_restart_path(&mut self, path: PathBuf) { + self.restart_path = Some(path); } /// Returns the HTTP client for the application. @@ -1466,6 +1478,21 @@ impl App { subscription } + /// Register a callback to be invoked when the application is about to restart. + /// + /// These callbacks are called before any `on_app_quit` callbacks. + pub fn on_app_restart(&self, mut on_restart: impl 'static + FnMut(&mut App)) -> Subscription { + let (subscription, activate) = self.restart_observers.insert( + (), + Box::new(move |cx| { + on_restart(cx); + true + }), + ); + activate(); + subscription + } + /// Register a callback to be invoked when a window is closed /// The window is no longer accessible at the point this callback is invoked. pub fn on_window_closed(&self, mut on_closed: impl FnMut(&mut App) + 'static) -> Subscription { diff --git a/crates/gpui/src/app/context.rs b/crates/gpui/src/app/context.rs index 392be2ffe9..68c41592b3 100644 --- a/crates/gpui/src/app/context.rs +++ b/crates/gpui/src/app/context.rs @@ -164,6 +164,20 @@ impl<'a, T: 'static> Context<'a, T> { subscription } + /// Register a callback to be invoked when the application is about to restart. + pub fn on_app_restart( + &self, + mut on_restart: impl FnMut(&mut T, &mut App) + 'static, + ) -> Subscription + where + T: 'static, + { + let handle = self.weak_entity(); + self.app.on_app_restart(move |cx| { + handle.update(cx, |entity, cx| on_restart(entity, cx)).ok(); + }) + } + /// Arrange for the given function to be invoked whenever the application is quit. /// The future returned from this callback will be polled for up to [crate::SHUTDOWN_TIMEOUT] until the app fully quits. pub fn on_app_quit( @@ -175,20 +189,15 @@ impl<'a, T: 'static> Context<'a, T> { T: 'static, { let handle = self.weak_entity(); - let (subscription, activate) = self.app.quit_observers.insert( - (), - Box::new(move |cx| { - let future = handle.update(cx, |entity, cx| on_quit(entity, cx)).ok(); - async move { - if let Some(future) = future { - future.await; - } + self.app.on_app_quit(move |cx| { + let future = handle.update(cx, |entity, cx| on_quit(entity, cx)).ok(); + async move { + if let Some(future) = future { + future.await; } - .boxed_local() - }), - ); - activate(); - subscription + } + .boxed_local() + }) } /// Tell GPUI that this entity has changed and observers of it should be notified. diff --git a/crates/gpui/src/platform/windows/platform.rs b/crates/gpui/src/platform/windows/platform.rs index 9e5d359e43..bbde655b80 100644 --- a/crates/gpui/src/platform/windows/platform.rs +++ b/crates/gpui/src/platform/windows/platform.rs @@ -370,9 +370,9 @@ impl Platform for WindowsPlatform { .detach(); } - fn restart(&self, _: Option) { + fn restart(&self, binary_path: Option) { let pid = std::process::id(); - let Some(app_path) = self.app_path().log_err() else { + let Some(app_path) = binary_path.or(self.app_path().log_err()) else { return; }; let script = format!( diff --git a/crates/language_models/src/provider/cloud.rs b/crates/language_models/src/provider/cloud.rs index ba110be9c5..ff8048040e 100644 --- a/crates/language_models/src/provider/cloud.rs +++ b/crates/language_models/src/provider/cloud.rs @@ -942,6 +942,7 @@ impl LanguageModel for CloudLanguageModel { model.id(), model.supports_parallel_tool_calls(), None, + None, ); let llm_api_token = self.llm_api_token.clone(); let future = self.request_limiter.stream(async move { diff --git a/crates/language_models/src/provider/open_ai.rs b/crates/language_models/src/provider/open_ai.rs index 9eac58c880..725027b2a7 100644 --- a/crates/language_models/src/provider/open_ai.rs +++ b/crates/language_models/src/provider/open_ai.rs @@ -14,7 +14,7 @@ use language_model::{ RateLimiter, Role, StopReason, TokenUsage, }; use menu; -use open_ai::{ImageUrl, Model, ResponseStreamEvent, stream_completion}; +use open_ai::{ImageUrl, Model, ReasoningEffort, ResponseStreamEvent, stream_completion}; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use settings::{Settings, SettingsStore}; @@ -45,6 +45,7 @@ pub struct AvailableModel { pub max_tokens: u64, pub max_output_tokens: Option, pub max_completion_tokens: Option, + pub reasoning_effort: Option, } pub struct OpenAiLanguageModelProvider { @@ -213,6 +214,7 @@ impl LanguageModelProvider for OpenAiLanguageModelProvider { max_tokens: model.max_tokens, max_output_tokens: model.max_output_tokens, max_completion_tokens: model.max_completion_tokens, + reasoning_effort: model.reasoning_effort.clone(), }, ); } @@ -369,6 +371,7 @@ impl LanguageModel for OpenAiLanguageModel { self.model.id(), self.model.supports_parallel_tool_calls(), self.max_output_tokens(), + self.model.reasoning_effort(), ); let completions = self.stream_completion(request, cx); async move { @@ -384,6 +387,7 @@ pub fn into_open_ai( model_id: &str, supports_parallel_tool_calls: bool, max_output_tokens: Option, + reasoning_effort: Option, ) -> open_ai::Request { let stream = !model_id.starts_with("o1-"); @@ -490,6 +494,7 @@ pub fn into_open_ai( LanguageModelToolChoice::Any => open_ai::ToolChoice::Required, LanguageModelToolChoice::None => open_ai::ToolChoice::None, }), + reasoning_effort, } } diff --git a/crates/language_models/src/provider/open_ai_compatible.rs b/crates/language_models/src/provider/open_ai_compatible.rs index 38bd7cee06..6e912765cd 100644 --- a/crates/language_models/src/provider/open_ai_compatible.rs +++ b/crates/language_models/src/provider/open_ai_compatible.rs @@ -355,7 +355,13 @@ impl LanguageModel for OpenAiCompatibleLanguageModel { LanguageModelCompletionError, >, > { - let request = into_open_ai(request, &self.model.name, true, self.max_output_tokens()); + let request = into_open_ai( + request, + &self.model.name, + true, + self.max_output_tokens(), + None, + ); let completions = self.stream_completion(request, cx); async move { let mapper = OpenAiEventMapper::new(); diff --git a/crates/language_models/src/provider/vercel.rs b/crates/language_models/src/provider/vercel.rs index 037ce467d0..57a89ba4aa 100644 --- a/crates/language_models/src/provider/vercel.rs +++ b/crates/language_models/src/provider/vercel.rs @@ -356,6 +356,7 @@ impl LanguageModel for VercelLanguageModel { self.model.id(), self.model.supports_parallel_tool_calls(), self.max_output_tokens(), + None, ); let completions = self.stream_completion(request, cx); async move { diff --git a/crates/language_models/src/provider/x_ai.rs b/crates/language_models/src/provider/x_ai.rs index 5f6034571b..5e7190ea96 100644 --- a/crates/language_models/src/provider/x_ai.rs +++ b/crates/language_models/src/provider/x_ai.rs @@ -360,6 +360,7 @@ impl LanguageModel for XAiLanguageModel { self.model.id(), self.model.supports_parallel_tool_calls(), self.max_output_tokens(), + None, ); let completions = self.stream_completion(request, cx); async move { diff --git a/crates/languages/src/cpp/outline.scm b/crates/languages/src/cpp/outline.scm index 448fe35220..c897366558 100644 --- a/crates/languages/src/cpp/outline.scm +++ b/crates/languages/src/cpp/outline.scm @@ -149,7 +149,9 @@ parameters: (parameter_list "(" @context ")" @context))) - ] - (type_qualifier)? @context) @item + ; Fields declarations may define multiple fields, and so @item is on the + ; declarator so they each get distinct ranges. + ] @item + (type_qualifier)? @context) (comment) @annotation diff --git a/crates/languages/src/go/outline.scm b/crates/languages/src/go/outline.scm index e37ae7e572..c745f55aff 100644 --- a/crates/languages/src/go/outline.scm +++ b/crates/languages/src/go/outline.scm @@ -1,4 +1,5 @@ (comment) @annotation + (type_declaration "type" @context [ @@ -42,13 +43,13 @@ (var_declaration "var" @context [ + ; The declaration may define multiple variables, and so @item is on + ; the identifier so they get distinct ranges. (var_spec - name: (identifier) @name) @item + name: (identifier) @name @item) (var_spec_list - "(" (var_spec - name: (identifier) @name) @item - ")" + name: (identifier) @name @item) ) ] ) @@ -60,5 +61,7 @@ "(" @context ")" @context)) @item +; Fields declarations may define multiple fields, and so @item is on the +; declarator so they each get distinct ranges. (field_declaration - name: (_) @name) @item + name: (_) @name @item) diff --git a/crates/languages/src/javascript/outline.scm b/crates/languages/src/javascript/outline.scm index 026c71e1f9..ca16c27a27 100644 --- a/crates/languages/src/javascript/outline.scm +++ b/crates/languages/src/javascript/outline.scm @@ -31,12 +31,16 @@ (export_statement (lexical_declaration ["let" "const"] @context + ; Multiple names may be exported - @item is on the declarator to keep + ; ranges distinct. (variable_declarator name: (_) @name) @item))) (program (lexical_declaration ["let" "const"] @context + ; Multiple names may be defined - @item is on the declarator to keep + ; ranges distinct. (variable_declarator name: (_) @name) @item)) diff --git a/crates/languages/src/tsx/outline.scm b/crates/languages/src/tsx/outline.scm index 5dafe791e4..f4261b9697 100644 --- a/crates/languages/src/tsx/outline.scm +++ b/crates/languages/src/tsx/outline.scm @@ -34,12 +34,16 @@ (export_statement (lexical_declaration ["let" "const"] @context + ; Multiple names may be exported - @item is on the declarator to keep + ; ranges distinct. (variable_declarator name: (_) @name) @item)) (program (lexical_declaration ["let" "const"] @context + ; Multiple names may be defined - @item is on the declarator to keep + ; ranges distinct. (variable_declarator name: (_) @name) @item)) diff --git a/crates/languages/src/typescript/outline.scm b/crates/languages/src/typescript/outline.scm index 5dafe791e4..f4261b9697 100644 --- a/crates/languages/src/typescript/outline.scm +++ b/crates/languages/src/typescript/outline.scm @@ -34,12 +34,16 @@ (export_statement (lexical_declaration ["let" "const"] @context + ; Multiple names may be exported - @item is on the declarator to keep + ; ranges distinct. (variable_declarator name: (_) @name) @item)) (program (lexical_declaration ["let" "const"] @context + ; Multiple names may be defined - @item is on the declarator to keep + ; ranges distinct. (variable_declarator name: (_) @name) @item)) diff --git a/crates/onboarding/Cargo.toml b/crates/onboarding/Cargo.toml index 436c714cf3..4157be3172 100644 --- a/crates/onboarding/Cargo.toml +++ b/crates/onboarding/Cargo.toml @@ -18,14 +18,13 @@ default = [] ai_onboarding.workspace = true anyhow.workspace = true client.workspace = true -command_palette_hooks.workspace = true component.workspace = true db.workspace = true documented.workspace = true editor.workspace = true -feature_flags.workspace = true fs.workspace = true fuzzy.workspace = true +git.workspace = true gpui.workspace = true itertools.workspace = true language.workspace = true @@ -37,6 +36,7 @@ project.workspace = true schemars.workspace = true serde.workspace = true settings.workspace = true +telemetry.workspace = true theme.workspace = true ui.workspace = true util.workspace = true diff --git a/crates/onboarding/src/ai_setup_page.rs b/crates/onboarding/src/ai_setup_page.rs index 8203f96479..bb1932bdf2 100644 --- a/crates/onboarding/src/ai_setup_page.rs +++ b/crates/onboarding/src/ai_setup_page.rs @@ -188,6 +188,11 @@ fn render_llm_provider_card( workspace .update(cx, |workspace, cx| { workspace.toggle_modal(window, cx, |window, cx| { + telemetry::event!( + "Welcome AI Modal Opened", + provider = provider.name().0, + ); + let modal = AiConfigurationModal::new( provider.clone(), window, @@ -245,16 +250,25 @@ pub(crate) fn render_ai_setup_page( ToggleState::Selected }, |&toggle_state, _, cx| { + let enabled = match toggle_state { + ToggleState::Indeterminate => { + return; + } + ToggleState::Unselected => true, + ToggleState::Selected => false, + }; + + telemetry::event!( + "Welcome AI Enabled", + toggle = if enabled { "on" } else { "off" }, + ); + let fs = ::global(cx); update_settings_file::( fs, cx, move |ai_settings: &mut Option, _| { - *ai_settings = match toggle_state { - ToggleState::Indeterminate => None, - ToggleState::Unselected => Some(true), - ToggleState::Selected => Some(false), - }; + *ai_settings = Some(enabled); }, ); }, diff --git a/crates/welcome/src/base_keymap_picker.rs b/crates/onboarding/src/base_keymap_picker.rs similarity index 99% rename from crates/welcome/src/base_keymap_picker.rs rename to crates/onboarding/src/base_keymap_picker.rs index 92317ca711..0ac07d9a9d 100644 --- a/crates/welcome/src/base_keymap_picker.rs +++ b/crates/onboarding/src/base_keymap_picker.rs @@ -12,7 +12,7 @@ use util::ResultExt; use workspace::{ModalView, Workspace, ui::HighlightedLabel}; actions!( - welcome, + zed, [ /// Toggles the base keymap selector modal. ToggleBaseKeymapSelector diff --git a/crates/onboarding/src/basics_page.rs b/crates/onboarding/src/basics_page.rs index a19a21fddf..86ddc22a86 100644 --- a/crates/onboarding/src/basics_page.rs +++ b/crates/onboarding/src/basics_page.rs @@ -58,7 +58,7 @@ fn render_theme_section(tab_index: &mut isize, cx: &mut App) -> impl IntoElement .tab_index(tab_index) .selected_index(theme_mode as usize) .style(ui::ToggleButtonGroupStyle::Outlined) - .button_width(rems_from_px(64.)), + .width(rems_from_px(3. * 64.)), ), ) .child( @@ -305,8 +305,8 @@ fn render_base_keymap_section(tab_index: &mut isize, cx: &mut App) -> impl IntoE .when_some(base_keymap, |this, base_keymap| { this.selected_index(base_keymap) }) + .full_width() .tab_index(tab_index) - .button_width(rems_from_px(216.)) .size(ui::ToggleButtonGroupSize::Medium) .style(ui::ToggleButtonGroupStyle::Outlined), ); diff --git a/crates/onboarding/src/editing_page.rs b/crates/onboarding/src/editing_page.rs index 13b4f6a5c1..aa7f4eee74 100644 --- a/crates/onboarding/src/editing_page.rs +++ b/crates/onboarding/src/editing_page.rs @@ -35,6 +35,11 @@ fn write_show_mini_map(show: ShowMinimap, cx: &mut App) { EditorSettings::override_global(curr_settings, cx); update_settings_file::(fs, cx, move |editor_settings, _| { + telemetry::event!( + "Welcome Minimap Clicked", + from = editor_settings.minimap.unwrap_or_default(), + to = show + ); editor_settings.minimap.get_or_insert_default().show = Some(show); }); } @@ -71,7 +76,7 @@ fn read_git_blame(cx: &App) -> bool { ProjectSettings::get_global(cx).git.inline_blame_enabled() } -fn set_git_blame(enabled: bool, cx: &mut App) { +fn write_git_blame(enabled: bool, cx: &mut App) { let fs = ::global(cx); let mut curr_settings = ProjectSettings::get_global(cx).clone(); @@ -95,6 +100,12 @@ fn write_ui_font_family(font: SharedString, cx: &mut App) { let fs = ::global(cx); update_settings_file::(fs, cx, move |theme_settings, _| { + telemetry::event!( + "Welcome Font Changed", + type = "ui font", + old = theme_settings.ui_font_family, + new = font.clone() + ); theme_settings.ui_font_family = Some(FontFamilyName(font.into())); }); } @@ -119,6 +130,13 @@ fn write_buffer_font_family(font_family: SharedString, cx: &mut App) { let fs = ::global(cx); update_settings_file::(fs, cx, move |theme_settings, _| { + telemetry::event!( + "Welcome Font Changed", + type = "editor font", + old = theme_settings.buffer_font_family, + new = font_family.clone() + ); + theme_settings.buffer_font_family = Some(FontFamilyName(font_family.into())); }); } @@ -197,7 +215,7 @@ fn render_setting_import_button( .color(Color::Muted) .size(IconSize::XSmall), ) - .child(Label::new(label)), + .child(Label::new(label.clone())), ) .when(imported, |this| { this.child( @@ -212,7 +230,10 @@ fn render_setting_import_button( ) }), ) - .on_click(move |_, window, cx| window.dispatch_action(action.boxed_clone(), cx)), + .on_click(move |_, window, cx| { + telemetry::event!("Welcome Import Settings", import_source = label,); + window.dispatch_action(action.boxed_clone(), cx); + }), ) } @@ -573,7 +594,7 @@ fn font_picker( ) -> FontPicker { let delegate = FontPickerDelegate::new(current_font, on_font_changed, cx); - Picker::list(delegate, window, cx) + Picker::uniform_list(delegate, window, cx) .show_scrollbar(true) .width(rems_from_px(210.)) .max_height(Some(rems(20.).into())) @@ -605,7 +626,13 @@ fn render_popular_settings_section( ui::ToggleState::Unselected }, |toggle_state, _, cx| { - write_font_ligatures(toggle_state == &ToggleState::Selected, cx); + let enabled = toggle_state == &ToggleState::Selected; + telemetry::event!( + "Welcome Font Ligature", + options = if enabled { "on" } else { "off" }, + ); + + write_font_ligatures(enabled, cx); }, ) .tab_index({ @@ -625,7 +652,13 @@ fn render_popular_settings_section( ui::ToggleState::Unselected }, |toggle_state, _, cx| { - write_format_on_save(toggle_state == &ToggleState::Selected, cx); + let enabled = toggle_state == &ToggleState::Selected; + telemetry::event!( + "Welcome Format On Save Changed", + options = if enabled { "on" } else { "off" }, + ); + + write_format_on_save(enabled, cx); }, ) .tab_index({ @@ -644,7 +677,13 @@ fn render_popular_settings_section( ui::ToggleState::Unselected }, |toggle_state, _, cx| { - write_inlay_hints(toggle_state == &ToggleState::Selected, cx); + let enabled = toggle_state == &ToggleState::Selected; + telemetry::event!( + "Welcome Inlay Hints Changed", + options = if enabled { "on" } else { "off" }, + ); + + write_inlay_hints(enabled, cx); }, ) .tab_index({ @@ -663,7 +702,13 @@ fn render_popular_settings_section( ui::ToggleState::Unselected }, |toggle_state, _, cx| { - set_git_blame(toggle_state == &ToggleState::Selected, cx); + let enabled = toggle_state == &ToggleState::Selected; + telemetry::event!( + "Welcome Git Blame Changed", + options = if enabled { "on" } else { "off" }, + ); + + write_git_blame(enabled, cx); }, ) .tab_index({ @@ -706,7 +751,7 @@ fn render_popular_settings_section( }) .tab_index(tab_index) .style(ToggleButtonGroupStyle::Outlined) - .button_width(ui::rems_from_px(64.)), + .width(ui::rems_from_px(3. * 64.)), ), ) } diff --git a/crates/welcome/src/multibuffer_hint.rs b/crates/onboarding/src/multibuffer_hint.rs similarity index 100% rename from crates/welcome/src/multibuffer_hint.rs rename to crates/onboarding/src/multibuffer_hint.rs diff --git a/crates/onboarding/src/onboarding.rs b/crates/onboarding/src/onboarding.rs index 7ba7ba60cb..e07a8dc9fb 100644 --- a/crates/onboarding/src/onboarding.rs +++ b/crates/onboarding/src/onboarding.rs @@ -1,8 +1,7 @@ -use crate::welcome::{ShowWelcome, WelcomePage}; +pub use crate::welcome::ShowWelcome; +use crate::{multibuffer_hint::MultibufferHint, welcome::WelcomePage}; use client::{Client, UserStore, zed_urls}; -use command_palette_hooks::CommandPaletteFilter; use db::kvp::KEY_VALUE_STORE; -use feature_flags::{FeatureFlag, FeatureFlagViewExt as _}; use fs::Fs; use gpui::{ Action, AnyElement, App, AppContext, AsyncWindowContext, Context, Entity, EventEmitter, @@ -27,17 +26,13 @@ use workspace::{ }; mod ai_setup_page; +mod base_keymap_picker; mod basics_page; mod editing_page; +pub mod multibuffer_hint; mod theme_preview; mod welcome; -pub struct OnBoardingFeatureFlag {} - -impl FeatureFlag for OnBoardingFeatureFlag { - const NAME: &'static str = "onboarding"; -} - /// Imports settings from Visual Studio Code. #[derive(Copy, Clone, Debug, Default, PartialEq, Deserialize, JsonSchema, Action)] #[action(namespace = zed)] @@ -57,6 +52,7 @@ pub struct ImportCursorSettings { } pub const FIRST_OPEN: &str = "first_open"; +pub const DOCS_URL: &str = "https://zed.dev/docs/"; actions!( zed, @@ -80,11 +76,19 @@ actions!( /// Sign in while in the onboarding flow. SignIn, /// Open the user account in zed.dev while in the onboarding flow. - OpenAccount + OpenAccount, + /// Resets the welcome screen hints to their initial state. + ResetHints ] ); pub fn init(cx: &mut App) { + cx.observe_new(|workspace: &mut Workspace, _, _cx| { + workspace + .register_action(|_workspace, _: &ResetHints, _, cx| MultibufferHint::set_count(0, cx)); + }) + .detach(); + cx.on_action(|_: &OpenOnboarding, cx| { with_active_or_new_workspace(cx, |workspace, window, cx| { workspace @@ -182,38 +186,14 @@ pub fn init(cx: &mut App) { }) .detach(); - cx.observe_new::(|_, window, cx| { - let Some(window) = window else { - return; - }; + base_keymap_picker::init(cx); - let onboarding_actions = [ - std::any::TypeId::of::(), - std::any::TypeId::of::(), - ]; - - CommandPaletteFilter::update_global(cx, |filter, _cx| { - filter.hide_action_types(&onboarding_actions); - }); - - cx.observe_flag::(window, move |is_enabled, _, _, cx| { - if is_enabled { - CommandPaletteFilter::update_global(cx, |filter, _cx| { - filter.show_action_types(onboarding_actions.iter()); - }); - } else { - CommandPaletteFilter::update_global(cx, |filter, _cx| { - filter.hide_action_types(&onboarding_actions); - }); - } - }) - .detach(); - }) - .detach(); register_serializable_item::(cx); + register_serializable_item::(cx); } pub fn show_onboarding_view(app_state: Arc, cx: &mut App) -> Task> { + telemetry::event!("Onboarding Page Opened"); open_new( Default::default(), app_state, @@ -242,6 +222,16 @@ enum SelectedPage { AiSetup, } +impl SelectedPage { + fn name(&self) -> &'static str { + match self { + SelectedPage::Basics => "Basics", + SelectedPage::Editing => "Editing", + SelectedPage::AiSetup => "AI Setup", + } + } +} + struct Onboarding { workspace: WeakEntity, focus_handle: FocusHandle, @@ -261,7 +251,21 @@ impl Onboarding { }) } - fn set_page(&mut self, page: SelectedPage, cx: &mut Context) { + fn set_page( + &mut self, + page: SelectedPage, + clicked: Option<&'static str>, + cx: &mut Context, + ) { + if let Some(click) = clicked { + telemetry::event!( + "Welcome Tab Clicked", + from = self.selected_page.name(), + to = page.name(), + clicked = click, + ); + } + self.selected_page = page; cx.notify(); cx.emit(ItemEvent::UpdateTab); @@ -325,8 +329,13 @@ impl Onboarding { gpui::Empty.into_any_element(), IntoElement::into_any_element, )) - .on_click(cx.listener(move |this, _, _, cx| { - this.set_page(page, cx); + .on_click(cx.listener(move |this, click_event, _, cx| { + let click = match click_event { + gpui::ClickEvent::Mouse(_) => "mouse", + gpui::ClickEvent::Keyboard(_) => "keyboard", + }; + + this.set_page(page, Some(click), cx); })) }) } @@ -475,6 +484,7 @@ impl Onboarding { } fn on_finish(_: &Finish, _: &mut Window, cx: &mut App) { + telemetry::event!("Welcome Skip Clicked"); go_to_welcome_page(cx); } @@ -532,13 +542,13 @@ impl Render for Onboarding { .on_action(Self::handle_sign_in) .on_action(Self::handle_open_account) .on_action(cx.listener(|this, _: &ActivateBasicsPage, _, cx| { - this.set_page(SelectedPage::Basics, cx); + this.set_page(SelectedPage::Basics, Some("action"), cx); })) .on_action(cx.listener(|this, _: &ActivateEditingPage, _, cx| { - this.set_page(SelectedPage::Editing, cx); + this.set_page(SelectedPage::Editing, Some("action"), cx); })) .on_action(cx.listener(|this, _: &ActivateAISetupPage, _, cx| { - this.set_page(SelectedPage::AiSetup, cx); + this.set_page(SelectedPage::AiSetup, Some("action"), cx); })) .on_action(cx.listener(|_, _: &menu::SelectNext, window, cx| { window.focus_next(); @@ -551,6 +561,7 @@ impl Render for Onboarding { .child( h_flex() .max_w(rems_from_px(1100.)) + .max_h(rems_from_px(850.)) .size_full() .m_auto() .py_20() @@ -560,12 +571,14 @@ impl Render for Onboarding { .child(self.render_nav(window, cx)) .child( v_flex() + .id("page-content") + .size_full() .max_w_full() .min_w_0() .pl_12() .border_l_1() .border_color(cx.theme().colors().border_variant.opacity(0.5)) - .size_full() + .overflow_y_scroll() .child(self.render_page(window, cx)), ), ) @@ -803,7 +816,7 @@ impl workspace::SerializableItem for Onboarding { if let Some(page) = page { zlog::info!("Onboarding page {page:?} loaded"); onboarding_page.update(cx, |onboarding_page, cx| { - onboarding_page.set_page(page, cx); + onboarding_page.set_page(page, None, cx); }) } onboarding_page diff --git a/crates/onboarding/src/welcome.rs b/crates/onboarding/src/welcome.rs index d4d6c3f701..ba0053a3b6 100644 --- a/crates/onboarding/src/welcome.rs +++ b/crates/onboarding/src/welcome.rs @@ -1,6 +1,6 @@ use gpui::{ Action, App, Context, Entity, EventEmitter, FocusHandle, Focusable, InteractiveElement, - NoAction, ParentElement, Render, Styled, Window, actions, + ParentElement, Render, Styled, Task, Window, actions, }; use menu::{SelectNext, SelectPrevious}; use ui::{ButtonLike, Divider, DividerColor, KeyBinding, Vector, VectorName, prelude::*}; @@ -38,8 +38,7 @@ const CONTENT: (Section<4>, Section<3>) = ( SectionEntry { icon: IconName::CloudDownload, title: "Clone a Repo", - // TODO: use proper action - action: &NoAction, + action: &git::Clone, }, SectionEntry { icon: IconName::ListCollapse, @@ -353,3 +352,109 @@ impl Item for WelcomePage { f(*event) } } + +impl workspace::SerializableItem for WelcomePage { + fn serialized_item_kind() -> &'static str { + "WelcomePage" + } + + fn cleanup( + workspace_id: workspace::WorkspaceId, + alive_items: Vec, + _window: &mut Window, + cx: &mut App, + ) -> Task> { + workspace::delete_unloaded_items( + alive_items, + workspace_id, + "welcome_pages", + &persistence::WELCOME_PAGES, + cx, + ) + } + + fn deserialize( + _project: Entity, + _workspace: gpui::WeakEntity, + workspace_id: workspace::WorkspaceId, + item_id: workspace::ItemId, + window: &mut Window, + cx: &mut App, + ) -> Task>> { + if persistence::WELCOME_PAGES + .get_welcome_page(item_id, workspace_id) + .ok() + .is_some_and(|is_open| is_open) + { + window.spawn(cx, async move |cx| cx.update(WelcomePage::new)) + } else { + Task::ready(Err(anyhow::anyhow!("No welcome page to deserialize"))) + } + } + + fn serialize( + &mut self, + workspace: &mut workspace::Workspace, + item_id: workspace::ItemId, + _closing: bool, + _window: &mut Window, + cx: &mut Context, + ) -> Option>> { + let workspace_id = workspace.database_id()?; + Some(cx.background_spawn(async move { + persistence::WELCOME_PAGES + .save_welcome_page(item_id, workspace_id, true) + .await + })) + } + + fn should_serialize(&self, event: &Self::Event) -> bool { + event == &ItemEvent::UpdateTab + } +} + +mod persistence { + use db::{define_connection, query, sqlez_macros::sql}; + use workspace::WorkspaceDb; + + define_connection! { + pub static ref WELCOME_PAGES: WelcomePagesDb = + &[ + sql!( + CREATE TABLE welcome_pages ( + workspace_id INTEGER, + item_id INTEGER UNIQUE, + is_open INTEGER DEFAULT FALSE, + + PRIMARY KEY(workspace_id, item_id), + FOREIGN KEY(workspace_id) REFERENCES workspaces(workspace_id) + ON DELETE CASCADE + ) STRICT; + ), + ]; + } + + impl WelcomePagesDb { + query! { + pub async fn save_welcome_page( + item_id: workspace::ItemId, + workspace_id: workspace::WorkspaceId, + is_open: bool + ) -> Result<()> { + INSERT OR REPLACE INTO welcome_pages(item_id, workspace_id, is_open) + VALUES (?, ?, ?) + } + } + + query! { + pub fn get_welcome_page( + item_id: workspace::ItemId, + workspace_id: workspace::WorkspaceId + ) -> Result { + SELECT is_open + FROM welcome_pages + WHERE item_id = ? AND workspace_id = ? + } + } + } +} diff --git a/crates/open_ai/src/open_ai.rs b/crates/open_ai/src/open_ai.rs index 919b1d9ebf..5801f29623 100644 --- a/crates/open_ai/src/open_ai.rs +++ b/crates/open_ai/src/open_ai.rs @@ -89,11 +89,13 @@ pub enum Model { max_tokens: u64, max_output_tokens: Option, max_completion_tokens: Option, + reasoning_effort: Option, }, } impl Model { pub fn default_fast() -> Self { + // TODO: Replace with FiveMini since all other models are deprecated Self::FourPointOneMini } @@ -206,6 +208,15 @@ impl Model { } } + pub fn reasoning_effort(&self) -> Option { + match self { + Self::Custom { + reasoning_effort, .. + } => reasoning_effort.to_owned(), + _ => None, + } + } + /// Returns whether the given model supports the `parallel_tool_calls` parameter. /// /// If the model does not support the parameter, do not pass it up, or the API will return an error. @@ -246,6 +257,7 @@ pub struct Request { pub tools: Vec, #[serde(default, skip_serializing_if = "Option::is_none")] pub prompt_cache_key: Option, + pub reasoning_effort: Option, } #[derive(Debug, Serialize, Deserialize)] @@ -257,6 +269,16 @@ pub enum ToolChoice { Other(ToolDefinition), } +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] +#[serde(rename_all = "lowercase")] +pub enum ReasoningEffort { + Minimal, + Low, + Medium, + High, +} + #[derive(Clone, Deserialize, Serialize, Debug)] #[serde(tag = "type", rename_all = "snake_case")] pub enum ToolDefinition { diff --git a/crates/repl/src/notebook/notebook_ui.rs b/crates/repl/src/notebook/notebook_ui.rs index b53809dff0..36a0af30d0 100644 --- a/crates/repl/src/notebook/notebook_ui.rs +++ b/crates/repl/src/notebook/notebook_ui.rs @@ -295,7 +295,7 @@ impl NotebookEditor { _cx: &mut Context, ) -> IconButton { let id: ElementId = ElementId::Name(id.into()); - IconButton::new(id, icon).width(px(CONTROL_SIZE).into()) + IconButton::new(id, icon).width(px(CONTROL_SIZE)) } fn render_notebook_controls( diff --git a/crates/terminal/src/terminal.rs b/crates/terminal/src/terminal.rs index 3e7d9c0ad4..c3c6de9e53 100644 --- a/crates/terminal/src/terminal.rs +++ b/crates/terminal/src/terminal.rs @@ -408,7 +408,13 @@ impl TerminalBuilder { let terminal_title_override = shell_params.as_ref().and_then(|e| e.title_override.clone()); #[cfg(windows)] - let shell_program = shell_params.as_ref().map(|params| params.program.clone()); + let shell_program = shell_params.as_ref().map(|params| { + use util::ResultExt; + + Self::resolve_path(¶ms.program) + .log_err() + .unwrap_or(params.program.clone()) + }); let pty_options = { let alac_shell = shell_params.map(|params| { @@ -589,6 +595,24 @@ impl TerminalBuilder { self.terminal } + + #[cfg(windows)] + fn resolve_path(path: &str) -> Result { + use windows::Win32::Storage::FileSystem::SearchPathW; + use windows::core::HSTRING; + + let path = if path.starts_with(r"\\?\") || !path.contains(&['/', '\\']) { + path.to_string() + } else { + r"\\?\".to_string() + path + }; + + let required_length = unsafe { SearchPathW(None, &HSTRING::from(&path), None, None, None) }; + let mut buf = vec![0u16; required_length as usize]; + let size = unsafe { SearchPathW(None, &HSTRING::from(&path), None, Some(&mut buf), None) }; + + Ok(String::from_utf16(&buf[..size as usize])?) + } } #[derive(Debug, Clone, Deserialize, Serialize)] diff --git a/crates/title_bar/src/title_bar.rs b/crates/title_bar/src/title_bar.rs index d11d3b7081..eb317a5616 100644 --- a/crates/title_bar/src/title_bar.rs +++ b/crates/title_bar/src/title_bar.rs @@ -595,7 +595,7 @@ impl TitleBar { .on_click(|_, window, cx| { if let Some(auto_updater) = auto_update::AutoUpdater::get(cx) { if auto_updater.read(cx).status().is_updated() { - workspace::reload(&Default::default(), cx); + workspace::reload(cx); return; } } diff --git a/crates/ui/src/components/button/button.rs b/crates/ui/src/components/button/button.rs index 19f782fb98..cee39ac23b 100644 --- a/crates/ui/src/components/button/button.rs +++ b/crates/ui/src/components/button/button.rs @@ -324,7 +324,7 @@ impl FixedWidth for Button { /// ``` /// /// This sets the button's width to be exactly 100 pixels. - fn width(mut self, width: DefiniteLength) -> Self { + fn width(mut self, width: impl Into) -> Self { self.base = self.base.width(width); self } diff --git a/crates/ui/src/components/button/button_like.rs b/crates/ui/src/components/button/button_like.rs index 35c78fbb5d..0b30007e44 100644 --- a/crates/ui/src/components/button/button_like.rs +++ b/crates/ui/src/components/button/button_like.rs @@ -499,8 +499,8 @@ impl Clickable for ButtonLike { } impl FixedWidth for ButtonLike { - fn width(mut self, width: DefiniteLength) -> Self { - self.width = Some(width); + fn width(mut self, width: impl Into) -> Self { + self.width = Some(width.into()); self } diff --git a/crates/ui/src/components/button/icon_button.rs b/crates/ui/src/components/button/icon_button.rs index 40e06e7a72..3975e45b6e 100644 --- a/crates/ui/src/components/button/icon_button.rs +++ b/crates/ui/src/components/button/icon_button.rs @@ -133,7 +133,7 @@ impl Clickable for IconButton { } impl FixedWidth for IconButton { - fn width(mut self, width: DefiniteLength) -> Self { + fn width(mut self, width: impl Into) -> Self { self.base = self.base.width(width); self } @@ -194,7 +194,7 @@ impl RenderOnce for IconButton { .map(|this| match self.shape { IconButtonShape::Square => { let size = self.icon_size.square(window, cx); - this.width(size.into()).height(size.into()) + this.width(size).height(size.into()) } IconButtonShape::Wide => this, }) diff --git a/crates/ui/src/components/button/toggle_button.rs b/crates/ui/src/components/button/toggle_button.rs index 91defa730b..2a862f4876 100644 --- a/crates/ui/src/components/button/toggle_button.rs +++ b/crates/ui/src/components/button/toggle_button.rs @@ -1,6 +1,6 @@ use std::rc::Rc; -use gpui::{AnyView, ClickEvent}; +use gpui::{AnyView, ClickEvent, relative}; use crate::{ButtonLike, ButtonLikeRounding, ElevationIndex, TintColor, Tooltip, prelude::*}; @@ -73,8 +73,8 @@ impl SelectableButton for ToggleButton { } impl FixedWidth for ToggleButton { - fn width(mut self, width: DefiniteLength) -> Self { - self.base.width = Some(width); + fn width(mut self, width: impl Into) -> Self { + self.base.width = Some(width.into()); self } @@ -429,7 +429,7 @@ where rows: [[T; COLS]; ROWS], style: ToggleButtonGroupStyle, size: ToggleButtonGroupSize, - button_width: Rems, + group_width: Option, selected_index: usize, tab_index: Option, } @@ -441,7 +441,7 @@ impl ToggleButtonGroup { rows: [buttons], style: ToggleButtonGroupStyle::Transparent, size: ToggleButtonGroupSize::Default, - button_width: rems_from_px(100.), + group_width: None, selected_index: 0, tab_index: None, } @@ -455,7 +455,7 @@ impl ToggleButtonGroup { rows: [first_row, second_row], style: ToggleButtonGroupStyle::Transparent, size: ToggleButtonGroupSize::Default, - button_width: rems_from_px(100.), + group_width: None, selected_index: 0, tab_index: None, } @@ -473,11 +473,6 @@ impl ToggleButtonGroup Self { - self.button_width = button_width; - self - } - pub fn selected_index(mut self, index: usize) -> Self { self.selected_index = index; self @@ -491,6 +486,24 @@ impl ToggleButtonGroup DefiniteLength { + relative(1. / COLS as f32) + } +} + +impl FixedWidth + for ToggleButtonGroup +{ + fn width(mut self, width: impl Into) -> Self { + self.group_width = Some(width.into()); + self + } + + fn full_width(mut self) -> Self { + self.group_width = Some(relative(1.)); + self + } } impl RenderOnce @@ -511,6 +524,7 @@ impl RenderOnce let entry_index = row_index * COLS + col_index; ButtonLike::new((self.group_name, entry_index)) + .full_width() .rounding(None) .when_some(self.tab_index, |this, tab_index| { this.tab_index(tab_index + entry_index as isize) @@ -527,7 +541,7 @@ impl RenderOnce }) .child( h_flex() - .min_w(self.button_width) + .w_full() .gap_1p5() .px_3() .py_1() @@ -561,6 +575,13 @@ impl RenderOnce let is_transparent = self.style == ToggleButtonGroupStyle::Transparent; v_flex() + .map(|this| { + if let Some(width) = self.group_width { + this.w(width) + } else { + this.w_full() + } + }) .rounded_md() .overflow_hidden() .map(|this| { @@ -583,6 +604,8 @@ impl RenderOnce .when(is_outlined_or_filled && !last_item, |this| { this.border_r_1().border_color(border_color) }) + .w(Self::button_width()) + .overflow_hidden() .child(item) })) })) @@ -630,7 +653,6 @@ impl Component ], ) .selected_index(1) - .button_width(rems_from_px(100.)) .into_any_element(), ), single_example( @@ -656,7 +678,6 @@ impl Component ], ) .selected_index(1) - .button_width(rems_from_px(100.)) .into_any_element(), ), single_example( @@ -675,7 +696,6 @@ impl Component ], ) .selected_index(3) - .button_width(rems_from_px(100.)) .into_any_element(), ), single_example( @@ -718,7 +738,6 @@ impl Component ], ) .selected_index(3) - .button_width(rems_from_px(100.)) .into_any_element(), ), ], @@ -763,7 +782,6 @@ impl Component ], ) .selected_index(1) - .button_width(rems_from_px(100.)) .style(ToggleButtonGroupStyle::Outlined) .into_any_element(), ), @@ -783,7 +801,6 @@ impl Component ], ) .selected_index(3) - .button_width(rems_from_px(100.)) .style(ToggleButtonGroupStyle::Outlined) .into_any_element(), ), @@ -827,7 +844,6 @@ impl Component ], ) .selected_index(3) - .button_width(rems_from_px(100.)) .style(ToggleButtonGroupStyle::Outlined) .into_any_element(), ), @@ -873,7 +889,6 @@ impl Component ], ) .selected_index(1) - .button_width(rems_from_px(100.)) .style(ToggleButtonGroupStyle::Filled) .into_any_element(), ), @@ -893,7 +908,7 @@ impl Component ], ) .selected_index(3) - .button_width(rems_from_px(100.)) + .width(rems_from_px(100.)) .style(ToggleButtonGroupStyle::Filled) .into_any_element(), ), @@ -937,7 +952,7 @@ impl Component ], ) .selected_index(3) - .button_width(rems_from_px(100.)) + .width(rems_from_px(100.)) .style(ToggleButtonGroupStyle::Filled) .into_any_element(), ), @@ -957,7 +972,6 @@ impl Component ], ) .selected_index(1) - .button_width(rems_from_px(100.)) .into_any_element(), )]) .into_any_element(), diff --git a/crates/ui/src/traits/fixed.rs b/crates/ui/src/traits/fixed.rs index 9ba64da090..6ca9c8617f 100644 --- a/crates/ui/src/traits/fixed.rs +++ b/crates/ui/src/traits/fixed.rs @@ -3,7 +3,7 @@ use gpui::DefiniteLength; /// A trait for elements that can have a fixed with. Enables the use of the `width` and `full_width` methods. pub trait FixedWidth { /// Sets the width of the element. - fn width(self, width: DefiniteLength) -> Self; + fn width(self, width: impl Into) -> Self; /// Sets the element's width to the full width of its container. fn full_width(self) -> Self; diff --git a/crates/welcome/Cargo.toml b/crates/welcome/Cargo.toml deleted file mode 100644 index acb3fe0f84..0000000000 --- a/crates/welcome/Cargo.toml +++ /dev/null @@ -1,40 +0,0 @@ -[package] -name = "welcome" -version = "0.1.0" -edition.workspace = true -publish.workspace = true -license = "GPL-3.0-or-later" - -[lints] -workspace = true - -[lib] -path = "src/welcome.rs" - -[features] -test-support = [] - -[dependencies] -anyhow.workspace = true -client.workspace = true -component.workspace = true -db.workspace = true -documented.workspace = true -fuzzy.workspace = true -gpui.workspace = true -install_cli.workspace = true -language.workspace = true -picker.workspace = true -project.workspace = true -serde.workspace = true -settings.workspace = true -telemetry.workspace = true -ui.workspace = true -util.workspace = true -vim_mode_setting.workspace = true -workspace-hack.workspace = true -workspace.workspace = true -zed_actions.workspace = true - -[dev-dependencies] -editor = { workspace = true, features = ["test-support"] } diff --git a/crates/welcome/LICENSE-GPL b/crates/welcome/LICENSE-GPL deleted file mode 120000 index 89e542f750..0000000000 --- a/crates/welcome/LICENSE-GPL +++ /dev/null @@ -1 +0,0 @@ -../../LICENSE-GPL \ No newline at end of file diff --git a/crates/welcome/src/welcome.rs b/crates/welcome/src/welcome.rs deleted file mode 100644 index b0a1c316f4..0000000000 --- a/crates/welcome/src/welcome.rs +++ /dev/null @@ -1,446 +0,0 @@ -use client::{TelemetrySettings, telemetry::Telemetry}; -use db::kvp::KEY_VALUE_STORE; -use gpui::{ - Action, App, Context, Entity, EventEmitter, FocusHandle, Focusable, InteractiveElement, - ParentElement, Render, Styled, Subscription, Task, WeakEntity, Window, actions, svg, -}; -use language::language_settings::{EditPredictionProvider, all_language_settings}; -use project::DisableAiSettings; -use settings::{Settings, SettingsStore}; -use std::sync::Arc; -use ui::{CheckboxWithLabel, ElevationIndex, Tooltip, prelude::*}; -use util::ResultExt; -use vim_mode_setting::VimModeSetting; -use workspace::{ - AppState, Welcome, Workspace, WorkspaceId, - dock::DockPosition, - item::{Item, ItemEvent}, - open_new, -}; - -pub use multibuffer_hint::*; - -mod base_keymap_picker; -mod multibuffer_hint; - -actions!( - welcome, - [ - /// Resets the welcome screen hints to their initial state. - ResetHints - ] -); - -pub const FIRST_OPEN: &str = "first_open"; -pub const DOCS_URL: &str = "https://zed.dev/docs/"; - -pub fn init(cx: &mut App) { - cx.observe_new(|workspace: &mut Workspace, _, _cx| { - workspace.register_action(|workspace, _: &Welcome, window, cx| { - let welcome_page = WelcomePage::new(workspace, cx); - workspace.add_item_to_active_pane(Box::new(welcome_page), None, true, window, cx) - }); - workspace - .register_action(|_workspace, _: &ResetHints, _, cx| MultibufferHint::set_count(0, cx)); - }) - .detach(); - - base_keymap_picker::init(cx); -} - -pub fn show_welcome_view(app_state: Arc, cx: &mut App) -> Task> { - open_new( - Default::default(), - app_state, - cx, - |workspace, window, cx| { - workspace.toggle_dock(DockPosition::Left, window, cx); - let welcome_page = WelcomePage::new(workspace, cx); - workspace.add_item_to_center(Box::new(welcome_page.clone()), window, cx); - - window.focus(&welcome_page.focus_handle(cx)); - - cx.notify(); - - db::write_and_log(cx, || { - KEY_VALUE_STORE.write_kvp(FIRST_OPEN.to_string(), "false".to_string()) - }); - }, - ) -} - -pub struct WelcomePage { - workspace: WeakEntity, - focus_handle: FocusHandle, - telemetry: Arc, - _settings_subscription: Subscription, -} - -impl Render for WelcomePage { - fn render(&mut self, _: &mut Window, cx: &mut Context) -> impl IntoElement { - let edit_prediction_provider_is_zed = - all_language_settings(None, cx).edit_predictions.provider - == EditPredictionProvider::Zed; - - let edit_prediction_label = if edit_prediction_provider_is_zed { - "Edit Prediction Enabled" - } else { - "Try Edit Prediction" - }; - - h_flex() - .size_full() - .bg(cx.theme().colors().editor_background) - .key_context("Welcome") - .track_focus(&self.focus_handle(cx)) - .child( - v_flex() - .gap_8() - .mx_auto() - .child( - v_flex() - .w_full() - .child( - svg() - .path("icons/logo_96.svg") - .text_color(cx.theme().colors().icon_disabled) - .w(px(40.)) - .h(px(40.)) - .mx_auto() - .mb_4(), - ) - .child( - h_flex() - .w_full() - .justify_center() - .child(Headline::new("Welcome to Zed")), - ) - .child( - h_flex().w_full().justify_center().child( - Label::new("The editor for what's next") - .color(Color::Muted) - .italic(), - ), - ), - ) - .child( - h_flex() - .items_start() - .gap_8() - .child( - v_flex() - .gap_2() - .pr_8() - .border_r_1() - .border_color(cx.theme().colors().border_variant) - .child( - self.section_label( cx).child( - Label::new("Get Started") - .size(LabelSize::XSmall) - .color(Color::Muted), - ), - ) - .child( - Button::new("choose-theme", "Choose a Theme") - .icon(IconName::SwatchBook) - .icon_size(IconSize::XSmall) - .icon_color(Color::Muted) - .icon_position(IconPosition::Start) - .on_click(cx.listener(|this, _, window, cx| { - telemetry::event!("Welcome Theme Changed"); - this.workspace - .update(cx, |_workspace, cx| { - window.dispatch_action(zed_actions::theme_selector::Toggle::default().boxed_clone(), cx); - }) - .ok(); - })), - ) - .child( - Button::new("choose-keymap", "Choose a Keymap") - .icon(IconName::Keyboard) - .icon_size(IconSize::XSmall) - .icon_color(Color::Muted) - .icon_position(IconPosition::Start) - .on_click(cx.listener(|this, _, window, cx| { - telemetry::event!("Welcome Keymap Changed"); - this.workspace - .update(cx, |workspace, cx| { - base_keymap_picker::toggle( - workspace, - &Default::default(), - window, cx, - ) - }) - .ok(); - })), - ) - .when(!DisableAiSettings::get_global(cx).disable_ai, |parent| { - parent.child( - Button::new( - "edit_prediction_onboarding", - edit_prediction_label, - ) - .disabled(edit_prediction_provider_is_zed) - .icon(IconName::ZedPredict) - .icon_size(IconSize::XSmall) - .icon_color(Color::Muted) - .icon_position(IconPosition::Start) - .on_click( - cx.listener(|_, _, window, cx| { - telemetry::event!("Welcome Screen Try Edit Prediction clicked"); - window.dispatch_action(zed_actions::OpenZedPredictOnboarding.boxed_clone(), cx); - }), - ), - ) - }) - .child( - Button::new("edit settings", "Edit Settings") - .icon(IconName::Settings) - .icon_size(IconSize::XSmall) - .icon_color(Color::Muted) - .icon_position(IconPosition::Start) - .on_click(cx.listener(|_, _, window, cx| { - telemetry::event!("Welcome Settings Edited"); - window.dispatch_action(Box::new( - zed_actions::OpenSettings, - ), cx); - })), - ) - - ) - .child( - v_flex() - .gap_2() - .child( - self.section_label(cx).child( - Label::new("Resources") - .size(LabelSize::XSmall) - .color(Color::Muted), - ), - ) - .when(cfg!(target_os = "macos"), |el| { - el.child( - Button::new("install-cli", "Install the CLI") - .icon(IconName::Terminal) - .icon_size(IconSize::XSmall) - .icon_color(Color::Muted) - .icon_position(IconPosition::Start) - .on_click(cx.listener(|this, _, window, cx| { - telemetry::event!("Welcome CLI Installed"); - this.workspace.update(cx, |_, cx|{ - install_cli::install_cli(window, cx); - }).log_err(); - })), - ) - }) - .child( - Button::new("view-docs", "View Documentation") - .icon(IconName::FileCode) - .icon_size(IconSize::XSmall) - .icon_color(Color::Muted) - .icon_position(IconPosition::Start) - .on_click(cx.listener(|_, _, _, cx| { - telemetry::event!("Welcome Documentation Viewed"); - cx.open_url(DOCS_URL); - })), - ) - .child( - Button::new("explore-extensions", "Explore Extensions") - .icon(IconName::Blocks) - .icon_size(IconSize::XSmall) - .icon_color(Color::Muted) - .icon_position(IconPosition::Start) - .on_click(cx.listener(|_, _, window, cx| { - telemetry::event!("Welcome Extensions Page Opened"); - window.dispatch_action(Box::new( - zed_actions::Extensions::default(), - ), cx); - })), - ) - ), - ) - .child( - v_container() - .px_2() - .gap_2() - .child( - h_flex() - .justify_between() - .child( - CheckboxWithLabel::new( - "enable-vim", - Label::new("Enable Vim Mode"), - if VimModeSetting::get_global(cx).0 { - ui::ToggleState::Selected - } else { - ui::ToggleState::Unselected - }, - cx.listener(move |this, selection, _window, cx| { - telemetry::event!("Welcome Vim Mode Toggled"); - this.update_settings::( - selection, - cx, - |setting, value| *setting = Some(value), - ); - }), - ) - .fill() - .elevation(ElevationIndex::ElevatedSurface), - ) - .child( - IconButton::new("vim-mode", IconName::Info) - .icon_size(IconSize::XSmall) - .icon_color(Color::Muted) - .tooltip( - Tooltip::text( - "You can also toggle Vim Mode via the command palette or Editor Controls menu.") - ), - ), - ) - .child( - CheckboxWithLabel::new( - "enable-crash", - Label::new("Send Crash Reports"), - if TelemetrySettings::get_global(cx).diagnostics { - ui::ToggleState::Selected - } else { - ui::ToggleState::Unselected - }, - cx.listener(move |this, selection, _window, cx| { - telemetry::event!("Welcome Diagnostic Telemetry Toggled"); - this.update_settings::(selection, cx, { - move |settings, value| { - settings.diagnostics = Some(value); - telemetry::event!( - "Settings Changed", - setting = "diagnostic telemetry", - value - ); - } - }); - }), - ) - .fill() - .elevation(ElevationIndex::ElevatedSurface), - ) - .child( - CheckboxWithLabel::new( - "enable-telemetry", - Label::new("Send Telemetry"), - if TelemetrySettings::get_global(cx).metrics { - ui::ToggleState::Selected - } else { - ui::ToggleState::Unselected - }, - cx.listener(move |this, selection, _window, cx| { - telemetry::event!("Welcome Metric Telemetry Toggled"); - this.update_settings::(selection, cx, { - move |settings, value| { - settings.metrics = Some(value); - telemetry::event!( - "Settings Changed", - setting = "metric telemetry", - value - ); - } - }); - }), - ) - .fill() - .elevation(ElevationIndex::ElevatedSurface), - ), - ), - ) - } -} - -impl WelcomePage { - pub fn new(workspace: &Workspace, cx: &mut Context) -> Entity { - let this = cx.new(|cx| { - cx.on_release(|_: &mut Self, _| { - telemetry::event!("Welcome Page Closed"); - }) - .detach(); - - WelcomePage { - focus_handle: cx.focus_handle(), - workspace: workspace.weak_handle(), - telemetry: workspace.client().telemetry().clone(), - _settings_subscription: cx - .observe_global::(move |_, cx| cx.notify()), - } - }); - - this - } - - fn section_label(&self, cx: &mut App) -> Div { - div() - .pl_1() - .font_buffer(cx) - .text_color(Color::Muted.color(cx)) - } - - fn update_settings( - &mut self, - selection: &ToggleState, - cx: &mut Context, - callback: impl 'static + Send + Fn(&mut T::FileContent, bool), - ) { - if let Some(workspace) = self.workspace.upgrade() { - let fs = workspace.read(cx).app_state().fs.clone(); - let selection = *selection; - settings::update_settings_file::(fs, cx, move |settings, _| { - let value = match selection { - ToggleState::Unselected => false, - ToggleState::Selected => true, - _ => return, - }; - - callback(settings, value) - }); - } - } -} - -impl EventEmitter for WelcomePage {} - -impl Focusable for WelcomePage { - fn focus_handle(&self, _: &App) -> gpui::FocusHandle { - self.focus_handle.clone() - } -} - -impl Item for WelcomePage { - type Event = ItemEvent; - - fn tab_content_text(&self, _detail: usize, _cx: &App) -> SharedString { - "Welcome".into() - } - - fn telemetry_event_text(&self) -> Option<&'static str> { - Some("Welcome Page Opened") - } - - fn show_toolbar(&self) -> bool { - false - } - - fn clone_on_split( - &self, - _workspace_id: Option, - _: &mut Window, - cx: &mut Context, - ) -> Option> { - Some(cx.new(|cx| WelcomePage { - focus_handle: cx.focus_handle(), - workspace: self.workspace.clone(), - telemetry: self.telemetry.clone(), - _settings_subscription: cx.observe_global::(move |_, cx| cx.notify()), - })) - } - - fn to_item_events(event: &Self::Event, mut f: impl FnMut(workspace::item::ItemEvent)) { - f(*event) - } -} diff --git a/crates/workspace/src/item.rs b/crates/workspace/src/item.rs index c8ebe4550b..bba50e4431 100644 --- a/crates/workspace/src/item.rs +++ b/crates/workspace/src/item.rs @@ -293,6 +293,7 @@ pub trait Item: Focusable + EventEmitter + Render + Sized { fn deactivated(&mut self, _window: &mut Window, _: &mut Context) {} fn discarded(&self, _project: Entity, _window: &mut Window, _cx: &mut Context) {} + fn on_removed(&self, _cx: &App) {} fn workspace_deactivated(&mut self, _window: &mut Window, _: &mut Context) {} fn navigate(&mut self, _: Box, _window: &mut Window, _: &mut Context) -> bool { false @@ -532,6 +533,7 @@ pub trait ItemHandle: 'static + Send { ); fn deactivated(&self, window: &mut Window, cx: &mut App); fn discarded(&self, project: Entity, window: &mut Window, cx: &mut App); + fn on_removed(&self, cx: &App); fn workspace_deactivated(&self, window: &mut Window, cx: &mut App); fn navigate(&self, data: Box, window: &mut Window, cx: &mut App) -> bool; fn item_id(&self) -> EntityId; @@ -968,6 +970,10 @@ impl ItemHandle for Entity { self.update(cx, |this, cx| this.deactivated(window, cx)); } + fn on_removed(&self, cx: &App) { + self.read(cx).on_removed(cx); + } + fn workspace_deactivated(&self, window: &mut Window, cx: &mut App) { self.update(cx, |this, cx| this.workspace_deactivated(window, cx)); } diff --git a/crates/workspace/src/pane.rs b/crates/workspace/src/pane.rs index f56e983114..7183f7d379 100644 --- a/crates/workspace/src/pane.rs +++ b/crates/workspace/src/pane.rs @@ -1829,6 +1829,7 @@ impl Pane { let mode = self.nav_history.mode(); self.nav_history.set_mode(NavigationMode::ClosingItem); item.deactivated(window, cx); + item.on_removed(cx); self.nav_history.set_mode(mode); if self.is_active_preview_item(item.item_id()) { diff --git a/crates/workspace/src/workspace.rs b/crates/workspace/src/workspace.rs index aab8a36f45..fb78c62f9e 100644 --- a/crates/workspace/src/workspace.rs +++ b/crates/workspace/src/workspace.rs @@ -224,6 +224,8 @@ actions!( ResetActiveDockSize, /// Resets all open docks to their default sizes. ResetOpenDocksSize, + /// Reloads the application + Reload, /// Saves the current file with a new name. SaveAs, /// Saves without formatting. @@ -246,8 +248,6 @@ actions!( ToggleZoom, /// Stops following a collaborator. Unfollow, - /// Shows the welcome screen. - Welcome, /// Restores the banner. RestoreBanner, /// Toggles expansion of the selected item. @@ -340,14 +340,6 @@ pub struct CloseInactiveTabsAndPanes { #[action(namespace = workspace)] pub struct SendKeystrokes(pub String); -/// Reloads the active item or workspace. -#[derive(Clone, Deserialize, PartialEq, Default, JsonSchema, Action)] -#[action(namespace = workspace)] -#[serde(deny_unknown_fields)] -pub struct Reload { - pub binary_path: Option, -} - actions!( project_symbols, [ @@ -555,8 +547,8 @@ pub fn init(app_state: Arc, cx: &mut App) { toast_layer::init(cx); history_manager::init(cx); - cx.on_action(Workspace::close_global); - cx.on_action(reload); + cx.on_action(|_: &CloseWindow, cx| Workspace::close_global(cx)); + cx.on_action(|_: &Reload, cx| reload(cx)); cx.on_action({ let app_state = Arc::downgrade(&app_state); @@ -2184,7 +2176,7 @@ impl Workspace { } } - pub fn close_global(_: &CloseWindow, cx: &mut App) { + pub fn close_global(cx: &mut App) { cx.defer(|cx| { cx.windows().iter().find(|window| { window @@ -7642,7 +7634,7 @@ pub fn join_in_room_project( }) } -pub fn reload(reload: &Reload, cx: &mut App) { +pub fn reload(cx: &mut App) { let should_confirm = WorkspaceSettings::get_global(cx).confirm_quit; let mut workspace_windows = cx .windows() @@ -7669,7 +7661,6 @@ pub fn reload(reload: &Reload, cx: &mut App) { .ok(); } - let binary_path = reload.binary_path.clone(); cx.spawn(async move |cx| { if let Some(prompt) = prompt { let answer = prompt.await?; @@ -7688,8 +7679,7 @@ pub fn reload(reload: &Reload, cx: &mut App) { } } } - - cx.update(|cx| cx.restart(binary_path)) + cx.update(|cx| cx.restart()) }) .detach_and_log_err(cx); } diff --git a/crates/zed/Cargo.toml b/crates/zed/Cargo.toml index 5997e43864..bdbb39698c 100644 --- a/crates/zed/Cargo.toml +++ b/crates/zed/Cargo.toml @@ -157,7 +157,6 @@ vim_mode_setting.workspace = true watch.workspace = true web_search.workspace = true web_search_providers.workspace = true -welcome.workspace = true workspace-hack.workspace = true workspace.workspace = true zed_actions.workspace = true diff --git a/crates/zed/src/main.rs b/crates/zed/src/main.rs index e4a14b5d32..fd987ef6c5 100644 --- a/crates/zed/src/main.rs +++ b/crates/zed/src/main.rs @@ -20,6 +20,7 @@ use gpui::{App, AppContext as _, Application, AsyncApp, Focusable as _, UpdateGl use gpui_tokio::Tokio; use http_client::{Url, read_proxy_from_env}; use language::LanguageRegistry; +use onboarding::{FIRST_OPEN, show_onboarding_view}; use prompt_store::PromptBuilder; use reqwest_client::ReqwestClient; @@ -44,7 +45,6 @@ use theme::{ }; use util::{ResultExt, TryFutureExt, maybe}; use uuid::Uuid; -use welcome::{FIRST_OPEN, show_welcome_view}; use workspace::{ AppState, SerializedWorkspaceLocation, Toast, Workspace, WorkspaceSettings, WorkspaceStore, notifications::NotificationId, @@ -201,16 +201,6 @@ pub fn main() { return; } - // Check if there is a pending installer - // If there is, run the installer and exit - // And we don't want to run the installer if we are not the first instance - #[cfg(target_os = "windows")] - let is_first_instance = crate::zed::windows_only_instance::is_first_instance(); - #[cfg(target_os = "windows")] - if is_first_instance && auto_update::check_pending_installation() { - return; - } - if args.dump_all_actions { dump_all_gpui_actions(); return; @@ -261,7 +251,15 @@ pub fn main() { return; } - log::info!("========== starting zed =========="); + log::info!( + "========== starting zed version {}, sha {} ==========", + app_version, + app_commit_sha + .as_ref() + .map(|sha| sha.short()) + .as_deref() + .unwrap_or("unknown"), + ); let app = Application::new().with_assets(Assets); @@ -283,30 +281,27 @@ pub fn main() { let (open_listener, mut open_rx) = OpenListener::new(); - let failed_single_instance_check = - if *db::ZED_STATELESS || *release_channel::RELEASE_CHANNEL == ReleaseChannel::Dev { - false - } else { - #[cfg(any(target_os = "linux", target_os = "freebsd"))] - { - crate::zed::listen_for_cli_connections(open_listener.clone()).is_err() - } + let failed_single_instance_check = if *db::ZED_STATELESS + || *release_channel::RELEASE_CHANNEL == ReleaseChannel::Dev + { + false + } else { + #[cfg(any(target_os = "linux", target_os = "freebsd"))] + { + crate::zed::listen_for_cli_connections(open_listener.clone()).is_err() + } - #[cfg(target_os = "windows")] - { - !crate::zed::windows_only_instance::handle_single_instance( - open_listener.clone(), - &args, - is_first_instance, - ) - } + #[cfg(target_os = "windows")] + { + !crate::zed::windows_only_instance::handle_single_instance(open_listener.clone(), &args) + } - #[cfg(target_os = "macos")] - { - use zed::mac_only_instance::*; - ensure_only_instance() != IsOnlyInstance::Yes - } - }; + #[cfg(target_os = "macos")] + { + use zed::mac_only_instance::*; + ensure_only_instance() != IsOnlyInstance::Yes + } + }; if failed_single_instance_check { println!("zed is already running"); return; @@ -628,7 +623,6 @@ pub fn main() { feedback::init(cx); markdown_preview::init(cx); svg_preview::init(cx); - welcome::init(cx); onboarding::init(cx); settings_ui::init(cx); extensions_ui::init(cx); @@ -1049,7 +1043,7 @@ async fn restore_or_create_workspace(app_state: Arc, cx: &mut AsyncApp } } } else if matches!(KEY_VALUE_STORE.read_kvp(FIRST_OPEN), Ok(None)) { - cx.update(|cx| show_welcome_view(app_state, cx))?.await?; + cx.update(|cx| show_onboarding_view(app_state, cx))?.await?; } else { cx.update(|cx| { workspace::open_new( diff --git a/crates/zed/src/zed.rs b/crates/zed/src/zed.rs index 8c89a7d85a..23020d3a9b 100644 --- a/crates/zed/src/zed.rs +++ b/crates/zed/src/zed.rs @@ -34,6 +34,8 @@ use image_viewer::ImageInfo; use language_tools::lsp_tool::{self, LspTool}; use migrate::{MigrationBanner, MigrationEvent, MigrationNotification, MigrationType}; use migrator::{migrate_keymap, migrate_settings}; +use onboarding::DOCS_URL; +use onboarding::multibuffer_hint::MultibufferHint; pub use open_listener::*; use outline_panel::OutlinePanel; use paths::{ @@ -67,7 +69,6 @@ use util::markdown::MarkdownString; use util::{ResultExt, asset_str}; use uuid::Uuid; use vim_mode_setting::VimModeSetting; -use welcome::{DOCS_URL, MultibufferHint}; use workspace::notifications::{NotificationId, dismiss_app_notification, show_app_notification}; use workspace::{ AppState, NewFile, NewWindow, OpenLog, Toast, Workspace, WorkspaceSettings, @@ -3975,7 +3976,6 @@ mod tests { client::init(&app_state.client, cx); language::init(cx); workspace::init(app_state.clone(), cx); - welcome::init(cx); onboarding::init(cx); Project::init_settings(cx); app_state @@ -4380,7 +4380,6 @@ mod tests { "toolchain", "variable_list", "vim", - "welcome", "workspace", "zed", "zed_predict_onboarding", diff --git a/crates/zed/src/zed/app_menus.rs b/crates/zed/src/zed/app_menus.rs index 53eec42ba0..9df55a2fb1 100644 --- a/crates/zed/src/zed/app_menus.rs +++ b/crates/zed/src/zed/app_menus.rs @@ -249,7 +249,7 @@ pub fn app_menus() -> Vec { ), MenuItem::action("View Telemetry", zed_actions::OpenTelemetryLog), MenuItem::action("View Dependency Licenses", zed_actions::OpenLicenses), - MenuItem::action("Show Welcome", workspace::Welcome), + MenuItem::action("Show Welcome", onboarding::ShowWelcome), MenuItem::action("Give Feedback...", zed_actions::feedback::GiveFeedback), MenuItem::separator(), MenuItem::action( diff --git a/crates/zed/src/zed/open_listener.rs b/crates/zed/src/zed/open_listener.rs index 2fd9b0a68c..82d3795e94 100644 --- a/crates/zed/src/zed/open_listener.rs +++ b/crates/zed/src/zed/open_listener.rs @@ -15,6 +15,8 @@ use futures::{FutureExt, SinkExt, StreamExt}; use git_ui::file_diff_view::FileDiffView; use gpui::{App, AsyncApp, Global, WindowHandle}; use language::Point; +use onboarding::FIRST_OPEN; +use onboarding::show_onboarding_view; use recent_projects::{SshSettings, open_ssh_project}; use remote::SshConnectionOptions; use settings::Settings; @@ -24,7 +26,6 @@ use std::thread; use std::time::Duration; use util::ResultExt; use util::paths::PathWithPosition; -use welcome::{FIRST_OPEN, show_welcome_view}; use workspace::item::ItemHandle; use workspace::{AppState, OpenOptions, SerializedWorkspaceLocation, Workspace}; @@ -378,7 +379,7 @@ async fn open_workspaces( if grouped_locations.is_empty() { // If we have no paths to open, show the welcome screen if this is the first launch if matches!(KEY_VALUE_STORE.read_kvp(FIRST_OPEN), Ok(None)) { - cx.update(|cx| show_welcome_view(app_state, cx).detach()) + cx.update(|cx| show_onboarding_view(app_state, cx).detach()) .log_err(); } // If not the first launch, show an empty window with empty editor diff --git a/crates/zed/src/zed/quick_action_bar/repl_menu.rs b/crates/zed/src/zed/quick_action_bar/repl_menu.rs index 5d1a6c8887..ca180dccdd 100644 --- a/crates/zed/src/zed/quick_action_bar/repl_menu.rs +++ b/crates/zed/src/zed/quick_action_bar/repl_menu.rs @@ -216,7 +216,7 @@ impl QuickActionBar { .size(IconSize::XSmall) .color(Color::Muted), ) - .width(rems(1.).into()) + .width(rems(1.)) .disabled(menu_state.popover_disabled), Tooltip::text("REPL Menu"), ); diff --git a/crates/zed/src/zed/windows_only_instance.rs b/crates/zed/src/zed/windows_only_instance.rs index 277e8ee724..bd62dea75a 100644 --- a/crates/zed/src/zed/windows_only_instance.rs +++ b/crates/zed/src/zed/windows_only_instance.rs @@ -25,7 +25,8 @@ use windows::{ use crate::{Args, OpenListener, RawOpenRequest}; -pub fn is_first_instance() -> bool { +#[inline] +fn is_first_instance() -> bool { unsafe { CreateMutexW( None, @@ -37,7 +38,8 @@ pub fn is_first_instance() -> bool { unsafe { GetLastError() != ERROR_ALREADY_EXISTS } } -pub fn handle_single_instance(opener: OpenListener, args: &Args, is_first_instance: bool) -> bool { +pub fn handle_single_instance(opener: OpenListener, args: &Args) -> bool { + let is_first_instance = is_first_instance(); if is_first_instance { // We are the first instance, listen for messages sent from other instances std::thread::spawn(move || { diff --git a/docs/src/ai/models.md b/docs/src/ai/models.md index b40f17b77f..8d46d0b8d1 100644 --- a/docs/src/ai/models.md +++ b/docs/src/ai/models.md @@ -12,8 +12,10 @@ We’re working hard to expand the models supported by Zed’s subscription offe | Claude Sonnet 4 | Anthropic | ✅ | 200k | N/A | $0.05 | | Claude Opus 4 | Anthropic | ❌ | 120k | $0.20 | N/A | | Claude Opus 4 | Anthropic | ✅ | 200k | N/A | $0.25 | +| Claude Opus 4.1 | Anthropic | ❌ | 120k | $0.20 | N/A | +| Claude Opus 4.1 | Anthropic | ✅ | 200k | N/A | $0.25 | -> Note: Because of the 5x token cost for [Opus relative to Sonnet](https://www.anthropic.com/pricing#api), each Opus prompt consumes 5 prompts against your billing meter +> Note: Because of the 5x token cost for [Opus relative to Sonnet](https://www.anthropic.com/pricing#api), each Opus 4 and 4.1 prompt consumes 5 prompts against your billing meter ## Usage {#usage} diff --git a/docs/src/telemetry.md b/docs/src/telemetry.md index 107aef5a96..46c39a88ae 100644 --- a/docs/src/telemetry.md +++ b/docs/src/telemetry.md @@ -4,7 +4,8 @@ Zed collects anonymous telemetry data to help the team understand how people are ## Configuring Telemetry Settings -You have full control over what data is sent out by Zed. To enable or disable some or all telemetry types, open your `settings.json` file via {#action zed::OpenSettings}({#kb zed::OpenSettings}) from the command palette. +You have full control over what data is sent out by Zed. +To enable or disable some or all telemetry types, open your `settings.json` file via {#action zed::OpenSettings}({#kb zed::OpenSettings}) from the command palette. Insert and tweak the following: @@ -15,8 +16,6 @@ Insert and tweak the following: }, ``` -The telemetry settings can also be configured via the welcome screen, which can be invoked via the {#action workspace::Welcome} action in the command palette. - ## Dataflow Telemetry is sent from the application to our servers. Data is proxied through our servers to enable us to easily switch analytics services. We currently use: diff --git a/extensions/emmet/Cargo.toml b/extensions/emmet/Cargo.toml index ff9debdea9..2fbdf2a7e5 100644 --- a/extensions/emmet/Cargo.toml +++ b/extensions/emmet/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "zed_emmet" -version = "0.0.5" +version = "0.0.6" edition.workspace = true publish.workspace = true license = "Apache-2.0" diff --git a/extensions/emmet/extension.toml b/extensions/emmet/extension.toml index 0ebb801f9d..a1848400b8 100644 --- a/extensions/emmet/extension.toml +++ b/extensions/emmet/extension.toml @@ -1,7 +1,7 @@ id = "emmet" name = "Emmet" description = "Emmet support" -version = "0.0.5" +version = "0.0.6" schema_version = 1 authors = ["Piotr Osiewicz "] repository = "https://github.com/zed-industries/zed" diff --git a/extensions/emmet/src/emmet.rs b/extensions/emmet/src/emmet.rs index e4fb3cf814..1434e16e88 100644 --- a/extensions/emmet/src/emmet.rs +++ b/extensions/emmet/src/emmet.rs @@ -5,7 +5,7 @@ struct EmmetExtension { did_find_server: bool, } -const SERVER_PATH: &str = "node_modules/.bin/emmet-language-server"; +const SERVER_PATH: &str = "node_modules/@olrtg/emmet-language-server/dist/index.js"; const PACKAGE_NAME: &str = "@olrtg/emmet-language-server"; impl EmmetExtension { diff --git a/rust-toolchain.toml b/rust-toolchain.toml index 3d87025a27..2c909e0e1e 100644 --- a/rust-toolchain.toml +++ b/rust-toolchain.toml @@ -3,11 +3,6 @@ channel = "1.89" profile = "minimal" components = [ "rustfmt", "clippy" ] targets = [ - "x86_64-apple-darwin", - "aarch64-apple-darwin", - "x86_64-unknown-freebsd", - "x86_64-unknown-linux-gnu", - "x86_64-pc-windows-msvc", "wasm32-wasip2", # extensions "x86_64-unknown-linux-musl", # remote server ]