diff --git a/assets/settings/default.json b/assets/settings/default.json index 08faedbed6..24412b883b 100644 --- a/assets/settings/default.json +++ b/assets/settings/default.json @@ -138,7 +138,13 @@ // Default width when the assistant is docked to the left or right. "default_width": 640, // Default height when the assistant is docked to the bottom. - "default_height": 320 + "default_height": 320, + // The default OpenAI model to use when starting new conversations. This + // setting can take two values: + // + // 1. "gpt-3.5-turbo-0613"" + // 2. "gpt-4-0613"" + "default_open_ai_model": "gpt-4-0613" }, // Whether the screen sharing icon is shown in the os status bar. "show_call_status_icon": true, diff --git a/crates/ai/src/ai.rs b/crates/ai/src/ai.rs index 7cc5f08f7c..d2be651bd5 100644 --- a/crates/ai/src/ai.rs +++ b/crates/ai/src/ai.rs @@ -3,6 +3,7 @@ mod assistant_settings; use anyhow::Result; pub use assistant::AssistantPanel; +use assistant_settings::OpenAIModel; use chrono::{DateTime, Local}; use collections::HashMap; use fs::Fs; @@ -60,7 +61,7 @@ struct SavedConversation { messages: Vec, message_metadata: HashMap, summary: String, - model: String, + model: OpenAIModel, } impl SavedConversation { diff --git a/crates/ai/src/assistant.rs b/crates/ai/src/assistant.rs index e5026182ed..81299bbdc2 100644 --- a/crates/ai/src/assistant.rs +++ b/crates/ai/src/assistant.rs @@ -1,5 +1,5 @@ use crate::{ - assistant_settings::{AssistantDockPosition, AssistantSettings}, + assistant_settings::{AssistantDockPosition, AssistantSettings, OpenAIModel}, MessageId, MessageMetadata, MessageStatus, OpenAIRequest, OpenAIResponseStreamEvent, RequestMessage, Role, SavedConversation, SavedConversationMetadata, SavedMessage, }; @@ -833,7 +833,7 @@ struct Conversation { pending_summary: Task>, completion_count: usize, pending_completions: Vec, - model: String, + model: OpenAIModel, token_count: Option, max_token_count: usize, pending_token_count: Task>, @@ -853,7 +853,6 @@ impl Conversation { language_registry: Arc, cx: &mut ModelContext, ) -> Self { - let model = "gpt-3.5-turbo-0613"; let markdown = language_registry.language_for_name("Markdown"); let buffer = cx.add_model(|cx| { let mut buffer = Buffer::new(0, "", cx); @@ -872,6 +871,9 @@ impl Conversation { buffer }); + let settings = settings::get::(cx); + let model = settings.default_open_ai_model.clone(); + let mut this = Self { message_anchors: Default::default(), messages_metadata: Default::default(), @@ -881,9 +883,9 @@ impl Conversation { completion_count: Default::default(), pending_completions: Default::default(), token_count: None, - max_token_count: tiktoken_rs::model::get_context_size(model), + max_token_count: tiktoken_rs::model::get_context_size(&model.full_name()), pending_token_count: Task::ready(None), - model: model.into(), + model: model.clone(), _subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)], pending_save: Task::ready(Ok(())), path: None, @@ -977,7 +979,7 @@ impl Conversation { completion_count: Default::default(), pending_completions: Default::default(), token_count: None, - max_token_count: tiktoken_rs::model::get_context_size(&model), + max_token_count: tiktoken_rs::model::get_context_size(&model.full_name()), pending_token_count: Task::ready(None), model, _subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)], @@ -1031,13 +1033,16 @@ impl Conversation { cx.background().timer(Duration::from_millis(200)).await; let token_count = cx .background() - .spawn(async move { tiktoken_rs::num_tokens_from_messages(&model, &messages) }) + .spawn(async move { + tiktoken_rs::num_tokens_from_messages(&model.full_name(), &messages) + }) .await?; this.upgrade(&cx) .ok_or_else(|| anyhow!("conversation was dropped"))? .update(&mut cx, |this, cx| { - this.max_token_count = tiktoken_rs::model::get_context_size(&this.model); + this.max_token_count = + tiktoken_rs::model::get_context_size(&this.model.full_name()); this.token_count = Some(token_count); cx.notify() }); @@ -1051,7 +1056,7 @@ impl Conversation { Some(self.max_token_count as isize - self.token_count? as isize) } - fn set_model(&mut self, model: String, cx: &mut ModelContext) { + fn set_model(&mut self, model: OpenAIModel, cx: &mut ModelContext) { self.model = model; self.count_remaining_tokens(cx); cx.notify(); @@ -1093,7 +1098,7 @@ impl Conversation { } } else { let request = OpenAIRequest { - model: self.model.clone(), + model: self.model.full_name().to_string(), messages: self .messages(cx) .filter(|message| matches!(message.status, MessageStatus::Done)) @@ -1419,7 +1424,7 @@ impl Conversation { .into(), })); let request = OpenAIRequest { - model: self.model.clone(), + model: self.model.full_name().to_string(), messages: messages.collect(), stream: true, }; @@ -2023,11 +2028,8 @@ impl ConversationEditor { fn cycle_model(&mut self, cx: &mut ViewContext) { self.conversation.update(cx, |conversation, cx| { - let new_model = match conversation.model.as_str() { - "gpt-4-0613" => "gpt-3.5-turbo-0613", - _ => "gpt-4-0613", - }; - conversation.set_model(new_model.into(), cx); + let new_model = conversation.model.cycle(); + conversation.set_model(new_model, cx); }); } @@ -2049,7 +2051,8 @@ impl ConversationEditor { MouseEventHandler::new::(0, cx, |state, cx| { let style = style.model.style_for(state); - Label::new(self.conversation.read(cx).model.clone(), style.text.clone()) + let model_display_name = self.conversation.read(cx).model.short_name(); + Label::new(model_display_name, style.text.clone()) .contained() .with_style(style.container) }) @@ -2238,6 +2241,8 @@ mod tests { #[gpui::test] fn test_inserting_and_removing_messages(cx: &mut AppContext) { + cx.set_global(SettingsStore::test(cx)); + init(cx); let registry = Arc::new(LanguageRegistry::test()); let conversation = cx.add_model(|cx| Conversation::new(Default::default(), registry, cx)); let buffer = conversation.read(cx).buffer.clone(); @@ -2364,6 +2369,8 @@ mod tests { #[gpui::test] fn test_message_splitting(cx: &mut AppContext) { + cx.set_global(SettingsStore::test(cx)); + init(cx); let registry = Arc::new(LanguageRegistry::test()); let conversation = cx.add_model(|cx| Conversation::new(Default::default(), registry, cx)); let buffer = conversation.read(cx).buffer.clone(); @@ -2458,6 +2465,8 @@ mod tests { #[gpui::test] fn test_messages_for_offsets(cx: &mut AppContext) { + cx.set_global(SettingsStore::test(cx)); + init(cx); let registry = Arc::new(LanguageRegistry::test()); let conversation = cx.add_model(|cx| Conversation::new(Default::default(), registry, cx)); let buffer = conversation.read(cx).buffer.clone(); @@ -2538,6 +2547,8 @@ mod tests { #[gpui::test] fn test_serialization(cx: &mut AppContext) { + cx.set_global(SettingsStore::test(cx)); + init(cx); let registry = Arc::new(LanguageRegistry::test()); let conversation = cx.add_model(|cx| Conversation::new(Default::default(), registry.clone(), cx)); diff --git a/crates/ai/src/assistant_settings.rs b/crates/ai/src/assistant_settings.rs index 04ba8fb946..05d8d9ffeb 100644 --- a/crates/ai/src/assistant_settings.rs +++ b/crates/ai/src/assistant_settings.rs @@ -3,6 +3,37 @@ use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use settings::Setting; +#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)] +pub enum OpenAIModel { + #[serde(rename = "gpt-3.5-turbo-0613")] + ThreePointFiveTurbo, + #[serde(rename = "gpt-4-0613")] + Four, +} + +impl OpenAIModel { + pub fn full_name(&self) -> &'static str { + match self { + OpenAIModel::ThreePointFiveTurbo => "gpt-3.5-turbo-0613", + OpenAIModel::Four => "gpt-4-0613", + } + } + + pub fn short_name(&self) -> &'static str { + match self { + OpenAIModel::ThreePointFiveTurbo => "gpt-3.5-turbo", + OpenAIModel::Four => "gpt-4", + } + } + + pub fn cycle(&self) -> Self { + match self { + OpenAIModel::ThreePointFiveTurbo => OpenAIModel::Four, + OpenAIModel::Four => OpenAIModel::ThreePointFiveTurbo, + } + } +} + #[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] #[serde(rename_all = "snake_case")] pub enum AssistantDockPosition { @@ -17,6 +48,7 @@ pub struct AssistantSettings { pub dock: AssistantDockPosition, pub default_width: f32, pub default_height: f32, + pub default_open_ai_model: OpenAIModel, } #[derive(Clone, Default, Serialize, Deserialize, JsonSchema, Debug)] @@ -25,6 +57,7 @@ pub struct AssistantSettingsContent { pub dock: Option, pub default_width: Option, pub default_height: Option, + pub default_open_ai_model: Option, } impl Setting for AssistantSettings { diff --git a/crates/settings/src/settings_store.rs b/crates/settings/src/settings_store.rs index 1188018cd8..da84074d2a 100644 --- a/crates/settings/src/settings_store.rs +++ b/crates/settings/src/settings_store.rs @@ -1,4 +1,4 @@ -use anyhow::{anyhow, Result}; +use anyhow::{anyhow, Context, Result}; use collections::{btree_map, hash_map, BTreeMap, HashMap}; use gpui::AppContext; use lazy_static::lazy_static; @@ -162,6 +162,7 @@ impl SettingsStore { if let Some(setting) = setting_value .load_setting(&default_settings, &user_values_stack, cx) + .context("A default setting must be added to the `default.json` file") .log_err() { setting_value.set_global_value(setting);