Add support for getting the token count for all parts of Gemini generation requests (#29630)
* `CountTokensRequest` now takes a full `GenerateContentRequest` instead of just content. * Fixes use of `models/` prefix in `model` field of `GenerateContentRequest`, since that's required for use in `CountTokensRequest`. This didn't cause issues before because it was always cleared and used in the path. Release Notes: - N/A
This commit is contained in:
parent
86484233c0
commit
76ad1a29a5
3 changed files with 90 additions and 33 deletions
|
@ -718,7 +718,8 @@ impl LanguageModel for CloudLanguageModel {
|
|||
CloudModel::Google(model) => {
|
||||
let client = self.client.clone();
|
||||
let llm_api_token = self.llm_api_token.clone();
|
||||
let request = into_google(request, model.id().into());
|
||||
let model_id = model.id().to_string();
|
||||
let generate_content_request = into_google(request, model_id.clone());
|
||||
async move {
|
||||
let http_client = &client.http_client();
|
||||
let token = llm_api_token.acquire(&client).await?;
|
||||
|
@ -736,9 +737,9 @@ impl LanguageModel for CloudLanguageModel {
|
|||
};
|
||||
let request_body = CountTokensBody {
|
||||
provider: zed_llm_client::LanguageModelProvider::Google,
|
||||
model: model.id().into(),
|
||||
model: model_id,
|
||||
provider_request: serde_json::to_value(&google_ai::CountTokensRequest {
|
||||
contents: request.contents,
|
||||
generate_content_request,
|
||||
})?,
|
||||
};
|
||||
let request = request_builder
|
||||
|
@ -895,7 +896,7 @@ impl LanguageModel for CloudLanguageModel {
|
|||
prompt_id,
|
||||
mode,
|
||||
provider: zed_llm_client::LanguageModelProvider::Google,
|
||||
model: request.model.clone(),
|
||||
model: request.model.model_id.clone(),
|
||||
provider_request: serde_json::to_value(&request)?,
|
||||
},
|
||||
)
|
||||
|
|
|
@ -344,9 +344,8 @@ impl LanguageModel for GoogleLanguageModel {
|
|||
http_client.as_ref(),
|
||||
&api_url,
|
||||
&api_key,
|
||||
&model_id,
|
||||
google_ai::CountTokensRequest {
|
||||
contents: request.contents,
|
||||
generate_content_request: request,
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
|
@ -382,7 +381,7 @@ impl LanguageModel for GoogleLanguageModel {
|
|||
|
||||
pub fn into_google(
|
||||
mut request: LanguageModelRequest,
|
||||
model: String,
|
||||
model_id: String,
|
||||
) -> google_ai::GenerateContentRequest {
|
||||
fn map_content(content: Vec<MessageContent>) -> Vec<Part> {
|
||||
content
|
||||
|
@ -442,7 +441,7 @@ pub fn into_google(
|
|||
};
|
||||
|
||||
google_ai::GenerateContentRequest {
|
||||
model,
|
||||
model: google_ai::ModelName { model_id },
|
||||
system_instruction: system_instructions,
|
||||
contents: request
|
||||
.messages
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue