Compare commits

...
Sign in to create a new pull request.

47 commits

Author SHA1 Message Date
Danilo Leal
bd4e943597
acp: Add onboarding modal & title bar banner (#36784)
Release Notes:

- N/A

---------

Co-authored-by: Bennet Bo Fenner <bennetbo@gmx.de>
2025-08-26 16:59:12 -03:00
Danilo Leal
c5d3c7d790
thread view: Improve agent installation UI (#36957)
Release Notes:

- N/A

---------

Co-authored-by: Conrad Irwin <conrad.irwin@gmail.com>
2025-08-26 16:58:23 -03:00
张小白
fff0ecead1
windows: Fix keystroke & keymap (#36572)
Closes #36300

This PR follows Windows conventions by introducing
`KeybindingKeystroke`, so shortcuts now show up as `ctrl-shift-4`
instead of `ctrl-$`.

It also fixes issues with keyboard layouts: when `use_key_equivalents`
is set to true, keys are remapped based on their virtual key codes. For
example, `ctrl-\` on a standard English layout will be mapped to
`ctrl-ё` on a Russian layout.


Release Notes:

- N/A

---------

Co-authored-by: Kate <kate@zed.dev>
2025-08-27 03:24:50 +08:00
Max Brunsfeld
b1b60bb7fe
Work around duplicate ssh projects in workspace migration (#36946)
Fixes another case where the sqlite migration could fail, reported by
@SomeoneToIgnore.

Release Notes:

- N/A
2025-08-26 10:54:39 -07:00
Adam Mulvany
0e575b2809
helix: Fix buffer search: deploy reset to normal mode (#36917)
## Fix: Preserve Helix mode when using  search

### Problem
When using `buffer search: deploy` in Helix mode, pressing Enter to
dismiss the search incorrectly returned to Vim NORMAL mode instead of
Helix NORMAL mode.

### Root Cause
The `search_deploy` function was resetting the entire `SearchState` to
default values when buffer search: deploy was activated. Since the
default `Mode` is `Normal`, this caused `prior_mode` to be set to Vim's
Normal mode regardless of the actual mode before search.

### Solution
Modified `search_deploy` to preserve the current mode when resetting
search state:
- Store the current mode before resetting
- Reset search state to default
- Restore the saved mode to `prior_mode`

This ensures the editor returns to the correct mode (Helix NORMAL or Vim
NORMAL) after dismissing buffer search.

### Settings

I was able to reproduce and then test the fix was successful with the
following config and have also tested with vim: default_mode commented
out to ensure that's not influencing the mode selection flow:

```
  "helix_mode": true,
  "vim_mode": true,
  "vim": {
    "default_mode": "helix_normal"
  },
```

This is on Kubuntu 24.04.

The following test combinations pass locally:

- `cargo test -p search`
- `cargo test -p vim` 
- `cargo test -p editor`
- `cargo test -p workspace`
- `cargo test -p gpui -- vim`
- `cargo test -p gpui -- helix`

Release Notes:

- Fixed Helix mode switching to Vim normal mode after using `buffer
search: deploy` to search

Closes #36872
2025-08-26 10:38:53 -06:00
Danilo Leal
65c6c709fd
thread view: Refine tool call UI (#36937)
Release Notes:

- N/A

---------

Co-authored-by: Bennet Bo Fenner <bennetbo@gmx.de>
2025-08-26 12:55:40 -03:00
Bennet Bo Fenner
858ab9cc23
Revert "ai: Auto select user model when there's no default" (#36932)
Reverts zed-industries/zed#36722

Release Notes:

- N/A
2025-08-26 13:55:09 +00:00
Daniel Martín
2c64b05ea4
emacs: Add editor::FindAllReferences keybinding (#36840)
This commit maps `editor::FindAllReferences` to Alt+? in the Emacs
keymap.

Release Notes:

- N/A
2025-08-26 13:43:58 +00:00
Peter Tripp
b7dad2cf71
Fix initial_tasks.json triggering diagnostic warning (#36523)
`zed::OpenProjectTasks` without an existing tasks.json will recreate it
from the template.
This file will immediately show a warning.

<img width="810" height="168" alt="Screenshot 2025-08-19 at 17 16 07"
src="https://github.com/user-attachments/assets/bbc8c7a0-7036-4927-8e85-b81b79aeaacb"
/>

Release Notes:

- N/A
2025-08-26 13:41:57 +00:00
Peter Tripp
76dbcde628
Support disabling drag-and-drop in Project Panel (#36719)
Release Notes:

- Added setting for disabling drag and drop in project panel. `{
"project_panel": {"drag_and_drop": false } }`
2025-08-26 13:35:45 +00:00
Peter Tripp
aa0f7a2d09
Fix conflicts in Linux default keymap (#36519)
Closes https://github.com/zed-industries/zed/issues/29746

| Action | New Key | Old Key | Former Conflict |
| - | - | - | - |
| `edit_prediction::ToggleMenu` | `ctrl-alt-shift-i` | `ctrl-shift-i` |
`editor::Format` |
| `editor::ToggleEditPrediction` | `ctrl-alt-shift-e` | `ctrl-shift-e` |
`project_panel::ToggleFocus` |

These aren't great keys and I'm open to alternate suggestions, but the
will work out of the box without conflict.

Release Notes:

- N/A
2025-08-26 09:33:42 -04:00
Bennet Bo Fenner
372b3c7af6
acp: Enable feature flag for everyone (#36928)
Release Notes:

- N/A
2025-08-26 15:30:26 +02:00
Bennet Bo Fenner
10a1140d49
acp: Improve matching logic when adding new entry to agent_servers (#36926)
Release Notes:

- N/A

---------

Co-authored-by: Antonio Scandurra <me@as-cii.com>
2025-08-26 11:18:50 +00:00
Bennet Bo Fenner
e96b68bc15
acp: Polish UI (#36927)
Release Notes:

- N/A

---------

Co-authored-by: Antonio Scandurra <me@as-cii.com>
2025-08-26 10:55:45 +00:00
Ben Brandt
b249593abe
agent2: Always finalize diffs from the edit tool (#36918)
Previously, we wouldn't finalize the diff if an error occurred during
editing or the tool call was canceled.

Release Notes:

- N/A

---------

Co-authored-by: Antonio Scandurra <me@as-cii.com>
2025-08-26 09:46:29 +00:00
Bennet Bo Fenner
c14d84cfdb
acp: Add button to configure custom agent in the configuration view (#36923)
Release Notes:

- N/A
2025-08-26 09:20:33 +00:00
Dan Dascalescu
428fc6d483
chore: Fix typo in 10_bug_report.yml (#36922)
Release Notes:

- N/A
2025-08-26 11:05:40 +02:00
Max Brunsfeld
64b14ef848
Fix Sqlite newline syntax in workspace migration (#36916)
Fixes one more case where I incorrectly tried to use a `\n` escape
sequence for a newline in sqlite.

Release Notes:

- N/A
2025-08-25 22:21:05 -07:00
Rui Ning
bf5ed6d1c9
Remote: Change "sh -c" to "sh -lc" to make config in $HOME/.profile effective (#36760)
Closes #ISSUE

Release Notes:

- The environment of original remote dev cannot be changed without sudo
because of the behavior of "sh -c". This PR changes "sh -c" to "sh -lc"
to let the shell source $HOME/.profile and support customized
environment like customized $PATH variable.
2025-08-25 21:40:53 -06:00
Romans Malinovskis
bb5cfe118f
Add "shift-r" and "g ." support for helix mode (#35468)
Related #4642
Compatible with #34136

Release Notes:

- Helix: `Shift+R` works as Paste instead of taking you to ReplaceMode
- Helix: `g .` goes to last modification place (similar to `. in vim)
2025-08-25 21:37:29 -06:00
Conrad Irwin
633ce23ae9
acp: Send user-configured MCP tools (#36910)
Release Notes:

- N/A
2025-08-26 00:55:24 +00:00
Max Brunsfeld
d43df9e841
Fix workspace migration failure (#36911)
This fixes a regression on nightly introduced in
https://github.com/zed-industries/zed/pull/36714

Release Notes:

- N/A
2025-08-26 00:27:52 +00:00
Conrad Irwin
f8667a8379
Remove unused files (#36909)
Closes #ISSUE

Release Notes:

- N/A
2025-08-25 22:23:58 +00:00
Conrad Irwin
1460573dd4
acp: Rename dev command (#36908)
Release Notes:

- N/A
2025-08-25 16:04:44 -06:00
Kirill Bulatov
65de969cc8
Do not show directories in the InvalidBufferView (#36906)
Follow-up of https://github.com/zed-industries/zed/pull/36764

Release Notes:

- N/A
2025-08-25 21:16:37 +00:00
Danilo Leal
628a9cd8ea
thread view: Add link to docs in the toolbar plus menu (#36883)
Release Notes:

- N/A
2025-08-25 17:34:55 -03:00
Gwen Lg
ad25aba990
remote_server: Improve error reporting (#33770)
Closes #33736

Use `thiserror` to implement error stack and `anyhow` to report is to
user.
Also move some code from main to remote_server to have better crate
isolation.

Release Notes:

- N/A

---------

Co-authored-by: Kirill Bulatov <kirill@zed.dev>
2025-08-25 20:23:29 +00:00
Alvaro Parker
99cee8778c
tab_switcher: Add support for diagnostics (#34547)
Support to show diagnostics on the tab switcher in the same way they are
displayed on the tab bar. This follows the setting
`tabs.show_diagnostics`.

This will improve user experience when disabling the tab bar and still
being able to see the diagnostics when switching tabs

Preview:

<img width="768" height="523" alt="Screenshot From 2025-07-16 11-02-42"
src="https://github.com/user-attachments/assets/308873ba-0458-485d-ae05-0de7c1cdfb28"
/>


Release Notes:

- Added diagnostics indicators to the tab switcher

---------

Co-authored-by: Kirill Bulatov <kirill@zed.dev>
2025-08-25 20:18:03 +00:00
Cole Miller
823a0018e5
acp: Show output for read_file tool in a code block (#36900)
Release Notes:

- N/A
2025-08-25 20:10:17 +00:00
Conrad Irwin
9cc006ff74
acp: Update error matching (#36898)
Release Notes:

- N/A
2025-08-25 14:07:10 -06:00
Michael Sloan
0470baca50
open_ai: Remove model field from ResponseStreamEvent (#36902)
Closes #36901

Release Notes:

- Fixed use of Open WebUI as an LLM provider.
2025-08-25 19:50:08 +00:00
John Tur
4605b96630
Fix constant thread creation on Windows (#36779)
See
https://github.com/zed-industries/zed/issues/36057#issuecomment-3215808649

Fixes https://github.com/zed-industries/zed/issues/36057

Release Notes:

- N/A
2025-08-25 12:45:28 -07:00
Danilo Leal
949398cb93
thread view: Fix some design papercuts (#36893)
Release Notes:

- N/A

---------

Co-authored-by: Conrad Irwin <conrad.irwin@gmail.com>
Co-authored-by: Ben Brandt <benjamin.j.brandt@gmail.com>
Co-authored-by: Matt Miller <mattrx@gmail.com>
2025-08-25 21:07:30 +02:00
Cretezy
79e74b880b
workspace: Allow disabling of padding on zoomed panels (#31913)
Screenshot:

| Before | After |
| -------|------|
|
![image](https://github.com/user-attachments/assets/629e7da2-6070-4abb-b469-3b0824524ca4)
|
![image](https://github.com/user-attachments/assets/99e54412-2e0b-4df9-9c40-a89b0411f6d8)
|
|
![image](https://github.com/user-attachments/assets/e99da846-f39b-47b5-808e-65c22a1af47b)
|
![image](https://github.com/user-attachments/assets/ccd4408f-8cce-44ec-a69a-81794125ec99)
|


Release Notes:

- Added `zoomed_padding` to allow disabling of padding around zoomed
panels

Co-authored-by: Mikayla Maki <mikayla@zed.dev>
2025-08-25 19:02:19 +00:00
Bennet Bo Fenner
59af2a7d1f
acp: Add telemetry (#36894)
Release Notes:

- N/A

---------

Co-authored-by: Conrad Irwin <conrad.irwin@gmail.com>
2025-08-25 20:51:23 +02:00
Danilo Leal
c786c0150f
agent: Add section for agent servers in settings view (#35206)
Release Notes:

- N/A

---------

Co-authored-by: Cole Miller <cole@zed.dev>
2025-08-25 14:45:24 -04:00
Cole Miller
5fd29d37a6
acp: Model-specific prompt capabilities for 1PA (#36879)
Adds support for per-session prompt capabilities and capability changes
on the Zed side (ACP itself still only has per-connection static
capabilities for now), and uses it to reflect image support accurately
in 1PA threads based on the currently-selected model.

Release Notes:

- N/A
2025-08-25 14:28:11 -04:00
Mikayla Maki
f1204dfc33
Revert "workspace: Disable padding on zoomed panels" (#36884)
Reverts zed-industries/zed#36012

We thought we didn't need this UI, but it turns out it was load bearing
:)

Release Notes:

- Restored the zoomed panel padding
2025-08-25 10:46:36 -07:00
Marshall Bowers
2e1ca47241
Make fields of AiUpsellCard private (#36888)
This PR makes the fields of the `AiUpsellCard` private, for better
encapsulation.

Release Notes:

- N/A
2025-08-25 17:21:20 +00:00
Finn Evers
5c346a4ccf
kotlin: Specify default language server (#36871)
As of
db52fc3655,
the Kotlin extension has two language servers. However, following that
change, no default language server for Kotlin was configured within this
repo, which led to two language servers being activated for Kotlin by
default.

This PR makes `kotlin-language-server` the default language server for
the extension. This also ensures that the [documentation within the
repository](https://github.com/zed-extensions/kotlin?tab=readme-ov-file#kotlin-lsp)
matches what is actually the case.


Release Notes:

- kotlin: Made `kotlin-language-server` the default language server.
2025-08-25 19:12:33 +02:00
Conrad Irwin
a102b08743
Require confirmation for fetch tool (#36881)
Using prompt injection, the agent may be tricked into making a fetch
request that includes unexpected data from the conversation in the URL.

As agent conversations may contain sensitive information (like private
code, or
potentially even API keys), this seems bad.

The easiest way to prevent this is to require the user to look at the
URL
before the model is allowed to fetch it.

Thanks to @ant4g0nist for bringing this to our attention.

Release Notes:

- agent panel: The fetch tool now requires confirmation.
2025-08-25 16:03:07 +00:00
Marshall Bowers
2dc4f156b3
Revert "Capture shorthand_field_initializer and modules in Rust highlights (#35842)" (#36880)
This PR reverts https://github.com/zed-industries/zed/pull/35842, as it
broke the syntax highlighting for `crate`:

### Before Revert

<img width="367" height="70" alt="Screenshot 2025-08-25 at 11 29 50 AM"
src="https://github.com/user-attachments/assets/ce9b8b59-4e89-43ed-84c7-95c0156b9168"
/>

### After Revert

<img width="353" height="69" alt="Screenshot 2025-08-25 at 11 32 17 AM"
src="https://github.com/user-attachments/assets/b6df5a21-64db-4abf-aa76-f085236da0c4"
/>

This reverts commit 896a35f7be.

Release Notes:

- Reverted https://github.com/zed-industries/zed/pull/35842.
2025-08-25 15:51:31 +00:00
Bennet Bo Fenner
557753d092
acp: Add Reauthenticate to dropdown (#36878)
Release Notes:

- N/A

Co-authored-by: Conrad Irwin <conrad.irwin@gmail.com>
2025-08-25 15:46:07 +00:00
Conrad Irwin
65fb17e2c9
acp: Remember following state (#36793)
A beta user reported that following was "lost" when asking for
confirmation, I
suspect they moved their cursor in the agent file while reviewing the
change.
Now we will resume following when the agent starts up again.

Release Notes:

- N/A
2025-08-25 09:34:30 -06:00
Smit Barmase
2fe3dbed31
project: Remove redundant Option from parse_register_capabilities (#36874)
Release Notes:

- N/A
2025-08-25 21:00:53 +05:30
Zach Riegel
fda5111dc0
Add CSS language injections for calls to styled (#33966)
…emotion).

Closes: https://github.com/zed-industries/zed/issues/17026

Release Notes:

- Added CSS language injection support for styled-components and emotion
in JavaScript, TypeScript, and TSX files.
2025-08-25 11:30:09 -04:00
Antonio Scandurra
69127d2bea
acp: Simplify control flow for native agent loop (#36868)
Release Notes:

- N/A

Co-authored-by: Bennet Bo Fenner <bennetbo@gmx.de>
2025-08-25 13:38:19 +00:00
131 changed files with 8024 additions and 4431 deletions

View file

@ -14,7 +14,7 @@ body:
### Description
<!-- Describe with sufficient detail to reproduce from a clean Zed install.
- Any code must be sufficient to reproduce (include context!)
- Code must as text, not just as a screenshot.
- Include code as text, not just as a screenshot.
- Issues with insufficient detail may be summarily closed.
-->

1
Cargo.lock generated
View file

@ -13521,6 +13521,7 @@ dependencies = [
"smol",
"sysinfo",
"telemetry_events",
"thiserror 2.0.12",
"toml 0.8.20",
"unindent",
"util",

View file

@ -0,0 +1,4 @@
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M8 12.375H13" stroke="black" stroke-width="1.2" stroke-linecap="round" stroke-linejoin="round"/>
<path d="M3 11.125L6.75003 7.375L3 3.62497" stroke="black" stroke-width="1.2" stroke-linecap="round" stroke-linejoin="round"/>
</svg>

After

Width:  |  Height:  |  Size: 336 B

1257
assets/images/acp_grid.svg Normal file

File diff suppressed because it is too large Load diff

After

Width:  |  Height:  |  Size: 176 KiB

View file

@ -0,0 +1 @@
<svg xmlns="http://www.w3.org/2000/svg" width="160" height="61" fill="none"><g clip-path="url(#a)"><path fill="#000" d="M130.75.385c5.428 0 10.297 2.81 13.011 7.511l14.214 24.618-.013-.005c2.599 4.504 2.707 9.932.28 14.513-2.618 4.944-7.862 8.015-13.679 8.015h-31.811c-.452 0-.873-.242-1.103-.637a1.268 1.268 0 0 1 0-1.274l3.919-6.78c.223-.394.65-.636 1.102-.636h28.288a5.622 5.622 0 0 0 4.925-2.849 5.615 5.615 0 0 0 0-5.69l-14.214-24.617a5.621 5.621 0 0 0-4.925-2.848 5.621 5.621 0 0 0-4.925 2.848l-14.214 24.618a6.267 6.267 0 0 0-.319.643.998.998 0 0 1-.069.14L101.724 54.4l-.823 1.313-2.529 4.39a1.27 1.27 0 0 1-1.103.636h-7.83c-.452 0-.873-.242-1.102-.637-.23-.394-.23-.879 0-1.274l2.188-3.791H66.803c-3.32 0-6.454-1.122-8.818-3.167a17.141 17.141 0 0 1-3.394-3.96 1.261 1.261 0 0 1-.091-.137L34.2 12.573a5.622 5.622 0 0 0-4.925-2.849 5.621 5.621 0 0 0-4.924 2.85L10.137 37.19a5.615 5.615 0 0 0 0 5.69 5.63 5.63 0 0 0 4.925 2.841h29.862a1.276 1.276 0 0 1 1.102 1.912l-3.912 6.778a1.27 1.27 0 0 1-1.102.638H14.495c-3.32 0-6.454-1.128-8.817-3.173-5.906-5.104-7.36-12.883-3.62-19.363L16.267 7.89C18.872 3.385 23.517.583 28.697.39c.184-.006.356-.006.534-.006 5.378 0 10.45 3.007 13.246 7.85l12.986 22.372L68.58 7.891C71.186 3.385 75.83.582 81.01.39c.185-.006.358-.006.536-.006 4.453 0 8.71 2.039 11.672 5.588.337.407.388.98.127 1.446l-3.765 6.6a1.268 1.268 0 0 1-2.205.006l-.847-1.465a5.623 5.623 0 0 0-4.926-2.848 5.622 5.622 0 0 0-4.924 2.848L62.464 37.18a5.614 5.614 0 0 0 0 5.689 5.628 5.628 0 0 0 4.925 2.842H95.91L117.76 7.87c2.714-4.683 7.575-7.486 12.99-7.486Z"/></g><defs><clipPath id="a"><path fill="#fff" d="M0 .385h160v60.36H0z"/></clipPath></defs></svg>

After

Width:  |  Height:  |  Size: 1.6 KiB

File diff suppressed because one or more lines are too long

After

Width:  |  Height:  |  Size: 14 KiB

View file

@ -40,7 +40,7 @@
"shift-f11": "debugger::StepOut",
"f11": "zed::ToggleFullScreen",
"ctrl-alt-z": "edit_prediction::RateCompletions",
"ctrl-shift-i": "edit_prediction::ToggleMenu",
"ctrl-alt-shift-i": "edit_prediction::ToggleMenu",
"ctrl-alt-l": "lsp_tool::ToggleMenu"
}
},
@ -120,7 +120,7 @@
"alt-g m": "git::OpenModifiedFiles",
"menu": "editor::OpenContextMenu",
"shift-f10": "editor::OpenContextMenu",
"ctrl-shift-e": "editor::ToggleEditPrediction",
"ctrl-alt-shift-e": "editor::ToggleEditPrediction",
"f9": "editor::ToggleBreakpoint",
"shift-f9": "editor::EditLogBreakpoint"
}

File diff suppressed because it is too large Load diff

View file

@ -38,6 +38,7 @@
"alt-;": ["editor::ToggleComments", { "advance_downwards": false }],
"ctrl-x ctrl-;": "editor::ToggleComments",
"alt-.": "editor::GoToDefinition", // xref-find-definitions
"alt-?": "editor::FindAllReferences", // xref-find-references
"alt-,": "pane::GoBack", // xref-pop-marker-stack
"ctrl-x h": "editor::SelectAll", // mark-whole-buffer
"ctrl-d": "editor::Delete", // delete-char

View file

@ -38,6 +38,7 @@
"alt-;": ["editor::ToggleComments", { "advance_downwards": false }],
"ctrl-x ctrl-;": "editor::ToggleComments",
"alt-.": "editor::GoToDefinition", // xref-find-definitions
"alt-?": "editor::FindAllReferences", // xref-find-references
"alt-,": "pane::GoBack", // xref-pop-marker-stack
"ctrl-x h": "editor::SelectAll", // mark-whole-buffer
"ctrl-d": "editor::Delete", // delete-char

View file

@ -428,11 +428,13 @@
"g h": "vim::StartOfLine",
"g s": "vim::FirstNonWhitespace", // "g s" default behavior is "space s"
"g e": "vim::EndOfDocument",
"g .": "vim::HelixGotoLastModification", // go to last modification
"g r": "editor::FindAllReferences", // zed specific
"g t": "vim::WindowTop",
"g c": "vim::WindowMiddle",
"g b": "vim::WindowBottom",
"shift-r": "editor::Paste",
"x": "editor::SelectLine",
"shift-x": "editor::SelectLine",
"%": "editor::SelectAll",

View file

@ -162,6 +162,12 @@
// 2. Always quit the application
// "on_last_window_closed": "quit_app",
"on_last_window_closed": "platform_default",
// Whether to show padding for zoomed panels.
// When enabled, zoomed center panels (e.g. code editor) will have padding all around,
// while zoomed bottom/left/right panels will have padding to the top/right/left (respectively).
//
// Default: true
"zoomed_padding": true,
// Whether to use the system provided dialogs for Open and Save As.
// When set to false, Zed will use the built-in keyboard-first pickers.
"use_system_path_prompts": true,
@ -647,6 +653,8 @@
// "never"
"show": "always"
},
// Whether to enable drag-and-drop operations in the project panel.
"drag_and_drop": true,
// Whether to hide the root entry when only one folder is open in the window.
"hide_root": false
},
@ -1629,6 +1637,9 @@
"allowed": true
}
},
"Kotlin": {
"language_servers": ["kotlin-language-server", "!kotlin-lsp", "..."]
},
"LaTeX": {
"formatter": "language_server",
"language_servers": ["texlab", "..."],

View file

@ -43,8 +43,8 @@
// "args": ["--login"]
// }
// }
"shell": "system",
"shell": "system"
// Represents the tags for inline runnable indicators, or spawning multiple tasks at once.
"tags": []
// "tags": []
}
]

View file

@ -183,16 +183,15 @@ impl ToolCall {
language_registry: Arc<LanguageRegistry>,
cx: &mut App,
) -> Self {
let title = if let Some((first_line, _)) = tool_call.title.split_once("\n") {
first_line.to_owned() + ""
} else {
tool_call.title
};
Self {
id: tool_call.id,
label: cx.new(|cx| {
Markdown::new(
tool_call.title.into(),
Some(language_registry.clone()),
None,
cx,
)
}),
label: cx
.new(|cx| Markdown::new(title.into(), Some(language_registry.clone()), None, cx)),
kind: tool_call.kind,
content: tool_call
.content
@ -233,7 +232,11 @@ impl ToolCall {
if let Some(title) = title {
self.label.update(cx, |label, cx| {
label.replace(title, cx);
if let Some((first_line, _)) = title.split_once("\n") {
label.replace(first_line.to_owned() + "", cx)
} else {
label.replace(title, cx);
}
});
}
@ -756,6 +759,8 @@ pub struct AcpThread {
connection: Rc<dyn AgentConnection>,
session_id: acp::SessionId,
token_usage: Option<TokenUsage>,
prompt_capabilities: acp::PromptCapabilities,
_observe_prompt_capabilities: Task<anyhow::Result<()>>,
}
#[derive(Debug)]
@ -770,11 +775,12 @@ pub enum AcpThreadEvent {
Stopped,
Error,
LoadError(LoadError),
PromptCapabilitiesUpdated,
}
impl EventEmitter<AcpThreadEvent> for AcpThread {}
#[derive(PartialEq, Eq)]
#[derive(PartialEq, Eq, Debug)]
pub enum ThreadStatus {
Idle,
WaitingForToolConfirmation,
@ -821,7 +827,20 @@ impl AcpThread {
project: Entity<Project>,
action_log: Entity<ActionLog>,
session_id: acp::SessionId,
mut prompt_capabilities_rx: watch::Receiver<acp::PromptCapabilities>,
cx: &mut Context<Self>,
) -> Self {
let prompt_capabilities = *prompt_capabilities_rx.borrow();
let task = cx.spawn::<_, anyhow::Result<()>>(async move |this, cx| {
loop {
let caps = prompt_capabilities_rx.recv().await?;
this.update(cx, |this, cx| {
this.prompt_capabilities = caps;
cx.emit(AcpThreadEvent::PromptCapabilitiesUpdated);
})?;
}
});
Self {
action_log,
shared_buffers: Default::default(),
@ -833,9 +852,15 @@ impl AcpThread {
connection,
session_id,
token_usage: None,
prompt_capabilities,
_observe_prompt_capabilities: task,
}
}
pub fn prompt_capabilities(&self) -> acp::PromptCapabilities {
self.prompt_capabilities
}
pub fn connection(&self) -> &Rc<dyn AgentConnection> {
&self.connection
}
@ -2599,13 +2624,19 @@ mod tests {
.into(),
);
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let thread = cx.new(|_cx| {
let thread = cx.new(|cx| {
AcpThread::new(
"Test",
self.clone(),
project,
action_log,
session_id.clone(),
watch::Receiver::constant(acp::PromptCapabilities {
image: true,
audio: true,
embedded_context: true,
}),
cx,
)
});
self.sessions.lock().insert(session_id, thread.downgrade());
@ -2639,14 +2670,6 @@ mod tests {
}
}
fn prompt_capabilities(&self) -> acp::PromptCapabilities {
acp::PromptCapabilities {
image: true,
audio: true,
embedded_context: true,
}
}
fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
let sessions = self.sessions.lock();
let thread = sessions.get(session_id).unwrap().clone();

View file

@ -38,8 +38,6 @@ pub trait AgentConnection {
cx: &mut App,
) -> Task<Result<acp::PromptResponse>>;
fn prompt_capabilities(&self) -> acp::PromptCapabilities;
fn resume(
&self,
_session_id: &acp::SessionId,
@ -329,13 +327,19 @@ mod test_support {
) -> Task<gpui::Result<Entity<AcpThread>>> {
let session_id = acp::SessionId(self.sessions.lock().len().to_string().into());
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let thread = cx.new(|_cx| {
let thread = cx.new(|cx| {
AcpThread::new(
"Test",
self.clone(),
project,
action_log,
session_id.clone(),
watch::Receiver::constant(acp::PromptCapabilities {
image: true,
audio: true,
embedded_context: true,
}),
cx,
)
});
self.sessions.lock().insert(
@ -348,14 +352,6 @@ mod test_support {
Task::ready(Ok(thread))
}
fn prompt_capabilities(&self) -> acp::PromptCapabilities {
acp::PromptCapabilities {
image: true,
audio: true,
embedded_context: true,
}
}
fn authenticate(
&self,
_method_id: acp::AuthMethodId,

View file

@ -21,12 +21,12 @@ use ui::prelude::*;
use util::ResultExt as _;
use workspace::{Item, Workspace};
actions!(acp, [OpenDebugTools]);
actions!(dev, [OpenAcpLogs]);
pub fn init(cx: &mut App) {
cx.observe_new(
|workspace: &mut Workspace, _window, _cx: &mut Context<Workspace>| {
workspace.register_action(|workspace, _: &OpenDebugTools, window, cx| {
workspace.register_action(|workspace, _: &OpenAcpLogs, window, cx| {
let acp_tools =
Box::new(cx.new(|cx| AcpTools::new(workspace.project().clone(), cx)));
workspace.add_item_to_active_pane(acp_tools, None, true, window, cx);

View file

@ -664,7 +664,7 @@ impl Thread {
}
pub fn get_or_init_configured_model(&mut self, cx: &App) -> Option<ConfiguredModel> {
if self.configured_model.is_none() || self.messages.is_empty() {
if self.configured_model.is_none() {
self.configured_model = LanguageModelRegistry::read_global(cx).default_model();
}
self.configured_model.clone()
@ -2097,7 +2097,7 @@ impl Thread {
}
pub fn summarize(&mut self, cx: &mut Context<Self>) {
let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model(cx) else {
let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
println!("No thread summary model");
return;
};
@ -2416,7 +2416,7 @@ impl Thread {
}
let Some(ConfiguredModel { model, provider }) =
LanguageModelRegistry::read_global(cx).thread_summary_model(cx)
LanguageModelRegistry::read_global(cx).thread_summary_model()
else {
return;
};
@ -5410,10 +5410,13 @@ fn main() {{
}),
cx,
);
registry.set_thread_summary_model(Some(ConfiguredModel {
provider,
model: model.clone(),
}));
registry.set_thread_summary_model(
Some(ConfiguredModel {
provider,
model: model.clone(),
}),
cx,
);
})
});

View file

@ -228,7 +228,7 @@ impl NativeAgent {
) -> Entity<AcpThread> {
let connection = Rc::new(NativeAgentConnection(cx.entity()));
let registry = LanguageModelRegistry::read_global(cx);
let summarization_model = registry.thread_summary_model(cx).map(|c| c.model);
let summarization_model = registry.thread_summary_model().map(|c| c.model);
thread_handle.update(cx, |thread, cx| {
thread.set_summarization_model(summarization_model, cx);
@ -240,13 +240,16 @@ impl NativeAgent {
let title = thread.title();
let project = thread.project.clone();
let action_log = thread.action_log.clone();
let acp_thread = cx.new(|_cx| {
let prompt_capabilities_rx = thread.prompt_capabilities_rx.clone();
let acp_thread = cx.new(|cx| {
acp_thread::AcpThread::new(
title,
connection,
project.clone(),
action_log.clone(),
session_id.clone(),
prompt_capabilities_rx,
cx,
)
});
let subscriptions = vec![
@ -521,7 +524,7 @@ impl NativeAgent {
let registry = LanguageModelRegistry::read_global(cx);
let default_model = registry.default_model().map(|m| m.model);
let summarization_model = registry.thread_summary_model(cx).map(|m| m.model);
let summarization_model = registry.thread_summary_model().map(|m| m.model);
for session in self.sessions.values_mut() {
session.thread.update(cx, |thread, cx| {
@ -925,14 +928,6 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
})
}
fn prompt_capabilities(&self) -> acp::PromptCapabilities {
acp::PromptCapabilities {
image: true,
audio: false,
embedded_context: true,
}
}
fn resume(
&self,
session_id: &acp::SessionId,

View file

@ -22,6 +22,10 @@ impl NativeAgentServer {
}
impl AgentServer for NativeAgentServer {
fn telemetry_id(&self) -> &'static str {
"zed"
}
fn name(&self) -> SharedString {
"Zed Agent".into()
}

View file

@ -72,6 +72,7 @@ async fn test_echo(cx: &mut TestAppContext) {
}
#[gpui::test]
#[cfg_attr(target_os = "windows", ignore)] // TODO: Fix this test on Windows
async fn test_thinking(cx: &mut TestAppContext) {
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
let fake_model = model.as_fake();
@ -471,7 +472,7 @@ async fn test_tool_authorization(cx: &mut TestAppContext) {
tool_name: ToolRequiringPermission::name().into(),
is_error: true,
content: "Permission to run tool denied by user".into(),
output: None
output: Some("Permission to run tool denied by user".into())
})
]
);
@ -1347,6 +1348,7 @@ async fn test_cancellation(cx: &mut TestAppContext) {
}
#[gpui::test]
#[cfg_attr(target_os = "windows", ignore)] // TODO: Fix this test on Windows
async fn test_in_progress_send_canceled_by_next_send(cx: &mut TestAppContext) {
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
let fake_model = model.as_fake();
@ -1685,6 +1687,7 @@ async fn test_truncate_second_message(cx: &mut TestAppContext) {
}
#[gpui::test]
#[cfg_attr(target_os = "windows", ignore)] // TODO: Fix this test on Windows
async fn test_title_generation(cx: &mut TestAppContext) {
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
let fake_model = model.as_fake();
@ -1819,11 +1822,11 @@ async fn test_agent_connection(cx: &mut TestAppContext) {
let clock = Arc::new(clock::FakeSystemClock::new());
let client = Client::new(clock, http_client, cx);
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
Project::init_settings(cx);
agent_settings::init(cx);
language_model::init(client.clone(), cx);
language_models::init(user_store, client.clone(), cx);
Project::init_settings(cx);
LanguageModelRegistry::test(cx);
agent_settings::init(cx);
});
cx.executor().forbid_parking();

View file

@ -575,11 +575,22 @@ pub struct Thread {
templates: Arc<Templates>,
model: Option<Arc<dyn LanguageModel>>,
summarization_model: Option<Arc<dyn LanguageModel>>,
prompt_capabilities_tx: watch::Sender<acp::PromptCapabilities>,
pub(crate) prompt_capabilities_rx: watch::Receiver<acp::PromptCapabilities>,
pub(crate) project: Entity<Project>,
pub(crate) action_log: Entity<ActionLog>,
}
impl Thread {
fn prompt_capabilities(model: Option<&dyn LanguageModel>) -> acp::PromptCapabilities {
let image = model.map_or(true, |model| model.supports_images());
acp::PromptCapabilities {
image,
audio: false,
embedded_context: true,
}
}
pub fn new(
project: Entity<Project>,
project_context: Entity<ProjectContext>,
@ -590,6 +601,8 @@ impl Thread {
) -> Self {
let profile_id = AgentSettings::get_global(cx).default_profile.clone();
let action_log = cx.new(|_cx| ActionLog::new(project.clone()));
let (prompt_capabilities_tx, prompt_capabilities_rx) =
watch::channel(Self::prompt_capabilities(model.as_deref()));
Self {
id: acp::SessionId(uuid::Uuid::new_v4().to_string().into()),
prompt_id: PromptId::new(),
@ -617,6 +630,8 @@ impl Thread {
templates,
model,
summarization_model: None,
prompt_capabilities_tx,
prompt_capabilities_rx,
project,
action_log,
}
@ -717,7 +732,17 @@ impl Thread {
stream.update_tool_call_fields(
&tool_use.id,
acp::ToolCallUpdateFields {
status: Some(acp::ToolCallStatus::Completed),
status: Some(
tool_result
.as_ref()
.map_or(acp::ToolCallStatus::Failed, |result| {
if result.is_error {
acp::ToolCallStatus::Failed
} else {
acp::ToolCallStatus::Completed
}
}),
),
raw_output: output,
..Default::default()
},
@ -750,6 +775,8 @@ impl Thread {
.or_else(|| registry.default_model())
.map(|model| model.model)
});
let (prompt_capabilities_tx, prompt_capabilities_rx) =
watch::channel(Self::prompt_capabilities(model.as_deref()));
Self {
id,
@ -779,6 +806,8 @@ impl Thread {
project,
action_log,
updated_at: db_thread.updated_at,
prompt_capabilities_tx,
prompt_capabilities_rx,
}
}
@ -946,10 +975,12 @@ impl Thread {
pub fn set_model(&mut self, model: Arc<dyn LanguageModel>, cx: &mut Context<Self>) {
let old_usage = self.latest_token_usage();
self.model = Some(model);
let new_caps = Self::prompt_capabilities(self.model.as_deref());
let new_usage = self.latest_token_usage();
if old_usage != new_usage {
cx.emit(TokenUsageUpdated(new_usage));
}
self.prompt_capabilities_tx.send(new_caps).log_err();
cx.notify()
}
@ -1142,37 +1173,7 @@ impl Thread {
_task: cx.spawn(async move |this, cx| {
log::debug!("Starting agent turn execution");
let turn_result: Result<()> = async {
let mut intent = CompletionIntent::UserPrompt;
loop {
Self::stream_completion(&this, &model, intent, &event_stream, cx).await?;
let mut end_turn = true;
this.update(cx, |this, cx| {
// Generate title if needed.
if this.title.is_none() && this.pending_title_generation.is_none() {
this.generate_title(cx);
}
// End the turn if the model didn't use tools.
let message = this.pending_message.as_ref();
end_turn =
message.map_or(true, |message| message.tool_results.is_empty());
this.flush_pending_message(cx);
})?;
if this.read_with(cx, |this, _| this.tool_use_limit_reached)? {
log::info!("Tool use limit reached, completing turn");
return Err(language_model::ToolUseLimitReachedError.into());
} else if end_turn {
log::debug!("No tool uses found, completing turn");
return Ok(());
} else {
intent = CompletionIntent::ToolResults;
}
}
}
.await;
let turn_result = Self::run_turn_internal(&this, model, &event_stream, cx).await;
_ = this.update(cx, |this, cx| this.flush_pending_message(cx));
match turn_result {
@ -1203,20 +1204,17 @@ impl Thread {
Ok(events_rx)
}
async fn stream_completion(
async fn run_turn_internal(
this: &WeakEntity<Self>,
model: &Arc<dyn LanguageModel>,
completion_intent: CompletionIntent,
model: Arc<dyn LanguageModel>,
event_stream: &ThreadEventStream,
cx: &mut AsyncApp,
) -> Result<()> {
log::debug!("Stream completion started successfully");
let mut attempt = None;
let mut attempt = 0;
let mut intent = CompletionIntent::UserPrompt;
loop {
let request = this.update(cx, |this, cx| {
this.build_completion_request(completion_intent, cx)
})??;
let request =
this.update(cx, |this, cx| this.build_completion_request(intent, cx))??;
telemetry::event!(
"Agent Thread Completion",
@ -1227,23 +1225,19 @@ impl Thread {
attempt
);
log::debug!(
"Calling model.stream_completion, attempt {}",
attempt.unwrap_or(0)
);
log::debug!("Calling model.stream_completion, attempt {}", attempt);
let mut events = model
.stream_completion(request, cx)
.await
.map_err(|error| anyhow!(error))?;
let mut tool_results = FuturesUnordered::new();
let mut error = None;
while let Some(event) = events.next().await {
log::trace!("Received completion event: {:?}", event);
match event {
Ok(event) => {
log::trace!("Received completion event: {:?}", event);
tool_results.extend(this.update(cx, |this, cx| {
this.handle_streamed_completion_event(event, event_stream, cx)
this.handle_completion_event(event, event_stream, cx)
})??);
}
Err(err) => {
@ -1253,6 +1247,7 @@ impl Thread {
}
}
let end_turn = tool_results.is_empty();
while let Some(tool_result) = tool_results.next().await {
log::debug!("Tool finished {:?}", tool_result);
@ -1275,65 +1270,83 @@ impl Thread {
})?;
}
this.update(cx, |this, cx| {
this.flush_pending_message(cx);
if this.title.is_none() && this.pending_title_generation.is_none() {
this.generate_title(cx);
}
})?;
if let Some(error) = error {
let completion_mode = this.read_with(cx, |thread, _cx| thread.completion_mode())?;
if completion_mode == CompletionMode::Normal {
return Err(anyhow!(error))?;
}
let Some(strategy) = Self::retry_strategy_for(&error) else {
return Err(anyhow!(error))?;
};
let max_attempts = match &strategy {
RetryStrategy::ExponentialBackoff { max_attempts, .. } => *max_attempts,
RetryStrategy::Fixed { max_attempts, .. } => *max_attempts,
};
let attempt = attempt.get_or_insert(0u8);
*attempt += 1;
let attempt = *attempt;
if attempt > max_attempts {
return Err(anyhow!(error))?;
}
let delay = match &strategy {
RetryStrategy::ExponentialBackoff { initial_delay, .. } => {
let delay_secs = initial_delay.as_secs() * 2u64.pow((attempt - 1) as u32);
Duration::from_secs(delay_secs)
}
RetryStrategy::Fixed { delay, .. } => *delay,
};
log::debug!("Retry attempt {attempt} with delay {delay:?}");
event_stream.send_retry(acp_thread::RetryStatus {
last_error: error.to_string().into(),
attempt: attempt as usize,
max_attempts: max_attempts as usize,
started_at: Instant::now(),
duration: delay,
});
cx.background_executor().timer(delay).await;
this.update(cx, |this, cx| {
this.flush_pending_message(cx);
attempt += 1;
let retry =
this.update(cx, |this, _| this.handle_completion_error(error, attempt))??;
let timer = cx.background_executor().timer(retry.duration);
event_stream.send_retry(retry);
timer.await;
this.update(cx, |this, _cx| {
if let Some(Message::Agent(message)) = this.messages.last() {
if message.tool_results.is_empty() {
intent = CompletionIntent::UserPrompt;
this.messages.push(Message::Resume);
}
}
})?;
} else {
} else if this.read_with(cx, |this, _| this.tool_use_limit_reached)? {
return Err(language_model::ToolUseLimitReachedError.into());
} else if end_turn {
return Ok(());
} else {
intent = CompletionIntent::ToolResults;
attempt = 0;
}
}
}
fn handle_completion_error(
&mut self,
error: LanguageModelCompletionError,
attempt: u8,
) -> Result<acp_thread::RetryStatus> {
if self.completion_mode == CompletionMode::Normal {
return Err(anyhow!(error));
}
let Some(strategy) = Self::retry_strategy_for(&error) else {
return Err(anyhow!(error));
};
let max_attempts = match &strategy {
RetryStrategy::ExponentialBackoff { max_attempts, .. } => *max_attempts,
RetryStrategy::Fixed { max_attempts, .. } => *max_attempts,
};
if attempt > max_attempts {
return Err(anyhow!(error));
}
let delay = match &strategy {
RetryStrategy::ExponentialBackoff { initial_delay, .. } => {
let delay_secs = initial_delay.as_secs() * 2u64.pow((attempt - 1) as u32);
Duration::from_secs(delay_secs)
}
RetryStrategy::Fixed { delay, .. } => *delay,
};
log::debug!("Retry attempt {attempt} with delay {delay:?}");
Ok(acp_thread::RetryStatus {
last_error: error.to_string().into(),
attempt: attempt as usize,
max_attempts: max_attempts as usize,
started_at: Instant::now(),
duration: delay,
})
}
/// A helper method that's called on every streamed completion event.
/// Returns an optional tool result task, which the main agentic loop will
/// send back to the model when it resolves.
fn handle_streamed_completion_event(
fn handle_completion_event(
&mut self,
event: LanguageModelCompletionEvent,
event_stream: &ThreadEventStream,
@ -1554,7 +1567,7 @@ impl Thread {
tool_name: tool_use.name,
is_error: true,
content: LanguageModelToolResultContent::Text(Arc::from(error.to_string())),
output: None,
output: Some(error.to_string().into()),
},
}
}))
@ -2456,6 +2469,30 @@ impl ToolCallEventStreamReceiver {
}
}
pub async fn expect_update_fields(&mut self) -> acp::ToolCallUpdateFields {
let event = self.0.next().await;
if let Some(Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
update,
)))) = event
{
update.fields
} else {
panic!("Expected update fields but got: {:?}", event);
}
}
pub async fn expect_diff(&mut self) -> Entity<acp_thread::Diff> {
let event = self.0.next().await;
if let Some(Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateDiff(
update,
)))) = event
{
update.diff
} else {
panic!("Expected diff but got: {:?}", event);
}
}
pub async fn expect_terminal(&mut self) -> Entity<acp_thread::Terminal> {
let event = self.0.next().await;
if let Some(Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateTerminal(

View file

@ -273,6 +273,13 @@ impl AgentTool for EditFileTool {
let diff = cx.new(|cx| Diff::new(buffer.clone(), cx))?;
event_stream.update_diff(diff.clone());
let _finalize_diff = util::defer({
let diff = diff.downgrade();
let mut cx = cx.clone();
move || {
diff.update(&mut cx, |diff, cx| diff.finalize(cx)).ok();
}
});
let old_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
let old_text = cx
@ -389,8 +396,6 @@ impl AgentTool for EditFileTool {
})
.await;
diff.update(cx, |diff, cx| diff.finalize(cx)).ok();
let input_path = input.path.display();
if unified_diff.is_empty() {
anyhow::ensure!(
@ -1545,6 +1550,100 @@ mod tests {
);
}
#[gpui::test]
async fn test_diff_finalization(cx: &mut TestAppContext) {
init_test(cx);
let fs = project::FakeFs::new(cx.executor());
fs.insert_tree("/", json!({"main.rs": ""})).await;
let project = Project::test(fs.clone(), [path!("/").as_ref()], cx).await;
let languages = project.read_with(cx, |project, _cx| project.languages().clone());
let context_server_registry =
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
let model = Arc::new(FakeLanguageModel::default());
let thread = cx.new(|cx| {
Thread::new(
project.clone(),
cx.new(|_cx| ProjectContext::default()),
context_server_registry.clone(),
Templates::new(),
Some(model.clone()),
cx,
)
});
// Ensure the diff is finalized after the edit completes.
{
let tool = Arc::new(EditFileTool::new(thread.downgrade(), languages.clone()));
let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
let edit = cx.update(|cx| {
tool.run(
EditFileToolInput {
display_description: "Edit file".into(),
path: path!("/main.rs").into(),
mode: EditFileMode::Edit,
},
stream_tx,
cx,
)
});
stream_rx.expect_update_fields().await;
let diff = stream_rx.expect_diff().await;
diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Pending(_))));
cx.run_until_parked();
model.end_last_completion_stream();
edit.await.unwrap();
diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Finalized(_))));
}
// Ensure the diff is finalized if an error occurs while editing.
{
model.forbid_requests();
let tool = Arc::new(EditFileTool::new(thread.downgrade(), languages.clone()));
let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
let edit = cx.update(|cx| {
tool.run(
EditFileToolInput {
display_description: "Edit file".into(),
path: path!("/main.rs").into(),
mode: EditFileMode::Edit,
},
stream_tx,
cx,
)
});
stream_rx.expect_update_fields().await;
let diff = stream_rx.expect_diff().await;
diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Pending(_))));
edit.await.unwrap_err();
diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Finalized(_))));
model.allow_requests();
}
// Ensure the diff is finalized if the tool call gets dropped.
{
let tool = Arc::new(EditFileTool::new(thread.downgrade(), languages.clone()));
let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
let edit = cx.update(|cx| {
tool.run(
EditFileToolInput {
display_description: "Edit file".into(),
path: path!("/main.rs").into(),
mode: EditFileMode::Edit,
},
stream_tx,
cx,
)
});
stream_rx.expect_update_fields().await;
let diff = stream_rx.expect_diff().await;
diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Pending(_))));
drop(edit);
cx.run_until_parked();
diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Finalized(_))));
}
}
fn init_test(cx: &mut TestAppContext) {
cx.update(|cx| {
let settings_store = SettingsStore::test(cx);

View file

@ -136,12 +136,17 @@ impl AgentTool for FetchTool {
fn run(
self: Arc<Self>,
input: Self::Input,
_event_stream: ToolCallEventStream,
event_stream: ToolCallEventStream,
cx: &mut App,
) -> Task<Result<Self::Output>> {
let authorize = event_stream.authorize(input.url.clone(), cx);
let text = cx.background_spawn({
let http_client = self.http_client.clone();
async move { Self::build_message(http_client, &input.url).await }
async move {
authorize.await?;
Self::build_message(http_client, &input.url).await
}
});
cx.foreground_executor().spawn(async move {

View file

@ -165,16 +165,17 @@ fn search_paths(glob: &str, project: Entity<Project>, cx: &mut App) -> Task<Resu
.collect();
cx.background_spawn(async move {
Ok(snapshots
.iter()
.flat_map(|snapshot| {
let mut results = Vec::new();
for snapshot in snapshots {
for entry in snapshot.entries(false, 0) {
let root_name = PathBuf::from(snapshot.root_name());
snapshot
.entries(false, 0)
.map(move |entry| root_name.join(&entry.path))
.filter(|path| path_matcher.is_match(&path))
})
.collect())
if path_matcher.is_match(root_name.join(&entry.path)) {
results.push(snapshot.abs_path().join(entry.path.as_ref()));
}
}
}
Ok(results)
})
}
@ -215,8 +216,8 @@ mod test {
assert_eq!(
matches,
&[
PathBuf::from("root/apple/banana/carrot"),
PathBuf::from("root/apple/bandana/carbonara")
PathBuf::from(path!("/root/apple/banana/carrot")),
PathBuf::from(path!("/root/apple/bandana/carbonara"))
]
);
@ -227,8 +228,8 @@ mod test {
assert_eq!(
matches,
&[
PathBuf::from("root/apple/banana/carrot"),
PathBuf::from("root/apple/bandana/carbonara")
PathBuf::from(path!("/root/apple/banana/carrot")),
PathBuf::from(path!("/root/apple/bandana/carbonara"))
]
);
}

View file

@ -11,6 +11,7 @@ use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::Settings;
use std::{path::Path, sync::Arc};
use util::markdown::MarkdownCodeBlock;
use crate::{AgentTool, ToolCallEventStream};
@ -243,6 +244,19 @@ impl AgentTool for ReadFileTool {
}]),
..Default::default()
});
if let Ok(LanguageModelToolResultContent::Text(text)) = &result {
let markdown = MarkdownCodeBlock {
tag: &input.path,
text,
}
.to_string();
event_stream.update_fields(ToolCallUpdateFields {
content: Some(vec![acp::ToolCallContent::Content {
content: markdown.into(),
}]),
..Default::default()
})
}
}
})?;

View file

@ -162,12 +162,34 @@ impl AgentConnection for AcpConnection {
let conn = self.connection.clone();
let sessions = self.sessions.clone();
let cwd = cwd.to_path_buf();
let context_server_store = project.read(cx).context_server_store().read(cx);
let mcp_servers = context_server_store
.configured_server_ids()
.iter()
.filter_map(|id| {
let configuration = context_server_store.configuration_for_server(id)?;
let command = configuration.command();
Some(acp::McpServer {
name: id.0.to_string(),
command: command.path.clone(),
args: command.args.clone(),
env: if let Some(env) = command.env.as_ref() {
env.iter()
.map(|(name, value)| acp::EnvVariable {
name: name.clone(),
value: value.clone(),
})
.collect()
} else {
vec![]
},
})
})
.collect();
cx.spawn(async move |cx| {
let response = conn
.new_session(acp::NewSessionRequest {
mcp_servers: vec![],
cwd,
})
.new_session(acp::NewSessionRequest { mcp_servers, cwd })
.await
.map_err(|err| {
if err.code == acp::ErrorCode::AUTH_REQUIRED.code {
@ -185,13 +207,16 @@ impl AgentConnection for AcpConnection {
let session_id = response.session_id;
let action_log = cx.new(|_| ActionLog::new(project.clone()))?;
let thread = cx.new(|_cx| {
let thread = cx.new(|cx| {
AcpThread::new(
self.server_name.clone(),
self.clone(),
project,
action_log,
session_id.clone(),
// ACP doesn't currently support per-session prompt capabilities or changing capabilities dynamically.
watch::Receiver::constant(self.prompt_capabilities),
cx,
)
})?;
@ -263,7 +288,9 @@ impl AgentConnection for AcpConnection {
match serde_json::from_value(data.clone()) {
Ok(ErrorDetails { details }) => {
if suppress_abort_err && details.contains("This operation was aborted")
if suppress_abort_err
&& (details.contains("This operation was aborted")
|| details.contains("The user aborted a request"))
{
Ok(acp::PromptResponse {
stop_reason: acp::StopReason::Cancelled,
@ -279,10 +306,6 @@ impl AgentConnection for AcpConnection {
})
}
fn prompt_capabilities(&self) -> acp::PromptCapabilities {
self.prompt_capabilities
}
fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
if let Some(session) = self.sessions.borrow_mut().get_mut(session_id) {
session.suppress_abort_err = true;

View file

@ -1,524 +0,0 @@
// Translates old acp agents into the new schema
use action_log::ActionLog;
use agent_client_protocol as acp;
use agentic_coding_protocol::{self as acp_old, AgentRequest as _};
use anyhow::{Context as _, Result, anyhow};
use futures::channel::oneshot;
use gpui::{AppContext as _, AsyncApp, Entity, Task, WeakEntity};
use project::Project;
use std::{any::Any, cell::RefCell, path::Path, rc::Rc};
use ui::App;
use util::ResultExt as _;
use crate::AgentServerCommand;
use acp_thread::{AcpThread, AgentConnection, AuthRequired};
#[derive(Clone)]
struct OldAcpClientDelegate {
thread: Rc<RefCell<WeakEntity<AcpThread>>>,
cx: AsyncApp,
next_tool_call_id: Rc<RefCell<u64>>,
// sent_buffer_versions: HashMap<Entity<Buffer>, HashMap<u64, BufferSnapshot>>,
}
impl OldAcpClientDelegate {
fn new(thread: Rc<RefCell<WeakEntity<AcpThread>>>, cx: AsyncApp) -> Self {
Self {
thread,
cx,
next_tool_call_id: Rc::new(RefCell::new(0)),
}
}
}
impl acp_old::Client for OldAcpClientDelegate {
async fn stream_assistant_message_chunk(
&self,
params: acp_old::StreamAssistantMessageChunkParams,
) -> Result<(), acp_old::Error> {
let cx = &mut self.cx.clone();
cx.update(|cx| {
self.thread
.borrow()
.update(cx, |thread, cx| match params.chunk {
acp_old::AssistantMessageChunk::Text { text } => {
thread.push_assistant_content_block(text.into(), false, cx)
}
acp_old::AssistantMessageChunk::Thought { thought } => {
thread.push_assistant_content_block(thought.into(), true, cx)
}
})
.log_err();
})?;
Ok(())
}
async fn request_tool_call_confirmation(
&self,
request: acp_old::RequestToolCallConfirmationParams,
) -> Result<acp_old::RequestToolCallConfirmationResponse, acp_old::Error> {
let cx = &mut self.cx.clone();
let old_acp_id = *self.next_tool_call_id.borrow() + 1;
self.next_tool_call_id.replace(old_acp_id);
let tool_call = into_new_tool_call(
acp::ToolCallId(old_acp_id.to_string().into()),
request.tool_call,
);
let mut options = match request.confirmation {
acp_old::ToolCallConfirmation::Edit { .. } => vec![(
acp_old::ToolCallConfirmationOutcome::AlwaysAllow,
acp::PermissionOptionKind::AllowAlways,
"Always Allow Edits".to_string(),
)],
acp_old::ToolCallConfirmation::Execute { root_command, .. } => vec![(
acp_old::ToolCallConfirmationOutcome::AlwaysAllow,
acp::PermissionOptionKind::AllowAlways,
format!("Always Allow {}", root_command),
)],
acp_old::ToolCallConfirmation::Mcp {
server_name,
tool_name,
..
} => vec![
(
acp_old::ToolCallConfirmationOutcome::AlwaysAllowMcpServer,
acp::PermissionOptionKind::AllowAlways,
format!("Always Allow {}", server_name),
),
(
acp_old::ToolCallConfirmationOutcome::AlwaysAllowTool,
acp::PermissionOptionKind::AllowAlways,
format!("Always Allow {}", tool_name),
),
],
acp_old::ToolCallConfirmation::Fetch { .. } => vec![(
acp_old::ToolCallConfirmationOutcome::AlwaysAllow,
acp::PermissionOptionKind::AllowAlways,
"Always Allow".to_string(),
)],
acp_old::ToolCallConfirmation::Other { .. } => vec![(
acp_old::ToolCallConfirmationOutcome::AlwaysAllow,
acp::PermissionOptionKind::AllowAlways,
"Always Allow".to_string(),
)],
};
options.extend([
(
acp_old::ToolCallConfirmationOutcome::Allow,
acp::PermissionOptionKind::AllowOnce,
"Allow".to_string(),
),
(
acp_old::ToolCallConfirmationOutcome::Reject,
acp::PermissionOptionKind::RejectOnce,
"Reject".to_string(),
),
]);
let mut outcomes = Vec::with_capacity(options.len());
let mut acp_options = Vec::with_capacity(options.len());
for (index, (outcome, kind, label)) in options.into_iter().enumerate() {
outcomes.push(outcome);
acp_options.push(acp::PermissionOption {
id: acp::PermissionOptionId(index.to_string().into()),
name: label,
kind,
})
}
let response = cx
.update(|cx| {
self.thread.borrow().update(cx, |thread, cx| {
thread.request_tool_call_authorization(tool_call.into(), acp_options, cx)
})
})??
.context("Failed to update thread")?
.await;
let outcome = match response {
Ok(option_id) => outcomes[option_id.0.parse::<usize>().unwrap_or(0)],
Err(oneshot::Canceled) => acp_old::ToolCallConfirmationOutcome::Cancel,
};
Ok(acp_old::RequestToolCallConfirmationResponse {
id: acp_old::ToolCallId(old_acp_id),
outcome,
})
}
async fn push_tool_call(
&self,
request: acp_old::PushToolCallParams,
) -> Result<acp_old::PushToolCallResponse, acp_old::Error> {
let cx = &mut self.cx.clone();
let old_acp_id = *self.next_tool_call_id.borrow() + 1;
self.next_tool_call_id.replace(old_acp_id);
cx.update(|cx| {
self.thread.borrow().update(cx, |thread, cx| {
thread.upsert_tool_call(
into_new_tool_call(acp::ToolCallId(old_acp_id.to_string().into()), request),
cx,
)
})
})??
.context("Failed to update thread")?;
Ok(acp_old::PushToolCallResponse {
id: acp_old::ToolCallId(old_acp_id),
})
}
async fn update_tool_call(
&self,
request: acp_old::UpdateToolCallParams,
) -> Result<(), acp_old::Error> {
let cx = &mut self.cx.clone();
cx.update(|cx| {
self.thread.borrow().update(cx, |thread, cx| {
thread.update_tool_call(
acp::ToolCallUpdate {
id: acp::ToolCallId(request.tool_call_id.0.to_string().into()),
fields: acp::ToolCallUpdateFields {
status: Some(into_new_tool_call_status(request.status)),
content: Some(
request
.content
.into_iter()
.map(into_new_tool_call_content)
.collect::<Vec<_>>(),
),
..Default::default()
},
},
cx,
)
})
})?
.context("Failed to update thread")??;
Ok(())
}
async fn update_plan(&self, request: acp_old::UpdatePlanParams) -> Result<(), acp_old::Error> {
let cx = &mut self.cx.clone();
cx.update(|cx| {
self.thread.borrow().update(cx, |thread, cx| {
thread.update_plan(
acp::Plan {
entries: request
.entries
.into_iter()
.map(into_new_plan_entry)
.collect(),
},
cx,
)
})
})?
.context("Failed to update thread")?;
Ok(())
}
async fn read_text_file(
&self,
acp_old::ReadTextFileParams { path, line, limit }: acp_old::ReadTextFileParams,
) -> Result<acp_old::ReadTextFileResponse, acp_old::Error> {
let content = self
.cx
.update(|cx| {
self.thread.borrow().update(cx, |thread, cx| {
thread.read_text_file(path, line, limit, false, cx)
})
})?
.context("Failed to update thread")?
.await?;
Ok(acp_old::ReadTextFileResponse { content })
}
async fn write_text_file(
&self,
acp_old::WriteTextFileParams { path, content }: acp_old::WriteTextFileParams,
) -> Result<(), acp_old::Error> {
self.cx
.update(|cx| {
self.thread
.borrow()
.update(cx, |thread, cx| thread.write_text_file(path, content, cx))
})?
.context("Failed to update thread")?
.await?;
Ok(())
}
}
fn into_new_tool_call(id: acp::ToolCallId, request: acp_old::PushToolCallParams) -> acp::ToolCall {
acp::ToolCall {
id,
title: request.label,
kind: acp_kind_from_old_icon(request.icon),
status: acp::ToolCallStatus::InProgress,
content: request
.content
.into_iter()
.map(into_new_tool_call_content)
.collect(),
locations: request
.locations
.into_iter()
.map(into_new_tool_call_location)
.collect(),
raw_input: None,
raw_output: None,
}
}
fn acp_kind_from_old_icon(icon: acp_old::Icon) -> acp::ToolKind {
match icon {
acp_old::Icon::FileSearch => acp::ToolKind::Search,
acp_old::Icon::Folder => acp::ToolKind::Search,
acp_old::Icon::Globe => acp::ToolKind::Search,
acp_old::Icon::Hammer => acp::ToolKind::Other,
acp_old::Icon::LightBulb => acp::ToolKind::Think,
acp_old::Icon::Pencil => acp::ToolKind::Edit,
acp_old::Icon::Regex => acp::ToolKind::Search,
acp_old::Icon::Terminal => acp::ToolKind::Execute,
}
}
fn into_new_tool_call_status(status: acp_old::ToolCallStatus) -> acp::ToolCallStatus {
match status {
acp_old::ToolCallStatus::Running => acp::ToolCallStatus::InProgress,
acp_old::ToolCallStatus::Finished => acp::ToolCallStatus::Completed,
acp_old::ToolCallStatus::Error => acp::ToolCallStatus::Failed,
}
}
fn into_new_tool_call_content(content: acp_old::ToolCallContent) -> acp::ToolCallContent {
match content {
acp_old::ToolCallContent::Markdown { markdown } => markdown.into(),
acp_old::ToolCallContent::Diff { diff } => acp::ToolCallContent::Diff {
diff: into_new_diff(diff),
},
}
}
fn into_new_diff(diff: acp_old::Diff) -> acp::Diff {
acp::Diff {
path: diff.path,
old_text: diff.old_text,
new_text: diff.new_text,
}
}
fn into_new_tool_call_location(location: acp_old::ToolCallLocation) -> acp::ToolCallLocation {
acp::ToolCallLocation {
path: location.path,
line: location.line,
}
}
fn into_new_plan_entry(entry: acp_old::PlanEntry) -> acp::PlanEntry {
acp::PlanEntry {
content: entry.content,
priority: into_new_plan_priority(entry.priority),
status: into_new_plan_status(entry.status),
}
}
fn into_new_plan_priority(priority: acp_old::PlanEntryPriority) -> acp::PlanEntryPriority {
match priority {
acp_old::PlanEntryPriority::Low => acp::PlanEntryPriority::Low,
acp_old::PlanEntryPriority::Medium => acp::PlanEntryPriority::Medium,
acp_old::PlanEntryPriority::High => acp::PlanEntryPriority::High,
}
}
fn into_new_plan_status(status: acp_old::PlanEntryStatus) -> acp::PlanEntryStatus {
match status {
acp_old::PlanEntryStatus::Pending => acp::PlanEntryStatus::Pending,
acp_old::PlanEntryStatus::InProgress => acp::PlanEntryStatus::InProgress,
acp_old::PlanEntryStatus::Completed => acp::PlanEntryStatus::Completed,
}
}
pub struct AcpConnection {
pub name: &'static str,
pub connection: acp_old::AgentConnection,
pub _child_status: Task<Result<()>>,
pub current_thread: Rc<RefCell<WeakEntity<AcpThread>>>,
}
impl AcpConnection {
pub fn stdio(
name: &'static str,
command: AgentServerCommand,
root_dir: &Path,
cx: &mut AsyncApp,
) -> Task<Result<Self>> {
let root_dir = root_dir.to_path_buf();
cx.spawn(async move |cx| {
let mut child = util::command::new_smol_command(&command.path)
.args(command.args.iter())
.current_dir(root_dir)
.stdin(std::process::Stdio::piped())
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::inherit())
.kill_on_drop(true)
.spawn()?;
let stdin = child.stdin.take().unwrap();
let stdout = child.stdout.take().unwrap();
log::trace!("Spawned (pid: {})", child.id());
let foreground_executor = cx.foreground_executor().clone();
let thread_rc = Rc::new(RefCell::new(WeakEntity::new_invalid()));
let (connection, io_fut) = acp_old::AgentConnection::connect_to_agent(
OldAcpClientDelegate::new(thread_rc.clone(), cx.clone()),
stdin,
stdout,
move |fut| foreground_executor.spawn(fut).detach(),
);
let io_task = cx.background_spawn(async move {
io_fut.await.log_err();
});
let child_status = cx.background_spawn(async move {
let result = match child.status().await {
Err(e) => Err(anyhow!(e)),
Ok(result) if result.success() => Ok(()),
Ok(result) => Err(anyhow!(result)),
};
drop(io_task);
result
});
Ok(Self {
name,
connection,
_child_status: child_status,
current_thread: thread_rc,
})
})
}
}
impl AgentConnection for AcpConnection {
fn new_thread(
self: Rc<Self>,
project: Entity<Project>,
_cwd: &Path,
cx: &mut App,
) -> Task<Result<Entity<AcpThread>>> {
let task = self.connection.request_any(
acp_old::InitializeParams {
protocol_version: acp_old::ProtocolVersion::latest(),
}
.into_any(),
);
let current_thread = self.current_thread.clone();
cx.spawn(async move |cx| {
let result = task.await?;
let result = acp_old::InitializeParams::response_from_any(result)?;
if !result.is_authenticated {
anyhow::bail!(AuthRequired::new())
}
cx.update(|cx| {
let thread = cx.new(|cx| {
let session_id = acp::SessionId("acp-old-no-id".into());
let action_log = cx.new(|_| ActionLog::new(project.clone()));
AcpThread::new(self.name, self.clone(), project, action_log, session_id)
});
current_thread.replace(thread.downgrade());
thread
})
})
}
fn auth_methods(&self) -> &[acp::AuthMethod] {
&[]
}
fn authenticate(&self, _method_id: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>> {
let task = self
.connection
.request_any(acp_old::AuthenticateParams.into_any());
cx.foreground_executor().spawn(async move {
task.await?;
Ok(())
})
}
fn prompt(
&self,
_id: Option<acp_thread::UserMessageId>,
params: acp::PromptRequest,
cx: &mut App,
) -> Task<Result<acp::PromptResponse>> {
let chunks = params
.prompt
.into_iter()
.filter_map(|block| match block {
acp::ContentBlock::Text(text) => {
Some(acp_old::UserMessageChunk::Text { text: text.text })
}
acp::ContentBlock::ResourceLink(link) => Some(acp_old::UserMessageChunk::Path {
path: link.uri.into(),
}),
_ => None,
})
.collect();
let task = self
.connection
.request_any(acp_old::SendUserMessageParams { chunks }.into_any());
cx.foreground_executor().spawn(async move {
task.await?;
anyhow::Ok(acp::PromptResponse {
stop_reason: acp::StopReason::EndTurn,
})
})
}
fn prompt_capabilities(&self) -> acp::PromptCapabilities {
acp::PromptCapabilities {
image: false,
audio: false,
embedded_context: false,
}
}
fn cancel(&self, _session_id: &acp::SessionId, cx: &mut App) {
let task = self
.connection
.request_any(acp_old::CancelSendMessageParams.into_any());
cx.foreground_executor()
.spawn(async move {
task.await?;
anyhow::Ok(())
})
.detach_and_log_err(cx)
}
fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
self
}
}

View file

@ -1,376 +0,0 @@
use acp_tools::AcpConnectionRegistry;
use action_log::ActionLog;
use agent_client_protocol::{self as acp, Agent as _, ErrorCode};
use anyhow::anyhow;
use collections::HashMap;
use futures::AsyncBufReadExt as _;
use futures::channel::oneshot;
use futures::io::BufReader;
use project::Project;
use serde::Deserialize;
use std::path::Path;
use std::rc::Rc;
use std::{any::Any, cell::RefCell};
use anyhow::{Context as _, Result};
use gpui::{App, AppContext as _, AsyncApp, Entity, Task, WeakEntity};
use crate::{AgentServerCommand, acp::UnsupportedVersion};
use acp_thread::{AcpThread, AgentConnection, AuthRequired, LoadError};
pub struct AcpConnection {
server_name: &'static str,
connection: Rc<acp::ClientSideConnection>,
sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
auth_methods: Vec<acp::AuthMethod>,
prompt_capabilities: acp::PromptCapabilities,
_io_task: Task<Result<()>>,
}
pub struct AcpSession {
thread: WeakEntity<AcpThread>,
suppress_abort_err: bool,
}
const MINIMUM_SUPPORTED_VERSION: acp::ProtocolVersion = acp::V1;
impl AcpConnection {
pub async fn stdio(
server_name: &'static str,
command: AgentServerCommand,
root_dir: &Path,
cx: &mut AsyncApp,
) -> Result<Self> {
let mut child = util::command::new_smol_command(&command.path)
.args(command.args.iter().map(|arg| arg.as_str()))
.envs(command.env.iter().flatten())
.current_dir(root_dir)
.stdin(std::process::Stdio::piped())
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::piped())
.kill_on_drop(true)
.spawn()?;
let stdout = child.stdout.take().context("Failed to take stdout")?;
let stdin = child.stdin.take().context("Failed to take stdin")?;
let stderr = child.stderr.take().context("Failed to take stderr")?;
log::trace!("Spawned (pid: {})", child.id());
let sessions = Rc::new(RefCell::new(HashMap::default()));
let client = ClientDelegate {
sessions: sessions.clone(),
cx: cx.clone(),
};
let (connection, io_task) = acp::ClientSideConnection::new(client, stdin, stdout, {
let foreground_executor = cx.foreground_executor().clone();
move |fut| {
foreground_executor.spawn(fut).detach();
}
});
let io_task = cx.background_spawn(io_task);
cx.background_spawn(async move {
let mut stderr = BufReader::new(stderr);
let mut line = String::new();
while let Ok(n) = stderr.read_line(&mut line).await
&& n > 0
{
log::warn!("agent stderr: {}", &line);
line.clear();
}
})
.detach();
cx.spawn({
let sessions = sessions.clone();
async move |cx| {
let status = child.status().await?;
for session in sessions.borrow().values() {
session
.thread
.update(cx, |thread, cx| {
thread.emit_load_error(LoadError::Exited { status }, cx)
})
.ok();
}
anyhow::Ok(())
}
})
.detach();
let connection = Rc::new(connection);
cx.update(|cx| {
AcpConnectionRegistry::default_global(cx).update(cx, |registry, cx| {
registry.set_active_connection(server_name, &connection, cx)
});
})?;
let response = connection
.initialize(acp::InitializeRequest {
protocol_version: acp::VERSION,
client_capabilities: acp::ClientCapabilities {
fs: acp::FileSystemCapability {
read_text_file: true,
write_text_file: true,
},
},
})
.await?;
if response.protocol_version < MINIMUM_SUPPORTED_VERSION {
return Err(UnsupportedVersion.into());
}
Ok(Self {
auth_methods: response.auth_methods,
connection,
server_name,
sessions,
prompt_capabilities: response.agent_capabilities.prompt_capabilities,
_io_task: io_task,
})
}
}
impl AgentConnection for AcpConnection {
fn new_thread(
self: Rc<Self>,
project: Entity<Project>,
cwd: &Path,
cx: &mut App,
) -> Task<Result<Entity<AcpThread>>> {
let conn = self.connection.clone();
let sessions = self.sessions.clone();
let cwd = cwd.to_path_buf();
cx.spawn(async move |cx| {
let response = conn
.new_session(acp::NewSessionRequest {
mcp_servers: vec![],
cwd,
})
.await
.map_err(|err| {
if err.code == acp::ErrorCode::AUTH_REQUIRED.code {
let mut error = AuthRequired::new();
if err.message != acp::ErrorCode::AUTH_REQUIRED.message {
error = error.with_description(err.message);
}
anyhow!(error)
} else {
anyhow!(err)
}
})?;
let session_id = response.session_id;
let action_log = cx.new(|_| ActionLog::new(project.clone()))?;
let thread = cx.new(|_cx| {
AcpThread::new(
self.server_name,
self.clone(),
project,
action_log,
session_id.clone(),
)
})?;
let session = AcpSession {
thread: thread.downgrade(),
suppress_abort_err: false,
};
sessions.borrow_mut().insert(session_id, session);
Ok(thread)
})
}
fn auth_methods(&self) -> &[acp::AuthMethod] {
&self.auth_methods
}
fn authenticate(&self, method_id: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>> {
let conn = self.connection.clone();
cx.foreground_executor().spawn(async move {
let result = conn
.authenticate(acp::AuthenticateRequest {
method_id: method_id.clone(),
})
.await?;
Ok(result)
})
}
fn prompt(
&self,
_id: Option<acp_thread::UserMessageId>,
params: acp::PromptRequest,
cx: &mut App,
) -> Task<Result<acp::PromptResponse>> {
let conn = self.connection.clone();
let sessions = self.sessions.clone();
let session_id = params.session_id.clone();
cx.foreground_executor().spawn(async move {
let result = conn.prompt(params).await;
let mut suppress_abort_err = false;
if let Some(session) = sessions.borrow_mut().get_mut(&session_id) {
suppress_abort_err = session.suppress_abort_err;
session.suppress_abort_err = false;
}
match result {
Ok(response) => Ok(response),
Err(err) => {
if err.code != ErrorCode::INTERNAL_ERROR.code {
anyhow::bail!(err)
}
let Some(data) = &err.data else {
anyhow::bail!(err)
};
// Temporary workaround until the following PR is generally available:
// https://github.com/google-gemini/gemini-cli/pull/6656
#[derive(Deserialize)]
#[serde(deny_unknown_fields)]
struct ErrorDetails {
details: Box<str>,
}
match serde_json::from_value(data.clone()) {
Ok(ErrorDetails { details }) => {
if suppress_abort_err && details.contains("This operation was aborted")
{
Ok(acp::PromptResponse {
stop_reason: acp::StopReason::Cancelled,
})
} else {
Err(anyhow!(details))
}
}
Err(_) => Err(anyhow!(err)),
}
}
}
})
}
fn prompt_capabilities(&self) -> acp::PromptCapabilities {
self.prompt_capabilities
}
fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
if let Some(session) = self.sessions.borrow_mut().get_mut(session_id) {
session.suppress_abort_err = true;
}
let conn = self.connection.clone();
let params = acp::CancelNotification {
session_id: session_id.clone(),
};
cx.foreground_executor()
.spawn(async move { conn.cancel(params).await })
.detach();
}
fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
self
}
}
struct ClientDelegate {
sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
cx: AsyncApp,
}
impl acp::Client for ClientDelegate {
async fn request_permission(
&self,
arguments: acp::RequestPermissionRequest,
) -> Result<acp::RequestPermissionResponse, acp::Error> {
let cx = &mut self.cx.clone();
let rx = self
.sessions
.borrow()
.get(&arguments.session_id)
.context("Failed to get session")?
.thread
.update(cx, |thread, cx| {
thread.request_tool_call_authorization(arguments.tool_call, arguments.options, cx)
})?;
let result = rx?.await;
let outcome = match result {
Ok(option) => acp::RequestPermissionOutcome::Selected { option_id: option },
Err(oneshot::Canceled) => acp::RequestPermissionOutcome::Cancelled,
};
Ok(acp::RequestPermissionResponse { outcome })
}
async fn write_text_file(
&self,
arguments: acp::WriteTextFileRequest,
) -> Result<(), acp::Error> {
let cx = &mut self.cx.clone();
let task = self
.sessions
.borrow()
.get(&arguments.session_id)
.context("Failed to get session")?
.thread
.update(cx, |thread, cx| {
thread.write_text_file(arguments.path, arguments.content, cx)
})?;
task.await?;
Ok(())
}
async fn read_text_file(
&self,
arguments: acp::ReadTextFileRequest,
) -> Result<acp::ReadTextFileResponse, acp::Error> {
let cx = &mut self.cx.clone();
let task = self
.sessions
.borrow()
.get(&arguments.session_id)
.context("Failed to get session")?
.thread
.update(cx, |thread, cx| {
thread.read_text_file(arguments.path, arguments.line, arguments.limit, false, cx)
})?;
let content = task.await?;
Ok(acp::ReadTextFileResponse { content })
}
async fn session_notification(
&self,
notification: acp::SessionNotification,
) -> Result<(), acp::Error> {
let cx = &mut self.cx.clone();
let sessions = self.sessions.borrow();
let session = sessions
.get(&notification.session_id)
.context("Failed to get session")?;
session.thread.update(cx, |thread, cx| {
thread.handle_session_update(notification.update, cx)
})??;
Ok(())
}
}

View file

@ -36,6 +36,7 @@ pub trait AgentServer: Send {
fn name(&self) -> SharedString;
fn empty_state_headline(&self) -> SharedString;
fn empty_state_message(&self) -> SharedString;
fn telemetry_id(&self) -> &'static str;
fn connect(
&self,
@ -97,7 +98,7 @@ pub struct AgentServerCommand {
}
impl AgentServerCommand {
pub(crate) async fn resolve(
pub async fn resolve(
path_bin_name: &'static str,
extra_args: &[&'static str],
fallback_path: Option<&Path>,

View file

@ -43,6 +43,10 @@ use acp_thread::{AcpThread, AgentConnection, AuthRequired, LoadError, MentionUri
pub struct ClaudeCode;
impl AgentServer for ClaudeCode {
fn telemetry_id(&self) -> &'static str {
"claude-code"
}
fn name(&self) -> SharedString {
"Claude Code".into()
}
@ -249,13 +253,19 @@ impl AgentConnection for ClaudeAgentConnection {
});
let action_log = cx.new(|_| ActionLog::new(project.clone()))?;
let thread = cx.new(|_cx| {
let thread = cx.new(|cx| {
AcpThread::new(
"Claude Code",
self.clone(),
project,
action_log,
session_id.clone(),
watch::Receiver::constant(acp::PromptCapabilities {
image: true,
audio: false,
embedded_context: true,
}),
cx,
)
})?;
@ -319,14 +329,6 @@ impl AgentConnection for ClaudeAgentConnection {
cx.foreground_executor().spawn(async move { end_rx.await? })
}
fn prompt_capabilities(&self) -> acp::PromptCapabilities {
acp::PromptCapabilities {
image: true,
audio: false,
embedded_context: true,
}
}
fn cancel(&self, session_id: &acp::SessionId, _cx: &mut App) {
let sessions = self.sessions.borrow();
let Some(session) = sessions.get(session_id) else {

View file

@ -22,6 +22,10 @@ impl CustomAgentServer {
}
impl crate::AgentServer for CustomAgentServer {
fn telemetry_id(&self) -> &'static str {
"custom"
}
fn name(&self) -> SharedString {
self.name.clone()
}

View file

@ -17,6 +17,10 @@ pub struct Gemini;
const ACP_ARG: &str = "--experimental-acp";
impl AgentServer for Gemini {
fn telemetry_id(&self) -> &'static str {
"gemini-cli"
}
fn name(&self) -> SharedString {
"Gemini CLI".into()
}
@ -53,7 +57,7 @@ impl AgentServer for Gemini {
return Err(LoadError::NotInstalled {
error_message: "Failed to find Gemini CLI binary".into(),
install_message: "Install Gemini CLI".into(),
install_command: "npm install -g @google/gemini-cli@preview".into()
install_command: Self::install_command().into(),
}.into());
};
@ -88,7 +92,7 @@ impl AgentServer for Gemini {
current_version
).into(),
upgrade_message: "Upgrade Gemini CLI to latest".into(),
upgrade_command: "npm install -g @google/gemini-cli@preview".into(),
upgrade_command: Self::upgrade_command().into(),
}.into())
}
}
@ -101,6 +105,20 @@ impl AgentServer for Gemini {
}
}
impl Gemini {
pub fn binary_name() -> &'static str {
"gemini"
}
pub fn install_command() -> &'static str {
"npm install -g @google/gemini-cli@preview"
}
pub fn upgrade_command() -> &'static str {
"npm install -g @google/gemini-cli@preview"
}
}
#[cfg(test)]
pub(crate) mod tests {
use super::*;

View file

@ -6,7 +6,7 @@ use agent2::HistoryStore;
use collections::HashMap;
use editor::{Editor, EditorMode, MinimapVisibility};
use gpui::{
AnyEntity, App, AppContext as _, Entity, EntityId, EventEmitter, Focusable,
AnyEntity, App, AppContext as _, Entity, EntityId, EventEmitter, Focusable, ScrollHandle,
TextStyleRefinement, WeakEntity, Window,
};
use language::language_settings::SoftWrap;
@ -154,10 +154,22 @@ impl EntryViewState {
});
}
}
AgentThreadEntry::AssistantMessage(_) => {
if index == self.entries.len() {
self.entries.push(Entry::empty())
}
AgentThreadEntry::AssistantMessage(message) => {
let entry = if let Some(Entry::AssistantMessage(entry)) =
self.entries.get_mut(index)
{
entry
} else {
self.set_entry(
index,
Entry::AssistantMessage(AssistantMessageEntry::default()),
);
let Some(Entry::AssistantMessage(entry)) = self.entries.get_mut(index) else {
unreachable!()
};
entry
};
entry.sync(message);
}
};
}
@ -177,7 +189,7 @@ impl EntryViewState {
pub fn settings_changed(&mut self, cx: &mut App) {
for entry in self.entries.iter() {
match entry {
Entry::UserMessage { .. } => {}
Entry::UserMessage { .. } | Entry::AssistantMessage { .. } => {}
Entry::Content(response_views) => {
for view in response_views.values() {
if let Ok(diff_editor) = view.clone().downcast::<Editor>() {
@ -208,9 +220,29 @@ pub enum ViewEvent {
MessageEditorEvent(Entity<MessageEditor>, MessageEditorEvent),
}
#[derive(Default, Debug)]
pub struct AssistantMessageEntry {
scroll_handles_by_chunk_index: HashMap<usize, ScrollHandle>,
}
impl AssistantMessageEntry {
pub fn scroll_handle_for_chunk(&self, ix: usize) -> Option<ScrollHandle> {
self.scroll_handles_by_chunk_index.get(&ix).cloned()
}
pub fn sync(&mut self, message: &acp_thread::AssistantMessage) {
if let Some(acp_thread::AssistantMessageChunk::Thought { .. }) = message.chunks.last() {
let ix = message.chunks.len() - 1;
let handle = self.scroll_handles_by_chunk_index.entry(ix).or_default();
handle.scroll_to_bottom();
}
}
}
#[derive(Debug)]
pub enum Entry {
UserMessage(Entity<MessageEditor>),
AssistantMessage(AssistantMessageEntry),
Content(HashMap<EntityId, AnyEntity>),
}
@ -218,7 +250,7 @@ impl Entry {
pub fn message_editor(&self) -> Option<&Entity<MessageEditor>> {
match self {
Self::UserMessage(editor) => Some(editor),
Entry::Content(_) => None,
Self::AssistantMessage(_) | Self::Content(_) => None,
}
}
@ -239,6 +271,16 @@ impl Entry {
.map(|entity| entity.downcast::<TerminalView>().unwrap())
}
pub fn scroll_handle_for_assistant_message_chunk(
&self,
chunk_ix: usize,
) -> Option<ScrollHandle> {
match self {
Self::AssistantMessage(message) => message.scroll_handle_for_chunk(chunk_ix),
Self::UserMessage(_) | Self::Content(_) => None,
}
}
fn content_map(&self) -> Option<&HashMap<EntityId, AnyEntity>> {
match self {
Self::Content(map) => Some(map),
@ -254,7 +296,7 @@ impl Entry {
pub fn has_content(&self) -> bool {
match self {
Self::Content(map) => !map.is_empty(),
Self::UserMessage(_) => false,
Self::UserMessage(_) | Self::AssistantMessage(_) => false,
}
}
}

View file

@ -373,7 +373,7 @@ impl MessageEditor {
if Img::extensions().contains(&extension) && !extension.contains("svg") {
if !self.prompt_capabilities.get().image {
return Task::ready(Err(anyhow!("This agent does not support images yet")));
return Task::ready(Err(anyhow!("This model does not support images yet")));
}
let task = self
.project

View file

@ -462,7 +462,7 @@ impl AcpThreadHistory {
cx.notify();
}))
.end_slot::<IconButton>(if hovered || selected {
.end_slot::<IconButton>(if hovered {
Some(
IconButton::new("delete", IconName::Trash)
.shape(IconButtonShape::Square)

File diff suppressed because it is too large Load diff

View file

@ -3,19 +3,23 @@ mod configure_context_server_modal;
mod manage_profiles_modal;
mod tool_picker;
use std::{sync::Arc, time::Duration};
use std::{ops::Range, sync::Arc, time::Duration};
use agent_servers::{AgentServerCommand, AgentServerSettings, AllAgentServersSettings, Gemini};
use agent_settings::AgentSettings;
use anyhow::Result;
use assistant_tool::{ToolSource, ToolWorkingSet};
use cloud_llm_client::Plan;
use collections::HashMap;
use context_server::ContextServerId;
use editor::{Editor, SelectionEffects, scroll::Autoscroll};
use extension::ExtensionManifest;
use extension_host::ExtensionStore;
use fs::Fs;
use gpui::{
Action, Animation, AnimationExt as _, AnyView, App, Corner, Entity, EventEmitter, FocusHandle,
Focusable, ScrollHandle, Subscription, Task, Transformation, WeakEntity, percentage,
Action, Animation, AnimationExt as _, AnyView, App, AsyncWindowContext, Corner, Entity,
EventEmitter, FocusHandle, Focusable, Hsla, ScrollHandle, Subscription, Task, Transformation,
WeakEntity, percentage,
};
use language::LanguageRegistry;
use language_model::{
@ -23,23 +27,24 @@ use language_model::{
};
use notifications::status_toast::{StatusToast, ToastIcon};
use project::{
Project,
context_server_store::{ContextServerConfiguration, ContextServerStatus, ContextServerStore},
project_settings::{ContextServerSettings, ProjectSettings},
};
use settings::{Settings, update_settings_file};
use settings::{Settings, SettingsStore, update_settings_file};
use ui::{
Chip, ContextMenu, Disclosure, Divider, DividerColor, ElevationIndex, Indicator, PopoverMenu,
Scrollbar, ScrollbarState, Switch, SwitchColor, SwitchField, Tooltip, prelude::*,
};
use util::ResultExt as _;
use workspace::Workspace;
use workspace::{Workspace, create_and_open_local_file};
use zed_actions::ExtensionCategoryFilter;
pub(crate) use configure_context_server_modal::ConfigureContextServerModal;
pub(crate) use manage_profiles_modal::ManageProfilesModal;
use crate::{
AddContextServer,
AddContextServer, ExternalAgent, NewExternalAgentThread,
agent_configuration::add_llm_provider_modal::{AddLlmProviderModal, LlmCompatibleProvider},
};
@ -47,6 +52,7 @@ pub struct AgentConfiguration {
fs: Arc<dyn Fs>,
language_registry: Arc<LanguageRegistry>,
workspace: WeakEntity<Workspace>,
project: WeakEntity<Project>,
focus_handle: FocusHandle,
configuration_views_by_provider: HashMap<LanguageModelProviderId, AnyView>,
context_server_store: Entity<ContextServerStore>,
@ -56,6 +62,8 @@ pub struct AgentConfiguration {
_registry_subscription: Subscription,
scroll_handle: ScrollHandle,
scrollbar_state: ScrollbarState,
gemini_is_installed: bool,
_check_for_gemini: Task<()>,
}
impl AgentConfiguration {
@ -65,6 +73,7 @@ impl AgentConfiguration {
tools: Entity<ToolWorkingSet>,
language_registry: Arc<LanguageRegistry>,
workspace: WeakEntity<Workspace>,
project: WeakEntity<Project>,
window: &mut Window,
cx: &mut Context<Self>,
) -> Self {
@ -89,6 +98,11 @@ impl AgentConfiguration {
cx.subscribe(&context_server_store, |_, _, _, cx| cx.notify())
.detach();
cx.observe_global_in::<SettingsStore>(window, |this, _, cx| {
this.check_for_gemini(cx);
cx.notify();
})
.detach();
let scroll_handle = ScrollHandle::new();
let scrollbar_state = ScrollbarState::new(scroll_handle.clone());
@ -97,6 +111,7 @@ impl AgentConfiguration {
fs,
language_registry,
workspace,
project,
focus_handle,
configuration_views_by_provider: HashMap::default(),
context_server_store,
@ -106,8 +121,11 @@ impl AgentConfiguration {
_registry_subscription: registry_subscription,
scroll_handle,
scrollbar_state,
gemini_is_installed: false,
_check_for_gemini: Task::ready(()),
};
this.build_provider_configuration_views(window, cx);
this.check_for_gemini(cx);
this
}
@ -137,6 +155,34 @@ impl AgentConfiguration {
self.configuration_views_by_provider
.insert(provider.id(), configuration_view);
}
fn check_for_gemini(&mut self, cx: &mut Context<Self>) {
let project = self.project.clone();
let settings = AllAgentServersSettings::get_global(cx).clone();
self._check_for_gemini = cx.spawn({
async move |this, cx| {
let Some(project) = project.upgrade() else {
return;
};
let gemini_is_installed = AgentServerCommand::resolve(
Gemini::binary_name(),
&[],
// TODO expose fallback path from the Gemini/CC types so we don't have to hardcode it again here
None,
settings.gemini,
&project,
cx,
)
.await
.is_some();
this.update(cx, |this, cx| {
this.gemini_is_installed = gemini_is_installed;
cx.notify();
})
.ok();
}
});
}
}
impl Focusable for AgentConfiguration {
@ -211,7 +257,6 @@ impl AgentConfiguration {
.child(
h_flex()
.id(provider_id_string.clone())
.cursor_pointer()
.px_2()
.py_0p5()
.w_full()
@ -231,10 +276,7 @@ impl AgentConfiguration {
h_flex()
.w_full()
.gap_1()
.child(
Label::new(provider_name.clone())
.size(LabelSize::Large),
)
.child(Label::new(provider_name.clone()))
.map(|this| {
if is_zed_provider && is_signed_in {
this.child(
@ -279,7 +321,7 @@ impl AgentConfiguration {
"Start New Thread",
)
.icon_position(IconPosition::Start)
.icon(IconName::Plus)
.icon(IconName::Thread)
.icon_size(IconSize::Small)
.icon_color(Color::Muted)
.label_size(LabelSize::Small)
@ -378,7 +420,7 @@ impl AgentConfiguration {
),
)
.child(
Label::new("Add at least one provider to use AI-powered features.")
Label::new("Add at least one provider to use AI-powered features with Zed's native agent.")
.color(Color::Muted),
),
),
@ -519,6 +561,14 @@ impl AgentConfiguration {
}
}
fn card_item_bg_color(&self, cx: &mut Context<Self>) -> Hsla {
cx.theme().colors().background.opacity(0.25)
}
fn card_item_border_color(&self, cx: &mut Context<Self>) -> Hsla {
cx.theme().colors().border.opacity(0.6)
}
fn render_context_servers_section(
&mut self,
window: &mut Window,
@ -536,7 +586,12 @@ impl AgentConfiguration {
v_flex()
.gap_0p5()
.child(Headline::new("Model Context Protocol (MCP) Servers"))
.child(Label::new("Connect to context servers through the Model Context Protocol, either using Zed extensions or directly.").color(Color::Muted)),
.child(
Label::new(
"All context servers connected through the Model Context Protocol.",
)
.color(Color::Muted),
),
)
.children(
context_server_ids.into_iter().map(|context_server_id| {
@ -546,7 +601,7 @@ impl AgentConfiguration {
.child(
h_flex()
.justify_between()
.gap_2()
.gap_1p5()
.child(
h_flex().w_full().child(
Button::new("add-context-server", "Add Custom Server")
@ -637,8 +692,6 @@ impl AgentConfiguration {
.map_or([].as_slice(), |tools| tools.as_slice());
let tool_count = tools.len();
let border_color = cx.theme().colors().border.opacity(0.6);
let (source_icon, source_tooltip) = if is_from_extension {
(
IconName::ZedMcpExtension,
@ -781,8 +834,8 @@ impl AgentConfiguration {
.id(item_id.clone())
.border_1()
.rounded_md()
.border_color(border_color)
.bg(cx.theme().colors().background.opacity(0.2))
.border_color(self.card_item_border_color(cx))
.bg(self.card_item_bg_color(cx))
.overflow_hidden()
.child(
h_flex()
@ -790,7 +843,11 @@ impl AgentConfiguration {
.justify_between()
.when(
error.is_some() || are_tools_expanded && tool_count >= 1,
|element| element.border_b_1().border_color(border_color),
|element| {
element
.border_b_1()
.border_color(self.card_item_border_color(cx))
},
)
.child(
h_flex()
@ -972,6 +1029,195 @@ impl AgentConfiguration {
))
})
}
fn render_agent_servers_section(&mut self, cx: &mut Context<Self>) -> impl IntoElement {
let settings = AllAgentServersSettings::get_global(cx).clone();
let user_defined_agents = settings
.custom
.iter()
.map(|(name, settings)| {
self.render_agent_server(
IconName::Ai,
name.clone(),
ExternalAgent::Custom {
name: name.clone(),
settings: settings.clone(),
},
None,
cx,
)
.into_any_element()
})
.collect::<Vec<_>>();
v_flex()
.border_b_1()
.border_color(cx.theme().colors().border)
.child(
v_flex()
.p(DynamicSpacing::Base16.rems(cx))
.pr(DynamicSpacing::Base20.rems(cx))
.gap_2()
.child(
v_flex()
.gap_0p5()
.child(
h_flex()
.w_full()
.gap_2()
.justify_between()
.child(Headline::new("External Agents"))
.child(
Button::new("add-agent", "Add Agent")
.icon_position(IconPosition::Start)
.icon(IconName::Plus)
.icon_size(IconSize::Small)
.icon_color(Color::Muted)
.label_size(LabelSize::Small)
.on_click(
move |_, window, cx| {
if let Some(workspace) = window.root().flatten() {
let workspace = workspace.downgrade();
window
.spawn(cx, async |cx| {
open_new_agent_servers_entry_in_settings_editor(
workspace,
cx,
).await
})
.detach_and_log_err(cx);
}
}
),
)
)
.child(
Label::new(
"Bring the agent of your choice to Zed via our new Agent Client Protocol.",
)
.color(Color::Muted),
),
)
.child(self.render_agent_server(
IconName::AiGemini,
"Gemini CLI",
ExternalAgent::Gemini,
(!self.gemini_is_installed).then_some(Gemini::install_command().into()),
cx,
))
// TODO add CC
.children(user_defined_agents),
)
}
fn render_agent_server(
&self,
icon: IconName,
name: impl Into<SharedString>,
agent: ExternalAgent,
install_command: Option<SharedString>,
cx: &mut Context<Self>,
) -> impl IntoElement {
let name = name.into();
h_flex()
.p_1()
.pl_2()
.gap_1p5()
.justify_between()
.border_1()
.rounded_md()
.border_color(self.card_item_border_color(cx))
.bg(self.card_item_bg_color(cx))
.overflow_hidden()
.child(
h_flex()
.gap_1p5()
.child(Icon::new(icon).size(IconSize::Small).color(Color::Muted))
.child(Label::new(name.clone())),
)
.map(|this| {
if let Some(install_command) = install_command {
this.child(
Button::new(
SharedString::from(format!("install_external_agent-{name}")),
"Install Agent",
)
.label_size(LabelSize::Small)
.icon(IconName::Plus)
.icon_position(IconPosition::Start)
.icon_size(IconSize::XSmall)
.icon_color(Color::Muted)
.tooltip(Tooltip::text(install_command.clone()))
.on_click(cx.listener(
move |this, _, window, cx| {
let Some(project) = this.project.upgrade() else {
return;
};
let Some(workspace) = this.workspace.upgrade() else {
return;
};
let cwd = project.read(cx).first_project_directory(cx);
let shell =
project.read(cx).terminal_settings(&cwd, cx).shell.clone();
let spawn_in_terminal = task::SpawnInTerminal {
id: task::TaskId(install_command.to_string()),
full_label: install_command.to_string(),
label: install_command.to_string(),
command: Some(install_command.to_string()),
args: Vec::new(),
command_label: install_command.to_string(),
cwd,
env: Default::default(),
use_new_terminal: true,
allow_concurrent_runs: true,
reveal: Default::default(),
reveal_target: Default::default(),
hide: Default::default(),
shell,
show_summary: true,
show_command: true,
show_rerun: false,
};
let task = workspace.update(cx, |workspace, cx| {
workspace.spawn_in_terminal(spawn_in_terminal, window, cx)
});
cx.spawn(async move |this, cx| {
task.await;
this.update(cx, |this, cx| {
this.check_for_gemini(cx);
})
.ok();
})
.detach();
},
)),
)
} else {
this.child(
h_flex().gap_1().child(
Button::new(
SharedString::from(format!("start_acp_thread-{name}")),
"Start New Thread",
)
.label_size(LabelSize::Small)
.icon(IconName::Thread)
.icon_position(IconPosition::Start)
.icon_size(IconSize::XSmall)
.icon_color(Color::Muted)
.on_click(move |_, window, cx| {
window.dispatch_action(
NewExternalAgentThread {
agent: Some(agent.clone()),
}
.boxed_clone(),
cx,
);
}),
),
)
}
})
}
}
impl Render for AgentConfiguration {
@ -991,6 +1237,7 @@ impl Render for AgentConfiguration {
.size_full()
.overflow_y_scroll()
.child(self.render_general_settings_section(cx))
.child(self.render_agent_servers_section(cx))
.child(self.render_context_servers_section(window, cx))
.child(self.render_provider_configuration_section(cx)),
)
@ -1109,3 +1356,109 @@ fn show_unable_to_uninstall_extension_with_context_server(
workspace.toggle_status_toast(status_toast, cx);
}
async fn open_new_agent_servers_entry_in_settings_editor(
workspace: WeakEntity<Workspace>,
cx: &mut AsyncWindowContext,
) -> Result<()> {
let settings_editor = workspace
.update_in(cx, |_, window, cx| {
create_and_open_local_file(paths::settings_file(), window, cx, || {
settings::initial_user_settings_content().as_ref().into()
})
})?
.await?
.downcast::<Editor>()
.unwrap();
settings_editor
.downgrade()
.update_in(cx, |item, window, cx| {
let text = item.buffer().read(cx).snapshot(cx).text();
let settings = cx.global::<SettingsStore>();
let mut unique_server_name = None;
let edits = settings.edits_for_update::<AllAgentServersSettings>(&text, |file| {
let server_name: Option<SharedString> = (0..u8::MAX)
.map(|i| {
if i == 0 {
"your_agent".into()
} else {
format!("your_agent_{}", i).into()
}
})
.find(|name| !file.custom.contains_key(name));
if let Some(server_name) = server_name {
unique_server_name = Some(server_name.clone());
file.custom.insert(
server_name,
AgentServerSettings {
command: AgentServerCommand {
path: "path_to_executable".into(),
args: vec![],
env: Some(HashMap::default()),
},
},
);
}
});
if edits.is_empty() {
return;
}
let ranges = edits
.iter()
.map(|(range, _)| range.clone())
.collect::<Vec<_>>();
item.edit(edits, cx);
if let Some((unique_server_name, buffer)) =
unique_server_name.zip(item.buffer().read(cx).as_singleton())
{
let snapshot = buffer.read(cx).snapshot();
if let Some(range) =
find_text_in_buffer(&unique_server_name, ranges[0].start, &snapshot)
{
item.change_selections(
SelectionEffects::scroll(Autoscroll::newest()),
window,
cx,
|selections| {
selections.select_ranges(vec![range]);
},
);
}
}
})
}
fn find_text_in_buffer(
text: &str,
start: usize,
snapshot: &language::BufferSnapshot,
) -> Option<Range<usize>> {
let chars = text.chars().collect::<Vec<char>>();
let mut offset = start;
let mut char_offset = 0;
for c in snapshot.chars_at(start) {
if char_offset >= chars.len() {
break;
}
offset += 1;
if c == chars[char_offset] {
char_offset += 1;
} else {
char_offset = 0;
}
}
if char_offset == chars.len() {
Some(offset.saturating_sub(chars.len())..offset)
} else {
None
}
}

View file

@ -1529,6 +1529,7 @@ impl AgentDiff {
| AcpThreadEvent::TokenUsageUpdated
| AcpThreadEvent::EntriesRemoved(_)
| AcpThreadEvent::ToolAuthorizationRequired
| AcpThreadEvent::PromptCapabilitiesUpdated
| AcpThreadEvent::Retry(_) => {}
}
}

View file

@ -9,9 +9,12 @@ use agent_servers::AgentServerSettings;
use agent2::{DbThreadMetadata, HistoryEntry};
use db::kvp::{Dismissable, KEY_VALUE_STORE};
use serde::{Deserialize, Serialize};
use zed_actions::OpenBrowser;
use zed_actions::agent::ReauthenticateAgent;
use crate::acp::{AcpThreadHistory, ThreadHistoryEvent};
use crate::agent_diff::AgentDiffThread;
use crate::ui::AcpOnboardingModal;
use crate::{
AddContextServer, AgentDiffPane, ContinueThread, ContinueWithBurnMode,
DeleteRecentlyOpenThread, ExpandMessageEditor, Follow, InlineAssistant, NewTextThread,
@ -75,7 +78,10 @@ use workspace::{
};
use zed_actions::{
DecreaseBufferFontSize, IncreaseBufferFontSize, ResetBufferFontSize,
agent::{OpenOnboardingModal, OpenSettings, ResetOnboarding, ToggleModelSelector},
agent::{
OpenAcpOnboardingModal, OpenOnboardingModal, OpenSettings, ResetOnboarding,
ToggleModelSelector,
},
assistant::{OpenRulesLibrary, ToggleFocus},
};
@ -199,6 +205,9 @@ pub fn init(cx: &mut App) {
.register_action(|workspace, _: &OpenOnboardingModal, window, cx| {
AgentOnboardingModal::toggle(workspace, window, cx)
})
.register_action(|workspace, _: &OpenAcpOnboardingModal, window, cx| {
AcpOnboardingModal::toggle(workspace, window, cx)
})
.register_action(|_workspace, _: &ResetOnboarding, window, cx| {
window.dispatch_action(workspace::RestoreBanner.boxed_clone(), cx);
window.refresh();
@ -240,6 +249,7 @@ enum WhichFontSize {
None,
}
// TODO unify this with ExternalAgent
#[derive(Debug, Default, Clone, PartialEq, Serialize, Deserialize)]
pub enum AgentType {
#[default]
@ -588,17 +598,6 @@ impl AgentPanel {
None
};
// Wait for the Gemini/Native feature flag to be available.
let client = workspace.read_with(cx, |workspace, _| workspace.client().clone())?;
if !client.status().borrow().is_signed_out() {
cx.update(|_, cx| {
cx.wait_for_flag_or_timeout::<feature_flags::GeminiAndNativeFeatureFlag>(
Duration::from_secs(2),
)
})?
.await;
}
let panel = workspace.update_in(cx, |workspace, window, cx| {
let panel = cx.new(|cx| {
Self::new(
@ -1024,6 +1023,8 @@ impl AgentPanel {
}
fn new_prompt_editor(&mut self, window: &mut Window, cx: &mut Context<Self>) {
telemetry::event!("Agent Thread Started", agent = "zed-text");
let context = self
.context_store
.update(cx, |context_store, cx| context_store.create(cx));
@ -1116,6 +1117,8 @@ impl AgentPanel {
}
};
telemetry::event!("Agent Thread Started", agent = ext_agent.name());
let server = ext_agent.server(fs, history);
this.update_in(cx, |this, window, cx| {
@ -1473,6 +1476,7 @@ impl AgentPanel {
tools,
self.language_registry.clone(),
self.workspace.clone(),
self.project.downgrade(),
window,
cx,
)
@ -1844,19 +1848,6 @@ impl AgentPanel {
menu
}
pub fn set_selected_agent(
&mut self,
agent: AgentType,
window: &mut Window,
cx: &mut Context<Self>,
) {
if self.selected_agent != agent {
self.selected_agent = agent.clone();
self.serialize(cx);
}
self.new_agent_thread(agent, window, cx);
}
pub fn selected_agent(&self) -> AgentType {
self.selected_agent.clone()
}
@ -1867,6 +1858,11 @@ impl AgentPanel {
window: &mut Window,
cx: &mut Context<Self>,
) {
if self.selected_agent != agent {
self.selected_agent = agent.clone();
self.serialize(cx);
}
match agent {
AgentType::Zed => {
window.dispatch_action(
@ -2204,6 +2200,8 @@ impl AgentPanel {
"Enable Full Screen"
};
let selected_agent = self.selected_agent.clone();
PopoverMenu::new("agent-options-menu")
.trigger_with_tooltip(
IconButton::new("agent-options-menu", IconName::Ellipsis)
@ -2283,6 +2281,11 @@ impl AgentPanel {
.action("Settings", Box::new(OpenSettings))
.separator()
.action(full_screen_label, Box::new(ToggleZoom));
if selected_agent == AgentType::Gemini {
menu = menu.action("Reauthenticate", Box::new(ReauthenticateAgent))
}
menu
}))
}
@ -2317,6 +2320,8 @@ impl AgentPanel {
.menu({
let menu = self.assistant_navigation_menu.clone();
move |window, cx| {
telemetry::event!("View Thread History Clicked");
if let Some(menu) = menu.as_ref() {
menu.update(cx, |_, cx| {
cx.defer_in(window, |menu, window, cx| {
@ -2495,6 +2500,8 @@ impl AgentPanel {
let workspace = self.workspace.clone();
move |window, cx| {
telemetry::event!("New Thread Clicked");
let active_thread = active_thread.clone();
Some(ContextMenu::build(window, cx, |mut menu, _window, cx| {
menu = menu
@ -2536,7 +2543,7 @@ impl AgentPanel {
workspace.panel::<AgentPanel>(cx)
{
panel.update(cx, |panel, cx| {
panel.set_selected_agent(
panel.new_agent_thread(
AgentType::NativeAgent,
window,
cx,
@ -2562,7 +2569,7 @@ impl AgentPanel {
workspace.panel::<AgentPanel>(cx)
{
panel.update(cx, |panel, cx| {
panel.set_selected_agent(
panel.new_agent_thread(
AgentType::TextThread,
window,
cx,
@ -2590,7 +2597,7 @@ impl AgentPanel {
workspace.panel::<AgentPanel>(cx)
{
panel.update(cx, |panel, cx| {
panel.set_selected_agent(
panel.new_agent_thread(
AgentType::Gemini,
window,
cx,
@ -2617,7 +2624,7 @@ impl AgentPanel {
workspace.panel::<AgentPanel>(cx)
{
panel.update(cx, |panel, cx| {
panel.set_selected_agent(
panel.new_agent_thread(
AgentType::ClaudeCode,
window,
cx,
@ -2650,7 +2657,7 @@ impl AgentPanel {
workspace.panel::<AgentPanel>(cx)
{
panel.update(cx, |panel, cx| {
panel.set_selected_agent(
panel.new_agent_thread(
AgentType::Custom {
name: agent_name
.clone(),
@ -2671,6 +2678,15 @@ impl AgentPanel {
}
menu
})
.when(cx.has_flag::<GeminiAndNativeFeatureFlag>(), |menu| {
menu.separator().link(
"Add Other Agents",
OpenBrowser {
url: zed_urls::external_agents_docs(cx),
}
.boxed_clone(),
)
});
menu
}))
@ -3751,6 +3767,11 @@ impl Render for AgentPanel {
}
}))
.on_action(cx.listener(Self::toggle_burn_mode))
.on_action(cx.listener(|this, _: &ReauthenticateAgent, window, cx| {
if let Some(thread_view) = this.active_thread_view() {
thread_view.update(cx, |thread_view, cx| thread_view.reauthenticate(window, cx))
}
}))
.child(self.render_toolbar(window, cx))
.children(self.render_onboarding(window, cx))
.map(|parent| match &self.active_view {

View file

@ -160,6 +160,7 @@ pub struct NewNativeAgentThreadFromSummary {
from_session_id: agent_client_protocol::SessionId,
}
// TODO unify this with AgentType
#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema)]
#[serde(rename_all = "snake_case")]
enum ExternalAgent {
@ -174,6 +175,15 @@ enum ExternalAgent {
}
impl ExternalAgent {
fn name(&self) -> &'static str {
match self {
Self::NativeAgent => "zed",
Self::Gemini => "gemini-cli",
Self::ClaudeCode => "claude-code",
Self::Custom { .. } => "custom",
}
}
pub fn server(
&self,
fs: Arc<dyn fs::Fs>,

View file

@ -6,7 +6,8 @@ use feature_flags::ZedProFeatureFlag;
use fuzzy::{StringMatch, StringMatchCandidate, match_strings};
use gpui::{Action, AnyElement, App, BackgroundExecutor, DismissEvent, Subscription, Task};
use language_model::{
ConfiguredModel, LanguageModel, LanguageModelProviderId, LanguageModelRegistry,
AuthenticateError, ConfiguredModel, LanguageModel, LanguageModelProviderId,
LanguageModelRegistry,
};
use ordered_float::OrderedFloat;
use picker::{Picker, PickerDelegate};
@ -76,6 +77,7 @@ pub struct LanguageModelPickerDelegate {
all_models: Arc<GroupedModels>,
filtered_entries: Vec<LanguageModelPickerEntry>,
selected_index: usize,
_authenticate_all_providers_task: Task<()>,
_subscriptions: Vec<Subscription>,
}
@ -96,6 +98,7 @@ impl LanguageModelPickerDelegate {
selected_index: Self::get_active_model_index(&entries, get_active_model(cx)),
filtered_entries: entries,
get_active_model: Arc::new(get_active_model),
_authenticate_all_providers_task: Self::authenticate_all_providers(cx),
_subscriptions: vec![cx.subscribe_in(
&LanguageModelRegistry::global(cx),
window,
@ -139,6 +142,56 @@ impl LanguageModelPickerDelegate {
.unwrap_or(0)
}
/// Authenticates all providers in the [`LanguageModelRegistry`].
///
/// We do this so that we can populate the language selector with all of the
/// models from the configured providers.
fn authenticate_all_providers(cx: &mut App) -> Task<()> {
let authenticate_all_providers = LanguageModelRegistry::global(cx)
.read(cx)
.providers()
.iter()
.map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
.collect::<Vec<_>>();
cx.spawn(async move |_cx| {
for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
if let Err(err) = authenticate_task.await {
if matches!(err, AuthenticateError::CredentialsNotFound) {
// Since we're authenticating these providers in the
// background for the purposes of populating the
// language selector, we don't care about providers
// where the credentials are not found.
} else {
// Some providers have noisy failure states that we
// don't want to spam the logs with every time the
// language model selector is initialized.
//
// Ideally these should have more clear failure modes
// that we know are safe to ignore here, like what we do
// with `CredentialsNotFound` above.
match provider_id.0.as_ref() {
"lmstudio" | "ollama" => {
// LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
//
// These fail noisily, so we don't log them.
}
"copilot_chat" => {
// Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
}
_ => {
log::error!(
"Failed to authenticate provider: {}: {err}",
provider_name.0
);
}
}
}
}
}
})
}
pub fn active_model(&self, cx: &App) -> Option<ConfiguredModel> {
(self.get_active_model)(cx)
}

View file

@ -361,6 +361,7 @@ impl TextThreadEditor {
if self.sending_disabled(cx) {
return;
}
telemetry::event!("Agent Message Sent", agent = "zed-text");
self.send_to_model(window, cx);
}

View file

@ -1,3 +1,4 @@
mod acp_onboarding_modal;
mod agent_notification;
mod burn_mode_tooltip;
mod context_pill;
@ -6,6 +7,7 @@ mod onboarding_modal;
pub mod preview;
mod unavailable_editing_tooltip;
pub use acp_onboarding_modal::*;
pub use agent_notification::*;
pub use burn_mode_tooltip::*;
pub use context_pill::*;

View file

@ -0,0 +1,254 @@
use client::zed_urls;
use gpui::{
ClickEvent, DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, MouseDownEvent, Render,
linear_color_stop, linear_gradient,
};
use ui::{TintColor, Vector, VectorName, prelude::*};
use workspace::{ModalView, Workspace};
use crate::agent_panel::{AgentPanel, AgentType};
macro_rules! acp_onboarding_event {
($name:expr) => {
telemetry::event!($name, source = "ACP Onboarding");
};
($name:expr, $($key:ident $(= $value:expr)?),+ $(,)?) => {
telemetry::event!($name, source = "ACP Onboarding", $($key $(= $value)?),+);
};
}
pub struct AcpOnboardingModal {
focus_handle: FocusHandle,
workspace: Entity<Workspace>,
}
impl AcpOnboardingModal {
pub fn toggle(workspace: &mut Workspace, window: &mut Window, cx: &mut Context<Workspace>) {
let workspace_entity = cx.entity();
workspace.toggle_modal(window, cx, |_window, cx| Self {
workspace: workspace_entity,
focus_handle: cx.focus_handle(),
});
}
fn open_panel(&mut self, _: &ClickEvent, window: &mut Window, cx: &mut Context<Self>) {
self.workspace.update(cx, |workspace, cx| {
workspace.focus_panel::<AgentPanel>(window, cx);
if let Some(panel) = workspace.panel::<AgentPanel>(cx) {
panel.update(cx, |panel, cx| {
panel.new_agent_thread(AgentType::Gemini, window, cx);
});
}
});
cx.emit(DismissEvent);
acp_onboarding_event!("Open Panel Clicked");
}
fn view_docs(&mut self, _: &ClickEvent, _: &mut Window, cx: &mut Context<Self>) {
cx.open_url(&zed_urls::external_agents_docs(cx));
cx.notify();
acp_onboarding_event!("Documentation Link Clicked");
}
fn cancel(&mut self, _: &menu::Cancel, _: &mut Window, cx: &mut Context<Self>) {
cx.emit(DismissEvent);
}
}
impl EventEmitter<DismissEvent> for AcpOnboardingModal {}
impl Focusable for AcpOnboardingModal {
fn focus_handle(&self, _cx: &App) -> FocusHandle {
self.focus_handle.clone()
}
}
impl ModalView for AcpOnboardingModal {}
impl Render for AcpOnboardingModal {
fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let illustration_element = |label: bool, opacity: f32| {
h_flex()
.px_1()
.py_0p5()
.gap_1()
.rounded_sm()
.bg(cx.theme().colors().element_active.opacity(0.05))
.border_1()
.border_color(cx.theme().colors().border)
.border_dashed()
.child(
Icon::new(IconName::Stop)
.size(IconSize::Small)
.color(Color::Custom(cx.theme().colors().text_muted.opacity(0.15))),
)
.map(|this| {
if label {
this.child(
Label::new("Your Agent Here")
.size(LabelSize::Small)
.color(Color::Muted),
)
} else {
this.child(
div().w_16().h_1().rounded_full().bg(cx
.theme()
.colors()
.element_active
.opacity(0.6)),
)
}
})
.opacity(opacity)
};
let illustration = h_flex()
.relative()
.h(rems_from_px(126.))
.bg(cx.theme().colors().editor_background)
.border_b_1()
.border_color(cx.theme().colors().border_variant)
.justify_center()
.gap_8()
.rounded_t_md()
.overflow_hidden()
.child(
div().absolute().inset_0().w(px(515.)).h(px(126.)).child(
Vector::new(VectorName::AcpGrid, rems_from_px(515.), rems_from_px(126.))
.color(ui::Color::Custom(cx.theme().colors().text.opacity(0.02))),
),
)
.child(div().absolute().inset_0().size_full().bg(linear_gradient(
0.,
linear_color_stop(
cx.theme().colors().elevated_surface_background.opacity(0.1),
0.9,
),
linear_color_stop(
cx.theme().colors().elevated_surface_background.opacity(0.),
0.,
),
)))
.child(
div()
.absolute()
.inset_0()
.size_full()
.bg(gpui::black().opacity(0.15)),
)
.child(
h_flex()
.gap_4()
.child(
Vector::new(VectorName::AcpLogo, rems_from_px(106.), rems_from_px(40.))
.color(ui::Color::Custom(cx.theme().colors().text.opacity(0.8))),
)
.child(
Vector::new(
VectorName::AcpLogoSerif,
rems_from_px(111.),
rems_from_px(41.),
)
.color(ui::Color::Custom(cx.theme().colors().text.opacity(0.8))),
),
)
.child(
v_flex()
.gap_1p5()
.child(illustration_element(false, 0.15))
.child(illustration_element(true, 0.3))
.child(
h_flex()
.pl_1()
.pr_2()
.py_0p5()
.gap_1()
.rounded_sm()
.bg(cx.theme().colors().element_active.opacity(0.2))
.border_1()
.border_color(cx.theme().colors().border)
.child(
Icon::new(IconName::AiGemini)
.size(IconSize::Small)
.color(Color::Muted),
)
.child(Label::new("New Gemini CLI Thread").size(LabelSize::Small)),
)
.child(illustration_element(true, 0.3))
.child(illustration_element(false, 0.15)),
);
let heading = v_flex()
.w_full()
.gap_1()
.child(
Label::new("Now Available")
.size(LabelSize::Small)
.color(Color::Muted),
)
.child(Headline::new("Bring Your Own Agent to Zed").size(HeadlineSize::Large));
let copy = "Bring the agent of your choice to Zed via our new Agent Client Protocol (ACP), starting with Google's Gemini CLI integration.";
let open_panel_button = Button::new("open-panel", "Start with Gemini CLI")
.icon_size(IconSize::Indicator)
.style(ButtonStyle::Tinted(TintColor::Accent))
.full_width()
.on_click(cx.listener(Self::open_panel));
let docs_button = Button::new("add-other-agents", "Add Other Agents")
.icon(IconName::ArrowUpRight)
.icon_size(IconSize::Indicator)
.icon_color(Color::Muted)
.full_width()
.on_click(cx.listener(Self::view_docs));
let close_button = h_flex().absolute().top_2().right_2().child(
IconButton::new("cancel", IconName::Close).on_click(cx.listener(
|_, _: &ClickEvent, _window, cx| {
acp_onboarding_event!("Canceled", trigger = "X click");
cx.emit(DismissEvent);
},
)),
);
v_flex()
.id("acp-onboarding")
.key_context("AcpOnboardingModal")
.relative()
.w(rems(34.))
.h_full()
.elevation_3(cx)
.track_focus(&self.focus_handle(cx))
.overflow_hidden()
.on_action(cx.listener(Self::cancel))
.on_action(cx.listener(|_, _: &menu::Cancel, _window, cx| {
acp_onboarding_event!("Canceled", trigger = "Action");
cx.emit(DismissEvent);
}))
.on_any_mouse_down(cx.listener(|this, _: &MouseDownEvent, window, _cx| {
this.focus_handle.focus(window);
}))
.child(illustration)
.child(
v_flex()
.p_4()
.gap_2()
.child(heading)
.child(Label::new(copy).color(Color::Muted))
.child(
v_flex()
.w_full()
.mt_2()
.gap_1()
.child(open_panel_button)
.child(docs_button),
),
)
.child(close_button)
}
}

View file

@ -12,11 +12,11 @@ use crate::{SignInStatus, YoungAccountBanner, plan_definitions::PlanDefinitions}
#[derive(IntoElement, RegisterComponent)]
pub struct AiUpsellCard {
pub sign_in_status: SignInStatus,
pub sign_in: Arc<dyn Fn(&mut Window, &mut App)>,
pub account_too_young: bool,
pub user_plan: Option<Plan>,
pub tab_index: Option<isize>,
sign_in_status: SignInStatus,
sign_in: Arc<dyn Fn(&mut Window, &mut App)>,
account_too_young: bool,
user_plan: Option<Plan>,
tab_index: Option<isize>,
}
impl AiUpsellCard {
@ -43,6 +43,11 @@ impl AiUpsellCard {
tab_index: None,
}
}
pub fn tab_index(mut self, tab_index: Option<isize>) -> Self {
self.tab_index = tab_index;
self
}
}
impl RenderOnce for AiUpsellCard {

View file

@ -118,7 +118,7 @@ impl Tool for FetchTool {
}
fn needs_confirmation(&self, _: &serde_json::Value, _: &Entity<Project>, _: &App) -> bool {
false
true
}
fn may_perform_edits(&self) -> bool {

View file

@ -435,8 +435,8 @@ mod test {
assert_eq!(
matches,
&[
PathBuf::from("root/apple/banana/carrot"),
PathBuf::from("root/apple/bandana/carbonara")
PathBuf::from(path!("root/apple/banana/carrot")),
PathBuf::from(path!("root/apple/bandana/carbonara"))
]
);
@ -447,8 +447,8 @@ mod test {
assert_eq!(
matches,
&[
PathBuf::from("root/apple/banana/carrot"),
PathBuf::from("root/apple/bandana/carbonara")
PathBuf::from(path!("root/apple/banana/carrot")),
PathBuf::from(path!("root/apple/bandana/carbonara"))
]
);
}

View file

@ -68,7 +68,7 @@ impl Tool for ReadFileTool {
}
fn icon(&self) -> IconName {
IconName::ToolRead
IconName::ToolSearch
}
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {

View file

@ -43,3 +43,11 @@ pub fn ai_privacy_and_security(cx: &App) -> String {
server_url = server_url(cx)
)
}
/// Returns the URL to Zed AI's external agents documentation.
pub fn external_agents_docs(cx: &App) -> String {
format!(
"{server_url}/docs/ai/external-agents",
server_url = server_url(cx)
)
}

View file

@ -1,7 +1,10 @@
use anyhow::Result;
use db::{
define_connection, query,
sqlez::{bindable::Column, statement::Statement},
query,
sqlez::{
bindable::Column, domain::Domain, statement::Statement,
thread_safe_connection::ThreadSafeConnection,
},
sqlez_macros::sql,
};
use serde::{Deserialize, Serialize};
@ -50,8 +53,11 @@ impl Column for SerializedCommandInvocation {
}
}
define_connection!(pub static ref COMMAND_PALETTE_HISTORY: CommandPaletteDB<()> =
&[sql!(
pub struct CommandPaletteDB(ThreadSafeConnection);
impl Domain for CommandPaletteDB {
const NAME: &str = stringify!(CommandPaletteDB);
const MIGRATIONS: &[&str] = &[sql!(
CREATE TABLE IF NOT EXISTS command_invocations(
id INTEGER PRIMARY KEY AUTOINCREMENT,
command_name TEXT NOT NULL,
@ -59,7 +65,9 @@ define_connection!(pub static ref COMMAND_PALETTE_HISTORY: CommandPaletteDB<()>
last_invoked INTEGER DEFAULT (unixepoch()) NOT NULL
) STRICT;
)];
);
}
db::static_connection!(COMMAND_PALETTE_HISTORY, CommandPaletteDB, []);
impl CommandPaletteDB {
pub async fn write_command_invocation(

View file

@ -110,11 +110,14 @@ pub async fn open_test_db<M: Migrator>(db_name: &str) -> ThreadSafeConnection {
}
/// Implements a basic DB wrapper for a given domain
///
/// Arguments:
/// - static variable name for connection
/// - type of connection wrapper
/// - dependencies, whose migrations should be run prior to this domain's migrations
#[macro_export]
macro_rules! define_connection {
(pub static ref $id:ident: $t:ident<()> = $migrations:expr; $($global:ident)?) => {
pub struct $t($crate::sqlez::thread_safe_connection::ThreadSafeConnection);
macro_rules! static_connection {
($id:ident, $t:ident, [ $($d:ty),* ] $(, $global:ident)?) => {
impl ::std::ops::Deref for $t {
type Target = $crate::sqlez::thread_safe_connection::ThreadSafeConnection;
@ -123,16 +126,6 @@ macro_rules! define_connection {
}
}
impl $crate::sqlez::domain::Domain for $t {
fn name() -> &'static str {
stringify!($t)
}
fn migrations() -> &'static [&'static str] {
$migrations
}
}
impl $t {
#[cfg(any(test, feature = "test-support"))]
pub async fn open_test_db(name: &'static str) -> Self {
@ -142,7 +135,8 @@ macro_rules! define_connection {
#[cfg(any(test, feature = "test-support"))]
pub static $id: std::sync::LazyLock<$t> = std::sync::LazyLock::new(|| {
$t($crate::smol::block_on($crate::open_test_db::<$t>(stringify!($id))))
#[allow(unused_parens)]
$t($crate::smol::block_on($crate::open_test_db::<($($d,)* $t)>(stringify!($id))))
});
#[cfg(not(any(test, feature = "test-support")))]
@ -153,46 +147,10 @@ macro_rules! define_connection {
} else {
$crate::RELEASE_CHANNEL.dev_name()
};
$t($crate::smol::block_on($crate::open_db::<$t>(db_dir, scope)))
#[allow(unused_parens)]
$t($crate::smol::block_on($crate::open_db::<($($d,)* $t)>(db_dir, scope)))
});
};
(pub static ref $id:ident: $t:ident<$($d:ty),+> = $migrations:expr; $($global:ident)?) => {
pub struct $t($crate::sqlez::thread_safe_connection::ThreadSafeConnection);
impl ::std::ops::Deref for $t {
type Target = $crate::sqlez::thread_safe_connection::ThreadSafeConnection;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl $crate::sqlez::domain::Domain for $t {
fn name() -> &'static str {
stringify!($t)
}
fn migrations() -> &'static [&'static str] {
$migrations
}
}
#[cfg(any(test, feature = "test-support"))]
pub static $id: std::sync::LazyLock<$t> = std::sync::LazyLock::new(|| {
$t($crate::smol::block_on($crate::open_test_db::<($($d),+, $t)>(stringify!($id))))
});
#[cfg(not(any(test, feature = "test-support")))]
pub static $id: std::sync::LazyLock<$t> = std::sync::LazyLock::new(|| {
let db_dir = $crate::database_dir();
let scope = if false $(|| stringify!($global) == "global")? {
"global"
} else {
$crate::RELEASE_CHANNEL.dev_name()
};
$t($crate::smol::block_on($crate::open_db::<($($d),+, $t)>(db_dir, scope)))
});
};
}
}
pub fn write_and_log<F>(cx: &App, db_write: impl FnOnce() -> F + Send + 'static)
@ -219,17 +177,12 @@ mod tests {
enum BadDB {}
impl Domain for BadDB {
fn name() -> &'static str {
"db_tests"
}
fn migrations() -> &'static [&'static str] {
&[
sql!(CREATE TABLE test(value);),
// failure because test already exists
sql!(CREATE TABLE test(value);),
]
}
const NAME: &str = "db_tests";
const MIGRATIONS: &[&str] = &[
sql!(CREATE TABLE test(value);),
// failure because test already exists
sql!(CREATE TABLE test(value);),
];
}
let tempdir = tempfile::Builder::new()
@ -251,25 +204,15 @@ mod tests {
enum CorruptedDB {}
impl Domain for CorruptedDB {
fn name() -> &'static str {
"db_tests"
}
fn migrations() -> &'static [&'static str] {
&[sql!(CREATE TABLE test(value);)]
}
const NAME: &str = "db_tests";
const MIGRATIONS: &[&str] = &[sql!(CREATE TABLE test(value);)];
}
enum GoodDB {}
impl Domain for GoodDB {
fn name() -> &'static str {
"db_tests" //Notice same name
}
fn migrations() -> &'static [&'static str] {
&[sql!(CREATE TABLE test2(value);)] //But different migration
}
const NAME: &str = "db_tests"; //Notice same name
const MIGRATIONS: &[&str] = &[sql!(CREATE TABLE test2(value);)];
}
let tempdir = tempfile::Builder::new()
@ -305,25 +248,16 @@ mod tests {
enum CorruptedDB {}
impl Domain for CorruptedDB {
fn name() -> &'static str {
"db_tests"
}
const NAME: &str = "db_tests";
fn migrations() -> &'static [&'static str] {
&[sql!(CREATE TABLE test(value);)]
}
const MIGRATIONS: &[&str] = &[sql!(CREATE TABLE test(value);)];
}
enum GoodDB {}
impl Domain for GoodDB {
fn name() -> &'static str {
"db_tests" //Notice same name
}
fn migrations() -> &'static [&'static str] {
&[sql!(CREATE TABLE test2(value);)] //But different migration
}
const NAME: &str = "db_tests"; //Notice same name
const MIGRATIONS: &[&str] = &[sql!(CREATE TABLE test2(value);)]; // But different migration
}
let tempdir = tempfile::Builder::new()

View file

@ -2,16 +2,26 @@ use gpui::App;
use sqlez_macros::sql;
use util::ResultExt as _;
use crate::{define_connection, query, write_and_log};
use crate::{
query,
sqlez::{domain::Domain, thread_safe_connection::ThreadSafeConnection},
write_and_log,
};
define_connection!(pub static ref KEY_VALUE_STORE: KeyValueStore<()> =
&[sql!(
pub struct KeyValueStore(crate::sqlez::thread_safe_connection::ThreadSafeConnection);
impl Domain for KeyValueStore {
const NAME: &str = stringify!(KeyValueStore);
const MIGRATIONS: &[&str] = &[sql!(
CREATE TABLE IF NOT EXISTS kv_store(
key TEXT PRIMARY KEY,
value TEXT NOT NULL
) STRICT;
)];
);
}
crate::static_connection!(KEY_VALUE_STORE, KeyValueStore, []);
pub trait Dismissable {
const KEY: &'static str;
@ -91,15 +101,19 @@ mod tests {
}
}
define_connection!(pub static ref GLOBAL_KEY_VALUE_STORE: GlobalKeyValueStore<()> =
&[sql!(
pub struct GlobalKeyValueStore(ThreadSafeConnection);
impl Domain for GlobalKeyValueStore {
const NAME: &str = stringify!(GlobalKeyValueStore);
const MIGRATIONS: &[&str] = &[sql!(
CREATE TABLE IF NOT EXISTS kv_store(
key TEXT PRIMARY KEY,
value TEXT NOT NULL
) STRICT;
)];
global
);
}
crate::static_connection!(GLOBAL_KEY_VALUE_STORE, GlobalKeyValueStore, [], global);
impl GlobalKeyValueStore {
query! {

View file

@ -19,6 +19,10 @@ static KEYMAP_LINUX: LazyLock<KeymapFile> = LazyLock::new(|| {
load_keymap("keymaps/default-linux.json").expect("Failed to load Linux keymap")
});
static KEYMAP_WINDOWS: LazyLock<KeymapFile> = LazyLock::new(|| {
load_keymap("keymaps/default-windows.json").expect("Failed to load Windows keymap")
});
static ALL_ACTIONS: LazyLock<Vec<ActionDef>> = LazyLock::new(dump_all_gpui_actions);
const FRONT_MATTER_COMMENT: &str = "<!-- ZED_META {} -->";
@ -216,6 +220,7 @@ fn find_binding(os: &str, action: &str) -> Option<String> {
let keymap = match os {
"macos" => &KEYMAP_MACOS,
"linux" | "freebsd" => &KEYMAP_LINUX,
"windows" => &KEYMAP_WINDOWS,
_ => unreachable!("Not a valid OS: {}", os),
};

View file

@ -2588,7 +2588,7 @@ impl Editor {
|| binding
.keystrokes()
.first()
.is_some_and(|keystroke| keystroke.modifiers.modified())
.is_some_and(|keystroke| keystroke.display_modifiers.modified())
}))
}
@ -7686,16 +7686,16 @@ impl Editor {
.keystroke()
{
modifiers_held = modifiers_held
|| (&accept_keystroke.modifiers == modifiers
&& accept_keystroke.modifiers.modified());
|| (&accept_keystroke.display_modifiers == modifiers
&& accept_keystroke.display_modifiers.modified());
};
if let Some(accept_partial_keystroke) = self
.accept_edit_prediction_keybind(true, window, cx)
.keystroke()
{
modifiers_held = modifiers_held
|| (&accept_partial_keystroke.modifiers == modifiers
&& accept_partial_keystroke.modifiers.modified());
|| (&accept_partial_keystroke.display_modifiers == modifiers
&& accept_partial_keystroke.display_modifiers.modified());
}
if modifiers_held {
@ -9044,7 +9044,7 @@ impl Editor {
let is_platform_style_mac = PlatformStyle::platform() == PlatformStyle::Mac;
let modifiers_color = if accept_keystroke.modifiers == window.modifiers() {
let modifiers_color = if accept_keystroke.display_modifiers == window.modifiers() {
Color::Accent
} else {
Color::Muted
@ -9056,19 +9056,19 @@ impl Editor {
.font(theme::ThemeSettings::get_global(cx).buffer_font.clone())
.text_size(TextSize::XSmall.rems(cx))
.child(h_flex().children(ui::render_modifiers(
&accept_keystroke.modifiers,
&accept_keystroke.display_modifiers,
PlatformStyle::platform(),
Some(modifiers_color),
Some(IconSize::XSmall.rems().into()),
true,
)))
.when(is_platform_style_mac, |parent| {
parent.child(accept_keystroke.key.clone())
parent.child(accept_keystroke.display_key.clone())
})
.when(!is_platform_style_mac, |parent| {
parent.child(
Key::new(
util::capitalize(&accept_keystroke.key),
util::capitalize(&accept_keystroke.display_key),
Some(Color::Default),
)
.size(Some(IconSize::XSmall.rems().into())),
@ -9171,7 +9171,7 @@ impl Editor {
max_width: Pixels,
cursor_point: Point,
style: &EditorStyle,
accept_keystroke: Option<&gpui::Keystroke>,
accept_keystroke: Option<&gpui::KeybindingKeystroke>,
_window: &Window,
cx: &mut Context<Editor>,
) -> Option<AnyElement> {
@ -9249,7 +9249,7 @@ impl Editor {
accept_keystroke.as_ref(),
|el, accept_keystroke| {
el.child(h_flex().children(ui::render_modifiers(
&accept_keystroke.modifiers,
&accept_keystroke.display_modifiers,
PlatformStyle::platform(),
Some(Color::Default),
Some(IconSize::XSmall.rems().into()),
@ -9319,7 +9319,7 @@ impl Editor {
.child(completion),
)
.when_some(accept_keystroke, |el, accept_keystroke| {
if !accept_keystroke.modifiers.modified() {
if !accept_keystroke.display_modifiers.modified() {
return el;
}
@ -9338,7 +9338,7 @@ impl Editor {
.font(theme::ThemeSettings::get_global(cx).buffer_font.clone())
.when(is_platform_style_mac, |parent| parent.gap_1())
.child(h_flex().children(ui::render_modifiers(
&accept_keystroke.modifiers,
&accept_keystroke.display_modifiers,
PlatformStyle::platform(),
Some(if !has_completion {
Color::Muted

View file

@ -43,10 +43,10 @@ use gpui::{
Bounds, ClickEvent, ClipboardItem, ContentMask, Context, Corner, Corners, CursorStyle,
DispatchPhase, Edges, Element, ElementInputHandler, Entity, Focusable as _, FontId,
GlobalElementId, Hitbox, HitboxBehavior, Hsla, InteractiveElement, IntoElement, IsZero,
Keystroke, Length, ModifiersChangedEvent, MouseButton, MouseClickEvent, MouseDownEvent,
MouseMoveEvent, MouseUpEvent, PaintQuad, ParentElement, Pixels, ScrollDelta, ScrollHandle,
ScrollWheelEvent, ShapedLine, SharedString, Size, StatefulInteractiveElement, Style, Styled,
TextRun, TextStyleRefinement, WeakEntity, Window, anchored, deferred, div, fill,
KeybindingKeystroke, Length, ModifiersChangedEvent, MouseButton, MouseClickEvent,
MouseDownEvent, MouseMoveEvent, MouseUpEvent, PaintQuad, ParentElement, Pixels, ScrollDelta,
ScrollHandle, ScrollWheelEvent, ShapedLine, SharedString, Size, StatefulInteractiveElement,
Style, Styled, TextRun, TextStyleRefinement, WeakEntity, Window, anchored, deferred, div, fill,
linear_color_stop, linear_gradient, outline, point, px, quad, relative, size, solid_background,
transparent_black,
};
@ -7150,7 +7150,7 @@ fn header_jump_data(
pub struct AcceptEditPredictionBinding(pub(crate) Option<gpui::KeyBinding>);
impl AcceptEditPredictionBinding {
pub fn keystroke(&self) -> Option<&Keystroke> {
pub fn keystroke(&self) -> Option<&KeybindingKeystroke> {
if let Some(binding) = self.0.as_ref() {
match &binding.keystrokes() {
[keystroke, ..] => Some(keystroke),

View file

@ -1404,7 +1404,7 @@ impl ProjectItem for Editor {
}
fn for_broken_project_item(
abs_path: PathBuf,
abs_path: &Path,
is_local: bool,
e: &anyhow::Error,
window: &mut Window,

View file

@ -1,13 +1,17 @@
use anyhow::Result;
use db::sqlez::bindable::{Bind, Column, StaticColumnCount};
use db::sqlez::statement::Statement;
use db::{
query,
sqlez::{
bindable::{Bind, Column, StaticColumnCount},
domain::Domain,
statement::Statement,
},
sqlez_macros::sql,
};
use fs::MTime;
use itertools::Itertools as _;
use std::path::PathBuf;
use db::sqlez_macros::sql;
use db::{define_connection, query};
use workspace::{ItemId, WorkspaceDb, WorkspaceId};
#[derive(Clone, Debug, PartialEq, Default)]
@ -83,7 +87,11 @@ impl Column for SerializedEditor {
}
}
define_connection!(
pub struct EditorDb(db::sqlez::thread_safe_connection::ThreadSafeConnection);
impl Domain for EditorDb {
const NAME: &str = stringify!(EditorDb);
// Current schema shape using pseudo-rust syntax:
// editors(
// item_id: usize,
@ -113,7 +121,8 @@ define_connection!(
// start: usize,
// end: usize,
// )
pub static ref DB: EditorDb<WorkspaceDb> = &[
const MIGRATIONS: &[&str] = &[
sql! (
CREATE TABLE editors(
item_id INTEGER NOT NULL,
@ -189,7 +198,9 @@ define_connection!(
) STRICT;
),
];
);
}
db::static_connection!(DB, EditorDb, [WorkspaceDb]);
// https://www.sqlite.org/limits.html
// > <..> the maximum value of a host parameter number is SQLITE_MAX_VARIABLE_NUMBER,

View file

@ -98,6 +98,10 @@ impl FeatureFlag for GeminiAndNativeFeatureFlag {
// integration too, and we'd like to turn Gemini/Native on in new builds
// without enabling Claude Code in old builds.
const NAME: &'static str = "gemini-and-native";
fn enabled_for_all() -> bool {
true
}
}
pub struct ClaudeCodeFeatureFlag;
@ -201,7 +205,7 @@ impl FeatureFlagAppExt for App {
fn has_flag<T: FeatureFlag>(&self) -> bool {
self.try_global::<FeatureFlags>()
.map(|flags| flags.has_flag::<T>())
.unwrap_or(false)
.unwrap_or(T::enabled_for_all())
}
fn is_staff(&self) -> bool {

View file

@ -4466,7 +4466,7 @@ fn current_language_model(cx: &Context<'_, GitPanel>) -> Option<Arc<dyn Language
is_enabled
.then(|| {
let ConfiguredModel { provider, model } =
LanguageModelRegistry::read_global(cx).commit_message_model(cx)?;
LanguageModelRegistry::read_global(cx).commit_message_model()?;
provider.is_authenticated(cx).then(|| model)
})

View file

@ -37,10 +37,10 @@ use crate::{
AssetSource, BackgroundExecutor, Bounds, ClipboardItem, CursorStyle, DispatchPhase, DisplayId,
EventEmitter, FocusHandle, FocusMap, ForegroundExecutor, Global, KeyBinding, KeyContext,
Keymap, Keystroke, LayoutId, Menu, MenuItem, OwnedMenu, PathPromptOptions, Pixels, Platform,
PlatformDisplay, PlatformKeyboardLayout, Point, PromptBuilder, PromptButton, PromptHandle,
PromptLevel, Render, RenderImage, RenderablePromptHandle, Reservation, ScreenCaptureSource,
SubscriberSet, Subscription, SvgRenderer, Task, TextSystem, Window, WindowAppearance,
WindowHandle, WindowId, WindowInvalidator,
PlatformDisplay, PlatformKeyboardLayout, PlatformKeyboardMapper, Point, PromptBuilder,
PromptButton, PromptHandle, PromptLevel, Render, RenderImage, RenderablePromptHandle,
Reservation, ScreenCaptureSource, SubscriberSet, Subscription, SvgRenderer, Task, TextSystem,
Window, WindowAppearance, WindowHandle, WindowId, WindowInvalidator,
colors::{Colors, GlobalColors},
current_platform, hash, init_app_menus,
};
@ -263,6 +263,7 @@ pub struct App {
pub(crate) focus_handles: Arc<FocusMap>,
pub(crate) keymap: Rc<RefCell<Keymap>>,
pub(crate) keyboard_layout: Box<dyn PlatformKeyboardLayout>,
pub(crate) keyboard_mapper: Rc<dyn PlatformKeyboardMapper>,
pub(crate) global_action_listeners:
FxHashMap<TypeId, Vec<Rc<dyn Fn(&dyn Any, DispatchPhase, &mut Self)>>>,
pending_effects: VecDeque<Effect>,
@ -312,6 +313,7 @@ impl App {
let text_system = Arc::new(TextSystem::new(platform.text_system()));
let entities = EntityMap::new();
let keyboard_layout = platform.keyboard_layout();
let keyboard_mapper = platform.keyboard_mapper();
let app = Rc::new_cyclic(|this| AppCell {
app: RefCell::new(App {
@ -337,6 +339,7 @@ impl App {
focus_handles: Arc::new(RwLock::new(SlotMap::with_key())),
keymap: Rc::new(RefCell::new(Keymap::default())),
keyboard_layout,
keyboard_mapper,
global_action_listeners: FxHashMap::default(),
pending_effects: VecDeque::new(),
pending_notifications: FxHashSet::default(),
@ -376,6 +379,7 @@ impl App {
if let Some(app) = app.upgrade() {
let cx = &mut app.borrow_mut();
cx.keyboard_layout = cx.platform.keyboard_layout();
cx.keyboard_mapper = cx.platform.keyboard_mapper();
cx.keyboard_layout_observers
.clone()
.retain(&(), move |callback| (callback)(cx));
@ -424,6 +428,11 @@ impl App {
self.keyboard_layout.as_ref()
}
/// Get the current keyboard mapper.
pub fn keyboard_mapper(&self) -> &Rc<dyn PlatformKeyboardMapper> {
&self.keyboard_mapper
}
/// Invokes a handler when the current keyboard layout changes
pub fn on_keyboard_layout_change<F>(&self, mut callback: F) -> Subscription
where

View file

@ -4,7 +4,7 @@ mod context;
pub use binding::*;
pub use context::*;
use crate::{Action, Keystroke, is_no_action};
use crate::{Action, AsKeystroke, Keystroke, is_no_action};
use collections::{HashMap, HashSet};
use smallvec::SmallVec;
use std::any::TypeId;
@ -141,7 +141,7 @@ impl Keymap {
/// only.
pub fn bindings_for_input(
&self,
input: &[Keystroke],
input: &[impl AsKeystroke],
context_stack: &[KeyContext],
) -> (SmallVec<[KeyBinding; 1]>, bool) {
let mut matched_bindings = SmallVec::<[(usize, BindingIndex, &KeyBinding); 1]>::new();
@ -192,7 +192,6 @@ impl Keymap {
(bindings, !pending.is_empty())
}
/// Check if the given binding is enabled, given a certain key context.
/// Returns the deepest depth at which the binding matches, or None if it doesn't match.
fn binding_enabled(&self, binding: &KeyBinding, contexts: &[KeyContext]) -> Option<usize> {
@ -639,7 +638,7 @@ mod tests {
fn assert_bindings(keymap: &Keymap, action: &dyn Action, expected: &[&str]) {
let actual = keymap
.bindings_for_action(action)
.map(|binding| binding.keystrokes[0].unparse())
.map(|binding| binding.keystrokes[0].inner.unparse())
.collect::<Vec<_>>();
assert_eq!(actual, expected, "{:?}", action);
}

View file

@ -1,14 +1,15 @@
use std::rc::Rc;
use collections::HashMap;
use crate::{Action, InvalidKeystrokeError, KeyBindingContextPredicate, Keystroke, SharedString};
use crate::{
Action, AsKeystroke, DummyKeyboardMapper, InvalidKeystrokeError, KeyBindingContextPredicate,
KeybindingKeystroke, Keystroke, PlatformKeyboardMapper, SharedString,
};
use smallvec::SmallVec;
/// A keybinding and its associated metadata, from the keymap.
pub struct KeyBinding {
pub(crate) action: Box<dyn Action>,
pub(crate) keystrokes: SmallVec<[Keystroke; 2]>,
pub(crate) keystrokes: SmallVec<[KeybindingKeystroke; 2]>,
pub(crate) context_predicate: Option<Rc<KeyBindingContextPredicate>>,
pub(crate) meta: Option<KeyBindingMetaIndex>,
/// The json input string used when building the keybinding, if any
@ -32,7 +33,15 @@ impl KeyBinding {
pub fn new<A: Action>(keystrokes: &str, action: A, context: Option<&str>) -> Self {
let context_predicate =
context.map(|context| KeyBindingContextPredicate::parse(context).unwrap().into());
Self::load(keystrokes, Box::new(action), context_predicate, None, None).unwrap()
Self::load(
keystrokes,
Box::new(action),
context_predicate,
false,
None,
&DummyKeyboardMapper,
)
.unwrap()
}
/// Load a keybinding from the given raw data.
@ -40,24 +49,22 @@ impl KeyBinding {
keystrokes: &str,
action: Box<dyn Action>,
context_predicate: Option<Rc<KeyBindingContextPredicate>>,
key_equivalents: Option<&HashMap<char, char>>,
use_key_equivalents: bool,
action_input: Option<SharedString>,
keyboard_mapper: &dyn PlatformKeyboardMapper,
) -> std::result::Result<Self, InvalidKeystrokeError> {
let mut keystrokes: SmallVec<[Keystroke; 2]> = keystrokes
let keystrokes: SmallVec<[KeybindingKeystroke; 2]> = keystrokes
.split_whitespace()
.map(Keystroke::parse)
.map(|source| {
let keystroke = Keystroke::parse(source)?;
Ok(KeybindingKeystroke::new(
keystroke,
use_key_equivalents,
keyboard_mapper,
))
})
.collect::<std::result::Result<_, _>>()?;
if let Some(equivalents) = key_equivalents {
for keystroke in keystrokes.iter_mut() {
if keystroke.key.chars().count() == 1
&& let Some(key) = equivalents.get(&keystroke.key.chars().next().unwrap())
{
keystroke.key = key.to_string();
}
}
}
Ok(Self {
keystrokes,
action,
@ -79,13 +86,13 @@ impl KeyBinding {
}
/// Check if the given keystrokes match this binding.
pub fn match_keystrokes(&self, typed: &[Keystroke]) -> Option<bool> {
pub fn match_keystrokes(&self, typed: &[impl AsKeystroke]) -> Option<bool> {
if self.keystrokes.len() < typed.len() {
return None;
}
for (target, typed) in self.keystrokes.iter().zip(typed.iter()) {
if !typed.should_match(target) {
if !typed.as_keystroke().should_match(target) {
return None;
}
}
@ -94,7 +101,7 @@ impl KeyBinding {
}
/// Get the keystrokes associated with this binding
pub fn keystrokes(&self) -> &[Keystroke] {
pub fn keystrokes(&self) -> &[KeybindingKeystroke] {
self.keystrokes.as_slice()
}

View file

@ -231,7 +231,6 @@ pub(crate) trait Platform: 'static {
fn on_quit(&self, callback: Box<dyn FnMut()>);
fn on_reopen(&self, callback: Box<dyn FnMut()>);
fn on_keyboard_layout_change(&self, callback: Box<dyn FnMut()>);
fn set_menus(&self, menus: Vec<Menu>, keymap: &Keymap);
fn get_menus(&self) -> Option<Vec<OwnedMenu>> {
@ -251,7 +250,6 @@ pub(crate) trait Platform: 'static {
fn on_app_menu_action(&self, callback: Box<dyn FnMut(&dyn Action)>);
fn on_will_open_app_menu(&self, callback: Box<dyn FnMut()>);
fn on_validate_app_menu_command(&self, callback: Box<dyn FnMut(&dyn Action) -> bool>);
fn keyboard_layout(&self) -> Box<dyn PlatformKeyboardLayout>;
fn compositor_name(&self) -> &'static str {
""
@ -272,6 +270,10 @@ pub(crate) trait Platform: 'static {
fn write_credentials(&self, url: &str, username: &str, password: &[u8]) -> Task<Result<()>>;
fn read_credentials(&self, url: &str) -> Task<Result<Option<(String, Vec<u8>)>>>;
fn delete_credentials(&self, url: &str) -> Task<Result<()>>;
fn keyboard_layout(&self) -> Box<dyn PlatformKeyboardLayout>;
fn keyboard_mapper(&self) -> Rc<dyn PlatformKeyboardMapper>;
fn on_keyboard_layout_change(&self, callback: Box<dyn FnMut()>);
}
/// A handle to a platform's display, e.g. a monitor or laptop screen.

View file

@ -1,3 +1,7 @@
use collections::HashMap;
use crate::{KeybindingKeystroke, Keystroke};
/// A trait for platform-specific keyboard layouts
pub trait PlatformKeyboardLayout {
/// Get the keyboard layout ID, which should be unique to the layout
@ -5,3 +9,33 @@ pub trait PlatformKeyboardLayout {
/// Get the keyboard layout display name
fn name(&self) -> &str;
}
/// A trait for platform-specific keyboard mappings
pub trait PlatformKeyboardMapper {
/// Map a key equivalent to its platform-specific representation
fn map_key_equivalent(
&self,
keystroke: Keystroke,
use_key_equivalents: bool,
) -> KeybindingKeystroke;
/// Get the key equivalents for the current keyboard layout,
/// only used on macOS
fn get_key_equivalents(&self) -> Option<&HashMap<char, char>>;
}
/// A dummy implementation of the platform keyboard mapper
pub struct DummyKeyboardMapper;
impl PlatformKeyboardMapper for DummyKeyboardMapper {
fn map_key_equivalent(
&self,
keystroke: Keystroke,
_use_key_equivalents: bool,
) -> KeybindingKeystroke {
KeybindingKeystroke::from_keystroke(keystroke)
}
fn get_key_equivalents(&self) -> Option<&HashMap<char, char>> {
None
}
}

View file

@ -5,6 +5,14 @@ use std::{
fmt::{Display, Write},
};
use crate::PlatformKeyboardMapper;
/// This is a helper trait so that we can simplify the implementation of some functions
pub trait AsKeystroke {
/// Returns the GPUI representation of the keystroke.
fn as_keystroke(&self) -> &Keystroke;
}
/// A keystroke and associated metadata generated by the platform
#[derive(Clone, Debug, Eq, PartialEq, Default, Deserialize, Hash)]
pub struct Keystroke {
@ -24,6 +32,17 @@ pub struct Keystroke {
pub key_char: Option<String>,
}
/// Represents a keystroke that can be used in keybindings and displayed to the user.
#[derive(Debug, Clone, Eq, PartialEq, Hash)]
pub struct KeybindingKeystroke {
/// The GPUI representation of the keystroke.
pub inner: Keystroke,
/// The modifiers to display.
pub display_modifiers: Modifiers,
/// The key to display.
pub display_key: String,
}
/// Error type for `Keystroke::parse`. This is used instead of `anyhow::Error` so that Zed can use
/// markdown to display it.
#[derive(Debug)]
@ -58,7 +77,7 @@ impl Keystroke {
///
/// This method assumes that `self` was typed and `target' is in the keymap, and checks
/// both possibilities for self against the target.
pub fn should_match(&self, target: &Keystroke) -> bool {
pub fn should_match(&self, target: &KeybindingKeystroke) -> bool {
#[cfg(not(target_os = "windows"))]
if let Some(key_char) = self
.key_char
@ -71,7 +90,7 @@ impl Keystroke {
..Default::default()
};
if &target.key == key_char && target.modifiers == ime_modifiers {
if &target.inner.key == key_char && target.inner.modifiers == ime_modifiers {
return true;
}
}
@ -83,12 +102,12 @@ impl Keystroke {
.filter(|key_char| key_char != &&self.key)
{
// On Windows, if key_char is set, then the typed keystroke produced the key_char
if &target.key == key_char && target.modifiers == Modifiers::none() {
if &target.inner.key == key_char && target.inner.modifiers == Modifiers::none() {
return true;
}
}
target.modifiers == self.modifiers && target.key == self.key
target.inner.modifiers == self.modifiers && target.inner.key == self.key
}
/// key syntax is:
@ -200,31 +219,7 @@ impl Keystroke {
/// Produces a representation of this key that Parse can understand.
pub fn unparse(&self) -> String {
let mut str = String::new();
if self.modifiers.function {
str.push_str("fn-");
}
if self.modifiers.control {
str.push_str("ctrl-");
}
if self.modifiers.alt {
str.push_str("alt-");
}
if self.modifiers.platform {
#[cfg(target_os = "macos")]
str.push_str("cmd-");
#[cfg(any(target_os = "linux", target_os = "freebsd"))]
str.push_str("super-");
#[cfg(target_os = "windows")]
str.push_str("win-");
}
if self.modifiers.shift {
str.push_str("shift-");
}
str.push_str(&self.key);
str
unparse(&self.modifiers, &self.key)
}
/// Returns true if this keystroke left
@ -266,6 +261,32 @@ impl Keystroke {
}
}
impl KeybindingKeystroke {
/// Create a new keybinding keystroke from the given keystroke
pub fn new(
inner: Keystroke,
use_key_equivalents: bool,
keyboard_mapper: &dyn PlatformKeyboardMapper,
) -> Self {
keyboard_mapper.map_key_equivalent(inner, use_key_equivalents)
}
pub(crate) fn from_keystroke(keystroke: Keystroke) -> Self {
let key = keystroke.key.clone();
let modifiers = keystroke.modifiers;
KeybindingKeystroke {
inner: keystroke,
display_modifiers: modifiers,
display_key: key,
}
}
/// Produces a representation of this key that Parse can understand.
pub fn unparse(&self) -> String {
unparse(&self.display_modifiers, &self.display_key)
}
}
fn is_printable_key(key: &str) -> bool {
!matches!(
key,
@ -322,65 +343,15 @@ fn is_printable_key(key: &str) -> bool {
impl std::fmt::Display for Keystroke {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
if self.modifiers.control {
#[cfg(target_os = "macos")]
f.write_char('^')?;
display_modifiers(&self.modifiers, f)?;
display_key(&self.key, f)
}
}
#[cfg(not(target_os = "macos"))]
write!(f, "ctrl-")?;
}
if self.modifiers.alt {
#[cfg(target_os = "macos")]
f.write_char('⌥')?;
#[cfg(not(target_os = "macos"))]
write!(f, "alt-")?;
}
if self.modifiers.platform {
#[cfg(target_os = "macos")]
f.write_char('⌘')?;
#[cfg(any(target_os = "linux", target_os = "freebsd"))]
f.write_char('❖')?;
#[cfg(target_os = "windows")]
f.write_char('⊞')?;
}
if self.modifiers.shift {
#[cfg(target_os = "macos")]
f.write_char('⇧')?;
#[cfg(not(target_os = "macos"))]
write!(f, "shift-")?;
}
let key = match self.key.as_str() {
#[cfg(target_os = "macos")]
"backspace" => '⌫',
#[cfg(target_os = "macos")]
"up" => '↑',
#[cfg(target_os = "macos")]
"down" => '↓',
#[cfg(target_os = "macos")]
"left" => '←',
#[cfg(target_os = "macos")]
"right" => '→',
#[cfg(target_os = "macos")]
"tab" => '⇥',
#[cfg(target_os = "macos")]
"escape" => '⎋',
#[cfg(target_os = "macos")]
"shift" => '⇧',
#[cfg(target_os = "macos")]
"control" => '⌃',
#[cfg(target_os = "macos")]
"alt" => '⌥',
#[cfg(target_os = "macos")]
"platform" => '⌘',
key if key.len() == 1 => key.chars().next().unwrap().to_ascii_uppercase(),
key => return f.write_str(key),
};
f.write_char(key)
impl std::fmt::Display for KeybindingKeystroke {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
display_modifiers(&self.display_modifiers, f)?;
display_key(&self.display_key, f)
}
}
@ -600,3 +571,110 @@ pub struct Capslock {
#[serde(default)]
pub on: bool,
}
impl AsKeystroke for Keystroke {
fn as_keystroke(&self) -> &Keystroke {
self
}
}
impl AsKeystroke for KeybindingKeystroke {
fn as_keystroke(&self) -> &Keystroke {
&self.inner
}
}
fn display_modifiers(modifiers: &Modifiers, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
if modifiers.control {
#[cfg(target_os = "macos")]
f.write_char('^')?;
#[cfg(not(target_os = "macos"))]
write!(f, "ctrl-")?;
}
if modifiers.alt {
#[cfg(target_os = "macos")]
f.write_char('⌥')?;
#[cfg(not(target_os = "macos"))]
write!(f, "alt-")?;
}
if modifiers.platform {
#[cfg(target_os = "macos")]
f.write_char('⌘')?;
#[cfg(any(target_os = "linux", target_os = "freebsd"))]
f.write_char('❖')?;
#[cfg(target_os = "windows")]
f.write_char('⊞')?;
}
if modifiers.shift {
#[cfg(target_os = "macos")]
f.write_char('⇧')?;
#[cfg(not(target_os = "macos"))]
write!(f, "shift-")?;
}
Ok(())
}
fn display_key(key: &str, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let key = match key {
#[cfg(target_os = "macos")]
"backspace" => '⌫',
#[cfg(target_os = "macos")]
"up" => '↑',
#[cfg(target_os = "macos")]
"down" => '↓',
#[cfg(target_os = "macos")]
"left" => '←',
#[cfg(target_os = "macos")]
"right" => '→',
#[cfg(target_os = "macos")]
"tab" => '⇥',
#[cfg(target_os = "macos")]
"escape" => '⎋',
#[cfg(target_os = "macos")]
"shift" => '⇧',
#[cfg(target_os = "macos")]
"control" => '⌃',
#[cfg(target_os = "macos")]
"alt" => '⌥',
#[cfg(target_os = "macos")]
"platform" => '⌘',
key if key.len() == 1 => key.chars().next().unwrap().to_ascii_uppercase(),
key => return f.write_str(key),
};
f.write_char(key)
}
#[inline]
fn unparse(modifiers: &Modifiers, key: &str) -> String {
let mut result = String::new();
if modifiers.function {
result.push_str("fn-");
}
if modifiers.control {
result.push_str("ctrl-");
}
if modifiers.alt {
result.push_str("alt-");
}
if modifiers.platform {
#[cfg(target_os = "macos")]
result.push_str("cmd-");
#[cfg(any(target_os = "linux", target_os = "freebsd"))]
result.push_str("super-");
#[cfg(target_os = "windows")]
result.push_str("win-");
}
if modifiers.shift {
result.push_str("shift-");
}
result.push_str(&key);
result
}

View file

@ -25,8 +25,8 @@ use xkbcommon::xkb::{self, Keycode, Keysym, State};
use crate::{
Action, AnyWindowHandle, BackgroundExecutor, ClipboardItem, CursorStyle, DisplayId,
ForegroundExecutor, Keymap, LinuxDispatcher, Menu, MenuItem, OwnedMenu, PathPromptOptions,
Pixels, Platform, PlatformDisplay, PlatformKeyboardLayout, PlatformTextSystem, PlatformWindow,
Point, Result, Task, WindowAppearance, WindowParams, px,
Pixels, Platform, PlatformDisplay, PlatformKeyboardLayout, PlatformKeyboardMapper,
PlatformTextSystem, PlatformWindow, Point, Result, Task, WindowAppearance, WindowParams, px,
};
#[cfg(any(feature = "wayland", feature = "x11"))]
@ -144,6 +144,10 @@ impl<P: LinuxClient + 'static> Platform for P {
self.keyboard_layout()
}
fn keyboard_mapper(&self) -> Rc<dyn PlatformKeyboardMapper> {
Rc::new(crate::DummyKeyboardMapper)
}
fn on_keyboard_layout_change(&self, callback: Box<dyn FnMut()>) {
self.with_common(|common| common.callbacks.keyboard_layout_change = Some(callback));
}

File diff suppressed because it is too large Load diff

View file

@ -1,5 +1,5 @@
use super::{
BoolExt, MacKeyboardLayout,
BoolExt, MacKeyboardLayout, MacKeyboardMapper,
attributed_string::{NSAttributedString, NSMutableAttributedString},
events::key_to_native,
renderer,
@ -8,8 +8,9 @@ use crate::{
Action, AnyWindowHandle, BackgroundExecutor, ClipboardEntry, ClipboardItem, ClipboardString,
CursorStyle, ForegroundExecutor, Image, ImageFormat, KeyContext, Keymap, MacDispatcher,
MacDisplay, MacWindow, Menu, MenuItem, OsMenu, OwnedMenu, PathPromptOptions, Platform,
PlatformDisplay, PlatformKeyboardLayout, PlatformTextSystem, PlatformWindow, Result,
SemanticVersion, SystemMenuType, Task, WindowAppearance, WindowParams, hash,
PlatformDisplay, PlatformKeyboardLayout, PlatformKeyboardMapper, PlatformTextSystem,
PlatformWindow, Result, SemanticVersion, SystemMenuType, Task, WindowAppearance, WindowParams,
hash,
};
use anyhow::{Context as _, anyhow};
use block::ConcreteBlock;
@ -171,6 +172,7 @@ pub(crate) struct MacPlatformState {
finish_launching: Option<Box<dyn FnOnce()>>,
dock_menu: Option<id>,
menus: Option<Vec<OwnedMenu>>,
keyboard_mapper: Rc<MacKeyboardMapper>,
}
impl Default for MacPlatform {
@ -189,6 +191,9 @@ impl MacPlatform {
#[cfg(not(feature = "font-kit"))]
let text_system = Arc::new(crate::NoopTextSystem::new());
let keyboard_layout = MacKeyboardLayout::new();
let keyboard_mapper = Rc::new(MacKeyboardMapper::new(keyboard_layout.id()));
Self(Mutex::new(MacPlatformState {
headless,
text_system,
@ -209,6 +214,7 @@ impl MacPlatform {
dock_menu: None,
on_keyboard_layout_change: None,
menus: None,
keyboard_mapper,
}))
}
@ -348,19 +354,19 @@ impl MacPlatform {
let mut mask = NSEventModifierFlags::empty();
for (modifier, flag) in &[
(
keystroke.modifiers.platform,
keystroke.display_modifiers.platform,
NSEventModifierFlags::NSCommandKeyMask,
),
(
keystroke.modifiers.control,
keystroke.display_modifiers.control,
NSEventModifierFlags::NSControlKeyMask,
),
(
keystroke.modifiers.alt,
keystroke.display_modifiers.alt,
NSEventModifierFlags::NSAlternateKeyMask,
),
(
keystroke.modifiers.shift,
keystroke.display_modifiers.shift,
NSEventModifierFlags::NSShiftKeyMask,
),
] {
@ -373,7 +379,7 @@ impl MacPlatform {
.initWithTitle_action_keyEquivalent_(
ns_string(name),
selector,
ns_string(key_to_native(&keystroke.key).as_ref()),
ns_string(key_to_native(&keystroke.display_key).as_ref()),
)
.autorelease();
if Self::os_version() >= SemanticVersion::new(12, 0, 0) {
@ -882,6 +888,10 @@ impl Platform for MacPlatform {
Box::new(MacKeyboardLayout::new())
}
fn keyboard_mapper(&self) -> Rc<dyn PlatformKeyboardMapper> {
self.0.lock().keyboard_mapper.clone()
}
fn app_path(&self) -> Result<PathBuf> {
unsafe {
let bundle: id = NSBundle::mainBundle();
@ -1393,6 +1403,8 @@ extern "C" fn will_terminate(this: &mut Object, _: Sel, _: id) {
extern "C" fn on_keyboard_layout_change(this: &mut Object, _: Sel, _: id) {
let platform = unsafe { get_mac_platform(this) };
let mut lock = platform.0.lock();
let keyboard_layout = MacKeyboardLayout::new();
lock.keyboard_mapper = Rc::new(MacKeyboardMapper::new(keyboard_layout.id()));
if let Some(mut callback) = lock.on_keyboard_layout_change.take() {
drop(lock);
callback();

View file

@ -1,8 +1,9 @@
use crate::{
AnyWindowHandle, BackgroundExecutor, ClipboardItem, CursorStyle, DevicePixels,
ForegroundExecutor, Keymap, NoopTextSystem, Platform, PlatformDisplay, PlatformKeyboardLayout,
PlatformTextSystem, PromptButton, ScreenCaptureFrame, ScreenCaptureSource, ScreenCaptureStream,
SourceMetadata, Task, TestDisplay, TestWindow, WindowAppearance, WindowParams, size,
DummyKeyboardMapper, ForegroundExecutor, Keymap, NoopTextSystem, Platform, PlatformDisplay,
PlatformKeyboardLayout, PlatformKeyboardMapper, PlatformTextSystem, PromptButton,
ScreenCaptureFrame, ScreenCaptureSource, ScreenCaptureStream, SourceMetadata, Task,
TestDisplay, TestWindow, WindowAppearance, WindowParams, size,
};
use anyhow::Result;
use collections::VecDeque;
@ -237,6 +238,10 @@ impl Platform for TestPlatform {
Box::new(TestKeyboardLayout)
}
fn keyboard_mapper(&self) -> Rc<dyn PlatformKeyboardMapper> {
Rc::new(DummyKeyboardMapper)
}
fn on_keyboard_layout_change(&self, _: Box<dyn FnMut()>) {}
fn run(&self, _on_finish_launching: Box<dyn FnOnce()>) {

View file

@ -9,10 +9,8 @@ use parking::Parker;
use parking_lot::Mutex;
use util::ResultExt;
use windows::{
Foundation::TimeSpan,
System::Threading::{
ThreadPool, ThreadPoolTimer, TimerElapsedHandler, WorkItemHandler, WorkItemOptions,
WorkItemPriority,
ThreadPool, ThreadPoolTimer, TimerElapsedHandler, WorkItemHandler, WorkItemPriority,
},
Win32::{
Foundation::{LPARAM, WPARAM},
@ -56,12 +54,7 @@ impl WindowsDispatcher {
Ok(())
})
};
ThreadPool::RunWithPriorityAndOptionsAsync(
&handler,
WorkItemPriority::High,
WorkItemOptions::TimeSliced,
)
.log_err();
ThreadPool::RunWithPriorityAsync(&handler, WorkItemPriority::High).log_err();
}
fn dispatch_on_threadpool_after(&self, runnable: Runnable, duration: Duration) {
@ -72,12 +65,7 @@ impl WindowsDispatcher {
Ok(())
})
};
let delay = TimeSpan {
// A time period expressed in 100-nanosecond units.
// 10,000,000 ticks per second
Duration: (duration.as_nanos() / 100) as i64,
};
ThreadPoolTimer::CreateTimer(&handler, delay).log_err();
ThreadPoolTimer::CreateTimer(&handler, duration.into()).log_err();
}
}

View file

@ -1,22 +1,31 @@
use anyhow::Result;
use collections::HashMap;
use windows::Win32::UI::{
Input::KeyboardAndMouse::{
GetKeyboardLayoutNameW, MAPVK_VK_TO_CHAR, MapVirtualKeyW, ToUnicode, VIRTUAL_KEY, VK_0,
VK_1, VK_2, VK_3, VK_4, VK_5, VK_6, VK_7, VK_8, VK_9, VK_ABNT_C1, VK_CONTROL, VK_MENU,
VK_OEM_1, VK_OEM_2, VK_OEM_3, VK_OEM_4, VK_OEM_5, VK_OEM_6, VK_OEM_7, VK_OEM_8, VK_OEM_102,
VK_OEM_COMMA, VK_OEM_MINUS, VK_OEM_PERIOD, VK_OEM_PLUS, VK_SHIFT,
GetKeyboardLayoutNameW, MAPVK_VK_TO_CHAR, MAPVK_VK_TO_VSC, MapVirtualKeyW, ToUnicode,
VIRTUAL_KEY, VK_0, VK_1, VK_2, VK_3, VK_4, VK_5, VK_6, VK_7, VK_8, VK_9, VK_ABNT_C1,
VK_CONTROL, VK_MENU, VK_OEM_1, VK_OEM_2, VK_OEM_3, VK_OEM_4, VK_OEM_5, VK_OEM_6, VK_OEM_7,
VK_OEM_8, VK_OEM_102, VK_OEM_COMMA, VK_OEM_MINUS, VK_OEM_PERIOD, VK_OEM_PLUS, VK_SHIFT,
},
WindowsAndMessaging::KL_NAMELENGTH,
};
use windows_core::HSTRING;
use crate::{Modifiers, PlatformKeyboardLayout};
use crate::{
KeybindingKeystroke, Keystroke, Modifiers, PlatformKeyboardLayout, PlatformKeyboardMapper,
};
pub(crate) struct WindowsKeyboardLayout {
id: String,
name: String,
}
pub(crate) struct WindowsKeyboardMapper {
key_to_vkey: HashMap<String, (u16, bool)>,
vkey_to_key: HashMap<u16, String>,
vkey_to_shifted: HashMap<u16, String>,
}
impl PlatformKeyboardLayout for WindowsKeyboardLayout {
fn id(&self) -> &str {
&self.id
@ -27,6 +36,65 @@ impl PlatformKeyboardLayout for WindowsKeyboardLayout {
}
}
impl PlatformKeyboardMapper for WindowsKeyboardMapper {
fn map_key_equivalent(
&self,
mut keystroke: Keystroke,
use_key_equivalents: bool,
) -> KeybindingKeystroke {
let Some((vkey, shifted_key)) = self.get_vkey_from_key(&keystroke.key, use_key_equivalents)
else {
return KeybindingKeystroke::from_keystroke(keystroke);
};
if shifted_key && keystroke.modifiers.shift {
log::warn!(
"Keystroke '{}' has both shift and a shifted key, this is likely a bug",
keystroke.key
);
}
let shift = shifted_key || keystroke.modifiers.shift;
keystroke.modifiers.shift = false;
let Some(key) = self.vkey_to_key.get(&vkey).cloned() else {
log::error!(
"Failed to map key equivalent '{:?}' to a valid key",
keystroke
);
return KeybindingKeystroke::from_keystroke(keystroke);
};
keystroke.key = if shift {
let Some(shifted_key) = self.vkey_to_shifted.get(&vkey).cloned() else {
log::error!(
"Failed to map keystroke {:?} with virtual key '{:?}' to a shifted key",
keystroke,
vkey
);
return KeybindingKeystroke::from_keystroke(keystroke);
};
shifted_key
} else {
key.clone()
};
let modifiers = Modifiers {
shift,
..keystroke.modifiers
};
KeybindingKeystroke {
inner: keystroke,
display_modifiers: modifiers,
display_key: key,
}
}
fn get_key_equivalents(&self) -> Option<&HashMap<char, char>> {
None
}
}
impl WindowsKeyboardLayout {
pub(crate) fn new() -> Result<Self> {
let mut buffer = [0u16; KL_NAMELENGTH as usize];
@ -48,6 +116,41 @@ impl WindowsKeyboardLayout {
}
}
impl WindowsKeyboardMapper {
pub(crate) fn new() -> Self {
let mut key_to_vkey = HashMap::default();
let mut vkey_to_key = HashMap::default();
let mut vkey_to_shifted = HashMap::default();
for vkey in CANDIDATE_VKEYS {
if let Some(key) = get_key_from_vkey(*vkey) {
key_to_vkey.insert(key.clone(), (vkey.0, false));
vkey_to_key.insert(vkey.0, key);
}
let scan_code = unsafe { MapVirtualKeyW(vkey.0 as u32, MAPVK_VK_TO_VSC) };
if scan_code == 0 {
continue;
}
if let Some(shifted_key) = get_shifted_key(*vkey, scan_code) {
key_to_vkey.insert(shifted_key.clone(), (vkey.0, true));
vkey_to_shifted.insert(vkey.0, shifted_key);
}
}
Self {
key_to_vkey,
vkey_to_key,
vkey_to_shifted,
}
}
fn get_vkey_from_key(&self, key: &str, use_key_equivalents: bool) -> Option<(u16, bool)> {
if use_key_equivalents {
get_vkey_from_key_with_us_layout(key)
} else {
self.key_to_vkey.get(key).cloned()
}
}
}
pub(crate) fn get_keystroke_key(
vkey: VIRTUAL_KEY,
scan_code: u32,
@ -140,3 +243,134 @@ pub(crate) fn generate_key_char(
_ => None,
}
}
fn get_vkey_from_key_with_us_layout(key: &str) -> Option<(u16, bool)> {
match key {
// ` => VK_OEM_3
"`" => Some((VK_OEM_3.0, false)),
"~" => Some((VK_OEM_3.0, true)),
"1" => Some((VK_1.0, false)),
"!" => Some((VK_1.0, true)),
"2" => Some((VK_2.0, false)),
"@" => Some((VK_2.0, true)),
"3" => Some((VK_3.0, false)),
"#" => Some((VK_3.0, true)),
"4" => Some((VK_4.0, false)),
"$" => Some((VK_4.0, true)),
"5" => Some((VK_5.0, false)),
"%" => Some((VK_5.0, true)),
"6" => Some((VK_6.0, false)),
"^" => Some((VK_6.0, true)),
"7" => Some((VK_7.0, false)),
"&" => Some((VK_7.0, true)),
"8" => Some((VK_8.0, false)),
"*" => Some((VK_8.0, true)),
"9" => Some((VK_9.0, false)),
"(" => Some((VK_9.0, true)),
"0" => Some((VK_0.0, false)),
")" => Some((VK_0.0, true)),
"-" => Some((VK_OEM_MINUS.0, false)),
"_" => Some((VK_OEM_MINUS.0, true)),
"=" => Some((VK_OEM_PLUS.0, false)),
"+" => Some((VK_OEM_PLUS.0, true)),
"[" => Some((VK_OEM_4.0, false)),
"{" => Some((VK_OEM_4.0, true)),
"]" => Some((VK_OEM_6.0, false)),
"}" => Some((VK_OEM_6.0, true)),
"\\" => Some((VK_OEM_5.0, false)),
"|" => Some((VK_OEM_5.0, true)),
";" => Some((VK_OEM_1.0, false)),
":" => Some((VK_OEM_1.0, true)),
"'" => Some((VK_OEM_7.0, false)),
"\"" => Some((VK_OEM_7.0, true)),
"," => Some((VK_OEM_COMMA.0, false)),
"<" => Some((VK_OEM_COMMA.0, true)),
"." => Some((VK_OEM_PERIOD.0, false)),
">" => Some((VK_OEM_PERIOD.0, true)),
"/" => Some((VK_OEM_2.0, false)),
"?" => Some((VK_OEM_2.0, true)),
_ => None,
}
}
const CANDIDATE_VKEYS: &[VIRTUAL_KEY] = &[
VK_OEM_3,
VK_OEM_MINUS,
VK_OEM_PLUS,
VK_OEM_4,
VK_OEM_5,
VK_OEM_6,
VK_OEM_1,
VK_OEM_7,
VK_OEM_COMMA,
VK_OEM_PERIOD,
VK_OEM_2,
VK_OEM_102,
VK_OEM_8,
VK_ABNT_C1,
VK_0,
VK_1,
VK_2,
VK_3,
VK_4,
VK_5,
VK_6,
VK_7,
VK_8,
VK_9,
];
#[cfg(test)]
mod tests {
use crate::{Keystroke, Modifiers, PlatformKeyboardMapper, WindowsKeyboardMapper};
#[test]
fn test_keyboard_mapper() {
let mapper = WindowsKeyboardMapper::new();
// Normal case
let keystroke = Keystroke {
modifiers: Modifiers::control(),
key: "a".to_string(),
key_char: None,
};
let mapped = mapper.map_key_equivalent(keystroke.clone(), true);
assert_eq!(mapped.inner, keystroke);
assert_eq!(mapped.display_key, "a");
assert_eq!(mapped.display_modifiers, Modifiers::control());
// Shifted case, ctrl-$
let keystroke = Keystroke {
modifiers: Modifiers::control(),
key: "$".to_string(),
key_char: None,
};
let mapped = mapper.map_key_equivalent(keystroke.clone(), true);
assert_eq!(mapped.inner, keystroke);
assert_eq!(mapped.display_key, "4");
assert_eq!(mapped.display_modifiers, Modifiers::control_shift());
// Shifted case, but shift is true
let keystroke = Keystroke {
modifiers: Modifiers::control_shift(),
key: "$".to_string(),
key_char: None,
};
let mapped = mapper.map_key_equivalent(keystroke, true);
assert_eq!(mapped.inner.modifiers, Modifiers::control());
assert_eq!(mapped.display_key, "4");
assert_eq!(mapped.display_modifiers, Modifiers::control_shift());
// Windows style
let keystroke = Keystroke {
modifiers: Modifiers::control_shift(),
key: "4".to_string(),
key_char: None,
};
let mapped = mapper.map_key_equivalent(keystroke, true);
assert_eq!(mapped.inner.modifiers, Modifiers::control());
assert_eq!(mapped.inner.key, "$");
assert_eq!(mapped.display_key, "4");
assert_eq!(mapped.display_modifiers, Modifiers::control_shift());
}
}

View file

@ -351,6 +351,10 @@ impl Platform for WindowsPlatform {
)
}
fn keyboard_mapper(&self) -> Rc<dyn PlatformKeyboardMapper> {
Rc::new(WindowsKeyboardMapper::new())
}
fn on_keyboard_layout_change(&self, callback: Box<dyn FnMut()>) {
self.state.borrow_mut().callbacks.keyboard_layout_change = Some(callback);
}

View file

@ -215,6 +215,7 @@ pub enum IconName {
Tab,
Terminal,
TerminalAlt,
TerminalGhost,
TextSnippet,
TextThread,
Thread,

View file

@ -401,12 +401,19 @@ pub fn init(cx: &mut App) {
mod persistence {
use std::path::PathBuf;
use db::{define_connection, query, sqlez_macros::sql};
use db::{
query,
sqlez::{domain::Domain, thread_safe_connection::ThreadSafeConnection},
sqlez_macros::sql,
};
use workspace::{ItemId, WorkspaceDb, WorkspaceId};
define_connection! {
pub static ref IMAGE_VIEWER: ImageViewerDb<WorkspaceDb> =
&[sql!(
pub struct ImageViewerDb(ThreadSafeConnection);
impl Domain for ImageViewerDb {
const NAME: &str = stringify!(ImageViewerDb);
const MIGRATIONS: &[&str] = &[sql!(
CREATE TABLE image_viewers (
workspace_id INTEGER,
item_id INTEGER UNIQUE,
@ -417,9 +424,11 @@ mod persistence {
FOREIGN KEY(workspace_id) REFERENCES workspaces(workspace_id)
ON DELETE CASCADE
) STRICT;
)];
)];
}
db::static_connection!(IMAGE_VIEWER, ImageViewerDb, [WorkspaceDb]);
impl ImageViewerDb {
query! {
pub async fn save_image_path(

View file

@ -1569,11 +1569,21 @@ impl Buffer {
self.send_operation(op, true, cx);
}
pub fn get_diagnostics(&self, server_id: LanguageServerId) -> Option<&DiagnosticSet> {
let Ok(idx) = self.diagnostics.binary_search_by_key(&server_id, |v| v.0) else {
return None;
};
Some(&self.diagnostics[idx].1)
pub fn buffer_diagnostics(
&self,
for_server: Option<LanguageServerId>,
) -> Vec<&DiagnosticEntry<Anchor>> {
match for_server {
Some(server_id) => match self.diagnostics.binary_search_by_key(&server_id, |v| v.0) {
Ok(idx) => self.diagnostics[idx].1.iter().collect(),
Err(_) => Vec::new(),
},
None => self
.diagnostics
.iter()
.flat_map(|(_, diagnostic_set)| diagnostic_set.iter())
.collect(),
}
}
fn request_autoindent(&mut self, cx: &mut Context<Self>) {

View file

@ -4,12 +4,16 @@ use crate::{
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
LanguageModelRequest, LanguageModelToolChoice,
};
use anyhow::anyhow;
use futures::{FutureExt, channel::mpsc, future::BoxFuture, stream::BoxStream};
use gpui::{AnyView, App, AsyncApp, Entity, Task, Window};
use http_client::Result;
use parking_lot::Mutex;
use smol::stream::StreamExt;
use std::sync::Arc;
use std::sync::{
Arc,
atomic::{AtomicBool, Ordering::SeqCst},
};
#[derive(Clone)]
pub struct FakeLanguageModelProvider {
@ -106,6 +110,7 @@ pub struct FakeLanguageModel {
>,
)>,
>,
forbid_requests: AtomicBool,
}
impl Default for FakeLanguageModel {
@ -114,11 +119,20 @@ impl Default for FakeLanguageModel {
provider_id: LanguageModelProviderId::from("fake".to_string()),
provider_name: LanguageModelProviderName::from("Fake".to_string()),
current_completion_txs: Mutex::new(Vec::new()),
forbid_requests: AtomicBool::new(false),
}
}
}
impl FakeLanguageModel {
pub fn allow_requests(&self) {
self.forbid_requests.store(false, SeqCst);
}
pub fn forbid_requests(&self) {
self.forbid_requests.store(true, SeqCst);
}
pub fn pending_completions(&self) -> Vec<LanguageModelRequest> {
self.current_completion_txs
.lock()
@ -251,9 +265,18 @@ impl LanguageModel for FakeLanguageModel {
LanguageModelCompletionError,
>,
> {
let (tx, rx) = mpsc::unbounded();
self.current_completion_txs.lock().push((request, tx));
async move { Ok(rx.boxed()) }.boxed()
if self.forbid_requests.load(SeqCst) {
async move {
Err(LanguageModelCompletionError::Other(anyhow!(
"requests are forbidden"
)))
}
.boxed()
} else {
let (tx, rx) = mpsc::unbounded();
self.current_completion_txs.lock().push((request, tx));
async move { Ok(rx.boxed()) }.boxed()
}
}
fn as_fake(&self) -> &Self {

View file

@ -6,6 +6,7 @@ use collections::BTreeMap;
use gpui::{App, Context, Entity, EventEmitter, Global, prelude::*};
use std::{str::FromStr, sync::Arc};
use thiserror::Error;
use util::maybe;
pub fn init(cx: &mut App) {
let registry = cx.new(|_cx| LanguageModelRegistry::default());
@ -41,9 +42,7 @@ impl std::fmt::Debug for ConfigurationError {
#[derive(Default)]
pub struct LanguageModelRegistry {
default_model: Option<ConfiguredModel>,
/// This model is automatically configured by a user's environment after
/// authenticating all providers. It's only used when default_model is not available.
environment_fallback_model: Option<ConfiguredModel>,
default_fast_model: Option<ConfiguredModel>,
inline_assistant_model: Option<ConfiguredModel>,
commit_message_model: Option<ConfiguredModel>,
thread_summary_model: Option<ConfiguredModel>,
@ -99,6 +98,9 @@ impl ConfiguredModel {
pub enum Event {
DefaultModelChanged,
InlineAssistantModelChanged,
CommitMessageModelChanged,
ThreadSummaryModelChanged,
ProviderStateChanged(LanguageModelProviderId),
AddedProvider(LanguageModelProviderId),
RemovedProvider(LanguageModelProviderId),
@ -224,7 +226,7 @@ impl LanguageModelRegistry {
cx: &mut Context<Self>,
) {
let configured_model = model.and_then(|model| self.select_model(model, cx));
self.set_inline_assistant_model(configured_model);
self.set_inline_assistant_model(configured_model, cx);
}
pub fn select_commit_message_model(
@ -233,7 +235,7 @@ impl LanguageModelRegistry {
cx: &mut Context<Self>,
) {
let configured_model = model.and_then(|model| self.select_model(model, cx));
self.set_commit_message_model(configured_model);
self.set_commit_message_model(configured_model, cx);
}
pub fn select_thread_summary_model(
@ -242,7 +244,7 @@ impl LanguageModelRegistry {
cx: &mut Context<Self>,
) {
let configured_model = model.and_then(|model| self.select_model(model, cx));
self.set_thread_summary_model(configured_model);
self.set_thread_summary_model(configured_model, cx);
}
/// Selects and sets the inline alternatives for language models based on
@ -276,60 +278,68 @@ impl LanguageModelRegistry {
}
pub fn set_default_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
match (self.default_model(), model.as_ref()) {
match (self.default_model.as_ref(), model.as_ref()) {
(Some(old), Some(new)) if old.is_same_as(new) => {}
(None, None) => {}
_ => cx.emit(Event::DefaultModelChanged),
}
self.default_fast_model = maybe!({
let provider = &model.as_ref()?.provider;
let fast_model = provider.default_fast_model(cx)?;
Some(ConfiguredModel {
provider: provider.clone(),
model: fast_model,
})
});
self.default_model = model;
}
pub fn set_environment_fallback_model(
pub fn set_inline_assistant_model(
&mut self,
model: Option<ConfiguredModel>,
cx: &mut Context<Self>,
) {
if self.default_model.is_none() {
match (self.environment_fallback_model.as_ref(), model.as_ref()) {
(Some(old), Some(new)) if old.is_same_as(new) => {}
(None, None) => {}
_ => cx.emit(Event::DefaultModelChanged),
}
match (self.inline_assistant_model.as_ref(), model.as_ref()) {
(Some(old), Some(new)) if old.is_same_as(new) => {}
(None, None) => {}
_ => cx.emit(Event::InlineAssistantModelChanged),
}
self.environment_fallback_model = model;
}
pub fn set_inline_assistant_model(&mut self, model: Option<ConfiguredModel>) {
self.inline_assistant_model = model;
}
pub fn set_commit_message_model(&mut self, model: Option<ConfiguredModel>) {
pub fn set_commit_message_model(
&mut self,
model: Option<ConfiguredModel>,
cx: &mut Context<Self>,
) {
match (self.commit_message_model.as_ref(), model.as_ref()) {
(Some(old), Some(new)) if old.is_same_as(new) => {}
(None, None) => {}
_ => cx.emit(Event::CommitMessageModelChanged),
}
self.commit_message_model = model;
}
pub fn set_thread_summary_model(&mut self, model: Option<ConfiguredModel>) {
pub fn set_thread_summary_model(
&mut self,
model: Option<ConfiguredModel>,
cx: &mut Context<Self>,
) {
match (self.thread_summary_model.as_ref(), model.as_ref()) {
(Some(old), Some(new)) if old.is_same_as(new) => {}
(None, None) => {}
_ => cx.emit(Event::ThreadSummaryModelChanged),
}
self.thread_summary_model = model;
}
#[track_caller]
pub fn default_model(&self) -> Option<ConfiguredModel> {
#[cfg(debug_assertions)]
if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
return None;
}
self.default_model
.clone()
.or_else(|| self.environment_fallback_model.clone())
}
pub fn default_fast_model(&self, cx: &App) -> Option<ConfiguredModel> {
let provider = self.default_model()?.provider;
let fast_model = provider.default_fast_model(cx)?;
Some(ConfiguredModel {
provider,
model: fast_model,
})
self.default_model.clone()
}
pub fn inline_assistant_model(&self) -> Option<ConfiguredModel> {
@ -343,7 +353,7 @@ impl LanguageModelRegistry {
.or_else(|| self.default_model.clone())
}
pub fn commit_message_model(&self, cx: &App) -> Option<ConfiguredModel> {
pub fn commit_message_model(&self) -> Option<ConfiguredModel> {
#[cfg(debug_assertions)]
if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
return None;
@ -351,11 +361,11 @@ impl LanguageModelRegistry {
self.commit_message_model
.clone()
.or_else(|| self.default_fast_model(cx))
.or_else(|| self.default_fast_model.clone())
.or_else(|| self.default_model.clone())
}
pub fn thread_summary_model(&self, cx: &App) -> Option<ConfiguredModel> {
pub fn thread_summary_model(&self) -> Option<ConfiguredModel> {
#[cfg(debug_assertions)]
if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
return None;
@ -363,7 +373,7 @@ impl LanguageModelRegistry {
self.thread_summary_model
.clone()
.or_else(|| self.default_fast_model(cx))
.or_else(|| self.default_fast_model.clone())
.or_else(|| self.default_model.clone())
}
@ -400,34 +410,4 @@ mod tests {
let providers = registry.read(cx).providers();
assert!(providers.is_empty());
}
#[gpui::test]
async fn test_configure_environment_fallback_model(cx: &mut gpui::TestAppContext) {
let registry = cx.new(|_| LanguageModelRegistry::default());
let provider = FakeLanguageModelProvider::default();
registry.update(cx, |registry, cx| {
registry.register_provider(provider.clone(), cx);
});
cx.update(|cx| provider.authenticate(cx)).await.unwrap();
registry.update(cx, |registry, cx| {
let provider = registry.provider(&provider.id()).unwrap();
registry.set_environment_fallback_model(
Some(ConfiguredModel {
provider: provider.clone(),
model: provider.default_model(cx).unwrap(),
}),
cx,
);
let default_model = registry.default_model().unwrap();
let fallback_model = registry.environment_fallback_model.clone().unwrap();
assert_eq!(default_model.model.id(), fallback_model.model.id());
assert_eq!(default_model.provider.id(), fallback_model.provider.id());
});
}
}

View file

@ -44,7 +44,6 @@ ollama = { workspace = true, features = ["schemars"] }
open_ai = { workspace = true, features = ["schemars"] }
open_router = { workspace = true, features = ["schemars"] }
partial-json-fixer.workspace = true
project.workspace = true
release_channel.workspace = true
schemars.workspace = true
serde.workspace = true

View file

@ -3,12 +3,8 @@ use std::sync::Arc;
use ::settings::{Settings, SettingsStore};
use client::{Client, UserStore};
use collections::HashSet;
use futures::future;
use gpui::{App, AppContext as _, Context, Entity};
use language_model::{
AuthenticateError, ConfiguredModel, LanguageModelProviderId, LanguageModelRegistry,
};
use project::DisableAiSettings;
use gpui::{App, Context, Entity};
use language_model::{LanguageModelProviderId, LanguageModelRegistry};
use provider::deepseek::DeepSeekLanguageModelProvider;
pub mod provider;
@ -17,7 +13,7 @@ pub mod ui;
use crate::provider::anthropic::AnthropicLanguageModelProvider;
use crate::provider::bedrock::BedrockLanguageModelProvider;
use crate::provider::cloud::{self, CloudLanguageModelProvider};
use crate::provider::cloud::CloudLanguageModelProvider;
use crate::provider::copilot_chat::CopilotChatLanguageModelProvider;
use crate::provider::google::GoogleLanguageModelProvider;
use crate::provider::lmstudio::LmStudioLanguageModelProvider;
@ -52,13 +48,6 @@ pub fn init(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) {
cx,
);
});
let mut already_authenticated = false;
if !DisableAiSettings::get_global(cx).disable_ai {
authenticate_all_providers(registry.clone(), cx);
already_authenticated = true;
}
cx.observe_global::<SettingsStore>(move |cx| {
let openai_compatible_providers_new = AllLanguageModelSettings::get_global(cx)
.openai_compatible
@ -76,12 +65,6 @@ pub fn init(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) {
);
});
openai_compatible_providers = openai_compatible_providers_new;
already_authenticated = false;
}
if !DisableAiSettings::get_global(cx).disable_ai && !already_authenticated {
authenticate_all_providers(registry.clone(), cx);
already_authenticated = true;
}
})
.detach();
@ -168,83 +151,3 @@ fn register_language_model_providers(
registry.register_provider(XAiLanguageModelProvider::new(client.http_client(), cx), cx);
registry.register_provider(CopilotChatLanguageModelProvider::new(cx), cx);
}
/// Authenticates all providers in the [`LanguageModelRegistry`].
///
/// We do this so that we can populate the language selector with all of the
/// models from the configured providers.
///
/// This function won't do anything if AI is disabled.
fn authenticate_all_providers(registry: Entity<LanguageModelRegistry>, cx: &mut App) {
let providers_to_authenticate = registry
.read(cx)
.providers()
.iter()
.map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
.collect::<Vec<_>>();
let mut tasks = Vec::with_capacity(providers_to_authenticate.len());
for (provider_id, provider_name, authenticate_task) in providers_to_authenticate {
tasks.push(cx.background_spawn(async move {
if let Err(err) = authenticate_task.await {
if matches!(err, AuthenticateError::CredentialsNotFound) {
// Since we're authenticating these providers in the
// background for the purposes of populating the
// language selector, we don't care about providers
// where the credentials are not found.
} else {
// Some providers have noisy failure states that we
// don't want to spam the logs with every time the
// language model selector is initialized.
//
// Ideally these should have more clear failure modes
// that we know are safe to ignore here, like what we do
// with `CredentialsNotFound` above.
match provider_id.0.as_ref() {
"lmstudio" | "ollama" => {
// LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
//
// These fail noisily, so we don't log them.
}
"copilot_chat" => {
// Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
}
_ => {
log::error!(
"Failed to authenticate provider: {}: {err}",
provider_name.0
);
}
}
}
}
}));
}
let all_authenticated_future = future::join_all(tasks);
cx.spawn(async move |cx| {
all_authenticated_future.await;
registry
.update(cx, |registry, cx| {
let cloud_provider = registry.provider(&cloud::PROVIDER_ID);
let fallback_model = cloud_provider
.iter()
.chain(registry.providers().iter())
.find(|provider| provider.is_authenticated(cx))
.and_then(|provider| {
Some(ConfiguredModel {
provider: provider.clone(),
model: provider
.default_model(cx)
.or_else(|| provider.recommended_models(cx).first().cloned())?,
})
});
registry.set_environment_fallback_model(fallback_model, cx);
})
.ok();
})
.detach();
}

View file

@ -44,8 +44,8 @@ use crate::provider::anthropic::{AnthropicEventMapper, count_anthropic_tokens, i
use crate::provider::google::{GoogleEventMapper, into_google};
use crate::provider::open_ai::{OpenAiEventMapper, count_open_ai_tokens, into_open_ai};
pub const PROVIDER_ID: LanguageModelProviderId = language_model::ZED_CLOUD_PROVIDER_ID;
pub const PROVIDER_NAME: LanguageModelProviderName = language_model::ZED_CLOUD_PROVIDER_NAME;
const PROVIDER_ID: LanguageModelProviderId = language_model::ZED_CLOUD_PROVIDER_ID;
const PROVIDER_NAME: LanguageModelProviderName = language_model::ZED_CLOUD_PROVIDER_NAME;
#[derive(Default, Clone, Debug, PartialEq)]
pub struct ZedDotDevSettings {
@ -146,7 +146,7 @@ impl State {
default_fast_model: None,
recommended_models: Vec::new(),
_fetch_models_task: cx.spawn(async move |this, cx| {
maybe!(async {
maybe!(async move {
let (client, llm_api_token) = this
.read_with(cx, |this, _cx| (client.clone(), this.llm_api_token.clone()))?;

View file

@ -4,7 +4,6 @@ use gpui::{
};
use itertools::Itertools;
use serde_json::json;
use settings::get_key_equivalents;
use ui::{Button, ButtonStyle};
use ui::{
ButtonCommon, Clickable, Context, FluentBuilder, InteractiveElement, Label, LabelCommon,
@ -169,7 +168,8 @@ impl Item for KeyContextView {
impl Render for KeyContextView {
fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl ui::IntoElement {
use itertools::Itertools;
let key_equivalents = get_key_equivalents(cx.keyboard_layout().id());
let key_equivalents = cx.keyboard_mapper().get_key_equivalents();
v_flex()
.id("key-context-view")
.overflow_scroll()

View file

@ -1743,6 +1743,5 @@ pub enum Event {
}
impl EventEmitter<Event> for LogStore {}
impl EventEmitter<Event> for LspLogView {}
impl EventEmitter<EditorEvent> for LspLogView {}
impl EventEmitter<SearchEvent> for LspLogView {}

View file

@ -11,6 +11,21 @@
(#set! injection.language "css"))
)
(call_expression
function: (member_expression
object: (identifier) @_obj (#eq? @_obj "styled")
property: (property_identifier))
arguments: (template_string (string_fragment) @injection.content
(#set! injection.language "css"))
)
(call_expression
function: (call_expression
function: (identifier) @_name (#eq? @_name "styled"))
arguments: (template_string (string_fragment) @injection.content
(#set! injection.language "css"))
)
(call_expression
function: (identifier) @_name (#eq? @_name "html")
arguments: (template_string) @injection.content

View file

@ -6,9 +6,6 @@
(self) @variable.special
(field_identifier) @property
(shorthand_field_initializer
(identifier) @property)
(trait_item name: (type_identifier) @type.interface)
(impl_item trait: (type_identifier) @type.interface)
(abstract_type trait: (type_identifier) @type.interface)
@ -41,20 +38,11 @@
(identifier) @function.special
(scoped_identifier
name: (identifier) @function.special)
]
"!" @function.special)
])
(macro_definition
name: (identifier) @function.special.definition)
(mod_item
name: (identifier) @module)
(visibility_modifier [
(crate) @keyword
(super) @keyword
])
; Identifier conventions
; Assume uppercase names are types/enum-constructors
@ -127,7 +115,9 @@
"where"
"while"
"yield"
(crate)
(mutable_specifier)
(super)
] @keyword
[
@ -199,7 +189,6 @@
operator: "/" @operator
(lifetime) @lifetime
(lifetime (identifier) @lifetime)
(parameter (identifier) @variable.parameter)

View file

@ -11,6 +11,21 @@
(#set! injection.language "css"))
)
(call_expression
function: (member_expression
object: (identifier) @_obj (#eq? @_obj "styled")
property: (property_identifier))
arguments: (template_string (string_fragment) @injection.content
(#set! injection.language "css"))
)
(call_expression
function: (call_expression
function: (identifier) @_name (#eq? @_name "styled"))
arguments: (template_string (string_fragment) @injection.content
(#set! injection.language "css"))
)
(call_expression
function: (identifier) @_name (#eq? @_name "html")
arguments: (template_string (string_fragment) @injection.content

View file

@ -15,6 +15,21 @@
(#set! injection.language "css"))
)
(call_expression
function: (member_expression
object: (identifier) @_obj (#eq? @_obj "styled")
property: (property_identifier))
arguments: (template_string (string_fragment) @injection.content
(#set! injection.language "css"))
)
(call_expression
function: (call_expression
function: (identifier) @_name (#eq? @_name "styled"))
arguments: (template_string (string_fragment) @injection.content
(#set! injection.language "css"))
)
(call_expression
function: (identifier) @_name (#eq? @_name "html")
arguments: (template_string) @injection.content

View file

@ -1323,7 +1323,7 @@ fn render_copy_code_block_button(
.icon_size(IconSize::Small)
.style(ButtonStyle::Filled)
.shape(ui::IconButtonShape::Square)
.tooltip(Tooltip::text("Copy Code"))
.tooltip(Tooltip::text("Copy"))
.on_click({
let markdown = markdown;
move |_event, _window, cx| {

View file

@ -283,17 +283,13 @@ pub(crate) fn render_ai_setup_page(
v_flex()
.mt_2()
.gap_6()
.child({
let mut ai_upsell_card =
AiUpsellCard::new(client, &user_store, user_store.read(cx).plan(), cx);
ai_upsell_card.tab_index = Some({
tab_index += 1;
tab_index - 1
});
ai_upsell_card
})
.child(
AiUpsellCard::new(client, &user_store, user_store.read(cx).plan(), cx)
.tab_index(Some({
tab_index += 1;
tab_index - 1
})),
)
.child(render_llm_provider_section(
&mut tab_index,
workspace,

View file

@ -850,13 +850,19 @@ impl workspace::SerializableItem for Onboarding {
}
mod persistence {
use db::{define_connection, query, sqlez_macros::sql};
use db::{
query,
sqlez::{domain::Domain, thread_safe_connection::ThreadSafeConnection},
sqlez_macros::sql,
};
use workspace::WorkspaceDb;
define_connection! {
pub static ref ONBOARDING_PAGES: OnboardingPagesDb<WorkspaceDb> =
&[
sql!(
pub struct OnboardingPagesDb(ThreadSafeConnection);
impl Domain for OnboardingPagesDb {
const NAME: &str = stringify!(OnboardingPagesDb);
const MIGRATIONS: &[&str] = &[sql!(
CREATE TABLE onboarding_pages (
workspace_id INTEGER,
item_id INTEGER UNIQUE,
@ -866,10 +872,11 @@ mod persistence {
FOREIGN KEY(workspace_id) REFERENCES workspaces(workspace_id)
ON DELETE CASCADE
) STRICT;
),
];
)];
}
db::static_connection!(ONBOARDING_PAGES, OnboardingPagesDb, [WorkspaceDb]);
impl OnboardingPagesDb {
query! {
pub async fn save_onboarding_page(

View file

@ -414,13 +414,19 @@ impl workspace::SerializableItem for WelcomePage {
}
mod persistence {
use db::{define_connection, query, sqlez_macros::sql};
use db::{
query,
sqlez::{domain::Domain, thread_safe_connection::ThreadSafeConnection},
sqlez_macros::sql,
};
use workspace::WorkspaceDb;
define_connection! {
pub static ref WELCOME_PAGES: WelcomePagesDb<WorkspaceDb> =
&[
sql!(
pub struct WelcomePagesDb(ThreadSafeConnection);
impl Domain for WelcomePagesDb {
const NAME: &str = stringify!(WelcomePagesDb);
const MIGRATIONS: &[&str] = (&[sql!(
CREATE TABLE welcome_pages (
workspace_id INTEGER,
item_id INTEGER UNIQUE,
@ -430,10 +436,11 @@ mod persistence {
FOREIGN KEY(workspace_id) REFERENCES workspaces(workspace_id)
ON DELETE CASCADE
) STRICT;
),
];
)]);
}
db::static_connection!(WELCOME_PAGES, WelcomePagesDb, [WorkspaceDb]);
impl WelcomePagesDb {
query! {
pub async fn save_welcome_page(

View file

@ -446,7 +446,6 @@ pub enum ResponseStreamResult {
#[derive(Serialize, Deserialize, Debug)]
pub struct ResponseStreamEvent {
pub model: String,
pub choices: Vec<ChoiceDelta>,
pub usage: Option<Usage>,
}

View file

@ -7588,19 +7588,16 @@ impl LspStore {
let snapshot = buffer_handle.read(cx).snapshot();
let buffer = buffer_handle.read(cx);
let reused_diagnostics = buffer
.get_diagnostics(server_id)
.into_iter()
.flat_map(|diag| {
diag.iter()
.filter(|v| merge(buffer, &v.diagnostic, cx))
.map(|v| {
let start = Unclipped(v.range.start.to_point_utf16(&snapshot));
let end = Unclipped(v.range.end.to_point_utf16(&snapshot));
DiagnosticEntry {
range: start..end,
diagnostic: v.diagnostic.clone(),
}
})
.buffer_diagnostics(Some(server_id))
.iter()
.filter(|v| merge(buffer, &v.diagnostic, cx))
.map(|v| {
let start = Unclipped(v.range.start.to_point_utf16(&snapshot));
let end = Unclipped(v.range.end.to_point_utf16(&snapshot));
DiagnosticEntry {
range: start..end,
diagnostic: v.diagnostic.clone(),
}
})
.collect::<Vec<_>>();
@ -11706,12 +11703,11 @@ impl LspStore {
// Ignore payload since we notify clients of setting changes unconditionally, relying on them pulling the latest settings.
}
"workspace/symbol" => {
if let Some(options) = parse_register_capabilities(reg)? {
server.update_capabilities(|capabilities| {
capabilities.workspace_symbol_provider = Some(options);
});
notify_server_capabilities_updated(&server, cx);
}
let options = parse_register_capabilities(reg)?;
server.update_capabilities(|capabilities| {
capabilities.workspace_symbol_provider = Some(options);
});
notify_server_capabilities_updated(&server, cx);
}
"workspace/fileOperations" => {
if let Some(options) = reg.register_options {
@ -11735,12 +11731,11 @@ impl LspStore {
}
}
"textDocument/rangeFormatting" => {
if let Some(options) = parse_register_capabilities(reg)? {
server.update_capabilities(|capabilities| {
capabilities.document_range_formatting_provider = Some(options);
});
notify_server_capabilities_updated(&server, cx);
}
let options = parse_register_capabilities(reg)?;
server.update_capabilities(|capabilities| {
capabilities.document_range_formatting_provider = Some(options);
});
notify_server_capabilities_updated(&server, cx);
}
"textDocument/onTypeFormatting" => {
if let Some(options) = reg
@ -11755,36 +11750,32 @@ impl LspStore {
}
}
"textDocument/formatting" => {
if let Some(options) = parse_register_capabilities(reg)? {
server.update_capabilities(|capabilities| {
capabilities.document_formatting_provider = Some(options);
});
notify_server_capabilities_updated(&server, cx);
}
let options = parse_register_capabilities(reg)?;
server.update_capabilities(|capabilities| {
capabilities.document_formatting_provider = Some(options);
});
notify_server_capabilities_updated(&server, cx);
}
"textDocument/rename" => {
if let Some(options) = parse_register_capabilities(reg)? {
server.update_capabilities(|capabilities| {
capabilities.rename_provider = Some(options);
});
notify_server_capabilities_updated(&server, cx);
}
let options = parse_register_capabilities(reg)?;
server.update_capabilities(|capabilities| {
capabilities.rename_provider = Some(options);
});
notify_server_capabilities_updated(&server, cx);
}
"textDocument/inlayHint" => {
if let Some(options) = parse_register_capabilities(reg)? {
server.update_capabilities(|capabilities| {
capabilities.inlay_hint_provider = Some(options);
});
notify_server_capabilities_updated(&server, cx);
}
let options = parse_register_capabilities(reg)?;
server.update_capabilities(|capabilities| {
capabilities.inlay_hint_provider = Some(options);
});
notify_server_capabilities_updated(&server, cx);
}
"textDocument/documentSymbol" => {
if let Some(options) = parse_register_capabilities(reg)? {
server.update_capabilities(|capabilities| {
capabilities.document_symbol_provider = Some(options);
});
notify_server_capabilities_updated(&server, cx);
}
let options = parse_register_capabilities(reg)?;
server.update_capabilities(|capabilities| {
capabilities.document_symbol_provider = Some(options);
});
notify_server_capabilities_updated(&server, cx);
}
"textDocument/codeAction" => {
if let Some(options) = reg
@ -11800,12 +11791,11 @@ impl LspStore {
}
}
"textDocument/definition" => {
if let Some(options) = parse_register_capabilities(reg)? {
server.update_capabilities(|capabilities| {
capabilities.definition_provider = Some(options);
});
notify_server_capabilities_updated(&server, cx);
}
let options = parse_register_capabilities(reg)?;
server.update_capabilities(|capabilities| {
capabilities.definition_provider = Some(options);
});
notify_server_capabilities_updated(&server, cx);
}
"textDocument/completion" => {
if let Some(caps) = reg
@ -12184,10 +12174,10 @@ impl LspStore {
// https://github.com/microsoft/vscode-languageserver-node/blob/d90a87f9557a0df9142cfb33e251cfa6fe27d970/client/src/common/client.ts#L2133
fn parse_register_capabilities<T: serde::de::DeserializeOwned>(
reg: lsp::Registration,
) -> anyhow::Result<Option<OneOf<bool, T>>> {
) -> Result<OneOf<bool, T>> {
Ok(match reg.register_options {
Some(options) => Some(OneOf::Right(serde_json::from_value::<T>(options)?)),
None => Some(OneOf::Left(true)),
Some(options) => OneOf::Right(serde_json::from_value::<T>(options)?),
None => OneOf::Left(true),
})
}

View file

@ -4089,6 +4089,7 @@ impl ProjectPanel {
.when(!is_sticky, |this| {
this
.when(is_highlighted && folded_directory_drag_target.is_none(), |this| this.border_color(transparent_white()).bg(item_colors.drag_over))
.when(settings.drag_and_drop, |this| this
.on_drag_move::<ExternalPaths>(cx.listener(
move |this, event: &DragMoveEvent<ExternalPaths>, _, cx| {
let is_current_target = this.drag_target_entry.as_ref()
@ -4222,7 +4223,7 @@ impl ProjectPanel {
}
this.drag_onto(selections, entry_id, kind.is_file(), window, cx);
}),
)
))
})
.on_mouse_down(
MouseButton::Left,
@ -4433,6 +4434,7 @@ impl ProjectPanel {
div()
.when(!is_sticky, |div| {
div
.when(settings.drag_and_drop, |div| div
.on_drop(cx.listener(move |this, selections: &DraggedSelection, window, cx| {
this.hover_scroll_task.take();
this.drag_target_entry = None;
@ -4464,7 +4466,7 @@ impl ProjectPanel {
}
},
))
)))
})
.child(
Label::new(DELIMITER.clone())
@ -4484,6 +4486,7 @@ impl ProjectPanel {
.when(index != components_len - 1, |div|{
let target_entry_id = folded_ancestors.ancestors.get(components_len - 1 - index).cloned();
div
.when(settings.drag_and_drop, |div| div
.on_drag_move(cx.listener(
move |this, event: &DragMoveEvent<DraggedSelection>, _, _| {
if event.bounds.contains(&event.event.position) {
@ -4521,7 +4524,7 @@ impl ProjectPanel {
target.index == index
), |this| {
this.bg(item_colors.drag_over)
})
}))
})
})
.on_click(cx.listener(move |this, _, _, cx| {
@ -5029,7 +5032,8 @@ impl ProjectPanel {
sticky_parents.reverse();
let git_status_enabled = ProjectPanelSettings::get_global(cx).git_status;
let panel_settings = ProjectPanelSettings::get_global(cx);
let git_status_enabled = panel_settings.git_status;
let root_name = OsStr::new(worktree.root_name());
let git_summaries_by_id = if git_status_enabled {
@ -5113,11 +5117,11 @@ impl Render for ProjectPanel {
fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let has_worktree = !self.visible_entries.is_empty();
let project = self.project.read(cx);
let indent_size = ProjectPanelSettings::get_global(cx).indent_size;
let show_indent_guides =
ProjectPanelSettings::get_global(cx).indent_guides.show == ShowIndentGuides::Always;
let panel_settings = ProjectPanelSettings::get_global(cx);
let indent_size = panel_settings.indent_size;
let show_indent_guides = panel_settings.indent_guides.show == ShowIndentGuides::Always;
let show_sticky_entries = {
if ProjectPanelSettings::get_global(cx).sticky_scroll {
if panel_settings.sticky_scroll {
let is_scrollable = self.scroll_handle.is_scrollable();
let is_scrolled = self.scroll_handle.offset().y < px(0.);
is_scrollable && is_scrolled
@ -5205,8 +5209,10 @@ impl Render for ProjectPanel {
h_flex()
.id("project-panel")
.group("project-panel")
.on_drag_move(cx.listener(handle_drag_move::<ExternalPaths>))
.on_drag_move(cx.listener(handle_drag_move::<DraggedSelection>))
.when(panel_settings.drag_and_drop, |this| {
this.on_drag_move(cx.listener(handle_drag_move::<ExternalPaths>))
.on_drag_move(cx.listener(handle_drag_move::<DraggedSelection>))
})
.size_full()
.relative()
.on_modifiers_changed(cx.listener(
@ -5544,30 +5550,32 @@ impl Render for ProjectPanel {
})),
)
.when(is_local, |div| {
div.drag_over::<ExternalPaths>(|style, _, _, cx| {
style.bg(cx.theme().colors().drop_target_background)
div.when(panel_settings.drag_and_drop, |div| {
div.drag_over::<ExternalPaths>(|style, _, _, cx| {
style.bg(cx.theme().colors().drop_target_background)
})
.on_drop(cx.listener(
move |this, external_paths: &ExternalPaths, window, cx| {
this.drag_target_entry = None;
this.hover_scroll_task.take();
if let Some(task) = this
.workspace
.update(cx, |workspace, cx| {
workspace.open_workspace_for_paths(
true,
external_paths.paths().to_owned(),
window,
cx,
)
})
.log_err()
{
task.detach_and_log_err(cx);
}
cx.stop_propagation();
},
))
})
.on_drop(cx.listener(
move |this, external_paths: &ExternalPaths, window, cx| {
this.drag_target_entry = None;
this.hover_scroll_task.take();
if let Some(task) = this
.workspace
.update(cx, |workspace, cx| {
workspace.open_workspace_for_paths(
true,
external_paths.paths().to_owned(),
window,
cx,
)
})
.log_err()
{
task.detach_and_log_err(cx);
}
cx.stop_propagation();
},
))
})
}
}

View file

@ -47,6 +47,7 @@ pub struct ProjectPanelSettings {
pub scrollbar: ScrollbarSettings,
pub show_diagnostics: ShowDiagnostics,
pub hide_root: bool,
pub drag_and_drop: bool,
}
#[derive(Copy, Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq, Eq)]
@ -160,6 +161,10 @@ pub struct ProjectPanelSettingsContent {
///
/// Default: true
pub sticky_scroll: Option<bool>,
/// Whether to enable drag-and-drop operations in the project panel.
///
/// Default: true
pub drag_and_drop: Option<bool>,
}
impl Settings for ProjectPanelSettings {

View file

@ -445,7 +445,7 @@ impl SshSocket {
}
async fn platform(&self) -> Result<SshPlatform> {
let uname = self.run_command("sh", &["-c", "uname -sm"]).await?;
let uname = self.run_command("sh", &["-lc", "uname -sm"]).await?;
let Some((os, arch)) = uname.split_once(" ") else {
anyhow::bail!("unknown uname: {uname:?}")
};
@ -476,7 +476,7 @@ impl SshSocket {
}
async fn shell(&self) -> String {
match self.run_command("sh", &["-c", "echo $SHELL"]).await {
match self.run_command("sh", &["-lc", "echo $SHELL"]).await {
Ok(shell) => shell.trim().to_owned(),
Err(e) => {
log::error!("Failed to get shell: {e}");
@ -1533,7 +1533,7 @@ impl RemoteConnection for SshRemoteConnection {
let ssh_proxy_process = match self
.socket
.ssh_command("sh", &["-c", &start_proxy_command])
.ssh_command("sh", &["-lc", &start_proxy_command])
// IMPORTANT: we kill this process when we drop the task that uses it.
.kill_on_drop(true)
.spawn()
@ -1910,7 +1910,7 @@ impl SshRemoteConnection {
.run_command(
"sh",
&[
"-c",
"-lc",
&shell_script!("mkdir -p {parent}", parent = parent.to_string().as_ref()),
],
)
@ -1988,7 +1988,7 @@ impl SshRemoteConnection {
.run_command(
"sh",
&[
"-c",
"-lc",
&shell_script!("mkdir -p {parent}", parent = parent.to_string().as_ref()),
],
)
@ -2036,7 +2036,7 @@ impl SshRemoteConnection {
dst_path = &dst_path.to_string()
)
};
self.socket.run_command("sh", &["-c", &script]).await?;
self.socket.run_command("sh", &["-lc", &script]).await?;
Ok(())
}

View file

@ -65,6 +65,7 @@ telemetry_events.workspace = true
util.workspace = true
watch.workspace = true
worktree.workspace = true
thiserror.workspace = true
[target.'cfg(not(windows))'.dependencies]
crashes.workspace = true

View file

@ -1,6 +1,7 @@
#![cfg_attr(target_os = "windows", allow(unused, dead_code))]
use clap::{Parser, Subcommand};
use clap::Parser;
use remote_server::Commands;
use std::path::PathBuf;
#[derive(Parser)]
@ -21,105 +22,34 @@ struct Cli {
printenv: bool,
}
#[derive(Subcommand)]
enum Commands {
Run {
#[arg(long)]
log_file: PathBuf,
#[arg(long)]
pid_file: PathBuf,
#[arg(long)]
stdin_socket: PathBuf,
#[arg(long)]
stdout_socket: PathBuf,
#[arg(long)]
stderr_socket: PathBuf,
},
Proxy {
#[arg(long)]
reconnect: bool,
#[arg(long)]
identifier: String,
},
Version,
}
#[cfg(windows)]
fn main() {
unimplemented!()
}
#[cfg(not(windows))]
fn main() {
use release_channel::{RELEASE_CHANNEL, ReleaseChannel};
use remote::proxy::ProxyLaunchError;
use remote_server::unix::{execute_proxy, execute_run};
fn main() -> anyhow::Result<()> {
let cli = Cli::parse();
if let Some(socket_path) = &cli.askpass {
askpass::main(socket_path);
return;
return Ok(());
}
if let Some(socket) = &cli.crash_handler {
crashes::crash_server(socket.as_path());
return;
return Ok(());
}
if cli.printenv {
util::shell_env::print_env();
return;
return Ok(());
}
let result = match cli.command {
Some(Commands::Run {
log_file,
pid_file,
stdin_socket,
stdout_socket,
stderr_socket,
}) => execute_run(
log_file,
pid_file,
stdin_socket,
stdout_socket,
stderr_socket,
),
Some(Commands::Proxy {
identifier,
reconnect,
}) => match execute_proxy(identifier, reconnect) {
Ok(_) => Ok(()),
Err(err) => {
if let Some(err) = err.downcast_ref::<ProxyLaunchError>() {
std::process::exit(err.to_exit_code());
}
Err(err)
}
},
Some(Commands::Version) => {
let release_channel = *RELEASE_CHANNEL;
match release_channel {
ReleaseChannel::Stable | ReleaseChannel::Preview => {
println!("{}", env!("ZED_PKG_VERSION"))
}
ReleaseChannel::Nightly | ReleaseChannel::Dev => {
println!(
"{}",
option_env!("ZED_COMMIT_SHA").unwrap_or(release_channel.dev_name())
)
}
};
std::process::exit(0);
}
None => {
eprintln!("usage: remote <run|proxy|version>");
std::process::exit(1);
}
};
if let Err(error) = result {
log::error!("exiting due to error: {}", error);
if let Some(command) = cli.command {
remote_server::run(command)
} else {
eprintln!("usage: remote <run|proxy|version>");
std::process::exit(1);
}
}

View file

@ -6,4 +6,78 @@ pub mod unix;
#[cfg(test)]
mod remote_editing_tests;
use clap::Subcommand;
use std::path::PathBuf;
pub use headless_project::{HeadlessAppState, HeadlessProject};
#[derive(Subcommand)]
pub enum Commands {
Run {
#[arg(long)]
log_file: PathBuf,
#[arg(long)]
pid_file: PathBuf,
#[arg(long)]
stdin_socket: PathBuf,
#[arg(long)]
stdout_socket: PathBuf,
#[arg(long)]
stderr_socket: PathBuf,
},
Proxy {
#[arg(long)]
reconnect: bool,
#[arg(long)]
identifier: String,
},
Version,
}
#[cfg(not(windows))]
pub fn run(command: Commands) -> anyhow::Result<()> {
use anyhow::Context;
use release_channel::{RELEASE_CHANNEL, ReleaseChannel};
use unix::{ExecuteProxyError, execute_proxy, execute_run};
match command {
Commands::Run {
log_file,
pid_file,
stdin_socket,
stdout_socket,
stderr_socket,
} => execute_run(
log_file,
pid_file,
stdin_socket,
stdout_socket,
stderr_socket,
),
Commands::Proxy {
identifier,
reconnect,
} => execute_proxy(identifier, reconnect)
.inspect_err(|err| {
if let ExecuteProxyError::ServerNotRunning(err) = err {
std::process::exit(err.to_exit_code());
}
})
.context("running proxy on the remote server"),
Commands::Version => {
let release_channel = *RELEASE_CHANNEL;
match release_channel {
ReleaseChannel::Stable | ReleaseChannel::Preview => {
println!("{}", env!("ZED_PKG_VERSION"))
}
ReleaseChannel::Nightly | ReleaseChannel::Dev => {
println!(
"{}",
option_env!("ZED_COMMIT_SHA").unwrap_or(release_channel.dev_name())
)
}
};
Ok(())
}
}
}

View file

@ -36,6 +36,7 @@ use smol::Async;
use smol::{net::unix::UnixListener, stream::StreamExt as _};
use std::ffi::OsStr;
use std::ops::ControlFlow;
use std::process::ExitStatus;
use std::str::FromStr;
use std::sync::LazyLock;
use std::{env, thread};
@ -46,6 +47,7 @@ use std::{
sync::Arc,
};
use telemetry_events::LocationData;
use thiserror::Error;
use util::ResultExt;
pub static VERSION: LazyLock<&str> = LazyLock::new(|| match *RELEASE_CHANNEL {
@ -526,7 +528,23 @@ pub fn execute_run(
Ok(())
}
#[derive(Clone)]
#[derive(Debug, Error)]
pub(crate) enum ServerPathError {
#[error("Failed to create server_dir `{path}`")]
CreateServerDir {
#[source]
source: std::io::Error,
path: PathBuf,
},
#[error("Failed to create logs_dir `{path}`")]
CreateLogsDir {
#[source]
source: std::io::Error,
path: PathBuf,
},
}
#[derive(Clone, Debug)]
struct ServerPaths {
log_file: PathBuf,
pid_file: PathBuf,
@ -536,10 +554,19 @@ struct ServerPaths {
}
impl ServerPaths {
fn new(identifier: &str) -> Result<Self> {
fn new(identifier: &str) -> Result<Self, ServerPathError> {
let server_dir = paths::remote_server_state_dir().join(identifier);
std::fs::create_dir_all(&server_dir)?;
std::fs::create_dir_all(&logs_dir())?;
std::fs::create_dir_all(&server_dir).map_err(|source| {
ServerPathError::CreateServerDir {
source,
path: server_dir.clone(),
}
})?;
let log_dir = logs_dir();
std::fs::create_dir_all(log_dir).map_err(|source| ServerPathError::CreateLogsDir {
source: source,
path: log_dir.clone(),
})?;
let pid_file = server_dir.join("server.pid");
let stdin_socket = server_dir.join("stdin.sock");
@ -557,7 +584,43 @@ impl ServerPaths {
}
}
pub fn execute_proxy(identifier: String, is_reconnecting: bool) -> Result<()> {
#[derive(Debug, Error)]
pub(crate) enum ExecuteProxyError {
#[error("Failed to init server paths")]
ServerPath(#[from] ServerPathError),
#[error(transparent)]
ServerNotRunning(#[from] ProxyLaunchError),
#[error("Failed to check PidFile '{path}'")]
CheckPidFile {
#[source]
source: CheckPidError,
path: PathBuf,
},
#[error("Failed to kill existing server with pid '{pid}'")]
KillRunningServer {
#[source]
source: std::io::Error,
pid: u32,
},
#[error("failed to spawn server")]
SpawnServer(#[source] SpawnServerError),
#[error("stdin_task failed")]
StdinTask(#[source] anyhow::Error),
#[error("stdout_task failed")]
StdoutTask(#[source] anyhow::Error),
#[error("stderr_task failed")]
StderrTask(#[source] anyhow::Error),
}
pub(crate) fn execute_proxy(
identifier: String,
is_reconnecting: bool,
) -> Result<(), ExecuteProxyError> {
init_logging_proxy();
let server_paths = ServerPaths::new(&identifier)?;
@ -574,12 +637,19 @@ pub fn execute_proxy(identifier: String, is_reconnecting: bool) -> Result<()> {
log::info!("starting proxy process. PID: {}", std::process::id());
let server_pid = check_pid_file(&server_paths.pid_file)?;
let server_pid = check_pid_file(&server_paths.pid_file).map_err(|source| {
ExecuteProxyError::CheckPidFile {
source,
path: server_paths.pid_file.clone(),
}
})?;
let server_running = server_pid.is_some();
if is_reconnecting {
if !server_running {
log::error!("attempted to reconnect, but no server running");
anyhow::bail!(ProxyLaunchError::ServerNotRunning);
return Err(ExecuteProxyError::ServerNotRunning(
ProxyLaunchError::ServerNotRunning,
));
}
} else {
if let Some(pid) = server_pid {
@ -590,7 +660,7 @@ pub fn execute_proxy(identifier: String, is_reconnecting: bool) -> Result<()> {
kill_running_server(pid, &server_paths)?;
}
spawn_server(&server_paths)?;
spawn_server(&server_paths).map_err(ExecuteProxyError::SpawnServer)?;
};
let stdin_task = smol::spawn(async move {
@ -630,9 +700,9 @@ pub fn execute_proxy(identifier: String, is_reconnecting: bool) -> Result<()> {
if let Err(forwarding_result) = smol::block_on(async move {
futures::select! {
result = stdin_task.fuse() => result.context("stdin_task failed"),
result = stdout_task.fuse() => result.context("stdout_task failed"),
result = stderr_task.fuse() => result.context("stderr_task failed"),
result = stdin_task.fuse() => result.map_err(ExecuteProxyError::StdinTask),
result = stdout_task.fuse() => result.map_err(ExecuteProxyError::StdoutTask),
result = stderr_task.fuse() => result.map_err(ExecuteProxyError::StderrTask),
}
}) {
log::error!(
@ -645,12 +715,12 @@ pub fn execute_proxy(identifier: String, is_reconnecting: bool) -> Result<()> {
Ok(())
}
fn kill_running_server(pid: u32, paths: &ServerPaths) -> Result<()> {
fn kill_running_server(pid: u32, paths: &ServerPaths) -> Result<(), ExecuteProxyError> {
log::info!("killing existing server with PID {}", pid);
std::process::Command::new("kill")
.arg(pid.to_string())
.output()
.context("failed to kill existing server")?;
.map_err(|source| ExecuteProxyError::KillRunningServer { source, pid })?;
for file in [
&paths.pid_file,
@ -664,18 +734,39 @@ fn kill_running_server(pid: u32, paths: &ServerPaths) -> Result<()> {
Ok(())
}
fn spawn_server(paths: &ServerPaths) -> Result<()> {
#[derive(Debug, Error)]
pub(crate) enum SpawnServerError {
#[error("failed to remove stdin socket")]
RemoveStdinSocket(#[source] std::io::Error),
#[error("failed to remove stdout socket")]
RemoveStdoutSocket(#[source] std::io::Error),
#[error("failed to remove stderr socket")]
RemoveStderrSocket(#[source] std::io::Error),
#[error("failed to get current_exe")]
CurrentExe(#[source] std::io::Error),
#[error("failed to launch server process")]
ProcessStatus(#[source] std::io::Error),
#[error("failed to launch and detach server process: {status}\n{paths}")]
LaunchStatus { status: ExitStatus, paths: String },
}
fn spawn_server(paths: &ServerPaths) -> Result<(), SpawnServerError> {
if paths.stdin_socket.exists() {
std::fs::remove_file(&paths.stdin_socket)?;
std::fs::remove_file(&paths.stdin_socket).map_err(SpawnServerError::RemoveStdinSocket)?;
}
if paths.stdout_socket.exists() {
std::fs::remove_file(&paths.stdout_socket)?;
std::fs::remove_file(&paths.stdout_socket).map_err(SpawnServerError::RemoveStdoutSocket)?;
}
if paths.stderr_socket.exists() {
std::fs::remove_file(&paths.stderr_socket)?;
std::fs::remove_file(&paths.stderr_socket).map_err(SpawnServerError::RemoveStderrSocket)?;
}
let binary_name = std::env::current_exe()?;
let binary_name = std::env::current_exe().map_err(SpawnServerError::CurrentExe)?;
let mut server_process = std::process::Command::new(binary_name);
server_process
.arg("run")
@ -692,11 +783,17 @@ fn spawn_server(paths: &ServerPaths) -> Result<()> {
let status = server_process
.status()
.context("failed to launch server process")?;
anyhow::ensure!(
status.success(),
"failed to launch and detach server process"
);
.map_err(SpawnServerError::ProcessStatus)?;
if !status.success() {
return Err(SpawnServerError::LaunchStatus {
status,
paths: format!(
"log file: {:?}, pid file: {:?}",
paths.log_file, paths.pid_file,
),
});
}
let mut total_time_waited = std::time::Duration::from_secs(0);
let wait_duration = std::time::Duration::from_millis(20);
@ -717,7 +814,15 @@ fn spawn_server(paths: &ServerPaths) -> Result<()> {
Ok(())
}
fn check_pid_file(path: &Path) -> Result<Option<u32>> {
#[derive(Debug, Error)]
#[error("Failed to remove PID file for missing process (pid `{pid}`")]
pub(crate) struct CheckPidError {
#[source]
source: std::io::Error,
pid: u32,
}
fn check_pid_file(path: &Path) -> Result<Option<u32>, CheckPidError> {
let Some(pid) = std::fs::read_to_string(&path)
.ok()
.and_then(|contents| contents.parse::<u32>().ok())
@ -742,7 +847,7 @@ fn check_pid_file(path: &Path) -> Result<Option<u32>> {
log::debug!(
"Found PID file, but process with that PID does not exist. Removing PID file."
);
std::fs::remove_file(&path).context("Failed to remove PID file")?;
std::fs::remove_file(&path).map_err(|source| CheckPidError { source, pid })?;
Ok(None)
}
}

File diff suppressed because it is too large Load diff

Some files were not shown because too many files have changed in this diff Show more