use std::mem; use anyhow::{Result, anyhow, bail}; use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream}; use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest}; use serde::{Deserialize, Deserializer, Serialize, Serializer}; pub const API_URL: &str = "https://generativelanguage.googleapis.com"; pub async fn stream_generate_content( client: &dyn HttpClient, api_url: &str, api_key: &str, mut request: GenerateContentRequest, ) -> Result>> { validate_generate_content_request(&request)?; // The `model` field is emptied as it is provided as a path parameter. let model_id = mem::take(&mut request.model.model_id); let uri = format!("{api_url}/v1beta/models/{model_id}:streamGenerateContent?alt=sse&key={api_key}",); let request_builder = HttpRequest::builder() .method(Method::POST) .uri(uri) .header("Content-Type", "application/json"); let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?; let mut response = client.send(request).await?; if response.status().is_success() { let reader = BufReader::new(response.into_body()); Ok(reader .lines() .filter_map(|line| async move { match line { Ok(line) => { if let Some(line) = line.strip_prefix("data: ") { match serde_json::from_str(line) { Ok(response) => Some(Ok(response)), Err(error) => Some(Err(anyhow!(format!( "Error parsing JSON: {error:?}\n{line:?}" )))), } } else { None } } Err(error) => Some(Err(anyhow!(error))), } }) .boxed()) } else { let mut text = String::new(); response.body_mut().read_to_string(&mut text).await?; Err(anyhow!( "error during streamGenerateContent, status code: {:?}, body: {}", response.status(), text )) } } pub async fn count_tokens( client: &dyn HttpClient, api_url: &str, api_key: &str, request: CountTokensRequest, ) -> Result { validate_generate_content_request(&request.generate_content_request)?; let uri = format!( "{api_url}/v1beta/models/{model_id}:countTokens?key={api_key}", model_id = &request.generate_content_request.model.model_id, ); let request = serde_json::to_string(&request)?; let request_builder = HttpRequest::builder() .method(Method::POST) .uri(&uri) .header("Content-Type", "application/json"); let http_request = request_builder.body(AsyncBody::from(request))?; let mut response = client.send(http_request).await?; let mut text = String::new(); response.body_mut().read_to_string(&mut text).await?; anyhow::ensure!( response.status().is_success(), "error during countTokens, status code: {:?}, body: {}", response.status(), text ); Ok(serde_json::from_str::(&text)?) } pub fn validate_generate_content_request(request: &GenerateContentRequest) -> Result<()> { if request.model.is_empty() { bail!("Model must be specified"); } if request.contents.is_empty() { bail!("Request must contain at least one content item"); } if let Some(user_content) = request .contents .iter() .find(|content| content.role == Role::User) { if user_content.parts.is_empty() { bail!("User content must contain at least one part"); } } Ok(()) } #[derive(Debug, Serialize, Deserialize)] pub enum Task { #[serde(rename = "generateContent")] GenerateContent, #[serde(rename = "streamGenerateContent")] StreamGenerateContent, #[serde(rename = "countTokens")] CountTokens, #[serde(rename = "embedContent")] EmbedContent, #[serde(rename = "batchEmbedContents")] BatchEmbedContents, } #[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct GenerateContentRequest { #[serde(default, skip_serializing_if = "ModelName::is_empty")] pub model: ModelName, pub contents: Vec, #[serde(skip_serializing_if = "Option::is_none")] pub system_instruction: Option, #[serde(skip_serializing_if = "Option::is_none")] pub generation_config: Option, #[serde(skip_serializing_if = "Option::is_none")] pub safety_settings: Option>, #[serde(skip_serializing_if = "Option::is_none")] pub tools: Option>, #[serde(skip_serializing_if = "Option::is_none")] pub tool_config: Option, } #[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct GenerateContentResponse { #[serde(skip_serializing_if = "Option::is_none")] pub candidates: Option>, #[serde(skip_serializing_if = "Option::is_none")] pub prompt_feedback: Option, #[serde(skip_serializing_if = "Option::is_none")] pub usage_metadata: Option, } #[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct GenerateContentCandidate { #[serde(skip_serializing_if = "Option::is_none")] pub index: Option, pub content: Content, #[serde(skip_serializing_if = "Option::is_none")] pub finish_reason: Option, #[serde(skip_serializing_if = "Option::is_none")] pub finish_message: Option, #[serde(skip_serializing_if = "Option::is_none")] pub safety_ratings: Option>, #[serde(skip_serializing_if = "Option::is_none")] pub citation_metadata: Option, } #[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct Content { #[serde(default)] pub parts: Vec, pub role: Role, } #[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct SystemInstruction { pub parts: Vec, } #[derive(Debug, PartialEq, Deserialize, Serialize)] #[serde(rename_all = "camelCase")] pub enum Role { User, Model, } #[derive(Debug, Serialize, Deserialize)] #[serde(untagged)] pub enum Part { TextPart(TextPart), InlineDataPart(InlineDataPart), FunctionCallPart(FunctionCallPart), FunctionResponsePart(FunctionResponsePart), ThoughtPart(ThoughtPart), } #[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct TextPart { pub text: String, } #[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct InlineDataPart { pub inline_data: GenerativeContentBlob, } #[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct GenerativeContentBlob { pub mime_type: String, pub data: String, } #[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct FunctionCallPart { pub function_call: FunctionCall, } #[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct FunctionResponsePart { pub function_response: FunctionResponse, } #[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct ThoughtPart { pub thought: bool, pub thought_signature: String, } #[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct CitationSource { #[serde(skip_serializing_if = "Option::is_none")] pub start_index: Option, #[serde(skip_serializing_if = "Option::is_none")] pub end_index: Option, #[serde(skip_serializing_if = "Option::is_none")] pub uri: Option, #[serde(skip_serializing_if = "Option::is_none")] pub license: Option, } #[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct CitationMetadata { pub citation_sources: Vec, } #[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct PromptFeedback { #[serde(skip_serializing_if = "Option::is_none")] pub block_reason: Option, pub safety_ratings: Vec, #[serde(skip_serializing_if = "Option::is_none")] pub block_reason_message: Option, } #[derive(Debug, Serialize, Deserialize, Default)] #[serde(rename_all = "camelCase")] pub struct UsageMetadata { #[serde(skip_serializing_if = "Option::is_none")] pub prompt_token_count: Option, #[serde(skip_serializing_if = "Option::is_none")] pub cached_content_token_count: Option, #[serde(skip_serializing_if = "Option::is_none")] pub candidates_token_count: Option, #[serde(skip_serializing_if = "Option::is_none")] pub tool_use_prompt_token_count: Option, #[serde(skip_serializing_if = "Option::is_none")] pub thoughts_token_count: Option, #[serde(skip_serializing_if = "Option::is_none")] pub total_token_count: Option, } #[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct ThinkingConfig { pub thinking_budget: u32, } #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] #[derive(Copy, Clone, Debug, Default, Serialize, Deserialize, PartialEq, Eq)] pub enum GoogleModelMode { #[default] Default, Thinking { budget_tokens: Option, }, } #[derive(Debug, Deserialize, Serialize)] #[serde(rename_all = "camelCase")] pub struct GenerationConfig { #[serde(skip_serializing_if = "Option::is_none")] pub candidate_count: Option, #[serde(skip_serializing_if = "Option::is_none")] pub stop_sequences: Option>, #[serde(skip_serializing_if = "Option::is_none")] pub max_output_tokens: Option, #[serde(skip_serializing_if = "Option::is_none")] pub temperature: Option, #[serde(skip_serializing_if = "Option::is_none")] pub top_p: Option, #[serde(skip_serializing_if = "Option::is_none")] pub top_k: Option, #[serde(skip_serializing_if = "Option::is_none")] pub thinking_config: Option, } #[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct SafetySetting { pub category: HarmCategory, pub threshold: HarmBlockThreshold, } #[derive(Debug, Serialize, Deserialize)] pub enum HarmCategory { #[serde(rename = "HARM_CATEGORY_UNSPECIFIED")] Unspecified, #[serde(rename = "HARM_CATEGORY_DEROGATORY")] Derogatory, #[serde(rename = "HARM_CATEGORY_TOXICITY")] Toxicity, #[serde(rename = "HARM_CATEGORY_VIOLENCE")] Violence, #[serde(rename = "HARM_CATEGORY_SEXUAL")] Sexual, #[serde(rename = "HARM_CATEGORY_MEDICAL")] Medical, #[serde(rename = "HARM_CATEGORY_DANGEROUS")] Dangerous, #[serde(rename = "HARM_CATEGORY_HARASSMENT")] Harassment, #[serde(rename = "HARM_CATEGORY_HATE_SPEECH")] HateSpeech, #[serde(rename = "HARM_CATEGORY_SEXUALLY_EXPLICIT")] SexuallyExplicit, #[serde(rename = "HARM_CATEGORY_DANGEROUS_CONTENT")] DangerousContent, } #[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "SCREAMING_SNAKE_CASE")] pub enum HarmBlockThreshold { #[serde(rename = "HARM_BLOCK_THRESHOLD_UNSPECIFIED")] Unspecified, BlockLowAndAbove, BlockMediumAndAbove, BlockOnlyHigh, BlockNone, } #[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "SCREAMING_SNAKE_CASE")] pub enum HarmProbability { #[serde(rename = "HARM_PROBABILITY_UNSPECIFIED")] Unspecified, Negligible, Low, Medium, High, } #[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct SafetyRating { pub category: HarmCategory, pub probability: HarmProbability, } #[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct CountTokensRequest { pub generate_content_request: GenerateContentRequest, } #[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct CountTokensResponse { pub total_tokens: usize, } #[derive(Debug, Serialize, Deserialize)] pub struct FunctionCall { pub name: String, pub args: serde_json::Value, } #[derive(Debug, Serialize, Deserialize)] pub struct FunctionResponse { pub name: String, pub response: serde_json::Value, } #[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct Tool { pub function_declarations: Vec, } #[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct ToolConfig { pub function_calling_config: FunctionCallingConfig, } #[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct FunctionCallingConfig { pub mode: FunctionCallingMode, #[serde(skip_serializing_if = "Option::is_none")] pub allowed_function_names: Option>, } #[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "lowercase")] pub enum FunctionCallingMode { Auto, Any, None, } #[derive(Debug, Serialize, Deserialize)] pub struct FunctionDeclaration { pub name: String, pub description: String, pub parameters: serde_json::Value, } #[derive(Debug, Default)] pub struct ModelName { pub model_id: String, } impl ModelName { pub fn is_empty(&self) -> bool { self.model_id.is_empty() } } const MODEL_NAME_PREFIX: &str = "models/"; impl Serialize for ModelName { fn serialize(&self, serializer: S) -> Result where S: Serializer, { serializer.serialize_str(&format!("{MODEL_NAME_PREFIX}{}", &self.model_id)) } } impl<'de> Deserialize<'de> for ModelName { fn deserialize(deserializer: D) -> Result where D: Deserializer<'de>, { let string = String::deserialize(deserializer)?; if let Some(id) = string.strip_prefix(MODEL_NAME_PREFIX) { Ok(Self { model_id: id.to_string(), }) } else { return Err(serde::de::Error::custom(format!( "Expected model name to begin with {}, got: {}", MODEL_NAME_PREFIX, string ))); } } } #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] #[derive(Clone, Default, Debug, Deserialize, Serialize, PartialEq, Eq, strum::EnumIter)] pub enum Model { #[serde(rename = "gemini-1.5-pro")] Gemini15Pro, #[serde(rename = "gemini-1.5-flash")] Gemini15Flash, #[serde(rename = "gemini-2.0-pro-exp")] Gemini20Pro, #[serde(rename = "gemini-2.0-flash")] #[default] Gemini20Flash, #[serde(rename = "gemini-2.0-flash-thinking-exp")] Gemini20FlashThinking, #[serde(rename = "gemini-2.0-flash-lite-preview")] Gemini20FlashLite, #[serde(rename = "gemini-2.5-pro-exp-03-25")] Gemini25ProExp0325, #[serde(rename = "gemini-2.5-pro-preview-03-25")] Gemini25ProPreview0325, #[serde(rename = "gemini-2.5-flash-preview-04-17")] Gemini25FlashPreview0417, #[serde( rename = "gemini-2.5-flash-preview-latest", alias = "gemini-2.5-flash-preview-05-20" )] Gemini25FlashPreview, #[serde( rename = "gemini-2.5-pro-preview-latest", alias = "gemini-2.5-pro-preview-06-05" )] Gemini25ProPreview, #[serde(rename = "custom")] Custom { name: String, /// The name displayed in the UI, such as in the assistant panel model dropdown menu. display_name: Option, max_tokens: usize, #[serde(default)] mode: GoogleModelMode, }, } impl Model { pub fn default_fast() -> Model { Model::Gemini20Flash } pub fn id(&self) -> &str { match self { Model::Gemini15Pro => "gemini-1.5-pro", Model::Gemini15Flash => "gemini-1.5-flash", Model::Gemini20Pro => "gemini-2.0-pro-exp", Model::Gemini20Flash => "gemini-2.0-flash", Model::Gemini20FlashThinking => "gemini-2.0-flash-thinking-exp", Model::Gemini20FlashLite => "gemini-2.0-flash-lite-preview", Model::Gemini25ProExp0325 => "gemini-2.5-pro-exp-03-25", Model::Gemini25ProPreview0325 => "gemini-2.5-pro-preview-03-25", Model::Gemini25FlashPreview0417 => "gemini-2.5-flash-preview-04-17", Model::Gemini25FlashPreview => "gemini-2.5-flash-preview-latest", Model::Gemini25ProPreview => "gemini-2.5-pro-preview-latest", Model::Custom { name, .. } => name, } } pub fn request_id(&self) -> &str { match self { Model::Gemini15Pro => "gemini-1.5-pro", Model::Gemini15Flash => "gemini-1.5-flash", Model::Gemini20Pro => "gemini-2.0-pro-exp", Model::Gemini20Flash => "gemini-2.0-flash", Model::Gemini20FlashThinking => "gemini-2.0-flash-thinking-exp", Model::Gemini20FlashLite => "gemini-2.0-flash-lite-preview", Model::Gemini25ProExp0325 => "gemini-2.5-pro-exp-03-25", Model::Gemini25ProPreview0325 => "gemini-2.5-pro-preview-03-25", Model::Gemini25FlashPreview0417 => "gemini-2.5-flash-preview-04-17", Model::Gemini25FlashPreview => "gemini-2.5-flash-preview-05-20", Model::Gemini25ProPreview => "gemini-2.5-pro-preview-06-05", Model::Custom { name, .. } => name, } } pub fn display_name(&self) -> &str { match self { Model::Gemini15Pro => "Gemini 1.5 Pro", Model::Gemini15Flash => "Gemini 1.5 Flash", Model::Gemini20Pro => "Gemini 2.0 Pro", Model::Gemini20Flash => "Gemini 2.0 Flash", Model::Gemini20FlashThinking => "Gemini 2.0 Flash Thinking", Model::Gemini20FlashLite => "Gemini 2.0 Flash Lite", Model::Gemini25ProExp0325 => "Gemini 2.5 Pro Exp", Model::Gemini25ProPreview0325 => "Gemini 2.5 Pro Preview (0325)", Model::Gemini25FlashPreview0417 => "Gemini 2.5 Flash Preview (0417)", Model::Gemini25FlashPreview => "Gemini 2.5 Flash Preview", Model::Gemini25ProPreview => "Gemini 2.5 Pro Preview", Self::Custom { name, display_name, .. } => display_name.as_ref().unwrap_or(name), } } pub fn max_token_count(&self) -> usize { const ONE_MILLION: usize = 1_048_576; const TWO_MILLION: usize = 2_097_152; match self { Model::Gemini15Pro => TWO_MILLION, Model::Gemini15Flash => ONE_MILLION, Model::Gemini20Pro => TWO_MILLION, Model::Gemini20Flash => ONE_MILLION, Model::Gemini20FlashThinking => ONE_MILLION, Model::Gemini20FlashLite => ONE_MILLION, Model::Gemini25ProExp0325 => ONE_MILLION, Model::Gemini25ProPreview0325 => ONE_MILLION, Model::Gemini25FlashPreview0417 => ONE_MILLION, Model::Gemini25FlashPreview => ONE_MILLION, Model::Gemini25ProPreview => ONE_MILLION, Model::Custom { max_tokens, .. } => *max_tokens, } } pub fn mode(&self) -> GoogleModelMode { match self { Self::Gemini15Pro | Self::Gemini15Flash | Self::Gemini20Pro | Self::Gemini20Flash | Self::Gemini20FlashThinking | Self::Gemini20FlashLite | Self::Gemini25ProExp0325 | Self::Gemini25ProPreview | Self::Gemini25FlashPreview | Self::Gemini25ProPreview0325 | Self::Gemini25FlashPreview0417 => GoogleModelMode::Default, Self::Custom { mode, .. } => *mode, } } } impl std::fmt::Display for Model { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{}", self.id()) } }