Add support for getting the token count for all parts of Gemini generation requests (#29630)

* `CountTokensRequest` now takes a full `GenerateContentRequest` instead
of just content.

* Fixes use of `models/` prefix in `model` field of
`GenerateContentRequest`, since that's required for use in
`CountTokensRequest`. This didn't cause issues before because it was
always cleared and used in the path.

Release Notes:

- N/A
This commit is contained in:
Michael Sloan 2025-05-04 15:32:45 -06:00 committed by GitHub
parent 86484233c0
commit 76ad1a29a5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 90 additions and 33 deletions

View file

@ -1,7 +1,9 @@
use std::mem;
use anyhow::{Result, anyhow, bail}; use anyhow::{Result, anyhow, bail};
use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream}; use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest}; use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Deserializer, Serialize, Serializer};
pub const API_URL: &str = "https://generativelanguage.googleapis.com"; pub const API_URL: &str = "https://generativelanguage.googleapis.com";
@ -11,25 +13,13 @@ pub async fn stream_generate_content(
api_key: &str, api_key: &str,
mut request: GenerateContentRequest, mut request: GenerateContentRequest,
) -> Result<BoxStream<'static, Result<GenerateContentResponse>>> { ) -> Result<BoxStream<'static, Result<GenerateContentResponse>>> {
if request.contents.is_empty() { validate_generate_content_request(&request)?;
bail!("Request must contain at least one content item");
}
if let Some(user_content) = request // The `model` field is emptied as it is provided as a path parameter.
.contents let model_id = mem::take(&mut request.model.model_id);
.iter()
.find(|content| content.role == Role::User)
{
if user_content.parts.is_empty() {
bail!("User content must contain at least one part");
}
}
let uri = format!( let uri =
"{api_url}/v1beta/models/{model}:streamGenerateContent?alt=sse&key={api_key}", format!("{api_url}/v1beta/models/{model_id}:streamGenerateContent?alt=sse&key={api_key}",);
model = request.model
);
request.model.clear();
let request_builder = HttpRequest::builder() let request_builder = HttpRequest::builder()
.method(Method::POST) .method(Method::POST)
@ -76,18 +66,22 @@ pub async fn count_tokens(
client: &dyn HttpClient, client: &dyn HttpClient,
api_url: &str, api_url: &str,
api_key: &str, api_key: &str,
model_id: &str,
request: CountTokensRequest, request: CountTokensRequest,
) -> Result<CountTokensResponse> { ) -> Result<CountTokensResponse> {
let uri = format!("{api_url}/v1beta/models/{model_id}:countTokens?key={api_key}",); validate_generate_content_request(&request.generate_content_request)?;
let request = serde_json::to_string(&request)?;
let uri = format!(
"{api_url}/v1beta/models/{model_id}:countTokens?key={api_key}",
model_id = &request.generate_content_request.model.model_id,
);
let request = serde_json::to_string(&request)?;
let request_builder = HttpRequest::builder() let request_builder = HttpRequest::builder()
.method(Method::POST) .method(Method::POST)
.uri(&uri) .uri(&uri)
.header("Content-Type", "application/json"); .header("Content-Type", "application/json");
let http_request = request_builder.body(AsyncBody::from(request))?; let http_request = request_builder.body(AsyncBody::from(request))?;
let mut response = client.send(http_request).await?; let mut response = client.send(http_request).await?;
let mut text = String::new(); let mut text = String::new();
response.body_mut().read_to_string(&mut text).await?; response.body_mut().read_to_string(&mut text).await?;
@ -102,6 +96,28 @@ pub async fn count_tokens(
} }
} }
pub fn validate_generate_content_request(request: &GenerateContentRequest) -> Result<()> {
if request.model.is_empty() {
bail!("Model must be specified");
}
if request.contents.is_empty() {
bail!("Request must contain at least one content item");
}
if let Some(user_content) = request
.contents
.iter()
.find(|content| content.role == Role::User)
{
if user_content.parts.is_empty() {
bail!("User content must contain at least one part");
}
}
Ok(())
}
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
pub enum Task { pub enum Task {
#[serde(rename = "generateContent")] #[serde(rename = "generateContent")]
@ -119,8 +135,8 @@ pub enum Task {
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
pub struct GenerateContentRequest { pub struct GenerateContentRequest {
#[serde(default, skip_serializing_if = "String::is_empty")] #[serde(default, skip_serializing_if = "ModelName::is_empty")]
pub model: String, pub model: ModelName,
pub contents: Vec<Content>, pub contents: Vec<Content>,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub system_instruction: Option<SystemInstruction>, pub system_instruction: Option<SystemInstruction>,
@ -350,7 +366,7 @@ pub struct SafetyRating {
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
pub struct CountTokensRequest { pub struct CountTokensRequest {
pub contents: Vec<Content>, pub generate_content_request: GenerateContentRequest,
} }
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
@ -406,6 +422,47 @@ pub struct FunctionDeclaration {
pub parameters: serde_json::Value, pub parameters: serde_json::Value,
} }
#[derive(Debug, Default)]
pub struct ModelName {
pub model_id: String,
}
impl ModelName {
pub fn is_empty(&self) -> bool {
self.model_id.is_empty()
}
}
const MODEL_NAME_PREFIX: &str = "models/";
impl Serialize for ModelName {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_str(&format!("{MODEL_NAME_PREFIX}{}", &self.model_id))
}
}
impl<'de> Deserialize<'de> for ModelName {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let string = String::deserialize(deserializer)?;
if let Some(id) = string.strip_prefix(MODEL_NAME_PREFIX) {
Ok(Self {
model_id: id.to_string(),
})
} else {
return Err(serde::de::Error::custom(format!(
"Expected model name to begin with {}, got: {}",
MODEL_NAME_PREFIX, string
)));
}
}
}
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
#[derive(Clone, Default, Debug, Deserialize, Serialize, PartialEq, Eq, strum::EnumIter)] #[derive(Clone, Default, Debug, Deserialize, Serialize, PartialEq, Eq, strum::EnumIter)]
pub enum Model { pub enum Model {

View file

@ -718,7 +718,8 @@ impl LanguageModel for CloudLanguageModel {
CloudModel::Google(model) => { CloudModel::Google(model) => {
let client = self.client.clone(); let client = self.client.clone();
let llm_api_token = self.llm_api_token.clone(); let llm_api_token = self.llm_api_token.clone();
let request = into_google(request, model.id().into()); let model_id = model.id().to_string();
let generate_content_request = into_google(request, model_id.clone());
async move { async move {
let http_client = &client.http_client(); let http_client = &client.http_client();
let token = llm_api_token.acquire(&client).await?; let token = llm_api_token.acquire(&client).await?;
@ -736,9 +737,9 @@ impl LanguageModel for CloudLanguageModel {
}; };
let request_body = CountTokensBody { let request_body = CountTokensBody {
provider: zed_llm_client::LanguageModelProvider::Google, provider: zed_llm_client::LanguageModelProvider::Google,
model: model.id().into(), model: model_id,
provider_request: serde_json::to_value(&google_ai::CountTokensRequest { provider_request: serde_json::to_value(&google_ai::CountTokensRequest {
contents: request.contents, generate_content_request,
})?, })?,
}; };
let request = request_builder let request = request_builder
@ -895,7 +896,7 @@ impl LanguageModel for CloudLanguageModel {
prompt_id, prompt_id,
mode, mode,
provider: zed_llm_client::LanguageModelProvider::Google, provider: zed_llm_client::LanguageModelProvider::Google,
model: request.model.clone(), model: request.model.model_id.clone(),
provider_request: serde_json::to_value(&request)?, provider_request: serde_json::to_value(&request)?,
}, },
) )

View file

@ -344,9 +344,8 @@ impl LanguageModel for GoogleLanguageModel {
http_client.as_ref(), http_client.as_ref(),
&api_url, &api_url,
&api_key, &api_key,
&model_id,
google_ai::CountTokensRequest { google_ai::CountTokensRequest {
contents: request.contents, generate_content_request: request,
}, },
) )
.await?; .await?;
@ -382,7 +381,7 @@ impl LanguageModel for GoogleLanguageModel {
pub fn into_google( pub fn into_google(
mut request: LanguageModelRequest, mut request: LanguageModelRequest,
model: String, model_id: String,
) -> google_ai::GenerateContentRequest { ) -> google_ai::GenerateContentRequest {
fn map_content(content: Vec<MessageContent>) -> Vec<Part> { fn map_content(content: Vec<MessageContent>) -> Vec<Part> {
content content
@ -442,7 +441,7 @@ pub fn into_google(
}; };
google_ai::GenerateContentRequest { google_ai::GenerateContentRequest {
model, model: google_ai::ModelName { model_id },
system_instruction: system_instructions, system_instruction: system_instructions,
contents: request contents: request
.messages .messages