replace OpenAIRequest with more generalized Box<dyn CompletionRequest>
This commit is contained in:
parent
05ae978cb7
commit
d813ae8845
6 changed files with 58 additions and 23 deletions
|
@ -1,11 +1,13 @@
|
|||
use anyhow::Result;
|
||||
use futures::{future::BoxFuture, stream::BoxStream};
|
||||
|
||||
use crate::providers::open_ai::completion::OpenAIRequest;
|
||||
pub trait CompletionRequest: Send + Sync {
|
||||
fn data(&self) -> serde_json::Result<String>;
|
||||
}
|
||||
|
||||
pub trait CompletionProvider {
|
||||
fn complete(
|
||||
&self,
|
||||
prompt: OpenAIRequest,
|
||||
prompt: Box<dyn CompletionRequest>,
|
||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
|
||||
}
|
||||
|
|
13
crates/ai/src/providers/dummy.rs
Normal file
13
crates/ai/src/providers/dummy.rs
Normal file
|
@ -0,0 +1,13 @@
|
|||
use crate::completion::CompletionRequest;
|
||||
use serde::Serialize;
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct DummyCompletionRequest {
|
||||
pub name: String,
|
||||
}
|
||||
|
||||
impl CompletionRequest for DummyCompletionRequest {
|
||||
fn data(&self) -> serde_json::Result<String> {
|
||||
serde_json::to_string(self)
|
||||
}
|
||||
}
|
|
@ -1 +1,2 @@
|
|||
pub mod dummy;
|
||||
pub mod open_ai;
|
||||
|
|
|
@ -12,7 +12,7 @@ use std::{
|
|||
sync::Arc,
|
||||
};
|
||||
|
||||
use crate::completion::CompletionProvider;
|
||||
use crate::completion::{CompletionProvider, CompletionRequest};
|
||||
|
||||
pub const OPENAI_API_URL: &'static str = "https://api.openai.com/v1";
|
||||
|
||||
|
@ -59,6 +59,12 @@ pub struct OpenAIRequest {
|
|||
pub temperature: f32,
|
||||
}
|
||||
|
||||
impl CompletionRequest for OpenAIRequest {
|
||||
fn data(&self) -> serde_json::Result<String> {
|
||||
serde_json::to_string(self)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
pub struct ResponseMessage {
|
||||
pub role: Option<Role>,
|
||||
|
@ -92,13 +98,11 @@ pub struct OpenAIResponseStreamEvent {
|
|||
pub async fn stream_completion(
|
||||
api_key: String,
|
||||
executor: Arc<Background>,
|
||||
mut request: OpenAIRequest,
|
||||
request: Box<dyn CompletionRequest>,
|
||||
) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
|
||||
request.stream = true;
|
||||
|
||||
let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAIResponseStreamEvent>>();
|
||||
|
||||
let json_data = serde_json::to_string(&request)?;
|
||||
let json_data = request.data()?;
|
||||
let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions"))
|
||||
.header("Content-Type", "application/json")
|
||||
.header("Authorization", format!("Bearer {}", api_key))
|
||||
|
@ -189,7 +193,7 @@ impl OpenAICompletionProvider {
|
|||
impl CompletionProvider for OpenAICompletionProvider {
|
||||
fn complete(
|
||||
&self,
|
||||
prompt: OpenAIRequest,
|
||||
prompt: Box<dyn CompletionRequest>,
|
||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
||||
let request = stream_completion(self.api_key.clone(), self.executor.clone(), prompt);
|
||||
async move {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue