assistant: Remove low_speed_timeout (#20681)

This removes the `low_speed_timeout` setting from all providers as a
response to issue #19509.

Reason being that the original `low_speed_timeout` was only as part of
#9913 because users wanted to _get rid of timeouts_. They wanted to bump
the default timeout from 5sec to a lot more.

Then, in the meantime, the meaning of `low_speed_timeout` changed in
#19055 and was changed to a normal `timeout`, which is a different thing
and breaks slower LLMs that don't reply with a complete response in the
configured timeout.

So we figured: let's remove the whole thing and replace it with a
default _connect_ timeout to make sure that we can connect to a server
in 10s, but then give the server as long as it wants to complete its
response.

Closes #19509

Release Notes:

- Removed the `low_speed_timeout` setting from LLM provider settings,
since it was only used to _increase_ the timeout to give LLMs more time,
but since we don't have any other use for it, we simply remove the
setting to give LLMs as long as they need.

---------

Co-authored-by: Antonio <antonio@zed.dev>
Co-authored-by: Peter Tripp <peter@zed.dev>
This commit is contained in:
Thorsten Ball 2024-11-15 07:37:31 +01:00 committed by GitHub
parent c9546070ac
commit aee01f2c50
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
19 changed files with 109 additions and 345 deletions

View file

@ -1053,13 +1053,11 @@
"api_url": "https://generativelanguage.googleapis.com"
},
"ollama": {
"api_url": "http://localhost:11434",
"low_speed_timeout_in_seconds": 60
"api_url": "http://localhost:11434"
},
"openai": {
"version": "1",
"api_url": "https://api.openai.com/v1",
"low_speed_timeout_in_seconds": 600
"api_url": "https://api.openai.com/v1"
}
},
// Zed's Prettier integration settings.

View file

