diff --git a/Cargo.lock b/Cargo.lock index 58ee6fb684..010af1f469 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -20,6 +20,7 @@ dependencies = [ "indoc", "itertools 0.14.0", "language", + "language_model", "markdown", "parking_lot", "project", diff --git a/crates/acp_thread/Cargo.toml b/crates/acp_thread/Cargo.toml index 2b9a6513c8..173f4c4208 100644 --- a/crates/acp_thread/Cargo.toml +++ b/crates/acp_thread/Cargo.toml @@ -28,6 +28,7 @@ futures.workspace = true gpui.workspace = true itertools.workspace = true language.workspace = true +language_model.workspace = true markdown.workspace = true parking_lot = { workspace = true, optional = true } project.workspace = true diff --git a/crates/acp_thread/src/connection.rs b/crates/acp_thread/src/connection.rs index 9c64ae65b0..f0cad7f11b 100644 --- a/crates/acp_thread/src/connection.rs +++ b/crates/acp_thread/src/connection.rs @@ -3,6 +3,7 @@ use agent_client_protocol::{self as acp}; use anyhow::Result; use collections::IndexMap; use gpui::{Entity, SharedString, Task}; +use language_model::LanguageModelProviderId; use project::Project; use std::{any::Any, error::Error, fmt, path::Path, rc::Rc, sync::Arc}; use ui::{App, IconName}; @@ -82,15 +83,14 @@ pub trait AgentSessionResume { #[derive(Debug)] pub struct AuthRequired { pub description: Option, - /// A Task that resolves when authentication is updated - pub update_task: Option>, + pub provider_id: Option, } impl AuthRequired { pub fn new() -> Self { Self { description: None, - update_task: None, + provider_id: None, } } @@ -99,8 +99,8 @@ impl AuthRequired { self } - pub fn with_update(mut self, update: Task<()>) -> Self { - self.update_task = Some(update); + pub fn with_language_model_provider(mut self, provider_id: LanguageModelProviderId) -> Self { + self.provider_id = Some(provider_id); self } } diff --git a/crates/agent_servers/src/claude.rs b/crates/agent_servers/src/claude.rs index 8fb4b898b1..a74090a5fe 100644 --- a/crates/agent_servers/src/claude.rs +++ b/crates/agent_servers/src/claude.rs @@ -3,7 +3,6 @@ pub mod tools; use collections::HashMap; use context_server::listener::McpServerTool; -use language_model::LanguageModelRegistry; use language_models::provider::anthropic::AnthropicLanguageModelProvider; use project::Project; use settings::SettingsStore; @@ -13,7 +12,6 @@ use std::cell::RefCell; use std::fmt::Display; use std::path::Path; use std::rc::Rc; -use std::sync::Arc; use uuid::Uuid; use agent_client_protocol as acp; @@ -99,53 +97,18 @@ impl AgentConnection for ClaudeAgentConnection { anyhow::bail!("Failed to find claude binary"); }; - let anthropic: Arc = cx.update(|cx| { - let registry = LanguageModelRegistry::global(cx); - let provider: Arc = registry - .read(cx) - .provider(&language_model::ANTHROPIC_PROVIDER_ID) - .context("Failed to get Anthropic provider")?; - - Arc::downcast::(provider) - .map_err(|_| anyhow!("Failed to downcast provider")) - })??; - - let api_key = cx - .update(|cx| AnthropicLanguageModelProvider::api_key(cx))? - .await - .map_err(|err| { - if err.is::() { - let (update_tx, update_rx) = oneshot::channel(); - let mut update_tx = Some(update_tx); - - let sub = cx - .update(|cx| { - anthropic.observe( - move |_cx| { - if let Some(update_tx) = update_tx.take() { - update_tx.send(()).ok(); - } - }, - cx, - ) - }) - .ok(); - - let update_task = cx.foreground_executor().spawn(async move { - update_rx.await.ok(); - drop(sub) - }); - - anyhow!( - AuthRequired::new() - .with_description( - "To use Claude Code in Zed, you need an [Anthropic API key](https://console.anthropic.com/settings/keys)\n\nAdd one in [settings](zed:///agent/settings) or set the `ANTHROPIC_API_KEY` variable".into()) - .with_update(update_task) - ) - } else { - anyhow!(err) - } - })?; + let api_key = + cx.update(|cx| AnthropicLanguageModelProvider::api_key(cx))? + .await + .map_err(|err| { + if err.is::() { + anyhow!(AuthRequired::new().with_language_model_provider( + language_model::ANTHROPIC_PROVIDER_ID + )) + } else { + anyhow!(err) + } + })?; let (mut thread_tx, thread_rx) = watch::channel(WeakEntity::new_invalid()); let permission_mcp_server = ClaudeZedMcpServer::new(thread_rx.clone(), cx).await?; diff --git a/crates/agent_ui/src/acp/thread_view.rs b/crates/agent_ui/src/acp/thread_view.rs index 7ad71e1f46..02e263da1a 100644 --- a/crates/agent_ui/src/acp/thread_view.rs +++ b/crates/agent_ui/src/acp/thread_view.rs @@ -1,6 +1,7 @@ use acp_thread::{ AcpThread, AcpThreadEvent, AgentThreadEntry, AssistantMessage, AssistantMessageChunk, - LoadError, MentionUri, ThreadStatus, ToolCall, ToolCallContent, ToolCallStatus, UserMessageId, + AuthRequired, LoadError, MentionUri, ThreadStatus, ToolCall, ToolCallContent, ToolCallStatus, + UserMessageId, }; use acp_thread::{AgentConnection, Plan}; use action_log::ActionLog; @@ -18,13 +19,16 @@ use editor::{Editor, EditorMode, MultiBuffer, PathKey, SelectionEffects}; use file_icons::FileIcons; use fs::Fs; use gpui::{ - Action, Animation, AnimationExt, App, BorderStyle, ClickEvent, ClipboardItem, EdgesRefinement, - Empty, Entity, FocusHandle, Focusable, Hsla, Length, ListOffset, ListState, MouseButton, - PlatformDisplay, SharedString, Stateful, StyleRefinement, Subscription, Task, TextStyle, - TextStyleRefinement, Transformation, UnderlineStyle, WeakEntity, Window, WindowHandle, div, - linear_color_stop, linear_gradient, list, percentage, point, prelude::*, pulsating_between, + Action, Animation, AnimationExt, AnyView, App, BorderStyle, ClickEvent, ClipboardItem, + EdgesRefinement, Empty, Entity, FocusHandle, Focusable, Hsla, Length, ListOffset, ListState, + MouseButton, PlatformDisplay, SharedString, Stateful, StyleRefinement, Subscription, Task, + TextStyle, TextStyleRefinement, Transformation, UnderlineStyle, WeakEntity, Window, + WindowHandle, div, linear_color_stop, linear_gradient, list, percentage, point, prelude::*, + pulsating_between, }; use language::Buffer; + +use language_model::LanguageModelRegistry; use markdown::{HeadingLevelStyles, Markdown, MarkdownElement, MarkdownStyle}; use project::Project; use prompt_store::PromptId; @@ -138,6 +142,8 @@ enum ThreadState { Unauthenticated { connection: Rc, description: Option>, + configuration_view: Option, + _subscription: Option, }, ServerExited { status: ExitStatus, @@ -268,44 +274,16 @@ impl AcpThreadView { }; let result = match result.await { - Err(e) => { - let mut cx = cx.clone(); - match e.downcast::() { - Ok(mut err) => { - if let Some(update_task) = err.update_task.take() { - let this = this.clone(); - let project = project.clone(); - cx.spawn(async move |cx| { - update_task.await; - this.update_in(cx, |this, window, cx| { - this.thread_state = Self::initial_state( - agent, - this.workspace.clone(), - project.clone(), - window, - cx, - ); - cx.notify(); - }) - .ok(); - }) - .detach(); - } - this.update(&mut cx, |this, cx| { - this.thread_state = ThreadState::Unauthenticated { - connection, - description: err.description.clone().map(|desc| { - cx.new(|cx| Markdown::new(desc.into(), None, None, cx)) - }), - }; - cx.notify(); - }) - .ok(); - return; - } - Err(err) => Err(err), + Err(e) => match e.downcast::() { + Ok(err) => { + cx.update(|window, cx| { + Self::handle_auth_required(this, err, agent, connection, window, cx) + }) + .log_err(); + return; } - } + Err(err) => Err(err), + }, Ok(thread) => Ok(thread), }; @@ -371,6 +349,68 @@ impl AcpThreadView { ThreadState::Loading { _task: load_task } } + fn handle_auth_required( + this: WeakEntity, + err: AuthRequired, + agent: Rc, + connection: Rc, + window: &mut Window, + cx: &mut App, + ) { + let agent_name = agent.name(); + let (configuration_view, subscription) = if let Some(provider_id) = err.provider_id { + let registry = LanguageModelRegistry::global(cx); + + let sub = window.subscribe(®istry, cx, { + let provider_id = provider_id.clone(); + let this = this.clone(); + move |_, ev, window, cx| { + if let language_model::Event::ProviderStateChanged(updated_provider_id) = &ev { + if &provider_id == updated_provider_id { + this.update(cx, |this, cx| { + this.thread_state = Self::initial_state( + agent.clone(), + this.workspace.clone(), + this.project.clone(), + window, + cx, + ); + cx.notify(); + }) + .ok(); + } + } + } + }); + + let view = registry.read(cx).provider(&provider_id).map(|provider| { + provider.configuration_view( + language_model::ConfigurationViewTargetAgent::Other(agent_name), + window, + cx, + ) + }); + + (view, Some(sub)) + } else { + (None, None) + }; + + this.update(cx, |this, cx| { + this.thread_state = ThreadState::Unauthenticated { + connection, + configuration_view, + description: err + .description + .clone() + .map(|desc| cx.new(|cx| Markdown::new(desc.into(), None, None, cx))), + _subscription: subscription, + }; + cx.notify(); + }) + .ok(); + } + fn handle_load_error(&mut self, err: anyhow::Error, cx: &mut Context) { if let Some(load_err) = err.downcast_ref::() { self.thread_state = ThreadState::LoadError(load_err.clone()); @@ -1867,19 +1907,53 @@ impl AcpThreadView { .into_any() } - fn render_pending_auth_state(&self) -> AnyElement { + fn render_auth_required_state( + &self, + connection: &Rc, + description: Option<&Entity>, + configuration_view: Option<&AnyView>, + window: &mut Window, + cx: &Context, + ) -> Div { v_flex() + .p_2() + .gap_2() + .flex_1() .items_center() .justify_center() - .child(self.render_error_agent_logo()) .child( - h_flex() - .mt_4() - .mb_1() + v_flex() + .items_center() .justify_center() - .child(Headline::new("Authentication Required").size(HeadlineSize::Medium)), + .child(self.render_error_agent_logo()) + .child( + h_flex().mt_4().mb_1().justify_center().child( + Headline::new("Authentication Required").size(HeadlineSize::Medium), + ), + ) + .into_any(), ) - .into_any() + .children(description.map(|desc| { + div().text_ui(cx).text_center().child( + self.render_markdown(desc.clone(), default_markdown_style(false, window, cx)), + ) + })) + .children( + configuration_view + .cloned() + .map(|view| div().px_4().w_full().max_w_128().child(view)), + ) + .child(h_flex().mt_1p5().justify_center().children( + connection.auth_methods().into_iter().map(|method| { + Button::new(SharedString::from(method.id.0.clone()), method.name.clone()) + .on_click({ + let method_id = method.id.clone(); + cx.listener(move |this, _, window, cx| { + this.authenticate(method_id.clone(), window, cx) + }) + }) + }), + )) } fn render_server_exited(&self, status: ExitStatus, _cx: &Context) -> AnyElement { @@ -2804,13 +2878,6 @@ impl AcpThreadView { cx.open_url(url.as_str()); } }) - } else if url == "zed:///agent/settings" { - workspace.update(cx, |workspace, cx| { - if let Some(panel) = workspace.panel::(cx) { - workspace.focus_panel::(window, cx); - panel.update(cx, |panel, cx| panel.open_configuration(window, cx)); - } - }); } else { cx.open_url(&url); } @@ -3383,33 +3450,15 @@ impl Render for AcpThreadView { ThreadState::Unauthenticated { connection, description, - } => v_flex() - .p_2() - .gap_2() - .flex_1() - .items_center() - .justify_center() - .child(self.render_pending_auth_state()) - .text_ui(cx) - .text_center() - .text_color(cx.theme().colors().text_muted) - .children(description.clone().map(|desc| { - self.render_markdown(desc, default_markdown_style(false, window, cx)) - })) - .child(h_flex().mt_1p5().justify_center().children( - connection.auth_methods().into_iter().map(|method| { - Button::new( - SharedString::from(method.id.0.clone()), - method.name.clone(), - ) - .on_click({ - let method_id = method.id.clone(); - cx.listener(move |this, _, window, cx| { - this.authenticate(method_id.clone(), window, cx) - }) - }) - }), - )), + configuration_view, + .. + } => self.render_auth_required_state( + &connection, + description.as_ref(), + configuration_view.as_ref(), + window, + cx, + ), ThreadState::Loading { .. } => v_flex().flex_1().child(self.render_empty_state(cx)), ThreadState::LoadError(e) => v_flex() .p_2() diff --git a/crates/agent_ui/src/agent_configuration.rs b/crates/agent_ui/src/agent_configuration.rs index b4ebb8206c..a0584f9e2e 100644 --- a/crates/agent_ui/src/agent_configuration.rs +++ b/crates/agent_ui/src/agent_configuration.rs @@ -137,7 +137,11 @@ impl AgentConfiguration { window: &mut Window, cx: &mut Context, ) { - let configuration_view = provider.configuration_view(window, cx); + let configuration_view = provider.configuration_view( + language_model::ConfigurationViewTargetAgent::ZedAgent, + window, + cx, + ); self.configuration_views_by_provider .insert(provider.id(), configuration_view); } diff --git a/crates/agent_ui/src/agent_ui.rs b/crates/agent_ui/src/agent_ui.rs index ce1c2203bf..8525d7f9e5 100644 --- a/crates/agent_ui/src/agent_ui.rs +++ b/crates/agent_ui/src/agent_ui.rs @@ -320,7 +320,7 @@ fn init_language_model_settings(cx: &mut App) { cx.subscribe( &LanguageModelRegistry::global(cx), |_, event: &language_model::Event, cx| match event { - language_model::Event::ProviderStateChanged + language_model::Event::ProviderStateChanged(_) | language_model::Event::AddedProvider(_) | language_model::Event::RemovedProvider(_) => { update_active_language_model_from_settings(cx); diff --git a/crates/agent_ui/src/language_model_selector.rs b/crates/agent_ui/src/language_model_selector.rs index bb8514a224..fa8ca490d8 100644 --- a/crates/agent_ui/src/language_model_selector.rs +++ b/crates/agent_ui/src/language_model_selector.rs @@ -104,7 +104,7 @@ impl LanguageModelPickerDelegate { window, |picker, _, event, window, cx| { match event { - language_model::Event::ProviderStateChanged + language_model::Event::ProviderStateChanged(_) | language_model::Event::AddedProvider(_) | language_model::Event::RemovedProvider(_) => { let query = picker.query(cx); diff --git a/crates/ai_onboarding/src/agent_api_keys_onboarding.rs b/crates/ai_onboarding/src/agent_api_keys_onboarding.rs index b55ad4c895..0a34a29068 100644 --- a/crates/ai_onboarding/src/agent_api_keys_onboarding.rs +++ b/crates/ai_onboarding/src/agent_api_keys_onboarding.rs @@ -11,7 +11,7 @@ impl ApiKeysWithProviders { cx.subscribe( &LanguageModelRegistry::global(cx), |this: &mut Self, _registry, event: &language_model::Event, cx| match event { - language_model::Event::ProviderStateChanged + language_model::Event::ProviderStateChanged(_) | language_model::Event::AddedProvider(_) | language_model::Event::RemovedProvider(_) => { this.configured_providers = Self::compute_configured_providers(cx) diff --git a/crates/ai_onboarding/src/agent_panel_onboarding_content.rs b/crates/ai_onboarding/src/agent_panel_onboarding_content.rs index f1629eeff8..23810b74f3 100644 --- a/crates/ai_onboarding/src/agent_panel_onboarding_content.rs +++ b/crates/ai_onboarding/src/agent_panel_onboarding_content.rs @@ -25,7 +25,7 @@ impl AgentPanelOnboarding { cx.subscribe( &LanguageModelRegistry::global(cx), |this: &mut Self, _registry, event: &language_model::Event, cx| match event { - language_model::Event::ProviderStateChanged + language_model::Event::ProviderStateChanged(_) | language_model::Event::AddedProvider(_) | language_model::Event::RemovedProvider(_) => { this.configured_providers = Self::compute_available_providers(cx) diff --git a/crates/language_model/src/fake_provider.rs b/crates/language_model/src/fake_provider.rs index a9c7d5c034..67fba44887 100644 --- a/crates/language_model/src/fake_provider.rs +++ b/crates/language_model/src/fake_provider.rs @@ -1,8 +1,8 @@ use crate::{ - AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, - LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, - LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, - LanguageModelToolChoice, + AuthenticateError, ConfigurationViewTargetAgent, LanguageModel, LanguageModelCompletionError, + LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider, + LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, + LanguageModelRequest, LanguageModelToolChoice, }; use futures::{FutureExt, StreamExt, channel::mpsc, future::BoxFuture, stream::BoxStream}; use gpui::{AnyView, App, AsyncApp, Entity, Task, Window}; @@ -62,7 +62,12 @@ impl LanguageModelProvider for FakeLanguageModelProvider { Task::ready(Ok(())) } - fn configuration_view(&self, _window: &mut Window, _: &mut App) -> AnyView { + fn configuration_view( + &self, + _target_agent: ConfigurationViewTargetAgent, + _window: &mut Window, + _: &mut App, + ) -> AnyView { unimplemented!() } diff --git a/crates/language_model/src/language_model.rs b/crates/language_model/src/language_model.rs index d74e8b7076..70e42cb02d 100644 --- a/crates/language_model/src/language_model.rs +++ b/crates/language_model/src/language_model.rs @@ -20,7 +20,6 @@ use icons::IconName; use parking_lot::Mutex; use schemars::JsonSchema; use serde::{Deserialize, Serialize, de::DeserializeOwned}; -use std::any::Any; use std::ops::{Add, Sub}; use std::str::FromStr; use std::sync::Arc; @@ -621,7 +620,7 @@ pub enum AuthenticateError { Other(#[from] anyhow::Error), } -pub trait LanguageModelProvider: Any + Send + Sync { +pub trait LanguageModelProvider: 'static { fn id(&self) -> LanguageModelProviderId; fn name(&self) -> LanguageModelProviderName; fn icon(&self) -> IconName { @@ -635,7 +634,12 @@ pub trait LanguageModelProvider: Any + Send + Sync { } fn is_authenticated(&self, cx: &App) -> bool; fn authenticate(&self, cx: &mut App) -> Task>; - fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView; + fn configuration_view( + &self, + target_agent: ConfigurationViewTargetAgent, + window: &mut Window, + cx: &mut App, + ) -> AnyView; fn must_accept_terms(&self, _cx: &App) -> bool { false } @@ -649,6 +653,13 @@ pub trait LanguageModelProvider: Any + Send + Sync { fn reset_credentials(&self, cx: &mut App) -> Task>; } +#[derive(Default, Clone, Copy)] +pub enum ConfigurationViewTargetAgent { + #[default] + ZedAgent, + Other(&'static str), +} + #[derive(PartialEq, Eq)] pub enum LanguageModelProviderTosView { /// When there are some past interactions in the Agent Panel. diff --git a/crates/language_model/src/registry.rs b/crates/language_model/src/registry.rs index 6b4f471b0f..078b90a291 100644 --- a/crates/language_model/src/registry.rs +++ b/crates/language_model/src/registry.rs @@ -107,8 +107,7 @@ pub enum Event { InlineAssistantModelChanged, CommitMessageModelChanged, ThreadSummaryModelChanged, - ProviderStateChanged, - ProviderAuthUpdated, + ProviderStateChanged(LanguageModelProviderId), AddedProvider(LanguageModelProviderId), RemovedProvider(LanguageModelProviderId), } @@ -149,8 +148,11 @@ impl LanguageModelRegistry { ) { let id = provider.id(); - let subscription = provider.subscribe(cx, |_, cx| { - cx.emit(Event::ProviderStateChanged); + let subscription = provider.subscribe(cx, { + let id = id.clone(); + move |_, cx| { + cx.emit(Event::ProviderStateChanged(id.clone())); + } }); if let Some(subscription) = subscription { subscription.detach(); diff --git a/crates/language_models/src/provider/anthropic.rs b/crates/language_models/src/provider/anthropic.rs index 3f14841210..810d4a5f44 100644 --- a/crates/language_models/src/provider/anthropic.rs +++ b/crates/language_models/src/provider/anthropic.rs @@ -15,11 +15,11 @@ use gpui::{ }; use http_client::HttpClient; use language_model::{ - AuthenticateError, LanguageModel, LanguageModelCacheConfiguration, - LanguageModelCompletionError, LanguageModelId, LanguageModelName, LanguageModelProvider, - LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, - LanguageModelRequest, LanguageModelToolChoice, LanguageModelToolResultContent, MessageContent, - RateLimiter, Role, + AuthenticateError, ConfigurationViewTargetAgent, LanguageModel, + LanguageModelCacheConfiguration, LanguageModelCompletionError, LanguageModelId, + LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, + LanguageModelProviderState, LanguageModelRequest, LanguageModelToolChoice, + LanguageModelToolResultContent, MessageContent, RateLimiter, Role, }; use language_model::{LanguageModelCompletionEvent, LanguageModelToolUse, StopReason}; use schemars::JsonSchema; @@ -223,14 +223,6 @@ impl AnthropicLanguageModelProvider { }) } } - - pub fn observe( - &self, - mut on_notify: impl FnMut(&mut App) + 'static, - cx: &mut App, - ) -> Subscription { - cx.observe(&self.state, move |_, cx| on_notify(cx)) - } } impl LanguageModelProviderState for AnthropicLanguageModelProvider { @@ -324,8 +316,13 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider { self.state.update(cx, |state, cx| state.authenticate(cx)) } - fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView { - cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx)) + fn configuration_view( + &self, + target_agent: ConfigurationViewTargetAgent, + window: &mut Window, + cx: &mut App, + ) -> AnyView { + cx.new(|cx| ConfigurationView::new(self.state.clone(), target_agent, window, cx)) .into() } @@ -927,12 +924,18 @@ struct ConfigurationView { api_key_editor: Entity, state: gpui::Entity, load_credentials_task: Option>, + target_agent: ConfigurationViewTargetAgent, } impl ConfigurationView { const PLACEHOLDER_TEXT: &'static str = "sk-ant-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"; - fn new(state: gpui::Entity, window: &mut Window, cx: &mut Context) -> Self { + fn new( + state: gpui::Entity, + target_agent: ConfigurationViewTargetAgent, + window: &mut Window, + cx: &mut Context, + ) -> Self { cx.observe(&state, |_, _, cx| { cx.notify(); }) @@ -964,6 +967,7 @@ impl ConfigurationView { }), state, load_credentials_task, + target_agent, } } @@ -1037,7 +1041,10 @@ impl Render for ConfigurationView { v_flex() .size_full() .on_action(cx.listener(Self::save_api_key)) - .child(Label::new("To use Zed's agent with Anthropic, you need to add an API key. Follow these steps:")) + .child(Label::new(format!("To use {}, you need to add an API key. Follow these steps:", match self.target_agent { + ConfigurationViewTargetAgent::ZedAgent => "Zed's agent with Anthropic", + ConfigurationViewTargetAgent::Other(agent) => agent, + }))) .child( List::new() .child( @@ -1048,7 +1055,7 @@ impl Render for ConfigurationView { ) ) .child( - InstructionListItem::text_only("Paste your API key below and hit enter to start using the assistant") + InstructionListItem::text_only("Paste your API key below and hit enter to start using the agent") ) ) .child( diff --git a/crates/language_models/src/provider/bedrock.rs b/crates/language_models/src/provider/bedrock.rs index 6df96c5c56..4e6744d745 100644 --- a/crates/language_models/src/provider/bedrock.rs +++ b/crates/language_models/src/provider/bedrock.rs @@ -348,7 +348,12 @@ impl LanguageModelProvider for BedrockLanguageModelProvider { self.state.update(cx, |state, cx| state.authenticate(cx)) } - fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView { + fn configuration_view( + &self, + _target_agent: language_model::ConfigurationViewTargetAgent, + window: &mut Window, + cx: &mut App, + ) -> AnyView { cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx)) .into() } diff --git a/crates/language_models/src/provider/cloud.rs b/crates/language_models/src/provider/cloud.rs index c1337399f9..c3f4399832 100644 --- a/crates/language_models/src/provider/cloud.rs +++ b/crates/language_models/src/provider/cloud.rs @@ -391,7 +391,12 @@ impl LanguageModelProvider for CloudLanguageModelProvider { Task::ready(Ok(())) } - fn configuration_view(&self, _: &mut Window, cx: &mut App) -> AnyView { + fn configuration_view( + &self, + _target_agent: language_model::ConfigurationViewTargetAgent, + _: &mut Window, + cx: &mut App, + ) -> AnyView { cx.new(|_| ConfigurationView::new(self.state.clone())) .into() } diff --git a/crates/language_models/src/provider/copilot_chat.rs b/crates/language_models/src/provider/copilot_chat.rs index 73f73a9a31..eb12c0056f 100644 --- a/crates/language_models/src/provider/copilot_chat.rs +++ b/crates/language_models/src/provider/copilot_chat.rs @@ -176,7 +176,12 @@ impl LanguageModelProvider for CopilotChatLanguageModelProvider { Task::ready(Err(err.into())) } - fn configuration_view(&self, _: &mut Window, cx: &mut App) -> AnyView { + fn configuration_view( + &self, + _target_agent: language_model::ConfigurationViewTargetAgent, + _: &mut Window, + cx: &mut App, + ) -> AnyView { let state = self.state.clone(); cx.new(|cx| ConfigurationView::new(state, cx)).into() } diff --git a/crates/language_models/src/provider/deepseek.rs b/crates/language_models/src/provider/deepseek.rs index a568ef4034..2b30d456ee 100644 --- a/crates/language_models/src/provider/deepseek.rs +++ b/crates/language_models/src/provider/deepseek.rs @@ -229,7 +229,12 @@ impl LanguageModelProvider for DeepSeekLanguageModelProvider { self.state.update(cx, |state, cx| state.authenticate(cx)) } - fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView { + fn configuration_view( + &self, + _target_agent: language_model::ConfigurationViewTargetAgent, + window: &mut Window, + cx: &mut App, + ) -> AnyView { cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx)) .into() } diff --git a/crates/language_models/src/provider/google.rs b/crates/language_models/src/provider/google.rs index b287e8181a..32f8838df7 100644 --- a/crates/language_models/src/provider/google.rs +++ b/crates/language_models/src/provider/google.rs @@ -277,7 +277,12 @@ impl LanguageModelProvider for GoogleLanguageModelProvider { self.state.update(cx, |state, cx| state.authenticate(cx)) } - fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView { + fn configuration_view( + &self, + _target_agent: language_model::ConfigurationViewTargetAgent, + window: &mut Window, + cx: &mut App, + ) -> AnyView { cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx)) .into() } diff --git a/crates/language_models/src/provider/lmstudio.rs b/crates/language_models/src/provider/lmstudio.rs index 36a32ab941..7ac08f2c15 100644 --- a/crates/language_models/src/provider/lmstudio.rs +++ b/crates/language_models/src/provider/lmstudio.rs @@ -226,7 +226,12 @@ impl LanguageModelProvider for LmStudioLanguageModelProvider { self.state.update(cx, |state, cx| state.authenticate(cx)) } - fn configuration_view(&self, _window: &mut Window, cx: &mut App) -> AnyView { + fn configuration_view( + &self, + _target_agent: language_model::ConfigurationViewTargetAgent, + _window: &mut Window, + cx: &mut App, + ) -> AnyView { let state = self.state.clone(); cx.new(|cx| ConfigurationView::new(state, cx)).into() } diff --git a/crates/language_models/src/provider/mistral.rs b/crates/language_models/src/provider/mistral.rs index 4a0d740334..e1d55801eb 100644 --- a/crates/language_models/src/provider/mistral.rs +++ b/crates/language_models/src/provider/mistral.rs @@ -243,7 +243,12 @@ impl LanguageModelProvider for MistralLanguageModelProvider { self.state.update(cx, |state, cx| state.authenticate(cx)) } - fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView { + fn configuration_view( + &self, + _target_agent: language_model::ConfigurationViewTargetAgent, + window: &mut Window, + cx: &mut App, + ) -> AnyView { cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx)) .into() } diff --git a/crates/language_models/src/provider/ollama.rs b/crates/language_models/src/provider/ollama.rs index 0c2b1107b1..93844542ea 100644 --- a/crates/language_models/src/provider/ollama.rs +++ b/crates/language_models/src/provider/ollama.rs @@ -255,7 +255,12 @@ impl LanguageModelProvider for OllamaLanguageModelProvider { self.state.update(cx, |state, cx| state.authenticate(cx)) } - fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView { + fn configuration_view( + &self, + _target_agent: language_model::ConfigurationViewTargetAgent, + window: &mut Window, + cx: &mut App, + ) -> AnyView { let state = self.state.clone(); cx.new(|cx| ConfigurationView::new(state, window, cx)) .into() diff --git a/crates/language_models/src/provider/open_ai.rs b/crates/language_models/src/provider/open_ai.rs index eaf8d885b3..04d89f2db1 100644 --- a/crates/language_models/src/provider/open_ai.rs +++ b/crates/language_models/src/provider/open_ai.rs @@ -233,7 +233,12 @@ impl LanguageModelProvider for OpenAiLanguageModelProvider { self.state.update(cx, |state, cx| state.authenticate(cx)) } - fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView { + fn configuration_view( + &self, + _target_agent: language_model::ConfigurationViewTargetAgent, + window: &mut Window, + cx: &mut App, + ) -> AnyView { cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx)) .into() } diff --git a/crates/language_models/src/provider/open_ai_compatible.rs b/crates/language_models/src/provider/open_ai_compatible.rs index e2d3adb198..c6b980c3ec 100644 --- a/crates/language_models/src/provider/open_ai_compatible.rs +++ b/crates/language_models/src/provider/open_ai_compatible.rs @@ -243,7 +243,12 @@ impl LanguageModelProvider for OpenAiCompatibleLanguageModelProvider { self.state.update(cx, |state, cx| state.authenticate(cx)) } - fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView { + fn configuration_view( + &self, + _target_agent: language_model::ConfigurationViewTargetAgent, + window: &mut Window, + cx: &mut App, + ) -> AnyView { cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx)) .into() } diff --git a/crates/language_models/src/provider/open_router.rs b/crates/language_models/src/provider/open_router.rs index 3a492086f1..5d8bace6d3 100644 --- a/crates/language_models/src/provider/open_router.rs +++ b/crates/language_models/src/provider/open_router.rs @@ -306,7 +306,12 @@ impl LanguageModelProvider for OpenRouterLanguageModelProvider { self.state.update(cx, |state, cx| state.authenticate(cx)) } - fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView { + fn configuration_view( + &self, + _target_agent: language_model::ConfigurationViewTargetAgent, + window: &mut Window, + cx: &mut App, + ) -> AnyView { cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx)) .into() } diff --git a/crates/language_models/src/provider/vercel.rs b/crates/language_models/src/provider/vercel.rs index 9f447cb68b..98e4f60b6b 100644 --- a/crates/language_models/src/provider/vercel.rs +++ b/crates/language_models/src/provider/vercel.rs @@ -230,7 +230,12 @@ impl LanguageModelProvider for VercelLanguageModelProvider { self.state.update(cx, |state, cx| state.authenticate(cx)) } - fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView { + fn configuration_view( + &self, + _target_agent: language_model::ConfigurationViewTargetAgent, + window: &mut Window, + cx: &mut App, + ) -> AnyView { cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx)) .into() } diff --git a/crates/language_models/src/provider/x_ai.rs b/crates/language_models/src/provider/x_ai.rs index fed6fe92bf..2b8238cc5c 100644 --- a/crates/language_models/src/provider/x_ai.rs +++ b/crates/language_models/src/provider/x_ai.rs @@ -230,7 +230,12 @@ impl LanguageModelProvider for XAiLanguageModelProvider { self.state.update(cx, |state, cx| state.authenticate(cx)) } - fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView { + fn configuration_view( + &self, + _target_agent: language_model::ConfigurationViewTargetAgent, + window: &mut Window, + cx: &mut App, + ) -> AnyView { cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx)) .into() } diff --git a/crates/onboarding/src/ai_setup_page.rs b/crates/onboarding/src/ai_setup_page.rs index bb1932bdf2..d700fa08bd 100644 --- a/crates/onboarding/src/ai_setup_page.rs +++ b/crates/onboarding/src/ai_setup_page.rs @@ -329,7 +329,11 @@ impl AiConfigurationModal { cx: &mut Context, ) -> Self { let focus_handle = cx.focus_handle(); - let configuration_view = selected_provider.configuration_view(window, cx); + let configuration_view = selected_provider.configuration_view( + language_model::ConfigurationViewTargetAgent::ZedAgent, + window, + cx, + ); Self { focus_handle,