diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index 004a9ead7b..d3d9a62f78 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -24,7 +24,7 @@ use language_model::{ LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolUseId, MessageContent, ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, SelectedModel, - StopReason, TokenUsage, WrappedTextContent, + StopReason, TokenUsage, }; use postage::stream::Stream as _; use project::Project; @@ -891,10 +891,7 @@ impl Thread { pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc> { match &self.tool_use.tool_result(id)?.content { - LanguageModelToolResultContent::Text(text) - | LanguageModelToolResultContent::WrappedText(WrappedTextContent { text, .. }) => { - Some(text) - } + LanguageModelToolResultContent::Text(text) => Some(text), LanguageModelToolResultContent::Image(_) => { // TODO: We should display image None @@ -2593,11 +2590,7 @@ impl Thread { writeln!(markdown, "**\n")?; match &tool_result.content { - LanguageModelToolResultContent::Text(text) - | LanguageModelToolResultContent::WrappedText(WrappedTextContent { - text, - .. - }) => { + LanguageModelToolResultContent::Text(text) => { writeln!(markdown, "{text}")?; } LanguageModelToolResultContent::Image(image) => { diff --git a/crates/eval/src/instance.rs b/crates/eval/src/instance.rs index 78326c0a69..955421576d 100644 --- a/crates/eval/src/instance.rs +++ b/crates/eval/src/instance.rs @@ -9,7 +9,7 @@ use handlebars::Handlebars; use language::{Buffer, DiagnosticSeverity, OffsetRangeExt as _}; use language_model::{ LanguageModel, LanguageModelCompletionEvent, LanguageModelRequest, LanguageModelRequestMessage, - LanguageModelToolResultContent, MessageContent, Role, TokenUsage, WrappedTextContent, + LanguageModelToolResultContent, MessageContent, Role, TokenUsage, }; use project::lsp_store::OpenLspBufferHandle; use project::{DiagnosticSummary, Project, ProjectPath}; @@ -967,11 +967,7 @@ impl RequestMarkdown { } match &tool_result.content { - LanguageModelToolResultContent::Text(text) - | LanguageModelToolResultContent::WrappedText(WrappedTextContent { - text, - .. - }) => { + LanguageModelToolResultContent::Text(text) => { writeln!(messages, "{text}\n").ok(); } LanguageModelToolResultContent::Image(image) => { diff --git a/crates/language_model/src/request.rs b/crates/language_model/src/request.rs index 1a6c695192..e997a2ec58 100644 --- a/crates/language_model/src/request.rs +++ b/crates/language_model/src/request.rs @@ -18,7 +18,7 @@ use zed_llm_client::CompletionMode; pub struct LanguageModelImage { /// A base64-encoded PNG image. pub source: SharedString, - size: Size, + pub size: Size, } impl LanguageModelImage { @@ -29,6 +29,41 @@ impl LanguageModelImage { pub fn is_empty(&self) -> bool { self.source.is_empty() } + + // Parse Self from a JSON object with case-insensitive field names + pub fn from_json(obj: &serde_json::Map) -> Option { + let mut source = None; + let mut size_obj = None; + + // Find source and size fields (case-insensitive) + for (k, v) in obj.iter() { + match k.to_lowercase().as_str() { + "source" => source = v.as_str(), + "size" => size_obj = v.as_object(), + _ => {} + } + } + + let source = source?; + let size_obj = size_obj?; + + let mut width = None; + let mut height = None; + + // Find width and height in size object (case-insensitive) + for (k, v) in size_obj.iter() { + match k.to_lowercase().as_str() { + "width" => width = v.as_i64().map(|w| w as i32), + "height" => height = v.as_i64().map(|h| h as i32), + _ => {} + } + } + + Some(Self { + size: size(DevicePixels(width?), DevicePixels(height?)), + source: SharedString::from(source.to_string()), + }) + } } impl std::fmt::Debug for LanguageModelImage { @@ -148,34 +183,102 @@ pub struct LanguageModelToolResult { pub output: Option, } -#[derive(Debug, Clone, Deserialize, Serialize, Eq, PartialEq, Hash)] -#[serde(untagged)] +#[derive(Debug, Clone, Serialize, Eq, PartialEq, Hash)] pub enum LanguageModelToolResultContent { Text(Arc), Image(LanguageModelImage), - WrappedText(WrappedTextContent), } -#[derive(Debug, Clone, Deserialize, Serialize, Eq, PartialEq, Hash)] -pub struct WrappedTextContent { - #[serde(rename = "type")] - pub content_type: String, - pub text: Arc, +impl<'de> Deserialize<'de> for LanguageModelToolResultContent { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + use serde::de::Error; + + let value = serde_json::Value::deserialize(deserializer)?; + + // Models can provide these responses in several styles. Try each in order. + + // 1. Try as plain string + if let Ok(text) = serde_json::from_value::(value.clone()) { + return Ok(Self::Text(Arc::from(text))); + } + + // 2. Try as object + if let Some(obj) = value.as_object() { + // get a JSON field case-insensitively + fn get_field<'a>( + obj: &'a serde_json::Map, + field: &str, + ) -> Option<&'a serde_json::Value> { + obj.iter() + .find(|(k, _)| k.to_lowercase() == field.to_lowercase()) + .map(|(_, v)| v) + } + + // Accept wrapped text format: { "type": "text", "text": "..." } + if let (Some(type_value), Some(text_value)) = + (get_field(&obj, "type"), get_field(&obj, "text")) + { + if let Some(type_str) = type_value.as_str() { + if type_str.to_lowercase() == "text" { + if let Some(text) = text_value.as_str() { + return Ok(Self::Text(Arc::from(text))); + } + } + } + } + + // Check for wrapped Text variant: { "text": "..." } + if let Some((_key, value)) = obj.iter().find(|(k, _)| k.to_lowercase() == "text") { + if obj.len() == 1 { + // Only one field, and it's "text" (case-insensitive) + if let Some(text) = value.as_str() { + return Ok(Self::Text(Arc::from(text))); + } + } + } + + // Check for wrapped Image variant: { "image": { "source": "...", "size": ... } } + if let Some((_key, value)) = obj.iter().find(|(k, _)| k.to_lowercase() == "image") { + if obj.len() == 1 { + // Only one field, and it's "image" (case-insensitive) + // Try to parse the nested image object + if let Some(image_obj) = value.as_object() { + if let Some(image) = LanguageModelImage::from_json(image_obj) { + return Ok(Self::Image(image)); + } + } + } + } + + // Try as direct Image (object with "source" and "size" fields) + if let Some(image) = LanguageModelImage::from_json(&obj) { + return Ok(Self::Image(image)); + } + } + + // If none of the variants match, return an error with the problematic JSON + Err(D::Error::custom(format!( + "data did not match any variant of LanguageModelToolResultContent. Expected either a string, \ + an object with 'type': 'text', a wrapped variant like {{\"Text\": \"...\"}}, or an image object. Got: {}", + serde_json::to_string_pretty(&value).unwrap_or_else(|_| value.to_string()) + ))) + } } impl LanguageModelToolResultContent { pub fn to_str(&self) -> Option<&str> { match self { - Self::Text(text) | Self::WrappedText(WrappedTextContent { text, .. }) => Some(&text), + Self::Text(text) => Some(&text), Self::Image(_) => None, } } pub fn is_empty(&self) -> bool { match self { - Self::Text(text) | Self::WrappedText(WrappedTextContent { text, .. }) => { - text.chars().all(|c| c.is_whitespace()) - } + Self::Text(text) => text.chars().all(|c| c.is_whitespace()), Self::Image(_) => false, } } @@ -294,3 +397,168 @@ pub struct LanguageModelResponseMessage { pub role: Option, pub content: Option, } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_language_model_tool_result_content_deserialization() { + let json = r#""This is plain text""#; + let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap(); + assert_eq!( + result, + LanguageModelToolResultContent::Text("This is plain text".into()) + ); + + let json = r#"{"type": "text", "text": "This is wrapped text"}"#; + let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap(); + assert_eq!( + result, + LanguageModelToolResultContent::Text("This is wrapped text".into()) + ); + + let json = r#"{"Type": "TEXT", "TEXT": "Case insensitive"}"#; + let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap(); + assert_eq!( + result, + LanguageModelToolResultContent::Text("Case insensitive".into()) + ); + + let json = r#"{"Text": "Wrapped variant"}"#; + let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap(); + assert_eq!( + result, + LanguageModelToolResultContent::Text("Wrapped variant".into()) + ); + + let json = r#"{"text": "Lowercase wrapped"}"#; + let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap(); + assert_eq!( + result, + LanguageModelToolResultContent::Text("Lowercase wrapped".into()) + ); + + // Test image deserialization + let json = r#"{ + "source": "base64encodedimagedata", + "size": { + "width": 100, + "height": 200 + } + }"#; + let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap(); + match result { + LanguageModelToolResultContent::Image(image) => { + assert_eq!(image.source.as_ref(), "base64encodedimagedata"); + assert_eq!(image.size.width.0, 100); + assert_eq!(image.size.height.0, 200); + } + _ => panic!("Expected Image variant"), + } + + // Test wrapped Image variant + let json = r#"{ + "Image": { + "source": "wrappedimagedata", + "size": { + "width": 50, + "height": 75 + } + } + }"#; + let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap(); + match result { + LanguageModelToolResultContent::Image(image) => { + assert_eq!(image.source.as_ref(), "wrappedimagedata"); + assert_eq!(image.size.width.0, 50); + assert_eq!(image.size.height.0, 75); + } + _ => panic!("Expected Image variant"), + } + + // Test wrapped Image variant with case insensitive + let json = r#"{ + "image": { + "Source": "caseinsensitive", + "SIZE": { + "width": 30, + "height": 40 + } + } + }"#; + let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap(); + match result { + LanguageModelToolResultContent::Image(image) => { + assert_eq!(image.source.as_ref(), "caseinsensitive"); + assert_eq!(image.size.width.0, 30); + assert_eq!(image.size.height.0, 40); + } + _ => panic!("Expected Image variant"), + } + + // Test that wrapped text with wrong type fails + let json = r#"{"type": "blahblah", "text": "This should fail"}"#; + let result: Result = serde_json::from_str(json); + assert!(result.is_err()); + + // Test that malformed JSON fails + let json = r#"{"invalid": "structure"}"#; + let result: Result = serde_json::from_str(json); + assert!(result.is_err()); + + // Test edge cases + let json = r#""""#; // Empty string + let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap(); + assert_eq!(result, LanguageModelToolResultContent::Text("".into())); + + // Test with extra fields in wrapped text (should be ignored) + let json = r#"{"type": "text", "text": "Hello", "extra": "field"}"#; + let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap(); + assert_eq!(result, LanguageModelToolResultContent::Text("Hello".into())); + + // Test direct image with case-insensitive fields + let json = r#"{ + "SOURCE": "directimage", + "Size": { + "width": 200, + "height": 300 + } + }"#; + let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap(); + match result { + LanguageModelToolResultContent::Image(image) => { + assert_eq!(image.source.as_ref(), "directimage"); + assert_eq!(image.size.width.0, 200); + assert_eq!(image.size.height.0, 300); + } + _ => panic!("Expected Image variant"), + } + + // Test that multiple fields prevent wrapped variant interpretation + let json = r#"{"Text": "not wrapped", "extra": "field"}"#; + let result: Result = serde_json::from_str(json); + assert!(result.is_err()); + + // Test wrapped text with uppercase TEXT variant + let json = r#"{"TEXT": "Uppercase variant"}"#; + let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap(); + assert_eq!( + result, + LanguageModelToolResultContent::Text("Uppercase variant".into()) + ); + + // Test that numbers and other JSON values fail gracefully + let json = r#"123"#; + let result: Result = serde_json::from_str(json); + assert!(result.is_err()); + + let json = r#"null"#; + let result: Result = serde_json::from_str(json); + assert!(result.is_err()); + + let json = r#"[1, 2, 3]"#; + let result: Result = serde_json::from_str(json); + assert!(result.is_err()); + } +} diff --git a/crates/language_models/src/provider/anthropic.rs b/crates/language_models/src/provider/anthropic.rs index f9dc7af3dc..055bdc52e2 100644 --- a/crates/language_models/src/provider/anthropic.rs +++ b/crates/language_models/src/provider/anthropic.rs @@ -19,7 +19,7 @@ use language_model::{ LanguageModelCompletionError, LanguageModelId, LanguageModelKnownError, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, LanguageModelToolChoice, - LanguageModelToolResultContent, MessageContent, RateLimiter, Role, WrappedTextContent, + LanguageModelToolResultContent, MessageContent, RateLimiter, Role, }; use language_model::{LanguageModelCompletionEvent, LanguageModelToolUse, StopReason}; use schemars::JsonSchema; @@ -350,11 +350,7 @@ pub fn count_anthropic_tokens( // TODO: Estimate token usage from tool uses. } MessageContent::ToolResult(tool_result) => match &tool_result.content { - LanguageModelToolResultContent::Text(text) - | LanguageModelToolResultContent::WrappedText(WrappedTextContent { - text, - .. - }) => { + LanguageModelToolResultContent::Text(text) => { string_contents.push_str(text); } LanguageModelToolResultContent::Image(image) => { @@ -592,10 +588,9 @@ pub fn into_anthropic( tool_use_id: tool_result.tool_use_id.to_string(), is_error: tool_result.is_error, content: match tool_result.content { - LanguageModelToolResultContent::Text(text) - | LanguageModelToolResultContent::WrappedText( - WrappedTextContent { text, .. }, - ) => ToolResultContent::Plain(text.to_string()), + LanguageModelToolResultContent::Text(text) => { + ToolResultContent::Plain(text.to_string()) + } LanguageModelToolResultContent::Image(image) => { ToolResultContent::Multipart(vec![ToolResultPart::Image { source: anthropic::ImageSource { diff --git a/crates/language_models/src/provider/bedrock.rs b/crates/language_models/src/provider/bedrock.rs index 7dd524c8fe..d2dc26009e 100644 --- a/crates/language_models/src/provider/bedrock.rs +++ b/crates/language_models/src/provider/bedrock.rs @@ -37,7 +37,7 @@ use language_model::{ LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent, RateLimiter, Role, - TokenUsage, WrappedTextContent, + TokenUsage, }; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; @@ -639,8 +639,7 @@ pub fn into_bedrock( BedrockToolResultBlock::builder() .tool_use_id(tool_result.tool_use_id.to_string()) .content(match tool_result.content { - LanguageModelToolResultContent::Text(text) - | LanguageModelToolResultContent::WrappedText(WrappedTextContent { text, .. }) => { + LanguageModelToolResultContent::Text(text) => { BedrockToolResultContentBlock::Text(text.to_string()) } LanguageModelToolResultContent::Image(_) => { @@ -775,11 +774,7 @@ pub fn get_bedrock_tokens( // TODO: Estimate token usage from tool uses. } MessageContent::ToolResult(tool_result) => match tool_result.content { - LanguageModelToolResultContent::Text(text) - | LanguageModelToolResultContent::WrappedText(WrappedTextContent { - text, - .. - }) => { + LanguageModelToolResultContent::Text(text) => { string_contents.push_str(&text); } LanguageModelToolResultContent::Image(image) => { diff --git a/crates/language_models/src/provider/copilot_chat.rs b/crates/language_models/src/provider/copilot_chat.rs index 78b23af805..25f97ffd59 100644 --- a/crates/language_models/src/provider/copilot_chat.rs +++ b/crates/language_models/src/provider/copilot_chat.rs @@ -23,7 +23,7 @@ use language_model::{ LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, LanguageModelRequestMessage, LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolSchemaFormat, LanguageModelToolUse, MessageContent, RateLimiter, Role, - StopReason, WrappedTextContent, + StopReason, }; use settings::SettingsStore; use std::time::Duration; @@ -455,11 +455,7 @@ fn into_copilot_chat( for content in &message.content { if let MessageContent::ToolResult(tool_result) = content { let content = match &tool_result.content { - LanguageModelToolResultContent::Text(text) - | LanguageModelToolResultContent::WrappedText(WrappedTextContent { - text, - .. - }) => text.to_string().into(), + LanguageModelToolResultContent::Text(text) => text.to_string().into(), LanguageModelToolResultContent::Image(image) => { if model.supports_vision() { ChatMessageContent::Multipart(vec![ChatMessagePart::Image { diff --git a/crates/language_models/src/provider/google.rs b/crates/language_models/src/provider/google.rs index 2203dc261f..73ee095c92 100644 --- a/crates/language_models/src/provider/google.rs +++ b/crates/language_models/src/provider/google.rs @@ -426,10 +426,7 @@ pub fn into_google( } language_model::MessageContent::ToolResult(tool_result) => { match tool_result.content { - language_model::LanguageModelToolResultContent::Text(text) - | language_model::LanguageModelToolResultContent::WrappedText( - language_model::WrappedTextContent { text, .. }, - ) => { + language_model::LanguageModelToolResultContent::Text(text) => { vec![Part::FunctionResponsePart( google_ai::FunctionResponsePart { function_response: google_ai::FunctionResponse { diff --git a/crates/language_models/src/provider/mistral.rs b/crates/language_models/src/provider/mistral.rs index cf9ca366ab..2966c3fad3 100644 --- a/crates/language_models/src/provider/mistral.rs +++ b/crates/language_models/src/provider/mistral.rs @@ -13,7 +13,7 @@ use language_model::{ LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent, - RateLimiter, Role, StopReason, WrappedTextContent, + RateLimiter, Role, StopReason, }; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; @@ -428,11 +428,7 @@ pub fn into_mistral( } MessageContent::ToolResult(tool_result) => { let content = match &tool_result.content { - LanguageModelToolResultContent::Text(text) - | LanguageModelToolResultContent::WrappedText(WrappedTextContent { - text, - .. - }) => text.to_string(), + LanguageModelToolResultContent::Text(text) => text.to_string(), LanguageModelToolResultContent::Image(_) => { // TODO: Mistral image support "[Tool responded with an image, but Zed doesn't support these in Mistral models yet]".to_string() diff --git a/crates/language_models/src/provider/open_ai.rs b/crates/language_models/src/provider/open_ai.rs index c843b736a0..ab2627f780 100644 --- a/crates/language_models/src/provider/open_ai.rs +++ b/crates/language_models/src/provider/open_ai.rs @@ -13,7 +13,7 @@ use language_model::{ LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent, - RateLimiter, Role, StopReason, WrappedTextContent, + RateLimiter, Role, StopReason, }; use open_ai::{ImageUrl, Model, ResponseStreamEvent, stream_completion}; use schemars::JsonSchema; @@ -407,11 +407,7 @@ pub fn into_open_ai( } MessageContent::ToolResult(tool_result) => { let content = match &tool_result.content { - LanguageModelToolResultContent::Text(text) - | LanguageModelToolResultContent::WrappedText(WrappedTextContent { - text, - .. - }) => { + LanguageModelToolResultContent::Text(text) => { vec![open_ai::MessagePart::Text { text: text.to_string(), }]