add base model to EmbeddingProvider, not yet leveraged for truncation
This commit is contained in:
parent
d1dec8314a
commit
2b780ee7b2
5 changed files with 57 additions and 2 deletions
|
@ -5,6 +5,8 @@ use rusqlite::types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef};
|
||||||
use rusqlite::ToSql;
|
use rusqlite::ToSql;
|
||||||
use std::time::Instant;
|
use std::time::Instant;
|
||||||
|
|
||||||
|
use crate::models::LanguageModel;
|
||||||
|
|
||||||
#[derive(Debug, PartialEq, Clone)]
|
#[derive(Debug, PartialEq, Clone)]
|
||||||
pub struct Embedding(pub Vec<f32>);
|
pub struct Embedding(pub Vec<f32>);
|
||||||
|
|
||||||
|
@ -66,6 +68,7 @@ impl Embedding {
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
pub trait EmbeddingProvider: Sync + Send {
|
pub trait EmbeddingProvider: Sync + Send {
|
||||||
|
fn base_model(&self) -> Box<dyn LanguageModel>;
|
||||||
fn is_authenticated(&self) -> bool;
|
fn is_authenticated(&self) -> bool;
|
||||||
async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>>;
|
async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>>;
|
||||||
fn max_tokens_per_batch(&self) -> usize;
|
fn max_tokens_per_batch(&self) -> usize;
|
||||||
|
|
|
@ -3,10 +3,42 @@ use std::time::Instant;
|
||||||
use crate::{
|
use crate::{
|
||||||
completion::CompletionRequest,
|
completion::CompletionRequest,
|
||||||
embedding::{Embedding, EmbeddingProvider},
|
embedding::{Embedding, EmbeddingProvider},
|
||||||
|
models::{LanguageModel, TruncationDirection},
|
||||||
};
|
};
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use serde::Serialize;
|
use serde::Serialize;
|
||||||
|
|
||||||
|
pub struct DummyLanguageModel {}
|
||||||
|
|
||||||
|
impl LanguageModel for DummyLanguageModel {
|
||||||
|
fn name(&self) -> String {
|
||||||
|
"dummy".to_string()
|
||||||
|
}
|
||||||
|
fn capacity(&self) -> anyhow::Result<usize> {
|
||||||
|
anyhow::Ok(1000)
|
||||||
|
}
|
||||||
|
fn truncate(
|
||||||
|
&self,
|
||||||
|
content: &str,
|
||||||
|
length: usize,
|
||||||
|
direction: crate::models::TruncationDirection,
|
||||||
|
) -> anyhow::Result<String> {
|
||||||
|
let truncated = match direction {
|
||||||
|
TruncationDirection::End => content.chars().collect::<Vec<char>>()[..length]
|
||||||
|
.iter()
|
||||||
|
.collect::<String>(),
|
||||||
|
TruncationDirection::Start => content.chars().collect::<Vec<char>>()[..length]
|
||||||
|
.iter()
|
||||||
|
.collect::<String>(),
|
||||||
|
};
|
||||||
|
|
||||||
|
anyhow::Ok(truncated)
|
||||||
|
}
|
||||||
|
fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
|
||||||
|
anyhow::Ok(content.chars().collect::<Vec<char>>().len())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Serialize)]
|
#[derive(Serialize)]
|
||||||
pub struct DummyCompletionRequest {
|
pub struct DummyCompletionRequest {
|
||||||
pub name: String,
|
pub name: String,
|
||||||
|
@ -22,6 +54,9 @@ pub struct DummyEmbeddingProvider {}
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl EmbeddingProvider for DummyEmbeddingProvider {
|
impl EmbeddingProvider for DummyEmbeddingProvider {
|
||||||
|
fn base_model(&self) -> Box<dyn LanguageModel> {
|
||||||
|
Box::new(DummyLanguageModel {})
|
||||||
|
}
|
||||||
fn is_authenticated(&self) -> bool {
|
fn is_authenticated(&self) -> bool {
|
||||||
true
|
true
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,6 +19,8 @@ use tiktoken_rs::{cl100k_base, CoreBPE};
|
||||||
use util::http::{HttpClient, Request};
|
use util::http::{HttpClient, Request};
|
||||||
|
|
||||||
use crate::embedding::{Embedding, EmbeddingProvider};
|
use crate::embedding::{Embedding, EmbeddingProvider};
|
||||||
|
use crate::models::LanguageModel;
|
||||||
|
use crate::providers::open_ai::OpenAILanguageModel;
|
||||||
|
|
||||||
lazy_static! {
|
lazy_static! {
|
||||||
static ref OPENAI_API_KEY: Option<String> = env::var("OPENAI_API_KEY").ok();
|
static ref OPENAI_API_KEY: Option<String> = env::var("OPENAI_API_KEY").ok();
|
||||||
|
@ -27,6 +29,7 @@ lazy_static! {
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct OpenAIEmbeddingProvider {
|
pub struct OpenAIEmbeddingProvider {
|
||||||
|
model: OpenAILanguageModel,
|
||||||
pub client: Arc<dyn HttpClient>,
|
pub client: Arc<dyn HttpClient>,
|
||||||
pub executor: Arc<Background>,
|
pub executor: Arc<Background>,
|
||||||
rate_limit_count_rx: watch::Receiver<Option<Instant>>,
|
rate_limit_count_rx: watch::Receiver<Option<Instant>>,
|
||||||
|
@ -65,7 +68,10 @@ impl OpenAIEmbeddingProvider {
|
||||||
let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None);
|
let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None);
|
||||||
let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx));
|
let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx));
|
||||||
|
|
||||||
|
let model = OpenAILanguageModel::load("text-embedding-ada-002");
|
||||||
|
|
||||||
OpenAIEmbeddingProvider {
|
OpenAIEmbeddingProvider {
|
||||||
|
model,
|
||||||
client,
|
client,
|
||||||
executor,
|
executor,
|
||||||
rate_limit_count_rx,
|
rate_limit_count_rx,
|
||||||
|
@ -131,6 +137,10 @@ impl OpenAIEmbeddingProvider {
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl EmbeddingProvider for OpenAIEmbeddingProvider {
|
impl EmbeddingProvider for OpenAIEmbeddingProvider {
|
||||||
|
fn base_model(&self) -> Box<dyn LanguageModel> {
|
||||||
|
let model: Box<dyn LanguageModel> = Box::new(self.model.clone());
|
||||||
|
model
|
||||||
|
}
|
||||||
fn is_authenticated(&self) -> bool {
|
fn is_authenticated(&self) -> bool {
|
||||||
OPENAI_API_KEY.as_ref().is_some()
|
OPENAI_API_KEY.as_ref().is_some()
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,6 +4,7 @@ use util::ResultExt;
|
||||||
|
|
||||||
use crate::models::{LanguageModel, TruncationDirection};
|
use crate::models::{LanguageModel, TruncationDirection};
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
pub struct OpenAILanguageModel {
|
pub struct OpenAILanguageModel {
|
||||||
name: String,
|
name: String,
|
||||||
bpe: Option<CoreBPE>,
|
bpe: Option<CoreBPE>,
|
||||||
|
|
|
@ -4,8 +4,11 @@ use crate::{
|
||||||
semantic_index_settings::SemanticIndexSettings,
|
semantic_index_settings::SemanticIndexSettings,
|
||||||
FileToEmbed, JobHandle, SearchResult, SemanticIndex, EMBEDDING_QUEUE_FLUSH_TIMEOUT,
|
FileToEmbed, JobHandle, SearchResult, SemanticIndex, EMBEDDING_QUEUE_FLUSH_TIMEOUT,
|
||||||
};
|
};
|
||||||
use ai::embedding::{Embedding, EmbeddingProvider};
|
use ai::providers::dummy::{DummyEmbeddingProvider, DummyLanguageModel};
|
||||||
use ai::providers::dummy::DummyEmbeddingProvider;
|
use ai::{
|
||||||
|
embedding::{Embedding, EmbeddingProvider},
|
||||||
|
models::LanguageModel,
|
||||||
|
};
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use gpui::{executor::Deterministic, Task, TestAppContext};
|
use gpui::{executor::Deterministic, Task, TestAppContext};
|
||||||
|
@ -1282,6 +1285,9 @@ impl FakeEmbeddingProvider {
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl EmbeddingProvider for FakeEmbeddingProvider {
|
impl EmbeddingProvider for FakeEmbeddingProvider {
|
||||||
|
fn base_model(&self) -> Box<dyn LanguageModel> {
|
||||||
|
Box::new(DummyLanguageModel {})
|
||||||
|
}
|
||||||
fn is_authenticated(&self) -> bool {
|
fn is_authenticated(&self) -> bool {
|
||||||
true
|
true
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue