ZIm/crates/ai/src/ai.rs
Nathan Sobo 30de64845f WIP
2023-05-22 23:11:27 -06:00

229 lines
7 KiB
Rust

use anyhow::{anyhow, Result};
use editor::Editor;
use futures::AsyncBufReadExt;
use futures::{io::BufReader, AsyncReadExt, Stream, StreamExt};
use gpui::executor::Background;
use gpui::{actions, AppContext, Task, ViewContext};
use indoc::indoc;
use isahc::prelude::*;
use isahc::{http::StatusCode, Request};
use serde::{Deserialize, Serialize};
use std::{io, sync::Arc};
use util::ResultExt;
actions!(ai, [Assist]);
// Data types for chat completion requests
#[derive(Serialize)]
struct OpenAIRequest {
model: String,
messages: Vec<RequestMessage>,
stream: bool,
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
struct RequestMessage {
role: Role,
content: String,
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
struct ResponseMessage {
role: Option<Role>,
content: Option<String>,
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
#[serde(rename_all = "lowercase")]
enum Role {
User,
Assistant,
System,
}
#[derive(Deserialize, Debug)]
struct OpenAIResponseStreamEvent {
pub id: Option<String>,
pub object: String,
pub created: u32,
pub model: String,
pub choices: Vec<ChatChoiceDelta>,
pub usage: Option<Usage>,
}
#[derive(Deserialize, Debug)]
struct Usage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
}
#[derive(Deserialize, Debug)]
struct ChatChoiceDelta {
pub index: u32,
pub delta: ResponseMessage,
pub finish_reason: Option<String>,
}
#[derive(Deserialize, Debug)]
struct OpenAIUsage {
prompt_tokens: u64,
completion_tokens: u64,
total_tokens: u64,
}
#[derive(Deserialize, Debug)]
struct OpenAIChoice {
text: String,
index: u32,
logprobs: Option<serde_json::Value>,
finish_reason: Option<String>,
}
pub fn init(cx: &mut AppContext) {
cx.add_async_action(assist)
}
fn assist(
editor: &mut Editor,
_: &Assist,
cx: &mut ViewContext<Editor>,
) -> Option<Task<Result<()>>> {
let api_key = std::env::var("OPENAI_API_KEY").log_err()?;
const SYSTEM_MESSAGE: &'static str = indoc! {r#"
You an AI language model embedded in a code editor named Zed, authored by Zed Industries.
The input you are currently processing was produced by a special \"model mention\" in a document that is open in the editor.
A model mention is indicated via a leading / on a line.
The user's currently selected text is indicated via ->->selected text<-<- surrounding selected text.
In this sentence, the word ->->example<-<- is selected.
Respond to any selected model mention.
Summarize each mention in a single short sentence like:
> The user selected the word \"example\".
Then provide your response to that mention below its summary.
"#};
let (user_message, insertion_site) = editor.buffer().update(cx, |buffer, cx| {
// Insert ->-> <-<- around selected text as described in the system prompt above.
let snapshot = buffer.snapshot(cx);
let mut user_message = String::new();
let mut buffer_offset = 0;
for selection in editor.selections.all(cx) {
user_message.extend(snapshot.text_for_range(buffer_offset..selection.start));
user_message.push_str("->->");
user_message.extend(snapshot.text_for_range(selection.start..selection.end));
buffer_offset = selection.end;
user_message.push_str("<-<-");
}
if buffer_offset < snapshot.len() {
user_message.extend(snapshot.text_for_range(buffer_offset..snapshot.len()));
}
// Ensure the document ends with 4 trailing newlines.
let trailing_newline_count = snapshot
.reversed_chars_at(snapshot.len())
.take_while(|c| *c == '\n')
.take(4);
let suffix = "\n".repeat(4 - trailing_newline_count.count());
buffer.edit([(snapshot.len()..snapshot.len(), suffix)], None, cx);
let snapshot = buffer.snapshot(cx); // Take a new snapshot after editing.
let insertion_site = snapshot.len() - 2; // Insert text at end of buffer, with an empty line both above and below.
(user_message, insertion_site)
});
let stream = stream_completion(
api_key,
cx.background_executor().clone(),
OpenAIRequest {
model: "gpt-4".to_string(),
messages: vec![
RequestMessage {
role: Role::System,
content: SYSTEM_MESSAGE.to_string(),
},
RequestMessage {
role: Role::User,
content: user_message,
},
],
stream: false,
},
);
let buffer = editor.buffer().clone();
Some(cx.spawn(|_, mut cx| async move {
let mut messages = stream.await?;
while let Some(message) = messages.next().await {
let mut message = message?;
if let Some(choice) = message.choices.pop() {
buffer.update(&mut cx, |buffer, cx| {
let text: Arc<str> = choice.delta.content?.into();
buffer.edit([(insertion_site.clone()..insertion_site, text)], None, cx);
Some(())
});
}
}
Ok(())
}))
}
async fn stream_completion(
api_key: String,
executor: Arc<Background>,
mut request: OpenAIRequest,
) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
request.stream = true;
let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAIResponseStreamEvent>>();
let json_data = serde_json::to_string(&request)?;
let mut response = Request::post("https://api.openai.com/v1/chat/completions")
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", api_key))
.body(json_data)?
.send_async()
.await?;
let status = response.status();
if status == StatusCode::OK {
executor
.spawn(async move {
let mut lines = BufReader::new(response.body_mut()).lines();
fn parse_line(
line: Result<String, io::Error>,
) -> Result<Option<OpenAIResponseStreamEvent>> {
if let Some(data) = line?.strip_prefix("data: ") {
let event = serde_json::from_str(&data)?;
Ok(Some(event))
} else {
Ok(None)
}
}
while let Some(line) = lines.next().await {
if let Some(event) = parse_line(line).transpose() {
tx.unbounded_send(event).log_err();
}
}
anyhow::Ok(())
})
.detach();
Ok(rx)
} else {
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
Err(anyhow!(
"Failed to connect to OpenAI API: {} {}",
response.status(),
body,
))
}
}
#[cfg(test)]
mod tests {}