updated authentication for embedding provider

This commit is contained in:
KCaverly 2023-10-26 11:18:16 +02:00
parent 71bc35d241
commit 3447a9478c
16 changed files with 277 additions and 271 deletions

View file

@ -4,14 +4,9 @@ use crate::{
semantic_index_settings::SemanticIndexSettings,
FileToEmbed, JobHandle, SearchResult, SemanticIndex, EMBEDDING_QUEUE_FLUSH_TIMEOUT,
};
use ai::providers::dummy::{DummyEmbeddingProvider, DummyLanguageModel};
use ai::{
embedding::{Embedding, EmbeddingProvider},
models::LanguageModel,
};
use anyhow::Result;
use async_trait::async_trait;
use gpui::{executor::Deterministic, AppContext, Task, TestAppContext};
use ai::test::FakeEmbeddingProvider;
use gpui::{executor::Deterministic, Task, TestAppContext};
use language::{Language, LanguageConfig, LanguageRegistry, ToOffset};
use parking_lot::Mutex;
use pretty_assertions::assert_eq;
@ -19,14 +14,7 @@ use project::{project_settings::ProjectSettings, search::PathMatcher, FakeFs, Fs
use rand::{rngs::StdRng, Rng};
use serde_json::json;
use settings::SettingsStore;
use std::{
path::Path,
sync::{
atomic::{self, AtomicUsize},
Arc,
},
time::{Instant, SystemTime},
};
use std::{path::Path, sync::Arc, time::SystemTime};
use unindent::Unindent;
use util::RandomCharIter;
@ -232,7 +220,11 @@ async fn test_embedding_batching(cx: &mut TestAppContext, mut rng: StdRng) {
let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
let mut queue = EmbeddingQueue::new(embedding_provider.clone(), cx.background(), None);
let mut queue = EmbeddingQueue::new(
embedding_provider.clone(),
cx.background(),
ai::auth::ProviderCredential::NoCredentials,
);
for file in &files {
queue.push(file.clone());
}
@ -284,7 +276,7 @@ fn assert_search_results(
#[gpui::test]
async fn test_code_context_retrieval_rust() {
let language = rust_lang();
let embedding_provider = Arc::new(DummyEmbeddingProvider {});
let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
let mut retriever = CodeContextRetriever::new(embedding_provider);
let text = "
@ -386,7 +378,7 @@ async fn test_code_context_retrieval_rust() {
#[gpui::test]
async fn test_code_context_retrieval_json() {
let language = json_lang();
let embedding_provider = Arc::new(DummyEmbeddingProvider {});
let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
let mut retriever = CodeContextRetriever::new(embedding_provider);
let text = r#"
@ -470,7 +462,7 @@ fn assert_documents_eq(
#[gpui::test]
async fn test_code_context_retrieval_javascript() {
let language = js_lang();
let embedding_provider = Arc::new(DummyEmbeddingProvider {});
let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
let mut retriever = CodeContextRetriever::new(embedding_provider);
let text = "
@ -569,7 +561,7 @@ async fn test_code_context_retrieval_javascript() {
#[gpui::test]
async fn test_code_context_retrieval_lua() {
let language = lua_lang();
let embedding_provider = Arc::new(DummyEmbeddingProvider {});
let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
let mut retriever = CodeContextRetriever::new(embedding_provider);
let text = r#"
@ -643,7 +635,7 @@ async fn test_code_context_retrieval_lua() {
#[gpui::test]
async fn test_code_context_retrieval_elixir() {
let language = elixir_lang();
let embedding_provider = Arc::new(DummyEmbeddingProvider {});
let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
let mut retriever = CodeContextRetriever::new(embedding_provider);
let text = r#"
@ -760,7 +752,7 @@ async fn test_code_context_retrieval_elixir() {
#[gpui::test]
async fn test_code_context_retrieval_cpp() {
let language = cpp_lang();
let embedding_provider = Arc::new(DummyEmbeddingProvider {});
let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
let mut retriever = CodeContextRetriever::new(embedding_provider);
let text = "
@ -913,7 +905,7 @@ async fn test_code_context_retrieval_cpp() {
#[gpui::test]
async fn test_code_context_retrieval_ruby() {
let language = ruby_lang();
let embedding_provider = Arc::new(DummyEmbeddingProvider {});
let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
let mut retriever = CodeContextRetriever::new(embedding_provider);
let text = r#"
@ -1104,7 +1096,7 @@ async fn test_code_context_retrieval_ruby() {
#[gpui::test]
async fn test_code_context_retrieval_php() {
let language = php_lang();
let embedding_provider = Arc::new(DummyEmbeddingProvider {});
let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
let mut retriever = CodeContextRetriever::new(embedding_provider);
let text = r#"
@ -1252,65 +1244,6 @@ async fn test_code_context_retrieval_php() {
);
}
#[derive(Default)]
struct FakeEmbeddingProvider {
embedding_count: AtomicUsize,
}
impl FakeEmbeddingProvider {
fn embedding_count(&self) -> usize {
self.embedding_count.load(atomic::Ordering::SeqCst)
}
fn embed_sync(&self, span: &str) -> Embedding {
let mut result = vec![1.0; 26];
for letter in span.chars() {
let letter = letter.to_ascii_lowercase();
if letter as u32 >= 'a' as u32 {
let ix = (letter as u32) - ('a' as u32);
if ix < 26 {
result[ix as usize] += 1.0;
}
}
}
let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
for x in &mut result {
*x /= norm;
}
result.into()
}
}
#[async_trait]
impl EmbeddingProvider for FakeEmbeddingProvider {
fn base_model(&self) -> Box<dyn LanguageModel> {
Box::new(DummyLanguageModel {})
}
fn retrieve_credentials(&self, _cx: &AppContext) -> Option<String> {
Some("Fake Credentials".to_string())
}
fn max_tokens_per_batch(&self) -> usize {
1000
}
fn rate_limit_expiration(&self) -> Option<Instant> {
None
}
async fn embed_batch(
&self,
spans: Vec<String>,
_api_key: Option<String>,
) -> Result<Vec<Embedding>> {
self.embedding_count
.fetch_add(spans.len(), atomic::Ordering::SeqCst);
anyhow::Ok(spans.iter().map(|span| self.embed_sync(span)).collect())
}
}
fn js_lang() -> Arc<Language> {
Arc::new(
Language::new(