agent2: Implement prompt caching (#36236)

Release Notes:

- N/A
This commit is contained in:
Bennet Bo Fenner 2025-08-15 15:17:56 +02:00 committed by GitHub
parent 846ed6adf9
commit f63036548c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 140 additions and 3 deletions

View file

@ -16,6 +16,7 @@ use language_model::{
LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolUse, MessageContent, LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolUse, MessageContent,
Role, StopReason, fake_provider::FakeLanguageModel, Role, StopReason, fake_provider::FakeLanguageModel,
}; };
use pretty_assertions::assert_eq;
use project::Project; use project::Project;
use prompt_store::ProjectContext; use prompt_store::ProjectContext;
use reqwest_client::ReqwestClient; use reqwest_client::ReqwestClient;
@ -129,6 +130,134 @@ async fn test_system_prompt(cx: &mut TestAppContext) {
); );
} }
#[gpui::test]
async fn test_prompt_caching(cx: &mut TestAppContext) {
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
let fake_model = model.as_fake();
// Send initial user message and verify it's cached
thread.update(cx, |thread, cx| {
thread.send(UserMessageId::new(), ["Message 1"], cx)
});
cx.run_until_parked();
let completion = fake_model.pending_completions().pop().unwrap();
assert_eq!(
completion.messages[1..],
vec![LanguageModelRequestMessage {
role: Role::User,
content: vec!["Message 1".into()],
cache: true
}]
);
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
"Response to Message 1".into(),
));
fake_model.end_last_completion_stream();
cx.run_until_parked();
// Send another user message and verify only the latest is cached
thread.update(cx, |thread, cx| {
thread.send(UserMessageId::new(), ["Message 2"], cx)
});
cx.run_until_parked();
let completion = fake_model.pending_completions().pop().unwrap();
assert_eq!(
completion.messages[1..],
vec![
LanguageModelRequestMessage {
role: Role::User,
content: vec!["Message 1".into()],
cache: false
},
LanguageModelRequestMessage {
role: Role::Assistant,
content: vec!["Response to Message 1".into()],
cache: false
},
LanguageModelRequestMessage {
role: Role::User,
content: vec!["Message 2".into()],
cache: true
}
]
);
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
"Response to Message 2".into(),
));
fake_model.end_last_completion_stream();
cx.run_until_parked();
// Simulate a tool call and verify that the latest tool result is cached
thread.update(cx, |thread, _| thread.add_tool(EchoTool));
thread.update(cx, |thread, cx| {
thread.send(UserMessageId::new(), ["Use the echo tool"], cx)
});
cx.run_until_parked();
let tool_use = LanguageModelToolUse {
id: "tool_1".into(),
name: EchoTool.name().into(),
raw_input: json!({"text": "test"}).to_string(),
input: json!({"text": "test"}),
is_input_complete: true,
};
fake_model
.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
fake_model.end_last_completion_stream();
cx.run_until_parked();
let completion = fake_model.pending_completions().pop().unwrap();
let tool_result = LanguageModelToolResult {
tool_use_id: "tool_1".into(),
tool_name: EchoTool.name().into(),
is_error: false,
content: "test".into(),
output: Some("test".into()),
};
assert_eq!(
completion.messages[1..],
vec![
LanguageModelRequestMessage {
role: Role::User,
content: vec!["Message 1".into()],
cache: false
},
LanguageModelRequestMessage {
role: Role::Assistant,
content: vec!["Response to Message 1".into()],
cache: false
},
LanguageModelRequestMessage {
role: Role::User,
content: vec!["Message 2".into()],
cache: false
},
LanguageModelRequestMessage {
role: Role::Assistant,
content: vec!["Response to Message 2".into()],
cache: false
},
LanguageModelRequestMessage {
role: Role::User,
content: vec!["Use the echo tool".into()],
cache: false
},
LanguageModelRequestMessage {
role: Role::Assistant,
content: vec![MessageContent::ToolUse(tool_use)],
cache: false
},
LanguageModelRequestMessage {
role: Role::User,
content: vec![MessageContent::ToolResult(tool_result)],
cache: true
}
]
);
}
#[gpui::test] #[gpui::test]
#[ignore = "can't run on CI yet"] #[ignore = "can't run on CI yet"]
async fn test_basic_tool_calls(cx: &mut TestAppContext) { async fn test_basic_tool_calls(cx: &mut TestAppContext) {
@ -440,7 +569,7 @@ async fn test_resume_after_tool_use_limit(cx: &mut TestAppContext) {
LanguageModelRequestMessage { LanguageModelRequestMessage {
role: Role::User, role: Role::User,
content: vec![MessageContent::ToolResult(tool_result.clone())], content: vec![MessageContent::ToolResult(tool_result.clone())],
cache: false cache: true
}, },
] ]
); );
@ -481,7 +610,7 @@ async fn test_resume_after_tool_use_limit(cx: &mut TestAppContext) {
LanguageModelRequestMessage { LanguageModelRequestMessage {
role: Role::User, role: Role::User,
content: vec!["Continue where you left off".into()], content: vec!["Continue where you left off".into()],
cache: false cache: true
} }
] ]
); );
@ -574,7 +703,7 @@ async fn test_send_after_tool_use_limit(cx: &mut TestAppContext) {
LanguageModelRequestMessage { LanguageModelRequestMessage {
role: Role::User, role: Role::User,
content: vec!["ghi".into()], content: vec!["ghi".into()],
cache: false cache: true
} }
] ]
); );

View file

@ -1041,6 +1041,14 @@ impl Thread {
messages.extend(message.to_request()); messages.extend(message.to_request());
} }
if let Some(last_user_message) = messages
.iter_mut()
.rev()
.find(|message| message.role == Role::User)
{
last_user_message.cache = true;
}
messages messages
} }