use anyhow::{Context as _, Result, anyhow}; use collections::BTreeMap; use credentials_provider::CredentialsProvider; use editor::{Editor, EditorElement, EditorStyle}; use futures::{FutureExt, Stream, StreamExt, future::BoxFuture}; use google_ai::{ FunctionDeclaration, GenerateContentResponse, GoogleModelMode, Part, SystemInstruction, ThinkingConfig, UsageMetadata, }; use gpui::{ AnyView, App, AsyncApp, Context, Entity, FontStyle, Subscription, Task, TextStyle, WhiteSpace, }; use http_client::HttpClient; use language_model::{ AuthenticateError, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelToolChoice, LanguageModelToolSchemaFormat, LanguageModelToolUse, LanguageModelToolUseId, MessageContent, StopReason, }; use language_model::{ LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role, }; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use settings::{Settings, SettingsStore}; use std::pin::Pin; use std::sync::{ Arc, atomic::{self, AtomicU64}, }; use strum::IntoEnumIterator; use theme::ThemeSettings; use ui::{Icon, IconName, List, Tooltip, prelude::*}; use util::ResultExt; use crate::AllLanguageModelSettings; use crate::ui::InstructionListItem; const PROVIDER_ID: LanguageModelProviderId = language_model::GOOGLE_PROVIDER_ID; const PROVIDER_NAME: LanguageModelProviderName = language_model::GOOGLE_PROVIDER_NAME; #[derive(Default, Clone, Debug, PartialEq)] pub struct GoogleSettings { pub api_url: String, pub available_models: Vec, } #[derive(Clone, Copy, Debug, Default, PartialEq, Serialize, Deserialize, JsonSchema)] #[serde(tag = "type", rename_all = "lowercase")] pub enum ModelMode { #[default] Default, Thinking { /// The maximum number of tokens to use for reasoning. Must be lower than the model's `max_output_tokens`. budget_tokens: Option, }, } impl From for GoogleModelMode { fn from(value: ModelMode) -> Self { match value { ModelMode::Default => GoogleModelMode::Default, ModelMode::Thinking { budget_tokens } => GoogleModelMode::Thinking { budget_tokens }, } } } impl From for ModelMode { fn from(value: GoogleModelMode) -> Self { match value { GoogleModelMode::Default => ModelMode::Default, GoogleModelMode::Thinking { budget_tokens } => ModelMode::Thinking { budget_tokens }, } } } #[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)] pub struct AvailableModel { name: String, display_name: Option, max_tokens: u64, mode: Option, } pub struct GoogleLanguageModelProvider { http_client: Arc, state: gpui::Entity, } pub struct State { api_key: Option, api_key_from_env: bool, _subscription: Subscription, } const GEMINI_API_KEY_VAR: &str = "GEMINI_API_KEY"; const GOOGLE_AI_API_KEY_VAR: &str = "GOOGLE_AI_API_KEY"; impl State { fn is_authenticated(&self) -> bool { self.api_key.is_some() } fn reset_api_key(&self, cx: &mut Context) -> Task> { let credentials_provider = ::global(cx); let api_url = AllLanguageModelSettings::get_global(cx) .google .api_url .clone(); cx.spawn(async move |this, cx| { credentials_provider .delete_credentials(&api_url, cx) .await .log_err(); this.update(cx, |this, cx| { this.api_key = None; this.api_key_from_env = false; cx.notify(); }) }) } fn set_api_key(&mut self, api_key: String, cx: &mut Context) -> Task> { let credentials_provider = ::global(cx); let api_url = AllLanguageModelSettings::get_global(cx) .google .api_url .clone(); cx.spawn(async move |this, cx| { credentials_provider .write_credentials(&api_url, "Bearer", api_key.as_bytes(), cx) .await?; this.update(cx, |this, cx| { this.api_key = Some(api_key); cx.notify(); }) }) } fn authenticate(&self, cx: &mut Context) -> Task> { if self.is_authenticated() { return Task::ready(Ok(())); } let credentials_provider = ::global(cx); let api_url = AllLanguageModelSettings::get_global(cx) .google .api_url .clone(); cx.spawn(async move |this, cx| { let (api_key, from_env) = if let Ok(api_key) = std::env::var(GOOGLE_AI_API_KEY_VAR) { (api_key, true) } else if let Ok(api_key) = std::env::var(GEMINI_API_KEY_VAR) { (api_key, true) } else { let (_, api_key) = credentials_provider .read_credentials(&api_url, cx) .await? .ok_or(AuthenticateError::CredentialsNotFound)?; ( String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?, false, ) }; this.update(cx, |this, cx| { this.api_key = Some(api_key); this.api_key_from_env = from_env; cx.notify(); })?; Ok(()) }) } } impl GoogleLanguageModelProvider { pub fn new(http_client: Arc, cx: &mut App) -> Self { let state = cx.new(|cx| State { api_key: None, api_key_from_env: false, _subscription: cx.observe_global::(|_, cx| { cx.notify(); }), }); Self { http_client, state } } fn create_language_model(&self, model: google_ai::Model) -> Arc { Arc::new(GoogleLanguageModel { id: LanguageModelId::from(model.id().to_string()), model, state: self.state.clone(), http_client: self.http_client.clone(), request_limiter: RateLimiter::new(4), }) } } impl LanguageModelProviderState for GoogleLanguageModelProvider { type ObservableEntity = State; fn observable_entity(&self) -> Option> { Some(self.state.clone()) } } impl LanguageModelProvider for GoogleLanguageModelProvider { fn id(&self) -> LanguageModelProviderId { PROVIDER_ID } fn name(&self) -> LanguageModelProviderName { PROVIDER_NAME } fn icon(&self) -> IconName { IconName::AiGoogle } fn default_model(&self, _cx: &App) -> Option> { Some(self.create_language_model(google_ai::Model::default())) } fn default_fast_model(&self, _cx: &App) -> Option> { Some(self.create_language_model(google_ai::Model::default_fast())) } fn provided_models(&self, cx: &App) -> Vec> { let mut models = BTreeMap::default(); // Add base models from google_ai::Model::iter() for model in google_ai::Model::iter() { if !matches!(model, google_ai::Model::Custom { .. }) { models.insert(model.id().to_string(), model); } } // Override with available models from settings for model in &AllLanguageModelSettings::get_global(cx) .google .available_models { models.insert( model.name.clone(), google_ai::Model::Custom { name: model.name.clone(), display_name: model.display_name.clone(), max_tokens: model.max_tokens, mode: model.mode.unwrap_or_default().into(), }, ); } models .into_values() .map(|model| { Arc::new(GoogleLanguageModel { id: LanguageModelId::from(model.id().to_string()), model, state: self.state.clone(), http_client: self.http_client.clone(), request_limiter: RateLimiter::new(4), }) as Arc }) .collect() } fn is_authenticated(&self, cx: &App) -> bool { self.state.read(cx).is_authenticated() } fn authenticate(&self, cx: &mut App) -> Task> { self.state.update(cx, |state, cx| state.authenticate(cx)) } fn configuration_view( &self, _target_agent: language_model::ConfigurationViewTargetAgent, window: &mut Window, cx: &mut App, ) -> AnyView { cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx)) .into() } fn reset_credentials(&self, cx: &mut App) -> Task> { self.state.update(cx, |state, cx| state.reset_api_key(cx)) } } pub struct GoogleLanguageModel { id: LanguageModelId, model: google_ai::Model, state: gpui::Entity, http_client: Arc, request_limiter: RateLimiter, } impl GoogleLanguageModel { fn stream_completion( &self, request: google_ai::GenerateContentRequest, cx: &AsyncApp, ) -> BoxFuture< 'static, Result>>, > { let http_client = self.http_client.clone(); let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, cx| { let settings = &AllLanguageModelSettings::get_global(cx).google; (state.api_key.clone(), settings.api_url.clone()) }) else { return futures::future::ready(Err(anyhow!("App state dropped"))).boxed(); }; async move { let api_key = api_key.context("Missing Google API key")?; let request = google_ai::stream_generate_content( http_client.as_ref(), &api_url, &api_key, request, ); request.await.context("failed to stream completion") } .boxed() } } impl LanguageModel for GoogleLanguageModel { fn id(&self) -> LanguageModelId { self.id.clone() } fn name(&self) -> LanguageModelName { LanguageModelName::from(self.model.display_name().to_string()) } fn provider_id(&self) -> LanguageModelProviderId { PROVIDER_ID } fn provider_name(&self) -> LanguageModelProviderName { PROVIDER_NAME } fn supports_tools(&self) -> bool { self.model.supports_tools() } fn supports_images(&self) -> bool { self.model.supports_images() } fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool { match choice { LanguageModelToolChoice::Auto | LanguageModelToolChoice::Any | LanguageModelToolChoice::None => true, } } fn tool_input_format(&self) -> LanguageModelToolSchemaFormat { LanguageModelToolSchemaFormat::JsonSchemaSubset } fn telemetry_id(&self) -> String { format!("google/{}", self.model.request_id()) } fn max_token_count(&self) -> u64 { self.model.max_token_count() } fn max_output_tokens(&self) -> Option { self.model.max_output_tokens() } fn count_tokens( &self, request: LanguageModelRequest, cx: &App, ) -> BoxFuture<'static, Result> { let model_id = self.model.request_id().to_string(); let request = into_google(request, model_id, self.model.mode()); let http_client = self.http_client.clone(); let api_key = self.state.read(cx).api_key.clone(); let settings = &AllLanguageModelSettings::get_global(cx).google; let api_url = settings.api_url.clone(); async move { let api_key = api_key.context("Missing Google API key")?; let response = google_ai::count_tokens( http_client.as_ref(), &api_url, &api_key, google_ai::CountTokensRequest { generate_content_request: request, }, ) .await?; Ok(response.total_tokens) } .boxed() } fn stream_completion( &self, request: LanguageModelRequest, cx: &AsyncApp, ) -> BoxFuture< 'static, Result< futures::stream::BoxStream< 'static, Result, >, LanguageModelCompletionError, >, > { let request = into_google( request, self.model.request_id().to_string(), self.model.mode(), ); let request = self.stream_completion(request, cx); let future = self.request_limiter.stream(async move { let response = request.await.map_err(LanguageModelCompletionError::from)?; Ok(GoogleEventMapper::new().map_stream(response)) }); async move { Ok(future.await?.boxed()) }.boxed() } } pub fn into_google( mut request: LanguageModelRequest, model_id: String, mode: GoogleModelMode, ) -> google_ai::GenerateContentRequest { fn map_content(content: Vec) -> Vec { content .into_iter() .flat_map(|content| match content { language_model::MessageContent::Text(text) => { if !text.is_empty() { vec![Part::TextPart(google_ai::TextPart { text })] } else { vec![] } } language_model::MessageContent::Thinking { text: _, signature: Some(signature), } => { if !signature.is_empty() { vec![Part::ThoughtPart(google_ai::ThoughtPart { thought: true, thought_signature: signature, })] } else { vec![] } } language_model::MessageContent::Thinking { .. } => { vec![] } language_model::MessageContent::RedactedThinking(_) => vec![], language_model::MessageContent::Image(image) => { vec![Part::InlineDataPart(google_ai::InlineDataPart { inline_data: google_ai::GenerativeContentBlob { mime_type: "image/png".to_string(), data: image.source.to_string(), }, })] } language_model::MessageContent::ToolUse(tool_use) => { vec![Part::FunctionCallPart(google_ai::FunctionCallPart { function_call: google_ai::FunctionCall { name: tool_use.name.to_string(), args: tool_use.input, }, })] } language_model::MessageContent::ToolResult(tool_result) => { match tool_result.content { language_model::LanguageModelToolResultContent::Text(text) => { vec![Part::FunctionResponsePart( google_ai::FunctionResponsePart { function_response: google_ai::FunctionResponse { name: tool_result.tool_name.to_string(), // The API expects a valid JSON object response: serde_json::json!({ "output": text }), }, }, )] } language_model::LanguageModelToolResultContent::Image(image) => { vec![ Part::FunctionResponsePart(google_ai::FunctionResponsePart { function_response: google_ai::FunctionResponse { name: tool_result.tool_name.to_string(), // The API expects a valid JSON object response: serde_json::json!({ "output": "Tool responded with an image" }), }, }), Part::InlineDataPart(google_ai::InlineDataPart { inline_data: google_ai::GenerativeContentBlob { mime_type: "image/png".to_string(), data: image.source.to_string(), }, }), ] } } } }) .collect() } let system_instructions = if request .messages .first() .is_some_and(|msg| matches!(msg.role, Role::System)) { let message = request.messages.remove(0); Some(SystemInstruction { parts: map_content(message.content), }) } else { None }; google_ai::GenerateContentRequest { model: google_ai::ModelName { model_id }, system_instruction: system_instructions, contents: request .messages .into_iter() .filter_map(|message| { let parts = map_content(message.content); if parts.is_empty() { None } else { Some(google_ai::Content { parts, role: match message.role { Role::User => google_ai::Role::User, Role::Assistant => google_ai::Role::Model, Role::System => google_ai::Role::User, // Google AI doesn't have a system role }, }) } }) .collect(), generation_config: Some(google_ai::GenerationConfig { candidate_count: Some(1), stop_sequences: Some(request.stop), max_output_tokens: None, temperature: request.temperature.map(|t| t as f64).or(Some(1.0)), thinking_config: match (request.thinking_allowed, mode) { (true, GoogleModelMode::Thinking { budget_tokens }) => { budget_tokens.map(|thinking_budget| ThinkingConfig { thinking_budget }) } _ => None, }, top_p: None, top_k: None, }), safety_settings: None, tools: (request.tools.len() > 0).then(|| { vec![google_ai::Tool { function_declarations: request .tools .into_iter() .map(|tool| FunctionDeclaration { name: tool.name, description: tool.description, parameters: tool.input_schema, }) .collect(), }] }), tool_config: request.tool_choice.map(|choice| google_ai::ToolConfig { function_calling_config: google_ai::FunctionCallingConfig { mode: match choice { LanguageModelToolChoice::Auto => google_ai::FunctionCallingMode::Auto, LanguageModelToolChoice::Any => google_ai::FunctionCallingMode::Any, LanguageModelToolChoice::None => google_ai::FunctionCallingMode::None, }, allowed_function_names: None, }, }), } } pub struct GoogleEventMapper { usage: UsageMetadata, stop_reason: StopReason, } impl GoogleEventMapper { pub fn new() -> Self { Self { usage: UsageMetadata::default(), stop_reason: StopReason::EndTurn, } } pub fn map_stream( mut self, events: Pin>>>, ) -> impl Stream> { events .map(Some) .chain(futures::stream::once(async { None })) .flat_map(move |event| { futures::stream::iter(match event { Some(Ok(event)) => self.map_event(event), Some(Err(error)) => { vec![Err(LanguageModelCompletionError::from(error))] } None => vec![Ok(LanguageModelCompletionEvent::Stop(self.stop_reason))], }) }) } pub fn map_event( &mut self, event: GenerateContentResponse, ) -> Vec> { static TOOL_CALL_COUNTER: AtomicU64 = AtomicU64::new(0); let mut events: Vec<_> = Vec::new(); let mut wants_to_use_tool = false; if let Some(usage_metadata) = event.usage_metadata { update_usage(&mut self.usage, &usage_metadata); events.push(Ok(LanguageModelCompletionEvent::UsageUpdate( convert_usage(&self.usage), ))) } if let Some(candidates) = event.candidates { for candidate in candidates { if let Some(finish_reason) = candidate.finish_reason.as_deref() { self.stop_reason = match finish_reason { "STOP" => StopReason::EndTurn, "MAX_TOKENS" => StopReason::MaxTokens, _ => { log::error!("Unexpected google finish_reason: {finish_reason}"); StopReason::EndTurn } }; } candidate .content .parts .into_iter() .for_each(|part| match part { Part::TextPart(text_part) => { events.push(Ok(LanguageModelCompletionEvent::Text(text_part.text))) } Part::InlineDataPart(_) => {} Part::FunctionCallPart(function_call_part) => { wants_to_use_tool = true; let name: Arc = function_call_part.function_call.name.into(); let next_tool_id = TOOL_CALL_COUNTER.fetch_add(1, atomic::Ordering::SeqCst); let id: LanguageModelToolUseId = format!("{}-{}", name, next_tool_id).into(); events.push(Ok(LanguageModelCompletionEvent::ToolUse( LanguageModelToolUse { id, name, is_input_complete: true, raw_input: function_call_part.function_call.args.to_string(), input: function_call_part.function_call.args, }, ))); } Part::FunctionResponsePart(_) => {} Part::ThoughtPart(part) => { events.push(Ok(LanguageModelCompletionEvent::Thinking { text: "(Encrypted thought)".to_string(), // TODO: Can we populate this from thought summaries? signature: Some(part.thought_signature), })); } }); } } // Even when Gemini wants to use a Tool, the API // responds with `finish_reason: STOP` if wants_to_use_tool { self.stop_reason = StopReason::ToolUse; events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse))); } events } } pub fn count_google_tokens( request: LanguageModelRequest, cx: &App, ) -> BoxFuture<'static, Result> { // We couldn't use the GoogleLanguageModelProvider to count tokens because the github copilot doesn't have the access to google_ai directly. // So we have to use tokenizer from tiktoken_rs to count tokens. cx.background_spawn(async move { let messages = request .messages .into_iter() .map(|message| tiktoken_rs::ChatCompletionRequestMessage { role: match message.role { Role::User => "user".into(), Role::Assistant => "assistant".into(), Role::System => "system".into(), }, content: Some(message.string_contents()), name: None, function_call: None, }) .collect::>(); // Tiktoken doesn't yet support these models, so we manually use the // same tokenizer as GPT-4. tiktoken_rs::num_tokens_from_messages("gpt-4", &messages).map(|tokens| tokens as u64) }) .boxed() } fn update_usage(usage: &mut UsageMetadata, new: &UsageMetadata) { if let Some(prompt_token_count) = new.prompt_token_count { usage.prompt_token_count = Some(prompt_token_count); } if let Some(cached_content_token_count) = new.cached_content_token_count { usage.cached_content_token_count = Some(cached_content_token_count); } if let Some(candidates_token_count) = new.candidates_token_count { usage.candidates_token_count = Some(candidates_token_count); } if let Some(tool_use_prompt_token_count) = new.tool_use_prompt_token_count { usage.tool_use_prompt_token_count = Some(tool_use_prompt_token_count); } if let Some(thoughts_token_count) = new.thoughts_token_count { usage.thoughts_token_count = Some(thoughts_token_count); } if let Some(total_token_count) = new.total_token_count { usage.total_token_count = Some(total_token_count); } } fn convert_usage(usage: &UsageMetadata) -> language_model::TokenUsage { let prompt_tokens = usage.prompt_token_count.unwrap_or(0); let cached_tokens = usage.cached_content_token_count.unwrap_or(0); let input_tokens = prompt_tokens - cached_tokens; let output_tokens = usage.candidates_token_count.unwrap_or(0); language_model::TokenUsage { input_tokens, output_tokens, cache_read_input_tokens: cached_tokens, cache_creation_input_tokens: 0, } } struct ConfigurationView { api_key_editor: Entity, state: gpui::Entity, load_credentials_task: Option>, } impl ConfigurationView { fn new(state: gpui::Entity, window: &mut Window, cx: &mut Context) -> Self { cx.observe(&state, |_, _, cx| { cx.notify(); }) .detach(); let load_credentials_task = Some(cx.spawn_in(window, { let state = state.clone(); async move |this, cx| { if let Some(task) = state .update(cx, |state, cx| state.authenticate(cx)) .log_err() { // We don't log an error, because "not signed in" is also an error. let _ = task.await; } this.update(cx, |this, cx| { this.load_credentials_task = None; cx.notify(); }) .log_err(); } })); Self { api_key_editor: cx.new(|cx| { let mut editor = Editor::single_line(window, cx); editor.set_placeholder_text("AIzaSy...", cx); editor }), state, load_credentials_task, } } fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context) { let api_key = self.api_key_editor.read(cx).text(cx); if api_key.is_empty() { return; } let state = self.state.clone(); cx.spawn_in(window, async move |_, cx| { state .update(cx, |state, cx| state.set_api_key(api_key, cx))? .await }) .detach_and_log_err(cx); cx.notify(); } fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context) { self.api_key_editor .update(cx, |editor, cx| editor.set_text("", window, cx)); let state = self.state.clone(); cx.spawn_in(window, async move |_, cx| { state.update(cx, |state, cx| state.reset_api_key(cx))?.await }) .detach_and_log_err(cx); cx.notify(); } fn render_api_key_editor(&self, cx: &mut Context) -> impl IntoElement { let settings = ThemeSettings::get_global(cx); let text_style = TextStyle { color: cx.theme().colors().text, font_family: settings.ui_font.family.clone(), font_features: settings.ui_font.features.clone(), font_fallbacks: settings.ui_font.fallbacks.clone(), font_size: rems(0.875).into(), font_weight: settings.ui_font.weight, font_style: FontStyle::Normal, line_height: relative(1.3), white_space: WhiteSpace::Normal, ..Default::default() }; EditorElement::new( &self.api_key_editor, EditorStyle { background: cx.theme().colors().editor_background, local_player: cx.theme().players().local(), text: text_style, ..Default::default() }, ) } fn should_render_editor(&self, cx: &mut Context) -> bool { !self.state.read(cx).is_authenticated() } } impl Render for ConfigurationView { fn render(&mut self, _: &mut Window, cx: &mut Context) -> impl IntoElement { let env_var_set = self.state.read(cx).api_key_from_env; if self.load_credentials_task.is_some() { div().child(Label::new("Loading credentials...")).into_any() } else if self.should_render_editor(cx) { v_flex() .size_full() .on_action(cx.listener(Self::save_api_key)) .child(Label::new("To use Zed's agent with Google AI, you need to add an API key. Follow these steps:")) .child( List::new() .child(InstructionListItem::new( "Create one by visiting", Some("Google AI's console"), Some("https://aistudio.google.com/app/apikey"), )) .child(InstructionListItem::text_only( "Paste your API key below and hit enter to start using the assistant", )), ) .child( h_flex() .w_full() .my_2() .px_2() .py_1() .bg(cx.theme().colors().editor_background) .border_1() .border_color(cx.theme().colors().border) .rounded_sm() .child(self.render_api_key_editor(cx)), ) .child( Label::new( format!("You can also assign the {GEMINI_API_KEY_VAR} environment variable and restart Zed."), ) .size(LabelSize::Small).color(Color::Muted), ) .into_any() } else { h_flex() .mt_1() .p_1() .justify_between() .rounded_md() .border_1() .border_color(cx.theme().colors().border) .bg(cx.theme().colors().background) .child( h_flex() .gap_1() .child(Icon::new(IconName::Check).color(Color::Success)) .child(Label::new(if env_var_set { format!("API key set in {GEMINI_API_KEY_VAR} environment variable.") } else { "API key configured.".to_string() })), ) .child( Button::new("reset-key", "Reset Key") .label_size(LabelSize::Small) .icon(Some(IconName::Trash)) .icon_size(IconSize::Small) .icon_position(IconPosition::Start) .disabled(env_var_set) .when(env_var_set, |this| { this.tooltip(Tooltip::text(format!("To reset your API key, make sure {GEMINI_API_KEY_VAR} and {GOOGLE_AI_API_KEY_VAR} environment variables are unset."))) }) .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx))), ) .into_any() } } }