assistant: Make scripting a first-class concept instead of a tool (#26338)

This PR makes refactors the scripting functionality to be a first-class
concept of the assistant instead of a generic tool, which will allow us
to build a more customized experience.

- The tool prompt has been slightly tweaked and is now included as a
system message in all conversations. I'm getting decent results, but now
that it isn't in the tools framework, it will probably require more
refining.

- The model will now include an `<eval ...>` tag at the end of the
message with the script. We parse this tag incrementally as it streams
in so that we can indicate that we are generating a script before we see
the closing `</eval>` tag. Later, this will help us interpret the script
as it arrives also.

- Threads now hold a `ScriptSession` entity which manages the state of
all scripts (from parsing to exited) in a centralized way, and will
later collect all script operations so they can be displayed in the UI.

- `script_tool` has been renamed to `assistant_scripting` 

- Script source now opens in a regular read-only buffer  

Note: We still need to handle persistence properly

Release Notes:

- N/A

---------

Co-authored-by: Marshall Bowers <git@maxdeviant.com>
This commit is contained in:
Agus Zubiaga 2025-03-09 06:01:49 -03:00 committed by GitHub
parent ed6bf7f161
commit e298301b40
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 811 additions and 197 deletions

41
Cargo.lock generated
View file

@ -450,6 +450,7 @@ version = "0.1.0"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"assistant_context_editor", "assistant_context_editor",
"assistant_scripting",
"assistant_settings", "assistant_settings",
"assistant_slash_command", "assistant_slash_command",
"assistant_tool", "assistant_tool",
@ -563,6 +564,25 @@ dependencies = [
"workspace", "workspace",
] ]
[[package]]
name = "assistant_scripting"
version = "0.1.0"
dependencies = [
"anyhow",
"collections",
"futures 0.3.31",
"gpui",
"mlua",
"parking_lot",
"project",
"rand 0.8.5",
"regex",
"serde",
"serde_json",
"settings",
"util",
]
[[package]] [[package]]
name = "assistant_settings" name = "assistant_settings"
version = "0.1.0" version = "0.1.0"
@ -11910,26 +11930,6 @@ version = "1.0.7"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a3cf7c11c38cb994f3d40e8a8cde3bbd1f72a435e4c49e85d6553d8312306152" checksum = "a3cf7c11c38cb994f3d40e8a8cde3bbd1f72a435e4c49e85d6553d8312306152"
[[package]]
name = "scripting_tool"
version = "0.1.0"
dependencies = [
"anyhow",
"assistant_tool",
"collections",
"futures 0.3.31",
"gpui",
"mlua",
"parking_lot",
"project",
"regex",
"schemars",
"serde",
"serde_json",
"settings",
"util",
]
[[package]] [[package]]
name = "scrypt" name = "scrypt"
version = "0.11.0" version = "0.11.0"
@ -16984,7 +16984,6 @@ dependencies = [
"repl", "repl",
"reqwest_client", "reqwest_client",
"rope", "rope",
"scripting_tool",
"search", "search",
"serde", "serde",
"serde_json", "serde_json",

View file

@ -118,7 +118,7 @@ members = [
"crates/rope", "crates/rope",
"crates/rpc", "crates/rpc",
"crates/schema_generator", "crates/schema_generator",
"crates/scripting_tool", "crates/assistant_scripting",
"crates/search", "crates/search",
"crates/semantic_index", "crates/semantic_index",
"crates/semantic_version", "crates/semantic_version",
@ -318,7 +318,7 @@ reqwest_client = { path = "crates/reqwest_client" }
rich_text = { path = "crates/rich_text" } rich_text = { path = "crates/rich_text" }
rope = { path = "crates/rope" } rope = { path = "crates/rope" }
rpc = { path = "crates/rpc" } rpc = { path = "crates/rpc" }
scripting_tool = { path = "crates/scripting_tool" } assistant_scripting = { path = "crates/assistant_scripting" }
search = { path = "crates/search" } search = { path = "crates/search" }
semantic_index = { path = "crates/semantic_index" } semantic_index = { path = "crates/semantic_index" }
semantic_version = { path = "crates/semantic_version" } semantic_version = { path = "crates/semantic_version" }

View file

@ -63,6 +63,7 @@ serde.workspace = true
serde_json.workspace = true serde_json.workspace = true
settings.workspace = true settings.workspace = true
smol.workspace = true smol.workspace = true
assistant_scripting.workspace = true
streaming_diff.workspace = true streaming_diff.workspace = true
telemetry_events.workspace = true telemetry_events.workspace = true
terminal.workspace = true terminal.workspace = true

View file

@ -1,11 +1,12 @@
use std::sync::Arc; use std::sync::Arc;
use collections::HashMap; use assistant_scripting::{ScriptId, ScriptState};
use collections::{HashMap, HashSet};
use editor::{Editor, MultiBuffer}; use editor::{Editor, MultiBuffer};
use gpui::{ use gpui::{
list, AbsoluteLength, AnyElement, App, ClickEvent, DefiniteLength, EdgesRefinement, Empty, list, AbsoluteLength, AnyElement, App, ClickEvent, DefiniteLength, EdgesRefinement, Empty,
Entity, Focusable, Length, ListAlignment, ListOffset, ListState, StyleRefinement, Subscription, Entity, Focusable, Length, ListAlignment, ListOffset, ListState, StyleRefinement, Subscription,
Task, TextStyleRefinement, UnderlineStyle, Task, TextStyleRefinement, UnderlineStyle, WeakEntity,
}; };
use language::{Buffer, LanguageRegistry}; use language::{Buffer, LanguageRegistry};
use language_model::{LanguageModelRegistry, LanguageModelToolUseId, Role}; use language_model::{LanguageModelRegistry, LanguageModelToolUseId, Role};
@ -14,6 +15,7 @@ use settings::Settings as _;
use theme::ThemeSettings; use theme::ThemeSettings;
use ui::{prelude::*, Disclosure, KeyBinding}; use ui::{prelude::*, Disclosure, KeyBinding};
use util::ResultExt as _; use util::ResultExt as _;
use workspace::Workspace;
use crate::thread::{MessageId, RequestKind, Thread, ThreadError, ThreadEvent}; use crate::thread::{MessageId, RequestKind, Thread, ThreadError, ThreadEvent};
use crate::thread_store::ThreadStore; use crate::thread_store::ThreadStore;
@ -21,6 +23,7 @@ use crate::tool_use::{ToolUse, ToolUseStatus};
use crate::ui::ContextPill; use crate::ui::ContextPill;
pub struct ActiveThread { pub struct ActiveThread {
workspace: WeakEntity<Workspace>,
language_registry: Arc<LanguageRegistry>, language_registry: Arc<LanguageRegistry>,
thread_store: Entity<ThreadStore>, thread_store: Entity<ThreadStore>,
thread: Entity<Thread>, thread: Entity<Thread>,
@ -30,6 +33,7 @@ pub struct ActiveThread {
rendered_messages_by_id: HashMap<MessageId, Entity<Markdown>>, rendered_messages_by_id: HashMap<MessageId, Entity<Markdown>>,
editing_message: Option<(MessageId, EditMessageState)>, editing_message: Option<(MessageId, EditMessageState)>,
expanded_tool_uses: HashMap<LanguageModelToolUseId, bool>, expanded_tool_uses: HashMap<LanguageModelToolUseId, bool>,
expanded_scripts: HashSet<ScriptId>,
last_error: Option<ThreadError>, last_error: Option<ThreadError>,
_subscriptions: Vec<Subscription>, _subscriptions: Vec<Subscription>,
} }
@ -40,6 +44,7 @@ struct EditMessageState {
impl ActiveThread { impl ActiveThread {
pub fn new( pub fn new(
workspace: WeakEntity<Workspace>,
thread: Entity<Thread>, thread: Entity<Thread>,
thread_store: Entity<ThreadStore>, thread_store: Entity<ThreadStore>,
language_registry: Arc<LanguageRegistry>, language_registry: Arc<LanguageRegistry>,
@ -52,6 +57,7 @@ impl ActiveThread {
]; ];
let mut this = Self { let mut this = Self {
workspace,
language_registry, language_registry,
thread_store, thread_store,
thread: thread.clone(), thread: thread.clone(),
@ -59,6 +65,7 @@ impl ActiveThread {
messages: Vec::new(), messages: Vec::new(),
rendered_messages_by_id: HashMap::default(), rendered_messages_by_id: HashMap::default(),
expanded_tool_uses: HashMap::default(), expanded_tool_uses: HashMap::default(),
expanded_scripts: HashSet::default(),
list_state: ListState::new(0, ListAlignment::Bottom, px(1024.), { list_state: ListState::new(0, ListAlignment::Bottom, px(1024.), {
let this = cx.entity().downgrade(); let this = cx.entity().downgrade();
move |ix, window: &mut Window, cx: &mut App| { move |ix, window: &mut Window, cx: &mut App| {
@ -241,7 +248,7 @@ impl ActiveThread {
fn handle_thread_event( fn handle_thread_event(
&mut self, &mut self,
_: &Entity<Thread>, _thread: &Entity<Thread>,
event: &ThreadEvent, event: &ThreadEvent,
window: &mut Window, window: &mut Window,
cx: &mut Context<Self>, cx: &mut Context<Self>,
@ -306,6 +313,14 @@ impl ActiveThread {
} }
} }
} }
ThreadEvent::ScriptFinished => {
let model_registry = LanguageModelRegistry::read_global(cx);
if let Some(model) = model_registry.active_model() {
self.thread.update(cx, |thread, cx| {
thread.send_to_model(model, RequestKind::Chat, false, cx);
});
}
}
} }
} }
@ -445,12 +460,16 @@ impl ActiveThread {
return Empty.into_any(); return Empty.into_any();
}; };
let context = self.thread.read(cx).context_for_message(message_id); let thread = self.thread.read(cx);
let tool_uses = self.thread.read(cx).tool_uses_for_message(message_id);
let colors = cx.theme().colors(); let context = thread.context_for_message(message_id);
let tool_uses = thread.tool_uses_for_message(message_id);
// Don't render user messages that are just there for returning tool results. // Don't render user messages that are just there for returning tool results.
if message.role == Role::User && self.thread.read(cx).message_has_tool_results(message_id) { if message.role == Role::User
&& (thread.message_has_tool_results(message_id)
|| thread.message_has_script_output(message_id))
{
return Empty.into_any(); return Empty.into_any();
} }
@ -463,6 +482,8 @@ impl ActiveThread {
.filter(|(id, _)| *id == message_id) .filter(|(id, _)| *id == message_id)
.map(|(_, state)| state.editor.clone()); .map(|(_, state)| state.editor.clone());
let colors = cx.theme().colors();
let message_content = v_flex() let message_content = v_flex()
.child( .child(
if let Some(edit_message_editor) = edit_message_editor.clone() { if let Some(edit_message_editor) = edit_message_editor.clone() {
@ -597,6 +618,7 @@ impl ActiveThread {
Role::Assistant => div() Role::Assistant => div()
.id(("message-container", ix)) .id(("message-container", ix))
.child(message_content) .child(message_content)
.children(self.render_script(message_id, cx))
.map(|parent| { .map(|parent| {
if tool_uses.is_empty() { if tool_uses.is_empty() {
return parent; return parent;
@ -716,6 +738,139 @@ impl ActiveThread {
}), }),
) )
} }
fn render_script(&self, message_id: MessageId, cx: &mut Context<Self>) -> Option<AnyElement> {
let script = self.thread.read(cx).script_for_message(message_id, cx)?;
let is_open = self.expanded_scripts.contains(&script.id);
let colors = cx.theme().colors();
let element = div().px_2p5().child(
v_flex()
.gap_1()
.rounded_lg()
.border_1()
.border_color(colors.border)
.child(
h_flex()
.justify_between()
.py_0p5()
.pl_1()
.pr_2()
.bg(colors.editor_foreground.opacity(0.02))
.when(is_open, |element| element.border_b_1().rounded_t(px(6.)))
.when(!is_open, |element| element.rounded_md())
.border_color(colors.border)
.child(
h_flex()
.gap_1()
.child(Disclosure::new("script-disclosure", is_open).on_click(
cx.listener({
let script_id = script.id;
move |this, _event, _window, _cx| {
if this.expanded_scripts.contains(&script_id) {
this.expanded_scripts.remove(&script_id);
} else {
this.expanded_scripts.insert(script_id);
}
}
}),
))
// TODO: Generate script description
.child(Label::new("Script")),
)
.child(
h_flex()
.gap_1()
.child(
Label::new(match script.state {
ScriptState::Generating => "Generating",
ScriptState::Running { .. } => "Running",
ScriptState::Succeeded { .. } => "Finished",
ScriptState::Failed { .. } => "Error",
})
.size(LabelSize::XSmall)
.buffer_font(cx),
)
.child(
IconButton::new("view-source", IconName::Eye)
.icon_color(Color::Muted)
.disabled(matches!(script.state, ScriptState::Generating))
.on_click(cx.listener({
let source = script.source.clone();
move |this, _event, window, cx| {
this.open_script_source(source.clone(), window, cx);
}
})),
),
),
)
.when(is_open, |parent| {
let stdout = script.stdout_snapshot();
let error = script.error();
parent.child(
v_flex()
.p_2()
.bg(colors.editor_background)
.gap_2()
.child(if stdout.is_empty() && error.is_none() {
Label::new("No output yet")
.size(LabelSize::Small)
.color(Color::Muted)
} else {
Label::new(stdout).size(LabelSize::Small).buffer_font(cx)
})
.children(script.error().map(|err| {
Label::new(err.to_string())
.size(LabelSize::Small)
.color(Color::Error)
})),
)
}),
);
Some(element.into_any())
}
fn open_script_source(
&mut self,
source: SharedString,
window: &mut Window,
cx: &mut Context<'_, ActiveThread>,
) {
let language_registry = self.language_registry.clone();
let workspace = self.workspace.clone();
let source = source.clone();
cx.spawn_in(window, |_, mut cx| async move {
let lua = language_registry.language_for_name("Lua").await.log_err();
workspace.update_in(&mut cx, |workspace, window, cx| {
let project = workspace.project().clone();
let buffer = project.update(cx, |project, cx| {
project.create_local_buffer(&source.trim(), lua, cx)
});
let buffer = cx.new(|cx| {
MultiBuffer::singleton(buffer, cx)
// TODO: Generate script description
.with_title("Assistant script".into())
});
let editor = cx.new(|cx| {
let mut editor =
Editor::for_multibuffer(buffer, Some(project), true, window, cx);
editor.set_read_only(true);
editor
});
workspace.add_item_to_active_pane(Box::new(editor), None, true, window, cx);
})
})
.detach_and_log_err(cx);
}
} }
impl Render for ActiveThread { impl Render for ActiveThread {

View file

@ -166,22 +166,25 @@ impl AssistantPanel {
let history_store = let history_store =
cx.new(|cx| HistoryStore::new(thread_store.clone(), context_store.clone(), cx)); cx.new(|cx| HistoryStore::new(thread_store.clone(), context_store.clone(), cx));
let thread = cx.new(|cx| {
ActiveThread::new(
workspace.clone(),
thread.clone(),
thread_store.clone(),
language_registry.clone(),
window,
cx,
)
});
Self { Self {
active_view: ActiveView::Thread, active_view: ActiveView::Thread,
workspace, workspace,
project: project.clone(), project: project.clone(),
fs: fs.clone(), fs: fs.clone(),
language_registry: language_registry.clone(), language_registry,
thread_store: thread_store.clone(), thread_store: thread_store.clone(),
thread: cx.new(|cx| { thread,
ActiveThread::new(
thread.clone(),
thread_store.clone(),
language_registry,
window,
cx,
)
}),
message_editor, message_editor,
context_store, context_store,
context_editor: None, context_editor: None,
@ -239,6 +242,7 @@ impl AssistantPanel {
self.active_view = ActiveView::Thread; self.active_view = ActiveView::Thread;
self.thread = cx.new(|cx| { self.thread = cx.new(|cx| {
ActiveThread::new( ActiveThread::new(
self.workspace.clone(),
thread.clone(), thread.clone(),
self.thread_store.clone(), self.thread_store.clone(),
self.language_registry.clone(), self.language_registry.clone(),
@ -372,6 +376,7 @@ impl AssistantPanel {
this.active_view = ActiveView::Thread; this.active_view = ActiveView::Thread;
this.thread = cx.new(|cx| { this.thread = cx.new(|cx| {
ActiveThread::new( ActiveThread::new(
this.workspace.clone(),
thread.clone(), thread.clone(),
this.thread_store.clone(), this.thread_store.clone(),
this.language_registry.clone(), this.language_registry.clone(),

View file

@ -1,11 +1,14 @@
use std::sync::Arc; use std::sync::Arc;
use anyhow::Result; use anyhow::Result;
use assistant_scripting::{
Script, ScriptEvent, ScriptId, ScriptSession, ScriptTagParser, SCRIPTING_PROMPT,
};
use assistant_tool::ToolWorkingSet; use assistant_tool::ToolWorkingSet;
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use collections::{BTreeMap, HashMap, HashSet}; use collections::{BTreeMap, HashMap, HashSet};
use futures::StreamExt as _; use futures::StreamExt as _;
use gpui::{App, Context, Entity, EventEmitter, SharedString, Task}; use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Subscription, Task};
use language_model::{ use language_model::{
LanguageModel, LanguageModelCompletionEvent, LanguageModelRegistry, LanguageModelRequest, LanguageModel, LanguageModelCompletionEvent, LanguageModelRegistry, LanguageModelRequest,
LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult, LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
@ -75,14 +78,21 @@ pub struct Thread {
project: Entity<Project>, project: Entity<Project>,
tools: Arc<ToolWorkingSet>, tools: Arc<ToolWorkingSet>,
tool_use: ToolUseState, tool_use: ToolUseState,
scripts_by_assistant_message: HashMap<MessageId, ScriptId>,
script_output_messages: HashSet<MessageId>,
script_session: Entity<ScriptSession>,
_script_session_subscription: Subscription,
} }
impl Thread { impl Thread {
pub fn new( pub fn new(
project: Entity<Project>, project: Entity<Project>,
tools: Arc<ToolWorkingSet>, tools: Arc<ToolWorkingSet>,
_cx: &mut Context<Self>, cx: &mut Context<Self>,
) -> Self { ) -> Self {
let script_session = cx.new(|cx| ScriptSession::new(project.clone(), cx));
let script_session_subscription = cx.subscribe(&script_session, Self::handle_script_event);
Self { Self {
id: ThreadId::new(), id: ThreadId::new(),
updated_at: Utc::now(), updated_at: Utc::now(),
@ -97,6 +107,10 @@ impl Thread {
project, project,
tools, tools,
tool_use: ToolUseState::new(), tool_use: ToolUseState::new(),
scripts_by_assistant_message: HashMap::default(),
script_output_messages: HashSet::default(),
script_session,
_script_session_subscription: script_session_subscription,
} }
} }
@ -105,7 +119,7 @@ impl Thread {
saved: SavedThread, saved: SavedThread,
project: Entity<Project>, project: Entity<Project>,
tools: Arc<ToolWorkingSet>, tools: Arc<ToolWorkingSet>,
_cx: &mut Context<Self>, cx: &mut Context<Self>,
) -> Self { ) -> Self {
let next_message_id = MessageId( let next_message_id = MessageId(
saved saved
@ -115,6 +129,8 @@ impl Thread {
.unwrap_or(0), .unwrap_or(0),
); );
let tool_use = ToolUseState::from_saved_messages(&saved.messages); let tool_use = ToolUseState::from_saved_messages(&saved.messages);
let script_session = cx.new(|cx| ScriptSession::new(project.clone(), cx));
let script_session_subscription = cx.subscribe(&script_session, Self::handle_script_event);
Self { Self {
id, id,
@ -138,6 +154,10 @@ impl Thread {
project, project,
tools, tools,
tool_use, tool_use,
scripts_by_assistant_message: HashMap::default(),
script_output_messages: HashSet::default(),
script_session,
_script_session_subscription: script_session_subscription,
} }
} }
@ -223,17 +243,22 @@ impl Thread {
self.tool_use.message_has_tool_results(message_id) self.tool_use.message_has_tool_results(message_id)
} }
pub fn message_has_script_output(&self, message_id: MessageId) -> bool {
self.script_output_messages.contains(&message_id)
}
pub fn insert_user_message( pub fn insert_user_message(
&mut self, &mut self,
text: impl Into<String>, text: impl Into<String>,
context: Vec<ContextSnapshot>, context: Vec<ContextSnapshot>,
cx: &mut Context<Self>, cx: &mut Context<Self>,
) { ) -> MessageId {
let message_id = self.insert_message(Role::User, text, cx); let message_id = self.insert_message(Role::User, text, cx);
let context_ids = context.iter().map(|context| context.id).collect::<Vec<_>>(); let context_ids = context.iter().map(|context| context.id).collect::<Vec<_>>();
self.context self.context
.extend(context.into_iter().map(|context| (context.id, context))); .extend(context.into_iter().map(|context| (context.id, context)));
self.context_by_message.insert(message_id, context_ids); self.context_by_message.insert(message_id, context_ids);
message_id
} }
pub fn insert_message( pub fn insert_message(
@ -302,6 +327,39 @@ impl Thread {
text text
} }
pub fn script_for_message<'a>(
&'a self,
message_id: MessageId,
cx: &'a App,
) -> Option<&'a Script> {
self.scripts_by_assistant_message
.get(&message_id)
.map(|script_id| self.script_session.read(cx).get(*script_id))
}
fn handle_script_event(
&mut self,
_script_session: Entity<ScriptSession>,
event: &ScriptEvent,
cx: &mut Context<Self>,
) {
match event {
ScriptEvent::Spawned(_) => {}
ScriptEvent::Exited(script_id) => {
if let Some(output_message) = self
.script_session
.read(cx)
.get(*script_id)
.output_message_for_llm()
{
let message_id = self.insert_user_message(output_message, vec![], cx);
self.script_output_messages.insert(message_id);
cx.emit(ThreadEvent::ScriptFinished)
}
}
}
}
pub fn send_to_model( pub fn send_to_model(
&mut self, &mut self,
model: Arc<dyn LanguageModel>, model: Arc<dyn LanguageModel>,
@ -330,7 +388,7 @@ impl Thread {
pub fn to_completion_request( pub fn to_completion_request(
&self, &self,
request_kind: RequestKind, request_kind: RequestKind,
_cx: &App, cx: &App,
) -> LanguageModelRequest { ) -> LanguageModelRequest {
let mut request = LanguageModelRequest { let mut request = LanguageModelRequest {
messages: vec![], messages: vec![],
@ -339,6 +397,12 @@ impl Thread {
temperature: None, temperature: None,
}; };
request.messages.push(LanguageModelRequestMessage {
role: Role::System,
content: vec![SCRIPTING_PROMPT.to_string().into()],
cache: true,
});
let mut referenced_context_ids = HashSet::default(); let mut referenced_context_ids = HashSet::default();
for message in &self.messages { for message in &self.messages {
@ -351,6 +415,7 @@ impl Thread {
content: Vec::new(), content: Vec::new(),
cache: false, cache: false,
}; };
match request_kind { match request_kind {
RequestKind::Chat => { RequestKind::Chat => {
self.tool_use self.tool_use
@ -371,11 +436,20 @@ impl Thread {
RequestKind::Chat => { RequestKind::Chat => {
self.tool_use self.tool_use
.attach_tool_uses(message.id, &mut request_message); .attach_tool_uses(message.id, &mut request_message);
if matches!(message.role, Role::Assistant) {
if let Some(script_id) = self.scripts_by_assistant_message.get(&message.id)
{
let script = self.script_session.read(cx).get(*script_id);
request_message.content.push(script.source_tag().into());
}
}
} }
RequestKind::Summarize => { RequestKind::Summarize => {
// We don't care about tool use during summarization. // We don't care about tool use during summarization.
} }
} };
request.messages.push(request_message); request.messages.push(request_message);
} }
@ -412,6 +486,8 @@ impl Thread {
let stream_completion = async { let stream_completion = async {
let mut events = stream.await?; let mut events = stream.await?;
let mut stop_reason = StopReason::EndTurn; let mut stop_reason = StopReason::EndTurn;
let mut script_tag_parser = ScriptTagParser::new();
let mut script_id = None;
while let Some(event) = events.next().await { while let Some(event) = events.next().await {
let event = event?; let event = event?;
@ -426,19 +502,43 @@ impl Thread {
} }
LanguageModelCompletionEvent::Text(chunk) => { LanguageModelCompletionEvent::Text(chunk) => {
if let Some(last_message) = thread.messages.last_mut() { if let Some(last_message) = thread.messages.last_mut() {
if last_message.role == Role::Assistant { let chunk = script_tag_parser.parse_chunk(&chunk);
last_message.text.push_str(&chunk);
let message_id = if last_message.role == Role::Assistant {
last_message.text.push_str(&chunk.content);
cx.emit(ThreadEvent::StreamedAssistantText( cx.emit(ThreadEvent::StreamedAssistantText(
last_message.id, last_message.id,
chunk, chunk.content,
)); ));
last_message.id
} else { } else {
// If we won't have an Assistant message yet, assume this chunk marks the beginning // If we won't have an Assistant message yet, assume this chunk marks the beginning
// of a new Assistant response. // of a new Assistant response.
// //
// Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
// will result in duplicating the text of the chunk in the rendered Markdown. // will result in duplicating the text of the chunk in the rendered Markdown.
thread.insert_message(Role::Assistant, chunk, cx); thread.insert_message(Role::Assistant, chunk.content, cx)
};
if script_id.is_none() && script_tag_parser.found_script() {
let id = thread
.script_session
.update(cx, |session, _cx| session.new_script());
thread.scripts_by_assistant_message.insert(message_id, id);
script_id = Some(id);
}
if let (Some(script_source), Some(script_id)) =
(chunk.script_source, script_id)
{
// TODO: move buffer to script and run as it streams
thread
.script_session
.update(cx, |this, cx| {
this.run_script(script_id, script_source, cx)
})
.detach_and_log_err(cx);
} }
} }
} }
@ -661,6 +761,7 @@ pub enum ThreadEvent {
#[allow(unused)] #[allow(unused)]
tool_use_id: LanguageModelToolUseId, tool_use_id: LanguageModelToolUseId,
}, },
ScriptFinished,
} }
impl EventEmitter<ThreadEvent> for Thread {} impl EventEmitter<ThreadEvent> for Thread {}

View file

@ -1,5 +1,5 @@
[package] [package]
name = "scripting_tool" name = "assistant_scripting"
version = "0.1.0" version = "0.1.0"
edition.workspace = true edition.workspace = true
publish.workspace = true publish.workspace = true
@ -9,12 +9,11 @@ license = "GPL-3.0-or-later"
workspace = true workspace = true
[lib] [lib]
path = "src/scripting_tool.rs" path = "src/assistant_scripting.rs"
doctest = false doctest = false
[dependencies] [dependencies]
anyhow.workspace = true anyhow.workspace = true
assistant_tool.workspace = true
collections.workspace = true collections.workspace = true
futures.workspace = true futures.workspace = true
gpui.workspace = true gpui.workspace = true
@ -22,7 +21,6 @@ mlua.workspace = true
parking_lot.workspace = true parking_lot.workspace = true
project.workspace = true project.workspace = true
regex.workspace = true regex.workspace = true
schemars.workspace = true
serde.workspace = true serde.workspace = true
serde_json.workspace = true serde_json.workspace = true
settings.workspace = true settings.workspace = true
@ -32,4 +30,5 @@ util.workspace = true
collections = { workspace = true, features = ["test-support"] } collections = { workspace = true, features = ["test-support"] }
gpui = { workspace = true, features = ["test-support"] } gpui = { workspace = true, features = ["test-support"] }
project = { workspace = true, features = ["test-support"] } project = { workspace = true, features = ["test-support"] }
rand.workspace = true
settings = { workspace = true, features = ["test-support"] } settings = { workspace = true, features = ["test-support"] }

View file

@ -0,0 +1,7 @@
mod session;
mod tag;
pub use session::*;
pub use tag::*;
pub const SCRIPTING_PROMPT: &str = include_str!("./system_prompt.txt");

View file

@ -1,10 +1,9 @@
use anyhow::Result;
use collections::{HashMap, HashSet}; use collections::{HashMap, HashSet};
use futures::{ use futures::{
channel::{mpsc, oneshot}, channel::{mpsc, oneshot},
pin_mut, SinkExt, StreamExt, pin_mut, SinkExt, StreamExt,
}; };
use gpui::{AppContext, AsyncApp, Context, Entity, Task, WeakEntity}; use gpui::{AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
use mlua::{Lua, MultiValue, Table, UserData, UserDataMethods}; use mlua::{Lua, MultiValue, Table, UserData, UserDataMethods};
use parking_lot::Mutex; use parking_lot::Mutex;
use project::{search::SearchQuery, Fs, Project}; use project::{search::SearchQuery, Fs, Project};
@ -16,24 +15,23 @@ use std::{
}; };
use util::{paths::PathMatcher, ResultExt}; use util::{paths::PathMatcher, ResultExt};
pub struct ScriptOutput { use crate::{SCRIPT_END_TAG, SCRIPT_START_TAG};
pub stdout: String,
}
struct ForegroundFn(Box<dyn FnOnce(WeakEntity<Session>, AsyncApp) + Send>); struct ForegroundFn(Box<dyn FnOnce(WeakEntity<ScriptSession>, AsyncApp) + Send>);
pub struct Session { pub struct ScriptSession {
project: Entity<Project>, project: Entity<Project>,
// TODO Remove this // TODO Remove this
fs_changes: Arc<Mutex<HashMap<PathBuf, Vec<u8>>>>, fs_changes: Arc<Mutex<HashMap<PathBuf, Vec<u8>>>>,
foreground_fns_tx: mpsc::Sender<ForegroundFn>, foreground_fns_tx: mpsc::Sender<ForegroundFn>,
_invoke_foreground_fns: Task<()>, _invoke_foreground_fns: Task<()>,
scripts: Vec<Script>,
} }
impl Session { impl ScriptSession {
pub fn new(project: Entity<Project>, cx: &mut Context<Self>) -> Self { pub fn new(project: Entity<Project>, cx: &mut Context<Self>) -> Self {
let (foreground_fns_tx, mut foreground_fns_rx) = mpsc::channel(128); let (foreground_fns_tx, mut foreground_fns_rx) = mpsc::channel(128);
Session { ScriptSession {
project, project,
fs_changes: Arc::new(Mutex::new(HashMap::default())), fs_changes: Arc::new(Mutex::new(HashMap::default())),
foreground_fns_tx, foreground_fns_tx,
@ -42,15 +40,62 @@ impl Session {
foreground_fn.0(this.clone(), cx.clone()); foreground_fn.0(this.clone(), cx.clone());
} }
}), }),
scripts: Vec::new(),
} }
} }
/// Runs a Lua script in a sandboxed environment and returns the printed lines pub fn new_script(&mut self) -> ScriptId {
let id = ScriptId(self.scripts.len() as u32);
let script = Script {
id,
state: ScriptState::Generating,
source: SharedString::new_static(""),
};
self.scripts.push(script);
id
}
pub fn run_script( pub fn run_script(
&mut self, &mut self,
script: String, script_id: ScriptId,
script_src: String,
cx: &mut Context<Self>, cx: &mut Context<Self>,
) -> Task<Result<ScriptOutput>> { ) -> Task<anyhow::Result<()>> {
let script = self.get_mut(script_id);
let stdout = Arc::new(Mutex::new(String::new()));
script.source = script_src.clone().into();
script.state = ScriptState::Running {
stdout: stdout.clone(),
};
let task = self.run_lua(script_src, stdout, cx);
cx.emit(ScriptEvent::Spawned(script_id));
cx.spawn(|session, mut cx| async move {
let result = task.await;
session.update(&mut cx, |session, cx| {
let script = session.get_mut(script_id);
let stdout = script.stdout_snapshot();
script.state = match result {
Ok(()) => ScriptState::Succeeded { stdout },
Err(error) => ScriptState::Failed { stdout, error },
};
cx.emit(ScriptEvent::Exited(script_id))
})
})
}
fn run_lua(
&mut self,
script: String,
stdout: Arc<Mutex<String>>,
cx: &mut Context<Self>,
) -> Task<anyhow::Result<()>> {
const SANDBOX_PREAMBLE: &str = include_str!("sandbox_preamble.lua"); const SANDBOX_PREAMBLE: &str = include_str!("sandbox_preamble.lua");
// TODO Remove fs_changes // TODO Remove fs_changes
@ -62,52 +107,64 @@ impl Session {
.visible_worktrees(cx) .visible_worktrees(cx)
.next() .next()
.map(|worktree| worktree.read(cx).abs_path()); .map(|worktree| worktree.read(cx).abs_path());
let fs = self.project.read(cx).fs().clone(); let fs = self.project.read(cx).fs().clone();
let foreground_fns_tx = self.foreground_fns_tx.clone(); let foreground_fns_tx = self.foreground_fns_tx.clone();
cx.background_spawn(async move {
let lua = Lua::new();
lua.set_memory_limit(2 * 1024 * 1024 * 1024)?; // 2 GB
let globals = lua.globals();
let stdout = Arc::new(Mutex::new(String::new()));
globals.set(
"sb_print",
lua.create_function({
let stdout = stdout.clone();
move |_, args: MultiValue| Self::print(args, &stdout)
})?,
)?;
globals.set(
"search",
lua.create_async_function({
let foreground_fns_tx = foreground_fns_tx.clone();
let fs = fs.clone();
move |lua, regex| {
Self::search(lua, foreground_fns_tx.clone(), fs.clone(), regex)
}
})?,
)?;
globals.set(
"sb_io_open",
lua.create_function({
let fs_changes = fs_changes.clone();
let root_dir = root_dir.clone();
move |lua, (path_str, mode)| {
Self::io_open(&lua, &fs_changes, root_dir.as_ref(), path_str, mode)
}
})?,
)?;
globals.set("user_script", script)?;
lua.load(SANDBOX_PREAMBLE).exec_async().await?; let task = cx.background_spawn({
let stdout = stdout.clone();
// Drop Lua instance to decrement reference count. async move {
drop(lua); let lua = Lua::new();
lua.set_memory_limit(2 * 1024 * 1024 * 1024)?; // 2 GB
let globals = lua.globals();
globals.set(
"sb_print",
lua.create_function({
let stdout = stdout.clone();
move |_, args: MultiValue| Self::print(args, &stdout)
})?,
)?;
globals.set(
"search",
lua.create_async_function({
let foreground_fns_tx = foreground_fns_tx.clone();
let fs = fs.clone();
move |lua, regex| {
Self::search(lua, foreground_fns_tx.clone(), fs.clone(), regex)
}
})?,
)?;
globals.set(
"sb_io_open",
lua.create_function({
let fs_changes = fs_changes.clone();
let root_dir = root_dir.clone();
move |lua, (path_str, mode)| {
Self::io_open(&lua, &fs_changes, root_dir.as_ref(), path_str, mode)
}
})?,
)?;
globals.set("user_script", script)?;
let stdout = Arc::try_unwrap(stdout) lua.load(SANDBOX_PREAMBLE).exec_async().await?;
.expect("no more references to stdout")
.into_inner(); // Drop Lua instance to decrement reference count.
Ok(ScriptOutput { stdout }) drop(lua);
})
anyhow::Ok(())
}
});
task
}
pub fn get(&self, script_id: ScriptId) -> &Script {
&self.scripts[script_id.0 as usize]
}
fn get_mut(&mut self, script_id: ScriptId) -> &mut Script {
&mut self.scripts[script_id.0 as usize]
} }
/// Sandboxed print() function in Lua. /// Sandboxed print() function in Lua.
@ -678,6 +735,79 @@ impl UserData for FileContent {
} }
} }
#[derive(Debug)]
pub enum ScriptEvent {
Spawned(ScriptId),
Exited(ScriptId),
}
impl EventEmitter<ScriptEvent> for ScriptSession {}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub struct ScriptId(u32);
pub struct Script {
pub id: ScriptId,
pub state: ScriptState,
pub source: SharedString,
}
pub enum ScriptState {
Generating,
Running {
stdout: Arc<Mutex<String>>,
},
Succeeded {
stdout: String,
},
Failed {
stdout: String,
error: anyhow::Error,
},
}
impl Script {
pub fn source_tag(&self) -> String {
format!("{}{}{}", SCRIPT_START_TAG, self.source, SCRIPT_END_TAG)
}
/// If exited, returns a message with the output for the LLM
pub fn output_message_for_llm(&self) -> Option<String> {
match &self.state {
ScriptState::Generating { .. } => None,
ScriptState::Running { .. } => None,
ScriptState::Succeeded { stdout } => {
format!("Here's the script output:\n{}", stdout).into()
}
ScriptState::Failed { stdout, error } => format!(
"The script failed with:\n{}\n\nHere's the output it managed to print:\n{}",
error, stdout
)
.into(),
}
}
/// Get a snapshot of the script's stdout
pub fn stdout_snapshot(&self) -> String {
match &self.state {
ScriptState::Generating { .. } => String::new(),
ScriptState::Running { stdout } => stdout.lock().clone(),
ScriptState::Succeeded { stdout } => stdout.clone(),
ScriptState::Failed { stdout, .. } => stdout.clone(),
}
}
/// Returns the error if the script failed, otherwise None
pub fn error(&self) -> Option<&anyhow::Error> {
match &self.state {
ScriptState::Generating { .. } => None,
ScriptState::Running { .. } => None,
ScriptState::Succeeded { .. } => None,
ScriptState::Failed { error, .. } => Some(error),
}
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use gpui::TestAppContext; use gpui::TestAppContext;
@ -689,35 +819,17 @@ mod tests {
#[gpui::test] #[gpui::test]
async fn test_print(cx: &mut TestAppContext) { async fn test_print(cx: &mut TestAppContext) {
init_test(cx);
let fs = FakeFs::new(cx.executor());
let project = Project::test(fs, [], cx).await;
let session = cx.new(|cx| Session::new(project, cx));
let script = r#" let script = r#"
print("Hello", "world!") print("Hello", "world!")
print("Goodbye", "moon!") print("Goodbye", "moon!")
"#; "#;
let output = session
.update(cx, |session, cx| session.run_script(script.to_string(), cx)) let output = test_script(script, cx).await.unwrap();
.await assert_eq!(output, "Hello\tworld!\nGoodbye\tmoon!\n");
.unwrap();
assert_eq!(output.stdout, "Hello\tworld!\nGoodbye\tmoon!\n");
} }
#[gpui::test] #[gpui::test]
async fn test_search(cx: &mut TestAppContext) { async fn test_search(cx: &mut TestAppContext) {
init_test(cx);
let fs = FakeFs::new(cx.executor());
fs.insert_tree(
"/",
json!({
"file1.txt": "Hello world!",
"file2.txt": "Goodbye moon!"
}),
)
.await;
let project = Project::test(fs, [Path::new("/")], cx).await;
let session = cx.new(|cx| Session::new(project, cx));
let script = r#" let script = r#"
local results = search("world") local results = search("world")
for i, result in ipairs(results) do for i, result in ipairs(results) do
@ -728,11 +840,36 @@ mod tests {
end end
end end
"#; "#;
let output = session
.update(cx, |session, cx| session.run_script(script.to_string(), cx)) let output = test_script(script, cx).await.unwrap();
.await assert_eq!(output, "File: /file1.txt\nMatches:\n world\n");
.unwrap(); }
assert_eq!(output.stdout, "File: /file1.txt\nMatches:\n world\n");
async fn test_script(source: &str, cx: &mut TestAppContext) -> anyhow::Result<String> {
init_test(cx);
let fs = FakeFs::new(cx.executor());
fs.insert_tree(
"/",
json!({
"file1.txt": "Hello world!",
"file2.txt": "Goodbye moon!"
}),
)
.await;
let project = Project::test(fs, [Path::new("/")], cx).await;
let session = cx.new(|cx| ScriptSession::new(project, cx));
let (script_id, task) = session.update(cx, |session, cx| {
let script_id = session.new_script();
let task = session.run_script(script_id, source.to_string(), cx);
(script_id, task)
});
task.await?;
Ok(session.read_with(cx, |session, _cx| session.get(script_id).stdout_snapshot()))
} }
fn init_test(cx: &mut TestAppContext) { fn init_test(cx: &mut TestAppContext) {

View file

@ -3,6 +3,12 @@ output was, including both stdout as well as the git diff of changes it made to
the filesystem. That way, you can get more information about the code base, or the filesystem. That way, you can get more information about the code base, or
make changes to the code base directly. make changes to the code base directly.
Put the Lua script inside of an `<eval>` tag like so:
<eval type="lua">
print("Hello, world!")
</eval>
The Lua script will have access to `io` and it will run with the current working The Lua script will have access to `io` and it will run with the current working
directory being in the root of the code base, so you can use it to explore, directory being in the root of the code base, so you can use it to explore,
search, make changes, etc. You can also have the script print things, and I'll search, make changes, etc. You can also have the script print things, and I'll
@ -10,13 +16,17 @@ tell you what the output was. Note that `io` only has `open`, and then the file
it returns only has the methods read, write, and close - it doesn't have popen it returns only has the methods read, write, and close - it doesn't have popen
or anything else. or anything else.
Also, I'm going to be putting this Lua script into JSON, so please don't use
Lua's double quote syntax for string literals - use one of Lua's other syntaxes
for string literals, so I don't have to escape the double quotes.
There will be a global called `search` which accepts a regex (it's implemented There will be a global called `search` which accepts a regex (it's implemented
using Rust's regex crate, so use that regex syntax) and runs that regex on the using Rust's regex crate, so use that regex syntax) and runs that regex on the
contents of every file in the code base (aside from gitignored files), then contents of every file in the code base (aside from gitignored files), then
returns an array of tables with two fields: "path" (the path to the file that returns an array of tables with two fields: "path" (the path to the file that
had the matches) and "matches" (an array of strings, with each string being a had the matches) and "matches" (an array of strings, with each string being a
match that was found within the file). match that was found within the file).
When I send you the script output, do not thank me for running it,
act as if you ran it yourself.
IMPORTANT!
Only include a maximum of one Lua script at the very end of your message
DO NOT WRITE ANYTHING ELSE AFTER THE SCRIPT. Wait for my response with the script
output to continue.

View file

@ -0,0 +1,260 @@
pub const SCRIPT_START_TAG: &str = "<eval type=\"lua\">";
pub const SCRIPT_END_TAG: &str = "</eval>";
const START_TAG: &[u8] = SCRIPT_START_TAG.as_bytes();
const END_TAG: &[u8] = SCRIPT_END_TAG.as_bytes();
/// Parses a script tag in an assistant message as it is being streamed.
pub struct ScriptTagParser {
state: State,
buffer: Vec<u8>,
tag_match_ix: usize,
}
enum State {
Unstarted,
Streaming,
Ended,
}
#[derive(Debug, PartialEq)]
pub struct ChunkOutput {
/// The chunk with script tags removed.
pub content: String,
/// The full script tag content. `None` until closed.
pub script_source: Option<String>,
}
impl ScriptTagParser {
/// Create a new script tag parser.
pub fn new() -> Self {
Self {
state: State::Unstarted,
buffer: Vec::new(),
tag_match_ix: 0,
}
}
/// Returns true if the parser has found a script tag.
pub fn found_script(&self) -> bool {
match self.state {
State::Unstarted => false,
State::Streaming | State::Ended => true,
}
}
/// Process a new chunk of input, splitting it into surrounding content and script source.
pub fn parse_chunk(&mut self, input: &str) -> ChunkOutput {
let mut content = Vec::with_capacity(input.len());
for byte in input.bytes() {
match self.state {
State::Unstarted => {
if collect_until_tag(byte, START_TAG, &mut self.tag_match_ix, &mut content) {
self.state = State::Streaming;
self.buffer = Vec::with_capacity(1024);
self.tag_match_ix = 0;
}
}
State::Streaming => {
if collect_until_tag(byte, END_TAG, &mut self.tag_match_ix, &mut self.buffer) {
self.state = State::Ended;
}
}
State::Ended => content.push(byte),
}
}
let content = unsafe { String::from_utf8_unchecked(content) };
let script_source = if matches!(self.state, State::Ended) && !self.buffer.is_empty() {
let source = unsafe { String::from_utf8_unchecked(std::mem::take(&mut self.buffer)) };
Some(source)
} else {
None
};
ChunkOutput {
content,
script_source,
}
}
}
fn collect_until_tag(byte: u8, tag: &[u8], tag_match_ix: &mut usize, buffer: &mut Vec<u8>) -> bool {
// this can't be a method because it'd require a mutable borrow on both self and self.buffer
if match_tag_byte(byte, tag, tag_match_ix) {
*tag_match_ix >= tag.len()
} else {
if *tag_match_ix > 0 {
// push the partially matched tag to the buffer
buffer.extend_from_slice(&tag[..*tag_match_ix]);
*tag_match_ix = 0;
// the tag might start to match again
if match_tag_byte(byte, tag, tag_match_ix) {
return *tag_match_ix >= tag.len();
}
}
buffer.push(byte);
false
}
}
fn match_tag_byte(byte: u8, tag: &[u8], tag_match_ix: &mut usize) -> bool {
if byte == tag[*tag_match_ix] {
*tag_match_ix += 1;
true
} else {
false
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_complete_tag() {
let mut parser = ScriptTagParser::new();
let input = "<eval type=\"lua\">print(\"Hello, World!\")</eval>";
let result = parser.parse_chunk(input);
assert_eq!(result.content, "");
assert_eq!(
result.script_source,
Some("print(\"Hello, World!\")".to_string())
);
}
#[test]
fn test_no_tag() {
let mut parser = ScriptTagParser::new();
let input = "No tags here, just plain text";
let result = parser.parse_chunk(input);
assert_eq!(result.content, "No tags here, just plain text");
assert_eq!(result.script_source, None);
}
#[test]
fn test_partial_end_tag() {
let mut parser = ScriptTagParser::new();
// Start the tag
let result = parser.parse_chunk("<eval type=\"lua\">let x = '</e");
assert_eq!(result.content, "");
assert_eq!(result.script_source, None);
// Finish with the rest
let result = parser.parse_chunk("val' + 'not the end';</eval>");
assert_eq!(result.content, "");
assert_eq!(
result.script_source,
Some("let x = '</eval' + 'not the end';".to_string())
);
}
#[test]
fn test_text_before_and_after_tag() {
let mut parser = ScriptTagParser::new();
let input = "Before tag <eval type=\"lua\">print(\"Hello\")</eval> After tag";
let result = parser.parse_chunk(input);
assert_eq!(result.content, "Before tag After tag");
assert_eq!(result.script_source, Some("print(\"Hello\")".to_string()));
}
#[test]
fn test_multiple_chunks_with_surrounding_text() {
let mut parser = ScriptTagParser::new();
// First chunk with text before
let result = parser.parse_chunk("Before script <eval type=\"lua\">local x = 10");
assert_eq!(result.content, "Before script ");
assert_eq!(result.script_source, None);
// Second chunk with script content
let result = parser.parse_chunk("\nlocal y = 20");
assert_eq!(result.content, "");
assert_eq!(result.script_source, None);
// Last chunk with text after
let result = parser.parse_chunk("\nprint(x + y)</eval> After script");
assert_eq!(result.content, " After script");
assert_eq!(
result.script_source,
Some("local x = 10\nlocal y = 20\nprint(x + y)".to_string())
);
let result = parser.parse_chunk(" there's more text");
assert_eq!(result.content, " there's more text");
assert_eq!(result.script_source, None);
}
#[test]
fn test_partial_start_tag_matching() {
let mut parser = ScriptTagParser::new();
// partial match of start tag...
let result = parser.parse_chunk("<ev");
assert_eq!(result.content, "");
// ...that's abandandoned when the < of a real tag is encountered
let result = parser.parse_chunk("<eval type=\"lua\">script content</eval>");
// ...so it gets pushed to content
assert_eq!(result.content, "<ev");
// ...and the real tag is parsed correctly
assert_eq!(result.script_source, Some("script content".to_string()));
}
#[test]
fn test_random_chunked_parsing() {
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
use std::time::{SystemTime, UNIX_EPOCH};
let test_inputs = [
"Before <eval type=\"lua\">print(\"Hello\")</eval> After",
"No tags here at all",
"<eval type=\"lua\">local x = 10\nlocal y = 20\nprint(x + y)</eval>",
"Text <eval type=\"lua\">if true then\nprint(\"nested </e\")\nend</eval> more",
];
let seed = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs();
eprintln!("Using random seed: {}", seed);
let mut rng = StdRng::seed_from_u64(seed);
for test_input in &test_inputs {
let mut reference_parser = ScriptTagParser::new();
let expected = reference_parser.parse_chunk(test_input);
let mut chunked_parser = ScriptTagParser::new();
let mut remaining = test_input.as_bytes();
let mut actual_content = String::new();
let mut actual_script = None;
while !remaining.is_empty() {
let chunk_size = rng.gen_range(1..=remaining.len().min(5));
let (chunk, rest) = remaining.split_at(chunk_size);
remaining = rest;
let chunk_str = std::str::from_utf8(chunk).unwrap();
let result = chunked_parser.parse_chunk(chunk_str);
actual_content.push_str(&result.content);
if result.script_source.is_some() {
actual_script = result.script_source;
}
}
assert_eq!(actual_content, expected.content);
assert_eq!(actual_script, expected.script_source);
}
}
}

View file

@ -1,58 +0,0 @@
mod session;
use project::Project;
pub(crate) use session::*;
use assistant_tool::{Tool, ToolRegistry};
use gpui::{App, AppContext as _, Entity, Task};
use schemars::JsonSchema;
use serde::Deserialize;
use std::sync::Arc;
pub fn init(cx: &App) {
let registry = ToolRegistry::global(cx);
registry.register_tool(ScriptingTool);
}
#[derive(Debug, Deserialize, JsonSchema)]
struct ScriptingToolInput {
lua_script: String,
}
struct ScriptingTool;
impl Tool for ScriptingTool {
fn name(&self) -> String {
"lua-interpreter".into()
}
fn description(&self) -> String {
include_str!("scripting_tool_description.txt").into()
}
fn input_schema(&self) -> serde_json::Value {
let schema = schemars::schema_for!(ScriptingToolInput);
serde_json::to_value(&schema).unwrap()
}
fn run(
self: Arc<Self>,
input: serde_json::Value,
project: Entity<Project>,
cx: &mut App,
) -> Task<anyhow::Result<String>> {
let input = match serde_json::from_value::<ScriptingToolInput>(input) {
Err(err) => return Task::ready(Err(err.into())),
Ok(input) => input,
};
let session = cx.new(|cx| Session::new(project, cx));
let lua_script = input.lua_script;
let script = session.update(cx, |session, cx| session.run_script(lua_script, cx));
cx.spawn(|_cx| async move {
let output = script.await?.stdout;
drop(session);
Ok(format!("The script output the following:\n{output}"))
})
}
}

View file

@ -98,7 +98,6 @@ remote.workspace = true
repl.workspace = true repl.workspace = true
reqwest_client.workspace = true reqwest_client.workspace = true
rope.workspace = true rope.workspace = true
scripting_tool.workspace = true
search.workspace = true search.workspace = true
serde.workspace = true serde.workspace = true
serde_json.workspace = true serde_json.workspace = true

View file

@ -476,7 +476,6 @@ fn main() {
cx, cx,
); );
assistant_tools::init(cx); assistant_tools::init(cx);
scripting_tool::init(cx);
repl::init(app_state.fs.clone(), cx); repl::init(app_state.fs.clone(), cx);
extension_host::init( extension_host::init(
extension_host_proxy, extension_host_proxy,