move OpenAIEmbeddings to OpenAIEmbeddingProvider in providers folder

This commit is contained in:
KCaverly 2023-10-22 14:46:22 +02:00
parent d813ae8845
commit d1dec8314a
7 changed files with 308 additions and 299 deletions

View file

@ -7,7 +7,8 @@ pub mod semantic_index_settings;
mod semantic_index_tests;
use crate::semantic_index_settings::SemanticIndexSettings;
use ai::embedding::{Embedding, EmbeddingProvider, OpenAIEmbeddings};
use ai::embedding::{Embedding, EmbeddingProvider};
use ai::providers::open_ai::OpenAIEmbeddingProvider;
use anyhow::{anyhow, Result};
use collections::{BTreeMap, HashMap, HashSet};
use db::VectorDatabase;
@ -88,7 +89,7 @@ pub fn init(
let semantic_index = SemanticIndex::new(
fs,
db_file_path,
Arc::new(OpenAIEmbeddings::new(http_client, cx.background())),
Arc::new(OpenAIEmbeddingProvider::new(http_client, cx.background())),
language_registry,
cx.clone(),
)

View file

@ -4,7 +4,8 @@ use crate::{
semantic_index_settings::SemanticIndexSettings,
FileToEmbed, JobHandle, SearchResult, SemanticIndex, EMBEDDING_QUEUE_FLUSH_TIMEOUT,
};
use ai::embedding::{DummyEmbeddings, Embedding, EmbeddingProvider};
use ai::embedding::{Embedding, EmbeddingProvider};
use ai::providers::dummy::DummyEmbeddingProvider;
use anyhow::Result;
use async_trait::async_trait;
use gpui::{executor::Deterministic, Task, TestAppContext};
@ -280,7 +281,7 @@ fn assert_search_results(
#[gpui::test]
async fn test_code_context_retrieval_rust() {
let language = rust_lang();
let embedding_provider = Arc::new(DummyEmbeddings {});
let embedding_provider = Arc::new(DummyEmbeddingProvider {});
let mut retriever = CodeContextRetriever::new(embedding_provider);
let text = "
@ -382,7 +383,7 @@ async fn test_code_context_retrieval_rust() {
#[gpui::test]
async fn test_code_context_retrieval_json() {
let language = json_lang();
let embedding_provider = Arc::new(DummyEmbeddings {});
let embedding_provider = Arc::new(DummyEmbeddingProvider {});
let mut retriever = CodeContextRetriever::new(embedding_provider);
let text = r#"
@ -466,7 +467,7 @@ fn assert_documents_eq(
#[gpui::test]
async fn test_code_context_retrieval_javascript() {
let language = js_lang();
let embedding_provider = Arc::new(DummyEmbeddings {});
let embedding_provider = Arc::new(DummyEmbeddingProvider {});
let mut retriever = CodeContextRetriever::new(embedding_provider);
let text = "
@ -565,7 +566,7 @@ async fn test_code_context_retrieval_javascript() {
#[gpui::test]
async fn test_code_context_retrieval_lua() {
let language = lua_lang();
let embedding_provider = Arc::new(DummyEmbeddings {});
let embedding_provider = Arc::new(DummyEmbeddingProvider {});
let mut retriever = CodeContextRetriever::new(embedding_provider);
let text = r#"
@ -639,7 +640,7 @@ async fn test_code_context_retrieval_lua() {
#[gpui::test]
async fn test_code_context_retrieval_elixir() {
let language = elixir_lang();
let embedding_provider = Arc::new(DummyEmbeddings {});
let embedding_provider = Arc::new(DummyEmbeddingProvider {});
let mut retriever = CodeContextRetriever::new(embedding_provider);
let text = r#"
@ -756,7 +757,7 @@ async fn test_code_context_retrieval_elixir() {
#[gpui::test]
async fn test_code_context_retrieval_cpp() {
let language = cpp_lang();
let embedding_provider = Arc::new(DummyEmbeddings {});
let embedding_provider = Arc::new(DummyEmbeddingProvider {});
let mut retriever = CodeContextRetriever::new(embedding_provider);
let text = "
@ -909,7 +910,7 @@ async fn test_code_context_retrieval_cpp() {
#[gpui::test]
async fn test_code_context_retrieval_ruby() {
let language = ruby_lang();
let embedding_provider = Arc::new(DummyEmbeddings {});
let embedding_provider = Arc::new(DummyEmbeddingProvider {});
let mut retriever = CodeContextRetriever::new(embedding_provider);
let text = r#"
@ -1100,7 +1101,7 @@ async fn test_code_context_retrieval_ruby() {
#[gpui::test]
async fn test_code_context_retrieval_php() {
let language = php_lang();
let embedding_provider = Arc::new(DummyEmbeddings {});
let embedding_provider = Arc::new(DummyEmbeddingProvider {});
let mut retriever = CodeContextRetriever::new(embedding_provider);
let text = r#"