use anyhow::anyhow; use tiktoken_rs::CoreBPE; use util::ResultExt; pub trait LanguageModel { fn name(&self) -> String; fn count_tokens(&self, content: &str) -> anyhow::Result; fn truncate(&self, content: &str, length: usize) -> anyhow::Result; fn truncate_start(&self, content: &str, length: usize) -> anyhow::Result; fn capacity(&self) -> anyhow::Result; } pub struct OpenAILanguageModel { name: String, bpe: Option, } impl OpenAILanguageModel { pub fn load(model_name: &str) -> Self { let bpe = tiktoken_rs::get_bpe_from_model(model_name).log_err(); OpenAILanguageModel { name: model_name.to_string(), bpe, } } } impl LanguageModel for OpenAILanguageModel { fn name(&self) -> String { self.name.clone() } fn count_tokens(&self, content: &str) -> anyhow::Result { if let Some(bpe) = &self.bpe { anyhow::Ok(bpe.encode_with_special_tokens(content).len()) } else { Err(anyhow!("bpe for open ai model was not retrieved")) } } fn truncate(&self, content: &str, length: usize) -> anyhow::Result { if let Some(bpe) = &self.bpe { let tokens = bpe.encode_with_special_tokens(content); if tokens.len() > length { bpe.decode(tokens[..length].to_vec()) } else { bpe.decode(tokens) } } else { Err(anyhow!("bpe for open ai model was not retrieved")) } } fn truncate_start(&self, content: &str, length: usize) -> anyhow::Result { if let Some(bpe) = &self.bpe { let tokens = bpe.encode_with_special_tokens(content); if tokens.len() > length { bpe.decode(tokens[length..].to_vec()) } else { bpe.decode(tokens) } } else { Err(anyhow!("bpe for open ai model was not retrieved")) } } fn capacity(&self) -> anyhow::Result { anyhow::Ok(tiktoken_rs::model::get_context_size(&self.name)) } }