Reuse OpenAI low_speed_timeout setting for zed.dev provider (#18144)

Release Notes:

- N/A
This commit is contained in:
jvmncs 2024-09-20 12:57:35 -04:00 committed by GitHub
parent d97427f69e
commit 9f6ff29a54
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 31 additions and 2 deletions

1
Cargo.lock generated
View file

@ -6285,6 +6285,7 @@ dependencies = [
"http_client", "http_client",
"image", "image",
"inline_completion_button", "inline_completion_button",
"isahc",
"language", "language",
"log", "log",
"menu", "menu",

View file

@ -32,6 +32,7 @@ futures.workspace = true
google_ai = { workspace = true, features = ["schemars"] } google_ai = { workspace = true, features = ["schemars"] }
gpui.workspace = true gpui.workspace = true
http_client.workspace = true http_client.workspace = true
isahc.workspace = true
inline_completion_button.workspace = true inline_completion_button.workspace = true
log.workspace = true log.workspace = true
menu.workspace = true menu.workspace = true

View file

@ -19,6 +19,7 @@ use gpui::{
Subscription, Task, Subscription, Task,
}; };
use http_client::{AsyncBody, HttpClient, Method, Response}; use http_client::{AsyncBody, HttpClient, Method, Response};
use isahc::config::Configurable;
use schemars::JsonSchema; use schemars::JsonSchema;
use serde::{de::DeserializeOwned, Deserialize, Serialize}; use serde::{de::DeserializeOwned, Deserialize, Serialize};
use serde_json::value::RawValue; use serde_json::value::RawValue;
@ -27,6 +28,7 @@ use smol::{
io::{AsyncReadExt, BufReader}, io::{AsyncReadExt, BufReader},
lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard}, lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard},
}; };
use std::time::Duration;
use std::{ use std::{
future, future,
sync::{Arc, LazyLock}, sync::{Arc, LazyLock},
@ -56,6 +58,7 @@ fn zed_cloud_provider_additional_models() -> &'static [AvailableModel] {
#[derive(Default, Clone, Debug, PartialEq)] #[derive(Default, Clone, Debug, PartialEq)]
pub struct ZedDotDevSettings { pub struct ZedDotDevSettings {
pub available_models: Vec<AvailableModel>, pub available_models: Vec<AvailableModel>,
pub low_speed_timeout: Option<Duration>,
} }
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)] #[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
@ -380,6 +383,7 @@ impl CloudLanguageModel {
client: Arc<Client>, client: Arc<Client>,
llm_api_token: LlmApiToken, llm_api_token: LlmApiToken,
body: PerformCompletionParams, body: PerformCompletionParams,
low_speed_timeout: Option<Duration>,
) -> Result<Response<AsyncBody>> { ) -> Result<Response<AsyncBody>> {
let http_client = &client.http_client(); let http_client = &client.http_client();
@ -387,7 +391,11 @@ impl CloudLanguageModel {
let mut did_retry = false; let mut did_retry = false;
let response = loop { let response = loop {
let request = http_client::Request::builder() let mut request_builder = http_client::Request::builder();
if let Some(low_speed_timeout) = low_speed_timeout {
request_builder = request_builder.low_speed_timeout(100, low_speed_timeout);
};
let request = request_builder
.method(Method::POST) .method(Method::POST)
.uri(http_client.build_zed_llm_url("/completion", &[])?.as_ref()) .uri(http_client.build_zed_llm_url("/completion", &[])?.as_ref())
.header("Content-Type", "application/json") .header("Content-Type", "application/json")
@ -501,8 +509,11 @@ impl LanguageModel for CloudLanguageModel {
fn stream_completion( fn stream_completion(
&self, &self,
request: LanguageModelRequest, request: LanguageModelRequest,
_cx: &AsyncAppContext, cx: &AsyncAppContext,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> { ) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
let openai_low_speed_timeout =
AllLanguageModelSettings::try_read_global(cx, |s| s.openai.low_speed_timeout.unwrap());
match &self.model { match &self.model {
CloudModel::Anthropic(model) => { CloudModel::Anthropic(model) => {
let request = request.into_anthropic(model.id().into(), model.max_output_tokens()); let request = request.into_anthropic(model.id().into(), model.max_output_tokens());
@ -519,6 +530,7 @@ impl LanguageModel for CloudLanguageModel {
&request, &request,
)?)?, )?)?,
}, },
None,
) )
.await?; .await?;
Ok(map_to_language_model_completion_events(Box::pin( Ok(map_to_language_model_completion_events(Box::pin(
@ -542,6 +554,7 @@ impl LanguageModel for CloudLanguageModel {
&request, &request,
)?)?, )?)?,
}, },
openai_low_speed_timeout,
) )
.await?; .await?;
Ok(open_ai::extract_text_from_events(response_lines(response))) Ok(open_ai::extract_text_from_events(response_lines(response)))
@ -569,6 +582,7 @@ impl LanguageModel for CloudLanguageModel {
&request, &request,
)?)?, )?)?,
}, },
None,
) )
.await?; .await?;
Ok(google_ai::extract_text_from_events(response_lines( Ok(google_ai::extract_text_from_events(response_lines(
@ -599,6 +613,7 @@ impl LanguageModel for CloudLanguageModel {
&request, &request,
)?)?, )?)?,
}, },
None,
) )
.await?; .await?;
Ok(open_ai::extract_text_from_events(response_lines(response))) Ok(open_ai::extract_text_from_events(response_lines(response)))
@ -650,6 +665,7 @@ impl LanguageModel for CloudLanguageModel {
&request, &request,
)?)?, )?)?,
}, },
None,
) )
.await?; .await?;
@ -694,6 +710,7 @@ impl LanguageModel for CloudLanguageModel {
&request, &request,
)?)?, )?)?,
}, },
None,
) )
.await?; .await?;
@ -741,6 +758,7 @@ impl LanguageModel for CloudLanguageModel {
&request, &request,
)?)?, )?)?,
}, },
None,
) )
.await?; .await?;

View file

@ -231,6 +231,7 @@ pub struct GoogleSettingsContent {
#[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)] #[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
pub struct ZedDotDevSettingsContent { pub struct ZedDotDevSettingsContent {
available_models: Option<Vec<cloud::AvailableModel>>, available_models: Option<Vec<cloud::AvailableModel>>,
pub low_speed_timeout_in_seconds: Option<u64>,
} }
#[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)] #[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
@ -333,6 +334,14 @@ impl settings::Settings for AllLanguageModelSettings {
.as_ref() .as_ref()
.and_then(|s| s.available_models.clone()), .and_then(|s| s.available_models.clone()),
); );
if let Some(low_speed_timeout_in_seconds) = value
.zed_dot_dev
.as_ref()
.and_then(|s| s.low_speed_timeout_in_seconds)
{
settings.zed_dot_dev.low_speed_timeout =
Some(Duration::from_secs(low_speed_timeout_in_seconds));
}
merge( merge(
&mut settings.google.api_url, &mut settings.google.api_url,