diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 25a1ed8670..3cf89066c1 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -686,8 +686,10 @@ jobs: timeout-minutes: 60 runs-on: github-8vcpu-ubuntu-2404 if: | + false && ( startsWith(github.ref, 'refs/tags/v') || contains(github.event.pull_request.labels.*.name, 'run-bundling') + ) needs: [linux_tests] name: Build Zed on FreeBSD # env: @@ -757,7 +759,7 @@ jobs: timeout-minutes: 120 name: Create a Windows installer runs-on: [self-hosted, Windows, X64] - if: ${{ startsWith(github.ref, 'refs/tags/v') || contains(github.event.pull_request.labels.*.name, 'run-bundling') }} + if: false && (startsWith(github.ref, 'refs/tags/v') || contains(github.event.pull_request.labels.*.name, 'run-bundling')) needs: [windows_tests] env: AZURE_TENANT_ID: ${{ secrets.AZURE_SIGNING_TENANT_ID }} @@ -800,6 +802,7 @@ jobs: - name: Upload Artifacts to release uses: softprops/action-gh-release@de2c0eb89ae2a093876385947365aca7b0e5f844 # v1 + # Re-enable when we are ready to publish windows preview releases if: ${{ !(contains(github.event.pull_request.labels.*.name, 'run-bundling')) && env.RELEASE_CHANNEL == 'preview' }} # upload only preview with: draft: true @@ -813,7 +816,7 @@ jobs: if: | startsWith(github.ref, 'refs/tags/v') && endsWith(github.ref, '-pre') && !endsWith(github.ref, '.0-pre') - needs: [bundle-mac, bundle-linux-x86_x64, bundle-linux-aarch64, bundle-windows-x64, freebsd] + needs: [bundle-mac, bundle-linux-x86_x64, bundle-linux-aarch64, bundle-windows-x64] runs-on: - self-hosted - bundle diff --git a/.github/workflows/release_nightly.yml b/.github/workflows/release_nightly.yml index df9f6ef40f..600d14dc9f 100644 --- a/.github/workflows/release_nightly.yml +++ b/.github/workflows/release_nightly.yml @@ -195,7 +195,7 @@ jobs: freebsd: timeout-minutes: 60 - if: github.repository_owner == 'zed-industries' + if: false && github.repository_owner == 'zed-industries' runs-on: github-8vcpu-ubuntu-2404 needs: tests env: diff --git a/Cargo.lock b/Cargo.lock index 38bb7819ca..de03330371 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2078,7 +2078,7 @@ dependencies = [ [[package]] name = "blade-graphics" version = "0.6.0" -source = "git+https://github.com/kvark/blade?rev=416375211bb0b5826b3584dccdb6a43369e499ad#416375211bb0b5826b3584dccdb6a43369e499ad" +source = "git+https://github.com/kvark/blade?rev=e0ec4e720957edd51b945b64dd85605ea54bcfe5#e0ec4e720957edd51b945b64dd85605ea54bcfe5" dependencies = [ "ash", "ash-window", @@ -2111,7 +2111,7 @@ dependencies = [ [[package]] name = "blade-macros" version = "0.3.0" -source = "git+https://github.com/kvark/blade?rev=416375211bb0b5826b3584dccdb6a43369e499ad#416375211bb0b5826b3584dccdb6a43369e499ad" +source = "git+https://github.com/kvark/blade?rev=e0ec4e720957edd51b945b64dd85605ea54bcfe5#e0ec4e720957edd51b945b64dd85605ea54bcfe5" dependencies = [ "proc-macro2", "quote", @@ -2121,7 +2121,7 @@ dependencies = [ [[package]] name = "blade-util" version = "0.2.0" -source = "git+https://github.com/kvark/blade?rev=416375211bb0b5826b3584dccdb6a43369e499ad#416375211bb0b5826b3584dccdb6a43369e499ad" +source = "git+https://github.com/kvark/blade?rev=e0ec4e720957edd51b945b64dd85605ea54bcfe5#e0ec4e720957edd51b945b64dd85605ea54bcfe5" dependencies = [ "blade-graphics", "bytemuck", @@ -3043,6 +3043,7 @@ dependencies = [ "context_server", "ctor", "dap", + "dap-types", "dap_adapters", "dashmap 6.1.0", "debugger_ui", @@ -9000,6 +9001,7 @@ dependencies = [ "util", "vercel", "workspace-hack", + "x_ai", "zed_llm_client", ] @@ -19731,6 +19733,17 @@ version = "0.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ec107c4503ea0b4a98ef47356329af139c0a4f7750e621cf2973cd3385ebcb3d" +[[package]] +name = "x_ai" +version = "0.1.0" +dependencies = [ + "anyhow", + "schemars", + "serde", + "strum 0.27.1", + "workspace-hack", +] + [[package]] name = "xattr" version = "0.2.3" @@ -19972,7 +19985,7 @@ dependencies = [ [[package]] name = "zed" -version = "0.195.0" +version = "0.195.5" dependencies = [ "activity_indicator", "agent", diff --git a/Cargo.toml b/Cargo.toml index a4d8b3cb95..d4e6bb09bd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -177,6 +177,7 @@ members = [ "crates/welcome", "crates/workspace", "crates/worktree", + "crates/x_ai", "crates/zed", "crates/zed_actions", "crates/zeta", @@ -390,6 +391,7 @@ web_search_providers = { path = "crates/web_search_providers" } welcome = { path = "crates/welcome" } workspace = { path = "crates/workspace" } worktree = { path = "crates/worktree" } +x_ai = { path = "crates/x_ai" } zed = { path = "crates/zed" } zed_actions = { path = "crates/zed_actions" } zeta = { path = "crates/zeta" } @@ -427,9 +429,9 @@ aws-smithy-runtime-api = { version = "1.7.4", features = ["http-1x", "client"] } aws-smithy-types = { version = "1.3.0", features = ["http-body-1-x"] } base64 = "0.22" bitflags = "2.6.0" -blade-graphics = { git = "https://github.com/kvark/blade", rev = "416375211bb0b5826b3584dccdb6a43369e499ad" } -blade-macros = { git = "https://github.com/kvark/blade", rev = "416375211bb0b5826b3584dccdb6a43369e499ad" } -blade-util = { git = "https://github.com/kvark/blade", rev = "416375211bb0b5826b3584dccdb6a43369e499ad" } +blade-graphics = { git = "https://github.com/kvark/blade", rev = "e0ec4e720957edd51b945b64dd85605ea54bcfe5" } +blade-macros = { git = "https://github.com/kvark/blade", rev = "e0ec4e720957edd51b945b64dd85605ea54bcfe5" } +blade-util = { git = "https://github.com/kvark/blade", rev = "e0ec4e720957edd51b945b64dd85605ea54bcfe5" } blake3 = "1.5.3" bytes = "1.0" cargo_metadata = "0.19" @@ -482,7 +484,7 @@ json_dotpath = "1.1" jsonschema = "0.30.0" jsonwebtoken = "9.3" jupyter-protocol = { git = "https://github.com/ConradIrwin/runtimed", rev = "7130c804216b6914355d15d0b91ea91f6babd734" } -jupyter-websocket-client = { git = "https://github.com/ConradIrwin/runtimed", rev = "7130c804216b6914355d15d0b91ea91f6babd734" } +jupyter-websocket-client = { git = "https://github.com/ConradIrwin/runtimed" ,rev = "7130c804216b6914355d15d0b91ea91f6babd734" } libc = "0.2" libsqlite3-sys = { version = "0.30.1", features = ["bundled"] } linkify = "0.10.0" @@ -493,7 +495,7 @@ metal = "0.29" moka = { version = "0.12.10", features = ["sync"] } naga = { version = "25.0", features = ["wgsl-in"] } nanoid = "0.4" -nbformat = { git = "https://github.com/ConradIrwin/runtimed", rev = "7130c804216b6914355d15d0b91ea91f6babd734" } +nbformat = { git = "https://github.com/ConradIrwin/runtimed", rev = "7130c804216b6914355d15d0b91ea91f6babd734" } nix = "0.29" num-format = "0.4.4" objc = "0.2" @@ -533,7 +535,7 @@ reqwest = { git = "https://github.com/zed-industries/reqwest.git", rev = "951c77 "stream", ] } rsa = "0.9.6" -runtimelib = { git = "https://github.com/ConradIrwin/runtimed", rev = "7130c804216b6914355d15d0b91ea91f6babd734", default-features = false, features = [ +runtimelib = { git = "https://github.com/ConradIrwin/runtimed", rev = "7130c804216b6914355d15d0b91ea91f6babd734", default-features = false, features = [ "async-dispatcher-runtime", ] } rust-embed = { version = "8.4", features = ["include-exclude"] } diff --git a/assets/icons/ai_x_ai.svg b/assets/icons/ai_x_ai.svg new file mode 100644 index 0000000000..289525c8ef --- /dev/null +++ b/assets/icons/ai_x_ai.svg @@ -0,0 +1,3 @@ + + + diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index 1f2654dac5..83ed90e18d 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -21,6 +21,7 @@ use gpui::{ AnyWindowHandle, App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity, Window, }; +use http_client::StatusCode; use language_model::{ ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelExt as _, LanguageModelId, LanguageModelRegistry, LanguageModelRequest, @@ -51,7 +52,19 @@ use uuid::Uuid; use zed_llm_client::{CompletionIntent, CompletionRequestStatus, UsageLimit}; const MAX_RETRY_ATTEMPTS: u8 = 3; -const BASE_RETRY_DELAY_SECS: u64 = 5; +const BASE_RETRY_DELAY: Duration = Duration::from_secs(5); + +#[derive(Debug, Clone)] +enum RetryStrategy { + ExponentialBackoff { + initial_delay: Duration, + max_attempts: u8, + }, + Fixed { + delay: Duration, + max_attempts: u8, + }, +} #[derive( Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema, @@ -383,6 +396,7 @@ pub struct Thread { remaining_turns: u32, configured_model: Option, profile: AgentProfile, + last_error_context: Option<(Arc, CompletionIntent)>, } #[derive(Clone, Debug)] @@ -476,10 +490,11 @@ impl Thread { retry_state: None, message_feedback: HashMap::default(), last_auto_capture_at: None, + last_error_context: None, last_received_chunk_at: None, request_callback: None, remaining_turns: u32::MAX, - configured_model, + configured_model: configured_model.clone(), profile: AgentProfile::new(profile_id, tools), } } @@ -600,6 +615,7 @@ impl Thread { feedback: None, message_feedback: HashMap::default(), last_auto_capture_at: None, + last_error_context: None, last_received_chunk_at: None, request_callback: None, remaining_turns: u32::MAX, @@ -1251,9 +1267,58 @@ impl Thread { self.flush_notifications(model.clone(), intent, cx); - let request = self.to_completion_request(model.clone(), intent, cx); + let _checkpoint = self.finalize_pending_checkpoint(cx); + self.stream_completion( + self.to_completion_request(model.clone(), intent, cx), + model, + intent, + window, + cx, + ); + } - self.stream_completion(request, model, intent, window, cx); + pub fn retry_last_completion( + &mut self, + window: Option, + cx: &mut Context, + ) { + // Clear any existing error state + self.retry_state = None; + + // Use the last error context if available, otherwise fall back to configured model + let (model, intent) = if let Some((model, intent)) = self.last_error_context.take() { + (model, intent) + } else if let Some(configured_model) = self.configured_model.as_ref() { + let model = configured_model.model.clone(); + let intent = if self.has_pending_tool_uses() { + CompletionIntent::ToolResults + } else { + CompletionIntent::UserPrompt + }; + (model, intent) + } else if let Some(configured_model) = self.get_or_init_configured_model(cx) { + let model = configured_model.model.clone(); + let intent = if self.has_pending_tool_uses() { + CompletionIntent::ToolResults + } else { + CompletionIntent::UserPrompt + }; + (model, intent) + } else { + return; + }; + + self.send_to_model(model, intent, window, cx); + } + + pub fn enable_burn_mode_and_retry( + &mut self, + window: Option, + cx: &mut Context, + ) { + self.completion_mode = CompletionMode::Burn; + cx.emit(ThreadEvent::ProfileChanged); + self.retry_last_completion(window, cx); } pub fn used_tools_since_last_user_message(&self) -> bool { @@ -1931,18 +1996,6 @@ impl Thread { project.set_agent_location(None, cx); }); - fn emit_generic_error(error: &anyhow::Error, cx: &mut Context) { - let error_message = error - .chain() - .map(|err| err.to_string()) - .collect::>() - .join("\n"); - cx.emit(ThreadEvent::ShowError(ThreadError::Message { - header: "Error interacting with language model".into(), - message: SharedString::from(error_message.clone()), - })); - } - if error.is::() { cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired)); } else if let Some(error) = @@ -1954,9 +2007,10 @@ impl Thread { } else if let Some(completion_error) = error.downcast_ref::() { - use LanguageModelCompletionError::*; match &completion_error { - PromptTooLarge { tokens, .. } => { + LanguageModelCompletionError::PromptTooLarge { + tokens, .. + } => { let tokens = tokens.unwrap_or_else(|| { // We didn't get an exact token count from the API, so fall back on our estimate. thread @@ -1977,63 +2031,22 @@ impl Thread { }); cx.notify(); } - RateLimitExceeded { - retry_after: Some(retry_after), - .. - } - | ServerOverloaded { - retry_after: Some(retry_after), - .. - } => { - thread.handle_rate_limit_error( - &completion_error, - *retry_after, - model.clone(), - intent, - window, - cx, - ); - retry_scheduled = true; - } - RateLimitExceeded { .. } | ServerOverloaded { .. } => { - retry_scheduled = thread.handle_retryable_error( - &completion_error, - model.clone(), - intent, - window, - cx, - ); - if !retry_scheduled { - emit_generic_error(error, cx); + _ => { + if let Some(retry_strategy) = + Thread::get_retry_strategy(completion_error) + { + retry_scheduled = thread + .handle_retryable_error_with_delay( + &completion_error, + Some(retry_strategy), + model.clone(), + intent, + window, + cx, + ); } } - ApiInternalServerError { .. } - | ApiReadResponseError { .. } - | HttpSend { .. } => { - retry_scheduled = thread.handle_retryable_error( - &completion_error, - model.clone(), - intent, - window, - cx, - ); - if !retry_scheduled { - emit_generic_error(error, cx); - } - } - NoApiKey { .. } - | HttpResponseError { .. } - | BadRequestFormat { .. } - | AuthenticationError { .. } - | PermissionError { .. } - | ApiEndpointNotFound { .. } - | SerializeRequest { .. } - | BuildRequestBody { .. } - | DeserializeResponse { .. } - | Other { .. } => emit_generic_error(error, cx), } - } else { - emit_generic_error(error, cx); } if !retry_scheduled { @@ -2160,73 +2173,132 @@ impl Thread { }); } - fn handle_rate_limit_error( - &mut self, - error: &LanguageModelCompletionError, - retry_after: Duration, - model: Arc, - intent: CompletionIntent, - window: Option, - cx: &mut Context, - ) { - // For rate limit errors, we only retry once with the specified duration - let retry_message = format!("{error}. Retrying in {} seconds…", retry_after.as_secs()); - log::warn!( - "Retrying completion request in {} seconds: {error:?}", - retry_after.as_secs(), - ); + fn get_retry_strategy(error: &LanguageModelCompletionError) -> Option { + use LanguageModelCompletionError::*; - // Add a UI-only message instead of a regular message - let id = self.next_message_id.post_inc(); - self.messages.push(Message { - id, - role: Role::System, - segments: vec![MessageSegment::Text(retry_message)], - loaded_context: LoadedContext::default(), - creases: Vec::new(), - is_hidden: false, - ui_only: true, - }); - cx.emit(ThreadEvent::MessageAdded(id)); - // Schedule the retry - let thread_handle = cx.entity().downgrade(); - - cx.spawn(async move |_thread, cx| { - cx.background_executor().timer(retry_after).await; - - thread_handle - .update(cx, |thread, cx| { - // Retry the completion - thread.send_to_model(model, intent, window, cx); + // General strategy here: + // - If retrying won't help (e.g. invalid API key or payload too large), return None so we don't retry at all. + // - If it's a time-based issue (e.g. server overloaded, rate limit exceeded), try multiple times with exponential backoff. + // - If it's an issue that *might* be fixed by retrying (e.g. internal server error), just retry once. + match error { + HttpResponseError { + status_code: StatusCode::TOO_MANY_REQUESTS, + .. + } => Some(RetryStrategy::ExponentialBackoff { + initial_delay: BASE_RETRY_DELAY, + max_attempts: MAX_RETRY_ATTEMPTS, + }), + ServerOverloaded { retry_after, .. } | RateLimitExceeded { retry_after, .. } => { + Some(RetryStrategy::Fixed { + delay: retry_after.unwrap_or(BASE_RETRY_DELAY), + max_attempts: MAX_RETRY_ATTEMPTS, }) - .log_err(); - }) - .detach(); - } - - fn handle_retryable_error( - &mut self, - error: &LanguageModelCompletionError, - model: Arc, - intent: CompletionIntent, - window: Option, - cx: &mut Context, - ) -> bool { - self.handle_retryable_error_with_delay(error, None, model, intent, window, cx) + } + UpstreamProviderError { + status, + retry_after, + .. + } => match *status { + StatusCode::TOO_MANY_REQUESTS | StatusCode::SERVICE_UNAVAILABLE => { + Some(RetryStrategy::Fixed { + delay: retry_after.unwrap_or(BASE_RETRY_DELAY), + max_attempts: MAX_RETRY_ATTEMPTS, + }) + } + StatusCode::INTERNAL_SERVER_ERROR => Some(RetryStrategy::Fixed { + delay: retry_after.unwrap_or(BASE_RETRY_DELAY), + // Internal Server Error could be anything, so only retry once. + max_attempts: 1, + }), + status => { + // There is no StatusCode variant for the unofficial HTTP 529 ("The service is overloaded"), + // but we frequently get them in practice. See https://http.dev/529 + if status.as_u16() == 529 { + Some(RetryStrategy::Fixed { + delay: retry_after.unwrap_or(BASE_RETRY_DELAY), + max_attempts: MAX_RETRY_ATTEMPTS, + }) + } else { + None + } + } + }, + ApiInternalServerError { .. } => Some(RetryStrategy::Fixed { + delay: BASE_RETRY_DELAY, + max_attempts: 1, + }), + ApiReadResponseError { .. } + | HttpSend { .. } + | DeserializeResponse { .. } + | BadRequestFormat { .. } => Some(RetryStrategy::Fixed { + delay: BASE_RETRY_DELAY, + max_attempts: 1, + }), + // Retrying these errors definitely shouldn't help. + HttpResponseError { + status_code: + StatusCode::PAYLOAD_TOO_LARGE | StatusCode::FORBIDDEN | StatusCode::UNAUTHORIZED, + .. + } + | SerializeRequest { .. } + | BuildRequestBody { .. } + | PromptTooLarge { .. } + | AuthenticationError { .. } + | PermissionError { .. } + | ApiEndpointNotFound { .. } + | NoApiKey { .. } => None, + // Retry all other 4xx and 5xx errors once. + HttpResponseError { status_code, .. } + if status_code.is_client_error() || status_code.is_server_error() => + { + Some(RetryStrategy::Fixed { + delay: BASE_RETRY_DELAY, + max_attempts: 1, + }) + } + // Conservatively assume that any other errors are non-retryable + HttpResponseError { .. } | Other(..) => None, + } } fn handle_retryable_error_with_delay( &mut self, error: &LanguageModelCompletionError, - custom_delay: Option, + strategy: Option, model: Arc, intent: CompletionIntent, window: Option, cx: &mut Context, ) -> bool { + // Store context for the Retry button + self.last_error_context = Some((model.clone(), intent)); + + // Only auto-retry if Burn Mode is enabled + if self.completion_mode != CompletionMode::Burn { + // Show error with retry options + cx.emit(ThreadEvent::ShowError(ThreadError::RetryableError { + message: format!( + "{}\n\nTo automatically retry when similar errors happen, enable Burn Mode.", + error + ) + .into(), + can_enable_burn_mode: true, + })); + return false; + } + + let Some(strategy) = strategy.or_else(|| Self::get_retry_strategy(error)) else { + return false; + }; + + let max_attempts = match &strategy { + RetryStrategy::ExponentialBackoff { max_attempts, .. } => *max_attempts, + RetryStrategy::Fixed { max_attempts, .. } => *max_attempts, + }; + let retry_state = self.retry_state.get_or_insert(RetryState { attempt: 0, - max_attempts: MAX_RETRY_ATTEMPTS, + max_attempts, intent, }); @@ -2236,20 +2308,24 @@ impl Thread { let intent = retry_state.intent; if attempt <= max_attempts { - // Use custom delay if provided (e.g., from rate limit), otherwise exponential backoff - let delay = if let Some(custom_delay) = custom_delay { - custom_delay - } else { - let delay_secs = BASE_RETRY_DELAY_SECS * 2u64.pow((attempt - 1) as u32); - Duration::from_secs(delay_secs) + let delay = match &strategy { + RetryStrategy::ExponentialBackoff { initial_delay, .. } => { + let delay_secs = initial_delay.as_secs() * 2u64.pow((attempt - 1) as u32); + Duration::from_secs(delay_secs) + } + RetryStrategy::Fixed { delay, .. } => *delay, }; // Add a transient message to inform the user let delay_secs = delay.as_secs(); - let retry_message = format!( - "{error}. Retrying (attempt {attempt} of {max_attempts}) \ - in {delay_secs} seconds..." - ); + let retry_message = if max_attempts == 1 { + format!("{error}. Retrying in {delay_secs} seconds...") + } else { + format!( + "{error}. Retrying (attempt {attempt} of {max_attempts}) \ + in {delay_secs} seconds..." + ) + }; log::warn!( "Retrying completion request (attempt {attempt} of {max_attempts}) \ in {delay_secs} seconds: {error:?}", @@ -2288,18 +2364,15 @@ impl Thread { // Max retries exceeded self.retry_state = None; - let notification_text = if max_attempts == 1 { - "Failed after retrying.".into() - } else { - format!("Failed after retrying {} times.", max_attempts).into() - }; - // Stop generating since we're giving up on retrying. self.pending_completions.clear(); - cx.emit(ThreadEvent::RetriesFailed { - message: notification_text, - }); + // Show error alongside a Retry button, but no + // Enable Burn Mode button (since it's already enabled) + cx.emit(ThreadEvent::ShowError(ThreadError::RetryableError { + message: format!("Failed after retrying: {}", error).into(), + can_enable_burn_mode: false, + })); false } @@ -3211,6 +3284,11 @@ pub enum ThreadError { header: SharedString, message: SharedString, }, + #[error("Retryable error: {message}")] + RetryableError { + message: SharedString, + can_enable_burn_mode: bool, + }, } #[derive(Debug, Clone)] @@ -3256,9 +3334,6 @@ pub enum ThreadEvent { CancelEditing, CompletionCanceled, ProfileChanged, - RetriesFailed { - message: SharedString, - }, } impl EventEmitter for Thread {} @@ -4169,6 +4244,11 @@ fn main() {{ let project = create_test_project(cx, json!({})).await; let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await; + // Enable Burn Mode to allow retries + thread.update(cx, |thread, _| { + thread.set_completion_mode(CompletionMode::Burn); + }); + // Create model that returns overloaded error let model = Arc::new(ErrorInjector::new(TestError::Overloaded)); @@ -4190,7 +4270,7 @@ fn main() {{ assert_eq!(retry_state.attempt, 1, "Should be first retry attempt"); assert_eq!( retry_state.max_attempts, MAX_RETRY_ATTEMPTS, - "Should have default max attempts" + "Should retry MAX_RETRY_ATTEMPTS times for overloaded errors" ); }); @@ -4242,6 +4322,11 @@ fn main() {{ let project = create_test_project(cx, json!({})).await; let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await; + // Enable Burn Mode to allow retries + thread.update(cx, |thread, _| { + thread.set_completion_mode(CompletionMode::Burn); + }); + // Create model that returns internal server error let model = Arc::new(ErrorInjector::new(TestError::InternalServerError)); @@ -4263,7 +4348,7 @@ fn main() {{ let retry_state = thread.retry_state.as_ref().unwrap(); assert_eq!(retry_state.attempt, 1, "Should be first retry attempt"); assert_eq!( - retry_state.max_attempts, MAX_RETRY_ATTEMPTS, + retry_state.max_attempts, 1, "Should have correct max attempts" ); }); @@ -4279,8 +4364,8 @@ fn main() {{ if let MessageSegment::Text(text) = seg { text.contains("internal") && text.contains("Fake") - && text - .contains(&format!("attempt 1 of {}", MAX_RETRY_ATTEMPTS)) + && text.contains("Retrying in") + && !text.contains("attempt") } else { false } @@ -4318,8 +4403,13 @@ fn main() {{ let project = create_test_project(cx, json!({})).await; let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await; - // Create model that returns overloaded error - let model = Arc::new(ErrorInjector::new(TestError::Overloaded)); + // Enable Burn Mode to allow retries + thread.update(cx, |thread, _| { + thread.set_completion_mode(CompletionMode::Burn); + }); + + // Create model that returns internal server error + let model = Arc::new(ErrorInjector::new(TestError::InternalServerError)); // Insert a user message thread.update(cx, |thread, cx| { @@ -4369,11 +4459,14 @@ fn main() {{ assert!(thread.retry_state.is_some(), "Should have retry state"); let retry_state = thread.retry_state.as_ref().unwrap(); assert_eq!(retry_state.attempt, 1, "Should be first retry attempt"); + assert_eq!( + retry_state.max_attempts, 1, + "Internal server errors should only retry once" + ); }); // Advance clock for first retry - cx.executor() - .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS)); + cx.executor().advance_clock(BASE_RETRY_DELAY); cx.run_until_parked(); // Should have scheduled second retry - count retry messages @@ -4393,93 +4486,25 @@ fn main() {{ }) .count() }); - assert_eq!(retry_count, 2, "Should have scheduled second retry"); - - // Check retry state updated - thread.read_with(cx, |thread, _| { - assert!(thread.retry_state.is_some(), "Should have retry state"); - let retry_state = thread.retry_state.as_ref().unwrap(); - assert_eq!(retry_state.attempt, 2, "Should be second retry attempt"); - assert_eq!( - retry_state.max_attempts, MAX_RETRY_ATTEMPTS, - "Should have correct max attempts" - ); - }); - - // Advance clock for second retry (exponential backoff) - cx.executor() - .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS * 2)); - cx.run_until_parked(); - - // Should have scheduled third retry - // Count all retry messages now - let retry_count = thread.update(cx, |thread, _| { - thread - .messages - .iter() - .filter(|m| { - m.ui_only - && m.segments.iter().any(|s| { - if let MessageSegment::Text(text) = s { - text.contains("Retrying") && text.contains("seconds") - } else { - false - } - }) - }) - .count() - }); assert_eq!( - retry_count, MAX_RETRY_ATTEMPTS as usize, - "Should have scheduled third retry" + retry_count, 1, + "Should have only one retry for internal server errors" ); - // Check retry state updated + // For internal server errors, we only retry once and then give up + // Check that retry_state is cleared after the single retry thread.read_with(cx, |thread, _| { - assert!(thread.retry_state.is_some(), "Should have retry state"); - let retry_state = thread.retry_state.as_ref().unwrap(); - assert_eq!( - retry_state.attempt, MAX_RETRY_ATTEMPTS, - "Should be at max retry attempt" - ); - assert_eq!( - retry_state.max_attempts, MAX_RETRY_ATTEMPTS, - "Should have correct max attempts" + assert!( + thread.retry_state.is_none(), + "Retry state should be cleared after single retry" ); }); - // Advance clock for third retry (exponential backoff) - cx.executor() - .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS * 4)); - cx.run_until_parked(); - - // No more retries should be scheduled after clock was advanced. - let retry_count = thread.update(cx, |thread, _| { - thread - .messages - .iter() - .filter(|m| { - m.ui_only - && m.segments.iter().any(|s| { - if let MessageSegment::Text(text) = s { - text.contains("Retrying") && text.contains("seconds") - } else { - false - } - }) - }) - .count() - }); - assert_eq!( - retry_count, MAX_RETRY_ATTEMPTS as usize, - "Should not exceed max retries" - ); - - // Final completion count should be initial + max retries + // Verify total attempts (1 initial + 1 retry) assert_eq!( *completion_count.lock(), - (MAX_RETRY_ATTEMPTS + 1) as usize, - "Should have made initial + max retry attempts" + 2, + "Should have attempted once plus 1 retry" ); } @@ -4490,6 +4515,11 @@ fn main() {{ let project = create_test_project(cx, json!({})).await; let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await; + // Enable Burn Mode to allow retries + thread.update(cx, |thread, _| { + thread.set_completion_mode(CompletionMode::Burn); + }); + // Create model that returns overloaded error let model = Arc::new(ErrorInjector::new(TestError::Overloaded)); @@ -4499,13 +4529,13 @@ fn main() {{ }); // Track events - let retries_failed = Arc::new(Mutex::new(false)); - let retries_failed_clone = retries_failed.clone(); + let stopped_with_error = Arc::new(Mutex::new(false)); + let stopped_with_error_clone = stopped_with_error.clone(); let _subscription = thread.update(cx, |_, cx| { cx.subscribe(&thread, move |_, _, event: &ThreadEvent, _| { - if let ThreadEvent::RetriesFailed { .. } = event { - *retries_failed_clone.lock() = true; + if let ThreadEvent::Stopped(Err(_)) = event { + *stopped_with_error_clone.lock() = true; } }) }); @@ -4517,23 +4547,11 @@ fn main() {{ cx.run_until_parked(); // Advance through all retries - for i in 0..MAX_RETRY_ATTEMPTS { - let delay = if i == 0 { - BASE_RETRY_DELAY_SECS - } else { - BASE_RETRY_DELAY_SECS * 2u64.pow(i as u32 - 1) - }; - cx.executor().advance_clock(Duration::from_secs(delay)); + for _ in 0..MAX_RETRY_ATTEMPTS { + cx.executor().advance_clock(BASE_RETRY_DELAY); cx.run_until_parked(); } - // After the 3rd retry is scheduled, we need to wait for it to execute and fail - // The 3rd retry has a delay of BASE_RETRY_DELAY_SECS * 4 (20 seconds) - let final_delay = BASE_RETRY_DELAY_SECS * 2u64.pow((MAX_RETRY_ATTEMPTS - 1) as u32); - cx.executor() - .advance_clock(Duration::from_secs(final_delay)); - cx.run_until_parked(); - let retry_count = thread.update(cx, |thread, _| { thread .messages @@ -4551,14 +4569,14 @@ fn main() {{ .count() }); - // After max retries, should emit RetriesFailed event + // After max retries, should emit Stopped(Err(...)) event assert_eq!( retry_count, MAX_RETRY_ATTEMPTS as usize, - "Should have attempted max retries" + "Should have attempted MAX_RETRY_ATTEMPTS retries for overloaded errors" ); assert!( - *retries_failed.lock(), - "Should emit RetriesFailed event after max retries exceeded" + *stopped_with_error.lock(), + "Should emit Stopped(Err(...)) event after max retries exceeded" ); // Retry state should be cleared @@ -4576,7 +4594,7 @@ fn main() {{ .count(); assert_eq!( retry_messages, MAX_RETRY_ATTEMPTS as usize, - "Should have one retry message per attempt" + "Should have MAX_RETRY_ATTEMPTS retry messages for overloaded errors" ); }); } @@ -4588,6 +4606,11 @@ fn main() {{ let project = create_test_project(cx, json!({})).await; let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await; + // Enable Burn Mode to allow retries + thread.update(cx, |thread, _| { + thread.set_completion_mode(CompletionMode::Burn); + }); + // We'll use a wrapper to switch behavior after first failure struct RetryTestModel { inner: Arc, @@ -4714,8 +4737,7 @@ fn main() {{ }); // Wait for retry - cx.executor() - .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS)); + cx.executor().advance_clock(BASE_RETRY_DELAY); cx.run_until_parked(); // Stream some successful content @@ -4757,6 +4779,11 @@ fn main() {{ let project = create_test_project(cx, json!({})).await; let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await; + // Enable Burn Mode to allow retries + thread.update(cx, |thread, _| { + thread.set_completion_mode(CompletionMode::Burn); + }); + // Create a model that fails once then succeeds struct FailOnceModel { inner: Arc, @@ -4877,8 +4904,7 @@ fn main() {{ }); // Wait for retry delay - cx.executor() - .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS)); + cx.executor().advance_clock(BASE_RETRY_DELAY); cx.run_until_parked(); // The retry should now use our FailOnceModel which should succeed @@ -4919,6 +4945,11 @@ fn main() {{ let project = create_test_project(cx, json!({})).await; let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await; + // Enable Burn Mode to allow retries + thread.update(cx, |thread, _| { + thread.set_completion_mode(CompletionMode::Burn); + }); + // Create a model that returns rate limit error with retry_after struct RateLimitModel { inner: Arc, @@ -5037,9 +5068,15 @@ fn main() {{ thread.read_with(cx, |thread, _| { assert!( - thread.retry_state.is_none(), - "Rate limit errors should not set retry_state" + thread.retry_state.is_some(), + "Rate limit errors should set retry_state" ); + if let Some(retry_state) = &thread.retry_state { + assert_eq!( + retry_state.max_attempts, MAX_RETRY_ATTEMPTS, + "Rate limit errors should use MAX_RETRY_ATTEMPTS" + ); + } }); // Verify we have one retry message @@ -5072,18 +5109,15 @@ fn main() {{ .find(|msg| msg.role == Role::System && msg.ui_only) .expect("Should have a retry message"); - // Check that the message doesn't contain attempt count + // Check that the message contains attempt count since we use retry_state if let Some(MessageSegment::Text(text)) = retry_message.segments.first() { assert!( - !text.contains("attempt"), - "Rate limit retry message should not contain attempt count" + text.contains(&format!("attempt 1 of {}", MAX_RETRY_ATTEMPTS)), + "Rate limit retry message should contain attempt count with MAX_RETRY_ATTEMPTS" ); assert!( - text.contains(&format!( - "Retrying in {} seconds", - TEST_RATE_LIMIT_RETRY_SECS - )), - "Rate limit retry message should contain retry delay" + text.contains("Retrying"), + "Rate limit retry message should contain retry text" ); } }); @@ -5189,6 +5223,79 @@ fn main() {{ ); } + #[gpui::test] + async fn test_no_retry_without_burn_mode(cx: &mut TestAppContext) { + init_test_settings(cx); + + let project = create_test_project(cx, json!({})).await; + let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await; + + // Ensure we're in Normal mode (not Burn mode) + thread.update(cx, |thread, _| { + thread.set_completion_mode(CompletionMode::Normal); + }); + + // Track error events + let error_events = Arc::new(Mutex::new(Vec::new())); + let error_events_clone = error_events.clone(); + + let _subscription = thread.update(cx, |_, cx| { + cx.subscribe(&thread, move |_, _, event: &ThreadEvent, _| { + if let ThreadEvent::ShowError(error) = event { + error_events_clone.lock().push(error.clone()); + } + }) + }); + + // Create model that returns overloaded error + let model = Arc::new(ErrorInjector::new(TestError::Overloaded)); + + // Insert a user message + thread.update(cx, |thread, cx| { + thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx); + }); + + // Start completion + thread.update(cx, |thread, cx| { + thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx); + }); + + cx.run_until_parked(); + + // Verify no retry state was created + thread.read_with(cx, |thread, _| { + assert!( + thread.retry_state.is_none(), + "Should not have retry state in Normal mode" + ); + }); + + // Check that a retryable error was reported + let errors = error_events.lock(); + assert!(!errors.is_empty(), "Should have received an error event"); + + if let ThreadError::RetryableError { + message: _, + can_enable_burn_mode, + } = &errors[0] + { + assert!( + *can_enable_burn_mode, + "Error should indicate burn mode can be enabled" + ); + } else { + panic!("Expected RetryableError, got {:?}", errors[0]); + } + + // Verify the thread is no longer generating + thread.read_with(cx, |thread, _| { + assert!( + !thread.is_generating(), + "Should not be generating after error without retry" + ); + }); + } + #[gpui::test] async fn test_retry_cancelled_on_stop(cx: &mut TestAppContext) { init_test_settings(cx); @@ -5196,6 +5303,11 @@ fn main() {{ let project = create_test_project(cx, json!({})).await; let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await; + // Enable Burn Mode to allow retries + thread.update(cx, |thread, _| { + thread.set_completion_mode(CompletionMode::Burn); + }); + // Create model that returns overloaded error let model = Arc::new(ErrorInjector::new(TestError::Overloaded)); diff --git a/crates/agent_ui/src/active_thread.rs b/crates/agent_ui/src/active_thread.rs index a4553fc901..93736ac169 100644 --- a/crates/agent_ui/src/active_thread.rs +++ b/crates/agent_ui/src/active_thread.rs @@ -983,30 +983,57 @@ impl ActiveThread { | ThreadEvent::SummaryChanged => { self.save_thread(cx); } - ThreadEvent::Stopped(reason) => match reason { - Ok(StopReason::EndTurn | StopReason::MaxTokens) => { - let used_tools = self.thread.read(cx).used_tools_since_last_user_message(); - self.play_notification_sound(window, cx); - self.show_notification( - if used_tools { - "Finished running tools" - } else { - "New message" - }, - IconName::ZedAssistant, - window, - cx, - ); + ThreadEvent::Stopped(reason) => { + match reason { + Ok(StopReason::EndTurn | StopReason::MaxTokens) => { + let used_tools = self.thread.read(cx).used_tools_since_last_user_message(); + self.notify_with_sound( + if used_tools { + "Finished running tools" + } else { + "New message" + }, + IconName::ZedAssistant, + window, + cx, + ); + } + Ok(StopReason::ToolUse) => { + // Don't notify for intermediate tool use + } + Ok(StopReason::Refusal) => { + self.notify_with_sound( + "Language model refused to respond", + IconName::Warning, + window, + cx, + ); + } + Err(error) => { + self.notify_with_sound( + "Agent stopped due to an error", + IconName::Warning, + window, + cx, + ); + + let error_message = error + .chain() + .map(|err| err.to_string()) + .collect::>() + .join("\n"); + self.last_error = Some(ThreadError::Message { + header: "Error".into(), + message: error_message.into(), + }); + } } - _ => {} - }, + } ThreadEvent::ToolConfirmationNeeded => { - self.play_notification_sound(window, cx); - self.show_notification("Waiting for tool confirmation", IconName::Info, window, cx); + self.notify_with_sound("Waiting for tool confirmation", IconName::Info, window, cx); } ThreadEvent::ToolUseLimitReached => { - self.play_notification_sound(window, cx); - self.show_notification( + self.notify_with_sound( "Consecutive tool use limit reached.", IconName::Warning, window, @@ -1149,9 +1176,6 @@ impl ActiveThread { self.save_thread(cx); cx.notify(); } - ThreadEvent::RetriesFailed { message } => { - self.show_notification(message, ui::IconName::Warning, window, cx); - } } } @@ -1206,6 +1230,17 @@ impl ActiveThread { } } + fn notify_with_sound( + &mut self, + caption: impl Into, + icon: IconName, + window: &mut Window, + cx: &mut Context, + ) { + self.play_notification_sound(window, cx); + self.show_notification(caption, icon, window, cx); + } + fn pop_up( &mut self, icon: IconName, @@ -3673,8 +3708,11 @@ pub(crate) fn open_context( AgentContextHandle::Thread(thread_context) => workspace.update(cx, |workspace, cx| { if let Some(panel) = workspace.panel::(cx) { - panel.update(cx, |panel, cx| { - panel.open_thread(thread_context.thread.clone(), window, cx); + let thread = thread_context.thread.clone(); + window.defer(cx, move |window, cx| { + panel.update(cx, |panel, cx| { + panel.open_thread(thread, window, cx); + }); }); } }), @@ -3682,8 +3720,11 @@ pub(crate) fn open_context( AgentContextHandle::TextThread(text_thread_context) => { workspace.update(cx, |workspace, cx| { if let Some(panel) = workspace.panel::(cx) { - panel.update(cx, |panel, cx| { - panel.open_prompt_editor(text_thread_context.context.clone(), window, cx) + let context = text_thread_context.context.clone(); + window.defer(cx, move |window, cx| { + panel.update(cx, |panel, cx| { + panel.open_prompt_editor(context, window, cx) + }); }); } }) diff --git a/crates/agent_ui/src/agent_configuration.rs b/crates/agent_ui/src/agent_configuration.rs index 8bfdd50761..579331c9ac 100644 --- a/crates/agent_ui/src/agent_configuration.rs +++ b/crates/agent_ui/src/agent_configuration.rs @@ -491,6 +491,7 @@ impl AgentConfiguration { category_filter: Some( ExtensionCategoryFilter::ContextServers, ), + id: None, } .boxed_clone(), cx, diff --git a/crates/agent_ui/src/agent_diff.rs b/crates/agent_ui/src/agent_diff.rs index 1a0f3ff27d..f6cbe984f8 100644 --- a/crates/agent_ui/src/agent_diff.rs +++ b/crates/agent_ui/src/agent_diff.rs @@ -1375,7 +1375,6 @@ impl AgentDiff { | ThreadEvent::ToolConfirmationNeeded | ThreadEvent::ToolUseLimitReached | ThreadEvent::CancelEditing - | ThreadEvent::RetriesFailed { .. } | ThreadEvent::ProfileChanged => {} } } diff --git a/crates/agent_ui/src/agent_panel.rs b/crates/agent_ui/src/agent_panel.rs index 5f58e0bd8d..828591b106 100644 --- a/crates/agent_ui/src/agent_panel.rs +++ b/crates/agent_ui/src/agent_panel.rs @@ -59,8 +59,9 @@ use theme::ThemeSettings; use time::UtcOffset; use ui::utils::WithRemSize; use ui::{ - Banner, Callout, CheckboxWithLabel, ContextMenu, ElevationIndex, KeyBinding, PopoverMenu, - PopoverMenuHandle, ProgressBar, Tab, Tooltip, Vector, VectorName, prelude::*, + Banner, Button, Callout, CheckboxWithLabel, ContextMenu, ElevationIndex, IconPosition, + KeyBinding, PopoverMenu, PopoverMenuHandle, ProgressBar, Tab, Tooltip, Vector, VectorName, + prelude::*, }; use util::ResultExt as _; use workspace::{ @@ -1778,6 +1779,7 @@ impl AgentPanel { category_filter: Some( zed_actions::ExtensionCategoryFilter::ContextServers, ), + id: None, }), ) .action("Add Custom Server…", Box::new(AddContextServer)) @@ -1887,45 +1889,45 @@ impl AgentPanel { } fn render_token_count(&self, cx: &App) -> Option { - let (active_thread, message_editor) = match &self.active_view { + match &self.active_view { ActiveView::Thread { thread, message_editor, .. - } => (thread.read(cx), message_editor.read(cx)), - ActiveView::TextThread { .. } | ActiveView::History | ActiveView::Configuration => { - return None; - } - }; + } => { + let active_thread = thread.read(cx); + let message_editor = message_editor.read(cx); - let editor_empty = message_editor.is_editor_fully_empty(cx); + let editor_empty = message_editor.is_editor_fully_empty(cx); - if active_thread.is_empty() && editor_empty { - return None; - } + if active_thread.is_empty() && editor_empty { + return None; + } - let thread = active_thread.thread().read(cx); - let is_generating = thread.is_generating(); - let conversation_token_usage = thread.total_token_usage()?; + let thread = active_thread.thread().read(cx); + let is_generating = thread.is_generating(); + let conversation_token_usage = thread.total_token_usage()?; - let (total_token_usage, is_estimating) = - if let Some((editing_message_id, unsent_tokens)) = active_thread.editing_message_id() { - let combined = thread - .token_usage_up_to_message(editing_message_id) - .add(unsent_tokens); + let (total_token_usage, is_estimating) = + if let Some((editing_message_id, unsent_tokens)) = + active_thread.editing_message_id() + { + let combined = thread + .token_usage_up_to_message(editing_message_id) + .add(unsent_tokens); - (combined, unsent_tokens > 0) - } else { - let unsent_tokens = message_editor.last_estimated_token_count().unwrap_or(0); - let combined = conversation_token_usage.add(unsent_tokens); + (combined, unsent_tokens > 0) + } else { + let unsent_tokens = + message_editor.last_estimated_token_count().unwrap_or(0); + let combined = conversation_token_usage.add(unsent_tokens); - (combined, unsent_tokens > 0) - }; + (combined, unsent_tokens > 0) + }; - let is_waiting_to_update_token_count = message_editor.is_waiting_to_update_token_count(); + let is_waiting_to_update_token_count = + message_editor.is_waiting_to_update_token_count(); - match &self.active_view { - ActiveView::Thread { .. } => { if total_token_usage.total == 0 { return None; } @@ -2819,6 +2821,21 @@ impl AgentPanel { .size(IconSize::Small) .color(Color::Error); + let retry_button = Button::new("retry", "Retry") + .icon(IconName::RotateCw) + .icon_position(IconPosition::Start) + .on_click({ + let thread = thread.clone(); + move |_, window, cx| { + thread.update(cx, |thread, cx| { + thread.clear_last_error(); + thread.thread().update(cx, |thread, cx| { + thread.retry_last_completion(Some(window.window_handle()), cx); + }); + }); + } + }); + div() .border_t_1() .border_color(cx.theme().colors().border) @@ -2827,13 +2844,72 @@ impl AgentPanel { .icon(icon) .title(header) .description(message.clone()) - .primary_action(self.dismiss_error_button(thread, cx)) - .secondary_action(self.create_copy_button(message_with_header)) + .primary_action(retry_button) + .secondary_action(self.dismiss_error_button(thread, cx)) + .tertiary_action(self.create_copy_button(message_with_header)) .bg_color(self.error_callout_bg(cx)), ) .into_any_element() } + fn render_retryable_error( + &self, + message: SharedString, + can_enable_burn_mode: bool, + thread: &Entity, + cx: &mut Context, + ) -> AnyElement { + let icon = Icon::new(IconName::XCircle) + .size(IconSize::Small) + .color(Color::Error); + + let retry_button = Button::new("retry", "Retry") + .icon(IconName::RotateCw) + .icon_position(IconPosition::Start) + .on_click({ + let thread = thread.clone(); + move |_, window, cx| { + thread.update(cx, |thread, cx| { + thread.clear_last_error(); + thread.thread().update(cx, |thread, cx| { + thread.retry_last_completion(Some(window.window_handle()), cx); + }); + }); + } + }); + + let mut callout = Callout::new() + .icon(icon) + .title("Error") + .description(message.clone()) + .bg_color(self.error_callout_bg(cx)) + .primary_action(retry_button); + + if can_enable_burn_mode { + let burn_mode_button = Button::new("enable_burn_retry", "Enable Burn Mode and Retry") + .icon(IconName::ZedBurnMode) + .icon_position(IconPosition::Start) + .on_click({ + let thread = thread.clone(); + move |_, window, cx| { + thread.update(cx, |thread, cx| { + thread.clear_last_error(); + thread.thread().update(cx, |thread, cx| { + thread.enable_burn_mode_and_retry(Some(window.window_handle()), cx); + }); + }); + } + }); + callout = callout.secondary_action(burn_mode_button); + } + + div() + .border_t_1() + .border_color(cx.theme().colors().border) + .child(callout) + .into_any_element() + } + fn render_prompt_editor( &self, context_editor: &Entity, @@ -3069,6 +3145,15 @@ impl Render for AgentPanel { ThreadError::Message { header, message } => { self.render_error_message(header, message, thread, cx) } + ThreadError::RetryableError { + message, + can_enable_burn_mode, + } => self.render_retryable_error( + message, + can_enable_burn_mode, + thread, + cx, + ), }) .into_any(), ) diff --git a/crates/assistant_tools/src/edit_agent/evals.rs b/crates/assistant_tools/src/edit_agent/evals.rs index 8df8f677f2..8d125d4a15 100644 --- a/crates/assistant_tools/src/edit_agent/evals.rs +++ b/crates/assistant_tools/src/edit_agent/evals.rs @@ -12,6 +12,7 @@ use collections::HashMap; use fs::FakeFs; use futures::{FutureExt, future::LocalBoxFuture}; use gpui::{AppContext, TestAppContext, Timer}; +use http_client::StatusCode; use indoc::{formatdoc, indoc}; use language_model::{ LanguageModelRegistry, LanguageModelRequestTool, LanguageModelToolResult, @@ -1671,6 +1672,30 @@ async fn retry_on_rate_limit(mut request: impl AsyncFnMut() -> Result) -> Timer::after(retry_after + jitter).await; continue; } + LanguageModelCompletionError::UpstreamProviderError { + status, + retry_after, + .. + } => { + // Only retry for specific status codes + let should_retry = matches!( + *status, + StatusCode::TOO_MANY_REQUESTS | StatusCode::SERVICE_UNAVAILABLE + ) || status.as_u16() == 529; + + if !should_retry { + return Err(err.into()); + } + + // Use server-provided retry_after if available, otherwise use default + let retry_after = retry_after.unwrap_or(Duration::from_secs(5)); + let jitter = retry_after.mul_f64(rand::thread_rng().gen_range(0.0..1.0)); + eprintln!( + "Attempt #{attempt}: {err}. Retry after {retry_after:?} + jitter of {jitter:?}" + ); + Timer::after(retry_after + jitter).await; + continue; + } _ => return Err(err.into()), }, Err(err) => return Err(err), diff --git a/crates/collab/Cargo.toml b/crates/collab/Cargo.toml index 55c15cac5a..7b536a2d24 100644 --- a/crates/collab/Cargo.toml +++ b/crates/collab/Cargo.toml @@ -94,6 +94,7 @@ context_server.workspace = true ctor.workspace = true dap = { workspace = true, features = ["test-support"] } dap_adapters = { workspace = true, features = ["test-support"] } +dap-types.workspace = true debugger_ui = { workspace = true, features = ["test-support"] } editor = { workspace = true, features = ["test-support"] } extension.workspace = true diff --git a/crates/collab/src/tests/remote_editing_collaboration_tests.rs b/crates/collab/src/tests/remote_editing_collaboration_tests.rs index 7aeb381c02..8ab6e6910c 100644 --- a/crates/collab/src/tests/remote_editing_collaboration_tests.rs +++ b/crates/collab/src/tests/remote_editing_collaboration_tests.rs @@ -2,6 +2,7 @@ use crate::tests::TestServer; use call::ActiveCall; use collections::{HashMap, HashSet}; +use dap::{Capabilities, adapters::DebugTaskDefinition, transport::RequestHandling}; use debugger_ui::debugger_panel::DebugPanel; use extension::ExtensionHostProxy; use fs::{FakeFs, Fs as _, RemoveOptions}; @@ -22,6 +23,7 @@ use language::{ use node_runtime::NodeRuntime; use project::{ ProjectPath, + debugger::session::ThreadId, lsp_store::{FormatTrigger, LspFormatTarget}, }; use remote::SshRemoteClient; @@ -29,7 +31,11 @@ use remote_server::{HeadlessAppState, HeadlessProject}; use rpc::proto; use serde_json::json; use settings::SettingsStore; -use std::{path::Path, sync::Arc}; +use std::{ + path::Path, + sync::{Arc, atomic::AtomicUsize}, +}; +use task::TcpArgumentsTemplate; use util::path; #[gpui::test(iterations = 10)] @@ -688,3 +694,162 @@ async fn test_remote_server_debugger( shutdown_session.await.unwrap(); } + +#[gpui::test] +async fn test_slow_adapter_startup_retries( + cx_a: &mut TestAppContext, + server_cx: &mut TestAppContext, + executor: BackgroundExecutor, +) { + cx_a.update(|cx| { + release_channel::init(SemanticVersion::default(), cx); + command_palette_hooks::init(cx); + zlog::init_test(); + dap_adapters::init(cx); + }); + server_cx.update(|cx| { + release_channel::init(SemanticVersion::default(), cx); + dap_adapters::init(cx); + }); + let (opts, server_ssh) = SshRemoteClient::fake_server(cx_a, server_cx); + let remote_fs = FakeFs::new(server_cx.executor()); + remote_fs + .insert_tree( + path!("/code"), + json!({ + "lib.rs": "fn one() -> usize { 1 }" + }), + ) + .await; + + // User A connects to the remote project via SSH. + server_cx.update(HeadlessProject::init); + let remote_http_client = Arc::new(BlockedHttpClient); + let node = NodeRuntime::unavailable(); + let languages = Arc::new(LanguageRegistry::new(server_cx.executor())); + let _headless_project = server_cx.new(|cx| { + client::init_settings(cx); + HeadlessProject::new( + HeadlessAppState { + session: server_ssh, + fs: remote_fs.clone(), + http_client: remote_http_client, + node_runtime: node, + languages, + extension_host_proxy: Arc::new(ExtensionHostProxy::new()), + }, + cx, + ) + }); + + let client_ssh = SshRemoteClient::fake_client(opts, cx_a).await; + let mut server = TestServer::start(server_cx.executor()).await; + let client_a = server.create_client(cx_a, "user_a").await; + cx_a.update(|cx| { + debugger_ui::init(cx); + command_palette_hooks::init(cx); + }); + let (project_a, _) = client_a + .build_ssh_project(path!("/code"), client_ssh.clone(), cx_a) + .await; + + let (workspace, cx_a) = client_a.build_workspace(&project_a, cx_a); + + let debugger_panel = workspace + .update_in(cx_a, |_workspace, window, cx| { + cx.spawn_in(window, DebugPanel::load) + }) + .await + .unwrap(); + + workspace.update_in(cx_a, |workspace, window, cx| { + workspace.add_panel(debugger_panel, window, cx); + }); + + cx_a.run_until_parked(); + let debug_panel = workspace + .update(cx_a, |workspace, cx| workspace.panel::(cx)) + .unwrap(); + + let workspace_window = cx_a + .window_handle() + .downcast::() + .unwrap(); + + let count = Arc::new(AtomicUsize::new(0)); + let session = debugger_ui::tests::start_debug_session_with( + &workspace_window, + cx_a, + DebugTaskDefinition { + adapter: "fake-adapter".into(), + label: "test".into(), + config: json!({ + "request": "launch" + }), + tcp_connection: Some(TcpArgumentsTemplate { + port: None, + host: None, + timeout: None, + }), + }, + move |client| { + let count = count.clone(); + client.on_request_ext::(move |_seq, _request| { + if count.fetch_add(1, std::sync::atomic::Ordering::SeqCst) < 5 { + return RequestHandling::Exit; + } + RequestHandling::Respond(Ok(Capabilities::default())) + }); + }, + ) + .unwrap(); + cx_a.run_until_parked(); + + let client = session.update(cx_a, |session, _| session.adapter_client().unwrap()); + client + .fake_event(dap::messages::Events::Stopped(dap::StoppedEvent { + reason: dap::StoppedEventReason::Pause, + description: None, + thread_id: Some(1), + preserve_focus_hint: None, + text: None, + all_threads_stopped: None, + hit_breakpoint_ids: None, + })) + .await; + + cx_a.run_until_parked(); + + let active_session = debug_panel + .update(cx_a, |this, _| this.active_session()) + .unwrap(); + + let running_state = active_session.update(cx_a, |active_session, _| { + active_session.running_state().clone() + }); + + assert_eq!( + client.id(), + running_state.read_with(cx_a, |running_state, _| running_state.session_id()) + ); + assert_eq!( + ThreadId(1), + running_state.read_with(cx_a, |running_state, _| running_state + .selected_thread_id() + .unwrap()) + ); + + let shutdown_session = workspace.update(cx_a, |workspace, cx| { + workspace.project().update(cx, |project, cx| { + project.dap_store().update(cx, |dap_store, cx| { + dap_store.shutdown_session(session.read(cx).session_id(), cx) + }) + }) + }); + + client_ssh.update(cx_a, |a, _| { + a.shutdown_processes(Some(proto::ShutdownRemoteServer {}), executor) + }); + + shutdown_session.await.unwrap(); +} diff --git a/crates/dap/src/adapters.rs b/crates/dap/src/adapters.rs index d9f26b3b34..bd36b07387 100644 --- a/crates/dap/src/adapters.rs +++ b/crates/dap/src/adapters.rs @@ -442,10 +442,18 @@ impl DebugAdapter for FakeAdapter { _: Option>, _: &mut AsyncApp, ) -> Result { + let connection = task_definition + .tcp_connection + .as_ref() + .map(|connection| TcpArguments { + host: connection.host(), + port: connection.port.unwrap_or(17), + timeout: connection.timeout, + }); Ok(DebugAdapterBinary { command: Some("command".into()), arguments: vec![], - connection: None, + connection, envs: HashMap::default(), cwd: None, request_args: StartDebuggingRequestArguments { diff --git a/crates/dap/src/client.rs b/crates/dap/src/client.rs index ff082e3b76..86a15b2d8a 100644 --- a/crates/dap/src/client.rs +++ b/crates/dap/src/client.rs @@ -2,7 +2,7 @@ use crate::{ adapters::DebugAdapterBinary, transport::{IoKind, LogKind, TransportDelegate}, }; -use anyhow::{Context as _, Result}; +use anyhow::Result; use dap_types::{ messages::{Message, Response}, requests::Request, @@ -110,9 +110,7 @@ impl DebugAdapterClient { self.transport_delegate .pending_requests .lock() - .as_mut() - .context("client is closed")? - .insert(sequence_id, callback_tx); + .insert(sequence_id, callback_tx)?; log::debug!( "Client {} send `{}` request with sequence_id: {}", @@ -170,6 +168,7 @@ impl DebugAdapterClient { pub fn kill(&self) { log::debug!("Killing DAP process"); self.transport_delegate.transport.lock().kill(); + self.transport_delegate.pending_requests.lock().shutdown(); } pub fn has_adapter_logs(&self) -> bool { @@ -184,11 +183,34 @@ impl DebugAdapterClient { } #[cfg(any(test, feature = "test-support"))] - pub fn on_request(&self, handler: F) + pub fn on_request(&self, mut handler: F) where F: 'static + Send + FnMut(u64, R::Arguments) -> Result, + { + use crate::transport::RequestHandling; + + self.transport_delegate + .transport + .lock() + .as_fake() + .on_request::(move |seq, request| { + RequestHandling::Respond(handler(seq, request)) + }); + } + + #[cfg(any(test, feature = "test-support"))] + pub fn on_request_ext(&self, handler: F) + where + F: 'static + + Send + + FnMut( + u64, + R::Arguments, + ) -> crate::transport::RequestHandling< + Result, + >, { self.transport_delegate .transport diff --git a/crates/dap/src/transport.rs b/crates/dap/src/transport.rs index 14370f66e4..6dadf1cf35 100644 --- a/crates/dap/src/transport.rs +++ b/crates/dap/src/transport.rs @@ -49,6 +49,12 @@ pub enum IoKind { StdErr, } +#[cfg(any(test, feature = "test-support"))] +pub enum RequestHandling { + Respond(T), + Exit, +} + type LogHandlers = Arc>>; pub trait Transport: Send + Sync { @@ -76,7 +82,11 @@ async fn start( ) -> Result> { #[cfg(any(test, feature = "test-support"))] if cfg!(any(test, feature = "test-support")) { - return Ok(Box::new(FakeTransport::start(cx).await?)); + if let Some(connection) = binary.connection.clone() { + return Ok(Box::new(FakeTransport::start_tcp(connection, cx).await?)); + } else { + return Ok(Box::new(FakeTransport::start_stdio(cx).await?)); + } } if binary.connection.is_some() { @@ -90,11 +100,57 @@ async fn start( } } +pub(crate) struct PendingRequests { + inner: Option>>>, +} + +impl PendingRequests { + fn new() -> Self { + Self { + inner: Some(HashMap::default()), + } + } + + fn flush(&mut self, e: anyhow::Error) { + let Some(inner) = self.inner.as_mut() else { + return; + }; + for (_, sender) in inner.drain() { + sender.send(Err(e.cloned())).ok(); + } + } + + pub(crate) fn insert( + &mut self, + sequence_id: u64, + callback_tx: oneshot::Sender>, + ) -> anyhow::Result<()> { + let Some(inner) = self.inner.as_mut() else { + bail!("client is closed") + }; + inner.insert(sequence_id, callback_tx); + Ok(()) + } + + pub(crate) fn remove( + &mut self, + sequence_id: u64, + ) -> anyhow::Result>>> { + let Some(inner) = self.inner.as_mut() else { + bail!("client is closed"); + }; + Ok(inner.remove(&sequence_id)) + } + + pub(crate) fn shutdown(&mut self) { + self.flush(anyhow!("transport shutdown")); + self.inner = None; + } +} + pub(crate) struct TransportDelegate { log_handlers: LogHandlers, - // TODO this should really be some kind of associative channel - pub(crate) pending_requests: - Arc>>>>>, + pub(crate) pending_requests: Arc>, pub(crate) transport: Mutex>, pub(crate) server_tx: smol::lock::Mutex>>, tasks: Mutex>>, @@ -108,7 +164,7 @@ impl TransportDelegate { transport: Mutex::new(transport), log_handlers, server_tx: Default::default(), - pending_requests: Arc::new(Mutex::new(Some(HashMap::default()))), + pending_requests: Arc::new(Mutex::new(PendingRequests::new())), tasks: Default::default(), }) } @@ -151,24 +207,10 @@ impl TransportDelegate { Ok(()) => { pending_requests .lock() - .take() - .into_iter() - .flatten() - .for_each(|(_, request)| { - request - .send(Err(anyhow!("debugger shutdown unexpectedly"))) - .ok(); - }); + .flush(anyhow!("debugger shutdown unexpectedly")); } Err(e) => { - pending_requests - .lock() - .take() - .into_iter() - .flatten() - .for_each(|(_, request)| { - request.send(Err(e.cloned())).ok(); - }); + pending_requests.lock().flush(e); } } })); @@ -286,7 +328,7 @@ impl TransportDelegate { async fn recv_from_server( server_stdout: Stdout, mut message_handler: DapMessageHandler, - pending_requests: Arc>>>>>, + pending_requests: Arc>, log_handlers: Option, ) -> Result<()> where @@ -303,14 +345,10 @@ impl TransportDelegate { ConnectionResult::Timeout => anyhow::bail!("Timed out when connecting to debugger"), ConnectionResult::ConnectionReset => { log::info!("Debugger closed the connection"); - break Ok(()); + return Ok(()); } ConnectionResult::Result(Ok(Message::Response(res))) => { - let tx = pending_requests - .lock() - .as_mut() - .context("client is closed")? - .remove(&res.request_seq); + let tx = pending_requests.lock().remove(res.request_seq)?; if let Some(tx) = tx { if let Err(e) = tx.send(Self::process_response(res)) { log::trace!("Did not send response `{:?}` for a cancelled", e); @@ -704,8 +742,7 @@ impl Drop for StdioTransport { } #[cfg(any(test, feature = "test-support"))] -type RequestHandler = - Box dap_types::messages::Response>; +type RequestHandler = Box RequestHandling>; #[cfg(any(test, feature = "test-support"))] type ResponseHandler = Box; @@ -716,23 +753,38 @@ pub struct FakeTransport { request_handlers: Arc>>, // for reverse request responses response_handlers: Arc>>, - - stdin_writer: Option, - stdout_reader: Option, message_handler: Option>>, + kind: FakeTransportKind, +} + +#[cfg(any(test, feature = "test-support"))] +pub enum FakeTransportKind { + Stdio { + stdin_writer: Option, + stdout_reader: Option, + }, + Tcp { + connection: TcpArguments, + executor: BackgroundExecutor, + }, } #[cfg(any(test, feature = "test-support"))] impl FakeTransport { pub fn on_request(&self, mut handler: F) where - F: 'static + Send + FnMut(u64, R::Arguments) -> Result, + F: 'static + + Send + + FnMut(u64, R::Arguments) -> RequestHandling>, { self.request_handlers.lock().insert( R::COMMAND, Box::new(move |seq, args| { let result = handler(seq, serde_json::from_value(args).unwrap()); - let response = match result { + let RequestHandling::Respond(response) = result else { + return RequestHandling::Exit; + }; + let response = match response { Ok(response) => Response { seq: seq + 1, request_seq: seq, @@ -750,7 +802,7 @@ impl FakeTransport { message: None, }, }; - response + RequestHandling::Respond(response) }), ); } @@ -764,86 +816,75 @@ impl FakeTransport { .insert(R::COMMAND, Box::new(handler)); } - async fn start(cx: &mut AsyncApp) -> Result { + async fn start_tcp(connection: TcpArguments, cx: &mut AsyncApp) -> Result { + Ok(Self { + request_handlers: Arc::new(Mutex::new(HashMap::default())), + response_handlers: Arc::new(Mutex::new(HashMap::default())), + message_handler: None, + kind: FakeTransportKind::Tcp { + connection, + executor: cx.background_executor().clone(), + }, + }) + } + + async fn handle_messages( + request_handlers: Arc>>, + response_handlers: Arc>>, + stdin_reader: PipeReader, + stdout_writer: PipeWriter, + ) -> Result<()> { use dap_types::requests::{Request, RunInTerminal, StartDebugging}; use serde_json::json; - let (stdin_writer, stdin_reader) = async_pipe::pipe(); - let (stdout_writer, stdout_reader) = async_pipe::pipe(); - - let mut this = Self { - request_handlers: Arc::new(Mutex::new(HashMap::default())), - response_handlers: Arc::new(Mutex::new(HashMap::default())), - stdin_writer: Some(stdin_writer), - stdout_reader: Some(stdout_reader), - message_handler: None, - }; - - let request_handlers = this.request_handlers.clone(); - let response_handlers = this.response_handlers.clone(); + let mut reader = BufReader::new(stdin_reader); let stdout_writer = Arc::new(smol::lock::Mutex::new(stdout_writer)); + let mut buffer = String::new(); - this.message_handler = Some(cx.background_spawn(async move { - let mut reader = BufReader::new(stdin_reader); - let mut buffer = String::new(); - - loop { - match TransportDelegate::receive_server_message(&mut reader, &mut buffer, None) - .await - { - ConnectionResult::Timeout => { - anyhow::bail!("Timed out when connecting to debugger"); - } - ConnectionResult::ConnectionReset => { - log::info!("Debugger closed the connection"); - break Ok(()); - } - ConnectionResult::Result(Err(e)) => break Err(e), - ConnectionResult::Result(Ok(message)) => { - match message { - Message::Request(request) => { - // redirect reverse requests to stdout writer/reader - if request.command == RunInTerminal::COMMAND - || request.command == StartDebugging::COMMAND - { - let message = - serde_json::to_string(&Message::Request(request)).unwrap(); - - let mut writer = stdout_writer.lock().await; - writer - .write_all( - TransportDelegate::build_rpc_message(message) - .as_bytes(), - ) - .await - .unwrap(); - writer.flush().await.unwrap(); - } else { - let response = if let Some(handle) = - request_handlers.lock().get_mut(request.command.as_str()) - { - handle(request.seq, request.arguments.unwrap_or(json!({}))) - } else { - panic!("No request handler for {}", request.command); - }; - let message = - serde_json::to_string(&Message::Response(response)) - .unwrap(); - - let mut writer = stdout_writer.lock().await; - writer - .write_all( - TransportDelegate::build_rpc_message(message) - .as_bytes(), - ) - .await - .unwrap(); - writer.flush().await.unwrap(); - } - } - Message::Event(event) => { + loop { + match TransportDelegate::receive_server_message(&mut reader, &mut buffer, None).await { + ConnectionResult::Timeout => { + anyhow::bail!("Timed out when connecting to debugger"); + } + ConnectionResult::ConnectionReset => { + log::info!("Debugger closed the connection"); + break Ok(()); + } + ConnectionResult::Result(Err(e)) => break Err(e), + ConnectionResult::Result(Ok(message)) => { + match message { + Message::Request(request) => { + // redirect reverse requests to stdout writer/reader + if request.command == RunInTerminal::COMMAND + || request.command == StartDebugging::COMMAND + { let message = - serde_json::to_string(&Message::Event(event)).unwrap(); + serde_json::to_string(&Message::Request(request)).unwrap(); + + let mut writer = stdout_writer.lock().await; + writer + .write_all( + TransportDelegate::build_rpc_message(message).as_bytes(), + ) + .await + .unwrap(); + writer.flush().await.unwrap(); + } else { + let response = if let Some(handle) = + request_handlers.lock().get_mut(request.command.as_str()) + { + handle(request.seq, request.arguments.unwrap_or(json!({}))) + } else { + panic!("No request handler for {}", request.command); + }; + let response = match response { + RequestHandling::Respond(response) => response, + RequestHandling::Exit => { + break Err(anyhow!("exit in response to request")); + } + }; + let message = + serde_json::to_string(&Message::Response(response)).unwrap(); let mut writer = stdout_writer.lock().await; writer @@ -854,20 +895,56 @@ impl FakeTransport { .unwrap(); writer.flush().await.unwrap(); } - Message::Response(response) => { - if let Some(handle) = - response_handlers.lock().get(response.command.as_str()) - { - handle(response); - } else { - log::error!("No response handler for {}", response.command); - } + } + Message::Event(event) => { + let message = serde_json::to_string(&Message::Event(event)).unwrap(); + + let mut writer = stdout_writer.lock().await; + writer + .write_all(TransportDelegate::build_rpc_message(message).as_bytes()) + .await + .unwrap(); + writer.flush().await.unwrap(); + } + Message::Response(response) => { + if let Some(handle) = + response_handlers.lock().get(response.command.as_str()) + { + handle(response); + } else { + log::error!("No response handler for {}", response.command); } } } } } - })); + } + } + + async fn start_stdio(cx: &mut AsyncApp) -> Result { + let (stdin_writer, stdin_reader) = async_pipe::pipe(); + let (stdout_writer, stdout_reader) = async_pipe::pipe(); + let kind = FakeTransportKind::Stdio { + stdin_writer: Some(stdin_writer), + stdout_reader: Some(stdout_reader), + }; + + let mut this = Self { + request_handlers: Arc::new(Mutex::new(HashMap::default())), + response_handlers: Arc::new(Mutex::new(HashMap::default())), + message_handler: None, + kind, + }; + + let request_handlers = this.request_handlers.clone(); + let response_handlers = this.response_handlers.clone(); + + this.message_handler = Some(cx.background_spawn(Self::handle_messages( + request_handlers, + response_handlers, + stdin_reader, + stdout_writer, + ))); Ok(this) } @@ -876,7 +953,10 @@ impl FakeTransport { #[cfg(any(test, feature = "test-support"))] impl Transport for FakeTransport { fn tcp_arguments(&self) -> Option { - None + match &self.kind { + FakeTransportKind::Stdio { .. } => None, + FakeTransportKind::Tcp { connection, .. } => Some(connection.clone()), + } } fn connect( @@ -887,12 +967,33 @@ impl Transport for FakeTransport { Box, )>, > { - let result = util::maybe!({ - Ok(( - Box::new(self.stdin_writer.take().context("Cannot reconnect")?) as _, - Box::new(self.stdout_reader.take().context("Cannot reconnect")?) as _, - )) - }); + let result = match &mut self.kind { + FakeTransportKind::Stdio { + stdin_writer, + stdout_reader, + } => util::maybe!({ + Ok(( + Box::new(stdin_writer.take().context("Cannot reconnect")?) as _, + Box::new(stdout_reader.take().context("Cannot reconnect")?) as _, + )) + }), + FakeTransportKind::Tcp { executor, .. } => { + let (stdin_writer, stdin_reader) = async_pipe::pipe(); + let (stdout_writer, stdout_reader) = async_pipe::pipe(); + + let request_handlers = self.request_handlers.clone(); + let response_handlers = self.response_handlers.clone(); + + self.message_handler = Some(executor.spawn(Self::handle_messages( + request_handlers, + response_handlers, + stdin_reader, + stdout_writer, + ))); + + Ok((Box::new(stdin_writer) as _, Box::new(stdout_reader) as _)) + } + }; Task::ready(result) } diff --git a/crates/debugger_ui/src/debugger_panel.rs b/crates/debugger_ui/src/debugger_panel.rs index 988f6f4019..ffde772c36 100644 --- a/crates/debugger_ui/src/debugger_panel.rs +++ b/crates/debugger_ui/src/debugger_panel.rs @@ -1694,6 +1694,7 @@ impl Render for DebugPanel { category_filter: Some( zed_actions::ExtensionCategoryFilter::DebugAdapters, ), + id: None, } .boxed_clone(), cx, diff --git a/crates/debugger_ui/src/session.rs b/crates/debugger_ui/src/session.rs index 482297b136..4cc2602909 100644 --- a/crates/debugger_ui/src/session.rs +++ b/crates/debugger_ui/src/session.rs @@ -122,7 +122,7 @@ impl DebugSession { .to_owned() } - pub(crate) fn running_state(&self) -> &Entity { + pub fn running_state(&self) -> &Entity { &self.running_state } diff --git a/crates/debugger_ui/src/session/running.rs b/crates/debugger_ui/src/session/running.rs index af8c14aef7..d308fc9bd2 100644 --- a/crates/debugger_ui/src/session/running.rs +++ b/crates/debugger_ui/src/session/running.rs @@ -1459,7 +1459,7 @@ impl RunningState { } } - pub(crate) fn selected_thread_id(&self) -> Option { + pub fn selected_thread_id(&self) -> Option { self.thread_id } diff --git a/crates/editor/src/editor.rs b/crates/editor/src/editor.rs index c5fe0db74c..0d905c99dc 100644 --- a/crates/editor/src/editor.rs +++ b/crates/editor/src/editor.rs @@ -482,9 +482,7 @@ pub enum SelectMode { #[derive(Clone, PartialEq, Eq, Debug)] pub enum EditorMode { - SingleLine { - auto_width: bool, - }, + SingleLine, AutoHeight { min_lines: usize, max_lines: Option, @@ -1662,13 +1660,7 @@ impl Editor { pub fn single_line(window: &mut Window, cx: &mut Context) -> Self { let buffer = cx.new(|cx| Buffer::local("", cx)); let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx)); - Self::new( - EditorMode::SingleLine { auto_width: false }, - buffer, - None, - window, - cx, - ) + Self::new(EditorMode::SingleLine, buffer, None, window, cx) } pub fn multi_line(window: &mut Window, cx: &mut Context) -> Self { @@ -1677,18 +1669,6 @@ impl Editor { Self::new(EditorMode::full(), buffer, None, window, cx) } - pub fn auto_width(window: &mut Window, cx: &mut Context) -> Self { - let buffer = cx.new(|cx| Buffer::local("", cx)); - let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx)); - Self::new( - EditorMode::SingleLine { auto_width: true }, - buffer, - None, - window, - cx, - ) - } - pub fn auto_height( min_lines: usize, max_lines: usize, @@ -20487,6 +20467,7 @@ impl Editor { if event.blurred != self.focus_handle { self.last_focused_descendant = Some(event.blurred); } + self.selection_drag_state = SelectionDragState::None; self.refresh_inlay_hints(InlayHintRefreshReason::ModifiersChanged(false), cx); } diff --git a/crates/editor/src/element.rs b/crates/editor/src/element.rs index 8a5bfb3bab..8db65189f8 100644 --- a/crates/editor/src/element.rs +++ b/crates/editor/src/element.rs @@ -7777,46 +7777,13 @@ impl Element for EditorElement { editor.set_style(self.style.clone(), window, cx); let layout_id = match editor.mode { - EditorMode::SingleLine { auto_width } => { + EditorMode::SingleLine => { let rem_size = window.rem_size(); - let height = self.style.text.line_height_in_pixels(rem_size); - if auto_width { - let editor_handle = cx.entity().clone(); - let style = self.style.clone(); - window.request_measured_layout( - Style::default(), - move |_, _, window, cx| { - let editor_snapshot = editor_handle - .update(cx, |editor, cx| editor.snapshot(window, cx)); - let line = Self::layout_lines( - DisplayRow(0)..DisplayRow(1), - &editor_snapshot, - &style, - px(f32::MAX), - |_| false, // Single lines never soft wrap - window, - cx, - ) - .pop() - .unwrap(); - - let font_id = - window.text_system().resolve_font(&style.text.font()); - let font_size = - style.text.font_size.to_pixels(window.rem_size()); - let em_width = - window.text_system().em_width(font_id, font_size).unwrap(); - - size(line.width + em_width, height) - }, - ) - } else { - let mut style = Style::default(); - style.size.height = height.into(); - style.size.width = relative(1.).into(); - window.request_layout(style, None, cx) - } + let mut style = Style::default(); + style.size.height = height.into(); + style.size.width = relative(1.).into(); + window.request_layout(style, None, cx) } EditorMode::AutoHeight { min_lines, @@ -10388,7 +10355,7 @@ mod tests { }); for editor_mode_without_invisibles in [ - EditorMode::SingleLine { auto_width: false }, + EditorMode::SingleLine, EditorMode::AutoHeight { min_lines: 1, max_lines: Some(100), diff --git a/crates/editor/src/scroll.rs b/crates/editor/src/scroll.rs index b3007d3091..6afbee3784 100644 --- a/crates/editor/src/scroll.rs +++ b/crates/editor/src/scroll.rs @@ -12,7 +12,7 @@ use crate::{ }; pub use autoscroll::{Autoscroll, AutoscrollStrategy}; use core::fmt::Debug; -use gpui::{App, Axis, Context, Global, Pixels, Task, Window, point, px}; +use gpui::{Along, App, Axis, Context, Global, Pixels, Task, Window, point, px}; use language::language_settings::{AllLanguageSettings, SoftWrap}; use language::{Bias, Point}; pub use scroll_amount::ScrollAmount; @@ -47,14 +47,14 @@ impl ScrollAnchor { } pub fn scroll_position(&self, snapshot: &DisplaySnapshot) -> gpui::Point { - let mut scroll_position = self.offset; - if self.anchor == Anchor::min() { - scroll_position.y = 0.; - } else { - let scroll_top = self.anchor.to_display_point(snapshot).row().as_f32(); - scroll_position.y += scroll_top; - } - scroll_position + self.offset.apply_along(Axis::Vertical, |offset| { + if self.anchor == Anchor::min() { + 0. + } else { + let scroll_top = self.anchor.to_display_point(snapshot).row().as_f32(); + (offset + scroll_top).max(0.) + } + }) } pub fn top_row(&self, buffer: &MultiBufferSnapshot) -> u32 { diff --git a/crates/eval/src/example.rs b/crates/eval/src/example.rs index 904eca83e6..09770364cb 100644 --- a/crates/eval/src/example.rs +++ b/crates/eval/src/example.rs @@ -221,9 +221,6 @@ impl ExampleContext { ThreadEvent::ShowError(thread_error) => { tx.try_send(Err(anyhow!(thread_error.clone()))).ok(); } - ThreadEvent::RetriesFailed { .. } => { - // Ignore retries failed events - } ThreadEvent::Stopped(reason) => match reason { Ok(StopReason::EndTurn) => { tx.close_channel(); diff --git a/crates/extensions_ui/src/extensions_ui.rs b/crates/extensions_ui/src/extensions_ui.rs index 48cb41a006..f97470da22 100644 --- a/crates/extensions_ui/src/extensions_ui.rs +++ b/crates/extensions_ui/src/extensions_ui.rs @@ -6,6 +6,7 @@ use std::sync::OnceLock; use std::time::Duration; use std::{ops::Range, sync::Arc}; +use anyhow::Context as _; use client::{ExtensionMetadata, ExtensionProvides}; use collections::{BTreeMap, BTreeSet}; use editor::{Editor, EditorElement, EditorStyle}; @@ -80,16 +81,24 @@ pub fn init(cx: &mut App) { .find_map(|item| item.downcast::()); if let Some(existing) = existing { - if provides_filter.is_some() { - existing.update(cx, |extensions_page, cx| { + existing.update(cx, |extensions_page, cx| { + if provides_filter.is_some() { extensions_page.change_provides_filter(provides_filter, cx); - }); - } + } + if let Some(id) = action.id.as_ref() { + extensions_page.focus_extension(id, window, cx); + } + }); workspace.activate_item(&existing, true, true, window, cx); } else { - let extensions_page = - ExtensionsPage::new(workspace, provides_filter, window, cx); + let extensions_page = ExtensionsPage::new( + workspace, + provides_filter, + action.id.as_deref(), + window, + cx, + ); workspace.add_item_to_active_pane( Box::new(extensions_page), None, @@ -287,6 +296,7 @@ impl ExtensionsPage { pub fn new( workspace: &Workspace, provides_filter: Option, + focus_extension_id: Option<&str>, window: &mut Window, cx: &mut Context, ) -> Entity { @@ -317,6 +327,9 @@ impl ExtensionsPage { let query_editor = cx.new(|cx| { let mut input = Editor::single_line(window, cx); input.set_placeholder_text("Search extensions...", cx); + if let Some(id) = focus_extension_id { + input.set_text(format!("id:{id}"), window, cx); + } input }); cx.subscribe(&query_editor, Self::on_query_change).detach(); @@ -340,7 +353,7 @@ impl ExtensionsPage { scrollbar_state: ScrollbarState::new(scroll_handle), }; this.fetch_extensions( - None, + this.search_query(cx), Some(BTreeSet::from_iter(this.provides_filter)), None, cx, @@ -464,9 +477,23 @@ impl ExtensionsPage { .cloned() .collect::>(); - let remote_extensions = extension_store.update(cx, |store, cx| { - store.fetch_extensions(search.as_deref(), provides_filter.as_ref(), cx) - }); + let remote_extensions = + if let Some(id) = search.as_ref().and_then(|s| s.strip_prefix("id:")) { + let versions = + extension_store.update(cx, |store, cx| store.fetch_extension_versions(id, cx)); + cx.foreground_executor().spawn(async move { + let versions = versions.await?; + let latest = versions + .into_iter() + .max_by_key(|v| v.published_at) + .context("no extension found")?; + Ok(vec![latest]) + }) + } else { + extension_store.update(cx, |store, cx| { + store.fetch_extensions(search.as_deref(), provides_filter.as_ref(), cx) + }) + }; cx.spawn(async move |this, cx| { let dev_extensions = if let Some(search) = search { @@ -1156,6 +1183,13 @@ impl ExtensionsPage { self.refresh_feature_upsells(cx); } + pub fn focus_extension(&mut self, id: &str, window: &mut Window, cx: &mut Context) { + self.query_editor.update(cx, |editor, cx| { + editor.set_text(format!("id:{id}"), window, cx) + }); + self.refresh_search(cx); + } + pub fn change_provides_filter( &mut self, provides_filter: Option, diff --git a/crates/gpui/build.rs b/crates/gpui/build.rs index b9496cc014..aed4397440 100644 --- a/crates/gpui/build.rs +++ b/crates/gpui/build.rs @@ -126,7 +126,7 @@ mod macos { "ContentMask".into(), "Uniforms".into(), "AtlasTile".into(), - "PathInputIndex".into(), + "PathRasterizationInputIndex".into(), "PathVertex_ScaledPixels".into(), "ShadowInputIndex".into(), "Shadow".into(), diff --git a/crates/gpui/examples/painting.rs b/crates/gpui/examples/painting.rs index 9ab58cffc9..ff4b64cbda 100644 --- a/crates/gpui/examples/painting.rs +++ b/crates/gpui/examples/painting.rs @@ -1,13 +1,9 @@ use gpui::{ Application, Background, Bounds, ColorSpace, Context, MouseDownEvent, Path, PathBuilder, - PathStyle, Pixels, Point, Render, SharedString, StrokeOptions, Window, WindowBounds, - WindowOptions, canvas, div, linear_color_stop, linear_gradient, point, prelude::*, px, rgb, - size, + PathStyle, Pixels, Point, Render, SharedString, StrokeOptions, Window, WindowOptions, canvas, + div, linear_color_stop, linear_gradient, point, prelude::*, px, rgb, size, }; -const DEFAULT_WINDOW_WIDTH: Pixels = px(1024.0); -const DEFAULT_WINDOW_HEIGHT: Pixels = px(768.0); - struct PaintingViewer { default_lines: Vec<(Path, Background)>, lines: Vec>>, @@ -151,6 +147,8 @@ impl PaintingViewer { px(320.0 + (i as f32 * 10.0).sin() * 40.0), )); } + let path = builder.build().unwrap(); + lines.push((path, gpui::green().into())); Self { default_lines: lines.clone(), @@ -185,13 +183,9 @@ fn button( } impl Render for PaintingViewer { - fn render(&mut self, window: &mut Window, cx: &mut Context) -> impl IntoElement { - window.request_animation_frame(); - + fn render(&mut self, _: &mut Window, cx: &mut Context) -> impl IntoElement { let default_lines = self.default_lines.clone(); let lines = self.lines.clone(); - let window_size = window.bounds().size; - let scale = window_size.width / DEFAULT_WINDOW_WIDTH; let dashed = self.dashed; div() @@ -228,7 +222,7 @@ impl Render for PaintingViewer { move |_, _, _| {}, move |_, _, window, _| { for (path, color) in default_lines { - window.paint_path(path.clone().scale(scale), color); + window.paint_path(path, color); } for points in lines { @@ -304,11 +298,6 @@ fn main() { cx.open_window( WindowOptions { focus: true, - window_bounds: Some(WindowBounds::Windowed(Bounds::centered( - None, - size(DEFAULT_WINDOW_WIDTH, DEFAULT_WINDOW_HEIGHT), - cx, - ))), ..Default::default() }, |window, cx| cx.new(|cx| PaintingViewer::new(window, cx)), diff --git a/crates/gpui/src/path_builder.rs b/crates/gpui/src/path_builder.rs index 13c168b0bb..6c8cfddd52 100644 --- a/crates/gpui/src/path_builder.rs +++ b/crates/gpui/src/path_builder.rs @@ -336,7 +336,10 @@ impl PathBuilder { let v1 = buf.vertices[i1]; let v2 = buf.vertices[i2]; - path.push_triangle((v0.into(), v1.into(), v2.into())); + path.push_triangle( + (v0.into(), v1.into(), v2.into()), + (point(0., 1.), point(0., 1.), point(0., 1.)), + ); } path diff --git a/crates/gpui/src/platform.rs b/crates/gpui/src/platform.rs index 1ad933dac1..79ec5e5da6 100644 --- a/crates/gpui/src/platform.rs +++ b/crates/gpui/src/platform.rs @@ -789,6 +789,7 @@ pub(crate) struct AtlasTextureId { pub(crate) enum AtlasTextureKind { Monochrome = 0, Polychrome = 1, + Path = 2, } #[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)] diff --git a/crates/gpui/src/platform/blade/blade_atlas.rs b/crates/gpui/src/platform/blade/blade_atlas.rs index 0b119c3910..78ba52056a 100644 --- a/crates/gpui/src/platform/blade/blade_atlas.rs +++ b/crates/gpui/src/platform/blade/blade_atlas.rs @@ -10,6 +10,8 @@ use etagere::BucketedAtlasAllocator; use parking_lot::Mutex; use std::{borrow::Cow, ops, sync::Arc}; +pub(crate) const PATH_TEXTURE_FORMAT: gpu::TextureFormat = gpu::TextureFormat::R16Float; + pub(crate) struct BladeAtlas(Mutex); struct PendingUpload { @@ -25,6 +27,7 @@ struct BladeAtlasState { tiles_by_key: FxHashMap, initializations: Vec, uploads: Vec, + path_sample_count: u32, } #[cfg(gles)] @@ -38,13 +41,13 @@ impl BladeAtlasState { } pub struct BladeTextureInfo { - #[allow(dead_code)] pub size: gpu::Extent, pub raw_view: gpu::TextureView, + pub msaa_view: Option, } impl BladeAtlas { - pub(crate) fn new(gpu: &Arc) -> Self { + pub(crate) fn new(gpu: &Arc, path_sample_count: u32) -> Self { BladeAtlas(Mutex::new(BladeAtlasState { gpu: Arc::clone(gpu), upload_belt: BufferBelt::new(BufferBeltDescriptor { @@ -56,6 +59,7 @@ impl BladeAtlas { tiles_by_key: Default::default(), initializations: Vec::new(), uploads: Vec::new(), + path_sample_count, })) } @@ -63,7 +67,6 @@ impl BladeAtlas { self.0.lock().destroy(); } - #[allow(dead_code)] pub(crate) fn clear_textures(&self, texture_kind: AtlasTextureKind) { let mut lock = self.0.lock(); let textures = &mut lock.storage[texture_kind]; @@ -72,6 +75,19 @@ impl BladeAtlas { } } + /// Allocate a rectangle and make it available for rendering immediately (without waiting for `before_frame`) + pub fn allocate_for_rendering( + &self, + size: Size, + texture_kind: AtlasTextureKind, + gpu_encoder: &mut gpu::CommandEncoder, + ) -> AtlasTile { + let mut lock = self.0.lock(); + let tile = lock.allocate(size, texture_kind); + lock.flush_initializations(gpu_encoder); + tile + } + pub fn before_frame(&self, gpu_encoder: &mut gpu::CommandEncoder) { let mut lock = self.0.lock(); lock.flush(gpu_encoder); @@ -93,6 +109,7 @@ impl BladeAtlas { depth: 1, }, raw_view: texture.raw_view, + msaa_view: texture.msaa_view, } } } @@ -183,8 +200,48 @@ impl BladeAtlasState { format = gpu::TextureFormat::Bgra8UnormSrgb; usage = gpu::TextureUsage::COPY | gpu::TextureUsage::RESOURCE; } + AtlasTextureKind::Path => { + format = PATH_TEXTURE_FORMAT; + usage = gpu::TextureUsage::COPY + | gpu::TextureUsage::RESOURCE + | gpu::TextureUsage::TARGET; + } } + // We currently only enable MSAA for path textures. + let (msaa, msaa_view) = if self.path_sample_count > 1 && kind == AtlasTextureKind::Path { + let msaa = self.gpu.create_texture(gpu::TextureDesc { + name: "msaa path texture", + format, + size: gpu::Extent { + width: size.width.into(), + height: size.height.into(), + depth: 1, + }, + array_layer_count: 1, + mip_level_count: 1, + sample_count: self.path_sample_count, + dimension: gpu::TextureDimension::D2, + usage: gpu::TextureUsage::TARGET, + external: None, + }); + + ( + Some(msaa), + Some(self.gpu.create_texture_view( + msaa, + gpu::TextureViewDesc { + name: "msaa texture view", + format, + dimension: gpu::ViewDimension::D2, + subresources: &Default::default(), + }, + )), + ) + } else { + (None, None) + }; + let raw = self.gpu.create_texture(gpu::TextureDesc { name: "atlas", format, @@ -222,6 +279,8 @@ impl BladeAtlasState { format, raw, raw_view, + msaa, + msaa_view, live_atlas_keys: 0, }; @@ -281,6 +340,7 @@ impl BladeAtlasState { struct BladeAtlasStorage { monochrome_textures: AtlasTextureList, polychrome_textures: AtlasTextureList, + path_textures: AtlasTextureList, } impl ops::Index for BladeAtlasStorage { @@ -289,6 +349,7 @@ impl ops::Index for BladeAtlasStorage { match kind { crate::AtlasTextureKind::Monochrome => &self.monochrome_textures, crate::AtlasTextureKind::Polychrome => &self.polychrome_textures, + crate::AtlasTextureKind::Path => &self.path_textures, } } } @@ -298,6 +359,7 @@ impl ops::IndexMut for BladeAtlasStorage { match kind { crate::AtlasTextureKind::Monochrome => &mut self.monochrome_textures, crate::AtlasTextureKind::Polychrome => &mut self.polychrome_textures, + crate::AtlasTextureKind::Path => &mut self.path_textures, } } } @@ -308,6 +370,7 @@ impl ops::Index for BladeAtlasStorage { let textures = match id.kind { crate::AtlasTextureKind::Monochrome => &self.monochrome_textures, crate::AtlasTextureKind::Polychrome => &self.polychrome_textures, + crate::AtlasTextureKind::Path => &self.path_textures, }; textures[id.index as usize].as_ref().unwrap() } @@ -321,6 +384,9 @@ impl BladeAtlasStorage { for mut texture in self.polychrome_textures.drain().flatten() { texture.destroy(gpu); } + for mut texture in self.path_textures.drain().flatten() { + texture.destroy(gpu); + } } } @@ -329,6 +395,8 @@ struct BladeAtlasTexture { allocator: BucketedAtlasAllocator, raw: gpu::Texture, raw_view: gpu::TextureView, + msaa: Option, + msaa_view: Option, format: gpu::TextureFormat, live_atlas_keys: u32, } @@ -356,6 +424,12 @@ impl BladeAtlasTexture { fn destroy(&mut self, gpu: &gpu::Context) { gpu.destroy_texture(self.raw); gpu.destroy_texture_view(self.raw_view); + if let Some(msaa) = self.msaa { + gpu.destroy_texture(msaa); + } + if let Some(msaa_view) = self.msaa_view { + gpu.destroy_texture_view(msaa_view); + } } fn bytes_per_pixel(&self) -> u8 { diff --git a/crates/gpui/src/platform/blade/blade_renderer.rs b/crates/gpui/src/platform/blade/blade_renderer.rs index 1b9f111b0d..cac47434ae 100644 --- a/crates/gpui/src/platform/blade/blade_renderer.rs +++ b/crates/gpui/src/platform/blade/blade_renderer.rs @@ -1,19 +1,24 @@ // Doing `if let` gives you nice scoping with passes/encoders #![allow(irrefutable_let_patterns)] -use super::{BladeAtlas, BladeContext}; +use super::{BladeAtlas, BladeContext, PATH_TEXTURE_FORMAT}; use crate::{ - Background, Bounds, ContentMask, DevicePixels, GpuSpecs, MonochromeSprite, PathVertex, - PolychromeSprite, PrimitiveBatch, Quad, ScaledPixels, Scene, Shadow, Size, Underline, + AtlasTextureKind, AtlasTile, Background, Bounds, ContentMask, DevicePixels, GpuSpecs, + MonochromeSprite, Path, PathId, PathVertex, PolychromeSprite, PrimitiveBatch, Quad, + ScaledPixels, Scene, Shadow, Size, Underline, }; -use blade_graphics::{self as gpu}; +use blade_graphics as gpu; use blade_util::{BufferBelt, BufferBeltDescriptor}; use bytemuck::{Pod, Zeroable}; +use collections::HashMap; #[cfg(target_os = "macos")] use media::core_video::CVMetalTextureCache; use std::{mem, sync::Arc}; const MAX_FRAME_TIME_MS: u32 = 10000; +// Use 4x MSAA, all devices support it. +// https://developer.apple.com/documentation/metal/mtldevice/1433355-supportstexturesamplecount +const DEFAULT_PATH_SAMPLE_COUNT: u32 = 4; #[repr(C)] #[derive(Clone, Copy, Pod, Zeroable)] @@ -61,9 +66,16 @@ struct ShaderShadowsData { } #[derive(blade_macros::ShaderData)] -struct ShaderPathsData { +struct ShaderPathRasterizationData { globals: GlobalParams, b_path_vertices: gpu::BufferPiece, +} + +#[derive(blade_macros::ShaderData)] +struct ShaderPathsData { + globals: GlobalParams, + t_sprite: gpu::TextureView, + s_sprite: gpu::Sampler, b_path_sprites: gpu::BufferPiece, } @@ -103,27 +115,13 @@ struct ShaderSurfacesData { struct PathSprite { bounds: Bounds, color: Background, -} - -/// Argument buffer layout for `draw_indirect` commands. -#[repr(C)] -#[derive(Copy, Clone, Debug, Default, Pod, Zeroable)] -pub struct DrawIndirectArgs { - /// The number of vertices to draw. - pub vertex_count: u32, - /// The number of instances to draw. - pub instance_count: u32, - /// The Index of the first vertex to draw. - pub first_vertex: u32, - /// The instance ID of the first instance to draw. - /// - /// Has to be 0, unless [`Features::INDIRECT_FIRST_INSTANCE`](crate::Features::INDIRECT_FIRST_INSTANCE) is enabled. - pub first_instance: u32, + tile: AtlasTile, } struct BladePipelines { quads: gpu::RenderPipeline, shadows: gpu::RenderPipeline, + path_rasterization: gpu::RenderPipeline, paths: gpu::RenderPipeline, underlines: gpu::RenderPipeline, mono_sprites: gpu::RenderPipeline, @@ -132,7 +130,7 @@ struct BladePipelines { } impl BladePipelines { - fn new(gpu: &gpu::Context, surface_info: gpu::SurfaceInfo, sample_count: u32) -> Self { + fn new(gpu: &gpu::Context, surface_info: gpu::SurfaceInfo, path_sample_count: u32) -> Self { use gpu::ShaderData as _; log::info!( @@ -180,10 +178,7 @@ impl BladePipelines { depth_stencil: None, fragment: Some(shader.at("fs_quad")), color_targets, - multisample_state: gpu::MultisampleState { - sample_count, - ..Default::default() - }, + multisample_state: gpu::MultisampleState::default(), }), shadows: gpu.create_render_pipeline(gpu::RenderPipelineDesc { name: "shadows", @@ -197,8 +192,26 @@ impl BladePipelines { depth_stencil: None, fragment: Some(shader.at("fs_shadow")), color_targets, + multisample_state: gpu::MultisampleState::default(), + }), + path_rasterization: gpu.create_render_pipeline(gpu::RenderPipelineDesc { + name: "path_rasterization", + data_layouts: &[&ShaderPathRasterizationData::layout()], + vertex: shader.at("vs_path_rasterization"), + vertex_fetches: &[], + primitive: gpu::PrimitiveState { + topology: gpu::PrimitiveTopology::TriangleList, + ..Default::default() + }, + depth_stencil: None, + fragment: Some(shader.at("fs_path_rasterization")), + color_targets: &[gpu::ColorTargetState { + format: PATH_TEXTURE_FORMAT, + blend: Some(gpu::BlendState::ADDITIVE), + write_mask: gpu::ColorWrites::default(), + }], multisample_state: gpu::MultisampleState { - sample_count, + sample_count: path_sample_count, ..Default::default() }, }), @@ -208,16 +221,13 @@ impl BladePipelines { vertex: shader.at("vs_path"), vertex_fetches: &[], primitive: gpu::PrimitiveState { - topology: gpu::PrimitiveTopology::TriangleList, + topology: gpu::PrimitiveTopology::TriangleStrip, ..Default::default() }, depth_stencil: None, fragment: Some(shader.at("fs_path")), color_targets, - multisample_state: gpu::MultisampleState { - sample_count, - ..Default::default() - }, + multisample_state: gpu::MultisampleState::default(), }), underlines: gpu.create_render_pipeline(gpu::RenderPipelineDesc { name: "underlines", @@ -231,10 +241,7 @@ impl BladePipelines { depth_stencil: None, fragment: Some(shader.at("fs_underline")), color_targets, - multisample_state: gpu::MultisampleState { - sample_count, - ..Default::default() - }, + multisample_state: gpu::MultisampleState::default(), }), mono_sprites: gpu.create_render_pipeline(gpu::RenderPipelineDesc { name: "mono-sprites", @@ -248,10 +255,7 @@ impl BladePipelines { depth_stencil: None, fragment: Some(shader.at("fs_mono_sprite")), color_targets, - multisample_state: gpu::MultisampleState { - sample_count, - ..Default::default() - }, + multisample_state: gpu::MultisampleState::default(), }), poly_sprites: gpu.create_render_pipeline(gpu::RenderPipelineDesc { name: "poly-sprites", @@ -265,10 +269,7 @@ impl BladePipelines { depth_stencil: None, fragment: Some(shader.at("fs_poly_sprite")), color_targets, - multisample_state: gpu::MultisampleState { - sample_count, - ..Default::default() - }, + multisample_state: gpu::MultisampleState::default(), }), surfaces: gpu.create_render_pipeline(gpu::RenderPipelineDesc { name: "surfaces", @@ -282,10 +283,7 @@ impl BladePipelines { depth_stencil: None, fragment: Some(shader.at("fs_surface")), color_targets, - multisample_state: gpu::MultisampleState { - sample_count, - ..Default::default() - }, + multisample_state: gpu::MultisampleState::default(), }), } } @@ -293,6 +291,7 @@ impl BladePipelines { fn destroy(&mut self, gpu: &gpu::Context) { gpu.destroy_render_pipeline(&mut self.quads); gpu.destroy_render_pipeline(&mut self.shadows); + gpu.destroy_render_pipeline(&mut self.path_rasterization); gpu.destroy_render_pipeline(&mut self.paths); gpu.destroy_render_pipeline(&mut self.underlines); gpu.destroy_render_pipeline(&mut self.mono_sprites); @@ -318,13 +317,12 @@ pub struct BladeRenderer { last_sync_point: Option, pipelines: BladePipelines, instance_belt: BufferBelt, + path_tiles: HashMap, atlas: Arc, atlas_sampler: gpu::Sampler, #[cfg(target_os = "macos")] core_video_texture_cache: CVMetalTextureCache, - sample_count: u32, - texture_msaa: Option, - texture_view_msaa: Option, + path_sample_count: u32, } impl BladeRenderer { @@ -333,18 +331,6 @@ impl BladeRenderer { window: &I, config: BladeSurfaceConfig, ) -> anyhow::Result { - // workaround for https://github.com/zed-industries/zed/issues/26143 - let sample_count = std::env::var("ZED_SAMPLE_COUNT") - .ok() - .or_else(|| std::env::var("ZED_PATH_SAMPLE_COUNT").ok()) - .and_then(|v| v.parse().ok()) - .or_else(|| { - [4, 2, 1] - .into_iter() - .find(|count| context.gpu.supports_texture_sample_count(*count)) - }) - .unwrap_or(1); - let surface_config = gpu::SurfaceConfig { size: config.size, usage: gpu::TextureUsage::TARGET, @@ -358,27 +344,22 @@ impl BladeRenderer { .create_surface_configured(window, surface_config) .map_err(|err| anyhow::anyhow!("Failed to create surface: {err:?}"))?; - let (texture_msaa, texture_view_msaa) = create_msaa_texture_if_needed( - &context.gpu, - surface.info().format, - config.size.width, - config.size.height, - sample_count, - ) - .unzip(); - let command_encoder = context.gpu.create_command_encoder(gpu::CommandEncoderDesc { name: "main", buffer_count: 2, }); - - let pipelines = BladePipelines::new(&context.gpu, surface.info(), sample_count); + // workaround for https://github.com/zed-industries/zed/issues/26143 + let path_sample_count = std::env::var("ZED_PATH_SAMPLE_COUNT") + .ok() + .and_then(|v| v.parse().ok()) + .unwrap_or(DEFAULT_PATH_SAMPLE_COUNT); + let pipelines = BladePipelines::new(&context.gpu, surface.info(), path_sample_count); let instance_belt = BufferBelt::new(BufferBeltDescriptor { memory: gpu::Memory::Shared, min_chunk_size: 0x1000, alignment: 0x40, // Vulkan `minStorageBufferOffsetAlignment` on Intel Xe }); - let atlas = Arc::new(BladeAtlas::new(&context.gpu)); + let atlas = Arc::new(BladeAtlas::new(&context.gpu, path_sample_count)); let atlas_sampler = context.gpu.create_sampler(gpu::SamplerDesc { name: "atlas", mag_filter: gpu::FilterMode::Linear, @@ -402,13 +383,12 @@ impl BladeRenderer { last_sync_point: None, pipelines, instance_belt, + path_tiles: HashMap::default(), atlas, atlas_sampler, #[cfg(target_os = "macos")] core_video_texture_cache, - sample_count, - texture_msaa, - texture_view_msaa, + path_sample_count, }) } @@ -461,24 +441,6 @@ impl BladeRenderer { self.surface_config.size = gpu_size; self.gpu .reconfigure_surface(&mut self.surface, self.surface_config); - - if let Some(texture_msaa) = self.texture_msaa { - self.gpu.destroy_texture(texture_msaa); - } - if let Some(texture_view_msaa) = self.texture_view_msaa { - self.gpu.destroy_texture_view(texture_view_msaa); - } - - let (texture_msaa, texture_view_msaa) = create_msaa_texture_if_needed( - &self.gpu, - self.surface.info().format, - gpu_size.width, - gpu_size.height, - self.sample_count, - ) - .unzip(); - self.texture_msaa = texture_msaa; - self.texture_view_msaa = texture_view_msaa; } } @@ -489,7 +451,8 @@ impl BladeRenderer { self.gpu .reconfigure_surface(&mut self.surface, self.surface_config); self.pipelines.destroy(&self.gpu); - self.pipelines = BladePipelines::new(&self.gpu, self.surface.info(), self.sample_count); + self.pipelines = + BladePipelines::new(&self.gpu, self.surface.info(), self.path_sample_count); } } @@ -527,6 +490,80 @@ impl BladeRenderer { objc2::rc::Retained::as_ptr(&self.surface.metal_layer()) as *mut _ } + #[profiling::function] + fn rasterize_paths(&mut self, paths: &[Path]) { + self.path_tiles.clear(); + let mut vertices_by_texture_id = HashMap::default(); + + for path in paths { + let clipped_bounds = path + .bounds + .intersect(&path.content_mask.bounds) + .map_origin(|origin| origin.floor()) + .map_size(|size| size.ceil()); + let tile = self.atlas.allocate_for_rendering( + clipped_bounds.size.map(Into::into), + AtlasTextureKind::Path, + &mut self.command_encoder, + ); + vertices_by_texture_id + .entry(tile.texture_id) + .or_insert(Vec::new()) + .extend(path.vertices.iter().map(|vertex| PathVertex { + xy_position: vertex.xy_position - clipped_bounds.origin + + tile.bounds.origin.map(Into::into), + st_position: vertex.st_position, + content_mask: ContentMask { + bounds: tile.bounds.map(Into::into), + }, + })); + self.path_tiles.insert(path.id, tile); + } + + for (texture_id, vertices) in vertices_by_texture_id { + let tex_info = self.atlas.get_texture_info(texture_id); + let globals = GlobalParams { + viewport_size: [tex_info.size.width as f32, tex_info.size.height as f32], + premultiplied_alpha: 0, + pad: 0, + }; + + let vertex_buf = unsafe { self.instance_belt.alloc_typed(&vertices, &self.gpu) }; + let frame_view = tex_info.raw_view; + let color_target = if let Some(msaa_view) = tex_info.msaa_view { + gpu::RenderTarget { + view: msaa_view, + init_op: gpu::InitOp::Clear(gpu::TextureColor::OpaqueBlack), + finish_op: gpu::FinishOp::ResolveTo(frame_view), + } + } else { + gpu::RenderTarget { + view: frame_view, + init_op: gpu::InitOp::Clear(gpu::TextureColor::OpaqueBlack), + finish_op: gpu::FinishOp::Store, + } + }; + + if let mut pass = self.command_encoder.render( + "paths", + gpu::RenderTargetSet { + colors: &[color_target], + depth_stencil: None, + }, + ) { + let mut encoder = pass.with(&self.pipelines.path_rasterization); + encoder.bind( + 0, + &ShaderPathRasterizationData { + globals, + b_path_vertices: vertex_buf, + }, + ); + encoder.draw(0, vertices.len() as u32, 0, 1); + } + } + } + pub fn destroy(&mut self) { self.wait_for_gpu(); self.atlas.destroy(); @@ -535,26 +572,17 @@ impl BladeRenderer { self.gpu.destroy_command_encoder(&mut self.command_encoder); self.pipelines.destroy(&self.gpu); self.gpu.destroy_surface(&mut self.surface); - if let Some(texture_msaa) = self.texture_msaa { - self.gpu.destroy_texture(texture_msaa); - } - if let Some(texture_view_msaa) = self.texture_view_msaa { - self.gpu.destroy_texture_view(texture_view_msaa); - } } pub fn draw(&mut self, scene: &Scene) { self.command_encoder.start(); self.atlas.before_frame(&mut self.command_encoder); + self.rasterize_paths(scene.paths()); let frame = { profiling::scope!("acquire frame"); self.surface.acquire_frame() }; - let frame_view = frame.texture_view(); - if let Some(texture_msaa) = self.texture_msaa { - self.command_encoder.init_texture(texture_msaa); - } self.command_encoder.init_texture(frame.texture()); let globals = GlobalParams { @@ -569,25 +597,14 @@ impl BladeRenderer { pad: 0, }; - let target = if let Some(texture_view_msaa) = self.texture_view_msaa { - gpu::RenderTarget { - view: texture_view_msaa, - init_op: gpu::InitOp::Clear(gpu::TextureColor::TransparentBlack), - finish_op: gpu::FinishOp::ResolveTo(frame_view), - } - } else { - gpu::RenderTarget { - view: frame_view, - init_op: gpu::InitOp::Clear(gpu::TextureColor::TransparentBlack), - finish_op: gpu::FinishOp::Store, - } - }; - - // draw to the target texture if let mut pass = self.command_encoder.render( "main", gpu::RenderTargetSet { - colors: &[target], + colors: &[gpu::RenderTarget { + view: frame.texture_view(), + init_op: gpu::InitOp::Clear(gpu::TextureColor::TransparentBlack), + finish_op: gpu::FinishOp::Store, + }], depth_stencil: None, }, ) { @@ -622,55 +639,32 @@ impl BladeRenderer { } PrimitiveBatch::Paths(paths) => { let mut encoder = pass.with(&self.pipelines.paths); - - let mut vertices = Vec::new(); - let mut sprites = Vec::with_capacity(paths.len()); - let mut draw_indirect_commands = Vec::with_capacity(paths.len()); - let mut first_vertex = 0; - - for (i, path) in paths.iter().enumerate() { - draw_indirect_commands.push(DrawIndirectArgs { - vertex_count: path.vertices.len() as u32, - instance_count: 1, - first_vertex, - first_instance: i as u32, - }); - first_vertex += path.vertices.len() as u32; - - vertices.extend(path.vertices.iter().map(|v| PathVertex { - xy_position: v.xy_position, - content_mask: ContentMask { - bounds: path.content_mask.bounds, + // todo(linux): group by texture ID + for path in paths { + let tile = &self.path_tiles[&path.id]; + let tex_info = self.atlas.get_texture_info(tile.texture_id); + let origin = path.bounds.intersect(&path.content_mask.bounds).origin; + let sprites = [PathSprite { + bounds: Bounds { + origin: origin.map(|p| p.floor()), + size: tile.bounds.size.map(Into::into), }, - })); - - sprites.push(PathSprite { - bounds: path.bounds, color: path.color, - }); - } + tile: (*tile).clone(), + }]; - let b_path_vertices = - unsafe { self.instance_belt.alloc_typed(&vertices, &self.gpu) }; - let instance_buf = - unsafe { self.instance_belt.alloc_typed(&sprites, &self.gpu) }; - let indirect_buf = unsafe { - self.instance_belt - .alloc_typed(&draw_indirect_commands, &self.gpu) - }; - - encoder.bind( - 0, - &ShaderPathsData { - globals, - b_path_vertices, - b_path_sprites: instance_buf, - }, - ); - - for i in 0..paths.len() { - encoder.draw_indirect(indirect_buf.buffer.at(indirect_buf.offset - + (i * mem::size_of::()) as u64)); + let instance_buf = + unsafe { self.instance_belt.alloc_typed(&sprites, &self.gpu) }; + encoder.bind( + 0, + &ShaderPathsData { + globals, + t_sprite: tex_info.raw_view, + s_sprite: self.atlas_sampler, + b_path_sprites: instance_buf, + }, + ); + encoder.draw(0, 4, 0, sprites.len() as u32); } } PrimitiveBatch::Underlines(underlines) => { @@ -823,47 +817,9 @@ impl BladeRenderer { profiling::scope!("finish"); self.instance_belt.flush(&sync_point); self.atlas.after_frame(&sync_point); + self.atlas.clear_textures(AtlasTextureKind::Path); self.wait_for_gpu(); self.last_sync_point = Some(sync_point); } } - -fn create_msaa_texture_if_needed( - gpu: &gpu::Context, - format: gpu::TextureFormat, - width: u32, - height: u32, - sample_count: u32, -) -> Option<(gpu::Texture, gpu::TextureView)> { - if sample_count <= 1 { - return None; - } - - let texture_msaa = gpu.create_texture(gpu::TextureDesc { - name: "msaa", - format, - size: gpu::Extent { - width, - height, - depth: 1, - }, - array_layer_count: 1, - mip_level_count: 1, - sample_count, - dimension: gpu::TextureDimension::D2, - usage: gpu::TextureUsage::TARGET, - external: None, - }); - let texture_view_msaa = gpu.create_texture_view( - texture_msaa, - gpu::TextureViewDesc { - name: "msaa view", - format, - dimension: gpu::ViewDimension::D2, - subresources: &Default::default(), - }, - ); - - Some((texture_msaa, texture_view_msaa)) -} diff --git a/crates/gpui/src/platform/blade/shaders.wgsl b/crates/gpui/src/platform/blade/shaders.wgsl index 00c9d07af7..0b34a0eea3 100644 --- a/crates/gpui/src/platform/blade/shaders.wgsl +++ b/crates/gpui/src/platform/blade/shaders.wgsl @@ -922,23 +922,59 @@ fn fs_shadow(input: ShadowVarying) -> @location(0) vec4 { return blend_color(input.color, alpha); } -// --- paths --- // +// --- path rasterization --- // struct PathVertex { xy_position: vec2, + st_position: vec2, content_mask: Bounds, } +var b_path_vertices: array; + +struct PathRasterizationVarying { + @builtin(position) position: vec4, + @location(0) st_position: vec2, + //TODO: use `clip_distance` once Naga supports it + @location(3) clip_distances: vec4, +} + +@vertex +fn vs_path_rasterization(@builtin(vertex_index) vertex_id: u32) -> PathRasterizationVarying { + let v = b_path_vertices[vertex_id]; + + var out = PathRasterizationVarying(); + out.position = to_device_position_impl(v.xy_position); + out.st_position = v.st_position; + out.clip_distances = distance_from_clip_rect_impl(v.xy_position, v.content_mask); + return out; +} + +@fragment +fn fs_path_rasterization(input: PathRasterizationVarying) -> @location(0) f32 { + let dx = dpdx(input.st_position); + let dy = dpdy(input.st_position); + if (any(input.clip_distances < vec4(0.0))) { + return 0.0; + } + + let gradient = 2.0 * input.st_position.xx * vec2(dx.x, dy.x) - vec2(dx.y, dy.y); + let f = input.st_position.x * input.st_position.x - input.st_position.y; + let distance = f / length(gradient); + return saturate(0.5 - distance); +} + +// --- paths --- // struct PathSprite { bounds: Bounds, color: Background, + tile: AtlasTile, } -var b_path_vertices: array; var b_path_sprites: array; struct PathVarying { @builtin(position) position: vec4, - @location(0) clip_distances: vec4, + @location(0) tile_position: vec2, @location(1) @interpolate(flat) instance_id: u32, @location(2) @interpolate(flat) color_solid: vec4, @location(3) @interpolate(flat) color0: vec4, @@ -947,12 +983,13 @@ struct PathVarying { @vertex fn vs_path(@builtin(vertex_index) vertex_id: u32, @builtin(instance_index) instance_id: u32) -> PathVarying { - let v = b_path_vertices[vertex_id]; + let unit_vertex = vec2(f32(vertex_id & 1u), 0.5 * f32(vertex_id & 2u)); let sprite = b_path_sprites[instance_id]; + // Don't apply content mask because it was already accounted for when rasterizing the path. var out = PathVarying(); - out.position = to_device_position_impl(v.xy_position); - out.clip_distances = distance_from_clip_rect_impl(v.xy_position, v.content_mask); + out.position = to_device_position(unit_vertex, sprite.bounds); + out.tile_position = to_tile_position(unit_vertex, sprite.tile); out.instance_id = instance_id; let gradient = prepare_gradient_color( @@ -969,15 +1006,13 @@ fn vs_path(@builtin(vertex_index) vertex_id: u32, @builtin(instance_index) insta @fragment fn fs_path(input: PathVarying) -> @location(0) vec4 { - if any(input.clip_distances < vec4(0.0)) { - return vec4(0.0); - } - + let sample = textureSample(t_sprite, s_sprite, input.tile_position).r; + let mask = 1.0 - abs(1.0 - sample % 2.0); let sprite = b_path_sprites[input.instance_id]; let background = sprite.color; let color = gradient_color(background, input.position.xy, sprite.bounds, input.color_solid, input.color0, input.color1); - return blend_color(color, 1.0); + return blend_color(color, mask); } // --- underlines --- // diff --git a/crates/gpui/src/platform/mac/metal_atlas.rs b/crates/gpui/src/platform/mac/metal_atlas.rs index 0c8e1d3703..366f2dcc3c 100644 --- a/crates/gpui/src/platform/mac/metal_atlas.rs +++ b/crates/gpui/src/platform/mac/metal_atlas.rs @@ -13,12 +13,14 @@ use std::borrow::Cow; pub(crate) struct MetalAtlas(Mutex); impl MetalAtlas { - pub(crate) fn new(device: Device) -> Self { + pub(crate) fn new(device: Device, path_sample_count: u32) -> Self { MetalAtlas(Mutex::new(MetalAtlasState { device: AssertSend(device), monochrome_textures: Default::default(), polychrome_textures: Default::default(), + path_textures: Default::default(), tiles_by_key: Default::default(), + path_sample_count, })) } @@ -26,7 +28,10 @@ impl MetalAtlas { self.0.lock().texture(id).metal_texture.clone() } - #[allow(dead_code)] + pub(crate) fn msaa_texture(&self, id: AtlasTextureId) -> Option { + self.0.lock().texture(id).msaa_texture.clone() + } + pub(crate) fn allocate( &self, size: Size, @@ -35,12 +40,12 @@ impl MetalAtlas { self.0.lock().allocate(size, texture_kind) } - #[allow(dead_code)] pub(crate) fn clear_textures(&self, texture_kind: AtlasTextureKind) { let mut lock = self.0.lock(); let textures = match texture_kind { AtlasTextureKind::Monochrome => &mut lock.monochrome_textures, AtlasTextureKind::Polychrome => &mut lock.polychrome_textures, + AtlasTextureKind::Path => &mut lock.path_textures, }; for texture in textures.iter_mut() { texture.clear(); @@ -52,7 +57,9 @@ struct MetalAtlasState { device: AssertSend, monochrome_textures: AtlasTextureList, polychrome_textures: AtlasTextureList, + path_textures: AtlasTextureList, tiles_by_key: FxHashMap, + path_sample_count: u32, } impl PlatformAtlas for MetalAtlas { @@ -87,6 +94,7 @@ impl PlatformAtlas for MetalAtlas { let textures = match id.kind { AtlasTextureKind::Monochrome => &mut lock.monochrome_textures, AtlasTextureKind::Polychrome => &mut lock.polychrome_textures, + AtlasTextureKind::Path => &mut lock.polychrome_textures, }; let Some(texture_slot) = textures @@ -120,6 +128,7 @@ impl MetalAtlasState { let textures = match texture_kind { AtlasTextureKind::Monochrome => &mut self.monochrome_textures, AtlasTextureKind::Polychrome => &mut self.polychrome_textures, + AtlasTextureKind::Path => &mut self.path_textures, }; if let Some(tile) = textures @@ -164,14 +173,31 @@ impl MetalAtlasState { pixel_format = metal::MTLPixelFormat::BGRA8Unorm; usage = metal::MTLTextureUsage::ShaderRead; } + AtlasTextureKind::Path => { + pixel_format = metal::MTLPixelFormat::R16Float; + usage = metal::MTLTextureUsage::RenderTarget | metal::MTLTextureUsage::ShaderRead; + } } texture_descriptor.set_pixel_format(pixel_format); texture_descriptor.set_usage(usage); let metal_texture = self.device.new_texture(&texture_descriptor); + // We currently only enable MSAA for path textures. + let msaa_texture = if self.path_sample_count > 1 && kind == AtlasTextureKind::Path { + let mut descriptor = texture_descriptor.clone(); + descriptor.set_texture_type(metal::MTLTextureType::D2Multisample); + descriptor.set_storage_mode(metal::MTLStorageMode::Private); + descriptor.set_sample_count(self.path_sample_count as _); + let msaa_texture = self.device.new_texture(&descriptor); + Some(msaa_texture) + } else { + None + }; + let texture_list = match kind { AtlasTextureKind::Monochrome => &mut self.monochrome_textures, AtlasTextureKind::Polychrome => &mut self.polychrome_textures, + AtlasTextureKind::Path => &mut self.path_textures, }; let index = texture_list.free_list.pop(); @@ -183,6 +209,7 @@ impl MetalAtlasState { }, allocator: etagere::BucketedAtlasAllocator::new(size.into()), metal_texture: AssertSend(metal_texture), + msaa_texture: AssertSend(msaa_texture), live_atlas_keys: 0, }; @@ -199,6 +226,7 @@ impl MetalAtlasState { let textures = match id.kind { crate::AtlasTextureKind::Monochrome => &self.monochrome_textures, crate::AtlasTextureKind::Polychrome => &self.polychrome_textures, + crate::AtlasTextureKind::Path => &self.path_textures, }; textures[id.index as usize].as_ref().unwrap() } @@ -208,6 +236,7 @@ struct MetalAtlasTexture { id: AtlasTextureId, allocator: BucketedAtlasAllocator, metal_texture: AssertSend, + msaa_texture: AssertSend>, live_atlas_keys: u32, } diff --git a/crates/gpui/src/platform/mac/metal_renderer.rs b/crates/gpui/src/platform/mac/metal_renderer.rs index 8936cf242c..3cdc2dd2cf 100644 --- a/crates/gpui/src/platform/mac/metal_renderer.rs +++ b/crates/gpui/src/platform/mac/metal_renderer.rs @@ -1,28 +1,27 @@ use super::metal_atlas::MetalAtlas; use crate::{ - AtlasTextureId, Background, Bounds, ContentMask, DevicePixels, MonochromeSprite, PaintSurface, - Path, PathVertex, PolychromeSprite, PrimitiveBatch, Quad, ScaledPixels, Scene, Shadow, Size, - Surface, Underline, point, size, + AtlasTextureId, AtlasTextureKind, AtlasTile, Background, Bounds, ContentMask, DevicePixels, + MonochromeSprite, PaintSurface, Path, PathId, PathVertex, PolychromeSprite, PrimitiveBatch, + Quad, ScaledPixels, Scene, Shadow, Size, Surface, Underline, point, size, }; -use anyhow::Result; +use anyhow::{Context as _, Result}; use block::ConcreteBlock; use cocoa::{ base::{NO, YES}, foundation::{NSSize, NSUInteger}, quartzcore::AutoresizingMask, }; +use collections::HashMap; use core_foundation::base::TCFType; use core_video::{ metal_texture::CVMetalTextureGetTexture, metal_texture_cache::CVMetalTextureCache, pixel_buffer::kCVPixelFormatType_420YpCbCr8BiPlanarFullRange, }; use foreign_types::{ForeignType, ForeignTypeRef}; -use metal::{ - CAMetalLayer, CommandQueue, MTLDrawPrimitivesIndirectArguments, MTLPixelFormat, - MTLResourceOptions, NSRange, -}; +use metal::{CAMetalLayer, CommandQueue, MTLPixelFormat, MTLResourceOptions, NSRange}; use objc::{self, msg_send, sel, sel_impl}; use parking_lot::Mutex; +use smallvec::SmallVec; use std::{cell::Cell, ffi::c_void, mem, ptr, sync::Arc}; // Exported to metal @@ -32,6 +31,9 @@ pub(crate) type PointF = crate::Point; const SHADERS_METALLIB: &[u8] = include_bytes!(concat!(env!("OUT_DIR"), "/shaders.metallib")); #[cfg(feature = "runtime_shaders")] const SHADERS_SOURCE_FILE: &str = include_str!(concat!(env!("OUT_DIR"), "/stitched_shaders.metal")); +// Use 4x MSAA, all devices support it. +// https://developer.apple.com/documentation/metal/mtldevice/1433355-supportstexturesamplecount +const PATH_SAMPLE_COUNT: u32 = 4; pub type Context = Arc>; pub type Renderer = MetalRenderer; @@ -96,7 +98,8 @@ pub(crate) struct MetalRenderer { layer: metal::MetalLayer, presents_with_transaction: bool, command_queue: CommandQueue, - path_pipeline_state: metal::RenderPipelineState, + paths_rasterization_pipeline_state: metal::RenderPipelineState, + path_sprites_pipeline_state: metal::RenderPipelineState, shadows_pipeline_state: metal::RenderPipelineState, quads_pipeline_state: metal::RenderPipelineState, underlines_pipeline_state: metal::RenderPipelineState, @@ -108,8 +111,6 @@ pub(crate) struct MetalRenderer { instance_buffer_pool: Arc>, sprite_atlas: Arc, core_video_texture_cache: core_video::metal_texture_cache::CVMetalTextureCache, - sample_count: u64, - msaa_texture: Option, } impl MetalRenderer { @@ -168,19 +169,22 @@ impl MetalRenderer { MTLResourceOptions::StorageModeManaged, ); - let sample_count = [4, 2, 1] - .into_iter() - .find(|count| device.supports_texture_sample_count(*count)) - .unwrap_or(1); - - let path_pipeline_state = build_pipeline_state( + let paths_rasterization_pipeline_state = build_path_rasterization_pipeline_state( &device, &library, - "paths", - "path_vertex", - "path_fragment", + "paths_rasterization", + "path_rasterization_vertex", + "path_rasterization_fragment", + MTLPixelFormat::R16Float, + PATH_SAMPLE_COUNT, + ); + let path_sprites_pipeline_state = build_pipeline_state( + &device, + &library, + "path_sprites", + "path_sprite_vertex", + "path_sprite_fragment", MTLPixelFormat::BGRA8Unorm, - sample_count, ); let shadows_pipeline_state = build_pipeline_state( &device, @@ -189,7 +193,6 @@ impl MetalRenderer { "shadow_vertex", "shadow_fragment", MTLPixelFormat::BGRA8Unorm, - sample_count, ); let quads_pipeline_state = build_pipeline_state( &device, @@ -198,7 +201,6 @@ impl MetalRenderer { "quad_vertex", "quad_fragment", MTLPixelFormat::BGRA8Unorm, - sample_count, ); let underlines_pipeline_state = build_pipeline_state( &device, @@ -207,7 +209,6 @@ impl MetalRenderer { "underline_vertex", "underline_fragment", MTLPixelFormat::BGRA8Unorm, - sample_count, ); let monochrome_sprites_pipeline_state = build_pipeline_state( &device, @@ -216,7 +217,6 @@ impl MetalRenderer { "monochrome_sprite_vertex", "monochrome_sprite_fragment", MTLPixelFormat::BGRA8Unorm, - sample_count, ); let polychrome_sprites_pipeline_state = build_pipeline_state( &device, @@ -225,7 +225,6 @@ impl MetalRenderer { "polychrome_sprite_vertex", "polychrome_sprite_fragment", MTLPixelFormat::BGRA8Unorm, - sample_count, ); let surfaces_pipeline_state = build_pipeline_state( &device, @@ -234,21 +233,20 @@ impl MetalRenderer { "surface_vertex", "surface_fragment", MTLPixelFormat::BGRA8Unorm, - sample_count, ); let command_queue = device.new_command_queue(); - let sprite_atlas = Arc::new(MetalAtlas::new(device.clone())); + let sprite_atlas = Arc::new(MetalAtlas::new(device.clone(), PATH_SAMPLE_COUNT)); let core_video_texture_cache = CVMetalTextureCache::new(None, device.clone(), None).unwrap(); - let msaa_texture = create_msaa_texture(&device, &layer, sample_count); Self { device, layer, presents_with_transaction: false, command_queue, - path_pipeline_state, + paths_rasterization_pipeline_state, + path_sprites_pipeline_state, shadows_pipeline_state, quads_pipeline_state, underlines_pipeline_state, @@ -259,8 +257,6 @@ impl MetalRenderer { instance_buffer_pool, sprite_atlas, core_video_texture_cache, - sample_count, - msaa_texture, } } @@ -293,8 +289,6 @@ impl MetalRenderer { setDrawableSize: size ]; } - - self.msaa_texture = create_msaa_texture(&self.device, &self.layer, self.sample_count); } pub fn update_transparency(&self, _transparent: bool) { @@ -381,23 +375,25 @@ impl MetalRenderer { let command_queue = self.command_queue.clone(); let command_buffer = command_queue.new_command_buffer(); let mut instance_offset = 0; + + let path_tiles = self + .rasterize_paths( + scene.paths(), + instance_buffer, + &mut instance_offset, + command_buffer, + ) + .with_context(|| format!("rasterizing {} paths", scene.paths().len()))?; + let render_pass_descriptor = metal::RenderPassDescriptor::new(); let color_attachment = render_pass_descriptor .color_attachments() .object_at(0) .unwrap(); - if let Some(msaa_texture_ref) = self.msaa_texture.as_deref() { - color_attachment.set_texture(Some(msaa_texture_ref)); - color_attachment.set_load_action(metal::MTLLoadAction::Clear); - color_attachment.set_store_action(metal::MTLStoreAction::MultisampleResolve); - color_attachment.set_resolve_texture(Some(drawable.texture())); - } else { - color_attachment.set_load_action(metal::MTLLoadAction::Clear); - color_attachment.set_texture(Some(drawable.texture())); - color_attachment.set_store_action(metal::MTLStoreAction::Store); - } - + color_attachment.set_texture(Some(drawable.texture())); + color_attachment.set_load_action(metal::MTLLoadAction::Clear); + color_attachment.set_store_action(metal::MTLStoreAction::Store); let alpha = if self.layer.is_opaque() { 1. } else { 0. }; color_attachment.set_clear_color(metal::MTLClearColor::new(0., 0., 0., alpha)); let command_encoder = command_buffer.new_render_command_encoder(render_pass_descriptor); @@ -429,6 +425,7 @@ impl MetalRenderer { ), PrimitiveBatch::Paths(paths) => self.draw_paths( paths, + &path_tiles, instance_buffer, &mut instance_offset, viewport_size, @@ -496,6 +493,106 @@ impl MetalRenderer { Ok(command_buffer.to_owned()) } + fn rasterize_paths( + &self, + paths: &[Path], + instance_buffer: &mut InstanceBuffer, + instance_offset: &mut usize, + command_buffer: &metal::CommandBufferRef, + ) -> Option> { + self.sprite_atlas.clear_textures(AtlasTextureKind::Path); + + let mut tiles = HashMap::default(); + let mut vertices_by_texture_id = HashMap::default(); + for path in paths { + let clipped_bounds = path.bounds.intersect(&path.content_mask.bounds); + + let tile = self + .sprite_atlas + .allocate(clipped_bounds.size.map(Into::into), AtlasTextureKind::Path)?; + vertices_by_texture_id + .entry(tile.texture_id) + .or_insert(Vec::new()) + .extend(path.vertices.iter().map(|vertex| PathVertex { + xy_position: vertex.xy_position - clipped_bounds.origin + + tile.bounds.origin.map(Into::into), + st_position: vertex.st_position, + content_mask: ContentMask { + bounds: tile.bounds.map(Into::into), + }, + })); + tiles.insert(path.id, tile); + } + + for (texture_id, vertices) in vertices_by_texture_id { + align_offset(instance_offset); + let vertices_bytes_len = mem::size_of_val(vertices.as_slice()); + let next_offset = *instance_offset + vertices_bytes_len; + if next_offset > instance_buffer.size { + return None; + } + + let render_pass_descriptor = metal::RenderPassDescriptor::new(); + let color_attachment = render_pass_descriptor + .color_attachments() + .object_at(0) + .unwrap(); + + let texture = self.sprite_atlas.metal_texture(texture_id); + let msaa_texture = self.sprite_atlas.msaa_texture(texture_id); + + if let Some(msaa_texture) = msaa_texture { + color_attachment.set_texture(Some(&msaa_texture)); + color_attachment.set_resolve_texture(Some(&texture)); + color_attachment.set_load_action(metal::MTLLoadAction::Clear); + color_attachment.set_store_action(metal::MTLStoreAction::MultisampleResolve); + } else { + color_attachment.set_texture(Some(&texture)); + color_attachment.set_load_action(metal::MTLLoadAction::Clear); + color_attachment.set_store_action(metal::MTLStoreAction::Store); + } + color_attachment.set_clear_color(metal::MTLClearColor::new(0., 0., 0., 1.)); + + let command_encoder = command_buffer.new_render_command_encoder(render_pass_descriptor); + command_encoder.set_render_pipeline_state(&self.paths_rasterization_pipeline_state); + command_encoder.set_vertex_buffer( + PathRasterizationInputIndex::Vertices as u64, + Some(&instance_buffer.metal_buffer), + *instance_offset as u64, + ); + let texture_size = Size { + width: DevicePixels::from(texture.width()), + height: DevicePixels::from(texture.height()), + }; + command_encoder.set_vertex_bytes( + PathRasterizationInputIndex::AtlasTextureSize as u64, + mem::size_of_val(&texture_size) as u64, + &texture_size as *const Size as *const _, + ); + + let buffer_contents = unsafe { + (instance_buffer.metal_buffer.contents() as *mut u8).add(*instance_offset) + }; + unsafe { + ptr::copy_nonoverlapping( + vertices.as_ptr() as *const u8, + buffer_contents, + vertices_bytes_len, + ); + } + + command_encoder.draw_primitives( + metal::MTLPrimitiveType::Triangle, + 0, + vertices.len() as u64, + ); + command_encoder.end_encoding(); + *instance_offset = next_offset; + } + + Some(tiles) + } + fn draw_shadows( &self, shadows: &[Shadow], @@ -621,6 +718,7 @@ impl MetalRenderer { fn draw_paths( &self, paths: &[Path], + tiles_by_path_id: &HashMap, instance_buffer: &mut InstanceBuffer, instance_offset: &mut usize, viewport_size: Size, @@ -630,108 +728,100 @@ impl MetalRenderer { return true; } - command_encoder.set_render_pipeline_state(&self.path_pipeline_state); + command_encoder.set_render_pipeline_state(&self.path_sprites_pipeline_state); + command_encoder.set_vertex_buffer( + SpriteInputIndex::Vertices as u64, + Some(&self.unit_vertices), + 0, + ); + command_encoder.set_vertex_bytes( + SpriteInputIndex::ViewportSize as u64, + mem::size_of_val(&viewport_size) as u64, + &viewport_size as *const Size as *const _, + ); - unsafe { - let base_addr = instance_buffer.metal_buffer.contents(); - let mut p = (base_addr as *mut u8).add(*instance_offset); - let mut draw_indirect_commands = Vec::with_capacity(paths.len()); + let mut prev_texture_id = None; + let mut sprites = SmallVec::<[_; 1]>::new(); + let mut paths_and_tiles = paths + .iter() + .map(|path| (path, tiles_by_path_id.get(&path.id).unwrap())) + .peekable(); - // copy vertices - let vertices_offset = (p as usize) - (base_addr as usize); - let mut first_vertex = 0; - for (i, path) in paths.iter().enumerate() { - if (p as usize) - (base_addr as usize) - + (mem::size_of::>() * path.vertices.len()) - > instance_buffer.size - { + loop { + if let Some((path, tile)) = paths_and_tiles.peek() { + if prev_texture_id.map_or(true, |texture_id| texture_id == tile.texture_id) { + prev_texture_id = Some(tile.texture_id); + let origin = path.bounds.intersect(&path.content_mask.bounds).origin; + sprites.push(PathSprite { + bounds: Bounds { + origin: origin.map(|p| p.floor()), + size: tile.bounds.size.map(Into::into), + }, + color: path.color, + tile: (*tile).clone(), + }); + paths_and_tiles.next(); + continue; + } + } + + if sprites.is_empty() { + break; + } else { + align_offset(instance_offset); + let texture_id = prev_texture_id.take().unwrap(); + let texture: metal::Texture = self.sprite_atlas.metal_texture(texture_id); + let texture_size = size( + DevicePixels(texture.width() as i32), + DevicePixels(texture.height() as i32), + ); + + command_encoder.set_vertex_buffer( + SpriteInputIndex::Sprites as u64, + Some(&instance_buffer.metal_buffer), + *instance_offset as u64, + ); + command_encoder.set_vertex_bytes( + SpriteInputIndex::AtlasTextureSize as u64, + mem::size_of_val(&texture_size) as u64, + &texture_size as *const Size as *const _, + ); + command_encoder.set_fragment_buffer( + SpriteInputIndex::Sprites as u64, + Some(&instance_buffer.metal_buffer), + *instance_offset as u64, + ); + command_encoder + .set_fragment_texture(SpriteInputIndex::AtlasTexture as u64, Some(&texture)); + + let sprite_bytes_len = mem::size_of_val(sprites.as_slice()); + let next_offset = *instance_offset + sprite_bytes_len; + if next_offset > instance_buffer.size { return false; } - for v in &path.vertices { - *(p as *mut PathVertex) = PathVertex { - xy_position: v.xy_position, - content_mask: ContentMask { - bounds: path.content_mask.bounds, - }, - }; - p = p.add(mem::size_of::>()); + let buffer_contents = unsafe { + (instance_buffer.metal_buffer.contents() as *mut u8).add(*instance_offset) + }; + + unsafe { + ptr::copy_nonoverlapping( + sprites.as_ptr() as *const u8, + buffer_contents, + sprite_bytes_len, + ); } - draw_indirect_commands.push(MTLDrawPrimitivesIndirectArguments { - vertexCount: path.vertices.len() as u32, - instanceCount: 1, - vertexStart: first_vertex, - baseInstance: i as u32, - }); - first_vertex += path.vertices.len() as u32; - } - - // copy sprites - let sprites_offset = (p as u64) - (base_addr as u64); - if (p as usize) - (base_addr as usize) + (mem::size_of::() * paths.len()) - > instance_buffer.size - { - return false; - } - for path in paths { - *(p as *mut PathSprite) = PathSprite { - bounds: path.bounds, - color: path.color, - }; - p = p.add(mem::size_of::()); - } - - // copy indirect commands - let icb_bytes_len = mem::size_of_val(draw_indirect_commands.as_slice()); - let icb_offset = (p as u64) - (base_addr as u64); - if (p as usize) - (base_addr as usize) + icb_bytes_len > instance_buffer.size { - return false; - } - ptr::copy_nonoverlapping( - draw_indirect_commands.as_ptr() as *const u8, - p, - icb_bytes_len, - ); - p = p.add(icb_bytes_len); - - // draw path - command_encoder.set_vertex_buffer( - PathInputIndex::Vertices as u64, - Some(&instance_buffer.metal_buffer), - vertices_offset as u64, - ); - - command_encoder.set_vertex_bytes( - PathInputIndex::ViewportSize as u64, - mem::size_of_val(&viewport_size) as u64, - &viewport_size as *const Size as *const _, - ); - - command_encoder.set_vertex_buffer( - PathInputIndex::Sprites as u64, - Some(&instance_buffer.metal_buffer), - sprites_offset, - ); - - command_encoder.set_fragment_buffer( - PathInputIndex::Sprites as u64, - Some(&instance_buffer.metal_buffer), - sprites_offset, - ); - - for i in 0..paths.len() { - command_encoder.draw_primitives_indirect( + command_encoder.draw_primitives_instanced( metal::MTLPrimitiveType::Triangle, - &instance_buffer.metal_buffer, - icb_offset - + (i * std::mem::size_of::()) as u64, + 0, + 6, + sprites.len() as u64, ); + *instance_offset = next_offset; + sprites.clear(); } - - *instance_offset = (p as usize) - (base_addr as usize); } - true } @@ -1053,7 +1143,6 @@ fn build_pipeline_state( vertex_fn_name: &str, fragment_fn_name: &str, pixel_format: metal::MTLPixelFormat, - sample_count: u64, ) -> metal::RenderPipelineState { let vertex_fn = library .get_function(vertex_fn_name, None) @@ -1066,7 +1155,6 @@ fn build_pipeline_state( descriptor.set_label(label); descriptor.set_vertex_function(Some(vertex_fn.as_ref())); descriptor.set_fragment_function(Some(fragment_fn.as_ref())); - descriptor.set_sample_count(sample_count); let color_attachment = descriptor.color_attachments().object_at(0).unwrap(); color_attachment.set_pixel_format(pixel_format); color_attachment.set_blending_enabled(true); @@ -1082,45 +1170,50 @@ fn build_pipeline_state( .expect("could not create render pipeline state") } +fn build_path_rasterization_pipeline_state( + device: &metal::DeviceRef, + library: &metal::LibraryRef, + label: &str, + vertex_fn_name: &str, + fragment_fn_name: &str, + pixel_format: metal::MTLPixelFormat, + path_sample_count: u32, +) -> metal::RenderPipelineState { + let vertex_fn = library + .get_function(vertex_fn_name, None) + .expect("error locating vertex function"); + let fragment_fn = library + .get_function(fragment_fn_name, None) + .expect("error locating fragment function"); + + let descriptor = metal::RenderPipelineDescriptor::new(); + descriptor.set_label(label); + descriptor.set_vertex_function(Some(vertex_fn.as_ref())); + descriptor.set_fragment_function(Some(fragment_fn.as_ref())); + if path_sample_count > 1 { + descriptor.set_raster_sample_count(path_sample_count as _); + descriptor.set_alpha_to_coverage_enabled(true); + } + let color_attachment = descriptor.color_attachments().object_at(0).unwrap(); + color_attachment.set_pixel_format(pixel_format); + color_attachment.set_blending_enabled(true); + color_attachment.set_rgb_blend_operation(metal::MTLBlendOperation::Add); + color_attachment.set_alpha_blend_operation(metal::MTLBlendOperation::Add); + color_attachment.set_source_rgb_blend_factor(metal::MTLBlendFactor::One); + color_attachment.set_source_alpha_blend_factor(metal::MTLBlendFactor::One); + color_attachment.set_destination_rgb_blend_factor(metal::MTLBlendFactor::One); + color_attachment.set_destination_alpha_blend_factor(metal::MTLBlendFactor::One); + + device + .new_render_pipeline_state(&descriptor) + .expect("could not create render pipeline state") +} + // Align to multiples of 256 make Metal happy. fn align_offset(offset: &mut usize) { *offset = (*offset).div_ceil(256) * 256; } -fn create_msaa_texture( - device: &metal::Device, - layer: &metal::MetalLayer, - sample_count: u64, -) -> Option { - let viewport_size = layer.drawable_size(); - let width = viewport_size.width.ceil() as u64; - let height = viewport_size.height.ceil() as u64; - - if width == 0 || height == 0 { - return None; - } - - if sample_count <= 1 { - return None; - } - - let texture_descriptor = metal::TextureDescriptor::new(); - texture_descriptor.set_texture_type(metal::MTLTextureType::D2Multisample); - - // MTLStorageMode default is `shared` only for Apple silicon GPUs. Use `private` for Apple and Intel GPUs both. - // Reference: https://developer.apple.com/documentation/metal/choosing-a-resource-storage-mode-for-apple-gpus - texture_descriptor.set_storage_mode(metal::MTLStorageMode::Private); - - texture_descriptor.set_width(width); - texture_descriptor.set_height(height); - texture_descriptor.set_pixel_format(layer.pixel_format()); - texture_descriptor.set_usage(metal::MTLTextureUsage::RenderTarget); - texture_descriptor.set_sample_count(sample_count); - - let metal_texture = device.new_texture(&texture_descriptor); - Some(metal_texture) -} - #[repr(C)] enum ShadowInputIndex { Vertices = 0, @@ -1162,10 +1255,9 @@ enum SurfaceInputIndex { } #[repr(C)] -enum PathInputIndex { +enum PathRasterizationInputIndex { Vertices = 0, - ViewportSize = 1, - Sprites = 2, + AtlasTextureSize = 1, } #[derive(Clone, Debug, Eq, PartialEq)] @@ -1173,6 +1265,7 @@ enum PathInputIndex { pub struct PathSprite { pub bounds: Bounds, pub color: Background, + pub tile: AtlasTile, } #[derive(Clone, Debug, Eq, PartialEq)] diff --git a/crates/gpui/src/platform/mac/shaders.metal b/crates/gpui/src/platform/mac/shaders.metal index 5f0dc3323d..64ebb1e22b 100644 --- a/crates/gpui/src/platform/mac/shaders.metal +++ b/crates/gpui/src/platform/mac/shaders.metal @@ -698,27 +698,76 @@ fragment float4 polychrome_sprite_fragment( return color; } -struct PathVertexOutput { +struct PathRasterizationVertexOutput { float4 position [[position]]; + float2 st_position; + float clip_rect_distance [[clip_distance]][4]; +}; + +struct PathRasterizationFragmentInput { + float4 position [[position]]; + float2 st_position; +}; + +vertex PathRasterizationVertexOutput path_rasterization_vertex( + uint vertex_id [[vertex_id]], + constant PathVertex_ScaledPixels *vertices + [[buffer(PathRasterizationInputIndex_Vertices)]], + constant Size_DevicePixels *atlas_size + [[buffer(PathRasterizationInputIndex_AtlasTextureSize)]]) { + PathVertex_ScaledPixels v = vertices[vertex_id]; + float2 vertex_position = float2(v.xy_position.x, v.xy_position.y); + float2 viewport_size = float2(atlas_size->width, atlas_size->height); + return PathRasterizationVertexOutput{ + float4(vertex_position / viewport_size * float2(2., -2.) + + float2(-1., 1.), + 0., 1.), + float2(v.st_position.x, v.st_position.y), + {v.xy_position.x - v.content_mask.bounds.origin.x, + v.content_mask.bounds.origin.x + v.content_mask.bounds.size.width - + v.xy_position.x, + v.xy_position.y - v.content_mask.bounds.origin.y, + v.content_mask.bounds.origin.y + v.content_mask.bounds.size.height - + v.xy_position.y}}; +} + +fragment float4 path_rasterization_fragment(PathRasterizationFragmentInput input + [[stage_in]]) { + float2 dx = dfdx(input.st_position); + float2 dy = dfdy(input.st_position); + float2 gradient = float2((2. * input.st_position.x) * dx.x - dx.y, + (2. * input.st_position.x) * dy.x - dy.y); + float f = (input.st_position.x * input.st_position.x) - input.st_position.y; + float distance = f / length(gradient); + float alpha = saturate(0.5 - distance); + return float4(alpha, 0., 0., 1.); +} + +struct PathSpriteVertexOutput { + float4 position [[position]]; + float2 tile_position; uint sprite_id [[flat]]; float4 solid_color [[flat]]; float4 color0 [[flat]]; float4 color1 [[flat]]; - float4 clip_distance; }; -vertex PathVertexOutput path_vertex( - uint vertex_id [[vertex_id]], - constant PathVertex_ScaledPixels *vertices [[buffer(PathInputIndex_Vertices)]], - uint sprite_id [[instance_id]], - constant PathSprite *sprites [[buffer(PathInputIndex_Sprites)]], - constant Size_DevicePixels *input_viewport_size [[buffer(PathInputIndex_ViewportSize)]]) { - PathVertex_ScaledPixels v = vertices[vertex_id]; - float2 vertex_position = float2(v.xy_position.x, v.xy_position.y); - float2 viewport_size = float2((float)input_viewport_size->width, - (float)input_viewport_size->height); +vertex PathSpriteVertexOutput path_sprite_vertex( + uint unit_vertex_id [[vertex_id]], uint sprite_id [[instance_id]], + constant float2 *unit_vertices [[buffer(SpriteInputIndex_Vertices)]], + constant PathSprite *sprites [[buffer(SpriteInputIndex_Sprites)]], + constant Size_DevicePixels *viewport_size + [[buffer(SpriteInputIndex_ViewportSize)]], + constant Size_DevicePixels *atlas_size + [[buffer(SpriteInputIndex_AtlasTextureSize)]]) { + + float2 unit_vertex = unit_vertices[unit_vertex_id]; PathSprite sprite = sprites[sprite_id]; - float4 device_position = float4(vertex_position / viewport_size * float2(2., -2.) + float2(-1., 1.), 0., 1.); + // Don't apply content mask because it was already accounted for when + // rasterizing the path. + float4 device_position = + to_device_position(unit_vertex, sprite.bounds, viewport_size); + float2 tile_position = to_tile_position(unit_vertex, sprite.tile, atlas_size); GradientColor gradient = prepare_fill_color( sprite.color.tag, @@ -728,32 +777,30 @@ vertex PathVertexOutput path_vertex( sprite.color.colors[1].color ); - return PathVertexOutput{ + return PathSpriteVertexOutput{ device_position, + tile_position, sprite_id, gradient.solid, gradient.color0, - gradient.color1, - {v.xy_position.x - v.content_mask.bounds.origin.x, - v.content_mask.bounds.origin.x + v.content_mask.bounds.size.width - - v.xy_position.x, - v.xy_position.y - v.content_mask.bounds.origin.y, - v.content_mask.bounds.origin.y + v.content_mask.bounds.size.height - - v.xy_position.y} + gradient.color1 }; } -fragment float4 path_fragment( - PathVertexOutput input [[stage_in]], - constant PathSprite *sprites [[buffer(PathInputIndex_Sprites)]]) { - if (any(input.clip_distance < float4(0.0))) { - return float4(0.0); - } - +fragment float4 path_sprite_fragment( + PathSpriteVertexOutput input [[stage_in]], + constant PathSprite *sprites [[buffer(SpriteInputIndex_Sprites)]], + texture2d atlas_texture [[texture(SpriteInputIndex_AtlasTexture)]]) { + constexpr sampler atlas_texture_sampler(mag_filter::linear, + min_filter::linear); + float4 sample = + atlas_texture.sample(atlas_texture_sampler, input.tile_position); + float mask = 1. - abs(1. - fmod(sample.r, 2.)); PathSprite sprite = sprites[input.sprite_id]; Background background = sprite.color; float4 color = fill_color(background, input.position.xy, sprite.bounds, input.solid_color, input.color0, input.color1); + color.a *= mask; return color; } diff --git a/crates/gpui/src/platform/test/window.rs b/crates/gpui/src/platform/test/window.rs index 65ee10a13f..1b88415d3b 100644 --- a/crates/gpui/src/platform/test/window.rs +++ b/crates/gpui/src/platform/test/window.rs @@ -341,7 +341,7 @@ impl PlatformAtlas for TestAtlas { crate::AtlasTile { texture_id: AtlasTextureId { index: texture_id, - kind: crate::AtlasTextureKind::Polychrome, + kind: crate::AtlasTextureKind::Path, }, tile_id: TileId(tile_id), padding: 0, diff --git a/crates/gpui/src/scene.rs b/crates/gpui/src/scene.rs index 681444a473..4eaef64afa 100644 --- a/crates/gpui/src/scene.rs +++ b/crates/gpui/src/scene.rs @@ -6,7 +6,7 @@ use serde::{Deserialize, Serialize}; use crate::{ AtlasTextureId, AtlasTile, Background, Bounds, ContentMask, Corners, Edges, Hsla, Pixels, - Point, Radians, ScaledPixels, Size, bounds_tree::BoundsTree, + Point, Radians, ScaledPixels, Size, bounds_tree::BoundsTree, point, }; use std::{fmt::Debug, iter::Peekable, ops::Range, slice}; @@ -43,7 +43,13 @@ impl Scene { self.surfaces.clear(); } - #[allow(dead_code)] + #[cfg_attr( + all( + any(target_os = "linux", target_os = "freebsd"), + not(any(feature = "x11", feature = "wayland")) + ), + allow(dead_code) + )] pub fn paths(&self) -> &[Path] { &self.paths } @@ -683,7 +689,6 @@ pub struct Path { start: Point

, current: Point

, contour_count: usize, - base_scale: f32, } impl Path { @@ -702,35 +707,25 @@ impl Path { content_mask: Default::default(), color: Default::default(), contour_count: 0, - base_scale: 1.0, } } - /// Set the base scale of the path. - pub fn scale(mut self, factor: f32) -> Self { - self.base_scale = factor; - self - } - - /// Apply a scale to the path. - pub(crate) fn apply_scale(&self, factor: f32) -> Path { + /// Scale this path by the given factor. + pub fn scale(&self, factor: f32) -> Path { Path { id: self.id, order: self.order, - bounds: self.bounds.scale(self.base_scale * factor), - content_mask: self.content_mask.scale(self.base_scale * factor), + bounds: self.bounds.scale(factor), + content_mask: self.content_mask.scale(factor), vertices: self .vertices .iter() - .map(|vertex| vertex.scale(self.base_scale * factor)) + .map(|vertex| vertex.scale(factor)) .collect(), - start: self - .start - .map(|start| start.scale(self.base_scale * factor)), - current: self.current.scale(self.base_scale * factor), + start: self.start.map(|start| start.scale(factor)), + current: self.current.scale(factor), contour_count: self.contour_count, color: self.color, - base_scale: 1.0, } } @@ -745,7 +740,10 @@ impl Path { pub fn line_to(&mut self, to: Point) { self.contour_count += 1; if self.contour_count > 1 { - self.push_triangle((self.start, self.current, to)); + self.push_triangle( + (self.start, self.current, to), + (point(0., 1.), point(0., 1.), point(0., 1.)), + ); } self.current = to; } @@ -754,15 +752,25 @@ impl Path { pub fn curve_to(&mut self, to: Point, ctrl: Point) { self.contour_count += 1; if self.contour_count > 1 { - self.push_triangle((self.start, self.current, to)); + self.push_triangle( + (self.start, self.current, to), + (point(0., 1.), point(0., 1.), point(0., 1.)), + ); } - self.push_triangle((self.current, ctrl, to)); + self.push_triangle( + (self.current, ctrl, to), + (point(0., 0.), point(0.5, 0.), point(1., 1.)), + ); self.current = to; } /// Push a triangle to the Path. - pub fn push_triangle(&mut self, xy: (Point, Point, Point)) { + pub fn push_triangle( + &mut self, + xy: (Point, Point, Point), + st: (Point, Point, Point), + ) { self.bounds = self .bounds .union(&Bounds { @@ -780,14 +788,17 @@ impl Path { self.vertices.push(PathVertex { xy_position: xy.0, + st_position: st.0, content_mask: Default::default(), }); self.vertices.push(PathVertex { xy_position: xy.1, + st_position: st.1, content_mask: Default::default(), }); self.vertices.push(PathVertex { xy_position: xy.2, + st_position: st.2, content_mask: Default::default(), }); } @@ -803,6 +814,7 @@ impl From> for Primitive { #[repr(C)] pub(crate) struct PathVertex { pub(crate) xy_position: Point

, + pub(crate) st_position: Point, pub(crate) content_mask: ContentMask

, } @@ -810,6 +822,7 @@ impl PathVertex { pub fn scale(&self, factor: f32) -> PathVertex { PathVertex { xy_position: self.xy_position.scale(factor), + st_position: self.st_position, content_mask: self.content_mask.scale(factor), } } diff --git a/crates/gpui/src/window.rs b/crates/gpui/src/window.rs index 8c01b8afcf..be3b753d6a 100644 --- a/crates/gpui/src/window.rs +++ b/crates/gpui/src/window.rs @@ -2633,7 +2633,7 @@ impl Window { path.color = color.opacity(opacity); self.next_frame .scene - .insert_primitive(path.apply_scale(scale_factor)); + .insert_primitive(path.scale(scale_factor)); } /// Paint an underline into the scene for the next frame at the current z-index. diff --git a/crates/icons/src/icons.rs b/crates/icons/src/icons.rs index 332e38b038..332a8d5791 100644 --- a/crates/icons/src/icons.rs +++ b/crates/icons/src/icons.rs @@ -20,6 +20,7 @@ pub enum IconName { AiOpenAi, AiOpenRouter, AiVZero, + AiXAi, AiZed, ArrowCircle, ArrowDown, diff --git a/crates/language_model/src/language_model.rs b/crates/language_model/src/language_model.rs index 81a0f7d8a1..6bd33fcdf5 100644 --- a/crates/language_model/src/language_model.rs +++ b/crates/language_model/src/language_model.rs @@ -116,6 +116,12 @@ pub enum LanguageModelCompletionError { provider: LanguageModelProviderName, message: String, }, + #[error("{message}")] + UpstreamProviderError { + message: String, + status: StatusCode, + retry_after: Option, + }, #[error("HTTP response error from {provider}'s API: status {status_code} - {message:?}")] HttpResponseError { provider: LanguageModelProviderName, @@ -178,6 +184,21 @@ pub enum LanguageModelCompletionError { } impl LanguageModelCompletionError { + fn parse_upstream_error_json(message: &str) -> Option<(StatusCode, String)> { + let error_json = serde_json::from_str::(message).ok()?; + let upstream_status = error_json + .get("upstream_status") + .and_then(|v| v.as_u64()) + .and_then(|status| u16::try_from(status).ok()) + .and_then(|status| StatusCode::from_u16(status).ok())?; + let inner_message = error_json + .get("message") + .and_then(|v| v.as_str()) + .unwrap_or(message) + .to_string(); + Some((upstream_status, inner_message)) + } + pub fn from_cloud_failure( upstream_provider: LanguageModelProviderName, code: String, @@ -191,6 +212,18 @@ impl LanguageModelCompletionError { Self::PromptTooLarge { tokens: Some(tokens), } + } else if code == "upstream_http_error" { + if let Some((upstream_status, inner_message)) = + Self::parse_upstream_error_json(&message) + { + return Self::from_http_status( + upstream_provider, + upstream_status, + inner_message, + retry_after, + ); + } + anyhow!("completion request failed, code: {code}, message: {message}").into() } else if let Some(status_code) = code .strip_prefix("upstream_http_") .and_then(|code| StatusCode::from_str(code).ok()) @@ -701,3 +734,104 @@ impl From for LanguageModelProviderName { Self(SharedString::from(value)) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_from_cloud_failure_with_upstream_http_error() { + let error = LanguageModelCompletionError::from_cloud_failure( + String::from("anthropic").into(), + "upstream_http_error".to_string(), + r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(), + None, + ); + + match error { + LanguageModelCompletionError::ServerOverloaded { provider, .. } => { + assert_eq!(provider.0, "anthropic"); + } + _ => panic!( + "Expected ServerOverloaded error for 503 status, got: {:?}", + error + ), + } + + let error = LanguageModelCompletionError::from_cloud_failure( + String::from("anthropic").into(), + "upstream_http_error".to_string(), + r#"{"code":"upstream_http_error","message":"Internal server error","upstream_status":500}"#.to_string(), + None, + ); + + match error { + LanguageModelCompletionError::ApiInternalServerError { provider, message } => { + assert_eq!(provider.0, "anthropic"); + assert_eq!(message, "Internal server error"); + } + _ => panic!( + "Expected ApiInternalServerError for 500 status, got: {:?}", + error + ), + } + } + + #[test] + fn test_from_cloud_failure_with_standard_format() { + let error = LanguageModelCompletionError::from_cloud_failure( + String::from("anthropic").into(), + "upstream_http_503".to_string(), + "Service unavailable".to_string(), + None, + ); + + match error { + LanguageModelCompletionError::ServerOverloaded { provider, .. } => { + assert_eq!(provider.0, "anthropic"); + } + _ => panic!("Expected ServerOverloaded error for upstream_http_503"), + } + } + + #[test] + fn test_upstream_http_error_connection_timeout() { + let error = LanguageModelCompletionError::from_cloud_failure( + String::from("anthropic").into(), + "upstream_http_error".to_string(), + r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(), + None, + ); + + match error { + LanguageModelCompletionError::ServerOverloaded { provider, .. } => { + assert_eq!(provider.0, "anthropic"); + } + _ => panic!( + "Expected ServerOverloaded error for connection timeout with 503 status, got: {:?}", + error + ), + } + + let error = LanguageModelCompletionError::from_cloud_failure( + String::from("anthropic").into(), + "upstream_http_error".to_string(), + r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":500}"#.to_string(), + None, + ); + + match error { + LanguageModelCompletionError::ApiInternalServerError { provider, message } => { + assert_eq!(provider.0, "anthropic"); + assert_eq!( + message, + "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout" + ); + } + _ => panic!( + "Expected ApiInternalServerError for connection timeout with 500 status, got: {:?}", + error + ), + } + } +} diff --git a/crates/language_models/Cargo.toml b/crates/language_models/Cargo.toml index 514443ddec..e928df8a74 100644 --- a/crates/language_models/Cargo.toml +++ b/crates/language_models/Cargo.toml @@ -44,6 +44,7 @@ ollama = { workspace = true, features = ["schemars"] } open_ai = { workspace = true, features = ["schemars"] } open_router = { workspace = true, features = ["schemars"] } vercel = { workspace = true, features = ["schemars"] } +x_ai = { workspace = true, features = ["schemars"] } partial-json-fixer.workspace = true proto.workspace = true release_channel.workspace = true diff --git a/crates/language_models/src/language_models.rs b/crates/language_models/src/language_models.rs index c7324732c9..192f5a5fae 100644 --- a/crates/language_models/src/language_models.rs +++ b/crates/language_models/src/language_models.rs @@ -20,6 +20,7 @@ use crate::provider::ollama::OllamaLanguageModelProvider; use crate::provider::open_ai::OpenAiLanguageModelProvider; use crate::provider::open_router::OpenRouterLanguageModelProvider; use crate::provider::vercel::VercelLanguageModelProvider; +use crate::provider::x_ai::XAiLanguageModelProvider; pub use crate::settings::*; pub fn init(user_store: Entity, client: Arc, cx: &mut App) { @@ -81,5 +82,6 @@ fn register_language_model_providers( VercelLanguageModelProvider::new(client.http_client(), cx), cx, ); + registry.register_provider(XAiLanguageModelProvider::new(client.http_client(), cx), cx); registry.register_provider(CopilotChatLanguageModelProvider::new(cx), cx); } diff --git a/crates/language_models/src/provider.rs b/crates/language_models/src/provider.rs index 6bc93bd366..c717be7c90 100644 --- a/crates/language_models/src/provider.rs +++ b/crates/language_models/src/provider.rs @@ -10,3 +10,4 @@ pub mod ollama; pub mod open_ai; pub mod open_router; pub mod vercel; +pub mod x_ai; diff --git a/crates/language_models/src/provider/cloud.rs b/crates/language_models/src/provider/cloud.rs index 9b7fee228a..8d25af1a49 100644 --- a/crates/language_models/src/provider/cloud.rs +++ b/crates/language_models/src/provider/cloud.rs @@ -166,46 +166,9 @@ impl State { } let response = Self::fetch_models(client, llm_api_token, use_cloud).await?; - cx.update(|cx| { - this.update(cx, |this, cx| { - let mut models = Vec::new(); - - for model in response.models { - models.push(Arc::new(model.clone())); - - // Right now we represent thinking variants of models as separate models on the client, - // so we need to insert variants for any model that supports thinking. - if model.supports_thinking { - models.push(Arc::new(zed_llm_client::LanguageModel { - id: zed_llm_client::LanguageModelId( - format!("{}-thinking", model.id).into(), - ), - display_name: format!("{} Thinking", model.display_name), - ..model - })); - } - } - - this.default_model = models - .iter() - .find(|model| model.id == response.default_model) - .cloned(); - this.default_fast_model = models - .iter() - .find(|model| model.id == response.default_fast_model) - .cloned(); - this.recommended_models = response - .recommended_models - .iter() - .filter_map(|id| models.iter().find(|model| &model.id == id)) - .cloned() - .collect(); - this.models = models; - cx.notify(); - }) - })??; - - anyhow::Ok(()) + this.update(cx, |this, cx| { + this.update_models(response, cx); + }) }) .await .context("failed to fetch Zed models") @@ -216,12 +179,15 @@ impl State { }), _llm_token_subscription: cx.subscribe( &refresh_llm_token_listener, - |this, _listener, _event, cx| { + move |this, _listener, _event, cx| { let client = this.client.clone(); let llm_api_token = this.llm_api_token.clone(); - cx.spawn(async move |_this, _cx| { + cx.spawn(async move |this, cx| { llm_api_token.refresh(&client).await?; - anyhow::Ok(()) + let response = Self::fetch_models(client, llm_api_token, use_cloud).await?; + this.update(cx, |this, cx| { + this.update_models(response, cx); + }) }) .detach_and_log_err(cx); }, @@ -264,6 +230,41 @@ impl State { })); } + fn update_models(&mut self, response: ListModelsResponse, cx: &mut Context) { + let mut models = Vec::new(); + + for model in response.models { + models.push(Arc::new(model.clone())); + + // Right now we represent thinking variants of models as separate models on the client, + // so we need to insert variants for any model that supports thinking. + if model.supports_thinking { + models.push(Arc::new(zed_llm_client::LanguageModel { + id: zed_llm_client::LanguageModelId(format!("{}-thinking", model.id).into()), + display_name: format!("{} Thinking", model.display_name), + ..model + })); + } + } + + self.default_model = models + .iter() + .find(|model| model.id == response.default_model) + .cloned(); + self.default_fast_model = models + .iter() + .find(|model| model.id == response.default_fast_model) + .cloned(); + self.recommended_models = response + .recommended_models + .iter() + .filter_map(|id| models.iter().find(|model| &model.id == id)) + .cloned() + .collect(); + self.models = models; + cx.notify(); + } + async fn fetch_models( client: Arc, llm_api_token: LlmApiToken, @@ -653,8 +654,62 @@ struct ApiError { headers: HeaderMap, } +/// Represents error responses from Zed's cloud API. +/// +/// Example JSON for an upstream HTTP error: +/// ```json +/// { +/// "code": "upstream_http_error", +/// "message": "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout", +/// "upstream_status": 503 +/// } +/// ``` +#[derive(Debug, serde::Deserialize)] +struct CloudApiError { + code: String, + message: String, + #[serde(default)] + #[serde(deserialize_with = "deserialize_optional_status_code")] + upstream_status: Option, + #[serde(default)] + retry_after: Option, +} + +fn deserialize_optional_status_code<'de, D>(deserializer: D) -> Result, D::Error> +where + D: serde::Deserializer<'de>, +{ + let opt: Option = Option::deserialize(deserializer)?; + Ok(opt.and_then(|code| StatusCode::from_u16(code).ok())) +} + impl From for LanguageModelCompletionError { fn from(error: ApiError) -> Self { + if let Ok(cloud_error) = serde_json::from_str::(&error.body) { + if cloud_error.code.starts_with("upstream_http_") { + let status = if let Some(status) = cloud_error.upstream_status { + status + } else if cloud_error.code.ends_with("_error") { + error.status + } else { + // If there's a status code in the code string (e.g. "upstream_http_429") + // then use that; otherwise, see if the JSON contains a status code. + cloud_error + .code + .strip_prefix("upstream_http_") + .and_then(|code_str| code_str.parse::().ok()) + .and_then(|code| StatusCode::from_u16(code).ok()) + .unwrap_or(error.status) + }; + + return LanguageModelCompletionError::UpstreamProviderError { + message: cloud_error.message, + status, + retry_after: cloud_error.retry_after.map(Duration::from_secs_f64), + }; + } + } + let retry_after = None; LanguageModelCompletionError::from_http_status( PROVIDER_NAME, @@ -1294,3 +1349,155 @@ impl Component for ZedAiConfiguration { ) } } + +#[cfg(test)] +mod tests { + use super::*; + use http_client::http::{HeaderMap, StatusCode}; + use language_model::LanguageModelCompletionError; + + #[test] + fn test_api_error_conversion_with_upstream_http_error() { + // upstream_http_error with 503 status should become ServerOverloaded + let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout","upstream_status":503}"#; + + let api_error = ApiError { + status: StatusCode::INTERNAL_SERVER_ERROR, + body: error_body.to_string(), + headers: HeaderMap::new(), + }; + + let completion_error: LanguageModelCompletionError = api_error.into(); + + match completion_error { + LanguageModelCompletionError::UpstreamProviderError { message, .. } => { + assert_eq!( + message, + "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout" + ); + } + _ => panic!( + "Expected UpstreamProviderError for upstream 503, got: {:?}", + completion_error + ), + } + + // upstream_http_error with 500 status should become ApiInternalServerError + let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the OpenAI API: internal server error","upstream_status":500}"#; + + let api_error = ApiError { + status: StatusCode::INTERNAL_SERVER_ERROR, + body: error_body.to_string(), + headers: HeaderMap::new(), + }; + + let completion_error: LanguageModelCompletionError = api_error.into(); + + match completion_error { + LanguageModelCompletionError::UpstreamProviderError { message, .. } => { + assert_eq!( + message, + "Received an error from the OpenAI API: internal server error" + ); + } + _ => panic!( + "Expected UpstreamProviderError for upstream 500, got: {:?}", + completion_error + ), + } + + // upstream_http_error with 429 status should become RateLimitExceeded + let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the Google API: rate limit exceeded","upstream_status":429}"#; + + let api_error = ApiError { + status: StatusCode::INTERNAL_SERVER_ERROR, + body: error_body.to_string(), + headers: HeaderMap::new(), + }; + + let completion_error: LanguageModelCompletionError = api_error.into(); + + match completion_error { + LanguageModelCompletionError::UpstreamProviderError { message, .. } => { + assert_eq!( + message, + "Received an error from the Google API: rate limit exceeded" + ); + } + _ => panic!( + "Expected UpstreamProviderError for upstream 429, got: {:?}", + completion_error + ), + } + + // Regular 500 error without upstream_http_error should remain ApiInternalServerError for Zed + let error_body = "Regular internal server error"; + + let api_error = ApiError { + status: StatusCode::INTERNAL_SERVER_ERROR, + body: error_body.to_string(), + headers: HeaderMap::new(), + }; + + let completion_error: LanguageModelCompletionError = api_error.into(); + + match completion_error { + LanguageModelCompletionError::ApiInternalServerError { provider, message } => { + assert_eq!(provider, PROVIDER_NAME); + assert_eq!(message, "Regular internal server error"); + } + _ => panic!( + "Expected ApiInternalServerError for regular 500, got: {:?}", + completion_error + ), + } + + // upstream_http_429 format should be converted to UpstreamProviderError + let error_body = r#"{"code":"upstream_http_429","message":"Upstream Anthropic rate limit exceeded.","retry_after":30.5}"#; + + let api_error = ApiError { + status: StatusCode::INTERNAL_SERVER_ERROR, + body: error_body.to_string(), + headers: HeaderMap::new(), + }; + + let completion_error: LanguageModelCompletionError = api_error.into(); + + match completion_error { + LanguageModelCompletionError::UpstreamProviderError { + message, + status, + retry_after, + } => { + assert_eq!(message, "Upstream Anthropic rate limit exceeded."); + assert_eq!(status, StatusCode::TOO_MANY_REQUESTS); + assert_eq!(retry_after, Some(Duration::from_secs_f64(30.5))); + } + _ => panic!( + "Expected UpstreamProviderError for upstream_http_429, got: {:?}", + completion_error + ), + } + + // Invalid JSON in error body should fall back to regular error handling + let error_body = "Not JSON at all"; + + let api_error = ApiError { + status: StatusCode::INTERNAL_SERVER_ERROR, + body: error_body.to_string(), + headers: HeaderMap::new(), + }; + + let completion_error: LanguageModelCompletionError = api_error.into(); + + match completion_error { + LanguageModelCompletionError::ApiInternalServerError { provider, .. } => { + assert_eq!(provider, PROVIDER_NAME); + } + _ => panic!( + "Expected ApiInternalServerError for invalid JSON, got: {:?}", + completion_error + ), + } + } +} diff --git a/crates/language_models/src/provider/open_router.rs b/crates/language_models/src/provider/open_router.rs index 5883da1e2f..90d375a6b2 100644 --- a/crates/language_models/src/provider/open_router.rs +++ b/crates/language_models/src/provider/open_router.rs @@ -376,7 +376,7 @@ impl LanguageModel for OpenRouterLanguageModel { fn tool_input_format(&self) -> LanguageModelToolSchemaFormat { let model_id = self.model.id().trim().to_lowercase(); - if model_id.contains("gemini") { + if model_id.contains("gemini") || model_id.contains("grok-4") { LanguageModelToolSchemaFormat::JsonSchemaSubset } else { LanguageModelToolSchemaFormat::JsonSchema diff --git a/crates/language_models/src/provider/x_ai.rs b/crates/language_models/src/provider/x_ai.rs new file mode 100644 index 0000000000..5f6034571b --- /dev/null +++ b/crates/language_models/src/provider/x_ai.rs @@ -0,0 +1,571 @@ +use anyhow::{Context as _, Result, anyhow}; +use collections::BTreeMap; +use credentials_provider::CredentialsProvider; +use futures::{FutureExt, StreamExt, future::BoxFuture}; +use gpui::{AnyView, App, AsyncApp, Context, Entity, Subscription, Task, Window}; +use http_client::HttpClient; +use language_model::{ + AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, + LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, + LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, + LanguageModelToolChoice, LanguageModelToolSchemaFormat, RateLimiter, Role, +}; +use menu; +use open_ai::ResponseStreamEvent; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; +use settings::{Settings, SettingsStore}; +use std::sync::Arc; +use strum::IntoEnumIterator; +use x_ai::Model; + +use ui::{ElevationIndex, List, Tooltip, prelude::*}; +use ui_input::SingleLineInput; +use util::ResultExt; + +use crate::{AllLanguageModelSettings, ui::InstructionListItem}; + +const PROVIDER_ID: &str = "x_ai"; +const PROVIDER_NAME: &str = "xAI"; + +#[derive(Default, Clone, Debug, PartialEq)] +pub struct XAiSettings { + pub api_url: String, + pub available_models: Vec, +} + +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)] +pub struct AvailableModel { + pub name: String, + pub display_name: Option, + pub max_tokens: u64, + pub max_output_tokens: Option, + pub max_completion_tokens: Option, +} + +pub struct XAiLanguageModelProvider { + http_client: Arc, + state: gpui::Entity, +} + +pub struct State { + api_key: Option, + api_key_from_env: bool, + _subscription: Subscription, +} + +const XAI_API_KEY_VAR: &str = "XAI_API_KEY"; + +impl State { + fn is_authenticated(&self) -> bool { + self.api_key.is_some() + } + + fn reset_api_key(&self, cx: &mut Context) -> Task> { + let credentials_provider = ::global(cx); + let settings = &AllLanguageModelSettings::get_global(cx).x_ai; + let api_url = if settings.api_url.is_empty() { + x_ai::XAI_API_URL.to_string() + } else { + settings.api_url.clone() + }; + cx.spawn(async move |this, cx| { + credentials_provider + .delete_credentials(&api_url, &cx) + .await + .log_err(); + this.update(cx, |this, cx| { + this.api_key = None; + this.api_key_from_env = false; + cx.notify(); + }) + }) + } + + fn set_api_key(&mut self, api_key: String, cx: &mut Context) -> Task> { + let credentials_provider = ::global(cx); + let settings = &AllLanguageModelSettings::get_global(cx).x_ai; + let api_url = if settings.api_url.is_empty() { + x_ai::XAI_API_URL.to_string() + } else { + settings.api_url.clone() + }; + cx.spawn(async move |this, cx| { + credentials_provider + .write_credentials(&api_url, "Bearer", api_key.as_bytes(), &cx) + .await + .log_err(); + this.update(cx, |this, cx| { + this.api_key = Some(api_key); + cx.notify(); + }) + }) + } + + fn authenticate(&self, cx: &mut Context) -> Task> { + if self.is_authenticated() { + return Task::ready(Ok(())); + } + + let credentials_provider = ::global(cx); + let settings = &AllLanguageModelSettings::get_global(cx).x_ai; + let api_url = if settings.api_url.is_empty() { + x_ai::XAI_API_URL.to_string() + } else { + settings.api_url.clone() + }; + cx.spawn(async move |this, cx| { + let (api_key, from_env) = if let Ok(api_key) = std::env::var(XAI_API_KEY_VAR) { + (api_key, true) + } else { + let (_, api_key) = credentials_provider + .read_credentials(&api_url, &cx) + .await? + .ok_or(AuthenticateError::CredentialsNotFound)?; + ( + String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?, + false, + ) + }; + this.update(cx, |this, cx| { + this.api_key = Some(api_key); + this.api_key_from_env = from_env; + cx.notify(); + })?; + + Ok(()) + }) + } +} + +impl XAiLanguageModelProvider { + pub fn new(http_client: Arc, cx: &mut App) -> Self { + let state = cx.new(|cx| State { + api_key: None, + api_key_from_env: false, + _subscription: cx.observe_global::(|_this: &mut State, cx| { + cx.notify(); + }), + }); + + Self { http_client, state } + } + + fn create_language_model(&self, model: x_ai::Model) -> Arc { + Arc::new(XAiLanguageModel { + id: LanguageModelId::from(model.id().to_string()), + model, + state: self.state.clone(), + http_client: self.http_client.clone(), + request_limiter: RateLimiter::new(4), + }) + } +} + +impl LanguageModelProviderState for XAiLanguageModelProvider { + type ObservableEntity = State; + + fn observable_entity(&self) -> Option> { + Some(self.state.clone()) + } +} + +impl LanguageModelProvider for XAiLanguageModelProvider { + fn id(&self) -> LanguageModelProviderId { + LanguageModelProviderId(PROVIDER_ID.into()) + } + + fn name(&self) -> LanguageModelProviderName { + LanguageModelProviderName(PROVIDER_NAME.into()) + } + + fn icon(&self) -> IconName { + IconName::AiXAi + } + + fn default_model(&self, _cx: &App) -> Option> { + Some(self.create_language_model(x_ai::Model::default())) + } + + fn default_fast_model(&self, _cx: &App) -> Option> { + Some(self.create_language_model(x_ai::Model::default_fast())) + } + + fn provided_models(&self, cx: &App) -> Vec> { + let mut models = BTreeMap::default(); + + for model in x_ai::Model::iter() { + if !matches!(model, x_ai::Model::Custom { .. }) { + models.insert(model.id().to_string(), model); + } + } + + for model in &AllLanguageModelSettings::get_global(cx) + .x_ai + .available_models + { + models.insert( + model.name.clone(), + x_ai::Model::Custom { + name: model.name.clone(), + display_name: model.display_name.clone(), + max_tokens: model.max_tokens, + max_output_tokens: model.max_output_tokens, + max_completion_tokens: model.max_completion_tokens, + }, + ); + } + + models + .into_values() + .map(|model| self.create_language_model(model)) + .collect() + } + + fn is_authenticated(&self, cx: &App) -> bool { + self.state.read(cx).is_authenticated() + } + + fn authenticate(&self, cx: &mut App) -> Task> { + self.state.update(cx, |state, cx| state.authenticate(cx)) + } + + fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView { + cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx)) + .into() + } + + fn reset_credentials(&self, cx: &mut App) -> Task> { + self.state.update(cx, |state, cx| state.reset_api_key(cx)) + } +} + +pub struct XAiLanguageModel { + id: LanguageModelId, + model: x_ai::Model, + state: gpui::Entity, + http_client: Arc, + request_limiter: RateLimiter, +} + +impl XAiLanguageModel { + fn stream_completion( + &self, + request: open_ai::Request, + cx: &AsyncApp, + ) -> BoxFuture<'static, Result>>> + { + let http_client = self.http_client.clone(); + let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, cx| { + let settings = &AllLanguageModelSettings::get_global(cx).x_ai; + let api_url = if settings.api_url.is_empty() { + x_ai::XAI_API_URL.to_string() + } else { + settings.api_url.clone() + }; + (state.api_key.clone(), api_url) + }) else { + return futures::future::ready(Err(anyhow!("App state dropped"))).boxed(); + }; + + let future = self.request_limiter.stream(async move { + let api_key = api_key.context("Missing xAI API Key")?; + let request = + open_ai::stream_completion(http_client.as_ref(), &api_url, &api_key, request); + let response = request.await?; + Ok(response) + }); + + async move { Ok(future.await?.boxed()) }.boxed() + } +} + +impl LanguageModel for XAiLanguageModel { + fn id(&self) -> LanguageModelId { + self.id.clone() + } + + fn name(&self) -> LanguageModelName { + LanguageModelName::from(self.model.display_name().to_string()) + } + + fn provider_id(&self) -> LanguageModelProviderId { + LanguageModelProviderId(PROVIDER_ID.into()) + } + + fn provider_name(&self) -> LanguageModelProviderName { + LanguageModelProviderName(PROVIDER_NAME.into()) + } + + fn supports_tools(&self) -> bool { + self.model.supports_tool() + } + + fn supports_images(&self) -> bool { + self.model.supports_images() + } + + fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool { + match choice { + LanguageModelToolChoice::Auto + | LanguageModelToolChoice::Any + | LanguageModelToolChoice::None => true, + } + } + fn tool_input_format(&self) -> LanguageModelToolSchemaFormat { + let model_id = self.model.id().trim().to_lowercase(); + if model_id.eq(x_ai::Model::Grok4.id()) { + LanguageModelToolSchemaFormat::JsonSchemaSubset + } else { + LanguageModelToolSchemaFormat::JsonSchema + } + } + + fn telemetry_id(&self) -> String { + format!("x_ai/{}", self.model.id()) + } + + fn max_token_count(&self) -> u64 { + self.model.max_token_count() + } + + fn max_output_tokens(&self) -> Option { + self.model.max_output_tokens() + } + + fn count_tokens( + &self, + request: LanguageModelRequest, + cx: &App, + ) -> BoxFuture<'static, Result> { + count_xai_tokens(request, self.model.clone(), cx) + } + + fn stream_completion( + &self, + request: LanguageModelRequest, + cx: &AsyncApp, + ) -> BoxFuture< + 'static, + Result< + futures::stream::BoxStream< + 'static, + Result, + >, + LanguageModelCompletionError, + >, + > { + let request = crate::provider::open_ai::into_open_ai( + request, + self.model.id(), + self.model.supports_parallel_tool_calls(), + self.max_output_tokens(), + ); + let completions = self.stream_completion(request, cx); + async move { + let mapper = crate::provider::open_ai::OpenAiEventMapper::new(); + Ok(mapper.map_stream(completions.await?).boxed()) + } + .boxed() + } +} + +pub fn count_xai_tokens( + request: LanguageModelRequest, + model: Model, + cx: &App, +) -> BoxFuture<'static, Result> { + cx.background_spawn(async move { + let messages = request + .messages + .into_iter() + .map(|message| tiktoken_rs::ChatCompletionRequestMessage { + role: match message.role { + Role::User => "user".into(), + Role::Assistant => "assistant".into(), + Role::System => "system".into(), + }, + content: Some(message.string_contents()), + name: None, + function_call: None, + }) + .collect::>(); + + let model_name = if model.max_token_count() >= 100_000 { + "gpt-4o" + } else { + "gpt-4" + }; + tiktoken_rs::num_tokens_from_messages(model_name, &messages).map(|tokens| tokens as u64) + }) + .boxed() +} + +struct ConfigurationView { + api_key_editor: Entity, + state: gpui::Entity, + load_credentials_task: Option>, +} + +impl ConfigurationView { + fn new(state: gpui::Entity, window: &mut Window, cx: &mut Context) -> Self { + let api_key_editor = cx.new(|cx| { + SingleLineInput::new( + window, + cx, + "xai-0000000000000000000000000000000000000000000000000", + ) + .label("API key") + }); + + cx.observe(&state, |_, _, cx| { + cx.notify(); + }) + .detach(); + + let load_credentials_task = Some(cx.spawn_in(window, { + let state = state.clone(); + async move |this, cx| { + if let Some(task) = state + .update(cx, |state, cx| state.authenticate(cx)) + .log_err() + { + // We don't log an error, because "not signed in" is also an error. + let _ = task.await; + } + this.update(cx, |this, cx| { + this.load_credentials_task = None; + cx.notify(); + }) + .log_err(); + } + })); + + Self { + api_key_editor, + state, + load_credentials_task, + } + } + + fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context) { + let api_key = self + .api_key_editor + .read(cx) + .editor() + .read(cx) + .text(cx) + .trim() + .to_string(); + + // Don't proceed if no API key is provided and we're not authenticated + if api_key.is_empty() && !self.state.read(cx).is_authenticated() { + return; + } + + let state = self.state.clone(); + cx.spawn_in(window, async move |_, cx| { + state + .update(cx, |state, cx| state.set_api_key(api_key, cx))? + .await + }) + .detach_and_log_err(cx); + + cx.notify(); + } + + fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context) { + self.api_key_editor.update(cx, |input, cx| { + input.editor.update(cx, |editor, cx| { + editor.set_text("", window, cx); + }); + }); + + let state = self.state.clone(); + cx.spawn_in(window, async move |_, cx| { + state.update(cx, |state, cx| state.reset_api_key(cx))?.await + }) + .detach_and_log_err(cx); + + cx.notify(); + } + + fn should_render_editor(&self, cx: &mut Context) -> bool { + !self.state.read(cx).is_authenticated() + } +} + +impl Render for ConfigurationView { + fn render(&mut self, _: &mut Window, cx: &mut Context) -> impl IntoElement { + let env_var_set = self.state.read(cx).api_key_from_env; + + let api_key_section = if self.should_render_editor(cx) { + v_flex() + .on_action(cx.listener(Self::save_api_key)) + .child(Label::new("To use Zed's agent with xAI, you need to add an API key. Follow these steps:")) + .child( + List::new() + .child(InstructionListItem::new( + "Create one by visiting", + Some("xAI console"), + Some("https://console.x.ai/team/default/api-keys"), + )) + .child(InstructionListItem::text_only( + "Paste your API key below and hit enter to start using the agent", + )), + ) + .child(self.api_key_editor.clone()) + .child( + Label::new(format!( + "You can also assign the {XAI_API_KEY_VAR} environment variable and restart Zed." + )) + .size(LabelSize::Small) + .color(Color::Muted), + ) + .child( + Label::new("Note that xAI is a custom OpenAI-compatible provider.") + .size(LabelSize::Small) + .color(Color::Muted), + ) + .into_any() + } else { + h_flex() + .mt_1() + .p_1() + .justify_between() + .rounded_md() + .border_1() + .border_color(cx.theme().colors().border) + .bg(cx.theme().colors().background) + .child( + h_flex() + .gap_1() + .child(Icon::new(IconName::Check).color(Color::Success)) + .child(Label::new(if env_var_set { + format!("API key set in {XAI_API_KEY_VAR} environment variable.") + } else { + "API key configured.".to_string() + })), + ) + .child( + Button::new("reset-api-key", "Reset API Key") + .label_size(LabelSize::Small) + .icon(IconName::Undo) + .icon_size(IconSize::Small) + .icon_position(IconPosition::Start) + .layer(ElevationIndex::ModalSurface) + .when(env_var_set, |this| { + this.tooltip(Tooltip::text(format!("To reset your API key, unset the {XAI_API_KEY_VAR} environment variable."))) + }) + .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx))), + ) + .into_any() + }; + + if self.load_credentials_task.is_some() { + div().child(Label::new("Loading credentials…")).into_any() + } else { + v_flex().size_full().child(api_key_section).into_any() + } + } +} diff --git a/crates/language_models/src/settings.rs b/crates/language_models/src/settings.rs index f96a2c0a66..dafbb62910 100644 --- a/crates/language_models/src/settings.rs +++ b/crates/language_models/src/settings.rs @@ -17,6 +17,7 @@ use crate::provider::{ open_ai::OpenAiSettings, open_router::OpenRouterSettings, vercel::VercelSettings, + x_ai::XAiSettings, }; /// Initializes the language model settings. @@ -28,33 +29,33 @@ pub fn init(cx: &mut App) { pub struct AllLanguageModelSettings { pub anthropic: AnthropicSettings, pub bedrock: AmazonBedrockSettings, - pub ollama: OllamaSettings, - pub openai: OpenAiSettings, - pub open_router: OpenRouterSettings, - pub zed_dot_dev: ZedDotDevSettings, - pub google: GoogleSettings, - pub vercel: VercelSettings, - - pub lmstudio: LmStudioSettings, pub deepseek: DeepSeekSettings, + pub google: GoogleSettings, + pub lmstudio: LmStudioSettings, pub mistral: MistralSettings, + pub ollama: OllamaSettings, + pub open_router: OpenRouterSettings, + pub openai: OpenAiSettings, + pub vercel: VercelSettings, + pub x_ai: XAiSettings, + pub zed_dot_dev: ZedDotDevSettings, } #[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)] pub struct AllLanguageModelSettingsContent { pub anthropic: Option, pub bedrock: Option, - pub ollama: Option, + pub deepseek: Option, + pub google: Option, pub lmstudio: Option, - pub openai: Option, + pub mistral: Option, + pub ollama: Option, pub open_router: Option, + pub openai: Option, + pub vercel: Option, + pub x_ai: Option, #[serde(rename = "zed.dev")] pub zed_dot_dev: Option, - pub google: Option, - pub deepseek: Option, - pub vercel: Option, - - pub mistral: Option, } #[derive(Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)] @@ -114,6 +115,12 @@ pub struct GoogleSettingsContent { pub available_models: Option>, } +#[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)] +pub struct XAiSettingsContent { + pub api_url: Option, + pub available_models: Option>, +} + #[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)] pub struct ZedDotDevSettingsContent { available_models: Option>, @@ -230,6 +237,18 @@ impl settings::Settings for AllLanguageModelSettings { vercel.as_ref().and_then(|s| s.available_models.clone()), ); + // XAI + let x_ai = value.x_ai.clone(); + merge( + &mut settings.x_ai.api_url, + x_ai.as_ref().and_then(|s| s.api_url.clone()), + ); + merge( + &mut settings.x_ai.available_models, + x_ai.as_ref().and_then(|s| s.available_models.clone()), + ); + + // ZedDotDev merge( &mut settings.zed_dot_dev.available_models, value diff --git a/crates/project/src/project.rs b/crates/project/src/project.rs index 8e1026421e..85036eca86 100644 --- a/crates/project/src/project.rs +++ b/crates/project/src/project.rs @@ -3362,8 +3362,14 @@ impl Project { cx: &mut Context, ) -> Task>> { let position = position.to_point_utf16(buffer.read(cx)); - self.lsp_store.update(cx, |lsp_store, cx| { + let guard = self.retain_remotely_created_models(cx); + let task = self.lsp_store.update(cx, |lsp_store, cx| { lsp_store.definitions(buffer, position, cx) + }); + cx.spawn(async move |_, _| { + let result = task.await; + drop(guard); + result }) } @@ -3374,8 +3380,14 @@ impl Project { cx: &mut Context, ) -> Task>> { let position = position.to_point_utf16(buffer.read(cx)); - self.lsp_store.update(cx, |lsp_store, cx| { + let guard = self.retain_remotely_created_models(cx); + let task = self.lsp_store.update(cx, |lsp_store, cx| { lsp_store.declarations(buffer, position, cx) + }); + cx.spawn(async move |_, _| { + let result = task.await; + drop(guard); + result }) } @@ -3386,8 +3398,14 @@ impl Project { cx: &mut Context, ) -> Task>> { let position = position.to_point_utf16(buffer.read(cx)); - self.lsp_store.update(cx, |lsp_store, cx| { + let guard = self.retain_remotely_created_models(cx); + let task = self.lsp_store.update(cx, |lsp_store, cx| { lsp_store.type_definitions(buffer, position, cx) + }); + cx.spawn(async move |_, _| { + let result = task.await; + drop(guard); + result }) } @@ -3398,8 +3416,14 @@ impl Project { cx: &mut Context, ) -> Task>> { let position = position.to_point_utf16(buffer.read(cx)); - self.lsp_store.update(cx, |lsp_store, cx| { + let guard = self.retain_remotely_created_models(cx); + let task = self.lsp_store.update(cx, |lsp_store, cx| { lsp_store.implementations(buffer, position, cx) + }); + cx.spawn(async move |_, _| { + let result = task.await; + drop(guard); + result }) } @@ -3410,8 +3434,14 @@ impl Project { cx: &mut Context, ) -> Task>> { let position = position.to_point_utf16(buffer.read(cx)); - self.lsp_store.update(cx, |lsp_store, cx| { + let guard = self.retain_remotely_created_models(cx); + let task = self.lsp_store.update(cx, |lsp_store, cx| { lsp_store.references(buffer, position, cx) + }); + cx.spawn(async move |_, _| { + let result = task.await; + drop(guard); + result }) } diff --git a/crates/project_panel/src/project_panel.rs b/crates/project_panel/src/project_panel.rs index 0ec9bac33f..0281397374 100644 --- a/crates/project_panel/src/project_panel.rs +++ b/crates/project_panel/src/project_panel.rs @@ -186,7 +186,6 @@ struct EntryDetails { #[derive(Debug, PartialEq, Eq, Clone)] struct StickyDetails { sticky_index: usize, - is_last: bool, } /// Permanently deletes the selected file or directory. @@ -3938,31 +3937,14 @@ impl ProjectPanel { } }; - let show_sticky_shadow = details.sticky.as_ref().map_or(false, |item| { - if item.is_last { - let is_scrollable = self.scroll_handle.is_scrollable(); - let is_scrolled = self.scroll_handle.offset().y < px(0.); - is_scrollable && is_scrolled - } else { - false - } - }); - let shadow_color_top = hsla(0.0, 0.0, 0.0, 0.1); - let shadow_color_bottom = hsla(0.0, 0.0, 0.0, 0.); - let sticky_shadow = div() - .absolute() - .left_0() - .bottom_neg_1p5() - .h_1p5() - .w_full() - .bg(linear_gradient( - 0., - linear_color_stop(shadow_color_top, 1.), - linear_color_stop(shadow_color_bottom, 0.), - )); + let id: ElementId = if is_sticky { + SharedString::from(format!("project_panel_sticky_item_{}", entry_id.to_usize())).into() + } else { + (entry_id.to_proto() as usize).into() + }; div() - .id(entry_id.to_proto() as usize) + .id(id.clone()) .relative() .group(GROUP_NAME) .cursor_pointer() @@ -3972,7 +3954,9 @@ impl ProjectPanel { .border_r_2() .border_color(border_color) .hover(|style| style.bg(bg_hover_color).border_color(border_hover_color)) - .when(show_sticky_shadow, |this| this.child(sticky_shadow)) + .when(is_sticky, |this| { + this.block_mouse_except_scroll() + }) .when(!is_sticky, |this| { this .when(is_highlighted && folded_directory_drag_target.is_none(), |this| this.border_color(transparent_white()).bg(item_colors.drag_over)) @@ -4183,6 +4167,16 @@ impl ProjectPanel { .unwrap_or(ScrollStrategy::Top); this.scroll_handle.scroll_to_item(index, strategy); cx.notify(); + // move down by 1px so that clicked item + // don't count as sticky anymore + cx.on_next_frame(window, |_, window, cx| { + cx.on_next_frame(window, |this, _, cx| { + let mut offset = this.scroll_handle.offset(); + offset.y += px(1.); + this.scroll_handle.set_offset(offset); + cx.notify(); + }); + }); return; } } @@ -4201,7 +4195,7 @@ impl ProjectPanel { }), ) .child( - ListItem::new(entry_id.to_proto() as usize) + ListItem::new(id) .indent_level(depth) .indent_step_size(px(settings.indent_size)) .spacing(match settings.entry_spacing { @@ -4924,7 +4918,6 @@ impl ProjectPanel { .unwrap_or_default(); let sticky_details = Some(StickyDetails { sticky_index: index, - is_last: index == last_item_index, }); let details = self.details_for_entry( entry, @@ -4936,7 +4929,24 @@ impl ProjectPanel { window, cx, ); - self.render_entry(entry.id, details, window, cx).into_any() + self.render_entry(entry.id, details, window, cx) + .when(index == last_item_index, |this| { + let shadow_color_top = hsla(0.0, 0.0, 0.0, 0.1); + let shadow_color_bottom = hsla(0.0, 0.0, 0.0, 0.); + let sticky_shadow = div() + .absolute() + .left_0() + .bottom_neg_1p5() + .h_1p5() + .w_full() + .bg(linear_gradient( + 0., + linear_color_stop(shadow_color_top, 1.), + linear_color_stop(shadow_color_bottom, 0.), + )); + this.child(sticky_shadow) + }) + .into_any() }) .collect() } @@ -4970,7 +4980,16 @@ impl Render for ProjectPanel { let indent_size = ProjectPanelSettings::get_global(cx).indent_size; let show_indent_guides = ProjectPanelSettings::get_global(cx).indent_guides.show == ShowIndentGuides::Always; - let show_sticky_scroll = ProjectPanelSettings::get_global(cx).sticky_scroll; + let show_sticky_entries = { + if ProjectPanelSettings::get_global(cx).sticky_scroll { + let is_scrollable = self.scroll_handle.is_scrollable(); + let is_scrolled = self.scroll_handle.offset().y < px(0.); + is_scrollable && is_scrolled + } else { + false + } + }; + let is_local = project.is_local(); if has_worktree { @@ -5262,7 +5281,7 @@ impl Render for ProjectPanel { }), ) }) - .when(show_sticky_scroll, |list| { + .when(show_sticky_entries, |list| { let sticky_items = ui::sticky_items( cx.entity().clone(), |this, range, window, cx| { diff --git a/crates/rules_library/src/rules_library.rs b/crates/rules_library/src/rules_library.rs index 66f589bfd3..49eca26838 100644 --- a/crates/rules_library/src/rules_library.rs +++ b/crates/rules_library/src/rules_library.rs @@ -611,7 +611,7 @@ impl RulesLibrary { this.update_in(cx, |this, window, cx| match rule { Ok(rule) => { let title_editor = cx.new(|cx| { - let mut editor = Editor::auto_width(window, cx); + let mut editor = Editor::single_line(window, cx); editor.set_placeholder_text("Untitled", cx); editor.set_text(rule_metadata.title.unwrap_or_default(), window, cx); if prompt_id.is_built_in() { diff --git a/crates/terminal_view/src/terminal_element.rs b/crates/terminal_view/src/terminal_element.rs index c34d892644..083c07de9c 100644 --- a/crates/terminal_view/src/terminal_element.rs +++ b/crates/terminal_view/src/terminal_element.rs @@ -127,7 +127,7 @@ impl BatchedTextRun { cx: &mut App, ) { let pos = Point::new( - (origin.x + self.start_point.column as f32 * dimensions.cell_width).floor(), + origin.x + self.start_point.column as f32 * dimensions.cell_width, origin.y + self.start_point.line as f32 * dimensions.line_height, ); @@ -494,6 +494,30 @@ impl TerminalElement { } } + /// Checks if a character is a decorative block/box-like character that should + /// preserve its exact colors without contrast adjustment. + /// + /// This specifically targets characters used as visual connectors, separators, + /// and borders where color matching with adjacent backgrounds is critical. + /// Regular icons (git, folders, etc.) are excluded as they need to remain readable. + /// + /// Fixes https://github.com/zed-industries/zed/issues/34234 + fn is_decorative_character(ch: char) -> bool { + matches!( + ch as u32, + // Unicode Box Drawing and Block Elements + 0x2500..=0x257F // Box Drawing (└ ┐ ─ │ etc.) + | 0x2580..=0x259F // Block Elements (▀ ▄ █ ░ ▒ ▓ etc.) + | 0x25A0..=0x25FF // Geometric Shapes (■ ▶ ● etc. - includes triangular/circular separators) + + // Private Use Area - Powerline separator symbols only + | 0xE0B0..=0xE0B7 // Powerline separators: triangles (E0B0-E0B3) and half circles (E0B4-E0B7) + | 0xE0B8..=0xE0BF // Additional Powerline separators: angles, flames, etc. + | 0xE0C0..=0xE0C8 // Powerline separators: pixelated triangles, curves + | 0xE0CC..=0xE0D4 // Powerline separators: rounded triangles, ice/lego style + ) + } + /// Converts the Alacritty cell styles to GPUI text styles and background color. fn cell_style( indexed: &IndexedCell, @@ -508,7 +532,10 @@ impl TerminalElement { let mut fg = convert_color(&fg, colors); let bg = convert_color(&bg, colors); - fg = color_contrast::ensure_minimum_contrast(fg, bg, minimum_contrast); + // Only apply contrast adjustment to non-decorative characters + if !Self::is_decorative_character(indexed.c) { + fg = color_contrast::ensure_minimum_contrast(fg, bg, minimum_contrast); + } // Ghostty uses (175/255) as the multiplier (~0.69), Alacritty uses 0.66, Kitty // uses 0.75. We're using 0.7 because it's pretty well in the middle of that. @@ -1575,6 +1602,101 @@ mod tests { use super::*; use gpui::{AbsoluteLength, Hsla, font}; + #[test] + fn test_is_decorative_character() { + // Box Drawing characters (U+2500 to U+257F) + assert!(TerminalElement::is_decorative_character('─')); // U+2500 + assert!(TerminalElement::is_decorative_character('│')); // U+2502 + assert!(TerminalElement::is_decorative_character('┌')); // U+250C + assert!(TerminalElement::is_decorative_character('┐')); // U+2510 + assert!(TerminalElement::is_decorative_character('└')); // U+2514 + assert!(TerminalElement::is_decorative_character('┘')); // U+2518 + assert!(TerminalElement::is_decorative_character('┼')); // U+253C + + // Block Elements (U+2580 to U+259F) + assert!(TerminalElement::is_decorative_character('▀')); // U+2580 + assert!(TerminalElement::is_decorative_character('▄')); // U+2584 + assert!(TerminalElement::is_decorative_character('█')); // U+2588 + assert!(TerminalElement::is_decorative_character('░')); // U+2591 + assert!(TerminalElement::is_decorative_character('▒')); // U+2592 + assert!(TerminalElement::is_decorative_character('▓')); // U+2593 + + // Geometric Shapes - block/box-like subset (U+25A0 to U+25D7) + assert!(TerminalElement::is_decorative_character('■')); // U+25A0 + assert!(TerminalElement::is_decorative_character('□')); // U+25A1 + assert!(TerminalElement::is_decorative_character('▲')); // U+25B2 + assert!(TerminalElement::is_decorative_character('▼')); // U+25BC + assert!(TerminalElement::is_decorative_character('◆')); // U+25C6 + assert!(TerminalElement::is_decorative_character('●')); // U+25CF + + // The specific character from the issue + assert!(TerminalElement::is_decorative_character('◗')); // U+25D7 + assert!(TerminalElement::is_decorative_character('◘')); // U+25D8 (now included in Geometric Shapes) + assert!(TerminalElement::is_decorative_character('◙')); // U+25D9 (now included in Geometric Shapes) + + // Powerline symbols (Private Use Area) + assert!(TerminalElement::is_decorative_character('\u{E0B0}')); // Powerline right triangle + assert!(TerminalElement::is_decorative_character('\u{E0B2}')); // Powerline left triangle + assert!(TerminalElement::is_decorative_character('\u{E0B4}')); // Powerline right half circle (the actual issue!) + assert!(TerminalElement::is_decorative_character('\u{E0B6}')); // Powerline left half circle + + // Characters that should NOT be considered decorative + assert!(!TerminalElement::is_decorative_character('A')); // Regular letter + assert!(!TerminalElement::is_decorative_character('$')); // Symbol + assert!(!TerminalElement::is_decorative_character(' ')); // Space + assert!(!TerminalElement::is_decorative_character('←')); // U+2190 (Arrow, not in our ranges) + assert!(!TerminalElement::is_decorative_character('→')); // U+2192 (Arrow, not in our ranges) + assert!(!TerminalElement::is_decorative_character('\u{F00C}')); // Font Awesome check (icon, needs contrast) + assert!(!TerminalElement::is_decorative_character('\u{E711}')); // Devicons (icon, needs contrast) + assert!(!TerminalElement::is_decorative_character('\u{EA71}')); // Codicons folder (icon, needs contrast) + assert!(!TerminalElement::is_decorative_character('\u{F401}')); // Octicons (icon, needs contrast) + assert!(!TerminalElement::is_decorative_character('\u{1F600}')); // Emoji (not in our ranges) + } + + #[test] + fn test_decorative_character_boundary_cases() { + // Test exact boundaries of our ranges + // Box Drawing range boundaries + assert!(TerminalElement::is_decorative_character('\u{2500}')); // First char + assert!(TerminalElement::is_decorative_character('\u{257F}')); // Last char + assert!(!TerminalElement::is_decorative_character('\u{24FF}')); // Just before + + // Block Elements range boundaries + assert!(TerminalElement::is_decorative_character('\u{2580}')); // First char + assert!(TerminalElement::is_decorative_character('\u{259F}')); // Last char + + // Geometric Shapes subset boundaries + assert!(TerminalElement::is_decorative_character('\u{25A0}')); // First char + assert!(TerminalElement::is_decorative_character('\u{25FF}')); // Last char + assert!(!TerminalElement::is_decorative_character('\u{2600}')); // Just after + } + + #[test] + fn test_decorative_characters_bypass_contrast_adjustment() { + // Decorative characters should not be affected by contrast adjustment + + // The specific character from issue #34234 + let problematic_char = '◗'; // U+25D7 + assert!( + TerminalElement::is_decorative_character(problematic_char), + "Character ◗ (U+25D7) should be recognized as decorative" + ); + + // Verify some other commonly used decorative characters + assert!(TerminalElement::is_decorative_character('│')); // Vertical line + assert!(TerminalElement::is_decorative_character('─')); // Horizontal line + assert!(TerminalElement::is_decorative_character('█')); // Full block + assert!(TerminalElement::is_decorative_character('▓')); // Dark shade + assert!(TerminalElement::is_decorative_character('■')); // Black square + assert!(TerminalElement::is_decorative_character('●')); // Black circle + + // Verify normal text characters are NOT decorative + assert!(!TerminalElement::is_decorative_character('A')); + assert!(!TerminalElement::is_decorative_character('1')); + assert!(!TerminalElement::is_decorative_character('$')); + assert!(!TerminalElement::is_decorative_character(' ')); + } + #[test] fn test_contrast_adjustment_logic() { // Test the core contrast adjustment logic without needing full app context diff --git a/crates/theme_selector/src/icon_theme_selector.rs b/crates/theme_selector/src/icon_theme_selector.rs index 40ba7bd5a6..1adfc4b5d8 100644 --- a/crates/theme_selector/src/icon_theme_selector.rs +++ b/crates/theme_selector/src/icon_theme_selector.rs @@ -327,6 +327,7 @@ impl PickerDelegate for IconThemeSelectorDelegate { window.dispatch_action( Box::new(Extensions { category_filter: Some(ExtensionCategoryFilter::IconThemes), + id: None, }), cx, ); diff --git a/crates/theme_selector/src/theme_selector.rs b/crates/theme_selector/src/theme_selector.rs index 09d9877df8..022daced7a 100644 --- a/crates/theme_selector/src/theme_selector.rs +++ b/crates/theme_selector/src/theme_selector.rs @@ -385,6 +385,7 @@ impl PickerDelegate for ThemeSelectorDelegate { window.dispatch_action( Box::new(Extensions { category_filter: Some(ExtensionCategoryFilter::Themes), + id: None, }), cx, ); diff --git a/crates/ui/src/components/sticky_items.rs b/crates/ui/src/components/sticky_items.rs index 218f7aae35..ca8b336a5a 100644 --- a/crates/ui/src/components/sticky_items.rs +++ b/crates/ui/src/components/sticky_items.rs @@ -149,47 +149,7 @@ where ) -> AnyElement { let entries = (self.compute_fn)(visible_range.clone(), window, cx); - struct StickyAnchor { - entry: T, - index: usize, - } - - let mut sticky_anchor = None; - let mut last_item_is_drifting = false; - - let mut iter = entries.iter().enumerate().peekable(); - while let Some((ix, current_entry)) = iter.next() { - let depth = current_entry.depth(); - - if depth < ix { - sticky_anchor = Some(StickyAnchor { - entry: current_entry.clone(), - index: visible_range.start + ix, - }); - break; - } - - if let Some(&(_next_ix, next_entry)) = iter.peek() { - let next_depth = next_entry.depth(); - let next_item_outdented = next_depth + 1 == depth; - - let depth_same_as_index = depth == ix; - let depth_greater_than_index = depth == ix + 1; - - if next_item_outdented && (depth_same_as_index || depth_greater_than_index) { - if depth_greater_than_index { - last_item_is_drifting = true; - } - sticky_anchor = Some(StickyAnchor { - entry: current_entry.clone(), - index: visible_range.start + ix, - }); - break; - } - } - } - - let Some(sticky_anchor) = sticky_anchor else { + let Some(sticky_anchor) = find_sticky_anchor(&entries, visible_range.start) else { return StickyItemsElement { drifting_element: None, drifting_decoration: None, @@ -203,23 +163,21 @@ where let mut elements = (self.render_fn)(sticky_anchor.entry, window, cx); let items_count = elements.len(); - let indents: SmallVec<[usize; 8]> = { - elements - .iter() - .enumerate() - .map(|(ix, _)| anchor_depth.saturating_sub(items_count.saturating_sub(ix))) - .collect() - }; + let indents: SmallVec<[usize; 8]> = (0..items_count) + .map(|ix| anchor_depth.saturating_sub(items_count.saturating_sub(ix))) + .collect(); let mut last_decoration_element = None; let mut rest_decoration_elements = SmallVec::new(); - let available_space = size( - AvailableSpace::Definite(bounds.size.width), + let expanded_width = bounds.size.width + scroll_offset.x.abs(); + + let decor_available_space = size( + AvailableSpace::Definite(expanded_width), AvailableSpace::Definite(bounds.size.height), ); - let drifting_y_offset = if last_item_is_drifting { + let drifting_y_offset = if sticky_anchor.drifting { let scroll_top = -scroll_offset.y; let anchor_top = item_height * (sticky_anchor.index + 1); let sticky_area_height = item_height * items_count; @@ -228,7 +186,7 @@ where Pixels::ZERO }; - let (drifting_indent, rest_indents) = if last_item_is_drifting && !indents.is_empty() { + let (drifting_indent, rest_indents) = if sticky_anchor.drifting && !indents.is_empty() { let last = indents[indents.len() - 1]; let rest: SmallVec<[usize; 8]> = indents[..indents.len() - 1].iter().copied().collect(); (Some(last), rest) @@ -236,11 +194,14 @@ where (None, indents) }; + let base_origin = bounds.origin - point(px(0.), scroll_offset.y); + for decoration in &self.decorations { if let Some(drifting_indent) = drifting_indent { let drifting_indent_vec: SmallVec<[usize; 8]> = [drifting_indent].into_iter().collect(); - let sticky_origin = bounds.origin - scroll_offset + + let sticky_origin = base_origin + point(px(0.), item_height * rest_indents.len() + drifting_y_offset); let decoration_bounds = Bounds::new(sticky_origin, bounds.size); @@ -252,13 +213,13 @@ where window, cx, ); - drifting_dec.layout_as_root(available_space, window, cx); + drifting_dec.layout_as_root(decor_available_space, window, cx); drifting_dec.prepaint_at(sticky_origin, window, cx); last_decoration_element = Some(drifting_dec); } if !rest_indents.is_empty() { - let decoration_bounds = Bounds::new(bounds.origin - scroll_offset, bounds.size); + let decoration_bounds = Bounds::new(base_origin, bounds.size); let mut rest_dec = decoration.as_ref().compute( &rest_indents, decoration_bounds, @@ -267,46 +228,45 @@ where window, cx, ); - rest_dec.layout_as_root(available_space, window, cx); + rest_dec.layout_as_root(decor_available_space, window, cx); rest_dec.prepaint_at(bounds.origin, window, cx); rest_decoration_elements.push(rest_dec); } } let (mut drifting_element, mut rest_elements) = - if last_item_is_drifting && !elements.is_empty() { + if sticky_anchor.drifting && !elements.is_empty() { let last = elements.pop().unwrap(); (Some(last), elements) } else { (None, elements) }; - for (ix, element) in rest_elements.iter_mut().enumerate() { - let sticky_origin = bounds.origin - scroll_offset + point(px(0.), item_height * ix); - let element_available_space = size( - AvailableSpace::Definite(bounds.size.width), - AvailableSpace::Definite(item_height), - ); - - element.layout_as_root(element_available_space, window, cx); - element.prepaint_at(sticky_origin, window, cx); - } + let element_available_space = size( + AvailableSpace::Definite(expanded_width), + AvailableSpace::Definite(item_height), + ); + // order of prepaint is important here + // mouse events checks hitboxes in reverse insertion order if let Some(ref mut drifting_element) = drifting_element { - let sticky_origin = bounds.origin - scroll_offset + let sticky_origin = base_origin + point( px(0.), item_height * rest_elements.len() + drifting_y_offset, ); - let element_available_space = size( - AvailableSpace::Definite(bounds.size.width), - AvailableSpace::Definite(item_height), - ); drifting_element.layout_as_root(element_available_space, window, cx); drifting_element.prepaint_at(sticky_origin, window, cx); } + for (ix, element) in rest_elements.iter_mut().enumerate() { + let sticky_origin = base_origin + point(px(0.), item_height * ix); + + element.layout_as_root(element_available_space, window, cx); + element.prepaint_at(sticky_origin, window, cx); + } + StickyItemsElement { drifting_element, drifting_decoration: last_decoration_element, @@ -317,6 +277,48 @@ where } } +struct StickyAnchor { + entry: T, + index: usize, + drifting: bool, +} + +fn find_sticky_anchor( + entries: &SmallVec<[T; 8]>, + visible_range_start: usize, +) -> Option> { + let mut iter = entries.iter().enumerate().peekable(); + while let Some((ix, current_entry)) = iter.next() { + let depth = current_entry.depth(); + + if depth < ix { + return Some(StickyAnchor { + entry: current_entry.clone(), + index: visible_range_start + ix, + drifting: false, + }); + } + + if let Some(&(_next_ix, next_entry)) = iter.peek() { + let next_depth = next_entry.depth(); + let next_item_outdented = next_depth + 1 == depth; + + let depth_same_as_index = depth == ix; + let depth_greater_than_index = depth == ix + 1; + + if next_item_outdented && (depth_same_as_index || depth_greater_than_index) { + return Some(StickyAnchor { + entry: current_entry.clone(), + index: visible_range_start + ix, + drifting: depth_greater_than_index, + }); + } + } + } + + None +} + /// A decoration for a [`StickyItems`]. This can be used for various things, /// such as rendering indent guides, or other visual effects. pub trait StickyItemsDecoration { diff --git a/crates/vim/src/normal/scroll.rs b/crates/vim/src/normal/scroll.rs index 47b9fe92fd..e2ae74b52b 100644 --- a/crates/vim/src/normal/scroll.rs +++ b/crates/vim/src/normal/scroll.rs @@ -230,7 +230,11 @@ fn scroll_editor( // column position, or the right-most column in the current // line, seeing as the cursor might be in a short line, in which // case we don't want to go past its last column. - let max_row_column = map.line_len(new_row); + let max_row_column = if new_row <= map.max_point().row() { + map.line_len(new_row) + } else { + 0 + }; let max_column = match min_column + visible_column_count as u32 { max_column if max_column >= max_row_column => max_row_column, max_column => max_column, diff --git a/crates/x_ai/Cargo.toml b/crates/x_ai/Cargo.toml new file mode 100644 index 0000000000..7ca0ca0939 --- /dev/null +++ b/crates/x_ai/Cargo.toml @@ -0,0 +1,23 @@ +[package] +name = "x_ai" +version = "0.1.0" +edition.workspace = true +publish.workspace = true +license = "GPL-3.0-or-later" + +[lints] +workspace = true + +[lib] +path = "src/x_ai.rs" + +[features] +default = [] +schemars = ["dep:schemars"] + +[dependencies] +anyhow.workspace = true +schemars = { workspace = true, optional = true } +serde.workspace = true +strum.workspace = true +workspace-hack.workspace = true diff --git a/crates/x_ai/LICENSE-GPL b/crates/x_ai/LICENSE-GPL new file mode 120000 index 0000000000..89e542f750 --- /dev/null +++ b/crates/x_ai/LICENSE-GPL @@ -0,0 +1 @@ +../../LICENSE-GPL \ No newline at end of file diff --git a/crates/x_ai/src/x_ai.rs b/crates/x_ai/src/x_ai.rs new file mode 100644 index 0000000000..ac116b2f8f --- /dev/null +++ b/crates/x_ai/src/x_ai.rs @@ -0,0 +1,126 @@ +use anyhow::Result; +use serde::{Deserialize, Serialize}; +use strum::EnumIter; + +pub const XAI_API_URL: &str = "https://api.x.ai/v1"; + +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)] +pub enum Model { + #[serde(rename = "grok-2-vision-latest")] + Grok2Vision, + #[default] + #[serde(rename = "grok-3-latest")] + Grok3, + #[serde(rename = "grok-3-mini-latest")] + Grok3Mini, + #[serde(rename = "grok-3-fast-latest")] + Grok3Fast, + #[serde(rename = "grok-3-mini-fast-latest")] + Grok3MiniFast, + #[serde(rename = "grok-4-latest")] + Grok4, + #[serde(rename = "custom")] + Custom { + name: String, + /// The name displayed in the UI, such as in the assistant panel model dropdown menu. + display_name: Option, + max_tokens: u64, + max_output_tokens: Option, + max_completion_tokens: Option, + }, +} + +impl Model { + pub fn default_fast() -> Self { + Self::Grok3Fast + } + + pub fn from_id(id: &str) -> Result { + match id { + "grok-2-vision" => Ok(Self::Grok2Vision), + "grok-3" => Ok(Self::Grok3), + "grok-3-mini" => Ok(Self::Grok3Mini), + "grok-3-fast" => Ok(Self::Grok3Fast), + "grok-3-mini-fast" => Ok(Self::Grok3MiniFast), + _ => anyhow::bail!("invalid model id '{id}'"), + } + } + + pub fn id(&self) -> &str { + match self { + Self::Grok2Vision => "grok-2-vision", + Self::Grok3 => "grok-3", + Self::Grok3Mini => "grok-3-mini", + Self::Grok3Fast => "grok-3-fast", + Self::Grok3MiniFast => "grok-3-mini-fast", + Self::Grok4 => "grok-4", + Self::Custom { name, .. } => name, + } + } + + pub fn display_name(&self) -> &str { + match self { + Self::Grok2Vision => "Grok 2 Vision", + Self::Grok3 => "Grok 3", + Self::Grok3Mini => "Grok 3 Mini", + Self::Grok3Fast => "Grok 3 Fast", + Self::Grok3MiniFast => "Grok 3 Mini Fast", + Self::Grok4 => "Grok 4", + Self::Custom { + name, display_name, .. + } => display_name.as_ref().unwrap_or(name), + } + } + + pub fn max_token_count(&self) -> u64 { + match self { + Self::Grok3 | Self::Grok3Mini | Self::Grok3Fast | Self::Grok3MiniFast => 131_072, + Self::Grok4 => 256_000, + Self::Grok2Vision => 8_192, + Self::Custom { max_tokens, .. } => *max_tokens, + } + } + + pub fn max_output_tokens(&self) -> Option { + match self { + Self::Grok3 | Self::Grok3Mini | Self::Grok3Fast | Self::Grok3MiniFast => Some(8_192), + Self::Grok4 => Some(64_000), + Self::Grok2Vision => Some(4_096), + Self::Custom { + max_output_tokens, .. + } => *max_output_tokens, + } + } + + pub fn supports_parallel_tool_calls(&self) -> bool { + match self { + Self::Grok2Vision + | Self::Grok3 + | Self::Grok3Mini + | Self::Grok3Fast + | Self::Grok3MiniFast + | Self::Grok4 => true, + Model::Custom { .. } => false, + } + } + + pub fn supports_tool(&self) -> bool { + match self { + Self::Grok2Vision + | Self::Grok3 + | Self::Grok3Mini + | Self::Grok3Fast + | Self::Grok3MiniFast + | Self::Grok4 => true, + Model::Custom { .. } => false, + } + } + + pub fn supports_images(&self) -> bool { + match self { + Self::Grok2Vision => true, + _ => false, + } + } +} diff --git a/crates/zed/Cargo.toml b/crates/zed/Cargo.toml index 884443e770..5c4dd8d746 100644 --- a/crates/zed/Cargo.toml +++ b/crates/zed/Cargo.toml @@ -2,7 +2,7 @@ description = "The fast, collaborative code editor." edition.workspace = true name = "zed" -version = "0.195.0" +version = "0.195.5" publish.workspace = true license = "GPL-3.0-or-later" authors = ["Zed Team "] diff --git a/crates/zed/RELEASE_CHANNEL b/crates/zed/RELEASE_CHANNEL index 38f8e886e1..870bbe4e50 100644 --- a/crates/zed/RELEASE_CHANNEL +++ b/crates/zed/RELEASE_CHANNEL @@ -1 +1 @@ -dev +stable \ No newline at end of file diff --git a/crates/zed/src/main.rs b/crates/zed/src/main.rs index b5efea10e2..d3727a06dc 100644 --- a/crates/zed/src/main.rs +++ b/crates/zed/src/main.rs @@ -725,6 +725,23 @@ fn handle_open_request(request: OpenRequest, app_state: Arc, cx: &mut return; } + if let Some(extension) = request.extension_id { + cx.spawn(async move |cx| { + let workspace = workspace::get_any_active_workspace(app_state, cx.clone()).await?; + workspace.update(cx, |_, window, cx| { + window.dispatch_action( + Box::new(zed_actions::Extensions { + category_filter: None, + id: Some(extension), + }), + cx, + ); + }) + }) + .detach_and_log_err(cx); + return; + } + if let Some(connection_options) = request.ssh_connection { cx.spawn(async move |mut cx| { let paths: Vec = request.open_paths.into_iter().map(PathBuf::from).collect(); diff --git a/crates/zed/src/zed/open_listener.rs b/crates/zed/src/zed/open_listener.rs index 0fb08d1be5..42eb8198a4 100644 --- a/crates/zed/src/zed/open_listener.rs +++ b/crates/zed/src/zed/open_listener.rs @@ -37,6 +37,7 @@ pub struct OpenRequest { pub join_channel: Option, pub ssh_connection: Option, pub dock_menu_action: Option, + pub extension_id: Option, } impl OpenRequest { @@ -54,6 +55,8 @@ impl OpenRequest { } else if let Some(file) = url.strip_prefix("zed://ssh") { let ssh_url = "ssh:/".to_string() + file; this.parse_ssh_file_path(&ssh_url, cx)? + } else if let Some(file) = url.strip_prefix("zed://extension/") { + this.extension_id = Some(file.to_string()) } else if url.starts_with("ssh://") { this.parse_ssh_file_path(&url, cx)? } else if let Some(request_path) = parse_zed_link(&url, cx) { diff --git a/crates/zed_actions/src/lib.rs b/crates/zed_actions/src/lib.rs index ffe232ad7b..2894a0e52f 100644 --- a/crates/zed_actions/src/lib.rs +++ b/crates/zed_actions/src/lib.rs @@ -76,6 +76,9 @@ pub struct Extensions { /// Filters the extensions page down to extensions that are in the specified category. #[serde(default)] pub category_filter: Option, + /// Focuses just the extension with the specified ID. + #[serde(default)] + pub id: Option, } /// Decreases the font size in the editor buffer. diff --git a/docs/src/ai/configuration.md b/docs/src/ai/configuration.md index 907e318d05..0d10aba8f2 100644 --- a/docs/src/ai/configuration.md +++ b/docs/src/ai/configuration.md @@ -23,6 +23,8 @@ Here's an overview of the supported providers and tool call support: | [OpenAI](#openai) | ✅ | | [OpenAI API Compatible](#openai-api-compatible) | 🚫 | | [OpenRouter](#openrouter) | ✅ | +| [Vercel](#vercel-v0) | ✅ | +| [xAI](#xai) | ✅ | ## Use Your Own Keys {#use-your-own-keys} @@ -442,27 +444,30 @@ Custom models will be listed in the model dropdown in the Agent Panel. Zed supports using OpenAI compatible APIs by specifying a custom `endpoint` and `available_models` for the OpenAI provider. -You can add a custom API URL for OpenAI either via the UI or by editing your `settings.json`. -Here are a few model examples you can plug in by using this feature: +Zed supports using OpenAI compatible APIs by specifying a custom `api_url` and `available_models` for the OpenAI provider. This is useful for connecting to other hosted services (like Together AI, Anyscale, etc.) or local models. -#### X.ai Grok +To configure a compatible API, you can add a custom API URL for OpenAI either via the UI or by editing your `settings.json`. For example, to connect to [Together AI](https://www.together.ai/): -Example configuration for using X.ai Grok with Zed: +1. Get an API key from your [Together AI account](https://api.together.ai/settings/api-keys). +2. Add the following to your `settings.json`: ```json +{ "language_models": { "openai": { - "api_url": "https://api.x.ai/v1", + "api_url": "https://api.together.xyz/v1", + "api_key": "YOUR_TOGETHER_AI_API_KEY", "available_models": [ { - "name": "grok-beta", - "display_name": "X.ai Grok (Beta)", - "max_tokens": 131072 + "name": "mistralai/Mixtral-8x7B-Instruct-v0.1", + "display_name": "Together Mixtral 8x7B", + "max_tokens": 32768, + "supports_tools": true } - ], - "version": "1" - }, + ] + } } +} ``` ### OpenRouter {#openrouter} @@ -523,7 +528,9 @@ You can find available models and their specifications on the [OpenRouter models Custom models will be listed in the model dropdown in the Agent Panel. -### Vercel v0 +### Vercel v0 {#vercel-v0} + +> ✅ Supports tool use [Vercel v0](https://vercel.com/docs/v0/api) is an expert model for generating full-stack apps, with framework-aware completions optimized for modern stacks like Next.js and Vercel. It supports text and image inputs and provides fast streaming responses. @@ -535,6 +542,49 @@ Once you have it, paste it directly into the Vercel provider section in the pane You should then find it as `v0-1.5-md` in the model dropdown in the Agent Panel. +### xAI {#xai} + +> ✅ Supports tool use + +Zed has first-class support for [xAI](https://x.ai/) models. You can use your own API key to access Grok models. + +1. [Create an API key in the xAI Console](https://console.x.ai/team/default/api-keys) +2. Open the settings view (`agent: open configuration`) and go to the **xAI** section +3. Enter your xAI API key + +The xAI API key will be saved in your keychain. Zed will also use the `XAI_API_KEY` environment variable if it's defined. + +> **Note:** While the xAI API is OpenAI-compatible, Zed has first-class support for it as a dedicated provider. For the best experience, we recommend using the dedicated `x_ai` provider configuration instead of the [OpenAI API Compatible](#openai-api-compatible) method. + +#### Custom Models {#xai-custom-models} + +The Zed agent comes pre-configured with common Grok models. If you wish to use alternate models or customize their parameters, you can do so by adding the following to your Zed `settings.json`: + +```json +{ + "language_models": { + "x_ai": { + "api_url": "https://api.x.ai/v1", + "available_models": [ + { + "name": "grok-1.5", + "display_name": "Grok 1.5", + "max_tokens": 131072, + "max_output_tokens": 8192 + }, + { + "name": "grok-1.5v", + "display_name": "Grok 1.5V (Vision)", + "max_tokens": 131072, + "max_output_tokens": 8192, + "supports_images": true + } + ] + } + } +} +``` + ## Advanced Configuration {#advanced-configuration} ### Custom Provider Endpoints {#custom-provider-endpoint} diff --git a/docs/src/linux.md b/docs/src/linux.md index 896bfdaf3f..ca65da2969 100644 --- a/docs/src/linux.md +++ b/docs/src/linux.md @@ -148,7 +148,7 @@ On some systems the file `/etc/prime-discrete` can be used to enforce the use of On others, you may be able to the environment variable `DRI_PRIME=1` when running Zed to force the use of the discrete GPU. -If you're using an AMD GPU and Zed crashes when selecting long lines, try setting the `ZED_SAMPLE_COUNT=0` environment variable. (See [#26143](https://github.com/zed-industries/zed/issues/26143)) +If you're using an AMD GPU and Zed crashes when selecting long lines, try setting the `ZED_PATH_SAMPLE_COUNT=0` environment variable. (See [#26143](https://github.com/zed-industries/zed/issues/26143)) If you're using an AMD GPU, you might get a 'Broken Pipe' error. Try using the RADV or Mesa drivers. (See [#13880](https://github.com/zed-industries/zed/issues/13880)) diff --git a/script/bundle-windows.ps1 b/script/bundle-windows.ps1 index 9b61d220cf..dc0bdb9b7d 100644 --- a/script/bundle-windows.ps1 +++ b/script/bundle-windows.ps1 @@ -44,8 +44,6 @@ function CheckEnvironmentVariables { } } -$innoDir = "$env:ZED_WORKSPACE\inno" - function PrepareForBundle { if (Test-Path "$innoDir") { Remove-Item -Path "$innoDir" -Recurse -Force @@ -236,6 +234,8 @@ function BuildInstaller { } ParseZedWorkspace +$innoDir = "$env:ZED_WORKSPACE\inno" + CheckEnvironmentVariables PrepareForBundle BuildZedAndItsFriends