diff --git a/Cargo.lock b/Cargo.lock index cd83e61508..96c50f7e5f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2619,6 +2619,7 @@ dependencies = [ "clock", "cocoa 0.26.0", "collections", + "credentials_provider", "feature_flags", "futures 0.3.31", "gpui", @@ -3477,6 +3478,19 @@ dependencies = [ "crc", ] +[[package]] +name = "credentials_provider" +version = "0.1.0" +dependencies = [ + "anyhow", + "futures 0.3.31", + "gpui", + "paths", + "release_channel", + "serde", + "serde_json", +] + [[package]] name = "criterion" version = "0.5.1" @@ -6985,6 +6999,7 @@ dependencies = [ "client", "collections", "copilot", + "credentials_provider", "deepseek", "editor", "feature_flags", @@ -16682,6 +16697,7 @@ dependencies = [ "command_palette_hooks", "component_preview", "copilot", + "credentials_provider", "db", "diagnostics", "editor", diff --git a/Cargo.toml b/Cargo.toml index 15edf729ca..448f653a51 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,6 +16,7 @@ members = [ "crates/auto_update", "crates/auto_update_ui", "crates/breadcrumbs", + "crates/buffer_diff", "crates/call", "crates/channel", "crates/cli", @@ -31,10 +32,10 @@ members = [ "crates/context_server", "crates/context_server_settings", "crates/copilot", + "crates/credentials_provider", "crates/db", "crates/deepseek", "crates/diagnostics", - "crates/buffer_diff", "crates/docs_preprocessor", "crates/editor", "crates/evals", @@ -233,6 +234,7 @@ component_preview = { path = "crates/component_preview" } context_server = { path = "crates/context_server" } context_server_settings = { path = "crates/context_server_settings" } copilot = { path = "crates/copilot" } +credentials_provider = { path = "crates/credentials_provider" } db = { path = "crates/db" } deepseek = { path = "crates/deepseek" } diagnostics = { path = "crates/diagnostics" } diff --git a/crates/client/Cargo.toml b/crates/client/Cargo.toml index f52f9559e8..e36d71b3dc 100644 --- a/crates/client/Cargo.toml +++ b/crates/client/Cargo.toml @@ -22,6 +22,7 @@ async-tungstenite = { workspace = true, features = ["async-std", "async-tls"] } chrono = { workspace = true, features = ["serde"] } clock.workspace = true collections.workspace = true +credentials_provider.workspace = true feature_flags.workspace = true futures.workspace = true gpui.workspace = true diff --git a/crates/client/src/client.rs b/crates/client/src/client.rs index 1b2d007f54..658c32ecfa 100644 --- a/crates/client/src/client.rs +++ b/crates/client/src/client.rs @@ -15,6 +15,7 @@ use async_tungstenite::tungstenite::{ }; use chrono::{DateTime, Utc}; use clock::SystemClock; +use credentials_provider::CredentialsProvider; use futures::{ channel::oneshot, future::BoxFuture, AsyncReadExt, FutureExt, SinkExt, Stream, StreamExt, TryFutureExt as _, TryStreamExt, @@ -57,14 +58,6 @@ static ZED_SERVER_URL: LazyLock> = LazyLock::new(|| std::env::var("ZED_SERVER_URL").ok()); static ZED_RPC_URL: LazyLock> = LazyLock::new(|| std::env::var("ZED_RPC_URL").ok()); -/// An environment variable whose presence indicates that the development auth -/// provider should be used. -/// -/// Only works in development. Setting this environment variable in other release -/// channels is a no-op. -pub static ZED_DEVELOPMENT_AUTH: LazyLock = LazyLock::new(|| { - std::env::var("ZED_DEVELOPMENT_AUTH").map_or(false, |value| !value.is_empty()) -}); pub static IMPERSONATE_LOGIN: LazyLock> = LazyLock::new(|| { std::env::var("ZED_IMPERSONATE") .ok() @@ -193,7 +186,7 @@ pub struct Client { peer: Arc, http: Arc, telemetry: Arc, - credentials_provider: Arc, + credentials_provider: ClientCredentialsProvider, state: RwLock, handler_set: parking_lot::Mutex, @@ -304,16 +297,46 @@ impl Credentials { } } -/// A provider for [`Credentials`]. -/// -/// Used to abstract over reading and writing credentials to some form of -/// persistence (like the system keychain). -trait CredentialsProvider { +pub struct ClientCredentialsProvider { + provider: Arc, +} + +impl ClientCredentialsProvider { + pub fn new(cx: &App) -> Self { + Self { + provider: ::global(cx), + } + } + + fn server_url(&self, cx: &AsyncApp) -> Result { + cx.update(|cx| ClientSettings::get_global(cx).server_url.clone()) + } + /// Reads the credentials from the provider. fn read_credentials<'a>( &'a self, cx: &'a AsyncApp, - ) -> Pin> + 'a>>; + ) -> Pin> + 'a>> { + async move { + if IMPERSONATE_LOGIN.is_some() { + return None; + } + + let server_url = self.server_url(cx).ok()?; + let (user_id, access_token) = self + .provider + .read_credentials(&server_url, cx) + .await + .log_err() + .flatten()?; + + Some(Credentials { + user_id: user_id.parse().ok()?, + access_token: String::from_utf8(access_token).ok()?, + }) + } + .boxed_local() + } /// Writes the credentials to the provider. fn write_credentials<'a>( @@ -321,13 +344,32 @@ trait CredentialsProvider { user_id: u64, access_token: String, cx: &'a AsyncApp, - ) -> Pin> + 'a>>; + ) -> Pin> + 'a>> { + async move { + let server_url = self.server_url(cx)?; + self.provider + .write_credentials( + &server_url, + &user_id.to_string(), + access_token.as_bytes(), + cx, + ) + .await + } + .boxed_local() + } /// Deletes the credentials from the provider. fn delete_credentials<'a>( &'a self, cx: &'a AsyncApp, - ) -> Pin> + 'a>>; + ) -> Pin> + 'a>> { + async move { + let server_url = self.server_url(cx)?; + self.provider.delete_credentials(&server_url, cx).await + } + .boxed_local() + } } impl Default for ClientState { @@ -484,27 +526,12 @@ impl Client { http: Arc, cx: &mut App, ) -> Arc { - let use_zed_development_auth = match ReleaseChannel::try_global(cx) { - Some(ReleaseChannel::Dev) => *ZED_DEVELOPMENT_AUTH, - Some(ReleaseChannel::Nightly | ReleaseChannel::Preview | ReleaseChannel::Stable) - | None => false, - }; - - let credentials_provider: Arc = - if use_zed_development_auth { - Arc::new(DevelopmentCredentialsProvider { - path: paths::config_dir().join("development_auth"), - }) - } else { - Arc::new(KeychainCredentialsProvider) - }; - Arc::new(Self { id: AtomicU64::new(0), peer: Peer::new(0), telemetry: Telemetry::new(clock, http.clone(), cx), http, - credentials_provider, + credentials_provider: ClientCredentialsProvider::new(cx), state: Default::default(), handler_set: Default::default(), @@ -842,8 +869,7 @@ impl Client { Ok(conn) => { self.state.write().credentials = Some(credentials.clone()); if !read_from_provider && IMPERSONATE_LOGIN.is_none() { - self.credentials_provider.write_credentials(credentials.user_id, credentials.access_token, cx).await.log_err(); - + self.credentials_provider.write_credentials(credentials.user_id, credentials.access_token, cx).await.log_err(); } futures::select_biased! { @@ -1588,130 +1614,6 @@ impl ProtoClient for Client { } } -#[derive(Serialize, Deserialize)] -struct DevelopmentCredentials { - user_id: u64, - access_token: String, -} - -/// A credentials provider that stores credentials in a local file. -/// -/// This MUST only be used in development, as this is not a secure way of storing -/// credentials on user machines. -/// -/// Its existence is purely to work around the annoyance of having to constantly -/// re-allow access to the system keychain when developing Zed. -struct DevelopmentCredentialsProvider { - path: PathBuf, -} - -impl CredentialsProvider for DevelopmentCredentialsProvider { - fn read_credentials<'a>( - &'a self, - _cx: &'a AsyncApp, - ) -> Pin> + 'a>> { - async move { - if IMPERSONATE_LOGIN.is_some() { - return None; - } - - let json = std::fs::read(&self.path).log_err()?; - - let credentials: DevelopmentCredentials = serde_json::from_slice(&json).log_err()?; - - Some(Credentials { - user_id: credentials.user_id, - access_token: credentials.access_token, - }) - } - .boxed_local() - } - - fn write_credentials<'a>( - &'a self, - user_id: u64, - access_token: String, - _cx: &'a AsyncApp, - ) -> Pin> + 'a>> { - async move { - let json = serde_json::to_string(&DevelopmentCredentials { - user_id, - access_token, - })?; - - std::fs::write(&self.path, json)?; - - Ok(()) - } - .boxed_local() - } - - fn delete_credentials<'a>( - &'a self, - _cx: &'a AsyncApp, - ) -> Pin> + 'a>> { - async move { Ok(std::fs::remove_file(&self.path)?) }.boxed_local() - } -} - -/// A credentials provider that stores credentials in the system keychain. -struct KeychainCredentialsProvider; - -impl CredentialsProvider for KeychainCredentialsProvider { - fn read_credentials<'a>( - &'a self, - cx: &'a AsyncApp, - ) -> Pin> + 'a>> { - async move { - if IMPERSONATE_LOGIN.is_some() { - return None; - } - - let (user_id, access_token) = cx - .update(|cx| cx.read_credentials(&ClientSettings::get_global(cx).server_url)) - .log_err()? - .await - .log_err()??; - - Some(Credentials { - user_id: user_id.parse().ok()?, - access_token: String::from_utf8(access_token).ok()?, - }) - } - .boxed_local() - } - - fn write_credentials<'a>( - &'a self, - user_id: u64, - access_token: String, - cx: &'a AsyncApp, - ) -> Pin> + 'a>> { - async move { - cx.update(move |cx| { - cx.write_credentials( - &ClientSettings::get_global(cx).server_url, - &user_id.to_string(), - access_token.as_bytes(), - ) - })? - .await - } - .boxed_local() - } - - fn delete_credentials<'a>( - &'a self, - cx: &'a AsyncApp, - ) -> Pin> + 'a>> { - async move { - cx.update(move |cx| cx.delete_credentials(&ClientSettings::get_global(cx).server_url))? - .await - } - .boxed_local() - } -} - /// prefix for the zed:// url scheme pub const ZED_URL_SCHEME: &str = "zed"; diff --git a/crates/credentials_provider/Cargo.toml b/crates/credentials_provider/Cargo.toml new file mode 100644 index 0000000000..bf47bb24b1 --- /dev/null +++ b/crates/credentials_provider/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "credentials_provider" +version = "0.1.0" +edition.workspace = true +publish.workspace = true +license = "GPL-3.0-or-later" + +[lints] +workspace = true + +[lib] +path = "src/credentials_provider.rs" + +[dependencies] +anyhow.workspace = true +futures.workspace = true +gpui.workspace = true +paths.workspace = true +release_channel.workspace = true +serde.workspace = true +serde_json.workspace = true diff --git a/crates/credentials_provider/LICENSE-GPL b/crates/credentials_provider/LICENSE-GPL new file mode 120000 index 0000000000..89e542f750 --- /dev/null +++ b/crates/credentials_provider/LICENSE-GPL @@ -0,0 +1 @@ +../../LICENSE-GPL \ No newline at end of file diff --git a/crates/credentials_provider/src/credentials_provider.rs b/crates/credentials_provider/src/credentials_provider.rs new file mode 100644 index 0000000000..86561b2d04 --- /dev/null +++ b/crates/credentials_provider/src/credentials_provider.rs @@ -0,0 +1,187 @@ +use std::collections::HashMap; +use std::future::Future; +use std::path::PathBuf; +use std::pin::Pin; +use std::sync::{Arc, LazyLock}; + +use anyhow::Result; +use futures::FutureExt as _; +use gpui::{App, AsyncApp}; +use release_channel::ReleaseChannel; + +/// An environment variable whose presence indicates that the development auth +/// provider should be used. +/// +/// Only works in development. Setting this environment variable in other release +/// channels is a no-op. +pub static ZED_DEVELOPMENT_AUTH: LazyLock = LazyLock::new(|| { + std::env::var("ZED_DEVELOPMENT_AUTH").map_or(false, |value| !value.is_empty()) +}); + +/// A provider for credentials. +/// +/// Used to abstract over reading and writing credentials to some form of +/// persistence (like the system keychain). +pub trait CredentialsProvider: Send + Sync { + /// Reads the credentials from the provider. + fn read_credentials<'a>( + &'a self, + url: &'a str, + cx: &'a AsyncApp, + ) -> Pin)>>> + 'a>>; + + /// Writes the credentials to the provider. + fn write_credentials<'a>( + &'a self, + url: &'a str, + username: &'a str, + password: &'a [u8], + cx: &'a AsyncApp, + ) -> Pin> + 'a>>; + + /// Deletes the credentials from the provider. + fn delete_credentials<'a>( + &'a self, + url: &'a str, + cx: &'a AsyncApp, + ) -> Pin> + 'a>>; +} + +impl dyn CredentialsProvider { + /// Returns the global [`CredentialsProvider`]. + pub fn global(cx: &App) -> Arc { + // The `CredentialsProvider` trait has `Send + Sync` bounds on it, so it + // seems like this is a false positive from Clippy. + #[allow(clippy::arc_with_non_send_sync)] + Self::new(cx) + } + + fn new(cx: &App) -> Arc { + let use_development_backend = match ReleaseChannel::try_global(cx) { + Some(ReleaseChannel::Dev) => *ZED_DEVELOPMENT_AUTH, + Some(ReleaseChannel::Nightly | ReleaseChannel::Preview | ReleaseChannel::Stable) + | None => false, + }; + + if use_development_backend { + Arc::new(DevelopmentCredentialsProvider::new()) + } else { + Arc::new(KeychainCredentialsProvider) + } + } +} + +/// A credentials provider that stores credentials in the system keychain. +struct KeychainCredentialsProvider; + +impl CredentialsProvider for KeychainCredentialsProvider { + fn read_credentials<'a>( + &'a self, + url: &'a str, + cx: &'a AsyncApp, + ) -> Pin)>>> + 'a>> { + async move { cx.update(|cx| cx.read_credentials(url))?.await }.boxed_local() + } + + fn write_credentials<'a>( + &'a self, + url: &'a str, + username: &'a str, + password: &'a [u8], + cx: &'a AsyncApp, + ) -> Pin> + 'a>> { + async move { + cx.update(move |cx| cx.write_credentials(url, username, password))? + .await + } + .boxed_local() + } + + fn delete_credentials<'a>( + &'a self, + url: &'a str, + cx: &'a AsyncApp, + ) -> Pin> + 'a>> { + async move { cx.update(move |cx| cx.delete_credentials(url))?.await }.boxed_local() + } +} + +/// A credentials provider that stores credentials in a local file. +/// +/// This MUST only be used in development, as this is not a secure way of storing +/// credentials on user machines. +/// +/// Its existence is purely to work around the annoyance of having to constantly +/// re-allow access to the system keychain when developing Zed. +struct DevelopmentCredentialsProvider { + path: PathBuf, +} + +impl DevelopmentCredentialsProvider { + fn new() -> Self { + let path = paths::config_dir().join("development_credentials"); + + Self { path } + } + + fn load_credentials(&self) -> Result)>> { + let json = std::fs::read(&self.path)?; + let credentials: HashMap)> = serde_json::from_slice(&json)?; + + Ok(credentials) + } + + fn save_credentials(&self, credentials: &HashMap)>) -> Result<()> { + let json = serde_json::to_string(credentials)?; + std::fs::write(&self.path, json)?; + + Ok(()) + } +} + +impl CredentialsProvider for DevelopmentCredentialsProvider { + fn read_credentials<'a>( + &'a self, + url: &'a str, + _cx: &'a AsyncApp, + ) -> Pin)>>> + 'a>> { + async move { + Ok(self + .load_credentials() + .unwrap_or_default() + .get(url) + .cloned()) + } + .boxed_local() + } + + fn write_credentials<'a>( + &'a self, + url: &'a str, + username: &'a str, + password: &'a [u8], + _cx: &'a AsyncApp, + ) -> Pin> + 'a>> { + async move { + let mut credentials = self.load_credentials().unwrap_or_default(); + credentials.insert(url.to_string(), (username.to_string(), password.to_vec())); + + self.save_credentials(&credentials) + } + .boxed_local() + } + + fn delete_credentials<'a>( + &'a self, + url: &'a str, + _cx: &'a AsyncApp, + ) -> Pin> + 'a>> { + async move { + let mut credentials = self.load_credentials()?; + credentials.remove(url); + + self.save_credentials(&credentials) + } + .boxed_local() + } +} diff --git a/crates/language_models/Cargo.toml b/crates/language_models/Cargo.toml index b3b9ddd7f1..9a9196ee1f 100644 --- a/crates/language_models/Cargo.toml +++ b/crates/language_models/Cargo.toml @@ -16,6 +16,7 @@ anthropic = { workspace = true, features = ["schemars"] } anyhow.workspace = true client.workspace = true collections.workspace = true +credentials_provider.workspace = true copilot = { workspace = true, features = ["schemars"] } deepseek = { workspace = true, features = ["schemars"] } editor.workspace = true diff --git a/crates/language_models/src/provider/anthropic.rs b/crates/language_models/src/provider/anthropic.rs index e990868b9c..e3ca4998fe 100644 --- a/crates/language_models/src/provider/anthropic.rs +++ b/crates/language_models/src/provider/anthropic.rs @@ -2,6 +2,7 @@ use crate::AllLanguageModelSettings; use anthropic::{AnthropicError, ContentDelta, Event, ResponseContent}; use anyhow::{anyhow, Context as _, Result}; use collections::{BTreeMap, HashMap}; +use credentials_provider::CredentialsProvider; use editor::{Editor, EditorElement, EditorStyle}; use futures::Stream; use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt, TryStreamExt as _}; @@ -70,10 +71,16 @@ pub struct State { impl State { fn reset_api_key(&self, cx: &mut Context) -> Task> { - let delete_credentials = - cx.delete_credentials(&AllLanguageModelSettings::get_global(cx).anthropic.api_url); + let credentials_provider = ::global(cx); + let api_url = AllLanguageModelSettings::get_global(cx) + .anthropic + .api_url + .clone(); cx.spawn(|this, mut cx| async move { - delete_credentials.await.ok(); + credentials_provider + .delete_credentials(&api_url, &cx) + .await + .ok(); this.update(&mut cx, |this, cx| { this.api_key = None; this.api_key_from_env = false; @@ -83,16 +90,16 @@ impl State { } fn set_api_key(&mut self, api_key: String, cx: &mut Context) -> Task> { - let write_credentials = cx.write_credentials( - AllLanguageModelSettings::get_global(cx) - .anthropic - .api_url - .as_str(), - "Bearer", - api_key.as_bytes(), - ); + let credentials_provider = ::global(cx); + let api_url = AllLanguageModelSettings::get_global(cx) + .anthropic + .api_url + .clone(); cx.spawn(|this, mut cx| async move { - write_credentials.await?; + credentials_provider + .write_credentials(&api_url, "Bearer", api_key.as_bytes(), &cx) + .await + .ok(); this.update(&mut cx, |this, cx| { this.api_key = Some(api_key); @@ -110,6 +117,7 @@ impl State { return Task::ready(Ok(())); } + let credentials_provider = ::global(cx); let api_url = AllLanguageModelSettings::get_global(cx) .anthropic .api_url @@ -119,8 +127,8 @@ impl State { let (api_key, from_env) = if let Ok(api_key) = std::env::var(ANTHROPIC_API_KEY_VAR) { (api_key, true) } else { - let (_, api_key) = cx - .update(|cx| cx.read_credentials(&api_url))? + let (_, api_key) = credentials_provider + .read_credentials(&api_url, &cx) .await? .ok_or(AuthenticateError::CredentialsNotFound)?; ( diff --git a/crates/language_models/src/provider/deepseek.rs b/crates/language_models/src/provider/deepseek.rs index 9a65273d12..91cc02149d 100644 --- a/crates/language_models/src/provider/deepseek.rs +++ b/crates/language_models/src/provider/deepseek.rs @@ -1,5 +1,6 @@ use anyhow::{anyhow, Context as _, Result}; use collections::BTreeMap; +use credentials_provider::CredentialsProvider; use editor::{Editor, EditorElement, EditorStyle}; use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt}; use gpui::{ @@ -57,10 +58,16 @@ impl State { } fn reset_api_key(&self, cx: &mut Context) -> Task> { - let settings = &AllLanguageModelSettings::get_global(cx).deepseek; - let delete_credentials = cx.delete_credentials(&settings.api_url); + let credentials_provider = ::global(cx); + let api_url = AllLanguageModelSettings::get_global(cx) + .deepseek + .api_url + .clone(); cx.spawn(|this, mut cx| async move { - delete_credentials.await.log_err(); + credentials_provider + .delete_credentials(&api_url, &cx) + .await + .log_err(); this.update(&mut cx, |this, cx| { this.api_key = None; this.api_key_from_env = false; @@ -70,12 +77,15 @@ impl State { } fn set_api_key(&mut self, api_key: String, cx: &mut Context) -> Task> { - let settings = &AllLanguageModelSettings::get_global(cx).deepseek; - let write_credentials = - cx.write_credentials(&settings.api_url, "Bearer", api_key.as_bytes()); - + let credentials_provider = ::global(cx); + let api_url = AllLanguageModelSettings::get_global(cx) + .deepseek + .api_url + .clone(); cx.spawn(|this, mut cx| async move { - write_credentials.await?; + credentials_provider + .write_credentials(&api_url, "Bearer", api_key.as_bytes(), &cx) + .await?; this.update(&mut cx, |this, cx| { this.api_key = Some(api_key); cx.notify(); @@ -88,17 +98,17 @@ impl State { return Task::ready(Ok(())); } + let credentials_provider = ::global(cx); let api_url = AllLanguageModelSettings::get_global(cx) .deepseek .api_url .clone(); - cx.spawn(|this, mut cx| async move { let (api_key, from_env) = if let Ok(api_key) = std::env::var(DEEPSEEK_API_KEY_VAR) { (api_key, true) } else { - let (_, api_key) = cx - .update(|cx| cx.read_credentials(&api_url))? + let (_, api_key) = credentials_provider + .read_credentials(&api_url, &cx) .await? .ok_or(AuthenticateError::CredentialsNotFound)?; ( diff --git a/crates/language_models/src/provider/google.rs b/crates/language_models/src/provider/google.rs index d7a6e8ba2a..9e313935c2 100644 --- a/crates/language_models/src/provider/google.rs +++ b/crates/language_models/src/provider/google.rs @@ -1,5 +1,6 @@ use anyhow::{anyhow, Context as _, Result}; use collections::BTreeMap; +use credentials_provider::CredentialsProvider; use editor::{Editor, EditorElement, EditorStyle}; use futures::{future::BoxFuture, FutureExt, StreamExt}; use google_ai::stream_generate_content; @@ -59,10 +60,16 @@ impl State { } fn reset_api_key(&self, cx: &mut Context) -> Task> { - let delete_credentials = - cx.delete_credentials(&AllLanguageModelSettings::get_global(cx).google.api_url); + let credentials_provider = ::global(cx); + let api_url = AllLanguageModelSettings::get_global(cx) + .google + .api_url + .clone(); cx.spawn(|this, mut cx| async move { - delete_credentials.await.ok(); + credentials_provider + .delete_credentials(&api_url, &cx) + .await + .log_err(); this.update(&mut cx, |this, cx| { this.api_key = None; this.api_key_from_env = false; @@ -72,12 +79,15 @@ impl State { } fn set_api_key(&mut self, api_key: String, cx: &mut Context) -> Task> { - let settings = &AllLanguageModelSettings::get_global(cx).google; - let write_credentials = - cx.write_credentials(&settings.api_url, "Bearer", api_key.as_bytes()); - + let credentials_provider = ::global(cx); + let api_url = AllLanguageModelSettings::get_global(cx) + .google + .api_url + .clone(); cx.spawn(|this, mut cx| async move { - write_credentials.await?; + credentials_provider + .write_credentials(&api_url, "Bearer", api_key.as_bytes(), &cx) + .await?; this.update(&mut cx, |this, cx| { this.api_key = Some(api_key); cx.notify(); @@ -90,6 +100,7 @@ impl State { return Task::ready(Ok(())); } + let credentials_provider = ::global(cx); let api_url = AllLanguageModelSettings::get_global(cx) .google .api_url @@ -99,8 +110,8 @@ impl State { let (api_key, from_env) = if let Ok(api_key) = std::env::var(GOOGLE_AI_API_KEY_VAR) { (api_key, true) } else { - let (_, api_key) = cx - .update(|cx| cx.read_credentials(&api_url))? + let (_, api_key) = credentials_provider + .read_credentials(&api_url, &cx) .await? .ok_or(AuthenticateError::CredentialsNotFound)?; ( @@ -208,16 +219,7 @@ impl LanguageModelProvider for GoogleLanguageModelProvider { } fn reset_credentials(&self, cx: &mut App) -> Task> { - let state = self.state.clone(); - let delete_credentials = - cx.delete_credentials(&AllLanguageModelSettings::get_global(cx).google.api_url); - cx.spawn(|mut cx| async move { - delete_credentials.await.log_err(); - state.update(&mut cx, |this, cx| { - this.api_key = None; - cx.notify(); - }) - }) + self.state.update(cx, |state, cx| state.reset_api_key(cx)) } } diff --git a/crates/language_models/src/provider/mistral.rs b/crates/language_models/src/provider/mistral.rs index a5cc3dac16..032ee38c42 100644 --- a/crates/language_models/src/provider/mistral.rs +++ b/crates/language_models/src/provider/mistral.rs @@ -1,5 +1,6 @@ use anyhow::{anyhow, Context as _, Result}; use collections::BTreeMap; +use credentials_provider::CredentialsProvider; use editor::{Editor, EditorElement, EditorStyle}; use futures::{future::BoxFuture, FutureExt, StreamExt}; use gpui::{ @@ -62,10 +63,16 @@ impl State { } fn reset_api_key(&self, cx: &mut Context) -> Task> { - let settings = &AllLanguageModelSettings::get_global(cx).mistral; - let delete_credentials = cx.delete_credentials(&settings.api_url); + let credentials_provider = ::global(cx); + let api_url = AllLanguageModelSettings::get_global(cx) + .mistral + .api_url + .clone(); cx.spawn(|this, mut cx| async move { - delete_credentials.await.log_err(); + credentials_provider + .delete_credentials(&api_url, &cx) + .await + .log_err(); this.update(&mut cx, |this, cx| { this.api_key = None; this.api_key_from_env = false; @@ -75,12 +82,15 @@ impl State { } fn set_api_key(&mut self, api_key: String, cx: &mut Context) -> Task> { - let settings = &AllLanguageModelSettings::get_global(cx).mistral; - let write_credentials = - cx.write_credentials(&settings.api_url, "Bearer", api_key.as_bytes()); - + let credentials_provider = ::global(cx); + let api_url = AllLanguageModelSettings::get_global(cx) + .mistral + .api_url + .clone(); cx.spawn(|this, mut cx| async move { - write_credentials.await?; + credentials_provider + .write_credentials(&api_url, "Bearer", api_key.as_bytes(), &cx) + .await?; this.update(&mut cx, |this, cx| { this.api_key = Some(api_key); cx.notify(); @@ -93,6 +103,7 @@ impl State { return Task::ready(Ok(())); } + let credentials_provider = ::global(cx); let api_url = AllLanguageModelSettings::get_global(cx) .mistral .api_url @@ -101,8 +112,8 @@ impl State { let (api_key, from_env) = if let Ok(api_key) = std::env::var(MISTRAL_API_KEY_VAR) { (api_key, true) } else { - let (_, api_key) = cx - .update(|cx| cx.read_credentials(&api_url))? + let (_, api_key) = credentials_provider + .read_credentials(&api_url, &cx) .await? .ok_or(AuthenticateError::CredentialsNotFound)?; ( diff --git a/crates/language_models/src/provider/open_ai.rs b/crates/language_models/src/provider/open_ai.rs index 765eae9b15..ee277247b8 100644 --- a/crates/language_models/src/provider/open_ai.rs +++ b/crates/language_models/src/provider/open_ai.rs @@ -1,5 +1,6 @@ use anyhow::{anyhow, Context as _, Result}; use collections::BTreeMap; +use credentials_provider::CredentialsProvider; use editor::{Editor, EditorElement, EditorStyle}; use futures::{future::BoxFuture, FutureExt, StreamExt}; use gpui::{ @@ -63,10 +64,16 @@ impl State { } fn reset_api_key(&self, cx: &mut Context) -> Task> { - let settings = &AllLanguageModelSettings::get_global(cx).openai; - let delete_credentials = cx.delete_credentials(&settings.api_url); + let credentials_provider = ::global(cx); + let api_url = AllLanguageModelSettings::get_global(cx) + .openai + .api_url + .clone(); cx.spawn(|this, mut cx| async move { - delete_credentials.await.log_err(); + credentials_provider + .delete_credentials(&api_url, &cx) + .await + .log_err(); this.update(&mut cx, |this, cx| { this.api_key = None; this.api_key_from_env = false; @@ -76,12 +83,16 @@ impl State { } fn set_api_key(&mut self, api_key: String, cx: &mut Context) -> Task> { - let settings = &AllLanguageModelSettings::get_global(cx).openai; - let write_credentials = - cx.write_credentials(&settings.api_url, "Bearer", api_key.as_bytes()); - + let credentials_provider = ::global(cx); + let api_url = AllLanguageModelSettings::get_global(cx) + .openai + .api_url + .clone(); cx.spawn(|this, mut cx| async move { - write_credentials.await?; + credentials_provider + .write_credentials(&api_url, "Bearer", api_key.as_bytes(), &cx) + .await + .log_err(); this.update(&mut cx, |this, cx| { this.api_key = Some(api_key); cx.notify(); @@ -94,6 +105,7 @@ impl State { return Task::ready(Ok(())); } + let credentials_provider = ::global(cx); let api_url = AllLanguageModelSettings::get_global(cx) .openai .api_url @@ -102,8 +114,8 @@ impl State { let (api_key, from_env) = if let Ok(api_key) = std::env::var(OPENAI_API_KEY_VAR) { (api_key, true) } else { - let (_, api_key) = cx - .update(|cx| cx.read_credentials(&api_url))? + let (_, api_key) = credentials_provider + .read_credentials(&api_url, &cx) .await? .ok_or(AuthenticateError::CredentialsNotFound)?; ( diff --git a/crates/zed/Cargo.toml b/crates/zed/Cargo.toml index 529c5f8b6b..f9f13e4de1 100644 --- a/crates/zed/Cargo.toml +++ b/crates/zed/Cargo.toml @@ -41,6 +41,7 @@ command_palette.workspace = true command_palette_hooks.workspace = true component_preview.workspace = true copilot.workspace = true +credentials_provider.workspace = true db.workspace = true diagnostics.workspace = true editor.workspace = true diff --git a/crates/zed/src/main.rs b/crates/zed/src/main.rs index 0ca2d93bc3..8ab32e9da7 100644 --- a/crates/zed/src/main.rs +++ b/crates/zed/src/main.rs @@ -723,7 +723,7 @@ fn handle_open_request(request: OpenRequest, app_state: Arc, cx: &mut async fn authenticate(client: Arc, cx: &AsyncApp) -> Result<()> { if stdout_is_a_pty() { - if *client::ZED_DEVELOPMENT_AUTH { + if *credentials_provider::ZED_DEVELOPMENT_AUTH { client.authenticate_and_connect(true, cx).await?; } else if client::IMPERSONATE_LOGIN.is_some() { client.authenticate_and_connect(false, cx).await?;