Add lua script access to code using cx + reuse project search logic (#26269)

Access to `cx` will be needed for anything that queries entities. In
this commit this is use of `WorktreeStore::find_search_candidates`. In
the future it will be things like access to LSP / tree-sitter outlines /
etc.

Changes to support access to `cx` from functions provided to the Lua
script:

* Adds a channel of requests that require a `cx`. Work enqueued to this
channel is run on the foreground thread.

* Adds `async` and `send` features to `mlua` crate so that async rust
functions can be used from Lua.

* Changes uses of `Rc<RefCell<...>>` to `Arc<Mutex<...>>` so that the
futures are `Send`.

One benefit of reusing project search logic for search candidates is
that it properly ignores paths.

Release Notes:

- N/A
This commit is contained in:
Michael Sloan 2025-03-07 03:02:49 -07:00 committed by GitHub
parent b0d1024f66
commit 205f9a9f03
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 222 additions and 105 deletions

6
Cargo.lock generated
View file

@ -8134,6 +8134,7 @@ checksum = "d3f763c1041eff92ffb5d7169968a327e1ed2ebfe425dac0ee5a35f29082534b"
dependencies = [ dependencies = [
"bstr", "bstr",
"either", "either",
"futures-util",
"mlua-sys", "mlua-sys",
"num-traits", "num-traits",
"parking_lot", "parking_lot",
@ -11915,12 +11916,17 @@ version = "0.1.0"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"assistant_tool", "assistant_tool",
"futures 0.3.31",
"gpui", "gpui",
"mlua", "mlua",
"parking_lot",
"project",
"regex", "regex",
"schemars", "schemars",
"serde", "serde",
"serde_json", "serde_json",
"smol",
"util",
"workspace", "workspace",
] ]

View file

@ -452,7 +452,7 @@ livekit = { git = "https://github.com/zed-industries/livekit-rust-sdks", rev = "
], default-features = false } ], default-features = false }
log = { version = "0.4.16", features = ["kv_unstable_serde", "serde"] } log = { version = "0.4.16", features = ["kv_unstable_serde", "serde"] }
markup5ever_rcdom = "0.3.0" markup5ever_rcdom = "0.3.0"
mlua = { version = "0.10", features = ["lua54", "vendored"] } mlua = { version = "0.10", features = ["lua54", "vendored", "async", "send"] }
nanoid = "0.4" nanoid = "0.4"
nbformat = { version = "0.10.0" } nbformat = { version = "0.10.0" }
nix = "0.29" nix = "0.29"

View file

@ -15,10 +15,15 @@ doctest = false
[dependencies] [dependencies]
anyhow.workspace = true anyhow.workspace = true
assistant_tool.workspace = true assistant_tool.workspace = true
futures.workspace = true
gpui.workspace = true gpui.workspace = true
mlua.workspace = true mlua.workspace = true
parking_lot.workspace = true
project.workspace = true
regex.workspace = true
schemars.workspace = true schemars.workspace = true
serde.workspace = true serde.workspace = true
serde_json.workspace = true serde_json.workspace = true
smol.workspace = true
util.workspace = true
workspace.workspace = true workspace.workspace = true
regex.workspace = true

View file

@ -1,16 +1,22 @@
use anyhow::anyhow; use anyhow::anyhow;
use assistant_tool::{Tool, ToolRegistry}; use assistant_tool::{Tool, ToolRegistry};
use gpui::{App, AppContext as _, Task, WeakEntity, Window}; use futures::{
channel::{mpsc, oneshot},
SinkExt, StreamExt as _,
};
use gpui::{App, AppContext as _, AsyncApp, Task, WeakEntity, Window};
use mlua::{Function, Lua, MultiValue, Result, UserData, UserDataMethods}; use mlua::{Function, Lua, MultiValue, Result, UserData, UserDataMethods};
use parking_lot::Mutex;
use project::{search::SearchQuery, ProjectPath, WorktreeId};
use schemars::JsonSchema; use schemars::JsonSchema;
use serde::Deserialize; use serde::Deserialize;
use std::{ use std::{
cell::RefCell, cell::RefCell,
collections::HashMap, collections::{HashMap, HashSet},
path::{Path, PathBuf}, path::{Path, PathBuf},
rc::Rc,
sync::Arc, sync::Arc,
}; };
use util::paths::PathMatcher;
use workspace::Workspace; use workspace::Workspace;
pub fn init(cx: &App) { pub fn init(cx: &App) {
@ -59,32 +65,49 @@ string being a match that was found within the file)."#.into()
_window: &mut Window, _window: &mut Window,
cx: &mut App, cx: &mut App,
) -> Task<anyhow::Result<String>> { ) -> Task<anyhow::Result<String>> {
let root_dir = workspace.update(cx, |workspace, cx| { let worktree_root_dir_and_id = workspace.update(cx, |workspace, cx| {
let first_worktree = workspace let first_worktree = workspace
.visible_worktrees(cx) .visible_worktrees(cx)
.next() .next()
.ok_or_else(|| anyhow!("no worktrees"))?; .ok_or_else(|| anyhow!("no worktrees"))?;
workspace let worktree_id = first_worktree.read(cx).id();
.absolute_path_of_worktree(first_worktree.read(cx).id(), cx) let root_dir = workspace
.ok_or_else(|| anyhow!("no worktree root")) .absolute_path_of_worktree(worktree_id, cx)
.ok_or_else(|| anyhow!("no worktree root"))?;
Ok((root_dir, worktree_id))
}); });
let root_dir = match root_dir { let (root_dir, worktree_id) = match worktree_root_dir_and_id {
Ok(root_dir) => root_dir, Ok(Ok(worktree_root_dir_and_id)) => worktree_root_dir_and_id,
Err(err) => return Task::ready(Err(err)), Ok(Err(err)) => return Task::ready(Err(err)),
};
let root_dir = match root_dir {
Ok(root_dir) => root_dir,
Err(err) => return Task::ready(Err(err)), Err(err) => return Task::ready(Err(err)),
}; };
let input = match serde_json::from_value::<ScriptingToolInput>(input) { let input = match serde_json::from_value::<ScriptingToolInput>(input) {
Err(err) => return Task::ready(Err(err.into())), Err(err) => return Task::ready(Err(err.into())),
Ok(input) => input, Ok(input) => input,
}; };
let (foreground_tx, mut foreground_rx) = mpsc::channel::<ForegroundFn>(1);
cx.spawn(move |cx| async move {
while let Some(request) = foreground_rx.next().await {
request.0(cx.clone());
}
})
.detach();
let lua_script = input.lua_script; let lua_script = input.lua_script;
cx.background_spawn(async move { cx.background_spawn(async move {
let fs_changes = HashMap::new(); let fs_changes = HashMap::new();
let output = run_sandboxed_lua(&lua_script, fs_changes, root_dir) let output = run_sandboxed_lua(
.map_err(|err| anyhow!(format!("{err}")))?; &lua_script,
fs_changes,
root_dir,
worktree_id,
workspace,
foreground_tx,
)
.await
.map_err(|err| anyhow!(format!("{err}")))?;
let output = output.printed_lines.join("\n"); let output = output.printed_lines.join("\n");
Ok(format!("The script output the following:\n{output}")) Ok(format!("The script output the following:\n{output}"))
@ -92,6 +115,38 @@ string being a match that was found within the file)."#.into()
} }
} }
struct ForegroundFn(Box<dyn FnOnce(AsyncApp) + Send>);
async fn run_foreground_fn<R: Send + 'static>(
description: &str,
foreground_tx: &mut mpsc::Sender<ForegroundFn>,
function: Box<dyn FnOnce(AsyncApp) -> anyhow::Result<R> + Send>,
) -> Result<R> {
let (response_tx, response_rx) = oneshot::channel();
let send_result = foreground_tx
.send(ForegroundFn(Box::new(move |cx| {
response_tx.send(function(cx)).ok();
})))
.await;
match send_result {
Ok(()) => (),
Err(err) => {
return Err(mlua::Error::runtime(format!(
"Internal error while enqueuing work for {description}: {err}"
)))
}
}
match response_rx.await {
Ok(Ok(result)) => Ok(result),
Ok(Err(err)) => Err(mlua::Error::runtime(format!(
"Error while {description}: {err}"
))),
Err(oneshot::Canceled) => Err(mlua::Error::runtime(format!(
"Internal error: response oneshot was canceled while {description}."
))),
}
}
const SANDBOX_PREAMBLE: &str = include_str!("sandbox_preamble.lua"); const SANDBOX_PREAMBLE: &str = include_str!("sandbox_preamble.lua");
struct FileContent(RefCell<Vec<u8>>); struct FileContent(RefCell<Vec<u8>>);
@ -103,7 +158,7 @@ impl UserData for FileContent {
} }
/// Sandboxed print() function in Lua. /// Sandboxed print() function in Lua.
fn print(lua: &Lua, printed_lines: Rc<RefCell<Vec<String>>>) -> Result<Function> { fn print(lua: &Lua, printed_lines: Arc<Mutex<Vec<String>>>) -> Result<Function> {
lua.create_function(move |_, args: MultiValue| { lua.create_function(move |_, args: MultiValue| {
let mut string = String::new(); let mut string = String::new();
@ -117,7 +172,7 @@ fn print(lua: &Lua, printed_lines: Rc<RefCell<Vec<String>>>) -> Result<Function>
string.push_str(arg.to_string()?.as_str()) string.push_str(arg.to_string()?.as_str())
} }
printed_lines.borrow_mut().push(string); printed_lines.lock().push(string);
Ok(()) Ok(())
}) })
@ -125,103 +180,139 @@ fn print(lua: &Lua, printed_lines: Rc<RefCell<Vec<String>>>) -> Result<Function>
fn search( fn search(
lua: &Lua, lua: &Lua,
_fs_changes: Rc<RefCell<HashMap<PathBuf, Vec<u8>>>>, _fs_changes: Arc<Mutex<HashMap<PathBuf, Vec<u8>>>>,
root_dir: PathBuf, root_dir: PathBuf,
worktree_id: WorktreeId,
workspace: WeakEntity<Workspace>,
foreground_tx: mpsc::Sender<ForegroundFn>,
) -> Result<Function> { ) -> Result<Function> {
lua.create_function(move |lua, regex: String| { lua.create_async_function(move |lua, regex: String| {
use mlua::Table; let root_dir = root_dir.clone();
use regex::Regex; let workspace = workspace.clone();
use std::fs; let mut foreground_tx = foreground_tx.clone();
async move {
use mlua::Table;
use regex::Regex;
use std::fs;
// Function to recursively search directory // TODO: Allow specification of these options.
let search_regex = match Regex::new(&regex) { let search_query = SearchQuery::regex(
Ok(re) => re, &regex,
Err(e) => return Err(mlua::Error::runtime(format!("Invalid regex: {}", e))), false,
}; false,
false,
let mut search_results: Vec<Result<Table>> = Vec::new(); PathMatcher::default(),
PathMatcher::default(),
// Create an explicit stack for directories to process None,
let mut dir_stack = vec![root_dir.clone()]; );
let search_query = match search_query {
while let Some(current_dir) = dir_stack.pop() { Ok(query) => query,
// Process each entry in the current directory Err(e) => return Err(mlua::Error::runtime(format!("Invalid search query: {}", e))),
let entries = match fs::read_dir(&current_dir) {
Ok(entries) => entries,
Err(e) => return Err(e.into()),
}; };
for entry_result in entries { // TODO: Should use `search_query.regex`. The tool description should also be updated,
let entry = match entry_result { // as it specifies standard regex.
Ok(e) => e, let search_regex = match Regex::new(&regex) {
Err(e) => return Err(e.into()), Ok(re) => re,
}; Err(e) => return Err(mlua::Error::runtime(format!("Invalid regex: {}", e))),
};
let path = entry.path(); let project_path_rx =
find_search_candidates(search_query, workspace, &mut foreground_tx).await?;
if path.is_dir() { let mut search_results: Vec<Result<Table>> = Vec::new();
// Skip .git directory and other common directories to ignore while let Ok(project_path) = project_path_rx.recv().await {
let dir_name = path.file_name().unwrap_or_default().to_string_lossy(); if project_path.worktree_id != worktree_id {
if !dir_name.starts_with('.') continue;
&& dir_name != "node_modules" }
&& dir_name != "target"
{ let path = root_dir.join(project_path.path);
// Instead of recursive call, add to stack
dir_stack.push(path); // Skip files larger than 1MB
if let Ok(metadata) = fs::metadata(&path) {
if metadata.len() > 1_000_000 {
continue;
} }
} else if path.is_file() { }
// Skip binary files and very large files
if let Ok(metadata) = fs::metadata(&path) { // Attempt to read the file as text
if metadata.len() > 1_000_000 { if let Ok(content) = fs::read_to_string(&path) {
// Skip files larger than 1MB let mut matches = Vec::new();
continue;
} // Find all regex matches in the content
for capture in search_regex.find_iter(&content) {
matches.push(capture.as_str().to_string());
} }
// Attempt to read the file as text // If we found matches, create a result entry
if let Ok(content) = fs::read_to_string(&path) { if !matches.is_empty() {
let mut matches = Vec::new(); let result_entry = lua.create_table()?;
result_entry.set("path", path.to_string_lossy().to_string())?;
// Find all regex matches in the content let matches_table = lua.create_table()?;
for capture in search_regex.find_iter(&content) { for (i, m) in matches.iter().enumerate() {
matches.push(capture.as_str().to_string()); matches_table.set(i + 1, m.clone())?;
} }
result_entry.set("matches", matches_table)?;
// If we found matches, create a result entry search_results.push(Ok(result_entry));
if !matches.is_empty() {
let result_entry = lua.create_table()?;
result_entry.set("path", path.to_string_lossy().to_string())?;
let matches_table = lua.create_table()?;
for (i, m) in matches.iter().enumerate() {
matches_table.set(i + 1, m.clone())?;
}
result_entry.set("matches", matches_table)?;
search_results.push(Ok(result_entry));
}
} }
} }
} }
}
// Create a table to hold our results // Create a table to hold our results
let results_table = lua.create_table()?; let results_table = lua.create_table()?;
for (i, result) in search_results.into_iter().enumerate() { for (i, result) in search_results.into_iter().enumerate() {
match result { match result {
Ok(entry) => results_table.set(i + 1, entry)?, Ok(entry) => results_table.set(i + 1, entry)?,
Err(e) => return Err(e), Err(e) => return Err(e),
}
} }
}
Ok(results_table) Ok(results_table)
}
}) })
} }
async fn find_search_candidates(
search_query: SearchQuery,
workspace: WeakEntity<Workspace>,
foreground_tx: &mut mpsc::Sender<ForegroundFn>,
) -> Result<smol::channel::Receiver<ProjectPath>> {
run_foreground_fn(
"finding search file candidates",
foreground_tx,
Box::new(move |mut cx| {
workspace.update(&mut cx, move |workspace, cx| {
workspace.project().update(cx, |project, cx| {
project.worktree_store().update(cx, |worktree_store, cx| {
// TODO: Better limit? For now this is the same as
// MAX_SEARCH_RESULT_FILES.
let limit = 5000;
// TODO: Providing non-empty open_entries can make this a bit more
// efficient as it can skip checking that these paths are textual.
let open_entries = HashSet::default();
// TODO: This is searching all worktrees, but should only search the
// current worktree
worktree_store.find_search_candidates(
search_query,
limit,
open_entries,
project.fs().clone(),
cx,
)
})
})
})
}),
)
.await
}
/// Sandboxed io.open() function in Lua. /// Sandboxed io.open() function in Lua.
fn io_open( fn io_open(
lua: &Lua, lua: &Lua,
fs_changes: Rc<RefCell<HashMap<PathBuf, Vec<u8>>>>, fs_changes: Arc<Mutex<HashMap<PathBuf, Vec<u8>>>>,
root_dir: PathBuf, root_dir: PathBuf,
) -> Result<Function> { ) -> Result<Function> {
lua.create_function(move |lua, (path_str, mode): (String, Option<String>)| { lua.create_function(move |lua, (path_str, mode): (String, Option<String>)| {
@ -281,7 +372,7 @@ fn io_open(
// Don't actually write to disk; instead, just update fs_changes. // Don't actually write to disk; instead, just update fs_changes.
let path_buf = PathBuf::from(&path); let path_buf = PathBuf::from(&path);
fs_changes fs_changes
.borrow_mut() .lock()
.insert(path_buf.clone(), content_vec.clone()); .insert(path_buf.clone(), content_vec.clone());
} }
@ -333,13 +424,15 @@ fn io_open(
return Ok((Some(file), String::new())); return Ok((Some(file), String::new()));
} }
let is_in_changes = fs_changes.borrow().contains_key(&path); let fs_changes_map = fs_changes.lock();
let is_in_changes = fs_changes_map.contains_key(&path);
let file_exists = is_in_changes || path.exists(); let file_exists = is_in_changes || path.exists();
let mut file_content = Vec::new(); let mut file_content = Vec::new();
if file_exists && !truncate { if file_exists && !truncate {
if is_in_changes { if is_in_changes {
file_content = fs_changes.borrow().get(&path).unwrap().clone(); file_content = fs_changes_map.get(&path).unwrap().clone();
} else { } else {
// Try to read existing content if file exists and we're not truncating // Try to read existing content if file exists and we're not truncating
match std::fs::read(&path) { match std::fs::read(&path) {
@ -349,6 +442,8 @@ fn io_open(
} }
} }
drop(fs_changes_map); // Unlock the fs_changes mutex.
// If in append mode, position should be at the end // If in append mode, position should be at the end
let position = if append && file_exists { let position = if append && file_exists {
file_content.len() file_content.len()
@ -582,9 +677,7 @@ fn io_open(
// Update fs_changes // Update fs_changes
let path = file_userdata.get::<String>("__path")?; let path = file_userdata.get::<String>("__path")?;
let path_buf = PathBuf::from(path); let path_buf = PathBuf::from(path);
fs_changes fs_changes.lock().insert(path_buf, content_vec.clone());
.borrow_mut()
.insert(path_buf, content_vec.clone());
Ok(true) Ok(true)
})? })?
@ -597,33 +690,46 @@ fn io_open(
} }
/// Runs a Lua script in a sandboxed environment and returns the printed lines /// Runs a Lua script in a sandboxed environment and returns the printed lines
pub fn run_sandboxed_lua( async fn run_sandboxed_lua(
script: &str, script: &str,
fs_changes: HashMap<PathBuf, Vec<u8>>, fs_changes: HashMap<PathBuf, Vec<u8>>,
root_dir: PathBuf, root_dir: PathBuf,
worktree_id: WorktreeId,
workspace: WeakEntity<Workspace>,
foreground_tx: mpsc::Sender<ForegroundFn>,
) -> Result<ScriptOutput> { ) -> Result<ScriptOutput> {
let lua = Lua::new(); let lua = Lua::new();
lua.set_memory_limit(2 * 1024 * 1024 * 1024)?; // 2 GB lua.set_memory_limit(2 * 1024 * 1024 * 1024)?; // 2 GB
let globals = lua.globals(); let globals = lua.globals();
// Track the lines the Lua script prints out. // Track the lines the Lua script prints out.
let printed_lines = Rc::new(RefCell::new(Vec::new())); let printed_lines = Arc::new(Mutex::new(Vec::new()));
let fs = Rc::new(RefCell::new(fs_changes)); let fs = Arc::new(Mutex::new(fs_changes));
globals.set("sb_print", print(&lua, printed_lines.clone())?)?; globals.set("sb_print", print(&lua, printed_lines.clone())?)?;
globals.set("search", search(&lua, fs.clone(), root_dir.clone())?)?; globals.set(
"search",
search(
&lua,
fs.clone(),
root_dir.clone(),
worktree_id,
workspace,
foreground_tx,
)?,
)?;
globals.set("sb_io_open", io_open(&lua, fs.clone(), root_dir)?)?; globals.set("sb_io_open", io_open(&lua, fs.clone(), root_dir)?)?;
globals.set("user_script", script)?; globals.set("user_script", script)?;
lua.load(SANDBOX_PREAMBLE).exec()?; lua.load(SANDBOX_PREAMBLE).exec_async().await?;
drop(lua); // Necessary so the Rc'd values get decremented. drop(lua); // Decrements the Arcs so that try_unwrap works.
Ok(ScriptOutput { Ok(ScriptOutput {
printed_lines: Rc::try_unwrap(printed_lines) printed_lines: Arc::try_unwrap(printed_lines)
.expect("There are still other references to printed_lines") .expect("There are still other references to printed_lines")
.into_inner(), .into_inner(),
fs_changes: Rc::try_unwrap(fs) fs_changes: Arc::try_unwrap(fs)
.expect("There are still other references to fs_changes") .expect("There are still other references to fs_changes")
.into_inner(), .into_inner(),
}) })