Make language model deserialization more resilient (#31311)

This expands our deserialization of JSON from models to be more tolerant
of different variations that the model may send, including
capitalization, wrapping things in objects vs. being plain strings, etc.

Also when deserialization fails, it reports the entire error in the JSON
so we can see what failed to deserialize. (Previously these errors were
very unhelpful at diagnosing the problem.)

Finally, also removes the `WrappedText` variant since the custom
deserializer just turns that style of JSON into a normal `Text` variant.

Release Notes:

- N/A
This commit is contained in:
Richard Feldman 2025-05-28 12:06:07 -04:00 committed by GitHub
parent 7443fde4e9
commit 00fd045844
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 301 additions and 69 deletions

View file

@ -18,7 +18,7 @@ use zed_llm_client::CompletionMode;
pub struct LanguageModelImage {
/// A base64-encoded PNG image.
pub source: SharedString,
size: Size<DevicePixels>,
pub size: Size<DevicePixels>,
}
impl LanguageModelImage {
@ -29,6 +29,41 @@ impl LanguageModelImage {
pub fn is_empty(&self) -> bool {
self.source.is_empty()
}
// Parse Self from a JSON object with case-insensitive field names
pub fn from_json(obj: &serde_json::Map<String, serde_json::Value>) -> Option<Self> {
let mut source = None;
let mut size_obj = None;
// Find source and size fields (case-insensitive)
for (k, v) in obj.iter() {
match k.to_lowercase().as_str() {
"source" => source = v.as_str(),
"size" => size_obj = v.as_object(),
_ => {}
}
}
let source = source?;
let size_obj = size_obj?;
let mut width = None;
let mut height = None;
// Find width and height in size object (case-insensitive)
for (k, v) in size_obj.iter() {
match k.to_lowercase().as_str() {
"width" => width = v.as_i64().map(|w| w as i32),
"height" => height = v.as_i64().map(|h| h as i32),
_ => {}
}
}
Some(Self {
size: size(DevicePixels(width?), DevicePixels(height?)),
source: SharedString::from(source.to_string()),
})
}
}
impl std::fmt::Debug for LanguageModelImage {
@ -148,34 +183,102 @@ pub struct LanguageModelToolResult {
pub output: Option<serde_json::Value>,
}
#[derive(Debug, Clone, Deserialize, Serialize, Eq, PartialEq, Hash)]
#[serde(untagged)]
#[derive(Debug, Clone, Serialize, Eq, PartialEq, Hash)]
pub enum LanguageModelToolResultContent {
Text(Arc<str>),
Image(LanguageModelImage),
WrappedText(WrappedTextContent),
}
#[derive(Debug, Clone, Deserialize, Serialize, Eq, PartialEq, Hash)]
pub struct WrappedTextContent {
#[serde(rename = "type")]
pub content_type: String,
pub text: Arc<str>,
impl<'de> Deserialize<'de> for LanguageModelToolResultContent {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
use serde::de::Error;
let value = serde_json::Value::deserialize(deserializer)?;
// Models can provide these responses in several styles. Try each in order.
// 1. Try as plain string
if let Ok(text) = serde_json::from_value::<String>(value.clone()) {
return Ok(Self::Text(Arc::from(text)));
}
// 2. Try as object
if let Some(obj) = value.as_object() {
// get a JSON field case-insensitively
fn get_field<'a>(
obj: &'a serde_json::Map<String, serde_json::Value>,
field: &str,
) -> Option<&'a serde_json::Value> {
obj.iter()
.find(|(k, _)| k.to_lowercase() == field.to_lowercase())
.map(|(_, v)| v)
}
// Accept wrapped text format: { "type": "text", "text": "..." }
if let (Some(type_value), Some(text_value)) =
(get_field(&obj, "type"), get_field(&obj, "text"))
{
if let Some(type_str) = type_value.as_str() {
if type_str.to_lowercase() == "text" {
if let Some(text) = text_value.as_str() {
return Ok(Self::Text(Arc::from(text)));
}
}
}
}
// Check for wrapped Text variant: { "text": "..." }
if let Some((_key, value)) = obj.iter().find(|(k, _)| k.to_lowercase() == "text") {
if obj.len() == 1 {
// Only one field, and it's "text" (case-insensitive)
if let Some(text) = value.as_str() {
return Ok(Self::Text(Arc::from(text)));
}
}
}
// Check for wrapped Image variant: { "image": { "source": "...", "size": ... } }
if let Some((_key, value)) = obj.iter().find(|(k, _)| k.to_lowercase() == "image") {
if obj.len() == 1 {
// Only one field, and it's "image" (case-insensitive)
// Try to parse the nested image object
if let Some(image_obj) = value.as_object() {
if let Some(image) = LanguageModelImage::from_json(image_obj) {
return Ok(Self::Image(image));
}
}
}
}
// Try as direct Image (object with "source" and "size" fields)
if let Some(image) = LanguageModelImage::from_json(&obj) {
return Ok(Self::Image(image));
}
}
// If none of the variants match, return an error with the problematic JSON
Err(D::Error::custom(format!(
"data did not match any variant of LanguageModelToolResultContent. Expected either a string, \
an object with 'type': 'text', a wrapped variant like {{\"Text\": \"...\"}}, or an image object. Got: {}",
serde_json::to_string_pretty(&value).unwrap_or_else(|_| value.to_string())
)))
}
}
impl LanguageModelToolResultContent {
pub fn to_str(&self) -> Option<&str> {
match self {
Self::Text(text) | Self::WrappedText(WrappedTextContent { text, .. }) => Some(&text),
Self::Text(text) => Some(&text),
Self::Image(_) => None,
}
}
pub fn is_empty(&self) -> bool {
match self {
Self::Text(text) | Self::WrappedText(WrappedTextContent { text, .. }) => {
text.chars().all(|c| c.is_whitespace())
}
Self::Text(text) => text.chars().all(|c| c.is_whitespace()),
Self::Image(_) => false,
}
}
@ -294,3 +397,168 @@ pub struct LanguageModelResponseMessage {
pub role: Option<Role>,
pub content: Option<String>,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_language_model_tool_result_content_deserialization() {
let json = r#""This is plain text""#;
let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
assert_eq!(
result,
LanguageModelToolResultContent::Text("This is plain text".into())
);
let json = r#"{"type": "text", "text": "This is wrapped text"}"#;
let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
assert_eq!(
result,
LanguageModelToolResultContent::Text("This is wrapped text".into())
);
let json = r#"{"Type": "TEXT", "TEXT": "Case insensitive"}"#;
let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
assert_eq!(
result,
LanguageModelToolResultContent::Text("Case insensitive".into())
);
let json = r#"{"Text": "Wrapped variant"}"#;
let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
assert_eq!(
result,
LanguageModelToolResultContent::Text("Wrapped variant".into())
);
let json = r#"{"text": "Lowercase wrapped"}"#;
let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
assert_eq!(
result,
LanguageModelToolResultContent::Text("Lowercase wrapped".into())
);
// Test image deserialization
let json = r#"{
"source": "base64encodedimagedata",
"size": {
"width": 100,
"height": 200
}
}"#;
let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
match result {
LanguageModelToolResultContent::Image(image) => {
assert_eq!(image.source.as_ref(), "base64encodedimagedata");
assert_eq!(image.size.width.0, 100);
assert_eq!(image.size.height.0, 200);
}
_ => panic!("Expected Image variant"),
}
// Test wrapped Image variant
let json = r#"{
"Image": {
"source": "wrappedimagedata",
"size": {
"width": 50,
"height": 75
}
}
}"#;
let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
match result {
LanguageModelToolResultContent::Image(image) => {
assert_eq!(image.source.as_ref(), "wrappedimagedata");
assert_eq!(image.size.width.0, 50);
assert_eq!(image.size.height.0, 75);
}
_ => panic!("Expected Image variant"),
}
// Test wrapped Image variant with case insensitive
let json = r#"{
"image": {
"Source": "caseinsensitive",
"SIZE": {
"width": 30,
"height": 40
}
}
}"#;
let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
match result {
LanguageModelToolResultContent::Image(image) => {
assert_eq!(image.source.as_ref(), "caseinsensitive");
assert_eq!(image.size.width.0, 30);
assert_eq!(image.size.height.0, 40);
}
_ => panic!("Expected Image variant"),
}
// Test that wrapped text with wrong type fails
let json = r#"{"type": "blahblah", "text": "This should fail"}"#;
let result: Result<LanguageModelToolResultContent, _> = serde_json::from_str(json);
assert!(result.is_err());
// Test that malformed JSON fails
let json = r#"{"invalid": "structure"}"#;
let result: Result<LanguageModelToolResultContent, _> = serde_json::from_str(json);
assert!(result.is_err());
// Test edge cases
let json = r#""""#; // Empty string
let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
assert_eq!(result, LanguageModelToolResultContent::Text("".into()));
// Test with extra fields in wrapped text (should be ignored)
let json = r#"{"type": "text", "text": "Hello", "extra": "field"}"#;
let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
assert_eq!(result, LanguageModelToolResultContent::Text("Hello".into()));
// Test direct image with case-insensitive fields
let json = r#"{
"SOURCE": "directimage",
"Size": {
"width": 200,
"height": 300
}
}"#;
let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
match result {
LanguageModelToolResultContent::Image(image) => {
assert_eq!(image.source.as_ref(), "directimage");
assert_eq!(image.size.width.0, 200);
assert_eq!(image.size.height.0, 300);
}
_ => panic!("Expected Image variant"),
}
// Test that multiple fields prevent wrapped variant interpretation
let json = r#"{"Text": "not wrapped", "extra": "field"}"#;
let result: Result<LanguageModelToolResultContent, _> = serde_json::from_str(json);
assert!(result.is_err());
// Test wrapped text with uppercase TEXT variant
let json = r#"{"TEXT": "Uppercase variant"}"#;
let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
assert_eq!(
result,
LanguageModelToolResultContent::Text("Uppercase variant".into())
);
// Test that numbers and other JSON values fail gracefully
let json = r#"123"#;
let result: Result<LanguageModelToolResultContent, _> = serde_json::from_str(json);
assert!(result.is_err());
let json = r#"null"#;
let result: Result<LanguageModelToolResultContent, _> = serde_json::from_str(json);
assert!(result.is_err());
let json = r#"[1, 2, 3]"#;
let result: Result<LanguageModelToolResultContent, _> = serde_json::from_str(json);
assert!(result.is_err());
}
}