assistant: Make scripting a first-class concept instead of a tool (#26338)
This PR makes refactors the scripting functionality to be a first-class concept of the assistant instead of a generic tool, which will allow us to build a more customized experience. - The tool prompt has been slightly tweaked and is now included as a system message in all conversations. I'm getting decent results, but now that it isn't in the tools framework, it will probably require more refining. - The model will now include an `<eval ...>` tag at the end of the message with the script. We parse this tag incrementally as it streams in so that we can indicate that we are generating a script before we see the closing `</eval>` tag. Later, this will help us interpret the script as it arrives also. - Threads now hold a `ScriptSession` entity which manages the state of all scripts (from parsing to exited) in a centralized way, and will later collect all script operations so they can be displayed in the UI. - `script_tool` has been renamed to `assistant_scripting` - Script source now opens in a regular read-only buffer Note: We still need to handle persistence properly Release Notes: - N/A --------- Co-authored-by: Marshall Bowers <git@maxdeviant.com>
This commit is contained in:
parent
ed6bf7f161
commit
e298301b40
16 changed files with 811 additions and 197 deletions
41
Cargo.lock
generated
41
Cargo.lock
generated
|
@ -450,6 +450,7 @@ version = "0.1.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"anyhow",
|
"anyhow",
|
||||||
"assistant_context_editor",
|
"assistant_context_editor",
|
||||||
|
"assistant_scripting",
|
||||||
"assistant_settings",
|
"assistant_settings",
|
||||||
"assistant_slash_command",
|
"assistant_slash_command",
|
||||||
"assistant_tool",
|
"assistant_tool",
|
||||||
|
@ -563,6 +564,25 @@ dependencies = [
|
||||||
"workspace",
|
"workspace",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "assistant_scripting"
|
||||||
|
version = "0.1.0"
|
||||||
|
dependencies = [
|
||||||
|
"anyhow",
|
||||||
|
"collections",
|
||||||
|
"futures 0.3.31",
|
||||||
|
"gpui",
|
||||||
|
"mlua",
|
||||||
|
"parking_lot",
|
||||||
|
"project",
|
||||||
|
"rand 0.8.5",
|
||||||
|
"regex",
|
||||||
|
"serde",
|
||||||
|
"serde_json",
|
||||||
|
"settings",
|
||||||
|
"util",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "assistant_settings"
|
name = "assistant_settings"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
|
@ -11910,26 +11930,6 @@ version = "1.0.7"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "a3cf7c11c38cb994f3d40e8a8cde3bbd1f72a435e4c49e85d6553d8312306152"
|
checksum = "a3cf7c11c38cb994f3d40e8a8cde3bbd1f72a435e4c49e85d6553d8312306152"
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "scripting_tool"
|
|
||||||
version = "0.1.0"
|
|
||||||
dependencies = [
|
|
||||||
"anyhow",
|
|
||||||
"assistant_tool",
|
|
||||||
"collections",
|
|
||||||
"futures 0.3.31",
|
|
||||||
"gpui",
|
|
||||||
"mlua",
|
|
||||||
"parking_lot",
|
|
||||||
"project",
|
|
||||||
"regex",
|
|
||||||
"schemars",
|
|
||||||
"serde",
|
|
||||||
"serde_json",
|
|
||||||
"settings",
|
|
||||||
"util",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "scrypt"
|
name = "scrypt"
|
||||||
version = "0.11.0"
|
version = "0.11.0"
|
||||||
|
@ -16984,7 +16984,6 @@ dependencies = [
|
||||||
"repl",
|
"repl",
|
||||||
"reqwest_client",
|
"reqwest_client",
|
||||||
"rope",
|
"rope",
|
||||||
"scripting_tool",
|
|
||||||
"search",
|
"search",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
|
|
|
@ -118,7 +118,7 @@ members = [
|
||||||
"crates/rope",
|
"crates/rope",
|
||||||
"crates/rpc",
|
"crates/rpc",
|
||||||
"crates/schema_generator",
|
"crates/schema_generator",
|
||||||
"crates/scripting_tool",
|
"crates/assistant_scripting",
|
||||||
"crates/search",
|
"crates/search",
|
||||||
"crates/semantic_index",
|
"crates/semantic_index",
|
||||||
"crates/semantic_version",
|
"crates/semantic_version",
|
||||||
|
@ -318,7 +318,7 @@ reqwest_client = { path = "crates/reqwest_client" }
|
||||||
rich_text = { path = "crates/rich_text" }
|
rich_text = { path = "crates/rich_text" }
|
||||||
rope = { path = "crates/rope" }
|
rope = { path = "crates/rope" }
|
||||||
rpc = { path = "crates/rpc" }
|
rpc = { path = "crates/rpc" }
|
||||||
scripting_tool = { path = "crates/scripting_tool" }
|
assistant_scripting = { path = "crates/assistant_scripting" }
|
||||||
search = { path = "crates/search" }
|
search = { path = "crates/search" }
|
||||||
semantic_index = { path = "crates/semantic_index" }
|
semantic_index = { path = "crates/semantic_index" }
|
||||||
semantic_version = { path = "crates/semantic_version" }
|
semantic_version = { path = "crates/semantic_version" }
|
||||||
|
|
|
@ -63,6 +63,7 @@ serde.workspace = true
|
||||||
serde_json.workspace = true
|
serde_json.workspace = true
|
||||||
settings.workspace = true
|
settings.workspace = true
|
||||||
smol.workspace = true
|
smol.workspace = true
|
||||||
|
assistant_scripting.workspace = true
|
||||||
streaming_diff.workspace = true
|
streaming_diff.workspace = true
|
||||||
telemetry_events.workspace = true
|
telemetry_events.workspace = true
|
||||||
terminal.workspace = true
|
terminal.workspace = true
|
||||||
|
|
|
@ -1,11 +1,12 @@
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
use collections::HashMap;
|
use assistant_scripting::{ScriptId, ScriptState};
|
||||||
|
use collections::{HashMap, HashSet};
|
||||||
use editor::{Editor, MultiBuffer};
|
use editor::{Editor, MultiBuffer};
|
||||||
use gpui::{
|
use gpui::{
|
||||||
list, AbsoluteLength, AnyElement, App, ClickEvent, DefiniteLength, EdgesRefinement, Empty,
|
list, AbsoluteLength, AnyElement, App, ClickEvent, DefiniteLength, EdgesRefinement, Empty,
|
||||||
Entity, Focusable, Length, ListAlignment, ListOffset, ListState, StyleRefinement, Subscription,
|
Entity, Focusable, Length, ListAlignment, ListOffset, ListState, StyleRefinement, Subscription,
|
||||||
Task, TextStyleRefinement, UnderlineStyle,
|
Task, TextStyleRefinement, UnderlineStyle, WeakEntity,
|
||||||
};
|
};
|
||||||
use language::{Buffer, LanguageRegistry};
|
use language::{Buffer, LanguageRegistry};
|
||||||
use language_model::{LanguageModelRegistry, LanguageModelToolUseId, Role};
|
use language_model::{LanguageModelRegistry, LanguageModelToolUseId, Role};
|
||||||
|
@ -14,6 +15,7 @@ use settings::Settings as _;
|
||||||
use theme::ThemeSettings;
|
use theme::ThemeSettings;
|
||||||
use ui::{prelude::*, Disclosure, KeyBinding};
|
use ui::{prelude::*, Disclosure, KeyBinding};
|
||||||
use util::ResultExt as _;
|
use util::ResultExt as _;
|
||||||
|
use workspace::Workspace;
|
||||||
|
|
||||||
use crate::thread::{MessageId, RequestKind, Thread, ThreadError, ThreadEvent};
|
use crate::thread::{MessageId, RequestKind, Thread, ThreadError, ThreadEvent};
|
||||||
use crate::thread_store::ThreadStore;
|
use crate::thread_store::ThreadStore;
|
||||||
|
@ -21,6 +23,7 @@ use crate::tool_use::{ToolUse, ToolUseStatus};
|
||||||
use crate::ui::ContextPill;
|
use crate::ui::ContextPill;
|
||||||
|
|
||||||
pub struct ActiveThread {
|
pub struct ActiveThread {
|
||||||
|
workspace: WeakEntity<Workspace>,
|
||||||
language_registry: Arc<LanguageRegistry>,
|
language_registry: Arc<LanguageRegistry>,
|
||||||
thread_store: Entity<ThreadStore>,
|
thread_store: Entity<ThreadStore>,
|
||||||
thread: Entity<Thread>,
|
thread: Entity<Thread>,
|
||||||
|
@ -30,6 +33,7 @@ pub struct ActiveThread {
|
||||||
rendered_messages_by_id: HashMap<MessageId, Entity<Markdown>>,
|
rendered_messages_by_id: HashMap<MessageId, Entity<Markdown>>,
|
||||||
editing_message: Option<(MessageId, EditMessageState)>,
|
editing_message: Option<(MessageId, EditMessageState)>,
|
||||||
expanded_tool_uses: HashMap<LanguageModelToolUseId, bool>,
|
expanded_tool_uses: HashMap<LanguageModelToolUseId, bool>,
|
||||||
|
expanded_scripts: HashSet<ScriptId>,
|
||||||
last_error: Option<ThreadError>,
|
last_error: Option<ThreadError>,
|
||||||
_subscriptions: Vec<Subscription>,
|
_subscriptions: Vec<Subscription>,
|
||||||
}
|
}
|
||||||
|
@ -40,6 +44,7 @@ struct EditMessageState {
|
||||||
|
|
||||||
impl ActiveThread {
|
impl ActiveThread {
|
||||||
pub fn new(
|
pub fn new(
|
||||||
|
workspace: WeakEntity<Workspace>,
|
||||||
thread: Entity<Thread>,
|
thread: Entity<Thread>,
|
||||||
thread_store: Entity<ThreadStore>,
|
thread_store: Entity<ThreadStore>,
|
||||||
language_registry: Arc<LanguageRegistry>,
|
language_registry: Arc<LanguageRegistry>,
|
||||||
|
@ -52,6 +57,7 @@ impl ActiveThread {
|
||||||
];
|
];
|
||||||
|
|
||||||
let mut this = Self {
|
let mut this = Self {
|
||||||
|
workspace,
|
||||||
language_registry,
|
language_registry,
|
||||||
thread_store,
|
thread_store,
|
||||||
thread: thread.clone(),
|
thread: thread.clone(),
|
||||||
|
@ -59,6 +65,7 @@ impl ActiveThread {
|
||||||
messages: Vec::new(),
|
messages: Vec::new(),
|
||||||
rendered_messages_by_id: HashMap::default(),
|
rendered_messages_by_id: HashMap::default(),
|
||||||
expanded_tool_uses: HashMap::default(),
|
expanded_tool_uses: HashMap::default(),
|
||||||
|
expanded_scripts: HashSet::default(),
|
||||||
list_state: ListState::new(0, ListAlignment::Bottom, px(1024.), {
|
list_state: ListState::new(0, ListAlignment::Bottom, px(1024.), {
|
||||||
let this = cx.entity().downgrade();
|
let this = cx.entity().downgrade();
|
||||||
move |ix, window: &mut Window, cx: &mut App| {
|
move |ix, window: &mut Window, cx: &mut App| {
|
||||||
|
@ -241,7 +248,7 @@ impl ActiveThread {
|
||||||
|
|
||||||
fn handle_thread_event(
|
fn handle_thread_event(
|
||||||
&mut self,
|
&mut self,
|
||||||
_: &Entity<Thread>,
|
_thread: &Entity<Thread>,
|
||||||
event: &ThreadEvent,
|
event: &ThreadEvent,
|
||||||
window: &mut Window,
|
window: &mut Window,
|
||||||
cx: &mut Context<Self>,
|
cx: &mut Context<Self>,
|
||||||
|
@ -306,6 +313,14 @@ impl ActiveThread {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
ThreadEvent::ScriptFinished => {
|
||||||
|
let model_registry = LanguageModelRegistry::read_global(cx);
|
||||||
|
if let Some(model) = model_registry.active_model() {
|
||||||
|
self.thread.update(cx, |thread, cx| {
|
||||||
|
thread.send_to_model(model, RequestKind::Chat, false, cx);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -445,12 +460,16 @@ impl ActiveThread {
|
||||||
return Empty.into_any();
|
return Empty.into_any();
|
||||||
};
|
};
|
||||||
|
|
||||||
let context = self.thread.read(cx).context_for_message(message_id);
|
let thread = self.thread.read(cx);
|
||||||
let tool_uses = self.thread.read(cx).tool_uses_for_message(message_id);
|
|
||||||
let colors = cx.theme().colors();
|
let context = thread.context_for_message(message_id);
|
||||||
|
let tool_uses = thread.tool_uses_for_message(message_id);
|
||||||
|
|
||||||
// Don't render user messages that are just there for returning tool results.
|
// Don't render user messages that are just there for returning tool results.
|
||||||
if message.role == Role::User && self.thread.read(cx).message_has_tool_results(message_id) {
|
if message.role == Role::User
|
||||||
|
&& (thread.message_has_tool_results(message_id)
|
||||||
|
|| thread.message_has_script_output(message_id))
|
||||||
|
{
|
||||||
return Empty.into_any();
|
return Empty.into_any();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -463,6 +482,8 @@ impl ActiveThread {
|
||||||
.filter(|(id, _)| *id == message_id)
|
.filter(|(id, _)| *id == message_id)
|
||||||
.map(|(_, state)| state.editor.clone());
|
.map(|(_, state)| state.editor.clone());
|
||||||
|
|
||||||
|
let colors = cx.theme().colors();
|
||||||
|
|
||||||
let message_content = v_flex()
|
let message_content = v_flex()
|
||||||
.child(
|
.child(
|
||||||
if let Some(edit_message_editor) = edit_message_editor.clone() {
|
if let Some(edit_message_editor) = edit_message_editor.clone() {
|
||||||
|
@ -597,6 +618,7 @@ impl ActiveThread {
|
||||||
Role::Assistant => div()
|
Role::Assistant => div()
|
||||||
.id(("message-container", ix))
|
.id(("message-container", ix))
|
||||||
.child(message_content)
|
.child(message_content)
|
||||||
|
.children(self.render_script(message_id, cx))
|
||||||
.map(|parent| {
|
.map(|parent| {
|
||||||
if tool_uses.is_empty() {
|
if tool_uses.is_empty() {
|
||||||
return parent;
|
return parent;
|
||||||
|
@ -716,6 +738,139 @@ impl ActiveThread {
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn render_script(&self, message_id: MessageId, cx: &mut Context<Self>) -> Option<AnyElement> {
|
||||||
|
let script = self.thread.read(cx).script_for_message(message_id, cx)?;
|
||||||
|
|
||||||
|
let is_open = self.expanded_scripts.contains(&script.id);
|
||||||
|
let colors = cx.theme().colors();
|
||||||
|
|
||||||
|
let element = div().px_2p5().child(
|
||||||
|
v_flex()
|
||||||
|
.gap_1()
|
||||||
|
.rounded_lg()
|
||||||
|
.border_1()
|
||||||
|
.border_color(colors.border)
|
||||||
|
.child(
|
||||||
|
h_flex()
|
||||||
|
.justify_between()
|
||||||
|
.py_0p5()
|
||||||
|
.pl_1()
|
||||||
|
.pr_2()
|
||||||
|
.bg(colors.editor_foreground.opacity(0.02))
|
||||||
|
.when(is_open, |element| element.border_b_1().rounded_t(px(6.)))
|
||||||
|
.when(!is_open, |element| element.rounded_md())
|
||||||
|
.border_color(colors.border)
|
||||||
|
.child(
|
||||||
|
h_flex()
|
||||||
|
.gap_1()
|
||||||
|
.child(Disclosure::new("script-disclosure", is_open).on_click(
|
||||||
|
cx.listener({
|
||||||
|
let script_id = script.id;
|
||||||
|
move |this, _event, _window, _cx| {
|
||||||
|
if this.expanded_scripts.contains(&script_id) {
|
||||||
|
this.expanded_scripts.remove(&script_id);
|
||||||
|
} else {
|
||||||
|
this.expanded_scripts.insert(script_id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
))
|
||||||
|
// TODO: Generate script description
|
||||||
|
.child(Label::new("Script")),
|
||||||
|
)
|
||||||
|
.child(
|
||||||
|
h_flex()
|
||||||
|
.gap_1()
|
||||||
|
.child(
|
||||||
|
Label::new(match script.state {
|
||||||
|
ScriptState::Generating => "Generating",
|
||||||
|
ScriptState::Running { .. } => "Running",
|
||||||
|
ScriptState::Succeeded { .. } => "Finished",
|
||||||
|
ScriptState::Failed { .. } => "Error",
|
||||||
|
})
|
||||||
|
.size(LabelSize::XSmall)
|
||||||
|
.buffer_font(cx),
|
||||||
|
)
|
||||||
|
.child(
|
||||||
|
IconButton::new("view-source", IconName::Eye)
|
||||||
|
.icon_color(Color::Muted)
|
||||||
|
.disabled(matches!(script.state, ScriptState::Generating))
|
||||||
|
.on_click(cx.listener({
|
||||||
|
let source = script.source.clone();
|
||||||
|
move |this, _event, window, cx| {
|
||||||
|
this.open_script_source(source.clone(), window, cx);
|
||||||
|
}
|
||||||
|
})),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
.when(is_open, |parent| {
|
||||||
|
let stdout = script.stdout_snapshot();
|
||||||
|
let error = script.error();
|
||||||
|
|
||||||
|
parent.child(
|
||||||
|
v_flex()
|
||||||
|
.p_2()
|
||||||
|
.bg(colors.editor_background)
|
||||||
|
.gap_2()
|
||||||
|
.child(if stdout.is_empty() && error.is_none() {
|
||||||
|
Label::new("No output yet")
|
||||||
|
.size(LabelSize::Small)
|
||||||
|
.color(Color::Muted)
|
||||||
|
} else {
|
||||||
|
Label::new(stdout).size(LabelSize::Small).buffer_font(cx)
|
||||||
|
})
|
||||||
|
.children(script.error().map(|err| {
|
||||||
|
Label::new(err.to_string())
|
||||||
|
.size(LabelSize::Small)
|
||||||
|
.color(Color::Error)
|
||||||
|
})),
|
||||||
|
)
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
|
||||||
|
Some(element.into_any())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn open_script_source(
|
||||||
|
&mut self,
|
||||||
|
source: SharedString,
|
||||||
|
window: &mut Window,
|
||||||
|
cx: &mut Context<'_, ActiveThread>,
|
||||||
|
) {
|
||||||
|
let language_registry = self.language_registry.clone();
|
||||||
|
let workspace = self.workspace.clone();
|
||||||
|
let source = source.clone();
|
||||||
|
|
||||||
|
cx.spawn_in(window, |_, mut cx| async move {
|
||||||
|
let lua = language_registry.language_for_name("Lua").await.log_err();
|
||||||
|
|
||||||
|
workspace.update_in(&mut cx, |workspace, window, cx| {
|
||||||
|
let project = workspace.project().clone();
|
||||||
|
|
||||||
|
let buffer = project.update(cx, |project, cx| {
|
||||||
|
project.create_local_buffer(&source.trim(), lua, cx)
|
||||||
|
});
|
||||||
|
|
||||||
|
let buffer = cx.new(|cx| {
|
||||||
|
MultiBuffer::singleton(buffer, cx)
|
||||||
|
// TODO: Generate script description
|
||||||
|
.with_title("Assistant script".into())
|
||||||
|
});
|
||||||
|
|
||||||
|
let editor = cx.new(|cx| {
|
||||||
|
let mut editor =
|
||||||
|
Editor::for_multibuffer(buffer, Some(project), true, window, cx);
|
||||||
|
editor.set_read_only(true);
|
||||||
|
editor
|
||||||
|
});
|
||||||
|
|
||||||
|
workspace.add_item_to_active_pane(Box::new(editor), None, true, window, cx);
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.detach_and_log_err(cx);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Render for ActiveThread {
|
impl Render for ActiveThread {
|
||||||
|
|
|
@ -166,22 +166,25 @@ impl AssistantPanel {
|
||||||
let history_store =
|
let history_store =
|
||||||
cx.new(|cx| HistoryStore::new(thread_store.clone(), context_store.clone(), cx));
|
cx.new(|cx| HistoryStore::new(thread_store.clone(), context_store.clone(), cx));
|
||||||
|
|
||||||
|
let thread = cx.new(|cx| {
|
||||||
|
ActiveThread::new(
|
||||||
|
workspace.clone(),
|
||||||
|
thread.clone(),
|
||||||
|
thread_store.clone(),
|
||||||
|
language_registry.clone(),
|
||||||
|
window,
|
||||||
|
cx,
|
||||||
|
)
|
||||||
|
});
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
active_view: ActiveView::Thread,
|
active_view: ActiveView::Thread,
|
||||||
workspace,
|
workspace,
|
||||||
project: project.clone(),
|
project: project.clone(),
|
||||||
fs: fs.clone(),
|
fs: fs.clone(),
|
||||||
language_registry: language_registry.clone(),
|
language_registry,
|
||||||
thread_store: thread_store.clone(),
|
thread_store: thread_store.clone(),
|
||||||
thread: cx.new(|cx| {
|
thread,
|
||||||
ActiveThread::new(
|
|
||||||
thread.clone(),
|
|
||||||
thread_store.clone(),
|
|
||||||
language_registry,
|
|
||||||
window,
|
|
||||||
cx,
|
|
||||||
)
|
|
||||||
}),
|
|
||||||
message_editor,
|
message_editor,
|
||||||
context_store,
|
context_store,
|
||||||
context_editor: None,
|
context_editor: None,
|
||||||
|
@ -239,6 +242,7 @@ impl AssistantPanel {
|
||||||
self.active_view = ActiveView::Thread;
|
self.active_view = ActiveView::Thread;
|
||||||
self.thread = cx.new(|cx| {
|
self.thread = cx.new(|cx| {
|
||||||
ActiveThread::new(
|
ActiveThread::new(
|
||||||
|
self.workspace.clone(),
|
||||||
thread.clone(),
|
thread.clone(),
|
||||||
self.thread_store.clone(),
|
self.thread_store.clone(),
|
||||||
self.language_registry.clone(),
|
self.language_registry.clone(),
|
||||||
|
@ -372,6 +376,7 @@ impl AssistantPanel {
|
||||||
this.active_view = ActiveView::Thread;
|
this.active_view = ActiveView::Thread;
|
||||||
this.thread = cx.new(|cx| {
|
this.thread = cx.new(|cx| {
|
||||||
ActiveThread::new(
|
ActiveThread::new(
|
||||||
|
this.workspace.clone(),
|
||||||
thread.clone(),
|
thread.clone(),
|
||||||
this.thread_store.clone(),
|
this.thread_store.clone(),
|
||||||
this.language_registry.clone(),
|
this.language_registry.clone(),
|
||||||
|
|
|
@ -1,11 +1,14 @@
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
|
use assistant_scripting::{
|
||||||
|
Script, ScriptEvent, ScriptId, ScriptSession, ScriptTagParser, SCRIPTING_PROMPT,
|
||||||
|
};
|
||||||
use assistant_tool::ToolWorkingSet;
|
use assistant_tool::ToolWorkingSet;
|
||||||
use chrono::{DateTime, Utc};
|
use chrono::{DateTime, Utc};
|
||||||
use collections::{BTreeMap, HashMap, HashSet};
|
use collections::{BTreeMap, HashMap, HashSet};
|
||||||
use futures::StreamExt as _;
|
use futures::StreamExt as _;
|
||||||
use gpui::{App, Context, Entity, EventEmitter, SharedString, Task};
|
use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Subscription, Task};
|
||||||
use language_model::{
|
use language_model::{
|
||||||
LanguageModel, LanguageModelCompletionEvent, LanguageModelRegistry, LanguageModelRequest,
|
LanguageModel, LanguageModelCompletionEvent, LanguageModelRegistry, LanguageModelRequest,
|
||||||
LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
|
LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
|
||||||
|
@ -75,14 +78,21 @@ pub struct Thread {
|
||||||
project: Entity<Project>,
|
project: Entity<Project>,
|
||||||
tools: Arc<ToolWorkingSet>,
|
tools: Arc<ToolWorkingSet>,
|
||||||
tool_use: ToolUseState,
|
tool_use: ToolUseState,
|
||||||
|
scripts_by_assistant_message: HashMap<MessageId, ScriptId>,
|
||||||
|
script_output_messages: HashSet<MessageId>,
|
||||||
|
script_session: Entity<ScriptSession>,
|
||||||
|
_script_session_subscription: Subscription,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Thread {
|
impl Thread {
|
||||||
pub fn new(
|
pub fn new(
|
||||||
project: Entity<Project>,
|
project: Entity<Project>,
|
||||||
tools: Arc<ToolWorkingSet>,
|
tools: Arc<ToolWorkingSet>,
|
||||||
_cx: &mut Context<Self>,
|
cx: &mut Context<Self>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
|
let script_session = cx.new(|cx| ScriptSession::new(project.clone(), cx));
|
||||||
|
let script_session_subscription = cx.subscribe(&script_session, Self::handle_script_event);
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
id: ThreadId::new(),
|
id: ThreadId::new(),
|
||||||
updated_at: Utc::now(),
|
updated_at: Utc::now(),
|
||||||
|
@ -97,6 +107,10 @@ impl Thread {
|
||||||
project,
|
project,
|
||||||
tools,
|
tools,
|
||||||
tool_use: ToolUseState::new(),
|
tool_use: ToolUseState::new(),
|
||||||
|
scripts_by_assistant_message: HashMap::default(),
|
||||||
|
script_output_messages: HashSet::default(),
|
||||||
|
script_session,
|
||||||
|
_script_session_subscription: script_session_subscription,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -105,7 +119,7 @@ impl Thread {
|
||||||
saved: SavedThread,
|
saved: SavedThread,
|
||||||
project: Entity<Project>,
|
project: Entity<Project>,
|
||||||
tools: Arc<ToolWorkingSet>,
|
tools: Arc<ToolWorkingSet>,
|
||||||
_cx: &mut Context<Self>,
|
cx: &mut Context<Self>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
let next_message_id = MessageId(
|
let next_message_id = MessageId(
|
||||||
saved
|
saved
|
||||||
|
@ -115,6 +129,8 @@ impl Thread {
|
||||||
.unwrap_or(0),
|
.unwrap_or(0),
|
||||||
);
|
);
|
||||||
let tool_use = ToolUseState::from_saved_messages(&saved.messages);
|
let tool_use = ToolUseState::from_saved_messages(&saved.messages);
|
||||||
|
let script_session = cx.new(|cx| ScriptSession::new(project.clone(), cx));
|
||||||
|
let script_session_subscription = cx.subscribe(&script_session, Self::handle_script_event);
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
id,
|
id,
|
||||||
|
@ -138,6 +154,10 @@ impl Thread {
|
||||||
project,
|
project,
|
||||||
tools,
|
tools,
|
||||||
tool_use,
|
tool_use,
|
||||||
|
scripts_by_assistant_message: HashMap::default(),
|
||||||
|
script_output_messages: HashSet::default(),
|
||||||
|
script_session,
|
||||||
|
_script_session_subscription: script_session_subscription,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -223,17 +243,22 @@ impl Thread {
|
||||||
self.tool_use.message_has_tool_results(message_id)
|
self.tool_use.message_has_tool_results(message_id)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn message_has_script_output(&self, message_id: MessageId) -> bool {
|
||||||
|
self.script_output_messages.contains(&message_id)
|
||||||
|
}
|
||||||
|
|
||||||
pub fn insert_user_message(
|
pub fn insert_user_message(
|
||||||
&mut self,
|
&mut self,
|
||||||
text: impl Into<String>,
|
text: impl Into<String>,
|
||||||
context: Vec<ContextSnapshot>,
|
context: Vec<ContextSnapshot>,
|
||||||
cx: &mut Context<Self>,
|
cx: &mut Context<Self>,
|
||||||
) {
|
) -> MessageId {
|
||||||
let message_id = self.insert_message(Role::User, text, cx);
|
let message_id = self.insert_message(Role::User, text, cx);
|
||||||
let context_ids = context.iter().map(|context| context.id).collect::<Vec<_>>();
|
let context_ids = context.iter().map(|context| context.id).collect::<Vec<_>>();
|
||||||
self.context
|
self.context
|
||||||
.extend(context.into_iter().map(|context| (context.id, context)));
|
.extend(context.into_iter().map(|context| (context.id, context)));
|
||||||
self.context_by_message.insert(message_id, context_ids);
|
self.context_by_message.insert(message_id, context_ids);
|
||||||
|
message_id
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn insert_message(
|
pub fn insert_message(
|
||||||
|
@ -302,6 +327,39 @@ impl Thread {
|
||||||
text
|
text
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn script_for_message<'a>(
|
||||||
|
&'a self,
|
||||||
|
message_id: MessageId,
|
||||||
|
cx: &'a App,
|
||||||
|
) -> Option<&'a Script> {
|
||||||
|
self.scripts_by_assistant_message
|
||||||
|
.get(&message_id)
|
||||||
|
.map(|script_id| self.script_session.read(cx).get(*script_id))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn handle_script_event(
|
||||||
|
&mut self,
|
||||||
|
_script_session: Entity<ScriptSession>,
|
||||||
|
event: &ScriptEvent,
|
||||||
|
cx: &mut Context<Self>,
|
||||||
|
) {
|
||||||
|
match event {
|
||||||
|
ScriptEvent::Spawned(_) => {}
|
||||||
|
ScriptEvent::Exited(script_id) => {
|
||||||
|
if let Some(output_message) = self
|
||||||
|
.script_session
|
||||||
|
.read(cx)
|
||||||
|
.get(*script_id)
|
||||||
|
.output_message_for_llm()
|
||||||
|
{
|
||||||
|
let message_id = self.insert_user_message(output_message, vec![], cx);
|
||||||
|
self.script_output_messages.insert(message_id);
|
||||||
|
cx.emit(ThreadEvent::ScriptFinished)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn send_to_model(
|
pub fn send_to_model(
|
||||||
&mut self,
|
&mut self,
|
||||||
model: Arc<dyn LanguageModel>,
|
model: Arc<dyn LanguageModel>,
|
||||||
|
@ -330,7 +388,7 @@ impl Thread {
|
||||||
pub fn to_completion_request(
|
pub fn to_completion_request(
|
||||||
&self,
|
&self,
|
||||||
request_kind: RequestKind,
|
request_kind: RequestKind,
|
||||||
_cx: &App,
|
cx: &App,
|
||||||
) -> LanguageModelRequest {
|
) -> LanguageModelRequest {
|
||||||
let mut request = LanguageModelRequest {
|
let mut request = LanguageModelRequest {
|
||||||
messages: vec![],
|
messages: vec![],
|
||||||
|
@ -339,6 +397,12 @@ impl Thread {
|
||||||
temperature: None,
|
temperature: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
request.messages.push(LanguageModelRequestMessage {
|
||||||
|
role: Role::System,
|
||||||
|
content: vec![SCRIPTING_PROMPT.to_string().into()],
|
||||||
|
cache: true,
|
||||||
|
});
|
||||||
|
|
||||||
let mut referenced_context_ids = HashSet::default();
|
let mut referenced_context_ids = HashSet::default();
|
||||||
|
|
||||||
for message in &self.messages {
|
for message in &self.messages {
|
||||||
|
@ -351,6 +415,7 @@ impl Thread {
|
||||||
content: Vec::new(),
|
content: Vec::new(),
|
||||||
cache: false,
|
cache: false,
|
||||||
};
|
};
|
||||||
|
|
||||||
match request_kind {
|
match request_kind {
|
||||||
RequestKind::Chat => {
|
RequestKind::Chat => {
|
||||||
self.tool_use
|
self.tool_use
|
||||||
|
@ -371,11 +436,20 @@ impl Thread {
|
||||||
RequestKind::Chat => {
|
RequestKind::Chat => {
|
||||||
self.tool_use
|
self.tool_use
|
||||||
.attach_tool_uses(message.id, &mut request_message);
|
.attach_tool_uses(message.id, &mut request_message);
|
||||||
|
|
||||||
|
if matches!(message.role, Role::Assistant) {
|
||||||
|
if let Some(script_id) = self.scripts_by_assistant_message.get(&message.id)
|
||||||
|
{
|
||||||
|
let script = self.script_session.read(cx).get(*script_id);
|
||||||
|
|
||||||
|
request_message.content.push(script.source_tag().into());
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
RequestKind::Summarize => {
|
RequestKind::Summarize => {
|
||||||
// We don't care about tool use during summarization.
|
// We don't care about tool use during summarization.
|
||||||
}
|
}
|
||||||
}
|
};
|
||||||
|
|
||||||
request.messages.push(request_message);
|
request.messages.push(request_message);
|
||||||
}
|
}
|
||||||
|
@ -412,6 +486,8 @@ impl Thread {
|
||||||
let stream_completion = async {
|
let stream_completion = async {
|
||||||
let mut events = stream.await?;
|
let mut events = stream.await?;
|
||||||
let mut stop_reason = StopReason::EndTurn;
|
let mut stop_reason = StopReason::EndTurn;
|
||||||
|
let mut script_tag_parser = ScriptTagParser::new();
|
||||||
|
let mut script_id = None;
|
||||||
|
|
||||||
while let Some(event) = events.next().await {
|
while let Some(event) = events.next().await {
|
||||||
let event = event?;
|
let event = event?;
|
||||||
|
@ -426,19 +502,43 @@ impl Thread {
|
||||||
}
|
}
|
||||||
LanguageModelCompletionEvent::Text(chunk) => {
|
LanguageModelCompletionEvent::Text(chunk) => {
|
||||||
if let Some(last_message) = thread.messages.last_mut() {
|
if let Some(last_message) = thread.messages.last_mut() {
|
||||||
if last_message.role == Role::Assistant {
|
let chunk = script_tag_parser.parse_chunk(&chunk);
|
||||||
last_message.text.push_str(&chunk);
|
|
||||||
|
let message_id = if last_message.role == Role::Assistant {
|
||||||
|
last_message.text.push_str(&chunk.content);
|
||||||
cx.emit(ThreadEvent::StreamedAssistantText(
|
cx.emit(ThreadEvent::StreamedAssistantText(
|
||||||
last_message.id,
|
last_message.id,
|
||||||
chunk,
|
chunk.content,
|
||||||
));
|
));
|
||||||
|
last_message.id
|
||||||
} else {
|
} else {
|
||||||
// If we won't have an Assistant message yet, assume this chunk marks the beginning
|
// If we won't have an Assistant message yet, assume this chunk marks the beginning
|
||||||
// of a new Assistant response.
|
// of a new Assistant response.
|
||||||
//
|
//
|
||||||
// Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
|
// Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
|
||||||
// will result in duplicating the text of the chunk in the rendered Markdown.
|
// will result in duplicating the text of the chunk in the rendered Markdown.
|
||||||
thread.insert_message(Role::Assistant, chunk, cx);
|
thread.insert_message(Role::Assistant, chunk.content, cx)
|
||||||
|
};
|
||||||
|
|
||||||
|
if script_id.is_none() && script_tag_parser.found_script() {
|
||||||
|
let id = thread
|
||||||
|
.script_session
|
||||||
|
.update(cx, |session, _cx| session.new_script());
|
||||||
|
thread.scripts_by_assistant_message.insert(message_id, id);
|
||||||
|
|
||||||
|
script_id = Some(id);
|
||||||
|
}
|
||||||
|
|
||||||
|
if let (Some(script_source), Some(script_id)) =
|
||||||
|
(chunk.script_source, script_id)
|
||||||
|
{
|
||||||
|
// TODO: move buffer to script and run as it streams
|
||||||
|
thread
|
||||||
|
.script_session
|
||||||
|
.update(cx, |this, cx| {
|
||||||
|
this.run_script(script_id, script_source, cx)
|
||||||
|
})
|
||||||
|
.detach_and_log_err(cx);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -661,6 +761,7 @@ pub enum ThreadEvent {
|
||||||
#[allow(unused)]
|
#[allow(unused)]
|
||||||
tool_use_id: LanguageModelToolUseId,
|
tool_use_id: LanguageModelToolUseId,
|
||||||
},
|
},
|
||||||
|
ScriptFinished,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl EventEmitter<ThreadEvent> for Thread {}
|
impl EventEmitter<ThreadEvent> for Thread {}
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
[package]
|
[package]
|
||||||
name = "scripting_tool"
|
name = "assistant_scripting"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
edition.workspace = true
|
edition.workspace = true
|
||||||
publish.workspace = true
|
publish.workspace = true
|
||||||
|
@ -9,12 +9,11 @@ license = "GPL-3.0-or-later"
|
||||||
workspace = true
|
workspace = true
|
||||||
|
|
||||||
[lib]
|
[lib]
|
||||||
path = "src/scripting_tool.rs"
|
path = "src/assistant_scripting.rs"
|
||||||
doctest = false
|
doctest = false
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
anyhow.workspace = true
|
anyhow.workspace = true
|
||||||
assistant_tool.workspace = true
|
|
||||||
collections.workspace = true
|
collections.workspace = true
|
||||||
futures.workspace = true
|
futures.workspace = true
|
||||||
gpui.workspace = true
|
gpui.workspace = true
|
||||||
|
@ -22,7 +21,6 @@ mlua.workspace = true
|
||||||
parking_lot.workspace = true
|
parking_lot.workspace = true
|
||||||
project.workspace = true
|
project.workspace = true
|
||||||
regex.workspace = true
|
regex.workspace = true
|
||||||
schemars.workspace = true
|
|
||||||
serde.workspace = true
|
serde.workspace = true
|
||||||
serde_json.workspace = true
|
serde_json.workspace = true
|
||||||
settings.workspace = true
|
settings.workspace = true
|
||||||
|
@ -32,4 +30,5 @@ util.workspace = true
|
||||||
collections = { workspace = true, features = ["test-support"] }
|
collections = { workspace = true, features = ["test-support"] }
|
||||||
gpui = { workspace = true, features = ["test-support"] }
|
gpui = { workspace = true, features = ["test-support"] }
|
||||||
project = { workspace = true, features = ["test-support"] }
|
project = { workspace = true, features = ["test-support"] }
|
||||||
|
rand.workspace = true
|
||||||
settings = { workspace = true, features = ["test-support"] }
|
settings = { workspace = true, features = ["test-support"] }
|
7
crates/assistant_scripting/src/assistant_scripting.rs
Normal file
7
crates/assistant_scripting/src/assistant_scripting.rs
Normal file
|
@ -0,0 +1,7 @@
|
||||||
|
mod session;
|
||||||
|
mod tag;
|
||||||
|
|
||||||
|
pub use session::*;
|
||||||
|
pub use tag::*;
|
||||||
|
|
||||||
|
pub const SCRIPTING_PROMPT: &str = include_str!("./system_prompt.txt");
|
|
@ -1,10 +1,9 @@
|
||||||
use anyhow::Result;
|
|
||||||
use collections::{HashMap, HashSet};
|
use collections::{HashMap, HashSet};
|
||||||
use futures::{
|
use futures::{
|
||||||
channel::{mpsc, oneshot},
|
channel::{mpsc, oneshot},
|
||||||
pin_mut, SinkExt, StreamExt,
|
pin_mut, SinkExt, StreamExt,
|
||||||
};
|
};
|
||||||
use gpui::{AppContext, AsyncApp, Context, Entity, Task, WeakEntity};
|
use gpui::{AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
|
||||||
use mlua::{Lua, MultiValue, Table, UserData, UserDataMethods};
|
use mlua::{Lua, MultiValue, Table, UserData, UserDataMethods};
|
||||||
use parking_lot::Mutex;
|
use parking_lot::Mutex;
|
||||||
use project::{search::SearchQuery, Fs, Project};
|
use project::{search::SearchQuery, Fs, Project};
|
||||||
|
@ -16,24 +15,23 @@ use std::{
|
||||||
};
|
};
|
||||||
use util::{paths::PathMatcher, ResultExt};
|
use util::{paths::PathMatcher, ResultExt};
|
||||||
|
|
||||||
pub struct ScriptOutput {
|
use crate::{SCRIPT_END_TAG, SCRIPT_START_TAG};
|
||||||
pub stdout: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
struct ForegroundFn(Box<dyn FnOnce(WeakEntity<Session>, AsyncApp) + Send>);
|
struct ForegroundFn(Box<dyn FnOnce(WeakEntity<ScriptSession>, AsyncApp) + Send>);
|
||||||
|
|
||||||
pub struct Session {
|
pub struct ScriptSession {
|
||||||
project: Entity<Project>,
|
project: Entity<Project>,
|
||||||
// TODO Remove this
|
// TODO Remove this
|
||||||
fs_changes: Arc<Mutex<HashMap<PathBuf, Vec<u8>>>>,
|
fs_changes: Arc<Mutex<HashMap<PathBuf, Vec<u8>>>>,
|
||||||
foreground_fns_tx: mpsc::Sender<ForegroundFn>,
|
foreground_fns_tx: mpsc::Sender<ForegroundFn>,
|
||||||
_invoke_foreground_fns: Task<()>,
|
_invoke_foreground_fns: Task<()>,
|
||||||
|
scripts: Vec<Script>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Session {
|
impl ScriptSession {
|
||||||
pub fn new(project: Entity<Project>, cx: &mut Context<Self>) -> Self {
|
pub fn new(project: Entity<Project>, cx: &mut Context<Self>) -> Self {
|
||||||
let (foreground_fns_tx, mut foreground_fns_rx) = mpsc::channel(128);
|
let (foreground_fns_tx, mut foreground_fns_rx) = mpsc::channel(128);
|
||||||
Session {
|
ScriptSession {
|
||||||
project,
|
project,
|
||||||
fs_changes: Arc::new(Mutex::new(HashMap::default())),
|
fs_changes: Arc::new(Mutex::new(HashMap::default())),
|
||||||
foreground_fns_tx,
|
foreground_fns_tx,
|
||||||
|
@ -42,15 +40,62 @@ impl Session {
|
||||||
foreground_fn.0(this.clone(), cx.clone());
|
foreground_fn.0(this.clone(), cx.clone());
|
||||||
}
|
}
|
||||||
}),
|
}),
|
||||||
|
scripts: Vec::new(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Runs a Lua script in a sandboxed environment and returns the printed lines
|
pub fn new_script(&mut self) -> ScriptId {
|
||||||
|
let id = ScriptId(self.scripts.len() as u32);
|
||||||
|
let script = Script {
|
||||||
|
id,
|
||||||
|
state: ScriptState::Generating,
|
||||||
|
source: SharedString::new_static(""),
|
||||||
|
};
|
||||||
|
self.scripts.push(script);
|
||||||
|
id
|
||||||
|
}
|
||||||
|
|
||||||
pub fn run_script(
|
pub fn run_script(
|
||||||
&mut self,
|
&mut self,
|
||||||
script: String,
|
script_id: ScriptId,
|
||||||
|
script_src: String,
|
||||||
cx: &mut Context<Self>,
|
cx: &mut Context<Self>,
|
||||||
) -> Task<Result<ScriptOutput>> {
|
) -> Task<anyhow::Result<()>> {
|
||||||
|
let script = self.get_mut(script_id);
|
||||||
|
|
||||||
|
let stdout = Arc::new(Mutex::new(String::new()));
|
||||||
|
script.source = script_src.clone().into();
|
||||||
|
script.state = ScriptState::Running {
|
||||||
|
stdout: stdout.clone(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let task = self.run_lua(script_src, stdout, cx);
|
||||||
|
|
||||||
|
cx.emit(ScriptEvent::Spawned(script_id));
|
||||||
|
|
||||||
|
cx.spawn(|session, mut cx| async move {
|
||||||
|
let result = task.await;
|
||||||
|
|
||||||
|
session.update(&mut cx, |session, cx| {
|
||||||
|
let script = session.get_mut(script_id);
|
||||||
|
let stdout = script.stdout_snapshot();
|
||||||
|
|
||||||
|
script.state = match result {
|
||||||
|
Ok(()) => ScriptState::Succeeded { stdout },
|
||||||
|
Err(error) => ScriptState::Failed { stdout, error },
|
||||||
|
};
|
||||||
|
|
||||||
|
cx.emit(ScriptEvent::Exited(script_id))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run_lua(
|
||||||
|
&mut self,
|
||||||
|
script: String,
|
||||||
|
stdout: Arc<Mutex<String>>,
|
||||||
|
cx: &mut Context<Self>,
|
||||||
|
) -> Task<anyhow::Result<()>> {
|
||||||
const SANDBOX_PREAMBLE: &str = include_str!("sandbox_preamble.lua");
|
const SANDBOX_PREAMBLE: &str = include_str!("sandbox_preamble.lua");
|
||||||
|
|
||||||
// TODO Remove fs_changes
|
// TODO Remove fs_changes
|
||||||
|
@ -62,52 +107,64 @@ impl Session {
|
||||||
.visible_worktrees(cx)
|
.visible_worktrees(cx)
|
||||||
.next()
|
.next()
|
||||||
.map(|worktree| worktree.read(cx).abs_path());
|
.map(|worktree| worktree.read(cx).abs_path());
|
||||||
|
|
||||||
let fs = self.project.read(cx).fs().clone();
|
let fs = self.project.read(cx).fs().clone();
|
||||||
let foreground_fns_tx = self.foreground_fns_tx.clone();
|
let foreground_fns_tx = self.foreground_fns_tx.clone();
|
||||||
cx.background_spawn(async move {
|
|
||||||
let lua = Lua::new();
|
|
||||||
lua.set_memory_limit(2 * 1024 * 1024 * 1024)?; // 2 GB
|
|
||||||
let globals = lua.globals();
|
|
||||||
let stdout = Arc::new(Mutex::new(String::new()));
|
|
||||||
globals.set(
|
|
||||||
"sb_print",
|
|
||||||
lua.create_function({
|
|
||||||
let stdout = stdout.clone();
|
|
||||||
move |_, args: MultiValue| Self::print(args, &stdout)
|
|
||||||
})?,
|
|
||||||
)?;
|
|
||||||
globals.set(
|
|
||||||
"search",
|
|
||||||
lua.create_async_function({
|
|
||||||
let foreground_fns_tx = foreground_fns_tx.clone();
|
|
||||||
let fs = fs.clone();
|
|
||||||
move |lua, regex| {
|
|
||||||
Self::search(lua, foreground_fns_tx.clone(), fs.clone(), regex)
|
|
||||||
}
|
|
||||||
})?,
|
|
||||||
)?;
|
|
||||||
globals.set(
|
|
||||||
"sb_io_open",
|
|
||||||
lua.create_function({
|
|
||||||
let fs_changes = fs_changes.clone();
|
|
||||||
let root_dir = root_dir.clone();
|
|
||||||
move |lua, (path_str, mode)| {
|
|
||||||
Self::io_open(&lua, &fs_changes, root_dir.as_ref(), path_str, mode)
|
|
||||||
}
|
|
||||||
})?,
|
|
||||||
)?;
|
|
||||||
globals.set("user_script", script)?;
|
|
||||||
|
|
||||||
lua.load(SANDBOX_PREAMBLE).exec_async().await?;
|
let task = cx.background_spawn({
|
||||||
|
let stdout = stdout.clone();
|
||||||
|
|
||||||
// Drop Lua instance to decrement reference count.
|
async move {
|
||||||
drop(lua);
|
let lua = Lua::new();
|
||||||
|
lua.set_memory_limit(2 * 1024 * 1024 * 1024)?; // 2 GB
|
||||||
|
let globals = lua.globals();
|
||||||
|
globals.set(
|
||||||
|
"sb_print",
|
||||||
|
lua.create_function({
|
||||||
|
let stdout = stdout.clone();
|
||||||
|
move |_, args: MultiValue| Self::print(args, &stdout)
|
||||||
|
})?,
|
||||||
|
)?;
|
||||||
|
globals.set(
|
||||||
|
"search",
|
||||||
|
lua.create_async_function({
|
||||||
|
let foreground_fns_tx = foreground_fns_tx.clone();
|
||||||
|
let fs = fs.clone();
|
||||||
|
move |lua, regex| {
|
||||||
|
Self::search(lua, foreground_fns_tx.clone(), fs.clone(), regex)
|
||||||
|
}
|
||||||
|
})?,
|
||||||
|
)?;
|
||||||
|
globals.set(
|
||||||
|
"sb_io_open",
|
||||||
|
lua.create_function({
|
||||||
|
let fs_changes = fs_changes.clone();
|
||||||
|
let root_dir = root_dir.clone();
|
||||||
|
move |lua, (path_str, mode)| {
|
||||||
|
Self::io_open(&lua, &fs_changes, root_dir.as_ref(), path_str, mode)
|
||||||
|
}
|
||||||
|
})?,
|
||||||
|
)?;
|
||||||
|
globals.set("user_script", script)?;
|
||||||
|
|
||||||
let stdout = Arc::try_unwrap(stdout)
|
lua.load(SANDBOX_PREAMBLE).exec_async().await?;
|
||||||
.expect("no more references to stdout")
|
|
||||||
.into_inner();
|
// Drop Lua instance to decrement reference count.
|
||||||
Ok(ScriptOutput { stdout })
|
drop(lua);
|
||||||
})
|
|
||||||
|
anyhow::Ok(())
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
task
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get(&self, script_id: ScriptId) -> &Script {
|
||||||
|
&self.scripts[script_id.0 as usize]
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_mut(&mut self, script_id: ScriptId) -> &mut Script {
|
||||||
|
&mut self.scripts[script_id.0 as usize]
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Sandboxed print() function in Lua.
|
/// Sandboxed print() function in Lua.
|
||||||
|
@ -678,6 +735,79 @@ impl UserData for FileContent {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub enum ScriptEvent {
|
||||||
|
Spawned(ScriptId),
|
||||||
|
Exited(ScriptId),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl EventEmitter<ScriptEvent> for ScriptSession {}
|
||||||
|
|
||||||
|
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
|
||||||
|
pub struct ScriptId(u32);
|
||||||
|
|
||||||
|
pub struct Script {
|
||||||
|
pub id: ScriptId,
|
||||||
|
pub state: ScriptState,
|
||||||
|
pub source: SharedString,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub enum ScriptState {
|
||||||
|
Generating,
|
||||||
|
Running {
|
||||||
|
stdout: Arc<Mutex<String>>,
|
||||||
|
},
|
||||||
|
Succeeded {
|
||||||
|
stdout: String,
|
||||||
|
},
|
||||||
|
Failed {
|
||||||
|
stdout: String,
|
||||||
|
error: anyhow::Error,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Script {
|
||||||
|
pub fn source_tag(&self) -> String {
|
||||||
|
format!("{}{}{}", SCRIPT_START_TAG, self.source, SCRIPT_END_TAG)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// If exited, returns a message with the output for the LLM
|
||||||
|
pub fn output_message_for_llm(&self) -> Option<String> {
|
||||||
|
match &self.state {
|
||||||
|
ScriptState::Generating { .. } => None,
|
||||||
|
ScriptState::Running { .. } => None,
|
||||||
|
ScriptState::Succeeded { stdout } => {
|
||||||
|
format!("Here's the script output:\n{}", stdout).into()
|
||||||
|
}
|
||||||
|
ScriptState::Failed { stdout, error } => format!(
|
||||||
|
"The script failed with:\n{}\n\nHere's the output it managed to print:\n{}",
|
||||||
|
error, stdout
|
||||||
|
)
|
||||||
|
.into(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get a snapshot of the script's stdout
|
||||||
|
pub fn stdout_snapshot(&self) -> String {
|
||||||
|
match &self.state {
|
||||||
|
ScriptState::Generating { .. } => String::new(),
|
||||||
|
ScriptState::Running { stdout } => stdout.lock().clone(),
|
||||||
|
ScriptState::Succeeded { stdout } => stdout.clone(),
|
||||||
|
ScriptState::Failed { stdout, .. } => stdout.clone(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the error if the script failed, otherwise None
|
||||||
|
pub fn error(&self) -> Option<&anyhow::Error> {
|
||||||
|
match &self.state {
|
||||||
|
ScriptState::Generating { .. } => None,
|
||||||
|
ScriptState::Running { .. } => None,
|
||||||
|
ScriptState::Succeeded { .. } => None,
|
||||||
|
ScriptState::Failed { error, .. } => Some(error),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use gpui::TestAppContext;
|
use gpui::TestAppContext;
|
||||||
|
@ -689,35 +819,17 @@ mod tests {
|
||||||
|
|
||||||
#[gpui::test]
|
#[gpui::test]
|
||||||
async fn test_print(cx: &mut TestAppContext) {
|
async fn test_print(cx: &mut TestAppContext) {
|
||||||
init_test(cx);
|
|
||||||
let fs = FakeFs::new(cx.executor());
|
|
||||||
let project = Project::test(fs, [], cx).await;
|
|
||||||
let session = cx.new(|cx| Session::new(project, cx));
|
|
||||||
let script = r#"
|
let script = r#"
|
||||||
print("Hello", "world!")
|
print("Hello", "world!")
|
||||||
print("Goodbye", "moon!")
|
print("Goodbye", "moon!")
|
||||||
"#;
|
"#;
|
||||||
let output = session
|
|
||||||
.update(cx, |session, cx| session.run_script(script.to_string(), cx))
|
let output = test_script(script, cx).await.unwrap();
|
||||||
.await
|
assert_eq!(output, "Hello\tworld!\nGoodbye\tmoon!\n");
|
||||||
.unwrap();
|
|
||||||
assert_eq!(output.stdout, "Hello\tworld!\nGoodbye\tmoon!\n");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[gpui::test]
|
#[gpui::test]
|
||||||
async fn test_search(cx: &mut TestAppContext) {
|
async fn test_search(cx: &mut TestAppContext) {
|
||||||
init_test(cx);
|
|
||||||
let fs = FakeFs::new(cx.executor());
|
|
||||||
fs.insert_tree(
|
|
||||||
"/",
|
|
||||||
json!({
|
|
||||||
"file1.txt": "Hello world!",
|
|
||||||
"file2.txt": "Goodbye moon!"
|
|
||||||
}),
|
|
||||||
)
|
|
||||||
.await;
|
|
||||||
let project = Project::test(fs, [Path::new("/")], cx).await;
|
|
||||||
let session = cx.new(|cx| Session::new(project, cx));
|
|
||||||
let script = r#"
|
let script = r#"
|
||||||
local results = search("world")
|
local results = search("world")
|
||||||
for i, result in ipairs(results) do
|
for i, result in ipairs(results) do
|
||||||
|
@ -728,11 +840,36 @@ mod tests {
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
"#;
|
"#;
|
||||||
let output = session
|
|
||||||
.update(cx, |session, cx| session.run_script(script.to_string(), cx))
|
let output = test_script(script, cx).await.unwrap();
|
||||||
.await
|
assert_eq!(output, "File: /file1.txt\nMatches:\n world\n");
|
||||||
.unwrap();
|
}
|
||||||
assert_eq!(output.stdout, "File: /file1.txt\nMatches:\n world\n");
|
|
||||||
|
async fn test_script(source: &str, cx: &mut TestAppContext) -> anyhow::Result<String> {
|
||||||
|
init_test(cx);
|
||||||
|
let fs = FakeFs::new(cx.executor());
|
||||||
|
fs.insert_tree(
|
||||||
|
"/",
|
||||||
|
json!({
|
||||||
|
"file1.txt": "Hello world!",
|
||||||
|
"file2.txt": "Goodbye moon!"
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let project = Project::test(fs, [Path::new("/")], cx).await;
|
||||||
|
let session = cx.new(|cx| ScriptSession::new(project, cx));
|
||||||
|
|
||||||
|
let (script_id, task) = session.update(cx, |session, cx| {
|
||||||
|
let script_id = session.new_script();
|
||||||
|
let task = session.run_script(script_id, source.to_string(), cx);
|
||||||
|
|
||||||
|
(script_id, task)
|
||||||
|
});
|
||||||
|
|
||||||
|
task.await?;
|
||||||
|
|
||||||
|
Ok(session.read_with(cx, |session, _cx| session.get(script_id).stdout_snapshot()))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn init_test(cx: &mut TestAppContext) {
|
fn init_test(cx: &mut TestAppContext) {
|
|
@ -3,6 +3,12 @@ output was, including both stdout as well as the git diff of changes it made to
|
||||||
the filesystem. That way, you can get more information about the code base, or
|
the filesystem. That way, you can get more information about the code base, or
|
||||||
make changes to the code base directly.
|
make changes to the code base directly.
|
||||||
|
|
||||||
|
Put the Lua script inside of an `<eval>` tag like so:
|
||||||
|
|
||||||
|
<eval type="lua">
|
||||||
|
print("Hello, world!")
|
||||||
|
</eval>
|
||||||
|
|
||||||
The Lua script will have access to `io` and it will run with the current working
|
The Lua script will have access to `io` and it will run with the current working
|
||||||
directory being in the root of the code base, so you can use it to explore,
|
directory being in the root of the code base, so you can use it to explore,
|
||||||
search, make changes, etc. You can also have the script print things, and I'll
|
search, make changes, etc. You can also have the script print things, and I'll
|
||||||
|
@ -10,13 +16,17 @@ tell you what the output was. Note that `io` only has `open`, and then the file
|
||||||
it returns only has the methods read, write, and close - it doesn't have popen
|
it returns only has the methods read, write, and close - it doesn't have popen
|
||||||
or anything else.
|
or anything else.
|
||||||
|
|
||||||
Also, I'm going to be putting this Lua script into JSON, so please don't use
|
|
||||||
Lua's double quote syntax for string literals - use one of Lua's other syntaxes
|
|
||||||
for string literals, so I don't have to escape the double quotes.
|
|
||||||
|
|
||||||
There will be a global called `search` which accepts a regex (it's implemented
|
There will be a global called `search` which accepts a regex (it's implemented
|
||||||
using Rust's regex crate, so use that regex syntax) and runs that regex on the
|
using Rust's regex crate, so use that regex syntax) and runs that regex on the
|
||||||
contents of every file in the code base (aside from gitignored files), then
|
contents of every file in the code base (aside from gitignored files), then
|
||||||
returns an array of tables with two fields: "path" (the path to the file that
|
returns an array of tables with two fields: "path" (the path to the file that
|
||||||
had the matches) and "matches" (an array of strings, with each string being a
|
had the matches) and "matches" (an array of strings, with each string being a
|
||||||
match that was found within the file).
|
match that was found within the file).
|
||||||
|
|
||||||
|
When I send you the script output, do not thank me for running it,
|
||||||
|
act as if you ran it yourself.
|
||||||
|
|
||||||
|
IMPORTANT!
|
||||||
|
Only include a maximum of one Lua script at the very end of your message
|
||||||
|
DO NOT WRITE ANYTHING ELSE AFTER THE SCRIPT. Wait for my response with the script
|
||||||
|
output to continue.
|
260
crates/assistant_scripting/src/tag.rs
Normal file
260
crates/assistant_scripting/src/tag.rs
Normal file
|
@ -0,0 +1,260 @@
|
||||||
|
pub const SCRIPT_START_TAG: &str = "<eval type=\"lua\">";
|
||||||
|
pub const SCRIPT_END_TAG: &str = "</eval>";
|
||||||
|
|
||||||
|
const START_TAG: &[u8] = SCRIPT_START_TAG.as_bytes();
|
||||||
|
const END_TAG: &[u8] = SCRIPT_END_TAG.as_bytes();
|
||||||
|
|
||||||
|
/// Parses a script tag in an assistant message as it is being streamed.
|
||||||
|
pub struct ScriptTagParser {
|
||||||
|
state: State,
|
||||||
|
buffer: Vec<u8>,
|
||||||
|
tag_match_ix: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
enum State {
|
||||||
|
Unstarted,
|
||||||
|
Streaming,
|
||||||
|
Ended,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, PartialEq)]
|
||||||
|
pub struct ChunkOutput {
|
||||||
|
/// The chunk with script tags removed.
|
||||||
|
pub content: String,
|
||||||
|
/// The full script tag content. `None` until closed.
|
||||||
|
pub script_source: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ScriptTagParser {
|
||||||
|
/// Create a new script tag parser.
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
state: State::Unstarted,
|
||||||
|
buffer: Vec::new(),
|
||||||
|
tag_match_ix: 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns true if the parser has found a script tag.
|
||||||
|
pub fn found_script(&self) -> bool {
|
||||||
|
match self.state {
|
||||||
|
State::Unstarted => false,
|
||||||
|
State::Streaming | State::Ended => true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Process a new chunk of input, splitting it into surrounding content and script source.
|
||||||
|
pub fn parse_chunk(&mut self, input: &str) -> ChunkOutput {
|
||||||
|
let mut content = Vec::with_capacity(input.len());
|
||||||
|
|
||||||
|
for byte in input.bytes() {
|
||||||
|
match self.state {
|
||||||
|
State::Unstarted => {
|
||||||
|
if collect_until_tag(byte, START_TAG, &mut self.tag_match_ix, &mut content) {
|
||||||
|
self.state = State::Streaming;
|
||||||
|
self.buffer = Vec::with_capacity(1024);
|
||||||
|
self.tag_match_ix = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
State::Streaming => {
|
||||||
|
if collect_until_tag(byte, END_TAG, &mut self.tag_match_ix, &mut self.buffer) {
|
||||||
|
self.state = State::Ended;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
State::Ended => content.push(byte),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let content = unsafe { String::from_utf8_unchecked(content) };
|
||||||
|
|
||||||
|
let script_source = if matches!(self.state, State::Ended) && !self.buffer.is_empty() {
|
||||||
|
let source = unsafe { String::from_utf8_unchecked(std::mem::take(&mut self.buffer)) };
|
||||||
|
|
||||||
|
Some(source)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
|
ChunkOutput {
|
||||||
|
content,
|
||||||
|
script_source,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn collect_until_tag(byte: u8, tag: &[u8], tag_match_ix: &mut usize, buffer: &mut Vec<u8>) -> bool {
|
||||||
|
// this can't be a method because it'd require a mutable borrow on both self and self.buffer
|
||||||
|
|
||||||
|
if match_tag_byte(byte, tag, tag_match_ix) {
|
||||||
|
*tag_match_ix >= tag.len()
|
||||||
|
} else {
|
||||||
|
if *tag_match_ix > 0 {
|
||||||
|
// push the partially matched tag to the buffer
|
||||||
|
buffer.extend_from_slice(&tag[..*tag_match_ix]);
|
||||||
|
*tag_match_ix = 0;
|
||||||
|
|
||||||
|
// the tag might start to match again
|
||||||
|
if match_tag_byte(byte, tag, tag_match_ix) {
|
||||||
|
return *tag_match_ix >= tag.len();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
buffer.push(byte);
|
||||||
|
|
||||||
|
false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn match_tag_byte(byte: u8, tag: &[u8], tag_match_ix: &mut usize) -> bool {
|
||||||
|
if byte == tag[*tag_match_ix] {
|
||||||
|
*tag_match_ix += 1;
|
||||||
|
true
|
||||||
|
} else {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_parse_complete_tag() {
|
||||||
|
let mut parser = ScriptTagParser::new();
|
||||||
|
let input = "<eval type=\"lua\">print(\"Hello, World!\")</eval>";
|
||||||
|
let result = parser.parse_chunk(input);
|
||||||
|
assert_eq!(result.content, "");
|
||||||
|
assert_eq!(
|
||||||
|
result.script_source,
|
||||||
|
Some("print(\"Hello, World!\")".to_string())
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_no_tag() {
|
||||||
|
let mut parser = ScriptTagParser::new();
|
||||||
|
let input = "No tags here, just plain text";
|
||||||
|
let result = parser.parse_chunk(input);
|
||||||
|
assert_eq!(result.content, "No tags here, just plain text");
|
||||||
|
assert_eq!(result.script_source, None);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_partial_end_tag() {
|
||||||
|
let mut parser = ScriptTagParser::new();
|
||||||
|
|
||||||
|
// Start the tag
|
||||||
|
let result = parser.parse_chunk("<eval type=\"lua\">let x = '</e");
|
||||||
|
assert_eq!(result.content, "");
|
||||||
|
assert_eq!(result.script_source, None);
|
||||||
|
|
||||||
|
// Finish with the rest
|
||||||
|
let result = parser.parse_chunk("val' + 'not the end';</eval>");
|
||||||
|
assert_eq!(result.content, "");
|
||||||
|
assert_eq!(
|
||||||
|
result.script_source,
|
||||||
|
Some("let x = '</eval' + 'not the end';".to_string())
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_text_before_and_after_tag() {
|
||||||
|
let mut parser = ScriptTagParser::new();
|
||||||
|
let input = "Before tag <eval type=\"lua\">print(\"Hello\")</eval> After tag";
|
||||||
|
let result = parser.parse_chunk(input);
|
||||||
|
assert_eq!(result.content, "Before tag After tag");
|
||||||
|
assert_eq!(result.script_source, Some("print(\"Hello\")".to_string()));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_multiple_chunks_with_surrounding_text() {
|
||||||
|
let mut parser = ScriptTagParser::new();
|
||||||
|
|
||||||
|
// First chunk with text before
|
||||||
|
let result = parser.parse_chunk("Before script <eval type=\"lua\">local x = 10");
|
||||||
|
assert_eq!(result.content, "Before script ");
|
||||||
|
assert_eq!(result.script_source, None);
|
||||||
|
|
||||||
|
// Second chunk with script content
|
||||||
|
let result = parser.parse_chunk("\nlocal y = 20");
|
||||||
|
assert_eq!(result.content, "");
|
||||||
|
assert_eq!(result.script_source, None);
|
||||||
|
|
||||||
|
// Last chunk with text after
|
||||||
|
let result = parser.parse_chunk("\nprint(x + y)</eval> After script");
|
||||||
|
assert_eq!(result.content, " After script");
|
||||||
|
assert_eq!(
|
||||||
|
result.script_source,
|
||||||
|
Some("local x = 10\nlocal y = 20\nprint(x + y)".to_string())
|
||||||
|
);
|
||||||
|
|
||||||
|
let result = parser.parse_chunk(" there's more text");
|
||||||
|
assert_eq!(result.content, " there's more text");
|
||||||
|
assert_eq!(result.script_source, None);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_partial_start_tag_matching() {
|
||||||
|
let mut parser = ScriptTagParser::new();
|
||||||
|
|
||||||
|
// partial match of start tag...
|
||||||
|
let result = parser.parse_chunk("<ev");
|
||||||
|
assert_eq!(result.content, "");
|
||||||
|
|
||||||
|
// ...that's abandandoned when the < of a real tag is encountered
|
||||||
|
let result = parser.parse_chunk("<eval type=\"lua\">script content</eval>");
|
||||||
|
// ...so it gets pushed to content
|
||||||
|
assert_eq!(result.content, "<ev");
|
||||||
|
// ...and the real tag is parsed correctly
|
||||||
|
assert_eq!(result.script_source, Some("script content".to_string()));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_random_chunked_parsing() {
|
||||||
|
use rand::rngs::StdRng;
|
||||||
|
use rand::{Rng, SeedableRng};
|
||||||
|
use std::time::{SystemTime, UNIX_EPOCH};
|
||||||
|
|
||||||
|
let test_inputs = [
|
||||||
|
"Before <eval type=\"lua\">print(\"Hello\")</eval> After",
|
||||||
|
"No tags here at all",
|
||||||
|
"<eval type=\"lua\">local x = 10\nlocal y = 20\nprint(x + y)</eval>",
|
||||||
|
"Text <eval type=\"lua\">if true then\nprint(\"nested </e\")\nend</eval> more",
|
||||||
|
];
|
||||||
|
|
||||||
|
let seed = SystemTime::now()
|
||||||
|
.duration_since(UNIX_EPOCH)
|
||||||
|
.unwrap()
|
||||||
|
.as_secs();
|
||||||
|
|
||||||
|
eprintln!("Using random seed: {}", seed);
|
||||||
|
let mut rng = StdRng::seed_from_u64(seed);
|
||||||
|
|
||||||
|
for test_input in &test_inputs {
|
||||||
|
let mut reference_parser = ScriptTagParser::new();
|
||||||
|
let expected = reference_parser.parse_chunk(test_input);
|
||||||
|
|
||||||
|
let mut chunked_parser = ScriptTagParser::new();
|
||||||
|
let mut remaining = test_input.as_bytes();
|
||||||
|
let mut actual_content = String::new();
|
||||||
|
let mut actual_script = None;
|
||||||
|
|
||||||
|
while !remaining.is_empty() {
|
||||||
|
let chunk_size = rng.gen_range(1..=remaining.len().min(5));
|
||||||
|
let (chunk, rest) = remaining.split_at(chunk_size);
|
||||||
|
remaining = rest;
|
||||||
|
|
||||||
|
let chunk_str = std::str::from_utf8(chunk).unwrap();
|
||||||
|
let result = chunked_parser.parse_chunk(chunk_str);
|
||||||
|
|
||||||
|
actual_content.push_str(&result.content);
|
||||||
|
if result.script_source.is_some() {
|
||||||
|
actual_script = result.script_source;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
assert_eq!(actual_content, expected.content);
|
||||||
|
assert_eq!(actual_script, expected.script_source);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,58 +0,0 @@
|
||||||
mod session;
|
|
||||||
|
|
||||||
use project::Project;
|
|
||||||
pub(crate) use session::*;
|
|
||||||
|
|
||||||
use assistant_tool::{Tool, ToolRegistry};
|
|
||||||
use gpui::{App, AppContext as _, Entity, Task};
|
|
||||||
use schemars::JsonSchema;
|
|
||||||
use serde::Deserialize;
|
|
||||||
use std::sync::Arc;
|
|
||||||
|
|
||||||
pub fn init(cx: &App) {
|
|
||||||
let registry = ToolRegistry::global(cx);
|
|
||||||
registry.register_tool(ScriptingTool);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Deserialize, JsonSchema)]
|
|
||||||
struct ScriptingToolInput {
|
|
||||||
lua_script: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
struct ScriptingTool;
|
|
||||||
|
|
||||||
impl Tool for ScriptingTool {
|
|
||||||
fn name(&self) -> String {
|
|
||||||
"lua-interpreter".into()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn description(&self) -> String {
|
|
||||||
include_str!("scripting_tool_description.txt").into()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn input_schema(&self) -> serde_json::Value {
|
|
||||||
let schema = schemars::schema_for!(ScriptingToolInput);
|
|
||||||
serde_json::to_value(&schema).unwrap()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn run(
|
|
||||||
self: Arc<Self>,
|
|
||||||
input: serde_json::Value,
|
|
||||||
project: Entity<Project>,
|
|
||||||
cx: &mut App,
|
|
||||||
) -> Task<anyhow::Result<String>> {
|
|
||||||
let input = match serde_json::from_value::<ScriptingToolInput>(input) {
|
|
||||||
Err(err) => return Task::ready(Err(err.into())),
|
|
||||||
Ok(input) => input,
|
|
||||||
};
|
|
||||||
|
|
||||||
let session = cx.new(|cx| Session::new(project, cx));
|
|
||||||
let lua_script = input.lua_script;
|
|
||||||
let script = session.update(cx, |session, cx| session.run_script(lua_script, cx));
|
|
||||||
cx.spawn(|_cx| async move {
|
|
||||||
let output = script.await?.stdout;
|
|
||||||
drop(session);
|
|
||||||
Ok(format!("The script output the following:\n{output}"))
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -98,7 +98,6 @@ remote.workspace = true
|
||||||
repl.workspace = true
|
repl.workspace = true
|
||||||
reqwest_client.workspace = true
|
reqwest_client.workspace = true
|
||||||
rope.workspace = true
|
rope.workspace = true
|
||||||
scripting_tool.workspace = true
|
|
||||||
search.workspace = true
|
search.workspace = true
|
||||||
serde.workspace = true
|
serde.workspace = true
|
||||||
serde_json.workspace = true
|
serde_json.workspace = true
|
||||||
|
|
|
@ -476,7 +476,6 @@ fn main() {
|
||||||
cx,
|
cx,
|
||||||
);
|
);
|
||||||
assistant_tools::init(cx);
|
assistant_tools::init(cx);
|
||||||
scripting_tool::init(cx);
|
|
||||||
repl::init(app_state.fs.clone(), cx);
|
repl::init(app_state.fs.clone(), cx);
|
||||||
extension_host::init(
|
extension_host::init(
|
||||||
extension_host_proxy,
|
extension_host_proxy,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue