Update casing of "OpenAI" in identifiers to match Rust conventions (#6940)

This PR updates the casing of "OpenAI" when used in Rust identifiers to
match the [Rust naming
guidelines](https://rust-lang.github.io/api-guidelines/naming.html):

> In `UpperCamelCase`, acronyms and contractions of compound words count
as one word: use `Uuid` rather than `UUID`, `Usize` rather than `USize`
or `Stdin` rather than `StdIn`.

Release Notes:

- N/A
This commit is contained in:
Marshall Bowers 2024-01-28 12:01:10 -05:00 committed by GitHub
parent e8bf06fc42
commit 027f055841
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 85 additions and 96 deletions

View file

@ -0,0 +1,9 @@
pub mod completion;
pub mod embedding;
pub mod model;
pub use completion::*;
pub use embedding::*;
pub use model::OpenAiLanguageModel;
pub const OPEN_AI_API_URL: &'static str = "https://api.openai.com/v1";

View file

@ -21,7 +21,7 @@ use crate::{
models::LanguageModel, models::LanguageModel,
}; };
use crate::providers::open_ai::{OpenAILanguageModel, OPENAI_API_URL}; use crate::providers::open_ai::{OpenAiLanguageModel, OPEN_AI_API_URL};
#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)] #[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
#[serde(rename_all = "lowercase")] #[serde(rename_all = "lowercase")]
@ -58,7 +58,7 @@ pub struct RequestMessage {
} }
#[derive(Debug, Default, Serialize)] #[derive(Debug, Default, Serialize)]
pub struct OpenAIRequest { pub struct OpenAiRequest {
pub model: String, pub model: String,
pub messages: Vec<RequestMessage>, pub messages: Vec<RequestMessage>,
pub stream: bool, pub stream: bool,
@ -66,7 +66,7 @@ pub struct OpenAIRequest {
pub temperature: f32, pub temperature: f32,
} }
impl CompletionRequest for OpenAIRequest { impl CompletionRequest for OpenAiRequest {
fn data(&self) -> serde_json::Result<String> { fn data(&self) -> serde_json::Result<String> {
serde_json::to_string(self) serde_json::to_string(self)
} }
@ -79,7 +79,7 @@ pub struct ResponseMessage {
} }
#[derive(Deserialize, Debug)] #[derive(Deserialize, Debug)]
pub struct OpenAIUsage { pub struct OpenAiUsage {
pub prompt_tokens: u32, pub prompt_tokens: u32,
pub completion_tokens: u32, pub completion_tokens: u32,
pub total_tokens: u32, pub total_tokens: u32,
@ -93,20 +93,20 @@ pub struct ChatChoiceDelta {
} }
#[derive(Deserialize, Debug)] #[derive(Deserialize, Debug)]
pub struct OpenAIResponseStreamEvent { pub struct OpenAiResponseStreamEvent {
pub id: Option<String>, pub id: Option<String>,
pub object: String, pub object: String,
pub created: u32, pub created: u32,
pub model: String, pub model: String,
pub choices: Vec<ChatChoiceDelta>, pub choices: Vec<ChatChoiceDelta>,
pub usage: Option<OpenAIUsage>, pub usage: Option<OpenAiUsage>,
} }
pub async fn stream_completion( pub async fn stream_completion(
credential: ProviderCredential, credential: ProviderCredential,
executor: BackgroundExecutor, executor: BackgroundExecutor,
request: Box<dyn CompletionRequest>, request: Box<dyn CompletionRequest>,
) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> { ) -> Result<impl Stream<Item = Result<OpenAiResponseStreamEvent>>> {
let api_key = match credential { let api_key = match credential {
ProviderCredential::Credentials { api_key } => api_key, ProviderCredential::Credentials { api_key } => api_key,
_ => { _ => {
@ -114,10 +114,10 @@ pub async fn stream_completion(
} }
}; };
let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAIResponseStreamEvent>>(); let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAiResponseStreamEvent>>();
let json_data = request.data()?; let json_data = request.data()?;
let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions")) let mut response = Request::post(format!("{OPEN_AI_API_URL}/chat/completions"))
.header("Content-Type", "application/json") .header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", api_key)) .header("Authorization", format!("Bearer {}", api_key))
.body(json_data)? .body(json_data)?
@ -132,7 +132,7 @@ pub async fn stream_completion(
fn parse_line( fn parse_line(
line: Result<String, io::Error>, line: Result<String, io::Error>,
) -> Result<Option<OpenAIResponseStreamEvent>> { ) -> Result<Option<OpenAiResponseStreamEvent>> {
if let Some(data) = line?.strip_prefix("data: ") { if let Some(data) = line?.strip_prefix("data: ") {
let event = serde_json::from_str(data)?; let event = serde_json::from_str(data)?;
Ok(Some(event)) Ok(Some(event))
@ -169,16 +169,16 @@ pub async fn stream_completion(
response.body_mut().read_to_string(&mut body).await?; response.body_mut().read_to_string(&mut body).await?;
#[derive(Deserialize)] #[derive(Deserialize)]
struct OpenAIResponse { struct OpenAiResponse {
error: OpenAIError, error: OpenAiError,
} }
#[derive(Deserialize)] #[derive(Deserialize)]
struct OpenAIError { struct OpenAiError {
message: String, message: String,
} }
match serde_json::from_str::<OpenAIResponse>(&body) { match serde_json::from_str::<OpenAiResponse>(&body) {
Ok(response) if !response.error.message.is_empty() => Err(anyhow!( Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
"Failed to connect to OpenAI API: {}", "Failed to connect to OpenAI API: {}",
response.error.message, response.error.message,
@ -194,16 +194,16 @@ pub async fn stream_completion(
} }
#[derive(Clone)] #[derive(Clone)]
pub struct OpenAICompletionProvider { pub struct OpenAiCompletionProvider {
model: OpenAILanguageModel, model: OpenAiLanguageModel,
credential: Arc<RwLock<ProviderCredential>>, credential: Arc<RwLock<ProviderCredential>>,
executor: BackgroundExecutor, executor: BackgroundExecutor,
} }
impl OpenAICompletionProvider { impl OpenAiCompletionProvider {
pub async fn new(model_name: String, executor: BackgroundExecutor) -> Self { pub async fn new(model_name: String, executor: BackgroundExecutor) -> Self {
let model = executor let model = executor
.spawn(async move { OpenAILanguageModel::load(&model_name) }) .spawn(async move { OpenAiLanguageModel::load(&model_name) })
.await; .await;
let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials)); let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials));
Self { Self {
@ -214,7 +214,7 @@ impl OpenAICompletionProvider {
} }
} }
impl CredentialProvider for OpenAICompletionProvider { impl CredentialProvider for OpenAiCompletionProvider {
fn has_credentials(&self) -> bool { fn has_credentials(&self) -> bool {
match *self.credential.read() { match *self.credential.read() {
ProviderCredential::Credentials { .. } => true, ProviderCredential::Credentials { .. } => true,
@ -232,7 +232,7 @@ impl CredentialProvider for OpenAICompletionProvider {
if let Some(api_key) = env::var("OPENAI_API_KEY").log_err() { if let Some(api_key) = env::var("OPENAI_API_KEY").log_err() {
async move { ProviderCredential::Credentials { api_key } }.boxed() async move { ProviderCredential::Credentials { api_key } }.boxed()
} else { } else {
let credentials = cx.read_credentials(OPENAI_API_URL); let credentials = cx.read_credentials(OPEN_AI_API_URL);
async move { async move {
if let Some(Some((_, api_key))) = credentials.await.log_err() { if let Some(Some((_, api_key))) = credentials.await.log_err() {
if let Some(api_key) = String::from_utf8(api_key).log_err() { if let Some(api_key) = String::from_utf8(api_key).log_err() {
@ -266,7 +266,7 @@ impl CredentialProvider for OpenAICompletionProvider {
let credential = credential.clone(); let credential = credential.clone();
let write_credentials = match credential { let write_credentials = match credential {
ProviderCredential::Credentials { api_key } => { ProviderCredential::Credentials { api_key } => {
Some(cx.write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes())) Some(cx.write_credentials(OPEN_AI_API_URL, "Bearer", api_key.as_bytes()))
} }
_ => None, _ => None,
}; };
@ -281,7 +281,7 @@ impl CredentialProvider for OpenAICompletionProvider {
fn delete_credentials(&self, cx: &mut AppContext) -> BoxFuture<()> { fn delete_credentials(&self, cx: &mut AppContext) -> BoxFuture<()> {
*self.credential.write() = ProviderCredential::NoCredentials; *self.credential.write() = ProviderCredential::NoCredentials;
let delete_credentials = cx.delete_credentials(OPENAI_API_URL); let delete_credentials = cx.delete_credentials(OPEN_AI_API_URL);
async move { async move {
delete_credentials.await.log_err(); delete_credentials.await.log_err();
} }
@ -289,7 +289,7 @@ impl CredentialProvider for OpenAICompletionProvider {
} }
} }
impl CompletionProvider for OpenAICompletionProvider { impl CompletionProvider for OpenAiCompletionProvider {
fn base_model(&self) -> Box<dyn LanguageModel> { fn base_model(&self) -> Box<dyn LanguageModel> {
let model: Box<dyn LanguageModel> = Box::new(self.model.clone()); let model: Box<dyn LanguageModel> = Box::new(self.model.clone());
model model

View file

@ -25,17 +25,17 @@ use util::ResultExt;
use crate::auth::{CredentialProvider, ProviderCredential}; use crate::auth::{CredentialProvider, ProviderCredential};
use crate::embedding::{Embedding, EmbeddingProvider}; use crate::embedding::{Embedding, EmbeddingProvider};
use crate::models::LanguageModel; use crate::models::LanguageModel;
use crate::providers::open_ai::OpenAILanguageModel; use crate::providers::open_ai::OpenAiLanguageModel;
use crate::providers::open_ai::OPENAI_API_URL; use crate::providers::open_ai::OPEN_AI_API_URL;
lazy_static! { lazy_static! {
static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap(); static ref OPEN_AI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap();
} }
#[derive(Clone)] #[derive(Clone)]
pub struct OpenAIEmbeddingProvider { pub struct OpenAiEmbeddingProvider {
model: OpenAILanguageModel, model: OpenAiLanguageModel,
credential: Arc<RwLock<ProviderCredential>>, credential: Arc<RwLock<ProviderCredential>>,
pub client: Arc<dyn HttpClient>, pub client: Arc<dyn HttpClient>,
pub executor: BackgroundExecutor, pub executor: BackgroundExecutor,
@ -44,42 +44,42 @@ pub struct OpenAIEmbeddingProvider {
} }
#[derive(Serialize)] #[derive(Serialize)]
struct OpenAIEmbeddingRequest<'a> { struct OpenAiEmbeddingRequest<'a> {
model: &'static str, model: &'static str,
input: Vec<&'a str>, input: Vec<&'a str>,
} }
#[derive(Deserialize)] #[derive(Deserialize)]
struct OpenAIEmbeddingResponse { struct OpenAiEmbeddingResponse {
data: Vec<OpenAIEmbedding>, data: Vec<OpenAiEmbedding>,
usage: OpenAIEmbeddingUsage, usage: OpenAiEmbeddingUsage,
} }
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
struct OpenAIEmbedding { struct OpenAiEmbedding {
embedding: Vec<f32>, embedding: Vec<f32>,
index: usize, index: usize,
object: String, object: String,
} }
#[derive(Deserialize)] #[derive(Deserialize)]
struct OpenAIEmbeddingUsage { struct OpenAiEmbeddingUsage {
prompt_tokens: usize, prompt_tokens: usize,
total_tokens: usize, total_tokens: usize,
} }
impl OpenAIEmbeddingProvider { impl OpenAiEmbeddingProvider {
pub async fn new(client: Arc<dyn HttpClient>, executor: BackgroundExecutor) -> Self { pub async fn new(client: Arc<dyn HttpClient>, executor: BackgroundExecutor) -> Self {
let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None); let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None);
let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx)); let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx));
// Loading the model is expensive, so ensure this runs off the main thread. // Loading the model is expensive, so ensure this runs off the main thread.
let model = executor let model = executor
.spawn(async move { OpenAILanguageModel::load("text-embedding-ada-002") }) .spawn(async move { OpenAiLanguageModel::load("text-embedding-ada-002") })
.await; .await;
let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials)); let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials));
OpenAIEmbeddingProvider { OpenAiEmbeddingProvider {
model, model,
credential, credential,
client, client,
@ -140,7 +140,7 @@ impl OpenAIEmbeddingProvider {
.header("Content-Type", "application/json") .header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", api_key)) .header("Authorization", format!("Bearer {}", api_key))
.body( .body(
serde_json::to_string(&OpenAIEmbeddingRequest { serde_json::to_string(&OpenAiEmbeddingRequest {
input: spans.clone(), input: spans.clone(),
model: "text-embedding-ada-002", model: "text-embedding-ada-002",
}) })
@ -152,7 +152,7 @@ impl OpenAIEmbeddingProvider {
} }
} }
impl CredentialProvider for OpenAIEmbeddingProvider { impl CredentialProvider for OpenAiEmbeddingProvider {
fn has_credentials(&self) -> bool { fn has_credentials(&self) -> bool {
match *self.credential.read() { match *self.credential.read() {
ProviderCredential::Credentials { .. } => true, ProviderCredential::Credentials { .. } => true,
@ -170,7 +170,7 @@ impl CredentialProvider for OpenAIEmbeddingProvider {
if let Some(api_key) = env::var("OPENAI_API_KEY").log_err() { if let Some(api_key) = env::var("OPENAI_API_KEY").log_err() {
async move { ProviderCredential::Credentials { api_key } }.boxed() async move { ProviderCredential::Credentials { api_key } }.boxed()
} else { } else {
let credentials = cx.read_credentials(OPENAI_API_URL); let credentials = cx.read_credentials(OPEN_AI_API_URL);
async move { async move {
if let Some(Some((_, api_key))) = credentials.await.log_err() { if let Some(Some((_, api_key))) = credentials.await.log_err() {
if let Some(api_key) = String::from_utf8(api_key).log_err() { if let Some(api_key) = String::from_utf8(api_key).log_err() {
@ -204,7 +204,7 @@ impl CredentialProvider for OpenAIEmbeddingProvider {
let credential = credential.clone(); let credential = credential.clone();
let write_credentials = match credential { let write_credentials = match credential {
ProviderCredential::Credentials { api_key } => { ProviderCredential::Credentials { api_key } => {
Some(cx.write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes())) Some(cx.write_credentials(OPEN_AI_API_URL, "Bearer", api_key.as_bytes()))
} }
_ => None, _ => None,
}; };
@ -219,7 +219,7 @@ impl CredentialProvider for OpenAIEmbeddingProvider {
fn delete_credentials(&self, cx: &mut AppContext) -> BoxFuture<()> { fn delete_credentials(&self, cx: &mut AppContext) -> BoxFuture<()> {
*self.credential.write() = ProviderCredential::NoCredentials; *self.credential.write() = ProviderCredential::NoCredentials;
let delete_credentials = cx.delete_credentials(OPENAI_API_URL); let delete_credentials = cx.delete_credentials(OPEN_AI_API_URL);
async move { async move {
delete_credentials.await.log_err(); delete_credentials.await.log_err();
} }
@ -228,7 +228,7 @@ impl CredentialProvider for OpenAIEmbeddingProvider {
} }
#[async_trait] #[async_trait]
impl EmbeddingProvider for OpenAIEmbeddingProvider { impl EmbeddingProvider for OpenAiEmbeddingProvider {
fn base_model(&self) -> Box<dyn LanguageModel> { fn base_model(&self) -> Box<dyn LanguageModel> {
let model: Box<dyn LanguageModel> = Box::new(self.model.clone()); let model: Box<dyn LanguageModel> = Box::new(self.model.clone());
model model
@ -270,7 +270,7 @@ impl EmbeddingProvider for OpenAIEmbeddingProvider {
StatusCode::OK => { StatusCode::OK => {
let mut body = String::new(); let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?; response.body_mut().read_to_string(&mut body).await?;
let response: OpenAIEmbeddingResponse = serde_json::from_str(&body)?; let response: OpenAiEmbeddingResponse = serde_json::from_str(&body)?;
log::trace!( log::trace!(
"openai embedding completed. tokens: {:?}", "openai embedding completed. tokens: {:?}",

View file

@ -1,9 +0,0 @@
pub mod completion;
pub mod embedding;
pub mod model;
pub use completion::*;
pub use embedding::*;
pub use model::OpenAILanguageModel;
pub const OPENAI_API_URL: &'static str = "https://api.openai.com/v1";

View file

@ -5,22 +5,22 @@ use util::ResultExt;
use crate::models::{LanguageModel, TruncationDirection}; use crate::models::{LanguageModel, TruncationDirection};
#[derive(Clone)] #[derive(Clone)]
pub struct OpenAILanguageModel { pub struct OpenAiLanguageModel {
name: String, name: String,
bpe: Option<CoreBPE>, bpe: Option<CoreBPE>,
} }
impl OpenAILanguageModel { impl OpenAiLanguageModel {
pub fn load(model_name: &str) -> Self { pub fn load(model_name: &str) -> Self {
let bpe = tiktoken_rs::get_bpe_from_model(model_name).log_err(); let bpe = tiktoken_rs::get_bpe_from_model(model_name).log_err();
OpenAILanguageModel { OpenAiLanguageModel {
name: model_name.to_string(), name: model_name.to_string(),
bpe, bpe,
} }
} }
} }
impl LanguageModel for OpenAILanguageModel { impl LanguageModel for OpenAiLanguageModel {
fn name(&self) -> String { fn name(&self) -> String {
self.name.clone() self.name.clone()
} }

View file

@ -1,11 +0,0 @@
pub trait LanguageModel {
fn name(&self) -> String;
fn count_tokens(&self, content: &str) -> anyhow::Result<usize>;
fn truncate(
&self,
content: &str,
length: usize,
direction: TruncationDirection,
) -> anyhow::Result<String>;
fn capacity(&self) -> anyhow::Result<usize>;
}

View file

@ -7,7 +7,7 @@ mod streaming_diff;
use ai::providers::open_ai::Role; use ai::providers::open_ai::Role;
use anyhow::Result; use anyhow::Result;
pub use assistant_panel::AssistantPanel; pub use assistant_panel::AssistantPanel;
use assistant_settings::OpenAIModel; use assistant_settings::OpenAiModel;
use chrono::{DateTime, Local}; use chrono::{DateTime, Local};
use collections::HashMap; use collections::HashMap;
use fs::Fs; use fs::Fs;
@ -68,7 +68,7 @@ struct SavedConversation {
messages: Vec<SavedMessage>, messages: Vec<SavedMessage>,
message_metadata: HashMap<MessageId, MessageMetadata>, message_metadata: HashMap<MessageId, MessageMetadata>,
summary: String, summary: String,
model: OpenAIModel, model: OpenAiModel,
} }
impl SavedConversation { impl SavedConversation {

View file

@ -1,5 +1,5 @@
use crate::{ use crate::{
assistant_settings::{AssistantDockPosition, AssistantSettings, OpenAIModel}, assistant_settings::{AssistantDockPosition, AssistantSettings, OpenAiModel},
codegen::{self, Codegen, CodegenKind}, codegen::{self, Codegen, CodegenKind},
prompts::generate_content_prompt, prompts::generate_content_prompt,
Assist, CycleMessageRole, InlineAssist, MessageId, MessageMetadata, MessageStatus, Assist, CycleMessageRole, InlineAssist, MessageId, MessageMetadata, MessageStatus,
@ -10,7 +10,7 @@ use ai::prompts::repository_context::PromptCodeSnippet;
use ai::{ use ai::{
auth::ProviderCredential, auth::ProviderCredential,
completion::{CompletionProvider, CompletionRequest}, completion::{CompletionProvider, CompletionRequest},
providers::open_ai::{OpenAICompletionProvider, OpenAIRequest, RequestMessage}, providers::open_ai::{OpenAiCompletionProvider, OpenAiRequest, RequestMessage},
}; };
use anyhow::{anyhow, Result}; use anyhow::{anyhow, Result};
use chrono::{DateTime, Local}; use chrono::{DateTime, Local};
@ -123,7 +123,7 @@ impl AssistantPanel {
.unwrap_or_default(); .unwrap_or_default();
// Defaulting currently to GPT4, allow for this to be set via config. // Defaulting currently to GPT4, allow for this to be set via config.
let completion_provider = let completion_provider =
OpenAICompletionProvider::new("gpt-4".into(), cx.background_executor().clone()) OpenAiCompletionProvider::new("gpt-4".into(), cx.background_executor().clone())
.await; .await;
// TODO: deserialize state. // TODO: deserialize state.
@ -717,7 +717,7 @@ impl AssistantPanel {
content: prompt, content: prompt,
}); });
let request = Box::new(OpenAIRequest { let request = Box::new(OpenAiRequest {
model: model.full_name().into(), model: model.full_name().into(),
messages, messages,
stream: true, stream: true,
@ -1393,7 +1393,7 @@ struct Conversation {
pending_summary: Task<Option<()>>, pending_summary: Task<Option<()>>,
completion_count: usize, completion_count: usize,
pending_completions: Vec<PendingCompletion>, pending_completions: Vec<PendingCompletion>,
model: OpenAIModel, model: OpenAiModel,
token_count: Option<usize>, token_count: Option<usize>,
max_token_count: usize, max_token_count: usize,
pending_token_count: Task<Option<()>>, pending_token_count: Task<Option<()>>,
@ -1501,7 +1501,7 @@ impl Conversation {
}; };
let model = saved_conversation.model; let model = saved_conversation.model;
let completion_provider: Arc<dyn CompletionProvider> = Arc::new( let completion_provider: Arc<dyn CompletionProvider> = Arc::new(
OpenAICompletionProvider::new( OpenAiCompletionProvider::new(
model.full_name().into(), model.full_name().into(),
cx.background_executor().clone(), cx.background_executor().clone(),
) )
@ -1626,7 +1626,7 @@ impl Conversation {
Some(self.max_token_count as isize - self.token_count? as isize) Some(self.max_token_count as isize - self.token_count? as isize)
} }
fn set_model(&mut self, model: OpenAIModel, cx: &mut ModelContext<Self>) { fn set_model(&mut self, model: OpenAiModel, cx: &mut ModelContext<Self>) {
self.model = model; self.model = model;
self.count_remaining_tokens(cx); self.count_remaining_tokens(cx);
cx.notify(); cx.notify();
@ -1679,7 +1679,7 @@ impl Conversation {
return Default::default(); return Default::default();
} }
let request: Box<dyn CompletionRequest> = Box::new(OpenAIRequest { let request: Box<dyn CompletionRequest> = Box::new(OpenAiRequest {
model: self.model.full_name().to_string(), model: self.model.full_name().to_string(),
messages: self messages: self
.messages(cx) .messages(cx)
@ -1962,7 +1962,7 @@ impl Conversation {
content: "Summarize the conversation into a short title without punctuation" content: "Summarize the conversation into a short title without punctuation"
.into(), .into(),
})); }));
let request: Box<dyn CompletionRequest> = Box::new(OpenAIRequest { let request: Box<dyn CompletionRequest> = Box::new(OpenAiRequest {
model: self.model.full_name().to_string(), model: self.model.full_name().to_string(),
messages: messages.collect(), messages: messages.collect(),
stream: true, stream: true,

View file

@ -5,7 +5,7 @@ use serde::{Deserialize, Serialize};
use settings::Settings; use settings::Settings;
#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)] #[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
pub enum OpenAIModel { pub enum OpenAiModel {
#[serde(rename = "gpt-3.5-turbo-0613")] #[serde(rename = "gpt-3.5-turbo-0613")]
ThreePointFiveTurbo, ThreePointFiveTurbo,
#[serde(rename = "gpt-4-0613")] #[serde(rename = "gpt-4-0613")]
@ -14,28 +14,28 @@ pub enum OpenAIModel {
FourTurbo, FourTurbo,
} }
impl OpenAIModel { impl OpenAiModel {
pub fn full_name(&self) -> &'static str { pub fn full_name(&self) -> &'static str {
match self { match self {
OpenAIModel::ThreePointFiveTurbo => "gpt-3.5-turbo-0613", OpenAiModel::ThreePointFiveTurbo => "gpt-3.5-turbo-0613",
OpenAIModel::Four => "gpt-4-0613", OpenAiModel::Four => "gpt-4-0613",
OpenAIModel::FourTurbo => "gpt-4-1106-preview", OpenAiModel::FourTurbo => "gpt-4-1106-preview",
} }
} }
pub fn short_name(&self) -> &'static str { pub fn short_name(&self) -> &'static str {
match self { match self {
OpenAIModel::ThreePointFiveTurbo => "gpt-3.5-turbo", OpenAiModel::ThreePointFiveTurbo => "gpt-3.5-turbo",
OpenAIModel::Four => "gpt-4", OpenAiModel::Four => "gpt-4",
OpenAIModel::FourTurbo => "gpt-4-turbo", OpenAiModel::FourTurbo => "gpt-4-turbo",
} }
} }
pub fn cycle(&self) -> Self { pub fn cycle(&self) -> Self {
match self { match self {
OpenAIModel::ThreePointFiveTurbo => OpenAIModel::Four, OpenAiModel::ThreePointFiveTurbo => OpenAiModel::Four,
OpenAIModel::Four => OpenAIModel::FourTurbo, OpenAiModel::Four => OpenAiModel::FourTurbo,
OpenAIModel::FourTurbo => OpenAIModel::ThreePointFiveTurbo, OpenAiModel::FourTurbo => OpenAiModel::ThreePointFiveTurbo,
} }
} }
} }
@ -54,7 +54,7 @@ pub struct AssistantSettings {
pub dock: AssistantDockPosition, pub dock: AssistantDockPosition,
pub default_width: Pixels, pub default_width: Pixels,
pub default_height: Pixels, pub default_height: Pixels,
pub default_open_ai_model: OpenAIModel, pub default_open_ai_model: OpenAiModel,
} }
/// Assistant panel settings /// Assistant panel settings
@ -79,7 +79,7 @@ pub struct AssistantSettingsContent {
/// The default OpenAI model to use when starting new conversations. /// The default OpenAI model to use when starting new conversations.
/// ///
/// Default: gpt-4-1106-preview /// Default: gpt-4-1106-preview
pub default_open_ai_model: Option<OpenAIModel>, pub default_open_ai_model: Option<OpenAiModel>,
} }
impl Settings for AssistantSettings { impl Settings for AssistantSettings {

View file

@ -4,7 +4,7 @@ use ai::prompts::file_context::FileContext;
use ai::prompts::generate::GenerateInlineContent; use ai::prompts::generate::GenerateInlineContent;
use ai::prompts::preamble::EngineerPreamble; use ai::prompts::preamble::EngineerPreamble;
use ai::prompts::repository_context::{PromptCodeSnippet, RepositoryContext}; use ai::prompts::repository_context::{PromptCodeSnippet, RepositoryContext};
use ai::providers::open_ai::OpenAILanguageModel; use ai::providers::open_ai::OpenAiLanguageModel;
use language::{BufferSnapshot, OffsetRangeExt, ToOffset}; use language::{BufferSnapshot, OffsetRangeExt, ToOffset};
use std::cmp::{self, Reverse}; use std::cmp::{self, Reverse};
use std::ops::Range; use std::ops::Range;
@ -131,7 +131,7 @@ pub fn generate_content_prompt(
project_name: Option<String>, project_name: Option<String>,
) -> anyhow::Result<String> { ) -> anyhow::Result<String> {
// Using new Prompt Templates // Using new Prompt Templates
let openai_model: Arc<dyn LanguageModel> = Arc::new(OpenAILanguageModel::load(model)); let openai_model: Arc<dyn LanguageModel> = Arc::new(OpenAiLanguageModel::load(model));
let lang_name = if let Some(language_name) = language_name { let lang_name = if let Some(language_name) = language_name {
Some(language_name.to_string()) Some(language_name.to_string())
} else { } else {

View file

@ -8,7 +8,7 @@ mod semantic_index_tests;
use crate::semantic_index_settings::SemanticIndexSettings; use crate::semantic_index_settings::SemanticIndexSettings;
use ai::embedding::{Embedding, EmbeddingProvider}; use ai::embedding::{Embedding, EmbeddingProvider};
use ai::providers::open_ai::OpenAIEmbeddingProvider; use ai::providers::open_ai::OpenAiEmbeddingProvider;
use anyhow::{anyhow, Context as _, Result}; use anyhow::{anyhow, Context as _, Result};
use collections::{BTreeMap, HashMap, HashSet}; use collections::{BTreeMap, HashMap, HashSet};
use db::VectorDatabase; use db::VectorDatabase;
@ -91,7 +91,7 @@ pub fn init(
cx.spawn(move |cx| async move { cx.spawn(move |cx| async move {
let embedding_provider = let embedding_provider =
OpenAIEmbeddingProvider::new(http_client, cx.background_executor().clone()).await; OpenAiEmbeddingProvider::new(http_client, cx.background_executor().clone()).await;
let semantic_index = SemanticIndex::new( let semantic_index = SemanticIndex::new(
fs, fs,
db_file_path, db_file_path,