pub mod assistant; mod assistant_settings; mod streaming_diff; use anyhow::{anyhow, Result}; pub use assistant::AssistantPanel; use assistant_settings::OpenAIModel; use chrono::{DateTime, Local}; use collections::HashMap; use fs::Fs; use futures::{io::BufReader, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt}; use gpui::{executor::Background, AppContext}; use isahc::{http::StatusCode, Request, RequestExt}; use regex::Regex; use serde::{Deserialize, Serialize}; use std::{ cmp::Reverse, ffi::OsStr, fmt::{self, Display}, io, path::PathBuf, sync::Arc, }; use util::paths::CONVERSATIONS_DIR; const OPENAI_API_URL: &'static str = "https://api.openai.com/v1"; // Data types for chat completion requests #[derive(Debug, Serialize)] pub struct OpenAIRequest { model: String, messages: Vec, stream: bool, } #[derive( Copy, Clone, Debug, Default, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize, )] struct MessageId(usize); #[derive(Clone, Debug, Serialize, Deserialize)] struct MessageMetadata { role: Role, sent_at: DateTime, status: MessageStatus, } #[derive(Clone, Debug, Serialize, Deserialize)] enum MessageStatus { Pending, Done, Error(Arc), } #[derive(Serialize, Deserialize)] struct SavedMessage { id: MessageId, start: usize, } #[derive(Serialize, Deserialize)] struct SavedConversation { zed: String, version: String, text: String, messages: Vec, message_metadata: HashMap, summary: String, model: OpenAIModel, } impl SavedConversation { const VERSION: &'static str = "0.1.0"; } struct SavedConversationMetadata { title: String, path: PathBuf, mtime: chrono::DateTime, } impl SavedConversationMetadata { pub async fn list(fs: Arc) -> Result> { fs.create_dir(&CONVERSATIONS_DIR).await?; let mut paths = fs.read_dir(&CONVERSATIONS_DIR).await?; let mut conversations = Vec::::new(); while let Some(path) = paths.next().await { let path = path?; if path.extension() != Some(OsStr::new("json")) { continue; } let pattern = r" - \d+.zed.json$"; let re = Regex::new(pattern).unwrap(); let metadata = fs.metadata(&path).await?; if let Some((file_name, metadata)) = path .file_name() .and_then(|name| name.to_str()) .zip(metadata) { let title = re.replace(file_name, ""); conversations.push(Self { title: title.into_owned(), path, mtime: metadata.mtime.into(), }); } } conversations.sort_unstable_by_key(|conversation| Reverse(conversation.mtime)); Ok(conversations) } } #[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] struct RequestMessage { role: Role, content: String, } #[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] pub struct ResponseMessage { role: Option, content: Option, } #[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)] #[serde(rename_all = "lowercase")] enum Role { User, Assistant, System, } impl Role { pub fn cycle(&mut self) { *self = match self { Role::User => Role::Assistant, Role::Assistant => Role::System, Role::System => Role::User, } } } impl Display for Role { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result { match self { Role::User => write!(f, "User"), Role::Assistant => write!(f, "Assistant"), Role::System => write!(f, "System"), } } } #[derive(Deserialize, Debug)] pub struct OpenAIResponseStreamEvent { pub id: Option, pub object: String, pub created: u32, pub model: String, pub choices: Vec, pub usage: Option, } #[derive(Deserialize, Debug)] pub struct Usage { pub prompt_tokens: u32, pub completion_tokens: u32, pub total_tokens: u32, } #[derive(Deserialize, Debug)] pub struct ChatChoiceDelta { pub index: u32, pub delta: ResponseMessage, pub finish_reason: Option, } #[derive(Deserialize, Debug)] struct OpenAIUsage { prompt_tokens: u64, completion_tokens: u64, total_tokens: u64, } #[derive(Deserialize, Debug)] struct OpenAIChoice { text: String, index: u32, logprobs: Option, finish_reason: Option, } pub fn init(cx: &mut AppContext) { assistant::init(cx); } pub async fn stream_completion( api_key: String, executor: Arc, mut request: OpenAIRequest, ) -> Result>> { request.stream = true; let (tx, rx) = futures::channel::mpsc::unbounded::>(); let json_data = serde_json::to_string(&request)?; let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions")) .header("Content-Type", "application/json") .header("Authorization", format!("Bearer {}", api_key)) .body(json_data)? .send_async() .await?; let status = response.status(); if status == StatusCode::OK { executor .spawn(async move { let mut lines = BufReader::new(response.body_mut()).lines(); fn parse_line( line: Result, ) -> Result> { if let Some(data) = line?.strip_prefix("data: ") { let event = serde_json::from_str(&data)?; Ok(Some(event)) } else { Ok(None) } } while let Some(line) = lines.next().await { if let Some(event) = parse_line(line).transpose() { let done = event.as_ref().map_or(false, |event| { event .choices .last() .map_or(false, |choice| choice.finish_reason.is_some()) }); if tx.unbounded_send(event).is_err() { break; } if done { break; } } } anyhow::Ok(()) }) .detach(); Ok(rx) } else { let mut body = String::new(); response.body_mut().read_to_string(&mut body).await?; #[derive(Deserialize)] struct OpenAIResponse { error: OpenAIError, } #[derive(Deserialize)] struct OpenAIError { message: String, } match serde_json::from_str::(&body) { Ok(response) if !response.error.message.is_empty() => Err(anyhow!( "Failed to connect to OpenAI API: {}", response.error.message, )), _ => Err(anyhow!( "Failed to connect to OpenAI API: {} {}", response.status(), body, )), } } } #[cfg(test)] #[ctor::ctor] fn init_logger() { if std::env::var("RUST_LOG").is_ok() { env_logger::init(); } }