From f63036548c2229a4dfe1cd7576bf6cee5cd3f1ca Mon Sep 17 00:00:00 2001 From: Bennet Bo Fenner Date: Fri, 15 Aug 2025 15:17:56 +0200 Subject: [PATCH] agent2: Implement prompt caching (#36236) Release Notes: - N/A --- crates/agent2/src/tests/mod.rs | 135 ++++++++++++++++++++++++++++++++- crates/agent2/src/thread.rs | 8 ++ 2 files changed, 140 insertions(+), 3 deletions(-) diff --git a/crates/agent2/src/tests/mod.rs b/crates/agent2/src/tests/mod.rs index cf90c8f650..cc8bd483bb 100644 --- a/crates/agent2/src/tests/mod.rs +++ b/crates/agent2/src/tests/mod.rs @@ -16,6 +16,7 @@ use language_model::{ LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolUse, MessageContent, Role, StopReason, fake_provider::FakeLanguageModel, }; +use pretty_assertions::assert_eq; use project::Project; use prompt_store::ProjectContext; use reqwest_client::ReqwestClient; @@ -129,6 +130,134 @@ async fn test_system_prompt(cx: &mut TestAppContext) { ); } +#[gpui::test] +async fn test_prompt_caching(cx: &mut TestAppContext) { + let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await; + let fake_model = model.as_fake(); + + // Send initial user message and verify it's cached + thread.update(cx, |thread, cx| { + thread.send(UserMessageId::new(), ["Message 1"], cx) + }); + cx.run_until_parked(); + + let completion = fake_model.pending_completions().pop().unwrap(); + assert_eq!( + completion.messages[1..], + vec![LanguageModelRequestMessage { + role: Role::User, + content: vec!["Message 1".into()], + cache: true + }] + ); + fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text( + "Response to Message 1".into(), + )); + fake_model.end_last_completion_stream(); + cx.run_until_parked(); + + // Send another user message and verify only the latest is cached + thread.update(cx, |thread, cx| { + thread.send(UserMessageId::new(), ["Message 2"], cx) + }); + cx.run_until_parked(); + + let completion = fake_model.pending_completions().pop().unwrap(); + assert_eq!( + completion.messages[1..], + vec![ + LanguageModelRequestMessage { + role: Role::User, + content: vec!["Message 1".into()], + cache: false + }, + LanguageModelRequestMessage { + role: Role::Assistant, + content: vec!["Response to Message 1".into()], + cache: false + }, + LanguageModelRequestMessage { + role: Role::User, + content: vec!["Message 2".into()], + cache: true + } + ] + ); + fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text( + "Response to Message 2".into(), + )); + fake_model.end_last_completion_stream(); + cx.run_until_parked(); + + // Simulate a tool call and verify that the latest tool result is cached + thread.update(cx, |thread, _| thread.add_tool(EchoTool)); + thread.update(cx, |thread, cx| { + thread.send(UserMessageId::new(), ["Use the echo tool"], cx) + }); + cx.run_until_parked(); + + let tool_use = LanguageModelToolUse { + id: "tool_1".into(), + name: EchoTool.name().into(), + raw_input: json!({"text": "test"}).to_string(), + input: json!({"text": "test"}), + is_input_complete: true, + }; + fake_model + .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone())); + fake_model.end_last_completion_stream(); + cx.run_until_parked(); + + let completion = fake_model.pending_completions().pop().unwrap(); + let tool_result = LanguageModelToolResult { + tool_use_id: "tool_1".into(), + tool_name: EchoTool.name().into(), + is_error: false, + content: "test".into(), + output: Some("test".into()), + }; + assert_eq!( + completion.messages[1..], + vec![ + LanguageModelRequestMessage { + role: Role::User, + content: vec!["Message 1".into()], + cache: false + }, + LanguageModelRequestMessage { + role: Role::Assistant, + content: vec!["Response to Message 1".into()], + cache: false + }, + LanguageModelRequestMessage { + role: Role::User, + content: vec!["Message 2".into()], + cache: false + }, + LanguageModelRequestMessage { + role: Role::Assistant, + content: vec!["Response to Message 2".into()], + cache: false + }, + LanguageModelRequestMessage { + role: Role::User, + content: vec!["Use the echo tool".into()], + cache: false + }, + LanguageModelRequestMessage { + role: Role::Assistant, + content: vec![MessageContent::ToolUse(tool_use)], + cache: false + }, + LanguageModelRequestMessage { + role: Role::User, + content: vec![MessageContent::ToolResult(tool_result)], + cache: true + } + ] + ); +} + #[gpui::test] #[ignore = "can't run on CI yet"] async fn test_basic_tool_calls(cx: &mut TestAppContext) { @@ -440,7 +569,7 @@ async fn test_resume_after_tool_use_limit(cx: &mut TestAppContext) { LanguageModelRequestMessage { role: Role::User, content: vec![MessageContent::ToolResult(tool_result.clone())], - cache: false + cache: true }, ] ); @@ -481,7 +610,7 @@ async fn test_resume_after_tool_use_limit(cx: &mut TestAppContext) { LanguageModelRequestMessage { role: Role::User, content: vec!["Continue where you left off".into()], - cache: false + cache: true } ] ); @@ -574,7 +703,7 @@ async fn test_send_after_tool_use_limit(cx: &mut TestAppContext) { LanguageModelRequestMessage { role: Role::User, content: vec!["ghi".into()], - cache: false + cache: true } ] ); diff --git a/crates/agent2/src/thread.rs b/crates/agent2/src/thread.rs index 231ee92dda..2fe2dc20bb 100644 --- a/crates/agent2/src/thread.rs +++ b/crates/agent2/src/thread.rs @@ -1041,6 +1041,14 @@ impl Thread { messages.extend(message.to_request()); } + if let Some(last_user_message) = messages + .iter_mut() + .rev() + .find(|message| message.role == Role::User) + { + last_user_message.cache = true; + } + messages }