diff --git a/Cargo.lock b/Cargo.lock index 26b8847041..a19506829e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6285,6 +6285,7 @@ dependencies = [ "http_client", "image", "inline_completion_button", + "isahc", "language", "log", "menu", diff --git a/crates/language_model/Cargo.toml b/crates/language_model/Cargo.toml index ef273ac44f..b63428c544 100644 --- a/crates/language_model/Cargo.toml +++ b/crates/language_model/Cargo.toml @@ -32,6 +32,7 @@ futures.workspace = true google_ai = { workspace = true, features = ["schemars"] } gpui.workspace = true http_client.workspace = true +isahc.workspace = true inline_completion_button.workspace = true log.workspace = true menu.workspace = true diff --git a/crates/language_model/src/provider/cloud.rs b/crates/language_model/src/provider/cloud.rs index f8f64ff3b8..58efb4cfe1 100644 --- a/crates/language_model/src/provider/cloud.rs +++ b/crates/language_model/src/provider/cloud.rs @@ -19,6 +19,7 @@ use gpui::{ Subscription, Task, }; use http_client::{AsyncBody, HttpClient, Method, Response}; +use isahc::config::Configurable; use schemars::JsonSchema; use serde::{de::DeserializeOwned, Deserialize, Serialize}; use serde_json::value::RawValue; @@ -27,6 +28,7 @@ use smol::{ io::{AsyncReadExt, BufReader}, lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard}, }; +use std::time::Duration; use std::{ future, sync::{Arc, LazyLock}, @@ -56,6 +58,7 @@ fn zed_cloud_provider_additional_models() -> &'static [AvailableModel] { #[derive(Default, Clone, Debug, PartialEq)] pub struct ZedDotDevSettings { pub available_models: Vec, + pub low_speed_timeout: Option, } #[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)] @@ -380,6 +383,7 @@ impl CloudLanguageModel { client: Arc, llm_api_token: LlmApiToken, body: PerformCompletionParams, + low_speed_timeout: Option, ) -> Result> { let http_client = &client.http_client(); @@ -387,7 +391,11 @@ impl CloudLanguageModel { let mut did_retry = false; let response = loop { - let request = http_client::Request::builder() + let mut request_builder = http_client::Request::builder(); + if let Some(low_speed_timeout) = low_speed_timeout { + request_builder = request_builder.low_speed_timeout(100, low_speed_timeout); + }; + let request = request_builder .method(Method::POST) .uri(http_client.build_zed_llm_url("/completion", &[])?.as_ref()) .header("Content-Type", "application/json") @@ -501,8 +509,11 @@ impl LanguageModel for CloudLanguageModel { fn stream_completion( &self, request: LanguageModelRequest, - _cx: &AsyncAppContext, + cx: &AsyncAppContext, ) -> BoxFuture<'static, Result>>> { + let openai_low_speed_timeout = + AllLanguageModelSettings::try_read_global(cx, |s| s.openai.low_speed_timeout.unwrap()); + match &self.model { CloudModel::Anthropic(model) => { let request = request.into_anthropic(model.id().into(), model.max_output_tokens()); @@ -519,6 +530,7 @@ impl LanguageModel for CloudLanguageModel { &request, )?)?, }, + None, ) .await?; Ok(map_to_language_model_completion_events(Box::pin( @@ -542,6 +554,7 @@ impl LanguageModel for CloudLanguageModel { &request, )?)?, }, + openai_low_speed_timeout, ) .await?; Ok(open_ai::extract_text_from_events(response_lines(response))) @@ -569,6 +582,7 @@ impl LanguageModel for CloudLanguageModel { &request, )?)?, }, + None, ) .await?; Ok(google_ai::extract_text_from_events(response_lines( @@ -599,6 +613,7 @@ impl LanguageModel for CloudLanguageModel { &request, )?)?, }, + None, ) .await?; Ok(open_ai::extract_text_from_events(response_lines(response))) @@ -650,6 +665,7 @@ impl LanguageModel for CloudLanguageModel { &request, )?)?, }, + None, ) .await?; @@ -694,6 +710,7 @@ impl LanguageModel for CloudLanguageModel { &request, )?)?, }, + None, ) .await?; @@ -741,6 +758,7 @@ impl LanguageModel for CloudLanguageModel { &request, )?)?, }, + None, ) .await?; diff --git a/crates/language_model/src/settings.rs b/crates/language_model/src/settings.rs index 80749c0bdb..8888d51e11 100644 --- a/crates/language_model/src/settings.rs +++ b/crates/language_model/src/settings.rs @@ -231,6 +231,7 @@ pub struct GoogleSettingsContent { #[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)] pub struct ZedDotDevSettingsContent { available_models: Option>, + pub low_speed_timeout_in_seconds: Option, } #[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)] @@ -333,6 +334,14 @@ impl settings::Settings for AllLanguageModelSettings { .as_ref() .and_then(|s| s.available_models.clone()), ); + if let Some(low_speed_timeout_in_seconds) = value + .zed_dot_dev + .as_ref() + .and_then(|s| s.low_speed_timeout_in_seconds) + { + settings.zed_dot_dev.low_speed_timeout = + Some(Duration::from_secs(low_speed_timeout_in_seconds)); + } merge( &mut settings.google.api_url,