assistant: Make scripting a first-class concept instead of a tool (#26338)
This PR makes refactors the scripting functionality to be a first-class concept of the assistant instead of a generic tool, which will allow us to build a more customized experience. - The tool prompt has been slightly tweaked and is now included as a system message in all conversations. I'm getting decent results, but now that it isn't in the tools framework, it will probably require more refining. - The model will now include an `<eval ...>` tag at the end of the message with the script. We parse this tag incrementally as it streams in so that we can indicate that we are generating a script before we see the closing `</eval>` tag. Later, this will help us interpret the script as it arrives also. - Threads now hold a `ScriptSession` entity which manages the state of all scripts (from parsing to exited) in a centralized way, and will later collect all script operations so they can be displayed in the UI. - `script_tool` has been renamed to `assistant_scripting` - Script source now opens in a regular read-only buffer Note: We still need to handle persistence properly Release Notes: - N/A --------- Co-authored-by: Marshall Bowers <git@maxdeviant.com>
This commit is contained in:
parent
ed6bf7f161
commit
e298301b40
16 changed files with 811 additions and 197 deletions
34
crates/assistant_scripting/Cargo.toml
Normal file
34
crates/assistant_scripting/Cargo.toml
Normal file
|
@ -0,0 +1,34 @@
|
|||
[package]
|
||||
name = "assistant_scripting"
|
||||
version = "0.1.0"
|
||||
edition.workspace = true
|
||||
publish.workspace = true
|
||||
license = "GPL-3.0-or-later"
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[lib]
|
||||
path = "src/assistant_scripting.rs"
|
||||
doctest = false
|
||||
|
||||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
collections.workspace = true
|
||||
futures.workspace = true
|
||||
gpui.workspace = true
|
||||
mlua.workspace = true
|
||||
parking_lot.workspace = true
|
||||
project.workspace = true
|
||||
regex.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
settings.workspace = true
|
||||
util.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
collections = { workspace = true, features = ["test-support"] }
|
||||
gpui = { workspace = true, features = ["test-support"] }
|
||||
project = { workspace = true, features = ["test-support"] }
|
||||
rand.workspace = true
|
||||
settings = { workspace = true, features = ["test-support"] }
|
1
crates/assistant_scripting/LICENSE-GPL
Symbolic link
1
crates/assistant_scripting/LICENSE-GPL
Symbolic link
|
@ -0,0 +1 @@
|
|||
../../LICENSE-GPL
|
7
crates/assistant_scripting/src/assistant_scripting.rs
Normal file
7
crates/assistant_scripting/src/assistant_scripting.rs
Normal file
|
@ -0,0 +1,7 @@
|
|||
mod session;
|
||||
mod tag;
|
||||
|
||||
pub use session::*;
|
||||
pub use tag::*;
|
||||
|
||||
pub const SCRIPTING_PROMPT: &str = include_str!("./system_prompt.txt");
|
40
crates/assistant_scripting/src/sandbox_preamble.lua
Normal file
40
crates/assistant_scripting/src/sandbox_preamble.lua
Normal file
|
@ -0,0 +1,40 @@
|
|||
---@diagnostic disable: undefined-global
|
||||
|
||||
-- Create a sandbox environment
|
||||
local sandbox = {}
|
||||
|
||||
-- Allow access to standard libraries (safe subset)
|
||||
sandbox.string = string
|
||||
sandbox.table = table
|
||||
sandbox.math = math
|
||||
sandbox.print = sb_print
|
||||
sandbox.type = type
|
||||
sandbox.tostring = tostring
|
||||
sandbox.tonumber = tonumber
|
||||
sandbox.pairs = pairs
|
||||
sandbox.ipairs = ipairs
|
||||
sandbox.search = search
|
||||
|
||||
-- Create a sandboxed version of LuaFileIO
|
||||
local io = {}
|
||||
|
||||
-- File functions
|
||||
io.open = sb_io_open
|
||||
|
||||
-- Add the sandboxed io library to the sandbox environment
|
||||
sandbox.io = io
|
||||
|
||||
|
||||
-- Load the script with the sandbox environment
|
||||
local user_script_fn, err = load(user_script, nil, "t", sandbox)
|
||||
|
||||
if not user_script_fn then
|
||||
error("Failed to load user script: " .. tostring(err))
|
||||
end
|
||||
|
||||
-- Execute the user script within the sandbox
|
||||
local success, result = pcall(user_script_fn)
|
||||
|
||||
if not success then
|
||||
error("Error executing user script: " .. tostring(result))
|
||||
end
|
880
crates/assistant_scripting/src/session.rs
Normal file
880
crates/assistant_scripting/src/session.rs
Normal file
|
@ -0,0 +1,880 @@
|
|||
use collections::{HashMap, HashSet};
|
||||
use futures::{
|
||||
channel::{mpsc, oneshot},
|
||||
pin_mut, SinkExt, StreamExt,
|
||||
};
|
||||
use gpui::{AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
|
||||
use mlua::{Lua, MultiValue, Table, UserData, UserDataMethods};
|
||||
use parking_lot::Mutex;
|
||||
use project::{search::SearchQuery, Fs, Project};
|
||||
use regex::Regex;
|
||||
use std::{
|
||||
cell::RefCell,
|
||||
path::{Path, PathBuf},
|
||||
sync::Arc,
|
||||
};
|
||||
use util::{paths::PathMatcher, ResultExt};
|
||||
|
||||
use crate::{SCRIPT_END_TAG, SCRIPT_START_TAG};
|
||||
|
||||
struct ForegroundFn(Box<dyn FnOnce(WeakEntity<ScriptSession>, AsyncApp) + Send>);
|
||||
|
||||
pub struct ScriptSession {
|
||||
project: Entity<Project>,
|
||||
// TODO Remove this
|
||||
fs_changes: Arc<Mutex<HashMap<PathBuf, Vec<u8>>>>,
|
||||
foreground_fns_tx: mpsc::Sender<ForegroundFn>,
|
||||
_invoke_foreground_fns: Task<()>,
|
||||
scripts: Vec<Script>,
|
||||
}
|
||||
|
||||
impl ScriptSession {
|
||||
pub fn new(project: Entity<Project>, cx: &mut Context<Self>) -> Self {
|
||||
let (foreground_fns_tx, mut foreground_fns_rx) = mpsc::channel(128);
|
||||
ScriptSession {
|
||||
project,
|
||||
fs_changes: Arc::new(Mutex::new(HashMap::default())),
|
||||
foreground_fns_tx,
|
||||
_invoke_foreground_fns: cx.spawn(|this, cx| async move {
|
||||
while let Some(foreground_fn) = foreground_fns_rx.next().await {
|
||||
foreground_fn.0(this.clone(), cx.clone());
|
||||
}
|
||||
}),
|
||||
scripts: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new_script(&mut self) -> ScriptId {
|
||||
let id = ScriptId(self.scripts.len() as u32);
|
||||
let script = Script {
|
||||
id,
|
||||
state: ScriptState::Generating,
|
||||
source: SharedString::new_static(""),
|
||||
};
|
||||
self.scripts.push(script);
|
||||
id
|
||||
}
|
||||
|
||||
pub fn run_script(
|
||||
&mut self,
|
||||
script_id: ScriptId,
|
||||
script_src: String,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Task<anyhow::Result<()>> {
|
||||
let script = self.get_mut(script_id);
|
||||
|
||||
let stdout = Arc::new(Mutex::new(String::new()));
|
||||
script.source = script_src.clone().into();
|
||||
script.state = ScriptState::Running {
|
||||
stdout: stdout.clone(),
|
||||
};
|
||||
|
||||
let task = self.run_lua(script_src, stdout, cx);
|
||||
|
||||
cx.emit(ScriptEvent::Spawned(script_id));
|
||||
|
||||
cx.spawn(|session, mut cx| async move {
|
||||
let result = task.await;
|
||||
|
||||
session.update(&mut cx, |session, cx| {
|
||||
let script = session.get_mut(script_id);
|
||||
let stdout = script.stdout_snapshot();
|
||||
|
||||
script.state = match result {
|
||||
Ok(()) => ScriptState::Succeeded { stdout },
|
||||
Err(error) => ScriptState::Failed { stdout, error },
|
||||
};
|
||||
|
||||
cx.emit(ScriptEvent::Exited(script_id))
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
fn run_lua(
|
||||
&mut self,
|
||||
script: String,
|
||||
stdout: Arc<Mutex<String>>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Task<anyhow::Result<()>> {
|
||||
const SANDBOX_PREAMBLE: &str = include_str!("sandbox_preamble.lua");
|
||||
|
||||
// TODO Remove fs_changes
|
||||
let fs_changes = self.fs_changes.clone();
|
||||
// TODO Honor all worktrees instead of the first one
|
||||
let root_dir = self
|
||||
.project
|
||||
.read(cx)
|
||||
.visible_worktrees(cx)
|
||||
.next()
|
||||
.map(|worktree| worktree.read(cx).abs_path());
|
||||
|
||||
let fs = self.project.read(cx).fs().clone();
|
||||
let foreground_fns_tx = self.foreground_fns_tx.clone();
|
||||
|
||||
let task = cx.background_spawn({
|
||||
let stdout = stdout.clone();
|
||||
|
||||
async move {
|
||||
let lua = Lua::new();
|
||||
lua.set_memory_limit(2 * 1024 * 1024 * 1024)?; // 2 GB
|
||||
let globals = lua.globals();
|
||||
globals.set(
|
||||
"sb_print",
|
||||
lua.create_function({
|
||||
let stdout = stdout.clone();
|
||||
move |_, args: MultiValue| Self::print(args, &stdout)
|
||||
})?,
|
||||
)?;
|
||||
globals.set(
|
||||
"search",
|
||||
lua.create_async_function({
|
||||
let foreground_fns_tx = foreground_fns_tx.clone();
|
||||
let fs = fs.clone();
|
||||
move |lua, regex| {
|
||||
Self::search(lua, foreground_fns_tx.clone(), fs.clone(), regex)
|
||||
}
|
||||
})?,
|
||||
)?;
|
||||
globals.set(
|
||||
"sb_io_open",
|
||||
lua.create_function({
|
||||
let fs_changes = fs_changes.clone();
|
||||
let root_dir = root_dir.clone();
|
||||
move |lua, (path_str, mode)| {
|
||||
Self::io_open(&lua, &fs_changes, root_dir.as_ref(), path_str, mode)
|
||||
}
|
||||
})?,
|
||||
)?;
|
||||
globals.set("user_script", script)?;
|
||||
|
||||
lua.load(SANDBOX_PREAMBLE).exec_async().await?;
|
||||
|
||||
// Drop Lua instance to decrement reference count.
|
||||
drop(lua);
|
||||
|
||||
anyhow::Ok(())
|
||||
}
|
||||
});
|
||||
|
||||
task
|
||||
}
|
||||
|
||||
pub fn get(&self, script_id: ScriptId) -> &Script {
|
||||
&self.scripts[script_id.0 as usize]
|
||||
}
|
||||
|
||||
fn get_mut(&mut self, script_id: ScriptId) -> &mut Script {
|
||||
&mut self.scripts[script_id.0 as usize]
|
||||
}
|
||||
|
||||
/// Sandboxed print() function in Lua.
|
||||
fn print(args: MultiValue, stdout: &Mutex<String>) -> mlua::Result<()> {
|
||||
for (index, arg) in args.into_iter().enumerate() {
|
||||
// Lua's `print()` prints tab characters between each argument.
|
||||
if index > 0 {
|
||||
stdout.lock().push('\t');
|
||||
}
|
||||
|
||||
// If the argument's to_string() fails, have the whole function call fail.
|
||||
stdout.lock().push_str(&arg.to_string()?);
|
||||
}
|
||||
stdout.lock().push('\n');
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Sandboxed io.open() function in Lua.
|
||||
fn io_open(
|
||||
lua: &Lua,
|
||||
fs_changes: &Arc<Mutex<HashMap<PathBuf, Vec<u8>>>>,
|
||||
root_dir: Option<&Arc<Path>>,
|
||||
path_str: String,
|
||||
mode: Option<String>,
|
||||
) -> mlua::Result<(Option<Table>, String)> {
|
||||
let root_dir = root_dir
|
||||
.ok_or_else(|| mlua::Error::runtime("cannot open file without a root directory"))?;
|
||||
|
||||
let mode = mode.unwrap_or_else(|| "r".to_string());
|
||||
|
||||
// Parse the mode string to determine read/write permissions
|
||||
let read_perm = mode.contains('r');
|
||||
let write_perm = mode.contains('w') || mode.contains('a') || mode.contains('+');
|
||||
let append = mode.contains('a');
|
||||
let truncate = mode.contains('w');
|
||||
|
||||
// This will be the Lua value returned from the `open` function.
|
||||
let file = lua.create_table()?;
|
||||
|
||||
// Store file metadata in the file
|
||||
file.set("__path", path_str.clone())?;
|
||||
file.set("__mode", mode.clone())?;
|
||||
file.set("__read_perm", read_perm)?;
|
||||
file.set("__write_perm", write_perm)?;
|
||||
|
||||
// Sandbox the path; it must be within root_dir
|
||||
let path: PathBuf = {
|
||||
let rust_path = Path::new(&path_str);
|
||||
|
||||
// Get absolute path
|
||||
if rust_path.is_absolute() {
|
||||
// Check if path starts with root_dir prefix without resolving symlinks
|
||||
if !rust_path.starts_with(&root_dir) {
|
||||
return Ok((
|
||||
None,
|
||||
format!(
|
||||
"Error: Absolute path {} is outside the current working directory",
|
||||
path_str
|
||||
),
|
||||
));
|
||||
}
|
||||
rust_path.to_path_buf()
|
||||
} else {
|
||||
// Make relative path absolute relative to cwd
|
||||
root_dir.join(rust_path)
|
||||
}
|
||||
};
|
||||
|
||||
// close method
|
||||
let close_fn = {
|
||||
let fs_changes = fs_changes.clone();
|
||||
lua.create_function(move |_lua, file_userdata: mlua::Table| {
|
||||
let write_perm = file_userdata.get::<bool>("__write_perm")?;
|
||||
let path = file_userdata.get::<String>("__path")?;
|
||||
|
||||
if write_perm {
|
||||
// When closing a writable file, record the content
|
||||
let content = file_userdata.get::<mlua::AnyUserData>("__content")?;
|
||||
let content_ref = content.borrow::<FileContent>()?;
|
||||
let content_vec = content_ref.0.borrow();
|
||||
|
||||
// Don't actually write to disk; instead, just update fs_changes.
|
||||
let path_buf = PathBuf::from(&path);
|
||||
fs_changes
|
||||
.lock()
|
||||
.insert(path_buf.clone(), content_vec.clone());
|
||||
}
|
||||
|
||||
Ok(true)
|
||||
})?
|
||||
};
|
||||
file.set("close", close_fn)?;
|
||||
|
||||
// If it's a directory, give it a custom read() and return early.
|
||||
if path.is_dir() {
|
||||
// TODO handle the case where we changed it in the in-memory fs
|
||||
|
||||
// Create a special directory handle
|
||||
file.set("__is_directory", true)?;
|
||||
|
||||
// Store directory entries
|
||||
let entries = match std::fs::read_dir(&path) {
|
||||
Ok(entries) => {
|
||||
let mut entry_names = Vec::new();
|
||||
for entry in entries.flatten() {
|
||||
entry_names.push(entry.file_name().to_string_lossy().into_owned());
|
||||
}
|
||||
entry_names
|
||||
}
|
||||
Err(e) => return Ok((None, format!("Error reading directory: {}", e))),
|
||||
};
|
||||
|
||||
// Save the list of entries
|
||||
file.set("__dir_entries", entries)?;
|
||||
file.set("__dir_position", 0usize)?;
|
||||
|
||||
// Create a directory-specific read function
|
||||
let read_fn = lua.create_function(|_lua, file_userdata: mlua::Table| {
|
||||
let position = file_userdata.get::<usize>("__dir_position")?;
|
||||
let entries = file_userdata.get::<Vec<String>>("__dir_entries")?;
|
||||
|
||||
if position >= entries.len() {
|
||||
return Ok(None); // No more entries
|
||||
}
|
||||
|
||||
let entry = entries[position].clone();
|
||||
file_userdata.set("__dir_position", position + 1)?;
|
||||
|
||||
Ok(Some(entry))
|
||||
})?;
|
||||
file.set("read", read_fn)?;
|
||||
|
||||
// If we got this far, the directory was opened successfully
|
||||
return Ok((Some(file), String::new()));
|
||||
}
|
||||
|
||||
let fs_changes_map = fs_changes.lock();
|
||||
|
||||
let is_in_changes = fs_changes_map.contains_key(&path);
|
||||
let file_exists = is_in_changes || path.exists();
|
||||
let mut file_content = Vec::new();
|
||||
|
||||
if file_exists && !truncate {
|
||||
if is_in_changes {
|
||||
file_content = fs_changes_map.get(&path).unwrap().clone();
|
||||
} else {
|
||||
// Try to read existing content if file exists and we're not truncating
|
||||
match std::fs::read(&path) {
|
||||
Ok(content) => file_content = content,
|
||||
Err(e) => return Ok((None, format!("Error reading file: {}", e))),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
drop(fs_changes_map); // Unlock the fs_changes mutex.
|
||||
|
||||
// If in append mode, position should be at the end
|
||||
let position = if append && file_exists {
|
||||
file_content.len()
|
||||
} else {
|
||||
0
|
||||
};
|
||||
file.set("__position", position)?;
|
||||
file.set(
|
||||
"__content",
|
||||
lua.create_userdata(FileContent(RefCell::new(file_content)))?,
|
||||
)?;
|
||||
|
||||
// Create file methods
|
||||
|
||||
// read method
|
||||
let read_fn = {
|
||||
lua.create_function(
|
||||
|_lua, (file_userdata, format): (mlua::Table, Option<mlua::Value>)| {
|
||||
let read_perm = file_userdata.get::<bool>("__read_perm")?;
|
||||
if !read_perm {
|
||||
return Err(mlua::Error::runtime("File not open for reading"));
|
||||
}
|
||||
|
||||
let content = file_userdata.get::<mlua::AnyUserData>("__content")?;
|
||||
let mut position = file_userdata.get::<usize>("__position")?;
|
||||
let content_ref = content.borrow::<FileContent>()?;
|
||||
let content_vec = content_ref.0.borrow();
|
||||
|
||||
if position >= content_vec.len() {
|
||||
return Ok(None); // EOF
|
||||
}
|
||||
|
||||
match format {
|
||||
Some(mlua::Value::String(s)) => {
|
||||
let lossy_string = s.to_string_lossy();
|
||||
let format_str: &str = lossy_string.as_ref();
|
||||
|
||||
// Only consider the first 2 bytes, since it's common to pass e.g. "*all" instead of "*a"
|
||||
match &format_str[0..2] {
|
||||
"*a" => {
|
||||
// Read entire file from current position
|
||||
let result = String::from_utf8_lossy(&content_vec[position..])
|
||||
.to_string();
|
||||
position = content_vec.len();
|
||||
file_userdata.set("__position", position)?;
|
||||
Ok(Some(result))
|
||||
}
|
||||
"*l" => {
|
||||
// Read next line
|
||||
let mut line = Vec::new();
|
||||
let mut found_newline = false;
|
||||
|
||||
while position < content_vec.len() {
|
||||
let byte = content_vec[position];
|
||||
position += 1;
|
||||
|
||||
if byte == b'\n' {
|
||||
found_newline = true;
|
||||
break;
|
||||
}
|
||||
|
||||
// Skip \r in \r\n sequence but add it if it's alone
|
||||
if byte == b'\r' {
|
||||
if position < content_vec.len()
|
||||
&& content_vec[position] == b'\n'
|
||||
{
|
||||
position += 1;
|
||||
found_newline = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
line.push(byte);
|
||||
}
|
||||
|
||||
file_userdata.set("__position", position)?;
|
||||
|
||||
if !found_newline
|
||||
&& line.is_empty()
|
||||
&& position >= content_vec.len()
|
||||
{
|
||||
return Ok(None); // EOF
|
||||
}
|
||||
|
||||
let result = String::from_utf8_lossy(&line).to_string();
|
||||
Ok(Some(result))
|
||||
}
|
||||
"*n" => {
|
||||
// Try to parse as a number (number of bytes to read)
|
||||
match format_str.parse::<usize>() {
|
||||
Ok(n) => {
|
||||
let end =
|
||||
std::cmp::min(position + n, content_vec.len());
|
||||
let bytes = &content_vec[position..end];
|
||||
let result = String::from_utf8_lossy(bytes).to_string();
|
||||
position = end;
|
||||
file_userdata.set("__position", position)?;
|
||||
Ok(Some(result))
|
||||
}
|
||||
Err(_) => Err(mlua::Error::runtime(format!(
|
||||
"Invalid format: {}",
|
||||
format_str
|
||||
))),
|
||||
}
|
||||
}
|
||||
"*L" => {
|
||||
// Read next line keeping the end of line
|
||||
let mut line = Vec::new();
|
||||
|
||||
while position < content_vec.len() {
|
||||
let byte = content_vec[position];
|
||||
position += 1;
|
||||
|
||||
line.push(byte);
|
||||
|
||||
if byte == b'\n' {
|
||||
break;
|
||||
}
|
||||
|
||||
// If we encounter a \r, add it and check if the next is \n
|
||||
if byte == b'\r'
|
||||
&& position < content_vec.len()
|
||||
&& content_vec[position] == b'\n'
|
||||
{
|
||||
line.push(content_vec[position]);
|
||||
position += 1;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
file_userdata.set("__position", position)?;
|
||||
|
||||
if line.is_empty() && position >= content_vec.len() {
|
||||
return Ok(None); // EOF
|
||||
}
|
||||
|
||||
let result = String::from_utf8_lossy(&line).to_string();
|
||||
Ok(Some(result))
|
||||
}
|
||||
_ => Err(mlua::Error::runtime(format!(
|
||||
"Unsupported format: {}",
|
||||
format_str
|
||||
))),
|
||||
}
|
||||
}
|
||||
Some(mlua::Value::Number(n)) => {
|
||||
// Read n bytes
|
||||
let n = n as usize;
|
||||
let end = std::cmp::min(position + n, content_vec.len());
|
||||
let bytes = &content_vec[position..end];
|
||||
let result = String::from_utf8_lossy(bytes).to_string();
|
||||
position = end;
|
||||
file_userdata.set("__position", position)?;
|
||||
Ok(Some(result))
|
||||
}
|
||||
Some(_) => Err(mlua::Error::runtime("Invalid format")),
|
||||
None => {
|
||||
// Default is to read a line
|
||||
let mut line = Vec::new();
|
||||
let mut found_newline = false;
|
||||
|
||||
while position < content_vec.len() {
|
||||
let byte = content_vec[position];
|
||||
position += 1;
|
||||
|
||||
if byte == b'\n' {
|
||||
found_newline = true;
|
||||
break;
|
||||
}
|
||||
|
||||
// Handle \r\n
|
||||
if byte == b'\r' {
|
||||
if position < content_vec.len()
|
||||
&& content_vec[position] == b'\n'
|
||||
{
|
||||
position += 1;
|
||||
found_newline = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
line.push(byte);
|
||||
}
|
||||
|
||||
file_userdata.set("__position", position)?;
|
||||
|
||||
if !found_newline && line.is_empty() && position >= content_vec.len() {
|
||||
return Ok(None); // EOF
|
||||
}
|
||||
|
||||
let result = String::from_utf8_lossy(&line).to_string();
|
||||
Ok(Some(result))
|
||||
}
|
||||
}
|
||||
},
|
||||
)?
|
||||
};
|
||||
file.set("read", read_fn)?;
|
||||
|
||||
// write method
|
||||
let write_fn = {
|
||||
let fs_changes = fs_changes.clone();
|
||||
|
||||
lua.create_function(move |_lua, (file_userdata, text): (mlua::Table, String)| {
|
||||
let write_perm = file_userdata.get::<bool>("__write_perm")?;
|
||||
if !write_perm {
|
||||
return Err(mlua::Error::runtime("File not open for writing"));
|
||||
}
|
||||
|
||||
let content = file_userdata.get::<mlua::AnyUserData>("__content")?;
|
||||
let position = file_userdata.get::<usize>("__position")?;
|
||||
let content_ref = content.borrow::<FileContent>()?;
|
||||
let mut content_vec = content_ref.0.borrow_mut();
|
||||
|
||||
let bytes = text.as_bytes();
|
||||
|
||||
// Ensure the vector has enough capacity
|
||||
if position + bytes.len() > content_vec.len() {
|
||||
content_vec.resize(position + bytes.len(), 0);
|
||||
}
|
||||
|
||||
// Write the bytes
|
||||
for (i, &byte) in bytes.iter().enumerate() {
|
||||
content_vec[position + i] = byte;
|
||||
}
|
||||
|
||||
// Update position
|
||||
let new_position = position + bytes.len();
|
||||
file_userdata.set("__position", new_position)?;
|
||||
|
||||
// Update fs_changes
|
||||
let path = file_userdata.get::<String>("__path")?;
|
||||
let path_buf = PathBuf::from(path);
|
||||
fs_changes.lock().insert(path_buf, content_vec.clone());
|
||||
|
||||
Ok(true)
|
||||
})?
|
||||
};
|
||||
file.set("write", write_fn)?;
|
||||
|
||||
// If we got this far, the file was opened successfully
|
||||
Ok((Some(file), String::new()))
|
||||
}
|
||||
|
||||
async fn search(
|
||||
lua: Lua,
|
||||
mut foreground_tx: mpsc::Sender<ForegroundFn>,
|
||||
fs: Arc<dyn Fs>,
|
||||
regex: String,
|
||||
) -> mlua::Result<Table> {
|
||||
// TODO: Allow specification of these options.
|
||||
let search_query = SearchQuery::regex(
|
||||
®ex,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
PathMatcher::default(),
|
||||
PathMatcher::default(),
|
||||
None,
|
||||
);
|
||||
let search_query = match search_query {
|
||||
Ok(query) => query,
|
||||
Err(e) => return Err(mlua::Error::runtime(format!("Invalid search query: {}", e))),
|
||||
};
|
||||
|
||||
// TODO: Should use `search_query.regex`. The tool description should also be updated,
|
||||
// as it specifies standard regex.
|
||||
let search_regex = match Regex::new(®ex) {
|
||||
Ok(re) => re,
|
||||
Err(e) => return Err(mlua::Error::runtime(format!("Invalid regex: {}", e))),
|
||||
};
|
||||
|
||||
let mut abs_paths_rx =
|
||||
Self::find_search_candidates(search_query, &mut foreground_tx).await?;
|
||||
|
||||
let mut search_results: Vec<Table> = Vec::new();
|
||||
while let Some(path) = abs_paths_rx.next().await {
|
||||
// Skip files larger than 1MB
|
||||
if let Ok(Some(metadata)) = fs.metadata(&path).await {
|
||||
if metadata.len > 1_000_000 {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// Attempt to read the file as text
|
||||
if let Ok(content) = fs.load(&path).await {
|
||||
let mut matches = Vec::new();
|
||||
|
||||
// Find all regex matches in the content
|
||||
for capture in search_regex.find_iter(&content) {
|
||||
matches.push(capture.as_str().to_string());
|
||||
}
|
||||
|
||||
// If we found matches, create a result entry
|
||||
if !matches.is_empty() {
|
||||
let result_entry = lua.create_table()?;
|
||||
result_entry.set("path", path.to_string_lossy().to_string())?;
|
||||
|
||||
let matches_table = lua.create_table()?;
|
||||
for (ix, m) in matches.iter().enumerate() {
|
||||
matches_table.set(ix + 1, m.clone())?;
|
||||
}
|
||||
result_entry.set("matches", matches_table)?;
|
||||
|
||||
search_results.push(result_entry);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Create a table to hold our results
|
||||
let results_table = lua.create_table()?;
|
||||
for (ix, entry) in search_results.into_iter().enumerate() {
|
||||
results_table.set(ix + 1, entry)?;
|
||||
}
|
||||
|
||||
Ok(results_table)
|
||||
}
|
||||
|
||||
async fn find_search_candidates(
|
||||
search_query: SearchQuery,
|
||||
foreground_tx: &mut mpsc::Sender<ForegroundFn>,
|
||||
) -> mlua::Result<mpsc::UnboundedReceiver<PathBuf>> {
|
||||
Self::run_foreground_fn(
|
||||
"finding search file candidates",
|
||||
foreground_tx,
|
||||
Box::new(move |session, mut cx| {
|
||||
session.update(&mut cx, |session, cx| {
|
||||
session.project.update(cx, |project, cx| {
|
||||
project.worktree_store().update(cx, |worktree_store, cx| {
|
||||
// TODO: Better limit? For now this is the same as
|
||||
// MAX_SEARCH_RESULT_FILES.
|
||||
let limit = 5000;
|
||||
// TODO: Providing non-empty open_entries can make this a bit more
|
||||
// efficient as it can skip checking that these paths are textual.
|
||||
let open_entries = HashSet::default();
|
||||
let candidates = worktree_store.find_search_candidates(
|
||||
search_query,
|
||||
limit,
|
||||
open_entries,
|
||||
project.fs().clone(),
|
||||
cx,
|
||||
);
|
||||
let (abs_paths_tx, abs_paths_rx) = mpsc::unbounded();
|
||||
cx.spawn(|worktree_store, cx| async move {
|
||||
pin_mut!(candidates);
|
||||
|
||||
while let Some(project_path) = candidates.next().await {
|
||||
worktree_store.read_with(&cx, |worktree_store, cx| {
|
||||
if let Some(worktree) = worktree_store
|
||||
.worktree_for_id(project_path.worktree_id, cx)
|
||||
{
|
||||
if let Some(abs_path) = worktree
|
||||
.read(cx)
|
||||
.absolutize(&project_path.path)
|
||||
.log_err()
|
||||
{
|
||||
abs_paths_tx.unbounded_send(abs_path)?;
|
||||
}
|
||||
}
|
||||
anyhow::Ok(())
|
||||
})??;
|
||||
}
|
||||
anyhow::Ok(())
|
||||
})
|
||||
.detach();
|
||||
abs_paths_rx
|
||||
})
|
||||
})
|
||||
})
|
||||
}),
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn run_foreground_fn<R: Send + 'static>(
|
||||
description: &str,
|
||||
foreground_tx: &mut mpsc::Sender<ForegroundFn>,
|
||||
function: Box<dyn FnOnce(WeakEntity<Self>, AsyncApp) -> anyhow::Result<R> + Send>,
|
||||
) -> mlua::Result<R> {
|
||||
let (response_tx, response_rx) = oneshot::channel();
|
||||
let send_result = foreground_tx
|
||||
.send(ForegroundFn(Box::new(move |this, cx| {
|
||||
response_tx.send(function(this, cx)).ok();
|
||||
})))
|
||||
.await;
|
||||
match send_result {
|
||||
Ok(()) => (),
|
||||
Err(err) => {
|
||||
return Err(mlua::Error::runtime(format!(
|
||||
"Internal error while enqueuing work for {description}: {err}"
|
||||
)))
|
||||
}
|
||||
}
|
||||
match response_rx.await {
|
||||
Ok(Ok(result)) => Ok(result),
|
||||
Ok(Err(err)) => Err(mlua::Error::runtime(format!(
|
||||
"Error while {description}: {err}"
|
||||
))),
|
||||
Err(oneshot::Canceled) => Err(mlua::Error::runtime(format!(
|
||||
"Internal error: response oneshot was canceled while {description}."
|
||||
))),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct FileContent(RefCell<Vec<u8>>);
|
||||
|
||||
impl UserData for FileContent {
|
||||
fn add_methods<M: UserDataMethods<Self>>(_methods: &mut M) {
|
||||
// FileContent doesn't have any methods so far.
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum ScriptEvent {
|
||||
Spawned(ScriptId),
|
||||
Exited(ScriptId),
|
||||
}
|
||||
|
||||
impl EventEmitter<ScriptEvent> for ScriptSession {}
|
||||
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
|
||||
pub struct ScriptId(u32);
|
||||
|
||||
pub struct Script {
|
||||
pub id: ScriptId,
|
||||
pub state: ScriptState,
|
||||
pub source: SharedString,
|
||||
}
|
||||
|
||||
pub enum ScriptState {
|
||||
Generating,
|
||||
Running {
|
||||
stdout: Arc<Mutex<String>>,
|
||||
},
|
||||
Succeeded {
|
||||
stdout: String,
|
||||
},
|
||||
Failed {
|
||||
stdout: String,
|
||||
error: anyhow::Error,
|
||||
},
|
||||
}
|
||||
|
||||
impl Script {
|
||||
pub fn source_tag(&self) -> String {
|
||||
format!("{}{}{}", SCRIPT_START_TAG, self.source, SCRIPT_END_TAG)
|
||||
}
|
||||
|
||||
/// If exited, returns a message with the output for the LLM
|
||||
pub fn output_message_for_llm(&self) -> Option<String> {
|
||||
match &self.state {
|
||||
ScriptState::Generating { .. } => None,
|
||||
ScriptState::Running { .. } => None,
|
||||
ScriptState::Succeeded { stdout } => {
|
||||
format!("Here's the script output:\n{}", stdout).into()
|
||||
}
|
||||
ScriptState::Failed { stdout, error } => format!(
|
||||
"The script failed with:\n{}\n\nHere's the output it managed to print:\n{}",
|
||||
error, stdout
|
||||
)
|
||||
.into(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get a snapshot of the script's stdout
|
||||
pub fn stdout_snapshot(&self) -> String {
|
||||
match &self.state {
|
||||
ScriptState::Generating { .. } => String::new(),
|
||||
ScriptState::Running { stdout } => stdout.lock().clone(),
|
||||
ScriptState::Succeeded { stdout } => stdout.clone(),
|
||||
ScriptState::Failed { stdout, .. } => stdout.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the error if the script failed, otherwise None
|
||||
pub fn error(&self) -> Option<&anyhow::Error> {
|
||||
match &self.state {
|
||||
ScriptState::Generating { .. } => None,
|
||||
ScriptState::Running { .. } => None,
|
||||
ScriptState::Succeeded { .. } => None,
|
||||
ScriptState::Failed { error, .. } => Some(error),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use gpui::TestAppContext;
|
||||
use project::FakeFs;
|
||||
use serde_json::json;
|
||||
use settings::SettingsStore;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_print(cx: &mut TestAppContext) {
|
||||
let script = r#"
|
||||
print("Hello", "world!")
|
||||
print("Goodbye", "moon!")
|
||||
"#;
|
||||
|
||||
let output = test_script(script, cx).await.unwrap();
|
||||
assert_eq!(output, "Hello\tworld!\nGoodbye\tmoon!\n");
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_search(cx: &mut TestAppContext) {
|
||||
let script = r#"
|
||||
local results = search("world")
|
||||
for i, result in ipairs(results) do
|
||||
print("File: " .. result.path)
|
||||
print("Matches:")
|
||||
for j, match in ipairs(result.matches) do
|
||||
print(" " .. match)
|
||||
end
|
||||
end
|
||||
"#;
|
||||
|
||||
let output = test_script(script, cx).await.unwrap();
|
||||
assert_eq!(output, "File: /file1.txt\nMatches:\n world\n");
|
||||
}
|
||||
|
||||
async fn test_script(source: &str, cx: &mut TestAppContext) -> anyhow::Result<String> {
|
||||
init_test(cx);
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
fs.insert_tree(
|
||||
"/",
|
||||
json!({
|
||||
"file1.txt": "Hello world!",
|
||||
"file2.txt": "Goodbye moon!"
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
|
||||
let project = Project::test(fs, [Path::new("/")], cx).await;
|
||||
let session = cx.new(|cx| ScriptSession::new(project, cx));
|
||||
|
||||
let (script_id, task) = session.update(cx, |session, cx| {
|
||||
let script_id = session.new_script();
|
||||
let task = session.run_script(script_id, source.to_string(), cx);
|
||||
|
||||
(script_id, task)
|
||||
});
|
||||
|
||||
task.await?;
|
||||
|
||||
Ok(session.read_with(cx, |session, _cx| session.get(script_id).stdout_snapshot()))
|
||||
}
|
||||
|
||||
fn init_test(cx: &mut TestAppContext) {
|
||||
let settings_store = cx.update(SettingsStore::test);
|
||||
cx.set_global(settings_store);
|
||||
cx.update(Project::init_settings);
|
||||
}
|
||||
}
|
32
crates/assistant_scripting/src/system_prompt.txt
Normal file
32
crates/assistant_scripting/src/system_prompt.txt
Normal file
|
@ -0,0 +1,32 @@
|
|||
You can write a Lua script and I'll run it on my codebase and tell you what its
|
||||
output was, including both stdout as well as the git diff of changes it made to
|
||||
the filesystem. That way, you can get more information about the code base, or
|
||||
make changes to the code base directly.
|
||||
|
||||
Put the Lua script inside of an `<eval>` tag like so:
|
||||
|
||||
<eval type="lua">
|
||||
print("Hello, world!")
|
||||
</eval>
|
||||
|
||||
The Lua script will have access to `io` and it will run with the current working
|
||||
directory being in the root of the code base, so you can use it to explore,
|
||||
search, make changes, etc. You can also have the script print things, and I'll
|
||||
tell you what the output was. Note that `io` only has `open`, and then the file
|
||||
it returns only has the methods read, write, and close - it doesn't have popen
|
||||
or anything else.
|
||||
|
||||
There will be a global called `search` which accepts a regex (it's implemented
|
||||
using Rust's regex crate, so use that regex syntax) and runs that regex on the
|
||||
contents of every file in the code base (aside from gitignored files), then
|
||||
returns an array of tables with two fields: "path" (the path to the file that
|
||||
had the matches) and "matches" (an array of strings, with each string being a
|
||||
match that was found within the file).
|
||||
|
||||
When I send you the script output, do not thank me for running it,
|
||||
act as if you ran it yourself.
|
||||
|
||||
IMPORTANT!
|
||||
Only include a maximum of one Lua script at the very end of your message
|
||||
DO NOT WRITE ANYTHING ELSE AFTER THE SCRIPT. Wait for my response with the script
|
||||
output to continue.
|
260
crates/assistant_scripting/src/tag.rs
Normal file
260
crates/assistant_scripting/src/tag.rs
Normal file
|
@ -0,0 +1,260 @@
|
|||
pub const SCRIPT_START_TAG: &str = "<eval type=\"lua\">";
|
||||
pub const SCRIPT_END_TAG: &str = "</eval>";
|
||||
|
||||
const START_TAG: &[u8] = SCRIPT_START_TAG.as_bytes();
|
||||
const END_TAG: &[u8] = SCRIPT_END_TAG.as_bytes();
|
||||
|
||||
/// Parses a script tag in an assistant message as it is being streamed.
|
||||
pub struct ScriptTagParser {
|
||||
state: State,
|
||||
buffer: Vec<u8>,
|
||||
tag_match_ix: usize,
|
||||
}
|
||||
|
||||
enum State {
|
||||
Unstarted,
|
||||
Streaming,
|
||||
Ended,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq)]
|
||||
pub struct ChunkOutput {
|
||||
/// The chunk with script tags removed.
|
||||
pub content: String,
|
||||
/// The full script tag content. `None` until closed.
|
||||
pub script_source: Option<String>,
|
||||
}
|
||||
|
||||
impl ScriptTagParser {
|
||||
/// Create a new script tag parser.
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
state: State::Unstarted,
|
||||
buffer: Vec::new(),
|
||||
tag_match_ix: 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns true if the parser has found a script tag.
|
||||
pub fn found_script(&self) -> bool {
|
||||
match self.state {
|
||||
State::Unstarted => false,
|
||||
State::Streaming | State::Ended => true,
|
||||
}
|
||||
}
|
||||
|
||||
/// Process a new chunk of input, splitting it into surrounding content and script source.
|
||||
pub fn parse_chunk(&mut self, input: &str) -> ChunkOutput {
|
||||
let mut content = Vec::with_capacity(input.len());
|
||||
|
||||
for byte in input.bytes() {
|
||||
match self.state {
|
||||
State::Unstarted => {
|
||||
if collect_until_tag(byte, START_TAG, &mut self.tag_match_ix, &mut content) {
|
||||
self.state = State::Streaming;
|
||||
self.buffer = Vec::with_capacity(1024);
|
||||
self.tag_match_ix = 0;
|
||||
}
|
||||
}
|
||||
State::Streaming => {
|
||||
if collect_until_tag(byte, END_TAG, &mut self.tag_match_ix, &mut self.buffer) {
|
||||
self.state = State::Ended;
|
||||
}
|
||||
}
|
||||
State::Ended => content.push(byte),
|
||||
}
|
||||
}
|
||||
|
||||
let content = unsafe { String::from_utf8_unchecked(content) };
|
||||
|
||||
let script_source = if matches!(self.state, State::Ended) && !self.buffer.is_empty() {
|
||||
let source = unsafe { String::from_utf8_unchecked(std::mem::take(&mut self.buffer)) };
|
||||
|
||||
Some(source)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
ChunkOutput {
|
||||
content,
|
||||
script_source,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn collect_until_tag(byte: u8, tag: &[u8], tag_match_ix: &mut usize, buffer: &mut Vec<u8>) -> bool {
|
||||
// this can't be a method because it'd require a mutable borrow on both self and self.buffer
|
||||
|
||||
if match_tag_byte(byte, tag, tag_match_ix) {
|
||||
*tag_match_ix >= tag.len()
|
||||
} else {
|
||||
if *tag_match_ix > 0 {
|
||||
// push the partially matched tag to the buffer
|
||||
buffer.extend_from_slice(&tag[..*tag_match_ix]);
|
||||
*tag_match_ix = 0;
|
||||
|
||||
// the tag might start to match again
|
||||
if match_tag_byte(byte, tag, tag_match_ix) {
|
||||
return *tag_match_ix >= tag.len();
|
||||
}
|
||||
}
|
||||
|
||||
buffer.push(byte);
|
||||
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
fn match_tag_byte(byte: u8, tag: &[u8], tag_match_ix: &mut usize) -> bool {
|
||||
if byte == tag[*tag_match_ix] {
|
||||
*tag_match_ix += 1;
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_parse_complete_tag() {
|
||||
let mut parser = ScriptTagParser::new();
|
||||
let input = "<eval type=\"lua\">print(\"Hello, World!\")</eval>";
|
||||
let result = parser.parse_chunk(input);
|
||||
assert_eq!(result.content, "");
|
||||
assert_eq!(
|
||||
result.script_source,
|
||||
Some("print(\"Hello, World!\")".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_no_tag() {
|
||||
let mut parser = ScriptTagParser::new();
|
||||
let input = "No tags here, just plain text";
|
||||
let result = parser.parse_chunk(input);
|
||||
assert_eq!(result.content, "No tags here, just plain text");
|
||||
assert_eq!(result.script_source, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_partial_end_tag() {
|
||||
let mut parser = ScriptTagParser::new();
|
||||
|
||||
// Start the tag
|
||||
let result = parser.parse_chunk("<eval type=\"lua\">let x = '</e");
|
||||
assert_eq!(result.content, "");
|
||||
assert_eq!(result.script_source, None);
|
||||
|
||||
// Finish with the rest
|
||||
let result = parser.parse_chunk("val' + 'not the end';</eval>");
|
||||
assert_eq!(result.content, "");
|
||||
assert_eq!(
|
||||
result.script_source,
|
||||
Some("let x = '</eval' + 'not the end';".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_text_before_and_after_tag() {
|
||||
let mut parser = ScriptTagParser::new();
|
||||
let input = "Before tag <eval type=\"lua\">print(\"Hello\")</eval> After tag";
|
||||
let result = parser.parse_chunk(input);
|
||||
assert_eq!(result.content, "Before tag After tag");
|
||||
assert_eq!(result.script_source, Some("print(\"Hello\")".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_multiple_chunks_with_surrounding_text() {
|
||||
let mut parser = ScriptTagParser::new();
|
||||
|
||||
// First chunk with text before
|
||||
let result = parser.parse_chunk("Before script <eval type=\"lua\">local x = 10");
|
||||
assert_eq!(result.content, "Before script ");
|
||||
assert_eq!(result.script_source, None);
|
||||
|
||||
// Second chunk with script content
|
||||
let result = parser.parse_chunk("\nlocal y = 20");
|
||||
assert_eq!(result.content, "");
|
||||
assert_eq!(result.script_source, None);
|
||||
|
||||
// Last chunk with text after
|
||||
let result = parser.parse_chunk("\nprint(x + y)</eval> After script");
|
||||
assert_eq!(result.content, " After script");
|
||||
assert_eq!(
|
||||
result.script_source,
|
||||
Some("local x = 10\nlocal y = 20\nprint(x + y)".to_string())
|
||||
);
|
||||
|
||||
let result = parser.parse_chunk(" there's more text");
|
||||
assert_eq!(result.content, " there's more text");
|
||||
assert_eq!(result.script_source, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_partial_start_tag_matching() {
|
||||
let mut parser = ScriptTagParser::new();
|
||||
|
||||
// partial match of start tag...
|
||||
let result = parser.parse_chunk("<ev");
|
||||
assert_eq!(result.content, "");
|
||||
|
||||
// ...that's abandandoned when the < of a real tag is encountered
|
||||
let result = parser.parse_chunk("<eval type=\"lua\">script content</eval>");
|
||||
// ...so it gets pushed to content
|
||||
assert_eq!(result.content, "<ev");
|
||||
// ...and the real tag is parsed correctly
|
||||
assert_eq!(result.script_source, Some("script content".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_random_chunked_parsing() {
|
||||
use rand::rngs::StdRng;
|
||||
use rand::{Rng, SeedableRng};
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
|
||||
let test_inputs = [
|
||||
"Before <eval type=\"lua\">print(\"Hello\")</eval> After",
|
||||
"No tags here at all",
|
||||
"<eval type=\"lua\">local x = 10\nlocal y = 20\nprint(x + y)</eval>",
|
||||
"Text <eval type=\"lua\">if true then\nprint(\"nested </e\")\nend</eval> more",
|
||||
];
|
||||
|
||||
let seed = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_secs();
|
||||
|
||||
eprintln!("Using random seed: {}", seed);
|
||||
let mut rng = StdRng::seed_from_u64(seed);
|
||||
|
||||
for test_input in &test_inputs {
|
||||
let mut reference_parser = ScriptTagParser::new();
|
||||
let expected = reference_parser.parse_chunk(test_input);
|
||||
|
||||
let mut chunked_parser = ScriptTagParser::new();
|
||||
let mut remaining = test_input.as_bytes();
|
||||
let mut actual_content = String::new();
|
||||
let mut actual_script = None;
|
||||
|
||||
while !remaining.is_empty() {
|
||||
let chunk_size = rng.gen_range(1..=remaining.len().min(5));
|
||||
let (chunk, rest) = remaining.split_at(chunk_size);
|
||||
remaining = rest;
|
||||
|
||||
let chunk_str = std::str::from_utf8(chunk).unwrap();
|
||||
let result = chunked_parser.parse_chunk(chunk_str);
|
||||
|
||||
actual_content.push_str(&result.content);
|
||||
if result.script_source.is_some() {
|
||||
actual_script = result.script_source;
|
||||
}
|
||||
}
|
||||
|
||||
assert_eq!(actual_content, expected.content);
|
||||
assert_eq!(actual_script, expected.script_source);
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue