Add support for getting the token count for all parts of Gemini generation requests (#29630)

* `CountTokensRequest` now takes a full `GenerateContentRequest` instead
of just content.

* Fixes use of `models/` prefix in `model` field of
`GenerateContentRequest`, since that's required for use in
`CountTokensRequest`. This didn't cause issues before because it was
always cleared and used in the path.

Release Notes:

- N/A
This commit is contained in:
Michael Sloan 2025-05-04 15:32:45 -06:00 committed by GitHub
parent 86484233c0
commit 76ad1a29a5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 90 additions and 33 deletions

View file

@ -718,7 +718,8 @@ impl LanguageModel for CloudLanguageModel {
CloudModel::Google(model) => {
let client = self.client.clone();
let llm_api_token = self.llm_api_token.clone();
let request = into_google(request, model.id().into());
let model_id = model.id().to_string();
let generate_content_request = into_google(request, model_id.clone());
async move {
let http_client = &client.http_client();
let token = llm_api_token.acquire(&client).await?;
@ -736,9 +737,9 @@ impl LanguageModel for CloudLanguageModel {
};
let request_body = CountTokensBody {
provider: zed_llm_client::LanguageModelProvider::Google,
model: model.id().into(),
model: model_id,
provider_request: serde_json::to_value(&google_ai::CountTokensRequest {
contents: request.contents,
generate_content_request,
})?,
};
let request = request_builder
@ -895,7 +896,7 @@ impl LanguageModel for CloudLanguageModel {
prompt_id,
mode,
provider: zed_llm_client::LanguageModelProvider::Google,
model: request.model.clone(),
model: request.model.model_id.clone(),
provider_request: serde_json::to_value(&request)?,
},
)

View file

@ -344,9 +344,8 @@ impl LanguageModel for GoogleLanguageModel {
http_client.as_ref(),
&api_url,
&api_key,
&model_id,
google_ai::CountTokensRequest {
contents: request.contents,
generate_content_request: request,
},
)
.await?;
@ -382,7 +381,7 @@ impl LanguageModel for GoogleLanguageModel {
pub fn into_google(
mut request: LanguageModelRequest,
model: String,
model_id: String,
) -> google_ai::GenerateContentRequest {
fn map_content(content: Vec<MessageContent>) -> Vec<Part> {
content
@ -442,7 +441,7 @@ pub fn into_google(
};
google_ai::GenerateContentRequest {
model,
model: google_ai::ModelName { model_id },
system_instruction: system_instructions,
contents: request
.messages