add concept of LanguageModel to CompletionProvider

This commit is contained in:
KCaverly 2023-10-27 08:51:30 +02:00
parent 6c8bb4b05e
commit ec9d79b6fe
5 changed files with 27 additions and 4 deletions

View file

@ -1,11 +1,14 @@
use anyhow::Result; use anyhow::Result;
use futures::{future::BoxFuture, stream::BoxStream}; use futures::{future::BoxFuture, stream::BoxStream};
use crate::models::LanguageModel;
pub trait CompletionRequest: Send + Sync { pub trait CompletionRequest: Send + Sync {
fn data(&self) -> serde_json::Result<String>; fn data(&self) -> serde_json::Result<String>;
} }
pub trait CompletionProvider { pub trait CompletionProvider {
fn base_model(&self) -> Box<dyn LanguageModel>;
fn complete( fn complete(
&self, &self,
prompt: Box<dyn CompletionRequest>, prompt: Box<dyn CompletionRequest>,

View file

@ -12,7 +12,12 @@ use std::{
sync::Arc, sync::Arc,
}; };
use crate::completion::{CompletionProvider, CompletionRequest}; use crate::{
completion::{CompletionProvider, CompletionRequest},
models::LanguageModel,
};
use super::OpenAILanguageModel;
pub const OPENAI_API_URL: &'static str = "https://api.openai.com/v1"; pub const OPENAI_API_URL: &'static str = "https://api.openai.com/v1";
@ -180,17 +185,27 @@ pub async fn stream_completion(
} }
pub struct OpenAICompletionProvider { pub struct OpenAICompletionProvider {
model: OpenAILanguageModel,
api_key: String, api_key: String,
executor: Arc<Background>, executor: Arc<Background>,
} }
impl OpenAICompletionProvider { impl OpenAICompletionProvider {
pub fn new(api_key: String, executor: Arc<Background>) -> Self { pub fn new(model_name: &str, api_key: String, executor: Arc<Background>) -> Self {
Self { api_key, executor } let model = OpenAILanguageModel::load(model_name);
Self {
model,
api_key,
executor,
}
} }
} }
impl CompletionProvider for OpenAICompletionProvider { impl CompletionProvider for OpenAICompletionProvider {
fn base_model(&self) -> Box<dyn LanguageModel> {
let model: Box<dyn LanguageModel> = Box::new(self.model.clone());
model
}
fn complete( fn complete(
&self, &self,
prompt: Box<dyn CompletionRequest>, prompt: Box<dyn CompletionRequest>,

View file

@ -26,7 +26,6 @@ use crate::providers::open_ai::OpenAILanguageModel;
use crate::providers::open_ai::auth::OpenAICredentialProvider; use crate::providers::open_ai::auth::OpenAICredentialProvider;
lazy_static! { lazy_static! {
static ref OPENAI_API_KEY: Option<String> = env::var("OPENAI_API_KEY").ok();
static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap(); static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap();
} }

View file

@ -328,6 +328,7 @@ impl AssistantPanel {
let inline_assist_id = post_inc(&mut self.next_inline_assist_id); let inline_assist_id = post_inc(&mut self.next_inline_assist_id);
let provider = Arc::new(OpenAICompletionProvider::new( let provider = Arc::new(OpenAICompletionProvider::new(
"gpt-4",
api_key, api_key,
cx.background().clone(), cx.background().clone(),
)); ));

View file

@ -335,6 +335,7 @@ fn strip_markdown_codeblock(
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use ai::{models::LanguageModel, test::FakeLanguageModel};
use futures::{ use futures::{
future::BoxFuture, future::BoxFuture,
stream::{self, BoxStream}, stream::{self, BoxStream},
@ -638,6 +639,10 @@ mod tests {
} }
impl CompletionProvider for TestCompletionProvider { impl CompletionProvider for TestCompletionProvider {
fn base_model(&self) -> Box<dyn LanguageModel> {
let model: Box<dyn LanguageModel> = Box::new(FakeLanguageModel { capacity: 8190 });
model
}
fn complete( fn complete(
&self, &self,
_prompt: Box<dyn CompletionRequest>, _prompt: Box<dyn CompletionRequest>,