parent
846ed6adf9
commit
f63036548c
2 changed files with 140 additions and 3 deletions
|
@ -16,6 +16,7 @@ use language_model::{
|
||||||
LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolUse, MessageContent,
|
LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolUse, MessageContent,
|
||||||
Role, StopReason, fake_provider::FakeLanguageModel,
|
Role, StopReason, fake_provider::FakeLanguageModel,
|
||||||
};
|
};
|
||||||
|
use pretty_assertions::assert_eq;
|
||||||
use project::Project;
|
use project::Project;
|
||||||
use prompt_store::ProjectContext;
|
use prompt_store::ProjectContext;
|
||||||
use reqwest_client::ReqwestClient;
|
use reqwest_client::ReqwestClient;
|
||||||
|
@ -129,6 +130,134 @@ async fn test_system_prompt(cx: &mut TestAppContext) {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[gpui::test]
|
||||||
|
async fn test_prompt_caching(cx: &mut TestAppContext) {
|
||||||
|
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
|
||||||
|
let fake_model = model.as_fake();
|
||||||
|
|
||||||
|
// Send initial user message and verify it's cached
|
||||||
|
thread.update(cx, |thread, cx| {
|
||||||
|
thread.send(UserMessageId::new(), ["Message 1"], cx)
|
||||||
|
});
|
||||||
|
cx.run_until_parked();
|
||||||
|
|
||||||
|
let completion = fake_model.pending_completions().pop().unwrap();
|
||||||
|
assert_eq!(
|
||||||
|
completion.messages[1..],
|
||||||
|
vec![LanguageModelRequestMessage {
|
||||||
|
role: Role::User,
|
||||||
|
content: vec!["Message 1".into()],
|
||||||
|
cache: true
|
||||||
|
}]
|
||||||
|
);
|
||||||
|
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
|
||||||
|
"Response to Message 1".into(),
|
||||||
|
));
|
||||||
|
fake_model.end_last_completion_stream();
|
||||||
|
cx.run_until_parked();
|
||||||
|
|
||||||
|
// Send another user message and verify only the latest is cached
|
||||||
|
thread.update(cx, |thread, cx| {
|
||||||
|
thread.send(UserMessageId::new(), ["Message 2"], cx)
|
||||||
|
});
|
||||||
|
cx.run_until_parked();
|
||||||
|
|
||||||
|
let completion = fake_model.pending_completions().pop().unwrap();
|
||||||
|
assert_eq!(
|
||||||
|
completion.messages[1..],
|
||||||
|
vec![
|
||||||
|
LanguageModelRequestMessage {
|
||||||
|
role: Role::User,
|
||||||
|
content: vec!["Message 1".into()],
|
||||||
|
cache: false
|
||||||
|
},
|
||||||
|
LanguageModelRequestMessage {
|
||||||
|
role: Role::Assistant,
|
||||||
|
content: vec!["Response to Message 1".into()],
|
||||||
|
cache: false
|
||||||
|
},
|
||||||
|
LanguageModelRequestMessage {
|
||||||
|
role: Role::User,
|
||||||
|
content: vec!["Message 2".into()],
|
||||||
|
cache: true
|
||||||
|
}
|
||||||
|
]
|
||||||
|
);
|
||||||
|
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
|
||||||
|
"Response to Message 2".into(),
|
||||||
|
));
|
||||||
|
fake_model.end_last_completion_stream();
|
||||||
|
cx.run_until_parked();
|
||||||
|
|
||||||
|
// Simulate a tool call and verify that the latest tool result is cached
|
||||||
|
thread.update(cx, |thread, _| thread.add_tool(EchoTool));
|
||||||
|
thread.update(cx, |thread, cx| {
|
||||||
|
thread.send(UserMessageId::new(), ["Use the echo tool"], cx)
|
||||||
|
});
|
||||||
|
cx.run_until_parked();
|
||||||
|
|
||||||
|
let tool_use = LanguageModelToolUse {
|
||||||
|
id: "tool_1".into(),
|
||||||
|
name: EchoTool.name().into(),
|
||||||
|
raw_input: json!({"text": "test"}).to_string(),
|
||||||
|
input: json!({"text": "test"}),
|
||||||
|
is_input_complete: true,
|
||||||
|
};
|
||||||
|
fake_model
|
||||||
|
.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
|
||||||
|
fake_model.end_last_completion_stream();
|
||||||
|
cx.run_until_parked();
|
||||||
|
|
||||||
|
let completion = fake_model.pending_completions().pop().unwrap();
|
||||||
|
let tool_result = LanguageModelToolResult {
|
||||||
|
tool_use_id: "tool_1".into(),
|
||||||
|
tool_name: EchoTool.name().into(),
|
||||||
|
is_error: false,
|
||||||
|
content: "test".into(),
|
||||||
|
output: Some("test".into()),
|
||||||
|
};
|
||||||
|
assert_eq!(
|
||||||
|
completion.messages[1..],
|
||||||
|
vec![
|
||||||
|
LanguageModelRequestMessage {
|
||||||
|
role: Role::User,
|
||||||
|
content: vec!["Message 1".into()],
|
||||||
|
cache: false
|
||||||
|
},
|
||||||
|
LanguageModelRequestMessage {
|
||||||
|
role: Role::Assistant,
|
||||||
|
content: vec!["Response to Message 1".into()],
|
||||||
|
cache: false
|
||||||
|
},
|
||||||
|
LanguageModelRequestMessage {
|
||||||
|
role: Role::User,
|
||||||
|
content: vec!["Message 2".into()],
|
||||||
|
cache: false
|
||||||
|
},
|
||||||
|
LanguageModelRequestMessage {
|
||||||
|
role: Role::Assistant,
|
||||||
|
content: vec!["Response to Message 2".into()],
|
||||||
|
cache: false
|
||||||
|
},
|
||||||
|
LanguageModelRequestMessage {
|
||||||
|
role: Role::User,
|
||||||
|
content: vec!["Use the echo tool".into()],
|
||||||
|
cache: false
|
||||||
|
},
|
||||||
|
LanguageModelRequestMessage {
|
||||||
|
role: Role::Assistant,
|
||||||
|
content: vec![MessageContent::ToolUse(tool_use)],
|
||||||
|
cache: false
|
||||||
|
},
|
||||||
|
LanguageModelRequestMessage {
|
||||||
|
role: Role::User,
|
||||||
|
content: vec![MessageContent::ToolResult(tool_result)],
|
||||||
|
cache: true
|
||||||
|
}
|
||||||
|
]
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
#[gpui::test]
|
#[gpui::test]
|
||||||
#[ignore = "can't run on CI yet"]
|
#[ignore = "can't run on CI yet"]
|
||||||
async fn test_basic_tool_calls(cx: &mut TestAppContext) {
|
async fn test_basic_tool_calls(cx: &mut TestAppContext) {
|
||||||
|
@ -440,7 +569,7 @@ async fn test_resume_after_tool_use_limit(cx: &mut TestAppContext) {
|
||||||
LanguageModelRequestMessage {
|
LanguageModelRequestMessage {
|
||||||
role: Role::User,
|
role: Role::User,
|
||||||
content: vec![MessageContent::ToolResult(tool_result.clone())],
|
content: vec![MessageContent::ToolResult(tool_result.clone())],
|
||||||
cache: false
|
cache: true
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
);
|
);
|
||||||
|
@ -481,7 +610,7 @@ async fn test_resume_after_tool_use_limit(cx: &mut TestAppContext) {
|
||||||
LanguageModelRequestMessage {
|
LanguageModelRequestMessage {
|
||||||
role: Role::User,
|
role: Role::User,
|
||||||
content: vec!["Continue where you left off".into()],
|
content: vec!["Continue where you left off".into()],
|
||||||
cache: false
|
cache: true
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
);
|
);
|
||||||
|
@ -574,7 +703,7 @@ async fn test_send_after_tool_use_limit(cx: &mut TestAppContext) {
|
||||||
LanguageModelRequestMessage {
|
LanguageModelRequestMessage {
|
||||||
role: Role::User,
|
role: Role::User,
|
||||||
content: vec!["ghi".into()],
|
content: vec!["ghi".into()],
|
||||||
cache: false
|
cache: true
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
);
|
);
|
||||||
|
|
|
@ -1041,6 +1041,14 @@ impl Thread {
|
||||||
messages.extend(message.to_request());
|
messages.extend(message.to_request());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if let Some(last_user_message) = messages
|
||||||
|
.iter_mut()
|
||||||
|
.rev()
|
||||||
|
.find(|message| message.role == Role::User)
|
||||||
|
{
|
||||||
|
last_user_message.cache = true;
|
||||||
|
}
|
||||||
|
|
||||||
messages
|
messages
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue