diff --git a/Cargo.lock b/Cargo.lock index 52b017b876..c707c3349d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6951,6 +6951,7 @@ dependencies = [ "serde_json", "smol", "strum", + "thiserror 1.0.69", "ui", "util", ] diff --git a/crates/assistant/src/assistant_panel.rs b/crates/assistant/src/assistant_panel.rs index 585f809e25..ecf2e2c421 100644 --- a/crates/assistant/src/assistant_panel.rs +++ b/crates/assistant/src/assistant_panel.rs @@ -20,7 +20,9 @@ use gpui::{ Subscription, Task, UpdateGlobal, WeakEntity, }; use language::LanguageRegistry; -use language_model::{LanguageModelProviderId, LanguageModelRegistry, ZED_CLOUD_PROVIDER_ID}; +use language_model::{ + AuthenticateError, LanguageModelProviderId, LanguageModelRegistry, ZED_CLOUD_PROVIDER_ID, +}; use project::Project; use prompt_library::{open_prompt_library, PromptBuilder, PromptLibrary}; use search::{buffer_search::DivRegistrar, BufferSearchBar}; @@ -1156,7 +1158,10 @@ impl AssistantPanel { .map_or(false, |provider| provider.is_authenticated(cx)) } - fn authenticate(&mut self, cx: &mut Context) -> Option>> { + fn authenticate( + &mut self, + cx: &mut Context, + ) -> Option>> { LanguageModelRegistry::read_global(cx) .active_provider() .map_or(None, |provider| Some(provider.authenticate(cx))) diff --git a/crates/language_model/Cargo.toml b/crates/language_model/Cargo.toml index 7b570a54f7..51f205dced 100644 --- a/crates/language_model/Cargo.toml +++ b/crates/language_model/Cargo.toml @@ -38,6 +38,7 @@ serde.workspace = true serde_json.workspace = true smol.workspace = true strum.workspace = true +thiserror.workspace = true ui.workspace = true util.workspace = true diff --git a/crates/language_model/src/fake_provider.rs b/crates/language_model/src/fake_provider.rs index c3eca1a2f4..a955638b21 100644 --- a/crates/language_model/src/fake_provider.rs +++ b/crates/language_model/src/fake_provider.rs @@ -1,6 +1,6 @@ use crate::{ - LanguageModel, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, - LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, + AuthenticateError, LanguageModel, LanguageModelCompletionEvent, LanguageModelId, + LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, }; use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt}; @@ -54,7 +54,7 @@ impl LanguageModelProvider for FakeLanguageModelProvider { true } - fn authenticate(&self, _: &mut App) -> Task> { + fn authenticate(&self, _: &mut App) -> Task> { Task::ready(Ok(())) } diff --git a/crates/language_model/src/language_model.rs b/crates/language_model/src/language_model.rs index 4cdf432572..6219fda739 100644 --- a/crates/language_model/src/language_model.rs +++ b/crates/language_model/src/language_model.rs @@ -21,6 +21,7 @@ use schemars::JsonSchema; use serde::{de::DeserializeOwned, Deserialize, Serialize}; use std::fmt; use std::{future::Future, sync::Arc}; +use thiserror::Error; use ui::IconName; pub const ZED_CLOUD_PROVIDER_ID: &str = "zed.dev"; @@ -231,6 +232,15 @@ pub trait LanguageModelTool: 'static + DeserializeOwned + JsonSchema { fn description() -> String; } +/// An error that occurred when trying to authenticate the language model provider. +#[derive(Debug, Error)] +pub enum AuthenticateError { + #[error("credentials not found")] + CredentialsNotFound, + #[error(transparent)] + Other(#[from] anyhow::Error), +} + pub trait LanguageModelProvider: 'static { fn id(&self) -> LanguageModelProviderId; fn name(&self) -> LanguageModelProviderName; @@ -240,7 +250,7 @@ pub trait LanguageModelProvider: 'static { fn provided_models(&self, cx: &App) -> Vec>; fn load_model(&self, _model: Arc, _cx: &App) {} fn is_authenticated(&self, cx: &App) -> bool; - fn authenticate(&self, cx: &mut App) -> Task>; + fn authenticate(&self, cx: &mut App) -> Task>; fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView; fn must_accept_terms(&self, _cx: &App) -> bool { false diff --git a/crates/language_models/src/provider/anthropic.rs b/crates/language_models/src/provider/anthropic.rs index 835229ab63..e990868b9c 100644 --- a/crates/language_models/src/provider/anthropic.rs +++ b/crates/language_models/src/provider/anthropic.rs @@ -10,8 +10,8 @@ use gpui::{ }; use http_client::HttpClient; use language_model::{ - LanguageModel, LanguageModelCacheConfiguration, LanguageModelId, LanguageModelName, - LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, + AuthenticateError, LanguageModel, LanguageModelCacheConfiguration, LanguageModelId, + LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role, }; use language_model::{LanguageModelCompletionEvent, LanguageModelToolUse, StopReason}; @@ -105,34 +105,38 @@ impl State { self.api_key.is_some() } - fn authenticate(&self, cx: &mut Context) -> Task> { + fn authenticate(&self, cx: &mut Context) -> Task> { if self.is_authenticated() { - Task::ready(Ok(())) - } else { - let api_url = AllLanguageModelSettings::get_global(cx) - .anthropic - .api_url - .clone(); - - cx.spawn(|this, mut cx| async move { - let (api_key, from_env) = if let Ok(api_key) = std::env::var(ANTHROPIC_API_KEY_VAR) - { - (api_key, true) - } else { - let (_, api_key) = cx - .update(|cx| cx.read_credentials(&api_url))? - .await? - .ok_or_else(|| anyhow!("credentials not found"))?; - (String::from_utf8(api_key)?, false) - }; - - this.update(&mut cx, |this, cx| { - this.api_key = Some(api_key); - this.api_key_from_env = from_env; - cx.notify(); - }) - }) + return Task::ready(Ok(())); } + + let api_url = AllLanguageModelSettings::get_global(cx) + .anthropic + .api_url + .clone(); + + cx.spawn(|this, mut cx| async move { + let (api_key, from_env) = if let Ok(api_key) = std::env::var(ANTHROPIC_API_KEY_VAR) { + (api_key, true) + } else { + let (_, api_key) = cx + .update(|cx| cx.read_credentials(&api_url))? + .await? + .ok_or(AuthenticateError::CredentialsNotFound)?; + ( + String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?, + false, + ) + }; + + this.update(&mut cx, |this, cx| { + this.api_key = Some(api_key); + this.api_key_from_env = from_env; + cx.notify(); + })?; + + Ok(()) + }) } } @@ -226,7 +230,7 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider { self.state.read(cx).is_authenticated() } - fn authenticate(&self, cx: &mut App) -> Task> { + fn authenticate(&self, cx: &mut App) -> Task> { self.state.update(cx, |state, cx| state.authenticate(cx)) } diff --git a/crates/language_models/src/provider/cloud.rs b/crates/language_models/src/provider/cloud.rs index 27d0e12596..05544f40db 100644 --- a/crates/language_models/src/provider/cloud.rs +++ b/crates/language_models/src/provider/cloud.rs @@ -17,9 +17,10 @@ use gpui::{ }; use http_client::{AsyncBody, HttpClient, Method, Response, StatusCode}; use language_model::{ - CloudModel, LanguageModel, LanguageModelCacheConfiguration, LanguageModelId, LanguageModelName, - LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, - LanguageModelProviderTosView, LanguageModelRequest, RateLimiter, ZED_CLOUD_PROVIDER_ID, + AuthenticateError, CloudModel, LanguageModel, LanguageModelCacheConfiguration, LanguageModelId, + LanguageModelName, LanguageModelProviderId, LanguageModelProviderName, + LanguageModelProviderState, LanguageModelProviderTosView, LanguageModelRequest, RateLimiter, + ZED_CLOUD_PROVIDER_ID, }; use language_model::{ LanguageModelAvailability, LanguageModelCompletionEvent, LanguageModelProvider, @@ -363,7 +364,7 @@ impl LanguageModelProvider for CloudLanguageModelProvider { !self.state.read(cx).is_signed_out() } - fn authenticate(&self, _cx: &mut App) -> Task> { + fn authenticate(&self, _cx: &mut App) -> Task> { Task::ready(Ok(())) } diff --git a/crates/language_models/src/provider/copilot_chat.rs b/crates/language_models/src/provider/copilot_chat.rs index 6efac131e9..1c4a4273ac 100644 --- a/crates/language_models/src/provider/copilot_chat.rs +++ b/crates/language_models/src/provider/copilot_chat.rs @@ -15,8 +15,8 @@ use gpui::{ Task, Transformation, }; use language_model::{ - LanguageModel, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, - LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, + AuthenticateError, LanguageModel, LanguageModelCompletionEvent, LanguageModelId, + LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role, }; use settings::SettingsStore; @@ -104,26 +104,28 @@ impl LanguageModelProvider for CopilotChatLanguageModelProvider { self.state.read(cx).is_authenticated(cx) } - fn authenticate(&self, cx: &mut App) -> Task> { - let result = if self.is_authenticated(cx) { - Ok(()) - } else if let Some(copilot) = Copilot::global(cx) { - let error_msg = match copilot.read(cx).status() { - Status::Disabled => anyhow::anyhow!("Copilot must be enabled for Copilot Chat to work. Please enable Copilot and try again."), - Status::Error(e) => anyhow::anyhow!(format!("Received the following error while signing into Copilot: {e}")), - Status::Starting { task: _ } => anyhow::anyhow!("Copilot is still starting, please wait for Copilot to start then try again"), - Status::Unauthorized => anyhow::anyhow!("Unable to authorize with Copilot. Please make sure that you have an active Copilot and Copilot Chat subscription."), - Status::Authorized => return Task::ready(Ok(())), - Status::SignedOut => anyhow::anyhow!("You have signed out of Copilot. Please sign in to Copilot and try again."), - Status::SigningIn { prompt: _ } => anyhow::anyhow!("Still signing into Copilot..."), - }; - Err(error_msg) - } else { - Err(anyhow::anyhow!( - "Copilot must be enabled for Copilot Chat to work. Please enable Copilot and try again." - )) + fn authenticate(&self, cx: &mut App) -> Task> { + if self.is_authenticated(cx) { + return Task::ready(Ok(())); }; - Task::ready(result) + + let Some(copilot) = Copilot::global(cx) else { + return Task::ready( Err(anyhow!( + "Copilot must be enabled for Copilot Chat to work. Please enable Copilot and try again." + ).into())); + }; + + let err = match copilot.read(cx).status() { + Status::Authorized => return Task::ready(Ok(())), + Status::Disabled => anyhow!("Copilot must be enabled for Copilot Chat to work. Please enable Copilot and try again."), + Status::Error(err) => anyhow!(format!("Received the following error while signing into Copilot: {err}")), + Status::Starting { task: _ } => anyhow!("Copilot is still starting, please wait for Copilot to start then try again"), + Status::Unauthorized => anyhow!("Unable to authorize with Copilot. Please make sure that you have an active Copilot and Copilot Chat subscription."), + Status::SignedOut => anyhow!("You have signed out of Copilot. Please sign in to Copilot and try again."), + Status::SigningIn { prompt: _ } => anyhow!("Still signing into Copilot..."), + }; + + Task::ready(Err(err.into())) } fn configuration_view(&self, _: &mut Window, cx: &mut App) -> AnyView { diff --git a/crates/language_models/src/provider/deepseek.rs b/crates/language_models/src/provider/deepseek.rs index ef7a60deeb..9a65273d12 100644 --- a/crates/language_models/src/provider/deepseek.rs +++ b/crates/language_models/src/provider/deepseek.rs @@ -1,4 +1,4 @@ -use anyhow::{anyhow, Result}; +use anyhow::{anyhow, Context as _, Result}; use collections::BTreeMap; use editor::{Editor, EditorElement, EditorStyle}; use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt}; @@ -8,8 +8,8 @@ use gpui::{ }; use http_client::HttpClient; use language_model::{ - LanguageModel, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, - LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, + AuthenticateError, LanguageModel, LanguageModelCompletionEvent, LanguageModelId, + LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role, }; use schemars::JsonSchema; @@ -83,33 +83,38 @@ impl State { }) } - fn authenticate(&self, cx: &mut Context) -> Task> { + fn authenticate(&self, cx: &mut Context) -> Task> { if self.is_authenticated() { - Task::ready(Ok(())) - } else { - let api_url = AllLanguageModelSettings::get_global(cx) - .deepseek - .api_url - .clone(); - - cx.spawn(|this, mut cx| async move { - let (api_key, from_env) = if let Ok(api_key) = std::env::var(DEEPSEEK_API_KEY_VAR) { - (api_key, true) - } else { - let (_, api_key) = cx - .update(|cx| cx.read_credentials(&api_url))? - .await? - .ok_or_else(|| anyhow!("credentials not found"))?; - (String::from_utf8(api_key)?, false) - }; - - this.update(&mut cx, |this, cx| { - this.api_key = Some(api_key); - this.api_key_from_env = from_env; - cx.notify(); - }) - }) + return Task::ready(Ok(())); } + + let api_url = AllLanguageModelSettings::get_global(cx) + .deepseek + .api_url + .clone(); + + cx.spawn(|this, mut cx| async move { + let (api_key, from_env) = if let Ok(api_key) = std::env::var(DEEPSEEK_API_KEY_VAR) { + (api_key, true) + } else { + let (_, api_key) = cx + .update(|cx| cx.read_credentials(&api_url))? + .await? + .ok_or(AuthenticateError::CredentialsNotFound)?; + ( + String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?, + false, + ) + }; + + this.update(&mut cx, |this, cx| { + this.api_key = Some(api_key); + this.api_key_from_env = from_env; + cx.notify(); + })?; + + Ok(()) + }) } } @@ -188,7 +193,7 @@ impl LanguageModelProvider for DeepSeekLanguageModelProvider { self.state.read(cx).is_authenticated() } - fn authenticate(&self, cx: &mut App) -> Task> { + fn authenticate(&self, cx: &mut App) -> Task> { self.state.update(cx, |state, cx| state.authenticate(cx)) } @@ -322,6 +327,7 @@ impl LanguageModel for DeepSeekLanguageModel { } .boxed() } + fn use_any_tool( &self, request: LanguageModelRequest, diff --git a/crates/language_models/src/provider/google.rs b/crates/language_models/src/provider/google.rs index 48960e8ecd..d7a6e8ba2a 100644 --- a/crates/language_models/src/provider/google.rs +++ b/crates/language_models/src/provider/google.rs @@ -1,4 +1,4 @@ -use anyhow::{anyhow, Result}; +use anyhow::{anyhow, Context as _, Result}; use collections::BTreeMap; use editor::{Editor, EditorElement, EditorStyle}; use futures::{future::BoxFuture, FutureExt, StreamExt}; @@ -7,7 +7,7 @@ use gpui::{ AnyView, App, AsyncApp, Context, Entity, FontStyle, Subscription, Task, TextStyle, WhiteSpace, }; use http_client::HttpClient; -use language_model::LanguageModelCompletionEvent; +use language_model::{AuthenticateError, LanguageModelCompletionEvent}; use language_model::{ LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, @@ -85,34 +85,38 @@ impl State { }) } - fn authenticate(&self, cx: &mut Context) -> Task> { + fn authenticate(&self, cx: &mut Context) -> Task> { if self.is_authenticated() { - Task::ready(Ok(())) - } else { - let api_url = AllLanguageModelSettings::get_global(cx) - .google - .api_url - .clone(); - - cx.spawn(|this, mut cx| async move { - let (api_key, from_env) = if let Ok(api_key) = std::env::var(GOOGLE_AI_API_KEY_VAR) - { - (api_key, true) - } else { - let (_, api_key) = cx - .update(|cx| cx.read_credentials(&api_url))? - .await? - .ok_or_else(|| anyhow!("credentials not found"))?; - (String::from_utf8(api_key)?, false) - }; - - this.update(&mut cx, |this, cx| { - this.api_key = Some(api_key); - this.api_key_from_env = from_env; - cx.notify(); - }) - }) + return Task::ready(Ok(())); } + + let api_url = AllLanguageModelSettings::get_global(cx) + .google + .api_url + .clone(); + + cx.spawn(|this, mut cx| async move { + let (api_key, from_env) = if let Ok(api_key) = std::env::var(GOOGLE_AI_API_KEY_VAR) { + (api_key, true) + } else { + let (_, api_key) = cx + .update(|cx| cx.read_credentials(&api_url))? + .await? + .ok_or(AuthenticateError::CredentialsNotFound)?; + ( + String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?, + false, + ) + }; + + this.update(&mut cx, |this, cx| { + this.api_key = Some(api_key); + this.api_key_from_env = from_env; + cx.notify(); + })?; + + Ok(()) + }) } } @@ -194,7 +198,7 @@ impl LanguageModelProvider for GoogleLanguageModelProvider { self.state.read(cx).is_authenticated() } - fn authenticate(&self, cx: &mut App) -> Task> { + fn authenticate(&self, cx: &mut App) -> Task> { self.state.update(cx, |state, cx| state.authenticate(cx)) } diff --git a/crates/language_models/src/provider/lmstudio.rs b/crates/language_models/src/provider/lmstudio.rs index 9096c3fb32..76832a44e1 100644 --- a/crates/language_models/src/provider/lmstudio.rs +++ b/crates/language_models/src/provider/lmstudio.rs @@ -2,7 +2,7 @@ use anyhow::{anyhow, Result}; use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt}; use gpui::{AnyView, App, AsyncApp, Context, Subscription, Task}; use http_client::HttpClient; -use language_model::LanguageModelCompletionEvent; +use language_model::{AuthenticateError, LanguageModelCompletionEvent}; use language_model::{ LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, @@ -90,12 +90,13 @@ impl State { self.fetch_model_task.replace(task); } - fn authenticate(&mut self, cx: &mut Context) -> Task> { + fn authenticate(&mut self, cx: &mut Context) -> Task> { if self.is_authenticated() { - Task::ready(Ok(())) - } else { - self.fetch_models(cx) + return Task::ready(Ok(())); } + + let fetch_models_task = self.fetch_models(cx); + cx.spawn(|_this, _cx| async move { Ok(fetch_models_task.await?) }) } } @@ -201,7 +202,7 @@ impl LanguageModelProvider for LmStudioLanguageModelProvider { self.state.read(cx).is_authenticated() } - fn authenticate(&self, cx: &mut App) -> Task> { + fn authenticate(&self, cx: &mut App) -> Task> { self.state.update(cx, |state, cx| state.authenticate(cx)) } diff --git a/crates/language_models/src/provider/mistral.rs b/crates/language_models/src/provider/mistral.rs index 383bff1437..a5cc3dac16 100644 --- a/crates/language_models/src/provider/mistral.rs +++ b/crates/language_models/src/provider/mistral.rs @@ -1,4 +1,4 @@ -use anyhow::{anyhow, Result}; +use anyhow::{anyhow, Context as _, Result}; use collections::BTreeMap; use editor::{Editor, EditorElement, EditorStyle}; use futures::{future::BoxFuture, FutureExt, StreamExt}; @@ -7,8 +7,8 @@ use gpui::{ }; use http_client::HttpClient; use language_model::{ - LanguageModel, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, - LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, + AuthenticateError, LanguageModel, LanguageModelCompletionEvent, LanguageModelId, + LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role, }; @@ -88,31 +88,36 @@ impl State { }) } - fn authenticate(&self, cx: &mut Context) -> Task> { + fn authenticate(&self, cx: &mut Context) -> Task> { if self.is_authenticated() { - Task::ready(Ok(())) - } else { - let api_url = AllLanguageModelSettings::get_global(cx) - .mistral - .api_url - .clone(); - cx.spawn(|this, mut cx| async move { - let (api_key, from_env) = if let Ok(api_key) = std::env::var(MISTRAL_API_KEY_VAR) { - (api_key, true) - } else { - let (_, api_key) = cx - .update(|cx| cx.read_credentials(&api_url))? - .await? - .ok_or_else(|| anyhow!("credentials not found"))?; - (String::from_utf8(api_key)?, false) - }; - this.update(&mut cx, |this, cx| { - this.api_key = Some(api_key); - this.api_key_from_env = from_env; - cx.notify(); - }) - }) + return Task::ready(Ok(())); } + + let api_url = AllLanguageModelSettings::get_global(cx) + .mistral + .api_url + .clone(); + cx.spawn(|this, mut cx| async move { + let (api_key, from_env) = if let Ok(api_key) = std::env::var(MISTRAL_API_KEY_VAR) { + (api_key, true) + } else { + let (_, api_key) = cx + .update(|cx| cx.read_credentials(&api_url))? + .await? + .ok_or(AuthenticateError::CredentialsNotFound)?; + ( + String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?, + false, + ) + }; + this.update(&mut cx, |this, cx| { + this.api_key = Some(api_key); + this.api_key_from_env = from_env; + cx.notify(); + })?; + + Ok(()) + }) } } @@ -196,7 +201,7 @@ impl LanguageModelProvider for MistralLanguageModelProvider { self.state.read(cx).is_authenticated() } - fn authenticate(&self, cx: &mut App) -> Task> { + fn authenticate(&self, cx: &mut App) -> Task> { self.state.update(cx, |state, cx| state.authenticate(cx)) } diff --git a/crates/language_models/src/provider/ollama.rs b/crates/language_models/src/provider/ollama.rs index a1acf5b695..a982eb3aa7 100644 --- a/crates/language_models/src/provider/ollama.rs +++ b/crates/language_models/src/provider/ollama.rs @@ -2,7 +2,7 @@ use anyhow::{anyhow, bail, Result}; use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt}; use gpui::{AnyView, App, AsyncApp, Context, Subscription, Task}; use http_client::HttpClient; -use language_model::LanguageModelCompletionEvent; +use language_model::{AuthenticateError, LanguageModelCompletionEvent}; use language_model::{ LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, @@ -95,12 +95,13 @@ impl State { self.fetch_model_task.replace(task); } - fn authenticate(&mut self, cx: &mut Context) -> Task> { + fn authenticate(&mut self, cx: &mut Context) -> Task> { if self.is_authenticated() { - Task::ready(Ok(())) - } else { - self.fetch_models(cx) + return Task::ready(Ok(())); } + + let fetch_models_task = self.fetch_models(cx); + cx.spawn(|_this, _cx| async move { Ok(fetch_models_task.await?) }) } } @@ -207,7 +208,7 @@ impl LanguageModelProvider for OllamaLanguageModelProvider { self.state.read(cx).is_authenticated() } - fn authenticate(&self, cx: &mut App) -> Task> { + fn authenticate(&self, cx: &mut App) -> Task> { self.state.update(cx, |state, cx| state.authenticate(cx)) } diff --git a/crates/language_models/src/provider/open_ai.rs b/crates/language_models/src/provider/open_ai.rs index f47ccc2eca..765eae9b15 100644 --- a/crates/language_models/src/provider/open_ai.rs +++ b/crates/language_models/src/provider/open_ai.rs @@ -1,4 +1,4 @@ -use anyhow::{anyhow, Result}; +use anyhow::{anyhow, Context as _, Result}; use collections::BTreeMap; use editor::{Editor, EditorElement, EditorStyle}; use futures::{future::BoxFuture, FutureExt, StreamExt}; @@ -7,8 +7,8 @@ use gpui::{ }; use http_client::HttpClient; use language_model::{ - LanguageModel, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, - LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, + AuthenticateError, LanguageModel, LanguageModelCompletionEvent, LanguageModelId, + LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role, }; use open_ai::{ @@ -89,31 +89,36 @@ impl State { }) } - fn authenticate(&self, cx: &mut Context) -> Task> { + fn authenticate(&self, cx: &mut Context) -> Task> { if self.is_authenticated() { - Task::ready(Ok(())) - } else { - let api_url = AllLanguageModelSettings::get_global(cx) - .openai - .api_url - .clone(); - cx.spawn(|this, mut cx| async move { - let (api_key, from_env) = if let Ok(api_key) = std::env::var(OPENAI_API_KEY_VAR) { - (api_key, true) - } else { - let (_, api_key) = cx - .update(|cx| cx.read_credentials(&api_url))? - .await? - .ok_or_else(|| anyhow!("credentials not found"))?; - (String::from_utf8(api_key)?, false) - }; - this.update(&mut cx, |this, cx| { - this.api_key = Some(api_key); - this.api_key_from_env = from_env; - cx.notify(); - }) - }) + return Task::ready(Ok(())); } + + let api_url = AllLanguageModelSettings::get_global(cx) + .openai + .api_url + .clone(); + cx.spawn(|this, mut cx| async move { + let (api_key, from_env) = if let Ok(api_key) = std::env::var(OPENAI_API_KEY_VAR) { + (api_key, true) + } else { + let (_, api_key) = cx + .update(|cx| cx.read_credentials(&api_url))? + .await? + .ok_or(AuthenticateError::CredentialsNotFound)?; + ( + String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?, + false, + ) + }; + this.update(&mut cx, |this, cx| { + this.api_key = Some(api_key); + this.api_key_from_env = from_env; + cx.notify(); + })?; + + Ok(()) + }) } } @@ -197,7 +202,7 @@ impl LanguageModelProvider for OpenAiLanguageModelProvider { self.state.read(cx).is_authenticated() } - fn authenticate(&self, cx: &mut App) -> Task> { + fn authenticate(&self, cx: &mut App) -> Task> { self.state.update(cx, |state, cx| state.authenticate(cx)) }