Treat invalid JSON in tool calls as failed tool calls (#29375)

Release Notes:

- N/A

---------

Co-authored-by: Max <max@zed.dev>
Co-authored-by: Max Brunsfeld <maxbrunsfeld@gmail.com>
This commit is contained in:
Richard Feldman 2025-04-24 16:54:27 -04:00 committed by GitHub
parent a98c648201
commit 720dfee803
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
17 changed files with 374 additions and 168 deletions

View file

@ -43,6 +43,7 @@ use ui::{
Disclosure, IconButton, KeyBinding, Scrollbar, ScrollbarState, TextSize, Tooltip, prelude::*,
};
use util::ResultExt as _;
use util::markdown::MarkdownString;
use workspace::{OpenOptions, Workspace};
use zed_actions::assistant::OpenRulesLibrary;
@ -769,7 +770,7 @@ impl ActiveThread {
this.render_tool_use_markdown(
tool_use.id.clone(),
tool_use.ui_text.clone(),
&tool_use.input,
&serde_json::to_string_pretty(&tool_use.input).unwrap_or_default(),
tool_use.status.text(),
cx,
);
@ -870,7 +871,7 @@ impl ActiveThread {
&mut self,
tool_use_id: LanguageModelToolUseId,
tool_label: impl Into<SharedString>,
tool_input: &serde_json::Value,
tool_input: &str,
tool_output: SharedString,
cx: &mut Context<Self>,
) {
@ -893,11 +894,10 @@ impl ActiveThread {
this.replace(tool_label, cx);
});
rendered.input.update(cx, |this, cx| {
let input = format!(
"```json\n{}\n```",
serde_json::to_string_pretty(tool_input).unwrap_or_default()
this.replace(
MarkdownString::code_block("json", tool_input).to_string(),
cx,
);
this.replace(input, cx);
});
rendered.output.update(cx, |this, cx| {
this.replace(tool_output, cx);
@ -988,7 +988,7 @@ impl ActiveThread {
self.render_tool_use_markdown(
tool_use.id.clone(),
tool_use.ui_text.clone(),
&tool_use.input,
&serde_json::to_string_pretty(&tool_use.input).unwrap_or_default(),
"".into(),
cx,
);
@ -1002,7 +1002,7 @@ impl ActiveThread {
self.render_tool_use_markdown(
tool_use_id.clone(),
ui_text.clone(),
input,
&serde_json::to_string_pretty(&input).unwrap_or_default(),
"".into(),
cx,
);
@ -1014,7 +1014,7 @@ impl ActiveThread {
self.render_tool_use_markdown(
tool_use.id.clone(),
tool_use.ui_text.clone(),
&tool_use.input,
&serde_json::to_string_pretty(&tool_use.input).unwrap_or_default(),
self.thread
.read(cx)
.output_for_tool(&tool_use.id)
@ -1026,6 +1026,23 @@ impl ActiveThread {
}
ThreadEvent::CheckpointChanged => cx.notify(),
ThreadEvent::ReceivedTextChunk => {}
ThreadEvent::InvalidToolInput {
tool_use_id,
ui_text,
invalid_input_json,
} => {
self.render_tool_use_markdown(
tool_use_id.clone(),
ui_text,
invalid_input_json,
self.thread
.read(cx)
.output_for_tool(tool_use_id)
.map(|output| output.clone().into())
.unwrap_or("".into()),
cx,
);
}
}
}

View file

@ -5,7 +5,9 @@ use anyhow::Result;
use client::telemetry::Telemetry;
use collections::HashSet;
use editor::{Anchor, AnchorRangeExt, MultiBuffer, MultiBufferSnapshot, ToOffset as _, ToPoint};
use futures::{SinkExt, Stream, StreamExt, channel::mpsc, future::LocalBoxFuture, join};
use futures::{
SinkExt, Stream, StreamExt, TryStreamExt as _, channel::mpsc, future::LocalBoxFuture, join,
};
use gpui::{App, AppContext as _, Context, Entity, EventEmitter, Subscription, Task};
use language::{Buffer, IndentKind, Point, TransactionId, line_diff};
use language_model::{
@ -508,7 +510,9 @@ impl CodegenAlternative {
let mut response_latency = None;
let request_start = Instant::now();
let diff = async {
let chunks = StripInvalidSpans::new(stream?.stream);
let chunks = StripInvalidSpans::new(
stream?.stream.map_err(|error| error.into()),
);
futures::pin_mut!(chunks);
let mut diff = StreamingDiff::new(selected_text.to_string());
let mut line_diff = LineDiff::default();

View file

@ -17,10 +17,10 @@ use gpui::{
AnyWindowHandle, App, AppContext, Context, Entity, EventEmitter, SharedString, Task, WeakEntity,
};
use language_model::{
ConfiguredModel, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
LanguageModelImage, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
LanguageModelId, LanguageModelImage, LanguageModelKnownError, LanguageModelRegistry,
LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool,
LanguageModelToolResult, LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, StopReason,
TokenUsage,
};
@ -1275,9 +1275,30 @@ impl Thread {
.push(event.as_ref().map_err(|error| error.to_string()).cloned());
}
let event = event?;
thread.update(cx, |thread, cx| {
let event = match event {
Ok(event) => event,
Err(LanguageModelCompletionError::BadInputJson {
id,
tool_name,
raw_input: invalid_input_json,
json_parse_error,
}) => {
thread.receive_invalid_tool_json(
id,
tool_name,
invalid_input_json,
json_parse_error,
window,
cx,
);
return Ok(());
}
Err(LanguageModelCompletionError::Other(error)) => {
return Err(error);
}
};
match event {
LanguageModelCompletionEvent::StartMessage { .. } => {
request_assistant_message_id = Some(thread.insert_message(
@ -1390,7 +1411,8 @@ impl Thread {
cx.notify();
thread.auto_capture_telemetry(cx);
})?;
Ok(())
})??;
smol::future::yield_now().await;
}
@ -1681,6 +1703,41 @@ impl Thread {
pending_tool_uses
}
pub fn receive_invalid_tool_json(
&mut self,
tool_use_id: LanguageModelToolUseId,
tool_name: Arc<str>,
invalid_json: Arc<str>,
error: String,
window: Option<AnyWindowHandle>,
cx: &mut Context<Thread>,
) {
log::error!("The model returned invalid input JSON: {invalid_json}");
let pending_tool_use = self.tool_use.insert_tool_output(
tool_use_id.clone(),
tool_name,
Err(anyhow!("Error parsing input JSON: {error}")),
cx,
);
let ui_text = if let Some(pending_tool_use) = &pending_tool_use {
pending_tool_use.ui_text.clone()
} else {
log::error!(
"There was no pending tool use for tool use {tool_use_id}, even though it finished (with invalid input JSON)."
);
format!("Unknown tool {}", tool_use_id).into()
};
cx.emit(ThreadEvent::InvalidToolInput {
tool_use_id: tool_use_id.clone(),
ui_text,
invalid_input_json: invalid_json,
});
self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
}
pub fn run_tool(
&mut self,
tool_use_id: LanguageModelToolUseId,
@ -2282,6 +2339,11 @@ pub enum ThreadEvent {
ui_text: Arc<str>,
input: serde_json::Value,
},
InvalidToolInput {
tool_use_id: LanguageModelToolUseId,
ui_text: Arc<str>,
invalid_input_json: Arc<str>,
},
Stopped(Result<StopReason, Arc<anyhow::Error>>),
MessageAdded(MessageId),
MessageEdited(MessageId),