diff --git a/crates/agent2/src/tests/mod.rs b/crates/agent2/src/tests/mod.rs index b13b1cbe1a..297453f1e6 100644 --- a/crates/agent2/src/tests/mod.rs +++ b/crates/agent2/src/tests/mod.rs @@ -286,6 +286,63 @@ async fn test_tool_authorization(cx: &mut TestAppContext) { ); } +#[gpui::test] +async fn test_tool_hallucination(cx: &mut TestAppContext) { + let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await; + let fake_model = model.as_fake(); + + let mut events = thread.update(cx, |thread, cx| thread.send(model.clone(), "abc", cx)); + cx.run_until_parked(); + fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse( + LanguageModelToolUse { + id: "tool_id_1".into(), + name: "nonexistent_tool".into(), + raw_input: "{}".into(), + input: json!({}), + is_input_complete: true, + }, + )); + fake_model.end_last_completion_stream(); + + let tool_call = expect_tool_call(&mut events).await; + assert_eq!(tool_call.title, "nonexistent_tool"); + assert_eq!(tool_call.status, acp::ToolCallStatus::Pending); + let update = expect_tool_call_update(&mut events).await; + assert_eq!(update.fields.status, Some(acp::ToolCallStatus::Failed)); +} + +async fn expect_tool_call( + events: &mut UnboundedReceiver>, +) -> acp::ToolCall { + let event = events + .next() + .await + .expect("no tool call authorization event received") + .unwrap(); + match event { + AgentResponseEvent::ToolCall(tool_call) => return tool_call, + event => { + panic!("Unexpected event {event:?}"); + } + } +} + +async fn expect_tool_call_update( + events: &mut UnboundedReceiver>, +) -> acp::ToolCallUpdate { + let event = events + .next() + .await + .expect("no tool call authorization event received") + .unwrap(); + match event { + AgentResponseEvent::ToolCallUpdate(tool_call_update) => return tool_call_update, + event => { + panic!("Unexpected event {event:?}"); + } + } +} + async fn next_tool_call_authorization( events: &mut UnboundedReceiver>, ) -> ToolCallAuthorization { diff --git a/crates/agent2/src/thread.rs b/crates/agent2/src/thread.rs index efc58d12b1..8f8fae5c67 100644 --- a/crates/agent2/src/thread.rs +++ b/crates/agent2/src/thread.rs @@ -434,23 +434,10 @@ impl Thread { event_stream: &AgentResponseEventStream, cx: &mut Context, ) -> Option> { - let Some(tool) = self.tools.get(tool_use.name.as_ref()).cloned() else { - if tool_use.is_input_complete { - let content = format!("No tool named {} exists", tool_use.name); - return Some(Task::ready(LanguageModelToolResult { - content: LanguageModelToolResultContent::Text(Arc::from(content)), - tool_use_id: tool_use.id, - tool_name: tool_use.name, - is_error: true, - output: None, - })); - } else { - return None; - } - }; - cx.notify(); + let tool = self.tools.get(tool_use.name.as_ref()).cloned(); + self.pending_tool_uses .insert(tool_use.id.clone(), tool_use.clone()); let last_message = self.last_assistant_message(); @@ -468,8 +455,15 @@ impl Thread { true } }); + if push_new_tool_use { - event_stream.send_tool_call(&tool_use, tool.kind()); + event_stream.send_tool_call( + &tool_use, + // todo! add default + tool.as_ref() + .map(|t| t.kind()) + .unwrap_or(acp::ToolKind::Other), + ); last_message .content .push(MessageContent::ToolUse(tool_use.clone())); @@ -487,6 +481,17 @@ impl Thread { return None; } + let Some(tool) = tool else { + let content = format!("No tool named {} exists", tool_use.name); + return Some(Task::ready(LanguageModelToolResult { + content: LanguageModelToolResultContent::Text(Arc::from(content)), + tool_use_id: tool_use.id, + tool_name: tool_use.name, + is_error: true, + output: None, + })); + }; + let tool_result = self.run_tool(tool, tool_use.clone(), event_stream.clone(), cx); Some(cx.foreground_executor().spawn(async move { match tool_result.await {