diff --git a/crates/google_ai/src/google_ai.rs b/crates/google_ai/src/google_ai.rs index 84bb15affa..f3210d9dfe 100644 --- a/crates/google_ai/src/google_ai.rs +++ b/crates/google_ai/src/google_ai.rs @@ -1,7 +1,9 @@ +use std::mem; + use anyhow::{Result, anyhow, bail}; use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream}; use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest}; -use serde::{Deserialize, Serialize}; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; pub const API_URL: &str = "https://generativelanguage.googleapis.com"; @@ -11,25 +13,13 @@ pub async fn stream_generate_content( api_key: &str, mut request: GenerateContentRequest, ) -> Result>> { - if request.contents.is_empty() { - bail!("Request must contain at least one content item"); - } + validate_generate_content_request(&request)?; - if let Some(user_content) = request - .contents - .iter() - .find(|content| content.role == Role::User) - { - if user_content.parts.is_empty() { - bail!("User content must contain at least one part"); - } - } + // The `model` field is emptied as it is provided as a path parameter. + let model_id = mem::take(&mut request.model.model_id); - let uri = format!( - "{api_url}/v1beta/models/{model}:streamGenerateContent?alt=sse&key={api_key}", - model = request.model - ); - request.model.clear(); + let uri = + format!("{api_url}/v1beta/models/{model_id}:streamGenerateContent?alt=sse&key={api_key}",); let request_builder = HttpRequest::builder() .method(Method::POST) @@ -76,18 +66,22 @@ pub async fn count_tokens( client: &dyn HttpClient, api_url: &str, api_key: &str, - model_id: &str, request: CountTokensRequest, ) -> Result { - let uri = format!("{api_url}/v1beta/models/{model_id}:countTokens?key={api_key}",); - let request = serde_json::to_string(&request)?; + validate_generate_content_request(&request.generate_content_request)?; + let uri = format!( + "{api_url}/v1beta/models/{model_id}:countTokens?key={api_key}", + model_id = &request.generate_content_request.model.model_id, + ); + + let request = serde_json::to_string(&request)?; let request_builder = HttpRequest::builder() .method(Method::POST) .uri(&uri) .header("Content-Type", "application/json"); - let http_request = request_builder.body(AsyncBody::from(request))?; + let mut response = client.send(http_request).await?; let mut text = String::new(); response.body_mut().read_to_string(&mut text).await?; @@ -102,6 +96,28 @@ pub async fn count_tokens( } } +pub fn validate_generate_content_request(request: &GenerateContentRequest) -> Result<()> { + if request.model.is_empty() { + bail!("Model must be specified"); + } + + if request.contents.is_empty() { + bail!("Request must contain at least one content item"); + } + + if let Some(user_content) = request + .contents + .iter() + .find(|content| content.role == Role::User) + { + if user_content.parts.is_empty() { + bail!("User content must contain at least one part"); + } + } + + Ok(()) +} + #[derive(Debug, Serialize, Deserialize)] pub enum Task { #[serde(rename = "generateContent")] @@ -119,8 +135,8 @@ pub enum Task { #[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct GenerateContentRequest { - #[serde(default, skip_serializing_if = "String::is_empty")] - pub model: String, + #[serde(default, skip_serializing_if = "ModelName::is_empty")] + pub model: ModelName, pub contents: Vec, #[serde(skip_serializing_if = "Option::is_none")] pub system_instruction: Option, @@ -350,7 +366,7 @@ pub struct SafetyRating { #[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct CountTokensRequest { - pub contents: Vec, + pub generate_content_request: GenerateContentRequest, } #[derive(Debug, Serialize, Deserialize)] @@ -406,6 +422,47 @@ pub struct FunctionDeclaration { pub parameters: serde_json::Value, } +#[derive(Debug, Default)] +pub struct ModelName { + pub model_id: String, +} + +impl ModelName { + pub fn is_empty(&self) -> bool { + self.model_id.is_empty() + } +} + +const MODEL_NAME_PREFIX: &str = "models/"; + +impl Serialize for ModelName { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + serializer.serialize_str(&format!("{MODEL_NAME_PREFIX}{}", &self.model_id)) + } +} + +impl<'de> Deserialize<'de> for ModelName { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let string = String::deserialize(deserializer)?; + if let Some(id) = string.strip_prefix(MODEL_NAME_PREFIX) { + Ok(Self { + model_id: id.to_string(), + }) + } else { + return Err(serde::de::Error::custom(format!( + "Expected model name to begin with {}, got: {}", + MODEL_NAME_PREFIX, string + ))); + } + } +} + #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] #[derive(Clone, Default, Debug, Deserialize, Serialize, PartialEq, Eq, strum::EnumIter)] pub enum Model { diff --git a/crates/language_models/src/provider/cloud.rs b/crates/language_models/src/provider/cloud.rs index 76f38fed71..939992c973 100644 --- a/crates/language_models/src/provider/cloud.rs +++ b/crates/language_models/src/provider/cloud.rs @@ -718,7 +718,8 @@ impl LanguageModel for CloudLanguageModel { CloudModel::Google(model) => { let client = self.client.clone(); let llm_api_token = self.llm_api_token.clone(); - let request = into_google(request, model.id().into()); + let model_id = model.id().to_string(); + let generate_content_request = into_google(request, model_id.clone()); async move { let http_client = &client.http_client(); let token = llm_api_token.acquire(&client).await?; @@ -736,9 +737,9 @@ impl LanguageModel for CloudLanguageModel { }; let request_body = CountTokensBody { provider: zed_llm_client::LanguageModelProvider::Google, - model: model.id().into(), + model: model_id, provider_request: serde_json::to_value(&google_ai::CountTokensRequest { - contents: request.contents, + generate_content_request, })?, }; let request = request_builder @@ -895,7 +896,7 @@ impl LanguageModel for CloudLanguageModel { prompt_id, mode, provider: zed_llm_client::LanguageModelProvider::Google, - model: request.model.clone(), + model: request.model.model_id.clone(), provider_request: serde_json::to_value(&request)?, }, ) diff --git a/crates/language_models/src/provider/google.rs b/crates/language_models/src/provider/google.rs index 279556e5a8..43068f1b00 100644 --- a/crates/language_models/src/provider/google.rs +++ b/crates/language_models/src/provider/google.rs @@ -344,9 +344,8 @@ impl LanguageModel for GoogleLanguageModel { http_client.as_ref(), &api_url, &api_key, - &model_id, google_ai::CountTokensRequest { - contents: request.contents, + generate_content_request: request, }, ) .await?; @@ -382,7 +381,7 @@ impl LanguageModel for GoogleLanguageModel { pub fn into_google( mut request: LanguageModelRequest, - model: String, + model_id: String, ) -> google_ai::GenerateContentRequest { fn map_content(content: Vec) -> Vec { content @@ -442,7 +441,7 @@ pub fn into_google( }; google_ai::GenerateContentRequest { - model, + model: google_ai::ModelName { model_id }, system_instruction: system_instructions, contents: request .messages