assistant: Make it easier to define custom models (#15442)

This PR makes it easier to specify custom models for the Google, OpenAI,
and Anthropic provider:

Before (google):

```json
{
  "language_models": {
    "google": {
      "available_models": [
        {
          "custom": {
            "name": "my-custom-google-model",
            "max_tokens": 12345
          }
        }
      ]
    }
  }
}
```

After (google):

```json
{
  "language_models": {
    "google": {
      "available_models": [
        {
          "name": "my-custom-google-model",
          "max_tokens": 12345
        }
      ]
    }
  }
}
```

Before (anthropic):

```json
{
  "language_models": {
    "anthropic": {
      "available_models": [
        {
          "custom": {
            "name": "my-custom-anthropic-model",
            "max_tokens": 12345
          }
        }
      ]
    }
  }
}
```

After (anthropic):

```json
{
  "language_models": {
    "anthropic": {
      "version": "1",
      "available_models": [
        {
          "name": "my-custom-anthropic-model",
          "max_tokens": 12345
        }
      ]
    }
  }
}

```

The settings will be auto-upgraded so the old versions will continue to
work (except for Google since that one has not been released).

/cc @as-cii 

Release Notes:

- N/A

---------

Co-authored-by: Thorsten <thorsten@zed.dev>
This commit is contained in:
Bennet Bo Fenner 2024-07-30 15:46:39 +02:00 committed by GitHub
parent 13dcb42c1c
commit 2ada2964c5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 256 additions and 47 deletions

View file

@ -37,6 +37,7 @@ menu.workspace = true
ollama = { workspace = true, features = ["schemars"] }
open_ai = { workspace = true, features = ["schemars"] }
proto = { workspace = true, features = ["test-support"] }
project.workspace = true
schemars.workspace = true
serde.workspace = true
serde_json.workspace = true

View file

@ -13,14 +13,15 @@ use futures::{future::BoxFuture, stream::BoxStream};
use gpui::{AnyView, AppContext, AsyncAppContext, SharedString, Task, WindowContext};
pub use model::*;
use project::Fs;
pub use registry::*;
pub use request::*;
pub use role::*;
use schemars::JsonSchema;
use serde::de::DeserializeOwned;
pub fn init(client: Arc<Client>, cx: &mut AppContext) {
settings::init(cx);
pub fn init(client: Arc<Client>, fs: Arc<dyn Fs>, cx: &mut AppContext) {
settings::init(fs, cx);
registry::init(client, cx);
}

View file

@ -12,6 +12,8 @@ use gpui::{
WhiteSpace,
};
use http_client::HttpClient;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsStore};
use std::{sync::Arc, time::Duration};
use strum::IntoEnumIterator;
@ -26,7 +28,14 @@ const PROVIDER_NAME: &str = "Anthropic";
pub struct AnthropicSettings {
pub api_url: String,
pub low_speed_timeout: Option<Duration>,
pub available_models: Vec<anthropic::Model>,
pub available_models: Vec<AvailableModel>,
pub needs_setting_migration: bool,
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
pub struct AvailableModel {
pub name: String,
pub max_tokens: usize,
}
pub struct AnthropicLanguageModelProvider {
@ -84,7 +93,13 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider {
.available_models
.iter()
{
models.insert(model.id().to_string(), model.clone());
models.insert(
model.name.clone(),
anthropic::Model::Custom {
name: model.name.clone(),
max_tokens: model.max_tokens,
},
);
}
models

View file

@ -8,6 +8,8 @@ use gpui::{
WhiteSpace,
};
use http_client::HttpClient;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsStore};
use std::{future, sync::Arc, time::Duration};
use strum::IntoEnumIterator;
@ -28,7 +30,13 @@ const PROVIDER_NAME: &str = "Google AI";
pub struct GoogleSettings {
pub api_url: String,
pub low_speed_timeout: Option<Duration>,
pub available_models: Vec<google_ai::Model>,
pub available_models: Vec<AvailableModel>,
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
pub struct AvailableModel {
name: String,
max_tokens: usize,
}
pub struct GoogleLanguageModelProvider {
@ -86,7 +94,13 @@ impl LanguageModelProvider for GoogleLanguageModelProvider {
.google
.available_models
{
models.insert(model.id().to_string(), model.clone());
models.insert(
model.name.clone(),
google_ai::Model::Custom {
name: model.name.clone(),
max_tokens: model.max_tokens,
},
);
}
models

View file

@ -8,6 +8,8 @@ use gpui::{
};
use http_client::HttpClient;
use open_ai::stream_completion;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsStore};
use std::{future, sync::Arc, time::Duration};
use strum::IntoEnumIterator;
@ -28,7 +30,14 @@ const PROVIDER_NAME: &str = "OpenAI";
pub struct OpenAiSettings {
pub api_url: String,
pub low_speed_timeout: Option<Duration>,
pub available_models: Vec<open_ai::Model>,
pub available_models: Vec<AvailableModel>,
pub needs_setting_migration: bool,
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
pub struct AvailableModel {
pub name: String,
pub max_tokens: usize,
}
pub struct OpenAiLanguageModelProvider {
@ -86,7 +95,13 @@ impl LanguageModelProvider for OpenAiLanguageModelProvider {
.openai
.available_models
{
models.insert(model.id().to_string(), model.clone());
models.insert(
model.name.clone(),
open_ai::Model::Custom {
name: model.name.clone(),
max_tokens: model.max_tokens,
},
);
}
models

View file

@ -1,12 +1,14 @@
use std::time::Duration;
use std::{sync::Arc, time::Duration};
use anyhow::Result;
use gpui::AppContext;
use project::Fs;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsSources};
use settings::{update_settings_file, Settings, SettingsSources};
use crate::provider::{
self,
anthropic::AnthropicSettings,
cloud::{self, ZedDotDevSettings},
copilot_chat::CopilotChatSettings,
@ -16,8 +18,36 @@ use crate::provider::{
};
/// Initializes the language model settings.
pub fn init(cx: &mut AppContext) {
pub fn init(fs: Arc<dyn Fs>, cx: &mut AppContext) {
AllLanguageModelSettings::register(cx);
if AllLanguageModelSettings::get_global(cx)
.openai
.needs_setting_migration
{
update_settings_file::<AllLanguageModelSettings>(fs.clone(), cx, move |setting, _| {
if let Some(settings) = setting.openai.clone() {
let (newest_version, _) = settings.upgrade();
setting.openai = Some(OpenAiSettingsContent::Versioned(
VersionedOpenAiSettingsContent::V1(newest_version),
));
}
});
}
if AllLanguageModelSettings::get_global(cx)
.anthropic
.needs_setting_migration
{
update_settings_file::<AllLanguageModelSettings>(fs, cx, move |setting, _| {
if let Some(settings) = setting.anthropic.clone() {
let (newest_version, _) = settings.upgrade();
setting.anthropic = Some(AnthropicSettingsContent::Versioned(
VersionedAnthropicSettingsContent::V1(newest_version),
));
}
});
}
}
#[derive(Default)]
@ -41,31 +71,129 @@ pub struct AllLanguageModelSettingsContent {
pub copilot_chat: Option<CopilotChatSettingsContent>,
}
#[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
pub struct AnthropicSettingsContent {
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
#[serde(untagged)]
pub enum AnthropicSettingsContent {
Legacy(LegacyAnthropicSettingsContent),
Versioned(VersionedAnthropicSettingsContent),
}
impl AnthropicSettingsContent {
pub fn upgrade(self) -> (AnthropicSettingsContentV1, bool) {
match self {
AnthropicSettingsContent::Legacy(content) => (
AnthropicSettingsContentV1 {
api_url: content.api_url,
low_speed_timeout_in_seconds: content.low_speed_timeout_in_seconds,
available_models: content.available_models.map(|models| {
models
.into_iter()
.filter_map(|model| match model {
anthropic::Model::Custom { name, max_tokens } => {
Some(provider::anthropic::AvailableModel { name, max_tokens })
}
_ => None,
})
.collect()
}),
},
true,
),
AnthropicSettingsContent::Versioned(content) => match content {
VersionedAnthropicSettingsContent::V1(content) => (content, false),
},
}
}
}
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
pub struct LegacyAnthropicSettingsContent {
pub api_url: Option<String>,
pub low_speed_timeout_in_seconds: Option<u64>,
pub available_models: Option<Vec<anthropic::Model>>,
}
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
#[serde(tag = "version")]
pub enum VersionedAnthropicSettingsContent {
#[serde(rename = "1")]
V1(AnthropicSettingsContentV1),
}
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
pub struct AnthropicSettingsContentV1 {
pub api_url: Option<String>,
pub low_speed_timeout_in_seconds: Option<u64>,
pub available_models: Option<Vec<provider::anthropic::AvailableModel>>,
}
#[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
pub struct OllamaSettingsContent {
pub api_url: Option<String>,
pub low_speed_timeout_in_seconds: Option<u64>,
}
#[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
pub struct OpenAiSettingsContent {
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
#[serde(untagged)]
pub enum OpenAiSettingsContent {
Legacy(LegacyOpenAiSettingsContent),
Versioned(VersionedOpenAiSettingsContent),
}
impl OpenAiSettingsContent {
pub fn upgrade(self) -> (OpenAiSettingsContentV1, bool) {
match self {
OpenAiSettingsContent::Legacy(content) => (
OpenAiSettingsContentV1 {
api_url: content.api_url,
low_speed_timeout_in_seconds: content.low_speed_timeout_in_seconds,
available_models: content.available_models.map(|models| {
models
.into_iter()
.filter_map(|model| match model {
open_ai::Model::Custom { name, max_tokens } => {
Some(provider::open_ai::AvailableModel { name, max_tokens })
}
_ => None,
})
.collect()
}),
},
true,
),
OpenAiSettingsContent::Versioned(content) => match content {
VersionedOpenAiSettingsContent::V1(content) => (content, false),
},
}
}
}
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
pub struct LegacyOpenAiSettingsContent {
pub api_url: Option<String>,
pub low_speed_timeout_in_seconds: Option<u64>,
pub available_models: Option<Vec<open_ai::Model>>,
}
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
#[serde(tag = "version")]
pub enum VersionedOpenAiSettingsContent {
#[serde(rename = "1")]
V1(OpenAiSettingsContentV1),
}
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
pub struct OpenAiSettingsContentV1 {
pub api_url: Option<String>,
pub low_speed_timeout_in_seconds: Option<u64>,
pub available_models: Option<Vec<provider::open_ai::AvailableModel>>,
}
#[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
pub struct GoogleSettingsContent {
pub api_url: Option<String>,
pub low_speed_timeout_in_seconds: Option<u64>,
pub available_models: Option<Vec<google_ai::Model>>,
pub available_models: Option<Vec<provider::google::AvailableModel>>,
}
#[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
@ -81,6 +209,8 @@ pub struct CopilotChatSettingsContent {
impl settings::Settings for AllLanguageModelSettings {
const KEY: Option<&'static str> = Some("language_models");
const PRESERVED_KEYS: Option<&'static [&'static str]> = Some(&["version"]);
type FileContent = AllLanguageModelSettingsContent;
fn load(sources: SettingsSources<Self::FileContent>, _: &mut AppContext) -> Result<Self> {
@ -93,12 +223,21 @@ impl settings::Settings for AllLanguageModelSettings {
let mut settings = AllLanguageModelSettings::default();
for value in sources.defaults_and_customizations() {
// Anthropic
let (anthropic, upgraded) = match value.anthropic.clone().map(|s| s.upgrade()) {
Some((content, upgraded)) => (Some(content), upgraded),
None => (None, false),
};
if upgraded {
settings.anthropic.needs_setting_migration = true;
}
merge(
&mut settings.anthropic.api_url,
value.anthropic.as_ref().and_then(|s| s.api_url.clone()),
anthropic.as_ref().and_then(|s| s.api_url.clone()),
);
if let Some(low_speed_timeout_in_seconds) = value
.anthropic
if let Some(low_speed_timeout_in_seconds) = anthropic
.as_ref()
.and_then(|s| s.low_speed_timeout_in_seconds)
{
@ -107,10 +246,7 @@ impl settings::Settings for AllLanguageModelSettings {
}
merge(
&mut settings.anthropic.available_models,
value
.anthropic
.as_ref()
.and_then(|s| s.available_models.clone()),
anthropic.as_ref().and_then(|s| s.available_models.clone()),
);
merge(
@ -126,24 +262,29 @@ impl settings::Settings for AllLanguageModelSettings {
Some(Duration::from_secs(low_speed_timeout_in_seconds));
}
// OpenAI
let (openai, upgraded) = match value.openai.clone().map(|s| s.upgrade()) {
Some((content, upgraded)) => (Some(content), upgraded),
None => (None, false),
};
if upgraded {
settings.openai.needs_setting_migration = true;
}
merge(
&mut settings.openai.api_url,
value.openai.as_ref().and_then(|s| s.api_url.clone()),
openai.as_ref().and_then(|s| s.api_url.clone()),
);
if let Some(low_speed_timeout_in_seconds) = value
.openai
.as_ref()
.and_then(|s| s.low_speed_timeout_in_seconds)
if let Some(low_speed_timeout_in_seconds) =
openai.as_ref().and_then(|s| s.low_speed_timeout_in_seconds)
{
settings.openai.low_speed_timeout =
Some(Duration::from_secs(low_speed_timeout_in_seconds));
}
merge(
&mut settings.openai.available_models,
value
.openai
.as_ref()
.and_then(|s| s.available_models.clone()),
openai.as_ref().and_then(|s| s.available_models.clone()),
);
merge(