Use anthropic provider key
This commit is contained in:
parent
1c43338056
commit
6c7a5c50bf
9 changed files with 171 additions and 45 deletions
|
@ -80,12 +80,35 @@ pub trait AgentSessionResume {
|
|||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct AuthRequired;
|
||||
pub struct AuthRequired {
|
||||
pub description: Option<String>,
|
||||
/// A Task that resolves when authentication is updated
|
||||
pub update_task: Option<Task<()>>,
|
||||
}
|
||||
|
||||
impl AuthRequired {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
description: None,
|
||||
update_task: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_description(mut self, description: String) -> Self {
|
||||
self.description = Some(description);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_update(mut self, update: Task<()>) -> Self {
|
||||
self.update_task = Some(update);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl Error for AuthRequired {}
|
||||
impl fmt::Display for AuthRequired {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "AuthRequired")
|
||||
write!(f, "Authentication required")
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -437,7 +437,7 @@ impl AgentConnection for AcpConnection {
|
|||
let result = acp_old::InitializeParams::response_from_any(result)?;
|
||||
|
||||
if !result.is_authenticated {
|
||||
anyhow::bail!(AuthRequired)
|
||||
anyhow::bail!(AuthRequired::new())
|
||||
}
|
||||
|
||||
cx.update(|cx| {
|
||||
|
|
|
@ -140,7 +140,13 @@ impl AgentConnection for AcpConnection {
|
|||
.await
|
||||
.map_err(|err| {
|
||||
if err.code == acp::ErrorCode::AUTH_REQUIRED.code {
|
||||
anyhow!(AuthRequired)
|
||||
let mut error = AuthRequired::new();
|
||||
|
||||
if err.message != acp::ErrorCode::AUTH_REQUIRED.message {
|
||||
error = error.with_description(err.message);
|
||||
}
|
||||
|
||||
anyhow!(error)
|
||||
} else {
|
||||
anyhow!(err)
|
||||
}
|
||||
|
|
|
@ -3,6 +3,8 @@ pub mod tools;
|
|||
|
||||
use collections::HashMap;
|
||||
use context_server::listener::McpServerTool;
|
||||
use language_model::LanguageModelRegistry;
|
||||
use language_models::provider::anthropic::AnthropicLanguageModelProvider;
|
||||
use project::Project;
|
||||
use settings::SettingsStore;
|
||||
use smol::process::Child;
|
||||
|
@ -11,6 +13,7 @@ use std::cell::RefCell;
|
|||
use std::fmt::Display;
|
||||
use std::path::Path;
|
||||
use std::rc::Rc;
|
||||
use std::sync::Arc;
|
||||
use uuid::Uuid;
|
||||
|
||||
use agent_client_protocol as acp;
|
||||
|
@ -96,12 +99,49 @@ impl AgentConnection for ClaudeAgentConnection {
|
|||
anyhow::bail!("Failed to find claude binary");
|
||||
};
|
||||
|
||||
let anthropic: Arc<AnthropicLanguageModelProvider> = cx.update(|cx| {
|
||||
let registry = LanguageModelRegistry::global(cx);
|
||||
let provider: Arc<dyn Any + Send + Sync> = registry
|
||||
.read(cx)
|
||||
.provider(&language_model::ANTHROPIC_PROVIDER_ID)
|
||||
.context("Failed to get Anthropic provider")?;
|
||||
|
||||
Arc::downcast::<AnthropicLanguageModelProvider>(provider)
|
||||
.map_err(|_| anyhow!("Failed to downcast provider"))
|
||||
})??;
|
||||
|
||||
let api_key = cx
|
||||
.update(|cx| language_models::provider::anthropic::ApiKey::get(cx))?
|
||||
.update(|cx| AnthropicLanguageModelProvider::api_key(cx))?
|
||||
.await
|
||||
.map_err(|err| {
|
||||
if err.is::<language_model::AuthenticateError>() {
|
||||
anyhow!(AuthRequired)
|
||||
let (update_tx, update_rx) = oneshot::channel();
|
||||
let mut update_tx = Some(update_tx);
|
||||
|
||||
let sub = cx
|
||||
.update(|cx| {
|
||||
anthropic.observe(
|
||||
move |_cx| {
|
||||
if let Some(update_tx) = update_tx.take() {
|
||||
update_tx.send(()).ok();
|
||||
}
|
||||
},
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.ok();
|
||||
|
||||
let update_task = cx.foreground_executor().spawn(async move {
|
||||
update_rx.await.ok();
|
||||
drop(sub)
|
||||
});
|
||||
|
||||
anyhow!(
|
||||
AuthRequired::new()
|
||||
.with_description(
|
||||
"To use Claude Code in Zed, you need an [Anthropic API key](https://console.anthropic.com/settings/keys)\n\nAdd one in [settings](zed:///agent/settings) or set the `ANTHROPIC_API_KEY` variable".into())
|
||||
.with_update(update_task)
|
||||
)
|
||||
} else {
|
||||
anyhow!(err)
|
||||
}
|
||||
|
|
|
@ -137,6 +137,7 @@ enum ThreadState {
|
|||
LoadError(LoadError),
|
||||
Unauthenticated {
|
||||
connection: Rc<dyn AgentConnection>,
|
||||
description: Option<Entity<Markdown>>,
|
||||
},
|
||||
ServerExited {
|
||||
status: ExitStatus,
|
||||
|
@ -269,15 +270,40 @@ impl AcpThreadView {
|
|||
let result = match result.await {
|
||||
Err(e) => {
|
||||
let mut cx = cx.clone();
|
||||
if e.is::<acp_thread::AuthRequired>() {
|
||||
this.update(&mut cx, |this, cx| {
|
||||
this.thread_state = ThreadState::Unauthenticated { connection };
|
||||
cx.notify();
|
||||
})
|
||||
.ok();
|
||||
return;
|
||||
} else {
|
||||
Err(e)
|
||||
match e.downcast::<acp_thread::AuthRequired>() {
|
||||
Ok(mut err) => {
|
||||
if let Some(update_task) = err.update_task.take() {
|
||||
let this = this.clone();
|
||||
let project = project.clone();
|
||||
cx.spawn(async move |cx| {
|
||||
update_task.await;
|
||||
this.update_in(cx, |this, window, cx| {
|
||||
this.thread_state = Self::initial_state(
|
||||
agent,
|
||||
this.workspace.clone(),
|
||||
project.clone(),
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
cx.notify();
|
||||
})
|
||||
.ok();
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
this.update(&mut cx, |this, cx| {
|
||||
this.thread_state = ThreadState::Unauthenticated {
|
||||
connection,
|
||||
description: err.description.clone().map(|desc| {
|
||||
cx.new(|cx| Markdown::new(desc.into(), None, None, cx))
|
||||
}),
|
||||
};
|
||||
cx.notify();
|
||||
})
|
||||
.ok();
|
||||
return;
|
||||
}
|
||||
Err(err) => Err(err),
|
||||
}
|
||||
}
|
||||
Ok(thread) => Ok(thread),
|
||||
|
@ -369,7 +395,7 @@ impl AcpThreadView {
|
|||
ThreadState::Ready { thread, .. } => thread.read(cx).title(),
|
||||
ThreadState::Loading { .. } => "Loading…".into(),
|
||||
ThreadState::LoadError(_) => "Failed to load".into(),
|
||||
ThreadState::Unauthenticated { .. } => "Not authenticated".into(),
|
||||
ThreadState::Unauthenticated { .. } => "Authentication Required".into(),
|
||||
ThreadState::ServerExited { .. } => "Server exited unexpectedly".into(),
|
||||
}
|
||||
}
|
||||
|
@ -708,7 +734,7 @@ impl AcpThreadView {
|
|||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
let ThreadState::Unauthenticated { ref connection } = self.thread_state else {
|
||||
let ThreadState::Unauthenticated { ref connection, .. } = self.thread_state else {
|
||||
return;
|
||||
};
|
||||
|
||||
|
@ -1851,7 +1877,7 @@ impl AcpThreadView {
|
|||
.mt_4()
|
||||
.mb_1()
|
||||
.justify_center()
|
||||
.child(Headline::new("Not Authenticated").size(HeadlineSize::Medium)),
|
||||
.child(Headline::new("Authentication Required").size(HeadlineSize::Medium)),
|
||||
)
|
||||
.into_any()
|
||||
}
|
||||
|
@ -2778,6 +2804,13 @@ impl AcpThreadView {
|
|||
cx.open_url(url.as_str());
|
||||
}
|
||||
})
|
||||
} else if url == "zed:///agent/settings" {
|
||||
workspace.update(cx, |workspace, cx| {
|
||||
if let Some(panel) = workspace.panel::<AgentPanel>(cx) {
|
||||
workspace.focus_panel::<AgentPanel>(window, cx);
|
||||
panel.update(cx, |panel, cx| panel.open_configuration(window, cx));
|
||||
}
|
||||
});
|
||||
} else {
|
||||
cx.open_url(&url);
|
||||
}
|
||||
|
@ -3347,12 +3380,22 @@ impl Render for AcpThreadView {
|
|||
.on_action(cx.listener(Self::toggle_burn_mode))
|
||||
.bg(cx.theme().colors().panel_background)
|
||||
.child(match &self.thread_state {
|
||||
ThreadState::Unauthenticated { connection } => v_flex()
|
||||
ThreadState::Unauthenticated {
|
||||
connection,
|
||||
description,
|
||||
} => v_flex()
|
||||
.p_2()
|
||||
.gap_2()
|
||||
.flex_1()
|
||||
.items_center()
|
||||
.justify_center()
|
||||
.child(self.render_pending_auth_state())
|
||||
.text_ui(cx)
|
||||
.text_center()
|
||||
.text_color(cx.theme().colors().text_muted)
|
||||
.children(description.clone().map(|desc| {
|
||||
self.render_markdown(desc, default_markdown_style(false, window, cx))
|
||||
}))
|
||||
.child(h_flex().mt_1p5().justify_center().children(
|
||||
connection.auth_methods().into_iter().map(|method| {
|
||||
Button::new(
|
||||
|
|
|
@ -201,3 +201,9 @@ impl Drop for Subscription {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for Subscription {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("Subscription").finish()
|
||||
}
|
||||
}
|
||||
|
|
|
@ -20,6 +20,7 @@ use icons::IconName;
|
|||
use parking_lot::Mutex;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize, de::DeserializeOwned};
|
||||
use std::any::Any;
|
||||
use std::ops::{Add, Sub};
|
||||
use std::str::FromStr;
|
||||
use std::sync::Arc;
|
||||
|
@ -620,7 +621,7 @@ pub enum AuthenticateError {
|
|||
Other(#[from] anyhow::Error),
|
||||
}
|
||||
|
||||
pub trait LanguageModelProvider: 'static {
|
||||
pub trait LanguageModelProvider: Any + Send + Sync {
|
||||
fn id(&self) -> LanguageModelProviderId;
|
||||
fn name(&self) -> LanguageModelProviderName;
|
||||
fn icon(&self) -> IconName {
|
||||
|
|
|
@ -108,6 +108,7 @@ pub enum Event {
|
|||
CommitMessageModelChanged,
|
||||
ThreadSummaryModelChanged,
|
||||
ProviderStateChanged,
|
||||
ProviderAuthUpdated,
|
||||
AddedProvider(LanguageModelProviderId),
|
||||
RemovedProvider(LanguageModelProviderId),
|
||||
}
|
||||
|
|
|
@ -153,7 +153,7 @@ impl State {
|
|||
return Task::ready(Ok(()));
|
||||
}
|
||||
|
||||
let key = ApiKey::get(cx);
|
||||
let key = AnthropicLanguageModelProvider::api_key(cx);
|
||||
|
||||
cx.spawn(async move |this, cx| {
|
||||
let key = key.await?;
|
||||
|
@ -174,8 +174,30 @@ pub struct ApiKey {
|
|||
pub from_env: bool,
|
||||
}
|
||||
|
||||
impl ApiKey {
|
||||
pub fn get(cx: &mut App) -> Task<Result<Self>> {
|
||||
impl AnthropicLanguageModelProvider {
|
||||
pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
|
||||
let state = cx.new(|cx| State {
|
||||
api_key: None,
|
||||
api_key_from_env: false,
|
||||
_subscription: cx.observe_global::<SettingsStore>(|_, cx| {
|
||||
cx.notify();
|
||||
}),
|
||||
});
|
||||
|
||||
Self { http_client, state }
|
||||
}
|
||||
|
||||
fn create_language_model(&self, model: anthropic::Model) -> Arc<dyn LanguageModel> {
|
||||
Arc::new(AnthropicModel {
|
||||
id: LanguageModelId::from(model.id().to_string()),
|
||||
model,
|
||||
state: self.state.clone(),
|
||||
http_client: self.http_client.clone(),
|
||||
request_limiter: RateLimiter::new(4),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn api_key(cx: &mut App) -> Task<Result<ApiKey>> {
|
||||
let credentials_provider = <dyn CredentialsProvider>::global(cx);
|
||||
let api_url = AllLanguageModelSettings::get_global(cx)
|
||||
.anthropic
|
||||
|
@ -201,29 +223,13 @@ impl ApiKey {
|
|||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AnthropicLanguageModelProvider {
|
||||
pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
|
||||
let state = cx.new(|cx| State {
|
||||
api_key: None,
|
||||
api_key_from_env: false,
|
||||
_subscription: cx.observe_global::<SettingsStore>(|_, cx| {
|
||||
cx.notify();
|
||||
}),
|
||||
});
|
||||
|
||||
Self { http_client, state }
|
||||
}
|
||||
|
||||
fn create_language_model(&self, model: anthropic::Model) -> Arc<dyn LanguageModel> {
|
||||
Arc::new(AnthropicModel {
|
||||
id: LanguageModelId::from(model.id().to_string()),
|
||||
model,
|
||||
state: self.state.clone(),
|
||||
http_client: self.http_client.clone(),
|
||||
request_limiter: RateLimiter::new(4),
|
||||
})
|
||||
pub fn observe(
|
||||
&self,
|
||||
mut on_notify: impl FnMut(&mut App) + 'static,
|
||||
cx: &mut App,
|
||||
) -> Subscription {
|
||||
cx.observe(&self.state, move |_, cx| on_notify(cx))
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue