add base model to EmbeddingProvider, not yet leveraged for truncation

This commit is contained in:
KCaverly 2023-10-22 15:00:09 +02:00
parent d1dec8314a
commit 2b780ee7b2
5 changed files with 57 additions and 2 deletions

View file

@ -5,6 +5,8 @@ use rusqlite::types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef};
use rusqlite::ToSql; use rusqlite::ToSql;
use std::time::Instant; use std::time::Instant;
use crate::models::LanguageModel;
#[derive(Debug, PartialEq, Clone)] #[derive(Debug, PartialEq, Clone)]
pub struct Embedding(pub Vec<f32>); pub struct Embedding(pub Vec<f32>);
@ -66,6 +68,7 @@ impl Embedding {
#[async_trait] #[async_trait]
pub trait EmbeddingProvider: Sync + Send { pub trait EmbeddingProvider: Sync + Send {
fn base_model(&self) -> Box<dyn LanguageModel>;
fn is_authenticated(&self) -> bool; fn is_authenticated(&self) -> bool;
async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>>; async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>>;
fn max_tokens_per_batch(&self) -> usize; fn max_tokens_per_batch(&self) -> usize;

View file

@ -3,10 +3,42 @@ use std::time::Instant;
use crate::{ use crate::{
completion::CompletionRequest, completion::CompletionRequest,
embedding::{Embedding, EmbeddingProvider}, embedding::{Embedding, EmbeddingProvider},
models::{LanguageModel, TruncationDirection},
}; };
use async_trait::async_trait; use async_trait::async_trait;
use serde::Serialize; use serde::Serialize;
pub struct DummyLanguageModel {}
impl LanguageModel for DummyLanguageModel {
fn name(&self) -> String {
"dummy".to_string()
}
fn capacity(&self) -> anyhow::Result<usize> {
anyhow::Ok(1000)
}
fn truncate(
&self,
content: &str,
length: usize,
direction: crate::models::TruncationDirection,
) -> anyhow::Result<String> {
let truncated = match direction {
TruncationDirection::End => content.chars().collect::<Vec<char>>()[..length]
.iter()
.collect::<String>(),
TruncationDirection::Start => content.chars().collect::<Vec<char>>()[..length]
.iter()
.collect::<String>(),
};
anyhow::Ok(truncated)
}
fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
anyhow::Ok(content.chars().collect::<Vec<char>>().len())
}
}
#[derive(Serialize)] #[derive(Serialize)]
pub struct DummyCompletionRequest { pub struct DummyCompletionRequest {
pub name: String, pub name: String,
@ -22,6 +54,9 @@ pub struct DummyEmbeddingProvider {}
#[async_trait] #[async_trait]
impl EmbeddingProvider for DummyEmbeddingProvider { impl EmbeddingProvider for DummyEmbeddingProvider {
fn base_model(&self) -> Box<dyn LanguageModel> {
Box::new(DummyLanguageModel {})
}
fn is_authenticated(&self) -> bool { fn is_authenticated(&self) -> bool {
true true
} }

View file

@ -19,6 +19,8 @@ use tiktoken_rs::{cl100k_base, CoreBPE};
use util::http::{HttpClient, Request}; use util::http::{HttpClient, Request};
use crate::embedding::{Embedding, EmbeddingProvider}; use crate::embedding::{Embedding, EmbeddingProvider};
use crate::models::LanguageModel;
use crate::providers::open_ai::OpenAILanguageModel;
lazy_static! { lazy_static! {
static ref OPENAI_API_KEY: Option<String> = env::var("OPENAI_API_KEY").ok(); static ref OPENAI_API_KEY: Option<String> = env::var("OPENAI_API_KEY").ok();
@ -27,6 +29,7 @@ lazy_static! {
#[derive(Clone)] #[derive(Clone)]
pub struct OpenAIEmbeddingProvider { pub struct OpenAIEmbeddingProvider {
model: OpenAILanguageModel,
pub client: Arc<dyn HttpClient>, pub client: Arc<dyn HttpClient>,
pub executor: Arc<Background>, pub executor: Arc<Background>,
rate_limit_count_rx: watch::Receiver<Option<Instant>>, rate_limit_count_rx: watch::Receiver<Option<Instant>>,
@ -65,7 +68,10 @@ impl OpenAIEmbeddingProvider {
let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None); let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None);
let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx)); let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx));
let model = OpenAILanguageModel::load("text-embedding-ada-002");
OpenAIEmbeddingProvider { OpenAIEmbeddingProvider {
model,
client, client,
executor, executor,
rate_limit_count_rx, rate_limit_count_rx,
@ -131,6 +137,10 @@ impl OpenAIEmbeddingProvider {
#[async_trait] #[async_trait]
impl EmbeddingProvider for OpenAIEmbeddingProvider { impl EmbeddingProvider for OpenAIEmbeddingProvider {
fn base_model(&self) -> Box<dyn LanguageModel> {
let model: Box<dyn LanguageModel> = Box::new(self.model.clone());
model
}
fn is_authenticated(&self) -> bool { fn is_authenticated(&self) -> bool {
OPENAI_API_KEY.as_ref().is_some() OPENAI_API_KEY.as_ref().is_some()
} }

View file

@ -4,6 +4,7 @@ use util::ResultExt;
use crate::models::{LanguageModel, TruncationDirection}; use crate::models::{LanguageModel, TruncationDirection};
#[derive(Clone)]
pub struct OpenAILanguageModel { pub struct OpenAILanguageModel {
name: String, name: String,
bpe: Option<CoreBPE>, bpe: Option<CoreBPE>,

View file

@ -4,8 +4,11 @@ use crate::{
semantic_index_settings::SemanticIndexSettings, semantic_index_settings::SemanticIndexSettings,
FileToEmbed, JobHandle, SearchResult, SemanticIndex, EMBEDDING_QUEUE_FLUSH_TIMEOUT, FileToEmbed, JobHandle, SearchResult, SemanticIndex, EMBEDDING_QUEUE_FLUSH_TIMEOUT,
}; };
use ai::embedding::{Embedding, EmbeddingProvider}; use ai::providers::dummy::{DummyEmbeddingProvider, DummyLanguageModel};
use ai::providers::dummy::DummyEmbeddingProvider; use ai::{
embedding::{Embedding, EmbeddingProvider},
models::LanguageModel,
};
use anyhow::Result; use anyhow::Result;
use async_trait::async_trait; use async_trait::async_trait;
use gpui::{executor::Deterministic, Task, TestAppContext}; use gpui::{executor::Deterministic, Task, TestAppContext};
@ -1282,6 +1285,9 @@ impl FakeEmbeddingProvider {
#[async_trait] #[async_trait]
impl EmbeddingProvider for FakeEmbeddingProvider { impl EmbeddingProvider for FakeEmbeddingProvider {
fn base_model(&self) -> Box<dyn LanguageModel> {
Box::new(DummyLanguageModel {})
}
fn is_authenticated(&self) -> bool { fn is_authenticated(&self) -> bool {
true true
} }