
This removes the `low_speed_timeout` setting from all providers as a response to issue #19509. Reason being that the original `low_speed_timeout` was only as part of #9913 because users wanted to _get rid of timeouts_. They wanted to bump the default timeout from 5sec to a lot more. Then, in the meantime, the meaning of `low_speed_timeout` changed in #19055 and was changed to a normal `timeout`, which is a different thing and breaks slower LLMs that don't reply with a complete response in the configured timeout. So we figured: let's remove the whole thing and replace it with a default _connect_ timeout to make sure that we can connect to a server in 10s, but then give the server as long as it wants to complete its response. Closes #19509 Release Notes: - Removed the `low_speed_timeout` setting from LLM provider settings, since it was only used to _increase_ the timeout to give LLMs more time, but since we don't have any other use for it, we simply remove the setting to give LLMs as long as they need. --------- Co-authored-by: Antonio <antonio@zed.dev> Co-authored-by: Peter Tripp <peter@zed.dev>
588 lines
18 KiB
Rust
588 lines
18 KiB
Rust
mod supported_countries;
|
|
|
|
use anyhow::{anyhow, Context, Result};
|
|
use futures::{
|
|
io::BufReader,
|
|
stream::{self, BoxStream},
|
|
AsyncBufReadExt, AsyncReadExt, Stream, StreamExt,
|
|
};
|
|
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
|
|
use serde::{Deserialize, Serialize};
|
|
use serde_json::Value;
|
|
use std::{
|
|
convert::TryFrom,
|
|
future::{self, Future},
|
|
pin::Pin,
|
|
};
|
|
use strum::EnumIter;
|
|
|
|
pub use supported_countries::*;
|
|
|
|
pub const OPEN_AI_API_URL: &str = "https://api.openai.com/v1";
|
|
|
|
fn is_none_or_empty<T: AsRef<[U]>, U>(opt: &Option<T>) -> bool {
|
|
opt.as_ref().map_or(true, |v| v.as_ref().is_empty())
|
|
}
|
|
|
|
#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
|
|
#[serde(rename_all = "lowercase")]
|
|
pub enum Role {
|
|
User,
|
|
Assistant,
|
|
System,
|
|
Tool,
|
|
}
|
|
|
|
impl TryFrom<String> for Role {
|
|
type Error = anyhow::Error;
|
|
|
|
fn try_from(value: String) -> Result<Self> {
|
|
match value.as_str() {
|
|
"user" => Ok(Self::User),
|
|
"assistant" => Ok(Self::Assistant),
|
|
"system" => Ok(Self::System),
|
|
"tool" => Ok(Self::Tool),
|
|
_ => Err(anyhow!("invalid role '{value}'")),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl From<Role> for String {
|
|
fn from(val: Role) -> Self {
|
|
match val {
|
|
Role::User => "user".to_owned(),
|
|
Role::Assistant => "assistant".to_owned(),
|
|
Role::System => "system".to_owned(),
|
|
Role::Tool => "tool".to_owned(),
|
|
}
|
|
}
|
|
}
|
|
|
|
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
|
|
#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
|
|
pub enum Model {
|
|
#[serde(rename = "gpt-3.5-turbo", alias = "gpt-3.5-turbo")]
|
|
ThreePointFiveTurbo,
|
|
#[serde(rename = "gpt-4", alias = "gpt-4")]
|
|
Four,
|
|
#[serde(rename = "gpt-4-turbo", alias = "gpt-4-turbo")]
|
|
FourTurbo,
|
|
#[serde(rename = "gpt-4o", alias = "gpt-4o")]
|
|
#[default]
|
|
FourOmni,
|
|
#[serde(rename = "gpt-4o-mini", alias = "gpt-4o-mini")]
|
|
FourOmniMini,
|
|
#[serde(rename = "o1-preview", alias = "o1-preview")]
|
|
O1Preview,
|
|
#[serde(rename = "o1-mini", alias = "o1-mini")]
|
|
O1Mini,
|
|
|
|
#[serde(rename = "custom")]
|
|
Custom {
|
|
name: String,
|
|
/// The name displayed in the UI, such as in the assistant panel model dropdown menu.
|
|
display_name: Option<String>,
|
|
max_tokens: usize,
|
|
max_output_tokens: Option<u32>,
|
|
max_completion_tokens: Option<u32>,
|
|
},
|
|
}
|
|
|
|
impl Model {
|
|
pub fn from_id(id: &str) -> Result<Self> {
|
|
match id {
|
|
"gpt-3.5-turbo" => Ok(Self::ThreePointFiveTurbo),
|
|
"gpt-4" => Ok(Self::Four),
|
|
"gpt-4-turbo-preview" => Ok(Self::FourTurbo),
|
|
"gpt-4o" => Ok(Self::FourOmni),
|
|
"gpt-4o-mini" => Ok(Self::FourOmniMini),
|
|
"o1-preview" => Ok(Self::O1Preview),
|
|
"o1-mini" => Ok(Self::O1Mini),
|
|
_ => Err(anyhow!("invalid model id")),
|
|
}
|
|
}
|
|
|
|
pub fn id(&self) -> &str {
|
|
match self {
|
|
Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
|
|
Self::Four => "gpt-4",
|
|
Self::FourTurbo => "gpt-4-turbo",
|
|
Self::FourOmni => "gpt-4o",
|
|
Self::FourOmniMini => "gpt-4o-mini",
|
|
Self::O1Preview => "o1-preview",
|
|
Self::O1Mini => "o1-mini",
|
|
Self::Custom { name, .. } => name,
|
|
}
|
|
}
|
|
|
|
pub fn display_name(&self) -> &str {
|
|
match self {
|
|
Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
|
|
Self::Four => "gpt-4",
|
|
Self::FourTurbo => "gpt-4-turbo",
|
|
Self::FourOmni => "gpt-4o",
|
|
Self::FourOmniMini => "gpt-4o-mini",
|
|
Self::O1Preview => "o1-preview",
|
|
Self::O1Mini => "o1-mini",
|
|
Self::Custom {
|
|
name, display_name, ..
|
|
} => display_name.as_ref().unwrap_or(name),
|
|
}
|
|
}
|
|
|
|
pub fn max_token_count(&self) -> usize {
|
|
match self {
|
|
Self::ThreePointFiveTurbo => 16385,
|
|
Self::Four => 8192,
|
|
Self::FourTurbo => 128000,
|
|
Self::FourOmni => 128000,
|
|
Self::FourOmniMini => 128000,
|
|
Self::O1Preview => 128000,
|
|
Self::O1Mini => 128000,
|
|
Self::Custom { max_tokens, .. } => *max_tokens,
|
|
}
|
|
}
|
|
|
|
pub fn max_output_tokens(&self) -> Option<u32> {
|
|
match self {
|
|
Self::Custom {
|
|
max_output_tokens, ..
|
|
} => *max_output_tokens,
|
|
_ => None,
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Serialize, Deserialize)]
|
|
pub struct Request {
|
|
pub model: String,
|
|
pub messages: Vec<RequestMessage>,
|
|
pub stream: bool,
|
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
|
pub max_tokens: Option<u32>,
|
|
#[serde(default, skip_serializing_if = "Vec::is_empty")]
|
|
pub stop: Vec<String>,
|
|
pub temperature: f32,
|
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
|
pub tool_choice: Option<ToolChoice>,
|
|
#[serde(default, skip_serializing_if = "Vec::is_empty")]
|
|
pub tools: Vec<ToolDefinition>,
|
|
}
|
|
|
|
#[derive(Debug, Serialize, Deserialize)]
|
|
#[serde(untagged)]
|
|
pub enum ToolChoice {
|
|
Auto,
|
|
Required,
|
|
None,
|
|
Other(ToolDefinition),
|
|
}
|
|
|
|
#[derive(Clone, Deserialize, Serialize, Debug)]
|
|
#[serde(tag = "type", rename_all = "snake_case")]
|
|
pub enum ToolDefinition {
|
|
#[allow(dead_code)]
|
|
Function { function: FunctionDefinition },
|
|
}
|
|
|
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
|
pub struct FunctionDefinition {
|
|
pub name: String,
|
|
pub description: Option<String>,
|
|
pub parameters: Option<Value>,
|
|
}
|
|
|
|
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
|
#[serde(tag = "role", rename_all = "lowercase")]
|
|
pub enum RequestMessage {
|
|
Assistant {
|
|
content: Option<String>,
|
|
#[serde(default, skip_serializing_if = "Vec::is_empty")]
|
|
tool_calls: Vec<ToolCall>,
|
|
},
|
|
User {
|
|
content: String,
|
|
},
|
|
System {
|
|
content: String,
|
|
},
|
|
Tool {
|
|
content: String,
|
|
tool_call_id: String,
|
|
},
|
|
}
|
|
|
|
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
|
pub struct ToolCall {
|
|
pub id: String,
|
|
#[serde(flatten)]
|
|
pub content: ToolCallContent,
|
|
}
|
|
|
|
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
|
#[serde(tag = "type", rename_all = "lowercase")]
|
|
pub enum ToolCallContent {
|
|
Function { function: FunctionContent },
|
|
}
|
|
|
|
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
|
pub struct FunctionContent {
|
|
pub name: String,
|
|
pub arguments: String,
|
|
}
|
|
|
|
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
|
pub struct ResponseMessageDelta {
|
|
pub role: Option<Role>,
|
|
pub content: Option<String>,
|
|
#[serde(default, skip_serializing_if = "is_none_or_empty")]
|
|
pub tool_calls: Option<Vec<ToolCallChunk>>,
|
|
}
|
|
|
|
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
|
pub struct ToolCallChunk {
|
|
pub index: usize,
|
|
pub id: Option<String>,
|
|
|
|
// There is also an optional `type` field that would determine if a
|
|
// function is there. Sometimes this streams in with the `function` before
|
|
// it streams in the `type`
|
|
pub function: Option<FunctionChunk>,
|
|
}
|
|
|
|
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
|
pub struct FunctionChunk {
|
|
pub name: Option<String>,
|
|
pub arguments: Option<String>,
|
|
}
|
|
|
|
#[derive(Serialize, Deserialize, Debug)]
|
|
pub struct Usage {
|
|
pub prompt_tokens: u32,
|
|
pub completion_tokens: u32,
|
|
pub total_tokens: u32,
|
|
}
|
|
|
|
#[derive(Serialize, Deserialize, Debug)]
|
|
pub struct ChoiceDelta {
|
|
pub index: u32,
|
|
pub delta: ResponseMessageDelta,
|
|
pub finish_reason: Option<String>,
|
|
}
|
|
|
|
#[derive(Serialize, Deserialize, Debug)]
|
|
#[serde(untagged)]
|
|
pub enum ResponseStreamResult {
|
|
Ok(ResponseStreamEvent),
|
|
Err { error: String },
|
|
}
|
|
|
|
#[derive(Serialize, Deserialize, Debug)]
|
|
pub struct ResponseStreamEvent {
|
|
pub created: u32,
|
|
pub model: String,
|
|
pub choices: Vec<ChoiceDelta>,
|
|
pub usage: Option<Usage>,
|
|
}
|
|
|
|
#[derive(Serialize, Deserialize, Debug)]
|
|
pub struct Response {
|
|
pub id: String,
|
|
pub object: String,
|
|
pub created: u64,
|
|
pub model: String,
|
|
pub choices: Vec<Choice>,
|
|
pub usage: Usage,
|
|
}
|
|
|
|
#[derive(Serialize, Deserialize, Debug)]
|
|
pub struct Choice {
|
|
pub index: u32,
|
|
pub message: RequestMessage,
|
|
pub finish_reason: Option<String>,
|
|
}
|
|
|
|
pub async fn complete(
|
|
client: &dyn HttpClient,
|
|
api_url: &str,
|
|
api_key: &str,
|
|
request: Request,
|
|
) -> Result<Response> {
|
|
let uri = format!("{api_url}/chat/completions");
|
|
let request_builder = HttpRequest::builder()
|
|
.method(Method::POST)
|
|
.uri(uri)
|
|
.header("Content-Type", "application/json")
|
|
.header("Authorization", format!("Bearer {}", api_key));
|
|
|
|
let mut request_body = request;
|
|
request_body.stream = false;
|
|
|
|
let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request_body)?))?;
|
|
let mut response = client.send(request).await?;
|
|
|
|
if response.status().is_success() {
|
|
let mut body = String::new();
|
|
response.body_mut().read_to_string(&mut body).await?;
|
|
let response: Response = serde_json::from_str(&body)?;
|
|
Ok(response)
|
|
} else {
|
|
let mut body = String::new();
|
|
response.body_mut().read_to_string(&mut body).await?;
|
|
|
|
#[derive(Deserialize)]
|
|
struct OpenAiResponse {
|
|
error: OpenAiError,
|
|
}
|
|
|
|
#[derive(Deserialize)]
|
|
struct OpenAiError {
|
|
message: String,
|
|
}
|
|
|
|
match serde_json::from_str::<OpenAiResponse>(&body) {
|
|
Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
|
|
"Failed to connect to OpenAI API: {}",
|
|
response.error.message,
|
|
)),
|
|
|
|
_ => Err(anyhow!(
|
|
"Failed to connect to OpenAI API: {} {}",
|
|
response.status(),
|
|
body,
|
|
)),
|
|
}
|
|
}
|
|
}
|
|
|
|
fn adapt_response_to_stream(response: Response) -> ResponseStreamEvent {
|
|
ResponseStreamEvent {
|
|
created: response.created as u32,
|
|
model: response.model,
|
|
choices: response
|
|
.choices
|
|
.into_iter()
|
|
.map(|choice| ChoiceDelta {
|
|
index: choice.index,
|
|
delta: ResponseMessageDelta {
|
|
role: Some(match choice.message {
|
|
RequestMessage::Assistant { .. } => Role::Assistant,
|
|
RequestMessage::User { .. } => Role::User,
|
|
RequestMessage::System { .. } => Role::System,
|
|
RequestMessage::Tool { .. } => Role::Tool,
|
|
}),
|
|
content: match choice.message {
|
|
RequestMessage::Assistant { content, .. } => content,
|
|
RequestMessage::User { content } => Some(content),
|
|
RequestMessage::System { content } => Some(content),
|
|
RequestMessage::Tool { content, .. } => Some(content),
|
|
},
|
|
tool_calls: None,
|
|
},
|
|
finish_reason: choice.finish_reason,
|
|
})
|
|
.collect(),
|
|
usage: Some(response.usage),
|
|
}
|
|
}
|
|
|
|
pub async fn stream_completion(
|
|
client: &dyn HttpClient,
|
|
api_url: &str,
|
|
api_key: &str,
|
|
request: Request,
|
|
) -> Result<BoxStream<'static, Result<ResponseStreamEvent>>> {
|
|
if request.model == "o1-preview" || request.model == "o1-mini" {
|
|
let response = complete(client, api_url, api_key, request).await;
|
|
let response_stream_event = response.map(adapt_response_to_stream);
|
|
return Ok(stream::once(future::ready(response_stream_event)).boxed());
|
|
}
|
|
|
|
let uri = format!("{api_url}/chat/completions");
|
|
let request_builder = HttpRequest::builder()
|
|
.method(Method::POST)
|
|
.uri(uri)
|
|
.header("Content-Type", "application/json")
|
|
.header("Authorization", format!("Bearer {}", api_key));
|
|
|
|
let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
|
|
let mut response = client.send(request).await?;
|
|
if response.status().is_success() {
|
|
let reader = BufReader::new(response.into_body());
|
|
Ok(reader
|
|
.lines()
|
|
.filter_map(|line| async move {
|
|
match line {
|
|
Ok(line) => {
|
|
let line = line.strip_prefix("data: ")?;
|
|
if line == "[DONE]" {
|
|
None
|
|
} else {
|
|
match serde_json::from_str(line) {
|
|
Ok(ResponseStreamResult::Ok(response)) => Some(Ok(response)),
|
|
Ok(ResponseStreamResult::Err { error }) => {
|
|
Some(Err(anyhow!(error)))
|
|
}
|
|
Err(error) => Some(Err(anyhow!(error))),
|
|
}
|
|
}
|
|
}
|
|
Err(error) => Some(Err(anyhow!(error))),
|
|
}
|
|
})
|
|
.boxed())
|
|
} else {
|
|
let mut body = String::new();
|
|
response.body_mut().read_to_string(&mut body).await?;
|
|
|
|
#[derive(Deserialize)]
|
|
struct OpenAiResponse {
|
|
error: OpenAiError,
|
|
}
|
|
|
|
#[derive(Deserialize)]
|
|
struct OpenAiError {
|
|
message: String,
|
|
}
|
|
|
|
match serde_json::from_str::<OpenAiResponse>(&body) {
|
|
Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
|
|
"Failed to connect to OpenAI API: {}",
|
|
response.error.message,
|
|
)),
|
|
|
|
_ => Err(anyhow!(
|
|
"Failed to connect to OpenAI API: {} {}",
|
|
response.status(),
|
|
body,
|
|
)),
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Copy, Clone, Serialize, Deserialize)]
|
|
pub enum OpenAiEmbeddingModel {
|
|
#[serde(rename = "text-embedding-3-small")]
|
|
TextEmbedding3Small,
|
|
#[serde(rename = "text-embedding-3-large")]
|
|
TextEmbedding3Large,
|
|
}
|
|
|
|
#[derive(Serialize)]
|
|
struct OpenAiEmbeddingRequest<'a> {
|
|
model: OpenAiEmbeddingModel,
|
|
input: Vec<&'a str>,
|
|
}
|
|
|
|
#[derive(Deserialize)]
|
|
pub struct OpenAiEmbeddingResponse {
|
|
pub data: Vec<OpenAiEmbedding>,
|
|
}
|
|
|
|
#[derive(Deserialize)]
|
|
pub struct OpenAiEmbedding {
|
|
pub embedding: Vec<f32>,
|
|
}
|
|
|
|
pub fn embed<'a>(
|
|
client: &dyn HttpClient,
|
|
api_url: &str,
|
|
api_key: &str,
|
|
model: OpenAiEmbeddingModel,
|
|
texts: impl IntoIterator<Item = &'a str>,
|
|
) -> impl 'static + Future<Output = Result<OpenAiEmbeddingResponse>> {
|
|
let uri = format!("{api_url}/embeddings");
|
|
|
|
let request = OpenAiEmbeddingRequest {
|
|
model,
|
|
input: texts.into_iter().collect(),
|
|
};
|
|
let body = AsyncBody::from(serde_json::to_string(&request).unwrap());
|
|
let request = HttpRequest::builder()
|
|
.method(Method::POST)
|
|
.uri(uri)
|
|
.header("Content-Type", "application/json")
|
|
.header("Authorization", format!("Bearer {}", api_key))
|
|
.body(body)
|
|
.map(|request| client.send(request));
|
|
|
|
async move {
|
|
let mut response = request?.await?;
|
|
let mut body = String::new();
|
|
response.body_mut().read_to_string(&mut body).await?;
|
|
|
|
if response.status().is_success() {
|
|
let response: OpenAiEmbeddingResponse =
|
|
serde_json::from_str(&body).context("failed to parse OpenAI embedding response")?;
|
|
Ok(response)
|
|
} else {
|
|
Err(anyhow!(
|
|
"error during embedding, status: {:?}, body: {:?}",
|
|
response.status(),
|
|
body
|
|
))
|
|
}
|
|
}
|
|
}
|
|
|
|
pub async fn extract_tool_args_from_events(
|
|
tool_name: String,
|
|
mut events: Pin<Box<dyn Send + Stream<Item = Result<ResponseStreamEvent>>>>,
|
|
) -> Result<impl Send + Stream<Item = Result<String>>> {
|
|
let mut tool_use_index = None;
|
|
let mut first_chunk = None;
|
|
while let Some(event) = events.next().await {
|
|
let call = event?.choices.into_iter().find_map(|choice| {
|
|
choice.delta.tool_calls?.into_iter().find_map(|call| {
|
|
if call.function.as_ref()?.name.as_deref()? == tool_name {
|
|
Some(call)
|
|
} else {
|
|
None
|
|
}
|
|
})
|
|
});
|
|
if let Some(call) = call {
|
|
tool_use_index = Some(call.index);
|
|
first_chunk = call.function.and_then(|func| func.arguments);
|
|
break;
|
|
}
|
|
}
|
|
|
|
let Some(tool_use_index) = tool_use_index else {
|
|
return Err(anyhow!("tool not used"));
|
|
};
|
|
|
|
Ok(events.filter_map(move |event| {
|
|
let result = match event {
|
|
Err(error) => Some(Err(error)),
|
|
Ok(ResponseStreamEvent { choices, .. }) => choices.into_iter().find_map(|choice| {
|
|
choice.delta.tool_calls?.into_iter().find_map(|call| {
|
|
if call.index == tool_use_index {
|
|
let func = call.function?;
|
|
let mut arguments = func.arguments?;
|
|
if let Some(mut first_chunk) = first_chunk.take() {
|
|
first_chunk.push_str(&arguments);
|
|
arguments = first_chunk
|
|
}
|
|
Some(Ok(arguments))
|
|
} else {
|
|
None
|
|
}
|
|
})
|
|
}),
|
|
};
|
|
|
|
async move { result }
|
|
}))
|
|
}
|
|
|
|
pub fn extract_text_from_events(
|
|
response: impl Stream<Item = Result<ResponseStreamEvent>>,
|
|
) -> impl Stream<Item = Result<String>> {
|
|
response.filter_map(|response| async move {
|
|
match response {
|
|
Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
|
|
Err(error) => Some(Err(error)),
|
|
}
|
|
})
|
|
}
|