diff --git a/crates/ai/src/models.rs b/crates/ai/src/models.rs index 0cafb49d94..d0206cc41c 100644 --- a/crates/ai/src/models.rs +++ b/crates/ai/src/models.rs @@ -6,6 +6,7 @@ pub trait LanguageModel { fn name(&self) -> String; fn count_tokens(&self, content: &str) -> anyhow::Result; fn truncate(&self, content: &str, length: usize) -> anyhow::Result; + fn truncate_start(&self, content: &str, length: usize) -> anyhow::Result; fn capacity(&self) -> anyhow::Result; } @@ -47,6 +48,18 @@ impl LanguageModel for OpenAILanguageModel { Err(anyhow!("bpe for open ai model was not retrieved")) } } + fn truncate_start(&self, content: &str, length: usize) -> anyhow::Result { + if let Some(bpe) = &self.bpe { + let tokens = bpe.encode_with_special_tokens(content); + if tokens.len() > length { + bpe.decode(tokens[length..].to_vec()) + } else { + bpe.decode(tokens) + } + } else { + Err(anyhow!("bpe for open ai model was not retrieved")) + } + } fn capacity(&self) -> anyhow::Result { anyhow::Ok(tiktoken_rs::model::get_context_size(&self.name)) } diff --git a/crates/ai/src/templates/base.rs b/crates/ai/src/templates/base.rs index 923e1833c2..bda1d6c30e 100644 --- a/crates/ai/src/templates/base.rs +++ b/crates/ai/src/templates/base.rs @@ -190,6 +190,13 @@ pub(crate) mod tests { .collect::(), ) } + fn truncate_start(&self, content: &str, length: usize) -> anyhow::Result { + anyhow::Ok( + content.chars().collect::>()[length..] + .into_iter() + .collect::(), + ) + } fn capacity(&self) -> anyhow::Result { anyhow::Ok(self.capacity) } diff --git a/crates/ai/src/templates/file_context.rs b/crates/ai/src/templates/file_context.rs index 5a6489a00c..253d24e469 100644 --- a/crates/ai/src/templates/file_context.rs +++ b/crates/ai/src/templates/file_context.rs @@ -1,9 +1,103 @@ use anyhow::anyhow; +use language::BufferSnapshot; use language::ToOffset; +use crate::models::LanguageModel; use crate::templates::base::PromptArguments; use crate::templates::base::PromptTemplate; use std::fmt::Write; +use std::ops::Range; +use std::sync::Arc; + +fn retrieve_context( + buffer: &BufferSnapshot, + selected_range: &Option>, + model: Arc, + max_token_count: Option, +) -> anyhow::Result<(String, usize, bool)> { + let mut prompt = String::new(); + let mut truncated = false; + if let Some(selected_range) = selected_range { + let start = selected_range.start.to_offset(buffer); + let end = selected_range.end.to_offset(buffer); + + let start_window = buffer.text_for_range(0..start).collect::(); + + let mut selected_window = String::new(); + if start == end { + write!(selected_window, "<|START|>").unwrap(); + } else { + write!(selected_window, "<|START|").unwrap(); + } + + write!( + selected_window, + "{}", + buffer.text_for_range(start..end).collect::() + ) + .unwrap(); + + if start != end { + write!(selected_window, "|END|>").unwrap(); + } + + let end_window = buffer.text_for_range(end..buffer.len()).collect::(); + + if let Some(max_token_count) = max_token_count { + let selected_tokens = model.count_tokens(&selected_window)?; + if selected_tokens > max_token_count { + return Err(anyhow!( + "selected range is greater than model context window, truncation not possible" + )); + }; + + let mut remaining_tokens = max_token_count - selected_tokens; + let start_window_tokens = model.count_tokens(&start_window)?; + let end_window_tokens = model.count_tokens(&end_window)?; + let outside_tokens = start_window_tokens + end_window_tokens; + if outside_tokens > remaining_tokens { + let (start_goal_tokens, end_goal_tokens) = + if start_window_tokens < end_window_tokens { + let start_goal_tokens = (remaining_tokens / 2).min(start_window_tokens); + remaining_tokens -= start_goal_tokens; + let end_goal_tokens = remaining_tokens.min(end_window_tokens); + (start_goal_tokens, end_goal_tokens) + } else { + let end_goal_tokens = (remaining_tokens / 2).min(end_window_tokens); + remaining_tokens -= end_goal_tokens; + let start_goal_tokens = remaining_tokens.min(start_window_tokens); + (start_goal_tokens, end_goal_tokens) + }; + + let truncated_start_window = + model.truncate_start(&start_window, start_goal_tokens)?; + let truncated_end_window = model.truncate(&end_window, end_goal_tokens)?; + writeln!( + prompt, + "{truncated_start_window}{selected_window}{truncated_end_window}" + ) + .unwrap(); + truncated = true; + } else { + writeln!(prompt, "{start_window}{selected_window}{end_window}").unwrap(); + } + } else { + // If we dont have a selected range, include entire file. + writeln!(prompt, "{}", &buffer.text()).unwrap(); + + // Dumb truncation strategy + if let Some(max_token_count) = max_token_count { + if model.count_tokens(&prompt)? > max_token_count { + truncated = true; + prompt = model.truncate(&prompt, max_token_count)?; + } + } + } + } + + let token_count = model.count_tokens(&prompt)?; + anyhow::Ok((prompt, token_count, truncated)) +} pub struct FileContext {} @@ -28,53 +122,24 @@ impl PromptTemplate for FileContext { .clone() .unwrap_or("".to_string()) .to_lowercase(); - writeln!(prompt, "```{language_name}").unwrap(); + + let (context, _, truncated) = retrieve_context( + buffer, + &args.selected_range, + args.model.clone(), + max_token_length, + )?; + writeln!(prompt, "```{language_name}\n{context}\n```").unwrap(); if let Some(selected_range) = &args.selected_range { let start = selected_range.start.to_offset(buffer); let end = selected_range.end.to_offset(buffer); - writeln!( - prompt, - "{}", - buffer.text_for_range(0..start).collect::() - ) - .unwrap(); - - if start == end { - write!(prompt, "<|START|>").unwrap(); - } else { - write!(prompt, "<|START|").unwrap(); - } - - write!( - prompt, - "{}", - buffer.text_for_range(start..end).collect::() - ) - .unwrap(); - if start != end { - write!(prompt, "|END|>").unwrap(); - } - - write!( - prompt, - "{}", - buffer.text_for_range(end..buffer.len()).collect::() - ) - .unwrap(); - - writeln!(prompt, "```").unwrap(); - if start == end { writeln!(prompt, "In particular, the user's cursor is currently on the '<|START|>' span in the above content, with no text selected.").unwrap(); } else { writeln!(prompt, "In particular, the user has selected a section of the text between the '<|START|' and '|END|>' spans.").unwrap(); } - } else { - // If we dont have a selected range, include entire file. - writeln!(prompt, "{}", &buffer.text()).unwrap(); - writeln!(prompt, "```").unwrap(); } // Really dumb truncation strategy diff --git a/crates/assistant/src/prompts.rs b/crates/assistant/src/prompts.rs index dffcbc2923..c7b52a3540 100644 --- a/crates/assistant/src/prompts.rs +++ b/crates/assistant/src/prompts.rs @@ -166,6 +166,8 @@ pub fn generate_content_prompt( let chain = PromptChain::new(args, templates); let (prompt, _) = chain.generate(true)?; + println!("PROMPT: {:?}", &prompt); + anyhow::Ok(prompt) }