acp: Handle Gemini Auth Better (#36631)
Release Notes: - N/A --------- Co-authored-by: Danilo Leal <daniloleal09@gmail.com>
This commit is contained in:
parent
00ff7b72d7
commit
5b443bb49e
3 changed files with 195 additions and 19 deletions
|
@ -5,6 +5,7 @@ use crate::{AgentServer, AgentServerCommand};
|
||||||
use acp_thread::{AgentConnection, LoadError};
|
use acp_thread::{AgentConnection, LoadError};
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use gpui::{Entity, Task};
|
use gpui::{Entity, Task};
|
||||||
|
use language_models::provider::google::GoogleLanguageModelProvider;
|
||||||
use project::Project;
|
use project::Project;
|
||||||
use settings::SettingsStore;
|
use settings::SettingsStore;
|
||||||
use ui::App;
|
use ui::App;
|
||||||
|
@ -47,7 +48,7 @@ impl AgentServer for Gemini {
|
||||||
settings.get::<AllAgentServersSettings>(None).gemini.clone()
|
settings.get::<AllAgentServersSettings>(None).gemini.clone()
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
let Some(command) =
|
let Some(mut command) =
|
||||||
AgentServerCommand::resolve("gemini", &[ACP_ARG], None, settings, &project, cx).await
|
AgentServerCommand::resolve("gemini", &[ACP_ARG], None, settings, &project, cx).await
|
||||||
else {
|
else {
|
||||||
return Err(LoadError::NotInstalled {
|
return Err(LoadError::NotInstalled {
|
||||||
|
@ -57,6 +58,10 @@ impl AgentServer for Gemini {
|
||||||
}.into());
|
}.into());
|
||||||
};
|
};
|
||||||
|
|
||||||
|
if let Some(api_key)= cx.update(GoogleLanguageModelProvider::api_key)?.await.ok() {
|
||||||
|
command.env.get_or_insert_default().insert("GEMINI_API_KEY".to_owned(), api_key.key);
|
||||||
|
}
|
||||||
|
|
||||||
let result = crate::acp::connect(server_name, command.clone(), &root_dir, cx).await;
|
let result = crate::acp::connect(server_name, command.clone(), &root_dir, cx).await;
|
||||||
if result.is_err() {
|
if result.is_err() {
|
||||||
let version_fut = util::command::new_smol_command(&command.path)
|
let version_fut = util::command::new_smol_command(&command.path)
|
||||||
|
|
|
@ -278,6 +278,7 @@ enum ThreadState {
|
||||||
connection: Rc<dyn AgentConnection>,
|
connection: Rc<dyn AgentConnection>,
|
||||||
description: Option<Entity<Markdown>>,
|
description: Option<Entity<Markdown>>,
|
||||||
configuration_view: Option<AnyView>,
|
configuration_view: Option<AnyView>,
|
||||||
|
pending_auth_method: Option<acp::AuthMethodId>,
|
||||||
_subscription: Option<Subscription>,
|
_subscription: Option<Subscription>,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
@ -563,6 +564,7 @@ impl AcpThreadView {
|
||||||
|
|
||||||
this.update(cx, |this, cx| {
|
this.update(cx, |this, cx| {
|
||||||
this.thread_state = ThreadState::Unauthenticated {
|
this.thread_state = ThreadState::Unauthenticated {
|
||||||
|
pending_auth_method: None,
|
||||||
connection,
|
connection,
|
||||||
configuration_view,
|
configuration_view,
|
||||||
description: err
|
description: err
|
||||||
|
@ -999,12 +1001,74 @@ impl AcpThreadView {
|
||||||
window: &mut Window,
|
window: &mut Window,
|
||||||
cx: &mut Context<Self>,
|
cx: &mut Context<Self>,
|
||||||
) {
|
) {
|
||||||
let ThreadState::Unauthenticated { ref connection, .. } = self.thread_state else {
|
let ThreadState::Unauthenticated {
|
||||||
|
connection,
|
||||||
|
pending_auth_method,
|
||||||
|
configuration_view,
|
||||||
|
..
|
||||||
|
} = &mut self.thread_state
|
||||||
|
else {
|
||||||
return;
|
return;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
if method.0.as_ref() == "gemini-api-key" {
|
||||||
|
let registry = LanguageModelRegistry::global(cx);
|
||||||
|
let provider = registry
|
||||||
|
.read(cx)
|
||||||
|
.provider(&language_model::GOOGLE_PROVIDER_ID)
|
||||||
|
.unwrap();
|
||||||
|
if !provider.is_authenticated(cx) {
|
||||||
|
let this = cx.weak_entity();
|
||||||
|
let agent = self.agent.clone();
|
||||||
|
let connection = connection.clone();
|
||||||
|
window.defer(cx, |window, cx| {
|
||||||
|
Self::handle_auth_required(
|
||||||
|
this,
|
||||||
|
AuthRequired {
|
||||||
|
description: Some("GEMINI_API_KEY must be set".to_owned()),
|
||||||
|
provider_id: Some(language_model::GOOGLE_PROVIDER_ID),
|
||||||
|
},
|
||||||
|
agent,
|
||||||
|
connection,
|
||||||
|
window,
|
||||||
|
cx,
|
||||||
|
);
|
||||||
|
});
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
} else if method.0.as_ref() == "vertex-ai"
|
||||||
|
&& std::env::var("GOOGLE_API_KEY").is_err()
|
||||||
|
&& (std::env::var("GOOGLE_CLOUD_PROJECT").is_err()
|
||||||
|
|| (std::env::var("GOOGLE_CLOUD_PROJECT").is_err()))
|
||||||
|
{
|
||||||
|
let this = cx.weak_entity();
|
||||||
|
let agent = self.agent.clone();
|
||||||
|
let connection = connection.clone();
|
||||||
|
|
||||||
|
window.defer(cx, |window, cx| {
|
||||||
|
Self::handle_auth_required(
|
||||||
|
this,
|
||||||
|
AuthRequired {
|
||||||
|
description: Some(
|
||||||
|
"GOOGLE_API_KEY must be set in the environment to use Vertex AI authentication for Gemini CLI. Please export it and restart Zed."
|
||||||
|
.to_owned(),
|
||||||
|
),
|
||||||
|
provider_id: None,
|
||||||
|
},
|
||||||
|
agent,
|
||||||
|
connection,
|
||||||
|
window,
|
||||||
|
cx,
|
||||||
|
)
|
||||||
|
});
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
self.thread_error.take();
|
self.thread_error.take();
|
||||||
|
configuration_view.take();
|
||||||
|
pending_auth_method.replace(method.clone());
|
||||||
let authenticate = connection.authenticate(method, cx);
|
let authenticate = connection.authenticate(method, cx);
|
||||||
|
cx.notify();
|
||||||
self.auth_task = Some(cx.spawn_in(window, {
|
self.auth_task = Some(cx.spawn_in(window, {
|
||||||
let project = self.project.clone();
|
let project = self.project.clone();
|
||||||
let agent = self.agent.clone();
|
let agent = self.agent.clone();
|
||||||
|
@ -2425,6 +2489,7 @@ impl AcpThreadView {
|
||||||
connection: &Rc<dyn AgentConnection>,
|
connection: &Rc<dyn AgentConnection>,
|
||||||
description: Option<&Entity<Markdown>>,
|
description: Option<&Entity<Markdown>>,
|
||||||
configuration_view: Option<&AnyView>,
|
configuration_view: Option<&AnyView>,
|
||||||
|
pending_auth_method: Option<&acp::AuthMethodId>,
|
||||||
window: &mut Window,
|
window: &mut Window,
|
||||||
cx: &Context<Self>,
|
cx: &Context<Self>,
|
||||||
) -> Div {
|
) -> Div {
|
||||||
|
@ -2456,17 +2521,80 @@ impl AcpThreadView {
|
||||||
.cloned()
|
.cloned()
|
||||||
.map(|view| div().px_4().w_full().max_w_128().child(view)),
|
.map(|view| div().px_4().w_full().max_w_128().child(view)),
|
||||||
)
|
)
|
||||||
.child(h_flex().mt_1p5().justify_center().children(
|
.when(
|
||||||
connection.auth_methods().iter().map(|method| {
|
configuration_view.is_none()
|
||||||
Button::new(SharedString::from(method.id.0.clone()), method.name.clone())
|
&& description.is_none()
|
||||||
.on_click({
|
&& pending_auth_method.is_none(),
|
||||||
let method_id = method.id.clone();
|
|el| {
|
||||||
cx.listener(move |this, _, window, cx| {
|
el.child(
|
||||||
this.authenticate(method_id.clone(), window, cx)
|
div()
|
||||||
|
.text_ui(cx)
|
||||||
|
.text_center()
|
||||||
|
.px_4()
|
||||||
|
.w_full()
|
||||||
|
.max_w_128()
|
||||||
|
.child(Label::new("Authentication required")),
|
||||||
|
)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.when_some(pending_auth_method, |el, _| {
|
||||||
|
let spinner_icon = div()
|
||||||
|
.px_0p5()
|
||||||
|
.id("generating")
|
||||||
|
.tooltip(Tooltip::text("Generating Changes…"))
|
||||||
|
.child(
|
||||||
|
Icon::new(IconName::ArrowCircle)
|
||||||
|
.size(IconSize::Small)
|
||||||
|
.with_animation(
|
||||||
|
"arrow-circle",
|
||||||
|
Animation::new(Duration::from_secs(2)).repeat(),
|
||||||
|
|icon, delta| {
|
||||||
|
icon.transform(Transformation::rotate(percentage(delta)))
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.into_any_element(),
|
||||||
|
)
|
||||||
|
.into_any();
|
||||||
|
el.child(
|
||||||
|
h_flex()
|
||||||
|
.text_ui(cx)
|
||||||
|
.text_center()
|
||||||
|
.justify_center()
|
||||||
|
.gap_2()
|
||||||
|
.px_4()
|
||||||
|
.w_full()
|
||||||
|
.max_w_128()
|
||||||
|
.child(Label::new("Authenticating..."))
|
||||||
|
.child(spinner_icon),
|
||||||
|
)
|
||||||
|
})
|
||||||
|
.child(
|
||||||
|
h_flex()
|
||||||
|
.mt_1p5()
|
||||||
|
.gap_1()
|
||||||
|
.flex_wrap()
|
||||||
|
.justify_center()
|
||||||
|
.children(connection.auth_methods().iter().enumerate().rev().map(
|
||||||
|
|(ix, method)| {
|
||||||
|
Button::new(
|
||||||
|
SharedString::from(method.id.0.clone()),
|
||||||
|
method.name.clone(),
|
||||||
|
)
|
||||||
|
.style(ButtonStyle::Outlined)
|
||||||
|
.when(ix == 0, |el| {
|
||||||
|
el.style(ButtonStyle::Tinted(ui::TintColor::Accent))
|
||||||
})
|
})
|
||||||
})
|
.size(ButtonSize::Medium)
|
||||||
}),
|
.label_size(LabelSize::Small)
|
||||||
))
|
.on_click({
|
||||||
|
let method_id = method.id.clone();
|
||||||
|
cx.listener(move |this, _, window, cx| {
|
||||||
|
this.authenticate(method_id.clone(), window, cx)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
},
|
||||||
|
)),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn render_load_error(&self, e: &LoadError, cx: &Context<Self>) -> AnyElement {
|
fn render_load_error(&self, e: &LoadError, cx: &Context<Self>) -> AnyElement {
|
||||||
|
@ -2551,6 +2679,8 @@ impl AcpThreadView {
|
||||||
let install_command = install_command.clone();
|
let install_command = install_command.clone();
|
||||||
container = container.child(
|
container = container.child(
|
||||||
Button::new("install", install_message)
|
Button::new("install", install_message)
|
||||||
|
.style(ButtonStyle::Tinted(ui::TintColor::Accent))
|
||||||
|
.size(ButtonSize::Medium)
|
||||||
.tooltip(Tooltip::text(install_command.clone()))
|
.tooltip(Tooltip::text(install_command.clone()))
|
||||||
.on_click(cx.listener(move |this, _, window, cx| {
|
.on_click(cx.listener(move |this, _, window, cx| {
|
||||||
let task = this
|
let task = this
|
||||||
|
@ -4372,11 +4502,13 @@ impl Render for AcpThreadView {
|
||||||
connection,
|
connection,
|
||||||
description,
|
description,
|
||||||
configuration_view,
|
configuration_view,
|
||||||
|
pending_auth_method,
|
||||||
..
|
..
|
||||||
} => self.render_auth_required_state(
|
} => self.render_auth_required_state(
|
||||||
connection,
|
connection,
|
||||||
description.as_ref(),
|
description.as_ref(),
|
||||||
configuration_view.as_ref(),
|
configuration_view.as_ref(),
|
||||||
|
pending_auth_method.as_ref(),
|
||||||
window,
|
window,
|
||||||
cx,
|
cx,
|
||||||
),
|
),
|
||||||
|
|
|
@ -12,9 +12,9 @@ use gpui::{
|
||||||
};
|
};
|
||||||
use http_client::HttpClient;
|
use http_client::HttpClient;
|
||||||
use language_model::{
|
use language_model::{
|
||||||
AuthenticateError, LanguageModelCompletionError, LanguageModelCompletionEvent,
|
AuthenticateError, ConfigurationViewTargetAgent, LanguageModelCompletionError,
|
||||||
LanguageModelToolChoice, LanguageModelToolSchemaFormat, LanguageModelToolUse,
|
LanguageModelCompletionEvent, LanguageModelToolChoice, LanguageModelToolSchemaFormat,
|
||||||
LanguageModelToolUseId, MessageContent, StopReason,
|
LanguageModelToolUse, LanguageModelToolUseId, MessageContent, StopReason,
|
||||||
};
|
};
|
||||||
use language_model::{
|
use language_model::{
|
||||||
LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
|
LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
|
||||||
|
@ -37,6 +37,8 @@ use util::ResultExt;
|
||||||
use crate::AllLanguageModelSettings;
|
use crate::AllLanguageModelSettings;
|
||||||
use crate::ui::InstructionListItem;
|
use crate::ui::InstructionListItem;
|
||||||
|
|
||||||
|
use super::anthropic::ApiKey;
|
||||||
|
|
||||||
const PROVIDER_ID: LanguageModelProviderId = language_model::GOOGLE_PROVIDER_ID;
|
const PROVIDER_ID: LanguageModelProviderId = language_model::GOOGLE_PROVIDER_ID;
|
||||||
const PROVIDER_NAME: LanguageModelProviderName = language_model::GOOGLE_PROVIDER_NAME;
|
const PROVIDER_NAME: LanguageModelProviderName = language_model::GOOGLE_PROVIDER_NAME;
|
||||||
|
|
||||||
|
@ -198,6 +200,33 @@ impl GoogleLanguageModelProvider {
|
||||||
request_limiter: RateLimiter::new(4),
|
request_limiter: RateLimiter::new(4),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn api_key(cx: &mut App) -> Task<Result<ApiKey>> {
|
||||||
|
let credentials_provider = <dyn CredentialsProvider>::global(cx);
|
||||||
|
let api_url = AllLanguageModelSettings::get_global(cx)
|
||||||
|
.google
|
||||||
|
.api_url
|
||||||
|
.clone();
|
||||||
|
|
||||||
|
if let Ok(key) = std::env::var(GEMINI_API_KEY_VAR) {
|
||||||
|
Task::ready(Ok(ApiKey {
|
||||||
|
key,
|
||||||
|
from_env: true,
|
||||||
|
}))
|
||||||
|
} else {
|
||||||
|
cx.spawn(async move |cx| {
|
||||||
|
let (_, api_key) = credentials_provider
|
||||||
|
.read_credentials(&api_url, cx)
|
||||||
|
.await?
|
||||||
|
.ok_or(AuthenticateError::CredentialsNotFound)?;
|
||||||
|
|
||||||
|
Ok(ApiKey {
|
||||||
|
key: String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?,
|
||||||
|
from_env: false,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl LanguageModelProviderState for GoogleLanguageModelProvider {
|
impl LanguageModelProviderState for GoogleLanguageModelProvider {
|
||||||
|
@ -279,11 +308,11 @@ impl LanguageModelProvider for GoogleLanguageModelProvider {
|
||||||
|
|
||||||
fn configuration_view(
|
fn configuration_view(
|
||||||
&self,
|
&self,
|
||||||
_target_agent: language_model::ConfigurationViewTargetAgent,
|
target_agent: language_model::ConfigurationViewTargetAgent,
|
||||||
window: &mut Window,
|
window: &mut Window,
|
||||||
cx: &mut App,
|
cx: &mut App,
|
||||||
) -> AnyView {
|
) -> AnyView {
|
||||||
cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx))
|
cx.new(|cx| ConfigurationView::new(self.state.clone(), target_agent, window, cx))
|
||||||
.into()
|
.into()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -776,11 +805,17 @@ fn convert_usage(usage: &UsageMetadata) -> language_model::TokenUsage {
|
||||||
struct ConfigurationView {
|
struct ConfigurationView {
|
||||||
api_key_editor: Entity<Editor>,
|
api_key_editor: Entity<Editor>,
|
||||||
state: gpui::Entity<State>,
|
state: gpui::Entity<State>,
|
||||||
|
target_agent: language_model::ConfigurationViewTargetAgent,
|
||||||
load_credentials_task: Option<Task<()>>,
|
load_credentials_task: Option<Task<()>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ConfigurationView {
|
impl ConfigurationView {
|
||||||
fn new(state: gpui::Entity<State>, window: &mut Window, cx: &mut Context<Self>) -> Self {
|
fn new(
|
||||||
|
state: gpui::Entity<State>,
|
||||||
|
target_agent: language_model::ConfigurationViewTargetAgent,
|
||||||
|
window: &mut Window,
|
||||||
|
cx: &mut Context<Self>,
|
||||||
|
) -> Self {
|
||||||
cx.observe(&state, |_, _, cx| {
|
cx.observe(&state, |_, _, cx| {
|
||||||
cx.notify();
|
cx.notify();
|
||||||
})
|
})
|
||||||
|
@ -810,6 +845,7 @@ impl ConfigurationView {
|
||||||
editor.set_placeholder_text("AIzaSy...", cx);
|
editor.set_placeholder_text("AIzaSy...", cx);
|
||||||
editor
|
editor
|
||||||
}),
|
}),
|
||||||
|
target_agent,
|
||||||
state,
|
state,
|
||||||
load_credentials_task,
|
load_credentials_task,
|
||||||
}
|
}
|
||||||
|
@ -885,7 +921,10 @@ impl Render for ConfigurationView {
|
||||||
v_flex()
|
v_flex()
|
||||||
.size_full()
|
.size_full()
|
||||||
.on_action(cx.listener(Self::save_api_key))
|
.on_action(cx.listener(Self::save_api_key))
|
||||||
.child(Label::new("To use Zed's agent with Google AI, you need to add an API key. Follow these steps:"))
|
.child(Label::new(format!("To use {}, you need to add an API key. Follow these steps:", match self.target_agent {
|
||||||
|
ConfigurationViewTargetAgent::ZedAgent => "Zed's agent with Google AI",
|
||||||
|
ConfigurationViewTargetAgent::Other(agent) => agent,
|
||||||
|
})))
|
||||||
.child(
|
.child(
|
||||||
List::new()
|
List::new()
|
||||||
.child(InstructionListItem::new(
|
.child(InstructionListItem::new(
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue