diff --git a/Cargo.lock b/Cargo.lock index bb3e4024ba..07a445eefe 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8955,6 +8955,7 @@ dependencies = [ "credentials_provider", "deepseek", "editor", + "feature_flags", "fs", "futures 0.3.31", "google_ai", @@ -18296,6 +18297,7 @@ version = "0.1.0" dependencies = [ "anyhow", "client", + "feature_flags", "futures 0.3.31", "gpui", "http_client", diff --git a/crates/feature_flags/src/feature_flags.rs b/crates/feature_flags/src/feature_flags.rs index 6c0cb763ef..2c2bbfe30d 100644 --- a/crates/feature_flags/src/feature_flags.rs +++ b/crates/feature_flags/src/feature_flags.rs @@ -92,6 +92,17 @@ impl FeatureFlag for JjUiFeatureFlag { const NAME: &'static str = "jj-ui"; } +pub struct ZedCloudFeatureFlag {} + +impl FeatureFlag for ZedCloudFeatureFlag { + const NAME: &'static str = "zed-cloud"; + + fn enabled_for_staff() -> bool { + // Require individual opt-in, for now. + false + } +} + pub trait FeatureFlagViewExt { fn observe_flag(&mut self, window: &Window, callback: F) -> Subscription where diff --git a/crates/http_client/src/http_client.rs b/crates/http_client/src/http_client.rs index 288dec9a31..c60a56002f 100644 --- a/crates/http_client/src/http_client.rs +++ b/crates/http_client/src/http_client.rs @@ -226,10 +226,21 @@ impl HttpClientWithUrl { } /// Builds a Zed LLM URL using the given path. - pub fn build_zed_llm_url(&self, path: &str, query: &[(&str, &str)]) -> Result { + pub fn build_zed_llm_url( + &self, + path: &str, + query: &[(&str, &str)], + use_cloud: bool, + ) -> Result { let base_url = self.base_url(); let base_api_url = match base_url.as_ref() { - "https://zed.dev" => "https://llm.zed.dev", + "https://zed.dev" => { + if use_cloud { + "https://cloud.zed.dev" + } else { + "https://llm.zed.dev" + } + } "https://staging.zed.dev" => "https://llm-staging.zed.dev", "http://localhost:3000" => "http://localhost:8787", other => other, diff --git a/crates/language_models/Cargo.toml b/crates/language_models/Cargo.toml index 0f248edd57..514443ddec 100644 --- a/crates/language_models/Cargo.toml +++ b/crates/language_models/Cargo.toml @@ -28,6 +28,7 @@ credentials_provider.workspace = true copilot.workspace = true deepseek = { workspace = true, features = ["schemars"] } editor.workspace = true +feature_flags.workspace = true fs.workspace = true futures.workspace = true google_ai = { workspace = true, features = ["schemars"] } diff --git a/crates/language_models/src/provider/cloud.rs b/crates/language_models/src/provider/cloud.rs index 1cd673710c..9b7fee228a 100644 --- a/crates/language_models/src/provider/cloud.rs +++ b/crates/language_models/src/provider/cloud.rs @@ -2,6 +2,7 @@ use anthropic::AnthropicModelMode; use anyhow::{Context as _, Result, anyhow}; use chrono::{DateTime, Utc}; use client::{Client, ModelRequestUsage, UserStore, zed_urls}; +use feature_flags::{FeatureFlagAppExt as _, ZedCloudFeatureFlag}; use futures::{ AsyncBufReadExt, FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream, }; @@ -136,6 +137,7 @@ impl State { cx: &mut Context, ) -> Self { let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx); + let use_cloud = cx.has_flag::(); Self { client: client.clone(), @@ -163,7 +165,7 @@ impl State { .await; } - let response = Self::fetch_models(client, llm_api_token).await?; + let response = Self::fetch_models(client, llm_api_token, use_cloud).await?; cx.update(|cx| { this.update(cx, |this, cx| { let mut models = Vec::new(); @@ -265,13 +267,18 @@ impl State { async fn fetch_models( client: Arc, llm_api_token: LlmApiToken, + use_cloud: bool, ) -> Result { let http_client = &client.http_client(); let token = llm_api_token.acquire(&client).await?; let request = http_client::Request::builder() .method(Method::GET) - .uri(http_client.build_zed_llm_url("/models", &[])?.as_ref()) + .uri( + http_client + .build_zed_llm_url("/models", &[], use_cloud)? + .as_ref(), + ) .header("Authorization", format!("Bearer {token}")) .body(AsyncBody::empty())?; let mut response = http_client @@ -535,6 +542,7 @@ impl CloudLanguageModel { llm_api_token: LlmApiToken, app_version: Option, body: CompletionBody, + use_cloud: bool, ) -> Result { let http_client = &client.http_client(); @@ -542,9 +550,11 @@ impl CloudLanguageModel { let mut refreshed_token = false; loop { - let request_builder = http_client::Request::builder() - .method(Method::POST) - .uri(http_client.build_zed_llm_url("/completions", &[])?.as_ref()); + let request_builder = http_client::Request::builder().method(Method::POST).uri( + http_client + .build_zed_llm_url("/completions", &[], use_cloud)? + .as_ref(), + ); let request_builder = if let Some(app_version) = app_version { request_builder.header(ZED_VERSION_HEADER_NAME, app_version.to_string()) } else { @@ -771,6 +781,7 @@ impl LanguageModel for CloudLanguageModel { let model_id = self.model.id.to_string(); let generate_content_request = into_google(request, model_id.clone(), GoogleModelMode::Default); + let use_cloud = cx.has_flag::(); async move { let http_client = &client.http_client(); let token = llm_api_token.acquire(&client).await?; @@ -786,7 +797,7 @@ impl LanguageModel for CloudLanguageModel { .method(Method::POST) .uri( http_client - .build_zed_llm_url("/count_tokens", &[])? + .build_zed_llm_url("/count_tokens", &[], use_cloud)? .as_ref(), ) .header("Content-Type", "application/json") @@ -835,6 +846,9 @@ impl LanguageModel for CloudLanguageModel { let intent = request.intent; let mode = request.mode; let app_version = cx.update(|cx| AppVersion::global(cx)).ok(); + let use_cloud = cx + .update(|cx| cx.has_flag::()) + .unwrap_or(false); match self.model.provider { zed_llm_client::LanguageModelProvider::Anthropic => { let request = into_anthropic( @@ -872,6 +886,7 @@ impl LanguageModel for CloudLanguageModel { provider_request: serde_json::to_value(&request) .map_err(|e| anyhow!(e))?, }, + use_cloud, ) .await .map_err(|err| match err.downcast::() { @@ -924,6 +939,7 @@ impl LanguageModel for CloudLanguageModel { provider_request: serde_json::to_value(&request) .map_err(|e| anyhow!(e))?, }, + use_cloud, ) .await?; @@ -964,6 +980,7 @@ impl LanguageModel for CloudLanguageModel { provider_request: serde_json::to_value(&request) .map_err(|e| anyhow!(e))?, }, + use_cloud, ) .await?; diff --git a/crates/web_search_providers/Cargo.toml b/crates/web_search_providers/Cargo.toml index 2e052796c4..208cb63593 100644 --- a/crates/web_search_providers/Cargo.toml +++ b/crates/web_search_providers/Cargo.toml @@ -14,6 +14,7 @@ path = "src/web_search_providers.rs" [dependencies] anyhow.workspace = true client.workspace = true +feature_flags.workspace = true futures.workspace = true gpui.workspace = true http_client.workspace = true diff --git a/crates/web_search_providers/src/cloud.rs b/crates/web_search_providers/src/cloud.rs index adf79b0ff6..79ccf97e47 100644 --- a/crates/web_search_providers/src/cloud.rs +++ b/crates/web_search_providers/src/cloud.rs @@ -2,6 +2,7 @@ use std::sync::Arc; use anyhow::{Context as _, Result}; use client::Client; +use feature_flags::{FeatureFlagAppExt as _, ZedCloudFeatureFlag}; use futures::AsyncReadExt as _; use gpui::{App, AppContext, Context, Entity, Subscription, Task}; use http_client::{HttpClient, Method}; @@ -62,7 +63,10 @@ impl WebSearchProvider for CloudWebSearchProvider { let client = state.client.clone(); let llm_api_token = state.llm_api_token.clone(); let body = WebSearchBody { query }; - cx.background_spawn(async move { perform_web_search(client, llm_api_token, body).await }) + let use_cloud = cx.has_flag::(); + cx.background_spawn(async move { + perform_web_search(client, llm_api_token, body, use_cloud).await + }) } } @@ -70,6 +74,7 @@ async fn perform_web_search( client: Arc, llm_api_token: LlmApiToken, body: WebSearchBody, + use_cloud: bool, ) -> Result { const MAX_RETRIES: usize = 3; @@ -86,7 +91,11 @@ async fn perform_web_search( let request = http_client::Request::builder() .method(Method::POST) - .uri(http_client.build_zed_llm_url("/web_search", &[])?.as_ref()) + .uri( + http_client + .build_zed_llm_url("/web_search", &[], use_cloud)? + .as_ref(), + ) .header("Content-Type", "application/json") .header("Authorization", format!("Bearer {token}")) .body(serde_json::to_string(&body)?.into())?; diff --git a/crates/zeta/src/zeta.rs b/crates/zeta/src/zeta.rs index 87cd1e604c..12d3d4bfbc 100644 --- a/crates/zeta/src/zeta.rs +++ b/crates/zeta/src/zeta.rs @@ -8,6 +8,7 @@ mod rate_completion_modal; pub(crate) use completion_diff_element::*; use db::kvp::KEY_VALUE_STORE; +use feature_flags::{FeatureFlagAppExt as _, ZedCloudFeatureFlag}; pub use init::*; use inline_completion::DataCollectionState; use license_detection::LICENSE_FILES_TO_CHECK; @@ -390,6 +391,7 @@ impl Zeta { let client = self.client.clone(); let llm_token = self.llm_token.clone(); let app_version = AppVersion::global(cx); + let use_cloud = cx.has_flag::(); let buffer = buffer.clone(); @@ -480,6 +482,7 @@ impl Zeta { llm_token, app_version, body, + use_cloud, }) .await; let (response, usage) = match response { @@ -745,6 +748,7 @@ and then another llm_token, app_version, body, + use_cloud, .. } = params; @@ -760,7 +764,7 @@ and then another } else { request_builder.uri( http_client - .build_zed_llm_url("/predict_edits/v2", &[])? + .build_zed_llm_url("/predict_edits/v2", &[], use_cloud)? .as_ref(), ) }; @@ -820,6 +824,7 @@ and then another let client = self.client.clone(); let llm_token = self.llm_token.clone(); let app_version = AppVersion::global(cx); + let use_cloud = cx.has_flag::(); cx.spawn(async move |this, cx| { let http_client = client.http_client(); let mut response = llm_token_retry(&llm_token, &client, |token| { @@ -830,7 +835,7 @@ and then another } else { request_builder.uri( http_client - .build_zed_llm_url("/predict_edits/accept", &[])? + .build_zed_llm_url("/predict_edits/accept", &[], use_cloud)? .as_ref(), ) }; @@ -1126,6 +1131,7 @@ struct PerformPredictEditsParams { pub llm_token: LlmApiToken, pub app_version: SemanticVersion, pub body: PredictEditsBody, + pub use_cloud: bool, } #[derive(Error, Debug)]