language_models: Add thinking to Mistral Provider (#32476)
Tested prompt: John is one of 4 children. The first sister is 4 years old. Next year, the second sister will be twice as old as the first sister. The third sister is two years older than the second sister. The third sister is half the age of her older brother. How old is John? Return your thinking inside <think></think> Release Notes: - Add thinking to Mistral Provider --------- Signed-off-by: Umesh Yadav <git@umesh.dev> Co-authored-by: Peter Tripp <peter@zed.dev>
This commit is contained in:
parent
021681d456
commit
ce39644cbd
2 changed files with 126 additions and 59 deletions
|
@ -86,6 +86,7 @@ pub enum Model {
|
|||
max_completion_tokens: Option<u64>,
|
||||
supports_tools: Option<bool>,
|
||||
supports_images: Option<bool>,
|
||||
supports_thinking: Option<bool>,
|
||||
},
|
||||
}
|
||||
|
||||
|
@ -214,6 +215,16 @@ impl Model {
|
|||
} => supports_images.unwrap_or(false),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn supports_thinking(&self) -> bool {
|
||||
match self {
|
||||
Self::MagistralMediumLatest | Self::MagistralSmallLatest => true,
|
||||
Self::Custom {
|
||||
supports_thinking, ..
|
||||
} => supports_thinking.unwrap_or(false),
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
|
@ -288,7 +299,9 @@ pub enum ToolChoice {
|
|||
#[serde(tag = "role", rename_all = "lowercase")]
|
||||
pub enum RequestMessage {
|
||||
Assistant {
|
||||
content: Option<String>,
|
||||
#[serde(flatten)]
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
content: Option<MessageContent>,
|
||||
#[serde(default, skip_serializing_if = "Vec::is_empty")]
|
||||
tool_calls: Vec<ToolCall>,
|
||||
},
|
||||
|
@ -297,7 +310,8 @@ pub enum RequestMessage {
|
|||
content: MessageContent,
|
||||
},
|
||||
System {
|
||||
content: String,
|
||||
#[serde(flatten)]
|
||||
content: MessageContent,
|
||||
},
|
||||
Tool {
|
||||
content: String,
|
||||
|
@ -305,7 +319,7 @@ pub enum RequestMessage {
|
|||
},
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq, Clone)]
|
||||
#[serde(untagged)]
|
||||
pub enum MessageContent {
|
||||
#[serde(rename = "content")]
|
||||
|
@ -346,11 +360,21 @@ impl MessageContent {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq, Clone)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum MessagePart {
|
||||
Text { text: String },
|
||||
ImageUrl { image_url: String },
|
||||
Thinking { thinking: Vec<ThinkingPart> },
|
||||
}
|
||||
|
||||
// Backwards-compatibility alias for provider code that refers to ContentPart
|
||||
pub type ContentPart = MessagePart;
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq, Clone)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum ThinkingPart {
|
||||
Text { text: String },
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
|
@ -418,24 +442,30 @@ pub struct StreamChoice {
|
|||
pub finish_reason: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
pub struct StreamDelta {
|
||||
pub role: Option<Role>,
|
||||
pub content: Option<String>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub content: Option<MessageContentDelta>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub tool_calls: Option<Vec<ToolCallChunk>>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub reasoning_content: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq, Clone)]
|
||||
#[serde(untagged)]
|
||||
pub enum MessageContentDelta {
|
||||
Text(String),
|
||||
Parts(Vec<MessagePart>),
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq, Clone)]
|
||||
pub struct ToolCallChunk {
|
||||
pub index: usize,
|
||||
pub id: Option<String>,
|
||||
pub function: Option<FunctionChunk>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq, Clone)]
|
||||
pub struct FunctionChunk {
|
||||
pub name: Option<String>,
|
||||
pub arguments: Option<String>,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue