diff --git a/assets/settings/default.json b/assets/settings/default.json index 3f15860d4a..f3d6dfef58 100644 --- a/assets/settings/default.json +++ b/assets/settings/default.json @@ -865,16 +865,18 @@ // Different settings for specific language models. "language_models": { "anthropic": { + "version": "1", "api_url": "https://api.anthropic.com" }, - "openai": { - "api_url": "https://api.openai.com/v1" - }, "google": { "api_url": "https://generativelanguage.googleapis.com" }, "ollama": { "api_url": "http://localhost:11434" + }, + "openai": { + "version": "1", + "api_url": "https://api.openai.com/v1" } }, // Zed's Prettier integration settings. diff --git a/crates/assistant/src/assistant_settings.rs b/crates/assistant/src/assistant_settings.rs index a438e842d6..152a2d629d 100644 --- a/crates/assistant/src/assistant_settings.rs +++ b/crates/assistant/src/assistant_settings.rs @@ -110,11 +110,15 @@ impl AssistantSettingsContent { move |content, _| { if content.anthropic.is_none() { content.anthropic = - Some(language_model::settings::AnthropicSettingsContent { - api_url, - low_speed_timeout_in_seconds, - ..Default::default() - }); + Some(language_model::settings::AnthropicSettingsContent::Versioned( + language_model::settings::VersionedAnthropicSettingsContent::V1( + language_model::settings::AnthropicSettingsContentV1 { + api_url, + low_speed_timeout_in_seconds, + available_models: None + } + ) + )); } }, ), @@ -145,12 +149,27 @@ impl AssistantSettingsContent { cx, move |content, _| { if content.openai.is_none() { + let available_models = available_models.map(|models| { + models + .into_iter() + .filter_map(|model| match model { + open_ai::Model::Custom { name, max_tokens } => { + Some(language_model::provider::open_ai::AvailableModel { name, max_tokens }) + } + _ => None, + }) + .collect::>() + }); content.openai = - Some(language_model::settings::OpenAiSettingsContent { - api_url, - low_speed_timeout_in_seconds, - available_models, - }); + Some(language_model::settings::OpenAiSettingsContent::Versioned( + language_model::settings::VersionedOpenAiSettingsContent::V1( + language_model::settings::OpenAiSettingsContentV1 { + api_url, + low_speed_timeout_in_seconds, + available_models + } + ) + )); } }, ), @@ -377,6 +396,7 @@ fn providers_schema(_: &mut schemars::gen::SchemaGenerator) -> schemars::schema: schemars::schema::SchemaObject { enum_values: Some(vec![ "anthropic".into(), + "google".into(), "ollama".into(), "openai".into(), "zed.dev".into(), diff --git a/crates/language_model/Cargo.toml b/crates/language_model/Cargo.toml index 9a5c60a0d8..d7b609dde0 100644 --- a/crates/language_model/Cargo.toml +++ b/crates/language_model/Cargo.toml @@ -37,6 +37,7 @@ menu.workspace = true ollama = { workspace = true, features = ["schemars"] } open_ai = { workspace = true, features = ["schemars"] } proto = { workspace = true, features = ["test-support"] } +project.workspace = true schemars.workspace = true serde.workspace = true serde_json.workspace = true diff --git a/crates/language_model/src/language_model.rs b/crates/language_model/src/language_model.rs index b9f3262f30..0d7a003663 100644 --- a/crates/language_model/src/language_model.rs +++ b/crates/language_model/src/language_model.rs @@ -13,14 +13,15 @@ use futures::{future::BoxFuture, stream::BoxStream}; use gpui::{AnyView, AppContext, AsyncAppContext, SharedString, Task, WindowContext}; pub use model::*; +use project::Fs; pub use registry::*; pub use request::*; pub use role::*; use schemars::JsonSchema; use serde::de::DeserializeOwned; -pub fn init(client: Arc, cx: &mut AppContext) { - settings::init(cx); +pub fn init(client: Arc, fs: Arc, cx: &mut AppContext) { + settings::init(fs, cx); registry::init(client, cx); } diff --git a/crates/language_model/src/provider/anthropic.rs b/crates/language_model/src/provider/anthropic.rs index cfca9358a1..32932953e7 100644 --- a/crates/language_model/src/provider/anthropic.rs +++ b/crates/language_model/src/provider/anthropic.rs @@ -12,6 +12,8 @@ use gpui::{ WhiteSpace, }; use http_client::HttpClient; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; use settings::{Settings, SettingsStore}; use std::{sync::Arc, time::Duration}; use strum::IntoEnumIterator; @@ -26,7 +28,14 @@ const PROVIDER_NAME: &str = "Anthropic"; pub struct AnthropicSettings { pub api_url: String, pub low_speed_timeout: Option, - pub available_models: Vec, + pub available_models: Vec, + pub needs_setting_migration: bool, +} + +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)] +pub struct AvailableModel { + pub name: String, + pub max_tokens: usize, } pub struct AnthropicLanguageModelProvider { @@ -84,7 +93,13 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider { .available_models .iter() { - models.insert(model.id().to_string(), model.clone()); + models.insert( + model.name.clone(), + anthropic::Model::Custom { + name: model.name.clone(), + max_tokens: model.max_tokens, + }, + ); } models diff --git a/crates/language_model/src/provider/google.rs b/crates/language_model/src/provider/google.rs index 3a869773a3..e0969eda0b 100644 --- a/crates/language_model/src/provider/google.rs +++ b/crates/language_model/src/provider/google.rs @@ -8,6 +8,8 @@ use gpui::{ WhiteSpace, }; use http_client::HttpClient; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; use settings::{Settings, SettingsStore}; use std::{future, sync::Arc, time::Duration}; use strum::IntoEnumIterator; @@ -28,7 +30,13 @@ const PROVIDER_NAME: &str = "Google AI"; pub struct GoogleSettings { pub api_url: String, pub low_speed_timeout: Option, - pub available_models: Vec, + pub available_models: Vec, +} + +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)] +pub struct AvailableModel { + name: String, + max_tokens: usize, } pub struct GoogleLanguageModelProvider { @@ -86,7 +94,13 @@ impl LanguageModelProvider for GoogleLanguageModelProvider { .google .available_models { - models.insert(model.id().to_string(), model.clone()); + models.insert( + model.name.clone(), + google_ai::Model::Custom { + name: model.name.clone(), + max_tokens: model.max_tokens, + }, + ); } models diff --git a/crates/language_model/src/provider/open_ai.rs b/crates/language_model/src/provider/open_ai.rs index d2465e4446..6beec3d0f5 100644 --- a/crates/language_model/src/provider/open_ai.rs +++ b/crates/language_model/src/provider/open_ai.rs @@ -8,6 +8,8 @@ use gpui::{ }; use http_client::HttpClient; use open_ai::stream_completion; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; use settings::{Settings, SettingsStore}; use std::{future, sync::Arc, time::Duration}; use strum::IntoEnumIterator; @@ -28,7 +30,14 @@ const PROVIDER_NAME: &str = "OpenAI"; pub struct OpenAiSettings { pub api_url: String, pub low_speed_timeout: Option, - pub available_models: Vec, + pub available_models: Vec, + pub needs_setting_migration: bool, +} + +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)] +pub struct AvailableModel { + pub name: String, + pub max_tokens: usize, } pub struct OpenAiLanguageModelProvider { @@ -86,7 +95,13 @@ impl LanguageModelProvider for OpenAiLanguageModelProvider { .openai .available_models { - models.insert(model.id().to_string(), model.clone()); + models.insert( + model.name.clone(), + open_ai::Model::Custom { + name: model.name.clone(), + max_tokens: model.max_tokens, + }, + ); } models diff --git a/crates/language_model/src/settings.rs b/crates/language_model/src/settings.rs index 58e38e4971..3cb012860c 100644 --- a/crates/language_model/src/settings.rs +++ b/crates/language_model/src/settings.rs @@ -1,12 +1,14 @@ -use std::time::Duration; +use std::{sync::Arc, time::Duration}; use anyhow::Result; use gpui::AppContext; +use project::Fs; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; -use settings::{Settings, SettingsSources}; +use settings::{update_settings_file, Settings, SettingsSources}; use crate::provider::{ + self, anthropic::AnthropicSettings, cloud::{self, ZedDotDevSettings}, copilot_chat::CopilotChatSettings, @@ -16,8 +18,36 @@ use crate::provider::{ }; /// Initializes the language model settings. -pub fn init(cx: &mut AppContext) { +pub fn init(fs: Arc, cx: &mut AppContext) { AllLanguageModelSettings::register(cx); + + if AllLanguageModelSettings::get_global(cx) + .openai + .needs_setting_migration + { + update_settings_file::(fs.clone(), cx, move |setting, _| { + if let Some(settings) = setting.openai.clone() { + let (newest_version, _) = settings.upgrade(); + setting.openai = Some(OpenAiSettingsContent::Versioned( + VersionedOpenAiSettingsContent::V1(newest_version), + )); + } + }); + } + + if AllLanguageModelSettings::get_global(cx) + .anthropic + .needs_setting_migration + { + update_settings_file::(fs, cx, move |setting, _| { + if let Some(settings) = setting.anthropic.clone() { + let (newest_version, _) = settings.upgrade(); + setting.anthropic = Some(AnthropicSettingsContent::Versioned( + VersionedAnthropicSettingsContent::V1(newest_version), + )); + } + }); + } } #[derive(Default)] @@ -41,31 +71,129 @@ pub struct AllLanguageModelSettingsContent { pub copilot_chat: Option, } -#[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)] -pub struct AnthropicSettingsContent { +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)] +#[serde(untagged)] +pub enum AnthropicSettingsContent { + Legacy(LegacyAnthropicSettingsContent), + Versioned(VersionedAnthropicSettingsContent), +} + +impl AnthropicSettingsContent { + pub fn upgrade(self) -> (AnthropicSettingsContentV1, bool) { + match self { + AnthropicSettingsContent::Legacy(content) => ( + AnthropicSettingsContentV1 { + api_url: content.api_url, + low_speed_timeout_in_seconds: content.low_speed_timeout_in_seconds, + available_models: content.available_models.map(|models| { + models + .into_iter() + .filter_map(|model| match model { + anthropic::Model::Custom { name, max_tokens } => { + Some(provider::anthropic::AvailableModel { name, max_tokens }) + } + _ => None, + }) + .collect() + }), + }, + true, + ), + AnthropicSettingsContent::Versioned(content) => match content { + VersionedAnthropicSettingsContent::V1(content) => (content, false), + }, + } + } +} + +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)] +pub struct LegacyAnthropicSettingsContent { pub api_url: Option, pub low_speed_timeout_in_seconds: Option, pub available_models: Option>, } +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)] +#[serde(tag = "version")] +pub enum VersionedAnthropicSettingsContent { + #[serde(rename = "1")] + V1(AnthropicSettingsContentV1), +} + +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)] +pub struct AnthropicSettingsContentV1 { + pub api_url: Option, + pub low_speed_timeout_in_seconds: Option, + pub available_models: Option>, +} + #[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)] pub struct OllamaSettingsContent { pub api_url: Option, pub low_speed_timeout_in_seconds: Option, } -#[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)] -pub struct OpenAiSettingsContent { +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)] +#[serde(untagged)] +pub enum OpenAiSettingsContent { + Legacy(LegacyOpenAiSettingsContent), + Versioned(VersionedOpenAiSettingsContent), +} + +impl OpenAiSettingsContent { + pub fn upgrade(self) -> (OpenAiSettingsContentV1, bool) { + match self { + OpenAiSettingsContent::Legacy(content) => ( + OpenAiSettingsContentV1 { + api_url: content.api_url, + low_speed_timeout_in_seconds: content.low_speed_timeout_in_seconds, + available_models: content.available_models.map(|models| { + models + .into_iter() + .filter_map(|model| match model { + open_ai::Model::Custom { name, max_tokens } => { + Some(provider::open_ai::AvailableModel { name, max_tokens }) + } + _ => None, + }) + .collect() + }), + }, + true, + ), + OpenAiSettingsContent::Versioned(content) => match content { + VersionedOpenAiSettingsContent::V1(content) => (content, false), + }, + } + } +} + +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)] +pub struct LegacyOpenAiSettingsContent { pub api_url: Option, pub low_speed_timeout_in_seconds: Option, pub available_models: Option>, } +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)] +#[serde(tag = "version")] +pub enum VersionedOpenAiSettingsContent { + #[serde(rename = "1")] + V1(OpenAiSettingsContentV1), +} + +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)] +pub struct OpenAiSettingsContentV1 { + pub api_url: Option, + pub low_speed_timeout_in_seconds: Option, + pub available_models: Option>, +} + #[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)] pub struct GoogleSettingsContent { pub api_url: Option, pub low_speed_timeout_in_seconds: Option, - pub available_models: Option>, + pub available_models: Option>, } #[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)] @@ -81,6 +209,8 @@ pub struct CopilotChatSettingsContent { impl settings::Settings for AllLanguageModelSettings { const KEY: Option<&'static str> = Some("language_models"); + const PRESERVED_KEYS: Option<&'static [&'static str]> = Some(&["version"]); + type FileContent = AllLanguageModelSettingsContent; fn load(sources: SettingsSources, _: &mut AppContext) -> Result { @@ -93,12 +223,21 @@ impl settings::Settings for AllLanguageModelSettings { let mut settings = AllLanguageModelSettings::default(); for value in sources.defaults_and_customizations() { + // Anthropic + let (anthropic, upgraded) = match value.anthropic.clone().map(|s| s.upgrade()) { + Some((content, upgraded)) => (Some(content), upgraded), + None => (None, false), + }; + + if upgraded { + settings.anthropic.needs_setting_migration = true; + } + merge( &mut settings.anthropic.api_url, - value.anthropic.as_ref().and_then(|s| s.api_url.clone()), + anthropic.as_ref().and_then(|s| s.api_url.clone()), ); - if let Some(low_speed_timeout_in_seconds) = value - .anthropic + if let Some(low_speed_timeout_in_seconds) = anthropic .as_ref() .and_then(|s| s.low_speed_timeout_in_seconds) { @@ -107,10 +246,7 @@ impl settings::Settings for AllLanguageModelSettings { } merge( &mut settings.anthropic.available_models, - value - .anthropic - .as_ref() - .and_then(|s| s.available_models.clone()), + anthropic.as_ref().and_then(|s| s.available_models.clone()), ); merge( @@ -126,24 +262,29 @@ impl settings::Settings for AllLanguageModelSettings { Some(Duration::from_secs(low_speed_timeout_in_seconds)); } + // OpenAI + let (openai, upgraded) = match value.openai.clone().map(|s| s.upgrade()) { + Some((content, upgraded)) => (Some(content), upgraded), + None => (None, false), + }; + + if upgraded { + settings.openai.needs_setting_migration = true; + } + merge( &mut settings.openai.api_url, - value.openai.as_ref().and_then(|s| s.api_url.clone()), + openai.as_ref().and_then(|s| s.api_url.clone()), ); - if let Some(low_speed_timeout_in_seconds) = value - .openai - .as_ref() - .and_then(|s| s.low_speed_timeout_in_seconds) + if let Some(low_speed_timeout_in_seconds) = + openai.as_ref().and_then(|s| s.low_speed_timeout_in_seconds) { settings.openai.low_speed_timeout = Some(Duration::from_secs(low_speed_timeout_in_seconds)); } merge( &mut settings.openai.available_models, - value - .openai - .as_ref() - .and_then(|s| s.available_models.clone()), + openai.as_ref().and_then(|s| s.available_models.clone()), ); merge( diff --git a/crates/zed/src/main.rs b/crates/zed/src/main.rs index 4cd3c5539b..a6c4ea02fb 100644 --- a/crates/zed/src/main.rs +++ b/crates/zed/src/main.rs @@ -174,7 +174,7 @@ fn init_common(app_state: Arc, cx: &mut AppContext) { cx, ); supermaven::init(app_state.client.clone(), cx); - language_model::init(app_state.client.clone(), cx); + language_model::init(app_state.client.clone(), app_state.fs.clone(), cx); snippet_provider::init(cx); inline_completion_registry::init(app_state.client.telemetry().clone(), cx); assistant::init(app_state.fs.clone(), app_state.client.clone(), cx); diff --git a/crates/zed/src/zed.rs b/crates/zed/src/zed.rs index 46cd10467f..15a2013444 100644 --- a/crates/zed/src/zed.rs +++ b/crates/zed/src/zed.rs @@ -3461,7 +3461,7 @@ mod tests { app_state.client.http_client().clone(), cx, ); - language_model::init(app_state.client.clone(), cx); + language_model::init(app_state.client.clone(), app_state.fs.clone(), cx); assistant::init(app_state.fs.clone(), app_state.client.clone(), cx); repl::init( app_state.fs.clone(),