diff --git a/crates/language_models/src/provider/open_ai.rs b/crates/language_models/src/provider/open_ai.rs index 146aac9cfa..4b164d0c65 100644 --- a/crates/language_models/src/provider/open_ai.rs +++ b/crates/language_models/src/provider/open_ai.rs @@ -438,7 +438,7 @@ pub fn into_open_ai( stream, stop: request.stop, temperature: request.temperature.unwrap_or(1.0), - max_tokens: max_output_tokens, + max_completion_tokens: max_output_tokens, parallel_tool_calls: if model.supports_parallel_tool_calls() && !request.tools.is_empty() { // Disable parallel tool calls, as the Agent currently expects a maximum of one per turn. Some(false) @@ -648,8 +648,6 @@ pub fn count_open_ai_tokens( | Model::FourPointOneMini | Model::FourPointOneNano | Model::O1 - | Model::O1Preview - | Model::O1Mini | Model::O3 | Model::O3Mini | Model::O4Mini => tiktoken_rs::num_tokens_from_messages(model.id(), &messages), diff --git a/crates/open_ai/src/open_ai.rs b/crates/open_ai/src/open_ai.rs index fb50a26a19..3ca953d766 100644 --- a/crates/open_ai/src/open_ai.rs +++ b/crates/open_ai/src/open_ai.rs @@ -1,16 +1,9 @@ use anyhow::{Context as _, Result, anyhow}; -use futures::{ - AsyncBufReadExt, AsyncReadExt, StreamExt, - io::BufReader, - stream::{self, BoxStream}, -}; +use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream}; use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest}; use serde::{Deserialize, Serialize}; use serde_json::Value; -use std::{ - convert::TryFrom, - future::{self, Future}, -}; +use std::{convert::TryFrom, future::Future}; use strum::EnumIter; pub const OPEN_AI_API_URL: &str = "https://api.openai.com/v1"; @@ -75,10 +68,6 @@ pub enum Model { FourPointOneNano, #[serde(rename = "o1")] O1, - #[serde(rename = "o1-preview")] - O1Preview, - #[serde(rename = "o1-mini")] - O1Mini, #[serde(rename = "o3-mini")] O3Mini, #[serde(rename = "o3")] @@ -113,8 +102,6 @@ impl Model { "gpt-4.1-mini" => Ok(Self::FourPointOneMini), "gpt-4.1-nano" => Ok(Self::FourPointOneNano), "o1" => Ok(Self::O1), - "o1-preview" => Ok(Self::O1Preview), - "o1-mini" => Ok(Self::O1Mini), "o3-mini" => Ok(Self::O3Mini), "o3" => Ok(Self::O3), "o4-mini" => Ok(Self::O4Mini), @@ -133,8 +120,6 @@ impl Model { Self::FourPointOneMini => "gpt-4.1-mini", Self::FourPointOneNano => "gpt-4.1-nano", Self::O1 => "o1", - Self::O1Preview => "o1-preview", - Self::O1Mini => "o1-mini", Self::O3Mini => "o3-mini", Self::O3 => "o3", Self::O4Mini => "o4-mini", @@ -153,8 +138,6 @@ impl Model { Self::FourPointOneMini => "gpt-4.1-mini", Self::FourPointOneNano => "gpt-4.1-nano", Self::O1 => "o1", - Self::O1Preview => "o1-preview", - Self::O1Mini => "o1-mini", Self::O3Mini => "o3-mini", Self::O3 => "o3", Self::O4Mini => "o4-mini", @@ -175,8 +158,6 @@ impl Model { Self::FourPointOneMini => 1_047_576, Self::FourPointOneNano => 1_047_576, Self::O1 => 200_000, - Self::O1Preview => 128_000, - Self::O1Mini => 128_000, Self::O3Mini => 200_000, Self::O3 => 200_000, Self::O4Mini => 200_000, @@ -198,8 +179,6 @@ impl Model { Self::FourPointOneMini => Some(32_768), Self::FourPointOneNano => Some(32_768), Self::O1 => Some(100_000), - Self::O1Preview => Some(32_768), - Self::O1Mini => Some(65_536), Self::O3Mini => Some(100_000), Self::O3 => Some(100_000), Self::O4Mini => Some(100_000), @@ -219,13 +198,7 @@ impl Model { | Self::FourPointOne | Self::FourPointOneMini | Self::FourPointOneNano => true, - Self::O1 - | Self::O1Preview - | Self::O1Mini - | Self::O3 - | Self::O3Mini - | Self::O4Mini - | Model::Custom { .. } => false, + Self::O1 | Self::O3 | Self::O3Mini | Self::O4Mini | Model::Custom { .. } => false, } } } @@ -236,7 +209,7 @@ pub struct Request { pub messages: Vec, pub stream: bool, #[serde(default, skip_serializing_if = "Option::is_none")] - pub max_tokens: Option, + pub max_completion_tokens: Option, #[serde(default, skip_serializing_if = "Vec::is_empty")] pub stop: Vec, pub temperature: f32, @@ -249,24 +222,6 @@ pub struct Request { pub tools: Vec, } -#[derive(Debug, Serialize, Deserialize)] -pub struct CompletionRequest { - pub model: String, - pub prompt: String, - pub max_tokens: u32, - pub temperature: f32, - #[serde(default, skip_serializing_if = "Option::is_none")] - pub prediction: Option, - #[serde(default, skip_serializing_if = "Option::is_none")] - pub rewrite_speculation: Option, -} - -#[derive(Clone, Deserialize, Serialize, Debug)] -#[serde(tag = "type", rename_all = "snake_case")] -pub enum Prediction { - Content { content: String }, -} - #[derive(Debug, Serialize, Deserialize)] #[serde(untagged)] pub enum ToolChoice { @@ -436,204 +391,12 @@ pub struct ResponseStreamEvent { pub usage: Option, } -#[derive(Serialize, Deserialize, Debug)] -pub struct CompletionResponse { - pub id: String, - pub object: String, - pub created: u64, - pub model: String, - pub choices: Vec, - pub usage: Usage, -} - -#[derive(Serialize, Deserialize, Debug)] -pub struct CompletionChoice { - pub text: String, -} - -#[derive(Serialize, Deserialize, Debug)] -pub struct Response { - pub id: String, - pub object: String, - pub created: u64, - pub model: String, - pub choices: Vec, - pub usage: Usage, -} - -#[derive(Serialize, Deserialize, Debug)] -pub struct Choice { - pub index: u32, - pub message: RequestMessage, - pub finish_reason: Option, -} - -pub async fn complete( - client: &dyn HttpClient, - api_url: &str, - api_key: &str, - request: Request, -) -> Result { - let uri = format!("{api_url}/chat/completions"); - let request_builder = HttpRequest::builder() - .method(Method::POST) - .uri(uri) - .header("Content-Type", "application/json") - .header("Authorization", format!("Bearer {}", api_key)); - - let mut request_body = request; - request_body.stream = false; - - let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request_body)?))?; - let mut response = client.send(request).await?; - - if response.status().is_success() { - let mut body = String::new(); - response.body_mut().read_to_string(&mut body).await?; - let response: Response = serde_json::from_str(&body)?; - Ok(response) - } else { - let mut body = String::new(); - response.body_mut().read_to_string(&mut body).await?; - - #[derive(Deserialize)] - struct OpenAiResponse { - error: OpenAiError, - } - - #[derive(Deserialize)] - struct OpenAiError { - message: String, - } - - match serde_json::from_str::(&body) { - Ok(response) if !response.error.message.is_empty() => anyhow::bail!( - "Failed to connect to OpenAI API: {}", - response.error.message, - ), - _ => anyhow::bail!( - "Failed to connect to OpenAI API: {} {}", - response.status(), - body, - ), - } - } -} - -pub async fn complete_text( - client: &dyn HttpClient, - api_url: &str, - api_key: &str, - request: CompletionRequest, -) -> Result { - let uri = format!("{api_url}/completions"); - let request_builder = HttpRequest::builder() - .method(Method::POST) - .uri(uri) - .header("Content-Type", "application/json") - .header("Authorization", format!("Bearer {}", api_key)); - - let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?; - let mut response = client.send(request).await?; - - if response.status().is_success() { - let mut body = String::new(); - response.body_mut().read_to_string(&mut body).await?; - let response = serde_json::from_str(&body)?; - Ok(response) - } else { - let mut body = String::new(); - response.body_mut().read_to_string(&mut body).await?; - - #[derive(Deserialize)] - struct OpenAiResponse { - error: OpenAiError, - } - - #[derive(Deserialize)] - struct OpenAiError { - message: String, - } - - match serde_json::from_str::(&body) { - Ok(response) if !response.error.message.is_empty() => anyhow::bail!( - "Failed to connect to OpenAI API: {}", - response.error.message, - ), - _ => anyhow::bail!( - "Failed to connect to OpenAI API: {} {}", - response.status(), - body, - ), - } - } -} - -fn adapt_response_to_stream(response: Response) -> ResponseStreamEvent { - ResponseStreamEvent { - created: response.created as u32, - model: response.model, - choices: response - .choices - .into_iter() - .map(|choice| { - let content = match &choice.message { - RequestMessage::Assistant { content, .. } => content.as_ref(), - RequestMessage::User { content } => Some(content), - RequestMessage::System { content } => Some(content), - RequestMessage::Tool { content, .. } => Some(content), - }; - - let mut text_content = String::new(); - match content { - Some(MessageContent::Plain(text)) => text_content.push_str(&text), - Some(MessageContent::Multipart(parts)) => { - for part in parts { - match part { - MessagePart::Text { text } => text_content.push_str(&text), - MessagePart::Image { .. } => {} - } - } - } - None => {} - }; - - ChoiceDelta { - index: choice.index, - delta: ResponseMessageDelta { - role: Some(match choice.message { - RequestMessage::Assistant { .. } => Role::Assistant, - RequestMessage::User { .. } => Role::User, - RequestMessage::System { .. } => Role::System, - RequestMessage::Tool { .. } => Role::Tool, - }), - content: if text_content.is_empty() { - None - } else { - Some(text_content) - }, - tool_calls: None, - }, - finish_reason: choice.finish_reason, - } - }) - .collect(), - usage: Some(response.usage), - } -} - pub async fn stream_completion( client: &dyn HttpClient, api_url: &str, api_key: &str, request: Request, ) -> Result>> { - if request.model.starts_with("o1") { - let response = complete(client, api_url, api_key, request).await; - let response_stream_event = response.map(adapt_response_to_stream); - return Ok(stream::once(future::ready(response_stream_event)).boxed()); - } - let uri = format!("{api_url}/chat/completions"); let request_builder = HttpRequest::builder() .method(Method::POST)