Allow using a custom model when using zed.dev (#14933)
Release Notes: - N/A
This commit is contained in:
parent
a334c69e05
commit
0155435142
6 changed files with 114 additions and 110 deletions
|
@ -20,6 +20,12 @@ pub enum Model {
|
|||
Claude3Sonnet,
|
||||
#[serde(alias = "claude-3-haiku", rename = "claude-3-haiku-20240307")]
|
||||
Claude3Haiku,
|
||||
#[serde(rename = "custom")]
|
||||
Custom {
|
||||
name: String,
|
||||
#[serde(default)]
|
||||
max_tokens: Option<usize>,
|
||||
},
|
||||
}
|
||||
|
||||
impl Model {
|
||||
|
@ -33,30 +39,41 @@ impl Model {
|
|||
} else if id.starts_with("claude-3-haiku") {
|
||||
Ok(Self::Claude3Haiku)
|
||||
} else {
|
||||
Err(anyhow!("Invalid model id: {}", id))
|
||||
Ok(Self::Custom {
|
||||
name: id.to_string(),
|
||||
max_tokens: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub fn id(&self) -> &'static str {
|
||||
pub fn id(&self) -> &str {
|
||||
match self {
|
||||
Model::Claude3_5Sonnet => "claude-3-5-sonnet-20240620",
|
||||
Model::Claude3Opus => "claude-3-opus-20240229",
|
||||
Model::Claude3Sonnet => "claude-3-sonnet-20240229",
|
||||
Model::Claude3Haiku => "claude-3-opus-20240307",
|
||||
Model::Custom { name, .. } => name,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn display_name(&self) -> &'static str {
|
||||
pub fn display_name(&self) -> &str {
|
||||
match self {
|
||||
Self::Claude3_5Sonnet => "Claude 3.5 Sonnet",
|
||||
Self::Claude3Opus => "Claude 3 Opus",
|
||||
Self::Claude3Sonnet => "Claude 3 Sonnet",
|
||||
Self::Claude3Haiku => "Claude 3 Haiku",
|
||||
Self::Custom { name, .. } => name,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn max_token_count(&self) -> usize {
|
||||
200_000
|
||||
match self {
|
||||
Self::Claude3_5Sonnet
|
||||
| Self::Claude3Opus
|
||||
| Self::Claude3Sonnet
|
||||
| Self::Claude3Haiku => 200_000,
|
||||
Self::Custom { max_tokens, .. } => max_tokens.unwrap_or(200_000),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -90,6 +107,7 @@ impl From<Role> for String {
|
|||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct Request {
|
||||
#[serde(serialize_with = "serialize_request_model")]
|
||||
pub model: Model,
|
||||
pub messages: Vec<RequestMessage>,
|
||||
pub stream: bool,
|
||||
|
@ -97,6 +115,13 @@ pub struct Request {
|
|||
pub max_tokens: u32,
|
||||
}
|
||||
|
||||
fn serialize_request_model<S>(model: &Model, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: serde::Serializer,
|
||||
{
|
||||
serializer.serialize_str(&model.id())
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
pub struct RequestMessage {
|
||||
pub role: Role,
|
||||
|
|
|
@ -668,7 +668,11 @@ mod tests {
|
|||
"version": "1",
|
||||
"provider": {
|
||||
"name": "zed.dev",
|
||||
"default_model": "custom"
|
||||
"default_model": {
|
||||
"custom": {
|
||||
"name": "custom-provider"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}"#,
|
||||
|
@ -679,7 +683,10 @@ mod tests {
|
|||
assert_eq!(
|
||||
AssistantSettings::get_global(cx).provider,
|
||||
AssistantProvider::ZedDotDev {
|
||||
model: CloudModel::Custom("custom".into())
|
||||
model: CloudModel::Custom {
|
||||
name: "custom-provider".into(),
|
||||
max_tokens: None
|
||||
}
|
||||
}
|
||||
);
|
||||
}
|
||||
|
|
|
@ -4514,7 +4514,7 @@ impl RateLimit for CompleteWithLanguageModelRateLimit {
|
|||
}
|
||||
|
||||
async fn complete_with_language_model(
|
||||
request: proto::CompleteWithLanguageModel,
|
||||
mut request: proto::CompleteWithLanguageModel,
|
||||
response: StreamingResponse<proto::CompleteWithLanguageModel>,
|
||||
session: Session,
|
||||
open_ai_api_key: Option<Arc<str>>,
|
||||
|
@ -4530,18 +4530,43 @@ async fn complete_with_language_model(
|
|||
.check::<CompleteWithLanguageModelRateLimit>(session.user_id())
|
||||
.await?;
|
||||
|
||||
if request.model.starts_with("gpt") {
|
||||
let api_key =
|
||||
open_ai_api_key.ok_or_else(|| anyhow!("no OpenAI API key configured on the server"))?;
|
||||
complete_with_open_ai(request, response, session, api_key).await?;
|
||||
} else if request.model.starts_with("gemini") {
|
||||
let api_key = google_ai_api_key
|
||||
.ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
|
||||
complete_with_google_ai(request, response, session, api_key).await?;
|
||||
} else if request.model.starts_with("claude") {
|
||||
let api_key = anthropic_api_key
|
||||
.ok_or_else(|| anyhow!("no Anthropic AI API key configured on the server"))?;
|
||||
complete_with_anthropic(request, response, session, api_key).await?;
|
||||
let mut provider_and_model = request.model.split('/');
|
||||
let (provider, model) = match (
|
||||
provider_and_model.next().unwrap(),
|
||||
provider_and_model.next(),
|
||||
) {
|
||||
(provider, Some(model)) => (provider, model),
|
||||
(model, None) => {
|
||||
if model.starts_with("gpt") {
|
||||
("openai", model)
|
||||
} else if model.starts_with("gemini") {
|
||||
("google", model)
|
||||
} else if model.starts_with("claude") {
|
||||
("anthropic", model)
|
||||
} else {
|
||||
("unknown", model)
|
||||
}
|
||||
}
|
||||
};
|
||||
let provider = provider.to_string();
|
||||
request.model = model.to_string();
|
||||
|
||||
match provider.as_str() {
|
||||
"openai" => {
|
||||
let api_key = open_ai_api_key.context("no OpenAI API key configured on the server")?;
|
||||
complete_with_open_ai(request, response, session, api_key).await?;
|
||||
}
|
||||
"anthropic" => {
|
||||
let api_key =
|
||||
anthropic_api_key.context("no Anthropic AI API key configured on the server")?;
|
||||
complete_with_anthropic(request, response, session, api_key).await?;
|
||||
}
|
||||
"google" => {
|
||||
let api_key =
|
||||
google_ai_api_key.context("no Google AI API key configured on the server")?;
|
||||
complete_with_google_ai(request, response, session, api_key).await?;
|
||||
}
|
||||
provider => return Err(anyhow!("unknown provider {:?}", provider))?,
|
||||
}
|
||||
|
||||
Ok(())
|
||||
|
|
|
@ -54,15 +54,15 @@ impl CloudCompletionProvider {
|
|||
|
||||
impl LanguageModelCompletionProvider for CloudCompletionProvider {
|
||||
fn available_models(&self) -> Vec<LanguageModel> {
|
||||
let mut custom_model = if let CloudModel::Custom(custom_model) = self.model.clone() {
|
||||
Some(custom_model)
|
||||
let mut custom_model = if matches!(self.model, CloudModel::Custom { .. }) {
|
||||
Some(self.model.clone())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
CloudModel::iter()
|
||||
.filter_map(move |model| {
|
||||
if let CloudModel::Custom(_) = model {
|
||||
Some(CloudModel::Custom(custom_model.take()?))
|
||||
if let CloudModel::Custom { .. } = model {
|
||||
custom_model.take()
|
||||
} else {
|
||||
Some(model)
|
||||
}
|
||||
|
@ -117,9 +117,9 @@ impl LanguageModelCompletionProvider for CloudCompletionProvider {
|
|||
// Can't find a tokenizer for Claude 3, so for now just use the same as OpenAI's as an approximation.
|
||||
count_open_ai_tokens(request, cx.background_executor())
|
||||
}
|
||||
LanguageModel::Cloud(CloudModel::Custom(model)) => {
|
||||
LanguageModel::Cloud(CloudModel::Custom { name, .. }) => {
|
||||
let request = self.client.request(proto::CountTokensWithLanguageModel {
|
||||
model,
|
||||
model: name,
|
||||
messages: request
|
||||
.messages
|
||||
.iter()
|
||||
|
|
|
@ -241,6 +241,7 @@ pub fn count_open_ai_tokens(
|
|||
| LanguageModel::Cloud(CloudModel::Claude3Opus)
|
||||
| LanguageModel::Cloud(CloudModel::Claude3Sonnet)
|
||||
| LanguageModel::Cloud(CloudModel::Claude3Haiku)
|
||||
| LanguageModel::Cloud(CloudModel::Custom { .. })
|
||||
| LanguageModel::OpenAi(OpenAiModel::Custom { .. }) => {
|
||||
// Tiktoken doesn't yet support these models, so we manually use the
|
||||
// same tokenizer as GPT-4.
|
||||
|
|
|
@ -2,100 +2,40 @@ use crate::LanguageModelRequest;
|
|||
pub use anthropic::Model as AnthropicModel;
|
||||
pub use ollama::Model as OllamaModel;
|
||||
pub use open_ai::Model as OpenAiModel;
|
||||
use schemars::{
|
||||
schema::{InstanceType, Metadata, Schema, SchemaObject},
|
||||
JsonSchema,
|
||||
};
|
||||
use serde::{
|
||||
de::{self, Visitor},
|
||||
Deserialize, Deserializer, Serialize, Serializer,
|
||||
};
|
||||
use std::fmt;
|
||||
use strum::{EnumIter, IntoEnumIterator};
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use strum::EnumIter;
|
||||
|
||||
#[derive(Clone, Debug, Default, PartialEq, EnumIter)]
|
||||
#[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize, JsonSchema, EnumIter)]
|
||||
pub enum CloudModel {
|
||||
#[serde(rename = "gpt-3.5-turbo")]
|
||||
Gpt3Point5Turbo,
|
||||
#[serde(rename = "gpt-4")]
|
||||
Gpt4,
|
||||
#[serde(rename = "gpt-4-turbo-preview")]
|
||||
Gpt4Turbo,
|
||||
#[serde(rename = "gpt-4o")]
|
||||
#[default]
|
||||
Gpt4Omni,
|
||||
#[serde(rename = "gpt-4o-mini")]
|
||||
Gpt4OmniMini,
|
||||
#[serde(rename = "claude-3-5-sonnet")]
|
||||
Claude3_5Sonnet,
|
||||
#[serde(rename = "claude-3-opus")]
|
||||
Claude3Opus,
|
||||
#[serde(rename = "claude-3-sonnet")]
|
||||
Claude3Sonnet,
|
||||
#[serde(rename = "claude-3-haiku")]
|
||||
Claude3Haiku,
|
||||
#[serde(rename = "gemini-1.5-pro")]
|
||||
Gemini15Pro,
|
||||
#[serde(rename = "gemini-1.5-flash")]
|
||||
Gemini15Flash,
|
||||
Custom(String),
|
||||
}
|
||||
|
||||
impl Serialize for CloudModel {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: Serializer,
|
||||
{
|
||||
serializer.serialize_str(self.id())
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> Deserialize<'de> for CloudModel {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
struct ZedDotDevModelVisitor;
|
||||
|
||||
impl<'de> Visitor<'de> for ZedDotDevModelVisitor {
|
||||
type Value = CloudModel;
|
||||
|
||||
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
|
||||
formatter.write_str("a string for a ZedDotDevModel variant or a custom model")
|
||||
}
|
||||
|
||||
fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
|
||||
where
|
||||
E: de::Error,
|
||||
{
|
||||
let model = CloudModel::iter()
|
||||
.find(|model| model.id() == value)
|
||||
.unwrap_or_else(|| CloudModel::Custom(value.to_string()));
|
||||
Ok(model)
|
||||
}
|
||||
}
|
||||
|
||||
deserializer.deserialize_str(ZedDotDevModelVisitor)
|
||||
}
|
||||
}
|
||||
|
||||
impl JsonSchema for CloudModel {
|
||||
fn schema_name() -> String {
|
||||
"ZedDotDevModel".to_owned()
|
||||
}
|
||||
|
||||
fn json_schema(_generator: &mut schemars::gen::SchemaGenerator) -> Schema {
|
||||
let variants = CloudModel::iter()
|
||||
.filter_map(|model| {
|
||||
let id = model.id();
|
||||
if id.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(id.to_string())
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
Schema::Object(SchemaObject {
|
||||
instance_type: Some(InstanceType::String.into()),
|
||||
enum_values: Some(variants.iter().map(|s| s.clone().into()).collect()),
|
||||
metadata: Some(Box::new(Metadata {
|
||||
title: Some("ZedDotDevModel".to_owned()),
|
||||
default: Some(CloudModel::default().id().into()),
|
||||
examples: variants.into_iter().map(Into::into).collect(),
|
||||
..Default::default()
|
||||
})),
|
||||
..Default::default()
|
||||
})
|
||||
}
|
||||
#[serde(rename = "custom")]
|
||||
Custom {
|
||||
name: String,
|
||||
max_tokens: Option<usize>,
|
||||
},
|
||||
}
|
||||
|
||||
impl CloudModel {
|
||||
|
@ -112,7 +52,7 @@ impl CloudModel {
|
|||
Self::Claude3Haiku => "claude-3-haiku",
|
||||
Self::Gemini15Pro => "gemini-1.5-pro",
|
||||
Self::Gemini15Flash => "gemini-1.5-flash",
|
||||
Self::Custom(id) => id,
|
||||
Self::Custom { name, .. } => name,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -129,7 +69,7 @@ impl CloudModel {
|
|||
Self::Claude3Haiku => "Claude 3 Haiku",
|
||||
Self::Gemini15Pro => "Gemini 1.5 Pro",
|
||||
Self::Gemini15Flash => "Gemini 1.5 Flash",
|
||||
Self::Custom(id) => id.as_str(),
|
||||
Self::Custom { name, .. } => name,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -145,14 +85,20 @@ impl CloudModel {
|
|||
| Self::Claude3Haiku => 200000,
|
||||
Self::Gemini15Pro => 128000,
|
||||
Self::Gemini15Flash => 32000,
|
||||
Self::Custom(_) => 4096, // TODO: Make this configurable
|
||||
Self::Custom { max_tokens, .. } => max_tokens.unwrap_or(200_000),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn preprocess_request(&self, request: &mut LanguageModelRequest) {
|
||||
match self {
|
||||
Self::Claude3Opus | Self::Claude3Sonnet | Self::Claude3Haiku => {
|
||||
request.preprocess_anthropic()
|
||||
Self::Claude3Opus
|
||||
| Self::Claude3Sonnet
|
||||
| Self::Claude3Haiku
|
||||
| Self::Claude3_5Sonnet => {
|
||||
request.preprocess_anthropic();
|
||||
}
|
||||
Self::Custom { name, .. } if name.starts_with("anthropic/") => {
|
||||
request.preprocess_anthropic();
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue