From 473067db3173f6e43666f1283c850cff8d2b8cd5 Mon Sep 17 00:00:00 2001 From: KCaverly Date: Wed, 18 Oct 2023 15:56:39 -0400 Subject: [PATCH] update PromptPriority to accomodate for both Mandatory and Ordered prompts --- crates/ai/src/templates/base.rs | 101 ++++++++++++++---- crates/ai/src/templates/file_context.rs | 2 - crates/ai/src/templates/repository_context.rs | 2 +- crates/assistant/src/prompts.rs | 20 ++-- 4 files changed, 96 insertions(+), 29 deletions(-) diff --git a/crates/ai/src/templates/base.rs b/crates/ai/src/templates/base.rs index aaf08d755e..2afcc87ff5 100644 --- a/crates/ai/src/templates/base.rs +++ b/crates/ai/src/templates/base.rs @@ -1,9 +1,9 @@ +use anyhow::anyhow; use std::cmp::Reverse; use std::ops::Range; use std::sync::Arc; -use gpui::ModelHandle; -use language::{Anchor, Buffer, BufferSnapshot, ToOffset}; +use language::BufferSnapshot; use util::ResultExt; use crate::models::LanguageModel; @@ -50,11 +50,21 @@ pub trait PromptTemplate { } #[repr(i8)] -#[derive(PartialEq, Eq, PartialOrd, Ord)] +#[derive(PartialEq, Eq, Ord)] pub enum PromptPriority { - Low, - Medium, - High, + Mandatory, // Ignores truncation + Ordered { order: usize }, // Truncates based on priority +} + +impl PartialOrd for PromptPriority { + fn partial_cmp(&self, other: &Self) -> Option { + match (self, other) { + (Self::Mandatory, Self::Mandatory) => Some(std::cmp::Ordering::Equal), + (Self::Mandatory, Self::Ordered { .. }) => Some(std::cmp::Ordering::Greater), + (Self::Ordered { .. }, Self::Mandatory) => Some(std::cmp::Ordering::Less), + (Self::Ordered { order: a }, Self::Ordered { order: b }) => b.partial_cmp(a), + } + } } pub struct PromptChain { @@ -86,14 +96,36 @@ impl PromptChain { let mut prompts = vec!["".to_string(); sorted_indices.len()]; for idx in sorted_indices { - let (_, template) = &self.templates[idx]; + let (priority, template) = &self.templates[idx]; + + // If PromptPriority is marked as mandatory, we ignore the tokens outstanding + // However, if a prompt is generated in excess of the available tokens, + // we raise an error outlining that a mandatory prompt has exceeded the available + // balance + let template_tokens = if let Some(template_tokens) = tokens_outstanding { + match priority { + &PromptPriority::Mandatory => None, + _ => Some(template_tokens), + } + } else { + None + }; + if let Some((template_prompt, prompt_token_count)) = - template.generate(&self.args, tokens_outstanding).log_err() + template.generate(&self.args, template_tokens).log_err() { if template_prompt != "" { prompts[idx] = template_prompt; if let Some(remaining_tokens) = tokens_outstanding { + if prompt_token_count > remaining_tokens + && priority == &PromptPriority::Mandatory + { + return Err(anyhow!( + "mandatory template added in excess of model capacity" + )); + } + let new_tokens = prompt_token_count + seperator_tokens; tokens_outstanding = if remaining_tokens > new_tokens { Some(remaining_tokens - new_tokens) @@ -105,6 +137,8 @@ impl PromptChain { } } + prompts.retain(|x| x != ""); + let full_prompt = prompts.join(seperator); let total_token_count = self.args.model.count_tokens(&full_prompt)?; anyhow::Ok((prompts.join(seperator), total_token_count)) @@ -196,8 +230,14 @@ pub(crate) mod tests { }; let templates: Vec<(PromptPriority, Box)> = vec![ - (PromptPriority::High, Box::new(TestPromptTemplate {})), - (PromptPriority::Medium, Box::new(TestLowPriorityTemplate {})), + ( + PromptPriority::Ordered { order: 0 }, + Box::new(TestPromptTemplate {}), + ), + ( + PromptPriority::Ordered { order: 1 }, + Box::new(TestLowPriorityTemplate {}), + ), ]; let chain = PromptChain::new(args, templates); @@ -226,8 +266,14 @@ pub(crate) mod tests { }; let templates: Vec<(PromptPriority, Box)> = vec![ - (PromptPriority::High, Box::new(TestPromptTemplate {})), - (PromptPriority::Medium, Box::new(TestLowPriorityTemplate {})), + ( + PromptPriority::Ordered { order: 0 }, + Box::new(TestPromptTemplate {}), + ), + ( + PromptPriority::Ordered { order: 1 }, + Box::new(TestLowPriorityTemplate {}), + ), ]; let chain = PromptChain::new(args, templates); @@ -257,9 +303,18 @@ pub(crate) mod tests { }; let templates: Vec<(PromptPriority, Box)> = vec![ - (PromptPriority::High, Box::new(TestPromptTemplate {})), - (PromptPriority::Medium, Box::new(TestLowPriorityTemplate {})), - (PromptPriority::Low, Box::new(TestLowPriorityTemplate {})), + ( + PromptPriority::Ordered { order: 0 }, + Box::new(TestPromptTemplate {}), + ), + ( + PromptPriority::Ordered { order: 1 }, + Box::new(TestLowPriorityTemplate {}), + ), + ( + PromptPriority::Ordered { order: 2 }, + Box::new(TestLowPriorityTemplate {}), + ), ]; let chain = PromptChain::new(args, templates); @@ -283,14 +338,22 @@ pub(crate) mod tests { user_prompt: None, }; let templates: Vec<(PromptPriority, Box)> = vec![ - (PromptPriority::Medium, Box::new(TestPromptTemplate {})), - (PromptPriority::High, Box::new(TestLowPriorityTemplate {})), - (PromptPriority::Low, Box::new(TestLowPriorityTemplate {})), + ( + PromptPriority::Mandatory, + Box::new(TestLowPriorityTemplate {}), + ), + ( + PromptPriority::Ordered { order: 0 }, + Box::new(TestPromptTemplate {}), + ), + ( + PromptPriority::Ordered { order: 1 }, + Box::new(TestLowPriorityTemplate {}), + ), ]; let chain = PromptChain::new(args, templates); let (prompt, token_count) = chain.generate(true).unwrap(); - println!("TOKEN COUNT: {:?}", token_count); assert_eq!( prompt, diff --git a/crates/ai/src/templates/file_context.rs b/crates/ai/src/templates/file_context.rs index 6d06305049..94b194d9bf 100644 --- a/crates/ai/src/templates/file_context.rs +++ b/crates/ai/src/templates/file_context.rs @@ -30,8 +30,6 @@ impl PromptTemplate for FileContext { writeln!(prompt, "```{language_name}").unwrap(); if let Some(buffer) = &args.buffer { - let mut content = String::new(); - if let Some(selected_range) = &args.selected_range { let start = selected_range.start.to_offset(buffer); let end = selected_range.end.to_offset(buffer); diff --git a/crates/ai/src/templates/repository_context.rs b/crates/ai/src/templates/repository_context.rs index 7dd1647c44..a8e7f4b5af 100644 --- a/crates/ai/src/templates/repository_context.rs +++ b/crates/ai/src/templates/repository_context.rs @@ -60,7 +60,7 @@ impl PromptTemplate for RepositoryContext { max_token_length: Option, ) -> anyhow::Result<(String, usize)> { const MAXIMUM_SNIPPET_TOKEN_COUNT: usize = 500; - let mut template = "You are working inside a large repository, here are a few code snippets that may be useful."; + let template = "You are working inside a large repository, here are a few code snippets that may be useful."; let mut prompt = String::new(); let mut remaining_tokens = max_token_length.clone(); diff --git a/crates/assistant/src/prompts.rs b/crates/assistant/src/prompts.rs index 1457d28fff..dffcbc2923 100644 --- a/crates/assistant/src/prompts.rs +++ b/crates/assistant/src/prompts.rs @@ -1,4 +1,3 @@ -use crate::codegen::CodegenKind; use ai::models::{LanguageModel, OpenAILanguageModel}; use ai::templates::base::{PromptArguments, PromptChain, PromptPriority, PromptTemplate}; use ai::templates::file_context::FileContext; @@ -7,10 +6,8 @@ use ai::templates::preamble::EngineerPreamble; use ai::templates::repository_context::{PromptCodeSnippet, RepositoryContext}; use language::{BufferSnapshot, OffsetRangeExt, ToOffset}; use std::cmp::{self, Reverse}; -use std::fmt::Write; use std::ops::Range; use std::sync::Arc; -use tiktoken_rs::ChatCompletionRequestMessage; #[allow(dead_code)] fn summarize(buffer: &BufferSnapshot, selected_range: Range) -> String { @@ -152,10 +149,19 @@ pub fn generate_content_prompt( }; let templates: Vec<(PromptPriority, Box)> = vec![ - (PromptPriority::High, Box::new(EngineerPreamble {})), - (PromptPriority::Low, Box::new(RepositoryContext {})), - (PromptPriority::Medium, Box::new(FileContext {})), - (PromptPriority::High, Box::new(GenerateInlineContent {})), + (PromptPriority::Mandatory, Box::new(EngineerPreamble {})), + ( + PromptPriority::Ordered { order: 1 }, + Box::new(RepositoryContext {}), + ), + ( + PromptPriority::Ordered { order: 0 }, + Box::new(FileContext {}), + ), + ( + PromptPriority::Mandatory, + Box::new(GenerateInlineContent {}), + ), ]; let chain = PromptChain::new(args, templates); let (prompt, _) = chain.generate(true)?;