Treat invalid JSON in tool calls as failed tool calls (#29375)
Release Notes: - N/A --------- Co-authored-by: Max <max@zed.dev> Co-authored-by: Max Brunsfeld <maxbrunsfeld@gmail.com>
This commit is contained in:
parent
a98c648201
commit
720dfee803
17 changed files with 374 additions and 168 deletions
|
@ -43,6 +43,7 @@ use ui::{
|
|||
Disclosure, IconButton, KeyBinding, Scrollbar, ScrollbarState, TextSize, Tooltip, prelude::*,
|
||||
};
|
||||
use util::ResultExt as _;
|
||||
use util::markdown::MarkdownString;
|
||||
use workspace::{OpenOptions, Workspace};
|
||||
use zed_actions::assistant::OpenRulesLibrary;
|
||||
|
||||
|
@ -769,7 +770,7 @@ impl ActiveThread {
|
|||
this.render_tool_use_markdown(
|
||||
tool_use.id.clone(),
|
||||
tool_use.ui_text.clone(),
|
||||
&tool_use.input,
|
||||
&serde_json::to_string_pretty(&tool_use.input).unwrap_or_default(),
|
||||
tool_use.status.text(),
|
||||
cx,
|
||||
);
|
||||
|
@ -870,7 +871,7 @@ impl ActiveThread {
|
|||
&mut self,
|
||||
tool_use_id: LanguageModelToolUseId,
|
||||
tool_label: impl Into<SharedString>,
|
||||
tool_input: &serde_json::Value,
|
||||
tool_input: &str,
|
||||
tool_output: SharedString,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
|
@ -893,11 +894,10 @@ impl ActiveThread {
|
|||
this.replace(tool_label, cx);
|
||||
});
|
||||
rendered.input.update(cx, |this, cx| {
|
||||
let input = format!(
|
||||
"```json\n{}\n```",
|
||||
serde_json::to_string_pretty(tool_input).unwrap_or_default()
|
||||
this.replace(
|
||||
MarkdownString::code_block("json", tool_input).to_string(),
|
||||
cx,
|
||||
);
|
||||
this.replace(input, cx);
|
||||
});
|
||||
rendered.output.update(cx, |this, cx| {
|
||||
this.replace(tool_output, cx);
|
||||
|
@ -988,7 +988,7 @@ impl ActiveThread {
|
|||
self.render_tool_use_markdown(
|
||||
tool_use.id.clone(),
|
||||
tool_use.ui_text.clone(),
|
||||
&tool_use.input,
|
||||
&serde_json::to_string_pretty(&tool_use.input).unwrap_or_default(),
|
||||
"".into(),
|
||||
cx,
|
||||
);
|
||||
|
@ -1002,7 +1002,7 @@ impl ActiveThread {
|
|||
self.render_tool_use_markdown(
|
||||
tool_use_id.clone(),
|
||||
ui_text.clone(),
|
||||
input,
|
||||
&serde_json::to_string_pretty(&input).unwrap_or_default(),
|
||||
"".into(),
|
||||
cx,
|
||||
);
|
||||
|
@ -1014,7 +1014,7 @@ impl ActiveThread {
|
|||
self.render_tool_use_markdown(
|
||||
tool_use.id.clone(),
|
||||
tool_use.ui_text.clone(),
|
||||
&tool_use.input,
|
||||
&serde_json::to_string_pretty(&tool_use.input).unwrap_or_default(),
|
||||
self.thread
|
||||
.read(cx)
|
||||
.output_for_tool(&tool_use.id)
|
||||
|
@ -1026,6 +1026,23 @@ impl ActiveThread {
|
|||
}
|
||||
ThreadEvent::CheckpointChanged => cx.notify(),
|
||||
ThreadEvent::ReceivedTextChunk => {}
|
||||
ThreadEvent::InvalidToolInput {
|
||||
tool_use_id,
|
||||
ui_text,
|
||||
invalid_input_json,
|
||||
} => {
|
||||
self.render_tool_use_markdown(
|
||||
tool_use_id.clone(),
|
||||
ui_text,
|
||||
invalid_input_json,
|
||||
self.thread
|
||||
.read(cx)
|
||||
.output_for_tool(tool_use_id)
|
||||
.map(|output| output.clone().into())
|
||||
.unwrap_or("".into()),
|
||||
cx,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -5,7 +5,9 @@ use anyhow::Result;
|
|||
use client::telemetry::Telemetry;
|
||||
use collections::HashSet;
|
||||
use editor::{Anchor, AnchorRangeExt, MultiBuffer, MultiBufferSnapshot, ToOffset as _, ToPoint};
|
||||
use futures::{SinkExt, Stream, StreamExt, channel::mpsc, future::LocalBoxFuture, join};
|
||||
use futures::{
|
||||
SinkExt, Stream, StreamExt, TryStreamExt as _, channel::mpsc, future::LocalBoxFuture, join,
|
||||
};
|
||||
use gpui::{App, AppContext as _, Context, Entity, EventEmitter, Subscription, Task};
|
||||
use language::{Buffer, IndentKind, Point, TransactionId, line_diff};
|
||||
use language_model::{
|
||||
|
@ -508,7 +510,9 @@ impl CodegenAlternative {
|
|||
let mut response_latency = None;
|
||||
let request_start = Instant::now();
|
||||
let diff = async {
|
||||
let chunks = StripInvalidSpans::new(stream?.stream);
|
||||
let chunks = StripInvalidSpans::new(
|
||||
stream?.stream.map_err(|error| error.into()),
|
||||
);
|
||||
futures::pin_mut!(chunks);
|
||||
let mut diff = StreamingDiff::new(selected_text.to_string());
|
||||
let mut line_diff = LineDiff::default();
|
||||
|
|
|
@ -17,10 +17,10 @@ use gpui::{
|
|||
AnyWindowHandle, App, AppContext, Context, Entity, EventEmitter, SharedString, Task, WeakEntity,
|
||||
};
|
||||
use language_model::{
|
||||
ConfiguredModel, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
|
||||
LanguageModelImage, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
|
||||
LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
|
||||
LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
|
||||
ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
|
||||
LanguageModelId, LanguageModelImage, LanguageModelKnownError, LanguageModelRegistry,
|
||||
LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool,
|
||||
LanguageModelToolResult, LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
|
||||
ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, StopReason,
|
||||
TokenUsage,
|
||||
};
|
||||
|
@ -1275,9 +1275,30 @@ impl Thread {
|
|||
.push(event.as_ref().map_err(|error| error.to_string()).cloned());
|
||||
}
|
||||
|
||||
let event = event?;
|
||||
|
||||
thread.update(cx, |thread, cx| {
|
||||
let event = match event {
|
||||
Ok(event) => event,
|
||||
Err(LanguageModelCompletionError::BadInputJson {
|
||||
id,
|
||||
tool_name,
|
||||
raw_input: invalid_input_json,
|
||||
json_parse_error,
|
||||
}) => {
|
||||
thread.receive_invalid_tool_json(
|
||||
id,
|
||||
tool_name,
|
||||
invalid_input_json,
|
||||
json_parse_error,
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
return Ok(());
|
||||
}
|
||||
Err(LanguageModelCompletionError::Other(error)) => {
|
||||
return Err(error);
|
||||
}
|
||||
};
|
||||
|
||||
match event {
|
||||
LanguageModelCompletionEvent::StartMessage { .. } => {
|
||||
request_assistant_message_id = Some(thread.insert_message(
|
||||
|
@ -1390,7 +1411,8 @@ impl Thread {
|
|||
cx.notify();
|
||||
|
||||
thread.auto_capture_telemetry(cx);
|
||||
})?;
|
||||
Ok(())
|
||||
})??;
|
||||
|
||||
smol::future::yield_now().await;
|
||||
}
|
||||
|
@ -1681,6 +1703,41 @@ impl Thread {
|
|||
pending_tool_uses
|
||||
}
|
||||
|
||||
pub fn receive_invalid_tool_json(
|
||||
&mut self,
|
||||
tool_use_id: LanguageModelToolUseId,
|
||||
tool_name: Arc<str>,
|
||||
invalid_json: Arc<str>,
|
||||
error: String,
|
||||
window: Option<AnyWindowHandle>,
|
||||
cx: &mut Context<Thread>,
|
||||
) {
|
||||
log::error!("The model returned invalid input JSON: {invalid_json}");
|
||||
|
||||
let pending_tool_use = self.tool_use.insert_tool_output(
|
||||
tool_use_id.clone(),
|
||||
tool_name,
|
||||
Err(anyhow!("Error parsing input JSON: {error}")),
|
||||
cx,
|
||||
);
|
||||
let ui_text = if let Some(pending_tool_use) = &pending_tool_use {
|
||||
pending_tool_use.ui_text.clone()
|
||||
} else {
|
||||
log::error!(
|
||||
"There was no pending tool use for tool use {tool_use_id}, even though it finished (with invalid input JSON)."
|
||||
);
|
||||
format!("Unknown tool {}", tool_use_id).into()
|
||||
};
|
||||
|
||||
cx.emit(ThreadEvent::InvalidToolInput {
|
||||
tool_use_id: tool_use_id.clone(),
|
||||
ui_text,
|
||||
invalid_input_json: invalid_json,
|
||||
});
|
||||
|
||||
self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
|
||||
}
|
||||
|
||||
pub fn run_tool(
|
||||
&mut self,
|
||||
tool_use_id: LanguageModelToolUseId,
|
||||
|
@ -2282,6 +2339,11 @@ pub enum ThreadEvent {
|
|||
ui_text: Arc<str>,
|
||||
input: serde_json::Value,
|
||||
},
|
||||
InvalidToolInput {
|
||||
tool_use_id: LanguageModelToolUseId,
|
||||
ui_text: Arc<str>,
|
||||
invalid_input_json: Arc<str>,
|
||||
},
|
||||
Stopped(Result<StopReason, Arc<anyhow::Error>>),
|
||||
MessageAdded(MessageId),
|
||||
MessageEdited(MessageId),
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue