updated authentication for embedding provider
This commit is contained in:
parent
71bc35d241
commit
3447a9478c
16 changed files with 277 additions and 271 deletions
|
@ -8,6 +8,9 @@ publish = false
|
||||||
path = "src/ai.rs"
|
path = "src/ai.rs"
|
||||||
doctest = false
|
doctest = false
|
||||||
|
|
||||||
|
[features]
|
||||||
|
test-support = []
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
gpui = { path = "../gpui" }
|
gpui = { path = "../gpui" }
|
||||||
util = { path = "../util" }
|
util = { path = "../util" }
|
||||||
|
|
|
@ -1,5 +1,8 @@
|
||||||
|
pub mod auth;
|
||||||
pub mod completion;
|
pub mod completion;
|
||||||
pub mod embedding;
|
pub mod embedding;
|
||||||
pub mod models;
|
pub mod models;
|
||||||
pub mod prompts;
|
pub mod prompts;
|
||||||
pub mod providers;
|
pub mod providers;
|
||||||
|
#[cfg(any(test, feature = "test-support"))]
|
||||||
|
pub mod test;
|
||||||
|
|
20
crates/ai/src/auth.rs
Normal file
20
crates/ai/src/auth.rs
Normal file
|
@ -0,0 +1,20 @@
|
||||||
|
use gpui::AppContext;
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub enum ProviderCredential {
|
||||||
|
Credentials { api_key: String },
|
||||||
|
NoCredentials,
|
||||||
|
NotNeeded,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait CredentialProvider: Send + Sync {
|
||||||
|
fn retrieve_credentials(&self, cx: &AppContext) -> ProviderCredential;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct NullCredentialProvider;
|
||||||
|
impl CredentialProvider for NullCredentialProvider {
|
||||||
|
fn retrieve_credentials(&self, _cx: &AppContext) -> ProviderCredential {
|
||||||
|
ProviderCredential::NotNeeded
|
||||||
|
}
|
||||||
|
}
|
|
@ -7,6 +7,7 @@ use ordered_float::OrderedFloat;
|
||||||
use rusqlite::types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef};
|
use rusqlite::types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef};
|
||||||
use rusqlite::ToSql;
|
use rusqlite::ToSql;
|
||||||
|
|
||||||
|
use crate::auth::{CredentialProvider, ProviderCredential};
|
||||||
use crate::models::LanguageModel;
|
use crate::models::LanguageModel;
|
||||||
|
|
||||||
#[derive(Debug, PartialEq, Clone)]
|
#[derive(Debug, PartialEq, Clone)]
|
||||||
|
@ -71,11 +72,14 @@ impl Embedding {
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
pub trait EmbeddingProvider: Sync + Send {
|
pub trait EmbeddingProvider: Sync + Send {
|
||||||
fn base_model(&self) -> Box<dyn LanguageModel>;
|
fn base_model(&self) -> Box<dyn LanguageModel>;
|
||||||
fn retrieve_credentials(&self, cx: &AppContext) -> Option<String>;
|
fn credential_provider(&self) -> Box<dyn CredentialProvider>;
|
||||||
|
fn retrieve_credentials(&self, cx: &AppContext) -> ProviderCredential {
|
||||||
|
self.credential_provider().retrieve_credentials(cx)
|
||||||
|
}
|
||||||
async fn embed_batch(
|
async fn embed_batch(
|
||||||
&self,
|
&self,
|
||||||
spans: Vec<String>,
|
spans: Vec<String>,
|
||||||
api_key: Option<String>,
|
credential: ProviderCredential,
|
||||||
) -> Result<Vec<Embedding>>;
|
) -> Result<Vec<Embedding>>;
|
||||||
fn max_tokens_per_batch(&self) -> usize;
|
fn max_tokens_per_batch(&self) -> usize;
|
||||||
fn rate_limit_expiration(&self) -> Option<Instant>;
|
fn rate_limit_expiration(&self) -> Option<Instant>;
|
||||||
|
|
|
@ -126,6 +126,7 @@ impl PromptChain {
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
pub(crate) mod tests {
|
pub(crate) mod tests {
|
||||||
use crate::models::TruncationDirection;
|
use crate::models::TruncationDirection;
|
||||||
|
use crate::test::FakeLanguageModel;
|
||||||
|
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
|
@ -181,39 +182,7 @@ pub(crate) mod tests {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone)]
|
let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity: 100 });
|
||||||
struct DummyLanguageModel {
|
|
||||||
capacity: usize,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl LanguageModel for DummyLanguageModel {
|
|
||||||
fn name(&self) -> String {
|
|
||||||
"dummy".to_string()
|
|
||||||
}
|
|
||||||
fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
|
|
||||||
anyhow::Ok(content.chars().collect::<Vec<char>>().len())
|
|
||||||
}
|
|
||||||
fn truncate(
|
|
||||||
&self,
|
|
||||||
content: &str,
|
|
||||||
length: usize,
|
|
||||||
direction: TruncationDirection,
|
|
||||||
) -> anyhow::Result<String> {
|
|
||||||
anyhow::Ok(match direction {
|
|
||||||
TruncationDirection::End => content.chars().collect::<Vec<char>>()[..length]
|
|
||||||
.into_iter()
|
|
||||||
.collect::<String>(),
|
|
||||||
TruncationDirection::Start => content.chars().collect::<Vec<char>>()[length..]
|
|
||||||
.into_iter()
|
|
||||||
.collect::<String>(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
fn capacity(&self) -> anyhow::Result<usize> {
|
|
||||||
anyhow::Ok(self.capacity)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let model: Arc<dyn LanguageModel> = Arc::new(DummyLanguageModel { capacity: 100 });
|
|
||||||
let args = PromptArguments {
|
let args = PromptArguments {
|
||||||
model: model.clone(),
|
model: model.clone(),
|
||||||
language_name: None,
|
language_name: None,
|
||||||
|
@ -249,7 +218,7 @@ pub(crate) mod tests {
|
||||||
|
|
||||||
// Testing with Truncation Off
|
// Testing with Truncation Off
|
||||||
// Should ignore capacity and return all prompts
|
// Should ignore capacity and return all prompts
|
||||||
let model: Arc<dyn LanguageModel> = Arc::new(DummyLanguageModel { capacity: 20 });
|
let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity: 20 });
|
||||||
let args = PromptArguments {
|
let args = PromptArguments {
|
||||||
model: model.clone(),
|
model: model.clone(),
|
||||||
language_name: None,
|
language_name: None,
|
||||||
|
@ -286,7 +255,7 @@ pub(crate) mod tests {
|
||||||
// Testing with Truncation Off
|
// Testing with Truncation Off
|
||||||
// Should ignore capacity and return all prompts
|
// Should ignore capacity and return all prompts
|
||||||
let capacity = 20;
|
let capacity = 20;
|
||||||
let model: Arc<dyn LanguageModel> = Arc::new(DummyLanguageModel { capacity });
|
let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity });
|
||||||
let args = PromptArguments {
|
let args = PromptArguments {
|
||||||
model: model.clone(),
|
model: model.clone(),
|
||||||
language_name: None,
|
language_name: None,
|
||||||
|
@ -322,7 +291,7 @@ pub(crate) mod tests {
|
||||||
// Change Ordering of Prompts Based on Priority
|
// Change Ordering of Prompts Based on Priority
|
||||||
let capacity = 120;
|
let capacity = 120;
|
||||||
let reserved_tokens = 10;
|
let reserved_tokens = 10;
|
||||||
let model: Arc<dyn LanguageModel> = Arc::new(DummyLanguageModel { capacity });
|
let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity });
|
||||||
let args = PromptArguments {
|
let args = PromptArguments {
|
||||||
model: model.clone(),
|
model: model.clone(),
|
||||||
language_name: None,
|
language_name: None,
|
||||||
|
|
|
@ -1,85 +0,0 @@
|
||||||
use std::time::Instant;
|
|
||||||
|
|
||||||
use crate::{
|
|
||||||
completion::CompletionRequest,
|
|
||||||
embedding::{Embedding, EmbeddingProvider},
|
|
||||||
models::{LanguageModel, TruncationDirection},
|
|
||||||
};
|
|
||||||
use async_trait::async_trait;
|
|
||||||
use gpui::AppContext;
|
|
||||||
use serde::Serialize;
|
|
||||||
|
|
||||||
pub struct DummyLanguageModel {}
|
|
||||||
|
|
||||||
impl LanguageModel for DummyLanguageModel {
|
|
||||||
fn name(&self) -> String {
|
|
||||||
"dummy".to_string()
|
|
||||||
}
|
|
||||||
fn capacity(&self) -> anyhow::Result<usize> {
|
|
||||||
anyhow::Ok(1000)
|
|
||||||
}
|
|
||||||
fn truncate(
|
|
||||||
&self,
|
|
||||||
content: &str,
|
|
||||||
length: usize,
|
|
||||||
direction: crate::models::TruncationDirection,
|
|
||||||
) -> anyhow::Result<String> {
|
|
||||||
if content.len() < length {
|
|
||||||
return anyhow::Ok(content.to_string());
|
|
||||||
}
|
|
||||||
|
|
||||||
let truncated = match direction {
|
|
||||||
TruncationDirection::End => content.chars().collect::<Vec<char>>()[..length]
|
|
||||||
.iter()
|
|
||||||
.collect::<String>(),
|
|
||||||
TruncationDirection::Start => content.chars().collect::<Vec<char>>()[..length]
|
|
||||||
.iter()
|
|
||||||
.collect::<String>(),
|
|
||||||
};
|
|
||||||
|
|
||||||
anyhow::Ok(truncated)
|
|
||||||
}
|
|
||||||
fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
|
|
||||||
anyhow::Ok(content.chars().collect::<Vec<char>>().len())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize)]
|
|
||||||
pub struct DummyCompletionRequest {
|
|
||||||
pub name: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl CompletionRequest for DummyCompletionRequest {
|
|
||||||
fn data(&self) -> serde_json::Result<String> {
|
|
||||||
serde_json::to_string(self)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct DummyEmbeddingProvider {}
|
|
||||||
|
|
||||||
#[async_trait]
|
|
||||||
impl EmbeddingProvider for DummyEmbeddingProvider {
|
|
||||||
fn retrieve_credentials(&self, _cx: &AppContext) -> Option<String> {
|
|
||||||
Some("Dummy Credentials".to_string())
|
|
||||||
}
|
|
||||||
fn base_model(&self) -> Box<dyn LanguageModel> {
|
|
||||||
Box::new(DummyLanguageModel {})
|
|
||||||
}
|
|
||||||
fn rate_limit_expiration(&self) -> Option<Instant> {
|
|
||||||
None
|
|
||||||
}
|
|
||||||
async fn embed_batch(
|
|
||||||
&self,
|
|
||||||
spans: Vec<String>,
|
|
||||||
api_key: Option<String>,
|
|
||||||
) -> anyhow::Result<Vec<Embedding>> {
|
|
||||||
// 1024 is the OpenAI Embeddings size for ada models.
|
|
||||||
// the model we will likely be starting with.
|
|
||||||
let dummy_vec = Embedding::from(vec![0.32 as f32; 1536]);
|
|
||||||
return Ok(vec![dummy_vec; spans.len()]);
|
|
||||||
}
|
|
||||||
|
|
||||||
fn max_tokens_per_batch(&self) -> usize {
|
|
||||||
8190
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,2 +1 @@
|
||||||
pub mod dummy;
|
|
||||||
pub mod open_ai;
|
pub mod open_ai;
|
||||||
|
|
33
crates/ai/src/providers/open_ai/auth.rs
Normal file
33
crates/ai/src/providers/open_ai/auth.rs
Normal file
|
@ -0,0 +1,33 @@
|
||||||
|
use std::env;
|
||||||
|
|
||||||
|
use gpui::AppContext;
|
||||||
|
use util::ResultExt;
|
||||||
|
|
||||||
|
use crate::auth::{CredentialProvider, ProviderCredential};
|
||||||
|
use crate::providers::open_ai::OPENAI_API_URL;
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct OpenAICredentialProvider {}
|
||||||
|
|
||||||
|
impl CredentialProvider for OpenAICredentialProvider {
|
||||||
|
fn retrieve_credentials(&self, cx: &AppContext) -> ProviderCredential {
|
||||||
|
let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") {
|
||||||
|
Some(api_key)
|
||||||
|
} else if let Some((_, api_key)) = cx
|
||||||
|
.platform()
|
||||||
|
.read_credentials(OPENAI_API_URL)
|
||||||
|
.log_err()
|
||||||
|
.flatten()
|
||||||
|
{
|
||||||
|
String::from_utf8(api_key).log_err()
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
|
if let Some(api_key) = api_key {
|
||||||
|
ProviderCredential::Credentials { api_key }
|
||||||
|
} else {
|
||||||
|
ProviderCredential::NoCredentials
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -2,7 +2,7 @@ use anyhow::{anyhow, Result};
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use futures::AsyncReadExt;
|
use futures::AsyncReadExt;
|
||||||
use gpui::executor::Background;
|
use gpui::executor::Background;
|
||||||
use gpui::{serde_json, AppContext};
|
use gpui::serde_json;
|
||||||
use isahc::http::StatusCode;
|
use isahc::http::StatusCode;
|
||||||
use isahc::prelude::Configurable;
|
use isahc::prelude::Configurable;
|
||||||
use isahc::{AsyncBody, Response};
|
use isahc::{AsyncBody, Response};
|
||||||
|
@ -17,13 +17,13 @@ use std::sync::Arc;
|
||||||
use std::time::{Duration, Instant};
|
use std::time::{Duration, Instant};
|
||||||
use tiktoken_rs::{cl100k_base, CoreBPE};
|
use tiktoken_rs::{cl100k_base, CoreBPE};
|
||||||
use util::http::{HttpClient, Request};
|
use util::http::{HttpClient, Request};
|
||||||
use util::ResultExt;
|
|
||||||
|
|
||||||
|
use crate::auth::{CredentialProvider, ProviderCredential};
|
||||||
use crate::embedding::{Embedding, EmbeddingProvider};
|
use crate::embedding::{Embedding, EmbeddingProvider};
|
||||||
use crate::models::LanguageModel;
|
use crate::models::LanguageModel;
|
||||||
use crate::providers::open_ai::OpenAILanguageModel;
|
use crate::providers::open_ai::OpenAILanguageModel;
|
||||||
|
|
||||||
use super::OPENAI_API_URL;
|
use crate::providers::open_ai::auth::OpenAICredentialProvider;
|
||||||
|
|
||||||
lazy_static! {
|
lazy_static! {
|
||||||
static ref OPENAI_API_KEY: Option<String> = env::var("OPENAI_API_KEY").ok();
|
static ref OPENAI_API_KEY: Option<String> = env::var("OPENAI_API_KEY").ok();
|
||||||
|
@ -33,6 +33,7 @@ lazy_static! {
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct OpenAIEmbeddingProvider {
|
pub struct OpenAIEmbeddingProvider {
|
||||||
model: OpenAILanguageModel,
|
model: OpenAILanguageModel,
|
||||||
|
credential_provider: OpenAICredentialProvider,
|
||||||
pub client: Arc<dyn HttpClient>,
|
pub client: Arc<dyn HttpClient>,
|
||||||
pub executor: Arc<Background>,
|
pub executor: Arc<Background>,
|
||||||
rate_limit_count_rx: watch::Receiver<Option<Instant>>,
|
rate_limit_count_rx: watch::Receiver<Option<Instant>>,
|
||||||
|
@ -73,6 +74,7 @@ impl OpenAIEmbeddingProvider {
|
||||||
|
|
||||||
OpenAIEmbeddingProvider {
|
OpenAIEmbeddingProvider {
|
||||||
model,
|
model,
|
||||||
|
credential_provider: OpenAICredentialProvider {},
|
||||||
client,
|
client,
|
||||||
executor,
|
executor,
|
||||||
rate_limit_count_rx,
|
rate_limit_count_rx,
|
||||||
|
@ -138,25 +140,17 @@ impl OpenAIEmbeddingProvider {
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl EmbeddingProvider for OpenAIEmbeddingProvider {
|
impl EmbeddingProvider for OpenAIEmbeddingProvider {
|
||||||
fn retrieve_credentials(&self, cx: &AppContext) -> Option<String> {
|
|
||||||
let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") {
|
|
||||||
Some(api_key)
|
|
||||||
} else if let Some((_, api_key)) = cx
|
|
||||||
.platform()
|
|
||||||
.read_credentials(OPENAI_API_URL)
|
|
||||||
.log_err()
|
|
||||||
.flatten()
|
|
||||||
{
|
|
||||||
String::from_utf8(api_key).log_err()
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
api_key
|
|
||||||
}
|
|
||||||
fn base_model(&self) -> Box<dyn LanguageModel> {
|
fn base_model(&self) -> Box<dyn LanguageModel> {
|
||||||
let model: Box<dyn LanguageModel> = Box::new(self.model.clone());
|
let model: Box<dyn LanguageModel> = Box::new(self.model.clone());
|
||||||
model
|
model
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn credential_provider(&self) -> Box<dyn CredentialProvider> {
|
||||||
|
let credential_provider: Box<dyn CredentialProvider> =
|
||||||
|
Box::new(self.credential_provider.clone());
|
||||||
|
credential_provider
|
||||||
|
}
|
||||||
|
|
||||||
fn max_tokens_per_batch(&self) -> usize {
|
fn max_tokens_per_batch(&self) -> usize {
|
||||||
50000
|
50000
|
||||||
}
|
}
|
||||||
|
@ -164,25 +158,11 @@ impl EmbeddingProvider for OpenAIEmbeddingProvider {
|
||||||
fn rate_limit_expiration(&self) -> Option<Instant> {
|
fn rate_limit_expiration(&self) -> Option<Instant> {
|
||||||
*self.rate_limit_count_rx.borrow()
|
*self.rate_limit_count_rx.borrow()
|
||||||
}
|
}
|
||||||
// fn truncate(&self, span: &str) -> (String, usize) {
|
|
||||||
// let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
|
|
||||||
// let output = if tokens.len() > OPENAI_INPUT_LIMIT {
|
|
||||||
// tokens.truncate(OPENAI_INPUT_LIMIT);
|
|
||||||
// OPENAI_BPE_TOKENIZER
|
|
||||||
// .decode(tokens.clone())
|
|
||||||
// .ok()
|
|
||||||
// .unwrap_or_else(|| span.to_string())
|
|
||||||
// } else {
|
|
||||||
// span.to_string()
|
|
||||||
// };
|
|
||||||
|
|
||||||
// (output, tokens.len())
|
|
||||||
// }
|
|
||||||
|
|
||||||
async fn embed_batch(
|
async fn embed_batch(
|
||||||
&self,
|
&self,
|
||||||
spans: Vec<String>,
|
spans: Vec<String>,
|
||||||
api_key: Option<String>,
|
_credential: ProviderCredential,
|
||||||
) -> Result<Vec<Embedding>> {
|
) -> Result<Vec<Embedding>> {
|
||||||
const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
|
const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
|
||||||
const MAX_RETRIES: usize = 4;
|
const MAX_RETRIES: usize = 4;
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
pub mod auth;
|
||||||
pub mod completion;
|
pub mod completion;
|
||||||
pub mod embedding;
|
pub mod embedding;
|
||||||
pub mod model;
|
pub mod model;
|
||||||
|
|
123
crates/ai/src/test.rs
Normal file
123
crates/ai/src/test.rs
Normal file
|
@ -0,0 +1,123 @@
|
||||||
|
use std::{
|
||||||
|
sync::atomic::{self, AtomicUsize, Ordering},
|
||||||
|
time::Instant,
|
||||||
|
};
|
||||||
|
|
||||||
|
use async_trait::async_trait;
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
auth::{CredentialProvider, NullCredentialProvider, ProviderCredential},
|
||||||
|
embedding::{Embedding, EmbeddingProvider},
|
||||||
|
models::{LanguageModel, TruncationDirection},
|
||||||
|
};
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct FakeLanguageModel {
|
||||||
|
pub capacity: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl LanguageModel for FakeLanguageModel {
|
||||||
|
fn name(&self) -> String {
|
||||||
|
"dummy".to_string()
|
||||||
|
}
|
||||||
|
fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
|
||||||
|
anyhow::Ok(content.chars().collect::<Vec<char>>().len())
|
||||||
|
}
|
||||||
|
fn truncate(
|
||||||
|
&self,
|
||||||
|
content: &str,
|
||||||
|
length: usize,
|
||||||
|
direction: TruncationDirection,
|
||||||
|
) -> anyhow::Result<String> {
|
||||||
|
anyhow::Ok(match direction {
|
||||||
|
TruncationDirection::End => content.chars().collect::<Vec<char>>()[..length]
|
||||||
|
.into_iter()
|
||||||
|
.collect::<String>(),
|
||||||
|
TruncationDirection::Start => content.chars().collect::<Vec<char>>()[length..]
|
||||||
|
.into_iter()
|
||||||
|
.collect::<String>(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
fn capacity(&self) -> anyhow::Result<usize> {
|
||||||
|
anyhow::Ok(self.capacity)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct FakeEmbeddingProvider {
|
||||||
|
pub embedding_count: AtomicUsize,
|
||||||
|
pub credential_provider: NullCredentialProvider,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Clone for FakeEmbeddingProvider {
|
||||||
|
fn clone(&self) -> Self {
|
||||||
|
FakeEmbeddingProvider {
|
||||||
|
embedding_count: AtomicUsize::new(self.embedding_count.load(Ordering::SeqCst)),
|
||||||
|
credential_provider: self.credential_provider.clone(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for FakeEmbeddingProvider {
|
||||||
|
fn default() -> Self {
|
||||||
|
FakeEmbeddingProvider {
|
||||||
|
embedding_count: AtomicUsize::default(),
|
||||||
|
credential_provider: NullCredentialProvider {},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl FakeEmbeddingProvider {
|
||||||
|
pub fn embedding_count(&self) -> usize {
|
||||||
|
self.embedding_count.load(atomic::Ordering::SeqCst)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn embed_sync(&self, span: &str) -> Embedding {
|
||||||
|
let mut result = vec![1.0; 26];
|
||||||
|
for letter in span.chars() {
|
||||||
|
let letter = letter.to_ascii_lowercase();
|
||||||
|
if letter as u32 >= 'a' as u32 {
|
||||||
|
let ix = (letter as u32) - ('a' as u32);
|
||||||
|
if ix < 26 {
|
||||||
|
result[ix as usize] += 1.0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||||
|
for x in &mut result {
|
||||||
|
*x /= norm;
|
||||||
|
}
|
||||||
|
|
||||||
|
result.into()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl EmbeddingProvider for FakeEmbeddingProvider {
|
||||||
|
fn base_model(&self) -> Box<dyn LanguageModel> {
|
||||||
|
Box::new(FakeLanguageModel { capacity: 1000 })
|
||||||
|
}
|
||||||
|
fn credential_provider(&self) -> Box<dyn CredentialProvider> {
|
||||||
|
let credential_provider: Box<dyn CredentialProvider> =
|
||||||
|
Box::new(self.credential_provider.clone());
|
||||||
|
credential_provider
|
||||||
|
}
|
||||||
|
fn max_tokens_per_batch(&self) -> usize {
|
||||||
|
1000
|
||||||
|
}
|
||||||
|
|
||||||
|
fn rate_limit_expiration(&self) -> Option<Instant> {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn embed_batch(
|
||||||
|
&self,
|
||||||
|
spans: Vec<String>,
|
||||||
|
_credential: ProviderCredential,
|
||||||
|
) -> anyhow::Result<Vec<Embedding>> {
|
||||||
|
self.embedding_count
|
||||||
|
.fetch_add(spans.len(), atomic::Ordering::SeqCst);
|
||||||
|
|
||||||
|
anyhow::Ok(spans.iter().map(|span| self.embed_sync(span)).collect())
|
||||||
|
}
|
||||||
|
}
|
|
@ -335,7 +335,6 @@ fn strip_markdown_codeblock(
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use ai::providers::dummy::DummyCompletionRequest;
|
|
||||||
use futures::{
|
use futures::{
|
||||||
future::BoxFuture,
|
future::BoxFuture,
|
||||||
stream::{self, BoxStream},
|
stream::{self, BoxStream},
|
||||||
|
@ -345,9 +344,21 @@ mod tests {
|
||||||
use language::{language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, Point};
|
use language::{language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, Point};
|
||||||
use parking_lot::Mutex;
|
use parking_lot::Mutex;
|
||||||
use rand::prelude::*;
|
use rand::prelude::*;
|
||||||
|
use serde::Serialize;
|
||||||
use settings::SettingsStore;
|
use settings::SettingsStore;
|
||||||
use smol::future::FutureExt;
|
use smol::future::FutureExt;
|
||||||
|
|
||||||
|
#[derive(Serialize)]
|
||||||
|
pub struct DummyCompletionRequest {
|
||||||
|
pub name: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CompletionRequest for DummyCompletionRequest {
|
||||||
|
fn data(&self) -> serde_json::Result<String> {
|
||||||
|
serde_json::to_string(self)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[gpui::test(iterations = 10)]
|
#[gpui::test(iterations = 10)]
|
||||||
async fn test_transform_autoindent(
|
async fn test_transform_autoindent(
|
||||||
cx: &mut TestAppContext,
|
cx: &mut TestAppContext,
|
||||||
|
@ -381,6 +392,7 @@ mod tests {
|
||||||
cx,
|
cx,
|
||||||
)
|
)
|
||||||
});
|
});
|
||||||
|
|
||||||
let request = Box::new(DummyCompletionRequest {
|
let request = Box::new(DummyCompletionRequest {
|
||||||
name: "test".to_string(),
|
name: "test".to_string(),
|
||||||
});
|
});
|
||||||
|
|
|
@ -42,6 +42,7 @@ sha1 = "0.10.5"
|
||||||
ndarray = { version = "0.15.0" }
|
ndarray = { version = "0.15.0" }
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
|
ai = { path = "../ai", features = ["test-support"] }
|
||||||
collections = { path = "../collections", features = ["test-support"] }
|
collections = { path = "../collections", features = ["test-support"] }
|
||||||
gpui = { path = "../gpui", features = ["test-support"] }
|
gpui = { path = "../gpui", features = ["test-support"] }
|
||||||
language = { path = "../language", features = ["test-support"] }
|
language = { path = "../language", features = ["test-support"] }
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
use crate::{parsing::Span, JobHandle};
|
use crate::{parsing::Span, JobHandle};
|
||||||
use ai::embedding::EmbeddingProvider;
|
use ai::{auth::ProviderCredential, embedding::EmbeddingProvider};
|
||||||
use gpui::executor::Background;
|
use gpui::executor::Background;
|
||||||
use parking_lot::Mutex;
|
use parking_lot::Mutex;
|
||||||
use smol::channel;
|
use smol::channel;
|
||||||
|
@ -41,7 +41,7 @@ pub struct EmbeddingQueue {
|
||||||
pending_batch_token_count: usize,
|
pending_batch_token_count: usize,
|
||||||
finished_files_tx: channel::Sender<FileToEmbed>,
|
finished_files_tx: channel::Sender<FileToEmbed>,
|
||||||
finished_files_rx: channel::Receiver<FileToEmbed>,
|
finished_files_rx: channel::Receiver<FileToEmbed>,
|
||||||
api_key: Option<String>,
|
provider_credential: ProviderCredential,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
|
@ -54,7 +54,7 @@ impl EmbeddingQueue {
|
||||||
pub fn new(
|
pub fn new(
|
||||||
embedding_provider: Arc<dyn EmbeddingProvider>,
|
embedding_provider: Arc<dyn EmbeddingProvider>,
|
||||||
executor: Arc<Background>,
|
executor: Arc<Background>,
|
||||||
api_key: Option<String>,
|
provider_credential: ProviderCredential,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
let (finished_files_tx, finished_files_rx) = channel::unbounded();
|
let (finished_files_tx, finished_files_rx) = channel::unbounded();
|
||||||
Self {
|
Self {
|
||||||
|
@ -64,12 +64,12 @@ impl EmbeddingQueue {
|
||||||
pending_batch_token_count: 0,
|
pending_batch_token_count: 0,
|
||||||
finished_files_tx,
|
finished_files_tx,
|
||||||
finished_files_rx,
|
finished_files_rx,
|
||||||
api_key,
|
provider_credential,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn set_api_key(&mut self, api_key: Option<String>) {
|
pub fn set_credential(&mut self, credential: ProviderCredential) {
|
||||||
self.api_key = api_key
|
self.provider_credential = credential
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn push(&mut self, file: FileToEmbed) {
|
pub fn push(&mut self, file: FileToEmbed) {
|
||||||
|
@ -118,7 +118,7 @@ impl EmbeddingQueue {
|
||||||
|
|
||||||
let finished_files_tx = self.finished_files_tx.clone();
|
let finished_files_tx = self.finished_files_tx.clone();
|
||||||
let embedding_provider = self.embedding_provider.clone();
|
let embedding_provider = self.embedding_provider.clone();
|
||||||
let api_key = self.api_key.clone();
|
let credential = self.provider_credential.clone();
|
||||||
|
|
||||||
self.executor
|
self.executor
|
||||||
.spawn(async move {
|
.spawn(async move {
|
||||||
|
@ -143,7 +143,7 @@ impl EmbeddingQueue {
|
||||||
return;
|
return;
|
||||||
};
|
};
|
||||||
|
|
||||||
match embedding_provider.embed_batch(spans, api_key).await {
|
match embedding_provider.embed_batch(spans, credential).await {
|
||||||
Ok(embeddings) => {
|
Ok(embeddings) => {
|
||||||
let mut embeddings = embeddings.into_iter();
|
let mut embeddings = embeddings.into_iter();
|
||||||
for fragment in batch {
|
for fragment in batch {
|
||||||
|
|
|
@ -7,6 +7,7 @@ pub mod semantic_index_settings;
|
||||||
mod semantic_index_tests;
|
mod semantic_index_tests;
|
||||||
|
|
||||||
use crate::semantic_index_settings::SemanticIndexSettings;
|
use crate::semantic_index_settings::SemanticIndexSettings;
|
||||||
|
use ai::auth::ProviderCredential;
|
||||||
use ai::embedding::{Embedding, EmbeddingProvider};
|
use ai::embedding::{Embedding, EmbeddingProvider};
|
||||||
use ai::providers::open_ai::OpenAIEmbeddingProvider;
|
use ai::providers::open_ai::OpenAIEmbeddingProvider;
|
||||||
use anyhow::{anyhow, Result};
|
use anyhow::{anyhow, Result};
|
||||||
|
@ -124,7 +125,7 @@ pub struct SemanticIndex {
|
||||||
_embedding_task: Task<()>,
|
_embedding_task: Task<()>,
|
||||||
_parsing_files_tasks: Vec<Task<()>>,
|
_parsing_files_tasks: Vec<Task<()>>,
|
||||||
projects: HashMap<WeakModelHandle<Project>, ProjectState>,
|
projects: HashMap<WeakModelHandle<Project>, ProjectState>,
|
||||||
api_key: Option<String>,
|
provider_credential: ProviderCredential,
|
||||||
embedding_queue: Arc<Mutex<EmbeddingQueue>>,
|
embedding_queue: Arc<Mutex<EmbeddingQueue>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -279,18 +280,27 @@ impl SemanticIndex {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn authenticate(&mut self, cx: &AppContext) {
|
pub fn authenticate(&mut self, cx: &AppContext) -> bool {
|
||||||
if self.api_key.is_none() {
|
let credential = self.provider_credential.clone();
|
||||||
self.api_key = self.embedding_provider.retrieve_credentials(cx);
|
match credential {
|
||||||
|
ProviderCredential::NoCredentials => {
|
||||||
self.embedding_queue
|
let credential = self.embedding_provider.retrieve_credentials(cx);
|
||||||
.lock()
|
self.provider_credential = credential;
|
||||||
.set_api_key(self.api_key.clone());
|
}
|
||||||
|
_ => {}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
self.embedding_queue.lock().set_credential(credential);
|
||||||
|
|
||||||
|
self.is_authenticated()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn is_authenticated(&self) -> bool {
|
pub fn is_authenticated(&self) -> bool {
|
||||||
self.api_key.is_some()
|
let credential = &self.provider_credential;
|
||||||
|
match credential {
|
||||||
|
&ProviderCredential::Credentials { .. } => true,
|
||||||
|
_ => false,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn enabled(cx: &AppContext) -> bool {
|
pub fn enabled(cx: &AppContext) -> bool {
|
||||||
|
@ -340,7 +350,7 @@ impl SemanticIndex {
|
||||||
Ok(cx.add_model(|cx| {
|
Ok(cx.add_model(|cx| {
|
||||||
let t0 = Instant::now();
|
let t0 = Instant::now();
|
||||||
let embedding_queue =
|
let embedding_queue =
|
||||||
EmbeddingQueue::new(embedding_provider.clone(), cx.background().clone(), None);
|
EmbeddingQueue::new(embedding_provider.clone(), cx.background().clone(), ProviderCredential::NoCredentials);
|
||||||
let _embedding_task = cx.background().spawn({
|
let _embedding_task = cx.background().spawn({
|
||||||
let embedded_files = embedding_queue.finished_files();
|
let embedded_files = embedding_queue.finished_files();
|
||||||
let db = db.clone();
|
let db = db.clone();
|
||||||
|
@ -405,7 +415,7 @@ impl SemanticIndex {
|
||||||
_embedding_task,
|
_embedding_task,
|
||||||
_parsing_files_tasks,
|
_parsing_files_tasks,
|
||||||
projects: Default::default(),
|
projects: Default::default(),
|
||||||
api_key: None,
|
provider_credential: ProviderCredential::NoCredentials,
|
||||||
embedding_queue
|
embedding_queue
|
||||||
}
|
}
|
||||||
}))
|
}))
|
||||||
|
@ -721,13 +731,14 @@ impl SemanticIndex {
|
||||||
|
|
||||||
let index = self.index_project(project.clone(), cx);
|
let index = self.index_project(project.clone(), cx);
|
||||||
let embedding_provider = self.embedding_provider.clone();
|
let embedding_provider = self.embedding_provider.clone();
|
||||||
let api_key = self.api_key.clone();
|
let credential = self.provider_credential.clone();
|
||||||
|
|
||||||
cx.spawn(|this, mut cx| async move {
|
cx.spawn(|this, mut cx| async move {
|
||||||
index.await?;
|
index.await?;
|
||||||
let t0 = Instant::now();
|
let t0 = Instant::now();
|
||||||
|
|
||||||
let query = embedding_provider
|
let query = embedding_provider
|
||||||
.embed_batch(vec![query], api_key)
|
.embed_batch(vec![query], credential)
|
||||||
.await?
|
.await?
|
||||||
.pop()
|
.pop()
|
||||||
.ok_or_else(|| anyhow!("could not embed query"))?;
|
.ok_or_else(|| anyhow!("could not embed query"))?;
|
||||||
|
@ -945,7 +956,7 @@ impl SemanticIndex {
|
||||||
let fs = self.fs.clone();
|
let fs = self.fs.clone();
|
||||||
let db_path = self.db.path().clone();
|
let db_path = self.db.path().clone();
|
||||||
let background = cx.background().clone();
|
let background = cx.background().clone();
|
||||||
let api_key = self.api_key.clone();
|
let credential = self.provider_credential.clone();
|
||||||
cx.background().spawn(async move {
|
cx.background().spawn(async move {
|
||||||
let db = VectorDatabase::new(fs, db_path.clone(), background).await?;
|
let db = VectorDatabase::new(fs, db_path.clone(), background).await?;
|
||||||
let mut results = Vec::<SearchResult>::new();
|
let mut results = Vec::<SearchResult>::new();
|
||||||
|
@ -964,7 +975,7 @@ impl SemanticIndex {
|
||||||
&mut spans,
|
&mut spans,
|
||||||
embedding_provider.as_ref(),
|
embedding_provider.as_ref(),
|
||||||
&db,
|
&db,
|
||||||
api_key.clone(),
|
credential.clone(),
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
.log_err()
|
.log_err()
|
||||||
|
@ -1008,9 +1019,8 @@ impl SemanticIndex {
|
||||||
project: ModelHandle<Project>,
|
project: ModelHandle<Project>,
|
||||||
cx: &mut ModelContext<Self>,
|
cx: &mut ModelContext<Self>,
|
||||||
) -> Task<Result<()>> {
|
) -> Task<Result<()>> {
|
||||||
if self.api_key.is_none() {
|
if !self.is_authenticated() {
|
||||||
self.authenticate(cx);
|
if !self.authenticate(cx) {
|
||||||
if self.api_key.is_none() {
|
|
||||||
return Task::ready(Err(anyhow!("user is not authenticated")));
|
return Task::ready(Err(anyhow!("user is not authenticated")));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1193,7 +1203,7 @@ impl SemanticIndex {
|
||||||
spans: &mut [Span],
|
spans: &mut [Span],
|
||||||
embedding_provider: &dyn EmbeddingProvider,
|
embedding_provider: &dyn EmbeddingProvider,
|
||||||
db: &VectorDatabase,
|
db: &VectorDatabase,
|
||||||
api_key: Option<String>,
|
credential: ProviderCredential,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let mut batch = Vec::new();
|
let mut batch = Vec::new();
|
||||||
let mut batch_tokens = 0;
|
let mut batch_tokens = 0;
|
||||||
|
@ -1216,7 +1226,7 @@ impl SemanticIndex {
|
||||||
|
|
||||||
if batch_tokens + span.token_count > embedding_provider.max_tokens_per_batch() {
|
if batch_tokens + span.token_count > embedding_provider.max_tokens_per_batch() {
|
||||||
let batch_embeddings = embedding_provider
|
let batch_embeddings = embedding_provider
|
||||||
.embed_batch(mem::take(&mut batch), api_key.clone())
|
.embed_batch(mem::take(&mut batch), credential.clone())
|
||||||
.await?;
|
.await?;
|
||||||
embeddings.extend(batch_embeddings);
|
embeddings.extend(batch_embeddings);
|
||||||
batch_tokens = 0;
|
batch_tokens = 0;
|
||||||
|
@ -1228,7 +1238,7 @@ impl SemanticIndex {
|
||||||
|
|
||||||
if !batch.is_empty() {
|
if !batch.is_empty() {
|
||||||
let batch_embeddings = embedding_provider
|
let batch_embeddings = embedding_provider
|
||||||
.embed_batch(mem::take(&mut batch), api_key)
|
.embed_batch(mem::take(&mut batch), credential)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
embeddings.extend(batch_embeddings);
|
embeddings.extend(batch_embeddings);
|
||||||
|
|
|
@ -4,14 +4,9 @@ use crate::{
|
||||||
semantic_index_settings::SemanticIndexSettings,
|
semantic_index_settings::SemanticIndexSettings,
|
||||||
FileToEmbed, JobHandle, SearchResult, SemanticIndex, EMBEDDING_QUEUE_FLUSH_TIMEOUT,
|
FileToEmbed, JobHandle, SearchResult, SemanticIndex, EMBEDDING_QUEUE_FLUSH_TIMEOUT,
|
||||||
};
|
};
|
||||||
use ai::providers::dummy::{DummyEmbeddingProvider, DummyLanguageModel};
|
use ai::test::FakeEmbeddingProvider;
|
||||||
use ai::{
|
|
||||||
embedding::{Embedding, EmbeddingProvider},
|
use gpui::{executor::Deterministic, Task, TestAppContext};
|
||||||
models::LanguageModel,
|
|
||||||
};
|
|
||||||
use anyhow::Result;
|
|
||||||
use async_trait::async_trait;
|
|
||||||
use gpui::{executor::Deterministic, AppContext, Task, TestAppContext};
|
|
||||||
use language::{Language, LanguageConfig, LanguageRegistry, ToOffset};
|
use language::{Language, LanguageConfig, LanguageRegistry, ToOffset};
|
||||||
use parking_lot::Mutex;
|
use parking_lot::Mutex;
|
||||||
use pretty_assertions::assert_eq;
|
use pretty_assertions::assert_eq;
|
||||||
|
@ -19,14 +14,7 @@ use project::{project_settings::ProjectSettings, search::PathMatcher, FakeFs, Fs
|
||||||
use rand::{rngs::StdRng, Rng};
|
use rand::{rngs::StdRng, Rng};
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
use settings::SettingsStore;
|
use settings::SettingsStore;
|
||||||
use std::{
|
use std::{path::Path, sync::Arc, time::SystemTime};
|
||||||
path::Path,
|
|
||||||
sync::{
|
|
||||||
atomic::{self, AtomicUsize},
|
|
||||||
Arc,
|
|
||||||
},
|
|
||||||
time::{Instant, SystemTime},
|
|
||||||
};
|
|
||||||
use unindent::Unindent;
|
use unindent::Unindent;
|
||||||
use util::RandomCharIter;
|
use util::RandomCharIter;
|
||||||
|
|
||||||
|
@ -232,7 +220,11 @@ async fn test_embedding_batching(cx: &mut TestAppContext, mut rng: StdRng) {
|
||||||
|
|
||||||
let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
|
let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
|
||||||
|
|
||||||
let mut queue = EmbeddingQueue::new(embedding_provider.clone(), cx.background(), None);
|
let mut queue = EmbeddingQueue::new(
|
||||||
|
embedding_provider.clone(),
|
||||||
|
cx.background(),
|
||||||
|
ai::auth::ProviderCredential::NoCredentials,
|
||||||
|
);
|
||||||
for file in &files {
|
for file in &files {
|
||||||
queue.push(file.clone());
|
queue.push(file.clone());
|
||||||
}
|
}
|
||||||
|
@ -284,7 +276,7 @@ fn assert_search_results(
|
||||||
#[gpui::test]
|
#[gpui::test]
|
||||||
async fn test_code_context_retrieval_rust() {
|
async fn test_code_context_retrieval_rust() {
|
||||||
let language = rust_lang();
|
let language = rust_lang();
|
||||||
let embedding_provider = Arc::new(DummyEmbeddingProvider {});
|
let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
|
||||||
let mut retriever = CodeContextRetriever::new(embedding_provider);
|
let mut retriever = CodeContextRetriever::new(embedding_provider);
|
||||||
|
|
||||||
let text = "
|
let text = "
|
||||||
|
@ -386,7 +378,7 @@ async fn test_code_context_retrieval_rust() {
|
||||||
#[gpui::test]
|
#[gpui::test]
|
||||||
async fn test_code_context_retrieval_json() {
|
async fn test_code_context_retrieval_json() {
|
||||||
let language = json_lang();
|
let language = json_lang();
|
||||||
let embedding_provider = Arc::new(DummyEmbeddingProvider {});
|
let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
|
||||||
let mut retriever = CodeContextRetriever::new(embedding_provider);
|
let mut retriever = CodeContextRetriever::new(embedding_provider);
|
||||||
|
|
||||||
let text = r#"
|
let text = r#"
|
||||||
|
@ -470,7 +462,7 @@ fn assert_documents_eq(
|
||||||
#[gpui::test]
|
#[gpui::test]
|
||||||
async fn test_code_context_retrieval_javascript() {
|
async fn test_code_context_retrieval_javascript() {
|
||||||
let language = js_lang();
|
let language = js_lang();
|
||||||
let embedding_provider = Arc::new(DummyEmbeddingProvider {});
|
let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
|
||||||
let mut retriever = CodeContextRetriever::new(embedding_provider);
|
let mut retriever = CodeContextRetriever::new(embedding_provider);
|
||||||
|
|
||||||
let text = "
|
let text = "
|
||||||
|
@ -569,7 +561,7 @@ async fn test_code_context_retrieval_javascript() {
|
||||||
#[gpui::test]
|
#[gpui::test]
|
||||||
async fn test_code_context_retrieval_lua() {
|
async fn test_code_context_retrieval_lua() {
|
||||||
let language = lua_lang();
|
let language = lua_lang();
|
||||||
let embedding_provider = Arc::new(DummyEmbeddingProvider {});
|
let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
|
||||||
let mut retriever = CodeContextRetriever::new(embedding_provider);
|
let mut retriever = CodeContextRetriever::new(embedding_provider);
|
||||||
|
|
||||||
let text = r#"
|
let text = r#"
|
||||||
|
@ -643,7 +635,7 @@ async fn test_code_context_retrieval_lua() {
|
||||||
#[gpui::test]
|
#[gpui::test]
|
||||||
async fn test_code_context_retrieval_elixir() {
|
async fn test_code_context_retrieval_elixir() {
|
||||||
let language = elixir_lang();
|
let language = elixir_lang();
|
||||||
let embedding_provider = Arc::new(DummyEmbeddingProvider {});
|
let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
|
||||||
let mut retriever = CodeContextRetriever::new(embedding_provider);
|
let mut retriever = CodeContextRetriever::new(embedding_provider);
|
||||||
|
|
||||||
let text = r#"
|
let text = r#"
|
||||||
|
@ -760,7 +752,7 @@ async fn test_code_context_retrieval_elixir() {
|
||||||
#[gpui::test]
|
#[gpui::test]
|
||||||
async fn test_code_context_retrieval_cpp() {
|
async fn test_code_context_retrieval_cpp() {
|
||||||
let language = cpp_lang();
|
let language = cpp_lang();
|
||||||
let embedding_provider = Arc::new(DummyEmbeddingProvider {});
|
let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
|
||||||
let mut retriever = CodeContextRetriever::new(embedding_provider);
|
let mut retriever = CodeContextRetriever::new(embedding_provider);
|
||||||
|
|
||||||
let text = "
|
let text = "
|
||||||
|
@ -913,7 +905,7 @@ async fn test_code_context_retrieval_cpp() {
|
||||||
#[gpui::test]
|
#[gpui::test]
|
||||||
async fn test_code_context_retrieval_ruby() {
|
async fn test_code_context_retrieval_ruby() {
|
||||||
let language = ruby_lang();
|
let language = ruby_lang();
|
||||||
let embedding_provider = Arc::new(DummyEmbeddingProvider {});
|
let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
|
||||||
let mut retriever = CodeContextRetriever::new(embedding_provider);
|
let mut retriever = CodeContextRetriever::new(embedding_provider);
|
||||||
|
|
||||||
let text = r#"
|
let text = r#"
|
||||||
|
@ -1104,7 +1096,7 @@ async fn test_code_context_retrieval_ruby() {
|
||||||
#[gpui::test]
|
#[gpui::test]
|
||||||
async fn test_code_context_retrieval_php() {
|
async fn test_code_context_retrieval_php() {
|
||||||
let language = php_lang();
|
let language = php_lang();
|
||||||
let embedding_provider = Arc::new(DummyEmbeddingProvider {});
|
let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
|
||||||
let mut retriever = CodeContextRetriever::new(embedding_provider);
|
let mut retriever = CodeContextRetriever::new(embedding_provider);
|
||||||
|
|
||||||
let text = r#"
|
let text = r#"
|
||||||
|
@ -1252,65 +1244,6 @@ async fn test_code_context_retrieval_php() {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Default)]
|
|
||||||
struct FakeEmbeddingProvider {
|
|
||||||
embedding_count: AtomicUsize,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl FakeEmbeddingProvider {
|
|
||||||
fn embedding_count(&self) -> usize {
|
|
||||||
self.embedding_count.load(atomic::Ordering::SeqCst)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn embed_sync(&self, span: &str) -> Embedding {
|
|
||||||
let mut result = vec![1.0; 26];
|
|
||||||
for letter in span.chars() {
|
|
||||||
let letter = letter.to_ascii_lowercase();
|
|
||||||
if letter as u32 >= 'a' as u32 {
|
|
||||||
let ix = (letter as u32) - ('a' as u32);
|
|
||||||
if ix < 26 {
|
|
||||||
result[ix as usize] += 1.0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
|
|
||||||
for x in &mut result {
|
|
||||||
*x /= norm;
|
|
||||||
}
|
|
||||||
|
|
||||||
result.into()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[async_trait]
|
|
||||||
impl EmbeddingProvider for FakeEmbeddingProvider {
|
|
||||||
fn base_model(&self) -> Box<dyn LanguageModel> {
|
|
||||||
Box::new(DummyLanguageModel {})
|
|
||||||
}
|
|
||||||
fn retrieve_credentials(&self, _cx: &AppContext) -> Option<String> {
|
|
||||||
Some("Fake Credentials".to_string())
|
|
||||||
}
|
|
||||||
fn max_tokens_per_batch(&self) -> usize {
|
|
||||||
1000
|
|
||||||
}
|
|
||||||
|
|
||||||
fn rate_limit_expiration(&self) -> Option<Instant> {
|
|
||||||
None
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn embed_batch(
|
|
||||||
&self,
|
|
||||||
spans: Vec<String>,
|
|
||||||
_api_key: Option<String>,
|
|
||||||
) -> Result<Vec<Embedding>> {
|
|
||||||
self.embedding_count
|
|
||||||
.fetch_add(spans.len(), atomic::Ordering::SeqCst);
|
|
||||||
|
|
||||||
anyhow::Ok(spans.iter().map(|span| self.embed_sync(span)).collect())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn js_lang() -> Arc<Language> {
|
fn js_lang() -> Arc<Language> {
|
||||||
Arc::new(
|
Arc::new(
|
||||||
Language::new(
|
Language::new(
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue