Compare commits
5 commits
main
...
claude-key
Author | SHA1 | Date | |
---|---|---|---|
![]() |
072703c8b2 | ||
![]() |
2ed2b82e02 | ||
![]() |
21082c0aba | ||
![]() |
6c7a5c50bf | ||
![]() |
1c43338056 |
32 changed files with 410 additions and 124 deletions
3
Cargo.lock
generated
3
Cargo.lock
generated
|
@ -20,6 +20,7 @@ dependencies = [
|
|||
"indoc",
|
||||
"itertools 0.14.0",
|
||||
"language",
|
||||
"language_model",
|
||||
"markdown",
|
||||
"parking_lot",
|
||||
"project",
|
||||
|
@ -267,6 +268,8 @@ dependencies = [
|
|||
"indoc",
|
||||
"itertools 0.14.0",
|
||||
"language",
|
||||
"language_model",
|
||||
"language_models",
|
||||
"libc",
|
||||
"log",
|
||||
"nix 0.29.0",
|
||||
|
|
|
@ -28,6 +28,7 @@ futures.workspace = true
|
|||
gpui.workspace = true
|
||||
itertools.workspace = true
|
||||
language.workspace = true
|
||||
language_model.workspace = true
|
||||
markdown.workspace = true
|
||||
parking_lot = { workspace = true, optional = true }
|
||||
project.workspace = true
|
||||
|
|
|
@ -3,6 +3,7 @@ use agent_client_protocol::{self as acp};
|
|||
use anyhow::Result;
|
||||
use collections::IndexMap;
|
||||
use gpui::{Entity, SharedString, Task};
|
||||
use language_model::LanguageModelProviderId;
|
||||
use project::Project;
|
||||
use std::{any::Any, error::Error, fmt, path::Path, rc::Rc, sync::Arc};
|
||||
use ui::{App, IconName};
|
||||
|
@ -80,12 +81,34 @@ pub trait AgentSessionResume {
|
|||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct AuthRequired;
|
||||
pub struct AuthRequired {
|
||||
pub description: Option<String>,
|
||||
pub provider_id: Option<LanguageModelProviderId>,
|
||||
}
|
||||
|
||||
impl AuthRequired {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
description: None,
|
||||
provider_id: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_description(mut self, description: String) -> Self {
|
||||
self.description = Some(description);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_language_model_provider(mut self, provider_id: LanguageModelProviderId) -> Self {
|
||||
self.provider_id = Some(provider_id);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl Error for AuthRequired {}
|
||||
impl fmt::Display for AuthRequired {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "AuthRequired")
|
||||
write!(f, "Authentication required")
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -27,6 +27,8 @@ futures.workspace = true
|
|||
gpui.workspace = true
|
||||
indoc.workspace = true
|
||||
itertools.workspace = true
|
||||
language_model.workspace = true
|
||||
language_models.workspace = true
|
||||
log.workspace = true
|
||||
paths.workspace = true
|
||||
project.workspace = true
|
||||
|
|
|
@ -437,7 +437,7 @@ impl AgentConnection for AcpConnection {
|
|||
let result = acp_old::InitializeParams::response_from_any(result)?;
|
||||
|
||||
if !result.is_authenticated {
|
||||
anyhow::bail!(AuthRequired)
|
||||
anyhow::bail!(AuthRequired::new())
|
||||
}
|
||||
|
||||
cx.update(|cx| {
|
||||
|
|
|
@ -140,7 +140,13 @@ impl AgentConnection for AcpConnection {
|
|||
.await
|
||||
.map_err(|err| {
|
||||
if err.code == acp::ErrorCode::AUTH_REQUIRED.code {
|
||||
anyhow!(AuthRequired)
|
||||
let mut error = AuthRequired::new();
|
||||
|
||||
if err.message != acp::ErrorCode::AUTH_REQUIRED.message {
|
||||
error = error.with_description(err.message);
|
||||
}
|
||||
|
||||
anyhow!(error)
|
||||
} else {
|
||||
anyhow!(err)
|
||||
}
|
||||
|
|
|
@ -3,6 +3,7 @@ pub mod tools;
|
|||
|
||||
use collections::HashMap;
|
||||
use context_server::listener::McpServerTool;
|
||||
use language_models::provider::anthropic::AnthropicLanguageModelProvider;
|
||||
use project::Project;
|
||||
use settings::SettingsStore;
|
||||
use smol::process::Child;
|
||||
|
@ -30,7 +31,7 @@ use util::{ResultExt, debug_panic};
|
|||
use crate::claude::mcp_server::{ClaudeZedMcpServer, McpConfig};
|
||||
use crate::claude::tools::ClaudeTool;
|
||||
use crate::{AgentServer, AgentServerCommand, AllAgentServersSettings};
|
||||
use acp_thread::{AcpThread, AgentConnection};
|
||||
use acp_thread::{AcpThread, AgentConnection, AuthRequired};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ClaudeCode;
|
||||
|
@ -79,6 +80,36 @@ impl AgentConnection for ClaudeAgentConnection {
|
|||
) -> Task<Result<Entity<AcpThread>>> {
|
||||
let cwd = cwd.to_owned();
|
||||
cx.spawn(async move |cx| {
|
||||
let settings = cx.read_global(|settings: &SettingsStore, _| {
|
||||
settings.get::<AllAgentServersSettings>(None).claude.clone()
|
||||
})?;
|
||||
|
||||
let Some(command) = AgentServerCommand::resolve(
|
||||
"claude",
|
||||
&[],
|
||||
Some(&util::paths::home_dir().join(".claude/local/claude")),
|
||||
settings,
|
||||
&project,
|
||||
cx,
|
||||
)
|
||||
.await
|
||||
else {
|
||||
anyhow::bail!("Failed to find claude binary");
|
||||
};
|
||||
|
||||
let api_key =
|
||||
cx.update(AnthropicLanguageModelProvider::api_key)?
|
||||
.await
|
||||
.map_err(|err| {
|
||||
if err.is::<language_model::AuthenticateError>() {
|
||||
anyhow!(AuthRequired::new().with_language_model_provider(
|
||||
language_model::ANTHROPIC_PROVIDER_ID
|
||||
))
|
||||
} else {
|
||||
anyhow!(err)
|
||||
}
|
||||
})?;
|
||||
|
||||
let (mut thread_tx, thread_rx) = watch::channel(WeakEntity::new_invalid());
|
||||
let permission_mcp_server = ClaudeZedMcpServer::new(thread_rx.clone(), cx).await?;
|
||||
|
||||
|
@ -98,23 +129,6 @@ impl AgentConnection for ClaudeAgentConnection {
|
|||
.await?;
|
||||
mcp_config_file.flush().await?;
|
||||
|
||||
let settings = cx.read_global(|settings: &SettingsStore, _| {
|
||||
settings.get::<AllAgentServersSettings>(None).claude.clone()
|
||||
})?;
|
||||
|
||||
let Some(command) = AgentServerCommand::resolve(
|
||||
"claude",
|
||||
&[],
|
||||
Some(&util::paths::home_dir().join(".claude/local/claude")),
|
||||
settings,
|
||||
&project,
|
||||
cx,
|
||||
)
|
||||
.await
|
||||
else {
|
||||
anyhow::bail!("Failed to find claude binary");
|
||||
};
|
||||
|
||||
let (incoming_message_tx, mut incoming_message_rx) = mpsc::unbounded();
|
||||
let (outgoing_tx, outgoing_rx) = mpsc::unbounded();
|
||||
|
||||
|
@ -126,6 +140,7 @@ impl AgentConnection for ClaudeAgentConnection {
|
|||
&command,
|
||||
ClaudeSessionMode::Start,
|
||||
session_id.clone(),
|
||||
api_key,
|
||||
&mcp_config_path,
|
||||
&cwd,
|
||||
)?;
|
||||
|
@ -320,6 +335,7 @@ fn spawn_claude(
|
|||
command: &AgentServerCommand,
|
||||
mode: ClaudeSessionMode,
|
||||
session_id: acp::SessionId,
|
||||
api_key: language_models::provider::anthropic::ApiKey,
|
||||
mcp_config_path: &Path,
|
||||
root_dir: &Path,
|
||||
) -> Result<Child> {
|
||||
|
@ -355,6 +371,8 @@ fn spawn_claude(
|
|||
ClaudeSessionMode::Resume => ["--resume".to_string(), session_id.to_string()],
|
||||
})
|
||||
.args(command.args.iter().map(|arg| arg.as_str()))
|
||||
.envs(command.env.iter().flatten())
|
||||
.env("ANTHROPIC_API_KEY", api_key.key)
|
||||
.current_dir(root_dir)
|
||||
.stdin(std::process::Stdio::piped())
|
||||
.stdout(std::process::Stdio::piped())
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
use acp_thread::{
|
||||
AcpThread, AcpThreadEvent, AgentThreadEntry, AssistantMessage, AssistantMessageChunk,
|
||||
LoadError, MentionUri, ThreadStatus, ToolCall, ToolCallContent, ToolCallStatus, UserMessageId,
|
||||
AuthRequired, LoadError, MentionUri, ThreadStatus, ToolCall, ToolCallContent, ToolCallStatus,
|
||||
UserMessageId,
|
||||
};
|
||||
use acp_thread::{AgentConnection, Plan};
|
||||
use action_log::ActionLog;
|
||||
|
@ -18,13 +19,16 @@ use editor::{Editor, EditorMode, MultiBuffer, PathKey, SelectionEffects};
|
|||
use file_icons::FileIcons;
|
||||
use fs::Fs;
|
||||
use gpui::{
|
||||
Action, Animation, AnimationExt, App, BorderStyle, ClickEvent, ClipboardItem, EdgesRefinement,
|
||||
Empty, Entity, FocusHandle, Focusable, Hsla, Length, ListOffset, ListState, MouseButton,
|
||||
PlatformDisplay, SharedString, Stateful, StyleRefinement, Subscription, Task, TextStyle,
|
||||
TextStyleRefinement, Transformation, UnderlineStyle, WeakEntity, Window, WindowHandle, div,
|
||||
linear_color_stop, linear_gradient, list, percentage, point, prelude::*, pulsating_between,
|
||||
Action, Animation, AnimationExt, AnyView, App, BorderStyle, ClickEvent, ClipboardItem,
|
||||
EdgesRefinement, Empty, Entity, FocusHandle, Focusable, Hsla, Length, ListOffset, ListState,
|
||||
MouseButton, PlatformDisplay, SharedString, Stateful, StyleRefinement, Subscription, Task,
|
||||
TextStyle, TextStyleRefinement, Transformation, UnderlineStyle, WeakEntity, Window,
|
||||
WindowHandle, div, linear_color_stop, linear_gradient, list, percentage, point, prelude::*,
|
||||
pulsating_between,
|
||||
};
|
||||
use language::Buffer;
|
||||
|
||||
use language_model::LanguageModelRegistry;
|
||||
use markdown::{HeadingLevelStyles, Markdown, MarkdownElement, MarkdownStyle};
|
||||
use project::Project;
|
||||
use prompt_store::PromptId;
|
||||
|
@ -137,6 +141,9 @@ enum ThreadState {
|
|||
LoadError(LoadError),
|
||||
Unauthenticated {
|
||||
connection: Rc<dyn AgentConnection>,
|
||||
description: Option<Entity<Markdown>>,
|
||||
configuration_view: Option<AnyView>,
|
||||
_subscription: Option<Subscription>,
|
||||
},
|
||||
ServerExited {
|
||||
status: ExitStatus,
|
||||
|
@ -267,19 +274,16 @@ impl AcpThreadView {
|
|||
};
|
||||
|
||||
let result = match result.await {
|
||||
Err(e) => {
|
||||
let mut cx = cx.clone();
|
||||
if e.is::<acp_thread::AuthRequired>() {
|
||||
this.update(&mut cx, |this, cx| {
|
||||
this.thread_state = ThreadState::Unauthenticated { connection };
|
||||
cx.notify();
|
||||
Err(e) => match e.downcast::<acp_thread::AuthRequired>() {
|
||||
Ok(err) => {
|
||||
cx.update(|window, cx| {
|
||||
Self::handle_auth_required(this, err, agent, connection, window, cx)
|
||||
})
|
||||
.ok();
|
||||
.log_err();
|
||||
return;
|
||||
} else {
|
||||
Err(e)
|
||||
}
|
||||
}
|
||||
Err(err) => Err(err),
|
||||
},
|
||||
Ok(thread) => Ok(thread),
|
||||
};
|
||||
|
||||
|
@ -345,6 +349,68 @@ impl AcpThreadView {
|
|||
ThreadState::Loading { _task: load_task }
|
||||
}
|
||||
|
||||
fn handle_auth_required(
|
||||
this: WeakEntity<Self>,
|
||||
err: AuthRequired,
|
||||
agent: Rc<dyn AgentServer>,
|
||||
connection: Rc<dyn AgentConnection>,
|
||||
window: &mut Window,
|
||||
cx: &mut App,
|
||||
) {
|
||||
let agent_name = agent.name();
|
||||
let (configuration_view, subscription) = if let Some(provider_id) = err.provider_id {
|
||||
let registry = LanguageModelRegistry::global(cx);
|
||||
|
||||
let sub = window.subscribe(®istry, cx, {
|
||||
let provider_id = provider_id.clone();
|
||||
let this = this.clone();
|
||||
move |_, ev, window, cx| {
|
||||
if let language_model::Event::ProviderStateChanged(updated_provider_id) = &ev {
|
||||
if &provider_id == updated_provider_id {
|
||||
this.update(cx, |this, cx| {
|
||||
this.thread_state = Self::initial_state(
|
||||
agent.clone(),
|
||||
this.workspace.clone(),
|
||||
this.project.clone(),
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
cx.notify();
|
||||
})
|
||||
.ok();
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let view = registry.read(cx).provider(&provider_id).map(|provider| {
|
||||
provider.configuration_view(
|
||||
language_model::ConfigurationViewTargetAgent::Other(agent_name),
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
});
|
||||
|
||||
(view, Some(sub))
|
||||
} else {
|
||||
(None, None)
|
||||
};
|
||||
|
||||
this.update(cx, |this, cx| {
|
||||
this.thread_state = ThreadState::Unauthenticated {
|
||||
connection,
|
||||
configuration_view,
|
||||
description: err
|
||||
.description
|
||||
.clone()
|
||||
.map(|desc| cx.new(|cx| Markdown::new(desc.into(), None, None, cx))),
|
||||
_subscription: subscription,
|
||||
};
|
||||
cx.notify();
|
||||
})
|
||||
.ok();
|
||||
}
|
||||
|
||||
fn handle_load_error(&mut self, err: anyhow::Error, cx: &mut Context<Self>) {
|
||||
if let Some(load_err) = err.downcast_ref::<LoadError>() {
|
||||
self.thread_state = ThreadState::LoadError(load_err.clone());
|
||||
|
@ -369,7 +435,7 @@ impl AcpThreadView {
|
|||
ThreadState::Ready { thread, .. } => thread.read(cx).title(),
|
||||
ThreadState::Loading { .. } => "Loading…".into(),
|
||||
ThreadState::LoadError(_) => "Failed to load".into(),
|
||||
ThreadState::Unauthenticated { .. } => "Not authenticated".into(),
|
||||
ThreadState::Unauthenticated { .. } => "Authentication Required".into(),
|
||||
ThreadState::ServerExited { .. } => "Server exited unexpectedly".into(),
|
||||
}
|
||||
}
|
||||
|
@ -708,7 +774,7 @@ impl AcpThreadView {
|
|||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
let ThreadState::Unauthenticated { ref connection } = self.thread_state else {
|
||||
let ThreadState::Unauthenticated { ref connection, .. } = self.thread_state else {
|
||||
return;
|
||||
};
|
||||
|
||||
|
@ -1841,19 +1907,63 @@ impl AcpThreadView {
|
|||
.into_any()
|
||||
}
|
||||
|
||||
fn render_pending_auth_state(&self) -> AnyElement {
|
||||
fn render_auth_required_state(
|
||||
&self,
|
||||
connection: &Rc<dyn AgentConnection>,
|
||||
description: Option<&Entity<Markdown>>,
|
||||
configuration_view: Option<&AnyView>,
|
||||
window: &mut Window,
|
||||
cx: &Context<Self>,
|
||||
) -> Div {
|
||||
v_flex()
|
||||
.py_2()
|
||||
.px_8()
|
||||
.w_full()
|
||||
.flex_1()
|
||||
.items_center()
|
||||
.justify_center()
|
||||
.child(self.render_error_agent_logo())
|
||||
.child(
|
||||
h_flex()
|
||||
.mt_4()
|
||||
.mb_1()
|
||||
v_flex()
|
||||
.w_full()
|
||||
.max_w(px(530.))
|
||||
.justify_center()
|
||||
.child(Headline::new("Not Authenticated").size(HeadlineSize::Medium)),
|
||||
.gap_2()
|
||||
.child(
|
||||
v_flex()
|
||||
.justify_center()
|
||||
.items_center()
|
||||
.child(self.render_error_agent_logo())
|
||||
.child(h_flex().mt_4().mb_1().justify_center().child(
|
||||
Headline::new("Authentication Required").size(HeadlineSize::Medium),
|
||||
))
|
||||
.into_any(),
|
||||
)
|
||||
.children(description.map(|desc| {
|
||||
div().text_ui(cx).text_center().child(self.render_markdown(
|
||||
desc.clone(),
|
||||
default_markdown_style(false, window, cx),
|
||||
))
|
||||
}))
|
||||
.children(
|
||||
configuration_view
|
||||
.cloned()
|
||||
.map(|view| div().w_full().child(view)),
|
||||
)
|
||||
.child(h_flex().mt_1p5().justify_center().children(
|
||||
connection.auth_methods().into_iter().map(|method| {
|
||||
Button::new(
|
||||
SharedString::from(method.id.0.clone()),
|
||||
method.name.clone(),
|
||||
)
|
||||
.on_click({
|
||||
let method_id = method.id.clone();
|
||||
cx.listener(move |this, _, window, cx| {
|
||||
this.authenticate(method_id.clone(), window, cx)
|
||||
})
|
||||
})
|
||||
}),
|
||||
)),
|
||||
)
|
||||
.into_any()
|
||||
}
|
||||
|
||||
fn render_server_exited(&self, status: ExitStatus, _cx: &Context<Self>) -> AnyElement {
|
||||
|
@ -3347,26 +3457,18 @@ impl Render for AcpThreadView {
|
|||
.on_action(cx.listener(Self::toggle_burn_mode))
|
||||
.bg(cx.theme().colors().panel_background)
|
||||
.child(match &self.thread_state {
|
||||
ThreadState::Unauthenticated { connection } => v_flex()
|
||||
.p_2()
|
||||
.flex_1()
|
||||
.items_center()
|
||||
.justify_center()
|
||||
.child(self.render_pending_auth_state())
|
||||
.child(h_flex().mt_1p5().justify_center().children(
|
||||
connection.auth_methods().into_iter().map(|method| {
|
||||
Button::new(
|
||||
SharedString::from(method.id.0.clone()),
|
||||
method.name.clone(),
|
||||
)
|
||||
.on_click({
|
||||
let method_id = method.id.clone();
|
||||
cx.listener(move |this, _, window, cx| {
|
||||
this.authenticate(method_id.clone(), window, cx)
|
||||
})
|
||||
})
|
||||
}),
|
||||
)),
|
||||
ThreadState::Unauthenticated {
|
||||
connection,
|
||||
description,
|
||||
configuration_view,
|
||||
..
|
||||
} => self.render_auth_required_state(
|
||||
&connection,
|
||||
description.as_ref(),
|
||||
configuration_view.as_ref(),
|
||||
window,
|
||||
cx,
|
||||
),
|
||||
ThreadState::Loading { .. } => v_flex().flex_1().child(self.render_empty_state(cx)),
|
||||
ThreadState::LoadError(e) => v_flex()
|
||||
.p_2()
|
||||
|
|
|
@ -137,7 +137,11 @@ impl AgentConfiguration {
|
|||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
let configuration_view = provider.configuration_view(window, cx);
|
||||
let configuration_view = provider.configuration_view(
|
||||
language_model::ConfigurationViewTargetAgent::ZedAgent,
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
self.configuration_views_by_provider
|
||||
.insert(provider.id(), configuration_view);
|
||||
}
|
||||
|
|
|
@ -320,7 +320,7 @@ fn init_language_model_settings(cx: &mut App) {
|
|||
cx.subscribe(
|
||||
&LanguageModelRegistry::global(cx),
|
||||
|_, event: &language_model::Event, cx| match event {
|
||||
language_model::Event::ProviderStateChanged
|
||||
language_model::Event::ProviderStateChanged(_)
|
||||
| language_model::Event::AddedProvider(_)
|
||||
| language_model::Event::RemovedProvider(_) => {
|
||||
update_active_language_model_from_settings(cx);
|
||||
|
|
|
@ -104,7 +104,7 @@ impl LanguageModelPickerDelegate {
|
|||
window,
|
||||
|picker, _, event, window, cx| {
|
||||
match event {
|
||||
language_model::Event::ProviderStateChanged
|
||||
language_model::Event::ProviderStateChanged(_)
|
||||
| language_model::Event::AddedProvider(_)
|
||||
| language_model::Event::RemovedProvider(_) => {
|
||||
let query = picker.query(cx);
|
||||
|
|
|
@ -11,7 +11,7 @@ impl ApiKeysWithProviders {
|
|||
cx.subscribe(
|
||||
&LanguageModelRegistry::global(cx),
|
||||
|this: &mut Self, _registry, event: &language_model::Event, cx| match event {
|
||||
language_model::Event::ProviderStateChanged
|
||||
language_model::Event::ProviderStateChanged(_)
|
||||
| language_model::Event::AddedProvider(_)
|
||||
| language_model::Event::RemovedProvider(_) => {
|
||||
this.configured_providers = Self::compute_configured_providers(cx)
|
||||
|
|
|
@ -25,7 +25,7 @@ impl AgentPanelOnboarding {
|
|||
cx.subscribe(
|
||||
&LanguageModelRegistry::global(cx),
|
||||
|this: &mut Self, _registry, event: &language_model::Event, cx| match event {
|
||||
language_model::Event::ProviderStateChanged
|
||||
language_model::Event::ProviderStateChanged(_)
|
||||
| language_model::Event::AddedProvider(_)
|
||||
| language_model::Event::RemovedProvider(_) => {
|
||||
this.configured_providers = Self::compute_available_providers(cx)
|
||||
|
|
|
@ -201,3 +201,9 @@ impl Drop for Subscription {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for Subscription {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("Subscription").finish()
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
use crate::{
|
||||
AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
|
||||
LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
|
||||
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
|
||||
LanguageModelToolChoice,
|
||||
AuthenticateError, ConfigurationViewTargetAgent, LanguageModel, LanguageModelCompletionError,
|
||||
LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider,
|
||||
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
|
||||
LanguageModelRequest, LanguageModelToolChoice,
|
||||
};
|
||||
use futures::{FutureExt, StreamExt, channel::mpsc, future::BoxFuture, stream::BoxStream};
|
||||
use gpui::{AnyView, App, AsyncApp, Entity, Task, Window};
|
||||
|
@ -62,7 +62,12 @@ impl LanguageModelProvider for FakeLanguageModelProvider {
|
|||
Task::ready(Ok(()))
|
||||
}
|
||||
|
||||
fn configuration_view(&self, _window: &mut Window, _: &mut App) -> AnyView {
|
||||
fn configuration_view(
|
||||
&self,
|
||||
_target_agent: ConfigurationViewTargetAgent,
|
||||
_window: &mut Window,
|
||||
_: &mut App,
|
||||
) -> AnyView {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
|
|
|
@ -634,7 +634,12 @@ pub trait LanguageModelProvider: 'static {
|
|||
}
|
||||
fn is_authenticated(&self, cx: &App) -> bool;
|
||||
fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>>;
|
||||
fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView;
|
||||
fn configuration_view(
|
||||
&self,
|
||||
target_agent: ConfigurationViewTargetAgent,
|
||||
window: &mut Window,
|
||||
cx: &mut App,
|
||||
) -> AnyView;
|
||||
fn must_accept_terms(&self, _cx: &App) -> bool {
|
||||
false
|
||||
}
|
||||
|
@ -648,6 +653,13 @@ pub trait LanguageModelProvider: 'static {
|
|||
fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>>;
|
||||
}
|
||||
|
||||
#[derive(Default, Clone, Copy)]
|
||||
pub enum ConfigurationViewTargetAgent {
|
||||
#[default]
|
||||
ZedAgent,
|
||||
Other(&'static str),
|
||||
}
|
||||
|
||||
#[derive(PartialEq, Eq)]
|
||||
pub enum LanguageModelProviderTosView {
|
||||
/// When there are some past interactions in the Agent Panel.
|
||||
|
|
|
@ -107,7 +107,7 @@ pub enum Event {
|
|||
InlineAssistantModelChanged,
|
||||
CommitMessageModelChanged,
|
||||
ThreadSummaryModelChanged,
|
||||
ProviderStateChanged,
|
||||
ProviderStateChanged(LanguageModelProviderId),
|
||||
AddedProvider(LanguageModelProviderId),
|
||||
RemovedProvider(LanguageModelProviderId),
|
||||
}
|
||||
|
@ -148,8 +148,11 @@ impl LanguageModelRegistry {
|
|||
) {
|
||||
let id = provider.id();
|
||||
|
||||
let subscription = provider.subscribe(cx, |_, cx| {
|
||||
cx.emit(Event::ProviderStateChanged);
|
||||
let subscription = provider.subscribe(cx, {
|
||||
let id = id.clone();
|
||||
move |_, cx| {
|
||||
cx.emit(Event::ProviderStateChanged(id.clone()));
|
||||
}
|
||||
});
|
||||
if let Some(subscription) = subscription {
|
||||
subscription.detach();
|
||||
|
|
|
@ -15,11 +15,11 @@ use gpui::{
|
|||
};
|
||||
use http_client::HttpClient;
|
||||
use language_model::{
|
||||
AuthenticateError, LanguageModel, LanguageModelCacheConfiguration,
|
||||
LanguageModelCompletionError, LanguageModelId, LanguageModelName, LanguageModelProvider,
|
||||
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
|
||||
LanguageModelRequest, LanguageModelToolChoice, LanguageModelToolResultContent, MessageContent,
|
||||
RateLimiter, Role,
|
||||
AuthenticateError, ConfigurationViewTargetAgent, LanguageModel,
|
||||
LanguageModelCacheConfiguration, LanguageModelCompletionError, LanguageModelId,
|
||||
LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
|
||||
LanguageModelProviderState, LanguageModelRequest, LanguageModelToolChoice,
|
||||
LanguageModelToolResultContent, MessageContent, RateLimiter, Role,
|
||||
};
|
||||
use language_model::{LanguageModelCompletionEvent, LanguageModelToolUse, StopReason};
|
||||
use schemars::JsonSchema;
|
||||
|
@ -153,29 +153,14 @@ impl State {
|
|||
return Task::ready(Ok(()));
|
||||
}
|
||||
|
||||
let credentials_provider = <dyn CredentialsProvider>::global(cx);
|
||||
let api_url = AllLanguageModelSettings::get_global(cx)
|
||||
.anthropic
|
||||
.api_url
|
||||
.clone();
|
||||
let key = AnthropicLanguageModelProvider::api_key(cx);
|
||||
|
||||
cx.spawn(async move |this, cx| {
|
||||
let (api_key, from_env) = if let Ok(api_key) = std::env::var(ANTHROPIC_API_KEY_VAR) {
|
||||
(api_key, true)
|
||||
} else {
|
||||
let (_, api_key) = credentials_provider
|
||||
.read_credentials(&api_url, &cx)
|
||||
.await?
|
||||
.ok_or(AuthenticateError::CredentialsNotFound)?;
|
||||
(
|
||||
String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?,
|
||||
false,
|
||||
)
|
||||
};
|
||||
let key = key.await?;
|
||||
|
||||
this.update(cx, |this, cx| {
|
||||
this.api_key = Some(api_key);
|
||||
this.api_key_from_env = from_env;
|
||||
this.api_key = Some(key.key);
|
||||
this.api_key_from_env = key.from_env;
|
||||
cx.notify();
|
||||
})?;
|
||||
|
||||
|
@ -184,6 +169,11 @@ impl State {
|
|||
}
|
||||
}
|
||||
|
||||
pub struct ApiKey {
|
||||
pub key: String,
|
||||
pub from_env: bool,
|
||||
}
|
||||
|
||||
impl AnthropicLanguageModelProvider {
|
||||
pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
|
||||
let state = cx.new(|cx| State {
|
||||
|
@ -206,6 +196,33 @@ impl AnthropicLanguageModelProvider {
|
|||
request_limiter: RateLimiter::new(4),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn api_key(cx: &mut App) -> Task<Result<ApiKey>> {
|
||||
let credentials_provider = <dyn CredentialsProvider>::global(cx);
|
||||
let api_url = AllLanguageModelSettings::get_global(cx)
|
||||
.anthropic
|
||||
.api_url
|
||||
.clone();
|
||||
|
||||
if let Ok(key) = std::env::var(ANTHROPIC_API_KEY_VAR) {
|
||||
Task::ready(Ok(ApiKey {
|
||||
key,
|
||||
from_env: true,
|
||||
}))
|
||||
} else {
|
||||
cx.spawn(async move |cx| {
|
||||
let (_, api_key) = credentials_provider
|
||||
.read_credentials(&api_url, &cx)
|
||||
.await?
|
||||
.ok_or(AuthenticateError::CredentialsNotFound)?;
|
||||
|
||||
Ok(ApiKey {
|
||||
key: String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?,
|
||||
from_env: false,
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl LanguageModelProviderState for AnthropicLanguageModelProvider {
|
||||
|
@ -299,8 +316,13 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider {
|
|||
self.state.update(cx, |state, cx| state.authenticate(cx))
|
||||
}
|
||||
|
||||
fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView {
|
||||
cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx))
|
||||
fn configuration_view(
|
||||
&self,
|
||||
target_agent: ConfigurationViewTargetAgent,
|
||||
window: &mut Window,
|
||||
cx: &mut App,
|
||||
) -> AnyView {
|
||||
cx.new(|cx| ConfigurationView::new(self.state.clone(), target_agent, window, cx))
|
||||
.into()
|
||||
}
|
||||
|
||||
|
@ -902,12 +924,18 @@ struct ConfigurationView {
|
|||
api_key_editor: Entity<Editor>,
|
||||
state: gpui::Entity<State>,
|
||||
load_credentials_task: Option<Task<()>>,
|
||||
target_agent: ConfigurationViewTargetAgent,
|
||||
}
|
||||
|
||||
impl ConfigurationView {
|
||||
const PLACEHOLDER_TEXT: &'static str = "sk-ant-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx";
|
||||
|
||||
fn new(state: gpui::Entity<State>, window: &mut Window, cx: &mut Context<Self>) -> Self {
|
||||
fn new(
|
||||
state: gpui::Entity<State>,
|
||||
target_agent: ConfigurationViewTargetAgent,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Self {
|
||||
cx.observe(&state, |_, _, cx| {
|
||||
cx.notify();
|
||||
})
|
||||
|
@ -939,6 +967,7 @@ impl ConfigurationView {
|
|||
}),
|
||||
state,
|
||||
load_credentials_task,
|
||||
target_agent,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1012,7 +1041,10 @@ impl Render for ConfigurationView {
|
|||
v_flex()
|
||||
.size_full()
|
||||
.on_action(cx.listener(Self::save_api_key))
|
||||
.child(Label::new("To use Zed's agent with Anthropic, you need to add an API key. Follow these steps:"))
|
||||
.child(Label::new(format!("To use {}, you need to add an API key. Follow these steps:", match self.target_agent {
|
||||
ConfigurationViewTargetAgent::ZedAgent => "Zed's agent with Anthropic",
|
||||
ConfigurationViewTargetAgent::Other(agent) => agent,
|
||||
})))
|
||||
.child(
|
||||
List::new()
|
||||
.child(
|
||||
|
@ -1023,7 +1055,7 @@ impl Render for ConfigurationView {
|
|||
)
|
||||
)
|
||||
.child(
|
||||
InstructionListItem::text_only("Paste your API key below and hit enter to start using the assistant")
|
||||
InstructionListItem::text_only("Paste your API key below and hit enter to start using the agent")
|
||||
)
|
||||
)
|
||||
.child(
|
||||
|
|
|
@ -348,7 +348,12 @@ impl LanguageModelProvider for BedrockLanguageModelProvider {
|
|||
self.state.update(cx, |state, cx| state.authenticate(cx))
|
||||
}
|
||||
|
||||
fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView {
|
||||
fn configuration_view(
|
||||
&self,
|
||||
_target_agent: language_model::ConfigurationViewTargetAgent,
|
||||
window: &mut Window,
|
||||
cx: &mut App,
|
||||
) -> AnyView {
|
||||
cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx))
|
||||
.into()
|
||||
}
|
||||
|
|
|
@ -391,7 +391,12 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
|
|||
Task::ready(Ok(()))
|
||||
}
|
||||
|
||||
fn configuration_view(&self, _: &mut Window, cx: &mut App) -> AnyView {
|
||||
fn configuration_view(
|
||||
&self,
|
||||
_target_agent: language_model::ConfigurationViewTargetAgent,
|
||||
_: &mut Window,
|
||||
cx: &mut App,
|
||||
) -> AnyView {
|
||||
cx.new(|_| ConfigurationView::new(self.state.clone()))
|
||||
.into()
|
||||
}
|
||||
|
|
|
@ -176,7 +176,12 @@ impl LanguageModelProvider for CopilotChatLanguageModelProvider {
|
|||
Task::ready(Err(err.into()))
|
||||
}
|
||||
|
||||
fn configuration_view(&self, _: &mut Window, cx: &mut App) -> AnyView {
|
||||
fn configuration_view(
|
||||
&self,
|
||||
_target_agent: language_model::ConfigurationViewTargetAgent,
|
||||
_: &mut Window,
|
||||
cx: &mut App,
|
||||
) -> AnyView {
|
||||
let state = self.state.clone();
|
||||
cx.new(|cx| ConfigurationView::new(state, cx)).into()
|
||||
}
|
||||
|
|
|
@ -229,7 +229,12 @@ impl LanguageModelProvider for DeepSeekLanguageModelProvider {
|
|||
self.state.update(cx, |state, cx| state.authenticate(cx))
|
||||
}
|
||||
|
||||
fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView {
|
||||
fn configuration_view(
|
||||
&self,
|
||||
_target_agent: language_model::ConfigurationViewTargetAgent,
|
||||
window: &mut Window,
|
||||
cx: &mut App,
|
||||
) -> AnyView {
|
||||
cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx))
|
||||
.into()
|
||||
}
|
||||
|
|
|
@ -277,7 +277,12 @@ impl LanguageModelProvider for GoogleLanguageModelProvider {
|
|||
self.state.update(cx, |state, cx| state.authenticate(cx))
|
||||
}
|
||||
|
||||
fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView {
|
||||
fn configuration_view(
|
||||
&self,
|
||||
_target_agent: language_model::ConfigurationViewTargetAgent,
|
||||
window: &mut Window,
|
||||
cx: &mut App,
|
||||
) -> AnyView {
|
||||
cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx))
|
||||
.into()
|
||||
}
|
||||
|
|
|
@ -226,7 +226,12 @@ impl LanguageModelProvider for LmStudioLanguageModelProvider {
|
|||
self.state.update(cx, |state, cx| state.authenticate(cx))
|
||||
}
|
||||
|
||||
fn configuration_view(&self, _window: &mut Window, cx: &mut App) -> AnyView {
|
||||
fn configuration_view(
|
||||
&self,
|
||||
_target_agent: language_model::ConfigurationViewTargetAgent,
|
||||
_window: &mut Window,
|
||||
cx: &mut App,
|
||||
) -> AnyView {
|
||||
let state = self.state.clone();
|
||||
cx.new(|cx| ConfigurationView::new(state, cx)).into()
|
||||
}
|
||||
|
|
|
@ -243,7 +243,12 @@ impl LanguageModelProvider for MistralLanguageModelProvider {
|
|||
self.state.update(cx, |state, cx| state.authenticate(cx))
|
||||
}
|
||||
|
||||
fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView {
|
||||
fn configuration_view(
|
||||
&self,
|
||||
_target_agent: language_model::ConfigurationViewTargetAgent,
|
||||
window: &mut Window,
|
||||
cx: &mut App,
|
||||
) -> AnyView {
|
||||
cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx))
|
||||
.into()
|
||||
}
|
||||
|
|
|
@ -255,7 +255,12 @@ impl LanguageModelProvider for OllamaLanguageModelProvider {
|
|||
self.state.update(cx, |state, cx| state.authenticate(cx))
|
||||
}
|
||||
|
||||
fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView {
|
||||
fn configuration_view(
|
||||
&self,
|
||||
_target_agent: language_model::ConfigurationViewTargetAgent,
|
||||
window: &mut Window,
|
||||
cx: &mut App,
|
||||
) -> AnyView {
|
||||
let state = self.state.clone();
|
||||
cx.new(|cx| ConfigurationView::new(state, window, cx))
|
||||
.into()
|
||||
|
|
|
@ -233,7 +233,12 @@ impl LanguageModelProvider for OpenAiLanguageModelProvider {
|
|||
self.state.update(cx, |state, cx| state.authenticate(cx))
|
||||
}
|
||||
|
||||
fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView {
|
||||
fn configuration_view(
|
||||
&self,
|
||||
_target_agent: language_model::ConfigurationViewTargetAgent,
|
||||
window: &mut Window,
|
||||
cx: &mut App,
|
||||
) -> AnyView {
|
||||
cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx))
|
||||
.into()
|
||||
}
|
||||
|
|
|
@ -243,7 +243,12 @@ impl LanguageModelProvider for OpenAiCompatibleLanguageModelProvider {
|
|||
self.state.update(cx, |state, cx| state.authenticate(cx))
|
||||
}
|
||||
|
||||
fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView {
|
||||
fn configuration_view(
|
||||
&self,
|
||||
_target_agent: language_model::ConfigurationViewTargetAgent,
|
||||
window: &mut Window,
|
||||
cx: &mut App,
|
||||
) -> AnyView {
|
||||
cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx))
|
||||
.into()
|
||||
}
|
||||
|
|
|
@ -306,7 +306,12 @@ impl LanguageModelProvider for OpenRouterLanguageModelProvider {
|
|||
self.state.update(cx, |state, cx| state.authenticate(cx))
|
||||
}
|
||||
|
||||
fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView {
|
||||
fn configuration_view(
|
||||
&self,
|
||||
_target_agent: language_model::ConfigurationViewTargetAgent,
|
||||
window: &mut Window,
|
||||
cx: &mut App,
|
||||
) -> AnyView {
|
||||
cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx))
|
||||
.into()
|
||||
}
|
||||
|
|
|
@ -230,7 +230,12 @@ impl LanguageModelProvider for VercelLanguageModelProvider {
|
|||
self.state.update(cx, |state, cx| state.authenticate(cx))
|
||||
}
|
||||
|
||||
fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView {
|
||||
fn configuration_view(
|
||||
&self,
|
||||
_target_agent: language_model::ConfigurationViewTargetAgent,
|
||||
window: &mut Window,
|
||||
cx: &mut App,
|
||||
) -> AnyView {
|
||||
cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx))
|
||||
.into()
|
||||
}
|
||||
|
|
|
@ -230,7 +230,12 @@ impl LanguageModelProvider for XAiLanguageModelProvider {
|
|||
self.state.update(cx, |state, cx| state.authenticate(cx))
|
||||
}
|
||||
|
||||
fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView {
|
||||
fn configuration_view(
|
||||
&self,
|
||||
_target_agent: language_model::ConfigurationViewTargetAgent,
|
||||
window: &mut Window,
|
||||
cx: &mut App,
|
||||
) -> AnyView {
|
||||
cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx))
|
||||
.into()
|
||||
}
|
||||
|
|
|
@ -329,7 +329,11 @@ impl AiConfigurationModal {
|
|||
cx: &mut Context<Self>,
|
||||
) -> Self {
|
||||
let focus_handle = cx.focus_handle();
|
||||
let configuration_view = selected_provider.configuration_view(window, cx);
|
||||
let configuration_view = selected_provider.configuration_view(
|
||||
language_model::ConfigurationViewTargetAgent::ZedAgent,
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
|
||||
Self {
|
||||
focus_handle,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue