updated authentication for embedding provider

This commit is contained in:
KCaverly 2023-10-26 11:18:16 +02:00
parent 71bc35d241
commit 3447a9478c
16 changed files with 277 additions and 271 deletions

View file

@ -42,6 +42,7 @@ sha1 = "0.10.5"
ndarray = { version = "0.15.0" }
[dev-dependencies]
ai = { path = "../ai", features = ["test-support"] }
collections = { path = "../collections", features = ["test-support"] }
gpui = { path = "../gpui", features = ["test-support"] }
language = { path = "../language", features = ["test-support"] }

View file

@ -1,5 +1,5 @@
use crate::{parsing::Span, JobHandle};
use ai::embedding::EmbeddingProvider;
use ai::{auth::ProviderCredential, embedding::EmbeddingProvider};
use gpui::executor::Background;
use parking_lot::Mutex;
use smol::channel;
@ -41,7 +41,7 @@ pub struct EmbeddingQueue {
pending_batch_token_count: usize,
finished_files_tx: channel::Sender<FileToEmbed>,
finished_files_rx: channel::Receiver<FileToEmbed>,
api_key: Option<String>,
provider_credential: ProviderCredential,
}
#[derive(Clone)]
@ -54,7 +54,7 @@ impl EmbeddingQueue {
pub fn new(
embedding_provider: Arc<dyn EmbeddingProvider>,
executor: Arc<Background>,
api_key: Option<String>,
provider_credential: ProviderCredential,
) -> Self {
let (finished_files_tx, finished_files_rx) = channel::unbounded();
Self {
@ -64,12 +64,12 @@ impl EmbeddingQueue {
pending_batch_token_count: 0,
finished_files_tx,
finished_files_rx,
api_key,
provider_credential,
}
}
pub fn set_api_key(&mut self, api_key: Option<String>) {
self.api_key = api_key
pub fn set_credential(&mut self, credential: ProviderCredential) {
self.provider_credential = credential
}
pub fn push(&mut self, file: FileToEmbed) {
@ -118,7 +118,7 @@ impl EmbeddingQueue {
let finished_files_tx = self.finished_files_tx.clone();
let embedding_provider = self.embedding_provider.clone();
let api_key = self.api_key.clone();
let credential = self.provider_credential.clone();
self.executor
.spawn(async move {
@ -143,7 +143,7 @@ impl EmbeddingQueue {
return;
};
match embedding_provider.embed_batch(spans, api_key).await {
match embedding_provider.embed_batch(spans, credential).await {
Ok(embeddings) => {
let mut embeddings = embeddings.into_iter();
for fragment in batch {

View file

@ -7,6 +7,7 @@ pub mod semantic_index_settings;
mod semantic_index_tests;
use crate::semantic_index_settings::SemanticIndexSettings;
use ai::auth::ProviderCredential;
use ai::embedding::{Embedding, EmbeddingProvider};
use ai::providers::open_ai::OpenAIEmbeddingProvider;
use anyhow::{anyhow, Result};
@ -124,7 +125,7 @@ pub struct SemanticIndex {
_embedding_task: Task<()>,
_parsing_files_tasks: Vec<Task<()>>,
projects: HashMap<WeakModelHandle<Project>, ProjectState>,
api_key: Option<String>,
provider_credential: ProviderCredential,
embedding_queue: Arc<Mutex<EmbeddingQueue>>,
}
@ -279,18 +280,27 @@ impl SemanticIndex {
}
}
pub fn authenticate(&mut self, cx: &AppContext) {
if self.api_key.is_none() {
self.api_key = self.embedding_provider.retrieve_credentials(cx);
self.embedding_queue
.lock()
.set_api_key(self.api_key.clone());
pub fn authenticate(&mut self, cx: &AppContext) -> bool {
let credential = self.provider_credential.clone();
match credential {
ProviderCredential::NoCredentials => {
let credential = self.embedding_provider.retrieve_credentials(cx);
self.provider_credential = credential;
}
_ => {}
}
self.embedding_queue.lock().set_credential(credential);
self.is_authenticated()
}
pub fn is_authenticated(&self) -> bool {
self.api_key.is_some()
let credential = &self.provider_credential;
match credential {
&ProviderCredential::Credentials { .. } => true,
_ => false,
}
}
pub fn enabled(cx: &AppContext) -> bool {
@ -340,7 +350,7 @@ impl SemanticIndex {
Ok(cx.add_model(|cx| {
let t0 = Instant::now();
let embedding_queue =
EmbeddingQueue::new(embedding_provider.clone(), cx.background().clone(), None);
EmbeddingQueue::new(embedding_provider.clone(), cx.background().clone(), ProviderCredential::NoCredentials);
let _embedding_task = cx.background().spawn({
let embedded_files = embedding_queue.finished_files();
let db = db.clone();
@ -405,7 +415,7 @@ impl SemanticIndex {
_embedding_task,
_parsing_files_tasks,
projects: Default::default(),
api_key: None,
provider_credential: ProviderCredential::NoCredentials,
embedding_queue
}
}))
@ -721,13 +731,14 @@ impl SemanticIndex {
let index = self.index_project(project.clone(), cx);
let embedding_provider = self.embedding_provider.clone();
let api_key = self.api_key.clone();
let credential = self.provider_credential.clone();
cx.spawn(|this, mut cx| async move {
index.await?;
let t0 = Instant::now();
let query = embedding_provider
.embed_batch(vec![query], api_key)
.embed_batch(vec![query], credential)
.await?
.pop()
.ok_or_else(|| anyhow!("could not embed query"))?;
@ -945,7 +956,7 @@ impl SemanticIndex {
let fs = self.fs.clone();
let db_path = self.db.path().clone();
let background = cx.background().clone();
let api_key = self.api_key.clone();
let credential = self.provider_credential.clone();
cx.background().spawn(async move {
let db = VectorDatabase::new(fs, db_path.clone(), background).await?;
let mut results = Vec::<SearchResult>::new();
@ -964,7 +975,7 @@ impl SemanticIndex {
&mut spans,
embedding_provider.as_ref(),
&db,
api_key.clone(),
credential.clone(),
)
.await
.log_err()
@ -1008,9 +1019,8 @@ impl SemanticIndex {
project: ModelHandle<Project>,
cx: &mut ModelContext<Self>,
) -> Task<Result<()>> {
if self.api_key.is_none() {
self.authenticate(cx);
if self.api_key.is_none() {
if !self.is_authenticated() {
if !self.authenticate(cx) {
return Task::ready(Err(anyhow!("user is not authenticated")));
}
}
@ -1193,7 +1203,7 @@ impl SemanticIndex {
spans: &mut [Span],
embedding_provider: &dyn EmbeddingProvider,
db: &VectorDatabase,
api_key: Option<String>,
credential: ProviderCredential,
) -> Result<()> {
let mut batch = Vec::new();
let mut batch_tokens = 0;
@ -1216,7 +1226,7 @@ impl SemanticIndex {
if batch_tokens + span.token_count > embedding_provider.max_tokens_per_batch() {
let batch_embeddings = embedding_provider
.embed_batch(mem::take(&mut batch), api_key.clone())
.embed_batch(mem::take(&mut batch), credential.clone())
.await?;
embeddings.extend(batch_embeddings);
batch_tokens = 0;
@ -1228,7 +1238,7 @@ impl SemanticIndex {
if !batch.is_empty() {
let batch_embeddings = embedding_provider
.embed_batch(mem::take(&mut batch), api_key)
.embed_batch(mem::take(&mut batch), credential)
.await?;
embeddings.extend(batch_embeddings);

View file

@ -4,14 +4,9 @@ use crate::{
semantic_index_settings::SemanticIndexSettings,
FileToEmbed, JobHandle, SearchResult, SemanticIndex, EMBEDDING_QUEUE_FLUSH_TIMEOUT,
};
use ai::providers::dummy::{DummyEmbeddingProvider, DummyLanguageModel};
use ai::{
embedding::{Embedding, EmbeddingProvider},
models::LanguageModel,
};
use anyhow::Result;
use async_trait::async_trait;
use gpui::{executor::Deterministic, AppContext, Task, TestAppContext};
use ai::test::FakeEmbeddingProvider;
use gpui::{executor::Deterministic, Task, TestAppContext};
use language::{Language, LanguageConfig, LanguageRegistry, ToOffset};
use parking_lot::Mutex;
use pretty_assertions::assert_eq;
@ -19,14 +14,7 @@ use project::{project_settings::ProjectSettings, search::PathMatcher, FakeFs, Fs
use rand::{rngs::StdRng, Rng};
use serde_json::json;
use settings::SettingsStore;
use std::{
path::Path,
sync::{
atomic::{self, AtomicUsize},
Arc,
},
time::{Instant, SystemTime},
};
use std::{path::Path, sync::Arc, time::SystemTime};
use unindent::Unindent;
use util::RandomCharIter;
@ -232,7 +220,11 @@ async fn test_embedding_batching(cx: &mut TestAppContext, mut rng: StdRng) {
let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
let mut queue = EmbeddingQueue::new(embedding_provider.clone(), cx.background(), None);
let mut queue = EmbeddingQueue::new(
embedding_provider.clone(),
cx.background(),
ai::auth::ProviderCredential::NoCredentials,
);
for file in &files {
queue.push(file.clone());
}
@ -284,7 +276,7 @@ fn assert_search_results(
#[gpui::test]
async fn test_code_context_retrieval_rust() {
let language = rust_lang();
let embedding_provider = Arc::new(DummyEmbeddingProvider {});
let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
let mut retriever = CodeContextRetriever::new(embedding_provider);
let text = "
@ -386,7 +378,7 @@ async fn test_code_context_retrieval_rust() {
#[gpui::test]
async fn test_code_context_retrieval_json() {
let language = json_lang();
let embedding_provider = Arc::new(DummyEmbeddingProvider {});
let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
let mut retriever = CodeContextRetriever::new(embedding_provider);
let text = r#"
@ -470,7 +462,7 @@ fn assert_documents_eq(
#[gpui::test]
async fn test_code_context_retrieval_javascript() {
let language = js_lang();
let embedding_provider = Arc::new(DummyEmbeddingProvider {});
let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
let mut retriever = CodeContextRetriever::new(embedding_provider);
let text = "
@ -569,7 +561,7 @@ async fn test_code_context_retrieval_javascript() {
#[gpui::test]
async fn test_code_context_retrieval_lua() {
let language = lua_lang();
let embedding_provider = Arc::new(DummyEmbeddingProvider {});
let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
let mut retriever = CodeContextRetriever::new(embedding_provider);
let text = r#"
@ -643,7 +635,7 @@ async fn test_code_context_retrieval_lua() {
#[gpui::test]
async fn test_code_context_retrieval_elixir() {
let language = elixir_lang();
let embedding_provider = Arc::new(DummyEmbeddingProvider {});
let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
let mut retriever = CodeContextRetriever::new(embedding_provider);
let text = r#"
@ -760,7 +752,7 @@ async fn test_code_context_retrieval_elixir() {
#[gpui::test]
async fn test_code_context_retrieval_cpp() {
let language = cpp_lang();
let embedding_provider = Arc::new(DummyEmbeddingProvider {});
let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
let mut retriever = CodeContextRetriever::new(embedding_provider);
let text = "
@ -913,7 +905,7 @@ async fn test_code_context_retrieval_cpp() {
#[gpui::test]
async fn test_code_context_retrieval_ruby() {
let language = ruby_lang();
let embedding_provider = Arc::new(DummyEmbeddingProvider {});
let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
let mut retriever = CodeContextRetriever::new(embedding_provider);
let text = r#"
@ -1104,7 +1096,7 @@ async fn test_code_context_retrieval_ruby() {
#[gpui::test]
async fn test_code_context_retrieval_php() {
let language = php_lang();
let embedding_provider = Arc::new(DummyEmbeddingProvider {});
let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
let mut retriever = CodeContextRetriever::new(embedding_provider);
let text = r#"
@ -1252,65 +1244,6 @@ async fn test_code_context_retrieval_php() {
);
}
#[derive(Default)]
struct FakeEmbeddingProvider {
embedding_count: AtomicUsize,
}
impl FakeEmbeddingProvider {
fn embedding_count(&self) -> usize {
self.embedding_count.load(atomic::Ordering::SeqCst)
}
fn embed_sync(&self, span: &str) -> Embedding {
let mut result = vec![1.0; 26];
for letter in span.chars() {
let letter = letter.to_ascii_lowercase();
if letter as u32 >= 'a' as u32 {
let ix = (letter as u32) - ('a' as u32);
if ix < 26 {
result[ix as usize] += 1.0;
}
}
}
let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
for x in &mut result {
*x /= norm;
}
result.into()
}
}
#[async_trait]
impl EmbeddingProvider for FakeEmbeddingProvider {
fn base_model(&self) -> Box<dyn LanguageModel> {
Box::new(DummyLanguageModel {})
}
fn retrieve_credentials(&self, _cx: &AppContext) -> Option<String> {
Some("Fake Credentials".to_string())
}
fn max_tokens_per_batch(&self) -> usize {
1000
}
fn rate_limit_expiration(&self) -> Option<Instant> {
None
}
async fn embed_batch(
&self,
spans: Vec<String>,
_api_key: Option<String>,
) -> Result<Vec<Embedding>> {
self.embedding_count
.fetch_add(spans.len(), atomic::Ordering::SeqCst);
anyhow::Ok(spans.iter().map(|span| self.embed_sync(span)).collect())
}
}
fn js_lang() -> Arc<Language> {
Arc::new(
Language::new(