Merge branch 'zed-industries:main' into main

This commit is contained in:
0x11 2025-08-20 10:31:30 +08:00 committed by GitHub
commit 62af5e6542
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
489 changed files with 13211 additions and 10509 deletions

View file

@ -56,7 +56,6 @@ runs:
$env:COMPlus_CreateDumpDiagnostics = "1"
cargo nextest run --workspace --no-fail-fast
continue-on-error: true
- name: Analyze crash dumps
if: always()

29
Cargo.lock generated
View file

@ -191,10 +191,12 @@ version = "0.1.0"
dependencies = [
"acp_thread",
"action_log",
"agent",
"agent-client-protocol",
"agent_servers",
"agent_settings",
"anyhow",
"assistant_context",
"assistant_tool",
"assistant_tools",
"chrono",
@ -204,10 +206,12 @@ dependencies = [
"collections",
"context_server",
"ctor",
"db",
"editor",
"env_logger 0.11.8",
"fs",
"futures 0.3.31",
"git",
"gpui",
"gpui_tokio",
"handlebars 4.5.0",
@ -221,6 +225,7 @@ dependencies = [
"log",
"lsp",
"open",
"parking_lot",
"paths",
"portable-pty",
"pretty_assertions",
@ -233,6 +238,7 @@ dependencies = [
"serde_json",
"settings",
"smol",
"sqlez",
"task",
"tempfile",
"terminal",
@ -249,6 +255,7 @@ dependencies = [
"workspace-hack",
"worktree",
"zlog",
"zstd",
]
[[package]]
@ -256,6 +263,7 @@ name = "agent_servers"
version = "0.1.0"
dependencies = [
"acp_thread",
"action_log",
"agent-client-protocol",
"agent_settings",
"agentic-coding-protocol",
@ -277,6 +285,7 @@ dependencies = [
"project",
"rand 0.8.5",
"schemars",
"semver",
"serde",
"serde_json",
"settings",
@ -3865,7 +3874,7 @@ dependencies = [
"jni",
"js-sys",
"libc",
"mach2",
"mach2 0.4.2",
"ndk",
"ndk-context",
"num-derive",
@ -4015,7 +4024,7 @@ checksum = "031ed29858d90cfdf27fe49fae28028a1f20466db97962fa2f4ea34809aeebf3"
dependencies = [
"cfg-if",
"libc",
"mach2",
"mach2 0.4.2",
]
[[package]]
@ -4027,7 +4036,7 @@ dependencies = [
"cfg-if",
"crash-context",
"libc",
"mach2",
"mach2 0.4.2",
"parking_lot",
]
@ -4037,6 +4046,7 @@ version = "0.1.0"
dependencies = [
"crash-handler",
"log",
"mach2 0.5.0",
"minidumper",
"paths",
"release_channel",
@ -9859,6 +9869,15 @@ dependencies = [
"libc",
]
[[package]]
name = "mach2"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6a1b95cd5421ec55b445b5ae102f5ea0e768de1f82bd3001e11f426c269c3aea"
dependencies = [
"libc",
]
[[package]]
name = "malloc_buf"
version = "0.0.6"
@ -10195,7 +10214,7 @@ dependencies = [
"goblin",
"libc",
"log",
"mach2",
"mach2 0.4.2",
"memmap2",
"memoffset",
"minidump-common",
@ -18285,7 +18304,7 @@ dependencies = [
"indexmap",
"libc",
"log",
"mach2",
"mach2 0.4.2",
"memfd",
"object",
"once_cell",

View file

@ -515,6 +515,7 @@ libsqlite3-sys = { version = "0.30.1", features = ["bundled"] }
linkify = "0.10.0"
log = { version = "0.4.16", features = ["kv_unstable_serde", "serde"] }
lsp-types = { git = "https://github.com/zed-industries/lsp-types", rev = "39f629bdd03d59abd786ed9fc27e8bca02c0c0ec" }
mach2 = "0.5"
markup5ever_rcdom = "0.3.0"
metal = "0.29"
minidumper = "0.8"
@ -821,16 +822,30 @@ single_range_in_vec_init = "allow"
style = { level = "allow", priority = -1 }
# Temporary list of style lints that we've fixed so far.
comparison_to_empty = "warn"
into_iter_on_ref = "warn"
iter_cloned_collect = "warn"
iter_next_slice = "warn"
iter_nth = "warn"
iter_nth_zero = "warn"
iter_skip_next = "warn"
let_and_return = "warn"
module_inception = { level = "deny" }
question_mark = { level = "deny" }
single_match = "warn"
redundant_closure = { level = "deny" }
redundant_static_lifetimes = { level = "warn" }
redundant_pattern_matching = "warn"
redundant_field_names = "warn"
declare_interior_mutable_const = { level = "deny" }
collapsible_if = { level = "warn"}
collapsible_else_if = { level = "warn" }
needless_borrow = { level = "warn"}
needless_return = { level = "warn" }
unnecessary_mut_passed = {level = "warn"}
unnecessary_map_or = { level = "warn" }
unused_unit = "warn"
# Individual rules that have violations in the codebase:
type_complexity = "allow"
# We often return trait objects from `new` functions.
@ -849,6 +864,10 @@ too_many_arguments = "allow"
# We often have large enum variants yet we rarely actually bother with splitting them up.
large_enum_variant = "allow"
# `enum_variant_names` fires for all enums, even when they derive serde traits.
# Adhering to this lint would be a breaking change.
enum_variant_names = "allow"
[workspace.metadata.cargo-machete]
ignored = [
"bindgen",

View file

@ -3,9 +3,13 @@ mod diff;
mod mention;
mod terminal;
use collections::HashSet;
pub use connection::*;
pub use diff::*;
use language::language_settings::FormatOnSave;
pub use mention::*;
use project::lsp_store::{FormatTrigger, LspFormatTarget};
use serde::{Deserialize, Serialize};
pub use terminal::*;
use action_log::ActionLog;
@ -49,7 +53,7 @@ impl UserMessage {
if self
.checkpoint
.as_ref()
.map_or(false, |checkpoint| checkpoint.show)
.is_some_and(|checkpoint| checkpoint.show)
{
writeln!(markdown, "## User (checkpoint)").unwrap();
} else {
@ -249,14 +253,13 @@ impl ToolCall {
}
if let Some(raw_output) = raw_output {
if self.content.is_empty() {
if let Some(markdown) = markdown_for_raw_output(&raw_output, &language_registry, cx)
{
self.content
.push(ToolCallContent::ContentBlock(ContentBlock::Markdown {
markdown,
}));
}
if self.content.is_empty()
&& let Some(markdown) = markdown_for_raw_output(&raw_output, &language_registry, cx)
{
self.content
.push(ToolCallContent::ContentBlock(ContentBlock::Markdown {
markdown,
}));
}
self.raw_output = Some(raw_output);
}
@ -430,11 +433,11 @@ impl ContentBlock {
language_registry: &Arc<LanguageRegistry>,
cx: &mut App,
) {
if matches!(self, ContentBlock::Empty) {
if let acp::ContentBlock::ResourceLink(resource_link) = block {
*self = ContentBlock::ResourceLink { resource_link };
return;
}
if matches!(self, ContentBlock::Empty)
&& let acp::ContentBlock::ResourceLink(resource_link) = block
{
*self = ContentBlock::ResourceLink { resource_link };
return;
}
let new_content = self.block_string_contents(block);
@ -538,9 +541,15 @@ impl ToolCallContent {
acp::ToolCallContent::Content { content } => {
Self::ContentBlock(ContentBlock::new(content, &language_registry, cx))
}
acp::ToolCallContent::Diff { diff } => {
Self::Diff(cx.new(|cx| Diff::from_acp(diff, language_registry, cx)))
}
acp::ToolCallContent::Diff { diff } => Self::Diff(cx.new(|cx| {
Diff::finalized(
diff.path,
diff.old_text,
diff.new_text,
language_registry,
cx,
)
})),
}
}
@ -659,6 +668,12 @@ impl PlanEntry {
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct TokenUsage {
pub max_tokens: u64,
pub used_tokens: u64,
}
#[derive(Debug, Clone)]
pub struct RetryStatus {
pub last_error: SharedString,
@ -678,18 +693,21 @@ pub struct AcpThread {
send_task: Option<Task<()>>,
connection: Rc<dyn AgentConnection>,
session_id: acp::SessionId,
token_usage: Option<TokenUsage>,
}
#[derive(Debug)]
pub enum AcpThreadEvent {
NewEntry,
TitleUpdated,
TokenUsageUpdated,
EntryUpdated(usize),
EntriesRemoved(Range<usize>),
ToolAuthorizationRequired,
Retry(RetryStatus),
Stopped,
Error,
ServerExited(ExitStatus),
LoadError(LoadError),
}
impl EventEmitter<AcpThreadEvent> for AcpThread {}
@ -703,20 +721,30 @@ pub enum ThreadStatus {
#[derive(Debug, Clone)]
pub enum LoadError {
NotInstalled {
error_message: SharedString,
install_message: SharedString,
install_command: String,
},
Unsupported {
error_message: SharedString,
upgrade_message: SharedString,
upgrade_command: String,
},
Exited(i32),
Exited {
status: ExitStatus,
},
Other(SharedString),
}
impl Display for LoadError {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
LoadError::Unsupported { error_message, .. } => write!(f, "{}", error_message),
LoadError::Exited(status) => write!(f, "Server exited with status {}", status),
LoadError::NotInstalled { error_message, .. }
| LoadError::Unsupported { error_message, .. } => {
write!(f, "{error_message}")
}
LoadError::Exited { status } => write!(f, "Server exited with status {status}"),
LoadError::Other(msg) => write!(f, "{}", msg),
}
}
@ -729,11 +757,9 @@ impl AcpThread {
title: impl Into<SharedString>,
connection: Rc<dyn AgentConnection>,
project: Entity<Project>,
action_log: Entity<ActionLog>,
session_id: acp::SessionId,
cx: &mut Context<Self>,
) -> Self {
let action_log = cx.new(|_| ActionLog::new(project.clone()));
Self {
action_log,
shared_buffers: Default::default(),
@ -744,6 +770,7 @@ impl AcpThread {
send_task: None,
connection,
session_id,
token_usage: None,
}
}
@ -783,6 +810,10 @@ impl AcpThread {
}
}
pub fn token_usage(&self) -> Option<&TokenUsage> {
self.token_usage.as_ref()
}
pub fn has_pending_edit_tool_calls(&self) -> bool {
for entry in self.entries.iter().rev() {
match entry {
@ -927,6 +958,17 @@ impl AcpThread {
cx.emit(AcpThreadEvent::NewEntry);
}
pub fn update_title(&mut self, title: SharedString, cx: &mut Context<Self>) -> Result<()> {
self.title = title;
cx.emit(AcpThreadEvent::TitleUpdated);
Ok(())
}
pub fn update_token_usage(&mut self, usage: Option<TokenUsage>, cx: &mut Context<Self>) {
self.token_usage = usage;
cx.emit(AcpThreadEvent::TokenUsageUpdated);
}
pub fn update_retry_status(&mut self, status: RetryStatus, cx: &mut Context<Self>) {
cx.emit(AcpThreadEvent::Retry(status));
}
@ -1022,6 +1064,22 @@ impl AcpThread {
})
}
pub fn tool_call(&mut self, id: &acp::ToolCallId) -> Option<(usize, &ToolCall)> {
self.entries
.iter()
.enumerate()
.rev()
.find_map(|(index, tool_call)| {
if let AgentThreadEntry::ToolCall(tool_call) = tool_call
&& &tool_call.id == id
{
Some((index, tool_call))
} else {
None
}
})
}
pub fn resolve_locations(&mut self, id: acp::ToolCallId, cx: &mut Context<Self>) {
let project = self.project.clone();
let Some((_, tool_call)) = self.tool_call_mut(&id) else {
@ -1572,30 +1630,59 @@ impl AcpThread {
.collect::<Vec<_>>()
})
.await;
cx.update(|cx| {
project.update(cx, |project, cx| {
project.set_agent_location(
Some(AgentLocation {
buffer: buffer.downgrade(),
position: edits
.last()
.map(|(range, _)| range.end)
.unwrap_or(Anchor::MIN),
}),
cx,
);
});
project.update(cx, |project, cx| {
project.set_agent_location(
Some(AgentLocation {
buffer: buffer.downgrade(),
position: edits
.last()
.map(|(range, _)| range.end)
.unwrap_or(Anchor::MIN),
}),
cx,
);
})?;
let format_on_save = cx.update(|cx| {
action_log.update(cx, |action_log, cx| {
action_log.buffer_read(buffer.clone(), cx);
});
buffer.update(cx, |buffer, cx| {
let format_on_save = buffer.update(cx, |buffer, cx| {
buffer.edit(edits, None, cx);
let settings = language::language_settings::language_settings(
buffer.language().map(|l| l.name()),
buffer.file(),
cx,
);
settings.format_on_save != FormatOnSave::Off
});
action_log.update(cx, |action_log, cx| {
action_log.buffer_edited(buffer.clone(), cx);
});
format_on_save
})?;
if format_on_save {
let format_task = project.update(cx, |project, cx| {
project.format(
HashSet::from_iter([buffer.clone()]),
LspFormatTarget::Buffers,
false,
FormatTrigger::Save,
cx,
)
})?;
format_task.await.log_err();
action_log.update(cx, |action_log, cx| {
action_log.buffer_edited(buffer.clone(), cx);
})?;
}
project
.update(cx, |project, cx| project.save_buffer(buffer, cx))?
.await
@ -1606,8 +1693,8 @@ impl AcpThread {
self.entries.iter().map(|e| e.to_markdown(cx)).collect()
}
pub fn emit_server_exited(&mut self, status: ExitStatus, cx: &mut Context<Self>) {
cx.emit(AcpThreadEvent::ServerExited(status));
pub fn emit_load_error(&mut self, error: LoadError, cx: &mut Context<Self>) {
cx.emit(AcpThreadEvent::LoadError(error));
}
}
@ -1658,7 +1745,7 @@ mod tests {
use super::*;
use anyhow::anyhow;
use futures::{channel::mpsc, future::LocalBoxFuture, select};
use gpui::{AsyncApp, TestAppContext, WeakEntity};
use gpui::{App, AsyncApp, TestAppContext, WeakEntity};
use indoc::indoc;
use project::{FakeFs, Fs};
use rand::Rng as _;
@ -2145,7 +2232,7 @@ mod tests {
"}
);
});
assert_eq!(fs.files(), vec![Path::new("/test/file-0")]);
assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["ipsum".into()], cx)))
.await
@ -2175,7 +2262,10 @@ mod tests {
});
assert_eq!(
fs.files(),
vec![Path::new("/test/file-0"), Path::new("/test/file-1")]
vec![
Path::new(path!("/test/file-0")),
Path::new(path!("/test/file-1"))
]
);
// Checkpoint isn't stored when there are no changes.
@ -2216,7 +2306,10 @@ mod tests {
});
assert_eq!(
fs.files(),
vec![Path::new("/test/file-0"), Path::new("/test/file-1")]
vec![
Path::new(path!("/test/file-0")),
Path::new(path!("/test/file-1"))
]
);
// Rewinding the conversation truncates the history and restores the checkpoint.
@ -2244,7 +2337,7 @@ mod tests {
"}
);
});
assert_eq!(fs.files(), vec![Path::new("/test/file-0")]);
assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
}
async fn run_until_first_tool_call(
@ -2328,7 +2421,7 @@ mod tests {
self: Rc<Self>,
project: Entity<Project>,
_cwd: &Path,
cx: &mut gpui::App,
cx: &mut App,
) -> Task<gpui::Result<Entity<AcpThread>>> {
let session_id = acp::SessionId(
rand::thread_rng()
@ -2338,8 +2431,16 @@ mod tests {
.collect::<String>()
.into(),
);
let thread =
cx.new(|cx| AcpThread::new("Test", self.clone(), project, session_id.clone(), cx));
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let thread = cx.new(|_cx| {
AcpThread::new(
"Test",
self.clone(),
project,
action_log,
session_id.clone(),
)
});
self.sessions.lock().insert(session_id, thread.downgrade());
Task::ready(Ok(thread))
}

View file

@ -5,11 +5,12 @@ use collections::IndexMap;
use gpui::{Entity, SharedString, Task};
use language_model::LanguageModelProviderId;
use project::Project;
use serde::{Deserialize, Serialize};
use std::{any::Any, error::Error, fmt, path::Path, rc::Rc, sync::Arc};
use ui::{App, IconName};
use uuid::Uuid;
#[derive(Clone, Debug, Eq, PartialEq)]
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Hash)]
pub struct UserMessageId(Arc<str>);
impl UserMessageId {
@ -208,6 +209,7 @@ impl AgentModelList {
mod test_support {
use std::sync::Arc;
use action_log::ActionLog;
use collections::HashMap;
use futures::{channel::oneshot, future::try_join_all};
use gpui::{AppContext as _, WeakEntity};
@ -295,8 +297,16 @@ mod test_support {
cx: &mut gpui::App,
) -> Task<gpui::Result<Entity<AcpThread>>> {
let session_id = acp::SessionId(self.sessions.lock().len().to_string().into());
let thread =
cx.new(|cx| AcpThread::new("Test", self.clone(), project, session_id.clone(), cx));
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let thread = cx.new(|_cx| {
AcpThread::new(
"Test",
self.clone(),
project,
action_log,
session_id.clone(),
)
});
self.sessions.lock().insert(
session_id,
Session {

View file

@ -1,4 +1,3 @@
use agent_client_protocol as acp;
use anyhow::Result;
use buffer_diff::{BufferDiff, BufferDiffSnapshot};
use editor::{MultiBuffer, PathKey};
@ -21,17 +20,13 @@ pub enum Diff {
}
impl Diff {
pub fn from_acp(
diff: acp::Diff,
pub fn finalized(
path: PathBuf,
old_text: Option<String>,
new_text: String,
language_registry: Arc<LanguageRegistry>,
cx: &mut Context<Self>,
) -> Self {
let acp::Diff {
path,
old_text,
new_text,
} = diff;
let multibuffer = cx.new(|_cx| MultiBuffer::without_headers(Capability::ReadOnly));
let new_buffer = cx.new(|cx| Buffer::local(new_text, cx));

View file

@ -2,6 +2,7 @@ use agent::ThreadId;
use anyhow::{Context as _, Result, bail};
use file_icons::FileIcons;
use prompt_store::{PromptId, UserPromptId};
use serde::{Deserialize, Serialize};
use std::{
fmt,
ops::Range,
@ -11,7 +12,7 @@ use std::{
use ui::{App, IconName, SharedString};
use url::Url;
#[derive(Clone, Debug, PartialEq, Eq)]
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub enum MentionUri {
File {
abs_path: PathBuf,
@ -51,6 +52,7 @@ impl MentionUri {
let path = url.path();
match url.scheme() {
"file" => {
let path = url.to_file_path().ok().context("Extracting file path")?;
if let Some(fragment) = url.fragment() {
let range = fragment
.strip_prefix("L")
@ -71,24 +73,16 @@ impl MentionUri {
if let Some(name) = single_query_param(&url, "symbol")? {
Ok(Self::Symbol {
name,
path: path.into(),
path,
line_range,
})
} else {
Ok(Self::Selection {
path: path.into(),
line_range,
})
Ok(Self::Selection { path, line_range })
}
} else if input.ends_with("/") {
Ok(Self::Directory { abs_path: path })
} else {
let abs_path =
PathBuf::from(format!("{}{}", url.host_str().unwrap_or(""), path));
if input.ends_with("/") {
Ok(Self::Directory { abs_path })
} else {
Ok(Self::File { abs_path })
}
Ok(Self::File { abs_path: path })
}
}
"zed" => {
@ -161,27 +155,17 @@ impl MentionUri {
pub fn to_uri(&self) -> Url {
match self {
MentionUri::File { abs_path } => {
let mut url = Url::parse("file:///").unwrap();
let path = abs_path.to_string_lossy();
url.set_path(&path);
url
Url::from_file_path(abs_path).expect("mention path should be absolute")
}
MentionUri::Directory { abs_path } => {
let mut url = Url::parse("file:///").unwrap();
let mut path = abs_path.to_string_lossy().to_string();
if !path.ends_with("/") {
path.push_str("/");
}
url.set_path(&path);
url
Url::from_directory_path(abs_path).expect("mention path should be absolute")
}
MentionUri::Symbol {
path,
name,
line_range,
} => {
let mut url = Url::parse("file:///").unwrap();
url.set_path(&path.to_string_lossy());
let mut url = Url::from_file_path(path).expect("mention path should be absolute");
url.query_pairs_mut().append_pair("symbol", name);
url.set_fragment(Some(&format!(
"L{}:{}",
@ -191,8 +175,7 @@ impl MentionUri {
url
}
MentionUri::Selection { path, line_range } => {
let mut url = Url::parse("file:///").unwrap();
url.set_path(&path.to_string_lossy());
let mut url = Url::from_file_path(path).expect("mention path should be absolute");
url.set_fragment(Some(&format!(
"L{}:{}",
line_range.start + 1,
@ -265,15 +248,17 @@ pub fn selection_name(path: &Path, line_range: &Range<u32>) -> String {
#[cfg(test)]
mod tests {
use util::{path, uri};
use super::*;
#[test]
fn test_parse_file_uri() {
let file_uri = "file:///path/to/file.rs";
let file_uri = uri!("file:///path/to/file.rs");
let parsed = MentionUri::parse(file_uri).unwrap();
match &parsed {
MentionUri::File { abs_path } => {
assert_eq!(abs_path.to_str().unwrap(), "/path/to/file.rs");
assert_eq!(abs_path.to_str().unwrap(), path!("/path/to/file.rs"));
}
_ => panic!("Expected File variant"),
}
@ -282,11 +267,11 @@ mod tests {
#[test]
fn test_parse_directory_uri() {
let file_uri = "file:///path/to/dir/";
let file_uri = uri!("file:///path/to/dir/");
let parsed = MentionUri::parse(file_uri).unwrap();
match &parsed {
MentionUri::Directory { abs_path } => {
assert_eq!(abs_path.to_str().unwrap(), "/path/to/dir/");
assert_eq!(abs_path.to_str().unwrap(), path!("/path/to/dir/"));
}
_ => panic!("Expected Directory variant"),
}
@ -296,22 +281,24 @@ mod tests {
#[test]
fn test_to_directory_uri_with_slash() {
let uri = MentionUri::Directory {
abs_path: PathBuf::from("/path/to/dir/"),
abs_path: PathBuf::from(path!("/path/to/dir/")),
};
assert_eq!(uri.to_uri().to_string(), "file:///path/to/dir/");
let expected = uri!("file:///path/to/dir/");
assert_eq!(uri.to_uri().to_string(), expected);
}
#[test]
fn test_to_directory_uri_without_slash() {
let uri = MentionUri::Directory {
abs_path: PathBuf::from("/path/to/dir"),
abs_path: PathBuf::from(path!("/path/to/dir")),
};
assert_eq!(uri.to_uri().to_string(), "file:///path/to/dir/");
let expected = uri!("file:///path/to/dir/");
assert_eq!(uri.to_uri().to_string(), expected);
}
#[test]
fn test_parse_symbol_uri() {
let symbol_uri = "file:///path/to/file.rs?symbol=MySymbol#L10:20";
let symbol_uri = uri!("file:///path/to/file.rs?symbol=MySymbol#L10:20");
let parsed = MentionUri::parse(symbol_uri).unwrap();
match &parsed {
MentionUri::Symbol {
@ -319,7 +306,7 @@ mod tests {
name,
line_range,
} => {
assert_eq!(path.to_str().unwrap(), "/path/to/file.rs");
assert_eq!(path.to_str().unwrap(), path!("/path/to/file.rs"));
assert_eq!(name, "MySymbol");
assert_eq!(line_range.start, 9);
assert_eq!(line_range.end, 19);
@ -331,11 +318,11 @@ mod tests {
#[test]
fn test_parse_selection_uri() {
let selection_uri = "file:///path/to/file.rs#L5:15";
let selection_uri = uri!("file:///path/to/file.rs#L5:15");
let parsed = MentionUri::parse(selection_uri).unwrap();
match &parsed {
MentionUri::Selection { path, line_range } => {
assert_eq!(path.to_str().unwrap(), "/path/to/file.rs");
assert_eq!(path.to_str().unwrap(), path!("/path/to/file.rs"));
assert_eq!(line_range.start, 4);
assert_eq!(line_range.end, 14);
}
@ -417,32 +404,35 @@ mod tests {
#[test]
fn test_invalid_line_range_format() {
// Missing L prefix
assert!(MentionUri::parse("file:///path/to/file.rs#10:20").is_err());
assert!(MentionUri::parse(uri!("file:///path/to/file.rs#10:20")).is_err());
// Missing colon separator
assert!(MentionUri::parse("file:///path/to/file.rs#L1020").is_err());
assert!(MentionUri::parse(uri!("file:///path/to/file.rs#L1020")).is_err());
// Invalid numbers
assert!(MentionUri::parse("file:///path/to/file.rs#L10:abc").is_err());
assert!(MentionUri::parse("file:///path/to/file.rs#Labc:20").is_err());
assert!(MentionUri::parse(uri!("file:///path/to/file.rs#L10:abc")).is_err());
assert!(MentionUri::parse(uri!("file:///path/to/file.rs#Labc:20")).is_err());
}
#[test]
fn test_invalid_query_parameters() {
// Invalid query parameter name
assert!(MentionUri::parse("file:///path/to/file.rs#L10:20?invalid=test").is_err());
assert!(MentionUri::parse(uri!("file:///path/to/file.rs#L10:20?invalid=test")).is_err());
// Too many query parameters
assert!(
MentionUri::parse("file:///path/to/file.rs#L10:20?symbol=test&another=param").is_err()
MentionUri::parse(uri!(
"file:///path/to/file.rs#L10:20?symbol=test&another=param"
))
.is_err()
);
}
#[test]
fn test_zero_based_line_numbers() {
// Test that 0-based line numbers are rejected (should be 1-based)
assert!(MentionUri::parse("file:///path/to/file.rs#L0:10").is_err());
assert!(MentionUri::parse("file:///path/to/file.rs#L1:0").is_err());
assert!(MentionUri::parse("file:///path/to/file.rs#L0:0").is_err());
assert!(MentionUri::parse(uri!("file:///path/to/file.rs#L0:10")).is_err());
assert!(MentionUri::parse(uri!("file:///path/to/file.rs#L1:0")).is_err());
assert!(MentionUri::parse(uri!("file:///path/to/file.rs#L0:0")).is_err());
}
}

View file

@ -116,7 +116,7 @@ impl ActionLog {
} else if buffer
.read(cx)
.file()
.map_or(false, |file| file.disk_state().exists())
.is_some_and(|file| file.disk_state().exists())
{
TrackedBufferStatus::Created {
existing_file_content: Some(buffer.read(cx).as_rope().clone()),
@ -215,7 +215,7 @@ impl ActionLog {
if buffer
.read(cx)
.file()
.map_or(false, |file| file.disk_state() == DiskState::Deleted)
.is_some_and(|file| file.disk_state() == DiskState::Deleted)
{
// If the buffer had been edited by a tool, but it got
// deleted externally, we want to stop tracking it.
@ -227,7 +227,7 @@ impl ActionLog {
if buffer
.read(cx)
.file()
.map_or(false, |file| file.disk_state() != DiskState::Deleted)
.is_some_and(|file| file.disk_state() != DiskState::Deleted)
{
// If the buffer had been deleted by a tool, but it got
// resurrected externally, we want to clear the edits we
@ -264,15 +264,14 @@ impl ActionLog {
if let Some((git_diff, (buffer_repo, _))) = git_diff.as_ref().zip(buffer_repo) {
cx.update(|cx| {
let mut old_head = buffer_repo.read(cx).head_commit.clone();
Some(cx.subscribe(git_diff, move |_, event, cx| match event {
buffer_diff::BufferDiffEvent::DiffChanged { .. } => {
Some(cx.subscribe(git_diff, move |_, event, cx| {
if let buffer_diff::BufferDiffEvent::DiffChanged { .. } = event {
let new_head = buffer_repo.read(cx).head_commit.clone();
if new_head != old_head {
old_head = new_head;
git_diff_updates_tx.send(()).ok();
}
}
_ => {}
}))
})?
} else {
@ -614,10 +613,10 @@ impl ActionLog {
false
}
});
if tracked_buffer.unreviewed_edits.is_empty() {
if let TrackedBufferStatus::Created { .. } = &mut tracked_buffer.status {
tracked_buffer.status = TrackedBufferStatus::Modified;
}
if tracked_buffer.unreviewed_edits.is_empty()
&& let TrackedBufferStatus::Created { .. } = &mut tracked_buffer.status
{
tracked_buffer.status = TrackedBufferStatus::Modified;
}
tracked_buffer.schedule_diff_update(ChangeAuthor::User, cx);
}
@ -811,7 +810,7 @@ impl ActionLog {
tracked.version != buffer.version
&& buffer
.file()
.map_or(false, |file| file.disk_state() != DiskState::Deleted)
.is_some_and(|file| file.disk_state() != DiskState::Deleted)
})
.map(|(buffer, _)| buffer)
}
@ -847,7 +846,7 @@ fn apply_non_conflicting_edits(
conflict = true;
if new_edits
.peek()
.map_or(false, |next_edit| next_edit.old.overlaps(&old_edit.new))
.is_some_and(|next_edit| next_edit.old.overlaps(&old_edit.new))
{
new_edit = new_edits.next().unwrap();
} else {

View file

@ -103,26 +103,21 @@ impl ActivityIndicator {
cx.subscribe_in(
&workspace_handle,
window,
|activity_indicator, _, event, window, cx| match event {
workspace::Event::ClearActivityIndicator { .. } => {
if activity_indicator.statuses.pop().is_some() {
activity_indicator.dismiss_error_message(
&DismissErrorMessage,
window,
cx,
);
cx.notify();
}
|activity_indicator, _, event, window, cx| {
if let workspace::Event::ClearActivityIndicator { .. } = event
&& activity_indicator.statuses.pop().is_some()
{
activity_indicator.dismiss_error_message(&DismissErrorMessage, window, cx);
cx.notify();
}
_ => {}
},
)
.detach();
cx.subscribe(
&project.read(cx).lsp_store(),
|activity_indicator, _, event, cx| match event {
LspStoreEvent::LanguageServerUpdate { name, message, .. } => {
|activity_indicator, _, event, cx| {
if let LspStoreEvent::LanguageServerUpdate { name, message, .. } = event {
if let proto::update_language_server::Variant::StatusUpdate(status_update) =
message
{
@ -191,7 +186,6 @@ impl ActivityIndicator {
}
cx.notify()
}
_ => {}
},
)
.detach();
@ -206,9 +200,10 @@ impl ActivityIndicator {
cx.subscribe(
&project.read(cx).git_store().clone(),
|_, _, event: &GitStoreEvent, cx| match event {
project::git_store::GitStoreEvent::JobsUpdated => cx.notify(),
_ => {}
|_, _, event: &GitStoreEvent, cx| {
if let project::git_store::GitStoreEvent::JobsUpdated = event {
cx.notify()
}
},
)
.detach();
@ -458,26 +453,24 @@ impl ActivityIndicator {
.map(|r| r.read(cx))
.and_then(Repository::current_job);
// Show any long-running git command
if let Some(job_info) = current_job {
if Instant::now() - job_info.start >= GIT_OPERATION_DELAY {
return Some(Content {
icon: Some(
Icon::new(IconName::ArrowCircle)
.size(IconSize::Small)
.with_animation(
"arrow-circle",
Animation::new(Duration::from_secs(2)).repeat(),
|icon, delta| {
icon.transform(Transformation::rotate(percentage(delta)))
},
)
.into_any_element(),
),
message: job_info.message.into(),
on_click: None,
tooltip_message: None,
});
}
if let Some(job_info) = current_job
&& Instant::now() - job_info.start >= GIT_OPERATION_DELAY
{
return Some(Content {
icon: Some(
Icon::new(IconName::ArrowCircle)
.size(IconSize::Small)
.with_animation(
"arrow-circle",
Animation::new(Duration::from_secs(2)).repeat(),
|icon, delta| icon.transform(Transformation::rotate(percentage(delta))),
)
.into_any_element(),
),
message: job_info.message.into(),
on_click: None,
tooltip_message: None,
});
}
// Show any language server installation info.
@ -740,21 +733,20 @@ impl ActivityIndicator {
if let Some(extension_store) =
ExtensionStore::try_global(cx).map(|extension_store| extension_store.read(cx))
&& let Some(extension_id) = extension_store.outstanding_operations().keys().next()
{
if let Some(extension_id) = extension_store.outstanding_operations().keys().next() {
return Some(Content {
icon: Some(
Icon::new(IconName::Download)
.size(IconSize::Small)
.into_any_element(),
),
message: format!("Updating {extension_id} extension…"),
on_click: Some(Arc::new(|this, window, cx| {
this.dismiss_error_message(&DismissErrorMessage, window, cx)
})),
tooltip_message: None,
});
}
return Some(Content {
icon: Some(
Icon::new(IconName::Download)
.size(IconSize::Small)
.into_any_element(),
),
message: format!("Updating {extension_id} extension…"),
on_click: Some(Arc::new(|this, window, cx| {
this.dismiss_error_message(&DismissErrorMessage, window, cx)
})),
tooltip_message: None,
});
}
None

View file

@ -90,7 +90,7 @@ impl AgentProfile {
return false;
};
return Self::is_enabled(settings, source, tool_name);
Self::is_enabled(settings, source, tool_name)
}
fn is_enabled(settings: &AgentProfileSettings, source: ToolSource, name: String) -> bool {

View file

@ -201,24 +201,24 @@ impl FileContextHandle {
parse_status.changed().await.log_err();
}
if let Ok(snapshot) = buffer.read_with(cx, |buffer, _| buffer.snapshot()) {
if let Some(outline) = snapshot.outline(None) {
let items = outline
.items
.into_iter()
.map(|item| item.to_point(&snapshot));
if let Ok(snapshot) = buffer.read_with(cx, |buffer, _| buffer.snapshot())
&& let Some(outline) = snapshot.outline(None)
{
let items = outline
.items
.into_iter()
.map(|item| item.to_point(&snapshot));
if let Ok(outline_text) =
outline::render_outline(items, None, 0, usize::MAX).await
{
let context = AgentContext::File(FileContext {
handle: self,
full_path,
text: outline_text.into(),
is_outline: true,
});
return Some((context, vec![buffer]));
}
if let Ok(outline_text) =
outline::render_outline(items, None, 0, usize::MAX).await
{
let context = AgentContext::File(FileContext {
handle: self,
full_path,
text: outline_text.into(),
is_outline: true,
});
return Some((context, vec![buffer]));
}
}
}

View file

@ -338,11 +338,9 @@ impl ContextStore {
image_task,
context_id: self.next_context_id.post_inc(),
});
if self.has_context(&context) {
if remove_if_exists {
self.remove_context(&context, cx);
return None;
}
if self.has_context(&context) && remove_if_exists {
self.remove_context(&context, cx);
return None;
}
self.insert_context(context.clone(), cx);

View file

@ -1645,15 +1645,13 @@ impl Thread {
self.tool_use
.request_tool_use(tool_message_id, tool_use, tool_use_metadata.clone(), cx);
let pending_tool_use = self.tool_use.insert_tool_output(
self.tool_use.insert_tool_output(
tool_use_id.clone(),
tool_name,
tool_output,
self.configured_model.as_ref(),
self.completion_mode,
);
pending_tool_use
)
}
pub fn stream_completion(
@ -1967,11 +1965,9 @@ impl Thread {
if let Some(prev_message) =
thread.messages.get(ix - 1)
{
if prev_message.role == Role::Assistant {
&& prev_message.role == Role::Assistant {
break;
}
}
}
}
@ -2476,13 +2472,13 @@ impl Thread {
.ok()?;
// Save thread so its summary can be reused later
if let Some(thread) = thread.upgrade() {
if let Ok(Ok(save_task)) = cx.update(|cx| {
if let Some(thread) = thread.upgrade()
&& let Ok(Ok(save_task)) = cx.update(|cx| {
thread_store
.update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx))
}) {
save_task.await.log_err();
}
})
{
save_task.await.log_err();
}
Some(())
@ -2730,12 +2726,11 @@ impl Thread {
window: Option<AnyWindowHandle>,
cx: &mut Context<Self>,
) {
if self.all_tools_finished() {
if let Some(ConfiguredModel { model, .. }) = self.configured_model.as_ref() {
if !canceled {
self.send_to_model(model.clone(), CompletionIntent::ToolResults, window, cx);
}
}
if self.all_tools_finished()
&& let Some(ConfiguredModel { model, .. }) = self.configured_model.as_ref()
&& !canceled
{
self.send_to_model(model.clone(), CompletionIntent::ToolResults, window, cx);
}
cx.emit(ThreadEvent::ToolFinished {
@ -2922,11 +2917,11 @@ impl Thread {
let buffer_store = project.read(app_cx).buffer_store();
for buffer_handle in buffer_store.read(app_cx).buffers() {
let buffer = buffer_handle.read(app_cx);
if buffer.is_dirty() {
if let Some(file) = buffer.file() {
let path = file.path().to_string_lossy().to_string();
unsaved_buffers.push(path);
}
if buffer.is_dirty()
&& let Some(file) = buffer.file()
{
let path = file.path().to_string_lossy().to_string();
unsaved_buffers.push(path);
}
}
})
@ -3178,13 +3173,13 @@ impl Thread {
.model
.max_token_count_for_mode(self.completion_mode().into());
if let Some(exceeded_error) = &self.exceeded_window_error {
if model.model.id() == exceeded_error.model_id {
return Some(TotalTokenUsage {
total: exceeded_error.token_count,
max,
});
}
if let Some(exceeded_error) = &self.exceeded_window_error
&& model.model.id() == exceeded_error.model_id
{
return Some(TotalTokenUsage {
total: exceeded_error.token_count,
max,
});
}
let total = self

View file

@ -42,7 +42,7 @@ use std::{
use util::ResultExt as _;
pub static ZED_STATELESS: std::sync::LazyLock<bool> =
std::sync::LazyLock::new(|| std::env::var("ZED_STATELESS").map_or(false, |v| !v.is_empty()));
std::sync::LazyLock::new(|| std::env::var("ZED_STATELESS").is_ok_and(|v| !v.is_empty()));
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum DataType {
@ -74,7 +74,7 @@ impl Column for DataType {
}
}
const RULES_FILE_NAMES: [&'static str; 9] = [
const RULES_FILE_NAMES: [&str; 9] = [
".rules",
".cursorrules",
".windsurfrules",
@ -581,33 +581,32 @@ impl ThreadStore {
return;
};
if protocol.capable(context_server::protocol::ServerCapability::Tools) {
if let Some(response) = protocol
if protocol.capable(context_server::protocol::ServerCapability::Tools)
&& let Some(response) = protocol
.request::<context_server::types::requests::ListTools>(())
.await
.log_err()
{
let tool_ids = tool_working_set
.update(cx, |tool_working_set, cx| {
tool_working_set.extend(
response.tools.into_iter().map(|tool| {
Arc::new(ContextServerTool::new(
context_server_store.clone(),
server.id(),
tool,
)) as Arc<dyn Tool>
}),
cx,
)
})
.log_err();
{
let tool_ids = tool_working_set
.update(cx, |tool_working_set, cx| {
tool_working_set.extend(
response.tools.into_iter().map(|tool| {
Arc::new(ContextServerTool::new(
context_server_store.clone(),
server.id(),
tool,
)) as Arc<dyn Tool>
}),
cx,
)
})
.log_err();
if let Some(tool_ids) = tool_ids {
this.update(cx, |this, _| {
this.context_server_tool_ids.insert(server_id, tool_ids);
})
.log_err();
}
if let Some(tool_ids) = tool_ids {
this.update(cx, |this, _| {
this.context_server_tool_ids.insert(server_id, tool_ids);
})
.log_err();
}
}
})
@ -697,13 +696,14 @@ impl SerializedThreadV0_1_0 {
let mut messages: Vec<SerializedMessage> = Vec::with_capacity(self.0.messages.len());
for message in self.0.messages {
if message.role == Role::User && !message.tool_results.is_empty() {
if let Some(last_message) = messages.last_mut() {
debug_assert!(last_message.role == Role::Assistant);
if message.role == Role::User
&& !message.tool_results.is_empty()
&& let Some(last_message) = messages.last_mut()
{
debug_assert!(last_message.role == Role::Assistant);
last_message.tool_results = message.tool_results;
continue;
}
last_message.tool_results = message.tool_results;
continue;
}
messages.push(message);
@ -893,7 +893,7 @@ impl ThreadsDatabase {
let needs_migration_from_heed = mdb_path.exists();
let connection = if *ZED_STATELESS {
let connection = if *ZED_STATELESS || cfg!(any(feature = "test-support", test)) {
Connection::open_memory(Some("THREAD_FALLBACK_DB"))
} else {
Connection::open_file(&sqlite_path.to_string_lossy())

View file

@ -112,19 +112,13 @@ impl ToolUseState {
},
);
if let Some(window) = &mut window {
if let Some(tool) = this.tools.read(cx).tool(tool_use, cx) {
if let Some(output) = tool_result.output.clone() {
if let Some(card) = tool.deserialize_card(
output,
project.clone(),
window,
cx,
) {
this.tool_result_cards.insert(tool_use_id, card);
}
}
}
if let Some(window) = &mut window
&& let Some(tool) = this.tools.read(cx).tool(tool_use, cx)
&& let Some(output) = tool_result.output.clone()
&& let Some(card) =
tool.deserialize_card(output, project.clone(), window, cx)
{
this.tool_result_cards.insert(tool_use_id, card);
}
}
}
@ -281,7 +275,7 @@ impl ToolUseState {
pub fn message_has_tool_results(&self, assistant_message_id: MessageId) -> bool {
self.tool_uses_by_assistant_message
.get(&assistant_message_id)
.map_or(false, |results| !results.is_empty())
.is_some_and(|results| !results.is_empty())
}
pub fn tool_result(

View file

@ -14,18 +14,22 @@ workspace = true
[dependencies]
acp_thread.workspace = true
action_log.workspace = true
agent.workspace = true
agent-client-protocol.workspace = true
agent_servers.workspace = true
agent_settings.workspace = true
anyhow.workspace = true
assistant_context.workspace = true
assistant_tool.workspace = true
assistant_tools.workspace = true
chrono.workspace = true
cloud_llm_client.workspace = true
collections.workspace = true
context_server.workspace = true
db.workspace = true
fs.workspace = true
futures.workspace = true
git.workspace = true
gpui.workspace = true
handlebars = { workspace = true, features = ["rust-embed"] }
html_to_markdown.workspace = true
@ -37,6 +41,7 @@ language_model.workspace = true
language_models.workspace = true
log.workspace = true
open.workspace = true
parking_lot.workspace = true
paths.workspace = true
portable-pty.workspace = true
project.workspace = true
@ -47,6 +52,7 @@ serde.workspace = true
serde_json.workspace = true
settings.workspace = true
smol.workspace = true
sqlez.workspace = true
task.workspace = true
terminal.workspace = true
text.workspace = true
@ -57,8 +63,11 @@ watch.workspace = true
web_search.workspace = true
which.workspace = true
workspace-hack.workspace = true
zstd.workspace = true
[dev-dependencies]
agent = { workspace = true, "features" = ["test-support"] }
assistant_context = { workspace = true, "features" = ["test-support"] }
ctor.workspace = true
client = { workspace = true, "features" = ["test-support"] }
clock = { workspace = true, "features" = ["test-support"] }
@ -66,6 +75,7 @@ context_server = { workspace = true, "features" = ["test-support"] }
editor = { workspace = true, "features" = ["test-support"] }
env_logger.workspace = true
fs = { workspace = true, "features" = ["test-support"] }
git = { workspace = true, "features" = ["test-support"] }
gpui = { workspace = true, "features" = ["test-support"] }
gpui_tokio.workspace = true
language = { workspace = true, "features" = ["test-support"] }

View file

@ -1,10 +1,10 @@
use crate::HistoryStore;
use crate::{
AgentResponseEvent, ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DeletePathTool,
DiagnosticsTool, EditFileTool, FetchTool, FindPathTool, GrepTool, ListDirectoryTool,
MovePathTool, NowTool, OpenTool, ReadFileTool, TerminalTool, ThinkingTool, Thread,
ToolCallAuthorization, UserMessageContent, WebSearchTool, templates::Templates,
ContextServerRegistry, Thread, ThreadEvent, ThreadsDatabase, ToolCallAuthorization,
UserMessageContent, templates::Templates,
};
use acp_thread::AgentModelSelector;
use acp_thread::{AcpThread, AgentModelSelector};
use action_log::ActionLog;
use agent_client_protocol as acp;
use agent_settings::AgentSettings;
use anyhow::{Context as _, Result, anyhow};
@ -28,7 +28,7 @@ use std::rc::Rc;
use std::sync::Arc;
use util::ResultExt;
const RULES_FILE_NAMES: [&'static str; 9] = [
const RULES_FILE_NAMES: [&str; 9] = [
".rules",
".cursorrules",
".windsurfrules",
@ -50,7 +50,8 @@ struct Session {
thread: Entity<Thread>,
/// The ACP thread that handles protocol communication
acp_thread: WeakEntity<acp_thread::AcpThread>,
_subscription: Subscription,
pending_save: Task<()>,
_subscriptions: Vec<Subscription>,
}
pub struct LanguageModels {
@ -154,6 +155,7 @@ impl LanguageModels {
pub struct NativeAgent {
/// Session ID -> Session mapping
sessions: HashMap<acp::SessionId, Session>,
history: Entity<HistoryStore>,
/// Shared project context for all threads
project_context: Entity<ProjectContext>,
project_context_needs_refresh: watch::Sender<()>,
@ -172,6 +174,7 @@ pub struct NativeAgent {
impl NativeAgent {
pub async fn new(
project: Entity<Project>,
history: Entity<HistoryStore>,
templates: Arc<Templates>,
prompt_store: Option<Entity<PromptStore>>,
fs: Arc<dyn Fs>,
@ -199,6 +202,7 @@ impl NativeAgent {
watch::channel(());
Self {
sessions: HashMap::new(),
history,
project_context: cx.new(|_| project_context),
project_context_needs_refresh: project_context_needs_refresh_tx,
_maintain_project_context: cx.spawn(async move |this, cx| {
@ -217,6 +221,55 @@ impl NativeAgent {
})
}
fn register_session(
&mut self,
thread_handle: Entity<Thread>,
cx: &mut Context<Self>,
) -> Entity<AcpThread> {
let connection = Rc::new(NativeAgentConnection(cx.entity()));
let registry = LanguageModelRegistry::read_global(cx);
let summarization_model = registry.thread_summary_model().map(|c| c.model);
thread_handle.update(cx, |thread, cx| {
thread.set_summarization_model(summarization_model, cx);
thread.add_default_tools(cx)
});
let thread = thread_handle.read(cx);
let session_id = thread.id().clone();
let title = thread.title();
let project = thread.project.clone();
let action_log = thread.action_log.clone();
let acp_thread = cx.new(|_cx| {
acp_thread::AcpThread::new(
title,
connection,
project.clone(),
action_log.clone(),
session_id.clone(),
)
});
let subscriptions = vec![
cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
this.sessions.remove(acp_thread.session_id());
}),
cx.observe(&thread_handle, move |this, thread, cx| {
this.save_thread(thread.clone(), cx)
}),
];
self.sessions.insert(
session_id,
Session {
thread: thread_handle,
acp_thread: acp_thread.downgrade(),
_subscriptions: subscriptions,
pending_save: Task::ready(()),
},
);
acp_thread
}
pub fn models(&self) -> &LanguageModels {
&self.models
}
@ -427,21 +480,79 @@ impl NativeAgent {
) {
self.models.refresh_list(cx);
let default_model = LanguageModelRegistry::read_global(cx)
.default_model()
.map(|m| m.model.clone());
let registry = LanguageModelRegistry::read_global(cx);
let default_model = registry.default_model().map(|m| m.model.clone());
let summarization_model = registry.thread_summary_model().map(|m| m.model.clone());
for session in self.sessions.values_mut() {
session.thread.update(cx, |thread, cx| {
if thread.model().is_none()
&& let Some(model) = default_model.clone()
{
thread.set_model(model);
thread.set_model(model, cx);
cx.notify();
}
thread.set_summarization_model(summarization_model.clone(), cx);
});
}
}
pub fn open_thread(
&mut self,
id: acp::SessionId,
cx: &mut Context<Self>,
) -> Task<Result<Entity<AcpThread>>> {
let database_future = ThreadsDatabase::connect(cx);
cx.spawn(async move |this, cx| {
let database = database_future.await.map_err(|err| anyhow!(err))?;
let db_thread = database
.load_thread(id.clone())
.await?
.with_context(|| format!("no thread found with ID: {id:?}"))?;
let thread = this.update(cx, |this, cx| {
let action_log = cx.new(|_cx| ActionLog::new(this.project.clone()));
cx.new(|cx| {
Thread::from_db(
id.clone(),
db_thread,
this.project.clone(),
this.project_context.clone(),
this.context_server_registry.clone(),
action_log.clone(),
this.templates.clone(),
cx,
)
})
})?;
let acp_thread =
this.update(cx, |this, cx| this.register_session(thread.clone(), cx))?;
let events = thread.update(cx, |thread, cx| thread.replay(cx))?;
cx.update(|cx| {
NativeAgentConnection::handle_thread_events(events, acp_thread.downgrade(), cx)
})?
.await?;
Ok(acp_thread)
})
}
fn save_thread(&mut self, thread: Entity<Thread>, cx: &mut Context<Self>) {
let database_future = ThreadsDatabase::connect(cx);
let (id, db_thread) =
thread.update(cx, |thread, cx| (thread.id().clone(), thread.to_db(cx)));
let Some(session) = self.sessions.get_mut(&id) else {
return;
};
let history = self.history.clone();
session.pending_save = cx.spawn(async move |_, cx| {
let Some(database) = database_future.await.map_err(|err| anyhow!(err)).log_err() else {
return;
};
let db_thread = db_thread.await;
database.save_thread(id, db_thread).await.log_err();
history.update(cx, |history, cx| history.reload(cx)).ok();
});
}
}
/// Wrapper struct that implements the AgentConnection trait
@ -462,10 +573,7 @@ impl NativeAgentConnection {
session_id: acp::SessionId,
cx: &mut App,
f: impl 'static
+ FnOnce(
Entity<Thread>,
&mut App,
) -> Result<mpsc::UnboundedReceiver<Result<AgentResponseEvent>>>,
+ FnOnce(Entity<Thread>, &mut App) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>>,
) -> Task<Result<acp::PromptResponse>> {
let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
agent
@ -477,19 +585,38 @@ impl NativeAgentConnection {
};
log::debug!("Found session for: {}", session_id);
let mut response_stream = match f(thread, cx) {
let response_stream = match f(thread, cx) {
Ok(stream) => stream,
Err(err) => return Task::ready(Err(err)),
};
Self::handle_thread_events(response_stream, acp_thread, cx)
}
fn handle_thread_events(
mut events: mpsc::UnboundedReceiver<Result<ThreadEvent>>,
acp_thread: WeakEntity<AcpThread>,
cx: &App,
) -> Task<Result<acp::PromptResponse>> {
cx.spawn(async move |cx| {
// Handle response stream and forward to session.acp_thread
while let Some(result) = response_stream.next().await {
while let Some(result) = events.next().await {
match result {
Ok(event) => {
log::trace!("Received completion event: {:?}", event);
match event {
AgentResponseEvent::Text(text) => {
ThreadEvent::UserMessage(message) => {
acp_thread.update(cx, |thread, cx| {
for content in message.content {
thread.push_user_content_block(
Some(message.id.clone()),
content.into(),
cx,
);
}
})?;
}
ThreadEvent::AgentText(text) => {
acp_thread.update(cx, |thread, cx| {
thread.push_assistant_content_block(
acp::ContentBlock::Text(acp::TextContent {
@ -501,7 +628,7 @@ impl NativeAgentConnection {
)
})?;
}
AgentResponseEvent::Thinking(text) => {
ThreadEvent::AgentThinking(text) => {
acp_thread.update(cx, |thread, cx| {
thread.push_assistant_content_block(
acp::ContentBlock::Text(acp::TextContent {
@ -513,7 +640,7 @@ impl NativeAgentConnection {
)
})?;
}
AgentResponseEvent::ToolCallAuthorization(ToolCallAuthorization {
ThreadEvent::ToolCallAuthorization(ToolCallAuthorization {
tool_call,
options,
response,
@ -536,22 +663,31 @@ impl NativeAgentConnection {
})
.detach();
}
AgentResponseEvent::ToolCall(tool_call) => {
ThreadEvent::ToolCall(tool_call) => {
acp_thread.update(cx, |thread, cx| {
thread.upsert_tool_call(tool_call, cx)
})??;
}
AgentResponseEvent::ToolCallUpdate(update) => {
ThreadEvent::ToolCallUpdate(update) => {
acp_thread.update(cx, |thread, cx| {
thread.update_tool_call(update, cx)
})??;
}
AgentResponseEvent::Retry(status) => {
ThreadEvent::TokenUsageUpdate(usage) => {
acp_thread.update(cx, |thread, cx| {
thread.update_token_usage(Some(usage), cx)
})?;
}
ThreadEvent::TitleUpdate(title) => {
acp_thread
.update(cx, |thread, cx| thread.update_title(title, cx))??;
}
ThreadEvent::Retry(status) => {
acp_thread.update(cx, |thread, cx| {
thread.update_retry_status(status, cx)
})?;
}
AgentResponseEvent::Stop(stop_reason) => {
ThreadEvent::Stop(stop_reason) => {
log::debug!("Assistant message complete: {:?}", stop_reason);
return Ok(acp::PromptResponse { stop_reason });
}
@ -604,8 +740,8 @@ impl AgentModelSelector for NativeAgentConnection {
return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
};
thread.update(cx, |thread, _cx| {
thread.set_model(model.clone());
thread.update(cx, |thread, cx| {
thread.set_model(model.clone(), cx);
});
update_settings_file::<AgentSettings>(
@ -665,31 +801,13 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
cx.spawn(async move |cx| {
log::debug!("Starting thread creation in async context");
// Generate session ID
let session_id = acp::SessionId(uuid::Uuid::new_v4().to_string().into());
log::info!("Created session with ID: {}", session_id);
// Create AcpThread
let acp_thread = cx.update(|cx| {
cx.new(|cx| {
acp_thread::AcpThread::new(
"agent2",
self.clone(),
project.clone(),
session_id.clone(),
cx,
)
})
})?;
let action_log = cx.update(|cx| acp_thread.read(cx).action_log().clone())?;
let action_log = cx.new(|_cx| ActionLog::new(project.clone()))?;
// Create Thread
let thread = agent.update(
cx,
|agent, cx: &mut gpui::Context<NativeAgent>| -> Result<_> {
// Fetch default model from registry settings
let registry = LanguageModelRegistry::read_global(cx);
// Log available models for debugging
let available_count = registry.available_models(cx).count();
log::debug!("Total available models: {}", available_count);
@ -701,7 +819,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
});
let thread = cx.new(|cx| {
let mut thread = Thread::new(
Thread::new(
project.clone(),
agent.project_context.clone(),
agent.context_server_registry.clone(),
@ -709,45 +827,13 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
agent.templates.clone(),
default_model,
cx,
);
thread.add_tool(CopyPathTool::new(project.clone()));
thread.add_tool(CreateDirectoryTool::new(project.clone()));
thread.add_tool(DeletePathTool::new(project.clone(), action_log.clone()));
thread.add_tool(DiagnosticsTool::new(project.clone()));
thread.add_tool(EditFileTool::new(cx.entity()));
thread.add_tool(FetchTool::new(project.read(cx).client().http_client()));
thread.add_tool(FindPathTool::new(project.clone()));
thread.add_tool(GrepTool::new(project.clone()));
thread.add_tool(ListDirectoryTool::new(project.clone()));
thread.add_tool(MovePathTool::new(project.clone()));
thread.add_tool(NowTool);
thread.add_tool(OpenTool::new(project.clone()));
thread.add_tool(ReadFileTool::new(project.clone(), action_log));
thread.add_tool(TerminalTool::new(project.clone(), cx));
thread.add_tool(ThinkingTool);
thread.add_tool(WebSearchTool); // TODO: Enable this only if it's a zed model.
thread
)
});
Ok(thread)
},
)??;
// Store the session
agent.update(cx, |agent, cx| {
agent.sessions.insert(
session_id,
Session {
thread,
acp_thread: acp_thread.downgrade(),
_subscription: cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
this.sessions.remove(acp_thread.session_id());
}),
},
);
})?;
Ok(acp_thread)
agent.update(cx, |agent, cx| agent.register_session(thread, cx))
})
}
@ -803,7 +889,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
log::info!("Cancelling on session: {}", session_id);
self.0.update(cx, |agent, cx| {
if let Some(agent) = agent.sessions.get(session_id) {
agent.thread.update(cx, |thread, _cx| thread.cancel());
agent.thread.update(cx, |thread, cx| thread.cancel(cx));
}
});
}
@ -814,10 +900,12 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
cx: &mut App,
) -> Option<Rc<dyn acp_thread::AgentSessionEditor>> {
self.0.update(cx, |agent, _cx| {
agent
.sessions
.get(session_id)
.map(|session| Rc::new(NativeAgentSessionEditor(session.thread.clone())) as _)
agent.sessions.get(session_id).map(|session| {
Rc::new(NativeAgentSessionEditor {
thread: session.thread.clone(),
acp_thread: session.acp_thread.clone(),
}) as _
})
})
}
@ -826,11 +914,27 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
}
}
struct NativeAgentSessionEditor(Entity<Thread>);
struct NativeAgentSessionEditor {
thread: Entity<Thread>,
acp_thread: WeakEntity<AcpThread>,
}
impl acp_thread::AgentSessionEditor for NativeAgentSessionEditor {
fn truncate(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
Task::ready(self.0.update(cx, |thread, _cx| thread.truncate(message_id)))
match self.thread.update(cx, |thread, cx| {
thread.truncate(message_id.clone(), cx)?;
Ok(thread.latest_token_usage())
}) {
Ok(usage) => {
self.acp_thread
.update(cx, |thread, cx| {
thread.update_token_usage(usage, cx);
})
.ok();
Task::ready(Ok(()))
}
Err(error) => Task::ready(Err(error)),
}
}
}
@ -869,8 +973,11 @@ mod tests {
)
.await;
let project = Project::test(fs.clone(), [], cx).await;
let context_store = cx.new(|cx| assistant_context::ContextStore::fake(project.clone(), cx));
let history_store = cx.new(|cx| HistoryStore::new(context_store, cx));
let agent = NativeAgent::new(
project.clone(),
history_store,
Templates::new(),
None,
fs.clone(),
@ -924,9 +1031,12 @@ mod tests {
let fs = FakeFs::new(cx.executor());
fs.insert_tree("/", json!({ "a": {} })).await;
let project = Project::test(fs.clone(), [], cx).await;
let context_store = cx.new(|cx| assistant_context::ContextStore::fake(project.clone(), cx));
let history_store = cx.new(|cx| HistoryStore::new(context_store, cx));
let connection = NativeAgentConnection(
NativeAgent::new(
project.clone(),
history_store,
Templates::new(),
None,
fs.clone(),
@ -977,9 +1087,13 @@ mod tests {
.await;
let project = Project::test(fs.clone(), [], cx).await;
let context_store = cx.new(|cx| assistant_context::ContextStore::fake(project.clone(), cx));
let history_store = cx.new(|cx| HistoryStore::new(context_store, cx));
// Create the agent and connection
let agent = NativeAgent::new(
project.clone(),
history_store,
Templates::new(),
None,
fs.clone(),

View file

@ -1,13 +1,18 @@
mod agent;
mod db;
mod history_store;
mod native_agent_server;
mod templates;
mod thread;
mod tool_schema;
mod tools;
#[cfg(test)]
mod tests;
pub use agent::*;
pub use db::*;
pub use history_store::*;
pub use native_agent_server::NativeAgentServer;
pub use templates::*;
pub use thread::*;

483
crates/agent2/src/db.rs Normal file
View file

@ -0,0 +1,483 @@
use crate::{AgentMessage, AgentMessageContent, UserMessage, UserMessageContent};
use acp_thread::UserMessageId;
use agent::thread_store;
use agent_client_protocol as acp;
use agent_settings::{AgentProfileId, CompletionMode};
use anyhow::{Result, anyhow};
use chrono::{DateTime, Utc};
use collections::{HashMap, IndexMap};
use futures::{FutureExt, future::Shared};
use gpui::{BackgroundExecutor, Global, Task};
use indoc::indoc;
use parking_lot::Mutex;
use serde::{Deserialize, Serialize};
use sqlez::{
bindable::{Bind, Column},
connection::Connection,
statement::Statement,
};
use std::sync::Arc;
use ui::{App, SharedString};
pub type DbMessage = crate::Message;
pub type DbSummary = agent::thread::DetailedSummaryState;
pub type DbLanguageModel = thread_store::SerializedLanguageModel;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DbThreadMetadata {
pub id: acp::SessionId,
#[serde(alias = "summary")]
pub title: SharedString,
pub updated_at: DateTime<Utc>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct DbThread {
pub title: SharedString,
pub messages: Vec<DbMessage>,
pub updated_at: DateTime<Utc>,
#[serde(default)]
pub summary: DbSummary,
#[serde(default)]
pub initial_project_snapshot: Option<Arc<agent::thread::ProjectSnapshot>>,
#[serde(default)]
pub cumulative_token_usage: language_model::TokenUsage,
#[serde(default)]
pub request_token_usage: HashMap<acp_thread::UserMessageId, language_model::TokenUsage>,
#[serde(default)]
pub model: Option<DbLanguageModel>,
#[serde(default)]
pub completion_mode: Option<CompletionMode>,
#[serde(default)]
pub profile: Option<AgentProfileId>,
}
impl DbThread {
pub const VERSION: &'static str = "0.3.0";
pub fn from_json(json: &[u8]) -> Result<Self> {
let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
match saved_thread_json.get("version") {
Some(serde_json::Value::String(version)) => match version.as_str() {
Self::VERSION => Ok(serde_json::from_value(saved_thread_json)?),
_ => Self::upgrade_from_agent_1(agent::SerializedThread::from_json(json)?),
},
_ => Self::upgrade_from_agent_1(agent::SerializedThread::from_json(json)?),
}
}
fn upgrade_from_agent_1(thread: agent::SerializedThread) -> Result<Self> {
let mut messages = Vec::new();
let mut request_token_usage = HashMap::default();
let mut last_user_message_id = None;
for (ix, msg) in thread.messages.into_iter().enumerate() {
let message = match msg.role {
language_model::Role::User => {
let mut content = Vec::new();
// Convert segments to content
for segment in msg.segments {
match segment {
thread_store::SerializedMessageSegment::Text { text } => {
content.push(UserMessageContent::Text(text));
}
thread_store::SerializedMessageSegment::Thinking { text, .. } => {
// User messages don't have thinking segments, but handle gracefully
content.push(UserMessageContent::Text(text));
}
thread_store::SerializedMessageSegment::RedactedThinking { .. } => {
// User messages don't have redacted thinking, skip.
}
}
}
// If no content was added, add context as text if available
if content.is_empty() && !msg.context.is_empty() {
content.push(UserMessageContent::Text(msg.context));
}
let id = UserMessageId::new();
last_user_message_id = Some(id.clone());
crate::Message::User(UserMessage {
// MessageId from old format can't be meaningfully converted, so generate a new one
id,
content,
})
}
language_model::Role::Assistant => {
let mut content = Vec::new();
// Convert segments to content
for segment in msg.segments {
match segment {
thread_store::SerializedMessageSegment::Text { text } => {
content.push(AgentMessageContent::Text(text));
}
thread_store::SerializedMessageSegment::Thinking {
text,
signature,
} => {
content.push(AgentMessageContent::Thinking { text, signature });
}
thread_store::SerializedMessageSegment::RedactedThinking { data } => {
content.push(AgentMessageContent::RedactedThinking(data));
}
}
}
// Convert tool uses
let mut tool_names_by_id = HashMap::default();
for tool_use in msg.tool_uses {
tool_names_by_id.insert(tool_use.id.clone(), tool_use.name.clone());
content.push(AgentMessageContent::ToolUse(
language_model::LanguageModelToolUse {
id: tool_use.id,
name: tool_use.name.into(),
raw_input: serde_json::to_string(&tool_use.input)
.unwrap_or_default(),
input: tool_use.input,
is_input_complete: true,
},
));
}
// Convert tool results
let mut tool_results = IndexMap::default();
for tool_result in msg.tool_results {
let name = tool_names_by_id
.remove(&tool_result.tool_use_id)
.unwrap_or_else(|| SharedString::from("unknown"));
tool_results.insert(
tool_result.tool_use_id.clone(),
language_model::LanguageModelToolResult {
tool_use_id: tool_result.tool_use_id,
tool_name: name.into(),
is_error: tool_result.is_error,
content: tool_result.content,
output: tool_result.output,
},
);
}
if let Some(last_user_message_id) = &last_user_message_id
&& let Some(token_usage) = thread.request_token_usage.get(ix).copied()
{
request_token_usage.insert(last_user_message_id.clone(), token_usage);
}
crate::Message::Agent(AgentMessage {
content,
tool_results,
})
}
language_model::Role::System => {
// Skip system messages as they're not supported in the new format
continue;
}
};
messages.push(message);
}
Ok(Self {
title: thread.summary,
messages,
updated_at: thread.updated_at,
summary: thread.detailed_summary_state,
initial_project_snapshot: thread.initial_project_snapshot,
cumulative_token_usage: thread.cumulative_token_usage,
request_token_usage,
model: thread.model,
completion_mode: thread.completion_mode,
profile: thread.profile,
})
}
}
pub static ZED_STATELESS: std::sync::LazyLock<bool> =
std::sync::LazyLock::new(|| std::env::var("ZED_STATELESS").is_ok_and(|v| !v.is_empty()));
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum DataType {
#[serde(rename = "json")]
Json,
#[serde(rename = "zstd")]
Zstd,
}
impl Bind for DataType {
fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
let value = match self {
DataType::Json => "json",
DataType::Zstd => "zstd",
};
value.bind(statement, start_index)
}
}
impl Column for DataType {
fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
let (value, next_index) = String::column(statement, start_index)?;
let data_type = match value.as_str() {
"json" => DataType::Json,
"zstd" => DataType::Zstd,
_ => anyhow::bail!("Unknown data type: {}", value),
};
Ok((data_type, next_index))
}
}
pub(crate) struct ThreadsDatabase {
executor: BackgroundExecutor,
connection: Arc<Mutex<Connection>>,
}
struct GlobalThreadsDatabase(Shared<Task<Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>);
impl Global for GlobalThreadsDatabase {}
impl ThreadsDatabase {
pub fn connect(cx: &mut App) -> Shared<Task<Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
if cx.has_global::<GlobalThreadsDatabase>() {
return cx.global::<GlobalThreadsDatabase>().0.clone();
}
let executor = cx.background_executor().clone();
let task = executor
.spawn({
let executor = executor.clone();
async move {
match ThreadsDatabase::new(executor) {
Ok(db) => Ok(Arc::new(db)),
Err(err) => Err(Arc::new(err)),
}
}
})
.shared();
cx.set_global(GlobalThreadsDatabase(task.clone()));
task
}
pub fn new(executor: BackgroundExecutor) -> Result<Self> {
let connection = if *ZED_STATELESS || cfg!(any(feature = "test-support", test)) {
Connection::open_memory(Some("THREAD_FALLBACK_DB"))
} else {
let threads_dir = paths::data_dir().join("threads");
std::fs::create_dir_all(&threads_dir)?;
let sqlite_path = threads_dir.join("threads.db");
Connection::open_file(&sqlite_path.to_string_lossy())
};
connection.exec(indoc! {"
CREATE TABLE IF NOT EXISTS threads (
id TEXT PRIMARY KEY,
summary TEXT NOT NULL,
updated_at TEXT NOT NULL,
data_type TEXT NOT NULL,
data BLOB NOT NULL
)
"})?()
.map_err(|e| anyhow!("Failed to create threads table: {}", e))?;
let db = Self {
executor: executor.clone(),
connection: Arc::new(Mutex::new(connection)),
};
Ok(db)
}
fn save_thread_sync(
connection: &Arc<Mutex<Connection>>,
id: acp::SessionId,
thread: DbThread,
) -> Result<()> {
const COMPRESSION_LEVEL: i32 = 3;
#[derive(Serialize)]
struct SerializedThread {
#[serde(flatten)]
thread: DbThread,
version: &'static str,
}
let title = thread.title.to_string();
let updated_at = thread.updated_at.to_rfc3339();
let json_data = serde_json::to_string(&SerializedThread {
thread,
version: DbThread::VERSION,
})?;
let connection = connection.lock();
let compressed = zstd::encode_all(json_data.as_bytes(), COMPRESSION_LEVEL)?;
let data_type = DataType::Zstd;
let data = compressed;
let mut insert = connection.exec_bound::<(Arc<str>, String, String, DataType, Vec<u8>)>(indoc! {"
INSERT OR REPLACE INTO threads (id, summary, updated_at, data_type, data) VALUES (?, ?, ?, ?, ?)
"})?;
insert((id.0.clone(), title, updated_at, data_type, data))?;
Ok(())
}
pub fn list_threads(&self) -> Task<Result<Vec<DbThreadMetadata>>> {
let connection = self.connection.clone();
self.executor.spawn(async move {
let connection = connection.lock();
let mut select =
connection.select_bound::<(), (Arc<str>, String, String)>(indoc! {"
SELECT id, summary, updated_at FROM threads ORDER BY updated_at DESC
"})?;
let rows = select(())?;
let mut threads = Vec::new();
for (id, summary, updated_at) in rows {
threads.push(DbThreadMetadata {
id: acp::SessionId(id),
title: summary.into(),
updated_at: DateTime::parse_from_rfc3339(&updated_at)?.with_timezone(&Utc),
});
}
Ok(threads)
})
}
pub fn load_thread(&self, id: acp::SessionId) -> Task<Result<Option<DbThread>>> {
let connection = self.connection.clone();
self.executor.spawn(async move {
let connection = connection.lock();
let mut select = connection.select_bound::<Arc<str>, (DataType, Vec<u8>)>(indoc! {"
SELECT data_type, data FROM threads WHERE id = ? LIMIT 1
"})?;
let rows = select(id.0)?;
if let Some((data_type, data)) = rows.into_iter().next() {
let json_data = match data_type {
DataType::Zstd => {
let decompressed = zstd::decode_all(&data[..])?;
String::from_utf8(decompressed)?
}
DataType::Json => String::from_utf8(data)?,
};
let thread = DbThread::from_json(json_data.as_bytes())?;
Ok(Some(thread))
} else {
Ok(None)
}
})
}
pub fn save_thread(&self, id: acp::SessionId, thread: DbThread) -> Task<Result<()>> {
let connection = self.connection.clone();
self.executor
.spawn(async move { Self::save_thread_sync(&connection, id, thread) })
}
pub fn delete_thread(&self, id: acp::SessionId) -> Task<Result<()>> {
let connection = self.connection.clone();
self.executor.spawn(async move {
let connection = connection.lock();
let mut delete = connection.exec_bound::<Arc<str>>(indoc! {"
DELETE FROM threads WHERE id = ?
"})?;
delete(id.0)?;
Ok(())
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use agent::MessageSegment;
use agent::context::LoadedContext;
use client::Client;
use fs::FakeFs;
use gpui::AppContext;
use gpui::TestAppContext;
use http_client::FakeHttpClient;
use language_model::Role;
use project::Project;
use settings::SettingsStore;
fn init_test(cx: &mut TestAppContext) {
env_logger::try_init().ok();
cx.update(|cx| {
let settings_store = SettingsStore::test(cx);
cx.set_global(settings_store);
Project::init_settings(cx);
language::init(cx);
let http_client = FakeHttpClient::with_404_response();
let clock = Arc::new(clock::FakeSystemClock::new());
let client = Client::new(clock, http_client, cx);
agent::init(cx);
agent_settings::init(cx);
language_model::init(client.clone(), cx);
});
}
#[gpui::test]
async fn test_retrieving_old_thread(cx: &mut TestAppContext) {
init_test(cx);
let fs = FakeFs::new(cx.executor());
let project = Project::test(fs, [], cx).await;
// Save a thread using the old agent.
let thread_store = cx.new(|cx| agent::ThreadStore::fake(project, cx));
let thread = thread_store.update(cx, |thread_store, cx| thread_store.create_thread(cx));
thread.update(cx, |thread, cx| {
thread.insert_message(
Role::User,
vec![MessageSegment::Text("Hey!".into())],
LoadedContext::default(),
vec![],
false,
cx,
);
thread.insert_message(
Role::Assistant,
vec![MessageSegment::Text("How're you doing?".into())],
LoadedContext::default(),
vec![],
false,
cx,
)
});
thread_store
.update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx))
.await
.unwrap();
// Open that same thread using the new agent.
let db = cx.update(ThreadsDatabase::connect).await.unwrap();
let threads = db.list_threads().await.unwrap();
assert_eq!(threads.len(), 1);
let thread = db
.load_thread(threads[0].id.clone())
.await
.unwrap()
.unwrap();
assert_eq!(thread.messages[0].to_markdown(), "## User\n\nHey!\n");
assert_eq!(
thread.messages[1].to_markdown(),
"## Assistant\n\nHow're you doing?\n"
);
}
}

View file

@ -0,0 +1,318 @@
use crate::{DbThreadMetadata, ThreadsDatabase};
use agent_client_protocol as acp;
use anyhow::{Context as _, Result, anyhow};
use assistant_context::SavedContextMetadata;
use chrono::{DateTime, Utc};
use db::kvp::KEY_VALUE_STORE;
use gpui::{App, AsyncApp, Entity, SharedString, Task, prelude::*};
use itertools::Itertools;
use paths::contexts_dir;
use serde::{Deserialize, Serialize};
use std::{collections::VecDeque, path::Path, sync::Arc, time::Duration};
use util::ResultExt as _;
const MAX_RECENTLY_OPENED_ENTRIES: usize = 6;
const RECENTLY_OPENED_THREADS_KEY: &str = "recent-agent-threads";
const SAVE_RECENTLY_OPENED_ENTRIES_DEBOUNCE: Duration = Duration::from_millis(50);
const DEFAULT_TITLE: &SharedString = &SharedString::new_static("New Thread");
#[derive(Clone, Debug)]
pub enum HistoryEntry {
AcpThread(DbThreadMetadata),
TextThread(SavedContextMetadata),
}
impl HistoryEntry {
pub fn updated_at(&self) -> DateTime<Utc> {
match self {
HistoryEntry::AcpThread(thread) => thread.updated_at,
HistoryEntry::TextThread(context) => context.mtime.to_utc(),
}
}
pub fn id(&self) -> HistoryEntryId {
match self {
HistoryEntry::AcpThread(thread) => HistoryEntryId::AcpThread(thread.id.clone()),
HistoryEntry::TextThread(context) => HistoryEntryId::TextThread(context.path.clone()),
}
}
pub fn title(&self) -> &SharedString {
match self {
HistoryEntry::AcpThread(thread) if thread.title.is_empty() => DEFAULT_TITLE,
HistoryEntry::AcpThread(thread) => &thread.title,
HistoryEntry::TextThread(context) => &context.title,
}
}
}
/// Generic identifier for a history entry.
#[derive(Clone, PartialEq, Eq, Debug)]
pub enum HistoryEntryId {
AcpThread(acp::SessionId),
TextThread(Arc<Path>),
}
#[derive(Serialize, Deserialize, Debug)]
enum SerializedRecentOpen {
AcpThread(String),
TextThread(String),
}
pub struct HistoryStore {
threads: Vec<DbThreadMetadata>,
context_store: Entity<assistant_context::ContextStore>,
recently_opened_entries: VecDeque<HistoryEntryId>,
_subscriptions: Vec<gpui::Subscription>,
_save_recently_opened_entries_task: Task<()>,
}
impl HistoryStore {
pub fn new(
context_store: Entity<assistant_context::ContextStore>,
cx: &mut Context<Self>,
) -> Self {
let subscriptions = vec![cx.observe(&context_store, |_, _, cx| cx.notify())];
cx.spawn(async move |this, cx| {
let entries = Self::load_recently_opened_entries(cx).await;
this.update(cx, |this, cx| {
if let Some(entries) = entries.log_err() {
this.recently_opened_entries = entries;
}
this.reload(cx);
})
.ok();
})
.detach();
Self {
context_store,
recently_opened_entries: VecDeque::default(),
threads: Vec::default(),
_subscriptions: subscriptions,
_save_recently_opened_entries_task: Task::ready(()),
}
}
pub fn delete_thread(
&mut self,
id: acp::SessionId,
cx: &mut Context<Self>,
) -> Task<Result<()>> {
let database_future = ThreadsDatabase::connect(cx);
cx.spawn(async move |this, cx| {
let database = database_future.await.map_err(|err| anyhow!(err))?;
database.delete_thread(id.clone()).await?;
this.update(cx, |this, cx| this.reload(cx))
})
}
pub fn delete_text_thread(
&mut self,
path: Arc<Path>,
cx: &mut Context<Self>,
) -> Task<Result<()>> {
self.context_store.update(cx, |context_store, cx| {
context_store.delete_local_context(path, cx)
})
}
pub fn reload(&self, cx: &mut Context<Self>) {
let database_future = ThreadsDatabase::connect(cx);
cx.spawn(async move |this, cx| {
let threads = database_future
.await
.map_err(|err| anyhow!(err))?
.list_threads()
.await?;
this.update(cx, |this, cx| {
if this.recently_opened_entries.len() < MAX_RECENTLY_OPENED_ENTRIES {
for thread in threads
.iter()
.take(MAX_RECENTLY_OPENED_ENTRIES - this.recently_opened_entries.len())
.rev()
{
this.push_recently_opened_entry(
HistoryEntryId::AcpThread(thread.id.clone()),
cx,
)
}
}
this.threads = threads;
cx.notify();
})
})
.detach_and_log_err(cx);
}
pub fn entries(&self, cx: &mut Context<Self>) -> Vec<HistoryEntry> {
let mut history_entries = Vec::new();
#[cfg(debug_assertions)]
if std::env::var("ZED_SIMULATE_NO_THREAD_HISTORY").is_ok() {
return history_entries;
}
history_entries.extend(self.threads.iter().cloned().map(HistoryEntry::AcpThread));
history_entries.extend(
self.context_store
.read(cx)
.unordered_contexts()
.cloned()
.map(HistoryEntry::TextThread),
);
history_entries.sort_unstable_by_key(|entry| std::cmp::Reverse(entry.updated_at()));
history_entries
}
pub fn is_empty(&self, cx: &App) -> bool {
self.threads.is_empty()
&& self
.context_store
.read(cx)
.unordered_contexts()
.next()
.is_none()
}
pub fn recent_entries(&self, limit: usize, cx: &mut Context<Self>) -> Vec<HistoryEntry> {
self.entries(cx).into_iter().take(limit).collect()
}
pub fn recently_opened_entries(&self, cx: &App) -> Vec<HistoryEntry> {
#[cfg(debug_assertions)]
if std::env::var("ZED_SIMULATE_NO_THREAD_HISTORY").is_ok() {
return Vec::new();
}
let thread_entries = self.threads.iter().flat_map(|thread| {
self.recently_opened_entries
.iter()
.enumerate()
.flat_map(|(index, entry)| match entry {
HistoryEntryId::AcpThread(id) if &thread.id == id => {
Some((index, HistoryEntry::AcpThread(thread.clone())))
}
_ => None,
})
});
let context_entries =
self.context_store
.read(cx)
.unordered_contexts()
.flat_map(|context| {
self.recently_opened_entries
.iter()
.enumerate()
.flat_map(|(index, entry)| match entry {
HistoryEntryId::TextThread(path) if &context.path == path => {
Some((index, HistoryEntry::TextThread(context.clone())))
}
_ => None,
})
});
thread_entries
.chain(context_entries)
// optimization to halt iteration early
.take(self.recently_opened_entries.len())
.sorted_unstable_by_key(|(index, _)| *index)
.map(|(_, entry)| entry)
.collect()
}
fn save_recently_opened_entries(&mut self, cx: &mut Context<Self>) {
let serialized_entries = self
.recently_opened_entries
.iter()
.filter_map(|entry| match entry {
HistoryEntryId::TextThread(path) => path.file_name().map(|file| {
SerializedRecentOpen::TextThread(file.to_string_lossy().to_string())
}),
HistoryEntryId::AcpThread(id) => {
Some(SerializedRecentOpen::AcpThread(id.to_string()))
}
})
.collect::<Vec<_>>();
self._save_recently_opened_entries_task = cx.spawn(async move |_, cx| {
let content = serde_json::to_string(&serialized_entries).unwrap();
cx.background_executor()
.timer(SAVE_RECENTLY_OPENED_ENTRIES_DEBOUNCE)
.await;
KEY_VALUE_STORE
.write_kvp(RECENTLY_OPENED_THREADS_KEY.to_owned(), content)
.await
.log_err();
});
}
fn load_recently_opened_entries(cx: &AsyncApp) -> Task<Result<VecDeque<HistoryEntryId>>> {
cx.background_spawn(async move {
let json = KEY_VALUE_STORE
.read_kvp(RECENTLY_OPENED_THREADS_KEY)?
.unwrap_or("[]".to_string());
let entries = serde_json::from_str::<Vec<SerializedRecentOpen>>(&json)
.context("deserializing persisted agent panel navigation history")?
.into_iter()
.take(MAX_RECENTLY_OPENED_ENTRIES)
.flat_map(|entry| match entry {
SerializedRecentOpen::AcpThread(id) => Some(HistoryEntryId::AcpThread(
acp::SessionId(id.as_str().into()),
)),
SerializedRecentOpen::TextThread(file_name) => Some(
HistoryEntryId::TextThread(contexts_dir().join(file_name).into()),
),
})
.collect();
Ok(entries)
})
}
pub fn push_recently_opened_entry(&mut self, entry: HistoryEntryId, cx: &mut Context<Self>) {
self.recently_opened_entries
.retain(|old_entry| old_entry != &entry);
self.recently_opened_entries.push_front(entry);
self.recently_opened_entries
.truncate(MAX_RECENTLY_OPENED_ENTRIES);
self.save_recently_opened_entries(cx);
}
pub fn remove_recently_opened_thread(&mut self, id: acp::SessionId, cx: &mut Context<Self>) {
self.recently_opened_entries.retain(|entry| match entry {
HistoryEntryId::AcpThread(thread_id) if thread_id == &id => false,
_ => true,
});
self.save_recently_opened_entries(cx);
}
pub fn replace_recently_opened_text_thread(
&mut self,
old_path: &Path,
new_path: &Arc<Path>,
cx: &mut Context<Self>,
) {
for entry in &mut self.recently_opened_entries {
match entry {
HistoryEntryId::TextThread(path) if path.as_ref() == old_path => {
*entry = HistoryEntryId::TextThread(new_path.clone());
break;
}
_ => {}
}
}
self.save_recently_opened_entries(cx);
}
pub fn remove_recently_opened_entry(&mut self, entry: &HistoryEntryId, cx: &mut Context<Self>) {
self.recently_opened_entries
.retain(|old_entry| old_entry != entry);
self.save_recently_opened_entries(cx);
}
}

View file

@ -7,16 +7,17 @@ use gpui::{App, Entity, Task};
use project::Project;
use prompt_store::PromptStore;
use crate::{NativeAgent, NativeAgentConnection, templates::Templates};
use crate::{HistoryStore, NativeAgent, NativeAgentConnection, templates::Templates};
#[derive(Clone)]
pub struct NativeAgentServer {
fs: Arc<dyn Fs>,
history: Entity<HistoryStore>,
}
impl NativeAgentServer {
pub fn new(fs: Arc<dyn Fs>) -> Self {
Self { fs }
pub fn new(fs: Arc<dyn Fs>, history: Entity<HistoryStore>) -> Self {
Self { fs, history }
}
}
@ -50,6 +51,7 @@ impl AgentServer for NativeAgentServer {
);
let project = project.clone();
let fs = self.fs.clone();
let history = self.history.clone();
let prompt_store = PromptStore::global(cx);
cx.spawn(async move |cx| {
log::debug!("Creating templates for native agent");
@ -57,7 +59,8 @@ impl AgentServer for NativeAgentServer {
let prompt_store = prompt_store.await?;
log::debug!("Creating native agent entity");
let agent = NativeAgent::new(project, templates, Some(prompt_store), fs, cx).await?;
let agent =
NativeAgent::new(project, history, templates, Some(prompt_store), fs, cx).await?;
// Create the connection wrapper
let connection = NativeAgentConnection(agent);

View file

@ -345,7 +345,7 @@ async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
let mut saw_partial_tool_use = false;
while let Some(event) = events.next().await {
if let Ok(AgentResponseEvent::ToolCall(tool_call)) = event {
if let Ok(ThreadEvent::ToolCall(tool_call)) = event {
thread.update(cx, |thread, _cx| {
// Look for a tool use in the thread's last message
let message = thread.last_message().unwrap();
@ -735,16 +735,14 @@ async fn test_send_after_tool_use_limit(cx: &mut TestAppContext) {
);
}
async fn expect_tool_call(
events: &mut UnboundedReceiver<Result<AgentResponseEvent>>,
) -> acp::ToolCall {
async fn expect_tool_call(events: &mut UnboundedReceiver<Result<ThreadEvent>>) -> acp::ToolCall {
let event = events
.next()
.await
.expect("no tool call authorization event received")
.unwrap();
match event {
AgentResponseEvent::ToolCall(tool_call) => return tool_call,
ThreadEvent::ToolCall(tool_call) => tool_call,
event => {
panic!("Unexpected event {event:?}");
}
@ -752,7 +750,7 @@ async fn expect_tool_call(
}
async fn expect_tool_call_update_fields(
events: &mut UnboundedReceiver<Result<AgentResponseEvent>>,
events: &mut UnboundedReceiver<Result<ThreadEvent>>,
) -> acp::ToolCallUpdate {
let event = events
.next()
@ -760,9 +758,7 @@ async fn expect_tool_call_update_fields(
.expect("no tool call authorization event received")
.unwrap();
match event {
AgentResponseEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => {
return update;
}
ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => update,
event => {
panic!("Unexpected event {event:?}");
}
@ -770,7 +766,7 @@ async fn expect_tool_call_update_fields(
}
async fn next_tool_call_authorization(
events: &mut UnboundedReceiver<Result<AgentResponseEvent>>,
events: &mut UnboundedReceiver<Result<ThreadEvent>>,
) -> ToolCallAuthorization {
loop {
let event = events
@ -778,7 +774,7 @@ async fn next_tool_call_authorization(
.await
.expect("no tool call authorization event received")
.unwrap();
if let AgentResponseEvent::ToolCallAuthorization(tool_call_authorization) = event {
if let ThreadEvent::ToolCallAuthorization(tool_call_authorization) = event {
let permission_kinds = tool_call_authorization
.options
.iter()
@ -945,13 +941,13 @@ async fn test_cancellation(cx: &mut TestAppContext) {
let mut echo_completed = false;
while let Some(event) = events.next().await {
match event.unwrap() {
AgentResponseEvent::ToolCall(tool_call) => {
ThreadEvent::ToolCall(tool_call) => {
assert_eq!(tool_call.title, expected_tools.remove(0));
if tool_call.title == "Echo" {
echo_id = Some(tool_call.id);
}
}
AgentResponseEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
acp::ToolCallUpdate {
id,
fields:
@ -973,13 +969,13 @@ async fn test_cancellation(cx: &mut TestAppContext) {
// Cancel the current send and ensure that the event stream is closed, even
// if one of the tools is still running.
thread.update(cx, |thread, _cx| thread.cancel());
thread.update(cx, |thread, cx| thread.cancel(cx));
let events = events.collect::<Vec<_>>().await;
let last_event = events.last();
assert!(
matches!(
last_event,
Some(Ok(AgentResponseEvent::Stop(acp::StopReason::Canceled)))
Some(Ok(ThreadEvent::Stop(acp::StopReason::Canceled)))
),
"unexpected event {last_event:?}"
);
@ -1121,7 +1117,7 @@ async fn test_refusal(cx: &mut TestAppContext) {
}
#[gpui::test]
async fn test_truncate(cx: &mut TestAppContext) {
async fn test_truncate_first_message(cx: &mut TestAppContext) {
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
let fake_model = model.as_fake();
@ -1141,9 +1137,18 @@ async fn test_truncate(cx: &mut TestAppContext) {
Hello
"}
);
assert_eq!(thread.latest_token_usage(), None);
});
fake_model.send_last_completion_stream_text_chunk("Hey!");
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
language_model::TokenUsage {
input_tokens: 32_000,
output_tokens: 16_000,
cache_creation_input_tokens: 0,
cache_read_input_tokens: 0,
},
));
cx.run_until_parked();
thread.read_with(cx, |thread, _| {
assert_eq!(
@ -1158,14 +1163,22 @@ async fn test_truncate(cx: &mut TestAppContext) {
Hey!
"}
);
assert_eq!(
thread.latest_token_usage(),
Some(acp_thread::TokenUsage {
used_tokens: 32_000 + 16_000,
max_tokens: 1_000_000,
})
);
});
thread
.update(cx, |thread, _cx| thread.truncate(message_id))
.update(cx, |thread, cx| thread.truncate(message_id, cx))
.unwrap();
cx.run_until_parked();
thread.read_with(cx, |thread, _| {
assert_eq!(thread.to_markdown(), "");
assert_eq!(thread.latest_token_usage(), None);
});
// Ensure we can still send a new message after truncation.
@ -1186,6 +1199,14 @@ async fn test_truncate(cx: &mut TestAppContext) {
});
cx.run_until_parked();
fake_model.send_last_completion_stream_text_chunk("Ahoy!");
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
language_model::TokenUsage {
input_tokens: 40_000,
output_tokens: 20_000,
cache_creation_input_tokens: 0,
cache_read_input_tokens: 0,
},
));
cx.run_until_parked();
thread.read_with(cx, |thread, _| {
assert_eq!(
@ -1200,9 +1221,171 @@ async fn test_truncate(cx: &mut TestAppContext) {
Ahoy!
"}
);
assert_eq!(
thread.latest_token_usage(),
Some(acp_thread::TokenUsage {
used_tokens: 40_000 + 20_000,
max_tokens: 1_000_000,
})
);
});
}
#[gpui::test]
async fn test_truncate_second_message(cx: &mut TestAppContext) {
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
let fake_model = model.as_fake();
thread
.update(cx, |thread, cx| {
thread.send(UserMessageId::new(), ["Message 1"], cx)
})
.unwrap();
cx.run_until_parked();
fake_model.send_last_completion_stream_text_chunk("Message 1 response");
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
language_model::TokenUsage {
input_tokens: 32_000,
output_tokens: 16_000,
cache_creation_input_tokens: 0,
cache_read_input_tokens: 0,
},
));
fake_model.end_last_completion_stream();
cx.run_until_parked();
let assert_first_message_state = |cx: &mut TestAppContext| {
thread.clone().read_with(cx, |thread, _| {
assert_eq!(
thread.to_markdown(),
indoc! {"
## User
Message 1
## Assistant
Message 1 response
"}
);
assert_eq!(
thread.latest_token_usage(),
Some(acp_thread::TokenUsage {
used_tokens: 32_000 + 16_000,
max_tokens: 1_000_000,
})
);
});
};
assert_first_message_state(cx);
let second_message_id = UserMessageId::new();
thread
.update(cx, |thread, cx| {
thread.send(second_message_id.clone(), ["Message 2"], cx)
})
.unwrap();
cx.run_until_parked();
fake_model.send_last_completion_stream_text_chunk("Message 2 response");
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
language_model::TokenUsage {
input_tokens: 40_000,
output_tokens: 20_000,
cache_creation_input_tokens: 0,
cache_read_input_tokens: 0,
},
));
fake_model.end_last_completion_stream();
cx.run_until_parked();
thread.read_with(cx, |thread, _| {
assert_eq!(
thread.to_markdown(),
indoc! {"
## User
Message 1
## Assistant
Message 1 response
## User
Message 2
## Assistant
Message 2 response
"}
);
assert_eq!(
thread.latest_token_usage(),
Some(acp_thread::TokenUsage {
used_tokens: 40_000 + 20_000,
max_tokens: 1_000_000,
})
);
});
thread
.update(cx, |thread, cx| thread.truncate(second_message_id, cx))
.unwrap();
cx.run_until_parked();
assert_first_message_state(cx);
}
#[gpui::test]
async fn test_title_generation(cx: &mut TestAppContext) {
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
let fake_model = model.as_fake();
let summary_model = Arc::new(FakeLanguageModel::default());
thread.update(cx, |thread, cx| {
thread.set_summarization_model(Some(summary_model.clone()), cx)
});
let send = thread
.update(cx, |thread, cx| {
thread.send(UserMessageId::new(), ["Hello"], cx)
})
.unwrap();
cx.run_until_parked();
fake_model.send_last_completion_stream_text_chunk("Hey!");
fake_model.end_last_completion_stream();
cx.run_until_parked();
thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "New Thread"));
// Ensure the summary model has been invoked to generate a title.
summary_model.send_last_completion_stream_text_chunk("Hello ");
summary_model.send_last_completion_stream_text_chunk("world\nG");
summary_model.send_last_completion_stream_text_chunk("oodnight Moon");
summary_model.end_last_completion_stream();
send.collect::<Vec<_>>().await;
thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
// Send another message, ensuring no title is generated this time.
let send = thread
.update(cx, |thread, cx| {
thread.send(UserMessageId::new(), ["Hello again"], cx)
})
.unwrap();
cx.run_until_parked();
fake_model.send_last_completion_stream_text_chunk("Hey again!");
fake_model.end_last_completion_stream();
cx.run_until_parked();
assert_eq!(summary_model.pending_completions(), Vec::new());
send.collect::<Vec<_>>().await;
thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
}
#[gpui::test]
async fn test_agent_connection(cx: &mut TestAppContext) {
cx.update(settings::init);
@ -1230,10 +1413,13 @@ async fn test_agent_connection(cx: &mut TestAppContext) {
fake_fs.insert_tree(path!("/test"), json!({})).await;
let project = Project::test(fake_fs.clone(), [Path::new("/test")], cx).await;
let cwd = Path::new("/test");
let context_store = cx.new(|cx| assistant_context::ContextStore::fake(project.clone(), cx));
let history_store = cx.new(|cx| HistoryStore::new(context_store, cx));
// Create agent and connection
let agent = NativeAgent::new(
project.clone(),
history_store,
templates.clone(),
None,
fake_fs.clone(),
@ -1442,7 +1628,7 @@ async fn test_send_no_retry_on_success(cx: &mut TestAppContext) {
let mut events = thread
.update(cx, |thread, cx| {
thread.set_completion_mode(agent_settings::CompletionMode::Burn);
thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
thread.send(UserMessageId::new(), ["Hello!"], cx)
})
.unwrap();
@ -1454,10 +1640,10 @@ async fn test_send_no_retry_on_success(cx: &mut TestAppContext) {
let mut retry_events = Vec::new();
while let Some(Ok(event)) = events.next().await {
match event {
AgentResponseEvent::Retry(retry_status) => {
ThreadEvent::Retry(retry_status) => {
retry_events.push(retry_status);
}
AgentResponseEvent::Stop(..) => break,
ThreadEvent::Stop(..) => break,
_ => {}
}
}
@ -1486,7 +1672,7 @@ async fn test_send_retry_on_error(cx: &mut TestAppContext) {
let mut events = thread
.update(cx, |thread, cx| {
thread.set_completion_mode(agent_settings::CompletionMode::Burn);
thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
thread.send(UserMessageId::new(), ["Hello!"], cx)
})
.unwrap();
@ -1507,10 +1693,10 @@ async fn test_send_retry_on_error(cx: &mut TestAppContext) {
let mut retry_events = Vec::new();
while let Some(Ok(event)) = events.next().await {
match event {
AgentResponseEvent::Retry(retry_status) => {
ThreadEvent::Retry(retry_status) => {
retry_events.push(retry_status);
}
AgentResponseEvent::Stop(..) => break,
ThreadEvent::Stop(..) => break,
_ => {}
}
}
@ -1543,7 +1729,7 @@ async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) {
let mut events = thread
.update(cx, |thread, cx| {
thread.set_completion_mode(agent_settings::CompletionMode::Burn);
thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
thread.send(UserMessageId::new(), ["Hello!"], cx)
})
.unwrap();
@ -1565,10 +1751,10 @@ async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) {
let mut retry_events = Vec::new();
while let Some(event) = events.next().await {
match event {
Ok(AgentResponseEvent::Retry(retry_status)) => {
Ok(ThreadEvent::Retry(retry_status)) => {
retry_events.push(retry_status);
}
Ok(AgentResponseEvent::Stop(..)) => break,
Ok(ThreadEvent::Stop(..)) => break,
Err(error) => errors.push(error),
_ => {}
}
@ -1592,11 +1778,11 @@ async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) {
}
/// Filters out the stop events for asserting against in tests
fn stop_events(result_events: Vec<Result<AgentResponseEvent>>) -> Vec<acp::StopReason> {
fn stop_events(result_events: Vec<Result<ThreadEvent>>) -> Vec<acp::StopReason> {
result_events
.into_iter()
.filter_map(|event| match event.unwrap() {
AgentResponseEvent::Stop(stop_reason) => Some(stop_reason),
ThreadEvent::Stop(stop_reason) => Some(stop_reason),
_ => None,
})
.collect()

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,43 @@
use language_model::LanguageModelToolSchemaFormat;
use schemars::{
JsonSchema, Schema,
generate::SchemaSettings,
transform::{Transform, transform_subschemas},
};
pub(crate) fn root_schema_for<T: JsonSchema>(format: LanguageModelToolSchemaFormat) -> Schema {
let mut generator = match format {
LanguageModelToolSchemaFormat::JsonSchema => SchemaSettings::draft07().into_generator(),
LanguageModelToolSchemaFormat::JsonSchemaSubset => SchemaSettings::openapi3()
.with(|settings| {
settings.meta_schema = None;
settings.inline_subschemas = true;
})
.with_transform(ToJsonSchemaSubsetTransform)
.into_generator(),
};
generator.root_schema_for::<T>()
}
#[derive(Debug, Clone)]
struct ToJsonSchemaSubsetTransform;
impl Transform for ToJsonSchemaSubsetTransform {
fn transform(&mut self, schema: &mut Schema) {
// Ensure that the type field is not an array, this happens when we use
// Option<T>, the type will be [T, "null"].
if let Some(type_field) = schema.get_mut("type")
&& let Some(types) = type_field.as_array()
&& let Some(first_type) = types.first()
{
*type_field = first_type.clone();
}
// oneOf is not supported, use anyOf instead
if let Some(one_of) = schema.remove("oneOf") {
schema.insert("anyOf".to_string(), one_of);
}
transform_subschemas(self, schema);
}
}

View file

@ -228,4 +228,14 @@ impl AnyAgentTool for ContextServerTool {
})
})
}
fn replay(
&self,
_input: serde_json::Value,
_output: serde_json::Value,
_event_stream: ToolCallEventStream,
_cx: &mut App,
) -> Result<()> {
Ok(())
}
}

View file

@ -5,10 +5,10 @@ use anyhow::{Context as _, Result, anyhow};
use assistant_tools::edit_agent::{EditAgent, EditAgentOutput, EditAgentOutputEvent, EditFormat};
use cloud_llm_client::CompletionIntent;
use collections::HashSet;
use gpui::{App, AppContext, AsyncApp, Entity, Task};
use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity};
use indoc::formatdoc;
use language::ToPoint;
use language::language_settings::{self, FormatOnSave};
use language::{LanguageRegistry, ToPoint};
use language_model::LanguageModelToolResultContent;
use paths;
use project::lsp_store::{FormatTrigger, LspFormatTarget};
@ -98,11 +98,13 @@ pub enum EditFileMode {
#[derive(Debug, Serialize, Deserialize)]
pub struct EditFileToolOutput {
#[serde(alias = "original_path")]
input_path: PathBuf,
project_path: PathBuf,
new_text: String,
old_text: Arc<String>,
#[serde(default)]
diff: String,
#[serde(alias = "raw_output")]
edit_agent_output: EditAgentOutput,
}
@ -122,12 +124,16 @@ impl From<EditFileToolOutput> for LanguageModelToolResultContent {
}
pub struct EditFileTool {
thread: Entity<Thread>,
thread: WeakEntity<Thread>,
language_registry: Arc<LanguageRegistry>,
}
impl EditFileTool {
pub fn new(thread: Entity<Thread>) -> Self {
Self { thread }
pub fn new(thread: WeakEntity<Thread>, language_registry: Arc<LanguageRegistry>) -> Self {
Self {
thread,
language_registry,
}
}
fn authorize(
@ -156,19 +162,22 @@ impl EditFileTool {
// It's also possible that the global config dir is configured to be inside the project,
// so check for that edge case too.
if let Ok(canonical_path) = std::fs::canonicalize(&input.path) {
if canonical_path.starts_with(paths::config_dir()) {
return event_stream.authorize(
format!("{} (global settings)", input.display_description),
cx,
);
}
if let Ok(canonical_path) = std::fs::canonicalize(&input.path)
&& canonical_path.starts_with(paths::config_dir())
{
return event_stream.authorize(
format!("{} (global settings)", input.display_description),
cx,
);
}
// Check if path is inside the global config directory
// First check if it's already inside project - if not, try to canonicalize
let thread = self.thread.read(cx);
let project_path = thread.project().read(cx).find_project_path(&input.path, cx);
let Ok(project_path) = self.thread.read_with(cx, |thread, cx| {
thread.project().read(cx).find_project_path(&input.path, cx)
}) else {
return Task::ready(Err(anyhow!("thread was dropped")));
};
// If the path is inside the project, and it's not one of the above edge cases,
// then no confirmation is necessary. Otherwise, confirmation is necessary.
@ -221,7 +230,12 @@ impl AgentTool for EditFileTool {
event_stream: ToolCallEventStream,
cx: &mut App,
) -> Task<Result<Self::Output>> {
let project = self.thread.read(cx).project().clone();
let Ok(project) = self
.thread
.read_with(cx, |thread, _cx| thread.project().clone())
else {
return Task::ready(Err(anyhow!("thread was dropped")));
};
let project_path = match resolve_path(&input, project.clone(), cx) {
Ok(path) => path,
Err(err) => return Task::ready(Err(anyhow!(err))),
@ -237,23 +251,17 @@ impl AgentTool for EditFileTool {
});
}
let Some(request) = self.thread.update(cx, |thread, cx| {
thread
.build_completion_request(CompletionIntent::ToolResults, cx)
.ok()
}) else {
return Task::ready(Err(anyhow!("Failed to build completion request")));
};
let thread = self.thread.read(cx);
let Some(model) = thread.model().cloned() else {
return Task::ready(Err(anyhow!("No language model configured")));
};
let action_log = thread.action_log().clone();
let authorize = self.authorize(&input, &event_stream, cx);
cx.spawn(async move |cx: &mut AsyncApp| {
authorize.await?;
let (request, model, action_log) = self.thread.update(cx, |thread, cx| {
let request = thread.build_completion_request(CompletionIntent::ToolResults, cx);
(request, thread.model().cloned(), thread.action_log().clone())
})?;
let request = request?;
let model = model.context("No language model configured")?;
let edit_format = EditFormat::from_model(model.clone())?;
let edit_agent = EditAgent::new(
model,
@ -419,7 +427,6 @@ impl AgentTool for EditFileTool {
Ok(EditFileToolOutput {
input_path: input.path,
project_path: project_path.path.to_path_buf(),
new_text: new_text.clone(),
old_text,
diff: unified_diff,
@ -427,6 +434,25 @@ impl AgentTool for EditFileTool {
})
})
}
fn replay(
&self,
_input: Self::Input,
output: Self::Output,
event_stream: ToolCallEventStream,
cx: &mut App,
) -> Result<()> {
event_stream.update_diff(cx.new(|cx| {
Diff::finalized(
output.input_path,
Some(output.old_text.to_string()),
output.new_text,
self.language_registry.clone(),
cx,
)
}));
Ok(())
}
}
/// Validate that the file path is valid, meaning:
@ -515,6 +541,7 @@ mod tests {
let fs = project::FakeFs::new(cx.executor());
fs.insert_tree("/root", json!({})).await;
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let context_server_registry =
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
@ -537,7 +564,11 @@ mod tests {
path: "root/nonexistent_file.txt".into(),
mode: EditFileMode::Edit,
};
Arc::new(EditFileTool { thread }).run(input, ToolCallEventStream::test().0, cx)
Arc::new(EditFileTool::new(thread.downgrade(), language_registry)).run(
input,
ToolCallEventStream::test().0,
cx,
)
})
.await;
assert_eq!(
@ -624,8 +655,7 @@ mod tests {
mode: mode.clone(),
};
let result = cx.update(|cx| resolve_path(&input, project, cx));
result
cx.update(|cx| resolve_path(&input, project, cx))
}
fn assert_resolved_path_eq(path: anyhow::Result<ProjectPath>, expected: &str) {
@ -750,9 +780,10 @@ mod tests {
path: "root/src/main.rs".into(),
mode: EditFileMode::Overwrite,
};
Arc::new(EditFileTool {
thread: thread.clone(),
})
Arc::new(EditFileTool::new(
thread.downgrade(),
language_registry.clone(),
))
.run(input, ToolCallEventStream::test().0, cx)
});
@ -806,7 +837,11 @@ mod tests {
path: "root/src/main.rs".into(),
mode: EditFileMode::Overwrite,
};
Arc::new(EditFileTool { thread }).run(input, ToolCallEventStream::test().0, cx)
Arc::new(EditFileTool::new(thread.downgrade(), language_registry)).run(
input,
ToolCallEventStream::test().0,
cx,
)
});
// Stream the unformatted content
@ -850,6 +885,7 @@ mod tests {
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
let context_server_registry =
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let model = Arc::new(FakeLanguageModel::default());
let thread = cx.new(|cx| {
@ -887,9 +923,10 @@ mod tests {
path: "root/src/main.rs".into(),
mode: EditFileMode::Overwrite,
};
Arc::new(EditFileTool {
thread: thread.clone(),
})
Arc::new(EditFileTool::new(
thread.downgrade(),
language_registry.clone(),
))
.run(input, ToolCallEventStream::test().0, cx)
});
@ -938,10 +975,11 @@ mod tests {
path: "root/src/main.rs".into(),
mode: EditFileMode::Overwrite,
};
Arc::new(EditFileTool {
thread: thread.clone(),
})
.run(input, ToolCallEventStream::test().0, cx)
Arc::new(EditFileTool::new(thread.downgrade(), language_registry)).run(
input,
ToolCallEventStream::test().0,
cx,
)
});
// Stream the content with trailing whitespace
@ -976,6 +1014,7 @@ mod tests {
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
let context_server_registry =
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let model = Arc::new(FakeLanguageModel::default());
let thread = cx.new(|cx| {
@ -989,7 +1028,7 @@ mod tests {
cx,
)
});
let tool = Arc::new(EditFileTool { thread });
let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
fs.insert_tree("/root", json!({})).await;
// Test 1: Path with .zed component should require confirmation
@ -1111,6 +1150,7 @@ mod tests {
let fs = project::FakeFs::new(cx.executor());
fs.insert_tree("/project", json!({})).await;
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
let context_server_registry =
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
let action_log = cx.new(|_| ActionLog::new(project.clone()));
@ -1126,7 +1166,7 @@ mod tests {
cx,
)
});
let tool = Arc::new(EditFileTool { thread });
let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
// Test global config paths - these should require confirmation if they exist and are outside the project
let test_cases = vec![
@ -1220,7 +1260,7 @@ mod tests {
cx,
)
.await;
let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let context_server_registry =
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
@ -1236,7 +1276,7 @@ mod tests {
cx,
)
});
let tool = Arc::new(EditFileTool { thread });
let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
// Test files in different worktrees
let test_cases = vec![
@ -1302,6 +1342,7 @@ mod tests {
)
.await;
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let context_server_registry =
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
@ -1317,7 +1358,7 @@ mod tests {
cx,
)
});
let tool = Arc::new(EditFileTool { thread });
let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
// Test edge cases
let test_cases = vec![
@ -1386,6 +1427,7 @@ mod tests {
)
.await;
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let context_server_registry =
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
@ -1401,7 +1443,7 @@ mod tests {
cx,
)
});
let tool = Arc::new(EditFileTool { thread });
let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
// Test different EditFileMode values
let modes = vec![
@ -1467,6 +1509,7 @@ mod tests {
init_test(cx);
let fs = project::FakeFs::new(cx.executor());
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let context_server_registry =
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
@ -1482,7 +1525,7 @@ mod tests {
cx,
)
});
let tool = Arc::new(EditFileTool { thread });
let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
assert_eq!(
tool.initial_title(Err(json!({

View file

@ -179,15 +179,14 @@ impl AgentTool for GrepTool {
// Check if this file should be excluded based on its worktree settings
if let Ok(Some(project_path)) = project.read_with(cx, |project, cx| {
project.find_project_path(&path, cx)
}) {
if cx.update(|cx| {
})
&& cx.update(|cx| {
let worktree_settings = WorktreeSettings::get(Some((&project_path).into()), cx);
worktree_settings.is_path_excluded(&project_path.path)
|| worktree_settings.is_path_private(&project_path.path)
}).unwrap_or(false) {
continue;
}
}
while *parse_status.borrow() != ParseStatus::Idle {
parse_status.changed().await?;
@ -275,12 +274,11 @@ impl AgentTool for GrepTool {
output.extend(snapshot.text_for_range(range));
output.push_str("\n```\n");
if let Some(ancestor_range) = ancestor_range {
if end_row < ancestor_range.end.row {
if let Some(ancestor_range) = ancestor_range
&& end_row < ancestor_range.end.row {
let remaining_lines = ancestor_range.end.row - end_row;
writeln!(output, "\n{} lines remaining in ancestor node. Read the file to see all.", remaining_lines)?;
}
}
matches_found += 1;
}

View file

@ -175,7 +175,7 @@ impl AgentTool for ReadFileTool {
buffer
.file()
.as_ref()
.map_or(true, |file| !file.disk_state().exists())
.is_none_or(|file| !file.disk_state().exists())
})? {
anyhow::bail!("{file_path} not found");
}

View file

@ -47,12 +47,9 @@ impl TerminalTool {
}
if which::which("bash").is_ok() {
log::info!("agent selected bash for terminal tool");
"bash".into()
} else {
let shell = get_system_shell();
log::info!("agent selected {shell} for terminal tool");
shell
get_system_shell()
}
});
Self {
@ -271,7 +268,7 @@ fn working_dir(
let project = project.read(cx);
let cd = &input.cd;
if cd == "." || cd == "" {
if cd == "." || cd.is_empty() {
// Accept "." or "" as meaning "the one worktree" if we only have one worktree.
let mut worktrees = project.worktrees(cx);
@ -296,10 +293,8 @@ fn working_dir(
{
return Ok(Some(input_path.into()));
}
} else {
if let Some(worktree) = project.worktree_for_root_name(cd, cx) {
return Ok(Some(worktree.read(cx).abs_path().to_path_buf()));
}
} else if let Some(worktree) = project.worktree_for_root_name(cd, cx) {
return Ok(Some(worktree.read(cx).abs_path().to_path_buf()));
}
anyhow::bail!("`cd` directory {cd:?} was not in any of the project's worktrees.");
@ -319,7 +314,7 @@ mod tests {
use theme::ThemeSettings;
use util::test::TempTree;
use crate::AgentResponseEvent;
use crate::ThreadEvent;
use super::*;
@ -396,7 +391,7 @@ mod tests {
});
cx.run_until_parked();
let event = stream_rx.try_next();
if let Ok(Some(Ok(AgentResponseEvent::ToolCallAuthorization(auth)))) = event {
if let Ok(Some(Ok(ThreadEvent::ToolCallAuthorization(auth)))) = event {
auth.response.send(auth.options[0].id.clone()).unwrap();
}

View file

@ -80,33 +80,48 @@ impl AgentTool for WebSearchTool {
}
};
let result_text = if response.results.len() == 1 {
"1 result".to_string()
} else {
format!("{} results", response.results.len())
};
event_stream.update_fields(acp::ToolCallUpdateFields {
title: Some(format!("Searched the web: {result_text}")),
content: Some(
response
.results
.iter()
.map(|result| acp::ToolCallContent::Content {
content: acp::ContentBlock::ResourceLink(acp::ResourceLink {
name: result.title.clone(),
uri: result.url.clone(),
title: Some(result.title.clone()),
description: Some(result.text.clone()),
mime_type: None,
annotations: None,
size: None,
}),
})
.collect(),
),
..Default::default()
});
emit_update(&response, &event_stream);
Ok(WebSearchToolOutput(response))
})
}
fn replay(
&self,
_input: Self::Input,
output: Self::Output,
event_stream: ToolCallEventStream,
_cx: &mut App,
) -> Result<()> {
emit_update(&output.0, &event_stream);
Ok(())
}
}
fn emit_update(response: &WebSearchResponse, event_stream: &ToolCallEventStream) {
let result_text = if response.results.len() == 1 {
"1 result".to_string()
} else {
format!("{} results", response.results.len())
};
event_stream.update_fields(acp::ToolCallUpdateFields {
title: Some(format!("Searched the web: {result_text}")),
content: Some(
response
.results
.iter()
.map(|result| acp::ToolCallContent::Content {
content: acp::ContentBlock::ResourceLink(acp::ResourceLink {
name: result.title.clone(),
uri: result.url.clone(),
title: Some(result.title.clone()),
description: Some(result.text.clone()),
mime_type: None,
annotations: None,
size: None,
}),
})
.collect(),
),
..Default::default()
});
}

View file

@ -18,6 +18,7 @@ doctest = false
[dependencies]
acp_thread.workspace = true
action_log.workspace = true
agent-client-protocol.workspace = true
agent_settings.workspace = true
agentic-coding-protocol.workspace = true
@ -28,6 +29,7 @@ futures.workspace = true
gpui.workspace = true
indoc.workspace = true
itertools.workspace = true
language.workspace = true
language_model.workspace = true
language_models.workspace = true
log.workspace = true
@ -35,6 +37,7 @@ paths.workspace = true
project.workspace = true
rand.workspace = true
schemars.workspace = true
semver.workspace = true
serde.workspace = true
serde_json.workspace = true
settings.workspace = true

View file

@ -1,4 +1,5 @@
// Translates old acp agents into the new schema
use action_log::ActionLog;
use agent_client_protocol as acp;
use agentic_coding_protocol::{self as acp_old, AgentRequest as _};
use anyhow::{Context as _, Result, anyhow};
@ -148,7 +149,7 @@ impl acp_old::Client for OldAcpClientDelegate {
Ok(acp_old::RequestToolCallConfirmationResponse {
id: acp_old::ToolCallId(old_acp_id),
outcome: outcome,
outcome,
})
}
@ -265,7 +266,7 @@ impl acp_old::Client for OldAcpClientDelegate {
fn into_new_tool_call(id: acp::ToolCallId, request: acp_old::PushToolCallParams) -> acp::ToolCall {
acp::ToolCall {
id: id,
id,
title: request.label,
kind: acp_kind_from_old_icon(request.icon),
status: acp::ToolCallStatus::InProgress,
@ -443,7 +444,8 @@ impl AgentConnection for AcpConnection {
cx.update(|cx| {
let thread = cx.new(|cx| {
let session_id = acp::SessionId("acp-old-no-id".into());
AcpThread::new(self.name, self.clone(), project, session_id, cx)
let action_log = cx.new(|_| ActionLog::new(project.clone()));
AcpThread::new(self.name, self.clone(), project, action_log, session_id)
});
current_thread.replace(thread.downgrade());
thread

View file

@ -1,3 +1,4 @@
use action_log::ActionLog;
use agent_client_protocol::{self as acp, Agent as _};
use anyhow::anyhow;
use collections::HashMap;
@ -13,7 +14,7 @@ use anyhow::{Context as _, Result};
use gpui::{App, AppContext as _, AsyncApp, Entity, Task, WeakEntity};
use crate::{AgentServerCommand, acp::UnsupportedVersion};
use acp_thread::{AcpThread, AgentConnection, AuthRequired};
use acp_thread::{AcpThread, AgentConnection, AuthRequired, LoadError};
pub struct AcpConnection {
server_name: &'static str,
@ -86,7 +87,9 @@ impl AcpConnection {
for session in sessions.borrow().values() {
session
.thread
.update(cx, |thread, cx| thread.emit_server_exited(status, cx))
.update(cx, |thread, cx| {
thread.emit_load_error(LoadError::Exited { status }, cx)
})
.ok();
}
@ -153,14 +156,14 @@ impl AgentConnection for AcpConnection {
})?;
let session_id = response.session_id;
let thread = cx.new(|cx| {
let action_log = cx.new(|_| ActionLog::new(project.clone()))?;
let thread = cx.new(|_cx| {
AcpThread::new(
self.server_name,
self.clone(),
project,
action_log,
session_id.clone(),
cx,
)
})?;

View file

@ -104,7 +104,7 @@ impl AgentServerCommand {
cx: &mut AsyncApp,
) -> Option<Self> {
if let Some(agent_settings) = settings {
return Some(Self {
Some(Self {
path: agent_settings.command.path,
args: agent_settings
.command
@ -113,7 +113,7 @@ impl AgentServerCommand {
.chain(extra_args.iter().map(|arg| arg.to_string()))
.collect(),
env: agent_settings.command.env,
});
})
} else {
match find_bin_in_path(path_bin_name, project, cx).await {
Some(path) => Some(Self {

View file

@ -1,6 +1,11 @@
mod edit_tool;
mod mcp_server;
mod permission_tool;
mod read_tool;
pub mod tools;
mod write_tool;
use action_log::ActionLog;
use collections::HashMap;
use context_server::listener::McpServerTool;
use language_models::provider::anthropic::AnthropicLanguageModelProvider;
@ -10,8 +15,9 @@ use smol::process::Child;
use std::any::Any;
use std::cell::RefCell;
use std::fmt::Display;
use std::path::Path;
use std::path::{Path, PathBuf};
use std::rc::Rc;
use util::command::new_smol_command;
use uuid::Uuid;
use agent_client_protocol as acp;
@ -31,7 +37,7 @@ use util::{ResultExt, debug_panic};
use crate::claude::mcp_server::{ClaudeZedMcpServer, McpConfig};
use crate::claude::tools::ClaudeTool;
use crate::{AgentServer, AgentServerCommand, AllAgentServersSettings};
use acp_thread::{AcpThread, AgentConnection, AuthRequired};
use acp_thread::{AcpThread, AgentConnection, AuthRequired, LoadError, MentionUri};
#[derive(Clone)]
pub struct ClaudeCode;
@ -98,7 +104,11 @@ impl AgentConnection for ClaudeAgentConnection {
)
.await
else {
anyhow::bail!("Failed to find claude binary");
return Err(LoadError::NotInstalled {
error_message: "Failed to find Claude Code binary".into(),
install_message: "Install Claude Code".into(),
install_command: "npm install -g @anthropic-ai/claude-code@latest".into(),
}.into());
};
let api_key =
@ -203,20 +213,50 @@ impl AgentConnection for ClaudeAgentConnection {
.await
}
if let Some(status) = child.status().await.log_err() {
if let Some(thread) = thread_rx.recv().await.ok() {
thread
.update(cx, |thread, cx| {
thread.emit_server_exited(status, cx);
})
.ok();
}
if let Some(status) = child.status().await.log_err()
&& let Some(thread) = thread_rx.recv().await.ok()
{
let version = claude_version(command.path.clone(), cx).await.log_err();
let help = claude_help(command.path.clone(), cx).await.log_err();
thread
.update(cx, |thread, cx| {
let error = if let Some(version) = version
&& let Some(help) = help
&& (!help.contains("--input-format")
|| !help.contains("--session-id"))
{
LoadError::Unsupported {
error_message: format!(
"Your installed version of Claude Code ({}, version {}) does not have required features for use with Zed.",
command.path.to_string_lossy(),
version,
)
.into(),
upgrade_message: "Upgrade Claude Code to latest".into(),
upgrade_command: format!(
"{} update",
command.path.to_string_lossy()
),
}
} else {
LoadError::Exited { status }
};
thread.emit_load_error(error, cx);
})
.ok();
}
}
});
let thread = cx.new(|cx| {
AcpThread::new("Claude Code", self.clone(), project, session_id.clone(), cx)
let action_log = cx.new(|_| ActionLog::new(project.clone()))?;
let thread = cx.new(|_cx| {
AcpThread::new(
"Claude Code",
self.clone(),
project,
action_log,
session_id.clone(),
)
})?;
thread_tx.send(thread.downgrade())?;
@ -259,27 +299,12 @@ impl AgentConnection for ClaudeAgentConnection {
let (end_tx, end_rx) = oneshot::channel();
session.turn_state.replace(TurnState::InProgress { end_tx });
let mut content = String::new();
for chunk in params.prompt {
match chunk {
acp::ContentBlock::Text(text_content) => {
content.push_str(&text_content.text);
}
acp::ContentBlock::ResourceLink(resource_link) => {
content.push_str(&format!("@{}", resource_link.uri));
}
acp::ContentBlock::Audio(_)
| acp::ContentBlock::Image(_)
| acp::ContentBlock::Resource(_) => {
// TODO
}
}
}
let content = acp_content_to_claude(params.prompt);
if let Err(err) = session.outgoing_tx.unbounded_send(SdkMessage::User {
message: Message {
role: Role::User,
content: Content::UntaggedText(content),
content: Content::Chunks(content),
id: None,
model: None,
stop_reason: None,
@ -358,18 +383,16 @@ fn spawn_claude(
&format!(
"mcp__{}__{}",
mcp_server::SERVER_NAME,
mcp_server::PermissionTool::NAME,
permission_tool::PermissionTool::NAME,
),
"--allowedTools",
&format!(
"mcp__{}__{},mcp__{}__{}",
"mcp__{}__{}",
mcp_server::SERVER_NAME,
mcp_server::EditTool::NAME,
mcp_server::SERVER_NAME,
mcp_server::ReadTool::NAME
read_tool::ReadTool::NAME
),
"--disallowedTools",
"Read,Edit",
"Read,Write,Edit,MultiEdit",
])
.args(match mode {
ClaudeSessionMode::Start => ["--session-id".to_string(), session_id.to_string()],
@ -388,6 +411,27 @@ fn spawn_claude(
Ok(child)
}
fn claude_version(path: PathBuf, cx: &mut AsyncApp) -> Task<Result<semver::Version>> {
cx.background_spawn(async move {
let output = new_smol_command(path).arg("--version").output().await?;
let output = String::from_utf8(output.stdout)?;
let version = output
.trim()
.strip_suffix(" (Claude Code)")
.context("parsing Claude version")?;
let version = semver::Version::parse(version)?;
anyhow::Ok(version)
})
}
fn claude_help(path: PathBuf, cx: &mut AsyncApp) -> Task<Result<String>> {
cx.background_spawn(async move {
let output = new_smol_command(path).arg("--help").output().await?;
let output = String::from_utf8(output.stdout)?;
anyhow::Ok(output)
})
}
struct ClaudeAgentSession {
outgoing_tx: UnboundedSender<SdkMessage>,
turn_state: Rc<RefCell<TurnState>>,
@ -477,9 +521,16 @@ impl ClaudeAgentSession {
let content = content.to_string();
thread
.update(cx, |thread, cx| {
let id = acp::ToolCallId(tool_use_id.into());
let set_new_content = !content.is_empty()
&& thread.tool_call(&id).is_none_or(|(_, tool_call)| {
// preserve rich diff if we have one
tool_call.diffs().next().is_none()
});
thread.update_tool_call(
acp::ToolCallUpdate {
id: acp::ToolCallId(tool_use_id.into()),
id,
fields: acp::ToolCallUpdateFields {
status: if turn_state.borrow().is_canceled() {
// Do not set to completed if turn was canceled
@ -487,7 +538,7 @@ impl ClaudeAgentSession {
} else {
Some(acp::ToolCallStatus::Completed)
},
content: (!content.is_empty())
content: set_new_content
.then(|| vec![content.into()]),
..Default::default()
},
@ -505,10 +556,17 @@ impl ClaudeAgentSession {
chunk
);
}
ContentChunk::Image { source } => {
if !turn_state.borrow().is_canceled() {
thread
.update(cx, |thread, cx| {
thread.push_user_content_block(None, source.into(), cx)
})
.log_err();
}
}
ContentChunk::Image
| ContentChunk::Document
| ContentChunk::WebSearchToolResult => {
ContentChunk::Document | ContentChunk::WebSearchToolResult => {
thread
.update(cx, |thread, cx| {
thread.push_assistant_content_block(
@ -594,7 +652,14 @@ impl ClaudeAgentSession {
"Should not get tool results with role: assistant. should we handle this?"
);
}
ContentChunk::Image | ContentChunk::Document => {
ContentChunk::Image { source } => {
thread
.update(cx, |thread, cx| {
thread.push_assistant_content_block(source.into(), false, cx)
})
.log_err();
}
ContentChunk::Document => {
thread
.update(cx, |thread, cx| {
thread.push_assistant_content_block(
@ -760,14 +825,44 @@ enum ContentChunk {
thinking: String,
},
RedactedThinking,
Image {
source: ImageSource,
},
// TODO
Image,
Document,
WebSearchToolResult,
#[serde(untagged)]
UntaggedText(String),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
enum ImageSource {
Base64 { data: String, media_type: String },
Url { url: String },
}
impl Into<acp::ContentBlock> for ImageSource {
fn into(self) -> acp::ContentBlock {
match self {
ImageSource::Base64 { data, media_type } => {
acp::ContentBlock::Image(acp::ImageContent {
annotations: None,
data,
mime_type: media_type,
uri: None,
})
}
ImageSource::Url { url } => acp::ContentBlock::Image(acp::ImageContent {
annotations: None,
data: "".to_string(),
mime_type: "".to_string(),
uri: Some(url),
}),
}
}
}
impl Display for ContentChunk {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
@ -776,7 +871,7 @@ impl Display for ContentChunk {
ContentChunk::RedactedThinking => write!(f, "Thinking: [REDACTED]"),
ContentChunk::UntaggedText(text) => write!(f, "{}", text),
ContentChunk::ToolResult { content, .. } => write!(f, "{}", content),
ContentChunk::Image
ContentChunk::Image { .. }
| ContentChunk::Document
| ContentChunk::ToolUse { .. }
| ContentChunk::WebSearchToolResult => {
@ -888,6 +983,75 @@ impl Display for ResultErrorType {
}
}
fn acp_content_to_claude(prompt: Vec<acp::ContentBlock>) -> Vec<ContentChunk> {
let mut content = Vec::with_capacity(prompt.len());
let mut context = Vec::with_capacity(prompt.len());
for chunk in prompt {
match chunk {
acp::ContentBlock::Text(text_content) => {
content.push(ContentChunk::Text {
text: text_content.text,
});
}
acp::ContentBlock::ResourceLink(resource_link) => {
match MentionUri::parse(&resource_link.uri) {
Ok(uri) => {
content.push(ContentChunk::Text {
text: format!("{}", uri.as_link()),
});
}
Err(_) => {
content.push(ContentChunk::Text {
text: resource_link.uri,
});
}
}
}
acp::ContentBlock::Resource(resource) => match resource.resource {
acp::EmbeddedResourceResource::TextResourceContents(resource) => {
match MentionUri::parse(&resource.uri) {
Ok(uri) => {
content.push(ContentChunk::Text {
text: format!("{}", uri.as_link()),
});
}
Err(_) => {
content.push(ContentChunk::Text {
text: resource.uri.clone(),
});
}
}
context.push(ContentChunk::Text {
text: format!(
"\n<context ref=\"{}\">\n{}\n</context>",
resource.uri, resource.text
),
});
}
acp::EmbeddedResourceResource::BlobResourceContents(_) => {
// Unsupported by SDK
}
},
acp::ContentBlock::Image(acp::ImageContent {
data, mime_type, ..
}) => content.push(ContentChunk::Image {
source: ImageSource::Base64 {
data,
media_type: mime_type,
},
}),
acp::ContentBlock::Audio(_) => {
// Unsupported by SDK
}
}
}
content.extend(context);
content
}
fn new_request_id() -> String {
use rand::Rng;
// In the Claude Code TS SDK they just generate a random 12 character string,
@ -1104,4 +1268,100 @@ pub(crate) mod tests {
_ => panic!("Expected ToolResult variant"),
}
}
#[test]
fn test_acp_content_to_claude() {
let acp_content = vec![
acp::ContentBlock::Text(acp::TextContent {
text: "Hello world".to_string(),
annotations: None,
}),
acp::ContentBlock::Image(acp::ImageContent {
data: "base64data".to_string(),
mime_type: "image/png".to_string(),
annotations: None,
uri: None,
}),
acp::ContentBlock::ResourceLink(acp::ResourceLink {
uri: "file:///path/to/example.rs".to_string(),
name: "example.rs".to_string(),
annotations: None,
description: None,
mime_type: None,
size: None,
title: None,
}),
acp::ContentBlock::Resource(acp::EmbeddedResource {
annotations: None,
resource: acp::EmbeddedResourceResource::TextResourceContents(
acp::TextResourceContents {
mime_type: None,
text: "fn main() { println!(\"Hello!\"); }".to_string(),
uri: "file:///path/to/code.rs".to_string(),
},
),
}),
acp::ContentBlock::ResourceLink(acp::ResourceLink {
uri: "invalid_uri_format".to_string(),
name: "invalid.txt".to_string(),
annotations: None,
description: None,
mime_type: None,
size: None,
title: None,
}),
];
let claude_content = acp_content_to_claude(acp_content);
assert_eq!(claude_content.len(), 6);
match &claude_content[0] {
ContentChunk::Text { text } => assert_eq!(text, "Hello world"),
_ => panic!("Expected Text chunk"),
}
match &claude_content[1] {
ContentChunk::Image { source } => match source {
ImageSource::Base64 { data, media_type } => {
assert_eq!(data, "base64data");
assert_eq!(media_type, "image/png");
}
_ => panic!("Expected Base64 image source"),
},
_ => panic!("Expected Image chunk"),
}
match &claude_content[2] {
ContentChunk::Text { text } => {
assert!(text.contains("example.rs"));
assert!(text.contains("file:///path/to/example.rs"));
}
_ => panic!("Expected Text chunk for ResourceLink"),
}
match &claude_content[3] {
ContentChunk::Text { text } => {
assert!(text.contains("code.rs"));
assert!(text.contains("file:///path/to/code.rs"));
}
_ => panic!("Expected Text chunk for Resource"),
}
match &claude_content[4] {
ContentChunk::Text { text } => {
assert_eq!(text, "invalid_uri_format");
}
_ => panic!("Expected Text chunk for invalid URI"),
}
match &claude_content[5] {
ContentChunk::Text { text } => {
assert!(text.contains("<context ref=\"file:///path/to/code.rs\">"));
assert!(text.contains("fn main() { println!(\"Hello!\"); }"));
assert!(text.contains("</context>"));
}
_ => panic!("Expected Text chunk for context"),
}
}
}

View file

@ -0,0 +1,178 @@
use acp_thread::AcpThread;
use anyhow::Result;
use context_server::{
listener::{McpServerTool, ToolResponse},
types::{ToolAnnotations, ToolResponseContent},
};
use gpui::{AsyncApp, WeakEntity};
use language::unified_diff;
use util::markdown::MarkdownCodeBlock;
use crate::tools::EditToolParams;
#[derive(Clone)]
pub struct EditTool {
thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
}
impl EditTool {
pub fn new(thread_rx: watch::Receiver<WeakEntity<AcpThread>>) -> Self {
Self { thread_rx }
}
}
impl McpServerTool for EditTool {
type Input = EditToolParams;
type Output = ();
const NAME: &'static str = "Edit";
fn annotations(&self) -> ToolAnnotations {
ToolAnnotations {
title: Some("Edit file".to_string()),
read_only_hint: Some(false),
destructive_hint: Some(false),
open_world_hint: Some(false),
idempotent_hint: Some(false),
}
}
async fn run(
&self,
input: Self::Input,
cx: &mut AsyncApp,
) -> Result<ToolResponse<Self::Output>> {
let mut thread_rx = self.thread_rx.clone();
let Some(thread) = thread_rx.recv().await?.upgrade() else {
anyhow::bail!("Thread closed");
};
let content = thread
.update(cx, |thread, cx| {
thread.read_text_file(input.abs_path.clone(), None, None, true, cx)
})?
.await?;
let (new_content, diff) = cx
.background_executor()
.spawn(async move {
let new_content = content.replace(&input.old_text, &input.new_text);
if new_content == content {
return Err(anyhow::anyhow!("Failed to find `old_text`",));
}
let diff = unified_diff(&content, &new_content);
Ok((new_content, diff))
})
.await?;
thread
.update(cx, |thread, cx| {
thread.write_text_file(input.abs_path, new_content, cx)
})?
.await?;
Ok(ToolResponse {
content: vec![ToolResponseContent::Text {
text: MarkdownCodeBlock {
tag: "diff",
text: diff.as_str().trim_end_matches('\n'),
}
.to_string(),
}],
structured_content: (),
})
}
}
#[cfg(test)]
mod tests {
use std::rc::Rc;
use acp_thread::{AgentConnection, StubAgentConnection};
use gpui::{Entity, TestAppContext};
use indoc::indoc;
use project::{FakeFs, Project};
use serde_json::json;
use settings::SettingsStore;
use util::path;
use super::*;
#[gpui::test]
async fn old_text_not_found(cx: &mut TestAppContext) {
let (_thread, tool) = init_test(cx).await;
let result = tool
.run(
EditToolParams {
abs_path: path!("/root/file.txt").into(),
old_text: "hi".into(),
new_text: "bye".into(),
},
&mut cx.to_async(),
)
.await;
assert_eq!(result.unwrap_err().to_string(), "Failed to find `old_text`");
}
#[gpui::test]
async fn found_and_replaced(cx: &mut TestAppContext) {
let (_thread, tool) = init_test(cx).await;
let result = tool
.run(
EditToolParams {
abs_path: path!("/root/file.txt").into(),
old_text: "hello".into(),
new_text: "hi".into(),
},
&mut cx.to_async(),
)
.await;
assert_eq!(
result.unwrap().content[0].text().unwrap(),
indoc! {
r"
```diff
@@ -1,1 +1,1 @@
-hello
+hi
```
"
}
);
}
async fn init_test(cx: &mut TestAppContext) -> (Entity<AcpThread>, EditTool) {
cx.update(|cx| {
let settings_store = SettingsStore::test(cx);
cx.set_global(settings_store);
language::init(cx);
Project::init_settings(cx);
});
let connection = Rc::new(StubAgentConnection::new());
let fs = FakeFs::new(cx.executor());
fs.insert_tree(
path!("/root"),
json!({
"file.txt": "hello"
}),
)
.await;
let project = Project::test(fs, [path!("/root").as_ref()], cx).await;
let (mut thread_tx, thread_rx) = watch::channel(WeakEntity::new_invalid());
let thread = cx
.update(|cx| connection.new_thread(project, path!("/test").as_ref(), cx))
.await
.unwrap();
thread_tx.send(thread.downgrade()).unwrap();
(thread, EditTool::new(thread_rx))
}
}

View file

@ -1,23 +1,22 @@
use std::path::PathBuf;
use std::sync::Arc;
use crate::claude::tools::{ClaudeTool, EditToolParams, ReadToolParams};
use crate::claude::edit_tool::EditTool;
use crate::claude::permission_tool::PermissionTool;
use crate::claude::read_tool::ReadTool;
use crate::claude::write_tool::WriteTool;
use acp_thread::AcpThread;
use agent_client_protocol as acp;
use agent_settings::AgentSettings;
use anyhow::{Context, Result};
#[cfg(not(test))]
use anyhow::Context as _;
use anyhow::Result;
use collections::HashMap;
use context_server::listener::{McpServerTool, ToolResponse};
use context_server::types::{
Implementation, InitializeParams, InitializeResponse, ProtocolVersion, ServerCapabilities,
ToolAnnotations, ToolResponseContent, ToolsCapabilities, requests,
ToolsCapabilities, requests,
};
use gpui::{App, AsyncApp, Task, WeakEntity};
use project::Fs;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::{Settings as _, update_settings_file};
use util::debug_panic;
use serde::Serialize;
pub struct ClaudeZedMcpServer {
server: context_server::listener::McpServer,
@ -34,16 +33,10 @@ impl ClaudeZedMcpServer {
let mut mcp_server = context_server::listener::McpServer::new(cx).await?;
mcp_server.handle_request::<requests::Initialize>(Self::handle_initialize);
mcp_server.add_tool(PermissionTool {
thread_rx: thread_rx.clone(),
fs: fs.clone(),
});
mcp_server.add_tool(ReadTool {
thread_rx: thread_rx.clone(),
});
mcp_server.add_tool(EditTool {
thread_rx: thread_rx.clone(),
});
mcp_server.add_tool(PermissionTool::new(fs.clone(), thread_rx.clone()));
mcp_server.add_tool(ReadTool::new(thread_rx.clone()));
mcp_server.add_tool(EditTool::new(thread_rx.clone()));
mcp_server.add_tool(WriteTool::new(thread_rx.clone()));
Ok(Self { server: mcp_server })
}
@ -104,249 +97,3 @@ pub struct McpServerConfig {
#[serde(skip_serializing_if = "Option::is_none")]
pub env: Option<HashMap<String, String>>,
}
// Tools
#[derive(Clone)]
pub struct PermissionTool {
fs: Arc<dyn Fs>,
thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
}
#[derive(Deserialize, JsonSchema, Debug)]
pub struct PermissionToolParams {
tool_name: String,
input: serde_json::Value,
tool_use_id: Option<String>,
}
#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
pub struct PermissionToolResponse {
behavior: PermissionToolBehavior,
updated_input: serde_json::Value,
}
#[derive(Serialize)]
#[serde(rename_all = "snake_case")]
enum PermissionToolBehavior {
Allow,
Deny,
}
impl McpServerTool for PermissionTool {
type Input = PermissionToolParams;
type Output = ();
const NAME: &'static str = "Confirmation";
fn description(&self) -> &'static str {
"Request permission for tool calls"
}
async fn run(
&self,
input: Self::Input,
cx: &mut AsyncApp,
) -> Result<ToolResponse<Self::Output>> {
if agent_settings::AgentSettings::try_read_global(cx, |settings| {
settings.always_allow_tool_actions
})
.unwrap_or(false)
{
let response = PermissionToolResponse {
behavior: PermissionToolBehavior::Allow,
updated_input: input.input,
};
return Ok(ToolResponse {
content: vec![ToolResponseContent::Text {
text: serde_json::to_string(&response)?,
}],
structured_content: (),
});
}
let mut thread_rx = self.thread_rx.clone();
let Some(thread) = thread_rx.recv().await?.upgrade() else {
anyhow::bail!("Thread closed");
};
let claude_tool = ClaudeTool::infer(&input.tool_name, input.input.clone());
let tool_call_id = acp::ToolCallId(input.tool_use_id.context("Tool ID required")?.into());
const ALWAYS_ALLOW: &'static str = "always_allow";
const ALLOW: &'static str = "allow";
const REJECT: &'static str = "reject";
let chosen_option = thread
.update(cx, |thread, cx| {
thread.request_tool_call_authorization(
claude_tool.as_acp(tool_call_id).into(),
vec![
acp::PermissionOption {
id: acp::PermissionOptionId(ALWAYS_ALLOW.into()),
name: "Always Allow".into(),
kind: acp::PermissionOptionKind::AllowAlways,
},
acp::PermissionOption {
id: acp::PermissionOptionId(ALLOW.into()),
name: "Allow".into(),
kind: acp::PermissionOptionKind::AllowOnce,
},
acp::PermissionOption {
id: acp::PermissionOptionId(REJECT.into()),
name: "Reject".into(),
kind: acp::PermissionOptionKind::RejectOnce,
},
],
cx,
)
})??
.await?;
let response = match chosen_option.0.as_ref() {
ALWAYS_ALLOW => {
cx.update(|cx| {
update_settings_file::<AgentSettings>(self.fs.clone(), cx, |settings, _| {
settings.set_always_allow_tool_actions(true);
});
})?;
PermissionToolResponse {
behavior: PermissionToolBehavior::Allow,
updated_input: input.input,
}
}
ALLOW => PermissionToolResponse {
behavior: PermissionToolBehavior::Allow,
updated_input: input.input,
},
REJECT => PermissionToolResponse {
behavior: PermissionToolBehavior::Deny,
updated_input: input.input,
},
opt => {
debug_panic!("Unexpected option: {}", opt);
PermissionToolResponse {
behavior: PermissionToolBehavior::Deny,
updated_input: input.input,
}
}
};
Ok(ToolResponse {
content: vec![ToolResponseContent::Text {
text: serde_json::to_string(&response)?,
}],
structured_content: (),
})
}
}
#[derive(Clone)]
pub struct ReadTool {
thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
}
impl McpServerTool for ReadTool {
type Input = ReadToolParams;
type Output = ();
const NAME: &'static str = "Read";
fn description(&self) -> &'static str {
"Read the contents of a file. In sessions with mcp__zed__Read always use it instead of Read as it contains the most up-to-date contents."
}
fn annotations(&self) -> ToolAnnotations {
ToolAnnotations {
title: Some("Read file".to_string()),
read_only_hint: Some(true),
destructive_hint: Some(false),
open_world_hint: Some(false),
idempotent_hint: None,
}
}
async fn run(
&self,
input: Self::Input,
cx: &mut AsyncApp,
) -> Result<ToolResponse<Self::Output>> {
let mut thread_rx = self.thread_rx.clone();
let Some(thread) = thread_rx.recv().await?.upgrade() else {
anyhow::bail!("Thread closed");
};
let content = thread
.update(cx, |thread, cx| {
thread.read_text_file(input.abs_path, input.offset, input.limit, false, cx)
})?
.await?;
Ok(ToolResponse {
content: vec![ToolResponseContent::Text { text: content }],
structured_content: (),
})
}
}
#[derive(Clone)]
pub struct EditTool {
thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
}
impl McpServerTool for EditTool {
type Input = EditToolParams;
type Output = ();
const NAME: &'static str = "Edit";
fn description(&self) -> &'static str {
"Edits a file. In sessions with mcp__zed__Edit always use it instead of Edit as it will show the diff to the user better."
}
fn annotations(&self) -> ToolAnnotations {
ToolAnnotations {
title: Some("Edit file".to_string()),
read_only_hint: Some(false),
destructive_hint: Some(false),
open_world_hint: Some(false),
idempotent_hint: Some(false),
}
}
async fn run(
&self,
input: Self::Input,
cx: &mut AsyncApp,
) -> Result<ToolResponse<Self::Output>> {
let mut thread_rx = self.thread_rx.clone();
let Some(thread) = thread_rx.recv().await?.upgrade() else {
anyhow::bail!("Thread closed");
};
let content = thread
.update(cx, |thread, cx| {
thread.read_text_file(input.abs_path.clone(), None, None, true, cx)
})?
.await?;
let new_content = content.replace(&input.old_text, &input.new_text);
if new_content == content {
return Err(anyhow::anyhow!("The old_text was not found in the content"));
}
thread
.update(cx, |thread, cx| {
thread.write_text_file(input.abs_path, new_content, cx)
})?
.await?;
Ok(ToolResponse {
content: vec![],
structured_content: (),
})
}
}

View file

@ -0,0 +1,158 @@
use std::sync::Arc;
use acp_thread::AcpThread;
use agent_client_protocol as acp;
use agent_settings::AgentSettings;
use anyhow::{Context as _, Result};
use context_server::{
listener::{McpServerTool, ToolResponse},
types::ToolResponseContent,
};
use gpui::{AsyncApp, WeakEntity};
use project::Fs;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::{Settings as _, update_settings_file};
use util::debug_panic;
use crate::tools::ClaudeTool;
#[derive(Clone)]
pub struct PermissionTool {
fs: Arc<dyn Fs>,
thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
}
/// Request permission for tool calls
#[derive(Deserialize, JsonSchema, Debug)]
pub struct PermissionToolParams {
tool_name: String,
input: serde_json::Value,
tool_use_id: Option<String>,
}
#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
pub struct PermissionToolResponse {
behavior: PermissionToolBehavior,
updated_input: serde_json::Value,
}
#[derive(Serialize)]
#[serde(rename_all = "snake_case")]
enum PermissionToolBehavior {
Allow,
Deny,
}
impl PermissionTool {
pub fn new(fs: Arc<dyn Fs>, thread_rx: watch::Receiver<WeakEntity<AcpThread>>) -> Self {
Self { fs, thread_rx }
}
}
impl McpServerTool for PermissionTool {
type Input = PermissionToolParams;
type Output = ();
const NAME: &'static str = "Confirmation";
async fn run(
&self,
input: Self::Input,
cx: &mut AsyncApp,
) -> Result<ToolResponse<Self::Output>> {
if agent_settings::AgentSettings::try_read_global(cx, |settings| {
settings.always_allow_tool_actions
})
.unwrap_or(false)
{
let response = PermissionToolResponse {
behavior: PermissionToolBehavior::Allow,
updated_input: input.input,
};
return Ok(ToolResponse {
content: vec![ToolResponseContent::Text {
text: serde_json::to_string(&response)?,
}],
structured_content: (),
});
}
let mut thread_rx = self.thread_rx.clone();
let Some(thread) = thread_rx.recv().await?.upgrade() else {
anyhow::bail!("Thread closed");
};
let claude_tool = ClaudeTool::infer(&input.tool_name, input.input.clone());
let tool_call_id = acp::ToolCallId(input.tool_use_id.context("Tool ID required")?.into());
const ALWAYS_ALLOW: &str = "always_allow";
const ALLOW: &str = "allow";
const REJECT: &str = "reject";
let chosen_option = thread
.update(cx, |thread, cx| {
thread.request_tool_call_authorization(
claude_tool.as_acp(tool_call_id).into(),
vec![
acp::PermissionOption {
id: acp::PermissionOptionId(ALWAYS_ALLOW.into()),
name: "Always Allow".into(),
kind: acp::PermissionOptionKind::AllowAlways,
},
acp::PermissionOption {
id: acp::PermissionOptionId(ALLOW.into()),
name: "Allow".into(),
kind: acp::PermissionOptionKind::AllowOnce,
},
acp::PermissionOption {
id: acp::PermissionOptionId(REJECT.into()),
name: "Reject".into(),
kind: acp::PermissionOptionKind::RejectOnce,
},
],
cx,
)
})??
.await?;
let response = match chosen_option.0.as_ref() {
ALWAYS_ALLOW => {
cx.update(|cx| {
update_settings_file::<AgentSettings>(self.fs.clone(), cx, |settings, _| {
settings.set_always_allow_tool_actions(true);
});
})?;
PermissionToolResponse {
behavior: PermissionToolBehavior::Allow,
updated_input: input.input,
}
}
ALLOW => PermissionToolResponse {
behavior: PermissionToolBehavior::Allow,
updated_input: input.input,
},
REJECT => PermissionToolResponse {
behavior: PermissionToolBehavior::Deny,
updated_input: input.input,
},
opt => {
debug_panic!("Unexpected option: {}", opt);
PermissionToolResponse {
behavior: PermissionToolBehavior::Deny,
updated_input: input.input,
}
}
};
Ok(ToolResponse {
content: vec![ToolResponseContent::Text {
text: serde_json::to_string(&response)?,
}],
structured_content: (),
})
}
}

View file

@ -0,0 +1,59 @@
use acp_thread::AcpThread;
use anyhow::Result;
use context_server::{
listener::{McpServerTool, ToolResponse},
types::{ToolAnnotations, ToolResponseContent},
};
use gpui::{AsyncApp, WeakEntity};
use crate::tools::ReadToolParams;
#[derive(Clone)]
pub struct ReadTool {
thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
}
impl ReadTool {
pub fn new(thread_rx: watch::Receiver<WeakEntity<AcpThread>>) -> Self {
Self { thread_rx }
}
}
impl McpServerTool for ReadTool {
type Input = ReadToolParams;
type Output = ();
const NAME: &'static str = "Read";
fn annotations(&self) -> ToolAnnotations {
ToolAnnotations {
title: Some("Read file".to_string()),
read_only_hint: Some(true),
destructive_hint: Some(false),
open_world_hint: Some(false),
idempotent_hint: None,
}
}
async fn run(
&self,
input: Self::Input,
cx: &mut AsyncApp,
) -> Result<ToolResponse<Self::Output>> {
let mut thread_rx = self.thread_rx.clone();
let Some(thread) = thread_rx.recv().await?.upgrade() else {
anyhow::bail!("Thread closed");
};
let content = thread
.update(cx, |thread, cx| {
thread.read_text_file(input.abs_path, input.offset, input.limit, false, cx)
})?
.await?;
Ok(ToolResponse {
content: vec![ToolResponseContent::Text { text: content }],
structured_content: (),
})
}
}

View file

@ -34,6 +34,7 @@ impl ClaudeTool {
// Known tools
"mcp__zed__Read" => Self::ReadFile(serde_json::from_value(input).log_err()),
"mcp__zed__Edit" => Self::Edit(serde_json::from_value(input).log_err()),
"mcp__zed__Write" => Self::Write(serde_json::from_value(input).log_err()),
"MultiEdit" => Self::MultiEdit(serde_json::from_value(input).log_err()),
"Write" => Self::Write(serde_json::from_value(input).log_err()),
"LS" => Self::Ls(serde_json::from_value(input).log_err()),
@ -93,7 +94,7 @@ impl ClaudeTool {
}
Self::MultiEdit(None) => "Multi Edit".into(),
Self::Write(Some(params)) => {
format!("Write {}", params.file_path.display())
format!("Write {}", params.abs_path.display())
}
Self::Write(None) => "Write".into(),
Self::Glob(Some(params)) => {
@ -153,7 +154,7 @@ impl ClaudeTool {
}],
Self::Write(Some(params)) => vec![acp::ToolCallContent::Diff {
diff: acp::Diff {
path: params.file_path.clone(),
path: params.abs_path.clone(),
old_text: None,
new_text: params.content.clone(),
},
@ -229,7 +230,10 @@ impl ClaudeTool {
line: None,
}]
}
Self::Write(Some(WriteToolParams { file_path, .. })) => {
Self::Write(Some(WriteToolParams {
abs_path: file_path,
..
})) => {
vec![acp::ToolCallLocation {
path: file_path.clone(),
line: None,
@ -302,6 +306,20 @@ impl ClaudeTool {
}
}
/// Edit a file.
///
/// In sessions with mcp__zed__Edit always use it instead of Edit as it will
/// allow the user to conveniently review changes.
///
/// File editing instructions:
/// - The `old_text` param must match existing file content, including indentation.
/// - The `old_text` param must come from the actual file, not an outline.
/// - The `old_text` section must not be empty.
/// - Be minimal with replacements:
/// - For unique lines, include only those lines.
/// - For non-unique lines, include enough context to identify them.
/// - Do not escape quotes, newlines, or other characters.
/// - Only edit the specified file.
#[derive(Deserialize, JsonSchema, Debug)]
pub struct EditToolParams {
/// The absolute path to the file to read.
@ -312,6 +330,11 @@ pub struct EditToolParams {
pub new_text: String,
}
/// Reads the content of the given file in the project.
///
/// Never attempt to read a path that hasn't been previously mentioned.
///
/// In sessions with mcp__zed__Read always use it instead of Read as it contains the most up-to-date contents.
#[derive(Deserialize, JsonSchema, Debug)]
pub struct ReadToolParams {
/// The absolute path to the file to read.
@ -324,11 +347,15 @@ pub struct ReadToolParams {
pub limit: Option<u32>,
}
/// Writes content to the specified file in the project.
///
/// In sessions with mcp__zed__Write always use it instead of Write as it will
/// allow the user to conveniently review changes.
#[derive(Deserialize, JsonSchema, Debug)]
pub struct WriteToolParams {
/// Absolute path for new file
pub file_path: PathBuf,
/// File content
/// The absolute path of the file to write.
pub abs_path: PathBuf,
/// The full content to write.
pub content: String,
}

View file

@ -0,0 +1,59 @@
use acp_thread::AcpThread;
use anyhow::Result;
use context_server::{
listener::{McpServerTool, ToolResponse},
types::ToolAnnotations,
};
use gpui::{AsyncApp, WeakEntity};
use crate::tools::WriteToolParams;
#[derive(Clone)]
pub struct WriteTool {
thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
}
impl WriteTool {
pub fn new(thread_rx: watch::Receiver<WeakEntity<AcpThread>>) -> Self {
Self { thread_rx }
}
}
impl McpServerTool for WriteTool {
type Input = WriteToolParams;
type Output = ();
const NAME: &'static str = "Write";
fn annotations(&self) -> ToolAnnotations {
ToolAnnotations {
title: Some("Write file".to_string()),
read_only_hint: Some(false),
destructive_hint: Some(false),
open_world_hint: Some(false),
idempotent_hint: Some(false),
}
}
async fn run(
&self,
input: Self::Input,
cx: &mut AsyncApp,
) -> Result<ToolResponse<Self::Output>> {
let mut thread_rx = self.thread_rx.clone();
let Some(thread) = thread_rx.recv().await?.upgrade() else {
anyhow::bail!("Thread closed");
};
thread
.update(cx, |thread, cx| {
thread.write_text_file(input.abs_path, input.content, cx)
})?
.await?;
Ok(ToolResponse {
content: vec![],
structured_content: (),
})
}
}

View file

@ -428,12 +428,9 @@ pub async fn new_test_thread(
.await
.unwrap();
let thread = cx
.update(|cx| connection.new_thread(project.clone(), current_dir.as_ref(), cx))
cx.update(|cx| connection.new_thread(project.clone(), current_dir.as_ref(), cx))
.await
.unwrap();
thread
.unwrap()
}
pub async fn run_until_first_tool_call(
@ -471,7 +468,7 @@ pub fn get_zed_path() -> PathBuf {
while zed_path
.file_name()
.map_or(true, |name| name.to_string_lossy() != "debug")
.is_none_or(|name| name.to_string_lossy() != "debug")
{
if !zed_path.pop() {
panic!("Could not find target directory");

View file

@ -50,7 +50,11 @@ impl AgentServer for Gemini {
let Some(command) =
AgentServerCommand::resolve("gemini", &[ACP_ARG], None, settings, &project, cx).await
else {
anyhow::bail!("Failed to find gemini binary");
return Err(LoadError::NotInstalled {
error_message: "Failed to find Gemini CLI binary".into(),
install_message: "Install Gemini CLI".into(),
install_command: "npm install -g @google/gemini-cli@latest".into()
}.into());
};
let result = crate::acp::connect(server_name, command.clone(), &root_dir, cx).await;
@ -75,10 +79,11 @@ impl AgentServer for Gemini {
if !supported {
return Err(LoadError::Unsupported {
error_message: format!(
"Your installed version of Gemini {} doesn't support the Agentic Coding Protocol (ACP).",
"Your installed version of Gemini CLI ({}, version {}) doesn't support the Agentic Coding Protocol (ACP).",
command.path.to_string_lossy(),
current_version
).into(),
upgrade_message: "Upgrade Gemini to Latest".into(),
upgrade_message: "Upgrade Gemini CLI to latest".into(),
upgrade_command: "npm install -g @google/gemini-cli@latest".into(),
}.into())
}

View file

@ -58,7 +58,7 @@ impl AgentProfileSettings {
|| self
.context_servers
.get(server_id)
.map_or(false, |preset| preset.tools.get(tool_name) == Some(&true))
.is_some_and(|preset| preset.tools.get(tool_name) == Some(&true))
}
}

View file

@ -116,15 +116,15 @@ pub struct LanguageModelParameters {
impl LanguageModelParameters {
pub fn matches(&self, model: &Arc<dyn LanguageModel>) -> bool {
if let Some(provider) = &self.provider {
if provider.0 != model.provider_id().0 {
return false;
}
if let Some(provider) = &self.provider
&& provider.0 != model.provider_id().0
{
return false;
}
if let Some(setting_model) = &self.model {
if *setting_model != model.id().0 {
return false;
}
if let Some(setting_model) = &self.model
&& *setting_model != model.id().0
{
return false;
}
true
}

View file

@ -3,8 +3,10 @@ mod entry_view_state;
mod message_editor;
mod model_selector;
mod model_selector_popover;
mod thread_history;
mod thread_view;
pub use model_selector::AcpModelSelector;
pub use model_selector_popover::AcpModelSelectorPopover;
pub use thread_history::*;
pub use thread_view::AcpThreadView;

View file

@ -763,14 +763,16 @@ fn confirm_completion_callback(
message_editor
.clone()
.update(cx, |message_editor, cx| {
message_editor.confirm_completion(
crease_text,
start,
content_len,
mention_uri,
window,
cx,
)
message_editor
.confirm_completion(
crease_text,
start,
content_len,
mention_uri,
window,
cx,
)
.detach();
})
.ok();
});
@ -795,7 +797,7 @@ impl MentionCompletion {
&& line
.chars()
.nth(last_mention_start - 1)
.map_or(false, |c| !c.is_whitespace())
.is_some_and(|c| !c.is_whitespace())
{
return None;
}

View file

@ -26,7 +26,7 @@ use gpui::{
};
use language::{Buffer, Language};
use language_model::LanguageModelImage;
use project::{CompletionIntent, Project, ProjectPath, Worktree};
use project::{Project, ProjectPath, Worktree};
use rope::Point;
use settings::Settings;
use std::{
@ -134,8 +134,8 @@ impl MessageEditor {
if prevent_slash_commands {
subscriptions.push(cx.subscribe_in(&editor, window, {
let semantics_provider = semantics_provider.clone();
move |this, editor, event, window, cx| match event {
EditorEvent::Edited { .. } => {
move |this, editor, event, window, cx| {
if let EditorEvent::Edited { .. } = event {
this.highlight_slash_command(
semantics_provider.clone(),
editor.clone(),
@ -143,7 +143,6 @@ impl MessageEditor {
cx,
);
}
_ => {}
}
}));
}
@ -202,18 +201,18 @@ impl MessageEditor {
mention_uri: MentionUri,
window: &mut Window,
cx: &mut Context<Self>,
) {
) -> Task<()> {
let snapshot = self
.editor
.update(cx, |editor, cx| editor.snapshot(window, cx));
let Some((excerpt_id, _, _)) = snapshot.buffer_snapshot.as_singleton() else {
return;
return Task::ready(());
};
let Some(anchor) = snapshot
.buffer_snapshot
.anchor_in_excerpt(*excerpt_id, start)
else {
return;
return Task::ready(());
};
if let MentionUri::File { abs_path, .. } = &mention_uri {
@ -228,7 +227,7 @@ impl MessageEditor {
.read(cx)
.project_path_for_absolute_path(abs_path, cx)
else {
return;
return Task::ready(());
};
let image = cx
.spawn(async move |_, cx| {
@ -252,9 +251,9 @@ impl MessageEditor {
window,
cx,
) else {
return;
return Task::ready(());
};
self.confirm_mention_for_image(
return self.confirm_mention_for_image(
crease_id,
anchor,
Some(abs_path.clone()),
@ -262,7 +261,6 @@ impl MessageEditor {
window,
cx,
);
return;
}
}
@ -276,27 +274,28 @@ impl MessageEditor {
window,
cx,
) else {
return;
return Task::ready(());
};
match mention_uri {
MentionUri::Fetch { url } => {
self.confirm_mention_for_fetch(crease_id, anchor, url, window, cx);
self.confirm_mention_for_fetch(crease_id, anchor, url, window, cx)
}
MentionUri::Directory { abs_path } => {
self.confirm_mention_for_directory(crease_id, anchor, abs_path, window, cx);
self.confirm_mention_for_directory(crease_id, anchor, abs_path, window, cx)
}
MentionUri::Thread { id, name } => {
self.confirm_mention_for_thread(crease_id, anchor, id, name, window, cx);
self.confirm_mention_for_thread(crease_id, anchor, id, name, window, cx)
}
MentionUri::TextThread { path, name } => {
self.confirm_mention_for_text_thread(crease_id, anchor, path, name, window, cx);
self.confirm_mention_for_text_thread(crease_id, anchor, path, name, window, cx)
}
MentionUri::File { .. }
| MentionUri::Symbol { .. }
| MentionUri::Rule { .. }
| MentionUri::Selection { .. } => {
self.mention_set.insert_uri(crease_id, mention_uri.clone());
Task::ready(())
}
}
}
@ -308,7 +307,7 @@ impl MessageEditor {
abs_path: PathBuf,
window: &mut Window,
cx: &mut Context<Self>,
) {
) -> Task<()> {
fn collect_files_in_path(worktree: &Worktree, path: &Path) -> Vec<(Arc<Path>, PathBuf)> {
let mut files = Vec::new();
@ -331,13 +330,13 @@ impl MessageEditor {
.read(cx)
.project_path_for_absolute_path(&abs_path, cx)
else {
return;
return Task::ready(());
};
let Some(entry) = self.project.read(cx).entry_for_path(&project_path, cx) else {
return;
return Task::ready(());
};
let Some(worktree) = self.project.read(cx).worktree_for_entry(entry.id, cx) else {
return;
return Task::ready(());
};
let project = self.project.clone();
let task = cx.spawn(async move |_, cx| {
@ -396,7 +395,9 @@ impl MessageEditor {
})
.shared();
self.mention_set.directories.insert(abs_path, task.clone());
self.mention_set
.directories
.insert(abs_path.clone(), task.clone());
let editor = self.editor.clone();
cx.spawn_in(window, async move |this, cx| {
@ -414,9 +415,12 @@ impl MessageEditor {
editor.remove_creases([crease_id], cx);
})
.ok();
this.update(cx, |this, _cx| {
this.mention_set.directories.remove(&abs_path);
})
.ok();
}
})
.detach();
}
fn confirm_mention_for_fetch(
@ -426,13 +430,13 @@ impl MessageEditor {
url: url::Url,
window: &mut Window,
cx: &mut Context<Self>,
) {
) -> Task<()> {
let Some(http_client) = self
.workspace
.update(cx, |workspace, _cx| workspace.client().http_client())
.ok()
else {
return;
return Task::ready(());
};
let url_string = url.to_string();
@ -450,9 +454,9 @@ impl MessageEditor {
cx.spawn_in(window, async move |this, cx| {
let fetch = fetch.await.notify_async_err(cx);
this.update(cx, |this, cx| {
let mention_uri = MentionUri::Fetch { url };
if fetch.is_some() {
this.mention_set.insert_uri(crease_id, mention_uri.clone());
this.mention_set
.insert_uri(crease_id, MentionUri::Fetch { url });
} else {
// Remove crease if we failed to fetch
this.editor.update(cx, |editor, cx| {
@ -461,11 +465,11 @@ impl MessageEditor {
});
editor.remove_creases([crease_id], cx);
});
this.mention_set.fetch_results.remove(&url);
}
})
.ok();
})
.detach();
}
pub fn confirm_mention_for_selection(
@ -528,7 +532,7 @@ impl MessageEditor {
name: String,
window: &mut Window,
cx: &mut Context<Self>,
) {
) -> Task<()> {
let uri = MentionUri::Thread {
id: id.clone(),
name,
@ -546,7 +550,7 @@ impl MessageEditor {
})
.shared();
self.mention_set.insert_thread(id, task.clone());
self.mention_set.insert_thread(id.clone(), task.clone());
let editor = self.editor.clone();
cx.spawn_in(window, async move |this, cx| {
@ -564,9 +568,12 @@ impl MessageEditor {
editor.remove_creases([crease_id], cx);
})
.ok();
this.update(cx, |this, _| {
this.mention_set.thread_summaries.remove(&id);
})
.ok();
}
})
.detach();
}
fn confirm_mention_for_text_thread(
@ -577,7 +584,7 @@ impl MessageEditor {
name: String,
window: &mut Window,
cx: &mut Context<Self>,
) {
) -> Task<()> {
let uri = MentionUri::TextThread {
path: path.clone(),
name,
@ -595,7 +602,8 @@ impl MessageEditor {
})
.shared();
self.mention_set.insert_text_thread(path, task.clone());
self.mention_set
.insert_text_thread(path.clone(), task.clone());
let editor = self.editor.clone();
cx.spawn_in(window, async move |this, cx| {
@ -613,9 +621,12 @@ impl MessageEditor {
editor.remove_creases([crease_id], cx);
})
.ok();
this.update(cx, |this, _| {
this.mention_set.text_thread_summaries.remove(&path);
})
.ok();
}
})
.detach();
}
pub fn contents(
@ -784,13 +795,15 @@ impl MessageEditor {
) else {
return;
};
self.confirm_mention_for_image(crease_id, anchor, None, task, window, cx);
self.confirm_mention_for_image(crease_id, anchor, None, task, window, cx)
.detach();
}
}
pub fn insert_dragged_files(
&self,
&mut self,
paths: Vec<project::ProjectPath>,
added_worktrees: Vec<Entity<Worktree>>,
window: &mut Window,
cx: &mut Context<Self>,
) {
@ -798,6 +811,7 @@ impl MessageEditor {
let Some(buffer) = buffer.read(cx).as_singleton() else {
return;
};
let mut tasks = Vec::new();
for path in paths {
let Some(entry) = self.project.read(cx).entry_for_path(&path, cx) else {
continue;
@ -805,39 +819,44 @@ impl MessageEditor {
let Some(abs_path) = self.project.read(cx).absolute_path(&path, cx) else {
continue;
};
let anchor = buffer.update(cx, |buffer, _cx| buffer.anchor_before(buffer.len()));
let path_prefix = abs_path
.file_name()
.unwrap_or(path.path.as_os_str())
.display()
.to_string();
let Some(completion) = ContextPickerCompletionProvider::completion_for_path(
path,
&path_prefix,
false,
entry.is_dir(),
anchor..anchor,
cx.weak_entity(),
self.project.clone(),
cx,
) else {
continue;
let (file_name, _) =
crate::context_picker::file_context_picker::extract_file_name_and_directory(
&path.path,
&path_prefix,
);
let uri = if entry.is_dir() {
MentionUri::Directory { abs_path }
} else {
MentionUri::File { abs_path }
};
let new_text = format!("{} ", uri.as_link());
let content_len = new_text.len() - 1;
let anchor = buffer.update(cx, |buffer, _cx| buffer.anchor_before(buffer.len()));
self.editor.update(cx, |message_editor, cx| {
message_editor.edit(
[(
multi_buffer::Anchor::max()..multi_buffer::Anchor::max(),
completion.new_text,
new_text,
)],
cx,
);
});
if let Some(confirm) = completion.confirm.clone() {
confirm(CompletionIntent::Complete, window, cx);
}
tasks.push(self.confirm_completion(file_name, anchor, content_len, uri, window, cx));
}
cx.spawn(async move |_, _| {
join_all(tasks).await;
drop(added_worktrees);
})
.detach();
}
pub fn set_read_only(&mut self, read_only: bool, cx: &mut Context<Self>) {
@ -855,7 +874,7 @@ impl MessageEditor {
image: Shared<Task<Result<Arc<Image>, String>>>,
window: &mut Window,
cx: &mut Context<Self>,
) {
) -> Task<()> {
let editor = self.editor.clone();
let task = cx
.spawn_in(window, {
@ -900,9 +919,12 @@ impl MessageEditor {
editor.remove_creases([crease_id], cx);
})
.ok();
this.update(cx, |this, _cx| {
this.mention_set.images.remove(&crease_id);
})
.ok();
}
})
.detach();
}
pub fn set_mode(&mut self, mode: EditorMode, cx: &mut Context<Self>) {
@ -1529,14 +1551,14 @@ impl SemanticsProvider for SlashCommandSemanticsProvider {
return None;
}
let range = snapshot.anchor_after(start)..snapshot.anchor_after(end);
return Some(Task::ready(vec![project::Hover {
Some(Task::ready(vec![project::Hover {
contents: vec![project::HoverBlock {
text: "Slash commands are not supported".into(),
kind: project::HoverBlockKind::PlainText,
}],
range: Some(range),
language: None,
}]));
}]))
}
fn inline_values(
@ -1640,7 +1662,7 @@ mod tests {
use serde_json::json;
use text::Point;
use ui::{App, Context, IntoElement, Render, SharedString, Window};
use util::path;
use util::{path, uri};
use workspace::{AppState, Item, Workspace};
use crate::acp::{
@ -1950,13 +1972,12 @@ mod tests {
editor.confirm_completion(&editor::actions::ConfirmCompletion::default(), window, cx);
});
let url_one = uri!("file:///dir/a/one.txt");
editor.update(&mut cx, |editor, cx| {
assert_eq!(editor.text(cx), "Lorem [@one.txt](file:///dir/a/one.txt) ");
let text = editor.text(cx);
assert_eq!(text, format!("Lorem [@one.txt]({url_one}) "));
assert!(!editor.has_visible_completions_menu());
assert_eq!(
fold_ranges(editor, cx),
vec![Point::new(0, 6)..Point::new(0, 39)]
);
assert_eq!(fold_ranges(editor, cx).len(), 1);
});
let contents = message_editor
@ -1977,47 +1998,35 @@ mod tests {
contents,
[Mention::Text {
content: "1".into(),
uri: "file:///dir/a/one.txt".parse().unwrap()
uri: url_one.parse().unwrap()
}]
);
cx.simulate_input(" ");
editor.update(&mut cx, |editor, cx| {
assert_eq!(editor.text(cx), "Lorem [@one.txt](file:///dir/a/one.txt) ");
let text = editor.text(cx);
assert_eq!(text, format!("Lorem [@one.txt]({url_one}) "));
assert!(!editor.has_visible_completions_menu());
assert_eq!(
fold_ranges(editor, cx),
vec![Point::new(0, 6)..Point::new(0, 39)]
);
assert_eq!(fold_ranges(editor, cx).len(), 1);
});
cx.simulate_input("Ipsum ");
editor.update(&mut cx, |editor, cx| {
assert_eq!(
editor.text(cx),
"Lorem [@one.txt](file:///dir/a/one.txt) Ipsum ",
);
let text = editor.text(cx);
assert_eq!(text, format!("Lorem [@one.txt]({url_one}) Ipsum "),);
assert!(!editor.has_visible_completions_menu());
assert_eq!(
fold_ranges(editor, cx),
vec![Point::new(0, 6)..Point::new(0, 39)]
);
assert_eq!(fold_ranges(editor, cx).len(), 1);
});
cx.simulate_input("@file ");
editor.update(&mut cx, |editor, cx| {
assert_eq!(
editor.text(cx),
"Lorem [@one.txt](file:///dir/a/one.txt) Ipsum @file ",
);
let text = editor.text(cx);
assert_eq!(text, format!("Lorem [@one.txt]({url_one}) Ipsum @file "),);
assert!(editor.has_visible_completions_menu());
assert_eq!(
fold_ranges(editor, cx),
vec![Point::new(0, 6)..Point::new(0, 39)]
);
assert_eq!(fold_ranges(editor, cx).len(), 1);
});
editor.update_in(&mut cx, |editor, window, cx| {
@ -2041,28 +2050,23 @@ mod tests {
.collect::<Vec<_>>();
assert_eq!(contents.len(), 2);
let url_eight = uri!("file:///dir/b/eight.txt");
pretty_assertions::assert_eq!(
contents[1],
Mention::Text {
content: "8".to_string(),
uri: "file:///dir/b/eight.txt".parse().unwrap(),
uri: url_eight.parse().unwrap(),
}
);
editor.update(&mut cx, |editor, cx| {
assert_eq!(
editor.text(cx),
"Lorem [@one.txt](file:///dir/a/one.txt) Ipsum [@eight.txt](file:///dir/b/eight.txt) "
);
assert!(!editor.has_visible_completions_menu());
assert_eq!(
fold_ranges(editor, cx),
vec![
Point::new(0, 6)..Point::new(0, 39),
Point::new(0, 47)..Point::new(0, 84)
]
);
});
assert_eq!(
editor.text(cx),
format!("Lorem [@one.txt]({url_one}) Ipsum [@eight.txt]({url_eight}) ")
);
assert!(!editor.has_visible_completions_menu());
assert_eq!(fold_ranges(editor, cx).len(), 2);
});
let plain_text_language = Arc::new(language::Language::new(
language::LanguageConfig {
@ -2108,7 +2112,7 @@ mod tests {
let fake_language_server = fake_language_servers.next().await.unwrap();
fake_language_server.set_request_handler::<lsp::WorkspaceSymbolRequest, _, _>(
|_, _| async move {
move |_, _| async move {
Ok(Some(lsp::WorkspaceSymbolResponse::Flat(vec![
#[allow(deprecated)]
lsp::SymbolInformation {
@ -2132,18 +2136,13 @@ mod tests {
cx.simulate_input("@symbol ");
editor.update(&mut cx, |editor, cx| {
assert_eq!(
editor.text(cx),
"Lorem [@one.txt](file:///dir/a/one.txt) Ipsum [@eight.txt](file:///dir/b/eight.txt) @symbol "
);
assert!(editor.has_visible_completions_menu());
assert_eq!(
current_completion_labels(editor),
&[
"MySymbol",
]
);
});
assert_eq!(
editor.text(cx),
format!("Lorem [@one.txt]({url_one}) Ipsum [@eight.txt]({url_eight}) @symbol ")
);
assert!(editor.has_visible_completions_menu());
assert_eq!(current_completion_labels(editor), &["MySymbol"]);
});
editor.update_in(&mut cx, |editor, window, cx| {
editor.confirm_completion(&editor::actions::ConfirmCompletion::default(), window, cx);
@ -2165,18 +2164,16 @@ mod tests {
contents[2],
Mention::Text {
content: "1".into(),
uri: "file:///dir/a/one.txt?symbol=MySymbol#L1:1"
.parse()
.unwrap(),
uri: format!("{url_one}?symbol=MySymbol#L1:1").parse().unwrap(),
}
);
cx.run_until_parked();
editor.read_with(&mut cx, |editor, cx| {
editor.read_with(&cx, |editor, cx| {
assert_eq!(
editor.text(cx),
"Lorem [@one.txt](file:///dir/a/one.txt) Ipsum [@eight.txt](file:///dir/b/eight.txt) [@MySymbol](file:///dir/a/one.txt?symbol=MySymbol#L1:1) "
format!("Lorem [@one.txt]({url_one}) Ipsum [@eight.txt]({url_eight}) [@MySymbol]({url_one}?symbol=MySymbol#L1:1) ")
);
});
}

View file

@ -0,0 +1,766 @@
use crate::RemoveSelectedThread;
use agent2::{HistoryEntry, HistoryStore};
use chrono::{Datelike as _, Local, NaiveDate, TimeDelta};
use editor::{Editor, EditorEvent};
use fuzzy::{StringMatch, StringMatchCandidate};
use gpui::{
App, Empty, Entity, EventEmitter, FocusHandle, Focusable, ScrollStrategy, Stateful, Task,
UniformListScrollHandle, Window, uniform_list,
};
use std::{fmt::Display, ops::Range, sync::Arc};
use time::{OffsetDateTime, UtcOffset};
use ui::{
HighlightedLabel, IconButtonShape, ListItem, ListItemSpacing, Scrollbar, ScrollbarState,
Tooltip, prelude::*,
};
use util::ResultExt;
pub struct AcpThreadHistory {
pub(crate) history_store: Entity<HistoryStore>,
scroll_handle: UniformListScrollHandle,
selected_index: usize,
hovered_index: Option<usize>,
search_editor: Entity<Editor>,
all_entries: Arc<Vec<HistoryEntry>>,
// When the search is empty, we display date separators between history entries
// This vector contains an enum of either a separator or an actual entry
separated_items: Vec<ListItemType>,
// Maps entry indexes to list item indexes
separated_item_indexes: Vec<u32>,
_separated_items_task: Option<Task<()>>,
search_state: SearchState,
scrollbar_visibility: bool,
scrollbar_state: ScrollbarState,
local_timezone: UtcOffset,
_subscriptions: Vec<gpui::Subscription>,
}
enum SearchState {
Empty,
Searching {
query: SharedString,
_task: Task<()>,
},
Searched {
query: SharedString,
matches: Vec<StringMatch>,
},
}
enum ListItemType {
BucketSeparator(TimeBucket),
Entry {
index: usize,
format: EntryTimeFormat,
},
}
pub enum ThreadHistoryEvent {
Open(HistoryEntry),
}
impl EventEmitter<ThreadHistoryEvent> for AcpThreadHistory {}
impl AcpThreadHistory {
pub(crate) fn new(
history_store: Entity<agent2::HistoryStore>,
window: &mut Window,
cx: &mut Context<Self>,
) -> Self {
let search_editor = cx.new(|cx| {
let mut editor = Editor::single_line(window, cx);
editor.set_placeholder_text("Search threads...", cx);
editor
});
let search_editor_subscription =
cx.subscribe(&search_editor, |this, search_editor, event, cx| {
if let EditorEvent::BufferEdited = event {
let query = search_editor.read(cx).text(cx);
this.search(query.into(), cx);
}
});
let history_store_subscription = cx.observe(&history_store, |this, _, cx| {
this.update_all_entries(cx);
});
let scroll_handle = UniformListScrollHandle::default();
let scrollbar_state = ScrollbarState::new(scroll_handle.clone());
let mut this = Self {
history_store,
scroll_handle,
selected_index: 0,
hovered_index: None,
search_state: SearchState::Empty,
all_entries: Default::default(),
separated_items: Default::default(),
separated_item_indexes: Default::default(),
search_editor,
scrollbar_visibility: true,
scrollbar_state,
local_timezone: UtcOffset::from_whole_seconds(
chrono::Local::now().offset().local_minus_utc(),
)
.unwrap(),
_subscriptions: vec![search_editor_subscription, history_store_subscription],
_separated_items_task: None,
};
this.update_all_entries(cx);
this
}
fn update_all_entries(&mut self, cx: &mut Context<Self>) {
let new_entries: Arc<Vec<HistoryEntry>> = self
.history_store
.update(cx, |store, cx| store.entries(cx))
.into();
self._separated_items_task.take();
let mut items = Vec::with_capacity(new_entries.len() + 1);
let mut indexes = Vec::with_capacity(new_entries.len() + 1);
let bg_task = cx.background_spawn(async move {
let mut bucket = None;
let today = Local::now().naive_local().date();
for (index, entry) in new_entries.iter().enumerate() {
let entry_date = entry
.updated_at()
.with_timezone(&Local)
.naive_local()
.date();
let entry_bucket = TimeBucket::from_dates(today, entry_date);
if Some(entry_bucket) != bucket {
bucket = Some(entry_bucket);
items.push(ListItemType::BucketSeparator(entry_bucket));
}
indexes.push(items.len() as u32);
items.push(ListItemType::Entry {
index,
format: entry_bucket.into(),
});
}
(new_entries, items, indexes)
});
let task = cx.spawn(async move |this, cx| {
let (new_entries, items, indexes) = bg_task.await;
this.update(cx, |this, cx| {
let previously_selected_entry =
this.all_entries.get(this.selected_index).map(|e| e.id());
this.all_entries = new_entries;
this.separated_items = items;
this.separated_item_indexes = indexes;
match &this.search_state {
SearchState::Empty => {
if this.selected_index >= this.all_entries.len() {
this.set_selected_entry_index(
this.all_entries.len().saturating_sub(1),
cx,
);
} else if let Some(prev_id) = previously_selected_entry
&& let Some(new_ix) = this
.all_entries
.iter()
.position(|probe| probe.id() == prev_id)
{
this.set_selected_entry_index(new_ix, cx);
}
}
SearchState::Searching { query, .. } | SearchState::Searched { query, .. } => {
this.search(query.clone(), cx);
}
}
cx.notify();
})
.log_err();
});
self._separated_items_task = Some(task);
}
fn search(&mut self, query: SharedString, cx: &mut Context<Self>) {
if query.is_empty() {
self.search_state = SearchState::Empty;
cx.notify();
return;
}
let all_entries = self.all_entries.clone();
let fuzzy_search_task = cx.background_spawn({
let query = query.clone();
let executor = cx.background_executor().clone();
async move {
let mut candidates = Vec::with_capacity(all_entries.len());
for (idx, entry) in all_entries.iter().enumerate() {
candidates.push(StringMatchCandidate::new(idx, entry.title()));
}
const MAX_MATCHES: usize = 100;
fuzzy::match_strings(
&candidates,
&query,
false,
true,
MAX_MATCHES,
&Default::default(),
executor,
)
.await
}
});
let task = cx.spawn({
let query = query.clone();
async move |this, cx| {
let matches = fuzzy_search_task.await;
this.update(cx, |this, cx| {
let SearchState::Searching {
query: current_query,
_task,
} = &this.search_state
else {
return;
};
if &query == current_query {
this.search_state = SearchState::Searched {
query: query.clone(),
matches,
};
this.set_selected_entry_index(0, cx);
cx.notify();
};
})
.log_err();
}
});
self.search_state = SearchState::Searching { query, _task: task };
cx.notify();
}
fn matched_count(&self) -> usize {
match &self.search_state {
SearchState::Empty => self.all_entries.len(),
SearchState::Searching { .. } => 0,
SearchState::Searched { matches, .. } => matches.len(),
}
}
fn list_item_count(&self) -> usize {
match &self.search_state {
SearchState::Empty => self.separated_items.len(),
SearchState::Searching { .. } => 0,
SearchState::Searched { matches, .. } => matches.len(),
}
}
fn search_produced_no_matches(&self) -> bool {
match &self.search_state {
SearchState::Empty => false,
SearchState::Searching { .. } => false,
SearchState::Searched { matches, .. } => matches.is_empty(),
}
}
fn get_match(&self, ix: usize) -> Option<&HistoryEntry> {
match &self.search_state {
SearchState::Empty => self.all_entries.get(ix),
SearchState::Searching { .. } => None,
SearchState::Searched { matches, .. } => matches
.get(ix)
.and_then(|m| self.all_entries.get(m.candidate_id)),
}
}
pub fn select_previous(
&mut self,
_: &menu::SelectPrevious,
_window: &mut Window,
cx: &mut Context<Self>,
) {
let count = self.matched_count();
if count > 0 {
if self.selected_index == 0 {
self.set_selected_entry_index(count - 1, cx);
} else {
self.set_selected_entry_index(self.selected_index - 1, cx);
}
}
}
pub fn select_next(
&mut self,
_: &menu::SelectNext,
_window: &mut Window,
cx: &mut Context<Self>,
) {
let count = self.matched_count();
if count > 0 {
if self.selected_index == count - 1 {
self.set_selected_entry_index(0, cx);
} else {
self.set_selected_entry_index(self.selected_index + 1, cx);
}
}
}
fn select_first(
&mut self,
_: &menu::SelectFirst,
_window: &mut Window,
cx: &mut Context<Self>,
) {
let count = self.matched_count();
if count > 0 {
self.set_selected_entry_index(0, cx);
}
}
fn select_last(&mut self, _: &menu::SelectLast, _window: &mut Window, cx: &mut Context<Self>) {
let count = self.matched_count();
if count > 0 {
self.set_selected_entry_index(count - 1, cx);
}
}
fn set_selected_entry_index(&mut self, entry_index: usize, cx: &mut Context<Self>) {
self.selected_index = entry_index;
let scroll_ix = match self.search_state {
SearchState::Empty | SearchState::Searching { .. } => self
.separated_item_indexes
.get(entry_index)
.map(|ix| *ix as usize)
.unwrap_or(entry_index + 1),
SearchState::Searched { .. } => entry_index,
};
self.scroll_handle
.scroll_to_item(scroll_ix, ScrollStrategy::Top);
cx.notify();
}
fn render_scrollbar(&self, cx: &mut Context<Self>) -> Option<Stateful<Div>> {
if !(self.scrollbar_visibility || self.scrollbar_state.is_dragging()) {
return None;
}
Some(
div()
.occlude()
.id("thread-history-scroll")
.h_full()
.bg(cx.theme().colors().panel_background.opacity(0.8))
.border_l_1()
.border_color(cx.theme().colors().border_variant)
.absolute()
.right_1()
.top_0()
.bottom_0()
.w_4()
.pl_1()
.cursor_default()
.on_mouse_move(cx.listener(|_, _, _window, cx| {
cx.notify();
cx.stop_propagation()
}))
.on_hover(|_, _window, cx| {
cx.stop_propagation();
})
.on_any_mouse_down(|_, _window, cx| {
cx.stop_propagation();
})
.on_scroll_wheel(cx.listener(|_, _, _window, cx| {
cx.notify();
}))
.children(Scrollbar::vertical(self.scrollbar_state.clone())),
)
}
fn confirm(&mut self, _: &menu::Confirm, _window: &mut Window, cx: &mut Context<Self>) {
self.confirm_entry(self.selected_index, cx);
}
fn confirm_entry(&mut self, ix: usize, cx: &mut Context<Self>) {
let Some(entry) = self.get_match(ix) else {
return;
};
cx.emit(ThreadHistoryEvent::Open(entry.clone()));
}
fn remove_selected_thread(
&mut self,
_: &RemoveSelectedThread,
_window: &mut Window,
cx: &mut Context<Self>,
) {
self.remove_thread(self.selected_index, cx)
}
fn remove_thread(&mut self, ix: usize, cx: &mut Context<Self>) {
let Some(entry) = self.get_match(ix) else {
return;
};
let task = match entry {
HistoryEntry::AcpThread(thread) => self
.history_store
.update(cx, |this, cx| this.delete_thread(thread.id.clone(), cx)),
HistoryEntry::TextThread(context) => self.history_store.update(cx, |this, cx| {
this.delete_text_thread(context.path.clone(), cx)
}),
};
task.detach_and_log_err(cx);
}
fn list_items(
&mut self,
range: Range<usize>,
_window: &mut Window,
cx: &mut Context<Self>,
) -> Vec<AnyElement> {
match &self.search_state {
SearchState::Empty => self
.separated_items
.get(range)
.iter()
.flat_map(|items| {
items
.iter()
.map(|item| self.render_list_item(item, vec![], cx))
})
.collect(),
SearchState::Searched { matches, .. } => matches[range]
.iter()
.filter_map(|m| {
let entry = self.all_entries.get(m.candidate_id)?;
Some(self.render_history_entry(
entry,
EntryTimeFormat::DateAndTime,
m.candidate_id,
m.positions.clone(),
cx,
))
})
.collect(),
SearchState::Searching { .. } => {
vec![]
}
}
}
fn render_list_item(
&self,
item: &ListItemType,
highlight_positions: Vec<usize>,
cx: &Context<Self>,
) -> AnyElement {
match item {
ListItemType::Entry { index, format } => match self.all_entries.get(*index) {
Some(entry) => self
.render_history_entry(entry, *format, *index, highlight_positions, cx)
.into_any(),
None => Empty.into_any_element(),
},
ListItemType::BucketSeparator(bucket) => div()
.px(DynamicSpacing::Base06.rems(cx))
.pt_2()
.pb_1()
.child(
Label::new(bucket.to_string())
.size(LabelSize::XSmall)
.color(Color::Muted),
)
.into_any_element(),
}
}
fn render_history_entry(
&self,
entry: &HistoryEntry,
format: EntryTimeFormat,
list_entry_ix: usize,
highlight_positions: Vec<usize>,
cx: &Context<Self>,
) -> AnyElement {
let selected = list_entry_ix == self.selected_index;
let hovered = Some(list_entry_ix) == self.hovered_index;
let timestamp = entry.updated_at().timestamp();
let thread_timestamp = format.format_timestamp(timestamp, self.local_timezone);
h_flex()
.w_full()
.pb_1()
.child(
ListItem::new(list_entry_ix)
.rounded()
.toggle_state(selected)
.spacing(ListItemSpacing::Sparse)
.start_slot(
h_flex()
.w_full()
.gap_2()
.justify_between()
.child(
HighlightedLabel::new(entry.title(), highlight_positions)
.size(LabelSize::Small)
.truncate(),
)
.child(
Label::new(thread_timestamp)
.color(Color::Muted)
.size(LabelSize::XSmall),
),
)
.on_hover(cx.listener(move |this, is_hovered, _window, cx| {
if *is_hovered {
this.hovered_index = Some(list_entry_ix);
} else if this.hovered_index == Some(list_entry_ix) {
this.hovered_index = None;
}
cx.notify();
}))
.end_slot::<IconButton>(if hovered || selected {
Some(
IconButton::new("delete", IconName::Trash)
.shape(IconButtonShape::Square)
.icon_size(IconSize::XSmall)
.icon_color(Color::Muted)
.tooltip(move |window, cx| {
Tooltip::for_action("Delete", &RemoveSelectedThread, window, cx)
})
.on_click(cx.listener(move |this, _, _, cx| {
this.remove_thread(list_entry_ix, cx)
})),
)
} else {
None
})
.on_click(
cx.listener(move |this, _, _, cx| this.confirm_entry(list_entry_ix, cx)),
),
)
.into_any_element()
}
}
impl Focusable for AcpThreadHistory {
fn focus_handle(&self, cx: &App) -> FocusHandle {
self.search_editor.focus_handle(cx)
}
}
impl Render for AcpThreadHistory {
fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
v_flex()
.key_context("ThreadHistory")
.size_full()
.on_action(cx.listener(Self::select_previous))
.on_action(cx.listener(Self::select_next))
.on_action(cx.listener(Self::select_first))
.on_action(cx.listener(Self::select_last))
.on_action(cx.listener(Self::confirm))
.on_action(cx.listener(Self::remove_selected_thread))
.when(!self.all_entries.is_empty(), |parent| {
parent.child(
h_flex()
.h(px(41.)) // Match the toolbar perfectly
.w_full()
.py_1()
.px_2()
.gap_2()
.justify_between()
.border_b_1()
.border_color(cx.theme().colors().border)
.child(
Icon::new(IconName::MagnifyingGlass)
.color(Color::Muted)
.size(IconSize::Small),
)
.child(self.search_editor.clone()),
)
})
.child({
let view = v_flex()
.id("list-container")
.relative()
.overflow_hidden()
.flex_grow();
if self.all_entries.is_empty() {
view.justify_center()
.child(
h_flex().w_full().justify_center().child(
Label::new("You don't have any past threads yet.")
.size(LabelSize::Small),
),
)
} else if self.search_produced_no_matches() {
view.justify_center().child(
h_flex().w_full().justify_center().child(
Label::new("No threads match your search.").size(LabelSize::Small),
),
)
} else {
view.pr_5()
.child(
uniform_list(
"thread-history",
self.list_item_count(),
cx.processor(|this, range: Range<usize>, window, cx| {
this.list_items(range, window, cx)
}),
)
.p_1()
.track_scroll(self.scroll_handle.clone())
.flex_grow(),
)
.when_some(self.render_scrollbar(cx), |div, scrollbar| {
div.child(scrollbar)
})
}
})
}
}
#[derive(Clone, Copy)]
pub enum EntryTimeFormat {
DateAndTime,
TimeOnly,
}
impl EntryTimeFormat {
fn format_timestamp(&self, timestamp: i64, timezone: UtcOffset) -> String {
let timestamp = OffsetDateTime::from_unix_timestamp(timestamp).unwrap();
match self {
EntryTimeFormat::DateAndTime => time_format::format_localized_timestamp(
timestamp,
OffsetDateTime::now_utc(),
timezone,
time_format::TimestampFormat::EnhancedAbsolute,
),
EntryTimeFormat::TimeOnly => time_format::format_time(timestamp),
}
}
}
impl From<TimeBucket> for EntryTimeFormat {
fn from(bucket: TimeBucket) -> Self {
match bucket {
TimeBucket::Today => EntryTimeFormat::TimeOnly,
TimeBucket::Yesterday => EntryTimeFormat::TimeOnly,
TimeBucket::ThisWeek => EntryTimeFormat::DateAndTime,
TimeBucket::PastWeek => EntryTimeFormat::DateAndTime,
TimeBucket::All => EntryTimeFormat::DateAndTime,
}
}
}
#[derive(PartialEq, Eq, Clone, Copy, Debug)]
enum TimeBucket {
Today,
Yesterday,
ThisWeek,
PastWeek,
All,
}
impl TimeBucket {
fn from_dates(reference: NaiveDate, date: NaiveDate) -> Self {
if date == reference {
return TimeBucket::Today;
}
if date == reference - TimeDelta::days(1) {
return TimeBucket::Yesterday;
}
let week = date.iso_week();
if reference.iso_week() == week {
return TimeBucket::ThisWeek;
}
let last_week = (reference - TimeDelta::days(7)).iso_week();
if week == last_week {
return TimeBucket::PastWeek;
}
TimeBucket::All
}
}
impl Display for TimeBucket {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
TimeBucket::Today => write!(f, "Today"),
TimeBucket::Yesterday => write!(f, "Yesterday"),
TimeBucket::ThisWeek => write!(f, "This Week"),
TimeBucket::PastWeek => write!(f, "Past Week"),
TimeBucket::All => write!(f, "All"),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use chrono::NaiveDate;
#[test]
fn test_time_bucket_from_dates() {
let today = NaiveDate::from_ymd_opt(2023, 1, 15).unwrap();
let date = today;
assert_eq!(TimeBucket::from_dates(today, date), TimeBucket::Today);
let date = NaiveDate::from_ymd_opt(2023, 1, 14).unwrap();
assert_eq!(TimeBucket::from_dates(today, date), TimeBucket::Yesterday);
let date = NaiveDate::from_ymd_opt(2023, 1, 13).unwrap();
assert_eq!(TimeBucket::from_dates(today, date), TimeBucket::ThisWeek);
let date = NaiveDate::from_ymd_opt(2023, 1, 11).unwrap();
assert_eq!(TimeBucket::from_dates(today, date), TimeBucket::ThisWeek);
let date = NaiveDate::from_ymd_opt(2023, 1, 8).unwrap();
assert_eq!(TimeBucket::from_dates(today, date), TimeBucket::PastWeek);
let date = NaiveDate::from_ymd_opt(2023, 1, 5).unwrap();
assert_eq!(TimeBucket::from_dates(today, date), TimeBucket::PastWeek);
// All: not in this week or last week
let date = NaiveDate::from_ymd_opt(2023, 1, 1).unwrap();
assert_eq!(TimeBucket::from_dates(today, date), TimeBucket::All);
// Test year boundary cases
let new_year = NaiveDate::from_ymd_opt(2023, 1, 1).unwrap();
let date = NaiveDate::from_ymd_opt(2022, 12, 31).unwrap();
assert_eq!(
TimeBucket::from_dates(new_year, date),
TimeBucket::Yesterday
);
let date = NaiveDate::from_ymd_opt(2022, 12, 28).unwrap();
assert_eq!(TimeBucket::from_dates(new_year, date), TimeBucket::ThisWeek);
}
}

View file

@ -9,6 +9,7 @@ use agent::{TextThreadStore, ThreadStore};
use agent_client_protocol::{self as acp};
use agent_servers::{AgentServer, ClaudeCode};
use agent_settings::{AgentProfileId, AgentSettings, CompletionMode, NotifyWhenAgentWaiting};
use agent2::{DbThreadMetadata, HistoryEntryId, HistoryStore};
use anyhow::bail;
use audio::{Audio, Sound};
use buffer_diff::BufferDiff;
@ -36,7 +37,7 @@ use rope::Point;
use settings::{Settings as _, SettingsStore};
use std::sync::Arc;
use std::time::Instant;
use std::{collections::BTreeMap, process::ExitStatus, rc::Rc, time::Duration};
use std::{collections::BTreeMap, rc::Rc, time::Duration};
use text::Anchor;
use theme::ThemeSettings;
use ui::{
@ -101,7 +102,7 @@ impl ProfileProvider for Entity<agent2::Thread> {
fn profiles_supported(&self, cx: &App) -> bool {
self.read(cx)
.model()
.map_or(false, |model| model.supports_tools())
.is_some_and(|model| model.supports_tools())
}
}
@ -110,6 +111,7 @@ pub struct AcpThreadView {
workspace: WeakEntity<Workspace>,
project: Entity<Project>,
thread_state: ThreadState,
history_store: Entity<HistoryStore>,
entry_view_state: Entity<EntryViewState>,
message_editor: Entity<MessageEditor>,
model_selector: Option<Entity<AcpModelSelectorPopover>>,
@ -147,16 +149,15 @@ enum ThreadState {
configuration_view: Option<AnyView>,
_subscription: Option<Subscription>,
},
ServerExited {
status: ExitStatus,
},
}
impl AcpThreadView {
pub fn new(
agent: Rc<dyn AgentServer>,
resume_thread: Option<DbThreadMetadata>,
workspace: WeakEntity<Workspace>,
project: Entity<Project>,
history_store: Entity<HistoryStore>,
thread_store: Entity<ThreadStore>,
text_thread_store: Entity<TextThreadStore>,
window: &mut Window,
@ -203,7 +204,7 @@ impl AcpThreadView {
workspace: workspace.clone(),
project: project.clone(),
entry_view_state,
thread_state: Self::initial_state(agent, workspace, project, window, cx),
thread_state: Self::initial_state(agent, resume_thread, workspace, project, window, cx),
message_editor,
model_selector: None,
profile_selector: None,
@ -221,6 +222,7 @@ impl AcpThreadView {
plan_expanded: false,
editor_expanded: false,
terminal_expanded: true,
history_store,
_subscriptions: subscriptions,
_cancel_task: None,
}
@ -228,6 +230,7 @@ impl AcpThreadView {
fn initial_state(
agent: Rc<dyn AgentServer>,
resume_thread: Option<DbThreadMetadata>,
workspace: WeakEntity<Workspace>,
project: Entity<Project>,
window: &mut Window,
@ -254,28 +257,27 @@ impl AcpThreadView {
}
};
// this.update_in(cx, |_this, _window, cx| {
// let status = connection.exit_status(cx);
// cx.spawn(async move |this, cx| {
// let status = status.await.ok();
// this.update(cx, |this, cx| {
// this.thread_state = ThreadState::ServerExited { status };
// cx.notify();
// })
// .ok();
// })
// .detach();
// })
// .ok();
let Some(result) = cx
.update(|_, cx| {
let result = if let Some(native_agent) = connection
.clone()
.downcast::<agent2::NativeAgentConnection>()
&& let Some(resume) = resume_thread.clone()
{
cx.update(|_, cx| {
native_agent
.0
.update(cx, |agent, cx| agent.open_thread(resume.id, cx))
})
.log_err()
} else {
cx.update(|_, cx| {
connection
.clone()
.new_thread(project.clone(), &root_dir, cx)
})
.log_err()
else {
};
let Some(result) = result else {
return;
};
@ -303,8 +305,22 @@ impl AcpThreadView {
let action_log_subscription =
cx.observe(&action_log, |_, _, cx| cx.notify());
this.list_state
.splice(0..0, thread.read(cx).entries().len());
let count = thread.read(cx).entries().len();
this.list_state.splice(0..0, count);
this.entry_view_state.update(cx, |view_state, cx| {
for ix in 0..count {
view_state.sync_entry(ix, &thread, window, cx);
}
});
if let Some(resume) = resume_thread {
this.history_store.update(cx, |history, cx| {
history.push_recently_opened_entry(
HistoryEntryId::AcpThread(resume.id),
cx,
);
});
}
AgentDiff::set_active_thread(&workspace, thread.clone(), window, cx);
@ -371,20 +387,21 @@ impl AcpThreadView {
let provider_id = provider_id.clone();
let this = this.clone();
move |_, ev, window, cx| {
if let language_model::Event::ProviderStateChanged(updated_provider_id) = &ev {
if &provider_id == updated_provider_id {
this.update(cx, |this, cx| {
this.thread_state = Self::initial_state(
agent.clone(),
this.workspace.clone(),
this.project.clone(),
window,
cx,
);
cx.notify();
})
.ok();
}
if let language_model::Event::ProviderStateChanged(updated_provider_id) = &ev
&& &provider_id == updated_provider_id
{
this.update(cx, |this, cx| {
this.thread_state = Self::initial_state(
agent.clone(),
None,
this.workspace.clone(),
this.project.clone(),
window,
cx,
);
cx.notify();
})
.ok();
}
}
});
@ -431,8 +448,7 @@ impl AcpThreadView {
ThreadState::Ready { thread, .. } => Some(thread),
ThreadState::Unauthenticated { .. }
| ThreadState::Loading { .. }
| ThreadState::LoadError(..)
| ThreadState::ServerExited { .. } => None,
| ThreadState::LoadError { .. } => None,
}
}
@ -442,7 +458,6 @@ impl AcpThreadView {
ThreadState::Loading { .. } => "Loading…".into(),
ThreadState::LoadError(_) => "Failed to load".into(),
ThreadState::Unauthenticated { .. } => "Authentication Required".into(),
ThreadState::ServerExited { .. } => "Server exited unexpectedly".into(),
}
}
@ -547,11 +562,17 @@ impl AcpThreadView {
}
fn send(&mut self, window: &mut Window, cx: &mut Context<Self>) {
if let Some(thread) = self.thread() {
if thread.read(cx).status() != ThreadStatus::Idle {
self.stop_current_and_send_new_message(window, cx);
return;
}
let Some(thread) = self.thread() else { return };
self.history_store.update(cx, |history, cx| {
history.push_recently_opened_entry(
HistoryEntryId::AcpThread(thread.read(cx).session_id().clone()),
cx,
);
});
if thread.read(cx).status() != ThreadStatus::Idle {
self.stop_current_and_send_new_message(window, cx);
return;
}
let contents = self
@ -628,25 +649,24 @@ impl AcpThreadView {
return;
};
if let Some(index) = self.editing_message.take() {
if let Some(editor) = self
if let Some(index) = self.editing_message.take()
&& let Some(editor) = self
.entry_view_state
.read(cx)
.entry(index)
.and_then(|e| e.message_editor())
.cloned()
{
editor.update(cx, |editor, cx| {
if let Some(user_message) = thread
.read(cx)
.entries()
.get(index)
.and_then(|e| e.user_message())
{
editor.set_message(user_message.chunks.clone(), window, cx);
}
})
}
{
editor.update(cx, |editor, cx| {
if let Some(user_message) = thread
.read(cx)
.entries()
.get(index)
.and_then(|e| e.user_message())
{
editor.set_message(user_message.chunks.clone(), window, cx);
}
})
};
self.focus_handle(cx).focus(window);
cx.notify();
@ -805,10 +825,11 @@ impl AcpThreadView {
cx,
);
}
AcpThreadEvent::ServerExited(status) => {
AcpThreadEvent::LoadError(error) => {
self.thread_retry_status.take();
self.thread_state = ThreadState::ServerExited { status: *status };
self.thread_state = ThreadState::LoadError(error.clone());
}
AcpThreadEvent::TitleUpdated | AcpThreadEvent::TokenUsageUpdated => {}
}
cx.notify();
}
@ -837,6 +858,7 @@ impl AcpThreadView {
} else {
this.thread_state = Self::initial_state(
agent,
None,
this.workspace.clone(),
project.clone(),
window,
@ -2115,7 +2137,7 @@ impl AcpThreadView {
.map(|view| div().px_4().w_full().max_w_128().child(view)),
)
.child(h_flex().mt_1p5().justify_center().children(
connection.auth_methods().into_iter().map(|method| {
connection.auth_methods().iter().map(|method| {
Button::new(SharedString::from(method.id.0.clone()), method.name.clone())
.on_click({
let method_id = method.id.clone();
@ -2127,28 +2149,6 @@ impl AcpThreadView {
))
}
fn render_server_exited(&self, status: ExitStatus, _cx: &Context<Self>) -> AnyElement {
v_flex()
.items_center()
.justify_center()
.child(self.render_error_agent_logo())
.child(
v_flex()
.mt_4()
.mb_2()
.gap_0p5()
.text_center()
.items_center()
.child(Headline::new("Server exited unexpectedly").size(HeadlineSize::Medium))
.child(
Label::new(format!("Exit status: {}", status.code().unwrap_or(-127)))
.size(LabelSize::Small)
.color(Color::Muted),
),
)
.into_any_element()
}
fn render_load_error(&self, e: &LoadError, cx: &Context<Self>) -> AnyElement {
let mut container = v_flex()
.items_center()
@ -2177,39 +2177,102 @@ impl AcpThreadView {
{
let upgrade_message = upgrade_message.clone();
let upgrade_command = upgrade_command.clone();
container = container.child(Button::new("upgrade", upgrade_message).on_click(
cx.listener(move |this, _, window, cx| {
this.workspace
.update(cx, |workspace, cx| {
let project = workspace.project().read(cx);
let cwd = project.first_project_directory(cx);
let shell = project.terminal_settings(&cwd, cx).shell.clone();
let spawn_in_terminal = task::SpawnInTerminal {
id: task::TaskId("install".to_string()),
full_label: upgrade_command.clone(),
label: upgrade_command.clone(),
command: Some(upgrade_command.clone()),
args: Vec::new(),
command_label: upgrade_command.clone(),
cwd,
env: Default::default(),
use_new_terminal: true,
allow_concurrent_runs: true,
reveal: Default::default(),
reveal_target: Default::default(),
hide: Default::default(),
shell,
show_summary: true,
show_command: true,
show_rerun: false,
};
workspace
.spawn_in_terminal(spawn_in_terminal, window, cx)
.detach();
container = container.child(
Button::new("upgrade", upgrade_message)
.tooltip(Tooltip::text(upgrade_command.clone()))
.on_click(cx.listener(move |this, _, window, cx| {
let task = this
.workspace
.update(cx, |workspace, cx| {
let project = workspace.project().read(cx);
let cwd = project.first_project_directory(cx);
let shell = project.terminal_settings(&cwd, cx).shell.clone();
let spawn_in_terminal = task::SpawnInTerminal {
id: task::TaskId("upgrade".to_string()),
full_label: upgrade_command.clone(),
label: upgrade_command.clone(),
command: Some(upgrade_command.clone()),
args: Vec::new(),
command_label: upgrade_command.clone(),
cwd,
env: Default::default(),
use_new_terminal: true,
allow_concurrent_runs: true,
reveal: Default::default(),
reveal_target: Default::default(),
hide: Default::default(),
shell,
show_summary: true,
show_command: true,
show_rerun: false,
};
workspace.spawn_in_terminal(spawn_in_terminal, window, cx)
})
.ok();
let Some(task) = task else { return };
cx.spawn_in(window, async move |this, cx| {
if let Some(Ok(_)) = task.await {
this.update_in(cx, |this, window, cx| {
this.reset(window, cx);
})
.ok();
}
})
.ok();
}),
));
.detach()
})),
);
} else if let LoadError::NotInstalled {
install_message,
install_command,
..
} = e
{
let install_message = install_message.clone();
let install_command = install_command.clone();
container = container.child(
Button::new("install", install_message)
.tooltip(Tooltip::text(install_command.clone()))
.on_click(cx.listener(move |this, _, window, cx| {
let task = this
.workspace
.update(cx, |workspace, cx| {
let project = workspace.project().read(cx);
let cwd = project.first_project_directory(cx);
let shell = project.terminal_settings(&cwd, cx).shell.clone();
let spawn_in_terminal = task::SpawnInTerminal {
id: task::TaskId("install".to_string()),
full_label: install_command.clone(),
label: install_command.clone(),
command: Some(install_command.clone()),
args: Vec::new(),
command_label: install_command.clone(),
cwd,
env: Default::default(),
use_new_terminal: true,
allow_concurrent_runs: true,
reveal: Default::default(),
reveal_target: Default::default(),
hide: Default::default(),
shell,
show_summary: true,
show_command: true,
show_rerun: false,
};
workspace.spawn_in_terminal(spawn_in_terminal, window, cx)
})
.ok();
let Some(task) = task else { return };
cx.spawn_in(window, async move |this, cx| {
if let Some(Ok(_)) = task.await {
this.update_in(cx, |this, window, cx| {
this.reset(window, cx);
})
.ok();
}
})
.detach()
})),
);
}
container.into_any()
@ -2565,7 +2628,7 @@ impl AcpThreadView {
) -> Div {
let editor_bg_color = cx.theme().colors().editor_background;
v_flex().children(changed_buffers.into_iter().enumerate().flat_map(
v_flex().children(changed_buffers.iter().enumerate().flat_map(
|(index, (buffer, _diff))| {
let file = buffer.read(cx).file()?;
let path = file.path();
@ -2785,6 +2848,7 @@ impl AcpThreadView {
.child(
h_flex()
.gap_1()
.children(self.render_token_usage(cx))
.children(self.profile_selector.clone())
.children(self.model_selector.clone())
.child(self.render_send_button(cx)),
@ -2807,6 +2871,44 @@ impl AcpThreadView {
.thread(acp_thread.session_id(), cx)
}
fn render_token_usage(&self, cx: &mut Context<Self>) -> Option<Div> {
let thread = self.thread()?.read(cx);
let usage = thread.token_usage()?;
let is_generating = thread.status() != ThreadStatus::Idle;
let used = crate::text_thread_editor::humanize_token_count(usage.used_tokens);
let max = crate::text_thread_editor::humanize_token_count(usage.max_tokens);
Some(
h_flex()
.flex_shrink_0()
.gap_0p5()
.mr_1()
.child(
Label::new(used)
.size(LabelSize::Small)
.color(Color::Muted)
.map(|label| {
if is_generating {
label
.with_animation(
"used-tokens-label",
Animation::new(Duration::from_secs(2))
.repeat()
.with_easing(pulsating_between(0.6, 1.)),
|label, delta| label.alpha(delta),
)
.into_any()
} else {
label.into_any_element()
}
}),
)
.child(Label::new("/").size(LabelSize::Small).color(Color::Muted))
.child(Label::new(max).size(LabelSize::Small).color(Color::Muted)),
)
}
fn toggle_burn_mode(
&mut self,
_: &ToggleBurnMode,
@ -2817,12 +2919,15 @@ impl AcpThreadView {
return;
};
thread.update(cx, |thread, _cx| {
thread.update(cx, |thread, cx| {
let current_mode = thread.completion_mode();
thread.set_completion_mode(match current_mode {
CompletionMode::Burn => CompletionMode::Normal,
CompletionMode::Normal => CompletionMode::Burn,
});
thread.set_completion_mode(
match current_mode {
CompletionMode::Burn => CompletionMode::Normal,
CompletionMode::Normal => CompletionMode::Burn,
},
cx,
);
});
}
@ -2831,7 +2936,7 @@ impl AcpThreadView {
if thread
.model()
.map_or(true, |model| !model.supports_burn_mode())
.is_none_or(|model| !model.supports_burn_mode())
{
return None;
}
@ -2863,9 +2968,9 @@ impl AcpThreadView {
fn render_send_button(&self, cx: &mut Context<Self>) -> AnyElement {
let is_editor_empty = self.message_editor.read(cx).is_empty(cx);
let is_generating = self.thread().map_or(false, |thread| {
thread.read(cx).status() != ThreadStatus::Idle
});
let is_generating = self
.thread()
.is_some_and(|thread| thread.read(cx).status() != ThreadStatus::Idle);
if is_generating && is_editor_empty {
IconButton::new("stop-generation", IconName::Stop)
@ -3265,62 +3370,61 @@ impl AcpThreadView {
})
})
.log_err()
&& let Some(pop_up) = screen_window.entity(cx).log_err()
{
if let Some(pop_up) = screen_window.entity(cx).log_err() {
self.notification_subscriptions
.entry(screen_window)
.or_insert_with(Vec::new)
.push(cx.subscribe_in(&pop_up, window, {
|this, _, event, window, cx| match event {
AgentNotificationEvent::Accepted => {
let handle = window.window_handle();
cx.activate(true);
self.notification_subscriptions
.entry(screen_window)
.or_insert_with(Vec::new)
.push(cx.subscribe_in(&pop_up, window, {
|this, _, event, window, cx| match event {
AgentNotificationEvent::Accepted => {
let handle = window.window_handle();
cx.activate(true);
let workspace_handle = this.workspace.clone();
let workspace_handle = this.workspace.clone();
// If there are multiple Zed windows, activate the correct one.
cx.defer(move |cx| {
handle
.update(cx, |_view, window, _cx| {
window.activate_window();
// If there are multiple Zed windows, activate the correct one.
cx.defer(move |cx| {
handle
.update(cx, |_view, window, _cx| {
window.activate_window();
if let Some(workspace) = workspace_handle.upgrade() {
workspace.update(_cx, |workspace, cx| {
workspace.focus_panel::<AgentPanel>(window, cx);
});
}
})
.log_err();
});
if let Some(workspace) = workspace_handle.upgrade() {
workspace.update(_cx, |workspace, cx| {
workspace.focus_panel::<AgentPanel>(window, cx);
});
}
})
.log_err();
});
this.dismiss_notifications(cx);
}
AgentNotificationEvent::Dismissed => {
this.dismiss_notifications(cx);
}
this.dismiss_notifications(cx);
}
}));
AgentNotificationEvent::Dismissed => {
this.dismiss_notifications(cx);
}
}
}));
self.notifications.push(screen_window);
self.notifications.push(screen_window);
// If the user manually refocuses the original window, dismiss the popup.
self.notification_subscriptions
.entry(screen_window)
.or_insert_with(Vec::new)
.push({
let pop_up_weak = pop_up.downgrade();
// If the user manually refocuses the original window, dismiss the popup.
self.notification_subscriptions
.entry(screen_window)
.or_insert_with(Vec::new)
.push({
let pop_up_weak = pop_up.downgrade();
cx.observe_window_activation(window, move |_, window, cx| {
if window.is_window_active() {
if let Some(pop_up) = pop_up_weak.upgrade() {
pop_up.update(cx, |_, cx| {
cx.emit(AgentNotificationEvent::Dismissed);
});
}
}
})
});
}
cx.observe_window_activation(window, move |_, window, cx| {
if window.is_window_active()
&& let Some(pop_up) = pop_up_weak.upgrade()
{
pop_up.update(cx, |_, cx| {
cx.emit(AgentNotificationEvent::Dismissed);
});
}
})
});
}
}
@ -3418,8 +3522,7 @@ impl AcpThreadView {
cx: &mut Context<Self>,
) {
self.message_editor.update(cx, |message_editor, cx| {
message_editor.insert_dragged_files(paths, window, cx);
drop(added_worktrees);
message_editor.insert_dragged_files(paths, added_worktrees, window, cx);
})
}
@ -3445,18 +3548,16 @@ impl AcpThreadView {
} else {
format!("Retrying. Next attempt in {next_attempt_in_secs} seconds.")
}
} else if next_attempt_in_secs == 1 {
format!(
"Retrying. Next attempt in 1 second (Attempt {} of {}).",
state.attempt, state.max_attempts,
)
} else {
if next_attempt_in_secs == 1 {
format!(
"Retrying. Next attempt in 1 second (Attempt {} of {}).",
state.attempt, state.max_attempts,
)
} else {
format!(
"Retrying. Next attempt in {next_attempt_in_secs} seconds (Attempt {} of {}).",
state.attempt, state.max_attempts,
)
}
format!(
"Retrying. Next attempt in {next_attempt_in_secs} seconds (Attempt {} of {}).",
state.attempt, state.max_attempts,
)
};
Some(
@ -3542,7 +3643,7 @@ impl AcpThreadView {
let supports_burn_mode = thread
.read(cx)
.model()
.map_or(false, |model| model.supports_burn_mode());
.is_some_and(|model| model.supports_burn_mode());
let focus_handle = self.focus_handle(cx);
@ -3574,8 +3675,9 @@ impl AcpThreadView {
))
.on_click({
cx.listener(move |this, _, _window, cx| {
thread.update(cx, |thread, _cx| {
thread.set_completion_mode(CompletionMode::Burn);
thread.update(cx, |thread, cx| {
thread
.set_completion_mode(CompletionMode::Burn, cx);
});
this.resume_chat(cx);
})
@ -3639,6 +3741,18 @@ impl AcpThreadView {
}
}))
}
fn reset(&mut self, window: &mut Window, cx: &mut Context<Self>) {
self.thread_state = Self::initial_state(
self.agent.clone(),
None,
self.workspace.clone(),
self.project.clone(),
window,
cx,
);
cx.notify();
}
}
impl Focusable for AcpThreadView {
@ -3677,12 +3791,6 @@ impl Render for AcpThreadView {
.items_center()
.justify_center()
.child(self.render_load_error(e, cx)),
ThreadState::ServerExited { status } => v_flex()
.p_2()
.flex_1()
.items_center()
.justify_center()
.child(self.render_server_exited(*status, cx)),
ThreadState::Ready { thread, .. } => {
let thread_clone = thread.clone();
@ -3894,6 +4002,7 @@ pub(crate) mod tests {
use acp_thread::StubAgentConnection;
use agent::{TextThreadStore, ThreadStore};
use agent_client_protocol::SessionId;
use assistant_context::ContextStore;
use editor::EditorSettings;
use fs::FakeFs;
use gpui::{EventEmitter, SemanticVersion, TestAppContext, VisualTestContext};
@ -4031,13 +4140,19 @@ pub(crate) mod tests {
cx.update(|_window, cx| cx.new(|cx| ThreadStore::fake(project.clone(), cx)));
let text_thread_store =
cx.update(|_window, cx| cx.new(|cx| TextThreadStore::fake(project.clone(), cx)));
let context_store =
cx.update(|_window, cx| cx.new(|cx| ContextStore::fake(project.clone(), cx)));
let history_store =
cx.update(|_window, cx| cx.new(|cx| HistoryStore::new(context_store, cx)));
let thread_view = cx.update(|window, cx| {
cx.new(|cx| {
AcpThreadView::new(
Rc::new(agent),
None,
workspace.downgrade(),
project,
history_store,
thread_store.clone(),
text_thread_store.clone(),
window,
@ -4158,12 +4273,13 @@ pub(crate) mod tests {
cx: &mut gpui::App,
) -> Task<gpui::Result<Entity<AcpThread>>> {
Task::ready(Ok(cx.new(|cx| {
let action_log = cx.new(|_| ActionLog::new(project.clone()));
AcpThread::new(
"SaboteurAgentConnection",
self,
project,
action_log,
SessionId("test".into()),
cx,
)
})))
}
@ -4233,14 +4349,20 @@ pub(crate) mod tests {
cx.update(|_window, cx| cx.new(|cx| ThreadStore::fake(project.clone(), cx)));
let text_thread_store =
cx.update(|_window, cx| cx.new(|cx| TextThreadStore::fake(project.clone(), cx)));
let context_store =
cx.update(|_window, cx| cx.new(|cx| ContextStore::fake(project.clone(), cx)));
let history_store =
cx.update(|_window, cx| cx.new(|cx| HistoryStore::new(context_store, cx)));
let connection = Rc::new(StubAgentConnection::new());
let thread_view = cx.update(|window, cx| {
cx.new(|cx| {
AcpThreadView::new(
Rc::new(StubAgentServer::new(connection.as_ref().clone())),
None,
workspace.downgrade(),
project.clone(),
history_store.clone(),
thread_store.clone(),
text_thread_store.clone(),
window,

View file

@ -1072,8 +1072,8 @@ impl ActiveThread {
}
ThreadEvent::MessageEdited(message_id) => {
self.clear_last_error();
if let Some(index) = self.messages.iter().position(|id| id == message_id) {
if let Some(rendered_message) = self.thread.update(cx, |thread, cx| {
if let Some(index) = self.messages.iter().position(|id| id == message_id)
&& let Some(rendered_message) = self.thread.update(cx, |thread, cx| {
thread.message(*message_id).map(|message| {
let mut rendered_message = RenderedMessage {
language_registry: self.language_registry.clone(),
@ -1084,14 +1084,14 @@ impl ActiveThread {
}
rendered_message
})
}) {
self.list_state.splice(index..index + 1, 1);
self.rendered_messages_by_id
.insert(*message_id, rendered_message);
self.scroll_to_bottom(cx);
self.save_thread(cx);
cx.notify();
}
})
{
self.list_state.splice(index..index + 1, 1);
self.rendered_messages_by_id
.insert(*message_id, rendered_message);
self.scroll_to_bottom(cx);
self.save_thread(cx);
cx.notify();
}
}
ThreadEvent::MessageDeleted(message_id) => {
@ -1272,62 +1272,61 @@ impl ActiveThread {
})
})
.log_err()
&& let Some(pop_up) = screen_window.entity(cx).log_err()
{
if let Some(pop_up) = screen_window.entity(cx).log_err() {
self.notification_subscriptions
.entry(screen_window)
.or_insert_with(Vec::new)
.push(cx.subscribe_in(&pop_up, window, {
|this, _, event, window, cx| match event {
AgentNotificationEvent::Accepted => {
let handle = window.window_handle();
cx.activate(true);
self.notification_subscriptions
.entry(screen_window)
.or_insert_with(Vec::new)
.push(cx.subscribe_in(&pop_up, window, {
|this, _, event, window, cx| match event {
AgentNotificationEvent::Accepted => {
let handle = window.window_handle();
cx.activate(true);
let workspace_handle = this.workspace.clone();
let workspace_handle = this.workspace.clone();
// If there are multiple Zed windows, activate the correct one.
cx.defer(move |cx| {
handle
.update(cx, |_view, window, _cx| {
window.activate_window();
// If there are multiple Zed windows, activate the correct one.
cx.defer(move |cx| {
handle
.update(cx, |_view, window, _cx| {
window.activate_window();
if let Some(workspace) = workspace_handle.upgrade() {
workspace.update(_cx, |workspace, cx| {
workspace.focus_panel::<AgentPanel>(window, cx);
});
}
})
.log_err();
});
if let Some(workspace) = workspace_handle.upgrade() {
workspace.update(_cx, |workspace, cx| {
workspace.focus_panel::<AgentPanel>(window, cx);
});
}
})
.log_err();
});
this.dismiss_notifications(cx);
}
AgentNotificationEvent::Dismissed => {
this.dismiss_notifications(cx);
}
this.dismiss_notifications(cx);
}
}));
AgentNotificationEvent::Dismissed => {
this.dismiss_notifications(cx);
}
}
}));
self.notifications.push(screen_window);
self.notifications.push(screen_window);
// If the user manually refocuses the original window, dismiss the popup.
self.notification_subscriptions
.entry(screen_window)
.or_insert_with(Vec::new)
.push({
let pop_up_weak = pop_up.downgrade();
// If the user manually refocuses the original window, dismiss the popup.
self.notification_subscriptions
.entry(screen_window)
.or_insert_with(Vec::new)
.push({
let pop_up_weak = pop_up.downgrade();
cx.observe_window_activation(window, move |_, window, cx| {
if window.is_window_active() {
if let Some(pop_up) = pop_up_weak.upgrade() {
pop_up.update(cx, |_, cx| {
cx.emit(AgentNotificationEvent::Dismissed);
});
}
}
})
});
}
cx.observe_window_activation(window, move |_, window, cx| {
if window.is_window_active()
&& let Some(pop_up) = pop_up_weak.upgrade()
{
pop_up.update(cx, |_, cx| {
cx.emit(AgentNotificationEvent::Dismissed);
});
}
})
});
}
}
@ -1374,12 +1373,12 @@ impl ActiveThread {
editor.focus_handle(cx).focus(window);
editor.move_to_end(&editor::actions::MoveToEnd, window, cx);
});
let buffer_edited_subscription = cx.subscribe(&editor, |this, _, event, cx| match event {
EditorEvent::BufferEdited => {
this.update_editing_message_token_count(true, cx);
}
_ => {}
});
let buffer_edited_subscription =
cx.subscribe(&editor, |this, _, event: &EditorEvent, cx| {
if event == &EditorEvent::BufferEdited {
this.update_editing_message_token_count(true, cx);
}
});
let context_picker_menu_handle = PopoverMenuHandle::default();
let context_strip = cx.new(|cx| {
@ -2247,9 +2246,7 @@ impl ActiveThread {
let after_editing_message = self
.editing_message
.as_ref()
.map_or(false, |(editing_message_id, _)| {
message_id > *editing_message_id
});
.is_some_and(|(editing_message_id, _)| message_id > *editing_message_id);
let backdrop = div()
.id(("backdrop", ix))
@ -2269,13 +2266,12 @@ impl ActiveThread {
let mut error = None;
if let Some(last_restore_checkpoint) =
self.thread.read(cx).last_restore_checkpoint()
&& last_restore_checkpoint.message_id() == message_id
{
if last_restore_checkpoint.message_id() == message_id {
match last_restore_checkpoint {
LastRestoreCheckpoint::Pending { .. } => is_pending = true,
LastRestoreCheckpoint::Error { error: err, .. } => {
error = Some(err.clone());
}
match last_restore_checkpoint {
LastRestoreCheckpoint::Pending { .. } => is_pending = true,
LastRestoreCheckpoint::Error { error: err, .. } => {
error = Some(err.clone());
}
}
}

View file

@ -96,7 +96,7 @@ impl AgentConfiguration {
let mut expanded_provider_configurations = HashMap::default();
if LanguageModelRegistry::read_global(cx)
.provider(&ZED_CLOUD_PROVIDER_ID)
.map_or(false, |cloud_provider| cloud_provider.must_accept_terms(cx))
.is_some_and(|cloud_provider| cloud_provider.must_accept_terms(cx))
{
expanded_provider_configurations.insert(ZED_CLOUD_PROVIDER_ID, true);
}
@ -958,7 +958,7 @@ impl AgentConfiguration {
}
parent.child(v_flex().py_1p5().px_1().gap_1().children(
tools.into_iter().enumerate().map(|(ix, tool)| {
tools.iter().enumerate().map(|(ix, tool)| {
h_flex()
.id(("tool-item", ix))
.px_1()

View file

@ -163,10 +163,10 @@ impl ConfigurationSource {
.read(cx)
.text(cx);
let settings = serde_json_lenient::from_str::<serde_json::Value>(&text)?;
if let Some(settings_validator) = settings_validator {
if let Err(error) = settings_validator.validate(&settings) {
return Err(anyhow::anyhow!(error.to_string()));
}
if let Some(settings_validator) = settings_validator
&& let Err(error) = settings_validator.validate(&settings)
{
return Err(anyhow::anyhow!(error.to_string()));
}
Ok((
id.clone(),
@ -487,7 +487,7 @@ impl ConfigureContextServerModal {
}
fn render_modal_description(&self, window: &mut Window, cx: &mut Context<Self>) -> AnyElement {
const MODAL_DESCRIPTION: &'static str = "Visit the MCP server configuration docs to find all necessary arguments and environment variables.";
const MODAL_DESCRIPTION: &str = "Visit the MCP server configuration docs to find all necessary arguments and environment variables.";
if let ConfigurationSource::Extension {
installation_instructions: Some(installation_instructions),
@ -716,24 +716,24 @@ fn wait_for_context_server(
project::context_server_store::Event::ServerStatusChanged { server_id, status } => {
match status {
ContextServerStatus::Running => {
if server_id == &context_server_id {
if let Some(tx) = tx.lock().unwrap().take() {
let _ = tx.send(Ok(()));
}
if server_id == &context_server_id
&& let Some(tx) = tx.lock().unwrap().take()
{
let _ = tx.send(Ok(()));
}
}
ContextServerStatus::Stopped => {
if server_id == &context_server_id {
if let Some(tx) = tx.lock().unwrap().take() {
let _ = tx.send(Err("Context server stopped running".into()));
}
if server_id == &context_server_id
&& let Some(tx) = tx.lock().unwrap().take()
{
let _ = tx.send(Err("Context server stopped running".into()));
}
}
ContextServerStatus::Error(error) => {
if server_id == &context_server_id {
if let Some(tx) = tx.lock().unwrap().take() {
let _ = tx.send(Err(error.clone()));
}
if server_id == &context_server_id
&& let Some(tx) = tx.lock().unwrap().take()
{
let _ = tx.send(Err(error.clone()));
}
}
_ => {}

View file

@ -191,10 +191,10 @@ impl PickerDelegate for ToolPickerDelegate {
BTreeMap::default();
for item in all_items.iter() {
if let PickerItem::Tool { server_id, name } = item.clone() {
if name.contains(&query) {
tools_by_provider.entry(server_id).or_default().push(name);
}
if let PickerItem::Tool { server_id, name } = item.clone()
&& name.contains(&query)
{
tools_by_provider.entry(server_id).or_default().push(name);
}
}

View file

@ -199,24 +199,21 @@ impl AgentDiffPane {
let action_log = thread.action_log(cx).clone();
let mut this = Self {
_subscriptions: [
Some(
cx.observe_in(&action_log, window, |this, _action_log, window, cx| {
this.update_excerpts(window, cx)
}),
),
_subscriptions: vec![
cx.observe_in(&action_log, window, |this, _action_log, window, cx| {
this.update_excerpts(window, cx)
}),
match &thread {
AgentDiffThread::Native(thread) => {
Some(cx.subscribe(thread, |this, _thread, event, cx| {
this.handle_thread_event(event, cx)
}))
}
AgentDiffThread::AcpThread(_) => None,
AgentDiffThread::Native(thread) => cx
.subscribe(thread, |this, _thread, event, cx| {
this.handle_native_thread_event(event, cx)
}),
AgentDiffThread::AcpThread(thread) => cx
.subscribe(thread, |this, _thread, event, cx| {
this.handle_acp_thread_event(event, cx)
}),
},
]
.into_iter()
.flatten()
.collect(),
],
title: SharedString::default(),
multibuffer,
editor,
@ -288,7 +285,7 @@ impl AgentDiffPane {
&& buffer
.read(cx)
.file()
.map_or(false, |file| file.disk_state() == DiskState::Deleted)
.is_some_and(|file| file.disk_state() == DiskState::Deleted)
{
editor.fold_buffer(snapshot.text.remote_id(), cx)
}
@ -324,10 +321,15 @@ impl AgentDiffPane {
}
}
fn handle_thread_event(&mut self, event: &ThreadEvent, cx: &mut Context<Self>) {
match event {
ThreadEvent::SummaryGenerated => self.update_title(cx),
_ => {}
fn handle_native_thread_event(&mut self, event: &ThreadEvent, cx: &mut Context<Self>) {
if let ThreadEvent::SummaryGenerated = event {
self.update_title(cx)
}
}
fn handle_acp_thread_event(&mut self, event: &AcpThreadEvent, cx: &mut Context<Self>) {
if let AcpThreadEvent::TitleUpdated = event {
self.update_title(cx)
}
}
@ -1043,23 +1045,23 @@ impl ToolbarItemView for AgentDiffToolbar {
return self.location(cx);
}
if let Some(editor) = item.act_as::<Editor>(cx) {
if editor.read(cx).mode().is_full() {
let agent_diff = AgentDiff::global(cx);
if let Some(editor) = item.act_as::<Editor>(cx)
&& editor.read(cx).mode().is_full()
{
let agent_diff = AgentDiff::global(cx);
self.active_item = Some(AgentDiffToolbarItem::Editor {
editor: editor.downgrade(),
state: agent_diff.read(cx).editor_state(&editor.downgrade()),
_diff_subscription: cx.observe(&agent_diff, Self::handle_diff_notify),
});
self.active_item = Some(AgentDiffToolbarItem::Editor {
editor: editor.downgrade(),
state: agent_diff.read(cx).editor_state(&editor.downgrade()),
_diff_subscription: cx.observe(&agent_diff, Self::handle_diff_notify),
});
return self.location(cx);
}
return self.location(cx);
}
}
self.active_item = None;
return self.location(cx);
self.location(cx)
}
fn pane_focus_update(
@ -1505,7 +1507,7 @@ impl AgentDiff {
.read(cx)
.entries()
.last()
.map_or(false, |entry| entry.diffs().next().is_some())
.is_some_and(|entry| entry.diffs().next().is_some())
{
self.update_reviewing_editors(workspace, window, cx);
}
@ -1515,15 +1517,17 @@ impl AgentDiff {
.read(cx)
.entries()
.get(*ix)
.map_or(false, |entry| entry.diffs().next().is_some())
.is_some_and(|entry| entry.diffs().next().is_some())
{
self.update_reviewing_editors(workspace, window, cx);
}
}
AcpThreadEvent::Stopped | AcpThreadEvent::Error | AcpThreadEvent::ServerExited(_) => {
AcpThreadEvent::Stopped | AcpThreadEvent::Error | AcpThreadEvent::LoadError(_) => {
self.update_reviewing_editors(workspace, window, cx);
}
AcpThreadEvent::EntriesRemoved(_)
AcpThreadEvent::TitleUpdated
| AcpThreadEvent::TokenUsageUpdated
| AcpThreadEvent::EntriesRemoved(_)
| AcpThreadEvent::ToolAuthorizationRequired
| AcpThreadEvent::Retry(_) => {}
}
@ -1536,21 +1540,11 @@ impl AgentDiff {
window: &mut Window,
cx: &mut Context<Self>,
) {
match event {
workspace::Event::ItemAdded { item } => {
if let Some(editor) = item.downcast::<Editor>() {
if let Some(buffer) = Self::full_editor_buffer(editor.read(cx), cx) {
self.register_editor(
workspace.downgrade(),
buffer.clone(),
editor,
window,
cx,
);
}
}
}
_ => {}
if let workspace::Event::ItemAdded { item } = event
&& let Some(editor) = item.downcast::<Editor>()
&& let Some(buffer) = Self::full_editor_buffer(editor.read(cx), cx)
{
self.register_editor(workspace.downgrade(), buffer.clone(), editor, window, cx);
}
}
@ -1710,7 +1704,7 @@ impl AgentDiff {
.read_with(cx, |editor, _cx| editor.workspace())
.ok()
.flatten()
.map_or(false, |editor_workspace| {
.is_some_and(|editor_workspace| {
editor_workspace.entity_id() == workspace.entity_id()
});
@ -1850,26 +1844,26 @@ impl AgentDiff {
let thread = thread.upgrade()?;
if let PostReviewState::AllReviewed = review(&editor, &thread, window, cx) {
if let Some(curr_buffer) = editor.read(cx).buffer().read(cx).as_singleton() {
let changed_buffers = thread.action_log(cx).read(cx).changed_buffers(cx);
if let PostReviewState::AllReviewed = review(&editor, &thread, window, cx)
&& let Some(curr_buffer) = editor.read(cx).buffer().read(cx).as_singleton()
{
let changed_buffers = thread.action_log(cx).read(cx).changed_buffers(cx);
let mut keys = changed_buffers.keys().cycle();
keys.find(|k| *k == &curr_buffer);
let next_project_path = keys
.next()
.filter(|k| *k != &curr_buffer)
.and_then(|after| after.read(cx).project_path(cx));
let mut keys = changed_buffers.keys().cycle();
keys.find(|k| *k == &curr_buffer);
let next_project_path = keys
.next()
.filter(|k| *k != &curr_buffer)
.and_then(|after| after.read(cx).project_path(cx));
if let Some(path) = next_project_path {
let task = workspace.open_path(path, None, true, window, cx);
let task = cx.spawn(async move |_, _cx| task.await.map(|_| ()));
return Some(task);
}
if let Some(path) = next_project_path {
let task = workspace.open_path(path, None, true, window, cx);
let task = cx.spawn(async move |_, _cx| task.await.map(|_| ()));
return Some(task);
}
}
return Some(Task::ready(Ok(())));
Some(Task::ready(Ok(())))
}
}

View file

@ -4,11 +4,11 @@ use std::rc::Rc;
use std::sync::Arc;
use std::time::Duration;
use agent_servers::AgentServer;
use agent2::{DbThreadMetadata, HistoryEntry};
use db::kvp::{Dismissable, KEY_VALUE_STORE};
use serde::{Deserialize, Serialize};
use crate::NewExternalAgentThread;
use crate::acp::{AcpThreadHistory, ThreadHistoryEvent};
use crate::agent_diff::AgentDiffThread;
use crate::{
AddContextServer, AgentDiffPane, ContinueThread, ContinueWithBurnMode,
@ -29,6 +29,7 @@ use crate::{
thread_history::{HistoryEntryElement, ThreadHistory},
ui::{AgentOnboardingModal, EndTrialUpsell},
};
use crate::{ExternalAgent, NewExternalAgentThread};
use agent::{
Thread, ThreadError, ThreadEvent, ThreadId, ThreadSummary, TokenUsageRatio,
context_store::ContextStore,
@ -44,7 +45,7 @@ use assistant_tool::ToolWorkingSet;
use client::{UserStore, zed_urls};
use cloud_llm_client::{CompletionIntent, Plan, UsageLimit};
use editor::{Anchor, AnchorRangeExt as _, Editor, EditorEvent, MultiBuffer};
use feature_flags::{self, FeatureFlagAppExt};
use feature_flags::{self, ClaudeCodeFeatureFlag, FeatureFlagAppExt, GeminiAndNativeFeatureFlag};
use fs::Fs;
use gpui::{
Action, Animation, AnimationExt as _, AnyElement, App, AsyncWindowContext, ClipboardItem,
@ -118,7 +119,7 @@ pub fn init(cx: &mut App) {
if let Some(panel) = workspace.panel::<AgentPanel>(cx) {
workspace.focus_panel::<AgentPanel>(window, cx);
panel.update(cx, |panel, cx| {
panel.new_external_thread(action.agent, window, cx)
panel.external_thread(action.agent, None, window, cx)
});
}
})
@ -353,7 +354,7 @@ impl ActiveView {
Self::Thread {
change_title_editor: editor,
thread: active_thread,
message_editor: message_editor,
message_editor,
_subscriptions: subscriptions,
}
}
@ -361,6 +362,7 @@ impl ActiveView {
pub fn prompt_editor(
context_editor: Entity<TextThreadEditor>,
history_store: Entity<HistoryStore>,
acp_history_store: Entity<agent2::HistoryStore>,
language_registry: Arc<LanguageRegistry>,
window: &mut Window,
cx: &mut App,
@ -438,6 +440,18 @@ impl ActiveView {
);
}
});
acp_history_store.update(cx, |history_store, cx| {
if let Some(old_path) = old_path {
history_store
.replace_recently_opened_text_thread(old_path, new_path, cx);
} else {
history_store.push_recently_opened_entry(
agent2::HistoryEntryId::TextThread(new_path.clone()),
cx,
);
}
});
}
_ => {}
}
@ -466,6 +480,8 @@ pub struct AgentPanel {
fs: Arc<dyn Fs>,
language_registry: Arc<LanguageRegistry>,
thread_store: Entity<ThreadStore>,
acp_history: Entity<AcpThreadHistory>,
acp_history_store: Entity<agent2::HistoryStore>,
_default_model_subscription: Subscription,
context_store: Entity<TextThreadStore>,
prompt_store: Option<Entity<PromptStore>>,
@ -632,6 +648,28 @@ impl AgentPanel {
)
});
let acp_history_store = cx.new(|cx| agent2::HistoryStore::new(context_store.clone(), cx));
let acp_history = cx.new(|cx| AcpThreadHistory::new(acp_history_store.clone(), window, cx));
cx.subscribe_in(
&acp_history,
window,
|this, _, event, window, cx| match event {
ThreadHistoryEvent::Open(HistoryEntry::AcpThread(thread)) => {
this.external_thread(
Some(crate::ExternalAgent::NativeAgent),
Some(thread.clone()),
window,
cx,
);
}
ThreadHistoryEvent::Open(HistoryEntry::TextThread(thread)) => {
this.open_saved_prompt_editor(thread.path.clone(), window, cx)
.detach_and_log_err(cx);
}
},
)
.detach();
cx.observe(&history_store, |_, _, cx| cx.notify()).detach();
let active_thread = cx.new(|cx| {
@ -670,6 +708,7 @@ impl AgentPanel {
ActiveView::prompt_editor(
context_editor,
history_store.clone(),
acp_history_store.clone(),
language_registry.clone(),
window,
cx,
@ -686,7 +725,11 @@ impl AgentPanel {
let assistant_navigation_menu =
ContextMenu::build_persistent(window, cx, move |mut menu, _window, cx| {
if let Some(panel) = panel.upgrade() {
menu = Self::populate_recently_opened_menu_section(menu, panel, cx);
if cx.has_flag::<GeminiAndNativeFeatureFlag>() {
menu = Self::populate_recently_opened_menu_section_new(menu, panel, cx);
} else {
menu = Self::populate_recently_opened_menu_section_old(menu, panel, cx);
}
}
menu.action("View All", Box::new(OpenHistory))
.end_slot_action(DeleteRecentlyOpenThread.boxed_clone())
@ -712,25 +755,25 @@ impl AgentPanel {
.ok();
});
let _default_model_subscription = cx.subscribe(
&LanguageModelRegistry::global(cx),
|this, _, event: &language_model::Event, cx| match event {
language_model::Event::DefaultModelChanged => match &this.active_view {
ActiveView::Thread { thread, .. } => {
thread
.read(cx)
.thread()
.clone()
.update(cx, |thread, cx| thread.get_or_init_configured_model(cx));
let _default_model_subscription =
cx.subscribe(
&LanguageModelRegistry::global(cx),
|this, _, event: &language_model::Event, cx| {
if let language_model::Event::DefaultModelChanged = event {
match &this.active_view {
ActiveView::Thread { thread, .. } => {
thread.read(cx).thread().clone().update(cx, |thread, cx| {
thread.get_or_init_configured_model(cx)
});
}
ActiveView::ExternalAgentThread { .. }
| ActiveView::TextThread { .. }
| ActiveView::History
| ActiveView::Configuration => {}
}
}
ActiveView::ExternalAgentThread { .. }
| ActiveView::TextThread { .. }
| ActiveView::History
| ActiveView::Configuration => {}
},
_ => {}
},
);
);
let onboarding = cx.new(|cx| {
AgentPanelOnboarding::new(
@ -774,6 +817,8 @@ impl AgentPanel {
zoomed: false,
pending_serialization: None,
onboarding,
acp_history,
acp_history_store,
selected_agent: AgentType::default(),
}
}
@ -836,6 +881,9 @@ impl AgentPanel {
}
fn new_thread(&mut self, action: &NewThread, window: &mut Window, cx: &mut Context<Self>) {
if cx.has_flag::<GeminiAndNativeFeatureFlag>() {
return self.new_agent_thread(AgentType::NativeAgent, window, cx);
}
// Preserve chat box text when using creating new thread
let preserved_text = self
.active_message_editor()
@ -940,6 +988,7 @@ impl AgentPanel {
ActiveView::prompt_editor(
context_editor.clone(),
self.history_store.clone(),
self.acp_history_store.clone(),
self.language_registry.clone(),
window,
cx,
@ -950,9 +999,10 @@ impl AgentPanel {
context_editor.focus_handle(cx).focus(window);
}
fn new_external_thread(
fn external_thread(
&mut self,
agent_choice: Option<crate::ExternalAgent>,
resume_thread: Option<DbThreadMetadata>,
window: &mut Window,
cx: &mut Context<Self>,
) {
@ -969,9 +1019,10 @@ impl AgentPanel {
let thread_store = self.thread_store.clone();
let text_thread_store = self.context_store.clone();
let history = self.acp_history_store.clone();
cx.spawn_in(window, async move |this, cx| {
let server: Rc<dyn AgentServer> = match agent_choice {
let ext_agent = match agent_choice {
Some(agent) => {
cx.background_spawn(async move {
if let Some(serialized) =
@ -985,10 +1036,10 @@ impl AgentPanel {
})
.detach();
agent.server(fs)
agent
}
None => cx
.background_spawn(async move {
None => {
cx.background_spawn(async move {
KEY_VALUE_STORE.read_kvp(LAST_USED_EXTERNAL_AGENT_KEY)
})
.await
@ -999,15 +1050,32 @@ impl AgentPanel {
})
.unwrap_or_default()
.agent
.server(fs),
}
};
let server = ext_agent.server(fs, history);
this.update_in(cx, |this, window, cx| {
match ext_agent {
crate::ExternalAgent::Gemini | crate::ExternalAgent::NativeAgent => {
if !cx.has_flag::<GeminiAndNativeFeatureFlag>() {
return;
}
}
crate::ExternalAgent::ClaudeCode => {
if !cx.has_flag::<ClaudeCodeFeatureFlag>() {
return;
}
}
}
let thread_view = cx.new(|cx| {
crate::acp::AcpThreadView::new(
server,
resume_thread,
workspace.clone(),
project,
this.acp_history_store.clone(),
thread_store.clone(),
text_thread_store.clone(),
window,
@ -1100,6 +1168,7 @@ impl AgentPanel {
ActiveView::prompt_editor(
editor.clone(),
self.history_store.clone(),
self.acp_history_store.clone(),
self.language_registry.clone(),
window,
cx,
@ -1397,15 +1466,14 @@ impl AgentPanel {
AssistantConfigurationEvent::NewThread(provider) => {
if LanguageModelRegistry::read_global(cx)
.default_model()
.map_or(true, |model| model.provider.id() != provider.id())
.is_none_or(|model| model.provider.id() != provider.id())
&& let Some(model) = provider.default_model(cx)
{
if let Some(model) = provider.default_model(cx) {
update_settings_file::<AgentSettings>(
self.fs.clone(),
cx,
move |settings, _| settings.set_model(model),
);
}
update_settings_file::<AgentSettings>(
self.fs.clone(),
cx,
move |settings, _| settings.set_model(model),
);
}
self.new_thread(&NewThread::default(), window, cx);
@ -1524,17 +1592,14 @@ impl AgentPanel {
let current_is_special = current_is_history || current_is_config;
let new_is_special = new_is_history || new_is_config;
match &self.active_view {
ActiveView::Thread { thread, .. } => {
let thread = thread.read(cx);
if thread.is_empty() {
let id = thread.thread().read(cx).id().clone();
self.history_store.update(cx, |store, cx| {
store.remove_recently_opened_thread(id, cx);
});
}
if let ActiveView::Thread { thread, .. } = &self.active_view {
let thread = thread.read(cx);
if thread.is_empty() {
let id = thread.thread().read(cx).id().clone();
self.history_store.update(cx, |store, cx| {
store.remove_recently_opened_thread(id, cx);
});
}
_ => {}
}
match &new_view {
@ -1547,6 +1612,14 @@ impl AgentPanel {
if let Some(path) = context_editor.read(cx).context().read(cx).path() {
store.push_recently_opened_entry(HistoryEntryId::Context(path.clone()), cx)
}
});
self.acp_history_store.update(cx, |store, cx| {
if let Some(path) = context_editor.read(cx).context().read(cx).path() {
store.push_recently_opened_entry(
agent2::HistoryEntryId::TextThread(path.clone()),
cx,
)
}
})
}
ActiveView::ExternalAgentThread { .. } => {}
@ -1567,7 +1640,7 @@ impl AgentPanel {
self.focus_handle(cx).focus(window);
}
fn populate_recently_opened_menu_section(
fn populate_recently_opened_menu_section_old(
mut menu: ContextMenu,
panel: Entity<Self>,
cx: &mut Context<ContextMenu>,
@ -1631,6 +1704,72 @@ impl AgentPanel {
menu
}
fn populate_recently_opened_menu_section_new(
mut menu: ContextMenu,
panel: Entity<Self>,
cx: &mut Context<ContextMenu>,
) -> ContextMenu {
let entries = panel
.read(cx)
.acp_history_store
.read(cx)
.recently_opened_entries(cx);
if entries.is_empty() {
return menu;
}
menu = menu.header("Recently Opened");
for entry in entries {
let title = entry.title().clone();
menu = menu.entry_with_end_slot_on_hover(
title,
None,
{
let panel = panel.downgrade();
let entry = entry.clone();
move |window, cx| {
let entry = entry.clone();
panel
.update(cx, move |this, cx| match &entry {
agent2::HistoryEntry::AcpThread(entry) => this.external_thread(
Some(ExternalAgent::NativeAgent),
Some(entry.clone()),
window,
cx,
),
agent2::HistoryEntry::TextThread(entry) => this
.open_saved_prompt_editor(entry.path.clone(), window, cx)
.detach_and_log_err(cx),
})
.ok();
}
},
IconName::Close,
"Close Entry".into(),
{
let panel = panel.downgrade();
let id = entry.id();
move |_window, cx| {
panel
.update(cx, |this, cx| {
this.acp_history_store.update(cx, |history_store, cx| {
history_store.remove_recently_opened_entry(&id, cx);
});
})
.ok();
}
},
);
}
menu = menu.separator();
menu
}
pub fn set_selected_agent(
&mut self,
agent: AgentType,
@ -1640,8 +1779,8 @@ impl AgentPanel {
if self.selected_agent != agent {
self.selected_agent = agent;
self.serialize(cx);
self.new_agent_thread(agent, window, cx);
}
self.new_agent_thread(agent, window, cx);
}
pub fn selected_agent(&self) -> AgentType {
@ -1668,13 +1807,13 @@ impl AgentPanel {
window.dispatch_action(NewTextThread.boxed_clone(), cx);
}
AgentType::NativeAgent => {
self.new_external_thread(Some(crate::ExternalAgent::NativeAgent), window, cx)
self.external_thread(Some(crate::ExternalAgent::NativeAgent), None, window, cx)
}
AgentType::Gemini => {
self.new_external_thread(Some(crate::ExternalAgent::Gemini), window, cx)
self.external_thread(Some(crate::ExternalAgent::Gemini), None, window, cx)
}
AgentType::ClaudeCode => {
self.new_external_thread(Some(crate::ExternalAgent::ClaudeCode), window, cx)
self.external_thread(Some(crate::ExternalAgent::ClaudeCode), None, window, cx)
}
}
}
@ -1685,7 +1824,13 @@ impl Focusable for AgentPanel {
match &self.active_view {
ActiveView::Thread { message_editor, .. } => message_editor.focus_handle(cx),
ActiveView::ExternalAgentThread { thread_view, .. } => thread_view.focus_handle(cx),
ActiveView::History => self.history.focus_handle(cx),
ActiveView::History => {
if cx.has_flag::<feature_flags::GeminiAndNativeFeatureFlag>() {
self.acp_history.focus_handle(cx)
} else {
self.history.focus_handle(cx)
}
}
ActiveView::TextThread { context_editor, .. } => context_editor.focus_handle(cx),
ActiveView::Configuration => {
if let Some(configuration) = self.configuration.as_ref() {
@ -2244,9 +2389,9 @@ impl AgentPanel {
})
.item(
ContextMenuEntry::new("New Thread")
.icon(IconName::Thread)
.icon_color(Color::Muted)
.action(NewThread::default().boxed_clone())
.icon(IconName::ZedAssistant)
.icon_color(Color::Muted)
.handler({
let workspace = workspace.clone();
move |window, cx| {
@ -2257,7 +2402,7 @@ impl AgentPanel {
{
panel.update(cx, |panel, cx| {
panel.set_selected_agent(
AgentType::Zed,
AgentType::NativeAgent,
window,
cx,
);
@ -2294,83 +2439,62 @@ impl AgentPanel {
}
}),
)
.item(
ContextMenuEntry::new("New Native Agent Thread")
.icon(IconName::ZedAssistant)
.icon_color(Color::Muted)
.handler({
let workspace = workspace.clone();
move |window, cx| {
if let Some(workspace) = workspace.upgrade() {
workspace.update(cx, |workspace, cx| {
if let Some(panel) =
workspace.panel::<AgentPanel>(cx)
{
panel.update(cx, |panel, cx| {
panel.set_selected_agent(
AgentType::NativeAgent,
window,
cx,
);
});
}
});
}
}
}),
)
.separator()
.header("External Agents")
.item(
ContextMenuEntry::new("New Gemini Thread")
.icon(IconName::AiGemini)
.icon_color(Color::Muted)
.handler({
let workspace = workspace.clone();
move |window, cx| {
if let Some(workspace) = workspace.upgrade() {
workspace.update(cx, |workspace, cx| {
if let Some(panel) =
workspace.panel::<AgentPanel>(cx)
{
panel.update(cx, |panel, cx| {
panel.set_selected_agent(
AgentType::Gemini,
window,
cx,
);
});
}
});
.when(cx.has_flag::<GeminiAndNativeFeatureFlag>(), |menu| {
menu.item(
ContextMenuEntry::new("New Gemini Thread")
.icon(IconName::AiGemini)
.icon_color(Color::Muted)
.handler({
let workspace = workspace.clone();
move |window, cx| {
if let Some(workspace) = workspace.upgrade() {
workspace.update(cx, |workspace, cx| {
if let Some(panel) =
workspace.panel::<AgentPanel>(cx)
{
panel.update(cx, |panel, cx| {
panel.set_selected_agent(
AgentType::Gemini,
window,
cx,
);
});
}
});
}
}
}
}),
)
.item(
ContextMenuEntry::new("New Claude Code Thread")
.icon(IconName::AiClaude)
.icon_color(Color::Muted)
.handler({
let workspace = workspace.clone();
move |window, cx| {
if let Some(workspace) = workspace.upgrade() {
workspace.update(cx, |workspace, cx| {
if let Some(panel) =
workspace.panel::<AgentPanel>(cx)
{
panel.update(cx, |panel, cx| {
panel.set_selected_agent(
AgentType::ClaudeCode,
window,
cx,
);
});
}
});
}),
)
})
.when(cx.has_flag::<ClaudeCodeFeatureFlag>(), |menu| {
menu.item(
ContextMenuEntry::new("New Claude Code Thread")
.icon(IconName::AiClaude)
.icon_color(Color::Muted)
.handler({
let workspace = workspace.clone();
move |window, cx| {
if let Some(workspace) = workspace.upgrade() {
workspace.update(cx, |workspace, cx| {
if let Some(panel) =
workspace.panel::<AgentPanel>(cx)
{
panel.update(cx, |panel, cx| {
panel.set_selected_agent(
AgentType::ClaudeCode,
window,
cx,
);
});
}
});
}
}
}
}),
);
}),
)
});
menu
}))
}
@ -2440,7 +2564,9 @@ impl AgentPanel {
}
fn render_toolbar(&self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
if cx.has_flag::<feature_flags::AcpFeatureFlag>() {
if cx.has_flag::<feature_flags::GeminiAndNativeFeatureFlag>()
|| cx.has_flag::<feature_flags::ClaudeCodeFeatureFlag>()
{
self.render_toolbar_new(window, cx).into_any_element()
} else {
self.render_toolbar_old(window, cx).into_any_element()
@ -2565,9 +2691,7 @@ impl AgentPanel {
}
ActiveView::ExternalAgentThread { .. }
| ActiveView::History
| ActiveView::Configuration => {
return None;
}
| ActiveView::Configuration => None,
}
}
@ -2583,7 +2707,7 @@ impl AgentPanel {
.thread()
.read(cx)
.configured_model()
.map_or(false, |model| {
.is_some_and(|model| {
model.provider.id() != language_model::ZED_CLOUD_PROVIDER_ID
})
{
@ -2594,7 +2718,7 @@ impl AgentPanel {
if LanguageModelRegistry::global(cx)
.read(cx)
.default_model()
.map_or(false, |model| {
.is_some_and(|model| {
model.provider.id() != language_model::ZED_CLOUD_PROVIDER_ID
})
{
@ -2625,9 +2749,12 @@ impl AgentPanel {
false
}
_ => {
let history_is_empty = self
.history_store
.update(cx, |store, cx| store.recent_entries(1, cx).is_empty());
let history_is_empty = if cx.has_flag::<GeminiAndNativeFeatureFlag>() {
self.acp_history_store.read(cx).is_empty(cx)
} else {
self.history_store
.update(cx, |store, cx| store.recent_entries(1, cx).is_empty())
};
let has_configured_non_zed_providers = LanguageModelRegistry::read_global(cx)
.providers()
@ -2908,9 +3035,7 @@ impl AgentPanel {
let zed_provider_configured = AgentSettings::get_global(cx)
.default_model
.as_ref()
.map_or(false, |selection| {
selection.provider.0.as_str() == "zed.dev"
});
.is_some_and(|selection| selection.provider.0.as_str() == "zed.dev");
let callout = if zed_provider_configured {
Callout::new()
@ -3326,7 +3451,7 @@ impl AgentPanel {
.on_drop(cx.listener(move |this, paths: &ExternalPaths, window, cx| {
let tasks = paths
.paths()
.into_iter()
.iter()
.map(|path| {
Workspace::project_path_for_path(this.project.clone(), path, false, cx)
})
@ -3515,7 +3640,13 @@ impl Render for AgentPanel {
ActiveView::ExternalAgentThread { thread_view, .. } => parent
.child(thread_view.clone())
.child(self.render_drag_target(cx)),
ActiveView::History => parent.child(self.history.clone()),
ActiveView::History => {
if cx.has_flag::<feature_flags::GeminiAndNativeFeatureFlag>() {
parent.child(self.acp_history.clone())
} else {
parent.child(self.history.clone())
}
}
ActiveView::TextThread {
context_editor,
buffer_search_bar,

View file

@ -156,11 +156,15 @@ enum ExternalAgent {
}
impl ExternalAgent {
pub fn server(&self, fs: Arc<dyn fs::Fs>) -> Rc<dyn agent_servers::AgentServer> {
pub fn server(
&self,
fs: Arc<dyn fs::Fs>,
history: Entity<agent2::HistoryStore>,
) -> Rc<dyn agent_servers::AgentServer> {
match self {
ExternalAgent::Gemini => Rc::new(agent_servers::Gemini),
ExternalAgent::ClaudeCode => Rc::new(agent_servers::ClaudeCode),
ExternalAgent::NativeAgent => Rc::new(agent2::NativeAgentServer::new(fs)),
ExternalAgent::NativeAgent => Rc::new(agent2::NativeAgentServer::new(fs, history)),
}
}
}

View file

@ -352,12 +352,12 @@ impl CodegenAlternative {
event: &multi_buffer::Event,
cx: &mut Context<Self>,
) {
if let multi_buffer::Event::TransactionUndone { transaction_id } = event {
if self.transformation_transaction_id == Some(*transaction_id) {
self.transformation_transaction_id = None;
self.generation = Task::ready(());
cx.emit(CodegenEvent::Undone);
}
if let multi_buffer::Event::TransactionUndone { transaction_id } = event
&& self.transformation_transaction_id == Some(*transaction_id)
{
self.transformation_transaction_id = None;
self.generation = Task::ready(());
cx.emit(CodegenEvent::Undone);
}
}
@ -576,38 +576,34 @@ impl CodegenAlternative {
let mut lines = chunk.split('\n').peekable();
while let Some(line) = lines.next() {
new_text.push_str(line);
if line_indent.is_none() {
if let Some(non_whitespace_ch_ix) =
if line_indent.is_none()
&& let Some(non_whitespace_ch_ix) =
new_text.find(|ch: char| !ch.is_whitespace())
{
line_indent = Some(non_whitespace_ch_ix);
base_indent = base_indent.or(line_indent);
{
line_indent = Some(non_whitespace_ch_ix);
base_indent = base_indent.or(line_indent);
let line_indent = line_indent.unwrap();
let base_indent = base_indent.unwrap();
let indent_delta =
line_indent as i32 - base_indent as i32;
let mut corrected_indent_len = cmp::max(
0,
suggested_line_indent.len as i32 + indent_delta,
)
as usize;
if first_line {
corrected_indent_len = corrected_indent_len
.saturating_sub(
selection_start.column as usize,
);
}
let indent_char = suggested_line_indent.char();
let mut indent_buffer = [0; 4];
let indent_str =
indent_char.encode_utf8(&mut indent_buffer);
new_text.replace_range(
..line_indent,
&indent_str.repeat(corrected_indent_len),
);
let line_indent = line_indent.unwrap();
let base_indent = base_indent.unwrap();
let indent_delta = line_indent as i32 - base_indent as i32;
let mut corrected_indent_len = cmp::max(
0,
suggested_line_indent.len as i32 + indent_delta,
)
as usize;
if first_line {
corrected_indent_len = corrected_indent_len
.saturating_sub(selection_start.column as usize);
}
let indent_char = suggested_line_indent.char();
let mut indent_buffer = [0; 4];
let indent_str =
indent_char.encode_utf8(&mut indent_buffer);
new_text.replace_range(
..line_indent,
&indent_str.repeat(corrected_indent_len),
);
}
if line_indent.is_some() {

View file

@ -385,12 +385,11 @@ impl ContextPicker {
}
pub fn select_first(&mut self, window: &mut Window, cx: &mut Context<Self>) {
match &self.mode {
ContextPickerState::Default(entity) => entity.update(cx, |entity, cx| {
// Other variants already select their first entry on open automatically
if let ContextPickerState::Default(entity) = &self.mode {
entity.update(cx, |entity, cx| {
entity.select_first(&Default::default(), window, cx)
}),
// Other variants already select their first entry on open automatically
_ => {}
})
}
}
@ -610,9 +609,7 @@ pub(crate) fn available_context_picker_entries(
.read(cx)
.active_item(cx)
.and_then(|item| item.downcast::<Editor>())
.map_or(false, |editor| {
editor.update(cx, |editor, cx| editor.has_non_empty_selection(cx))
});
.is_some_and(|editor| editor.update(cx, |editor, cx| editor.has_non_empty_selection(cx)));
if has_selection {
entries.push(ContextPickerEntry::Action(
ContextPickerAction::AddSelections,
@ -680,7 +677,7 @@ pub(crate) fn recent_context_picker_entries(
.filter(|(_, abs_path)| {
abs_path
.as_ref()
.map_or(true, |path| !exclude_paths.contains(path.as_path()))
.is_none_or(|path| !exclude_paths.contains(path.as_path()))
})
.take(4)
.filter_map(|(project_path, _)| {

View file

@ -1020,7 +1020,7 @@ impl MentionCompletion {
&& line
.chars()
.nth(last_mention_start - 1)
.map_or(false, |c| !c.is_whitespace())
.is_some_and(|c| !c.is_whitespace())
{
return None;
}

View file

@ -226,9 +226,10 @@ impl PickerDelegate for FetchContextPickerDelegate {
_window: &mut Window,
cx: &mut Context<Picker<Self>>,
) -> Option<Self::ListItem> {
let added = self.context_store.upgrade().map_or(false, |context_store| {
context_store.read(cx).includes_url(&self.url)
});
let added = self
.context_store
.upgrade()
.is_some_and(|context_store| context_store.read(cx).includes_url(&self.url));
Some(
ListItem::new(ix)

View file

@ -239,9 +239,7 @@ pub(crate) fn search_files(
PathMatchCandidateSet {
snapshot: worktree.snapshot(),
include_ignored: worktree
.root_entry()
.map_or(false, |entry| entry.is_ignored),
include_ignored: worktree.root_entry().is_some_and(|entry| entry.is_ignored),
include_root_name: true,
candidates: project::Candidates::Entries,
}

View file

@ -159,7 +159,7 @@ pub fn render_thread_context_entry(
context_store: WeakEntity<ContextStore>,
cx: &mut App,
) -> Div {
let added = context_store.upgrade().map_or(false, |context_store| {
let added = context_store.upgrade().is_some_and(|context_store| {
context_store
.read(cx)
.includes_user_rules(user_rules.prompt_id)

View file

@ -294,7 +294,7 @@ pub(crate) fn search_symbols(
.partition(|candidate| {
project
.entry_for_path(&symbols[candidate.id].path, cx)
.map_or(false, |e| !e.is_ignored)
.is_some_and(|e| !e.is_ignored)
})
})
.log_err()

View file

@ -236,12 +236,10 @@ pub fn render_thread_context_entry(
let is_added = match entry {
ThreadContextEntry::Thread { id, .. } => context_store
.upgrade()
.map_or(false, |ctx_store| ctx_store.read(cx).includes_thread(id)),
ThreadContextEntry::Context { path, .. } => {
context_store.upgrade().map_or(false, |ctx_store| {
ctx_store.read(cx).includes_text_thread(path)
})
}
.is_some_and(|ctx_store| ctx_store.read(cx).includes_thread(id)),
ThreadContextEntry::Context { path, .. } => context_store
.upgrade()
.is_some_and(|ctx_store| ctx_store.read(cx).includes_text_thread(path)),
};
h_flex()

View file

@ -368,10 +368,10 @@ impl ContextStrip {
_window: &mut Window,
cx: &mut Context<Self>,
) {
if let Some(suggested) = self.suggested_context(cx) {
if self.is_suggested_focused(&self.added_contexts(cx)) {
self.add_suggested_context(&suggested, cx);
}
if let Some(suggested) = self.suggested_context(cx)
&& self.is_suggested_focused(&self.added_contexts(cx))
{
self.add_suggested_context(&suggested, cx);
}
}

View file

@ -182,13 +182,13 @@ impl InlineAssistant {
match event {
workspace::Event::UserSavedItem { item, .. } => {
// When the user manually saves an editor, automatically accepts all finished transformations.
if let Some(editor) = item.upgrade().and_then(|item| item.act_as::<Editor>(cx)) {
if let Some(editor_assists) = self.assists_by_editor.get(&editor.downgrade()) {
for assist_id in editor_assists.assist_ids.clone() {
let assist = &self.assists[&assist_id];
if let CodegenStatus::Done = assist.codegen.read(cx).status(cx) {
self.finish_assist(assist_id, false, window, cx)
}
if let Some(editor) = item.upgrade().and_then(|item| item.act_as::<Editor>(cx))
&& let Some(editor_assists) = self.assists_by_editor.get(&editor.downgrade())
{
for assist_id in editor_assists.assist_ids.clone() {
let assist = &self.assists[&assist_id];
if let CodegenStatus::Done = assist.codegen.read(cx).status(cx) {
self.finish_assist(assist_id, false, window, cx)
}
}
}
@ -342,13 +342,11 @@ impl InlineAssistant {
)
.await
.ok();
if let Some(answer) = answer {
if answer == 0 {
cx.update(|window, cx| {
window.dispatch_action(Box::new(OpenSettings), cx)
})
if let Some(answer) = answer
&& answer == 0
{
cx.update(|window, cx| window.dispatch_action(Box::new(OpenSettings), cx))
.ok();
}
}
anyhow::Ok(())
})
@ -435,11 +433,11 @@ impl InlineAssistant {
}
}
if let Some(prev_selection) = selections.last_mut() {
if selection.start <= prev_selection.end {
prev_selection.end = selection.end;
continue;
}
if let Some(prev_selection) = selections.last_mut()
&& selection.start <= prev_selection.end
{
prev_selection.end = selection.end;
continue;
}
let latest_selection = newest_selection.get_or_insert_with(|| selection.clone());
@ -985,14 +983,13 @@ impl InlineAssistant {
EditorEvent::SelectionsChanged { .. } => {
for assist_id in editor_assists.assist_ids.clone() {
let assist = &self.assists[&assist_id];
if let Some(decorations) = assist.decorations.as_ref() {
if decorations
if let Some(decorations) = assist.decorations.as_ref()
&& decorations
.prompt_editor
.focus_handle(cx)
.is_focused(window)
{
return;
}
{
return;
}
}
@ -1123,7 +1120,7 @@ impl InlineAssistant {
if editor_assists
.scroll_lock
.as_ref()
.map_or(false, |lock| lock.assist_id == assist_id)
.is_some_and(|lock| lock.assist_id == assist_id)
{
editor_assists.scroll_lock = None;
}
@ -1503,20 +1500,18 @@ impl InlineAssistant {
window: &mut Window,
cx: &mut App,
) -> Option<InlineAssistTarget> {
if let Some(terminal_panel) = workspace.panel::<TerminalPanel>(cx) {
if terminal_panel
if let Some(terminal_panel) = workspace.panel::<TerminalPanel>(cx)
&& terminal_panel
.read(cx)
.focus_handle(cx)
.contains_focused(window, cx)
{
if let Some(terminal_view) = terminal_panel.read(cx).pane().and_then(|pane| {
pane.read(cx)
.active_item()
.and_then(|t| t.downcast::<TerminalView>())
}) {
return Some(InlineAssistTarget::Terminal(terminal_view));
}
}
&& let Some(terminal_view) = terminal_panel.read(cx).pane().and_then(|pane| {
pane.read(cx)
.active_item()
.and_then(|t| t.downcast::<TerminalView>())
})
{
return Some(InlineAssistTarget::Terminal(terminal_view));
}
let context_editor = agent_panel
@ -1741,22 +1736,20 @@ impl InlineAssist {
return;
};
if let CodegenStatus::Error(error) = codegen.read(cx).status(cx) {
if assist.decorations.is_none() {
if let Some(workspace) = assist.workspace.upgrade() {
let error = format!("Inline assistant error: {}", error);
workspace.update(cx, |workspace, cx| {
struct InlineAssistantError;
if let CodegenStatus::Error(error) = codegen.read(cx).status(cx)
&& assist.decorations.is_none()
&& let Some(workspace) = assist.workspace.upgrade()
{
let error = format!("Inline assistant error: {}", error);
workspace.update(cx, |workspace, cx| {
struct InlineAssistantError;
let id =
NotificationId::composite::<InlineAssistantError>(
assist_id.0,
);
let id = NotificationId::composite::<InlineAssistantError>(
assist_id.0,
);
workspace.show_toast(Toast::new(id, error), cx);
})
}
}
workspace.show_toast(Toast::new(id, error), cx);
})
}
if assist.decorations.is_none() {
@ -1821,18 +1814,18 @@ impl CodeActionProvider for AssistantCodeActionProvider {
has_diagnostics = true;
}
if has_diagnostics {
if let Some(symbols_containing_start) = snapshot.symbols_containing(range.start, None) {
if let Some(symbol) = symbols_containing_start.last() {
range.start = cmp::min(range.start, symbol.range.start.to_point(&snapshot));
range.end = cmp::max(range.end, symbol.range.end.to_point(&snapshot));
}
if let Some(symbols_containing_start) = snapshot.symbols_containing(range.start, None)
&& let Some(symbol) = symbols_containing_start.last()
{
range.start = cmp::min(range.start, symbol.range.start.to_point(&snapshot));
range.end = cmp::max(range.end, symbol.range.end.to_point(&snapshot));
}
if let Some(symbols_containing_end) = snapshot.symbols_containing(range.end, None) {
if let Some(symbol) = symbols_containing_end.last() {
range.start = cmp::min(range.start, symbol.range.start.to_point(&snapshot));
range.end = cmp::max(range.end, symbol.range.end.to_point(&snapshot));
}
if let Some(symbols_containing_end) = snapshot.symbols_containing(range.end, None)
&& let Some(symbol) = symbols_containing_end.last()
{
range.start = cmp::min(range.start, symbol.range.start.to_point(&snapshot));
range.end = cmp::max(range.end, symbol.range.end.to_point(&snapshot));
}
Task::ready(Ok(vec![CodeAction {

View file

@ -345,7 +345,7 @@ impl<T: 'static> PromptEditor<T> {
let prompt = self.editor.read(cx).text(cx);
if self
.prompt_history_ix
.map_or(true, |ix| self.prompt_history[ix] != prompt)
.is_none_or(|ix| self.prompt_history[ix] != prompt)
{
self.prompt_history_ix.take();
self.pending_prompt = prompt;

View file

@ -117,7 +117,7 @@ pub(crate) fn create_editor(
let mut editor = Editor::new(
editor::EditorMode::AutoHeight {
min_lines,
max_lines: max_lines,
max_lines,
},
buffer,
None,
@ -156,7 +156,7 @@ impl ProfileProvider for Entity<Thread> {
fn profiles_supported(&self, cx: &App) -> bool {
self.read(cx)
.configured_model()
.map_or(false, |model| model.model.supports_tools())
.is_some_and(|model| model.model.supports_tools())
}
fn profile_id(&self, cx: &App) -> AgentProfileId {
@ -215,9 +215,10 @@ impl MessageEditor {
let subscriptions = vec![
cx.subscribe_in(&context_strip, window, Self::handle_context_strip_event),
cx.subscribe(&editor, |this, _, event, cx| match event {
EditorEvent::BufferEdited => this.handle_message_changed(cx),
_ => {}
cx.subscribe(&editor, |this, _, event: &EditorEvent, cx| {
if event == &EditorEvent::BufferEdited {
this.handle_message_changed(cx)
}
}),
cx.observe(&context_store, |this, _, cx| {
// When context changes, reload it for token counting.
@ -1132,7 +1133,7 @@ impl MessageEditor {
)
.when(is_edit_changes_expanded, |parent| {
parent.child(
v_flex().children(changed_buffers.into_iter().enumerate().flat_map(
v_flex().children(changed_buffers.iter().enumerate().flat_map(
|(index, (buffer, _diff))| {
let file = buffer.read(cx).file()?;
let path = file.path();
@ -1289,7 +1290,7 @@ impl MessageEditor {
self.thread
.read(cx)
.configured_model()
.map_or(false, |model| model.provider.id() == ZED_CLOUD_PROVIDER_ID)
.is_some_and(|model| model.provider.id() == ZED_CLOUD_PROVIDER_ID)
}
fn render_usage_callout(&self, line_height: Pixels, cx: &mut Context<Self>) -> Option<Div> {
@ -1442,7 +1443,7 @@ impl MessageEditor {
let message_text = editor.read(cx).text(cx);
if message_text.is_empty()
&& loaded_context.map_or(true, |loaded_context| loaded_context.is_empty())
&& loaded_context.is_none_or(|loaded_context| loaded_context.is_empty())
{
return None;
}
@ -1605,7 +1606,8 @@ pub fn extract_message_creases(
.collect::<HashMap<_, _>>();
// Filter the addon's list of creases based on what the editor reports,
// since the addon might have removed creases in it.
let creases = editor.display_map.update(cx, |display_map, cx| {
editor.display_map.update(cx, |display_map, cx| {
display_map
.snapshot(cx)
.crease_snapshot
@ -1629,8 +1631,7 @@ pub fn extract_message_creases(
}
})
.collect()
});
creases
})
}
impl EventEmitter<MessageEditorEvent> for MessageEditor {}

View file

@ -140,12 +140,10 @@ impl PickerDelegate for SlashCommandDelegate {
);
ret.push(index - 1);
}
} else {
if let SlashCommandEntry::Advert { .. } = command {
previous_is_advert = true;
if index != 0 {
ret.push(index - 1);
}
} else if let SlashCommandEntry::Advert { .. } = command {
previous_is_advert = true;
if index != 0 {
ret.push(index - 1);
}
}
}
@ -329,9 +327,7 @@ where
};
let picker_view = cx.new(|cx| {
let picker =
Picker::uniform_list(delegate, window, cx).max_height(Some(rems(20.).into()));
picker
Picker::uniform_list(delegate, window, cx).max_height(Some(rems(20.).into()))
});
let handle = self

View file

@ -388,20 +388,20 @@ impl TerminalInlineAssistant {
window: &mut Window,
cx: &mut App,
) {
if let Some(assist) = self.assists.get_mut(&assist_id) {
if let Some(prompt_editor) = assist.prompt_editor.as_ref().cloned() {
assist
.terminal
.update(cx, |terminal, cx| {
terminal.clear_block_below_cursor(cx);
let block = terminal_view::BlockProperties {
height,
render: Box::new(move |_| prompt_editor.clone().into_any_element()),
};
terminal.set_block_below_cursor(block, window, cx);
})
.log_err();
}
if let Some(assist) = self.assists.get_mut(&assist_id)
&& let Some(prompt_editor) = assist.prompt_editor.as_ref().cloned()
{
assist
.terminal
.update(cx, |terminal, cx| {
terminal.clear_block_below_cursor(cx);
let block = terminal_view::BlockProperties {
height,
render: Box::new(move |_| prompt_editor.clone().into_any_element()),
};
terminal.set_block_below_cursor(block, window, cx);
})
.log_err();
}
}
}
@ -450,23 +450,20 @@ impl TerminalInlineAssist {
return;
};
if let CodegenStatus::Error(error) = &codegen.read(cx).status {
if assist.prompt_editor.is_none() {
if let Some(workspace) = assist.workspace.upgrade() {
let error =
format!("Terminal inline assistant error: {}", error);
workspace.update(cx, |workspace, cx| {
struct InlineAssistantError;
if let CodegenStatus::Error(error) = &codegen.read(cx).status
&& assist.prompt_editor.is_none()
&& let Some(workspace) = assist.workspace.upgrade()
{
let error = format!("Terminal inline assistant error: {}", error);
workspace.update(cx, |workspace, cx| {
struct InlineAssistantError;
let id =
NotificationId::composite::<InlineAssistantError>(
assist_id.0,
);
let id = NotificationId::composite::<InlineAssistantError>(
assist_id.0,
);
workspace.show_toast(Toast::new(id, error), cx);
})
}
}
workspace.show_toast(Toast::new(id, error), cx);
})
}
if assist.prompt_editor.is_none() {

View file

@ -373,7 +373,7 @@ impl TextThreadEditor {
.map(|default| default.provider);
if provider
.as_ref()
.map_or(false, |provider| provider.must_accept_terms(cx))
.is_some_and(|provider| provider.must_accept_terms(cx))
{
self.show_accept_terms = true;
cx.notify();
@ -457,7 +457,7 @@ impl TextThreadEditor {
|| snapshot
.chars_at(newest_cursor)
.next()
.map_or(false, |ch| ch != '\n')
.is_some_and(|ch| ch != '\n')
{
editor.move_to_end_of_line(
&MoveToEndOfLine {
@ -540,7 +540,7 @@ impl TextThreadEditor {
let context = self.context.read(cx);
let sections = context
.slash_command_output_sections()
.into_iter()
.iter()
.filter(|section| section.is_valid(context.buffer().read(cx)))
.cloned()
.collect::<Vec<_>>();
@ -745,28 +745,27 @@ impl TextThreadEditor {
) {
if let Some(invoked_slash_command) =
self.context.read(cx).invoked_slash_command(&command_id)
&& let InvokedSlashCommandStatus::Finished = invoked_slash_command.status
{
if let InvokedSlashCommandStatus::Finished = invoked_slash_command.status {
let run_commands_in_ranges = invoked_slash_command.run_commands_in_ranges.clone();
for range in run_commands_in_ranges {
let commands = self.context.update(cx, |context, cx| {
context.reparse(cx);
context
.pending_commands_for_range(range.clone(), cx)
.to_vec()
});
let run_commands_in_ranges = invoked_slash_command.run_commands_in_ranges.clone();
for range in run_commands_in_ranges {
let commands = self.context.update(cx, |context, cx| {
context.reparse(cx);
context
.pending_commands_for_range(range.clone(), cx)
.to_vec()
});
for command in commands {
self.run_command(
command.source_range,
&command.name,
&command.arguments,
false,
self.workspace.clone(),
window,
cx,
);
}
for command in commands {
self.run_command(
command.source_range,
&command.name,
&command.arguments,
false,
self.workspace.clone(),
window,
cx,
);
}
}
}
@ -1238,7 +1237,7 @@ impl TextThreadEditor {
let mut new_blocks = vec![];
let mut block_index_to_message = vec![];
for message in self.context.read(cx).messages(cx) {
if let Some(_) = blocks_to_remove.remove(&message.id) {
if blocks_to_remove.remove(&message.id).is_some() {
// This is an old message that we might modify.
let Some((meta, block_id)) = old_blocks.get_mut(&message.id) else {
debug_assert!(
@ -1276,7 +1275,7 @@ impl TextThreadEditor {
context_editor_view: &Entity<TextThreadEditor>,
cx: &mut Context<Workspace>,
) -> Option<(String, bool)> {
const CODE_FENCE_DELIMITER: &'static str = "```";
const CODE_FENCE_DELIMITER: &str = "```";
let context_editor = context_editor_view.read(cx).editor.clone();
context_editor.update(cx, |context_editor, cx| {
@ -2162,8 +2161,8 @@ impl TextThreadEditor {
/// Returns the contents of the *outermost* fenced code block that contains the given offset.
fn find_surrounding_code_block(snapshot: &BufferSnapshot, offset: usize) -> Option<Range<usize>> {
const CODE_BLOCK_NODE: &'static str = "fenced_code_block";
const CODE_BLOCK_CONTENT: &'static str = "code_fence_content";
const CODE_BLOCK_NODE: &str = "fenced_code_block";
const CODE_BLOCK_CONTENT: &str = "code_fence_content";
let layer = snapshot.syntax_layers().next()?;
@ -3130,7 +3129,7 @@ mod tests {
let context_editor = window
.update(&mut cx, |_, window, cx| {
cx.new(|cx| {
let editor = TextThreadEditor::for_context(
TextThreadEditor::for_context(
context.clone(),
fs,
workspace.downgrade(),
@ -3138,8 +3137,7 @@ mod tests {
None,
window,
cx,
);
editor
)
})
})
.unwrap();

View file

@ -166,14 +166,13 @@ impl ThreadHistory {
this.all_entries.len().saturating_sub(1),
cx,
);
} else if let Some(prev_id) = previously_selected_entry {
if let Some(new_ix) = this
} else if let Some(prev_id) = previously_selected_entry
&& let Some(new_ix) = this
.all_entries
.iter()
.position(|probe| probe.id() == prev_id)
{
this.set_selected_entry_index(new_ix, cx);
}
{
this.set_selected_entry_index(new_ix, cx);
}
}
SearchState::Searching { query, .. } | SearchState::Searched { query, .. } => {

View file

@ -14,13 +14,11 @@ pub struct IncompatibleToolsState {
impl IncompatibleToolsState {
pub fn new(thread: Entity<Thread>, cx: &mut Context<Self>) -> Self {
let _tool_working_set_subscription =
cx.subscribe(&thread, |this, _, event, _| match event {
ThreadEvent::ProfileChanged => {
this.cache.clear();
}
_ => {}
});
let _tool_working_set_subscription = cx.subscribe(&thread, |this, _, event, _| {
if let ThreadEvent::ProfileChanged = event {
this.cache.clear();
}
});
Self {
cache: HashMap::default(),

View file

@ -177,11 +177,11 @@ impl AskPassSession {
_ = askpass_opened_rx.fuse() => {
// Note: this await can only resolve after we are dropped.
askpass_kill_master_rx.await.ok();
return AskPassResult::CancelledByUser
AskPassResult::CancelledByUser
}
_ = futures::FutureExt::fuse(smol::Timer::after(connection_timeout)) => {
return AskPassResult::Timedout
AskPassResult::Timedout
}
}
}
@ -215,7 +215,7 @@ pub fn main(socket: &str) {
}
#[cfg(target_os = "windows")]
while buffer.last().map_or(false, |&b| b == b'\n' || b == b'\r') {
while buffer.last().is_some_and(|&b| b == b'\n' || b == b'\r') {
buffer.pop();
}
if buffer.last() != Some(&b'\0') {

View file

@ -590,7 +590,7 @@ impl From<&Message> for MessageMetadata {
impl MessageMetadata {
pub fn is_cache_valid(&self, buffer: &BufferSnapshot, range: &Range<usize>) -> bool {
let result = match &self.cache {
match &self.cache {
Some(MessageCacheMetadata { cached_at, .. }) => !buffer.has_edits_since_in_range(
cached_at,
Range {
@ -599,8 +599,7 @@ impl MessageMetadata {
},
),
_ => false,
};
result
}
}
}
@ -1023,9 +1022,11 @@ impl AssistantContext {
summary: new_summary,
..
} => {
if self.summary.timestamp().map_or(true, |current_timestamp| {
new_summary.timestamp > current_timestamp
}) {
if self
.summary
.timestamp()
.is_none_or(|current_timestamp| new_summary.timestamp > current_timestamp)
{
self.summary = ContextSummary::Content(new_summary);
summary_generated = true;
}
@ -1076,20 +1077,20 @@ impl AssistantContext {
timestamp,
..
} => {
if let Some(slash_command) = self.invoked_slash_commands.get_mut(&id) {
if timestamp > slash_command.timestamp {
slash_command.timestamp = timestamp;
match error_message {
Some(message) => {
slash_command.status =
InvokedSlashCommandStatus::Error(message.into());
}
None => {
slash_command.status = InvokedSlashCommandStatus::Finished;
}
if let Some(slash_command) = self.invoked_slash_commands.get_mut(&id)
&& timestamp > slash_command.timestamp
{
slash_command.timestamp = timestamp;
match error_message {
Some(message) => {
slash_command.status =
InvokedSlashCommandStatus::Error(message.into());
}
None => {
slash_command.status = InvokedSlashCommandStatus::Finished;
}
cx.emit(ContextEvent::InvokedSlashCommandChanged { command_id: id });
}
cx.emit(ContextEvent::InvokedSlashCommandChanged { command_id: id });
}
}
ContextOperation::BufferOperation(_) => unreachable!(),
@ -1339,7 +1340,7 @@ impl AssistantContext {
let is_invalid = self
.messages_metadata
.get(&message_id)
.map_or(true, |metadata| {
.is_none_or(|metadata| {
!metadata.is_cache_valid(&buffer, &message.offset_range)
|| *encountered_invalid
});
@ -1368,10 +1369,10 @@ impl AssistantContext {
continue;
}
if let Some(last_anchor) = last_anchor {
if message.id == last_anchor {
hit_last_anchor = true;
}
if let Some(last_anchor) = last_anchor
&& message.id == last_anchor
{
hit_last_anchor = true;
}
new_anchor_needs_caching = new_anchor_needs_caching
@ -1406,10 +1407,10 @@ impl AssistantContext {
if !self.pending_completions.is_empty() {
return;
}
if let Some(cache_configuration) = cache_configuration {
if !cache_configuration.should_speculate {
return;
}
if let Some(cache_configuration) = cache_configuration
&& !cache_configuration.should_speculate
{
return;
}
let request = {
@ -1552,25 +1553,24 @@ impl AssistantContext {
})
.map(ToOwned::to_owned)
.collect::<SmallVec<_>>();
if let Some(command) = self.slash_commands.command(name, cx) {
if !command.requires_argument() || !arguments.is_empty() {
let start_ix = offset + command_line.name.start - 1;
let end_ix = offset
+ command_line
.arguments
.last()
.map_or(command_line.name.end, |argument| argument.end);
let source_range =
buffer.anchor_after(start_ix)..buffer.anchor_after(end_ix);
let pending_command = ParsedSlashCommand {
name: name.to_string(),
arguments,
source_range,
status: PendingSlashCommandStatus::Idle,
};
updated.push(pending_command.clone());
new_commands.push(pending_command);
}
if let Some(command) = self.slash_commands.command(name, cx)
&& (!command.requires_argument() || !arguments.is_empty())
{
let start_ix = offset + command_line.name.start - 1;
let end_ix = offset
+ command_line
.arguments
.last()
.map_or(command_line.name.end, |argument| argument.end);
let source_range = buffer.anchor_after(start_ix)..buffer.anchor_after(end_ix);
let pending_command = ParsedSlashCommand {
name: name.to_string(),
arguments,
source_range,
status: PendingSlashCommandStatus::Idle,
};
updated.push(pending_command.clone());
new_commands.push(pending_command);
}
}
@ -1799,14 +1799,13 @@ impl AssistantContext {
});
let end = this.buffer.read(cx).anchor_before(insert_position);
if run_commands_in_text {
if let Some(invoked_slash_command) =
if run_commands_in_text
&& let Some(invoked_slash_command) =
this.invoked_slash_commands.get_mut(&command_id)
{
invoked_slash_command
.run_commands_in_ranges
.push(start..end);
}
{
invoked_slash_command
.run_commands_in_ranges
.push(start..end);
}
}
SlashCommandEvent::EndSection => {
@ -1862,7 +1861,7 @@ impl AssistantContext {
{
let newline_offset = insert_position.saturating_sub(1);
if buffer.contains_str_at(newline_offset, "\n")
&& last_section_range.map_or(true, |last_section_range| {
&& last_section_range.is_none_or(|last_section_range| {
!last_section_range
.to_offset(buffer)
.contains(&newline_offset)
@ -2081,15 +2080,12 @@ impl AssistantContext {
match event {
LanguageModelCompletionEvent::StatusUpdate(status_update) => {
match status_update {
CompletionRequestStatus::UsageUpdated { amount, limit } => {
this.update_model_request_usage(
amount as u32,
limit,
cx,
);
}
_ => {}
if let CompletionRequestStatus::UsageUpdated { amount, limit } = status_update {
this.update_model_request_usage(
amount as u32,
limit,
cx,
);
}
}
LanguageModelCompletionEvent::StartMessage { .. } => {}
@ -2315,10 +2311,7 @@ impl AssistantContext {
let mut request_message = LanguageModelRequestMessage {
role: message.role,
content: Vec::new(),
cache: message
.cache
.as_ref()
.map_or(false, |cache| cache.is_anchor),
cache: message.cache.as_ref().is_some_and(|cache| cache.is_anchor),
};
while let Some(content) = contents.peek() {
@ -2741,10 +2734,10 @@ impl AssistantContext {
}
this.read_with(cx, |this, _cx| {
if let Some(summary) = this.summary.content() {
if summary.text.is_empty() {
bail!("Model generated an empty summary");
}
if let Some(summary) = this.summary.content()
&& summary.text.is_empty()
{
bail!("Model generated an empty summary");
}
Ok(())
})??;
@ -2799,7 +2792,7 @@ impl AssistantContext {
let mut current_message = messages.next();
while let Some(offset) = offsets.next() {
// Locate the message that contains the offset.
while current_message.as_ref().map_or(false, |message| {
while current_message.as_ref().is_some_and(|message| {
!message.offset_range.contains(&offset) && messages.peek().is_some()
}) {
current_message = messages.next();
@ -2809,7 +2802,7 @@ impl AssistantContext {
};
// Skip offsets that are in the same message.
while offsets.peek().map_or(false, |offset| {
while offsets.peek().is_some_and(|offset| {
message.offset_range.contains(offset) || messages.peek().is_none()
}) {
offsets.next();
@ -2924,18 +2917,18 @@ impl AssistantContext {
fs.create_dir(contexts_dir().as_ref()).await?;
// rename before write ensures that only one file exists
if let Some(old_path) = old_path.as_ref() {
if new_path.as_path() != old_path.as_ref() {
fs.rename(
old_path,
&new_path,
RenameOptions {
overwrite: true,
ignore_if_exists: true,
},
)
.await?;
}
if let Some(old_path) = old_path.as_ref()
&& new_path.as_path() != old_path.as_ref()
{
fs.rename(
old_path,
&new_path,
RenameOptions {
overwrite: true,
ignore_if_exists: true,
},
)
.await?;
}
// update path before write in case it fails

View file

@ -1055,7 +1055,7 @@ fn test_mark_cache_anchors(cx: &mut App) {
assert_eq!(
messages_cache(&context, cx)
.iter()
.filter(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor))
.filter(|(_, cache)| cache.as_ref().is_some_and(|cache| cache.is_anchor))
.count(),
0,
"Empty messages should not have any cache anchors."
@ -1083,7 +1083,7 @@ fn test_mark_cache_anchors(cx: &mut App) {
assert_eq!(
messages_cache(&context, cx)
.iter()
.filter(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor))
.filter(|(_, cache)| cache.as_ref().is_some_and(|cache| cache.is_anchor))
.count(),
0,
"Messages should not be marked for cache before going over the token minimum."
@ -1098,7 +1098,7 @@ fn test_mark_cache_anchors(cx: &mut App) {
assert_eq!(
messages_cache(&context, cx)
.iter()
.map(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor))
.map(|(_, cache)| cache.as_ref().is_some_and(|cache| cache.is_anchor))
.collect::<Vec<bool>>(),
vec![true, true, false],
"Last message should not be an anchor on speculative request."
@ -1116,7 +1116,7 @@ fn test_mark_cache_anchors(cx: &mut App) {
assert_eq!(
messages_cache(&context, cx)
.iter()
.map(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor))
.map(|(_, cache)| cache.as_ref().is_some_and(|cache| cache.is_anchor))
.collect::<Vec<bool>>(),
vec![false, true, true, false],
"Most recent message should also be cached if not a speculative request."

View file

@ -320,7 +320,7 @@ impl ContextStore {
.client
.subscribe_to_entity(remote_id)
.log_err()
.map(|subscription| subscription.set_entity(&cx.entity(), &mut cx.to_async()));
.map(|subscription| subscription.set_entity(&cx.entity(), &cx.to_async()));
self.advertise_contexts(cx);
} else {
self.client_subscription = None;
@ -789,7 +789,7 @@ impl ContextStore {
let fs = self.fs.clone();
cx.spawn(async move |this, cx| {
pub static ZED_STATELESS: LazyLock<bool> =
LazyLock::new(|| std::env::var("ZED_STATELESS").map_or(false, |v| !v.is_empty()));
LazyLock::new(|| std::env::var("ZED_STATELESS").is_ok_and(|v| !v.is_empty()));
if *ZED_STATELESS {
return Ok(());
}
@ -894,34 +894,33 @@ impl ContextStore {
return;
};
if protocol.capable(context_server::protocol::ServerCapability::Prompts) {
if let Some(response) = protocol
if protocol.capable(context_server::protocol::ServerCapability::Prompts)
&& let Some(response) = protocol
.request::<context_server::types::requests::PromptsList>(())
.await
.log_err()
{
let slash_command_ids = response
.prompts
.into_iter()
.filter(assistant_slash_commands::acceptable_prompt)
.map(|prompt| {
log::info!("registering context server command: {:?}", prompt.name);
slash_command_working_set.insert(Arc::new(
assistant_slash_commands::ContextServerSlashCommand::new(
context_server_store.clone(),
server.id(),
prompt,
),
))
})
.collect::<Vec<_>>();
this.update(cx, |this, _cx| {
this.context_server_slash_command_ids
.insert(server_id.clone(), slash_command_ids);
{
let slash_command_ids = response
.prompts
.into_iter()
.filter(assistant_slash_commands::acceptable_prompt)
.map(|prompt| {
log::debug!("registering context server command: {:?}", prompt.name);
slash_command_working_set.insert(Arc::new(
assistant_slash_commands::ContextServerSlashCommand::new(
context_server_store.clone(),
server.id(),
prompt,
),
))
})
.log_err();
}
.collect::<Vec<_>>();
this.update(cx, |this, _cx| {
this.context_server_slash_command_ids
.insert(server_id.clone(), slash_command_ids);
})
.log_err();
}
})
.detach();

View file

@ -39,10 +39,10 @@ impl SlashCommand for ContextServerSlashCommand {
fn label(&self, cx: &App) -> language::CodeLabel {
let mut parts = vec![self.prompt.name.as_str()];
if let Some(args) = &self.prompt.arguments {
if let Some(arg) = args.first() {
parts.push(arg.name.as_str());
}
if let Some(args) = &self.prompt.arguments
&& let Some(arg) = args.first()
{
parts.push(arg.name.as_str());
}
create_label_for_command(parts[0], &parts[1..], cx)
}
@ -62,9 +62,10 @@ impl SlashCommand for ContextServerSlashCommand {
}
fn requires_argument(&self) -> bool {
self.prompt.arguments.as_ref().map_or(false, |args| {
args.iter().any(|arg| arg.required == Some(true))
})
self.prompt
.arguments
.as_ref()
.is_some_and(|args| args.iter().any(|arg| arg.required == Some(true)))
}
fn complete_argument(

View file

@ -66,23 +66,22 @@ impl SlashCommand for DeltaSlashCommand {
.metadata
.as_ref()
.and_then(|value| serde_json::from_value::<FileCommandMetadata>(value.clone()).ok())
&& paths.insert(metadata.path.clone())
{
if paths.insert(metadata.path.clone()) {
file_command_old_outputs.push(
context_buffer
.as_rope()
.slice(section.range.to_offset(&context_buffer)),
);
file_command_new_outputs.push(Arc::new(FileSlashCommand).run(
std::slice::from_ref(&metadata.path),
context_slash_command_output_sections,
context_buffer.clone(),
workspace.clone(),
delegate.clone(),
window,
cx,
));
}
file_command_old_outputs.push(
context_buffer
.as_rope()
.slice(section.range.to_offset(&context_buffer)),
);
file_command_new_outputs.push(Arc::new(FileSlashCommand).run(
std::slice::from_ref(&metadata.path),
context_slash_command_output_sections,
context_buffer.clone(),
workspace.clone(),
delegate.clone(),
window,
cx,
));
}
}
@ -95,25 +94,25 @@ impl SlashCommand for DeltaSlashCommand {
.into_iter()
.zip(file_command_new_outputs)
{
if let Ok(new_output) = new_output {
if let Ok(new_output) = SlashCommandOutput::from_event_stream(new_output).await
{
if let Some(file_command_range) = new_output.sections.first() {
let new_text = &new_output.text[file_command_range.range.clone()];
if old_text.chars().ne(new_text.chars()) {
changes_detected = true;
output.sections.extend(new_output.sections.into_iter().map(
|section| SlashCommandOutputSection {
range: output.text.len() + section.range.start
..output.text.len() + section.range.end,
icon: section.icon,
label: section.label,
metadata: section.metadata,
},
));
output.text.push_str(&new_output.text);
}
}
if let Ok(new_output) = new_output
&& let Ok(new_output) = SlashCommandOutput::from_event_stream(new_output).await
&& let Some(file_command_range) = new_output.sections.first()
{
let new_text = &new_output.text[file_command_range.range.clone()];
if old_text.chars().ne(new_text.chars()) {
changes_detected = true;
output
.sections
.extend(new_output.sections.into_iter().map(|section| {
SlashCommandOutputSection {
range: output.text.len() + section.range.start
..output.text.len() + section.range.end,
icon: section.icon,
label: section.label,
metadata: section.metadata,
}
}));
output.text.push_str(&new_output.text);
}
}
}

View file

@ -61,7 +61,7 @@ impl DiagnosticsSlashCommand {
snapshot: worktree.snapshot(),
include_ignored: worktree
.root_entry()
.map_or(false, |entry| entry.is_ignored),
.is_some_and(|entry| entry.is_ignored),
include_root_name: true,
candidates: project::Candidates::Entries,
}
@ -280,10 +280,10 @@ fn collect_diagnostics(
let mut project_summary = DiagnosticSummary::default();
for (project_path, path, summary) in diagnostic_summaries {
if let Some(path_matcher) = &options.path_matcher {
if !path_matcher.is_match(&path) {
continue;
}
if let Some(path_matcher) = &options.path_matcher
&& !path_matcher.is_match(&path)
{
continue;
}
project_summary.error_count += summary.error_count;

View file

@ -92,7 +92,7 @@ impl FileSlashCommand {
snapshot: worktree.snapshot(),
include_ignored: worktree
.root_entry()
.map_or(false, |entry| entry.is_ignored),
.is_some_and(|entry| entry.is_ignored),
include_root_name: true,
candidates: project::Candidates::Entries,
}
@ -223,7 +223,7 @@ fn collect_files(
cx: &mut App,
) -> impl Stream<Item = Result<SlashCommandEvent>> + use<> {
let Ok(matchers) = glob_inputs
.into_iter()
.iter()
.map(|glob_input| {
custom_path_matcher::PathMatcher::new(&[glob_input.to_owned()])
.with_context(|| format!("invalid path {glob_input}"))
@ -379,7 +379,7 @@ fn collect_files(
}
}
while let Some(_) = directory_stack.pop() {
while directory_stack.pop().is_some() {
events_tx.unbounded_send(Ok(SlashCommandEvent::EndSection))?;
}
}
@ -491,7 +491,7 @@ mod custom_path_matcher {
impl PathMatcher {
pub fn new(globs: &[String]) -> Result<Self, globset::Error> {
let globs = globs
.into_iter()
.iter()
.map(|glob| Glob::new(&SanitizedPath::from(glob).to_glob_string()))
.collect::<Result<Vec<_>, _>>()?;
let sources = globs.iter().map(|glob| glob.glob().to_owned()).collect();
@ -536,7 +536,7 @@ mod custom_path_matcher {
let path_str = path.to_string_lossy();
let separator = std::path::MAIN_SEPARATOR_STR;
if path_str.ends_with(separator) {
return false;
false
} else {
self.glob.is_match(path_str.to_string() + separator)
}

View file

@ -195,16 +195,14 @@ fn tab_items_for_queries(
}
for editor in workspace.items_of_type::<Editor>(cx) {
if let Some(buffer) = editor.read(cx).buffer().read(cx).as_singleton() {
if let Some(timestamp) =
if let Some(buffer) = editor.read(cx).buffer().read(cx).as_singleton()
&& let Some(timestamp) =
timestamps_by_entity_id.get(&editor.entity_id())
{
if visited_buffers.insert(buffer.read(cx).remote_id()) {
let snapshot = buffer.read(cx).snapshot();
let full_path = snapshot.resolve_file_path(cx, true);
open_buffers.push((full_path, snapshot, *timestamp));
}
}
&& visited_buffers.insert(buffer.read(cx).remote_id())
{
let snapshot = buffer.read(cx).snapshot();
let full_path = snapshot.resolve_file_path(cx, true);
open_buffers.push((full_path, snapshot, *timestamp));
}
}

View file

@ -24,16 +24,16 @@ pub fn adapt_schema_to_format(
fn preprocess_json_schema(json: &mut Value) -> Result<()> {
// `additionalProperties` defaults to `false` unless explicitly specified.
// This prevents models from hallucinating tool parameters.
if let Value::Object(obj) = json {
if matches!(obj.get("type"), Some(Value::String(s)) if s == "object") {
if !obj.contains_key("additionalProperties") {
obj.insert("additionalProperties".to_string(), Value::Bool(false));
}
if let Value::Object(obj) = json
&& matches!(obj.get("type"), Some(Value::String(s)) if s == "object")
{
if !obj.contains_key("additionalProperties") {
obj.insert("additionalProperties".to_string(), Value::Bool(false));
}
// OpenAI API requires non-missing `properties`
if !obj.contains_key("properties") {
obj.insert("properties".to_string(), Value::Object(Default::default()));
}
// OpenAI API requires non-missing `properties`
if !obj.contains_key("properties") {
obj.insert("properties".to_string(), Value::Object(Default::default()));
}
}
Ok(())
@ -59,10 +59,10 @@ fn adapt_to_json_schema_subset(json: &mut Value) -> Result<()> {
("optional", |value| value.is_boolean()),
];
for (key, predicate) in KEYS_TO_REMOVE {
if let Some(value) = obj.get(key) {
if predicate(value) {
obj.remove(key);
}
if let Some(value) = obj.get(key)
&& predicate(value)
{
obj.remove(key);
}
}
@ -77,12 +77,12 @@ fn adapt_to_json_schema_subset(json: &mut Value) -> Result<()> {
}
// Handle oneOf -> anyOf conversion
if let Some(subschemas) = obj.get_mut("oneOf") {
if subschemas.is_array() {
let subschemas_clone = subschemas.clone();
obj.remove("oneOf");
obj.insert("anyOf".to_string(), subschemas_clone);
}
if let Some(subschemas) = obj.get_mut("oneOf")
&& subschemas.is_array()
{
let subschemas_clone = subschemas.clone();
obj.remove("oneOf");
obj.insert("anyOf".to_string(), subschemas_clone);
}
// Recursively process all nested objects and arrays

View file

@ -156,13 +156,13 @@ fn resolve_context_server_tool_name_conflicts(
if duplicated_tool_names.is_empty() {
return context_server_tools
.into_iter()
.iter()
.map(|tool| (resolve_tool_name(tool).into(), tool.clone()))
.collect();
}
context_server_tools
.into_iter()
.iter()
.filter_map(|tool| {
let mut tool_name = resolve_tool_name(tool);
if !duplicated_tool_names.contains(&tool_name) {

View file

@ -72,11 +72,10 @@ pub fn init(http_client: Arc<HttpClientWithUrl>, cx: &mut App) {
register_web_search_tool(&LanguageModelRegistry::global(cx), cx);
cx.subscribe(
&LanguageModelRegistry::global(cx),
move |registry, event, cx| match event {
language_model::Event::DefaultModelChanged => {
move |registry, event, cx| {
if let language_model::Event::DefaultModelChanged = event {
register_web_search_tool(&registry, cx);
}
_ => {}
},
)
.detach();
@ -86,7 +85,7 @@ fn register_web_search_tool(registry: &Entity<LanguageModelRegistry>, cx: &mut A
let using_zed_provider = registry
.read(cx)
.default_model()
.map_or(false, |default| default.is_provided_by_zed());
.is_some_and(|default| default.is_provided_by_zed());
if using_zed_provider {
ToolRegistry::global(cx).register_tool(WebSearchTool);
} else {

View file

@ -672,29 +672,30 @@ impl EditAgent {
cx: &mut AsyncApp,
) -> Result<BoxStream<'static, Result<String, LanguageModelCompletionError>>> {
let mut messages_iter = conversation.messages.iter_mut();
if let Some(last_message) = messages_iter.next_back() {
if last_message.role == Role::Assistant {
let old_content_len = last_message.content.len();
last_message
.content
.retain(|content| !matches!(content, MessageContent::ToolUse(_)));
let new_content_len = last_message.content.len();
if let Some(last_message) = messages_iter.next_back()
&& last_message.role == Role::Assistant
{
let old_content_len = last_message.content.len();
last_message
.content
.retain(|content| !matches!(content, MessageContent::ToolUse(_)));
let new_content_len = last_message.content.len();
// We just removed pending tool uses from the content of the
// last message, so it doesn't make sense to cache it anymore
// (e.g., the message will look very different on the next
// request). Thus, we move the flag to the message prior to it,
// as it will still be a valid prefix of the conversation.
if old_content_len != new_content_len && last_message.cache {
if let Some(prev_message) = messages_iter.next_back() {
last_message.cache = false;
prev_message.cache = true;
}
}
// We just removed pending tool uses from the content of the
// last message, so it doesn't make sense to cache it anymore
// (e.g., the message will look very different on the next
// request). Thus, we move the flag to the message prior to it,
// as it will still be a valid prefix of the conversation.
if old_content_len != new_content_len
&& last_message.cache
&& let Some(prev_message) = messages_iter.next_back()
{
last_message.cache = false;
prev_message.cache = true;
}
if last_message.content.is_empty() {
conversation.messages.pop();
}
if last_message.content.is_empty() {
conversation.messages.pop();
}
}

View file

@ -1283,14 +1283,14 @@ impl EvalAssertion {
// Parse the score from the response
let re = regex::Regex::new(r"<score>(\d+)</score>").unwrap();
if let Some(captures) = re.captures(&output) {
if let Some(score_match) = captures.get(1) {
let score = score_match.as_str().parse().unwrap_or(0);
return Ok(EvalAssertionOutcome {
score,
message: Some(output),
});
}
if let Some(captures) = re.captures(&output)
&& let Some(score_match) = captures.get(1)
{
let score = score_match.as_str().parse().unwrap_or(0);
return Ok(EvalAssertionOutcome {
score,
message: Some(output),
});
}
anyhow::bail!("No score found in response. Raw output: {output}");
@ -1586,7 +1586,7 @@ impl EditAgentTest {
let has_system_prompt = eval
.conversation
.first()
.map_or(false, |msg| msg.role == Role::System);
.is_some_and(|msg| msg.role == Role::System);
let messages = if has_system_prompt {
eval.conversation
} else {

View file

@ -155,10 +155,10 @@ impl Tool for EditFileTool {
// It's also possible that the global config dir is configured to be inside the project,
// so check for that edge case too.
if let Ok(canonical_path) = std::fs::canonicalize(&input.path) {
if canonical_path.starts_with(paths::config_dir()) {
return true;
}
if let Ok(canonical_path) = std::fs::canonicalize(&input.path)
&& canonical_path.starts_with(paths::config_dir())
{
return true;
}
// Check if path is inside the global config directory
@ -199,10 +199,10 @@ impl Tool for EditFileTool {
.any(|c| c.as_os_str() == local_settings_folder.as_os_str())
{
description.push_str(" (local settings)");
} else if let Ok(canonical_path) = std::fs::canonicalize(&input.path) {
if canonical_path.starts_with(paths::config_dir()) {
description.push_str(" (global settings)");
}
} else if let Ok(canonical_path) = std::fs::canonicalize(&input.path)
&& canonical_path.starts_with(paths::config_dir())
{
description.push_str(" (global settings)");
}
description
@ -1356,8 +1356,7 @@ mod tests {
mode: mode.clone(),
};
let result = cx.update(|cx| resolve_path(&input, project, cx));
result
cx.update(|cx| resolve_path(&input, project, cx))
}
fn assert_resolved_path_eq(path: anyhow::Result<ProjectPath>, expected: &str) {

View file

@ -188,15 +188,14 @@ impl Tool for GrepTool {
// Check if this file should be excluded based on its worktree settings
if let Ok(Some(project_path)) = project.read_with(cx, |project, cx| {
project.find_project_path(&path, cx)
}) {
if cx.update(|cx| {
})
&& cx.update(|cx| {
let worktree_settings = WorktreeSettings::get(Some((&project_path).into()), cx);
worktree_settings.is_path_excluded(&project_path.path)
|| worktree_settings.is_path_private(&project_path.path)
}).unwrap_or(false) {
continue;
}
}
while *parse_status.borrow() != ParseStatus::Idle {
parse_status.changed().await?;
@ -284,12 +283,11 @@ impl Tool for GrepTool {
output.extend(snapshot.text_for_range(range));
output.push_str("\n```\n");
if let Some(ancestor_range) = ancestor_range {
if end_row < ancestor_range.end.row {
if let Some(ancestor_range) = ancestor_range
&& end_row < ancestor_range.end.row {
let remaining_lines = ancestor_range.end.row - end_row;
writeln!(output, "\n{} lines remaining in ancestor node. Read the file to see all.", remaining_lines)?;
}
}
matches_found += 1;
}

View file

@ -201,7 +201,7 @@ impl Tool for ReadFileTool {
buffer
.file()
.as_ref()
.map_or(true, |file| !file.disk_state().exists())
.is_none_or(|file| !file.disk_state().exists())
})? {
anyhow::bail!("{file_path} not found");
}

View file

@ -43,12 +43,11 @@ impl Transform for ToJsonSchemaSubsetTransform {
fn transform(&mut self, schema: &mut Schema) {
// Ensure that the type field is not an array, this happens when we use
// Option<T>, the type will be [T, "null"].
if let Some(type_field) = schema.get_mut("type") {
if let Some(types) = type_field.as_array() {
if let Some(first_type) = types.first() {
*type_field = first_type.clone();
}
}
if let Some(type_field) = schema.get_mut("type")
&& let Some(types) = type_field.as_array()
&& let Some(first_type) = types.first()
{
*type_field = first_type.clone();
}
// oneOf is not supported, use anyOf instead

View file

@ -59,12 +59,9 @@ impl TerminalTool {
}
if which::which("bash").is_ok() {
log::info!("agent selected bash for terminal tool");
"bash".into()
} else {
let shell = get_system_shell();
log::info!("agent selected {shell} for terminal tool");
shell
get_system_shell()
}
});
Self {
@ -216,7 +213,8 @@ impl Tool for TerminalTool {
async move |cx| {
let program = program.await;
let env = env.await;
let terminal = project
project
.update(cx, |project, cx| {
project.create_terminal(
TerminalKind::Task(task::SpawnInTerminal {
@ -229,8 +227,7 @@ impl Tool for TerminalTool {
cx,
)
})?
.await;
terminal
.await
}
});
@ -387,7 +384,7 @@ fn working_dir(
let project = project.read(cx);
let cd = &input.cd;
if cd == "." || cd == "" {
if cd == "." || cd.is_empty() {
// Accept "." or "" as meaning "the one worktree" if we only have one worktree.
let mut worktrees = project.worktrees(cx);
@ -412,10 +409,8 @@ fn working_dir(
{
return Ok(Some(input_path.into()));
}
} else {
if let Some(worktree) = project.worktree_for_root_name(cd, cx) {
return Ok(Some(worktree.read(cx).abs_path().to_path_buf()));
}
} else if let Some(worktree) = project.worktree_for_root_name(cd, cx) {
return Ok(Some(worktree.read(cx).abs_path().to_path_buf()));
}
anyhow::bail!("`cd` directory {cd:?} was not in any of the project's worktrees.");

View file

@ -543,7 +543,7 @@ impl AutoUpdater {
async fn update(this: Entity<Self>, mut cx: AsyncApp) -> Result<()> {
let (client, installed_version, previous_status, release_channel) =
this.read_with(&mut cx, |this, cx| {
this.read_with(&cx, |this, cx| {
(
this.http_client.clone(),
this.current_version,

View file

@ -186,11 +186,11 @@ unsafe extern "system" fn wnd_proc(
}),
WM_TERMINATE => {
with_dialog_data(hwnd, |data| {
if let Ok(result) = data.borrow_mut().rx.recv() {
if let Err(e) = result {
log::error!("Failed to update Zed: {:?}", e);
show_error(format!("Error: {:?}", e));
}
if let Ok(result) = data.borrow_mut().rx.recv()
&& let Err(e) = result
{
log::error!("Failed to update Zed: {:?}", e);
show_error(format!("Error: {:?}", e));
}
});
unsafe { PostQuitMessage(0) };

View file

@ -54,11 +54,7 @@ pub async fn stream_completion(
)])));
}
if request
.tools
.as_ref()
.map_or(false, |t| !t.tools.is_empty())
{
if request.tools.as_ref().is_some_and(|t| !t.tools.is_empty()) {
response = response.set_tool_config(request.tools);
}

View file

@ -82,11 +82,12 @@ impl Render for Breadcrumbs {
}
text_style.color = Color::Muted.color(cx);
if index == 0 && !TabBarSettings::get_global(cx).show && active_item.is_dirty(cx) {
if let Some(styled_element) = apply_dirty_filename_style(&segment, &text_style, cx)
{
return styled_element;
}
if index == 0
&& !TabBarSettings::get_global(cx).show
&& active_item.is_dirty(cx)
&& let Some(styled_element) = apply_dirty_filename_style(&segment, &text_style, cx)
{
return styled_element;
}
StyledText::new(segment.text.replace('\n', ""))

View file

@ -572,14 +572,14 @@ impl BufferDiffInner {
pending_range.end.column = 0;
}
if pending_range == (start_point..end_point) {
if !buffer.has_edits_since_in_range(
if pending_range == (start_point..end_point)
&& !buffer.has_edits_since_in_range(
&pending_hunk.buffer_version,
start_anchor..end_anchor,
) {
has_pending = true;
secondary_status = pending_hunk.new_status;
}
)
{
has_pending = true;
secondary_status = pending_hunk.new_status;
}
}
@ -1036,16 +1036,15 @@ impl BufferDiff {
_ => (true, Some(text::Anchor::MIN..text::Anchor::MAX)),
};
if let Some(secondary_changed_range) = secondary_diff_change {
if let Some(secondary_hunk_range) =
if let Some(secondary_changed_range) = secondary_diff_change
&& let Some(secondary_hunk_range) =
self.range_to_hunk_range(secondary_changed_range, buffer, cx)
{
if let Some(range) = &mut changed_range {
range.start = secondary_hunk_range.start.min(&range.start, buffer);
range.end = secondary_hunk_range.end.max(&range.end, buffer);
} else {
changed_range = Some(secondary_hunk_range);
}
{
if let Some(range) = &mut changed_range {
range.start = secondary_hunk_range.start.min(&range.start, buffer);
range.end = secondary_hunk_range.end.max(&range.end, buffer);
} else {
changed_range = Some(secondary_hunk_range);
}
}

View file

@ -116,7 +116,7 @@ impl ActiveCall {
envelope: TypedEnvelope<proto::IncomingCall>,
mut cx: AsyncApp,
) -> Result<proto::Ack> {
let user_store = this.read_with(&mut cx, |this, _| this.user_store.clone())?;
let user_store = this.read_with(&cx, |this, _| this.user_store.clone())?;
let call = IncomingCall {
room_id: envelope.payload.room_id,
participants: user_store
@ -147,7 +147,7 @@ impl ActiveCall {
let mut incoming_call = this.incoming_call.0.borrow_mut();
if incoming_call
.as_ref()
.map_or(false, |call| call.room_id == envelope.payload.room_id)
.is_some_and(|call| call.room_id == envelope.payload.room_id)
{
incoming_call.take();
}

View file

@ -64,7 +64,7 @@ pub struct RemoteParticipant {
impl RemoteParticipant {
pub fn has_video_tracks(&self) -> bool {
return !self.video_tracks.is_empty();
!self.video_tracks.is_empty()
}
pub fn can_write(&self) -> bool {

Some files were not shown because too many files have changed in this diff Show more