move keychain access into semantic index as opposed to on init
This commit is contained in:
parent
67e590202a
commit
8ffe5a3ec7
7 changed files with 114 additions and 92 deletions
|
@ -2,7 +2,7 @@ use anyhow::{anyhow, Result};
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use futures::AsyncReadExt;
|
use futures::AsyncReadExt;
|
||||||
use gpui::executor::Background;
|
use gpui::executor::Background;
|
||||||
use gpui::{serde_json, ViewContext};
|
use gpui::{serde_json, AppContext};
|
||||||
use isahc::http::StatusCode;
|
use isahc::http::StatusCode;
|
||||||
use isahc::prelude::Configurable;
|
use isahc::prelude::Configurable;
|
||||||
use isahc::{AsyncBody, Response};
|
use isahc::{AsyncBody, Response};
|
||||||
|
@ -89,7 +89,6 @@ impl Embedding {
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct OpenAIEmbeddings {
|
pub struct OpenAIEmbeddings {
|
||||||
pub api_key: Option<String>,
|
|
||||||
pub client: Arc<dyn HttpClient>,
|
pub client: Arc<dyn HttpClient>,
|
||||||
pub executor: Arc<Background>,
|
pub executor: Arc<Background>,
|
||||||
rate_limit_count_rx: watch::Receiver<Option<Instant>>,
|
rate_limit_count_rx: watch::Receiver<Option<Instant>>,
|
||||||
|
@ -123,8 +122,12 @@ struct OpenAIEmbeddingUsage {
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
pub trait EmbeddingProvider: Sync + Send {
|
pub trait EmbeddingProvider: Sync + Send {
|
||||||
fn is_authenticated(&self) -> bool;
|
fn retrieve_credentials(&self, cx: &AppContext) -> Option<String>;
|
||||||
async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>>;
|
async fn embed_batch(
|
||||||
|
&self,
|
||||||
|
spans: Vec<String>,
|
||||||
|
api_key: Option<String>,
|
||||||
|
) -> Result<Vec<Embedding>>;
|
||||||
fn max_tokens_per_batch(&self) -> usize;
|
fn max_tokens_per_batch(&self) -> usize;
|
||||||
fn truncate(&self, span: &str) -> (String, usize);
|
fn truncate(&self, span: &str) -> (String, usize);
|
||||||
fn rate_limit_expiration(&self) -> Option<Instant>;
|
fn rate_limit_expiration(&self) -> Option<Instant>;
|
||||||
|
@ -134,13 +137,17 @@ pub struct DummyEmbeddings {}
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl EmbeddingProvider for DummyEmbeddings {
|
impl EmbeddingProvider for DummyEmbeddings {
|
||||||
fn is_authenticated(&self) -> bool {
|
fn retrieve_credentials(&self, _cx: &AppContext) -> Option<String> {
|
||||||
true
|
Some("Dummy API KEY".to_string())
|
||||||
}
|
}
|
||||||
fn rate_limit_expiration(&self) -> Option<Instant> {
|
fn rate_limit_expiration(&self) -> Option<Instant> {
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
|
async fn embed_batch(
|
||||||
|
&self,
|
||||||
|
spans: Vec<String>,
|
||||||
|
_api_key: Option<String>,
|
||||||
|
) -> Result<Vec<Embedding>> {
|
||||||
// 1024 is the OpenAI Embeddings size for ada models.
|
// 1024 is the OpenAI Embeddings size for ada models.
|
||||||
// the model we will likely be starting with.
|
// the model we will likely be starting with.
|
||||||
let dummy_vec = Embedding::from(vec![0.32 as f32; 1536]);
|
let dummy_vec = Embedding::from(vec![0.32 as f32; 1536]);
|
||||||
|
@ -169,36 +176,11 @@ impl EmbeddingProvider for DummyEmbeddings {
|
||||||
const OPENAI_INPUT_LIMIT: usize = 8190;
|
const OPENAI_INPUT_LIMIT: usize = 8190;
|
||||||
|
|
||||||
impl OpenAIEmbeddings {
|
impl OpenAIEmbeddings {
|
||||||
pub fn authenticate(&mut self, cx: &mut ViewContext<Self>) {
|
pub fn new(client: Arc<dyn HttpClient>, executor: Arc<Background>) -> Self {
|
||||||
if self.api_key.is_none() {
|
|
||||||
let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") {
|
|
||||||
Some(api_key)
|
|
||||||
} else if let Some((_, api_key)) = cx
|
|
||||||
.platform()
|
|
||||||
.read_credentials(OPENAI_API_URL)
|
|
||||||
.log_err()
|
|
||||||
.flatten()
|
|
||||||
{
|
|
||||||
String::from_utf8(api_key).log_err()
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
|
|
||||||
if let Some(api_key) = api_key {
|
|
||||||
self.api_key = Some(api_key);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
pub fn new(
|
|
||||||
api_key: Option<String>,
|
|
||||||
client: Arc<dyn HttpClient>,
|
|
||||||
executor: Arc<Background>,
|
|
||||||
) -> Self {
|
|
||||||
let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None);
|
let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None);
|
||||||
let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx));
|
let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx));
|
||||||
|
|
||||||
OpenAIEmbeddings {
|
OpenAIEmbeddings {
|
||||||
api_key,
|
|
||||||
client,
|
client,
|
||||||
executor,
|
executor,
|
||||||
rate_limit_count_rx,
|
rate_limit_count_rx,
|
||||||
|
@ -264,8 +246,19 @@ impl OpenAIEmbeddings {
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl EmbeddingProvider for OpenAIEmbeddings {
|
impl EmbeddingProvider for OpenAIEmbeddings {
|
||||||
fn is_authenticated(&self) -> bool {
|
fn retrieve_credentials(&self, cx: &AppContext) -> Option<String> {
|
||||||
self.api_key.is_some()
|
if let Ok(api_key) = env::var("OPENAI_API_KEY") {
|
||||||
|
Some(api_key)
|
||||||
|
} else if let Some((_, api_key)) = cx
|
||||||
|
.platform()
|
||||||
|
.read_credentials(OPENAI_API_URL)
|
||||||
|
.log_err()
|
||||||
|
.flatten()
|
||||||
|
{
|
||||||
|
String::from_utf8(api_key).log_err()
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn max_tokens_per_batch(&self) -> usize {
|
fn max_tokens_per_batch(&self) -> usize {
|
||||||
|
@ -290,11 +283,15 @@ impl EmbeddingProvider for OpenAIEmbeddings {
|
||||||
(output, tokens.len())
|
(output, tokens.len())
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
|
async fn embed_batch(
|
||||||
|
&self,
|
||||||
|
spans: Vec<String>,
|
||||||
|
api_key: Option<String>,
|
||||||
|
) -> Result<Vec<Embedding>> {
|
||||||
const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
|
const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
|
||||||
const MAX_RETRIES: usize = 4;
|
const MAX_RETRIES: usize = 4;
|
||||||
|
|
||||||
let Some(api_key) = self.api_key.clone() else {
|
let Some(api_key) = api_key else {
|
||||||
return Err(anyhow!("no open ai key provided"));
|
return Err(anyhow!("no open ai key provided"));
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -8489,6 +8489,18 @@ impl Project {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(any(test, feature = "test-support"))]
|
||||||
|
fn install_default_formatters(
|
||||||
|
&self,
|
||||||
|
_worktree: Option<WorktreeId>,
|
||||||
|
_new_language: &Language,
|
||||||
|
_language_settings: &LanguageSettings,
|
||||||
|
_cx: &mut ModelContext<Self>,
|
||||||
|
) -> Task<anyhow::Result<()>> {
|
||||||
|
return Task::ready(Ok(()));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(any(test, feature = "test-support")))]
|
||||||
fn install_default_formatters(
|
fn install_default_formatters(
|
||||||
&self,
|
&self,
|
||||||
worktree: Option<WorktreeId>,
|
worktree: Option<WorktreeId>,
|
||||||
|
|
|
@ -41,6 +41,7 @@ pub struct EmbeddingQueue {
|
||||||
pending_batch_token_count: usize,
|
pending_batch_token_count: usize,
|
||||||
finished_files_tx: channel::Sender<FileToEmbed>,
|
finished_files_tx: channel::Sender<FileToEmbed>,
|
||||||
finished_files_rx: channel::Receiver<FileToEmbed>,
|
finished_files_rx: channel::Receiver<FileToEmbed>,
|
||||||
|
api_key: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
|
@ -50,7 +51,11 @@ pub struct FileFragmentToEmbed {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl EmbeddingQueue {
|
impl EmbeddingQueue {
|
||||||
pub fn new(embedding_provider: Arc<dyn EmbeddingProvider>, executor: Arc<Background>) -> Self {
|
pub fn new(
|
||||||
|
embedding_provider: Arc<dyn EmbeddingProvider>,
|
||||||
|
executor: Arc<Background>,
|
||||||
|
api_key: Option<String>,
|
||||||
|
) -> Self {
|
||||||
let (finished_files_tx, finished_files_rx) = channel::unbounded();
|
let (finished_files_tx, finished_files_rx) = channel::unbounded();
|
||||||
Self {
|
Self {
|
||||||
embedding_provider,
|
embedding_provider,
|
||||||
|
@ -59,9 +64,14 @@ impl EmbeddingQueue {
|
||||||
pending_batch_token_count: 0,
|
pending_batch_token_count: 0,
|
||||||
finished_files_tx,
|
finished_files_tx,
|
||||||
finished_files_rx,
|
finished_files_rx,
|
||||||
|
api_key,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn set_api_key(&mut self, api_key: Option<String>) {
|
||||||
|
self.api_key = api_key
|
||||||
|
}
|
||||||
|
|
||||||
pub fn push(&mut self, file: FileToEmbed) {
|
pub fn push(&mut self, file: FileToEmbed) {
|
||||||
if file.spans.is_empty() {
|
if file.spans.is_empty() {
|
||||||
self.finished_files_tx.try_send(file).unwrap();
|
self.finished_files_tx.try_send(file).unwrap();
|
||||||
|
@ -108,6 +118,7 @@ impl EmbeddingQueue {
|
||||||
|
|
||||||
let finished_files_tx = self.finished_files_tx.clone();
|
let finished_files_tx = self.finished_files_tx.clone();
|
||||||
let embedding_provider = self.embedding_provider.clone();
|
let embedding_provider = self.embedding_provider.clone();
|
||||||
|
let api_key = self.api_key.clone();
|
||||||
|
|
||||||
self.executor
|
self.executor
|
||||||
.spawn(async move {
|
.spawn(async move {
|
||||||
|
@ -132,7 +143,7 @@ impl EmbeddingQueue {
|
||||||
return;
|
return;
|
||||||
};
|
};
|
||||||
|
|
||||||
match embedding_provider.embed_batch(spans).await {
|
match embedding_provider.embed_batch(spans, api_key).await {
|
||||||
Ok(embeddings) => {
|
Ok(embeddings) => {
|
||||||
let mut embeddings = embeddings.into_iter();
|
let mut embeddings = embeddings.into_iter();
|
||||||
for fragment in batch {
|
for fragment in batch {
|
||||||
|
|
|
@ -7,10 +7,7 @@ pub mod semantic_index_settings;
|
||||||
mod semantic_index_tests;
|
mod semantic_index_tests;
|
||||||
|
|
||||||
use crate::semantic_index_settings::SemanticIndexSettings;
|
use crate::semantic_index_settings::SemanticIndexSettings;
|
||||||
use ai::{
|
use ai::embedding::{Embedding, EmbeddingProvider, OpenAIEmbeddings};
|
||||||
completion::OPENAI_API_URL,
|
|
||||||
embedding::{Embedding, EmbeddingProvider, OpenAIEmbeddings},
|
|
||||||
};
|
|
||||||
use anyhow::{anyhow, Result};
|
use anyhow::{anyhow, Result};
|
||||||
use collections::{BTreeMap, HashMap, HashSet};
|
use collections::{BTreeMap, HashMap, HashSet};
|
||||||
use db::VectorDatabase;
|
use db::VectorDatabase;
|
||||||
|
@ -58,19 +55,6 @@ pub fn init(
|
||||||
.join(Path::new(RELEASE_CHANNEL_NAME.as_str()))
|
.join(Path::new(RELEASE_CHANNEL_NAME.as_str()))
|
||||||
.join("embeddings_db");
|
.join("embeddings_db");
|
||||||
|
|
||||||
let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") {
|
|
||||||
Some(api_key)
|
|
||||||
} else if let Some((_, api_key)) = cx
|
|
||||||
.platform()
|
|
||||||
.read_credentials(OPENAI_API_URL)
|
|
||||||
.log_err()
|
|
||||||
.flatten()
|
|
||||||
{
|
|
||||||
String::from_utf8(api_key).log_err()
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
|
|
||||||
cx.subscribe_global::<WorkspaceCreated, _>({
|
cx.subscribe_global::<WorkspaceCreated, _>({
|
||||||
move |event, cx| {
|
move |event, cx| {
|
||||||
let Some(semantic_index) = SemanticIndex::global(cx) else {
|
let Some(semantic_index) = SemanticIndex::global(cx) else {
|
||||||
|
@ -104,7 +88,7 @@ pub fn init(
|
||||||
let semantic_index = SemanticIndex::new(
|
let semantic_index = SemanticIndex::new(
|
||||||
fs,
|
fs,
|
||||||
db_file_path,
|
db_file_path,
|
||||||
Arc::new(OpenAIEmbeddings::new(api_key, http_client, cx.background())),
|
Arc::new(OpenAIEmbeddings::new(http_client, cx.background())),
|
||||||
language_registry,
|
language_registry,
|
||||||
cx.clone(),
|
cx.clone(),
|
||||||
)
|
)
|
||||||
|
@ -139,6 +123,8 @@ pub struct SemanticIndex {
|
||||||
_embedding_task: Task<()>,
|
_embedding_task: Task<()>,
|
||||||
_parsing_files_tasks: Vec<Task<()>>,
|
_parsing_files_tasks: Vec<Task<()>>,
|
||||||
projects: HashMap<WeakModelHandle<Project>, ProjectState>,
|
projects: HashMap<WeakModelHandle<Project>, ProjectState>,
|
||||||
|
api_key: Option<String>,
|
||||||
|
embedding_queue: Arc<Mutex<EmbeddingQueue>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ProjectState {
|
struct ProjectState {
|
||||||
|
@ -284,7 +270,7 @@ pub struct SearchResult {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl SemanticIndex {
|
impl SemanticIndex {
|
||||||
pub fn global(cx: &AppContext) -> Option<ModelHandle<SemanticIndex>> {
|
pub fn global(cx: &mut AppContext) -> Option<ModelHandle<SemanticIndex>> {
|
||||||
if cx.has_global::<ModelHandle<Self>>() {
|
if cx.has_global::<ModelHandle<Self>>() {
|
||||||
Some(cx.global::<ModelHandle<SemanticIndex>>().clone())
|
Some(cx.global::<ModelHandle<SemanticIndex>>().clone())
|
||||||
} else {
|
} else {
|
||||||
|
@ -292,12 +278,26 @@ impl SemanticIndex {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn authenticate(&mut self, cx: &AppContext) {
|
||||||
|
if self.api_key.is_none() {
|
||||||
|
self.api_key = self.embedding_provider.retrieve_credentials(cx);
|
||||||
|
|
||||||
|
self.embedding_queue
|
||||||
|
.lock()
|
||||||
|
.set_api_key(self.api_key.clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn is_authenticated(&self) -> bool {
|
||||||
|
self.api_key.is_some()
|
||||||
|
}
|
||||||
|
|
||||||
pub fn enabled(cx: &AppContext) -> bool {
|
pub fn enabled(cx: &AppContext) -> bool {
|
||||||
settings::get::<SemanticIndexSettings>(cx).enabled
|
settings::get::<SemanticIndexSettings>(cx).enabled
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn status(&self, project: &ModelHandle<Project>) -> SemanticIndexStatus {
|
pub fn status(&self, project: &ModelHandle<Project>) -> SemanticIndexStatus {
|
||||||
if !self.embedding_provider.is_authenticated() {
|
if !self.is_authenticated() {
|
||||||
return SemanticIndexStatus::NotAuthenticated;
|
return SemanticIndexStatus::NotAuthenticated;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -339,7 +339,7 @@ impl SemanticIndex {
|
||||||
Ok(cx.add_model(|cx| {
|
Ok(cx.add_model(|cx| {
|
||||||
let t0 = Instant::now();
|
let t0 = Instant::now();
|
||||||
let embedding_queue =
|
let embedding_queue =
|
||||||
EmbeddingQueue::new(embedding_provider.clone(), cx.background().clone());
|
EmbeddingQueue::new(embedding_provider.clone(), cx.background().clone(), None);
|
||||||
let _embedding_task = cx.background().spawn({
|
let _embedding_task = cx.background().spawn({
|
||||||
let embedded_files = embedding_queue.finished_files();
|
let embedded_files = embedding_queue.finished_files();
|
||||||
let db = db.clone();
|
let db = db.clone();
|
||||||
|
@ -404,6 +404,8 @@ impl SemanticIndex {
|
||||||
_embedding_task,
|
_embedding_task,
|
||||||
_parsing_files_tasks,
|
_parsing_files_tasks,
|
||||||
projects: Default::default(),
|
projects: Default::default(),
|
||||||
|
api_key: None,
|
||||||
|
embedding_queue
|
||||||
}
|
}
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
@ -718,12 +720,13 @@ impl SemanticIndex {
|
||||||
|
|
||||||
let index = self.index_project(project.clone(), cx);
|
let index = self.index_project(project.clone(), cx);
|
||||||
let embedding_provider = self.embedding_provider.clone();
|
let embedding_provider = self.embedding_provider.clone();
|
||||||
|
let api_key = self.api_key.clone();
|
||||||
|
|
||||||
cx.spawn(|this, mut cx| async move {
|
cx.spawn(|this, mut cx| async move {
|
||||||
index.await?;
|
index.await?;
|
||||||
let t0 = Instant::now();
|
let t0 = Instant::now();
|
||||||
let query = embedding_provider
|
let query = embedding_provider
|
||||||
.embed_batch(vec![query])
|
.embed_batch(vec![query], api_key)
|
||||||
.await?
|
.await?
|
||||||
.pop()
|
.pop()
|
||||||
.ok_or_else(|| anyhow!("could not embed query"))?;
|
.ok_or_else(|| anyhow!("could not embed query"))?;
|
||||||
|
@ -941,6 +944,7 @@ impl SemanticIndex {
|
||||||
let fs = self.fs.clone();
|
let fs = self.fs.clone();
|
||||||
let db_path = self.db.path().clone();
|
let db_path = self.db.path().clone();
|
||||||
let background = cx.background().clone();
|
let background = cx.background().clone();
|
||||||
|
let api_key = self.api_key.clone();
|
||||||
cx.background().spawn(async move {
|
cx.background().spawn(async move {
|
||||||
let db = VectorDatabase::new(fs, db_path.clone(), background).await?;
|
let db = VectorDatabase::new(fs, db_path.clone(), background).await?;
|
||||||
let mut results = Vec::<SearchResult>::new();
|
let mut results = Vec::<SearchResult>::new();
|
||||||
|
@ -955,10 +959,15 @@ impl SemanticIndex {
|
||||||
.parse_file_with_template(None, &snapshot.text(), language)
|
.parse_file_with_template(None, &snapshot.text(), language)
|
||||||
.log_err()
|
.log_err()
|
||||||
.unwrap_or_default();
|
.unwrap_or_default();
|
||||||
if Self::embed_spans(&mut spans, embedding_provider.as_ref(), &db)
|
if Self::embed_spans(
|
||||||
.await
|
&mut spans,
|
||||||
.log_err()
|
embedding_provider.as_ref(),
|
||||||
.is_some()
|
&db,
|
||||||
|
api_key.clone(),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.log_err()
|
||||||
|
.is_some()
|
||||||
{
|
{
|
||||||
for span in spans {
|
for span in spans {
|
||||||
let similarity = span.embedding.unwrap().similarity(&query);
|
let similarity = span.embedding.unwrap().similarity(&query);
|
||||||
|
@ -998,8 +1007,11 @@ impl SemanticIndex {
|
||||||
project: ModelHandle<Project>,
|
project: ModelHandle<Project>,
|
||||||
cx: &mut ModelContext<Self>,
|
cx: &mut ModelContext<Self>,
|
||||||
) -> Task<Result<()>> {
|
) -> Task<Result<()>> {
|
||||||
if !self.embedding_provider.is_authenticated() {
|
if self.api_key.is_none() {
|
||||||
return Task::ready(Err(anyhow!("user is not authenticated")));
|
self.authenticate(cx);
|
||||||
|
if self.api_key.is_none() {
|
||||||
|
return Task::ready(Err(anyhow!("user is not authenticated")));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if !self.projects.contains_key(&project.downgrade()) {
|
if !self.projects.contains_key(&project.downgrade()) {
|
||||||
|
@ -1180,6 +1192,7 @@ impl SemanticIndex {
|
||||||
spans: &mut [Span],
|
spans: &mut [Span],
|
||||||
embedding_provider: &dyn EmbeddingProvider,
|
embedding_provider: &dyn EmbeddingProvider,
|
||||||
db: &VectorDatabase,
|
db: &VectorDatabase,
|
||||||
|
api_key: Option<String>,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let mut batch = Vec::new();
|
let mut batch = Vec::new();
|
||||||
let mut batch_tokens = 0;
|
let mut batch_tokens = 0;
|
||||||
|
@ -1202,7 +1215,7 @@ impl SemanticIndex {
|
||||||
|
|
||||||
if batch_tokens + span.token_count > embedding_provider.max_tokens_per_batch() {
|
if batch_tokens + span.token_count > embedding_provider.max_tokens_per_batch() {
|
||||||
let batch_embeddings = embedding_provider
|
let batch_embeddings = embedding_provider
|
||||||
.embed_batch(mem::take(&mut batch))
|
.embed_batch(mem::take(&mut batch), api_key.clone())
|
||||||
.await?;
|
.await?;
|
||||||
embeddings.extend(batch_embeddings);
|
embeddings.extend(batch_embeddings);
|
||||||
batch_tokens = 0;
|
batch_tokens = 0;
|
||||||
|
@ -1214,7 +1227,7 @@ impl SemanticIndex {
|
||||||
|
|
||||||
if !batch.is_empty() {
|
if !batch.is_empty() {
|
||||||
let batch_embeddings = embedding_provider
|
let batch_embeddings = embedding_provider
|
||||||
.embed_batch(mem::take(&mut batch))
|
.embed_batch(mem::take(&mut batch), api_key)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
embeddings.extend(batch_embeddings);
|
embeddings.extend(batch_embeddings);
|
||||||
|
|
|
@ -7,7 +7,7 @@ use crate::{
|
||||||
use ai::embedding::{DummyEmbeddings, Embedding, EmbeddingProvider};
|
use ai::embedding::{DummyEmbeddings, Embedding, EmbeddingProvider};
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use gpui::{executor::Deterministic, Task, TestAppContext};
|
use gpui::{executor::Deterministic, AppContext, Task, TestAppContext};
|
||||||
use language::{Language, LanguageConfig, LanguageRegistry, ToOffset};
|
use language::{Language, LanguageConfig, LanguageRegistry, ToOffset};
|
||||||
use parking_lot::Mutex;
|
use parking_lot::Mutex;
|
||||||
use pretty_assertions::assert_eq;
|
use pretty_assertions::assert_eq;
|
||||||
|
@ -228,7 +228,7 @@ async fn test_embedding_batching(cx: &mut TestAppContext, mut rng: StdRng) {
|
||||||
|
|
||||||
let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
|
let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
|
||||||
|
|
||||||
let mut queue = EmbeddingQueue::new(embedding_provider.clone(), cx.background());
|
let mut queue = EmbeddingQueue::new(embedding_provider.clone(), cx.background(), None);
|
||||||
for file in &files {
|
for file in &files {
|
||||||
queue.push(file.clone());
|
queue.push(file.clone());
|
||||||
}
|
}
|
||||||
|
@ -1281,8 +1281,8 @@ impl FakeEmbeddingProvider {
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl EmbeddingProvider for FakeEmbeddingProvider {
|
impl EmbeddingProvider for FakeEmbeddingProvider {
|
||||||
fn is_authenticated(&self) -> bool {
|
fn retrieve_credentials(&self, _cx: &AppContext) -> Option<String> {
|
||||||
true
|
Some("Fake Credentials".to_string())
|
||||||
}
|
}
|
||||||
fn truncate(&self, span: &str) -> (String, usize) {
|
fn truncate(&self, span: &str) -> (String, usize) {
|
||||||
(span.to_string(), 1)
|
(span.to_string(), 1)
|
||||||
|
@ -1296,7 +1296,11 @@ impl EmbeddingProvider for FakeEmbeddingProvider {
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
|
async fn embed_batch(
|
||||||
|
&self,
|
||||||
|
spans: Vec<String>,
|
||||||
|
_api_key: Option<String>,
|
||||||
|
) -> Result<Vec<Embedding>> {
|
||||||
self.embedding_count
|
self.embedding_count
|
||||||
.fetch_add(spans.len(), atomic::Ordering::SeqCst);
|
.fetch_add(spans.len(), atomic::Ordering::SeqCst);
|
||||||
Ok(spans.iter().map(|span| self.embed_sync(span)).collect())
|
Ok(spans.iter().map(|span| self.embed_sync(span)).collect())
|
||||||
|
|
|
@ -1,4 +1,3 @@
|
||||||
use ai::completion::OPENAI_API_URL;
|
|
||||||
use ai::embedding::OpenAIEmbeddings;
|
use ai::embedding::OpenAIEmbeddings;
|
||||||
use anyhow::{anyhow, Result};
|
use anyhow::{anyhow, Result};
|
||||||
use client::{self, UserStore};
|
use client::{self, UserStore};
|
||||||
|
@ -18,7 +17,6 @@ use std::{cmp, env, fs};
|
||||||
use util::channel::{RELEASE_CHANNEL, RELEASE_CHANNEL_NAME};
|
use util::channel::{RELEASE_CHANNEL, RELEASE_CHANNEL_NAME};
|
||||||
use util::http::{self};
|
use util::http::{self};
|
||||||
use util::paths::EMBEDDINGS_DIR;
|
use util::paths::EMBEDDINGS_DIR;
|
||||||
use util::ResultExt;
|
|
||||||
use zed::languages;
|
use zed::languages;
|
||||||
|
|
||||||
#[derive(Deserialize, Clone, Serialize)]
|
#[derive(Deserialize, Clone, Serialize)]
|
||||||
|
@ -57,7 +55,7 @@ fn parse_eval() -> anyhow::Result<Vec<RepoEval>> {
|
||||||
.as_path()
|
.as_path()
|
||||||
.parent()
|
.parent()
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.join("crates/semantic_index/eval");
|
.join("zed/crates/semantic_index/eval");
|
||||||
|
|
||||||
let mut repo_evals: Vec<RepoEval> = Vec::new();
|
let mut repo_evals: Vec<RepoEval> = Vec::new();
|
||||||
for entry in fs::read_dir(eval_folder)? {
|
for entry in fs::read_dir(eval_folder)? {
|
||||||
|
@ -472,25 +470,12 @@ fn main() {
|
||||||
|
|
||||||
let languages = languages.clone();
|
let languages = languages.clone();
|
||||||
|
|
||||||
let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") {
|
|
||||||
Some(api_key)
|
|
||||||
} else if let Some((_, api_key)) = cx
|
|
||||||
.platform()
|
|
||||||
.read_credentials(OPENAI_API_URL)
|
|
||||||
.log_err()
|
|
||||||
.flatten()
|
|
||||||
{
|
|
||||||
String::from_utf8(api_key).log_err()
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
|
|
||||||
let fs = fs.clone();
|
let fs = fs.clone();
|
||||||
cx.spawn(|mut cx| async move {
|
cx.spawn(|mut cx| async move {
|
||||||
let semantic_index = SemanticIndex::new(
|
let semantic_index = SemanticIndex::new(
|
||||||
fs.clone(),
|
fs.clone(),
|
||||||
db_file_path,
|
db_file_path,
|
||||||
Arc::new(OpenAIEmbeddings::new(api_key, http_client, cx.background())),
|
Arc::new(OpenAIEmbeddings::new(http_client, cx.background())),
|
||||||
languages.clone(),
|
languages.clone(),
|
||||||
cx.clone(),
|
cx.clone(),
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,3 +1,3 @@
|
||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
|
|
||||||
RUST_LOG=semantic_index=trace cargo run -p semantic_index --example eval --release
|
RUST_LOG=semantic_index=trace cargo run --example semantic_index_eval --release
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue