scripting tool: Use project buffers in io.open
(#26425)
This PR makes `io.open` use our own implementation again, but instead of the real filesystem, it will now use the project's to check file metadata and perform read and writes using project buffers. This also cleans up the `io.open` implementation by splitting it into multiple methods, adds tests for various File I/O patterns, and fixes a few bugs in read formats. Release Notes: - N/A
This commit is contained in:
parent
d562f58e76
commit
bf11b888c3
5 changed files with 502 additions and 335 deletions
|
@ -8,9 +8,9 @@ local sandbox = {}
|
|||
-- to our in-memory log rather than to stdout, we will delete this loop (and re-enable
|
||||
-- the I/O module being sandboxed below) to have things be sandboxed again.
|
||||
for k, v in pairs(_G) do
|
||||
if sandbox[k] == nil then
|
||||
sandbox[k] = v
|
||||
end
|
||||
if sandbox[k] == nil then
|
||||
sandbox[k] = v
|
||||
end
|
||||
end
|
||||
|
||||
-- Allow access to standard libraries (safe subset)
|
||||
|
@ -29,24 +29,27 @@ sandbox.search = search
|
|||
sandbox.outline = outline
|
||||
|
||||
-- Create a sandboxed version of LuaFileIO
|
||||
local io = {}
|
||||
-- local io = {};
|
||||
--
|
||||
-- For now we are using unsandboxed io
|
||||
local io = _G.io;
|
||||
|
||||
-- File functions
|
||||
io.open = sb_io_open
|
||||
|
||||
-- Add the sandboxed io library to the sandbox environment
|
||||
-- sandbox.io = io -- Uncomment this line to re-enable sandboxed file I/O.
|
||||
sandbox.io = io
|
||||
|
||||
-- Load the script with the sandbox environment
|
||||
local user_script_fn, err = load(user_script, nil, "t", sandbox)
|
||||
|
||||
if not user_script_fn then
|
||||
error("Failed to load user script: " .. tostring(err))
|
||||
error("Failed to load user script: " .. tostring(err))
|
||||
end
|
||||
|
||||
-- Execute the user script within the sandbox
|
||||
local success, result = pcall(user_script_fn)
|
||||
|
||||
if not success then
|
||||
error("Error executing user script: " .. tostring(result))
|
||||
error("Error executing user script: " .. tostring(result))
|
||||
end
|
||||
|
|
|
@ -36,7 +36,7 @@ impl ScriptingTool {
|
|||
};
|
||||
|
||||
// TODO: Store a session per thread
|
||||
let session = cx.new(|cx| ScriptSession::new(project, cx));
|
||||
let session = cx.new(|cx| ScriptingSession::new(project, cx));
|
||||
let lua_script = input.lua_script;
|
||||
|
||||
let (script_id, script_task) =
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
use anyhow::anyhow;
|
||||
use collections::{HashMap, HashSet};
|
||||
use collections::HashSet;
|
||||
use futures::{
|
||||
channel::{mpsc, oneshot},
|
||||
pin_mut, SinkExt, StreamExt,
|
||||
|
@ -7,32 +7,28 @@ use futures::{
|
|||
use gpui::{AppContext, AsyncApp, Context, Entity, Task, WeakEntity};
|
||||
use mlua::{ExternalResult, Lua, MultiValue, Table, UserData, UserDataMethods};
|
||||
use parking_lot::Mutex;
|
||||
use project::{search::SearchQuery, Fs, Project};
|
||||
use project::{search::SearchQuery, Fs, Project, ProjectPath, WorktreeId};
|
||||
use regex::Regex;
|
||||
use std::{
|
||||
cell::RefCell,
|
||||
path::{Path, PathBuf},
|
||||
sync::Arc,
|
||||
};
|
||||
use util::{paths::PathMatcher, ResultExt};
|
||||
|
||||
struct ForegroundFn(Box<dyn FnOnce(WeakEntity<ScriptSession>, AsyncApp) + Send>);
|
||||
struct ForegroundFn(Box<dyn FnOnce(WeakEntity<ScriptingSession>, AsyncApp) + Send>);
|
||||
|
||||
pub struct ScriptSession {
|
||||
pub struct ScriptingSession {
|
||||
project: Entity<Project>,
|
||||
// TODO Remove this
|
||||
fs_changes: Arc<Mutex<HashMap<PathBuf, Vec<u8>>>>,
|
||||
foreground_fns_tx: mpsc::Sender<ForegroundFn>,
|
||||
_invoke_foreground_fns: Task<()>,
|
||||
scripts: Vec<Script>,
|
||||
}
|
||||
|
||||
impl ScriptSession {
|
||||
impl ScriptingSession {
|
||||
pub fn new(project: Entity<Project>, cx: &mut Context<Self>) -> Self {
|
||||
let (foreground_fns_tx, mut foreground_fns_rx) = mpsc::channel(128);
|
||||
ScriptSession {
|
||||
ScriptingSession {
|
||||
project,
|
||||
fs_changes: Arc::new(Mutex::new(HashMap::default())),
|
||||
foreground_fns_tx,
|
||||
_invoke_foreground_fns: cx.spawn(|this, cx| async move {
|
||||
while let Some(foreground_fn) = foreground_fns_rx.next().await {
|
||||
|
@ -88,15 +84,18 @@ impl ScriptSession {
|
|||
) -> Task<anyhow::Result<()>> {
|
||||
const SANDBOX_PREAMBLE: &str = include_str!("sandbox_preamble.lua");
|
||||
|
||||
// TODO Remove fs_changes
|
||||
let fs_changes = self.fs_changes.clone();
|
||||
// TODO Honor all worktrees instead of the first one
|
||||
let root_dir = self
|
||||
let worktree_info = self
|
||||
.project
|
||||
.read(cx)
|
||||
.visible_worktrees(cx)
|
||||
.next()
|
||||
.map(|worktree| worktree.read(cx).abs_path());
|
||||
.map(|worktree| {
|
||||
let worktree = worktree.read(cx);
|
||||
(worktree.id(), worktree.abs_path())
|
||||
});
|
||||
|
||||
let root_dir = worktree_info.as_ref().map(|(_, root)| root.clone());
|
||||
|
||||
let fs = self.project.read(cx).fs().clone();
|
||||
let foreground_fns_tx = self.foreground_fns_tx.clone();
|
||||
|
@ -127,6 +126,7 @@ impl ScriptSession {
|
|||
"search",
|
||||
lua.create_async_function({
|
||||
let foreground_fns_tx = foreground_fns_tx.clone();
|
||||
let fs = fs.clone();
|
||||
move |lua, regex| {
|
||||
let mut foreground_fns_tx = foreground_fns_tx.clone();
|
||||
let fs = fs.clone();
|
||||
|
@ -142,6 +142,7 @@ impl ScriptSession {
|
|||
"outline",
|
||||
lua.create_async_function({
|
||||
let root_dir = root_dir.clone();
|
||||
let foreground_fns_tx = foreground_fns_tx.clone();
|
||||
move |_lua, path| {
|
||||
let mut foreground_fns_tx = foreground_fns_tx.clone();
|
||||
let root_dir = root_dir.clone();
|
||||
|
@ -155,11 +156,24 @@ impl ScriptSession {
|
|||
)?;
|
||||
globals.set(
|
||||
"sb_io_open",
|
||||
lua.create_function({
|
||||
let fs_changes = fs_changes.clone();
|
||||
let root_dir = root_dir.clone();
|
||||
lua.create_async_function({
|
||||
let worktree_info = worktree_info.clone();
|
||||
let foreground_fns_tx = foreground_fns_tx.clone();
|
||||
move |lua, (path_str, mode)| {
|
||||
Self::io_open(&lua, &fs_changes, root_dir.as_ref(), path_str, mode)
|
||||
let worktree_info = worktree_info.clone();
|
||||
let mut foreground_fns_tx = foreground_fns_tx.clone();
|
||||
let fs = fs.clone();
|
||||
async move {
|
||||
Self::io_open(
|
||||
&lua,
|
||||
worktree_info,
|
||||
&mut foreground_fns_tx,
|
||||
fs,
|
||||
path_str,
|
||||
mode,
|
||||
)
|
||||
.await
|
||||
}
|
||||
}
|
||||
})?,
|
||||
)?;
|
||||
|
@ -202,14 +216,15 @@ impl ScriptSession {
|
|||
}
|
||||
|
||||
/// Sandboxed io.open() function in Lua.
|
||||
fn io_open(
|
||||
async fn io_open(
|
||||
lua: &Lua,
|
||||
fs_changes: &Arc<Mutex<HashMap<PathBuf, Vec<u8>>>>,
|
||||
root_dir: Option<&Arc<Path>>,
|
||||
worktree_info: Option<(WorktreeId, Arc<Path>)>,
|
||||
foreground_tx: &mut mpsc::Sender<ForegroundFn>,
|
||||
fs: Arc<dyn Fs>,
|
||||
path_str: String,
|
||||
mode: Option<String>,
|
||||
) -> mlua::Result<(Option<Table>, String)> {
|
||||
let root_dir = root_dir
|
||||
let (worktree_id, root_dir) = worktree_info
|
||||
.ok_or_else(|| mlua::Error::runtime("cannot open file without a root directory"))?;
|
||||
|
||||
let mode = mode.unwrap_or_else(|| "r".to_string());
|
||||
|
@ -224,7 +239,6 @@ impl ScriptSession {
|
|||
let file = lua.create_table()?;
|
||||
|
||||
// Store file metadata in the file
|
||||
file.set("__path", path_str.clone())?;
|
||||
file.set("__mode", mode.clone())?;
|
||||
file.set("__read_perm", read_perm)?;
|
||||
file.set("__write_perm", write_perm)?;
|
||||
|
@ -234,338 +248,354 @@ impl ScriptSession {
|
|||
Err(err) => return Ok((None, format!("{err}"))),
|
||||
};
|
||||
|
||||
// close method
|
||||
let close_fn = {
|
||||
let fs_changes = fs_changes.clone();
|
||||
lua.create_function(move |_lua, file_userdata: mlua::Table| {
|
||||
let write_perm = file_userdata.get::<bool>("__write_perm")?;
|
||||
let path = file_userdata.get::<String>("__path")?;
|
||||
let project_path = ProjectPath {
|
||||
worktree_id,
|
||||
path: Path::new(&path_str).into(),
|
||||
};
|
||||
|
||||
if write_perm {
|
||||
// When closing a writable file, record the content
|
||||
let content = file_userdata.get::<mlua::AnyUserData>("__content")?;
|
||||
let content_ref = content.borrow::<FileContent>()?;
|
||||
let content_vec = content_ref.0.borrow();
|
||||
|
||||
// Don't actually write to disk; instead, just update fs_changes.
|
||||
let path_buf = PathBuf::from(&path);
|
||||
fs_changes
|
||||
.lock()
|
||||
.insert(path_buf.clone(), content_vec.clone());
|
||||
// flush / close method
|
||||
let flush_fn = {
|
||||
let project_path = project_path.clone();
|
||||
let foreground_tx = foreground_tx.clone();
|
||||
lua.create_async_function(move |_lua, file_userdata: mlua::Table| {
|
||||
let project_path = project_path.clone();
|
||||
let mut foreground_tx = foreground_tx.clone();
|
||||
async move {
|
||||
Self::io_file_flush(file_userdata, project_path, &mut foreground_tx).await
|
||||
}
|
||||
|
||||
Ok(true)
|
||||
})?
|
||||
};
|
||||
file.set("close", close_fn)?;
|
||||
file.set("flush", flush_fn.clone())?;
|
||||
// We don't really hold files open, so we only need to flush on close
|
||||
file.set("close", flush_fn)?;
|
||||
|
||||
// If it's a directory, give it a custom read() and return early.
|
||||
if path.is_dir() {
|
||||
// TODO handle the case where we changed it in the in-memory fs
|
||||
|
||||
// Create a special directory handle
|
||||
file.set("__is_directory", true)?;
|
||||
|
||||
// Store directory entries
|
||||
let entries = match std::fs::read_dir(&path) {
|
||||
Ok(entries) => {
|
||||
let mut entry_names = Vec::new();
|
||||
for entry in entries.flatten() {
|
||||
entry_names.push(entry.file_name().to_string_lossy().into_owned());
|
||||
}
|
||||
entry_names
|
||||
}
|
||||
Err(e) => return Ok((None, format!("Error reading directory: {}", e))),
|
||||
};
|
||||
|
||||
// Save the list of entries
|
||||
file.set("__dir_entries", entries)?;
|
||||
file.set("__dir_position", 0usize)?;
|
||||
|
||||
// Create a directory-specific read function
|
||||
let read_fn = lua.create_function(|_lua, file_userdata: mlua::Table| {
|
||||
let position = file_userdata.get::<usize>("__dir_position")?;
|
||||
let entries = file_userdata.get::<Vec<String>>("__dir_entries")?;
|
||||
|
||||
if position >= entries.len() {
|
||||
return Ok(None); // No more entries
|
||||
}
|
||||
|
||||
let entry = entries[position].clone();
|
||||
file_userdata.set("__dir_position", position + 1)?;
|
||||
|
||||
Ok(Some(entry))
|
||||
})?;
|
||||
file.set("read", read_fn)?;
|
||||
|
||||
// If we got this far, the directory was opened successfully
|
||||
return Ok((Some(file), String::new()));
|
||||
if fs.is_dir(&path).await {
|
||||
return Self::io_file_dir(lua, fs, file, &path).await;
|
||||
}
|
||||
|
||||
let fs_changes_map = fs_changes.lock();
|
||||
|
||||
let is_in_changes = fs_changes_map.contains_key(&path);
|
||||
let file_exists = is_in_changes || path.exists();
|
||||
let mut file_content = Vec::new();
|
||||
|
||||
if file_exists && !truncate {
|
||||
if is_in_changes {
|
||||
file_content = fs_changes_map.get(&path).unwrap().clone();
|
||||
} else {
|
||||
// Try to read existing content if file exists and we're not truncating
|
||||
match std::fs::read(&path) {
|
||||
Ok(content) => file_content = content,
|
||||
Err(e) => return Ok((None, format!("Error reading file: {}", e))),
|
||||
}
|
||||
if !truncate {
|
||||
// Try to read existing content if we're not truncating
|
||||
match Self::read_buffer(project_path.clone(), foreground_tx).await {
|
||||
Ok(content) => file_content = content.into_bytes(),
|
||||
Err(e) => return Ok((None, format!("Error reading file: {}", e))),
|
||||
}
|
||||
}
|
||||
|
||||
drop(fs_changes_map); // Unlock the fs_changes mutex.
|
||||
|
||||
// If in append mode, position should be at the end
|
||||
let position = if append && file_exists {
|
||||
file_content.len()
|
||||
} else {
|
||||
0
|
||||
};
|
||||
let position = if append { file_content.len() } else { 0 };
|
||||
file.set("__position", position)?;
|
||||
file.set(
|
||||
"__content",
|
||||
lua.create_userdata(FileContent(RefCell::new(file_content)))?,
|
||||
lua.create_userdata(FileContent(Arc::new(Mutex::new(file_content))))?,
|
||||
)?;
|
||||
|
||||
// Create file methods
|
||||
|
||||
// read method
|
||||
let read_fn = {
|
||||
lua.create_function(
|
||||
|_lua, (file_userdata, format): (mlua::Table, Option<mlua::Value>)| {
|
||||
let read_perm = file_userdata.get::<bool>("__read_perm")?;
|
||||
if !read_perm {
|
||||
return Err(mlua::Error::runtime("File not open for reading"));
|
||||
}
|
||||
|
||||
let content = file_userdata.get::<mlua::AnyUserData>("__content")?;
|
||||
let mut position = file_userdata.get::<usize>("__position")?;
|
||||
let content_ref = content.borrow::<FileContent>()?;
|
||||
let content_vec = content_ref.0.borrow();
|
||||
|
||||
if position >= content_vec.len() {
|
||||
return Ok(None); // EOF
|
||||
}
|
||||
|
||||
match format {
|
||||
Some(mlua::Value::String(s)) => {
|
||||
let lossy_string = s.to_string_lossy();
|
||||
let format_str: &str = lossy_string.as_ref();
|
||||
|
||||
// Only consider the first 2 bytes, since it's common to pass e.g. "*all" instead of "*a"
|
||||
match &format_str[0..2] {
|
||||
"*a" => {
|
||||
// Read entire file from current position
|
||||
let result = String::from_utf8_lossy(&content_vec[position..])
|
||||
.to_string();
|
||||
position = content_vec.len();
|
||||
file_userdata.set("__position", position)?;
|
||||
Ok(Some(result))
|
||||
}
|
||||
"*l" => {
|
||||
// Read next line
|
||||
let mut line = Vec::new();
|
||||
let mut found_newline = false;
|
||||
|
||||
while position < content_vec.len() {
|
||||
let byte = content_vec[position];
|
||||
position += 1;
|
||||
|
||||
if byte == b'\n' {
|
||||
found_newline = true;
|
||||
break;
|
||||
}
|
||||
|
||||
// Skip \r in \r\n sequence but add it if it's alone
|
||||
if byte == b'\r' {
|
||||
if position < content_vec.len()
|
||||
&& content_vec[position] == b'\n'
|
||||
{
|
||||
position += 1;
|
||||
found_newline = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
line.push(byte);
|
||||
}
|
||||
|
||||
file_userdata.set("__position", position)?;
|
||||
|
||||
if !found_newline
|
||||
&& line.is_empty()
|
||||
&& position >= content_vec.len()
|
||||
{
|
||||
return Ok(None); // EOF
|
||||
}
|
||||
|
||||
let result = String::from_utf8_lossy(&line).to_string();
|
||||
Ok(Some(result))
|
||||
}
|
||||
"*n" => {
|
||||
// Try to parse as a number (number of bytes to read)
|
||||
match format_str.parse::<usize>() {
|
||||
Ok(n) => {
|
||||
let end =
|
||||
std::cmp::min(position + n, content_vec.len());
|
||||
let bytes = &content_vec[position..end];
|
||||
let result = String::from_utf8_lossy(bytes).to_string();
|
||||
position = end;
|
||||
file_userdata.set("__position", position)?;
|
||||
Ok(Some(result))
|
||||
}
|
||||
Err(_) => Err(mlua::Error::runtime(format!(
|
||||
"Invalid format: {}",
|
||||
format_str
|
||||
))),
|
||||
}
|
||||
}
|
||||
"*L" => {
|
||||
// Read next line keeping the end of line
|
||||
let mut line = Vec::new();
|
||||
|
||||
while position < content_vec.len() {
|
||||
let byte = content_vec[position];
|
||||
position += 1;
|
||||
|
||||
line.push(byte);
|
||||
|
||||
if byte == b'\n' {
|
||||
break;
|
||||
}
|
||||
|
||||
// If we encounter a \r, add it and check if the next is \n
|
||||
if byte == b'\r'
|
||||
&& position < content_vec.len()
|
||||
&& content_vec[position] == b'\n'
|
||||
{
|
||||
line.push(content_vec[position]);
|
||||
position += 1;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
file_userdata.set("__position", position)?;
|
||||
|
||||
if line.is_empty() && position >= content_vec.len() {
|
||||
return Ok(None); // EOF
|
||||
}
|
||||
|
||||
let result = String::from_utf8_lossy(&line).to_string();
|
||||
Ok(Some(result))
|
||||
}
|
||||
_ => Err(mlua::Error::runtime(format!(
|
||||
"Unsupported format: {}",
|
||||
format_str
|
||||
))),
|
||||
}
|
||||
}
|
||||
Some(mlua::Value::Number(n)) => {
|
||||
// Read n bytes
|
||||
let n = n as usize;
|
||||
let end = std::cmp::min(position + n, content_vec.len());
|
||||
let bytes = &content_vec[position..end];
|
||||
let result = String::from_utf8_lossy(bytes).to_string();
|
||||
position = end;
|
||||
file_userdata.set("__position", position)?;
|
||||
Ok(Some(result))
|
||||
}
|
||||
Some(_) => Err(mlua::Error::runtime("Invalid format")),
|
||||
None => {
|
||||
// Default is to read a line
|
||||
let mut line = Vec::new();
|
||||
let mut found_newline = false;
|
||||
|
||||
while position < content_vec.len() {
|
||||
let byte = content_vec[position];
|
||||
position += 1;
|
||||
|
||||
if byte == b'\n' {
|
||||
found_newline = true;
|
||||
break;
|
||||
}
|
||||
|
||||
// Handle \r\n
|
||||
if byte == b'\r' {
|
||||
if position < content_vec.len()
|
||||
&& content_vec[position] == b'\n'
|
||||
{
|
||||
position += 1;
|
||||
found_newline = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
line.push(byte);
|
||||
}
|
||||
|
||||
file_userdata.set("__position", position)?;
|
||||
|
||||
if !found_newline && line.is_empty() && position >= content_vec.len() {
|
||||
return Ok(None); // EOF
|
||||
}
|
||||
|
||||
let result = String::from_utf8_lossy(&line).to_string();
|
||||
Ok(Some(result))
|
||||
}
|
||||
}
|
||||
},
|
||||
)?
|
||||
};
|
||||
let read_fn = lua.create_function(Self::io_file_read)?;
|
||||
file.set("read", read_fn)?;
|
||||
|
||||
// write method
|
||||
let write_fn = {
|
||||
let fs_changes = fs_changes.clone();
|
||||
|
||||
lua.create_function(move |_lua, (file_userdata, text): (mlua::Table, String)| {
|
||||
let write_perm = file_userdata.get::<bool>("__write_perm")?;
|
||||
if !write_perm {
|
||||
return Err(mlua::Error::runtime("File not open for writing"));
|
||||
}
|
||||
|
||||
let content = file_userdata.get::<mlua::AnyUserData>("__content")?;
|
||||
let position = file_userdata.get::<usize>("__position")?;
|
||||
let content_ref = content.borrow::<FileContent>()?;
|
||||
let mut content_vec = content_ref.0.borrow_mut();
|
||||
|
||||
let bytes = text.as_bytes();
|
||||
|
||||
// Ensure the vector has enough capacity
|
||||
if position + bytes.len() > content_vec.len() {
|
||||
content_vec.resize(position + bytes.len(), 0);
|
||||
}
|
||||
|
||||
// Write the bytes
|
||||
for (i, &byte) in bytes.iter().enumerate() {
|
||||
content_vec[position + i] = byte;
|
||||
}
|
||||
|
||||
// Update position
|
||||
let new_position = position + bytes.len();
|
||||
file_userdata.set("__position", new_position)?;
|
||||
|
||||
// Update fs_changes
|
||||
let path = file_userdata.get::<String>("__path")?;
|
||||
let path_buf = PathBuf::from(path);
|
||||
fs_changes.lock().insert(path_buf, content_vec.clone());
|
||||
|
||||
Ok(true)
|
||||
})?
|
||||
};
|
||||
let write_fn = lua.create_function(Self::io_file_write)?;
|
||||
file.set("write", write_fn)?;
|
||||
|
||||
// If we got this far, the file was opened successfully
|
||||
Ok((Some(file), String::new()))
|
||||
}
|
||||
|
||||
async fn read_buffer(
|
||||
project_path: ProjectPath,
|
||||
foreground_tx: &mut mpsc::Sender<ForegroundFn>,
|
||||
) -> anyhow::Result<String> {
|
||||
Self::run_foreground_fn(
|
||||
"read file from buffer",
|
||||
foreground_tx,
|
||||
Box::new(move |session, mut cx| {
|
||||
session.update(&mut cx, |session, cx| {
|
||||
let open_buffer_task = session
|
||||
.project
|
||||
.update(cx, |project, cx| project.open_buffer(project_path, cx));
|
||||
|
||||
cx.spawn(|_, cx| async move {
|
||||
let buffer = open_buffer_task.await?;
|
||||
|
||||
let text = buffer.read_with(&cx, |buffer, _cx| buffer.text())?;
|
||||
Ok(text)
|
||||
})
|
||||
})
|
||||
}),
|
||||
)
|
||||
.await??
|
||||
.await
|
||||
}
|
||||
|
||||
async fn io_file_flush(
|
||||
file_userdata: mlua::Table,
|
||||
project_path: ProjectPath,
|
||||
foreground_tx: &mut mpsc::Sender<ForegroundFn>,
|
||||
) -> mlua::Result<bool> {
|
||||
let write_perm = file_userdata.get::<bool>("__write_perm")?;
|
||||
|
||||
if write_perm {
|
||||
// When closing a writable file, record the content
|
||||
let content = file_userdata.get::<mlua::AnyUserData>("__content")?;
|
||||
let content_ref = content.borrow::<FileContent>()?;
|
||||
let text = {
|
||||
let mut content_vec = content_ref.0.lock();
|
||||
let content_vec = std::mem::take(&mut *content_vec);
|
||||
String::from_utf8(content_vec).into_lua_err()?
|
||||
};
|
||||
|
||||
Self::write_to_buffer(project_path, text, foreground_tx)
|
||||
.await
|
||||
.into_lua_err()?;
|
||||
}
|
||||
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
async fn write_to_buffer(
|
||||
project_path: ProjectPath,
|
||||
text: String,
|
||||
foreground_tx: &mut mpsc::Sender<ForegroundFn>,
|
||||
) -> anyhow::Result<()> {
|
||||
Self::run_foreground_fn(
|
||||
"write to buffer",
|
||||
foreground_tx,
|
||||
Box::new(move |session, mut cx| {
|
||||
session.update(&mut cx, |session, cx| {
|
||||
let open_buffer_task = session
|
||||
.project
|
||||
.update(cx, |project, cx| project.open_buffer(project_path, cx));
|
||||
|
||||
cx.spawn(move |session, mut cx| async move {
|
||||
let buffer = open_buffer_task.await?;
|
||||
|
||||
let diff = buffer
|
||||
.update(&mut cx, |buffer, cx| buffer.diff(text, cx))?
|
||||
.await;
|
||||
|
||||
buffer.update(&mut cx, |buffer, cx| {
|
||||
buffer.apply_diff(diff, cx);
|
||||
})?;
|
||||
|
||||
session
|
||||
.update(&mut cx, |session, cx| {
|
||||
session
|
||||
.project
|
||||
.update(cx, |project, cx| project.save_buffer(buffer, cx))
|
||||
})?
|
||||
.await
|
||||
})
|
||||
})
|
||||
}),
|
||||
)
|
||||
.await??
|
||||
.await
|
||||
}
|
||||
|
||||
async fn io_file_dir(
|
||||
lua: &Lua,
|
||||
fs: Arc<dyn Fs>,
|
||||
file: Table,
|
||||
path: &Path,
|
||||
) -> mlua::Result<(Option<Table>, String)> {
|
||||
// Create a special directory handle
|
||||
file.set("__is_directory", true)?;
|
||||
|
||||
// Store directory entries
|
||||
let entries = match fs.read_dir(&path).await {
|
||||
Ok(entries) => {
|
||||
let mut entry_names = Vec::new();
|
||||
|
||||
// Process the stream of directory entries
|
||||
pin_mut!(entries);
|
||||
while let Some(Ok(entry_result)) = entries.next().await {
|
||||
if let Some(file_name) = entry_result.file_name() {
|
||||
entry_names.push(file_name.to_string_lossy().into_owned());
|
||||
}
|
||||
}
|
||||
|
||||
entry_names
|
||||
}
|
||||
Err(e) => return Ok((None, format!("Error reading directory: {}", e))),
|
||||
};
|
||||
|
||||
// Save the list of entries
|
||||
file.set("__dir_entries", entries)?;
|
||||
file.set("__dir_position", 0usize)?;
|
||||
|
||||
// Create a directory-specific read function
|
||||
let read_fn = lua.create_function(|_lua, file_userdata: mlua::Table| {
|
||||
let position = file_userdata.get::<usize>("__dir_position")?;
|
||||
let entries = file_userdata.get::<Vec<String>>("__dir_entries")?;
|
||||
|
||||
if position >= entries.len() {
|
||||
return Ok(None); // No more entries
|
||||
}
|
||||
|
||||
let entry = entries[position].clone();
|
||||
file_userdata.set("__dir_position", position + 1)?;
|
||||
|
||||
Ok(Some(entry))
|
||||
})?;
|
||||
file.set("read", read_fn)?;
|
||||
|
||||
// If we got this far, the directory was opened successfully
|
||||
return Ok((Some(file), String::new()));
|
||||
}
|
||||
|
||||
fn io_file_read(
|
||||
lua: &Lua,
|
||||
(file_userdata, format): (Table, Option<mlua::Value>),
|
||||
) -> mlua::Result<Option<mlua::String>> {
|
||||
let read_perm = file_userdata.get::<bool>("__read_perm")?;
|
||||
if !read_perm {
|
||||
return Err(mlua::Error::runtime("File not open for reading"));
|
||||
}
|
||||
|
||||
let content = file_userdata.get::<mlua::AnyUserData>("__content")?;
|
||||
let position = file_userdata.get::<usize>("__position")?;
|
||||
let content_ref = content.borrow::<FileContent>()?;
|
||||
let content = content_ref.0.lock();
|
||||
|
||||
if position >= content.len() {
|
||||
return Ok(None); // EOF
|
||||
}
|
||||
|
||||
let (result, new_position) = match Self::io_file_read_format(format)? {
|
||||
FileReadFormat::All => {
|
||||
// Read entire file from current position
|
||||
let result = content[position..].to_vec();
|
||||
(Some(result), content.len())
|
||||
}
|
||||
FileReadFormat::Line => {
|
||||
if let Some(next_newline_ix) = content[position..].iter().position(|c| *c == b'\n')
|
||||
{
|
||||
let mut line = content[position..position + next_newline_ix].to_vec();
|
||||
if line.ends_with(b"\r") {
|
||||
line.pop();
|
||||
}
|
||||
(Some(line), position + next_newline_ix + 1)
|
||||
} else if position < content.len() {
|
||||
let line = content[position..].to_vec();
|
||||
(Some(line), content.len())
|
||||
} else {
|
||||
(None, position) // EOF
|
||||
}
|
||||
}
|
||||
FileReadFormat::LineWithLineFeed => {
|
||||
if position < content.len() {
|
||||
let next_line_ix = content[position..]
|
||||
.iter()
|
||||
.position(|c| *c == b'\n')
|
||||
.map_or(content.len(), |ix| position + ix + 1);
|
||||
let line = content[position..next_line_ix].to_vec();
|
||||
(Some(line), next_line_ix)
|
||||
} else {
|
||||
(None, position) // EOF
|
||||
}
|
||||
}
|
||||
FileReadFormat::Bytes(n) => {
|
||||
let end = std::cmp::min(position + n, content.len());
|
||||
let result = content[position..end].to_vec();
|
||||
(Some(result), end)
|
||||
}
|
||||
};
|
||||
|
||||
// Update the position in the file userdata
|
||||
if new_position != position {
|
||||
file_userdata.set("__position", new_position)?;
|
||||
}
|
||||
|
||||
// Convert the result to a Lua string
|
||||
match result {
|
||||
Some(bytes) => Ok(Some(lua.create_string(bytes)?)),
|
||||
None => Ok(None),
|
||||
}
|
||||
}
|
||||
|
||||
fn io_file_read_format(format: Option<mlua::Value>) -> mlua::Result<FileReadFormat> {
|
||||
let format = match format {
|
||||
Some(mlua::Value::String(s)) => {
|
||||
let lossy_string = s.to_string_lossy();
|
||||
let format_str: &str = lossy_string.as_ref();
|
||||
|
||||
// Only consider the first 2 bytes, since it's common to pass e.g. "*all" instead of "*a"
|
||||
match &format_str[0..2] {
|
||||
"*a" => FileReadFormat::All,
|
||||
"*l" => FileReadFormat::Line,
|
||||
"*L" => FileReadFormat::LineWithLineFeed,
|
||||
"*n" => {
|
||||
// Try to parse as a number (number of bytes to read)
|
||||
match format_str.parse::<usize>() {
|
||||
Ok(n) => FileReadFormat::Bytes(n),
|
||||
Err(_) => {
|
||||
return Err(mlua::Error::runtime(format!(
|
||||
"Invalid format: {}",
|
||||
format_str
|
||||
)))
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
return Err(mlua::Error::runtime(format!(
|
||||
"Unsupported format: {}",
|
||||
format_str
|
||||
)))
|
||||
}
|
||||
}
|
||||
}
|
||||
Some(mlua::Value::Number(n)) => FileReadFormat::Bytes(n as usize),
|
||||
Some(mlua::Value::Integer(n)) => FileReadFormat::Bytes(n as usize),
|
||||
Some(value) => {
|
||||
return Err(mlua::Error::runtime(format!(
|
||||
"Invalid file format {:?}",
|
||||
value
|
||||
)))
|
||||
}
|
||||
None => FileReadFormat::Line, // Default is to read a line
|
||||
};
|
||||
|
||||
Ok(format)
|
||||
}
|
||||
|
||||
fn io_file_write(
|
||||
_lua: &Lua,
|
||||
(file_userdata, text): (Table, mlua::String),
|
||||
) -> mlua::Result<bool> {
|
||||
let write_perm = file_userdata.get::<bool>("__write_perm")?;
|
||||
if !write_perm {
|
||||
return Err(mlua::Error::runtime("File not open for writing"));
|
||||
}
|
||||
|
||||
let content = file_userdata.get::<mlua::AnyUserData>("__content")?;
|
||||
let position = file_userdata.get::<usize>("__position")?;
|
||||
let content_ref = content.borrow::<FileContent>()?;
|
||||
let mut content_vec = content_ref.0.lock();
|
||||
|
||||
let bytes = text.as_bytes();
|
||||
|
||||
// Ensure the vector has enough capacity
|
||||
if position + bytes.len() > content_vec.len() {
|
||||
content_vec.resize(position + bytes.len(), 0);
|
||||
}
|
||||
|
||||
// Write the bytes
|
||||
for (i, &byte) in bytes.iter().enumerate() {
|
||||
content_vec[position + i] = byte;
|
||||
}
|
||||
|
||||
// Update position
|
||||
let new_position = position + bytes.len();
|
||||
file_userdata.set("__position", new_position)?;
|
||||
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
async fn search(
|
||||
lua: &Lua,
|
||||
foreground_tx: &mut mpsc::Sender<ForegroundFn>,
|
||||
|
@ -789,7 +819,14 @@ impl ScriptSession {
|
|||
}
|
||||
}
|
||||
|
||||
struct FileContent(RefCell<Vec<u8>>);
|
||||
enum FileReadFormat {
|
||||
All,
|
||||
Line,
|
||||
LineWithLineFeed,
|
||||
Bytes(usize),
|
||||
}
|
||||
|
||||
struct FileContent(Arc<Mutex<Vec<u8>>>);
|
||||
|
||||
impl UserData for FileContent {
|
||||
fn add_methods<M: UserDataMethods<Self>>(_methods: &mut M) {
|
||||
|
@ -804,6 +841,7 @@ pub struct Script {
|
|||
pub state: ScriptState,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum ScriptState {
|
||||
Running {
|
||||
stdout: Arc<Mutex<String>>,
|
||||
|
@ -863,6 +901,8 @@ mod tests {
|
|||
assert_eq!(output, "Hello\tworld!\nGoodbye\tmoon!\n");
|
||||
}
|
||||
|
||||
// search
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_search(cx: &mut TestAppContext) {
|
||||
let script = r#"
|
||||
|
@ -880,6 +920,117 @@ mod tests {
|
|||
assert_eq!(output, "File: /file1.txt\nMatches:\n world\n");
|
||||
}
|
||||
|
||||
// io.open
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_open_and_read_file(cx: &mut TestAppContext) {
|
||||
let script = r#"
|
||||
local file = io.open("file1.txt", "r")
|
||||
local content = file:read()
|
||||
print("Content:", content)
|
||||
file:close()
|
||||
"#;
|
||||
|
||||
let output = test_script(script, cx).await.unwrap();
|
||||
assert_eq!(output, "Content:\tHello world!\n");
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_read_write_roundtrip(cx: &mut TestAppContext) {
|
||||
let script = r#"
|
||||
local file = io.open("new_file.txt", "w")
|
||||
file:write("This is new content")
|
||||
file:close()
|
||||
|
||||
-- Read back to verify
|
||||
local read_file = io.open("new_file.txt", "r")
|
||||
if read_file then
|
||||
local content = read_file:read("*a")
|
||||
print("Written content:", content)
|
||||
read_file:close()
|
||||
end
|
||||
"#;
|
||||
|
||||
let output = test_script(script, cx).await.unwrap();
|
||||
assert_eq!(output, "Written content:\tThis is new content\n");
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_multiple_writes(cx: &mut TestAppContext) {
|
||||
let script = r#"
|
||||
-- Test writing to a file multiple times
|
||||
local file = io.open("multiwrite.txt", "w")
|
||||
file:write("First line\n")
|
||||
file:write("Second line\n")
|
||||
file:write("Third line")
|
||||
file:close()
|
||||
|
||||
-- Read back to verify
|
||||
local read_file = io.open("multiwrite.txt", "r")
|
||||
if read_file then
|
||||
local content = read_file:read("*a")
|
||||
print("Full content:", content)
|
||||
read_file:close()
|
||||
end
|
||||
"#;
|
||||
|
||||
let output = test_script(script, cx).await.unwrap();
|
||||
assert_eq!(
|
||||
output,
|
||||
"Full content:\tFirst line\nSecond line\nThird line\n"
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_read_formats(cx: &mut TestAppContext) {
|
||||
let script = r#"
|
||||
local file = io.open("multiline.txt", "w")
|
||||
file:write("Line 1\nLine 2\nLine 3")
|
||||
file:close()
|
||||
|
||||
-- Test "*a" (all)
|
||||
local f = io.open("multiline.txt", "r")
|
||||
local all = f:read("*a")
|
||||
print("All:", all)
|
||||
f:close()
|
||||
|
||||
-- Test "*l" (line)
|
||||
f = io.open("multiline.txt", "r")
|
||||
local line1 = f:read("*l")
|
||||
local line2 = f:read("*l")
|
||||
local line3 = f:read("*l")
|
||||
print("Line 1:", line1)
|
||||
print("Line 2:", line2)
|
||||
print("Line 3:", line3)
|
||||
f:close()
|
||||
|
||||
-- Test "*L" (line with newline)
|
||||
f = io.open("multiline.txt", "r")
|
||||
local line_with_nl = f:read("*L")
|
||||
print("Line with newline length:", #line_with_nl)
|
||||
print("Last char:", string.byte(line_with_nl, #line_with_nl))
|
||||
f:close()
|
||||
|
||||
-- Test number of bytes
|
||||
f = io.open("multiline.txt", "r")
|
||||
local bytes5 = f:read(5)
|
||||
print("5 bytes:", bytes5)
|
||||
f:close()
|
||||
"#;
|
||||
|
||||
let output = test_script(script, cx).await.unwrap();
|
||||
println!("{}", &output);
|
||||
assert!(output.contains("All:\tLine 1\nLine 2\nLine 3"));
|
||||
assert!(output.contains("Line 1:\tLine 1"));
|
||||
assert!(output.contains("Line 2:\tLine 2"));
|
||||
assert!(output.contains("Line 3:\tLine 3"));
|
||||
assert!(output.contains("Line with newline length:\t7"));
|
||||
assert!(output.contains("Last char:\t10")); // LF
|
||||
assert!(output.contains("5 bytes:\tLine "));
|
||||
}
|
||||
|
||||
// helpers
|
||||
|
||||
async fn test_script(source: &str, cx: &mut TestAppContext) -> anyhow::Result<String> {
|
||||
init_test(cx);
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
|
@ -893,19 +1044,29 @@ mod tests {
|
|||
.await;
|
||||
|
||||
let project = Project::test(fs, [Path::new("/")], cx).await;
|
||||
let session = cx.new(|cx| ScriptSession::new(project, cx));
|
||||
let session = cx.new(|cx| ScriptingSession::new(project, cx));
|
||||
|
||||
let (script_id, task) =
|
||||
session.update(cx, |session, cx| session.run_script(source.to_string(), cx));
|
||||
|
||||
task.await;
|
||||
|
||||
Ok(session.read_with(cx, |session, _cx| session.get(script_id).stdout_snapshot()))
|
||||
Ok(session.read_with(cx, |session, _cx| {
|
||||
let script = session.get(script_id);
|
||||
let stdout = script.stdout_snapshot();
|
||||
|
||||
if let ScriptState::Failed { error, .. } = &script.state {
|
||||
panic!("Script failed:\n{}\n\n{}", error, stdout);
|
||||
}
|
||||
|
||||
stdout
|
||||
}))
|
||||
}
|
||||
|
||||
fn init_test(cx: &mut TestAppContext) {
|
||||
let settings_store = cx.update(SettingsStore::test);
|
||||
cx.set_global(settings_store);
|
||||
cx.update(Project::init_settings);
|
||||
cx.update(language::init);
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue