move OpenAILanguageModel to providers folder
This commit is contained in:
parent
a62baf34f2
commit
3712794e56
6 changed files with 62 additions and 56 deletions
|
@ -2,3 +2,4 @@ pub mod completion;
|
||||||
pub mod embedding;
|
pub mod embedding;
|
||||||
pub mod models;
|
pub mod models;
|
||||||
pub mod prompts;
|
pub mod prompts;
|
||||||
|
pub mod providers;
|
||||||
|
|
|
@ -1,7 +1,3 @@
|
||||||
use anyhow::anyhow;
|
|
||||||
use tiktoken_rs::CoreBPE;
|
|
||||||
use util::ResultExt;
|
|
||||||
|
|
||||||
pub enum TruncationDirection {
|
pub enum TruncationDirection {
|
||||||
Start,
|
Start,
|
||||||
End,
|
End,
|
||||||
|
@ -18,54 +14,3 @@ pub trait LanguageModel {
|
||||||
) -> anyhow::Result<String>;
|
) -> anyhow::Result<String>;
|
||||||
fn capacity(&self) -> anyhow::Result<usize>;
|
fn capacity(&self) -> anyhow::Result<usize>;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct OpenAILanguageModel {
|
|
||||||
name: String,
|
|
||||||
bpe: Option<CoreBPE>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl OpenAILanguageModel {
|
|
||||||
pub fn load(model_name: &str) -> Self {
|
|
||||||
let bpe = tiktoken_rs::get_bpe_from_model(model_name).log_err();
|
|
||||||
OpenAILanguageModel {
|
|
||||||
name: model_name.to_string(),
|
|
||||||
bpe,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl LanguageModel for OpenAILanguageModel {
|
|
||||||
fn name(&self) -> String {
|
|
||||||
self.name.clone()
|
|
||||||
}
|
|
||||||
fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
|
|
||||||
if let Some(bpe) = &self.bpe {
|
|
||||||
anyhow::Ok(bpe.encode_with_special_tokens(content).len())
|
|
||||||
} else {
|
|
||||||
Err(anyhow!("bpe for open ai model was not retrieved"))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
fn truncate(
|
|
||||||
&self,
|
|
||||||
content: &str,
|
|
||||||
length: usize,
|
|
||||||
direction: TruncationDirection,
|
|
||||||
) -> anyhow::Result<String> {
|
|
||||||
if let Some(bpe) = &self.bpe {
|
|
||||||
let tokens = bpe.encode_with_special_tokens(content);
|
|
||||||
if tokens.len() > length {
|
|
||||||
match direction {
|
|
||||||
TruncationDirection::End => bpe.decode(tokens[..length].to_vec()),
|
|
||||||
TruncationDirection::Start => bpe.decode(tokens[length..].to_vec()),
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
bpe.decode(tokens)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
Err(anyhow!("bpe for open ai model was not retrieved"))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
fn capacity(&self) -> anyhow::Result<usize> {
|
|
||||||
anyhow::Ok(tiktoken_rs::model::get_context_size(&self.name))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
1
crates/ai/src/providers/mod.rs
Normal file
1
crates/ai/src/providers/mod.rs
Normal file
|
@ -0,0 +1 @@
|
||||||
|
pub mod open_ai;
|
2
crates/ai/src/providers/open_ai/mod.rs
Normal file
2
crates/ai/src/providers/open_ai/mod.rs
Normal file
|
@ -0,0 +1,2 @@
|
||||||
|
pub mod model;
|
||||||
|
pub use model::OpenAILanguageModel;
|
56
crates/ai/src/providers/open_ai/model.rs
Normal file
56
crates/ai/src/providers/open_ai/model.rs
Normal file
|
@ -0,0 +1,56 @@
|
||||||
|
use anyhow::anyhow;
|
||||||
|
use tiktoken_rs::CoreBPE;
|
||||||
|
use util::ResultExt;
|
||||||
|
|
||||||
|
use crate::models::{LanguageModel, TruncationDirection};
|
||||||
|
|
||||||
|
pub struct OpenAILanguageModel {
|
||||||
|
name: String,
|
||||||
|
bpe: Option<CoreBPE>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl OpenAILanguageModel {
|
||||||
|
pub fn load(model_name: &str) -> Self {
|
||||||
|
let bpe = tiktoken_rs::get_bpe_from_model(model_name).log_err();
|
||||||
|
OpenAILanguageModel {
|
||||||
|
name: model_name.to_string(),
|
||||||
|
bpe,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl LanguageModel for OpenAILanguageModel {
|
||||||
|
fn name(&self) -> String {
|
||||||
|
self.name.clone()
|
||||||
|
}
|
||||||
|
fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
|
||||||
|
if let Some(bpe) = &self.bpe {
|
||||||
|
anyhow::Ok(bpe.encode_with_special_tokens(content).len())
|
||||||
|
} else {
|
||||||
|
Err(anyhow!("bpe for open ai model was not retrieved"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fn truncate(
|
||||||
|
&self,
|
||||||
|
content: &str,
|
||||||
|
length: usize,
|
||||||
|
direction: TruncationDirection,
|
||||||
|
) -> anyhow::Result<String> {
|
||||||
|
if let Some(bpe) = &self.bpe {
|
||||||
|
let tokens = bpe.encode_with_special_tokens(content);
|
||||||
|
if tokens.len() > length {
|
||||||
|
match direction {
|
||||||
|
TruncationDirection::End => bpe.decode(tokens[..length].to_vec()),
|
||||||
|
TruncationDirection::Start => bpe.decode(tokens[length..].to_vec()),
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
bpe.decode(tokens)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
Err(anyhow!("bpe for open ai model was not retrieved"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fn capacity(&self) -> anyhow::Result<usize> {
|
||||||
|
anyhow::Ok(tiktoken_rs::model::get_context_size(&self.name))
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,9 +1,10 @@
|
||||||
use ai::models::{LanguageModel, OpenAILanguageModel};
|
use ai::models::LanguageModel;
|
||||||
use ai::prompts::base::{PromptArguments, PromptChain, PromptPriority, PromptTemplate};
|
use ai::prompts::base::{PromptArguments, PromptChain, PromptPriority, PromptTemplate};
|
||||||
use ai::prompts::file_context::FileContext;
|
use ai::prompts::file_context::FileContext;
|
||||||
use ai::prompts::generate::GenerateInlineContent;
|
use ai::prompts::generate::GenerateInlineContent;
|
||||||
use ai::prompts::preamble::EngineerPreamble;
|
use ai::prompts::preamble::EngineerPreamble;
|
||||||
use ai::prompts::repository_context::{PromptCodeSnippet, RepositoryContext};
|
use ai::prompts::repository_context::{PromptCodeSnippet, RepositoryContext};
|
||||||
|
use ai::providers::open_ai::OpenAILanguageModel;
|
||||||
use language::{BufferSnapshot, OffsetRangeExt, ToOffset};
|
use language::{BufferSnapshot, OffsetRangeExt, ToOffset};
|
||||||
use std::cmp::{self, Reverse};
|
use std::cmp::{self, Reverse};
|
||||||
use std::ops::Range;
|
use std::ops::Range;
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue