Merge 24c2a465bb
into a102b08743
This commit is contained in:
commit
b5ba14dd06
2 changed files with 80 additions and 47 deletions
|
@ -11,8 +11,8 @@ use language_model::{
|
||||||
LanguageModelToolUseId, MessageContent, RateLimiter, Role, StopReason, TokenUsage,
|
LanguageModelToolUseId, MessageContent, RateLimiter, Role, StopReason, TokenUsage,
|
||||||
};
|
};
|
||||||
use ollama::{
|
use ollama::{
|
||||||
ChatMessage, ChatOptions, ChatRequest, ChatResponseDelta, KeepAlive, OllamaFunctionTool,
|
ChatMessage, ChatOptions, ChatRequest, ChatResponseDelta, KeepAlive, OllamaFunctionCall,
|
||||||
OllamaToolCall, get_models, show_model, stream_chat_completion,
|
OllamaFunctionTool, OllamaToolCall, get_models, show_model, stream_chat_completion,
|
||||||
};
|
};
|
||||||
use schemars::JsonSchema;
|
use schemars::JsonSchema;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
@ -282,59 +282,85 @@ impl OllamaLanguageModel {
|
||||||
fn to_ollama_request(&self, request: LanguageModelRequest) -> ChatRequest {
|
fn to_ollama_request(&self, request: LanguageModelRequest) -> ChatRequest {
|
||||||
let supports_vision = self.model.supports_vision.unwrap_or(false);
|
let supports_vision = self.model.supports_vision.unwrap_or(false);
|
||||||
|
|
||||||
ChatRequest {
|
let mut messages = Vec::with_capacity(request.messages.len());
|
||||||
model: self.model.name.clone(),
|
|
||||||
messages: request
|
|
||||||
.messages
|
|
||||||
.into_iter()
|
|
||||||
.map(|msg| {
|
|
||||||
let images = if supports_vision {
|
|
||||||
msg.content
|
|
||||||
.iter()
|
|
||||||
.filter_map(|content| match content {
|
|
||||||
MessageContent::Image(image) => Some(image.source.to_string()),
|
|
||||||
_ => None,
|
|
||||||
})
|
|
||||||
.collect::<Vec<String>>()
|
|
||||||
} else {
|
|
||||||
vec![]
|
|
||||||
};
|
|
||||||
|
|
||||||
match msg.role {
|
for mut msg in request.messages.into_iter() {
|
||||||
Role::User => ChatMessage::User {
|
let images = if supports_vision {
|
||||||
|
msg.content
|
||||||
|
.iter()
|
||||||
|
.filter_map(|content| match content {
|
||||||
|
MessageContent::Image(image) => Some(image.source.to_string()),
|
||||||
|
_ => None,
|
||||||
|
})
|
||||||
|
.collect::<Vec<String>>()
|
||||||
|
} else {
|
||||||
|
vec![]
|
||||||
|
};
|
||||||
|
|
||||||
|
match msg.role {
|
||||||
|
Role::User => {
|
||||||
|
for tool_result in msg
|
||||||
|
.content
|
||||||
|
.extract_if(.., |x| matches!(x, MessageContent::ToolResult(..)))
|
||||||
|
{
|
||||||
|
match tool_result {
|
||||||
|
MessageContent::ToolResult(tool_result) => {
|
||||||
|
messages.push(ChatMessage::Tool {
|
||||||
|
tool_name: tool_result.tool_name.to_string(),
|
||||||
|
content: tool_result.content.to_str().unwrap_or("").to_string(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
_ => unreachable!("Only tool result should be extracted"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !msg.content.is_empty() {
|
||||||
|
messages.push(ChatMessage::User {
|
||||||
content: msg.string_contents(),
|
content: msg.string_contents(),
|
||||||
images: if images.is_empty() {
|
images: if images.is_empty() {
|
||||||
None
|
None
|
||||||
} else {
|
} else {
|
||||||
Some(images)
|
Some(images)
|
||||||
},
|
},
|
||||||
},
|
})
|
||||||
Role::Assistant => {
|
|
||||||
let content = msg.string_contents();
|
|
||||||
let thinking =
|
|
||||||
msg.content.into_iter().find_map(|content| match content {
|
|
||||||
MessageContent::Thinking { text, .. } if !text.is_empty() => {
|
|
||||||
Some(text)
|
|
||||||
}
|
|
||||||
_ => None,
|
|
||||||
});
|
|
||||||
ChatMessage::Assistant {
|
|
||||||
content,
|
|
||||||
tool_calls: None,
|
|
||||||
images: if images.is_empty() {
|
|
||||||
None
|
|
||||||
} else {
|
|
||||||
Some(images)
|
|
||||||
},
|
|
||||||
thinking,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Role::System => ChatMessage::System {
|
|
||||||
content: msg.string_contents(),
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
})
|
}
|
||||||
.collect(),
|
Role::Assistant => {
|
||||||
|
let content = msg.string_contents();
|
||||||
|
let mut thinking = None;
|
||||||
|
let mut tool_calls = Vec::new();
|
||||||
|
for content in msg.content.into_iter() {
|
||||||
|
match content {
|
||||||
|
MessageContent::Thinking { text, .. } if !text.is_empty() => {
|
||||||
|
thinking = Some(text)
|
||||||
|
}
|
||||||
|
MessageContent::ToolUse(tool_use) => {
|
||||||
|
tool_calls.push(OllamaToolCall::Function(OllamaFunctionCall {
|
||||||
|
name: tool_use.name.to_string(),
|
||||||
|
arguments: tool_use.input,
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
_ => (),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
messages.push(ChatMessage::Assistant {
|
||||||
|
content,
|
||||||
|
tool_calls: Some(tool_calls),
|
||||||
|
images: if images.is_empty() {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some(images)
|
||||||
|
},
|
||||||
|
thinking,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
Role::System => messages.push(ChatMessage::System {
|
||||||
|
content: msg.string_contents(),
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ChatRequest {
|
||||||
|
model: self.model.name.clone(),
|
||||||
|
messages,
|
||||||
keep_alive: self.model.keep_alive.clone().unwrap_or_default(),
|
keep_alive: self.model.keep_alive.clone().unwrap_or_default(),
|
||||||
stream: true,
|
stream: true,
|
||||||
options: Some(ChatOptions {
|
options: Some(ChatOptions {
|
||||||
|
@ -479,6 +505,9 @@ fn map_to_language_model_completion_events(
|
||||||
ChatMessage::System { content } => {
|
ChatMessage::System { content } => {
|
||||||
events.push(Ok(LanguageModelCompletionEvent::Text(content)));
|
events.push(Ok(LanguageModelCompletionEvent::Text(content)));
|
||||||
}
|
}
|
||||||
|
ChatMessage::Tool { content, .. } => {
|
||||||
|
events.push(Ok(LanguageModelCompletionEvent::Text(content)));
|
||||||
|
}
|
||||||
ChatMessage::Assistant {
|
ChatMessage::Assistant {
|
||||||
content,
|
content,
|
||||||
tool_calls,
|
tool_calls,
|
||||||
|
|
|
@ -117,6 +117,10 @@ pub enum ChatMessage {
|
||||||
System {
|
System {
|
||||||
content: String,
|
content: String,
|
||||||
},
|
},
|
||||||
|
Tool {
|
||||||
|
tool_name: String,
|
||||||
|
content: String,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[derive(Serialize, Deserialize, Debug)]
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue