Allow using a custom model when using zed.dev (#14933)

Release Notes:

- N/A
This commit is contained in:
Antonio Scandurra 2024-07-22 12:25:53 +02:00 committed by GitHub
parent a334c69e05
commit 0155435142
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 114 additions and 110 deletions

View file

@ -20,6 +20,12 @@ pub enum Model {
Claude3Sonnet,
#[serde(alias = "claude-3-haiku", rename = "claude-3-haiku-20240307")]
Claude3Haiku,
#[serde(rename = "custom")]
Custom {
name: String,
#[serde(default)]
max_tokens: Option<usize>,
},
}
impl Model {
@ -33,30 +39,41 @@ impl Model {
} else if id.starts_with("claude-3-haiku") {
Ok(Self::Claude3Haiku)
} else {
Err(anyhow!("Invalid model id: {}", id))
Ok(Self::Custom {
name: id.to_string(),
max_tokens: None,
})
}
}
pub fn id(&self) -> &'static str {
pub fn id(&self) -> &str {
match self {
Model::Claude3_5Sonnet => "claude-3-5-sonnet-20240620",
Model::Claude3Opus => "claude-3-opus-20240229",
Model::Claude3Sonnet => "claude-3-sonnet-20240229",
Model::Claude3Haiku => "claude-3-opus-20240307",
Model::Custom { name, .. } => name,
}
}
pub fn display_name(&self) -> &'static str {
pub fn display_name(&self) -> &str {
match self {
Self::Claude3_5Sonnet => "Claude 3.5 Sonnet",
Self::Claude3Opus => "Claude 3 Opus",
Self::Claude3Sonnet => "Claude 3 Sonnet",
Self::Claude3Haiku => "Claude 3 Haiku",
Self::Custom { name, .. } => name,
}
}
pub fn max_token_count(&self) -> usize {
200_000
match self {
Self::Claude3_5Sonnet
| Self::Claude3Opus
| Self::Claude3Sonnet
| Self::Claude3Haiku => 200_000,
Self::Custom { max_tokens, .. } => max_tokens.unwrap_or(200_000),
}
}
}
@ -90,6 +107,7 @@ impl From<Role> for String {
#[derive(Debug, Serialize)]
pub struct Request {
#[serde(serialize_with = "serialize_request_model")]
pub model: Model,
pub messages: Vec<RequestMessage>,
pub stream: bool,
@ -97,6 +115,13 @@ pub struct Request {
pub max_tokens: u32,
}
fn serialize_request_model<S>(model: &Model, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serializer.serialize_str(&model.id())
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
pub struct RequestMessage {
pub role: Role,

View file

@ -668,7 +668,11 @@ mod tests {
"version": "1",
"provider": {
"name": "zed.dev",
"default_model": "custom"
"default_model": {
"custom": {
"name": "custom-provider"
}
}
}
}
}"#,
@ -679,7 +683,10 @@ mod tests {
assert_eq!(
AssistantSettings::get_global(cx).provider,
AssistantProvider::ZedDotDev {
model: CloudModel::Custom("custom".into())
model: CloudModel::Custom {
name: "custom-provider".into(),
max_tokens: None
}
}
);
}

View file

@ -4514,7 +4514,7 @@ impl RateLimit for CompleteWithLanguageModelRateLimit {
}
async fn complete_with_language_model(
request: proto::CompleteWithLanguageModel,
mut request: proto::CompleteWithLanguageModel,
response: StreamingResponse<proto::CompleteWithLanguageModel>,
session: Session,
open_ai_api_key: Option<Arc<str>>,
@ -4530,18 +4530,43 @@ async fn complete_with_language_model(
.check::<CompleteWithLanguageModelRateLimit>(session.user_id())
.await?;
if request.model.starts_with("gpt") {
let api_key =
open_ai_api_key.ok_or_else(|| anyhow!("no OpenAI API key configured on the server"))?;
complete_with_open_ai(request, response, session, api_key).await?;
} else if request.model.starts_with("gemini") {
let api_key = google_ai_api_key
.ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
complete_with_google_ai(request, response, session, api_key).await?;
} else if request.model.starts_with("claude") {
let api_key = anthropic_api_key
.ok_or_else(|| anyhow!("no Anthropic AI API key configured on the server"))?;
complete_with_anthropic(request, response, session, api_key).await?;
let mut provider_and_model = request.model.split('/');
let (provider, model) = match (
provider_and_model.next().unwrap(),
provider_and_model.next(),
) {
(provider, Some(model)) => (provider, model),
(model, None) => {
if model.starts_with("gpt") {
("openai", model)
} else if model.starts_with("gemini") {
("google", model)
} else if model.starts_with("claude") {
("anthropic", model)
} else {
("unknown", model)
}
}
};
let provider = provider.to_string();
request.model = model.to_string();
match provider.as_str() {
"openai" => {
let api_key = open_ai_api_key.context("no OpenAI API key configured on the server")?;
complete_with_open_ai(request, response, session, api_key).await?;
}
"anthropic" => {
let api_key =
anthropic_api_key.context("no Anthropic AI API key configured on the server")?;
complete_with_anthropic(request, response, session, api_key).await?;
}
"google" => {
let api_key =
google_ai_api_key.context("no Google AI API key configured on the server")?;
complete_with_google_ai(request, response, session, api_key).await?;
}
provider => return Err(anyhow!("unknown provider {:?}", provider))?,
}
Ok(())

View file

@ -54,15 +54,15 @@ impl CloudCompletionProvider {
impl LanguageModelCompletionProvider for CloudCompletionProvider {
fn available_models(&self) -> Vec<LanguageModel> {
let mut custom_model = if let CloudModel::Custom(custom_model) = self.model.clone() {
Some(custom_model)
let mut custom_model = if matches!(self.model, CloudModel::Custom { .. }) {
Some(self.model.clone())
} else {
None
};
CloudModel::iter()
.filter_map(move |model| {
if let CloudModel::Custom(_) = model {
Some(CloudModel::Custom(custom_model.take()?))
if let CloudModel::Custom { .. } = model {
custom_model.take()
} else {
Some(model)
}
@ -117,9 +117,9 @@ impl LanguageModelCompletionProvider for CloudCompletionProvider {
// Can't find a tokenizer for Claude 3, so for now just use the same as OpenAI's as an approximation.
count_open_ai_tokens(request, cx.background_executor())
}
LanguageModel::Cloud(CloudModel::Custom(model)) => {
LanguageModel::Cloud(CloudModel::Custom { name, .. }) => {
let request = self.client.request(proto::CountTokensWithLanguageModel {
model,
model: name,
messages: request
.messages
.iter()

View file

@ -241,6 +241,7 @@ pub fn count_open_ai_tokens(
| LanguageModel::Cloud(CloudModel::Claude3Opus)
| LanguageModel::Cloud(CloudModel::Claude3Sonnet)
| LanguageModel::Cloud(CloudModel::Claude3Haiku)
| LanguageModel::Cloud(CloudModel::Custom { .. })
| LanguageModel::OpenAi(OpenAiModel::Custom { .. }) => {
// Tiktoken doesn't yet support these models, so we manually use the
// same tokenizer as GPT-4.

View file

@ -2,100 +2,40 @@ use crate::LanguageModelRequest;
pub use anthropic::Model as AnthropicModel;
pub use ollama::Model as OllamaModel;
pub use open_ai::Model as OpenAiModel;
use schemars::{
schema::{InstanceType, Metadata, Schema, SchemaObject},
JsonSchema,
};
use serde::{
de::{self, Visitor},
Deserialize, Deserializer, Serialize, Serializer,
};
use std::fmt;
use strum::{EnumIter, IntoEnumIterator};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use strum::EnumIter;
#[derive(Clone, Debug, Default, PartialEq, EnumIter)]
#[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize, JsonSchema, EnumIter)]
pub enum CloudModel {
#[serde(rename = "gpt-3.5-turbo")]
Gpt3Point5Turbo,
#[serde(rename = "gpt-4")]
Gpt4,
#[serde(rename = "gpt-4-turbo-preview")]
Gpt4Turbo,
#[serde(rename = "gpt-4o")]
#[default]
Gpt4Omni,
#[serde(rename = "gpt-4o-mini")]
Gpt4OmniMini,
#[serde(rename = "claude-3-5-sonnet")]
Claude3_5Sonnet,
#[serde(rename = "claude-3-opus")]
Claude3Opus,
#[serde(rename = "claude-3-sonnet")]
Claude3Sonnet,
#[serde(rename = "claude-3-haiku")]
Claude3Haiku,
#[serde(rename = "gemini-1.5-pro")]
Gemini15Pro,
#[serde(rename = "gemini-1.5-flash")]
Gemini15Flash,
Custom(String),
}
impl Serialize for CloudModel {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_str(self.id())
}
}
impl<'de> Deserialize<'de> for CloudModel {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
struct ZedDotDevModelVisitor;
impl<'de> Visitor<'de> for ZedDotDevModelVisitor {
type Value = CloudModel;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("a string for a ZedDotDevModel variant or a custom model")
}
fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
where
E: de::Error,
{
let model = CloudModel::iter()
.find(|model| model.id() == value)
.unwrap_or_else(|| CloudModel::Custom(value.to_string()));
Ok(model)
}
}
deserializer.deserialize_str(ZedDotDevModelVisitor)
}
}
impl JsonSchema for CloudModel {
fn schema_name() -> String {
"ZedDotDevModel".to_owned()
}
fn json_schema(_generator: &mut schemars::gen::SchemaGenerator) -> Schema {
let variants = CloudModel::iter()
.filter_map(|model| {
let id = model.id();
if id.is_empty() {
None
} else {
Some(id.to_string())
}
})
.collect::<Vec<_>>();
Schema::Object(SchemaObject {
instance_type: Some(InstanceType::String.into()),
enum_values: Some(variants.iter().map(|s| s.clone().into()).collect()),
metadata: Some(Box::new(Metadata {
title: Some("ZedDotDevModel".to_owned()),
default: Some(CloudModel::default().id().into()),
examples: variants.into_iter().map(Into::into).collect(),
..Default::default()
})),
..Default::default()
})
}
#[serde(rename = "custom")]
Custom {
name: String,
max_tokens: Option<usize>,
},
}
impl CloudModel {
@ -112,7 +52,7 @@ impl CloudModel {
Self::Claude3Haiku => "claude-3-haiku",
Self::Gemini15Pro => "gemini-1.5-pro",
Self::Gemini15Flash => "gemini-1.5-flash",
Self::Custom(id) => id,
Self::Custom { name, .. } => name,
}
}
@ -129,7 +69,7 @@ impl CloudModel {
Self::Claude3Haiku => "Claude 3 Haiku",
Self::Gemini15Pro => "Gemini 1.5 Pro",
Self::Gemini15Flash => "Gemini 1.5 Flash",
Self::Custom(id) => id.as_str(),
Self::Custom { name, .. } => name,
}
}
@ -145,14 +85,20 @@ impl CloudModel {
| Self::Claude3Haiku => 200000,
Self::Gemini15Pro => 128000,
Self::Gemini15Flash => 32000,
Self::Custom(_) => 4096, // TODO: Make this configurable
Self::Custom { max_tokens, .. } => max_tokens.unwrap_or(200_000),
}
}
pub fn preprocess_request(&self, request: &mut LanguageModelRequest) {
match self {
Self::Claude3Opus | Self::Claude3Sonnet | Self::Claude3Haiku => {
request.preprocess_anthropic()
Self::Claude3Opus
| Self::Claude3Sonnet
| Self::Claude3Haiku
| Self::Claude3_5Sonnet => {
request.preprocess_anthropic();
}
Self::Custom { name, .. } if name.starts_with("anthropic/") => {
request.preprocess_anthropic();
}
_ => {}
}