diff --git a/Cargo.lock b/Cargo.lock index 3995e9518b..56850e38c1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -450,7 +450,6 @@ version = "0.1.0" dependencies = [ "anyhow", "assistant_context_editor", - "assistant_scripting", "assistant_settings", "assistant_slash_command", "assistant_tool", @@ -564,26 +563,6 @@ dependencies = [ "workspace", ] -[[package]] -name = "assistant_scripting" -version = "0.1.0" -dependencies = [ - "anyhow", - "collections", - "futures 0.3.31", - "gpui", - "log", - "mlua", - "parking_lot", - "project", - "rand 0.8.5", - "regex", - "serde", - "serde_json", - "settings", - "util", -] - [[package]] name = "assistant_settings" version = "0.1.0" @@ -11931,6 +11910,28 @@ version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a3cf7c11c38cb994f3d40e8a8cde3bbd1f72a435e4c49e85d6553d8312306152" +[[package]] +name = "scripting_tool" +version = "0.1.0" +dependencies = [ + "anyhow", + "assistant_tool", + "collections", + "futures 0.3.31", + "gpui", + "log", + "mlua", + "parking_lot", + "project", + "rand 0.8.5", + "regex", + "schemars", + "serde", + "serde_json", + "settings", + "util", +] + [[package]] name = "scrypt" version = "0.11.0" @@ -16985,6 +16986,7 @@ dependencies = [ "repl", "reqwest_client", "rope", + "scripting_tool", "search", "serde", "serde_json", diff --git a/Cargo.toml b/Cargo.toml index 70453e455f..ddcd16eaca 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,7 +8,7 @@ members = [ "crates/assistant", "crates/assistant2", "crates/assistant_context_editor", - "crates/assistant_scripting", + "crates/scripting_tool", "crates/assistant_settings", "crates/assistant_slash_command", "crates/assistant_slash_commands", @@ -318,7 +318,7 @@ reqwest_client = { path = "crates/reqwest_client" } rich_text = { path = "crates/rich_text" } rope = { path = "crates/rope" } rpc = { path = "crates/rpc" } -assistant_scripting = { path = "crates/assistant_scripting" } +scripting_tool = { path = "crates/scripting_tool" } search = { path = "crates/search" } semantic_index = { path = "crates/semantic_index" } semantic_version = { path = "crates/semantic_version" } diff --git a/crates/assistant2/Cargo.toml b/crates/assistant2/Cargo.toml index 74266fb9f1..66aca29947 100644 --- a/crates/assistant2/Cargo.toml +++ b/crates/assistant2/Cargo.toml @@ -21,7 +21,6 @@ test-support = [ [dependencies] anyhow.workspace = true assistant_context_editor.workspace = true -assistant_scripting.workspace = true assistant_settings.workspace = true assistant_slash_command.workspace = true assistant_tool.workspace = true diff --git a/crates/assistant2/src/active_thread.rs b/crates/assistant2/src/active_thread.rs index 8f3677b6d0..e63ae4cd05 100644 --- a/crates/assistant2/src/active_thread.rs +++ b/crates/assistant2/src/active_thread.rs @@ -1,12 +1,11 @@ use std::sync::Arc; -use assistant_scripting::{ScriptId, ScriptState}; -use collections::{HashMap, HashSet}; +use collections::HashMap; use editor::{Editor, MultiBuffer}; use gpui::{ list, AbsoluteLength, AnyElement, App, ClickEvent, DefiniteLength, EdgesRefinement, Empty, Entity, Focusable, Length, ListAlignment, ListOffset, ListState, StyleRefinement, Subscription, - Task, TextStyleRefinement, UnderlineStyle, WeakEntity, + Task, TextStyleRefinement, UnderlineStyle, }; use language::{Buffer, LanguageRegistry}; use language_model::{LanguageModelRegistry, LanguageModelToolUseId, Role}; @@ -15,7 +14,6 @@ use settings::Settings as _; use theme::ThemeSettings; use ui::{prelude::*, Disclosure, KeyBinding}; use util::ResultExt as _; -use workspace::Workspace; use crate::thread::{MessageId, RequestKind, Thread, ThreadError, ThreadEvent}; use crate::thread_store::ThreadStore; @@ -23,7 +21,6 @@ use crate::tool_use::{ToolUse, ToolUseStatus}; use crate::ui::ContextPill; pub struct ActiveThread { - workspace: WeakEntity, language_registry: Arc, thread_store: Entity, thread: Entity, @@ -33,7 +30,6 @@ pub struct ActiveThread { rendered_messages_by_id: HashMap>, editing_message: Option<(MessageId, EditMessageState)>, expanded_tool_uses: HashMap, - expanded_scripts: HashSet, last_error: Option, _subscriptions: Vec, } @@ -44,7 +40,6 @@ struct EditMessageState { impl ActiveThread { pub fn new( - workspace: WeakEntity, thread: Entity, thread_store: Entity, language_registry: Arc, @@ -57,7 +52,6 @@ impl ActiveThread { ]; let mut this = Self { - workspace, language_registry, thread_store, thread: thread.clone(), @@ -65,7 +59,6 @@ impl ActiveThread { messages: Vec::new(), rendered_messages_by_id: HashMap::default(), expanded_tool_uses: HashMap::default(), - expanded_scripts: HashSet::default(), list_state: ListState::new(0, ListAlignment::Bottom, px(1024.), { let this = cx.entity().downgrade(); move |ix, window: &mut Window, cx: &mut App| { @@ -466,10 +459,7 @@ impl ActiveThread { let tool_uses = thread.tool_uses_for_message(message_id); // Don't render user messages that are just there for returning tool results. - if message.role == Role::User - && (thread.message_has_tool_results(message_id) - || thread.message_has_script_output(message_id)) - { + if message.role == Role::User && thread.message_has_tool_results(message_id) { return Empty.into_any(); } @@ -618,7 +608,6 @@ impl ActiveThread { Role::Assistant => div() .id(("message-container", ix)) .child(message_content) - .children(self.render_script(message_id, cx)) .map(|parent| { if tool_uses.is_empty() { return parent; @@ -738,139 +727,6 @@ impl ActiveThread { }), ) } - - fn render_script(&self, message_id: MessageId, cx: &mut Context) -> Option { - let script = self.thread.read(cx).script_for_message(message_id, cx)?; - - let is_open = self.expanded_scripts.contains(&script.id); - let colors = cx.theme().colors(); - - let element = div().px_2p5().child( - v_flex() - .gap_1() - .rounded_lg() - .border_1() - .border_color(colors.border) - .child( - h_flex() - .justify_between() - .py_0p5() - .pl_1() - .pr_2() - .bg(colors.editor_foreground.opacity(0.02)) - .when(is_open, |element| element.border_b_1().rounded_t(px(6.))) - .when(!is_open, |element| element.rounded_md()) - .border_color(colors.border) - .child( - h_flex() - .gap_1() - .child(Disclosure::new("script-disclosure", is_open).on_click( - cx.listener({ - let script_id = script.id; - move |this, _event, _window, _cx| { - if this.expanded_scripts.contains(&script_id) { - this.expanded_scripts.remove(&script_id); - } else { - this.expanded_scripts.insert(script_id); - } - } - }), - )) - // TODO: Generate script description - .child(Label::new("Script")), - ) - .child( - h_flex() - .gap_1() - .child( - Label::new(match script.state { - ScriptState::Generating => "Generating", - ScriptState::Running { .. } => "Running", - ScriptState::Succeeded { .. } => "Finished", - ScriptState::Failed { .. } => "Error", - }) - .size(LabelSize::XSmall) - .buffer_font(cx), - ) - .child( - IconButton::new("view-source", IconName::Eye) - .icon_color(Color::Muted) - .disabled(matches!(script.state, ScriptState::Generating)) - .on_click(cx.listener({ - let source = script.source.clone(); - move |this, _event, window, cx| { - this.open_script_source(source.clone(), window, cx); - } - })), - ), - ), - ) - .when(is_open, |parent| { - let stdout = script.stdout_snapshot(); - let error = script.error(); - - parent.child( - v_flex() - .p_2() - .bg(colors.editor_background) - .gap_2() - .child(if stdout.is_empty() && error.is_none() { - Label::new("No output yet") - .size(LabelSize::Small) - .color(Color::Muted) - } else { - Label::new(stdout).size(LabelSize::Small).buffer_font(cx) - }) - .children(script.error().map(|err| { - Label::new(err.to_string()) - .size(LabelSize::Small) - .color(Color::Error) - })), - ) - }), - ); - - Some(element.into_any()) - } - - fn open_script_source( - &mut self, - source: SharedString, - window: &mut Window, - cx: &mut Context<'_, ActiveThread>, - ) { - let language_registry = self.language_registry.clone(); - let workspace = self.workspace.clone(); - let source = source.clone(); - - cx.spawn_in(window, |_, mut cx| async move { - let lua = language_registry.language_for_name("Lua").await.log_err(); - - workspace.update_in(&mut cx, |workspace, window, cx| { - let project = workspace.project().clone(); - - let buffer = project.update(cx, |project, cx| { - project.create_local_buffer(&source.trim(), lua, cx) - }); - - let buffer = cx.new(|cx| { - MultiBuffer::singleton(buffer, cx) - // TODO: Generate script description - .with_title("Assistant script".into()) - }); - - let editor = cx.new(|cx| { - let mut editor = - Editor::for_multibuffer(buffer, Some(project), true, window, cx); - editor.set_read_only(true); - editor - }); - - workspace.add_item_to_active_pane(Box::new(editor), None, true, window, cx); - }) - }) - .detach_and_log_err(cx); - } } impl Render for ActiveThread { diff --git a/crates/assistant2/src/assistant_panel.rs b/crates/assistant2/src/assistant_panel.rs index 123ae8f459..70ad76fadc 100644 --- a/crates/assistant2/src/assistant_panel.rs +++ b/crates/assistant2/src/assistant_panel.rs @@ -168,7 +168,6 @@ impl AssistantPanel { let thread = cx.new(|cx| { ActiveThread::new( - workspace.clone(), thread.clone(), thread_store.clone(), language_registry.clone(), @@ -242,7 +241,6 @@ impl AssistantPanel { self.active_view = ActiveView::Thread; self.thread = cx.new(|cx| { ActiveThread::new( - self.workspace.clone(), thread.clone(), self.thread_store.clone(), self.language_registry.clone(), @@ -376,7 +374,6 @@ impl AssistantPanel { this.active_view = ActiveView::Thread; this.thread = cx.new(|cx| { ActiveThread::new( - this.workspace.clone(), thread.clone(), this.thread_store.clone(), this.language_registry.clone(), diff --git a/crates/assistant2/src/thread.rs b/crates/assistant2/src/thread.rs index 4dc194e0e9..9d19a59d0e 100644 --- a/crates/assistant2/src/thread.rs +++ b/crates/assistant2/src/thread.rs @@ -1,14 +1,11 @@ use std::sync::Arc; use anyhow::Result; -use assistant_scripting::{ - Script, ScriptEvent, ScriptId, ScriptSession, ScriptTagParser, SCRIPTING_PROMPT, -}; use assistant_tool::ToolWorkingSet; use chrono::{DateTime, Utc}; use collections::{BTreeMap, HashMap, HashSet}; use futures::StreamExt as _; -use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Subscription, Task}; +use gpui::{App, Context, Entity, EventEmitter, SharedString, Task}; use language_model::{ LanguageModel, LanguageModelCompletionEvent, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult, @@ -78,21 +75,14 @@ pub struct Thread { project: Entity, tools: Arc, tool_use: ToolUseState, - scripts_by_assistant_message: HashMap, - script_output_messages: HashSet, - script_session: Entity, - _script_session_subscription: Subscription, } impl Thread { pub fn new( project: Entity, tools: Arc, - cx: &mut Context, + _cx: &mut Context, ) -> Self { - let script_session = cx.new(|cx| ScriptSession::new(project.clone(), cx)); - let script_session_subscription = cx.subscribe(&script_session, Self::handle_script_event); - Self { id: ThreadId::new(), updated_at: Utc::now(), @@ -107,10 +97,6 @@ impl Thread { project, tools, tool_use: ToolUseState::new(), - scripts_by_assistant_message: HashMap::default(), - script_output_messages: HashSet::default(), - script_session, - _script_session_subscription: script_session_subscription, } } @@ -119,7 +105,7 @@ impl Thread { saved: SavedThread, project: Entity, tools: Arc, - cx: &mut Context, + _cx: &mut Context, ) -> Self { let next_message_id = MessageId( saved @@ -129,8 +115,6 @@ impl Thread { .unwrap_or(0), ); let tool_use = ToolUseState::from_saved_messages(&saved.messages); - let script_session = cx.new(|cx| ScriptSession::new(project.clone(), cx)); - let script_session_subscription = cx.subscribe(&script_session, Self::handle_script_event); Self { id, @@ -154,10 +138,6 @@ impl Thread { project, tools, tool_use, - scripts_by_assistant_message: HashMap::default(), - script_output_messages: HashSet::default(), - script_session, - _script_session_subscription: script_session_subscription, } } @@ -243,10 +223,6 @@ impl Thread { self.tool_use.message_has_tool_results(message_id) } - pub fn message_has_script_output(&self, message_id: MessageId) -> bool { - self.script_output_messages.contains(&message_id) - } - pub fn insert_user_message( &mut self, text: impl Into, @@ -327,39 +303,6 @@ impl Thread { text } - pub fn script_for_message<'a>( - &'a self, - message_id: MessageId, - cx: &'a App, - ) -> Option<&'a Script> { - self.scripts_by_assistant_message - .get(&message_id) - .map(|script_id| self.script_session.read(cx).get(*script_id)) - } - - fn handle_script_event( - &mut self, - _script_session: Entity, - event: &ScriptEvent, - cx: &mut Context, - ) { - match event { - ScriptEvent::Spawned(_) => {} - ScriptEvent::Exited(script_id) => { - if let Some(output_message) = self - .script_session - .read(cx) - .get(*script_id) - .output_message_for_llm() - { - let message_id = self.insert_user_message(output_message, vec![], cx); - self.script_output_messages.insert(message_id); - cx.emit(ThreadEvent::ScriptFinished) - } - } - } - } - pub fn send_to_model( &mut self, model: Arc, @@ -388,7 +331,7 @@ impl Thread { pub fn to_completion_request( &self, request_kind: RequestKind, - cx: &App, + _cx: &App, ) -> LanguageModelRequest { let mut request = LanguageModelRequest { messages: vec![], @@ -397,12 +340,6 @@ impl Thread { temperature: None, }; - request.messages.push(LanguageModelRequestMessage { - role: Role::System, - content: vec![SCRIPTING_PROMPT.to_string().into()], - cache: true, - }); - let mut referenced_context_ids = HashSet::default(); for message in &self.messages { @@ -436,15 +373,6 @@ impl Thread { RequestKind::Chat => { self.tool_use .attach_tool_uses(message.id, &mut request_message); - - if matches!(message.role, Role::Assistant) { - if let Some(script_id) = self.scripts_by_assistant_message.get(&message.id) - { - let script = self.script_session.read(cx).get(*script_id); - - request_message.content.push(script.source_tag().into()); - } - } } RequestKind::Summarize => { // We don't care about tool use during summarization. @@ -486,8 +414,6 @@ impl Thread { let stream_completion = async { let mut events = stream.await?; let mut stop_reason = StopReason::EndTurn; - let mut script_tag_parser = ScriptTagParser::new(); - let mut script_id = None; while let Some(event) = events.next().await { let event = event?; @@ -502,44 +428,20 @@ impl Thread { } LanguageModelCompletionEvent::Text(chunk) => { if let Some(last_message) = thread.messages.last_mut() { - let chunk = script_tag_parser.parse_chunk(&chunk); - - let message_id = if last_message.role == Role::Assistant { - last_message.text.push_str(&chunk.content); + if last_message.role == Role::Assistant { + last_message.text.push_str(&chunk); cx.emit(ThreadEvent::StreamedAssistantText( last_message.id, - chunk.content, + chunk, )); - last_message.id } else { // If we won't have an Assistant message yet, assume this chunk marks the beginning // of a new Assistant response. // // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it // will result in duplicating the text of the chunk in the rendered Markdown. - thread.insert_message(Role::Assistant, chunk.content, cx) + thread.insert_message(Role::Assistant, chunk, cx); }; - - if script_id.is_none() && script_tag_parser.found_script() { - let id = thread - .script_session - .update(cx, |session, _cx| session.new_script()); - thread.scripts_by_assistant_message.insert(message_id, id); - - script_id = Some(id); - } - - if let (Some(script_source), Some(script_id)) = - (chunk.script_source, script_id) - { - // TODO: move buffer to script and run as it streams - thread - .script_session - .update(cx, |this, cx| { - this.run_script(script_id, script_source, cx) - }) - .detach_and_log_err(cx); - } } } LanguageModelCompletionEvent::ToolUse(tool_use) => { diff --git a/crates/assistant_scripting/src/assistant_scripting.rs b/crates/assistant_scripting/src/assistant_scripting.rs deleted file mode 100644 index fbe335d7fb..0000000000 --- a/crates/assistant_scripting/src/assistant_scripting.rs +++ /dev/null @@ -1,7 +0,0 @@ -mod session; -mod tag; - -pub use session::*; -pub use tag::*; - -pub const SCRIPTING_PROMPT: &str = include_str!("./system_prompt.txt"); diff --git a/crates/assistant_scripting/src/system_prompt.txt b/crates/assistant_scripting/src/system_prompt.txt deleted file mode 100644 index 359085b9dd..0000000000 --- a/crates/assistant_scripting/src/system_prompt.txt +++ /dev/null @@ -1,36 +0,0 @@ -You can write a Lua script and I'll run it on my codebase and tell you what its -output was, including both stdout as well as the git diff of changes it made to -the filesystem. That way, you can get more information about the code base, or -make changes to the code base directly. - -Put the Lua script inside of an `` tag like so: - - -print("Hello, world!") - - -The Lua script will have access to `io` and it will run with the current working -directory being in the root of the code base, so you can use it to explore, -search, make changes, etc. You can also have the script print things, and I'll -tell you what the output was. Note that `io` only has `open`, and then the file -it returns only has the methods read, write, and close - it doesn't have popen -or anything else. - -There is a function called `search` which accepts a regex (it's implemented -using Rust's regex crate, so use that regex syntax) and runs that regex on the -contents of every file in the code base (aside from gitignored files), then -returns an array of tables with two fields: "path" (the path to the file that -had the matches) and "matches" (an array of strings, with each string being a -match that was found within the file). - -There is a function called `outline` which accepts the path to a source file, -and returns a string where each line is a declaration. These lines are indented -with 2 spaces to indicate when a declaration is inside another. - -When I send you the script output, do not thank me for running it, -act as if you ran it yourself. - -IMPORTANT! -Only include a maximum of one Lua script at the very end of your message -DO NOT WRITE ANYTHING ELSE AFTER THE SCRIPT. Wait for my response with the script -output to continue. diff --git a/crates/assistant_scripting/src/tag.rs b/crates/assistant_scripting/src/tag.rs deleted file mode 100644 index 369f24ae90..0000000000 --- a/crates/assistant_scripting/src/tag.rs +++ /dev/null @@ -1,260 +0,0 @@ -pub const SCRIPT_START_TAG: &str = ""; -pub const SCRIPT_END_TAG: &str = ""; - -const START_TAG: &[u8] = SCRIPT_START_TAG.as_bytes(); -const END_TAG: &[u8] = SCRIPT_END_TAG.as_bytes(); - -/// Parses a script tag in an assistant message as it is being streamed. -pub struct ScriptTagParser { - state: State, - buffer: Vec, - tag_match_ix: usize, -} - -enum State { - Unstarted, - Streaming, - Ended, -} - -#[derive(Debug, PartialEq)] -pub struct ChunkOutput { - /// The chunk with script tags removed. - pub content: String, - /// The full script tag content. `None` until closed. - pub script_source: Option, -} - -impl ScriptTagParser { - /// Create a new script tag parser. - pub fn new() -> Self { - Self { - state: State::Unstarted, - buffer: Vec::new(), - tag_match_ix: 0, - } - } - - /// Returns true if the parser has found a script tag. - pub fn found_script(&self) -> bool { - match self.state { - State::Unstarted => false, - State::Streaming | State::Ended => true, - } - } - - /// Process a new chunk of input, splitting it into surrounding content and script source. - pub fn parse_chunk(&mut self, input: &str) -> ChunkOutput { - let mut content = Vec::with_capacity(input.len()); - - for byte in input.bytes() { - match self.state { - State::Unstarted => { - if collect_until_tag(byte, START_TAG, &mut self.tag_match_ix, &mut content) { - self.state = State::Streaming; - self.buffer = Vec::with_capacity(1024); - self.tag_match_ix = 0; - } - } - State::Streaming => { - if collect_until_tag(byte, END_TAG, &mut self.tag_match_ix, &mut self.buffer) { - self.state = State::Ended; - } - } - State::Ended => content.push(byte), - } - } - - let content = unsafe { String::from_utf8_unchecked(content) }; - - let script_source = if matches!(self.state, State::Ended) && !self.buffer.is_empty() { - let source = unsafe { String::from_utf8_unchecked(std::mem::take(&mut self.buffer)) }; - - Some(source) - } else { - None - }; - - ChunkOutput { - content, - script_source, - } - } -} - -fn collect_until_tag(byte: u8, tag: &[u8], tag_match_ix: &mut usize, buffer: &mut Vec) -> bool { - // this can't be a method because it'd require a mutable borrow on both self and self.buffer - - if match_tag_byte(byte, tag, tag_match_ix) { - *tag_match_ix >= tag.len() - } else { - if *tag_match_ix > 0 { - // push the partially matched tag to the buffer - buffer.extend_from_slice(&tag[..*tag_match_ix]); - *tag_match_ix = 0; - - // the tag might start to match again - if match_tag_byte(byte, tag, tag_match_ix) { - return *tag_match_ix >= tag.len(); - } - } - - buffer.push(byte); - - false - } -} - -fn match_tag_byte(byte: u8, tag: &[u8], tag_match_ix: &mut usize) -> bool { - if byte == tag[*tag_match_ix] { - *tag_match_ix += 1; - true - } else { - false - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_parse_complete_tag() { - let mut parser = ScriptTagParser::new(); - let input = "print(\"Hello, World!\")"; - let result = parser.parse_chunk(input); - assert_eq!(result.content, ""); - assert_eq!( - result.script_source, - Some("print(\"Hello, World!\")".to_string()) - ); - } - - #[test] - fn test_no_tag() { - let mut parser = ScriptTagParser::new(); - let input = "No tags here, just plain text"; - let result = parser.parse_chunk(input); - assert_eq!(result.content, "No tags here, just plain text"); - assert_eq!(result.script_source, None); - } - - #[test] - fn test_partial_end_tag() { - let mut parser = ScriptTagParser::new(); - - // Start the tag - let result = parser.parse_chunk("let x = '"); - assert_eq!(result.content, ""); - assert_eq!( - result.script_source, - Some("let x = 'print(\"Hello\") After tag"; - let result = parser.parse_chunk(input); - assert_eq!(result.content, "Before tag After tag"); - assert_eq!(result.script_source, Some("print(\"Hello\")".to_string())); - } - - #[test] - fn test_multiple_chunks_with_surrounding_text() { - let mut parser = ScriptTagParser::new(); - - // First chunk with text before - let result = parser.parse_chunk("Before script local x = 10"); - assert_eq!(result.content, "Before script "); - assert_eq!(result.script_source, None); - - // Second chunk with script content - let result = parser.parse_chunk("\nlocal y = 20"); - assert_eq!(result.content, ""); - assert_eq!(result.script_source, None); - - // Last chunk with text after - let result = parser.parse_chunk("\nprint(x + y) After script"); - assert_eq!(result.content, " After script"); - assert_eq!( - result.script_source, - Some("local x = 10\nlocal y = 20\nprint(x + y)".to_string()) - ); - - let result = parser.parse_chunk(" there's more text"); - assert_eq!(result.content, " there's more text"); - assert_eq!(result.script_source, None); - } - - #[test] - fn test_partial_start_tag_matching() { - let mut parser = ScriptTagParser::new(); - - // partial match of start tag... - let result = parser.parse_chunk("script content"); - // ...so it gets pushed to content - assert_eq!(result.content, "print(\"Hello\") After", - "No tags here at all", - "local x = 10\nlocal y = 20\nprint(x + y)", - "Text if true then\nprint(\"nested more", - ]; - - let seed = SystemTime::now() - .duration_since(UNIX_EPOCH) - .unwrap() - .as_secs(); - - eprintln!("Using random seed: {}", seed); - let mut rng = StdRng::seed_from_u64(seed); - - for test_input in &test_inputs { - let mut reference_parser = ScriptTagParser::new(); - let expected = reference_parser.parse_chunk(test_input); - - let mut chunked_parser = ScriptTagParser::new(); - let mut remaining = test_input.as_bytes(); - let mut actual_content = String::new(); - let mut actual_script = None; - - while !remaining.is_empty() { - let chunk_size = rng.gen_range(1..=remaining.len().min(5)); - let (chunk, rest) = remaining.split_at(chunk_size); - remaining = rest; - - let chunk_str = std::str::from_utf8(chunk).unwrap(); - let result = chunked_parser.parse_chunk(chunk_str); - - actual_content.push_str(&result.content); - if result.script_source.is_some() { - actual_script = result.script_source; - } - } - - assert_eq!(actual_content, expected.content); - assert_eq!(actual_script, expected.script_source); - } - } -} diff --git a/crates/assistant_scripting/Cargo.toml b/crates/scripting_tool/Cargo.toml similarity index 87% rename from crates/assistant_scripting/Cargo.toml rename to crates/scripting_tool/Cargo.toml index 2993b646f4..18ea42eb87 100644 --- a/crates/assistant_scripting/Cargo.toml +++ b/crates/scripting_tool/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "assistant_scripting" +name = "scripting_tool" version = "0.1.0" edition.workspace = true publish.workspace = true @@ -9,11 +9,12 @@ license = "GPL-3.0-or-later" workspace = true [lib] -path = "src/assistant_scripting.rs" +path = "src/scripting_tool.rs" doctest = false [dependencies] anyhow.workspace = true +assistant_tool.workspace = true collections.workspace = true futures.workspace = true gpui.workspace = true @@ -22,6 +23,7 @@ mlua.workspace = true parking_lot.workspace = true project.workspace = true regex.workspace = true +schemars.workspace = true serde.workspace = true serde_json.workspace = true settings.workspace = true diff --git a/crates/assistant_scripting/LICENSE-GPL b/crates/scripting_tool/LICENSE-GPL similarity index 100% rename from crates/assistant_scripting/LICENSE-GPL rename to crates/scripting_tool/LICENSE-GPL diff --git a/crates/assistant_scripting/src/sandbox_preamble.lua b/crates/scripting_tool/src/sandbox_preamble.lua similarity index 100% rename from crates/assistant_scripting/src/sandbox_preamble.lua rename to crates/scripting_tool/src/sandbox_preamble.lua diff --git a/crates/scripting_tool/src/scripting_tool.rs b/crates/scripting_tool/src/scripting_tool.rs new file mode 100644 index 0000000000..53a4cf944c --- /dev/null +++ b/crates/scripting_tool/src/scripting_tool.rs @@ -0,0 +1,74 @@ +mod session; + +use project::Project; +use session::*; + +use assistant_tool::{Tool, ToolRegistry}; +use gpui::{App, AppContext as _, Entity, Task}; +use schemars::JsonSchema; +use serde::Deserialize; +use std::sync::Arc; + +pub fn init(cx: &App) { + let registry = ToolRegistry::global(cx); + registry.register_tool(ScriptingTool); +} + +#[derive(Debug, Deserialize, JsonSchema)] +struct ScriptingToolInput { + lua_script: String, +} + +struct ScriptingTool; + +impl Tool for ScriptingTool { + fn name(&self) -> String { + "lua-interpreter".into() + } + + fn description(&self) -> String { + include_str!("scripting_tool_description.txt").into() + } + + fn input_schema(&self) -> serde_json::Value { + let schema = schemars::schema_for!(ScriptingToolInput); + serde_json::to_value(&schema).unwrap() + } + + fn run( + self: Arc, + input: serde_json::Value, + project: Entity, + cx: &mut App, + ) -> Task> { + let input = match serde_json::from_value::(input) { + Err(err) => return Task::ready(Err(err.into())), + Ok(input) => input, + }; + + // TODO: Store a session per thread + let session = cx.new(|cx| ScriptSession::new(project, cx)); + let lua_script = input.lua_script; + + let (script_id, script_task) = + session.update(cx, |session, cx| session.run_script(lua_script, cx)); + + cx.spawn(|cx| async move { + script_task.await; + + let message = session.read_with(&cx, |session, _cx| { + // Using a id to get the script output seems impractical. + // Why not just include it in the Task result? + // This is because we'll later report the script state as it runs, + // currently not supported by the `Tool` interface. + session + .get(script_id) + .output_message_for_llm() + .expect("Script shouldn't still be running") + })?; + + drop(session); + Ok(message) + }) + } +} diff --git a/crates/scripting_tool/src/scripting_tool_description.txt b/crates/scripting_tool/src/scripting_tool_description.txt new file mode 100644 index 0000000000..cd983336cf --- /dev/null +++ b/crates/scripting_tool/src/scripting_tool_description.txt @@ -0,0 +1,22 @@ +You can write a Lua script and I'll run it on my codebase and tell you what its +output was, including both stdout as well as the git diff of changes it made to + the filesystem. That way, you can get more information about the code base, or + make changes to the code base directly. + + The Lua script will have access to `io` and it will run with the current working + directory being in the root of the code base, so you can use it to explore, + search, make changes, etc. You can also have the script print things, and I'll + tell you what the output was. Note that `io` only has `open`, and then the file + it returns only has the methods read, write, and close - it doesn't have popen + or anything else. + + Also, I'm going to be putting this Lua script into JSON, so please don't use + Lua's double quote syntax for string literals - use one of Lua's other syntaxes + for string literals, so I don't have to escape the double quotes. + + There will be a global called `search` which accepts a regex (it's implemented + using Rust's regex crate, so use that regex syntax) and runs that regex on the + contents of every file in the code base (aside from gitignored files), then + returns an array of tables with two fields: "path" (the path to the file that + had the matches) and "matches" (an array of strings, with each string being a + match that was found within the file). diff --git a/crates/assistant_scripting/src/session.rs b/crates/scripting_tool/src/session.rs similarity index 94% rename from crates/assistant_scripting/src/session.rs rename to crates/scripting_tool/src/session.rs index 1e26bb5beb..1f92214fe8 100644 --- a/crates/assistant_scripting/src/session.rs +++ b/crates/scripting_tool/src/session.rs @@ -4,7 +4,7 @@ use futures::{ channel::{mpsc, oneshot}, pin_mut, SinkExt, StreamExt, }; -use gpui::{AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity}; +use gpui::{AppContext, AsyncApp, Context, Entity, Task, WeakEntity}; use mlua::{ExternalResult, Lua, MultiValue, Table, UserData, UserDataMethods}; use parking_lot::Mutex; use project::{search::SearchQuery, Fs, Project}; @@ -16,8 +16,6 @@ use std::{ }; use util::{paths::PathMatcher, ResultExt}; -use crate::{SCRIPT_END_TAG, SCRIPT_START_TAG}; - struct ForegroundFn(Box, AsyncApp) + Send>); pub struct ScriptSession { @@ -45,50 +43,41 @@ impl ScriptSession { } } - pub fn new_script(&mut self) -> ScriptId { - let id = ScriptId(self.scripts.len() as u32); - let script = Script { - id, - state: ScriptState::Generating, - source: SharedString::new_static(""), - }; - self.scripts.push(script); - id - } - pub fn run_script( &mut self, - script_id: ScriptId, script_src: String, cx: &mut Context, - ) -> Task> { - let script = self.get_mut(script_id); + ) -> (ScriptId, Task<()>) { + let id = ScriptId(self.scripts.len() as u32); let stdout = Arc::new(Mutex::new(String::new())); - script.source = script_src.clone().into(); - script.state = ScriptState::Running { - stdout: stdout.clone(), + + let script = Script { + state: ScriptState::Running { + stdout: stdout.clone(), + }, }; + self.scripts.push(script); let task = self.run_lua(script_src, stdout, cx); - cx.emit(ScriptEvent::Spawned(script_id)); - - cx.spawn(|session, mut cx| async move { + let task = cx.spawn(|session, mut cx| async move { let result = task.await; - session.update(&mut cx, |session, cx| { - let script = session.get_mut(script_id); - let stdout = script.stdout_snapshot(); + session + .update(&mut cx, |session, _cx| { + let script = session.get_mut(id); + let stdout = script.stdout_snapshot(); - script.state = match result { - Ok(()) => ScriptState::Succeeded { stdout }, - Err(error) => ScriptState::Failed { stdout, error }, - }; + script.state = match result { + Ok(()) => ScriptState::Succeeded { stdout }, + Err(error) => ScriptState::Failed { stdout, error }, + }; + }) + .log_err(); + }); - cx.emit(ScriptEvent::Exited(script_id)) - }) - }) + (id, task) } fn run_lua( @@ -808,25 +797,14 @@ impl UserData for FileContent { } } -#[derive(Debug)] -pub enum ScriptEvent { - Spawned(ScriptId), - Exited(ScriptId), -} - -impl EventEmitter for ScriptSession {} - #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] pub struct ScriptId(u32); pub struct Script { - pub id: ScriptId, pub state: ScriptState, - pub source: SharedString, } pub enum ScriptState { - Generating, Running { stdout: Arc>, }, @@ -840,14 +818,9 @@ pub enum ScriptState { } impl Script { - pub fn source_tag(&self) -> String { - format!("{}{}{}", SCRIPT_START_TAG, self.source, SCRIPT_END_TAG) - } - /// If exited, returns a message with the output for the LLM pub fn output_message_for_llm(&self) -> Option { match &self.state { - ScriptState::Generating { .. } => None, ScriptState::Running { .. } => None, ScriptState::Succeeded { stdout } => { format!("Here's the script output:\n{}", stdout).into() @@ -863,22 +836,11 @@ impl Script { /// Get a snapshot of the script's stdout pub fn stdout_snapshot(&self) -> String { match &self.state { - ScriptState::Generating { .. } => String::new(), ScriptState::Running { stdout } => stdout.lock().clone(), ScriptState::Succeeded { stdout } => stdout.clone(), ScriptState::Failed { stdout, .. } => stdout.clone(), } } - - /// Returns the error if the script failed, otherwise None - pub fn error(&self) -> Option<&anyhow::Error> { - match &self.state { - ScriptState::Generating { .. } => None, - ScriptState::Running { .. } => None, - ScriptState::Succeeded { .. } => None, - ScriptState::Failed { error, .. } => Some(error), - } - } } #[cfg(test)] @@ -933,14 +895,10 @@ mod tests { let project = Project::test(fs, [Path::new("/")], cx).await; let session = cx.new(|cx| ScriptSession::new(project, cx)); - let (script_id, task) = session.update(cx, |session, cx| { - let script_id = session.new_script(); - let task = session.run_script(script_id, source.to_string(), cx); + let (script_id, task) = + session.update(cx, |session, cx| session.run_script(source.to_string(), cx)); - (script_id, task) - }); - - task.await?; + task.await; Ok(session.read_with(cx, |session, _cx| session.get(script_id).stdout_snapshot())) } diff --git a/crates/zed/Cargo.toml b/crates/zed/Cargo.toml index 4fcce5cd15..4699a42824 100644 --- a/crates/zed/Cargo.toml +++ b/crates/zed/Cargo.toml @@ -98,6 +98,7 @@ remote.workspace = true repl.workspace = true reqwest_client.workspace = true rope.workspace = true +scripting_tool.workspace = true search.workspace = true serde.workspace = true serde_json.workspace = true diff --git a/crates/zed/src/main.rs b/crates/zed/src/main.rs index 4133a15867..901cecae4b 100644 --- a/crates/zed/src/main.rs +++ b/crates/zed/src/main.rs @@ -476,6 +476,7 @@ fn main() { cx, ); assistant_tools::init(cx); + scripting_tool::init(cx); repl::init(app_state.fs.clone(), cx); extension_host::init( extension_host_proxy,