Use anthropic provider key

This commit is contained in:
Agus Zubiaga 2025-08-18 16:19:36 -03:00
parent 1c43338056
commit 6c7a5c50bf
9 changed files with 171 additions and 45 deletions

View file

@ -80,12 +80,35 @@ pub trait AgentSessionResume {
}
#[derive(Debug)]
pub struct AuthRequired;
pub struct AuthRequired {
pub description: Option<String>,
/// A Task that resolves when authentication is updated
pub update_task: Option<Task<()>>,
}
impl AuthRequired {
pub fn new() -> Self {
Self {
description: None,
update_task: None,
}
}
pub fn with_description(mut self, description: String) -> Self {
self.description = Some(description);
self
}
pub fn with_update(mut self, update: Task<()>) -> Self {
self.update_task = Some(update);
self
}
}
impl Error for AuthRequired {}
impl fmt::Display for AuthRequired {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "AuthRequired")
write!(f, "Authentication required")
}
}

View file

@ -437,7 +437,7 @@ impl AgentConnection for AcpConnection {
let result = acp_old::InitializeParams::response_from_any(result)?;
if !result.is_authenticated {
anyhow::bail!(AuthRequired)
anyhow::bail!(AuthRequired::new())
}
cx.update(|cx| {

View file

@ -140,7 +140,13 @@ impl AgentConnection for AcpConnection {
.await
.map_err(|err| {
if err.code == acp::ErrorCode::AUTH_REQUIRED.code {
anyhow!(AuthRequired)
let mut error = AuthRequired::new();
if err.message != acp::ErrorCode::AUTH_REQUIRED.message {
error = error.with_description(err.message);
}
anyhow!(error)
} else {
anyhow!(err)
}

View file

@ -3,6 +3,8 @@ pub mod tools;
use collections::HashMap;
use context_server::listener::McpServerTool;
use language_model::LanguageModelRegistry;
use language_models::provider::anthropic::AnthropicLanguageModelProvider;
use project::Project;
use settings::SettingsStore;
use smol::process::Child;
@ -11,6 +13,7 @@ use std::cell::RefCell;
use std::fmt::Display;
use std::path::Path;
use std::rc::Rc;
use std::sync::Arc;
use uuid::Uuid;
use agent_client_protocol as acp;
@ -96,12 +99,49 @@ impl AgentConnection for ClaudeAgentConnection {
anyhow::bail!("Failed to find claude binary");
};
let anthropic: Arc<AnthropicLanguageModelProvider> = cx.update(|cx| {
let registry = LanguageModelRegistry::global(cx);
let provider: Arc<dyn Any + Send + Sync> = registry
.read(cx)
.provider(&language_model::ANTHROPIC_PROVIDER_ID)
.context("Failed to get Anthropic provider")?;
Arc::downcast::<AnthropicLanguageModelProvider>(provider)
.map_err(|_| anyhow!("Failed to downcast provider"))
})??;
let api_key = cx
.update(|cx| language_models::provider::anthropic::ApiKey::get(cx))?
.update(|cx| AnthropicLanguageModelProvider::api_key(cx))?
.await
.map_err(|err| {
if err.is::<language_model::AuthenticateError>() {
anyhow!(AuthRequired)
let (update_tx, update_rx) = oneshot::channel();
let mut update_tx = Some(update_tx);
let sub = cx
.update(|cx| {
anthropic.observe(
move |_cx| {
if let Some(update_tx) = update_tx.take() {
update_tx.send(()).ok();
}
},
cx,
)
})
.ok();
let update_task = cx.foreground_executor().spawn(async move {
update_rx.await.ok();
drop(sub)
});
anyhow!(
AuthRequired::new()
.with_description(
"To use Claude Code in Zed, you need an [Anthropic API key](https://console.anthropic.com/settings/keys)\n\nAdd one in [settings](zed:///agent/settings) or set the `ANTHROPIC_API_KEY` variable".into())
.with_update(update_task)
)
} else {
anyhow!(err)
}

View file

@ -137,6 +137,7 @@ enum ThreadState {
LoadError(LoadError),
Unauthenticated {
connection: Rc<dyn AgentConnection>,
description: Option<Entity<Markdown>>,
},
ServerExited {
status: ExitStatus,
@ -269,15 +270,40 @@ impl AcpThreadView {
let result = match result.await {
Err(e) => {
let mut cx = cx.clone();
if e.is::<acp_thread::AuthRequired>() {
this.update(&mut cx, |this, cx| {
this.thread_state = ThreadState::Unauthenticated { connection };
cx.notify();
})
.ok();
return;
} else {
Err(e)
match e.downcast::<acp_thread::AuthRequired>() {
Ok(mut err) => {
if let Some(update_task) = err.update_task.take() {
let this = this.clone();
let project = project.clone();
cx.spawn(async move |cx| {
update_task.await;
this.update_in(cx, |this, window, cx| {
this.thread_state = Self::initial_state(
agent,
this.workspace.clone(),
project.clone(),
window,
cx,
);
cx.notify();
})
.ok();
})
.detach();
}
this.update(&mut cx, |this, cx| {
this.thread_state = ThreadState::Unauthenticated {
connection,
description: err.description.clone().map(|desc| {
cx.new(|cx| Markdown::new(desc.into(), None, None, cx))
}),
};
cx.notify();
})
.ok();
return;
}
Err(err) => Err(err),
}
}
Ok(thread) => Ok(thread),
@ -369,7 +395,7 @@ impl AcpThreadView {
ThreadState::Ready { thread, .. } => thread.read(cx).title(),
ThreadState::Loading { .. } => "Loading…".into(),
ThreadState::LoadError(_) => "Failed to load".into(),
ThreadState::Unauthenticated { .. } => "Not authenticated".into(),
ThreadState::Unauthenticated { .. } => "Authentication Required".into(),
ThreadState::ServerExited { .. } => "Server exited unexpectedly".into(),
}
}
@ -708,7 +734,7 @@ impl AcpThreadView {
window: &mut Window,
cx: &mut Context<Self>,
) {
let ThreadState::Unauthenticated { ref connection } = self.thread_state else {
let ThreadState::Unauthenticated { ref connection, .. } = self.thread_state else {
return;
};
@ -1851,7 +1877,7 @@ impl AcpThreadView {
.mt_4()
.mb_1()
.justify_center()
.child(Headline::new("Not Authenticated").size(HeadlineSize::Medium)),
.child(Headline::new("Authentication Required").size(HeadlineSize::Medium)),
)
.into_any()
}
@ -2778,6 +2804,13 @@ impl AcpThreadView {
cx.open_url(url.as_str());
}
})
} else if url == "zed:///agent/settings" {
workspace.update(cx, |workspace, cx| {
if let Some(panel) = workspace.panel::<AgentPanel>(cx) {
workspace.focus_panel::<AgentPanel>(window, cx);
panel.update(cx, |panel, cx| panel.open_configuration(window, cx));
}
});
} else {
cx.open_url(&url);
}
@ -3347,12 +3380,22 @@ impl Render for AcpThreadView {
.on_action(cx.listener(Self::toggle_burn_mode))
.bg(cx.theme().colors().panel_background)
.child(match &self.thread_state {
ThreadState::Unauthenticated { connection } => v_flex()
ThreadState::Unauthenticated {
connection,
description,
} => v_flex()
.p_2()
.gap_2()
.flex_1()
.items_center()
.justify_center()
.child(self.render_pending_auth_state())
.text_ui(cx)
.text_center()
.text_color(cx.theme().colors().text_muted)
.children(description.clone().map(|desc| {
self.render_markdown(desc, default_markdown_style(false, window, cx))
}))
.child(h_flex().mt_1p5().justify_center().children(
connection.auth_methods().into_iter().map(|method| {
Button::new(

View file

@ -201,3 +201,9 @@ impl Drop for Subscription {
}
}
}
impl std::fmt::Debug for Subscription {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Subscription").finish()
}
}

View file

@ -20,6 +20,7 @@ use icons::IconName;
use parking_lot::Mutex;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize, de::DeserializeOwned};
use std::any::Any;
use std::ops::{Add, Sub};
use std::str::FromStr;
use std::sync::Arc;
@ -620,7 +621,7 @@ pub enum AuthenticateError {
Other(#[from] anyhow::Error),
}
pub trait LanguageModelProvider: 'static {
pub trait LanguageModelProvider: Any + Send + Sync {
fn id(&self) -> LanguageModelProviderId;
fn name(&self) -> LanguageModelProviderName;
fn icon(&self) -> IconName {

View file

@ -108,6 +108,7 @@ pub enum Event {
CommitMessageModelChanged,
ThreadSummaryModelChanged,
ProviderStateChanged,
ProviderAuthUpdated,
AddedProvider(LanguageModelProviderId),
RemovedProvider(LanguageModelProviderId),
}

View file

@ -153,7 +153,7 @@ impl State {
return Task::ready(Ok(()));
}
let key = ApiKey::get(cx);
let key = AnthropicLanguageModelProvider::api_key(cx);
cx.spawn(async move |this, cx| {
let key = key.await?;
@ -174,8 +174,30 @@ pub struct ApiKey {
pub from_env: bool,
}
impl ApiKey {
pub fn get(cx: &mut App) -> Task<Result<Self>> {
impl AnthropicLanguageModelProvider {
pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
let state = cx.new(|cx| State {
api_key: None,
api_key_from_env: false,
_subscription: cx.observe_global::<SettingsStore>(|_, cx| {
cx.notify();
}),
});
Self { http_client, state }
}
fn create_language_model(&self, model: anthropic::Model) -> Arc<dyn LanguageModel> {
Arc::new(AnthropicModel {
id: LanguageModelId::from(model.id().to_string()),
model,
state: self.state.clone(),
http_client: self.http_client.clone(),
request_limiter: RateLimiter::new(4),
})
}
pub fn api_key(cx: &mut App) -> Task<Result<ApiKey>> {
let credentials_provider = <dyn CredentialsProvider>::global(cx);
let api_url = AllLanguageModelSettings::get_global(cx)
.anthropic
@ -201,29 +223,13 @@ impl ApiKey {
})
}
}
}
impl AnthropicLanguageModelProvider {
pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
let state = cx.new(|cx| State {
api_key: None,
api_key_from_env: false,
_subscription: cx.observe_global::<SettingsStore>(|_, cx| {
cx.notify();
}),
});
Self { http_client, state }
}
fn create_language_model(&self, model: anthropic::Model) -> Arc<dyn LanguageModel> {
Arc::new(AnthropicModel {
id: LanguageModelId::from(model.id().to_string()),
model,
state: self.state.clone(),
http_client: self.http_client.clone(),
request_limiter: RateLimiter::new(4),
})
pub fn observe(
&self,
mut on_notify: impl FnMut(&mut App) + 'static,
cx: &mut App,
) -> Subscription {
cx.observe(&self.state, move |_, cx| on_notify(cx))
}
}