language_model: Return AuthenticateErrors from LanguageModelProvider::authenticate (#25126)

This PR updates the `LanguageModelProvider::authenticate` method to
return an `AuthenticateError` instead of an `anyhow::Error`.

This allows us to model the "credentials not found" state explicitly as
`AuthenticateError::CredentialsNotFound`, which enables the caller to
check for this state and act accordingly.

Planning to use this in #25123 to silence errors about missing
credentials when authenticating providers in the background.

Release Notes:

- N/A
This commit is contained in:
Marshall Bowers 2025-02-18 19:01:48 -05:00 committed by GitHub
parent 2627a5fdbe
commit 7a6b652ebc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 230 additions and 184 deletions

View file

@ -10,8 +10,8 @@ use gpui::{
};
use http_client::HttpClient;
use language_model::{
LanguageModel, LanguageModelCacheConfiguration, LanguageModelId, LanguageModelName,
LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
AuthenticateError, LanguageModel, LanguageModelCacheConfiguration, LanguageModelId,
LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
};
use language_model::{LanguageModelCompletionEvent, LanguageModelToolUse, StopReason};
@ -105,34 +105,38 @@ impl State {
self.api_key.is_some()
}
fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
if self.is_authenticated() {
Task::ready(Ok(()))
} else {
let api_url = AllLanguageModelSettings::get_global(cx)
.anthropic
.api_url
.clone();
cx.spawn(|this, mut cx| async move {
let (api_key, from_env) = if let Ok(api_key) = std::env::var(ANTHROPIC_API_KEY_VAR)
{
(api_key, true)
} else {
let (_, api_key) = cx
.update(|cx| cx.read_credentials(&api_url))?
.await?
.ok_or_else(|| anyhow!("credentials not found"))?;
(String::from_utf8(api_key)?, false)
};
this.update(&mut cx, |this, cx| {
this.api_key = Some(api_key);
this.api_key_from_env = from_env;
cx.notify();
})
})
return Task::ready(Ok(()));
}
let api_url = AllLanguageModelSettings::get_global(cx)
.anthropic
.api_url
.clone();
cx.spawn(|this, mut cx| async move {
let (api_key, from_env) = if let Ok(api_key) = std::env::var(ANTHROPIC_API_KEY_VAR) {
(api_key, true)
} else {
let (_, api_key) = cx
.update(|cx| cx.read_credentials(&api_url))?
.await?
.ok_or(AuthenticateError::CredentialsNotFound)?;
(
String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?,
false,
)
};
this.update(&mut cx, |this, cx| {
this.api_key = Some(api_key);
this.api_key_from_env = from_env;
cx.notify();
})?;
Ok(())
})
}
}
@ -226,7 +230,7 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider {
self.state.read(cx).is_authenticated()
}
fn authenticate(&self, cx: &mut App) -> Task<Result<()>> {
fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
self.state.update(cx, |state, cx| state.authenticate(cx))
}

View file

@ -17,9 +17,10 @@ use gpui::{
};
use http_client::{AsyncBody, HttpClient, Method, Response, StatusCode};
use language_model::{
CloudModel, LanguageModel, LanguageModelCacheConfiguration, LanguageModelId, LanguageModelName,
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
LanguageModelProviderTosView, LanguageModelRequest, RateLimiter, ZED_CLOUD_PROVIDER_ID,
AuthenticateError, CloudModel, LanguageModel, LanguageModelCacheConfiguration, LanguageModelId,
LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelProviderTosView, LanguageModelRequest, RateLimiter,
ZED_CLOUD_PROVIDER_ID,
};
use language_model::{
LanguageModelAvailability, LanguageModelCompletionEvent, LanguageModelProvider,
@ -363,7 +364,7 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
!self.state.read(cx).is_signed_out()
}
fn authenticate(&self, _cx: &mut App) -> Task<Result<()>> {
fn authenticate(&self, _cx: &mut App) -> Task<Result<(), AuthenticateError>> {
Task::ready(Ok(()))
}

View file

@ -15,8 +15,8 @@ use gpui::{
Task, Transformation,
};
use language_model::{
LanguageModel, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName,
LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
AuthenticateError, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
};
use settings::SettingsStore;
@ -104,26 +104,28 @@ impl LanguageModelProvider for CopilotChatLanguageModelProvider {
self.state.read(cx).is_authenticated(cx)
}
fn authenticate(&self, cx: &mut App) -> Task<Result<()>> {
let result = if self.is_authenticated(cx) {
Ok(())
} else if let Some(copilot) = Copilot::global(cx) {
let error_msg = match copilot.read(cx).status() {
Status::Disabled => anyhow::anyhow!("Copilot must be enabled for Copilot Chat to work. Please enable Copilot and try again."),
Status::Error(e) => anyhow::anyhow!(format!("Received the following error while signing into Copilot: {e}")),
Status::Starting { task: _ } => anyhow::anyhow!("Copilot is still starting, please wait for Copilot to start then try again"),
Status::Unauthorized => anyhow::anyhow!("Unable to authorize with Copilot. Please make sure that you have an active Copilot and Copilot Chat subscription."),
Status::Authorized => return Task::ready(Ok(())),
Status::SignedOut => anyhow::anyhow!("You have signed out of Copilot. Please sign in to Copilot and try again."),
Status::SigningIn { prompt: _ } => anyhow::anyhow!("Still signing into Copilot..."),
};
Err(error_msg)
} else {
Err(anyhow::anyhow!(
"Copilot must be enabled for Copilot Chat to work. Please enable Copilot and try again."
))
fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
if self.is_authenticated(cx) {
return Task::ready(Ok(()));
};
Task::ready(result)
let Some(copilot) = Copilot::global(cx) else {
return Task::ready( Err(anyhow!(
"Copilot must be enabled for Copilot Chat to work. Please enable Copilot and try again."
).into()));
};
let err = match copilot.read(cx).status() {
Status::Authorized => return Task::ready(Ok(())),
Status::Disabled => anyhow!("Copilot must be enabled for Copilot Chat to work. Please enable Copilot and try again."),
Status::Error(err) => anyhow!(format!("Received the following error while signing into Copilot: {err}")),
Status::Starting { task: _ } => anyhow!("Copilot is still starting, please wait for Copilot to start then try again"),
Status::Unauthorized => anyhow!("Unable to authorize with Copilot. Please make sure that you have an active Copilot and Copilot Chat subscription."),
Status::SignedOut => anyhow!("You have signed out of Copilot. Please sign in to Copilot and try again."),
Status::SigningIn { prompt: _ } => anyhow!("Still signing into Copilot..."),
};
Task::ready(Err(err.into()))
}
fn configuration_view(&self, _: &mut Window, cx: &mut App) -> AnyView {

View file

@ -1,4 +1,4 @@
use anyhow::{anyhow, Result};
use anyhow::{anyhow, Context as _, Result};
use collections::BTreeMap;
use editor::{Editor, EditorElement, EditorStyle};
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
@ -8,8 +8,8 @@ use gpui::{
};
use http_client::HttpClient;
use language_model::{
LanguageModel, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName,
LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
AuthenticateError, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
};
use schemars::JsonSchema;
@ -83,33 +83,38 @@ impl State {
})
}
fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
if self.is_authenticated() {
Task::ready(Ok(()))
} else {
let api_url = AllLanguageModelSettings::get_global(cx)
.deepseek
.api_url
.clone();
cx.spawn(|this, mut cx| async move {
let (api_key, from_env) = if let Ok(api_key) = std::env::var(DEEPSEEK_API_KEY_VAR) {
(api_key, true)
} else {
let (_, api_key) = cx
.update(|cx| cx.read_credentials(&api_url))?
.await?
.ok_or_else(|| anyhow!("credentials not found"))?;
(String::from_utf8(api_key)?, false)
};
this.update(&mut cx, |this, cx| {
this.api_key = Some(api_key);
this.api_key_from_env = from_env;
cx.notify();
})
})
return Task::ready(Ok(()));
}
let api_url = AllLanguageModelSettings::get_global(cx)
.deepseek
.api_url
.clone();
cx.spawn(|this, mut cx| async move {
let (api_key, from_env) = if let Ok(api_key) = std::env::var(DEEPSEEK_API_KEY_VAR) {
(api_key, true)
} else {
let (_, api_key) = cx
.update(|cx| cx.read_credentials(&api_url))?
.await?
.ok_or(AuthenticateError::CredentialsNotFound)?;
(
String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?,
false,
)
};
this.update(&mut cx, |this, cx| {
this.api_key = Some(api_key);
this.api_key_from_env = from_env;
cx.notify();
})?;
Ok(())
})
}
}
@ -188,7 +193,7 @@ impl LanguageModelProvider for DeepSeekLanguageModelProvider {
self.state.read(cx).is_authenticated()
}
fn authenticate(&self, cx: &mut App) -> Task<Result<()>> {
fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
self.state.update(cx, |state, cx| state.authenticate(cx))
}
@ -322,6 +327,7 @@ impl LanguageModel for DeepSeekLanguageModel {
}
.boxed()
}
fn use_any_tool(
&self,
request: LanguageModelRequest,

View file

@ -1,4 +1,4 @@
use anyhow::{anyhow, Result};
use anyhow::{anyhow, Context as _, Result};
use collections::BTreeMap;
use editor::{Editor, EditorElement, EditorStyle};
use futures::{future::BoxFuture, FutureExt, StreamExt};
@ -7,7 +7,7 @@ use gpui::{
AnyView, App, AsyncApp, Context, Entity, FontStyle, Subscription, Task, TextStyle, WhiteSpace,
};
use http_client::HttpClient;
use language_model::LanguageModelCompletionEvent;
use language_model::{AuthenticateError, LanguageModelCompletionEvent};
use language_model::{
LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
@ -85,34 +85,38 @@ impl State {
})
}
fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
if self.is_authenticated() {
Task::ready(Ok(()))
} else {
let api_url = AllLanguageModelSettings::get_global(cx)
.google
.api_url
.clone();
cx.spawn(|this, mut cx| async move {
let (api_key, from_env) = if let Ok(api_key) = std::env::var(GOOGLE_AI_API_KEY_VAR)
{
(api_key, true)
} else {
let (_, api_key) = cx
.update(|cx| cx.read_credentials(&api_url))?
.await?
.ok_or_else(|| anyhow!("credentials not found"))?;
(String::from_utf8(api_key)?, false)
};
this.update(&mut cx, |this, cx| {
this.api_key = Some(api_key);
this.api_key_from_env = from_env;
cx.notify();
})
})
return Task::ready(Ok(()));
}
let api_url = AllLanguageModelSettings::get_global(cx)
.google
.api_url
.clone();
cx.spawn(|this, mut cx| async move {
let (api_key, from_env) = if let Ok(api_key) = std::env::var(GOOGLE_AI_API_KEY_VAR) {
(api_key, true)
} else {
let (_, api_key) = cx
.update(|cx| cx.read_credentials(&api_url))?
.await?
.ok_or(AuthenticateError::CredentialsNotFound)?;
(
String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?,
false,
)
};
this.update(&mut cx, |this, cx| {
this.api_key = Some(api_key);
this.api_key_from_env = from_env;
cx.notify();
})?;
Ok(())
})
}
}
@ -194,7 +198,7 @@ impl LanguageModelProvider for GoogleLanguageModelProvider {
self.state.read(cx).is_authenticated()
}
fn authenticate(&self, cx: &mut App) -> Task<Result<()>> {
fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
self.state.update(cx, |state, cx| state.authenticate(cx))
}

View file

@ -2,7 +2,7 @@ use anyhow::{anyhow, Result};
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
use gpui::{AnyView, App, AsyncApp, Context, Subscription, Task};
use http_client::HttpClient;
use language_model::LanguageModelCompletionEvent;
use language_model::{AuthenticateError, LanguageModelCompletionEvent};
use language_model::{
LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
@ -90,12 +90,13 @@ impl State {
self.fetch_model_task.replace(task);
}
fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
if self.is_authenticated() {
Task::ready(Ok(()))
} else {
self.fetch_models(cx)
return Task::ready(Ok(()));
}
let fetch_models_task = self.fetch_models(cx);
cx.spawn(|_this, _cx| async move { Ok(fetch_models_task.await?) })
}
}
@ -201,7 +202,7 @@ impl LanguageModelProvider for LmStudioLanguageModelProvider {
self.state.read(cx).is_authenticated()
}
fn authenticate(&self, cx: &mut App) -> Task<Result<()>> {
fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
self.state.update(cx, |state, cx| state.authenticate(cx))
}

View file

@ -1,4 +1,4 @@
use anyhow::{anyhow, Result};
use anyhow::{anyhow, Context as _, Result};
use collections::BTreeMap;
use editor::{Editor, EditorElement, EditorStyle};
use futures::{future::BoxFuture, FutureExt, StreamExt};
@ -7,8 +7,8 @@ use gpui::{
};
use http_client::HttpClient;
use language_model::{
LanguageModel, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName,
LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
AuthenticateError, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
};
@ -88,31 +88,36 @@ impl State {
})
}
fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
if self.is_authenticated() {
Task::ready(Ok(()))
} else {
let api_url = AllLanguageModelSettings::get_global(cx)
.mistral
.api_url
.clone();
cx.spawn(|this, mut cx| async move {
let (api_key, from_env) = if let Ok(api_key) = std::env::var(MISTRAL_API_KEY_VAR) {
(api_key, true)
} else {
let (_, api_key) = cx
.update(|cx| cx.read_credentials(&api_url))?
.await?
.ok_or_else(|| anyhow!("credentials not found"))?;
(String::from_utf8(api_key)?, false)
};
this.update(&mut cx, |this, cx| {
this.api_key = Some(api_key);
this.api_key_from_env = from_env;
cx.notify();
})
})
return Task::ready(Ok(()));
}
let api_url = AllLanguageModelSettings::get_global(cx)
.mistral
.api_url
.clone();
cx.spawn(|this, mut cx| async move {
let (api_key, from_env) = if let Ok(api_key) = std::env::var(MISTRAL_API_KEY_VAR) {
(api_key, true)
} else {
let (_, api_key) = cx
.update(|cx| cx.read_credentials(&api_url))?
.await?
.ok_or(AuthenticateError::CredentialsNotFound)?;
(
String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?,
false,
)
};
this.update(&mut cx, |this, cx| {
this.api_key = Some(api_key);
this.api_key_from_env = from_env;
cx.notify();
})?;
Ok(())
})
}
}
@ -196,7 +201,7 @@ impl LanguageModelProvider for MistralLanguageModelProvider {
self.state.read(cx).is_authenticated()
}
fn authenticate(&self, cx: &mut App) -> Task<Result<()>> {
fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
self.state.update(cx, |state, cx| state.authenticate(cx))
}

View file

@ -2,7 +2,7 @@ use anyhow::{anyhow, bail, Result};
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
use gpui::{AnyView, App, AsyncApp, Context, Subscription, Task};
use http_client::HttpClient;
use language_model::LanguageModelCompletionEvent;
use language_model::{AuthenticateError, LanguageModelCompletionEvent};
use language_model::{
LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
@ -95,12 +95,13 @@ impl State {
self.fetch_model_task.replace(task);
}
fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
if self.is_authenticated() {
Task::ready(Ok(()))
} else {
self.fetch_models(cx)
return Task::ready(Ok(()));
}
let fetch_models_task = self.fetch_models(cx);
cx.spawn(|_this, _cx| async move { Ok(fetch_models_task.await?) })
}
}
@ -207,7 +208,7 @@ impl LanguageModelProvider for OllamaLanguageModelProvider {
self.state.read(cx).is_authenticated()
}
fn authenticate(&self, cx: &mut App) -> Task<Result<()>> {
fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
self.state.update(cx, |state, cx| state.authenticate(cx))
}

View file

@ -1,4 +1,4 @@
use anyhow::{anyhow, Result};
use anyhow::{anyhow, Context as _, Result};
use collections::BTreeMap;
use editor::{Editor, EditorElement, EditorStyle};
use futures::{future::BoxFuture, FutureExt, StreamExt};
@ -7,8 +7,8 @@ use gpui::{
};
use http_client::HttpClient;
use language_model::{
LanguageModel, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName,
LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
AuthenticateError, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
};
use open_ai::{
@ -89,31 +89,36 @@ impl State {
})
}
fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
if self.is_authenticated() {
Task::ready(Ok(()))
} else {
let api_url = AllLanguageModelSettings::get_global(cx)
.openai
.api_url
.clone();
cx.spawn(|this, mut cx| async move {
let (api_key, from_env) = if let Ok(api_key) = std::env::var(OPENAI_API_KEY_VAR) {
(api_key, true)
} else {
let (_, api_key) = cx
.update(|cx| cx.read_credentials(&api_url))?
.await?
.ok_or_else(|| anyhow!("credentials not found"))?;
(String::from_utf8(api_key)?, false)
};
this.update(&mut cx, |this, cx| {
this.api_key = Some(api_key);
this.api_key_from_env = from_env;
cx.notify();
})
})
return Task::ready(Ok(()));
}
let api_url = AllLanguageModelSettings::get_global(cx)
.openai
.api_url
.clone();
cx.spawn(|this, mut cx| async move {
let (api_key, from_env) = if let Ok(api_key) = std::env::var(OPENAI_API_KEY_VAR) {
(api_key, true)
} else {
let (_, api_key) = cx
.update(|cx| cx.read_credentials(&api_url))?
.await?
.ok_or(AuthenticateError::CredentialsNotFound)?;
(
String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?,
false,
)
};
this.update(&mut cx, |this, cx| {
this.api_key = Some(api_key);
this.api_key_from_env = from_env;
cx.notify();
})?;
Ok(())
})
}
}
@ -197,7 +202,7 @@ impl LanguageModelProvider for OpenAiLanguageModelProvider {
self.state.read(cx).is_authenticated()
}
fn authenticate(&self, cx: &mut App) -> Task<Result<()>> {
fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
self.state.update(cx, |state, cx| state.authenticate(cx))
}