language_models: Add thinking to Mistral Provider (#32476)
Tested prompt: John is one of 4 children. The first sister is 4 years old. Next year, the second sister will be twice as old as the first sister. The third sister is two years older than the second sister. The third sister is half the age of her older brother. How old is John? Return your thinking inside <think></think> Release Notes: - Add thinking to Mistral Provider --------- Signed-off-by: Umesh Yadav <git@umesh.dev> Co-authored-by: Peter Tripp <peter@zed.dev>
This commit is contained in:
parent
021681d456
commit
ce39644cbd
2 changed files with 126 additions and 59 deletions
|
@ -47,6 +47,7 @@ pub struct AvailableModel {
|
|||
pub max_completion_tokens: Option<u64>,
|
||||
pub supports_tools: Option<bool>,
|
||||
pub supports_images: Option<bool>,
|
||||
pub supports_thinking: Option<bool>,
|
||||
}
|
||||
|
||||
pub struct MistralLanguageModelProvider {
|
||||
|
@ -215,6 +216,7 @@ impl LanguageModelProvider for MistralLanguageModelProvider {
|
|||
max_completion_tokens: model.max_completion_tokens,
|
||||
supports_tools: model.supports_tools,
|
||||
supports_images: model.supports_images,
|
||||
supports_thinking: model.supports_thinking,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
@ -366,11 +368,7 @@ impl LanguageModel for MistralLanguageModel {
|
|||
LanguageModelCompletionError,
|
||||
>,
|
||||
> {
|
||||
let request = into_mistral(
|
||||
request,
|
||||
self.model.id().to_string(),
|
||||
self.max_output_tokens(),
|
||||
);
|
||||
let request = into_mistral(request, self.model.clone(), self.max_output_tokens());
|
||||
let stream = self.stream_completion(request, cx);
|
||||
|
||||
async move {
|
||||
|
@ -384,7 +382,7 @@ impl LanguageModel for MistralLanguageModel {
|
|||
|
||||
pub fn into_mistral(
|
||||
request: LanguageModelRequest,
|
||||
model: String,
|
||||
model: mistral::Model,
|
||||
max_output_tokens: Option<u64>,
|
||||
) -> mistral::Request {
|
||||
let stream = true;
|
||||
|
@ -401,13 +399,20 @@ pub fn into_mistral(
|
|||
.push_part(mistral::MessagePart::Text { text: text.clone() });
|
||||
}
|
||||
MessageContent::Image(image_content) => {
|
||||
message_content.push_part(mistral::MessagePart::ImageUrl {
|
||||
image_url: image_content.to_base64_url(),
|
||||
});
|
||||
if model.supports_images() {
|
||||
message_content.push_part(mistral::MessagePart::ImageUrl {
|
||||
image_url: image_content.to_base64_url(),
|
||||
});
|
||||
}
|
||||
}
|
||||
MessageContent::Thinking { text, .. } => {
|
||||
message_content
|
||||
.push_part(mistral::MessagePart::Text { text: text.clone() });
|
||||
if model.supports_thinking() {
|
||||
message_content.push_part(mistral::MessagePart::Thinking {
|
||||
thinking: vec![mistral::ThinkingPart::Text {
|
||||
text: text.clone(),
|
||||
}],
|
||||
});
|
||||
}
|
||||
}
|
||||
MessageContent::RedactedThinking(_) => {}
|
||||
MessageContent::ToolUse(_) => {
|
||||
|
@ -437,12 +442,28 @@ pub fn into_mistral(
|
|||
Role::Assistant => {
|
||||
for content in &message.content {
|
||||
match content {
|
||||
MessageContent::Text(text) | MessageContent::Thinking { text, .. } => {
|
||||
MessageContent::Text(text) => {
|
||||
messages.push(mistral::RequestMessage::Assistant {
|
||||
content: Some(text.clone()),
|
||||
content: Some(mistral::MessageContent::Plain {
|
||||
content: text.clone(),
|
||||
}),
|
||||
tool_calls: Vec::new(),
|
||||
});
|
||||
}
|
||||
MessageContent::Thinking { text, .. } => {
|
||||
if model.supports_thinking() {
|
||||
messages.push(mistral::RequestMessage::Assistant {
|
||||
content: Some(mistral::MessageContent::Multipart {
|
||||
content: vec![mistral::MessagePart::Thinking {
|
||||
thinking: vec![mistral::ThinkingPart::Text {
|
||||
text: text.clone(),
|
||||
}],
|
||||
}],
|
||||
}),
|
||||
tool_calls: Vec::new(),
|
||||
});
|
||||
}
|
||||
}
|
||||
MessageContent::RedactedThinking(_) => {}
|
||||
MessageContent::Image(_) => {}
|
||||
MessageContent::ToolUse(tool_use) => {
|
||||
|
@ -477,11 +498,26 @@ pub fn into_mistral(
|
|||
Role::System => {
|
||||
for content in &message.content {
|
||||
match content {
|
||||
MessageContent::Text(text) | MessageContent::Thinking { text, .. } => {
|
||||
MessageContent::Text(text) => {
|
||||
messages.push(mistral::RequestMessage::System {
|
||||
content: text.clone(),
|
||||
content: mistral::MessageContent::Plain {
|
||||
content: text.clone(),
|
||||
},
|
||||
});
|
||||
}
|
||||
MessageContent::Thinking { text, .. } => {
|
||||
if model.supports_thinking() {
|
||||
messages.push(mistral::RequestMessage::System {
|
||||
content: mistral::MessageContent::Multipart {
|
||||
content: vec![mistral::MessagePart::Thinking {
|
||||
thinking: vec![mistral::ThinkingPart::Text {
|
||||
text: text.clone(),
|
||||
}],
|
||||
}],
|
||||
},
|
||||
});
|
||||
}
|
||||
}
|
||||
MessageContent::RedactedThinking(_) => {}
|
||||
MessageContent::Image(_)
|
||||
| MessageContent::ToolUse(_)
|
||||
|
@ -494,37 +530,8 @@ pub fn into_mistral(
|
|||
}
|
||||
}
|
||||
|
||||
// The Mistral API requires that tool messages be followed by assistant messages,
|
||||
// not user messages. When we have a tool->user sequence in the conversation,
|
||||
// we need to insert a placeholder assistant message to maintain proper conversation
|
||||
// flow and prevent API errors. This is a Mistral-specific requirement that differs
|
||||
// from other language model APIs.
|
||||
let messages = {
|
||||
let mut fixed_messages = Vec::with_capacity(messages.len());
|
||||
let mut messages_iter = messages.into_iter().peekable();
|
||||
|
||||
while let Some(message) = messages_iter.next() {
|
||||
let is_tool_message = matches!(message, mistral::RequestMessage::Tool { .. });
|
||||
fixed_messages.push(message);
|
||||
|
||||
// Insert assistant message between tool and user messages
|
||||
if is_tool_message {
|
||||
if let Some(next_msg) = messages_iter.peek() {
|
||||
if matches!(next_msg, mistral::RequestMessage::User { .. }) {
|
||||
fixed_messages.push(mistral::RequestMessage::Assistant {
|
||||
content: Some(" ".to_string()),
|
||||
tool_calls: Vec::new(),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fixed_messages
|
||||
};
|
||||
|
||||
mistral::Request {
|
||||
model,
|
||||
model: model.id().to_string(),
|
||||
messages,
|
||||
stream,
|
||||
max_tokens: max_output_tokens,
|
||||
|
@ -595,8 +602,38 @@ impl MistralEventMapper {
|
|||
};
|
||||
|
||||
let mut events = Vec::new();
|
||||
if let Some(content) = choice.delta.content.clone() {
|
||||
events.push(Ok(LanguageModelCompletionEvent::Text(content)));
|
||||
if let Some(content) = choice.delta.content.as_ref() {
|
||||
match content {
|
||||
mistral::MessageContentDelta::Text(text) => {
|
||||
events.push(Ok(LanguageModelCompletionEvent::Text(text.clone())));
|
||||
}
|
||||
mistral::MessageContentDelta::Parts(parts) => {
|
||||
for part in parts {
|
||||
match part {
|
||||
mistral::MessagePart::Text { text } => {
|
||||
events.push(Ok(LanguageModelCompletionEvent::Text(text.clone())));
|
||||
}
|
||||
mistral::MessagePart::Thinking { thinking } => {
|
||||
for tp in thinking.iter().cloned() {
|
||||
match tp {
|
||||
mistral::ThinkingPart::Text { text } => {
|
||||
events.push(Ok(
|
||||
LanguageModelCompletionEvent::Thinking {
|
||||
text,
|
||||
signature: None,
|
||||
},
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
mistral::MessagePart::ImageUrl { .. } => {
|
||||
// We currently don't emit a separate event for images in responses.
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(tool_calls) = choice.delta.tool_calls.as_ref() {
|
||||
|
@ -908,7 +945,7 @@ mod tests {
|
|||
thinking_allowed: true,
|
||||
};
|
||||
|
||||
let mistral_request = into_mistral(request, "mistral-small-latest".into(), None);
|
||||
let mistral_request = into_mistral(request, mistral::Model::MistralSmallLatest, None);
|
||||
|
||||
assert_eq!(mistral_request.model, "mistral-small-latest");
|
||||
assert_eq!(mistral_request.temperature, Some(0.5));
|
||||
|
@ -941,7 +978,7 @@ mod tests {
|
|||
thinking_allowed: true,
|
||||
};
|
||||
|
||||
let mistral_request = into_mistral(request, "pixtral-12b-latest".into(), None);
|
||||
let mistral_request = into_mistral(request, mistral::Model::Pixtral12BLatest, None);
|
||||
|
||||
assert_eq!(mistral_request.messages.len(), 1);
|
||||
assert!(matches!(
|
||||
|
|
|
@ -86,6 +86,7 @@ pub enum Model {
|
|||
max_completion_tokens: Option<u64>,
|
||||
supports_tools: Option<bool>,
|
||||
supports_images: Option<bool>,
|
||||
supports_thinking: Option<bool>,
|
||||
},
|
||||
}
|
||||
|
||||
|
@ -214,6 +215,16 @@ impl Model {
|
|||
} => supports_images.unwrap_or(false),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn supports_thinking(&self) -> bool {
|
||||
match self {
|
||||
Self::MagistralMediumLatest | Self::MagistralSmallLatest => true,
|
||||
Self::Custom {
|
||||
supports_thinking, ..
|
||||
} => supports_thinking.unwrap_or(false),
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
|
@ -288,7 +299,9 @@ pub enum ToolChoice {
|
|||
#[serde(tag = "role", rename_all = "lowercase")]
|
||||
pub enum RequestMessage {
|
||||
Assistant {
|
||||
content: Option<String>,
|
||||
#[serde(flatten)]
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
content: Option<MessageContent>,
|
||||
#[serde(default, skip_serializing_if = "Vec::is_empty")]
|
||||
tool_calls: Vec<ToolCall>,
|
||||
},
|
||||
|
@ -297,7 +310,8 @@ pub enum RequestMessage {
|
|||
content: MessageContent,
|
||||
},
|
||||
System {
|
||||
content: String,
|
||||
#[serde(flatten)]
|
||||
content: MessageContent,
|
||||
},
|
||||
Tool {
|
||||
content: String,
|
||||
|
@ -305,7 +319,7 @@ pub enum RequestMessage {
|
|||
},
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq, Clone)]
|
||||
#[serde(untagged)]
|
||||
pub enum MessageContent {
|
||||
#[serde(rename = "content")]
|
||||
|
@ -346,11 +360,21 @@ impl MessageContent {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq, Clone)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum MessagePart {
|
||||
Text { text: String },
|
||||
ImageUrl { image_url: String },
|
||||
Thinking { thinking: Vec<ThinkingPart> },
|
||||
}
|
||||
|
||||
// Backwards-compatibility alias for provider code that refers to ContentPart
|
||||
pub type ContentPart = MessagePart;
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq, Clone)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum ThinkingPart {
|
||||
Text { text: String },
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
|
@ -418,24 +442,30 @@ pub struct StreamChoice {
|
|||
pub finish_reason: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
pub struct StreamDelta {
|
||||
pub role: Option<Role>,
|
||||
pub content: Option<String>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub content: Option<MessageContentDelta>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub tool_calls: Option<Vec<ToolCallChunk>>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub reasoning_content: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq, Clone)]
|
||||
#[serde(untagged)]
|
||||
pub enum MessageContentDelta {
|
||||
Text(String),
|
||||
Parts(Vec<MessagePart>),
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq, Clone)]
|
||||
pub struct ToolCallChunk {
|
||||
pub index: usize,
|
||||
pub id: Option<String>,
|
||||
pub function: Option<FunctionChunk>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq, Clone)]
|
||||
pub struct FunctionChunk {
|
||||
pub name: Option<String>,
|
||||
pub arguments: Option<String>,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue