assistant: Make it easier to define custom models (#15442)
This PR makes it easier to specify custom models for the Google, OpenAI, and Anthropic provider: Before (google): ```json { "language_models": { "google": { "available_models": [ { "custom": { "name": "my-custom-google-model", "max_tokens": 12345 } } ] } } } ``` After (google): ```json { "language_models": { "google": { "available_models": [ { "name": "my-custom-google-model", "max_tokens": 12345 } ] } } } ``` Before (anthropic): ```json { "language_models": { "anthropic": { "available_models": [ { "custom": { "name": "my-custom-anthropic-model", "max_tokens": 12345 } } ] } } } ``` After (anthropic): ```json { "language_models": { "anthropic": { "version": "1", "available_models": [ { "name": "my-custom-anthropic-model", "max_tokens": 12345 } ] } } } ``` The settings will be auto-upgraded so the old versions will continue to work (except for Google since that one has not been released). /cc @as-cii Release Notes: - N/A --------- Co-authored-by: Thorsten <thorsten@zed.dev>
This commit is contained in:
parent
13dcb42c1c
commit
2ada2964c5
10 changed files with 256 additions and 47 deletions
|
@ -37,6 +37,7 @@ menu.workspace = true
|
|||
ollama = { workspace = true, features = ["schemars"] }
|
||||
open_ai = { workspace = true, features = ["schemars"] }
|
||||
proto = { workspace = true, features = ["test-support"] }
|
||||
project.workspace = true
|
||||
schemars.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
|
|
|
@ -13,14 +13,15 @@ use futures::{future::BoxFuture, stream::BoxStream};
|
|||
use gpui::{AnyView, AppContext, AsyncAppContext, SharedString, Task, WindowContext};
|
||||
|
||||
pub use model::*;
|
||||
use project::Fs;
|
||||
pub use registry::*;
|
||||
pub use request::*;
|
||||
pub use role::*;
|
||||
use schemars::JsonSchema;
|
||||
use serde::de::DeserializeOwned;
|
||||
|
||||
pub fn init(client: Arc<Client>, cx: &mut AppContext) {
|
||||
settings::init(cx);
|
||||
pub fn init(client: Arc<Client>, fs: Arc<dyn Fs>, cx: &mut AppContext) {
|
||||
settings::init(fs, cx);
|
||||
registry::init(client, cx);
|
||||
}
|
||||
|
||||
|
|
|
@ -12,6 +12,8 @@ use gpui::{
|
|||
WhiteSpace,
|
||||
};
|
||||
use http_client::HttpClient;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use settings::{Settings, SettingsStore};
|
||||
use std::{sync::Arc, time::Duration};
|
||||
use strum::IntoEnumIterator;
|
||||
|
@ -26,7 +28,14 @@ const PROVIDER_NAME: &str = "Anthropic";
|
|||
pub struct AnthropicSettings {
|
||||
pub api_url: String,
|
||||
pub low_speed_timeout: Option<Duration>,
|
||||
pub available_models: Vec<anthropic::Model>,
|
||||
pub available_models: Vec<AvailableModel>,
|
||||
pub needs_setting_migration: bool,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct AvailableModel {
|
||||
pub name: String,
|
||||
pub max_tokens: usize,
|
||||
}
|
||||
|
||||
pub struct AnthropicLanguageModelProvider {
|
||||
|
@ -84,7 +93,13 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider {
|
|||
.available_models
|
||||
.iter()
|
||||
{
|
||||
models.insert(model.id().to_string(), model.clone());
|
||||
models.insert(
|
||||
model.name.clone(),
|
||||
anthropic::Model::Custom {
|
||||
name: model.name.clone(),
|
||||
max_tokens: model.max_tokens,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
models
|
||||
|
|
|
@ -8,6 +8,8 @@ use gpui::{
|
|||
WhiteSpace,
|
||||
};
|
||||
use http_client::HttpClient;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use settings::{Settings, SettingsStore};
|
||||
use std::{future, sync::Arc, time::Duration};
|
||||
use strum::IntoEnumIterator;
|
||||
|
@ -28,7 +30,13 @@ const PROVIDER_NAME: &str = "Google AI";
|
|||
pub struct GoogleSettings {
|
||||
pub api_url: String,
|
||||
pub low_speed_timeout: Option<Duration>,
|
||||
pub available_models: Vec<google_ai::Model>,
|
||||
pub available_models: Vec<AvailableModel>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct AvailableModel {
|
||||
name: String,
|
||||
max_tokens: usize,
|
||||
}
|
||||
|
||||
pub struct GoogleLanguageModelProvider {
|
||||
|
@ -86,7 +94,13 @@ impl LanguageModelProvider for GoogleLanguageModelProvider {
|
|||
.google
|
||||
.available_models
|
||||
{
|
||||
models.insert(model.id().to_string(), model.clone());
|
||||
models.insert(
|
||||
model.name.clone(),
|
||||
google_ai::Model::Custom {
|
||||
name: model.name.clone(),
|
||||
max_tokens: model.max_tokens,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
models
|
||||
|
|
|
@ -8,6 +8,8 @@ use gpui::{
|
|||
};
|
||||
use http_client::HttpClient;
|
||||
use open_ai::stream_completion;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use settings::{Settings, SettingsStore};
|
||||
use std::{future, sync::Arc, time::Duration};
|
||||
use strum::IntoEnumIterator;
|
||||
|
@ -28,7 +30,14 @@ const PROVIDER_NAME: &str = "OpenAI";
|
|||
pub struct OpenAiSettings {
|
||||
pub api_url: String,
|
||||
pub low_speed_timeout: Option<Duration>,
|
||||
pub available_models: Vec<open_ai::Model>,
|
||||
pub available_models: Vec<AvailableModel>,
|
||||
pub needs_setting_migration: bool,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct AvailableModel {
|
||||
pub name: String,
|
||||
pub max_tokens: usize,
|
||||
}
|
||||
|
||||
pub struct OpenAiLanguageModelProvider {
|
||||
|
@ -86,7 +95,13 @@ impl LanguageModelProvider for OpenAiLanguageModelProvider {
|
|||
.openai
|
||||
.available_models
|
||||
{
|
||||
models.insert(model.id().to_string(), model.clone());
|
||||
models.insert(
|
||||
model.name.clone(),
|
||||
open_ai::Model::Custom {
|
||||
name: model.name.clone(),
|
||||
max_tokens: model.max_tokens,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
models
|
||||
|
|
|
@ -1,12 +1,14 @@
|
|||
use std::time::Duration;
|
||||
use std::{sync::Arc, time::Duration};
|
||||
|
||||
use anyhow::Result;
|
||||
use gpui::AppContext;
|
||||
use project::Fs;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use settings::{Settings, SettingsSources};
|
||||
use settings::{update_settings_file, Settings, SettingsSources};
|
||||
|
||||
use crate::provider::{
|
||||
self,
|
||||
anthropic::AnthropicSettings,
|
||||
cloud::{self, ZedDotDevSettings},
|
||||
copilot_chat::CopilotChatSettings,
|
||||
|
@ -16,8 +18,36 @@ use crate::provider::{
|
|||
};
|
||||
|
||||
/// Initializes the language model settings.
|
||||
pub fn init(cx: &mut AppContext) {
|
||||
pub fn init(fs: Arc<dyn Fs>, cx: &mut AppContext) {
|
||||
AllLanguageModelSettings::register(cx);
|
||||
|
||||
if AllLanguageModelSettings::get_global(cx)
|
||||
.openai
|
||||
.needs_setting_migration
|
||||
{
|
||||
update_settings_file::<AllLanguageModelSettings>(fs.clone(), cx, move |setting, _| {
|
||||
if let Some(settings) = setting.openai.clone() {
|
||||
let (newest_version, _) = settings.upgrade();
|
||||
setting.openai = Some(OpenAiSettingsContent::Versioned(
|
||||
VersionedOpenAiSettingsContent::V1(newest_version),
|
||||
));
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
if AllLanguageModelSettings::get_global(cx)
|
||||
.anthropic
|
||||
.needs_setting_migration
|
||||
{
|
||||
update_settings_file::<AllLanguageModelSettings>(fs, cx, move |setting, _| {
|
||||
if let Some(settings) = setting.anthropic.clone() {
|
||||
let (newest_version, _) = settings.upgrade();
|
||||
setting.anthropic = Some(AnthropicSettingsContent::Versioned(
|
||||
VersionedAnthropicSettingsContent::V1(newest_version),
|
||||
));
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
|
@ -41,31 +71,129 @@ pub struct AllLanguageModelSettingsContent {
|
|||
pub copilot_chat: Option<CopilotChatSettingsContent>,
|
||||
}
|
||||
|
||||
#[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
|
||||
pub struct AnthropicSettingsContent {
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
|
||||
#[serde(untagged)]
|
||||
pub enum AnthropicSettingsContent {
|
||||
Legacy(LegacyAnthropicSettingsContent),
|
||||
Versioned(VersionedAnthropicSettingsContent),
|
||||
}
|
||||
|
||||
impl AnthropicSettingsContent {
|
||||
pub fn upgrade(self) -> (AnthropicSettingsContentV1, bool) {
|
||||
match self {
|
||||
AnthropicSettingsContent::Legacy(content) => (
|
||||
AnthropicSettingsContentV1 {
|
||||
api_url: content.api_url,
|
||||
low_speed_timeout_in_seconds: content.low_speed_timeout_in_seconds,
|
||||
available_models: content.available_models.map(|models| {
|
||||
models
|
||||
.into_iter()
|
||||
.filter_map(|model| match model {
|
||||
anthropic::Model::Custom { name, max_tokens } => {
|
||||
Some(provider::anthropic::AvailableModel { name, max_tokens })
|
||||
}
|
||||
_ => None,
|
||||
})
|
||||
.collect()
|
||||
}),
|
||||
},
|
||||
true,
|
||||
),
|
||||
AnthropicSettingsContent::Versioned(content) => match content {
|
||||
VersionedAnthropicSettingsContent::V1(content) => (content, false),
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
|
||||
pub struct LegacyAnthropicSettingsContent {
|
||||
pub api_url: Option<String>,
|
||||
pub low_speed_timeout_in_seconds: Option<u64>,
|
||||
pub available_models: Option<Vec<anthropic::Model>>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
|
||||
#[serde(tag = "version")]
|
||||
pub enum VersionedAnthropicSettingsContent {
|
||||
#[serde(rename = "1")]
|
||||
V1(AnthropicSettingsContentV1),
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
|
||||
pub struct AnthropicSettingsContentV1 {
|
||||
pub api_url: Option<String>,
|
||||
pub low_speed_timeout_in_seconds: Option<u64>,
|
||||
pub available_models: Option<Vec<provider::anthropic::AvailableModel>>,
|
||||
}
|
||||
|
||||
#[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
|
||||
pub struct OllamaSettingsContent {
|
||||
pub api_url: Option<String>,
|
||||
pub low_speed_timeout_in_seconds: Option<u64>,
|
||||
}
|
||||
|
||||
#[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
|
||||
pub struct OpenAiSettingsContent {
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
|
||||
#[serde(untagged)]
|
||||
pub enum OpenAiSettingsContent {
|
||||
Legacy(LegacyOpenAiSettingsContent),
|
||||
Versioned(VersionedOpenAiSettingsContent),
|
||||
}
|
||||
|
||||
impl OpenAiSettingsContent {
|
||||
pub fn upgrade(self) -> (OpenAiSettingsContentV1, bool) {
|
||||
match self {
|
||||
OpenAiSettingsContent::Legacy(content) => (
|
||||
OpenAiSettingsContentV1 {
|
||||
api_url: content.api_url,
|
||||
low_speed_timeout_in_seconds: content.low_speed_timeout_in_seconds,
|
||||
available_models: content.available_models.map(|models| {
|
||||
models
|
||||
.into_iter()
|
||||
.filter_map(|model| match model {
|
||||
open_ai::Model::Custom { name, max_tokens } => {
|
||||
Some(provider::open_ai::AvailableModel { name, max_tokens })
|
||||
}
|
||||
_ => None,
|
||||
})
|
||||
.collect()
|
||||
}),
|
||||
},
|
||||
true,
|
||||
),
|
||||
OpenAiSettingsContent::Versioned(content) => match content {
|
||||
VersionedOpenAiSettingsContent::V1(content) => (content, false),
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
|
||||
pub struct LegacyOpenAiSettingsContent {
|
||||
pub api_url: Option<String>,
|
||||
pub low_speed_timeout_in_seconds: Option<u64>,
|
||||
pub available_models: Option<Vec<open_ai::Model>>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
|
||||
#[serde(tag = "version")]
|
||||
pub enum VersionedOpenAiSettingsContent {
|
||||
#[serde(rename = "1")]
|
||||
V1(OpenAiSettingsContentV1),
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
|
||||
pub struct OpenAiSettingsContentV1 {
|
||||
pub api_url: Option<String>,
|
||||
pub low_speed_timeout_in_seconds: Option<u64>,
|
||||
pub available_models: Option<Vec<provider::open_ai::AvailableModel>>,
|
||||
}
|
||||
|
||||
#[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
|
||||
pub struct GoogleSettingsContent {
|
||||
pub api_url: Option<String>,
|
||||
pub low_speed_timeout_in_seconds: Option<u64>,
|
||||
pub available_models: Option<Vec<google_ai::Model>>,
|
||||
pub available_models: Option<Vec<provider::google::AvailableModel>>,
|
||||
}
|
||||
|
||||
#[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
|
||||
|
@ -81,6 +209,8 @@ pub struct CopilotChatSettingsContent {
|
|||
impl settings::Settings for AllLanguageModelSettings {
|
||||
const KEY: Option<&'static str> = Some("language_models");
|
||||
|
||||
const PRESERVED_KEYS: Option<&'static [&'static str]> = Some(&["version"]);
|
||||
|
||||
type FileContent = AllLanguageModelSettingsContent;
|
||||
|
||||
fn load(sources: SettingsSources<Self::FileContent>, _: &mut AppContext) -> Result<Self> {
|
||||
|
@ -93,12 +223,21 @@ impl settings::Settings for AllLanguageModelSettings {
|
|||
let mut settings = AllLanguageModelSettings::default();
|
||||
|
||||
for value in sources.defaults_and_customizations() {
|
||||
// Anthropic
|
||||
let (anthropic, upgraded) = match value.anthropic.clone().map(|s| s.upgrade()) {
|
||||
Some((content, upgraded)) => (Some(content), upgraded),
|
||||
None => (None, false),
|
||||
};
|
||||
|
||||
if upgraded {
|
||||
settings.anthropic.needs_setting_migration = true;
|
||||
}
|
||||
|
||||
merge(
|
||||
&mut settings.anthropic.api_url,
|
||||
value.anthropic.as_ref().and_then(|s| s.api_url.clone()),
|
||||
anthropic.as_ref().and_then(|s| s.api_url.clone()),
|
||||
);
|
||||
if let Some(low_speed_timeout_in_seconds) = value
|
||||
.anthropic
|
||||
if let Some(low_speed_timeout_in_seconds) = anthropic
|
||||
.as_ref()
|
||||
.and_then(|s| s.low_speed_timeout_in_seconds)
|
||||
{
|
||||
|
@ -107,10 +246,7 @@ impl settings::Settings for AllLanguageModelSettings {
|
|||
}
|
||||
merge(
|
||||
&mut settings.anthropic.available_models,
|
||||
value
|
||||
.anthropic
|
||||
.as_ref()
|
||||
.and_then(|s| s.available_models.clone()),
|
||||
anthropic.as_ref().and_then(|s| s.available_models.clone()),
|
||||
);
|
||||
|
||||
merge(
|
||||
|
@ -126,24 +262,29 @@ impl settings::Settings for AllLanguageModelSettings {
|
|||
Some(Duration::from_secs(low_speed_timeout_in_seconds));
|
||||
}
|
||||
|
||||
// OpenAI
|
||||
let (openai, upgraded) = match value.openai.clone().map(|s| s.upgrade()) {
|
||||
Some((content, upgraded)) => (Some(content), upgraded),
|
||||
None => (None, false),
|
||||
};
|
||||
|
||||
if upgraded {
|
||||
settings.openai.needs_setting_migration = true;
|
||||
}
|
||||
|
||||
merge(
|
||||
&mut settings.openai.api_url,
|
||||
value.openai.as_ref().and_then(|s| s.api_url.clone()),
|
||||
openai.as_ref().and_then(|s| s.api_url.clone()),
|
||||
);
|
||||
if let Some(low_speed_timeout_in_seconds) = value
|
||||
.openai
|
||||
.as_ref()
|
||||
.and_then(|s| s.low_speed_timeout_in_seconds)
|
||||
if let Some(low_speed_timeout_in_seconds) =
|
||||
openai.as_ref().and_then(|s| s.low_speed_timeout_in_seconds)
|
||||
{
|
||||
settings.openai.low_speed_timeout =
|
||||
Some(Duration::from_secs(low_speed_timeout_in_seconds));
|
||||
}
|
||||
merge(
|
||||
&mut settings.openai.available_models,
|
||||
value
|
||||
.openai
|
||||
.as_ref()
|
||||
.and_then(|s| s.available_models.clone()),
|
||||
openai.as_ref().and_then(|s| s.available_models.clone()),
|
||||
);
|
||||
|
||||
merge(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue