Add support for getting the token count for all parts of Gemini generation requests (#29630)
* `CountTokensRequest` now takes a full `GenerateContentRequest` instead of just content. * Fixes use of `models/` prefix in `model` field of `GenerateContentRequest`, since that's required for use in `CountTokensRequest`. This didn't cause issues before because it was always cleared and used in the path. Release Notes: - N/A
This commit is contained in:
parent
86484233c0
commit
76ad1a29a5
3 changed files with 90 additions and 33 deletions
|
@ -1,7 +1,9 @@
|
||||||
|
use std::mem;
|
||||||
|
|
||||||
use anyhow::{Result, anyhow, bail};
|
use anyhow::{Result, anyhow, bail};
|
||||||
use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
|
use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
|
||||||
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
|
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Deserializer, Serialize, Serializer};
|
||||||
|
|
||||||
pub const API_URL: &str = "https://generativelanguage.googleapis.com";
|
pub const API_URL: &str = "https://generativelanguage.googleapis.com";
|
||||||
|
|
||||||
|
@ -11,25 +13,13 @@ pub async fn stream_generate_content(
|
||||||
api_key: &str,
|
api_key: &str,
|
||||||
mut request: GenerateContentRequest,
|
mut request: GenerateContentRequest,
|
||||||
) -> Result<BoxStream<'static, Result<GenerateContentResponse>>> {
|
) -> Result<BoxStream<'static, Result<GenerateContentResponse>>> {
|
||||||
if request.contents.is_empty() {
|
validate_generate_content_request(&request)?;
|
||||||
bail!("Request must contain at least one content item");
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Some(user_content) = request
|
// The `model` field is emptied as it is provided as a path parameter.
|
||||||
.contents
|
let model_id = mem::take(&mut request.model.model_id);
|
||||||
.iter()
|
|
||||||
.find(|content| content.role == Role::User)
|
|
||||||
{
|
|
||||||
if user_content.parts.is_empty() {
|
|
||||||
bail!("User content must contain at least one part");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let uri = format!(
|
let uri =
|
||||||
"{api_url}/v1beta/models/{model}:streamGenerateContent?alt=sse&key={api_key}",
|
format!("{api_url}/v1beta/models/{model_id}:streamGenerateContent?alt=sse&key={api_key}",);
|
||||||
model = request.model
|
|
||||||
);
|
|
||||||
request.model.clear();
|
|
||||||
|
|
||||||
let request_builder = HttpRequest::builder()
|
let request_builder = HttpRequest::builder()
|
||||||
.method(Method::POST)
|
.method(Method::POST)
|
||||||
|
@ -76,18 +66,22 @@ pub async fn count_tokens(
|
||||||
client: &dyn HttpClient,
|
client: &dyn HttpClient,
|
||||||
api_url: &str,
|
api_url: &str,
|
||||||
api_key: &str,
|
api_key: &str,
|
||||||
model_id: &str,
|
|
||||||
request: CountTokensRequest,
|
request: CountTokensRequest,
|
||||||
) -> Result<CountTokensResponse> {
|
) -> Result<CountTokensResponse> {
|
||||||
let uri = format!("{api_url}/v1beta/models/{model_id}:countTokens?key={api_key}",);
|
validate_generate_content_request(&request.generate_content_request)?;
|
||||||
let request = serde_json::to_string(&request)?;
|
|
||||||
|
|
||||||
|
let uri = format!(
|
||||||
|
"{api_url}/v1beta/models/{model_id}:countTokens?key={api_key}",
|
||||||
|
model_id = &request.generate_content_request.model.model_id,
|
||||||
|
);
|
||||||
|
|
||||||
|
let request = serde_json::to_string(&request)?;
|
||||||
let request_builder = HttpRequest::builder()
|
let request_builder = HttpRequest::builder()
|
||||||
.method(Method::POST)
|
.method(Method::POST)
|
||||||
.uri(&uri)
|
.uri(&uri)
|
||||||
.header("Content-Type", "application/json");
|
.header("Content-Type", "application/json");
|
||||||
|
|
||||||
let http_request = request_builder.body(AsyncBody::from(request))?;
|
let http_request = request_builder.body(AsyncBody::from(request))?;
|
||||||
|
|
||||||
let mut response = client.send(http_request).await?;
|
let mut response = client.send(http_request).await?;
|
||||||
let mut text = String::new();
|
let mut text = String::new();
|
||||||
response.body_mut().read_to_string(&mut text).await?;
|
response.body_mut().read_to_string(&mut text).await?;
|
||||||
|
@ -102,6 +96,28 @@ pub async fn count_tokens(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn validate_generate_content_request(request: &GenerateContentRequest) -> Result<()> {
|
||||||
|
if request.model.is_empty() {
|
||||||
|
bail!("Model must be specified");
|
||||||
|
}
|
||||||
|
|
||||||
|
if request.contents.is_empty() {
|
||||||
|
bail!("Request must contain at least one content item");
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(user_content) = request
|
||||||
|
.contents
|
||||||
|
.iter()
|
||||||
|
.find(|content| content.role == Role::User)
|
||||||
|
{
|
||||||
|
if user_content.parts.is_empty() {
|
||||||
|
bail!("User content must contain at least one part");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
pub enum Task {
|
pub enum Task {
|
||||||
#[serde(rename = "generateContent")]
|
#[serde(rename = "generateContent")]
|
||||||
|
@ -119,8 +135,8 @@ pub enum Task {
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
#[serde(rename_all = "camelCase")]
|
#[serde(rename_all = "camelCase")]
|
||||||
pub struct GenerateContentRequest {
|
pub struct GenerateContentRequest {
|
||||||
#[serde(default, skip_serializing_if = "String::is_empty")]
|
#[serde(default, skip_serializing_if = "ModelName::is_empty")]
|
||||||
pub model: String,
|
pub model: ModelName,
|
||||||
pub contents: Vec<Content>,
|
pub contents: Vec<Content>,
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub system_instruction: Option<SystemInstruction>,
|
pub system_instruction: Option<SystemInstruction>,
|
||||||
|
@ -350,7 +366,7 @@ pub struct SafetyRating {
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
#[serde(rename_all = "camelCase")]
|
#[serde(rename_all = "camelCase")]
|
||||||
pub struct CountTokensRequest {
|
pub struct CountTokensRequest {
|
||||||
pub contents: Vec<Content>,
|
pub generate_content_request: GenerateContentRequest,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
@ -406,6 +422,47 @@ pub struct FunctionDeclaration {
|
||||||
pub parameters: serde_json::Value,
|
pub parameters: serde_json::Value,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Default)]
|
||||||
|
pub struct ModelName {
|
||||||
|
pub model_id: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ModelName {
|
||||||
|
pub fn is_empty(&self) -> bool {
|
||||||
|
self.model_id.is_empty()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const MODEL_NAME_PREFIX: &str = "models/";
|
||||||
|
|
||||||
|
impl Serialize for ModelName {
|
||||||
|
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||||
|
where
|
||||||
|
S: Serializer,
|
||||||
|
{
|
||||||
|
serializer.serialize_str(&format!("{MODEL_NAME_PREFIX}{}", &self.model_id))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'de> Deserialize<'de> for ModelName {
|
||||||
|
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||||
|
where
|
||||||
|
D: Deserializer<'de>,
|
||||||
|
{
|
||||||
|
let string = String::deserialize(deserializer)?;
|
||||||
|
if let Some(id) = string.strip_prefix(MODEL_NAME_PREFIX) {
|
||||||
|
Ok(Self {
|
||||||
|
model_id: id.to_string(),
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
return Err(serde::de::Error::custom(format!(
|
||||||
|
"Expected model name to begin with {}, got: {}",
|
||||||
|
MODEL_NAME_PREFIX, string
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
|
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
|
||||||
#[derive(Clone, Default, Debug, Deserialize, Serialize, PartialEq, Eq, strum::EnumIter)]
|
#[derive(Clone, Default, Debug, Deserialize, Serialize, PartialEq, Eq, strum::EnumIter)]
|
||||||
pub enum Model {
|
pub enum Model {
|
||||||
|
|
|
@ -718,7 +718,8 @@ impl LanguageModel for CloudLanguageModel {
|
||||||
CloudModel::Google(model) => {
|
CloudModel::Google(model) => {
|
||||||
let client = self.client.clone();
|
let client = self.client.clone();
|
||||||
let llm_api_token = self.llm_api_token.clone();
|
let llm_api_token = self.llm_api_token.clone();
|
||||||
let request = into_google(request, model.id().into());
|
let model_id = model.id().to_string();
|
||||||
|
let generate_content_request = into_google(request, model_id.clone());
|
||||||
async move {
|
async move {
|
||||||
let http_client = &client.http_client();
|
let http_client = &client.http_client();
|
||||||
let token = llm_api_token.acquire(&client).await?;
|
let token = llm_api_token.acquire(&client).await?;
|
||||||
|
@ -736,9 +737,9 @@ impl LanguageModel for CloudLanguageModel {
|
||||||
};
|
};
|
||||||
let request_body = CountTokensBody {
|
let request_body = CountTokensBody {
|
||||||
provider: zed_llm_client::LanguageModelProvider::Google,
|
provider: zed_llm_client::LanguageModelProvider::Google,
|
||||||
model: model.id().into(),
|
model: model_id,
|
||||||
provider_request: serde_json::to_value(&google_ai::CountTokensRequest {
|
provider_request: serde_json::to_value(&google_ai::CountTokensRequest {
|
||||||
contents: request.contents,
|
generate_content_request,
|
||||||
})?,
|
})?,
|
||||||
};
|
};
|
||||||
let request = request_builder
|
let request = request_builder
|
||||||
|
@ -895,7 +896,7 @@ impl LanguageModel for CloudLanguageModel {
|
||||||
prompt_id,
|
prompt_id,
|
||||||
mode,
|
mode,
|
||||||
provider: zed_llm_client::LanguageModelProvider::Google,
|
provider: zed_llm_client::LanguageModelProvider::Google,
|
||||||
model: request.model.clone(),
|
model: request.model.model_id.clone(),
|
||||||
provider_request: serde_json::to_value(&request)?,
|
provider_request: serde_json::to_value(&request)?,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
|
@ -344,9 +344,8 @@ impl LanguageModel for GoogleLanguageModel {
|
||||||
http_client.as_ref(),
|
http_client.as_ref(),
|
||||||
&api_url,
|
&api_url,
|
||||||
&api_key,
|
&api_key,
|
||||||
&model_id,
|
|
||||||
google_ai::CountTokensRequest {
|
google_ai::CountTokensRequest {
|
||||||
contents: request.contents,
|
generate_content_request: request,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
|
@ -382,7 +381,7 @@ impl LanguageModel for GoogleLanguageModel {
|
||||||
|
|
||||||
pub fn into_google(
|
pub fn into_google(
|
||||||
mut request: LanguageModelRequest,
|
mut request: LanguageModelRequest,
|
||||||
model: String,
|
model_id: String,
|
||||||
) -> google_ai::GenerateContentRequest {
|
) -> google_ai::GenerateContentRequest {
|
||||||
fn map_content(content: Vec<MessageContent>) -> Vec<Part> {
|
fn map_content(content: Vec<MessageContent>) -> Vec<Part> {
|
||||||
content
|
content
|
||||||
|
@ -442,7 +441,7 @@ pub fn into_google(
|
||||||
};
|
};
|
||||||
|
|
||||||
google_ai::GenerateContentRequest {
|
google_ai::GenerateContentRequest {
|
||||||
model,
|
model: google_ai::ModelName { model_id },
|
||||||
system_instruction: system_instructions,
|
system_instruction: system_instructions,
|
||||||
contents: request
|
contents: request
|
||||||
.messages
|
.messages
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue