replace OpenAIRequest with more generalized Box<dyn CompletionRequest>

This commit is contained in:
KCaverly 2023-10-22 14:33:19 +02:00
parent 05ae978cb7
commit d813ae8845
6 changed files with 58 additions and 23 deletions

View file

@ -1,11 +1,13 @@
use anyhow::Result; use anyhow::Result;
use futures::{future::BoxFuture, stream::BoxStream}; use futures::{future::BoxFuture, stream::BoxStream};
use crate::providers::open_ai::completion::OpenAIRequest; pub trait CompletionRequest: Send + Sync {
fn data(&self) -> serde_json::Result<String>;
}
pub trait CompletionProvider { pub trait CompletionProvider {
fn complete( fn complete(
&self, &self,
prompt: OpenAIRequest, prompt: Box<dyn CompletionRequest>,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>; ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
} }

View file

@ -0,0 +1,13 @@
use crate::completion::CompletionRequest;
use serde::Serialize;
#[derive(Serialize)]
pub struct DummyCompletionRequest {
pub name: String,
}
impl CompletionRequest for DummyCompletionRequest {
fn data(&self) -> serde_json::Result<String> {
serde_json::to_string(self)
}
}

View file

@ -1 +1,2 @@
pub mod dummy;
pub mod open_ai; pub mod open_ai;

View file

@ -12,7 +12,7 @@ use std::{
sync::Arc, sync::Arc,
}; };
use crate::completion::CompletionProvider; use crate::completion::{CompletionProvider, CompletionRequest};
pub const OPENAI_API_URL: &'static str = "https://api.openai.com/v1"; pub const OPENAI_API_URL: &'static str = "https://api.openai.com/v1";
@ -59,6 +59,12 @@ pub struct OpenAIRequest {
pub temperature: f32, pub temperature: f32,
} }
impl CompletionRequest for OpenAIRequest {
fn data(&self) -> serde_json::Result<String> {
serde_json::to_string(self)
}
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] #[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
pub struct ResponseMessage { pub struct ResponseMessage {
pub role: Option<Role>, pub role: Option<Role>,
@ -92,13 +98,11 @@ pub struct OpenAIResponseStreamEvent {
pub async fn stream_completion( pub async fn stream_completion(
api_key: String, api_key: String,
executor: Arc<Background>, executor: Arc<Background>,
mut request: OpenAIRequest, request: Box<dyn CompletionRequest>,
) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> { ) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
request.stream = true;
let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAIResponseStreamEvent>>(); let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAIResponseStreamEvent>>();
let json_data = serde_json::to_string(&request)?; let json_data = request.data()?;
let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions")) let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions"))
.header("Content-Type", "application/json") .header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", api_key)) .header("Authorization", format!("Bearer {}", api_key))
@ -189,7 +193,7 @@ impl OpenAICompletionProvider {
impl CompletionProvider for OpenAICompletionProvider { impl CompletionProvider for OpenAICompletionProvider {
fn complete( fn complete(
&self, &self,
prompt: OpenAIRequest, prompt: Box<dyn CompletionRequest>,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> { ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
let request = stream_completion(self.api_key.clone(), self.executor.clone(), prompt); let request = stream_completion(self.api_key.clone(), self.executor.clone(), prompt);
async move { async move {

View file

@ -6,8 +6,11 @@ use crate::{
SavedMessage, SavedMessage,
}; };
use ai::providers::open_ai::{ use ai::{
stream_completion, OpenAICompletionProvider, OpenAIRequest, RequestMessage, OPENAI_API_URL, completion::CompletionRequest,
providers::open_ai::{
stream_completion, OpenAICompletionProvider, OpenAIRequest, RequestMessage, OPENAI_API_URL,
},
}; };
use ai::prompts::repository_context::PromptCodeSnippet; use ai::prompts::repository_context::PromptCodeSnippet;
@ -745,13 +748,14 @@ impl AssistantPanel {
content: prompt, content: prompt,
}); });
let request = OpenAIRequest { let request = Box::new(OpenAIRequest {
model: model.full_name().into(), model: model.full_name().into(),
messages, messages,
stream: true, stream: true,
stop: vec!["|END|>".to_string()], stop: vec!["|END|>".to_string()],
temperature, temperature,
}; });
codegen.update(&mut cx, |codegen, cx| codegen.start(request, cx)); codegen.update(&mut cx, |codegen, cx| codegen.start(request, cx));
anyhow::Ok(()) anyhow::Ok(())
}) })
@ -1735,7 +1739,7 @@ impl Conversation {
return Default::default(); return Default::default();
}; };
let request = OpenAIRequest { let request: Box<dyn CompletionRequest> = Box::new(OpenAIRequest {
model: self.model.full_name().to_string(), model: self.model.full_name().to_string(),
messages: self messages: self
.messages(cx) .messages(cx)
@ -1745,7 +1749,7 @@ impl Conversation {
stream: true, stream: true,
stop: vec![], stop: vec![],
temperature: 1.0, temperature: 1.0,
}; });
let stream = stream_completion(api_key, cx.background().clone(), request); let stream = stream_completion(api_key, cx.background().clone(), request);
let assistant_message = self let assistant_message = self
@ -2025,13 +2029,13 @@ impl Conversation {
"Summarize the conversation into a short title without punctuation" "Summarize the conversation into a short title without punctuation"
.into(), .into(),
})); }));
let request = OpenAIRequest { let request: Box<dyn CompletionRequest> = Box::new(OpenAIRequest {
model: self.model.full_name().to_string(), model: self.model.full_name().to_string(),
messages: messages.collect(), messages: messages.collect(),
stream: true, stream: true,
stop: vec![], stop: vec![],
temperature: 1.0, temperature: 1.0,
}; });
let stream = stream_completion(api_key, cx.background().clone(), request); let stream = stream_completion(api_key, cx.background().clone(), request);
self.pending_summary = cx.spawn(|this, mut cx| { self.pending_summary = cx.spawn(|this, mut cx| {

View file

@ -1,6 +1,5 @@
use crate::streaming_diff::{Hunk, StreamingDiff}; use crate::streaming_diff::{Hunk, StreamingDiff};
use ai::completion::CompletionProvider; use ai::completion::{CompletionProvider, CompletionRequest};
use ai::providers::open_ai::OpenAIRequest;
use anyhow::Result; use anyhow::Result;
use editor::{multi_buffer, Anchor, MultiBuffer, MultiBufferSnapshot, ToOffset, ToPoint}; use editor::{multi_buffer, Anchor, MultiBuffer, MultiBufferSnapshot, ToOffset, ToPoint};
use futures::{channel::mpsc, SinkExt, Stream, StreamExt}; use futures::{channel::mpsc, SinkExt, Stream, StreamExt};
@ -96,7 +95,7 @@ impl Codegen {
self.error.as_ref() self.error.as_ref()
} }
pub fn start(&mut self, prompt: OpenAIRequest, cx: &mut ModelContext<Self>) { pub fn start(&mut self, prompt: Box<dyn CompletionRequest>, cx: &mut ModelContext<Self>) {
let range = self.range(); let range = self.range();
let snapshot = self.snapshot.clone(); let snapshot = self.snapshot.clone();
let selected_text = snapshot let selected_text = snapshot
@ -336,6 +335,7 @@ fn strip_markdown_codeblock(
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use ai::providers::dummy::DummyCompletionRequest;
use futures::{ use futures::{
future::BoxFuture, future::BoxFuture,
stream::{self, BoxStream}, stream::{self, BoxStream},
@ -381,7 +381,10 @@ mod tests {
cx, cx,
) )
}); });
codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx)); let request = Box::new(DummyCompletionRequest {
name: "test".to_string(),
});
codegen.update(cx, |codegen, cx| codegen.start(request, cx));
let mut new_text = concat!( let mut new_text = concat!(
" let mut x = 0;\n", " let mut x = 0;\n",
@ -443,7 +446,11 @@ mod tests {
cx, cx,
) )
}); });
codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx));
let request = Box::new(DummyCompletionRequest {
name: "test".to_string(),
});
codegen.update(cx, |codegen, cx| codegen.start(request, cx));
let mut new_text = concat!( let mut new_text = concat!(
"t mut x = 0;\n", "t mut x = 0;\n",
@ -505,7 +512,11 @@ mod tests {
cx, cx,
) )
}); });
codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx));
let request = Box::new(DummyCompletionRequest {
name: "test".to_string(),
});
codegen.update(cx, |codegen, cx| codegen.start(request, cx));
let mut new_text = concat!( let mut new_text = concat!(
"let mut x = 0;\n", "let mut x = 0;\n",
@ -617,7 +628,7 @@ mod tests {
impl CompletionProvider for TestCompletionProvider { impl CompletionProvider for TestCompletionProvider {
fn complete( fn complete(
&self, &self,
_prompt: OpenAIRequest, _prompt: Box<dyn CompletionRequest>,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> { ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
let (tx, rx) = mpsc::channel(1); let (tx, rx) = mpsc::channel(1);
*self.last_completion_tx.lock() = Some(tx); *self.last_completion_tx.lock() = Some(tx);