Embed Anthropic configuration view

This commit is contained in:
Agus Zubiaga 2025-08-18 17:18:21 -03:00
parent 6c7a5c50bf
commit 21082c0aba
28 changed files with 299 additions and 187 deletions

1
Cargo.lock generated
View file

@ -20,6 +20,7 @@ dependencies = [
"indoc", "indoc",
"itertools 0.14.0", "itertools 0.14.0",
"language", "language",
"language_model",
"markdown", "markdown",
"parking_lot", "parking_lot",
"project", "project",

View file

@ -28,6 +28,7 @@ futures.workspace = true
gpui.workspace = true gpui.workspace = true
itertools.workspace = true itertools.workspace = true
language.workspace = true language.workspace = true
language_model.workspace = true
markdown.workspace = true markdown.workspace = true
parking_lot = { workspace = true, optional = true } parking_lot = { workspace = true, optional = true }
project.workspace = true project.workspace = true

View file

@ -3,6 +3,7 @@ use agent_client_protocol::{self as acp};
use anyhow::Result; use anyhow::Result;
use collections::IndexMap; use collections::IndexMap;
use gpui::{Entity, SharedString, Task}; use gpui::{Entity, SharedString, Task};
use language_model::LanguageModelProviderId;
use project::Project; use project::Project;
use std::{any::Any, error::Error, fmt, path::Path, rc::Rc, sync::Arc}; use std::{any::Any, error::Error, fmt, path::Path, rc::Rc, sync::Arc};
use ui::{App, IconName}; use ui::{App, IconName};
@ -82,15 +83,14 @@ pub trait AgentSessionResume {
#[derive(Debug)] #[derive(Debug)]
pub struct AuthRequired { pub struct AuthRequired {
pub description: Option<String>, pub description: Option<String>,
/// A Task that resolves when authentication is updated pub provider_id: Option<LanguageModelProviderId>,
pub update_task: Option<Task<()>>,
} }
impl AuthRequired { impl AuthRequired {
pub fn new() -> Self { pub fn new() -> Self {
Self { Self {
description: None, description: None,
update_task: None, provider_id: None,
} }
} }
@ -99,8 +99,8 @@ impl AuthRequired {
self self
} }
pub fn with_update(mut self, update: Task<()>) -> Self { pub fn with_language_model_provider(mut self, provider_id: LanguageModelProviderId) -> Self {
self.update_task = Some(update); self.provider_id = Some(provider_id);
self self
} }
} }

View file

@ -3,7 +3,6 @@ pub mod tools;
use collections::HashMap; use collections::HashMap;
use context_server::listener::McpServerTool; use context_server::listener::McpServerTool;
use language_model::LanguageModelRegistry;
use language_models::provider::anthropic::AnthropicLanguageModelProvider; use language_models::provider::anthropic::AnthropicLanguageModelProvider;
use project::Project; use project::Project;
use settings::SettingsStore; use settings::SettingsStore;
@ -13,7 +12,6 @@ use std::cell::RefCell;
use std::fmt::Display; use std::fmt::Display;
use std::path::Path; use std::path::Path;
use std::rc::Rc; use std::rc::Rc;
use std::sync::Arc;
use uuid::Uuid; use uuid::Uuid;
use agent_client_protocol as acp; use agent_client_protocol as acp;
@ -99,53 +97,18 @@ impl AgentConnection for ClaudeAgentConnection {
anyhow::bail!("Failed to find claude binary"); anyhow::bail!("Failed to find claude binary");
}; };
let anthropic: Arc<AnthropicLanguageModelProvider> = cx.update(|cx| { let api_key =
let registry = LanguageModelRegistry::global(cx); cx.update(|cx| AnthropicLanguageModelProvider::api_key(cx))?
let provider: Arc<dyn Any + Send + Sync> = registry .await
.read(cx) .map_err(|err| {
.provider(&language_model::ANTHROPIC_PROVIDER_ID) if err.is::<language_model::AuthenticateError>() {
.context("Failed to get Anthropic provider")?; anyhow!(AuthRequired::new().with_language_model_provider(
language_model::ANTHROPIC_PROVIDER_ID
Arc::downcast::<AnthropicLanguageModelProvider>(provider) ))
.map_err(|_| anyhow!("Failed to downcast provider")) } else {
})??; anyhow!(err)
}
let api_key = cx })?;
.update(|cx| AnthropicLanguageModelProvider::api_key(cx))?
.await
.map_err(|err| {
if err.is::<language_model::AuthenticateError>() {
let (update_tx, update_rx) = oneshot::channel();
let mut update_tx = Some(update_tx);
let sub = cx
.update(|cx| {
anthropic.observe(
move |_cx| {
if let Some(update_tx) = update_tx.take() {
update_tx.send(()).ok();
}
},
cx,
)
})
.ok();
let update_task = cx.foreground_executor().spawn(async move {
update_rx.await.ok();
drop(sub)
});
anyhow!(
AuthRequired::new()
.with_description(
"To use Claude Code in Zed, you need an [Anthropic API key](https://console.anthropic.com/settings/keys)\n\nAdd one in [settings](zed:///agent/settings) or set the `ANTHROPIC_API_KEY` variable".into())
.with_update(update_task)
)
} else {
anyhow!(err)
}
})?;
let (mut thread_tx, thread_rx) = watch::channel(WeakEntity::new_invalid()); let (mut thread_tx, thread_rx) = watch::channel(WeakEntity::new_invalid());
let permission_mcp_server = ClaudeZedMcpServer::new(thread_rx.clone(), cx).await?; let permission_mcp_server = ClaudeZedMcpServer::new(thread_rx.clone(), cx).await?;

View file

@ -1,6 +1,7 @@
use acp_thread::{ use acp_thread::{
AcpThread, AcpThreadEvent, AgentThreadEntry, AssistantMessage, AssistantMessageChunk, AcpThread, AcpThreadEvent, AgentThreadEntry, AssistantMessage, AssistantMessageChunk,
LoadError, MentionUri, ThreadStatus, ToolCall, ToolCallContent, ToolCallStatus, UserMessageId, AuthRequired, LoadError, MentionUri, ThreadStatus, ToolCall, ToolCallContent, ToolCallStatus,
UserMessageId,
}; };
use acp_thread::{AgentConnection, Plan}; use acp_thread::{AgentConnection, Plan};
use action_log::ActionLog; use action_log::ActionLog;
@ -18,13 +19,16 @@ use editor::{Editor, EditorMode, MultiBuffer, PathKey, SelectionEffects};
use file_icons::FileIcons; use file_icons::FileIcons;
use fs::Fs; use fs::Fs;
use gpui::{ use gpui::{
Action, Animation, AnimationExt, App, BorderStyle, ClickEvent, ClipboardItem, EdgesRefinement, Action, Animation, AnimationExt, AnyView, App, BorderStyle, ClickEvent, ClipboardItem,
Empty, Entity, FocusHandle, Focusable, Hsla, Length, ListOffset, ListState, MouseButton, EdgesRefinement, Empty, Entity, FocusHandle, Focusable, Hsla, Length, ListOffset, ListState,
PlatformDisplay, SharedString, Stateful, StyleRefinement, Subscription, Task, TextStyle, MouseButton, PlatformDisplay, SharedString, Stateful, StyleRefinement, Subscription, Task,
TextStyleRefinement, Transformation, UnderlineStyle, WeakEntity, Window, WindowHandle, div, TextStyle, TextStyleRefinement, Transformation, UnderlineStyle, WeakEntity, Window,
linear_color_stop, linear_gradient, list, percentage, point, prelude::*, pulsating_between, WindowHandle, div, linear_color_stop, linear_gradient, list, percentage, point, prelude::*,
pulsating_between,
}; };
use language::Buffer; use language::Buffer;
use language_model::LanguageModelRegistry;
use markdown::{HeadingLevelStyles, Markdown, MarkdownElement, MarkdownStyle}; use markdown::{HeadingLevelStyles, Markdown, MarkdownElement, MarkdownStyle};
use project::Project; use project::Project;
use prompt_store::PromptId; use prompt_store::PromptId;
@ -138,6 +142,8 @@ enum ThreadState {
Unauthenticated { Unauthenticated {
connection: Rc<dyn AgentConnection>, connection: Rc<dyn AgentConnection>,
description: Option<Entity<Markdown>>, description: Option<Entity<Markdown>>,
configuration_view: Option<AnyView>,
_subscription: Option<Subscription>,
}, },
ServerExited { ServerExited {
status: ExitStatus, status: ExitStatus,
@ -268,44 +274,16 @@ impl AcpThreadView {
}; };
let result = match result.await { let result = match result.await {
Err(e) => { Err(e) => match e.downcast::<acp_thread::AuthRequired>() {
let mut cx = cx.clone(); Ok(err) => {
match e.downcast::<acp_thread::AuthRequired>() { cx.update(|window, cx| {
Ok(mut err) => { Self::handle_auth_required(this, err, agent, connection, window, cx)
if let Some(update_task) = err.update_task.take() { })
let this = this.clone(); .log_err();
let project = project.clone(); return;
cx.spawn(async move |cx| {
update_task.await;
this.update_in(cx, |this, window, cx| {
this.thread_state = Self::initial_state(
agent,
this.workspace.clone(),
project.clone(),
window,
cx,
);
cx.notify();
})
.ok();
})
.detach();
}
this.update(&mut cx, |this, cx| {
this.thread_state = ThreadState::Unauthenticated {
connection,
description: err.description.clone().map(|desc| {
cx.new(|cx| Markdown::new(desc.into(), None, None, cx))
}),
};
cx.notify();
})
.ok();
return;
}
Err(err) => Err(err),
} }
} Err(err) => Err(err),
},
Ok(thread) => Ok(thread), Ok(thread) => Ok(thread),
}; };
@ -371,6 +349,68 @@ impl AcpThreadView {
ThreadState::Loading { _task: load_task } ThreadState::Loading { _task: load_task }
} }
fn handle_auth_required(
this: WeakEntity<Self>,
err: AuthRequired,
agent: Rc<dyn AgentServer>,
connection: Rc<dyn AgentConnection>,
window: &mut Window,
cx: &mut App,
) {
let agent_name = agent.name();
let (configuration_view, subscription) = if let Some(provider_id) = err.provider_id {
let registry = LanguageModelRegistry::global(cx);
let sub = window.subscribe(&registry, cx, {
let provider_id = provider_id.clone();
let this = this.clone();
move |_, ev, window, cx| {
if let language_model::Event::ProviderStateChanged(updated_provider_id) = &ev {
if &provider_id == updated_provider_id {
this.update(cx, |this, cx| {
this.thread_state = Self::initial_state(
agent.clone(),
this.workspace.clone(),
this.project.clone(),
window,
cx,
);
cx.notify();
})
.ok();
}
}
}
});
let view = registry.read(cx).provider(&provider_id).map(|provider| {
provider.configuration_view(
language_model::ConfigurationViewTargetAgent::Other(agent_name),
window,
cx,
)
});
(view, Some(sub))
} else {
(None, None)
};
this.update(cx, |this, cx| {
this.thread_state = ThreadState::Unauthenticated {
connection,
configuration_view,
description: err
.description
.clone()
.map(|desc| cx.new(|cx| Markdown::new(desc.into(), None, None, cx))),
_subscription: subscription,
};
cx.notify();
})
.ok();
}
fn handle_load_error(&mut self, err: anyhow::Error, cx: &mut Context<Self>) { fn handle_load_error(&mut self, err: anyhow::Error, cx: &mut Context<Self>) {
if let Some(load_err) = err.downcast_ref::<LoadError>() { if let Some(load_err) = err.downcast_ref::<LoadError>() {
self.thread_state = ThreadState::LoadError(load_err.clone()); self.thread_state = ThreadState::LoadError(load_err.clone());
@ -1867,19 +1907,53 @@ impl AcpThreadView {
.into_any() .into_any()
} }
fn render_pending_auth_state(&self) -> AnyElement { fn render_auth_required_state(
&self,
connection: &Rc<dyn AgentConnection>,
description: Option<&Entity<Markdown>>,
configuration_view: Option<&AnyView>,
window: &mut Window,
cx: &Context<Self>,
) -> Div {
v_flex() v_flex()
.p_2()
.gap_2()
.flex_1()
.items_center() .items_center()
.justify_center() .justify_center()
.child(self.render_error_agent_logo())
.child( .child(
h_flex() v_flex()
.mt_4() .items_center()
.mb_1()
.justify_center() .justify_center()
.child(Headline::new("Authentication Required").size(HeadlineSize::Medium)), .child(self.render_error_agent_logo())
.child(
h_flex().mt_4().mb_1().justify_center().child(
Headline::new("Authentication Required").size(HeadlineSize::Medium),
),
)
.into_any(),
) )
.into_any() .children(description.map(|desc| {
div().text_ui(cx).text_center().child(
self.render_markdown(desc.clone(), default_markdown_style(false, window, cx)),
)
}))
.children(
configuration_view
.cloned()
.map(|view| div().px_4().w_full().max_w_128().child(view)),
)
.child(h_flex().mt_1p5().justify_center().children(
connection.auth_methods().into_iter().map(|method| {
Button::new(SharedString::from(method.id.0.clone()), method.name.clone())
.on_click({
let method_id = method.id.clone();
cx.listener(move |this, _, window, cx| {
this.authenticate(method_id.clone(), window, cx)
})
})
}),
))
} }
fn render_server_exited(&self, status: ExitStatus, _cx: &Context<Self>) -> AnyElement { fn render_server_exited(&self, status: ExitStatus, _cx: &Context<Self>) -> AnyElement {
@ -2804,13 +2878,6 @@ impl AcpThreadView {
cx.open_url(url.as_str()); cx.open_url(url.as_str());
} }
}) })
} else if url == "zed:///agent/settings" {
workspace.update(cx, |workspace, cx| {
if let Some(panel) = workspace.panel::<AgentPanel>(cx) {
workspace.focus_panel::<AgentPanel>(window, cx);
panel.update(cx, |panel, cx| panel.open_configuration(window, cx));
}
});
} else { } else {
cx.open_url(&url); cx.open_url(&url);
} }
@ -3383,33 +3450,15 @@ impl Render for AcpThreadView {
ThreadState::Unauthenticated { ThreadState::Unauthenticated {
connection, connection,
description, description,
} => v_flex() configuration_view,
.p_2() ..
.gap_2() } => self.render_auth_required_state(
.flex_1() &connection,
.items_center() description.as_ref(),
.justify_center() configuration_view.as_ref(),
.child(self.render_pending_auth_state()) window,
.text_ui(cx) cx,
.text_center() ),
.text_color(cx.theme().colors().text_muted)
.children(description.clone().map(|desc| {
self.render_markdown(desc, default_markdown_style(false, window, cx))
}))
.child(h_flex().mt_1p5().justify_center().children(
connection.auth_methods().into_iter().map(|method| {
Button::new(
SharedString::from(method.id.0.clone()),
method.name.clone(),
)
.on_click({
let method_id = method.id.clone();
cx.listener(move |this, _, window, cx| {
this.authenticate(method_id.clone(), window, cx)
})
})
}),
)),
ThreadState::Loading { .. } => v_flex().flex_1().child(self.render_empty_state(cx)), ThreadState::Loading { .. } => v_flex().flex_1().child(self.render_empty_state(cx)),
ThreadState::LoadError(e) => v_flex() ThreadState::LoadError(e) => v_flex()
.p_2() .p_2()

View file

@ -137,7 +137,11 @@ impl AgentConfiguration {
window: &mut Window, window: &mut Window,
cx: &mut Context<Self>, cx: &mut Context<Self>,
) { ) {
let configuration_view = provider.configuration_view(window, cx); let configuration_view = provider.configuration_view(
language_model::ConfigurationViewTargetAgent::ZedAgent,
window,
cx,
);
self.configuration_views_by_provider self.configuration_views_by_provider
.insert(provider.id(), configuration_view); .insert(provider.id(), configuration_view);
} }

View file

@ -320,7 +320,7 @@ fn init_language_model_settings(cx: &mut App) {
cx.subscribe( cx.subscribe(
&LanguageModelRegistry::global(cx), &LanguageModelRegistry::global(cx),
|_, event: &language_model::Event, cx| match event { |_, event: &language_model::Event, cx| match event {
language_model::Event::ProviderStateChanged language_model::Event::ProviderStateChanged(_)
| language_model::Event::AddedProvider(_) | language_model::Event::AddedProvider(_)
| language_model::Event::RemovedProvider(_) => { | language_model::Event::RemovedProvider(_) => {
update_active_language_model_from_settings(cx); update_active_language_model_from_settings(cx);

View file

@ -104,7 +104,7 @@ impl LanguageModelPickerDelegate {
window, window,
|picker, _, event, window, cx| { |picker, _, event, window, cx| {
match event { match event {
language_model::Event::ProviderStateChanged language_model::Event::ProviderStateChanged(_)
| language_model::Event::AddedProvider(_) | language_model::Event::AddedProvider(_)
| language_model::Event::RemovedProvider(_) => { | language_model::Event::RemovedProvider(_) => {
let query = picker.query(cx); let query = picker.query(cx);

View file

@ -11,7 +11,7 @@ impl ApiKeysWithProviders {
cx.subscribe( cx.subscribe(
&LanguageModelRegistry::global(cx), &LanguageModelRegistry::global(cx),
|this: &mut Self, _registry, event: &language_model::Event, cx| match event { |this: &mut Self, _registry, event: &language_model::Event, cx| match event {
language_model::Event::ProviderStateChanged language_model::Event::ProviderStateChanged(_)
| language_model::Event::AddedProvider(_) | language_model::Event::AddedProvider(_)
| language_model::Event::RemovedProvider(_) => { | language_model::Event::RemovedProvider(_) => {
this.configured_providers = Self::compute_configured_providers(cx) this.configured_providers = Self::compute_configured_providers(cx)

View file

@ -25,7 +25,7 @@ impl AgentPanelOnboarding {
cx.subscribe( cx.subscribe(
&LanguageModelRegistry::global(cx), &LanguageModelRegistry::global(cx),
|this: &mut Self, _registry, event: &language_model::Event, cx| match event { |this: &mut Self, _registry, event: &language_model::Event, cx| match event {
language_model::Event::ProviderStateChanged language_model::Event::ProviderStateChanged(_)
| language_model::Event::AddedProvider(_) | language_model::Event::AddedProvider(_)
| language_model::Event::RemovedProvider(_) => { | language_model::Event::RemovedProvider(_) => {
this.configured_providers = Self::compute_available_providers(cx) this.configured_providers = Self::compute_available_providers(cx)

View file

@ -1,8 +1,8 @@
use crate::{ use crate::{
AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, AuthenticateError, ConfigurationViewTargetAgent, LanguageModel, LanguageModelCompletionError,
LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider,
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
LanguageModelToolChoice, LanguageModelRequest, LanguageModelToolChoice,
}; };
use futures::{FutureExt, StreamExt, channel::mpsc, future::BoxFuture, stream::BoxStream}; use futures::{FutureExt, StreamExt, channel::mpsc, future::BoxFuture, stream::BoxStream};
use gpui::{AnyView, App, AsyncApp, Entity, Task, Window}; use gpui::{AnyView, App, AsyncApp, Entity, Task, Window};
@ -62,7 +62,12 @@ impl LanguageModelProvider for FakeLanguageModelProvider {
Task::ready(Ok(())) Task::ready(Ok(()))
} }
fn configuration_view(&self, _window: &mut Window, _: &mut App) -> AnyView { fn configuration_view(
&self,
_target_agent: ConfigurationViewTargetAgent,
_window: &mut Window,
_: &mut App,
) -> AnyView {
unimplemented!() unimplemented!()
} }

View file

@ -20,7 +20,6 @@ use icons::IconName;
use parking_lot::Mutex; use parking_lot::Mutex;
use schemars::JsonSchema; use schemars::JsonSchema;
use serde::{Deserialize, Serialize, de::DeserializeOwned}; use serde::{Deserialize, Serialize, de::DeserializeOwned};
use std::any::Any;
use std::ops::{Add, Sub}; use std::ops::{Add, Sub};
use std::str::FromStr; use std::str::FromStr;
use std::sync::Arc; use std::sync::Arc;
@ -621,7 +620,7 @@ pub enum AuthenticateError {
Other(#[from] anyhow::Error), Other(#[from] anyhow::Error),
} }
pub trait LanguageModelProvider: Any + Send + Sync { pub trait LanguageModelProvider: 'static {
fn id(&self) -> LanguageModelProviderId; fn id(&self) -> LanguageModelProviderId;
fn name(&self) -> LanguageModelProviderName; fn name(&self) -> LanguageModelProviderName;
fn icon(&self) -> IconName { fn icon(&self) -> IconName {
@ -635,7 +634,12 @@ pub trait LanguageModelProvider: Any + Send + Sync {
} }
fn is_authenticated(&self, cx: &App) -> bool; fn is_authenticated(&self, cx: &App) -> bool;
fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>>; fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>>;
fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView; fn configuration_view(
&self,
target_agent: ConfigurationViewTargetAgent,
window: &mut Window,
cx: &mut App,
) -> AnyView;
fn must_accept_terms(&self, _cx: &App) -> bool { fn must_accept_terms(&self, _cx: &App) -> bool {
false false
} }
@ -649,6 +653,13 @@ pub trait LanguageModelProvider: Any + Send + Sync {
fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>>; fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>>;
} }
#[derive(Default, Clone, Copy)]
pub enum ConfigurationViewTargetAgent {
#[default]
ZedAgent,
Other(&'static str),
}
#[derive(PartialEq, Eq)] #[derive(PartialEq, Eq)]
pub enum LanguageModelProviderTosView { pub enum LanguageModelProviderTosView {
/// When there are some past interactions in the Agent Panel. /// When there are some past interactions in the Agent Panel.

View file

@ -107,8 +107,7 @@ pub enum Event {
InlineAssistantModelChanged, InlineAssistantModelChanged,
CommitMessageModelChanged, CommitMessageModelChanged,
ThreadSummaryModelChanged, ThreadSummaryModelChanged,
ProviderStateChanged, ProviderStateChanged(LanguageModelProviderId),
ProviderAuthUpdated,
AddedProvider(LanguageModelProviderId), AddedProvider(LanguageModelProviderId),
RemovedProvider(LanguageModelProviderId), RemovedProvider(LanguageModelProviderId),
} }
@ -149,8 +148,11 @@ impl LanguageModelRegistry {
) { ) {
let id = provider.id(); let id = provider.id();
let subscription = provider.subscribe(cx, |_, cx| { let subscription = provider.subscribe(cx, {
cx.emit(Event::ProviderStateChanged); let id = id.clone();
move |_, cx| {
cx.emit(Event::ProviderStateChanged(id.clone()));
}
}); });
if let Some(subscription) = subscription { if let Some(subscription) = subscription {
subscription.detach(); subscription.detach();

View file

@ -15,11 +15,11 @@ use gpui::{
}; };
use http_client::HttpClient; use http_client::HttpClient;
use language_model::{ use language_model::{
AuthenticateError, LanguageModel, LanguageModelCacheConfiguration, AuthenticateError, ConfigurationViewTargetAgent, LanguageModel,
LanguageModelCompletionError, LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelCacheConfiguration, LanguageModelCompletionError, LanguageModelId,
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelRequest, LanguageModelToolChoice, LanguageModelToolResultContent, MessageContent, LanguageModelProviderState, LanguageModelRequest, LanguageModelToolChoice,
RateLimiter, Role, LanguageModelToolResultContent, MessageContent, RateLimiter, Role,
}; };
use language_model::{LanguageModelCompletionEvent, LanguageModelToolUse, StopReason}; use language_model::{LanguageModelCompletionEvent, LanguageModelToolUse, StopReason};
use schemars::JsonSchema; use schemars::JsonSchema;
@ -223,14 +223,6 @@ impl AnthropicLanguageModelProvider {
}) })
} }
} }
pub fn observe(
&self,
mut on_notify: impl FnMut(&mut App) + 'static,
cx: &mut App,
) -> Subscription {
cx.observe(&self.state, move |_, cx| on_notify(cx))
}
} }
impl LanguageModelProviderState for AnthropicLanguageModelProvider { impl LanguageModelProviderState for AnthropicLanguageModelProvider {
@ -324,8 +316,13 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider {
self.state.update(cx, |state, cx| state.authenticate(cx)) self.state.update(cx, |state, cx| state.authenticate(cx))
} }
fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView { fn configuration_view(
cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx)) &self,
target_agent: ConfigurationViewTargetAgent,
window: &mut Window,
cx: &mut App,
) -> AnyView {
cx.new(|cx| ConfigurationView::new(self.state.clone(), target_agent, window, cx))
.into() .into()
} }
@ -927,12 +924,18 @@ struct ConfigurationView {
api_key_editor: Entity<Editor>, api_key_editor: Entity<Editor>,
state: gpui::Entity<State>, state: gpui::Entity<State>,
load_credentials_task: Option<Task<()>>, load_credentials_task: Option<Task<()>>,
target_agent: ConfigurationViewTargetAgent,
} }
impl ConfigurationView { impl ConfigurationView {
const PLACEHOLDER_TEXT: &'static str = "sk-ant-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"; const PLACEHOLDER_TEXT: &'static str = "sk-ant-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx";
fn new(state: gpui::Entity<State>, window: &mut Window, cx: &mut Context<Self>) -> Self { fn new(
state: gpui::Entity<State>,
target_agent: ConfigurationViewTargetAgent,
window: &mut Window,
cx: &mut Context<Self>,
) -> Self {
cx.observe(&state, |_, _, cx| { cx.observe(&state, |_, _, cx| {
cx.notify(); cx.notify();
}) })
@ -964,6 +967,7 @@ impl ConfigurationView {
}), }),
state, state,
load_credentials_task, load_credentials_task,
target_agent,
} }
} }
@ -1037,7 +1041,10 @@ impl Render for ConfigurationView {
v_flex() v_flex()
.size_full() .size_full()
.on_action(cx.listener(Self::save_api_key)) .on_action(cx.listener(Self::save_api_key))
.child(Label::new("To use Zed's agent with Anthropic, you need to add an API key. Follow these steps:")) .child(Label::new(format!("To use {}, you need to add an API key. Follow these steps:", match self.target_agent {
ConfigurationViewTargetAgent::ZedAgent => "Zed's agent with Anthropic",
ConfigurationViewTargetAgent::Other(agent) => agent,
})))
.child( .child(
List::new() List::new()
.child( .child(
@ -1048,7 +1055,7 @@ impl Render for ConfigurationView {
) )
) )
.child( .child(
InstructionListItem::text_only("Paste your API key below and hit enter to start using the assistant") InstructionListItem::text_only("Paste your API key below and hit enter to start using the agent")
) )
) )
.child( .child(

View file

@ -348,7 +348,12 @@ impl LanguageModelProvider for BedrockLanguageModelProvider {
self.state.update(cx, |state, cx| state.authenticate(cx)) self.state.update(cx, |state, cx| state.authenticate(cx))
} }
fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView { fn configuration_view(
&self,
_target_agent: language_model::ConfigurationViewTargetAgent,
window: &mut Window,
cx: &mut App,
) -> AnyView {
cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx)) cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx))
.into() .into()
} }

View file

@ -391,7 +391,12 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
Task::ready(Ok(())) Task::ready(Ok(()))
} }
fn configuration_view(&self, _: &mut Window, cx: &mut App) -> AnyView { fn configuration_view(
&self,
_target_agent: language_model::ConfigurationViewTargetAgent,
_: &mut Window,
cx: &mut App,
) -> AnyView {
cx.new(|_| ConfigurationView::new(self.state.clone())) cx.new(|_| ConfigurationView::new(self.state.clone()))
.into() .into()
} }

View file

@ -176,7 +176,12 @@ impl LanguageModelProvider for CopilotChatLanguageModelProvider {
Task::ready(Err(err.into())) Task::ready(Err(err.into()))
} }
fn configuration_view(&self, _: &mut Window, cx: &mut App) -> AnyView { fn configuration_view(
&self,
_target_agent: language_model::ConfigurationViewTargetAgent,
_: &mut Window,
cx: &mut App,
) -> AnyView {
let state = self.state.clone(); let state = self.state.clone();
cx.new(|cx| ConfigurationView::new(state, cx)).into() cx.new(|cx| ConfigurationView::new(state, cx)).into()
} }

View file

@ -229,7 +229,12 @@ impl LanguageModelProvider for DeepSeekLanguageModelProvider {
self.state.update(cx, |state, cx| state.authenticate(cx)) self.state.update(cx, |state, cx| state.authenticate(cx))
} }
fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView { fn configuration_view(
&self,
_target_agent: language_model::ConfigurationViewTargetAgent,
window: &mut Window,
cx: &mut App,
) -> AnyView {
cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx)) cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx))
.into() .into()
} }

View file

@ -277,7 +277,12 @@ impl LanguageModelProvider for GoogleLanguageModelProvider {
self.state.update(cx, |state, cx| state.authenticate(cx)) self.state.update(cx, |state, cx| state.authenticate(cx))
} }
fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView { fn configuration_view(
&self,
_target_agent: language_model::ConfigurationViewTargetAgent,
window: &mut Window,
cx: &mut App,
) -> AnyView {
cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx)) cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx))
.into() .into()
} }

View file

@ -226,7 +226,12 @@ impl LanguageModelProvider for LmStudioLanguageModelProvider {
self.state.update(cx, |state, cx| state.authenticate(cx)) self.state.update(cx, |state, cx| state.authenticate(cx))
} }
fn configuration_view(&self, _window: &mut Window, cx: &mut App) -> AnyView { fn configuration_view(
&self,
_target_agent: language_model::ConfigurationViewTargetAgent,
_window: &mut Window,
cx: &mut App,
) -> AnyView {
let state = self.state.clone(); let state = self.state.clone();
cx.new(|cx| ConfigurationView::new(state, cx)).into() cx.new(|cx| ConfigurationView::new(state, cx)).into()
} }

View file

@ -243,7 +243,12 @@ impl LanguageModelProvider for MistralLanguageModelProvider {
self.state.update(cx, |state, cx| state.authenticate(cx)) self.state.update(cx, |state, cx| state.authenticate(cx))
} }
fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView { fn configuration_view(
&self,
_target_agent: language_model::ConfigurationViewTargetAgent,
window: &mut Window,
cx: &mut App,
) -> AnyView {
cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx)) cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx))
.into() .into()
} }

View file

@ -255,7 +255,12 @@ impl LanguageModelProvider for OllamaLanguageModelProvider {
self.state.update(cx, |state, cx| state.authenticate(cx)) self.state.update(cx, |state, cx| state.authenticate(cx))
} }
fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView { fn configuration_view(
&self,
_target_agent: language_model::ConfigurationViewTargetAgent,
window: &mut Window,
cx: &mut App,
) -> AnyView {
let state = self.state.clone(); let state = self.state.clone();
cx.new(|cx| ConfigurationView::new(state, window, cx)) cx.new(|cx| ConfigurationView::new(state, window, cx))
.into() .into()

View file

@ -233,7 +233,12 @@ impl LanguageModelProvider for OpenAiLanguageModelProvider {
self.state.update(cx, |state, cx| state.authenticate(cx)) self.state.update(cx, |state, cx| state.authenticate(cx))
} }
fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView { fn configuration_view(
&self,
_target_agent: language_model::ConfigurationViewTargetAgent,
window: &mut Window,
cx: &mut App,
) -> AnyView {
cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx)) cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx))
.into() .into()
} }

View file

@ -243,7 +243,12 @@ impl LanguageModelProvider for OpenAiCompatibleLanguageModelProvider {
self.state.update(cx, |state, cx| state.authenticate(cx)) self.state.update(cx, |state, cx| state.authenticate(cx))
} }
fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView { fn configuration_view(
&self,
_target_agent: language_model::ConfigurationViewTargetAgent,
window: &mut Window,
cx: &mut App,
) -> AnyView {
cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx)) cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx))
.into() .into()
} }

View file

@ -306,7 +306,12 @@ impl LanguageModelProvider for OpenRouterLanguageModelProvider {
self.state.update(cx, |state, cx| state.authenticate(cx)) self.state.update(cx, |state, cx| state.authenticate(cx))
} }
fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView { fn configuration_view(
&self,
_target_agent: language_model::ConfigurationViewTargetAgent,
window: &mut Window,
cx: &mut App,
) -> AnyView {
cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx)) cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx))
.into() .into()
} }

View file

@ -230,7 +230,12 @@ impl LanguageModelProvider for VercelLanguageModelProvider {
self.state.update(cx, |state, cx| state.authenticate(cx)) self.state.update(cx, |state, cx| state.authenticate(cx))
} }
fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView { fn configuration_view(
&self,
_target_agent: language_model::ConfigurationViewTargetAgent,
window: &mut Window,
cx: &mut App,
) -> AnyView {
cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx)) cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx))
.into() .into()
} }

View file

@ -230,7 +230,12 @@ impl LanguageModelProvider for XAiLanguageModelProvider {
self.state.update(cx, |state, cx| state.authenticate(cx)) self.state.update(cx, |state, cx| state.authenticate(cx))
} }
fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView { fn configuration_view(
&self,
_target_agent: language_model::ConfigurationViewTargetAgent,
window: &mut Window,
cx: &mut App,
) -> AnyView {
cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx)) cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx))
.into() .into()
} }

View file

@ -329,7 +329,11 @@ impl AiConfigurationModal {
cx: &mut Context<Self>, cx: &mut Context<Self>,
) -> Self { ) -> Self {
let focus_handle = cx.focus_handle(); let focus_handle = cx.focus_handle();
let configuration_view = selected_provider.configuration_view(window, cx); let configuration_view = selected_provider.configuration_view(
language_model::ConfigurationViewTargetAgent::ZedAgent,
window,
cx,
);
Self { Self {
focus_handle, focus_handle,