From f833a01a7e2ee97fbdaf5e6c0d49d873631b4e3d Mon Sep 17 00:00:00 2001 From: Richard Hao Date: Tue, 18 Feb 2025 03:25:38 +0800 Subject: [PATCH] copilot: Add support for Gemini 2.0 Flash model to Copilot Chat (#24952) Co-authored-by: Peter Tripp --- crates/copilot/src/copilot_chat.rs | 8 ++++- .../src/provider/copilot_chat.rs | 6 +++- crates/language_models/src/provider/google.rs | 32 ++++++++++++++++++- 3 files changed, 43 insertions(+), 3 deletions(-) diff --git a/crates/copilot/src/copilot_chat.rs b/crates/copilot/src/copilot_chat.rs index b45bd6270c..14e1a4b210 100644 --- a/crates/copilot/src/copilot_chat.rs +++ b/crates/copilot/src/copilot_chat.rs @@ -40,13 +40,15 @@ pub enum Model { O3Mini, #[serde(alias = "claude-3-5-sonnet", rename = "claude-3.5-sonnet")] Claude3_5Sonnet, + #[serde(alias = "gemini-2.0-flash", rename = "gemini-2.0-flash-001")] + Gemini20Flash, } impl Model { pub fn uses_streaming(&self) -> bool { match self { Self::Gpt4o | Self::Gpt4 | Self::Gpt3_5Turbo | Self::Claude3_5Sonnet => true, - Self::O3Mini | Self::O1 => false, + Self::O3Mini | Self::O1 | Self::Gemini20Flash => false, } } @@ -58,6 +60,7 @@ impl Model { "o1" => Ok(Self::O1), "o3-mini" => Ok(Self::O3Mini), "claude-3-5-sonnet" => Ok(Self::Claude3_5Sonnet), + "gemini-2.0-flash-001" => Ok(Self::Gemini20Flash), _ => Err(anyhow!("Invalid model id: {}", id)), } } @@ -70,6 +73,7 @@ impl Model { Self::O3Mini => "o3-mini", Self::O1 => "o1", Self::Claude3_5Sonnet => "claude-3-5-sonnet", + Self::Gemini20Flash => "gemini-2.0-flash-001", } } @@ -81,6 +85,7 @@ impl Model { Self::O3Mini => "o3-mini", Self::O1 => "o1", Self::Claude3_5Sonnet => "Claude 3.5 Sonnet", + Self::Gemini20Flash => "Gemini 2.0 Flash", } } @@ -92,6 +97,7 @@ impl Model { Self::O3Mini => 20000, Self::O1 => 20000, Self::Claude3_5Sonnet => 200_000, + Model::Gemini20Flash => 128_000, } } } diff --git a/crates/language_models/src/provider/copilot_chat.rs b/crates/language_models/src/provider/copilot_chat.rs index 067b4863e5..6efac131e9 100644 --- a/crates/language_models/src/provider/copilot_chat.rs +++ b/crates/language_models/src/provider/copilot_chat.rs @@ -25,6 +25,7 @@ use strum::IntoEnumIterator; use ui::prelude::*; use super::anthropic::count_anthropic_tokens; +use super::google::count_google_tokens; use super::open_ai::count_open_ai_tokens; const PROVIDER_ID: &str = "copilot_chat"; @@ -174,13 +175,16 @@ impl LanguageModel for CopilotChatLanguageModel { ) -> BoxFuture<'static, Result> { match self.model { CopilotChatModel::Claude3_5Sonnet => count_anthropic_tokens(request, cx), + CopilotChatModel::Gemini20Flash => count_google_tokens(request, cx), _ => { let model = match self.model { CopilotChatModel::Gpt4o => open_ai::Model::FourOmni, CopilotChatModel::Gpt4 => open_ai::Model::Four, CopilotChatModel::Gpt3_5Turbo => open_ai::Model::ThreePointFiveTurbo, CopilotChatModel::O1 | CopilotChatModel::O3Mini => open_ai::Model::Four, - CopilotChatModel::Claude3_5Sonnet => unreachable!(), + CopilotChatModel::Claude3_5Sonnet | CopilotChatModel::Gemini20Flash => { + unreachable!() + } }; count_open_ai_tokens(request, model, cx) } diff --git a/crates/language_models/src/provider/google.rs b/crates/language_models/src/provider/google.rs index 791a03b78a..2054b686dc 100644 --- a/crates/language_models/src/provider/google.rs +++ b/crates/language_models/src/provider/google.rs @@ -11,7 +11,7 @@ use language_model::LanguageModelCompletionEvent; use language_model::{ LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, - LanguageModelRequest, RateLimiter, + LanguageModelRequest, RateLimiter, Role, }; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; @@ -324,6 +324,36 @@ impl LanguageModel for GoogleLanguageModel { } } +pub fn count_google_tokens( + request: LanguageModelRequest, + cx: &App, +) -> BoxFuture<'static, Result> { + // We couldn't use the GoogleLanguageModelProvider to count tokens because the github copilot doesn't have the access to google_ai directly. + // So we have to use tokenizer from tiktoken_rs to count tokens. + cx.background_executor() + .spawn(async move { + let messages = request + .messages + .into_iter() + .map(|message| tiktoken_rs::ChatCompletionRequestMessage { + role: match message.role { + Role::User => "user".into(), + Role::Assistant => "assistant".into(), + Role::System => "system".into(), + }, + content: Some(message.string_contents()), + name: None, + function_call: None, + }) + .collect::>(); + + // Tiktoken doesn't yet support these models, so we manually use the + // same tokenizer as GPT-4. + tiktoken_rs::num_tokens_from_messages("gpt-4", &messages) + }) + .boxed() +} + struct ConfigurationView { api_key_editor: Entity, state: gpui::Entity,