diff --git a/crates/assistant2/src/message_editor.rs b/crates/assistant2/src/message_editor.rs index d450536f43..67199e141f 100644 --- a/crates/assistant2/src/message_editor.rs +++ b/crates/assistant2/src/message_editor.rs @@ -2,6 +2,7 @@ use std::sync::Arc; use editor::actions::MoveUp; use editor::{Editor, EditorElement, EditorEvent, EditorStyle}; +use file_icons::FileIcons; use fs::Fs; use gpui::{ Animation, AnimationExt, App, DismissEvent, Entity, Focusable, Subscription, TextStyle, @@ -15,8 +16,8 @@ use std::time::Duration; use text::Bias; use theme::ThemeSettings; use ui::{ - prelude::*, ButtonLike, KeyBinding, PlatformStyle, PopoverMenu, PopoverMenuHandle, Switch, - Tooltip, + prelude::*, ButtonLike, Disclosure, KeyBinding, PlatformStyle, PopoverMenu, PopoverMenuHandle, + Switch, Tooltip, }; use vim_mode_setting::VimModeSetting; use workspace::Workspace; @@ -39,6 +40,7 @@ pub struct MessageEditor { inline_context_picker_menu_handle: PopoverMenuHandle, model_selector: Entity, use_tools: bool, + edits_expanded: bool, _subscriptions: Vec, } @@ -117,6 +119,7 @@ impl MessageEditor { ) }), use_tools: false, + edits_expanded: false, _subscriptions: subscriptions, } } @@ -303,6 +306,9 @@ impl Render for MessageEditor { px(64.) }; + let changed_buffers = self.thread.read(cx).scripting_changed_buffers(cx); + let changed_buffers_count = changed_buffers.len(); + v_flex() .size_full() .when(is_streaming_completion, |parent| { @@ -363,6 +369,109 @@ impl Render for MessageEditor { ), ) }) + .when(changed_buffers_count > 0, |parent| { + parent.child( + v_flex() + .mx_2() + .bg(cx.theme().colors().element_background) + .border_1() + .border_b_0() + .border_color(cx.theme().colors().border) + .rounded_t_md() + .child( + h_flex() + .gap_2() + .p_2() + .child( + Disclosure::new("edits-disclosure", self.edits_expanded) + .on_click(cx.listener(|this, _ev, _window, cx| { + this.edits_expanded = !this.edits_expanded; + cx.notify(); + })), + ) + .child( + Label::new("Edits") + .size(LabelSize::XSmall) + .color(Color::Muted), + ) + .child(Label::new("•").size(LabelSize::XSmall).color(Color::Muted)) + .child( + Label::new(format!( + "{} {}", + changed_buffers_count, + if changed_buffers_count == 1 { + "file" + } else { + "files" + } + )) + .size(LabelSize::XSmall) + .color(Color::Muted), + ), + ) + .when(self.edits_expanded, |parent| { + parent.child( + v_flex().bg(cx.theme().colors().editor_background).children( + changed_buffers.enumerate().flat_map(|(index, buffer)| { + let file = buffer.read(cx).file()?; + let path = file.path(); + + let parent_label = path.parent().and_then(|parent| { + let parent_str = parent.to_string_lossy(); + + if parent_str.is_empty() { + None + } else { + Some( + Label::new(format!( + "{}{}", + parent_str, + std::path::MAIN_SEPARATOR_STR + )) + .color(Color::Muted) + .size(LabelSize::Small), + ) + } + }); + + let name_label = path.file_name().map(|name| { + Label::new(name.to_string_lossy().to_string()) + .size(LabelSize::Small) + }); + + let file_icon = FileIcons::get_icon(&path, cx) + .map(Icon::from_path) + .unwrap_or_else(|| Icon::new(IconName::File)); + + let element = div() + .p_2() + .when(index + 1 < changed_buffers_count, |parent| { + parent + .border_color(cx.theme().colors().border) + .border_b_1() + }) + .child( + h_flex() + .gap_2() + .child(file_icon) + .child( + // TODO: handle overflow + h_flex() + .children(parent_label) + .children(name_label), + ) + // TODO: show lines changed + .child(Label::new("+").color(Color::Created)) + .child(Label::new("-").color(Color::Deleted)), + ); + + Some(element) + }), + ), + ) + }), + ) + }) .child( v_flex() .key_context("MessageEditor") diff --git a/crates/assistant2/src/thread.rs b/crates/assistant2/src/thread.rs index bfc7e852c9..6f8d72ef7c 100644 --- a/crates/assistant2/src/thread.rs +++ b/crates/assistant2/src/thread.rs @@ -5,7 +5,7 @@ use assistant_tool::ToolWorkingSet; use chrono::{DateTime, Utc}; use collections::{BTreeMap, HashMap, HashSet}; use futures::StreamExt as _; -use gpui::{App, Context, Entity, EventEmitter, SharedString, Task}; +use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Task}; use language_model::{ LanguageModel, LanguageModelCompletionEvent, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult, @@ -13,7 +13,7 @@ use language_model::{ Role, StopReason, }; use project::Project; -use scripting_tool::ScriptingTool; +use scripting_tool::{ScriptingSession, ScriptingTool}; use serde::{Deserialize, Serialize}; use util::{post_inc, TryFutureExt as _}; use uuid::Uuid; @@ -76,6 +76,7 @@ pub struct Thread { project: Entity, tools: Arc, tool_use: ToolUseState, + scripting_session: Entity, scripting_tool_use: ToolUseState, } @@ -83,8 +84,10 @@ impl Thread { pub fn new( project: Entity, tools: Arc, - _cx: &mut Context, + cx: &mut Context, ) -> Self { + let scripting_session = cx.new(|cx| ScriptingSession::new(project.clone(), cx)); + Self { id: ThreadId::new(), updated_at: Utc::now(), @@ -99,6 +102,7 @@ impl Thread { project, tools, tool_use: ToolUseState::new(), + scripting_session, scripting_tool_use: ToolUseState::new(), } } @@ -108,7 +112,7 @@ impl Thread { saved: SavedThread, project: Entity, tools: Arc, - _cx: &mut Context, + cx: &mut Context, ) -> Self { let next_message_id = MessageId( saved @@ -121,6 +125,7 @@ impl Thread { ToolUseState::from_saved_messages(&saved.messages, |name| name != ScriptingTool::NAME); let scripting_tool_use = ToolUseState::from_saved_messages(&saved.messages, |name| name == ScriptingTool::NAME); + let scripting_session = cx.new(|cx| ScriptingSession::new(project.clone(), cx)); Self { id, @@ -144,6 +149,7 @@ impl Thread { project, tools, tool_use, + scripting_session, scripting_tool_use, } } @@ -237,6 +243,13 @@ impl Thread { self.scripting_tool_use.tool_results_for_message(id) } + pub fn scripting_changed_buffers<'a>( + &self, + cx: &'a App, + ) -> impl ExactSizeIterator> { + self.scripting_session.read(cx).changed_buffers() + } + pub fn message_has_tool_results(&self, message_id: MessageId) -> bool { self.tool_use.message_has_tool_results(message_id) } @@ -637,7 +650,32 @@ impl Thread { .collect::>(); for scripting_tool_use in pending_scripting_tool_uses { - let task = ScriptingTool.run(scripting_tool_use.input, self.project.clone(), cx); + let task = match ScriptingTool::deserialize_input(scripting_tool_use.input) { + Err(err) => Task::ready(Err(err.into())), + Ok(input) => { + let (script_id, script_task) = + self.scripting_session.update(cx, move |session, cx| { + session.run_script(input.lua_script, cx) + }); + + let session = self.scripting_session.clone(); + cx.spawn(|_, cx| async move { + script_task.await; + + let message = session.read_with(&cx, |session, _cx| { + // Using a id to get the script output seems impractical. + // Why not just include it in the Task result? + // This is because we'll later report the script state as it runs, + session + .get(script_id) + .output_message_for_llm() + .expect("Script shouldn't still be running") + })?; + + Ok(message) + }) + } + }; self.insert_scripting_tool_output(scripting_tool_use.id.clone(), task, cx); } diff --git a/crates/scripting_tool/src/session.rs b/crates/scripting_tool/src/scripting_session.rs similarity index 82% rename from crates/scripting_tool/src/session.rs rename to crates/scripting_tool/src/scripting_session.rs index 6cf862c669..465398d52f 100644 --- a/crates/scripting_tool/src/session.rs +++ b/crates/scripting_tool/src/scripting_session.rs @@ -5,6 +5,7 @@ use futures::{ pin_mut, SinkExt, StreamExt, }; use gpui::{AppContext, AsyncApp, Context, Entity, Task, WeakEntity}; +use language::Buffer; use mlua::{ExternalResult, Lua, MultiValue, Table, UserData, UserDataMethods}; use parking_lot::Mutex; use project::{search::SearchQuery, Fs, Project, ProjectPath, WorktreeId}; @@ -19,9 +20,10 @@ struct ForegroundFn(Box, AsyncApp) + Sen pub struct ScriptingSession { project: Entity, + scripts: Vec