assistant: Make it easier to define custom models (#15442)

This PR makes it easier to specify custom models for the Google, OpenAI,
and Anthropic provider:

Before (google):

```json
{
  "language_models": {
    "google": {
      "available_models": [
        {
          "custom": {
            "name": "my-custom-google-model",
            "max_tokens": 12345
          }
        }
      ]
    }
  }
}
```

After (google):

```json
{
  "language_models": {
    "google": {
      "available_models": [
        {
          "name": "my-custom-google-model",
          "max_tokens": 12345
        }
      ]
    }
  }
}
```

Before (anthropic):

```json
{
  "language_models": {
    "anthropic": {
      "available_models": [
        {
          "custom": {
            "name": "my-custom-anthropic-model",
            "max_tokens": 12345
          }
        }
      ]
    }
  }
}
```

After (anthropic):

```json
{
  "language_models": {
    "anthropic": {
      "version": "1",
      "available_models": [
        {
          "name": "my-custom-anthropic-model",
          "max_tokens": 12345
        }
      ]
    }
  }
}

```

The settings will be auto-upgraded so the old versions will continue to
work (except for Google since that one has not been released).

/cc @as-cii 

Release Notes:

- N/A

---------

Co-authored-by: Thorsten <thorsten@zed.dev>
This commit is contained in:
Bennet Bo Fenner 2024-07-30 15:46:39 +02:00 committed by GitHub
parent 13dcb42c1c
commit 2ada2964c5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 256 additions and 47 deletions

View file

@ -1,12 +1,14 @@
use std::time::Duration;
use std::{sync::Arc, time::Duration};
use anyhow::Result;
use gpui::AppContext;
use project::Fs;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsSources};
use settings::{update_settings_file, Settings, SettingsSources};
use crate::provider::{
self,
anthropic::AnthropicSettings,
cloud::{self, ZedDotDevSettings},
copilot_chat::CopilotChatSettings,
@ -16,8 +18,36 @@ use crate::provider::{
};
/// Initializes the language model settings.
pub fn init(cx: &mut AppContext) {
pub fn init(fs: Arc<dyn Fs>, cx: &mut AppContext) {
AllLanguageModelSettings::register(cx);
if AllLanguageModelSettings::get_global(cx)
.openai
.needs_setting_migration
{
update_settings_file::<AllLanguageModelSettings>(fs.clone(), cx, move |setting, _| {
if let Some(settings) = setting.openai.clone() {
let (newest_version, _) = settings.upgrade();
setting.openai = Some(OpenAiSettingsContent::Versioned(
VersionedOpenAiSettingsContent::V1(newest_version),
));
}
});
}
if AllLanguageModelSettings::get_global(cx)
.anthropic
.needs_setting_migration
{
update_settings_file::<AllLanguageModelSettings>(fs, cx, move |setting, _| {
if let Some(settings) = setting.anthropic.clone() {
let (newest_version, _) = settings.upgrade();
setting.anthropic = Some(AnthropicSettingsContent::Versioned(
VersionedAnthropicSettingsContent::V1(newest_version),
));
}
});
}
}
#[derive(Default)]
@ -41,31 +71,129 @@ pub struct AllLanguageModelSettingsContent {
pub copilot_chat: Option<CopilotChatSettingsContent>,
}
#[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
pub struct AnthropicSettingsContent {
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
#[serde(untagged)]
pub enum AnthropicSettingsContent {
Legacy(LegacyAnthropicSettingsContent),
Versioned(VersionedAnthropicSettingsContent),
}
impl AnthropicSettingsContent {
pub fn upgrade(self) -> (AnthropicSettingsContentV1, bool) {
match self {
AnthropicSettingsContent::Legacy(content) => (
AnthropicSettingsContentV1 {
api_url: content.api_url,
low_speed_timeout_in_seconds: content.low_speed_timeout_in_seconds,
available_models: content.available_models.map(|models| {
models
.into_iter()
.filter_map(|model| match model {
anthropic::Model::Custom { name, max_tokens } => {
Some(provider::anthropic::AvailableModel { name, max_tokens })
}
_ => None,
})
.collect()
}),
},
true,
),
AnthropicSettingsContent::Versioned(content) => match content {
VersionedAnthropicSettingsContent::V1(content) => (content, false),
},
}
}
}
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
pub struct LegacyAnthropicSettingsContent {
pub api_url: Option<String>,
pub low_speed_timeout_in_seconds: Option<u64>,
pub available_models: Option<Vec<anthropic::Model>>,
}
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
#[serde(tag = "version")]
pub enum VersionedAnthropicSettingsContent {
#[serde(rename = "1")]
V1(AnthropicSettingsContentV1),
}
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
pub struct AnthropicSettingsContentV1 {
pub api_url: Option<String>,
pub low_speed_timeout_in_seconds: Option<u64>,
pub available_models: Option<Vec<provider::anthropic::AvailableModel>>,
}
#[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
pub struct OllamaSettingsContent {
pub api_url: Option<String>,
pub low_speed_timeout_in_seconds: Option<u64>,
}
#[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
pub struct OpenAiSettingsContent {
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
#[serde(untagged)]
pub enum OpenAiSettingsContent {
Legacy(LegacyOpenAiSettingsContent),
Versioned(VersionedOpenAiSettingsContent),
}
impl OpenAiSettingsContent {
pub fn upgrade(self) -> (OpenAiSettingsContentV1, bool) {
match self {
OpenAiSettingsContent::Legacy(content) => (
OpenAiSettingsContentV1 {
api_url: content.api_url,
low_speed_timeout_in_seconds: content.low_speed_timeout_in_seconds,
available_models: content.available_models.map(|models| {
models
.into_iter()
.filter_map(|model| match model {
open_ai::Model::Custom { name, max_tokens } => {
Some(provider::open_ai::AvailableModel { name, max_tokens })
}
_ => None,
})
.collect()
}),
},
true,
),
OpenAiSettingsContent::Versioned(content) => match content {
VersionedOpenAiSettingsContent::V1(content) => (content, false),
},
}
}
}
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
pub struct LegacyOpenAiSettingsContent {
pub api_url: Option<String>,
pub low_speed_timeout_in_seconds: Option<u64>,
pub available_models: Option<Vec<open_ai::Model>>,
}
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
#[serde(tag = "version")]
pub enum VersionedOpenAiSettingsContent {
#[serde(rename = "1")]
V1(OpenAiSettingsContentV1),
}
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
pub struct OpenAiSettingsContentV1 {
pub api_url: Option<String>,
pub low_speed_timeout_in_seconds: Option<u64>,
pub available_models: Option<Vec<provider::open_ai::AvailableModel>>,
}
#[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
pub struct GoogleSettingsContent {
pub api_url: Option<String>,
pub low_speed_timeout_in_seconds: Option<u64>,
pub available_models: Option<Vec<google_ai::Model>>,
pub available_models: Option<Vec<provider::google::AvailableModel>>,
}
#[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
@ -81,6 +209,8 @@ pub struct CopilotChatSettingsContent {
impl settings::Settings for AllLanguageModelSettings {
const KEY: Option<&'static str> = Some("language_models");
const PRESERVED_KEYS: Option<&'static [&'static str]> = Some(&["version"]);
type FileContent = AllLanguageModelSettingsContent;
fn load(sources: SettingsSources<Self::FileContent>, _: &mut AppContext) -> Result<Self> {
@ -93,12 +223,21 @@ impl settings::Settings for AllLanguageModelSettings {
let mut settings = AllLanguageModelSettings::default();
for value in sources.defaults_and_customizations() {
// Anthropic
let (anthropic, upgraded) = match value.anthropic.clone().map(|s| s.upgrade()) {
Some((content, upgraded)) => (Some(content), upgraded),
None => (None, false),
};
if upgraded {
settings.anthropic.needs_setting_migration = true;
}
merge(
&mut settings.anthropic.api_url,
value.anthropic.as_ref().and_then(|s| s.api_url.clone()),
anthropic.as_ref().and_then(|s| s.api_url.clone()),
);
if let Some(low_speed_timeout_in_seconds) = value
.anthropic
if let Some(low_speed_timeout_in_seconds) = anthropic
.as_ref()
.and_then(|s| s.low_speed_timeout_in_seconds)
{
@ -107,10 +246,7 @@ impl settings::Settings for AllLanguageModelSettings {
}
merge(
&mut settings.anthropic.available_models,
value
.anthropic
.as_ref()
.and_then(|s| s.available_models.clone()),
anthropic.as_ref().and_then(|s| s.available_models.clone()),
);
merge(
@ -126,24 +262,29 @@ impl settings::Settings for AllLanguageModelSettings {
Some(Duration::from_secs(low_speed_timeout_in_seconds));
}
// OpenAI
let (openai, upgraded) = match value.openai.clone().map(|s| s.upgrade()) {
Some((content, upgraded)) => (Some(content), upgraded),
None => (None, false),
};
if upgraded {
settings.openai.needs_setting_migration = true;
}
merge(
&mut settings.openai.api_url,
value.openai.as_ref().and_then(|s| s.api_url.clone()),
openai.as_ref().and_then(|s| s.api_url.clone()),
);
if let Some(low_speed_timeout_in_seconds) = value
.openai
.as_ref()
.and_then(|s| s.low_speed_timeout_in_seconds)
if let Some(low_speed_timeout_in_seconds) =
openai.as_ref().and_then(|s| s.low_speed_timeout_in_seconds)
{
settings.openai.low_speed_timeout =
Some(Duration::from_secs(low_speed_timeout_in_seconds));
}
merge(
&mut settings.openai.available_models,
value
.openai
.as_ref()
.and_then(|s| s.available_models.clone()),
openai.as_ref().and_then(|s| s.available_models.clone()),
);
merge(