moved authentication for the semantic index into the EmbeddingProvider

This commit is contained in:
KCaverly 2023-10-30 10:02:27 -04:00
parent 1e8b23d8fb
commit a2c3971ad6
14 changed files with 200 additions and 206 deletions

View file

@ -5,10 +5,11 @@ use std::{
use async_trait::async_trait;
use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
use gpui::AppContext;
use parking_lot::Mutex;
use crate::{
auth::{CredentialProvider, NullCredentialProvider, ProviderCredential},
auth::{CredentialProvider, ProviderCredential},
completion::{CompletionProvider, CompletionRequest},
embedding::{Embedding, EmbeddingProvider},
models::{LanguageModel, TruncationDirection},
@ -52,14 +53,12 @@ impl LanguageModel for FakeLanguageModel {
pub struct FakeEmbeddingProvider {
pub embedding_count: AtomicUsize,
pub credential_provider: NullCredentialProvider,
}
impl Clone for FakeEmbeddingProvider {
fn clone(&self) -> Self {
FakeEmbeddingProvider {
embedding_count: AtomicUsize::new(self.embedding_count.load(Ordering::SeqCst)),
credential_provider: self.credential_provider.clone(),
}
}
}
@ -68,7 +67,6 @@ impl Default for FakeEmbeddingProvider {
fn default() -> Self {
FakeEmbeddingProvider {
embedding_count: AtomicUsize::default(),
credential_provider: NullCredentialProvider {},
}
}
}
@ -99,16 +97,22 @@ impl FakeEmbeddingProvider {
}
}
impl CredentialProvider for FakeEmbeddingProvider {
fn has_credentials(&self) -> bool {
true
}
fn retrieve_credentials(&self, _cx: &AppContext) -> ProviderCredential {
ProviderCredential::NotNeeded
}
fn save_credentials(&self, _cx: &AppContext, _credential: ProviderCredential) {}
fn delete_credentials(&self, _cx: &AppContext) {}
}
#[async_trait]
impl EmbeddingProvider for FakeEmbeddingProvider {
fn base_model(&self) -> Box<dyn LanguageModel> {
Box::new(FakeLanguageModel { capacity: 1000 })
}
fn credential_provider(&self) -> Box<dyn CredentialProvider> {
let credential_provider: Box<dyn CredentialProvider> =
Box::new(self.credential_provider.clone());
credential_provider
}
fn max_tokens_per_batch(&self) -> usize {
1000
}
@ -117,11 +121,7 @@ impl EmbeddingProvider for FakeEmbeddingProvider {
None
}
async fn embed_batch(
&self,
spans: Vec<String>,
_credential: ProviderCredential,
) -> anyhow::Result<Vec<Embedding>> {
async fn embed_batch(&self, spans: Vec<String>) -> anyhow::Result<Vec<Embedding>> {
self.embedding_count
.fetch_add(spans.len(), atomic::Ordering::SeqCst);
@ -129,11 +129,11 @@ impl EmbeddingProvider for FakeEmbeddingProvider {
}
}
pub struct TestCompletionProvider {
pub struct FakeCompletionProvider {
last_completion_tx: Mutex<Option<mpsc::Sender<String>>>,
}
impl TestCompletionProvider {
impl FakeCompletionProvider {
pub fn new() -> Self {
Self {
last_completion_tx: Mutex::new(None),
@ -150,14 +150,22 @@ impl TestCompletionProvider {
}
}
impl CompletionProvider for TestCompletionProvider {
impl CredentialProvider for FakeCompletionProvider {
fn has_credentials(&self) -> bool {
true
}
fn retrieve_credentials(&self, _cx: &AppContext) -> ProviderCredential {
ProviderCredential::NotNeeded
}
fn save_credentials(&self, _cx: &AppContext, _credential: ProviderCredential) {}
fn delete_credentials(&self, _cx: &AppContext) {}
}
impl CompletionProvider for FakeCompletionProvider {
fn base_model(&self) -> Box<dyn LanguageModel> {
let model: Box<dyn LanguageModel> = Box::new(FakeLanguageModel { capacity: 8190 });
model
}
fn credential_provider(&self) -> Box<dyn CredentialProvider> {
Box::new(NullCredentialProvider {})
}
fn complete(
&self,
_prompt: Box<dyn CompletionRequest>,