diff --git a/Cargo.lock b/Cargo.lock index 3158a61ad8..3bc2b63843 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -20,6 +20,7 @@ dependencies = [ "indoc", "itertools 0.14.0", "language", + "language_model", "markdown", "parking_lot", "project", @@ -267,6 +268,8 @@ dependencies = [ "indoc", "itertools 0.14.0", "language", + "language_model", + "language_models", "libc", "log", "nix 0.29.0", diff --git a/crates/acp_thread/Cargo.toml b/crates/acp_thread/Cargo.toml index 2b9a6513c8..173f4c4208 100644 --- a/crates/acp_thread/Cargo.toml +++ b/crates/acp_thread/Cargo.toml @@ -28,6 +28,7 @@ futures.workspace = true gpui.workspace = true itertools.workspace = true language.workspace = true +language_model.workspace = true markdown.workspace = true parking_lot = { workspace = true, optional = true } project.workspace = true diff --git a/crates/acp_thread/src/connection.rs b/crates/acp_thread/src/connection.rs index a328499bbc..0d4116321d 100644 --- a/crates/acp_thread/src/connection.rs +++ b/crates/acp_thread/src/connection.rs @@ -3,6 +3,7 @@ use agent_client_protocol::{self as acp}; use anyhow::Result; use collections::IndexMap; use gpui::{Entity, SharedString, Task}; +use language_model::LanguageModelProviderId; use project::Project; use std::{any::Any, error::Error, fmt, path::Path, rc::Rc, sync::Arc}; use ui::{App, IconName}; @@ -80,12 +81,34 @@ pub trait AgentSessionResume { } #[derive(Debug)] -pub struct AuthRequired; +pub struct AuthRequired { + pub description: Option, + pub provider_id: Option, +} + +impl AuthRequired { + pub fn new() -> Self { + Self { + description: None, + provider_id: None, + } + } + + pub fn with_description(mut self, description: String) -> Self { + self.description = Some(description); + self + } + + pub fn with_language_model_provider(mut self, provider_id: LanguageModelProviderId) -> Self { + self.provider_id = Some(provider_id); + self + } +} impl Error for AuthRequired {} impl fmt::Display for AuthRequired { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "AuthRequired") + write!(f, "Authentication required") } } diff --git a/crates/agent_servers/Cargo.toml b/crates/agent_servers/Cargo.toml index 81c97c8aa6..f894bb15bf 100644 --- a/crates/agent_servers/Cargo.toml +++ b/crates/agent_servers/Cargo.toml @@ -27,6 +27,8 @@ futures.workspace = true gpui.workspace = true indoc.workspace = true itertools.workspace = true +language_model.workspace = true +language_models.workspace = true log.workspace = true paths.workspace = true project.workspace = true diff --git a/crates/agent_servers/src/acp/v0.rs b/crates/agent_servers/src/acp/v0.rs index 74647f7313..551e9fa01a 100644 --- a/crates/agent_servers/src/acp/v0.rs +++ b/crates/agent_servers/src/acp/v0.rs @@ -437,7 +437,7 @@ impl AgentConnection for AcpConnection { let result = acp_old::InitializeParams::response_from_any(result)?; if !result.is_authenticated { - anyhow::bail!(AuthRequired) + anyhow::bail!(AuthRequired::new()) } cx.update(|cx| { diff --git a/crates/agent_servers/src/acp/v1.rs b/crates/agent_servers/src/acp/v1.rs index b77b5ef36d..93a5ae757a 100644 --- a/crates/agent_servers/src/acp/v1.rs +++ b/crates/agent_servers/src/acp/v1.rs @@ -140,7 +140,13 @@ impl AgentConnection for AcpConnection { .await .map_err(|err| { if err.code == acp::ErrorCode::AUTH_REQUIRED.code { - anyhow!(AuthRequired) + let mut error = AuthRequired::new(); + + if err.message != acp::ErrorCode::AUTH_REQUIRED.message { + error = error.with_description(err.message); + } + + anyhow!(error) } else { anyhow!(err) } diff --git a/crates/agent_servers/src/claude.rs b/crates/agent_servers/src/claude.rs index d15cc1dd89..d80d040aad 100644 --- a/crates/agent_servers/src/claude.rs +++ b/crates/agent_servers/src/claude.rs @@ -3,6 +3,7 @@ pub mod tools; use collections::HashMap; use context_server::listener::McpServerTool; +use language_models::provider::anthropic::AnthropicLanguageModelProvider; use project::Project; use settings::SettingsStore; use smol::process::Child; @@ -30,7 +31,7 @@ use util::{ResultExt, debug_panic}; use crate::claude::mcp_server::{ClaudeZedMcpServer, McpConfig}; use crate::claude::tools::ClaudeTool; use crate::{AgentServer, AgentServerCommand, AllAgentServersSettings}; -use acp_thread::{AcpThread, AgentConnection}; +use acp_thread::{AcpThread, AgentConnection, AuthRequired}; #[derive(Clone)] pub struct ClaudeCode; @@ -79,6 +80,36 @@ impl AgentConnection for ClaudeAgentConnection { ) -> Task>> { let cwd = cwd.to_owned(); cx.spawn(async move |cx| { + let settings = cx.read_global(|settings: &SettingsStore, _| { + settings.get::(None).claude.clone() + })?; + + let Some(command) = AgentServerCommand::resolve( + "claude", + &[], + Some(&util::paths::home_dir().join(".claude/local/claude")), + settings, + &project, + cx, + ) + .await + else { + anyhow::bail!("Failed to find claude binary"); + }; + + let api_key = + cx.update(AnthropicLanguageModelProvider::api_key)? + .await + .map_err(|err| { + if err.is::() { + anyhow!(AuthRequired::new().with_language_model_provider( + language_model::ANTHROPIC_PROVIDER_ID + )) + } else { + anyhow!(err) + } + })?; + let (mut thread_tx, thread_rx) = watch::channel(WeakEntity::new_invalid()); let permission_mcp_server = ClaudeZedMcpServer::new(thread_rx.clone(), cx).await?; @@ -98,23 +129,6 @@ impl AgentConnection for ClaudeAgentConnection { .await?; mcp_config_file.flush().await?; - let settings = cx.read_global(|settings: &SettingsStore, _| { - settings.get::(None).claude.clone() - })?; - - let Some(command) = AgentServerCommand::resolve( - "claude", - &[], - Some(&util::paths::home_dir().join(".claude/local/claude")), - settings, - &project, - cx, - ) - .await - else { - anyhow::bail!("Failed to find claude binary"); - }; - let (incoming_message_tx, mut incoming_message_rx) = mpsc::unbounded(); let (outgoing_tx, outgoing_rx) = mpsc::unbounded(); @@ -126,6 +140,7 @@ impl AgentConnection for ClaudeAgentConnection { &command, ClaudeSessionMode::Start, session_id.clone(), + api_key, &mcp_config_path, &cwd, )?; @@ -320,6 +335,7 @@ fn spawn_claude( command: &AgentServerCommand, mode: ClaudeSessionMode, session_id: acp::SessionId, + api_key: language_models::provider::anthropic::ApiKey, mcp_config_path: &Path, root_dir: &Path, ) -> Result { @@ -355,6 +371,8 @@ fn spawn_claude( ClaudeSessionMode::Resume => ["--resume".to_string(), session_id.to_string()], }) .args(command.args.iter().map(|arg| arg.as_str())) + .envs(command.env.iter().flatten()) + .env("ANTHROPIC_API_KEY", api_key.key) .current_dir(root_dir) .stdin(std::process::Stdio::piped()) .stdout(std::process::Stdio::piped()) diff --git a/crates/agent_ui/src/acp/thread_view.rs b/crates/agent_ui/src/acp/thread_view.rs index 2c02027c4d..e2e5820812 100644 --- a/crates/agent_ui/src/acp/thread_view.rs +++ b/crates/agent_ui/src/acp/thread_view.rs @@ -1,6 +1,7 @@ use acp_thread::{ AcpThread, AcpThreadEvent, AgentThreadEntry, AssistantMessage, AssistantMessageChunk, - LoadError, MentionUri, ThreadStatus, ToolCall, ToolCallContent, ToolCallStatus, UserMessageId, + AuthRequired, LoadError, MentionUri, ThreadStatus, ToolCall, ToolCallContent, ToolCallStatus, + UserMessageId, }; use acp_thread::{AgentConnection, Plan}; use action_log::ActionLog; @@ -18,13 +19,16 @@ use editor::{Editor, EditorMode, MultiBuffer, PathKey, SelectionEffects}; use file_icons::FileIcons; use fs::Fs; use gpui::{ - Action, Animation, AnimationExt, App, BorderStyle, ClickEvent, ClipboardItem, EdgesRefinement, - Empty, Entity, FocusHandle, Focusable, Hsla, Length, ListOffset, ListState, MouseButton, - PlatformDisplay, SharedString, Stateful, StyleRefinement, Subscription, Task, TextStyle, - TextStyleRefinement, Transformation, UnderlineStyle, WeakEntity, Window, WindowHandle, div, - linear_color_stop, linear_gradient, list, percentage, point, prelude::*, pulsating_between, + Action, Animation, AnimationExt, AnyView, App, BorderStyle, ClickEvent, ClipboardItem, + EdgesRefinement, Empty, Entity, FocusHandle, Focusable, Hsla, Length, ListOffset, ListState, + MouseButton, PlatformDisplay, SharedString, Stateful, StyleRefinement, Subscription, Task, + TextStyle, TextStyleRefinement, Transformation, UnderlineStyle, WeakEntity, Window, + WindowHandle, div, linear_color_stop, linear_gradient, list, percentage, point, prelude::*, + pulsating_between, }; use language::Buffer; + +use language_model::LanguageModelRegistry; use markdown::{HeadingLevelStyles, Markdown, MarkdownElement, MarkdownStyle}; use project::Project; use prompt_store::PromptId; @@ -137,6 +141,9 @@ enum ThreadState { LoadError(LoadError), Unauthenticated { connection: Rc, + description: Option>, + configuration_view: Option, + _subscription: Option, }, ServerExited { status: ExitStatus, @@ -267,19 +274,16 @@ impl AcpThreadView { }; let result = match result.await { - Err(e) => { - let mut cx = cx.clone(); - if e.is::() { - this.update(&mut cx, |this, cx| { - this.thread_state = ThreadState::Unauthenticated { connection }; - cx.notify(); + Err(e) => match e.downcast::() { + Ok(err) => { + cx.update(|window, cx| { + Self::handle_auth_required(this, err, agent, connection, window, cx) }) - .ok(); + .log_err(); return; - } else { - Err(e) } - } + Err(err) => Err(err), + }, Ok(thread) => Ok(thread), }; @@ -345,6 +349,68 @@ impl AcpThreadView { ThreadState::Loading { _task: load_task } } + fn handle_auth_required( + this: WeakEntity, + err: AuthRequired, + agent: Rc, + connection: Rc, + window: &mut Window, + cx: &mut App, + ) { + let agent_name = agent.name(); + let (configuration_view, subscription) = if let Some(provider_id) = err.provider_id { + let registry = LanguageModelRegistry::global(cx); + + let sub = window.subscribe(®istry, cx, { + let provider_id = provider_id.clone(); + let this = this.clone(); + move |_, ev, window, cx| { + if let language_model::Event::ProviderStateChanged(updated_provider_id) = &ev { + if &provider_id == updated_provider_id { + this.update(cx, |this, cx| { + this.thread_state = Self::initial_state( + agent.clone(), + this.workspace.clone(), + this.project.clone(), + window, + cx, + ); + cx.notify(); + }) + .ok(); + } + } + } + }); + + let view = registry.read(cx).provider(&provider_id).map(|provider| { + provider.configuration_view( + language_model::ConfigurationViewTargetAgent::Other(agent_name), + window, + cx, + ) + }); + + (view, Some(sub)) + } else { + (None, None) + }; + + this.update(cx, |this, cx| { + this.thread_state = ThreadState::Unauthenticated { + connection, + configuration_view, + description: err + .description + .clone() + .map(|desc| cx.new(|cx| Markdown::new(desc.into(), None, None, cx))), + _subscription: subscription, + }; + cx.notify(); + }) + .ok(); + } + fn handle_load_error(&mut self, err: anyhow::Error, cx: &mut Context) { if let Some(load_err) = err.downcast_ref::() { self.thread_state = ThreadState::LoadError(load_err.clone()); @@ -369,7 +435,7 @@ impl AcpThreadView { ThreadState::Ready { thread, .. } => thread.read(cx).title(), ThreadState::Loading { .. } => "Loading…".into(), ThreadState::LoadError(_) => "Failed to load".into(), - ThreadState::Unauthenticated { .. } => "Not authenticated".into(), + ThreadState::Unauthenticated { .. } => "Authentication Required".into(), ThreadState::ServerExited { .. } => "Server exited unexpectedly".into(), } } @@ -708,7 +774,7 @@ impl AcpThreadView { window: &mut Window, cx: &mut Context, ) { - let ThreadState::Unauthenticated { ref connection } = self.thread_state else { + let ThreadState::Unauthenticated { ref connection, .. } = self.thread_state else { return; }; @@ -1841,19 +1907,53 @@ impl AcpThreadView { .into_any() } - fn render_pending_auth_state(&self) -> AnyElement { + fn render_auth_required_state( + &self, + connection: &Rc, + description: Option<&Entity>, + configuration_view: Option<&AnyView>, + window: &mut Window, + cx: &Context, + ) -> Div { v_flex() + .p_2() + .gap_2() + .flex_1() .items_center() .justify_center() - .child(self.render_error_agent_logo()) .child( - h_flex() - .mt_4() - .mb_1() + v_flex() + .items_center() .justify_center() - .child(Headline::new("Not Authenticated").size(HeadlineSize::Medium)), + .child(self.render_error_agent_logo()) + .child( + h_flex().mt_4().mb_1().justify_center().child( + Headline::new("Authentication Required").size(HeadlineSize::Medium), + ), + ) + .into_any(), ) - .into_any() + .children(description.map(|desc| { + div().text_ui(cx).text_center().child( + self.render_markdown(desc.clone(), default_markdown_style(false, window, cx)), + ) + })) + .children( + configuration_view + .cloned() + .map(|view| div().px_4().w_full().max_w_128().child(view)), + ) + .child(h_flex().mt_1p5().justify_center().children( + connection.auth_methods().into_iter().map(|method| { + Button::new(SharedString::from(method.id.0.clone()), method.name.clone()) + .on_click({ + let method_id = method.id.clone(); + cx.listener(move |this, _, window, cx| { + this.authenticate(method_id.clone(), window, cx) + }) + }) + }), + )) } fn render_server_exited(&self, status: ExitStatus, _cx: &Context) -> AnyElement { @@ -3347,26 +3447,18 @@ impl Render for AcpThreadView { .on_action(cx.listener(Self::toggle_burn_mode)) .bg(cx.theme().colors().panel_background) .child(match &self.thread_state { - ThreadState::Unauthenticated { connection } => v_flex() - .p_2() - .flex_1() - .items_center() - .justify_center() - .child(self.render_pending_auth_state()) - .child(h_flex().mt_1p5().justify_center().children( - connection.auth_methods().into_iter().map(|method| { - Button::new( - SharedString::from(method.id.0.clone()), - method.name.clone(), - ) - .on_click({ - let method_id = method.id.clone(); - cx.listener(move |this, _, window, cx| { - this.authenticate(method_id.clone(), window, cx) - }) - }) - }), - )), + ThreadState::Unauthenticated { + connection, + description, + configuration_view, + .. + } => self.render_auth_required_state( + &connection, + description.as_ref(), + configuration_view.as_ref(), + window, + cx, + ), ThreadState::Loading { .. } => v_flex().flex_1().child(self.render_empty_state(cx)), ThreadState::LoadError(e) => v_flex() .p_2() diff --git a/crates/agent_ui/src/agent_configuration.rs b/crates/agent_ui/src/agent_configuration.rs index b4ebb8206c..a0584f9e2e 100644 --- a/crates/agent_ui/src/agent_configuration.rs +++ b/crates/agent_ui/src/agent_configuration.rs @@ -137,7 +137,11 @@ impl AgentConfiguration { window: &mut Window, cx: &mut Context, ) { - let configuration_view = provider.configuration_view(window, cx); + let configuration_view = provider.configuration_view( + language_model::ConfigurationViewTargetAgent::ZedAgent, + window, + cx, + ); self.configuration_views_by_provider .insert(provider.id(), configuration_view); } diff --git a/crates/agent_ui/src/agent_ui.rs b/crates/agent_ui/src/agent_ui.rs index ce1c2203bf..8525d7f9e5 100644 --- a/crates/agent_ui/src/agent_ui.rs +++ b/crates/agent_ui/src/agent_ui.rs @@ -320,7 +320,7 @@ fn init_language_model_settings(cx: &mut App) { cx.subscribe( &LanguageModelRegistry::global(cx), |_, event: &language_model::Event, cx| match event { - language_model::Event::ProviderStateChanged + language_model::Event::ProviderStateChanged(_) | language_model::Event::AddedProvider(_) | language_model::Event::RemovedProvider(_) => { update_active_language_model_from_settings(cx); diff --git a/crates/agent_ui/src/language_model_selector.rs b/crates/agent_ui/src/language_model_selector.rs index bb8514a224..fa8ca490d8 100644 --- a/crates/agent_ui/src/language_model_selector.rs +++ b/crates/agent_ui/src/language_model_selector.rs @@ -104,7 +104,7 @@ impl LanguageModelPickerDelegate { window, |picker, _, event, window, cx| { match event { - language_model::Event::ProviderStateChanged + language_model::Event::ProviderStateChanged(_) | language_model::Event::AddedProvider(_) | language_model::Event::RemovedProvider(_) => { let query = picker.query(cx); diff --git a/crates/ai_onboarding/src/agent_api_keys_onboarding.rs b/crates/ai_onboarding/src/agent_api_keys_onboarding.rs index b55ad4c895..0a34a29068 100644 --- a/crates/ai_onboarding/src/agent_api_keys_onboarding.rs +++ b/crates/ai_onboarding/src/agent_api_keys_onboarding.rs @@ -11,7 +11,7 @@ impl ApiKeysWithProviders { cx.subscribe( &LanguageModelRegistry::global(cx), |this: &mut Self, _registry, event: &language_model::Event, cx| match event { - language_model::Event::ProviderStateChanged + language_model::Event::ProviderStateChanged(_) | language_model::Event::AddedProvider(_) | language_model::Event::RemovedProvider(_) => { this.configured_providers = Self::compute_configured_providers(cx) diff --git a/crates/ai_onboarding/src/agent_panel_onboarding_content.rs b/crates/ai_onboarding/src/agent_panel_onboarding_content.rs index f1629eeff8..23810b74f3 100644 --- a/crates/ai_onboarding/src/agent_panel_onboarding_content.rs +++ b/crates/ai_onboarding/src/agent_panel_onboarding_content.rs @@ -25,7 +25,7 @@ impl AgentPanelOnboarding { cx.subscribe( &LanguageModelRegistry::global(cx), |this: &mut Self, _registry, event: &language_model::Event, cx| match event { - language_model::Event::ProviderStateChanged + language_model::Event::ProviderStateChanged(_) | language_model::Event::AddedProvider(_) | language_model::Event::RemovedProvider(_) => { this.configured_providers = Self::compute_available_providers(cx) diff --git a/crates/gpui/src/subscription.rs b/crates/gpui/src/subscription.rs index a584f1a45f..bd869f8d32 100644 --- a/crates/gpui/src/subscription.rs +++ b/crates/gpui/src/subscription.rs @@ -201,3 +201,9 @@ impl Drop for Subscription { } } } + +impl std::fmt::Debug for Subscription { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Subscription").finish() + } +} diff --git a/crates/language_model/src/fake_provider.rs b/crates/language_model/src/fake_provider.rs index a9c7d5c034..67fba44887 100644 --- a/crates/language_model/src/fake_provider.rs +++ b/crates/language_model/src/fake_provider.rs @@ -1,8 +1,8 @@ use crate::{ - AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, - LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, - LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, - LanguageModelToolChoice, + AuthenticateError, ConfigurationViewTargetAgent, LanguageModel, LanguageModelCompletionError, + LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider, + LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, + LanguageModelRequest, LanguageModelToolChoice, }; use futures::{FutureExt, StreamExt, channel::mpsc, future::BoxFuture, stream::BoxStream}; use gpui::{AnyView, App, AsyncApp, Entity, Task, Window}; @@ -62,7 +62,12 @@ impl LanguageModelProvider for FakeLanguageModelProvider { Task::ready(Ok(())) } - fn configuration_view(&self, _window: &mut Window, _: &mut App) -> AnyView { + fn configuration_view( + &self, + _target_agent: ConfigurationViewTargetAgent, + _window: &mut Window, + _: &mut App, + ) -> AnyView { unimplemented!() } diff --git a/crates/language_model/src/language_model.rs b/crates/language_model/src/language_model.rs index 1637d2de8a..70e42cb02d 100644 --- a/crates/language_model/src/language_model.rs +++ b/crates/language_model/src/language_model.rs @@ -634,7 +634,12 @@ pub trait LanguageModelProvider: 'static { } fn is_authenticated(&self, cx: &App) -> bool; fn authenticate(&self, cx: &mut App) -> Task>; - fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView; + fn configuration_view( + &self, + target_agent: ConfigurationViewTargetAgent, + window: &mut Window, + cx: &mut App, + ) -> AnyView; fn must_accept_terms(&self, _cx: &App) -> bool { false } @@ -648,6 +653,13 @@ pub trait LanguageModelProvider: 'static { fn reset_credentials(&self, cx: &mut App) -> Task>; } +#[derive(Default, Clone, Copy)] +pub enum ConfigurationViewTargetAgent { + #[default] + ZedAgent, + Other(&'static str), +} + #[derive(PartialEq, Eq)] pub enum LanguageModelProviderTosView { /// When there are some past interactions in the Agent Panel. diff --git a/crates/language_model/src/registry.rs b/crates/language_model/src/registry.rs index 7cf071808a..078b90a291 100644 --- a/crates/language_model/src/registry.rs +++ b/crates/language_model/src/registry.rs @@ -107,7 +107,7 @@ pub enum Event { InlineAssistantModelChanged, CommitMessageModelChanged, ThreadSummaryModelChanged, - ProviderStateChanged, + ProviderStateChanged(LanguageModelProviderId), AddedProvider(LanguageModelProviderId), RemovedProvider(LanguageModelProviderId), } @@ -148,8 +148,11 @@ impl LanguageModelRegistry { ) { let id = provider.id(); - let subscription = provider.subscribe(cx, |_, cx| { - cx.emit(Event::ProviderStateChanged); + let subscription = provider.subscribe(cx, { + let id = id.clone(); + move |_, cx| { + cx.emit(Event::ProviderStateChanged(id.clone())); + } }); if let Some(subscription) = subscription { subscription.detach(); diff --git a/crates/language_models/src/provider/anthropic.rs b/crates/language_models/src/provider/anthropic.rs index ef21e85f71..810d4a5f44 100644 --- a/crates/language_models/src/provider/anthropic.rs +++ b/crates/language_models/src/provider/anthropic.rs @@ -15,11 +15,11 @@ use gpui::{ }; use http_client::HttpClient; use language_model::{ - AuthenticateError, LanguageModel, LanguageModelCacheConfiguration, - LanguageModelCompletionError, LanguageModelId, LanguageModelName, LanguageModelProvider, - LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, - LanguageModelRequest, LanguageModelToolChoice, LanguageModelToolResultContent, MessageContent, - RateLimiter, Role, + AuthenticateError, ConfigurationViewTargetAgent, LanguageModel, + LanguageModelCacheConfiguration, LanguageModelCompletionError, LanguageModelId, + LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, + LanguageModelProviderState, LanguageModelRequest, LanguageModelToolChoice, + LanguageModelToolResultContent, MessageContent, RateLimiter, Role, }; use language_model::{LanguageModelCompletionEvent, LanguageModelToolUse, StopReason}; use schemars::JsonSchema; @@ -153,29 +153,14 @@ impl State { return Task::ready(Ok(())); } - let credentials_provider = ::global(cx); - let api_url = AllLanguageModelSettings::get_global(cx) - .anthropic - .api_url - .clone(); + let key = AnthropicLanguageModelProvider::api_key(cx); cx.spawn(async move |this, cx| { - let (api_key, from_env) = if let Ok(api_key) = std::env::var(ANTHROPIC_API_KEY_VAR) { - (api_key, true) - } else { - let (_, api_key) = credentials_provider - .read_credentials(&api_url, &cx) - .await? - .ok_or(AuthenticateError::CredentialsNotFound)?; - ( - String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?, - false, - ) - }; + let key = key.await?; this.update(cx, |this, cx| { - this.api_key = Some(api_key); - this.api_key_from_env = from_env; + this.api_key = Some(key.key); + this.api_key_from_env = key.from_env; cx.notify(); })?; @@ -184,6 +169,11 @@ impl State { } } +pub struct ApiKey { + pub key: String, + pub from_env: bool, +} + impl AnthropicLanguageModelProvider { pub fn new(http_client: Arc, cx: &mut App) -> Self { let state = cx.new(|cx| State { @@ -206,6 +196,33 @@ impl AnthropicLanguageModelProvider { request_limiter: RateLimiter::new(4), }) } + + pub fn api_key(cx: &mut App) -> Task> { + let credentials_provider = ::global(cx); + let api_url = AllLanguageModelSettings::get_global(cx) + .anthropic + .api_url + .clone(); + + if let Ok(key) = std::env::var(ANTHROPIC_API_KEY_VAR) { + Task::ready(Ok(ApiKey { + key, + from_env: true, + })) + } else { + cx.spawn(async move |cx| { + let (_, api_key) = credentials_provider + .read_credentials(&api_url, &cx) + .await? + .ok_or(AuthenticateError::CredentialsNotFound)?; + + Ok(ApiKey { + key: String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?, + from_env: false, + }) + }) + } + } } impl LanguageModelProviderState for AnthropicLanguageModelProvider { @@ -299,8 +316,13 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider { self.state.update(cx, |state, cx| state.authenticate(cx)) } - fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView { - cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx)) + fn configuration_view( + &self, + target_agent: ConfigurationViewTargetAgent, + window: &mut Window, + cx: &mut App, + ) -> AnyView { + cx.new(|cx| ConfigurationView::new(self.state.clone(), target_agent, window, cx)) .into() } @@ -902,12 +924,18 @@ struct ConfigurationView { api_key_editor: Entity, state: gpui::Entity, load_credentials_task: Option>, + target_agent: ConfigurationViewTargetAgent, } impl ConfigurationView { const PLACEHOLDER_TEXT: &'static str = "sk-ant-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"; - fn new(state: gpui::Entity, window: &mut Window, cx: &mut Context) -> Self { + fn new( + state: gpui::Entity, + target_agent: ConfigurationViewTargetAgent, + window: &mut Window, + cx: &mut Context, + ) -> Self { cx.observe(&state, |_, _, cx| { cx.notify(); }) @@ -939,6 +967,7 @@ impl ConfigurationView { }), state, load_credentials_task, + target_agent, } } @@ -1012,7 +1041,10 @@ impl Render for ConfigurationView { v_flex() .size_full() .on_action(cx.listener(Self::save_api_key)) - .child(Label::new("To use Zed's agent with Anthropic, you need to add an API key. Follow these steps:")) + .child(Label::new(format!("To use {}, you need to add an API key. Follow these steps:", match self.target_agent { + ConfigurationViewTargetAgent::ZedAgent => "Zed's agent with Anthropic", + ConfigurationViewTargetAgent::Other(agent) => agent, + }))) .child( List::new() .child( @@ -1023,7 +1055,7 @@ impl Render for ConfigurationView { ) ) .child( - InstructionListItem::text_only("Paste your API key below and hit enter to start using the assistant") + InstructionListItem::text_only("Paste your API key below and hit enter to start using the agent") ) ) .child( diff --git a/crates/language_models/src/provider/bedrock.rs b/crates/language_models/src/provider/bedrock.rs index 6df96c5c56..4e6744d745 100644 --- a/crates/language_models/src/provider/bedrock.rs +++ b/crates/language_models/src/provider/bedrock.rs @@ -348,7 +348,12 @@ impl LanguageModelProvider for BedrockLanguageModelProvider { self.state.update(cx, |state, cx| state.authenticate(cx)) } - fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView { + fn configuration_view( + &self, + _target_agent: language_model::ConfigurationViewTargetAgent, + window: &mut Window, + cx: &mut App, + ) -> AnyView { cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx)) .into() } diff --git a/crates/language_models/src/provider/cloud.rs b/crates/language_models/src/provider/cloud.rs index c1337399f9..c3f4399832 100644 --- a/crates/language_models/src/provider/cloud.rs +++ b/crates/language_models/src/provider/cloud.rs @@ -391,7 +391,12 @@ impl LanguageModelProvider for CloudLanguageModelProvider { Task::ready(Ok(())) } - fn configuration_view(&self, _: &mut Window, cx: &mut App) -> AnyView { + fn configuration_view( + &self, + _target_agent: language_model::ConfigurationViewTargetAgent, + _: &mut Window, + cx: &mut App, + ) -> AnyView { cx.new(|_| ConfigurationView::new(self.state.clone())) .into() } diff --git a/crates/language_models/src/provider/copilot_chat.rs b/crates/language_models/src/provider/copilot_chat.rs index 73f73a9a31..eb12c0056f 100644 --- a/crates/language_models/src/provider/copilot_chat.rs +++ b/crates/language_models/src/provider/copilot_chat.rs @@ -176,7 +176,12 @@ impl LanguageModelProvider for CopilotChatLanguageModelProvider { Task::ready(Err(err.into())) } - fn configuration_view(&self, _: &mut Window, cx: &mut App) -> AnyView { + fn configuration_view( + &self, + _target_agent: language_model::ConfigurationViewTargetAgent, + _: &mut Window, + cx: &mut App, + ) -> AnyView { let state = self.state.clone(); cx.new(|cx| ConfigurationView::new(state, cx)).into() } diff --git a/crates/language_models/src/provider/deepseek.rs b/crates/language_models/src/provider/deepseek.rs index a568ef4034..2b30d456ee 100644 --- a/crates/language_models/src/provider/deepseek.rs +++ b/crates/language_models/src/provider/deepseek.rs @@ -229,7 +229,12 @@ impl LanguageModelProvider for DeepSeekLanguageModelProvider { self.state.update(cx, |state, cx| state.authenticate(cx)) } - fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView { + fn configuration_view( + &self, + _target_agent: language_model::ConfigurationViewTargetAgent, + window: &mut Window, + cx: &mut App, + ) -> AnyView { cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx)) .into() } diff --git a/crates/language_models/src/provider/google.rs b/crates/language_models/src/provider/google.rs index b287e8181a..32f8838df7 100644 --- a/crates/language_models/src/provider/google.rs +++ b/crates/language_models/src/provider/google.rs @@ -277,7 +277,12 @@ impl LanguageModelProvider for GoogleLanguageModelProvider { self.state.update(cx, |state, cx| state.authenticate(cx)) } - fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView { + fn configuration_view( + &self, + _target_agent: language_model::ConfigurationViewTargetAgent, + window: &mut Window, + cx: &mut App, + ) -> AnyView { cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx)) .into() } diff --git a/crates/language_models/src/provider/lmstudio.rs b/crates/language_models/src/provider/lmstudio.rs index 36a32ab941..7ac08f2c15 100644 --- a/crates/language_models/src/provider/lmstudio.rs +++ b/crates/language_models/src/provider/lmstudio.rs @@ -226,7 +226,12 @@ impl LanguageModelProvider for LmStudioLanguageModelProvider { self.state.update(cx, |state, cx| state.authenticate(cx)) } - fn configuration_view(&self, _window: &mut Window, cx: &mut App) -> AnyView { + fn configuration_view( + &self, + _target_agent: language_model::ConfigurationViewTargetAgent, + _window: &mut Window, + cx: &mut App, + ) -> AnyView { let state = self.state.clone(); cx.new(|cx| ConfigurationView::new(state, cx)).into() } diff --git a/crates/language_models/src/provider/mistral.rs b/crates/language_models/src/provider/mistral.rs index 4a0d740334..e1d55801eb 100644 --- a/crates/language_models/src/provider/mistral.rs +++ b/crates/language_models/src/provider/mistral.rs @@ -243,7 +243,12 @@ impl LanguageModelProvider for MistralLanguageModelProvider { self.state.update(cx, |state, cx| state.authenticate(cx)) } - fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView { + fn configuration_view( + &self, + _target_agent: language_model::ConfigurationViewTargetAgent, + window: &mut Window, + cx: &mut App, + ) -> AnyView { cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx)) .into() } diff --git a/crates/language_models/src/provider/ollama.rs b/crates/language_models/src/provider/ollama.rs index 0c2b1107b1..93844542ea 100644 --- a/crates/language_models/src/provider/ollama.rs +++ b/crates/language_models/src/provider/ollama.rs @@ -255,7 +255,12 @@ impl LanguageModelProvider for OllamaLanguageModelProvider { self.state.update(cx, |state, cx| state.authenticate(cx)) } - fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView { + fn configuration_view( + &self, + _target_agent: language_model::ConfigurationViewTargetAgent, + window: &mut Window, + cx: &mut App, + ) -> AnyView { let state = self.state.clone(); cx.new(|cx| ConfigurationView::new(state, window, cx)) .into() diff --git a/crates/language_models/src/provider/open_ai.rs b/crates/language_models/src/provider/open_ai.rs index eaf8d885b3..04d89f2db1 100644 --- a/crates/language_models/src/provider/open_ai.rs +++ b/crates/language_models/src/provider/open_ai.rs @@ -233,7 +233,12 @@ impl LanguageModelProvider for OpenAiLanguageModelProvider { self.state.update(cx, |state, cx| state.authenticate(cx)) } - fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView { + fn configuration_view( + &self, + _target_agent: language_model::ConfigurationViewTargetAgent, + window: &mut Window, + cx: &mut App, + ) -> AnyView { cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx)) .into() } diff --git a/crates/language_models/src/provider/open_ai_compatible.rs b/crates/language_models/src/provider/open_ai_compatible.rs index e2d3adb198..c6b980c3ec 100644 --- a/crates/language_models/src/provider/open_ai_compatible.rs +++ b/crates/language_models/src/provider/open_ai_compatible.rs @@ -243,7 +243,12 @@ impl LanguageModelProvider for OpenAiCompatibleLanguageModelProvider { self.state.update(cx, |state, cx| state.authenticate(cx)) } - fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView { + fn configuration_view( + &self, + _target_agent: language_model::ConfigurationViewTargetAgent, + window: &mut Window, + cx: &mut App, + ) -> AnyView { cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx)) .into() } diff --git a/crates/language_models/src/provider/open_router.rs b/crates/language_models/src/provider/open_router.rs index 3a492086f1..5d8bace6d3 100644 --- a/crates/language_models/src/provider/open_router.rs +++ b/crates/language_models/src/provider/open_router.rs @@ -306,7 +306,12 @@ impl LanguageModelProvider for OpenRouterLanguageModelProvider { self.state.update(cx, |state, cx| state.authenticate(cx)) } - fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView { + fn configuration_view( + &self, + _target_agent: language_model::ConfigurationViewTargetAgent, + window: &mut Window, + cx: &mut App, + ) -> AnyView { cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx)) .into() } diff --git a/crates/language_models/src/provider/vercel.rs b/crates/language_models/src/provider/vercel.rs index 9f447cb68b..98e4f60b6b 100644 --- a/crates/language_models/src/provider/vercel.rs +++ b/crates/language_models/src/provider/vercel.rs @@ -230,7 +230,12 @@ impl LanguageModelProvider for VercelLanguageModelProvider { self.state.update(cx, |state, cx| state.authenticate(cx)) } - fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView { + fn configuration_view( + &self, + _target_agent: language_model::ConfigurationViewTargetAgent, + window: &mut Window, + cx: &mut App, + ) -> AnyView { cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx)) .into() } diff --git a/crates/language_models/src/provider/x_ai.rs b/crates/language_models/src/provider/x_ai.rs index fed6fe92bf..2b8238cc5c 100644 --- a/crates/language_models/src/provider/x_ai.rs +++ b/crates/language_models/src/provider/x_ai.rs @@ -230,7 +230,12 @@ impl LanguageModelProvider for XAiLanguageModelProvider { self.state.update(cx, |state, cx| state.authenticate(cx)) } - fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView { + fn configuration_view( + &self, + _target_agent: language_model::ConfigurationViewTargetAgent, + window: &mut Window, + cx: &mut App, + ) -> AnyView { cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx)) .into() } diff --git a/crates/onboarding/src/ai_setup_page.rs b/crates/onboarding/src/ai_setup_page.rs index bb1932bdf2..d700fa08bd 100644 --- a/crates/onboarding/src/ai_setup_page.rs +++ b/crates/onboarding/src/ai_setup_page.rs @@ -329,7 +329,11 @@ impl AiConfigurationModal { cx: &mut Context, ) -> Self { let focus_handle = cx.focus_handle(); - let configuration_view = selected_provider.configuration_view(window, cx); + let configuration_view = selected_provider.configuration_view( + language_model::ConfigurationViewTargetAgent::ZedAgent, + window, + cx, + ); Self { focus_handle,