Make language model deserialization more resilient (#31311)
This expands our deserialization of JSON from models to be more tolerant of different variations that the model may send, including capitalization, wrapping things in objects vs. being plain strings, etc. Also when deserialization fails, it reports the entire error in the JSON so we can see what failed to deserialize. (Previously these errors were very unhelpful at diagnosing the problem.) Finally, also removes the `WrappedText` variant since the custom deserializer just turns that style of JSON into a normal `Text` variant. Release Notes: - N/A
This commit is contained in:
parent
7443fde4e9
commit
00fd045844
9 changed files with 301 additions and 69 deletions
|
@ -18,7 +18,7 @@ use zed_llm_client::CompletionMode;
|
|||
pub struct LanguageModelImage {
|
||||
/// A base64-encoded PNG image.
|
||||
pub source: SharedString,
|
||||
size: Size<DevicePixels>,
|
||||
pub size: Size<DevicePixels>,
|
||||
}
|
||||
|
||||
impl LanguageModelImage {
|
||||
|
@ -29,6 +29,41 @@ impl LanguageModelImage {
|
|||
pub fn is_empty(&self) -> bool {
|
||||
self.source.is_empty()
|
||||
}
|
||||
|
||||
// Parse Self from a JSON object with case-insensitive field names
|
||||
pub fn from_json(obj: &serde_json::Map<String, serde_json::Value>) -> Option<Self> {
|
||||
let mut source = None;
|
||||
let mut size_obj = None;
|
||||
|
||||
// Find source and size fields (case-insensitive)
|
||||
for (k, v) in obj.iter() {
|
||||
match k.to_lowercase().as_str() {
|
||||
"source" => source = v.as_str(),
|
||||
"size" => size_obj = v.as_object(),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
let source = source?;
|
||||
let size_obj = size_obj?;
|
||||
|
||||
let mut width = None;
|
||||
let mut height = None;
|
||||
|
||||
// Find width and height in size object (case-insensitive)
|
||||
for (k, v) in size_obj.iter() {
|
||||
match k.to_lowercase().as_str() {
|
||||
"width" => width = v.as_i64().map(|w| w as i32),
|
||||
"height" => height = v.as_i64().map(|h| h as i32),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
Some(Self {
|
||||
size: size(DevicePixels(width?), DevicePixels(height?)),
|
||||
source: SharedString::from(source.to_string()),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for LanguageModelImage {
|
||||
|
@ -148,34 +183,102 @@ pub struct LanguageModelToolResult {
|
|||
pub output: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, Eq, PartialEq, Hash)]
|
||||
#[serde(untagged)]
|
||||
#[derive(Debug, Clone, Serialize, Eq, PartialEq, Hash)]
|
||||
pub enum LanguageModelToolResultContent {
|
||||
Text(Arc<str>),
|
||||
Image(LanguageModelImage),
|
||||
WrappedText(WrappedTextContent),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, Eq, PartialEq, Hash)]
|
||||
pub struct WrappedTextContent {
|
||||
#[serde(rename = "type")]
|
||||
pub content_type: String,
|
||||
pub text: Arc<str>,
|
||||
impl<'de> Deserialize<'de> for LanguageModelToolResultContent {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
use serde::de::Error;
|
||||
|
||||
let value = serde_json::Value::deserialize(deserializer)?;
|
||||
|
||||
// Models can provide these responses in several styles. Try each in order.
|
||||
|
||||
// 1. Try as plain string
|
||||
if let Ok(text) = serde_json::from_value::<String>(value.clone()) {
|
||||
return Ok(Self::Text(Arc::from(text)));
|
||||
}
|
||||
|
||||
// 2. Try as object
|
||||
if let Some(obj) = value.as_object() {
|
||||
// get a JSON field case-insensitively
|
||||
fn get_field<'a>(
|
||||
obj: &'a serde_json::Map<String, serde_json::Value>,
|
||||
field: &str,
|
||||
) -> Option<&'a serde_json::Value> {
|
||||
obj.iter()
|
||||
.find(|(k, _)| k.to_lowercase() == field.to_lowercase())
|
||||
.map(|(_, v)| v)
|
||||
}
|
||||
|
||||
// Accept wrapped text format: { "type": "text", "text": "..." }
|
||||
if let (Some(type_value), Some(text_value)) =
|
||||
(get_field(&obj, "type"), get_field(&obj, "text"))
|
||||
{
|
||||
if let Some(type_str) = type_value.as_str() {
|
||||
if type_str.to_lowercase() == "text" {
|
||||
if let Some(text) = text_value.as_str() {
|
||||
return Ok(Self::Text(Arc::from(text)));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check for wrapped Text variant: { "text": "..." }
|
||||
if let Some((_key, value)) = obj.iter().find(|(k, _)| k.to_lowercase() == "text") {
|
||||
if obj.len() == 1 {
|
||||
// Only one field, and it's "text" (case-insensitive)
|
||||
if let Some(text) = value.as_str() {
|
||||
return Ok(Self::Text(Arc::from(text)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check for wrapped Image variant: { "image": { "source": "...", "size": ... } }
|
||||
if let Some((_key, value)) = obj.iter().find(|(k, _)| k.to_lowercase() == "image") {
|
||||
if obj.len() == 1 {
|
||||
// Only one field, and it's "image" (case-insensitive)
|
||||
// Try to parse the nested image object
|
||||
if let Some(image_obj) = value.as_object() {
|
||||
if let Some(image) = LanguageModelImage::from_json(image_obj) {
|
||||
return Ok(Self::Image(image));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Try as direct Image (object with "source" and "size" fields)
|
||||
if let Some(image) = LanguageModelImage::from_json(&obj) {
|
||||
return Ok(Self::Image(image));
|
||||
}
|
||||
}
|
||||
|
||||
// If none of the variants match, return an error with the problematic JSON
|
||||
Err(D::Error::custom(format!(
|
||||
"data did not match any variant of LanguageModelToolResultContent. Expected either a string, \
|
||||
an object with 'type': 'text', a wrapped variant like {{\"Text\": \"...\"}}, or an image object. Got: {}",
|
||||
serde_json::to_string_pretty(&value).unwrap_or_else(|_| value.to_string())
|
||||
)))
|
||||
}
|
||||
}
|
||||
|
||||
impl LanguageModelToolResultContent {
|
||||
pub fn to_str(&self) -> Option<&str> {
|
||||
match self {
|
||||
Self::Text(text) | Self::WrappedText(WrappedTextContent { text, .. }) => Some(&text),
|
||||
Self::Text(text) => Some(&text),
|
||||
Self::Image(_) => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_empty(&self) -> bool {
|
||||
match self {
|
||||
Self::Text(text) | Self::WrappedText(WrappedTextContent { text, .. }) => {
|
||||
text.chars().all(|c| c.is_whitespace())
|
||||
}
|
||||
Self::Text(text) => text.chars().all(|c| c.is_whitespace()),
|
||||
Self::Image(_) => false,
|
||||
}
|
||||
}
|
||||
|
@ -294,3 +397,168 @@ pub struct LanguageModelResponseMessage {
|
|||
pub role: Option<Role>,
|
||||
pub content: Option<String>,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_language_model_tool_result_content_deserialization() {
|
||||
let json = r#""This is plain text""#;
|
||||
let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(
|
||||
result,
|
||||
LanguageModelToolResultContent::Text("This is plain text".into())
|
||||
);
|
||||
|
||||
let json = r#"{"type": "text", "text": "This is wrapped text"}"#;
|
||||
let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(
|
||||
result,
|
||||
LanguageModelToolResultContent::Text("This is wrapped text".into())
|
||||
);
|
||||
|
||||
let json = r#"{"Type": "TEXT", "TEXT": "Case insensitive"}"#;
|
||||
let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(
|
||||
result,
|
||||
LanguageModelToolResultContent::Text("Case insensitive".into())
|
||||
);
|
||||
|
||||
let json = r#"{"Text": "Wrapped variant"}"#;
|
||||
let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(
|
||||
result,
|
||||
LanguageModelToolResultContent::Text("Wrapped variant".into())
|
||||
);
|
||||
|
||||
let json = r#"{"text": "Lowercase wrapped"}"#;
|
||||
let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(
|
||||
result,
|
||||
LanguageModelToolResultContent::Text("Lowercase wrapped".into())
|
||||
);
|
||||
|
||||
// Test image deserialization
|
||||
let json = r#"{
|
||||
"source": "base64encodedimagedata",
|
||||
"size": {
|
||||
"width": 100,
|
||||
"height": 200
|
||||
}
|
||||
}"#;
|
||||
let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
|
||||
match result {
|
||||
LanguageModelToolResultContent::Image(image) => {
|
||||
assert_eq!(image.source.as_ref(), "base64encodedimagedata");
|
||||
assert_eq!(image.size.width.0, 100);
|
||||
assert_eq!(image.size.height.0, 200);
|
||||
}
|
||||
_ => panic!("Expected Image variant"),
|
||||
}
|
||||
|
||||
// Test wrapped Image variant
|
||||
let json = r#"{
|
||||
"Image": {
|
||||
"source": "wrappedimagedata",
|
||||
"size": {
|
||||
"width": 50,
|
||||
"height": 75
|
||||
}
|
||||
}
|
||||
}"#;
|
||||
let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
|
||||
match result {
|
||||
LanguageModelToolResultContent::Image(image) => {
|
||||
assert_eq!(image.source.as_ref(), "wrappedimagedata");
|
||||
assert_eq!(image.size.width.0, 50);
|
||||
assert_eq!(image.size.height.0, 75);
|
||||
}
|
||||
_ => panic!("Expected Image variant"),
|
||||
}
|
||||
|
||||
// Test wrapped Image variant with case insensitive
|
||||
let json = r#"{
|
||||
"image": {
|
||||
"Source": "caseinsensitive",
|
||||
"SIZE": {
|
||||
"width": 30,
|
||||
"height": 40
|
||||
}
|
||||
}
|
||||
}"#;
|
||||
let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
|
||||
match result {
|
||||
LanguageModelToolResultContent::Image(image) => {
|
||||
assert_eq!(image.source.as_ref(), "caseinsensitive");
|
||||
assert_eq!(image.size.width.0, 30);
|
||||
assert_eq!(image.size.height.0, 40);
|
||||
}
|
||||
_ => panic!("Expected Image variant"),
|
||||
}
|
||||
|
||||
// Test that wrapped text with wrong type fails
|
||||
let json = r#"{"type": "blahblah", "text": "This should fail"}"#;
|
||||
let result: Result<LanguageModelToolResultContent, _> = serde_json::from_str(json);
|
||||
assert!(result.is_err());
|
||||
|
||||
// Test that malformed JSON fails
|
||||
let json = r#"{"invalid": "structure"}"#;
|
||||
let result: Result<LanguageModelToolResultContent, _> = serde_json::from_str(json);
|
||||
assert!(result.is_err());
|
||||
|
||||
// Test edge cases
|
||||
let json = r#""""#; // Empty string
|
||||
let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(result, LanguageModelToolResultContent::Text("".into()));
|
||||
|
||||
// Test with extra fields in wrapped text (should be ignored)
|
||||
let json = r#"{"type": "text", "text": "Hello", "extra": "field"}"#;
|
||||
let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(result, LanguageModelToolResultContent::Text("Hello".into()));
|
||||
|
||||
// Test direct image with case-insensitive fields
|
||||
let json = r#"{
|
||||
"SOURCE": "directimage",
|
||||
"Size": {
|
||||
"width": 200,
|
||||
"height": 300
|
||||
}
|
||||
}"#;
|
||||
let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
|
||||
match result {
|
||||
LanguageModelToolResultContent::Image(image) => {
|
||||
assert_eq!(image.source.as_ref(), "directimage");
|
||||
assert_eq!(image.size.width.0, 200);
|
||||
assert_eq!(image.size.height.0, 300);
|
||||
}
|
||||
_ => panic!("Expected Image variant"),
|
||||
}
|
||||
|
||||
// Test that multiple fields prevent wrapped variant interpretation
|
||||
let json = r#"{"Text": "not wrapped", "extra": "field"}"#;
|
||||
let result: Result<LanguageModelToolResultContent, _> = serde_json::from_str(json);
|
||||
assert!(result.is_err());
|
||||
|
||||
// Test wrapped text with uppercase TEXT variant
|
||||
let json = r#"{"TEXT": "Uppercase variant"}"#;
|
||||
let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(
|
||||
result,
|
||||
LanguageModelToolResultContent::Text("Uppercase variant".into())
|
||||
);
|
||||
|
||||
// Test that numbers and other JSON values fail gracefully
|
||||
let json = r#"123"#;
|
||||
let result: Result<LanguageModelToolResultContent, _> = serde_json::from_str(json);
|
||||
assert!(result.is_err());
|
||||
|
||||
let json = r#"null"#;
|
||||
let result: Result<LanguageModelToolResultContent, _> = serde_json::from_str(json);
|
||||
assert!(result.is_err());
|
||||
|
||||
let json = r#"[1, 2, 3]"#;
|
||||
let result: Result<LanguageModelToolResultContent, _> = serde_json::from_str(json);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue