From 00fd045844a4bfc902863c64a2a331df16fea629 Mon Sep 17 00:00:00 2001 From: Richard Feldman Date: Wed, 28 May 2025 12:06:07 -0400 Subject: [PATCH] Make language model deserialization more resilient (#31311) This expands our deserialization of JSON from models to be more tolerant of different variations that the model may send, including capitalization, wrapping things in objects vs. being plain strings, etc. Also when deserialization fails, it reports the entire error in the JSON so we can see what failed to deserialize. (Previously these errors were very unhelpful at diagnosing the problem.) Finally, also removes the `WrappedText` variant since the custom deserializer just turns that style of JSON into a normal `Text` variant. Release Notes: - N/A --- crates/agent/src/thread.rs | 13 +- crates/eval/src/instance.rs | 8 +- crates/language_model/src/request.rs | 294 +++++++++++++++++- .../language_models/src/provider/anthropic.rs | 15 +- .../language_models/src/provider/bedrock.rs | 11 +- .../src/provider/copilot_chat.rs | 8 +- crates/language_models/src/provider/google.rs | 5 +- .../language_models/src/provider/mistral.rs | 8 +- .../language_models/src/provider/open_ai.rs | 8 +- 9 files changed, 301 insertions(+), 69 deletions(-) diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index 004a9ead7b..d3d9a62f78 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -24,7 +24,7 @@ use language_model::{ LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolUseId, MessageContent, ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, SelectedModel, - StopReason, TokenUsage, WrappedTextContent, + StopReason, TokenUsage, }; use postage::stream::Stream as _; use project::Project; @@ -891,10 +891,7 @@ impl Thread { pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc> { match &self.tool_use.tool_result(id)?.content { - LanguageModelToolResultContent::Text(text) - | LanguageModelToolResultContent::WrappedText(WrappedTextContent { text, .. }) => { - Some(text) - } + LanguageModelToolResultContent::Text(text) => Some(text), LanguageModelToolResultContent::Image(_) => { // TODO: We should display image None @@ -2593,11 +2590,7 @@ impl Thread { writeln!(markdown, "**\n")?; match &tool_result.content { - LanguageModelToolResultContent::Text(text) - | LanguageModelToolResultContent::WrappedText(WrappedTextContent { - text, - .. - }) => { + LanguageModelToolResultContent::Text(text) => { writeln!(markdown, "{text}")?; } LanguageModelToolResultContent::Image(image) => { diff --git a/crates/eval/src/instance.rs b/crates/eval/src/instance.rs index 78326c0a69..955421576d 100644 --- a/crates/eval/src/instance.rs +++ b/crates/eval/src/instance.rs @@ -9,7 +9,7 @@ use handlebars::Handlebars; use language::{Buffer, DiagnosticSeverity, OffsetRangeExt as _}; use language_model::{ LanguageModel, LanguageModelCompletionEvent, LanguageModelRequest, LanguageModelRequestMessage, - LanguageModelToolResultContent, MessageContent, Role, TokenUsage, WrappedTextContent, + LanguageModelToolResultContent, MessageContent, Role, TokenUsage, }; use project::lsp_store::OpenLspBufferHandle; use project::{DiagnosticSummary, Project, ProjectPath}; @@ -967,11 +967,7 @@ impl RequestMarkdown { } match &tool_result.content { - LanguageModelToolResultContent::Text(text) - | LanguageModelToolResultContent::WrappedText(WrappedTextContent { - text, - .. - }) => { + LanguageModelToolResultContent::Text(text) => { writeln!(messages, "{text}\n").ok(); } LanguageModelToolResultContent::Image(image) => { diff --git a/crates/language_model/src/request.rs b/crates/language_model/src/request.rs index 1a6c695192..e997a2ec58 100644 --- a/crates/language_model/src/request.rs +++ b/crates/language_model/src/request.rs @@ -18,7 +18,7 @@ use zed_llm_client::CompletionMode; pub struct LanguageModelImage { /// A base64-encoded PNG image. pub source: SharedString, - size: Size, + pub size: Size, } impl LanguageModelImage { @@ -29,6 +29,41 @@ impl LanguageModelImage { pub fn is_empty(&self) -> bool { self.source.is_empty() } + + // Parse Self from a JSON object with case-insensitive field names + pub fn from_json(obj: &serde_json::Map) -> Option { + let mut source = None; + let mut size_obj = None; + + // Find source and size fields (case-insensitive) + for (k, v) in obj.iter() { + match k.to_lowercase().as_str() { + "source" => source = v.as_str(), + "size" => size_obj = v.as_object(), + _ => {} + } + } + + let source = source?; + let size_obj = size_obj?; + + let mut width = None; + let mut height = None; + + // Find width and height in size object (case-insensitive) + for (k, v) in size_obj.iter() { + match k.to_lowercase().as_str() { + "width" => width = v.as_i64().map(|w| w as i32), + "height" => height = v.as_i64().map(|h| h as i32), + _ => {} + } + } + + Some(Self { + size: size(DevicePixels(width?), DevicePixels(height?)), + source: SharedString::from(source.to_string()), + }) + } } impl std::fmt::Debug for LanguageModelImage { @@ -148,34 +183,102 @@ pub struct LanguageModelToolResult { pub output: Option, } -#[derive(Debug, Clone, Deserialize, Serialize, Eq, PartialEq, Hash)] -#[serde(untagged)] +#[derive(Debug, Clone, Serialize, Eq, PartialEq, Hash)] pub enum LanguageModelToolResultContent { Text(Arc), Image(LanguageModelImage), - WrappedText(WrappedTextContent), } -#[derive(Debug, Clone, Deserialize, Serialize, Eq, PartialEq, Hash)] -pub struct WrappedTextContent { - #[serde(rename = "type")] - pub content_type: String, - pub text: Arc, +impl<'de> Deserialize<'de> for LanguageModelToolResultContent { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + use serde::de::Error; + + let value = serde_json::Value::deserialize(deserializer)?; + + // Models can provide these responses in several styles. Try each in order. + + // 1. Try as plain string + if let Ok(text) = serde_json::from_value::(value.clone()) { + return Ok(Self::Text(Arc::from(text))); + } + + // 2. Try as object + if let Some(obj) = value.as_object() { + // get a JSON field case-insensitively + fn get_field<'a>( + obj: &'a serde_json::Map, + field: &str, + ) -> Option<&'a serde_json::Value> { + obj.iter() + .find(|(k, _)| k.to_lowercase() == field.to_lowercase()) + .map(|(_, v)| v) + } + + // Accept wrapped text format: { "type": "text", "text": "..." } + if let (Some(type_value), Some(text_value)) = + (get_field(&obj, "type"), get_field(&obj, "text")) + { + if let Some(type_str) = type_value.as_str() { + if type_str.to_lowercase() == "text" { + if let Some(text) = text_value.as_str() { + return Ok(Self::Text(Arc::from(text))); + } + } + } + } + + // Check for wrapped Text variant: { "text": "..." } + if let Some((_key, value)) = obj.iter().find(|(k, _)| k.to_lowercase() == "text") { + if obj.len() == 1 { + // Only one field, and it's "text" (case-insensitive) + if let Some(text) = value.as_str() { + return Ok(Self::Text(Arc::from(text))); + } + } + } + + // Check for wrapped Image variant: { "image": { "source": "...", "size": ... } } + if let Some((_key, value)) = obj.iter().find(|(k, _)| k.to_lowercase() == "image") { + if obj.len() == 1 { + // Only one field, and it's "image" (case-insensitive) + // Try to parse the nested image object + if let Some(image_obj) = value.as_object() { + if let Some(image) = LanguageModelImage::from_json(image_obj) { + return Ok(Self::Image(image)); + } + } + } + } + + // Try as direct Image (object with "source" and "size" fields) + if let Some(image) = LanguageModelImage::from_json(&obj) { + return Ok(Self::Image(image)); + } + } + + // If none of the variants match, return an error with the problematic JSON + Err(D::Error::custom(format!( + "data did not match any variant of LanguageModelToolResultContent. Expected either a string, \ + an object with 'type': 'text', a wrapped variant like {{\"Text\": \"...\"}}, or an image object. Got: {}", + serde_json::to_string_pretty(&value).unwrap_or_else(|_| value.to_string()) + ))) + } } impl LanguageModelToolResultContent { pub fn to_str(&self) -> Option<&str> { match self { - Self::Text(text) | Self::WrappedText(WrappedTextContent { text, .. }) => Some(&text), + Self::Text(text) => Some(&text), Self::Image(_) => None, } } pub fn is_empty(&self) -> bool { match self { - Self::Text(text) | Self::WrappedText(WrappedTextContent { text, .. }) => { - text.chars().all(|c| c.is_whitespace()) - } + Self::Text(text) => text.chars().all(|c| c.is_whitespace()), Self::Image(_) => false, } } @@ -294,3 +397,168 @@ pub struct LanguageModelResponseMessage { pub role: Option, pub content: Option, } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_language_model_tool_result_content_deserialization() { + let json = r#""This is plain text""#; + let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap(); + assert_eq!( + result, + LanguageModelToolResultContent::Text("This is plain text".into()) + ); + + let json = r#"{"type": "text", "text": "This is wrapped text"}"#; + let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap(); + assert_eq!( + result, + LanguageModelToolResultContent::Text("This is wrapped text".into()) + ); + + let json = r#"{"Type": "TEXT", "TEXT": "Case insensitive"}"#; + let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap(); + assert_eq!( + result, + LanguageModelToolResultContent::Text("Case insensitive".into()) + ); + + let json = r#"{"Text": "Wrapped variant"}"#; + let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap(); + assert_eq!( + result, + LanguageModelToolResultContent::Text("Wrapped variant".into()) + ); + + let json = r#"{"text": "Lowercase wrapped"}"#; + let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap(); + assert_eq!( + result, + LanguageModelToolResultContent::Text("Lowercase wrapped".into()) + ); + + // Test image deserialization + let json = r#"{ + "source": "base64encodedimagedata", + "size": { + "width": 100, + "height": 200 + } + }"#; + let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap(); + match result { + LanguageModelToolResultContent::Image(image) => { + assert_eq!(image.source.as_ref(), "base64encodedimagedata"); + assert_eq!(image.size.width.0, 100); + assert_eq!(image.size.height.0, 200); + } + _ => panic!("Expected Image variant"), + } + + // Test wrapped Image variant + let json = r#"{ + "Image": { + "source": "wrappedimagedata", + "size": { + "width": 50, + "height": 75 + } + } + }"#; + let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap(); + match result { + LanguageModelToolResultContent::Image(image) => { + assert_eq!(image.source.as_ref(), "wrappedimagedata"); + assert_eq!(image.size.width.0, 50); + assert_eq!(image.size.height.0, 75); + } + _ => panic!("Expected Image variant"), + } + + // Test wrapped Image variant with case insensitive + let json = r#"{ + "image": { + "Source": "caseinsensitive", + "SIZE": { + "width": 30, + "height": 40 + } + } + }"#; + let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap(); + match result { + LanguageModelToolResultContent::Image(image) => { + assert_eq!(image.source.as_ref(), "caseinsensitive"); + assert_eq!(image.size.width.0, 30); + assert_eq!(image.size.height.0, 40); + } + _ => panic!("Expected Image variant"), + } + + // Test that wrapped text with wrong type fails + let json = r#"{"type": "blahblah", "text": "This should fail"}"#; + let result: Result = serde_json::from_str(json); + assert!(result.is_err()); + + // Test that malformed JSON fails + let json = r#"{"invalid": "structure"}"#; + let result: Result = serde_json::from_str(json); + assert!(result.is_err()); + + // Test edge cases + let json = r#""""#; // Empty string + let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap(); + assert_eq!(result, LanguageModelToolResultContent::Text("".into())); + + // Test with extra fields in wrapped text (should be ignored) + let json = r#"{"type": "text", "text": "Hello", "extra": "field"}"#; + let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap(); + assert_eq!(result, LanguageModelToolResultContent::Text("Hello".into())); + + // Test direct image with case-insensitive fields + let json = r#"{ + "SOURCE": "directimage", + "Size": { + "width": 200, + "height": 300 + } + }"#; + let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap(); + match result { + LanguageModelToolResultContent::Image(image) => { + assert_eq!(image.source.as_ref(), "directimage"); + assert_eq!(image.size.width.0, 200); + assert_eq!(image.size.height.0, 300); + } + _ => panic!("Expected Image variant"), + } + + // Test that multiple fields prevent wrapped variant interpretation + let json = r#"{"Text": "not wrapped", "extra": "field"}"#; + let result: Result = serde_json::from_str(json); + assert!(result.is_err()); + + // Test wrapped text with uppercase TEXT variant + let json = r#"{"TEXT": "Uppercase variant"}"#; + let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap(); + assert_eq!( + result, + LanguageModelToolResultContent::Text("Uppercase variant".into()) + ); + + // Test that numbers and other JSON values fail gracefully + let json = r#"123"#; + let result: Result = serde_json::from_str(json); + assert!(result.is_err()); + + let json = r#"null"#; + let result: Result = serde_json::from_str(json); + assert!(result.is_err()); + + let json = r#"[1, 2, 3]"#; + let result: Result = serde_json::from_str(json); + assert!(result.is_err()); + } +} diff --git a/crates/language_models/src/provider/anthropic.rs b/crates/language_models/src/provider/anthropic.rs index f9dc7af3dc..055bdc52e2 100644 --- a/crates/language_models/src/provider/anthropic.rs +++ b/crates/language_models/src/provider/anthropic.rs @@ -19,7 +19,7 @@ use language_model::{ LanguageModelCompletionError, LanguageModelId, LanguageModelKnownError, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, LanguageModelToolChoice, - LanguageModelToolResultContent, MessageContent, RateLimiter, Role, WrappedTextContent, + LanguageModelToolResultContent, MessageContent, RateLimiter, Role, }; use language_model::{LanguageModelCompletionEvent, LanguageModelToolUse, StopReason}; use schemars::JsonSchema; @@ -350,11 +350,7 @@ pub fn count_anthropic_tokens( // TODO: Estimate token usage from tool uses. } MessageContent::ToolResult(tool_result) => match &tool_result.content { - LanguageModelToolResultContent::Text(text) - | LanguageModelToolResultContent::WrappedText(WrappedTextContent { - text, - .. - }) => { + LanguageModelToolResultContent::Text(text) => { string_contents.push_str(text); } LanguageModelToolResultContent::Image(image) => { @@ -592,10 +588,9 @@ pub fn into_anthropic( tool_use_id: tool_result.tool_use_id.to_string(), is_error: tool_result.is_error, content: match tool_result.content { - LanguageModelToolResultContent::Text(text) - | LanguageModelToolResultContent::WrappedText( - WrappedTextContent { text, .. }, - ) => ToolResultContent::Plain(text.to_string()), + LanguageModelToolResultContent::Text(text) => { + ToolResultContent::Plain(text.to_string()) + } LanguageModelToolResultContent::Image(image) => { ToolResultContent::Multipart(vec![ToolResultPart::Image { source: anthropic::ImageSource { diff --git a/crates/language_models/src/provider/bedrock.rs b/crates/language_models/src/provider/bedrock.rs index 7dd524c8fe..d2dc26009e 100644 --- a/crates/language_models/src/provider/bedrock.rs +++ b/crates/language_models/src/provider/bedrock.rs @@ -37,7 +37,7 @@ use language_model::{ LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent, RateLimiter, Role, - TokenUsage, WrappedTextContent, + TokenUsage, }; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; @@ -639,8 +639,7 @@ pub fn into_bedrock( BedrockToolResultBlock::builder() .tool_use_id(tool_result.tool_use_id.to_string()) .content(match tool_result.content { - LanguageModelToolResultContent::Text(text) - | LanguageModelToolResultContent::WrappedText(WrappedTextContent { text, .. }) => { + LanguageModelToolResultContent::Text(text) => { BedrockToolResultContentBlock::Text(text.to_string()) } LanguageModelToolResultContent::Image(_) => { @@ -775,11 +774,7 @@ pub fn get_bedrock_tokens( // TODO: Estimate token usage from tool uses. } MessageContent::ToolResult(tool_result) => match tool_result.content { - LanguageModelToolResultContent::Text(text) - | LanguageModelToolResultContent::WrappedText(WrappedTextContent { - text, - .. - }) => { + LanguageModelToolResultContent::Text(text) => { string_contents.push_str(&text); } LanguageModelToolResultContent::Image(image) => { diff --git a/crates/language_models/src/provider/copilot_chat.rs b/crates/language_models/src/provider/copilot_chat.rs index 78b23af805..25f97ffd59 100644 --- a/crates/language_models/src/provider/copilot_chat.rs +++ b/crates/language_models/src/provider/copilot_chat.rs @@ -23,7 +23,7 @@ use language_model::{ LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, LanguageModelRequestMessage, LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolSchemaFormat, LanguageModelToolUse, MessageContent, RateLimiter, Role, - StopReason, WrappedTextContent, + StopReason, }; use settings::SettingsStore; use std::time::Duration; @@ -455,11 +455,7 @@ fn into_copilot_chat( for content in &message.content { if let MessageContent::ToolResult(tool_result) = content { let content = match &tool_result.content { - LanguageModelToolResultContent::Text(text) - | LanguageModelToolResultContent::WrappedText(WrappedTextContent { - text, - .. - }) => text.to_string().into(), + LanguageModelToolResultContent::Text(text) => text.to_string().into(), LanguageModelToolResultContent::Image(image) => { if model.supports_vision() { ChatMessageContent::Multipart(vec![ChatMessagePart::Image { diff --git a/crates/language_models/src/provider/google.rs b/crates/language_models/src/provider/google.rs index 2203dc261f..73ee095c92 100644 --- a/crates/language_models/src/provider/google.rs +++ b/crates/language_models/src/provider/google.rs @@ -426,10 +426,7 @@ pub fn into_google( } language_model::MessageContent::ToolResult(tool_result) => { match tool_result.content { - language_model::LanguageModelToolResultContent::Text(text) - | language_model::LanguageModelToolResultContent::WrappedText( - language_model::WrappedTextContent { text, .. }, - ) => { + language_model::LanguageModelToolResultContent::Text(text) => { vec![Part::FunctionResponsePart( google_ai::FunctionResponsePart { function_response: google_ai::FunctionResponse { diff --git a/crates/language_models/src/provider/mistral.rs b/crates/language_models/src/provider/mistral.rs index cf9ca366ab..2966c3fad3 100644 --- a/crates/language_models/src/provider/mistral.rs +++ b/crates/language_models/src/provider/mistral.rs @@ -13,7 +13,7 @@ use language_model::{ LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent, - RateLimiter, Role, StopReason, WrappedTextContent, + RateLimiter, Role, StopReason, }; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; @@ -428,11 +428,7 @@ pub fn into_mistral( } MessageContent::ToolResult(tool_result) => { let content = match &tool_result.content { - LanguageModelToolResultContent::Text(text) - | LanguageModelToolResultContent::WrappedText(WrappedTextContent { - text, - .. - }) => text.to_string(), + LanguageModelToolResultContent::Text(text) => text.to_string(), LanguageModelToolResultContent::Image(_) => { // TODO: Mistral image support "[Tool responded with an image, but Zed doesn't support these in Mistral models yet]".to_string() diff --git a/crates/language_models/src/provider/open_ai.rs b/crates/language_models/src/provider/open_ai.rs index c843b736a0..ab2627f780 100644 --- a/crates/language_models/src/provider/open_ai.rs +++ b/crates/language_models/src/provider/open_ai.rs @@ -13,7 +13,7 @@ use language_model::{ LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent, - RateLimiter, Role, StopReason, WrappedTextContent, + RateLimiter, Role, StopReason, }; use open_ai::{ImageUrl, Model, ResponseStreamEvent, stream_completion}; use schemars::JsonSchema; @@ -407,11 +407,7 @@ pub fn into_open_ai( } MessageContent::ToolResult(tool_result) => { let content = match &tool_result.content { - LanguageModelToolResultContent::Text(text) - | LanguageModelToolResultContent::WrappedText(WrappedTextContent { - text, - .. - }) => { + LanguageModelToolResultContent::Text(text) => { vec![open_ai::MessagePart::Text { text: text.to_string(), }]