use ai::{ embedding::{Embedding, EmbeddingProvider}, models::TruncationDirection, }; use anyhow::{anyhow, Result}; use language::{Grammar, Language}; use rusqlite::{ types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef}, ToSql, }; use sha1::{Digest, Sha1}; use std::{ borrow::Cow, cmp::{self, Reverse}, collections::HashSet, ops::Range, path::Path, sync::Arc, }; use tree_sitter::{Parser, QueryCursor}; #[derive(Debug, PartialEq, Eq, Clone, Hash)] pub struct SpanDigest(pub [u8; 20]); impl FromSql for SpanDigest { fn column_result(value: ValueRef) -> FromSqlResult { let blob = value.as_blob()?; let bytes = blob.try_into() .map_err(|_| rusqlite::types::FromSqlError::InvalidBlobSize { expected_size: 20, blob_size: blob.len(), })?; return Ok(SpanDigest(bytes)); } } impl ToSql for SpanDigest { fn to_sql(&self) -> rusqlite::Result { self.0.to_sql() } } impl From<&'_ str> for SpanDigest { fn from(value: &'_ str) -> Self { let mut sha1 = Sha1::new(); sha1.update(value); Self(sha1.finalize().into()) } } #[derive(Debug, PartialEq, Clone)] pub struct Span { pub name: String, pub range: Range, pub content: String, pub embedding: Option, pub digest: SpanDigest, pub token_count: usize, } const CODE_CONTEXT_TEMPLATE: &str = "The below code snippet is from file ''\n\n```\n\n```"; const ENTIRE_FILE_TEMPLATE: &str = "The below snippet is from file ''\n\n```\n\n```"; const MARKDOWN_CONTEXT_TEMPLATE: &str = "The below file contents is from file ''\n\n"; pub const PARSEABLE_ENTIRE_FILE_TYPES: &[&str] = &[ "TOML", "YAML", "CSS", "HEEX", "ERB", "SVELTE", "HTML", "Scheme", ]; pub struct CodeContextRetriever { pub parser: Parser, pub cursor: QueryCursor, pub embedding_provider: Arc, } // Every match has an item, this represents the fundamental treesitter symbol and anchors the search // Every match has one or more 'name' captures. These indicate the display range of the item for deduplication. // If there are preceeding comments, we track this with a context capture // If there is a piece that should be collapsed in hierarchical queries, we capture it with a collapse capture // If there is a piece that should be kept inside a collapsed node, we capture it with a keep capture #[derive(Debug, Clone)] pub struct CodeContextMatch { pub start_col: usize, pub item_range: Option>, pub name_range: Option>, pub context_ranges: Vec>, pub collapse_ranges: Vec>, } impl CodeContextRetriever { pub fn new(embedding_provider: Arc) -> Self { Self { parser: Parser::new(), cursor: QueryCursor::new(), embedding_provider, } } fn parse_entire_file( &self, relative_path: Option<&Path>, language_name: Arc, content: &str, ) -> Result> { let document_span = ENTIRE_FILE_TEMPLATE .replace( "", &relative_path.map_or(Cow::Borrowed("untitled"), |path| path.to_string_lossy()), ) .replace("", language_name.as_ref()) .replace("", &content); let digest = SpanDigest::from(document_span.as_str()); let model = self.embedding_provider.base_model(); let document_span = model.truncate( &document_span, model.capacity()?, ai::models::TruncationDirection::End, )?; let token_count = model.count_tokens(&document_span)?; Ok(vec![Span { range: 0..content.len(), content: document_span, embedding: Default::default(), name: language_name.to_string(), digest, token_count, }]) } fn parse_markdown_file( &self, relative_path: Option<&Path>, content: &str, ) -> Result> { let document_span = MARKDOWN_CONTEXT_TEMPLATE .replace( "", &relative_path.map_or(Cow::Borrowed("untitled"), |path| path.to_string_lossy()), ) .replace("", &content); let digest = SpanDigest::from(document_span.as_str()); let model = self.embedding_provider.base_model(); let document_span = model.truncate( &document_span, model.capacity()?, ai::models::TruncationDirection::End, )?; let token_count = model.count_tokens(&document_span)?; Ok(vec![Span { range: 0..content.len(), content: document_span, embedding: None, name: "Markdown".to_string(), digest, token_count, }]) } fn get_matches_in_file( &mut self, content: &str, grammar: &Arc, ) -> Result> { let embedding_config = grammar .embedding_config .as_ref() .ok_or_else(|| anyhow!("no embedding queries"))?; self.parser.set_language(grammar.ts_language).unwrap(); let tree = self .parser .parse(&content, None) .ok_or_else(|| anyhow!("parsing failed"))?; let mut captures: Vec = Vec::new(); let mut collapse_ranges: Vec> = Vec::new(); let mut keep_ranges: Vec> = Vec::new(); for mat in self.cursor.matches( &embedding_config.query, tree.root_node(), content.as_bytes(), ) { let mut start_col = 0; let mut item_range: Option> = None; let mut name_range: Option> = None; let mut context_ranges: Vec> = Vec::new(); collapse_ranges.clear(); keep_ranges.clear(); for capture in mat.captures { if capture.index == embedding_config.item_capture_ix { item_range = Some(capture.node.byte_range()); start_col = capture.node.start_position().column; } else if Some(capture.index) == embedding_config.name_capture_ix { name_range = Some(capture.node.byte_range()); } else if Some(capture.index) == embedding_config.context_capture_ix { context_ranges.push(capture.node.byte_range()); } else if Some(capture.index) == embedding_config.collapse_capture_ix { collapse_ranges.push(capture.node.byte_range()); } else if Some(capture.index) == embedding_config.keep_capture_ix { keep_ranges.push(capture.node.byte_range()); } } captures.push(CodeContextMatch { start_col, item_range, name_range, context_ranges, collapse_ranges: subtract_ranges(&collapse_ranges, &keep_ranges), }); } Ok(captures) } pub fn parse_file_with_template( &mut self, relative_path: Option<&Path>, content: &str, language: Arc, ) -> Result> { let language_name = language.name(); if PARSEABLE_ENTIRE_FILE_TYPES.contains(&language_name.as_ref()) { return self.parse_entire_file(relative_path, language_name, &content); } else if ["Markdown", "Plain Text"].contains(&language_name.as_ref()) { return self.parse_markdown_file(relative_path, &content); } let mut spans = self.parse_file(content, language)?; for span in &mut spans { let document_content = CODE_CONTEXT_TEMPLATE .replace( "", &relative_path.map_or(Cow::Borrowed("untitled"), |path| path.to_string_lossy()), ) .replace("", language_name.as_ref()) .replace("item", &span.content); let model = self.embedding_provider.base_model(); let document_content = model.truncate( &document_content, model.capacity()?, TruncationDirection::End, )?; let token_count = model.count_tokens(&document_content)?; span.content = document_content; span.token_count = token_count; } Ok(spans) } pub fn parse_file(&mut self, content: &str, language: Arc) -> Result> { let grammar = language .grammar() .ok_or_else(|| anyhow!("no grammar for language"))?; // Iterate through query matches let matches = self.get_matches_in_file(content, grammar)?; let language_scope = language.default_scope(); let placeholder = language_scope.collapsed_placeholder(); let mut spans = Vec::new(); let mut collapsed_ranges_within = Vec::new(); let mut parsed_name_ranges = HashSet::new(); for (i, context_match) in matches.iter().enumerate() { // Items which are collapsible but not embeddable have no item range let item_range = if let Some(item_range) = context_match.item_range.clone() { item_range } else { continue; }; // Checks for deduplication let name; if let Some(name_range) = context_match.name_range.clone() { name = content .get(name_range.clone()) .map_or(String::new(), |s| s.to_string()); if parsed_name_ranges.contains(&name_range) { continue; } parsed_name_ranges.insert(name_range); } else { name = String::new(); } collapsed_ranges_within.clear(); 'outer: for remaining_match in &matches[(i + 1)..] { for collapsed_range in &remaining_match.collapse_ranges { if item_range.start <= collapsed_range.start && item_range.end >= collapsed_range.end { collapsed_ranges_within.push(collapsed_range.clone()); } else { break 'outer; } } } collapsed_ranges_within.sort_by_key(|r| (r.start, Reverse(r.end))); let mut span_content = String::new(); for context_range in &context_match.context_ranges { add_content_from_range( &mut span_content, content, context_range.clone(), context_match.start_col, ); span_content.push_str("\n"); } let mut offset = item_range.start; for collapsed_range in &collapsed_ranges_within { if collapsed_range.start > offset { add_content_from_range( &mut span_content, content, offset..collapsed_range.start, context_match.start_col, ); offset = collapsed_range.start; } if collapsed_range.end > offset { span_content.push_str(placeholder); offset = collapsed_range.end; } } if offset < item_range.end { add_content_from_range( &mut span_content, content, offset..item_range.end, context_match.start_col, ); } let sha1 = SpanDigest::from(span_content.as_str()); spans.push(Span { name, content: span_content, range: item_range.clone(), embedding: None, digest: sha1, token_count: 0, }) } return Ok(spans); } } pub(crate) fn subtract_ranges( ranges: &[Range], ranges_to_subtract: &[Range], ) -> Vec> { let mut result = Vec::new(); let mut ranges_to_subtract = ranges_to_subtract.iter().peekable(); for range in ranges { let mut offset = range.start; while offset < range.end { if let Some(range_to_subtract) = ranges_to_subtract.peek() { if offset < range_to_subtract.start { let next_offset = cmp::min(range_to_subtract.start, range.end); result.push(offset..next_offset); offset = next_offset; } else { let next_offset = cmp::min(range_to_subtract.end, range.end); offset = next_offset; } if offset >= range_to_subtract.end { ranges_to_subtract.next(); } } else { result.push(offset..range.end); offset = range.end; } } } result } fn add_content_from_range( output: &mut String, content: &str, range: Range, start_col: usize, ) { for mut line in content.get(range.clone()).unwrap_or("").lines() { for _ in 0..start_col { if line.starts_with(' ') { line = &line[1..]; } else { break; } } output.push_str(line); output.push('\n'); } output.pop(); }