193 lines
5.4 KiB
Rust
193 lines
5.4 KiB
Rust
use std::{
|
|
sync::atomic::{self, AtomicUsize, Ordering},
|
|
time::Instant,
|
|
};
|
|
|
|
use async_trait::async_trait;
|
|
use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
|
|
use gpui2::AppContext;
|
|
use parking_lot::Mutex;
|
|
|
|
use crate::{
|
|
auth::{CredentialProvider, ProviderCredential},
|
|
completion::{CompletionProvider, CompletionRequest},
|
|
embedding::{Embedding, EmbeddingProvider},
|
|
models::{LanguageModel, TruncationDirection},
|
|
};
|
|
|
|
#[derive(Clone)]
|
|
pub struct FakeLanguageModel {
|
|
pub capacity: usize,
|
|
}
|
|
|
|
impl LanguageModel for FakeLanguageModel {
|
|
fn name(&self) -> String {
|
|
"dummy".to_string()
|
|
}
|
|
fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
|
|
anyhow::Ok(content.chars().collect::<Vec<char>>().len())
|
|
}
|
|
fn truncate(
|
|
&self,
|
|
content: &str,
|
|
length: usize,
|
|
direction: TruncationDirection,
|
|
) -> anyhow::Result<String> {
|
|
println!("TRYING TO TRUNCATE: {:?}", length.clone());
|
|
|
|
if length > self.count_tokens(content)? {
|
|
println!("NOT TRUNCATING");
|
|
return anyhow::Ok(content.to_string());
|
|
}
|
|
|
|
anyhow::Ok(match direction {
|
|
TruncationDirection::End => content.chars().collect::<Vec<char>>()[..length]
|
|
.into_iter()
|
|
.collect::<String>(),
|
|
TruncationDirection::Start => content.chars().collect::<Vec<char>>()[length..]
|
|
.into_iter()
|
|
.collect::<String>(),
|
|
})
|
|
}
|
|
fn capacity(&self) -> anyhow::Result<usize> {
|
|
anyhow::Ok(self.capacity)
|
|
}
|
|
}
|
|
|
|
pub struct FakeEmbeddingProvider {
|
|
pub embedding_count: AtomicUsize,
|
|
}
|
|
|
|
impl Clone for FakeEmbeddingProvider {
|
|
fn clone(&self) -> Self {
|
|
FakeEmbeddingProvider {
|
|
embedding_count: AtomicUsize::new(self.embedding_count.load(Ordering::SeqCst)),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl Default for FakeEmbeddingProvider {
|
|
fn default() -> Self {
|
|
FakeEmbeddingProvider {
|
|
embedding_count: AtomicUsize::default(),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl FakeEmbeddingProvider {
|
|
pub fn embedding_count(&self) -> usize {
|
|
self.embedding_count.load(atomic::Ordering::SeqCst)
|
|
}
|
|
|
|
pub fn embed_sync(&self, span: &str) -> Embedding {
|
|
let mut result = vec![1.0; 26];
|
|
for letter in span.chars() {
|
|
let letter = letter.to_ascii_lowercase();
|
|
if letter as u32 >= 'a' as u32 {
|
|
let ix = (letter as u32) - ('a' as u32);
|
|
if ix < 26 {
|
|
result[ix as usize] += 1.0;
|
|
}
|
|
}
|
|
}
|
|
|
|
let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
|
|
for x in &mut result {
|
|
*x /= norm;
|
|
}
|
|
|
|
result.into()
|
|
}
|
|
}
|
|
|
|
#[async_trait]
|
|
impl CredentialProvider for FakeEmbeddingProvider {
|
|
fn has_credentials(&self) -> bool {
|
|
true
|
|
}
|
|
async fn retrieve_credentials(&self, _cx: &mut AppContext) -> ProviderCredential {
|
|
ProviderCredential::NotNeeded
|
|
}
|
|
async fn save_credentials(&self, _cx: &mut AppContext, _credential: ProviderCredential) {}
|
|
async fn delete_credentials(&self, _cx: &mut AppContext) {}
|
|
}
|
|
|
|
#[async_trait]
|
|
impl EmbeddingProvider for FakeEmbeddingProvider {
|
|
fn base_model(&self) -> Box<dyn LanguageModel> {
|
|
Box::new(FakeLanguageModel { capacity: 1000 })
|
|
}
|
|
fn max_tokens_per_batch(&self) -> usize {
|
|
1000
|
|
}
|
|
|
|
fn rate_limit_expiration(&self) -> Option<Instant> {
|
|
None
|
|
}
|
|
|
|
async fn embed_batch(&self, spans: Vec<String>) -> anyhow::Result<Vec<Embedding>> {
|
|
self.embedding_count
|
|
.fetch_add(spans.len(), atomic::Ordering::SeqCst);
|
|
|
|
anyhow::Ok(spans.iter().map(|span| self.embed_sync(span)).collect())
|
|
}
|
|
}
|
|
|
|
pub struct FakeCompletionProvider {
|
|
last_completion_tx: Mutex<Option<mpsc::Sender<String>>>,
|
|
}
|
|
|
|
impl Clone for FakeCompletionProvider {
|
|
fn clone(&self) -> Self {
|
|
Self {
|
|
last_completion_tx: Mutex::new(None),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl FakeCompletionProvider {
|
|
pub fn new() -> Self {
|
|
Self {
|
|
last_completion_tx: Mutex::new(None),
|
|
}
|
|
}
|
|
|
|
pub fn send_completion(&self, completion: impl Into<String>) {
|
|
let mut tx = self.last_completion_tx.lock();
|
|
tx.as_mut().unwrap().try_send(completion.into()).unwrap();
|
|
}
|
|
|
|
pub fn finish_completion(&self) {
|
|
self.last_completion_tx.lock().take().unwrap();
|
|
}
|
|
}
|
|
|
|
#[async_trait]
|
|
impl CredentialProvider for FakeCompletionProvider {
|
|
fn has_credentials(&self) -> bool {
|
|
true
|
|
}
|
|
async fn retrieve_credentials(&self, _cx: &mut AppContext) -> ProviderCredential {
|
|
ProviderCredential::NotNeeded
|
|
}
|
|
async fn save_credentials(&self, _cx: &mut AppContext, _credential: ProviderCredential) {}
|
|
async fn delete_credentials(&self, _cx: &mut AppContext) {}
|
|
}
|
|
|
|
impl CompletionProvider for FakeCompletionProvider {
|
|
fn base_model(&self) -> Box<dyn LanguageModel> {
|
|
let model: Box<dyn LanguageModel> = Box::new(FakeLanguageModel { capacity: 8190 });
|
|
model
|
|
}
|
|
fn complete(
|
|
&self,
|
|
_prompt: Box<dyn CompletionRequest>,
|
|
) -> BoxFuture<'static, anyhow::Result<BoxStream<'static, anyhow::Result<String>>>> {
|
|
let (tx, rx) = mpsc::channel(1);
|
|
*self.last_completion_tx.lock() = Some(tx);
|
|
async move { Ok(rx.map(|rx| Ok(rx)).boxed()) }.boxed()
|
|
}
|
|
fn box_clone(&self) -> Box<dyn CompletionProvider> {
|
|
Box::new((*self).clone())
|
|
}
|
|
}
|