Add initial scripting_tool (#26066)
Just a basic implementation so we can start trying it out. Release Notes: - N/A --------- Co-authored-by: Marshall <marshall@zed.dev> Co-authored-by: Michael <michael@zed.dev>
This commit is contained in:
parent
886d8c1cab
commit
a4ec0af681
8 changed files with 932 additions and 0 deletions
80
Cargo.lock
generated
80
Cargo.lock
generated
|
@ -4310,6 +4310,12 @@ dependencies = [
|
||||||
"regex",
|
"regex",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "env_home"
|
||||||
|
version = "0.1.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "c7f84e12ccf0a7ddc17a6c41c93326024c42920d7ee630d04950e6926645c0fe"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "env_logger"
|
name = "env_logger"
|
||||||
version = "0.10.2"
|
version = "0.10.2"
|
||||||
|
@ -7665,6 +7671,25 @@ dependencies = [
|
||||||
"url",
|
"url",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "lua-src"
|
||||||
|
version = "547.0.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "1edaf29e3517b49b8b746701e5648ccb5785cde1c119062cbabbc5d5cd115e42"
|
||||||
|
dependencies = [
|
||||||
|
"cc",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "luajit-src"
|
||||||
|
version = "210.5.12+a4f56a4"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "b3a8e7962a5368d5f264d045a5a255e90f9aa3fc1941ae15a8d2940d42cac671"
|
||||||
|
dependencies = [
|
||||||
|
"cc",
|
||||||
|
"which 7.0.2",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "lyon"
|
name = "lyon"
|
||||||
version = "1.0.1"
|
version = "1.0.1"
|
||||||
|
@ -8078,6 +8103,33 @@ dependencies = [
|
||||||
"strum",
|
"strum",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "mlua"
|
||||||
|
version = "0.10.3"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "d3f763c1041eff92ffb5d7169968a327e1ed2ebfe425dac0ee5a35f29082534b"
|
||||||
|
dependencies = [
|
||||||
|
"bstr",
|
||||||
|
"either",
|
||||||
|
"mlua-sys",
|
||||||
|
"num-traits",
|
||||||
|
"parking_lot",
|
||||||
|
"rustc-hash 2.1.1",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "mlua-sys"
|
||||||
|
version = "0.6.7"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "1901c1a635a22fe9250ffcc4fcc937c16b47c2e9e71adba8784af8bca1f69594"
|
||||||
|
dependencies = [
|
||||||
|
"cc",
|
||||||
|
"cfg-if",
|
||||||
|
"lua-src",
|
||||||
|
"luajit-src",
|
||||||
|
"pkg-config",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "msvc_spectre_libs"
|
name = "msvc_spectre_libs"
|
||||||
version = "0.1.2"
|
version = "0.1.2"
|
||||||
|
@ -11829,6 +11881,21 @@ version = "1.0.7"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "a3cf7c11c38cb994f3d40e8a8cde3bbd1f72a435e4c49e85d6553d8312306152"
|
checksum = "a3cf7c11c38cb994f3d40e8a8cde3bbd1f72a435e4c49e85d6553d8312306152"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "scripting_tool"
|
||||||
|
version = "0.1.0"
|
||||||
|
dependencies = [
|
||||||
|
"anyhow",
|
||||||
|
"assistant_tool",
|
||||||
|
"gpui",
|
||||||
|
"mlua",
|
||||||
|
"regex",
|
||||||
|
"schemars",
|
||||||
|
"serde",
|
||||||
|
"serde_json",
|
||||||
|
"workspace",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "scrypt"
|
name = "scrypt"
|
||||||
version = "0.11.0"
|
version = "0.11.0"
|
||||||
|
@ -15723,6 +15790,18 @@ dependencies = [
|
||||||
"winsafe",
|
"winsafe",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "which"
|
||||||
|
version = "7.0.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "2774c861e1f072b3aadc02f8ba886c26ad6321567ecc294c935434cad06f1283"
|
||||||
|
dependencies = [
|
||||||
|
"either",
|
||||||
|
"env_home",
|
||||||
|
"rustix",
|
||||||
|
"winsafe",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "whoami"
|
name = "whoami"
|
||||||
version = "1.5.2"
|
version = "1.5.2"
|
||||||
|
@ -16863,6 +16942,7 @@ dependencies = [
|
||||||
"repl",
|
"repl",
|
||||||
"reqwest_client",
|
"reqwest_client",
|
||||||
"rope",
|
"rope",
|
||||||
|
"scripting_tool",
|
||||||
"search",
|
"search",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
|
|
|
@ -117,6 +117,7 @@ members = [
|
||||||
"crates/rope",
|
"crates/rope",
|
||||||
"crates/rpc",
|
"crates/rpc",
|
||||||
"crates/schema_generator",
|
"crates/schema_generator",
|
||||||
|
"crates/scripting_tool",
|
||||||
"crates/search",
|
"crates/search",
|
||||||
"crates/semantic_index",
|
"crates/semantic_index",
|
||||||
"crates/semantic_version",
|
"crates/semantic_version",
|
||||||
|
@ -321,6 +322,7 @@ reqwest_client = { path = "crates/reqwest_client" }
|
||||||
rich_text = { path = "crates/rich_text" }
|
rich_text = { path = "crates/rich_text" }
|
||||||
rope = { path = "crates/rope" }
|
rope = { path = "crates/rope" }
|
||||||
rpc = { path = "crates/rpc" }
|
rpc = { path = "crates/rpc" }
|
||||||
|
scripting_tool = { path = "crates/scripting_tool" }
|
||||||
search = { path = "crates/search" }
|
search = { path = "crates/search" }
|
||||||
semantic_index = { path = "crates/semantic_index" }
|
semantic_index = { path = "crates/semantic_index" }
|
||||||
semantic_version = { path = "crates/semantic_version" }
|
semantic_version = { path = "crates/semantic_version" }
|
||||||
|
@ -453,6 +455,7 @@ livekit = { git = "https://github.com/zed-industries/livekit-rust-sdks", rev = "
|
||||||
], default-features = false }
|
], default-features = false }
|
||||||
log = { version = "0.4.16", features = ["kv_unstable_serde", "serde"] }
|
log = { version = "0.4.16", features = ["kv_unstable_serde", "serde"] }
|
||||||
markup5ever_rcdom = "0.3.0"
|
markup5ever_rcdom = "0.3.0"
|
||||||
|
mlua = { version = "0.10", features = ["lua54", "vendored"] }
|
||||||
nanoid = "0.4"
|
nanoid = "0.4"
|
||||||
nbformat = { version = "0.10.0" }
|
nbformat = { version = "0.10.0" }
|
||||||
nix = "0.29"
|
nix = "0.29"
|
||||||
|
|
24
crates/scripting_tool/Cargo.toml
Normal file
24
crates/scripting_tool/Cargo.toml
Normal file
|
@ -0,0 +1,24 @@
|
||||||
|
[package]
|
||||||
|
name = "scripting_tool"
|
||||||
|
version = "0.1.0"
|
||||||
|
edition.workspace = true
|
||||||
|
publish.workspace = true
|
||||||
|
license = "GPL-3.0-or-later"
|
||||||
|
|
||||||
|
[lints]
|
||||||
|
workspace = true
|
||||||
|
|
||||||
|
[lib]
|
||||||
|
path = "src/scripting_tool.rs"
|
||||||
|
doctest = false
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
anyhow.workspace = true
|
||||||
|
assistant_tool.workspace = true
|
||||||
|
gpui.workspace = true
|
||||||
|
mlua.workspace = true
|
||||||
|
schemars.workspace = true
|
||||||
|
serde.workspace = true
|
||||||
|
serde_json.workspace = true
|
||||||
|
workspace.workspace = true
|
||||||
|
regex.workspace = true
|
1
crates/scripting_tool/LICENSE-GPL
Symbolic link
1
crates/scripting_tool/LICENSE-GPL
Symbolic link
|
@ -0,0 +1 @@
|
||||||
|
../../LICENSE-GPL
|
40
crates/scripting_tool/src/sandbox_preamble.lua
Normal file
40
crates/scripting_tool/src/sandbox_preamble.lua
Normal file
|
@ -0,0 +1,40 @@
|
||||||
|
---@diagnostic disable: undefined-global
|
||||||
|
|
||||||
|
-- Create a sandbox environment
|
||||||
|
local sandbox = {}
|
||||||
|
|
||||||
|
-- Allow access to standard libraries (safe subset)
|
||||||
|
sandbox.string = string
|
||||||
|
sandbox.table = table
|
||||||
|
sandbox.math = math
|
||||||
|
sandbox.print = sb_print
|
||||||
|
sandbox.type = type
|
||||||
|
sandbox.tostring = tostring
|
||||||
|
sandbox.tonumber = tonumber
|
||||||
|
sandbox.pairs = pairs
|
||||||
|
sandbox.ipairs = ipairs
|
||||||
|
sandbox.search = search
|
||||||
|
|
||||||
|
-- Create a sandboxed version of LuaFileIO
|
||||||
|
local io = {}
|
||||||
|
|
||||||
|
-- File functions
|
||||||
|
io.open = sb_io_open
|
||||||
|
|
||||||
|
-- Add the sandboxed io library to the sandbox environment
|
||||||
|
sandbox.io = io
|
||||||
|
|
||||||
|
|
||||||
|
-- Load the script with the sandbox environment
|
||||||
|
local user_script_fn, err = load(user_script, nil, "t", sandbox)
|
||||||
|
|
||||||
|
if not user_script_fn then
|
||||||
|
error("Failed to load user script: " .. tostring(err))
|
||||||
|
end
|
||||||
|
|
||||||
|
-- Execute the user script within the sandbox
|
||||||
|
local success, result = pcall(user_script_fn)
|
||||||
|
|
||||||
|
if not success then
|
||||||
|
error("Error executing user script: " .. tostring(result))
|
||||||
|
end
|
782
crates/scripting_tool/src/scripting_tool.rs
Normal file
782
crates/scripting_tool/src/scripting_tool.rs
Normal file
|
@ -0,0 +1,782 @@
|
||||||
|
use anyhow::anyhow;
|
||||||
|
use assistant_tool::{Tool, ToolRegistry};
|
||||||
|
use gpui::{App, AppContext as _, Task, WeakEntity, Window};
|
||||||
|
use mlua::{Function, Lua, MultiValue, Result, UserData, UserDataMethods};
|
||||||
|
use schemars::JsonSchema;
|
||||||
|
use serde::Deserialize;
|
||||||
|
use std::{
|
||||||
|
cell::RefCell,
|
||||||
|
collections::HashMap,
|
||||||
|
path::{Path, PathBuf},
|
||||||
|
rc::Rc,
|
||||||
|
sync::Arc,
|
||||||
|
};
|
||||||
|
use workspace::Workspace;
|
||||||
|
|
||||||
|
pub fn init(cx: &App) {
|
||||||
|
let registry = ToolRegistry::global(cx);
|
||||||
|
registry.register_tool(ScriptingTool);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize, JsonSchema)]
|
||||||
|
struct ScriptingToolInput {
|
||||||
|
lua_script: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ScriptingTool;
|
||||||
|
|
||||||
|
impl Tool for ScriptingTool {
|
||||||
|
fn name(&self) -> String {
|
||||||
|
"lua-interpreter".into()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn description(&self) -> String {
|
||||||
|
r#"You can write a Lua script and I'll run it on my code base and tell you what its output was,
|
||||||
|
including both stdout as well as the git diff of changes it made to the filesystem. That way,
|
||||||
|
you can get more information about the code base, or make changes to the code base directly.
|
||||||
|
The lua script will have access to `io` and it will run with the current working directory being in
|
||||||
|
the root of the code base, so you can use it to explore, search, make changes, etc. You can also have
|
||||||
|
the script print things, and I'll tell you what the output was. Note that `io` only has `open`, and
|
||||||
|
then the file it returns only has the methods read, write, and close - it doesn't have popen or
|
||||||
|
anything else. Also, I'm going to be putting this Lua script into JSON, so please don't use Lua's
|
||||||
|
double quote syntax for string literals - use one of Lua's other syntaxes for string literals, so I
|
||||||
|
don't have to escape the double quotes. There will be a global called `search` which accepts a regex
|
||||||
|
(it's implemented using Rust's regex crate, so use that regex syntax) and runs that regex on the contents
|
||||||
|
of every file in the code base (aside from gitignored files), then returns an array of tables with two
|
||||||
|
fields: "path" (the path to the file that had the matches) and "matches" (an array of strings, with each
|
||||||
|
string being a match that was found within the file)."#.into()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn input_schema(&self) -> serde_json::Value {
|
||||||
|
let schema = schemars::schema_for!(ScriptingToolInput);
|
||||||
|
serde_json::to_value(&schema).unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run(
|
||||||
|
self: Arc<Self>,
|
||||||
|
input: serde_json::Value,
|
||||||
|
workspace: WeakEntity<Workspace>,
|
||||||
|
_window: &mut Window,
|
||||||
|
cx: &mut App,
|
||||||
|
) -> Task<anyhow::Result<String>> {
|
||||||
|
let root_dir = workspace.update(cx, |workspace, cx| {
|
||||||
|
let first_worktree = workspace
|
||||||
|
.visible_worktrees(cx)
|
||||||
|
.next()
|
||||||
|
.ok_or_else(|| anyhow!("no worktrees"))?;
|
||||||
|
workspace
|
||||||
|
.absolute_path_of_worktree(first_worktree.read(cx).id(), cx)
|
||||||
|
.ok_or_else(|| anyhow!("no worktree root"))
|
||||||
|
});
|
||||||
|
let root_dir = match root_dir {
|
||||||
|
Ok(root_dir) => root_dir,
|
||||||
|
Err(err) => return Task::ready(Err(err)),
|
||||||
|
};
|
||||||
|
let root_dir = match root_dir {
|
||||||
|
Ok(root_dir) => root_dir,
|
||||||
|
Err(err) => return Task::ready(Err(err)),
|
||||||
|
};
|
||||||
|
let input = match serde_json::from_value::<ScriptingToolInput>(input) {
|
||||||
|
Err(err) => return Task::ready(Err(err.into())),
|
||||||
|
Ok(input) => input,
|
||||||
|
};
|
||||||
|
let lua_script = input.lua_script;
|
||||||
|
cx.background_spawn(async move {
|
||||||
|
let fs_changes = HashMap::new();
|
||||||
|
let output = run_sandboxed_lua(&lua_script, fs_changes, root_dir)
|
||||||
|
.map_err(|err| anyhow!(format!("{err}")))?;
|
||||||
|
let output = output.printed_lines.join("\n");
|
||||||
|
|
||||||
|
Ok(format!("The script output the following:\n{output}"))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const SANDBOX_PREAMBLE: &str = include_str!("sandbox_preamble.lua");
|
||||||
|
|
||||||
|
struct FileContent(RefCell<Vec<u8>>);
|
||||||
|
|
||||||
|
impl UserData for FileContent {
|
||||||
|
fn add_methods<M: UserDataMethods<Self>>(_methods: &mut M) {
|
||||||
|
// FileContent doesn't have any methods so far.
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Sandboxed print() function in Lua.
|
||||||
|
fn print(lua: &Lua, printed_lines: Rc<RefCell<Vec<String>>>) -> Result<Function> {
|
||||||
|
lua.create_function(move |_, args: MultiValue| {
|
||||||
|
let mut string = String::new();
|
||||||
|
|
||||||
|
for arg in args.into_iter() {
|
||||||
|
// Lua's `print()` prints tab characters between each argument.
|
||||||
|
if !string.is_empty() {
|
||||||
|
string.push('\t');
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the argument's to_string() fails, have the whole function call fail.
|
||||||
|
string.push_str(arg.to_string()?.as_str())
|
||||||
|
}
|
||||||
|
|
||||||
|
printed_lines.borrow_mut().push(string);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn search(
|
||||||
|
lua: &Lua,
|
||||||
|
_fs_changes: Rc<RefCell<HashMap<PathBuf, Vec<u8>>>>,
|
||||||
|
root_dir: PathBuf,
|
||||||
|
) -> Result<Function> {
|
||||||
|
lua.create_function(move |lua, regex: String| {
|
||||||
|
use mlua::Table;
|
||||||
|
use regex::Regex;
|
||||||
|
use std::fs;
|
||||||
|
|
||||||
|
// Function to recursively search directory
|
||||||
|
let search_regex = match Regex::new(®ex) {
|
||||||
|
Ok(re) => re,
|
||||||
|
Err(e) => return Err(mlua::Error::runtime(format!("Invalid regex: {}", e))),
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut search_results: Vec<Result<Table>> = Vec::new();
|
||||||
|
|
||||||
|
// Create an explicit stack for directories to process
|
||||||
|
let mut dir_stack = vec![root_dir.clone()];
|
||||||
|
|
||||||
|
while let Some(current_dir) = dir_stack.pop() {
|
||||||
|
// Process each entry in the current directory
|
||||||
|
let entries = match fs::read_dir(¤t_dir) {
|
||||||
|
Ok(entries) => entries,
|
||||||
|
Err(e) => return Err(e.into()),
|
||||||
|
};
|
||||||
|
|
||||||
|
for entry_result in entries {
|
||||||
|
let entry = match entry_result {
|
||||||
|
Ok(e) => e,
|
||||||
|
Err(e) => return Err(e.into()),
|
||||||
|
};
|
||||||
|
|
||||||
|
let path = entry.path();
|
||||||
|
|
||||||
|
if path.is_dir() {
|
||||||
|
// Skip .git directory and other common directories to ignore
|
||||||
|
let dir_name = path.file_name().unwrap_or_default().to_string_lossy();
|
||||||
|
if !dir_name.starts_with('.')
|
||||||
|
&& dir_name != "node_modules"
|
||||||
|
&& dir_name != "target"
|
||||||
|
{
|
||||||
|
// Instead of recursive call, add to stack
|
||||||
|
dir_stack.push(path);
|
||||||
|
}
|
||||||
|
} else if path.is_file() {
|
||||||
|
// Skip binary files and very large files
|
||||||
|
if let Ok(metadata) = fs::metadata(&path) {
|
||||||
|
if metadata.len() > 1_000_000 {
|
||||||
|
// Skip files larger than 1MB
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Attempt to read the file as text
|
||||||
|
if let Ok(content) = fs::read_to_string(&path) {
|
||||||
|
let mut matches = Vec::new();
|
||||||
|
|
||||||
|
// Find all regex matches in the content
|
||||||
|
for capture in search_regex.find_iter(&content) {
|
||||||
|
matches.push(capture.as_str().to_string());
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we found matches, create a result entry
|
||||||
|
if !matches.is_empty() {
|
||||||
|
let result_entry = lua.create_table()?;
|
||||||
|
result_entry.set("path", path.to_string_lossy().to_string())?;
|
||||||
|
|
||||||
|
let matches_table = lua.create_table()?;
|
||||||
|
for (i, m) in matches.iter().enumerate() {
|
||||||
|
matches_table.set(i + 1, m.clone())?;
|
||||||
|
}
|
||||||
|
result_entry.set("matches", matches_table)?;
|
||||||
|
|
||||||
|
search_results.push(Ok(result_entry));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a table to hold our results
|
||||||
|
let results_table = lua.create_table()?;
|
||||||
|
for (i, result) in search_results.into_iter().enumerate() {
|
||||||
|
match result {
|
||||||
|
Ok(entry) => results_table.set(i + 1, entry)?,
|
||||||
|
Err(e) => return Err(e),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(results_table)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Sandboxed io.open() function in Lua.
|
||||||
|
fn io_open(
|
||||||
|
lua: &Lua,
|
||||||
|
fs_changes: Rc<RefCell<HashMap<PathBuf, Vec<u8>>>>,
|
||||||
|
root_dir: PathBuf,
|
||||||
|
) -> Result<Function> {
|
||||||
|
lua.create_function(move |lua, (path_str, mode): (String, Option<String>)| {
|
||||||
|
let mode = mode.unwrap_or_else(|| "r".to_string());
|
||||||
|
|
||||||
|
// Parse the mode string to determine read/write permissions
|
||||||
|
let read_perm = mode.contains('r');
|
||||||
|
let write_perm = mode.contains('w') || mode.contains('a') || mode.contains('+');
|
||||||
|
let append = mode.contains('a');
|
||||||
|
let truncate = mode.contains('w');
|
||||||
|
|
||||||
|
// This will be the Lua value returned from the `open` function.
|
||||||
|
let file = lua.create_table()?;
|
||||||
|
|
||||||
|
// Store file metadata in the file
|
||||||
|
file.set("__path", path_str.clone())?;
|
||||||
|
file.set("__mode", mode.clone())?;
|
||||||
|
file.set("__read_perm", read_perm)?;
|
||||||
|
file.set("__write_perm", write_perm)?;
|
||||||
|
|
||||||
|
// Sandbox the path; it must be within root_dir
|
||||||
|
let path: PathBuf = {
|
||||||
|
let rust_path = Path::new(&path_str);
|
||||||
|
|
||||||
|
// Get absolute path
|
||||||
|
if rust_path.is_absolute() {
|
||||||
|
// Check if path starts with root_dir prefix without resolving symlinks
|
||||||
|
if !rust_path.starts_with(&root_dir) {
|
||||||
|
return Ok((
|
||||||
|
None,
|
||||||
|
format!(
|
||||||
|
"Error: Absolute path {} is outside the current working directory",
|
||||||
|
path_str
|
||||||
|
),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
rust_path.to_path_buf()
|
||||||
|
} else {
|
||||||
|
// Make relative path absolute relative to cwd
|
||||||
|
root_dir.join(rust_path)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// close method
|
||||||
|
let close_fn = {
|
||||||
|
let fs_changes = fs_changes.clone();
|
||||||
|
lua.create_function(move |_lua, file_userdata: mlua::Table| {
|
||||||
|
let write_perm = file_userdata.get::<bool>("__write_perm")?;
|
||||||
|
let path = file_userdata.get::<String>("__path")?;
|
||||||
|
|
||||||
|
if write_perm {
|
||||||
|
// When closing a writable file, record the content
|
||||||
|
let content = file_userdata.get::<mlua::AnyUserData>("__content")?;
|
||||||
|
let content_ref = content.borrow::<FileContent>()?;
|
||||||
|
let content_vec = content_ref.0.borrow();
|
||||||
|
|
||||||
|
// Don't actually write to disk; instead, just update fs_changes.
|
||||||
|
let path_buf = PathBuf::from(&path);
|
||||||
|
fs_changes
|
||||||
|
.borrow_mut()
|
||||||
|
.insert(path_buf.clone(), content_vec.clone());
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(true)
|
||||||
|
})?
|
||||||
|
};
|
||||||
|
file.set("close", close_fn)?;
|
||||||
|
|
||||||
|
// If it's a directory, give it a custom read() and return early.
|
||||||
|
if path.is_dir() {
|
||||||
|
// TODO handle the case where we changed it in the in-memory fs
|
||||||
|
|
||||||
|
// Create a special directory handle
|
||||||
|
file.set("__is_directory", true)?;
|
||||||
|
|
||||||
|
// Store directory entries
|
||||||
|
let entries = match std::fs::read_dir(&path) {
|
||||||
|
Ok(entries) => {
|
||||||
|
let mut entry_names = Vec::new();
|
||||||
|
for entry in entries.flatten() {
|
||||||
|
entry_names.push(entry.file_name().to_string_lossy().into_owned());
|
||||||
|
}
|
||||||
|
entry_names
|
||||||
|
}
|
||||||
|
Err(e) => return Ok((None, format!("Error reading directory: {}", e))),
|
||||||
|
};
|
||||||
|
|
||||||
|
// Save the list of entries
|
||||||
|
file.set("__dir_entries", entries)?;
|
||||||
|
file.set("__dir_position", 0usize)?;
|
||||||
|
|
||||||
|
// Create a directory-specific read function
|
||||||
|
let read_fn = lua.create_function(|_lua, file_userdata: mlua::Table| {
|
||||||
|
let position = file_userdata.get::<usize>("__dir_position")?;
|
||||||
|
let entries = file_userdata.get::<Vec<String>>("__dir_entries")?;
|
||||||
|
|
||||||
|
if position >= entries.len() {
|
||||||
|
return Ok(None); // No more entries
|
||||||
|
}
|
||||||
|
|
||||||
|
let entry = entries[position].clone();
|
||||||
|
file_userdata.set("__dir_position", position + 1)?;
|
||||||
|
|
||||||
|
Ok(Some(entry))
|
||||||
|
})?;
|
||||||
|
file.set("read", read_fn)?;
|
||||||
|
|
||||||
|
// If we got this far, the directory was opened successfully
|
||||||
|
return Ok((Some(file), String::new()));
|
||||||
|
}
|
||||||
|
|
||||||
|
let is_in_changes = fs_changes.borrow().contains_key(&path);
|
||||||
|
let file_exists = is_in_changes || path.exists();
|
||||||
|
let mut file_content = Vec::new();
|
||||||
|
|
||||||
|
if file_exists && !truncate {
|
||||||
|
if is_in_changes {
|
||||||
|
file_content = fs_changes.borrow().get(&path).unwrap().clone();
|
||||||
|
} else {
|
||||||
|
// Try to read existing content if file exists and we're not truncating
|
||||||
|
match std::fs::read(&path) {
|
||||||
|
Ok(content) => file_content = content,
|
||||||
|
Err(e) => return Ok((None, format!("Error reading file: {}", e))),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If in append mode, position should be at the end
|
||||||
|
let position = if append && file_exists {
|
||||||
|
file_content.len()
|
||||||
|
} else {
|
||||||
|
0
|
||||||
|
};
|
||||||
|
file.set("__position", position)?;
|
||||||
|
file.set(
|
||||||
|
"__content",
|
||||||
|
lua.create_userdata(FileContent(RefCell::new(file_content)))?,
|
||||||
|
)?;
|
||||||
|
|
||||||
|
// Create file methods
|
||||||
|
|
||||||
|
// read method
|
||||||
|
let read_fn = {
|
||||||
|
lua.create_function(
|
||||||
|
|_lua, (file_userdata, format): (mlua::Table, Option<mlua::Value>)| {
|
||||||
|
let read_perm = file_userdata.get::<bool>("__read_perm")?;
|
||||||
|
if !read_perm {
|
||||||
|
return Err(mlua::Error::runtime("File not open for reading"));
|
||||||
|
}
|
||||||
|
|
||||||
|
let content = file_userdata.get::<mlua::AnyUserData>("__content")?;
|
||||||
|
let mut position = file_userdata.get::<usize>("__position")?;
|
||||||
|
let content_ref = content.borrow::<FileContent>()?;
|
||||||
|
let content_vec = content_ref.0.borrow();
|
||||||
|
|
||||||
|
if position >= content_vec.len() {
|
||||||
|
return Ok(None); // EOF
|
||||||
|
}
|
||||||
|
|
||||||
|
match format {
|
||||||
|
Some(mlua::Value::String(s)) => {
|
||||||
|
let lossy_string = s.to_string_lossy();
|
||||||
|
let format_str: &str = lossy_string.as_ref();
|
||||||
|
|
||||||
|
// Only consider the first 2 bytes, since it's common to pass e.g. "*all" instead of "*a"
|
||||||
|
match &format_str[0..2] {
|
||||||
|
"*a" => {
|
||||||
|
// Read entire file from current position
|
||||||
|
let result = String::from_utf8_lossy(&content_vec[position..])
|
||||||
|
.to_string();
|
||||||
|
position = content_vec.len();
|
||||||
|
file_userdata.set("__position", position)?;
|
||||||
|
Ok(Some(result))
|
||||||
|
}
|
||||||
|
"*l" => {
|
||||||
|
// Read next line
|
||||||
|
let mut line = Vec::new();
|
||||||
|
let mut found_newline = false;
|
||||||
|
|
||||||
|
while position < content_vec.len() {
|
||||||
|
let byte = content_vec[position];
|
||||||
|
position += 1;
|
||||||
|
|
||||||
|
if byte == b'\n' {
|
||||||
|
found_newline = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Skip \r in \r\n sequence but add it if it's alone
|
||||||
|
if byte == b'\r' {
|
||||||
|
if position < content_vec.len()
|
||||||
|
&& content_vec[position] == b'\n'
|
||||||
|
{
|
||||||
|
position += 1;
|
||||||
|
found_newline = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
line.push(byte);
|
||||||
|
}
|
||||||
|
|
||||||
|
file_userdata.set("__position", position)?;
|
||||||
|
|
||||||
|
if !found_newline
|
||||||
|
&& line.is_empty()
|
||||||
|
&& position >= content_vec.len()
|
||||||
|
{
|
||||||
|
return Ok(None); // EOF
|
||||||
|
}
|
||||||
|
|
||||||
|
let result = String::from_utf8_lossy(&line).to_string();
|
||||||
|
Ok(Some(result))
|
||||||
|
}
|
||||||
|
"*n" => {
|
||||||
|
// Try to parse as a number (number of bytes to read)
|
||||||
|
match format_str.parse::<usize>() {
|
||||||
|
Ok(n) => {
|
||||||
|
let end =
|
||||||
|
std::cmp::min(position + n, content_vec.len());
|
||||||
|
let bytes = &content_vec[position..end];
|
||||||
|
let result = String::from_utf8_lossy(bytes).to_string();
|
||||||
|
position = end;
|
||||||
|
file_userdata.set("__position", position)?;
|
||||||
|
Ok(Some(result))
|
||||||
|
}
|
||||||
|
Err(_) => Err(mlua::Error::runtime(format!(
|
||||||
|
"Invalid format: {}",
|
||||||
|
format_str
|
||||||
|
))),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"*L" => {
|
||||||
|
// Read next line keeping the end of line
|
||||||
|
let mut line = Vec::new();
|
||||||
|
|
||||||
|
while position < content_vec.len() {
|
||||||
|
let byte = content_vec[position];
|
||||||
|
position += 1;
|
||||||
|
|
||||||
|
line.push(byte);
|
||||||
|
|
||||||
|
if byte == b'\n' {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we encounter a \r, add it and check if the next is \n
|
||||||
|
if byte == b'\r'
|
||||||
|
&& position < content_vec.len()
|
||||||
|
&& content_vec[position] == b'\n'
|
||||||
|
{
|
||||||
|
line.push(content_vec[position]);
|
||||||
|
position += 1;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
file_userdata.set("__position", position)?;
|
||||||
|
|
||||||
|
if line.is_empty() && position >= content_vec.len() {
|
||||||
|
return Ok(None); // EOF
|
||||||
|
}
|
||||||
|
|
||||||
|
let result = String::from_utf8_lossy(&line).to_string();
|
||||||
|
Ok(Some(result))
|
||||||
|
}
|
||||||
|
_ => Err(mlua::Error::runtime(format!(
|
||||||
|
"Unsupported format: {}",
|
||||||
|
format_str
|
||||||
|
))),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Some(mlua::Value::Number(n)) => {
|
||||||
|
// Read n bytes
|
||||||
|
let n = n as usize;
|
||||||
|
let end = std::cmp::min(position + n, content_vec.len());
|
||||||
|
let bytes = &content_vec[position..end];
|
||||||
|
let result = String::from_utf8_lossy(bytes).to_string();
|
||||||
|
position = end;
|
||||||
|
file_userdata.set("__position", position)?;
|
||||||
|
Ok(Some(result))
|
||||||
|
}
|
||||||
|
Some(_) => Err(mlua::Error::runtime("Invalid format")),
|
||||||
|
None => {
|
||||||
|
// Default is to read a line
|
||||||
|
let mut line = Vec::new();
|
||||||
|
let mut found_newline = false;
|
||||||
|
|
||||||
|
while position < content_vec.len() {
|
||||||
|
let byte = content_vec[position];
|
||||||
|
position += 1;
|
||||||
|
|
||||||
|
if byte == b'\n' {
|
||||||
|
found_newline = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle \r\n
|
||||||
|
if byte == b'\r' {
|
||||||
|
if position < content_vec.len()
|
||||||
|
&& content_vec[position] == b'\n'
|
||||||
|
{
|
||||||
|
position += 1;
|
||||||
|
found_newline = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
line.push(byte);
|
||||||
|
}
|
||||||
|
|
||||||
|
file_userdata.set("__position", position)?;
|
||||||
|
|
||||||
|
if !found_newline && line.is_empty() && position >= content_vec.len() {
|
||||||
|
return Ok(None); // EOF
|
||||||
|
}
|
||||||
|
|
||||||
|
let result = String::from_utf8_lossy(&line).to_string();
|
||||||
|
Ok(Some(result))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)?
|
||||||
|
};
|
||||||
|
file.set("read", read_fn)?;
|
||||||
|
|
||||||
|
// write method
|
||||||
|
let write_fn = {
|
||||||
|
let fs_changes = fs_changes.clone();
|
||||||
|
|
||||||
|
lua.create_function(move |_lua, (file_userdata, text): (mlua::Table, String)| {
|
||||||
|
let write_perm = file_userdata.get::<bool>("__write_perm")?;
|
||||||
|
if !write_perm {
|
||||||
|
return Err(mlua::Error::runtime("File not open for writing"));
|
||||||
|
}
|
||||||
|
|
||||||
|
let content = file_userdata.get::<mlua::AnyUserData>("__content")?;
|
||||||
|
let position = file_userdata.get::<usize>("__position")?;
|
||||||
|
let content_ref = content.borrow::<FileContent>()?;
|
||||||
|
let mut content_vec = content_ref.0.borrow_mut();
|
||||||
|
|
||||||
|
let bytes = text.as_bytes();
|
||||||
|
|
||||||
|
// Ensure the vector has enough capacity
|
||||||
|
if position + bytes.len() > content_vec.len() {
|
||||||
|
content_vec.resize(position + bytes.len(), 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write the bytes
|
||||||
|
for (i, &byte) in bytes.iter().enumerate() {
|
||||||
|
content_vec[position + i] = byte;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update position
|
||||||
|
let new_position = position + bytes.len();
|
||||||
|
file_userdata.set("__position", new_position)?;
|
||||||
|
|
||||||
|
// Update fs_changes
|
||||||
|
let path = file_userdata.get::<String>("__path")?;
|
||||||
|
let path_buf = PathBuf::from(path);
|
||||||
|
fs_changes
|
||||||
|
.borrow_mut()
|
||||||
|
.insert(path_buf, content_vec.clone());
|
||||||
|
|
||||||
|
Ok(true)
|
||||||
|
})?
|
||||||
|
};
|
||||||
|
file.set("write", write_fn)?;
|
||||||
|
|
||||||
|
// If we got this far, the file was opened successfully
|
||||||
|
Ok((Some(file), String::new()))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Runs a Lua script in a sandboxed environment and returns the printed lines
|
||||||
|
pub fn run_sandboxed_lua(
|
||||||
|
script: &str,
|
||||||
|
fs_changes: HashMap<PathBuf, Vec<u8>>,
|
||||||
|
root_dir: PathBuf,
|
||||||
|
) -> Result<ScriptOutput> {
|
||||||
|
let lua = Lua::new();
|
||||||
|
lua.set_memory_limit(2 * 1024 * 1024 * 1024)?; // 2 GB
|
||||||
|
let globals = lua.globals();
|
||||||
|
|
||||||
|
// Track the lines the Lua script prints out.
|
||||||
|
let printed_lines = Rc::new(RefCell::new(Vec::new()));
|
||||||
|
let fs = Rc::new(RefCell::new(fs_changes));
|
||||||
|
|
||||||
|
globals.set("sb_print", print(&lua, printed_lines.clone())?)?;
|
||||||
|
globals.set("search", search(&lua, fs.clone(), root_dir.clone())?)?;
|
||||||
|
globals.set("sb_io_open", io_open(&lua, fs.clone(), root_dir)?)?;
|
||||||
|
globals.set("user_script", script)?;
|
||||||
|
|
||||||
|
lua.load(SANDBOX_PREAMBLE).exec()?;
|
||||||
|
|
||||||
|
drop(lua); // Necessary so the Rc'd values get decremented.
|
||||||
|
|
||||||
|
Ok(ScriptOutput {
|
||||||
|
printed_lines: Rc::try_unwrap(printed_lines)
|
||||||
|
.expect("There are still other references to printed_lines")
|
||||||
|
.into_inner(),
|
||||||
|
fs_changes: Rc::try_unwrap(fs)
|
||||||
|
.expect("There are still other references to fs_changes")
|
||||||
|
.into_inner(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct ScriptOutput {
|
||||||
|
printed_lines: Vec<String>,
|
||||||
|
#[allow(dead_code)]
|
||||||
|
fs_changes: HashMap<PathBuf, Vec<u8>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(dead_code)]
|
||||||
|
impl ScriptOutput {
|
||||||
|
fn fs_diff(&self) -> HashMap<PathBuf, String> {
|
||||||
|
let mut diff_map = HashMap::new();
|
||||||
|
for (path, content) in &self.fs_changes {
|
||||||
|
let diff = if path.exists() {
|
||||||
|
// Read the current file content
|
||||||
|
match std::fs::read(path) {
|
||||||
|
Ok(current_content) => {
|
||||||
|
// Convert both to strings for diffing
|
||||||
|
let new_content = String::from_utf8_lossy(content).to_string();
|
||||||
|
let old_content = String::from_utf8_lossy(¤t_content).to_string();
|
||||||
|
|
||||||
|
// Generate a git-style diff
|
||||||
|
let new_lines: Vec<&str> = new_content.lines().collect();
|
||||||
|
let old_lines: Vec<&str> = old_content.lines().collect();
|
||||||
|
|
||||||
|
let path_str = path.to_string_lossy();
|
||||||
|
let mut diff = format!("diff --git a/{} b/{}\n", path_str, path_str);
|
||||||
|
diff.push_str(&format!("--- a/{}\n", path_str));
|
||||||
|
diff.push_str(&format!("+++ b/{}\n", path_str));
|
||||||
|
|
||||||
|
// Very basic diff algorithm - this is simplified
|
||||||
|
let mut i = 0;
|
||||||
|
let mut j = 0;
|
||||||
|
|
||||||
|
while i < old_lines.len() || j < new_lines.len() {
|
||||||
|
if i < old_lines.len()
|
||||||
|
&& j < new_lines.len()
|
||||||
|
&& old_lines[i] == new_lines[j]
|
||||||
|
{
|
||||||
|
i += 1;
|
||||||
|
j += 1;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find next matching line
|
||||||
|
let mut next_i = i;
|
||||||
|
let mut next_j = j;
|
||||||
|
let mut found = false;
|
||||||
|
|
||||||
|
// Look ahead for matches
|
||||||
|
for look_i in i..std::cmp::min(i + 10, old_lines.len()) {
|
||||||
|
for look_j in j..std::cmp::min(j + 10, new_lines.len()) {
|
||||||
|
if old_lines[look_i] == new_lines[look_j] {
|
||||||
|
next_i = look_i;
|
||||||
|
next_j = look_j;
|
||||||
|
found = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if found {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Output the hunk header
|
||||||
|
diff.push_str(&format!(
|
||||||
|
"@@ -{},{} +{},{} @@\n",
|
||||||
|
i + 1,
|
||||||
|
if found {
|
||||||
|
next_i - i
|
||||||
|
} else {
|
||||||
|
old_lines.len() - i
|
||||||
|
},
|
||||||
|
j + 1,
|
||||||
|
if found {
|
||||||
|
next_j - j
|
||||||
|
} else {
|
||||||
|
new_lines.len() - j
|
||||||
|
}
|
||||||
|
));
|
||||||
|
|
||||||
|
// Output removed lines
|
||||||
|
for k in i..next_i {
|
||||||
|
diff.push_str(&format!("-{}\n", old_lines[k]));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Output added lines
|
||||||
|
for k in j..next_j {
|
||||||
|
diff.push_str(&format!("+{}\n", new_lines[k]));
|
||||||
|
}
|
||||||
|
|
||||||
|
i = next_i;
|
||||||
|
j = next_j;
|
||||||
|
|
||||||
|
if found {
|
||||||
|
i += 1;
|
||||||
|
j += 1;
|
||||||
|
} else {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
diff
|
||||||
|
}
|
||||||
|
Err(_) => format!("Error reading current file: {}", path.display()),
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// New file
|
||||||
|
let content_str = String::from_utf8_lossy(content).to_string();
|
||||||
|
let path_str = path.to_string_lossy();
|
||||||
|
let mut diff = format!("diff --git a/{} b/{}\n", path_str, path_str);
|
||||||
|
diff.push_str("new file mode 100644\n");
|
||||||
|
diff.push_str("--- /dev/null\n");
|
||||||
|
diff.push_str(&format!("+++ b/{}\n", path_str));
|
||||||
|
|
||||||
|
let lines: Vec<&str> = content_str.lines().collect();
|
||||||
|
diff.push_str(&format!("@@ -0,0 +1,{} @@\n", lines.len()));
|
||||||
|
|
||||||
|
for line in lines {
|
||||||
|
diff.push_str(&format!("+{}\n", line));
|
||||||
|
}
|
||||||
|
|
||||||
|
diff
|
||||||
|
};
|
||||||
|
|
||||||
|
diff_map.insert(path.clone(), diff);
|
||||||
|
}
|
||||||
|
|
||||||
|
diff_map
|
||||||
|
}
|
||||||
|
|
||||||
|
fn diff_to_string(&self) -> String {
|
||||||
|
let mut answer = String::new();
|
||||||
|
let diff_map = self.fs_diff();
|
||||||
|
|
||||||
|
if diff_map.is_empty() {
|
||||||
|
return "No changes to files".to_string();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sort the paths for consistent output
|
||||||
|
let mut paths: Vec<&PathBuf> = diff_map.keys().collect();
|
||||||
|
paths.sort();
|
||||||
|
|
||||||
|
for path in paths {
|
||||||
|
if !answer.is_empty() {
|
||||||
|
answer.push_str("\n");
|
||||||
|
}
|
||||||
|
answer.push_str(&diff_map[path]);
|
||||||
|
}
|
||||||
|
|
||||||
|
answer
|
||||||
|
}
|
||||||
|
}
|
|
@ -98,6 +98,7 @@ remote.workspace = true
|
||||||
repl.workspace = true
|
repl.workspace = true
|
||||||
reqwest_client.workspace = true
|
reqwest_client.workspace = true
|
||||||
rope.workspace = true
|
rope.workspace = true
|
||||||
|
scripting_tool.workspace = true
|
||||||
search.workspace = true
|
search.workspace = true
|
||||||
serde.workspace = true
|
serde.workspace = true
|
||||||
serde_json.workspace = true
|
serde_json.workspace = true
|
||||||
|
|
|
@ -477,6 +477,7 @@ fn main() {
|
||||||
cx,
|
cx,
|
||||||
);
|
);
|
||||||
assistant_tools::init(cx);
|
assistant_tools::init(cx);
|
||||||
|
scripting_tool::init(cx);
|
||||||
repl::init(app_state.fs.clone(), cx);
|
repl::init(app_state.fs.clone(), cx);
|
||||||
extension_host::init(
|
extension_host::init(
|
||||||
extension_host_proxy,
|
extension_host_proxy,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue