Make slash command output streamable (#19632)

This PR adds support for streaming output from slash commands

In this PR we are focused primarily on the interface of the
`SlashCommand` trait to support streaming the output. We will follow up
later with support for extensions and context servers to take advantage
of the streaming nature.

Release Notes:

- N/A

---------

Co-authored-by: David Soria Parra <davidsp@anthropic.com>
Co-authored-by: Antonio Scandurra <me@as-cii.com>
Co-authored-by: David <david@anthropic.com>
Co-authored-by: Antonio <antonio@zed.dev>
Co-authored-by: Max <max@zed.dev>
Co-authored-by: Max Brunsfeld <maxbrunsfeld@gmail.com>
Co-authored-by: Will <will@zed.dev>
This commit is contained in:
Marshall Bowers 2024-11-06 19:24:43 -05:00 committed by GitHub
parent f6fbf662b4
commit b129e18396
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 1130 additions and 501 deletions

View file

@ -2,14 +2,19 @@ use super::{AssistantEdit, MessageCacheMetadata};
use crate::{
assistant_panel, prompt_library, slash_command::file_command, AssistantEditKind, CacheStatus,
Context, ContextEvent, ContextId, ContextOperation, MessageId, MessageStatus, PromptBuilder,
SlashCommandId,
};
use anyhow::Result;
use assistant_slash_command::{
ArgumentCompletion, SlashCommand, SlashCommandOutput, SlashCommandOutputSection,
SlashCommandRegistry, SlashCommandResult,
ArgumentCompletion, SlashCommand, SlashCommandContent, SlashCommandEvent, SlashCommandOutput,
SlashCommandOutputSection, SlashCommandRegistry, SlashCommandResult,
};
use collections::HashSet;
use collections::{HashMap, HashSet};
use fs::FakeFs;
use futures::{
channel::mpsc,
stream::{self, StreamExt},
};
use gpui::{AppContext, Model, SharedString, Task, TestAppContext, WeakView};
use language::{Buffer, BufferSnapshot, LanguageRegistry, LspAdapterDelegate};
use language_model::{LanguageModelCacheConfiguration, LanguageModelRegistry, Role};
@ -27,8 +32,8 @@ use std::{
rc::Rc,
sync::{atomic::AtomicBool, Arc},
};
use text::{network::Network, OffsetRangeExt as _, ReplicaId};
use ui::{Context as _, WindowContext};
use text::{network::Network, OffsetRangeExt as _, ReplicaId, ToOffset};
use ui::{Context as _, IconName, WindowContext};
use unindent::Unindent;
use util::{
test::{generate_marked_text, marked_text_ranges},
@ -381,20 +386,41 @@ async fn test_slash_commands(cx: &mut TestAppContext) {
let context =
cx.new_model(|cx| Context::local(registry.clone(), None, None, prompt_builder.clone(), cx));
let output_ranges = Rc::new(RefCell::new(HashSet::default()));
#[derive(Default)]
struct ContextRanges {
parsed_commands: HashSet<Range<language::Anchor>>,
command_outputs: HashMap<SlashCommandId, Range<language::Anchor>>,
output_sections: HashSet<Range<language::Anchor>>,
}
let context_ranges = Rc::new(RefCell::new(ContextRanges::default()));
context.update(cx, |_, cx| {
cx.subscribe(&context, {
let ranges = output_ranges.clone();
move |_, _, event, _| match event {
ContextEvent::PendingSlashCommandsUpdated { removed, updated } => {
for range in removed {
ranges.borrow_mut().remove(range);
let context_ranges = context_ranges.clone();
move |context, _, event, _| {
let mut context_ranges = context_ranges.borrow_mut();
match event {
ContextEvent::InvokedSlashCommandChanged { command_id } => {
let command = context.invoked_slash_command(command_id).unwrap();
context_ranges
.command_outputs
.insert(*command_id, command.range.clone());
}
for command in updated {
ranges.borrow_mut().insert(command.source_range.clone());
ContextEvent::ParsedSlashCommandsUpdated { removed, updated } => {
for range in removed {
context_ranges.parsed_commands.remove(range);
}
for command in updated {
context_ranges
.parsed_commands
.insert(command.source_range.clone());
}
}
ContextEvent::SlashCommandOutputSectionAdded { section } => {
context_ranges.output_sections.insert(section.range.clone());
}
_ => {}
}
_ => {}
}
})
.detach();
@ -406,14 +432,12 @@ async fn test_slash_commands(cx: &mut TestAppContext) {
buffer.update(cx, |buffer, cx| {
buffer.edit([(0..0, "/file src/lib.rs")], None, cx);
});
assert_text_and_output_ranges(
assert_text_and_context_ranges(
&buffer,
&output_ranges.borrow(),
"
«/file src/lib.rs»
"
.unindent()
.trim_end(),
&context_ranges,
&"
«/file src/lib.rs»"
.unindent(),
cx,
);
@ -422,14 +446,12 @@ async fn test_slash_commands(cx: &mut TestAppContext) {
let edit_offset = buffer.text().find("lib.rs").unwrap();
buffer.edit([(edit_offset..edit_offset + "lib".len(), "main")], None, cx);
});
assert_text_and_output_ranges(
assert_text_and_context_ranges(
&buffer,
&output_ranges.borrow(),
"
«/file src/main.rs»
"
.unindent()
.trim_end(),
&context_ranges,
&"
«/file src/main.rs»"
.unindent(),
cx,
);
@ -442,36 +464,180 @@ async fn test_slash_commands(cx: &mut TestAppContext) {
cx,
);
});
assert_text_and_output_ranges(
assert_text_and_context_ranges(
&buffer,
&output_ranges.borrow(),
&context_ranges,
&"
/unknown src/main.rs"
.unindent(),
cx,
);
// Undoing the insertion of an non-existent slash command resorts the previous one.
buffer.update(cx, |buffer, cx| buffer.undo(cx));
assert_text_and_context_ranges(
&buffer,
&context_ranges,
&"
«/file src/main.rs»"
.unindent(),
cx,
);
let (command_output_tx, command_output_rx) = mpsc::unbounded();
context.update(cx, |context, cx| {
let command_source_range = context.parsed_slash_commands[0].source_range.clone();
context.insert_command_output(
command_source_range,
"file",
Task::ready(Ok(command_output_rx.boxed())),
true,
cx,
);
});
assert_text_and_context_ranges(
&buffer,
&context_ranges,
&"
«/file src/main.rs»
"
/unknown src/main.rs
.unindent(),
cx,
);
command_output_tx
.unbounded_send(Ok(SlashCommandEvent::StartSection {
icon: IconName::Ai,
label: "src/main.rs".into(),
metadata: None,
}))
.unwrap();
command_output_tx
.unbounded_send(Ok(SlashCommandEvent::Content("src/main.rs".into())))
.unwrap();
cx.run_until_parked();
assert_text_and_context_ranges(
&buffer,
&context_ranges,
&"
«/file src/main.rs»
src/main.rs
"
.unindent()
.trim_end(),
.unindent(),
cx,
);
command_output_tx
.unbounded_send(Ok(SlashCommandEvent::Content("\nfn main() {}".into())))
.unwrap();
cx.run_until_parked();
assert_text_and_context_ranges(
&buffer,
&context_ranges,
&"
«/file src/main.rs»
src/main.rs
fn main() {}
"
.unindent(),
cx,
);
command_output_tx
.unbounded_send(Ok(SlashCommandEvent::EndSection { metadata: None }))
.unwrap();
cx.run_until_parked();
assert_text_and_context_ranges(
&buffer,
&context_ranges,
&"
«/file src/main.rs»
src/main.rs
fn main() {}
"
.unindent(),
cx,
);
drop(command_output_tx);
cx.run_until_parked();
assert_text_and_context_ranges(
&buffer,
&context_ranges,
&"
src/main.rs
fn main() {}
"
.unindent(),
cx,
);
#[track_caller]
fn assert_text_and_output_ranges(
fn assert_text_and_context_ranges(
buffer: &Model<Buffer>,
ranges: &HashSet<Range<language::Anchor>>,
ranges: &RefCell<ContextRanges>,
expected_marked_text: &str,
cx: &mut TestAppContext,
) {
let (expected_text, expected_ranges) = marked_text_ranges(expected_marked_text, false);
let (actual_text, actual_ranges) = buffer.update(cx, |buffer, _| {
let mut ranges = ranges
.iter()
.map(|range| range.to_offset(buffer))
.collect::<Vec<_>>();
ranges.sort_by_key(|a| a.start);
(buffer.text(), ranges)
let mut actual_marked_text = String::new();
buffer.update(cx, |buffer, _| {
struct Endpoint {
offset: usize,
marker: char,
}
let ranges = ranges.borrow();
let mut endpoints = Vec::new();
for range in ranges.command_outputs.values() {
endpoints.push(Endpoint {
offset: range.start.to_offset(buffer),
marker: '',
});
}
for range in ranges.parsed_commands.iter() {
endpoints.push(Endpoint {
offset: range.start.to_offset(buffer),
marker: '«',
});
}
for range in ranges.output_sections.iter() {
endpoints.push(Endpoint {
offset: range.start.to_offset(buffer),
marker: '',
});
}
for range in ranges.output_sections.iter() {
endpoints.push(Endpoint {
offset: range.end.to_offset(buffer),
marker: '',
});
}
for range in ranges.parsed_commands.iter() {
endpoints.push(Endpoint {
offset: range.end.to_offset(buffer),
marker: '»',
});
}
for range in ranges.command_outputs.values() {
endpoints.push(Endpoint {
offset: range.end.to_offset(buffer),
marker: '',
});
}
endpoints.sort_by_key(|endpoint| endpoint.offset);
let mut offset = 0;
for endpoint in endpoints {
actual_marked_text.extend(buffer.text_for_range(offset..endpoint.offset));
actual_marked_text.push(endpoint.marker);
offset = endpoint.offset;
}
actual_marked_text.extend(buffer.text_for_range(offset..buffer.len()));
});
assert_eq!(actual_text, expected_text);
assert_eq!(actual_ranges, expected_ranges);
assert_eq!(actual_marked_text, expected_marked_text);
}
}
@ -1063,44 +1229,57 @@ async fn test_random_context_collaboration(cx: &mut TestAppContext, mut rng: Std
offset + 1..offset + 1 + command_text.len()
});
let output_len = rng.gen_range(1..=10);
let output_text = RandomCharIter::new(&mut rng)
.filter(|c| *c != '\r')
.take(output_len)
.take(10)
.collect::<String>();
let mut events = vec![Ok(SlashCommandEvent::StartMessage {
role: Role::User,
merge_same_roles: true,
})];
let num_sections = rng.gen_range(0..=3);
let mut sections = Vec::with_capacity(num_sections);
let mut section_start = 0;
for _ in 0..num_sections {
let section_start = rng.gen_range(0..output_len);
let section_end = rng.gen_range(section_start..=output_len);
sections.push(SlashCommandOutputSection {
range: section_start..section_end,
icon: ui::IconName::Ai,
let mut section_end = rng.gen_range(section_start..=output_text.len());
while !output_text.is_char_boundary(section_end) {
section_end += 1;
}
events.push(Ok(SlashCommandEvent::StartSection {
icon: IconName::Ai,
label: "section".into(),
metadata: None,
});
}));
events.push(Ok(SlashCommandEvent::Content(SlashCommandContent::Text {
text: output_text[section_start..section_end].to_string(),
run_commands_in_text: false,
})));
events.push(Ok(SlashCommandEvent::EndSection { metadata: None }));
section_start = section_end;
}
if section_start < output_text.len() {
events.push(Ok(SlashCommandEvent::Content(SlashCommandContent::Text {
text: output_text[section_start..].to_string(),
run_commands_in_text: false,
})));
}
log::info!(
"Context {}: insert slash command output at {:?} with {:?}",
"Context {}: insert slash command output at {:?} with {:?} events",
context_index,
command_range,
sections
events.len()
);
let command_range = context.buffer.read(cx).anchor_after(command_range.start)
..context.buffer.read(cx).anchor_after(command_range.end);
context.insert_command_output(
command_range,
Task::ready(Ok(SlashCommandOutput {
text: output_text,
sections,
run_commands_in_text: false,
}
.to_event_stream())),
"/command",
Task::ready(Ok(stream::iter(events).boxed())),
true,
false,
cx,
);
});