diff --git a/assets/prompts/assistant_system_prompt_reminder.hbs b/assets/prompts/assistant_system_prompt_reminder.hbs new file mode 100644 index 0000000000..998adfb5ff --- /dev/null +++ b/assets/prompts/assistant_system_prompt_reminder.hbs @@ -0,0 +1 @@ +In your response, make sure to remember and follow my instructions about how to format code blocks (and don't mention that you are remembering it, just follow the instructions). diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index 32b82898c4..4167bfbbd5 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -997,6 +997,20 @@ impl Thread { self.attached_tracked_files_state(&mut request.messages, cx); + // Add reminder to the last user message about code blocks + if let Some(last_user_message) = request + .messages + .iter_mut() + .rev() + .find(|msg| msg.role == Role::User) + { + last_user_message + .content + .push(MessageContent::Text(system_prompt_reminder( + &self.prompt_builder, + ))); + } + request } @@ -1810,6 +1824,12 @@ impl Thread { } } +pub fn system_prompt_reminder(prompt_builder: &prompt_store::PromptBuilder) -> String { + prompt_builder + .generate_assistant_system_prompt_reminder() + .unwrap_or_default() +} + #[derive(Debug, Clone)] pub enum ThreadError { PaymentRequired, @@ -1879,7 +1899,7 @@ mod tests { ) .await; - let (_workspace, _thread_store, thread, context_store) = + let (_workspace, _thread_store, thread, context_store, prompt_builder) = setup_test_environment(cx, project.clone()).await; add_file_to_context(&project, &context_store, "test/code.rs", cx) @@ -1933,8 +1953,14 @@ fn main() {{ }); assert_eq!(request.messages.len(), 1); - let expected_full_message = format!("{}Please explain this code", expected_context); - assert_eq!(request.messages[0].string_contents(), expected_full_message); + let actual_message = request.messages[0].string_contents(); + let expected_content = format!( + "{}Please explain this code{}", + expected_context, + system_prompt_reminder(&prompt_builder) + ); + + assert_eq!(actual_message, expected_content); } #[gpui::test] @@ -1951,7 +1977,7 @@ fn main() {{ ) .await; - let (_, _thread_store, thread, context_store) = + let (_, _thread_store, thread, context_store, _prompt_builder) = setup_test_environment(cx, project.clone()).await; // Open files individually @@ -2051,7 +2077,7 @@ fn main() {{ ) .await; - let (_, _thread_store, thread, _context_store) = + let (_, _thread_store, thread, _context_store, prompt_builder) = setup_test_environment(cx, project.clone()).await; // Insert user message without any context (empty context vector) @@ -2077,11 +2103,14 @@ fn main() {{ }); assert_eq!(request.messages.len(), 1); - assert_eq!( - request.messages[0].string_contents(), - "What is the best way to learn Rust?" + let actual_message = request.messages[0].string_contents(); + let expected_content = format!( + "What is the best way to learn Rust?{}", + system_prompt_reminder(&prompt_builder) ); + assert_eq!(actual_message, expected_content); + // Add second message, also without context let message2_id = thread.update(cx, |thread, cx| { thread.insert_user_message("Are there any good books?", vec![], None, cx) @@ -2097,14 +2126,17 @@ fn main() {{ }); assert_eq!(request.messages.len(), 2); - assert_eq!( - request.messages[0].string_contents(), - "What is the best way to learn Rust?" - ); - assert_eq!( - request.messages[1].string_contents(), - "Are there any good books?" + // First message should be the system prompt + assert_eq!(request.messages[0].role, Role::User); + + // Second message should be the user message with prompt reminder + let actual_message = request.messages[1].string_contents(); + let expected_content = format!( + "Are there any good books?{}", + system_prompt_reminder(&prompt_builder) ); + + assert_eq!(actual_message, expected_content); } #[gpui::test] @@ -2117,7 +2149,7 @@ fn main() {{ ) .await; - let (_workspace, _thread_store, thread, context_store) = + let (_workspace, _thread_store, thread, context_store, prompt_builder) = setup_test_environment(cx, project.clone()).await; // Open buffer and add it to context @@ -2177,11 +2209,14 @@ fn main() {{ // The last message should be the stale buffer notification assert_eq!(last_message.role, Role::User); - // Check the exact content of the message - let expected_content = "These files changed since last read:\n- code.rs\n"; + let actual_message = last_message.string_contents(); + let expected_content = format!( + "These files changed since last read:\n- code.rs\n{}", + system_prompt_reminder(&prompt_builder) + ); + assert_eq!( - last_message.string_contents(), - expected_content, + actual_message, expected_content, "Last message should be exactly the stale buffer notification" ); } @@ -2219,24 +2254,27 @@ fn main() {{ Entity, Entity, Entity, + Arc, ) { let (workspace, cx) = cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx)); + let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap()); + let thread_store = cx.update(|_, cx| { - ThreadStore::new( - project.clone(), - Arc::default(), - Arc::new(PromptBuilder::new(None).unwrap()), - cx, - ) - .unwrap() + ThreadStore::new(project.clone(), Arc::default(), prompt_builder.clone(), cx).unwrap() }); let thread = thread_store.update(cx, |store, cx| store.create_thread(cx)); let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None)); - (workspace, thread_store, thread, context_store) + ( + workspace, + thread_store, + thread, + context_store, + prompt_builder, + ) } async fn add_file_to_context( diff --git a/crates/prompt_store/src/prompts.rs b/crates/prompt_store/src/prompts.rs index 8577a21a1e..3bff73defd 100644 --- a/crates/prompt_store/src/prompts.rs +++ b/crates/prompt_store/src/prompts.rs @@ -261,6 +261,12 @@ impl PromptBuilder { .render("assistant_system_prompt", context) } + pub fn generate_assistant_system_prompt_reminder(&self) -> Result { + self.handlebars + .lock() + .render("assistant_system_prompt_reminder", &()) + } + pub fn generate_inline_transformation_prompt( &self, user_prompt: String,