assistant: Fix Google AI provider not respecting low_speed_timeout_in_seconds (#17423)

Release Notes:

- Fixed an issue when using Google Gemini models, where the setting
`low_speed_timeout_in_seconds` was not respected
This commit is contained in:
Bennet Bo Fenner 2024-09-05 18:16:30 +02:00 committed by GitHub
parent a1c676128a
commit f413ea90bf
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 50 additions and 12 deletions

View file

@ -18,6 +18,7 @@ schemars = ["dep:schemars"]
anyhow.workspace = true
futures.workspace = true
http_client.workspace = true
isahc.workspace = true
schemars = { workspace = true, optional = true }
serde.workspace = true
serde_json.workspace = true

View file

@ -2,8 +2,10 @@ mod supported_countries;
use anyhow::{anyhow, Result};
use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
use http_client::HttpClient;
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
use isahc::config::Configurable;
use serde::{Deserialize, Serialize};
use std::time::Duration;
pub use supported_countries::*;
@ -14,6 +16,7 @@ pub async fn stream_generate_content(
api_url: &str,
api_key: &str,
mut request: GenerateContentRequest,
low_speed_timeout: Option<Duration>,
) -> Result<BoxStream<'static, Result<GenerateContentResponse>>> {
let uri = format!(
"{api_url}/v1beta/models/{model}:streamGenerateContent?alt=sse&key={api_key}",
@ -21,8 +24,17 @@ pub async fn stream_generate_content(
);
request.model.clear();
let request = serde_json::to_string(&request)?;
let mut response = client.post_json(&uri, request.into()).await?;
let mut request_builder = HttpRequest::builder()
.method(Method::POST)
.uri(uri)
.header("Content-Type", "application/json");
if let Some(low_speed_timeout) = low_speed_timeout {
request_builder = request_builder.low_speed_timeout(100, low_speed_timeout);
};
let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
let mut response = client.send(request).await?;
if response.status().is_success() {
let reader = BufReader::new(response.into_body());
Ok(reader
@ -59,13 +71,25 @@ pub async fn count_tokens(
api_url: &str,
api_key: &str,
request: CountTokensRequest,
low_speed_timeout: Option<Duration>,
) -> Result<CountTokensResponse> {
let uri = format!(
"{}/v1beta/models/gemini-pro:countTokens?key={}",
api_url, api_key
);
let request = serde_json::to_string(&request)?;
let mut response = client.post_json(&uri, request.into()).await?;
let mut request_builder = HttpRequest::builder()
.method(Method::POST)
.uri(&uri)
.header("Content-Type", "application/json");
if let Some(low_speed_timeout) = low_speed_timeout {
request_builder = request_builder.low_speed_timeout(100, low_speed_timeout);
}
let http_request = request_builder.body(AsyncBody::from(request))?;
let mut response = client.send(http_request).await?;
let mut text = String::new();
response.body_mut().read_to_string(&mut text).await?;
if response.status().is_success() {