assistant: Use tool interface for scripts (#26377)
We decided to expose scripting as tools again. We are aware of the UX downsides of doing so, but we want to focus on getting it working well first, and the model seems to make better use of it as an actual tool. In the future, the tools API might support streaming. If it doesn't and we need to ship, we can consider reverting this. Release Notes: - N/A
This commit is contained in:
parent
3891381d3e
commit
2fc4dec58f
17 changed files with 163 additions and 652 deletions
|
@ -1,12 +1,11 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use assistant_scripting::{ScriptId, ScriptState};
|
||||
use collections::{HashMap, HashSet};
|
||||
use collections::HashMap;
|
||||
use editor::{Editor, MultiBuffer};
|
||||
use gpui::{
|
||||
list, AbsoluteLength, AnyElement, App, ClickEvent, DefiniteLength, EdgesRefinement, Empty,
|
||||
Entity, Focusable, Length, ListAlignment, ListOffset, ListState, StyleRefinement, Subscription,
|
||||
Task, TextStyleRefinement, UnderlineStyle, WeakEntity,
|
||||
Task, TextStyleRefinement, UnderlineStyle,
|
||||
};
|
||||
use language::{Buffer, LanguageRegistry};
|
||||
use language_model::{LanguageModelRegistry, LanguageModelToolUseId, Role};
|
||||
|
@ -15,7 +14,6 @@ use settings::Settings as _;
|
|||
use theme::ThemeSettings;
|
||||
use ui::{prelude::*, Disclosure, KeyBinding};
|
||||
use util::ResultExt as _;
|
||||
use workspace::Workspace;
|
||||
|
||||
use crate::thread::{MessageId, RequestKind, Thread, ThreadError, ThreadEvent};
|
||||
use crate::thread_store::ThreadStore;
|
||||
|
@ -23,7 +21,6 @@ use crate::tool_use::{ToolUse, ToolUseStatus};
|
|||
use crate::ui::ContextPill;
|
||||
|
||||
pub struct ActiveThread {
|
||||
workspace: WeakEntity<Workspace>,
|
||||
language_registry: Arc<LanguageRegistry>,
|
||||
thread_store: Entity<ThreadStore>,
|
||||
thread: Entity<Thread>,
|
||||
|
@ -33,7 +30,6 @@ pub struct ActiveThread {
|
|||
rendered_messages_by_id: HashMap<MessageId, Entity<Markdown>>,
|
||||
editing_message: Option<(MessageId, EditMessageState)>,
|
||||
expanded_tool_uses: HashMap<LanguageModelToolUseId, bool>,
|
||||
expanded_scripts: HashSet<ScriptId>,
|
||||
last_error: Option<ThreadError>,
|
||||
_subscriptions: Vec<Subscription>,
|
||||
}
|
||||
|
@ -44,7 +40,6 @@ struct EditMessageState {
|
|||
|
||||
impl ActiveThread {
|
||||
pub fn new(
|
||||
workspace: WeakEntity<Workspace>,
|
||||
thread: Entity<Thread>,
|
||||
thread_store: Entity<ThreadStore>,
|
||||
language_registry: Arc<LanguageRegistry>,
|
||||
|
@ -57,7 +52,6 @@ impl ActiveThread {
|
|||
];
|
||||
|
||||
let mut this = Self {
|
||||
workspace,
|
||||
language_registry,
|
||||
thread_store,
|
||||
thread: thread.clone(),
|
||||
|
@ -65,7 +59,6 @@ impl ActiveThread {
|
|||
messages: Vec::new(),
|
||||
rendered_messages_by_id: HashMap::default(),
|
||||
expanded_tool_uses: HashMap::default(),
|
||||
expanded_scripts: HashSet::default(),
|
||||
list_state: ListState::new(0, ListAlignment::Bottom, px(1024.), {
|
||||
let this = cx.entity().downgrade();
|
||||
move |ix, window: &mut Window, cx: &mut App| {
|
||||
|
@ -466,10 +459,7 @@ impl ActiveThread {
|
|||
let tool_uses = thread.tool_uses_for_message(message_id);
|
||||
|
||||
// Don't render user messages that are just there for returning tool results.
|
||||
if message.role == Role::User
|
||||
&& (thread.message_has_tool_results(message_id)
|
||||
|| thread.message_has_script_output(message_id))
|
||||
{
|
||||
if message.role == Role::User && thread.message_has_tool_results(message_id) {
|
||||
return Empty.into_any();
|
||||
}
|
||||
|
||||
|
@ -618,7 +608,6 @@ impl ActiveThread {
|
|||
Role::Assistant => div()
|
||||
.id(("message-container", ix))
|
||||
.child(message_content)
|
||||
.children(self.render_script(message_id, cx))
|
||||
.map(|parent| {
|
||||
if tool_uses.is_empty() {
|
||||
return parent;
|
||||
|
@ -738,139 +727,6 @@ impl ActiveThread {
|
|||
}),
|
||||
)
|
||||
}
|
||||
|
||||
fn render_script(&self, message_id: MessageId, cx: &mut Context<Self>) -> Option<AnyElement> {
|
||||
let script = self.thread.read(cx).script_for_message(message_id, cx)?;
|
||||
|
||||
let is_open = self.expanded_scripts.contains(&script.id);
|
||||
let colors = cx.theme().colors();
|
||||
|
||||
let element = div().px_2p5().child(
|
||||
v_flex()
|
||||
.gap_1()
|
||||
.rounded_lg()
|
||||
.border_1()
|
||||
.border_color(colors.border)
|
||||
.child(
|
||||
h_flex()
|
||||
.justify_between()
|
||||
.py_0p5()
|
||||
.pl_1()
|
||||
.pr_2()
|
||||
.bg(colors.editor_foreground.opacity(0.02))
|
||||
.when(is_open, |element| element.border_b_1().rounded_t(px(6.)))
|
||||
.when(!is_open, |element| element.rounded_md())
|
||||
.border_color(colors.border)
|
||||
.child(
|
||||
h_flex()
|
||||
.gap_1()
|
||||
.child(Disclosure::new("script-disclosure", is_open).on_click(
|
||||
cx.listener({
|
||||
let script_id = script.id;
|
||||
move |this, _event, _window, _cx| {
|
||||
if this.expanded_scripts.contains(&script_id) {
|
||||
this.expanded_scripts.remove(&script_id);
|
||||
} else {
|
||||
this.expanded_scripts.insert(script_id);
|
||||
}
|
||||
}
|
||||
}),
|
||||
))
|
||||
// TODO: Generate script description
|
||||
.child(Label::new("Script")),
|
||||
)
|
||||
.child(
|
||||
h_flex()
|
||||
.gap_1()
|
||||
.child(
|
||||
Label::new(match script.state {
|
||||
ScriptState::Generating => "Generating",
|
||||
ScriptState::Running { .. } => "Running",
|
||||
ScriptState::Succeeded { .. } => "Finished",
|
||||
ScriptState::Failed { .. } => "Error",
|
||||
})
|
||||
.size(LabelSize::XSmall)
|
||||
.buffer_font(cx),
|
||||
)
|
||||
.child(
|
||||
IconButton::new("view-source", IconName::Eye)
|
||||
.icon_color(Color::Muted)
|
||||
.disabled(matches!(script.state, ScriptState::Generating))
|
||||
.on_click(cx.listener({
|
||||
let source = script.source.clone();
|
||||
move |this, _event, window, cx| {
|
||||
this.open_script_source(source.clone(), window, cx);
|
||||
}
|
||||
})),
|
||||
),
|
||||
),
|
||||
)
|
||||
.when(is_open, |parent| {
|
||||
let stdout = script.stdout_snapshot();
|
||||
let error = script.error();
|
||||
|
||||
parent.child(
|
||||
v_flex()
|
||||
.p_2()
|
||||
.bg(colors.editor_background)
|
||||
.gap_2()
|
||||
.child(if stdout.is_empty() && error.is_none() {
|
||||
Label::new("No output yet")
|
||||
.size(LabelSize::Small)
|
||||
.color(Color::Muted)
|
||||
} else {
|
||||
Label::new(stdout).size(LabelSize::Small).buffer_font(cx)
|
||||
})
|
||||
.children(script.error().map(|err| {
|
||||
Label::new(err.to_string())
|
||||
.size(LabelSize::Small)
|
||||
.color(Color::Error)
|
||||
})),
|
||||
)
|
||||
}),
|
||||
);
|
||||
|
||||
Some(element.into_any())
|
||||
}
|
||||
|
||||
fn open_script_source(
|
||||
&mut self,
|
||||
source: SharedString,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<'_, ActiveThread>,
|
||||
) {
|
||||
let language_registry = self.language_registry.clone();
|
||||
let workspace = self.workspace.clone();
|
||||
let source = source.clone();
|
||||
|
||||
cx.spawn_in(window, |_, mut cx| async move {
|
||||
let lua = language_registry.language_for_name("Lua").await.log_err();
|
||||
|
||||
workspace.update_in(&mut cx, |workspace, window, cx| {
|
||||
let project = workspace.project().clone();
|
||||
|
||||
let buffer = project.update(cx, |project, cx| {
|
||||
project.create_local_buffer(&source.trim(), lua, cx)
|
||||
});
|
||||
|
||||
let buffer = cx.new(|cx| {
|
||||
MultiBuffer::singleton(buffer, cx)
|
||||
// TODO: Generate script description
|
||||
.with_title("Assistant script".into())
|
||||
});
|
||||
|
||||
let editor = cx.new(|cx| {
|
||||
let mut editor =
|
||||
Editor::for_multibuffer(buffer, Some(project), true, window, cx);
|
||||
editor.set_read_only(true);
|
||||
editor
|
||||
});
|
||||
|
||||
workspace.add_item_to_active_pane(Box::new(editor), None, true, window, cx);
|
||||
})
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
}
|
||||
}
|
||||
|
||||
impl Render for ActiveThread {
|
||||
|
|
|
@ -168,7 +168,6 @@ impl AssistantPanel {
|
|||
|
||||
let thread = cx.new(|cx| {
|
||||
ActiveThread::new(
|
||||
workspace.clone(),
|
||||
thread.clone(),
|
||||
thread_store.clone(),
|
||||
language_registry.clone(),
|
||||
|
@ -242,7 +241,6 @@ impl AssistantPanel {
|
|||
self.active_view = ActiveView::Thread;
|
||||
self.thread = cx.new(|cx| {
|
||||
ActiveThread::new(
|
||||
self.workspace.clone(),
|
||||
thread.clone(),
|
||||
self.thread_store.clone(),
|
||||
self.language_registry.clone(),
|
||||
|
@ -376,7 +374,6 @@ impl AssistantPanel {
|
|||
this.active_view = ActiveView::Thread;
|
||||
this.thread = cx.new(|cx| {
|
||||
ActiveThread::new(
|
||||
this.workspace.clone(),
|
||||
thread.clone(),
|
||||
this.thread_store.clone(),
|
||||
this.language_registry.clone(),
|
||||
|
|
|
@ -1,14 +1,11 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use anyhow::Result;
|
||||
use assistant_scripting::{
|
||||
Script, ScriptEvent, ScriptId, ScriptSession, ScriptTagParser, SCRIPTING_PROMPT,
|
||||
};
|
||||
use assistant_tool::ToolWorkingSet;
|
||||
use chrono::{DateTime, Utc};
|
||||
use collections::{BTreeMap, HashMap, HashSet};
|
||||
use futures::StreamExt as _;
|
||||
use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Subscription, Task};
|
||||
use gpui::{App, Context, Entity, EventEmitter, SharedString, Task};
|
||||
use language_model::{
|
||||
LanguageModel, LanguageModelCompletionEvent, LanguageModelRegistry, LanguageModelRequest,
|
||||
LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
|
||||
|
@ -78,21 +75,14 @@ pub struct Thread {
|
|||
project: Entity<Project>,
|
||||
tools: Arc<ToolWorkingSet>,
|
||||
tool_use: ToolUseState,
|
||||
scripts_by_assistant_message: HashMap<MessageId, ScriptId>,
|
||||
script_output_messages: HashSet<MessageId>,
|
||||
script_session: Entity<ScriptSession>,
|
||||
_script_session_subscription: Subscription,
|
||||
}
|
||||
|
||||
impl Thread {
|
||||
pub fn new(
|
||||
project: Entity<Project>,
|
||||
tools: Arc<ToolWorkingSet>,
|
||||
cx: &mut Context<Self>,
|
||||
_cx: &mut Context<Self>,
|
||||
) -> Self {
|
||||
let script_session = cx.new(|cx| ScriptSession::new(project.clone(), cx));
|
||||
let script_session_subscription = cx.subscribe(&script_session, Self::handle_script_event);
|
||||
|
||||
Self {
|
||||
id: ThreadId::new(),
|
||||
updated_at: Utc::now(),
|
||||
|
@ -107,10 +97,6 @@ impl Thread {
|
|||
project,
|
||||
tools,
|
||||
tool_use: ToolUseState::new(),
|
||||
scripts_by_assistant_message: HashMap::default(),
|
||||
script_output_messages: HashSet::default(),
|
||||
script_session,
|
||||
_script_session_subscription: script_session_subscription,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -119,7 +105,7 @@ impl Thread {
|
|||
saved: SavedThread,
|
||||
project: Entity<Project>,
|
||||
tools: Arc<ToolWorkingSet>,
|
||||
cx: &mut Context<Self>,
|
||||
_cx: &mut Context<Self>,
|
||||
) -> Self {
|
||||
let next_message_id = MessageId(
|
||||
saved
|
||||
|
@ -129,8 +115,6 @@ impl Thread {
|
|||
.unwrap_or(0),
|
||||
);
|
||||
let tool_use = ToolUseState::from_saved_messages(&saved.messages);
|
||||
let script_session = cx.new(|cx| ScriptSession::new(project.clone(), cx));
|
||||
let script_session_subscription = cx.subscribe(&script_session, Self::handle_script_event);
|
||||
|
||||
Self {
|
||||
id,
|
||||
|
@ -154,10 +138,6 @@ impl Thread {
|
|||
project,
|
||||
tools,
|
||||
tool_use,
|
||||
scripts_by_assistant_message: HashMap::default(),
|
||||
script_output_messages: HashSet::default(),
|
||||
script_session,
|
||||
_script_session_subscription: script_session_subscription,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -243,10 +223,6 @@ impl Thread {
|
|||
self.tool_use.message_has_tool_results(message_id)
|
||||
}
|
||||
|
||||
pub fn message_has_script_output(&self, message_id: MessageId) -> bool {
|
||||
self.script_output_messages.contains(&message_id)
|
||||
}
|
||||
|
||||
pub fn insert_user_message(
|
||||
&mut self,
|
||||
text: impl Into<String>,
|
||||
|
@ -327,39 +303,6 @@ impl Thread {
|
|||
text
|
||||
}
|
||||
|
||||
pub fn script_for_message<'a>(
|
||||
&'a self,
|
||||
message_id: MessageId,
|
||||
cx: &'a App,
|
||||
) -> Option<&'a Script> {
|
||||
self.scripts_by_assistant_message
|
||||
.get(&message_id)
|
||||
.map(|script_id| self.script_session.read(cx).get(*script_id))
|
||||
}
|
||||
|
||||
fn handle_script_event(
|
||||
&mut self,
|
||||
_script_session: Entity<ScriptSession>,
|
||||
event: &ScriptEvent,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
match event {
|
||||
ScriptEvent::Spawned(_) => {}
|
||||
ScriptEvent::Exited(script_id) => {
|
||||
if let Some(output_message) = self
|
||||
.script_session
|
||||
.read(cx)
|
||||
.get(*script_id)
|
||||
.output_message_for_llm()
|
||||
{
|
||||
let message_id = self.insert_user_message(output_message, vec![], cx);
|
||||
self.script_output_messages.insert(message_id);
|
||||
cx.emit(ThreadEvent::ScriptFinished)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn send_to_model(
|
||||
&mut self,
|
||||
model: Arc<dyn LanguageModel>,
|
||||
|
@ -388,7 +331,7 @@ impl Thread {
|
|||
pub fn to_completion_request(
|
||||
&self,
|
||||
request_kind: RequestKind,
|
||||
cx: &App,
|
||||
_cx: &App,
|
||||
) -> LanguageModelRequest {
|
||||
let mut request = LanguageModelRequest {
|
||||
messages: vec![],
|
||||
|
@ -397,12 +340,6 @@ impl Thread {
|
|||
temperature: None,
|
||||
};
|
||||
|
||||
request.messages.push(LanguageModelRequestMessage {
|
||||
role: Role::System,
|
||||
content: vec![SCRIPTING_PROMPT.to_string().into()],
|
||||
cache: true,
|
||||
});
|
||||
|
||||
let mut referenced_context_ids = HashSet::default();
|
||||
|
||||
for message in &self.messages {
|
||||
|
@ -436,15 +373,6 @@ impl Thread {
|
|||
RequestKind::Chat => {
|
||||
self.tool_use
|
||||
.attach_tool_uses(message.id, &mut request_message);
|
||||
|
||||
if matches!(message.role, Role::Assistant) {
|
||||
if let Some(script_id) = self.scripts_by_assistant_message.get(&message.id)
|
||||
{
|
||||
let script = self.script_session.read(cx).get(*script_id);
|
||||
|
||||
request_message.content.push(script.source_tag().into());
|
||||
}
|
||||
}
|
||||
}
|
||||
RequestKind::Summarize => {
|
||||
// We don't care about tool use during summarization.
|
||||
|
@ -486,8 +414,6 @@ impl Thread {
|
|||
let stream_completion = async {
|
||||
let mut events = stream.await?;
|
||||
let mut stop_reason = StopReason::EndTurn;
|
||||
let mut script_tag_parser = ScriptTagParser::new();
|
||||
let mut script_id = None;
|
||||
|
||||
while let Some(event) = events.next().await {
|
||||
let event = event?;
|
||||
|
@ -502,44 +428,20 @@ impl Thread {
|
|||
}
|
||||
LanguageModelCompletionEvent::Text(chunk) => {
|
||||
if let Some(last_message) = thread.messages.last_mut() {
|
||||
let chunk = script_tag_parser.parse_chunk(&chunk);
|
||||
|
||||
let message_id = if last_message.role == Role::Assistant {
|
||||
last_message.text.push_str(&chunk.content);
|
||||
if last_message.role == Role::Assistant {
|
||||
last_message.text.push_str(&chunk);
|
||||
cx.emit(ThreadEvent::StreamedAssistantText(
|
||||
last_message.id,
|
||||
chunk.content,
|
||||
chunk,
|
||||
));
|
||||
last_message.id
|
||||
} else {
|
||||
// If we won't have an Assistant message yet, assume this chunk marks the beginning
|
||||
// of a new Assistant response.
|
||||
//
|
||||
// Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
|
||||
// will result in duplicating the text of the chunk in the rendered Markdown.
|
||||
thread.insert_message(Role::Assistant, chunk.content, cx)
|
||||
thread.insert_message(Role::Assistant, chunk, cx);
|
||||
};
|
||||
|
||||
if script_id.is_none() && script_tag_parser.found_script() {
|
||||
let id = thread
|
||||
.script_session
|
||||
.update(cx, |session, _cx| session.new_script());
|
||||
thread.scripts_by_assistant_message.insert(message_id, id);
|
||||
|
||||
script_id = Some(id);
|
||||
}
|
||||
|
||||
if let (Some(script_source), Some(script_id)) =
|
||||
(chunk.script_source, script_id)
|
||||
{
|
||||
// TODO: move buffer to script and run as it streams
|
||||
thread
|
||||
.script_session
|
||||
.update(cx, |this, cx| {
|
||||
this.run_script(script_id, script_source, cx)
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
}
|
||||
}
|
||||
}
|
||||
LanguageModelCompletionEvent::ToolUse(tool_use) => {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue