Add lua script access to code using cx + reuse project search logic (#26269)

Access to `cx` will be needed for anything that queries entities. In
this commit this is use of `WorktreeStore::find_search_candidates`. In
the future it will be things like access to LSP / tree-sitter outlines /
etc.

Changes to support access to `cx` from functions provided to the Lua
script:

* Adds a channel of requests that require a `cx`. Work enqueued to this
channel is run on the foreground thread.

* Adds `async` and `send` features to `mlua` crate so that async rust
functions can be used from Lua.

* Changes uses of `Rc<RefCell<...>>` to `Arc<Mutex<...>>` so that the
futures are `Send`.

One benefit of reusing project search logic for search candidates is
that it properly ignores paths.

Release Notes:

- N/A
This commit is contained in:
Michael Sloan 2025-03-07 03:02:49 -07:00 committed by GitHub
parent b0d1024f66
commit 205f9a9f03
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 222 additions and 105 deletions

6
Cargo.lock generated
View file

@ -8134,6 +8134,7 @@ checksum = "d3f763c1041eff92ffb5d7169968a327e1ed2ebfe425dac0ee5a35f29082534b"
dependencies = [
"bstr",
"either",
"futures-util",
"mlua-sys",
"num-traits",
"parking_lot",
@ -11915,12 +11916,17 @@ version = "0.1.0"
dependencies = [
"anyhow",
"assistant_tool",
"futures 0.3.31",
"gpui",
"mlua",
"parking_lot",
"project",
"regex",
"schemars",
"serde",
"serde_json",
"smol",
"util",
"workspace",
]

View file

@ -452,7 +452,7 @@ livekit = { git = "https://github.com/zed-industries/livekit-rust-sdks", rev = "
], default-features = false }
log = { version = "0.4.16", features = ["kv_unstable_serde", "serde"] }
markup5ever_rcdom = "0.3.0"
mlua = { version = "0.10", features = ["lua54", "vendored"] }
mlua = { version = "0.10", features = ["lua54", "vendored", "async", "send"] }
nanoid = "0.4"
nbformat = { version = "0.10.0" }
nix = "0.29"

View file

@ -15,10 +15,15 @@ doctest = false
[dependencies]
anyhow.workspace = true
assistant_tool.workspace = true
futures.workspace = true
gpui.workspace = true
mlua.workspace = true
parking_lot.workspace = true
project.workspace = true
regex.workspace = true
schemars.workspace = true
serde.workspace = true
serde_json.workspace = true
smol.workspace = true
util.workspace = true
workspace.workspace = true
regex.workspace = true

View file

@ -1,16 +1,22 @@
use anyhow::anyhow;
use assistant_tool::{Tool, ToolRegistry};
use gpui::{App, AppContext as _, Task, WeakEntity, Window};
use futures::{
channel::{mpsc, oneshot},
SinkExt, StreamExt as _,
};
use gpui::{App, AppContext as _, AsyncApp, Task, WeakEntity, Window};
use mlua::{Function, Lua, MultiValue, Result, UserData, UserDataMethods};
use parking_lot::Mutex;
use project::{search::SearchQuery, ProjectPath, WorktreeId};
use schemars::JsonSchema;
use serde::Deserialize;
use std::{
cell::RefCell,
collections::HashMap,
collections::{HashMap, HashSet},
path::{Path, PathBuf},
rc::Rc,
sync::Arc,
};
use util::paths::PathMatcher;
use workspace::Workspace;
pub fn init(cx: &App) {
@ -59,32 +65,49 @@ string being a match that was found within the file)."#.into()
_window: &mut Window,
cx: &mut App,
) -> Task<anyhow::Result<String>> {
let root_dir = workspace.update(cx, |workspace, cx| {
let worktree_root_dir_and_id = workspace.update(cx, |workspace, cx| {
let first_worktree = workspace
.visible_worktrees(cx)
.next()
.ok_or_else(|| anyhow!("no worktrees"))?;
workspace
.absolute_path_of_worktree(first_worktree.read(cx).id(), cx)
.ok_or_else(|| anyhow!("no worktree root"))
let worktree_id = first_worktree.read(cx).id();
let root_dir = workspace
.absolute_path_of_worktree(worktree_id, cx)
.ok_or_else(|| anyhow!("no worktree root"))?;
Ok((root_dir, worktree_id))
});
let root_dir = match root_dir {
Ok(root_dir) => root_dir,
Err(err) => return Task::ready(Err(err)),
};
let root_dir = match root_dir {
Ok(root_dir) => root_dir,
let (root_dir, worktree_id) = match worktree_root_dir_and_id {
Ok(Ok(worktree_root_dir_and_id)) => worktree_root_dir_and_id,
Ok(Err(err)) => return Task::ready(Err(err)),
Err(err) => return Task::ready(Err(err)),
};
let input = match serde_json::from_value::<ScriptingToolInput>(input) {
Err(err) => return Task::ready(Err(err.into())),
Ok(input) => input,
};
let (foreground_tx, mut foreground_rx) = mpsc::channel::<ForegroundFn>(1);
cx.spawn(move |cx| async move {
while let Some(request) = foreground_rx.next().await {
request.0(cx.clone());
}
})
.detach();
let lua_script = input.lua_script;
cx.background_spawn(async move {
let fs_changes = HashMap::new();
let output = run_sandboxed_lua(&lua_script, fs_changes, root_dir)
.map_err(|err| anyhow!(format!("{err}")))?;
let output = run_sandboxed_lua(
&lua_script,
fs_changes,
root_dir,
worktree_id,
workspace,
foreground_tx,
)
.await
.map_err(|err| anyhow!(format!("{err}")))?;
let output = output.printed_lines.join("\n");
Ok(format!("The script output the following:\n{output}"))
@ -92,6 +115,38 @@ string being a match that was found within the file)."#.into()
}
}
struct ForegroundFn(Box<dyn FnOnce(AsyncApp) + Send>);
async fn run_foreground_fn<R: Send + 'static>(
description: &str,
foreground_tx: &mut mpsc::Sender<ForegroundFn>,
function: Box<dyn FnOnce(AsyncApp) -> anyhow::Result<R> + Send>,
) -> Result<R> {
let (response_tx, response_rx) = oneshot::channel();
let send_result = foreground_tx
.send(ForegroundFn(Box::new(move |cx| {
response_tx.send(function(cx)).ok();
})))
.await;
match send_result {
Ok(()) => (),
Err(err) => {
return Err(mlua::Error::runtime(format!(
"Internal error while enqueuing work for {description}: {err}"
)))
}
}
match response_rx.await {
Ok(Ok(result)) => Ok(result),
Ok(Err(err)) => Err(mlua::Error::runtime(format!(
"Error while {description}: {err}"
))),
Err(oneshot::Canceled) => Err(mlua::Error::runtime(format!(
"Internal error: response oneshot was canceled while {description}."
))),
}
}
const SANDBOX_PREAMBLE: &str = include_str!("sandbox_preamble.lua");
struct FileContent(RefCell<Vec<u8>>);
@ -103,7 +158,7 @@ impl UserData for FileContent {
}
/// Sandboxed print() function in Lua.
fn print(lua: &Lua, printed_lines: Rc<RefCell<Vec<String>>>) -> Result<Function> {
fn print(lua: &Lua, printed_lines: Arc<Mutex<Vec<String>>>) -> Result<Function> {
lua.create_function(move |_, args: MultiValue| {
let mut string = String::new();
@ -117,7 +172,7 @@ fn print(lua: &Lua, printed_lines: Rc<RefCell<Vec<String>>>) -> Result<Function>
string.push_str(arg.to_string()?.as_str())
}
printed_lines.borrow_mut().push(string);
printed_lines.lock().push(string);
Ok(())
})
@ -125,103 +180,139 @@ fn print(lua: &Lua, printed_lines: Rc<RefCell<Vec<String>>>) -> Result<Function>
fn search(
lua: &Lua,
_fs_changes: Rc<RefCell<HashMap<PathBuf, Vec<u8>>>>,
_fs_changes: Arc<Mutex<HashMap<PathBuf, Vec<u8>>>>,
root_dir: PathBuf,
worktree_id: WorktreeId,
workspace: WeakEntity<Workspace>,
foreground_tx: mpsc::Sender<ForegroundFn>,
) -> Result<Function> {
lua.create_function(move |lua, regex: String| {
use mlua::Table;
use regex::Regex;
use std::fs;
lua.create_async_function(move |lua, regex: String| {
let root_dir = root_dir.clone();
let workspace = workspace.clone();
let mut foreground_tx = foreground_tx.clone();
async move {
use mlua::Table;
use regex::Regex;
use std::fs;
// Function to recursively search directory
let search_regex = match Regex::new(&regex) {
Ok(re) => re,
Err(e) => return Err(mlua::Error::runtime(format!("Invalid regex: {}", e))),
};
let mut search_results: Vec<Result<Table>> = Vec::new();
// Create an explicit stack for directories to process
let mut dir_stack = vec![root_dir.clone()];
while let Some(current_dir) = dir_stack.pop() {
// Process each entry in the current directory
let entries = match fs::read_dir(&current_dir) {
Ok(entries) => entries,
Err(e) => return Err(e.into()),
// TODO: Allow specification of these options.
let search_query = SearchQuery::regex(
&regex,
false,
false,
false,
PathMatcher::default(),
PathMatcher::default(),
None,
);
let search_query = match search_query {
Ok(query) => query,
Err(e) => return Err(mlua::Error::runtime(format!("Invalid search query: {}", e))),
};
for entry_result in entries {
let entry = match entry_result {
Ok(e) => e,
Err(e) => return Err(e.into()),
};
// TODO: Should use `search_query.regex`. The tool description should also be updated,
// as it specifies standard regex.
let search_regex = match Regex::new(&regex) {
Ok(re) => re,
Err(e) => return Err(mlua::Error::runtime(format!("Invalid regex: {}", e))),
};
let path = entry.path();
let project_path_rx =
find_search_candidates(search_query, workspace, &mut foreground_tx).await?;
if path.is_dir() {
// Skip .git directory and other common directories to ignore
let dir_name = path.file_name().unwrap_or_default().to_string_lossy();
if !dir_name.starts_with('.')
&& dir_name != "node_modules"
&& dir_name != "target"
{
// Instead of recursive call, add to stack
dir_stack.push(path);
let mut search_results: Vec<Result<Table>> = Vec::new();
while let Ok(project_path) = project_path_rx.recv().await {
if project_path.worktree_id != worktree_id {
continue;
}
let path = root_dir.join(project_path.path);
// Skip files larger than 1MB
if let Ok(metadata) = fs::metadata(&path) {
if metadata.len() > 1_000_000 {
continue;
}
} else if path.is_file() {
// Skip binary files and very large files
if let Ok(metadata) = fs::metadata(&path) {
if metadata.len() > 1_000_000 {
// Skip files larger than 1MB
continue;
}
}
// Attempt to read the file as text
if let Ok(content) = fs::read_to_string(&path) {
let mut matches = Vec::new();
// Find all regex matches in the content
for capture in search_regex.find_iter(&content) {
matches.push(capture.as_str().to_string());
}
// Attempt to read the file as text
if let Ok(content) = fs::read_to_string(&path) {
let mut matches = Vec::new();
// If we found matches, create a result entry
if !matches.is_empty() {
let result_entry = lua.create_table()?;
result_entry.set("path", path.to_string_lossy().to_string())?;
// Find all regex matches in the content
for capture in search_regex.find_iter(&content) {
matches.push(capture.as_str().to_string());
let matches_table = lua.create_table()?;
for (i, m) in matches.iter().enumerate() {
matches_table.set(i + 1, m.clone())?;
}
result_entry.set("matches", matches_table)?;
// If we found matches, create a result entry
if !matches.is_empty() {
let result_entry = lua.create_table()?;
result_entry.set("path", path.to_string_lossy().to_string())?;
let matches_table = lua.create_table()?;
for (i, m) in matches.iter().enumerate() {
matches_table.set(i + 1, m.clone())?;
}
result_entry.set("matches", matches_table)?;
search_results.push(Ok(result_entry));
}
search_results.push(Ok(result_entry));
}
}
}
}
// Create a table to hold our results
let results_table = lua.create_table()?;
for (i, result) in search_results.into_iter().enumerate() {
match result {
Ok(entry) => results_table.set(i + 1, entry)?,
Err(e) => return Err(e),
// Create a table to hold our results
let results_table = lua.create_table()?;
for (i, result) in search_results.into_iter().enumerate() {
match result {
Ok(entry) => results_table.set(i + 1, entry)?,
Err(e) => return Err(e),
}
}
}
Ok(results_table)
Ok(results_table)
}
})
}
async fn find_search_candidates(
search_query: SearchQuery,
workspace: WeakEntity<Workspace>,
foreground_tx: &mut mpsc::Sender<ForegroundFn>,
) -> Result<smol::channel::Receiver<ProjectPath>> {
run_foreground_fn(
"finding search file candidates",
foreground_tx,
Box::new(move |mut cx| {
workspace.update(&mut cx, move |workspace, cx| {
workspace.project().update(cx, |project, cx| {
project.worktree_store().update(cx, |worktree_store, cx| {
// TODO: Better limit? For now this is the same as
// MAX_SEARCH_RESULT_FILES.
let limit = 5000;
// TODO: Providing non-empty open_entries can make this a bit more
// efficient as it can skip checking that these paths are textual.
let open_entries = HashSet::default();
// TODO: This is searching all worktrees, but should only search the
// current worktree
worktree_store.find_search_candidates(
search_query,
limit,
open_entries,
project.fs().clone(),
cx,
)
})
})
})
}),
)
.await
}
/// Sandboxed io.open() function in Lua.
fn io_open(
lua: &Lua,
fs_changes: Rc<RefCell<HashMap<PathBuf, Vec<u8>>>>,
fs_changes: Arc<Mutex<HashMap<PathBuf, Vec<u8>>>>,
root_dir: PathBuf,
) -> Result<Function> {
lua.create_function(move |lua, (path_str, mode): (String, Option<String>)| {
@ -281,7 +372,7 @@ fn io_open(
// Don't actually write to disk; instead, just update fs_changes.
let path_buf = PathBuf::from(&path);
fs_changes
.borrow_mut()
.lock()
.insert(path_buf.clone(), content_vec.clone());
}
@ -333,13 +424,15 @@ fn io_open(
return Ok((Some(file), String::new()));
}
let is_in_changes = fs_changes.borrow().contains_key(&path);
let fs_changes_map = fs_changes.lock();
let is_in_changes = fs_changes_map.contains_key(&path);
let file_exists = is_in_changes || path.exists();
let mut file_content = Vec::new();
if file_exists && !truncate {
if is_in_changes {
file_content = fs_changes.borrow().get(&path).unwrap().clone();
file_content = fs_changes_map.get(&path).unwrap().clone();
} else {
// Try to read existing content if file exists and we're not truncating
match std::fs::read(&path) {
@ -349,6 +442,8 @@ fn io_open(
}
}
drop(fs_changes_map); // Unlock the fs_changes mutex.
// If in append mode, position should be at the end
let position = if append && file_exists {
file_content.len()
@ -582,9 +677,7 @@ fn io_open(
// Update fs_changes
let path = file_userdata.get::<String>("__path")?;
let path_buf = PathBuf::from(path);
fs_changes
.borrow_mut()
.insert(path_buf, content_vec.clone());
fs_changes.lock().insert(path_buf, content_vec.clone());
Ok(true)
})?
@ -597,33 +690,46 @@ fn io_open(
}
/// Runs a Lua script in a sandboxed environment and returns the printed lines
pub fn run_sandboxed_lua(
async fn run_sandboxed_lua(
script: &str,
fs_changes: HashMap<PathBuf, Vec<u8>>,
root_dir: PathBuf,
worktree_id: WorktreeId,
workspace: WeakEntity<Workspace>,
foreground_tx: mpsc::Sender<ForegroundFn>,
) -> Result<ScriptOutput> {
let lua = Lua::new();
lua.set_memory_limit(2 * 1024 * 1024 * 1024)?; // 2 GB
let globals = lua.globals();
// Track the lines the Lua script prints out.
let printed_lines = Rc::new(RefCell::new(Vec::new()));
let fs = Rc::new(RefCell::new(fs_changes));
let printed_lines = Arc::new(Mutex::new(Vec::new()));
let fs = Arc::new(Mutex::new(fs_changes));
globals.set("sb_print", print(&lua, printed_lines.clone())?)?;
globals.set("search", search(&lua, fs.clone(), root_dir.clone())?)?;
globals.set(
"search",
search(
&lua,
fs.clone(),
root_dir.clone(),
worktree_id,
workspace,
foreground_tx,
)?,
)?;
globals.set("sb_io_open", io_open(&lua, fs.clone(), root_dir)?)?;
globals.set("user_script", script)?;
lua.load(SANDBOX_PREAMBLE).exec()?;
lua.load(SANDBOX_PREAMBLE).exec_async().await?;
drop(lua); // Necessary so the Rc'd values get decremented.
drop(lua); // Decrements the Arcs so that try_unwrap works.
Ok(ScriptOutput {
printed_lines: Rc::try_unwrap(printed_lines)
printed_lines: Arc::try_unwrap(printed_lines)
.expect("There are still other references to printed_lines")
.into_inner(),
fs_changes: Rc::try_unwrap(fs)
fs_changes: Arc::try_unwrap(fs)
.expect("There are still other references to fs_changes")
.into_inner(),
})