updated authentication for embedding provider
This commit is contained in:
parent
71bc35d241
commit
3447a9478c
16 changed files with 277 additions and 271 deletions
|
@ -4,14 +4,9 @@ use crate::{
|
|||
semantic_index_settings::SemanticIndexSettings,
|
||||
FileToEmbed, JobHandle, SearchResult, SemanticIndex, EMBEDDING_QUEUE_FLUSH_TIMEOUT,
|
||||
};
|
||||
use ai::providers::dummy::{DummyEmbeddingProvider, DummyLanguageModel};
|
||||
use ai::{
|
||||
embedding::{Embedding, EmbeddingProvider},
|
||||
models::LanguageModel,
|
||||
};
|
||||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use gpui::{executor::Deterministic, AppContext, Task, TestAppContext};
|
||||
use ai::test::FakeEmbeddingProvider;
|
||||
|
||||
use gpui::{executor::Deterministic, Task, TestAppContext};
|
||||
use language::{Language, LanguageConfig, LanguageRegistry, ToOffset};
|
||||
use parking_lot::Mutex;
|
||||
use pretty_assertions::assert_eq;
|
||||
|
@ -19,14 +14,7 @@ use project::{project_settings::ProjectSettings, search::PathMatcher, FakeFs, Fs
|
|||
use rand::{rngs::StdRng, Rng};
|
||||
use serde_json::json;
|
||||
use settings::SettingsStore;
|
||||
use std::{
|
||||
path::Path,
|
||||
sync::{
|
||||
atomic::{self, AtomicUsize},
|
||||
Arc,
|
||||
},
|
||||
time::{Instant, SystemTime},
|
||||
};
|
||||
use std::{path::Path, sync::Arc, time::SystemTime};
|
||||
use unindent::Unindent;
|
||||
use util::RandomCharIter;
|
||||
|
||||
|
@ -232,7 +220,11 @@ async fn test_embedding_batching(cx: &mut TestAppContext, mut rng: StdRng) {
|
|||
|
||||
let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
|
||||
|
||||
let mut queue = EmbeddingQueue::new(embedding_provider.clone(), cx.background(), None);
|
||||
let mut queue = EmbeddingQueue::new(
|
||||
embedding_provider.clone(),
|
||||
cx.background(),
|
||||
ai::auth::ProviderCredential::NoCredentials,
|
||||
);
|
||||
for file in &files {
|
||||
queue.push(file.clone());
|
||||
}
|
||||
|
@ -284,7 +276,7 @@ fn assert_search_results(
|
|||
#[gpui::test]
|
||||
async fn test_code_context_retrieval_rust() {
|
||||
let language = rust_lang();
|
||||
let embedding_provider = Arc::new(DummyEmbeddingProvider {});
|
||||
let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
|
||||
let mut retriever = CodeContextRetriever::new(embedding_provider);
|
||||
|
||||
let text = "
|
||||
|
@ -386,7 +378,7 @@ async fn test_code_context_retrieval_rust() {
|
|||
#[gpui::test]
|
||||
async fn test_code_context_retrieval_json() {
|
||||
let language = json_lang();
|
||||
let embedding_provider = Arc::new(DummyEmbeddingProvider {});
|
||||
let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
|
||||
let mut retriever = CodeContextRetriever::new(embedding_provider);
|
||||
|
||||
let text = r#"
|
||||
|
@ -470,7 +462,7 @@ fn assert_documents_eq(
|
|||
#[gpui::test]
|
||||
async fn test_code_context_retrieval_javascript() {
|
||||
let language = js_lang();
|
||||
let embedding_provider = Arc::new(DummyEmbeddingProvider {});
|
||||
let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
|
||||
let mut retriever = CodeContextRetriever::new(embedding_provider);
|
||||
|
||||
let text = "
|
||||
|
@ -569,7 +561,7 @@ async fn test_code_context_retrieval_javascript() {
|
|||
#[gpui::test]
|
||||
async fn test_code_context_retrieval_lua() {
|
||||
let language = lua_lang();
|
||||
let embedding_provider = Arc::new(DummyEmbeddingProvider {});
|
||||
let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
|
||||
let mut retriever = CodeContextRetriever::new(embedding_provider);
|
||||
|
||||
let text = r#"
|
||||
|
@ -643,7 +635,7 @@ async fn test_code_context_retrieval_lua() {
|
|||
#[gpui::test]
|
||||
async fn test_code_context_retrieval_elixir() {
|
||||
let language = elixir_lang();
|
||||
let embedding_provider = Arc::new(DummyEmbeddingProvider {});
|
||||
let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
|
||||
let mut retriever = CodeContextRetriever::new(embedding_provider);
|
||||
|
||||
let text = r#"
|
||||
|
@ -760,7 +752,7 @@ async fn test_code_context_retrieval_elixir() {
|
|||
#[gpui::test]
|
||||
async fn test_code_context_retrieval_cpp() {
|
||||
let language = cpp_lang();
|
||||
let embedding_provider = Arc::new(DummyEmbeddingProvider {});
|
||||
let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
|
||||
let mut retriever = CodeContextRetriever::new(embedding_provider);
|
||||
|
||||
let text = "
|
||||
|
@ -913,7 +905,7 @@ async fn test_code_context_retrieval_cpp() {
|
|||
#[gpui::test]
|
||||
async fn test_code_context_retrieval_ruby() {
|
||||
let language = ruby_lang();
|
||||
let embedding_provider = Arc::new(DummyEmbeddingProvider {});
|
||||
let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
|
||||
let mut retriever = CodeContextRetriever::new(embedding_provider);
|
||||
|
||||
let text = r#"
|
||||
|
@ -1104,7 +1096,7 @@ async fn test_code_context_retrieval_ruby() {
|
|||
#[gpui::test]
|
||||
async fn test_code_context_retrieval_php() {
|
||||
let language = php_lang();
|
||||
let embedding_provider = Arc::new(DummyEmbeddingProvider {});
|
||||
let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
|
||||
let mut retriever = CodeContextRetriever::new(embedding_provider);
|
||||
|
||||
let text = r#"
|
||||
|
@ -1252,65 +1244,6 @@ async fn test_code_context_retrieval_php() {
|
|||
);
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
struct FakeEmbeddingProvider {
|
||||
embedding_count: AtomicUsize,
|
||||
}
|
||||
|
||||
impl FakeEmbeddingProvider {
|
||||
fn embedding_count(&self) -> usize {
|
||||
self.embedding_count.load(atomic::Ordering::SeqCst)
|
||||
}
|
||||
|
||||
fn embed_sync(&self, span: &str) -> Embedding {
|
||||
let mut result = vec![1.0; 26];
|
||||
for letter in span.chars() {
|
||||
let letter = letter.to_ascii_lowercase();
|
||||
if letter as u32 >= 'a' as u32 {
|
||||
let ix = (letter as u32) - ('a' as u32);
|
||||
if ix < 26 {
|
||||
result[ix as usize] += 1.0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
for x in &mut result {
|
||||
*x /= norm;
|
||||
}
|
||||
|
||||
result.into()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl EmbeddingProvider for FakeEmbeddingProvider {
|
||||
fn base_model(&self) -> Box<dyn LanguageModel> {
|
||||
Box::new(DummyLanguageModel {})
|
||||
}
|
||||
fn retrieve_credentials(&self, _cx: &AppContext) -> Option<String> {
|
||||
Some("Fake Credentials".to_string())
|
||||
}
|
||||
fn max_tokens_per_batch(&self) -> usize {
|
||||
1000
|
||||
}
|
||||
|
||||
fn rate_limit_expiration(&self) -> Option<Instant> {
|
||||
None
|
||||
}
|
||||
|
||||
async fn embed_batch(
|
||||
&self,
|
||||
spans: Vec<String>,
|
||||
_api_key: Option<String>,
|
||||
) -> Result<Vec<Embedding>> {
|
||||
self.embedding_count
|
||||
.fetch_add(spans.len(), atomic::Ordering::SeqCst);
|
||||
|
||||
anyhow::Ok(spans.iter().map(|span| self.embed_sync(span)).collect())
|
||||
}
|
||||
}
|
||||
|
||||
fn js_lang() -> Arc<Language> {
|
||||
Arc::new(
|
||||
Language::new(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue