diff --git a/assets/settings/default.json b/assets/settings/default.json index 65254afb7c..22dafb2890 100644 --- a/assets/settings/default.json +++ b/assets/settings/default.json @@ -916,7 +916,8 @@ }, "openai": { "version": "1", - "api_url": "https://api.openai.com/v1" + "api_url": "https://api.openai.com/v1", + "low_speed_timeout_in_seconds": 600 } }, // Zed's Prettier integration settings. diff --git a/crates/assistant/src/assistant_settings.rs b/crates/assistant/src/assistant_settings.rs index 7939eacd93..e2c6a8eb24 100644 --- a/crates/assistant/src/assistant_settings.rs +++ b/crates/assistant/src/assistant_settings.rs @@ -163,11 +163,13 @@ impl AssistantSettingsContent { display_name, max_tokens, max_output_tokens, + max_completion_tokens: None, } => Some(open_ai::AvailableModel { name, display_name, max_tokens, max_output_tokens, + max_completion_tokens: None, }), _ => None, }) diff --git a/crates/assistant/src/inline_assistant.rs b/crates/assistant/src/inline_assistant.rs index 246a408477..b01a712a7e 100644 --- a/crates/assistant/src/inline_assistant.rs +++ b/crates/assistant/src/inline_assistant.rs @@ -2407,7 +2407,7 @@ impl Codegen { Ok(LanguageModelRequest { messages, tools: Vec::new(), - stop: vec!["|END|>".to_string()], + stop: Vec::new(), temperature: 1., }) } diff --git a/crates/language_model/src/provider/cloud.rs b/crates/language_model/src/provider/cloud.rs index 0de7fb3feb..f8f64ff3b8 100644 --- a/crates/language_model/src/provider/cloud.rs +++ b/crates/language_model/src/provider/cloud.rs @@ -78,6 +78,8 @@ pub struct AvailableModel { pub max_tokens: usize, /// The maximum number of output tokens allowed by the model. pub max_output_tokens: Option, + /// The maximum number of completion tokens allowed by the model (o1-* only) + pub max_completion_tokens: Option, /// Override this model with a different Anthropic model for tool calls. pub tool_override: Option, /// Indicates whether this custom model supports caching. @@ -257,6 +259,7 @@ impl LanguageModelProvider for CloudLanguageModelProvider { display_name: model.display_name.clone(), max_tokens: model.max_tokens, max_output_tokens: model.max_output_tokens, + max_completion_tokens: model.max_completion_tokens, }), AvailableProvider::Google => CloudModel::Google(google_ai::Model::Custom { name: model.name.clone(), diff --git a/crates/language_model/src/provider/open_ai.rs b/crates/language_model/src/provider/open_ai.rs index fe5e60caec..98424a23aa 100644 --- a/crates/language_model/src/provider/open_ai.rs +++ b/crates/language_model/src/provider/open_ai.rs @@ -43,6 +43,7 @@ pub struct AvailableModel { pub display_name: Option, pub max_tokens: usize, pub max_output_tokens: Option, + pub max_completion_tokens: Option, } pub struct OpenAiLanguageModelProvider { @@ -175,6 +176,7 @@ impl LanguageModelProvider for OpenAiLanguageModelProvider { display_name: model.display_name.clone(), max_tokens: model.max_tokens, max_output_tokens: model.max_output_tokens, + max_completion_tokens: model.max_completion_tokens, }, ); } diff --git a/crates/language_model/src/settings.rs b/crates/language_model/src/settings.rs index 0059ed56c4..80749c0bdb 100644 --- a/crates/language_model/src/settings.rs +++ b/crates/language_model/src/settings.rs @@ -178,11 +178,13 @@ impl OpenAiSettingsContent { display_name, max_tokens, max_output_tokens, + max_completion_tokens, } => Some(provider::open_ai::AvailableModel { name, max_tokens, max_output_tokens, display_name, + max_completion_tokens, }), _ => None, }) diff --git a/crates/open_ai/src/open_ai.rs b/crates/open_ai/src/open_ai.rs index 5b621d6bb8..7b0294bd9c 100644 --- a/crates/open_ai/src/open_ai.rs +++ b/crates/open_ai/src/open_ai.rs @@ -1,12 +1,21 @@ mod supported_countries; use anyhow::{anyhow, Context, Result}; -use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt}; +use futures::{ + io::BufReader, + stream::{self, BoxStream}, + AsyncBufReadExt, AsyncReadExt, Stream, StreamExt, +}; use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest}; use isahc::config::Configurable; use serde::{Deserialize, Serialize}; use serde_json::Value; -use std::{convert::TryFrom, future::Future, pin::Pin, time::Duration}; +use std::{ + convert::TryFrom, + future::{self, Future}, + pin::Pin, + time::Duration, +}; use strum::EnumIter; pub use supported_countries::*; @@ -72,6 +81,7 @@ pub enum Model { display_name: Option, max_tokens: usize, max_output_tokens: Option, + max_completion_tokens: Option, }, } @@ -139,6 +149,7 @@ pub struct Request { pub stream: bool, #[serde(default, skip_serializing_if = "Option::is_none")] pub max_tokens: Option, + #[serde(default, skip_serializing_if = "Vec::is_empty")] pub stop: Vec, pub temperature: f32, #[serde(default, skip_serializing_if = "Option::is_none")] @@ -263,6 +274,111 @@ pub struct ResponseStreamEvent { pub usage: Option, } +#[derive(Serialize, Deserialize, Debug)] +pub struct Response { + pub id: String, + pub object: String, + pub created: u64, + pub model: String, + pub choices: Vec, + pub usage: Usage, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct Choice { + pub index: u32, + pub message: RequestMessage, + pub finish_reason: Option, +} + +pub async fn complete( + client: &dyn HttpClient, + api_url: &str, + api_key: &str, + request: Request, + low_speed_timeout: Option, +) -> Result { + let uri = format!("{api_url}/chat/completions"); + let mut request_builder = HttpRequest::builder() + .method(Method::POST) + .uri(uri) + .header("Content-Type", "application/json") + .header("Authorization", format!("Bearer {}", api_key)); + if let Some(low_speed_timeout) = low_speed_timeout { + request_builder = request_builder.low_speed_timeout(100, low_speed_timeout); + }; + + let mut request_body = request; + request_body.stream = false; + + let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request_body)?))?; + let mut response = client.send(request).await?; + + if response.status().is_success() { + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + let response: Response = serde_json::from_str(&body)?; + Ok(response) + } else { + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + + #[derive(Deserialize)] + struct OpenAiResponse { + error: OpenAiError, + } + + #[derive(Deserialize)] + struct OpenAiError { + message: String, + } + + match serde_json::from_str::(&body) { + Ok(response) if !response.error.message.is_empty() => Err(anyhow!( + "Failed to connect to OpenAI API: {}", + response.error.message, + )), + + _ => Err(anyhow!( + "Failed to connect to OpenAI API: {} {}", + response.status(), + body, + )), + } + } +} + +fn adapt_response_to_stream(response: Response) -> ResponseStreamEvent { + ResponseStreamEvent { + created: response.created as u32, + model: response.model, + choices: response + .choices + .into_iter() + .map(|choice| ChoiceDelta { + index: choice.index, + delta: ResponseMessageDelta { + role: Some(match choice.message { + RequestMessage::Assistant { .. } => Role::Assistant, + RequestMessage::User { .. } => Role::User, + RequestMessage::System { .. } => Role::System, + RequestMessage::Tool { .. } => Role::Tool, + }), + content: match choice.message { + RequestMessage::Assistant { content, .. } => content, + RequestMessage::User { content } => Some(content), + RequestMessage::System { content } => Some(content), + RequestMessage::Tool { content, .. } => Some(content), + }, + tool_calls: None, + }, + finish_reason: choice.finish_reason, + }) + .collect(), + usage: Some(response.usage), + } +} + pub async fn stream_completion( client: &dyn HttpClient, api_url: &str, @@ -270,6 +386,12 @@ pub async fn stream_completion( request: Request, low_speed_timeout: Option, ) -> Result>> { + if request.model == "o1-preview" || request.model == "o1-mini" { + let response = complete(client, api_url, api_key, request, low_speed_timeout).await; + let response_stream_event = response.map(adapt_response_to_stream); + return Ok(stream::once(future::ready(response_stream_event)).boxed()); + } + let uri = format!("{api_url}/chat/completions"); let mut request_builder = HttpRequest::builder() .method(Method::POST)