Fix tool hallucination event order
This commit is contained in:
parent
342247f60f
commit
0d24686a9c
2 changed files with 78 additions and 16 deletions
|
@ -286,6 +286,63 @@ async fn test_tool_authorization(cx: &mut TestAppContext) {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[gpui::test]
|
||||||
|
async fn test_tool_hallucination(cx: &mut TestAppContext) {
|
||||||
|
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
|
||||||
|
let fake_model = model.as_fake();
|
||||||
|
|
||||||
|
let mut events = thread.update(cx, |thread, cx| thread.send(model.clone(), "abc", cx));
|
||||||
|
cx.run_until_parked();
|
||||||
|
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
|
||||||
|
LanguageModelToolUse {
|
||||||
|
id: "tool_id_1".into(),
|
||||||
|
name: "nonexistent_tool".into(),
|
||||||
|
raw_input: "{}".into(),
|
||||||
|
input: json!({}),
|
||||||
|
is_input_complete: true,
|
||||||
|
},
|
||||||
|
));
|
||||||
|
fake_model.end_last_completion_stream();
|
||||||
|
|
||||||
|
let tool_call = expect_tool_call(&mut events).await;
|
||||||
|
assert_eq!(tool_call.title, "nonexistent_tool");
|
||||||
|
assert_eq!(tool_call.status, acp::ToolCallStatus::Pending);
|
||||||
|
let update = expect_tool_call_update(&mut events).await;
|
||||||
|
assert_eq!(update.fields.status, Some(acp::ToolCallStatus::Failed));
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn expect_tool_call(
|
||||||
|
events: &mut UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
|
||||||
|
) -> acp::ToolCall {
|
||||||
|
let event = events
|
||||||
|
.next()
|
||||||
|
.await
|
||||||
|
.expect("no tool call authorization event received")
|
||||||
|
.unwrap();
|
||||||
|
match event {
|
||||||
|
AgentResponseEvent::ToolCall(tool_call) => return tool_call,
|
||||||
|
event => {
|
||||||
|
panic!("Unexpected event {event:?}");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn expect_tool_call_update(
|
||||||
|
events: &mut UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
|
||||||
|
) -> acp::ToolCallUpdate {
|
||||||
|
let event = events
|
||||||
|
.next()
|
||||||
|
.await
|
||||||
|
.expect("no tool call authorization event received")
|
||||||
|
.unwrap();
|
||||||
|
match event {
|
||||||
|
AgentResponseEvent::ToolCallUpdate(tool_call_update) => return tool_call_update,
|
||||||
|
event => {
|
||||||
|
panic!("Unexpected event {event:?}");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
async fn next_tool_call_authorization(
|
async fn next_tool_call_authorization(
|
||||||
events: &mut UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
|
events: &mut UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
|
||||||
) -> ToolCallAuthorization {
|
) -> ToolCallAuthorization {
|
||||||
|
|
|
@ -434,23 +434,10 @@ impl Thread {
|
||||||
event_stream: &AgentResponseEventStream,
|
event_stream: &AgentResponseEventStream,
|
||||||
cx: &mut Context<Self>,
|
cx: &mut Context<Self>,
|
||||||
) -> Option<Task<LanguageModelToolResult>> {
|
) -> Option<Task<LanguageModelToolResult>> {
|
||||||
let Some(tool) = self.tools.get(tool_use.name.as_ref()).cloned() else {
|
|
||||||
if tool_use.is_input_complete {
|
|
||||||
let content = format!("No tool named {} exists", tool_use.name);
|
|
||||||
return Some(Task::ready(LanguageModelToolResult {
|
|
||||||
content: LanguageModelToolResultContent::Text(Arc::from(content)),
|
|
||||||
tool_use_id: tool_use.id,
|
|
||||||
tool_name: tool_use.name,
|
|
||||||
is_error: true,
|
|
||||||
output: None,
|
|
||||||
}));
|
|
||||||
} else {
|
|
||||||
return None;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
cx.notify();
|
cx.notify();
|
||||||
|
|
||||||
|
let tool = self.tools.get(tool_use.name.as_ref()).cloned();
|
||||||
|
|
||||||
self.pending_tool_uses
|
self.pending_tool_uses
|
||||||
.insert(tool_use.id.clone(), tool_use.clone());
|
.insert(tool_use.id.clone(), tool_use.clone());
|
||||||
let last_message = self.last_assistant_message();
|
let last_message = self.last_assistant_message();
|
||||||
|
@ -468,8 +455,15 @@ impl Thread {
|
||||||
true
|
true
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
if push_new_tool_use {
|
if push_new_tool_use {
|
||||||
event_stream.send_tool_call(&tool_use, tool.kind());
|
event_stream.send_tool_call(
|
||||||
|
&tool_use,
|
||||||
|
// todo! add default
|
||||||
|
tool.as_ref()
|
||||||
|
.map(|t| t.kind())
|
||||||
|
.unwrap_or(acp::ToolKind::Other),
|
||||||
|
);
|
||||||
last_message
|
last_message
|
||||||
.content
|
.content
|
||||||
.push(MessageContent::ToolUse(tool_use.clone()));
|
.push(MessageContent::ToolUse(tool_use.clone()));
|
||||||
|
@ -487,6 +481,17 @@ impl Thread {
|
||||||
return None;
|
return None;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let Some(tool) = tool else {
|
||||||
|
let content = format!("No tool named {} exists", tool_use.name);
|
||||||
|
return Some(Task::ready(LanguageModelToolResult {
|
||||||
|
content: LanguageModelToolResultContent::Text(Arc::from(content)),
|
||||||
|
tool_use_id: tool_use.id,
|
||||||
|
tool_name: tool_use.name,
|
||||||
|
is_error: true,
|
||||||
|
output: None,
|
||||||
|
}));
|
||||||
|
};
|
||||||
|
|
||||||
let tool_result = self.run_tool(tool, tool_use.clone(), event_stream.clone(), cx);
|
let tool_result = self.run_tool(tool, tool_use.clone(), event_stream.clone(), cx);
|
||||||
Some(cx.foreground_executor().spawn(async move {
|
Some(cx.foreground_executor().spawn(async move {
|
||||||
match tool_result.await {
|
match tool_result.await {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue