cleaned up truncate vs truncate start
This commit is contained in:
parent
9c49191031
commit
0b57ab7303
4 changed files with 56 additions and 32 deletions
|
@ -2,11 +2,20 @@ use anyhow::anyhow;
|
|||
use tiktoken_rs::CoreBPE;
|
||||
use util::ResultExt;
|
||||
|
||||
pub enum TruncationDirection {
|
||||
Start,
|
||||
End,
|
||||
}
|
||||
|
||||
pub trait LanguageModel {
|
||||
fn name(&self) -> String;
|
||||
fn count_tokens(&self, content: &str) -> anyhow::Result<usize>;
|
||||
fn truncate(&self, content: &str, length: usize) -> anyhow::Result<String>;
|
||||
fn truncate_start(&self, content: &str, length: usize) -> anyhow::Result<String>;
|
||||
fn truncate(
|
||||
&self,
|
||||
content: &str,
|
||||
length: usize,
|
||||
direction: TruncationDirection,
|
||||
) -> anyhow::Result<String>;
|
||||
fn capacity(&self) -> anyhow::Result<usize>;
|
||||
}
|
||||
|
||||
|
@ -36,23 +45,19 @@ impl LanguageModel for OpenAILanguageModel {
|
|||
Err(anyhow!("bpe for open ai model was not retrieved"))
|
||||
}
|
||||
}
|
||||
fn truncate(&self, content: &str, length: usize) -> anyhow::Result<String> {
|
||||
fn truncate(
|
||||
&self,
|
||||
content: &str,
|
||||
length: usize,
|
||||
direction: TruncationDirection,
|
||||
) -> anyhow::Result<String> {
|
||||
if let Some(bpe) = &self.bpe {
|
||||
let tokens = bpe.encode_with_special_tokens(content);
|
||||
if tokens.len() > length {
|
||||
bpe.decode(tokens[..length].to_vec())
|
||||
} else {
|
||||
bpe.decode(tokens)
|
||||
}
|
||||
} else {
|
||||
Err(anyhow!("bpe for open ai model was not retrieved"))
|
||||
}
|
||||
}
|
||||
fn truncate_start(&self, content: &str, length: usize) -> anyhow::Result<String> {
|
||||
if let Some(bpe) = &self.bpe {
|
||||
let tokens = bpe.encode_with_special_tokens(content);
|
||||
if tokens.len() > length {
|
||||
bpe.decode(tokens[length..].to_vec())
|
||||
match direction {
|
||||
TruncationDirection::End => bpe.decode(tokens[..length].to_vec()),
|
||||
TruncationDirection::Start => bpe.decode(tokens[length..].to_vec()),
|
||||
}
|
||||
} else {
|
||||
bpe.decode(tokens)
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue