diff --git a/crates/agent_settings/src/agent_settings.rs b/crates/agent_settings/src/agent_settings.rs index a162ce064e..ce7bd56047 100644 --- a/crates/agent_settings/src/agent_settings.rs +++ b/crates/agent_settings/src/agent_settings.rs @@ -372,6 +372,7 @@ impl AgentSettingsContent { None, None, Some(language_model.supports_tools()), + Some(language_model.supports_images()), None, )), api_url, diff --git a/crates/language_models/src/provider/ollama.rs b/crates/language_models/src/provider/ollama.rs index 78645cc1b9..ed5f10da23 100644 --- a/crates/language_models/src/provider/ollama.rs +++ b/crates/language_models/src/provider/ollama.rs @@ -4,14 +4,11 @@ use futures::{Stream, TryFutureExt, stream}; use gpui::{AnyView, App, AsyncApp, Context, Subscription, Task}; use http_client::HttpClient; use language_model::{ - AuthenticateError, LanguageModelCompletionError, LanguageModelCompletionEvent, + AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, + LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, + LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, LanguageModelRequestTool, LanguageModelToolChoice, LanguageModelToolUse, - LanguageModelToolUseId, MessageContent, StopReason, -}; -use language_model::{ - LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider, - LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, - LanguageModelRequest, RateLimiter, Role, + LanguageModelToolUseId, MessageContent, RateLimiter, Role, StopReason, }; use ollama::{ ChatMessage, ChatOptions, ChatRequest, ChatResponseDelta, KeepAlive, OllamaFunctionTool, @@ -54,6 +51,8 @@ pub struct AvailableModel { pub keep_alive: Option, /// Whether the model supports tools pub supports_tools: Option, + /// Whether the model supports vision + pub supports_images: Option, /// Whether to enable think mode pub supports_thinking: Option, } @@ -101,6 +100,7 @@ impl State { None, None, Some(capabilities.supports_tools()), + Some(capabilities.supports_vision()), Some(capabilities.supports_thinking()), ); Ok(ollama_model) @@ -222,6 +222,7 @@ impl LanguageModelProvider for OllamaLanguageModelProvider { max_tokens: model.max_tokens, keep_alive: model.keep_alive.clone(), supports_tools: model.supports_tools, + supports_vision: model.supports_images, supports_thinking: model.supports_thinking, }, ); @@ -277,30 +278,59 @@ pub struct OllamaLanguageModel { impl OllamaLanguageModel { fn to_ollama_request(&self, request: LanguageModelRequest) -> ChatRequest { + let supports_vision = self.model.supports_vision.unwrap_or(false); + ChatRequest { model: self.model.name.clone(), messages: request .messages .into_iter() - .map(|msg| match msg.role { - Role::User => ChatMessage::User { - content: msg.string_contents(), - }, - Role::Assistant => { - let content = msg.string_contents(); - let thinking = msg.content.into_iter().find_map(|content| match content { - MessageContent::Thinking { text, .. } if !text.is_empty() => Some(text), - _ => None, - }); - ChatMessage::Assistant { - content, - tool_calls: None, - thinking, + .map(|msg| { + let images = if supports_vision { + msg.content + .iter() + .filter_map(|content| match content { + MessageContent::Image(image) => Some(image.source.to_string()), + _ => None, + }) + .collect::>() + } else { + vec![] + }; + + match msg.role { + Role::User => ChatMessage::User { + content: msg.string_contents(), + images: if images.is_empty() { + None + } else { + Some(images) + }, + }, + Role::Assistant => { + let content = msg.string_contents(); + let thinking = + msg.content.into_iter().find_map(|content| match content { + MessageContent::Thinking { text, .. } if !text.is_empty() => { + Some(text) + } + _ => None, + }); + ChatMessage::Assistant { + content, + tool_calls: None, + images: if images.is_empty() { + None + } else { + Some(images) + }, + thinking, + } } + Role::System => ChatMessage::System { + content: msg.string_contents(), + }, } - Role::System => ChatMessage::System { - content: msg.string_contents(), - }, }) .collect(), keep_alive: self.model.keep_alive.clone().unwrap_or_default(), @@ -339,7 +369,7 @@ impl LanguageModel for OllamaLanguageModel { } fn supports_images(&self) -> bool { - false + self.model.supports_vision.unwrap_or(false) } fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool { @@ -437,7 +467,7 @@ fn map_to_language_model_completion_events( let mut events = Vec::new(); match delta.message { - ChatMessage::User { content } => { + ChatMessage::User { content, images: _ } => { events.push(Ok(LanguageModelCompletionEvent::Text(content))); } ChatMessage::System { content } => { @@ -446,6 +476,7 @@ fn map_to_language_model_completion_events( ChatMessage::Assistant { content, tool_calls, + images: _, thinking, } => { if let Some(text) = thinking { diff --git a/crates/ollama/src/ollama.rs b/crates/ollama/src/ollama.rs index b52df6e4ce..1e68d58b96 100644 --- a/crates/ollama/src/ollama.rs +++ b/crates/ollama/src/ollama.rs @@ -38,6 +38,7 @@ pub struct Model { pub max_tokens: usize, pub keep_alive: Option, pub supports_tools: Option, + pub supports_vision: Option, pub supports_thinking: Option, } @@ -68,6 +69,7 @@ impl Model { display_name: Option<&str>, max_tokens: Option, supports_tools: Option, + supports_vision: Option, supports_thinking: Option, ) -> Self { Self { @@ -78,6 +80,7 @@ impl Model { max_tokens: max_tokens.unwrap_or_else(|| get_max_tokens(name)), keep_alive: Some(KeepAlive::indefinite()), supports_tools, + supports_vision, supports_thinking, } } @@ -101,10 +104,14 @@ pub enum ChatMessage { Assistant { content: String, tool_calls: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + images: Option>, thinking: Option, }, User { content: String, + #[serde(skip_serializing_if = "Option::is_none")] + images: Option>, }, System { content: String, @@ -221,6 +228,10 @@ impl ModelShow { self.capabilities.iter().any(|v| v == "tools") } + pub fn supports_vision(&self) -> bool { + self.capabilities.iter().any(|v| v == "vision") + } + pub fn supports_thinking(&self) -> bool { self.capabilities.iter().any(|v| v == "thinking") } @@ -468,6 +479,7 @@ mod tests { ChatMessage::Assistant { content, tool_calls, + images: _, thinking, } => { assert!(content.is_empty()); @@ -534,4 +546,70 @@ mod tests { assert!(result.capabilities.contains(&"tools".to_string())); assert!(result.capabilities.contains(&"completion".to_string())); } + + #[test] + fn serialize_chat_request_with_images() { + let base64_image = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg=="; + + let request = ChatRequest { + model: "llava".to_string(), + messages: vec![ChatMessage::User { + content: "What do you see in this image?".to_string(), + images: Some(vec![base64_image.to_string()]), + }], + stream: false, + keep_alive: KeepAlive::default(), + options: None, + think: None, + tools: vec![], + }; + + let serialized = serde_json::to_string(&request).unwrap(); + assert!(serialized.contains("images")); + assert!(serialized.contains(base64_image)); + } + + #[test] + fn serialize_chat_request_without_images() { + let request = ChatRequest { + model: "llama3.2".to_string(), + messages: vec![ChatMessage::User { + content: "Hello, world!".to_string(), + images: None, + }], + stream: false, + keep_alive: KeepAlive::default(), + options: None, + think: None, + tools: vec![], + }; + + let serialized = serde_json::to_string(&request).unwrap(); + assert!(!serialized.contains("images")); + } + + #[test] + fn test_json_format_with_images() { + let base64_image = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg=="; + + let request = ChatRequest { + model: "llava".to_string(), + messages: vec![ChatMessage::User { + content: "What do you see?".to_string(), + images: Some(vec![base64_image.to_string()]), + }], + stream: false, + keep_alive: KeepAlive::default(), + options: None, + think: None, + tools: vec![], + }; + + let serialized = serde_json::to_string(&request).unwrap(); + + let parsed: serde_json::Value = serde_json::from_str(&serialized).unwrap(); + let message_images = parsed["messages"][0]["images"].as_array().unwrap(); + assert_eq!(message_images.len(), 1); + assert_eq!(message_images[0].as_str().unwrap(), base64_image); + } }