catchup with main

This commit is contained in:
KCaverly 2023-10-25 16:31:00 +02:00
commit 71bc35d241
84 changed files with 6026 additions and 3636 deletions

View file

@ -2,6 +2,7 @@ use std::time::Instant;
use anyhow::Result;
use async_trait::async_trait;
use gpui::AppContext;
use ordered_float::OrderedFloat;
use rusqlite::types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef};
use rusqlite::ToSql;
@ -70,8 +71,12 @@ impl Embedding {
#[async_trait]
pub trait EmbeddingProvider: Sync + Send {
fn base_model(&self) -> Box<dyn LanguageModel>;
fn is_authenticated(&self) -> bool;
async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>>;
fn retrieve_credentials(&self, cx: &AppContext) -> Option<String>;
async fn embed_batch(
&self,
spans: Vec<String>,
api_key: Option<String>,
) -> Result<Vec<Embedding>>;
fn max_tokens_per_batch(&self) -> usize;
fn rate_limit_expiration(&self) -> Option<Instant>;
}

View file

@ -6,6 +6,7 @@ use crate::{
models::{LanguageModel, TruncationDirection},
};
use async_trait::async_trait;
use gpui::AppContext;
use serde::Serialize;
pub struct DummyLanguageModel {}
@ -58,16 +59,20 @@ pub struct DummyEmbeddingProvider {}
#[async_trait]
impl EmbeddingProvider for DummyEmbeddingProvider {
fn retrieve_credentials(&self, _cx: &AppContext) -> Option<String> {
Some("Dummy Credentials".to_string())
}
fn base_model(&self) -> Box<dyn LanguageModel> {
Box::new(DummyLanguageModel {})
}
fn is_authenticated(&self) -> bool {
true
}
fn rate_limit_expiration(&self) -> Option<Instant> {
None
}
async fn embed_batch(&self, spans: Vec<String>) -> anyhow::Result<Vec<Embedding>> {
async fn embed_batch(
&self,
spans: Vec<String>,
api_key: Option<String>,
) -> anyhow::Result<Vec<Embedding>> {
// 1024 is the OpenAI Embeddings size for ada models.
// the model we will likely be starting with.
let dummy_vec = Embedding::from(vec![0.32 as f32; 1536]);

View file

@ -2,7 +2,7 @@ use anyhow::{anyhow, Result};
use async_trait::async_trait;
use futures::AsyncReadExt;
use gpui::executor::Background;
use gpui::serde_json;
use gpui::{serde_json, AppContext};
use isahc::http::StatusCode;
use isahc::prelude::Configurable;
use isahc::{AsyncBody, Response};
@ -17,11 +17,14 @@ use std::sync::Arc;
use std::time::{Duration, Instant};
use tiktoken_rs::{cl100k_base, CoreBPE};
use util::http::{HttpClient, Request};
use util::ResultExt;
use crate::embedding::{Embedding, EmbeddingProvider};
use crate::models::LanguageModel;
use crate::providers::open_ai::OpenAILanguageModel;
use super::OPENAI_API_URL;
lazy_static! {
static ref OPENAI_API_KEY: Option<String> = env::var("OPENAI_API_KEY").ok();
static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap();
@ -135,13 +138,25 @@ impl OpenAIEmbeddingProvider {
#[async_trait]
impl EmbeddingProvider for OpenAIEmbeddingProvider {
fn retrieve_credentials(&self, cx: &AppContext) -> Option<String> {
let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") {
Some(api_key)
} else if let Some((_, api_key)) = cx
.platform()
.read_credentials(OPENAI_API_URL)
.log_err()
.flatten()
{
String::from_utf8(api_key).log_err()
} else {
None
};
api_key
}
fn base_model(&self) -> Box<dyn LanguageModel> {
let model: Box<dyn LanguageModel> = Box::new(self.model.clone());
model
}
fn is_authenticated(&self) -> bool {
OPENAI_API_KEY.as_ref().is_some()
}
fn max_tokens_per_batch(&self) -> usize {
50000
}
@ -164,7 +179,11 @@ impl EmbeddingProvider for OpenAIEmbeddingProvider {
// (output, tokens.len())
// }
async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
async fn embed_batch(
&self,
spans: Vec<String>,
api_key: Option<String>,
) -> Result<Vec<Embedding>> {
const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
const MAX_RETRIES: usize = 4;