mod supported_countries; use anyhow::{anyhow, Result}; use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt}; use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest}; use serde::{Deserialize, Serialize}; pub use supported_countries::*; pub const API_URL: &str = "https://generativelanguage.googleapis.com"; pub async fn stream_generate_content( client: &dyn HttpClient, api_url: &str, api_key: &str, mut request: GenerateContentRequest, ) -> Result>> { let uri = format!( "{api_url}/v1beta/models/{model}:streamGenerateContent?alt=sse&key={api_key}", model = request.model ); request.model.clear(); let request_builder = HttpRequest::builder() .method(Method::POST) .uri(uri) .header("Content-Type", "application/json"); let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?; let mut response = client.send(request).await?; if response.status().is_success() { let reader = BufReader::new(response.into_body()); Ok(reader .lines() .filter_map(|line| async move { match line { Ok(line) => { if let Some(line) = line.strip_prefix("data: ") { match serde_json::from_str(line) { Ok(response) => Some(Ok(response)), Err(error) => Some(Err(anyhow!(error))), } } else { None } } Err(error) => Some(Err(anyhow!(error))), } }) .boxed()) } else { let mut text = String::new(); response.body_mut().read_to_string(&mut text).await?; Err(anyhow!( "error during streamGenerateContent, status code: {:?}, body: {}", response.status(), text )) } } pub async fn count_tokens( client: &dyn HttpClient, api_url: &str, api_key: &str, request: CountTokensRequest, ) -> Result { let uri = format!( "{}/v1beta/models/gemini-pro:countTokens?key={}", api_url, api_key ); let request = serde_json::to_string(&request)?; let request_builder = HttpRequest::builder() .method(Method::POST) .uri(&uri) .header("Content-Type", "application/json"); let http_request = request_builder.body(AsyncBody::from(request))?; let mut response = client.send(http_request).await?; let mut text = String::new(); response.body_mut().read_to_string(&mut text).await?; if response.status().is_success() { Ok(serde_json::from_str::(&text)?) } else { Err(anyhow!( "error during countTokens, status code: {:?}, body: {}", response.status(), text )) } } #[derive(Debug, Serialize, Deserialize)] pub enum Task { #[serde(rename = "generateContent")] GenerateContent, #[serde(rename = "streamGenerateContent")] StreamGenerateContent, #[serde(rename = "countTokens")] CountTokens, #[serde(rename = "embedContent")] EmbedContent, #[serde(rename = "batchEmbedContents")] BatchEmbedContents, } #[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct GenerateContentRequest { #[serde(default, skip_serializing_if = "String::is_empty")] pub model: String, pub contents: Vec, pub generation_config: Option, pub safety_settings: Option>, } #[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct GenerateContentResponse { pub candidates: Option>, pub prompt_feedback: Option, } #[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct GenerateContentCandidate { pub index: Option, pub content: Content, pub finish_reason: Option, pub finish_message: Option, pub safety_ratings: Option>, pub citation_metadata: Option, } #[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct Content { pub parts: Vec, pub role: Role, } #[derive(Debug, Deserialize, Serialize)] #[serde(rename_all = "camelCase")] pub enum Role { User, Model, } #[derive(Debug, Serialize, Deserialize)] #[serde(untagged)] pub enum Part { TextPart(TextPart), InlineDataPart(InlineDataPart), } #[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct TextPart { pub text: String, } #[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct InlineDataPart { pub inline_data: GenerativeContentBlob, } #[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct GenerativeContentBlob { pub mime_type: String, pub data: String, } #[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct CitationSource { pub start_index: Option, pub end_index: Option, pub uri: Option, pub license: Option, } #[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct CitationMetadata { pub citation_sources: Vec, } #[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct PromptFeedback { pub block_reason: Option, pub safety_ratings: Vec, pub block_reason_message: Option, } #[derive(Debug, Deserialize, Serialize)] #[serde(rename_all = "camelCase")] pub struct GenerationConfig { pub candidate_count: Option, pub stop_sequences: Option>, pub max_output_tokens: Option, pub temperature: Option, pub top_p: Option, pub top_k: Option, } #[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct SafetySetting { pub category: HarmCategory, pub threshold: HarmBlockThreshold, } #[derive(Debug, Serialize, Deserialize)] pub enum HarmCategory { #[serde(rename = "HARM_CATEGORY_UNSPECIFIED")] Unspecified, #[serde(rename = "HARM_CATEGORY_DEROGATORY")] Derogatory, #[serde(rename = "HARM_CATEGORY_TOXICITY")] Toxicity, #[serde(rename = "HARM_CATEGORY_VIOLENCE")] Violence, #[serde(rename = "HARM_CATEGORY_SEXUAL")] Sexual, #[serde(rename = "HARM_CATEGORY_MEDICAL")] Medical, #[serde(rename = "HARM_CATEGORY_DANGEROUS")] Dangerous, #[serde(rename = "HARM_CATEGORY_HARASSMENT")] Harassment, #[serde(rename = "HARM_CATEGORY_HATE_SPEECH")] HateSpeech, #[serde(rename = "HARM_CATEGORY_SEXUALLY_EXPLICIT")] SexuallyExplicit, #[serde(rename = "HARM_CATEGORY_DANGEROUS_CONTENT")] DangerousContent, } #[derive(Debug, Serialize, Deserialize)] pub enum HarmBlockThreshold { #[serde(rename = "HARM_BLOCK_THRESHOLD_UNSPECIFIED")] Unspecified, #[serde(rename = "BLOCK_LOW_AND_ABOVE")] BlockLowAndAbove, #[serde(rename = "BLOCK_MEDIUM_AND_ABOVE")] BlockMediumAndAbove, #[serde(rename = "BLOCK_ONLY_HIGH")] BlockOnlyHigh, #[serde(rename = "BLOCK_NONE")] BlockNone, } #[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "SCREAMING_SNAKE_CASE")] pub enum HarmProbability { #[serde(rename = "HARM_PROBABILITY_UNSPECIFIED")] Unspecified, Negligible, Low, Medium, High, } #[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct SafetyRating { pub category: HarmCategory, pub probability: HarmProbability, } #[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct CountTokensRequest { pub contents: Vec, } #[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct CountTokensResponse { pub total_tokens: usize, } #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] #[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq, strum::EnumIter)] pub enum Model { #[serde(rename = "gemini-1.5-pro")] Gemini15Pro, #[serde(rename = "gemini-1.5-flash")] Gemini15Flash, #[serde(rename = "custom")] Custom { name: String, /// The name displayed in the UI, such as in the assistant panel model dropdown menu. display_name: Option, max_tokens: usize, }, } impl Model { pub fn id(&self) -> &str { match self { Model::Gemini15Pro => "gemini-1.5-pro", Model::Gemini15Flash => "gemini-1.5-flash", Model::Custom { name, .. } => name, } } pub fn display_name(&self) -> &str { match self { Model::Gemini15Pro => "Gemini 1.5 Pro", Model::Gemini15Flash => "Gemini 1.5 Flash", Self::Custom { name, display_name, .. } => display_name.as_ref().unwrap_or(name), } } pub fn max_token_count(&self) -> usize { match self { Model::Gemini15Pro => 2_000_000, Model::Gemini15Flash => 1_000_000, Model::Custom { max_tokens, .. } => *max_tokens, } } } impl std::fmt::Display for Model { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{}", self.id()) } } pub fn extract_text_from_events( events: impl Stream>, ) -> impl Stream> { events.filter_map(|event| async move { match event { Ok(event) => event.candidates.and_then(|candidates| { candidates.into_iter().next().and_then(|candidate| { candidate.content.parts.into_iter().next().and_then(|part| { if let Part::TextPart(TextPart { text }) = part { Some(Ok(text)) } else { None } }) }) }), Err(error) => Some(Err(error)), } }) }