Compare commits
32 commits
Author | SHA1 | Date | |
---|---|---|---|
![]() |
a3677f4002 | ||
![]() |
4af18d4e83 | ||
![]() |
7c0e8e64b7 | ||
![]() |
72015b2f66 | ||
![]() |
730960f483 | ||
![]() |
71de16c9d2 | ||
![]() |
d90380e3ca | ||
![]() |
e19ed37bd1 | ||
![]() |
9bb4c657e3 | ||
![]() |
28bb50798b | ||
![]() |
9394a698f7 | ||
![]() |
97948bf613 | ||
![]() |
b67f775c0e | ||
![]() |
43e2c92910 | ||
![]() |
cc5aaf765f | ||
![]() |
395cd47164 | ||
![]() |
6ae05c95d3 | ||
![]() |
662a13f034 | ||
![]() |
fa815dbf70 | ||
![]() |
a394df5c0c | ||
![]() |
76d78e8a14 | ||
![]() |
3a7871d248 | ||
![]() |
7f2283749b | ||
![]() |
2d724520bc | ||
![]() |
473062aeef | ||
![]() |
612c9addff | ||
![]() |
19a60dbf9c | ||
![]() |
acba38dabd | ||
![]() |
c1b3111c15 | ||
![]() |
623388ad80 | ||
![]() |
eb89e9a572 | ||
![]() |
e306a55073 |
69 changed files with 3407 additions and 1231 deletions
7
.github/workflows/ci.yml
vendored
7
.github/workflows/ci.yml
vendored
|
@ -686,8 +686,10 @@ jobs:
|
|||
timeout-minutes: 60
|
||||
runs-on: github-8vcpu-ubuntu-2404
|
||||
if: |
|
||||
false && (
|
||||
startsWith(github.ref, 'refs/tags/v')
|
||||
|| contains(github.event.pull_request.labels.*.name, 'run-bundling')
|
||||
)
|
||||
needs: [linux_tests]
|
||||
name: Build Zed on FreeBSD
|
||||
# env:
|
||||
|
@ -757,7 +759,7 @@ jobs:
|
|||
timeout-minutes: 120
|
||||
name: Create a Windows installer
|
||||
runs-on: [self-hosted, Windows, X64]
|
||||
if: ${{ startsWith(github.ref, 'refs/tags/v') || contains(github.event.pull_request.labels.*.name, 'run-bundling') }}
|
||||
if: false && (startsWith(github.ref, 'refs/tags/v') || contains(github.event.pull_request.labels.*.name, 'run-bundling'))
|
||||
needs: [windows_tests]
|
||||
env:
|
||||
AZURE_TENANT_ID: ${{ secrets.AZURE_SIGNING_TENANT_ID }}
|
||||
|
@ -800,6 +802,7 @@ jobs:
|
|||
|
||||
- name: Upload Artifacts to release
|
||||
uses: softprops/action-gh-release@de2c0eb89ae2a093876385947365aca7b0e5f844 # v1
|
||||
# Re-enable when we are ready to publish windows preview releases
|
||||
if: ${{ !(contains(github.event.pull_request.labels.*.name, 'run-bundling')) && env.RELEASE_CHANNEL == 'preview' }} # upload only preview
|
||||
with:
|
||||
draft: true
|
||||
|
@ -813,7 +816,7 @@ jobs:
|
|||
if: |
|
||||
startsWith(github.ref, 'refs/tags/v')
|
||||
&& endsWith(github.ref, '-pre') && !endsWith(github.ref, '.0-pre')
|
||||
needs: [bundle-mac, bundle-linux-x86_x64, bundle-linux-aarch64, bundle-windows-x64, freebsd]
|
||||
needs: [bundle-mac, bundle-linux-x86_x64, bundle-linux-aarch64, bundle-windows-x64]
|
||||
runs-on:
|
||||
- self-hosted
|
||||
- bundle
|
||||
|
|
2
.github/workflows/release_nightly.yml
vendored
2
.github/workflows/release_nightly.yml
vendored
|
@ -195,7 +195,7 @@ jobs:
|
|||
|
||||
freebsd:
|
||||
timeout-minutes: 60
|
||||
if: github.repository_owner == 'zed-industries'
|
||||
if: false && github.repository_owner == 'zed-industries'
|
||||
runs-on: github-8vcpu-ubuntu-2404
|
||||
needs: tests
|
||||
env:
|
||||
|
|
21
Cargo.lock
generated
21
Cargo.lock
generated
|
@ -2078,7 +2078,7 @@ dependencies = [
|
|||
[[package]]
|
||||
name = "blade-graphics"
|
||||
version = "0.6.0"
|
||||
source = "git+https://github.com/kvark/blade?rev=416375211bb0b5826b3584dccdb6a43369e499ad#416375211bb0b5826b3584dccdb6a43369e499ad"
|
||||
source = "git+https://github.com/kvark/blade?rev=e0ec4e720957edd51b945b64dd85605ea54bcfe5#e0ec4e720957edd51b945b64dd85605ea54bcfe5"
|
||||
dependencies = [
|
||||
"ash",
|
||||
"ash-window",
|
||||
|
@ -2111,7 +2111,7 @@ dependencies = [
|
|||
[[package]]
|
||||
name = "blade-macros"
|
||||
version = "0.3.0"
|
||||
source = "git+https://github.com/kvark/blade?rev=416375211bb0b5826b3584dccdb6a43369e499ad#416375211bb0b5826b3584dccdb6a43369e499ad"
|
||||
source = "git+https://github.com/kvark/blade?rev=e0ec4e720957edd51b945b64dd85605ea54bcfe5#e0ec4e720957edd51b945b64dd85605ea54bcfe5"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
|
@ -2121,7 +2121,7 @@ dependencies = [
|
|||
[[package]]
|
||||
name = "blade-util"
|
||||
version = "0.2.0"
|
||||
source = "git+https://github.com/kvark/blade?rev=416375211bb0b5826b3584dccdb6a43369e499ad#416375211bb0b5826b3584dccdb6a43369e499ad"
|
||||
source = "git+https://github.com/kvark/blade?rev=e0ec4e720957edd51b945b64dd85605ea54bcfe5#e0ec4e720957edd51b945b64dd85605ea54bcfe5"
|
||||
dependencies = [
|
||||
"blade-graphics",
|
||||
"bytemuck",
|
||||
|
@ -3043,6 +3043,7 @@ dependencies = [
|
|||
"context_server",
|
||||
"ctor",
|
||||
"dap",
|
||||
"dap-types",
|
||||
"dap_adapters",
|
||||
"dashmap 6.1.0",
|
||||
"debugger_ui",
|
||||
|
@ -9000,6 +9001,7 @@ dependencies = [
|
|||
"util",
|
||||
"vercel",
|
||||
"workspace-hack",
|
||||
"x_ai",
|
||||
"zed_llm_client",
|
||||
]
|
||||
|
||||
|
@ -19731,6 +19733,17 @@ version = "0.13.1"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ec107c4503ea0b4a98ef47356329af139c0a4f7750e621cf2973cd3385ebcb3d"
|
||||
|
||||
[[package]]
|
||||
name = "x_ai"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"schemars",
|
||||
"serde",
|
||||
"strum 0.27.1",
|
||||
"workspace-hack",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "xattr"
|
||||
version = "0.2.3"
|
||||
|
@ -19972,7 +19985,7 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "zed"
|
||||
version = "0.195.0"
|
||||
version = "0.195.5"
|
||||
dependencies = [
|
||||
"activity_indicator",
|
||||
"agent",
|
||||
|
|
14
Cargo.toml
14
Cargo.toml
|
@ -177,6 +177,7 @@ members = [
|
|||
"crates/welcome",
|
||||
"crates/workspace",
|
||||
"crates/worktree",
|
||||
"crates/x_ai",
|
||||
"crates/zed",
|
||||
"crates/zed_actions",
|
||||
"crates/zeta",
|
||||
|
@ -390,6 +391,7 @@ web_search_providers = { path = "crates/web_search_providers" }
|
|||
welcome = { path = "crates/welcome" }
|
||||
workspace = { path = "crates/workspace" }
|
||||
worktree = { path = "crates/worktree" }
|
||||
x_ai = { path = "crates/x_ai" }
|
||||
zed = { path = "crates/zed" }
|
||||
zed_actions = { path = "crates/zed_actions" }
|
||||
zeta = { path = "crates/zeta" }
|
||||
|
@ -427,9 +429,9 @@ aws-smithy-runtime-api = { version = "1.7.4", features = ["http-1x", "client"] }
|
|||
aws-smithy-types = { version = "1.3.0", features = ["http-body-1-x"] }
|
||||
base64 = "0.22"
|
||||
bitflags = "2.6.0"
|
||||
blade-graphics = { git = "https://github.com/kvark/blade", rev = "416375211bb0b5826b3584dccdb6a43369e499ad" }
|
||||
blade-macros = { git = "https://github.com/kvark/blade", rev = "416375211bb0b5826b3584dccdb6a43369e499ad" }
|
||||
blade-util = { git = "https://github.com/kvark/blade", rev = "416375211bb0b5826b3584dccdb6a43369e499ad" }
|
||||
blade-graphics = { git = "https://github.com/kvark/blade", rev = "e0ec4e720957edd51b945b64dd85605ea54bcfe5" }
|
||||
blade-macros = { git = "https://github.com/kvark/blade", rev = "e0ec4e720957edd51b945b64dd85605ea54bcfe5" }
|
||||
blade-util = { git = "https://github.com/kvark/blade", rev = "e0ec4e720957edd51b945b64dd85605ea54bcfe5" }
|
||||
blake3 = "1.5.3"
|
||||
bytes = "1.0"
|
||||
cargo_metadata = "0.19"
|
||||
|
@ -482,7 +484,7 @@ json_dotpath = "1.1"
|
|||
jsonschema = "0.30.0"
|
||||
jsonwebtoken = "9.3"
|
||||
jupyter-protocol = { git = "https://github.com/ConradIrwin/runtimed", rev = "7130c804216b6914355d15d0b91ea91f6babd734" }
|
||||
jupyter-websocket-client = { git = "https://github.com/ConradIrwin/runtimed", rev = "7130c804216b6914355d15d0b91ea91f6babd734" }
|
||||
jupyter-websocket-client = { git = "https://github.com/ConradIrwin/runtimed" ,rev = "7130c804216b6914355d15d0b91ea91f6babd734" }
|
||||
libc = "0.2"
|
||||
libsqlite3-sys = { version = "0.30.1", features = ["bundled"] }
|
||||
linkify = "0.10.0"
|
||||
|
@ -493,7 +495,7 @@ metal = "0.29"
|
|||
moka = { version = "0.12.10", features = ["sync"] }
|
||||
naga = { version = "25.0", features = ["wgsl-in"] }
|
||||
nanoid = "0.4"
|
||||
nbformat = { git = "https://github.com/ConradIrwin/runtimed", rev = "7130c804216b6914355d15d0b91ea91f6babd734" }
|
||||
nbformat = { git = "https://github.com/ConradIrwin/runtimed", rev = "7130c804216b6914355d15d0b91ea91f6babd734" }
|
||||
nix = "0.29"
|
||||
num-format = "0.4.4"
|
||||
objc = "0.2"
|
||||
|
@ -533,7 +535,7 @@ reqwest = { git = "https://github.com/zed-industries/reqwest.git", rev = "951c77
|
|||
"stream",
|
||||
] }
|
||||
rsa = "0.9.6"
|
||||
runtimelib = { git = "https://github.com/ConradIrwin/runtimed", rev = "7130c804216b6914355d15d0b91ea91f6babd734", default-features = false, features = [
|
||||
runtimelib = { git = "https://github.com/ConradIrwin/runtimed", rev = "7130c804216b6914355d15d0b91ea91f6babd734", default-features = false, features = [
|
||||
"async-dispatcher-runtime",
|
||||
] }
|
||||
rust-embed = { version = "8.4", features = ["include-exclude"] }
|
||||
|
|
3
assets/icons/ai_x_ai.svg
Normal file
3
assets/icons/ai_x_ai.svg
Normal file
|
@ -0,0 +1,3 @@
|
|||
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="m12.414 5.47.27 9.641h2.157l.27-13.15zM15.11.889h-3.293L6.651 7.613l1.647 2.142zM.889 15.11H4.18l1.647-2.142-1.647-2.143zm0-9.641 7.409 9.641h3.292L4.181 5.47z" fill="#000"/>
|
||||
</svg>
|
After Width: | Height: | Size: 289 B |
|
@ -21,6 +21,7 @@ use gpui::{
|
|||
AnyWindowHandle, App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task,
|
||||
WeakEntity, Window,
|
||||
};
|
||||
use http_client::StatusCode;
|
||||
use language_model::{
|
||||
ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
|
||||
LanguageModelExt as _, LanguageModelId, LanguageModelRegistry, LanguageModelRequest,
|
||||
|
@ -51,7 +52,19 @@ use uuid::Uuid;
|
|||
use zed_llm_client::{CompletionIntent, CompletionRequestStatus, UsageLimit};
|
||||
|
||||
const MAX_RETRY_ATTEMPTS: u8 = 3;
|
||||
const BASE_RETRY_DELAY_SECS: u64 = 5;
|
||||
const BASE_RETRY_DELAY: Duration = Duration::from_secs(5);
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
enum RetryStrategy {
|
||||
ExponentialBackoff {
|
||||
initial_delay: Duration,
|
||||
max_attempts: u8,
|
||||
},
|
||||
Fixed {
|
||||
delay: Duration,
|
||||
max_attempts: u8,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(
|
||||
Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
|
||||
|
@ -383,6 +396,7 @@ pub struct Thread {
|
|||
remaining_turns: u32,
|
||||
configured_model: Option<ConfiguredModel>,
|
||||
profile: AgentProfile,
|
||||
last_error_context: Option<(Arc<dyn LanguageModel>, CompletionIntent)>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
|
@ -476,10 +490,11 @@ impl Thread {
|
|||
retry_state: None,
|
||||
message_feedback: HashMap::default(),
|
||||
last_auto_capture_at: None,
|
||||
last_error_context: None,
|
||||
last_received_chunk_at: None,
|
||||
request_callback: None,
|
||||
remaining_turns: u32::MAX,
|
||||
configured_model,
|
||||
configured_model: configured_model.clone(),
|
||||
profile: AgentProfile::new(profile_id, tools),
|
||||
}
|
||||
}
|
||||
|
@ -600,6 +615,7 @@ impl Thread {
|
|||
feedback: None,
|
||||
message_feedback: HashMap::default(),
|
||||
last_auto_capture_at: None,
|
||||
last_error_context: None,
|
||||
last_received_chunk_at: None,
|
||||
request_callback: None,
|
||||
remaining_turns: u32::MAX,
|
||||
|
@ -1251,9 +1267,58 @@ impl Thread {
|
|||
|
||||
self.flush_notifications(model.clone(), intent, cx);
|
||||
|
||||
let request = self.to_completion_request(model.clone(), intent, cx);
|
||||
let _checkpoint = self.finalize_pending_checkpoint(cx);
|
||||
self.stream_completion(
|
||||
self.to_completion_request(model.clone(), intent, cx),
|
||||
model,
|
||||
intent,
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
}
|
||||
|
||||
self.stream_completion(request, model, intent, window, cx);
|
||||
pub fn retry_last_completion(
|
||||
&mut self,
|
||||
window: Option<AnyWindowHandle>,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
// Clear any existing error state
|
||||
self.retry_state = None;
|
||||
|
||||
// Use the last error context if available, otherwise fall back to configured model
|
||||
let (model, intent) = if let Some((model, intent)) = self.last_error_context.take() {
|
||||
(model, intent)
|
||||
} else if let Some(configured_model) = self.configured_model.as_ref() {
|
||||
let model = configured_model.model.clone();
|
||||
let intent = if self.has_pending_tool_uses() {
|
||||
CompletionIntent::ToolResults
|
||||
} else {
|
||||
CompletionIntent::UserPrompt
|
||||
};
|
||||
(model, intent)
|
||||
} else if let Some(configured_model) = self.get_or_init_configured_model(cx) {
|
||||
let model = configured_model.model.clone();
|
||||
let intent = if self.has_pending_tool_uses() {
|
||||
CompletionIntent::ToolResults
|
||||
} else {
|
||||
CompletionIntent::UserPrompt
|
||||
};
|
||||
(model, intent)
|
||||
} else {
|
||||
return;
|
||||
};
|
||||
|
||||
self.send_to_model(model, intent, window, cx);
|
||||
}
|
||||
|
||||
pub fn enable_burn_mode_and_retry(
|
||||
&mut self,
|
||||
window: Option<AnyWindowHandle>,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
self.completion_mode = CompletionMode::Burn;
|
||||
cx.emit(ThreadEvent::ProfileChanged);
|
||||
self.retry_last_completion(window, cx);
|
||||
}
|
||||
|
||||
pub fn used_tools_since_last_user_message(&self) -> bool {
|
||||
|
@ -1931,18 +1996,6 @@ impl Thread {
|
|||
project.set_agent_location(None, cx);
|
||||
});
|
||||
|
||||
fn emit_generic_error(error: &anyhow::Error, cx: &mut Context<Thread>) {
|
||||
let error_message = error
|
||||
.chain()
|
||||
.map(|err| err.to_string())
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n");
|
||||
cx.emit(ThreadEvent::ShowError(ThreadError::Message {
|
||||
header: "Error interacting with language model".into(),
|
||||
message: SharedString::from(error_message.clone()),
|
||||
}));
|
||||
}
|
||||
|
||||
if error.is::<PaymentRequiredError>() {
|
||||
cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
|
||||
} else if let Some(error) =
|
||||
|
@ -1954,9 +2007,10 @@ impl Thread {
|
|||
} else if let Some(completion_error) =
|
||||
error.downcast_ref::<LanguageModelCompletionError>()
|
||||
{
|
||||
use LanguageModelCompletionError::*;
|
||||
match &completion_error {
|
||||
PromptTooLarge { tokens, .. } => {
|
||||
LanguageModelCompletionError::PromptTooLarge {
|
||||
tokens, ..
|
||||
} => {
|
||||
let tokens = tokens.unwrap_or_else(|| {
|
||||
// We didn't get an exact token count from the API, so fall back on our estimate.
|
||||
thread
|
||||
|
@ -1977,63 +2031,22 @@ impl Thread {
|
|||
});
|
||||
cx.notify();
|
||||
}
|
||||
RateLimitExceeded {
|
||||
retry_after: Some(retry_after),
|
||||
..
|
||||
}
|
||||
| ServerOverloaded {
|
||||
retry_after: Some(retry_after),
|
||||
..
|
||||
} => {
|
||||
thread.handle_rate_limit_error(
|
||||
&completion_error,
|
||||
*retry_after,
|
||||
model.clone(),
|
||||
intent,
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
retry_scheduled = true;
|
||||
}
|
||||
RateLimitExceeded { .. } | ServerOverloaded { .. } => {
|
||||
retry_scheduled = thread.handle_retryable_error(
|
||||
&completion_error,
|
||||
model.clone(),
|
||||
intent,
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
if !retry_scheduled {
|
||||
emit_generic_error(error, cx);
|
||||
_ => {
|
||||
if let Some(retry_strategy) =
|
||||
Thread::get_retry_strategy(completion_error)
|
||||
{
|
||||
retry_scheduled = thread
|
||||
.handle_retryable_error_with_delay(
|
||||
&completion_error,
|
||||
Some(retry_strategy),
|
||||
model.clone(),
|
||||
intent,
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
}
|
||||
}
|
||||
ApiInternalServerError { .. }
|
||||
| ApiReadResponseError { .. }
|
||||
| HttpSend { .. } => {
|
||||
retry_scheduled = thread.handle_retryable_error(
|
||||
&completion_error,
|
||||
model.clone(),
|
||||
intent,
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
if !retry_scheduled {
|
||||
emit_generic_error(error, cx);
|
||||
}
|
||||
}
|
||||
NoApiKey { .. }
|
||||
| HttpResponseError { .. }
|
||||
| BadRequestFormat { .. }
|
||||
| AuthenticationError { .. }
|
||||
| PermissionError { .. }
|
||||
| ApiEndpointNotFound { .. }
|
||||
| SerializeRequest { .. }
|
||||
| BuildRequestBody { .. }
|
||||
| DeserializeResponse { .. }
|
||||
| Other { .. } => emit_generic_error(error, cx),
|
||||
}
|
||||
} else {
|
||||
emit_generic_error(error, cx);
|
||||
}
|
||||
|
||||
if !retry_scheduled {
|
||||
|
@ -2160,73 +2173,132 @@ impl Thread {
|
|||
});
|
||||
}
|
||||
|
||||
fn handle_rate_limit_error(
|
||||
&mut self,
|
||||
error: &LanguageModelCompletionError,
|
||||
retry_after: Duration,
|
||||
model: Arc<dyn LanguageModel>,
|
||||
intent: CompletionIntent,
|
||||
window: Option<AnyWindowHandle>,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
// For rate limit errors, we only retry once with the specified duration
|
||||
let retry_message = format!("{error}. Retrying in {} seconds…", retry_after.as_secs());
|
||||
log::warn!(
|
||||
"Retrying completion request in {} seconds: {error:?}",
|
||||
retry_after.as_secs(),
|
||||
);
|
||||
fn get_retry_strategy(error: &LanguageModelCompletionError) -> Option<RetryStrategy> {
|
||||
use LanguageModelCompletionError::*;
|
||||
|
||||
// Add a UI-only message instead of a regular message
|
||||
let id = self.next_message_id.post_inc();
|
||||
self.messages.push(Message {
|
||||
id,
|
||||
role: Role::System,
|
||||
segments: vec![MessageSegment::Text(retry_message)],
|
||||
loaded_context: LoadedContext::default(),
|
||||
creases: Vec::new(),
|
||||
is_hidden: false,
|
||||
ui_only: true,
|
||||
});
|
||||
cx.emit(ThreadEvent::MessageAdded(id));
|
||||
// Schedule the retry
|
||||
let thread_handle = cx.entity().downgrade();
|
||||
|
||||
cx.spawn(async move |_thread, cx| {
|
||||
cx.background_executor().timer(retry_after).await;
|
||||
|
||||
thread_handle
|
||||
.update(cx, |thread, cx| {
|
||||
// Retry the completion
|
||||
thread.send_to_model(model, intent, window, cx);
|
||||
// General strategy here:
|
||||
// - If retrying won't help (e.g. invalid API key or payload too large), return None so we don't retry at all.
|
||||
// - If it's a time-based issue (e.g. server overloaded, rate limit exceeded), try multiple times with exponential backoff.
|
||||
// - If it's an issue that *might* be fixed by retrying (e.g. internal server error), just retry once.
|
||||
match error {
|
||||
HttpResponseError {
|
||||
status_code: StatusCode::TOO_MANY_REQUESTS,
|
||||
..
|
||||
} => Some(RetryStrategy::ExponentialBackoff {
|
||||
initial_delay: BASE_RETRY_DELAY,
|
||||
max_attempts: MAX_RETRY_ATTEMPTS,
|
||||
}),
|
||||
ServerOverloaded { retry_after, .. } | RateLimitExceeded { retry_after, .. } => {
|
||||
Some(RetryStrategy::Fixed {
|
||||
delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
|
||||
max_attempts: MAX_RETRY_ATTEMPTS,
|
||||
})
|
||||
.log_err();
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
|
||||
fn handle_retryable_error(
|
||||
&mut self,
|
||||
error: &LanguageModelCompletionError,
|
||||
model: Arc<dyn LanguageModel>,
|
||||
intent: CompletionIntent,
|
||||
window: Option<AnyWindowHandle>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> bool {
|
||||
self.handle_retryable_error_with_delay(error, None, model, intent, window, cx)
|
||||
}
|
||||
UpstreamProviderError {
|
||||
status,
|
||||
retry_after,
|
||||
..
|
||||
} => match *status {
|
||||
StatusCode::TOO_MANY_REQUESTS | StatusCode::SERVICE_UNAVAILABLE => {
|
||||
Some(RetryStrategy::Fixed {
|
||||
delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
|
||||
max_attempts: MAX_RETRY_ATTEMPTS,
|
||||
})
|
||||
}
|
||||
StatusCode::INTERNAL_SERVER_ERROR => Some(RetryStrategy::Fixed {
|
||||
delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
|
||||
// Internal Server Error could be anything, so only retry once.
|
||||
max_attempts: 1,
|
||||
}),
|
||||
status => {
|
||||
// There is no StatusCode variant for the unofficial HTTP 529 ("The service is overloaded"),
|
||||
// but we frequently get them in practice. See https://http.dev/529
|
||||
if status.as_u16() == 529 {
|
||||
Some(RetryStrategy::Fixed {
|
||||
delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
|
||||
max_attempts: MAX_RETRY_ATTEMPTS,
|
||||
})
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
},
|
||||
ApiInternalServerError { .. } => Some(RetryStrategy::Fixed {
|
||||
delay: BASE_RETRY_DELAY,
|
||||
max_attempts: 1,
|
||||
}),
|
||||
ApiReadResponseError { .. }
|
||||
| HttpSend { .. }
|
||||
| DeserializeResponse { .. }
|
||||
| BadRequestFormat { .. } => Some(RetryStrategy::Fixed {
|
||||
delay: BASE_RETRY_DELAY,
|
||||
max_attempts: 1,
|
||||
}),
|
||||
// Retrying these errors definitely shouldn't help.
|
||||
HttpResponseError {
|
||||
status_code:
|
||||
StatusCode::PAYLOAD_TOO_LARGE | StatusCode::FORBIDDEN | StatusCode::UNAUTHORIZED,
|
||||
..
|
||||
}
|
||||
| SerializeRequest { .. }
|
||||
| BuildRequestBody { .. }
|
||||
| PromptTooLarge { .. }
|
||||
| AuthenticationError { .. }
|
||||
| PermissionError { .. }
|
||||
| ApiEndpointNotFound { .. }
|
||||
| NoApiKey { .. } => None,
|
||||
// Retry all other 4xx and 5xx errors once.
|
||||
HttpResponseError { status_code, .. }
|
||||
if status_code.is_client_error() || status_code.is_server_error() =>
|
||||
{
|
||||
Some(RetryStrategy::Fixed {
|
||||
delay: BASE_RETRY_DELAY,
|
||||
max_attempts: 1,
|
||||
})
|
||||
}
|
||||
// Conservatively assume that any other errors are non-retryable
|
||||
HttpResponseError { .. } | Other(..) => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_retryable_error_with_delay(
|
||||
&mut self,
|
||||
error: &LanguageModelCompletionError,
|
||||
custom_delay: Option<Duration>,
|
||||
strategy: Option<RetryStrategy>,
|
||||
model: Arc<dyn LanguageModel>,
|
||||
intent: CompletionIntent,
|
||||
window: Option<AnyWindowHandle>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> bool {
|
||||
// Store context for the Retry button
|
||||
self.last_error_context = Some((model.clone(), intent));
|
||||
|
||||
// Only auto-retry if Burn Mode is enabled
|
||||
if self.completion_mode != CompletionMode::Burn {
|
||||
// Show error with retry options
|
||||
cx.emit(ThreadEvent::ShowError(ThreadError::RetryableError {
|
||||
message: format!(
|
||||
"{}\n\nTo automatically retry when similar errors happen, enable Burn Mode.",
|
||||
error
|
||||
)
|
||||
.into(),
|
||||
can_enable_burn_mode: true,
|
||||
}));
|
||||
return false;
|
||||
}
|
||||
|
||||
let Some(strategy) = strategy.or_else(|| Self::get_retry_strategy(error)) else {
|
||||
return false;
|
||||
};
|
||||
|
||||
let max_attempts = match &strategy {
|
||||
RetryStrategy::ExponentialBackoff { max_attempts, .. } => *max_attempts,
|
||||
RetryStrategy::Fixed { max_attempts, .. } => *max_attempts,
|
||||
};
|
||||
|
||||
let retry_state = self.retry_state.get_or_insert(RetryState {
|
||||
attempt: 0,
|
||||
max_attempts: MAX_RETRY_ATTEMPTS,
|
||||
max_attempts,
|
||||
intent,
|
||||
});
|
||||
|
||||
|
@ -2236,20 +2308,24 @@ impl Thread {
|
|||
let intent = retry_state.intent;
|
||||
|
||||
if attempt <= max_attempts {
|
||||
// Use custom delay if provided (e.g., from rate limit), otherwise exponential backoff
|
||||
let delay = if let Some(custom_delay) = custom_delay {
|
||||
custom_delay
|
||||
} else {
|
||||
let delay_secs = BASE_RETRY_DELAY_SECS * 2u64.pow((attempt - 1) as u32);
|
||||
Duration::from_secs(delay_secs)
|
||||
let delay = match &strategy {
|
||||
RetryStrategy::ExponentialBackoff { initial_delay, .. } => {
|
||||
let delay_secs = initial_delay.as_secs() * 2u64.pow((attempt - 1) as u32);
|
||||
Duration::from_secs(delay_secs)
|
||||
}
|
||||
RetryStrategy::Fixed { delay, .. } => *delay,
|
||||
};
|
||||
|
||||
// Add a transient message to inform the user
|
||||
let delay_secs = delay.as_secs();
|
||||
let retry_message = format!(
|
||||
"{error}. Retrying (attempt {attempt} of {max_attempts}) \
|
||||
in {delay_secs} seconds..."
|
||||
);
|
||||
let retry_message = if max_attempts == 1 {
|
||||
format!("{error}. Retrying in {delay_secs} seconds...")
|
||||
} else {
|
||||
format!(
|
||||
"{error}. Retrying (attempt {attempt} of {max_attempts}) \
|
||||
in {delay_secs} seconds..."
|
||||
)
|
||||
};
|
||||
log::warn!(
|
||||
"Retrying completion request (attempt {attempt} of {max_attempts}) \
|
||||
in {delay_secs} seconds: {error:?}",
|
||||
|
@ -2288,18 +2364,15 @@ impl Thread {
|
|||
// Max retries exceeded
|
||||
self.retry_state = None;
|
||||
|
||||
let notification_text = if max_attempts == 1 {
|
||||
"Failed after retrying.".into()
|
||||
} else {
|
||||
format!("Failed after retrying {} times.", max_attempts).into()
|
||||
};
|
||||
|
||||
// Stop generating since we're giving up on retrying.
|
||||
self.pending_completions.clear();
|
||||
|
||||
cx.emit(ThreadEvent::RetriesFailed {
|
||||
message: notification_text,
|
||||
});
|
||||
// Show error alongside a Retry button, but no
|
||||
// Enable Burn Mode button (since it's already enabled)
|
||||
cx.emit(ThreadEvent::ShowError(ThreadError::RetryableError {
|
||||
message: format!("Failed after retrying: {}", error).into(),
|
||||
can_enable_burn_mode: false,
|
||||
}));
|
||||
|
||||
false
|
||||
}
|
||||
|
@ -3211,6 +3284,11 @@ pub enum ThreadError {
|
|||
header: SharedString,
|
||||
message: SharedString,
|
||||
},
|
||||
#[error("Retryable error: {message}")]
|
||||
RetryableError {
|
||||
message: SharedString,
|
||||
can_enable_burn_mode: bool,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
|
@ -3256,9 +3334,6 @@ pub enum ThreadEvent {
|
|||
CancelEditing,
|
||||
CompletionCanceled,
|
||||
ProfileChanged,
|
||||
RetriesFailed {
|
||||
message: SharedString,
|
||||
},
|
||||
}
|
||||
|
||||
impl EventEmitter<ThreadEvent> for Thread {}
|
||||
|
@ -4169,6 +4244,11 @@ fn main() {{
|
|||
let project = create_test_project(cx, json!({})).await;
|
||||
let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
|
||||
|
||||
// Enable Burn Mode to allow retries
|
||||
thread.update(cx, |thread, _| {
|
||||
thread.set_completion_mode(CompletionMode::Burn);
|
||||
});
|
||||
|
||||
// Create model that returns overloaded error
|
||||
let model = Arc::new(ErrorInjector::new(TestError::Overloaded));
|
||||
|
||||
|
@ -4190,7 +4270,7 @@ fn main() {{
|
|||
assert_eq!(retry_state.attempt, 1, "Should be first retry attempt");
|
||||
assert_eq!(
|
||||
retry_state.max_attempts, MAX_RETRY_ATTEMPTS,
|
||||
"Should have default max attempts"
|
||||
"Should retry MAX_RETRY_ATTEMPTS times for overloaded errors"
|
||||
);
|
||||
});
|
||||
|
||||
|
@ -4242,6 +4322,11 @@ fn main() {{
|
|||
let project = create_test_project(cx, json!({})).await;
|
||||
let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
|
||||
|
||||
// Enable Burn Mode to allow retries
|
||||
thread.update(cx, |thread, _| {
|
||||
thread.set_completion_mode(CompletionMode::Burn);
|
||||
});
|
||||
|
||||
// Create model that returns internal server error
|
||||
let model = Arc::new(ErrorInjector::new(TestError::InternalServerError));
|
||||
|
||||
|
@ -4263,7 +4348,7 @@ fn main() {{
|
|||
let retry_state = thread.retry_state.as_ref().unwrap();
|
||||
assert_eq!(retry_state.attempt, 1, "Should be first retry attempt");
|
||||
assert_eq!(
|
||||
retry_state.max_attempts, MAX_RETRY_ATTEMPTS,
|
||||
retry_state.max_attempts, 1,
|
||||
"Should have correct max attempts"
|
||||
);
|
||||
});
|
||||
|
@ -4279,8 +4364,8 @@ fn main() {{
|
|||
if let MessageSegment::Text(text) = seg {
|
||||
text.contains("internal")
|
||||
&& text.contains("Fake")
|
||||
&& text
|
||||
.contains(&format!("attempt 1 of {}", MAX_RETRY_ATTEMPTS))
|
||||
&& text.contains("Retrying in")
|
||||
&& !text.contains("attempt")
|
||||
} else {
|
||||
false
|
||||
}
|
||||
|
@ -4318,8 +4403,13 @@ fn main() {{
|
|||
let project = create_test_project(cx, json!({})).await;
|
||||
let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
|
||||
|
||||
// Create model that returns overloaded error
|
||||
let model = Arc::new(ErrorInjector::new(TestError::Overloaded));
|
||||
// Enable Burn Mode to allow retries
|
||||
thread.update(cx, |thread, _| {
|
||||
thread.set_completion_mode(CompletionMode::Burn);
|
||||
});
|
||||
|
||||
// Create model that returns internal server error
|
||||
let model = Arc::new(ErrorInjector::new(TestError::InternalServerError));
|
||||
|
||||
// Insert a user message
|
||||
thread.update(cx, |thread, cx| {
|
||||
|
@ -4369,11 +4459,14 @@ fn main() {{
|
|||
assert!(thread.retry_state.is_some(), "Should have retry state");
|
||||
let retry_state = thread.retry_state.as_ref().unwrap();
|
||||
assert_eq!(retry_state.attempt, 1, "Should be first retry attempt");
|
||||
assert_eq!(
|
||||
retry_state.max_attempts, 1,
|
||||
"Internal server errors should only retry once"
|
||||
);
|
||||
});
|
||||
|
||||
// Advance clock for first retry
|
||||
cx.executor()
|
||||
.advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS));
|
||||
cx.executor().advance_clock(BASE_RETRY_DELAY);
|
||||
cx.run_until_parked();
|
||||
|
||||
// Should have scheduled second retry - count retry messages
|
||||
|
@ -4393,93 +4486,25 @@ fn main() {{
|
|||
})
|
||||
.count()
|
||||
});
|
||||
assert_eq!(retry_count, 2, "Should have scheduled second retry");
|
||||
|
||||
// Check retry state updated
|
||||
thread.read_with(cx, |thread, _| {
|
||||
assert!(thread.retry_state.is_some(), "Should have retry state");
|
||||
let retry_state = thread.retry_state.as_ref().unwrap();
|
||||
assert_eq!(retry_state.attempt, 2, "Should be second retry attempt");
|
||||
assert_eq!(
|
||||
retry_state.max_attempts, MAX_RETRY_ATTEMPTS,
|
||||
"Should have correct max attempts"
|
||||
);
|
||||
});
|
||||
|
||||
// Advance clock for second retry (exponential backoff)
|
||||
cx.executor()
|
||||
.advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS * 2));
|
||||
cx.run_until_parked();
|
||||
|
||||
// Should have scheduled third retry
|
||||
// Count all retry messages now
|
||||
let retry_count = thread.update(cx, |thread, _| {
|
||||
thread
|
||||
.messages
|
||||
.iter()
|
||||
.filter(|m| {
|
||||
m.ui_only
|
||||
&& m.segments.iter().any(|s| {
|
||||
if let MessageSegment::Text(text) = s {
|
||||
text.contains("Retrying") && text.contains("seconds")
|
||||
} else {
|
||||
false
|
||||
}
|
||||
})
|
||||
})
|
||||
.count()
|
||||
});
|
||||
assert_eq!(
|
||||
retry_count, MAX_RETRY_ATTEMPTS as usize,
|
||||
"Should have scheduled third retry"
|
||||
retry_count, 1,
|
||||
"Should have only one retry for internal server errors"
|
||||
);
|
||||
|
||||
// Check retry state updated
|
||||
// For internal server errors, we only retry once and then give up
|
||||
// Check that retry_state is cleared after the single retry
|
||||
thread.read_with(cx, |thread, _| {
|
||||
assert!(thread.retry_state.is_some(), "Should have retry state");
|
||||
let retry_state = thread.retry_state.as_ref().unwrap();
|
||||
assert_eq!(
|
||||
retry_state.attempt, MAX_RETRY_ATTEMPTS,
|
||||
"Should be at max retry attempt"
|
||||
);
|
||||
assert_eq!(
|
||||
retry_state.max_attempts, MAX_RETRY_ATTEMPTS,
|
||||
"Should have correct max attempts"
|
||||
assert!(
|
||||
thread.retry_state.is_none(),
|
||||
"Retry state should be cleared after single retry"
|
||||
);
|
||||
});
|
||||
|
||||
// Advance clock for third retry (exponential backoff)
|
||||
cx.executor()
|
||||
.advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS * 4));
|
||||
cx.run_until_parked();
|
||||
|
||||
// No more retries should be scheduled after clock was advanced.
|
||||
let retry_count = thread.update(cx, |thread, _| {
|
||||
thread
|
||||
.messages
|
||||
.iter()
|
||||
.filter(|m| {
|
||||
m.ui_only
|
||||
&& m.segments.iter().any(|s| {
|
||||
if let MessageSegment::Text(text) = s {
|
||||
text.contains("Retrying") && text.contains("seconds")
|
||||
} else {
|
||||
false
|
||||
}
|
||||
})
|
||||
})
|
||||
.count()
|
||||
});
|
||||
assert_eq!(
|
||||
retry_count, MAX_RETRY_ATTEMPTS as usize,
|
||||
"Should not exceed max retries"
|
||||
);
|
||||
|
||||
// Final completion count should be initial + max retries
|
||||
// Verify total attempts (1 initial + 1 retry)
|
||||
assert_eq!(
|
||||
*completion_count.lock(),
|
||||
(MAX_RETRY_ATTEMPTS + 1) as usize,
|
||||
"Should have made initial + max retry attempts"
|
||||
2,
|
||||
"Should have attempted once plus 1 retry"
|
||||
);
|
||||
}
|
||||
|
||||
|
@ -4490,6 +4515,11 @@ fn main() {{
|
|||
let project = create_test_project(cx, json!({})).await;
|
||||
let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
|
||||
|
||||
// Enable Burn Mode to allow retries
|
||||
thread.update(cx, |thread, _| {
|
||||
thread.set_completion_mode(CompletionMode::Burn);
|
||||
});
|
||||
|
||||
// Create model that returns overloaded error
|
||||
let model = Arc::new(ErrorInjector::new(TestError::Overloaded));
|
||||
|
||||
|
@ -4499,13 +4529,13 @@ fn main() {{
|
|||
});
|
||||
|
||||
// Track events
|
||||
let retries_failed = Arc::new(Mutex::new(false));
|
||||
let retries_failed_clone = retries_failed.clone();
|
||||
let stopped_with_error = Arc::new(Mutex::new(false));
|
||||
let stopped_with_error_clone = stopped_with_error.clone();
|
||||
|
||||
let _subscription = thread.update(cx, |_, cx| {
|
||||
cx.subscribe(&thread, move |_, _, event: &ThreadEvent, _| {
|
||||
if let ThreadEvent::RetriesFailed { .. } = event {
|
||||
*retries_failed_clone.lock() = true;
|
||||
if let ThreadEvent::Stopped(Err(_)) = event {
|
||||
*stopped_with_error_clone.lock() = true;
|
||||
}
|
||||
})
|
||||
});
|
||||
|
@ -4517,23 +4547,11 @@ fn main() {{
|
|||
cx.run_until_parked();
|
||||
|
||||
// Advance through all retries
|
||||
for i in 0..MAX_RETRY_ATTEMPTS {
|
||||
let delay = if i == 0 {
|
||||
BASE_RETRY_DELAY_SECS
|
||||
} else {
|
||||
BASE_RETRY_DELAY_SECS * 2u64.pow(i as u32 - 1)
|
||||
};
|
||||
cx.executor().advance_clock(Duration::from_secs(delay));
|
||||
for _ in 0..MAX_RETRY_ATTEMPTS {
|
||||
cx.executor().advance_clock(BASE_RETRY_DELAY);
|
||||
cx.run_until_parked();
|
||||
}
|
||||
|
||||
// After the 3rd retry is scheduled, we need to wait for it to execute and fail
|
||||
// The 3rd retry has a delay of BASE_RETRY_DELAY_SECS * 4 (20 seconds)
|
||||
let final_delay = BASE_RETRY_DELAY_SECS * 2u64.pow((MAX_RETRY_ATTEMPTS - 1) as u32);
|
||||
cx.executor()
|
||||
.advance_clock(Duration::from_secs(final_delay));
|
||||
cx.run_until_parked();
|
||||
|
||||
let retry_count = thread.update(cx, |thread, _| {
|
||||
thread
|
||||
.messages
|
||||
|
@ -4551,14 +4569,14 @@ fn main() {{
|
|||
.count()
|
||||
});
|
||||
|
||||
// After max retries, should emit RetriesFailed event
|
||||
// After max retries, should emit Stopped(Err(...)) event
|
||||
assert_eq!(
|
||||
retry_count, MAX_RETRY_ATTEMPTS as usize,
|
||||
"Should have attempted max retries"
|
||||
"Should have attempted MAX_RETRY_ATTEMPTS retries for overloaded errors"
|
||||
);
|
||||
assert!(
|
||||
*retries_failed.lock(),
|
||||
"Should emit RetriesFailed event after max retries exceeded"
|
||||
*stopped_with_error.lock(),
|
||||
"Should emit Stopped(Err(...)) event after max retries exceeded"
|
||||
);
|
||||
|
||||
// Retry state should be cleared
|
||||
|
@ -4576,7 +4594,7 @@ fn main() {{
|
|||
.count();
|
||||
assert_eq!(
|
||||
retry_messages, MAX_RETRY_ATTEMPTS as usize,
|
||||
"Should have one retry message per attempt"
|
||||
"Should have MAX_RETRY_ATTEMPTS retry messages for overloaded errors"
|
||||
);
|
||||
});
|
||||
}
|
||||
|
@ -4588,6 +4606,11 @@ fn main() {{
|
|||
let project = create_test_project(cx, json!({})).await;
|
||||
let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
|
||||
|
||||
// Enable Burn Mode to allow retries
|
||||
thread.update(cx, |thread, _| {
|
||||
thread.set_completion_mode(CompletionMode::Burn);
|
||||
});
|
||||
|
||||
// We'll use a wrapper to switch behavior after first failure
|
||||
struct RetryTestModel {
|
||||
inner: Arc<FakeLanguageModel>,
|
||||
|
@ -4714,8 +4737,7 @@ fn main() {{
|
|||
});
|
||||
|
||||
// Wait for retry
|
||||
cx.executor()
|
||||
.advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS));
|
||||
cx.executor().advance_clock(BASE_RETRY_DELAY);
|
||||
cx.run_until_parked();
|
||||
|
||||
// Stream some successful content
|
||||
|
@ -4757,6 +4779,11 @@ fn main() {{
|
|||
let project = create_test_project(cx, json!({})).await;
|
||||
let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
|
||||
|
||||
// Enable Burn Mode to allow retries
|
||||
thread.update(cx, |thread, _| {
|
||||
thread.set_completion_mode(CompletionMode::Burn);
|
||||
});
|
||||
|
||||
// Create a model that fails once then succeeds
|
||||
struct FailOnceModel {
|
||||
inner: Arc<FakeLanguageModel>,
|
||||
|
@ -4877,8 +4904,7 @@ fn main() {{
|
|||
});
|
||||
|
||||
// Wait for retry delay
|
||||
cx.executor()
|
||||
.advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS));
|
||||
cx.executor().advance_clock(BASE_RETRY_DELAY);
|
||||
cx.run_until_parked();
|
||||
|
||||
// The retry should now use our FailOnceModel which should succeed
|
||||
|
@ -4919,6 +4945,11 @@ fn main() {{
|
|||
let project = create_test_project(cx, json!({})).await;
|
||||
let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
|
||||
|
||||
// Enable Burn Mode to allow retries
|
||||
thread.update(cx, |thread, _| {
|
||||
thread.set_completion_mode(CompletionMode::Burn);
|
||||
});
|
||||
|
||||
// Create a model that returns rate limit error with retry_after
|
||||
struct RateLimitModel {
|
||||
inner: Arc<FakeLanguageModel>,
|
||||
|
@ -5037,9 +5068,15 @@ fn main() {{
|
|||
|
||||
thread.read_with(cx, |thread, _| {
|
||||
assert!(
|
||||
thread.retry_state.is_none(),
|
||||
"Rate limit errors should not set retry_state"
|
||||
thread.retry_state.is_some(),
|
||||
"Rate limit errors should set retry_state"
|
||||
);
|
||||
if let Some(retry_state) = &thread.retry_state {
|
||||
assert_eq!(
|
||||
retry_state.max_attempts, MAX_RETRY_ATTEMPTS,
|
||||
"Rate limit errors should use MAX_RETRY_ATTEMPTS"
|
||||
);
|
||||
}
|
||||
});
|
||||
|
||||
// Verify we have one retry message
|
||||
|
@ -5072,18 +5109,15 @@ fn main() {{
|
|||
.find(|msg| msg.role == Role::System && msg.ui_only)
|
||||
.expect("Should have a retry message");
|
||||
|
||||
// Check that the message doesn't contain attempt count
|
||||
// Check that the message contains attempt count since we use retry_state
|
||||
if let Some(MessageSegment::Text(text)) = retry_message.segments.first() {
|
||||
assert!(
|
||||
!text.contains("attempt"),
|
||||
"Rate limit retry message should not contain attempt count"
|
||||
text.contains(&format!("attempt 1 of {}", MAX_RETRY_ATTEMPTS)),
|
||||
"Rate limit retry message should contain attempt count with MAX_RETRY_ATTEMPTS"
|
||||
);
|
||||
assert!(
|
||||
text.contains(&format!(
|
||||
"Retrying in {} seconds",
|
||||
TEST_RATE_LIMIT_RETRY_SECS
|
||||
)),
|
||||
"Rate limit retry message should contain retry delay"
|
||||
text.contains("Retrying"),
|
||||
"Rate limit retry message should contain retry text"
|
||||
);
|
||||
}
|
||||
});
|
||||
|
@ -5189,6 +5223,79 @@ fn main() {{
|
|||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_no_retry_without_burn_mode(cx: &mut TestAppContext) {
|
||||
init_test_settings(cx);
|
||||
|
||||
let project = create_test_project(cx, json!({})).await;
|
||||
let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
|
||||
|
||||
// Ensure we're in Normal mode (not Burn mode)
|
||||
thread.update(cx, |thread, _| {
|
||||
thread.set_completion_mode(CompletionMode::Normal);
|
||||
});
|
||||
|
||||
// Track error events
|
||||
let error_events = Arc::new(Mutex::new(Vec::new()));
|
||||
let error_events_clone = error_events.clone();
|
||||
|
||||
let _subscription = thread.update(cx, |_, cx| {
|
||||
cx.subscribe(&thread, move |_, _, event: &ThreadEvent, _| {
|
||||
if let ThreadEvent::ShowError(error) = event {
|
||||
error_events_clone.lock().push(error.clone());
|
||||
}
|
||||
})
|
||||
});
|
||||
|
||||
// Create model that returns overloaded error
|
||||
let model = Arc::new(ErrorInjector::new(TestError::Overloaded));
|
||||
|
||||
// Insert a user message
|
||||
thread.update(cx, |thread, cx| {
|
||||
thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
|
||||
});
|
||||
|
||||
// Start completion
|
||||
thread.update(cx, |thread, cx| {
|
||||
thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
|
||||
});
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
// Verify no retry state was created
|
||||
thread.read_with(cx, |thread, _| {
|
||||
assert!(
|
||||
thread.retry_state.is_none(),
|
||||
"Should not have retry state in Normal mode"
|
||||
);
|
||||
});
|
||||
|
||||
// Check that a retryable error was reported
|
||||
let errors = error_events.lock();
|
||||
assert!(!errors.is_empty(), "Should have received an error event");
|
||||
|
||||
if let ThreadError::RetryableError {
|
||||
message: _,
|
||||
can_enable_burn_mode,
|
||||
} = &errors[0]
|
||||
{
|
||||
assert!(
|
||||
*can_enable_burn_mode,
|
||||
"Error should indicate burn mode can be enabled"
|
||||
);
|
||||
} else {
|
||||
panic!("Expected RetryableError, got {:?}", errors[0]);
|
||||
}
|
||||
|
||||
// Verify the thread is no longer generating
|
||||
thread.read_with(cx, |thread, _| {
|
||||
assert!(
|
||||
!thread.is_generating(),
|
||||
"Should not be generating after error without retry"
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_retry_cancelled_on_stop(cx: &mut TestAppContext) {
|
||||
init_test_settings(cx);
|
||||
|
@ -5196,6 +5303,11 @@ fn main() {{
|
|||
let project = create_test_project(cx, json!({})).await;
|
||||
let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
|
||||
|
||||
// Enable Burn Mode to allow retries
|
||||
thread.update(cx, |thread, _| {
|
||||
thread.set_completion_mode(CompletionMode::Burn);
|
||||
});
|
||||
|
||||
// Create model that returns overloaded error
|
||||
let model = Arc::new(ErrorInjector::new(TestError::Overloaded));
|
||||
|
||||
|
|
|
@ -983,30 +983,57 @@ impl ActiveThread {
|
|||
| ThreadEvent::SummaryChanged => {
|
||||
self.save_thread(cx);
|
||||
}
|
||||
ThreadEvent::Stopped(reason) => match reason {
|
||||
Ok(StopReason::EndTurn | StopReason::MaxTokens) => {
|
||||
let used_tools = self.thread.read(cx).used_tools_since_last_user_message();
|
||||
self.play_notification_sound(window, cx);
|
||||
self.show_notification(
|
||||
if used_tools {
|
||||
"Finished running tools"
|
||||
} else {
|
||||
"New message"
|
||||
},
|
||||
IconName::ZedAssistant,
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
ThreadEvent::Stopped(reason) => {
|
||||
match reason {
|
||||
Ok(StopReason::EndTurn | StopReason::MaxTokens) => {
|
||||
let used_tools = self.thread.read(cx).used_tools_since_last_user_message();
|
||||
self.notify_with_sound(
|
||||
if used_tools {
|
||||
"Finished running tools"
|
||||
} else {
|
||||
"New message"
|
||||
},
|
||||
IconName::ZedAssistant,
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
}
|
||||
Ok(StopReason::ToolUse) => {
|
||||
// Don't notify for intermediate tool use
|
||||
}
|
||||
Ok(StopReason::Refusal) => {
|
||||
self.notify_with_sound(
|
||||
"Language model refused to respond",
|
||||
IconName::Warning,
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
}
|
||||
Err(error) => {
|
||||
self.notify_with_sound(
|
||||
"Agent stopped due to an error",
|
||||
IconName::Warning,
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
|
||||
let error_message = error
|
||||
.chain()
|
||||
.map(|err| err.to_string())
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n");
|
||||
self.last_error = Some(ThreadError::Message {
|
||||
header: "Error".into(),
|
||||
message: error_message.into(),
|
||||
});
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
},
|
||||
}
|
||||
ThreadEvent::ToolConfirmationNeeded => {
|
||||
self.play_notification_sound(window, cx);
|
||||
self.show_notification("Waiting for tool confirmation", IconName::Info, window, cx);
|
||||
self.notify_with_sound("Waiting for tool confirmation", IconName::Info, window, cx);
|
||||
}
|
||||
ThreadEvent::ToolUseLimitReached => {
|
||||
self.play_notification_sound(window, cx);
|
||||
self.show_notification(
|
||||
self.notify_with_sound(
|
||||
"Consecutive tool use limit reached.",
|
||||
IconName::Warning,
|
||||
window,
|
||||
|
@ -1149,9 +1176,6 @@ impl ActiveThread {
|
|||
self.save_thread(cx);
|
||||
cx.notify();
|
||||
}
|
||||
ThreadEvent::RetriesFailed { message } => {
|
||||
self.show_notification(message, ui::IconName::Warning, window, cx);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1206,6 +1230,17 @@ impl ActiveThread {
|
|||
}
|
||||
}
|
||||
|
||||
fn notify_with_sound(
|
||||
&mut self,
|
||||
caption: impl Into<SharedString>,
|
||||
icon: IconName,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<ActiveThread>,
|
||||
) {
|
||||
self.play_notification_sound(window, cx);
|
||||
self.show_notification(caption, icon, window, cx);
|
||||
}
|
||||
|
||||
fn pop_up(
|
||||
&mut self,
|
||||
icon: IconName,
|
||||
|
|
|
@ -491,6 +491,7 @@ impl AgentConfiguration {
|
|||
category_filter: Some(
|
||||
ExtensionCategoryFilter::ContextServers,
|
||||
),
|
||||
id: None,
|
||||
}
|
||||
.boxed_clone(),
|
||||
cx,
|
||||
|
|
|
@ -1375,7 +1375,6 @@ impl AgentDiff {
|
|||
| ThreadEvent::ToolConfirmationNeeded
|
||||
| ThreadEvent::ToolUseLimitReached
|
||||
| ThreadEvent::CancelEditing
|
||||
| ThreadEvent::RetriesFailed { .. }
|
||||
| ThreadEvent::ProfileChanged => {}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -59,8 +59,9 @@ use theme::ThemeSettings;
|
|||
use time::UtcOffset;
|
||||
use ui::utils::WithRemSize;
|
||||
use ui::{
|
||||
Banner, Callout, CheckboxWithLabel, ContextMenu, ElevationIndex, KeyBinding, PopoverMenu,
|
||||
PopoverMenuHandle, ProgressBar, Tab, Tooltip, Vector, VectorName, prelude::*,
|
||||
Banner, Button, Callout, CheckboxWithLabel, ContextMenu, ElevationIndex, IconPosition,
|
||||
KeyBinding, PopoverMenu, PopoverMenuHandle, ProgressBar, Tab, Tooltip, Vector, VectorName,
|
||||
prelude::*,
|
||||
};
|
||||
use util::ResultExt as _;
|
||||
use workspace::{
|
||||
|
@ -1778,6 +1779,7 @@ impl AgentPanel {
|
|||
category_filter: Some(
|
||||
zed_actions::ExtensionCategoryFilter::ContextServers,
|
||||
),
|
||||
id: None,
|
||||
}),
|
||||
)
|
||||
.action("Add Custom Server…", Box::new(AddContextServer))
|
||||
|
@ -1887,45 +1889,45 @@ impl AgentPanel {
|
|||
}
|
||||
|
||||
fn render_token_count(&self, cx: &App) -> Option<AnyElement> {
|
||||
let (active_thread, message_editor) = match &self.active_view {
|
||||
match &self.active_view {
|
||||
ActiveView::Thread {
|
||||
thread,
|
||||
message_editor,
|
||||
..
|
||||
} => (thread.read(cx), message_editor.read(cx)),
|
||||
ActiveView::TextThread { .. } | ActiveView::History | ActiveView::Configuration => {
|
||||
return None;
|
||||
}
|
||||
};
|
||||
} => {
|
||||
let active_thread = thread.read(cx);
|
||||
let message_editor = message_editor.read(cx);
|
||||
|
||||
let editor_empty = message_editor.is_editor_fully_empty(cx);
|
||||
let editor_empty = message_editor.is_editor_fully_empty(cx);
|
||||
|
||||
if active_thread.is_empty() && editor_empty {
|
||||
return None;
|
||||
}
|
||||
if active_thread.is_empty() && editor_empty {
|
||||
return None;
|
||||
}
|
||||
|
||||
let thread = active_thread.thread().read(cx);
|
||||
let is_generating = thread.is_generating();
|
||||
let conversation_token_usage = thread.total_token_usage()?;
|
||||
let thread = active_thread.thread().read(cx);
|
||||
let is_generating = thread.is_generating();
|
||||
let conversation_token_usage = thread.total_token_usage()?;
|
||||
|
||||
let (total_token_usage, is_estimating) =
|
||||
if let Some((editing_message_id, unsent_tokens)) = active_thread.editing_message_id() {
|
||||
let combined = thread
|
||||
.token_usage_up_to_message(editing_message_id)
|
||||
.add(unsent_tokens);
|
||||
let (total_token_usage, is_estimating) =
|
||||
if let Some((editing_message_id, unsent_tokens)) =
|
||||
active_thread.editing_message_id()
|
||||
{
|
||||
let combined = thread
|
||||
.token_usage_up_to_message(editing_message_id)
|
||||
.add(unsent_tokens);
|
||||
|
||||
(combined, unsent_tokens > 0)
|
||||
} else {
|
||||
let unsent_tokens = message_editor.last_estimated_token_count().unwrap_or(0);
|
||||
let combined = conversation_token_usage.add(unsent_tokens);
|
||||
(combined, unsent_tokens > 0)
|
||||
} else {
|
||||
let unsent_tokens =
|
||||
message_editor.last_estimated_token_count().unwrap_or(0);
|
||||
let combined = conversation_token_usage.add(unsent_tokens);
|
||||
|
||||
(combined, unsent_tokens > 0)
|
||||
};
|
||||
(combined, unsent_tokens > 0)
|
||||
};
|
||||
|
||||
let is_waiting_to_update_token_count = message_editor.is_waiting_to_update_token_count();
|
||||
let is_waiting_to_update_token_count =
|
||||
message_editor.is_waiting_to_update_token_count();
|
||||
|
||||
match &self.active_view {
|
||||
ActiveView::Thread { .. } => {
|
||||
if total_token_usage.total == 0 {
|
||||
return None;
|
||||
}
|
||||
|
@ -2819,6 +2821,21 @@ impl AgentPanel {
|
|||
.size(IconSize::Small)
|
||||
.color(Color::Error);
|
||||
|
||||
let retry_button = Button::new("retry", "Retry")
|
||||
.icon(IconName::RotateCw)
|
||||
.icon_position(IconPosition::Start)
|
||||
.on_click({
|
||||
let thread = thread.clone();
|
||||
move |_, window, cx| {
|
||||
thread.update(cx, |thread, cx| {
|
||||
thread.clear_last_error();
|
||||
thread.thread().update(cx, |thread, cx| {
|
||||
thread.retry_last_completion(Some(window.window_handle()), cx);
|
||||
});
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
div()
|
||||
.border_t_1()
|
||||
.border_color(cx.theme().colors().border)
|
||||
|
@ -2827,13 +2844,72 @@ impl AgentPanel {
|
|||
.icon(icon)
|
||||
.title(header)
|
||||
.description(message.clone())
|
||||
.primary_action(self.dismiss_error_button(thread, cx))
|
||||
.secondary_action(self.create_copy_button(message_with_header))
|
||||
.primary_action(retry_button)
|
||||
.secondary_action(self.dismiss_error_button(thread, cx))
|
||||
.tertiary_action(self.create_copy_button(message_with_header))
|
||||
.bg_color(self.error_callout_bg(cx)),
|
||||
)
|
||||
.into_any_element()
|
||||
}
|
||||
|
||||
fn render_retryable_error(
|
||||
&self,
|
||||
message: SharedString,
|
||||
can_enable_burn_mode: bool,
|
||||
thread: &Entity<ActiveThread>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> AnyElement {
|
||||
let icon = Icon::new(IconName::XCircle)
|
||||
.size(IconSize::Small)
|
||||
.color(Color::Error);
|
||||
|
||||
let retry_button = Button::new("retry", "Retry")
|
||||
.icon(IconName::RotateCw)
|
||||
.icon_position(IconPosition::Start)
|
||||
.on_click({
|
||||
let thread = thread.clone();
|
||||
move |_, window, cx| {
|
||||
thread.update(cx, |thread, cx| {
|
||||
thread.clear_last_error();
|
||||
thread.thread().update(cx, |thread, cx| {
|
||||
thread.retry_last_completion(Some(window.window_handle()), cx);
|
||||
});
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
let mut callout = Callout::new()
|
||||
.icon(icon)
|
||||
.title("Error")
|
||||
.description(message.clone())
|
||||
.bg_color(self.error_callout_bg(cx))
|
||||
.primary_action(retry_button);
|
||||
|
||||
if can_enable_burn_mode {
|
||||
let burn_mode_button = Button::new("enable_burn_retry", "Enable Burn Mode and Retry")
|
||||
.icon(IconName::ZedBurnMode)
|
||||
.icon_position(IconPosition::Start)
|
||||
.on_click({
|
||||
let thread = thread.clone();
|
||||
move |_, window, cx| {
|
||||
thread.update(cx, |thread, cx| {
|
||||
thread.clear_last_error();
|
||||
thread.thread().update(cx, |thread, cx| {
|
||||
thread.enable_burn_mode_and_retry(Some(window.window_handle()), cx);
|
||||
});
|
||||
});
|
||||
}
|
||||
});
|
||||
callout = callout.secondary_action(burn_mode_button);
|
||||
}
|
||||
|
||||
div()
|
||||
.border_t_1()
|
||||
.border_color(cx.theme().colors().border)
|
||||
.child(callout)
|
||||
.into_any_element()
|
||||
}
|
||||
|
||||
fn render_prompt_editor(
|
||||
&self,
|
||||
context_editor: &Entity<TextThreadEditor>,
|
||||
|
@ -3069,6 +3145,15 @@ impl Render for AgentPanel {
|
|||
ThreadError::Message { header, message } => {
|
||||
self.render_error_message(header, message, thread, cx)
|
||||
}
|
||||
ThreadError::RetryableError {
|
||||
message,
|
||||
can_enable_burn_mode,
|
||||
} => self.render_retryable_error(
|
||||
message,
|
||||
can_enable_burn_mode,
|
||||
thread,
|
||||
cx,
|
||||
),
|
||||
})
|
||||
.into_any(),
|
||||
)
|
||||
|
|
|
@ -12,6 +12,7 @@ use collections::HashMap;
|
|||
use fs::FakeFs;
|
||||
use futures::{FutureExt, future::LocalBoxFuture};
|
||||
use gpui::{AppContext, TestAppContext, Timer};
|
||||
use http_client::StatusCode;
|
||||
use indoc::{formatdoc, indoc};
|
||||
use language_model::{
|
||||
LanguageModelRegistry, LanguageModelRequestTool, LanguageModelToolResult,
|
||||
|
@ -1671,6 +1672,30 @@ async fn retry_on_rate_limit<R>(mut request: impl AsyncFnMut() -> Result<R>) ->
|
|||
Timer::after(retry_after + jitter).await;
|
||||
continue;
|
||||
}
|
||||
LanguageModelCompletionError::UpstreamProviderError {
|
||||
status,
|
||||
retry_after,
|
||||
..
|
||||
} => {
|
||||
// Only retry for specific status codes
|
||||
let should_retry = matches!(
|
||||
*status,
|
||||
StatusCode::TOO_MANY_REQUESTS | StatusCode::SERVICE_UNAVAILABLE
|
||||
) || status.as_u16() == 529;
|
||||
|
||||
if !should_retry {
|
||||
return Err(err.into());
|
||||
}
|
||||
|
||||
// Use server-provided retry_after if available, otherwise use default
|
||||
let retry_after = retry_after.unwrap_or(Duration::from_secs(5));
|
||||
let jitter = retry_after.mul_f64(rand::thread_rng().gen_range(0.0..1.0));
|
||||
eprintln!(
|
||||
"Attempt #{attempt}: {err}. Retry after {retry_after:?} + jitter of {jitter:?}"
|
||||
);
|
||||
Timer::after(retry_after + jitter).await;
|
||||
continue;
|
||||
}
|
||||
_ => return Err(err.into()),
|
||||
},
|
||||
Err(err) => return Err(err),
|
||||
|
|
|
@ -94,6 +94,7 @@ context_server.workspace = true
|
|||
ctor.workspace = true
|
||||
dap = { workspace = true, features = ["test-support"] }
|
||||
dap_adapters = { workspace = true, features = ["test-support"] }
|
||||
dap-types.workspace = true
|
||||
debugger_ui = { workspace = true, features = ["test-support"] }
|
||||
editor = { workspace = true, features = ["test-support"] }
|
||||
extension.workspace = true
|
||||
|
|
|
@ -2,6 +2,7 @@ use crate::tests::TestServer;
|
|||
use call::ActiveCall;
|
||||
use collections::{HashMap, HashSet};
|
||||
|
||||
use dap::{Capabilities, adapters::DebugTaskDefinition, transport::RequestHandling};
|
||||
use debugger_ui::debugger_panel::DebugPanel;
|
||||
use extension::ExtensionHostProxy;
|
||||
use fs::{FakeFs, Fs as _, RemoveOptions};
|
||||
|
@ -22,6 +23,7 @@ use language::{
|
|||
use node_runtime::NodeRuntime;
|
||||
use project::{
|
||||
ProjectPath,
|
||||
debugger::session::ThreadId,
|
||||
lsp_store::{FormatTrigger, LspFormatTarget},
|
||||
};
|
||||
use remote::SshRemoteClient;
|
||||
|
@ -29,7 +31,11 @@ use remote_server::{HeadlessAppState, HeadlessProject};
|
|||
use rpc::proto;
|
||||
use serde_json::json;
|
||||
use settings::SettingsStore;
|
||||
use std::{path::Path, sync::Arc};
|
||||
use std::{
|
||||
path::Path,
|
||||
sync::{Arc, atomic::AtomicUsize},
|
||||
};
|
||||
use task::TcpArgumentsTemplate;
|
||||
use util::path;
|
||||
|
||||
#[gpui::test(iterations = 10)]
|
||||
|
@ -688,3 +694,162 @@ async fn test_remote_server_debugger(
|
|||
|
||||
shutdown_session.await.unwrap();
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_slow_adapter_startup_retries(
|
||||
cx_a: &mut TestAppContext,
|
||||
server_cx: &mut TestAppContext,
|
||||
executor: BackgroundExecutor,
|
||||
) {
|
||||
cx_a.update(|cx| {
|
||||
release_channel::init(SemanticVersion::default(), cx);
|
||||
command_palette_hooks::init(cx);
|
||||
zlog::init_test();
|
||||
dap_adapters::init(cx);
|
||||
});
|
||||
server_cx.update(|cx| {
|
||||
release_channel::init(SemanticVersion::default(), cx);
|
||||
dap_adapters::init(cx);
|
||||
});
|
||||
let (opts, server_ssh) = SshRemoteClient::fake_server(cx_a, server_cx);
|
||||
let remote_fs = FakeFs::new(server_cx.executor());
|
||||
remote_fs
|
||||
.insert_tree(
|
||||
path!("/code"),
|
||||
json!({
|
||||
"lib.rs": "fn one() -> usize { 1 }"
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
|
||||
// User A connects to the remote project via SSH.
|
||||
server_cx.update(HeadlessProject::init);
|
||||
let remote_http_client = Arc::new(BlockedHttpClient);
|
||||
let node = NodeRuntime::unavailable();
|
||||
let languages = Arc::new(LanguageRegistry::new(server_cx.executor()));
|
||||
let _headless_project = server_cx.new(|cx| {
|
||||
client::init_settings(cx);
|
||||
HeadlessProject::new(
|
||||
HeadlessAppState {
|
||||
session: server_ssh,
|
||||
fs: remote_fs.clone(),
|
||||
http_client: remote_http_client,
|
||||
node_runtime: node,
|
||||
languages,
|
||||
extension_host_proxy: Arc::new(ExtensionHostProxy::new()),
|
||||
},
|
||||
cx,
|
||||
)
|
||||
});
|
||||
|
||||
let client_ssh = SshRemoteClient::fake_client(opts, cx_a).await;
|
||||
let mut server = TestServer::start(server_cx.executor()).await;
|
||||
let client_a = server.create_client(cx_a, "user_a").await;
|
||||
cx_a.update(|cx| {
|
||||
debugger_ui::init(cx);
|
||||
command_palette_hooks::init(cx);
|
||||
});
|
||||
let (project_a, _) = client_a
|
||||
.build_ssh_project(path!("/code"), client_ssh.clone(), cx_a)
|
||||
.await;
|
||||
|
||||
let (workspace, cx_a) = client_a.build_workspace(&project_a, cx_a);
|
||||
|
||||
let debugger_panel = workspace
|
||||
.update_in(cx_a, |_workspace, window, cx| {
|
||||
cx.spawn_in(window, DebugPanel::load)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
workspace.update_in(cx_a, |workspace, window, cx| {
|
||||
workspace.add_panel(debugger_panel, window, cx);
|
||||
});
|
||||
|
||||
cx_a.run_until_parked();
|
||||
let debug_panel = workspace
|
||||
.update(cx_a, |workspace, cx| workspace.panel::<DebugPanel>(cx))
|
||||
.unwrap();
|
||||
|
||||
let workspace_window = cx_a
|
||||
.window_handle()
|
||||
.downcast::<workspace::Workspace>()
|
||||
.unwrap();
|
||||
|
||||
let count = Arc::new(AtomicUsize::new(0));
|
||||
let session = debugger_ui::tests::start_debug_session_with(
|
||||
&workspace_window,
|
||||
cx_a,
|
||||
DebugTaskDefinition {
|
||||
adapter: "fake-adapter".into(),
|
||||
label: "test".into(),
|
||||
config: json!({
|
||||
"request": "launch"
|
||||
}),
|
||||
tcp_connection: Some(TcpArgumentsTemplate {
|
||||
port: None,
|
||||
host: None,
|
||||
timeout: None,
|
||||
}),
|
||||
},
|
||||
move |client| {
|
||||
let count = count.clone();
|
||||
client.on_request_ext::<dap::requests::Initialize, _>(move |_seq, _request| {
|
||||
if count.fetch_add(1, std::sync::atomic::Ordering::SeqCst) < 5 {
|
||||
return RequestHandling::Exit;
|
||||
}
|
||||
RequestHandling::Respond(Ok(Capabilities::default()))
|
||||
});
|
||||
},
|
||||
)
|
||||
.unwrap();
|
||||
cx_a.run_until_parked();
|
||||
|
||||
let client = session.update(cx_a, |session, _| session.adapter_client().unwrap());
|
||||
client
|
||||
.fake_event(dap::messages::Events::Stopped(dap::StoppedEvent {
|
||||
reason: dap::StoppedEventReason::Pause,
|
||||
description: None,
|
||||
thread_id: Some(1),
|
||||
preserve_focus_hint: None,
|
||||
text: None,
|
||||
all_threads_stopped: None,
|
||||
hit_breakpoint_ids: None,
|
||||
}))
|
||||
.await;
|
||||
|
||||
cx_a.run_until_parked();
|
||||
|
||||
let active_session = debug_panel
|
||||
.update(cx_a, |this, _| this.active_session())
|
||||
.unwrap();
|
||||
|
||||
let running_state = active_session.update(cx_a, |active_session, _| {
|
||||
active_session.running_state().clone()
|
||||
});
|
||||
|
||||
assert_eq!(
|
||||
client.id(),
|
||||
running_state.read_with(cx_a, |running_state, _| running_state.session_id())
|
||||
);
|
||||
assert_eq!(
|
||||
ThreadId(1),
|
||||
running_state.read_with(cx_a, |running_state, _| running_state
|
||||
.selected_thread_id()
|
||||
.unwrap())
|
||||
);
|
||||
|
||||
let shutdown_session = workspace.update(cx_a, |workspace, cx| {
|
||||
workspace.project().update(cx, |project, cx| {
|
||||
project.dap_store().update(cx, |dap_store, cx| {
|
||||
dap_store.shutdown_session(session.read(cx).session_id(), cx)
|
||||
})
|
||||
})
|
||||
});
|
||||
|
||||
client_ssh.update(cx_a, |a, _| {
|
||||
a.shutdown_processes(Some(proto::ShutdownRemoteServer {}), executor)
|
||||
});
|
||||
|
||||
shutdown_session.await.unwrap();
|
||||
}
|
||||
|
|
|
@ -442,10 +442,18 @@ impl DebugAdapter for FakeAdapter {
|
|||
_: Option<Vec<String>>,
|
||||
_: &mut AsyncApp,
|
||||
) -> Result<DebugAdapterBinary> {
|
||||
let connection = task_definition
|
||||
.tcp_connection
|
||||
.as_ref()
|
||||
.map(|connection| TcpArguments {
|
||||
host: connection.host(),
|
||||
port: connection.port.unwrap_or(17),
|
||||
timeout: connection.timeout,
|
||||
});
|
||||
Ok(DebugAdapterBinary {
|
||||
command: Some("command".into()),
|
||||
arguments: vec![],
|
||||
connection: None,
|
||||
connection,
|
||||
envs: HashMap::default(),
|
||||
cwd: None,
|
||||
request_args: StartDebuggingRequestArguments {
|
||||
|
|
|
@ -2,7 +2,7 @@ use crate::{
|
|||
adapters::DebugAdapterBinary,
|
||||
transport::{IoKind, LogKind, TransportDelegate},
|
||||
};
|
||||
use anyhow::{Context as _, Result};
|
||||
use anyhow::Result;
|
||||
use dap_types::{
|
||||
messages::{Message, Response},
|
||||
requests::Request,
|
||||
|
@ -110,9 +110,7 @@ impl DebugAdapterClient {
|
|||
self.transport_delegate
|
||||
.pending_requests
|
||||
.lock()
|
||||
.as_mut()
|
||||
.context("client is closed")?
|
||||
.insert(sequence_id, callback_tx);
|
||||
.insert(sequence_id, callback_tx)?;
|
||||
|
||||
log::debug!(
|
||||
"Client {} send `{}` request with sequence_id: {}",
|
||||
|
@ -170,6 +168,7 @@ impl DebugAdapterClient {
|
|||
pub fn kill(&self) {
|
||||
log::debug!("Killing DAP process");
|
||||
self.transport_delegate.transport.lock().kill();
|
||||
self.transport_delegate.pending_requests.lock().shutdown();
|
||||
}
|
||||
|
||||
pub fn has_adapter_logs(&self) -> bool {
|
||||
|
@ -184,11 +183,34 @@ impl DebugAdapterClient {
|
|||
}
|
||||
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
pub fn on_request<R: dap_types::requests::Request, F>(&self, handler: F)
|
||||
pub fn on_request<R: dap_types::requests::Request, F>(&self, mut handler: F)
|
||||
where
|
||||
F: 'static
|
||||
+ Send
|
||||
+ FnMut(u64, R::Arguments) -> Result<R::Response, dap_types::ErrorResponse>,
|
||||
{
|
||||
use crate::transport::RequestHandling;
|
||||
|
||||
self.transport_delegate
|
||||
.transport
|
||||
.lock()
|
||||
.as_fake()
|
||||
.on_request::<R, _>(move |seq, request| {
|
||||
RequestHandling::Respond(handler(seq, request))
|
||||
});
|
||||
}
|
||||
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
pub fn on_request_ext<R: dap_types::requests::Request, F>(&self, handler: F)
|
||||
where
|
||||
F: 'static
|
||||
+ Send
|
||||
+ FnMut(
|
||||
u64,
|
||||
R::Arguments,
|
||||
) -> crate::transport::RequestHandling<
|
||||
Result<R::Response, dap_types::ErrorResponse>,
|
||||
>,
|
||||
{
|
||||
self.transport_delegate
|
||||
.transport
|
||||
|
|
|
@ -49,6 +49,12 @@ pub enum IoKind {
|
|||
StdErr,
|
||||
}
|
||||
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
pub enum RequestHandling<T> {
|
||||
Respond(T),
|
||||
Exit,
|
||||
}
|
||||
|
||||
type LogHandlers = Arc<Mutex<SmallVec<[(LogKind, IoHandler); 2]>>>;
|
||||
|
||||
pub trait Transport: Send + Sync {
|
||||
|
@ -76,7 +82,11 @@ async fn start(
|
|||
) -> Result<Box<dyn Transport>> {
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
if cfg!(any(test, feature = "test-support")) {
|
||||
return Ok(Box::new(FakeTransport::start(cx).await?));
|
||||
if let Some(connection) = binary.connection.clone() {
|
||||
return Ok(Box::new(FakeTransport::start_tcp(connection, cx).await?));
|
||||
} else {
|
||||
return Ok(Box::new(FakeTransport::start_stdio(cx).await?));
|
||||
}
|
||||
}
|
||||
|
||||
if binary.connection.is_some() {
|
||||
|
@ -90,11 +100,57 @@ async fn start(
|
|||
}
|
||||
}
|
||||
|
||||
pub(crate) struct PendingRequests {
|
||||
inner: Option<HashMap<u64, oneshot::Sender<Result<Response>>>>,
|
||||
}
|
||||
|
||||
impl PendingRequests {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
inner: Some(HashMap::default()),
|
||||
}
|
||||
}
|
||||
|
||||
fn flush(&mut self, e: anyhow::Error) {
|
||||
let Some(inner) = self.inner.as_mut() else {
|
||||
return;
|
||||
};
|
||||
for (_, sender) in inner.drain() {
|
||||
sender.send(Err(e.cloned())).ok();
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn insert(
|
||||
&mut self,
|
||||
sequence_id: u64,
|
||||
callback_tx: oneshot::Sender<Result<Response>>,
|
||||
) -> anyhow::Result<()> {
|
||||
let Some(inner) = self.inner.as_mut() else {
|
||||
bail!("client is closed")
|
||||
};
|
||||
inner.insert(sequence_id, callback_tx);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn remove(
|
||||
&mut self,
|
||||
sequence_id: u64,
|
||||
) -> anyhow::Result<Option<oneshot::Sender<Result<Response>>>> {
|
||||
let Some(inner) = self.inner.as_mut() else {
|
||||
bail!("client is closed");
|
||||
};
|
||||
Ok(inner.remove(&sequence_id))
|
||||
}
|
||||
|
||||
pub(crate) fn shutdown(&mut self) {
|
||||
self.flush(anyhow!("transport shutdown"));
|
||||
self.inner = None;
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct TransportDelegate {
|
||||
log_handlers: LogHandlers,
|
||||
// TODO this should really be some kind of associative channel
|
||||
pub(crate) pending_requests:
|
||||
Arc<Mutex<Option<HashMap<u64, oneshot::Sender<Result<Response>>>>>>,
|
||||
pub(crate) pending_requests: Arc<Mutex<PendingRequests>>,
|
||||
pub(crate) transport: Mutex<Box<dyn Transport>>,
|
||||
pub(crate) server_tx: smol::lock::Mutex<Option<Sender<Message>>>,
|
||||
tasks: Mutex<Vec<Task<()>>>,
|
||||
|
@ -108,7 +164,7 @@ impl TransportDelegate {
|
|||
transport: Mutex::new(transport),
|
||||
log_handlers,
|
||||
server_tx: Default::default(),
|
||||
pending_requests: Arc::new(Mutex::new(Some(HashMap::default()))),
|
||||
pending_requests: Arc::new(Mutex::new(PendingRequests::new())),
|
||||
tasks: Default::default(),
|
||||
})
|
||||
}
|
||||
|
@ -151,24 +207,10 @@ impl TransportDelegate {
|
|||
Ok(()) => {
|
||||
pending_requests
|
||||
.lock()
|
||||
.take()
|
||||
.into_iter()
|
||||
.flatten()
|
||||
.for_each(|(_, request)| {
|
||||
request
|
||||
.send(Err(anyhow!("debugger shutdown unexpectedly")))
|
||||
.ok();
|
||||
});
|
||||
.flush(anyhow!("debugger shutdown unexpectedly"));
|
||||
}
|
||||
Err(e) => {
|
||||
pending_requests
|
||||
.lock()
|
||||
.take()
|
||||
.into_iter()
|
||||
.flatten()
|
||||
.for_each(|(_, request)| {
|
||||
request.send(Err(e.cloned())).ok();
|
||||
});
|
||||
pending_requests.lock().flush(e);
|
||||
}
|
||||
}
|
||||
}));
|
||||
|
@ -286,7 +328,7 @@ impl TransportDelegate {
|
|||
async fn recv_from_server<Stdout>(
|
||||
server_stdout: Stdout,
|
||||
mut message_handler: DapMessageHandler,
|
||||
pending_requests: Arc<Mutex<Option<HashMap<u64, oneshot::Sender<Result<Response>>>>>>,
|
||||
pending_requests: Arc<Mutex<PendingRequests>>,
|
||||
log_handlers: Option<LogHandlers>,
|
||||
) -> Result<()>
|
||||
where
|
||||
|
@ -303,14 +345,10 @@ impl TransportDelegate {
|
|||
ConnectionResult::Timeout => anyhow::bail!("Timed out when connecting to debugger"),
|
||||
ConnectionResult::ConnectionReset => {
|
||||
log::info!("Debugger closed the connection");
|
||||
break Ok(());
|
||||
return Ok(());
|
||||
}
|
||||
ConnectionResult::Result(Ok(Message::Response(res))) => {
|
||||
let tx = pending_requests
|
||||
.lock()
|
||||
.as_mut()
|
||||
.context("client is closed")?
|
||||
.remove(&res.request_seq);
|
||||
let tx = pending_requests.lock().remove(res.request_seq)?;
|
||||
if let Some(tx) = tx {
|
||||
if let Err(e) = tx.send(Self::process_response(res)) {
|
||||
log::trace!("Did not send response `{:?}` for a cancelled", e);
|
||||
|
@ -704,8 +742,7 @@ impl Drop for StdioTransport {
|
|||
}
|
||||
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
type RequestHandler =
|
||||
Box<dyn Send + FnMut(u64, serde_json::Value) -> dap_types::messages::Response>;
|
||||
type RequestHandler = Box<dyn Send + FnMut(u64, serde_json::Value) -> RequestHandling<Response>>;
|
||||
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
type ResponseHandler = Box<dyn Send + Fn(Response)>;
|
||||
|
@ -716,23 +753,38 @@ pub struct FakeTransport {
|
|||
request_handlers: Arc<Mutex<HashMap<&'static str, RequestHandler>>>,
|
||||
// for reverse request responses
|
||||
response_handlers: Arc<Mutex<HashMap<&'static str, ResponseHandler>>>,
|
||||
|
||||
stdin_writer: Option<PipeWriter>,
|
||||
stdout_reader: Option<PipeReader>,
|
||||
message_handler: Option<Task<Result<()>>>,
|
||||
kind: FakeTransportKind,
|
||||
}
|
||||
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
pub enum FakeTransportKind {
|
||||
Stdio {
|
||||
stdin_writer: Option<PipeWriter>,
|
||||
stdout_reader: Option<PipeReader>,
|
||||
},
|
||||
Tcp {
|
||||
connection: TcpArguments,
|
||||
executor: BackgroundExecutor,
|
||||
},
|
||||
}
|
||||
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
impl FakeTransport {
|
||||
pub fn on_request<R: dap_types::requests::Request, F>(&self, mut handler: F)
|
||||
where
|
||||
F: 'static + Send + FnMut(u64, R::Arguments) -> Result<R::Response, ErrorResponse>,
|
||||
F: 'static
|
||||
+ Send
|
||||
+ FnMut(u64, R::Arguments) -> RequestHandling<Result<R::Response, ErrorResponse>>,
|
||||
{
|
||||
self.request_handlers.lock().insert(
|
||||
R::COMMAND,
|
||||
Box::new(move |seq, args| {
|
||||
let result = handler(seq, serde_json::from_value(args).unwrap());
|
||||
let response = match result {
|
||||
let RequestHandling::Respond(response) = result else {
|
||||
return RequestHandling::Exit;
|
||||
};
|
||||
let response = match response {
|
||||
Ok(response) => Response {
|
||||
seq: seq + 1,
|
||||
request_seq: seq,
|
||||
|
@ -750,7 +802,7 @@ impl FakeTransport {
|
|||
message: None,
|
||||
},
|
||||
};
|
||||
response
|
||||
RequestHandling::Respond(response)
|
||||
}),
|
||||
);
|
||||
}
|
||||
|
@ -764,86 +816,75 @@ impl FakeTransport {
|
|||
.insert(R::COMMAND, Box::new(handler));
|
||||
}
|
||||
|
||||
async fn start(cx: &mut AsyncApp) -> Result<Self> {
|
||||
async fn start_tcp(connection: TcpArguments, cx: &mut AsyncApp) -> Result<Self> {
|
||||
Ok(Self {
|
||||
request_handlers: Arc::new(Mutex::new(HashMap::default())),
|
||||
response_handlers: Arc::new(Mutex::new(HashMap::default())),
|
||||
message_handler: None,
|
||||
kind: FakeTransportKind::Tcp {
|
||||
connection,
|
||||
executor: cx.background_executor().clone(),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
async fn handle_messages(
|
||||
request_handlers: Arc<Mutex<HashMap<&'static str, RequestHandler>>>,
|
||||
response_handlers: Arc<Mutex<HashMap<&'static str, ResponseHandler>>>,
|
||||
stdin_reader: PipeReader,
|
||||
stdout_writer: PipeWriter,
|
||||
) -> Result<()> {
|
||||
use dap_types::requests::{Request, RunInTerminal, StartDebugging};
|
||||
use serde_json::json;
|
||||
|
||||
let (stdin_writer, stdin_reader) = async_pipe::pipe();
|
||||
let (stdout_writer, stdout_reader) = async_pipe::pipe();
|
||||
|
||||
let mut this = Self {
|
||||
request_handlers: Arc::new(Mutex::new(HashMap::default())),
|
||||
response_handlers: Arc::new(Mutex::new(HashMap::default())),
|
||||
stdin_writer: Some(stdin_writer),
|
||||
stdout_reader: Some(stdout_reader),
|
||||
message_handler: None,
|
||||
};
|
||||
|
||||
let request_handlers = this.request_handlers.clone();
|
||||
let response_handlers = this.response_handlers.clone();
|
||||
let mut reader = BufReader::new(stdin_reader);
|
||||
let stdout_writer = Arc::new(smol::lock::Mutex::new(stdout_writer));
|
||||
let mut buffer = String::new();
|
||||
|
||||
this.message_handler = Some(cx.background_spawn(async move {
|
||||
let mut reader = BufReader::new(stdin_reader);
|
||||
let mut buffer = String::new();
|
||||
|
||||
loop {
|
||||
match TransportDelegate::receive_server_message(&mut reader, &mut buffer, None)
|
||||
.await
|
||||
{
|
||||
ConnectionResult::Timeout => {
|
||||
anyhow::bail!("Timed out when connecting to debugger");
|
||||
}
|
||||
ConnectionResult::ConnectionReset => {
|
||||
log::info!("Debugger closed the connection");
|
||||
break Ok(());
|
||||
}
|
||||
ConnectionResult::Result(Err(e)) => break Err(e),
|
||||
ConnectionResult::Result(Ok(message)) => {
|
||||
match message {
|
||||
Message::Request(request) => {
|
||||
// redirect reverse requests to stdout writer/reader
|
||||
if request.command == RunInTerminal::COMMAND
|
||||
|| request.command == StartDebugging::COMMAND
|
||||
{
|
||||
let message =
|
||||
serde_json::to_string(&Message::Request(request)).unwrap();
|
||||
|
||||
let mut writer = stdout_writer.lock().await;
|
||||
writer
|
||||
.write_all(
|
||||
TransportDelegate::build_rpc_message(message)
|
||||
.as_bytes(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
writer.flush().await.unwrap();
|
||||
} else {
|
||||
let response = if let Some(handle) =
|
||||
request_handlers.lock().get_mut(request.command.as_str())
|
||||
{
|
||||
handle(request.seq, request.arguments.unwrap_or(json!({})))
|
||||
} else {
|
||||
panic!("No request handler for {}", request.command);
|
||||
};
|
||||
let message =
|
||||
serde_json::to_string(&Message::Response(response))
|
||||
.unwrap();
|
||||
|
||||
let mut writer = stdout_writer.lock().await;
|
||||
writer
|
||||
.write_all(
|
||||
TransportDelegate::build_rpc_message(message)
|
||||
.as_bytes(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
writer.flush().await.unwrap();
|
||||
}
|
||||
}
|
||||
Message::Event(event) => {
|
||||
loop {
|
||||
match TransportDelegate::receive_server_message(&mut reader, &mut buffer, None).await {
|
||||
ConnectionResult::Timeout => {
|
||||
anyhow::bail!("Timed out when connecting to debugger");
|
||||
}
|
||||
ConnectionResult::ConnectionReset => {
|
||||
log::info!("Debugger closed the connection");
|
||||
break Ok(());
|
||||
}
|
||||
ConnectionResult::Result(Err(e)) => break Err(e),
|
||||
ConnectionResult::Result(Ok(message)) => {
|
||||
match message {
|
||||
Message::Request(request) => {
|
||||
// redirect reverse requests to stdout writer/reader
|
||||
if request.command == RunInTerminal::COMMAND
|
||||
|| request.command == StartDebugging::COMMAND
|
||||
{
|
||||
let message =
|
||||
serde_json::to_string(&Message::Event(event)).unwrap();
|
||||
serde_json::to_string(&Message::Request(request)).unwrap();
|
||||
|
||||
let mut writer = stdout_writer.lock().await;
|
||||
writer
|
||||
.write_all(
|
||||
TransportDelegate::build_rpc_message(message).as_bytes(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
writer.flush().await.unwrap();
|
||||
} else {
|
||||
let response = if let Some(handle) =
|
||||
request_handlers.lock().get_mut(request.command.as_str())
|
||||
{
|
||||
handle(request.seq, request.arguments.unwrap_or(json!({})))
|
||||
} else {
|
||||
panic!("No request handler for {}", request.command);
|
||||
};
|
||||
let response = match response {
|
||||
RequestHandling::Respond(response) => response,
|
||||
RequestHandling::Exit => {
|
||||
break Err(anyhow!("exit in response to request"));
|
||||
}
|
||||
};
|
||||
let message =
|
||||
serde_json::to_string(&Message::Response(response)).unwrap();
|
||||
|
||||
let mut writer = stdout_writer.lock().await;
|
||||
writer
|
||||
|
@ -854,20 +895,56 @@ impl FakeTransport {
|
|||
.unwrap();
|
||||
writer.flush().await.unwrap();
|
||||
}
|
||||
Message::Response(response) => {
|
||||
if let Some(handle) =
|
||||
response_handlers.lock().get(response.command.as_str())
|
||||
{
|
||||
handle(response);
|
||||
} else {
|
||||
log::error!("No response handler for {}", response.command);
|
||||
}
|
||||
}
|
||||
Message::Event(event) => {
|
||||
let message = serde_json::to_string(&Message::Event(event)).unwrap();
|
||||
|
||||
let mut writer = stdout_writer.lock().await;
|
||||
writer
|
||||
.write_all(TransportDelegate::build_rpc_message(message).as_bytes())
|
||||
.await
|
||||
.unwrap();
|
||||
writer.flush().await.unwrap();
|
||||
}
|
||||
Message::Response(response) => {
|
||||
if let Some(handle) =
|
||||
response_handlers.lock().get(response.command.as_str())
|
||||
{
|
||||
handle(response);
|
||||
} else {
|
||||
log::error!("No response handler for {}", response.command);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
async fn start_stdio(cx: &mut AsyncApp) -> Result<Self> {
|
||||
let (stdin_writer, stdin_reader) = async_pipe::pipe();
|
||||
let (stdout_writer, stdout_reader) = async_pipe::pipe();
|
||||
let kind = FakeTransportKind::Stdio {
|
||||
stdin_writer: Some(stdin_writer),
|
||||
stdout_reader: Some(stdout_reader),
|
||||
};
|
||||
|
||||
let mut this = Self {
|
||||
request_handlers: Arc::new(Mutex::new(HashMap::default())),
|
||||
response_handlers: Arc::new(Mutex::new(HashMap::default())),
|
||||
message_handler: None,
|
||||
kind,
|
||||
};
|
||||
|
||||
let request_handlers = this.request_handlers.clone();
|
||||
let response_handlers = this.response_handlers.clone();
|
||||
|
||||
this.message_handler = Some(cx.background_spawn(Self::handle_messages(
|
||||
request_handlers,
|
||||
response_handlers,
|
||||
stdin_reader,
|
||||
stdout_writer,
|
||||
)));
|
||||
|
||||
Ok(this)
|
||||
}
|
||||
|
@ -876,7 +953,10 @@ impl FakeTransport {
|
|||
#[cfg(any(test, feature = "test-support"))]
|
||||
impl Transport for FakeTransport {
|
||||
fn tcp_arguments(&self) -> Option<TcpArguments> {
|
||||
None
|
||||
match &self.kind {
|
||||
FakeTransportKind::Stdio { .. } => None,
|
||||
FakeTransportKind::Tcp { connection, .. } => Some(connection.clone()),
|
||||
}
|
||||
}
|
||||
|
||||
fn connect(
|
||||
|
@ -887,12 +967,33 @@ impl Transport for FakeTransport {
|
|||
Box<dyn AsyncRead + Unpin + Send + 'static>,
|
||||
)>,
|
||||
> {
|
||||
let result = util::maybe!({
|
||||
Ok((
|
||||
Box::new(self.stdin_writer.take().context("Cannot reconnect")?) as _,
|
||||
Box::new(self.stdout_reader.take().context("Cannot reconnect")?) as _,
|
||||
))
|
||||
});
|
||||
let result = match &mut self.kind {
|
||||
FakeTransportKind::Stdio {
|
||||
stdin_writer,
|
||||
stdout_reader,
|
||||
} => util::maybe!({
|
||||
Ok((
|
||||
Box::new(stdin_writer.take().context("Cannot reconnect")?) as _,
|
||||
Box::new(stdout_reader.take().context("Cannot reconnect")?) as _,
|
||||
))
|
||||
}),
|
||||
FakeTransportKind::Tcp { executor, .. } => {
|
||||
let (stdin_writer, stdin_reader) = async_pipe::pipe();
|
||||
let (stdout_writer, stdout_reader) = async_pipe::pipe();
|
||||
|
||||
let request_handlers = self.request_handlers.clone();
|
||||
let response_handlers = self.response_handlers.clone();
|
||||
|
||||
self.message_handler = Some(executor.spawn(Self::handle_messages(
|
||||
request_handlers,
|
||||
response_handlers,
|
||||
stdin_reader,
|
||||
stdout_writer,
|
||||
)));
|
||||
|
||||
Ok((Box::new(stdin_writer) as _, Box::new(stdout_reader) as _))
|
||||
}
|
||||
};
|
||||
Task::ready(result)
|
||||
}
|
||||
|
||||
|
|
|
@ -1694,6 +1694,7 @@ impl Render for DebugPanel {
|
|||
category_filter: Some(
|
||||
zed_actions::ExtensionCategoryFilter::DebugAdapters,
|
||||
),
|
||||
id: None,
|
||||
}
|
||||
.boxed_clone(),
|
||||
cx,
|
||||
|
|
|
@ -122,7 +122,7 @@ impl DebugSession {
|
|||
.to_owned()
|
||||
}
|
||||
|
||||
pub(crate) fn running_state(&self) -> &Entity<RunningState> {
|
||||
pub fn running_state(&self) -> &Entity<RunningState> {
|
||||
&self.running_state
|
||||
}
|
||||
|
||||
|
|
|
@ -1459,7 +1459,7 @@ impl RunningState {
|
|||
}
|
||||
}
|
||||
|
||||
pub(crate) fn selected_thread_id(&self) -> Option<ThreadId> {
|
||||
pub fn selected_thread_id(&self) -> Option<ThreadId> {
|
||||
self.thread_id
|
||||
}
|
||||
|
||||
|
|
|
@ -482,9 +482,7 @@ pub enum SelectMode {
|
|||
|
||||
#[derive(Clone, PartialEq, Eq, Debug)]
|
||||
pub enum EditorMode {
|
||||
SingleLine {
|
||||
auto_width: bool,
|
||||
},
|
||||
SingleLine,
|
||||
AutoHeight {
|
||||
min_lines: usize,
|
||||
max_lines: Option<usize>,
|
||||
|
@ -1662,13 +1660,7 @@ impl Editor {
|
|||
pub fn single_line(window: &mut Window, cx: &mut Context<Self>) -> Self {
|
||||
let buffer = cx.new(|cx| Buffer::local("", cx));
|
||||
let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
|
||||
Self::new(
|
||||
EditorMode::SingleLine { auto_width: false },
|
||||
buffer,
|
||||
None,
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
Self::new(EditorMode::SingleLine, buffer, None, window, cx)
|
||||
}
|
||||
|
||||
pub fn multi_line(window: &mut Window, cx: &mut Context<Self>) -> Self {
|
||||
|
@ -1677,18 +1669,6 @@ impl Editor {
|
|||
Self::new(EditorMode::full(), buffer, None, window, cx)
|
||||
}
|
||||
|
||||
pub fn auto_width(window: &mut Window, cx: &mut Context<Self>) -> Self {
|
||||
let buffer = cx.new(|cx| Buffer::local("", cx));
|
||||
let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
|
||||
Self::new(
|
||||
EditorMode::SingleLine { auto_width: true },
|
||||
buffer,
|
||||
None,
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn auto_height(
|
||||
min_lines: usize,
|
||||
max_lines: usize,
|
||||
|
@ -20487,6 +20467,7 @@ impl Editor {
|
|||
if event.blurred != self.focus_handle {
|
||||
self.last_focused_descendant = Some(event.blurred);
|
||||
}
|
||||
self.selection_drag_state = SelectionDragState::None;
|
||||
self.refresh_inlay_hints(InlayHintRefreshReason::ModifiersChanged(false), cx);
|
||||
}
|
||||
|
||||
|
|
|
@ -7777,46 +7777,13 @@ impl Element for EditorElement {
|
|||
editor.set_style(self.style.clone(), window, cx);
|
||||
|
||||
let layout_id = match editor.mode {
|
||||
EditorMode::SingleLine { auto_width } => {
|
||||
EditorMode::SingleLine => {
|
||||
let rem_size = window.rem_size();
|
||||
|
||||
let height = self.style.text.line_height_in_pixels(rem_size);
|
||||
if auto_width {
|
||||
let editor_handle = cx.entity().clone();
|
||||
let style = self.style.clone();
|
||||
window.request_measured_layout(
|
||||
Style::default(),
|
||||
move |_, _, window, cx| {
|
||||
let editor_snapshot = editor_handle
|
||||
.update(cx, |editor, cx| editor.snapshot(window, cx));
|
||||
let line = Self::layout_lines(
|
||||
DisplayRow(0)..DisplayRow(1),
|
||||
&editor_snapshot,
|
||||
&style,
|
||||
px(f32::MAX),
|
||||
|_| false, // Single lines never soft wrap
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
.pop()
|
||||
.unwrap();
|
||||
|
||||
let font_id =
|
||||
window.text_system().resolve_font(&style.text.font());
|
||||
let font_size =
|
||||
style.text.font_size.to_pixels(window.rem_size());
|
||||
let em_width =
|
||||
window.text_system().em_width(font_id, font_size).unwrap();
|
||||
|
||||
size(line.width + em_width, height)
|
||||
},
|
||||
)
|
||||
} else {
|
||||
let mut style = Style::default();
|
||||
style.size.height = height.into();
|
||||
style.size.width = relative(1.).into();
|
||||
window.request_layout(style, None, cx)
|
||||
}
|
||||
let mut style = Style::default();
|
||||
style.size.height = height.into();
|
||||
style.size.width = relative(1.).into();
|
||||
window.request_layout(style, None, cx)
|
||||
}
|
||||
EditorMode::AutoHeight {
|
||||
min_lines,
|
||||
|
@ -10388,7 +10355,7 @@ mod tests {
|
|||
});
|
||||
|
||||
for editor_mode_without_invisibles in [
|
||||
EditorMode::SingleLine { auto_width: false },
|
||||
EditorMode::SingleLine,
|
||||
EditorMode::AutoHeight {
|
||||
min_lines: 1,
|
||||
max_lines: Some(100),
|
||||
|
|
|
@ -12,7 +12,7 @@ use crate::{
|
|||
};
|
||||
pub use autoscroll::{Autoscroll, AutoscrollStrategy};
|
||||
use core::fmt::Debug;
|
||||
use gpui::{App, Axis, Context, Global, Pixels, Task, Window, point, px};
|
||||
use gpui::{Along, App, Axis, Context, Global, Pixels, Task, Window, point, px};
|
||||
use language::language_settings::{AllLanguageSettings, SoftWrap};
|
||||
use language::{Bias, Point};
|
||||
pub use scroll_amount::ScrollAmount;
|
||||
|
@ -47,14 +47,14 @@ impl ScrollAnchor {
|
|||
}
|
||||
|
||||
pub fn scroll_position(&self, snapshot: &DisplaySnapshot) -> gpui::Point<f32> {
|
||||
let mut scroll_position = self.offset;
|
||||
if self.anchor == Anchor::min() {
|
||||
scroll_position.y = 0.;
|
||||
} else {
|
||||
let scroll_top = self.anchor.to_display_point(snapshot).row().as_f32();
|
||||
scroll_position.y += scroll_top;
|
||||
}
|
||||
scroll_position
|
||||
self.offset.apply_along(Axis::Vertical, |offset| {
|
||||
if self.anchor == Anchor::min() {
|
||||
0.
|
||||
} else {
|
||||
let scroll_top = self.anchor.to_display_point(snapshot).row().as_f32();
|
||||
(offset + scroll_top).max(0.)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
pub fn top_row(&self, buffer: &MultiBufferSnapshot) -> u32 {
|
||||
|
|
|
@ -221,9 +221,6 @@ impl ExampleContext {
|
|||
ThreadEvent::ShowError(thread_error) => {
|
||||
tx.try_send(Err(anyhow!(thread_error.clone()))).ok();
|
||||
}
|
||||
ThreadEvent::RetriesFailed { .. } => {
|
||||
// Ignore retries failed events
|
||||
}
|
||||
ThreadEvent::Stopped(reason) => match reason {
|
||||
Ok(StopReason::EndTurn) => {
|
||||
tx.close_channel();
|
||||
|
|
|
@ -6,6 +6,7 @@ use std::sync::OnceLock;
|
|||
use std::time::Duration;
|
||||
use std::{ops::Range, sync::Arc};
|
||||
|
||||
use anyhow::Context as _;
|
||||
use client::{ExtensionMetadata, ExtensionProvides};
|
||||
use collections::{BTreeMap, BTreeSet};
|
||||
use editor::{Editor, EditorElement, EditorStyle};
|
||||
|
@ -80,16 +81,24 @@ pub fn init(cx: &mut App) {
|
|||
.find_map(|item| item.downcast::<ExtensionsPage>());
|
||||
|
||||
if let Some(existing) = existing {
|
||||
if provides_filter.is_some() {
|
||||
existing.update(cx, |extensions_page, cx| {
|
||||
existing.update(cx, |extensions_page, cx| {
|
||||
if provides_filter.is_some() {
|
||||
extensions_page.change_provides_filter(provides_filter, cx);
|
||||
});
|
||||
}
|
||||
}
|
||||
if let Some(id) = action.id.as_ref() {
|
||||
extensions_page.focus_extension(id, window, cx);
|
||||
}
|
||||
});
|
||||
|
||||
workspace.activate_item(&existing, true, true, window, cx);
|
||||
} else {
|
||||
let extensions_page =
|
||||
ExtensionsPage::new(workspace, provides_filter, window, cx);
|
||||
let extensions_page = ExtensionsPage::new(
|
||||
workspace,
|
||||
provides_filter,
|
||||
action.id.as_deref(),
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
workspace.add_item_to_active_pane(
|
||||
Box::new(extensions_page),
|
||||
None,
|
||||
|
@ -287,6 +296,7 @@ impl ExtensionsPage {
|
|||
pub fn new(
|
||||
workspace: &Workspace,
|
||||
provides_filter: Option<ExtensionProvides>,
|
||||
focus_extension_id: Option<&str>,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Workspace>,
|
||||
) -> Entity<Self> {
|
||||
|
@ -317,6 +327,9 @@ impl ExtensionsPage {
|
|||
let query_editor = cx.new(|cx| {
|
||||
let mut input = Editor::single_line(window, cx);
|
||||
input.set_placeholder_text("Search extensions...", cx);
|
||||
if let Some(id) = focus_extension_id {
|
||||
input.set_text(format!("id:{id}"), window, cx);
|
||||
}
|
||||
input
|
||||
});
|
||||
cx.subscribe(&query_editor, Self::on_query_change).detach();
|
||||
|
@ -340,7 +353,7 @@ impl ExtensionsPage {
|
|||
scrollbar_state: ScrollbarState::new(scroll_handle),
|
||||
};
|
||||
this.fetch_extensions(
|
||||
None,
|
||||
this.search_query(cx),
|
||||
Some(BTreeSet::from_iter(this.provides_filter)),
|
||||
None,
|
||||
cx,
|
||||
|
@ -464,9 +477,23 @@ impl ExtensionsPage {
|
|||
.cloned()
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let remote_extensions = extension_store.update(cx, |store, cx| {
|
||||
store.fetch_extensions(search.as_deref(), provides_filter.as_ref(), cx)
|
||||
});
|
||||
let remote_extensions =
|
||||
if let Some(id) = search.as_ref().and_then(|s| s.strip_prefix("id:")) {
|
||||
let versions =
|
||||
extension_store.update(cx, |store, cx| store.fetch_extension_versions(id, cx));
|
||||
cx.foreground_executor().spawn(async move {
|
||||
let versions = versions.await?;
|
||||
let latest = versions
|
||||
.into_iter()
|
||||
.max_by_key(|v| v.published_at)
|
||||
.context("no extension found")?;
|
||||
Ok(vec![latest])
|
||||
})
|
||||
} else {
|
||||
extension_store.update(cx, |store, cx| {
|
||||
store.fetch_extensions(search.as_deref(), provides_filter.as_ref(), cx)
|
||||
})
|
||||
};
|
||||
|
||||
cx.spawn(async move |this, cx| {
|
||||
let dev_extensions = if let Some(search) = search {
|
||||
|
@ -1156,6 +1183,13 @@ impl ExtensionsPage {
|
|||
self.refresh_feature_upsells(cx);
|
||||
}
|
||||
|
||||
pub fn focus_extension(&mut self, id: &str, window: &mut Window, cx: &mut Context<Self>) {
|
||||
self.query_editor.update(cx, |editor, cx| {
|
||||
editor.set_text(format!("id:{id}"), window, cx)
|
||||
});
|
||||
self.refresh_search(cx);
|
||||
}
|
||||
|
||||
pub fn change_provides_filter(
|
||||
&mut self,
|
||||
provides_filter: Option<ExtensionProvides>,
|
||||
|
|
|
@ -126,7 +126,7 @@ mod macos {
|
|||
"ContentMask".into(),
|
||||
"Uniforms".into(),
|
||||
"AtlasTile".into(),
|
||||
"PathInputIndex".into(),
|
||||
"PathRasterizationInputIndex".into(),
|
||||
"PathVertex_ScaledPixels".into(),
|
||||
"ShadowInputIndex".into(),
|
||||
"Shadow".into(),
|
||||
|
|
|
@ -1,13 +1,9 @@
|
|||
use gpui::{
|
||||
Application, Background, Bounds, ColorSpace, Context, MouseDownEvent, Path, PathBuilder,
|
||||
PathStyle, Pixels, Point, Render, SharedString, StrokeOptions, Window, WindowBounds,
|
||||
WindowOptions, canvas, div, linear_color_stop, linear_gradient, point, prelude::*, px, rgb,
|
||||
size,
|
||||
PathStyle, Pixels, Point, Render, SharedString, StrokeOptions, Window, WindowOptions, canvas,
|
||||
div, linear_color_stop, linear_gradient, point, prelude::*, px, rgb, size,
|
||||
};
|
||||
|
||||
const DEFAULT_WINDOW_WIDTH: Pixels = px(1024.0);
|
||||
const DEFAULT_WINDOW_HEIGHT: Pixels = px(768.0);
|
||||
|
||||
struct PaintingViewer {
|
||||
default_lines: Vec<(Path<Pixels>, Background)>,
|
||||
lines: Vec<Vec<Point<Pixels>>>,
|
||||
|
@ -151,6 +147,8 @@ impl PaintingViewer {
|
|||
px(320.0 + (i as f32 * 10.0).sin() * 40.0),
|
||||
));
|
||||
}
|
||||
let path = builder.build().unwrap();
|
||||
lines.push((path, gpui::green().into()));
|
||||
|
||||
Self {
|
||||
default_lines: lines.clone(),
|
||||
|
@ -185,13 +183,9 @@ fn button(
|
|||
}
|
||||
|
||||
impl Render for PaintingViewer {
|
||||
fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
|
||||
window.request_animation_frame();
|
||||
|
||||
fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
|
||||
let default_lines = self.default_lines.clone();
|
||||
let lines = self.lines.clone();
|
||||
let window_size = window.bounds().size;
|
||||
let scale = window_size.width / DEFAULT_WINDOW_WIDTH;
|
||||
let dashed = self.dashed;
|
||||
|
||||
div()
|
||||
|
@ -228,7 +222,7 @@ impl Render for PaintingViewer {
|
|||
move |_, _, _| {},
|
||||
move |_, _, window, _| {
|
||||
for (path, color) in default_lines {
|
||||
window.paint_path(path.clone().scale(scale), color);
|
||||
window.paint_path(path, color);
|
||||
}
|
||||
|
||||
for points in lines {
|
||||
|
@ -304,11 +298,6 @@ fn main() {
|
|||
cx.open_window(
|
||||
WindowOptions {
|
||||
focus: true,
|
||||
window_bounds: Some(WindowBounds::Windowed(Bounds::centered(
|
||||
None,
|
||||
size(DEFAULT_WINDOW_WIDTH, DEFAULT_WINDOW_HEIGHT),
|
||||
cx,
|
||||
))),
|
||||
..Default::default()
|
||||
},
|
||||
|window, cx| cx.new(|cx| PaintingViewer::new(window, cx)),
|
||||
|
|
|
@ -336,7 +336,10 @@ impl PathBuilder {
|
|||
let v1 = buf.vertices[i1];
|
||||
let v2 = buf.vertices[i2];
|
||||
|
||||
path.push_triangle((v0.into(), v1.into(), v2.into()));
|
||||
path.push_triangle(
|
||||
(v0.into(), v1.into(), v2.into()),
|
||||
(point(0., 1.), point(0., 1.), point(0., 1.)),
|
||||
);
|
||||
}
|
||||
|
||||
path
|
||||
|
|
|
@ -789,6 +789,7 @@ pub(crate) struct AtlasTextureId {
|
|||
pub(crate) enum AtlasTextureKind {
|
||||
Monochrome = 0,
|
||||
Polychrome = 1,
|
||||
Path = 2,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
|
||||
|
|
|
@ -10,6 +10,8 @@ use etagere::BucketedAtlasAllocator;
|
|||
use parking_lot::Mutex;
|
||||
use std::{borrow::Cow, ops, sync::Arc};
|
||||
|
||||
pub(crate) const PATH_TEXTURE_FORMAT: gpu::TextureFormat = gpu::TextureFormat::R16Float;
|
||||
|
||||
pub(crate) struct BladeAtlas(Mutex<BladeAtlasState>);
|
||||
|
||||
struct PendingUpload {
|
||||
|
@ -25,6 +27,7 @@ struct BladeAtlasState {
|
|||
tiles_by_key: FxHashMap<AtlasKey, AtlasTile>,
|
||||
initializations: Vec<AtlasTextureId>,
|
||||
uploads: Vec<PendingUpload>,
|
||||
path_sample_count: u32,
|
||||
}
|
||||
|
||||
#[cfg(gles)]
|
||||
|
@ -38,13 +41,13 @@ impl BladeAtlasState {
|
|||
}
|
||||
|
||||
pub struct BladeTextureInfo {
|
||||
#[allow(dead_code)]
|
||||
pub size: gpu::Extent,
|
||||
pub raw_view: gpu::TextureView,
|
||||
pub msaa_view: Option<gpu::TextureView>,
|
||||
}
|
||||
|
||||
impl BladeAtlas {
|
||||
pub(crate) fn new(gpu: &Arc<gpu::Context>) -> Self {
|
||||
pub(crate) fn new(gpu: &Arc<gpu::Context>, path_sample_count: u32) -> Self {
|
||||
BladeAtlas(Mutex::new(BladeAtlasState {
|
||||
gpu: Arc::clone(gpu),
|
||||
upload_belt: BufferBelt::new(BufferBeltDescriptor {
|
||||
|
@ -56,6 +59,7 @@ impl BladeAtlas {
|
|||
tiles_by_key: Default::default(),
|
||||
initializations: Vec::new(),
|
||||
uploads: Vec::new(),
|
||||
path_sample_count,
|
||||
}))
|
||||
}
|
||||
|
||||
|
@ -63,7 +67,6 @@ impl BladeAtlas {
|
|||
self.0.lock().destroy();
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub(crate) fn clear_textures(&self, texture_kind: AtlasTextureKind) {
|
||||
let mut lock = self.0.lock();
|
||||
let textures = &mut lock.storage[texture_kind];
|
||||
|
@ -72,6 +75,19 @@ impl BladeAtlas {
|
|||
}
|
||||
}
|
||||
|
||||
/// Allocate a rectangle and make it available for rendering immediately (without waiting for `before_frame`)
|
||||
pub fn allocate_for_rendering(
|
||||
&self,
|
||||
size: Size<DevicePixels>,
|
||||
texture_kind: AtlasTextureKind,
|
||||
gpu_encoder: &mut gpu::CommandEncoder,
|
||||
) -> AtlasTile {
|
||||
let mut lock = self.0.lock();
|
||||
let tile = lock.allocate(size, texture_kind);
|
||||
lock.flush_initializations(gpu_encoder);
|
||||
tile
|
||||
}
|
||||
|
||||
pub fn before_frame(&self, gpu_encoder: &mut gpu::CommandEncoder) {
|
||||
let mut lock = self.0.lock();
|
||||
lock.flush(gpu_encoder);
|
||||
|
@ -93,6 +109,7 @@ impl BladeAtlas {
|
|||
depth: 1,
|
||||
},
|
||||
raw_view: texture.raw_view,
|
||||
msaa_view: texture.msaa_view,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -183,8 +200,48 @@ impl BladeAtlasState {
|
|||
format = gpu::TextureFormat::Bgra8UnormSrgb;
|
||||
usage = gpu::TextureUsage::COPY | gpu::TextureUsage::RESOURCE;
|
||||
}
|
||||
AtlasTextureKind::Path => {
|
||||
format = PATH_TEXTURE_FORMAT;
|
||||
usage = gpu::TextureUsage::COPY
|
||||
| gpu::TextureUsage::RESOURCE
|
||||
| gpu::TextureUsage::TARGET;
|
||||
}
|
||||
}
|
||||
|
||||
// We currently only enable MSAA for path textures.
|
||||
let (msaa, msaa_view) = if self.path_sample_count > 1 && kind == AtlasTextureKind::Path {
|
||||
let msaa = self.gpu.create_texture(gpu::TextureDesc {
|
||||
name: "msaa path texture",
|
||||
format,
|
||||
size: gpu::Extent {
|
||||
width: size.width.into(),
|
||||
height: size.height.into(),
|
||||
depth: 1,
|
||||
},
|
||||
array_layer_count: 1,
|
||||
mip_level_count: 1,
|
||||
sample_count: self.path_sample_count,
|
||||
dimension: gpu::TextureDimension::D2,
|
||||
usage: gpu::TextureUsage::TARGET,
|
||||
external: None,
|
||||
});
|
||||
|
||||
(
|
||||
Some(msaa),
|
||||
Some(self.gpu.create_texture_view(
|
||||
msaa,
|
||||
gpu::TextureViewDesc {
|
||||
name: "msaa texture view",
|
||||
format,
|
||||
dimension: gpu::ViewDimension::D2,
|
||||
subresources: &Default::default(),
|
||||
},
|
||||
)),
|
||||
)
|
||||
} else {
|
||||
(None, None)
|
||||
};
|
||||
|
||||
let raw = self.gpu.create_texture(gpu::TextureDesc {
|
||||
name: "atlas",
|
||||
format,
|
||||
|
@ -222,6 +279,8 @@ impl BladeAtlasState {
|
|||
format,
|
||||
raw,
|
||||
raw_view,
|
||||
msaa,
|
||||
msaa_view,
|
||||
live_atlas_keys: 0,
|
||||
};
|
||||
|
||||
|
@ -281,6 +340,7 @@ impl BladeAtlasState {
|
|||
struct BladeAtlasStorage {
|
||||
monochrome_textures: AtlasTextureList<BladeAtlasTexture>,
|
||||
polychrome_textures: AtlasTextureList<BladeAtlasTexture>,
|
||||
path_textures: AtlasTextureList<BladeAtlasTexture>,
|
||||
}
|
||||
|
||||
impl ops::Index<AtlasTextureKind> for BladeAtlasStorage {
|
||||
|
@ -289,6 +349,7 @@ impl ops::Index<AtlasTextureKind> for BladeAtlasStorage {
|
|||
match kind {
|
||||
crate::AtlasTextureKind::Monochrome => &self.monochrome_textures,
|
||||
crate::AtlasTextureKind::Polychrome => &self.polychrome_textures,
|
||||
crate::AtlasTextureKind::Path => &self.path_textures,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -298,6 +359,7 @@ impl ops::IndexMut<AtlasTextureKind> for BladeAtlasStorage {
|
|||
match kind {
|
||||
crate::AtlasTextureKind::Monochrome => &mut self.monochrome_textures,
|
||||
crate::AtlasTextureKind::Polychrome => &mut self.polychrome_textures,
|
||||
crate::AtlasTextureKind::Path => &mut self.path_textures,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -308,6 +370,7 @@ impl ops::Index<AtlasTextureId> for BladeAtlasStorage {
|
|||
let textures = match id.kind {
|
||||
crate::AtlasTextureKind::Monochrome => &self.monochrome_textures,
|
||||
crate::AtlasTextureKind::Polychrome => &self.polychrome_textures,
|
||||
crate::AtlasTextureKind::Path => &self.path_textures,
|
||||
};
|
||||
textures[id.index as usize].as_ref().unwrap()
|
||||
}
|
||||
|
@ -321,6 +384,9 @@ impl BladeAtlasStorage {
|
|||
for mut texture in self.polychrome_textures.drain().flatten() {
|
||||
texture.destroy(gpu);
|
||||
}
|
||||
for mut texture in self.path_textures.drain().flatten() {
|
||||
texture.destroy(gpu);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -329,6 +395,8 @@ struct BladeAtlasTexture {
|
|||
allocator: BucketedAtlasAllocator,
|
||||
raw: gpu::Texture,
|
||||
raw_view: gpu::TextureView,
|
||||
msaa: Option<gpu::Texture>,
|
||||
msaa_view: Option<gpu::TextureView>,
|
||||
format: gpu::TextureFormat,
|
||||
live_atlas_keys: u32,
|
||||
}
|
||||
|
@ -356,6 +424,12 @@ impl BladeAtlasTexture {
|
|||
fn destroy(&mut self, gpu: &gpu::Context) {
|
||||
gpu.destroy_texture(self.raw);
|
||||
gpu.destroy_texture_view(self.raw_view);
|
||||
if let Some(msaa) = self.msaa {
|
||||
gpu.destroy_texture(msaa);
|
||||
}
|
||||
if let Some(msaa_view) = self.msaa_view {
|
||||
gpu.destroy_texture_view(msaa_view);
|
||||
}
|
||||
}
|
||||
|
||||
fn bytes_per_pixel(&self) -> u8 {
|
||||
|
|
|
@ -1,19 +1,24 @@
|
|||
// Doing `if let` gives you nice scoping with passes/encoders
|
||||
#![allow(irrefutable_let_patterns)]
|
||||
|
||||
use super::{BladeAtlas, BladeContext};
|
||||
use super::{BladeAtlas, BladeContext, PATH_TEXTURE_FORMAT};
|
||||
use crate::{
|
||||
Background, Bounds, ContentMask, DevicePixels, GpuSpecs, MonochromeSprite, PathVertex,
|
||||
PolychromeSprite, PrimitiveBatch, Quad, ScaledPixels, Scene, Shadow, Size, Underline,
|
||||
AtlasTextureKind, AtlasTile, Background, Bounds, ContentMask, DevicePixels, GpuSpecs,
|
||||
MonochromeSprite, Path, PathId, PathVertex, PolychromeSprite, PrimitiveBatch, Quad,
|
||||
ScaledPixels, Scene, Shadow, Size, Underline,
|
||||
};
|
||||
use blade_graphics::{self as gpu};
|
||||
use blade_graphics as gpu;
|
||||
use blade_util::{BufferBelt, BufferBeltDescriptor};
|
||||
use bytemuck::{Pod, Zeroable};
|
||||
use collections::HashMap;
|
||||
#[cfg(target_os = "macos")]
|
||||
use media::core_video::CVMetalTextureCache;
|
||||
use std::{mem, sync::Arc};
|
||||
|
||||
const MAX_FRAME_TIME_MS: u32 = 10000;
|
||||
// Use 4x MSAA, all devices support it.
|
||||
// https://developer.apple.com/documentation/metal/mtldevice/1433355-supportstexturesamplecount
|
||||
const DEFAULT_PATH_SAMPLE_COUNT: u32 = 4;
|
||||
|
||||
#[repr(C)]
|
||||
#[derive(Clone, Copy, Pod, Zeroable)]
|
||||
|
@ -61,9 +66,16 @@ struct ShaderShadowsData {
|
|||
}
|
||||
|
||||
#[derive(blade_macros::ShaderData)]
|
||||
struct ShaderPathsData {
|
||||
struct ShaderPathRasterizationData {
|
||||
globals: GlobalParams,
|
||||
b_path_vertices: gpu::BufferPiece,
|
||||
}
|
||||
|
||||
#[derive(blade_macros::ShaderData)]
|
||||
struct ShaderPathsData {
|
||||
globals: GlobalParams,
|
||||
t_sprite: gpu::TextureView,
|
||||
s_sprite: gpu::Sampler,
|
||||
b_path_sprites: gpu::BufferPiece,
|
||||
}
|
||||
|
||||
|
@ -103,27 +115,13 @@ struct ShaderSurfacesData {
|
|||
struct PathSprite {
|
||||
bounds: Bounds<ScaledPixels>,
|
||||
color: Background,
|
||||
}
|
||||
|
||||
/// Argument buffer layout for `draw_indirect` commands.
|
||||
#[repr(C)]
|
||||
#[derive(Copy, Clone, Debug, Default, Pod, Zeroable)]
|
||||
pub struct DrawIndirectArgs {
|
||||
/// The number of vertices to draw.
|
||||
pub vertex_count: u32,
|
||||
/// The number of instances to draw.
|
||||
pub instance_count: u32,
|
||||
/// The Index of the first vertex to draw.
|
||||
pub first_vertex: u32,
|
||||
/// The instance ID of the first instance to draw.
|
||||
///
|
||||
/// Has to be 0, unless [`Features::INDIRECT_FIRST_INSTANCE`](crate::Features::INDIRECT_FIRST_INSTANCE) is enabled.
|
||||
pub first_instance: u32,
|
||||
tile: AtlasTile,
|
||||
}
|
||||
|
||||
struct BladePipelines {
|
||||
quads: gpu::RenderPipeline,
|
||||
shadows: gpu::RenderPipeline,
|
||||
path_rasterization: gpu::RenderPipeline,
|
||||
paths: gpu::RenderPipeline,
|
||||
underlines: gpu::RenderPipeline,
|
||||
mono_sprites: gpu::RenderPipeline,
|
||||
|
@ -132,7 +130,7 @@ struct BladePipelines {
|
|||
}
|
||||
|
||||
impl BladePipelines {
|
||||
fn new(gpu: &gpu::Context, surface_info: gpu::SurfaceInfo, sample_count: u32) -> Self {
|
||||
fn new(gpu: &gpu::Context, surface_info: gpu::SurfaceInfo, path_sample_count: u32) -> Self {
|
||||
use gpu::ShaderData as _;
|
||||
|
||||
log::info!(
|
||||
|
@ -180,10 +178,7 @@ impl BladePipelines {
|
|||
depth_stencil: None,
|
||||
fragment: Some(shader.at("fs_quad")),
|
||||
color_targets,
|
||||
multisample_state: gpu::MultisampleState {
|
||||
sample_count,
|
||||
..Default::default()
|
||||
},
|
||||
multisample_state: gpu::MultisampleState::default(),
|
||||
}),
|
||||
shadows: gpu.create_render_pipeline(gpu::RenderPipelineDesc {
|
||||
name: "shadows",
|
||||
|
@ -197,8 +192,26 @@ impl BladePipelines {
|
|||
depth_stencil: None,
|
||||
fragment: Some(shader.at("fs_shadow")),
|
||||
color_targets,
|
||||
multisample_state: gpu::MultisampleState::default(),
|
||||
}),
|
||||
path_rasterization: gpu.create_render_pipeline(gpu::RenderPipelineDesc {
|
||||
name: "path_rasterization",
|
||||
data_layouts: &[&ShaderPathRasterizationData::layout()],
|
||||
vertex: shader.at("vs_path_rasterization"),
|
||||
vertex_fetches: &[],
|
||||
primitive: gpu::PrimitiveState {
|
||||
topology: gpu::PrimitiveTopology::TriangleList,
|
||||
..Default::default()
|
||||
},
|
||||
depth_stencil: None,
|
||||
fragment: Some(shader.at("fs_path_rasterization")),
|
||||
color_targets: &[gpu::ColorTargetState {
|
||||
format: PATH_TEXTURE_FORMAT,
|
||||
blend: Some(gpu::BlendState::ADDITIVE),
|
||||
write_mask: gpu::ColorWrites::default(),
|
||||
}],
|
||||
multisample_state: gpu::MultisampleState {
|
||||
sample_count,
|
||||
sample_count: path_sample_count,
|
||||
..Default::default()
|
||||
},
|
||||
}),
|
||||
|
@ -208,16 +221,13 @@ impl BladePipelines {
|
|||
vertex: shader.at("vs_path"),
|
||||
vertex_fetches: &[],
|
||||
primitive: gpu::PrimitiveState {
|
||||
topology: gpu::PrimitiveTopology::TriangleList,
|
||||
topology: gpu::PrimitiveTopology::TriangleStrip,
|
||||
..Default::default()
|
||||
},
|
||||
depth_stencil: None,
|
||||
fragment: Some(shader.at("fs_path")),
|
||||
color_targets,
|
||||
multisample_state: gpu::MultisampleState {
|
||||
sample_count,
|
||||
..Default::default()
|
||||
},
|
||||
multisample_state: gpu::MultisampleState::default(),
|
||||
}),
|
||||
underlines: gpu.create_render_pipeline(gpu::RenderPipelineDesc {
|
||||
name: "underlines",
|
||||
|
@ -231,10 +241,7 @@ impl BladePipelines {
|
|||
depth_stencil: None,
|
||||
fragment: Some(shader.at("fs_underline")),
|
||||
color_targets,
|
||||
multisample_state: gpu::MultisampleState {
|
||||
sample_count,
|
||||
..Default::default()
|
||||
},
|
||||
multisample_state: gpu::MultisampleState::default(),
|
||||
}),
|
||||
mono_sprites: gpu.create_render_pipeline(gpu::RenderPipelineDesc {
|
||||
name: "mono-sprites",
|
||||
|
@ -248,10 +255,7 @@ impl BladePipelines {
|
|||
depth_stencil: None,
|
||||
fragment: Some(shader.at("fs_mono_sprite")),
|
||||
color_targets,
|
||||
multisample_state: gpu::MultisampleState {
|
||||
sample_count,
|
||||
..Default::default()
|
||||
},
|
||||
multisample_state: gpu::MultisampleState::default(),
|
||||
}),
|
||||
poly_sprites: gpu.create_render_pipeline(gpu::RenderPipelineDesc {
|
||||
name: "poly-sprites",
|
||||
|
@ -265,10 +269,7 @@ impl BladePipelines {
|
|||
depth_stencil: None,
|
||||
fragment: Some(shader.at("fs_poly_sprite")),
|
||||
color_targets,
|
||||
multisample_state: gpu::MultisampleState {
|
||||
sample_count,
|
||||
..Default::default()
|
||||
},
|
||||
multisample_state: gpu::MultisampleState::default(),
|
||||
}),
|
||||
surfaces: gpu.create_render_pipeline(gpu::RenderPipelineDesc {
|
||||
name: "surfaces",
|
||||
|
@ -282,10 +283,7 @@ impl BladePipelines {
|
|||
depth_stencil: None,
|
||||
fragment: Some(shader.at("fs_surface")),
|
||||
color_targets,
|
||||
multisample_state: gpu::MultisampleState {
|
||||
sample_count,
|
||||
..Default::default()
|
||||
},
|
||||
multisample_state: gpu::MultisampleState::default(),
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
@ -293,6 +291,7 @@ impl BladePipelines {
|
|||
fn destroy(&mut self, gpu: &gpu::Context) {
|
||||
gpu.destroy_render_pipeline(&mut self.quads);
|
||||
gpu.destroy_render_pipeline(&mut self.shadows);
|
||||
gpu.destroy_render_pipeline(&mut self.path_rasterization);
|
||||
gpu.destroy_render_pipeline(&mut self.paths);
|
||||
gpu.destroy_render_pipeline(&mut self.underlines);
|
||||
gpu.destroy_render_pipeline(&mut self.mono_sprites);
|
||||
|
@ -318,13 +317,12 @@ pub struct BladeRenderer {
|
|||
last_sync_point: Option<gpu::SyncPoint>,
|
||||
pipelines: BladePipelines,
|
||||
instance_belt: BufferBelt,
|
||||
path_tiles: HashMap<PathId, AtlasTile>,
|
||||
atlas: Arc<BladeAtlas>,
|
||||
atlas_sampler: gpu::Sampler,
|
||||
#[cfg(target_os = "macos")]
|
||||
core_video_texture_cache: CVMetalTextureCache,
|
||||
sample_count: u32,
|
||||
texture_msaa: Option<gpu::Texture>,
|
||||
texture_view_msaa: Option<gpu::TextureView>,
|
||||
path_sample_count: u32,
|
||||
}
|
||||
|
||||
impl BladeRenderer {
|
||||
|
@ -333,18 +331,6 @@ impl BladeRenderer {
|
|||
window: &I,
|
||||
config: BladeSurfaceConfig,
|
||||
) -> anyhow::Result<Self> {
|
||||
// workaround for https://github.com/zed-industries/zed/issues/26143
|
||||
let sample_count = std::env::var("ZED_SAMPLE_COUNT")
|
||||
.ok()
|
||||
.or_else(|| std::env::var("ZED_PATH_SAMPLE_COUNT").ok())
|
||||
.and_then(|v| v.parse().ok())
|
||||
.or_else(|| {
|
||||
[4, 2, 1]
|
||||
.into_iter()
|
||||
.find(|count| context.gpu.supports_texture_sample_count(*count))
|
||||
})
|
||||
.unwrap_or(1);
|
||||
|
||||
let surface_config = gpu::SurfaceConfig {
|
||||
size: config.size,
|
||||
usage: gpu::TextureUsage::TARGET,
|
||||
|
@ -358,27 +344,22 @@ impl BladeRenderer {
|
|||
.create_surface_configured(window, surface_config)
|
||||
.map_err(|err| anyhow::anyhow!("Failed to create surface: {err:?}"))?;
|
||||
|
||||
let (texture_msaa, texture_view_msaa) = create_msaa_texture_if_needed(
|
||||
&context.gpu,
|
||||
surface.info().format,
|
||||
config.size.width,
|
||||
config.size.height,
|
||||
sample_count,
|
||||
)
|
||||
.unzip();
|
||||
|
||||
let command_encoder = context.gpu.create_command_encoder(gpu::CommandEncoderDesc {
|
||||
name: "main",
|
||||
buffer_count: 2,
|
||||
});
|
||||
|
||||
let pipelines = BladePipelines::new(&context.gpu, surface.info(), sample_count);
|
||||
// workaround for https://github.com/zed-industries/zed/issues/26143
|
||||
let path_sample_count = std::env::var("ZED_PATH_SAMPLE_COUNT")
|
||||
.ok()
|
||||
.and_then(|v| v.parse().ok())
|
||||
.unwrap_or(DEFAULT_PATH_SAMPLE_COUNT);
|
||||
let pipelines = BladePipelines::new(&context.gpu, surface.info(), path_sample_count);
|
||||
let instance_belt = BufferBelt::new(BufferBeltDescriptor {
|
||||
memory: gpu::Memory::Shared,
|
||||
min_chunk_size: 0x1000,
|
||||
alignment: 0x40, // Vulkan `minStorageBufferOffsetAlignment` on Intel Xe
|
||||
});
|
||||
let atlas = Arc::new(BladeAtlas::new(&context.gpu));
|
||||
let atlas = Arc::new(BladeAtlas::new(&context.gpu, path_sample_count));
|
||||
let atlas_sampler = context.gpu.create_sampler(gpu::SamplerDesc {
|
||||
name: "atlas",
|
||||
mag_filter: gpu::FilterMode::Linear,
|
||||
|
@ -402,13 +383,12 @@ impl BladeRenderer {
|
|||
last_sync_point: None,
|
||||
pipelines,
|
||||
instance_belt,
|
||||
path_tiles: HashMap::default(),
|
||||
atlas,
|
||||
atlas_sampler,
|
||||
#[cfg(target_os = "macos")]
|
||||
core_video_texture_cache,
|
||||
sample_count,
|
||||
texture_msaa,
|
||||
texture_view_msaa,
|
||||
path_sample_count,
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -461,24 +441,6 @@ impl BladeRenderer {
|
|||
self.surface_config.size = gpu_size;
|
||||
self.gpu
|
||||
.reconfigure_surface(&mut self.surface, self.surface_config);
|
||||
|
||||
if let Some(texture_msaa) = self.texture_msaa {
|
||||
self.gpu.destroy_texture(texture_msaa);
|
||||
}
|
||||
if let Some(texture_view_msaa) = self.texture_view_msaa {
|
||||
self.gpu.destroy_texture_view(texture_view_msaa);
|
||||
}
|
||||
|
||||
let (texture_msaa, texture_view_msaa) = create_msaa_texture_if_needed(
|
||||
&self.gpu,
|
||||
self.surface.info().format,
|
||||
gpu_size.width,
|
||||
gpu_size.height,
|
||||
self.sample_count,
|
||||
)
|
||||
.unzip();
|
||||
self.texture_msaa = texture_msaa;
|
||||
self.texture_view_msaa = texture_view_msaa;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -489,7 +451,8 @@ impl BladeRenderer {
|
|||
self.gpu
|
||||
.reconfigure_surface(&mut self.surface, self.surface_config);
|
||||
self.pipelines.destroy(&self.gpu);
|
||||
self.pipelines = BladePipelines::new(&self.gpu, self.surface.info(), self.sample_count);
|
||||
self.pipelines =
|
||||
BladePipelines::new(&self.gpu, self.surface.info(), self.path_sample_count);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -527,6 +490,80 @@ impl BladeRenderer {
|
|||
objc2::rc::Retained::as_ptr(&self.surface.metal_layer()) as *mut _
|
||||
}
|
||||
|
||||
#[profiling::function]
|
||||
fn rasterize_paths(&mut self, paths: &[Path<ScaledPixels>]) {
|
||||
self.path_tiles.clear();
|
||||
let mut vertices_by_texture_id = HashMap::default();
|
||||
|
||||
for path in paths {
|
||||
let clipped_bounds = path
|
||||
.bounds
|
||||
.intersect(&path.content_mask.bounds)
|
||||
.map_origin(|origin| origin.floor())
|
||||
.map_size(|size| size.ceil());
|
||||
let tile = self.atlas.allocate_for_rendering(
|
||||
clipped_bounds.size.map(Into::into),
|
||||
AtlasTextureKind::Path,
|
||||
&mut self.command_encoder,
|
||||
);
|
||||
vertices_by_texture_id
|
||||
.entry(tile.texture_id)
|
||||
.or_insert(Vec::new())
|
||||
.extend(path.vertices.iter().map(|vertex| PathVertex {
|
||||
xy_position: vertex.xy_position - clipped_bounds.origin
|
||||
+ tile.bounds.origin.map(Into::into),
|
||||
st_position: vertex.st_position,
|
||||
content_mask: ContentMask {
|
||||
bounds: tile.bounds.map(Into::into),
|
||||
},
|
||||
}));
|
||||
self.path_tiles.insert(path.id, tile);
|
||||
}
|
||||
|
||||
for (texture_id, vertices) in vertices_by_texture_id {
|
||||
let tex_info = self.atlas.get_texture_info(texture_id);
|
||||
let globals = GlobalParams {
|
||||
viewport_size: [tex_info.size.width as f32, tex_info.size.height as f32],
|
||||
premultiplied_alpha: 0,
|
||||
pad: 0,
|
||||
};
|
||||
|
||||
let vertex_buf = unsafe { self.instance_belt.alloc_typed(&vertices, &self.gpu) };
|
||||
let frame_view = tex_info.raw_view;
|
||||
let color_target = if let Some(msaa_view) = tex_info.msaa_view {
|
||||
gpu::RenderTarget {
|
||||
view: msaa_view,
|
||||
init_op: gpu::InitOp::Clear(gpu::TextureColor::OpaqueBlack),
|
||||
finish_op: gpu::FinishOp::ResolveTo(frame_view),
|
||||
}
|
||||
} else {
|
||||
gpu::RenderTarget {
|
||||
view: frame_view,
|
||||
init_op: gpu::InitOp::Clear(gpu::TextureColor::OpaqueBlack),
|
||||
finish_op: gpu::FinishOp::Store,
|
||||
}
|
||||
};
|
||||
|
||||
if let mut pass = self.command_encoder.render(
|
||||
"paths",
|
||||
gpu::RenderTargetSet {
|
||||
colors: &[color_target],
|
||||
depth_stencil: None,
|
||||
},
|
||||
) {
|
||||
let mut encoder = pass.with(&self.pipelines.path_rasterization);
|
||||
encoder.bind(
|
||||
0,
|
||||
&ShaderPathRasterizationData {
|
||||
globals,
|
||||
b_path_vertices: vertex_buf,
|
||||
},
|
||||
);
|
||||
encoder.draw(0, vertices.len() as u32, 0, 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn destroy(&mut self) {
|
||||
self.wait_for_gpu();
|
||||
self.atlas.destroy();
|
||||
|
@ -535,26 +572,17 @@ impl BladeRenderer {
|
|||
self.gpu.destroy_command_encoder(&mut self.command_encoder);
|
||||
self.pipelines.destroy(&self.gpu);
|
||||
self.gpu.destroy_surface(&mut self.surface);
|
||||
if let Some(texture_msaa) = self.texture_msaa {
|
||||
self.gpu.destroy_texture(texture_msaa);
|
||||
}
|
||||
if let Some(texture_view_msaa) = self.texture_view_msaa {
|
||||
self.gpu.destroy_texture_view(texture_view_msaa);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn draw(&mut self, scene: &Scene) {
|
||||
self.command_encoder.start();
|
||||
self.atlas.before_frame(&mut self.command_encoder);
|
||||
self.rasterize_paths(scene.paths());
|
||||
|
||||
let frame = {
|
||||
profiling::scope!("acquire frame");
|
||||
self.surface.acquire_frame()
|
||||
};
|
||||
let frame_view = frame.texture_view();
|
||||
if let Some(texture_msaa) = self.texture_msaa {
|
||||
self.command_encoder.init_texture(texture_msaa);
|
||||
}
|
||||
self.command_encoder.init_texture(frame.texture());
|
||||
|
||||
let globals = GlobalParams {
|
||||
|
@ -569,25 +597,14 @@ impl BladeRenderer {
|
|||
pad: 0,
|
||||
};
|
||||
|
||||
let target = if let Some(texture_view_msaa) = self.texture_view_msaa {
|
||||
gpu::RenderTarget {
|
||||
view: texture_view_msaa,
|
||||
init_op: gpu::InitOp::Clear(gpu::TextureColor::TransparentBlack),
|
||||
finish_op: gpu::FinishOp::ResolveTo(frame_view),
|
||||
}
|
||||
} else {
|
||||
gpu::RenderTarget {
|
||||
view: frame_view,
|
||||
init_op: gpu::InitOp::Clear(gpu::TextureColor::TransparentBlack),
|
||||
finish_op: gpu::FinishOp::Store,
|
||||
}
|
||||
};
|
||||
|
||||
// draw to the target texture
|
||||
if let mut pass = self.command_encoder.render(
|
||||
"main",
|
||||
gpu::RenderTargetSet {
|
||||
colors: &[target],
|
||||
colors: &[gpu::RenderTarget {
|
||||
view: frame.texture_view(),
|
||||
init_op: gpu::InitOp::Clear(gpu::TextureColor::TransparentBlack),
|
||||
finish_op: gpu::FinishOp::Store,
|
||||
}],
|
||||
depth_stencil: None,
|
||||
},
|
||||
) {
|
||||
|
@ -622,55 +639,32 @@ impl BladeRenderer {
|
|||
}
|
||||
PrimitiveBatch::Paths(paths) => {
|
||||
let mut encoder = pass.with(&self.pipelines.paths);
|
||||
|
||||
let mut vertices = Vec::new();
|
||||
let mut sprites = Vec::with_capacity(paths.len());
|
||||
let mut draw_indirect_commands = Vec::with_capacity(paths.len());
|
||||
let mut first_vertex = 0;
|
||||
|
||||
for (i, path) in paths.iter().enumerate() {
|
||||
draw_indirect_commands.push(DrawIndirectArgs {
|
||||
vertex_count: path.vertices.len() as u32,
|
||||
instance_count: 1,
|
||||
first_vertex,
|
||||
first_instance: i as u32,
|
||||
});
|
||||
first_vertex += path.vertices.len() as u32;
|
||||
|
||||
vertices.extend(path.vertices.iter().map(|v| PathVertex {
|
||||
xy_position: v.xy_position,
|
||||
content_mask: ContentMask {
|
||||
bounds: path.content_mask.bounds,
|
||||
// todo(linux): group by texture ID
|
||||
for path in paths {
|
||||
let tile = &self.path_tiles[&path.id];
|
||||
let tex_info = self.atlas.get_texture_info(tile.texture_id);
|
||||
let origin = path.bounds.intersect(&path.content_mask.bounds).origin;
|
||||
let sprites = [PathSprite {
|
||||
bounds: Bounds {
|
||||
origin: origin.map(|p| p.floor()),
|
||||
size: tile.bounds.size.map(Into::into),
|
||||
},
|
||||
}));
|
||||
|
||||
sprites.push(PathSprite {
|
||||
bounds: path.bounds,
|
||||
color: path.color,
|
||||
});
|
||||
}
|
||||
tile: (*tile).clone(),
|
||||
}];
|
||||
|
||||
let b_path_vertices =
|
||||
unsafe { self.instance_belt.alloc_typed(&vertices, &self.gpu) };
|
||||
let instance_buf =
|
||||
unsafe { self.instance_belt.alloc_typed(&sprites, &self.gpu) };
|
||||
let indirect_buf = unsafe {
|
||||
self.instance_belt
|
||||
.alloc_typed(&draw_indirect_commands, &self.gpu)
|
||||
};
|
||||
|
||||
encoder.bind(
|
||||
0,
|
||||
&ShaderPathsData {
|
||||
globals,
|
||||
b_path_vertices,
|
||||
b_path_sprites: instance_buf,
|
||||
},
|
||||
);
|
||||
|
||||
for i in 0..paths.len() {
|
||||
encoder.draw_indirect(indirect_buf.buffer.at(indirect_buf.offset
|
||||
+ (i * mem::size_of::<DrawIndirectArgs>()) as u64));
|
||||
let instance_buf =
|
||||
unsafe { self.instance_belt.alloc_typed(&sprites, &self.gpu) };
|
||||
encoder.bind(
|
||||
0,
|
||||
&ShaderPathsData {
|
||||
globals,
|
||||
t_sprite: tex_info.raw_view,
|
||||
s_sprite: self.atlas_sampler,
|
||||
b_path_sprites: instance_buf,
|
||||
},
|
||||
);
|
||||
encoder.draw(0, 4, 0, sprites.len() as u32);
|
||||
}
|
||||
}
|
||||
PrimitiveBatch::Underlines(underlines) => {
|
||||
|
@ -823,47 +817,9 @@ impl BladeRenderer {
|
|||
profiling::scope!("finish");
|
||||
self.instance_belt.flush(&sync_point);
|
||||
self.atlas.after_frame(&sync_point);
|
||||
self.atlas.clear_textures(AtlasTextureKind::Path);
|
||||
|
||||
self.wait_for_gpu();
|
||||
self.last_sync_point = Some(sync_point);
|
||||
}
|
||||
}
|
||||
|
||||
fn create_msaa_texture_if_needed(
|
||||
gpu: &gpu::Context,
|
||||
format: gpu::TextureFormat,
|
||||
width: u32,
|
||||
height: u32,
|
||||
sample_count: u32,
|
||||
) -> Option<(gpu::Texture, gpu::TextureView)> {
|
||||
if sample_count <= 1 {
|
||||
return None;
|
||||
}
|
||||
|
||||
let texture_msaa = gpu.create_texture(gpu::TextureDesc {
|
||||
name: "msaa",
|
||||
format,
|
||||
size: gpu::Extent {
|
||||
width,
|
||||
height,
|
||||
depth: 1,
|
||||
},
|
||||
array_layer_count: 1,
|
||||
mip_level_count: 1,
|
||||
sample_count,
|
||||
dimension: gpu::TextureDimension::D2,
|
||||
usage: gpu::TextureUsage::TARGET,
|
||||
external: None,
|
||||
});
|
||||
let texture_view_msaa = gpu.create_texture_view(
|
||||
texture_msaa,
|
||||
gpu::TextureViewDesc {
|
||||
name: "msaa view",
|
||||
format,
|
||||
dimension: gpu::ViewDimension::D2,
|
||||
subresources: &Default::default(),
|
||||
},
|
||||
);
|
||||
|
||||
Some((texture_msaa, texture_view_msaa))
|
||||
}
|
||||
|
|
|
@ -922,23 +922,59 @@ fn fs_shadow(input: ShadowVarying) -> @location(0) vec4<f32> {
|
|||
return blend_color(input.color, alpha);
|
||||
}
|
||||
|
||||
// --- paths --- //
|
||||
// --- path rasterization --- //
|
||||
|
||||
struct PathVertex {
|
||||
xy_position: vec2<f32>,
|
||||
st_position: vec2<f32>,
|
||||
content_mask: Bounds,
|
||||
}
|
||||
var<storage, read> b_path_vertices: array<PathVertex>;
|
||||
|
||||
struct PathRasterizationVarying {
|
||||
@builtin(position) position: vec4<f32>,
|
||||
@location(0) st_position: vec2<f32>,
|
||||
//TODO: use `clip_distance` once Naga supports it
|
||||
@location(3) clip_distances: vec4<f32>,
|
||||
}
|
||||
|
||||
@vertex
|
||||
fn vs_path_rasterization(@builtin(vertex_index) vertex_id: u32) -> PathRasterizationVarying {
|
||||
let v = b_path_vertices[vertex_id];
|
||||
|
||||
var out = PathRasterizationVarying();
|
||||
out.position = to_device_position_impl(v.xy_position);
|
||||
out.st_position = v.st_position;
|
||||
out.clip_distances = distance_from_clip_rect_impl(v.xy_position, v.content_mask);
|
||||
return out;
|
||||
}
|
||||
|
||||
@fragment
|
||||
fn fs_path_rasterization(input: PathRasterizationVarying) -> @location(0) f32 {
|
||||
let dx = dpdx(input.st_position);
|
||||
let dy = dpdy(input.st_position);
|
||||
if (any(input.clip_distances < vec4<f32>(0.0))) {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
let gradient = 2.0 * input.st_position.xx * vec2<f32>(dx.x, dy.x) - vec2<f32>(dx.y, dy.y);
|
||||
let f = input.st_position.x * input.st_position.x - input.st_position.y;
|
||||
let distance = f / length(gradient);
|
||||
return saturate(0.5 - distance);
|
||||
}
|
||||
|
||||
// --- paths --- //
|
||||
|
||||
struct PathSprite {
|
||||
bounds: Bounds,
|
||||
color: Background,
|
||||
tile: AtlasTile,
|
||||
}
|
||||
var<storage, read> b_path_vertices: array<PathVertex>;
|
||||
var<storage, read> b_path_sprites: array<PathSprite>;
|
||||
|
||||
struct PathVarying {
|
||||
@builtin(position) position: vec4<f32>,
|
||||
@location(0) clip_distances: vec4<f32>,
|
||||
@location(0) tile_position: vec2<f32>,
|
||||
@location(1) @interpolate(flat) instance_id: u32,
|
||||
@location(2) @interpolate(flat) color_solid: vec4<f32>,
|
||||
@location(3) @interpolate(flat) color0: vec4<f32>,
|
||||
|
@ -947,12 +983,13 @@ struct PathVarying {
|
|||
|
||||
@vertex
|
||||
fn vs_path(@builtin(vertex_index) vertex_id: u32, @builtin(instance_index) instance_id: u32) -> PathVarying {
|
||||
let v = b_path_vertices[vertex_id];
|
||||
let unit_vertex = vec2<f32>(f32(vertex_id & 1u), 0.5 * f32(vertex_id & 2u));
|
||||
let sprite = b_path_sprites[instance_id];
|
||||
// Don't apply content mask because it was already accounted for when rasterizing the path.
|
||||
|
||||
var out = PathVarying();
|
||||
out.position = to_device_position_impl(v.xy_position);
|
||||
out.clip_distances = distance_from_clip_rect_impl(v.xy_position, v.content_mask);
|
||||
out.position = to_device_position(unit_vertex, sprite.bounds);
|
||||
out.tile_position = to_tile_position(unit_vertex, sprite.tile);
|
||||
out.instance_id = instance_id;
|
||||
|
||||
let gradient = prepare_gradient_color(
|
||||
|
@ -969,15 +1006,13 @@ fn vs_path(@builtin(vertex_index) vertex_id: u32, @builtin(instance_index) insta
|
|||
|
||||
@fragment
|
||||
fn fs_path(input: PathVarying) -> @location(0) vec4<f32> {
|
||||
if any(input.clip_distances < vec4<f32>(0.0)) {
|
||||
return vec4<f32>(0.0);
|
||||
}
|
||||
|
||||
let sample = textureSample(t_sprite, s_sprite, input.tile_position).r;
|
||||
let mask = 1.0 - abs(1.0 - sample % 2.0);
|
||||
let sprite = b_path_sprites[input.instance_id];
|
||||
let background = sprite.color;
|
||||
let color = gradient_color(background, input.position.xy, sprite.bounds,
|
||||
input.color_solid, input.color0, input.color1);
|
||||
return blend_color(color, 1.0);
|
||||
return blend_color(color, mask);
|
||||
}
|
||||
|
||||
// --- underlines --- //
|
||||
|
|
|
@ -13,12 +13,14 @@ use std::borrow::Cow;
|
|||
pub(crate) struct MetalAtlas(Mutex<MetalAtlasState>);
|
||||
|
||||
impl MetalAtlas {
|
||||
pub(crate) fn new(device: Device) -> Self {
|
||||
pub(crate) fn new(device: Device, path_sample_count: u32) -> Self {
|
||||
MetalAtlas(Mutex::new(MetalAtlasState {
|
||||
device: AssertSend(device),
|
||||
monochrome_textures: Default::default(),
|
||||
polychrome_textures: Default::default(),
|
||||
path_textures: Default::default(),
|
||||
tiles_by_key: Default::default(),
|
||||
path_sample_count,
|
||||
}))
|
||||
}
|
||||
|
||||
|
@ -26,7 +28,10 @@ impl MetalAtlas {
|
|||
self.0.lock().texture(id).metal_texture.clone()
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub(crate) fn msaa_texture(&self, id: AtlasTextureId) -> Option<metal::Texture> {
|
||||
self.0.lock().texture(id).msaa_texture.clone()
|
||||
}
|
||||
|
||||
pub(crate) fn allocate(
|
||||
&self,
|
||||
size: Size<DevicePixels>,
|
||||
|
@ -35,12 +40,12 @@ impl MetalAtlas {
|
|||
self.0.lock().allocate(size, texture_kind)
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub(crate) fn clear_textures(&self, texture_kind: AtlasTextureKind) {
|
||||
let mut lock = self.0.lock();
|
||||
let textures = match texture_kind {
|
||||
AtlasTextureKind::Monochrome => &mut lock.monochrome_textures,
|
||||
AtlasTextureKind::Polychrome => &mut lock.polychrome_textures,
|
||||
AtlasTextureKind::Path => &mut lock.path_textures,
|
||||
};
|
||||
for texture in textures.iter_mut() {
|
||||
texture.clear();
|
||||
|
@ -52,7 +57,9 @@ struct MetalAtlasState {
|
|||
device: AssertSend<Device>,
|
||||
monochrome_textures: AtlasTextureList<MetalAtlasTexture>,
|
||||
polychrome_textures: AtlasTextureList<MetalAtlasTexture>,
|
||||
path_textures: AtlasTextureList<MetalAtlasTexture>,
|
||||
tiles_by_key: FxHashMap<AtlasKey, AtlasTile>,
|
||||
path_sample_count: u32,
|
||||
}
|
||||
|
||||
impl PlatformAtlas for MetalAtlas {
|
||||
|
@ -87,6 +94,7 @@ impl PlatformAtlas for MetalAtlas {
|
|||
let textures = match id.kind {
|
||||
AtlasTextureKind::Monochrome => &mut lock.monochrome_textures,
|
||||
AtlasTextureKind::Polychrome => &mut lock.polychrome_textures,
|
||||
AtlasTextureKind::Path => &mut lock.polychrome_textures,
|
||||
};
|
||||
|
||||
let Some(texture_slot) = textures
|
||||
|
@ -120,6 +128,7 @@ impl MetalAtlasState {
|
|||
let textures = match texture_kind {
|
||||
AtlasTextureKind::Monochrome => &mut self.monochrome_textures,
|
||||
AtlasTextureKind::Polychrome => &mut self.polychrome_textures,
|
||||
AtlasTextureKind::Path => &mut self.path_textures,
|
||||
};
|
||||
|
||||
if let Some(tile) = textures
|
||||
|
@ -164,14 +173,31 @@ impl MetalAtlasState {
|
|||
pixel_format = metal::MTLPixelFormat::BGRA8Unorm;
|
||||
usage = metal::MTLTextureUsage::ShaderRead;
|
||||
}
|
||||
AtlasTextureKind::Path => {
|
||||
pixel_format = metal::MTLPixelFormat::R16Float;
|
||||
usage = metal::MTLTextureUsage::RenderTarget | metal::MTLTextureUsage::ShaderRead;
|
||||
}
|
||||
}
|
||||
texture_descriptor.set_pixel_format(pixel_format);
|
||||
texture_descriptor.set_usage(usage);
|
||||
let metal_texture = self.device.new_texture(&texture_descriptor);
|
||||
|
||||
// We currently only enable MSAA for path textures.
|
||||
let msaa_texture = if self.path_sample_count > 1 && kind == AtlasTextureKind::Path {
|
||||
let mut descriptor = texture_descriptor.clone();
|
||||
descriptor.set_texture_type(metal::MTLTextureType::D2Multisample);
|
||||
descriptor.set_storage_mode(metal::MTLStorageMode::Private);
|
||||
descriptor.set_sample_count(self.path_sample_count as _);
|
||||
let msaa_texture = self.device.new_texture(&descriptor);
|
||||
Some(msaa_texture)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let texture_list = match kind {
|
||||
AtlasTextureKind::Monochrome => &mut self.monochrome_textures,
|
||||
AtlasTextureKind::Polychrome => &mut self.polychrome_textures,
|
||||
AtlasTextureKind::Path => &mut self.path_textures,
|
||||
};
|
||||
|
||||
let index = texture_list.free_list.pop();
|
||||
|
@ -183,6 +209,7 @@ impl MetalAtlasState {
|
|||
},
|
||||
allocator: etagere::BucketedAtlasAllocator::new(size.into()),
|
||||
metal_texture: AssertSend(metal_texture),
|
||||
msaa_texture: AssertSend(msaa_texture),
|
||||
live_atlas_keys: 0,
|
||||
};
|
||||
|
||||
|
@ -199,6 +226,7 @@ impl MetalAtlasState {
|
|||
let textures = match id.kind {
|
||||
crate::AtlasTextureKind::Monochrome => &self.monochrome_textures,
|
||||
crate::AtlasTextureKind::Polychrome => &self.polychrome_textures,
|
||||
crate::AtlasTextureKind::Path => &self.path_textures,
|
||||
};
|
||||
textures[id.index as usize].as_ref().unwrap()
|
||||
}
|
||||
|
@ -208,6 +236,7 @@ struct MetalAtlasTexture {
|
|||
id: AtlasTextureId,
|
||||
allocator: BucketedAtlasAllocator,
|
||||
metal_texture: AssertSend<metal::Texture>,
|
||||
msaa_texture: AssertSend<Option<metal::Texture>>,
|
||||
live_atlas_keys: u32,
|
||||
}
|
||||
|
||||
|
|
|
@ -1,28 +1,27 @@
|
|||
use super::metal_atlas::MetalAtlas;
|
||||
use crate::{
|
||||
AtlasTextureId, Background, Bounds, ContentMask, DevicePixels, MonochromeSprite, PaintSurface,
|
||||
Path, PathVertex, PolychromeSprite, PrimitiveBatch, Quad, ScaledPixels, Scene, Shadow, Size,
|
||||
Surface, Underline, point, size,
|
||||
AtlasTextureId, AtlasTextureKind, AtlasTile, Background, Bounds, ContentMask, DevicePixels,
|
||||
MonochromeSprite, PaintSurface, Path, PathId, PathVertex, PolychromeSprite, PrimitiveBatch,
|
||||
Quad, ScaledPixels, Scene, Shadow, Size, Surface, Underline, point, size,
|
||||
};
|
||||
use anyhow::Result;
|
||||
use anyhow::{Context as _, Result};
|
||||
use block::ConcreteBlock;
|
||||
use cocoa::{
|
||||
base::{NO, YES},
|
||||
foundation::{NSSize, NSUInteger},
|
||||
quartzcore::AutoresizingMask,
|
||||
};
|
||||
use collections::HashMap;
|
||||
use core_foundation::base::TCFType;
|
||||
use core_video::{
|
||||
metal_texture::CVMetalTextureGetTexture, metal_texture_cache::CVMetalTextureCache,
|
||||
pixel_buffer::kCVPixelFormatType_420YpCbCr8BiPlanarFullRange,
|
||||
};
|
||||
use foreign_types::{ForeignType, ForeignTypeRef};
|
||||
use metal::{
|
||||
CAMetalLayer, CommandQueue, MTLDrawPrimitivesIndirectArguments, MTLPixelFormat,
|
||||
MTLResourceOptions, NSRange,
|
||||
};
|
||||
use metal::{CAMetalLayer, CommandQueue, MTLPixelFormat, MTLResourceOptions, NSRange};
|
||||
use objc::{self, msg_send, sel, sel_impl};
|
||||
use parking_lot::Mutex;
|
||||
use smallvec::SmallVec;
|
||||
use std::{cell::Cell, ffi::c_void, mem, ptr, sync::Arc};
|
||||
|
||||
// Exported to metal
|
||||
|
@ -32,6 +31,9 @@ pub(crate) type PointF = crate::Point<f32>;
|
|||
const SHADERS_METALLIB: &[u8] = include_bytes!(concat!(env!("OUT_DIR"), "/shaders.metallib"));
|
||||
#[cfg(feature = "runtime_shaders")]
|
||||
const SHADERS_SOURCE_FILE: &str = include_str!(concat!(env!("OUT_DIR"), "/stitched_shaders.metal"));
|
||||
// Use 4x MSAA, all devices support it.
|
||||
// https://developer.apple.com/documentation/metal/mtldevice/1433355-supportstexturesamplecount
|
||||
const PATH_SAMPLE_COUNT: u32 = 4;
|
||||
|
||||
pub type Context = Arc<Mutex<InstanceBufferPool>>;
|
||||
pub type Renderer = MetalRenderer;
|
||||
|
@ -96,7 +98,8 @@ pub(crate) struct MetalRenderer {
|
|||
layer: metal::MetalLayer,
|
||||
presents_with_transaction: bool,
|
||||
command_queue: CommandQueue,
|
||||
path_pipeline_state: metal::RenderPipelineState,
|
||||
paths_rasterization_pipeline_state: metal::RenderPipelineState,
|
||||
path_sprites_pipeline_state: metal::RenderPipelineState,
|
||||
shadows_pipeline_state: metal::RenderPipelineState,
|
||||
quads_pipeline_state: metal::RenderPipelineState,
|
||||
underlines_pipeline_state: metal::RenderPipelineState,
|
||||
|
@ -108,8 +111,6 @@ pub(crate) struct MetalRenderer {
|
|||
instance_buffer_pool: Arc<Mutex<InstanceBufferPool>>,
|
||||
sprite_atlas: Arc<MetalAtlas>,
|
||||
core_video_texture_cache: core_video::metal_texture_cache::CVMetalTextureCache,
|
||||
sample_count: u64,
|
||||
msaa_texture: Option<metal::Texture>,
|
||||
}
|
||||
|
||||
impl MetalRenderer {
|
||||
|
@ -168,19 +169,22 @@ impl MetalRenderer {
|
|||
MTLResourceOptions::StorageModeManaged,
|
||||
);
|
||||
|
||||
let sample_count = [4, 2, 1]
|
||||
.into_iter()
|
||||
.find(|count| device.supports_texture_sample_count(*count))
|
||||
.unwrap_or(1);
|
||||
|
||||
let path_pipeline_state = build_pipeline_state(
|
||||
let paths_rasterization_pipeline_state = build_path_rasterization_pipeline_state(
|
||||
&device,
|
||||
&library,
|
||||
"paths",
|
||||
"path_vertex",
|
||||
"path_fragment",
|
||||
"paths_rasterization",
|
||||
"path_rasterization_vertex",
|
||||
"path_rasterization_fragment",
|
||||
MTLPixelFormat::R16Float,
|
||||
PATH_SAMPLE_COUNT,
|
||||
);
|
||||
let path_sprites_pipeline_state = build_pipeline_state(
|
||||
&device,
|
||||
&library,
|
||||
"path_sprites",
|
||||
"path_sprite_vertex",
|
||||
"path_sprite_fragment",
|
||||
MTLPixelFormat::BGRA8Unorm,
|
||||
sample_count,
|
||||
);
|
||||
let shadows_pipeline_state = build_pipeline_state(
|
||||
&device,
|
||||
|
@ -189,7 +193,6 @@ impl MetalRenderer {
|
|||
"shadow_vertex",
|
||||
"shadow_fragment",
|
||||
MTLPixelFormat::BGRA8Unorm,
|
||||
sample_count,
|
||||
);
|
||||
let quads_pipeline_state = build_pipeline_state(
|
||||
&device,
|
||||
|
@ -198,7 +201,6 @@ impl MetalRenderer {
|
|||
"quad_vertex",
|
||||
"quad_fragment",
|
||||
MTLPixelFormat::BGRA8Unorm,
|
||||
sample_count,
|
||||
);
|
||||
let underlines_pipeline_state = build_pipeline_state(
|
||||
&device,
|
||||
|
@ -207,7 +209,6 @@ impl MetalRenderer {
|
|||
"underline_vertex",
|
||||
"underline_fragment",
|
||||
MTLPixelFormat::BGRA8Unorm,
|
||||
sample_count,
|
||||
);
|
||||
let monochrome_sprites_pipeline_state = build_pipeline_state(
|
||||
&device,
|
||||
|
@ -216,7 +217,6 @@ impl MetalRenderer {
|
|||
"monochrome_sprite_vertex",
|
||||
"monochrome_sprite_fragment",
|
||||
MTLPixelFormat::BGRA8Unorm,
|
||||
sample_count,
|
||||
);
|
||||
let polychrome_sprites_pipeline_state = build_pipeline_state(
|
||||
&device,
|
||||
|
@ -225,7 +225,6 @@ impl MetalRenderer {
|
|||
"polychrome_sprite_vertex",
|
||||
"polychrome_sprite_fragment",
|
||||
MTLPixelFormat::BGRA8Unorm,
|
||||
sample_count,
|
||||
);
|
||||
let surfaces_pipeline_state = build_pipeline_state(
|
||||
&device,
|
||||
|
@ -234,21 +233,20 @@ impl MetalRenderer {
|
|||
"surface_vertex",
|
||||
"surface_fragment",
|
||||
MTLPixelFormat::BGRA8Unorm,
|
||||
sample_count,
|
||||
);
|
||||
|
||||
let command_queue = device.new_command_queue();
|
||||
let sprite_atlas = Arc::new(MetalAtlas::new(device.clone()));
|
||||
let sprite_atlas = Arc::new(MetalAtlas::new(device.clone(), PATH_SAMPLE_COUNT));
|
||||
let core_video_texture_cache =
|
||||
CVMetalTextureCache::new(None, device.clone(), None).unwrap();
|
||||
let msaa_texture = create_msaa_texture(&device, &layer, sample_count);
|
||||
|
||||
Self {
|
||||
device,
|
||||
layer,
|
||||
presents_with_transaction: false,
|
||||
command_queue,
|
||||
path_pipeline_state,
|
||||
paths_rasterization_pipeline_state,
|
||||
path_sprites_pipeline_state,
|
||||
shadows_pipeline_state,
|
||||
quads_pipeline_state,
|
||||
underlines_pipeline_state,
|
||||
|
@ -259,8 +257,6 @@ impl MetalRenderer {
|
|||
instance_buffer_pool,
|
||||
sprite_atlas,
|
||||
core_video_texture_cache,
|
||||
sample_count,
|
||||
msaa_texture,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -293,8 +289,6 @@ impl MetalRenderer {
|
|||
setDrawableSize: size
|
||||
];
|
||||
}
|
||||
|
||||
self.msaa_texture = create_msaa_texture(&self.device, &self.layer, self.sample_count);
|
||||
}
|
||||
|
||||
pub fn update_transparency(&self, _transparent: bool) {
|
||||
|
@ -381,23 +375,25 @@ impl MetalRenderer {
|
|||
let command_queue = self.command_queue.clone();
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let mut instance_offset = 0;
|
||||
|
||||
let path_tiles = self
|
||||
.rasterize_paths(
|
||||
scene.paths(),
|
||||
instance_buffer,
|
||||
&mut instance_offset,
|
||||
command_buffer,
|
||||
)
|
||||
.with_context(|| format!("rasterizing {} paths", scene.paths().len()))?;
|
||||
|
||||
let render_pass_descriptor = metal::RenderPassDescriptor::new();
|
||||
let color_attachment = render_pass_descriptor
|
||||
.color_attachments()
|
||||
.object_at(0)
|
||||
.unwrap();
|
||||
|
||||
if let Some(msaa_texture_ref) = self.msaa_texture.as_deref() {
|
||||
color_attachment.set_texture(Some(msaa_texture_ref));
|
||||
color_attachment.set_load_action(metal::MTLLoadAction::Clear);
|
||||
color_attachment.set_store_action(metal::MTLStoreAction::MultisampleResolve);
|
||||
color_attachment.set_resolve_texture(Some(drawable.texture()));
|
||||
} else {
|
||||
color_attachment.set_load_action(metal::MTLLoadAction::Clear);
|
||||
color_attachment.set_texture(Some(drawable.texture()));
|
||||
color_attachment.set_store_action(metal::MTLStoreAction::Store);
|
||||
}
|
||||
|
||||
color_attachment.set_texture(Some(drawable.texture()));
|
||||
color_attachment.set_load_action(metal::MTLLoadAction::Clear);
|
||||
color_attachment.set_store_action(metal::MTLStoreAction::Store);
|
||||
let alpha = if self.layer.is_opaque() { 1. } else { 0. };
|
||||
color_attachment.set_clear_color(metal::MTLClearColor::new(0., 0., 0., alpha));
|
||||
let command_encoder = command_buffer.new_render_command_encoder(render_pass_descriptor);
|
||||
|
@ -429,6 +425,7 @@ impl MetalRenderer {
|
|||
),
|
||||
PrimitiveBatch::Paths(paths) => self.draw_paths(
|
||||
paths,
|
||||
&path_tiles,
|
||||
instance_buffer,
|
||||
&mut instance_offset,
|
||||
viewport_size,
|
||||
|
@ -496,6 +493,106 @@ impl MetalRenderer {
|
|||
Ok(command_buffer.to_owned())
|
||||
}
|
||||
|
||||
fn rasterize_paths(
|
||||
&self,
|
||||
paths: &[Path<ScaledPixels>],
|
||||
instance_buffer: &mut InstanceBuffer,
|
||||
instance_offset: &mut usize,
|
||||
command_buffer: &metal::CommandBufferRef,
|
||||
) -> Option<HashMap<PathId, AtlasTile>> {
|
||||
self.sprite_atlas.clear_textures(AtlasTextureKind::Path);
|
||||
|
||||
let mut tiles = HashMap::default();
|
||||
let mut vertices_by_texture_id = HashMap::default();
|
||||
for path in paths {
|
||||
let clipped_bounds = path.bounds.intersect(&path.content_mask.bounds);
|
||||
|
||||
let tile = self
|
||||
.sprite_atlas
|
||||
.allocate(clipped_bounds.size.map(Into::into), AtlasTextureKind::Path)?;
|
||||
vertices_by_texture_id
|
||||
.entry(tile.texture_id)
|
||||
.or_insert(Vec::new())
|
||||
.extend(path.vertices.iter().map(|vertex| PathVertex {
|
||||
xy_position: vertex.xy_position - clipped_bounds.origin
|
||||
+ tile.bounds.origin.map(Into::into),
|
||||
st_position: vertex.st_position,
|
||||
content_mask: ContentMask {
|
||||
bounds: tile.bounds.map(Into::into),
|
||||
},
|
||||
}));
|
||||
tiles.insert(path.id, tile);
|
||||
}
|
||||
|
||||
for (texture_id, vertices) in vertices_by_texture_id {
|
||||
align_offset(instance_offset);
|
||||
let vertices_bytes_len = mem::size_of_val(vertices.as_slice());
|
||||
let next_offset = *instance_offset + vertices_bytes_len;
|
||||
if next_offset > instance_buffer.size {
|
||||
return None;
|
||||
}
|
||||
|
||||
let render_pass_descriptor = metal::RenderPassDescriptor::new();
|
||||
let color_attachment = render_pass_descriptor
|
||||
.color_attachments()
|
||||
.object_at(0)
|
||||
.unwrap();
|
||||
|
||||
let texture = self.sprite_atlas.metal_texture(texture_id);
|
||||
let msaa_texture = self.sprite_atlas.msaa_texture(texture_id);
|
||||
|
||||
if let Some(msaa_texture) = msaa_texture {
|
||||
color_attachment.set_texture(Some(&msaa_texture));
|
||||
color_attachment.set_resolve_texture(Some(&texture));
|
||||
color_attachment.set_load_action(metal::MTLLoadAction::Clear);
|
||||
color_attachment.set_store_action(metal::MTLStoreAction::MultisampleResolve);
|
||||
} else {
|
||||
color_attachment.set_texture(Some(&texture));
|
||||
color_attachment.set_load_action(metal::MTLLoadAction::Clear);
|
||||
color_attachment.set_store_action(metal::MTLStoreAction::Store);
|
||||
}
|
||||
color_attachment.set_clear_color(metal::MTLClearColor::new(0., 0., 0., 1.));
|
||||
|
||||
let command_encoder = command_buffer.new_render_command_encoder(render_pass_descriptor);
|
||||
command_encoder.set_render_pipeline_state(&self.paths_rasterization_pipeline_state);
|
||||
command_encoder.set_vertex_buffer(
|
||||
PathRasterizationInputIndex::Vertices as u64,
|
||||
Some(&instance_buffer.metal_buffer),
|
||||
*instance_offset as u64,
|
||||
);
|
||||
let texture_size = Size {
|
||||
width: DevicePixels::from(texture.width()),
|
||||
height: DevicePixels::from(texture.height()),
|
||||
};
|
||||
command_encoder.set_vertex_bytes(
|
||||
PathRasterizationInputIndex::AtlasTextureSize as u64,
|
||||
mem::size_of_val(&texture_size) as u64,
|
||||
&texture_size as *const Size<DevicePixels> as *const _,
|
||||
);
|
||||
|
||||
let buffer_contents = unsafe {
|
||||
(instance_buffer.metal_buffer.contents() as *mut u8).add(*instance_offset)
|
||||
};
|
||||
unsafe {
|
||||
ptr::copy_nonoverlapping(
|
||||
vertices.as_ptr() as *const u8,
|
||||
buffer_contents,
|
||||
vertices_bytes_len,
|
||||
);
|
||||
}
|
||||
|
||||
command_encoder.draw_primitives(
|
||||
metal::MTLPrimitiveType::Triangle,
|
||||
0,
|
||||
vertices.len() as u64,
|
||||
);
|
||||
command_encoder.end_encoding();
|
||||
*instance_offset = next_offset;
|
||||
}
|
||||
|
||||
Some(tiles)
|
||||
}
|
||||
|
||||
fn draw_shadows(
|
||||
&self,
|
||||
shadows: &[Shadow],
|
||||
|
@ -621,6 +718,7 @@ impl MetalRenderer {
|
|||
fn draw_paths(
|
||||
&self,
|
||||
paths: &[Path<ScaledPixels>],
|
||||
tiles_by_path_id: &HashMap<PathId, AtlasTile>,
|
||||
instance_buffer: &mut InstanceBuffer,
|
||||
instance_offset: &mut usize,
|
||||
viewport_size: Size<DevicePixels>,
|
||||
|
@ -630,108 +728,100 @@ impl MetalRenderer {
|
|||
return true;
|
||||
}
|
||||
|
||||
command_encoder.set_render_pipeline_state(&self.path_pipeline_state);
|
||||
command_encoder.set_render_pipeline_state(&self.path_sprites_pipeline_state);
|
||||
command_encoder.set_vertex_buffer(
|
||||
SpriteInputIndex::Vertices as u64,
|
||||
Some(&self.unit_vertices),
|
||||
0,
|
||||
);
|
||||
command_encoder.set_vertex_bytes(
|
||||
SpriteInputIndex::ViewportSize as u64,
|
||||
mem::size_of_val(&viewport_size) as u64,
|
||||
&viewport_size as *const Size<DevicePixels> as *const _,
|
||||
);
|
||||
|
||||
unsafe {
|
||||
let base_addr = instance_buffer.metal_buffer.contents();
|
||||
let mut p = (base_addr as *mut u8).add(*instance_offset);
|
||||
let mut draw_indirect_commands = Vec::with_capacity(paths.len());
|
||||
let mut prev_texture_id = None;
|
||||
let mut sprites = SmallVec::<[_; 1]>::new();
|
||||
let mut paths_and_tiles = paths
|
||||
.iter()
|
||||
.map(|path| (path, tiles_by_path_id.get(&path.id).unwrap()))
|
||||
.peekable();
|
||||
|
||||
// copy vertices
|
||||
let vertices_offset = (p as usize) - (base_addr as usize);
|
||||
let mut first_vertex = 0;
|
||||
for (i, path) in paths.iter().enumerate() {
|
||||
if (p as usize) - (base_addr as usize)
|
||||
+ (mem::size_of::<PathVertex<ScaledPixels>>() * path.vertices.len())
|
||||
> instance_buffer.size
|
||||
{
|
||||
loop {
|
||||
if let Some((path, tile)) = paths_and_tiles.peek() {
|
||||
if prev_texture_id.map_or(true, |texture_id| texture_id == tile.texture_id) {
|
||||
prev_texture_id = Some(tile.texture_id);
|
||||
let origin = path.bounds.intersect(&path.content_mask.bounds).origin;
|
||||
sprites.push(PathSprite {
|
||||
bounds: Bounds {
|
||||
origin: origin.map(|p| p.floor()),
|
||||
size: tile.bounds.size.map(Into::into),
|
||||
},
|
||||
color: path.color,
|
||||
tile: (*tile).clone(),
|
||||
});
|
||||
paths_and_tiles.next();
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
if sprites.is_empty() {
|
||||
break;
|
||||
} else {
|
||||
align_offset(instance_offset);
|
||||
let texture_id = prev_texture_id.take().unwrap();
|
||||
let texture: metal::Texture = self.sprite_atlas.metal_texture(texture_id);
|
||||
let texture_size = size(
|
||||
DevicePixels(texture.width() as i32),
|
||||
DevicePixels(texture.height() as i32),
|
||||
);
|
||||
|
||||
command_encoder.set_vertex_buffer(
|
||||
SpriteInputIndex::Sprites as u64,
|
||||
Some(&instance_buffer.metal_buffer),
|
||||
*instance_offset as u64,
|
||||
);
|
||||
command_encoder.set_vertex_bytes(
|
||||
SpriteInputIndex::AtlasTextureSize as u64,
|
||||
mem::size_of_val(&texture_size) as u64,
|
||||
&texture_size as *const Size<DevicePixels> as *const _,
|
||||
);
|
||||
command_encoder.set_fragment_buffer(
|
||||
SpriteInputIndex::Sprites as u64,
|
||||
Some(&instance_buffer.metal_buffer),
|
||||
*instance_offset as u64,
|
||||
);
|
||||
command_encoder
|
||||
.set_fragment_texture(SpriteInputIndex::AtlasTexture as u64, Some(&texture));
|
||||
|
||||
let sprite_bytes_len = mem::size_of_val(sprites.as_slice());
|
||||
let next_offset = *instance_offset + sprite_bytes_len;
|
||||
if next_offset > instance_buffer.size {
|
||||
return false;
|
||||
}
|
||||
|
||||
for v in &path.vertices {
|
||||
*(p as *mut PathVertex<ScaledPixels>) = PathVertex {
|
||||
xy_position: v.xy_position,
|
||||
content_mask: ContentMask {
|
||||
bounds: path.content_mask.bounds,
|
||||
},
|
||||
};
|
||||
p = p.add(mem::size_of::<PathVertex<ScaledPixels>>());
|
||||
let buffer_contents = unsafe {
|
||||
(instance_buffer.metal_buffer.contents() as *mut u8).add(*instance_offset)
|
||||
};
|
||||
|
||||
unsafe {
|
||||
ptr::copy_nonoverlapping(
|
||||
sprites.as_ptr() as *const u8,
|
||||
buffer_contents,
|
||||
sprite_bytes_len,
|
||||
);
|
||||
}
|
||||
|
||||
draw_indirect_commands.push(MTLDrawPrimitivesIndirectArguments {
|
||||
vertexCount: path.vertices.len() as u32,
|
||||
instanceCount: 1,
|
||||
vertexStart: first_vertex,
|
||||
baseInstance: i as u32,
|
||||
});
|
||||
first_vertex += path.vertices.len() as u32;
|
||||
}
|
||||
|
||||
// copy sprites
|
||||
let sprites_offset = (p as u64) - (base_addr as u64);
|
||||
if (p as usize) - (base_addr as usize) + (mem::size_of::<PathSprite>() * paths.len())
|
||||
> instance_buffer.size
|
||||
{
|
||||
return false;
|
||||
}
|
||||
for path in paths {
|
||||
*(p as *mut PathSprite) = PathSprite {
|
||||
bounds: path.bounds,
|
||||
color: path.color,
|
||||
};
|
||||
p = p.add(mem::size_of::<PathSprite>());
|
||||
}
|
||||
|
||||
// copy indirect commands
|
||||
let icb_bytes_len = mem::size_of_val(draw_indirect_commands.as_slice());
|
||||
let icb_offset = (p as u64) - (base_addr as u64);
|
||||
if (p as usize) - (base_addr as usize) + icb_bytes_len > instance_buffer.size {
|
||||
return false;
|
||||
}
|
||||
ptr::copy_nonoverlapping(
|
||||
draw_indirect_commands.as_ptr() as *const u8,
|
||||
p,
|
||||
icb_bytes_len,
|
||||
);
|
||||
p = p.add(icb_bytes_len);
|
||||
|
||||
// draw path
|
||||
command_encoder.set_vertex_buffer(
|
||||
PathInputIndex::Vertices as u64,
|
||||
Some(&instance_buffer.metal_buffer),
|
||||
vertices_offset as u64,
|
||||
);
|
||||
|
||||
command_encoder.set_vertex_bytes(
|
||||
PathInputIndex::ViewportSize as u64,
|
||||
mem::size_of_val(&viewport_size) as u64,
|
||||
&viewport_size as *const Size<DevicePixels> as *const _,
|
||||
);
|
||||
|
||||
command_encoder.set_vertex_buffer(
|
||||
PathInputIndex::Sprites as u64,
|
||||
Some(&instance_buffer.metal_buffer),
|
||||
sprites_offset,
|
||||
);
|
||||
|
||||
command_encoder.set_fragment_buffer(
|
||||
PathInputIndex::Sprites as u64,
|
||||
Some(&instance_buffer.metal_buffer),
|
||||
sprites_offset,
|
||||
);
|
||||
|
||||
for i in 0..paths.len() {
|
||||
command_encoder.draw_primitives_indirect(
|
||||
command_encoder.draw_primitives_instanced(
|
||||
metal::MTLPrimitiveType::Triangle,
|
||||
&instance_buffer.metal_buffer,
|
||||
icb_offset
|
||||
+ (i * std::mem::size_of::<MTLDrawPrimitivesIndirectArguments>()) as u64,
|
||||
0,
|
||||
6,
|
||||
sprites.len() as u64,
|
||||
);
|
||||
*instance_offset = next_offset;
|
||||
sprites.clear();
|
||||
}
|
||||
|
||||
*instance_offset = (p as usize) - (base_addr as usize);
|
||||
}
|
||||
|
||||
true
|
||||
}
|
||||
|
||||
|
@ -1053,7 +1143,6 @@ fn build_pipeline_state(
|
|||
vertex_fn_name: &str,
|
||||
fragment_fn_name: &str,
|
||||
pixel_format: metal::MTLPixelFormat,
|
||||
sample_count: u64,
|
||||
) -> metal::RenderPipelineState {
|
||||
let vertex_fn = library
|
||||
.get_function(vertex_fn_name, None)
|
||||
|
@ -1066,7 +1155,6 @@ fn build_pipeline_state(
|
|||
descriptor.set_label(label);
|
||||
descriptor.set_vertex_function(Some(vertex_fn.as_ref()));
|
||||
descriptor.set_fragment_function(Some(fragment_fn.as_ref()));
|
||||
descriptor.set_sample_count(sample_count);
|
||||
let color_attachment = descriptor.color_attachments().object_at(0).unwrap();
|
||||
color_attachment.set_pixel_format(pixel_format);
|
||||
color_attachment.set_blending_enabled(true);
|
||||
|
@ -1082,45 +1170,50 @@ fn build_pipeline_state(
|
|||
.expect("could not create render pipeline state")
|
||||
}
|
||||
|
||||
fn build_path_rasterization_pipeline_state(
|
||||
device: &metal::DeviceRef,
|
||||
library: &metal::LibraryRef,
|
||||
label: &str,
|
||||
vertex_fn_name: &str,
|
||||
fragment_fn_name: &str,
|
||||
pixel_format: metal::MTLPixelFormat,
|
||||
path_sample_count: u32,
|
||||
) -> metal::RenderPipelineState {
|
||||
let vertex_fn = library
|
||||
.get_function(vertex_fn_name, None)
|
||||
.expect("error locating vertex function");
|
||||
let fragment_fn = library
|
||||
.get_function(fragment_fn_name, None)
|
||||
.expect("error locating fragment function");
|
||||
|
||||
let descriptor = metal::RenderPipelineDescriptor::new();
|
||||
descriptor.set_label(label);
|
||||
descriptor.set_vertex_function(Some(vertex_fn.as_ref()));
|
||||
descriptor.set_fragment_function(Some(fragment_fn.as_ref()));
|
||||
if path_sample_count > 1 {
|
||||
descriptor.set_raster_sample_count(path_sample_count as _);
|
||||
descriptor.set_alpha_to_coverage_enabled(true);
|
||||
}
|
||||
let color_attachment = descriptor.color_attachments().object_at(0).unwrap();
|
||||
color_attachment.set_pixel_format(pixel_format);
|
||||
color_attachment.set_blending_enabled(true);
|
||||
color_attachment.set_rgb_blend_operation(metal::MTLBlendOperation::Add);
|
||||
color_attachment.set_alpha_blend_operation(metal::MTLBlendOperation::Add);
|
||||
color_attachment.set_source_rgb_blend_factor(metal::MTLBlendFactor::One);
|
||||
color_attachment.set_source_alpha_blend_factor(metal::MTLBlendFactor::One);
|
||||
color_attachment.set_destination_rgb_blend_factor(metal::MTLBlendFactor::One);
|
||||
color_attachment.set_destination_alpha_blend_factor(metal::MTLBlendFactor::One);
|
||||
|
||||
device
|
||||
.new_render_pipeline_state(&descriptor)
|
||||
.expect("could not create render pipeline state")
|
||||
}
|
||||
|
||||
// Align to multiples of 256 make Metal happy.
|
||||
fn align_offset(offset: &mut usize) {
|
||||
*offset = (*offset).div_ceil(256) * 256;
|
||||
}
|
||||
|
||||
fn create_msaa_texture(
|
||||
device: &metal::Device,
|
||||
layer: &metal::MetalLayer,
|
||||
sample_count: u64,
|
||||
) -> Option<metal::Texture> {
|
||||
let viewport_size = layer.drawable_size();
|
||||
let width = viewport_size.width.ceil() as u64;
|
||||
let height = viewport_size.height.ceil() as u64;
|
||||
|
||||
if width == 0 || height == 0 {
|
||||
return None;
|
||||
}
|
||||
|
||||
if sample_count <= 1 {
|
||||
return None;
|
||||
}
|
||||
|
||||
let texture_descriptor = metal::TextureDescriptor::new();
|
||||
texture_descriptor.set_texture_type(metal::MTLTextureType::D2Multisample);
|
||||
|
||||
// MTLStorageMode default is `shared` only for Apple silicon GPUs. Use `private` for Apple and Intel GPUs both.
|
||||
// Reference: https://developer.apple.com/documentation/metal/choosing-a-resource-storage-mode-for-apple-gpus
|
||||
texture_descriptor.set_storage_mode(metal::MTLStorageMode::Private);
|
||||
|
||||
texture_descriptor.set_width(width);
|
||||
texture_descriptor.set_height(height);
|
||||
texture_descriptor.set_pixel_format(layer.pixel_format());
|
||||
texture_descriptor.set_usage(metal::MTLTextureUsage::RenderTarget);
|
||||
texture_descriptor.set_sample_count(sample_count);
|
||||
|
||||
let metal_texture = device.new_texture(&texture_descriptor);
|
||||
Some(metal_texture)
|
||||
}
|
||||
|
||||
#[repr(C)]
|
||||
enum ShadowInputIndex {
|
||||
Vertices = 0,
|
||||
|
@ -1162,10 +1255,9 @@ enum SurfaceInputIndex {
|
|||
}
|
||||
|
||||
#[repr(C)]
|
||||
enum PathInputIndex {
|
||||
enum PathRasterizationInputIndex {
|
||||
Vertices = 0,
|
||||
ViewportSize = 1,
|
||||
Sprites = 2,
|
||||
AtlasTextureSize = 1,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Eq, PartialEq)]
|
||||
|
@ -1173,6 +1265,7 @@ enum PathInputIndex {
|
|||
pub struct PathSprite {
|
||||
pub bounds: Bounds<ScaledPixels>,
|
||||
pub color: Background,
|
||||
pub tile: AtlasTile,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Eq, PartialEq)]
|
||||
|
|
|
@ -698,27 +698,76 @@ fragment float4 polychrome_sprite_fragment(
|
|||
return color;
|
||||
}
|
||||
|
||||
struct PathVertexOutput {
|
||||
struct PathRasterizationVertexOutput {
|
||||
float4 position [[position]];
|
||||
float2 st_position;
|
||||
float clip_rect_distance [[clip_distance]][4];
|
||||
};
|
||||
|
||||
struct PathRasterizationFragmentInput {
|
||||
float4 position [[position]];
|
||||
float2 st_position;
|
||||
};
|
||||
|
||||
vertex PathRasterizationVertexOutput path_rasterization_vertex(
|
||||
uint vertex_id [[vertex_id]],
|
||||
constant PathVertex_ScaledPixels *vertices
|
||||
[[buffer(PathRasterizationInputIndex_Vertices)]],
|
||||
constant Size_DevicePixels *atlas_size
|
||||
[[buffer(PathRasterizationInputIndex_AtlasTextureSize)]]) {
|
||||
PathVertex_ScaledPixels v = vertices[vertex_id];
|
||||
float2 vertex_position = float2(v.xy_position.x, v.xy_position.y);
|
||||
float2 viewport_size = float2(atlas_size->width, atlas_size->height);
|
||||
return PathRasterizationVertexOutput{
|
||||
float4(vertex_position / viewport_size * float2(2., -2.) +
|
||||
float2(-1., 1.),
|
||||
0., 1.),
|
||||
float2(v.st_position.x, v.st_position.y),
|
||||
{v.xy_position.x - v.content_mask.bounds.origin.x,
|
||||
v.content_mask.bounds.origin.x + v.content_mask.bounds.size.width -
|
||||
v.xy_position.x,
|
||||
v.xy_position.y - v.content_mask.bounds.origin.y,
|
||||
v.content_mask.bounds.origin.y + v.content_mask.bounds.size.height -
|
||||
v.xy_position.y}};
|
||||
}
|
||||
|
||||
fragment float4 path_rasterization_fragment(PathRasterizationFragmentInput input
|
||||
[[stage_in]]) {
|
||||
float2 dx = dfdx(input.st_position);
|
||||
float2 dy = dfdy(input.st_position);
|
||||
float2 gradient = float2((2. * input.st_position.x) * dx.x - dx.y,
|
||||
(2. * input.st_position.x) * dy.x - dy.y);
|
||||
float f = (input.st_position.x * input.st_position.x) - input.st_position.y;
|
||||
float distance = f / length(gradient);
|
||||
float alpha = saturate(0.5 - distance);
|
||||
return float4(alpha, 0., 0., 1.);
|
||||
}
|
||||
|
||||
struct PathSpriteVertexOutput {
|
||||
float4 position [[position]];
|
||||
float2 tile_position;
|
||||
uint sprite_id [[flat]];
|
||||
float4 solid_color [[flat]];
|
||||
float4 color0 [[flat]];
|
||||
float4 color1 [[flat]];
|
||||
float4 clip_distance;
|
||||
};
|
||||
|
||||
vertex PathVertexOutput path_vertex(
|
||||
uint vertex_id [[vertex_id]],
|
||||
constant PathVertex_ScaledPixels *vertices [[buffer(PathInputIndex_Vertices)]],
|
||||
uint sprite_id [[instance_id]],
|
||||
constant PathSprite *sprites [[buffer(PathInputIndex_Sprites)]],
|
||||
constant Size_DevicePixels *input_viewport_size [[buffer(PathInputIndex_ViewportSize)]]) {
|
||||
PathVertex_ScaledPixels v = vertices[vertex_id];
|
||||
float2 vertex_position = float2(v.xy_position.x, v.xy_position.y);
|
||||
float2 viewport_size = float2((float)input_viewport_size->width,
|
||||
(float)input_viewport_size->height);
|
||||
vertex PathSpriteVertexOutput path_sprite_vertex(
|
||||
uint unit_vertex_id [[vertex_id]], uint sprite_id [[instance_id]],
|
||||
constant float2 *unit_vertices [[buffer(SpriteInputIndex_Vertices)]],
|
||||
constant PathSprite *sprites [[buffer(SpriteInputIndex_Sprites)]],
|
||||
constant Size_DevicePixels *viewport_size
|
||||
[[buffer(SpriteInputIndex_ViewportSize)]],
|
||||
constant Size_DevicePixels *atlas_size
|
||||
[[buffer(SpriteInputIndex_AtlasTextureSize)]]) {
|
||||
|
||||
float2 unit_vertex = unit_vertices[unit_vertex_id];
|
||||
PathSprite sprite = sprites[sprite_id];
|
||||
float4 device_position = float4(vertex_position / viewport_size * float2(2., -2.) + float2(-1., 1.), 0., 1.);
|
||||
// Don't apply content mask because it was already accounted for when
|
||||
// rasterizing the path.
|
||||
float4 device_position =
|
||||
to_device_position(unit_vertex, sprite.bounds, viewport_size);
|
||||
float2 tile_position = to_tile_position(unit_vertex, sprite.tile, atlas_size);
|
||||
|
||||
GradientColor gradient = prepare_fill_color(
|
||||
sprite.color.tag,
|
||||
|
@ -728,32 +777,30 @@ vertex PathVertexOutput path_vertex(
|
|||
sprite.color.colors[1].color
|
||||
);
|
||||
|
||||
return PathVertexOutput{
|
||||
return PathSpriteVertexOutput{
|
||||
device_position,
|
||||
tile_position,
|
||||
sprite_id,
|
||||
gradient.solid,
|
||||
gradient.color0,
|
||||
gradient.color1,
|
||||
{v.xy_position.x - v.content_mask.bounds.origin.x,
|
||||
v.content_mask.bounds.origin.x + v.content_mask.bounds.size.width -
|
||||
v.xy_position.x,
|
||||
v.xy_position.y - v.content_mask.bounds.origin.y,
|
||||
v.content_mask.bounds.origin.y + v.content_mask.bounds.size.height -
|
||||
v.xy_position.y}
|
||||
gradient.color1
|
||||
};
|
||||
}
|
||||
|
||||
fragment float4 path_fragment(
|
||||
PathVertexOutput input [[stage_in]],
|
||||
constant PathSprite *sprites [[buffer(PathInputIndex_Sprites)]]) {
|
||||
if (any(input.clip_distance < float4(0.0))) {
|
||||
return float4(0.0);
|
||||
}
|
||||
|
||||
fragment float4 path_sprite_fragment(
|
||||
PathSpriteVertexOutput input [[stage_in]],
|
||||
constant PathSprite *sprites [[buffer(SpriteInputIndex_Sprites)]],
|
||||
texture2d<float> atlas_texture [[texture(SpriteInputIndex_AtlasTexture)]]) {
|
||||
constexpr sampler atlas_texture_sampler(mag_filter::linear,
|
||||
min_filter::linear);
|
||||
float4 sample =
|
||||
atlas_texture.sample(atlas_texture_sampler, input.tile_position);
|
||||
float mask = 1. - abs(1. - fmod(sample.r, 2.));
|
||||
PathSprite sprite = sprites[input.sprite_id];
|
||||
Background background = sprite.color;
|
||||
float4 color = fill_color(background, input.position.xy, sprite.bounds,
|
||||
input.solid_color, input.color0, input.color1);
|
||||
color.a *= mask;
|
||||
return color;
|
||||
}
|
||||
|
||||
|
|
|
@ -341,7 +341,7 @@ impl PlatformAtlas for TestAtlas {
|
|||
crate::AtlasTile {
|
||||
texture_id: AtlasTextureId {
|
||||
index: texture_id,
|
||||
kind: crate::AtlasTextureKind::Polychrome,
|
||||
kind: crate::AtlasTextureKind::Path,
|
||||
},
|
||||
tile_id: TileId(tile_id),
|
||||
padding: 0,
|
||||
|
|
|
@ -6,7 +6,7 @@ use serde::{Deserialize, Serialize};
|
|||
|
||||
use crate::{
|
||||
AtlasTextureId, AtlasTile, Background, Bounds, ContentMask, Corners, Edges, Hsla, Pixels,
|
||||
Point, Radians, ScaledPixels, Size, bounds_tree::BoundsTree,
|
||||
Point, Radians, ScaledPixels, Size, bounds_tree::BoundsTree, point,
|
||||
};
|
||||
use std::{fmt::Debug, iter::Peekable, ops::Range, slice};
|
||||
|
||||
|
@ -43,7 +43,13 @@ impl Scene {
|
|||
self.surfaces.clear();
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
#[cfg_attr(
|
||||
all(
|
||||
any(target_os = "linux", target_os = "freebsd"),
|
||||
not(any(feature = "x11", feature = "wayland"))
|
||||
),
|
||||
allow(dead_code)
|
||||
)]
|
||||
pub fn paths(&self) -> &[Path<ScaledPixels>] {
|
||||
&self.paths
|
||||
}
|
||||
|
@ -683,7 +689,6 @@ pub struct Path<P: Clone + Debug + Default + PartialEq> {
|
|||
start: Point<P>,
|
||||
current: Point<P>,
|
||||
contour_count: usize,
|
||||
base_scale: f32,
|
||||
}
|
||||
|
||||
impl Path<Pixels> {
|
||||
|
@ -702,35 +707,25 @@ impl Path<Pixels> {
|
|||
content_mask: Default::default(),
|
||||
color: Default::default(),
|
||||
contour_count: 0,
|
||||
base_scale: 1.0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Set the base scale of the path.
|
||||
pub fn scale(mut self, factor: f32) -> Self {
|
||||
self.base_scale = factor;
|
||||
self
|
||||
}
|
||||
|
||||
/// Apply a scale to the path.
|
||||
pub(crate) fn apply_scale(&self, factor: f32) -> Path<ScaledPixels> {
|
||||
/// Scale this path by the given factor.
|
||||
pub fn scale(&self, factor: f32) -> Path<ScaledPixels> {
|
||||
Path {
|
||||
id: self.id,
|
||||
order: self.order,
|
||||
bounds: self.bounds.scale(self.base_scale * factor),
|
||||
content_mask: self.content_mask.scale(self.base_scale * factor),
|
||||
bounds: self.bounds.scale(factor),
|
||||
content_mask: self.content_mask.scale(factor),
|
||||
vertices: self
|
||||
.vertices
|
||||
.iter()
|
||||
.map(|vertex| vertex.scale(self.base_scale * factor))
|
||||
.map(|vertex| vertex.scale(factor))
|
||||
.collect(),
|
||||
start: self
|
||||
.start
|
||||
.map(|start| start.scale(self.base_scale * factor)),
|
||||
current: self.current.scale(self.base_scale * factor),
|
||||
start: self.start.map(|start| start.scale(factor)),
|
||||
current: self.current.scale(factor),
|
||||
contour_count: self.contour_count,
|
||||
color: self.color,
|
||||
base_scale: 1.0,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -745,7 +740,10 @@ impl Path<Pixels> {
|
|||
pub fn line_to(&mut self, to: Point<Pixels>) {
|
||||
self.contour_count += 1;
|
||||
if self.contour_count > 1 {
|
||||
self.push_triangle((self.start, self.current, to));
|
||||
self.push_triangle(
|
||||
(self.start, self.current, to),
|
||||
(point(0., 1.), point(0., 1.), point(0., 1.)),
|
||||
);
|
||||
}
|
||||
self.current = to;
|
||||
}
|
||||
|
@ -754,15 +752,25 @@ impl Path<Pixels> {
|
|||
pub fn curve_to(&mut self, to: Point<Pixels>, ctrl: Point<Pixels>) {
|
||||
self.contour_count += 1;
|
||||
if self.contour_count > 1 {
|
||||
self.push_triangle((self.start, self.current, to));
|
||||
self.push_triangle(
|
||||
(self.start, self.current, to),
|
||||
(point(0., 1.), point(0., 1.), point(0., 1.)),
|
||||
);
|
||||
}
|
||||
|
||||
self.push_triangle((self.current, ctrl, to));
|
||||
self.push_triangle(
|
||||
(self.current, ctrl, to),
|
||||
(point(0., 0.), point(0.5, 0.), point(1., 1.)),
|
||||
);
|
||||
self.current = to;
|
||||
}
|
||||
|
||||
/// Push a triangle to the Path.
|
||||
pub fn push_triangle(&mut self, xy: (Point<Pixels>, Point<Pixels>, Point<Pixels>)) {
|
||||
pub fn push_triangle(
|
||||
&mut self,
|
||||
xy: (Point<Pixels>, Point<Pixels>, Point<Pixels>),
|
||||
st: (Point<f32>, Point<f32>, Point<f32>),
|
||||
) {
|
||||
self.bounds = self
|
||||
.bounds
|
||||
.union(&Bounds {
|
||||
|
@ -780,14 +788,17 @@ impl Path<Pixels> {
|
|||
|
||||
self.vertices.push(PathVertex {
|
||||
xy_position: xy.0,
|
||||
st_position: st.0,
|
||||
content_mask: Default::default(),
|
||||
});
|
||||
self.vertices.push(PathVertex {
|
||||
xy_position: xy.1,
|
||||
st_position: st.1,
|
||||
content_mask: Default::default(),
|
||||
});
|
||||
self.vertices.push(PathVertex {
|
||||
xy_position: xy.2,
|
||||
st_position: st.2,
|
||||
content_mask: Default::default(),
|
||||
});
|
||||
}
|
||||
|
@ -803,6 +814,7 @@ impl From<Path<ScaledPixels>> for Primitive {
|
|||
#[repr(C)]
|
||||
pub(crate) struct PathVertex<P: Clone + Debug + Default + PartialEq> {
|
||||
pub(crate) xy_position: Point<P>,
|
||||
pub(crate) st_position: Point<f32>,
|
||||
pub(crate) content_mask: ContentMask<P>,
|
||||
}
|
||||
|
||||
|
@ -810,6 +822,7 @@ impl PathVertex<Pixels> {
|
|||
pub fn scale(&self, factor: f32) -> PathVertex<ScaledPixels> {
|
||||
PathVertex {
|
||||
xy_position: self.xy_position.scale(factor),
|
||||
st_position: self.st_position,
|
||||
content_mask: self.content_mask.scale(factor),
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2633,7 +2633,7 @@ impl Window {
|
|||
path.color = color.opacity(opacity);
|
||||
self.next_frame
|
||||
.scene
|
||||
.insert_primitive(path.apply_scale(scale_factor));
|
||||
.insert_primitive(path.scale(scale_factor));
|
||||
}
|
||||
|
||||
/// Paint an underline into the scene for the next frame at the current z-index.
|
||||
|
|
|
@ -20,6 +20,7 @@ pub enum IconName {
|
|||
AiOpenAi,
|
||||
AiOpenRouter,
|
||||
AiVZero,
|
||||
AiXAi,
|
||||
AiZed,
|
||||
ArrowCircle,
|
||||
ArrowDown,
|
||||
|
|
|
@ -116,6 +116,12 @@ pub enum LanguageModelCompletionError {
|
|||
provider: LanguageModelProviderName,
|
||||
message: String,
|
||||
},
|
||||
#[error("{message}")]
|
||||
UpstreamProviderError {
|
||||
message: String,
|
||||
status: StatusCode,
|
||||
retry_after: Option<Duration>,
|
||||
},
|
||||
#[error("HTTP response error from {provider}'s API: status {status_code} - {message:?}")]
|
||||
HttpResponseError {
|
||||
provider: LanguageModelProviderName,
|
||||
|
@ -178,6 +184,21 @@ pub enum LanguageModelCompletionError {
|
|||
}
|
||||
|
||||
impl LanguageModelCompletionError {
|
||||
fn parse_upstream_error_json(message: &str) -> Option<(StatusCode, String)> {
|
||||
let error_json = serde_json::from_str::<serde_json::Value>(message).ok()?;
|
||||
let upstream_status = error_json
|
||||
.get("upstream_status")
|
||||
.and_then(|v| v.as_u64())
|
||||
.and_then(|status| u16::try_from(status).ok())
|
||||
.and_then(|status| StatusCode::from_u16(status).ok())?;
|
||||
let inner_message = error_json
|
||||
.get("message")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or(message)
|
||||
.to_string();
|
||||
Some((upstream_status, inner_message))
|
||||
}
|
||||
|
||||
pub fn from_cloud_failure(
|
||||
upstream_provider: LanguageModelProviderName,
|
||||
code: String,
|
||||
|
@ -191,6 +212,18 @@ impl LanguageModelCompletionError {
|
|||
Self::PromptTooLarge {
|
||||
tokens: Some(tokens),
|
||||
}
|
||||
} else if code == "upstream_http_error" {
|
||||
if let Some((upstream_status, inner_message)) =
|
||||
Self::parse_upstream_error_json(&message)
|
||||
{
|
||||
return Self::from_http_status(
|
||||
upstream_provider,
|
||||
upstream_status,
|
||||
inner_message,
|
||||
retry_after,
|
||||
);
|
||||
}
|
||||
anyhow!("completion request failed, code: {code}, message: {message}").into()
|
||||
} else if let Some(status_code) = code
|
||||
.strip_prefix("upstream_http_")
|
||||
.and_then(|code| StatusCode::from_str(code).ok())
|
||||
|
@ -701,3 +734,104 @@ impl From<String> for LanguageModelProviderName {
|
|||
Self(SharedString::from(value))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_from_cloud_failure_with_upstream_http_error() {
|
||||
let error = LanguageModelCompletionError::from_cloud_failure(
|
||||
String::from("anthropic").into(),
|
||||
"upstream_http_error".to_string(),
|
||||
r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
|
||||
None,
|
||||
);
|
||||
|
||||
match error {
|
||||
LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
|
||||
assert_eq!(provider.0, "anthropic");
|
||||
}
|
||||
_ => panic!(
|
||||
"Expected ServerOverloaded error for 503 status, got: {:?}",
|
||||
error
|
||||
),
|
||||
}
|
||||
|
||||
let error = LanguageModelCompletionError::from_cloud_failure(
|
||||
String::from("anthropic").into(),
|
||||
"upstream_http_error".to_string(),
|
||||
r#"{"code":"upstream_http_error","message":"Internal server error","upstream_status":500}"#.to_string(),
|
||||
None,
|
||||
);
|
||||
|
||||
match error {
|
||||
LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
|
||||
assert_eq!(provider.0, "anthropic");
|
||||
assert_eq!(message, "Internal server error");
|
||||
}
|
||||
_ => panic!(
|
||||
"Expected ApiInternalServerError for 500 status, got: {:?}",
|
||||
error
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_from_cloud_failure_with_standard_format() {
|
||||
let error = LanguageModelCompletionError::from_cloud_failure(
|
||||
String::from("anthropic").into(),
|
||||
"upstream_http_503".to_string(),
|
||||
"Service unavailable".to_string(),
|
||||
None,
|
||||
);
|
||||
|
||||
match error {
|
||||
LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
|
||||
assert_eq!(provider.0, "anthropic");
|
||||
}
|
||||
_ => panic!("Expected ServerOverloaded error for upstream_http_503"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_upstream_http_error_connection_timeout() {
|
||||
let error = LanguageModelCompletionError::from_cloud_failure(
|
||||
String::from("anthropic").into(),
|
||||
"upstream_http_error".to_string(),
|
||||
r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
|
||||
None,
|
||||
);
|
||||
|
||||
match error {
|
||||
LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
|
||||
assert_eq!(provider.0, "anthropic");
|
||||
}
|
||||
_ => panic!(
|
||||
"Expected ServerOverloaded error for connection timeout with 503 status, got: {:?}",
|
||||
error
|
||||
),
|
||||
}
|
||||
|
||||
let error = LanguageModelCompletionError::from_cloud_failure(
|
||||
String::from("anthropic").into(),
|
||||
"upstream_http_error".to_string(),
|
||||
r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":500}"#.to_string(),
|
||||
None,
|
||||
);
|
||||
|
||||
match error {
|
||||
LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
|
||||
assert_eq!(provider.0, "anthropic");
|
||||
assert_eq!(
|
||||
message,
|
||||
"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout"
|
||||
);
|
||||
}
|
||||
_ => panic!(
|
||||
"Expected ApiInternalServerError for connection timeout with 500 status, got: {:?}",
|
||||
error
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -44,6 +44,7 @@ ollama = { workspace = true, features = ["schemars"] }
|
|||
open_ai = { workspace = true, features = ["schemars"] }
|
||||
open_router = { workspace = true, features = ["schemars"] }
|
||||
vercel = { workspace = true, features = ["schemars"] }
|
||||
x_ai = { workspace = true, features = ["schemars"] }
|
||||
partial-json-fixer.workspace = true
|
||||
proto.workspace = true
|
||||
release_channel.workspace = true
|
||||
|
|
|
@ -20,6 +20,7 @@ use crate::provider::ollama::OllamaLanguageModelProvider;
|
|||
use crate::provider::open_ai::OpenAiLanguageModelProvider;
|
||||
use crate::provider::open_router::OpenRouterLanguageModelProvider;
|
||||
use crate::provider::vercel::VercelLanguageModelProvider;
|
||||
use crate::provider::x_ai::XAiLanguageModelProvider;
|
||||
pub use crate::settings::*;
|
||||
|
||||
pub fn init(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) {
|
||||
|
@ -81,5 +82,6 @@ fn register_language_model_providers(
|
|||
VercelLanguageModelProvider::new(client.http_client(), cx),
|
||||
cx,
|
||||
);
|
||||
registry.register_provider(XAiLanguageModelProvider::new(client.http_client(), cx), cx);
|
||||
registry.register_provider(CopilotChatLanguageModelProvider::new(cx), cx);
|
||||
}
|
||||
|
|
|
@ -10,3 +10,4 @@ pub mod ollama;
|
|||
pub mod open_ai;
|
||||
pub mod open_router;
|
||||
pub mod vercel;
|
||||
pub mod x_ai;
|
||||
|
|
|
@ -166,46 +166,9 @@ impl State {
|
|||
}
|
||||
|
||||
let response = Self::fetch_models(client, llm_api_token, use_cloud).await?;
|
||||
cx.update(|cx| {
|
||||
this.update(cx, |this, cx| {
|
||||
let mut models = Vec::new();
|
||||
|
||||
for model in response.models {
|
||||
models.push(Arc::new(model.clone()));
|
||||
|
||||
// Right now we represent thinking variants of models as separate models on the client,
|
||||
// so we need to insert variants for any model that supports thinking.
|
||||
if model.supports_thinking {
|
||||
models.push(Arc::new(zed_llm_client::LanguageModel {
|
||||
id: zed_llm_client::LanguageModelId(
|
||||
format!("{}-thinking", model.id).into(),
|
||||
),
|
||||
display_name: format!("{} Thinking", model.display_name),
|
||||
..model
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
this.default_model = models
|
||||
.iter()
|
||||
.find(|model| model.id == response.default_model)
|
||||
.cloned();
|
||||
this.default_fast_model = models
|
||||
.iter()
|
||||
.find(|model| model.id == response.default_fast_model)
|
||||
.cloned();
|
||||
this.recommended_models = response
|
||||
.recommended_models
|
||||
.iter()
|
||||
.filter_map(|id| models.iter().find(|model| &model.id == id))
|
||||
.cloned()
|
||||
.collect();
|
||||
this.models = models;
|
||||
cx.notify();
|
||||
})
|
||||
})??;
|
||||
|
||||
anyhow::Ok(())
|
||||
this.update(cx, |this, cx| {
|
||||
this.update_models(response, cx);
|
||||
})
|
||||
})
|
||||
.await
|
||||
.context("failed to fetch Zed models")
|
||||
|
@ -216,12 +179,15 @@ impl State {
|
|||
}),
|
||||
_llm_token_subscription: cx.subscribe(
|
||||
&refresh_llm_token_listener,
|
||||
|this, _listener, _event, cx| {
|
||||
move |this, _listener, _event, cx| {
|
||||
let client = this.client.clone();
|
||||
let llm_api_token = this.llm_api_token.clone();
|
||||
cx.spawn(async move |_this, _cx| {
|
||||
cx.spawn(async move |this, cx| {
|
||||
llm_api_token.refresh(&client).await?;
|
||||
anyhow::Ok(())
|
||||
let response = Self::fetch_models(client, llm_api_token, use_cloud).await?;
|
||||
this.update(cx, |this, cx| {
|
||||
this.update_models(response, cx);
|
||||
})
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
},
|
||||
|
@ -264,6 +230,41 @@ impl State {
|
|||
}));
|
||||
}
|
||||
|
||||
fn update_models(&mut self, response: ListModelsResponse, cx: &mut Context<Self>) {
|
||||
let mut models = Vec::new();
|
||||
|
||||
for model in response.models {
|
||||
models.push(Arc::new(model.clone()));
|
||||
|
||||
// Right now we represent thinking variants of models as separate models on the client,
|
||||
// so we need to insert variants for any model that supports thinking.
|
||||
if model.supports_thinking {
|
||||
models.push(Arc::new(zed_llm_client::LanguageModel {
|
||||
id: zed_llm_client::LanguageModelId(format!("{}-thinking", model.id).into()),
|
||||
display_name: format!("{} Thinking", model.display_name),
|
||||
..model
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
self.default_model = models
|
||||
.iter()
|
||||
.find(|model| model.id == response.default_model)
|
||||
.cloned();
|
||||
self.default_fast_model = models
|
||||
.iter()
|
||||
.find(|model| model.id == response.default_fast_model)
|
||||
.cloned();
|
||||
self.recommended_models = response
|
||||
.recommended_models
|
||||
.iter()
|
||||
.filter_map(|id| models.iter().find(|model| &model.id == id))
|
||||
.cloned()
|
||||
.collect();
|
||||
self.models = models;
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
async fn fetch_models(
|
||||
client: Arc<Client>,
|
||||
llm_api_token: LlmApiToken,
|
||||
|
@ -653,8 +654,62 @@ struct ApiError {
|
|||
headers: HeaderMap<HeaderValue>,
|
||||
}
|
||||
|
||||
/// Represents error responses from Zed's cloud API.
|
||||
///
|
||||
/// Example JSON for an upstream HTTP error:
|
||||
/// ```json
|
||||
/// {
|
||||
/// "code": "upstream_http_error",
|
||||
/// "message": "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout",
|
||||
/// "upstream_status": 503
|
||||
/// }
|
||||
/// ```
|
||||
#[derive(Debug, serde::Deserialize)]
|
||||
struct CloudApiError {
|
||||
code: String,
|
||||
message: String,
|
||||
#[serde(default)]
|
||||
#[serde(deserialize_with = "deserialize_optional_status_code")]
|
||||
upstream_status: Option<StatusCode>,
|
||||
#[serde(default)]
|
||||
retry_after: Option<f64>,
|
||||
}
|
||||
|
||||
fn deserialize_optional_status_code<'de, D>(deserializer: D) -> Result<Option<StatusCode>, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
let opt: Option<u16> = Option::deserialize(deserializer)?;
|
||||
Ok(opt.and_then(|code| StatusCode::from_u16(code).ok()))
|
||||
}
|
||||
|
||||
impl From<ApiError> for LanguageModelCompletionError {
|
||||
fn from(error: ApiError) -> Self {
|
||||
if let Ok(cloud_error) = serde_json::from_str::<CloudApiError>(&error.body) {
|
||||
if cloud_error.code.starts_with("upstream_http_") {
|
||||
let status = if let Some(status) = cloud_error.upstream_status {
|
||||
status
|
||||
} else if cloud_error.code.ends_with("_error") {
|
||||
error.status
|
||||
} else {
|
||||
// If there's a status code in the code string (e.g. "upstream_http_429")
|
||||
// then use that; otherwise, see if the JSON contains a status code.
|
||||
cloud_error
|
||||
.code
|
||||
.strip_prefix("upstream_http_")
|
||||
.and_then(|code_str| code_str.parse::<u16>().ok())
|
||||
.and_then(|code| StatusCode::from_u16(code).ok())
|
||||
.unwrap_or(error.status)
|
||||
};
|
||||
|
||||
return LanguageModelCompletionError::UpstreamProviderError {
|
||||
message: cloud_error.message,
|
||||
status,
|
||||
retry_after: cloud_error.retry_after.map(Duration::from_secs_f64),
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
let retry_after = None;
|
||||
LanguageModelCompletionError::from_http_status(
|
||||
PROVIDER_NAME,
|
||||
|
@ -1294,3 +1349,155 @@ impl Component for ZedAiConfiguration {
|
|||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use http_client::http::{HeaderMap, StatusCode};
|
||||
use language_model::LanguageModelCompletionError;
|
||||
|
||||
#[test]
|
||||
fn test_api_error_conversion_with_upstream_http_error() {
|
||||
// upstream_http_error with 503 status should become ServerOverloaded
|
||||
let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout","upstream_status":503}"#;
|
||||
|
||||
let api_error = ApiError {
|
||||
status: StatusCode::INTERNAL_SERVER_ERROR,
|
||||
body: error_body.to_string(),
|
||||
headers: HeaderMap::new(),
|
||||
};
|
||||
|
||||
let completion_error: LanguageModelCompletionError = api_error.into();
|
||||
|
||||
match completion_error {
|
||||
LanguageModelCompletionError::UpstreamProviderError { message, .. } => {
|
||||
assert_eq!(
|
||||
message,
|
||||
"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout"
|
||||
);
|
||||
}
|
||||
_ => panic!(
|
||||
"Expected UpstreamProviderError for upstream 503, got: {:?}",
|
||||
completion_error
|
||||
),
|
||||
}
|
||||
|
||||
// upstream_http_error with 500 status should become ApiInternalServerError
|
||||
let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the OpenAI API: internal server error","upstream_status":500}"#;
|
||||
|
||||
let api_error = ApiError {
|
||||
status: StatusCode::INTERNAL_SERVER_ERROR,
|
||||
body: error_body.to_string(),
|
||||
headers: HeaderMap::new(),
|
||||
};
|
||||
|
||||
let completion_error: LanguageModelCompletionError = api_error.into();
|
||||
|
||||
match completion_error {
|
||||
LanguageModelCompletionError::UpstreamProviderError { message, .. } => {
|
||||
assert_eq!(
|
||||
message,
|
||||
"Received an error from the OpenAI API: internal server error"
|
||||
);
|
||||
}
|
||||
_ => panic!(
|
||||
"Expected UpstreamProviderError for upstream 500, got: {:?}",
|
||||
completion_error
|
||||
),
|
||||
}
|
||||
|
||||
// upstream_http_error with 429 status should become RateLimitExceeded
|
||||
let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the Google API: rate limit exceeded","upstream_status":429}"#;
|
||||
|
||||
let api_error = ApiError {
|
||||
status: StatusCode::INTERNAL_SERVER_ERROR,
|
||||
body: error_body.to_string(),
|
||||
headers: HeaderMap::new(),
|
||||
};
|
||||
|
||||
let completion_error: LanguageModelCompletionError = api_error.into();
|
||||
|
||||
match completion_error {
|
||||
LanguageModelCompletionError::UpstreamProviderError { message, .. } => {
|
||||
assert_eq!(
|
||||
message,
|
||||
"Received an error from the Google API: rate limit exceeded"
|
||||
);
|
||||
}
|
||||
_ => panic!(
|
||||
"Expected UpstreamProviderError for upstream 429, got: {:?}",
|
||||
completion_error
|
||||
),
|
||||
}
|
||||
|
||||
// Regular 500 error without upstream_http_error should remain ApiInternalServerError for Zed
|
||||
let error_body = "Regular internal server error";
|
||||
|
||||
let api_error = ApiError {
|
||||
status: StatusCode::INTERNAL_SERVER_ERROR,
|
||||
body: error_body.to_string(),
|
||||
headers: HeaderMap::new(),
|
||||
};
|
||||
|
||||
let completion_error: LanguageModelCompletionError = api_error.into();
|
||||
|
||||
match completion_error {
|
||||
LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
|
||||
assert_eq!(provider, PROVIDER_NAME);
|
||||
assert_eq!(message, "Regular internal server error");
|
||||
}
|
||||
_ => panic!(
|
||||
"Expected ApiInternalServerError for regular 500, got: {:?}",
|
||||
completion_error
|
||||
),
|
||||
}
|
||||
|
||||
// upstream_http_429 format should be converted to UpstreamProviderError
|
||||
let error_body = r#"{"code":"upstream_http_429","message":"Upstream Anthropic rate limit exceeded.","retry_after":30.5}"#;
|
||||
|
||||
let api_error = ApiError {
|
||||
status: StatusCode::INTERNAL_SERVER_ERROR,
|
||||
body: error_body.to_string(),
|
||||
headers: HeaderMap::new(),
|
||||
};
|
||||
|
||||
let completion_error: LanguageModelCompletionError = api_error.into();
|
||||
|
||||
match completion_error {
|
||||
LanguageModelCompletionError::UpstreamProviderError {
|
||||
message,
|
||||
status,
|
||||
retry_after,
|
||||
} => {
|
||||
assert_eq!(message, "Upstream Anthropic rate limit exceeded.");
|
||||
assert_eq!(status, StatusCode::TOO_MANY_REQUESTS);
|
||||
assert_eq!(retry_after, Some(Duration::from_secs_f64(30.5)));
|
||||
}
|
||||
_ => panic!(
|
||||
"Expected UpstreamProviderError for upstream_http_429, got: {:?}",
|
||||
completion_error
|
||||
),
|
||||
}
|
||||
|
||||
// Invalid JSON in error body should fall back to regular error handling
|
||||
let error_body = "Not JSON at all";
|
||||
|
||||
let api_error = ApiError {
|
||||
status: StatusCode::INTERNAL_SERVER_ERROR,
|
||||
body: error_body.to_string(),
|
||||
headers: HeaderMap::new(),
|
||||
};
|
||||
|
||||
let completion_error: LanguageModelCompletionError = api_error.into();
|
||||
|
||||
match completion_error {
|
||||
LanguageModelCompletionError::ApiInternalServerError { provider, .. } => {
|
||||
assert_eq!(provider, PROVIDER_NAME);
|
||||
}
|
||||
_ => panic!(
|
||||
"Expected ApiInternalServerError for invalid JSON, got: {:?}",
|
||||
completion_error
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -376,7 +376,7 @@ impl LanguageModel for OpenRouterLanguageModel {
|
|||
|
||||
fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
|
||||
let model_id = self.model.id().trim().to_lowercase();
|
||||
if model_id.contains("gemini") {
|
||||
if model_id.contains("gemini") || model_id.contains("grok-4") {
|
||||
LanguageModelToolSchemaFormat::JsonSchemaSubset
|
||||
} else {
|
||||
LanguageModelToolSchemaFormat::JsonSchema
|
||||
|
|
571
crates/language_models/src/provider/x_ai.rs
Normal file
571
crates/language_models/src/provider/x_ai.rs
Normal file
|
@ -0,0 +1,571 @@
|
|||
use anyhow::{Context as _, Result, anyhow};
|
||||
use collections::BTreeMap;
|
||||
use credentials_provider::CredentialsProvider;
|
||||
use futures::{FutureExt, StreamExt, future::BoxFuture};
|
||||
use gpui::{AnyView, App, AsyncApp, Context, Entity, Subscription, Task, Window};
|
||||
use http_client::HttpClient;
|
||||
use language_model::{
|
||||
AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
|
||||
LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
|
||||
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
|
||||
LanguageModelToolChoice, LanguageModelToolSchemaFormat, RateLimiter, Role,
|
||||
};
|
||||
use menu;
|
||||
use open_ai::ResponseStreamEvent;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use settings::{Settings, SettingsStore};
|
||||
use std::sync::Arc;
|
||||
use strum::IntoEnumIterator;
|
||||
use x_ai::Model;
|
||||
|
||||
use ui::{ElevationIndex, List, Tooltip, prelude::*};
|
||||
use ui_input::SingleLineInput;
|
||||
use util::ResultExt;
|
||||
|
||||
use crate::{AllLanguageModelSettings, ui::InstructionListItem};
|
||||
|
||||
const PROVIDER_ID: &str = "x_ai";
|
||||
const PROVIDER_NAME: &str = "xAI";
|
||||
|
||||
#[derive(Default, Clone, Debug, PartialEq)]
|
||||
pub struct XAiSettings {
|
||||
pub api_url: String,
|
||||
pub available_models: Vec<AvailableModel>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct AvailableModel {
|
||||
pub name: String,
|
||||
pub display_name: Option<String>,
|
||||
pub max_tokens: u64,
|
||||
pub max_output_tokens: Option<u64>,
|
||||
pub max_completion_tokens: Option<u64>,
|
||||
}
|
||||
|
||||
pub struct XAiLanguageModelProvider {
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
state: gpui::Entity<State>,
|
||||
}
|
||||
|
||||
pub struct State {
|
||||
api_key: Option<String>,
|
||||
api_key_from_env: bool,
|
||||
_subscription: Subscription,
|
||||
}
|
||||
|
||||
const XAI_API_KEY_VAR: &str = "XAI_API_KEY";
|
||||
|
||||
impl State {
|
||||
fn is_authenticated(&self) -> bool {
|
||||
self.api_key.is_some()
|
||||
}
|
||||
|
||||
fn reset_api_key(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
|
||||
let credentials_provider = <dyn CredentialsProvider>::global(cx);
|
||||
let settings = &AllLanguageModelSettings::get_global(cx).x_ai;
|
||||
let api_url = if settings.api_url.is_empty() {
|
||||
x_ai::XAI_API_URL.to_string()
|
||||
} else {
|
||||
settings.api_url.clone()
|
||||
};
|
||||
cx.spawn(async move |this, cx| {
|
||||
credentials_provider
|
||||
.delete_credentials(&api_url, &cx)
|
||||
.await
|
||||
.log_err();
|
||||
this.update(cx, |this, cx| {
|
||||
this.api_key = None;
|
||||
this.api_key_from_env = false;
|
||||
cx.notify();
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
fn set_api_key(&mut self, api_key: String, cx: &mut Context<Self>) -> Task<Result<()>> {
|
||||
let credentials_provider = <dyn CredentialsProvider>::global(cx);
|
||||
let settings = &AllLanguageModelSettings::get_global(cx).x_ai;
|
||||
let api_url = if settings.api_url.is_empty() {
|
||||
x_ai::XAI_API_URL.to_string()
|
||||
} else {
|
||||
settings.api_url.clone()
|
||||
};
|
||||
cx.spawn(async move |this, cx| {
|
||||
credentials_provider
|
||||
.write_credentials(&api_url, "Bearer", api_key.as_bytes(), &cx)
|
||||
.await
|
||||
.log_err();
|
||||
this.update(cx, |this, cx| {
|
||||
this.api_key = Some(api_key);
|
||||
cx.notify();
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
|
||||
if self.is_authenticated() {
|
||||
return Task::ready(Ok(()));
|
||||
}
|
||||
|
||||
let credentials_provider = <dyn CredentialsProvider>::global(cx);
|
||||
let settings = &AllLanguageModelSettings::get_global(cx).x_ai;
|
||||
let api_url = if settings.api_url.is_empty() {
|
||||
x_ai::XAI_API_URL.to_string()
|
||||
} else {
|
||||
settings.api_url.clone()
|
||||
};
|
||||
cx.spawn(async move |this, cx| {
|
||||
let (api_key, from_env) = if let Ok(api_key) = std::env::var(XAI_API_KEY_VAR) {
|
||||
(api_key, true)
|
||||
} else {
|
||||
let (_, api_key) = credentials_provider
|
||||
.read_credentials(&api_url, &cx)
|
||||
.await?
|
||||
.ok_or(AuthenticateError::CredentialsNotFound)?;
|
||||
(
|
||||
String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?,
|
||||
false,
|
||||
)
|
||||
};
|
||||
this.update(cx, |this, cx| {
|
||||
this.api_key = Some(api_key);
|
||||
this.api_key_from_env = from_env;
|
||||
cx.notify();
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl XAiLanguageModelProvider {
|
||||
pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
|
||||
let state = cx.new(|cx| State {
|
||||
api_key: None,
|
||||
api_key_from_env: false,
|
||||
_subscription: cx.observe_global::<SettingsStore>(|_this: &mut State, cx| {
|
||||
cx.notify();
|
||||
}),
|
||||
});
|
||||
|
||||
Self { http_client, state }
|
||||
}
|
||||
|
||||
fn create_language_model(&self, model: x_ai::Model) -> Arc<dyn LanguageModel> {
|
||||
Arc::new(XAiLanguageModel {
|
||||
id: LanguageModelId::from(model.id().to_string()),
|
||||
model,
|
||||
state: self.state.clone(),
|
||||
http_client: self.http_client.clone(),
|
||||
request_limiter: RateLimiter::new(4),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl LanguageModelProviderState for XAiLanguageModelProvider {
|
||||
type ObservableEntity = State;
|
||||
|
||||
fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>> {
|
||||
Some(self.state.clone())
|
||||
}
|
||||
}
|
||||
|
||||
impl LanguageModelProvider for XAiLanguageModelProvider {
|
||||
fn id(&self) -> LanguageModelProviderId {
|
||||
LanguageModelProviderId(PROVIDER_ID.into())
|
||||
}
|
||||
|
||||
fn name(&self) -> LanguageModelProviderName {
|
||||
LanguageModelProviderName(PROVIDER_NAME.into())
|
||||
}
|
||||
|
||||
fn icon(&self) -> IconName {
|
||||
IconName::AiXAi
|
||||
}
|
||||
|
||||
fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
|
||||
Some(self.create_language_model(x_ai::Model::default()))
|
||||
}
|
||||
|
||||
fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
|
||||
Some(self.create_language_model(x_ai::Model::default_fast()))
|
||||
}
|
||||
|
||||
fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
|
||||
let mut models = BTreeMap::default();
|
||||
|
||||
for model in x_ai::Model::iter() {
|
||||
if !matches!(model, x_ai::Model::Custom { .. }) {
|
||||
models.insert(model.id().to_string(), model);
|
||||
}
|
||||
}
|
||||
|
||||
for model in &AllLanguageModelSettings::get_global(cx)
|
||||
.x_ai
|
||||
.available_models
|
||||
{
|
||||
models.insert(
|
||||
model.name.clone(),
|
||||
x_ai::Model::Custom {
|
||||
name: model.name.clone(),
|
||||
display_name: model.display_name.clone(),
|
||||
max_tokens: model.max_tokens,
|
||||
max_output_tokens: model.max_output_tokens,
|
||||
max_completion_tokens: model.max_completion_tokens,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
models
|
||||
.into_values()
|
||||
.map(|model| self.create_language_model(model))
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn is_authenticated(&self, cx: &App) -> bool {
|
||||
self.state.read(cx).is_authenticated()
|
||||
}
|
||||
|
||||
fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
|
||||
self.state.update(cx, |state, cx| state.authenticate(cx))
|
||||
}
|
||||
|
||||
fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView {
|
||||
cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx))
|
||||
.into()
|
||||
}
|
||||
|
||||
fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
|
||||
self.state.update(cx, |state, cx| state.reset_api_key(cx))
|
||||
}
|
||||
}
|
||||
|
||||
pub struct XAiLanguageModel {
|
||||
id: LanguageModelId,
|
||||
model: x_ai::Model,
|
||||
state: gpui::Entity<State>,
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
request_limiter: RateLimiter,
|
||||
}
|
||||
|
||||
impl XAiLanguageModel {
|
||||
fn stream_completion(
|
||||
&self,
|
||||
request: open_ai::Request,
|
||||
cx: &AsyncApp,
|
||||
) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<ResponseStreamEvent>>>>
|
||||
{
|
||||
let http_client = self.http_client.clone();
|
||||
let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, cx| {
|
||||
let settings = &AllLanguageModelSettings::get_global(cx).x_ai;
|
||||
let api_url = if settings.api_url.is_empty() {
|
||||
x_ai::XAI_API_URL.to_string()
|
||||
} else {
|
||||
settings.api_url.clone()
|
||||
};
|
||||
(state.api_key.clone(), api_url)
|
||||
}) else {
|
||||
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
|
||||
};
|
||||
|
||||
let future = self.request_limiter.stream(async move {
|
||||
let api_key = api_key.context("Missing xAI API Key")?;
|
||||
let request =
|
||||
open_ai::stream_completion(http_client.as_ref(), &api_url, &api_key, request);
|
||||
let response = request.await?;
|
||||
Ok(response)
|
||||
});
|
||||
|
||||
async move { Ok(future.await?.boxed()) }.boxed()
|
||||
}
|
||||
}
|
||||
|
||||
impl LanguageModel for XAiLanguageModel {
|
||||
fn id(&self) -> LanguageModelId {
|
||||
self.id.clone()
|
||||
}
|
||||
|
||||
fn name(&self) -> LanguageModelName {
|
||||
LanguageModelName::from(self.model.display_name().to_string())
|
||||
}
|
||||
|
||||
fn provider_id(&self) -> LanguageModelProviderId {
|
||||
LanguageModelProviderId(PROVIDER_ID.into())
|
||||
}
|
||||
|
||||
fn provider_name(&self) -> LanguageModelProviderName {
|
||||
LanguageModelProviderName(PROVIDER_NAME.into())
|
||||
}
|
||||
|
||||
fn supports_tools(&self) -> bool {
|
||||
self.model.supports_tool()
|
||||
}
|
||||
|
||||
fn supports_images(&self) -> bool {
|
||||
self.model.supports_images()
|
||||
}
|
||||
|
||||
fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
|
||||
match choice {
|
||||
LanguageModelToolChoice::Auto
|
||||
| LanguageModelToolChoice::Any
|
||||
| LanguageModelToolChoice::None => true,
|
||||
}
|
||||
}
|
||||
fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
|
||||
let model_id = self.model.id().trim().to_lowercase();
|
||||
if model_id.eq(x_ai::Model::Grok4.id()) {
|
||||
LanguageModelToolSchemaFormat::JsonSchemaSubset
|
||||
} else {
|
||||
LanguageModelToolSchemaFormat::JsonSchema
|
||||
}
|
||||
}
|
||||
|
||||
fn telemetry_id(&self) -> String {
|
||||
format!("x_ai/{}", self.model.id())
|
||||
}
|
||||
|
||||
fn max_token_count(&self) -> u64 {
|
||||
self.model.max_token_count()
|
||||
}
|
||||
|
||||
fn max_output_tokens(&self) -> Option<u64> {
|
||||
self.model.max_output_tokens()
|
||||
}
|
||||
|
||||
fn count_tokens(
|
||||
&self,
|
||||
request: LanguageModelRequest,
|
||||
cx: &App,
|
||||
) -> BoxFuture<'static, Result<u64>> {
|
||||
count_xai_tokens(request, self.model.clone(), cx)
|
||||
}
|
||||
|
||||
fn stream_completion(
|
||||
&self,
|
||||
request: LanguageModelRequest,
|
||||
cx: &AsyncApp,
|
||||
) -> BoxFuture<
|
||||
'static,
|
||||
Result<
|
||||
futures::stream::BoxStream<
|
||||
'static,
|
||||
Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
|
||||
>,
|
||||
LanguageModelCompletionError,
|
||||
>,
|
||||
> {
|
||||
let request = crate::provider::open_ai::into_open_ai(
|
||||
request,
|
||||
self.model.id(),
|
||||
self.model.supports_parallel_tool_calls(),
|
||||
self.max_output_tokens(),
|
||||
);
|
||||
let completions = self.stream_completion(request, cx);
|
||||
async move {
|
||||
let mapper = crate::provider::open_ai::OpenAiEventMapper::new();
|
||||
Ok(mapper.map_stream(completions.await?).boxed())
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn count_xai_tokens(
|
||||
request: LanguageModelRequest,
|
||||
model: Model,
|
||||
cx: &App,
|
||||
) -> BoxFuture<'static, Result<u64>> {
|
||||
cx.background_spawn(async move {
|
||||
let messages = request
|
||||
.messages
|
||||
.into_iter()
|
||||
.map(|message| tiktoken_rs::ChatCompletionRequestMessage {
|
||||
role: match message.role {
|
||||
Role::User => "user".into(),
|
||||
Role::Assistant => "assistant".into(),
|
||||
Role::System => "system".into(),
|
||||
},
|
||||
content: Some(message.string_contents()),
|
||||
name: None,
|
||||
function_call: None,
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let model_name = if model.max_token_count() >= 100_000 {
|
||||
"gpt-4o"
|
||||
} else {
|
||||
"gpt-4"
|
||||
};
|
||||
tiktoken_rs::num_tokens_from_messages(model_name, &messages).map(|tokens| tokens as u64)
|
||||
})
|
||||
.boxed()
|
||||
}
|
||||
|
||||
struct ConfigurationView {
|
||||
api_key_editor: Entity<SingleLineInput>,
|
||||
state: gpui::Entity<State>,
|
||||
load_credentials_task: Option<Task<()>>,
|
||||
}
|
||||
|
||||
impl ConfigurationView {
|
||||
fn new(state: gpui::Entity<State>, window: &mut Window, cx: &mut Context<Self>) -> Self {
|
||||
let api_key_editor = cx.new(|cx| {
|
||||
SingleLineInput::new(
|
||||
window,
|
||||
cx,
|
||||
"xai-0000000000000000000000000000000000000000000000000",
|
||||
)
|
||||
.label("API key")
|
||||
});
|
||||
|
||||
cx.observe(&state, |_, _, cx| {
|
||||
cx.notify();
|
||||
})
|
||||
.detach();
|
||||
|
||||
let load_credentials_task = Some(cx.spawn_in(window, {
|
||||
let state = state.clone();
|
||||
async move |this, cx| {
|
||||
if let Some(task) = state
|
||||
.update(cx, |state, cx| state.authenticate(cx))
|
||||
.log_err()
|
||||
{
|
||||
// We don't log an error, because "not signed in" is also an error.
|
||||
let _ = task.await;
|
||||
}
|
||||
this.update(cx, |this, cx| {
|
||||
this.load_credentials_task = None;
|
||||
cx.notify();
|
||||
})
|
||||
.log_err();
|
||||
}
|
||||
}));
|
||||
|
||||
Self {
|
||||
api_key_editor,
|
||||
state,
|
||||
load_credentials_task,
|
||||
}
|
||||
}
|
||||
|
||||
fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
|
||||
let api_key = self
|
||||
.api_key_editor
|
||||
.read(cx)
|
||||
.editor()
|
||||
.read(cx)
|
||||
.text(cx)
|
||||
.trim()
|
||||
.to_string();
|
||||
|
||||
// Don't proceed if no API key is provided and we're not authenticated
|
||||
if api_key.is_empty() && !self.state.read(cx).is_authenticated() {
|
||||
return;
|
||||
}
|
||||
|
||||
let state = self.state.clone();
|
||||
cx.spawn_in(window, async move |_, cx| {
|
||||
state
|
||||
.update(cx, |state, cx| state.set_api_key(api_key, cx))?
|
||||
.await
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
|
||||
self.api_key_editor.update(cx, |input, cx| {
|
||||
input.editor.update(cx, |editor, cx| {
|
||||
editor.set_text("", window, cx);
|
||||
});
|
||||
});
|
||||
|
||||
let state = self.state.clone();
|
||||
cx.spawn_in(window, async move |_, cx| {
|
||||
state.update(cx, |state, cx| state.reset_api_key(cx))?.await
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
fn should_render_editor(&self, cx: &mut Context<Self>) -> bool {
|
||||
!self.state.read(cx).is_authenticated()
|
||||
}
|
||||
}
|
||||
|
||||
impl Render for ConfigurationView {
|
||||
fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
|
||||
let env_var_set = self.state.read(cx).api_key_from_env;
|
||||
|
||||
let api_key_section = if self.should_render_editor(cx) {
|
||||
v_flex()
|
||||
.on_action(cx.listener(Self::save_api_key))
|
||||
.child(Label::new("To use Zed's agent with xAI, you need to add an API key. Follow these steps:"))
|
||||
.child(
|
||||
List::new()
|
||||
.child(InstructionListItem::new(
|
||||
"Create one by visiting",
|
||||
Some("xAI console"),
|
||||
Some("https://console.x.ai/team/default/api-keys"),
|
||||
))
|
||||
.child(InstructionListItem::text_only(
|
||||
"Paste your API key below and hit enter to start using the agent",
|
||||
)),
|
||||
)
|
||||
.child(self.api_key_editor.clone())
|
||||
.child(
|
||||
Label::new(format!(
|
||||
"You can also assign the {XAI_API_KEY_VAR} environment variable and restart Zed."
|
||||
))
|
||||
.size(LabelSize::Small)
|
||||
.color(Color::Muted),
|
||||
)
|
||||
.child(
|
||||
Label::new("Note that xAI is a custom OpenAI-compatible provider.")
|
||||
.size(LabelSize::Small)
|
||||
.color(Color::Muted),
|
||||
)
|
||||
.into_any()
|
||||
} else {
|
||||
h_flex()
|
||||
.mt_1()
|
||||
.p_1()
|
||||
.justify_between()
|
||||
.rounded_md()
|
||||
.border_1()
|
||||
.border_color(cx.theme().colors().border)
|
||||
.bg(cx.theme().colors().background)
|
||||
.child(
|
||||
h_flex()
|
||||
.gap_1()
|
||||
.child(Icon::new(IconName::Check).color(Color::Success))
|
||||
.child(Label::new(if env_var_set {
|
||||
format!("API key set in {XAI_API_KEY_VAR} environment variable.")
|
||||
} else {
|
||||
"API key configured.".to_string()
|
||||
})),
|
||||
)
|
||||
.child(
|
||||
Button::new("reset-api-key", "Reset API Key")
|
||||
.label_size(LabelSize::Small)
|
||||
.icon(IconName::Undo)
|
||||
.icon_size(IconSize::Small)
|
||||
.icon_position(IconPosition::Start)
|
||||
.layer(ElevationIndex::ModalSurface)
|
||||
.when(env_var_set, |this| {
|
||||
this.tooltip(Tooltip::text(format!("To reset your API key, unset the {XAI_API_KEY_VAR} environment variable.")))
|
||||
})
|
||||
.on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx))),
|
||||
)
|
||||
.into_any()
|
||||
};
|
||||
|
||||
if self.load_credentials_task.is_some() {
|
||||
div().child(Label::new("Loading credentials…")).into_any()
|
||||
} else {
|
||||
v_flex().size_full().child(api_key_section).into_any()
|
||||
}
|
||||
}
|
||||
}
|
|
@ -17,6 +17,7 @@ use crate::provider::{
|
|||
open_ai::OpenAiSettings,
|
||||
open_router::OpenRouterSettings,
|
||||
vercel::VercelSettings,
|
||||
x_ai::XAiSettings,
|
||||
};
|
||||
|
||||
/// Initializes the language model settings.
|
||||
|
@ -28,33 +29,33 @@ pub fn init(cx: &mut App) {
|
|||
pub struct AllLanguageModelSettings {
|
||||
pub anthropic: AnthropicSettings,
|
||||
pub bedrock: AmazonBedrockSettings,
|
||||
pub ollama: OllamaSettings,
|
||||
pub openai: OpenAiSettings,
|
||||
pub open_router: OpenRouterSettings,
|
||||
pub zed_dot_dev: ZedDotDevSettings,
|
||||
pub google: GoogleSettings,
|
||||
pub vercel: VercelSettings,
|
||||
|
||||
pub lmstudio: LmStudioSettings,
|
||||
pub deepseek: DeepSeekSettings,
|
||||
pub google: GoogleSettings,
|
||||
pub lmstudio: LmStudioSettings,
|
||||
pub mistral: MistralSettings,
|
||||
pub ollama: OllamaSettings,
|
||||
pub open_router: OpenRouterSettings,
|
||||
pub openai: OpenAiSettings,
|
||||
pub vercel: VercelSettings,
|
||||
pub x_ai: XAiSettings,
|
||||
pub zed_dot_dev: ZedDotDevSettings,
|
||||
}
|
||||
|
||||
#[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
|
||||
pub struct AllLanguageModelSettingsContent {
|
||||
pub anthropic: Option<AnthropicSettingsContent>,
|
||||
pub bedrock: Option<AmazonBedrockSettingsContent>,
|
||||
pub ollama: Option<OllamaSettingsContent>,
|
||||
pub deepseek: Option<DeepseekSettingsContent>,
|
||||
pub google: Option<GoogleSettingsContent>,
|
||||
pub lmstudio: Option<LmStudioSettingsContent>,
|
||||
pub openai: Option<OpenAiSettingsContent>,
|
||||
pub mistral: Option<MistralSettingsContent>,
|
||||
pub ollama: Option<OllamaSettingsContent>,
|
||||
pub open_router: Option<OpenRouterSettingsContent>,
|
||||
pub openai: Option<OpenAiSettingsContent>,
|
||||
pub vercel: Option<VercelSettingsContent>,
|
||||
pub x_ai: Option<XAiSettingsContent>,
|
||||
#[serde(rename = "zed.dev")]
|
||||
pub zed_dot_dev: Option<ZedDotDevSettingsContent>,
|
||||
pub google: Option<GoogleSettingsContent>,
|
||||
pub deepseek: Option<DeepseekSettingsContent>,
|
||||
pub vercel: Option<VercelSettingsContent>,
|
||||
|
||||
pub mistral: Option<MistralSettingsContent>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
|
||||
|
@ -114,6 +115,12 @@ pub struct GoogleSettingsContent {
|
|||
pub available_models: Option<Vec<provider::google::AvailableModel>>,
|
||||
}
|
||||
|
||||
#[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
|
||||
pub struct XAiSettingsContent {
|
||||
pub api_url: Option<String>,
|
||||
pub available_models: Option<Vec<provider::x_ai::AvailableModel>>,
|
||||
}
|
||||
|
||||
#[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
|
||||
pub struct ZedDotDevSettingsContent {
|
||||
available_models: Option<Vec<cloud::AvailableModel>>,
|
||||
|
@ -230,6 +237,18 @@ impl settings::Settings for AllLanguageModelSettings {
|
|||
vercel.as_ref().and_then(|s| s.available_models.clone()),
|
||||
);
|
||||
|
||||
// XAI
|
||||
let x_ai = value.x_ai.clone();
|
||||
merge(
|
||||
&mut settings.x_ai.api_url,
|
||||
x_ai.as_ref().and_then(|s| s.api_url.clone()),
|
||||
);
|
||||
merge(
|
||||
&mut settings.x_ai.available_models,
|
||||
x_ai.as_ref().and_then(|s| s.available_models.clone()),
|
||||
);
|
||||
|
||||
// ZedDotDev
|
||||
merge(
|
||||
&mut settings.zed_dot_dev.available_models,
|
||||
value
|
||||
|
|
|
@ -3362,8 +3362,14 @@ impl Project {
|
|||
cx: &mut Context<Self>,
|
||||
) -> Task<Result<Vec<LocationLink>>> {
|
||||
let position = position.to_point_utf16(buffer.read(cx));
|
||||
self.lsp_store.update(cx, |lsp_store, cx| {
|
||||
let guard = self.retain_remotely_created_models(cx);
|
||||
let task = self.lsp_store.update(cx, |lsp_store, cx| {
|
||||
lsp_store.definitions(buffer, position, cx)
|
||||
});
|
||||
cx.spawn(async move |_, _| {
|
||||
let result = task.await;
|
||||
drop(guard);
|
||||
result
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -3374,8 +3380,14 @@ impl Project {
|
|||
cx: &mut Context<Self>,
|
||||
) -> Task<Result<Vec<LocationLink>>> {
|
||||
let position = position.to_point_utf16(buffer.read(cx));
|
||||
self.lsp_store.update(cx, |lsp_store, cx| {
|
||||
let guard = self.retain_remotely_created_models(cx);
|
||||
let task = self.lsp_store.update(cx, |lsp_store, cx| {
|
||||
lsp_store.declarations(buffer, position, cx)
|
||||
});
|
||||
cx.spawn(async move |_, _| {
|
||||
let result = task.await;
|
||||
drop(guard);
|
||||
result
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -3386,8 +3398,14 @@ impl Project {
|
|||
cx: &mut Context<Self>,
|
||||
) -> Task<Result<Vec<LocationLink>>> {
|
||||
let position = position.to_point_utf16(buffer.read(cx));
|
||||
self.lsp_store.update(cx, |lsp_store, cx| {
|
||||
let guard = self.retain_remotely_created_models(cx);
|
||||
let task = self.lsp_store.update(cx, |lsp_store, cx| {
|
||||
lsp_store.type_definitions(buffer, position, cx)
|
||||
});
|
||||
cx.spawn(async move |_, _| {
|
||||
let result = task.await;
|
||||
drop(guard);
|
||||
result
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -3398,8 +3416,14 @@ impl Project {
|
|||
cx: &mut Context<Self>,
|
||||
) -> Task<Result<Vec<LocationLink>>> {
|
||||
let position = position.to_point_utf16(buffer.read(cx));
|
||||
self.lsp_store.update(cx, |lsp_store, cx| {
|
||||
let guard = self.retain_remotely_created_models(cx);
|
||||
let task = self.lsp_store.update(cx, |lsp_store, cx| {
|
||||
lsp_store.implementations(buffer, position, cx)
|
||||
});
|
||||
cx.spawn(async move |_, _| {
|
||||
let result = task.await;
|
||||
drop(guard);
|
||||
result
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -3410,8 +3434,14 @@ impl Project {
|
|||
cx: &mut Context<Self>,
|
||||
) -> Task<Result<Vec<Location>>> {
|
||||
let position = position.to_point_utf16(buffer.read(cx));
|
||||
self.lsp_store.update(cx, |lsp_store, cx| {
|
||||
let guard = self.retain_remotely_created_models(cx);
|
||||
let task = self.lsp_store.update(cx, |lsp_store, cx| {
|
||||
lsp_store.references(buffer, position, cx)
|
||||
});
|
||||
cx.spawn(async move |_, _| {
|
||||
let result = task.await;
|
||||
drop(guard);
|
||||
result
|
||||
})
|
||||
}
|
||||
|
||||
|
|
|
@ -186,7 +186,6 @@ struct EntryDetails {
|
|||
#[derive(Debug, PartialEq, Eq, Clone)]
|
||||
struct StickyDetails {
|
||||
sticky_index: usize,
|
||||
is_last: bool,
|
||||
}
|
||||
|
||||
/// Permanently deletes the selected file or directory.
|
||||
|
@ -342,12 +341,20 @@ struct ItemColors {
|
|||
focused: Hsla,
|
||||
}
|
||||
|
||||
fn get_item_color(cx: &App) -> ItemColors {
|
||||
fn get_item_color(is_sticky: bool, cx: &App) -> ItemColors {
|
||||
let colors = cx.theme().colors();
|
||||
|
||||
ItemColors {
|
||||
default: colors.panel_background,
|
||||
hover: colors.element_hover,
|
||||
default: if is_sticky {
|
||||
colors.panel_overlay_background
|
||||
} else {
|
||||
colors.panel_background
|
||||
},
|
||||
hover: if is_sticky {
|
||||
colors.panel_overlay_hover
|
||||
} else {
|
||||
colors.element_hover
|
||||
},
|
||||
marked: colors.element_selected,
|
||||
focused: colors.panel_focused_border,
|
||||
drag_over: colors.drop_target_background,
|
||||
|
@ -3850,7 +3857,7 @@ impl ProjectPanel {
|
|||
|
||||
let filename_text_color = details.filename_text_color;
|
||||
let diagnostic_severity = details.diagnostic_severity;
|
||||
let item_colors = get_item_color(cx);
|
||||
let item_colors = get_item_color(is_sticky, cx);
|
||||
|
||||
let canonical_path = details
|
||||
.canonical_path
|
||||
|
@ -3938,31 +3945,14 @@ impl ProjectPanel {
|
|||
}
|
||||
};
|
||||
|
||||
let show_sticky_shadow = details.sticky.as_ref().map_or(false, |item| {
|
||||
if item.is_last {
|
||||
let is_scrollable = self.scroll_handle.is_scrollable();
|
||||
let is_scrolled = self.scroll_handle.offset().y < px(0.);
|
||||
is_scrollable && is_scrolled
|
||||
} else {
|
||||
false
|
||||
}
|
||||
});
|
||||
let shadow_color_top = hsla(0.0, 0.0, 0.0, 0.1);
|
||||
let shadow_color_bottom = hsla(0.0, 0.0, 0.0, 0.);
|
||||
let sticky_shadow = div()
|
||||
.absolute()
|
||||
.left_0()
|
||||
.bottom_neg_1p5()
|
||||
.h_1p5()
|
||||
.w_full()
|
||||
.bg(linear_gradient(
|
||||
0.,
|
||||
linear_color_stop(shadow_color_top, 1.),
|
||||
linear_color_stop(shadow_color_bottom, 0.),
|
||||
));
|
||||
let id: ElementId = if is_sticky {
|
||||
SharedString::from(format!("project_panel_sticky_item_{}", entry_id.to_usize())).into()
|
||||
} else {
|
||||
(entry_id.to_proto() as usize).into()
|
||||
};
|
||||
|
||||
div()
|
||||
.id(entry_id.to_proto() as usize)
|
||||
.id(id.clone())
|
||||
.relative()
|
||||
.group(GROUP_NAME)
|
||||
.cursor_pointer()
|
||||
|
@ -3972,7 +3962,9 @@ impl ProjectPanel {
|
|||
.border_r_2()
|
||||
.border_color(border_color)
|
||||
.hover(|style| style.bg(bg_hover_color).border_color(border_hover_color))
|
||||
.when(show_sticky_shadow, |this| this.child(sticky_shadow))
|
||||
.when(is_sticky, |this| {
|
||||
this.block_mouse_except_scroll()
|
||||
})
|
||||
.when(!is_sticky, |this| {
|
||||
this
|
||||
.when(is_highlighted && folded_directory_drag_target.is_none(), |this| this.border_color(transparent_white()).bg(item_colors.drag_over))
|
||||
|
@ -4183,6 +4175,16 @@ impl ProjectPanel {
|
|||
.unwrap_or(ScrollStrategy::Top);
|
||||
this.scroll_handle.scroll_to_item(index, strategy);
|
||||
cx.notify();
|
||||
// move down by 1px so that clicked item
|
||||
// don't count as sticky anymore
|
||||
cx.on_next_frame(window, |_, window, cx| {
|
||||
cx.on_next_frame(window, |this, _, cx| {
|
||||
let mut offset = this.scroll_handle.offset();
|
||||
offset.y += px(1.);
|
||||
this.scroll_handle.set_offset(offset);
|
||||
cx.notify();
|
||||
});
|
||||
});
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
@ -4201,7 +4203,7 @@ impl ProjectPanel {
|
|||
}),
|
||||
)
|
||||
.child(
|
||||
ListItem::new(entry_id.to_proto() as usize)
|
||||
ListItem::new(id)
|
||||
.indent_level(depth)
|
||||
.indent_step_size(px(settings.indent_size))
|
||||
.spacing(match settings.entry_spacing {
|
||||
|
@ -4924,7 +4926,6 @@ impl ProjectPanel {
|
|||
.unwrap_or_default();
|
||||
let sticky_details = Some(StickyDetails {
|
||||
sticky_index: index,
|
||||
is_last: index == last_item_index,
|
||||
});
|
||||
let details = self.details_for_entry(
|
||||
entry,
|
||||
|
@ -4936,7 +4937,24 @@ impl ProjectPanel {
|
|||
window,
|
||||
cx,
|
||||
);
|
||||
self.render_entry(entry.id, details, window, cx).into_any()
|
||||
self.render_entry(entry.id, details, window, cx)
|
||||
.when(index == last_item_index, |this| {
|
||||
let shadow_color_top = hsla(0.0, 0.0, 0.0, 0.1);
|
||||
let shadow_color_bottom = hsla(0.0, 0.0, 0.0, 0.);
|
||||
let sticky_shadow = div()
|
||||
.absolute()
|
||||
.left_0()
|
||||
.bottom_neg_1p5()
|
||||
.h_1p5()
|
||||
.w_full()
|
||||
.bg(linear_gradient(
|
||||
0.,
|
||||
linear_color_stop(shadow_color_top, 1.),
|
||||
linear_color_stop(shadow_color_bottom, 0.),
|
||||
));
|
||||
this.child(sticky_shadow)
|
||||
})
|
||||
.into_any()
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
@ -4970,7 +4988,16 @@ impl Render for ProjectPanel {
|
|||
let indent_size = ProjectPanelSettings::get_global(cx).indent_size;
|
||||
let show_indent_guides =
|
||||
ProjectPanelSettings::get_global(cx).indent_guides.show == ShowIndentGuides::Always;
|
||||
let show_sticky_scroll = ProjectPanelSettings::get_global(cx).sticky_scroll;
|
||||
let show_sticky_entries = {
|
||||
if ProjectPanelSettings::get_global(cx).sticky_scroll {
|
||||
let is_scrollable = self.scroll_handle.is_scrollable();
|
||||
let is_scrolled = self.scroll_handle.offset().y < px(0.);
|
||||
is_scrollable && is_scrolled
|
||||
} else {
|
||||
false
|
||||
}
|
||||
};
|
||||
|
||||
let is_local = project.is_local();
|
||||
|
||||
if has_worktree {
|
||||
|
@ -5262,7 +5289,7 @@ impl Render for ProjectPanel {
|
|||
}),
|
||||
)
|
||||
})
|
||||
.when(show_sticky_scroll, |list| {
|
||||
.when(show_sticky_entries, |list| {
|
||||
let sticky_items = ui::sticky_items(
|
||||
cx.entity().clone(),
|
||||
|this, range, window, cx| {
|
||||
|
|
|
@ -611,7 +611,7 @@ impl RulesLibrary {
|
|||
this.update_in(cx, |this, window, cx| match rule {
|
||||
Ok(rule) => {
|
||||
let title_editor = cx.new(|cx| {
|
||||
let mut editor = Editor::auto_width(window, cx);
|
||||
let mut editor = Editor::single_line(window, cx);
|
||||
editor.set_placeholder_text("Untitled", cx);
|
||||
editor.set_text(rule_metadata.title.unwrap_or_default(), window, cx);
|
||||
if prompt_id.is_built_in() {
|
||||
|
|
|
@ -127,7 +127,7 @@ impl BatchedTextRun {
|
|||
cx: &mut App,
|
||||
) {
|
||||
let pos = Point::new(
|
||||
(origin.x + self.start_point.column as f32 * dimensions.cell_width).floor(),
|
||||
origin.x + self.start_point.column as f32 * dimensions.cell_width,
|
||||
origin.y + self.start_point.line as f32 * dimensions.line_height,
|
||||
);
|
||||
|
||||
|
@ -494,6 +494,30 @@ impl TerminalElement {
|
|||
}
|
||||
}
|
||||
|
||||
/// Checks if a character is a decorative block/box-like character that should
|
||||
/// preserve its exact colors without contrast adjustment.
|
||||
///
|
||||
/// This specifically targets characters used as visual connectors, separators,
|
||||
/// and borders where color matching with adjacent backgrounds is critical.
|
||||
/// Regular icons (git, folders, etc.) are excluded as they need to remain readable.
|
||||
///
|
||||
/// Fixes https://github.com/zed-industries/zed/issues/34234
|
||||
fn is_decorative_character(ch: char) -> bool {
|
||||
matches!(
|
||||
ch as u32,
|
||||
// Unicode Box Drawing and Block Elements
|
||||
0x2500..=0x257F // Box Drawing (└ ┐ ─ │ etc.)
|
||||
| 0x2580..=0x259F // Block Elements (▀ ▄ █ ░ ▒ ▓ etc.)
|
||||
| 0x25A0..=0x25FF // Geometric Shapes (■ ▶ ● etc. - includes triangular/circular separators)
|
||||
|
||||
// Private Use Area - Powerline separator symbols only
|
||||
| 0xE0B0..=0xE0B7 // Powerline separators: triangles (E0B0-E0B3) and half circles (E0B4-E0B7)
|
||||
| 0xE0B8..=0xE0BF // Additional Powerline separators: angles, flames, etc.
|
||||
| 0xE0C0..=0xE0C8 // Powerline separators: pixelated triangles, curves
|
||||
| 0xE0CC..=0xE0D4 // Powerline separators: rounded triangles, ice/lego style
|
||||
)
|
||||
}
|
||||
|
||||
/// Converts the Alacritty cell styles to GPUI text styles and background color.
|
||||
fn cell_style(
|
||||
indexed: &IndexedCell,
|
||||
|
@ -508,7 +532,10 @@ impl TerminalElement {
|
|||
let mut fg = convert_color(&fg, colors);
|
||||
let bg = convert_color(&bg, colors);
|
||||
|
||||
fg = color_contrast::ensure_minimum_contrast(fg, bg, minimum_contrast);
|
||||
// Only apply contrast adjustment to non-decorative characters
|
||||
if !Self::is_decorative_character(indexed.c) {
|
||||
fg = color_contrast::ensure_minimum_contrast(fg, bg, minimum_contrast);
|
||||
}
|
||||
|
||||
// Ghostty uses (175/255) as the multiplier (~0.69), Alacritty uses 0.66, Kitty
|
||||
// uses 0.75. We're using 0.7 because it's pretty well in the middle of that.
|
||||
|
@ -1575,6 +1602,101 @@ mod tests {
|
|||
use super::*;
|
||||
use gpui::{AbsoluteLength, Hsla, font};
|
||||
|
||||
#[test]
|
||||
fn test_is_decorative_character() {
|
||||
// Box Drawing characters (U+2500 to U+257F)
|
||||
assert!(TerminalElement::is_decorative_character('─')); // U+2500
|
||||
assert!(TerminalElement::is_decorative_character('│')); // U+2502
|
||||
assert!(TerminalElement::is_decorative_character('┌')); // U+250C
|
||||
assert!(TerminalElement::is_decorative_character('┐')); // U+2510
|
||||
assert!(TerminalElement::is_decorative_character('└')); // U+2514
|
||||
assert!(TerminalElement::is_decorative_character('┘')); // U+2518
|
||||
assert!(TerminalElement::is_decorative_character('┼')); // U+253C
|
||||
|
||||
// Block Elements (U+2580 to U+259F)
|
||||
assert!(TerminalElement::is_decorative_character('▀')); // U+2580
|
||||
assert!(TerminalElement::is_decorative_character('▄')); // U+2584
|
||||
assert!(TerminalElement::is_decorative_character('█')); // U+2588
|
||||
assert!(TerminalElement::is_decorative_character('░')); // U+2591
|
||||
assert!(TerminalElement::is_decorative_character('▒')); // U+2592
|
||||
assert!(TerminalElement::is_decorative_character('▓')); // U+2593
|
||||
|
||||
// Geometric Shapes - block/box-like subset (U+25A0 to U+25D7)
|
||||
assert!(TerminalElement::is_decorative_character('■')); // U+25A0
|
||||
assert!(TerminalElement::is_decorative_character('□')); // U+25A1
|
||||
assert!(TerminalElement::is_decorative_character('▲')); // U+25B2
|
||||
assert!(TerminalElement::is_decorative_character('▼')); // U+25BC
|
||||
assert!(TerminalElement::is_decorative_character('◆')); // U+25C6
|
||||
assert!(TerminalElement::is_decorative_character('●')); // U+25CF
|
||||
|
||||
// The specific character from the issue
|
||||
assert!(TerminalElement::is_decorative_character('◗')); // U+25D7
|
||||
assert!(TerminalElement::is_decorative_character('◘')); // U+25D8 (now included in Geometric Shapes)
|
||||
assert!(TerminalElement::is_decorative_character('◙')); // U+25D9 (now included in Geometric Shapes)
|
||||
|
||||
// Powerline symbols (Private Use Area)
|
||||
assert!(TerminalElement::is_decorative_character('\u{E0B0}')); // Powerline right triangle
|
||||
assert!(TerminalElement::is_decorative_character('\u{E0B2}')); // Powerline left triangle
|
||||
assert!(TerminalElement::is_decorative_character('\u{E0B4}')); // Powerline right half circle (the actual issue!)
|
||||
assert!(TerminalElement::is_decorative_character('\u{E0B6}')); // Powerline left half circle
|
||||
|
||||
// Characters that should NOT be considered decorative
|
||||
assert!(!TerminalElement::is_decorative_character('A')); // Regular letter
|
||||
assert!(!TerminalElement::is_decorative_character('$')); // Symbol
|
||||
assert!(!TerminalElement::is_decorative_character(' ')); // Space
|
||||
assert!(!TerminalElement::is_decorative_character('←')); // U+2190 (Arrow, not in our ranges)
|
||||
assert!(!TerminalElement::is_decorative_character('→')); // U+2192 (Arrow, not in our ranges)
|
||||
assert!(!TerminalElement::is_decorative_character('\u{F00C}')); // Font Awesome check (icon, needs contrast)
|
||||
assert!(!TerminalElement::is_decorative_character('\u{E711}')); // Devicons (icon, needs contrast)
|
||||
assert!(!TerminalElement::is_decorative_character('\u{EA71}')); // Codicons folder (icon, needs contrast)
|
||||
assert!(!TerminalElement::is_decorative_character('\u{F401}')); // Octicons (icon, needs contrast)
|
||||
assert!(!TerminalElement::is_decorative_character('\u{1F600}')); // Emoji (not in our ranges)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_decorative_character_boundary_cases() {
|
||||
// Test exact boundaries of our ranges
|
||||
// Box Drawing range boundaries
|
||||
assert!(TerminalElement::is_decorative_character('\u{2500}')); // First char
|
||||
assert!(TerminalElement::is_decorative_character('\u{257F}')); // Last char
|
||||
assert!(!TerminalElement::is_decorative_character('\u{24FF}')); // Just before
|
||||
|
||||
// Block Elements range boundaries
|
||||
assert!(TerminalElement::is_decorative_character('\u{2580}')); // First char
|
||||
assert!(TerminalElement::is_decorative_character('\u{259F}')); // Last char
|
||||
|
||||
// Geometric Shapes subset boundaries
|
||||
assert!(TerminalElement::is_decorative_character('\u{25A0}')); // First char
|
||||
assert!(TerminalElement::is_decorative_character('\u{25FF}')); // Last char
|
||||
assert!(!TerminalElement::is_decorative_character('\u{2600}')); // Just after
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_decorative_characters_bypass_contrast_adjustment() {
|
||||
// Decorative characters should not be affected by contrast adjustment
|
||||
|
||||
// The specific character from issue #34234
|
||||
let problematic_char = '◗'; // U+25D7
|
||||
assert!(
|
||||
TerminalElement::is_decorative_character(problematic_char),
|
||||
"Character ◗ (U+25D7) should be recognized as decorative"
|
||||
);
|
||||
|
||||
// Verify some other commonly used decorative characters
|
||||
assert!(TerminalElement::is_decorative_character('│')); // Vertical line
|
||||
assert!(TerminalElement::is_decorative_character('─')); // Horizontal line
|
||||
assert!(TerminalElement::is_decorative_character('█')); // Full block
|
||||
assert!(TerminalElement::is_decorative_character('▓')); // Dark shade
|
||||
assert!(TerminalElement::is_decorative_character('■')); // Black square
|
||||
assert!(TerminalElement::is_decorative_character('●')); // Black circle
|
||||
|
||||
// Verify normal text characters are NOT decorative
|
||||
assert!(!TerminalElement::is_decorative_character('A'));
|
||||
assert!(!TerminalElement::is_decorative_character('1'));
|
||||
assert!(!TerminalElement::is_decorative_character('$'));
|
||||
assert!(!TerminalElement::is_decorative_character(' '));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_contrast_adjustment_logic() {
|
||||
// Test the core contrast adjustment logic without needing full app context
|
||||
|
|
|
@ -83,6 +83,8 @@ impl ThemeColors {
|
|||
panel_indent_guide: neutral().light_alpha().step_5(),
|
||||
panel_indent_guide_hover: neutral().light_alpha().step_6(),
|
||||
panel_indent_guide_active: neutral().light_alpha().step_6(),
|
||||
panel_overlay_background: neutral().light().step_2(),
|
||||
panel_overlay_hover: neutral().light_alpha().step_4(),
|
||||
pane_focused_border: blue().light().step_5(),
|
||||
pane_group_border: neutral().light().step_6(),
|
||||
scrollbar_thumb_background: neutral().light_alpha().step_3(),
|
||||
|
@ -206,6 +208,8 @@ impl ThemeColors {
|
|||
panel_indent_guide: neutral().dark_alpha().step_4(),
|
||||
panel_indent_guide_hover: neutral().dark_alpha().step_6(),
|
||||
panel_indent_guide_active: neutral().dark_alpha().step_6(),
|
||||
panel_overlay_background: neutral().dark().step_2(),
|
||||
panel_overlay_hover: neutral().dark_alpha().step_4(),
|
||||
pane_focused_border: blue().dark().step_5(),
|
||||
pane_group_border: neutral().dark().step_6(),
|
||||
scrollbar_thumb_background: neutral().dark_alpha().step_3(),
|
||||
|
|
|
@ -59,6 +59,7 @@ pub(crate) fn zed_default_dark() -> Theme {
|
|||
let bg = hsla(215. / 360., 12. / 100., 15. / 100., 1.);
|
||||
let editor = hsla(220. / 360., 12. / 100., 18. / 100., 1.);
|
||||
let elevated_surface = hsla(225. / 360., 12. / 100., 17. / 100., 1.);
|
||||
let hover = hsla(225.0 / 360., 11.8 / 100., 26.7 / 100., 1.0);
|
||||
|
||||
let blue = hsla(207.8 / 360., 81. / 100., 66. / 100., 1.0);
|
||||
let gray = hsla(218.8 / 360., 10. / 100., 40. / 100., 1.0);
|
||||
|
@ -108,14 +109,14 @@ pub(crate) fn zed_default_dark() -> Theme {
|
|||
surface_background: bg,
|
||||
background: bg,
|
||||
element_background: hsla(223.0 / 360., 13. / 100., 21. / 100., 1.0),
|
||||
element_hover: hsla(225.0 / 360., 11.8 / 100., 26.7 / 100., 1.0),
|
||||
element_hover: hover,
|
||||
element_active: hsla(220.0 / 360., 11.8 / 100., 20.0 / 100., 1.0),
|
||||
element_selected: hsla(224.0 / 360., 11.3 / 100., 26.1 / 100., 1.0),
|
||||
element_disabled: SystemColors::default().transparent,
|
||||
element_selection_background: player.local().selection.alpha(0.25),
|
||||
drop_target_background: hsla(220.0 / 360., 8.3 / 100., 21.4 / 100., 1.0),
|
||||
ghost_element_background: SystemColors::default().transparent,
|
||||
ghost_element_hover: hsla(225.0 / 360., 11.8 / 100., 26.7 / 100., 1.0),
|
||||
ghost_element_hover: hover,
|
||||
ghost_element_active: hsla(220.0 / 360., 11.8 / 100., 20.0 / 100., 1.0),
|
||||
ghost_element_selected: hsla(224.0 / 360., 11.3 / 100., 26.1 / 100., 1.0),
|
||||
ghost_element_disabled: SystemColors::default().transparent,
|
||||
|
@ -202,10 +203,12 @@ pub(crate) fn zed_default_dark() -> Theme {
|
|||
panel_indent_guide: hsla(228. / 360., 8. / 100., 25. / 100., 1.),
|
||||
panel_indent_guide_hover: hsla(225. / 360., 13. / 100., 12. / 100., 1.),
|
||||
panel_indent_guide_active: hsla(225. / 360., 13. / 100., 12. / 100., 1.),
|
||||
panel_overlay_background: bg,
|
||||
panel_overlay_hover: hover,
|
||||
pane_focused_border: blue,
|
||||
pane_group_border: hsla(225. / 360., 13. / 100., 12. / 100., 1.),
|
||||
scrollbar_thumb_background: gpui::transparent_black(),
|
||||
scrollbar_thumb_hover_background: hsla(225.0 / 360., 11.8 / 100., 26.7 / 100., 1.0),
|
||||
scrollbar_thumb_hover_background: hover,
|
||||
scrollbar_thumb_active_background: hsla(
|
||||
225.0 / 360.,
|
||||
11.8 / 100.,
|
||||
|
|
|
@ -352,6 +352,12 @@ pub struct ThemeColorsContent {
|
|||
#[serde(rename = "panel.indent_guide_active")]
|
||||
pub panel_indent_guide_active: Option<String>,
|
||||
|
||||
#[serde(rename = "panel.overlay_background")]
|
||||
pub panel_overlay_background: Option<String>,
|
||||
|
||||
#[serde(rename = "panel.overlay_hover")]
|
||||
pub panel_overlay_hover: Option<String>,
|
||||
|
||||
#[serde(rename = "pane.focused_border")]
|
||||
pub pane_focused_border: Option<String>,
|
||||
|
||||
|
@ -675,6 +681,14 @@ impl ThemeColorsContent {
|
|||
.scrollbar_thumb_border
|
||||
.as_ref()
|
||||
.and_then(|color| try_parse_color(color).ok());
|
||||
let element_hover = self
|
||||
.element_hover
|
||||
.as_ref()
|
||||
.and_then(|color| try_parse_color(color).ok());
|
||||
let panel_background = self
|
||||
.panel_background
|
||||
.as_ref()
|
||||
.and_then(|color| try_parse_color(color).ok());
|
||||
ThemeColorsRefinement {
|
||||
border,
|
||||
border_variant: self
|
||||
|
@ -713,10 +727,7 @@ impl ThemeColorsContent {
|
|||
.element_background
|
||||
.as_ref()
|
||||
.and_then(|color| try_parse_color(color).ok()),
|
||||
element_hover: self
|
||||
.element_hover
|
||||
.as_ref()
|
||||
.and_then(|color| try_parse_color(color).ok()),
|
||||
element_hover,
|
||||
element_active: self
|
||||
.element_active
|
||||
.as_ref()
|
||||
|
@ -833,10 +844,7 @@ impl ThemeColorsContent {
|
|||
.search_match_background
|
||||
.as_ref()
|
||||
.and_then(|color| try_parse_color(color).ok()),
|
||||
panel_background: self
|
||||
.panel_background
|
||||
.as_ref()
|
||||
.and_then(|color| try_parse_color(color).ok()),
|
||||
panel_background,
|
||||
panel_focused_border: self
|
||||
.panel_focused_border
|
||||
.as_ref()
|
||||
|
@ -853,6 +861,16 @@ impl ThemeColorsContent {
|
|||
.panel_indent_guide_active
|
||||
.as_ref()
|
||||
.and_then(|color| try_parse_color(color).ok()),
|
||||
panel_overlay_background: self
|
||||
.panel_overlay_background
|
||||
.as_ref()
|
||||
.and_then(|color| try_parse_color(color).ok())
|
||||
.or(panel_background),
|
||||
panel_overlay_hover: self
|
||||
.panel_overlay_hover
|
||||
.as_ref()
|
||||
.and_then(|color| try_parse_color(color).ok())
|
||||
.or(element_hover),
|
||||
pane_focused_border: self
|
||||
.pane_focused_border
|
||||
.as_ref()
|
||||
|
|
|
@ -131,6 +131,12 @@ pub struct ThemeColors {
|
|||
pub panel_indent_guide: Hsla,
|
||||
pub panel_indent_guide_hover: Hsla,
|
||||
pub panel_indent_guide_active: Hsla,
|
||||
|
||||
/// The color of the overlay surface on top of panel.
|
||||
pub panel_overlay_background: Hsla,
|
||||
/// The color of the overlay surface on top of panel when hovered over.
|
||||
pub panel_overlay_hover: Hsla,
|
||||
|
||||
pub pane_focused_border: Hsla,
|
||||
pub pane_group_border: Hsla,
|
||||
/// The color of the scrollbar thumb.
|
||||
|
@ -326,6 +332,8 @@ pub enum ThemeColorField {
|
|||
PanelIndentGuide,
|
||||
PanelIndentGuideHover,
|
||||
PanelIndentGuideActive,
|
||||
PanelOverlayBackground,
|
||||
PanelOverlayHover,
|
||||
PaneFocusedBorder,
|
||||
PaneGroupBorder,
|
||||
ScrollbarThumbBackground,
|
||||
|
@ -438,6 +446,8 @@ impl ThemeColors {
|
|||
ThemeColorField::PanelIndentGuide => self.panel_indent_guide,
|
||||
ThemeColorField::PanelIndentGuideHover => self.panel_indent_guide_hover,
|
||||
ThemeColorField::PanelIndentGuideActive => self.panel_indent_guide_active,
|
||||
ThemeColorField::PanelOverlayBackground => self.panel_overlay_background,
|
||||
ThemeColorField::PanelOverlayHover => self.panel_overlay_hover,
|
||||
ThemeColorField::PaneFocusedBorder => self.pane_focused_border,
|
||||
ThemeColorField::PaneGroupBorder => self.pane_group_border,
|
||||
ThemeColorField::ScrollbarThumbBackground => self.scrollbar_thumb_background,
|
||||
|
|
|
@ -327,6 +327,7 @@ impl PickerDelegate for IconThemeSelectorDelegate {
|
|||
window.dispatch_action(
|
||||
Box::new(Extensions {
|
||||
category_filter: Some(ExtensionCategoryFilter::IconThemes),
|
||||
id: None,
|
||||
}),
|
||||
cx,
|
||||
);
|
||||
|
|
|
@ -385,6 +385,7 @@ impl PickerDelegate for ThemeSelectorDelegate {
|
|||
window.dispatch_action(
|
||||
Box::new(Extensions {
|
||||
category_filter: Some(ExtensionCategoryFilter::Themes),
|
||||
id: None,
|
||||
}),
|
||||
cx,
|
||||
);
|
||||
|
|
|
@ -149,47 +149,7 @@ where
|
|||
) -> AnyElement {
|
||||
let entries = (self.compute_fn)(visible_range.clone(), window, cx);
|
||||
|
||||
struct StickyAnchor<T> {
|
||||
entry: T,
|
||||
index: usize,
|
||||
}
|
||||
|
||||
let mut sticky_anchor = None;
|
||||
let mut last_item_is_drifting = false;
|
||||
|
||||
let mut iter = entries.iter().enumerate().peekable();
|
||||
while let Some((ix, current_entry)) = iter.next() {
|
||||
let depth = current_entry.depth();
|
||||
|
||||
if depth < ix {
|
||||
sticky_anchor = Some(StickyAnchor {
|
||||
entry: current_entry.clone(),
|
||||
index: visible_range.start + ix,
|
||||
});
|
||||
break;
|
||||
}
|
||||
|
||||
if let Some(&(_next_ix, next_entry)) = iter.peek() {
|
||||
let next_depth = next_entry.depth();
|
||||
let next_item_outdented = next_depth + 1 == depth;
|
||||
|
||||
let depth_same_as_index = depth == ix;
|
||||
let depth_greater_than_index = depth == ix + 1;
|
||||
|
||||
if next_item_outdented && (depth_same_as_index || depth_greater_than_index) {
|
||||
if depth_greater_than_index {
|
||||
last_item_is_drifting = true;
|
||||
}
|
||||
sticky_anchor = Some(StickyAnchor {
|
||||
entry: current_entry.clone(),
|
||||
index: visible_range.start + ix,
|
||||
});
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let Some(sticky_anchor) = sticky_anchor else {
|
||||
let Some(sticky_anchor) = find_sticky_anchor(&entries, visible_range.start) else {
|
||||
return StickyItemsElement {
|
||||
drifting_element: None,
|
||||
drifting_decoration: None,
|
||||
|
@ -203,23 +163,21 @@ where
|
|||
let mut elements = (self.render_fn)(sticky_anchor.entry, window, cx);
|
||||
let items_count = elements.len();
|
||||
|
||||
let indents: SmallVec<[usize; 8]> = {
|
||||
elements
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(ix, _)| anchor_depth.saturating_sub(items_count.saturating_sub(ix)))
|
||||
.collect()
|
||||
};
|
||||
let indents: SmallVec<[usize; 8]> = (0..items_count)
|
||||
.map(|ix| anchor_depth.saturating_sub(items_count.saturating_sub(ix)))
|
||||
.collect();
|
||||
|
||||
let mut last_decoration_element = None;
|
||||
let mut rest_decoration_elements = SmallVec::new();
|
||||
|
||||
let available_space = size(
|
||||
AvailableSpace::Definite(bounds.size.width),
|
||||
let expanded_width = bounds.size.width + scroll_offset.x.abs();
|
||||
|
||||
let decor_available_space = size(
|
||||
AvailableSpace::Definite(expanded_width),
|
||||
AvailableSpace::Definite(bounds.size.height),
|
||||
);
|
||||
|
||||
let drifting_y_offset = if last_item_is_drifting {
|
||||
let drifting_y_offset = if sticky_anchor.drifting {
|
||||
let scroll_top = -scroll_offset.y;
|
||||
let anchor_top = item_height * (sticky_anchor.index + 1);
|
||||
let sticky_area_height = item_height * items_count;
|
||||
|
@ -228,7 +186,7 @@ where
|
|||
Pixels::ZERO
|
||||
};
|
||||
|
||||
let (drifting_indent, rest_indents) = if last_item_is_drifting && !indents.is_empty() {
|
||||
let (drifting_indent, rest_indents) = if sticky_anchor.drifting && !indents.is_empty() {
|
||||
let last = indents[indents.len() - 1];
|
||||
let rest: SmallVec<[usize; 8]> = indents[..indents.len() - 1].iter().copied().collect();
|
||||
(Some(last), rest)
|
||||
|
@ -236,11 +194,14 @@ where
|
|||
(None, indents)
|
||||
};
|
||||
|
||||
let base_origin = bounds.origin - point(px(0.), scroll_offset.y);
|
||||
|
||||
for decoration in &self.decorations {
|
||||
if let Some(drifting_indent) = drifting_indent {
|
||||
let drifting_indent_vec: SmallVec<[usize; 8]> =
|
||||
[drifting_indent].into_iter().collect();
|
||||
let sticky_origin = bounds.origin - scroll_offset
|
||||
|
||||
let sticky_origin = base_origin
|
||||
+ point(px(0.), item_height * rest_indents.len() + drifting_y_offset);
|
||||
let decoration_bounds = Bounds::new(sticky_origin, bounds.size);
|
||||
|
||||
|
@ -252,13 +213,13 @@ where
|
|||
window,
|
||||
cx,
|
||||
);
|
||||
drifting_dec.layout_as_root(available_space, window, cx);
|
||||
drifting_dec.layout_as_root(decor_available_space, window, cx);
|
||||
drifting_dec.prepaint_at(sticky_origin, window, cx);
|
||||
last_decoration_element = Some(drifting_dec);
|
||||
}
|
||||
|
||||
if !rest_indents.is_empty() {
|
||||
let decoration_bounds = Bounds::new(bounds.origin - scroll_offset, bounds.size);
|
||||
let decoration_bounds = Bounds::new(base_origin, bounds.size);
|
||||
let mut rest_dec = decoration.as_ref().compute(
|
||||
&rest_indents,
|
||||
decoration_bounds,
|
||||
|
@ -267,46 +228,45 @@ where
|
|||
window,
|
||||
cx,
|
||||
);
|
||||
rest_dec.layout_as_root(available_space, window, cx);
|
||||
rest_dec.layout_as_root(decor_available_space, window, cx);
|
||||
rest_dec.prepaint_at(bounds.origin, window, cx);
|
||||
rest_decoration_elements.push(rest_dec);
|
||||
}
|
||||
}
|
||||
|
||||
let (mut drifting_element, mut rest_elements) =
|
||||
if last_item_is_drifting && !elements.is_empty() {
|
||||
if sticky_anchor.drifting && !elements.is_empty() {
|
||||
let last = elements.pop().unwrap();
|
||||
(Some(last), elements)
|
||||
} else {
|
||||
(None, elements)
|
||||
};
|
||||
|
||||
for (ix, element) in rest_elements.iter_mut().enumerate() {
|
||||
let sticky_origin = bounds.origin - scroll_offset + point(px(0.), item_height * ix);
|
||||
let element_available_space = size(
|
||||
AvailableSpace::Definite(bounds.size.width),
|
||||
AvailableSpace::Definite(item_height),
|
||||
);
|
||||
|
||||
element.layout_as_root(element_available_space, window, cx);
|
||||
element.prepaint_at(sticky_origin, window, cx);
|
||||
}
|
||||
let element_available_space = size(
|
||||
AvailableSpace::Definite(expanded_width),
|
||||
AvailableSpace::Definite(item_height),
|
||||
);
|
||||
|
||||
// order of prepaint is important here
|
||||
// mouse events checks hitboxes in reverse insertion order
|
||||
if let Some(ref mut drifting_element) = drifting_element {
|
||||
let sticky_origin = bounds.origin - scroll_offset
|
||||
let sticky_origin = base_origin
|
||||
+ point(
|
||||
px(0.),
|
||||
item_height * rest_elements.len() + drifting_y_offset,
|
||||
);
|
||||
let element_available_space = size(
|
||||
AvailableSpace::Definite(bounds.size.width),
|
||||
AvailableSpace::Definite(item_height),
|
||||
);
|
||||
|
||||
drifting_element.layout_as_root(element_available_space, window, cx);
|
||||
drifting_element.prepaint_at(sticky_origin, window, cx);
|
||||
}
|
||||
|
||||
for (ix, element) in rest_elements.iter_mut().enumerate() {
|
||||
let sticky_origin = base_origin + point(px(0.), item_height * ix);
|
||||
|
||||
element.layout_as_root(element_available_space, window, cx);
|
||||
element.prepaint_at(sticky_origin, window, cx);
|
||||
}
|
||||
|
||||
StickyItemsElement {
|
||||
drifting_element,
|
||||
drifting_decoration: last_decoration_element,
|
||||
|
@ -317,6 +277,48 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
struct StickyAnchor<T> {
|
||||
entry: T,
|
||||
index: usize,
|
||||
drifting: bool,
|
||||
}
|
||||
|
||||
fn find_sticky_anchor<T: StickyCandidate + Clone>(
|
||||
entries: &SmallVec<[T; 8]>,
|
||||
visible_range_start: usize,
|
||||
) -> Option<StickyAnchor<T>> {
|
||||
let mut iter = entries.iter().enumerate().peekable();
|
||||
while let Some((ix, current_entry)) = iter.next() {
|
||||
let depth = current_entry.depth();
|
||||
|
||||
if depth < ix {
|
||||
return Some(StickyAnchor {
|
||||
entry: current_entry.clone(),
|
||||
index: visible_range_start + ix,
|
||||
drifting: false,
|
||||
});
|
||||
}
|
||||
|
||||
if let Some(&(_next_ix, next_entry)) = iter.peek() {
|
||||
let next_depth = next_entry.depth();
|
||||
let next_item_outdented = next_depth + 1 == depth;
|
||||
|
||||
let depth_same_as_index = depth == ix;
|
||||
let depth_greater_than_index = depth == ix + 1;
|
||||
|
||||
if next_item_outdented && (depth_same_as_index || depth_greater_than_index) {
|
||||
return Some(StickyAnchor {
|
||||
entry: current_entry.clone(),
|
||||
index: visible_range_start + ix,
|
||||
drifting: depth_greater_than_index,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
/// A decoration for a [`StickyItems`]. This can be used for various things,
|
||||
/// such as rendering indent guides, or other visual effects.
|
||||
pub trait StickyItemsDecoration {
|
||||
|
|
|
@ -230,7 +230,11 @@ fn scroll_editor(
|
|||
// column position, or the right-most column in the current
|
||||
// line, seeing as the cursor might be in a short line, in which
|
||||
// case we don't want to go past its last column.
|
||||
let max_row_column = map.line_len(new_row);
|
||||
let max_row_column = if new_row <= map.max_point().row() {
|
||||
map.line_len(new_row)
|
||||
} else {
|
||||
0
|
||||
};
|
||||
let max_column = match min_column + visible_column_count as u32 {
|
||||
max_column if max_column >= max_row_column => max_row_column,
|
||||
max_column => max_column,
|
||||
|
|
23
crates/x_ai/Cargo.toml
Normal file
23
crates/x_ai/Cargo.toml
Normal file
|
@ -0,0 +1,23 @@
|
|||
[package]
|
||||
name = "x_ai"
|
||||
version = "0.1.0"
|
||||
edition.workspace = true
|
||||
publish.workspace = true
|
||||
license = "GPL-3.0-or-later"
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[lib]
|
||||
path = "src/x_ai.rs"
|
||||
|
||||
[features]
|
||||
default = []
|
||||
schemars = ["dep:schemars"]
|
||||
|
||||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
schemars = { workspace = true, optional = true }
|
||||
serde.workspace = true
|
||||
strum.workspace = true
|
||||
workspace-hack.workspace = true
|
1
crates/x_ai/LICENSE-GPL
Symbolic link
1
crates/x_ai/LICENSE-GPL
Symbolic link
|
@ -0,0 +1 @@
|
|||
../../LICENSE-GPL
|
126
crates/x_ai/src/x_ai.rs
Normal file
126
crates/x_ai/src/x_ai.rs
Normal file
|
@ -0,0 +1,126 @@
|
|||
use anyhow::Result;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use strum::EnumIter;
|
||||
|
||||
pub const XAI_API_URL: &str = "https://api.x.ai/v1";
|
||||
|
||||
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
|
||||
#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
|
||||
pub enum Model {
|
||||
#[serde(rename = "grok-2-vision-latest")]
|
||||
Grok2Vision,
|
||||
#[default]
|
||||
#[serde(rename = "grok-3-latest")]
|
||||
Grok3,
|
||||
#[serde(rename = "grok-3-mini-latest")]
|
||||
Grok3Mini,
|
||||
#[serde(rename = "grok-3-fast-latest")]
|
||||
Grok3Fast,
|
||||
#[serde(rename = "grok-3-mini-fast-latest")]
|
||||
Grok3MiniFast,
|
||||
#[serde(rename = "grok-4-latest")]
|
||||
Grok4,
|
||||
#[serde(rename = "custom")]
|
||||
Custom {
|
||||
name: String,
|
||||
/// The name displayed in the UI, such as in the assistant panel model dropdown menu.
|
||||
display_name: Option<String>,
|
||||
max_tokens: u64,
|
||||
max_output_tokens: Option<u64>,
|
||||
max_completion_tokens: Option<u64>,
|
||||
},
|
||||
}
|
||||
|
||||
impl Model {
|
||||
pub fn default_fast() -> Self {
|
||||
Self::Grok3Fast
|
||||
}
|
||||
|
||||
pub fn from_id(id: &str) -> Result<Self> {
|
||||
match id {
|
||||
"grok-2-vision" => Ok(Self::Grok2Vision),
|
||||
"grok-3" => Ok(Self::Grok3),
|
||||
"grok-3-mini" => Ok(Self::Grok3Mini),
|
||||
"grok-3-fast" => Ok(Self::Grok3Fast),
|
||||
"grok-3-mini-fast" => Ok(Self::Grok3MiniFast),
|
||||
_ => anyhow::bail!("invalid model id '{id}'"),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn id(&self) -> &str {
|
||||
match self {
|
||||
Self::Grok2Vision => "grok-2-vision",
|
||||
Self::Grok3 => "grok-3",
|
||||
Self::Grok3Mini => "grok-3-mini",
|
||||
Self::Grok3Fast => "grok-3-fast",
|
||||
Self::Grok3MiniFast => "grok-3-mini-fast",
|
||||
Self::Grok4 => "grok-4",
|
||||
Self::Custom { name, .. } => name,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn display_name(&self) -> &str {
|
||||
match self {
|
||||
Self::Grok2Vision => "Grok 2 Vision",
|
||||
Self::Grok3 => "Grok 3",
|
||||
Self::Grok3Mini => "Grok 3 Mini",
|
||||
Self::Grok3Fast => "Grok 3 Fast",
|
||||
Self::Grok3MiniFast => "Grok 3 Mini Fast",
|
||||
Self::Grok4 => "Grok 4",
|
||||
Self::Custom {
|
||||
name, display_name, ..
|
||||
} => display_name.as_ref().unwrap_or(name),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn max_token_count(&self) -> u64 {
|
||||
match self {
|
||||
Self::Grok3 | Self::Grok3Mini | Self::Grok3Fast | Self::Grok3MiniFast => 131_072,
|
||||
Self::Grok4 => 256_000,
|
||||
Self::Grok2Vision => 8_192,
|
||||
Self::Custom { max_tokens, .. } => *max_tokens,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn max_output_tokens(&self) -> Option<u64> {
|
||||
match self {
|
||||
Self::Grok3 | Self::Grok3Mini | Self::Grok3Fast | Self::Grok3MiniFast => Some(8_192),
|
||||
Self::Grok4 => Some(64_000),
|
||||
Self::Grok2Vision => Some(4_096),
|
||||
Self::Custom {
|
||||
max_output_tokens, ..
|
||||
} => *max_output_tokens,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn supports_parallel_tool_calls(&self) -> bool {
|
||||
match self {
|
||||
Self::Grok2Vision
|
||||
| Self::Grok3
|
||||
| Self::Grok3Mini
|
||||
| Self::Grok3Fast
|
||||
| Self::Grok3MiniFast
|
||||
| Self::Grok4 => true,
|
||||
Model::Custom { .. } => false,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn supports_tool(&self) -> bool {
|
||||
match self {
|
||||
Self::Grok2Vision
|
||||
| Self::Grok3
|
||||
| Self::Grok3Mini
|
||||
| Self::Grok3Fast
|
||||
| Self::Grok3MiniFast
|
||||
| Self::Grok4 => true,
|
||||
Model::Custom { .. } => false,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn supports_images(&self) -> bool {
|
||||
match self {
|
||||
Self::Grok2Vision => true,
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
}
|
|
@ -2,7 +2,7 @@
|
|||
description = "The fast, collaborative code editor."
|
||||
edition.workspace = true
|
||||
name = "zed"
|
||||
version = "0.195.0"
|
||||
version = "0.195.5"
|
||||
publish.workspace = true
|
||||
license = "GPL-3.0-or-later"
|
||||
authors = ["Zed Team <hi@zed.dev>"]
|
||||
|
|
|
@ -1 +1 @@
|
|||
dev
|
||||
stable
|
|
@ -725,6 +725,23 @@ fn handle_open_request(request: OpenRequest, app_state: Arc<AppState>, cx: &mut
|
|||
return;
|
||||
}
|
||||
|
||||
if let Some(extension) = request.extension_id {
|
||||
cx.spawn(async move |cx| {
|
||||
let workspace = workspace::get_any_active_workspace(app_state, cx.clone()).await?;
|
||||
workspace.update(cx, |_, window, cx| {
|
||||
window.dispatch_action(
|
||||
Box::new(zed_actions::Extensions {
|
||||
category_filter: None,
|
||||
id: Some(extension),
|
||||
}),
|
||||
cx,
|
||||
);
|
||||
})
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
return;
|
||||
}
|
||||
|
||||
if let Some(connection_options) = request.ssh_connection {
|
||||
cx.spawn(async move |mut cx| {
|
||||
let paths: Vec<PathBuf> = request.open_paths.into_iter().map(PathBuf::from).collect();
|
||||
|
|
|
@ -37,6 +37,7 @@ pub struct OpenRequest {
|
|||
pub join_channel: Option<u64>,
|
||||
pub ssh_connection: Option<SshConnectionOptions>,
|
||||
pub dock_menu_action: Option<usize>,
|
||||
pub extension_id: Option<String>,
|
||||
}
|
||||
|
||||
impl OpenRequest {
|
||||
|
@ -54,6 +55,8 @@ impl OpenRequest {
|
|||
} else if let Some(file) = url.strip_prefix("zed://ssh") {
|
||||
let ssh_url = "ssh:/".to_string() + file;
|
||||
this.parse_ssh_file_path(&ssh_url, cx)?
|
||||
} else if let Some(file) = url.strip_prefix("zed://extension/") {
|
||||
this.extension_id = Some(file.to_string())
|
||||
} else if url.starts_with("ssh://") {
|
||||
this.parse_ssh_file_path(&url, cx)?
|
||||
} else if let Some(request_path) = parse_zed_link(&url, cx) {
|
||||
|
|
|
@ -76,6 +76,9 @@ pub struct Extensions {
|
|||
/// Filters the extensions page down to extensions that are in the specified category.
|
||||
#[serde(default)]
|
||||
pub category_filter: Option<ExtensionCategoryFilter>,
|
||||
/// Focuses just the extension with the specified ID.
|
||||
#[serde(default)]
|
||||
pub id: Option<String>,
|
||||
}
|
||||
|
||||
/// Decreases the font size in the editor buffer.
|
||||
|
|
|
@ -23,6 +23,8 @@ Here's an overview of the supported providers and tool call support:
|
|||
| [OpenAI](#openai) | ✅ |
|
||||
| [OpenAI API Compatible](#openai-api-compatible) | 🚫 |
|
||||
| [OpenRouter](#openrouter) | ✅ |
|
||||
| [Vercel](#vercel-v0) | ✅ |
|
||||
| [xAI](#xai) | ✅ |
|
||||
|
||||
## Use Your Own Keys {#use-your-own-keys}
|
||||
|
||||
|
@ -442,27 +444,30 @@ Custom models will be listed in the model dropdown in the Agent Panel.
|
|||
|
||||
Zed supports using OpenAI compatible APIs by specifying a custom `endpoint` and `available_models` for the OpenAI provider.
|
||||
|
||||
You can add a custom API URL for OpenAI either via the UI or by editing your `settings.json`.
|
||||
Here are a few model examples you can plug in by using this feature:
|
||||
Zed supports using OpenAI compatible APIs by specifying a custom `api_url` and `available_models` for the OpenAI provider. This is useful for connecting to other hosted services (like Together AI, Anyscale, etc.) or local models.
|
||||
|
||||
#### X.ai Grok
|
||||
To configure a compatible API, you can add a custom API URL for OpenAI either via the UI or by editing your `settings.json`. For example, to connect to [Together AI](https://www.together.ai/):
|
||||
|
||||
Example configuration for using X.ai Grok with Zed:
|
||||
1. Get an API key from your [Together AI account](https://api.together.ai/settings/api-keys).
|
||||
2. Add the following to your `settings.json`:
|
||||
|
||||
```json
|
||||
{
|
||||
"language_models": {
|
||||
"openai": {
|
||||
"api_url": "https://api.x.ai/v1",
|
||||
"api_url": "https://api.together.xyz/v1",
|
||||
"api_key": "YOUR_TOGETHER_AI_API_KEY",
|
||||
"available_models": [
|
||||
{
|
||||
"name": "grok-beta",
|
||||
"display_name": "X.ai Grok (Beta)",
|
||||
"max_tokens": 131072
|
||||
"name": "mistralai/Mixtral-8x7B-Instruct-v0.1",
|
||||
"display_name": "Together Mixtral 8x7B",
|
||||
"max_tokens": 32768,
|
||||
"supports_tools": true
|
||||
}
|
||||
],
|
||||
"version": "1"
|
||||
},
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### OpenRouter {#openrouter}
|
||||
|
@ -523,7 +528,9 @@ You can find available models and their specifications on the [OpenRouter models
|
|||
|
||||
Custom models will be listed in the model dropdown in the Agent Panel.
|
||||
|
||||
### Vercel v0
|
||||
### Vercel v0 {#vercel-v0}
|
||||
|
||||
> ✅ Supports tool use
|
||||
|
||||
[Vercel v0](https://vercel.com/docs/v0/api) is an expert model for generating full-stack apps, with framework-aware completions optimized for modern stacks like Next.js and Vercel.
|
||||
It supports text and image inputs and provides fast streaming responses.
|
||||
|
@ -535,6 +542,49 @@ Once you have it, paste it directly into the Vercel provider section in the pane
|
|||
|
||||
You should then find it as `v0-1.5-md` in the model dropdown in the Agent Panel.
|
||||
|
||||
### xAI {#xai}
|
||||
|
||||
> ✅ Supports tool use
|
||||
|
||||
Zed has first-class support for [xAI](https://x.ai/) models. You can use your own API key to access Grok models.
|
||||
|
||||
1. [Create an API key in the xAI Console](https://console.x.ai/team/default/api-keys)
|
||||
2. Open the settings view (`agent: open configuration`) and go to the **xAI** section
|
||||
3. Enter your xAI API key
|
||||
|
||||
The xAI API key will be saved in your keychain. Zed will also use the `XAI_API_KEY` environment variable if it's defined.
|
||||
|
||||
> **Note:** While the xAI API is OpenAI-compatible, Zed has first-class support for it as a dedicated provider. For the best experience, we recommend using the dedicated `x_ai` provider configuration instead of the [OpenAI API Compatible](#openai-api-compatible) method.
|
||||
|
||||
#### Custom Models {#xai-custom-models}
|
||||
|
||||
The Zed agent comes pre-configured with common Grok models. If you wish to use alternate models or customize their parameters, you can do so by adding the following to your Zed `settings.json`:
|
||||
|
||||
```json
|
||||
{
|
||||
"language_models": {
|
||||
"x_ai": {
|
||||
"api_url": "https://api.x.ai/v1",
|
||||
"available_models": [
|
||||
{
|
||||
"name": "grok-1.5",
|
||||
"display_name": "Grok 1.5",
|
||||
"max_tokens": 131072,
|
||||
"max_output_tokens": 8192
|
||||
},
|
||||
{
|
||||
"name": "grok-1.5v",
|
||||
"display_name": "Grok 1.5V (Vision)",
|
||||
"max_tokens": 131072,
|
||||
"max_output_tokens": 8192,
|
||||
"supports_images": true
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Advanced Configuration {#advanced-configuration}
|
||||
|
||||
### Custom Provider Endpoints {#custom-provider-endpoint}
|
||||
|
|
|
@ -148,7 +148,7 @@ On some systems the file `/etc/prime-discrete` can be used to enforce the use of
|
|||
|
||||
On others, you may be able to the environment variable `DRI_PRIME=1` when running Zed to force the use of the discrete GPU.
|
||||
|
||||
If you're using an AMD GPU and Zed crashes when selecting long lines, try setting the `ZED_SAMPLE_COUNT=0` environment variable. (See [#26143](https://github.com/zed-industries/zed/issues/26143))
|
||||
If you're using an AMD GPU and Zed crashes when selecting long lines, try setting the `ZED_PATH_SAMPLE_COUNT=0` environment variable. (See [#26143](https://github.com/zed-industries/zed/issues/26143))
|
||||
|
||||
If you're using an AMD GPU, you might get a 'Broken Pipe' error. Try using the RADV or Mesa drivers. (See [#13880](https://github.com/zed-industries/zed/issues/13880))
|
||||
|
||||
|
|
|
@ -44,8 +44,6 @@ function CheckEnvironmentVariables {
|
|||
}
|
||||
}
|
||||
|
||||
$innoDir = "$env:ZED_WORKSPACE\inno"
|
||||
|
||||
function PrepareForBundle {
|
||||
if (Test-Path "$innoDir") {
|
||||
Remove-Item -Path "$innoDir" -Recurse -Force
|
||||
|
@ -236,6 +234,8 @@ function BuildInstaller {
|
|||
}
|
||||
|
||||
ParseZedWorkspace
|
||||
$innoDir = "$env:ZED_WORKSPACE\inno"
|
||||
|
||||
CheckEnvironmentVariables
|
||||
PrepareForBundle
|
||||
BuildZedAndItsFriends
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue