Fix issues with Claude in Assistant2 (#12619)
Release Notes: - N/A --------- Co-authored-by: Nathan <nathan@zed.dev>
This commit is contained in:
parent
afc0650a49
commit
3cd6719b30
11 changed files with 276 additions and 213 deletions
|
@ -14,10 +14,10 @@ use serde::{
|
|||
use settings::{Settings, SettingsSources};
|
||||
use strum::{EnumIter, IntoEnumIterator};
|
||||
|
||||
use crate::LanguageModel;
|
||||
use crate::{preprocess_anthropic_request, LanguageModel, LanguageModelRequest};
|
||||
|
||||
#[derive(Clone, Debug, Default, PartialEq, EnumIter)]
|
||||
pub enum ZedDotDevModel {
|
||||
pub enum CloudModel {
|
||||
Gpt3Point5Turbo,
|
||||
Gpt4,
|
||||
Gpt4Turbo,
|
||||
|
@ -29,7 +29,7 @@ pub enum ZedDotDevModel {
|
|||
Custom(String),
|
||||
}
|
||||
|
||||
impl Serialize for ZedDotDevModel {
|
||||
impl Serialize for CloudModel {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: Serializer,
|
||||
|
@ -38,7 +38,7 @@ impl Serialize for ZedDotDevModel {
|
|||
}
|
||||
}
|
||||
|
||||
impl<'de> Deserialize<'de> for ZedDotDevModel {
|
||||
impl<'de> Deserialize<'de> for CloudModel {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
|
@ -46,7 +46,7 @@ impl<'de> Deserialize<'de> for ZedDotDevModel {
|
|||
struct ZedDotDevModelVisitor;
|
||||
|
||||
impl<'de> Visitor<'de> for ZedDotDevModelVisitor {
|
||||
type Value = ZedDotDevModel;
|
||||
type Value = CloudModel;
|
||||
|
||||
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
|
||||
formatter.write_str("a string for a ZedDotDevModel variant or a custom model")
|
||||
|
@ -56,9 +56,9 @@ impl<'de> Deserialize<'de> for ZedDotDevModel {
|
|||
where
|
||||
E: de::Error,
|
||||
{
|
||||
let model = ZedDotDevModel::iter()
|
||||
let model = CloudModel::iter()
|
||||
.find(|model| model.id() == value)
|
||||
.unwrap_or_else(|| ZedDotDevModel::Custom(value.to_string()));
|
||||
.unwrap_or_else(|| CloudModel::Custom(value.to_string()));
|
||||
Ok(model)
|
||||
}
|
||||
}
|
||||
|
@ -67,13 +67,13 @@ impl<'de> Deserialize<'de> for ZedDotDevModel {
|
|||
}
|
||||
}
|
||||
|
||||
impl JsonSchema for ZedDotDevModel {
|
||||
impl JsonSchema for CloudModel {
|
||||
fn schema_name() -> String {
|
||||
"ZedDotDevModel".to_owned()
|
||||
}
|
||||
|
||||
fn json_schema(_generator: &mut schemars::gen::SchemaGenerator) -> Schema {
|
||||
let variants = ZedDotDevModel::iter()
|
||||
let variants = CloudModel::iter()
|
||||
.filter_map(|model| {
|
||||
let id = model.id();
|
||||
if id.is_empty() {
|
||||
|
@ -88,7 +88,7 @@ impl JsonSchema for ZedDotDevModel {
|
|||
enum_values: Some(variants.iter().map(|s| s.clone().into()).collect()),
|
||||
metadata: Some(Box::new(Metadata {
|
||||
title: Some("ZedDotDevModel".to_owned()),
|
||||
default: Some(ZedDotDevModel::default().id().into()),
|
||||
default: Some(CloudModel::default().id().into()),
|
||||
examples: variants.into_iter().map(Into::into).collect(),
|
||||
..Default::default()
|
||||
})),
|
||||
|
@ -97,7 +97,7 @@ impl JsonSchema for ZedDotDevModel {
|
|||
}
|
||||
}
|
||||
|
||||
impl ZedDotDevModel {
|
||||
impl CloudModel {
|
||||
pub fn id(&self) -> &str {
|
||||
match self {
|
||||
Self::Gpt3Point5Turbo => "gpt-3.5-turbo",
|
||||
|
@ -133,6 +133,15 @@ impl ZedDotDevModel {
|
|||
Self::Custom(_) => 4096, // TODO: Make this configurable
|
||||
}
|
||||
}
|
||||
|
||||
pub fn preprocess_request(&self, request: &mut LanguageModelRequest) {
|
||||
match self {
|
||||
Self::Claude3Opus | Self::Claude3Sonnet | Self::Claude3Haiku => {
|
||||
preprocess_anthropic_request(request)
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Default, Debug, Serialize, Deserialize, JsonSchema)]
|
||||
|
@ -147,7 +156,7 @@ pub enum AssistantDockPosition {
|
|||
#[derive(Debug, PartialEq)]
|
||||
pub enum AssistantProvider {
|
||||
ZedDotDev {
|
||||
model: ZedDotDevModel,
|
||||
model: CloudModel,
|
||||
},
|
||||
OpenAi {
|
||||
model: OpenAiModel,
|
||||
|
@ -175,9 +184,7 @@ impl Default for AssistantProvider {
|
|||
#[serde(tag = "name", rename_all = "snake_case")]
|
||||
pub enum AssistantProviderContent {
|
||||
#[serde(rename = "zed.dev")]
|
||||
ZedDotDev {
|
||||
default_model: Option<ZedDotDevModel>,
|
||||
},
|
||||
ZedDotDev { default_model: Option<CloudModel> },
|
||||
#[serde(rename = "openai")]
|
||||
OpenAi {
|
||||
default_model: Option<OpenAiModel>,
|
||||
|
@ -281,7 +288,7 @@ impl AssistantSettingsContent {
|
|||
Some(AssistantProviderContent::ZedDotDev {
|
||||
default_model: model,
|
||||
}) => {
|
||||
if let LanguageModel::ZedDotDev(new_model) = new_model {
|
||||
if let LanguageModel::Cloud(new_model) = new_model {
|
||||
*model = Some(new_model);
|
||||
}
|
||||
}
|
||||
|
@ -302,7 +309,7 @@ impl AssistantSettingsContent {
|
|||
}
|
||||
}
|
||||
provider => match new_model {
|
||||
LanguageModel::ZedDotDev(model) => {
|
||||
LanguageModel::Cloud(model) => {
|
||||
*provider = Some(AssistantProviderContent::ZedDotDev {
|
||||
default_model: Some(model),
|
||||
})
|
||||
|
@ -613,7 +620,7 @@ mod tests {
|
|||
assert_eq!(
|
||||
AssistantSettings::get_global(cx).provider,
|
||||
AssistantProvider::ZedDotDev {
|
||||
model: ZedDotDevModel::Custom("custom".into())
|
||||
model: CloudModel::Custom("custom".into())
|
||||
}
|
||||
);
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue