move keychain access into semantic index as opposed to on init

This commit is contained in:
KCaverly 2023-10-24 13:26:37 +02:00
parent 67e590202a
commit 8ffe5a3ec7
7 changed files with 114 additions and 92 deletions

View file

@ -41,6 +41,7 @@ pub struct EmbeddingQueue {
pending_batch_token_count: usize,
finished_files_tx: channel::Sender<FileToEmbed>,
finished_files_rx: channel::Receiver<FileToEmbed>,
api_key: Option<String>,
}
#[derive(Clone)]
@ -50,7 +51,11 @@ pub struct FileFragmentToEmbed {
}
impl EmbeddingQueue {
pub fn new(embedding_provider: Arc<dyn EmbeddingProvider>, executor: Arc<Background>) -> Self {
pub fn new(
embedding_provider: Arc<dyn EmbeddingProvider>,
executor: Arc<Background>,
api_key: Option<String>,
) -> Self {
let (finished_files_tx, finished_files_rx) = channel::unbounded();
Self {
embedding_provider,
@ -59,9 +64,14 @@ impl EmbeddingQueue {
pending_batch_token_count: 0,
finished_files_tx,
finished_files_rx,
api_key,
}
}
pub fn set_api_key(&mut self, api_key: Option<String>) {
self.api_key = api_key
}
pub fn push(&mut self, file: FileToEmbed) {
if file.spans.is_empty() {
self.finished_files_tx.try_send(file).unwrap();
@ -108,6 +118,7 @@ impl EmbeddingQueue {
let finished_files_tx = self.finished_files_tx.clone();
let embedding_provider = self.embedding_provider.clone();
let api_key = self.api_key.clone();
self.executor
.spawn(async move {
@ -132,7 +143,7 @@ impl EmbeddingQueue {
return;
};
match embedding_provider.embed_batch(spans).await {
match embedding_provider.embed_batch(spans, api_key).await {
Ok(embeddings) => {
let mut embeddings = embeddings.into_iter();
for fragment in batch {

View file

@ -7,10 +7,7 @@ pub mod semantic_index_settings;
mod semantic_index_tests;
use crate::semantic_index_settings::SemanticIndexSettings;
use ai::{
completion::OPENAI_API_URL,
embedding::{Embedding, EmbeddingProvider, OpenAIEmbeddings},
};
use ai::embedding::{Embedding, EmbeddingProvider, OpenAIEmbeddings};
use anyhow::{anyhow, Result};
use collections::{BTreeMap, HashMap, HashSet};
use db::VectorDatabase;
@ -58,19 +55,6 @@ pub fn init(
.join(Path::new(RELEASE_CHANNEL_NAME.as_str()))
.join("embeddings_db");
let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") {
Some(api_key)
} else if let Some((_, api_key)) = cx
.platform()
.read_credentials(OPENAI_API_URL)
.log_err()
.flatten()
{
String::from_utf8(api_key).log_err()
} else {
None
};
cx.subscribe_global::<WorkspaceCreated, _>({
move |event, cx| {
let Some(semantic_index) = SemanticIndex::global(cx) else {
@ -104,7 +88,7 @@ pub fn init(
let semantic_index = SemanticIndex::new(
fs,
db_file_path,
Arc::new(OpenAIEmbeddings::new(api_key, http_client, cx.background())),
Arc::new(OpenAIEmbeddings::new(http_client, cx.background())),
language_registry,
cx.clone(),
)
@ -139,6 +123,8 @@ pub struct SemanticIndex {
_embedding_task: Task<()>,
_parsing_files_tasks: Vec<Task<()>>,
projects: HashMap<WeakModelHandle<Project>, ProjectState>,
api_key: Option<String>,
embedding_queue: Arc<Mutex<EmbeddingQueue>>,
}
struct ProjectState {
@ -284,7 +270,7 @@ pub struct SearchResult {
}
impl SemanticIndex {
pub fn global(cx: &AppContext) -> Option<ModelHandle<SemanticIndex>> {
pub fn global(cx: &mut AppContext) -> Option<ModelHandle<SemanticIndex>> {
if cx.has_global::<ModelHandle<Self>>() {
Some(cx.global::<ModelHandle<SemanticIndex>>().clone())
} else {
@ -292,12 +278,26 @@ impl SemanticIndex {
}
}
pub fn authenticate(&mut self, cx: &AppContext) {
if self.api_key.is_none() {
self.api_key = self.embedding_provider.retrieve_credentials(cx);
self.embedding_queue
.lock()
.set_api_key(self.api_key.clone());
}
}
pub fn is_authenticated(&self) -> bool {
self.api_key.is_some()
}
pub fn enabled(cx: &AppContext) -> bool {
settings::get::<SemanticIndexSettings>(cx).enabled
}
pub fn status(&self, project: &ModelHandle<Project>) -> SemanticIndexStatus {
if !self.embedding_provider.is_authenticated() {
if !self.is_authenticated() {
return SemanticIndexStatus::NotAuthenticated;
}
@ -339,7 +339,7 @@ impl SemanticIndex {
Ok(cx.add_model(|cx| {
let t0 = Instant::now();
let embedding_queue =
EmbeddingQueue::new(embedding_provider.clone(), cx.background().clone());
EmbeddingQueue::new(embedding_provider.clone(), cx.background().clone(), None);
let _embedding_task = cx.background().spawn({
let embedded_files = embedding_queue.finished_files();
let db = db.clone();
@ -404,6 +404,8 @@ impl SemanticIndex {
_embedding_task,
_parsing_files_tasks,
projects: Default::default(),
api_key: None,
embedding_queue
}
}))
}
@ -718,12 +720,13 @@ impl SemanticIndex {
let index = self.index_project(project.clone(), cx);
let embedding_provider = self.embedding_provider.clone();
let api_key = self.api_key.clone();
cx.spawn(|this, mut cx| async move {
index.await?;
let t0 = Instant::now();
let query = embedding_provider
.embed_batch(vec![query])
.embed_batch(vec![query], api_key)
.await?
.pop()
.ok_or_else(|| anyhow!("could not embed query"))?;
@ -941,6 +944,7 @@ impl SemanticIndex {
let fs = self.fs.clone();
let db_path = self.db.path().clone();
let background = cx.background().clone();
let api_key = self.api_key.clone();
cx.background().spawn(async move {
let db = VectorDatabase::new(fs, db_path.clone(), background).await?;
let mut results = Vec::<SearchResult>::new();
@ -955,10 +959,15 @@ impl SemanticIndex {
.parse_file_with_template(None, &snapshot.text(), language)
.log_err()
.unwrap_or_default();
if Self::embed_spans(&mut spans, embedding_provider.as_ref(), &db)
.await
.log_err()
.is_some()
if Self::embed_spans(
&mut spans,
embedding_provider.as_ref(),
&db,
api_key.clone(),
)
.await
.log_err()
.is_some()
{
for span in spans {
let similarity = span.embedding.unwrap().similarity(&query);
@ -998,8 +1007,11 @@ impl SemanticIndex {
project: ModelHandle<Project>,
cx: &mut ModelContext<Self>,
) -> Task<Result<()>> {
if !self.embedding_provider.is_authenticated() {
return Task::ready(Err(anyhow!("user is not authenticated")));
if self.api_key.is_none() {
self.authenticate(cx);
if self.api_key.is_none() {
return Task::ready(Err(anyhow!("user is not authenticated")));
}
}
if !self.projects.contains_key(&project.downgrade()) {
@ -1180,6 +1192,7 @@ impl SemanticIndex {
spans: &mut [Span],
embedding_provider: &dyn EmbeddingProvider,
db: &VectorDatabase,
api_key: Option<String>,
) -> Result<()> {
let mut batch = Vec::new();
let mut batch_tokens = 0;
@ -1202,7 +1215,7 @@ impl SemanticIndex {
if batch_tokens + span.token_count > embedding_provider.max_tokens_per_batch() {
let batch_embeddings = embedding_provider
.embed_batch(mem::take(&mut batch))
.embed_batch(mem::take(&mut batch), api_key.clone())
.await?;
embeddings.extend(batch_embeddings);
batch_tokens = 0;
@ -1214,7 +1227,7 @@ impl SemanticIndex {
if !batch.is_empty() {
let batch_embeddings = embedding_provider
.embed_batch(mem::take(&mut batch))
.embed_batch(mem::take(&mut batch), api_key)
.await?;
embeddings.extend(batch_embeddings);

View file

@ -7,7 +7,7 @@ use crate::{
use ai::embedding::{DummyEmbeddings, Embedding, EmbeddingProvider};
use anyhow::Result;
use async_trait::async_trait;
use gpui::{executor::Deterministic, Task, TestAppContext};
use gpui::{executor::Deterministic, AppContext, Task, TestAppContext};
use language::{Language, LanguageConfig, LanguageRegistry, ToOffset};
use parking_lot::Mutex;
use pretty_assertions::assert_eq;
@ -228,7 +228,7 @@ async fn test_embedding_batching(cx: &mut TestAppContext, mut rng: StdRng) {
let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
let mut queue = EmbeddingQueue::new(embedding_provider.clone(), cx.background());
let mut queue = EmbeddingQueue::new(embedding_provider.clone(), cx.background(), None);
for file in &files {
queue.push(file.clone());
}
@ -1281,8 +1281,8 @@ impl FakeEmbeddingProvider {
#[async_trait]
impl EmbeddingProvider for FakeEmbeddingProvider {
fn is_authenticated(&self) -> bool {
true
fn retrieve_credentials(&self, _cx: &AppContext) -> Option<String> {
Some("Fake Credentials".to_string())
}
fn truncate(&self, span: &str) -> (String, usize) {
(span.to_string(), 1)
@ -1296,7 +1296,11 @@ impl EmbeddingProvider for FakeEmbeddingProvider {
None
}
async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
async fn embed_batch(
&self,
spans: Vec<String>,
_api_key: Option<String>,
) -> Result<Vec<Embedding>> {
self.embedding_count
.fetch_add(spans.len(), atomic::Ordering::SeqCst);
Ok(spans.iter().map(|span| self.embed_sync(span)).collect())