acp: Handle Gemini Auth Better (#36631)

Release Notes:

- N/A

---------

Co-authored-by: Danilo Leal <daniloleal09@gmail.com>
This commit is contained in:
Conrad Irwin 2025-08-20 16:12:41 -06:00 committed by GitHub
parent c9c708ff08
commit 5120b6b7f9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 195 additions and 19 deletions

View file

@ -278,6 +278,7 @@ enum ThreadState {
connection: Rc<dyn AgentConnection>,
description: Option<Entity<Markdown>>,
configuration_view: Option<AnyView>,
pending_auth_method: Option<acp::AuthMethodId>,
_subscription: Option<Subscription>,
},
}
@ -563,6 +564,7 @@ impl AcpThreadView {
this.update(cx, |this, cx| {
this.thread_state = ThreadState::Unauthenticated {
pending_auth_method: None,
connection,
configuration_view,
description: err
@ -999,12 +1001,74 @@ impl AcpThreadView {
window: &mut Window,
cx: &mut Context<Self>,
) {
let ThreadState::Unauthenticated { ref connection, .. } = self.thread_state else {
let ThreadState::Unauthenticated {
connection,
pending_auth_method,
configuration_view,
..
} = &mut self.thread_state
else {
return;
};
if method.0.as_ref() == "gemini-api-key" {
let registry = LanguageModelRegistry::global(cx);
let provider = registry
.read(cx)
.provider(&language_model::GOOGLE_PROVIDER_ID)
.unwrap();
if !provider.is_authenticated(cx) {
let this = cx.weak_entity();
let agent = self.agent.clone();
let connection = connection.clone();
window.defer(cx, |window, cx| {
Self::handle_auth_required(
this,
AuthRequired {
description: Some("GEMINI_API_KEY must be set".to_owned()),
provider_id: Some(language_model::GOOGLE_PROVIDER_ID),
},
agent,
connection,
window,
cx,
);
});
return;
}
} else if method.0.as_ref() == "vertex-ai"
&& std::env::var("GOOGLE_API_KEY").is_err()
&& (std::env::var("GOOGLE_CLOUD_PROJECT").is_err()
|| (std::env::var("GOOGLE_CLOUD_PROJECT").is_err()))
{
let this = cx.weak_entity();
let agent = self.agent.clone();
let connection = connection.clone();
window.defer(cx, |window, cx| {
Self::handle_auth_required(
this,
AuthRequired {
description: Some(
"GOOGLE_API_KEY must be set in the environment to use Vertex AI authentication for Gemini CLI. Please export it and restart Zed."
.to_owned(),
),
provider_id: None,
},
agent,
connection,
window,
cx,
)
});
return;
}
self.thread_error.take();
configuration_view.take();
pending_auth_method.replace(method.clone());
let authenticate = connection.authenticate(method, cx);
cx.notify();
self.auth_task = Some(cx.spawn_in(window, {
let project = self.project.clone();
let agent = self.agent.clone();
@ -2425,6 +2489,7 @@ impl AcpThreadView {
connection: &Rc<dyn AgentConnection>,
description: Option<&Entity<Markdown>>,
configuration_view: Option<&AnyView>,
pending_auth_method: Option<&acp::AuthMethodId>,
window: &mut Window,
cx: &Context<Self>,
) -> Div {
@ -2456,17 +2521,80 @@ impl AcpThreadView {
.cloned()
.map(|view| div().px_4().w_full().max_w_128().child(view)),
)
.child(h_flex().mt_1p5().justify_center().children(
connection.auth_methods().iter().map(|method| {
Button::new(SharedString::from(method.id.0.clone()), method.name.clone())
.on_click({
let method_id = method.id.clone();
cx.listener(move |this, _, window, cx| {
this.authenticate(method_id.clone(), window, cx)
.when(
configuration_view.is_none()
&& description.is_none()
&& pending_auth_method.is_none(),
|el| {
el.child(
div()
.text_ui(cx)
.text_center()
.px_4()
.w_full()
.max_w_128()
.child(Label::new("Authentication required")),
)
},
)
.when_some(pending_auth_method, |el, _| {
let spinner_icon = div()
.px_0p5()
.id("generating")
.tooltip(Tooltip::text("Generating Changes…"))
.child(
Icon::new(IconName::ArrowCircle)
.size(IconSize::Small)
.with_animation(
"arrow-circle",
Animation::new(Duration::from_secs(2)).repeat(),
|icon, delta| {
icon.transform(Transformation::rotate(percentage(delta)))
},
)
.into_any_element(),
)
.into_any();
el.child(
h_flex()
.text_ui(cx)
.text_center()
.justify_center()
.gap_2()
.px_4()
.w_full()
.max_w_128()
.child(Label::new("Authenticating..."))
.child(spinner_icon),
)
})
.child(
h_flex()
.mt_1p5()
.gap_1()
.flex_wrap()
.justify_center()
.children(connection.auth_methods().iter().enumerate().rev().map(
|(ix, method)| {
Button::new(
SharedString::from(method.id.0.clone()),
method.name.clone(),
)
.style(ButtonStyle::Outlined)
.when(ix == 0, |el| {
el.style(ButtonStyle::Tinted(ui::TintColor::Accent))
})
})
}),
))
.size(ButtonSize::Medium)
.label_size(LabelSize::Small)
.on_click({
let method_id = method.id.clone();
cx.listener(move |this, _, window, cx| {
this.authenticate(method_id.clone(), window, cx)
})
})
},
)),
)
}
fn render_load_error(&self, e: &LoadError, cx: &Context<Self>) -> AnyElement {
@ -2551,6 +2679,8 @@ impl AcpThreadView {
let install_command = install_command.clone();
container = container.child(
Button::new("install", install_message)
.style(ButtonStyle::Tinted(ui::TintColor::Accent))
.size(ButtonSize::Medium)
.tooltip(Tooltip::text(install_command.clone()))
.on_click(cx.listener(move |this, _, window, cx| {
let task = this
@ -4372,11 +4502,13 @@ impl Render for AcpThreadView {
connection,
description,
configuration_view,
pending_auth_method,
..
} => self.render_auth_required_state(
connection,
description.as_ref(),
configuration_view.as_ref(),
pending_auth_method.as_ref(),
window,
cx,
),