use anyhow::{anyhow, Result}; use editor::Editor; use futures::AsyncBufReadExt; use futures::{io::BufReader, AsyncReadExt, Stream, StreamExt}; use gpui::executor::Background; use gpui::{actions, AppContext, Task, ViewContext}; use indoc::indoc; use isahc::prelude::*; use isahc::{http::StatusCode, Request}; use serde::{Deserialize, Serialize}; use std::{io, sync::Arc}; use util::ResultExt; actions!(ai, [Assist]); // Data types for chat completion requests #[derive(Serialize)] struct OpenAIRequest { model: String, messages: Vec, stream: bool, } #[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] struct RequestMessage { role: Role, content: String, } #[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] struct ResponseMessage { role: Option, content: Option, } #[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] #[serde(rename_all = "lowercase")] enum Role { User, Assistant, System, } #[derive(Deserialize, Debug)] struct OpenAIResponseStreamEvent { pub id: Option, pub object: String, pub created: u32, pub model: String, pub choices: Vec, pub usage: Option, } #[derive(Deserialize, Debug)] struct Usage { pub prompt_tokens: u32, pub completion_tokens: u32, pub total_tokens: u32, } #[derive(Deserialize, Debug)] struct ChatChoiceDelta { pub index: u32, pub delta: ResponseMessage, pub finish_reason: Option, } #[derive(Deserialize, Debug)] struct OpenAIUsage { prompt_tokens: u64, completion_tokens: u64, total_tokens: u64, } #[derive(Deserialize, Debug)] struct OpenAIChoice { text: String, index: u32, logprobs: Option, finish_reason: Option, } pub fn init(cx: &mut AppContext) { cx.add_async_action(assist) } fn assist( editor: &mut Editor, _: &Assist, cx: &mut ViewContext, ) -> Option>> { let api_key = std::env::var("OPENAI_API_KEY").log_err()?; const SYSTEM_MESSAGE: &'static str = indoc! {r#" You an AI language model embedded in a code editor named Zed, authored by Zed Industries. The input you are currently processing was produced by a special \"model mention\" in a document that is open in the editor. A model mention is indicated via a leading / on a line. The user's currently selected text is indicated via ->->selected text<-<- surrounding selected text. In this sentence, the word ->->example<-<- is selected. Respond to any selected model mention. Summarize each mention in a single short sentence like: > The user selected the word \"example\". Then provide your response to that mention below its summary. "#}; let (user_message, insertion_site) = editor.buffer().update(cx, |buffer, cx| { // Insert ->-> <-<- around selected text as described in the system prompt above. let snapshot = buffer.snapshot(cx); let mut user_message = String::new(); let mut buffer_offset = 0; for selection in editor.selections.all(cx) { user_message.extend(snapshot.text_for_range(buffer_offset..selection.start)); user_message.push_str("->->"); user_message.extend(snapshot.text_for_range(selection.start..selection.end)); buffer_offset = selection.end; user_message.push_str("<-<-"); } if buffer_offset < snapshot.len() { user_message.extend(snapshot.text_for_range(buffer_offset..snapshot.len())); } // Ensure the document ends with 4 trailing newlines. let trailing_newline_count = snapshot .reversed_chars_at(snapshot.len()) .take_while(|c| *c == '\n') .take(4); let suffix = "\n".repeat(4 - trailing_newline_count.count()); buffer.edit([(snapshot.len()..snapshot.len(), suffix)], None, cx); let snapshot = buffer.snapshot(cx); // Take a new snapshot after editing. let insertion_site = snapshot.len() - 2; // Insert text at end of buffer, with an empty line both above and below. (user_message, insertion_site) }); let stream = stream_completion( api_key, cx.background_executor().clone(), OpenAIRequest { model: "gpt-4".to_string(), messages: vec![ RequestMessage { role: Role::System, content: SYSTEM_MESSAGE.to_string(), }, RequestMessage { role: Role::User, content: user_message, }, ], stream: false, }, ); let buffer = editor.buffer().clone(); Some(cx.spawn(|_, mut cx| async move { let mut messages = stream.await?; while let Some(message) = messages.next().await { let mut message = message?; if let Some(choice) = message.choices.pop() { buffer.update(&mut cx, |buffer, cx| { let text: Arc = choice.delta.content?.into(); buffer.edit([(insertion_site.clone()..insertion_site, text)], None, cx); Some(()) }); } } Ok(()) })) } async fn stream_completion( api_key: String, executor: Arc, mut request: OpenAIRequest, ) -> Result>> { request.stream = true; let (tx, rx) = futures::channel::mpsc::unbounded::>(); let json_data = serde_json::to_string(&request)?; let mut response = Request::post("https://api.openai.com/v1/chat/completions") .header("Content-Type", "application/json") .header("Authorization", format!("Bearer {}", api_key)) .body(json_data)? .send_async() .await?; let status = response.status(); if status == StatusCode::OK { executor .spawn(async move { let mut lines = BufReader::new(response.body_mut()).lines(); fn parse_line( line: Result, ) -> Result> { if let Some(data) = line?.strip_prefix("data: ") { let event = serde_json::from_str(&data)?; Ok(Some(event)) } else { Ok(None) } } while let Some(line) = lines.next().await { if let Some(event) = parse_line(line).transpose() { tx.unbounded_send(event).log_err(); } } anyhow::Ok(()) }) .detach(); Ok(rx) } else { let mut body = String::new(); response.body_mut().read_to_string(&mut body).await?; Err(anyhow!( "Failed to connect to OpenAI API: {} {}", response.status(), body, )) } } #[cfg(test)] mod tests {}