catchup with main
This commit is contained in:
commit
71bc35d241
84 changed files with 6026 additions and 3636 deletions
|
@ -41,6 +41,7 @@ pub struct EmbeddingQueue {
|
|||
pending_batch_token_count: usize,
|
||||
finished_files_tx: channel::Sender<FileToEmbed>,
|
||||
finished_files_rx: channel::Receiver<FileToEmbed>,
|
||||
api_key: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
|
@ -50,7 +51,11 @@ pub struct FileFragmentToEmbed {
|
|||
}
|
||||
|
||||
impl EmbeddingQueue {
|
||||
pub fn new(embedding_provider: Arc<dyn EmbeddingProvider>, executor: Arc<Background>) -> Self {
|
||||
pub fn new(
|
||||
embedding_provider: Arc<dyn EmbeddingProvider>,
|
||||
executor: Arc<Background>,
|
||||
api_key: Option<String>,
|
||||
) -> Self {
|
||||
let (finished_files_tx, finished_files_rx) = channel::unbounded();
|
||||
Self {
|
||||
embedding_provider,
|
||||
|
@ -59,9 +64,14 @@ impl EmbeddingQueue {
|
|||
pending_batch_token_count: 0,
|
||||
finished_files_tx,
|
||||
finished_files_rx,
|
||||
api_key,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_api_key(&mut self, api_key: Option<String>) {
|
||||
self.api_key = api_key
|
||||
}
|
||||
|
||||
pub fn push(&mut self, file: FileToEmbed) {
|
||||
if file.spans.is_empty() {
|
||||
self.finished_files_tx.try_send(file).unwrap();
|
||||
|
@ -108,6 +118,7 @@ impl EmbeddingQueue {
|
|||
|
||||
let finished_files_tx = self.finished_files_tx.clone();
|
||||
let embedding_provider = self.embedding_provider.clone();
|
||||
let api_key = self.api_key.clone();
|
||||
|
||||
self.executor
|
||||
.spawn(async move {
|
||||
|
@ -132,7 +143,7 @@ impl EmbeddingQueue {
|
|||
return;
|
||||
};
|
||||
|
||||
match embedding_provider.embed_batch(spans).await {
|
||||
match embedding_provider.embed_batch(spans, api_key).await {
|
||||
Ok(embeddings) => {
|
||||
let mut embeddings = embeddings.into_iter();
|
||||
for fragment in batch {
|
||||
|
|
|
@ -124,6 +124,8 @@ pub struct SemanticIndex {
|
|||
_embedding_task: Task<()>,
|
||||
_parsing_files_tasks: Vec<Task<()>>,
|
||||
projects: HashMap<WeakModelHandle<Project>, ProjectState>,
|
||||
api_key: Option<String>,
|
||||
embedding_queue: Arc<Mutex<EmbeddingQueue>>,
|
||||
}
|
||||
|
||||
struct ProjectState {
|
||||
|
@ -269,7 +271,7 @@ pub struct SearchResult {
|
|||
}
|
||||
|
||||
impl SemanticIndex {
|
||||
pub fn global(cx: &AppContext) -> Option<ModelHandle<SemanticIndex>> {
|
||||
pub fn global(cx: &mut AppContext) -> Option<ModelHandle<SemanticIndex>> {
|
||||
if cx.has_global::<ModelHandle<Self>>() {
|
||||
Some(cx.global::<ModelHandle<SemanticIndex>>().clone())
|
||||
} else {
|
||||
|
@ -277,12 +279,26 @@ impl SemanticIndex {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn authenticate(&mut self, cx: &AppContext) {
|
||||
if self.api_key.is_none() {
|
||||
self.api_key = self.embedding_provider.retrieve_credentials(cx);
|
||||
|
||||
self.embedding_queue
|
||||
.lock()
|
||||
.set_api_key(self.api_key.clone());
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_authenticated(&self) -> bool {
|
||||
self.api_key.is_some()
|
||||
}
|
||||
|
||||
pub fn enabled(cx: &AppContext) -> bool {
|
||||
settings::get::<SemanticIndexSettings>(cx).enabled
|
||||
}
|
||||
|
||||
pub fn status(&self, project: &ModelHandle<Project>) -> SemanticIndexStatus {
|
||||
if !self.embedding_provider.is_authenticated() {
|
||||
if !self.is_authenticated() {
|
||||
return SemanticIndexStatus::NotAuthenticated;
|
||||
}
|
||||
|
||||
|
@ -324,7 +340,7 @@ impl SemanticIndex {
|
|||
Ok(cx.add_model(|cx| {
|
||||
let t0 = Instant::now();
|
||||
let embedding_queue =
|
||||
EmbeddingQueue::new(embedding_provider.clone(), cx.background().clone());
|
||||
EmbeddingQueue::new(embedding_provider.clone(), cx.background().clone(), None);
|
||||
let _embedding_task = cx.background().spawn({
|
||||
let embedded_files = embedding_queue.finished_files();
|
||||
let db = db.clone();
|
||||
|
@ -389,6 +405,8 @@ impl SemanticIndex {
|
|||
_embedding_task,
|
||||
_parsing_files_tasks,
|
||||
projects: Default::default(),
|
||||
api_key: None,
|
||||
embedding_queue
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
@ -703,12 +721,13 @@ impl SemanticIndex {
|
|||
|
||||
let index = self.index_project(project.clone(), cx);
|
||||
let embedding_provider = self.embedding_provider.clone();
|
||||
let api_key = self.api_key.clone();
|
||||
|
||||
cx.spawn(|this, mut cx| async move {
|
||||
index.await?;
|
||||
let t0 = Instant::now();
|
||||
let query = embedding_provider
|
||||
.embed_batch(vec![query])
|
||||
.embed_batch(vec![query], api_key)
|
||||
.await?
|
||||
.pop()
|
||||
.ok_or_else(|| anyhow!("could not embed query"))?;
|
||||
|
@ -926,6 +945,7 @@ impl SemanticIndex {
|
|||
let fs = self.fs.clone();
|
||||
let db_path = self.db.path().clone();
|
||||
let background = cx.background().clone();
|
||||
let api_key = self.api_key.clone();
|
||||
cx.background().spawn(async move {
|
||||
let db = VectorDatabase::new(fs, db_path.clone(), background).await?;
|
||||
let mut results = Vec::<SearchResult>::new();
|
||||
|
@ -940,10 +960,15 @@ impl SemanticIndex {
|
|||
.parse_file_with_template(None, &snapshot.text(), language)
|
||||
.log_err()
|
||||
.unwrap_or_default();
|
||||
if Self::embed_spans(&mut spans, embedding_provider.as_ref(), &db)
|
||||
.await
|
||||
.log_err()
|
||||
.is_some()
|
||||
if Self::embed_spans(
|
||||
&mut spans,
|
||||
embedding_provider.as_ref(),
|
||||
&db,
|
||||
api_key.clone(),
|
||||
)
|
||||
.await
|
||||
.log_err()
|
||||
.is_some()
|
||||
{
|
||||
for span in spans {
|
||||
let similarity = span.embedding.unwrap().similarity(&query);
|
||||
|
@ -983,8 +1008,11 @@ impl SemanticIndex {
|
|||
project: ModelHandle<Project>,
|
||||
cx: &mut ModelContext<Self>,
|
||||
) -> Task<Result<()>> {
|
||||
if !self.embedding_provider.is_authenticated() {
|
||||
return Task::ready(Err(anyhow!("user is not authenticated")));
|
||||
if self.api_key.is_none() {
|
||||
self.authenticate(cx);
|
||||
if self.api_key.is_none() {
|
||||
return Task::ready(Err(anyhow!("user is not authenticated")));
|
||||
}
|
||||
}
|
||||
|
||||
if !self.projects.contains_key(&project.downgrade()) {
|
||||
|
@ -1165,6 +1193,7 @@ impl SemanticIndex {
|
|||
spans: &mut [Span],
|
||||
embedding_provider: &dyn EmbeddingProvider,
|
||||
db: &VectorDatabase,
|
||||
api_key: Option<String>,
|
||||
) -> Result<()> {
|
||||
let mut batch = Vec::new();
|
||||
let mut batch_tokens = 0;
|
||||
|
@ -1187,7 +1216,7 @@ impl SemanticIndex {
|
|||
|
||||
if batch_tokens + span.token_count > embedding_provider.max_tokens_per_batch() {
|
||||
let batch_embeddings = embedding_provider
|
||||
.embed_batch(mem::take(&mut batch))
|
||||
.embed_batch(mem::take(&mut batch), api_key.clone())
|
||||
.await?;
|
||||
embeddings.extend(batch_embeddings);
|
||||
batch_tokens = 0;
|
||||
|
@ -1199,7 +1228,7 @@ impl SemanticIndex {
|
|||
|
||||
if !batch.is_empty() {
|
||||
let batch_embeddings = embedding_provider
|
||||
.embed_batch(mem::take(&mut batch))
|
||||
.embed_batch(mem::take(&mut batch), api_key)
|
||||
.await?;
|
||||
|
||||
embeddings.extend(batch_embeddings);
|
||||
|
|
|
@ -11,7 +11,7 @@ use ai::{
|
|||
};
|
||||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use gpui::{executor::Deterministic, Task, TestAppContext};
|
||||
use gpui::{executor::Deterministic, AppContext, Task, TestAppContext};
|
||||
use language::{Language, LanguageConfig, LanguageRegistry, ToOffset};
|
||||
use parking_lot::Mutex;
|
||||
use pretty_assertions::assert_eq;
|
||||
|
@ -232,7 +232,7 @@ async fn test_embedding_batching(cx: &mut TestAppContext, mut rng: StdRng) {
|
|||
|
||||
let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
|
||||
|
||||
let mut queue = EmbeddingQueue::new(embedding_provider.clone(), cx.background());
|
||||
let mut queue = EmbeddingQueue::new(embedding_provider.clone(), cx.background(), None);
|
||||
for file in &files {
|
||||
queue.push(file.clone());
|
||||
}
|
||||
|
@ -1288,8 +1288,8 @@ impl EmbeddingProvider for FakeEmbeddingProvider {
|
|||
fn base_model(&self) -> Box<dyn LanguageModel> {
|
||||
Box::new(DummyLanguageModel {})
|
||||
}
|
||||
fn is_authenticated(&self) -> bool {
|
||||
true
|
||||
fn retrieve_credentials(&self, _cx: &AppContext) -> Option<String> {
|
||||
Some("Fake Credentials".to_string())
|
||||
}
|
||||
fn max_tokens_per_batch(&self) -> usize {
|
||||
1000
|
||||
|
@ -1299,7 +1299,11 @@ impl EmbeddingProvider for FakeEmbeddingProvider {
|
|||
None
|
||||
}
|
||||
|
||||
async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
|
||||
async fn embed_batch(
|
||||
&self,
|
||||
spans: Vec<String>,
|
||||
_api_key: Option<String>,
|
||||
) -> Result<Vec<Embedding>> {
|
||||
self.embedding_count
|
||||
.fetch_add(spans.len(), atomic::Ordering::SeqCst);
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue