
This PR unships tool use from Assistant1. This was only ever partially implemented, and was never released to end users. Assistant2 will support tool use. Release Notes: - N/A
1657 lines
52 KiB
Rust
1657 lines
52 KiB
Rust
use crate::{
|
|
AssistantContext, AssistantEdit, AssistantEditKind, CacheStatus, ContextEvent, ContextId,
|
|
ContextOperation, InvokedSlashCommandId, MessageCacheMetadata, MessageId, MessageStatus,
|
|
};
|
|
use anyhow::Result;
|
|
use assistant_slash_command::{
|
|
ArgumentCompletion, SlashCommand, SlashCommandContent, SlashCommandEvent, SlashCommandOutput,
|
|
SlashCommandOutputSection, SlashCommandRegistry, SlashCommandResult, SlashCommandWorkingSet,
|
|
};
|
|
use assistant_slash_commands::FileSlashCommand;
|
|
use collections::{HashMap, HashSet};
|
|
use fs::FakeFs;
|
|
use futures::{
|
|
channel::mpsc,
|
|
stream::{self, StreamExt},
|
|
};
|
|
use gpui::{prelude::*, App, Entity, SharedString, Task, TestAppContext, WeakEntity};
|
|
use language::{Buffer, BufferSnapshot, LanguageRegistry, LspAdapterDelegate};
|
|
use language_model::{LanguageModelCacheConfiguration, LanguageModelRegistry, Role};
|
|
use parking_lot::Mutex;
|
|
use pretty_assertions::assert_eq;
|
|
use project::Project;
|
|
use prompt_library::PromptBuilder;
|
|
use rand::prelude::*;
|
|
use serde_json::json;
|
|
use settings::SettingsStore;
|
|
use std::{
|
|
cell::RefCell,
|
|
env,
|
|
ops::Range,
|
|
path::Path,
|
|
rc::Rc,
|
|
sync::{atomic::AtomicBool, Arc},
|
|
};
|
|
use text::{network::Network, OffsetRangeExt as _, ReplicaId, ToOffset};
|
|
use ui::{IconName, Window};
|
|
use unindent::Unindent;
|
|
use util::{
|
|
test::{generate_marked_text, marked_text_ranges},
|
|
RandomCharIter,
|
|
};
|
|
use workspace::Workspace;
|
|
|
|
#[gpui::test]
|
|
fn test_inserting_and_removing_messages(cx: &mut App) {
|
|
let settings_store = SettingsStore::test(cx);
|
|
LanguageModelRegistry::test(cx);
|
|
cx.set_global(settings_store);
|
|
let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
|
|
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
|
|
let context = cx.new(|cx| {
|
|
AssistantContext::local(
|
|
registry,
|
|
None,
|
|
None,
|
|
prompt_builder.clone(),
|
|
Arc::new(SlashCommandWorkingSet::default()),
|
|
cx,
|
|
)
|
|
});
|
|
let buffer = context.read(cx).buffer.clone();
|
|
|
|
let message_1 = context.read(cx).message_anchors[0].clone();
|
|
assert_eq!(
|
|
messages(&context, cx),
|
|
vec![(message_1.id, Role::User, 0..0)]
|
|
);
|
|
|
|
let message_2 = context.update(cx, |context, cx| {
|
|
context
|
|
.insert_message_after(message_1.id, Role::Assistant, MessageStatus::Done, cx)
|
|
.unwrap()
|
|
});
|
|
assert_eq!(
|
|
messages(&context, cx),
|
|
vec![
|
|
(message_1.id, Role::User, 0..1),
|
|
(message_2.id, Role::Assistant, 1..1)
|
|
]
|
|
);
|
|
|
|
buffer.update(cx, |buffer, cx| {
|
|
buffer.edit([(0..0, "1"), (1..1, "2")], None, cx)
|
|
});
|
|
assert_eq!(
|
|
messages(&context, cx),
|
|
vec![
|
|
(message_1.id, Role::User, 0..2),
|
|
(message_2.id, Role::Assistant, 2..3)
|
|
]
|
|
);
|
|
|
|
let message_3 = context.update(cx, |context, cx| {
|
|
context
|
|
.insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
|
|
.unwrap()
|
|
});
|
|
assert_eq!(
|
|
messages(&context, cx),
|
|
vec![
|
|
(message_1.id, Role::User, 0..2),
|
|
(message_2.id, Role::Assistant, 2..4),
|
|
(message_3.id, Role::User, 4..4)
|
|
]
|
|
);
|
|
|
|
let message_4 = context.update(cx, |context, cx| {
|
|
context
|
|
.insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
|
|
.unwrap()
|
|
});
|
|
assert_eq!(
|
|
messages(&context, cx),
|
|
vec![
|
|
(message_1.id, Role::User, 0..2),
|
|
(message_2.id, Role::Assistant, 2..4),
|
|
(message_4.id, Role::User, 4..5),
|
|
(message_3.id, Role::User, 5..5),
|
|
]
|
|
);
|
|
|
|
buffer.update(cx, |buffer, cx| {
|
|
buffer.edit([(4..4, "C"), (5..5, "D")], None, cx)
|
|
});
|
|
assert_eq!(
|
|
messages(&context, cx),
|
|
vec![
|
|
(message_1.id, Role::User, 0..2),
|
|
(message_2.id, Role::Assistant, 2..4),
|
|
(message_4.id, Role::User, 4..6),
|
|
(message_3.id, Role::User, 6..7),
|
|
]
|
|
);
|
|
|
|
// Deleting across message boundaries merges the messages.
|
|
buffer.update(cx, |buffer, cx| buffer.edit([(1..4, "")], None, cx));
|
|
assert_eq!(
|
|
messages(&context, cx),
|
|
vec![
|
|
(message_1.id, Role::User, 0..3),
|
|
(message_3.id, Role::User, 3..4),
|
|
]
|
|
);
|
|
|
|
// Undoing the deletion should also undo the merge.
|
|
buffer.update(cx, |buffer, cx| buffer.undo(cx));
|
|
assert_eq!(
|
|
messages(&context, cx),
|
|
vec![
|
|
(message_1.id, Role::User, 0..2),
|
|
(message_2.id, Role::Assistant, 2..4),
|
|
(message_4.id, Role::User, 4..6),
|
|
(message_3.id, Role::User, 6..7),
|
|
]
|
|
);
|
|
|
|
// Redoing the deletion should also redo the merge.
|
|
buffer.update(cx, |buffer, cx| buffer.redo(cx));
|
|
assert_eq!(
|
|
messages(&context, cx),
|
|
vec![
|
|
(message_1.id, Role::User, 0..3),
|
|
(message_3.id, Role::User, 3..4),
|
|
]
|
|
);
|
|
|
|
// Ensure we can still insert after a merged message.
|
|
let message_5 = context.update(cx, |context, cx| {
|
|
context
|
|
.insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx)
|
|
.unwrap()
|
|
});
|
|
assert_eq!(
|
|
messages(&context, cx),
|
|
vec![
|
|
(message_1.id, Role::User, 0..3),
|
|
(message_5.id, Role::System, 3..4),
|
|
(message_3.id, Role::User, 4..5)
|
|
]
|
|
);
|
|
}
|
|
|
|
#[gpui::test]
|
|
fn test_message_splitting(cx: &mut App) {
|
|
let settings_store = SettingsStore::test(cx);
|
|
cx.set_global(settings_store);
|
|
LanguageModelRegistry::test(cx);
|
|
let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
|
|
|
|
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
|
|
let context = cx.new(|cx| {
|
|
AssistantContext::local(
|
|
registry.clone(),
|
|
None,
|
|
None,
|
|
prompt_builder.clone(),
|
|
Arc::new(SlashCommandWorkingSet::default()),
|
|
cx,
|
|
)
|
|
});
|
|
let buffer = context.read(cx).buffer.clone();
|
|
|
|
let message_1 = context.read(cx).message_anchors[0].clone();
|
|
assert_eq!(
|
|
messages(&context, cx),
|
|
vec![(message_1.id, Role::User, 0..0)]
|
|
);
|
|
|
|
buffer.update(cx, |buffer, cx| {
|
|
buffer.edit([(0..0, "aaa\nbbb\nccc\nddd\n")], None, cx)
|
|
});
|
|
|
|
let (_, message_2) = context.update(cx, |context, cx| context.split_message(3..3, cx));
|
|
let message_2 = message_2.unwrap();
|
|
|
|
// We recycle newlines in the middle of a split message
|
|
assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\nddd\n");
|
|
assert_eq!(
|
|
messages(&context, cx),
|
|
vec![
|
|
(message_1.id, Role::User, 0..4),
|
|
(message_2.id, Role::User, 4..16),
|
|
]
|
|
);
|
|
|
|
let (_, message_3) = context.update(cx, |context, cx| context.split_message(3..3, cx));
|
|
let message_3 = message_3.unwrap();
|
|
|
|
// We don't recycle newlines at the end of a split message
|
|
assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n");
|
|
assert_eq!(
|
|
messages(&context, cx),
|
|
vec![
|
|
(message_1.id, Role::User, 0..4),
|
|
(message_3.id, Role::User, 4..5),
|
|
(message_2.id, Role::User, 5..17),
|
|
]
|
|
);
|
|
|
|
let (_, message_4) = context.update(cx, |context, cx| context.split_message(9..9, cx));
|
|
let message_4 = message_4.unwrap();
|
|
assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n");
|
|
assert_eq!(
|
|
messages(&context, cx),
|
|
vec![
|
|
(message_1.id, Role::User, 0..4),
|
|
(message_3.id, Role::User, 4..5),
|
|
(message_2.id, Role::User, 5..9),
|
|
(message_4.id, Role::User, 9..17),
|
|
]
|
|
);
|
|
|
|
let (_, message_5) = context.update(cx, |context, cx| context.split_message(9..9, cx));
|
|
let message_5 = message_5.unwrap();
|
|
assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\nddd\n");
|
|
assert_eq!(
|
|
messages(&context, cx),
|
|
vec![
|
|
(message_1.id, Role::User, 0..4),
|
|
(message_3.id, Role::User, 4..5),
|
|
(message_2.id, Role::User, 5..9),
|
|
(message_4.id, Role::User, 9..10),
|
|
(message_5.id, Role::User, 10..18),
|
|
]
|
|
);
|
|
|
|
let (message_6, message_7) =
|
|
context.update(cx, |context, cx| context.split_message(14..16, cx));
|
|
let message_6 = message_6.unwrap();
|
|
let message_7 = message_7.unwrap();
|
|
assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\ndd\nd\n");
|
|
assert_eq!(
|
|
messages(&context, cx),
|
|
vec![
|
|
(message_1.id, Role::User, 0..4),
|
|
(message_3.id, Role::User, 4..5),
|
|
(message_2.id, Role::User, 5..9),
|
|
(message_4.id, Role::User, 9..10),
|
|
(message_5.id, Role::User, 10..14),
|
|
(message_6.id, Role::User, 14..17),
|
|
(message_7.id, Role::User, 17..19),
|
|
]
|
|
);
|
|
}
|
|
|
|
#[gpui::test]
|
|
fn test_messages_for_offsets(cx: &mut App) {
|
|
let settings_store = SettingsStore::test(cx);
|
|
LanguageModelRegistry::test(cx);
|
|
cx.set_global(settings_store);
|
|
let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
|
|
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
|
|
let context = cx.new(|cx| {
|
|
AssistantContext::local(
|
|
registry,
|
|
None,
|
|
None,
|
|
prompt_builder.clone(),
|
|
Arc::new(SlashCommandWorkingSet::default()),
|
|
cx,
|
|
)
|
|
});
|
|
let buffer = context.read(cx).buffer.clone();
|
|
|
|
let message_1 = context.read(cx).message_anchors[0].clone();
|
|
assert_eq!(
|
|
messages(&context, cx),
|
|
vec![(message_1.id, Role::User, 0..0)]
|
|
);
|
|
|
|
buffer.update(cx, |buffer, cx| buffer.edit([(0..0, "aaa")], None, cx));
|
|
let message_2 = context
|
|
.update(cx, |context, cx| {
|
|
context.insert_message_after(message_1.id, Role::User, MessageStatus::Done, cx)
|
|
})
|
|
.unwrap();
|
|
buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "bbb")], None, cx));
|
|
|
|
let message_3 = context
|
|
.update(cx, |context, cx| {
|
|
context.insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
|
|
})
|
|
.unwrap();
|
|
buffer.update(cx, |buffer, cx| buffer.edit([(8..8, "ccc")], None, cx));
|
|
|
|
assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc");
|
|
assert_eq!(
|
|
messages(&context, cx),
|
|
vec![
|
|
(message_1.id, Role::User, 0..4),
|
|
(message_2.id, Role::User, 4..8),
|
|
(message_3.id, Role::User, 8..11)
|
|
]
|
|
);
|
|
|
|
assert_eq!(
|
|
message_ids_for_offsets(&context, &[0, 4, 9], cx),
|
|
[message_1.id, message_2.id, message_3.id]
|
|
);
|
|
assert_eq!(
|
|
message_ids_for_offsets(&context, &[0, 1, 11], cx),
|
|
[message_1.id, message_3.id]
|
|
);
|
|
|
|
let message_4 = context
|
|
.update(cx, |context, cx| {
|
|
context.insert_message_after(message_3.id, Role::User, MessageStatus::Done, cx)
|
|
})
|
|
.unwrap();
|
|
assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\n");
|
|
assert_eq!(
|
|
messages(&context, cx),
|
|
vec![
|
|
(message_1.id, Role::User, 0..4),
|
|
(message_2.id, Role::User, 4..8),
|
|
(message_3.id, Role::User, 8..12),
|
|
(message_4.id, Role::User, 12..12)
|
|
]
|
|
);
|
|
assert_eq!(
|
|
message_ids_for_offsets(&context, &[0, 4, 8, 12], cx),
|
|
[message_1.id, message_2.id, message_3.id, message_4.id]
|
|
);
|
|
|
|
fn message_ids_for_offsets(
|
|
context: &Entity<AssistantContext>,
|
|
offsets: &[usize],
|
|
cx: &App,
|
|
) -> Vec<MessageId> {
|
|
context
|
|
.read(cx)
|
|
.messages_for_offsets(offsets.iter().copied(), cx)
|
|
.into_iter()
|
|
.map(|message| message.id)
|
|
.collect()
|
|
}
|
|
}
|
|
|
|
#[gpui::test]
|
|
async fn test_slash_commands(cx: &mut TestAppContext) {
|
|
let settings_store = cx.update(SettingsStore::test);
|
|
cx.set_global(settings_store);
|
|
cx.update(LanguageModelRegistry::test);
|
|
cx.update(Project::init_settings);
|
|
let fs = FakeFs::new(cx.background_executor.clone());
|
|
|
|
fs.insert_tree(
|
|
"/test",
|
|
json!({
|
|
"src": {
|
|
"lib.rs": "fn one() -> usize { 1 }",
|
|
"main.rs": "
|
|
use crate::one;
|
|
fn main() { one(); }
|
|
".unindent(),
|
|
}
|
|
}),
|
|
)
|
|
.await;
|
|
|
|
let slash_command_registry = cx.update(SlashCommandRegistry::default_global);
|
|
slash_command_registry.register_command(FileSlashCommand, false);
|
|
|
|
let registry = Arc::new(LanguageRegistry::test(cx.executor()));
|
|
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
|
|
let context = cx.new(|cx| {
|
|
AssistantContext::local(
|
|
registry.clone(),
|
|
None,
|
|
None,
|
|
prompt_builder.clone(),
|
|
Arc::new(SlashCommandWorkingSet::default()),
|
|
cx,
|
|
)
|
|
});
|
|
|
|
#[derive(Default)]
|
|
struct ContextRanges {
|
|
parsed_commands: HashSet<Range<language::Anchor>>,
|
|
command_outputs: HashMap<InvokedSlashCommandId, Range<language::Anchor>>,
|
|
output_sections: HashSet<Range<language::Anchor>>,
|
|
}
|
|
|
|
let context_ranges = Rc::new(RefCell::new(ContextRanges::default()));
|
|
context.update(cx, |_, cx| {
|
|
cx.subscribe(&context, {
|
|
let context_ranges = context_ranges.clone();
|
|
move |context, _, event, _| {
|
|
let mut context_ranges = context_ranges.borrow_mut();
|
|
match event {
|
|
ContextEvent::InvokedSlashCommandChanged { command_id } => {
|
|
let command = context.invoked_slash_command(command_id).unwrap();
|
|
context_ranges
|
|
.command_outputs
|
|
.insert(*command_id, command.range.clone());
|
|
}
|
|
ContextEvent::ParsedSlashCommandsUpdated { removed, updated } => {
|
|
for range in removed {
|
|
context_ranges.parsed_commands.remove(range);
|
|
}
|
|
for command in updated {
|
|
context_ranges
|
|
.parsed_commands
|
|
.insert(command.source_range.clone());
|
|
}
|
|
}
|
|
ContextEvent::SlashCommandOutputSectionAdded { section } => {
|
|
context_ranges.output_sections.insert(section.range.clone());
|
|
}
|
|
_ => {}
|
|
}
|
|
}
|
|
})
|
|
.detach();
|
|
});
|
|
|
|
let buffer = context.read_with(cx, |context, _| context.buffer.clone());
|
|
|
|
// Insert a slash command
|
|
buffer.update(cx, |buffer, cx| {
|
|
buffer.edit([(0..0, "/file src/lib.rs")], None, cx);
|
|
});
|
|
assert_text_and_context_ranges(
|
|
&buffer,
|
|
&context_ranges,
|
|
&"
|
|
«/file src/lib.rs»"
|
|
.unindent(),
|
|
cx,
|
|
);
|
|
|
|
// Edit the argument of the slash command.
|
|
buffer.update(cx, |buffer, cx| {
|
|
let edit_offset = buffer.text().find("lib.rs").unwrap();
|
|
buffer.edit([(edit_offset..edit_offset + "lib".len(), "main")], None, cx);
|
|
});
|
|
assert_text_and_context_ranges(
|
|
&buffer,
|
|
&context_ranges,
|
|
&"
|
|
«/file src/main.rs»"
|
|
.unindent(),
|
|
cx,
|
|
);
|
|
|
|
// Edit the name of the slash command, using one that doesn't exist.
|
|
buffer.update(cx, |buffer, cx| {
|
|
let edit_offset = buffer.text().find("/file").unwrap();
|
|
buffer.edit(
|
|
[(edit_offset..edit_offset + "/file".len(), "/unknown")],
|
|
None,
|
|
cx,
|
|
);
|
|
});
|
|
assert_text_and_context_ranges(
|
|
&buffer,
|
|
&context_ranges,
|
|
&"
|
|
/unknown src/main.rs"
|
|
.unindent(),
|
|
cx,
|
|
);
|
|
|
|
// Undoing the insertion of an non-existent slash command resorts the previous one.
|
|
buffer.update(cx, |buffer, cx| buffer.undo(cx));
|
|
assert_text_and_context_ranges(
|
|
&buffer,
|
|
&context_ranges,
|
|
&"
|
|
«/file src/main.rs»"
|
|
.unindent(),
|
|
cx,
|
|
);
|
|
|
|
let (command_output_tx, command_output_rx) = mpsc::unbounded();
|
|
context.update(cx, |context, cx| {
|
|
let command_source_range = context.parsed_slash_commands[0].source_range.clone();
|
|
context.insert_command_output(
|
|
command_source_range,
|
|
"file",
|
|
Task::ready(Ok(command_output_rx.boxed())),
|
|
true,
|
|
cx,
|
|
);
|
|
});
|
|
assert_text_and_context_ranges(
|
|
&buffer,
|
|
&context_ranges,
|
|
&"
|
|
⟦«/file src/main.rs»
|
|
…⟧
|
|
"
|
|
.unindent(),
|
|
cx,
|
|
);
|
|
|
|
command_output_tx
|
|
.unbounded_send(Ok(SlashCommandEvent::StartSection {
|
|
icon: IconName::Ai,
|
|
label: "src/main.rs".into(),
|
|
metadata: None,
|
|
}))
|
|
.unwrap();
|
|
command_output_tx
|
|
.unbounded_send(Ok(SlashCommandEvent::Content("src/main.rs".into())))
|
|
.unwrap();
|
|
cx.run_until_parked();
|
|
assert_text_and_context_ranges(
|
|
&buffer,
|
|
&context_ranges,
|
|
&"
|
|
⟦«/file src/main.rs»
|
|
src/main.rs…⟧
|
|
"
|
|
.unindent(),
|
|
cx,
|
|
);
|
|
|
|
command_output_tx
|
|
.unbounded_send(Ok(SlashCommandEvent::Content("\nfn main() {}".into())))
|
|
.unwrap();
|
|
cx.run_until_parked();
|
|
assert_text_and_context_ranges(
|
|
&buffer,
|
|
&context_ranges,
|
|
&"
|
|
⟦«/file src/main.rs»
|
|
src/main.rs
|
|
fn main() {}…⟧
|
|
"
|
|
.unindent(),
|
|
cx,
|
|
);
|
|
|
|
command_output_tx
|
|
.unbounded_send(Ok(SlashCommandEvent::EndSection))
|
|
.unwrap();
|
|
cx.run_until_parked();
|
|
assert_text_and_context_ranges(
|
|
&buffer,
|
|
&context_ranges,
|
|
&"
|
|
⟦«/file src/main.rs»
|
|
⟪src/main.rs
|
|
fn main() {}⟫…⟧
|
|
"
|
|
.unindent(),
|
|
cx,
|
|
);
|
|
|
|
drop(command_output_tx);
|
|
cx.run_until_parked();
|
|
assert_text_and_context_ranges(
|
|
&buffer,
|
|
&context_ranges,
|
|
&"
|
|
⟦⟪src/main.rs
|
|
fn main() {}⟫⟧
|
|
"
|
|
.unindent(),
|
|
cx,
|
|
);
|
|
|
|
#[track_caller]
|
|
fn assert_text_and_context_ranges(
|
|
buffer: &Entity<Buffer>,
|
|
ranges: &RefCell<ContextRanges>,
|
|
expected_marked_text: &str,
|
|
cx: &mut TestAppContext,
|
|
) {
|
|
let mut actual_marked_text = String::new();
|
|
buffer.update(cx, |buffer, _| {
|
|
struct Endpoint {
|
|
offset: usize,
|
|
marker: char,
|
|
}
|
|
|
|
let ranges = ranges.borrow();
|
|
let mut endpoints = Vec::new();
|
|
for range in ranges.command_outputs.values() {
|
|
endpoints.push(Endpoint {
|
|
offset: range.start.to_offset(buffer),
|
|
marker: '⟦',
|
|
});
|
|
}
|
|
for range in ranges.parsed_commands.iter() {
|
|
endpoints.push(Endpoint {
|
|
offset: range.start.to_offset(buffer),
|
|
marker: '«',
|
|
});
|
|
}
|
|
for range in ranges.output_sections.iter() {
|
|
endpoints.push(Endpoint {
|
|
offset: range.start.to_offset(buffer),
|
|
marker: '⟪',
|
|
});
|
|
}
|
|
|
|
for range in ranges.output_sections.iter() {
|
|
endpoints.push(Endpoint {
|
|
offset: range.end.to_offset(buffer),
|
|
marker: '⟫',
|
|
});
|
|
}
|
|
for range in ranges.parsed_commands.iter() {
|
|
endpoints.push(Endpoint {
|
|
offset: range.end.to_offset(buffer),
|
|
marker: '»',
|
|
});
|
|
}
|
|
for range in ranges.command_outputs.values() {
|
|
endpoints.push(Endpoint {
|
|
offset: range.end.to_offset(buffer),
|
|
marker: '⟧',
|
|
});
|
|
}
|
|
|
|
endpoints.sort_by_key(|endpoint| endpoint.offset);
|
|
let mut offset = 0;
|
|
for endpoint in endpoints {
|
|
actual_marked_text.extend(buffer.text_for_range(offset..endpoint.offset));
|
|
actual_marked_text.push(endpoint.marker);
|
|
offset = endpoint.offset;
|
|
}
|
|
actual_marked_text.extend(buffer.text_for_range(offset..buffer.len()));
|
|
});
|
|
|
|
assert_eq!(actual_marked_text, expected_marked_text);
|
|
}
|
|
}
|
|
|
|
#[gpui::test]
|
|
async fn test_workflow_step_parsing(cx: &mut TestAppContext) {
|
|
cx.update(prompt_library::init);
|
|
let mut settings_store = cx.update(SettingsStore::test);
|
|
cx.update(|cx| {
|
|
settings_store
|
|
.set_user_settings(
|
|
r#"{ "assistant": { "enable_experimental_live_diffs": true } }"#,
|
|
cx,
|
|
)
|
|
.unwrap()
|
|
});
|
|
cx.set_global(settings_store);
|
|
cx.update(language::init);
|
|
cx.update(Project::init_settings);
|
|
let fs = FakeFs::new(cx.executor());
|
|
let project = Project::test(fs, [Path::new("/root")], cx).await;
|
|
cx.update(LanguageModelRegistry::test);
|
|
|
|
let registry = Arc::new(LanguageRegistry::test(cx.executor()));
|
|
|
|
// Create a new context
|
|
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
|
|
let context = cx.new(|cx| {
|
|
AssistantContext::local(
|
|
registry.clone(),
|
|
Some(project),
|
|
None,
|
|
prompt_builder.clone(),
|
|
Arc::new(SlashCommandWorkingSet::default()),
|
|
cx,
|
|
)
|
|
});
|
|
|
|
// Insert an assistant message to simulate a response.
|
|
let assistant_message_id = context.update(cx, |context, cx| {
|
|
let user_message_id = context.messages(cx).next().unwrap().id;
|
|
context
|
|
.insert_message_after(user_message_id, Role::Assistant, MessageStatus::Done, cx)
|
|
.unwrap()
|
|
.id
|
|
});
|
|
|
|
// No edit tags
|
|
edit(
|
|
&context,
|
|
"
|
|
|
|
«one
|
|
two
|
|
»",
|
|
cx,
|
|
);
|
|
expect_patches(
|
|
&context,
|
|
"
|
|
|
|
one
|
|
two
|
|
",
|
|
&[],
|
|
cx,
|
|
);
|
|
|
|
// Partial edit step tag is added
|
|
edit(
|
|
&context,
|
|
"
|
|
|
|
one
|
|
two
|
|
«
|
|
<patch»",
|
|
cx,
|
|
);
|
|
expect_patches(
|
|
&context,
|
|
"
|
|
|
|
one
|
|
two
|
|
|
|
<patch",
|
|
&[],
|
|
cx,
|
|
);
|
|
|
|
// The rest of the step tag is added. The unclosed
|
|
// step is treated as incomplete.
|
|
edit(
|
|
&context,
|
|
"
|
|
|
|
one
|
|
two
|
|
|
|
<patch«>
|
|
<edit>»",
|
|
cx,
|
|
);
|
|
expect_patches(
|
|
&context,
|
|
"
|
|
|
|
one
|
|
two
|
|
|
|
«<patch>
|
|
<edit>»",
|
|
&[&[]],
|
|
cx,
|
|
);
|
|
|
|
// The full patch is added
|
|
edit(
|
|
&context,
|
|
"
|
|
|
|
one
|
|
two
|
|
|
|
<patch>
|
|
<edit>«
|
|
<description>add a `two` function</description>
|
|
<path>src/lib.rs</path>
|
|
<operation>insert_after</operation>
|
|
<old_text>fn one</old_text>
|
|
<new_text>
|
|
fn two() {}
|
|
</new_text>
|
|
</edit>
|
|
</patch>
|
|
|
|
also,»",
|
|
cx,
|
|
);
|
|
expect_patches(
|
|
&context,
|
|
"
|
|
|
|
one
|
|
two
|
|
|
|
«<patch>
|
|
<edit>
|
|
<description>add a `two` function</description>
|
|
<path>src/lib.rs</path>
|
|
<operation>insert_after</operation>
|
|
<old_text>fn one</old_text>
|
|
<new_text>
|
|
fn two() {}
|
|
</new_text>
|
|
</edit>
|
|
</patch>
|
|
»
|
|
also,",
|
|
&[&[AssistantEdit {
|
|
path: "src/lib.rs".into(),
|
|
kind: AssistantEditKind::InsertAfter {
|
|
old_text: "fn one".into(),
|
|
new_text: "fn two() {}".into(),
|
|
description: Some("add a `two` function".into()),
|
|
},
|
|
}]],
|
|
cx,
|
|
);
|
|
|
|
// The step is manually edited.
|
|
edit(
|
|
&context,
|
|
"
|
|
|
|
one
|
|
two
|
|
|
|
<patch>
|
|
<edit>
|
|
<description>add a `two` function</description>
|
|
<path>src/lib.rs</path>
|
|
<operation>insert_after</operation>
|
|
<old_text>«fn zero»</old_text>
|
|
<new_text>
|
|
fn two() {}
|
|
</new_text>
|
|
</edit>
|
|
</patch>
|
|
|
|
also,",
|
|
cx,
|
|
);
|
|
expect_patches(
|
|
&context,
|
|
"
|
|
|
|
one
|
|
two
|
|
|
|
«<patch>
|
|
<edit>
|
|
<description>add a `two` function</description>
|
|
<path>src/lib.rs</path>
|
|
<operation>insert_after</operation>
|
|
<old_text>fn zero</old_text>
|
|
<new_text>
|
|
fn two() {}
|
|
</new_text>
|
|
</edit>
|
|
</patch>
|
|
»
|
|
also,",
|
|
&[&[AssistantEdit {
|
|
path: "src/lib.rs".into(),
|
|
kind: AssistantEditKind::InsertAfter {
|
|
old_text: "fn zero".into(),
|
|
new_text: "fn two() {}".into(),
|
|
description: Some("add a `two` function".into()),
|
|
},
|
|
}]],
|
|
cx,
|
|
);
|
|
|
|
// When setting the message role to User, the steps are cleared.
|
|
context.update(cx, |context, cx| {
|
|
context.cycle_message_roles(HashSet::from_iter([assistant_message_id]), cx);
|
|
context.cycle_message_roles(HashSet::from_iter([assistant_message_id]), cx);
|
|
});
|
|
expect_patches(
|
|
&context,
|
|
"
|
|
|
|
one
|
|
two
|
|
|
|
<patch>
|
|
<edit>
|
|
<description>add a `two` function</description>
|
|
<path>src/lib.rs</path>
|
|
<operation>insert_after</operation>
|
|
<old_text>fn zero</old_text>
|
|
<new_text>
|
|
fn two() {}
|
|
</new_text>
|
|
</edit>
|
|
</patch>
|
|
|
|
also,",
|
|
&[],
|
|
cx,
|
|
);
|
|
|
|
// When setting the message role back to Assistant, the steps are reparsed.
|
|
context.update(cx, |context, cx| {
|
|
context.cycle_message_roles(HashSet::from_iter([assistant_message_id]), cx);
|
|
});
|
|
expect_patches(
|
|
&context,
|
|
"
|
|
|
|
one
|
|
two
|
|
|
|
«<patch>
|
|
<edit>
|
|
<description>add a `two` function</description>
|
|
<path>src/lib.rs</path>
|
|
<operation>insert_after</operation>
|
|
<old_text>fn zero</old_text>
|
|
<new_text>
|
|
fn two() {}
|
|
</new_text>
|
|
</edit>
|
|
</patch>
|
|
»
|
|
also,",
|
|
&[&[AssistantEdit {
|
|
path: "src/lib.rs".into(),
|
|
kind: AssistantEditKind::InsertAfter {
|
|
old_text: "fn zero".into(),
|
|
new_text: "fn two() {}".into(),
|
|
description: Some("add a `two` function".into()),
|
|
},
|
|
}]],
|
|
cx,
|
|
);
|
|
|
|
// Ensure steps are re-parsed when deserializing.
|
|
let serialized_context = context.read_with(cx, |context, cx| context.serialize(cx));
|
|
let deserialized_context = cx.new(|cx| {
|
|
AssistantContext::deserialize(
|
|
serialized_context,
|
|
Default::default(),
|
|
registry.clone(),
|
|
prompt_builder.clone(),
|
|
Arc::new(SlashCommandWorkingSet::default()),
|
|
None,
|
|
None,
|
|
cx,
|
|
)
|
|
});
|
|
expect_patches(
|
|
&deserialized_context,
|
|
"
|
|
|
|
one
|
|
two
|
|
|
|
«<patch>
|
|
<edit>
|
|
<description>add a `two` function</description>
|
|
<path>src/lib.rs</path>
|
|
<operation>insert_after</operation>
|
|
<old_text>fn zero</old_text>
|
|
<new_text>
|
|
fn two() {}
|
|
</new_text>
|
|
</edit>
|
|
</patch>
|
|
»
|
|
also,",
|
|
&[&[AssistantEdit {
|
|
path: "src/lib.rs".into(),
|
|
kind: AssistantEditKind::InsertAfter {
|
|
old_text: "fn zero".into(),
|
|
new_text: "fn two() {}".into(),
|
|
description: Some("add a `two` function".into()),
|
|
},
|
|
}]],
|
|
cx,
|
|
);
|
|
|
|
fn edit(
|
|
context: &Entity<AssistantContext>,
|
|
new_text_marked_with_edits: &str,
|
|
cx: &mut TestAppContext,
|
|
) {
|
|
context.update(cx, |context, cx| {
|
|
context.buffer.update(cx, |buffer, cx| {
|
|
buffer.edit_via_marked_text(&new_text_marked_with_edits.unindent(), None, cx);
|
|
});
|
|
});
|
|
cx.executor().run_until_parked();
|
|
}
|
|
|
|
#[track_caller]
|
|
fn expect_patches(
|
|
context: &Entity<AssistantContext>,
|
|
expected_marked_text: &str,
|
|
expected_suggestions: &[&[AssistantEdit]],
|
|
cx: &mut TestAppContext,
|
|
) {
|
|
let expected_marked_text = expected_marked_text.unindent();
|
|
let (expected_text, _) = marked_text_ranges(&expected_marked_text, false);
|
|
|
|
let (buffer_text, ranges, patches) = context.update(cx, |context, cx| {
|
|
context.buffer.read_with(cx, |buffer, _| {
|
|
let ranges = context
|
|
.patches
|
|
.iter()
|
|
.map(|entry| entry.range.to_offset(buffer))
|
|
.collect::<Vec<_>>();
|
|
(
|
|
buffer.text(),
|
|
ranges,
|
|
context
|
|
.patches
|
|
.iter()
|
|
.map(|step| step.edits.clone())
|
|
.collect::<Vec<_>>(),
|
|
)
|
|
})
|
|
});
|
|
|
|
assert_eq!(buffer_text, expected_text);
|
|
|
|
let actual_marked_text = generate_marked_text(&expected_text, &ranges, false);
|
|
assert_eq!(actual_marked_text, expected_marked_text);
|
|
|
|
assert_eq!(
|
|
patches
|
|
.iter()
|
|
.map(|patch| {
|
|
patch
|
|
.iter()
|
|
.map(|edit| {
|
|
let edit = edit.as_ref().unwrap();
|
|
AssistantEdit {
|
|
path: edit.path.clone(),
|
|
kind: edit.kind.clone(),
|
|
}
|
|
})
|
|
.collect::<Vec<_>>()
|
|
})
|
|
.collect::<Vec<_>>(),
|
|
expected_suggestions
|
|
);
|
|
}
|
|
}
|
|
|
|
#[gpui::test]
|
|
async fn test_serialization(cx: &mut TestAppContext) {
|
|
let settings_store = cx.update(SettingsStore::test);
|
|
cx.set_global(settings_store);
|
|
cx.update(LanguageModelRegistry::test);
|
|
let registry = Arc::new(LanguageRegistry::test(cx.executor()));
|
|
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
|
|
let context = cx.new(|cx| {
|
|
AssistantContext::local(
|
|
registry.clone(),
|
|
None,
|
|
None,
|
|
prompt_builder.clone(),
|
|
Arc::new(SlashCommandWorkingSet::default()),
|
|
cx,
|
|
)
|
|
});
|
|
let buffer = context.read_with(cx, |context, _| context.buffer.clone());
|
|
let message_0 = context.read_with(cx, |context, _| context.message_anchors[0].id);
|
|
let message_1 = context.update(cx, |context, cx| {
|
|
context
|
|
.insert_message_after(message_0, Role::Assistant, MessageStatus::Done, cx)
|
|
.unwrap()
|
|
});
|
|
let message_2 = context.update(cx, |context, cx| {
|
|
context
|
|
.insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx)
|
|
.unwrap()
|
|
});
|
|
buffer.update(cx, |buffer, cx| {
|
|
buffer.edit([(0..0, "a"), (1..1, "b\nc")], None, cx);
|
|
buffer.finalize_last_transaction();
|
|
});
|
|
let _message_3 = context.update(cx, |context, cx| {
|
|
context
|
|
.insert_message_after(message_2.id, Role::System, MessageStatus::Done, cx)
|
|
.unwrap()
|
|
});
|
|
buffer.update(cx, |buffer, cx| buffer.undo(cx));
|
|
assert_eq!(buffer.read_with(cx, |buffer, _| buffer.text()), "a\nb\nc\n");
|
|
assert_eq!(
|
|
cx.read(|cx| messages(&context, cx)),
|
|
[
|
|
(message_0, Role::User, 0..2),
|
|
(message_1.id, Role::Assistant, 2..6),
|
|
(message_2.id, Role::System, 6..6),
|
|
]
|
|
);
|
|
|
|
let serialized_context = context.read_with(cx, |context, cx| context.serialize(cx));
|
|
let deserialized_context = cx.new(|cx| {
|
|
AssistantContext::deserialize(
|
|
serialized_context,
|
|
Default::default(),
|
|
registry.clone(),
|
|
prompt_builder.clone(),
|
|
Arc::new(SlashCommandWorkingSet::default()),
|
|
None,
|
|
None,
|
|
cx,
|
|
)
|
|
});
|
|
let deserialized_buffer =
|
|
deserialized_context.read_with(cx, |context, _| context.buffer.clone());
|
|
assert_eq!(
|
|
deserialized_buffer.read_with(cx, |buffer, _| buffer.text()),
|
|
"a\nb\nc\n"
|
|
);
|
|
assert_eq!(
|
|
cx.read(|cx| messages(&deserialized_context, cx)),
|
|
[
|
|
(message_0, Role::User, 0..2),
|
|
(message_1.id, Role::Assistant, 2..6),
|
|
(message_2.id, Role::System, 6..6),
|
|
]
|
|
);
|
|
}
|
|
|
|
#[gpui::test(iterations = 100)]
|
|
async fn test_random_context_collaboration(cx: &mut TestAppContext, mut rng: StdRng) {
|
|
let min_peers = env::var("MIN_PEERS")
|
|
.map(|i| i.parse().expect("invalid `MIN_PEERS` variable"))
|
|
.unwrap_or(2);
|
|
let max_peers = env::var("MAX_PEERS")
|
|
.map(|i| i.parse().expect("invalid `MAX_PEERS` variable"))
|
|
.unwrap_or(5);
|
|
let operations = env::var("OPERATIONS")
|
|
.map(|i| i.parse().expect("invalid `OPERATIONS` variable"))
|
|
.unwrap_or(50);
|
|
|
|
let settings_store = cx.update(SettingsStore::test);
|
|
cx.set_global(settings_store);
|
|
cx.update(LanguageModelRegistry::test);
|
|
|
|
let slash_commands = cx.update(SlashCommandRegistry::default_global);
|
|
slash_commands.register_command(FakeSlashCommand("cmd-1".into()), false);
|
|
slash_commands.register_command(FakeSlashCommand("cmd-2".into()), false);
|
|
slash_commands.register_command(FakeSlashCommand("cmd-3".into()), false);
|
|
|
|
let registry = Arc::new(LanguageRegistry::test(cx.background_executor.clone()));
|
|
let network = Arc::new(Mutex::new(Network::new(rng.clone())));
|
|
let mut contexts = Vec::new();
|
|
|
|
let num_peers = rng.gen_range(min_peers..=max_peers);
|
|
let context_id = ContextId::new();
|
|
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
|
|
for i in 0..num_peers {
|
|
let context = cx.new(|cx| {
|
|
AssistantContext::new(
|
|
context_id.clone(),
|
|
i as ReplicaId,
|
|
language::Capability::ReadWrite,
|
|
registry.clone(),
|
|
prompt_builder.clone(),
|
|
Arc::new(SlashCommandWorkingSet::default()),
|
|
None,
|
|
None,
|
|
cx,
|
|
)
|
|
});
|
|
|
|
cx.update(|cx| {
|
|
cx.subscribe(&context, {
|
|
let network = network.clone();
|
|
move |_, event, _| {
|
|
if let ContextEvent::Operation(op) = event {
|
|
network
|
|
.lock()
|
|
.broadcast(i as ReplicaId, vec![op.to_proto()]);
|
|
}
|
|
}
|
|
})
|
|
.detach();
|
|
});
|
|
|
|
contexts.push(context);
|
|
network.lock().add_peer(i as ReplicaId);
|
|
}
|
|
|
|
let mut mutation_count = operations;
|
|
|
|
while mutation_count > 0
|
|
|| !network.lock().is_idle()
|
|
|| network.lock().contains_disconnected_peers()
|
|
{
|
|
let context_index = rng.gen_range(0..contexts.len());
|
|
let context = &contexts[context_index];
|
|
|
|
match rng.gen_range(0..100) {
|
|
0..=29 if mutation_count > 0 => {
|
|
log::info!("Context {}: edit buffer", context_index);
|
|
context.update(cx, |context, cx| {
|
|
context
|
|
.buffer
|
|
.update(cx, |buffer, cx| buffer.randomly_edit(&mut rng, 1, cx));
|
|
});
|
|
mutation_count -= 1;
|
|
}
|
|
30..=44 if mutation_count > 0 => {
|
|
context.update(cx, |context, cx| {
|
|
let range = context.buffer.read(cx).random_byte_range(0, &mut rng);
|
|
log::info!("Context {}: split message at {:?}", context_index, range);
|
|
context.split_message(range, cx);
|
|
});
|
|
mutation_count -= 1;
|
|
}
|
|
45..=59 if mutation_count > 0 => {
|
|
context.update(cx, |context, cx| {
|
|
if let Some(message) = context.messages(cx).choose(&mut rng) {
|
|
let role = *[Role::User, Role::Assistant, Role::System]
|
|
.choose(&mut rng)
|
|
.unwrap();
|
|
log::info!(
|
|
"Context {}: insert message after {:?} with {:?}",
|
|
context_index,
|
|
message.id,
|
|
role
|
|
);
|
|
context.insert_message_after(message.id, role, MessageStatus::Done, cx);
|
|
}
|
|
});
|
|
mutation_count -= 1;
|
|
}
|
|
60..=74 if mutation_count > 0 => {
|
|
context.update(cx, |context, cx| {
|
|
let command_text = "/".to_string()
|
|
+ slash_commands
|
|
.command_names()
|
|
.choose(&mut rng)
|
|
.unwrap()
|
|
.clone()
|
|
.as_ref();
|
|
|
|
let command_range = context.buffer.update(cx, |buffer, cx| {
|
|
let offset = buffer.random_byte_range(0, &mut rng).start;
|
|
buffer.edit(
|
|
[(offset..offset, format!("\n{}\n", command_text))],
|
|
None,
|
|
cx,
|
|
);
|
|
offset + 1..offset + 1 + command_text.len()
|
|
});
|
|
|
|
let output_text = RandomCharIter::new(&mut rng)
|
|
.filter(|c| *c != '\r')
|
|
.take(10)
|
|
.collect::<String>();
|
|
|
|
let mut events = vec![Ok(SlashCommandEvent::StartMessage {
|
|
role: Role::User,
|
|
merge_same_roles: true,
|
|
})];
|
|
|
|
let num_sections = rng.gen_range(0..=3);
|
|
let mut section_start = 0;
|
|
for _ in 0..num_sections {
|
|
let mut section_end = rng.gen_range(section_start..=output_text.len());
|
|
while !output_text.is_char_boundary(section_end) {
|
|
section_end += 1;
|
|
}
|
|
events.push(Ok(SlashCommandEvent::StartSection {
|
|
icon: IconName::Ai,
|
|
label: "section".into(),
|
|
metadata: None,
|
|
}));
|
|
events.push(Ok(SlashCommandEvent::Content(SlashCommandContent::Text {
|
|
text: output_text[section_start..section_end].to_string(),
|
|
run_commands_in_text: false,
|
|
})));
|
|
events.push(Ok(SlashCommandEvent::EndSection));
|
|
section_start = section_end;
|
|
}
|
|
|
|
if section_start < output_text.len() {
|
|
events.push(Ok(SlashCommandEvent::Content(SlashCommandContent::Text {
|
|
text: output_text[section_start..].to_string(),
|
|
run_commands_in_text: false,
|
|
})));
|
|
}
|
|
|
|
log::info!(
|
|
"Context {}: insert slash command output at {:?} with {:?} events",
|
|
context_index,
|
|
command_range,
|
|
events.len()
|
|
);
|
|
|
|
let command_range = context.buffer.read(cx).anchor_after(command_range.start)
|
|
..context.buffer.read(cx).anchor_after(command_range.end);
|
|
context.insert_command_output(
|
|
command_range,
|
|
"/command",
|
|
Task::ready(Ok(stream::iter(events).boxed())),
|
|
true,
|
|
cx,
|
|
);
|
|
});
|
|
cx.run_until_parked();
|
|
mutation_count -= 1;
|
|
}
|
|
75..=84 if mutation_count > 0 => {
|
|
context.update(cx, |context, cx| {
|
|
if let Some(message) = context.messages(cx).choose(&mut rng) {
|
|
let new_status = match rng.gen_range(0..3) {
|
|
0 => MessageStatus::Done,
|
|
1 => MessageStatus::Pending,
|
|
_ => MessageStatus::Error(SharedString::from("Random error")),
|
|
};
|
|
log::info!(
|
|
"Context {}: update message {:?} status to {:?}",
|
|
context_index,
|
|
message.id,
|
|
new_status
|
|
);
|
|
context.update_metadata(message.id, cx, |metadata| {
|
|
metadata.status = new_status;
|
|
});
|
|
}
|
|
});
|
|
mutation_count -= 1;
|
|
}
|
|
_ => {
|
|
let replica_id = context_index as ReplicaId;
|
|
if network.lock().is_disconnected(replica_id) {
|
|
network.lock().reconnect_peer(replica_id, 0);
|
|
|
|
let (ops_to_send, ops_to_receive) = cx.read(|cx| {
|
|
let host_context = &contexts[0].read(cx);
|
|
let guest_context = context.read(cx);
|
|
(
|
|
guest_context.serialize_ops(&host_context.version(cx), cx),
|
|
host_context.serialize_ops(&guest_context.version(cx), cx),
|
|
)
|
|
});
|
|
let ops_to_send = ops_to_send.await;
|
|
let ops_to_receive = ops_to_receive
|
|
.await
|
|
.into_iter()
|
|
.map(ContextOperation::from_proto)
|
|
.collect::<Result<Vec<_>>>()
|
|
.unwrap();
|
|
log::info!(
|
|
"Context {}: reconnecting. Sent {} operations, received {} operations",
|
|
context_index,
|
|
ops_to_send.len(),
|
|
ops_to_receive.len()
|
|
);
|
|
|
|
network.lock().broadcast(replica_id, ops_to_send);
|
|
context.update(cx, |context, cx| context.apply_ops(ops_to_receive, cx));
|
|
} else if rng.gen_bool(0.1) && replica_id != 0 {
|
|
log::info!("Context {}: disconnecting", context_index);
|
|
network.lock().disconnect_peer(replica_id);
|
|
} else if network.lock().has_unreceived(replica_id) {
|
|
log::info!("Context {}: applying operations", context_index);
|
|
let ops = network.lock().receive(replica_id);
|
|
let ops = ops
|
|
.into_iter()
|
|
.map(ContextOperation::from_proto)
|
|
.collect::<Result<Vec<_>>>()
|
|
.unwrap();
|
|
context.update(cx, |context, cx| context.apply_ops(ops, cx));
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
cx.read(|cx| {
|
|
let first_context = contexts[0].read(cx);
|
|
for context in &contexts[1..] {
|
|
let context = context.read(cx);
|
|
assert!(context.pending_ops.is_empty(), "pending ops: {:?}", context.pending_ops);
|
|
assert_eq!(
|
|
context.buffer.read(cx).text(),
|
|
first_context.buffer.read(cx).text(),
|
|
"Context {} text != Context 0 text",
|
|
context.buffer.read(cx).replica_id()
|
|
);
|
|
assert_eq!(
|
|
context.message_anchors,
|
|
first_context.message_anchors,
|
|
"Context {} messages != Context 0 messages",
|
|
context.buffer.read(cx).replica_id()
|
|
);
|
|
assert_eq!(
|
|
context.messages_metadata,
|
|
first_context.messages_metadata,
|
|
"Context {} message metadata != Context 0 message metadata",
|
|
context.buffer.read(cx).replica_id()
|
|
);
|
|
assert_eq!(
|
|
context.slash_command_output_sections,
|
|
first_context.slash_command_output_sections,
|
|
"Context {} slash command output sections != Context 0 slash command output sections",
|
|
context.buffer.read(cx).replica_id()
|
|
);
|
|
}
|
|
});
|
|
}
|
|
|
|
#[gpui::test]
|
|
fn test_mark_cache_anchors(cx: &mut App) {
|
|
let settings_store = SettingsStore::test(cx);
|
|
LanguageModelRegistry::test(cx);
|
|
cx.set_global(settings_store);
|
|
let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
|
|
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
|
|
let context = cx.new(|cx| {
|
|
AssistantContext::local(
|
|
registry,
|
|
None,
|
|
None,
|
|
prompt_builder.clone(),
|
|
Arc::new(SlashCommandWorkingSet::default()),
|
|
cx,
|
|
)
|
|
});
|
|
let buffer = context.read(cx).buffer.clone();
|
|
|
|
// Create a test cache configuration
|
|
let cache_configuration = &Some(LanguageModelCacheConfiguration {
|
|
max_cache_anchors: 3,
|
|
should_speculate: true,
|
|
min_total_token: 10,
|
|
});
|
|
|
|
let message_1 = context.read(cx).message_anchors[0].clone();
|
|
|
|
context.update(cx, |context, cx| {
|
|
context.mark_cache_anchors(cache_configuration, false, cx)
|
|
});
|
|
|
|
assert_eq!(
|
|
messages_cache(&context, cx)
|
|
.iter()
|
|
.filter(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor))
|
|
.count(),
|
|
0,
|
|
"Empty messages should not have any cache anchors."
|
|
);
|
|
|
|
buffer.update(cx, |buffer, cx| buffer.edit([(0..0, "aaa")], None, cx));
|
|
let message_2 = context
|
|
.update(cx, |context, cx| {
|
|
context.insert_message_after(message_1.id, Role::User, MessageStatus::Pending, cx)
|
|
})
|
|
.unwrap();
|
|
|
|
buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "bbbbbbb")], None, cx));
|
|
let message_3 = context
|
|
.update(cx, |context, cx| {
|
|
context.insert_message_after(message_2.id, Role::User, MessageStatus::Pending, cx)
|
|
})
|
|
.unwrap();
|
|
buffer.update(cx, |buffer, cx| buffer.edit([(12..12, "cccccc")], None, cx));
|
|
|
|
context.update(cx, |context, cx| {
|
|
context.mark_cache_anchors(cache_configuration, false, cx)
|
|
});
|
|
assert_eq!(buffer.read(cx).text(), "aaa\nbbbbbbb\ncccccc");
|
|
assert_eq!(
|
|
messages_cache(&context, cx)
|
|
.iter()
|
|
.filter(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor))
|
|
.count(),
|
|
0,
|
|
"Messages should not be marked for cache before going over the token minimum."
|
|
);
|
|
context.update(cx, |context, _| {
|
|
context.token_count = Some(20);
|
|
});
|
|
|
|
context.update(cx, |context, cx| {
|
|
context.mark_cache_anchors(cache_configuration, true, cx)
|
|
});
|
|
assert_eq!(
|
|
messages_cache(&context, cx)
|
|
.iter()
|
|
.map(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor))
|
|
.collect::<Vec<bool>>(),
|
|
vec![true, true, false],
|
|
"Last message should not be an anchor on speculative request."
|
|
);
|
|
|
|
context
|
|
.update(cx, |context, cx| {
|
|
context.insert_message_after(message_3.id, Role::Assistant, MessageStatus::Pending, cx)
|
|
})
|
|
.unwrap();
|
|
|
|
context.update(cx, |context, cx| {
|
|
context.mark_cache_anchors(cache_configuration, false, cx)
|
|
});
|
|
assert_eq!(
|
|
messages_cache(&context, cx)
|
|
.iter()
|
|
.map(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor))
|
|
.collect::<Vec<bool>>(),
|
|
vec![false, true, true, false],
|
|
"Most recent message should also be cached if not a speculative request."
|
|
);
|
|
context.update(cx, |context, cx| {
|
|
context.update_cache_status_for_completion(cx)
|
|
});
|
|
assert_eq!(
|
|
messages_cache(&context, cx)
|
|
.iter()
|
|
.map(|(_, cache)| cache
|
|
.as_ref()
|
|
.map_or(None, |cache| Some(cache.status.clone())))
|
|
.collect::<Vec<Option<CacheStatus>>>(),
|
|
vec![
|
|
Some(CacheStatus::Cached),
|
|
Some(CacheStatus::Cached),
|
|
Some(CacheStatus::Cached),
|
|
None
|
|
],
|
|
"All user messages prior to anchor should be marked as cached."
|
|
);
|
|
|
|
buffer.update(cx, |buffer, cx| buffer.edit([(14..14, "d")], None, cx));
|
|
context.update(cx, |context, cx| {
|
|
context.mark_cache_anchors(cache_configuration, false, cx)
|
|
});
|
|
assert_eq!(
|
|
messages_cache(&context, cx)
|
|
.iter()
|
|
.map(|(_, cache)| cache
|
|
.as_ref()
|
|
.map_or(None, |cache| Some(cache.status.clone())))
|
|
.collect::<Vec<Option<CacheStatus>>>(),
|
|
vec![
|
|
Some(CacheStatus::Cached),
|
|
Some(CacheStatus::Cached),
|
|
Some(CacheStatus::Pending),
|
|
None
|
|
],
|
|
"Modifying a message should invalidate it's cache but leave previous messages."
|
|
);
|
|
buffer.update(cx, |buffer, cx| buffer.edit([(2..2, "e")], None, cx));
|
|
context.update(cx, |context, cx| {
|
|
context.mark_cache_anchors(cache_configuration, false, cx)
|
|
});
|
|
assert_eq!(
|
|
messages_cache(&context, cx)
|
|
.iter()
|
|
.map(|(_, cache)| cache
|
|
.as_ref()
|
|
.map_or(None, |cache| Some(cache.status.clone())))
|
|
.collect::<Vec<Option<CacheStatus>>>(),
|
|
vec![
|
|
Some(CacheStatus::Pending),
|
|
Some(CacheStatus::Pending),
|
|
Some(CacheStatus::Pending),
|
|
None
|
|
],
|
|
"Modifying a message should invalidate all future messages."
|
|
);
|
|
}
|
|
|
|
fn messages(context: &Entity<AssistantContext>, cx: &App) -> Vec<(MessageId, Role, Range<usize>)> {
|
|
context
|
|
.read(cx)
|
|
.messages(cx)
|
|
.map(|message| (message.id, message.role, message.offset_range))
|
|
.collect()
|
|
}
|
|
|
|
fn messages_cache(
|
|
context: &Entity<AssistantContext>,
|
|
cx: &App,
|
|
) -> Vec<(MessageId, Option<MessageCacheMetadata>)> {
|
|
context
|
|
.read(cx)
|
|
.messages(cx)
|
|
.map(|message| (message.id, message.cache.clone()))
|
|
.collect()
|
|
}
|
|
|
|
#[derive(Clone)]
|
|
struct FakeSlashCommand(String);
|
|
|
|
impl SlashCommand for FakeSlashCommand {
|
|
fn name(&self) -> String {
|
|
self.0.clone()
|
|
}
|
|
|
|
fn description(&self) -> String {
|
|
format!("Fake slash command: {}", self.0)
|
|
}
|
|
|
|
fn menu_text(&self) -> String {
|
|
format!("Run fake command: {}", self.0)
|
|
}
|
|
|
|
fn complete_argument(
|
|
self: Arc<Self>,
|
|
_arguments: &[String],
|
|
_cancel: Arc<AtomicBool>,
|
|
_workspace: Option<WeakEntity<Workspace>>,
|
|
_window: &mut Window,
|
|
_cx: &mut App,
|
|
) -> Task<Result<Vec<ArgumentCompletion>>> {
|
|
Task::ready(Ok(vec![]))
|
|
}
|
|
|
|
fn requires_argument(&self) -> bool {
|
|
false
|
|
}
|
|
|
|
fn run(
|
|
self: Arc<Self>,
|
|
_arguments: &[String],
|
|
_context_slash_command_output_sections: &[SlashCommandOutputSection<language::Anchor>],
|
|
_context_buffer: BufferSnapshot,
|
|
_workspace: WeakEntity<Workspace>,
|
|
_delegate: Option<Arc<dyn LspAdapterDelegate>>,
|
|
_window: &mut Window,
|
|
_cx: &mut App,
|
|
) -> Task<SlashCommandResult> {
|
|
Task::ready(Ok(SlashCommandOutput {
|
|
text: format!("Executed fake command: {}", self.0),
|
|
sections: vec![],
|
|
run_commands_in_text: false,
|
|
}
|
|
.to_event_stream()))
|
|
}
|
|
}
|