language_models: Add thinking to Mistral Provider (#32476)

Tested prompt:

John is one of 4 children. The first sister is 4 years old. Next year,
the second sister will be twice as old as the first sister. The third
sister is two years older than the second sister. The third sister is
half the age of her older brother. How old is John? Return your thinking
inside <think></think>

Release Notes:

- Add thinking to Mistral Provider

---------

Signed-off-by: Umesh Yadav <git@umesh.dev>
Co-authored-by: Peter Tripp <peter@zed.dev>
This commit is contained in:
Umesh Yadav 2025-08-10 00:55:47 +05:30 committed by GitHub
parent 021681d456
commit ce39644cbd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 126 additions and 59 deletions

View file

@ -86,6 +86,7 @@ pub enum Model {
max_completion_tokens: Option<u64>,
supports_tools: Option<bool>,
supports_images: Option<bool>,
supports_thinking: Option<bool>,
},
}
@ -214,6 +215,16 @@ impl Model {
} => supports_images.unwrap_or(false),
}
}
pub fn supports_thinking(&self) -> bool {
match self {
Self::MagistralMediumLatest | Self::MagistralSmallLatest => true,
Self::Custom {
supports_thinking, ..
} => supports_thinking.unwrap_or(false),
_ => false,
}
}
}
#[derive(Debug, Serialize, Deserialize)]
@ -288,7 +299,9 @@ pub enum ToolChoice {
#[serde(tag = "role", rename_all = "lowercase")]
pub enum RequestMessage {
Assistant {
content: Option<String>,
#[serde(flatten)]
#[serde(default, skip_serializing_if = "Option::is_none")]
content: Option<MessageContent>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
tool_calls: Vec<ToolCall>,
},
@ -297,7 +310,8 @@ pub enum RequestMessage {
content: MessageContent,
},
System {
content: String,
#[serde(flatten)]
content: MessageContent,
},
Tool {
content: String,
@ -305,7 +319,7 @@ pub enum RequestMessage {
},
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq, Clone)]
#[serde(untagged)]
pub enum MessageContent {
#[serde(rename = "content")]
@ -346,11 +360,21 @@ impl MessageContent {
}
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq, Clone)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum MessagePart {
Text { text: String },
ImageUrl { image_url: String },
Thinking { thinking: Vec<ThinkingPart> },
}
// Backwards-compatibility alias for provider code that refers to ContentPart
pub type ContentPart = MessagePart;
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq, Clone)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ThinkingPart {
Text { text: String },
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
@ -418,24 +442,30 @@ pub struct StreamChoice {
pub finish_reason: Option<String>,
}
#[derive(Serialize, Deserialize, Debug)]
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct StreamDelta {
pub role: Option<Role>,
pub content: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub content: Option<MessageContentDelta>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<ToolCallChunk>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub reasoning_content: Option<String>,
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq, Clone)]
#[serde(untagged)]
pub enum MessageContentDelta {
Text(String),
Parts(Vec<MessagePart>),
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq, Clone)]
pub struct ToolCallChunk {
pub index: usize,
pub id: Option<String>,
pub function: Option<FunctionChunk>,
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq, Clone)]
pub struct FunctionChunk {
pub name: Option<String>,
pub arguments: Option<String>,