ZIm/crates/google_ai/src/google_ai.rs
Thorsten Ball aee01f2c50
assistant: Remove low_speed_timeout (#20681)
This removes the `low_speed_timeout` setting from all providers as a
response to issue #19509.

Reason being that the original `low_speed_timeout` was only as part of
#9913 because users wanted to _get rid of timeouts_. They wanted to bump
the default timeout from 5sec to a lot more.

Then, in the meantime, the meaning of `low_speed_timeout` changed in
#19055 and was changed to a normal `timeout`, which is a different thing
and breaks slower LLMs that don't reply with a complete response in the
configured timeout.

So we figured: let's remove the whole thing and replace it with a
default _connect_ timeout to make sure that we can connect to a server
in 10s, but then give the server as long as it wants to complete its
response.

Closes #19509

Release Notes:

- Removed the `low_speed_timeout` setting from LLM provider settings,
since it was only used to _increase_ the timeout to give LLMs more time,
but since we don't have any other use for it, we simply remove the
setting to give LLMs as long as they need.

---------

Co-authored-by: Antonio <antonio@zed.dev>
Co-authored-by: Peter Tripp <peter@zed.dev>
2024-11-15 07:37:31 +01:00

356 lines
10 KiB
Rust

mod supported_countries;
use anyhow::{anyhow, Result};
use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
use serde::{Deserialize, Serialize};
pub use supported_countries::*;
pub const API_URL: &str = "https://generativelanguage.googleapis.com";
pub async fn stream_generate_content(
client: &dyn HttpClient,
api_url: &str,
api_key: &str,
mut request: GenerateContentRequest,
) -> Result<BoxStream<'static, Result<GenerateContentResponse>>> {
let uri = format!(
"{api_url}/v1beta/models/{model}:streamGenerateContent?alt=sse&key={api_key}",
model = request.model
);
request.model.clear();
let request_builder = HttpRequest::builder()
.method(Method::POST)
.uri(uri)
.header("Content-Type", "application/json");
let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
let mut response = client.send(request).await?;
if response.status().is_success() {
let reader = BufReader::new(response.into_body());
Ok(reader
.lines()
.filter_map(|line| async move {
match line {
Ok(line) => {
if let Some(line) = line.strip_prefix("data: ") {
match serde_json::from_str(line) {
Ok(response) => Some(Ok(response)),
Err(error) => Some(Err(anyhow!(error))),
}
} else {
None
}
}
Err(error) => Some(Err(anyhow!(error))),
}
})
.boxed())
} else {
let mut text = String::new();
response.body_mut().read_to_string(&mut text).await?;
Err(anyhow!(
"error during streamGenerateContent, status code: {:?}, body: {}",
response.status(),
text
))
}
}
pub async fn count_tokens(
client: &dyn HttpClient,
api_url: &str,
api_key: &str,
request: CountTokensRequest,
) -> Result<CountTokensResponse> {
let uri = format!(
"{}/v1beta/models/gemini-pro:countTokens?key={}",
api_url, api_key
);
let request = serde_json::to_string(&request)?;
let request_builder = HttpRequest::builder()
.method(Method::POST)
.uri(&uri)
.header("Content-Type", "application/json");
let http_request = request_builder.body(AsyncBody::from(request))?;
let mut response = client.send(http_request).await?;
let mut text = String::new();
response.body_mut().read_to_string(&mut text).await?;
if response.status().is_success() {
Ok(serde_json::from_str::<CountTokensResponse>(&text)?)
} else {
Err(anyhow!(
"error during countTokens, status code: {:?}, body: {}",
response.status(),
text
))
}
}
#[derive(Debug, Serialize, Deserialize)]
pub enum Task {
#[serde(rename = "generateContent")]
GenerateContent,
#[serde(rename = "streamGenerateContent")]
StreamGenerateContent,
#[serde(rename = "countTokens")]
CountTokens,
#[serde(rename = "embedContent")]
EmbedContent,
#[serde(rename = "batchEmbedContents")]
BatchEmbedContents,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct GenerateContentRequest {
#[serde(default, skip_serializing_if = "String::is_empty")]
pub model: String,
pub contents: Vec<Content>,
pub generation_config: Option<GenerationConfig>,
pub safety_settings: Option<Vec<SafetySetting>>,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct GenerateContentResponse {
pub candidates: Option<Vec<GenerateContentCandidate>>,
pub prompt_feedback: Option<PromptFeedback>,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct GenerateContentCandidate {
pub index: Option<usize>,
pub content: Content,
pub finish_reason: Option<String>,
pub finish_message: Option<String>,
pub safety_ratings: Option<Vec<SafetyRating>>,
pub citation_metadata: Option<CitationMetadata>,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct Content {
pub parts: Vec<Part>,
pub role: Role,
}
#[derive(Debug, Deserialize, Serialize)]
#[serde(rename_all = "camelCase")]
pub enum Role {
User,
Model,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(untagged)]
pub enum Part {
TextPart(TextPart),
InlineDataPart(InlineDataPart),
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct TextPart {
pub text: String,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct InlineDataPart {
pub inline_data: GenerativeContentBlob,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct GenerativeContentBlob {
pub mime_type: String,
pub data: String,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct CitationSource {
pub start_index: Option<usize>,
pub end_index: Option<usize>,
pub uri: Option<String>,
pub license: Option<String>,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct CitationMetadata {
pub citation_sources: Vec<CitationSource>,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct PromptFeedback {
pub block_reason: Option<String>,
pub safety_ratings: Vec<SafetyRating>,
pub block_reason_message: Option<String>,
}
#[derive(Debug, Deserialize, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct GenerationConfig {
pub candidate_count: Option<usize>,
pub stop_sequences: Option<Vec<String>>,
pub max_output_tokens: Option<usize>,
pub temperature: Option<f64>,
pub top_p: Option<f64>,
pub top_k: Option<usize>,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct SafetySetting {
pub category: HarmCategory,
pub threshold: HarmBlockThreshold,
}
#[derive(Debug, Serialize, Deserialize)]
pub enum HarmCategory {
#[serde(rename = "HARM_CATEGORY_UNSPECIFIED")]
Unspecified,
#[serde(rename = "HARM_CATEGORY_DEROGATORY")]
Derogatory,
#[serde(rename = "HARM_CATEGORY_TOXICITY")]
Toxicity,
#[serde(rename = "HARM_CATEGORY_VIOLENCE")]
Violence,
#[serde(rename = "HARM_CATEGORY_SEXUAL")]
Sexual,
#[serde(rename = "HARM_CATEGORY_MEDICAL")]
Medical,
#[serde(rename = "HARM_CATEGORY_DANGEROUS")]
Dangerous,
#[serde(rename = "HARM_CATEGORY_HARASSMENT")]
Harassment,
#[serde(rename = "HARM_CATEGORY_HATE_SPEECH")]
HateSpeech,
#[serde(rename = "HARM_CATEGORY_SEXUALLY_EXPLICIT")]
SexuallyExplicit,
#[serde(rename = "HARM_CATEGORY_DANGEROUS_CONTENT")]
DangerousContent,
}
#[derive(Debug, Serialize, Deserialize)]
pub enum HarmBlockThreshold {
#[serde(rename = "HARM_BLOCK_THRESHOLD_UNSPECIFIED")]
Unspecified,
#[serde(rename = "BLOCK_LOW_AND_ABOVE")]
BlockLowAndAbove,
#[serde(rename = "BLOCK_MEDIUM_AND_ABOVE")]
BlockMediumAndAbove,
#[serde(rename = "BLOCK_ONLY_HIGH")]
BlockOnlyHigh,
#[serde(rename = "BLOCK_NONE")]
BlockNone,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
pub enum HarmProbability {
#[serde(rename = "HARM_PROBABILITY_UNSPECIFIED")]
Unspecified,
Negligible,
Low,
Medium,
High,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct SafetyRating {
pub category: HarmCategory,
pub probability: HarmProbability,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct CountTokensRequest {
pub contents: Vec<Content>,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct CountTokensResponse {
pub total_tokens: usize,
}
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq, strum::EnumIter)]
pub enum Model {
#[serde(rename = "gemini-1.5-pro")]
Gemini15Pro,
#[serde(rename = "gemini-1.5-flash")]
Gemini15Flash,
#[serde(rename = "custom")]
Custom {
name: String,
/// The name displayed in the UI, such as in the assistant panel model dropdown menu.
display_name: Option<String>,
max_tokens: usize,
},
}
impl Model {
pub fn id(&self) -> &str {
match self {
Model::Gemini15Pro => "gemini-1.5-pro",
Model::Gemini15Flash => "gemini-1.5-flash",
Model::Custom { name, .. } => name,
}
}
pub fn display_name(&self) -> &str {
match self {
Model::Gemini15Pro => "Gemini 1.5 Pro",
Model::Gemini15Flash => "Gemini 1.5 Flash",
Self::Custom {
name, display_name, ..
} => display_name.as_ref().unwrap_or(name),
}
}
pub fn max_token_count(&self) -> usize {
match self {
Model::Gemini15Pro => 2_000_000,
Model::Gemini15Flash => 1_000_000,
Model::Custom { max_tokens, .. } => *max_tokens,
}
}
}
impl std::fmt::Display for Model {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.id())
}
}
pub fn extract_text_from_events(
events: impl Stream<Item = Result<GenerateContentResponse>>,
) -> impl Stream<Item = Result<String>> {
events.filter_map(|event| async move {
match event {
Ok(event) => event.candidates.and_then(|candidates| {
candidates.into_iter().next().and_then(|candidate| {
candidate.content.parts.into_iter().next().and_then(|part| {
if let Part::TextPart(TextPart { text }) = part {
Some(Ok(text))
} else {
None
}
})
})
}),
Err(error) => Some(Err(error)),
}
})
}