diff --git a/crates/google_ai/src/google_ai.rs b/crates/google_ai/src/google_ai.rs index bb32cb48c0..b40c5714b8 100644 --- a/crates/google_ai/src/google_ai.rs +++ b/crates/google_ai/src/google_ai.rs @@ -1,6 +1,6 @@ mod supported_countries; -use anyhow::{anyhow, Result}; +use anyhow::{anyhow, bail, Result}; use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt}; use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest}; use serde::{Deserialize, Serialize}; @@ -15,6 +15,20 @@ pub async fn stream_generate_content( api_key: &str, mut request: GenerateContentRequest, ) -> Result>> { + if request.contents.is_empty() { + bail!("Request must contain at least one content item"); + } + + if let Some(user_content) = request + .contents + .iter() + .find(|content| content.role == Role::User) + { + if user_content.parts.is_empty() { + bail!("User content must contain at least one part"); + } + } + let uri = format!( "{api_url}/v1beta/models/{model}:streamGenerateContent?alt=sse&key={api_key}", model = request.model @@ -140,7 +154,7 @@ pub struct Content { pub role: Role, } -#[derive(Debug, Deserialize, Serialize)] +#[derive(Debug, PartialEq, Deserialize, Serialize)] #[serde(rename_all = "camelCase")] pub enum Role { User, @@ -291,6 +305,8 @@ pub enum Model { Gemini15Pro, #[serde(rename = "gemini-1.5-flash")] Gemini15Flash, + #[serde(rename = "gemini-2.0-flash-exp")] + Gemini20Flash, #[serde(rename = "custom")] Custom { name: String, @@ -305,6 +321,7 @@ impl Model { match self { Model::Gemini15Pro => "gemini-1.5-pro", Model::Gemini15Flash => "gemini-1.5-flash", + Model::Gemini20Flash => "gemini-2.0-flash-exp", Model::Custom { name, .. } => name, } } @@ -313,6 +330,7 @@ impl Model { match self { Model::Gemini15Pro => "Gemini 1.5 Pro", Model::Gemini15Flash => "Gemini 1.5 Flash", + Model::Gemini20Flash => "Gemini 2.0 Flash", Self::Custom { name, display_name, .. } => display_name.as_ref().unwrap_or(name), @@ -323,6 +341,7 @@ impl Model { match self { Model::Gemini15Pro => 2_000_000, Model::Gemini15Flash => 1_000_000, + Model::Gemini20Flash => 1_000_000, Model::Custom { max_tokens, .. } => *max_tokens, } } diff --git a/crates/language_model/src/model/cloud_model.rs b/crates/language_model/src/model/cloud_model.rs index 804adb22f1..57ed28d625 100644 --- a/crates/language_model/src/model/cloud_model.rs +++ b/crates/language_model/src/model/cloud_model.rs @@ -88,6 +88,7 @@ impl CloudModel { Self::Google(model) => match model { google_ai::Model::Gemini15Pro | google_ai::Model::Gemini15Flash + | google_ai::Model::Gemini20Flash | google_ai::Model::Custom { .. } => { LanguageModelAvailability::RequiresPlan(Plan::ZedPro) }