Update casing of "OpenAI" in identifiers to match Rust conventions (#6940)
This PR updates the casing of "OpenAI" when used in Rust identifiers to match the [Rust naming guidelines](https://rust-lang.github.io/api-guidelines/naming.html): > In `UpperCamelCase`, acronyms and contractions of compound words count as one word: use `Uuid` rather than `UUID`, `Usize` rather than `USize` or `Stdin` rather than `StdIn`. Release Notes: - N/A
This commit is contained in:
parent
e8bf06fc42
commit
027f055841
11 changed files with 85 additions and 96 deletions
9
crates/ai/src/providers/open_ai.rs
Normal file
9
crates/ai/src/providers/open_ai.rs
Normal file
|
@ -0,0 +1,9 @@
|
||||||
|
pub mod completion;
|
||||||
|
pub mod embedding;
|
||||||
|
pub mod model;
|
||||||
|
|
||||||
|
pub use completion::*;
|
||||||
|
pub use embedding::*;
|
||||||
|
pub use model::OpenAiLanguageModel;
|
||||||
|
|
||||||
|
pub const OPEN_AI_API_URL: &'static str = "https://api.openai.com/v1";
|
|
@ -21,7 +21,7 @@ use crate::{
|
||||||
models::LanguageModel,
|
models::LanguageModel,
|
||||||
};
|
};
|
||||||
|
|
||||||
use crate::providers::open_ai::{OpenAILanguageModel, OPENAI_API_URL};
|
use crate::providers::open_ai::{OpenAiLanguageModel, OPEN_AI_API_URL};
|
||||||
|
|
||||||
#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
|
#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||||
#[serde(rename_all = "lowercase")]
|
#[serde(rename_all = "lowercase")]
|
||||||
|
@ -58,7 +58,7 @@ pub struct RequestMessage {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Default, Serialize)]
|
#[derive(Debug, Default, Serialize)]
|
||||||
pub struct OpenAIRequest {
|
pub struct OpenAiRequest {
|
||||||
pub model: String,
|
pub model: String,
|
||||||
pub messages: Vec<RequestMessage>,
|
pub messages: Vec<RequestMessage>,
|
||||||
pub stream: bool,
|
pub stream: bool,
|
||||||
|
@ -66,7 +66,7 @@ pub struct OpenAIRequest {
|
||||||
pub temperature: f32,
|
pub temperature: f32,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl CompletionRequest for OpenAIRequest {
|
impl CompletionRequest for OpenAiRequest {
|
||||||
fn data(&self) -> serde_json::Result<String> {
|
fn data(&self) -> serde_json::Result<String> {
|
||||||
serde_json::to_string(self)
|
serde_json::to_string(self)
|
||||||
}
|
}
|
||||||
|
@ -79,7 +79,7 @@ pub struct ResponseMessage {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Deserialize, Debug)]
|
#[derive(Deserialize, Debug)]
|
||||||
pub struct OpenAIUsage {
|
pub struct OpenAiUsage {
|
||||||
pub prompt_tokens: u32,
|
pub prompt_tokens: u32,
|
||||||
pub completion_tokens: u32,
|
pub completion_tokens: u32,
|
||||||
pub total_tokens: u32,
|
pub total_tokens: u32,
|
||||||
|
@ -93,20 +93,20 @@ pub struct ChatChoiceDelta {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Deserialize, Debug)]
|
#[derive(Deserialize, Debug)]
|
||||||
pub struct OpenAIResponseStreamEvent {
|
pub struct OpenAiResponseStreamEvent {
|
||||||
pub id: Option<String>,
|
pub id: Option<String>,
|
||||||
pub object: String,
|
pub object: String,
|
||||||
pub created: u32,
|
pub created: u32,
|
||||||
pub model: String,
|
pub model: String,
|
||||||
pub choices: Vec<ChatChoiceDelta>,
|
pub choices: Vec<ChatChoiceDelta>,
|
||||||
pub usage: Option<OpenAIUsage>,
|
pub usage: Option<OpenAiUsage>,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn stream_completion(
|
pub async fn stream_completion(
|
||||||
credential: ProviderCredential,
|
credential: ProviderCredential,
|
||||||
executor: BackgroundExecutor,
|
executor: BackgroundExecutor,
|
||||||
request: Box<dyn CompletionRequest>,
|
request: Box<dyn CompletionRequest>,
|
||||||
) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
|
) -> Result<impl Stream<Item = Result<OpenAiResponseStreamEvent>>> {
|
||||||
let api_key = match credential {
|
let api_key = match credential {
|
||||||
ProviderCredential::Credentials { api_key } => api_key,
|
ProviderCredential::Credentials { api_key } => api_key,
|
||||||
_ => {
|
_ => {
|
||||||
|
@ -114,10 +114,10 @@ pub async fn stream_completion(
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAIResponseStreamEvent>>();
|
let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAiResponseStreamEvent>>();
|
||||||
|
|
||||||
let json_data = request.data()?;
|
let json_data = request.data()?;
|
||||||
let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions"))
|
let mut response = Request::post(format!("{OPEN_AI_API_URL}/chat/completions"))
|
||||||
.header("Content-Type", "application/json")
|
.header("Content-Type", "application/json")
|
||||||
.header("Authorization", format!("Bearer {}", api_key))
|
.header("Authorization", format!("Bearer {}", api_key))
|
||||||
.body(json_data)?
|
.body(json_data)?
|
||||||
|
@ -132,7 +132,7 @@ pub async fn stream_completion(
|
||||||
|
|
||||||
fn parse_line(
|
fn parse_line(
|
||||||
line: Result<String, io::Error>,
|
line: Result<String, io::Error>,
|
||||||
) -> Result<Option<OpenAIResponseStreamEvent>> {
|
) -> Result<Option<OpenAiResponseStreamEvent>> {
|
||||||
if let Some(data) = line?.strip_prefix("data: ") {
|
if let Some(data) = line?.strip_prefix("data: ") {
|
||||||
let event = serde_json::from_str(data)?;
|
let event = serde_json::from_str(data)?;
|
||||||
Ok(Some(event))
|
Ok(Some(event))
|
||||||
|
@ -169,16 +169,16 @@ pub async fn stream_completion(
|
||||||
response.body_mut().read_to_string(&mut body).await?;
|
response.body_mut().read_to_string(&mut body).await?;
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize)]
|
||||||
struct OpenAIResponse {
|
struct OpenAiResponse {
|
||||||
error: OpenAIError,
|
error: OpenAiError,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize)]
|
||||||
struct OpenAIError {
|
struct OpenAiError {
|
||||||
message: String,
|
message: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
match serde_json::from_str::<OpenAIResponse>(&body) {
|
match serde_json::from_str::<OpenAiResponse>(&body) {
|
||||||
Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
|
Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
|
||||||
"Failed to connect to OpenAI API: {}",
|
"Failed to connect to OpenAI API: {}",
|
||||||
response.error.message,
|
response.error.message,
|
||||||
|
@ -194,16 +194,16 @@ pub async fn stream_completion(
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct OpenAICompletionProvider {
|
pub struct OpenAiCompletionProvider {
|
||||||
model: OpenAILanguageModel,
|
model: OpenAiLanguageModel,
|
||||||
credential: Arc<RwLock<ProviderCredential>>,
|
credential: Arc<RwLock<ProviderCredential>>,
|
||||||
executor: BackgroundExecutor,
|
executor: BackgroundExecutor,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl OpenAICompletionProvider {
|
impl OpenAiCompletionProvider {
|
||||||
pub async fn new(model_name: String, executor: BackgroundExecutor) -> Self {
|
pub async fn new(model_name: String, executor: BackgroundExecutor) -> Self {
|
||||||
let model = executor
|
let model = executor
|
||||||
.spawn(async move { OpenAILanguageModel::load(&model_name) })
|
.spawn(async move { OpenAiLanguageModel::load(&model_name) })
|
||||||
.await;
|
.await;
|
||||||
let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials));
|
let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials));
|
||||||
Self {
|
Self {
|
||||||
|
@ -214,7 +214,7 @@ impl OpenAICompletionProvider {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl CredentialProvider for OpenAICompletionProvider {
|
impl CredentialProvider for OpenAiCompletionProvider {
|
||||||
fn has_credentials(&self) -> bool {
|
fn has_credentials(&self) -> bool {
|
||||||
match *self.credential.read() {
|
match *self.credential.read() {
|
||||||
ProviderCredential::Credentials { .. } => true,
|
ProviderCredential::Credentials { .. } => true,
|
||||||
|
@ -232,7 +232,7 @@ impl CredentialProvider for OpenAICompletionProvider {
|
||||||
if let Some(api_key) = env::var("OPENAI_API_KEY").log_err() {
|
if let Some(api_key) = env::var("OPENAI_API_KEY").log_err() {
|
||||||
async move { ProviderCredential::Credentials { api_key } }.boxed()
|
async move { ProviderCredential::Credentials { api_key } }.boxed()
|
||||||
} else {
|
} else {
|
||||||
let credentials = cx.read_credentials(OPENAI_API_URL);
|
let credentials = cx.read_credentials(OPEN_AI_API_URL);
|
||||||
async move {
|
async move {
|
||||||
if let Some(Some((_, api_key))) = credentials.await.log_err() {
|
if let Some(Some((_, api_key))) = credentials.await.log_err() {
|
||||||
if let Some(api_key) = String::from_utf8(api_key).log_err() {
|
if let Some(api_key) = String::from_utf8(api_key).log_err() {
|
||||||
|
@ -266,7 +266,7 @@ impl CredentialProvider for OpenAICompletionProvider {
|
||||||
let credential = credential.clone();
|
let credential = credential.clone();
|
||||||
let write_credentials = match credential {
|
let write_credentials = match credential {
|
||||||
ProviderCredential::Credentials { api_key } => {
|
ProviderCredential::Credentials { api_key } => {
|
||||||
Some(cx.write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes()))
|
Some(cx.write_credentials(OPEN_AI_API_URL, "Bearer", api_key.as_bytes()))
|
||||||
}
|
}
|
||||||
_ => None,
|
_ => None,
|
||||||
};
|
};
|
||||||
|
@ -281,7 +281,7 @@ impl CredentialProvider for OpenAICompletionProvider {
|
||||||
|
|
||||||
fn delete_credentials(&self, cx: &mut AppContext) -> BoxFuture<()> {
|
fn delete_credentials(&self, cx: &mut AppContext) -> BoxFuture<()> {
|
||||||
*self.credential.write() = ProviderCredential::NoCredentials;
|
*self.credential.write() = ProviderCredential::NoCredentials;
|
||||||
let delete_credentials = cx.delete_credentials(OPENAI_API_URL);
|
let delete_credentials = cx.delete_credentials(OPEN_AI_API_URL);
|
||||||
async move {
|
async move {
|
||||||
delete_credentials.await.log_err();
|
delete_credentials.await.log_err();
|
||||||
}
|
}
|
||||||
|
@ -289,7 +289,7 @@ impl CredentialProvider for OpenAICompletionProvider {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl CompletionProvider for OpenAICompletionProvider {
|
impl CompletionProvider for OpenAiCompletionProvider {
|
||||||
fn base_model(&self) -> Box<dyn LanguageModel> {
|
fn base_model(&self) -> Box<dyn LanguageModel> {
|
||||||
let model: Box<dyn LanguageModel> = Box::new(self.model.clone());
|
let model: Box<dyn LanguageModel> = Box::new(self.model.clone());
|
||||||
model
|
model
|
||||||
|
|
|
@ -25,17 +25,17 @@ use util::ResultExt;
|
||||||
use crate::auth::{CredentialProvider, ProviderCredential};
|
use crate::auth::{CredentialProvider, ProviderCredential};
|
||||||
use crate::embedding::{Embedding, EmbeddingProvider};
|
use crate::embedding::{Embedding, EmbeddingProvider};
|
||||||
use crate::models::LanguageModel;
|
use crate::models::LanguageModel;
|
||||||
use crate::providers::open_ai::OpenAILanguageModel;
|
use crate::providers::open_ai::OpenAiLanguageModel;
|
||||||
|
|
||||||
use crate::providers::open_ai::OPENAI_API_URL;
|
use crate::providers::open_ai::OPEN_AI_API_URL;
|
||||||
|
|
||||||
lazy_static! {
|
lazy_static! {
|
||||||
static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap();
|
static ref OPEN_AI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct OpenAIEmbeddingProvider {
|
pub struct OpenAiEmbeddingProvider {
|
||||||
model: OpenAILanguageModel,
|
model: OpenAiLanguageModel,
|
||||||
credential: Arc<RwLock<ProviderCredential>>,
|
credential: Arc<RwLock<ProviderCredential>>,
|
||||||
pub client: Arc<dyn HttpClient>,
|
pub client: Arc<dyn HttpClient>,
|
||||||
pub executor: BackgroundExecutor,
|
pub executor: BackgroundExecutor,
|
||||||
|
@ -44,42 +44,42 @@ pub struct OpenAIEmbeddingProvider {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize)]
|
#[derive(Serialize)]
|
||||||
struct OpenAIEmbeddingRequest<'a> {
|
struct OpenAiEmbeddingRequest<'a> {
|
||||||
model: &'static str,
|
model: &'static str,
|
||||||
input: Vec<&'a str>,
|
input: Vec<&'a str>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize)]
|
||||||
struct OpenAIEmbeddingResponse {
|
struct OpenAiEmbeddingResponse {
|
||||||
data: Vec<OpenAIEmbedding>,
|
data: Vec<OpenAiEmbedding>,
|
||||||
usage: OpenAIEmbeddingUsage,
|
usage: OpenAiEmbeddingUsage,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
struct OpenAIEmbedding {
|
struct OpenAiEmbedding {
|
||||||
embedding: Vec<f32>,
|
embedding: Vec<f32>,
|
||||||
index: usize,
|
index: usize,
|
||||||
object: String,
|
object: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize)]
|
||||||
struct OpenAIEmbeddingUsage {
|
struct OpenAiEmbeddingUsage {
|
||||||
prompt_tokens: usize,
|
prompt_tokens: usize,
|
||||||
total_tokens: usize,
|
total_tokens: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl OpenAIEmbeddingProvider {
|
impl OpenAiEmbeddingProvider {
|
||||||
pub async fn new(client: Arc<dyn HttpClient>, executor: BackgroundExecutor) -> Self {
|
pub async fn new(client: Arc<dyn HttpClient>, executor: BackgroundExecutor) -> Self {
|
||||||
let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None);
|
let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None);
|
||||||
let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx));
|
let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx));
|
||||||
|
|
||||||
// Loading the model is expensive, so ensure this runs off the main thread.
|
// Loading the model is expensive, so ensure this runs off the main thread.
|
||||||
let model = executor
|
let model = executor
|
||||||
.spawn(async move { OpenAILanguageModel::load("text-embedding-ada-002") })
|
.spawn(async move { OpenAiLanguageModel::load("text-embedding-ada-002") })
|
||||||
.await;
|
.await;
|
||||||
let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials));
|
let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials));
|
||||||
|
|
||||||
OpenAIEmbeddingProvider {
|
OpenAiEmbeddingProvider {
|
||||||
model,
|
model,
|
||||||
credential,
|
credential,
|
||||||
client,
|
client,
|
||||||
|
@ -140,7 +140,7 @@ impl OpenAIEmbeddingProvider {
|
||||||
.header("Content-Type", "application/json")
|
.header("Content-Type", "application/json")
|
||||||
.header("Authorization", format!("Bearer {}", api_key))
|
.header("Authorization", format!("Bearer {}", api_key))
|
||||||
.body(
|
.body(
|
||||||
serde_json::to_string(&OpenAIEmbeddingRequest {
|
serde_json::to_string(&OpenAiEmbeddingRequest {
|
||||||
input: spans.clone(),
|
input: spans.clone(),
|
||||||
model: "text-embedding-ada-002",
|
model: "text-embedding-ada-002",
|
||||||
})
|
})
|
||||||
|
@ -152,7 +152,7 @@ impl OpenAIEmbeddingProvider {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl CredentialProvider for OpenAIEmbeddingProvider {
|
impl CredentialProvider for OpenAiEmbeddingProvider {
|
||||||
fn has_credentials(&self) -> bool {
|
fn has_credentials(&self) -> bool {
|
||||||
match *self.credential.read() {
|
match *self.credential.read() {
|
||||||
ProviderCredential::Credentials { .. } => true,
|
ProviderCredential::Credentials { .. } => true,
|
||||||
|
@ -170,7 +170,7 @@ impl CredentialProvider for OpenAIEmbeddingProvider {
|
||||||
if let Some(api_key) = env::var("OPENAI_API_KEY").log_err() {
|
if let Some(api_key) = env::var("OPENAI_API_KEY").log_err() {
|
||||||
async move { ProviderCredential::Credentials { api_key } }.boxed()
|
async move { ProviderCredential::Credentials { api_key } }.boxed()
|
||||||
} else {
|
} else {
|
||||||
let credentials = cx.read_credentials(OPENAI_API_URL);
|
let credentials = cx.read_credentials(OPEN_AI_API_URL);
|
||||||
async move {
|
async move {
|
||||||
if let Some(Some((_, api_key))) = credentials.await.log_err() {
|
if let Some(Some((_, api_key))) = credentials.await.log_err() {
|
||||||
if let Some(api_key) = String::from_utf8(api_key).log_err() {
|
if let Some(api_key) = String::from_utf8(api_key).log_err() {
|
||||||
|
@ -204,7 +204,7 @@ impl CredentialProvider for OpenAIEmbeddingProvider {
|
||||||
let credential = credential.clone();
|
let credential = credential.clone();
|
||||||
let write_credentials = match credential {
|
let write_credentials = match credential {
|
||||||
ProviderCredential::Credentials { api_key } => {
|
ProviderCredential::Credentials { api_key } => {
|
||||||
Some(cx.write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes()))
|
Some(cx.write_credentials(OPEN_AI_API_URL, "Bearer", api_key.as_bytes()))
|
||||||
}
|
}
|
||||||
_ => None,
|
_ => None,
|
||||||
};
|
};
|
||||||
|
@ -219,7 +219,7 @@ impl CredentialProvider for OpenAIEmbeddingProvider {
|
||||||
|
|
||||||
fn delete_credentials(&self, cx: &mut AppContext) -> BoxFuture<()> {
|
fn delete_credentials(&self, cx: &mut AppContext) -> BoxFuture<()> {
|
||||||
*self.credential.write() = ProviderCredential::NoCredentials;
|
*self.credential.write() = ProviderCredential::NoCredentials;
|
||||||
let delete_credentials = cx.delete_credentials(OPENAI_API_URL);
|
let delete_credentials = cx.delete_credentials(OPEN_AI_API_URL);
|
||||||
async move {
|
async move {
|
||||||
delete_credentials.await.log_err();
|
delete_credentials.await.log_err();
|
||||||
}
|
}
|
||||||
|
@ -228,7 +228,7 @@ impl CredentialProvider for OpenAIEmbeddingProvider {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl EmbeddingProvider for OpenAIEmbeddingProvider {
|
impl EmbeddingProvider for OpenAiEmbeddingProvider {
|
||||||
fn base_model(&self) -> Box<dyn LanguageModel> {
|
fn base_model(&self) -> Box<dyn LanguageModel> {
|
||||||
let model: Box<dyn LanguageModel> = Box::new(self.model.clone());
|
let model: Box<dyn LanguageModel> = Box::new(self.model.clone());
|
||||||
model
|
model
|
||||||
|
@ -270,7 +270,7 @@ impl EmbeddingProvider for OpenAIEmbeddingProvider {
|
||||||
StatusCode::OK => {
|
StatusCode::OK => {
|
||||||
let mut body = String::new();
|
let mut body = String::new();
|
||||||
response.body_mut().read_to_string(&mut body).await?;
|
response.body_mut().read_to_string(&mut body).await?;
|
||||||
let response: OpenAIEmbeddingResponse = serde_json::from_str(&body)?;
|
let response: OpenAiEmbeddingResponse = serde_json::from_str(&body)?;
|
||||||
|
|
||||||
log::trace!(
|
log::trace!(
|
||||||
"openai embedding completed. tokens: {:?}",
|
"openai embedding completed. tokens: {:?}",
|
||||||
|
|
|
@ -1,9 +0,0 @@
|
||||||
pub mod completion;
|
|
||||||
pub mod embedding;
|
|
||||||
pub mod model;
|
|
||||||
|
|
||||||
pub use completion::*;
|
|
||||||
pub use embedding::*;
|
|
||||||
pub use model::OpenAILanguageModel;
|
|
||||||
|
|
||||||
pub const OPENAI_API_URL: &'static str = "https://api.openai.com/v1";
|
|
|
@ -5,22 +5,22 @@ use util::ResultExt;
|
||||||
use crate::models::{LanguageModel, TruncationDirection};
|
use crate::models::{LanguageModel, TruncationDirection};
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct OpenAILanguageModel {
|
pub struct OpenAiLanguageModel {
|
||||||
name: String,
|
name: String,
|
||||||
bpe: Option<CoreBPE>,
|
bpe: Option<CoreBPE>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl OpenAILanguageModel {
|
impl OpenAiLanguageModel {
|
||||||
pub fn load(model_name: &str) -> Self {
|
pub fn load(model_name: &str) -> Self {
|
||||||
let bpe = tiktoken_rs::get_bpe_from_model(model_name).log_err();
|
let bpe = tiktoken_rs::get_bpe_from_model(model_name).log_err();
|
||||||
OpenAILanguageModel {
|
OpenAiLanguageModel {
|
||||||
name: model_name.to_string(),
|
name: model_name.to_string(),
|
||||||
bpe,
|
bpe,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl LanguageModel for OpenAILanguageModel {
|
impl LanguageModel for OpenAiLanguageModel {
|
||||||
fn name(&self) -> String {
|
fn name(&self) -> String {
|
||||||
self.name.clone()
|
self.name.clone()
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,11 +0,0 @@
|
||||||
pub trait LanguageModel {
|
|
||||||
fn name(&self) -> String;
|
|
||||||
fn count_tokens(&self, content: &str) -> anyhow::Result<usize>;
|
|
||||||
fn truncate(
|
|
||||||
&self,
|
|
||||||
content: &str,
|
|
||||||
length: usize,
|
|
||||||
direction: TruncationDirection,
|
|
||||||
) -> anyhow::Result<String>;
|
|
||||||
fn capacity(&self) -> anyhow::Result<usize>;
|
|
||||||
}
|
|
|
@ -7,7 +7,7 @@ mod streaming_diff;
|
||||||
use ai::providers::open_ai::Role;
|
use ai::providers::open_ai::Role;
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
pub use assistant_panel::AssistantPanel;
|
pub use assistant_panel::AssistantPanel;
|
||||||
use assistant_settings::OpenAIModel;
|
use assistant_settings::OpenAiModel;
|
||||||
use chrono::{DateTime, Local};
|
use chrono::{DateTime, Local};
|
||||||
use collections::HashMap;
|
use collections::HashMap;
|
||||||
use fs::Fs;
|
use fs::Fs;
|
||||||
|
@ -68,7 +68,7 @@ struct SavedConversation {
|
||||||
messages: Vec<SavedMessage>,
|
messages: Vec<SavedMessage>,
|
||||||
message_metadata: HashMap<MessageId, MessageMetadata>,
|
message_metadata: HashMap<MessageId, MessageMetadata>,
|
||||||
summary: String,
|
summary: String,
|
||||||
model: OpenAIModel,
|
model: OpenAiModel,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl SavedConversation {
|
impl SavedConversation {
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
use crate::{
|
use crate::{
|
||||||
assistant_settings::{AssistantDockPosition, AssistantSettings, OpenAIModel},
|
assistant_settings::{AssistantDockPosition, AssistantSettings, OpenAiModel},
|
||||||
codegen::{self, Codegen, CodegenKind},
|
codegen::{self, Codegen, CodegenKind},
|
||||||
prompts::generate_content_prompt,
|
prompts::generate_content_prompt,
|
||||||
Assist, CycleMessageRole, InlineAssist, MessageId, MessageMetadata, MessageStatus,
|
Assist, CycleMessageRole, InlineAssist, MessageId, MessageMetadata, MessageStatus,
|
||||||
|
@ -10,7 +10,7 @@ use ai::prompts::repository_context::PromptCodeSnippet;
|
||||||
use ai::{
|
use ai::{
|
||||||
auth::ProviderCredential,
|
auth::ProviderCredential,
|
||||||
completion::{CompletionProvider, CompletionRequest},
|
completion::{CompletionProvider, CompletionRequest},
|
||||||
providers::open_ai::{OpenAICompletionProvider, OpenAIRequest, RequestMessage},
|
providers::open_ai::{OpenAiCompletionProvider, OpenAiRequest, RequestMessage},
|
||||||
};
|
};
|
||||||
use anyhow::{anyhow, Result};
|
use anyhow::{anyhow, Result};
|
||||||
use chrono::{DateTime, Local};
|
use chrono::{DateTime, Local};
|
||||||
|
@ -123,7 +123,7 @@ impl AssistantPanel {
|
||||||
.unwrap_or_default();
|
.unwrap_or_default();
|
||||||
// Defaulting currently to GPT4, allow for this to be set via config.
|
// Defaulting currently to GPT4, allow for this to be set via config.
|
||||||
let completion_provider =
|
let completion_provider =
|
||||||
OpenAICompletionProvider::new("gpt-4".into(), cx.background_executor().clone())
|
OpenAiCompletionProvider::new("gpt-4".into(), cx.background_executor().clone())
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
// TODO: deserialize state.
|
// TODO: deserialize state.
|
||||||
|
@ -717,7 +717,7 @@ impl AssistantPanel {
|
||||||
content: prompt,
|
content: prompt,
|
||||||
});
|
});
|
||||||
|
|
||||||
let request = Box::new(OpenAIRequest {
|
let request = Box::new(OpenAiRequest {
|
||||||
model: model.full_name().into(),
|
model: model.full_name().into(),
|
||||||
messages,
|
messages,
|
||||||
stream: true,
|
stream: true,
|
||||||
|
@ -1393,7 +1393,7 @@ struct Conversation {
|
||||||
pending_summary: Task<Option<()>>,
|
pending_summary: Task<Option<()>>,
|
||||||
completion_count: usize,
|
completion_count: usize,
|
||||||
pending_completions: Vec<PendingCompletion>,
|
pending_completions: Vec<PendingCompletion>,
|
||||||
model: OpenAIModel,
|
model: OpenAiModel,
|
||||||
token_count: Option<usize>,
|
token_count: Option<usize>,
|
||||||
max_token_count: usize,
|
max_token_count: usize,
|
||||||
pending_token_count: Task<Option<()>>,
|
pending_token_count: Task<Option<()>>,
|
||||||
|
@ -1501,7 +1501,7 @@ impl Conversation {
|
||||||
};
|
};
|
||||||
let model = saved_conversation.model;
|
let model = saved_conversation.model;
|
||||||
let completion_provider: Arc<dyn CompletionProvider> = Arc::new(
|
let completion_provider: Arc<dyn CompletionProvider> = Arc::new(
|
||||||
OpenAICompletionProvider::new(
|
OpenAiCompletionProvider::new(
|
||||||
model.full_name().into(),
|
model.full_name().into(),
|
||||||
cx.background_executor().clone(),
|
cx.background_executor().clone(),
|
||||||
)
|
)
|
||||||
|
@ -1626,7 +1626,7 @@ impl Conversation {
|
||||||
Some(self.max_token_count as isize - self.token_count? as isize)
|
Some(self.max_token_count as isize - self.token_count? as isize)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn set_model(&mut self, model: OpenAIModel, cx: &mut ModelContext<Self>) {
|
fn set_model(&mut self, model: OpenAiModel, cx: &mut ModelContext<Self>) {
|
||||||
self.model = model;
|
self.model = model;
|
||||||
self.count_remaining_tokens(cx);
|
self.count_remaining_tokens(cx);
|
||||||
cx.notify();
|
cx.notify();
|
||||||
|
@ -1679,7 +1679,7 @@ impl Conversation {
|
||||||
return Default::default();
|
return Default::default();
|
||||||
}
|
}
|
||||||
|
|
||||||
let request: Box<dyn CompletionRequest> = Box::new(OpenAIRequest {
|
let request: Box<dyn CompletionRequest> = Box::new(OpenAiRequest {
|
||||||
model: self.model.full_name().to_string(),
|
model: self.model.full_name().to_string(),
|
||||||
messages: self
|
messages: self
|
||||||
.messages(cx)
|
.messages(cx)
|
||||||
|
@ -1962,7 +1962,7 @@ impl Conversation {
|
||||||
content: "Summarize the conversation into a short title without punctuation"
|
content: "Summarize the conversation into a short title without punctuation"
|
||||||
.into(),
|
.into(),
|
||||||
}));
|
}));
|
||||||
let request: Box<dyn CompletionRequest> = Box::new(OpenAIRequest {
|
let request: Box<dyn CompletionRequest> = Box::new(OpenAiRequest {
|
||||||
model: self.model.full_name().to_string(),
|
model: self.model.full_name().to_string(),
|
||||||
messages: messages.collect(),
|
messages: messages.collect(),
|
||||||
stream: true,
|
stream: true,
|
||||||
|
|
|
@ -5,7 +5,7 @@ use serde::{Deserialize, Serialize};
|
||||||
use settings::Settings;
|
use settings::Settings;
|
||||||
|
|
||||||
#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
|
#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
|
||||||
pub enum OpenAIModel {
|
pub enum OpenAiModel {
|
||||||
#[serde(rename = "gpt-3.5-turbo-0613")]
|
#[serde(rename = "gpt-3.5-turbo-0613")]
|
||||||
ThreePointFiveTurbo,
|
ThreePointFiveTurbo,
|
||||||
#[serde(rename = "gpt-4-0613")]
|
#[serde(rename = "gpt-4-0613")]
|
||||||
|
@ -14,28 +14,28 @@ pub enum OpenAIModel {
|
||||||
FourTurbo,
|
FourTurbo,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl OpenAIModel {
|
impl OpenAiModel {
|
||||||
pub fn full_name(&self) -> &'static str {
|
pub fn full_name(&self) -> &'static str {
|
||||||
match self {
|
match self {
|
||||||
OpenAIModel::ThreePointFiveTurbo => "gpt-3.5-turbo-0613",
|
OpenAiModel::ThreePointFiveTurbo => "gpt-3.5-turbo-0613",
|
||||||
OpenAIModel::Four => "gpt-4-0613",
|
OpenAiModel::Four => "gpt-4-0613",
|
||||||
OpenAIModel::FourTurbo => "gpt-4-1106-preview",
|
OpenAiModel::FourTurbo => "gpt-4-1106-preview",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn short_name(&self) -> &'static str {
|
pub fn short_name(&self) -> &'static str {
|
||||||
match self {
|
match self {
|
||||||
OpenAIModel::ThreePointFiveTurbo => "gpt-3.5-turbo",
|
OpenAiModel::ThreePointFiveTurbo => "gpt-3.5-turbo",
|
||||||
OpenAIModel::Four => "gpt-4",
|
OpenAiModel::Four => "gpt-4",
|
||||||
OpenAIModel::FourTurbo => "gpt-4-turbo",
|
OpenAiModel::FourTurbo => "gpt-4-turbo",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn cycle(&self) -> Self {
|
pub fn cycle(&self) -> Self {
|
||||||
match self {
|
match self {
|
||||||
OpenAIModel::ThreePointFiveTurbo => OpenAIModel::Four,
|
OpenAiModel::ThreePointFiveTurbo => OpenAiModel::Four,
|
||||||
OpenAIModel::Four => OpenAIModel::FourTurbo,
|
OpenAiModel::Four => OpenAiModel::FourTurbo,
|
||||||
OpenAIModel::FourTurbo => OpenAIModel::ThreePointFiveTurbo,
|
OpenAiModel::FourTurbo => OpenAiModel::ThreePointFiveTurbo,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -54,7 +54,7 @@ pub struct AssistantSettings {
|
||||||
pub dock: AssistantDockPosition,
|
pub dock: AssistantDockPosition,
|
||||||
pub default_width: Pixels,
|
pub default_width: Pixels,
|
||||||
pub default_height: Pixels,
|
pub default_height: Pixels,
|
||||||
pub default_open_ai_model: OpenAIModel,
|
pub default_open_ai_model: OpenAiModel,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Assistant panel settings
|
/// Assistant panel settings
|
||||||
|
@ -79,7 +79,7 @@ pub struct AssistantSettingsContent {
|
||||||
/// The default OpenAI model to use when starting new conversations.
|
/// The default OpenAI model to use when starting new conversations.
|
||||||
///
|
///
|
||||||
/// Default: gpt-4-1106-preview
|
/// Default: gpt-4-1106-preview
|
||||||
pub default_open_ai_model: Option<OpenAIModel>,
|
pub default_open_ai_model: Option<OpenAiModel>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Settings for AssistantSettings {
|
impl Settings for AssistantSettings {
|
||||||
|
|
|
@ -4,7 +4,7 @@ use ai::prompts::file_context::FileContext;
|
||||||
use ai::prompts::generate::GenerateInlineContent;
|
use ai::prompts::generate::GenerateInlineContent;
|
||||||
use ai::prompts::preamble::EngineerPreamble;
|
use ai::prompts::preamble::EngineerPreamble;
|
||||||
use ai::prompts::repository_context::{PromptCodeSnippet, RepositoryContext};
|
use ai::prompts::repository_context::{PromptCodeSnippet, RepositoryContext};
|
||||||
use ai::providers::open_ai::OpenAILanguageModel;
|
use ai::providers::open_ai::OpenAiLanguageModel;
|
||||||
use language::{BufferSnapshot, OffsetRangeExt, ToOffset};
|
use language::{BufferSnapshot, OffsetRangeExt, ToOffset};
|
||||||
use std::cmp::{self, Reverse};
|
use std::cmp::{self, Reverse};
|
||||||
use std::ops::Range;
|
use std::ops::Range;
|
||||||
|
@ -131,7 +131,7 @@ pub fn generate_content_prompt(
|
||||||
project_name: Option<String>,
|
project_name: Option<String>,
|
||||||
) -> anyhow::Result<String> {
|
) -> anyhow::Result<String> {
|
||||||
// Using new Prompt Templates
|
// Using new Prompt Templates
|
||||||
let openai_model: Arc<dyn LanguageModel> = Arc::new(OpenAILanguageModel::load(model));
|
let openai_model: Arc<dyn LanguageModel> = Arc::new(OpenAiLanguageModel::load(model));
|
||||||
let lang_name = if let Some(language_name) = language_name {
|
let lang_name = if let Some(language_name) = language_name {
|
||||||
Some(language_name.to_string())
|
Some(language_name.to_string())
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -8,7 +8,7 @@ mod semantic_index_tests;
|
||||||
|
|
||||||
use crate::semantic_index_settings::SemanticIndexSettings;
|
use crate::semantic_index_settings::SemanticIndexSettings;
|
||||||
use ai::embedding::{Embedding, EmbeddingProvider};
|
use ai::embedding::{Embedding, EmbeddingProvider};
|
||||||
use ai::providers::open_ai::OpenAIEmbeddingProvider;
|
use ai::providers::open_ai::OpenAiEmbeddingProvider;
|
||||||
use anyhow::{anyhow, Context as _, Result};
|
use anyhow::{anyhow, Context as _, Result};
|
||||||
use collections::{BTreeMap, HashMap, HashSet};
|
use collections::{BTreeMap, HashMap, HashSet};
|
||||||
use db::VectorDatabase;
|
use db::VectorDatabase;
|
||||||
|
@ -91,7 +91,7 @@ pub fn init(
|
||||||
|
|
||||||
cx.spawn(move |cx| async move {
|
cx.spawn(move |cx| async move {
|
||||||
let embedding_provider =
|
let embedding_provider =
|
||||||
OpenAIEmbeddingProvider::new(http_client, cx.background_executor().clone()).await;
|
OpenAiEmbeddingProvider::new(http_client, cx.background_executor().clone()).await;
|
||||||
let semantic_index = SemanticIndex::new(
|
let semantic_index = SemanticIndex::new(
|
||||||
fs,
|
fs,
|
||||||
db_file_path,
|
db_file_path,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue