Treat invalid JSON in tool calls as failed tool calls (#29375)
Release Notes: - N/A --------- Co-authored-by: Max <max@zed.dev> Co-authored-by: Max Brunsfeld <maxbrunsfeld@gmail.com>
This commit is contained in:
parent
a98c648201
commit
720dfee803
17 changed files with 374 additions and 168 deletions
|
@ -17,10 +17,10 @@ use gpui::{
|
|||
AnyWindowHandle, App, AppContext, Context, Entity, EventEmitter, SharedString, Task, WeakEntity,
|
||||
};
|
||||
use language_model::{
|
||||
ConfiguredModel, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
|
||||
LanguageModelImage, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
|
||||
LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
|
||||
LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
|
||||
ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
|
||||
LanguageModelId, LanguageModelImage, LanguageModelKnownError, LanguageModelRegistry,
|
||||
LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool,
|
||||
LanguageModelToolResult, LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
|
||||
ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, StopReason,
|
||||
TokenUsage,
|
||||
};
|
||||
|
@ -1275,9 +1275,30 @@ impl Thread {
|
|||
.push(event.as_ref().map_err(|error| error.to_string()).cloned());
|
||||
}
|
||||
|
||||
let event = event?;
|
||||
|
||||
thread.update(cx, |thread, cx| {
|
||||
let event = match event {
|
||||
Ok(event) => event,
|
||||
Err(LanguageModelCompletionError::BadInputJson {
|
||||
id,
|
||||
tool_name,
|
||||
raw_input: invalid_input_json,
|
||||
json_parse_error,
|
||||
}) => {
|
||||
thread.receive_invalid_tool_json(
|
||||
id,
|
||||
tool_name,
|
||||
invalid_input_json,
|
||||
json_parse_error,
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
return Ok(());
|
||||
}
|
||||
Err(LanguageModelCompletionError::Other(error)) => {
|
||||
return Err(error);
|
||||
}
|
||||
};
|
||||
|
||||
match event {
|
||||
LanguageModelCompletionEvent::StartMessage { .. } => {
|
||||
request_assistant_message_id = Some(thread.insert_message(
|
||||
|
@ -1390,7 +1411,8 @@ impl Thread {
|
|||
cx.notify();
|
||||
|
||||
thread.auto_capture_telemetry(cx);
|
||||
})?;
|
||||
Ok(())
|
||||
})??;
|
||||
|
||||
smol::future::yield_now().await;
|
||||
}
|
||||
|
@ -1681,6 +1703,41 @@ impl Thread {
|
|||
pending_tool_uses
|
||||
}
|
||||
|
||||
pub fn receive_invalid_tool_json(
|
||||
&mut self,
|
||||
tool_use_id: LanguageModelToolUseId,
|
||||
tool_name: Arc<str>,
|
||||
invalid_json: Arc<str>,
|
||||
error: String,
|
||||
window: Option<AnyWindowHandle>,
|
||||
cx: &mut Context<Thread>,
|
||||
) {
|
||||
log::error!("The model returned invalid input JSON: {invalid_json}");
|
||||
|
||||
let pending_tool_use = self.tool_use.insert_tool_output(
|
||||
tool_use_id.clone(),
|
||||
tool_name,
|
||||
Err(anyhow!("Error parsing input JSON: {error}")),
|
||||
cx,
|
||||
);
|
||||
let ui_text = if let Some(pending_tool_use) = &pending_tool_use {
|
||||
pending_tool_use.ui_text.clone()
|
||||
} else {
|
||||
log::error!(
|
||||
"There was no pending tool use for tool use {tool_use_id}, even though it finished (with invalid input JSON)."
|
||||
);
|
||||
format!("Unknown tool {}", tool_use_id).into()
|
||||
};
|
||||
|
||||
cx.emit(ThreadEvent::InvalidToolInput {
|
||||
tool_use_id: tool_use_id.clone(),
|
||||
ui_text,
|
||||
invalid_input_json: invalid_json,
|
||||
});
|
||||
|
||||
self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
|
||||
}
|
||||
|
||||
pub fn run_tool(
|
||||
&mut self,
|
||||
tool_use_id: LanguageModelToolUseId,
|
||||
|
@ -2282,6 +2339,11 @@ pub enum ThreadEvent {
|
|||
ui_text: Arc<str>,
|
||||
input: serde_json::Value,
|
||||
},
|
||||
InvalidToolInput {
|
||||
tool_use_id: LanguageModelToolUseId,
|
||||
ui_text: Arc<str>,
|
||||
invalid_input_json: Arc<str>,
|
||||
},
|
||||
Stopped(Result<StopReason, Arc<anyhow::Error>>),
|
||||
MessageAdded(MessageId),
|
||||
MessageEdited(MessageId),
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue