diff --git a/crates/copilot/src/copilot_chat.rs b/crates/copilot/src/copilot_chat.rs index b92f8e2042..ec306f1b69 100644 --- a/crates/copilot/src/copilot_chat.rs +++ b/crates/copilot/src/copilot_chat.rs @@ -581,7 +581,7 @@ async fn stream_completion( api_key: String, request: Request, ) -> Result>> { - let is_vision_request = request.messages.last().map_or(false, |message| match message { + let is_vision_request = request.messages.iter().any(|message| match message { ChatMessage::User { content } | ChatMessage::Assistant { content, .. } | ChatMessage::Tool { content, .. } => { @@ -736,4 +736,116 @@ mod tests { assert_eq!(schema.data[0].id, "gpt-4"); assert_eq!(schema.data[1].id, "claude-3.7-sonnet"); } + + #[test] + fn test_vision_request_detection() { + fn message_contains_image(message: &ChatMessage) -> bool { + match message { + ChatMessage::User { content } + | ChatMessage::Assistant { content, .. } + | ChatMessage::Tool { content, .. } => { + matches!(content, ChatMessageContent::Multipart(parts) if + parts.iter().any(|part| matches!(part, ChatMessagePart::Image { .. }))) + } + _ => false, + } + } + + // Helper function to detect if a request is a vision request + fn is_vision_request(request: &Request) -> bool { + request.messages.iter().any(message_contains_image) + } + + let request_with_image_in_last = Request { + intent: true, + n: 1, + stream: true, + temperature: 0.1, + model: "claude-3.7-sonnet".to_string(), + messages: vec![ + ChatMessage::User { + content: ChatMessageContent::Plain("Hello".to_string()), + }, + ChatMessage::Assistant { + content: ChatMessageContent::Plain("How can I help?".to_string()), + tool_calls: vec![], + }, + ChatMessage::User { + content: ChatMessageContent::Multipart(vec![ + ChatMessagePart::Text { + text: "What's in this image?".to_string(), + }, + ChatMessagePart::Image { + image_url: ImageUrl { + url: "".to_string(), + }, + }, + ]), + }, + ], + tools: vec![], + tool_choice: None, + }; + + let request_with_image_in_earlier = Request { + intent: true, + n: 1, + stream: true, + temperature: 0.1, + model: "claude-3.7-sonnet".to_string(), + messages: vec![ + ChatMessage::User { + content: ChatMessageContent::Plain("Hello".to_string()), + }, + ChatMessage::User { + content: ChatMessageContent::Multipart(vec![ + ChatMessagePart::Text { + text: "What's in this image?".to_string(), + }, + ChatMessagePart::Image { + image_url: ImageUrl { + url: "".to_string(), + }, + }, + ]), + }, + ChatMessage::Assistant { + content: ChatMessageContent::Plain("I see a cat in the image.".to_string()), + tool_calls: vec![], + }, + ChatMessage::User { + content: ChatMessageContent::Plain("What color is it?".to_string()), + }, + ], + tools: vec![], + tool_choice: None, + }; + + let request_with_no_images = Request { + intent: true, + n: 1, + stream: true, + temperature: 0.1, + model: "claude-3.7-sonnet".to_string(), + messages: vec![ + ChatMessage::User { + content: ChatMessageContent::Plain("Hello".to_string()), + }, + ChatMessage::Assistant { + content: ChatMessageContent::Plain("How can I help?".to_string()), + tool_calls: vec![], + }, + ChatMessage::User { + content: ChatMessageContent::Plain("Tell me about Rust.".to_string()), + }, + ], + tools: vec![], + tool_choice: None, + }; + + assert!(is_vision_request(&request_with_image_in_last)); + assert!(is_vision_request(&request_with_image_in_earlier)); + + assert!(!is_vision_request(&request_with_no_images)); + } }