assistant: Make scripting a first-class concept instead of a tool (#26338)

This PR makes refactors the scripting functionality to be a first-class
concept of the assistant instead of a generic tool, which will allow us
to build a more customized experience.

- The tool prompt has been slightly tweaked and is now included as a
system message in all conversations. I'm getting decent results, but now
that it isn't in the tools framework, it will probably require more
refining.

- The model will now include an `<eval ...>` tag at the end of the
message with the script. We parse this tag incrementally as it streams
in so that we can indicate that we are generating a script before we see
the closing `</eval>` tag. Later, this will help us interpret the script
as it arrives also.

- Threads now hold a `ScriptSession` entity which manages the state of
all scripts (from parsing to exited) in a centralized way, and will
later collect all script operations so they can be displayed in the UI.

- `script_tool` has been renamed to `assistant_scripting` 

- Script source now opens in a regular read-only buffer  

Note: We still need to handle persistence properly

Release Notes:

- N/A

---------

Co-authored-by: Marshall Bowers <git@maxdeviant.com>
This commit is contained in:
Agus Zubiaga 2025-03-09 06:01:49 -03:00 committed by GitHub
parent ed6bf7f161
commit e298301b40
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 811 additions and 197 deletions

41
Cargo.lock generated
View file

@ -450,6 +450,7 @@ version = "0.1.0"
dependencies = [
"anyhow",
"assistant_context_editor",
"assistant_scripting",
"assistant_settings",
"assistant_slash_command",
"assistant_tool",
@ -563,6 +564,25 @@ dependencies = [
"workspace",
]
[[package]]
name = "assistant_scripting"
version = "0.1.0"
dependencies = [
"anyhow",
"collections",
"futures 0.3.31",
"gpui",
"mlua",
"parking_lot",
"project",
"rand 0.8.5",
"regex",
"serde",
"serde_json",
"settings",
"util",
]
[[package]]
name = "assistant_settings"
version = "0.1.0"
@ -11910,26 +11930,6 @@ version = "1.0.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a3cf7c11c38cb994f3d40e8a8cde3bbd1f72a435e4c49e85d6553d8312306152"
[[package]]
name = "scripting_tool"
version = "0.1.0"
dependencies = [
"anyhow",
"assistant_tool",
"collections",
"futures 0.3.31",
"gpui",
"mlua",
"parking_lot",
"project",
"regex",
"schemars",
"serde",
"serde_json",
"settings",
"util",
]
[[package]]
name = "scrypt"
version = "0.11.0"
@ -16984,7 +16984,6 @@ dependencies = [
"repl",
"reqwest_client",
"rope",
"scripting_tool",
"search",
"serde",
"serde_json",

View file

@ -118,7 +118,7 @@ members = [
"crates/rope",
"crates/rpc",
"crates/schema_generator",
"crates/scripting_tool",
"crates/assistant_scripting",
"crates/search",
"crates/semantic_index",
"crates/semantic_version",
@ -318,7 +318,7 @@ reqwest_client = { path = "crates/reqwest_client" }
rich_text = { path = "crates/rich_text" }
rope = { path = "crates/rope" }
rpc = { path = "crates/rpc" }
scripting_tool = { path = "crates/scripting_tool" }
assistant_scripting = { path = "crates/assistant_scripting" }
search = { path = "crates/search" }
semantic_index = { path = "crates/semantic_index" }
semantic_version = { path = "crates/semantic_version" }

View file

@ -63,6 +63,7 @@ serde.workspace = true
serde_json.workspace = true
settings.workspace = true
smol.workspace = true
assistant_scripting.workspace = true
streaming_diff.workspace = true
telemetry_events.workspace = true
terminal.workspace = true

View file

@ -1,11 +1,12 @@
use std::sync::Arc;
use collections::HashMap;
use assistant_scripting::{ScriptId, ScriptState};
use collections::{HashMap, HashSet};
use editor::{Editor, MultiBuffer};
use gpui::{
list, AbsoluteLength, AnyElement, App, ClickEvent, DefiniteLength, EdgesRefinement, Empty,
Entity, Focusable, Length, ListAlignment, ListOffset, ListState, StyleRefinement, Subscription,
Task, TextStyleRefinement, UnderlineStyle,
Task, TextStyleRefinement, UnderlineStyle, WeakEntity,
};
use language::{Buffer, LanguageRegistry};
use language_model::{LanguageModelRegistry, LanguageModelToolUseId, Role};
@ -14,6 +15,7 @@ use settings::Settings as _;
use theme::ThemeSettings;
use ui::{prelude::*, Disclosure, KeyBinding};
use util::ResultExt as _;
use workspace::Workspace;
use crate::thread::{MessageId, RequestKind, Thread, ThreadError, ThreadEvent};
use crate::thread_store::ThreadStore;
@ -21,6 +23,7 @@ use crate::tool_use::{ToolUse, ToolUseStatus};
use crate::ui::ContextPill;
pub struct ActiveThread {
workspace: WeakEntity<Workspace>,
language_registry: Arc<LanguageRegistry>,
thread_store: Entity<ThreadStore>,
thread: Entity<Thread>,
@ -30,6 +33,7 @@ pub struct ActiveThread {
rendered_messages_by_id: HashMap<MessageId, Entity<Markdown>>,
editing_message: Option<(MessageId, EditMessageState)>,
expanded_tool_uses: HashMap<LanguageModelToolUseId, bool>,
expanded_scripts: HashSet<ScriptId>,
last_error: Option<ThreadError>,
_subscriptions: Vec<Subscription>,
}
@ -40,6 +44,7 @@ struct EditMessageState {
impl ActiveThread {
pub fn new(
workspace: WeakEntity<Workspace>,
thread: Entity<Thread>,
thread_store: Entity<ThreadStore>,
language_registry: Arc<LanguageRegistry>,
@ -52,6 +57,7 @@ impl ActiveThread {
];
let mut this = Self {
workspace,
language_registry,
thread_store,
thread: thread.clone(),
@ -59,6 +65,7 @@ impl ActiveThread {
messages: Vec::new(),
rendered_messages_by_id: HashMap::default(),
expanded_tool_uses: HashMap::default(),
expanded_scripts: HashSet::default(),
list_state: ListState::new(0, ListAlignment::Bottom, px(1024.), {
let this = cx.entity().downgrade();
move |ix, window: &mut Window, cx: &mut App| {
@ -241,7 +248,7 @@ impl ActiveThread {
fn handle_thread_event(
&mut self,
_: &Entity<Thread>,
_thread: &Entity<Thread>,
event: &ThreadEvent,
window: &mut Window,
cx: &mut Context<Self>,
@ -306,6 +313,14 @@ impl ActiveThread {
}
}
}
ThreadEvent::ScriptFinished => {
let model_registry = LanguageModelRegistry::read_global(cx);
if let Some(model) = model_registry.active_model() {
self.thread.update(cx, |thread, cx| {
thread.send_to_model(model, RequestKind::Chat, false, cx);
});
}
}
}
}
@ -445,12 +460,16 @@ impl ActiveThread {
return Empty.into_any();
};
let context = self.thread.read(cx).context_for_message(message_id);
let tool_uses = self.thread.read(cx).tool_uses_for_message(message_id);
let colors = cx.theme().colors();
let thread = self.thread.read(cx);
let context = thread.context_for_message(message_id);
let tool_uses = thread.tool_uses_for_message(message_id);
// Don't render user messages that are just there for returning tool results.
if message.role == Role::User && self.thread.read(cx).message_has_tool_results(message_id) {
if message.role == Role::User
&& (thread.message_has_tool_results(message_id)
|| thread.message_has_script_output(message_id))
{
return Empty.into_any();
}
@ -463,6 +482,8 @@ impl ActiveThread {
.filter(|(id, _)| *id == message_id)
.map(|(_, state)| state.editor.clone());
let colors = cx.theme().colors();
let message_content = v_flex()
.child(
if let Some(edit_message_editor) = edit_message_editor.clone() {
@ -597,6 +618,7 @@ impl ActiveThread {
Role::Assistant => div()
.id(("message-container", ix))
.child(message_content)
.children(self.render_script(message_id, cx))
.map(|parent| {
if tool_uses.is_empty() {
return parent;
@ -716,6 +738,139 @@ impl ActiveThread {
}),
)
}
fn render_script(&self, message_id: MessageId, cx: &mut Context<Self>) -> Option<AnyElement> {
let script = self.thread.read(cx).script_for_message(message_id, cx)?;
let is_open = self.expanded_scripts.contains(&script.id);
let colors = cx.theme().colors();
let element = div().px_2p5().child(
v_flex()
.gap_1()
.rounded_lg()
.border_1()
.border_color(colors.border)
.child(
h_flex()
.justify_between()
.py_0p5()
.pl_1()
.pr_2()
.bg(colors.editor_foreground.opacity(0.02))
.when(is_open, |element| element.border_b_1().rounded_t(px(6.)))
.when(!is_open, |element| element.rounded_md())
.border_color(colors.border)
.child(
h_flex()
.gap_1()
.child(Disclosure::new("script-disclosure", is_open).on_click(
cx.listener({
let script_id = script.id;
move |this, _event, _window, _cx| {
if this.expanded_scripts.contains(&script_id) {
this.expanded_scripts.remove(&script_id);
} else {
this.expanded_scripts.insert(script_id);
}
}
}),
))
// TODO: Generate script description
.child(Label::new("Script")),
)
.child(
h_flex()
.gap_1()
.child(
Label::new(match script.state {
ScriptState::Generating => "Generating",
ScriptState::Running { .. } => "Running",
ScriptState::Succeeded { .. } => "Finished",
ScriptState::Failed { .. } => "Error",
})
.size(LabelSize::XSmall)
.buffer_font(cx),
)
.child(
IconButton::new("view-source", IconName::Eye)
.icon_color(Color::Muted)
.disabled(matches!(script.state, ScriptState::Generating))
.on_click(cx.listener({
let source = script.source.clone();
move |this, _event, window, cx| {
this.open_script_source(source.clone(), window, cx);
}
})),
),
),
)
.when(is_open, |parent| {
let stdout = script.stdout_snapshot();
let error = script.error();
parent.child(
v_flex()
.p_2()
.bg(colors.editor_background)
.gap_2()
.child(if stdout.is_empty() && error.is_none() {
Label::new("No output yet")
.size(LabelSize::Small)
.color(Color::Muted)
} else {
Label::new(stdout).size(LabelSize::Small).buffer_font(cx)
})
.children(script.error().map(|err| {
Label::new(err.to_string())
.size(LabelSize::Small)
.color(Color::Error)
})),
)
}),
);
Some(element.into_any())
}
fn open_script_source(
&mut self,
source: SharedString,
window: &mut Window,
cx: &mut Context<'_, ActiveThread>,
) {
let language_registry = self.language_registry.clone();
let workspace = self.workspace.clone();
let source = source.clone();
cx.spawn_in(window, |_, mut cx| async move {
let lua = language_registry.language_for_name("Lua").await.log_err();
workspace.update_in(&mut cx, |workspace, window, cx| {
let project = workspace.project().clone();
let buffer = project.update(cx, |project, cx| {
project.create_local_buffer(&source.trim(), lua, cx)
});
let buffer = cx.new(|cx| {
MultiBuffer::singleton(buffer, cx)
// TODO: Generate script description
.with_title("Assistant script".into())
});
let editor = cx.new(|cx| {
let mut editor =
Editor::for_multibuffer(buffer, Some(project), true, window, cx);
editor.set_read_only(true);
editor
});
workspace.add_item_to_active_pane(Box::new(editor), None, true, window, cx);
})
})
.detach_and_log_err(cx);
}
}
impl Render for ActiveThread {

View file

@ -166,22 +166,25 @@ impl AssistantPanel {
let history_store =
cx.new(|cx| HistoryStore::new(thread_store.clone(), context_store.clone(), cx));
let thread = cx.new(|cx| {
ActiveThread::new(
workspace.clone(),
thread.clone(),
thread_store.clone(),
language_registry.clone(),
window,
cx,
)
});
Self {
active_view: ActiveView::Thread,
workspace,
project: project.clone(),
fs: fs.clone(),
language_registry: language_registry.clone(),
thread_store: thread_store.clone(),
thread: cx.new(|cx| {
ActiveThread::new(
thread.clone(),
thread_store.clone(),
language_registry,
window,
cx,
)
}),
thread_store: thread_store.clone(),
thread,
message_editor,
context_store,
context_editor: None,
@ -239,6 +242,7 @@ impl AssistantPanel {
self.active_view = ActiveView::Thread;
self.thread = cx.new(|cx| {
ActiveThread::new(
self.workspace.clone(),
thread.clone(),
self.thread_store.clone(),
self.language_registry.clone(),
@ -372,6 +376,7 @@ impl AssistantPanel {
this.active_view = ActiveView::Thread;
this.thread = cx.new(|cx| {
ActiveThread::new(
this.workspace.clone(),
thread.clone(),
this.thread_store.clone(),
this.language_registry.clone(),

View file

@ -1,11 +1,14 @@
use std::sync::Arc;
use anyhow::Result;
use assistant_scripting::{
Script, ScriptEvent, ScriptId, ScriptSession, ScriptTagParser, SCRIPTING_PROMPT,
};
use assistant_tool::ToolWorkingSet;
use chrono::{DateTime, Utc};
use collections::{BTreeMap, HashMap, HashSet};
use futures::StreamExt as _;
use gpui::{App, Context, Entity, EventEmitter, SharedString, Task};
use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Subscription, Task};
use language_model::{
LanguageModel, LanguageModelCompletionEvent, LanguageModelRegistry, LanguageModelRequest,
LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
@ -75,14 +78,21 @@ pub struct Thread {
project: Entity<Project>,
tools: Arc<ToolWorkingSet>,
tool_use: ToolUseState,
scripts_by_assistant_message: HashMap<MessageId, ScriptId>,
script_output_messages: HashSet<MessageId>,
script_session: Entity<ScriptSession>,
_script_session_subscription: Subscription,
}
impl Thread {
pub fn new(
project: Entity<Project>,
tools: Arc<ToolWorkingSet>,
_cx: &mut Context<Self>,
cx: &mut Context<Self>,
) -> Self {
let script_session = cx.new(|cx| ScriptSession::new(project.clone(), cx));
let script_session_subscription = cx.subscribe(&script_session, Self::handle_script_event);
Self {
id: ThreadId::new(),
updated_at: Utc::now(),
@ -97,6 +107,10 @@ impl Thread {
project,
tools,
tool_use: ToolUseState::new(),
scripts_by_assistant_message: HashMap::default(),
script_output_messages: HashSet::default(),
script_session,
_script_session_subscription: script_session_subscription,
}
}
@ -105,7 +119,7 @@ impl Thread {
saved: SavedThread,
project: Entity<Project>,
tools: Arc<ToolWorkingSet>,
_cx: &mut Context<Self>,
cx: &mut Context<Self>,
) -> Self {
let next_message_id = MessageId(
saved
@ -115,6 +129,8 @@ impl Thread {
.unwrap_or(0),
);
let tool_use = ToolUseState::from_saved_messages(&saved.messages);
let script_session = cx.new(|cx| ScriptSession::new(project.clone(), cx));
let script_session_subscription = cx.subscribe(&script_session, Self::handle_script_event);
Self {
id,
@ -138,6 +154,10 @@ impl Thread {
project,
tools,
tool_use,
scripts_by_assistant_message: HashMap::default(),
script_output_messages: HashSet::default(),
script_session,
_script_session_subscription: script_session_subscription,
}
}
@ -223,17 +243,22 @@ impl Thread {
self.tool_use.message_has_tool_results(message_id)
}
pub fn message_has_script_output(&self, message_id: MessageId) -> bool {
self.script_output_messages.contains(&message_id)
}
pub fn insert_user_message(
&mut self,
text: impl Into<String>,
context: Vec<ContextSnapshot>,
cx: &mut Context<Self>,
) {
) -> MessageId {
let message_id = self.insert_message(Role::User, text, cx);
let context_ids = context.iter().map(|context| context.id).collect::<Vec<_>>();
self.context
.extend(context.into_iter().map(|context| (context.id, context)));
self.context_by_message.insert(message_id, context_ids);
message_id
}
pub fn insert_message(
@ -302,6 +327,39 @@ impl Thread {
text
}
pub fn script_for_message<'a>(
&'a self,
message_id: MessageId,
cx: &'a App,
) -> Option<&'a Script> {
self.scripts_by_assistant_message
.get(&message_id)
.map(|script_id| self.script_session.read(cx).get(*script_id))
}
fn handle_script_event(
&mut self,
_script_session: Entity<ScriptSession>,
event: &ScriptEvent,
cx: &mut Context<Self>,
) {
match event {
ScriptEvent::Spawned(_) => {}
ScriptEvent::Exited(script_id) => {
if let Some(output_message) = self
.script_session
.read(cx)
.get(*script_id)
.output_message_for_llm()
{
let message_id = self.insert_user_message(output_message, vec![], cx);
self.script_output_messages.insert(message_id);
cx.emit(ThreadEvent::ScriptFinished)
}
}
}
}
pub fn send_to_model(
&mut self,
model: Arc<dyn LanguageModel>,
@ -330,7 +388,7 @@ impl Thread {
pub fn to_completion_request(
&self,
request_kind: RequestKind,
_cx: &App,
cx: &App,
) -> LanguageModelRequest {
let mut request = LanguageModelRequest {
messages: vec![],
@ -339,6 +397,12 @@ impl Thread {
temperature: None,
};
request.messages.push(LanguageModelRequestMessage {
role: Role::System,
content: vec![SCRIPTING_PROMPT.to_string().into()],
cache: true,
});
let mut referenced_context_ids = HashSet::default();
for message in &self.messages {
@ -351,6 +415,7 @@ impl Thread {
content: Vec::new(),
cache: false,
};
match request_kind {
RequestKind::Chat => {
self.tool_use
@ -371,11 +436,20 @@ impl Thread {
RequestKind::Chat => {
self.tool_use
.attach_tool_uses(message.id, &mut request_message);
if matches!(message.role, Role::Assistant) {
if let Some(script_id) = self.scripts_by_assistant_message.get(&message.id)
{
let script = self.script_session.read(cx).get(*script_id);
request_message.content.push(script.source_tag().into());
}
}
}
RequestKind::Summarize => {
// We don't care about tool use during summarization.
}
}
};
request.messages.push(request_message);
}
@ -412,6 +486,8 @@ impl Thread {
let stream_completion = async {
let mut events = stream.await?;
let mut stop_reason = StopReason::EndTurn;
let mut script_tag_parser = ScriptTagParser::new();
let mut script_id = None;
while let Some(event) = events.next().await {
let event = event?;
@ -426,19 +502,43 @@ impl Thread {
}
LanguageModelCompletionEvent::Text(chunk) => {
if let Some(last_message) = thread.messages.last_mut() {
if last_message.role == Role::Assistant {
last_message.text.push_str(&chunk);
let chunk = script_tag_parser.parse_chunk(&chunk);
let message_id = if last_message.role == Role::Assistant {
last_message.text.push_str(&chunk.content);
cx.emit(ThreadEvent::StreamedAssistantText(
last_message.id,
chunk,
chunk.content,
));
last_message.id
} else {
// If we won't have an Assistant message yet, assume this chunk marks the beginning
// of a new Assistant response.
//
// Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
// will result in duplicating the text of the chunk in the rendered Markdown.
thread.insert_message(Role::Assistant, chunk, cx);
thread.insert_message(Role::Assistant, chunk.content, cx)
};
if script_id.is_none() && script_tag_parser.found_script() {
let id = thread
.script_session
.update(cx, |session, _cx| session.new_script());
thread.scripts_by_assistant_message.insert(message_id, id);
script_id = Some(id);
}
if let (Some(script_source), Some(script_id)) =
(chunk.script_source, script_id)
{
// TODO: move buffer to script and run as it streams
thread
.script_session
.update(cx, |this, cx| {
this.run_script(script_id, script_source, cx)
})
.detach_and_log_err(cx);
}
}
}
@ -661,6 +761,7 @@ pub enum ThreadEvent {
#[allow(unused)]
tool_use_id: LanguageModelToolUseId,
},
ScriptFinished,
}
impl EventEmitter<ThreadEvent> for Thread {}

View file

@ -1,5 +1,5 @@
[package]
name = "scripting_tool"
name = "assistant_scripting"
version = "0.1.0"
edition.workspace = true
publish.workspace = true
@ -9,12 +9,11 @@ license = "GPL-3.0-or-later"
workspace = true
[lib]
path = "src/scripting_tool.rs"
path = "src/assistant_scripting.rs"
doctest = false
[dependencies]
anyhow.workspace = true
assistant_tool.workspace = true
collections.workspace = true
futures.workspace = true
gpui.workspace = true
@ -22,7 +21,6 @@ mlua.workspace = true
parking_lot.workspace = true
project.workspace = true
regex.workspace = true
schemars.workspace = true
serde.workspace = true
serde_json.workspace = true
settings.workspace = true
@ -32,4 +30,5 @@ util.workspace = true
collections = { workspace = true, features = ["test-support"] }
gpui = { workspace = true, features = ["test-support"] }
project = { workspace = true, features = ["test-support"] }
rand.workspace = true
settings = { workspace = true, features = ["test-support"] }

View file

@ -0,0 +1,7 @@
mod session;
mod tag;
pub use session::*;
pub use tag::*;
pub const SCRIPTING_PROMPT: &str = include_str!("./system_prompt.txt");

View file

@ -1,10 +1,9 @@
use anyhow::Result;
use collections::{HashMap, HashSet};
use futures::{
channel::{mpsc, oneshot},
pin_mut, SinkExt, StreamExt,
};
use gpui::{AppContext, AsyncApp, Context, Entity, Task, WeakEntity};
use gpui::{AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
use mlua::{Lua, MultiValue, Table, UserData, UserDataMethods};
use parking_lot::Mutex;
use project::{search::SearchQuery, Fs, Project};
@ -16,24 +15,23 @@ use std::{
};
use util::{paths::PathMatcher, ResultExt};
pub struct ScriptOutput {
pub stdout: String,
}
use crate::{SCRIPT_END_TAG, SCRIPT_START_TAG};
struct ForegroundFn(Box<dyn FnOnce(WeakEntity<Session>, AsyncApp) + Send>);
struct ForegroundFn(Box<dyn FnOnce(WeakEntity<ScriptSession>, AsyncApp) + Send>);
pub struct Session {
pub struct ScriptSession {
project: Entity<Project>,
// TODO Remove this
fs_changes: Arc<Mutex<HashMap<PathBuf, Vec<u8>>>>,
foreground_fns_tx: mpsc::Sender<ForegroundFn>,
_invoke_foreground_fns: Task<()>,
scripts: Vec<Script>,
}
impl Session {
impl ScriptSession {
pub fn new(project: Entity<Project>, cx: &mut Context<Self>) -> Self {
let (foreground_fns_tx, mut foreground_fns_rx) = mpsc::channel(128);
Session {
ScriptSession {
project,
fs_changes: Arc::new(Mutex::new(HashMap::default())),
foreground_fns_tx,
@ -42,15 +40,62 @@ impl Session {
foreground_fn.0(this.clone(), cx.clone());
}
}),
scripts: Vec::new(),
}
}
/// Runs a Lua script in a sandboxed environment and returns the printed lines
pub fn new_script(&mut self) -> ScriptId {
let id = ScriptId(self.scripts.len() as u32);
let script = Script {
id,
state: ScriptState::Generating,
source: SharedString::new_static(""),
};
self.scripts.push(script);
id
}
pub fn run_script(
&mut self,
script: String,
script_id: ScriptId,
script_src: String,
cx: &mut Context<Self>,
) -> Task<Result<ScriptOutput>> {
) -> Task<anyhow::Result<()>> {
let script = self.get_mut(script_id);
let stdout = Arc::new(Mutex::new(String::new()));
script.source = script_src.clone().into();
script.state = ScriptState::Running {
stdout: stdout.clone(),
};
let task = self.run_lua(script_src, stdout, cx);
cx.emit(ScriptEvent::Spawned(script_id));
cx.spawn(|session, mut cx| async move {
let result = task.await;
session.update(&mut cx, |session, cx| {
let script = session.get_mut(script_id);
let stdout = script.stdout_snapshot();
script.state = match result {
Ok(()) => ScriptState::Succeeded { stdout },
Err(error) => ScriptState::Failed { stdout, error },
};
cx.emit(ScriptEvent::Exited(script_id))
})
})
}
fn run_lua(
&mut self,
script: String,
stdout: Arc<Mutex<String>>,
cx: &mut Context<Self>,
) -> Task<anyhow::Result<()>> {
const SANDBOX_PREAMBLE: &str = include_str!("sandbox_preamble.lua");
// TODO Remove fs_changes
@ -62,13 +107,17 @@ impl Session {
.visible_worktrees(cx)
.next()
.map(|worktree| worktree.read(cx).abs_path());
let fs = self.project.read(cx).fs().clone();
let foreground_fns_tx = self.foreground_fns_tx.clone();
cx.background_spawn(async move {
let task = cx.background_spawn({
let stdout = stdout.clone();
async move {
let lua = Lua::new();
lua.set_memory_limit(2 * 1024 * 1024 * 1024)?; // 2 GB
let globals = lua.globals();
let stdout = Arc::new(Mutex::new(String::new()));
globals.set(
"sb_print",
lua.create_function({
@ -103,11 +152,19 @@ impl Session {
// Drop Lua instance to decrement reference count.
drop(lua);
let stdout = Arc::try_unwrap(stdout)
.expect("no more references to stdout")
.into_inner();
Ok(ScriptOutput { stdout })
})
anyhow::Ok(())
}
});
task
}
pub fn get(&self, script_id: ScriptId) -> &Script {
&self.scripts[script_id.0 as usize]
}
fn get_mut(&mut self, script_id: ScriptId) -> &mut Script {
&mut self.scripts[script_id.0 as usize]
}
/// Sandboxed print() function in Lua.
@ -678,6 +735,79 @@ impl UserData for FileContent {
}
}
#[derive(Debug)]
pub enum ScriptEvent {
Spawned(ScriptId),
Exited(ScriptId),
}
impl EventEmitter<ScriptEvent> for ScriptSession {}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub struct ScriptId(u32);
pub struct Script {
pub id: ScriptId,
pub state: ScriptState,
pub source: SharedString,
}
pub enum ScriptState {
Generating,
Running {
stdout: Arc<Mutex<String>>,
},
Succeeded {
stdout: String,
},
Failed {
stdout: String,
error: anyhow::Error,
},
}
impl Script {
pub fn source_tag(&self) -> String {
format!("{}{}{}", SCRIPT_START_TAG, self.source, SCRIPT_END_TAG)
}
/// If exited, returns a message with the output for the LLM
pub fn output_message_for_llm(&self) -> Option<String> {
match &self.state {
ScriptState::Generating { .. } => None,
ScriptState::Running { .. } => None,
ScriptState::Succeeded { stdout } => {
format!("Here's the script output:\n{}", stdout).into()
}
ScriptState::Failed { stdout, error } => format!(
"The script failed with:\n{}\n\nHere's the output it managed to print:\n{}",
error, stdout
)
.into(),
}
}
/// Get a snapshot of the script's stdout
pub fn stdout_snapshot(&self) -> String {
match &self.state {
ScriptState::Generating { .. } => String::new(),
ScriptState::Running { stdout } => stdout.lock().clone(),
ScriptState::Succeeded { stdout } => stdout.clone(),
ScriptState::Failed { stdout, .. } => stdout.clone(),
}
}
/// Returns the error if the script failed, otherwise None
pub fn error(&self) -> Option<&anyhow::Error> {
match &self.state {
ScriptState::Generating { .. } => None,
ScriptState::Running { .. } => None,
ScriptState::Succeeded { .. } => None,
ScriptState::Failed { error, .. } => Some(error),
}
}
}
#[cfg(test)]
mod tests {
use gpui::TestAppContext;
@ -689,35 +819,17 @@ mod tests {
#[gpui::test]
async fn test_print(cx: &mut TestAppContext) {
init_test(cx);
let fs = FakeFs::new(cx.executor());
let project = Project::test(fs, [], cx).await;
let session = cx.new(|cx| Session::new(project, cx));
let script = r#"
print("Hello", "world!")
print("Goodbye", "moon!")
"#;
let output = session
.update(cx, |session, cx| session.run_script(script.to_string(), cx))
.await
.unwrap();
assert_eq!(output.stdout, "Hello\tworld!\nGoodbye\tmoon!\n");
let output = test_script(script, cx).await.unwrap();
assert_eq!(output, "Hello\tworld!\nGoodbye\tmoon!\n");
}
#[gpui::test]
async fn test_search(cx: &mut TestAppContext) {
init_test(cx);
let fs = FakeFs::new(cx.executor());
fs.insert_tree(
"/",
json!({
"file1.txt": "Hello world!",
"file2.txt": "Goodbye moon!"
}),
)
.await;
let project = Project::test(fs, [Path::new("/")], cx).await;
let session = cx.new(|cx| Session::new(project, cx));
let script = r#"
local results = search("world")
for i, result in ipairs(results) do
@ -728,11 +840,36 @@ mod tests {
end
end
"#;
let output = session
.update(cx, |session, cx| session.run_script(script.to_string(), cx))
.await
.unwrap();
assert_eq!(output.stdout, "File: /file1.txt\nMatches:\n world\n");
let output = test_script(script, cx).await.unwrap();
assert_eq!(output, "File: /file1.txt\nMatches:\n world\n");
}
async fn test_script(source: &str, cx: &mut TestAppContext) -> anyhow::Result<String> {
init_test(cx);
let fs = FakeFs::new(cx.executor());
fs.insert_tree(
"/",
json!({
"file1.txt": "Hello world!",
"file2.txt": "Goodbye moon!"
}),
)
.await;
let project = Project::test(fs, [Path::new("/")], cx).await;
let session = cx.new(|cx| ScriptSession::new(project, cx));
let (script_id, task) = session.update(cx, |session, cx| {
let script_id = session.new_script();
let task = session.run_script(script_id, source.to_string(), cx);
(script_id, task)
});
task.await?;
Ok(session.read_with(cx, |session, _cx| session.get(script_id).stdout_snapshot()))
}
fn init_test(cx: &mut TestAppContext) {

View file

@ -3,6 +3,12 @@ output was, including both stdout as well as the git diff of changes it made to
the filesystem. That way, you can get more information about the code base, or
make changes to the code base directly.
Put the Lua script inside of an `<eval>` tag like so:
<eval type="lua">
print("Hello, world!")
</eval>
The Lua script will have access to `io` and it will run with the current working
directory being in the root of the code base, so you can use it to explore,
search, make changes, etc. You can also have the script print things, and I'll
@ -10,13 +16,17 @@ tell you what the output was. Note that `io` only has `open`, and then the file
it returns only has the methods read, write, and close - it doesn't have popen
or anything else.
Also, I'm going to be putting this Lua script into JSON, so please don't use
Lua's double quote syntax for string literals - use one of Lua's other syntaxes
for string literals, so I don't have to escape the double quotes.
There will be a global called `search` which accepts a regex (it's implemented
using Rust's regex crate, so use that regex syntax) and runs that regex on the
contents of every file in the code base (aside from gitignored files), then
returns an array of tables with two fields: "path" (the path to the file that
had the matches) and "matches" (an array of strings, with each string being a
match that was found within the file).
When I send you the script output, do not thank me for running it,
act as if you ran it yourself.
IMPORTANT!
Only include a maximum of one Lua script at the very end of your message
DO NOT WRITE ANYTHING ELSE AFTER THE SCRIPT. Wait for my response with the script
output to continue.

View file

@ -0,0 +1,260 @@
pub const SCRIPT_START_TAG: &str = "<eval type=\"lua\">";
pub const SCRIPT_END_TAG: &str = "</eval>";
const START_TAG: &[u8] = SCRIPT_START_TAG.as_bytes();
const END_TAG: &[u8] = SCRIPT_END_TAG.as_bytes();
/// Parses a script tag in an assistant message as it is being streamed.
pub struct ScriptTagParser {
state: State,
buffer: Vec<u8>,
tag_match_ix: usize,
}
enum State {
Unstarted,
Streaming,
Ended,
}
#[derive(Debug, PartialEq)]
pub struct ChunkOutput {
/// The chunk with script tags removed.
pub content: String,
/// The full script tag content. `None` until closed.
pub script_source: Option<String>,
}
impl ScriptTagParser {
/// Create a new script tag parser.
pub fn new() -> Self {
Self {
state: State::Unstarted,
buffer: Vec::new(),
tag_match_ix: 0,
}
}
/// Returns true if the parser has found a script tag.
pub fn found_script(&self) -> bool {
match self.state {
State::Unstarted => false,
State::Streaming | State::Ended => true,
}
}
/// Process a new chunk of input, splitting it into surrounding content and script source.
pub fn parse_chunk(&mut self, input: &str) -> ChunkOutput {
let mut content = Vec::with_capacity(input.len());
for byte in input.bytes() {
match self.state {
State::Unstarted => {
if collect_until_tag(byte, START_TAG, &mut self.tag_match_ix, &mut content) {
self.state = State::Streaming;
self.buffer = Vec::with_capacity(1024);
self.tag_match_ix = 0;
}
}
State::Streaming => {
if collect_until_tag(byte, END_TAG, &mut self.tag_match_ix, &mut self.buffer) {
self.state = State::Ended;
}
}
State::Ended => content.push(byte),
}
}
let content = unsafe { String::from_utf8_unchecked(content) };
let script_source = if matches!(self.state, State::Ended) && !self.buffer.is_empty() {
let source = unsafe { String::from_utf8_unchecked(std::mem::take(&mut self.buffer)) };
Some(source)
} else {
None
};
ChunkOutput {
content,
script_source,
}
}
}
fn collect_until_tag(byte: u8, tag: &[u8], tag_match_ix: &mut usize, buffer: &mut Vec<u8>) -> bool {
// this can't be a method because it'd require a mutable borrow on both self and self.buffer
if match_tag_byte(byte, tag, tag_match_ix) {
*tag_match_ix >= tag.len()
} else {
if *tag_match_ix > 0 {
// push the partially matched tag to the buffer
buffer.extend_from_slice(&tag[..*tag_match_ix]);
*tag_match_ix = 0;
// the tag might start to match again
if match_tag_byte(byte, tag, tag_match_ix) {
return *tag_match_ix >= tag.len();
}
}
buffer.push(byte);
false
}
}
fn match_tag_byte(byte: u8, tag: &[u8], tag_match_ix: &mut usize) -> bool {
if byte == tag[*tag_match_ix] {
*tag_match_ix += 1;
true
} else {
false
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_complete_tag() {
let mut parser = ScriptTagParser::new();
let input = "<eval type=\"lua\">print(\"Hello, World!\")</eval>";
let result = parser.parse_chunk(input);
assert_eq!(result.content, "");
assert_eq!(
result.script_source,
Some("print(\"Hello, World!\")".to_string())
);
}
#[test]
fn test_no_tag() {
let mut parser = ScriptTagParser::new();
let input = "No tags here, just plain text";
let result = parser.parse_chunk(input);
assert_eq!(result.content, "No tags here, just plain text");
assert_eq!(result.script_source, None);
}
#[test]
fn test_partial_end_tag() {
let mut parser = ScriptTagParser::new();
// Start the tag
let result = parser.parse_chunk("<eval type=\"lua\">let x = '</e");
assert_eq!(result.content, "");
assert_eq!(result.script_source, None);
// Finish with the rest
let result = parser.parse_chunk("val' + 'not the end';</eval>");
assert_eq!(result.content, "");
assert_eq!(
result.script_source,
Some("let x = '</eval' + 'not the end';".to_string())
);
}
#[test]
fn test_text_before_and_after_tag() {
let mut parser = ScriptTagParser::new();
let input = "Before tag <eval type=\"lua\">print(\"Hello\")</eval> After tag";
let result = parser.parse_chunk(input);
assert_eq!(result.content, "Before tag After tag");
assert_eq!(result.script_source, Some("print(\"Hello\")".to_string()));
}
#[test]
fn test_multiple_chunks_with_surrounding_text() {
let mut parser = ScriptTagParser::new();
// First chunk with text before
let result = parser.parse_chunk("Before script <eval type=\"lua\">local x = 10");
assert_eq!(result.content, "Before script ");
assert_eq!(result.script_source, None);
// Second chunk with script content
let result = parser.parse_chunk("\nlocal y = 20");
assert_eq!(result.content, "");
assert_eq!(result.script_source, None);
// Last chunk with text after
let result = parser.parse_chunk("\nprint(x + y)</eval> After script");
assert_eq!(result.content, " After script");
assert_eq!(
result.script_source,
Some("local x = 10\nlocal y = 20\nprint(x + y)".to_string())
);
let result = parser.parse_chunk(" there's more text");
assert_eq!(result.content, " there's more text");
assert_eq!(result.script_source, None);
}
#[test]
fn test_partial_start_tag_matching() {
let mut parser = ScriptTagParser::new();
// partial match of start tag...
let result = parser.parse_chunk("<ev");
assert_eq!(result.content, "");
// ...that's abandandoned when the < of a real tag is encountered
let result = parser.parse_chunk("<eval type=\"lua\">script content</eval>");
// ...so it gets pushed to content
assert_eq!(result.content, "<ev");
// ...and the real tag is parsed correctly
assert_eq!(result.script_source, Some("script content".to_string()));
}
#[test]
fn test_random_chunked_parsing() {
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
use std::time::{SystemTime, UNIX_EPOCH};
let test_inputs = [
"Before <eval type=\"lua\">print(\"Hello\")</eval> After",
"No tags here at all",
"<eval type=\"lua\">local x = 10\nlocal y = 20\nprint(x + y)</eval>",
"Text <eval type=\"lua\">if true then\nprint(\"nested </e\")\nend</eval> more",
];
let seed = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs();
eprintln!("Using random seed: {}", seed);
let mut rng = StdRng::seed_from_u64(seed);
for test_input in &test_inputs {
let mut reference_parser = ScriptTagParser::new();
let expected = reference_parser.parse_chunk(test_input);
let mut chunked_parser = ScriptTagParser::new();
let mut remaining = test_input.as_bytes();
let mut actual_content = String::new();
let mut actual_script = None;
while !remaining.is_empty() {
let chunk_size = rng.gen_range(1..=remaining.len().min(5));
let (chunk, rest) = remaining.split_at(chunk_size);
remaining = rest;
let chunk_str = std::str::from_utf8(chunk).unwrap();
let result = chunked_parser.parse_chunk(chunk_str);
actual_content.push_str(&result.content);
if result.script_source.is_some() {
actual_script = result.script_source;
}
}
assert_eq!(actual_content, expected.content);
assert_eq!(actual_script, expected.script_source);
}
}
}

View file

@ -1,58 +0,0 @@
mod session;
use project::Project;
pub(crate) use session::*;
use assistant_tool::{Tool, ToolRegistry};
use gpui::{App, AppContext as _, Entity, Task};
use schemars::JsonSchema;
use serde::Deserialize;
use std::sync::Arc;
pub fn init(cx: &App) {
let registry = ToolRegistry::global(cx);
registry.register_tool(ScriptingTool);
}
#[derive(Debug, Deserialize, JsonSchema)]
struct ScriptingToolInput {
lua_script: String,
}
struct ScriptingTool;
impl Tool for ScriptingTool {
fn name(&self) -> String {
"lua-interpreter".into()
}
fn description(&self) -> String {
include_str!("scripting_tool_description.txt").into()
}
fn input_schema(&self) -> serde_json::Value {
let schema = schemars::schema_for!(ScriptingToolInput);
serde_json::to_value(&schema).unwrap()
}
fn run(
self: Arc<Self>,
input: serde_json::Value,
project: Entity<Project>,
cx: &mut App,
) -> Task<anyhow::Result<String>> {
let input = match serde_json::from_value::<ScriptingToolInput>(input) {
Err(err) => return Task::ready(Err(err.into())),
Ok(input) => input,
};
let session = cx.new(|cx| Session::new(project, cx));
let lua_script = input.lua_script;
let script = session.update(cx, |session, cx| session.run_script(lua_script, cx));
cx.spawn(|_cx| async move {
let output = script.await?.stdout;
drop(session);
Ok(format!("The script output the following:\n{output}"))
})
}
}

View file

@ -98,7 +98,6 @@ remote.workspace = true
repl.workspace = true
reqwest_client.workspace = true
rope.workspace = true
scripting_tool.workspace = true
search.workspace = true
serde.workspace = true
serde_json.workspace = true

View file

@ -476,7 +476,6 @@ fn main() {
cx,
);
assistant_tools::init(cx);
scripting_tool::init(cx);
repl::init(app_state.fs.clone(), cx);
extension_host::init(
extension_host_proxy,