assistant: Make it easier to define custom models (#15442)

This PR makes it easier to specify custom models for the Google, OpenAI,
and Anthropic provider:

Before (google):

```json
{
  "language_models": {
    "google": {
      "available_models": [
        {
          "custom": {
            "name": "my-custom-google-model",
            "max_tokens": 12345
          }
        }
      ]
    }
  }
}
```

After (google):

```json
{
  "language_models": {
    "google": {
      "available_models": [
        {
          "name": "my-custom-google-model",
          "max_tokens": 12345
        }
      ]
    }
  }
}
```

Before (anthropic):

```json
{
  "language_models": {
    "anthropic": {
      "available_models": [
        {
          "custom": {
            "name": "my-custom-anthropic-model",
            "max_tokens": 12345
          }
        }
      ]
    }
  }
}
```

After (anthropic):

```json
{
  "language_models": {
    "anthropic": {
      "version": "1",
      "available_models": [
        {
          "name": "my-custom-anthropic-model",
          "max_tokens": 12345
        }
      ]
    }
  }
}

```

The settings will be auto-upgraded so the old versions will continue to
work (except for Google since that one has not been released).

/cc @as-cii 

Release Notes:

- N/A

---------

Co-authored-by: Thorsten <thorsten@zed.dev>
This commit is contained in:
Bennet Bo Fenner 2024-07-30 15:46:39 +02:00 committed by GitHub
parent 13dcb42c1c
commit 2ada2964c5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 256 additions and 47 deletions

View file

@ -865,16 +865,18 @@
// Different settings for specific language models. // Different settings for specific language models.
"language_models": { "language_models": {
"anthropic": { "anthropic": {
"version": "1",
"api_url": "https://api.anthropic.com" "api_url": "https://api.anthropic.com"
}, },
"openai": {
"api_url": "https://api.openai.com/v1"
},
"google": { "google": {
"api_url": "https://generativelanguage.googleapis.com" "api_url": "https://generativelanguage.googleapis.com"
}, },
"ollama": { "ollama": {
"api_url": "http://localhost:11434" "api_url": "http://localhost:11434"
},
"openai": {
"version": "1",
"api_url": "https://api.openai.com/v1"
} }
}, },
// Zed's Prettier integration settings. // Zed's Prettier integration settings.

View file

@ -110,11 +110,15 @@ impl AssistantSettingsContent {
move |content, _| { move |content, _| {
if content.anthropic.is_none() { if content.anthropic.is_none() {
content.anthropic = content.anthropic =
Some(language_model::settings::AnthropicSettingsContent { Some(language_model::settings::AnthropicSettingsContent::Versioned(
api_url, language_model::settings::VersionedAnthropicSettingsContent::V1(
low_speed_timeout_in_seconds, language_model::settings::AnthropicSettingsContentV1 {
..Default::default() api_url,
}); low_speed_timeout_in_seconds,
available_models: None
}
)
));
} }
}, },
), ),
@ -145,12 +149,27 @@ impl AssistantSettingsContent {
cx, cx,
move |content, _| { move |content, _| {
if content.openai.is_none() { if content.openai.is_none() {
let available_models = available_models.map(|models| {
models
.into_iter()
.filter_map(|model| match model {
open_ai::Model::Custom { name, max_tokens } => {
Some(language_model::provider::open_ai::AvailableModel { name, max_tokens })
}
_ => None,
})
.collect::<Vec<_>>()
});
content.openai = content.openai =
Some(language_model::settings::OpenAiSettingsContent { Some(language_model::settings::OpenAiSettingsContent::Versioned(
api_url, language_model::settings::VersionedOpenAiSettingsContent::V1(
low_speed_timeout_in_seconds, language_model::settings::OpenAiSettingsContentV1 {
available_models, api_url,
}); low_speed_timeout_in_seconds,
available_models
}
)
));
} }
}, },
), ),
@ -377,6 +396,7 @@ fn providers_schema(_: &mut schemars::gen::SchemaGenerator) -> schemars::schema:
schemars::schema::SchemaObject { schemars::schema::SchemaObject {
enum_values: Some(vec![ enum_values: Some(vec![
"anthropic".into(), "anthropic".into(),
"google".into(),
"ollama".into(), "ollama".into(),
"openai".into(), "openai".into(),
"zed.dev".into(), "zed.dev".into(),

View file

@ -37,6 +37,7 @@ menu.workspace = true
ollama = { workspace = true, features = ["schemars"] } ollama = { workspace = true, features = ["schemars"] }
open_ai = { workspace = true, features = ["schemars"] } open_ai = { workspace = true, features = ["schemars"] }
proto = { workspace = true, features = ["test-support"] } proto = { workspace = true, features = ["test-support"] }
project.workspace = true
schemars.workspace = true schemars.workspace = true
serde.workspace = true serde.workspace = true
serde_json.workspace = true serde_json.workspace = true

View file

@ -13,14 +13,15 @@ use futures::{future::BoxFuture, stream::BoxStream};
use gpui::{AnyView, AppContext, AsyncAppContext, SharedString, Task, WindowContext}; use gpui::{AnyView, AppContext, AsyncAppContext, SharedString, Task, WindowContext};
pub use model::*; pub use model::*;
use project::Fs;
pub use registry::*; pub use registry::*;
pub use request::*; pub use request::*;
pub use role::*; pub use role::*;
use schemars::JsonSchema; use schemars::JsonSchema;
use serde::de::DeserializeOwned; use serde::de::DeserializeOwned;
pub fn init(client: Arc<Client>, cx: &mut AppContext) { pub fn init(client: Arc<Client>, fs: Arc<dyn Fs>, cx: &mut AppContext) {
settings::init(cx); settings::init(fs, cx);
registry::init(client, cx); registry::init(client, cx);
} }

View file

@ -12,6 +12,8 @@ use gpui::{
WhiteSpace, WhiteSpace,
}; };
use http_client::HttpClient; use http_client::HttpClient;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsStore}; use settings::{Settings, SettingsStore};
use std::{sync::Arc, time::Duration}; use std::{sync::Arc, time::Duration};
use strum::IntoEnumIterator; use strum::IntoEnumIterator;
@ -26,7 +28,14 @@ const PROVIDER_NAME: &str = "Anthropic";
pub struct AnthropicSettings { pub struct AnthropicSettings {
pub api_url: String, pub api_url: String,
pub low_speed_timeout: Option<Duration>, pub low_speed_timeout: Option<Duration>,
pub available_models: Vec<anthropic::Model>, pub available_models: Vec<AvailableModel>,
pub needs_setting_migration: bool,
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
pub struct AvailableModel {
pub name: String,
pub max_tokens: usize,
} }
pub struct AnthropicLanguageModelProvider { pub struct AnthropicLanguageModelProvider {
@ -84,7 +93,13 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider {
.available_models .available_models
.iter() .iter()
{ {
models.insert(model.id().to_string(), model.clone()); models.insert(
model.name.clone(),
anthropic::Model::Custom {
name: model.name.clone(),
max_tokens: model.max_tokens,
},
);
} }
models models

View file

@ -8,6 +8,8 @@ use gpui::{
WhiteSpace, WhiteSpace,
}; };
use http_client::HttpClient; use http_client::HttpClient;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsStore}; use settings::{Settings, SettingsStore};
use std::{future, sync::Arc, time::Duration}; use std::{future, sync::Arc, time::Duration};
use strum::IntoEnumIterator; use strum::IntoEnumIterator;
@ -28,7 +30,13 @@ const PROVIDER_NAME: &str = "Google AI";
pub struct GoogleSettings { pub struct GoogleSettings {
pub api_url: String, pub api_url: String,
pub low_speed_timeout: Option<Duration>, pub low_speed_timeout: Option<Duration>,
pub available_models: Vec<google_ai::Model>, pub available_models: Vec<AvailableModel>,
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
pub struct AvailableModel {
name: String,
max_tokens: usize,
} }
pub struct GoogleLanguageModelProvider { pub struct GoogleLanguageModelProvider {
@ -86,7 +94,13 @@ impl LanguageModelProvider for GoogleLanguageModelProvider {
.google .google
.available_models .available_models
{ {
models.insert(model.id().to_string(), model.clone()); models.insert(
model.name.clone(),
google_ai::Model::Custom {
name: model.name.clone(),
max_tokens: model.max_tokens,
},
);
} }
models models

View file

@ -8,6 +8,8 @@ use gpui::{
}; };
use http_client::HttpClient; use http_client::HttpClient;
use open_ai::stream_completion; use open_ai::stream_completion;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsStore}; use settings::{Settings, SettingsStore};
use std::{future, sync::Arc, time::Duration}; use std::{future, sync::Arc, time::Duration};
use strum::IntoEnumIterator; use strum::IntoEnumIterator;
@ -28,7 +30,14 @@ const PROVIDER_NAME: &str = "OpenAI";
pub struct OpenAiSettings { pub struct OpenAiSettings {
pub api_url: String, pub api_url: String,
pub low_speed_timeout: Option<Duration>, pub low_speed_timeout: Option<Duration>,
pub available_models: Vec<open_ai::Model>, pub available_models: Vec<AvailableModel>,
pub needs_setting_migration: bool,
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
pub struct AvailableModel {
pub name: String,
pub max_tokens: usize,
} }
pub struct OpenAiLanguageModelProvider { pub struct OpenAiLanguageModelProvider {
@ -86,7 +95,13 @@ impl LanguageModelProvider for OpenAiLanguageModelProvider {
.openai .openai
.available_models .available_models
{ {
models.insert(model.id().to_string(), model.clone()); models.insert(
model.name.clone(),
open_ai::Model::Custom {
name: model.name.clone(),
max_tokens: model.max_tokens,
},
);
} }
models models

View file

@ -1,12 +1,14 @@
use std::time::Duration; use std::{sync::Arc, time::Duration};
use anyhow::Result; use anyhow::Result;
use gpui::AppContext; use gpui::AppContext;
use project::Fs;
use schemars::JsonSchema; use schemars::JsonSchema;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsSources}; use settings::{update_settings_file, Settings, SettingsSources};
use crate::provider::{ use crate::provider::{
self,
anthropic::AnthropicSettings, anthropic::AnthropicSettings,
cloud::{self, ZedDotDevSettings}, cloud::{self, ZedDotDevSettings},
copilot_chat::CopilotChatSettings, copilot_chat::CopilotChatSettings,
@ -16,8 +18,36 @@ use crate::provider::{
}; };
/// Initializes the language model settings. /// Initializes the language model settings.
pub fn init(cx: &mut AppContext) { pub fn init(fs: Arc<dyn Fs>, cx: &mut AppContext) {
AllLanguageModelSettings::register(cx); AllLanguageModelSettings::register(cx);
if AllLanguageModelSettings::get_global(cx)
.openai
.needs_setting_migration
{
update_settings_file::<AllLanguageModelSettings>(fs.clone(), cx, move |setting, _| {
if let Some(settings) = setting.openai.clone() {
let (newest_version, _) = settings.upgrade();
setting.openai = Some(OpenAiSettingsContent::Versioned(
VersionedOpenAiSettingsContent::V1(newest_version),
));
}
});
}
if AllLanguageModelSettings::get_global(cx)
.anthropic
.needs_setting_migration
{
update_settings_file::<AllLanguageModelSettings>(fs, cx, move |setting, _| {
if let Some(settings) = setting.anthropic.clone() {
let (newest_version, _) = settings.upgrade();
setting.anthropic = Some(AnthropicSettingsContent::Versioned(
VersionedAnthropicSettingsContent::V1(newest_version),
));
}
});
}
} }
#[derive(Default)] #[derive(Default)]
@ -41,31 +71,129 @@ pub struct AllLanguageModelSettingsContent {
pub copilot_chat: Option<CopilotChatSettingsContent>, pub copilot_chat: Option<CopilotChatSettingsContent>,
} }
#[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)] #[derive(Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
pub struct AnthropicSettingsContent { #[serde(untagged)]
pub enum AnthropicSettingsContent {
Legacy(LegacyAnthropicSettingsContent),
Versioned(VersionedAnthropicSettingsContent),
}
impl AnthropicSettingsContent {
pub fn upgrade(self) -> (AnthropicSettingsContentV1, bool) {
match self {
AnthropicSettingsContent::Legacy(content) => (
AnthropicSettingsContentV1 {
api_url: content.api_url,
low_speed_timeout_in_seconds: content.low_speed_timeout_in_seconds,
available_models: content.available_models.map(|models| {
models
.into_iter()
.filter_map(|model| match model {
anthropic::Model::Custom { name, max_tokens } => {
Some(provider::anthropic::AvailableModel { name, max_tokens })
}
_ => None,
})
.collect()
}),
},
true,
),
AnthropicSettingsContent::Versioned(content) => match content {
VersionedAnthropicSettingsContent::V1(content) => (content, false),
},
}
}
}
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
pub struct LegacyAnthropicSettingsContent {
pub api_url: Option<String>, pub api_url: Option<String>,
pub low_speed_timeout_in_seconds: Option<u64>, pub low_speed_timeout_in_seconds: Option<u64>,
pub available_models: Option<Vec<anthropic::Model>>, pub available_models: Option<Vec<anthropic::Model>>,
} }
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
#[serde(tag = "version")]
pub enum VersionedAnthropicSettingsContent {
#[serde(rename = "1")]
V1(AnthropicSettingsContentV1),
}
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
pub struct AnthropicSettingsContentV1 {
pub api_url: Option<String>,
pub low_speed_timeout_in_seconds: Option<u64>,
pub available_models: Option<Vec<provider::anthropic::AvailableModel>>,
}
#[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)] #[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
pub struct OllamaSettingsContent { pub struct OllamaSettingsContent {
pub api_url: Option<String>, pub api_url: Option<String>,
pub low_speed_timeout_in_seconds: Option<u64>, pub low_speed_timeout_in_seconds: Option<u64>,
} }
#[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)] #[derive(Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
pub struct OpenAiSettingsContent { #[serde(untagged)]
pub enum OpenAiSettingsContent {
Legacy(LegacyOpenAiSettingsContent),
Versioned(VersionedOpenAiSettingsContent),
}
impl OpenAiSettingsContent {
pub fn upgrade(self) -> (OpenAiSettingsContentV1, bool) {
match self {
OpenAiSettingsContent::Legacy(content) => (
OpenAiSettingsContentV1 {
api_url: content.api_url,
low_speed_timeout_in_seconds: content.low_speed_timeout_in_seconds,
available_models: content.available_models.map(|models| {
models
.into_iter()
.filter_map(|model| match model {
open_ai::Model::Custom { name, max_tokens } => {
Some(provider::open_ai::AvailableModel { name, max_tokens })
}
_ => None,
})
.collect()
}),
},
true,
),
OpenAiSettingsContent::Versioned(content) => match content {
VersionedOpenAiSettingsContent::V1(content) => (content, false),
},
}
}
}
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
pub struct LegacyOpenAiSettingsContent {
pub api_url: Option<String>, pub api_url: Option<String>,
pub low_speed_timeout_in_seconds: Option<u64>, pub low_speed_timeout_in_seconds: Option<u64>,
pub available_models: Option<Vec<open_ai::Model>>, pub available_models: Option<Vec<open_ai::Model>>,
} }
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
#[serde(tag = "version")]
pub enum VersionedOpenAiSettingsContent {
#[serde(rename = "1")]
V1(OpenAiSettingsContentV1),
}
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
pub struct OpenAiSettingsContentV1 {
pub api_url: Option<String>,
pub low_speed_timeout_in_seconds: Option<u64>,
pub available_models: Option<Vec<provider::open_ai::AvailableModel>>,
}
#[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)] #[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
pub struct GoogleSettingsContent { pub struct GoogleSettingsContent {
pub api_url: Option<String>, pub api_url: Option<String>,
pub low_speed_timeout_in_seconds: Option<u64>, pub low_speed_timeout_in_seconds: Option<u64>,
pub available_models: Option<Vec<google_ai::Model>>, pub available_models: Option<Vec<provider::google::AvailableModel>>,
} }
#[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)] #[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
@ -81,6 +209,8 @@ pub struct CopilotChatSettingsContent {
impl settings::Settings for AllLanguageModelSettings { impl settings::Settings for AllLanguageModelSettings {
const KEY: Option<&'static str> = Some("language_models"); const KEY: Option<&'static str> = Some("language_models");
const PRESERVED_KEYS: Option<&'static [&'static str]> = Some(&["version"]);
type FileContent = AllLanguageModelSettingsContent; type FileContent = AllLanguageModelSettingsContent;
fn load(sources: SettingsSources<Self::FileContent>, _: &mut AppContext) -> Result<Self> { fn load(sources: SettingsSources<Self::FileContent>, _: &mut AppContext) -> Result<Self> {
@ -93,12 +223,21 @@ impl settings::Settings for AllLanguageModelSettings {
let mut settings = AllLanguageModelSettings::default(); let mut settings = AllLanguageModelSettings::default();
for value in sources.defaults_and_customizations() { for value in sources.defaults_and_customizations() {
// Anthropic
let (anthropic, upgraded) = match value.anthropic.clone().map(|s| s.upgrade()) {
Some((content, upgraded)) => (Some(content), upgraded),
None => (None, false),
};
if upgraded {
settings.anthropic.needs_setting_migration = true;
}
merge( merge(
&mut settings.anthropic.api_url, &mut settings.anthropic.api_url,
value.anthropic.as_ref().and_then(|s| s.api_url.clone()), anthropic.as_ref().and_then(|s| s.api_url.clone()),
); );
if let Some(low_speed_timeout_in_seconds) = value if let Some(low_speed_timeout_in_seconds) = anthropic
.anthropic
.as_ref() .as_ref()
.and_then(|s| s.low_speed_timeout_in_seconds) .and_then(|s| s.low_speed_timeout_in_seconds)
{ {
@ -107,10 +246,7 @@ impl settings::Settings for AllLanguageModelSettings {
} }
merge( merge(
&mut settings.anthropic.available_models, &mut settings.anthropic.available_models,
value anthropic.as_ref().and_then(|s| s.available_models.clone()),
.anthropic
.as_ref()
.and_then(|s| s.available_models.clone()),
); );
merge( merge(
@ -126,24 +262,29 @@ impl settings::Settings for AllLanguageModelSettings {
Some(Duration::from_secs(low_speed_timeout_in_seconds)); Some(Duration::from_secs(low_speed_timeout_in_seconds));
} }
// OpenAI
let (openai, upgraded) = match value.openai.clone().map(|s| s.upgrade()) {
Some((content, upgraded)) => (Some(content), upgraded),
None => (None, false),
};
if upgraded {
settings.openai.needs_setting_migration = true;
}
merge( merge(
&mut settings.openai.api_url, &mut settings.openai.api_url,
value.openai.as_ref().and_then(|s| s.api_url.clone()), openai.as_ref().and_then(|s| s.api_url.clone()),
); );
if let Some(low_speed_timeout_in_seconds) = value if let Some(low_speed_timeout_in_seconds) =
.openai openai.as_ref().and_then(|s| s.low_speed_timeout_in_seconds)
.as_ref()
.and_then(|s| s.low_speed_timeout_in_seconds)
{ {
settings.openai.low_speed_timeout = settings.openai.low_speed_timeout =
Some(Duration::from_secs(low_speed_timeout_in_seconds)); Some(Duration::from_secs(low_speed_timeout_in_seconds));
} }
merge( merge(
&mut settings.openai.available_models, &mut settings.openai.available_models,
value openai.as_ref().and_then(|s| s.available_models.clone()),
.openai
.as_ref()
.and_then(|s| s.available_models.clone()),
); );
merge( merge(

View file

@ -174,7 +174,7 @@ fn init_common(app_state: Arc<AppState>, cx: &mut AppContext) {
cx, cx,
); );
supermaven::init(app_state.client.clone(), cx); supermaven::init(app_state.client.clone(), cx);
language_model::init(app_state.client.clone(), cx); language_model::init(app_state.client.clone(), app_state.fs.clone(), cx);
snippet_provider::init(cx); snippet_provider::init(cx);
inline_completion_registry::init(app_state.client.telemetry().clone(), cx); inline_completion_registry::init(app_state.client.telemetry().clone(), cx);
assistant::init(app_state.fs.clone(), app_state.client.clone(), cx); assistant::init(app_state.fs.clone(), app_state.client.clone(), cx);

View file

@ -3461,7 +3461,7 @@ mod tests {
app_state.client.http_client().clone(), app_state.client.http_client().clone(),
cx, cx,
); );
language_model::init(app_state.client.clone(), cx); language_model::init(app_state.client.clone(), app_state.fs.clone(), cx);
assistant::init(app_state.fs.clone(), app_state.client.clone(), cx); assistant::init(app_state.fs.clone(), app_state.client.clone(), cx);
repl::init( repl::init(
app_state.fs.clone(), app_state.fs.clone(),