added openai language model tokenizer and LanguageModel trait

This commit is contained in:
KCaverly 2023-10-17 16:21:03 -04:00
parent ad92fe49c7
commit a874a09b7e
4 changed files with 102 additions and 44 deletions

49
crates/ai/src/models.rs Normal file
View file

@ -0,0 +1,49 @@
use anyhow::anyhow;
use tiktoken_rs::CoreBPE;
use util::ResultExt;
pub trait LanguageModel {
fn name(&self) -> String;
fn count_tokens(&self, content: &str) -> anyhow::Result<usize>;
fn truncate(&self, content: &str, length: usize) -> anyhow::Result<String>;
fn capacity(&self) -> anyhow::Result<usize>;
}
struct OpenAILanguageModel {
name: String,
bpe: Option<CoreBPE>,
}
impl OpenAILanguageModel {
pub fn load(model_name: String) -> Self {
let bpe = tiktoken_rs::get_bpe_from_model(&model_name).log_err();
OpenAILanguageModel {
name: model_name,
bpe,
}
}
}
impl LanguageModel for OpenAILanguageModel {
fn name(&self) -> String {
self.name.clone()
}
fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
if let Some(bpe) = &self.bpe {
anyhow::Ok(bpe.encode_with_special_tokens(content).len())
} else {
Err(anyhow!("bpe for open ai model was not retrieved"))
}
}
fn truncate(&self, content: &str, length: usize) -> anyhow::Result<String> {
if let Some(bpe) = &self.bpe {
let tokens = bpe.encode_with_special_tokens(content);
bpe.decode(tokens[..length].to_vec())
} else {
Err(anyhow!("bpe for open ai model was not retrieved"))
}
}
fn capacity(&self) -> anyhow::Result<usize> {
anyhow::Ok(tiktoken_rs::model::get_context_size(&self.name))
}
}