scripting tool: Use project buffers in io.open (#26425)

This PR makes `io.open` use our own implementation again, but instead of
the real filesystem, it will now use the project's to check file
metadata and perform read and writes using project buffers.

This also cleans up the `io.open` implementation by splitting it into
multiple methods, adds tests for various File I/O patterns, and fixes a
few bugs in read formats.

Release Notes:

- N/A
This commit is contained in:
Agus Zubiaga 2025-03-11 00:52:16 -03:00 committed by GitHub
parent d562f58e76
commit bf11b888c3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 502 additions and 335 deletions

View file

@ -17,6 +17,7 @@ anyhow.workspace = true
collections.workspace = true
futures.workspace = true
gpui.workspace = true
language.workspace = true
log.workspace = true
mlua.workspace = true
parking_lot.workspace = true
@ -31,6 +32,7 @@ util.workspace = true
[dev-dependencies]
collections = { workspace = true, features = ["test-support"] }
gpui = { workspace = true, features = ["test-support"] }
language = { workspace = true, features = ["test-support"] }
project = { workspace = true, features = ["test-support"] }
rand.workspace = true
settings = { workspace = true, features = ["test-support"] }

View file

@ -8,9 +8,9 @@ local sandbox = {}
-- to our in-memory log rather than to stdout, we will delete this loop (and re-enable
-- the I/O module being sandboxed below) to have things be sandboxed again.
for k, v in pairs(_G) do
if sandbox[k] == nil then
sandbox[k] = v
end
if sandbox[k] == nil then
sandbox[k] = v
end
end
-- Allow access to standard libraries (safe subset)
@ -29,24 +29,27 @@ sandbox.search = search
sandbox.outline = outline
-- Create a sandboxed version of LuaFileIO
local io = {}
-- local io = {};
--
-- For now we are using unsandboxed io
local io = _G.io;
-- File functions
io.open = sb_io_open
-- Add the sandboxed io library to the sandbox environment
-- sandbox.io = io -- Uncomment this line to re-enable sandboxed file I/O.
sandbox.io = io
-- Load the script with the sandbox environment
local user_script_fn, err = load(user_script, nil, "t", sandbox)
if not user_script_fn then
error("Failed to load user script: " .. tostring(err))
error("Failed to load user script: " .. tostring(err))
end
-- Execute the user script within the sandbox
local success, result = pcall(user_script_fn)
if not success then
error("Error executing user script: " .. tostring(result))
error("Error executing user script: " .. tostring(result))
end

View file

@ -36,7 +36,7 @@ impl ScriptingTool {
};
// TODO: Store a session per thread
let session = cx.new(|cx| ScriptSession::new(project, cx));
let session = cx.new(|cx| ScriptingSession::new(project, cx));
let lua_script = input.lua_script;
let (script_id, script_task) =

View file

@ -1,5 +1,5 @@
use anyhow::anyhow;
use collections::{HashMap, HashSet};
use collections::HashSet;
use futures::{
channel::{mpsc, oneshot},
pin_mut, SinkExt, StreamExt,
@ -7,32 +7,28 @@ use futures::{
use gpui::{AppContext, AsyncApp, Context, Entity, Task, WeakEntity};
use mlua::{ExternalResult, Lua, MultiValue, Table, UserData, UserDataMethods};
use parking_lot::Mutex;
use project::{search::SearchQuery, Fs, Project};
use project::{search::SearchQuery, Fs, Project, ProjectPath, WorktreeId};
use regex::Regex;
use std::{
cell::RefCell,
path::{Path, PathBuf},
sync::Arc,
};
use util::{paths::PathMatcher, ResultExt};
struct ForegroundFn(Box<dyn FnOnce(WeakEntity<ScriptSession>, AsyncApp) + Send>);
struct ForegroundFn(Box<dyn FnOnce(WeakEntity<ScriptingSession>, AsyncApp) + Send>);
pub struct ScriptSession {
pub struct ScriptingSession {
project: Entity<Project>,
// TODO Remove this
fs_changes: Arc<Mutex<HashMap<PathBuf, Vec<u8>>>>,
foreground_fns_tx: mpsc::Sender<ForegroundFn>,
_invoke_foreground_fns: Task<()>,
scripts: Vec<Script>,
}
impl ScriptSession {
impl ScriptingSession {
pub fn new(project: Entity<Project>, cx: &mut Context<Self>) -> Self {
let (foreground_fns_tx, mut foreground_fns_rx) = mpsc::channel(128);
ScriptSession {
ScriptingSession {
project,
fs_changes: Arc::new(Mutex::new(HashMap::default())),
foreground_fns_tx,
_invoke_foreground_fns: cx.spawn(|this, cx| async move {
while let Some(foreground_fn) = foreground_fns_rx.next().await {
@ -88,15 +84,18 @@ impl ScriptSession {
) -> Task<anyhow::Result<()>> {
const SANDBOX_PREAMBLE: &str = include_str!("sandbox_preamble.lua");
// TODO Remove fs_changes
let fs_changes = self.fs_changes.clone();
// TODO Honor all worktrees instead of the first one
let root_dir = self
let worktree_info = self
.project
.read(cx)
.visible_worktrees(cx)
.next()
.map(|worktree| worktree.read(cx).abs_path());
.map(|worktree| {
let worktree = worktree.read(cx);
(worktree.id(), worktree.abs_path())
});
let root_dir = worktree_info.as_ref().map(|(_, root)| root.clone());
let fs = self.project.read(cx).fs().clone();
let foreground_fns_tx = self.foreground_fns_tx.clone();
@ -127,6 +126,7 @@ impl ScriptSession {
"search",
lua.create_async_function({
let foreground_fns_tx = foreground_fns_tx.clone();
let fs = fs.clone();
move |lua, regex| {
let mut foreground_fns_tx = foreground_fns_tx.clone();
let fs = fs.clone();
@ -142,6 +142,7 @@ impl ScriptSession {
"outline",
lua.create_async_function({
let root_dir = root_dir.clone();
let foreground_fns_tx = foreground_fns_tx.clone();
move |_lua, path| {
let mut foreground_fns_tx = foreground_fns_tx.clone();
let root_dir = root_dir.clone();
@ -155,11 +156,24 @@ impl ScriptSession {
)?;
globals.set(
"sb_io_open",
lua.create_function({
let fs_changes = fs_changes.clone();
let root_dir = root_dir.clone();
lua.create_async_function({
let worktree_info = worktree_info.clone();
let foreground_fns_tx = foreground_fns_tx.clone();
move |lua, (path_str, mode)| {
Self::io_open(&lua, &fs_changes, root_dir.as_ref(), path_str, mode)
let worktree_info = worktree_info.clone();
let mut foreground_fns_tx = foreground_fns_tx.clone();
let fs = fs.clone();
async move {
Self::io_open(
&lua,
worktree_info,
&mut foreground_fns_tx,
fs,
path_str,
mode,
)
.await
}
}
})?,
)?;
@ -202,14 +216,15 @@ impl ScriptSession {
}
/// Sandboxed io.open() function in Lua.
fn io_open(
async fn io_open(
lua: &Lua,
fs_changes: &Arc<Mutex<HashMap<PathBuf, Vec<u8>>>>,
root_dir: Option<&Arc<Path>>,
worktree_info: Option<(WorktreeId, Arc<Path>)>,
foreground_tx: &mut mpsc::Sender<ForegroundFn>,
fs: Arc<dyn Fs>,
path_str: String,
mode: Option<String>,
) -> mlua::Result<(Option<Table>, String)> {
let root_dir = root_dir
let (worktree_id, root_dir) = worktree_info
.ok_or_else(|| mlua::Error::runtime("cannot open file without a root directory"))?;
let mode = mode.unwrap_or_else(|| "r".to_string());
@ -224,7 +239,6 @@ impl ScriptSession {
let file = lua.create_table()?;
// Store file metadata in the file
file.set("__path", path_str.clone())?;
file.set("__mode", mode.clone())?;
file.set("__read_perm", read_perm)?;
file.set("__write_perm", write_perm)?;
@ -234,338 +248,354 @@ impl ScriptSession {
Err(err) => return Ok((None, format!("{err}"))),
};
// close method
let close_fn = {
let fs_changes = fs_changes.clone();
lua.create_function(move |_lua, file_userdata: mlua::Table| {
let write_perm = file_userdata.get::<bool>("__write_perm")?;
let path = file_userdata.get::<String>("__path")?;
let project_path = ProjectPath {
worktree_id,
path: Path::new(&path_str).into(),
};
if write_perm {
// When closing a writable file, record the content
let content = file_userdata.get::<mlua::AnyUserData>("__content")?;
let content_ref = content.borrow::<FileContent>()?;
let content_vec = content_ref.0.borrow();
// Don't actually write to disk; instead, just update fs_changes.
let path_buf = PathBuf::from(&path);
fs_changes
.lock()
.insert(path_buf.clone(), content_vec.clone());
// flush / close method
let flush_fn = {
let project_path = project_path.clone();
let foreground_tx = foreground_tx.clone();
lua.create_async_function(move |_lua, file_userdata: mlua::Table| {
let project_path = project_path.clone();
let mut foreground_tx = foreground_tx.clone();
async move {
Self::io_file_flush(file_userdata, project_path, &mut foreground_tx).await
}
Ok(true)
})?
};
file.set("close", close_fn)?;
file.set("flush", flush_fn.clone())?;
// We don't really hold files open, so we only need to flush on close
file.set("close", flush_fn)?;
// If it's a directory, give it a custom read() and return early.
if path.is_dir() {
// TODO handle the case where we changed it in the in-memory fs
// Create a special directory handle
file.set("__is_directory", true)?;
// Store directory entries
let entries = match std::fs::read_dir(&path) {
Ok(entries) => {
let mut entry_names = Vec::new();
for entry in entries.flatten() {
entry_names.push(entry.file_name().to_string_lossy().into_owned());
}
entry_names
}
Err(e) => return Ok((None, format!("Error reading directory: {}", e))),
};
// Save the list of entries
file.set("__dir_entries", entries)?;
file.set("__dir_position", 0usize)?;
// Create a directory-specific read function
let read_fn = lua.create_function(|_lua, file_userdata: mlua::Table| {
let position = file_userdata.get::<usize>("__dir_position")?;
let entries = file_userdata.get::<Vec<String>>("__dir_entries")?;
if position >= entries.len() {
return Ok(None); // No more entries
}
let entry = entries[position].clone();
file_userdata.set("__dir_position", position + 1)?;
Ok(Some(entry))
})?;
file.set("read", read_fn)?;
// If we got this far, the directory was opened successfully
return Ok((Some(file), String::new()));
if fs.is_dir(&path).await {
return Self::io_file_dir(lua, fs, file, &path).await;
}
let fs_changes_map = fs_changes.lock();
let is_in_changes = fs_changes_map.contains_key(&path);
let file_exists = is_in_changes || path.exists();
let mut file_content = Vec::new();
if file_exists && !truncate {
if is_in_changes {
file_content = fs_changes_map.get(&path).unwrap().clone();
} else {
// Try to read existing content if file exists and we're not truncating
match std::fs::read(&path) {
Ok(content) => file_content = content,
Err(e) => return Ok((None, format!("Error reading file: {}", e))),
}
if !truncate {
// Try to read existing content if we're not truncating
match Self::read_buffer(project_path.clone(), foreground_tx).await {
Ok(content) => file_content = content.into_bytes(),
Err(e) => return Ok((None, format!("Error reading file: {}", e))),
}
}
drop(fs_changes_map); // Unlock the fs_changes mutex.
// If in append mode, position should be at the end
let position = if append && file_exists {
file_content.len()
} else {
0
};
let position = if append { file_content.len() } else { 0 };
file.set("__position", position)?;
file.set(
"__content",
lua.create_userdata(FileContent(RefCell::new(file_content)))?,
lua.create_userdata(FileContent(Arc::new(Mutex::new(file_content))))?,
)?;
// Create file methods
// read method
let read_fn = {
lua.create_function(
|_lua, (file_userdata, format): (mlua::Table, Option<mlua::Value>)| {
let read_perm = file_userdata.get::<bool>("__read_perm")?;
if !read_perm {
return Err(mlua::Error::runtime("File not open for reading"));
}
let content = file_userdata.get::<mlua::AnyUserData>("__content")?;
let mut position = file_userdata.get::<usize>("__position")?;
let content_ref = content.borrow::<FileContent>()?;
let content_vec = content_ref.0.borrow();
if position >= content_vec.len() {
return Ok(None); // EOF
}
match format {
Some(mlua::Value::String(s)) => {
let lossy_string = s.to_string_lossy();
let format_str: &str = lossy_string.as_ref();
// Only consider the first 2 bytes, since it's common to pass e.g. "*all" instead of "*a"
match &format_str[0..2] {
"*a" => {
// Read entire file from current position
let result = String::from_utf8_lossy(&content_vec[position..])
.to_string();
position = content_vec.len();
file_userdata.set("__position", position)?;
Ok(Some(result))
}
"*l" => {
// Read next line
let mut line = Vec::new();
let mut found_newline = false;
while position < content_vec.len() {
let byte = content_vec[position];
position += 1;
if byte == b'\n' {
found_newline = true;
break;
}
// Skip \r in \r\n sequence but add it if it's alone
if byte == b'\r' {
if position < content_vec.len()
&& content_vec[position] == b'\n'
{
position += 1;
found_newline = true;
break;
}
}
line.push(byte);
}
file_userdata.set("__position", position)?;
if !found_newline
&& line.is_empty()
&& position >= content_vec.len()
{
return Ok(None); // EOF
}
let result = String::from_utf8_lossy(&line).to_string();
Ok(Some(result))
}
"*n" => {
// Try to parse as a number (number of bytes to read)
match format_str.parse::<usize>() {
Ok(n) => {
let end =
std::cmp::min(position + n, content_vec.len());
let bytes = &content_vec[position..end];
let result = String::from_utf8_lossy(bytes).to_string();
position = end;
file_userdata.set("__position", position)?;
Ok(Some(result))
}
Err(_) => Err(mlua::Error::runtime(format!(
"Invalid format: {}",
format_str
))),
}
}
"*L" => {
// Read next line keeping the end of line
let mut line = Vec::new();
while position < content_vec.len() {
let byte = content_vec[position];
position += 1;
line.push(byte);
if byte == b'\n' {
break;
}
// If we encounter a \r, add it and check if the next is \n
if byte == b'\r'
&& position < content_vec.len()
&& content_vec[position] == b'\n'
{
line.push(content_vec[position]);
position += 1;
break;
}
}
file_userdata.set("__position", position)?;
if line.is_empty() && position >= content_vec.len() {
return Ok(None); // EOF
}
let result = String::from_utf8_lossy(&line).to_string();
Ok(Some(result))
}
_ => Err(mlua::Error::runtime(format!(
"Unsupported format: {}",
format_str
))),
}
}
Some(mlua::Value::Number(n)) => {
// Read n bytes
let n = n as usize;
let end = std::cmp::min(position + n, content_vec.len());
let bytes = &content_vec[position..end];
let result = String::from_utf8_lossy(bytes).to_string();
position = end;
file_userdata.set("__position", position)?;
Ok(Some(result))
}
Some(_) => Err(mlua::Error::runtime("Invalid format")),
None => {
// Default is to read a line
let mut line = Vec::new();
let mut found_newline = false;
while position < content_vec.len() {
let byte = content_vec[position];
position += 1;
if byte == b'\n' {
found_newline = true;
break;
}
// Handle \r\n
if byte == b'\r' {
if position < content_vec.len()
&& content_vec[position] == b'\n'
{
position += 1;
found_newline = true;
break;
}
}
line.push(byte);
}
file_userdata.set("__position", position)?;
if !found_newline && line.is_empty() && position >= content_vec.len() {
return Ok(None); // EOF
}
let result = String::from_utf8_lossy(&line).to_string();
Ok(Some(result))
}
}
},
)?
};
let read_fn = lua.create_function(Self::io_file_read)?;
file.set("read", read_fn)?;
// write method
let write_fn = {
let fs_changes = fs_changes.clone();
lua.create_function(move |_lua, (file_userdata, text): (mlua::Table, String)| {
let write_perm = file_userdata.get::<bool>("__write_perm")?;
if !write_perm {
return Err(mlua::Error::runtime("File not open for writing"));
}
let content = file_userdata.get::<mlua::AnyUserData>("__content")?;
let position = file_userdata.get::<usize>("__position")?;
let content_ref = content.borrow::<FileContent>()?;
let mut content_vec = content_ref.0.borrow_mut();
let bytes = text.as_bytes();
// Ensure the vector has enough capacity
if position + bytes.len() > content_vec.len() {
content_vec.resize(position + bytes.len(), 0);
}
// Write the bytes
for (i, &byte) in bytes.iter().enumerate() {
content_vec[position + i] = byte;
}
// Update position
let new_position = position + bytes.len();
file_userdata.set("__position", new_position)?;
// Update fs_changes
let path = file_userdata.get::<String>("__path")?;
let path_buf = PathBuf::from(path);
fs_changes.lock().insert(path_buf, content_vec.clone());
Ok(true)
})?
};
let write_fn = lua.create_function(Self::io_file_write)?;
file.set("write", write_fn)?;
// If we got this far, the file was opened successfully
Ok((Some(file), String::new()))
}
async fn read_buffer(
project_path: ProjectPath,
foreground_tx: &mut mpsc::Sender<ForegroundFn>,
) -> anyhow::Result<String> {
Self::run_foreground_fn(
"read file from buffer",
foreground_tx,
Box::new(move |session, mut cx| {
session.update(&mut cx, |session, cx| {
let open_buffer_task = session
.project
.update(cx, |project, cx| project.open_buffer(project_path, cx));
cx.spawn(|_, cx| async move {
let buffer = open_buffer_task.await?;
let text = buffer.read_with(&cx, |buffer, _cx| buffer.text())?;
Ok(text)
})
})
}),
)
.await??
.await
}
async fn io_file_flush(
file_userdata: mlua::Table,
project_path: ProjectPath,
foreground_tx: &mut mpsc::Sender<ForegroundFn>,
) -> mlua::Result<bool> {
let write_perm = file_userdata.get::<bool>("__write_perm")?;
if write_perm {
// When closing a writable file, record the content
let content = file_userdata.get::<mlua::AnyUserData>("__content")?;
let content_ref = content.borrow::<FileContent>()?;
let text = {
let mut content_vec = content_ref.0.lock();
let content_vec = std::mem::take(&mut *content_vec);
String::from_utf8(content_vec).into_lua_err()?
};
Self::write_to_buffer(project_path, text, foreground_tx)
.await
.into_lua_err()?;
}
Ok(true)
}
async fn write_to_buffer(
project_path: ProjectPath,
text: String,
foreground_tx: &mut mpsc::Sender<ForegroundFn>,
) -> anyhow::Result<()> {
Self::run_foreground_fn(
"write to buffer",
foreground_tx,
Box::new(move |session, mut cx| {
session.update(&mut cx, |session, cx| {
let open_buffer_task = session
.project
.update(cx, |project, cx| project.open_buffer(project_path, cx));
cx.spawn(move |session, mut cx| async move {
let buffer = open_buffer_task.await?;
let diff = buffer
.update(&mut cx, |buffer, cx| buffer.diff(text, cx))?
.await;
buffer.update(&mut cx, |buffer, cx| {
buffer.apply_diff(diff, cx);
})?;
session
.update(&mut cx, |session, cx| {
session
.project
.update(cx, |project, cx| project.save_buffer(buffer, cx))
})?
.await
})
})
}),
)
.await??
.await
}
async fn io_file_dir(
lua: &Lua,
fs: Arc<dyn Fs>,
file: Table,
path: &Path,
) -> mlua::Result<(Option<Table>, String)> {
// Create a special directory handle
file.set("__is_directory", true)?;
// Store directory entries
let entries = match fs.read_dir(&path).await {
Ok(entries) => {
let mut entry_names = Vec::new();
// Process the stream of directory entries
pin_mut!(entries);
while let Some(Ok(entry_result)) = entries.next().await {
if let Some(file_name) = entry_result.file_name() {
entry_names.push(file_name.to_string_lossy().into_owned());
}
}
entry_names
}
Err(e) => return Ok((None, format!("Error reading directory: {}", e))),
};
// Save the list of entries
file.set("__dir_entries", entries)?;
file.set("__dir_position", 0usize)?;
// Create a directory-specific read function
let read_fn = lua.create_function(|_lua, file_userdata: mlua::Table| {
let position = file_userdata.get::<usize>("__dir_position")?;
let entries = file_userdata.get::<Vec<String>>("__dir_entries")?;
if position >= entries.len() {
return Ok(None); // No more entries
}
let entry = entries[position].clone();
file_userdata.set("__dir_position", position + 1)?;
Ok(Some(entry))
})?;
file.set("read", read_fn)?;
// If we got this far, the directory was opened successfully
return Ok((Some(file), String::new()));
}
fn io_file_read(
lua: &Lua,
(file_userdata, format): (Table, Option<mlua::Value>),
) -> mlua::Result<Option<mlua::String>> {
let read_perm = file_userdata.get::<bool>("__read_perm")?;
if !read_perm {
return Err(mlua::Error::runtime("File not open for reading"));
}
let content = file_userdata.get::<mlua::AnyUserData>("__content")?;
let position = file_userdata.get::<usize>("__position")?;
let content_ref = content.borrow::<FileContent>()?;
let content = content_ref.0.lock();
if position >= content.len() {
return Ok(None); // EOF
}
let (result, new_position) = match Self::io_file_read_format(format)? {
FileReadFormat::All => {
// Read entire file from current position
let result = content[position..].to_vec();
(Some(result), content.len())
}
FileReadFormat::Line => {
if let Some(next_newline_ix) = content[position..].iter().position(|c| *c == b'\n')
{
let mut line = content[position..position + next_newline_ix].to_vec();
if line.ends_with(b"\r") {
line.pop();
}
(Some(line), position + next_newline_ix + 1)
} else if position < content.len() {
let line = content[position..].to_vec();
(Some(line), content.len())
} else {
(None, position) // EOF
}
}
FileReadFormat::LineWithLineFeed => {
if position < content.len() {
let next_line_ix = content[position..]
.iter()
.position(|c| *c == b'\n')
.map_or(content.len(), |ix| position + ix + 1);
let line = content[position..next_line_ix].to_vec();
(Some(line), next_line_ix)
} else {
(None, position) // EOF
}
}
FileReadFormat::Bytes(n) => {
let end = std::cmp::min(position + n, content.len());
let result = content[position..end].to_vec();
(Some(result), end)
}
};
// Update the position in the file userdata
if new_position != position {
file_userdata.set("__position", new_position)?;
}
// Convert the result to a Lua string
match result {
Some(bytes) => Ok(Some(lua.create_string(bytes)?)),
None => Ok(None),
}
}
fn io_file_read_format(format: Option<mlua::Value>) -> mlua::Result<FileReadFormat> {
let format = match format {
Some(mlua::Value::String(s)) => {
let lossy_string = s.to_string_lossy();
let format_str: &str = lossy_string.as_ref();
// Only consider the first 2 bytes, since it's common to pass e.g. "*all" instead of "*a"
match &format_str[0..2] {
"*a" => FileReadFormat::All,
"*l" => FileReadFormat::Line,
"*L" => FileReadFormat::LineWithLineFeed,
"*n" => {
// Try to parse as a number (number of bytes to read)
match format_str.parse::<usize>() {
Ok(n) => FileReadFormat::Bytes(n),
Err(_) => {
return Err(mlua::Error::runtime(format!(
"Invalid format: {}",
format_str
)))
}
}
}
_ => {
return Err(mlua::Error::runtime(format!(
"Unsupported format: {}",
format_str
)))
}
}
}
Some(mlua::Value::Number(n)) => FileReadFormat::Bytes(n as usize),
Some(mlua::Value::Integer(n)) => FileReadFormat::Bytes(n as usize),
Some(value) => {
return Err(mlua::Error::runtime(format!(
"Invalid file format {:?}",
value
)))
}
None => FileReadFormat::Line, // Default is to read a line
};
Ok(format)
}
fn io_file_write(
_lua: &Lua,
(file_userdata, text): (Table, mlua::String),
) -> mlua::Result<bool> {
let write_perm = file_userdata.get::<bool>("__write_perm")?;
if !write_perm {
return Err(mlua::Error::runtime("File not open for writing"));
}
let content = file_userdata.get::<mlua::AnyUserData>("__content")?;
let position = file_userdata.get::<usize>("__position")?;
let content_ref = content.borrow::<FileContent>()?;
let mut content_vec = content_ref.0.lock();
let bytes = text.as_bytes();
// Ensure the vector has enough capacity
if position + bytes.len() > content_vec.len() {
content_vec.resize(position + bytes.len(), 0);
}
// Write the bytes
for (i, &byte) in bytes.iter().enumerate() {
content_vec[position + i] = byte;
}
// Update position
let new_position = position + bytes.len();
file_userdata.set("__position", new_position)?;
Ok(true)
}
async fn search(
lua: &Lua,
foreground_tx: &mut mpsc::Sender<ForegroundFn>,
@ -789,7 +819,14 @@ impl ScriptSession {
}
}
struct FileContent(RefCell<Vec<u8>>);
enum FileReadFormat {
All,
Line,
LineWithLineFeed,
Bytes(usize),
}
struct FileContent(Arc<Mutex<Vec<u8>>>);
impl UserData for FileContent {
fn add_methods<M: UserDataMethods<Self>>(_methods: &mut M) {
@ -804,6 +841,7 @@ pub struct Script {
pub state: ScriptState,
}
#[derive(Debug)]
pub enum ScriptState {
Running {
stdout: Arc<Mutex<String>>,
@ -863,6 +901,8 @@ mod tests {
assert_eq!(output, "Hello\tworld!\nGoodbye\tmoon!\n");
}
// search
#[gpui::test]
async fn test_search(cx: &mut TestAppContext) {
let script = r#"
@ -880,6 +920,117 @@ mod tests {
assert_eq!(output, "File: /file1.txt\nMatches:\n world\n");
}
// io.open
#[gpui::test]
async fn test_open_and_read_file(cx: &mut TestAppContext) {
let script = r#"
local file = io.open("file1.txt", "r")
local content = file:read()
print("Content:", content)
file:close()
"#;
let output = test_script(script, cx).await.unwrap();
assert_eq!(output, "Content:\tHello world!\n");
}
#[gpui::test]
async fn test_read_write_roundtrip(cx: &mut TestAppContext) {
let script = r#"
local file = io.open("new_file.txt", "w")
file:write("This is new content")
file:close()
-- Read back to verify
local read_file = io.open("new_file.txt", "r")
if read_file then
local content = read_file:read("*a")
print("Written content:", content)
read_file:close()
end
"#;
let output = test_script(script, cx).await.unwrap();
assert_eq!(output, "Written content:\tThis is new content\n");
}
#[gpui::test]
async fn test_multiple_writes(cx: &mut TestAppContext) {
let script = r#"
-- Test writing to a file multiple times
local file = io.open("multiwrite.txt", "w")
file:write("First line\n")
file:write("Second line\n")
file:write("Third line")
file:close()
-- Read back to verify
local read_file = io.open("multiwrite.txt", "r")
if read_file then
local content = read_file:read("*a")
print("Full content:", content)
read_file:close()
end
"#;
let output = test_script(script, cx).await.unwrap();
assert_eq!(
output,
"Full content:\tFirst line\nSecond line\nThird line\n"
);
}
#[gpui::test]
async fn test_read_formats(cx: &mut TestAppContext) {
let script = r#"
local file = io.open("multiline.txt", "w")
file:write("Line 1\nLine 2\nLine 3")
file:close()
-- Test "*a" (all)
local f = io.open("multiline.txt", "r")
local all = f:read("*a")
print("All:", all)
f:close()
-- Test "*l" (line)
f = io.open("multiline.txt", "r")
local line1 = f:read("*l")
local line2 = f:read("*l")
local line3 = f:read("*l")
print("Line 1:", line1)
print("Line 2:", line2)
print("Line 3:", line3)
f:close()
-- Test "*L" (line with newline)
f = io.open("multiline.txt", "r")
local line_with_nl = f:read("*L")
print("Line with newline length:", #line_with_nl)
print("Last char:", string.byte(line_with_nl, #line_with_nl))
f:close()
-- Test number of bytes
f = io.open("multiline.txt", "r")
local bytes5 = f:read(5)
print("5 bytes:", bytes5)
f:close()
"#;
let output = test_script(script, cx).await.unwrap();
println!("{}", &output);
assert!(output.contains("All:\tLine 1\nLine 2\nLine 3"));
assert!(output.contains("Line 1:\tLine 1"));
assert!(output.contains("Line 2:\tLine 2"));
assert!(output.contains("Line 3:\tLine 3"));
assert!(output.contains("Line with newline length:\t7"));
assert!(output.contains("Last char:\t10")); // LF
assert!(output.contains("5 bytes:\tLine "));
}
// helpers
async fn test_script(source: &str, cx: &mut TestAppContext) -> anyhow::Result<String> {
init_test(cx);
let fs = FakeFs::new(cx.executor());
@ -893,19 +1044,29 @@ mod tests {
.await;
let project = Project::test(fs, [Path::new("/")], cx).await;
let session = cx.new(|cx| ScriptSession::new(project, cx));
let session = cx.new(|cx| ScriptingSession::new(project, cx));
let (script_id, task) =
session.update(cx, |session, cx| session.run_script(source.to_string(), cx));
task.await;
Ok(session.read_with(cx, |session, _cx| session.get(script_id).stdout_snapshot()))
Ok(session.read_with(cx, |session, _cx| {
let script = session.get(script_id);
let stdout = script.stdout_snapshot();
if let ScriptState::Failed { error, .. } = &script.state {
panic!("Script failed:\n{}\n\n{}", error, stdout);
}
stdout
}))
}
fn init_test(cx: &mut TestAppContext) {
let settings_store = cx.update(SettingsStore::test);
cx.set_global(settings_store);
cx.update(Project::init_settings);
cx.update(language::init);
}
}