From 500af6d7754adf1a60f245200271e4dd40d7fb8f Mon Sep 17 00:00:00 2001 From: KCaverly Date: Mon, 16 Oct 2023 18:47:10 -0400 Subject: [PATCH] progress on prompt chains --- Cargo.lock | 1 + crates/ai/Cargo.toml | 1 + crates/ai/src/prompts.rs | 149 ++++++++++++++++++ crates/ai/src/templates.rs | 76 --------- crates/ai/src/templates/base.rs | 112 +++++++++++++ crates/ai/src/templates/mod.rs | 3 + crates/ai/src/templates/preamble.rs | 34 ++++ crates/ai/src/templates/repository_context.rs | 49 ++++++ 8 files changed, 349 insertions(+), 76 deletions(-) create mode 100644 crates/ai/src/prompts.rs delete mode 100644 crates/ai/src/templates.rs create mode 100644 crates/ai/src/templates/base.rs create mode 100644 crates/ai/src/templates/mod.rs create mode 100644 crates/ai/src/templates/preamble.rs create mode 100644 crates/ai/src/templates/repository_context.rs diff --git a/Cargo.lock b/Cargo.lock index cd9dee0bda..9938c5d2fa 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -91,6 +91,7 @@ dependencies = [ "futures 0.3.28", "gpui", "isahc", + "language", "lazy_static", "log", "matrixmultiply", diff --git a/crates/ai/Cargo.toml b/crates/ai/Cargo.toml index 542d7f422f..b24c4e5ece 100644 --- a/crates/ai/Cargo.toml +++ b/crates/ai/Cargo.toml @@ -11,6 +11,7 @@ doctest = false [dependencies] gpui = { path = "../gpui" } util = { path = "../util" } +language = { path = "../language" } async-trait.workspace = true anyhow.workspace = true futures.workspace = true diff --git a/crates/ai/src/prompts.rs b/crates/ai/src/prompts.rs new file mode 100644 index 0000000000..6d2c0629fa --- /dev/null +++ b/crates/ai/src/prompts.rs @@ -0,0 +1,149 @@ +use gpui::{AsyncAppContext, ModelHandle}; +use language::{Anchor, Buffer}; +use std::{fmt::Write, ops::Range, path::PathBuf}; + +pub struct PromptCodeSnippet { + path: Option, + language_name: Option, + content: String, +} + +impl PromptCodeSnippet { + pub fn new(buffer: ModelHandle, range: Range, cx: &AsyncAppContext) -> Self { + let (content, language_name, file_path) = buffer.read_with(cx, |buffer, _| { + let snapshot = buffer.snapshot(); + let content = snapshot.text_for_range(range.clone()).collect::(); + + let language_name = buffer + .language() + .and_then(|language| Some(language.name().to_string())); + + let file_path = buffer + .file() + .and_then(|file| Some(file.path().to_path_buf())); + + (content, language_name, file_path) + }); + + PromptCodeSnippet { + path: file_path, + language_name, + content, + } + } +} + +impl ToString for PromptCodeSnippet { + fn to_string(&self) -> String { + let path = self + .path + .as_ref() + .and_then(|path| Some(path.to_string_lossy().to_string())) + .unwrap_or("".to_string()); + let language_name = self.language_name.clone().unwrap_or("".to_string()); + let content = self.content.clone(); + + format!("The below code snippet may be relevant from file: {path}\n```{language_name}\n{content}\n```") + } +} + +enum PromptFileType { + Text, + Code, +} + +#[derive(Default)] +struct PromptArguments { + pub language_name: Option, + pub project_name: Option, + pub snippets: Vec, + pub model_name: String, +} + +impl PromptArguments { + pub fn get_file_type(&self) -> PromptFileType { + if self + .language_name + .as_ref() + .and_then(|name| Some(!["Markdown", "Plain Text"].contains(&name.as_str()))) + .unwrap_or(true) + { + PromptFileType::Code + } else { + PromptFileType::Text + } + } +} + +trait PromptTemplate { + fn generate(args: PromptArguments, max_token_length: Option) -> String; +} + +struct EngineerPreamble {} + +impl PromptTemplate for EngineerPreamble { + fn generate(args: PromptArguments, max_token_length: Option) -> String { + let mut prompt = String::new(); + + match args.get_file_type() { + PromptFileType::Code => { + writeln!( + prompt, + "You are an expert {} engineer.", + args.language_name.unwrap_or("".to_string()) + ) + .unwrap(); + } + PromptFileType::Text => { + writeln!(prompt, "You are an expert engineer.").unwrap(); + } + } + + if let Some(project_name) = args.project_name { + writeln!( + prompt, + "You are currently working inside the '{project_name}' in Zed the code editor." + ) + .unwrap(); + } + + prompt + } +} + +struct RepositorySnippets {} + +impl PromptTemplate for RepositorySnippets { + fn generate(args: PromptArguments, max_token_length: Option) -> String { + const MAXIMUM_SNIPPET_TOKEN_COUNT: usize = 500; + let mut template = "You are working inside a large repository, here are a few code snippets that may be useful"; + let mut prompt = String::new(); + + if let Ok(encoding) = tiktoken_rs::get_bpe_from_model(args.model_name.as_str()) { + let default_token_count = + tiktoken_rs::model::get_context_size(args.model_name.as_str()); + let mut remaining_token_count = max_token_length.unwrap_or(default_token_count); + + for snippet in args.snippets { + let mut snippet_prompt = template.to_string(); + let content = snippet.to_string(); + writeln!(snippet_prompt, "{content}").unwrap(); + + let token_count = encoding + .encode_with_special_tokens(snippet_prompt.as_str()) + .len(); + if token_count <= remaining_token_count { + if token_count < MAXIMUM_SNIPPET_TOKEN_COUNT { + writeln!(prompt, "{snippet_prompt}").unwrap(); + remaining_token_count -= token_count; + template = ""; + } + } else { + break; + } + } + } + + prompt + } +} diff --git a/crates/ai/src/templates.rs b/crates/ai/src/templates.rs deleted file mode 100644 index d9771ce569..0000000000 --- a/crates/ai/src/templates.rs +++ /dev/null @@ -1,76 +0,0 @@ -use std::fmt::Write; - -pub struct PromptCodeSnippet { - path: Option, - language_name: Option, - content: String, -} - -enum PromptFileType { - Text, - Code, -} - -#[derive(Default)] -struct PromptArguments { - pub language_name: Option, - pub project_name: Option, - pub snippets: Vec, -} - -impl PromptArguments { - pub fn get_file_type(&self) -> PromptFileType { - if self - .language_name - .as_ref() - .and_then(|name| Some(!["Markdown", "Plain Text"].contains(&name.as_str()))) - .unwrap_or(true) - { - PromptFileType::Code - } else { - PromptFileType::Text - } - } -} - -trait PromptTemplate { - fn generate(args: PromptArguments) -> String; -} - -struct EngineerPreamble {} - -impl PromptTemplate for EngineerPreamble { - fn generate(args: PromptArguments) -> String { - let mut prompt = String::new(); - - match args.get_file_type() { - PromptFileType::Code => { - writeln!( - prompt, - "You are an expert {} engineer.", - args.language_name.unwrap_or("".to_string()) - ) - .unwrap(); - } - PromptFileType::Text => { - writeln!(prompt, "You are an expert engineer.").unwrap(); - } - } - - if let Some(project_name) = args.project_name { - writeln!( - prompt, - "You are currently working inside the '{project_name}' in Zed the code editor." - ) - .unwrap(); - } - - prompt - } -} - -struct RepositorySnippets {} - -impl PromptTemplate for RepositorySnippets { - fn generate(args: PromptArguments) -> String {} -} diff --git a/crates/ai/src/templates/base.rs b/crates/ai/src/templates/base.rs new file mode 100644 index 0000000000..3d8479e512 --- /dev/null +++ b/crates/ai/src/templates/base.rs @@ -0,0 +1,112 @@ +use std::cmp::Reverse; + +use crate::templates::repository_context::PromptCodeSnippet; + +pub(crate) enum PromptFileType { + Text, + Code, +} + +#[derive(Default)] +pub struct PromptArguments { + pub model_name: String, + pub language_name: Option, + pub project_name: Option, + pub snippets: Vec, + pub reserved_tokens: usize, +} + +impl PromptArguments { + pub(crate) fn get_file_type(&self) -> PromptFileType { + if self + .language_name + .as_ref() + .and_then(|name| Some(!["Markdown", "Plain Text"].contains(&name.as_str()))) + .unwrap_or(true) + { + PromptFileType::Code + } else { + PromptFileType::Text + } + } +} + +pub trait PromptTemplate { + fn generate(&self, args: &PromptArguments, max_token_length: Option) -> String; +} + +#[repr(i8)] +#[derive(PartialEq, Eq, PartialOrd, Ord)] +pub enum PromptPriority { + Low, + Medium, + High, +} + +pub struct PromptChain { + args: PromptArguments, + templates: Vec<(PromptPriority, Box)>, +} + +impl PromptChain { + pub fn new( + args: PromptArguments, + templates: Vec<(PromptPriority, Box)>, + ) -> Self { + // templates.sort_by(|a, b| a.0.cmp(&b.0)); + + PromptChain { args, templates } + } + + pub fn generate(&self, truncate: bool) -> anyhow::Result { + // Argsort based on Prompt Priority + let mut sorted_indices = (0..self.templates.len()).collect::>(); + sorted_indices.sort_by_key(|&i| Reverse(&self.templates[i].0)); + + println!("{:?}", sorted_indices); + + let mut prompts = Vec::new(); + for (_, template) in &self.templates { + prompts.push(template.generate(&self.args, None)); + } + + anyhow::Ok(prompts.join("\n")) + } +} + +#[cfg(test)] +pub(crate) mod tests { + use super::*; + + #[test] + pub fn test_prompt_chain() { + struct TestPromptTemplate {} + impl PromptTemplate for TestPromptTemplate { + fn generate(&self, args: &PromptArguments, max_token_length: Option) -> String { + "This is a test prompt template".to_string() + } + } + + struct TestLowPriorityTemplate {} + impl PromptTemplate for TestLowPriorityTemplate { + fn generate(&self, args: &PromptArguments, max_token_length: Option) -> String { + "This is a low priority test prompt template".to_string() + } + } + + let args = PromptArguments { + model_name: "gpt-4".to_string(), + ..Default::default() + }; + + let templates: Vec<(PromptPriority, Box)> = vec![ + (PromptPriority::High, Box::new(TestPromptTemplate {})), + (PromptPriority::Medium, Box::new(TestLowPriorityTemplate {})), + ]; + let chain = PromptChain::new(args, templates); + + let prompt = chain.generate(false); + println!("{:?}", prompt); + panic!(); + } +} diff --git a/crates/ai/src/templates/mod.rs b/crates/ai/src/templates/mod.rs new file mode 100644 index 0000000000..62cb600eca --- /dev/null +++ b/crates/ai/src/templates/mod.rs @@ -0,0 +1,3 @@ +pub mod base; +pub mod preamble; +pub mod repository_context; diff --git a/crates/ai/src/templates/preamble.rs b/crates/ai/src/templates/preamble.rs new file mode 100644 index 0000000000..b1d33f885e --- /dev/null +++ b/crates/ai/src/templates/preamble.rs @@ -0,0 +1,34 @@ +use crate::templates::base::{PromptArguments, PromptFileType, PromptTemplate}; +use std::fmt::Write; + +struct EngineerPreamble {} + +impl PromptTemplate for EngineerPreamble { + fn generate(&self, args: &PromptArguments, max_token_length: Option) -> String { + let mut prompt = String::new(); + + match args.get_file_type() { + PromptFileType::Code => { + writeln!( + prompt, + "You are an expert {} engineer.", + args.language_name.clone().unwrap_or("".to_string()) + ) + .unwrap(); + } + PromptFileType::Text => { + writeln!(prompt, "You are an expert engineer.").unwrap(); + } + } + + if let Some(project_name) = args.project_name.clone() { + writeln!( + prompt, + "You are currently working inside the '{project_name}' in Zed the code editor." + ) + .unwrap(); + } + + prompt + } +} diff --git a/crates/ai/src/templates/repository_context.rs b/crates/ai/src/templates/repository_context.rs new file mode 100644 index 0000000000..f9c2253c65 --- /dev/null +++ b/crates/ai/src/templates/repository_context.rs @@ -0,0 +1,49 @@ +use std::{ops::Range, path::PathBuf}; + +use gpui::{AsyncAppContext, ModelHandle}; +use language::{Anchor, Buffer}; + +pub struct PromptCodeSnippet { + path: Option, + language_name: Option, + content: String, +} + +impl PromptCodeSnippet { + pub fn new(buffer: ModelHandle, range: Range, cx: &AsyncAppContext) -> Self { + let (content, language_name, file_path) = buffer.read_with(cx, |buffer, _| { + let snapshot = buffer.snapshot(); + let content = snapshot.text_for_range(range.clone()).collect::(); + + let language_name = buffer + .language() + .and_then(|language| Some(language.name().to_string())); + + let file_path = buffer + .file() + .and_then(|file| Some(file.path().to_path_buf())); + + (content, language_name, file_path) + }); + + PromptCodeSnippet { + path: file_path, + language_name, + content, + } + } +} + +impl ToString for PromptCodeSnippet { + fn to_string(&self) -> String { + let path = self + .path + .as_ref() + .and_then(|path| Some(path.to_string_lossy().to_string())) + .unwrap_or("".to_string()); + let language_name = self.language_name.clone().unwrap_or("".to_string()); + let content = self.content.clone(); + + format!("The below code snippet may be relevant from file: {path}\n```{language_name}\n{content}\n```") + } +}