use std::io::{Cursor, Write}; use std::sync::Arc; use anyhow::Result; use base64::write::EncoderWriter; use cloud_llm_client::{CompletionIntent, CompletionMode}; use gpui::{ App, AppContext as _, DevicePixels, Image, ImageFormat, ObjectFit, SharedString, Size, Task, point, px, size, }; use image::codecs::png::PngEncoder; use serde::{Deserialize, Serialize}; use util::ResultExt; use crate::role::Role; use crate::{LanguageModelToolUse, LanguageModelToolUseId}; #[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Hash)] pub struct LanguageModelImage { /// A base64-encoded PNG image. pub source: SharedString, pub size: Size, } impl LanguageModelImage { pub fn len(&self) -> usize { self.source.len() } pub fn is_empty(&self) -> bool { self.source.is_empty() } // Parse Self from a JSON object with case-insensitive field names pub fn from_json(obj: &serde_json::Map) -> Option { let mut source = None; let mut size_obj = None; // Find source and size fields (case-insensitive) for (k, v) in obj.iter() { match k.to_lowercase().as_str() { "source" => source = v.as_str(), "size" => size_obj = v.as_object(), _ => {} } } let source = source?; let size_obj = size_obj?; let mut width = None; let mut height = None; // Find width and height in size object (case-insensitive) for (k, v) in size_obj.iter() { match k.to_lowercase().as_str() { "width" => width = v.as_i64().map(|w| w as i32), "height" => height = v.as_i64().map(|h| h as i32), _ => {} } } Some(Self { size: size(DevicePixels(width?), DevicePixels(height?)), source: SharedString::from(source.to_string()), }) } } impl std::fmt::Debug for LanguageModelImage { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("LanguageModelImage") .field("source", &format!("<{} bytes>", self.source.len())) .field("size", &self.size) .finish() } } /// Anthropic wants uploaded images to be smaller than this in both dimensions. const ANTHROPIC_SIZE_LIMT: f32 = 1568.; impl LanguageModelImage { pub fn empty() -> Self { Self { source: "".into(), size: size(DevicePixels(0), DevicePixels(0)), } } pub fn from_image(data: Arc, cx: &mut App) -> Task> { cx.background_spawn(async move { let image_bytes = Cursor::new(data.bytes()); let dynamic_image = match data.format() { ImageFormat::Png => image::codecs::png::PngDecoder::new(image_bytes) .and_then(image::DynamicImage::from_decoder), ImageFormat::Jpeg => image::codecs::jpeg::JpegDecoder::new(image_bytes) .and_then(image::DynamicImage::from_decoder), ImageFormat::Webp => image::codecs::webp::WebPDecoder::new(image_bytes) .and_then(image::DynamicImage::from_decoder), ImageFormat::Gif => image::codecs::gif::GifDecoder::new(image_bytes) .and_then(image::DynamicImage::from_decoder), _ => return None, } .log_err()?; let width = dynamic_image.width(); let height = dynamic_image.height(); let image_size = size(DevicePixels(width as i32), DevicePixels(height as i32)); let base64_image = { if image_size.width.0 > ANTHROPIC_SIZE_LIMT as i32 || image_size.height.0 > ANTHROPIC_SIZE_LIMT as i32 { let new_bounds = ObjectFit::ScaleDown.get_bounds( gpui::Bounds { origin: point(px(0.0), px(0.0)), size: size(px(ANTHROPIC_SIZE_LIMT), px(ANTHROPIC_SIZE_LIMT)), }, image_size, ); let resized_image = dynamic_image.resize( new_bounds.size.width.0 as u32, new_bounds.size.height.0 as u32, image::imageops::FilterType::Triangle, ); encode_as_base64(data, resized_image) } else { encode_as_base64(data, dynamic_image) } } .log_err()?; // SAFETY: The base64 encoder should not produce non-UTF8. let source = unsafe { String::from_utf8_unchecked(base64_image) }; Some(LanguageModelImage { size: image_size, source: source.into(), }) }) } pub fn estimate_tokens(&self) -> usize { let width = self.size.width.0.unsigned_abs() as usize; let height = self.size.height.0.unsigned_abs() as usize; // From: https://docs.anthropic.com/en/docs/build-with-claude/vision#calculate-image-costs // Note that are a lot of conditions on Anthropic's API, and OpenAI doesn't use this, // so this method is more of a rough guess. (width * height) / 750 } pub fn to_base64_url(&self) -> String { format!("data:image/png;base64,{}", self.source) } } fn encode_as_base64(data: Arc, image: image::DynamicImage) -> Result> { let mut base64_image = Vec::new(); { let mut base64_encoder = EncoderWriter::new( Cursor::new(&mut base64_image), &base64::engine::general_purpose::STANDARD, ); if data.format() == ImageFormat::Png { base64_encoder.write_all(data.bytes())?; } else { let mut png = Vec::new(); image.write_with_encoder(PngEncoder::new(&mut png))?; base64_encoder.write_all(png.as_slice())?; } } Ok(base64_image) } #[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)] pub struct LanguageModelToolResult { pub tool_use_id: LanguageModelToolUseId, pub tool_name: Arc, pub is_error: bool, pub content: LanguageModelToolResultContent, pub output: Option, } #[derive(Debug, Clone, Serialize, Eq, PartialEq, Hash)] pub enum LanguageModelToolResultContent { Text(Arc), Image(LanguageModelImage), } impl<'de> Deserialize<'de> for LanguageModelToolResultContent { fn deserialize(deserializer: D) -> Result where D: serde::Deserializer<'de>, { use serde::de::Error; let value = serde_json::Value::deserialize(deserializer)?; // Models can provide these responses in several styles. Try each in order. // 1. Try as plain string if let Ok(text) = serde_json::from_value::(value.clone()) { return Ok(Self::Text(Arc::from(text))); } // 2. Try as object if let Some(obj) = value.as_object() { // get a JSON field case-insensitively fn get_field<'a>( obj: &'a serde_json::Map, field: &str, ) -> Option<&'a serde_json::Value> { obj.iter() .find(|(k, _)| k.to_lowercase() == field.to_lowercase()) .map(|(_, v)| v) } // Accept wrapped text format: { "type": "text", "text": "..." } if let (Some(type_value), Some(text_value)) = (get_field(obj, "type"), get_field(obj, "text")) && let Some(type_str) = type_value.as_str() && type_str.to_lowercase() == "text" && let Some(text) = text_value.as_str() { return Ok(Self::Text(Arc::from(text))); } // Check for wrapped Text variant: { "text": "..." } if let Some((_key, value)) = obj.iter().find(|(k, _)| k.to_lowercase() == "text") && obj.len() == 1 { // Only one field, and it's "text" (case-insensitive) if let Some(text) = value.as_str() { return Ok(Self::Text(Arc::from(text))); } } // Check for wrapped Image variant: { "image": { "source": "...", "size": ... } } if let Some((_key, value)) = obj.iter().find(|(k, _)| k.to_lowercase() == "image") && obj.len() == 1 { // Only one field, and it's "image" (case-insensitive) // Try to parse the nested image object if let Some(image_obj) = value.as_object() && let Some(image) = LanguageModelImage::from_json(image_obj) { return Ok(Self::Image(image)); } } // Try as direct Image (object with "source" and "size" fields) if let Some(image) = LanguageModelImage::from_json(obj) { return Ok(Self::Image(image)); } } // If none of the variants match, return an error with the problematic JSON Err(D::Error::custom(format!( "data did not match any variant of LanguageModelToolResultContent. Expected either a string, \ an object with 'type': 'text', a wrapped variant like {{\"Text\": \"...\"}}, or an image object. Got: {}", serde_json::to_string_pretty(&value).unwrap_or_else(|_| value.to_string()) ))) } } impl LanguageModelToolResultContent { pub fn to_str(&self) -> Option<&str> { match self { Self::Text(text) => Some(text), Self::Image(_) => None, } } pub fn is_empty(&self) -> bool { match self { Self::Text(text) => text.chars().all(|c| c.is_whitespace()), Self::Image(_) => false, } } } impl From<&str> for LanguageModelToolResultContent { fn from(value: &str) -> Self { Self::Text(Arc::from(value)) } } impl From for LanguageModelToolResultContent { fn from(value: String) -> Self { Self::Text(Arc::from(value)) } } impl From for LanguageModelToolResultContent { fn from(image: LanguageModelImage) -> Self { Self::Image(image) } } #[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)] pub enum MessageContent { Text(String), Thinking { text: String, signature: Option, }, RedactedThinking(String), Image(LanguageModelImage), ToolUse(LanguageModelToolUse), ToolResult(LanguageModelToolResult), } impl MessageContent { pub fn to_str(&self) -> Option<&str> { match self { MessageContent::Text(text) => Some(text.as_str()), MessageContent::Thinking { text, .. } => Some(text.as_str()), MessageContent::RedactedThinking(_) => None, MessageContent::ToolResult(tool_result) => tool_result.content.to_str(), MessageContent::ToolUse(_) | MessageContent::Image(_) => None, } } pub fn is_empty(&self) -> bool { match self { MessageContent::Text(text) => text.chars().all(|c| c.is_whitespace()), MessageContent::Thinking { text, .. } => text.chars().all(|c| c.is_whitespace()), MessageContent::ToolResult(tool_result) => tool_result.content.is_empty(), MessageContent::RedactedThinking(_) | MessageContent::ToolUse(_) | MessageContent::Image(_) => false, } } } impl From for MessageContent { fn from(value: String) -> Self { MessageContent::Text(value) } } impl From<&str> for MessageContent { fn from(value: &str) -> Self { MessageContent::Text(value.to_string()) } } #[derive(Clone, Serialize, Deserialize, Debug, PartialEq, Hash)] pub struct LanguageModelRequestMessage { pub role: Role, pub content: Vec, pub cache: bool, } impl LanguageModelRequestMessage { pub fn string_contents(&self) -> String { let mut buffer = String::new(); for string in self.content.iter().filter_map(|content| content.to_str()) { buffer.push_str(string); } buffer } pub fn contents_empty(&self) -> bool { self.content.iter().all(|content| content.is_empty()) } } #[derive(Debug, PartialEq, Hash, Clone, Serialize, Deserialize)] pub struct LanguageModelRequestTool { pub name: String, pub description: String, pub input_schema: serde_json::Value, } #[derive(Debug, PartialEq, Hash, Clone, Serialize, Deserialize)] pub enum LanguageModelToolChoice { Auto, Any, None, } #[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)] pub struct LanguageModelRequest { pub thread_id: Option, pub prompt_id: Option, pub intent: Option, pub mode: Option, pub messages: Vec, pub tools: Vec, pub tool_choice: Option, pub stop: Vec, pub temperature: Option, pub thinking_allowed: bool, } #[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] pub struct LanguageModelResponseMessage { pub role: Option, pub content: Option, } #[cfg(test)] mod tests { use super::*; #[test] fn test_language_model_tool_result_content_deserialization() { let json = r#""This is plain text""#; let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap(); assert_eq!( result, LanguageModelToolResultContent::Text("This is plain text".into()) ); let json = r#"{"type": "text", "text": "This is wrapped text"}"#; let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap(); assert_eq!( result, LanguageModelToolResultContent::Text("This is wrapped text".into()) ); let json = r#"{"Type": "TEXT", "TEXT": "Case insensitive"}"#; let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap(); assert_eq!( result, LanguageModelToolResultContent::Text("Case insensitive".into()) ); let json = r#"{"Text": "Wrapped variant"}"#; let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap(); assert_eq!( result, LanguageModelToolResultContent::Text("Wrapped variant".into()) ); let json = r#"{"text": "Lowercase wrapped"}"#; let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap(); assert_eq!( result, LanguageModelToolResultContent::Text("Lowercase wrapped".into()) ); // Test image deserialization let json = r#"{ "source": "base64encodedimagedata", "size": { "width": 100, "height": 200 } }"#; let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap(); match result { LanguageModelToolResultContent::Image(image) => { assert_eq!(image.source.as_ref(), "base64encodedimagedata"); assert_eq!(image.size.width.0, 100); assert_eq!(image.size.height.0, 200); } _ => panic!("Expected Image variant"), } // Test wrapped Image variant let json = r#"{ "Image": { "source": "wrappedimagedata", "size": { "width": 50, "height": 75 } } }"#; let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap(); match result { LanguageModelToolResultContent::Image(image) => { assert_eq!(image.source.as_ref(), "wrappedimagedata"); assert_eq!(image.size.width.0, 50); assert_eq!(image.size.height.0, 75); } _ => panic!("Expected Image variant"), } // Test wrapped Image variant with case insensitive let json = r#"{ "image": { "Source": "caseinsensitive", "SIZE": { "width": 30, "height": 40 } } }"#; let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap(); match result { LanguageModelToolResultContent::Image(image) => { assert_eq!(image.source.as_ref(), "caseinsensitive"); assert_eq!(image.size.width.0, 30); assert_eq!(image.size.height.0, 40); } _ => panic!("Expected Image variant"), } // Test that wrapped text with wrong type fails let json = r#"{"type": "blahblah", "text": "This should fail"}"#; let result: Result = serde_json::from_str(json); assert!(result.is_err()); // Test that malformed JSON fails let json = r#"{"invalid": "structure"}"#; let result: Result = serde_json::from_str(json); assert!(result.is_err()); // Test edge cases let json = r#""""#; // Empty string let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap(); assert_eq!(result, LanguageModelToolResultContent::Text("".into())); // Test with extra fields in wrapped text (should be ignored) let json = r#"{"type": "text", "text": "Hello", "extra": "field"}"#; let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap(); assert_eq!(result, LanguageModelToolResultContent::Text("Hello".into())); // Test direct image with case-insensitive fields let json = r#"{ "SOURCE": "directimage", "Size": { "width": 200, "height": 300 } }"#; let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap(); match result { LanguageModelToolResultContent::Image(image) => { assert_eq!(image.source.as_ref(), "directimage"); assert_eq!(image.size.width.0, 200); assert_eq!(image.size.height.0, 300); } _ => panic!("Expected Image variant"), } // Test that multiple fields prevent wrapped variant interpretation let json = r#"{"Text": "not wrapped", "extra": "field"}"#; let result: Result = serde_json::from_str(json); assert!(result.is_err()); // Test wrapped text with uppercase TEXT variant let json = r#"{"TEXT": "Uppercase variant"}"#; let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap(); assert_eq!( result, LanguageModelToolResultContent::Text("Uppercase variant".into()) ); // Test that numbers and other JSON values fail gracefully let json = r#"123"#; let result: Result = serde_json::from_str(json); assert!(result.is_err()); let json = r#"null"#; let result: Result = serde_json::from_str(json); assert!(result.is_err()); let json = r#"[1, 2, 3]"#; let result: Result = serde_json::from_str(json); assert!(result.is_err()); } }