From efde5aa2bb3bd1a34ef4d8d6c3554b7e5e554419 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Fri, 7 Mar 2025 16:44:36 +0100 Subject: [PATCH] Extract a `Session` struct to hold state about a given thread's scripting session (#26282) We're still recreating a session for every tool call, but the idea is to have a long-lived `Session` per assistant thread. Release Notes: - N/A --------- Co-authored-by: Agus Zubiaga --- Cargo.lock | 3 +- crates/scripting_tool/Cargo.toml | 10 +- crates/scripting_tool/src/scripting_tool.rs | 842 +------------------- crates/scripting_tool/src/session.rs | 743 +++++++++++++++++ 4 files changed, 769 insertions(+), 829 deletions(-) create mode 100644 crates/scripting_tool/src/session.rs diff --git a/Cargo.lock b/Cargo.lock index d2af4b0444..4e9b781ed2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -11916,6 +11916,7 @@ version = "0.1.0" dependencies = [ "anyhow", "assistant_tool", + "collections", "futures 0.3.31", "gpui", "mlua", @@ -11925,7 +11926,7 @@ dependencies = [ "schemars", "serde", "serde_json", - "smol", + "settings", "util", "workspace", ] diff --git a/crates/scripting_tool/Cargo.toml b/crates/scripting_tool/Cargo.toml index ae86dbb1ce..775cf87f7d 100644 --- a/crates/scripting_tool/Cargo.toml +++ b/crates/scripting_tool/Cargo.toml @@ -15,6 +15,7 @@ doctest = false [dependencies] anyhow.workspace = true assistant_tool.workspace = true +collections.workspace = true futures.workspace = true gpui.workspace = true mlua.workspace = true @@ -24,6 +25,13 @@ regex.workspace = true schemars.workspace = true serde.workspace = true serde_json.workspace = true -smol.workspace = true +settings.workspace = true util.workspace = true workspace.workspace = true + +[dev-dependencies] +collections = { workspace = true, features = ["test-support"] } +gpui = { workspace = true, features = ["test-support"] } +project = { workspace = true, features = ["test-support"] } +settings = { workspace = true, features = ["test-support"] } +workspace = { workspace = true, features = ["test-support"] } diff --git a/crates/scripting_tool/src/scripting_tool.rs b/crates/scripting_tool/src/scripting_tool.rs index 6797d003bd..414d7ff968 100644 --- a/crates/scripting_tool/src/scripting_tool.rs +++ b/crates/scripting_tool/src/scripting_tool.rs @@ -1,22 +1,12 @@ -use anyhow::anyhow; +mod session; + +pub(crate) use session::*; + use assistant_tool::{Tool, ToolRegistry}; -use futures::{ - channel::{mpsc, oneshot}, - SinkExt, StreamExt as _, -}; -use gpui::{App, AppContext as _, AsyncApp, Task, WeakEntity, Window}; -use mlua::{Function, Lua, MultiValue, Result, UserData, UserDataMethods}; -use parking_lot::Mutex; -use project::{search::SearchQuery, ProjectPath, WorktreeId}; +use gpui::{App, AppContext as _, Task, WeakEntity, Window}; use schemars::JsonSchema; use serde::Deserialize; -use std::{ - cell::RefCell, - collections::{HashMap, HashSet}, - path::{Path, PathBuf}, - sync::Arc, -}; -use util::paths::PathMatcher; +use std::sync::Arc; use workspace::Workspace; pub fn init(cx: &App) { @@ -65,824 +55,22 @@ string being a match that was found within the file)."#.into() _window: &mut Window, cx: &mut App, ) -> Task> { - let worktree_root_dir_and_id = workspace.update(cx, |workspace, cx| { - let first_worktree = workspace - .visible_worktrees(cx) - .next() - .ok_or_else(|| anyhow!("no worktrees"))?; - let worktree_id = first_worktree.read(cx).id(); - let root_dir = workspace - .absolute_path_of_worktree(worktree_id, cx) - .ok_or_else(|| anyhow!("no worktree root"))?; - Ok((root_dir, worktree_id)) - }); - let (root_dir, worktree_id) = match worktree_root_dir_and_id { - Ok(Ok(worktree_root_dir_and_id)) => worktree_root_dir_and_id, - Ok(Err(err)) => return Task::ready(Err(err)), - Err(err) => return Task::ready(Err(err)), - }; let input = match serde_json::from_value::(input) { Err(err) => return Task::ready(Err(err.into())), Ok(input) => input, }; + let Ok(project) = workspace.read_with(cx, |workspace, _cx| workspace.project().clone()) + else { + return Task::ready(Err(anyhow::anyhow!("No project found"))); + }; - let (foreground_tx, mut foreground_rx) = mpsc::channel::(1); - - cx.spawn(move |cx| async move { - while let Some(request) = foreground_rx.next().await { - request.0(cx.clone()); - } - }) - .detach(); - + let session = cx.new(|cx| Session::new(project, cx)); let lua_script = input.lua_script; - cx.background_spawn(async move { - let fs_changes = HashMap::new(); - let output = run_sandboxed_lua( - &lua_script, - fs_changes, - root_dir, - worktree_id, - workspace, - foreground_tx, - ) - .await - .map_err(|err| anyhow!(format!("{err}")))?; - let output = output.printed_lines.join("\n"); - + let script = session.update(cx, |session, cx| session.run_script(lua_script, cx)); + cx.spawn(|_cx| async move { + let output = script.await?.stdout; + drop(session); Ok(format!("The script output the following:\n{output}")) }) } } - -struct ForegroundFn(Box); - -async fn run_foreground_fn( - description: &str, - foreground_tx: &mut mpsc::Sender, - function: Box anyhow::Result + Send>, -) -> Result { - let (response_tx, response_rx) = oneshot::channel(); - let send_result = foreground_tx - .send(ForegroundFn(Box::new(move |cx| { - response_tx.send(function(cx)).ok(); - }))) - .await; - match send_result { - Ok(()) => (), - Err(err) => { - return Err(mlua::Error::runtime(format!( - "Internal error while enqueuing work for {description}: {err}" - ))) - } - } - match response_rx.await { - Ok(Ok(result)) => Ok(result), - Ok(Err(err)) => Err(mlua::Error::runtime(format!( - "Error while {description}: {err}" - ))), - Err(oneshot::Canceled) => Err(mlua::Error::runtime(format!( - "Internal error: response oneshot was canceled while {description}." - ))), - } -} - -const SANDBOX_PREAMBLE: &str = include_str!("sandbox_preamble.lua"); - -struct FileContent(RefCell>); - -impl UserData for FileContent { - fn add_methods>(_methods: &mut M) { - // FileContent doesn't have any methods so far. - } -} - -/// Sandboxed print() function in Lua. -fn print(lua: &Lua, printed_lines: Arc>>) -> Result { - lua.create_function(move |_, args: MultiValue| { - let mut string = String::new(); - - for arg in args.into_iter() { - // Lua's `print()` prints tab characters between each argument. - if !string.is_empty() { - string.push('\t'); - } - - // If the argument's to_string() fails, have the whole function call fail. - string.push_str(arg.to_string()?.as_str()) - } - - printed_lines.lock().push(string); - - Ok(()) - }) -} - -fn search( - lua: &Lua, - _fs_changes: Arc>>>, - root_dir: PathBuf, - worktree_id: WorktreeId, - workspace: WeakEntity, - foreground_tx: mpsc::Sender, -) -> Result { - lua.create_async_function(move |lua, regex: String| { - let root_dir = root_dir.clone(); - let workspace = workspace.clone(); - let mut foreground_tx = foreground_tx.clone(); - async move { - use mlua::Table; - use regex::Regex; - use std::fs; - - // TODO: Allow specification of these options. - let search_query = SearchQuery::regex( - ®ex, - false, - false, - false, - PathMatcher::default(), - PathMatcher::default(), - None, - ); - let search_query = match search_query { - Ok(query) => query, - Err(e) => return Err(mlua::Error::runtime(format!("Invalid search query: {}", e))), - }; - - // TODO: Should use `search_query.regex`. The tool description should also be updated, - // as it specifies standard regex. - let search_regex = match Regex::new(®ex) { - Ok(re) => re, - Err(e) => return Err(mlua::Error::runtime(format!("Invalid regex: {}", e))), - }; - - let project_path_rx = - find_search_candidates(search_query, workspace, &mut foreground_tx).await?; - - let mut search_results: Vec> = Vec::new(); - while let Ok(project_path) = project_path_rx.recv().await { - if project_path.worktree_id != worktree_id { - continue; - } - - let path = root_dir.join(project_path.path); - - // Skip files larger than 1MB - if let Ok(metadata) = fs::metadata(&path) { - if metadata.len() > 1_000_000 { - continue; - } - } - - // Attempt to read the file as text - if let Ok(content) = fs::read_to_string(&path) { - let mut matches = Vec::new(); - - // Find all regex matches in the content - for capture in search_regex.find_iter(&content) { - matches.push(capture.as_str().to_string()); - } - - // If we found matches, create a result entry - if !matches.is_empty() { - let result_entry = lua.create_table()?; - result_entry.set("path", path.to_string_lossy().to_string())?; - - let matches_table = lua.create_table()?; - for (i, m) in matches.iter().enumerate() { - matches_table.set(i + 1, m.clone())?; - } - result_entry.set("matches", matches_table)?; - - search_results.push(Ok(result_entry)); - } - } - } - - // Create a table to hold our results - let results_table = lua.create_table()?; - for (i, result) in search_results.into_iter().enumerate() { - match result { - Ok(entry) => results_table.set(i + 1, entry)?, - Err(e) => return Err(e), - } - } - - Ok(results_table) - } - }) -} - -async fn find_search_candidates( - search_query: SearchQuery, - workspace: WeakEntity, - foreground_tx: &mut mpsc::Sender, -) -> Result> { - run_foreground_fn( - "finding search file candidates", - foreground_tx, - Box::new(move |mut cx| { - workspace.update(&mut cx, move |workspace, cx| { - workspace.project().update(cx, |project, cx| { - project.worktree_store().update(cx, |worktree_store, cx| { - // TODO: Better limit? For now this is the same as - // MAX_SEARCH_RESULT_FILES. - let limit = 5000; - // TODO: Providing non-empty open_entries can make this a bit more - // efficient as it can skip checking that these paths are textual. - let open_entries = HashSet::default(); - // TODO: This is searching all worktrees, but should only search the - // current worktree - worktree_store.find_search_candidates( - search_query, - limit, - open_entries, - project.fs().clone(), - cx, - ) - }) - }) - }) - }), - ) - .await -} - -/// Sandboxed io.open() function in Lua. -fn io_open( - lua: &Lua, - fs_changes: Arc>>>, - root_dir: PathBuf, -) -> Result { - lua.create_function(move |lua, (path_str, mode): (String, Option)| { - let mode = mode.unwrap_or_else(|| "r".to_string()); - - // Parse the mode string to determine read/write permissions - let read_perm = mode.contains('r'); - let write_perm = mode.contains('w') || mode.contains('a') || mode.contains('+'); - let append = mode.contains('a'); - let truncate = mode.contains('w'); - - // This will be the Lua value returned from the `open` function. - let file = lua.create_table()?; - - // Store file metadata in the file - file.set("__path", path_str.clone())?; - file.set("__mode", mode.clone())?; - file.set("__read_perm", read_perm)?; - file.set("__write_perm", write_perm)?; - - // Sandbox the path; it must be within root_dir - let path: PathBuf = { - let rust_path = Path::new(&path_str); - - // Get absolute path - if rust_path.is_absolute() { - // Check if path starts with root_dir prefix without resolving symlinks - if !rust_path.starts_with(&root_dir) { - return Ok(( - None, - format!( - "Error: Absolute path {} is outside the current working directory", - path_str - ), - )); - } - rust_path.to_path_buf() - } else { - // Make relative path absolute relative to cwd - root_dir.join(rust_path) - } - }; - - // close method - let close_fn = { - let fs_changes = fs_changes.clone(); - lua.create_function(move |_lua, file_userdata: mlua::Table| { - let write_perm = file_userdata.get::("__write_perm")?; - let path = file_userdata.get::("__path")?; - - if write_perm { - // When closing a writable file, record the content - let content = file_userdata.get::("__content")?; - let content_ref = content.borrow::()?; - let content_vec = content_ref.0.borrow(); - - // Don't actually write to disk; instead, just update fs_changes. - let path_buf = PathBuf::from(&path); - fs_changes - .lock() - .insert(path_buf.clone(), content_vec.clone()); - } - - Ok(true) - })? - }; - file.set("close", close_fn)?; - - // If it's a directory, give it a custom read() and return early. - if path.is_dir() { - // TODO handle the case where we changed it in the in-memory fs - - // Create a special directory handle - file.set("__is_directory", true)?; - - // Store directory entries - let entries = match std::fs::read_dir(&path) { - Ok(entries) => { - let mut entry_names = Vec::new(); - for entry in entries.flatten() { - entry_names.push(entry.file_name().to_string_lossy().into_owned()); - } - entry_names - } - Err(e) => return Ok((None, format!("Error reading directory: {}", e))), - }; - - // Save the list of entries - file.set("__dir_entries", entries)?; - file.set("__dir_position", 0usize)?; - - // Create a directory-specific read function - let read_fn = lua.create_function(|_lua, file_userdata: mlua::Table| { - let position = file_userdata.get::("__dir_position")?; - let entries = file_userdata.get::>("__dir_entries")?; - - if position >= entries.len() { - return Ok(None); // No more entries - } - - let entry = entries[position].clone(); - file_userdata.set("__dir_position", position + 1)?; - - Ok(Some(entry)) - })?; - file.set("read", read_fn)?; - - // If we got this far, the directory was opened successfully - return Ok((Some(file), String::new())); - } - - let fs_changes_map = fs_changes.lock(); - - let is_in_changes = fs_changes_map.contains_key(&path); - let file_exists = is_in_changes || path.exists(); - let mut file_content = Vec::new(); - - if file_exists && !truncate { - if is_in_changes { - file_content = fs_changes_map.get(&path).unwrap().clone(); - } else { - // Try to read existing content if file exists and we're not truncating - match std::fs::read(&path) { - Ok(content) => file_content = content, - Err(e) => return Ok((None, format!("Error reading file: {}", e))), - } - } - } - - drop(fs_changes_map); // Unlock the fs_changes mutex. - - // If in append mode, position should be at the end - let position = if append && file_exists { - file_content.len() - } else { - 0 - }; - file.set("__position", position)?; - file.set( - "__content", - lua.create_userdata(FileContent(RefCell::new(file_content)))?, - )?; - - // Create file methods - - // read method - let read_fn = { - lua.create_function( - |_lua, (file_userdata, format): (mlua::Table, Option)| { - let read_perm = file_userdata.get::("__read_perm")?; - if !read_perm { - return Err(mlua::Error::runtime("File not open for reading")); - } - - let content = file_userdata.get::("__content")?; - let mut position = file_userdata.get::("__position")?; - let content_ref = content.borrow::()?; - let content_vec = content_ref.0.borrow(); - - if position >= content_vec.len() { - return Ok(None); // EOF - } - - match format { - Some(mlua::Value::String(s)) => { - let lossy_string = s.to_string_lossy(); - let format_str: &str = lossy_string.as_ref(); - - // Only consider the first 2 bytes, since it's common to pass e.g. "*all" instead of "*a" - match &format_str[0..2] { - "*a" => { - // Read entire file from current position - let result = String::from_utf8_lossy(&content_vec[position..]) - .to_string(); - position = content_vec.len(); - file_userdata.set("__position", position)?; - Ok(Some(result)) - } - "*l" => { - // Read next line - let mut line = Vec::new(); - let mut found_newline = false; - - while position < content_vec.len() { - let byte = content_vec[position]; - position += 1; - - if byte == b'\n' { - found_newline = true; - break; - } - - // Skip \r in \r\n sequence but add it if it's alone - if byte == b'\r' { - if position < content_vec.len() - && content_vec[position] == b'\n' - { - position += 1; - found_newline = true; - break; - } - } - - line.push(byte); - } - - file_userdata.set("__position", position)?; - - if !found_newline - && line.is_empty() - && position >= content_vec.len() - { - return Ok(None); // EOF - } - - let result = String::from_utf8_lossy(&line).to_string(); - Ok(Some(result)) - } - "*n" => { - // Try to parse as a number (number of bytes to read) - match format_str.parse::() { - Ok(n) => { - let end = - std::cmp::min(position + n, content_vec.len()); - let bytes = &content_vec[position..end]; - let result = String::from_utf8_lossy(bytes).to_string(); - position = end; - file_userdata.set("__position", position)?; - Ok(Some(result)) - } - Err(_) => Err(mlua::Error::runtime(format!( - "Invalid format: {}", - format_str - ))), - } - } - "*L" => { - // Read next line keeping the end of line - let mut line = Vec::new(); - - while position < content_vec.len() { - let byte = content_vec[position]; - position += 1; - - line.push(byte); - - if byte == b'\n' { - break; - } - - // If we encounter a \r, add it and check if the next is \n - if byte == b'\r' - && position < content_vec.len() - && content_vec[position] == b'\n' - { - line.push(content_vec[position]); - position += 1; - break; - } - } - - file_userdata.set("__position", position)?; - - if line.is_empty() && position >= content_vec.len() { - return Ok(None); // EOF - } - - let result = String::from_utf8_lossy(&line).to_string(); - Ok(Some(result)) - } - _ => Err(mlua::Error::runtime(format!( - "Unsupported format: {}", - format_str - ))), - } - } - Some(mlua::Value::Number(n)) => { - // Read n bytes - let n = n as usize; - let end = std::cmp::min(position + n, content_vec.len()); - let bytes = &content_vec[position..end]; - let result = String::from_utf8_lossy(bytes).to_string(); - position = end; - file_userdata.set("__position", position)?; - Ok(Some(result)) - } - Some(_) => Err(mlua::Error::runtime("Invalid format")), - None => { - // Default is to read a line - let mut line = Vec::new(); - let mut found_newline = false; - - while position < content_vec.len() { - let byte = content_vec[position]; - position += 1; - - if byte == b'\n' { - found_newline = true; - break; - } - - // Handle \r\n - if byte == b'\r' { - if position < content_vec.len() - && content_vec[position] == b'\n' - { - position += 1; - found_newline = true; - break; - } - } - - line.push(byte); - } - - file_userdata.set("__position", position)?; - - if !found_newline && line.is_empty() && position >= content_vec.len() { - return Ok(None); // EOF - } - - let result = String::from_utf8_lossy(&line).to_string(); - Ok(Some(result)) - } - } - }, - )? - }; - file.set("read", read_fn)?; - - // write method - let write_fn = { - let fs_changes = fs_changes.clone(); - - lua.create_function(move |_lua, (file_userdata, text): (mlua::Table, String)| { - let write_perm = file_userdata.get::("__write_perm")?; - if !write_perm { - return Err(mlua::Error::runtime("File not open for writing")); - } - - let content = file_userdata.get::("__content")?; - let position = file_userdata.get::("__position")?; - let content_ref = content.borrow::()?; - let mut content_vec = content_ref.0.borrow_mut(); - - let bytes = text.as_bytes(); - - // Ensure the vector has enough capacity - if position + bytes.len() > content_vec.len() { - content_vec.resize(position + bytes.len(), 0); - } - - // Write the bytes - for (i, &byte) in bytes.iter().enumerate() { - content_vec[position + i] = byte; - } - - // Update position - let new_position = position + bytes.len(); - file_userdata.set("__position", new_position)?; - - // Update fs_changes - let path = file_userdata.get::("__path")?; - let path_buf = PathBuf::from(path); - fs_changes.lock().insert(path_buf, content_vec.clone()); - - Ok(true) - })? - }; - file.set("write", write_fn)?; - - // If we got this far, the file was opened successfully - Ok((Some(file), String::new())) - }) -} - -/// Runs a Lua script in a sandboxed environment and returns the printed lines -async fn run_sandboxed_lua( - script: &str, - fs_changes: HashMap>, - root_dir: PathBuf, - worktree_id: WorktreeId, - workspace: WeakEntity, - foreground_tx: mpsc::Sender, -) -> Result { - let lua = Lua::new(); - lua.set_memory_limit(2 * 1024 * 1024 * 1024)?; // 2 GB - let globals = lua.globals(); - - // Track the lines the Lua script prints out. - let printed_lines = Arc::new(Mutex::new(Vec::new())); - let fs = Arc::new(Mutex::new(fs_changes)); - - globals.set("sb_print", print(&lua, printed_lines.clone())?)?; - globals.set( - "search", - search( - &lua, - fs.clone(), - root_dir.clone(), - worktree_id, - workspace, - foreground_tx, - )?, - )?; - globals.set("sb_io_open", io_open(&lua, fs.clone(), root_dir)?)?; - globals.set("user_script", script)?; - - lua.load(SANDBOX_PREAMBLE).exec_async().await?; - - drop(lua); // Decrements the Arcs so that try_unwrap works. - - Ok(ScriptOutput { - printed_lines: Arc::try_unwrap(printed_lines) - .expect("There are still other references to printed_lines") - .into_inner(), - fs_changes: Arc::try_unwrap(fs) - .expect("There are still other references to fs_changes") - .into_inner(), - }) -} - -pub struct ScriptOutput { - printed_lines: Vec, - #[allow(dead_code)] - fs_changes: HashMap>, -} - -#[allow(dead_code)] -impl ScriptOutput { - fn fs_diff(&self) -> HashMap { - let mut diff_map = HashMap::new(); - for (path, content) in &self.fs_changes { - let diff = if path.exists() { - // Read the current file content - match std::fs::read(path) { - Ok(current_content) => { - // Convert both to strings for diffing - let new_content = String::from_utf8_lossy(content).to_string(); - let old_content = String::from_utf8_lossy(¤t_content).to_string(); - - // Generate a git-style diff - let new_lines: Vec<&str> = new_content.lines().collect(); - let old_lines: Vec<&str> = old_content.lines().collect(); - - let path_str = path.to_string_lossy(); - let mut diff = format!("diff --git a/{} b/{}\n", path_str, path_str); - diff.push_str(&format!("--- a/{}\n", path_str)); - diff.push_str(&format!("+++ b/{}\n", path_str)); - - // Very basic diff algorithm - this is simplified - let mut i = 0; - let mut j = 0; - - while i < old_lines.len() || j < new_lines.len() { - if i < old_lines.len() - && j < new_lines.len() - && old_lines[i] == new_lines[j] - { - i += 1; - j += 1; - continue; - } - - // Find next matching line - let mut next_i = i; - let mut next_j = j; - let mut found = false; - - // Look ahead for matches - for look_i in i..std::cmp::min(i + 10, old_lines.len()) { - for look_j in j..std::cmp::min(j + 10, new_lines.len()) { - if old_lines[look_i] == new_lines[look_j] { - next_i = look_i; - next_j = look_j; - found = true; - break; - } - } - if found { - break; - } - } - - // Output the hunk header - diff.push_str(&format!( - "@@ -{},{} +{},{} @@\n", - i + 1, - if found { - next_i - i - } else { - old_lines.len() - i - }, - j + 1, - if found { - next_j - j - } else { - new_lines.len() - j - } - )); - - // Output removed lines - for k in i..next_i { - diff.push_str(&format!("-{}\n", old_lines[k])); - } - - // Output added lines - for k in j..next_j { - diff.push_str(&format!("+{}\n", new_lines[k])); - } - - i = next_i; - j = next_j; - - if found { - i += 1; - j += 1; - } else { - break; - } - } - - diff - } - Err(_) => format!("Error reading current file: {}", path.display()), - } - } else { - // New file - let content_str = String::from_utf8_lossy(content).to_string(); - let path_str = path.to_string_lossy(); - let mut diff = format!("diff --git a/{} b/{}\n", path_str, path_str); - diff.push_str("new file mode 100644\n"); - diff.push_str("--- /dev/null\n"); - diff.push_str(&format!("+++ b/{}\n", path_str)); - - let lines: Vec<&str> = content_str.lines().collect(); - diff.push_str(&format!("@@ -0,0 +1,{} @@\n", lines.len())); - - for line in lines { - diff.push_str(&format!("+{}\n", line)); - } - - diff - }; - - diff_map.insert(path.clone(), diff); - } - - diff_map - } - - fn diff_to_string(&self) -> String { - let mut answer = String::new(); - let diff_map = self.fs_diff(); - - if diff_map.is_empty() { - return "No changes to files".to_string(); - } - - // Sort the paths for consistent output - let mut paths: Vec<&PathBuf> = diff_map.keys().collect(); - paths.sort(); - - for path in paths { - if !answer.is_empty() { - answer.push_str("\n"); - } - answer.push_str(&diff_map[path]); - } - - answer - } -} diff --git a/crates/scripting_tool/src/session.rs b/crates/scripting_tool/src/session.rs new file mode 100644 index 0000000000..36bd395fd2 --- /dev/null +++ b/crates/scripting_tool/src/session.rs @@ -0,0 +1,743 @@ +use anyhow::Result; +use collections::{HashMap, HashSet}; +use futures::{ + channel::{mpsc, oneshot}, + pin_mut, SinkExt, StreamExt, +}; +use gpui::{AppContext, AsyncApp, Context, Entity, Task, WeakEntity}; +use mlua::{Lua, MultiValue, Table, UserData, UserDataMethods}; +use parking_lot::Mutex; +use project::{search::SearchQuery, Fs, Project}; +use regex::Regex; +use std::{ + cell::RefCell, + path::{Path, PathBuf}, + sync::Arc, +}; +use util::{paths::PathMatcher, ResultExt}; + +pub struct ScriptOutput { + pub stdout: String, +} + +struct ForegroundFn(Box, AsyncApp) + Send>); + +pub struct Session { + project: Entity, + // TODO Remove this + fs_changes: Arc>>>, + foreground_fns_tx: mpsc::Sender, + _invoke_foreground_fns: Task<()>, +} + +impl Session { + pub fn new(project: Entity, cx: &mut Context) -> Self { + let (foreground_fns_tx, mut foreground_fns_rx) = mpsc::channel(128); + Session { + project, + fs_changes: Arc::new(Mutex::new(HashMap::default())), + foreground_fns_tx, + _invoke_foreground_fns: cx.spawn(|this, cx| async move { + while let Some(foreground_fn) = foreground_fns_rx.next().await { + foreground_fn.0(this.clone(), cx.clone()); + } + }), + } + } + + /// Runs a Lua script in a sandboxed environment and returns the printed lines + pub fn run_script( + &mut self, + script: String, + cx: &mut Context, + ) -> Task> { + const SANDBOX_PREAMBLE: &str = include_str!("sandbox_preamble.lua"); + + // TODO Remove fs_changes + let fs_changes = self.fs_changes.clone(); + // TODO Honor all worktrees instead of the first one + let root_dir = self + .project + .read(cx) + .visible_worktrees(cx) + .next() + .map(|worktree| worktree.read(cx).abs_path()); + let fs = self.project.read(cx).fs().clone(); + let foreground_fns_tx = self.foreground_fns_tx.clone(); + cx.background_spawn(async move { + let lua = Lua::new(); + lua.set_memory_limit(2 * 1024 * 1024 * 1024)?; // 2 GB + let globals = lua.globals(); + let stdout = Arc::new(Mutex::new(String::new())); + globals.set( + "sb_print", + lua.create_function({ + let stdout = stdout.clone(); + move |_, args: MultiValue| Self::print(args, &stdout) + })?, + )?; + globals.set( + "search", + lua.create_async_function({ + let foreground_fns_tx = foreground_fns_tx.clone(); + let fs = fs.clone(); + move |lua, regex| { + Self::search(lua, foreground_fns_tx.clone(), fs.clone(), regex) + } + })?, + )?; + globals.set( + "sb_io_open", + lua.create_function({ + let fs_changes = fs_changes.clone(); + let root_dir = root_dir.clone(); + move |lua, (path_str, mode)| { + Self::io_open(&lua, &fs_changes, root_dir.as_ref(), path_str, mode) + } + })?, + )?; + globals.set("user_script", script)?; + + lua.load(SANDBOX_PREAMBLE).exec_async().await?; + + // Drop Lua instance to decrement reference count. + drop(lua); + + let stdout = Arc::try_unwrap(stdout) + .expect("no more references to stdout") + .into_inner(); + Ok(ScriptOutput { stdout }) + }) + } + + /// Sandboxed print() function in Lua. + fn print(args: MultiValue, stdout: &Mutex) -> mlua::Result<()> { + for (index, arg) in args.into_iter().enumerate() { + // Lua's `print()` prints tab characters between each argument. + if index > 0 { + stdout.lock().push('\t'); + } + + // If the argument's to_string() fails, have the whole function call fail. + stdout.lock().push_str(&arg.to_string()?); + } + stdout.lock().push('\n'); + + Ok(()) + } + + /// Sandboxed io.open() function in Lua. + fn io_open( + lua: &Lua, + fs_changes: &Arc>>>, + root_dir: Option<&Arc>, + path_str: String, + mode: Option, + ) -> mlua::Result<(Option, String)> { + let root_dir = root_dir + .ok_or_else(|| mlua::Error::runtime("cannot open file without a root directory"))?; + + let mode = mode.unwrap_or_else(|| "r".to_string()); + + // Parse the mode string to determine read/write permissions + let read_perm = mode.contains('r'); + let write_perm = mode.contains('w') || mode.contains('a') || mode.contains('+'); + let append = mode.contains('a'); + let truncate = mode.contains('w'); + + // This will be the Lua value returned from the `open` function. + let file = lua.create_table()?; + + // Store file metadata in the file + file.set("__path", path_str.clone())?; + file.set("__mode", mode.clone())?; + file.set("__read_perm", read_perm)?; + file.set("__write_perm", write_perm)?; + + // Sandbox the path; it must be within root_dir + let path: PathBuf = { + let rust_path = Path::new(&path_str); + + // Get absolute path + if rust_path.is_absolute() { + // Check if path starts with root_dir prefix without resolving symlinks + if !rust_path.starts_with(&root_dir) { + return Ok(( + None, + format!( + "Error: Absolute path {} is outside the current working directory", + path_str + ), + )); + } + rust_path.to_path_buf() + } else { + // Make relative path absolute relative to cwd + root_dir.join(rust_path) + } + }; + + // close method + let close_fn = { + let fs_changes = fs_changes.clone(); + lua.create_function(move |_lua, file_userdata: mlua::Table| { + let write_perm = file_userdata.get::("__write_perm")?; + let path = file_userdata.get::("__path")?; + + if write_perm { + // When closing a writable file, record the content + let content = file_userdata.get::("__content")?; + let content_ref = content.borrow::()?; + let content_vec = content_ref.0.borrow(); + + // Don't actually write to disk; instead, just update fs_changes. + let path_buf = PathBuf::from(&path); + fs_changes + .lock() + .insert(path_buf.clone(), content_vec.clone()); + } + + Ok(true) + })? + }; + file.set("close", close_fn)?; + + // If it's a directory, give it a custom read() and return early. + if path.is_dir() { + // TODO handle the case where we changed it in the in-memory fs + + // Create a special directory handle + file.set("__is_directory", true)?; + + // Store directory entries + let entries = match std::fs::read_dir(&path) { + Ok(entries) => { + let mut entry_names = Vec::new(); + for entry in entries.flatten() { + entry_names.push(entry.file_name().to_string_lossy().into_owned()); + } + entry_names + } + Err(e) => return Ok((None, format!("Error reading directory: {}", e))), + }; + + // Save the list of entries + file.set("__dir_entries", entries)?; + file.set("__dir_position", 0usize)?; + + // Create a directory-specific read function + let read_fn = lua.create_function(|_lua, file_userdata: mlua::Table| { + let position = file_userdata.get::("__dir_position")?; + let entries = file_userdata.get::>("__dir_entries")?; + + if position >= entries.len() { + return Ok(None); // No more entries + } + + let entry = entries[position].clone(); + file_userdata.set("__dir_position", position + 1)?; + + Ok(Some(entry)) + })?; + file.set("read", read_fn)?; + + // If we got this far, the directory was opened successfully + return Ok((Some(file), String::new())); + } + + let fs_changes_map = fs_changes.lock(); + + let is_in_changes = fs_changes_map.contains_key(&path); + let file_exists = is_in_changes || path.exists(); + let mut file_content = Vec::new(); + + if file_exists && !truncate { + if is_in_changes { + file_content = fs_changes_map.get(&path).unwrap().clone(); + } else { + // Try to read existing content if file exists and we're not truncating + match std::fs::read(&path) { + Ok(content) => file_content = content, + Err(e) => return Ok((None, format!("Error reading file: {}", e))), + } + } + } + + drop(fs_changes_map); // Unlock the fs_changes mutex. + + // If in append mode, position should be at the end + let position = if append && file_exists { + file_content.len() + } else { + 0 + }; + file.set("__position", position)?; + file.set( + "__content", + lua.create_userdata(FileContent(RefCell::new(file_content)))?, + )?; + + // Create file methods + + // read method + let read_fn = { + lua.create_function( + |_lua, (file_userdata, format): (mlua::Table, Option)| { + let read_perm = file_userdata.get::("__read_perm")?; + if !read_perm { + return Err(mlua::Error::runtime("File not open for reading")); + } + + let content = file_userdata.get::("__content")?; + let mut position = file_userdata.get::("__position")?; + let content_ref = content.borrow::()?; + let content_vec = content_ref.0.borrow(); + + if position >= content_vec.len() { + return Ok(None); // EOF + } + + match format { + Some(mlua::Value::String(s)) => { + let lossy_string = s.to_string_lossy(); + let format_str: &str = lossy_string.as_ref(); + + // Only consider the first 2 bytes, since it's common to pass e.g. "*all" instead of "*a" + match &format_str[0..2] { + "*a" => { + // Read entire file from current position + let result = String::from_utf8_lossy(&content_vec[position..]) + .to_string(); + position = content_vec.len(); + file_userdata.set("__position", position)?; + Ok(Some(result)) + } + "*l" => { + // Read next line + let mut line = Vec::new(); + let mut found_newline = false; + + while position < content_vec.len() { + let byte = content_vec[position]; + position += 1; + + if byte == b'\n' { + found_newline = true; + break; + } + + // Skip \r in \r\n sequence but add it if it's alone + if byte == b'\r' { + if position < content_vec.len() + && content_vec[position] == b'\n' + { + position += 1; + found_newline = true; + break; + } + } + + line.push(byte); + } + + file_userdata.set("__position", position)?; + + if !found_newline + && line.is_empty() + && position >= content_vec.len() + { + return Ok(None); // EOF + } + + let result = String::from_utf8_lossy(&line).to_string(); + Ok(Some(result)) + } + "*n" => { + // Try to parse as a number (number of bytes to read) + match format_str.parse::() { + Ok(n) => { + let end = + std::cmp::min(position + n, content_vec.len()); + let bytes = &content_vec[position..end]; + let result = String::from_utf8_lossy(bytes).to_string(); + position = end; + file_userdata.set("__position", position)?; + Ok(Some(result)) + } + Err(_) => Err(mlua::Error::runtime(format!( + "Invalid format: {}", + format_str + ))), + } + } + "*L" => { + // Read next line keeping the end of line + let mut line = Vec::new(); + + while position < content_vec.len() { + let byte = content_vec[position]; + position += 1; + + line.push(byte); + + if byte == b'\n' { + break; + } + + // If we encounter a \r, add it and check if the next is \n + if byte == b'\r' + && position < content_vec.len() + && content_vec[position] == b'\n' + { + line.push(content_vec[position]); + position += 1; + break; + } + } + + file_userdata.set("__position", position)?; + + if line.is_empty() && position >= content_vec.len() { + return Ok(None); // EOF + } + + let result = String::from_utf8_lossy(&line).to_string(); + Ok(Some(result)) + } + _ => Err(mlua::Error::runtime(format!( + "Unsupported format: {}", + format_str + ))), + } + } + Some(mlua::Value::Number(n)) => { + // Read n bytes + let n = n as usize; + let end = std::cmp::min(position + n, content_vec.len()); + let bytes = &content_vec[position..end]; + let result = String::from_utf8_lossy(bytes).to_string(); + position = end; + file_userdata.set("__position", position)?; + Ok(Some(result)) + } + Some(_) => Err(mlua::Error::runtime("Invalid format")), + None => { + // Default is to read a line + let mut line = Vec::new(); + let mut found_newline = false; + + while position < content_vec.len() { + let byte = content_vec[position]; + position += 1; + + if byte == b'\n' { + found_newline = true; + break; + } + + // Handle \r\n + if byte == b'\r' { + if position < content_vec.len() + && content_vec[position] == b'\n' + { + position += 1; + found_newline = true; + break; + } + } + + line.push(byte); + } + + file_userdata.set("__position", position)?; + + if !found_newline && line.is_empty() && position >= content_vec.len() { + return Ok(None); // EOF + } + + let result = String::from_utf8_lossy(&line).to_string(); + Ok(Some(result)) + } + } + }, + )? + }; + file.set("read", read_fn)?; + + // write method + let write_fn = { + let fs_changes = fs_changes.clone(); + + lua.create_function(move |_lua, (file_userdata, text): (mlua::Table, String)| { + let write_perm = file_userdata.get::("__write_perm")?; + if !write_perm { + return Err(mlua::Error::runtime("File not open for writing")); + } + + let content = file_userdata.get::("__content")?; + let position = file_userdata.get::("__position")?; + let content_ref = content.borrow::()?; + let mut content_vec = content_ref.0.borrow_mut(); + + let bytes = text.as_bytes(); + + // Ensure the vector has enough capacity + if position + bytes.len() > content_vec.len() { + content_vec.resize(position + bytes.len(), 0); + } + + // Write the bytes + for (i, &byte) in bytes.iter().enumerate() { + content_vec[position + i] = byte; + } + + // Update position + let new_position = position + bytes.len(); + file_userdata.set("__position", new_position)?; + + // Update fs_changes + let path = file_userdata.get::("__path")?; + let path_buf = PathBuf::from(path); + fs_changes.lock().insert(path_buf, content_vec.clone()); + + Ok(true) + })? + }; + file.set("write", write_fn)?; + + // If we got this far, the file was opened successfully + Ok((Some(file), String::new())) + } + + async fn search( + lua: Lua, + mut foreground_tx: mpsc::Sender, + fs: Arc, + regex: String, + ) -> mlua::Result
{ + // TODO: Allow specification of these options. + let search_query = SearchQuery::regex( + ®ex, + false, + false, + false, + PathMatcher::default(), + PathMatcher::default(), + None, + ); + let search_query = match search_query { + Ok(query) => query, + Err(e) => return Err(mlua::Error::runtime(format!("Invalid search query: {}", e))), + }; + + // TODO: Should use `search_query.regex`. The tool description should also be updated, + // as it specifies standard regex. + let search_regex = match Regex::new(®ex) { + Ok(re) => re, + Err(e) => return Err(mlua::Error::runtime(format!("Invalid regex: {}", e))), + }; + + let mut abs_paths_rx = + Self::find_search_candidates(search_query, &mut foreground_tx).await?; + + let mut search_results: Vec
= Vec::new(); + while let Some(path) = abs_paths_rx.next().await { + // Skip files larger than 1MB + if let Ok(Some(metadata)) = fs.metadata(&path).await { + if metadata.len > 1_000_000 { + continue; + } + } + + // Attempt to read the file as text + if let Ok(content) = fs.load(&path).await { + let mut matches = Vec::new(); + + // Find all regex matches in the content + for capture in search_regex.find_iter(&content) { + matches.push(capture.as_str().to_string()); + } + + // If we found matches, create a result entry + if !matches.is_empty() { + let result_entry = lua.create_table()?; + result_entry.set("path", path.to_string_lossy().to_string())?; + + let matches_table = lua.create_table()?; + for (ix, m) in matches.iter().enumerate() { + matches_table.set(ix + 1, m.clone())?; + } + result_entry.set("matches", matches_table)?; + + search_results.push(result_entry); + } + } + } + + // Create a table to hold our results + let results_table = lua.create_table()?; + for (ix, entry) in search_results.into_iter().enumerate() { + results_table.set(ix + 1, entry)?; + } + + Ok(results_table) + } + + async fn find_search_candidates( + search_query: SearchQuery, + foreground_tx: &mut mpsc::Sender, + ) -> mlua::Result> { + Self::run_foreground_fn( + "finding search file candidates", + foreground_tx, + Box::new(move |session, mut cx| { + session.update(&mut cx, |session, cx| { + session.project.update(cx, |project, cx| { + project.worktree_store().update(cx, |worktree_store, cx| { + // TODO: Better limit? For now this is the same as + // MAX_SEARCH_RESULT_FILES. + let limit = 5000; + // TODO: Providing non-empty open_entries can make this a bit more + // efficient as it can skip checking that these paths are textual. + let open_entries = HashSet::default(); + let candidates = worktree_store.find_search_candidates( + search_query, + limit, + open_entries, + project.fs().clone(), + cx, + ); + let (abs_paths_tx, abs_paths_rx) = mpsc::unbounded(); + cx.spawn(|worktree_store, cx| async move { + pin_mut!(candidates); + + while let Some(project_path) = candidates.next().await { + worktree_store.read_with(&cx, |worktree_store, cx| { + if let Some(worktree) = worktree_store + .worktree_for_id(project_path.worktree_id, cx) + { + if let Some(abs_path) = worktree + .read(cx) + .absolutize(&project_path.path) + .log_err() + { + abs_paths_tx.unbounded_send(abs_path)?; + } + } + anyhow::Ok(()) + })??; + } + anyhow::Ok(()) + }) + .detach(); + abs_paths_rx + }) + }) + }) + }), + ) + .await + } + + async fn run_foreground_fn( + description: &str, + foreground_tx: &mut mpsc::Sender, + function: Box, AsyncApp) -> anyhow::Result + Send>, + ) -> mlua::Result { + let (response_tx, response_rx) = oneshot::channel(); + let send_result = foreground_tx + .send(ForegroundFn(Box::new(move |this, cx| { + response_tx.send(function(this, cx)).ok(); + }))) + .await; + match send_result { + Ok(()) => (), + Err(err) => { + return Err(mlua::Error::runtime(format!( + "Internal error while enqueuing work for {description}: {err}" + ))) + } + } + match response_rx.await { + Ok(Ok(result)) => Ok(result), + Ok(Err(err)) => Err(mlua::Error::runtime(format!( + "Error while {description}: {err}" + ))), + Err(oneshot::Canceled) => Err(mlua::Error::runtime(format!( + "Internal error: response oneshot was canceled while {description}." + ))), + } + } +} + +struct FileContent(RefCell>); + +impl UserData for FileContent { + fn add_methods>(_methods: &mut M) { + // FileContent doesn't have any methods so far. + } +} + +#[cfg(test)] +mod tests { + use gpui::TestAppContext; + use project::FakeFs; + use serde_json::json; + use settings::SettingsStore; + + use super::*; + + #[gpui::test] + async fn test_print(cx: &mut TestAppContext) { + init_test(cx); + let fs = FakeFs::new(cx.executor()); + let project = Project::test(fs, [], cx).await; + let session = cx.new(|cx| Session::new(project, cx)); + let script = r#" + print("Hello", "world!") + print("Goodbye", "moon!") + "#; + let output = session + .update(cx, |session, cx| session.run_script(script.to_string(), cx)) + .await + .unwrap(); + assert_eq!(output.stdout, "Hello\tworld!\nGoodbye\tmoon!\n"); + } + + #[gpui::test] + async fn test_search(cx: &mut TestAppContext) { + init_test(cx); + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + "/", + json!({ + "file1.txt": "Hello world!", + "file2.txt": "Goodbye moon!" + }), + ) + .await; + let project = Project::test(fs, [Path::new("/")], cx).await; + let session = cx.new(|cx| Session::new(project, cx)); + let script = r#" + local results = search("world") + for i, result in ipairs(results) do + print("File: " .. result.path) + print("Matches:") + for j, match in ipairs(result.matches) do + print(" " .. match) + end + end + "#; + let output = session + .update(cx, |session, cx| session.run_script(script.to_string(), cx)) + .await + .unwrap(); + assert_eq!(output.stdout, "File: /file1.txt\nMatches:\n world\n"); + } + + fn init_test(cx: &mut TestAppContext) { + let settings_store = cx.update(SettingsStore::test); + cx.set_global(settings_store); + cx.update(Project::init_settings); + } +}