language_model: Return AuthenticateErrors from LanguageModelProvider::authenticate (#25126)

This PR updates the `LanguageModelProvider::authenticate` method to
return an `AuthenticateError` instead of an `anyhow::Error`.

This allows us to model the "credentials not found" state explicitly as
`AuthenticateError::CredentialsNotFound`, which enables the caller to
check for this state and act accordingly.

Planning to use this in #25123 to silence errors about missing
credentials when authenticating providers in the background.

Release Notes:

- N/A
This commit is contained in:
Marshall Bowers 2025-02-18 19:01:48 -05:00 committed by GitHub
parent 2627a5fdbe
commit 7a6b652ebc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 230 additions and 184 deletions

1
Cargo.lock generated
View file

@ -6951,6 +6951,7 @@ dependencies = [
"serde_json", "serde_json",
"smol", "smol",
"strum", "strum",
"thiserror 1.0.69",
"ui", "ui",
"util", "util",
] ]

View file

@ -20,7 +20,9 @@ use gpui::{
Subscription, Task, UpdateGlobal, WeakEntity, Subscription, Task, UpdateGlobal, WeakEntity,
}; };
use language::LanguageRegistry; use language::LanguageRegistry;
use language_model::{LanguageModelProviderId, LanguageModelRegistry, ZED_CLOUD_PROVIDER_ID}; use language_model::{
AuthenticateError, LanguageModelProviderId, LanguageModelRegistry, ZED_CLOUD_PROVIDER_ID,
};
use project::Project; use project::Project;
use prompt_library::{open_prompt_library, PromptBuilder, PromptLibrary}; use prompt_library::{open_prompt_library, PromptBuilder, PromptLibrary};
use search::{buffer_search::DivRegistrar, BufferSearchBar}; use search::{buffer_search::DivRegistrar, BufferSearchBar};
@ -1156,7 +1158,10 @@ impl AssistantPanel {
.map_or(false, |provider| provider.is_authenticated(cx)) .map_or(false, |provider| provider.is_authenticated(cx))
} }
fn authenticate(&mut self, cx: &mut Context<Self>) -> Option<Task<Result<()>>> { fn authenticate(
&mut self,
cx: &mut Context<Self>,
) -> Option<Task<Result<(), AuthenticateError>>> {
LanguageModelRegistry::read_global(cx) LanguageModelRegistry::read_global(cx)
.active_provider() .active_provider()
.map_or(None, |provider| Some(provider.authenticate(cx))) .map_or(None, |provider| Some(provider.authenticate(cx)))

View file

@ -38,6 +38,7 @@ serde.workspace = true
serde_json.workspace = true serde_json.workspace = true
smol.workspace = true smol.workspace = true
strum.workspace = true strum.workspace = true
thiserror.workspace = true
ui.workspace = true ui.workspace = true
util.workspace = true util.workspace = true

View file

@ -1,6 +1,6 @@
use crate::{ use crate::{
LanguageModel, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, AuthenticateError, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelRequest, LanguageModelProviderState, LanguageModelRequest,
}; };
use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt}; use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
@ -54,7 +54,7 @@ impl LanguageModelProvider for FakeLanguageModelProvider {
true true
} }
fn authenticate(&self, _: &mut App) -> Task<Result<()>> { fn authenticate(&self, _: &mut App) -> Task<Result<(), AuthenticateError>> {
Task::ready(Ok(())) Task::ready(Ok(()))
} }

View file

@ -21,6 +21,7 @@ use schemars::JsonSchema;
use serde::{de::DeserializeOwned, Deserialize, Serialize}; use serde::{de::DeserializeOwned, Deserialize, Serialize};
use std::fmt; use std::fmt;
use std::{future::Future, sync::Arc}; use std::{future::Future, sync::Arc};
use thiserror::Error;
use ui::IconName; use ui::IconName;
pub const ZED_CLOUD_PROVIDER_ID: &str = "zed.dev"; pub const ZED_CLOUD_PROVIDER_ID: &str = "zed.dev";
@ -231,6 +232,15 @@ pub trait LanguageModelTool: 'static + DeserializeOwned + JsonSchema {
fn description() -> String; fn description() -> String;
} }
/// An error that occurred when trying to authenticate the language model provider.
#[derive(Debug, Error)]
pub enum AuthenticateError {
#[error("credentials not found")]
CredentialsNotFound,
#[error(transparent)]
Other(#[from] anyhow::Error),
}
pub trait LanguageModelProvider: 'static { pub trait LanguageModelProvider: 'static {
fn id(&self) -> LanguageModelProviderId; fn id(&self) -> LanguageModelProviderId;
fn name(&self) -> LanguageModelProviderName; fn name(&self) -> LanguageModelProviderName;
@ -240,7 +250,7 @@ pub trait LanguageModelProvider: 'static {
fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>>; fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>>;
fn load_model(&self, _model: Arc<dyn LanguageModel>, _cx: &App) {} fn load_model(&self, _model: Arc<dyn LanguageModel>, _cx: &App) {}
fn is_authenticated(&self, cx: &App) -> bool; fn is_authenticated(&self, cx: &App) -> bool;
fn authenticate(&self, cx: &mut App) -> Task<Result<()>>; fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>>;
fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView; fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView;
fn must_accept_terms(&self, _cx: &App) -> bool { fn must_accept_terms(&self, _cx: &App) -> bool {
false false

View file

@ -10,8 +10,8 @@ use gpui::{
}; };
use http_client::HttpClient; use http_client::HttpClient;
use language_model::{ use language_model::{
LanguageModel, LanguageModelCacheConfiguration, LanguageModelId, LanguageModelName, AuthenticateError, LanguageModel, LanguageModelCacheConfiguration, LanguageModelId,
LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role, LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
}; };
use language_model::{LanguageModelCompletionEvent, LanguageModelToolUse, StopReason}; use language_model::{LanguageModelCompletionEvent, LanguageModelToolUse, StopReason};
@ -105,34 +105,38 @@ impl State {
self.api_key.is_some() self.api_key.is_some()
} }
fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<()>> { fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
if self.is_authenticated() { if self.is_authenticated() {
Task::ready(Ok(())) return Task::ready(Ok(()));
} else { }
let api_url = AllLanguageModelSettings::get_global(cx) let api_url = AllLanguageModelSettings::get_global(cx)
.anthropic .anthropic
.api_url .api_url
.clone(); .clone();
cx.spawn(|this, mut cx| async move { cx.spawn(|this, mut cx| async move {
let (api_key, from_env) = if let Ok(api_key) = std::env::var(ANTHROPIC_API_KEY_VAR) let (api_key, from_env) = if let Ok(api_key) = std::env::var(ANTHROPIC_API_KEY_VAR) {
{
(api_key, true) (api_key, true)
} else { } else {
let (_, api_key) = cx let (_, api_key) = cx
.update(|cx| cx.read_credentials(&api_url))? .update(|cx| cx.read_credentials(&api_url))?
.await? .await?
.ok_or_else(|| anyhow!("credentials not found"))?; .ok_or(AuthenticateError::CredentialsNotFound)?;
(String::from_utf8(api_key)?, false) (
String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?,
false,
)
}; };
this.update(&mut cx, |this, cx| { this.update(&mut cx, |this, cx| {
this.api_key = Some(api_key); this.api_key = Some(api_key);
this.api_key_from_env = from_env; this.api_key_from_env = from_env;
cx.notify(); cx.notify();
})?;
Ok(())
}) })
})
}
} }
} }
@ -226,7 +230,7 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider {
self.state.read(cx).is_authenticated() self.state.read(cx).is_authenticated()
} }
fn authenticate(&self, cx: &mut App) -> Task<Result<()>> { fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
self.state.update(cx, |state, cx| state.authenticate(cx)) self.state.update(cx, |state, cx| state.authenticate(cx))
} }

View file

@ -17,9 +17,10 @@ use gpui::{
}; };
use http_client::{AsyncBody, HttpClient, Method, Response, StatusCode}; use http_client::{AsyncBody, HttpClient, Method, Response, StatusCode};
use language_model::{ use language_model::{
CloudModel, LanguageModel, LanguageModelCacheConfiguration, LanguageModelId, LanguageModelName, AuthenticateError, CloudModel, LanguageModel, LanguageModelCacheConfiguration, LanguageModelId,
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderTosView, LanguageModelRequest, RateLimiter, ZED_CLOUD_PROVIDER_ID, LanguageModelProviderState, LanguageModelProviderTosView, LanguageModelRequest, RateLimiter,
ZED_CLOUD_PROVIDER_ID,
}; };
use language_model::{ use language_model::{
LanguageModelAvailability, LanguageModelCompletionEvent, LanguageModelProvider, LanguageModelAvailability, LanguageModelCompletionEvent, LanguageModelProvider,
@ -363,7 +364,7 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
!self.state.read(cx).is_signed_out() !self.state.read(cx).is_signed_out()
} }
fn authenticate(&self, _cx: &mut App) -> Task<Result<()>> { fn authenticate(&self, _cx: &mut App) -> Task<Result<(), AuthenticateError>> {
Task::ready(Ok(())) Task::ready(Ok(()))
} }

View file

@ -15,8 +15,8 @@ use gpui::{
Task, Transformation, Task, Transformation,
}; };
use language_model::{ use language_model::{
LanguageModel, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, AuthenticateError, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role, LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
}; };
use settings::SettingsStore; use settings::SettingsStore;
@ -104,26 +104,28 @@ impl LanguageModelProvider for CopilotChatLanguageModelProvider {
self.state.read(cx).is_authenticated(cx) self.state.read(cx).is_authenticated(cx)
} }
fn authenticate(&self, cx: &mut App) -> Task<Result<()>> { fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
let result = if self.is_authenticated(cx) { if self.is_authenticated(cx) {
Ok(()) return Task::ready(Ok(()));
} else if let Some(copilot) = Copilot::global(cx) {
let error_msg = match copilot.read(cx).status() {
Status::Disabled => anyhow::anyhow!("Copilot must be enabled for Copilot Chat to work. Please enable Copilot and try again."),
Status::Error(e) => anyhow::anyhow!(format!("Received the following error while signing into Copilot: {e}")),
Status::Starting { task: _ } => anyhow::anyhow!("Copilot is still starting, please wait for Copilot to start then try again"),
Status::Unauthorized => anyhow::anyhow!("Unable to authorize with Copilot. Please make sure that you have an active Copilot and Copilot Chat subscription."),
Status::Authorized => return Task::ready(Ok(())),
Status::SignedOut => anyhow::anyhow!("You have signed out of Copilot. Please sign in to Copilot and try again."),
Status::SigningIn { prompt: _ } => anyhow::anyhow!("Still signing into Copilot..."),
}; };
Err(error_msg)
} else { let Some(copilot) = Copilot::global(cx) else {
Err(anyhow::anyhow!( return Task::ready( Err(anyhow!(
"Copilot must be enabled for Copilot Chat to work. Please enable Copilot and try again." "Copilot must be enabled for Copilot Chat to work. Please enable Copilot and try again."
)) ).into()));
}; };
Task::ready(result)
let err = match copilot.read(cx).status() {
Status::Authorized => return Task::ready(Ok(())),
Status::Disabled => anyhow!("Copilot must be enabled for Copilot Chat to work. Please enable Copilot and try again."),
Status::Error(err) => anyhow!(format!("Received the following error while signing into Copilot: {err}")),
Status::Starting { task: _ } => anyhow!("Copilot is still starting, please wait for Copilot to start then try again"),
Status::Unauthorized => anyhow!("Unable to authorize with Copilot. Please make sure that you have an active Copilot and Copilot Chat subscription."),
Status::SignedOut => anyhow!("You have signed out of Copilot. Please sign in to Copilot and try again."),
Status::SigningIn { prompt: _ } => anyhow!("Still signing into Copilot..."),
};
Task::ready(Err(err.into()))
} }
fn configuration_view(&self, _: &mut Window, cx: &mut App) -> AnyView { fn configuration_view(&self, _: &mut Window, cx: &mut App) -> AnyView {

View file

@ -1,4 +1,4 @@
use anyhow::{anyhow, Result}; use anyhow::{anyhow, Context as _, Result};
use collections::BTreeMap; use collections::BTreeMap;
use editor::{Editor, EditorElement, EditorStyle}; use editor::{Editor, EditorElement, EditorStyle};
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt}; use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
@ -8,8 +8,8 @@ use gpui::{
}; };
use http_client::HttpClient; use http_client::HttpClient;
use language_model::{ use language_model::{
LanguageModel, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, AuthenticateError, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role, LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
}; };
use schemars::JsonSchema; use schemars::JsonSchema;
@ -83,10 +83,11 @@ impl State {
}) })
} }
fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<()>> { fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
if self.is_authenticated() { if self.is_authenticated() {
Task::ready(Ok(())) return Task::ready(Ok(()));
} else { }
let api_url = AllLanguageModelSettings::get_global(cx) let api_url = AllLanguageModelSettings::get_global(cx)
.deepseek .deepseek
.api_url .api_url
@ -99,17 +100,21 @@ impl State {
let (_, api_key) = cx let (_, api_key) = cx
.update(|cx| cx.read_credentials(&api_url))? .update(|cx| cx.read_credentials(&api_url))?
.await? .await?
.ok_or_else(|| anyhow!("credentials not found"))?; .ok_or(AuthenticateError::CredentialsNotFound)?;
(String::from_utf8(api_key)?, false) (
String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?,
false,
)
}; };
this.update(&mut cx, |this, cx| { this.update(&mut cx, |this, cx| {
this.api_key = Some(api_key); this.api_key = Some(api_key);
this.api_key_from_env = from_env; this.api_key_from_env = from_env;
cx.notify(); cx.notify();
})?;
Ok(())
}) })
})
}
} }
} }
@ -188,7 +193,7 @@ impl LanguageModelProvider for DeepSeekLanguageModelProvider {
self.state.read(cx).is_authenticated() self.state.read(cx).is_authenticated()
} }
fn authenticate(&self, cx: &mut App) -> Task<Result<()>> { fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
self.state.update(cx, |state, cx| state.authenticate(cx)) self.state.update(cx, |state, cx| state.authenticate(cx))
} }
@ -322,6 +327,7 @@ impl LanguageModel for DeepSeekLanguageModel {
} }
.boxed() .boxed()
} }
fn use_any_tool( fn use_any_tool(
&self, &self,
request: LanguageModelRequest, request: LanguageModelRequest,

View file

@ -1,4 +1,4 @@
use anyhow::{anyhow, Result}; use anyhow::{anyhow, Context as _, Result};
use collections::BTreeMap; use collections::BTreeMap;
use editor::{Editor, EditorElement, EditorStyle}; use editor::{Editor, EditorElement, EditorStyle};
use futures::{future::BoxFuture, FutureExt, StreamExt}; use futures::{future::BoxFuture, FutureExt, StreamExt};
@ -7,7 +7,7 @@ use gpui::{
AnyView, App, AsyncApp, Context, Entity, FontStyle, Subscription, Task, TextStyle, WhiteSpace, AnyView, App, AsyncApp, Context, Entity, FontStyle, Subscription, Task, TextStyle, WhiteSpace,
}; };
use http_client::HttpClient; use http_client::HttpClient;
use language_model::LanguageModelCompletionEvent; use language_model::{AuthenticateError, LanguageModelCompletionEvent};
use language_model::{ use language_model::{
LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
@ -85,34 +85,38 @@ impl State {
}) })
} }
fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<()>> { fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
if self.is_authenticated() { if self.is_authenticated() {
Task::ready(Ok(())) return Task::ready(Ok(()));
} else { }
let api_url = AllLanguageModelSettings::get_global(cx) let api_url = AllLanguageModelSettings::get_global(cx)
.google .google
.api_url .api_url
.clone(); .clone();
cx.spawn(|this, mut cx| async move { cx.spawn(|this, mut cx| async move {
let (api_key, from_env) = if let Ok(api_key) = std::env::var(GOOGLE_AI_API_KEY_VAR) let (api_key, from_env) = if let Ok(api_key) = std::env::var(GOOGLE_AI_API_KEY_VAR) {
{
(api_key, true) (api_key, true)
} else { } else {
let (_, api_key) = cx let (_, api_key) = cx
.update(|cx| cx.read_credentials(&api_url))? .update(|cx| cx.read_credentials(&api_url))?
.await? .await?
.ok_or_else(|| anyhow!("credentials not found"))?; .ok_or(AuthenticateError::CredentialsNotFound)?;
(String::from_utf8(api_key)?, false) (
String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?,
false,
)
}; };
this.update(&mut cx, |this, cx| { this.update(&mut cx, |this, cx| {
this.api_key = Some(api_key); this.api_key = Some(api_key);
this.api_key_from_env = from_env; this.api_key_from_env = from_env;
cx.notify(); cx.notify();
})?;
Ok(())
}) })
})
}
} }
} }
@ -194,7 +198,7 @@ impl LanguageModelProvider for GoogleLanguageModelProvider {
self.state.read(cx).is_authenticated() self.state.read(cx).is_authenticated()
} }
fn authenticate(&self, cx: &mut App) -> Task<Result<()>> { fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
self.state.update(cx, |state, cx| state.authenticate(cx)) self.state.update(cx, |state, cx| state.authenticate(cx))
} }

View file

@ -2,7 +2,7 @@ use anyhow::{anyhow, Result};
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt}; use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
use gpui::{AnyView, App, AsyncApp, Context, Subscription, Task}; use gpui::{AnyView, App, AsyncApp, Context, Subscription, Task};
use http_client::HttpClient; use http_client::HttpClient;
use language_model::LanguageModelCompletionEvent; use language_model::{AuthenticateError, LanguageModelCompletionEvent};
use language_model::{ use language_model::{
LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
@ -90,12 +90,13 @@ impl State {
self.fetch_model_task.replace(task); self.fetch_model_task.replace(task);
} }
fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> { fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
if self.is_authenticated() { if self.is_authenticated() {
Task::ready(Ok(())) return Task::ready(Ok(()));
} else {
self.fetch_models(cx)
} }
let fetch_models_task = self.fetch_models(cx);
cx.spawn(|_this, _cx| async move { Ok(fetch_models_task.await?) })
} }
} }
@ -201,7 +202,7 @@ impl LanguageModelProvider for LmStudioLanguageModelProvider {
self.state.read(cx).is_authenticated() self.state.read(cx).is_authenticated()
} }
fn authenticate(&self, cx: &mut App) -> Task<Result<()>> { fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
self.state.update(cx, |state, cx| state.authenticate(cx)) self.state.update(cx, |state, cx| state.authenticate(cx))
} }

View file

@ -1,4 +1,4 @@
use anyhow::{anyhow, Result}; use anyhow::{anyhow, Context as _, Result};
use collections::BTreeMap; use collections::BTreeMap;
use editor::{Editor, EditorElement, EditorStyle}; use editor::{Editor, EditorElement, EditorStyle};
use futures::{future::BoxFuture, FutureExt, StreamExt}; use futures::{future::BoxFuture, FutureExt, StreamExt};
@ -7,8 +7,8 @@ use gpui::{
}; };
use http_client::HttpClient; use http_client::HttpClient;
use language_model::{ use language_model::{
LanguageModel, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, AuthenticateError, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role, LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
}; };
@ -88,10 +88,11 @@ impl State {
}) })
} }
fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<()>> { fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
if self.is_authenticated() { if self.is_authenticated() {
Task::ready(Ok(())) return Task::ready(Ok(()));
} else { }
let api_url = AllLanguageModelSettings::get_global(cx) let api_url = AllLanguageModelSettings::get_global(cx)
.mistral .mistral
.api_url .api_url
@ -103,16 +104,20 @@ impl State {
let (_, api_key) = cx let (_, api_key) = cx
.update(|cx| cx.read_credentials(&api_url))? .update(|cx| cx.read_credentials(&api_url))?
.await? .await?
.ok_or_else(|| anyhow!("credentials not found"))?; .ok_or(AuthenticateError::CredentialsNotFound)?;
(String::from_utf8(api_key)?, false) (
String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?,
false,
)
}; };
this.update(&mut cx, |this, cx| { this.update(&mut cx, |this, cx| {
this.api_key = Some(api_key); this.api_key = Some(api_key);
this.api_key_from_env = from_env; this.api_key_from_env = from_env;
cx.notify(); cx.notify();
})?;
Ok(())
}) })
})
}
} }
} }
@ -196,7 +201,7 @@ impl LanguageModelProvider for MistralLanguageModelProvider {
self.state.read(cx).is_authenticated() self.state.read(cx).is_authenticated()
} }
fn authenticate(&self, cx: &mut App) -> Task<Result<()>> { fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
self.state.update(cx, |state, cx| state.authenticate(cx)) self.state.update(cx, |state, cx| state.authenticate(cx))
} }

View file

@ -2,7 +2,7 @@ use anyhow::{anyhow, bail, Result};
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt}; use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
use gpui::{AnyView, App, AsyncApp, Context, Subscription, Task}; use gpui::{AnyView, App, AsyncApp, Context, Subscription, Task};
use http_client::HttpClient; use http_client::HttpClient;
use language_model::LanguageModelCompletionEvent; use language_model::{AuthenticateError, LanguageModelCompletionEvent};
use language_model::{ use language_model::{
LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
@ -95,12 +95,13 @@ impl State {
self.fetch_model_task.replace(task); self.fetch_model_task.replace(task);
} }
fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> { fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
if self.is_authenticated() { if self.is_authenticated() {
Task::ready(Ok(())) return Task::ready(Ok(()));
} else {
self.fetch_models(cx)
} }
let fetch_models_task = self.fetch_models(cx);
cx.spawn(|_this, _cx| async move { Ok(fetch_models_task.await?) })
} }
} }
@ -207,7 +208,7 @@ impl LanguageModelProvider for OllamaLanguageModelProvider {
self.state.read(cx).is_authenticated() self.state.read(cx).is_authenticated()
} }
fn authenticate(&self, cx: &mut App) -> Task<Result<()>> { fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
self.state.update(cx, |state, cx| state.authenticate(cx)) self.state.update(cx, |state, cx| state.authenticate(cx))
} }

View file

@ -1,4 +1,4 @@
use anyhow::{anyhow, Result}; use anyhow::{anyhow, Context as _, Result};
use collections::BTreeMap; use collections::BTreeMap;
use editor::{Editor, EditorElement, EditorStyle}; use editor::{Editor, EditorElement, EditorStyle};
use futures::{future::BoxFuture, FutureExt, StreamExt}; use futures::{future::BoxFuture, FutureExt, StreamExt};
@ -7,8 +7,8 @@ use gpui::{
}; };
use http_client::HttpClient; use http_client::HttpClient;
use language_model::{ use language_model::{
LanguageModel, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, AuthenticateError, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role, LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
}; };
use open_ai::{ use open_ai::{
@ -89,10 +89,11 @@ impl State {
}) })
} }
fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<()>> { fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
if self.is_authenticated() { if self.is_authenticated() {
Task::ready(Ok(())) return Task::ready(Ok(()));
} else { }
let api_url = AllLanguageModelSettings::get_global(cx) let api_url = AllLanguageModelSettings::get_global(cx)
.openai .openai
.api_url .api_url
@ -104,16 +105,20 @@ impl State {
let (_, api_key) = cx let (_, api_key) = cx
.update(|cx| cx.read_credentials(&api_url))? .update(|cx| cx.read_credentials(&api_url))?
.await? .await?
.ok_or_else(|| anyhow!("credentials not found"))?; .ok_or(AuthenticateError::CredentialsNotFound)?;
(String::from_utf8(api_key)?, false) (
String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?,
false,
)
}; };
this.update(&mut cx, |this, cx| { this.update(&mut cx, |this, cx| {
this.api_key = Some(api_key); this.api_key = Some(api_key);
this.api_key_from_env = from_env; this.api_key_from_env = from_env;
cx.notify(); cx.notify();
})?;
Ok(())
}) })
})
}
} }
} }
@ -197,7 +202,7 @@ impl LanguageModelProvider for OpenAiLanguageModelProvider {
self.state.read(cx).is_authenticated() self.state.read(cx).is_authenticated()
} }
fn authenticate(&self, cx: &mut App) -> Task<Result<()>> { fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
self.state.update(cx, |state, cx| state.authenticate(cx)) self.state.update(cx, |state, cx| state.authenticate(cx))
} }