Compare commits

..

32 commits

Author SHA1 Message Date
David Kleingeld
1a31961dac Extract get docs and show actual hover into function 2025-08-25 17:39:52 +02:00
Richard Feldman
78ca73e0b9
Minimize diff some more
Co-authored-by: David Kleingeld <davidsk@zed.dev>
2025-08-25 11:30:45 -04:00
Richard Feldman
c70178b7ea
Go back to using previous_valid_anchor
Co-authored-by: David Kleingeld <davidsk@zed.dev>
2025-08-25 11:10:09 -04:00
Richard Feldman
4df010d33a
Remove unnecessary deferred conditional
Co-authored-by: David Kleingeld <davidsk@zed.dev>
2025-08-25 11:07:03 -04:00
Richard Feldman
c44f9dfb17
Remove redundant tag type check
Co-authored-by: David Kleingeld <davidsk@zed.dev>
2025-08-25 11:02:23 -04:00
Richard Feldman
7c5738142a
Minimize diff a bit
Co-authored-by: David Kleingeld <davidsk@zed.dev>
2025-08-25 11:00:52 -04:00
Richard Feldman
23c9eb875a
Go back to old visible inlay hint mapping
Co-authored-by: David Kleingeld <davidsk@zed.dev>
2025-08-25 11:00:00 -04:00
Richard Feldman
81578c0d9e
Restore previous hovered_offset logic
Co-authored-by: David Kleingeld <davidsk@zed.dev>
2025-08-25 10:49:28 -04:00
Richard Feldman
70918609d8
Fix has_pending_selection logic for inlay hovers
Co-authored-by: David Kleingeld <davidsk@zed.dev>
2025-08-25 10:41:34 -04:00
Richard Feldman
835651fc08
Remove conditionals reintroduced by mistake in the merge
Co-authored-by: David Kleingeld <davidsk@zed.dev>
2025-08-25 10:25:33 -04:00
Richard Feldman
267e44e513
Merge remote-tracking branch 'origin/main' into inlay-hint-tooltip 2025-08-25 09:50:07 -04:00
Richard Feldman
7ef2a8211f
Merge remote-tracking branch 'origin/main' into inlay-hint-tooltip 2025-07-15 10:53:58 -04:00
Richard Feldman
a509ae241a
Revert "wip"
This reverts commit 0bb1a5f98a.
2025-07-11 12:00:43 -04:00
Richard Feldman
0bb1a5f98a
wip 2025-07-11 12:00:41 -04:00
Richard Feldman
79f376d752
Clean up inlay hint hover logic 2025-07-11 11:05:23 -04:00
Richard Feldman
fe8b3fe53d
Drop debug logging 2025-07-10 22:30:29 -04:00
Richard Feldman
2cd812e54f
Delete unused import 2025-07-10 22:23:52 -04:00
Richard Feldman
6adf082e43
Revert "Add a hover when hovering over inlays"
This reverts commit 0d6232b373.
2025-07-10 22:23:39 -04:00
Richard Feldman
e4963e70cc
Revert "Attempt to fix hover"
This reverts commit e7c6f228d5.
2025-07-10 22:23:34 -04:00
Richard Feldman
e7c6f228d5
Attempt to fix hover 2025-07-10 22:23:26 -04:00
Richard Feldman
0d6232b373
Add a hover when hovering over inlays 2025-07-10 19:14:36 -04:00
Richard Feldman
ca4df68f31
Remove flashed black circle 2025-07-10 17:54:22 -04:00
Richard Feldman
c96b6a06f0
Don't show loading message 2025-07-10 17:46:13 -04:00
Richard Feldman
b1cd20a435
Remove all debug logging from inlay hint hover implementation 2025-07-10 17:46:08 -04:00
Richard Feldman
509375c83c
Remove some debug logging 2025-07-10 17:39:37 -04:00
Richard Feldman
b6bd9c0682
It works 2025-07-10 17:32:12 -04:00
Richard Feldman
8ee82395b8
Kinda make this work 2025-07-10 17:14:30 -04:00
Richard Feldman
a322aa33c7
wip - currently just shows a generic message, not the docs 2025-07-09 16:37:37 -04:00
Richard Feldman
2ff30d20e3
Fix inlay hint hover by not clearing hover when mouse is over inlay
When hovering over an inlay hint, point_for_position.as_valid() returns None
because inlays don't have valid text positions. This was causing hover_at(editor, None)
to be called, which would hide any active hovers.

The fix is simple: don't call hover_at when we're over an inlay position.
The inlay hover is already handled by update_hovered_link, so we don't need
to do anything else.
2025-07-09 13:15:43 -04:00
Richard Feldman
01d7b3345b
Fix inlay hint caching 2025-07-09 13:04:11 -04:00
Richard Feldman
17f7312fc0
Extract resolve_hint
Co-authored-by: Cole Miller <cole@zed.dev>
2025-07-07 16:48:33 -04:00
Richard Feldman
a61e478152
Reproduce #33715 in a test 2025-07-02 15:50:02 -04:00
132 changed files with 4621 additions and 8045 deletions

View file

@ -14,7 +14,7 @@ body:
### Description
<!-- Describe with sufficient detail to reproduce from a clean Zed install.
- Any code must be sufficient to reproduce (include context!)
- Include code as text, not just as a screenshot.
- Code must as text, not just as a screenshot.
- Issues with insufficient detail may be summarily closed.
-->

1
Cargo.lock generated
View file

@ -13521,7 +13521,6 @@ dependencies = [
"smol",
"sysinfo",
"telemetry_events",
"thiserror 2.0.12",
"toml 0.8.20",
"unindent",
"util",

View file

@ -1,4 +0,0 @@
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M8 12.375H13" stroke="black" stroke-width="1.2" stroke-linecap="round" stroke-linejoin="round"/>
<path d="M3 11.125L6.75003 7.375L3 3.62497" stroke="black" stroke-width="1.2" stroke-linecap="round" stroke-linejoin="round"/>
</svg>

Before

Width:  |  Height:  |  Size: 336 B

File diff suppressed because it is too large Load diff

Before

Width:  |  Height:  |  Size: 176 KiB

View file

@ -1 +0,0 @@
<svg xmlns="http://www.w3.org/2000/svg" width="160" height="61" fill="none"><g clip-path="url(#a)"><path fill="#000" d="M130.75.385c5.428 0 10.297 2.81 13.011 7.511l14.214 24.618-.013-.005c2.599 4.504 2.707 9.932.28 14.513-2.618 4.944-7.862 8.015-13.679 8.015h-31.811c-.452 0-.873-.242-1.103-.637a1.268 1.268 0 0 1 0-1.274l3.919-6.78c.223-.394.65-.636 1.102-.636h28.288a5.622 5.622 0 0 0 4.925-2.849 5.615 5.615 0 0 0 0-5.69l-14.214-24.617a5.621 5.621 0 0 0-4.925-2.848 5.621 5.621 0 0 0-4.925 2.848l-14.214 24.618a6.267 6.267 0 0 0-.319.643.998.998 0 0 1-.069.14L101.724 54.4l-.823 1.313-2.529 4.39a1.27 1.27 0 0 1-1.103.636h-7.83c-.452 0-.873-.242-1.102-.637-.23-.394-.23-.879 0-1.274l2.188-3.791H66.803c-3.32 0-6.454-1.122-8.818-3.167a17.141 17.141 0 0 1-3.394-3.96 1.261 1.261 0 0 1-.091-.137L34.2 12.573a5.622 5.622 0 0 0-4.925-2.849 5.621 5.621 0 0 0-4.924 2.85L10.137 37.19a5.615 5.615 0 0 0 0 5.69 5.63 5.63 0 0 0 4.925 2.841h29.862a1.276 1.276 0 0 1 1.102 1.912l-3.912 6.778a1.27 1.27 0 0 1-1.102.638H14.495c-3.32 0-6.454-1.128-8.817-3.173-5.906-5.104-7.36-12.883-3.62-19.363L16.267 7.89C18.872 3.385 23.517.583 28.697.39c.184-.006.356-.006.534-.006 5.378 0 10.45 3.007 13.246 7.85l12.986 22.372L68.58 7.891C71.186 3.385 75.83.582 81.01.39c.185-.006.358-.006.536-.006 4.453 0 8.71 2.039 11.672 5.588.337.407.388.98.127 1.446l-3.765 6.6a1.268 1.268 0 0 1-2.205.006l-.847-1.465a5.623 5.623 0 0 0-4.926-2.848 5.622 5.622 0 0 0-4.924 2.848L62.464 37.18a5.614 5.614 0 0 0 0 5.689 5.628 5.628 0 0 0 4.925 2.842H95.91L117.76 7.87c2.714-4.683 7.575-7.486 12.99-7.486Z"/></g><defs><clipPath id="a"><path fill="#fff" d="M0 .385h160v60.36H0z"/></clipPath></defs></svg>

Before

Width:  |  Height:  |  Size: 1.6 KiB

File diff suppressed because one or more lines are too long

Before

Width:  |  Height:  |  Size: 14 KiB

View file

@ -40,7 +40,7 @@
"shift-f11": "debugger::StepOut",
"f11": "zed::ToggleFullScreen",
"ctrl-alt-z": "edit_prediction::RateCompletions",
"ctrl-alt-shift-i": "edit_prediction::ToggleMenu",
"ctrl-shift-i": "edit_prediction::ToggleMenu",
"ctrl-alt-l": "lsp_tool::ToggleMenu"
}
},
@ -120,7 +120,7 @@
"alt-g m": "git::OpenModifiedFiles",
"menu": "editor::OpenContextMenu",
"shift-f10": "editor::OpenContextMenu",
"ctrl-alt-shift-e": "editor::ToggleEditPrediction",
"ctrl-shift-e": "editor::ToggleEditPrediction",
"f9": "editor::ToggleBreakpoint",
"shift-f9": "editor::EditLogBreakpoint"
}

File diff suppressed because it is too large Load diff

View file

@ -38,7 +38,6 @@
"alt-;": ["editor::ToggleComments", { "advance_downwards": false }],
"ctrl-x ctrl-;": "editor::ToggleComments",
"alt-.": "editor::GoToDefinition", // xref-find-definitions
"alt-?": "editor::FindAllReferences", // xref-find-references
"alt-,": "pane::GoBack", // xref-pop-marker-stack
"ctrl-x h": "editor::SelectAll", // mark-whole-buffer
"ctrl-d": "editor::Delete", // delete-char

View file

@ -38,7 +38,6 @@
"alt-;": ["editor::ToggleComments", { "advance_downwards": false }],
"ctrl-x ctrl-;": "editor::ToggleComments",
"alt-.": "editor::GoToDefinition", // xref-find-definitions
"alt-?": "editor::FindAllReferences", // xref-find-references
"alt-,": "pane::GoBack", // xref-pop-marker-stack
"ctrl-x h": "editor::SelectAll", // mark-whole-buffer
"ctrl-d": "editor::Delete", // delete-char

View file

@ -428,13 +428,11 @@
"g h": "vim::StartOfLine",
"g s": "vim::FirstNonWhitespace", // "g s" default behavior is "space s"
"g e": "vim::EndOfDocument",
"g .": "vim::HelixGotoLastModification", // go to last modification
"g r": "editor::FindAllReferences", // zed specific
"g t": "vim::WindowTop",
"g c": "vim::WindowMiddle",
"g b": "vim::WindowBottom",
"shift-r": "editor::Paste",
"x": "editor::SelectLine",
"shift-x": "editor::SelectLine",
"%": "editor::SelectAll",

View file

@ -162,12 +162,6 @@
// 2. Always quit the application
// "on_last_window_closed": "quit_app",
"on_last_window_closed": "platform_default",
// Whether to show padding for zoomed panels.
// When enabled, zoomed center panels (e.g. code editor) will have padding all around,
// while zoomed bottom/left/right panels will have padding to the top/right/left (respectively).
//
// Default: true
"zoomed_padding": true,
// Whether to use the system provided dialogs for Open and Save As.
// When set to false, Zed will use the built-in keyboard-first pickers.
"use_system_path_prompts": true,
@ -653,8 +647,6 @@
// "never"
"show": "always"
},
// Whether to enable drag-and-drop operations in the project panel.
"drag_and_drop": true,
// Whether to hide the root entry when only one folder is open in the window.
"hide_root": false
},
@ -1637,9 +1629,6 @@
"allowed": true
}
},
"Kotlin": {
"language_servers": ["kotlin-language-server", "!kotlin-lsp", "..."]
},
"LaTeX": {
"formatter": "language_server",
"language_servers": ["texlab", "..."],

View file

@ -43,8 +43,8 @@
// "args": ["--login"]
// }
// }
"shell": "system"
"shell": "system",
// Represents the tags for inline runnable indicators, or spawning multiple tasks at once.
// "tags": []
"tags": []
}
]

View file

@ -183,15 +183,16 @@ impl ToolCall {
language_registry: Arc<LanguageRegistry>,
cx: &mut App,
) -> Self {
let title = if let Some((first_line, _)) = tool_call.title.split_once("\n") {
first_line.to_owned() + ""
} else {
tool_call.title
};
Self {
id: tool_call.id,
label: cx
.new(|cx| Markdown::new(title.into(), Some(language_registry.clone()), None, cx)),
label: cx.new(|cx| {
Markdown::new(
tool_call.title.into(),
Some(language_registry.clone()),
None,
cx,
)
}),
kind: tool_call.kind,
content: tool_call
.content
@ -232,11 +233,7 @@ impl ToolCall {
if let Some(title) = title {
self.label.update(cx, |label, cx| {
if let Some((first_line, _)) = title.split_once("\n") {
label.replace(first_line.to_owned() + "", cx)
} else {
label.replace(title, cx);
}
label.replace(title, cx);
});
}
@ -759,8 +756,6 @@ pub struct AcpThread {
connection: Rc<dyn AgentConnection>,
session_id: acp::SessionId,
token_usage: Option<TokenUsage>,
prompt_capabilities: acp::PromptCapabilities,
_observe_prompt_capabilities: Task<anyhow::Result<()>>,
}
#[derive(Debug)]
@ -775,12 +770,11 @@ pub enum AcpThreadEvent {
Stopped,
Error,
LoadError(LoadError),
PromptCapabilitiesUpdated,
}
impl EventEmitter<AcpThreadEvent> for AcpThread {}
#[derive(PartialEq, Eq, Debug)]
#[derive(PartialEq, Eq)]
pub enum ThreadStatus {
Idle,
WaitingForToolConfirmation,
@ -827,20 +821,7 @@ impl AcpThread {
project: Entity<Project>,
action_log: Entity<ActionLog>,
session_id: acp::SessionId,
mut prompt_capabilities_rx: watch::Receiver<acp::PromptCapabilities>,
cx: &mut Context<Self>,
) -> Self {
let prompt_capabilities = *prompt_capabilities_rx.borrow();
let task = cx.spawn::<_, anyhow::Result<()>>(async move |this, cx| {
loop {
let caps = prompt_capabilities_rx.recv().await?;
this.update(cx, |this, cx| {
this.prompt_capabilities = caps;
cx.emit(AcpThreadEvent::PromptCapabilitiesUpdated);
})?;
}
});
Self {
action_log,
shared_buffers: Default::default(),
@ -852,15 +833,9 @@ impl AcpThread {
connection,
session_id,
token_usage: None,
prompt_capabilities,
_observe_prompt_capabilities: task,
}
}
pub fn prompt_capabilities(&self) -> acp::PromptCapabilities {
self.prompt_capabilities
}
pub fn connection(&self) -> &Rc<dyn AgentConnection> {
&self.connection
}
@ -2624,19 +2599,13 @@ mod tests {
.into(),
);
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let thread = cx.new(|cx| {
let thread = cx.new(|_cx| {
AcpThread::new(
"Test",
self.clone(),
project,
action_log,
session_id.clone(),
watch::Receiver::constant(acp::PromptCapabilities {
image: true,
audio: true,
embedded_context: true,
}),
cx,
)
});
self.sessions.lock().insert(session_id, thread.downgrade());
@ -2670,6 +2639,14 @@ mod tests {
}
}
fn prompt_capabilities(&self) -> acp::PromptCapabilities {
acp::PromptCapabilities {
image: true,
audio: true,
embedded_context: true,
}
}
fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
let sessions = self.sessions.lock();
let thread = sessions.get(session_id).unwrap().clone();

View file

@ -38,6 +38,8 @@ pub trait AgentConnection {
cx: &mut App,
) -> Task<Result<acp::PromptResponse>>;
fn prompt_capabilities(&self) -> acp::PromptCapabilities;
fn resume(
&self,
_session_id: &acp::SessionId,
@ -327,19 +329,13 @@ mod test_support {
) -> Task<gpui::Result<Entity<AcpThread>>> {
let session_id = acp::SessionId(self.sessions.lock().len().to_string().into());
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let thread = cx.new(|cx| {
let thread = cx.new(|_cx| {
AcpThread::new(
"Test",
self.clone(),
project,
action_log,
session_id.clone(),
watch::Receiver::constant(acp::PromptCapabilities {
image: true,
audio: true,
embedded_context: true,
}),
cx,
)
});
self.sessions.lock().insert(
@ -352,6 +348,14 @@ mod test_support {
Task::ready(Ok(thread))
}
fn prompt_capabilities(&self) -> acp::PromptCapabilities {
acp::PromptCapabilities {
image: true,
audio: true,
embedded_context: true,
}
}
fn authenticate(
&self,
_method_id: acp::AuthMethodId,

View file

@ -21,12 +21,12 @@ use ui::prelude::*;
use util::ResultExt as _;
use workspace::{Item, Workspace};
actions!(dev, [OpenAcpLogs]);
actions!(acp, [OpenDebugTools]);
pub fn init(cx: &mut App) {
cx.observe_new(
|workspace: &mut Workspace, _window, _cx: &mut Context<Workspace>| {
workspace.register_action(|workspace, _: &OpenAcpLogs, window, cx| {
workspace.register_action(|workspace, _: &OpenDebugTools, window, cx| {
let acp_tools =
Box::new(cx.new(|cx| AcpTools::new(workspace.project().clone(), cx)));
workspace.add_item_to_active_pane(acp_tools, None, true, window, cx);

View file

@ -664,7 +664,7 @@ impl Thread {
}
pub fn get_or_init_configured_model(&mut self, cx: &App) -> Option<ConfiguredModel> {
if self.configured_model.is_none() {
if self.configured_model.is_none() || self.messages.is_empty() {
self.configured_model = LanguageModelRegistry::read_global(cx).default_model();
}
self.configured_model.clone()
@ -2097,7 +2097,7 @@ impl Thread {
}
pub fn summarize(&mut self, cx: &mut Context<Self>) {
let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model(cx) else {
println!("No thread summary model");
return;
};
@ -2416,7 +2416,7 @@ impl Thread {
}
let Some(ConfiguredModel { model, provider }) =
LanguageModelRegistry::read_global(cx).thread_summary_model()
LanguageModelRegistry::read_global(cx).thread_summary_model(cx)
else {
return;
};
@ -5410,13 +5410,10 @@ fn main() {{
}),
cx,
);
registry.set_thread_summary_model(
Some(ConfiguredModel {
provider,
model: model.clone(),
}),
cx,
);
registry.set_thread_summary_model(Some(ConfiguredModel {
provider,
model: model.clone(),
}));
})
});

View file

@ -228,7 +228,7 @@ impl NativeAgent {
) -> Entity<AcpThread> {
let connection = Rc::new(NativeAgentConnection(cx.entity()));
let registry = LanguageModelRegistry::read_global(cx);
let summarization_model = registry.thread_summary_model().map(|c| c.model);
let summarization_model = registry.thread_summary_model(cx).map(|c| c.model);
thread_handle.update(cx, |thread, cx| {
thread.set_summarization_model(summarization_model, cx);
@ -240,16 +240,13 @@ impl NativeAgent {
let title = thread.title();
let project = thread.project.clone();
let action_log = thread.action_log.clone();
let prompt_capabilities_rx = thread.prompt_capabilities_rx.clone();
let acp_thread = cx.new(|cx| {
let acp_thread = cx.new(|_cx| {
acp_thread::AcpThread::new(
title,
connection,
project.clone(),
action_log.clone(),
session_id.clone(),
prompt_capabilities_rx,
cx,
)
});
let subscriptions = vec![
@ -524,7 +521,7 @@ impl NativeAgent {
let registry = LanguageModelRegistry::read_global(cx);
let default_model = registry.default_model().map(|m| m.model);
let summarization_model = registry.thread_summary_model().map(|m| m.model);
let summarization_model = registry.thread_summary_model(cx).map(|m| m.model);
for session in self.sessions.values_mut() {
session.thread.update(cx, |thread, cx| {
@ -928,6 +925,14 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
})
}
fn prompt_capabilities(&self) -> acp::PromptCapabilities {
acp::PromptCapabilities {
image: true,
audio: false,
embedded_context: true,
}
}
fn resume(
&self,
session_id: &acp::SessionId,

View file

@ -22,10 +22,6 @@ impl NativeAgentServer {
}
impl AgentServer for NativeAgentServer {
fn telemetry_id(&self) -> &'static str {
"zed"
}
fn name(&self) -> SharedString {
"Zed Agent".into()
}

View file

@ -72,7 +72,6 @@ async fn test_echo(cx: &mut TestAppContext) {
}
#[gpui::test]
#[cfg_attr(target_os = "windows", ignore)] // TODO: Fix this test on Windows
async fn test_thinking(cx: &mut TestAppContext) {
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
let fake_model = model.as_fake();
@ -472,7 +471,7 @@ async fn test_tool_authorization(cx: &mut TestAppContext) {
tool_name: ToolRequiringPermission::name().into(),
is_error: true,
content: "Permission to run tool denied by user".into(),
output: Some("Permission to run tool denied by user".into())
output: None
})
]
);
@ -1348,7 +1347,6 @@ async fn test_cancellation(cx: &mut TestAppContext) {
}
#[gpui::test]
#[cfg_attr(target_os = "windows", ignore)] // TODO: Fix this test on Windows
async fn test_in_progress_send_canceled_by_next_send(cx: &mut TestAppContext) {
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
let fake_model = model.as_fake();
@ -1687,7 +1685,6 @@ async fn test_truncate_second_message(cx: &mut TestAppContext) {
}
#[gpui::test]
#[cfg_attr(target_os = "windows", ignore)] // TODO: Fix this test on Windows
async fn test_title_generation(cx: &mut TestAppContext) {
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
let fake_model = model.as_fake();
@ -1822,11 +1819,11 @@ async fn test_agent_connection(cx: &mut TestAppContext) {
let clock = Arc::new(clock::FakeSystemClock::new());
let client = Client::new(clock, http_client, cx);
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
Project::init_settings(cx);
agent_settings::init(cx);
language_model::init(client.clone(), cx);
language_models::init(user_store, client.clone(), cx);
Project::init_settings(cx);
LanguageModelRegistry::test(cx);
agent_settings::init(cx);
});
cx.executor().forbid_parking();

View file

@ -575,22 +575,11 @@ pub struct Thread {
templates: Arc<Templates>,
model: Option<Arc<dyn LanguageModel>>,
summarization_model: Option<Arc<dyn LanguageModel>>,
prompt_capabilities_tx: watch::Sender<acp::PromptCapabilities>,
pub(crate) prompt_capabilities_rx: watch::Receiver<acp::PromptCapabilities>,
pub(crate) project: Entity<Project>,
pub(crate) action_log: Entity<ActionLog>,
}
impl Thread {
fn prompt_capabilities(model: Option<&dyn LanguageModel>) -> acp::PromptCapabilities {
let image = model.map_or(true, |model| model.supports_images());
acp::PromptCapabilities {
image,
audio: false,
embedded_context: true,
}
}
pub fn new(
project: Entity<Project>,
project_context: Entity<ProjectContext>,
@ -601,8 +590,6 @@ impl Thread {
) -> Self {
let profile_id = AgentSettings::get_global(cx).default_profile.clone();
let action_log = cx.new(|_cx| ActionLog::new(project.clone()));
let (prompt_capabilities_tx, prompt_capabilities_rx) =
watch::channel(Self::prompt_capabilities(model.as_deref()));
Self {
id: acp::SessionId(uuid::Uuid::new_v4().to_string().into()),
prompt_id: PromptId::new(),
@ -630,8 +617,6 @@ impl Thread {
templates,
model,
summarization_model: None,
prompt_capabilities_tx,
prompt_capabilities_rx,
project,
action_log,
}
@ -732,17 +717,7 @@ impl Thread {
stream.update_tool_call_fields(
&tool_use.id,
acp::ToolCallUpdateFields {
status: Some(
tool_result
.as_ref()
.map_or(acp::ToolCallStatus::Failed, |result| {
if result.is_error {
acp::ToolCallStatus::Failed
} else {
acp::ToolCallStatus::Completed
}
}),
),
status: Some(acp::ToolCallStatus::Completed),
raw_output: output,
..Default::default()
},
@ -775,8 +750,6 @@ impl Thread {
.or_else(|| registry.default_model())
.map(|model| model.model)
});
let (prompt_capabilities_tx, prompt_capabilities_rx) =
watch::channel(Self::prompt_capabilities(model.as_deref()));
Self {
id,
@ -806,8 +779,6 @@ impl Thread {
project,
action_log,
updated_at: db_thread.updated_at,
prompt_capabilities_tx,
prompt_capabilities_rx,
}
}
@ -975,12 +946,10 @@ impl Thread {
pub fn set_model(&mut self, model: Arc<dyn LanguageModel>, cx: &mut Context<Self>) {
let old_usage = self.latest_token_usage();
self.model = Some(model);
let new_caps = Self::prompt_capabilities(self.model.as_deref());
let new_usage = self.latest_token_usage();
if old_usage != new_usage {
cx.emit(TokenUsageUpdated(new_usage));
}
self.prompt_capabilities_tx.send(new_caps).log_err();
cx.notify()
}
@ -1173,7 +1142,37 @@ impl Thread {
_task: cx.spawn(async move |this, cx| {
log::debug!("Starting agent turn execution");
let turn_result = Self::run_turn_internal(&this, model, &event_stream, cx).await;
let turn_result: Result<()> = async {
let mut intent = CompletionIntent::UserPrompt;
loop {
Self::stream_completion(&this, &model, intent, &event_stream, cx).await?;
let mut end_turn = true;
this.update(cx, |this, cx| {
// Generate title if needed.
if this.title.is_none() && this.pending_title_generation.is_none() {
this.generate_title(cx);
}
// End the turn if the model didn't use tools.
let message = this.pending_message.as_ref();
end_turn =
message.map_or(true, |message| message.tool_results.is_empty());
this.flush_pending_message(cx);
})?;
if this.read_with(cx, |this, _| this.tool_use_limit_reached)? {
log::info!("Tool use limit reached, completing turn");
return Err(language_model::ToolUseLimitReachedError.into());
} else if end_turn {
log::debug!("No tool uses found, completing turn");
return Ok(());
} else {
intent = CompletionIntent::ToolResults;
}
}
}
.await;
_ = this.update(cx, |this, cx| this.flush_pending_message(cx));
match turn_result {
@ -1204,17 +1203,20 @@ impl Thread {
Ok(events_rx)
}
async fn run_turn_internal(
async fn stream_completion(
this: &WeakEntity<Self>,
model: Arc<dyn LanguageModel>,
model: &Arc<dyn LanguageModel>,
completion_intent: CompletionIntent,
event_stream: &ThreadEventStream,
cx: &mut AsyncApp,
) -> Result<()> {
let mut attempt = 0;
let mut intent = CompletionIntent::UserPrompt;
log::debug!("Stream completion started successfully");
let mut attempt = None;
loop {
let request =
this.update(cx, |this, cx| this.build_completion_request(intent, cx))??;
let request = this.update(cx, |this, cx| {
this.build_completion_request(completion_intent, cx)
})??;
telemetry::event!(
"Agent Thread Completion",
@ -1225,19 +1227,23 @@ impl Thread {
attempt
);
log::debug!("Calling model.stream_completion, attempt {}", attempt);
log::debug!(
"Calling model.stream_completion, attempt {}",
attempt.unwrap_or(0)
);
let mut events = model
.stream_completion(request, cx)
.await
.map_err(|error| anyhow!(error))?;
let mut tool_results = FuturesUnordered::new();
let mut error = None;
while let Some(event) = events.next().await {
log::trace!("Received completion event: {:?}", event);
match event {
Ok(event) => {
log::trace!("Received completion event: {:?}", event);
tool_results.extend(this.update(cx, |this, cx| {
this.handle_completion_event(event, event_stream, cx)
this.handle_streamed_completion_event(event, event_stream, cx)
})??);
}
Err(err) => {
@ -1247,7 +1253,6 @@ impl Thread {
}
}
let end_turn = tool_results.is_empty();
while let Some(tool_result) = tool_results.next().await {
log::debug!("Tool finished {:?}", tool_result);
@ -1270,83 +1275,65 @@ impl Thread {
})?;
}
this.update(cx, |this, cx| {
this.flush_pending_message(cx);
if this.title.is_none() && this.pending_title_generation.is_none() {
this.generate_title(cx);
}
})?;
if let Some(error) = error {
attempt += 1;
let retry =
this.update(cx, |this, _| this.handle_completion_error(error, attempt))??;
let timer = cx.background_executor().timer(retry.duration);
event_stream.send_retry(retry);
timer.await;
this.update(cx, |this, _cx| {
let completion_mode = this.read_with(cx, |thread, _cx| thread.completion_mode())?;
if completion_mode == CompletionMode::Normal {
return Err(anyhow!(error))?;
}
let Some(strategy) = Self::retry_strategy_for(&error) else {
return Err(anyhow!(error))?;
};
let max_attempts = match &strategy {
RetryStrategy::ExponentialBackoff { max_attempts, .. } => *max_attempts,
RetryStrategy::Fixed { max_attempts, .. } => *max_attempts,
};
let attempt = attempt.get_or_insert(0u8);
*attempt += 1;
let attempt = *attempt;
if attempt > max_attempts {
return Err(anyhow!(error))?;
}
let delay = match &strategy {
RetryStrategy::ExponentialBackoff { initial_delay, .. } => {
let delay_secs = initial_delay.as_secs() * 2u64.pow((attempt - 1) as u32);
Duration::from_secs(delay_secs)
}
RetryStrategy::Fixed { delay, .. } => *delay,
};
log::debug!("Retry attempt {attempt} with delay {delay:?}");
event_stream.send_retry(acp_thread::RetryStatus {
last_error: error.to_string().into(),
attempt: attempt as usize,
max_attempts: max_attempts as usize,
started_at: Instant::now(),
duration: delay,
});
cx.background_executor().timer(delay).await;
this.update(cx, |this, cx| {
this.flush_pending_message(cx);
if let Some(Message::Agent(message)) = this.messages.last() {
if message.tool_results.is_empty() {
intent = CompletionIntent::UserPrompt;
this.messages.push(Message::Resume);
}
}
})?;
} else if this.read_with(cx, |this, _| this.tool_use_limit_reached)? {
return Err(language_model::ToolUseLimitReachedError.into());
} else if end_turn {
return Ok(());
} else {
intent = CompletionIntent::ToolResults;
attempt = 0;
return Ok(());
}
}
}
fn handle_completion_error(
&mut self,
error: LanguageModelCompletionError,
attempt: u8,
) -> Result<acp_thread::RetryStatus> {
if self.completion_mode == CompletionMode::Normal {
return Err(anyhow!(error));
}
let Some(strategy) = Self::retry_strategy_for(&error) else {
return Err(anyhow!(error));
};
let max_attempts = match &strategy {
RetryStrategy::ExponentialBackoff { max_attempts, .. } => *max_attempts,
RetryStrategy::Fixed { max_attempts, .. } => *max_attempts,
};
if attempt > max_attempts {
return Err(anyhow!(error));
}
let delay = match &strategy {
RetryStrategy::ExponentialBackoff { initial_delay, .. } => {
let delay_secs = initial_delay.as_secs() * 2u64.pow((attempt - 1) as u32);
Duration::from_secs(delay_secs)
}
RetryStrategy::Fixed { delay, .. } => *delay,
};
log::debug!("Retry attempt {attempt} with delay {delay:?}");
Ok(acp_thread::RetryStatus {
last_error: error.to_string().into(),
attempt: attempt as usize,
max_attempts: max_attempts as usize,
started_at: Instant::now(),
duration: delay,
})
}
/// A helper method that's called on every streamed completion event.
/// Returns an optional tool result task, which the main agentic loop will
/// send back to the model when it resolves.
fn handle_completion_event(
fn handle_streamed_completion_event(
&mut self,
event: LanguageModelCompletionEvent,
event_stream: &ThreadEventStream,
@ -1567,7 +1554,7 @@ impl Thread {
tool_name: tool_use.name,
is_error: true,
content: LanguageModelToolResultContent::Text(Arc::from(error.to_string())),
output: Some(error.to_string().into()),
output: None,
},
}
}))
@ -2469,30 +2456,6 @@ impl ToolCallEventStreamReceiver {
}
}
pub async fn expect_update_fields(&mut self) -> acp::ToolCallUpdateFields {
let event = self.0.next().await;
if let Some(Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
update,
)))) = event
{
update.fields
} else {
panic!("Expected update fields but got: {:?}", event);
}
}
pub async fn expect_diff(&mut self) -> Entity<acp_thread::Diff> {
let event = self.0.next().await;
if let Some(Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateDiff(
update,
)))) = event
{
update.diff
} else {
panic!("Expected diff but got: {:?}", event);
}
}
pub async fn expect_terminal(&mut self) -> Entity<acp_thread::Terminal> {
let event = self.0.next().await;
if let Some(Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateTerminal(

View file

@ -273,13 +273,6 @@ impl AgentTool for EditFileTool {
let diff = cx.new(|cx| Diff::new(buffer.clone(), cx))?;
event_stream.update_diff(diff.clone());
let _finalize_diff = util::defer({
let diff = diff.downgrade();
let mut cx = cx.clone();
move || {
diff.update(&mut cx, |diff, cx| diff.finalize(cx)).ok();
}
});
let old_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
let old_text = cx
@ -396,6 +389,8 @@ impl AgentTool for EditFileTool {
})
.await;
diff.update(cx, |diff, cx| diff.finalize(cx)).ok();
let input_path = input.path.display();
if unified_diff.is_empty() {
anyhow::ensure!(
@ -1550,100 +1545,6 @@ mod tests {
);
}
#[gpui::test]
async fn test_diff_finalization(cx: &mut TestAppContext) {
init_test(cx);
let fs = project::FakeFs::new(cx.executor());
fs.insert_tree("/", json!({"main.rs": ""})).await;
let project = Project::test(fs.clone(), [path!("/").as_ref()], cx).await;
let languages = project.read_with(cx, |project, _cx| project.languages().clone());
let context_server_registry =
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
let model = Arc::new(FakeLanguageModel::default());
let thread = cx.new(|cx| {
Thread::new(
project.clone(),
cx.new(|_cx| ProjectContext::default()),
context_server_registry.clone(),
Templates::new(),
Some(model.clone()),
cx,
)
});
// Ensure the diff is finalized after the edit completes.
{
let tool = Arc::new(EditFileTool::new(thread.downgrade(), languages.clone()));
let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
let edit = cx.update(|cx| {
tool.run(
EditFileToolInput {
display_description: "Edit file".into(),
path: path!("/main.rs").into(),
mode: EditFileMode::Edit,
},
stream_tx,
cx,
)
});
stream_rx.expect_update_fields().await;
let diff = stream_rx.expect_diff().await;
diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Pending(_))));
cx.run_until_parked();
model.end_last_completion_stream();
edit.await.unwrap();
diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Finalized(_))));
}
// Ensure the diff is finalized if an error occurs while editing.
{
model.forbid_requests();
let tool = Arc::new(EditFileTool::new(thread.downgrade(), languages.clone()));
let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
let edit = cx.update(|cx| {
tool.run(
EditFileToolInput {
display_description: "Edit file".into(),
path: path!("/main.rs").into(),
mode: EditFileMode::Edit,
},
stream_tx,
cx,
)
});
stream_rx.expect_update_fields().await;
let diff = stream_rx.expect_diff().await;
diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Pending(_))));
edit.await.unwrap_err();
diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Finalized(_))));
model.allow_requests();
}
// Ensure the diff is finalized if the tool call gets dropped.
{
let tool = Arc::new(EditFileTool::new(thread.downgrade(), languages.clone()));
let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
let edit = cx.update(|cx| {
tool.run(
EditFileToolInput {
display_description: "Edit file".into(),
path: path!("/main.rs").into(),
mode: EditFileMode::Edit,
},
stream_tx,
cx,
)
});
stream_rx.expect_update_fields().await;
let diff = stream_rx.expect_diff().await;
diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Pending(_))));
drop(edit);
cx.run_until_parked();
diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Finalized(_))));
}
}
fn init_test(cx: &mut TestAppContext) {
cx.update(|cx| {
let settings_store = SettingsStore::test(cx);

View file

@ -136,17 +136,12 @@ impl AgentTool for FetchTool {
fn run(
self: Arc<Self>,
input: Self::Input,
event_stream: ToolCallEventStream,
_event_stream: ToolCallEventStream,
cx: &mut App,
) -> Task<Result<Self::Output>> {
let authorize = event_stream.authorize(input.url.clone(), cx);
let text = cx.background_spawn({
let http_client = self.http_client.clone();
async move {
authorize.await?;
Self::build_message(http_client, &input.url).await
}
async move { Self::build_message(http_client, &input.url).await }
});
cx.foreground_executor().spawn(async move {

View file

@ -165,17 +165,16 @@ fn search_paths(glob: &str, project: Entity<Project>, cx: &mut App) -> Task<Resu
.collect();
cx.background_spawn(async move {
let mut results = Vec::new();
for snapshot in snapshots {
for entry in snapshot.entries(false, 0) {
Ok(snapshots
.iter()
.flat_map(|snapshot| {
let root_name = PathBuf::from(snapshot.root_name());
if path_matcher.is_match(root_name.join(&entry.path)) {
results.push(snapshot.abs_path().join(entry.path.as_ref()));
}
}
}
Ok(results)
snapshot
.entries(false, 0)
.map(move |entry| root_name.join(&entry.path))
.filter(|path| path_matcher.is_match(&path))
})
.collect())
})
}
@ -216,8 +215,8 @@ mod test {
assert_eq!(
matches,
&[
PathBuf::from(path!("/root/apple/banana/carrot")),
PathBuf::from(path!("/root/apple/bandana/carbonara"))
PathBuf::from("root/apple/banana/carrot"),
PathBuf::from("root/apple/bandana/carbonara")
]
);
@ -228,8 +227,8 @@ mod test {
assert_eq!(
matches,
&[
PathBuf::from(path!("/root/apple/banana/carrot")),
PathBuf::from(path!("/root/apple/bandana/carbonara"))
PathBuf::from("root/apple/banana/carrot"),
PathBuf::from("root/apple/bandana/carbonara")
]
);
}

View file

@ -11,7 +11,6 @@ use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::Settings;
use std::{path::Path, sync::Arc};
use util::markdown::MarkdownCodeBlock;
use crate::{AgentTool, ToolCallEventStream};
@ -244,19 +243,6 @@ impl AgentTool for ReadFileTool {
}]),
..Default::default()
});
if let Ok(LanguageModelToolResultContent::Text(text)) = &result {
let markdown = MarkdownCodeBlock {
tag: &input.path,
text,
}
.to_string();
event_stream.update_fields(ToolCallUpdateFields {
content: Some(vec![acp::ToolCallContent::Content {
content: markdown.into(),
}]),
..Default::default()
})
}
}
})?;

View file

@ -162,34 +162,12 @@ impl AgentConnection for AcpConnection {
let conn = self.connection.clone();
let sessions = self.sessions.clone();
let cwd = cwd.to_path_buf();
let context_server_store = project.read(cx).context_server_store().read(cx);
let mcp_servers = context_server_store
.configured_server_ids()
.iter()
.filter_map(|id| {
let configuration = context_server_store.configuration_for_server(id)?;
let command = configuration.command();
Some(acp::McpServer {
name: id.0.to_string(),
command: command.path.clone(),
args: command.args.clone(),
env: if let Some(env) = command.env.as_ref() {
env.iter()
.map(|(name, value)| acp::EnvVariable {
name: name.clone(),
value: value.clone(),
})
.collect()
} else {
vec![]
},
})
})
.collect();
cx.spawn(async move |cx| {
let response = conn
.new_session(acp::NewSessionRequest { mcp_servers, cwd })
.new_session(acp::NewSessionRequest {
mcp_servers: vec![],
cwd,
})
.await
.map_err(|err| {
if err.code == acp::ErrorCode::AUTH_REQUIRED.code {
@ -207,16 +185,13 @@ impl AgentConnection for AcpConnection {
let session_id = response.session_id;
let action_log = cx.new(|_| ActionLog::new(project.clone()))?;
let thread = cx.new(|cx| {
let thread = cx.new(|_cx| {
AcpThread::new(
self.server_name.clone(),
self.clone(),
project,
action_log,
session_id.clone(),
// ACP doesn't currently support per-session prompt capabilities or changing capabilities dynamically.
watch::Receiver::constant(self.prompt_capabilities),
cx,
)
})?;
@ -288,9 +263,7 @@ impl AgentConnection for AcpConnection {
match serde_json::from_value(data.clone()) {
Ok(ErrorDetails { details }) => {
if suppress_abort_err
&& (details.contains("This operation was aborted")
|| details.contains("The user aborted a request"))
if suppress_abort_err && details.contains("This operation was aborted")
{
Ok(acp::PromptResponse {
stop_reason: acp::StopReason::Cancelled,
@ -306,6 +279,10 @@ impl AgentConnection for AcpConnection {
})
}
fn prompt_capabilities(&self) -> acp::PromptCapabilities {
self.prompt_capabilities
}
fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
if let Some(session) = self.sessions.borrow_mut().get_mut(session_id) {
session.suppress_abort_err = true;

View file

@ -0,0 +1,524 @@
// Translates old acp agents into the new schema
use action_log::ActionLog;
use agent_client_protocol as acp;
use agentic_coding_protocol::{self as acp_old, AgentRequest as _};
use anyhow::{Context as _, Result, anyhow};
use futures::channel::oneshot;
use gpui::{AppContext as _, AsyncApp, Entity, Task, WeakEntity};
use project::Project;
use std::{any::Any, cell::RefCell, path::Path, rc::Rc};
use ui::App;
use util::ResultExt as _;
use crate::AgentServerCommand;
use acp_thread::{AcpThread, AgentConnection, AuthRequired};
#[derive(Clone)]
struct OldAcpClientDelegate {
thread: Rc<RefCell<WeakEntity<AcpThread>>>,
cx: AsyncApp,
next_tool_call_id: Rc<RefCell<u64>>,
// sent_buffer_versions: HashMap<Entity<Buffer>, HashMap<u64, BufferSnapshot>>,
}
impl OldAcpClientDelegate {
fn new(thread: Rc<RefCell<WeakEntity<AcpThread>>>, cx: AsyncApp) -> Self {
Self {
thread,
cx,
next_tool_call_id: Rc::new(RefCell::new(0)),
}
}
}
impl acp_old::Client for OldAcpClientDelegate {
async fn stream_assistant_message_chunk(
&self,
params: acp_old::StreamAssistantMessageChunkParams,
) -> Result<(), acp_old::Error> {
let cx = &mut self.cx.clone();
cx.update(|cx| {
self.thread
.borrow()
.update(cx, |thread, cx| match params.chunk {
acp_old::AssistantMessageChunk::Text { text } => {
thread.push_assistant_content_block(text.into(), false, cx)
}
acp_old::AssistantMessageChunk::Thought { thought } => {
thread.push_assistant_content_block(thought.into(), true, cx)
}
})
.log_err();
})?;
Ok(())
}
async fn request_tool_call_confirmation(
&self,
request: acp_old::RequestToolCallConfirmationParams,
) -> Result<acp_old::RequestToolCallConfirmationResponse, acp_old::Error> {
let cx = &mut self.cx.clone();
let old_acp_id = *self.next_tool_call_id.borrow() + 1;
self.next_tool_call_id.replace(old_acp_id);
let tool_call = into_new_tool_call(
acp::ToolCallId(old_acp_id.to_string().into()),
request.tool_call,
);
let mut options = match request.confirmation {
acp_old::ToolCallConfirmation::Edit { .. } => vec![(
acp_old::ToolCallConfirmationOutcome::AlwaysAllow,
acp::PermissionOptionKind::AllowAlways,
"Always Allow Edits".to_string(),
)],
acp_old::ToolCallConfirmation::Execute { root_command, .. } => vec![(
acp_old::ToolCallConfirmationOutcome::AlwaysAllow,
acp::PermissionOptionKind::AllowAlways,
format!("Always Allow {}", root_command),
)],
acp_old::ToolCallConfirmation::Mcp {
server_name,
tool_name,
..
} => vec![
(
acp_old::ToolCallConfirmationOutcome::AlwaysAllowMcpServer,
acp::PermissionOptionKind::AllowAlways,
format!("Always Allow {}", server_name),
),
(
acp_old::ToolCallConfirmationOutcome::AlwaysAllowTool,
acp::PermissionOptionKind::AllowAlways,
format!("Always Allow {}", tool_name),
),
],
acp_old::ToolCallConfirmation::Fetch { .. } => vec![(
acp_old::ToolCallConfirmationOutcome::AlwaysAllow,
acp::PermissionOptionKind::AllowAlways,
"Always Allow".to_string(),
)],
acp_old::ToolCallConfirmation::Other { .. } => vec![(
acp_old::ToolCallConfirmationOutcome::AlwaysAllow,
acp::PermissionOptionKind::AllowAlways,
"Always Allow".to_string(),
)],
};
options.extend([
(
acp_old::ToolCallConfirmationOutcome::Allow,
acp::PermissionOptionKind::AllowOnce,
"Allow".to_string(),
),
(
acp_old::ToolCallConfirmationOutcome::Reject,
acp::PermissionOptionKind::RejectOnce,
"Reject".to_string(),
),
]);
let mut outcomes = Vec::with_capacity(options.len());
let mut acp_options = Vec::with_capacity(options.len());
for (index, (outcome, kind, label)) in options.into_iter().enumerate() {
outcomes.push(outcome);
acp_options.push(acp::PermissionOption {
id: acp::PermissionOptionId(index.to_string().into()),
name: label,
kind,
})
}
let response = cx
.update(|cx| {
self.thread.borrow().update(cx, |thread, cx| {
thread.request_tool_call_authorization(tool_call.into(), acp_options, cx)
})
})??
.context("Failed to update thread")?
.await;
let outcome = match response {
Ok(option_id) => outcomes[option_id.0.parse::<usize>().unwrap_or(0)],
Err(oneshot::Canceled) => acp_old::ToolCallConfirmationOutcome::Cancel,
};
Ok(acp_old::RequestToolCallConfirmationResponse {
id: acp_old::ToolCallId(old_acp_id),
outcome,
})
}
async fn push_tool_call(
&self,
request: acp_old::PushToolCallParams,
) -> Result<acp_old::PushToolCallResponse, acp_old::Error> {
let cx = &mut self.cx.clone();
let old_acp_id = *self.next_tool_call_id.borrow() + 1;
self.next_tool_call_id.replace(old_acp_id);
cx.update(|cx| {
self.thread.borrow().update(cx, |thread, cx| {
thread.upsert_tool_call(
into_new_tool_call(acp::ToolCallId(old_acp_id.to_string().into()), request),
cx,
)
})
})??
.context("Failed to update thread")?;
Ok(acp_old::PushToolCallResponse {
id: acp_old::ToolCallId(old_acp_id),
})
}
async fn update_tool_call(
&self,
request: acp_old::UpdateToolCallParams,
) -> Result<(), acp_old::Error> {
let cx = &mut self.cx.clone();
cx.update(|cx| {
self.thread.borrow().update(cx, |thread, cx| {
thread.update_tool_call(
acp::ToolCallUpdate {
id: acp::ToolCallId(request.tool_call_id.0.to_string().into()),
fields: acp::ToolCallUpdateFields {
status: Some(into_new_tool_call_status(request.status)),
content: Some(
request
.content
.into_iter()
.map(into_new_tool_call_content)
.collect::<Vec<_>>(),
),
..Default::default()
},
},
cx,
)
})
})?
.context("Failed to update thread")??;
Ok(())
}
async fn update_plan(&self, request: acp_old::UpdatePlanParams) -> Result<(), acp_old::Error> {
let cx = &mut self.cx.clone();
cx.update(|cx| {
self.thread.borrow().update(cx, |thread, cx| {
thread.update_plan(
acp::Plan {
entries: request
.entries
.into_iter()
.map(into_new_plan_entry)
.collect(),
},
cx,
)
})
})?
.context("Failed to update thread")?;
Ok(())
}
async fn read_text_file(
&self,
acp_old::ReadTextFileParams { path, line, limit }: acp_old::ReadTextFileParams,
) -> Result<acp_old::ReadTextFileResponse, acp_old::Error> {
let content = self
.cx
.update(|cx| {
self.thread.borrow().update(cx, |thread, cx| {
thread.read_text_file(path, line, limit, false, cx)
})
})?
.context("Failed to update thread")?
.await?;
Ok(acp_old::ReadTextFileResponse { content })
}
async fn write_text_file(
&self,
acp_old::WriteTextFileParams { path, content }: acp_old::WriteTextFileParams,
) -> Result<(), acp_old::Error> {
self.cx
.update(|cx| {
self.thread
.borrow()
.update(cx, |thread, cx| thread.write_text_file(path, content, cx))
})?
.context("Failed to update thread")?
.await?;
Ok(())
}
}
fn into_new_tool_call(id: acp::ToolCallId, request: acp_old::PushToolCallParams) -> acp::ToolCall {
acp::ToolCall {
id,
title: request.label,
kind: acp_kind_from_old_icon(request.icon),
status: acp::ToolCallStatus::InProgress,
content: request
.content
.into_iter()
.map(into_new_tool_call_content)
.collect(),
locations: request
.locations
.into_iter()
.map(into_new_tool_call_location)
.collect(),
raw_input: None,
raw_output: None,
}
}
fn acp_kind_from_old_icon(icon: acp_old::Icon) -> acp::ToolKind {
match icon {
acp_old::Icon::FileSearch => acp::ToolKind::Search,
acp_old::Icon::Folder => acp::ToolKind::Search,
acp_old::Icon::Globe => acp::ToolKind::Search,
acp_old::Icon::Hammer => acp::ToolKind::Other,
acp_old::Icon::LightBulb => acp::ToolKind::Think,
acp_old::Icon::Pencil => acp::ToolKind::Edit,
acp_old::Icon::Regex => acp::ToolKind::Search,
acp_old::Icon::Terminal => acp::ToolKind::Execute,
}
}
fn into_new_tool_call_status(status: acp_old::ToolCallStatus) -> acp::ToolCallStatus {
match status {
acp_old::ToolCallStatus::Running => acp::ToolCallStatus::InProgress,
acp_old::ToolCallStatus::Finished => acp::ToolCallStatus::Completed,
acp_old::ToolCallStatus::Error => acp::ToolCallStatus::Failed,
}
}
fn into_new_tool_call_content(content: acp_old::ToolCallContent) -> acp::ToolCallContent {
match content {
acp_old::ToolCallContent::Markdown { markdown } => markdown.into(),
acp_old::ToolCallContent::Diff { diff } => acp::ToolCallContent::Diff {
diff: into_new_diff(diff),
},
}
}
fn into_new_diff(diff: acp_old::Diff) -> acp::Diff {
acp::Diff {
path: diff.path,
old_text: diff.old_text,
new_text: diff.new_text,
}
}
fn into_new_tool_call_location(location: acp_old::ToolCallLocation) -> acp::ToolCallLocation {
acp::ToolCallLocation {
path: location.path,
line: location.line,
}
}
fn into_new_plan_entry(entry: acp_old::PlanEntry) -> acp::PlanEntry {
acp::PlanEntry {
content: entry.content,
priority: into_new_plan_priority(entry.priority),
status: into_new_plan_status(entry.status),
}
}
fn into_new_plan_priority(priority: acp_old::PlanEntryPriority) -> acp::PlanEntryPriority {
match priority {
acp_old::PlanEntryPriority::Low => acp::PlanEntryPriority::Low,
acp_old::PlanEntryPriority::Medium => acp::PlanEntryPriority::Medium,
acp_old::PlanEntryPriority::High => acp::PlanEntryPriority::High,
}
}
fn into_new_plan_status(status: acp_old::PlanEntryStatus) -> acp::PlanEntryStatus {
match status {
acp_old::PlanEntryStatus::Pending => acp::PlanEntryStatus::Pending,
acp_old::PlanEntryStatus::InProgress => acp::PlanEntryStatus::InProgress,
acp_old::PlanEntryStatus::Completed => acp::PlanEntryStatus::Completed,
}
}
pub struct AcpConnection {
pub name: &'static str,
pub connection: acp_old::AgentConnection,
pub _child_status: Task<Result<()>>,
pub current_thread: Rc<RefCell<WeakEntity<AcpThread>>>,
}
impl AcpConnection {
pub fn stdio(
name: &'static str,
command: AgentServerCommand,
root_dir: &Path,
cx: &mut AsyncApp,
) -> Task<Result<Self>> {
let root_dir = root_dir.to_path_buf();
cx.spawn(async move |cx| {
let mut child = util::command::new_smol_command(&command.path)
.args(command.args.iter())
.current_dir(root_dir)
.stdin(std::process::Stdio::piped())
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::inherit())
.kill_on_drop(true)
.spawn()?;
let stdin = child.stdin.take().unwrap();
let stdout = child.stdout.take().unwrap();
log::trace!("Spawned (pid: {})", child.id());
let foreground_executor = cx.foreground_executor().clone();
let thread_rc = Rc::new(RefCell::new(WeakEntity::new_invalid()));
let (connection, io_fut) = acp_old::AgentConnection::connect_to_agent(
OldAcpClientDelegate::new(thread_rc.clone(), cx.clone()),
stdin,
stdout,
move |fut| foreground_executor.spawn(fut).detach(),
);
let io_task = cx.background_spawn(async move {
io_fut.await.log_err();
});
let child_status = cx.background_spawn(async move {
let result = match child.status().await {
Err(e) => Err(anyhow!(e)),
Ok(result) if result.success() => Ok(()),
Ok(result) => Err(anyhow!(result)),
};
drop(io_task);
result
});
Ok(Self {
name,
connection,
_child_status: child_status,
current_thread: thread_rc,
})
})
}
}
impl AgentConnection for AcpConnection {
fn new_thread(
self: Rc<Self>,
project: Entity<Project>,
_cwd: &Path,
cx: &mut App,
) -> Task<Result<Entity<AcpThread>>> {
let task = self.connection.request_any(
acp_old::InitializeParams {
protocol_version: acp_old::ProtocolVersion::latest(),
}
.into_any(),
);
let current_thread = self.current_thread.clone();
cx.spawn(async move |cx| {
let result = task.await?;
let result = acp_old::InitializeParams::response_from_any(result)?;
if !result.is_authenticated {
anyhow::bail!(AuthRequired::new())
}
cx.update(|cx| {
let thread = cx.new(|cx| {
let session_id = acp::SessionId("acp-old-no-id".into());
let action_log = cx.new(|_| ActionLog::new(project.clone()));
AcpThread::new(self.name, self.clone(), project, action_log, session_id)
});
current_thread.replace(thread.downgrade());
thread
})
})
}
fn auth_methods(&self) -> &[acp::AuthMethod] {
&[]
}
fn authenticate(&self, _method_id: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>> {
let task = self
.connection
.request_any(acp_old::AuthenticateParams.into_any());
cx.foreground_executor().spawn(async move {
task.await?;
Ok(())
})
}
fn prompt(
&self,
_id: Option<acp_thread::UserMessageId>,
params: acp::PromptRequest,
cx: &mut App,
) -> Task<Result<acp::PromptResponse>> {
let chunks = params
.prompt
.into_iter()
.filter_map(|block| match block {
acp::ContentBlock::Text(text) => {
Some(acp_old::UserMessageChunk::Text { text: text.text })
}
acp::ContentBlock::ResourceLink(link) => Some(acp_old::UserMessageChunk::Path {
path: link.uri.into(),
}),
_ => None,
})
.collect();
let task = self
.connection
.request_any(acp_old::SendUserMessageParams { chunks }.into_any());
cx.foreground_executor().spawn(async move {
task.await?;
anyhow::Ok(acp::PromptResponse {
stop_reason: acp::StopReason::EndTurn,
})
})
}
fn prompt_capabilities(&self) -> acp::PromptCapabilities {
acp::PromptCapabilities {
image: false,
audio: false,
embedded_context: false,
}
}
fn cancel(&self, _session_id: &acp::SessionId, cx: &mut App) {
let task = self
.connection
.request_any(acp_old::CancelSendMessageParams.into_any());
cx.foreground_executor()
.spawn(async move {
task.await?;
anyhow::Ok(())
})
.detach_and_log_err(cx)
}
fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
self
}
}

View file

@ -0,0 +1,376 @@
use acp_tools::AcpConnectionRegistry;
use action_log::ActionLog;
use agent_client_protocol::{self as acp, Agent as _, ErrorCode};
use anyhow::anyhow;
use collections::HashMap;
use futures::AsyncBufReadExt as _;
use futures::channel::oneshot;
use futures::io::BufReader;
use project::Project;
use serde::Deserialize;
use std::path::Path;
use std::rc::Rc;
use std::{any::Any, cell::RefCell};
use anyhow::{Context as _, Result};
use gpui::{App, AppContext as _, AsyncApp, Entity, Task, WeakEntity};
use crate::{AgentServerCommand, acp::UnsupportedVersion};
use acp_thread::{AcpThread, AgentConnection, AuthRequired, LoadError};
pub struct AcpConnection {
server_name: &'static str,
connection: Rc<acp::ClientSideConnection>,
sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
auth_methods: Vec<acp::AuthMethod>,
prompt_capabilities: acp::PromptCapabilities,
_io_task: Task<Result<()>>,
}
pub struct AcpSession {
thread: WeakEntity<AcpThread>,
suppress_abort_err: bool,
}
const MINIMUM_SUPPORTED_VERSION: acp::ProtocolVersion = acp::V1;
impl AcpConnection {
pub async fn stdio(
server_name: &'static str,
command: AgentServerCommand,
root_dir: &Path,
cx: &mut AsyncApp,
) -> Result<Self> {
let mut child = util::command::new_smol_command(&command.path)
.args(command.args.iter().map(|arg| arg.as_str()))
.envs(command.env.iter().flatten())
.current_dir(root_dir)
.stdin(std::process::Stdio::piped())
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::piped())
.kill_on_drop(true)
.spawn()?;
let stdout = child.stdout.take().context("Failed to take stdout")?;
let stdin = child.stdin.take().context("Failed to take stdin")?;
let stderr = child.stderr.take().context("Failed to take stderr")?;
log::trace!("Spawned (pid: {})", child.id());
let sessions = Rc::new(RefCell::new(HashMap::default()));
let client = ClientDelegate {
sessions: sessions.clone(),
cx: cx.clone(),
};
let (connection, io_task) = acp::ClientSideConnection::new(client, stdin, stdout, {
let foreground_executor = cx.foreground_executor().clone();
move |fut| {
foreground_executor.spawn(fut).detach();
}
});
let io_task = cx.background_spawn(io_task);
cx.background_spawn(async move {
let mut stderr = BufReader::new(stderr);
let mut line = String::new();
while let Ok(n) = stderr.read_line(&mut line).await
&& n > 0
{
log::warn!("agent stderr: {}", &line);
line.clear();
}
})
.detach();
cx.spawn({
let sessions = sessions.clone();
async move |cx| {
let status = child.status().await?;
for session in sessions.borrow().values() {
session
.thread
.update(cx, |thread, cx| {
thread.emit_load_error(LoadError::Exited { status }, cx)
})
.ok();
}
anyhow::Ok(())
}
})
.detach();
let connection = Rc::new(connection);
cx.update(|cx| {
AcpConnectionRegistry::default_global(cx).update(cx, |registry, cx| {
registry.set_active_connection(server_name, &connection, cx)
});
})?;
let response = connection
.initialize(acp::InitializeRequest {
protocol_version: acp::VERSION,
client_capabilities: acp::ClientCapabilities {
fs: acp::FileSystemCapability {
read_text_file: true,
write_text_file: true,
},
},
})
.await?;
if response.protocol_version < MINIMUM_SUPPORTED_VERSION {
return Err(UnsupportedVersion.into());
}
Ok(Self {
auth_methods: response.auth_methods,
connection,
server_name,
sessions,
prompt_capabilities: response.agent_capabilities.prompt_capabilities,
_io_task: io_task,
})
}
}
impl AgentConnection for AcpConnection {
fn new_thread(
self: Rc<Self>,
project: Entity<Project>,
cwd: &Path,
cx: &mut App,
) -> Task<Result<Entity<AcpThread>>> {
let conn = self.connection.clone();
let sessions = self.sessions.clone();
let cwd = cwd.to_path_buf();
cx.spawn(async move |cx| {
let response = conn
.new_session(acp::NewSessionRequest {
mcp_servers: vec![],
cwd,
})
.await
.map_err(|err| {
if err.code == acp::ErrorCode::AUTH_REQUIRED.code {
let mut error = AuthRequired::new();
if err.message != acp::ErrorCode::AUTH_REQUIRED.message {
error = error.with_description(err.message);
}
anyhow!(error)
} else {
anyhow!(err)
}
})?;
let session_id = response.session_id;
let action_log = cx.new(|_| ActionLog::new(project.clone()))?;
let thread = cx.new(|_cx| {
AcpThread::new(
self.server_name,
self.clone(),
project,
action_log,
session_id.clone(),
)
})?;
let session = AcpSession {
thread: thread.downgrade(),
suppress_abort_err: false,
};
sessions.borrow_mut().insert(session_id, session);
Ok(thread)
})
}
fn auth_methods(&self) -> &[acp::AuthMethod] {
&self.auth_methods
}
fn authenticate(&self, method_id: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>> {
let conn = self.connection.clone();
cx.foreground_executor().spawn(async move {
let result = conn
.authenticate(acp::AuthenticateRequest {
method_id: method_id.clone(),
})
.await?;
Ok(result)
})
}
fn prompt(
&self,
_id: Option<acp_thread::UserMessageId>,
params: acp::PromptRequest,
cx: &mut App,
) -> Task<Result<acp::PromptResponse>> {
let conn = self.connection.clone();
let sessions = self.sessions.clone();
let session_id = params.session_id.clone();
cx.foreground_executor().spawn(async move {
let result = conn.prompt(params).await;
let mut suppress_abort_err = false;
if let Some(session) = sessions.borrow_mut().get_mut(&session_id) {
suppress_abort_err = session.suppress_abort_err;
session.suppress_abort_err = false;
}
match result {
Ok(response) => Ok(response),
Err(err) => {
if err.code != ErrorCode::INTERNAL_ERROR.code {
anyhow::bail!(err)
}
let Some(data) = &err.data else {
anyhow::bail!(err)
};
// Temporary workaround until the following PR is generally available:
// https://github.com/google-gemini/gemini-cli/pull/6656
#[derive(Deserialize)]
#[serde(deny_unknown_fields)]
struct ErrorDetails {
details: Box<str>,
}
match serde_json::from_value(data.clone()) {
Ok(ErrorDetails { details }) => {
if suppress_abort_err && details.contains("This operation was aborted")
{
Ok(acp::PromptResponse {
stop_reason: acp::StopReason::Cancelled,
})
} else {
Err(anyhow!(details))
}
}
Err(_) => Err(anyhow!(err)),
}
}
}
})
}
fn prompt_capabilities(&self) -> acp::PromptCapabilities {
self.prompt_capabilities
}
fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
if let Some(session) = self.sessions.borrow_mut().get_mut(session_id) {
session.suppress_abort_err = true;
}
let conn = self.connection.clone();
let params = acp::CancelNotification {
session_id: session_id.clone(),
};
cx.foreground_executor()
.spawn(async move { conn.cancel(params).await })
.detach();
}
fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
self
}
}
struct ClientDelegate {
sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
cx: AsyncApp,
}
impl acp::Client for ClientDelegate {
async fn request_permission(
&self,
arguments: acp::RequestPermissionRequest,
) -> Result<acp::RequestPermissionResponse, acp::Error> {
let cx = &mut self.cx.clone();
let rx = self
.sessions
.borrow()
.get(&arguments.session_id)
.context("Failed to get session")?
.thread
.update(cx, |thread, cx| {
thread.request_tool_call_authorization(arguments.tool_call, arguments.options, cx)
})?;
let result = rx?.await;
let outcome = match result {
Ok(option) => acp::RequestPermissionOutcome::Selected { option_id: option },
Err(oneshot::Canceled) => acp::RequestPermissionOutcome::Cancelled,
};
Ok(acp::RequestPermissionResponse { outcome })
}
async fn write_text_file(
&self,
arguments: acp::WriteTextFileRequest,
) -> Result<(), acp::Error> {
let cx = &mut self.cx.clone();
let task = self
.sessions
.borrow()
.get(&arguments.session_id)
.context("Failed to get session")?
.thread
.update(cx, |thread, cx| {
thread.write_text_file(arguments.path, arguments.content, cx)
})?;
task.await?;
Ok(())
}
async fn read_text_file(
&self,
arguments: acp::ReadTextFileRequest,
) -> Result<acp::ReadTextFileResponse, acp::Error> {
let cx = &mut self.cx.clone();
let task = self
.sessions
.borrow()
.get(&arguments.session_id)
.context("Failed to get session")?
.thread
.update(cx, |thread, cx| {
thread.read_text_file(arguments.path, arguments.line, arguments.limit, false, cx)
})?;
let content = task.await?;
Ok(acp::ReadTextFileResponse { content })
}
async fn session_notification(
&self,
notification: acp::SessionNotification,
) -> Result<(), acp::Error> {
let cx = &mut self.cx.clone();
let sessions = self.sessions.borrow();
let session = sessions
.get(&notification.session_id)
.context("Failed to get session")?;
session.thread.update(cx, |thread, cx| {
thread.handle_session_update(notification.update, cx)
})??;
Ok(())
}
}

View file

@ -36,7 +36,6 @@ pub trait AgentServer: Send {
fn name(&self) -> SharedString;
fn empty_state_headline(&self) -> SharedString;
fn empty_state_message(&self) -> SharedString;
fn telemetry_id(&self) -> &'static str;
fn connect(
&self,
@ -98,7 +97,7 @@ pub struct AgentServerCommand {
}
impl AgentServerCommand {
pub async fn resolve(
pub(crate) async fn resolve(
path_bin_name: &'static str,
extra_args: &[&'static str],
fallback_path: Option<&Path>,

View file

@ -43,10 +43,6 @@ use acp_thread::{AcpThread, AgentConnection, AuthRequired, LoadError, MentionUri
pub struct ClaudeCode;
impl AgentServer for ClaudeCode {
fn telemetry_id(&self) -> &'static str {
"claude-code"
}
fn name(&self) -> SharedString {
"Claude Code".into()
}
@ -253,19 +249,13 @@ impl AgentConnection for ClaudeAgentConnection {
});
let action_log = cx.new(|_| ActionLog::new(project.clone()))?;
let thread = cx.new(|cx| {
let thread = cx.new(|_cx| {
AcpThread::new(
"Claude Code",
self.clone(),
project,
action_log,
session_id.clone(),
watch::Receiver::constant(acp::PromptCapabilities {
image: true,
audio: false,
embedded_context: true,
}),
cx,
)
})?;
@ -329,6 +319,14 @@ impl AgentConnection for ClaudeAgentConnection {
cx.foreground_executor().spawn(async move { end_rx.await? })
}
fn prompt_capabilities(&self) -> acp::PromptCapabilities {
acp::PromptCapabilities {
image: true,
audio: false,
embedded_context: true,
}
}
fn cancel(&self, session_id: &acp::SessionId, _cx: &mut App) {
let sessions = self.sessions.borrow();
let Some(session) = sessions.get(session_id) else {

View file

@ -22,10 +22,6 @@ impl CustomAgentServer {
}
impl crate::AgentServer for CustomAgentServer {
fn telemetry_id(&self) -> &'static str {
"custom"
}
fn name(&self) -> SharedString {
self.name.clone()
}

View file

@ -17,10 +17,6 @@ pub struct Gemini;
const ACP_ARG: &str = "--experimental-acp";
impl AgentServer for Gemini {
fn telemetry_id(&self) -> &'static str {
"gemini-cli"
}
fn name(&self) -> SharedString {
"Gemini CLI".into()
}
@ -57,7 +53,7 @@ impl AgentServer for Gemini {
return Err(LoadError::NotInstalled {
error_message: "Failed to find Gemini CLI binary".into(),
install_message: "Install Gemini CLI".into(),
install_command: Self::install_command().into(),
install_command: "npm install -g @google/gemini-cli@preview".into()
}.into());
};
@ -92,7 +88,7 @@ impl AgentServer for Gemini {
current_version
).into(),
upgrade_message: "Upgrade Gemini CLI to latest".into(),
upgrade_command: Self::upgrade_command().into(),
upgrade_command: "npm install -g @google/gemini-cli@preview".into(),
}.into())
}
}
@ -105,20 +101,6 @@ impl AgentServer for Gemini {
}
}
impl Gemini {
pub fn binary_name() -> &'static str {
"gemini"
}
pub fn install_command() -> &'static str {
"npm install -g @google/gemini-cli@preview"
}
pub fn upgrade_command() -> &'static str {
"npm install -g @google/gemini-cli@preview"
}
}
#[cfg(test)]
pub(crate) mod tests {
use super::*;

View file

@ -6,7 +6,7 @@ use agent2::HistoryStore;
use collections::HashMap;
use editor::{Editor, EditorMode, MinimapVisibility};
use gpui::{
AnyEntity, App, AppContext as _, Entity, EntityId, EventEmitter, Focusable, ScrollHandle,
AnyEntity, App, AppContext as _, Entity, EntityId, EventEmitter, Focusable,
TextStyleRefinement, WeakEntity, Window,
};
use language::language_settings::SoftWrap;
@ -154,22 +154,10 @@ impl EntryViewState {
});
}
}
AgentThreadEntry::AssistantMessage(message) => {
let entry = if let Some(Entry::AssistantMessage(entry)) =
self.entries.get_mut(index)
{
entry
} else {
self.set_entry(
index,
Entry::AssistantMessage(AssistantMessageEntry::default()),
);
let Some(Entry::AssistantMessage(entry)) = self.entries.get_mut(index) else {
unreachable!()
};
entry
};
entry.sync(message);
AgentThreadEntry::AssistantMessage(_) => {
if index == self.entries.len() {
self.entries.push(Entry::empty())
}
}
};
}
@ -189,7 +177,7 @@ impl EntryViewState {
pub fn settings_changed(&mut self, cx: &mut App) {
for entry in self.entries.iter() {
match entry {
Entry::UserMessage { .. } | Entry::AssistantMessage { .. } => {}
Entry::UserMessage { .. } => {}
Entry::Content(response_views) => {
for view in response_views.values() {
if let Ok(diff_editor) = view.clone().downcast::<Editor>() {
@ -220,29 +208,9 @@ pub enum ViewEvent {
MessageEditorEvent(Entity<MessageEditor>, MessageEditorEvent),
}
#[derive(Default, Debug)]
pub struct AssistantMessageEntry {
scroll_handles_by_chunk_index: HashMap<usize, ScrollHandle>,
}
impl AssistantMessageEntry {
pub fn scroll_handle_for_chunk(&self, ix: usize) -> Option<ScrollHandle> {
self.scroll_handles_by_chunk_index.get(&ix).cloned()
}
pub fn sync(&mut self, message: &acp_thread::AssistantMessage) {
if let Some(acp_thread::AssistantMessageChunk::Thought { .. }) = message.chunks.last() {
let ix = message.chunks.len() - 1;
let handle = self.scroll_handles_by_chunk_index.entry(ix).or_default();
handle.scroll_to_bottom();
}
}
}
#[derive(Debug)]
pub enum Entry {
UserMessage(Entity<MessageEditor>),
AssistantMessage(AssistantMessageEntry),
Content(HashMap<EntityId, AnyEntity>),
}
@ -250,7 +218,7 @@ impl Entry {
pub fn message_editor(&self) -> Option<&Entity<MessageEditor>> {
match self {
Self::UserMessage(editor) => Some(editor),
Self::AssistantMessage(_) | Self::Content(_) => None,
Entry::Content(_) => None,
}
}
@ -271,16 +239,6 @@ impl Entry {
.map(|entity| entity.downcast::<TerminalView>().unwrap())
}
pub fn scroll_handle_for_assistant_message_chunk(
&self,
chunk_ix: usize,
) -> Option<ScrollHandle> {
match self {
Self::AssistantMessage(message) => message.scroll_handle_for_chunk(chunk_ix),
Self::UserMessage(_) | Self::Content(_) => None,
}
}
fn content_map(&self) -> Option<&HashMap<EntityId, AnyEntity>> {
match self {
Self::Content(map) => Some(map),
@ -296,7 +254,7 @@ impl Entry {
pub fn has_content(&self) -> bool {
match self {
Self::Content(map) => !map.is_empty(),
Self::UserMessage(_) | Self::AssistantMessage(_) => false,
Self::UserMessage(_) => false,
}
}
}

View file

@ -373,7 +373,7 @@ impl MessageEditor {
if Img::extensions().contains(&extension) && !extension.contains("svg") {
if !self.prompt_capabilities.get().image {
return Task::ready(Err(anyhow!("This model does not support images yet")));
return Task::ready(Err(anyhow!("This agent does not support images yet")));
}
let task = self
.project

View file

@ -462,7 +462,7 @@ impl AcpThreadHistory {
cx.notify();
}))
.end_slot::<IconButton>(if hovered {
.end_slot::<IconButton>(if hovered || selected {
Some(
IconButton::new("delete", IconName::Trash)
.shape(IconButtonShape::Square)

File diff suppressed because it is too large Load diff

View file

@ -3,23 +3,19 @@ mod configure_context_server_modal;
mod manage_profiles_modal;
mod tool_picker;
use std::{ops::Range, sync::Arc, time::Duration};
use std::{sync::Arc, time::Duration};
use agent_servers::{AgentServerCommand, AgentServerSettings, AllAgentServersSettings, Gemini};
use agent_settings::AgentSettings;
use anyhow::Result;
use assistant_tool::{ToolSource, ToolWorkingSet};
use cloud_llm_client::Plan;
use collections::HashMap;
use context_server::ContextServerId;
use editor::{Editor, SelectionEffects, scroll::Autoscroll};
use extension::ExtensionManifest;
use extension_host::ExtensionStore;
use fs::Fs;
use gpui::{
Action, Animation, AnimationExt as _, AnyView, App, AsyncWindowContext, Corner, Entity,
EventEmitter, FocusHandle, Focusable, Hsla, ScrollHandle, Subscription, Task, Transformation,
WeakEntity, percentage,
Action, Animation, AnimationExt as _, AnyView, App, Corner, Entity, EventEmitter, FocusHandle,
Focusable, ScrollHandle, Subscription, Task, Transformation, WeakEntity, percentage,
};
use language::LanguageRegistry;
use language_model::{
@ -27,24 +23,23 @@ use language_model::{
};
use notifications::status_toast::{StatusToast, ToastIcon};
use project::{
Project,
context_server_store::{ContextServerConfiguration, ContextServerStatus, ContextServerStore},
project_settings::{ContextServerSettings, ProjectSettings},
};
use settings::{Settings, SettingsStore, update_settings_file};
use settings::{Settings, update_settings_file};
use ui::{
Chip, ContextMenu, Disclosure, Divider, DividerColor, ElevationIndex, Indicator, PopoverMenu,
Scrollbar, ScrollbarState, Switch, SwitchColor, SwitchField, Tooltip, prelude::*,
};
use util::ResultExt as _;
use workspace::{Workspace, create_and_open_local_file};
use workspace::Workspace;
use zed_actions::ExtensionCategoryFilter;
pub(crate) use configure_context_server_modal::ConfigureContextServerModal;
pub(crate) use manage_profiles_modal::ManageProfilesModal;
use crate::{
AddContextServer, ExternalAgent, NewExternalAgentThread,
AddContextServer,
agent_configuration::add_llm_provider_modal::{AddLlmProviderModal, LlmCompatibleProvider},
};
@ -52,7 +47,6 @@ pub struct AgentConfiguration {
fs: Arc<dyn Fs>,
language_registry: Arc<LanguageRegistry>,
workspace: WeakEntity<Workspace>,
project: WeakEntity<Project>,
focus_handle: FocusHandle,
configuration_views_by_provider: HashMap<LanguageModelProviderId, AnyView>,
context_server_store: Entity<ContextServerStore>,
@ -62,8 +56,6 @@ pub struct AgentConfiguration {
_registry_subscription: Subscription,
scroll_handle: ScrollHandle,
scrollbar_state: ScrollbarState,
gemini_is_installed: bool,
_check_for_gemini: Task<()>,
}
impl AgentConfiguration {
@ -73,7 +65,6 @@ impl AgentConfiguration {
tools: Entity<ToolWorkingSet>,
language_registry: Arc<LanguageRegistry>,
workspace: WeakEntity<Workspace>,
project: WeakEntity<Project>,
window: &mut Window,
cx: &mut Context<Self>,
) -> Self {
@ -98,11 +89,6 @@ impl AgentConfiguration {
cx.subscribe(&context_server_store, |_, _, _, cx| cx.notify())
.detach();
cx.observe_global_in::<SettingsStore>(window, |this, _, cx| {
this.check_for_gemini(cx);
cx.notify();
})
.detach();
let scroll_handle = ScrollHandle::new();
let scrollbar_state = ScrollbarState::new(scroll_handle.clone());
@ -111,7 +97,6 @@ impl AgentConfiguration {
fs,
language_registry,
workspace,
project,
focus_handle,
configuration_views_by_provider: HashMap::default(),
context_server_store,
@ -121,11 +106,8 @@ impl AgentConfiguration {
_registry_subscription: registry_subscription,
scroll_handle,
scrollbar_state,
gemini_is_installed: false,
_check_for_gemini: Task::ready(()),
};
this.build_provider_configuration_views(window, cx);
this.check_for_gemini(cx);
this
}
@ -155,34 +137,6 @@ impl AgentConfiguration {
self.configuration_views_by_provider
.insert(provider.id(), configuration_view);
}
fn check_for_gemini(&mut self, cx: &mut Context<Self>) {
let project = self.project.clone();
let settings = AllAgentServersSettings::get_global(cx).clone();
self._check_for_gemini = cx.spawn({
async move |this, cx| {
let Some(project) = project.upgrade() else {
return;
};
let gemini_is_installed = AgentServerCommand::resolve(
Gemini::binary_name(),
&[],
// TODO expose fallback path from the Gemini/CC types so we don't have to hardcode it again here
None,
settings.gemini,
&project,
cx,
)
.await
.is_some();
this.update(cx, |this, cx| {
this.gemini_is_installed = gemini_is_installed;
cx.notify();
})
.ok();
}
});
}
}
impl Focusable for AgentConfiguration {
@ -257,6 +211,7 @@ impl AgentConfiguration {
.child(
h_flex()
.id(provider_id_string.clone())
.cursor_pointer()
.px_2()
.py_0p5()
.w_full()
@ -276,7 +231,10 @@ impl AgentConfiguration {
h_flex()
.w_full()
.gap_1()
.child(Label::new(provider_name.clone()))
.child(
Label::new(provider_name.clone())
.size(LabelSize::Large),
)
.map(|this| {
if is_zed_provider && is_signed_in {
this.child(
@ -321,7 +279,7 @@ impl AgentConfiguration {
"Start New Thread",
)
.icon_position(IconPosition::Start)
.icon(IconName::Thread)
.icon(IconName::Plus)
.icon_size(IconSize::Small)
.icon_color(Color::Muted)
.label_size(LabelSize::Small)
@ -420,7 +378,7 @@ impl AgentConfiguration {
),
)
.child(
Label::new("Add at least one provider to use AI-powered features with Zed's native agent.")
Label::new("Add at least one provider to use AI-powered features.")
.color(Color::Muted),
),
),
@ -561,14 +519,6 @@ impl AgentConfiguration {
}
}
fn card_item_bg_color(&self, cx: &mut Context<Self>) -> Hsla {
cx.theme().colors().background.opacity(0.25)
}
fn card_item_border_color(&self, cx: &mut Context<Self>) -> Hsla {
cx.theme().colors().border.opacity(0.6)
}
fn render_context_servers_section(
&mut self,
window: &mut Window,
@ -586,12 +536,7 @@ impl AgentConfiguration {
v_flex()
.gap_0p5()
.child(Headline::new("Model Context Protocol (MCP) Servers"))
.child(
Label::new(
"All context servers connected through the Model Context Protocol.",
)
.color(Color::Muted),
),
.child(Label::new("Connect to context servers through the Model Context Protocol, either using Zed extensions or directly.").color(Color::Muted)),
)
.children(
context_server_ids.into_iter().map(|context_server_id| {
@ -601,7 +546,7 @@ impl AgentConfiguration {
.child(
h_flex()
.justify_between()
.gap_1p5()
.gap_2()
.child(
h_flex().w_full().child(
Button::new("add-context-server", "Add Custom Server")
@ -692,6 +637,8 @@ impl AgentConfiguration {
.map_or([].as_slice(), |tools| tools.as_slice());
let tool_count = tools.len();
let border_color = cx.theme().colors().border.opacity(0.6);
let (source_icon, source_tooltip) = if is_from_extension {
(
IconName::ZedMcpExtension,
@ -834,8 +781,8 @@ impl AgentConfiguration {
.id(item_id.clone())
.border_1()
.rounded_md()
.border_color(self.card_item_border_color(cx))
.bg(self.card_item_bg_color(cx))
.border_color(border_color)
.bg(cx.theme().colors().background.opacity(0.2))
.overflow_hidden()
.child(
h_flex()
@ -843,11 +790,7 @@ impl AgentConfiguration {
.justify_between()
.when(
error.is_some() || are_tools_expanded && tool_count >= 1,
|element| {
element
.border_b_1()
.border_color(self.card_item_border_color(cx))
},
|element| element.border_b_1().border_color(border_color),
)
.child(
h_flex()
@ -1029,195 +972,6 @@ impl AgentConfiguration {
))
})
}
fn render_agent_servers_section(&mut self, cx: &mut Context<Self>) -> impl IntoElement {
let settings = AllAgentServersSettings::get_global(cx).clone();
let user_defined_agents = settings
.custom
.iter()
.map(|(name, settings)| {
self.render_agent_server(
IconName::Ai,
name.clone(),
ExternalAgent::Custom {
name: name.clone(),
settings: settings.clone(),
},
None,
cx,
)
.into_any_element()
})
.collect::<Vec<_>>();
v_flex()
.border_b_1()
.border_color(cx.theme().colors().border)
.child(
v_flex()
.p(DynamicSpacing::Base16.rems(cx))
.pr(DynamicSpacing::Base20.rems(cx))
.gap_2()
.child(
v_flex()
.gap_0p5()
.child(
h_flex()
.w_full()
.gap_2()
.justify_between()
.child(Headline::new("External Agents"))
.child(
Button::new("add-agent", "Add Agent")
.icon_position(IconPosition::Start)
.icon(IconName::Plus)
.icon_size(IconSize::Small)
.icon_color(Color::Muted)
.label_size(LabelSize::Small)
.on_click(
move |_, window, cx| {
if let Some(workspace) = window.root().flatten() {
let workspace = workspace.downgrade();
window
.spawn(cx, async |cx| {
open_new_agent_servers_entry_in_settings_editor(
workspace,
cx,
).await
})
.detach_and_log_err(cx);
}
}
),
)
)
.child(
Label::new(
"Bring the agent of your choice to Zed via our new Agent Client Protocol.",
)
.color(Color::Muted),
),
)
.child(self.render_agent_server(
IconName::AiGemini,
"Gemini CLI",
ExternalAgent::Gemini,
(!self.gemini_is_installed).then_some(Gemini::install_command().into()),
cx,
))
// TODO add CC
.children(user_defined_agents),
)
}
fn render_agent_server(
&self,
icon: IconName,
name: impl Into<SharedString>,
agent: ExternalAgent,
install_command: Option<SharedString>,
cx: &mut Context<Self>,
) -> impl IntoElement {
let name = name.into();
h_flex()
.p_1()
.pl_2()
.gap_1p5()
.justify_between()
.border_1()
.rounded_md()
.border_color(self.card_item_border_color(cx))
.bg(self.card_item_bg_color(cx))
.overflow_hidden()
.child(
h_flex()
.gap_1p5()
.child(Icon::new(icon).size(IconSize::Small).color(Color::Muted))
.child(Label::new(name.clone())),
)
.map(|this| {
if let Some(install_command) = install_command {
this.child(
Button::new(
SharedString::from(format!("install_external_agent-{name}")),
"Install Agent",
)
.label_size(LabelSize::Small)
.icon(IconName::Plus)
.icon_position(IconPosition::Start)
.icon_size(IconSize::XSmall)
.icon_color(Color::Muted)
.tooltip(Tooltip::text(install_command.clone()))
.on_click(cx.listener(
move |this, _, window, cx| {
let Some(project) = this.project.upgrade() else {
return;
};
let Some(workspace) = this.workspace.upgrade() else {
return;
};
let cwd = project.read(cx).first_project_directory(cx);
let shell =
project.read(cx).terminal_settings(&cwd, cx).shell.clone();
let spawn_in_terminal = task::SpawnInTerminal {
id: task::TaskId(install_command.to_string()),
full_label: install_command.to_string(),
label: install_command.to_string(),
command: Some(install_command.to_string()),
args: Vec::new(),
command_label: install_command.to_string(),
cwd,
env: Default::default(),
use_new_terminal: true,
allow_concurrent_runs: true,
reveal: Default::default(),
reveal_target: Default::default(),
hide: Default::default(),
shell,
show_summary: true,
show_command: true,
show_rerun: false,
};
let task = workspace.update(cx, |workspace, cx| {
workspace.spawn_in_terminal(spawn_in_terminal, window, cx)
});
cx.spawn(async move |this, cx| {
task.await;
this.update(cx, |this, cx| {
this.check_for_gemini(cx);
})
.ok();
})
.detach();
},
)),
)
} else {
this.child(
h_flex().gap_1().child(
Button::new(
SharedString::from(format!("start_acp_thread-{name}")),
"Start New Thread",
)
.label_size(LabelSize::Small)
.icon(IconName::Thread)
.icon_position(IconPosition::Start)
.icon_size(IconSize::XSmall)
.icon_color(Color::Muted)
.on_click(move |_, window, cx| {
window.dispatch_action(
NewExternalAgentThread {
agent: Some(agent.clone()),
}
.boxed_clone(),
cx,
);
}),
),
)
}
})
}
}
impl Render for AgentConfiguration {
@ -1237,7 +991,6 @@ impl Render for AgentConfiguration {
.size_full()
.overflow_y_scroll()
.child(self.render_general_settings_section(cx))
.child(self.render_agent_servers_section(cx))
.child(self.render_context_servers_section(window, cx))
.child(self.render_provider_configuration_section(cx)),
)
@ -1356,109 +1109,3 @@ fn show_unable_to_uninstall_extension_with_context_server(
workspace.toggle_status_toast(status_toast, cx);
}
async fn open_new_agent_servers_entry_in_settings_editor(
workspace: WeakEntity<Workspace>,
cx: &mut AsyncWindowContext,
) -> Result<()> {
let settings_editor = workspace
.update_in(cx, |_, window, cx| {
create_and_open_local_file(paths::settings_file(), window, cx, || {
settings::initial_user_settings_content().as_ref().into()
})
})?
.await?
.downcast::<Editor>()
.unwrap();
settings_editor
.downgrade()
.update_in(cx, |item, window, cx| {
let text = item.buffer().read(cx).snapshot(cx).text();
let settings = cx.global::<SettingsStore>();
let mut unique_server_name = None;
let edits = settings.edits_for_update::<AllAgentServersSettings>(&text, |file| {
let server_name: Option<SharedString> = (0..u8::MAX)
.map(|i| {
if i == 0 {
"your_agent".into()
} else {
format!("your_agent_{}", i).into()
}
})
.find(|name| !file.custom.contains_key(name));
if let Some(server_name) = server_name {
unique_server_name = Some(server_name.clone());
file.custom.insert(
server_name,
AgentServerSettings {
command: AgentServerCommand {
path: "path_to_executable".into(),
args: vec![],
env: Some(HashMap::default()),
},
},
);
}
});
if edits.is_empty() {
return;
}
let ranges = edits
.iter()
.map(|(range, _)| range.clone())
.collect::<Vec<_>>();
item.edit(edits, cx);
if let Some((unique_server_name, buffer)) =
unique_server_name.zip(item.buffer().read(cx).as_singleton())
{
let snapshot = buffer.read(cx).snapshot();
if let Some(range) =
find_text_in_buffer(&unique_server_name, ranges[0].start, &snapshot)
{
item.change_selections(
SelectionEffects::scroll(Autoscroll::newest()),
window,
cx,
|selections| {
selections.select_ranges(vec![range]);
},
);
}
}
})
}
fn find_text_in_buffer(
text: &str,
start: usize,
snapshot: &language::BufferSnapshot,
) -> Option<Range<usize>> {
let chars = text.chars().collect::<Vec<char>>();
let mut offset = start;
let mut char_offset = 0;
for c in snapshot.chars_at(start) {
if char_offset >= chars.len() {
break;
}
offset += 1;
if c == chars[char_offset] {
char_offset += 1;
} else {
char_offset = 0;
}
}
if char_offset == chars.len() {
Some(offset.saturating_sub(chars.len())..offset)
} else {
None
}
}

View file

@ -1529,7 +1529,6 @@ impl AgentDiff {
| AcpThreadEvent::TokenUsageUpdated
| AcpThreadEvent::EntriesRemoved(_)
| AcpThreadEvent::ToolAuthorizationRequired
| AcpThreadEvent::PromptCapabilitiesUpdated
| AcpThreadEvent::Retry(_) => {}
}
}

View file

@ -9,12 +9,9 @@ use agent_servers::AgentServerSettings;
use agent2::{DbThreadMetadata, HistoryEntry};
use db::kvp::{Dismissable, KEY_VALUE_STORE};
use serde::{Deserialize, Serialize};
use zed_actions::OpenBrowser;
use zed_actions::agent::ReauthenticateAgent;
use crate::acp::{AcpThreadHistory, ThreadHistoryEvent};
use crate::agent_diff::AgentDiffThread;
use crate::ui::AcpOnboardingModal;
use crate::{
AddContextServer, AgentDiffPane, ContinueThread, ContinueWithBurnMode,
DeleteRecentlyOpenThread, ExpandMessageEditor, Follow, InlineAssistant, NewTextThread,
@ -78,10 +75,7 @@ use workspace::{
};
use zed_actions::{
DecreaseBufferFontSize, IncreaseBufferFontSize, ResetBufferFontSize,
agent::{
OpenAcpOnboardingModal, OpenOnboardingModal, OpenSettings, ResetOnboarding,
ToggleModelSelector,
},
agent::{OpenOnboardingModal, OpenSettings, ResetOnboarding, ToggleModelSelector},
assistant::{OpenRulesLibrary, ToggleFocus},
};
@ -205,9 +199,6 @@ pub fn init(cx: &mut App) {
.register_action(|workspace, _: &OpenOnboardingModal, window, cx| {
AgentOnboardingModal::toggle(workspace, window, cx)
})
.register_action(|workspace, _: &OpenAcpOnboardingModal, window, cx| {
AcpOnboardingModal::toggle(workspace, window, cx)
})
.register_action(|_workspace, _: &ResetOnboarding, window, cx| {
window.dispatch_action(workspace::RestoreBanner.boxed_clone(), cx);
window.refresh();
@ -249,7 +240,6 @@ enum WhichFontSize {
None,
}
// TODO unify this with ExternalAgent
#[derive(Debug, Default, Clone, PartialEq, Serialize, Deserialize)]
pub enum AgentType {
#[default]
@ -598,6 +588,17 @@ impl AgentPanel {
None
};
// Wait for the Gemini/Native feature flag to be available.
let client = workspace.read_with(cx, |workspace, _| workspace.client().clone())?;
if !client.status().borrow().is_signed_out() {
cx.update(|_, cx| {
cx.wait_for_flag_or_timeout::<feature_flags::GeminiAndNativeFeatureFlag>(
Duration::from_secs(2),
)
})?
.await;
}
let panel = workspace.update_in(cx, |workspace, window, cx| {
let panel = cx.new(|cx| {
Self::new(
@ -1023,8 +1024,6 @@ impl AgentPanel {
}
fn new_prompt_editor(&mut self, window: &mut Window, cx: &mut Context<Self>) {
telemetry::event!("Agent Thread Started", agent = "zed-text");
let context = self
.context_store
.update(cx, |context_store, cx| context_store.create(cx));
@ -1117,8 +1116,6 @@ impl AgentPanel {
}
};
telemetry::event!("Agent Thread Started", agent = ext_agent.name());
let server = ext_agent.server(fs, history);
this.update_in(cx, |this, window, cx| {
@ -1476,7 +1473,6 @@ impl AgentPanel {
tools,
self.language_registry.clone(),
self.workspace.clone(),
self.project.downgrade(),
window,
cx,
)
@ -1848,6 +1844,19 @@ impl AgentPanel {
menu
}
pub fn set_selected_agent(
&mut self,
agent: AgentType,
window: &mut Window,
cx: &mut Context<Self>,
) {
if self.selected_agent != agent {
self.selected_agent = agent.clone();
self.serialize(cx);
}
self.new_agent_thread(agent, window, cx);
}
pub fn selected_agent(&self) -> AgentType {
self.selected_agent.clone()
}
@ -1858,11 +1867,6 @@ impl AgentPanel {
window: &mut Window,
cx: &mut Context<Self>,
) {
if self.selected_agent != agent {
self.selected_agent = agent.clone();
self.serialize(cx);
}
match agent {
AgentType::Zed => {
window.dispatch_action(
@ -2200,8 +2204,6 @@ impl AgentPanel {
"Enable Full Screen"
};
let selected_agent = self.selected_agent.clone();
PopoverMenu::new("agent-options-menu")
.trigger_with_tooltip(
IconButton::new("agent-options-menu", IconName::Ellipsis)
@ -2281,11 +2283,6 @@ impl AgentPanel {
.action("Settings", Box::new(OpenSettings))
.separator()
.action(full_screen_label, Box::new(ToggleZoom));
if selected_agent == AgentType::Gemini {
menu = menu.action("Reauthenticate", Box::new(ReauthenticateAgent))
}
menu
}))
}
@ -2320,8 +2317,6 @@ impl AgentPanel {
.menu({
let menu = self.assistant_navigation_menu.clone();
move |window, cx| {
telemetry::event!("View Thread History Clicked");
if let Some(menu) = menu.as_ref() {
menu.update(cx, |_, cx| {
cx.defer_in(window, |menu, window, cx| {
@ -2500,8 +2495,6 @@ impl AgentPanel {
let workspace = self.workspace.clone();
move |window, cx| {
telemetry::event!("New Thread Clicked");
let active_thread = active_thread.clone();
Some(ContextMenu::build(window, cx, |mut menu, _window, cx| {
menu = menu
@ -2543,7 +2536,7 @@ impl AgentPanel {
workspace.panel::<AgentPanel>(cx)
{
panel.update(cx, |panel, cx| {
panel.new_agent_thread(
panel.set_selected_agent(
AgentType::NativeAgent,
window,
cx,
@ -2569,7 +2562,7 @@ impl AgentPanel {
workspace.panel::<AgentPanel>(cx)
{
panel.update(cx, |panel, cx| {
panel.new_agent_thread(
panel.set_selected_agent(
AgentType::TextThread,
window,
cx,
@ -2597,7 +2590,7 @@ impl AgentPanel {
workspace.panel::<AgentPanel>(cx)
{
panel.update(cx, |panel, cx| {
panel.new_agent_thread(
panel.set_selected_agent(
AgentType::Gemini,
window,
cx,
@ -2624,7 +2617,7 @@ impl AgentPanel {
workspace.panel::<AgentPanel>(cx)
{
panel.update(cx, |panel, cx| {
panel.new_agent_thread(
panel.set_selected_agent(
AgentType::ClaudeCode,
window,
cx,
@ -2657,7 +2650,7 @@ impl AgentPanel {
workspace.panel::<AgentPanel>(cx)
{
panel.update(cx, |panel, cx| {
panel.new_agent_thread(
panel.set_selected_agent(
AgentType::Custom {
name: agent_name
.clone(),
@ -2678,15 +2671,6 @@ impl AgentPanel {
}
menu
})
.when(cx.has_flag::<GeminiAndNativeFeatureFlag>(), |menu| {
menu.separator().link(
"Add Other Agents",
OpenBrowser {
url: zed_urls::external_agents_docs(cx),
}
.boxed_clone(),
)
});
menu
}))
@ -3767,11 +3751,6 @@ impl Render for AgentPanel {
}
}))
.on_action(cx.listener(Self::toggle_burn_mode))
.on_action(cx.listener(|this, _: &ReauthenticateAgent, window, cx| {
if let Some(thread_view) = this.active_thread_view() {
thread_view.update(cx, |thread_view, cx| thread_view.reauthenticate(window, cx))
}
}))
.child(self.render_toolbar(window, cx))
.children(self.render_onboarding(window, cx))
.map(|parent| match &self.active_view {

View file

@ -160,7 +160,6 @@ pub struct NewNativeAgentThreadFromSummary {
from_session_id: agent_client_protocol::SessionId,
}
// TODO unify this with AgentType
#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema)]
#[serde(rename_all = "snake_case")]
enum ExternalAgent {
@ -175,15 +174,6 @@ enum ExternalAgent {
}
impl ExternalAgent {
fn name(&self) -> &'static str {
match self {
Self::NativeAgent => "zed",
Self::Gemini => "gemini-cli",
Self::ClaudeCode => "claude-code",
Self::Custom { .. } => "custom",
}
}
pub fn server(
&self,
fs: Arc<dyn fs::Fs>,

View file

@ -6,8 +6,7 @@ use feature_flags::ZedProFeatureFlag;
use fuzzy::{StringMatch, StringMatchCandidate, match_strings};
use gpui::{Action, AnyElement, App, BackgroundExecutor, DismissEvent, Subscription, Task};
use language_model::{
AuthenticateError, ConfiguredModel, LanguageModel, LanguageModelProviderId,
LanguageModelRegistry,
ConfiguredModel, LanguageModel, LanguageModelProviderId, LanguageModelRegistry,
};
use ordered_float::OrderedFloat;
use picker::{Picker, PickerDelegate};
@ -77,7 +76,6 @@ pub struct LanguageModelPickerDelegate {
all_models: Arc<GroupedModels>,
filtered_entries: Vec<LanguageModelPickerEntry>,
selected_index: usize,
_authenticate_all_providers_task: Task<()>,
_subscriptions: Vec<Subscription>,
}
@ -98,7 +96,6 @@ impl LanguageModelPickerDelegate {
selected_index: Self::get_active_model_index(&entries, get_active_model(cx)),
filtered_entries: entries,
get_active_model: Arc::new(get_active_model),
_authenticate_all_providers_task: Self::authenticate_all_providers(cx),
_subscriptions: vec![cx.subscribe_in(
&LanguageModelRegistry::global(cx),
window,
@ -142,56 +139,6 @@ impl LanguageModelPickerDelegate {
.unwrap_or(0)
}
/// Authenticates all providers in the [`LanguageModelRegistry`].
///
/// We do this so that we can populate the language selector with all of the
/// models from the configured providers.
fn authenticate_all_providers(cx: &mut App) -> Task<()> {
let authenticate_all_providers = LanguageModelRegistry::global(cx)
.read(cx)
.providers()
.iter()
.map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
.collect::<Vec<_>>();
cx.spawn(async move |_cx| {
for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
if let Err(err) = authenticate_task.await {
if matches!(err, AuthenticateError::CredentialsNotFound) {
// Since we're authenticating these providers in the
// background for the purposes of populating the
// language selector, we don't care about providers
// where the credentials are not found.
} else {
// Some providers have noisy failure states that we
// don't want to spam the logs with every time the
// language model selector is initialized.
//
// Ideally these should have more clear failure modes
// that we know are safe to ignore here, like what we do
// with `CredentialsNotFound` above.
match provider_id.0.as_ref() {
"lmstudio" | "ollama" => {
// LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
//
// These fail noisily, so we don't log them.
}
"copilot_chat" => {
// Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
}
_ => {
log::error!(
"Failed to authenticate provider: {}: {err}",
provider_name.0
);
}
}
}
}
}
})
}
pub fn active_model(&self, cx: &App) -> Option<ConfiguredModel> {
(self.get_active_model)(cx)
}

View file

@ -361,7 +361,6 @@ impl TextThreadEditor {
if self.sending_disabled(cx) {
return;
}
telemetry::event!("Agent Message Sent", agent = "zed-text");
self.send_to_model(window, cx);
}

View file

@ -1,4 +1,3 @@
mod acp_onboarding_modal;
mod agent_notification;
mod burn_mode_tooltip;
mod context_pill;
@ -7,7 +6,6 @@ mod onboarding_modal;
pub mod preview;
mod unavailable_editing_tooltip;
pub use acp_onboarding_modal::*;
pub use agent_notification::*;
pub use burn_mode_tooltip::*;
pub use context_pill::*;

View file

@ -1,254 +0,0 @@
use client::zed_urls;
use gpui::{
ClickEvent, DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, MouseDownEvent, Render,
linear_color_stop, linear_gradient,
};
use ui::{TintColor, Vector, VectorName, prelude::*};
use workspace::{ModalView, Workspace};
use crate::agent_panel::{AgentPanel, AgentType};
macro_rules! acp_onboarding_event {
($name:expr) => {
telemetry::event!($name, source = "ACP Onboarding");
};
($name:expr, $($key:ident $(= $value:expr)?),+ $(,)?) => {
telemetry::event!($name, source = "ACP Onboarding", $($key $(= $value)?),+);
};
}
pub struct AcpOnboardingModal {
focus_handle: FocusHandle,
workspace: Entity<Workspace>,
}
impl AcpOnboardingModal {
pub fn toggle(workspace: &mut Workspace, window: &mut Window, cx: &mut Context<Workspace>) {
let workspace_entity = cx.entity();
workspace.toggle_modal(window, cx, |_window, cx| Self {
workspace: workspace_entity,
focus_handle: cx.focus_handle(),
});
}
fn open_panel(&mut self, _: &ClickEvent, window: &mut Window, cx: &mut Context<Self>) {
self.workspace.update(cx, |workspace, cx| {
workspace.focus_panel::<AgentPanel>(window, cx);
if let Some(panel) = workspace.panel::<AgentPanel>(cx) {
panel.update(cx, |panel, cx| {
panel.new_agent_thread(AgentType::Gemini, window, cx);
});
}
});
cx.emit(DismissEvent);
acp_onboarding_event!("Open Panel Clicked");
}
fn view_docs(&mut self, _: &ClickEvent, _: &mut Window, cx: &mut Context<Self>) {
cx.open_url(&zed_urls::external_agents_docs(cx));
cx.notify();
acp_onboarding_event!("Documentation Link Clicked");
}
fn cancel(&mut self, _: &menu::Cancel, _: &mut Window, cx: &mut Context<Self>) {
cx.emit(DismissEvent);
}
}
impl EventEmitter<DismissEvent> for AcpOnboardingModal {}
impl Focusable for AcpOnboardingModal {
fn focus_handle(&self, _cx: &App) -> FocusHandle {
self.focus_handle.clone()
}
}
impl ModalView for AcpOnboardingModal {}
impl Render for AcpOnboardingModal {
fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let illustration_element = |label: bool, opacity: f32| {
h_flex()
.px_1()
.py_0p5()
.gap_1()
.rounded_sm()
.bg(cx.theme().colors().element_active.opacity(0.05))
.border_1()
.border_color(cx.theme().colors().border)
.border_dashed()
.child(
Icon::new(IconName::Stop)
.size(IconSize::Small)
.color(Color::Custom(cx.theme().colors().text_muted.opacity(0.15))),
)
.map(|this| {
if label {
this.child(
Label::new("Your Agent Here")
.size(LabelSize::Small)
.color(Color::Muted),
)
} else {
this.child(
div().w_16().h_1().rounded_full().bg(cx
.theme()
.colors()
.element_active
.opacity(0.6)),
)
}
})
.opacity(opacity)
};
let illustration = h_flex()
.relative()
.h(rems_from_px(126.))
.bg(cx.theme().colors().editor_background)
.border_b_1()
.border_color(cx.theme().colors().border_variant)
.justify_center()
.gap_8()
.rounded_t_md()
.overflow_hidden()
.child(
div().absolute().inset_0().w(px(515.)).h(px(126.)).child(
Vector::new(VectorName::AcpGrid, rems_from_px(515.), rems_from_px(126.))
.color(ui::Color::Custom(cx.theme().colors().text.opacity(0.02))),
),
)
.child(div().absolute().inset_0().size_full().bg(linear_gradient(
0.,
linear_color_stop(
cx.theme().colors().elevated_surface_background.opacity(0.1),
0.9,
),
linear_color_stop(
cx.theme().colors().elevated_surface_background.opacity(0.),
0.,
),
)))
.child(
div()
.absolute()
.inset_0()
.size_full()
.bg(gpui::black().opacity(0.15)),
)
.child(
h_flex()
.gap_4()
.child(
Vector::new(VectorName::AcpLogo, rems_from_px(106.), rems_from_px(40.))
.color(ui::Color::Custom(cx.theme().colors().text.opacity(0.8))),
)
.child(
Vector::new(
VectorName::AcpLogoSerif,
rems_from_px(111.),
rems_from_px(41.),
)
.color(ui::Color::Custom(cx.theme().colors().text.opacity(0.8))),
),
)
.child(
v_flex()
.gap_1p5()
.child(illustration_element(false, 0.15))
.child(illustration_element(true, 0.3))
.child(
h_flex()
.pl_1()
.pr_2()
.py_0p5()
.gap_1()
.rounded_sm()
.bg(cx.theme().colors().element_active.opacity(0.2))
.border_1()
.border_color(cx.theme().colors().border)
.child(
Icon::new(IconName::AiGemini)
.size(IconSize::Small)
.color(Color::Muted),
)
.child(Label::new("New Gemini CLI Thread").size(LabelSize::Small)),
)
.child(illustration_element(true, 0.3))
.child(illustration_element(false, 0.15)),
);
let heading = v_flex()
.w_full()
.gap_1()
.child(
Label::new("Now Available")
.size(LabelSize::Small)
.color(Color::Muted),
)
.child(Headline::new("Bring Your Own Agent to Zed").size(HeadlineSize::Large));
let copy = "Bring the agent of your choice to Zed via our new Agent Client Protocol (ACP), starting with Google's Gemini CLI integration.";
let open_panel_button = Button::new("open-panel", "Start with Gemini CLI")
.icon_size(IconSize::Indicator)
.style(ButtonStyle::Tinted(TintColor::Accent))
.full_width()
.on_click(cx.listener(Self::open_panel));
let docs_button = Button::new("add-other-agents", "Add Other Agents")
.icon(IconName::ArrowUpRight)
.icon_size(IconSize::Indicator)
.icon_color(Color::Muted)
.full_width()
.on_click(cx.listener(Self::view_docs));
let close_button = h_flex().absolute().top_2().right_2().child(
IconButton::new("cancel", IconName::Close).on_click(cx.listener(
|_, _: &ClickEvent, _window, cx| {
acp_onboarding_event!("Canceled", trigger = "X click");
cx.emit(DismissEvent);
},
)),
);
v_flex()
.id("acp-onboarding")
.key_context("AcpOnboardingModal")
.relative()
.w(rems(34.))
.h_full()
.elevation_3(cx)
.track_focus(&self.focus_handle(cx))
.overflow_hidden()
.on_action(cx.listener(Self::cancel))
.on_action(cx.listener(|_, _: &menu::Cancel, _window, cx| {
acp_onboarding_event!("Canceled", trigger = "Action");
cx.emit(DismissEvent);
}))
.on_any_mouse_down(cx.listener(|this, _: &MouseDownEvent, window, _cx| {
this.focus_handle.focus(window);
}))
.child(illustration)
.child(
v_flex()
.p_4()
.gap_2()
.child(heading)
.child(Label::new(copy).color(Color::Muted))
.child(
v_flex()
.w_full()
.mt_2()
.gap_1()
.child(open_panel_button)
.child(docs_button),
),
)
.child(close_button)
}
}

View file

@ -12,11 +12,11 @@ use crate::{SignInStatus, YoungAccountBanner, plan_definitions::PlanDefinitions}
#[derive(IntoElement, RegisterComponent)]
pub struct AiUpsellCard {
sign_in_status: SignInStatus,
sign_in: Arc<dyn Fn(&mut Window, &mut App)>,
account_too_young: bool,
user_plan: Option<Plan>,
tab_index: Option<isize>,
pub sign_in_status: SignInStatus,
pub sign_in: Arc<dyn Fn(&mut Window, &mut App)>,
pub account_too_young: bool,
pub user_plan: Option<Plan>,
pub tab_index: Option<isize>,
}
impl AiUpsellCard {
@ -43,11 +43,6 @@ impl AiUpsellCard {
tab_index: None,
}
}
pub fn tab_index(mut self, tab_index: Option<isize>) -> Self {
self.tab_index = tab_index;
self
}
}
impl RenderOnce for AiUpsellCard {

View file

@ -118,7 +118,7 @@ impl Tool for FetchTool {
}
fn needs_confirmation(&self, _: &serde_json::Value, _: &Entity<Project>, _: &App) -> bool {
true
false
}
fn may_perform_edits(&self) -> bool {

View file

@ -435,8 +435,8 @@ mod test {
assert_eq!(
matches,
&[
PathBuf::from(path!("root/apple/banana/carrot")),
PathBuf::from(path!("root/apple/bandana/carbonara"))
PathBuf::from("root/apple/banana/carrot"),
PathBuf::from("root/apple/bandana/carbonara")
]
);
@ -447,8 +447,8 @@ mod test {
assert_eq!(
matches,
&[
PathBuf::from(path!("root/apple/banana/carrot")),
PathBuf::from(path!("root/apple/bandana/carbonara"))
PathBuf::from("root/apple/banana/carrot"),
PathBuf::from("root/apple/bandana/carbonara")
]
);
}

View file

@ -68,7 +68,7 @@ impl Tool for ReadFileTool {
}
fn icon(&self) -> IconName {
IconName::ToolSearch
IconName::ToolRead
}
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {

View file

@ -43,11 +43,3 @@ pub fn ai_privacy_and_security(cx: &App) -> String {
server_url = server_url(cx)
)
}
/// Returns the URL to Zed AI's external agents documentation.
pub fn external_agents_docs(cx: &App) -> String {
format!(
"{server_url}/docs/ai/external-agents",
server_url = server_url(cx)
)
}

View file

@ -1,10 +1,7 @@
use anyhow::Result;
use db::{
query,
sqlez::{
bindable::Column, domain::Domain, statement::Statement,
thread_safe_connection::ThreadSafeConnection,
},
define_connection, query,
sqlez::{bindable::Column, statement::Statement},
sqlez_macros::sql,
};
use serde::{Deserialize, Serialize};
@ -53,11 +50,8 @@ impl Column for SerializedCommandInvocation {
}
}
pub struct CommandPaletteDB(ThreadSafeConnection);
impl Domain for CommandPaletteDB {
const NAME: &str = stringify!(CommandPaletteDB);
const MIGRATIONS: &[&str] = &[sql!(
define_connection!(pub static ref COMMAND_PALETTE_HISTORY: CommandPaletteDB<()> =
&[sql!(
CREATE TABLE IF NOT EXISTS command_invocations(
id INTEGER PRIMARY KEY AUTOINCREMENT,
command_name TEXT NOT NULL,
@ -65,9 +59,7 @@ impl Domain for CommandPaletteDB {
last_invoked INTEGER DEFAULT (unixepoch()) NOT NULL
) STRICT;
)];
}
db::static_connection!(COMMAND_PALETTE_HISTORY, CommandPaletteDB, []);
);
impl CommandPaletteDB {
pub async fn write_command_invocation(

View file

@ -110,14 +110,11 @@ pub async fn open_test_db<M: Migrator>(db_name: &str) -> ThreadSafeConnection {
}
/// Implements a basic DB wrapper for a given domain
///
/// Arguments:
/// - static variable name for connection
/// - type of connection wrapper
/// - dependencies, whose migrations should be run prior to this domain's migrations
#[macro_export]
macro_rules! static_connection {
($id:ident, $t:ident, [ $($d:ty),* ] $(, $global:ident)?) => {
macro_rules! define_connection {
(pub static ref $id:ident: $t:ident<()> = $migrations:expr; $($global:ident)?) => {
pub struct $t($crate::sqlez::thread_safe_connection::ThreadSafeConnection);
impl ::std::ops::Deref for $t {
type Target = $crate::sqlez::thread_safe_connection::ThreadSafeConnection;
@ -126,6 +123,16 @@ macro_rules! static_connection {
}
}
impl $crate::sqlez::domain::Domain for $t {
fn name() -> &'static str {
stringify!($t)
}
fn migrations() -> &'static [&'static str] {
$migrations
}
}
impl $t {
#[cfg(any(test, feature = "test-support"))]
pub async fn open_test_db(name: &'static str) -> Self {
@ -135,8 +142,7 @@ macro_rules! static_connection {
#[cfg(any(test, feature = "test-support"))]
pub static $id: std::sync::LazyLock<$t> = std::sync::LazyLock::new(|| {
#[allow(unused_parens)]
$t($crate::smol::block_on($crate::open_test_db::<($($d,)* $t)>(stringify!($id))))
$t($crate::smol::block_on($crate::open_test_db::<$t>(stringify!($id))))
});
#[cfg(not(any(test, feature = "test-support")))]
@ -147,10 +153,46 @@ macro_rules! static_connection {
} else {
$crate::RELEASE_CHANNEL.dev_name()
};
#[allow(unused_parens)]
$t($crate::smol::block_on($crate::open_db::<($($d,)* $t)>(db_dir, scope)))
$t($crate::smol::block_on($crate::open_db::<$t>(db_dir, scope)))
});
}
};
(pub static ref $id:ident: $t:ident<$($d:ty),+> = $migrations:expr; $($global:ident)?) => {
pub struct $t($crate::sqlez::thread_safe_connection::ThreadSafeConnection);
impl ::std::ops::Deref for $t {
type Target = $crate::sqlez::thread_safe_connection::ThreadSafeConnection;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl $crate::sqlez::domain::Domain for $t {
fn name() -> &'static str {
stringify!($t)
}
fn migrations() -> &'static [&'static str] {
$migrations
}
}
#[cfg(any(test, feature = "test-support"))]
pub static $id: std::sync::LazyLock<$t> = std::sync::LazyLock::new(|| {
$t($crate::smol::block_on($crate::open_test_db::<($($d),+, $t)>(stringify!($id))))
});
#[cfg(not(any(test, feature = "test-support")))]
pub static $id: std::sync::LazyLock<$t> = std::sync::LazyLock::new(|| {
let db_dir = $crate::database_dir();
let scope = if false $(|| stringify!($global) == "global")? {
"global"
} else {
$crate::RELEASE_CHANNEL.dev_name()
};
$t($crate::smol::block_on($crate::open_db::<($($d),+, $t)>(db_dir, scope)))
});
};
}
pub fn write_and_log<F>(cx: &App, db_write: impl FnOnce() -> F + Send + 'static)
@ -177,12 +219,17 @@ mod tests {
enum BadDB {}
impl Domain for BadDB {
const NAME: &str = "db_tests";
const MIGRATIONS: &[&str] = &[
sql!(CREATE TABLE test(value);),
// failure because test already exists
sql!(CREATE TABLE test(value);),
];
fn name() -> &'static str {
"db_tests"
}
fn migrations() -> &'static [&'static str] {
&[
sql!(CREATE TABLE test(value);),
// failure because test already exists
sql!(CREATE TABLE test(value);),
]
}
}
let tempdir = tempfile::Builder::new()
@ -204,15 +251,25 @@ mod tests {
enum CorruptedDB {}
impl Domain for CorruptedDB {
const NAME: &str = "db_tests";
const MIGRATIONS: &[&str] = &[sql!(CREATE TABLE test(value);)];
fn name() -> &'static str {
"db_tests"
}
fn migrations() -> &'static [&'static str] {
&[sql!(CREATE TABLE test(value);)]
}
}
enum GoodDB {}
impl Domain for GoodDB {
const NAME: &str = "db_tests"; //Notice same name
const MIGRATIONS: &[&str] = &[sql!(CREATE TABLE test2(value);)];
fn name() -> &'static str {
"db_tests" //Notice same name
}
fn migrations() -> &'static [&'static str] {
&[sql!(CREATE TABLE test2(value);)] //But different migration
}
}
let tempdir = tempfile::Builder::new()
@ -248,16 +305,25 @@ mod tests {
enum CorruptedDB {}
impl Domain for CorruptedDB {
const NAME: &str = "db_tests";
fn name() -> &'static str {
"db_tests"
}
const MIGRATIONS: &[&str] = &[sql!(CREATE TABLE test(value);)];
fn migrations() -> &'static [&'static str] {
&[sql!(CREATE TABLE test(value);)]
}
}
enum GoodDB {}
impl Domain for GoodDB {
const NAME: &str = "db_tests"; //Notice same name
const MIGRATIONS: &[&str] = &[sql!(CREATE TABLE test2(value);)]; // But different migration
fn name() -> &'static str {
"db_tests" //Notice same name
}
fn migrations() -> &'static [&'static str] {
&[sql!(CREATE TABLE test2(value);)] //But different migration
}
}
let tempdir = tempfile::Builder::new()

View file

@ -2,26 +2,16 @@ use gpui::App;
use sqlez_macros::sql;
use util::ResultExt as _;
use crate::{
query,
sqlez::{domain::Domain, thread_safe_connection::ThreadSafeConnection},
write_and_log,
};
use crate::{define_connection, query, write_and_log};
pub struct KeyValueStore(crate::sqlez::thread_safe_connection::ThreadSafeConnection);
impl Domain for KeyValueStore {
const NAME: &str = stringify!(KeyValueStore);
const MIGRATIONS: &[&str] = &[sql!(
define_connection!(pub static ref KEY_VALUE_STORE: KeyValueStore<()> =
&[sql!(
CREATE TABLE IF NOT EXISTS kv_store(
key TEXT PRIMARY KEY,
value TEXT NOT NULL
) STRICT;
)];
}
crate::static_connection!(KEY_VALUE_STORE, KeyValueStore, []);
);
pub trait Dismissable {
const KEY: &'static str;
@ -101,19 +91,15 @@ mod tests {
}
}
pub struct GlobalKeyValueStore(ThreadSafeConnection);
impl Domain for GlobalKeyValueStore {
const NAME: &str = stringify!(GlobalKeyValueStore);
const MIGRATIONS: &[&str] = &[sql!(
define_connection!(pub static ref GLOBAL_KEY_VALUE_STORE: GlobalKeyValueStore<()> =
&[sql!(
CREATE TABLE IF NOT EXISTS kv_store(
key TEXT PRIMARY KEY,
value TEXT NOT NULL
) STRICT;
)];
}
crate::static_connection!(GLOBAL_KEY_VALUE_STORE, GlobalKeyValueStore, [], global);
global
);
impl GlobalKeyValueStore {
query! {

View file

@ -19,10 +19,6 @@ static KEYMAP_LINUX: LazyLock<KeymapFile> = LazyLock::new(|| {
load_keymap("keymaps/default-linux.json").expect("Failed to load Linux keymap")
});
static KEYMAP_WINDOWS: LazyLock<KeymapFile> = LazyLock::new(|| {
load_keymap("keymaps/default-windows.json").expect("Failed to load Windows keymap")
});
static ALL_ACTIONS: LazyLock<Vec<ActionDef>> = LazyLock::new(dump_all_gpui_actions);
const FRONT_MATTER_COMMENT: &str = "<!-- ZED_META {} -->";
@ -220,7 +216,6 @@ fn find_binding(os: &str, action: &str) -> Option<String> {
let keymap = match os {
"macos" => &KEYMAP_MACOS,
"linux" | "freebsd" => &KEYMAP_LINUX,
"windows" => &KEYMAP_WINDOWS,
_ => unreachable!("Not a valid OS: {}", os),
};

View file

@ -2588,7 +2588,7 @@ impl Editor {
|| binding
.keystrokes()
.first()
.is_some_and(|keystroke| keystroke.display_modifiers.modified())
.is_some_and(|keystroke| keystroke.modifiers.modified())
}))
}
@ -7686,16 +7686,16 @@ impl Editor {
.keystroke()
{
modifiers_held = modifiers_held
|| (&accept_keystroke.display_modifiers == modifiers
&& accept_keystroke.display_modifiers.modified());
|| (&accept_keystroke.modifiers == modifiers
&& accept_keystroke.modifiers.modified());
};
if let Some(accept_partial_keystroke) = self
.accept_edit_prediction_keybind(true, window, cx)
.keystroke()
{
modifiers_held = modifiers_held
|| (&accept_partial_keystroke.display_modifiers == modifiers
&& accept_partial_keystroke.display_modifiers.modified());
|| (&accept_partial_keystroke.modifiers == modifiers
&& accept_partial_keystroke.modifiers.modified());
}
if modifiers_held {
@ -9044,7 +9044,7 @@ impl Editor {
let is_platform_style_mac = PlatformStyle::platform() == PlatformStyle::Mac;
let modifiers_color = if accept_keystroke.display_modifiers == window.modifiers() {
let modifiers_color = if accept_keystroke.modifiers == window.modifiers() {
Color::Accent
} else {
Color::Muted
@ -9056,19 +9056,19 @@ impl Editor {
.font(theme::ThemeSettings::get_global(cx).buffer_font.clone())
.text_size(TextSize::XSmall.rems(cx))
.child(h_flex().children(ui::render_modifiers(
&accept_keystroke.display_modifiers,
&accept_keystroke.modifiers,
PlatformStyle::platform(),
Some(modifiers_color),
Some(IconSize::XSmall.rems().into()),
true,
)))
.when(is_platform_style_mac, |parent| {
parent.child(accept_keystroke.display_key.clone())
parent.child(accept_keystroke.key.clone())
})
.when(!is_platform_style_mac, |parent| {
parent.child(
Key::new(
util::capitalize(&accept_keystroke.display_key),
util::capitalize(&accept_keystroke.key),
Some(Color::Default),
)
.size(Some(IconSize::XSmall.rems().into())),
@ -9171,7 +9171,7 @@ impl Editor {
max_width: Pixels,
cursor_point: Point,
style: &EditorStyle,
accept_keystroke: Option<&gpui::KeybindingKeystroke>,
accept_keystroke: Option<&gpui::Keystroke>,
_window: &Window,
cx: &mut Context<Editor>,
) -> Option<AnyElement> {
@ -9249,7 +9249,7 @@ impl Editor {
accept_keystroke.as_ref(),
|el, accept_keystroke| {
el.child(h_flex().children(ui::render_modifiers(
&accept_keystroke.display_modifiers,
&accept_keystroke.modifiers,
PlatformStyle::platform(),
Some(Color::Default),
Some(IconSize::XSmall.rems().into()),
@ -9319,7 +9319,7 @@ impl Editor {
.child(completion),
)
.when_some(accept_keystroke, |el, accept_keystroke| {
if !accept_keystroke.display_modifiers.modified() {
if !accept_keystroke.modifiers.modified() {
return el;
}
@ -9338,7 +9338,7 @@ impl Editor {
.font(theme::ThemeSettings::get_global(cx).buffer_font.clone())
.when(is_platform_style_mac, |parent| parent.gap_1())
.child(h_flex().children(ui::render_modifiers(
&accept_keystroke.display_modifiers,
&accept_keystroke.modifiers,
PlatformStyle::platform(),
Some(if !has_completion {
Color::Muted

View file

@ -43,10 +43,10 @@ use gpui::{
Bounds, ClickEvent, ClipboardItem, ContentMask, Context, Corner, Corners, CursorStyle,
DispatchPhase, Edges, Element, ElementInputHandler, Entity, Focusable as _, FontId,
GlobalElementId, Hitbox, HitboxBehavior, Hsla, InteractiveElement, IntoElement, IsZero,
KeybindingKeystroke, Length, ModifiersChangedEvent, MouseButton, MouseClickEvent,
MouseDownEvent, MouseMoveEvent, MouseUpEvent, PaintQuad, ParentElement, Pixels, ScrollDelta,
ScrollHandle, ScrollWheelEvent, ShapedLine, SharedString, Size, StatefulInteractiveElement,
Style, Styled, TextRun, TextStyleRefinement, WeakEntity, Window, anchored, deferred, div, fill,
Keystroke, Length, ModifiersChangedEvent, MouseButton, MouseClickEvent, MouseDownEvent,
MouseMoveEvent, MouseUpEvent, PaintQuad, ParentElement, Pixels, ScrollDelta, ScrollHandle,
ScrollWheelEvent, ShapedLine, SharedString, Size, StatefulInteractiveElement, Style, Styled,
TextRun, TextStyleRefinement, WeakEntity, Window, anchored, deferred, div, fill,
linear_color_stop, linear_gradient, outline, point, px, quad, relative, size, solid_background,
transparent_black,
};
@ -7150,7 +7150,7 @@ fn header_jump_data(
pub struct AcceptEditPredictionBinding(pub(crate) Option<gpui::KeyBinding>);
impl AcceptEditPredictionBinding {
pub fn keystroke(&self) -> Option<&KeybindingKeystroke> {
pub fn keystroke(&self) -> Option<&Keystroke> {
if let Some(binding) = self.0.as_ref() {
match &binding.keystrokes() {
[keystroke, ..] => Some(keystroke),

View file

@ -1,6 +1,7 @@
use crate::{
Anchor, Editor, EditorSettings, EditorSnapshot, FindAllReferences, GoToDefinition,
GoToTypeDefinition, GotoDefinitionKind, InlayId, Navigated, PointForPosition, SelectPhase,
display_map::InlayOffset,
editor_settings::GoToDefinitionFallback,
hover_popover::{self, InlayHover},
scroll::ScrollAmount,
@ -15,6 +16,7 @@ use project::{
};
use settings::Settings;
use std::ops::Range;
use text;
use theme::ActiveTheme as _;
use util::{ResultExt, TryFutureExt as _, maybe};
@ -121,13 +123,25 @@ impl Editor {
cx: &mut Context<Self>,
) {
let hovered_link_modifier = Editor::multi_cursor_modifier(false, &modifiers, cx);
if !hovered_link_modifier || self.has_pending_selection() {
// When you're dragging to select, and you release the drag to create the selection,
// if you happened to end over something hoverable (including an inlay hint), don't
// have the hovered link appear. That would be annoying, because all you're trying
// to do is to create a selection, not hover to see a hovered link.
if self.has_pending_selection() {
self.hide_hovered_link(cx);
return;
}
match point_for_position.as_valid() {
Some(point) => {
// Hide the underline unless you're holding the modifier key on the keyboard
// which will perform a goto definition.
if !hovered_link_modifier {
self.hide_hovered_link(cx);
return;
}
let trigger_point = TriggerPoint::Text(
snapshot
.buffer_snapshot
@ -319,6 +333,7 @@ pub fn update_inlay_link_and_hover_points(
let inlay_hint_cache = editor.inlay_hint_cache();
let excerpt_id = previous_valid_anchor.excerpt_id;
if let Some(cached_hint) = inlay_hint_cache.hint_by_id(excerpt_id, hovered_hint.id) {
// Check if we should process this hint for hover
match cached_hint.resolve_state {
ResolveState::CanResolve(_, _) => {
if let Some(buffer_id) = snapshot
@ -419,31 +434,46 @@ pub fn update_inlay_link_and_hover_points(
);
hover_updated = true;
}
if let Some((language_server_id, location)) =
hovered_hint_part.location
&& secondary_held
&& !editor.has_pending_nonempty_selection()
{
go_to_definition_updated = true;
show_link_definition(
shift_held,
editor,
TriggerPoint::InlayHint(
highlight,
location,
language_server_id,
),
snapshot,
window,
cx,
);
// Now perform the "Go to Definition" flow to get hover documentation
if let Some(project) = editor.project.clone() {
let highlight = highlight.clone();
let hint_value = hovered_hint_part.value.clone();
let location = location.clone();
get_docs_then_show_hover(
window, cx, highlight, hint_value, location,
project,
);
}
if secondary_held
&& !editor.has_pending_nonempty_selection()
{
go_to_definition_updated = true;
show_link_definition(
shift_held,
editor,
TriggerPoint::InlayHint(
highlight,
location,
language_server_id,
),
snapshot,
window,
cx,
);
}
}
}
}
};
}
}
ResolveState::Resolving => {}
}
};
}
}
}
@ -456,6 +486,145 @@ pub fn update_inlay_link_and_hover_points(
}
}
fn get_docs_then_show_hover(
window: &mut Window,
cx: &mut Context<'_, Editor>,
highlight: InlayHighlight,
hint_value: String,
location: lsp::Location,
project: Entity<Project>,
) {
cx.spawn_in(window, async move |editor, cx| {
async move {
// Small delay to show the loading message first
cx.background_executor()
.timer(std::time::Duration::from_millis(50))
.await;
// Convert LSP URL to file path
let file_path = location
.uri
.to_file_path()
.map_err(|_| anyhow::anyhow!("Invalid file URL"))?;
// Open the definition file
let definition_buffer = project
.update(cx, |project, cx| project.open_local_buffer(file_path, cx))?
.await?;
// Extract documentation directly from the source
let documentation = definition_buffer.update(cx, |buffer, _| {
let line_number = location.range.start.line as usize;
// Get the text of the buffer
let text = buffer.text();
let lines: Vec<&str> = text.lines().collect();
// Look backwards from the definition line to find doc comments
let mut doc_lines = Vec::new();
let mut current_line = line_number.saturating_sub(1);
// Skip any attributes like #[derive(...)]
while current_line > 0
&& lines.get(current_line).map_or(false, |line| {
let trimmed = line.trim();
trimmed.starts_with("#[") || trimmed.is_empty()
})
{
current_line = current_line.saturating_sub(1);
}
// Collect doc comments
while current_line > 0 {
if let Some(line) = lines.get(current_line) {
let trimmed = line.trim();
if trimmed.starts_with("///") {
// Remove the /// and any leading space
let doc_text = trimmed
.strip_prefix("///")
.unwrap_or("")
.strip_prefix(" ")
.unwrap_or_else(|| trimmed.strip_prefix("///").unwrap_or(""));
doc_lines.push(doc_text.to_string());
} else if !trimmed.is_empty() {
// Stop at the first non-doc, non-empty line
break;
}
}
current_line = current_line.saturating_sub(1);
}
// Reverse to get correct order
doc_lines.reverse();
// Also get the actual definition line
let definition = lines
.get(line_number)
.map(|s| s.trim().to_string())
.unwrap_or_else(|| hint_value.clone());
if doc_lines.is_empty() {
None
} else {
let docs = doc_lines.join("\n");
Some((definition, docs))
}
})?;
if let Some((definition, docs)) = documentation {
// Format as markdown with the definition as a code block
let formatted_docs = format!("```rust\n{}\n```\n\n{}", definition, docs);
editor
.update_in(cx, |editor, window, cx| {
hover_popover::hover_at_inlay(
editor,
InlayHover {
tooltip: HoverBlock {
text: formatted_docs,
kind: HoverBlockKind::Markdown,
},
range: highlight,
},
window,
cx,
);
})
.log_err();
} else {
// Fallback to showing just the location info
let fallback_text = format!(
"{}\n\nDefined in at line {}",
hint_value.trim(),
// filename, // TODO
location.range.start.line + 1
);
editor
.update_in(cx, |editor, window, cx| {
hover_popover::hover_at_inlay(
editor,
InlayHover {
tooltip: HoverBlock {
text: fallback_text,
kind: HoverBlockKind::PlainText,
},
range: highlight,
},
window,
cx,
);
})
.log_err();
}
anyhow::Ok(())
}
.log_err()
.await
})
.detach();
}
pub fn show_link_definition(
shift_held: bool,
editor: &mut Editor,

View file

@ -1404,7 +1404,7 @@ impl ProjectItem for Editor {
}
fn for_broken_project_item(
abs_path: &Path,
abs_path: PathBuf,
is_local: bool,
e: &anyhow::Error,
window: &mut Window,

View file

@ -1,17 +1,13 @@
use anyhow::Result;
use db::{
query,
sqlez::{
bindable::{Bind, Column, StaticColumnCount},
domain::Domain,
statement::Statement,
},
sqlez_macros::sql,
};
use db::sqlez::bindable::{Bind, Column, StaticColumnCount};
use db::sqlez::statement::Statement;
use fs::MTime;
use itertools::Itertools as _;
use std::path::PathBuf;
use db::sqlez_macros::sql;
use db::{define_connection, query};
use workspace::{ItemId, WorkspaceDb, WorkspaceId};
#[derive(Clone, Debug, PartialEq, Default)]
@ -87,11 +83,7 @@ impl Column for SerializedEditor {
}
}
pub struct EditorDb(db::sqlez::thread_safe_connection::ThreadSafeConnection);
impl Domain for EditorDb {
const NAME: &str = stringify!(EditorDb);
define_connection!(
// Current schema shape using pseudo-rust syntax:
// editors(
// item_id: usize,
@ -121,8 +113,7 @@ impl Domain for EditorDb {
// start: usize,
// end: usize,
// )
const MIGRATIONS: &[&str] = &[
pub static ref DB: EditorDb<WorkspaceDb> = &[
sql! (
CREATE TABLE editors(
item_id INTEGER NOT NULL,
@ -198,9 +189,7 @@ impl Domain for EditorDb {
) STRICT;
),
];
}
db::static_connection!(DB, EditorDb, [WorkspaceDb]);
);
// https://www.sqlite.org/limits.html
// > <..> the maximum value of a host parameter number is SQLITE_MAX_VARIABLE_NUMBER,

View file

@ -98,10 +98,6 @@ impl FeatureFlag for GeminiAndNativeFeatureFlag {
// integration too, and we'd like to turn Gemini/Native on in new builds
// without enabling Claude Code in old builds.
const NAME: &'static str = "gemini-and-native";
fn enabled_for_all() -> bool {
true
}
}
pub struct ClaudeCodeFeatureFlag;
@ -205,7 +201,7 @@ impl FeatureFlagAppExt for App {
fn has_flag<T: FeatureFlag>(&self) -> bool {
self.try_global::<FeatureFlags>()
.map(|flags| flags.has_flag::<T>())
.unwrap_or(T::enabled_for_all())
.unwrap_or(false)
}
fn is_staff(&self) -> bool {

View file

@ -4466,7 +4466,7 @@ fn current_language_model(cx: &Context<'_, GitPanel>) -> Option<Arc<dyn Language
is_enabled
.then(|| {
let ConfiguredModel { provider, model } =
LanguageModelRegistry::read_global(cx).commit_message_model()?;
LanguageModelRegistry::read_global(cx).commit_message_model(cx)?;
provider.is_authenticated(cx).then(|| model)
})

View file

@ -37,10 +37,10 @@ use crate::{
AssetSource, BackgroundExecutor, Bounds, ClipboardItem, CursorStyle, DispatchPhase, DisplayId,
EventEmitter, FocusHandle, FocusMap, ForegroundExecutor, Global, KeyBinding, KeyContext,
Keymap, Keystroke, LayoutId, Menu, MenuItem, OwnedMenu, PathPromptOptions, Pixels, Platform,
PlatformDisplay, PlatformKeyboardLayout, PlatformKeyboardMapper, Point, PromptBuilder,
PromptButton, PromptHandle, PromptLevel, Render, RenderImage, RenderablePromptHandle,
Reservation, ScreenCaptureSource, SubscriberSet, Subscription, SvgRenderer, Task, TextSystem,
Window, WindowAppearance, WindowHandle, WindowId, WindowInvalidator,
PlatformDisplay, PlatformKeyboardLayout, Point, PromptBuilder, PromptButton, PromptHandle,
PromptLevel, Render, RenderImage, RenderablePromptHandle, Reservation, ScreenCaptureSource,
SubscriberSet, Subscription, SvgRenderer, Task, TextSystem, Window, WindowAppearance,
WindowHandle, WindowId, WindowInvalidator,
colors::{Colors, GlobalColors},
current_platform, hash, init_app_menus,
};
@ -263,7 +263,6 @@ pub struct App {
pub(crate) focus_handles: Arc<FocusMap>,
pub(crate) keymap: Rc<RefCell<Keymap>>,
pub(crate) keyboard_layout: Box<dyn PlatformKeyboardLayout>,
pub(crate) keyboard_mapper: Rc<dyn PlatformKeyboardMapper>,
pub(crate) global_action_listeners:
FxHashMap<TypeId, Vec<Rc<dyn Fn(&dyn Any, DispatchPhase, &mut Self)>>>,
pending_effects: VecDeque<Effect>,
@ -313,7 +312,6 @@ impl App {
let text_system = Arc::new(TextSystem::new(platform.text_system()));
let entities = EntityMap::new();
let keyboard_layout = platform.keyboard_layout();
let keyboard_mapper = platform.keyboard_mapper();
let app = Rc::new_cyclic(|this| AppCell {
app: RefCell::new(App {
@ -339,7 +337,6 @@ impl App {
focus_handles: Arc::new(RwLock::new(SlotMap::with_key())),
keymap: Rc::new(RefCell::new(Keymap::default())),
keyboard_layout,
keyboard_mapper,
global_action_listeners: FxHashMap::default(),
pending_effects: VecDeque::new(),
pending_notifications: FxHashSet::default(),
@ -379,7 +376,6 @@ impl App {
if let Some(app) = app.upgrade() {
let cx = &mut app.borrow_mut();
cx.keyboard_layout = cx.platform.keyboard_layout();
cx.keyboard_mapper = cx.platform.keyboard_mapper();
cx.keyboard_layout_observers
.clone()
.retain(&(), move |callback| (callback)(cx));
@ -428,11 +424,6 @@ impl App {
self.keyboard_layout.as_ref()
}
/// Get the current keyboard mapper.
pub fn keyboard_mapper(&self) -> &Rc<dyn PlatformKeyboardMapper> {
&self.keyboard_mapper
}
/// Invokes a handler when the current keyboard layout changes
pub fn on_keyboard_layout_change<F>(&self, mut callback: F) -> Subscription
where

View file

@ -4,7 +4,7 @@ mod context;
pub use binding::*;
pub use context::*;
use crate::{Action, AsKeystroke, Keystroke, is_no_action};
use crate::{Action, Keystroke, is_no_action};
use collections::{HashMap, HashSet};
use smallvec::SmallVec;
use std::any::TypeId;
@ -141,7 +141,7 @@ impl Keymap {
/// only.
pub fn bindings_for_input(
&self,
input: &[impl AsKeystroke],
input: &[Keystroke],
context_stack: &[KeyContext],
) -> (SmallVec<[KeyBinding; 1]>, bool) {
let mut matched_bindings = SmallVec::<[(usize, BindingIndex, &KeyBinding); 1]>::new();
@ -192,6 +192,7 @@ impl Keymap {
(bindings, !pending.is_empty())
}
/// Check if the given binding is enabled, given a certain key context.
/// Returns the deepest depth at which the binding matches, or None if it doesn't match.
fn binding_enabled(&self, binding: &KeyBinding, contexts: &[KeyContext]) -> Option<usize> {
@ -638,7 +639,7 @@ mod tests {
fn assert_bindings(keymap: &Keymap, action: &dyn Action, expected: &[&str]) {
let actual = keymap
.bindings_for_action(action)
.map(|binding| binding.keystrokes[0].inner.unparse())
.map(|binding| binding.keystrokes[0].unparse())
.collect::<Vec<_>>();
assert_eq!(actual, expected, "{:?}", action);
}

View file

@ -1,15 +1,14 @@
use std::rc::Rc;
use crate::{
Action, AsKeystroke, DummyKeyboardMapper, InvalidKeystrokeError, KeyBindingContextPredicate,
KeybindingKeystroke, Keystroke, PlatformKeyboardMapper, SharedString,
};
use collections::HashMap;
use crate::{Action, InvalidKeystrokeError, KeyBindingContextPredicate, Keystroke, SharedString};
use smallvec::SmallVec;
/// A keybinding and its associated metadata, from the keymap.
pub struct KeyBinding {
pub(crate) action: Box<dyn Action>,
pub(crate) keystrokes: SmallVec<[KeybindingKeystroke; 2]>,
pub(crate) keystrokes: SmallVec<[Keystroke; 2]>,
pub(crate) context_predicate: Option<Rc<KeyBindingContextPredicate>>,
pub(crate) meta: Option<KeyBindingMetaIndex>,
/// The json input string used when building the keybinding, if any
@ -33,15 +32,7 @@ impl KeyBinding {
pub fn new<A: Action>(keystrokes: &str, action: A, context: Option<&str>) -> Self {
let context_predicate =
context.map(|context| KeyBindingContextPredicate::parse(context).unwrap().into());
Self::load(
keystrokes,
Box::new(action),
context_predicate,
false,
None,
&DummyKeyboardMapper,
)
.unwrap()
Self::load(keystrokes, Box::new(action), context_predicate, None, None).unwrap()
}
/// Load a keybinding from the given raw data.
@ -49,22 +40,24 @@ impl KeyBinding {
keystrokes: &str,
action: Box<dyn Action>,
context_predicate: Option<Rc<KeyBindingContextPredicate>>,
use_key_equivalents: bool,
key_equivalents: Option<&HashMap<char, char>>,
action_input: Option<SharedString>,
keyboard_mapper: &dyn PlatformKeyboardMapper,
) -> std::result::Result<Self, InvalidKeystrokeError> {
let keystrokes: SmallVec<[KeybindingKeystroke; 2]> = keystrokes
let mut keystrokes: SmallVec<[Keystroke; 2]> = keystrokes
.split_whitespace()
.map(|source| {
let keystroke = Keystroke::parse(source)?;
Ok(KeybindingKeystroke::new(
keystroke,
use_key_equivalents,
keyboard_mapper,
))
})
.map(Keystroke::parse)
.collect::<std::result::Result<_, _>>()?;
if let Some(equivalents) = key_equivalents {
for keystroke in keystrokes.iter_mut() {
if keystroke.key.chars().count() == 1
&& let Some(key) = equivalents.get(&keystroke.key.chars().next().unwrap())
{
keystroke.key = key.to_string();
}
}
}
Ok(Self {
keystrokes,
action,
@ -86,13 +79,13 @@ impl KeyBinding {
}
/// Check if the given keystrokes match this binding.
pub fn match_keystrokes(&self, typed: &[impl AsKeystroke]) -> Option<bool> {
pub fn match_keystrokes(&self, typed: &[Keystroke]) -> Option<bool> {
if self.keystrokes.len() < typed.len() {
return None;
}
for (target, typed) in self.keystrokes.iter().zip(typed.iter()) {
if !typed.as_keystroke().should_match(target) {
if !typed.should_match(target) {
return None;
}
}
@ -101,7 +94,7 @@ impl KeyBinding {
}
/// Get the keystrokes associated with this binding
pub fn keystrokes(&self) -> &[KeybindingKeystroke] {
pub fn keystrokes(&self) -> &[Keystroke] {
self.keystrokes.as_slice()
}

View file

@ -231,6 +231,7 @@ pub(crate) trait Platform: 'static {
fn on_quit(&self, callback: Box<dyn FnMut()>);
fn on_reopen(&self, callback: Box<dyn FnMut()>);
fn on_keyboard_layout_change(&self, callback: Box<dyn FnMut()>);
fn set_menus(&self, menus: Vec<Menu>, keymap: &Keymap);
fn get_menus(&self) -> Option<Vec<OwnedMenu>> {
@ -250,6 +251,7 @@ pub(crate) trait Platform: 'static {
fn on_app_menu_action(&self, callback: Box<dyn FnMut(&dyn Action)>);
fn on_will_open_app_menu(&self, callback: Box<dyn FnMut()>);
fn on_validate_app_menu_command(&self, callback: Box<dyn FnMut(&dyn Action) -> bool>);
fn keyboard_layout(&self) -> Box<dyn PlatformKeyboardLayout>;
fn compositor_name(&self) -> &'static str {
""
@ -270,10 +272,6 @@ pub(crate) trait Platform: 'static {
fn write_credentials(&self, url: &str, username: &str, password: &[u8]) -> Task<Result<()>>;
fn read_credentials(&self, url: &str) -> Task<Result<Option<(String, Vec<u8>)>>>;
fn delete_credentials(&self, url: &str) -> Task<Result<()>>;
fn keyboard_layout(&self) -> Box<dyn PlatformKeyboardLayout>;
fn keyboard_mapper(&self) -> Rc<dyn PlatformKeyboardMapper>;
fn on_keyboard_layout_change(&self, callback: Box<dyn FnMut()>);
}
/// A handle to a platform's display, e.g. a monitor or laptop screen.

View file

@ -1,7 +1,3 @@
use collections::HashMap;
use crate::{KeybindingKeystroke, Keystroke};
/// A trait for platform-specific keyboard layouts
pub trait PlatformKeyboardLayout {
/// Get the keyboard layout ID, which should be unique to the layout
@ -9,33 +5,3 @@ pub trait PlatformKeyboardLayout {
/// Get the keyboard layout display name
fn name(&self) -> &str;
}
/// A trait for platform-specific keyboard mappings
pub trait PlatformKeyboardMapper {
/// Map a key equivalent to its platform-specific representation
fn map_key_equivalent(
&self,
keystroke: Keystroke,
use_key_equivalents: bool,
) -> KeybindingKeystroke;
/// Get the key equivalents for the current keyboard layout,
/// only used on macOS
fn get_key_equivalents(&self) -> Option<&HashMap<char, char>>;
}
/// A dummy implementation of the platform keyboard mapper
pub struct DummyKeyboardMapper;
impl PlatformKeyboardMapper for DummyKeyboardMapper {
fn map_key_equivalent(
&self,
keystroke: Keystroke,
_use_key_equivalents: bool,
) -> KeybindingKeystroke {
KeybindingKeystroke::from_keystroke(keystroke)
}
fn get_key_equivalents(&self) -> Option<&HashMap<char, char>> {
None
}
}

View file

@ -5,14 +5,6 @@ use std::{
fmt::{Display, Write},
};
use crate::PlatformKeyboardMapper;
/// This is a helper trait so that we can simplify the implementation of some functions
pub trait AsKeystroke {
/// Returns the GPUI representation of the keystroke.
fn as_keystroke(&self) -> &Keystroke;
}
/// A keystroke and associated metadata generated by the platform
#[derive(Clone, Debug, Eq, PartialEq, Default, Deserialize, Hash)]
pub struct Keystroke {
@ -32,17 +24,6 @@ pub struct Keystroke {
pub key_char: Option<String>,
}
/// Represents a keystroke that can be used in keybindings and displayed to the user.
#[derive(Debug, Clone, Eq, PartialEq, Hash)]
pub struct KeybindingKeystroke {
/// The GPUI representation of the keystroke.
pub inner: Keystroke,
/// The modifiers to display.
pub display_modifiers: Modifiers,
/// The key to display.
pub display_key: String,
}
/// Error type for `Keystroke::parse`. This is used instead of `anyhow::Error` so that Zed can use
/// markdown to display it.
#[derive(Debug)]
@ -77,7 +58,7 @@ impl Keystroke {
///
/// This method assumes that `self` was typed and `target' is in the keymap, and checks
/// both possibilities for self against the target.
pub fn should_match(&self, target: &KeybindingKeystroke) -> bool {
pub fn should_match(&self, target: &Keystroke) -> bool {
#[cfg(not(target_os = "windows"))]
if let Some(key_char) = self
.key_char
@ -90,7 +71,7 @@ impl Keystroke {
..Default::default()
};
if &target.inner.key == key_char && target.inner.modifiers == ime_modifiers {
if &target.key == key_char && target.modifiers == ime_modifiers {
return true;
}
}
@ -102,12 +83,12 @@ impl Keystroke {
.filter(|key_char| key_char != &&self.key)
{
// On Windows, if key_char is set, then the typed keystroke produced the key_char
if &target.inner.key == key_char && target.inner.modifiers == Modifiers::none() {
if &target.key == key_char && target.modifiers == Modifiers::none() {
return true;
}
}
target.inner.modifiers == self.modifiers && target.inner.key == self.key
target.modifiers == self.modifiers && target.key == self.key
}
/// key syntax is:
@ -219,7 +200,31 @@ impl Keystroke {
/// Produces a representation of this key that Parse can understand.
pub fn unparse(&self) -> String {
unparse(&self.modifiers, &self.key)
let mut str = String::new();
if self.modifiers.function {
str.push_str("fn-");
}
if self.modifiers.control {
str.push_str("ctrl-");
}
if self.modifiers.alt {
str.push_str("alt-");
}
if self.modifiers.platform {
#[cfg(target_os = "macos")]
str.push_str("cmd-");
#[cfg(any(target_os = "linux", target_os = "freebsd"))]
str.push_str("super-");
#[cfg(target_os = "windows")]
str.push_str("win-");
}
if self.modifiers.shift {
str.push_str("shift-");
}
str.push_str(&self.key);
str
}
/// Returns true if this keystroke left
@ -261,32 +266,6 @@ impl Keystroke {
}
}
impl KeybindingKeystroke {
/// Create a new keybinding keystroke from the given keystroke
pub fn new(
inner: Keystroke,
use_key_equivalents: bool,
keyboard_mapper: &dyn PlatformKeyboardMapper,
) -> Self {
keyboard_mapper.map_key_equivalent(inner, use_key_equivalents)
}
pub(crate) fn from_keystroke(keystroke: Keystroke) -> Self {
let key = keystroke.key.clone();
let modifiers = keystroke.modifiers;
KeybindingKeystroke {
inner: keystroke,
display_modifiers: modifiers,
display_key: key,
}
}
/// Produces a representation of this key that Parse can understand.
pub fn unparse(&self) -> String {
unparse(&self.display_modifiers, &self.display_key)
}
}
fn is_printable_key(key: &str) -> bool {
!matches!(
key,
@ -343,15 +322,65 @@ fn is_printable_key(key: &str) -> bool {
impl std::fmt::Display for Keystroke {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
display_modifiers(&self.modifiers, f)?;
display_key(&self.key, f)
}
}
if self.modifiers.control {
#[cfg(target_os = "macos")]
f.write_char('^')?;
impl std::fmt::Display for KeybindingKeystroke {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
display_modifiers(&self.display_modifiers, f)?;
display_key(&self.display_key, f)
#[cfg(not(target_os = "macos"))]
write!(f, "ctrl-")?;
}
if self.modifiers.alt {
#[cfg(target_os = "macos")]
f.write_char('⌥')?;
#[cfg(not(target_os = "macos"))]
write!(f, "alt-")?;
}
if self.modifiers.platform {
#[cfg(target_os = "macos")]
f.write_char('⌘')?;
#[cfg(any(target_os = "linux", target_os = "freebsd"))]
f.write_char('❖')?;
#[cfg(target_os = "windows")]
f.write_char('⊞')?;
}
if self.modifiers.shift {
#[cfg(target_os = "macos")]
f.write_char('⇧')?;
#[cfg(not(target_os = "macos"))]
write!(f, "shift-")?;
}
let key = match self.key.as_str() {
#[cfg(target_os = "macos")]
"backspace" => '⌫',
#[cfg(target_os = "macos")]
"up" => '↑',
#[cfg(target_os = "macos")]
"down" => '↓',
#[cfg(target_os = "macos")]
"left" => '←',
#[cfg(target_os = "macos")]
"right" => '→',
#[cfg(target_os = "macos")]
"tab" => '⇥',
#[cfg(target_os = "macos")]
"escape" => '⎋',
#[cfg(target_os = "macos")]
"shift" => '⇧',
#[cfg(target_os = "macos")]
"control" => '⌃',
#[cfg(target_os = "macos")]
"alt" => '⌥',
#[cfg(target_os = "macos")]
"platform" => '⌘',
key if key.len() == 1 => key.chars().next().unwrap().to_ascii_uppercase(),
key => return f.write_str(key),
};
f.write_char(key)
}
}
@ -571,110 +600,3 @@ pub struct Capslock {
#[serde(default)]
pub on: bool,
}
impl AsKeystroke for Keystroke {
fn as_keystroke(&self) -> &Keystroke {
self
}
}
impl AsKeystroke for KeybindingKeystroke {
fn as_keystroke(&self) -> &Keystroke {
&self.inner
}
}
fn display_modifiers(modifiers: &Modifiers, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
if modifiers.control {
#[cfg(target_os = "macos")]
f.write_char('^')?;
#[cfg(not(target_os = "macos"))]
write!(f, "ctrl-")?;
}
if modifiers.alt {
#[cfg(target_os = "macos")]
f.write_char('⌥')?;
#[cfg(not(target_os = "macos"))]
write!(f, "alt-")?;
}
if modifiers.platform {
#[cfg(target_os = "macos")]
f.write_char('⌘')?;
#[cfg(any(target_os = "linux", target_os = "freebsd"))]
f.write_char('❖')?;
#[cfg(target_os = "windows")]
f.write_char('⊞')?;
}
if modifiers.shift {
#[cfg(target_os = "macos")]
f.write_char('⇧')?;
#[cfg(not(target_os = "macos"))]
write!(f, "shift-")?;
}
Ok(())
}
fn display_key(key: &str, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let key = match key {
#[cfg(target_os = "macos")]
"backspace" => '⌫',
#[cfg(target_os = "macos")]
"up" => '↑',
#[cfg(target_os = "macos")]
"down" => '↓',
#[cfg(target_os = "macos")]
"left" => '←',
#[cfg(target_os = "macos")]
"right" => '→',
#[cfg(target_os = "macos")]
"tab" => '⇥',
#[cfg(target_os = "macos")]
"escape" => '⎋',
#[cfg(target_os = "macos")]
"shift" => '⇧',
#[cfg(target_os = "macos")]
"control" => '⌃',
#[cfg(target_os = "macos")]
"alt" => '⌥',
#[cfg(target_os = "macos")]
"platform" => '⌘',
key if key.len() == 1 => key.chars().next().unwrap().to_ascii_uppercase(),
key => return f.write_str(key),
};
f.write_char(key)
}
#[inline]
fn unparse(modifiers: &Modifiers, key: &str) -> String {
let mut result = String::new();
if modifiers.function {
result.push_str("fn-");
}
if modifiers.control {
result.push_str("ctrl-");
}
if modifiers.alt {
result.push_str("alt-");
}
if modifiers.platform {
#[cfg(target_os = "macos")]
result.push_str("cmd-");
#[cfg(any(target_os = "linux", target_os = "freebsd"))]
result.push_str("super-");
#[cfg(target_os = "windows")]
result.push_str("win-");
}
if modifiers.shift {
result.push_str("shift-");
}
result.push_str(&key);
result
}

View file

@ -25,8 +25,8 @@ use xkbcommon::xkb::{self, Keycode, Keysym, State};
use crate::{
Action, AnyWindowHandle, BackgroundExecutor, ClipboardItem, CursorStyle, DisplayId,
ForegroundExecutor, Keymap, LinuxDispatcher, Menu, MenuItem, OwnedMenu, PathPromptOptions,
Pixels, Platform, PlatformDisplay, PlatformKeyboardLayout, PlatformKeyboardMapper,
PlatformTextSystem, PlatformWindow, Point, Result, Task, WindowAppearance, WindowParams, px,
Pixels, Platform, PlatformDisplay, PlatformKeyboardLayout, PlatformTextSystem, PlatformWindow,
Point, Result, Task, WindowAppearance, WindowParams, px,
};
#[cfg(any(feature = "wayland", feature = "x11"))]
@ -144,10 +144,6 @@ impl<P: LinuxClient + 'static> Platform for P {
self.keyboard_layout()
}
fn keyboard_mapper(&self) -> Rc<dyn PlatformKeyboardMapper> {
Rc::new(crate::DummyKeyboardMapper)
}
fn on_keyboard_layout_change(&self, callback: Box<dyn FnMut()>) {
self.with_common(|common| common.callbacks.keyboard_layout_change = Some(callback));
}

File diff suppressed because it is too large Load diff

View file

@ -1,5 +1,5 @@
use super::{
BoolExt, MacKeyboardLayout, MacKeyboardMapper,
BoolExt, MacKeyboardLayout,
attributed_string::{NSAttributedString, NSMutableAttributedString},
events::key_to_native,
renderer,
@ -8,9 +8,8 @@ use crate::{
Action, AnyWindowHandle, BackgroundExecutor, ClipboardEntry, ClipboardItem, ClipboardString,
CursorStyle, ForegroundExecutor, Image, ImageFormat, KeyContext, Keymap, MacDispatcher,
MacDisplay, MacWindow, Menu, MenuItem, OsMenu, OwnedMenu, PathPromptOptions, Platform,
PlatformDisplay, PlatformKeyboardLayout, PlatformKeyboardMapper, PlatformTextSystem,
PlatformWindow, Result, SemanticVersion, SystemMenuType, Task, WindowAppearance, WindowParams,
hash,
PlatformDisplay, PlatformKeyboardLayout, PlatformTextSystem, PlatformWindow, Result,
SemanticVersion, SystemMenuType, Task, WindowAppearance, WindowParams, hash,
};
use anyhow::{Context as _, anyhow};
use block::ConcreteBlock;
@ -172,7 +171,6 @@ pub(crate) struct MacPlatformState {
finish_launching: Option<Box<dyn FnOnce()>>,
dock_menu: Option<id>,
menus: Option<Vec<OwnedMenu>>,
keyboard_mapper: Rc<MacKeyboardMapper>,
}
impl Default for MacPlatform {
@ -191,9 +189,6 @@ impl MacPlatform {
#[cfg(not(feature = "font-kit"))]
let text_system = Arc::new(crate::NoopTextSystem::new());
let keyboard_layout = MacKeyboardLayout::new();
let keyboard_mapper = Rc::new(MacKeyboardMapper::new(keyboard_layout.id()));
Self(Mutex::new(MacPlatformState {
headless,
text_system,
@ -214,7 +209,6 @@ impl MacPlatform {
dock_menu: None,
on_keyboard_layout_change: None,
menus: None,
keyboard_mapper,
}))
}
@ -354,19 +348,19 @@ impl MacPlatform {
let mut mask = NSEventModifierFlags::empty();
for (modifier, flag) in &[
(
keystroke.display_modifiers.platform,
keystroke.modifiers.platform,
NSEventModifierFlags::NSCommandKeyMask,
),
(
keystroke.display_modifiers.control,
keystroke.modifiers.control,
NSEventModifierFlags::NSControlKeyMask,
),
(
keystroke.display_modifiers.alt,
keystroke.modifiers.alt,
NSEventModifierFlags::NSAlternateKeyMask,
),
(
keystroke.display_modifiers.shift,
keystroke.modifiers.shift,
NSEventModifierFlags::NSShiftKeyMask,
),
] {
@ -379,7 +373,7 @@ impl MacPlatform {
.initWithTitle_action_keyEquivalent_(
ns_string(name),
selector,
ns_string(key_to_native(&keystroke.display_key).as_ref()),
ns_string(key_to_native(&keystroke.key).as_ref()),
)
.autorelease();
if Self::os_version() >= SemanticVersion::new(12, 0, 0) {
@ -888,10 +882,6 @@ impl Platform for MacPlatform {
Box::new(MacKeyboardLayout::new())
}
fn keyboard_mapper(&self) -> Rc<dyn PlatformKeyboardMapper> {
self.0.lock().keyboard_mapper.clone()
}
fn app_path(&self) -> Result<PathBuf> {
unsafe {
let bundle: id = NSBundle::mainBundle();
@ -1403,8 +1393,6 @@ extern "C" fn will_terminate(this: &mut Object, _: Sel, _: id) {
extern "C" fn on_keyboard_layout_change(this: &mut Object, _: Sel, _: id) {
let platform = unsafe { get_mac_platform(this) };
let mut lock = platform.0.lock();
let keyboard_layout = MacKeyboardLayout::new();
lock.keyboard_mapper = Rc::new(MacKeyboardMapper::new(keyboard_layout.id()));
if let Some(mut callback) = lock.on_keyboard_layout_change.take() {
drop(lock);
callback();

View file

@ -1,9 +1,8 @@
use crate::{
AnyWindowHandle, BackgroundExecutor, ClipboardItem, CursorStyle, DevicePixels,
DummyKeyboardMapper, ForegroundExecutor, Keymap, NoopTextSystem, Platform, PlatformDisplay,
PlatformKeyboardLayout, PlatformKeyboardMapper, PlatformTextSystem, PromptButton,
ScreenCaptureFrame, ScreenCaptureSource, ScreenCaptureStream, SourceMetadata, Task,
TestDisplay, TestWindow, WindowAppearance, WindowParams, size,
ForegroundExecutor, Keymap, NoopTextSystem, Platform, PlatformDisplay, PlatformKeyboardLayout,
PlatformTextSystem, PromptButton, ScreenCaptureFrame, ScreenCaptureSource, ScreenCaptureStream,
SourceMetadata, Task, TestDisplay, TestWindow, WindowAppearance, WindowParams, size,
};
use anyhow::Result;
use collections::VecDeque;
@ -238,10 +237,6 @@ impl Platform for TestPlatform {
Box::new(TestKeyboardLayout)
}
fn keyboard_mapper(&self) -> Rc<dyn PlatformKeyboardMapper> {
Rc::new(DummyKeyboardMapper)
}
fn on_keyboard_layout_change(&self, _: Box<dyn FnMut()>) {}
fn run(&self, _on_finish_launching: Box<dyn FnOnce()>) {

View file

@ -9,8 +9,10 @@ use parking::Parker;
use parking_lot::Mutex;
use util::ResultExt;
use windows::{
Foundation::TimeSpan,
System::Threading::{
ThreadPool, ThreadPoolTimer, TimerElapsedHandler, WorkItemHandler, WorkItemPriority,
ThreadPool, ThreadPoolTimer, TimerElapsedHandler, WorkItemHandler, WorkItemOptions,
WorkItemPriority,
},
Win32::{
Foundation::{LPARAM, WPARAM},
@ -54,7 +56,12 @@ impl WindowsDispatcher {
Ok(())
})
};
ThreadPool::RunWithPriorityAsync(&handler, WorkItemPriority::High).log_err();
ThreadPool::RunWithPriorityAndOptionsAsync(
&handler,
WorkItemPriority::High,
WorkItemOptions::TimeSliced,
)
.log_err();
}
fn dispatch_on_threadpool_after(&self, runnable: Runnable, duration: Duration) {
@ -65,7 +72,12 @@ impl WindowsDispatcher {
Ok(())
})
};
ThreadPoolTimer::CreateTimer(&handler, duration.into()).log_err();
let delay = TimeSpan {
// A time period expressed in 100-nanosecond units.
// 10,000,000 ticks per second
Duration: (duration.as_nanos() / 100) as i64,
};
ThreadPoolTimer::CreateTimer(&handler, delay).log_err();
}
}

View file

@ -1,31 +1,22 @@
use anyhow::Result;
use collections::HashMap;
use windows::Win32::UI::{
Input::KeyboardAndMouse::{
GetKeyboardLayoutNameW, MAPVK_VK_TO_CHAR, MAPVK_VK_TO_VSC, MapVirtualKeyW, ToUnicode,
VIRTUAL_KEY, VK_0, VK_1, VK_2, VK_3, VK_4, VK_5, VK_6, VK_7, VK_8, VK_9, VK_ABNT_C1,
VK_CONTROL, VK_MENU, VK_OEM_1, VK_OEM_2, VK_OEM_3, VK_OEM_4, VK_OEM_5, VK_OEM_6, VK_OEM_7,
VK_OEM_8, VK_OEM_102, VK_OEM_COMMA, VK_OEM_MINUS, VK_OEM_PERIOD, VK_OEM_PLUS, VK_SHIFT,
GetKeyboardLayoutNameW, MAPVK_VK_TO_CHAR, MapVirtualKeyW, ToUnicode, VIRTUAL_KEY, VK_0,
VK_1, VK_2, VK_3, VK_4, VK_5, VK_6, VK_7, VK_8, VK_9, VK_ABNT_C1, VK_CONTROL, VK_MENU,
VK_OEM_1, VK_OEM_2, VK_OEM_3, VK_OEM_4, VK_OEM_5, VK_OEM_6, VK_OEM_7, VK_OEM_8, VK_OEM_102,
VK_OEM_COMMA, VK_OEM_MINUS, VK_OEM_PERIOD, VK_OEM_PLUS, VK_SHIFT,
},
WindowsAndMessaging::KL_NAMELENGTH,
};
use windows_core::HSTRING;
use crate::{
KeybindingKeystroke, Keystroke, Modifiers, PlatformKeyboardLayout, PlatformKeyboardMapper,
};
use crate::{Modifiers, PlatformKeyboardLayout};
pub(crate) struct WindowsKeyboardLayout {
id: String,
name: String,
}
pub(crate) struct WindowsKeyboardMapper {
key_to_vkey: HashMap<String, (u16, bool)>,
vkey_to_key: HashMap<u16, String>,
vkey_to_shifted: HashMap<u16, String>,
}
impl PlatformKeyboardLayout for WindowsKeyboardLayout {
fn id(&self) -> &str {
&self.id
@ -36,65 +27,6 @@ impl PlatformKeyboardLayout for WindowsKeyboardLayout {
}
}
impl PlatformKeyboardMapper for WindowsKeyboardMapper {
fn map_key_equivalent(
&self,
mut keystroke: Keystroke,
use_key_equivalents: bool,
) -> KeybindingKeystroke {
let Some((vkey, shifted_key)) = self.get_vkey_from_key(&keystroke.key, use_key_equivalents)
else {
return KeybindingKeystroke::from_keystroke(keystroke);
};
if shifted_key && keystroke.modifiers.shift {
log::warn!(
"Keystroke '{}' has both shift and a shifted key, this is likely a bug",
keystroke.key
);
}
let shift = shifted_key || keystroke.modifiers.shift;
keystroke.modifiers.shift = false;
let Some(key) = self.vkey_to_key.get(&vkey).cloned() else {
log::error!(
"Failed to map key equivalent '{:?}' to a valid key",
keystroke
);
return KeybindingKeystroke::from_keystroke(keystroke);
};
keystroke.key = if shift {
let Some(shifted_key) = self.vkey_to_shifted.get(&vkey).cloned() else {
log::error!(
"Failed to map keystroke {:?} with virtual key '{:?}' to a shifted key",
keystroke,
vkey
);
return KeybindingKeystroke::from_keystroke(keystroke);
};
shifted_key
} else {
key.clone()
};
let modifiers = Modifiers {
shift,
..keystroke.modifiers
};
KeybindingKeystroke {
inner: keystroke,
display_modifiers: modifiers,
display_key: key,
}
}
fn get_key_equivalents(&self) -> Option<&HashMap<char, char>> {
None
}
}
impl WindowsKeyboardLayout {
pub(crate) fn new() -> Result<Self> {
let mut buffer = [0u16; KL_NAMELENGTH as usize];
@ -116,41 +48,6 @@ impl WindowsKeyboardLayout {
}
}
impl WindowsKeyboardMapper {
pub(crate) fn new() -> Self {
let mut key_to_vkey = HashMap::default();
let mut vkey_to_key = HashMap::default();
let mut vkey_to_shifted = HashMap::default();
for vkey in CANDIDATE_VKEYS {
if let Some(key) = get_key_from_vkey(*vkey) {
key_to_vkey.insert(key.clone(), (vkey.0, false));
vkey_to_key.insert(vkey.0, key);
}
let scan_code = unsafe { MapVirtualKeyW(vkey.0 as u32, MAPVK_VK_TO_VSC) };
if scan_code == 0 {
continue;
}
if let Some(shifted_key) = get_shifted_key(*vkey, scan_code) {
key_to_vkey.insert(shifted_key.clone(), (vkey.0, true));
vkey_to_shifted.insert(vkey.0, shifted_key);
}
}
Self {
key_to_vkey,
vkey_to_key,
vkey_to_shifted,
}
}
fn get_vkey_from_key(&self, key: &str, use_key_equivalents: bool) -> Option<(u16, bool)> {
if use_key_equivalents {
get_vkey_from_key_with_us_layout(key)
} else {
self.key_to_vkey.get(key).cloned()
}
}
}
pub(crate) fn get_keystroke_key(
vkey: VIRTUAL_KEY,
scan_code: u32,
@ -243,134 +140,3 @@ pub(crate) fn generate_key_char(
_ => None,
}
}
fn get_vkey_from_key_with_us_layout(key: &str) -> Option<(u16, bool)> {
match key {
// ` => VK_OEM_3
"`" => Some((VK_OEM_3.0, false)),
"~" => Some((VK_OEM_3.0, true)),
"1" => Some((VK_1.0, false)),
"!" => Some((VK_1.0, true)),
"2" => Some((VK_2.0, false)),
"@" => Some((VK_2.0, true)),
"3" => Some((VK_3.0, false)),
"#" => Some((VK_3.0, true)),
"4" => Some((VK_4.0, false)),
"$" => Some((VK_4.0, true)),
"5" => Some((VK_5.0, false)),
"%" => Some((VK_5.0, true)),
"6" => Some((VK_6.0, false)),
"^" => Some((VK_6.0, true)),
"7" => Some((VK_7.0, false)),
"&" => Some((VK_7.0, true)),
"8" => Some((VK_8.0, false)),
"*" => Some((VK_8.0, true)),
"9" => Some((VK_9.0, false)),
"(" => Some((VK_9.0, true)),
"0" => Some((VK_0.0, false)),
")" => Some((VK_0.0, true)),
"-" => Some((VK_OEM_MINUS.0, false)),
"_" => Some((VK_OEM_MINUS.0, true)),
"=" => Some((VK_OEM_PLUS.0, false)),
"+" => Some((VK_OEM_PLUS.0, true)),
"[" => Some((VK_OEM_4.0, false)),
"{" => Some((VK_OEM_4.0, true)),
"]" => Some((VK_OEM_6.0, false)),
"}" => Some((VK_OEM_6.0, true)),
"\\" => Some((VK_OEM_5.0, false)),
"|" => Some((VK_OEM_5.0, true)),
";" => Some((VK_OEM_1.0, false)),
":" => Some((VK_OEM_1.0, true)),
"'" => Some((VK_OEM_7.0, false)),
"\"" => Some((VK_OEM_7.0, true)),
"," => Some((VK_OEM_COMMA.0, false)),
"<" => Some((VK_OEM_COMMA.0, true)),
"." => Some((VK_OEM_PERIOD.0, false)),
">" => Some((VK_OEM_PERIOD.0, true)),
"/" => Some((VK_OEM_2.0, false)),
"?" => Some((VK_OEM_2.0, true)),
_ => None,
}
}
const CANDIDATE_VKEYS: &[VIRTUAL_KEY] = &[
VK_OEM_3,
VK_OEM_MINUS,
VK_OEM_PLUS,
VK_OEM_4,
VK_OEM_5,
VK_OEM_6,
VK_OEM_1,
VK_OEM_7,
VK_OEM_COMMA,
VK_OEM_PERIOD,
VK_OEM_2,
VK_OEM_102,
VK_OEM_8,
VK_ABNT_C1,
VK_0,
VK_1,
VK_2,
VK_3,
VK_4,
VK_5,
VK_6,
VK_7,
VK_8,
VK_9,
];
#[cfg(test)]
mod tests {
use crate::{Keystroke, Modifiers, PlatformKeyboardMapper, WindowsKeyboardMapper};
#[test]
fn test_keyboard_mapper() {
let mapper = WindowsKeyboardMapper::new();
// Normal case
let keystroke = Keystroke {
modifiers: Modifiers::control(),
key: "a".to_string(),
key_char: None,
};
let mapped = mapper.map_key_equivalent(keystroke.clone(), true);
assert_eq!(mapped.inner, keystroke);
assert_eq!(mapped.display_key, "a");
assert_eq!(mapped.display_modifiers, Modifiers::control());
// Shifted case, ctrl-$
let keystroke = Keystroke {
modifiers: Modifiers::control(),
key: "$".to_string(),
key_char: None,
};
let mapped = mapper.map_key_equivalent(keystroke.clone(), true);
assert_eq!(mapped.inner, keystroke);
assert_eq!(mapped.display_key, "4");
assert_eq!(mapped.display_modifiers, Modifiers::control_shift());
// Shifted case, but shift is true
let keystroke = Keystroke {
modifiers: Modifiers::control_shift(),
key: "$".to_string(),
key_char: None,
};
let mapped = mapper.map_key_equivalent(keystroke, true);
assert_eq!(mapped.inner.modifiers, Modifiers::control());
assert_eq!(mapped.display_key, "4");
assert_eq!(mapped.display_modifiers, Modifiers::control_shift());
// Windows style
let keystroke = Keystroke {
modifiers: Modifiers::control_shift(),
key: "4".to_string(),
key_char: None,
};
let mapped = mapper.map_key_equivalent(keystroke, true);
assert_eq!(mapped.inner.modifiers, Modifiers::control());
assert_eq!(mapped.inner.key, "$");
assert_eq!(mapped.display_key, "4");
assert_eq!(mapped.display_modifiers, Modifiers::control_shift());
}
}

View file

@ -351,10 +351,6 @@ impl Platform for WindowsPlatform {
)
}
fn keyboard_mapper(&self) -> Rc<dyn PlatformKeyboardMapper> {
Rc::new(WindowsKeyboardMapper::new())
}
fn on_keyboard_layout_change(&self, callback: Box<dyn FnMut()>) {
self.state.borrow_mut().callbacks.keyboard_layout_change = Some(callback);
}

View file

@ -215,7 +215,6 @@ pub enum IconName {
Tab,
Terminal,
TerminalAlt,
TerminalGhost,
TextSnippet,
TextThread,
Thread,

View file

@ -401,19 +401,12 @@ pub fn init(cx: &mut App) {
mod persistence {
use std::path::PathBuf;
use db::{
query,
sqlez::{domain::Domain, thread_safe_connection::ThreadSafeConnection},
sqlez_macros::sql,
};
use db::{define_connection, query, sqlez_macros::sql};
use workspace::{ItemId, WorkspaceDb, WorkspaceId};
pub struct ImageViewerDb(ThreadSafeConnection);
impl Domain for ImageViewerDb {
const NAME: &str = stringify!(ImageViewerDb);
const MIGRATIONS: &[&str] = &[sql!(
define_connection! {
pub static ref IMAGE_VIEWER: ImageViewerDb<WorkspaceDb> =
&[sql!(
CREATE TABLE image_viewers (
workspace_id INTEGER,
item_id INTEGER UNIQUE,
@ -424,11 +417,9 @@ mod persistence {
FOREIGN KEY(workspace_id) REFERENCES workspaces(workspace_id)
ON DELETE CASCADE
) STRICT;
)];
)];
}
db::static_connection!(IMAGE_VIEWER, ImageViewerDb, [WorkspaceDb]);
impl ImageViewerDb {
query! {
pub async fn save_image_path(

View file

@ -1569,21 +1569,11 @@ impl Buffer {
self.send_operation(op, true, cx);
}
pub fn buffer_diagnostics(
&self,
for_server: Option<LanguageServerId>,
) -> Vec<&DiagnosticEntry<Anchor>> {
match for_server {
Some(server_id) => match self.diagnostics.binary_search_by_key(&server_id, |v| v.0) {
Ok(idx) => self.diagnostics[idx].1.iter().collect(),
Err(_) => Vec::new(),
},
None => self
.diagnostics
.iter()
.flat_map(|(_, diagnostic_set)| diagnostic_set.iter())
.collect(),
}
pub fn get_diagnostics(&self, server_id: LanguageServerId) -> Option<&DiagnosticSet> {
let Ok(idx) = self.diagnostics.binary_search_by_key(&server_id, |v| v.0) else {
return None;
};
Some(&self.diagnostics[idx].1)
}
fn request_autoindent(&mut self, cx: &mut Context<Self>) {

View file

@ -4,16 +4,12 @@ use crate::{
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
LanguageModelRequest, LanguageModelToolChoice,
};
use anyhow::anyhow;
use futures::{FutureExt, channel::mpsc, future::BoxFuture, stream::BoxStream};
use gpui::{AnyView, App, AsyncApp, Entity, Task, Window};
use http_client::Result;
use parking_lot::Mutex;
use smol::stream::StreamExt;
use std::sync::{
Arc,
atomic::{AtomicBool, Ordering::SeqCst},
};
use std::sync::Arc;
#[derive(Clone)]
pub struct FakeLanguageModelProvider {
@ -110,7 +106,6 @@ pub struct FakeLanguageModel {
>,
)>,
>,
forbid_requests: AtomicBool,
}
impl Default for FakeLanguageModel {
@ -119,20 +114,11 @@ impl Default for FakeLanguageModel {
provider_id: LanguageModelProviderId::from("fake".to_string()),
provider_name: LanguageModelProviderName::from("Fake".to_string()),
current_completion_txs: Mutex::new(Vec::new()),
forbid_requests: AtomicBool::new(false),
}
}
}
impl FakeLanguageModel {
pub fn allow_requests(&self) {
self.forbid_requests.store(false, SeqCst);
}
pub fn forbid_requests(&self) {
self.forbid_requests.store(true, SeqCst);
}
pub fn pending_completions(&self) -> Vec<LanguageModelRequest> {
self.current_completion_txs
.lock()
@ -265,18 +251,9 @@ impl LanguageModel for FakeLanguageModel {
LanguageModelCompletionError,
>,
> {
if self.forbid_requests.load(SeqCst) {
async move {
Err(LanguageModelCompletionError::Other(anyhow!(
"requests are forbidden"
)))
}
.boxed()
} else {
let (tx, rx) = mpsc::unbounded();
self.current_completion_txs.lock().push((request, tx));
async move { Ok(rx.boxed()) }.boxed()
}
let (tx, rx) = mpsc::unbounded();
self.current_completion_txs.lock().push((request, tx));
async move { Ok(rx.boxed()) }.boxed()
}
fn as_fake(&self) -> &Self {

View file

@ -6,7 +6,6 @@ use collections::BTreeMap;
use gpui::{App, Context, Entity, EventEmitter, Global, prelude::*};
use std::{str::FromStr, sync::Arc};
use thiserror::Error;
use util::maybe;
pub fn init(cx: &mut App) {
let registry = cx.new(|_cx| LanguageModelRegistry::default());
@ -42,7 +41,9 @@ impl std::fmt::Debug for ConfigurationError {
#[derive(Default)]
pub struct LanguageModelRegistry {
default_model: Option<ConfiguredModel>,
default_fast_model: Option<ConfiguredModel>,
/// This model is automatically configured by a user's environment after
/// authenticating all providers. It's only used when default_model is not available.
environment_fallback_model: Option<ConfiguredModel>,
inline_assistant_model: Option<ConfiguredModel>,
commit_message_model: Option<ConfiguredModel>,
thread_summary_model: Option<ConfiguredModel>,
@ -98,9 +99,6 @@ impl ConfiguredModel {
pub enum Event {
DefaultModelChanged,
InlineAssistantModelChanged,
CommitMessageModelChanged,
ThreadSummaryModelChanged,
ProviderStateChanged(LanguageModelProviderId),
AddedProvider(LanguageModelProviderId),
RemovedProvider(LanguageModelProviderId),
@ -226,7 +224,7 @@ impl LanguageModelRegistry {
cx: &mut Context<Self>,
) {
let configured_model = model.and_then(|model| self.select_model(model, cx));
self.set_inline_assistant_model(configured_model, cx);
self.set_inline_assistant_model(configured_model);
}
pub fn select_commit_message_model(
@ -235,7 +233,7 @@ impl LanguageModelRegistry {
cx: &mut Context<Self>,
) {
let configured_model = model.and_then(|model| self.select_model(model, cx));
self.set_commit_message_model(configured_model, cx);
self.set_commit_message_model(configured_model);
}
pub fn select_thread_summary_model(
@ -244,7 +242,7 @@ impl LanguageModelRegistry {
cx: &mut Context<Self>,
) {
let configured_model = model.and_then(|model| self.select_model(model, cx));
self.set_thread_summary_model(configured_model, cx);
self.set_thread_summary_model(configured_model);
}
/// Selects and sets the inline alternatives for language models based on
@ -278,68 +276,60 @@ impl LanguageModelRegistry {
}
pub fn set_default_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
match (self.default_model.as_ref(), model.as_ref()) {
match (self.default_model(), model.as_ref()) {
(Some(old), Some(new)) if old.is_same_as(new) => {}
(None, None) => {}
_ => cx.emit(Event::DefaultModelChanged),
}
self.default_fast_model = maybe!({
let provider = &model.as_ref()?.provider;
let fast_model = provider.default_fast_model(cx)?;
Some(ConfiguredModel {
provider: provider.clone(),
model: fast_model,
})
});
self.default_model = model;
}
pub fn set_inline_assistant_model(
pub fn set_environment_fallback_model(
&mut self,
model: Option<ConfiguredModel>,
cx: &mut Context<Self>,
) {
match (self.inline_assistant_model.as_ref(), model.as_ref()) {
(Some(old), Some(new)) if old.is_same_as(new) => {}
(None, None) => {}
_ => cx.emit(Event::InlineAssistantModelChanged),
if self.default_model.is_none() {
match (self.environment_fallback_model.as_ref(), model.as_ref()) {
(Some(old), Some(new)) if old.is_same_as(new) => {}
(None, None) => {}
_ => cx.emit(Event::DefaultModelChanged),
}
}
self.environment_fallback_model = model;
}
pub fn set_inline_assistant_model(&mut self, model: Option<ConfiguredModel>) {
self.inline_assistant_model = model;
}
pub fn set_commit_message_model(
&mut self,
model: Option<ConfiguredModel>,
cx: &mut Context<Self>,
) {
match (self.commit_message_model.as_ref(), model.as_ref()) {
(Some(old), Some(new)) if old.is_same_as(new) => {}
(None, None) => {}
_ => cx.emit(Event::CommitMessageModelChanged),
}
pub fn set_commit_message_model(&mut self, model: Option<ConfiguredModel>) {
self.commit_message_model = model;
}
pub fn set_thread_summary_model(
&mut self,
model: Option<ConfiguredModel>,
cx: &mut Context<Self>,
) {
match (self.thread_summary_model.as_ref(), model.as_ref()) {
(Some(old), Some(new)) if old.is_same_as(new) => {}
(None, None) => {}
_ => cx.emit(Event::ThreadSummaryModelChanged),
}
pub fn set_thread_summary_model(&mut self, model: Option<ConfiguredModel>) {
self.thread_summary_model = model;
}
#[track_caller]
pub fn default_model(&self) -> Option<ConfiguredModel> {
#[cfg(debug_assertions)]
if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
return None;
}
self.default_model.clone()
self.default_model
.clone()
.or_else(|| self.environment_fallback_model.clone())
}
pub fn default_fast_model(&self, cx: &App) -> Option<ConfiguredModel> {
let provider = self.default_model()?.provider;
let fast_model = provider.default_fast_model(cx)?;
Some(ConfiguredModel {
provider,
model: fast_model,
})
}
pub fn inline_assistant_model(&self) -> Option<ConfiguredModel> {
@ -353,7 +343,7 @@ impl LanguageModelRegistry {
.or_else(|| self.default_model.clone())
}
pub fn commit_message_model(&self) -> Option<ConfiguredModel> {
pub fn commit_message_model(&self, cx: &App) -> Option<ConfiguredModel> {
#[cfg(debug_assertions)]
if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
return None;
@ -361,11 +351,11 @@ impl LanguageModelRegistry {
self.commit_message_model
.clone()
.or_else(|| self.default_fast_model.clone())
.or_else(|| self.default_fast_model(cx))
.or_else(|| self.default_model.clone())
}
pub fn thread_summary_model(&self) -> Option<ConfiguredModel> {
pub fn thread_summary_model(&self, cx: &App) -> Option<ConfiguredModel> {
#[cfg(debug_assertions)]
if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
return None;
@ -373,7 +363,7 @@ impl LanguageModelRegistry {
self.thread_summary_model
.clone()
.or_else(|| self.default_fast_model.clone())
.or_else(|| self.default_fast_model(cx))
.or_else(|| self.default_model.clone())
}
@ -410,4 +400,34 @@ mod tests {
let providers = registry.read(cx).providers();
assert!(providers.is_empty());
}
#[gpui::test]
async fn test_configure_environment_fallback_model(cx: &mut gpui::TestAppContext) {
let registry = cx.new(|_| LanguageModelRegistry::default());
let provider = FakeLanguageModelProvider::default();
registry.update(cx, |registry, cx| {
registry.register_provider(provider.clone(), cx);
});
cx.update(|cx| provider.authenticate(cx)).await.unwrap();
registry.update(cx, |registry, cx| {
let provider = registry.provider(&provider.id()).unwrap();
registry.set_environment_fallback_model(
Some(ConfiguredModel {
provider: provider.clone(),
model: provider.default_model(cx).unwrap(),
}),
cx,
);
let default_model = registry.default_model().unwrap();
let fallback_model = registry.environment_fallback_model.clone().unwrap();
assert_eq!(default_model.model.id(), fallback_model.model.id());
assert_eq!(default_model.provider.id(), fallback_model.provider.id());
});
}
}

View file

@ -44,6 +44,7 @@ ollama = { workspace = true, features = ["schemars"] }
open_ai = { workspace = true, features = ["schemars"] }
open_router = { workspace = true, features = ["schemars"] }
partial-json-fixer.workspace = true
project.workspace = true
release_channel.workspace = true
schemars.workspace = true
serde.workspace = true

View file

@ -3,8 +3,12 @@ use std::sync::Arc;
use ::settings::{Settings, SettingsStore};
use client::{Client, UserStore};
use collections::HashSet;
use gpui::{App, Context, Entity};
use language_model::{LanguageModelProviderId, LanguageModelRegistry};
use futures::future;
use gpui::{App, AppContext as _, Context, Entity};
use language_model::{
AuthenticateError, ConfiguredModel, LanguageModelProviderId, LanguageModelRegistry,
};
use project::DisableAiSettings;
use provider::deepseek::DeepSeekLanguageModelProvider;
pub mod provider;
@ -13,7 +17,7 @@ pub mod ui;
use crate::provider::anthropic::AnthropicLanguageModelProvider;
use crate::provider::bedrock::BedrockLanguageModelProvider;
use crate::provider::cloud::CloudLanguageModelProvider;
use crate::provider::cloud::{self, CloudLanguageModelProvider};
use crate::provider::copilot_chat::CopilotChatLanguageModelProvider;
use crate::provider::google::GoogleLanguageModelProvider;
use crate::provider::lmstudio::LmStudioLanguageModelProvider;
@ -48,6 +52,13 @@ pub fn init(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) {
cx,
);
});
let mut already_authenticated = false;
if !DisableAiSettings::get_global(cx).disable_ai {
authenticate_all_providers(registry.clone(), cx);
already_authenticated = true;
}
cx.observe_global::<SettingsStore>(move |cx| {
let openai_compatible_providers_new = AllLanguageModelSettings::get_global(cx)
.openai_compatible
@ -65,6 +76,12 @@ pub fn init(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) {
);
});
openai_compatible_providers = openai_compatible_providers_new;
already_authenticated = false;
}
if !DisableAiSettings::get_global(cx).disable_ai && !already_authenticated {
authenticate_all_providers(registry.clone(), cx);
already_authenticated = true;
}
})
.detach();
@ -151,3 +168,83 @@ fn register_language_model_providers(
registry.register_provider(XAiLanguageModelProvider::new(client.http_client(), cx), cx);
registry.register_provider(CopilotChatLanguageModelProvider::new(cx), cx);
}
/// Authenticates all providers in the [`LanguageModelRegistry`].
///
/// We do this so that we can populate the language selector with all of the
/// models from the configured providers.
///
/// This function won't do anything if AI is disabled.
fn authenticate_all_providers(registry: Entity<LanguageModelRegistry>, cx: &mut App) {
let providers_to_authenticate = registry
.read(cx)
.providers()
.iter()
.map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
.collect::<Vec<_>>();
let mut tasks = Vec::with_capacity(providers_to_authenticate.len());
for (provider_id, provider_name, authenticate_task) in providers_to_authenticate {
tasks.push(cx.background_spawn(async move {
if let Err(err) = authenticate_task.await {
if matches!(err, AuthenticateError::CredentialsNotFound) {
// Since we're authenticating these providers in the
// background for the purposes of populating the
// language selector, we don't care about providers
// where the credentials are not found.
} else {
// Some providers have noisy failure states that we
// don't want to spam the logs with every time the
// language model selector is initialized.
//
// Ideally these should have more clear failure modes
// that we know are safe to ignore here, like what we do
// with `CredentialsNotFound` above.
match provider_id.0.as_ref() {
"lmstudio" | "ollama" => {
// LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
//
// These fail noisily, so we don't log them.
}
"copilot_chat" => {
// Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
}
_ => {
log::error!(
"Failed to authenticate provider: {}: {err}",
provider_name.0
);
}
}
}
}
}));
}
let all_authenticated_future = future::join_all(tasks);
cx.spawn(async move |cx| {
all_authenticated_future.await;
registry
.update(cx, |registry, cx| {
let cloud_provider = registry.provider(&cloud::PROVIDER_ID);
let fallback_model = cloud_provider
.iter()
.chain(registry.providers().iter())
.find(|provider| provider.is_authenticated(cx))
.and_then(|provider| {
Some(ConfiguredModel {
provider: provider.clone(),
model: provider
.default_model(cx)
.or_else(|| provider.recommended_models(cx).first().cloned())?,
})
});
registry.set_environment_fallback_model(fallback_model, cx);
})
.ok();
})
.detach();
}

View file

@ -44,8 +44,8 @@ use crate::provider::anthropic::{AnthropicEventMapper, count_anthropic_tokens, i
use crate::provider::google::{GoogleEventMapper, into_google};
use crate::provider::open_ai::{OpenAiEventMapper, count_open_ai_tokens, into_open_ai};
const PROVIDER_ID: LanguageModelProviderId = language_model::ZED_CLOUD_PROVIDER_ID;
const PROVIDER_NAME: LanguageModelProviderName = language_model::ZED_CLOUD_PROVIDER_NAME;
pub const PROVIDER_ID: LanguageModelProviderId = language_model::ZED_CLOUD_PROVIDER_ID;
pub const PROVIDER_NAME: LanguageModelProviderName = language_model::ZED_CLOUD_PROVIDER_NAME;
#[derive(Default, Clone, Debug, PartialEq)]
pub struct ZedDotDevSettings {
@ -146,7 +146,7 @@ impl State {
default_fast_model: None,
recommended_models: Vec::new(),
_fetch_models_task: cx.spawn(async move |this, cx| {
maybe!(async move {
maybe!(async {
let (client, llm_api_token) = this
.read_with(cx, |this, _cx| (client.clone(), this.llm_api_token.clone()))?;

View file

@ -4,6 +4,7 @@ use gpui::{
};
use itertools::Itertools;
use serde_json::json;
use settings::get_key_equivalents;
use ui::{Button, ButtonStyle};
use ui::{
ButtonCommon, Clickable, Context, FluentBuilder, InteractiveElement, Label, LabelCommon,
@ -168,8 +169,7 @@ impl Item for KeyContextView {
impl Render for KeyContextView {
fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl ui::IntoElement {
use itertools::Itertools;
let key_equivalents = cx.keyboard_mapper().get_key_equivalents();
let key_equivalents = get_key_equivalents(cx.keyboard_layout().id());
v_flex()
.id("key-context-view")
.overflow_scroll()

View file

@ -1743,5 +1743,6 @@ pub enum Event {
}
impl EventEmitter<Event> for LogStore {}
impl EventEmitter<Event> for LspLogView {}
impl EventEmitter<EditorEvent> for LspLogView {}
impl EventEmitter<SearchEvent> for LspLogView {}

View file

@ -11,21 +11,6 @@
(#set! injection.language "css"))
)
(call_expression
function: (member_expression
object: (identifier) @_obj (#eq? @_obj "styled")
property: (property_identifier))
arguments: (template_string (string_fragment) @injection.content
(#set! injection.language "css"))
)
(call_expression
function: (call_expression
function: (identifier) @_name (#eq? @_name "styled"))
arguments: (template_string (string_fragment) @injection.content
(#set! injection.language "css"))
)
(call_expression
function: (identifier) @_name (#eq? @_name "html")
arguments: (template_string) @injection.content

View file

@ -6,6 +6,9 @@
(self) @variable.special
(field_identifier) @property
(shorthand_field_initializer
(identifier) @property)
(trait_item name: (type_identifier) @type.interface)
(impl_item trait: (type_identifier) @type.interface)
(abstract_type trait: (type_identifier) @type.interface)
@ -38,11 +41,20 @@
(identifier) @function.special
(scoped_identifier
name: (identifier) @function.special)
])
]
"!" @function.special)
(macro_definition
name: (identifier) @function.special.definition)
(mod_item
name: (identifier) @module)
(visibility_modifier [
(crate) @keyword
(super) @keyword
])
; Identifier conventions
; Assume uppercase names are types/enum-constructors
@ -115,9 +127,7 @@
"where"
"while"
"yield"
(crate)
(mutable_specifier)
(super)
] @keyword
[
@ -189,6 +199,7 @@
operator: "/" @operator
(lifetime) @lifetime
(lifetime (identifier) @lifetime)
(parameter (identifier) @variable.parameter)

View file

@ -11,21 +11,6 @@
(#set! injection.language "css"))
)
(call_expression
function: (member_expression
object: (identifier) @_obj (#eq? @_obj "styled")
property: (property_identifier))
arguments: (template_string (string_fragment) @injection.content
(#set! injection.language "css"))
)
(call_expression
function: (call_expression
function: (identifier) @_name (#eq? @_name "styled"))
arguments: (template_string (string_fragment) @injection.content
(#set! injection.language "css"))
)
(call_expression
function: (identifier) @_name (#eq? @_name "html")
arguments: (template_string (string_fragment) @injection.content

View file

@ -15,21 +15,6 @@
(#set! injection.language "css"))
)
(call_expression
function: (member_expression
object: (identifier) @_obj (#eq? @_obj "styled")
property: (property_identifier))
arguments: (template_string (string_fragment) @injection.content
(#set! injection.language "css"))
)
(call_expression
function: (call_expression
function: (identifier) @_name (#eq? @_name "styled"))
arguments: (template_string (string_fragment) @injection.content
(#set! injection.language "css"))
)
(call_expression
function: (identifier) @_name (#eq? @_name "html")
arguments: (template_string) @injection.content

View file

@ -1323,7 +1323,7 @@ fn render_copy_code_block_button(
.icon_size(IconSize::Small)
.style(ButtonStyle::Filled)
.shape(ui::IconButtonShape::Square)
.tooltip(Tooltip::text("Copy"))
.tooltip(Tooltip::text("Copy Code"))
.on_click({
let markdown = markdown;
move |_event, _window, cx| {

View file

@ -283,13 +283,17 @@ pub(crate) fn render_ai_setup_page(
v_flex()
.mt_2()
.gap_6()
.child(
AiUpsellCard::new(client, &user_store, user_store.read(cx).plan(), cx)
.tab_index(Some({
tab_index += 1;
tab_index - 1
})),
)
.child({
let mut ai_upsell_card =
AiUpsellCard::new(client, &user_store, user_store.read(cx).plan(), cx);
ai_upsell_card.tab_index = Some({
tab_index += 1;
tab_index - 1
});
ai_upsell_card
})
.child(render_llm_provider_section(
&mut tab_index,
workspace,

View file

@ -850,19 +850,13 @@ impl workspace::SerializableItem for Onboarding {
}
mod persistence {
use db::{
query,
sqlez::{domain::Domain, thread_safe_connection::ThreadSafeConnection},
sqlez_macros::sql,
};
use db::{define_connection, query, sqlez_macros::sql};
use workspace::WorkspaceDb;
pub struct OnboardingPagesDb(ThreadSafeConnection);
impl Domain for OnboardingPagesDb {
const NAME: &str = stringify!(OnboardingPagesDb);
const MIGRATIONS: &[&str] = &[sql!(
define_connection! {
pub static ref ONBOARDING_PAGES: OnboardingPagesDb<WorkspaceDb> =
&[
sql!(
CREATE TABLE onboarding_pages (
workspace_id INTEGER,
item_id INTEGER UNIQUE,
@ -872,11 +866,10 @@ mod persistence {
FOREIGN KEY(workspace_id) REFERENCES workspaces(workspace_id)
ON DELETE CASCADE
) STRICT;
)];
),
];
}
db::static_connection!(ONBOARDING_PAGES, OnboardingPagesDb, [WorkspaceDb]);
impl OnboardingPagesDb {
query! {
pub async fn save_onboarding_page(

View file

@ -414,19 +414,13 @@ impl workspace::SerializableItem for WelcomePage {
}
mod persistence {
use db::{
query,
sqlez::{domain::Domain, thread_safe_connection::ThreadSafeConnection},
sqlez_macros::sql,
};
use db::{define_connection, query, sqlez_macros::sql};
use workspace::WorkspaceDb;
pub struct WelcomePagesDb(ThreadSafeConnection);
impl Domain for WelcomePagesDb {
const NAME: &str = stringify!(WelcomePagesDb);
const MIGRATIONS: &[&str] = (&[sql!(
define_connection! {
pub static ref WELCOME_PAGES: WelcomePagesDb<WorkspaceDb> =
&[
sql!(
CREATE TABLE welcome_pages (
workspace_id INTEGER,
item_id INTEGER UNIQUE,
@ -436,11 +430,10 @@ mod persistence {
FOREIGN KEY(workspace_id) REFERENCES workspaces(workspace_id)
ON DELETE CASCADE
) STRICT;
)]);
),
];
}
db::static_connection!(WELCOME_PAGES, WelcomePagesDb, [WorkspaceDb]);
impl WelcomePagesDb {
query! {
pub async fn save_welcome_page(

View file

@ -446,6 +446,7 @@ pub enum ResponseStreamResult {
#[derive(Serialize, Deserialize, Debug)]
pub struct ResponseStreamEvent {
pub model: String,
pub choices: Vec<ChoiceDelta>,
pub usage: Option<Usage>,
}

View file

@ -7588,16 +7588,19 @@ impl LspStore {
let snapshot = buffer_handle.read(cx).snapshot();
let buffer = buffer_handle.read(cx);
let reused_diagnostics = buffer
.buffer_diagnostics(Some(server_id))
.iter()
.filter(|v| merge(buffer, &v.diagnostic, cx))
.map(|v| {
let start = Unclipped(v.range.start.to_point_utf16(&snapshot));
let end = Unclipped(v.range.end.to_point_utf16(&snapshot));
DiagnosticEntry {
range: start..end,
diagnostic: v.diagnostic.clone(),
}
.get_diagnostics(server_id)
.into_iter()
.flat_map(|diag| {
diag.iter()
.filter(|v| merge(buffer, &v.diagnostic, cx))
.map(|v| {
let start = Unclipped(v.range.start.to_point_utf16(&snapshot));
let end = Unclipped(v.range.end.to_point_utf16(&snapshot));
DiagnosticEntry {
range: start..end,
diagnostic: v.diagnostic.clone(),
}
})
})
.collect::<Vec<_>>();
@ -11703,11 +11706,12 @@ impl LspStore {
// Ignore payload since we notify clients of setting changes unconditionally, relying on them pulling the latest settings.
}
"workspace/symbol" => {
let options = parse_register_capabilities(reg)?;
server.update_capabilities(|capabilities| {
capabilities.workspace_symbol_provider = Some(options);
});
notify_server_capabilities_updated(&server, cx);
if let Some(options) = parse_register_capabilities(reg)? {
server.update_capabilities(|capabilities| {
capabilities.workspace_symbol_provider = Some(options);
});
notify_server_capabilities_updated(&server, cx);
}
}
"workspace/fileOperations" => {
if let Some(options) = reg.register_options {
@ -11731,11 +11735,12 @@ impl LspStore {
}
}
"textDocument/rangeFormatting" => {
let options = parse_register_capabilities(reg)?;
server.update_capabilities(|capabilities| {
capabilities.document_range_formatting_provider = Some(options);
});
notify_server_capabilities_updated(&server, cx);
if let Some(options) = parse_register_capabilities(reg)? {
server.update_capabilities(|capabilities| {
capabilities.document_range_formatting_provider = Some(options);
});
notify_server_capabilities_updated(&server, cx);
}
}
"textDocument/onTypeFormatting" => {
if let Some(options) = reg
@ -11750,32 +11755,36 @@ impl LspStore {
}
}
"textDocument/formatting" => {
let options = parse_register_capabilities(reg)?;
server.update_capabilities(|capabilities| {
capabilities.document_formatting_provider = Some(options);
});
notify_server_capabilities_updated(&server, cx);
if let Some(options) = parse_register_capabilities(reg)? {
server.update_capabilities(|capabilities| {
capabilities.document_formatting_provider = Some(options);
});
notify_server_capabilities_updated(&server, cx);
}
}
"textDocument/rename" => {
let options = parse_register_capabilities(reg)?;
server.update_capabilities(|capabilities| {
capabilities.rename_provider = Some(options);
});
notify_server_capabilities_updated(&server, cx);
if let Some(options) = parse_register_capabilities(reg)? {
server.update_capabilities(|capabilities| {
capabilities.rename_provider = Some(options);
});
notify_server_capabilities_updated(&server, cx);
}
}
"textDocument/inlayHint" => {
let options = parse_register_capabilities(reg)?;
server.update_capabilities(|capabilities| {
capabilities.inlay_hint_provider = Some(options);
});
notify_server_capabilities_updated(&server, cx);
if let Some(options) = parse_register_capabilities(reg)? {
server.update_capabilities(|capabilities| {
capabilities.inlay_hint_provider = Some(options);
});
notify_server_capabilities_updated(&server, cx);
}
}
"textDocument/documentSymbol" => {
let options = parse_register_capabilities(reg)?;
server.update_capabilities(|capabilities| {
capabilities.document_symbol_provider = Some(options);
});
notify_server_capabilities_updated(&server, cx);
if let Some(options) = parse_register_capabilities(reg)? {
server.update_capabilities(|capabilities| {
capabilities.document_symbol_provider = Some(options);
});
notify_server_capabilities_updated(&server, cx);
}
}
"textDocument/codeAction" => {
if let Some(options) = reg
@ -11791,11 +11800,12 @@ impl LspStore {
}
}
"textDocument/definition" => {
let options = parse_register_capabilities(reg)?;
server.update_capabilities(|capabilities| {
capabilities.definition_provider = Some(options);
});
notify_server_capabilities_updated(&server, cx);
if let Some(options) = parse_register_capabilities(reg)? {
server.update_capabilities(|capabilities| {
capabilities.definition_provider = Some(options);
});
notify_server_capabilities_updated(&server, cx);
}
}
"textDocument/completion" => {
if let Some(caps) = reg
@ -12174,10 +12184,10 @@ impl LspStore {
// https://github.com/microsoft/vscode-languageserver-node/blob/d90a87f9557a0df9142cfb33e251cfa6fe27d970/client/src/common/client.ts#L2133
fn parse_register_capabilities<T: serde::de::DeserializeOwned>(
reg: lsp::Registration,
) -> Result<OneOf<bool, T>> {
) -> anyhow::Result<Option<OneOf<bool, T>>> {
Ok(match reg.register_options {
Some(options) => OneOf::Right(serde_json::from_value::<T>(options)?),
None => OneOf::Left(true),
Some(options) => Some(OneOf::Right(serde_json::from_value::<T>(options)?)),
None => Some(OneOf::Left(true)),
})
}

View file

@ -4089,7 +4089,6 @@ impl ProjectPanel {
.when(!is_sticky, |this| {
this
.when(is_highlighted && folded_directory_drag_target.is_none(), |this| this.border_color(transparent_white()).bg(item_colors.drag_over))
.when(settings.drag_and_drop, |this| this
.on_drag_move::<ExternalPaths>(cx.listener(
move |this, event: &DragMoveEvent<ExternalPaths>, _, cx| {
let is_current_target = this.drag_target_entry.as_ref()
@ -4223,7 +4222,7 @@ impl ProjectPanel {
}
this.drag_onto(selections, entry_id, kind.is_file(), window, cx);
}),
))
)
})
.on_mouse_down(
MouseButton::Left,
@ -4434,7 +4433,6 @@ impl ProjectPanel {
div()
.when(!is_sticky, |div| {
div
.when(settings.drag_and_drop, |div| div
.on_drop(cx.listener(move |this, selections: &DraggedSelection, window, cx| {
this.hover_scroll_task.take();
this.drag_target_entry = None;
@ -4466,7 +4464,7 @@ impl ProjectPanel {
}
},
)))
))
})
.child(
Label::new(DELIMITER.clone())
@ -4486,7 +4484,6 @@ impl ProjectPanel {
.when(index != components_len - 1, |div|{
let target_entry_id = folded_ancestors.ancestors.get(components_len - 1 - index).cloned();
div
.when(settings.drag_and_drop, |div| div
.on_drag_move(cx.listener(
move |this, event: &DragMoveEvent<DraggedSelection>, _, _| {
if event.bounds.contains(&event.event.position) {
@ -4524,7 +4521,7 @@ impl ProjectPanel {
target.index == index
), |this| {
this.bg(item_colors.drag_over)
}))
})
})
})
.on_click(cx.listener(move |this, _, _, cx| {
@ -5032,8 +5029,7 @@ impl ProjectPanel {
sticky_parents.reverse();
let panel_settings = ProjectPanelSettings::get_global(cx);
let git_status_enabled = panel_settings.git_status;
let git_status_enabled = ProjectPanelSettings::get_global(cx).git_status;
let root_name = OsStr::new(worktree.root_name());
let git_summaries_by_id = if git_status_enabled {
@ -5117,11 +5113,11 @@ impl Render for ProjectPanel {
fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let has_worktree = !self.visible_entries.is_empty();
let project = self.project.read(cx);
let panel_settings = ProjectPanelSettings::get_global(cx);
let indent_size = panel_settings.indent_size;
let show_indent_guides = panel_settings.indent_guides.show == ShowIndentGuides::Always;
let indent_size = ProjectPanelSettings::get_global(cx).indent_size;
let show_indent_guides =
ProjectPanelSettings::get_global(cx).indent_guides.show == ShowIndentGuides::Always;
let show_sticky_entries = {
if panel_settings.sticky_scroll {
if ProjectPanelSettings::get_global(cx).sticky_scroll {
let is_scrollable = self.scroll_handle.is_scrollable();
let is_scrolled = self.scroll_handle.offset().y < px(0.);
is_scrollable && is_scrolled
@ -5209,10 +5205,8 @@ impl Render for ProjectPanel {
h_flex()
.id("project-panel")
.group("project-panel")
.when(panel_settings.drag_and_drop, |this| {
this.on_drag_move(cx.listener(handle_drag_move::<ExternalPaths>))
.on_drag_move(cx.listener(handle_drag_move::<DraggedSelection>))
})
.on_drag_move(cx.listener(handle_drag_move::<ExternalPaths>))
.on_drag_move(cx.listener(handle_drag_move::<DraggedSelection>))
.size_full()
.relative()
.on_modifiers_changed(cx.listener(
@ -5550,32 +5544,30 @@ impl Render for ProjectPanel {
})),
)
.when(is_local, |div| {
div.when(panel_settings.drag_and_drop, |div| {
div.drag_over::<ExternalPaths>(|style, _, _, cx| {
style.bg(cx.theme().colors().drop_target_background)
})
.on_drop(cx.listener(
move |this, external_paths: &ExternalPaths, window, cx| {
this.drag_target_entry = None;
this.hover_scroll_task.take();
if let Some(task) = this
.workspace
.update(cx, |workspace, cx| {
workspace.open_workspace_for_paths(
true,
external_paths.paths().to_owned(),
window,
cx,
)
})
.log_err()
{
task.detach_and_log_err(cx);
}
cx.stop_propagation();
},
))
div.drag_over::<ExternalPaths>(|style, _, _, cx| {
style.bg(cx.theme().colors().drop_target_background)
})
.on_drop(cx.listener(
move |this, external_paths: &ExternalPaths, window, cx| {
this.drag_target_entry = None;
this.hover_scroll_task.take();
if let Some(task) = this
.workspace
.update(cx, |workspace, cx| {
workspace.open_workspace_for_paths(
true,
external_paths.paths().to_owned(),
window,
cx,
)
})
.log_err()
{
task.detach_and_log_err(cx);
}
cx.stop_propagation();
},
))
})
}
}

View file

@ -47,7 +47,6 @@ pub struct ProjectPanelSettings {
pub scrollbar: ScrollbarSettings,
pub show_diagnostics: ShowDiagnostics,
pub hide_root: bool,
pub drag_and_drop: bool,
}
#[derive(Copy, Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq, Eq)]
@ -161,10 +160,6 @@ pub struct ProjectPanelSettingsContent {
///
/// Default: true
pub sticky_scroll: Option<bool>,
/// Whether to enable drag-and-drop operations in the project panel.
///
/// Default: true
pub drag_and_drop: Option<bool>,
}
impl Settings for ProjectPanelSettings {

View file

@ -445,7 +445,7 @@ impl SshSocket {
}
async fn platform(&self) -> Result<SshPlatform> {
let uname = self.run_command("sh", &["-lc", "uname -sm"]).await?;
let uname = self.run_command("sh", &["-c", "uname -sm"]).await?;
let Some((os, arch)) = uname.split_once(" ") else {
anyhow::bail!("unknown uname: {uname:?}")
};
@ -476,7 +476,7 @@ impl SshSocket {
}
async fn shell(&self) -> String {
match self.run_command("sh", &["-lc", "echo $SHELL"]).await {
match self.run_command("sh", &["-c", "echo $SHELL"]).await {
Ok(shell) => shell.trim().to_owned(),
Err(e) => {
log::error!("Failed to get shell: {e}");
@ -1533,7 +1533,7 @@ impl RemoteConnection for SshRemoteConnection {
let ssh_proxy_process = match self
.socket
.ssh_command("sh", &["-lc", &start_proxy_command])
.ssh_command("sh", &["-c", &start_proxy_command])
// IMPORTANT: we kill this process when we drop the task that uses it.
.kill_on_drop(true)
.spawn()
@ -1910,7 +1910,7 @@ impl SshRemoteConnection {
.run_command(
"sh",
&[
"-lc",
"-c",
&shell_script!("mkdir -p {parent}", parent = parent.to_string().as_ref()),
],
)
@ -1988,7 +1988,7 @@ impl SshRemoteConnection {
.run_command(
"sh",
&[
"-lc",
"-c",
&shell_script!("mkdir -p {parent}", parent = parent.to_string().as_ref()),
],
)
@ -2036,7 +2036,7 @@ impl SshRemoteConnection {
dst_path = &dst_path.to_string()
)
};
self.socket.run_command("sh", &["-lc", &script]).await?;
self.socket.run_command("sh", &["-c", &script]).await?;
Ok(())
}

View file

@ -65,7 +65,6 @@ telemetry_events.workspace = true
util.workspace = true
watch.workspace = true
worktree.workspace = true
thiserror.workspace = true
[target.'cfg(not(windows))'.dependencies]
crashes.workspace = true

View file

@ -1,7 +1,6 @@
#![cfg_attr(target_os = "windows", allow(unused, dead_code))]
use clap::Parser;
use remote_server::Commands;
use clap::{Parser, Subcommand};
use std::path::PathBuf;
#[derive(Parser)]
@ -22,34 +21,105 @@ struct Cli {
printenv: bool,
}
#[derive(Subcommand)]
enum Commands {
Run {
#[arg(long)]
log_file: PathBuf,
#[arg(long)]
pid_file: PathBuf,
#[arg(long)]
stdin_socket: PathBuf,
#[arg(long)]
stdout_socket: PathBuf,
#[arg(long)]
stderr_socket: PathBuf,
},
Proxy {
#[arg(long)]
reconnect: bool,
#[arg(long)]
identifier: String,
},
Version,
}
#[cfg(windows)]
fn main() {
unimplemented!()
}
#[cfg(not(windows))]
fn main() -> anyhow::Result<()> {
fn main() {
use release_channel::{RELEASE_CHANNEL, ReleaseChannel};
use remote::proxy::ProxyLaunchError;
use remote_server::unix::{execute_proxy, execute_run};
let cli = Cli::parse();
if let Some(socket_path) = &cli.askpass {
askpass::main(socket_path);
return Ok(());
return;
}
if let Some(socket) = &cli.crash_handler {
crashes::crash_server(socket.as_path());
return Ok(());
return;
}
if cli.printenv {
util::shell_env::print_env();
return Ok(());
return;
}
if let Some(command) = cli.command {
remote_server::run(command)
} else {
eprintln!("usage: remote <run|proxy|version>");
let result = match cli.command {
Some(Commands::Run {
log_file,
pid_file,
stdin_socket,
stdout_socket,
stderr_socket,
}) => execute_run(
log_file,
pid_file,
stdin_socket,
stdout_socket,
stderr_socket,
),
Some(Commands::Proxy {
identifier,
reconnect,
}) => match execute_proxy(identifier, reconnect) {
Ok(_) => Ok(()),
Err(err) => {
if let Some(err) = err.downcast_ref::<ProxyLaunchError>() {
std::process::exit(err.to_exit_code());
}
Err(err)
}
},
Some(Commands::Version) => {
let release_channel = *RELEASE_CHANNEL;
match release_channel {
ReleaseChannel::Stable | ReleaseChannel::Preview => {
println!("{}", env!("ZED_PKG_VERSION"))
}
ReleaseChannel::Nightly | ReleaseChannel::Dev => {
println!(
"{}",
option_env!("ZED_COMMIT_SHA").unwrap_or(release_channel.dev_name())
)
}
};
std::process::exit(0);
}
None => {
eprintln!("usage: remote <run|proxy|version>");
std::process::exit(1);
}
};
if let Err(error) = result {
log::error!("exiting due to error: {}", error);
std::process::exit(1);
}
}

View file

@ -6,78 +6,4 @@ pub mod unix;
#[cfg(test)]
mod remote_editing_tests;
use clap::Subcommand;
use std::path::PathBuf;
pub use headless_project::{HeadlessAppState, HeadlessProject};
#[derive(Subcommand)]
pub enum Commands {
Run {
#[arg(long)]
log_file: PathBuf,
#[arg(long)]
pid_file: PathBuf,
#[arg(long)]
stdin_socket: PathBuf,
#[arg(long)]
stdout_socket: PathBuf,
#[arg(long)]
stderr_socket: PathBuf,
},
Proxy {
#[arg(long)]
reconnect: bool,
#[arg(long)]
identifier: String,
},
Version,
}
#[cfg(not(windows))]
pub fn run(command: Commands) -> anyhow::Result<()> {
use anyhow::Context;
use release_channel::{RELEASE_CHANNEL, ReleaseChannel};
use unix::{ExecuteProxyError, execute_proxy, execute_run};
match command {
Commands::Run {
log_file,
pid_file,
stdin_socket,
stdout_socket,
stderr_socket,
} => execute_run(
log_file,
pid_file,
stdin_socket,
stdout_socket,
stderr_socket,
),
Commands::Proxy {
identifier,
reconnect,
} => execute_proxy(identifier, reconnect)
.inspect_err(|err| {
if let ExecuteProxyError::ServerNotRunning(err) = err {
std::process::exit(err.to_exit_code());
}
})
.context("running proxy on the remote server"),
Commands::Version => {
let release_channel = *RELEASE_CHANNEL;
match release_channel {
ReleaseChannel::Stable | ReleaseChannel::Preview => {
println!("{}", env!("ZED_PKG_VERSION"))
}
ReleaseChannel::Nightly | ReleaseChannel::Dev => {
println!(
"{}",
option_env!("ZED_COMMIT_SHA").unwrap_or(release_channel.dev_name())
)
}
};
Ok(())
}
}
}

View file

@ -36,7 +36,6 @@ use smol::Async;
use smol::{net::unix::UnixListener, stream::StreamExt as _};
use std::ffi::OsStr;
use std::ops::ControlFlow;
use std::process::ExitStatus;
use std::str::FromStr;
use std::sync::LazyLock;
use std::{env, thread};
@ -47,7 +46,6 @@ use std::{
sync::Arc,
};
use telemetry_events::LocationData;
use thiserror::Error;
use util::ResultExt;
pub static VERSION: LazyLock<&str> = LazyLock::new(|| match *RELEASE_CHANNEL {
@ -528,23 +526,7 @@ pub fn execute_run(
Ok(())
}
#[derive(Debug, Error)]
pub(crate) enum ServerPathError {
#[error("Failed to create server_dir `{path}`")]
CreateServerDir {
#[source]
source: std::io::Error,
path: PathBuf,
},
#[error("Failed to create logs_dir `{path}`")]
CreateLogsDir {
#[source]
source: std::io::Error,
path: PathBuf,
},
}
#[derive(Clone, Debug)]
#[derive(Clone)]
struct ServerPaths {
log_file: PathBuf,
pid_file: PathBuf,
@ -554,19 +536,10 @@ struct ServerPaths {
}
impl ServerPaths {
fn new(identifier: &str) -> Result<Self, ServerPathError> {
fn new(identifier: &str) -> Result<Self> {
let server_dir = paths::remote_server_state_dir().join(identifier);
std::fs::create_dir_all(&server_dir).map_err(|source| {
ServerPathError::CreateServerDir {
source,
path: server_dir.clone(),
}
})?;
let log_dir = logs_dir();
std::fs::create_dir_all(log_dir).map_err(|source| ServerPathError::CreateLogsDir {
source: source,
path: log_dir.clone(),
})?;
std::fs::create_dir_all(&server_dir)?;
std::fs::create_dir_all(&logs_dir())?;
let pid_file = server_dir.join("server.pid");
let stdin_socket = server_dir.join("stdin.sock");
@ -584,43 +557,7 @@ impl ServerPaths {
}
}
#[derive(Debug, Error)]
pub(crate) enum ExecuteProxyError {
#[error("Failed to init server paths")]
ServerPath(#[from] ServerPathError),
#[error(transparent)]
ServerNotRunning(#[from] ProxyLaunchError),
#[error("Failed to check PidFile '{path}'")]
CheckPidFile {
#[source]
source: CheckPidError,
path: PathBuf,
},
#[error("Failed to kill existing server with pid '{pid}'")]
KillRunningServer {
#[source]
source: std::io::Error,
pid: u32,
},
#[error("failed to spawn server")]
SpawnServer(#[source] SpawnServerError),
#[error("stdin_task failed")]
StdinTask(#[source] anyhow::Error),
#[error("stdout_task failed")]
StdoutTask(#[source] anyhow::Error),
#[error("stderr_task failed")]
StderrTask(#[source] anyhow::Error),
}
pub(crate) fn execute_proxy(
identifier: String,
is_reconnecting: bool,
) -> Result<(), ExecuteProxyError> {
pub fn execute_proxy(identifier: String, is_reconnecting: bool) -> Result<()> {
init_logging_proxy();
let server_paths = ServerPaths::new(&identifier)?;
@ -637,19 +574,12 @@ pub(crate) fn execute_proxy(
log::info!("starting proxy process. PID: {}", std::process::id());
let server_pid = check_pid_file(&server_paths.pid_file).map_err(|source| {
ExecuteProxyError::CheckPidFile {
source,
path: server_paths.pid_file.clone(),
}
})?;
let server_pid = check_pid_file(&server_paths.pid_file)?;
let server_running = server_pid.is_some();
if is_reconnecting {
if !server_running {
log::error!("attempted to reconnect, but no server running");
return Err(ExecuteProxyError::ServerNotRunning(
ProxyLaunchError::ServerNotRunning,
));
anyhow::bail!(ProxyLaunchError::ServerNotRunning);
}
} else {
if let Some(pid) = server_pid {
@ -660,7 +590,7 @@ pub(crate) fn execute_proxy(
kill_running_server(pid, &server_paths)?;
}
spawn_server(&server_paths).map_err(ExecuteProxyError::SpawnServer)?;
spawn_server(&server_paths)?;
};
let stdin_task = smol::spawn(async move {
@ -700,9 +630,9 @@ pub(crate) fn execute_proxy(
if let Err(forwarding_result) = smol::block_on(async move {
futures::select! {
result = stdin_task.fuse() => result.map_err(ExecuteProxyError::StdinTask),
result = stdout_task.fuse() => result.map_err(ExecuteProxyError::StdoutTask),
result = stderr_task.fuse() => result.map_err(ExecuteProxyError::StderrTask),
result = stdin_task.fuse() => result.context("stdin_task failed"),
result = stdout_task.fuse() => result.context("stdout_task failed"),
result = stderr_task.fuse() => result.context("stderr_task failed"),
}
}) {
log::error!(
@ -715,12 +645,12 @@ pub(crate) fn execute_proxy(
Ok(())
}
fn kill_running_server(pid: u32, paths: &ServerPaths) -> Result<(), ExecuteProxyError> {
fn kill_running_server(pid: u32, paths: &ServerPaths) -> Result<()> {
log::info!("killing existing server with PID {}", pid);
std::process::Command::new("kill")
.arg(pid.to_string())
.output()
.map_err(|source| ExecuteProxyError::KillRunningServer { source, pid })?;
.context("failed to kill existing server")?;
for file in [
&paths.pid_file,
@ -734,39 +664,18 @@ fn kill_running_server(pid: u32, paths: &ServerPaths) -> Result<(), ExecuteProxy
Ok(())
}
#[derive(Debug, Error)]
pub(crate) enum SpawnServerError {
#[error("failed to remove stdin socket")]
RemoveStdinSocket(#[source] std::io::Error),
#[error("failed to remove stdout socket")]
RemoveStdoutSocket(#[source] std::io::Error),
#[error("failed to remove stderr socket")]
RemoveStderrSocket(#[source] std::io::Error),
#[error("failed to get current_exe")]
CurrentExe(#[source] std::io::Error),
#[error("failed to launch server process")]
ProcessStatus(#[source] std::io::Error),
#[error("failed to launch and detach server process: {status}\n{paths}")]
LaunchStatus { status: ExitStatus, paths: String },
}
fn spawn_server(paths: &ServerPaths) -> Result<(), SpawnServerError> {
fn spawn_server(paths: &ServerPaths) -> Result<()> {
if paths.stdin_socket.exists() {
std::fs::remove_file(&paths.stdin_socket).map_err(SpawnServerError::RemoveStdinSocket)?;
std::fs::remove_file(&paths.stdin_socket)?;
}
if paths.stdout_socket.exists() {
std::fs::remove_file(&paths.stdout_socket).map_err(SpawnServerError::RemoveStdoutSocket)?;
std::fs::remove_file(&paths.stdout_socket)?;
}
if paths.stderr_socket.exists() {
std::fs::remove_file(&paths.stderr_socket).map_err(SpawnServerError::RemoveStderrSocket)?;
std::fs::remove_file(&paths.stderr_socket)?;
}
let binary_name = std::env::current_exe().map_err(SpawnServerError::CurrentExe)?;
let binary_name = std::env::current_exe()?;
let mut server_process = std::process::Command::new(binary_name);
server_process
.arg("run")
@ -783,17 +692,11 @@ fn spawn_server(paths: &ServerPaths) -> Result<(), SpawnServerError> {
let status = server_process
.status()
.map_err(SpawnServerError::ProcessStatus)?;
if !status.success() {
return Err(SpawnServerError::LaunchStatus {
status,
paths: format!(
"log file: {:?}, pid file: {:?}",
paths.log_file, paths.pid_file,
),
});
}
.context("failed to launch server process")?;
anyhow::ensure!(
status.success(),
"failed to launch and detach server process"
);
let mut total_time_waited = std::time::Duration::from_secs(0);
let wait_duration = std::time::Duration::from_millis(20);
@ -814,15 +717,7 @@ fn spawn_server(paths: &ServerPaths) -> Result<(), SpawnServerError> {
Ok(())
}
#[derive(Debug, Error)]
#[error("Failed to remove PID file for missing process (pid `{pid}`")]
pub(crate) struct CheckPidError {
#[source]
source: std::io::Error,
pid: u32,
}
fn check_pid_file(path: &Path) -> Result<Option<u32>, CheckPidError> {
fn check_pid_file(path: &Path) -> Result<Option<u32>> {
let Some(pid) = std::fs::read_to_string(&path)
.ok()
.and_then(|contents| contents.parse::<u32>().ok())
@ -847,7 +742,7 @@ fn check_pid_file(path: &Path) -> Result<Option<u32>, CheckPidError> {
log::debug!(
"Found PID file, but process with that PID does not exist. Removing PID file."
);
std::fs::remove_file(&path).map_err(|source| CheckPidError { source, pid })?;
std::fs::remove_file(&path).context("Failed to remove PID file")?;
Ok(None)
}
}

Some files were not shown because too many files have changed in this diff Show more