use anyhow::{anyhow, Result}; use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt}; use serde::{Deserialize, Serialize}; use std::convert::TryFrom; use util::http::{AsyncBody, HttpClient, Method, Request as HttpRequest}; #[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)] #[serde(rename_all = "lowercase")] pub enum Role { User, Assistant, System, } impl TryFrom for Role { type Error = anyhow::Error; fn try_from(value: String) -> Result { match value.as_str() { "user" => Ok(Self::User), "assistant" => Ok(Self::Assistant), "system" => Ok(Self::System), _ => Err(anyhow!("invalid role '{value}'")), } } } impl From for String { fn from(val: Role) -> Self { match val { Role::User => "user".to_owned(), Role::Assistant => "assistant".to_owned(), Role::System => "system".to_owned(), } } } #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] #[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)] pub enum Model { #[serde(rename = "gpt-3.5-turbo", alias = "gpt-3.5-turbo-0613")] ThreePointFiveTurbo, #[serde(rename = "gpt-4", alias = "gpt-4-0613")] Four, #[serde(rename = "gpt-4-turbo-preview", alias = "gpt-4-1106-preview")] #[default] FourTurbo, } impl Model { pub fn from_id(id: &str) -> Result { match id { "gpt-3.5-turbo" => Ok(Self::ThreePointFiveTurbo), "gpt-4" => Ok(Self::Four), "gpt-4-turbo-preview" => Ok(Self::FourTurbo), _ => Err(anyhow!("invalid model id")), } } pub fn id(&self) -> &'static str { match self { Self::ThreePointFiveTurbo => "gpt-3.5-turbo", Self::Four => "gpt-4", Self::FourTurbo => "gpt-4-turbo-preview", } } pub fn display_name(&self) -> &'static str { match self { Self::ThreePointFiveTurbo => "gpt-3.5-turbo", Self::Four => "gpt-4", Self::FourTurbo => "gpt-4-turbo", } } pub fn max_token_count(&self) -> usize { match self { Model::ThreePointFiveTurbo => 4096, Model::Four => 8192, Model::FourTurbo => 128000, } } } #[derive(Debug, Serialize)] pub struct Request { pub model: Model, pub messages: Vec, pub stream: bool, pub stop: Vec, pub temperature: f32, } #[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] pub struct RequestMessage { pub role: Role, pub content: String, } #[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] pub struct ResponseMessage { pub role: Option, pub content: Option, } #[derive(Deserialize, Debug)] pub struct Usage { pub prompt_tokens: u32, pub completion_tokens: u32, pub total_tokens: u32, } #[derive(Deserialize, Debug)] pub struct ChoiceDelta { pub index: u32, pub delta: ResponseMessage, pub finish_reason: Option, } #[derive(Deserialize, Debug)] pub struct ResponseStreamEvent { pub created: u32, pub model: String, pub choices: Vec, pub usage: Option, } pub async fn stream_completion( client: &dyn HttpClient, api_url: &str, api_key: &str, request: Request, ) -> Result>> { let uri = format!("{api_url}/chat/completions"); let request = HttpRequest::builder() .method(Method::POST) .uri(uri) .header("Content-Type", "application/json") .header("Authorization", format!("Bearer {}", api_key)) .body(AsyncBody::from(serde_json::to_string(&request)?))?; let mut response = client.send(request).await?; if response.status().is_success() { let reader = BufReader::new(response.into_body()); Ok(reader .lines() .filter_map(|line| async move { match line { Ok(line) => { let line = line.strip_prefix("data: ")?; if line == "[DONE]" { None } else { match serde_json::from_str(line) { Ok(response) => Some(Ok(response)), Err(error) => Some(Err(anyhow!(error))), } } } Err(error) => Some(Err(anyhow!(error))), } }) .boxed()) } else { let mut body = String::new(); response.body_mut().read_to_string(&mut body).await?; #[derive(Deserialize)] struct OpenAiResponse { error: OpenAiError, } #[derive(Deserialize)] struct OpenAiError { message: String, } match serde_json::from_str::(&body) { Ok(response) if !response.error.message.is_empty() => Err(anyhow!( "Failed to connect to OpenAI API: {}", response.error.message, )), _ => Err(anyhow!( "Failed to connect to OpenAI API: {} {}", response.status(), body, )), } } }