Reuse OpenAI low_speed_timeout setting for zed.dev provider (#18144)

Release Notes:

- N/A
This commit is contained in:
jvmncs 2024-09-20 12:57:35 -04:00 committed by GitHub
parent d97427f69e
commit 9f6ff29a54
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 31 additions and 2 deletions

View file

@ -32,6 +32,7 @@ futures.workspace = true
google_ai = { workspace = true, features = ["schemars"] }
gpui.workspace = true
http_client.workspace = true
isahc.workspace = true
inline_completion_button.workspace = true
log.workspace = true
menu.workspace = true

View file

@ -19,6 +19,7 @@ use gpui::{
Subscription, Task,
};
use http_client::{AsyncBody, HttpClient, Method, Response};
use isahc::config::Configurable;
use schemars::JsonSchema;
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use serde_json::value::RawValue;
@ -27,6 +28,7 @@ use smol::{
io::{AsyncReadExt, BufReader},
lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard},
};
use std::time::Duration;
use std::{
future,
sync::{Arc, LazyLock},
@ -56,6 +58,7 @@ fn zed_cloud_provider_additional_models() -> &'static [AvailableModel] {
#[derive(Default, Clone, Debug, PartialEq)]
pub struct ZedDotDevSettings {
pub available_models: Vec<AvailableModel>,
pub low_speed_timeout: Option<Duration>,
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
@ -380,6 +383,7 @@ impl CloudLanguageModel {
client: Arc<Client>,
llm_api_token: LlmApiToken,
body: PerformCompletionParams,
low_speed_timeout: Option<Duration>,
) -> Result<Response<AsyncBody>> {
let http_client = &client.http_client();
@ -387,7 +391,11 @@ impl CloudLanguageModel {
let mut did_retry = false;
let response = loop {
let request = http_client::Request::builder()
let mut request_builder = http_client::Request::builder();
if let Some(low_speed_timeout) = low_speed_timeout {
request_builder = request_builder.low_speed_timeout(100, low_speed_timeout);
};
let request = request_builder
.method(Method::POST)
.uri(http_client.build_zed_llm_url("/completion", &[])?.as_ref())
.header("Content-Type", "application/json")
@ -501,8 +509,11 @@ impl LanguageModel for CloudLanguageModel {
fn stream_completion(
&self,
request: LanguageModelRequest,
_cx: &AsyncAppContext,
cx: &AsyncAppContext,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
let openai_low_speed_timeout =
AllLanguageModelSettings::try_read_global(cx, |s| s.openai.low_speed_timeout.unwrap());
match &self.model {
CloudModel::Anthropic(model) => {
let request = request.into_anthropic(model.id().into(), model.max_output_tokens());
@ -519,6 +530,7 @@ impl LanguageModel for CloudLanguageModel {
&request,
)?)?,
},
None,
)
.await?;
Ok(map_to_language_model_completion_events(Box::pin(
@ -542,6 +554,7 @@ impl LanguageModel for CloudLanguageModel {
&request,
)?)?,
},
openai_low_speed_timeout,
)
.await?;
Ok(open_ai::extract_text_from_events(response_lines(response)))
@ -569,6 +582,7 @@ impl LanguageModel for CloudLanguageModel {
&request,
)?)?,
},
None,
)
.await?;
Ok(google_ai::extract_text_from_events(response_lines(
@ -599,6 +613,7 @@ impl LanguageModel for CloudLanguageModel {
&request,
)?)?,
},
None,
)
.await?;
Ok(open_ai::extract_text_from_events(response_lines(response)))
@ -650,6 +665,7 @@ impl LanguageModel for CloudLanguageModel {
&request,
)?)?,
},
None,
)
.await?;
@ -694,6 +710,7 @@ impl LanguageModel for CloudLanguageModel {
&request,
)?)?,
},
None,
)
.await?;
@ -741,6 +758,7 @@ impl LanguageModel for CloudLanguageModel {
&request,
)?)?,
},
None,
)
.await?;

View file

@ -231,6 +231,7 @@ pub struct GoogleSettingsContent {
#[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
pub struct ZedDotDevSettingsContent {
available_models: Option<Vec<cloud::AvailableModel>>,
pub low_speed_timeout_in_seconds: Option<u64>,
}
#[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
@ -333,6 +334,14 @@ impl settings::Settings for AllLanguageModelSettings {
.as_ref()
.and_then(|s| s.available_models.clone()),
);
if let Some(low_speed_timeout_in_seconds) = value
.zed_dot_dev
.as_ref()
.and_then(|s| s.low_speed_timeout_in_seconds)
{
settings.zed_dot_dev.low_speed_timeout =
Some(Duration::from_secs(low_speed_timeout_in_seconds));
}
merge(
&mut settings.google.api_url,