ZIm/crates/reqwest_client/src/reqwest_client.rs
Thorsten Ball aee01f2c50
assistant: Remove low_speed_timeout (#20681)
This removes the `low_speed_timeout` setting from all providers as a
response to issue #19509.

Reason being that the original `low_speed_timeout` was only as part of
#9913 because users wanted to _get rid of timeouts_. They wanted to bump
the default timeout from 5sec to a lot more.

Then, in the meantime, the meaning of `low_speed_timeout` changed in
#19055 and was changed to a normal `timeout`, which is a different thing
and breaks slower LLMs that don't reply with a complete response in the
configured timeout.

So we figured: let's remove the whole thing and replace it with a
default _connect_ timeout to make sure that we can connect to a server
in 10s, but then give the server as long as it wants to complete its
response.

Closes #19509

Release Notes:

- Removed the `low_speed_timeout` setting from LLM provider settings,
since it was only used to _increase_ the timeout to give LLMs more time,
but since we don't have any other use for it, we simply remove the
setting to give LLMs as long as they need.

---------

Co-authored-by: Antonio <antonio@zed.dev>
Co-authored-by: Peter Tripp <peter@zed.dev>
2024-11-15 07:37:31 +01:00

279 lines
9.4 KiB
Rust

use std::{any::type_name, mem, pin::Pin, sync::OnceLock, task::Poll, time::Duration};
use anyhow::anyhow;
use bytes::{BufMut, Bytes, BytesMut};
use futures::{AsyncRead, TryStreamExt as _};
use http_client::{http, RedirectPolicy};
use reqwest::{
header::{HeaderMap, HeaderValue},
redirect,
};
use smol::future::FutureExt;
const DEFAULT_CAPACITY: usize = 4096;
static RUNTIME: OnceLock<tokio::runtime::Runtime> = OnceLock::new();
pub struct ReqwestClient {
client: reqwest::Client,
proxy: Option<http::Uri>,
handle: tokio::runtime::Handle,
}
impl ReqwestClient {
fn builder() -> reqwest::ClientBuilder {
reqwest::Client::builder()
.use_rustls_tls()
.connect_timeout(Duration::from_secs(10))
}
pub fn new() -> Self {
Self::builder()
.build()
.expect("Failed to initialize HTTP client")
.into()
}
pub fn user_agent(agent: &str) -> anyhow::Result<Self> {
let mut map = HeaderMap::new();
map.insert(http::header::USER_AGENT, HeaderValue::from_str(agent)?);
let client = Self::builder().default_headers(map).build()?;
Ok(client.into())
}
pub fn proxy_and_user_agent(proxy: Option<http::Uri>, agent: &str) -> anyhow::Result<Self> {
let mut map = HeaderMap::new();
map.insert(http::header::USER_AGENT, HeaderValue::from_str(agent)?);
let mut client = Self::builder().default_headers(map);
if let Some(proxy) = proxy.clone().and_then(|proxy_uri| {
reqwest::Proxy::all(proxy_uri.to_string())
.inspect_err(|e| log::error!("Failed to parse proxy URI {}: {}", proxy_uri, e))
.ok()
}) {
client = client.proxy(proxy);
}
let client = client.build()?;
let mut client: ReqwestClient = client.into();
client.proxy = proxy;
Ok(client)
}
}
impl From<reqwest::Client> for ReqwestClient {
fn from(client: reqwest::Client) -> Self {
let handle = tokio::runtime::Handle::try_current().unwrap_or_else(|_| {
log::debug!("no tokio runtime found, creating one for Reqwest...");
let runtime = RUNTIME.get_or_init(|| {
tokio::runtime::Builder::new_multi_thread()
// Since we now have two executors, let's try to keep our footprint small
.worker_threads(1)
.enable_all()
.build()
.expect("Failed to initialize HTTP client")
});
runtime.handle().clone()
});
Self {
client,
handle,
proxy: None,
}
}
}
// This struct is essentially a re-implementation of
// https://docs.rs/tokio-util/0.7.12/tokio_util/io/struct.ReaderStream.html
// except outside of Tokio's aegis
struct StreamReader {
reader: Option<Pin<Box<dyn futures::AsyncRead + Send + Sync>>>,
buf: BytesMut,
capacity: usize,
}
impl StreamReader {
fn new(reader: Pin<Box<dyn futures::AsyncRead + Send + Sync>>) -> Self {
Self {
reader: Some(reader),
buf: BytesMut::new(),
capacity: DEFAULT_CAPACITY,
}
}
}
impl futures::Stream for StreamReader {
type Item = std::io::Result<Bytes>;
fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Option<Self::Item>> {
let mut this = self.as_mut();
let mut reader = match this.reader.take() {
Some(r) => r,
None => return Poll::Ready(None),
};
if this.buf.capacity() == 0 {
let capacity = this.capacity;
this.buf.reserve(capacity);
}
match poll_read_buf(&mut reader, cx, &mut this.buf) {
Poll::Pending => Poll::Pending,
Poll::Ready(Err(err)) => {
self.reader = None;
Poll::Ready(Some(Err(err)))
}
Poll::Ready(Ok(0)) => {
self.reader = None;
Poll::Ready(None)
}
Poll::Ready(Ok(_)) => {
let chunk = this.buf.split();
self.reader = Some(reader);
Poll::Ready(Some(Ok(chunk.freeze())))
}
}
}
}
/// Implementation from https://docs.rs/tokio-util/0.7.12/src/tokio_util/util/poll_buf.rs.html
/// Specialized for this use case
pub fn poll_read_buf(
io: &mut Pin<Box<dyn futures::AsyncRead + Send + Sync>>,
cx: &mut std::task::Context<'_>,
buf: &mut BytesMut,
) -> Poll<std::io::Result<usize>> {
if !buf.has_remaining_mut() {
return Poll::Ready(Ok(0));
}
let n = {
let dst = buf.chunk_mut();
// Safety: `chunk_mut()` returns a `&mut UninitSlice`, and `UninitSlice` is a
// transparent wrapper around `[MaybeUninit<u8>]`.
let dst = unsafe { &mut *(dst as *mut _ as *mut [std::mem::MaybeUninit<u8>]) };
let mut buf = tokio::io::ReadBuf::uninit(dst);
let ptr = buf.filled().as_ptr();
let unfilled_portion = buf.initialize_unfilled();
// SAFETY: Pin projection
let io_pin = unsafe { Pin::new_unchecked(io) };
std::task::ready!(io_pin.poll_read(cx, unfilled_portion)?);
// Ensure the pointer does not change from under us
assert_eq!(ptr, buf.filled().as_ptr());
buf.filled().len()
};
// Safety: This is guaranteed to be the number of initialized (and read)
// bytes due to the invariants provided by `ReadBuf::filled`.
unsafe {
buf.advance_mut(n);
}
Poll::Ready(Ok(n))
}
impl http_client::HttpClient for ReqwestClient {
fn proxy(&self) -> Option<&http::Uri> {
self.proxy.as_ref()
}
fn type_name(&self) -> &'static str {
type_name::<Self>()
}
fn send(
&self,
req: http::Request<http_client::AsyncBody>,
) -> futures::future::BoxFuture<
'static,
Result<http_client::Response<http_client::AsyncBody>, anyhow::Error>,
> {
let (parts, body) = req.into_parts();
let mut request = self.client.request(parts.method, parts.uri.to_string());
request = request.headers(parts.headers);
if let Some(redirect_policy) = parts.extensions.get::<RedirectPolicy>() {
request = request.redirect_policy(match redirect_policy {
RedirectPolicy::NoFollow => redirect::Policy::none(),
RedirectPolicy::FollowLimit(limit) => redirect::Policy::limited(*limit as usize),
RedirectPolicy::FollowAll => redirect::Policy::limited(100),
});
}
let request = request.body(match body.0 {
http_client::Inner::Empty => reqwest::Body::default(),
http_client::Inner::Bytes(cursor) => cursor.into_inner().into(),
http_client::Inner::AsyncReader(stream) => {
reqwest::Body::wrap_stream(StreamReader::new(stream))
}
});
let handle = self.handle.clone();
async move {
let mut response = handle.spawn(async { request.send().await }).await??;
let headers = mem::take(response.headers_mut());
let mut builder = http::Response::builder()
.status(response.status().as_u16())
.version(response.version());
*builder.headers_mut().unwrap() = headers;
let bytes = response
.bytes_stream()
.map_err(|e| futures::io::Error::new(futures::io::ErrorKind::Other, e))
.into_async_read();
let body = http_client::AsyncBody::from_reader(bytes);
builder.body(body).map_err(|e| anyhow!(e))
}
.boxed()
}
}
#[cfg(test)]
mod tests {
use http_client::{http, HttpClient};
use crate::ReqwestClient;
#[test]
fn test_proxy_uri() {
let client = ReqwestClient::new();
assert_eq!(client.proxy(), None);
let proxy = http::Uri::from_static("http://localhost:10809");
let client = ReqwestClient::proxy_and_user_agent(Some(proxy.clone()), "test").unwrap();
assert_eq!(client.proxy(), Some(&proxy));
let proxy = http::Uri::from_static("https://localhost:10809");
let client = ReqwestClient::proxy_and_user_agent(Some(proxy.clone()), "test").unwrap();
assert_eq!(client.proxy(), Some(&proxy));
let proxy = http::Uri::from_static("socks4://localhost:10808");
let client = ReqwestClient::proxy_and_user_agent(Some(proxy.clone()), "test").unwrap();
assert_eq!(client.proxy(), Some(&proxy));
let proxy = http::Uri::from_static("socks4a://localhost:10808");
let client = ReqwestClient::proxy_and_user_agent(Some(proxy.clone()), "test").unwrap();
assert_eq!(client.proxy(), Some(&proxy));
let proxy = http::Uri::from_static("socks5://localhost:10808");
let client = ReqwestClient::proxy_and_user_agent(Some(proxy.clone()), "test").unwrap();
assert_eq!(client.proxy(), Some(&proxy));
let proxy = http::Uri::from_static("socks5h://localhost:10808");
let client = ReqwestClient::proxy_and_user_agent(Some(proxy.clone()), "test").unwrap();
assert_eq!(client.proxy(), Some(&proxy));
}
#[test]
#[should_panic]
fn test_invalid_proxy_uri() {
let proxy = http::Uri::from_static("file:///etc/hosts");
ReqwestClient::proxy_and_user_agent(Some(proxy), "test").unwrap();
}
}