replace OpenAIRequest with more generalized Box<dyn CompletionRequest>

This commit is contained in:
KCaverly 2023-10-22 14:33:19 +02:00
parent 05ae978cb7
commit d813ae8845
6 changed files with 58 additions and 23 deletions

View file

@ -1,11 +1,13 @@
use anyhow::Result;
use futures::{future::BoxFuture, stream::BoxStream};
use crate::providers::open_ai::completion::OpenAIRequest;
pub trait CompletionRequest: Send + Sync {
fn data(&self) -> serde_json::Result<String>;
}
pub trait CompletionProvider {
fn complete(
&self,
prompt: OpenAIRequest,
prompt: Box<dyn CompletionRequest>,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
}

View file

@ -0,0 +1,13 @@
use crate::completion::CompletionRequest;
use serde::Serialize;
#[derive(Serialize)]
pub struct DummyCompletionRequest {
pub name: String,
}
impl CompletionRequest for DummyCompletionRequest {
fn data(&self) -> serde_json::Result<String> {
serde_json::to_string(self)
}
}

View file

@ -1 +1,2 @@
pub mod dummy;
pub mod open_ai;

View file

@ -12,7 +12,7 @@ use std::{
sync::Arc,
};
use crate::completion::CompletionProvider;
use crate::completion::{CompletionProvider, CompletionRequest};
pub const OPENAI_API_URL: &'static str = "https://api.openai.com/v1";
@ -59,6 +59,12 @@ pub struct OpenAIRequest {
pub temperature: f32,
}
impl CompletionRequest for OpenAIRequest {
fn data(&self) -> serde_json::Result<String> {
serde_json::to_string(self)
}
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
pub struct ResponseMessage {
pub role: Option<Role>,
@ -92,13 +98,11 @@ pub struct OpenAIResponseStreamEvent {
pub async fn stream_completion(
api_key: String,
executor: Arc<Background>,
mut request: OpenAIRequest,
request: Box<dyn CompletionRequest>,
) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
request.stream = true;
let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAIResponseStreamEvent>>();
let json_data = serde_json::to_string(&request)?;
let json_data = request.data()?;
let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions"))
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", api_key))
@ -189,7 +193,7 @@ impl OpenAICompletionProvider {
impl CompletionProvider for OpenAICompletionProvider {
fn complete(
&self,
prompt: OpenAIRequest,
prompt: Box<dyn CompletionRequest>,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
let request = stream_completion(self.api_key.clone(), self.executor.clone(), prompt);
async move {