@ -1,13 +1,12 @@
mod supported_countries;
use std::time::Duration;
use std::{pin::Pin, str::FromStr};
use anyhow::{anyhow, Context, Result};
use chrono::{DateTime, Utc};
use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
use http_client::http::{HeaderMap, HeaderValue};
use http_client::{AsyncBody, HttpClient, HttpRequestExt, Method, Request as HttpRequest};
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
use serde::{Deserialize, Serialize};
use strum::{EnumIter, EnumString};
use thiserror::Error;
@ -207,9 +206,8 @@ pub async fn stream_completion(
api_url: &str,
api_key: &str,
request: Request,
low_speed_timeout: Option<Duration>,
) -> Result<BoxStream<'static, Result<Event, AnthropicError>>, AnthropicError> {
stream_completion_with_rate_limit_info(client, api_url, api_key, request, low_speed_timeout)
stream_completion_with_rate_limit_info(client, api_url, api_key, request)
.await
.map(|output| output.0)
}
@ -261,7 +259,6 @@ pub async fn stream_completion_with_rate_limit_info(
api_url: &str,
api_key: &str,
request: Request,
low_speed_timeout: Option<Duration>,
) -> Result<
(
BoxStream<'static, Result<Event, AnthropicError>>,
@ -274,7 +271,7 @@ pub async fn stream_completion_with_rate_limit_info(
stream: true,
};
let uri = format!("{api_url}/v1/messages");
let mut request_builder = HttpRequest::builder()
let request_builder = HttpRequest::builder()
.method(Method::POST)
.uri(uri)
.header("Anthropic-Version", "2023-06-01")
@ -284,9 +281,6 @@ pub async fn stream_completion_with_rate_limit_info(
)
.header("X-Api-Key", api_key)
.header("Content-Type", "application/json");
if let Some(low_speed_timeout) = low_speed_timeout {
request_builder = request_builder.read_timeout(low_speed_timeout);
}
let serialized_request =
serde_json::to_string(&request).context("failed to serialize request")?;
let request = request_builder

View file

@ -35,20 +35,17 @@ pub enum AssistantProviderContentV1 {
OpenAi {
default_model: Option<OpenAiModel>,
api_url: Option<String>,
low_speed_timeout_in_seconds: Option<u64>,
available_models: Option<Vec<OpenAiModel>>,
},
#[serde(rename = "anthropic")]
Anthropic {
default_model: Option<AnthropicModel>,
api_url: Option<String>,
low_speed_timeout_in_seconds: Option<u64>,
},
#[serde(rename = "ollama")]
Ollama {
default_model: Option<OllamaModel>,
api_url: Option<String>,
low_speed_timeout_in_seconds: Option<u64>,
},
}
@ -115,47 +112,41 @@ impl AssistantSettingsContent {
if let VersionedAssistantSettingsContent::V1(settings) = settings {
if let Some(provider) = settings.provider.clone() {
match provider {
AssistantProviderContentV1::Anthropic {
api_url,
low_speed_timeout_in_seconds,
..
} => update_settings_file::<AllLanguageModelSettings>(
fs,
cx,
move |content, _| {
if content.anthropic.is_none() {
content.anthropic = Some(AnthropicSettingsContent::Versioned(
VersionedAnthropicSettingsContent::V1(
AnthropicSettingsContentV1 {
api_url,
low_speed_timeout_in_seconds,
available_models: None,
},
),
));
}
},
),
AssistantProviderContentV1::Ollama {
api_url,
low_speed_timeout_in_seconds,
..
} => update_settings_file::<AllLanguageModelSettings>(
fs,
cx,
move |content, _| {
if content.ollama.is_none() {
content.ollama = Some(OllamaSettingsContent {
api_url,
low_speed_timeout_in_seconds,
available_models: None,
});
}
},
),
AssistantProviderContentV1::Anthropic { api_url, .. } => {
update_settings_file::<AllLanguageModelSettings>(
fs,
cx,
move |content, _| {
if content.anthropic.is_none() {
content.anthropic =
Some(AnthropicSettingsContent::Versioned(
VersionedAnthropicSettingsContent::V1(
AnthropicSettingsContentV1 {
api_url,
available_models: None,
},
),
));
}
},
)
}
AssistantProviderContentV1::Ollama { api_url, .. } => {
update_settings_file::<AllLanguageModelSettings>(
fs,
cx,
move |content, _| {
if content.ollama.is_none() {
content.ollama = Some(OllamaSettingsContent {
api_url,
available_models: None,
});
}
},
)
}
AssistantProviderContentV1::OpenAi {
api_url,
low_speed_timeout_in_seconds,
available_models,
..
} => update_settings_file::<AllLanguageModelSettings>(
@ -188,7 +179,6 @@ impl AssistantSettingsContent {
VersionedOpenAiSettingsContent::V1(
OpenAiSettingsContentV1 {
api_url,
low_speed_timeout_in_seconds,
available_models,
},
),
@ -298,54 +288,41 @@ impl AssistantSettingsContent {
log::warn!("attempted to set zed.dev model on outdated settings");
}
"anthropic" => {
let (api_url, low_speed_timeout_in_seconds) = match &settings.provider {
Some(AssistantProviderContentV1::Anthropic {
api_url,
low_speed_timeout_in_seconds,
..
}) => (api_url.clone(), *low_speed_timeout_in_seconds),
_ => (None, None),
let api_url = match &settings.provider {
Some(AssistantProviderContentV1::Anthropic { api_url, .. }) => {
api_url.clone()
}
_ => None,
};
settings.provider = Some(AssistantProviderContentV1::Anthropic {
default_model: AnthropicModel::from_id(&model).ok(),
api_url,
low_speed_timeout_in_seconds,
});
}
"ollama" => {
let (api_url, low_speed_timeout_in_seconds) = match &settings.provider {
Some(AssistantProviderContentV1::Ollama {
api_url,
low_speed_timeout_in_seconds,
..
}) => (api_url.clone(), *low_speed_timeout_in_seconds),
_ => (None, None),
let api_url = match &settings.provider {
Some(AssistantProviderContentV1::Ollama { api_url, .. }) => {
api_url.clone()
}
_ => None,
};
settings.provider = Some(AssistantProviderContentV1::Ollama {
default_model: Some(ollama::Model::new(&model, None, None)),
api_url,
low_speed_timeout_in_seconds,
});
}
"openai" => {
let (api_url, low_speed_timeout_in_seconds, available_models) =
match &settings.provider {
Some(AssistantProviderContentV1::OpenAi {
api_url,
low_speed_timeout_in_seconds,
available_models,
..
}) => (
api_url.clone(),
*low_speed_timeout_in_seconds,
available_models.clone(),
),
_ => (None, None, None),
};
let (api_url, available_models) = match &settings.provider {
Some(AssistantProviderContentV1::OpenAi {
api_url,
available_models,
..
}) => (api_url.clone(), available_models.clone()),
_ => (None, None),
};
settings.provider = Some(AssistantProviderContentV1::OpenAi {
default_model: OpenAiModel::from_id(&model).ok(),
api_url,
low_speed_timeout_in_seconds,
available_models,
});
}

View file

@ -267,7 +267,6 @@ async fn perform_completion(
anthropic::ANTHROPIC_API_URL,
api_key,
request,
None,
)
.await
.map_err(|err| match err {
@ -357,7 +356,6 @@ async fn perform_completion(
open_ai::OPEN_AI_API_URL,
api_key,
serde_json::from_str(params.provider_request.get())?,
None,
)
.await?;
@ -390,7 +388,6 @@ async fn perform_completion(
google_ai::API_URL,
api_key,
serde_json::from_str(params.provider_request.get())?,
None,
)
.await?;

View file

@ -3621,7 +3621,6 @@ async fn count_language_model_tokens(
google_ai::API_URL,
api_key,
serde_json::from_str(&request.request)?,
None,
)
.await?
}

View file

@ -1,13 +1,13 @@
use std::path::PathBuf;
use std::sync::Arc;
use std::sync::OnceLock;
use std::{sync::Arc, time::Duration};
use anyhow::{anyhow, Result};
use chrono::DateTime;
use fs::Fs;
use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt};
use gpui::{AppContext, AsyncAppContext, Global};
use http_client::{AsyncBody, HttpClient, HttpRequestExt, Method, Request as HttpRequest};
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
use paths::home_dir;
use serde::{Deserialize, Serialize};
use settings::watch_config_file;
@ -254,7 +254,6 @@ impl CopilotChat {
pub async fn stream_completion(
request: Request,
low_speed_timeout: Option<Duration>,
mut cx: AsyncAppContext,
) -> Result<BoxStream<'static, Result<ResponseEvent>>> {
let Some(this) = cx.update(|cx| Self::global(cx)).ok().flatten() else {
@ -274,8 +273,7 @@ impl CopilotChat {
let token = match api_token {
Some(api_token) if api_token.remaining_seconds() > 5 * 60 => api_token.clone(),
_ => {
let token =
request_api_token(&oauth_token, client.clone(), low_speed_timeout).await?;
let token = request_api_token(&oauth_token, client.clone()).await?;
this.update(&mut cx, |this, cx| {
this.api_token = Some(token.clone());
cx.notify();
@ -284,25 +282,17 @@ impl CopilotChat {
}
};
stream_completion(client.clone(), token.api_key, request, low_speed_timeout).await
stream_completion(client.clone(), token.api_key, request).await
}
}
async fn request_api_token(
oauth_token: &str,
client: Arc<dyn HttpClient>,
low_speed_timeout: Option<Duration>,
) -> Result<ApiToken> {
let mut request_builder = HttpRequest::builder()
async fn request_api_token(oauth_token: &str, client: Arc<dyn HttpClient>) -> Result<ApiToken> {
let request_builder = HttpRequest::builder()
.method(Method::GET)
.uri(COPILOT_CHAT_AUTH_URL)
.header("Authorization", format!("token {}", oauth_token))
.header("Accept", "application/json");
if let Some(low_speed_timeout) = low_speed_timeout {
request_builder = request_builder.read_timeout(low_speed_timeout);
}
let request = request_builder.body(AsyncBody::empty())?;
let mut response = client.send(request).await?;
@ -340,9 +330,8 @@ async fn stream_completion(
client: Arc<dyn HttpClient>,
api_key: String,
request: Request,
low_speed_timeout: Option<Duration>,
) -> Result<BoxStream<'static, Result<ResponseEvent>>> {
let mut request_builder = HttpRequest::builder()
let request_builder = HttpRequest::builder()
.method(Method::POST)
.uri(COPILOT_CHAT_COMPLETION_URL)
.header(
@ -356,9 +345,6 @@ async fn stream_completion(
.header("Content-Type", "application/json")
.header("Copilot-Integration-Id", "vscode-chat");
if let Some(low_speed_timeout) = low_speed_timeout {
request_builder = request_builder.read_timeout(low_speed_timeout);
}
let is_streaming = request.stream;
let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;

View file

@ -2,9 +2,8 @@ mod supported_countries;
use anyhow::{anyhow, Result};
use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
use http_client::{AsyncBody, HttpClient, HttpRequestExt, Method, Request as HttpRequest};
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
use serde::{Deserialize, Serialize};
use std::time::Duration;
pub use supported_countries::*;
@ -15,7 +14,6 @@ pub async fn stream_generate_content(
api_url: &str,
api_key: &str,
mut request: GenerateContentRequest,
low_speed_timeout: Option<Duration>,
) -> Result<BoxStream<'static, Result<GenerateContentResponse>>> {
let uri = format!(
"{api_url}/v1beta/models/{model}:streamGenerateContent?alt=sse&key={api_key}",
@ -23,15 +21,11 @@ pub async fn stream_generate_content(
);
request.model.clear();
let mut request_builder = HttpRequest::builder()
let request_builder = HttpRequest::builder()
.method(Method::POST)
.uri(uri)
.header("Content-Type", "application/json");
if let Some(low_speed_timeout) = low_speed_timeout {
request_builder = request_builder.read_timeout(low_speed_timeout);
};
let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
let mut response = client.send(request).await?;
if response.status().is_success() {
@ -70,7 +64,6 @@ pub async fn count_tokens(
api_url: &str,
api_key: &str,
request: CountTokensRequest,
low_speed_timeout: Option<Duration>,
) -> Result<CountTokensResponse> {
let uri = format!(
"{}/v1beta/models/gemini-pro:countTokens?key={}",
@ -78,15 +71,11 @@ pub async fn count_tokens(
);
let request = serde_json::to_string(&request)?;
let mut request_builder = HttpRequest::builder()
let request_builder = HttpRequest::builder()
.method(Method::POST)
.uri(&uri)
.header("Content-Type", "application/json");
if let Some(low_speed_timeout) = low_speed_timeout {
request_builder = request_builder.read_timeout(low_speed_timeout);
}
let http_request = request_builder.body(AsyncBody::from(request))?;
let mut response = client.send(http_request).await?;
let mut text = String::new();

View file

@ -13,21 +13,10 @@ use std::fmt;
use std::{
any::type_name,
sync::{Arc, Mutex},
time::Duration,
};
pub use url::Url;
#[derive(Clone, Debug)]
pub struct ReadTimeout(pub Duration);
impl Default for ReadTimeout {
fn default() -> Self {
Self(Duration::from_secs(5))
}
}
#[derive(Default, Debug, Clone, PartialEq, Eq, Hash)]
pub enum RedirectPolicy {
#[default]
NoFollow,
@ -37,20 +26,11 @@ pub enum RedirectPolicy {
pub struct FollowRedirects(pub bool);
pub trait HttpRequestExt {
/// Set a read timeout on the request.
/// For isahc, this is the low_speed_timeout.
/// For other clients, this is the timeout used for read calls when reading the response.
/// In all cases this prevents servers stalling completely, but allows them to send data slowly.
fn read_timeout(self, timeout: Duration) -> Self;
/// Whether or not to follow redirects
fn follow_redirects(self, follow: RedirectPolicy) -> Self;
}
impl HttpRequestExt for http::request::Builder {
fn read_timeout(self, timeout: Duration) -> Self {
self.extension(ReadTimeout(timeout))
}
fn follow_redirects(self, follow: RedirectPolicy) -> Self {
self.extension(follow)
}

View file

@ -20,7 +20,7 @@ use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsStore};
use std::pin::Pin;
use std::str::FromStr;
use std::{sync::Arc, time::Duration};
use std::sync::Arc;
use strum::IntoEnumIterator;
use theme::ThemeSettings;
use ui::{prelude::*, Icon, IconName, Tooltip};
@ -32,7 +32,6 @@ const PROVIDER_NAME: &str = "Anthropic";
#[derive(Default, Clone, Debug, PartialEq)]
pub struct AnthropicSettings {
pub api_url: String,
pub low_speed_timeout: Option<Duration>,
/// Extend Zed's list of Anthropic models.
pub available_models: Vec<AvailableModel>,
pub needs_setting_migration: bool,
@ -309,26 +308,17 @@ impl AnthropicModel {
{
let http_client = self.http_client.clone();
let Ok((api_key, api_url, low_speed_timeout)) = cx.read_model(&self.state, |state, cx| {
let Ok((api_key, api_url)) = cx.read_model(&self.state, |state, cx| {
let settings = &AllLanguageModelSettings::get_global(cx).anthropic;
(
state.api_key.clone(),
settings.api_url.clone(),
settings.low_speed_timeout,
)
(state.api_key.clone(), settings.api_url.clone())
}) else {
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
};
async move {
let api_key = api_key.ok_or_else(|| anyhow!("Missing Anthropic API Key"))?;
let request = anthropic::stream_completion(
http_client.as_ref(),
&api_url,
&api_key,
request,
low_speed_timeout,
);
let request =
anthropic::stream_completion(http_client.as_ref(), &api_url, &api_key, request);
request.await.context("failed to stream completion")
}
.boxed()

View file

@ -21,7 +21,7 @@ use gpui::{
AnyElement, AnyView, AppContext, AsyncAppContext, EventEmitter, FontWeight, Global, Model,
ModelContext, ReadGlobal, Subscription, Task,
};
use http_client::{AsyncBody, HttpClient, HttpRequestExt, Method, Response, StatusCode};
use http_client::{AsyncBody, HttpClient, Method, Response, StatusCode};
use proto::TypedEnvelope;
use schemars::JsonSchema;
use serde::{de::DeserializeOwned, Deserialize, Serialize};
@ -32,7 +32,6 @@ use smol::{
lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard},
};
use std::fmt;
use std::time::Duration;
use std::{
future,
sync::{Arc, LazyLock},
@ -63,7 +62,6 @@ fn zed_cloud_provider_additional_models() -> &'static [AvailableModel] {
#[derive(Default, Clone, Debug, PartialEq)]
pub struct ZedDotDevSettings {
pub available_models: Vec<AvailableModel>,
pub low_speed_timeout: Option<Duration>,
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
@ -475,7 +473,6 @@ impl CloudLanguageModel {
client: Arc<Client>,
llm_api_token: LlmApiToken,
body: PerformCompletionParams,
low_speed_timeout: Option<Duration>,
) -> Result<Response<AsyncBody>> {
let http_client = &client.http_client();
@ -483,10 +480,7 @@ impl CloudLanguageModel {
let mut did_retry = false;
let response = loop {
let mut request_builder = http_client::Request::builder();
if let Some(low_speed_timeout) = low_speed_timeout {
request_builder = request_builder.read_timeout(low_speed_timeout);
};
let request_builder = http_client::Request::builder();
let request = request_builder
.method(Method::POST)
.uri(http_client.build_zed_llm_url("/completion", &[])?.as_ref())
@ -607,11 +601,8 @@ impl LanguageModel for CloudLanguageModel {
fn stream_completion(
&self,
request: LanguageModelRequest,
cx: &AsyncAppContext,
_cx: &AsyncAppContext,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
let openai_low_speed_timeout =
AllLanguageModelSettings::try_read_global(cx, |s| s.openai.low_speed_timeout.unwrap());
match &self.model {
CloudModel::Anthropic(model) => {
let request = request.into_anthropic(
@ -632,7 +623,6 @@ impl LanguageModel for CloudLanguageModel {
&request,
)?)?,
},
None,
)
.await?;
Ok(map_to_language_model_completion_events(Box::pin(
@ -656,7 +646,6 @@ impl LanguageModel for CloudLanguageModel {
&request,
)?)?,
},
openai_low_speed_timeout,
)
.await?;
Ok(open_ai::extract_text_from_events(response_lines(response)))
@ -684,7 +673,6 @@ impl LanguageModel for CloudLanguageModel {
&request,
)?)?,
},
None,
)
.await?;
Ok(google_ai::extract_text_from_events(response_lines(
@ -741,7 +729,6 @@ impl LanguageModel for CloudLanguageModel {
&request,
)?)?,
},
None,
)
.await?;
@ -786,7 +773,6 @@ impl LanguageModel for CloudLanguageModel {
&request,
)?)?,
},
None,
)
.await?;

View file

@ -14,7 +14,7 @@ use gpui::{
percentage, svg, Animation, AnimationExt, AnyView, AppContext, AsyncAppContext, Model, Render,
Subscription, Task, Transformation,
};
use settings::{Settings, SettingsStore};
use settings::SettingsStore;
use std::time::Duration;
use strum::IntoEnumIterator;
use ui::{
@ -23,7 +23,6 @@ use ui::{
ViewContext, VisualContext, WindowContext,
};
use crate::settings::AllLanguageModelSettings;
use crate::{
LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
LanguageModelProviderId, LanguageModelProviderName, LanguageModelRequest, RateLimiter, Role,
@ -37,9 +36,7 @@ const PROVIDER_ID: &str = "copilot_chat";
const PROVIDER_NAME: &str = "GitHub Copilot Chat";
#[derive(Default, Clone, Debug, PartialEq)]
pub struct CopilotChatSettings {
pub low_speed_timeout: Option<Duration>,
}
pub struct CopilotChatSettings {}
pub struct CopilotChatLanguageModelProvider {
state: Model<State>,
@ -218,17 +215,10 @@ impl LanguageModel for CopilotChatLanguageModel {
let copilot_request = self.to_copilot_chat_request(request);
let is_streaming = copilot_request.stream;
let Ok(low_speed_timeout) = cx.update(|cx| {
AllLanguageModelSettings::get_global(cx)
.copilot_chat
.low_speed_timeout
}) else {
return futures::future::ready(Err(anyhow::anyhow!("App state dropped"))).boxed();
};
let request_limiter = self.request_limiter.clone();
let future = cx.spawn(|cx| async move {
let response = CopilotChat::stream_completion(copilot_request, low_speed_timeout, cx);
let response = CopilotChat::stream_completion(copilot_request, cx);
request_limiter.stream(async move {
let response = response.await?;
let stream = response

View file

@ -11,7 +11,7 @@ use http_client::HttpClient;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsStore};
use std::{future, sync::Arc, time::Duration};
use std::{future, sync::Arc};
use strum::IntoEnumIterator;
use theme::ThemeSettings;
use ui::{prelude::*, Icon, IconName, Tooltip};
@ -30,7 +30,6 @@ const PROVIDER_NAME: &str = "Google AI";
#[derive(Default, Clone, Debug, PartialEq)]
pub struct GoogleSettings {
pub api_url: String,
pub low_speed_timeout: Option<Duration>,
pub available_models: Vec<AvailableModel>,
}
@ -262,7 +261,6 @@ impl LanguageModel for GoogleLanguageModel {
let settings = &AllLanguageModelSettings::get_global(cx).google;
let api_url = settings.api_url.clone();
let low_speed_timeout = settings.low_speed_timeout;
async move {
let api_key = api_key.ok_or_else(|| anyhow!("Missing Google API key"))?;
@ -273,7 +271,6 @@ impl LanguageModel for GoogleLanguageModel {
google_ai::CountTokensRequest {
contents: request.contents,
},
low_speed_timeout,
)
.await?;
Ok(response.total_tokens)
@ -292,26 +289,17 @@ impl LanguageModel for GoogleLanguageModel {
let request = request.into_google(self.model.id().to_string());
let http_client = self.http_client.clone();
let Ok((api_key, api_url, low_speed_timeout)) = cx.read_model(&self.state, |state, cx| {
let Ok((api_key, api_url)) = cx.read_model(&self.state, |state, cx| {
let settings = &AllLanguageModelSettings::get_global(cx).google;
(
state.api_key.clone(),
settings.api_url.clone(),
settings.low_speed_timeout,
)
(state.api_key.clone(), settings.api_url.clone())
}) else {
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
};
let future = self.rate_limiter.stream(async move {
let api_key = api_key.ok_or_else(|| anyhow!("Missing Google API Key"))?;
let response = stream_generate_content(
http_client.as_ref(),
&api_url,
&api_key,
request,
low_speed_timeout,
);
let response =
stream_generate_content(http_client.as_ref(), &api_url, &api_key, request);
let events = response.await?;
Ok(google_ai::extract_text_from_events(events).boxed())
});

View file

@ -9,7 +9,7 @@ use ollama::{
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsStore};
use std::{collections::BTreeMap, sync::Arc, time::Duration};
use std::{collections::BTreeMap, sync::Arc};
use ui::{prelude::*, ButtonLike, Indicator};
use util::ResultExt;
@ -30,7 +30,6 @@ const PROVIDER_NAME: &str = "Ollama";
#[derive(Default, Debug, Clone, PartialEq)]
pub struct OllamaSettings {
pub api_url: String,
pub low_speed_timeout: Option<Duration>,
pub available_models: Vec<AvailableModel>,
}
@ -327,17 +326,15 @@ impl LanguageModel for OllamaLanguageModel {
let request = self.to_ollama_request(request);
let http_client = self.http_client.clone();
let Ok((api_url, low_speed_timeout)) = cx.update(|cx| {
let Ok(api_url) = cx.update(|cx| {
let settings = &AllLanguageModelSettings::get_global(cx).ollama;
(settings.api_url.clone(), settings.low_speed_timeout)
settings.api_url.clone()
}) else {
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
};
let future = self.request_limiter.stream(async move {
let response =
stream_chat_completion(http_client.as_ref(), &api_url, request, low_speed_timeout)
.await?;
let response = stream_chat_completion(http_client.as_ref(), &api_url, request).await?;
let stream = response
.filter_map(|response| async move {
match response {

View file

@ -13,7 +13,7 @@ use open_ai::{
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsStore};
use std::{sync::Arc, time::Duration};
use std::sync::Arc;
use strum::IntoEnumIterator;
use theme::ThemeSettings;
use ui::{prelude::*, Icon, IconName, Tooltip};
@ -32,7 +32,6 @@ const PROVIDER_NAME: &str = "OpenAI";
#[derive(Default, Clone, Debug, PartialEq)]
pub struct OpenAiSettings {
pub api_url: String,
pub low_speed_timeout: Option<Duration>,
pub available_models: Vec<AvailableModel>,
pub needs_setting_migration: bool,
}
@ -229,26 +228,16 @@ impl OpenAiLanguageModel {
) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<ResponseStreamEvent>>>>
{
let http_client = self.http_client.clone();
let Ok((api_key, api_url, low_speed_timeout)) = cx.read_model(&self.state, |state, cx| {
let Ok((api_key, api_url)) = cx.read_model(&self.state, |state, cx| {
let settings = &AllLanguageModelSettings::get_global(cx).openai;
(
state.api_key.clone(),
settings.api_url.clone(),
settings.low_speed_timeout,
)
(state.api_key.clone(), settings.api_url.clone())
}) else {
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
};
let future = self.request_limiter.stream(async move {
let api_key = api_key.ok_or_else(|| anyhow!("Missing OpenAI API Key"))?;
let request = stream_completion(
http_client.as_ref(),
&api_url,
&api_key,
request,
low_speed_timeout,
);
let request = stream_completion(http_client.as_ref(), &api_url, &api_key, request);
let response = request.await?;
Ok(response)
});

View file

@ -1,4 +1,4 @@
use std::{sync::Arc, time::Duration};
use std::sync::Arc;
use anyhow::Result;
use gpui::AppContext;
@ -87,7 +87,6 @@ impl AnthropicSettingsContent {
AnthropicSettingsContent::Legacy(content) => (
AnthropicSettingsContentV1 {
api_url: content.api_url,
low_speed_timeout_in_seconds: content.low_speed_timeout_in_seconds,
available_models: content.available_models.map(|models| {
models
.into_iter()
@ -132,7 +131,6 @@ impl AnthropicSettingsContent {
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
pub struct LegacyAnthropicSettingsContent {
pub api_url: Option<String>,
pub low_speed_timeout_in_seconds: Option<u64>,
pub available_models: Option<Vec<anthropic::Model>>,
}
@ -146,14 +144,12 @@ pub enum VersionedAnthropicSettingsContent {
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
pub struct AnthropicSettingsContentV1 {
pub api_url: Option<String>,
pub low_speed_timeout_in_seconds: Option<u64>,
pub available_models: Option<Vec<provider::anthropic::AvailableModel>>,
}
#[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
pub struct OllamaSettingsContent {
pub api_url: Option<String>,
pub low_speed_timeout_in_seconds: Option<u64>,
pub available_models: Option<Vec<provider::ollama::AvailableModel>>,
}
@ -170,7 +166,6 @@ impl OpenAiSettingsContent {
OpenAiSettingsContent::Legacy(content) => (
OpenAiSettingsContentV1 {
api_url: content.api_url,
low_speed_timeout_in_seconds: content.low_speed_timeout_in_seconds,
available_models: content.available_models.map(|models| {
models
.into_iter()
@ -205,7 +200,6 @@ impl OpenAiSettingsContent {
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
pub struct LegacyOpenAiSettingsContent {
pub api_url: Option<String>,
pub low_speed_timeout_in_seconds: Option<u64>,
pub available_models: Option<Vec<open_ai::Model>>,
}
@ -219,27 +213,22 @@ pub enum VersionedOpenAiSettingsContent {
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
pub struct OpenAiSettingsContentV1 {
pub api_url: Option<String>,
pub low_speed_timeout_in_seconds: Option<u64>,
pub available_models: Option<Vec<provider::open_ai::AvailableModel>>,
}
#[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
pub struct GoogleSettingsContent {
pub api_url: Option<String>,
pub low_speed_timeout_in_seconds: Option<u64>,
pub available_models: Option<Vec<provider::google::AvailableModel>>,
}
#[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
pub struct ZedDotDevSettingsContent {
available_models: Option<Vec<cloud::AvailableModel>>,
pub low_speed_timeout_in_seconds: Option<u64>,
}
#[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
pub struct CopilotChatSettingsContent {
low_speed_timeout_in_seconds: Option<u64>,
}
pub struct CopilotChatSettingsContent {}
impl settings::Settings for AllLanguageModelSettings {
const KEY: Option<&'static str> = Some("language_models");
@ -272,13 +261,6 @@ impl settings::Settings for AllLanguageModelSettings {
&mut settings.anthropic.api_url,
anthropic.as_ref().and_then(|s| s.api_url.clone()),
);
if let Some(low_speed_timeout_in_seconds) = anthropic
.as_ref()
.and_then(|s| s.low_speed_timeout_in_seconds)
{
settings.anthropic.low_speed_timeout =
Some(Duration::from_secs(low_speed_timeout_in_seconds));
}
merge(
&mut settings.anthropic.available_models,
anthropic.as_ref().and_then(|s| s.available_models.clone()),
@ -291,14 +273,6 @@ impl settings::Settings for AllLanguageModelSettings {
&mut settings.ollama.api_url,
value.ollama.as_ref().and_then(|s| s.api_url.clone()),
);
if let Some(low_speed_timeout_in_seconds) = value
.ollama
.as_ref()
.and_then(|s| s.low_speed_timeout_in_seconds)
{
settings.ollama.low_speed_timeout =
Some(Duration::from_secs(low_speed_timeout_in_seconds));
}
merge(
&mut settings.ollama.available_models,
ollama.as_ref().and_then(|s| s.available_models.clone()),
@ -318,17 +292,10 @@ impl settings::Settings for AllLanguageModelSettings {
&mut settings.openai.api_url,
openai.as_ref().and_then(|s| s.api_url.clone()),
);
if let Some(low_speed_timeout_in_seconds) =
openai.as_ref().and_then(|s| s.low_speed_timeout_in_seconds)
{
settings.openai.low_speed_timeout =
Some(Duration::from_secs(low_speed_timeout_in_seconds));
}
merge(
&mut settings.openai.available_models,
openai.as_ref().and_then(|s| s.available_models.clone()),
);
merge(
&mut settings.zed_dot_dev.available_models,
value
@ -336,27 +303,10 @@ impl settings::Settings for AllLanguageModelSettings {
.as_ref()
.and_then(|s| s.available_models.clone()),
);
if let Some(low_speed_timeout_in_seconds) = value
.zed_dot_dev
.as_ref()
.and_then(|s| s.low_speed_timeout_in_seconds)
{
settings.zed_dot_dev.low_speed_timeout =
Some(Duration::from_secs(low_speed_timeout_in_seconds));
}
merge(
&mut settings.google.api_url,
value.google.as_ref().and_then(|s| s.api_url.clone()),
);
if let Some(low_speed_timeout_in_seconds) = value
.google
.as_ref()
.and_then(|s| s.low_speed_timeout_in_seconds)
{
settings.google.low_speed_timeout =
Some(Duration::from_secs(low_speed_timeout_in_seconds));
}
merge(
&mut settings.google.available_models,
value
@ -364,15 +314,6 @@ impl settings::Settings for AllLanguageModelSettings {
.as_ref()
.and_then(|s| s.available_models.clone()),
);
if let Some(low_speed_timeout) = value
.copilot_chat
.as_ref()
.and_then(|s| s.low_speed_timeout_in_seconds)
{
settings.copilot_chat.low_speed_timeout =
Some(Duration::from_secs(low_speed_timeout));
}
}
Ok(settings)

View file

@ -1,6 +1,6 @@
use anyhow::{anyhow, Context, Result};
use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt};
use http_client::{http, AsyncBody, HttpClient, HttpRequestExt, Method, Request as HttpRequest};
use http_client::{http, AsyncBody, HttpClient, Method, Request as HttpRequest};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use serde_json::{value::RawValue, Value};
@ -262,18 +262,13 @@ pub async fn stream_chat_completion(
client: &dyn HttpClient,
api_url: &str,
request: ChatRequest,
low_speed_timeout: Option<Duration>,
) -> Result<BoxStream<'static, Result<ChatResponseDelta>>> {
let uri = format!("{api_url}/api/chat");
let mut request_builder = http::Request::builder()
let request_builder = http::Request::builder()
.method(Method::POST)
.uri(uri)
.header("Content-Type", "application/json");
if let Some(low_speed_timeout) = low_speed_timeout {
request_builder = request_builder.read_timeout(low_speed_timeout);
}
let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
let mut response = client.send(request).await?;
if response.status().is_success() {
@ -281,7 +276,7 @@ pub async fn stream_chat_completion(
Ok(reader
.lines()
.filter_map(|line| async move {
.filter_map(move |line| async move {
match line {
Ok(line) => {
Some(serde_json::from_str(&line).context("Unable to parse chat response"))

View file

@ -6,14 +6,13 @@ use futures::{
stream::{self, BoxStream},
AsyncBufReadExt, AsyncReadExt, Stream, StreamExt,
};
use http_client::{AsyncBody, HttpClient, HttpRequestExt, Method, Request as HttpRequest};
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::{
convert::TryFrom,
future::{self, Future},
pin::Pin,
time::Duration,
};
use strum::EnumIter;
@ -308,17 +307,13 @@ pub async fn complete(
api_url: &str,
api_key: &str,
request: Request,
low_speed_timeout: Option<Duration>,
) -> Result<Response> {
let uri = format!("{api_url}/chat/completions");
let mut request_builder = HttpRequest::builder()
let request_builder = HttpRequest::builder()
.method(Method::POST)
.uri(uri)
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", api_key));
if let Some(low_speed_timeout) = low_speed_timeout {
request_builder = request_builder.read_timeout(low_speed_timeout);
};
let mut request_body = request;
request_body.stream = false;
@ -396,25 +391,20 @@ pub async fn stream_completion(
api_url: &str,
api_key: &str,
request: Request,
low_speed_timeout: Option<Duration>,
) -> Result<BoxStream<'static, Result<ResponseStreamEvent>>> {
if request.model == "o1-preview" || request.model == "o1-mini" {
let response = complete(client, api_url, api_key, request, low_speed_timeout).await;
let response = complete(client, api_url, api_key, request).await;
let response_stream_event = response.map(adapt_response_to_stream);
return Ok(stream::once(future::ready(response_stream_event)).boxed());
}
let uri = format!("{api_url}/chat/completions");
let mut request_builder = HttpRequest::builder()
let request_builder = HttpRequest::builder()
.method(Method::POST)
.uri(uri)
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", api_key));
if let Some(low_speed_timeout) = low_speed_timeout {
request_builder = request_builder.read_timeout(low_speed_timeout);
};
let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
let mut response = client.send(request).await?;
if response.status().is_success() {

View file

@ -1,9 +1,9 @@
use std::{any::type_name, mem, pin::Pin, sync::OnceLock, task::Poll};
use std::{any::type_name, mem, pin::Pin, sync::OnceLock, task::Poll, time::Duration};
use anyhow::anyhow;
use bytes::{BufMut, Bytes, BytesMut};
use futures::{AsyncRead, TryStreamExt as _};
use http_client::{http, ReadTimeout, RedirectPolicy};
use http_client::{http, RedirectPolicy};
use reqwest::{
header::{HeaderMap, HeaderValue},
redirect,
@ -20,9 +20,14 @@ pub struct ReqwestClient {
}
impl ReqwestClient {
pub fn new() -> Self {
fn builder() -> reqwest::ClientBuilder {
reqwest::Client::builder()
.use_rustls_tls()
.connect_timeout(Duration::from_secs(10))
}
pub fn new() -> Self {
Self::builder()
.build()
.expect("Failed to initialize HTTP client")
.into()
@ -31,19 +36,14 @@ impl ReqwestClient {
pub fn user_agent(agent: &str) -> anyhow::Result<Self> {
let mut map = HeaderMap::new();
map.insert(http::header::USER_AGENT, HeaderValue::from_str(agent)?);
let client = reqwest::Client::builder()
.default_headers(map)
.use_rustls_tls()
.build()?;
let client = Self::builder().default_headers(map).build()?;
Ok(client.into())
}
pub fn proxy_and_user_agent(proxy: Option<http::Uri>, agent: &str) -> anyhow::Result<Self> {
let mut map = HeaderMap::new();
map.insert(http::header::USER_AGENT, HeaderValue::from_str(agent)?);
let mut client = reqwest::Client::builder()
.use_rustls_tls()
.default_headers(map);
let mut client = Self::builder().default_headers(map);
if let Some(proxy) = proxy.clone().and_then(|proxy_uri| {
reqwest::Proxy::all(proxy_uri.to_string())
.inspect_err(|e| log::error!("Failed to parse proxy URI {}: {}", proxy_uri, e))
@ -204,9 +204,6 @@ impl http_client::HttpClient for ReqwestClient {
RedirectPolicy::FollowAll => redirect::Policy::limited(100),
});
}
if let Some(ReadTimeout(timeout)) = parts.extensions.get::<ReadTimeout>() {
request = request.timeout(*timeout);
}
let request = request.body(match body.0 {
http_client::Inner::Empty => reqwest::Body::default(),
http_client::Inner::Bytes(cursor) => cursor.into_inner().into(),

View file

@ -124,8 +124,6 @@ Download and install Ollama from [ollama.com/download](https://ollama.com/downlo
3. In the assistant panel, select one of the Ollama models using the model dropdown.
4. (Optional) Specify an [`api_url`](#custom-endpoint) or [`low_speed_timeout_in_seconds`](#provider-timeout) if required.
#### Ollama Context Length {#ollama-context}
Zed has pre-configured maximum context lengths (`max_tokens`) to match the capabilities of common models. Zed API requests to Ollama include this as `num_ctx` parameter, but the default values do not exceed `16384` so users with ~16GB of ram are able to use most models out of the box. See [get_max_tokens in ollama.rs](https://github.com/zed-industries/zed/blob/main/crates/ollama/src/ollama.rs) for a complete set of defaults.
@ -139,7 +137,6 @@ Depending on your hardware or use-case you may wish to limit or increase the con
"language_models": {
"ollama": {
"api_url": "http://localhost:11434",
"low_speed_timeout_in_seconds": 120,
"available_models": [
{
"name": "qwen2.5-coder",
@ -233,22 +230,6 @@ To do so, add the following to your Zed `settings.json`:
Where `some-provider` can be any of the following values: `anthropic`, `google`, `ollama`, `openai`.
#### Custom timeout {#provider-timeout}
You can customize the timeout that's used for LLM requests, by adding the following to your Zed `settings.json`:
```json
{
"language_models": {
"some-provider": {
"low_speed_timeout_in_seconds": 10
}
}
}
```
Where `some-provider` can be any of the following values: `anthropic`, `copilot_chat`, `google`, `ollama`, `openai`.
#### Configuring the default model {#default-model}
The default model can be set via the model dropdown in the assistant panel's top-right corner. Selecting a model saves it as the default.