Add lua script access to code using cx
+ reuse project search logic (#26269)
Access to `cx` will be needed for anything that queries entities. In this commit this is use of `WorktreeStore::find_search_candidates`. In the future it will be things like access to LSP / tree-sitter outlines / etc. Changes to support access to `cx` from functions provided to the Lua script: * Adds a channel of requests that require a `cx`. Work enqueued to this channel is run on the foreground thread. * Adds `async` and `send` features to `mlua` crate so that async rust functions can be used from Lua. * Changes uses of `Rc<RefCell<...>>` to `Arc<Mutex<...>>` so that the futures are `Send`. One benefit of reusing project search logic for search candidates is that it properly ignores paths. Release Notes: - N/A
This commit is contained in:
parent
b0d1024f66
commit
205f9a9f03
4 changed files with 222 additions and 105 deletions
6
Cargo.lock
generated
6
Cargo.lock
generated
|
@ -8134,6 +8134,7 @@ checksum = "d3f763c1041eff92ffb5d7169968a327e1ed2ebfe425dac0ee5a35f29082534b"
|
|||
dependencies = [
|
||||
"bstr",
|
||||
"either",
|
||||
"futures-util",
|
||||
"mlua-sys",
|
||||
"num-traits",
|
||||
"parking_lot",
|
||||
|
@ -11915,12 +11916,17 @@ version = "0.1.0"
|
|||
dependencies = [
|
||||
"anyhow",
|
||||
"assistant_tool",
|
||||
"futures 0.3.31",
|
||||
"gpui",
|
||||
"mlua",
|
||||
"parking_lot",
|
||||
"project",
|
||||
"regex",
|
||||
"schemars",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"smol",
|
||||
"util",
|
||||
"workspace",
|
||||
]
|
||||
|
||||
|
|
|
@ -452,7 +452,7 @@ livekit = { git = "https://github.com/zed-industries/livekit-rust-sdks", rev = "
|
|||
], default-features = false }
|
||||
log = { version = "0.4.16", features = ["kv_unstable_serde", "serde"] }
|
||||
markup5ever_rcdom = "0.3.0"
|
||||
mlua = { version = "0.10", features = ["lua54", "vendored"] }
|
||||
mlua = { version = "0.10", features = ["lua54", "vendored", "async", "send"] }
|
||||
nanoid = "0.4"
|
||||
nbformat = { version = "0.10.0" }
|
||||
nix = "0.29"
|
||||
|
|
|
@ -15,10 +15,15 @@ doctest = false
|
|||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
assistant_tool.workspace = true
|
||||
futures.workspace = true
|
||||
gpui.workspace = true
|
||||
mlua.workspace = true
|
||||
parking_lot.workspace = true
|
||||
project.workspace = true
|
||||
regex.workspace = true
|
||||
schemars.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
smol.workspace = true
|
||||
util.workspace = true
|
||||
workspace.workspace = true
|
||||
regex.workspace = true
|
||||
|
|
|
@ -1,16 +1,22 @@
|
|||
use anyhow::anyhow;
|
||||
use assistant_tool::{Tool, ToolRegistry};
|
||||
use gpui::{App, AppContext as _, Task, WeakEntity, Window};
|
||||
use futures::{
|
||||
channel::{mpsc, oneshot},
|
||||
SinkExt, StreamExt as _,
|
||||
};
|
||||
use gpui::{App, AppContext as _, AsyncApp, Task, WeakEntity, Window};
|
||||
use mlua::{Function, Lua, MultiValue, Result, UserData, UserDataMethods};
|
||||
use parking_lot::Mutex;
|
||||
use project::{search::SearchQuery, ProjectPath, WorktreeId};
|
||||
use schemars::JsonSchema;
|
||||
use serde::Deserialize;
|
||||
use std::{
|
||||
cell::RefCell,
|
||||
collections::HashMap,
|
||||
collections::{HashMap, HashSet},
|
||||
path::{Path, PathBuf},
|
||||
rc::Rc,
|
||||
sync::Arc,
|
||||
};
|
||||
use util::paths::PathMatcher;
|
||||
use workspace::Workspace;
|
||||
|
||||
pub fn init(cx: &App) {
|
||||
|
@ -59,32 +65,49 @@ string being a match that was found within the file)."#.into()
|
|||
_window: &mut Window,
|
||||
cx: &mut App,
|
||||
) -> Task<anyhow::Result<String>> {
|
||||
let root_dir = workspace.update(cx, |workspace, cx| {
|
||||
let worktree_root_dir_and_id = workspace.update(cx, |workspace, cx| {
|
||||
let first_worktree = workspace
|
||||
.visible_worktrees(cx)
|
||||
.next()
|
||||
.ok_or_else(|| anyhow!("no worktrees"))?;
|
||||
workspace
|
||||
.absolute_path_of_worktree(first_worktree.read(cx).id(), cx)
|
||||
.ok_or_else(|| anyhow!("no worktree root"))
|
||||
let worktree_id = first_worktree.read(cx).id();
|
||||
let root_dir = workspace
|
||||
.absolute_path_of_worktree(worktree_id, cx)
|
||||
.ok_or_else(|| anyhow!("no worktree root"))?;
|
||||
Ok((root_dir, worktree_id))
|
||||
});
|
||||
let root_dir = match root_dir {
|
||||
Ok(root_dir) => root_dir,
|
||||
Err(err) => return Task::ready(Err(err)),
|
||||
};
|
||||
let root_dir = match root_dir {
|
||||
Ok(root_dir) => root_dir,
|
||||
let (root_dir, worktree_id) = match worktree_root_dir_and_id {
|
||||
Ok(Ok(worktree_root_dir_and_id)) => worktree_root_dir_and_id,
|
||||
Ok(Err(err)) => return Task::ready(Err(err)),
|
||||
Err(err) => return Task::ready(Err(err)),
|
||||
};
|
||||
let input = match serde_json::from_value::<ScriptingToolInput>(input) {
|
||||
Err(err) => return Task::ready(Err(err.into())),
|
||||
Ok(input) => input,
|
||||
};
|
||||
|
||||
let (foreground_tx, mut foreground_rx) = mpsc::channel::<ForegroundFn>(1);
|
||||
|
||||
cx.spawn(move |cx| async move {
|
||||
while let Some(request) = foreground_rx.next().await {
|
||||
request.0(cx.clone());
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
|
||||
let lua_script = input.lua_script;
|
||||
cx.background_spawn(async move {
|
||||
let fs_changes = HashMap::new();
|
||||
let output = run_sandboxed_lua(&lua_script, fs_changes, root_dir)
|
||||
.map_err(|err| anyhow!(format!("{err}")))?;
|
||||
let output = run_sandboxed_lua(
|
||||
&lua_script,
|
||||
fs_changes,
|
||||
root_dir,
|
||||
worktree_id,
|
||||
workspace,
|
||||
foreground_tx,
|
||||
)
|
||||
.await
|
||||
.map_err(|err| anyhow!(format!("{err}")))?;
|
||||
let output = output.printed_lines.join("\n");
|
||||
|
||||
Ok(format!("The script output the following:\n{output}"))
|
||||
|
@ -92,6 +115,38 @@ string being a match that was found within the file)."#.into()
|
|||
}
|
||||
}
|
||||
|
||||
struct ForegroundFn(Box<dyn FnOnce(AsyncApp) + Send>);
|
||||
|
||||
async fn run_foreground_fn<R: Send + 'static>(
|
||||
description: &str,
|
||||
foreground_tx: &mut mpsc::Sender<ForegroundFn>,
|
||||
function: Box<dyn FnOnce(AsyncApp) -> anyhow::Result<R> + Send>,
|
||||
) -> Result<R> {
|
||||
let (response_tx, response_rx) = oneshot::channel();
|
||||
let send_result = foreground_tx
|
||||
.send(ForegroundFn(Box::new(move |cx| {
|
||||
response_tx.send(function(cx)).ok();
|
||||
})))
|
||||
.await;
|
||||
match send_result {
|
||||
Ok(()) => (),
|
||||
Err(err) => {
|
||||
return Err(mlua::Error::runtime(format!(
|
||||
"Internal error while enqueuing work for {description}: {err}"
|
||||
)))
|
||||
}
|
||||
}
|
||||
match response_rx.await {
|
||||
Ok(Ok(result)) => Ok(result),
|
||||
Ok(Err(err)) => Err(mlua::Error::runtime(format!(
|
||||
"Error while {description}: {err}"
|
||||
))),
|
||||
Err(oneshot::Canceled) => Err(mlua::Error::runtime(format!(
|
||||
"Internal error: response oneshot was canceled while {description}."
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
const SANDBOX_PREAMBLE: &str = include_str!("sandbox_preamble.lua");
|
||||
|
||||
struct FileContent(RefCell<Vec<u8>>);
|
||||
|
@ -103,7 +158,7 @@ impl UserData for FileContent {
|
|||
}
|
||||
|
||||
/// Sandboxed print() function in Lua.
|
||||
fn print(lua: &Lua, printed_lines: Rc<RefCell<Vec<String>>>) -> Result<Function> {
|
||||
fn print(lua: &Lua, printed_lines: Arc<Mutex<Vec<String>>>) -> Result<Function> {
|
||||
lua.create_function(move |_, args: MultiValue| {
|
||||
let mut string = String::new();
|
||||
|
||||
|
@ -117,7 +172,7 @@ fn print(lua: &Lua, printed_lines: Rc<RefCell<Vec<String>>>) -> Result<Function>
|
|||
string.push_str(arg.to_string()?.as_str())
|
||||
}
|
||||
|
||||
printed_lines.borrow_mut().push(string);
|
||||
printed_lines.lock().push(string);
|
||||
|
||||
Ok(())
|
||||
})
|
||||
|
@ -125,103 +180,139 @@ fn print(lua: &Lua, printed_lines: Rc<RefCell<Vec<String>>>) -> Result<Function>
|
|||
|
||||
fn search(
|
||||
lua: &Lua,
|
||||
_fs_changes: Rc<RefCell<HashMap<PathBuf, Vec<u8>>>>,
|
||||
_fs_changes: Arc<Mutex<HashMap<PathBuf, Vec<u8>>>>,
|
||||
root_dir: PathBuf,
|
||||
worktree_id: WorktreeId,
|
||||
workspace: WeakEntity<Workspace>,
|
||||
foreground_tx: mpsc::Sender<ForegroundFn>,
|
||||
) -> Result<Function> {
|
||||
lua.create_function(move |lua, regex: String| {
|
||||
use mlua::Table;
|
||||
use regex::Regex;
|
||||
use std::fs;
|
||||
lua.create_async_function(move |lua, regex: String| {
|
||||
let root_dir = root_dir.clone();
|
||||
let workspace = workspace.clone();
|
||||
let mut foreground_tx = foreground_tx.clone();
|
||||
async move {
|
||||
use mlua::Table;
|
||||
use regex::Regex;
|
||||
use std::fs;
|
||||
|
||||
// Function to recursively search directory
|
||||
let search_regex = match Regex::new(®ex) {
|
||||
Ok(re) => re,
|
||||
Err(e) => return Err(mlua::Error::runtime(format!("Invalid regex: {}", e))),
|
||||
};
|
||||
|
||||
let mut search_results: Vec<Result<Table>> = Vec::new();
|
||||
|
||||
// Create an explicit stack for directories to process
|
||||
let mut dir_stack = vec![root_dir.clone()];
|
||||
|
||||
while let Some(current_dir) = dir_stack.pop() {
|
||||
// Process each entry in the current directory
|
||||
let entries = match fs::read_dir(¤t_dir) {
|
||||
Ok(entries) => entries,
|
||||
Err(e) => return Err(e.into()),
|
||||
// TODO: Allow specification of these options.
|
||||
let search_query = SearchQuery::regex(
|
||||
®ex,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
PathMatcher::default(),
|
||||
PathMatcher::default(),
|
||||
None,
|
||||
);
|
||||
let search_query = match search_query {
|
||||
Ok(query) => query,
|
||||
Err(e) => return Err(mlua::Error::runtime(format!("Invalid search query: {}", e))),
|
||||
};
|
||||
|
||||
for entry_result in entries {
|
||||
let entry = match entry_result {
|
||||
Ok(e) => e,
|
||||
Err(e) => return Err(e.into()),
|
||||
};
|
||||
// TODO: Should use `search_query.regex`. The tool description should also be updated,
|
||||
// as it specifies standard regex.
|
||||
let search_regex = match Regex::new(®ex) {
|
||||
Ok(re) => re,
|
||||
Err(e) => return Err(mlua::Error::runtime(format!("Invalid regex: {}", e))),
|
||||
};
|
||||
|
||||
let path = entry.path();
|
||||
let project_path_rx =
|
||||
find_search_candidates(search_query, workspace, &mut foreground_tx).await?;
|
||||
|
||||
if path.is_dir() {
|
||||
// Skip .git directory and other common directories to ignore
|
||||
let dir_name = path.file_name().unwrap_or_default().to_string_lossy();
|
||||
if !dir_name.starts_with('.')
|
||||
&& dir_name != "node_modules"
|
||||
&& dir_name != "target"
|
||||
{
|
||||
// Instead of recursive call, add to stack
|
||||
dir_stack.push(path);
|
||||
let mut search_results: Vec<Result<Table>> = Vec::new();
|
||||
while let Ok(project_path) = project_path_rx.recv().await {
|
||||
if project_path.worktree_id != worktree_id {
|
||||
continue;
|
||||
}
|
||||
|
||||
let path = root_dir.join(project_path.path);
|
||||
|
||||
// Skip files larger than 1MB
|
||||
if let Ok(metadata) = fs::metadata(&path) {
|
||||
if metadata.len() > 1_000_000 {
|
||||
continue;
|
||||
}
|
||||
} else if path.is_file() {
|
||||
// Skip binary files and very large files
|
||||
if let Ok(metadata) = fs::metadata(&path) {
|
||||
if metadata.len() > 1_000_000 {
|
||||
// Skip files larger than 1MB
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// Attempt to read the file as text
|
||||
if let Ok(content) = fs::read_to_string(&path) {
|
||||
let mut matches = Vec::new();
|
||||
|
||||
// Find all regex matches in the content
|
||||
for capture in search_regex.find_iter(&content) {
|
||||
matches.push(capture.as_str().to_string());
|
||||
}
|
||||
|
||||
// Attempt to read the file as text
|
||||
if let Ok(content) = fs::read_to_string(&path) {
|
||||
let mut matches = Vec::new();
|
||||
// If we found matches, create a result entry
|
||||
if !matches.is_empty() {
|
||||
let result_entry = lua.create_table()?;
|
||||
result_entry.set("path", path.to_string_lossy().to_string())?;
|
||||
|
||||
// Find all regex matches in the content
|
||||
for capture in search_regex.find_iter(&content) {
|
||||
matches.push(capture.as_str().to_string());
|
||||
let matches_table = lua.create_table()?;
|
||||
for (i, m) in matches.iter().enumerate() {
|
||||
matches_table.set(i + 1, m.clone())?;
|
||||
}
|
||||
result_entry.set("matches", matches_table)?;
|
||||
|
||||
// If we found matches, create a result entry
|
||||
if !matches.is_empty() {
|
||||
let result_entry = lua.create_table()?;
|
||||
result_entry.set("path", path.to_string_lossy().to_string())?;
|
||||
|
||||
let matches_table = lua.create_table()?;
|
||||
for (i, m) in matches.iter().enumerate() {
|
||||
matches_table.set(i + 1, m.clone())?;
|
||||
}
|
||||
result_entry.set("matches", matches_table)?;
|
||||
|
||||
search_results.push(Ok(result_entry));
|
||||
}
|
||||
search_results.push(Ok(result_entry));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Create a table to hold our results
|
||||
let results_table = lua.create_table()?;
|
||||
for (i, result) in search_results.into_iter().enumerate() {
|
||||
match result {
|
||||
Ok(entry) => results_table.set(i + 1, entry)?,
|
||||
Err(e) => return Err(e),
|
||||
// Create a table to hold our results
|
||||
let results_table = lua.create_table()?;
|
||||
for (i, result) in search_results.into_iter().enumerate() {
|
||||
match result {
|
||||
Ok(entry) => results_table.set(i + 1, entry)?,
|
||||
Err(e) => return Err(e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(results_table)
|
||||
Ok(results_table)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
async fn find_search_candidates(
|
||||
search_query: SearchQuery,
|
||||
workspace: WeakEntity<Workspace>,
|
||||
foreground_tx: &mut mpsc::Sender<ForegroundFn>,
|
||||
) -> Result<smol::channel::Receiver<ProjectPath>> {
|
||||
run_foreground_fn(
|
||||
"finding search file candidates",
|
||||
foreground_tx,
|
||||
Box::new(move |mut cx| {
|
||||
workspace.update(&mut cx, move |workspace, cx| {
|
||||
workspace.project().update(cx, |project, cx| {
|
||||
project.worktree_store().update(cx, |worktree_store, cx| {
|
||||
// TODO: Better limit? For now this is the same as
|
||||
// MAX_SEARCH_RESULT_FILES.
|
||||
let limit = 5000;
|
||||
// TODO: Providing non-empty open_entries can make this a bit more
|
||||
// efficient as it can skip checking that these paths are textual.
|
||||
let open_entries = HashSet::default();
|
||||
// TODO: This is searching all worktrees, but should only search the
|
||||
// current worktree
|
||||
worktree_store.find_search_candidates(
|
||||
search_query,
|
||||
limit,
|
||||
open_entries,
|
||||
project.fs().clone(),
|
||||
cx,
|
||||
)
|
||||
})
|
||||
})
|
||||
})
|
||||
}),
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
/// Sandboxed io.open() function in Lua.
|
||||
fn io_open(
|
||||
lua: &Lua,
|
||||
fs_changes: Rc<RefCell<HashMap<PathBuf, Vec<u8>>>>,
|
||||
fs_changes: Arc<Mutex<HashMap<PathBuf, Vec<u8>>>>,
|
||||
root_dir: PathBuf,
|
||||
) -> Result<Function> {
|
||||
lua.create_function(move |lua, (path_str, mode): (String, Option<String>)| {
|
||||
|
@ -281,7 +372,7 @@ fn io_open(
|
|||
// Don't actually write to disk; instead, just update fs_changes.
|
||||
let path_buf = PathBuf::from(&path);
|
||||
fs_changes
|
||||
.borrow_mut()
|
||||
.lock()
|
||||
.insert(path_buf.clone(), content_vec.clone());
|
||||
}
|
||||
|
||||
|
@ -333,13 +424,15 @@ fn io_open(
|
|||
return Ok((Some(file), String::new()));
|
||||
}
|
||||
|
||||
let is_in_changes = fs_changes.borrow().contains_key(&path);
|
||||
let fs_changes_map = fs_changes.lock();
|
||||
|
||||
let is_in_changes = fs_changes_map.contains_key(&path);
|
||||
let file_exists = is_in_changes || path.exists();
|
||||
let mut file_content = Vec::new();
|
||||
|
||||
if file_exists && !truncate {
|
||||
if is_in_changes {
|
||||
file_content = fs_changes.borrow().get(&path).unwrap().clone();
|
||||
file_content = fs_changes_map.get(&path).unwrap().clone();
|
||||
} else {
|
||||
// Try to read existing content if file exists and we're not truncating
|
||||
match std::fs::read(&path) {
|
||||
|
@ -349,6 +442,8 @@ fn io_open(
|
|||
}
|
||||
}
|
||||
|
||||
drop(fs_changes_map); // Unlock the fs_changes mutex.
|
||||
|
||||
// If in append mode, position should be at the end
|
||||
let position = if append && file_exists {
|
||||
file_content.len()
|
||||
|
@ -582,9 +677,7 @@ fn io_open(
|
|||
// Update fs_changes
|
||||
let path = file_userdata.get::<String>("__path")?;
|
||||
let path_buf = PathBuf::from(path);
|
||||
fs_changes
|
||||
.borrow_mut()
|
||||
.insert(path_buf, content_vec.clone());
|
||||
fs_changes.lock().insert(path_buf, content_vec.clone());
|
||||
|
||||
Ok(true)
|
||||
})?
|
||||
|
@ -597,33 +690,46 @@ fn io_open(
|
|||
}
|
||||
|
||||
/// Runs a Lua script in a sandboxed environment and returns the printed lines
|
||||
pub fn run_sandboxed_lua(
|
||||
async fn run_sandboxed_lua(
|
||||
script: &str,
|
||||
fs_changes: HashMap<PathBuf, Vec<u8>>,
|
||||
root_dir: PathBuf,
|
||||
worktree_id: WorktreeId,
|
||||
workspace: WeakEntity<Workspace>,
|
||||
foreground_tx: mpsc::Sender<ForegroundFn>,
|
||||
) -> Result<ScriptOutput> {
|
||||
let lua = Lua::new();
|
||||
lua.set_memory_limit(2 * 1024 * 1024 * 1024)?; // 2 GB
|
||||
let globals = lua.globals();
|
||||
|
||||
// Track the lines the Lua script prints out.
|
||||
let printed_lines = Rc::new(RefCell::new(Vec::new()));
|
||||
let fs = Rc::new(RefCell::new(fs_changes));
|
||||
let printed_lines = Arc::new(Mutex::new(Vec::new()));
|
||||
let fs = Arc::new(Mutex::new(fs_changes));
|
||||
|
||||
globals.set("sb_print", print(&lua, printed_lines.clone())?)?;
|
||||
globals.set("search", search(&lua, fs.clone(), root_dir.clone())?)?;
|
||||
globals.set(
|
||||
"search",
|
||||
search(
|
||||
&lua,
|
||||
fs.clone(),
|
||||
root_dir.clone(),
|
||||
worktree_id,
|
||||
workspace,
|
||||
foreground_tx,
|
||||
)?,
|
||||
)?;
|
||||
globals.set("sb_io_open", io_open(&lua, fs.clone(), root_dir)?)?;
|
||||
globals.set("user_script", script)?;
|
||||
|
||||
lua.load(SANDBOX_PREAMBLE).exec()?;
|
||||
lua.load(SANDBOX_PREAMBLE).exec_async().await?;
|
||||
|
||||
drop(lua); // Necessary so the Rc'd values get decremented.
|
||||
drop(lua); // Decrements the Arcs so that try_unwrap works.
|
||||
|
||||
Ok(ScriptOutput {
|
||||
printed_lines: Rc::try_unwrap(printed_lines)
|
||||
printed_lines: Arc::try_unwrap(printed_lines)
|
||||
.expect("There are still other references to printed_lines")
|
||||
.into_inner(),
|
||||
fs_changes: Rc::try_unwrap(fs)
|
||||
fs_changes: Arc::try_unwrap(fs)
|
||||
.expect("There are still other references to fs_changes")
|
||||
.into_inner(),
|
||||
})
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue