language_models: Add support for images to Mistral models (#32154)
Tested with following models. Hallucinates with whites outline images like white lined zed logo but works fine with zed black outlined logo: Pixtral 12B (pixtral-12b-latest) Pixtral Large (pixtral-large-latest) Mistral Medium (mistral-medium-latest) Mistral Small (mistral-small-latest) After this PR, almost all of the zed's llm provider who support images are now supported. Only remaining one is LMStudio. Hopefully we will get that one as well soon. Release Notes: - Add support for images to mistral models --------- Signed-off-by: Umesh Yadav <git@umesh.dev> Co-authored-by: Bennet Bo Fenner <bennetbo@gmx.de> Co-authored-by: Bennet Bo Fenner <bennet@zed.dev>
This commit is contained in:
parent
4ac7935589
commit
0bc9478b46
3 changed files with 257 additions and 92 deletions
|
@ -60,6 +60,10 @@ pub enum Model {
|
|||
OpenCodestralMamba,
|
||||
#[serde(rename = "devstral-small-latest", alias = "devstral-small-latest")]
|
||||
DevstralSmallLatest,
|
||||
#[serde(rename = "pixtral-12b-latest", alias = "pixtral-12b-latest")]
|
||||
Pixtral12BLatest,
|
||||
#[serde(rename = "pixtral-large-latest", alias = "pixtral-large-latest")]
|
||||
PixtralLargeLatest,
|
||||
|
||||
#[serde(rename = "custom")]
|
||||
Custom {
|
||||
|
@ -70,6 +74,7 @@ pub enum Model {
|
|||
max_output_tokens: Option<u32>,
|
||||
max_completion_tokens: Option<u32>,
|
||||
supports_tools: Option<bool>,
|
||||
supports_images: Option<bool>,
|
||||
},
|
||||
}
|
||||
|
||||
|
@ -86,6 +91,9 @@ impl Model {
|
|||
"mistral-small-latest" => Ok(Self::MistralSmallLatest),
|
||||
"open-mistral-nemo" => Ok(Self::OpenMistralNemo),
|
||||
"open-codestral-mamba" => Ok(Self::OpenCodestralMamba),
|
||||
"devstral-small-latest" => Ok(Self::DevstralSmallLatest),
|
||||
"pixtral-12b-latest" => Ok(Self::Pixtral12BLatest),
|
||||
"pixtral-large-latest" => Ok(Self::PixtralLargeLatest),
|
||||
invalid_id => anyhow::bail!("invalid model id '{invalid_id}'"),
|
||||
}
|
||||
}
|
||||
|
@ -99,6 +107,8 @@ impl Model {
|
|||
Self::OpenMistralNemo => "open-mistral-nemo",
|
||||
Self::OpenCodestralMamba => "open-codestral-mamba",
|
||||
Self::DevstralSmallLatest => "devstral-small-latest",
|
||||
Self::Pixtral12BLatest => "pixtral-12b-latest",
|
||||
Self::PixtralLargeLatest => "pixtral-large-latest",
|
||||
Self::Custom { name, .. } => name,
|
||||
}
|
||||
}
|
||||
|
@ -112,6 +122,8 @@ impl Model {
|
|||
Self::OpenMistralNemo => "open-mistral-nemo",
|
||||
Self::OpenCodestralMamba => "open-codestral-mamba",
|
||||
Self::DevstralSmallLatest => "devstral-small-latest",
|
||||
Self::Pixtral12BLatest => "pixtral-12b-latest",
|
||||
Self::PixtralLargeLatest => "pixtral-large-latest",
|
||||
Self::Custom {
|
||||
name, display_name, ..
|
||||
} => display_name.as_ref().unwrap_or(name),
|
||||
|
@ -127,6 +139,8 @@ impl Model {
|
|||
Self::OpenMistralNemo => 131000,
|
||||
Self::OpenCodestralMamba => 256000,
|
||||
Self::DevstralSmallLatest => 262144,
|
||||
Self::Pixtral12BLatest => 128000,
|
||||
Self::PixtralLargeLatest => 128000,
|
||||
Self::Custom { max_tokens, .. } => *max_tokens,
|
||||
}
|
||||
}
|
||||
|
@ -148,10 +162,29 @@ impl Model {
|
|||
| Self::MistralSmallLatest
|
||||
| Self::OpenMistralNemo
|
||||
| Self::OpenCodestralMamba
|
||||
| Self::DevstralSmallLatest => true,
|
||||
| Self::DevstralSmallLatest
|
||||
| Self::Pixtral12BLatest
|
||||
| Self::PixtralLargeLatest => true,
|
||||
Self::Custom { supports_tools, .. } => supports_tools.unwrap_or(false),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn supports_images(&self) -> bool {
|
||||
match self {
|
||||
Self::Pixtral12BLatest
|
||||
| Self::PixtralLargeLatest
|
||||
| Self::MistralMediumLatest
|
||||
| Self::MistralSmallLatest => true,
|
||||
Self::CodestralLatest
|
||||
| Self::MistralLargeLatest
|
||||
| Self::OpenMistralNemo
|
||||
| Self::OpenCodestralMamba
|
||||
| Self::DevstralSmallLatest => false,
|
||||
Self::Custom {
|
||||
supports_images, ..
|
||||
} => supports_images.unwrap_or(false),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
|
@ -231,7 +264,8 @@ pub enum RequestMessage {
|
|||
tool_calls: Vec<ToolCall>,
|
||||
},
|
||||
User {
|
||||
content: String,
|
||||
#[serde(flatten)]
|
||||
content: MessageContent,
|
||||
},
|
||||
System {
|
||||
content: String,
|
||||
|
@ -242,6 +276,54 @@ pub enum RequestMessage {
|
|||
},
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
#[serde(untagged)]
|
||||
pub enum MessageContent {
|
||||
#[serde(rename = "content")]
|
||||
Plain { content: String },
|
||||
#[serde(rename = "content")]
|
||||
Multipart { content: Vec<MessagePart> },
|
||||
}
|
||||
|
||||
impl MessageContent {
|
||||
pub fn empty() -> Self {
|
||||
Self::Plain {
|
||||
content: String::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn push_part(&mut self, part: MessagePart) {
|
||||
match self {
|
||||
Self::Plain { content } => match part {
|
||||
MessagePart::Text { text } => {
|
||||
content.push_str(&text);
|
||||
}
|
||||
part => {
|
||||
let mut parts = if content.is_empty() {
|
||||
Vec::new()
|
||||
} else {
|
||||
vec![MessagePart::Text {
|
||||
text: content.clone(),
|
||||
}]
|
||||
};
|
||||
parts.push(part);
|
||||
*self = Self::Multipart { content: parts };
|
||||
}
|
||||
},
|
||||
Self::Multipart { content } => {
|
||||
content.push(part);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum MessagePart {
|
||||
Text { text: String },
|
||||
ImageUrl { image_url: String },
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
pub struct ToolCall {
|
||||
pub id: String,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue