ZIm/crates/assistant_slash_command/src/assistant_slash_command.rs
2024-11-08 09:29:14 +01:00

496 lines
17 KiB
Rust

mod slash_command_registry;
use anyhow::Result;
use futures::stream::{self, BoxStream};
use futures::StreamExt;
use gpui::{AnyElement, AppContext, ElementId, SharedString, Task, WeakView, WindowContext};
use language::{BufferSnapshot, CodeLabel, LspAdapterDelegate, OffsetRangeExt};
pub use language_model::Role;
use serde::{Deserialize, Serialize};
pub use slash_command_registry::*;
use std::{
ops::Range,
sync::{atomic::AtomicBool, Arc},
};
use workspace::{ui::IconName, Workspace};
pub fn init(cx: &mut AppContext) {
SlashCommandRegistry::default_global(cx);
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum AfterCompletion {
/// Run the command
Run,
/// Continue composing the current argument, doesn't add a space
Compose,
/// Continue the command composition, adds a space
Continue,
}
impl From<bool> for AfterCompletion {
fn from(value: bool) -> Self {
if value {
AfterCompletion::Run
} else {
AfterCompletion::Continue
}
}
}
impl AfterCompletion {
pub fn run(&self) -> bool {
match self {
AfterCompletion::Run => true,
AfterCompletion::Compose | AfterCompletion::Continue => false,
}
}
}
#[derive(Debug)]
pub struct ArgumentCompletion {
/// The label to display for this completion.
pub label: CodeLabel,
/// The new text that should be inserted into the command when this completion is accepted.
pub new_text: String,
/// Whether the command should be run when accepting this completion.
pub after_completion: AfterCompletion,
/// Whether to replace the all arguments, or whether to treat this as an independent argument.
pub replace_previous_arguments: bool,
}
pub type SlashCommandResult = Result<BoxStream<'static, Result<SlashCommandEvent>>>;
pub trait SlashCommand: 'static + Send + Sync {
fn name(&self) -> String;
fn icon(&self) -> IconName {
IconName::Slash
}
fn label(&self, _cx: &AppContext) -> CodeLabel {
CodeLabel::plain(self.name(), None)
}
fn description(&self) -> String;
fn menu_text(&self) -> String;
fn complete_argument(
self: Arc<Self>,
arguments: &[String],
cancel: Arc<AtomicBool>,
workspace: Option<WeakView<Workspace>>,
cx: &mut WindowContext,
) -> Task<Result<Vec<ArgumentCompletion>>>;
fn requires_argument(&self) -> bool;
fn accepts_arguments(&self) -> bool {
self.requires_argument()
}
fn run(
self: Arc<Self>,
arguments: &[String],
context_slash_command_output_sections: &[SlashCommandOutputSection<language::Anchor>],
context_buffer: BufferSnapshot,
workspace: WeakView<Workspace>,
// TODO: We're just using the `LspAdapterDelegate` here because that is
// what the extension API is already expecting.
//
// It may be that `LspAdapterDelegate` needs a more general name, or
// perhaps another kind of delegate is needed here.
delegate: Option<Arc<dyn LspAdapterDelegate>>,
cx: &mut WindowContext,
) -> Task<SlashCommandResult>;
}
pub type RenderFoldPlaceholder = Arc<
dyn Send
+ Sync
+ Fn(ElementId, Arc<dyn Fn(&mut WindowContext)>, &mut WindowContext) -> AnyElement,
>;
#[derive(Debug, PartialEq)]
pub enum SlashCommandContent {
Text {
text: String,
run_commands_in_text: bool,
},
}
impl<'a> From<&'a str> for SlashCommandContent {
fn from(text: &'a str) -> Self {
Self::Text {
text: text.into(),
run_commands_in_text: false,
}
}
}
#[derive(Debug, PartialEq)]
pub enum SlashCommandEvent {
StartMessage {
role: Role,
merge_same_roles: bool,
},
StartSection {
icon: IconName,
label: SharedString,
metadata: Option<serde_json::Value>,
},
Content(SlashCommandContent),
EndSection,
}
#[derive(Debug, Default, PartialEq, Clone)]
pub struct SlashCommandOutput {
pub text: String,
pub sections: Vec<SlashCommandOutputSection<usize>>,
pub run_commands_in_text: bool,
}
impl SlashCommandOutput {
pub fn ensure_valid_section_ranges(&mut self) {
for section in &mut self.sections {
section.range.start = section.range.start.min(self.text.len());
section.range.end = section.range.end.min(self.text.len());
while !self.text.is_char_boundary(section.range.start) {
section.range.start -= 1;
}
while !self.text.is_char_boundary(section.range.end) {
section.range.end += 1;
}
}
}
/// Returns this [`SlashCommandOutput`] as a stream of [`SlashCommandEvent`]s.
pub fn to_event_stream(mut self) -> BoxStream<'static, Result<SlashCommandEvent>> {
self.ensure_valid_section_ranges();
let mut events = Vec::new();
let mut section_endpoints = Vec::new();
for section in self.sections {
section_endpoints.push((
section.range.start,
SlashCommandEvent::StartSection {
icon: section.icon,
label: section.label,
metadata: section.metadata,
},
));
section_endpoints.push((section.range.end, SlashCommandEvent::EndSection));
}
section_endpoints.sort_by_key(|(offset, _)| *offset);
let mut content_offset = 0;
for (endpoint_offset, endpoint) in section_endpoints {
if content_offset < endpoint_offset {
events.push(Ok(SlashCommandEvent::Content(SlashCommandContent::Text {
text: self.text[content_offset..endpoint_offset].to_string(),
run_commands_in_text: self.run_commands_in_text,
})));
content_offset = endpoint_offset;
}
events.push(Ok(endpoint));
}
if content_offset < self.text.len() {
events.push(Ok(SlashCommandEvent::Content(SlashCommandContent::Text {
text: self.text[content_offset..].to_string(),
run_commands_in_text: self.run_commands_in_text,
})));
}
stream::iter(events).boxed()
}
pub async fn from_event_stream(
mut events: BoxStream<'static, Result<SlashCommandEvent>>,
) -> Result<SlashCommandOutput> {
let mut output = SlashCommandOutput::default();
let mut section_stack = Vec::new();
while let Some(event) = events.next().await {
match event? {
SlashCommandEvent::StartSection {
icon,
label,
metadata,
} => {
let start = output.text.len();
section_stack.push(SlashCommandOutputSection {
range: start..start,
icon,
label,
metadata,
});
}
SlashCommandEvent::Content(SlashCommandContent::Text {
text,
run_commands_in_text,
}) => {
output.text.push_str(&text);
output.run_commands_in_text = run_commands_in_text;
if let Some(section) = section_stack.last_mut() {
section.range.end = output.text.len();
}
}
SlashCommandEvent::EndSection => {
if let Some(section) = section_stack.pop() {
output.sections.push(section);
}
}
SlashCommandEvent::StartMessage { .. } => {}
}
}
while let Some(section) = section_stack.pop() {
output.sections.push(section);
}
Ok(output)
}
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct SlashCommandOutputSection<T> {
pub range: Range<T>,
pub icon: IconName,
pub label: SharedString,
pub metadata: Option<serde_json::Value>,
}
impl SlashCommandOutputSection<language::Anchor> {
pub fn is_valid(&self, buffer: &language::TextBuffer) -> bool {
self.range.start.is_valid(buffer) && !self.range.to_offset(buffer).is_empty()
}
}
#[cfg(test)]
mod tests {
use pretty_assertions::assert_eq;
use serde_json::json;
use super::*;
#[gpui::test]
async fn test_slash_command_output_to_events_round_trip() {
// Test basic output consisting of a single section.
{
let text = "Hello, world!".to_string();
let range = 0..text.len();
let output = SlashCommandOutput {
text,
sections: vec![SlashCommandOutputSection {
range,
icon: IconName::Code,
label: "Section 1".into(),
metadata: None,
}],
run_commands_in_text: false,
};
let events = output.clone().to_event_stream().collect::<Vec<_>>().await;
let events = events
.into_iter()
.filter_map(|event| event.ok())
.collect::<Vec<_>>();
assert_eq!(
events,
vec![
SlashCommandEvent::StartSection {
icon: IconName::Code,
label: "Section 1".into(),
metadata: None
},
SlashCommandEvent::Content(SlashCommandContent::Text {
text: "Hello, world!".into(),
run_commands_in_text: false
}),
SlashCommandEvent::EndSection
]
);
let new_output =
SlashCommandOutput::from_event_stream(output.clone().to_event_stream())
.await
.unwrap();
assert_eq!(new_output, output);
}
// Test output where the sections do not comprise all of the text.
{
let text = "Apple\nCucumber\nBanana\n".to_string();
let output = SlashCommandOutput {
text,
sections: vec![
SlashCommandOutputSection {
range: 0..6,
icon: IconName::Check,
label: "Fruit".into(),
metadata: None,
},
SlashCommandOutputSection {
range: 15..22,
icon: IconName::Check,
label: "Fruit".into(),
metadata: None,
},
],
run_commands_in_text: false,
};
let events = output.clone().to_event_stream().collect::<Vec<_>>().await;
let events = events
.into_iter()
.filter_map(|event| event.ok())
.collect::<Vec<_>>();
assert_eq!(
events,
vec![
SlashCommandEvent::StartSection {
icon: IconName::Check,
label: "Fruit".into(),
metadata: None
},
SlashCommandEvent::Content(SlashCommandContent::Text {
text: "Apple\n".into(),
run_commands_in_text: false
}),
SlashCommandEvent::EndSection,
SlashCommandEvent::Content(SlashCommandContent::Text {
text: "Cucumber\n".into(),
run_commands_in_text: false
}),
SlashCommandEvent::StartSection {
icon: IconName::Check,
label: "Fruit".into(),
metadata: None
},
SlashCommandEvent::Content(SlashCommandContent::Text {
text: "Banana\n".into(),
run_commands_in_text: false
}),
SlashCommandEvent::EndSection
]
);
let new_output =
SlashCommandOutput::from_event_stream(output.clone().to_event_stream())
.await
.unwrap();
assert_eq!(new_output, output);
}
// Test output consisting of multiple sections.
{
let text = "Line 1\nLine 2\nLine 3\nLine 4\n".to_string();
let output = SlashCommandOutput {
text,
sections: vec![
SlashCommandOutputSection {
range: 0..6,
icon: IconName::FileCode,
label: "Section 1".into(),
metadata: Some(json!({ "a": true })),
},
SlashCommandOutputSection {
range: 7..13,
icon: IconName::FileDoc,
label: "Section 2".into(),
metadata: Some(json!({ "b": true })),
},
SlashCommandOutputSection {
range: 14..20,
icon: IconName::FileGit,
label: "Section 3".into(),
metadata: Some(json!({ "c": true })),
},
SlashCommandOutputSection {
range: 21..27,
icon: IconName::FileToml,
label: "Section 4".into(),
metadata: Some(json!({ "d": true })),
},
],
run_commands_in_text: false,
};
let events = output.clone().to_event_stream().collect::<Vec<_>>().await;
let events = events
.into_iter()
.filter_map(|event| event.ok())
.collect::<Vec<_>>();
assert_eq!(
events,
vec![
SlashCommandEvent::StartSection {
icon: IconName::FileCode,
label: "Section 1".into(),
metadata: Some(json!({ "a": true }))
},
SlashCommandEvent::Content(SlashCommandContent::Text {
text: "Line 1".into(),
run_commands_in_text: false
}),
SlashCommandEvent::EndSection,
SlashCommandEvent::Content(SlashCommandContent::Text {
text: "\n".into(),
run_commands_in_text: false
}),
SlashCommandEvent::StartSection {
icon: IconName::FileDoc,
label: "Section 2".into(),
metadata: Some(json!({ "b": true }))
},
SlashCommandEvent::Content(SlashCommandContent::Text {
text: "Line 2".into(),
run_commands_in_text: false
}),
SlashCommandEvent::EndSection,
SlashCommandEvent::Content(SlashCommandContent::Text {
text: "\n".into(),
run_commands_in_text: false
}),
SlashCommandEvent::StartSection {
icon: IconName::FileGit,
label: "Section 3".into(),
metadata: Some(json!({ "c": true }))
},
SlashCommandEvent::Content(SlashCommandContent::Text {
text: "Line 3".into(),
run_commands_in_text: false
}),
SlashCommandEvent::EndSection,
SlashCommandEvent::Content(SlashCommandContent::Text {
text: "\n".into(),
run_commands_in_text: false
}),
SlashCommandEvent::StartSection {
icon: IconName::FileToml,
label: "Section 4".into(),
metadata: Some(json!({ "d": true }))
},
SlashCommandEvent::Content(SlashCommandContent::Text {
text: "Line 4".into(),
run_commands_in_text: false
}),
SlashCommandEvent::EndSection,
SlashCommandEvent::Content(SlashCommandContent::Text {
text: "\n".into(),
run_commands_in_text: false
}),
]
);
let new_output =
SlashCommandOutput::from_event_stream(output.clone().to_event_stream())
.await
.unwrap();
assert_eq!(new_output, output);
}
}
}