Allow using a custom model when using zed.dev (#14933)
Release Notes: - N/A
This commit is contained in:
parent
a334c69e05
commit
0155435142
6 changed files with 114 additions and 110 deletions
|
@ -20,6 +20,12 @@ pub enum Model {
|
||||||
Claude3Sonnet,
|
Claude3Sonnet,
|
||||||
#[serde(alias = "claude-3-haiku", rename = "claude-3-haiku-20240307")]
|
#[serde(alias = "claude-3-haiku", rename = "claude-3-haiku-20240307")]
|
||||||
Claude3Haiku,
|
Claude3Haiku,
|
||||||
|
#[serde(rename = "custom")]
|
||||||
|
Custom {
|
||||||
|
name: String,
|
||||||
|
#[serde(default)]
|
||||||
|
max_tokens: Option<usize>,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Model {
|
impl Model {
|
||||||
|
@ -33,30 +39,41 @@ impl Model {
|
||||||
} else if id.starts_with("claude-3-haiku") {
|
} else if id.starts_with("claude-3-haiku") {
|
||||||
Ok(Self::Claude3Haiku)
|
Ok(Self::Claude3Haiku)
|
||||||
} else {
|
} else {
|
||||||
Err(anyhow!("Invalid model id: {}", id))
|
Ok(Self::Custom {
|
||||||
|
name: id.to_string(),
|
||||||
|
max_tokens: None,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn id(&self) -> &'static str {
|
pub fn id(&self) -> &str {
|
||||||
match self {
|
match self {
|
||||||
Model::Claude3_5Sonnet => "claude-3-5-sonnet-20240620",
|
Model::Claude3_5Sonnet => "claude-3-5-sonnet-20240620",
|
||||||
Model::Claude3Opus => "claude-3-opus-20240229",
|
Model::Claude3Opus => "claude-3-opus-20240229",
|
||||||
Model::Claude3Sonnet => "claude-3-sonnet-20240229",
|
Model::Claude3Sonnet => "claude-3-sonnet-20240229",
|
||||||
Model::Claude3Haiku => "claude-3-opus-20240307",
|
Model::Claude3Haiku => "claude-3-opus-20240307",
|
||||||
|
Model::Custom { name, .. } => name,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn display_name(&self) -> &'static str {
|
pub fn display_name(&self) -> &str {
|
||||||
match self {
|
match self {
|
||||||
Self::Claude3_5Sonnet => "Claude 3.5 Sonnet",
|
Self::Claude3_5Sonnet => "Claude 3.5 Sonnet",
|
||||||
Self::Claude3Opus => "Claude 3 Opus",
|
Self::Claude3Opus => "Claude 3 Opus",
|
||||||
Self::Claude3Sonnet => "Claude 3 Sonnet",
|
Self::Claude3Sonnet => "Claude 3 Sonnet",
|
||||||
Self::Claude3Haiku => "Claude 3 Haiku",
|
Self::Claude3Haiku => "Claude 3 Haiku",
|
||||||
|
Self::Custom { name, .. } => name,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn max_token_count(&self) -> usize {
|
pub fn max_token_count(&self) -> usize {
|
||||||
200_000
|
match self {
|
||||||
|
Self::Claude3_5Sonnet
|
||||||
|
| Self::Claude3Opus
|
||||||
|
| Self::Claude3Sonnet
|
||||||
|
| Self::Claude3Haiku => 200_000,
|
||||||
|
Self::Custom { max_tokens, .. } => max_tokens.unwrap_or(200_000),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -90,6 +107,7 @@ impl From<Role> for String {
|
||||||
|
|
||||||
#[derive(Debug, Serialize)]
|
#[derive(Debug, Serialize)]
|
||||||
pub struct Request {
|
pub struct Request {
|
||||||
|
#[serde(serialize_with = "serialize_request_model")]
|
||||||
pub model: Model,
|
pub model: Model,
|
||||||
pub messages: Vec<RequestMessage>,
|
pub messages: Vec<RequestMessage>,
|
||||||
pub stream: bool,
|
pub stream: bool,
|
||||||
|
@ -97,6 +115,13 @@ pub struct Request {
|
||||||
pub max_tokens: u32,
|
pub max_tokens: u32,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn serialize_request_model<S>(model: &Model, serializer: S) -> Result<S::Ok, S::Error>
|
||||||
|
where
|
||||||
|
S: serde::Serializer,
|
||||||
|
{
|
||||||
|
serializer.serialize_str(&model.id())
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||||
pub struct RequestMessage {
|
pub struct RequestMessage {
|
||||||
pub role: Role,
|
pub role: Role,
|
||||||
|
|
|
@ -668,7 +668,11 @@ mod tests {
|
||||||
"version": "1",
|
"version": "1",
|
||||||
"provider": {
|
"provider": {
|
||||||
"name": "zed.dev",
|
"name": "zed.dev",
|
||||||
"default_model": "custom"
|
"default_model": {
|
||||||
|
"custom": {
|
||||||
|
"name": "custom-provider"
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}"#,
|
}"#,
|
||||||
|
@ -679,7 +683,10 @@ mod tests {
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
AssistantSettings::get_global(cx).provider,
|
AssistantSettings::get_global(cx).provider,
|
||||||
AssistantProvider::ZedDotDev {
|
AssistantProvider::ZedDotDev {
|
||||||
model: CloudModel::Custom("custom".into())
|
model: CloudModel::Custom {
|
||||||
|
name: "custom-provider".into(),
|
||||||
|
max_tokens: None
|
||||||
|
}
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
|
@ -4514,7 +4514,7 @@ impl RateLimit for CompleteWithLanguageModelRateLimit {
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn complete_with_language_model(
|
async fn complete_with_language_model(
|
||||||
request: proto::CompleteWithLanguageModel,
|
mut request: proto::CompleteWithLanguageModel,
|
||||||
response: StreamingResponse<proto::CompleteWithLanguageModel>,
|
response: StreamingResponse<proto::CompleteWithLanguageModel>,
|
||||||
session: Session,
|
session: Session,
|
||||||
open_ai_api_key: Option<Arc<str>>,
|
open_ai_api_key: Option<Arc<str>>,
|
||||||
|
@ -4530,19 +4530,44 @@ async fn complete_with_language_model(
|
||||||
.check::<CompleteWithLanguageModelRateLimit>(session.user_id())
|
.check::<CompleteWithLanguageModelRateLimit>(session.user_id())
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
if request.model.starts_with("gpt") {
|
let mut provider_and_model = request.model.split('/');
|
||||||
let api_key =
|
let (provider, model) = match (
|
||||||
open_ai_api_key.ok_or_else(|| anyhow!("no OpenAI API key configured on the server"))?;
|
provider_and_model.next().unwrap(),
|
||||||
|
provider_and_model.next(),
|
||||||
|
) {
|
||||||
|
(provider, Some(model)) => (provider, model),
|
||||||
|
(model, None) => {
|
||||||
|
if model.starts_with("gpt") {
|
||||||
|
("openai", model)
|
||||||
|
} else if model.starts_with("gemini") {
|
||||||
|
("google", model)
|
||||||
|
} else if model.starts_with("claude") {
|
||||||
|
("anthropic", model)
|
||||||
|
} else {
|
||||||
|
("unknown", model)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let provider = provider.to_string();
|
||||||
|
request.model = model.to_string();
|
||||||
|
|
||||||
|
match provider.as_str() {
|
||||||
|
"openai" => {
|
||||||
|
let api_key = open_ai_api_key.context("no OpenAI API key configured on the server")?;
|
||||||
complete_with_open_ai(request, response, session, api_key).await?;
|
complete_with_open_ai(request, response, session, api_key).await?;
|
||||||
} else if request.model.starts_with("gemini") {
|
}
|
||||||
let api_key = google_ai_api_key
|
"anthropic" => {
|
||||||
.ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
|
let api_key =
|
||||||
complete_with_google_ai(request, response, session, api_key).await?;
|
anthropic_api_key.context("no Anthropic AI API key configured on the server")?;
|
||||||
} else if request.model.starts_with("claude") {
|
|
||||||
let api_key = anthropic_api_key
|
|
||||||
.ok_or_else(|| anyhow!("no Anthropic AI API key configured on the server"))?;
|
|
||||||
complete_with_anthropic(request, response, session, api_key).await?;
|
complete_with_anthropic(request, response, session, api_key).await?;
|
||||||
}
|
}
|
||||||
|
"google" => {
|
||||||
|
let api_key =
|
||||||
|
google_ai_api_key.context("no Google AI API key configured on the server")?;
|
||||||
|
complete_with_google_ai(request, response, session, api_key).await?;
|
||||||
|
}
|
||||||
|
provider => return Err(anyhow!("unknown provider {:?}", provider))?,
|
||||||
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
|
@ -54,15 +54,15 @@ impl CloudCompletionProvider {
|
||||||
|
|
||||||
impl LanguageModelCompletionProvider for CloudCompletionProvider {
|
impl LanguageModelCompletionProvider for CloudCompletionProvider {
|
||||||
fn available_models(&self) -> Vec<LanguageModel> {
|
fn available_models(&self) -> Vec<LanguageModel> {
|
||||||
let mut custom_model = if let CloudModel::Custom(custom_model) = self.model.clone() {
|
let mut custom_model = if matches!(self.model, CloudModel::Custom { .. }) {
|
||||||
Some(custom_model)
|
Some(self.model.clone())
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
};
|
};
|
||||||
CloudModel::iter()
|
CloudModel::iter()
|
||||||
.filter_map(move |model| {
|
.filter_map(move |model| {
|
||||||
if let CloudModel::Custom(_) = model {
|
if let CloudModel::Custom { .. } = model {
|
||||||
Some(CloudModel::Custom(custom_model.take()?))
|
custom_model.take()
|
||||||
} else {
|
} else {
|
||||||
Some(model)
|
Some(model)
|
||||||
}
|
}
|
||||||
|
@ -117,9 +117,9 @@ impl LanguageModelCompletionProvider for CloudCompletionProvider {
|
||||||
// Can't find a tokenizer for Claude 3, so for now just use the same as OpenAI's as an approximation.
|
// Can't find a tokenizer for Claude 3, so for now just use the same as OpenAI's as an approximation.
|
||||||
count_open_ai_tokens(request, cx.background_executor())
|
count_open_ai_tokens(request, cx.background_executor())
|
||||||
}
|
}
|
||||||
LanguageModel::Cloud(CloudModel::Custom(model)) => {
|
LanguageModel::Cloud(CloudModel::Custom { name, .. }) => {
|
||||||
let request = self.client.request(proto::CountTokensWithLanguageModel {
|
let request = self.client.request(proto::CountTokensWithLanguageModel {
|
||||||
model,
|
model: name,
|
||||||
messages: request
|
messages: request
|
||||||
.messages
|
.messages
|
||||||
.iter()
|
.iter()
|
||||||
|
|
|
@ -241,6 +241,7 @@ pub fn count_open_ai_tokens(
|
||||||
| LanguageModel::Cloud(CloudModel::Claude3Opus)
|
| LanguageModel::Cloud(CloudModel::Claude3Opus)
|
||||||
| LanguageModel::Cloud(CloudModel::Claude3Sonnet)
|
| LanguageModel::Cloud(CloudModel::Claude3Sonnet)
|
||||||
| LanguageModel::Cloud(CloudModel::Claude3Haiku)
|
| LanguageModel::Cloud(CloudModel::Claude3Haiku)
|
||||||
|
| LanguageModel::Cloud(CloudModel::Custom { .. })
|
||||||
| LanguageModel::OpenAi(OpenAiModel::Custom { .. }) => {
|
| LanguageModel::OpenAi(OpenAiModel::Custom { .. }) => {
|
||||||
// Tiktoken doesn't yet support these models, so we manually use the
|
// Tiktoken doesn't yet support these models, so we manually use the
|
||||||
// same tokenizer as GPT-4.
|
// same tokenizer as GPT-4.
|
||||||
|
|
|
@ -2,100 +2,40 @@ use crate::LanguageModelRequest;
|
||||||
pub use anthropic::Model as AnthropicModel;
|
pub use anthropic::Model as AnthropicModel;
|
||||||
pub use ollama::Model as OllamaModel;
|
pub use ollama::Model as OllamaModel;
|
||||||
pub use open_ai::Model as OpenAiModel;
|
pub use open_ai::Model as OpenAiModel;
|
||||||
use schemars::{
|
use schemars::JsonSchema;
|
||||||
schema::{InstanceType, Metadata, Schema, SchemaObject},
|
use serde::{Deserialize, Serialize};
|
||||||
JsonSchema,
|
use strum::EnumIter;
|
||||||
};
|
|
||||||
use serde::{
|
|
||||||
de::{self, Visitor},
|
|
||||||
Deserialize, Deserializer, Serialize, Serializer,
|
|
||||||
};
|
|
||||||
use std::fmt;
|
|
||||||
use strum::{EnumIter, IntoEnumIterator};
|
|
||||||
|
|
||||||
#[derive(Clone, Debug, Default, PartialEq, EnumIter)]
|
#[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize, JsonSchema, EnumIter)]
|
||||||
pub enum CloudModel {
|
pub enum CloudModel {
|
||||||
|
#[serde(rename = "gpt-3.5-turbo")]
|
||||||
Gpt3Point5Turbo,
|
Gpt3Point5Turbo,
|
||||||
|
#[serde(rename = "gpt-4")]
|
||||||
Gpt4,
|
Gpt4,
|
||||||
|
#[serde(rename = "gpt-4-turbo-preview")]
|
||||||
Gpt4Turbo,
|
Gpt4Turbo,
|
||||||
|
#[serde(rename = "gpt-4o")]
|
||||||
#[default]
|
#[default]
|
||||||
Gpt4Omni,
|
Gpt4Omni,
|
||||||
|
#[serde(rename = "gpt-4o-mini")]
|
||||||
Gpt4OmniMini,
|
Gpt4OmniMini,
|
||||||
|
#[serde(rename = "claude-3-5-sonnet")]
|
||||||
Claude3_5Sonnet,
|
Claude3_5Sonnet,
|
||||||
|
#[serde(rename = "claude-3-opus")]
|
||||||
Claude3Opus,
|
Claude3Opus,
|
||||||
|
#[serde(rename = "claude-3-sonnet")]
|
||||||
Claude3Sonnet,
|
Claude3Sonnet,
|
||||||
|
#[serde(rename = "claude-3-haiku")]
|
||||||
Claude3Haiku,
|
Claude3Haiku,
|
||||||
|
#[serde(rename = "gemini-1.5-pro")]
|
||||||
Gemini15Pro,
|
Gemini15Pro,
|
||||||
|
#[serde(rename = "gemini-1.5-flash")]
|
||||||
Gemini15Flash,
|
Gemini15Flash,
|
||||||
Custom(String),
|
#[serde(rename = "custom")]
|
||||||
}
|
Custom {
|
||||||
|
name: String,
|
||||||
impl Serialize for CloudModel {
|
max_tokens: Option<usize>,
|
||||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
},
|
||||||
where
|
|
||||||
S: Serializer,
|
|
||||||
{
|
|
||||||
serializer.serialize_str(self.id())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'de> Deserialize<'de> for CloudModel {
|
|
||||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
|
||||||
where
|
|
||||||
D: Deserializer<'de>,
|
|
||||||
{
|
|
||||||
struct ZedDotDevModelVisitor;
|
|
||||||
|
|
||||||
impl<'de> Visitor<'de> for ZedDotDevModelVisitor {
|
|
||||||
type Value = CloudModel;
|
|
||||||
|
|
||||||
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
|
|
||||||
formatter.write_str("a string for a ZedDotDevModel variant or a custom model")
|
|
||||||
}
|
|
||||||
|
|
||||||
fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
|
|
||||||
where
|
|
||||||
E: de::Error,
|
|
||||||
{
|
|
||||||
let model = CloudModel::iter()
|
|
||||||
.find(|model| model.id() == value)
|
|
||||||
.unwrap_or_else(|| CloudModel::Custom(value.to_string()));
|
|
||||||
Ok(model)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
deserializer.deserialize_str(ZedDotDevModelVisitor)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl JsonSchema for CloudModel {
|
|
||||||
fn schema_name() -> String {
|
|
||||||
"ZedDotDevModel".to_owned()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn json_schema(_generator: &mut schemars::gen::SchemaGenerator) -> Schema {
|
|
||||||
let variants = CloudModel::iter()
|
|
||||||
.filter_map(|model| {
|
|
||||||
let id = model.id();
|
|
||||||
if id.is_empty() {
|
|
||||||
None
|
|
||||||
} else {
|
|
||||||
Some(id.to_string())
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.collect::<Vec<_>>();
|
|
||||||
Schema::Object(SchemaObject {
|
|
||||||
instance_type: Some(InstanceType::String.into()),
|
|
||||||
enum_values: Some(variants.iter().map(|s| s.clone().into()).collect()),
|
|
||||||
metadata: Some(Box::new(Metadata {
|
|
||||||
title: Some("ZedDotDevModel".to_owned()),
|
|
||||||
default: Some(CloudModel::default().id().into()),
|
|
||||||
examples: variants.into_iter().map(Into::into).collect(),
|
|
||||||
..Default::default()
|
|
||||||
})),
|
|
||||||
..Default::default()
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl CloudModel {
|
impl CloudModel {
|
||||||
|
@ -112,7 +52,7 @@ impl CloudModel {
|
||||||
Self::Claude3Haiku => "claude-3-haiku",
|
Self::Claude3Haiku => "claude-3-haiku",
|
||||||
Self::Gemini15Pro => "gemini-1.5-pro",
|
Self::Gemini15Pro => "gemini-1.5-pro",
|
||||||
Self::Gemini15Flash => "gemini-1.5-flash",
|
Self::Gemini15Flash => "gemini-1.5-flash",
|
||||||
Self::Custom(id) => id,
|
Self::Custom { name, .. } => name,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -129,7 +69,7 @@ impl CloudModel {
|
||||||
Self::Claude3Haiku => "Claude 3 Haiku",
|
Self::Claude3Haiku => "Claude 3 Haiku",
|
||||||
Self::Gemini15Pro => "Gemini 1.5 Pro",
|
Self::Gemini15Pro => "Gemini 1.5 Pro",
|
||||||
Self::Gemini15Flash => "Gemini 1.5 Flash",
|
Self::Gemini15Flash => "Gemini 1.5 Flash",
|
||||||
Self::Custom(id) => id.as_str(),
|
Self::Custom { name, .. } => name,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -145,14 +85,20 @@ impl CloudModel {
|
||||||
| Self::Claude3Haiku => 200000,
|
| Self::Claude3Haiku => 200000,
|
||||||
Self::Gemini15Pro => 128000,
|
Self::Gemini15Pro => 128000,
|
||||||
Self::Gemini15Flash => 32000,
|
Self::Gemini15Flash => 32000,
|
||||||
Self::Custom(_) => 4096, // TODO: Make this configurable
|
Self::Custom { max_tokens, .. } => max_tokens.unwrap_or(200_000),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn preprocess_request(&self, request: &mut LanguageModelRequest) {
|
pub fn preprocess_request(&self, request: &mut LanguageModelRequest) {
|
||||||
match self {
|
match self {
|
||||||
Self::Claude3Opus | Self::Claude3Sonnet | Self::Claude3Haiku => {
|
Self::Claude3Opus
|
||||||
request.preprocess_anthropic()
|
| Self::Claude3Sonnet
|
||||||
|
| Self::Claude3Haiku
|
||||||
|
| Self::Claude3_5Sonnet => {
|
||||||
|
request.preprocess_anthropic();
|
||||||
|
}
|
||||||
|
Self::Custom { name, .. } if name.starts_with("anthropic/") => {
|
||||||
|
request.preprocess_anthropic();
|
||||||
}
|
}
|
||||||
_ => {}
|
_ => {}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue