From 799e81ffe5006cac517579f2a6cdcfa0abe34cfa Mon Sep 17 00:00:00 2001 From: volt <110602200+respberryx@users.noreply.github.com> Date: Mon, 6 Jan 2025 22:28:31 +0100 Subject: [PATCH] google_ai: Add Gemini 2.0 Flash support (#22665) Release Notes: - Added support for Google's Gemini 2.0 Flash experimental model. Note: Weirdly enough the model is slow on small talk responses like 'hi' (in my tests) but very fast on things that need more tokens like 'write me a snake game in python'. Likely an API problem. TESTED ONLY ON WINDOWS! Would test further but don't have Linux installed and don't have an Mac. Will likely work everywhere. Why?: I think Gemini 2.0 Flash is incredibly good model at coding and following instructions. I think it would be nice to have it in the editor. I did as minimal changes as possible while adding the model and streaming validation. I think it's worth merging the commits as they bring good improvements. --------- Co-authored-by: Marshall Bowers --- crates/google_ai/src/google_ai.rs | 23 +++++++++++++++++-- .../language_model/src/model/cloud_model.rs | 1 + 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/crates/google_ai/src/google_ai.rs b/crates/google_ai/src/google_ai.rs index bb32cb48c0..b40c5714b8 100644 --- a/crates/google_ai/src/google_ai.rs +++ b/crates/google_ai/src/google_ai.rs @@ -1,6 +1,6 @@ mod supported_countries; -use anyhow::{anyhow, Result}; +use anyhow::{anyhow, bail, Result}; use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt}; use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest}; use serde::{Deserialize, Serialize}; @@ -15,6 +15,20 @@ pub async fn stream_generate_content( api_key: &str, mut request: GenerateContentRequest, ) -> Result>> { + if request.contents.is_empty() { + bail!("Request must contain at least one content item"); + } + + if let Some(user_content) = request + .contents + .iter() + .find(|content| content.role == Role::User) + { + if user_content.parts.is_empty() { + bail!("User content must contain at least one part"); + } + } + let uri = format!( "{api_url}/v1beta/models/{model}:streamGenerateContent?alt=sse&key={api_key}", model = request.model @@ -140,7 +154,7 @@ pub struct Content { pub role: Role, } -#[derive(Debug, Deserialize, Serialize)] +#[derive(Debug, PartialEq, Deserialize, Serialize)] #[serde(rename_all = "camelCase")] pub enum Role { User, @@ -291,6 +305,8 @@ pub enum Model { Gemini15Pro, #[serde(rename = "gemini-1.5-flash")] Gemini15Flash, + #[serde(rename = "gemini-2.0-flash-exp")] + Gemini20Flash, #[serde(rename = "custom")] Custom { name: String, @@ -305,6 +321,7 @@ impl Model { match self { Model::Gemini15Pro => "gemini-1.5-pro", Model::Gemini15Flash => "gemini-1.5-flash", + Model::Gemini20Flash => "gemini-2.0-flash-exp", Model::Custom { name, .. } => name, } } @@ -313,6 +330,7 @@ impl Model { match self { Model::Gemini15Pro => "Gemini 1.5 Pro", Model::Gemini15Flash => "Gemini 1.5 Flash", + Model::Gemini20Flash => "Gemini 2.0 Flash", Self::Custom { name, display_name, .. } => display_name.as_ref().unwrap_or(name), @@ -323,6 +341,7 @@ impl Model { match self { Model::Gemini15Pro => 2_000_000, Model::Gemini15Flash => 1_000_000, + Model::Gemini20Flash => 1_000_000, Model::Custom { max_tokens, .. } => *max_tokens, } } diff --git a/crates/language_model/src/model/cloud_model.rs b/crates/language_model/src/model/cloud_model.rs index 804adb22f1..57ed28d625 100644 --- a/crates/language_model/src/model/cloud_model.rs +++ b/crates/language_model/src/model/cloud_model.rs @@ -88,6 +88,7 @@ impl CloudModel { Self::Google(model) => match model { google_ai::Model::Gemini15Pro | google_ai::Model::Gemini15Flash + | google_ai::Model::Gemini20Flash | google_ai::Model::Custom { .. } => { LanguageModelAvailability::RequiresPlan(Plan::ZedPro) }