assistant: Update SlashCommand trait with streaming return type (#19652)

This PR updates the `SlashCommand` trait to use a streaming return type.

This change is just at the trait layer. The goal here is to decouple
changing the trait's API while preserving behavior on either side.

The `SlashCommandOutput` type now has two methods for converting two and
from a stream to use in cases where we're not yet doing streaming.

On the `SlashCommand` implementer side, the implements can call
`to_event_stream` to produce a stream of events based off the
`SlashCommandOutput`.

On the slash command consumer side we use
`SlashCommandOutput::from_event_stream` to convert a stream of events
back into a `SlashCommandOutput`.

The `/file` slash command has been updated to emit `SlashCommandEvent`s
directly in order for it to work properly.

Release Notes:

- N/A

---------

Co-authored-by: Max <max@zed.dev>
This commit is contained in:
Marshall Bowers 2024-10-23 21:26:50 -04:00 committed by GitHub
parent 510c71d41b
commit d30361537e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
23 changed files with 516 additions and 88 deletions

View file

@ -1,6 +1,8 @@
mod slash_command_registry;
use anyhow::Result;
use futures::stream::{self, BoxStream};
use futures::StreamExt;
use gpui::{AnyElement, AppContext, ElementId, SharedString, Task, WeakView, WindowContext};
use language::{BufferSnapshot, CodeLabel, LspAdapterDelegate, OffsetRangeExt};
use serde::{Deserialize, Serialize};
@ -56,7 +58,7 @@ pub struct ArgumentCompletion {
pub replace_previous_arguments: bool,
}
pub type SlashCommandResult = Result<SlashCommandOutput>;
pub type SlashCommandResult = Result<BoxStream<'static, Result<SlashCommandEvent>>>;
pub trait SlashCommand: 'static + Send + Sync {
fn name(&self) -> String;
@ -98,13 +100,146 @@ pub type RenderFoldPlaceholder = Arc<
+ Fn(ElementId, Arc<dyn Fn(&mut WindowContext)>, &mut WindowContext) -> AnyElement,
>;
#[derive(Debug, Default, PartialEq)]
#[derive(Debug, PartialEq, Eq)]
pub enum SlashCommandContent {
Text {
text: String,
run_commands_in_text: bool,
},
}
#[derive(Debug, PartialEq, Eq)]
pub enum SlashCommandEvent {
StartSection {
icon: IconName,
label: SharedString,
metadata: Option<serde_json::Value>,
},
Content(SlashCommandContent),
EndSection {
metadata: Option<serde_json::Value>,
},
}
#[derive(Debug, Default, PartialEq, Clone)]
pub struct SlashCommandOutput {
pub text: String,
pub sections: Vec<SlashCommandOutputSection<usize>>,
pub run_commands_in_text: bool,
}
impl SlashCommandOutput {
pub fn ensure_valid_section_ranges(&mut self) {
for section in &mut self.sections {
section.range.start = section.range.start.min(self.text.len());
section.range.end = section.range.end.min(self.text.len());
while !self.text.is_char_boundary(section.range.start) {
section.range.start -= 1;
}
while !self.text.is_char_boundary(section.range.end) {
section.range.end += 1;
}
}
}
/// Returns this [`SlashCommandOutput`] as a stream of [`SlashCommandEvent`]s.
pub fn to_event_stream(mut self) -> BoxStream<'static, Result<SlashCommandEvent>> {
self.ensure_valid_section_ranges();
let mut events = Vec::new();
let mut last_section_end = 0;
for section in self.sections {
if last_section_end < section.range.start {
events.push(Ok(SlashCommandEvent::Content(SlashCommandContent::Text {
text: self
.text
.get(last_section_end..section.range.start)
.unwrap_or_default()
.to_string(),
run_commands_in_text: self.run_commands_in_text,
})));
}
events.push(Ok(SlashCommandEvent::StartSection {
icon: section.icon,
label: section.label,
metadata: section.metadata.clone(),
}));
events.push(Ok(SlashCommandEvent::Content(SlashCommandContent::Text {
text: self
.text
.get(section.range.start..section.range.end)
.unwrap_or_default()
.to_string(),
run_commands_in_text: self.run_commands_in_text,
})));
events.push(Ok(SlashCommandEvent::EndSection {
metadata: section.metadata,
}));
last_section_end = section.range.end;
}
if last_section_end < self.text.len() {
events.push(Ok(SlashCommandEvent::Content(SlashCommandContent::Text {
text: self.text[last_section_end..].to_string(),
run_commands_in_text: self.run_commands_in_text,
})));
}
stream::iter(events).boxed()
}
pub async fn from_event_stream(
mut events: BoxStream<'static, Result<SlashCommandEvent>>,
) -> Result<SlashCommandOutput> {
let mut output = SlashCommandOutput::default();
let mut section_stack = Vec::new();
while let Some(event) = events.next().await {
match event? {
SlashCommandEvent::StartSection {
icon,
label,
metadata,
} => {
let start = output.text.len();
section_stack.push(SlashCommandOutputSection {
range: start..start,
icon,
label,
metadata,
});
}
SlashCommandEvent::Content(SlashCommandContent::Text {
text,
run_commands_in_text,
}) => {
output.text.push_str(&text);
output.run_commands_in_text = run_commands_in_text;
if let Some(section) = section_stack.last_mut() {
section.range.end = output.text.len();
}
}
SlashCommandEvent::EndSection { metadata } => {
if let Some(mut section) = section_stack.pop() {
section.metadata = metadata;
output.sections.push(section);
}
}
}
}
while let Some(section) = section_stack.pop() {
output.sections.push(section);
}
Ok(output)
}
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct SlashCommandOutputSection<T> {
pub range: Range<T>,
@ -118,3 +253,243 @@ impl SlashCommandOutputSection<language::Anchor> {
self.range.start.is_valid(buffer) && !self.range.to_offset(buffer).is_empty()
}
}
#[cfg(test)]
mod tests {
use pretty_assertions::assert_eq;
use serde_json::json;
use super::*;
#[gpui::test]
async fn test_slash_command_output_to_events_round_trip() {
// Test basic output consisting of a single section.
{
let text = "Hello, world!".to_string();
let range = 0..text.len();
let output = SlashCommandOutput {
text,
sections: vec![SlashCommandOutputSection {
range,
icon: IconName::Code,
label: "Section 1".into(),
metadata: None,
}],
run_commands_in_text: false,
};
let events = output.clone().to_event_stream().collect::<Vec<_>>().await;
let events = events
.into_iter()
.filter_map(|event| event.ok())
.collect::<Vec<_>>();
assert_eq!(
events,
vec![
SlashCommandEvent::StartSection {
icon: IconName::Code,
label: "Section 1".into(),
metadata: None
},
SlashCommandEvent::Content(SlashCommandContent::Text {
text: "Hello, world!".into(),
run_commands_in_text: false
}),
SlashCommandEvent::EndSection { metadata: None }
]
);
let new_output =
SlashCommandOutput::from_event_stream(output.clone().to_event_stream())
.await
.unwrap();
assert_eq!(new_output, output);
}
// Test output where the sections do not comprise all of the text.
{
let text = "Apple\nCucumber\nBanana\n".to_string();
let output = SlashCommandOutput {
text,
sections: vec![
SlashCommandOutputSection {
range: 0..6,
icon: IconName::Check,
label: "Fruit".into(),
metadata: None,
},
SlashCommandOutputSection {
range: 15..22,
icon: IconName::Check,
label: "Fruit".into(),
metadata: None,
},
],
run_commands_in_text: false,
};
let events = output.clone().to_event_stream().collect::<Vec<_>>().await;
let events = events
.into_iter()
.filter_map(|event| event.ok())
.collect::<Vec<_>>();
assert_eq!(
events,
vec![
SlashCommandEvent::StartSection {
icon: IconName::Check,
label: "Fruit".into(),
metadata: None
},
SlashCommandEvent::Content(SlashCommandContent::Text {
text: "Apple\n".into(),
run_commands_in_text: false
}),
SlashCommandEvent::EndSection { metadata: None },
SlashCommandEvent::Content(SlashCommandContent::Text {
text: "Cucumber\n".into(),
run_commands_in_text: false
}),
SlashCommandEvent::StartSection {
icon: IconName::Check,
label: "Fruit".into(),
metadata: None
},
SlashCommandEvent::Content(SlashCommandContent::Text {
text: "Banana\n".into(),
run_commands_in_text: false
}),
SlashCommandEvent::EndSection { metadata: None }
]
);
let new_output =
SlashCommandOutput::from_event_stream(output.clone().to_event_stream())
.await
.unwrap();
assert_eq!(new_output, output);
}
// Test output consisting of multiple sections.
{
let text = "Line 1\nLine 2\nLine 3\nLine 4\n".to_string();
let output = SlashCommandOutput {
text,
sections: vec![
SlashCommandOutputSection {
range: 0..6,
icon: IconName::FileCode,
label: "Section 1".into(),
metadata: Some(json!({ "a": true })),
},
SlashCommandOutputSection {
range: 7..13,
icon: IconName::FileDoc,
label: "Section 2".into(),
metadata: Some(json!({ "b": true })),
},
SlashCommandOutputSection {
range: 14..20,
icon: IconName::FileGit,
label: "Section 3".into(),
metadata: Some(json!({ "c": true })),
},
SlashCommandOutputSection {
range: 21..27,
icon: IconName::FileToml,
label: "Section 4".into(),
metadata: Some(json!({ "d": true })),
},
],
run_commands_in_text: false,
};
let events = output.clone().to_event_stream().collect::<Vec<_>>().await;
let events = events
.into_iter()
.filter_map(|event| event.ok())
.collect::<Vec<_>>();
assert_eq!(
events,
vec![
SlashCommandEvent::StartSection {
icon: IconName::FileCode,
label: "Section 1".into(),
metadata: Some(json!({ "a": true }))
},
SlashCommandEvent::Content(SlashCommandContent::Text {
text: "Line 1".into(),
run_commands_in_text: false
}),
SlashCommandEvent::EndSection {
metadata: Some(json!({ "a": true }))
},
SlashCommandEvent::Content(SlashCommandContent::Text {
text: "\n".into(),
run_commands_in_text: false
}),
SlashCommandEvent::StartSection {
icon: IconName::FileDoc,
label: "Section 2".into(),
metadata: Some(json!({ "b": true }))
},
SlashCommandEvent::Content(SlashCommandContent::Text {
text: "Line 2".into(),
run_commands_in_text: false
}),
SlashCommandEvent::EndSection {
metadata: Some(json!({ "b": true }))
},
SlashCommandEvent::Content(SlashCommandContent::Text {
text: "\n".into(),
run_commands_in_text: false
}),
SlashCommandEvent::StartSection {
icon: IconName::FileGit,
label: "Section 3".into(),
metadata: Some(json!({ "c": true }))
},
SlashCommandEvent::Content(SlashCommandContent::Text {
text: "Line 3".into(),
run_commands_in_text: false
}),
SlashCommandEvent::EndSection {
metadata: Some(json!({ "c": true }))
},
SlashCommandEvent::Content(SlashCommandContent::Text {
text: "\n".into(),
run_commands_in_text: false
}),
SlashCommandEvent::StartSection {
icon: IconName::FileToml,
label: "Section 4".into(),
metadata: Some(json!({ "d": true }))
},
SlashCommandEvent::Content(SlashCommandContent::Text {
text: "Line 4".into(),
run_commands_in_text: false
}),
SlashCommandEvent::EndSection {
metadata: Some(json!({ "d": true }))
},
SlashCommandEvent::Content(SlashCommandContent::Text {
text: "\n".into(),
run_commands_in_text: false
}),
]
);
let new_output =
SlashCommandOutput::from_event_stream(output.clone().to_event_stream())
.await
.unwrap();
assert_eq!(new_output, output);
}
}
}