diff --git a/Cargo.lock b/Cargo.lock index 316587ee83..d2af4b0444 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8134,6 +8134,7 @@ checksum = "d3f763c1041eff92ffb5d7169968a327e1ed2ebfe425dac0ee5a35f29082534b" dependencies = [ "bstr", "either", + "futures-util", "mlua-sys", "num-traits", "parking_lot", @@ -11915,12 +11916,17 @@ version = "0.1.0" dependencies = [ "anyhow", "assistant_tool", + "futures 0.3.31", "gpui", "mlua", + "parking_lot", + "project", "regex", "schemars", "serde", "serde_json", + "smol", + "util", "workspace", ] diff --git a/Cargo.toml b/Cargo.toml index 20e4a00a9f..5cb40db6ad 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -452,7 +452,7 @@ livekit = { git = "https://github.com/zed-industries/livekit-rust-sdks", rev = " ], default-features = false } log = { version = "0.4.16", features = ["kv_unstable_serde", "serde"] } markup5ever_rcdom = "0.3.0" -mlua = { version = "0.10", features = ["lua54", "vendored"] } +mlua = { version = "0.10", features = ["lua54", "vendored", "async", "send"] } nanoid = "0.4" nbformat = { version = "0.10.0" } nix = "0.29" diff --git a/crates/scripting_tool/Cargo.toml b/crates/scripting_tool/Cargo.toml index f9045ff7f8..ae86dbb1ce 100644 --- a/crates/scripting_tool/Cargo.toml +++ b/crates/scripting_tool/Cargo.toml @@ -15,10 +15,15 @@ doctest = false [dependencies] anyhow.workspace = true assistant_tool.workspace = true +futures.workspace = true gpui.workspace = true mlua.workspace = true +parking_lot.workspace = true +project.workspace = true +regex.workspace = true schemars.workspace = true serde.workspace = true serde_json.workspace = true +smol.workspace = true +util.workspace = true workspace.workspace = true -regex.workspace = true diff --git a/crates/scripting_tool/src/scripting_tool.rs b/crates/scripting_tool/src/scripting_tool.rs index 42a553494c..6797d003bd 100644 --- a/crates/scripting_tool/src/scripting_tool.rs +++ b/crates/scripting_tool/src/scripting_tool.rs @@ -1,16 +1,22 @@ use anyhow::anyhow; use assistant_tool::{Tool, ToolRegistry}; -use gpui::{App, AppContext as _, Task, WeakEntity, Window}; +use futures::{ + channel::{mpsc, oneshot}, + SinkExt, StreamExt as _, +}; +use gpui::{App, AppContext as _, AsyncApp, Task, WeakEntity, Window}; use mlua::{Function, Lua, MultiValue, Result, UserData, UserDataMethods}; +use parking_lot::Mutex; +use project::{search::SearchQuery, ProjectPath, WorktreeId}; use schemars::JsonSchema; use serde::Deserialize; use std::{ cell::RefCell, - collections::HashMap, + collections::{HashMap, HashSet}, path::{Path, PathBuf}, - rc::Rc, sync::Arc, }; +use util::paths::PathMatcher; use workspace::Workspace; pub fn init(cx: &App) { @@ -59,32 +65,49 @@ string being a match that was found within the file)."#.into() _window: &mut Window, cx: &mut App, ) -> Task> { - let root_dir = workspace.update(cx, |workspace, cx| { + let worktree_root_dir_and_id = workspace.update(cx, |workspace, cx| { let first_worktree = workspace .visible_worktrees(cx) .next() .ok_or_else(|| anyhow!("no worktrees"))?; - workspace - .absolute_path_of_worktree(first_worktree.read(cx).id(), cx) - .ok_or_else(|| anyhow!("no worktree root")) + let worktree_id = first_worktree.read(cx).id(); + let root_dir = workspace + .absolute_path_of_worktree(worktree_id, cx) + .ok_or_else(|| anyhow!("no worktree root"))?; + Ok((root_dir, worktree_id)) }); - let root_dir = match root_dir { - Ok(root_dir) => root_dir, - Err(err) => return Task::ready(Err(err)), - }; - let root_dir = match root_dir { - Ok(root_dir) => root_dir, + let (root_dir, worktree_id) = match worktree_root_dir_and_id { + Ok(Ok(worktree_root_dir_and_id)) => worktree_root_dir_and_id, + Ok(Err(err)) => return Task::ready(Err(err)), Err(err) => return Task::ready(Err(err)), }; let input = match serde_json::from_value::(input) { Err(err) => return Task::ready(Err(err.into())), Ok(input) => input, }; + + let (foreground_tx, mut foreground_rx) = mpsc::channel::(1); + + cx.spawn(move |cx| async move { + while let Some(request) = foreground_rx.next().await { + request.0(cx.clone()); + } + }) + .detach(); + let lua_script = input.lua_script; cx.background_spawn(async move { let fs_changes = HashMap::new(); - let output = run_sandboxed_lua(&lua_script, fs_changes, root_dir) - .map_err(|err| anyhow!(format!("{err}")))?; + let output = run_sandboxed_lua( + &lua_script, + fs_changes, + root_dir, + worktree_id, + workspace, + foreground_tx, + ) + .await + .map_err(|err| anyhow!(format!("{err}")))?; let output = output.printed_lines.join("\n"); Ok(format!("The script output the following:\n{output}")) @@ -92,6 +115,38 @@ string being a match that was found within the file)."#.into() } } +struct ForegroundFn(Box); + +async fn run_foreground_fn( + description: &str, + foreground_tx: &mut mpsc::Sender, + function: Box anyhow::Result + Send>, +) -> Result { + let (response_tx, response_rx) = oneshot::channel(); + let send_result = foreground_tx + .send(ForegroundFn(Box::new(move |cx| { + response_tx.send(function(cx)).ok(); + }))) + .await; + match send_result { + Ok(()) => (), + Err(err) => { + return Err(mlua::Error::runtime(format!( + "Internal error while enqueuing work for {description}: {err}" + ))) + } + } + match response_rx.await { + Ok(Ok(result)) => Ok(result), + Ok(Err(err)) => Err(mlua::Error::runtime(format!( + "Error while {description}: {err}" + ))), + Err(oneshot::Canceled) => Err(mlua::Error::runtime(format!( + "Internal error: response oneshot was canceled while {description}." + ))), + } +} + const SANDBOX_PREAMBLE: &str = include_str!("sandbox_preamble.lua"); struct FileContent(RefCell>); @@ -103,7 +158,7 @@ impl UserData for FileContent { } /// Sandboxed print() function in Lua. -fn print(lua: &Lua, printed_lines: Rc>>) -> Result { +fn print(lua: &Lua, printed_lines: Arc>>) -> Result { lua.create_function(move |_, args: MultiValue| { let mut string = String::new(); @@ -117,7 +172,7 @@ fn print(lua: &Lua, printed_lines: Rc>>) -> Result string.push_str(arg.to_string()?.as_str()) } - printed_lines.borrow_mut().push(string); + printed_lines.lock().push(string); Ok(()) }) @@ -125,103 +180,139 @@ fn print(lua: &Lua, printed_lines: Rc>>) -> Result fn search( lua: &Lua, - _fs_changes: Rc>>>, + _fs_changes: Arc>>>, root_dir: PathBuf, + worktree_id: WorktreeId, + workspace: WeakEntity, + foreground_tx: mpsc::Sender, ) -> Result { - lua.create_function(move |lua, regex: String| { - use mlua::Table; - use regex::Regex; - use std::fs; + lua.create_async_function(move |lua, regex: String| { + let root_dir = root_dir.clone(); + let workspace = workspace.clone(); + let mut foreground_tx = foreground_tx.clone(); + async move { + use mlua::Table; + use regex::Regex; + use std::fs; - // Function to recursively search directory - let search_regex = match Regex::new(®ex) { - Ok(re) => re, - Err(e) => return Err(mlua::Error::runtime(format!("Invalid regex: {}", e))), - }; - - let mut search_results: Vec> = Vec::new(); - - // Create an explicit stack for directories to process - let mut dir_stack = vec![root_dir.clone()]; - - while let Some(current_dir) = dir_stack.pop() { - // Process each entry in the current directory - let entries = match fs::read_dir(¤t_dir) { - Ok(entries) => entries, - Err(e) => return Err(e.into()), + // TODO: Allow specification of these options. + let search_query = SearchQuery::regex( + ®ex, + false, + false, + false, + PathMatcher::default(), + PathMatcher::default(), + None, + ); + let search_query = match search_query { + Ok(query) => query, + Err(e) => return Err(mlua::Error::runtime(format!("Invalid search query: {}", e))), }; - for entry_result in entries { - let entry = match entry_result { - Ok(e) => e, - Err(e) => return Err(e.into()), - }; + // TODO: Should use `search_query.regex`. The tool description should also be updated, + // as it specifies standard regex. + let search_regex = match Regex::new(®ex) { + Ok(re) => re, + Err(e) => return Err(mlua::Error::runtime(format!("Invalid regex: {}", e))), + }; - let path = entry.path(); + let project_path_rx = + find_search_candidates(search_query, workspace, &mut foreground_tx).await?; - if path.is_dir() { - // Skip .git directory and other common directories to ignore - let dir_name = path.file_name().unwrap_or_default().to_string_lossy(); - if !dir_name.starts_with('.') - && dir_name != "node_modules" - && dir_name != "target" - { - // Instead of recursive call, add to stack - dir_stack.push(path); + let mut search_results: Vec> = Vec::new(); + while let Ok(project_path) = project_path_rx.recv().await { + if project_path.worktree_id != worktree_id { + continue; + } + + let path = root_dir.join(project_path.path); + + // Skip files larger than 1MB + if let Ok(metadata) = fs::metadata(&path) { + if metadata.len() > 1_000_000 { + continue; } - } else if path.is_file() { - // Skip binary files and very large files - if let Ok(metadata) = fs::metadata(&path) { - if metadata.len() > 1_000_000 { - // Skip files larger than 1MB - continue; - } + } + + // Attempt to read the file as text + if let Ok(content) = fs::read_to_string(&path) { + let mut matches = Vec::new(); + + // Find all regex matches in the content + for capture in search_regex.find_iter(&content) { + matches.push(capture.as_str().to_string()); } - // Attempt to read the file as text - if let Ok(content) = fs::read_to_string(&path) { - let mut matches = Vec::new(); + // If we found matches, create a result entry + if !matches.is_empty() { + let result_entry = lua.create_table()?; + result_entry.set("path", path.to_string_lossy().to_string())?; - // Find all regex matches in the content - for capture in search_regex.find_iter(&content) { - matches.push(capture.as_str().to_string()); + let matches_table = lua.create_table()?; + for (i, m) in matches.iter().enumerate() { + matches_table.set(i + 1, m.clone())?; } + result_entry.set("matches", matches_table)?; - // If we found matches, create a result entry - if !matches.is_empty() { - let result_entry = lua.create_table()?; - result_entry.set("path", path.to_string_lossy().to_string())?; - - let matches_table = lua.create_table()?; - for (i, m) in matches.iter().enumerate() { - matches_table.set(i + 1, m.clone())?; - } - result_entry.set("matches", matches_table)?; - - search_results.push(Ok(result_entry)); - } + search_results.push(Ok(result_entry)); } } } - } - // Create a table to hold our results - let results_table = lua.create_table()?; - for (i, result) in search_results.into_iter().enumerate() { - match result { - Ok(entry) => results_table.set(i + 1, entry)?, - Err(e) => return Err(e), + // Create a table to hold our results + let results_table = lua.create_table()?; + for (i, result) in search_results.into_iter().enumerate() { + match result { + Ok(entry) => results_table.set(i + 1, entry)?, + Err(e) => return Err(e), + } } - } - Ok(results_table) + Ok(results_table) + } }) } +async fn find_search_candidates( + search_query: SearchQuery, + workspace: WeakEntity, + foreground_tx: &mut mpsc::Sender, +) -> Result> { + run_foreground_fn( + "finding search file candidates", + foreground_tx, + Box::new(move |mut cx| { + workspace.update(&mut cx, move |workspace, cx| { + workspace.project().update(cx, |project, cx| { + project.worktree_store().update(cx, |worktree_store, cx| { + // TODO: Better limit? For now this is the same as + // MAX_SEARCH_RESULT_FILES. + let limit = 5000; + // TODO: Providing non-empty open_entries can make this a bit more + // efficient as it can skip checking that these paths are textual. + let open_entries = HashSet::default(); + // TODO: This is searching all worktrees, but should only search the + // current worktree + worktree_store.find_search_candidates( + search_query, + limit, + open_entries, + project.fs().clone(), + cx, + ) + }) + }) + }) + }), + ) + .await +} + /// Sandboxed io.open() function in Lua. fn io_open( lua: &Lua, - fs_changes: Rc>>>, + fs_changes: Arc>>>, root_dir: PathBuf, ) -> Result { lua.create_function(move |lua, (path_str, mode): (String, Option)| { @@ -281,7 +372,7 @@ fn io_open( // Don't actually write to disk; instead, just update fs_changes. let path_buf = PathBuf::from(&path); fs_changes - .borrow_mut() + .lock() .insert(path_buf.clone(), content_vec.clone()); } @@ -333,13 +424,15 @@ fn io_open( return Ok((Some(file), String::new())); } - let is_in_changes = fs_changes.borrow().contains_key(&path); + let fs_changes_map = fs_changes.lock(); + + let is_in_changes = fs_changes_map.contains_key(&path); let file_exists = is_in_changes || path.exists(); let mut file_content = Vec::new(); if file_exists && !truncate { if is_in_changes { - file_content = fs_changes.borrow().get(&path).unwrap().clone(); + file_content = fs_changes_map.get(&path).unwrap().clone(); } else { // Try to read existing content if file exists and we're not truncating match std::fs::read(&path) { @@ -349,6 +442,8 @@ fn io_open( } } + drop(fs_changes_map); // Unlock the fs_changes mutex. + // If in append mode, position should be at the end let position = if append && file_exists { file_content.len() @@ -582,9 +677,7 @@ fn io_open( // Update fs_changes let path = file_userdata.get::("__path")?; let path_buf = PathBuf::from(path); - fs_changes - .borrow_mut() - .insert(path_buf, content_vec.clone()); + fs_changes.lock().insert(path_buf, content_vec.clone()); Ok(true) })? @@ -597,33 +690,46 @@ fn io_open( } /// Runs a Lua script in a sandboxed environment and returns the printed lines -pub fn run_sandboxed_lua( +async fn run_sandboxed_lua( script: &str, fs_changes: HashMap>, root_dir: PathBuf, + worktree_id: WorktreeId, + workspace: WeakEntity, + foreground_tx: mpsc::Sender, ) -> Result { let lua = Lua::new(); lua.set_memory_limit(2 * 1024 * 1024 * 1024)?; // 2 GB let globals = lua.globals(); // Track the lines the Lua script prints out. - let printed_lines = Rc::new(RefCell::new(Vec::new())); - let fs = Rc::new(RefCell::new(fs_changes)); + let printed_lines = Arc::new(Mutex::new(Vec::new())); + let fs = Arc::new(Mutex::new(fs_changes)); globals.set("sb_print", print(&lua, printed_lines.clone())?)?; - globals.set("search", search(&lua, fs.clone(), root_dir.clone())?)?; + globals.set( + "search", + search( + &lua, + fs.clone(), + root_dir.clone(), + worktree_id, + workspace, + foreground_tx, + )?, + )?; globals.set("sb_io_open", io_open(&lua, fs.clone(), root_dir)?)?; globals.set("user_script", script)?; - lua.load(SANDBOX_PREAMBLE).exec()?; + lua.load(SANDBOX_PREAMBLE).exec_async().await?; - drop(lua); // Necessary so the Rc'd values get decremented. + drop(lua); // Decrements the Arcs so that try_unwrap works. Ok(ScriptOutput { - printed_lines: Rc::try_unwrap(printed_lines) + printed_lines: Arc::try_unwrap(printed_lines) .expect("There are still other references to printed_lines") .into_inner(), - fs_changes: Rc::try_unwrap(fs) + fs_changes: Arc::try_unwrap(fs) .expect("There are still other references to fs_changes") .into_inner(), })