progress on prompt chains
This commit is contained in:
parent
40755961ea
commit
500af6d775
8 changed files with 349 additions and 76 deletions
1
Cargo.lock
generated
1
Cargo.lock
generated
|
@ -91,6 +91,7 @@ dependencies = [
|
|||
"futures 0.3.28",
|
||||
"gpui",
|
||||
"isahc",
|
||||
"language",
|
||||
"lazy_static",
|
||||
"log",
|
||||
"matrixmultiply",
|
||||
|
|
|
@ -11,6 +11,7 @@ doctest = false
|
|||
[dependencies]
|
||||
gpui = { path = "../gpui" }
|
||||
util = { path = "../util" }
|
||||
language = { path = "../language" }
|
||||
async-trait.workspace = true
|
||||
anyhow.workspace = true
|
||||
futures.workspace = true
|
||||
|
|
149
crates/ai/src/prompts.rs
Normal file
149
crates/ai/src/prompts.rs
Normal file
|
@ -0,0 +1,149 @@
|
|||
use gpui::{AsyncAppContext, ModelHandle};
|
||||
use language::{Anchor, Buffer};
|
||||
use std::{fmt::Write, ops::Range, path::PathBuf};
|
||||
|
||||
pub struct PromptCodeSnippet {
|
||||
path: Option<PathBuf>,
|
||||
language_name: Option<String>,
|
||||
content: String,
|
||||
}
|
||||
|
||||
impl PromptCodeSnippet {
|
||||
pub fn new(buffer: ModelHandle<Buffer>, range: Range<Anchor>, cx: &AsyncAppContext) -> Self {
|
||||
let (content, language_name, file_path) = buffer.read_with(cx, |buffer, _| {
|
||||
let snapshot = buffer.snapshot();
|
||||
let content = snapshot.text_for_range(range.clone()).collect::<String>();
|
||||
|
||||
let language_name = buffer
|
||||
.language()
|
||||
.and_then(|language| Some(language.name().to_string()));
|
||||
|
||||
let file_path = buffer
|
||||
.file()
|
||||
.and_then(|file| Some(file.path().to_path_buf()));
|
||||
|
||||
(content, language_name, file_path)
|
||||
});
|
||||
|
||||
PromptCodeSnippet {
|
||||
path: file_path,
|
||||
language_name,
|
||||
content,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ToString for PromptCodeSnippet {
|
||||
fn to_string(&self) -> String {
|
||||
let path = self
|
||||
.path
|
||||
.as_ref()
|
||||
.and_then(|path| Some(path.to_string_lossy().to_string()))
|
||||
.unwrap_or("".to_string());
|
||||
let language_name = self.language_name.clone().unwrap_or("".to_string());
|
||||
let content = self.content.clone();
|
||||
|
||||
format!("The below code snippet may be relevant from file: {path}\n```{language_name}\n{content}\n```")
|
||||
}
|
||||
}
|
||||
|
||||
enum PromptFileType {
|
||||
Text,
|
||||
Code,
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
struct PromptArguments {
|
||||
pub language_name: Option<String>,
|
||||
pub project_name: Option<String>,
|
||||
pub snippets: Vec<PromptCodeSnippet>,
|
||||
pub model_name: String,
|
||||
}
|
||||
|
||||
impl PromptArguments {
|
||||
pub fn get_file_type(&self) -> PromptFileType {
|
||||
if self
|
||||
.language_name
|
||||
.as_ref()
|
||||
.and_then(|name| Some(!["Markdown", "Plain Text"].contains(&name.as_str())))
|
||||
.unwrap_or(true)
|
||||
{
|
||||
PromptFileType::Code
|
||||
} else {
|
||||
PromptFileType::Text
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
trait PromptTemplate {
|
||||
fn generate(args: PromptArguments, max_token_length: Option<usize>) -> String;
|
||||
}
|
||||
|
||||
struct EngineerPreamble {}
|
||||
|
||||
impl PromptTemplate for EngineerPreamble {
|
||||
fn generate(args: PromptArguments, max_token_length: Option<usize>) -> String {
|
||||
let mut prompt = String::new();
|
||||
|
||||
match args.get_file_type() {
|
||||
PromptFileType::Code => {
|
||||
writeln!(
|
||||
prompt,
|
||||
"You are an expert {} engineer.",
|
||||
args.language_name.unwrap_or("".to_string())
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
PromptFileType::Text => {
|
||||
writeln!(prompt, "You are an expert engineer.").unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(project_name) = args.project_name {
|
||||
writeln!(
|
||||
prompt,
|
||||
"You are currently working inside the '{project_name}' in Zed the code editor."
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
prompt
|
||||
}
|
||||
}
|
||||
|
||||
struct RepositorySnippets {}
|
||||
|
||||
impl PromptTemplate for RepositorySnippets {
|
||||
fn generate(args: PromptArguments, max_token_length: Option<usize>) -> String {
|
||||
const MAXIMUM_SNIPPET_TOKEN_COUNT: usize = 500;
|
||||
let mut template = "You are working inside a large repository, here are a few code snippets that may be useful";
|
||||
let mut prompt = String::new();
|
||||
|
||||
if let Ok(encoding) = tiktoken_rs::get_bpe_from_model(args.model_name.as_str()) {
|
||||
let default_token_count =
|
||||
tiktoken_rs::model::get_context_size(args.model_name.as_str());
|
||||
let mut remaining_token_count = max_token_length.unwrap_or(default_token_count);
|
||||
|
||||
for snippet in args.snippets {
|
||||
let mut snippet_prompt = template.to_string();
|
||||
let content = snippet.to_string();
|
||||
writeln!(snippet_prompt, "{content}").unwrap();
|
||||
|
||||
let token_count = encoding
|
||||
.encode_with_special_tokens(snippet_prompt.as_str())
|
||||
.len();
|
||||
if token_count <= remaining_token_count {
|
||||
if token_count < MAXIMUM_SNIPPET_TOKEN_COUNT {
|
||||
writeln!(prompt, "{snippet_prompt}").unwrap();
|
||||
remaining_token_count -= token_count;
|
||||
template = "";
|
||||
}
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
prompt
|
||||
}
|
||||
}
|
|
@ -1,76 +0,0 @@
|
|||
use std::fmt::Write;
|
||||
|
||||
pub struct PromptCodeSnippet {
|
||||
path: Option<PathBuf>,
|
||||
language_name: Option<String>,
|
||||
content: String,
|
||||
}
|
||||
|
||||
enum PromptFileType {
|
||||
Text,
|
||||
Code,
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
struct PromptArguments {
|
||||
pub language_name: Option<String>,
|
||||
pub project_name: Option<String>,
|
||||
pub snippets: Vec<PromptCodeSnippet>,
|
||||
}
|
||||
|
||||
impl PromptArguments {
|
||||
pub fn get_file_type(&self) -> PromptFileType {
|
||||
if self
|
||||
.language_name
|
||||
.as_ref()
|
||||
.and_then(|name| Some(!["Markdown", "Plain Text"].contains(&name.as_str())))
|
||||
.unwrap_or(true)
|
||||
{
|
||||
PromptFileType::Code
|
||||
} else {
|
||||
PromptFileType::Text
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
trait PromptTemplate {
|
||||
fn generate(args: PromptArguments) -> String;
|
||||
}
|
||||
|
||||
struct EngineerPreamble {}
|
||||
|
||||
impl PromptTemplate for EngineerPreamble {
|
||||
fn generate(args: PromptArguments) -> String {
|
||||
let mut prompt = String::new();
|
||||
|
||||
match args.get_file_type() {
|
||||
PromptFileType::Code => {
|
||||
writeln!(
|
||||
prompt,
|
||||
"You are an expert {} engineer.",
|
||||
args.language_name.unwrap_or("".to_string())
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
PromptFileType::Text => {
|
||||
writeln!(prompt, "You are an expert engineer.").unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(project_name) = args.project_name {
|
||||
writeln!(
|
||||
prompt,
|
||||
"You are currently working inside the '{project_name}' in Zed the code editor."
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
prompt
|
||||
}
|
||||
}
|
||||
|
||||
struct RepositorySnippets {}
|
||||
|
||||
impl PromptTemplate for RepositorySnippets {
|
||||
fn generate(args: PromptArguments) -> String {}
|
||||
}
|
112
crates/ai/src/templates/base.rs
Normal file
112
crates/ai/src/templates/base.rs
Normal file
|
@ -0,0 +1,112 @@
|
|||
use std::cmp::Reverse;
|
||||
|
||||
use crate::templates::repository_context::PromptCodeSnippet;
|
||||
|
||||
pub(crate) enum PromptFileType {
|
||||
Text,
|
||||
Code,
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct PromptArguments {
|
||||
pub model_name: String,
|
||||
pub language_name: Option<String>,
|
||||
pub project_name: Option<String>,
|
||||
pub snippets: Vec<PromptCodeSnippet>,
|
||||
pub reserved_tokens: usize,
|
||||
}
|
||||
|
||||
impl PromptArguments {
|
||||
pub(crate) fn get_file_type(&self) -> PromptFileType {
|
||||
if self
|
||||
.language_name
|
||||
.as_ref()
|
||||
.and_then(|name| Some(!["Markdown", "Plain Text"].contains(&name.as_str())))
|
||||
.unwrap_or(true)
|
||||
{
|
||||
PromptFileType::Code
|
||||
} else {
|
||||
PromptFileType::Text
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait PromptTemplate {
|
||||
fn generate(&self, args: &PromptArguments, max_token_length: Option<usize>) -> String;
|
||||
}
|
||||
|
||||
#[repr(i8)]
|
||||
#[derive(PartialEq, Eq, PartialOrd, Ord)]
|
||||
pub enum PromptPriority {
|
||||
Low,
|
||||
Medium,
|
||||
High,
|
||||
}
|
||||
|
||||
pub struct PromptChain {
|
||||
args: PromptArguments,
|
||||
templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)>,
|
||||
}
|
||||
|
||||
impl PromptChain {
|
||||
pub fn new(
|
||||
args: PromptArguments,
|
||||
templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)>,
|
||||
) -> Self {
|
||||
// templates.sort_by(|a, b| a.0.cmp(&b.0));
|
||||
|
||||
PromptChain { args, templates }
|
||||
}
|
||||
|
||||
pub fn generate(&self, truncate: bool) -> anyhow::Result<String> {
|
||||
// Argsort based on Prompt Priority
|
||||
let mut sorted_indices = (0..self.templates.len()).collect::<Vec<_>>();
|
||||
sorted_indices.sort_by_key(|&i| Reverse(&self.templates[i].0));
|
||||
|
||||
println!("{:?}", sorted_indices);
|
||||
|
||||
let mut prompts = Vec::new();
|
||||
for (_, template) in &self.templates {
|
||||
prompts.push(template.generate(&self.args, None));
|
||||
}
|
||||
|
||||
anyhow::Ok(prompts.join("\n"))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
pub fn test_prompt_chain() {
|
||||
struct TestPromptTemplate {}
|
||||
impl PromptTemplate for TestPromptTemplate {
|
||||
fn generate(&self, args: &PromptArguments, max_token_length: Option<usize>) -> String {
|
||||
"This is a test prompt template".to_string()
|
||||
}
|
||||
}
|
||||
|
||||
struct TestLowPriorityTemplate {}
|
||||
impl PromptTemplate for TestLowPriorityTemplate {
|
||||
fn generate(&self, args: &PromptArguments, max_token_length: Option<usize>) -> String {
|
||||
"This is a low priority test prompt template".to_string()
|
||||
}
|
||||
}
|
||||
|
||||
let args = PromptArguments {
|
||||
model_name: "gpt-4".to_string(),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
|
||||
(PromptPriority::High, Box::new(TestPromptTemplate {})),
|
||||
(PromptPriority::Medium, Box::new(TestLowPriorityTemplate {})),
|
||||
];
|
||||
let chain = PromptChain::new(args, templates);
|
||||
|
||||
let prompt = chain.generate(false);
|
||||
println!("{:?}", prompt);
|
||||
panic!();
|
||||
}
|
||||
}
|
3
crates/ai/src/templates/mod.rs
Normal file
3
crates/ai/src/templates/mod.rs
Normal file
|
@ -0,0 +1,3 @@
|
|||
pub mod base;
|
||||
pub mod preamble;
|
||||
pub mod repository_context;
|
34
crates/ai/src/templates/preamble.rs
Normal file
34
crates/ai/src/templates/preamble.rs
Normal file
|
@ -0,0 +1,34 @@
|
|||
use crate::templates::base::{PromptArguments, PromptFileType, PromptTemplate};
|
||||
use std::fmt::Write;
|
||||
|
||||
struct EngineerPreamble {}
|
||||
|
||||
impl PromptTemplate for EngineerPreamble {
|
||||
fn generate(&self, args: &PromptArguments, max_token_length: Option<usize>) -> String {
|
||||
let mut prompt = String::new();
|
||||
|
||||
match args.get_file_type() {
|
||||
PromptFileType::Code => {
|
||||
writeln!(
|
||||
prompt,
|
||||
"You are an expert {} engineer.",
|
||||
args.language_name.clone().unwrap_or("".to_string())
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
PromptFileType::Text => {
|
||||
writeln!(prompt, "You are an expert engineer.").unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(project_name) = args.project_name.clone() {
|
||||
writeln!(
|
||||
prompt,
|
||||
"You are currently working inside the '{project_name}' in Zed the code editor."
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
prompt
|
||||
}
|
||||
}
|
49
crates/ai/src/templates/repository_context.rs
Normal file
49
crates/ai/src/templates/repository_context.rs
Normal file
|
@ -0,0 +1,49 @@
|
|||
use std::{ops::Range, path::PathBuf};
|
||||
|
||||
use gpui::{AsyncAppContext, ModelHandle};
|
||||
use language::{Anchor, Buffer};
|
||||
|
||||
pub struct PromptCodeSnippet {
|
||||
path: Option<PathBuf>,
|
||||
language_name: Option<String>,
|
||||
content: String,
|
||||
}
|
||||
|
||||
impl PromptCodeSnippet {
|
||||
pub fn new(buffer: ModelHandle<Buffer>, range: Range<Anchor>, cx: &AsyncAppContext) -> Self {
|
||||
let (content, language_name, file_path) = buffer.read_with(cx, |buffer, _| {
|
||||
let snapshot = buffer.snapshot();
|
||||
let content = snapshot.text_for_range(range.clone()).collect::<String>();
|
||||
|
||||
let language_name = buffer
|
||||
.language()
|
||||
.and_then(|language| Some(language.name().to_string()));
|
||||
|
||||
let file_path = buffer
|
||||
.file()
|
||||
.and_then(|file| Some(file.path().to_path_buf()));
|
||||
|
||||
(content, language_name, file_path)
|
||||
});
|
||||
|
||||
PromptCodeSnippet {
|
||||
path: file_path,
|
||||
language_name,
|
||||
content,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ToString for PromptCodeSnippet {
|
||||
fn to_string(&self) -> String {
|
||||
let path = self
|
||||
.path
|
||||
.as_ref()
|
||||
.and_then(|path| Some(path.to_string_lossy().to_string()))
|
||||
.unwrap_or("".to_string());
|
||||
let language_name = self.language_name.clone().unwrap_or("".to_string());
|
||||
let content = self.content.clone();
|
||||
|
||||
format!("The below code snippet may be relevant from file: {path}\n```{language_name}\n{content}\n```")
|
||||
}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue