agent2: Port Zed AI features (#36172)

Release Notes:

- N/A

---------

Co-authored-by: Antonio Scandurra <me@as-cii.com>
This commit is contained in:
Bennet Bo Fenner 2025-08-15 13:17:17 +02:00 committed by GitHub
parent f8b0105258
commit 6f3cd42411
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
17 changed files with 994 additions and 358 deletions

View file

@ -12,9 +12,9 @@ use gpui::{
};
use indoc::indoc;
use language_model::{
LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
LanguageModelRegistry, LanguageModelToolResult, LanguageModelToolUse, Role, StopReason,
fake_provider::FakeLanguageModel,
LanguageModel, LanguageModelCompletionEvent, LanguageModelId, LanguageModelRegistry,
LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolUse, MessageContent,
Role, StopReason, fake_provider::FakeLanguageModel,
};
use project::Project;
use prompt_store::ProjectContext;
@ -394,8 +394,194 @@ async fn test_tool_hallucination(cx: &mut TestAppContext) {
assert_eq!(update.fields.status, Some(acp::ToolCallStatus::Failed));
}
#[gpui::test]
async fn test_resume_after_tool_use_limit(cx: &mut TestAppContext) {
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
let fake_model = model.as_fake();
let events = thread.update(cx, |thread, cx| {
thread.add_tool(EchoTool);
thread.send(UserMessageId::new(), ["abc"], cx)
});
cx.run_until_parked();
let tool_use = LanguageModelToolUse {
id: "tool_id_1".into(),
name: EchoTool.name().into(),
raw_input: "{}".into(),
input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
is_input_complete: true,
};
fake_model
.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
fake_model.end_last_completion_stream();
cx.run_until_parked();
let completion = fake_model.pending_completions().pop().unwrap();
let tool_result = LanguageModelToolResult {
tool_use_id: "tool_id_1".into(),
tool_name: EchoTool.name().into(),
is_error: false,
content: "def".into(),
output: Some("def".into()),
};
assert_eq!(
completion.messages[1..],
vec![
LanguageModelRequestMessage {
role: Role::User,
content: vec!["abc".into()],
cache: false
},
LanguageModelRequestMessage {
role: Role::Assistant,
content: vec![MessageContent::ToolUse(tool_use.clone())],
cache: false
},
LanguageModelRequestMessage {
role: Role::User,
content: vec![MessageContent::ToolResult(tool_result.clone())],
cache: false
},
]
);
// Simulate reaching tool use limit.
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::StatusUpdate(
cloud_llm_client::CompletionRequestStatus::ToolUseLimitReached,
));
fake_model.end_last_completion_stream();
let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
assert!(
last_event
.unwrap_err()
.is::<language_model::ToolUseLimitReachedError>()
);
let events = thread.update(cx, |thread, cx| thread.resume(cx)).unwrap();
cx.run_until_parked();
let completion = fake_model.pending_completions().pop().unwrap();
assert_eq!(
completion.messages[1..],
vec![
LanguageModelRequestMessage {
role: Role::User,
content: vec!["abc".into()],
cache: false
},
LanguageModelRequestMessage {
role: Role::Assistant,
content: vec![MessageContent::ToolUse(tool_use)],
cache: false
},
LanguageModelRequestMessage {
role: Role::User,
content: vec![MessageContent::ToolResult(tool_result)],
cache: false
},
LanguageModelRequestMessage {
role: Role::User,
content: vec!["Continue where you left off".into()],
cache: false
}
]
);
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text("Done".into()));
fake_model.end_last_completion_stream();
events.collect::<Vec<_>>().await;
thread.read_with(cx, |thread, _cx| {
assert_eq!(
thread.last_message().unwrap().to_markdown(),
indoc! {"
## Assistant
Done
"}
)
});
// Ensure we error if calling resume when tool use limit was *not* reached.
let error = thread
.update(cx, |thread, cx| thread.resume(cx))
.unwrap_err();
assert_eq!(
error.to_string(),
"can only resume after tool use limit is reached"
)
}
#[gpui::test]
async fn test_send_after_tool_use_limit(cx: &mut TestAppContext) {
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
let fake_model = model.as_fake();
let events = thread.update(cx, |thread, cx| {
thread.add_tool(EchoTool);
thread.send(UserMessageId::new(), ["abc"], cx)
});
cx.run_until_parked();
let tool_use = LanguageModelToolUse {
id: "tool_id_1".into(),
name: EchoTool.name().into(),
raw_input: "{}".into(),
input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
is_input_complete: true,
};
let tool_result = LanguageModelToolResult {
tool_use_id: "tool_id_1".into(),
tool_name: EchoTool.name().into(),
is_error: false,
content: "def".into(),
output: Some("def".into()),
};
fake_model
.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::StatusUpdate(
cloud_llm_client::CompletionRequestStatus::ToolUseLimitReached,
));
fake_model.end_last_completion_stream();
let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
assert!(
last_event
.unwrap_err()
.is::<language_model::ToolUseLimitReachedError>()
);
thread.update(cx, |thread, cx| {
thread.send(UserMessageId::new(), vec!["ghi"], cx)
});
cx.run_until_parked();
let completion = fake_model.pending_completions().pop().unwrap();
assert_eq!(
completion.messages[1..],
vec![
LanguageModelRequestMessage {
role: Role::User,
content: vec!["abc".into()],
cache: false
},
LanguageModelRequestMessage {
role: Role::Assistant,
content: vec![MessageContent::ToolUse(tool_use)],
cache: false
},
LanguageModelRequestMessage {
role: Role::User,
content: vec![MessageContent::ToolResult(tool_result)],
cache: false
},
LanguageModelRequestMessage {
role: Role::User,
content: vec!["ghi".into()],
cache: false
}
]
);
}
async fn expect_tool_call(
events: &mut UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
events: &mut UnboundedReceiver<Result<AgentResponseEvent>>,
) -> acp::ToolCall {
let event = events
.next()
@ -411,7 +597,7 @@ async fn expect_tool_call(
}
async fn expect_tool_call_update_fields(
events: &mut UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
events: &mut UnboundedReceiver<Result<AgentResponseEvent>>,
) -> acp::ToolCallUpdate {
let event = events
.next()
@ -429,7 +615,7 @@ async fn expect_tool_call_update_fields(
}
async fn next_tool_call_authorization(
events: &mut UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
events: &mut UnboundedReceiver<Result<AgentResponseEvent>>,
) -> ToolCallAuthorization {
loop {
let event = events
@ -1007,9 +1193,7 @@ async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
}
/// Filters out the stop events for asserting against in tests
fn stop_events(
result_events: Vec<Result<AgentResponseEvent, LanguageModelCompletionError>>,
) -> Vec<acp::StopReason> {
fn stop_events(result_events: Vec<Result<AgentResponseEvent>>) -> Vec<acp::StopReason> {
result_events
.into_iter()
.filter_map(|event| match event.unwrap() {