port ai crate to ai2, with all tests passing
This commit is contained in:
parent
204aba07f6
commit
04ab68502b
22 changed files with 1930 additions and 0 deletions
28
Cargo.lock
generated
28
Cargo.lock
generated
|
@ -108,6 +108,33 @@ dependencies = [
|
|||
"util",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ai2"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"async-trait",
|
||||
"bincode",
|
||||
"futures 0.3.28",
|
||||
"gpui2",
|
||||
"isahc",
|
||||
"language2",
|
||||
"lazy_static",
|
||||
"log",
|
||||
"matrixmultiply",
|
||||
"ordered-float 2.10.0",
|
||||
"parking_lot 0.11.2",
|
||||
"parse_duration",
|
||||
"postage",
|
||||
"rand 0.8.5",
|
||||
"regex",
|
||||
"rusqlite",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"tiktoken-rs",
|
||||
"util",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "alacritty_config"
|
||||
version = "0.1.2-dev"
|
||||
|
@ -10903,6 +10930,7 @@ dependencies = [
|
|||
name = "zed2"
|
||||
version = "0.109.0"
|
||||
dependencies = [
|
||||
"ai2",
|
||||
"anyhow",
|
||||
"async-compression",
|
||||
"async-recursion 0.3.2",
|
||||
|
|
38
crates/Cargo.toml
Normal file
38
crates/Cargo.toml
Normal file
|
@ -0,0 +1,38 @@
|
|||
[package]
|
||||
name = "ai"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
publish = false
|
||||
|
||||
[lib]
|
||||
path = "src/ai.rs"
|
||||
doctest = false
|
||||
|
||||
[features]
|
||||
test-support = []
|
||||
|
||||
[dependencies]
|
||||
gpui = { path = "../gpui" }
|
||||
util = { path = "../util" }
|
||||
language = { path = "../language" }
|
||||
async-trait.workspace = true
|
||||
anyhow.workspace = true
|
||||
futures.workspace = true
|
||||
lazy_static.workspace = true
|
||||
ordered-float.workspace = true
|
||||
parking_lot.workspace = true
|
||||
isahc.workspace = true
|
||||
regex.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
postage.workspace = true
|
||||
rand.workspace = true
|
||||
log.workspace = true
|
||||
parse_duration = "2.1.1"
|
||||
tiktoken-rs = "0.5.0"
|
||||
matrixmultiply = "0.3.7"
|
||||
rusqlite = { version = "0.29.0", features = ["blob", "array", "modern_sqlite"] }
|
||||
bincode = "1.3.3"
|
||||
|
||||
[dev-dependencies]
|
||||
gpui = { path = "../gpui", features = ["test-support"] }
|
38
crates/ai2/Cargo.toml
Normal file
38
crates/ai2/Cargo.toml
Normal file
|
@ -0,0 +1,38 @@
|
|||
[package]
|
||||
name = "ai2"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
publish = false
|
||||
|
||||
[lib]
|
||||
path = "src/ai2.rs"
|
||||
doctest = false
|
||||
|
||||
[features]
|
||||
test-support = []
|
||||
|
||||
[dependencies]
|
||||
gpui2 = { path = "../gpui2" }
|
||||
util = { path = "../util" }
|
||||
language2 = { path = "../language2" }
|
||||
async-trait.workspace = true
|
||||
anyhow.workspace = true
|
||||
futures.workspace = true
|
||||
lazy_static.workspace = true
|
||||
ordered-float.workspace = true
|
||||
parking_lot.workspace = true
|
||||
isahc.workspace = true
|
||||
regex.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
postage.workspace = true
|
||||
rand.workspace = true
|
||||
log.workspace = true
|
||||
parse_duration = "2.1.1"
|
||||
tiktoken-rs = "0.5.0"
|
||||
matrixmultiply = "0.3.7"
|
||||
rusqlite = { version = "0.29.0", features = ["blob", "array", "modern_sqlite"] }
|
||||
bincode = "1.3.3"
|
||||
|
||||
[dev-dependencies]
|
||||
gpui2 = { path = "../gpui2", features = ["test-support"] }
|
8
crates/ai2/src/ai2.rs
Normal file
8
crates/ai2/src/ai2.rs
Normal file
|
@ -0,0 +1,8 @@
|
|||
pub mod auth;
|
||||
pub mod completion;
|
||||
pub mod embedding;
|
||||
pub mod models;
|
||||
pub mod prompts;
|
||||
pub mod providers;
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
pub mod test;
|
17
crates/ai2/src/auth.rs
Normal file
17
crates/ai2/src/auth.rs
Normal file
|
@ -0,0 +1,17 @@
|
|||
use async_trait::async_trait;
|
||||
use gpui2::AppContext;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub enum ProviderCredential {
|
||||
Credentials { api_key: String },
|
||||
NoCredentials,
|
||||
NotNeeded,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait CredentialProvider: Send + Sync {
|
||||
fn has_credentials(&self) -> bool;
|
||||
async fn retrieve_credentials(&self, cx: &mut AppContext) -> ProviderCredential;
|
||||
async fn save_credentials(&self, cx: &mut AppContext, credential: ProviderCredential);
|
||||
async fn delete_credentials(&self, cx: &mut AppContext);
|
||||
}
|
23
crates/ai2/src/completion.rs
Normal file
23
crates/ai2/src/completion.rs
Normal file
|
@ -0,0 +1,23 @@
|
|||
use anyhow::Result;
|
||||
use futures::{future::BoxFuture, stream::BoxStream};
|
||||
|
||||
use crate::{auth::CredentialProvider, models::LanguageModel};
|
||||
|
||||
pub trait CompletionRequest: Send + Sync {
|
||||
fn data(&self) -> serde_json::Result<String>;
|
||||
}
|
||||
|
||||
pub trait CompletionProvider: CredentialProvider {
|
||||
fn base_model(&self) -> Box<dyn LanguageModel>;
|
||||
fn complete(
|
||||
&self,
|
||||
prompt: Box<dyn CompletionRequest>,
|
||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
|
||||
fn box_clone(&self) -> Box<dyn CompletionProvider>;
|
||||
}
|
||||
|
||||
impl Clone for Box<dyn CompletionProvider> {
|
||||
fn clone(&self) -> Box<dyn CompletionProvider> {
|
||||
self.box_clone()
|
||||
}
|
||||
}
|
123
crates/ai2/src/embedding.rs
Normal file
123
crates/ai2/src/embedding.rs
Normal file
|
@ -0,0 +1,123 @@
|
|||
use std::time::Instant;
|
||||
|
||||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use ordered_float::OrderedFloat;
|
||||
use rusqlite::types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef};
|
||||
use rusqlite::ToSql;
|
||||
|
||||
use crate::auth::CredentialProvider;
|
||||
use crate::models::LanguageModel;
|
||||
|
||||
#[derive(Debug, PartialEq, Clone)]
|
||||
pub struct Embedding(pub Vec<f32>);
|
||||
|
||||
// This is needed for semantic index functionality
|
||||
// Unfortunately it has to live wherever the "Embedding" struct is created.
|
||||
// Keeping this in here though, introduces a 'rusqlite' dependency into AI
|
||||
// which is less than ideal
|
||||
impl FromSql for Embedding {
|
||||
fn column_result(value: ValueRef) -> FromSqlResult<Self> {
|
||||
let bytes = value.as_blob()?;
|
||||
let embedding: Result<Vec<f32>, Box<bincode::ErrorKind>> = bincode::deserialize(bytes);
|
||||
if embedding.is_err() {
|
||||
return Err(rusqlite::types::FromSqlError::Other(embedding.unwrap_err()));
|
||||
}
|
||||
Ok(Embedding(embedding.unwrap()))
|
||||
}
|
||||
}
|
||||
|
||||
impl ToSql for Embedding {
|
||||
fn to_sql(&self) -> rusqlite::Result<ToSqlOutput> {
|
||||
let bytes = bincode::serialize(&self.0)
|
||||
.map_err(|err| rusqlite::Error::ToSqlConversionFailure(Box::new(err)))?;
|
||||
Ok(ToSqlOutput::Owned(rusqlite::types::Value::Blob(bytes)))
|
||||
}
|
||||
}
|
||||
impl From<Vec<f32>> for Embedding {
|
||||
fn from(value: Vec<f32>) -> Self {
|
||||
Embedding(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl Embedding {
|
||||
pub fn similarity(&self, other: &Self) -> OrderedFloat<f32> {
|
||||
let len = self.0.len();
|
||||
assert_eq!(len, other.0.len());
|
||||
|
||||
let mut result = 0.0;
|
||||
unsafe {
|
||||
matrixmultiply::sgemm(
|
||||
1,
|
||||
len,
|
||||
1,
|
||||
1.0,
|
||||
self.0.as_ptr(),
|
||||
len as isize,
|
||||
1,
|
||||
other.0.as_ptr(),
|
||||
1,
|
||||
len as isize,
|
||||
0.0,
|
||||
&mut result as *mut f32,
|
||||
1,
|
||||
1,
|
||||
);
|
||||
}
|
||||
OrderedFloat(result)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait EmbeddingProvider: CredentialProvider {
|
||||
fn base_model(&self) -> Box<dyn LanguageModel>;
|
||||
async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>>;
|
||||
fn max_tokens_per_batch(&self) -> usize;
|
||||
fn rate_limit_expiration(&self) -> Option<Instant>;
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use rand::prelude::*;
|
||||
|
||||
#[gpui2::test]
|
||||
fn test_similarity(mut rng: StdRng) {
|
||||
assert_eq!(
|
||||
Embedding::from(vec![1., 0., 0., 0., 0.])
|
||||
.similarity(&Embedding::from(vec![0., 1., 0., 0., 0.])),
|
||||
0.
|
||||
);
|
||||
assert_eq!(
|
||||
Embedding::from(vec![2., 0., 0., 0., 0.])
|
||||
.similarity(&Embedding::from(vec![3., 1., 0., 0., 0.])),
|
||||
6.
|
||||
);
|
||||
|
||||
for _ in 0..100 {
|
||||
let size = 1536;
|
||||
let mut a = vec![0.; size];
|
||||
let mut b = vec![0.; size];
|
||||
for (a, b) in a.iter_mut().zip(b.iter_mut()) {
|
||||
*a = rng.gen();
|
||||
*b = rng.gen();
|
||||
}
|
||||
let a = Embedding::from(a);
|
||||
let b = Embedding::from(b);
|
||||
|
||||
assert_eq!(
|
||||
round_to_decimals(a.similarity(&b), 1),
|
||||
round_to_decimals(reference_dot(&a.0, &b.0), 1)
|
||||
);
|
||||
}
|
||||
|
||||
fn round_to_decimals(n: OrderedFloat<f32>, decimal_places: i32) -> f32 {
|
||||
let factor = (10.0 as f32).powi(decimal_places);
|
||||
(n * factor).round() / factor
|
||||
}
|
||||
|
||||
fn reference_dot(a: &[f32], b: &[f32]) -> OrderedFloat<f32> {
|
||||
OrderedFloat(a.iter().zip(b.iter()).map(|(a, b)| a * b).sum())
|
||||
}
|
||||
}
|
||||
}
|
16
crates/ai2/src/models.rs
Normal file
16
crates/ai2/src/models.rs
Normal file
|
@ -0,0 +1,16 @@
|
|||
pub enum TruncationDirection {
|
||||
Start,
|
||||
End,
|
||||
}
|
||||
|
||||
pub trait LanguageModel {
|
||||
fn name(&self) -> String;
|
||||
fn count_tokens(&self, content: &str) -> anyhow::Result<usize>;
|
||||
fn truncate(
|
||||
&self,
|
||||
content: &str,
|
||||
length: usize,
|
||||
direction: TruncationDirection,
|
||||
) -> anyhow::Result<String>;
|
||||
fn capacity(&self) -> anyhow::Result<usize>;
|
||||
}
|
330
crates/ai2/src/prompts/base.rs
Normal file
330
crates/ai2/src/prompts/base.rs
Normal file
|
@ -0,0 +1,330 @@
|
|||
use std::cmp::Reverse;
|
||||
use std::ops::Range;
|
||||
use std::sync::Arc;
|
||||
|
||||
use language2::BufferSnapshot;
|
||||
use util::ResultExt;
|
||||
|
||||
use crate::models::LanguageModel;
|
||||
use crate::prompts::repository_context::PromptCodeSnippet;
|
||||
|
||||
pub(crate) enum PromptFileType {
|
||||
Text,
|
||||
Code,
|
||||
}
|
||||
|
||||
// TODO: Set this up to manage for defaults well
|
||||
pub struct PromptArguments {
|
||||
pub model: Arc<dyn LanguageModel>,
|
||||
pub user_prompt: Option<String>,
|
||||
pub language_name: Option<String>,
|
||||
pub project_name: Option<String>,
|
||||
pub snippets: Vec<PromptCodeSnippet>,
|
||||
pub reserved_tokens: usize,
|
||||
pub buffer: Option<BufferSnapshot>,
|
||||
pub selected_range: Option<Range<usize>>,
|
||||
}
|
||||
|
||||
impl PromptArguments {
|
||||
pub(crate) fn get_file_type(&self) -> PromptFileType {
|
||||
if self
|
||||
.language_name
|
||||
.as_ref()
|
||||
.and_then(|name| Some(!["Markdown", "Plain Text"].contains(&name.as_str())))
|
||||
.unwrap_or(true)
|
||||
{
|
||||
PromptFileType::Code
|
||||
} else {
|
||||
PromptFileType::Text
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait PromptTemplate {
|
||||
fn generate(
|
||||
&self,
|
||||
args: &PromptArguments,
|
||||
max_token_length: Option<usize>,
|
||||
) -> anyhow::Result<(String, usize)>;
|
||||
}
|
||||
|
||||
#[repr(i8)]
|
||||
#[derive(PartialEq, Eq, Ord)]
|
||||
pub enum PromptPriority {
|
||||
Mandatory, // Ignores truncation
|
||||
Ordered { order: usize }, // Truncates based on priority
|
||||
}
|
||||
|
||||
impl PartialOrd for PromptPriority {
|
||||
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
|
||||
match (self, other) {
|
||||
(Self::Mandatory, Self::Mandatory) => Some(std::cmp::Ordering::Equal),
|
||||
(Self::Mandatory, Self::Ordered { .. }) => Some(std::cmp::Ordering::Greater),
|
||||
(Self::Ordered { .. }, Self::Mandatory) => Some(std::cmp::Ordering::Less),
|
||||
(Self::Ordered { order: a }, Self::Ordered { order: b }) => b.partial_cmp(a),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct PromptChain {
|
||||
args: PromptArguments,
|
||||
templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)>,
|
||||
}
|
||||
|
||||
impl PromptChain {
|
||||
pub fn new(
|
||||
args: PromptArguments,
|
||||
templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)>,
|
||||
) -> Self {
|
||||
PromptChain { args, templates }
|
||||
}
|
||||
|
||||
pub fn generate(&self, truncate: bool) -> anyhow::Result<(String, usize)> {
|
||||
// Argsort based on Prompt Priority
|
||||
let seperator = "\n";
|
||||
let seperator_tokens = self.args.model.count_tokens(seperator)?;
|
||||
let mut sorted_indices = (0..self.templates.len()).collect::<Vec<_>>();
|
||||
sorted_indices.sort_by_key(|&i| Reverse(&self.templates[i].0));
|
||||
|
||||
// If Truncate
|
||||
let mut tokens_outstanding = if truncate {
|
||||
Some(self.args.model.capacity()? - self.args.reserved_tokens)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let mut prompts = vec!["".to_string(); sorted_indices.len()];
|
||||
for idx in sorted_indices {
|
||||
let (_, template) = &self.templates[idx];
|
||||
|
||||
if let Some((template_prompt, prompt_token_count)) =
|
||||
template.generate(&self.args, tokens_outstanding).log_err()
|
||||
{
|
||||
if template_prompt != "" {
|
||||
prompts[idx] = template_prompt;
|
||||
|
||||
if let Some(remaining_tokens) = tokens_outstanding {
|
||||
let new_tokens = prompt_token_count + seperator_tokens;
|
||||
tokens_outstanding = if remaining_tokens > new_tokens {
|
||||
Some(remaining_tokens - new_tokens)
|
||||
} else {
|
||||
Some(0)
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
prompts.retain(|x| x != "");
|
||||
|
||||
let full_prompt = prompts.join(seperator);
|
||||
let total_token_count = self.args.model.count_tokens(&full_prompt)?;
|
||||
anyhow::Ok((prompts.join(seperator), total_token_count))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) mod tests {
|
||||
use crate::models::TruncationDirection;
|
||||
use crate::test::FakeLanguageModel;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
pub fn test_prompt_chain() {
|
||||
struct TestPromptTemplate {}
|
||||
impl PromptTemplate for TestPromptTemplate {
|
||||
fn generate(
|
||||
&self,
|
||||
args: &PromptArguments,
|
||||
max_token_length: Option<usize>,
|
||||
) -> anyhow::Result<(String, usize)> {
|
||||
let mut content = "This is a test prompt template".to_string();
|
||||
|
||||
let mut token_count = args.model.count_tokens(&content)?;
|
||||
if let Some(max_token_length) = max_token_length {
|
||||
if token_count > max_token_length {
|
||||
content = args.model.truncate(
|
||||
&content,
|
||||
max_token_length,
|
||||
TruncationDirection::End,
|
||||
)?;
|
||||
token_count = max_token_length;
|
||||
}
|
||||
}
|
||||
|
||||
anyhow::Ok((content, token_count))
|
||||
}
|
||||
}
|
||||
|
||||
struct TestLowPriorityTemplate {}
|
||||
impl PromptTemplate for TestLowPriorityTemplate {
|
||||
fn generate(
|
||||
&self,
|
||||
args: &PromptArguments,
|
||||
max_token_length: Option<usize>,
|
||||
) -> anyhow::Result<(String, usize)> {
|
||||
let mut content = "This is a low priority test prompt template".to_string();
|
||||
|
||||
let mut token_count = args.model.count_tokens(&content)?;
|
||||
if let Some(max_token_length) = max_token_length {
|
||||
if token_count > max_token_length {
|
||||
content = args.model.truncate(
|
||||
&content,
|
||||
max_token_length,
|
||||
TruncationDirection::End,
|
||||
)?;
|
||||
token_count = max_token_length;
|
||||
}
|
||||
}
|
||||
|
||||
anyhow::Ok((content, token_count))
|
||||
}
|
||||
}
|
||||
|
||||
let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity: 100 });
|
||||
let args = PromptArguments {
|
||||
model: model.clone(),
|
||||
language_name: None,
|
||||
project_name: None,
|
||||
snippets: Vec::new(),
|
||||
reserved_tokens: 0,
|
||||
buffer: None,
|
||||
selected_range: None,
|
||||
user_prompt: None,
|
||||
};
|
||||
|
||||
let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
|
||||
(
|
||||
PromptPriority::Ordered { order: 0 },
|
||||
Box::new(TestPromptTemplate {}),
|
||||
),
|
||||
(
|
||||
PromptPriority::Ordered { order: 1 },
|
||||
Box::new(TestLowPriorityTemplate {}),
|
||||
),
|
||||
];
|
||||
let chain = PromptChain::new(args, templates);
|
||||
|
||||
let (prompt, token_count) = chain.generate(false).unwrap();
|
||||
|
||||
assert_eq!(
|
||||
prompt,
|
||||
"This is a test prompt template\nThis is a low priority test prompt template"
|
||||
.to_string()
|
||||
);
|
||||
|
||||
assert_eq!(model.count_tokens(&prompt).unwrap(), token_count);
|
||||
|
||||
// Testing with Truncation Off
|
||||
// Should ignore capacity and return all prompts
|
||||
let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity: 20 });
|
||||
let args = PromptArguments {
|
||||
model: model.clone(),
|
||||
language_name: None,
|
||||
project_name: None,
|
||||
snippets: Vec::new(),
|
||||
reserved_tokens: 0,
|
||||
buffer: None,
|
||||
selected_range: None,
|
||||
user_prompt: None,
|
||||
};
|
||||
|
||||
let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
|
||||
(
|
||||
PromptPriority::Ordered { order: 0 },
|
||||
Box::new(TestPromptTemplate {}),
|
||||
),
|
||||
(
|
||||
PromptPriority::Ordered { order: 1 },
|
||||
Box::new(TestLowPriorityTemplate {}),
|
||||
),
|
||||
];
|
||||
let chain = PromptChain::new(args, templates);
|
||||
|
||||
let (prompt, token_count) = chain.generate(false).unwrap();
|
||||
|
||||
assert_eq!(
|
||||
prompt,
|
||||
"This is a test prompt template\nThis is a low priority test prompt template"
|
||||
.to_string()
|
||||
);
|
||||
|
||||
assert_eq!(model.count_tokens(&prompt).unwrap(), token_count);
|
||||
|
||||
// Testing with Truncation Off
|
||||
// Should ignore capacity and return all prompts
|
||||
let capacity = 20;
|
||||
let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity });
|
||||
let args = PromptArguments {
|
||||
model: model.clone(),
|
||||
language_name: None,
|
||||
project_name: None,
|
||||
snippets: Vec::new(),
|
||||
reserved_tokens: 0,
|
||||
buffer: None,
|
||||
selected_range: None,
|
||||
user_prompt: None,
|
||||
};
|
||||
|
||||
let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
|
||||
(
|
||||
PromptPriority::Ordered { order: 0 },
|
||||
Box::new(TestPromptTemplate {}),
|
||||
),
|
||||
(
|
||||
PromptPriority::Ordered { order: 1 },
|
||||
Box::new(TestLowPriorityTemplate {}),
|
||||
),
|
||||
(
|
||||
PromptPriority::Ordered { order: 2 },
|
||||
Box::new(TestLowPriorityTemplate {}),
|
||||
),
|
||||
];
|
||||
let chain = PromptChain::new(args, templates);
|
||||
|
||||
let (prompt, token_count) = chain.generate(true).unwrap();
|
||||
|
||||
assert_eq!(prompt, "This is a test promp".to_string());
|
||||
assert_eq!(token_count, capacity);
|
||||
|
||||
// Change Ordering of Prompts Based on Priority
|
||||
let capacity = 120;
|
||||
let reserved_tokens = 10;
|
||||
let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity });
|
||||
let args = PromptArguments {
|
||||
model: model.clone(),
|
||||
language_name: None,
|
||||
project_name: None,
|
||||
snippets: Vec::new(),
|
||||
reserved_tokens,
|
||||
buffer: None,
|
||||
selected_range: None,
|
||||
user_prompt: None,
|
||||
};
|
||||
let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
|
||||
(
|
||||
PromptPriority::Mandatory,
|
||||
Box::new(TestLowPriorityTemplate {}),
|
||||
),
|
||||
(
|
||||
PromptPriority::Ordered { order: 0 },
|
||||
Box::new(TestPromptTemplate {}),
|
||||
),
|
||||
(
|
||||
PromptPriority::Ordered { order: 1 },
|
||||
Box::new(TestLowPriorityTemplate {}),
|
||||
),
|
||||
];
|
||||
let chain = PromptChain::new(args, templates);
|
||||
|
||||
let (prompt, token_count) = chain.generate(true).unwrap();
|
||||
|
||||
assert_eq!(
|
||||
prompt,
|
||||
"This is a low priority test prompt template\nThis is a test prompt template\nThis is a low priority test prompt "
|
||||
.to_string()
|
||||
);
|
||||
assert_eq!(token_count, capacity - reserved_tokens);
|
||||
}
|
||||
}
|
164
crates/ai2/src/prompts/file_context.rs
Normal file
164
crates/ai2/src/prompts/file_context.rs
Normal file
|
@ -0,0 +1,164 @@
|
|||
use anyhow::anyhow;
|
||||
use language2::BufferSnapshot;
|
||||
use language2::ToOffset;
|
||||
|
||||
use crate::models::LanguageModel;
|
||||
use crate::models::TruncationDirection;
|
||||
use crate::prompts::base::PromptArguments;
|
||||
use crate::prompts::base::PromptTemplate;
|
||||
use std::fmt::Write;
|
||||
use std::ops::Range;
|
||||
use std::sync::Arc;
|
||||
|
||||
fn retrieve_context(
|
||||
buffer: &BufferSnapshot,
|
||||
selected_range: &Option<Range<usize>>,
|
||||
model: Arc<dyn LanguageModel>,
|
||||
max_token_count: Option<usize>,
|
||||
) -> anyhow::Result<(String, usize, bool)> {
|
||||
let mut prompt = String::new();
|
||||
let mut truncated = false;
|
||||
if let Some(selected_range) = selected_range {
|
||||
let start = selected_range.start.to_offset(buffer);
|
||||
let end = selected_range.end.to_offset(buffer);
|
||||
|
||||
let start_window = buffer.text_for_range(0..start).collect::<String>();
|
||||
|
||||
let mut selected_window = String::new();
|
||||
if start == end {
|
||||
write!(selected_window, "<|START|>").unwrap();
|
||||
} else {
|
||||
write!(selected_window, "<|START|").unwrap();
|
||||
}
|
||||
|
||||
write!(
|
||||
selected_window,
|
||||
"{}",
|
||||
buffer.text_for_range(start..end).collect::<String>()
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
if start != end {
|
||||
write!(selected_window, "|END|>").unwrap();
|
||||
}
|
||||
|
||||
let end_window = buffer.text_for_range(end..buffer.len()).collect::<String>();
|
||||
|
||||
if let Some(max_token_count) = max_token_count {
|
||||
let selected_tokens = model.count_tokens(&selected_window)?;
|
||||
if selected_tokens > max_token_count {
|
||||
return Err(anyhow!(
|
||||
"selected range is greater than model context window, truncation not possible"
|
||||
));
|
||||
};
|
||||
|
||||
let mut remaining_tokens = max_token_count - selected_tokens;
|
||||
let start_window_tokens = model.count_tokens(&start_window)?;
|
||||
let end_window_tokens = model.count_tokens(&end_window)?;
|
||||
let outside_tokens = start_window_tokens + end_window_tokens;
|
||||
if outside_tokens > remaining_tokens {
|
||||
let (start_goal_tokens, end_goal_tokens) =
|
||||
if start_window_tokens < end_window_tokens {
|
||||
let start_goal_tokens = (remaining_tokens / 2).min(start_window_tokens);
|
||||
remaining_tokens -= start_goal_tokens;
|
||||
let end_goal_tokens = remaining_tokens.min(end_window_tokens);
|
||||
(start_goal_tokens, end_goal_tokens)
|
||||
} else {
|
||||
let end_goal_tokens = (remaining_tokens / 2).min(end_window_tokens);
|
||||
remaining_tokens -= end_goal_tokens;
|
||||
let start_goal_tokens = remaining_tokens.min(start_window_tokens);
|
||||
(start_goal_tokens, end_goal_tokens)
|
||||
};
|
||||
|
||||
let truncated_start_window =
|
||||
model.truncate(&start_window, start_goal_tokens, TruncationDirection::Start)?;
|
||||
let truncated_end_window =
|
||||
model.truncate(&end_window, end_goal_tokens, TruncationDirection::End)?;
|
||||
writeln!(
|
||||
prompt,
|
||||
"{truncated_start_window}{selected_window}{truncated_end_window}"
|
||||
)
|
||||
.unwrap();
|
||||
truncated = true;
|
||||
} else {
|
||||
writeln!(prompt, "{start_window}{selected_window}{end_window}").unwrap();
|
||||
}
|
||||
} else {
|
||||
// If we dont have a selected range, include entire file.
|
||||
writeln!(prompt, "{}", &buffer.text()).unwrap();
|
||||
|
||||
// Dumb truncation strategy
|
||||
if let Some(max_token_count) = max_token_count {
|
||||
if model.count_tokens(&prompt)? > max_token_count {
|
||||
truncated = true;
|
||||
prompt = model.truncate(&prompt, max_token_count, TruncationDirection::End)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let token_count = model.count_tokens(&prompt)?;
|
||||
anyhow::Ok((prompt, token_count, truncated))
|
||||
}
|
||||
|
||||
pub struct FileContext {}
|
||||
|
||||
impl PromptTemplate for FileContext {
|
||||
fn generate(
|
||||
&self,
|
||||
args: &PromptArguments,
|
||||
max_token_length: Option<usize>,
|
||||
) -> anyhow::Result<(String, usize)> {
|
||||
if let Some(buffer) = &args.buffer {
|
||||
let mut prompt = String::new();
|
||||
// Add Initial Preamble
|
||||
// TODO: Do we want to add the path in here?
|
||||
writeln!(
|
||||
prompt,
|
||||
"The file you are currently working on has the following content:"
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let language_name = args
|
||||
.language_name
|
||||
.clone()
|
||||
.unwrap_or("".to_string())
|
||||
.to_lowercase();
|
||||
|
||||
let (context, _, truncated) = retrieve_context(
|
||||
buffer,
|
||||
&args.selected_range,
|
||||
args.model.clone(),
|
||||
max_token_length,
|
||||
)?;
|
||||
writeln!(prompt, "```{language_name}\n{context}\n```").unwrap();
|
||||
|
||||
if truncated {
|
||||
writeln!(prompt, "Note the content has been truncated and only represents a portion of the file.").unwrap();
|
||||
}
|
||||
|
||||
if let Some(selected_range) = &args.selected_range {
|
||||
let start = selected_range.start.to_offset(buffer);
|
||||
let end = selected_range.end.to_offset(buffer);
|
||||
|
||||
if start == end {
|
||||
writeln!(prompt, "In particular, the user's cursor is currently on the '<|START|>' span in the above content, with no text selected.").unwrap();
|
||||
} else {
|
||||
writeln!(prompt, "In particular, the user has selected a section of the text between the '<|START|' and '|END|>' spans.").unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
// Really dumb truncation strategy
|
||||
if let Some(max_tokens) = max_token_length {
|
||||
prompt = args
|
||||
.model
|
||||
.truncate(&prompt, max_tokens, TruncationDirection::End)?;
|
||||
}
|
||||
|
||||
let token_count = args.model.count_tokens(&prompt)?;
|
||||
anyhow::Ok((prompt, token_count))
|
||||
} else {
|
||||
Err(anyhow!("no buffer provided to retrieve file context from"))
|
||||
}
|
||||
}
|
||||
}
|
99
crates/ai2/src/prompts/generate.rs
Normal file
99
crates/ai2/src/prompts/generate.rs
Normal file
|
@ -0,0 +1,99 @@
|
|||
use crate::prompts::base::{PromptArguments, PromptFileType, PromptTemplate};
|
||||
use anyhow::anyhow;
|
||||
use std::fmt::Write;
|
||||
|
||||
pub fn capitalize(s: &str) -> String {
|
||||
let mut c = s.chars();
|
||||
match c.next() {
|
||||
None => String::new(),
|
||||
Some(f) => f.to_uppercase().collect::<String>() + c.as_str(),
|
||||
}
|
||||
}
|
||||
|
||||
pub struct GenerateInlineContent {}
|
||||
|
||||
impl PromptTemplate for GenerateInlineContent {
|
||||
fn generate(
|
||||
&self,
|
||||
args: &PromptArguments,
|
||||
max_token_length: Option<usize>,
|
||||
) -> anyhow::Result<(String, usize)> {
|
||||
let Some(user_prompt) = &args.user_prompt else {
|
||||
return Err(anyhow!("user prompt not provided"));
|
||||
};
|
||||
|
||||
let file_type = args.get_file_type();
|
||||
let content_type = match &file_type {
|
||||
PromptFileType::Code => "code",
|
||||
PromptFileType::Text => "text",
|
||||
};
|
||||
|
||||
let mut prompt = String::new();
|
||||
|
||||
if let Some(selected_range) = &args.selected_range {
|
||||
if selected_range.start == selected_range.end {
|
||||
writeln!(
|
||||
prompt,
|
||||
"Assume the cursor is located where the `<|START|>` span is."
|
||||
)
|
||||
.unwrap();
|
||||
writeln!(
|
||||
prompt,
|
||||
"{} can't be replaced, so assume your answer will be inserted at the cursor.",
|
||||
capitalize(content_type)
|
||||
)
|
||||
.unwrap();
|
||||
writeln!(
|
||||
prompt,
|
||||
"Generate {content_type} based on the users prompt: {user_prompt}",
|
||||
)
|
||||
.unwrap();
|
||||
} else {
|
||||
writeln!(prompt, "Modify the user's selected {content_type} based upon the users prompt: '{user_prompt}'").unwrap();
|
||||
writeln!(prompt, "You must reply with only the adjusted {content_type} (within the '<|START|' and '|END|>' spans) not the entire file.").unwrap();
|
||||
writeln!(prompt, "Double check that you only return code and not the '<|START|' and '|END|'> spans").unwrap();
|
||||
}
|
||||
} else {
|
||||
writeln!(
|
||||
prompt,
|
||||
"Generate {content_type} based on the users prompt: {user_prompt}"
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
if let Some(language_name) = &args.language_name {
|
||||
writeln!(
|
||||
prompt,
|
||||
"Your answer MUST always and only be valid {}.",
|
||||
language_name
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
writeln!(prompt, "Never make remarks about the output.").unwrap();
|
||||
writeln!(
|
||||
prompt,
|
||||
"Do not return anything else, except the generated {content_type}."
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
match file_type {
|
||||
PromptFileType::Code => {
|
||||
// writeln!(prompt, "Always wrap your code in a Markdown block.").unwrap();
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
// Really dumb truncation strategy
|
||||
if let Some(max_tokens) = max_token_length {
|
||||
prompt = args.model.truncate(
|
||||
&prompt,
|
||||
max_tokens,
|
||||
crate::models::TruncationDirection::End,
|
||||
)?;
|
||||
}
|
||||
|
||||
let token_count = args.model.count_tokens(&prompt)?;
|
||||
|
||||
anyhow::Ok((prompt, token_count))
|
||||
}
|
||||
}
|
5
crates/ai2/src/prompts/mod.rs
Normal file
5
crates/ai2/src/prompts/mod.rs
Normal file
|
@ -0,0 +1,5 @@
|
|||
pub mod base;
|
||||
pub mod file_context;
|
||||
pub mod generate;
|
||||
pub mod preamble;
|
||||
pub mod repository_context;
|
52
crates/ai2/src/prompts/preamble.rs
Normal file
52
crates/ai2/src/prompts/preamble.rs
Normal file
|
@ -0,0 +1,52 @@
|
|||
use crate::prompts::base::{PromptArguments, PromptFileType, PromptTemplate};
|
||||
use std::fmt::Write;
|
||||
|
||||
pub struct EngineerPreamble {}
|
||||
|
||||
impl PromptTemplate for EngineerPreamble {
|
||||
fn generate(
|
||||
&self,
|
||||
args: &PromptArguments,
|
||||
max_token_length: Option<usize>,
|
||||
) -> anyhow::Result<(String, usize)> {
|
||||
let mut prompts = Vec::new();
|
||||
|
||||
match args.get_file_type() {
|
||||
PromptFileType::Code => {
|
||||
prompts.push(format!(
|
||||
"You are an expert {}engineer.",
|
||||
args.language_name.clone().unwrap_or("".to_string()) + " "
|
||||
));
|
||||
}
|
||||
PromptFileType::Text => {
|
||||
prompts.push("You are an expert engineer.".to_string());
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(project_name) = args.project_name.clone() {
|
||||
prompts.push(format!(
|
||||
"You are currently working inside the '{project_name}' project in code editor Zed."
|
||||
));
|
||||
}
|
||||
|
||||
if let Some(mut remaining_tokens) = max_token_length {
|
||||
let mut prompt = String::new();
|
||||
let mut total_count = 0;
|
||||
for prompt_piece in prompts {
|
||||
let prompt_token_count =
|
||||
args.model.count_tokens(&prompt_piece)? + args.model.count_tokens("\n")?;
|
||||
if remaining_tokens > prompt_token_count {
|
||||
writeln!(prompt, "{prompt_piece}").unwrap();
|
||||
remaining_tokens -= prompt_token_count;
|
||||
total_count += prompt_token_count;
|
||||
}
|
||||
}
|
||||
|
||||
anyhow::Ok((prompt, total_count))
|
||||
} else {
|
||||
let prompt = prompts.join("\n");
|
||||
let token_count = args.model.count_tokens(&prompt)?;
|
||||
anyhow::Ok((prompt, token_count))
|
||||
}
|
||||
}
|
||||
}
|
98
crates/ai2/src/prompts/repository_context.rs
Normal file
98
crates/ai2/src/prompts/repository_context.rs
Normal file
|
@ -0,0 +1,98 @@
|
|||
use crate::prompts::base::{PromptArguments, PromptTemplate};
|
||||
use std::fmt::Write;
|
||||
use std::{ops::Range, path::PathBuf};
|
||||
|
||||
use gpui2::{AsyncAppContext, Handle};
|
||||
use language2::{Anchor, Buffer};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct PromptCodeSnippet {
|
||||
path: Option<PathBuf>,
|
||||
language_name: Option<String>,
|
||||
content: String,
|
||||
}
|
||||
|
||||
impl PromptCodeSnippet {
|
||||
pub fn new(
|
||||
buffer: Handle<Buffer>,
|
||||
range: Range<Anchor>,
|
||||
cx: &mut AsyncAppContext,
|
||||
) -> anyhow::Result<Self> {
|
||||
let (content, language_name, file_path) = buffer.update(cx, |buffer, _| {
|
||||
let snapshot = buffer.snapshot();
|
||||
let content = snapshot.text_for_range(range.clone()).collect::<String>();
|
||||
|
||||
let language_name = buffer
|
||||
.language()
|
||||
.and_then(|language| Some(language.name().to_string().to_lowercase()));
|
||||
|
||||
let file_path = buffer
|
||||
.file()
|
||||
.and_then(|file| Some(file.path().to_path_buf()));
|
||||
|
||||
(content, language_name, file_path)
|
||||
})?;
|
||||
|
||||
anyhow::Ok(PromptCodeSnippet {
|
||||
path: file_path,
|
||||
language_name,
|
||||
content,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl ToString for PromptCodeSnippet {
|
||||
fn to_string(&self) -> String {
|
||||
let path = self
|
||||
.path
|
||||
.as_ref()
|
||||
.and_then(|path| Some(path.to_string_lossy().to_string()))
|
||||
.unwrap_or("".to_string());
|
||||
let language_name = self.language_name.clone().unwrap_or("".to_string());
|
||||
let content = self.content.clone();
|
||||
|
||||
format!("The below code snippet may be relevant from file: {path}\n```{language_name}\n{content}\n```")
|
||||
}
|
||||
}
|
||||
|
||||
pub struct RepositoryContext {}
|
||||
|
||||
impl PromptTemplate for RepositoryContext {
|
||||
fn generate(
|
||||
&self,
|
||||
args: &PromptArguments,
|
||||
max_token_length: Option<usize>,
|
||||
) -> anyhow::Result<(String, usize)> {
|
||||
const MAXIMUM_SNIPPET_TOKEN_COUNT: usize = 500;
|
||||
let template = "You are working inside a large repository, here are a few code snippets that may be useful.";
|
||||
let mut prompt = String::new();
|
||||
|
||||
let mut remaining_tokens = max_token_length.clone();
|
||||
let seperator_token_length = args.model.count_tokens("\n")?;
|
||||
for snippet in &args.snippets {
|
||||
let mut snippet_prompt = template.to_string();
|
||||
let content = snippet.to_string();
|
||||
writeln!(snippet_prompt, "{content}").unwrap();
|
||||
|
||||
let token_count = args.model.count_tokens(&snippet_prompt)?;
|
||||
if token_count <= MAXIMUM_SNIPPET_TOKEN_COUNT {
|
||||
if let Some(tokens_left) = remaining_tokens {
|
||||
if tokens_left >= token_count {
|
||||
writeln!(prompt, "{snippet_prompt}").unwrap();
|
||||
remaining_tokens = if tokens_left >= (token_count + seperator_token_length)
|
||||
{
|
||||
Some(tokens_left - token_count - seperator_token_length)
|
||||
} else {
|
||||
Some(0)
|
||||
};
|
||||
}
|
||||
} else {
|
||||
writeln!(prompt, "{snippet_prompt}").unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let total_token_count = args.model.count_tokens(&prompt)?;
|
||||
anyhow::Ok((prompt, total_token_count))
|
||||
}
|
||||
}
|
1
crates/ai2/src/providers/mod.rs
Normal file
1
crates/ai2/src/providers/mod.rs
Normal file
|
@ -0,0 +1 @@
|
|||
pub mod open_ai;
|
306
crates/ai2/src/providers/open_ai/completion.rs
Normal file
306
crates/ai2/src/providers/open_ai/completion.rs
Normal file
|
@ -0,0 +1,306 @@
|
|||
use anyhow::{anyhow, Result};
|
||||
use async_trait::async_trait;
|
||||
use futures::{
|
||||
future::BoxFuture, io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, FutureExt,
|
||||
Stream, StreamExt,
|
||||
};
|
||||
use gpui2::{AppContext, Executor};
|
||||
use isahc::{http::StatusCode, Request, RequestExt};
|
||||
use parking_lot::RwLock;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{
|
||||
env,
|
||||
fmt::{self, Display},
|
||||
io,
|
||||
sync::Arc,
|
||||
};
|
||||
use util::ResultExt;
|
||||
|
||||
use crate::{
|
||||
auth::{CredentialProvider, ProviderCredential},
|
||||
completion::{CompletionProvider, CompletionRequest},
|
||||
models::LanguageModel,
|
||||
};
|
||||
|
||||
use crate::providers::open_ai::{OpenAILanguageModel, OPENAI_API_URL};
|
||||
|
||||
#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum Role {
|
||||
User,
|
||||
Assistant,
|
||||
System,
|
||||
}
|
||||
|
||||
impl Role {
|
||||
pub fn cycle(&mut self) {
|
||||
*self = match self {
|
||||
Role::User => Role::Assistant,
|
||||
Role::Assistant => Role::System,
|
||||
Role::System => Role::User,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for Role {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Role::User => write!(f, "User"),
|
||||
Role::Assistant => write!(f, "Assistant"),
|
||||
Role::System => write!(f, "System"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
pub struct RequestMessage {
|
||||
pub role: Role,
|
||||
pub content: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Serialize)]
|
||||
pub struct OpenAIRequest {
|
||||
pub model: String,
|
||||
pub messages: Vec<RequestMessage>,
|
||||
pub stream: bool,
|
||||
pub stop: Vec<String>,
|
||||
pub temperature: f32,
|
||||
}
|
||||
|
||||
impl CompletionRequest for OpenAIRequest {
|
||||
fn data(&self) -> serde_json::Result<String> {
|
||||
serde_json::to_string(self)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
pub struct ResponseMessage {
|
||||
pub role: Option<Role>,
|
||||
pub content: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct OpenAIUsage {
|
||||
pub prompt_tokens: u32,
|
||||
pub completion_tokens: u32,
|
||||
pub total_tokens: u32,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct ChatChoiceDelta {
|
||||
pub index: u32,
|
||||
pub delta: ResponseMessage,
|
||||
pub finish_reason: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct OpenAIResponseStreamEvent {
|
||||
pub id: Option<String>,
|
||||
pub object: String,
|
||||
pub created: u32,
|
||||
pub model: String,
|
||||
pub choices: Vec<ChatChoiceDelta>,
|
||||
pub usage: Option<OpenAIUsage>,
|
||||
}
|
||||
|
||||
pub async fn stream_completion(
|
||||
credential: ProviderCredential,
|
||||
executor: Arc<Executor>,
|
||||
request: Box<dyn CompletionRequest>,
|
||||
) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
|
||||
let api_key = match credential {
|
||||
ProviderCredential::Credentials { api_key } => api_key,
|
||||
_ => {
|
||||
return Err(anyhow!("no credentials provider for completion"));
|
||||
}
|
||||
};
|
||||
|
||||
let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAIResponseStreamEvent>>();
|
||||
|
||||
let json_data = request.data()?;
|
||||
let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions"))
|
||||
.header("Content-Type", "application/json")
|
||||
.header("Authorization", format!("Bearer {}", api_key))
|
||||
.body(json_data)?
|
||||
.send_async()
|
||||
.await?;
|
||||
|
||||
let status = response.status();
|
||||
if status == StatusCode::OK {
|
||||
executor
|
||||
.spawn(async move {
|
||||
let mut lines = BufReader::new(response.body_mut()).lines();
|
||||
|
||||
fn parse_line(
|
||||
line: Result<String, io::Error>,
|
||||
) -> Result<Option<OpenAIResponseStreamEvent>> {
|
||||
if let Some(data) = line?.strip_prefix("data: ") {
|
||||
let event = serde_json::from_str(&data)?;
|
||||
Ok(Some(event))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
while let Some(line) = lines.next().await {
|
||||
if let Some(event) = parse_line(line).transpose() {
|
||||
let done = event.as_ref().map_or(false, |event| {
|
||||
event
|
||||
.choices
|
||||
.last()
|
||||
.map_or(false, |choice| choice.finish_reason.is_some())
|
||||
});
|
||||
if tx.unbounded_send(event).is_err() {
|
||||
break;
|
||||
}
|
||||
|
||||
if done {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
anyhow::Ok(())
|
||||
})
|
||||
.detach();
|
||||
|
||||
Ok(rx)
|
||||
} else {
|
||||
let mut body = String::new();
|
||||
response.body_mut().read_to_string(&mut body).await?;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct OpenAIResponse {
|
||||
error: OpenAIError,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct OpenAIError {
|
||||
message: String,
|
||||
}
|
||||
|
||||
match serde_json::from_str::<OpenAIResponse>(&body) {
|
||||
Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
|
||||
"Failed to connect to OpenAI API: {}",
|
||||
response.error.message,
|
||||
)),
|
||||
|
||||
_ => Err(anyhow!(
|
||||
"Failed to connect to OpenAI API: {} {}",
|
||||
response.status(),
|
||||
body,
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct OpenAICompletionProvider {
|
||||
model: OpenAILanguageModel,
|
||||
credential: Arc<RwLock<ProviderCredential>>,
|
||||
executor: Arc<Executor>,
|
||||
}
|
||||
|
||||
impl OpenAICompletionProvider {
|
||||
pub fn new(model_name: &str, executor: Arc<Executor>) -> Self {
|
||||
let model = OpenAILanguageModel::load(model_name);
|
||||
let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials));
|
||||
Self {
|
||||
model,
|
||||
credential,
|
||||
executor,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl CredentialProvider for OpenAICompletionProvider {
|
||||
fn has_credentials(&self) -> bool {
|
||||
match *self.credential.read() {
|
||||
ProviderCredential::Credentials { .. } => true,
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
async fn retrieve_credentials(&self, cx: &mut AppContext) -> ProviderCredential {
|
||||
let existing_credential = self.credential.read().clone();
|
||||
|
||||
let retrieved_credential = cx
|
||||
.run_on_main(move |cx| match existing_credential {
|
||||
ProviderCredential::Credentials { .. } => {
|
||||
return existing_credential.clone();
|
||||
}
|
||||
_ => {
|
||||
if let Some(api_key) = env::var("OPENAI_API_KEY").log_err() {
|
||||
return ProviderCredential::Credentials { api_key };
|
||||
}
|
||||
|
||||
if let Some(Some((_, api_key))) = cx.read_credentials(OPENAI_API_URL).log_err()
|
||||
{
|
||||
if let Some(api_key) = String::from_utf8(api_key).log_err() {
|
||||
return ProviderCredential::Credentials { api_key };
|
||||
} else {
|
||||
return ProviderCredential::NoCredentials;
|
||||
}
|
||||
} else {
|
||||
return ProviderCredential::NoCredentials;
|
||||
}
|
||||
}
|
||||
})
|
||||
.await;
|
||||
|
||||
*self.credential.write() = retrieved_credential.clone();
|
||||
retrieved_credential
|
||||
}
|
||||
|
||||
async fn save_credentials(&self, cx: &mut AppContext, credential: ProviderCredential) {
|
||||
*self.credential.write() = credential.clone();
|
||||
let credential = credential.clone();
|
||||
cx.run_on_main(move |cx| match credential {
|
||||
ProviderCredential::Credentials { api_key } => {
|
||||
cx.write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes())
|
||||
.log_err();
|
||||
}
|
||||
_ => {}
|
||||
})
|
||||
.await;
|
||||
}
|
||||
async fn delete_credentials(&self, cx: &mut AppContext) {
|
||||
cx.run_on_main(move |cx| cx.delete_credentials(OPENAI_API_URL).log_err())
|
||||
.await;
|
||||
*self.credential.write() = ProviderCredential::NoCredentials;
|
||||
}
|
||||
}
|
||||
|
||||
impl CompletionProvider for OpenAICompletionProvider {
|
||||
fn base_model(&self) -> Box<dyn LanguageModel> {
|
||||
let model: Box<dyn LanguageModel> = Box::new(self.model.clone());
|
||||
model
|
||||
}
|
||||
fn complete(
|
||||
&self,
|
||||
prompt: Box<dyn CompletionRequest>,
|
||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
||||
// Currently the CompletionRequest for OpenAI, includes a 'model' parameter
|
||||
// This means that the model is determined by the CompletionRequest and not the CompletionProvider,
|
||||
// which is currently model based, due to the langauge model.
|
||||
// At some point in the future we should rectify this.
|
||||
let credential = self.credential.read().clone();
|
||||
let request = stream_completion(credential, self.executor.clone(), prompt);
|
||||
async move {
|
||||
let response = request.await?;
|
||||
let stream = response
|
||||
.filter_map(|response| async move {
|
||||
match response {
|
||||
Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
|
||||
Err(error) => Some(Err(error)),
|
||||
}
|
||||
})
|
||||
.boxed();
|
||||
Ok(stream)
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
fn box_clone(&self) -> Box<dyn CompletionProvider> {
|
||||
Box::new((*self).clone())
|
||||
}
|
||||
}
|
313
crates/ai2/src/providers/open_ai/embedding.rs
Normal file
313
crates/ai2/src/providers/open_ai/embedding.rs
Normal file
|
@ -0,0 +1,313 @@
|
|||
use anyhow::{anyhow, Result};
|
||||
use async_trait::async_trait;
|
||||
use futures::AsyncReadExt;
|
||||
use gpui2::Executor;
|
||||
use gpui2::{serde_json, AppContext};
|
||||
use isahc::http::StatusCode;
|
||||
use isahc::prelude::Configurable;
|
||||
use isahc::{AsyncBody, Response};
|
||||
use lazy_static::lazy_static;
|
||||
use parking_lot::{Mutex, RwLock};
|
||||
use parse_duration::parse;
|
||||
use postage::watch;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::env;
|
||||
use std::ops::Add;
|
||||
use std::sync::Arc;
|
||||
use std::time::{Duration, Instant};
|
||||
use tiktoken_rs::{cl100k_base, CoreBPE};
|
||||
use util::http::{HttpClient, Request};
|
||||
use util::ResultExt;
|
||||
|
||||
use crate::auth::{CredentialProvider, ProviderCredential};
|
||||
use crate::embedding::{Embedding, EmbeddingProvider};
|
||||
use crate::models::LanguageModel;
|
||||
use crate::providers::open_ai::OpenAILanguageModel;
|
||||
|
||||
use crate::providers::open_ai::OPENAI_API_URL;
|
||||
|
||||
lazy_static! {
|
||||
static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap();
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct OpenAIEmbeddingProvider {
|
||||
model: OpenAILanguageModel,
|
||||
credential: Arc<RwLock<ProviderCredential>>,
|
||||
pub client: Arc<dyn HttpClient>,
|
||||
pub executor: Arc<Executor>,
|
||||
rate_limit_count_rx: watch::Receiver<Option<Instant>>,
|
||||
rate_limit_count_tx: Arc<Mutex<watch::Sender<Option<Instant>>>>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct OpenAIEmbeddingRequest<'a> {
|
||||
model: &'static str,
|
||||
input: Vec<&'a str>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct OpenAIEmbeddingResponse {
|
||||
data: Vec<OpenAIEmbedding>,
|
||||
usage: OpenAIEmbeddingUsage,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct OpenAIEmbedding {
|
||||
embedding: Vec<f32>,
|
||||
index: usize,
|
||||
object: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct OpenAIEmbeddingUsage {
|
||||
prompt_tokens: usize,
|
||||
total_tokens: usize,
|
||||
}
|
||||
|
||||
impl OpenAIEmbeddingProvider {
|
||||
pub fn new(client: Arc<dyn HttpClient>, executor: Arc<Executor>) -> Self {
|
||||
let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None);
|
||||
let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx));
|
||||
|
||||
let model = OpenAILanguageModel::load("text-embedding-ada-002");
|
||||
let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials));
|
||||
|
||||
OpenAIEmbeddingProvider {
|
||||
model,
|
||||
credential,
|
||||
client,
|
||||
executor,
|
||||
rate_limit_count_rx,
|
||||
rate_limit_count_tx,
|
||||
}
|
||||
}
|
||||
|
||||
fn get_api_key(&self) -> Result<String> {
|
||||
match self.credential.read().clone() {
|
||||
ProviderCredential::Credentials { api_key } => Ok(api_key),
|
||||
_ => Err(anyhow!("api credentials not provided")),
|
||||
}
|
||||
}
|
||||
|
||||
fn resolve_rate_limit(&self) {
|
||||
let reset_time = *self.rate_limit_count_tx.lock().borrow();
|
||||
|
||||
if let Some(reset_time) = reset_time {
|
||||
if Instant::now() >= reset_time {
|
||||
*self.rate_limit_count_tx.lock().borrow_mut() = None
|
||||
}
|
||||
}
|
||||
|
||||
log::trace!(
|
||||
"resolving reset time: {:?}",
|
||||
*self.rate_limit_count_tx.lock().borrow()
|
||||
);
|
||||
}
|
||||
|
||||
fn update_reset_time(&self, reset_time: Instant) {
|
||||
let original_time = *self.rate_limit_count_tx.lock().borrow();
|
||||
|
||||
let updated_time = if let Some(original_time) = original_time {
|
||||
if reset_time < original_time {
|
||||
Some(reset_time)
|
||||
} else {
|
||||
Some(original_time)
|
||||
}
|
||||
} else {
|
||||
Some(reset_time)
|
||||
};
|
||||
|
||||
log::trace!("updating rate limit time: {:?}", updated_time);
|
||||
|
||||
*self.rate_limit_count_tx.lock().borrow_mut() = updated_time;
|
||||
}
|
||||
async fn send_request(
|
||||
&self,
|
||||
api_key: &str,
|
||||
spans: Vec<&str>,
|
||||
request_timeout: u64,
|
||||
) -> Result<Response<AsyncBody>> {
|
||||
let request = Request::post("https://api.openai.com/v1/embeddings")
|
||||
.redirect_policy(isahc::config::RedirectPolicy::Follow)
|
||||
.timeout(Duration::from_secs(request_timeout))
|
||||
.header("Content-Type", "application/json")
|
||||
.header("Authorization", format!("Bearer {}", api_key))
|
||||
.body(
|
||||
serde_json::to_string(&OpenAIEmbeddingRequest {
|
||||
input: spans.clone(),
|
||||
model: "text-embedding-ada-002",
|
||||
})
|
||||
.unwrap()
|
||||
.into(),
|
||||
)?;
|
||||
|
||||
Ok(self.client.send(request).await?)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl CredentialProvider for OpenAIEmbeddingProvider {
|
||||
fn has_credentials(&self) -> bool {
|
||||
match *self.credential.read() {
|
||||
ProviderCredential::Credentials { .. } => true,
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
async fn retrieve_credentials(&self, cx: &mut AppContext) -> ProviderCredential {
|
||||
let existing_credential = self.credential.read().clone();
|
||||
|
||||
let retrieved_credential = cx
|
||||
.run_on_main(move |cx| match existing_credential {
|
||||
ProviderCredential::Credentials { .. } => {
|
||||
return existing_credential.clone();
|
||||
}
|
||||
_ => {
|
||||
if let Some(api_key) = env::var("OPENAI_API_KEY").log_err() {
|
||||
return ProviderCredential::Credentials { api_key };
|
||||
}
|
||||
|
||||
if let Some(Some((_, api_key))) = cx.read_credentials(OPENAI_API_URL).log_err()
|
||||
{
|
||||
if let Some(api_key) = String::from_utf8(api_key).log_err() {
|
||||
return ProviderCredential::Credentials { api_key };
|
||||
} else {
|
||||
return ProviderCredential::NoCredentials;
|
||||
}
|
||||
} else {
|
||||
return ProviderCredential::NoCredentials;
|
||||
}
|
||||
}
|
||||
})
|
||||
.await;
|
||||
|
||||
*self.credential.write() = retrieved_credential.clone();
|
||||
retrieved_credential
|
||||
}
|
||||
|
||||
async fn save_credentials(&self, cx: &mut AppContext, credential: ProviderCredential) {
|
||||
*self.credential.write() = credential.clone();
|
||||
let credential = credential.clone();
|
||||
cx.run_on_main(move |cx| match credential {
|
||||
ProviderCredential::Credentials { api_key } => {
|
||||
cx.write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes())
|
||||
.log_err();
|
||||
}
|
||||
_ => {}
|
||||
})
|
||||
.await;
|
||||
}
|
||||
async fn delete_credentials(&self, cx: &mut AppContext) {
|
||||
cx.run_on_main(move |cx| cx.delete_credentials(OPENAI_API_URL).log_err())
|
||||
.await;
|
||||
*self.credential.write() = ProviderCredential::NoCredentials;
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl EmbeddingProvider for OpenAIEmbeddingProvider {
|
||||
fn base_model(&self) -> Box<dyn LanguageModel> {
|
||||
let model: Box<dyn LanguageModel> = Box::new(self.model.clone());
|
||||
model
|
||||
}
|
||||
|
||||
fn max_tokens_per_batch(&self) -> usize {
|
||||
50000
|
||||
}
|
||||
|
||||
fn rate_limit_expiration(&self) -> Option<Instant> {
|
||||
*self.rate_limit_count_rx.borrow()
|
||||
}
|
||||
|
||||
async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
|
||||
const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
|
||||
const MAX_RETRIES: usize = 4;
|
||||
|
||||
let api_key = self.get_api_key()?;
|
||||
|
||||
let mut request_number = 0;
|
||||
let mut rate_limiting = false;
|
||||
let mut request_timeout: u64 = 15;
|
||||
let mut response: Response<AsyncBody>;
|
||||
while request_number < MAX_RETRIES {
|
||||
response = self
|
||||
.send_request(
|
||||
&api_key,
|
||||
spans.iter().map(|x| &**x).collect(),
|
||||
request_timeout,
|
||||
)
|
||||
.await?;
|
||||
|
||||
request_number += 1;
|
||||
|
||||
match response.status() {
|
||||
StatusCode::REQUEST_TIMEOUT => {
|
||||
request_timeout += 5;
|
||||
}
|
||||
StatusCode::OK => {
|
||||
let mut body = String::new();
|
||||
response.body_mut().read_to_string(&mut body).await?;
|
||||
let response: OpenAIEmbeddingResponse = serde_json::from_str(&body)?;
|
||||
|
||||
log::trace!(
|
||||
"openai embedding completed. tokens: {:?}",
|
||||
response.usage.total_tokens
|
||||
);
|
||||
|
||||
// If we complete a request successfully that was previously rate_limited
|
||||
// resolve the rate limit
|
||||
if rate_limiting {
|
||||
self.resolve_rate_limit()
|
||||
}
|
||||
|
||||
return Ok(response
|
||||
.data
|
||||
.into_iter()
|
||||
.map(|embedding| Embedding::from(embedding.embedding))
|
||||
.collect());
|
||||
}
|
||||
StatusCode::TOO_MANY_REQUESTS => {
|
||||
rate_limiting = true;
|
||||
let mut body = String::new();
|
||||
response.body_mut().read_to_string(&mut body).await?;
|
||||
|
||||
let delay_duration = {
|
||||
let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64);
|
||||
if let Some(time_to_reset) =
|
||||
response.headers().get("x-ratelimit-reset-tokens")
|
||||
{
|
||||
if let Ok(time_str) = time_to_reset.to_str() {
|
||||
parse(time_str).unwrap_or(delay)
|
||||
} else {
|
||||
delay
|
||||
}
|
||||
} else {
|
||||
delay
|
||||
}
|
||||
};
|
||||
|
||||
// If we've previously rate limited, increment the duration but not the count
|
||||
let reset_time = Instant::now().add(delay_duration);
|
||||
self.update_reset_time(reset_time);
|
||||
|
||||
log::trace!(
|
||||
"openai rate limiting: waiting {:?} until lifted",
|
||||
&delay_duration
|
||||
);
|
||||
|
||||
self.executor.timer(delay_duration).await;
|
||||
}
|
||||
_ => {
|
||||
let mut body = String::new();
|
||||
response.body_mut().read_to_string(&mut body).await?;
|
||||
return Err(anyhow!(
|
||||
"open ai bad request: {:?} {:?}",
|
||||
&response.status(),
|
||||
body
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(anyhow!("openai max retries"))
|
||||
}
|
||||
}
|
9
crates/ai2/src/providers/open_ai/mod.rs
Normal file
9
crates/ai2/src/providers/open_ai/mod.rs
Normal file
|
@ -0,0 +1,9 @@
|
|||
pub mod completion;
|
||||
pub mod embedding;
|
||||
pub mod model;
|
||||
|
||||
pub use completion::*;
|
||||
pub use embedding::*;
|
||||
pub use model::OpenAILanguageModel;
|
||||
|
||||
pub const OPENAI_API_URL: &'static str = "https://api.openai.com/v1";
|
57
crates/ai2/src/providers/open_ai/model.rs
Normal file
57
crates/ai2/src/providers/open_ai/model.rs
Normal file
|
@ -0,0 +1,57 @@
|
|||
use anyhow::anyhow;
|
||||
use tiktoken_rs::CoreBPE;
|
||||
use util::ResultExt;
|
||||
|
||||
use crate::models::{LanguageModel, TruncationDirection};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct OpenAILanguageModel {
|
||||
name: String,
|
||||
bpe: Option<CoreBPE>,
|
||||
}
|
||||
|
||||
impl OpenAILanguageModel {
|
||||
pub fn load(model_name: &str) -> Self {
|
||||
let bpe = tiktoken_rs::get_bpe_from_model(model_name).log_err();
|
||||
OpenAILanguageModel {
|
||||
name: model_name.to_string(),
|
||||
bpe,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl LanguageModel for OpenAILanguageModel {
|
||||
fn name(&self) -> String {
|
||||
self.name.clone()
|
||||
}
|
||||
fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
|
||||
if let Some(bpe) = &self.bpe {
|
||||
anyhow::Ok(bpe.encode_with_special_tokens(content).len())
|
||||
} else {
|
||||
Err(anyhow!("bpe for open ai model was not retrieved"))
|
||||
}
|
||||
}
|
||||
fn truncate(
|
||||
&self,
|
||||
content: &str,
|
||||
length: usize,
|
||||
direction: TruncationDirection,
|
||||
) -> anyhow::Result<String> {
|
||||
if let Some(bpe) = &self.bpe {
|
||||
let tokens = bpe.encode_with_special_tokens(content);
|
||||
if tokens.len() > length {
|
||||
match direction {
|
||||
TruncationDirection::End => bpe.decode(tokens[..length].to_vec()),
|
||||
TruncationDirection::Start => bpe.decode(tokens[length..].to_vec()),
|
||||
}
|
||||
} else {
|
||||
bpe.decode(tokens)
|
||||
}
|
||||
} else {
|
||||
Err(anyhow!("bpe for open ai model was not retrieved"))
|
||||
}
|
||||
}
|
||||
fn capacity(&self) -> anyhow::Result<usize> {
|
||||
anyhow::Ok(tiktoken_rs::model::get_context_size(&self.name))
|
||||
}
|
||||
}
|
11
crates/ai2/src/providers/open_ai/new.rs
Normal file
11
crates/ai2/src/providers/open_ai/new.rs
Normal file
|
@ -0,0 +1,11 @@
|
|||
pub trait LanguageModel {
|
||||
fn name(&self) -> String;
|
||||
fn count_tokens(&self, content: &str) -> anyhow::Result<usize>;
|
||||
fn truncate(
|
||||
&self,
|
||||
content: &str,
|
||||
length: usize,
|
||||
direction: TruncationDirection,
|
||||
) -> anyhow::Result<String>;
|
||||
fn capacity(&self) -> anyhow::Result<usize>;
|
||||
}
|
193
crates/ai2/src/test.rs
Normal file
193
crates/ai2/src/test.rs
Normal file
|
@ -0,0 +1,193 @@
|
|||
use std::{
|
||||
sync::atomic::{self, AtomicUsize, Ordering},
|
||||
time::Instant,
|
||||
};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
|
||||
use gpui2::AppContext;
|
||||
use parking_lot::Mutex;
|
||||
|
||||
use crate::{
|
||||
auth::{CredentialProvider, ProviderCredential},
|
||||
completion::{CompletionProvider, CompletionRequest},
|
||||
embedding::{Embedding, EmbeddingProvider},
|
||||
models::{LanguageModel, TruncationDirection},
|
||||
};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct FakeLanguageModel {
|
||||
pub capacity: usize,
|
||||
}
|
||||
|
||||
impl LanguageModel for FakeLanguageModel {
|
||||
fn name(&self) -> String {
|
||||
"dummy".to_string()
|
||||
}
|
||||
fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
|
||||
anyhow::Ok(content.chars().collect::<Vec<char>>().len())
|
||||
}
|
||||
fn truncate(
|
||||
&self,
|
||||
content: &str,
|
||||
length: usize,
|
||||
direction: TruncationDirection,
|
||||
) -> anyhow::Result<String> {
|
||||
println!("TRYING TO TRUNCATE: {:?}", length.clone());
|
||||
|
||||
if length > self.count_tokens(content)? {
|
||||
println!("NOT TRUNCATING");
|
||||
return anyhow::Ok(content.to_string());
|
||||
}
|
||||
|
||||
anyhow::Ok(match direction {
|
||||
TruncationDirection::End => content.chars().collect::<Vec<char>>()[..length]
|
||||
.into_iter()
|
||||
.collect::<String>(),
|
||||
TruncationDirection::Start => content.chars().collect::<Vec<char>>()[length..]
|
||||
.into_iter()
|
||||
.collect::<String>(),
|
||||
})
|
||||
}
|
||||
fn capacity(&self) -> anyhow::Result<usize> {
|
||||
anyhow::Ok(self.capacity)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct FakeEmbeddingProvider {
|
||||
pub embedding_count: AtomicUsize,
|
||||
}
|
||||
|
||||
impl Clone for FakeEmbeddingProvider {
|
||||
fn clone(&self) -> Self {
|
||||
FakeEmbeddingProvider {
|
||||
embedding_count: AtomicUsize::new(self.embedding_count.load(Ordering::SeqCst)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for FakeEmbeddingProvider {
|
||||
fn default() -> Self {
|
||||
FakeEmbeddingProvider {
|
||||
embedding_count: AtomicUsize::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl FakeEmbeddingProvider {
|
||||
pub fn embedding_count(&self) -> usize {
|
||||
self.embedding_count.load(atomic::Ordering::SeqCst)
|
||||
}
|
||||
|
||||
pub fn embed_sync(&self, span: &str) -> Embedding {
|
||||
let mut result = vec![1.0; 26];
|
||||
for letter in span.chars() {
|
||||
let letter = letter.to_ascii_lowercase();
|
||||
if letter as u32 >= 'a' as u32 {
|
||||
let ix = (letter as u32) - ('a' as u32);
|
||||
if ix < 26 {
|
||||
result[ix as usize] += 1.0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
for x in &mut result {
|
||||
*x /= norm;
|
||||
}
|
||||
|
||||
result.into()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl CredentialProvider for FakeEmbeddingProvider {
|
||||
fn has_credentials(&self) -> bool {
|
||||
true
|
||||
}
|
||||
async fn retrieve_credentials(&self, _cx: &mut AppContext) -> ProviderCredential {
|
||||
ProviderCredential::NotNeeded
|
||||
}
|
||||
async fn save_credentials(&self, _cx: &mut AppContext, _credential: ProviderCredential) {}
|
||||
async fn delete_credentials(&self, _cx: &mut AppContext) {}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl EmbeddingProvider for FakeEmbeddingProvider {
|
||||
fn base_model(&self) -> Box<dyn LanguageModel> {
|
||||
Box::new(FakeLanguageModel { capacity: 1000 })
|
||||
}
|
||||
fn max_tokens_per_batch(&self) -> usize {
|
||||
1000
|
||||
}
|
||||
|
||||
fn rate_limit_expiration(&self) -> Option<Instant> {
|
||||
None
|
||||
}
|
||||
|
||||
async fn embed_batch(&self, spans: Vec<String>) -> anyhow::Result<Vec<Embedding>> {
|
||||
self.embedding_count
|
||||
.fetch_add(spans.len(), atomic::Ordering::SeqCst);
|
||||
|
||||
anyhow::Ok(spans.iter().map(|span| self.embed_sync(span)).collect())
|
||||
}
|
||||
}
|
||||
|
||||
pub struct FakeCompletionProvider {
|
||||
last_completion_tx: Mutex<Option<mpsc::Sender<String>>>,
|
||||
}
|
||||
|
||||
impl Clone for FakeCompletionProvider {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
last_completion_tx: Mutex::new(None),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl FakeCompletionProvider {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
last_completion_tx: Mutex::new(None),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn send_completion(&self, completion: impl Into<String>) {
|
||||
let mut tx = self.last_completion_tx.lock();
|
||||
tx.as_mut().unwrap().try_send(completion.into()).unwrap();
|
||||
}
|
||||
|
||||
pub fn finish_completion(&self) {
|
||||
self.last_completion_tx.lock().take().unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl CredentialProvider for FakeCompletionProvider {
|
||||
fn has_credentials(&self) -> bool {
|
||||
true
|
||||
}
|
||||
async fn retrieve_credentials(&self, _cx: &mut AppContext) -> ProviderCredential {
|
||||
ProviderCredential::NotNeeded
|
||||
}
|
||||
async fn save_credentials(&self, _cx: &mut AppContext, _credential: ProviderCredential) {}
|
||||
async fn delete_credentials(&self, _cx: &mut AppContext) {}
|
||||
}
|
||||
|
||||
impl CompletionProvider for FakeCompletionProvider {
|
||||
fn base_model(&self) -> Box<dyn LanguageModel> {
|
||||
let model: Box<dyn LanguageModel> = Box::new(FakeLanguageModel { capacity: 8190 });
|
||||
model
|
||||
}
|
||||
fn complete(
|
||||
&self,
|
||||
_prompt: Box<dyn CompletionRequest>,
|
||||
) -> BoxFuture<'static, anyhow::Result<BoxStream<'static, anyhow::Result<String>>>> {
|
||||
let (tx, rx) = mpsc::channel(1);
|
||||
*self.last_completion_tx.lock() = Some(tx);
|
||||
async move { Ok(rx.map(|rx| Ok(rx)).boxed()) }.boxed()
|
||||
}
|
||||
fn box_clone(&self) -> Box<dyn CompletionProvider> {
|
||||
Box::new((*self).clone())
|
||||
}
|
||||
}
|
|
@ -15,6 +15,7 @@ name = "Zed"
|
|||
path = "src/main.rs"
|
||||
|
||||
[dependencies]
|
||||
ai2 = { path = "../ai2"}
|
||||
# audio = { path = "../audio" }
|
||||
# activity_indicator = { path = "../activity_indicator" }
|
||||
# auto_update = { path = "../auto_update" }
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue