moved authentication for the semantic index into the EmbeddingProvider
This commit is contained in:
parent
1e8b23d8fb
commit
a2c3971ad6
14 changed files with 200 additions and 206 deletions
|
@ -8,17 +8,8 @@ pub enum ProviderCredential {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub trait CredentialProvider: Send + Sync {
|
pub trait CredentialProvider: Send + Sync {
|
||||||
|
fn has_credentials(&self) -> bool;
|
||||||
fn retrieve_credentials(&self, cx: &AppContext) -> ProviderCredential;
|
fn retrieve_credentials(&self, cx: &AppContext) -> ProviderCredential;
|
||||||
fn save_credentials(&self, cx: &AppContext, credential: ProviderCredential);
|
fn save_credentials(&self, cx: &AppContext, credential: ProviderCredential);
|
||||||
fn delete_credentials(&self, cx: &AppContext);
|
fn delete_credentials(&self, cx: &AppContext);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone)]
|
|
||||||
pub struct NullCredentialProvider;
|
|
||||||
impl CredentialProvider for NullCredentialProvider {
|
|
||||||
fn retrieve_credentials(&self, _cx: &AppContext) -> ProviderCredential {
|
|
||||||
ProviderCredential::NotNeeded
|
|
||||||
}
|
|
||||||
fn save_credentials(&self, cx: &AppContext, credential: ProviderCredential) {}
|
|
||||||
fn delete_credentials(&self, cx: &AppContext) {}
|
|
||||||
}
|
|
||||||
|
|
|
@ -1,28 +1,14 @@
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use futures::{future::BoxFuture, stream::BoxStream};
|
use futures::{future::BoxFuture, stream::BoxStream};
|
||||||
use gpui::AppContext;
|
|
||||||
|
|
||||||
use crate::{
|
use crate::{auth::CredentialProvider, models::LanguageModel};
|
||||||
auth::{CredentialProvider, ProviderCredential},
|
|
||||||
models::LanguageModel,
|
|
||||||
};
|
|
||||||
|
|
||||||
pub trait CompletionRequest: Send + Sync {
|
pub trait CompletionRequest: Send + Sync {
|
||||||
fn data(&self) -> serde_json::Result<String>;
|
fn data(&self) -> serde_json::Result<String>;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub trait CompletionProvider {
|
pub trait CompletionProvider: CredentialProvider {
|
||||||
fn base_model(&self) -> Box<dyn LanguageModel>;
|
fn base_model(&self) -> Box<dyn LanguageModel>;
|
||||||
fn credential_provider(&self) -> Box<dyn CredentialProvider>;
|
|
||||||
fn retrieve_credentials(&self, cx: &AppContext) -> ProviderCredential {
|
|
||||||
self.credential_provider().retrieve_credentials(cx)
|
|
||||||
}
|
|
||||||
fn save_credentials(&self, cx: &AppContext, credential: ProviderCredential) {
|
|
||||||
self.credential_provider().save_credentials(cx, credential);
|
|
||||||
}
|
|
||||||
fn delete_credentials(&self, cx: &AppContext) {
|
|
||||||
self.credential_provider().delete_credentials(cx);
|
|
||||||
}
|
|
||||||
fn complete(
|
fn complete(
|
||||||
&self,
|
&self,
|
||||||
prompt: Box<dyn CompletionRequest>,
|
prompt: Box<dyn CompletionRequest>,
|
||||||
|
|
|
@ -2,12 +2,11 @@ use std::time::Instant;
|
||||||
|
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use gpui::AppContext;
|
|
||||||
use ordered_float::OrderedFloat;
|
use ordered_float::OrderedFloat;
|
||||||
use rusqlite::types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef};
|
use rusqlite::types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef};
|
||||||
use rusqlite::ToSql;
|
use rusqlite::ToSql;
|
||||||
|
|
||||||
use crate::auth::{CredentialProvider, ProviderCredential};
|
use crate::auth::CredentialProvider;
|
||||||
use crate::models::LanguageModel;
|
use crate::models::LanguageModel;
|
||||||
|
|
||||||
#[derive(Debug, PartialEq, Clone)]
|
#[derive(Debug, PartialEq, Clone)]
|
||||||
|
@ -70,17 +69,9 @@ impl Embedding {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
pub trait EmbeddingProvider: Sync + Send {
|
pub trait EmbeddingProvider: CredentialProvider {
|
||||||
fn base_model(&self) -> Box<dyn LanguageModel>;
|
fn base_model(&self) -> Box<dyn LanguageModel>;
|
||||||
fn credential_provider(&self) -> Box<dyn CredentialProvider>;
|
async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>>;
|
||||||
fn retrieve_credentials(&self, cx: &AppContext) -> ProviderCredential {
|
|
||||||
self.credential_provider().retrieve_credentials(cx)
|
|
||||||
}
|
|
||||||
async fn embed_batch(
|
|
||||||
&self,
|
|
||||||
spans: Vec<String>,
|
|
||||||
credential: ProviderCredential,
|
|
||||||
) -> Result<Vec<Embedding>>;
|
|
||||||
fn max_tokens_per_batch(&self) -> usize;
|
fn max_tokens_per_batch(&self) -> usize;
|
||||||
fn rate_limit_expiration(&self) -> Option<Instant>;
|
fn rate_limit_expiration(&self) -> Option<Instant>;
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,46 +0,0 @@
|
||||||
use std::env;
|
|
||||||
|
|
||||||
use gpui::AppContext;
|
|
||||||
use util::ResultExt;
|
|
||||||
|
|
||||||
use crate::auth::{CredentialProvider, ProviderCredential};
|
|
||||||
use crate::providers::open_ai::OPENAI_API_URL;
|
|
||||||
|
|
||||||
#[derive(Clone)]
|
|
||||||
pub struct OpenAICredentialProvider {}
|
|
||||||
|
|
||||||
impl CredentialProvider for OpenAICredentialProvider {
|
|
||||||
fn retrieve_credentials(&self, cx: &AppContext) -> ProviderCredential {
|
|
||||||
let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") {
|
|
||||||
Some(api_key)
|
|
||||||
} else if let Some((_, api_key)) = cx
|
|
||||||
.platform()
|
|
||||||
.read_credentials(OPENAI_API_URL)
|
|
||||||
.log_err()
|
|
||||||
.flatten()
|
|
||||||
{
|
|
||||||
String::from_utf8(api_key).log_err()
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
|
|
||||||
if let Some(api_key) = api_key {
|
|
||||||
ProviderCredential::Credentials { api_key }
|
|
||||||
} else {
|
|
||||||
ProviderCredential::NoCredentials
|
|
||||||
}
|
|
||||||
}
|
|
||||||
fn save_credentials(&self, cx: &AppContext, credential: ProviderCredential) {
|
|
||||||
match credential {
|
|
||||||
ProviderCredential::Credentials { api_key } => {
|
|
||||||
cx.platform()
|
|
||||||
.write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes())
|
|
||||||
.log_err();
|
|
||||||
}
|
|
||||||
_ => {}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
fn delete_credentials(&self, cx: &AppContext) {
|
|
||||||
cx.platform().delete_credentials(OPENAI_API_URL).log_err();
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -3,14 +3,17 @@ use futures::{
|
||||||
future::BoxFuture, io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, FutureExt,
|
future::BoxFuture, io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, FutureExt,
|
||||||
Stream, StreamExt,
|
Stream, StreamExt,
|
||||||
};
|
};
|
||||||
use gpui::executor::Background;
|
use gpui::{executor::Background, AppContext};
|
||||||
use isahc::{http::StatusCode, Request, RequestExt};
|
use isahc::{http::StatusCode, Request, RequestExt};
|
||||||
|
use parking_lot::RwLock;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::{
|
use std::{
|
||||||
|
env,
|
||||||
fmt::{self, Display},
|
fmt::{self, Display},
|
||||||
io,
|
io,
|
||||||
sync::Arc,
|
sync::Arc,
|
||||||
};
|
};
|
||||||
|
use util::ResultExt;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
auth::{CredentialProvider, ProviderCredential},
|
auth::{CredentialProvider, ProviderCredential},
|
||||||
|
@ -18,9 +21,7 @@ use crate::{
|
||||||
models::LanguageModel,
|
models::LanguageModel,
|
||||||
};
|
};
|
||||||
|
|
||||||
use super::{auth::OpenAICredentialProvider, OpenAILanguageModel};
|
use crate::providers::open_ai::{OpenAILanguageModel, OPENAI_API_URL};
|
||||||
|
|
||||||
pub const OPENAI_API_URL: &'static str = "https://api.openai.com/v1";
|
|
||||||
|
|
||||||
#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
|
#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||||
#[serde(rename_all = "lowercase")]
|
#[serde(rename_all = "lowercase")]
|
||||||
|
@ -194,42 +195,83 @@ pub async fn stream_completion(
|
||||||
|
|
||||||
pub struct OpenAICompletionProvider {
|
pub struct OpenAICompletionProvider {
|
||||||
model: OpenAILanguageModel,
|
model: OpenAILanguageModel,
|
||||||
credential_provider: OpenAICredentialProvider,
|
credential: Arc<RwLock<ProviderCredential>>,
|
||||||
credential: ProviderCredential,
|
|
||||||
executor: Arc<Background>,
|
executor: Arc<Background>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl OpenAICompletionProvider {
|
impl OpenAICompletionProvider {
|
||||||
pub fn new(
|
pub fn new(model_name: &str, executor: Arc<Background>) -> Self {
|
||||||
model_name: &str,
|
|
||||||
credential: ProviderCredential,
|
|
||||||
executor: Arc<Background>,
|
|
||||||
) -> Self {
|
|
||||||
let model = OpenAILanguageModel::load(model_name);
|
let model = OpenAILanguageModel::load(model_name);
|
||||||
let credential_provider = OpenAICredentialProvider {};
|
let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials));
|
||||||
Self {
|
Self {
|
||||||
model,
|
model,
|
||||||
credential_provider,
|
|
||||||
credential,
|
credential,
|
||||||
executor,
|
executor,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl CredentialProvider for OpenAICompletionProvider {
|
||||||
|
fn has_credentials(&self) -> bool {
|
||||||
|
match *self.credential.read() {
|
||||||
|
ProviderCredential::Credentials { .. } => true,
|
||||||
|
_ => false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fn retrieve_credentials(&self, cx: &AppContext) -> ProviderCredential {
|
||||||
|
let mut credential = self.credential.write();
|
||||||
|
match *credential {
|
||||||
|
ProviderCredential::Credentials { .. } => {
|
||||||
|
return credential.clone();
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
if let Ok(api_key) = env::var("OPENAI_API_KEY") {
|
||||||
|
*credential = ProviderCredential::Credentials { api_key };
|
||||||
|
} else if let Some((_, api_key)) = cx
|
||||||
|
.platform()
|
||||||
|
.read_credentials(OPENAI_API_URL)
|
||||||
|
.log_err()
|
||||||
|
.flatten()
|
||||||
|
{
|
||||||
|
if let Some(api_key) = String::from_utf8(api_key).log_err() {
|
||||||
|
*credential = ProviderCredential::Credentials { api_key };
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
credential.clone()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn save_credentials(&self, cx: &AppContext, credential: ProviderCredential) {
|
||||||
|
match credential.clone() {
|
||||||
|
ProviderCredential::Credentials { api_key } => {
|
||||||
|
cx.platform()
|
||||||
|
.write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes())
|
||||||
|
.log_err();
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
|
||||||
|
*self.credential.write() = credential;
|
||||||
|
}
|
||||||
|
fn delete_credentials(&self, cx: &AppContext) {
|
||||||
|
cx.platform().delete_credentials(OPENAI_API_URL).log_err();
|
||||||
|
*self.credential.write() = ProviderCredential::NoCredentials;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl CompletionProvider for OpenAICompletionProvider {
|
impl CompletionProvider for OpenAICompletionProvider {
|
||||||
fn base_model(&self) -> Box<dyn LanguageModel> {
|
fn base_model(&self) -> Box<dyn LanguageModel> {
|
||||||
let model: Box<dyn LanguageModel> = Box::new(self.model.clone());
|
let model: Box<dyn LanguageModel> = Box::new(self.model.clone());
|
||||||
model
|
model
|
||||||
}
|
}
|
||||||
fn credential_provider(&self) -> Box<dyn CredentialProvider> {
|
|
||||||
let provider: Box<dyn CredentialProvider> = Box::new(self.credential_provider.clone());
|
|
||||||
provider
|
|
||||||
}
|
|
||||||
fn complete(
|
fn complete(
|
||||||
&self,
|
&self,
|
||||||
prompt: Box<dyn CompletionRequest>,
|
prompt: Box<dyn CompletionRequest>,
|
||||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
||||||
let credential = self.credential.clone();
|
let credential = self.credential.read().clone();
|
||||||
let request = stream_completion(credential, self.executor.clone(), prompt);
|
let request = stream_completion(credential, self.executor.clone(), prompt);
|
||||||
async move {
|
async move {
|
||||||
let response = request.await?;
|
let response = request.await?;
|
||||||
|
|
|
@ -2,27 +2,29 @@ use anyhow::{anyhow, Result};
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use futures::AsyncReadExt;
|
use futures::AsyncReadExt;
|
||||||
use gpui::executor::Background;
|
use gpui::executor::Background;
|
||||||
use gpui::serde_json;
|
use gpui::{serde_json, AppContext};
|
||||||
use isahc::http::StatusCode;
|
use isahc::http::StatusCode;
|
||||||
use isahc::prelude::Configurable;
|
use isahc::prelude::Configurable;
|
||||||
use isahc::{AsyncBody, Response};
|
use isahc::{AsyncBody, Response};
|
||||||
use lazy_static::lazy_static;
|
use lazy_static::lazy_static;
|
||||||
use parking_lot::Mutex;
|
use parking_lot::{Mutex, RwLock};
|
||||||
use parse_duration::parse;
|
use parse_duration::parse;
|
||||||
use postage::watch;
|
use postage::watch;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::env;
|
||||||
use std::ops::Add;
|
use std::ops::Add;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::time::{Duration, Instant};
|
use std::time::{Duration, Instant};
|
||||||
use tiktoken_rs::{cl100k_base, CoreBPE};
|
use tiktoken_rs::{cl100k_base, CoreBPE};
|
||||||
use util::http::{HttpClient, Request};
|
use util::http::{HttpClient, Request};
|
||||||
|
use util::ResultExt;
|
||||||
|
|
||||||
use crate::auth::{CredentialProvider, ProviderCredential};
|
use crate::auth::{CredentialProvider, ProviderCredential};
|
||||||
use crate::embedding::{Embedding, EmbeddingProvider};
|
use crate::embedding::{Embedding, EmbeddingProvider};
|
||||||
use crate::models::LanguageModel;
|
use crate::models::LanguageModel;
|
||||||
use crate::providers::open_ai::OpenAILanguageModel;
|
use crate::providers::open_ai::OpenAILanguageModel;
|
||||||
|
|
||||||
use crate::providers::open_ai::auth::OpenAICredentialProvider;
|
use crate::providers::open_ai::OPENAI_API_URL;
|
||||||
|
|
||||||
lazy_static! {
|
lazy_static! {
|
||||||
static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap();
|
static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap();
|
||||||
|
@ -31,7 +33,7 @@ lazy_static! {
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct OpenAIEmbeddingProvider {
|
pub struct OpenAIEmbeddingProvider {
|
||||||
model: OpenAILanguageModel,
|
model: OpenAILanguageModel,
|
||||||
credential_provider: OpenAICredentialProvider,
|
credential: Arc<RwLock<ProviderCredential>>,
|
||||||
pub client: Arc<dyn HttpClient>,
|
pub client: Arc<dyn HttpClient>,
|
||||||
pub executor: Arc<Background>,
|
pub executor: Arc<Background>,
|
||||||
rate_limit_count_rx: watch::Receiver<Option<Instant>>,
|
rate_limit_count_rx: watch::Receiver<Option<Instant>>,
|
||||||
|
@ -69,10 +71,11 @@ impl OpenAIEmbeddingProvider {
|
||||||
let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx));
|
let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx));
|
||||||
|
|
||||||
let model = OpenAILanguageModel::load("text-embedding-ada-002");
|
let model = OpenAILanguageModel::load("text-embedding-ada-002");
|
||||||
|
let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials));
|
||||||
|
|
||||||
OpenAIEmbeddingProvider {
|
OpenAIEmbeddingProvider {
|
||||||
model,
|
model,
|
||||||
credential_provider: OpenAICredentialProvider {},
|
credential,
|
||||||
client,
|
client,
|
||||||
executor,
|
executor,
|
||||||
rate_limit_count_rx,
|
rate_limit_count_rx,
|
||||||
|
@ -80,6 +83,13 @@ impl OpenAIEmbeddingProvider {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn get_api_key(&self) -> Result<String> {
|
||||||
|
match self.credential.read().clone() {
|
||||||
|
ProviderCredential::Credentials { api_key } => Ok(api_key),
|
||||||
|
_ => Err(anyhow!("api credentials not provided")),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn resolve_rate_limit(&self) {
|
fn resolve_rate_limit(&self) {
|
||||||
let reset_time = *self.rate_limit_count_tx.lock().borrow();
|
let reset_time = *self.rate_limit_count_tx.lock().borrow();
|
||||||
|
|
||||||
|
@ -136,6 +146,57 @@ impl OpenAIEmbeddingProvider {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl CredentialProvider for OpenAIEmbeddingProvider {
|
||||||
|
fn has_credentials(&self) -> bool {
|
||||||
|
match *self.credential.read() {
|
||||||
|
ProviderCredential::Credentials { .. } => true,
|
||||||
|
_ => false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fn retrieve_credentials(&self, cx: &AppContext) -> ProviderCredential {
|
||||||
|
let mut credential = self.credential.write();
|
||||||
|
match *credential {
|
||||||
|
ProviderCredential::Credentials { .. } => {
|
||||||
|
return credential.clone();
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
if let Ok(api_key) = env::var("OPENAI_API_KEY") {
|
||||||
|
*credential = ProviderCredential::Credentials { api_key };
|
||||||
|
} else if let Some((_, api_key)) = cx
|
||||||
|
.platform()
|
||||||
|
.read_credentials(OPENAI_API_URL)
|
||||||
|
.log_err()
|
||||||
|
.flatten()
|
||||||
|
{
|
||||||
|
if let Some(api_key) = String::from_utf8(api_key).log_err() {
|
||||||
|
*credential = ProviderCredential::Credentials { api_key };
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
credential.clone()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn save_credentials(&self, cx: &AppContext, credential: ProviderCredential) {
|
||||||
|
match credential.clone() {
|
||||||
|
ProviderCredential::Credentials { api_key } => {
|
||||||
|
cx.platform()
|
||||||
|
.write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes())
|
||||||
|
.log_err();
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
|
||||||
|
*self.credential.write() = credential;
|
||||||
|
}
|
||||||
|
fn delete_credentials(&self, cx: &AppContext) {
|
||||||
|
cx.platform().delete_credentials(OPENAI_API_URL).log_err();
|
||||||
|
*self.credential.write() = ProviderCredential::NoCredentials;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl EmbeddingProvider for OpenAIEmbeddingProvider {
|
impl EmbeddingProvider for OpenAIEmbeddingProvider {
|
||||||
fn base_model(&self) -> Box<dyn LanguageModel> {
|
fn base_model(&self) -> Box<dyn LanguageModel> {
|
||||||
|
@ -143,12 +204,6 @@ impl EmbeddingProvider for OpenAIEmbeddingProvider {
|
||||||
model
|
model
|
||||||
}
|
}
|
||||||
|
|
||||||
fn credential_provider(&self) -> Box<dyn CredentialProvider> {
|
|
||||||
let credential_provider: Box<dyn CredentialProvider> =
|
|
||||||
Box::new(self.credential_provider.clone());
|
|
||||||
credential_provider
|
|
||||||
}
|
|
||||||
|
|
||||||
fn max_tokens_per_batch(&self) -> usize {
|
fn max_tokens_per_batch(&self) -> usize {
|
||||||
50000
|
50000
|
||||||
}
|
}
|
||||||
|
@ -157,18 +212,11 @@ impl EmbeddingProvider for OpenAIEmbeddingProvider {
|
||||||
*self.rate_limit_count_rx.borrow()
|
*self.rate_limit_count_rx.borrow()
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn embed_batch(
|
async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
|
||||||
&self,
|
|
||||||
spans: Vec<String>,
|
|
||||||
credential: ProviderCredential,
|
|
||||||
) -> Result<Vec<Embedding>> {
|
|
||||||
const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
|
const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
|
||||||
const MAX_RETRIES: usize = 4;
|
const MAX_RETRIES: usize = 4;
|
||||||
|
|
||||||
let api_key = match credential {
|
let api_key = self.get_api_key()?;
|
||||||
ProviderCredential::Credentials { api_key } => anyhow::Ok(api_key),
|
|
||||||
_ => Err(anyhow!("no api key provided")),
|
|
||||||
}?;
|
|
||||||
|
|
||||||
let mut request_number = 0;
|
let mut request_number = 0;
|
||||||
let mut rate_limiting = false;
|
let mut rate_limiting = false;
|
||||||
|
|
|
@ -1,4 +1,3 @@
|
||||||
pub mod auth;
|
|
||||||
pub mod completion;
|
pub mod completion;
|
||||||
pub mod embedding;
|
pub mod embedding;
|
||||||
pub mod model;
|
pub mod model;
|
||||||
|
@ -6,3 +5,5 @@ pub mod model;
|
||||||
pub use completion::*;
|
pub use completion::*;
|
||||||
pub use embedding::*;
|
pub use embedding::*;
|
||||||
pub use model::OpenAILanguageModel;
|
pub use model::OpenAILanguageModel;
|
||||||
|
|
||||||
|
pub const OPENAI_API_URL: &'static str = "https://api.openai.com/v1";
|
||||||
|
|
11
crates/ai/src/providers/open_ai/new.rs
Normal file
11
crates/ai/src/providers/open_ai/new.rs
Normal file
|
@ -0,0 +1,11 @@
|
||||||
|
pub trait LanguageModel {
|
||||||
|
fn name(&self) -> String;
|
||||||
|
fn count_tokens(&self, content: &str) -> anyhow::Result<usize>;
|
||||||
|
fn truncate(
|
||||||
|
&self,
|
||||||
|
content: &str,
|
||||||
|
length: usize,
|
||||||
|
direction: TruncationDirection,
|
||||||
|
) -> anyhow::Result<String>;
|
||||||
|
fn capacity(&self) -> anyhow::Result<usize>;
|
||||||
|
}
|
|
@ -5,10 +5,11 @@ use std::{
|
||||||
|
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
|
use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
|
||||||
|
use gpui::AppContext;
|
||||||
use parking_lot::Mutex;
|
use parking_lot::Mutex;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
auth::{CredentialProvider, NullCredentialProvider, ProviderCredential},
|
auth::{CredentialProvider, ProviderCredential},
|
||||||
completion::{CompletionProvider, CompletionRequest},
|
completion::{CompletionProvider, CompletionRequest},
|
||||||
embedding::{Embedding, EmbeddingProvider},
|
embedding::{Embedding, EmbeddingProvider},
|
||||||
models::{LanguageModel, TruncationDirection},
|
models::{LanguageModel, TruncationDirection},
|
||||||
|
@ -52,14 +53,12 @@ impl LanguageModel for FakeLanguageModel {
|
||||||
|
|
||||||
pub struct FakeEmbeddingProvider {
|
pub struct FakeEmbeddingProvider {
|
||||||
pub embedding_count: AtomicUsize,
|
pub embedding_count: AtomicUsize,
|
||||||
pub credential_provider: NullCredentialProvider,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Clone for FakeEmbeddingProvider {
|
impl Clone for FakeEmbeddingProvider {
|
||||||
fn clone(&self) -> Self {
|
fn clone(&self) -> Self {
|
||||||
FakeEmbeddingProvider {
|
FakeEmbeddingProvider {
|
||||||
embedding_count: AtomicUsize::new(self.embedding_count.load(Ordering::SeqCst)),
|
embedding_count: AtomicUsize::new(self.embedding_count.load(Ordering::SeqCst)),
|
||||||
credential_provider: self.credential_provider.clone(),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -68,7 +67,6 @@ impl Default for FakeEmbeddingProvider {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
FakeEmbeddingProvider {
|
FakeEmbeddingProvider {
|
||||||
embedding_count: AtomicUsize::default(),
|
embedding_count: AtomicUsize::default(),
|
||||||
credential_provider: NullCredentialProvider {},
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -99,16 +97,22 @@ impl FakeEmbeddingProvider {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl CredentialProvider for FakeEmbeddingProvider {
|
||||||
|
fn has_credentials(&self) -> bool {
|
||||||
|
true
|
||||||
|
}
|
||||||
|
fn retrieve_credentials(&self, _cx: &AppContext) -> ProviderCredential {
|
||||||
|
ProviderCredential::NotNeeded
|
||||||
|
}
|
||||||
|
fn save_credentials(&self, _cx: &AppContext, _credential: ProviderCredential) {}
|
||||||
|
fn delete_credentials(&self, _cx: &AppContext) {}
|
||||||
|
}
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl EmbeddingProvider for FakeEmbeddingProvider {
|
impl EmbeddingProvider for FakeEmbeddingProvider {
|
||||||
fn base_model(&self) -> Box<dyn LanguageModel> {
|
fn base_model(&self) -> Box<dyn LanguageModel> {
|
||||||
Box::new(FakeLanguageModel { capacity: 1000 })
|
Box::new(FakeLanguageModel { capacity: 1000 })
|
||||||
}
|
}
|
||||||
fn credential_provider(&self) -> Box<dyn CredentialProvider> {
|
|
||||||
let credential_provider: Box<dyn CredentialProvider> =
|
|
||||||
Box::new(self.credential_provider.clone());
|
|
||||||
credential_provider
|
|
||||||
}
|
|
||||||
fn max_tokens_per_batch(&self) -> usize {
|
fn max_tokens_per_batch(&self) -> usize {
|
||||||
1000
|
1000
|
||||||
}
|
}
|
||||||
|
@ -117,11 +121,7 @@ impl EmbeddingProvider for FakeEmbeddingProvider {
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn embed_batch(
|
async fn embed_batch(&self, spans: Vec<String>) -> anyhow::Result<Vec<Embedding>> {
|
||||||
&self,
|
|
||||||
spans: Vec<String>,
|
|
||||||
_credential: ProviderCredential,
|
|
||||||
) -> anyhow::Result<Vec<Embedding>> {
|
|
||||||
self.embedding_count
|
self.embedding_count
|
||||||
.fetch_add(spans.len(), atomic::Ordering::SeqCst);
|
.fetch_add(spans.len(), atomic::Ordering::SeqCst);
|
||||||
|
|
||||||
|
@ -129,11 +129,11 @@ impl EmbeddingProvider for FakeEmbeddingProvider {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct TestCompletionProvider {
|
pub struct FakeCompletionProvider {
|
||||||
last_completion_tx: Mutex<Option<mpsc::Sender<String>>>,
|
last_completion_tx: Mutex<Option<mpsc::Sender<String>>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl TestCompletionProvider {
|
impl FakeCompletionProvider {
|
||||||
pub fn new() -> Self {
|
pub fn new() -> Self {
|
||||||
Self {
|
Self {
|
||||||
last_completion_tx: Mutex::new(None),
|
last_completion_tx: Mutex::new(None),
|
||||||
|
@ -150,14 +150,22 @@ impl TestCompletionProvider {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl CompletionProvider for TestCompletionProvider {
|
impl CredentialProvider for FakeCompletionProvider {
|
||||||
|
fn has_credentials(&self) -> bool {
|
||||||
|
true
|
||||||
|
}
|
||||||
|
fn retrieve_credentials(&self, _cx: &AppContext) -> ProviderCredential {
|
||||||
|
ProviderCredential::NotNeeded
|
||||||
|
}
|
||||||
|
fn save_credentials(&self, _cx: &AppContext, _credential: ProviderCredential) {}
|
||||||
|
fn delete_credentials(&self, _cx: &AppContext) {}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CompletionProvider for FakeCompletionProvider {
|
||||||
fn base_model(&self) -> Box<dyn LanguageModel> {
|
fn base_model(&self) -> Box<dyn LanguageModel> {
|
||||||
let model: Box<dyn LanguageModel> = Box::new(FakeLanguageModel { capacity: 8190 });
|
let model: Box<dyn LanguageModel> = Box::new(FakeLanguageModel { capacity: 8190 });
|
||||||
model
|
model
|
||||||
}
|
}
|
||||||
fn credential_provider(&self) -> Box<dyn CredentialProvider> {
|
|
||||||
Box::new(NullCredentialProvider {})
|
|
||||||
}
|
|
||||||
fn complete(
|
fn complete(
|
||||||
&self,
|
&self,
|
||||||
_prompt: Box<dyn CompletionRequest>,
|
_prompt: Box<dyn CompletionRequest>,
|
||||||
|
|
|
@ -10,7 +10,7 @@ use ai::{
|
||||||
auth::ProviderCredential,
|
auth::ProviderCredential,
|
||||||
completion::{CompletionProvider, CompletionRequest},
|
completion::{CompletionProvider, CompletionRequest},
|
||||||
providers::open_ai::{
|
providers::open_ai::{
|
||||||
stream_completion, OpenAICompletionProvider, OpenAIRequest, RequestMessage, OPENAI_API_URL,
|
stream_completion, OpenAICompletionProvider, OpenAIRequest, RequestMessage,
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -48,7 +48,7 @@ use semantic_index::{SemanticIndex, SemanticIndexStatus};
|
||||||
use settings::SettingsStore;
|
use settings::SettingsStore;
|
||||||
use std::{
|
use std::{
|
||||||
cell::{Cell, RefCell},
|
cell::{Cell, RefCell},
|
||||||
cmp, env,
|
cmp,
|
||||||
fmt::Write,
|
fmt::Write,
|
||||||
iter,
|
iter,
|
||||||
ops::Range,
|
ops::Range,
|
||||||
|
@ -210,7 +210,6 @@ impl AssistantPanel {
|
||||||
// Defaulting currently to GPT4, allow for this to be set via config.
|
// Defaulting currently to GPT4, allow for this to be set via config.
|
||||||
let completion_provider = Box::new(OpenAICompletionProvider::new(
|
let completion_provider = Box::new(OpenAICompletionProvider::new(
|
||||||
"gpt-4",
|
"gpt-4",
|
||||||
ProviderCredential::NoCredentials,
|
|
||||||
cx.background().clone(),
|
cx.background().clone(),
|
||||||
));
|
));
|
||||||
|
|
||||||
|
@ -298,7 +297,6 @@ impl AssistantPanel {
|
||||||
cx: &mut ViewContext<Self>,
|
cx: &mut ViewContext<Self>,
|
||||||
project: &ModelHandle<Project>,
|
project: &ModelHandle<Project>,
|
||||||
) {
|
) {
|
||||||
let credential = self.credential.borrow().clone();
|
|
||||||
let selection = editor.read(cx).selections.newest_anchor().clone();
|
let selection = editor.read(cx).selections.newest_anchor().clone();
|
||||||
if selection.start.excerpt_id() != selection.end.excerpt_id() {
|
if selection.start.excerpt_id() != selection.end.excerpt_id() {
|
||||||
return;
|
return;
|
||||||
|
@ -330,7 +328,6 @@ impl AssistantPanel {
|
||||||
let inline_assist_id = post_inc(&mut self.next_inline_assist_id);
|
let inline_assist_id = post_inc(&mut self.next_inline_assist_id);
|
||||||
let provider = Arc::new(OpenAICompletionProvider::new(
|
let provider = Arc::new(OpenAICompletionProvider::new(
|
||||||
"gpt-4",
|
"gpt-4",
|
||||||
credential,
|
|
||||||
cx.background().clone(),
|
cx.background().clone(),
|
||||||
));
|
));
|
||||||
|
|
||||||
|
|
|
@ -335,7 +335,7 @@ fn strip_markdown_codeblock(
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use ai::test::TestCompletionProvider;
|
use ai::test::FakeCompletionProvider;
|
||||||
use futures::stream::{self};
|
use futures::stream::{self};
|
||||||
use gpui::{executor::Deterministic, TestAppContext};
|
use gpui::{executor::Deterministic, TestAppContext};
|
||||||
use indoc::indoc;
|
use indoc::indoc;
|
||||||
|
@ -379,7 +379,7 @@ mod tests {
|
||||||
let snapshot = buffer.snapshot(cx);
|
let snapshot = buffer.snapshot(cx);
|
||||||
snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
|
snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
|
||||||
});
|
});
|
||||||
let provider = Arc::new(TestCompletionProvider::new());
|
let provider = Arc::new(FakeCompletionProvider::new());
|
||||||
let codegen = cx.add_model(|cx| {
|
let codegen = cx.add_model(|cx| {
|
||||||
Codegen::new(
|
Codegen::new(
|
||||||
buffer.clone(),
|
buffer.clone(),
|
||||||
|
@ -445,7 +445,7 @@ mod tests {
|
||||||
let snapshot = buffer.snapshot(cx);
|
let snapshot = buffer.snapshot(cx);
|
||||||
snapshot.anchor_before(Point::new(1, 6))
|
snapshot.anchor_before(Point::new(1, 6))
|
||||||
});
|
});
|
||||||
let provider = Arc::new(TestCompletionProvider::new());
|
let provider = Arc::new(FakeCompletionProvider::new());
|
||||||
let codegen = cx.add_model(|cx| {
|
let codegen = cx.add_model(|cx| {
|
||||||
Codegen::new(
|
Codegen::new(
|
||||||
buffer.clone(),
|
buffer.clone(),
|
||||||
|
@ -511,7 +511,7 @@ mod tests {
|
||||||
let snapshot = buffer.snapshot(cx);
|
let snapshot = buffer.snapshot(cx);
|
||||||
snapshot.anchor_before(Point::new(1, 2))
|
snapshot.anchor_before(Point::new(1, 2))
|
||||||
});
|
});
|
||||||
let provider = Arc::new(TestCompletionProvider::new());
|
let provider = Arc::new(FakeCompletionProvider::new());
|
||||||
let codegen = cx.add_model(|cx| {
|
let codegen = cx.add_model(|cx| {
|
||||||
Codegen::new(
|
Codegen::new(
|
||||||
buffer.clone(),
|
buffer.clone(),
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
use crate::{parsing::Span, JobHandle};
|
use crate::{parsing::Span, JobHandle};
|
||||||
use ai::{auth::ProviderCredential, embedding::EmbeddingProvider};
|
use ai::embedding::EmbeddingProvider;
|
||||||
use gpui::executor::Background;
|
use gpui::executor::Background;
|
||||||
use parking_lot::Mutex;
|
use parking_lot::Mutex;
|
||||||
use smol::channel;
|
use smol::channel;
|
||||||
|
@ -41,7 +41,6 @@ pub struct EmbeddingQueue {
|
||||||
pending_batch_token_count: usize,
|
pending_batch_token_count: usize,
|
||||||
finished_files_tx: channel::Sender<FileToEmbed>,
|
finished_files_tx: channel::Sender<FileToEmbed>,
|
||||||
finished_files_rx: channel::Receiver<FileToEmbed>,
|
finished_files_rx: channel::Receiver<FileToEmbed>,
|
||||||
pub provider_credential: ProviderCredential,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
|
@ -51,11 +50,7 @@ pub struct FileFragmentToEmbed {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl EmbeddingQueue {
|
impl EmbeddingQueue {
|
||||||
pub fn new(
|
pub fn new(embedding_provider: Arc<dyn EmbeddingProvider>, executor: Arc<Background>) -> Self {
|
||||||
embedding_provider: Arc<dyn EmbeddingProvider>,
|
|
||||||
executor: Arc<Background>,
|
|
||||||
provider_credential: ProviderCredential,
|
|
||||||
) -> Self {
|
|
||||||
let (finished_files_tx, finished_files_rx) = channel::unbounded();
|
let (finished_files_tx, finished_files_rx) = channel::unbounded();
|
||||||
Self {
|
Self {
|
||||||
embedding_provider,
|
embedding_provider,
|
||||||
|
@ -64,14 +59,9 @@ impl EmbeddingQueue {
|
||||||
pending_batch_token_count: 0,
|
pending_batch_token_count: 0,
|
||||||
finished_files_tx,
|
finished_files_tx,
|
||||||
finished_files_rx,
|
finished_files_rx,
|
||||||
provider_credential,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn set_credential(&mut self, credential: ProviderCredential) {
|
|
||||||
self.provider_credential = credential;
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn push(&mut self, file: FileToEmbed) {
|
pub fn push(&mut self, file: FileToEmbed) {
|
||||||
if file.spans.is_empty() {
|
if file.spans.is_empty() {
|
||||||
self.finished_files_tx.try_send(file).unwrap();
|
self.finished_files_tx.try_send(file).unwrap();
|
||||||
|
@ -118,7 +108,6 @@ impl EmbeddingQueue {
|
||||||
|
|
||||||
let finished_files_tx = self.finished_files_tx.clone();
|
let finished_files_tx = self.finished_files_tx.clone();
|
||||||
let embedding_provider = self.embedding_provider.clone();
|
let embedding_provider = self.embedding_provider.clone();
|
||||||
let credential = self.provider_credential.clone();
|
|
||||||
|
|
||||||
self.executor
|
self.executor
|
||||||
.spawn(async move {
|
.spawn(async move {
|
||||||
|
@ -143,7 +132,7 @@ impl EmbeddingQueue {
|
||||||
return;
|
return;
|
||||||
};
|
};
|
||||||
|
|
||||||
match embedding_provider.embed_batch(spans, credential).await {
|
match embedding_provider.embed_batch(spans).await {
|
||||||
Ok(embeddings) => {
|
Ok(embeddings) => {
|
||||||
let mut embeddings = embeddings.into_iter();
|
let mut embeddings = embeddings.into_iter();
|
||||||
for fragment in batch {
|
for fragment in batch {
|
||||||
|
|
|
@ -7,7 +7,6 @@ pub mod semantic_index_settings;
|
||||||
mod semantic_index_tests;
|
mod semantic_index_tests;
|
||||||
|
|
||||||
use crate::semantic_index_settings::SemanticIndexSettings;
|
use crate::semantic_index_settings::SemanticIndexSettings;
|
||||||
use ai::auth::ProviderCredential;
|
|
||||||
use ai::embedding::{Embedding, EmbeddingProvider};
|
use ai::embedding::{Embedding, EmbeddingProvider};
|
||||||
use ai::providers::open_ai::OpenAIEmbeddingProvider;
|
use ai::providers::open_ai::OpenAIEmbeddingProvider;
|
||||||
use anyhow::{anyhow, Result};
|
use anyhow::{anyhow, Result};
|
||||||
|
@ -125,8 +124,6 @@ pub struct SemanticIndex {
|
||||||
_embedding_task: Task<()>,
|
_embedding_task: Task<()>,
|
||||||
_parsing_files_tasks: Vec<Task<()>>,
|
_parsing_files_tasks: Vec<Task<()>>,
|
||||||
projects: HashMap<WeakModelHandle<Project>, ProjectState>,
|
projects: HashMap<WeakModelHandle<Project>, ProjectState>,
|
||||||
provider_credential: ProviderCredential,
|
|
||||||
embedding_queue: Arc<Mutex<EmbeddingQueue>>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ProjectState {
|
struct ProjectState {
|
||||||
|
@ -281,24 +278,17 @@ impl SemanticIndex {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn authenticate(&mut self, cx: &AppContext) -> bool {
|
pub fn authenticate(&mut self, cx: &AppContext) -> bool {
|
||||||
let existing_credential = self.provider_credential.clone();
|
if !self.embedding_provider.has_credentials() {
|
||||||
let credential = match existing_credential {
|
self.embedding_provider.retrieve_credentials(cx);
|
||||||
ProviderCredential::NoCredentials => self.embedding_provider.retrieve_credentials(cx),
|
} else {
|
||||||
_ => existing_credential,
|
return true;
|
||||||
};
|
}
|
||||||
|
|
||||||
self.provider_credential = credential.clone();
|
self.embedding_provider.has_credentials()
|
||||||
self.embedding_queue.lock().set_credential(credential);
|
|
||||||
self.is_authenticated()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn is_authenticated(&self) -> bool {
|
pub fn is_authenticated(&self) -> bool {
|
||||||
let credential = &self.provider_credential;
|
self.embedding_provider.has_credentials()
|
||||||
match credential {
|
|
||||||
&ProviderCredential::Credentials { .. } => true,
|
|
||||||
&ProviderCredential::NotNeeded => true,
|
|
||||||
_ => false,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn enabled(cx: &AppContext) -> bool {
|
pub fn enabled(cx: &AppContext) -> bool {
|
||||||
|
@ -348,7 +338,7 @@ impl SemanticIndex {
|
||||||
Ok(cx.add_model(|cx| {
|
Ok(cx.add_model(|cx| {
|
||||||
let t0 = Instant::now();
|
let t0 = Instant::now();
|
||||||
let embedding_queue =
|
let embedding_queue =
|
||||||
EmbeddingQueue::new(embedding_provider.clone(), cx.background().clone(), ProviderCredential::NoCredentials);
|
EmbeddingQueue::new(embedding_provider.clone(), cx.background().clone());
|
||||||
let _embedding_task = cx.background().spawn({
|
let _embedding_task = cx.background().spawn({
|
||||||
let embedded_files = embedding_queue.finished_files();
|
let embedded_files = embedding_queue.finished_files();
|
||||||
let db = db.clone();
|
let db = db.clone();
|
||||||
|
@ -413,8 +403,6 @@ impl SemanticIndex {
|
||||||
_embedding_task,
|
_embedding_task,
|
||||||
_parsing_files_tasks,
|
_parsing_files_tasks,
|
||||||
projects: Default::default(),
|
projects: Default::default(),
|
||||||
provider_credential: ProviderCredential::NoCredentials,
|
|
||||||
embedding_queue
|
|
||||||
}
|
}
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
@ -729,14 +717,13 @@ impl SemanticIndex {
|
||||||
|
|
||||||
let index = self.index_project(project.clone(), cx);
|
let index = self.index_project(project.clone(), cx);
|
||||||
let embedding_provider = self.embedding_provider.clone();
|
let embedding_provider = self.embedding_provider.clone();
|
||||||
let credential = self.provider_credential.clone();
|
|
||||||
|
|
||||||
cx.spawn(|this, mut cx| async move {
|
cx.spawn(|this, mut cx| async move {
|
||||||
index.await?;
|
index.await?;
|
||||||
let t0 = Instant::now();
|
let t0 = Instant::now();
|
||||||
|
|
||||||
let query = embedding_provider
|
let query = embedding_provider
|
||||||
.embed_batch(vec![query], credential)
|
.embed_batch(vec![query])
|
||||||
.await?
|
.await?
|
||||||
.pop()
|
.pop()
|
||||||
.ok_or_else(|| anyhow!("could not embed query"))?;
|
.ok_or_else(|| anyhow!("could not embed query"))?;
|
||||||
|
@ -954,7 +941,6 @@ impl SemanticIndex {
|
||||||
let fs = self.fs.clone();
|
let fs = self.fs.clone();
|
||||||
let db_path = self.db.path().clone();
|
let db_path = self.db.path().clone();
|
||||||
let background = cx.background().clone();
|
let background = cx.background().clone();
|
||||||
let credential = self.provider_credential.clone();
|
|
||||||
cx.background().spawn(async move {
|
cx.background().spawn(async move {
|
||||||
let db = VectorDatabase::new(fs, db_path.clone(), background).await?;
|
let db = VectorDatabase::new(fs, db_path.clone(), background).await?;
|
||||||
let mut results = Vec::<SearchResult>::new();
|
let mut results = Vec::<SearchResult>::new();
|
||||||
|
@ -969,12 +955,7 @@ impl SemanticIndex {
|
||||||
.parse_file_with_template(None, &snapshot.text(), language)
|
.parse_file_with_template(None, &snapshot.text(), language)
|
||||||
.log_err()
|
.log_err()
|
||||||
.unwrap_or_default();
|
.unwrap_or_default();
|
||||||
if Self::embed_spans(
|
if Self::embed_spans(&mut spans, embedding_provider.as_ref(), &db)
|
||||||
&mut spans,
|
|
||||||
embedding_provider.as_ref(),
|
|
||||||
&db,
|
|
||||||
credential.clone(),
|
|
||||||
)
|
|
||||||
.await
|
.await
|
||||||
.log_err()
|
.log_err()
|
||||||
.is_some()
|
.is_some()
|
||||||
|
@ -1201,7 +1182,6 @@ impl SemanticIndex {
|
||||||
spans: &mut [Span],
|
spans: &mut [Span],
|
||||||
embedding_provider: &dyn EmbeddingProvider,
|
embedding_provider: &dyn EmbeddingProvider,
|
||||||
db: &VectorDatabase,
|
db: &VectorDatabase,
|
||||||
credential: ProviderCredential,
|
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let mut batch = Vec::new();
|
let mut batch = Vec::new();
|
||||||
let mut batch_tokens = 0;
|
let mut batch_tokens = 0;
|
||||||
|
@ -1224,7 +1204,7 @@ impl SemanticIndex {
|
||||||
|
|
||||||
if batch_tokens + span.token_count > embedding_provider.max_tokens_per_batch() {
|
if batch_tokens + span.token_count > embedding_provider.max_tokens_per_batch() {
|
||||||
let batch_embeddings = embedding_provider
|
let batch_embeddings = embedding_provider
|
||||||
.embed_batch(mem::take(&mut batch), credential.clone())
|
.embed_batch(mem::take(&mut batch))
|
||||||
.await?;
|
.await?;
|
||||||
embeddings.extend(batch_embeddings);
|
embeddings.extend(batch_embeddings);
|
||||||
batch_tokens = 0;
|
batch_tokens = 0;
|
||||||
|
@ -1236,7 +1216,7 @@ impl SemanticIndex {
|
||||||
|
|
||||||
if !batch.is_empty() {
|
if !batch.is_empty() {
|
||||||
let batch_embeddings = embedding_provider
|
let batch_embeddings = embedding_provider
|
||||||
.embed_batch(mem::take(&mut batch), credential)
|
.embed_batch(mem::take(&mut batch))
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
embeddings.extend(batch_embeddings);
|
embeddings.extend(batch_embeddings);
|
||||||
|
|
|
@ -220,11 +220,7 @@ async fn test_embedding_batching(cx: &mut TestAppContext, mut rng: StdRng) {
|
||||||
|
|
||||||
let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
|
let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
|
||||||
|
|
||||||
let mut queue = EmbeddingQueue::new(
|
let mut queue = EmbeddingQueue::new(embedding_provider.clone(), cx.background());
|
||||||
embedding_provider.clone(),
|
|
||||||
cx.background(),
|
|
||||||
ai::auth::ProviderCredential::NoCredentials,
|
|
||||||
);
|
|
||||||
for file in &files {
|
for file in &files {
|
||||||
queue.push(file.clone());
|
queue.push(file.clone());
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue