Add SyntaxMap methods for running queries and combining their results

This commit is contained in:
Max Brunsfeld 2022-08-23 14:26:09 -07:00
parent 71e17a54ae
commit 9113c94371

View file

@ -3,11 +3,19 @@ use crate::{
ToTreeSitterPoint, ToTreeSitterPoint,
}; };
use std::{ use std::{
borrow::Cow, cell::RefCell, cmp::Ordering, collections::BinaryHeap, ops::Range, sync::Arc, borrow::Cow,
cell::RefCell,
cmp::{Ordering, Reverse},
collections::BinaryHeap,
iter::Peekable,
ops::{DerefMut, Range},
sync::Arc,
}; };
use sum_tree::{Bias, SeekTarget, SumTree}; use sum_tree::{Bias, SeekTarget, SumTree};
use text::{Anchor, BufferSnapshot, OffsetRangeExt, Point, Rope, ToOffset, ToPoint}; use text::{Anchor, BufferSnapshot, OffsetRangeExt, Point, Rope, ToOffset, ToPoint};
use tree_sitter::{Node, Parser, Tree}; use tree_sitter::{
Node, Parser, Query, QueryCapture, QueryCaptures, QueryCursor, QueryMatch, QueryMatches, Tree,
};
thread_local! { thread_local! {
static PARSER: RefCell<Parser> = RefCell::new(Parser::new()); static PARSER: RefCell<Parser> = RefCell::new(Parser::new());
@ -26,6 +34,42 @@ pub struct SyntaxSnapshot {
layers: SumTree<SyntaxLayer>, layers: SumTree<SyntaxLayer>,
} }
pub struct SyntaxMapCaptures<'a> {
layers: Vec<SyntaxMapCapturesLayer<'a>>,
}
pub struct SyntaxMapMatches<'a> {
layers: Vec<SyntaxMapMatchesLayer<'a>>,
}
pub struct SyntaxMapCapture<'a> {
pub grammar: &'a Grammar,
pub depth: usize,
pub node: Node<'a>,
pub index: u32,
}
pub struct SyntaxMapMatch<'a> {
pub grammar: &'a Grammar,
pub depth: usize,
pub pattern_index: usize,
pub captures: &'a [QueryCapture<'a>],
}
struct SyntaxMapCapturesLayer<'a> {
depth: usize,
captures: Peekable<QueryCaptures<'a, 'a, TextProvider<'a>>>,
grammar: &'a Grammar,
_query_cursor: QueryCursorHandle,
}
struct SyntaxMapMatchesLayer<'a> {
depth: usize,
matches: Peekable<QueryMatches<'a, 'a, TextProvider<'a>>>,
grammar: &'a Grammar,
_query_cursor: QueryCursorHandle,
}
#[derive(Clone)] #[derive(Clone)]
struct SyntaxLayer { struct SyntaxLayer {
depth: usize, depth: usize,
@ -385,6 +429,100 @@ impl SyntaxSnapshot {
self.layers = layers; self.layers = layers;
} }
pub fn captures<'a>(
&'a self,
range: Range<usize>,
buffer: &'a BufferSnapshot,
query: impl Fn(&Grammar) -> Option<&Query>,
) -> SyntaxMapCaptures {
let mut result = SyntaxMapCaptures { layers: Vec::new() };
for (grammar, depth, node) in self.layers_for_range(range.clone(), buffer) {
let query = if let Some(query) = query(grammar) {
query
} else {
continue;
};
let mut query_cursor = QueryCursorHandle::new();
// TODO - add a Tree-sitter API to remove the need for this.
let cursor = unsafe {
std::mem::transmute::<_, &'static mut QueryCursor>(query_cursor.deref_mut())
};
cursor.set_byte_range(range.clone());
let captures = cursor.captures(query, node, TextProvider(buffer.as_rope()));
let mut layer = SyntaxMapCapturesLayer {
depth,
grammar,
captures: captures.peekable(),
_query_cursor: query_cursor,
};
if let Some(key) = layer.sort_key() {
let mut ix = 0;
while let Some(next_layer) = result.layers.get_mut(ix) {
if let Some(next_key) = next_layer.sort_key() {
if key > next_key {
ix += 1;
continue;
}
}
break;
}
result.layers.insert(ix, layer);
}
}
result
}
pub fn matches<'a>(
&'a self,
range: Range<usize>,
buffer: &'a BufferSnapshot,
query: impl Fn(&Grammar) -> Option<&Query>,
) -> SyntaxMapMatches {
let mut result = SyntaxMapMatches { layers: Vec::new() };
for (grammar, depth, node) in self.layers_for_range(range.clone(), buffer) {
let query = if let Some(query) = query(grammar) {
query
} else {
continue;
};
let mut query_cursor = QueryCursorHandle::new();
// TODO - add a Tree-sitter API to remove the need for this.
let cursor = unsafe {
std::mem::transmute::<_, &'static mut QueryCursor>(query_cursor.deref_mut())
};
cursor.set_byte_range(range.clone());
let matches = cursor.matches(query, node, TextProvider(buffer.as_rope()));
let mut layer = SyntaxMapMatchesLayer {
depth,
grammar,
matches: matches.peekable(),
_query_cursor: query_cursor,
};
if let Some(key) = layer.sort_key() {
let mut ix = 0;
while let Some(next_layer) = result.layers.get_mut(ix) {
if let Some(next_key) = next_layer.sort_key() {
if key > next_key {
ix += 1;
continue;
}
}
break;
}
result.layers.insert(ix, layer);
}
}
result
}
pub fn layers(&self, buffer: &BufferSnapshot) -> Vec<(&Grammar, Node)> { pub fn layers(&self, buffer: &BufferSnapshot) -> Vec<(&Grammar, Node)> {
self.layers self.layers
.iter() .iter()
@ -408,7 +546,7 @@ impl SyntaxSnapshot {
&self, &self,
range: Range<T>, range: Range<T>,
buffer: &BufferSnapshot, buffer: &BufferSnapshot,
) -> Vec<(&Grammar, Node)> { ) -> Vec<(&Grammar, usize, Node)> {
let start = buffer.anchor_before(range.start.to_offset(buffer)); let start = buffer.anchor_before(range.start.to_offset(buffer));
let end = buffer.anchor_after(range.end.to_offset(buffer)); let end = buffer.anchor_after(range.end.to_offset(buffer));
@ -424,6 +562,7 @@ impl SyntaxSnapshot {
if let Some(grammar) = &layer.language.grammar { if let Some(grammar) = &layer.language.grammar {
result.push(( result.push((
grammar.as_ref(), grammar.as_ref(),
layer.depth,
layer.tree.root_node_with_offset( layer.tree.root_node_with_offset(
layer.range.start.to_offset(buffer), layer.range.start.to_offset(buffer),
layer.range.start.to_point(buffer).to_ts_point(), layer.range.start.to_point(buffer).to_ts_point(),
@ -437,6 +576,60 @@ impl SyntaxSnapshot {
} }
} }
impl<'a> Iterator for SyntaxMapCaptures<'a> {
type Item = SyntaxMapCapture<'a>;
fn next(&mut self) -> Option<Self::Item> {
let layer = self.layers.first_mut()?;
let (mat, ix) = layer.captures.next()?;
let capture = mat.captures[ix as usize];
let grammar = layer.grammar;
let depth = layer.depth;
if let Some(key) = layer.sort_key() {
let mut i = 1;
while let Some(later_layer) = self.layers.get_mut(i) {
if let Some(later_key) = later_layer.sort_key() {
if key > later_key {
i += 1;
continue;
}
}
break;
}
if i > 1 {
self.layers[0..i].rotate_left(1);
}
} else {
self.layers.remove(0);
}
Some(SyntaxMapCapture {
grammar,
depth,
node: capture.node,
index: capture.index,
})
}
}
impl<'a> SyntaxMapCapturesLayer<'a> {
fn sort_key(&mut self) -> Option<(usize, Reverse<usize>, usize)> {
let (mat, ix) = self.captures.peek()?;
let range = &mat.captures[*ix].node.byte_range();
Some((range.start, Reverse(range.end), self.depth))
}
}
impl<'a> SyntaxMapMatchesLayer<'a> {
fn sort_key(&mut self) -> Option<(usize, Reverse<usize>, usize)> {
let mat = self.matches.peek()?;
let range = mat.captures.first()?.node.start_byte()..mat.captures.last()?.node.end_byte();
Some((range.start, Reverse(range.end), self.depth))
}
}
fn join_ranges( fn join_ranges(
a: impl Iterator<Item = Range<usize>>, a: impl Iterator<Item = Range<usize>>,
b: impl Iterator<Item = Range<usize>>, b: impl Iterator<Item = Range<usize>>,
@ -875,10 +1068,10 @@ mod tests {
"fn a() { dbg!(b.c(vec![d.«e»])) }", "fn a() { dbg!(b.c(vec![d.«e»])) }",
]); ]);
assert_node_ranges( assert_capture_ranges(
&syntax_map, &syntax_map,
&buffer, &buffer,
"(field_identifier) @_", &["field"],
"fn a() { dbg!(b.«c»(vec![d.«e»])) }", "fn a() { dbg!(b.«c»(vec![d.«e»])) }",
); );
} }
@ -909,10 +1102,10 @@ mod tests {
", ",
]); ]);
assert_node_ranges( assert_capture_ranges(
&syntax_map, &syntax_map,
&buffer, &buffer,
"(struct_expression) @_", &["struct"],
" "
fn a() { fn a() {
b!(«B {}»); b!(«B {}»);
@ -952,10 +1145,10 @@ mod tests {
", ",
]); ]);
assert_node_ranges( assert_capture_ranges(
&syntax_map, &syntax_map,
&buffer, &buffer,
"(field_identifier) @_", &["field"],
" "
fn a() { fn a() {
b!( b!(
@ -1129,6 +1322,13 @@ mod tests {
}, },
Some(tree_sitter_rust::language()), Some(tree_sitter_rust::language()),
) )
.with_highlights_query(
r#"
(field_identifier) @field
(struct_expression) @struct
"#,
)
.unwrap()
.with_injection_query( .with_injection_query(
r#" r#"
(macro_invocation (macro_invocation
@ -1156,7 +1356,7 @@ mod tests {
expected_layers.len(), expected_layers.len(),
"wrong number of layers" "wrong number of layers"
); );
for (i, ((_, node), expected_s_exp)) in for (i, ((_, _, node), expected_s_exp)) in
layers.iter().zip(expected_layers.iter()).enumerate() layers.iter().zip(expected_layers.iter()).enumerate()
{ {
let actual_s_exp = node.to_sexp(); let actual_s_exp = node.to_sexp();
@ -1170,18 +1370,25 @@ mod tests {
} }
} }
fn assert_node_ranges( fn assert_capture_ranges(
syntax_map: &SyntaxMap, syntax_map: &SyntaxMap,
buffer: &BufferSnapshot, buffer: &BufferSnapshot,
query: &str, highlight_query_capture_names: &[&str],
marked_string: &str, marked_string: &str,
) { ) {
let mut cursor = QueryCursorHandle::new();
let mut actual_ranges = Vec::<Range<usize>>::new(); let mut actual_ranges = Vec::<Range<usize>>::new();
for (grammar, node) in syntax_map.layers(buffer) { for capture in syntax_map.captures(0..buffer.len(), buffer, |grammar| {
let query = Query::new(grammar.ts_language, query).unwrap(); grammar.highlights_query.as_ref()
for (mat, ix) in cursor.captures(&query, node, TextProvider(buffer.as_rope())) { }) {
actual_ranges.push(mat.captures[ix].node.byte_range()); let name = &capture
.grammar
.highlights_query
.as_ref()
.unwrap()
.capture_names()[capture.index as usize];
dbg!(capture.node, capture.index, name);
if highlight_query_capture_names.contains(&name.as_str()) {
actual_ranges.push(capture.node.byte_range());
} }
} }