ZIm/crates/assistant_context_editor/src/context/context_tests.rs
Marshall Bowers 4ab372d6b5
assistant: Unship tool use (#23969)
This PR unships tool use from Assistant1.

This was only ever partially implemented, and was never released to end
users.

Assistant2 will support tool use.

Release Notes:

- N/A
2025-01-30 19:46:15 +00:00

1657 lines
52 KiB
Rust

use crate::{
AssistantContext, AssistantEdit, AssistantEditKind, CacheStatus, ContextEvent, ContextId,
ContextOperation, InvokedSlashCommandId, MessageCacheMetadata, MessageId, MessageStatus,
};
use anyhow::Result;
use assistant_slash_command::{
ArgumentCompletion, SlashCommand, SlashCommandContent, SlashCommandEvent, SlashCommandOutput,
SlashCommandOutputSection, SlashCommandRegistry, SlashCommandResult, SlashCommandWorkingSet,
};
use assistant_slash_commands::FileSlashCommand;
use collections::{HashMap, HashSet};
use fs::FakeFs;
use futures::{
channel::mpsc,
stream::{self, StreamExt},
};
use gpui::{prelude::*, App, Entity, SharedString, Task, TestAppContext, WeakEntity};
use language::{Buffer, BufferSnapshot, LanguageRegistry, LspAdapterDelegate};
use language_model::{LanguageModelCacheConfiguration, LanguageModelRegistry, Role};
use parking_lot::Mutex;
use pretty_assertions::assert_eq;
use project::Project;
use prompt_library::PromptBuilder;
use rand::prelude::*;
use serde_json::json;
use settings::SettingsStore;
use std::{
cell::RefCell,
env,
ops::Range,
path::Path,
rc::Rc,
sync::{atomic::AtomicBool, Arc},
};
use text::{network::Network, OffsetRangeExt as _, ReplicaId, ToOffset};
use ui::{IconName, Window};
use unindent::Unindent;
use util::{
test::{generate_marked_text, marked_text_ranges},
RandomCharIter,
};
use workspace::Workspace;
#[gpui::test]
fn test_inserting_and_removing_messages(cx: &mut App) {
let settings_store = SettingsStore::test(cx);
LanguageModelRegistry::test(cx);
cx.set_global(settings_store);
let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
let context = cx.new(|cx| {
AssistantContext::local(
registry,
None,
None,
prompt_builder.clone(),
Arc::new(SlashCommandWorkingSet::default()),
cx,
)
});
let buffer = context.read(cx).buffer.clone();
let message_1 = context.read(cx).message_anchors[0].clone();
assert_eq!(
messages(&context, cx),
vec![(message_1.id, Role::User, 0..0)]
);
let message_2 = context.update(cx, |context, cx| {
context
.insert_message_after(message_1.id, Role::Assistant, MessageStatus::Done, cx)
.unwrap()
});
assert_eq!(
messages(&context, cx),
vec![
(message_1.id, Role::User, 0..1),
(message_2.id, Role::Assistant, 1..1)
]
);
buffer.update(cx, |buffer, cx| {
buffer.edit([(0..0, "1"), (1..1, "2")], None, cx)
});
assert_eq!(
messages(&context, cx),
vec![
(message_1.id, Role::User, 0..2),
(message_2.id, Role::Assistant, 2..3)
]
);
let message_3 = context.update(cx, |context, cx| {
context
.insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
.unwrap()
});
assert_eq!(
messages(&context, cx),
vec![
(message_1.id, Role::User, 0..2),
(message_2.id, Role::Assistant, 2..4),
(message_3.id, Role::User, 4..4)
]
);
let message_4 = context.update(cx, |context, cx| {
context
.insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
.unwrap()
});
assert_eq!(
messages(&context, cx),
vec![
(message_1.id, Role::User, 0..2),
(message_2.id, Role::Assistant, 2..4),
(message_4.id, Role::User, 4..5),
(message_3.id, Role::User, 5..5),
]
);
buffer.update(cx, |buffer, cx| {
buffer.edit([(4..4, "C"), (5..5, "D")], None, cx)
});
assert_eq!(
messages(&context, cx),
vec![
(message_1.id, Role::User, 0..2),
(message_2.id, Role::Assistant, 2..4),
(message_4.id, Role::User, 4..6),
(message_3.id, Role::User, 6..7),
]
);
// Deleting across message boundaries merges the messages.
buffer.update(cx, |buffer, cx| buffer.edit([(1..4, "")], None, cx));
assert_eq!(
messages(&context, cx),
vec![
(message_1.id, Role::User, 0..3),
(message_3.id, Role::User, 3..4),
]
);
// Undoing the deletion should also undo the merge.
buffer.update(cx, |buffer, cx| buffer.undo(cx));
assert_eq!(
messages(&context, cx),
vec![
(message_1.id, Role::User, 0..2),
(message_2.id, Role::Assistant, 2..4),
(message_4.id, Role::User, 4..6),
(message_3.id, Role::User, 6..7),
]
);
// Redoing the deletion should also redo the merge.
buffer.update(cx, |buffer, cx| buffer.redo(cx));
assert_eq!(
messages(&context, cx),
vec![
(message_1.id, Role::User, 0..3),
(message_3.id, Role::User, 3..4),
]
);
// Ensure we can still insert after a merged message.
let message_5 = context.update(cx, |context, cx| {
context
.insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx)
.unwrap()
});
assert_eq!(
messages(&context, cx),
vec![
(message_1.id, Role::User, 0..3),
(message_5.id, Role::System, 3..4),
(message_3.id, Role::User, 4..5)
]
);
}
#[gpui::test]
fn test_message_splitting(cx: &mut App) {
let settings_store = SettingsStore::test(cx);
cx.set_global(settings_store);
LanguageModelRegistry::test(cx);
let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
let context = cx.new(|cx| {
AssistantContext::local(
registry.clone(),
None,
None,
prompt_builder.clone(),
Arc::new(SlashCommandWorkingSet::default()),
cx,
)
});
let buffer = context.read(cx).buffer.clone();
let message_1 = context.read(cx).message_anchors[0].clone();
assert_eq!(
messages(&context, cx),
vec![(message_1.id, Role::User, 0..0)]
);
buffer.update(cx, |buffer, cx| {
buffer.edit([(0..0, "aaa\nbbb\nccc\nddd\n")], None, cx)
});
let (_, message_2) = context.update(cx, |context, cx| context.split_message(3..3, cx));
let message_2 = message_2.unwrap();
// We recycle newlines in the middle of a split message
assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\nddd\n");
assert_eq!(
messages(&context, cx),
vec![
(message_1.id, Role::User, 0..4),
(message_2.id, Role::User, 4..16),
]
);
let (_, message_3) = context.update(cx, |context, cx| context.split_message(3..3, cx));
let message_3 = message_3.unwrap();
// We don't recycle newlines at the end of a split message
assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n");
assert_eq!(
messages(&context, cx),
vec![
(message_1.id, Role::User, 0..4),
(message_3.id, Role::User, 4..5),
(message_2.id, Role::User, 5..17),
]
);
let (_, message_4) = context.update(cx, |context, cx| context.split_message(9..9, cx));
let message_4 = message_4.unwrap();
assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n");
assert_eq!(
messages(&context, cx),
vec![
(message_1.id, Role::User, 0..4),
(message_3.id, Role::User, 4..5),
(message_2.id, Role::User, 5..9),
(message_4.id, Role::User, 9..17),
]
);
let (_, message_5) = context.update(cx, |context, cx| context.split_message(9..9, cx));
let message_5 = message_5.unwrap();
assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\nddd\n");
assert_eq!(
messages(&context, cx),
vec![
(message_1.id, Role::User, 0..4),
(message_3.id, Role::User, 4..5),
(message_2.id, Role::User, 5..9),
(message_4.id, Role::User, 9..10),
(message_5.id, Role::User, 10..18),
]
);
let (message_6, message_7) =
context.update(cx, |context, cx| context.split_message(14..16, cx));
let message_6 = message_6.unwrap();
let message_7 = message_7.unwrap();
assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\ndd\nd\n");
assert_eq!(
messages(&context, cx),
vec![
(message_1.id, Role::User, 0..4),
(message_3.id, Role::User, 4..5),
(message_2.id, Role::User, 5..9),
(message_4.id, Role::User, 9..10),
(message_5.id, Role::User, 10..14),
(message_6.id, Role::User, 14..17),
(message_7.id, Role::User, 17..19),
]
);
}
#[gpui::test]
fn test_messages_for_offsets(cx: &mut App) {
let settings_store = SettingsStore::test(cx);
LanguageModelRegistry::test(cx);
cx.set_global(settings_store);
let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
let context = cx.new(|cx| {
AssistantContext::local(
registry,
None,
None,
prompt_builder.clone(),
Arc::new(SlashCommandWorkingSet::default()),
cx,
)
});
let buffer = context.read(cx).buffer.clone();
let message_1 = context.read(cx).message_anchors[0].clone();
assert_eq!(
messages(&context, cx),
vec![(message_1.id, Role::User, 0..0)]
);
buffer.update(cx, |buffer, cx| buffer.edit([(0..0, "aaa")], None, cx));
let message_2 = context
.update(cx, |context, cx| {
context.insert_message_after(message_1.id, Role::User, MessageStatus::Done, cx)
})
.unwrap();
buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "bbb")], None, cx));
let message_3 = context
.update(cx, |context, cx| {
context.insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
})
.unwrap();
buffer.update(cx, |buffer, cx| buffer.edit([(8..8, "ccc")], None, cx));
assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc");
assert_eq!(
messages(&context, cx),
vec![
(message_1.id, Role::User, 0..4),
(message_2.id, Role::User, 4..8),
(message_3.id, Role::User, 8..11)
]
);
assert_eq!(
message_ids_for_offsets(&context, &[0, 4, 9], cx),
[message_1.id, message_2.id, message_3.id]
);
assert_eq!(
message_ids_for_offsets(&context, &[0, 1, 11], cx),
[message_1.id, message_3.id]
);
let message_4 = context
.update(cx, |context, cx| {
context.insert_message_after(message_3.id, Role::User, MessageStatus::Done, cx)
})
.unwrap();
assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\n");
assert_eq!(
messages(&context, cx),
vec![
(message_1.id, Role::User, 0..4),
(message_2.id, Role::User, 4..8),
(message_3.id, Role::User, 8..12),
(message_4.id, Role::User, 12..12)
]
);
assert_eq!(
message_ids_for_offsets(&context, &[0, 4, 8, 12], cx),
[message_1.id, message_2.id, message_3.id, message_4.id]
);
fn message_ids_for_offsets(
context: &Entity<AssistantContext>,
offsets: &[usize],
cx: &App,
) -> Vec<MessageId> {
context
.read(cx)
.messages_for_offsets(offsets.iter().copied(), cx)
.into_iter()
.map(|message| message.id)
.collect()
}
}
#[gpui::test]
async fn test_slash_commands(cx: &mut TestAppContext) {
let settings_store = cx.update(SettingsStore::test);
cx.set_global(settings_store);
cx.update(LanguageModelRegistry::test);
cx.update(Project::init_settings);
let fs = FakeFs::new(cx.background_executor.clone());
fs.insert_tree(
"/test",
json!({
"src": {
"lib.rs": "fn one() -> usize { 1 }",
"main.rs": "
use crate::one;
fn main() { one(); }
".unindent(),
}
}),
)
.await;
let slash_command_registry = cx.update(SlashCommandRegistry::default_global);
slash_command_registry.register_command(FileSlashCommand, false);
let registry = Arc::new(LanguageRegistry::test(cx.executor()));
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
let context = cx.new(|cx| {
AssistantContext::local(
registry.clone(),
None,
None,
prompt_builder.clone(),
Arc::new(SlashCommandWorkingSet::default()),
cx,
)
});
#[derive(Default)]
struct ContextRanges {
parsed_commands: HashSet<Range<language::Anchor>>,
command_outputs: HashMap<InvokedSlashCommandId, Range<language::Anchor>>,
output_sections: HashSet<Range<language::Anchor>>,
}
let context_ranges = Rc::new(RefCell::new(ContextRanges::default()));
context.update(cx, |_, cx| {
cx.subscribe(&context, {
let context_ranges = context_ranges.clone();
move |context, _, event, _| {
let mut context_ranges = context_ranges.borrow_mut();
match event {
ContextEvent::InvokedSlashCommandChanged { command_id } => {
let command = context.invoked_slash_command(command_id).unwrap();
context_ranges
.command_outputs
.insert(*command_id, command.range.clone());
}
ContextEvent::ParsedSlashCommandsUpdated { removed, updated } => {
for range in removed {
context_ranges.parsed_commands.remove(range);
}
for command in updated {
context_ranges
.parsed_commands
.insert(command.source_range.clone());
}
}
ContextEvent::SlashCommandOutputSectionAdded { section } => {
context_ranges.output_sections.insert(section.range.clone());
}
_ => {}
}
}
})
.detach();
});
let buffer = context.read_with(cx, |context, _| context.buffer.clone());
// Insert a slash command
buffer.update(cx, |buffer, cx| {
buffer.edit([(0..0, "/file src/lib.rs")], None, cx);
});
assert_text_and_context_ranges(
&buffer,
&context_ranges,
&"
«/file src/lib.rs»"
.unindent(),
cx,
);
// Edit the argument of the slash command.
buffer.update(cx, |buffer, cx| {
let edit_offset = buffer.text().find("lib.rs").unwrap();
buffer.edit([(edit_offset..edit_offset + "lib".len(), "main")], None, cx);
});
assert_text_and_context_ranges(
&buffer,
&context_ranges,
&"
«/file src/main.rs»"
.unindent(),
cx,
);
// Edit the name of the slash command, using one that doesn't exist.
buffer.update(cx, |buffer, cx| {
let edit_offset = buffer.text().find("/file").unwrap();
buffer.edit(
[(edit_offset..edit_offset + "/file".len(), "/unknown")],
None,
cx,
);
});
assert_text_and_context_ranges(
&buffer,
&context_ranges,
&"
/unknown src/main.rs"
.unindent(),
cx,
);
// Undoing the insertion of an non-existent slash command resorts the previous one.
buffer.update(cx, |buffer, cx| buffer.undo(cx));
assert_text_and_context_ranges(
&buffer,
&context_ranges,
&"
«/file src/main.rs»"
.unindent(),
cx,
);
let (command_output_tx, command_output_rx) = mpsc::unbounded();
context.update(cx, |context, cx| {
let command_source_range = context.parsed_slash_commands[0].source_range.clone();
context.insert_command_output(
command_source_range,
"file",
Task::ready(Ok(command_output_rx.boxed())),
true,
cx,
);
});
assert_text_and_context_ranges(
&buffer,
&context_ranges,
&"
⟦«/file src/main.rs»
…⟧
"
.unindent(),
cx,
);
command_output_tx
.unbounded_send(Ok(SlashCommandEvent::StartSection {
icon: IconName::Ai,
label: "src/main.rs".into(),
metadata: None,
}))
.unwrap();
command_output_tx
.unbounded_send(Ok(SlashCommandEvent::Content("src/main.rs".into())))
.unwrap();
cx.run_until_parked();
assert_text_and_context_ranges(
&buffer,
&context_ranges,
&"
⟦«/file src/main.rs»
src/main.rs…⟧
"
.unindent(),
cx,
);
command_output_tx
.unbounded_send(Ok(SlashCommandEvent::Content("\nfn main() {}".into())))
.unwrap();
cx.run_until_parked();
assert_text_and_context_ranges(
&buffer,
&context_ranges,
&"
⟦«/file src/main.rs»
src/main.rs
fn main() {}…⟧
"
.unindent(),
cx,
);
command_output_tx
.unbounded_send(Ok(SlashCommandEvent::EndSection))
.unwrap();
cx.run_until_parked();
assert_text_and_context_ranges(
&buffer,
&context_ranges,
&"
⟦«/file src/main.rs»
⟪src/main.rs
fn main() {}⟫…⟧
"
.unindent(),
cx,
);
drop(command_output_tx);
cx.run_until_parked();
assert_text_and_context_ranges(
&buffer,
&context_ranges,
&"
⟦⟪src/main.rs
fn main() {}⟫⟧
"
.unindent(),
cx,
);
#[track_caller]
fn assert_text_and_context_ranges(
buffer: &Entity<Buffer>,
ranges: &RefCell<ContextRanges>,
expected_marked_text: &str,
cx: &mut TestAppContext,
) {
let mut actual_marked_text = String::new();
buffer.update(cx, |buffer, _| {
struct Endpoint {
offset: usize,
marker: char,
}
let ranges = ranges.borrow();
let mut endpoints = Vec::new();
for range in ranges.command_outputs.values() {
endpoints.push(Endpoint {
offset: range.start.to_offset(buffer),
marker: '',
});
}
for range in ranges.parsed_commands.iter() {
endpoints.push(Endpoint {
offset: range.start.to_offset(buffer),
marker: '«',
});
}
for range in ranges.output_sections.iter() {
endpoints.push(Endpoint {
offset: range.start.to_offset(buffer),
marker: '',
});
}
for range in ranges.output_sections.iter() {
endpoints.push(Endpoint {
offset: range.end.to_offset(buffer),
marker: '',
});
}
for range in ranges.parsed_commands.iter() {
endpoints.push(Endpoint {
offset: range.end.to_offset(buffer),
marker: '»',
});
}
for range in ranges.command_outputs.values() {
endpoints.push(Endpoint {
offset: range.end.to_offset(buffer),
marker: '',
});
}
endpoints.sort_by_key(|endpoint| endpoint.offset);
let mut offset = 0;
for endpoint in endpoints {
actual_marked_text.extend(buffer.text_for_range(offset..endpoint.offset));
actual_marked_text.push(endpoint.marker);
offset = endpoint.offset;
}
actual_marked_text.extend(buffer.text_for_range(offset..buffer.len()));
});
assert_eq!(actual_marked_text, expected_marked_text);
}
}
#[gpui::test]
async fn test_workflow_step_parsing(cx: &mut TestAppContext) {
cx.update(prompt_library::init);
let mut settings_store = cx.update(SettingsStore::test);
cx.update(|cx| {
settings_store
.set_user_settings(
r#"{ "assistant": { "enable_experimental_live_diffs": true } }"#,
cx,
)
.unwrap()
});
cx.set_global(settings_store);
cx.update(language::init);
cx.update(Project::init_settings);
let fs = FakeFs::new(cx.executor());
let project = Project::test(fs, [Path::new("/root")], cx).await;
cx.update(LanguageModelRegistry::test);
let registry = Arc::new(LanguageRegistry::test(cx.executor()));
// Create a new context
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
let context = cx.new(|cx| {
AssistantContext::local(
registry.clone(),
Some(project),
None,
prompt_builder.clone(),
Arc::new(SlashCommandWorkingSet::default()),
cx,
)
});
// Insert an assistant message to simulate a response.
let assistant_message_id = context.update(cx, |context, cx| {
let user_message_id = context.messages(cx).next().unwrap().id;
context
.insert_message_after(user_message_id, Role::Assistant, MessageStatus::Done, cx)
.unwrap()
.id
});
// No edit tags
edit(
&context,
"
«one
two
»",
cx,
);
expect_patches(
&context,
"
one
two
",
&[],
cx,
);
// Partial edit step tag is added
edit(
&context,
"
one
two
«
<patch»",
cx,
);
expect_patches(
&context,
"
one
two
<patch",
&[],
cx,
);
// The rest of the step tag is added. The unclosed
// step is treated as incomplete.
edit(
&context,
"
one
two
<patch«>
<edit>»",
cx,
);
expect_patches(
&context,
"
one
two
«<patch>
<edit>»",
&[&[]],
cx,
);
// The full patch is added
edit(
&context,
"
one
two
<patch>
<edit>«
<description>add a `two` function</description>
<path>src/lib.rs</path>
<operation>insert_after</operation>
<old_text>fn one</old_text>
<new_text>
fn two() {}
</new_text>
</edit>
</patch>
also,»",
cx,
);
expect_patches(
&context,
"
one
two
«<patch>
<edit>
<description>add a `two` function</description>
<path>src/lib.rs</path>
<operation>insert_after</operation>
<old_text>fn one</old_text>
<new_text>
fn two() {}
</new_text>
</edit>
</patch>
»
also,",
&[&[AssistantEdit {
path: "src/lib.rs".into(),
kind: AssistantEditKind::InsertAfter {
old_text: "fn one".into(),
new_text: "fn two() {}".into(),
description: Some("add a `two` function".into()),
},
}]],
cx,
);
// The step is manually edited.
edit(
&context,
"
one
two
<patch>
<edit>
<description>add a `two` function</description>
<path>src/lib.rs</path>
<operation>insert_after</operation>
<old_text>«fn zero»</old_text>
<new_text>
fn two() {}
</new_text>
</edit>
</patch>
also,",
cx,
);
expect_patches(
&context,
"
one
two
«<patch>
<edit>
<description>add a `two` function</description>
<path>src/lib.rs</path>
<operation>insert_after</operation>
<old_text>fn zero</old_text>
<new_text>
fn two() {}
</new_text>
</edit>
</patch>
»
also,",
&[&[AssistantEdit {
path: "src/lib.rs".into(),
kind: AssistantEditKind::InsertAfter {
old_text: "fn zero".into(),
new_text: "fn two() {}".into(),
description: Some("add a `two` function".into()),
},
}]],
cx,
);
// When setting the message role to User, the steps are cleared.
context.update(cx, |context, cx| {
context.cycle_message_roles(HashSet::from_iter([assistant_message_id]), cx);
context.cycle_message_roles(HashSet::from_iter([assistant_message_id]), cx);
});
expect_patches(
&context,
"
one
two
<patch>
<edit>
<description>add a `two` function</description>
<path>src/lib.rs</path>
<operation>insert_after</operation>
<old_text>fn zero</old_text>
<new_text>
fn two() {}
</new_text>
</edit>
</patch>
also,",
&[],
cx,
);
// When setting the message role back to Assistant, the steps are reparsed.
context.update(cx, |context, cx| {
context.cycle_message_roles(HashSet::from_iter([assistant_message_id]), cx);
});
expect_patches(
&context,
"
one
two
«<patch>
<edit>
<description>add a `two` function</description>
<path>src/lib.rs</path>
<operation>insert_after</operation>
<old_text>fn zero</old_text>
<new_text>
fn two() {}
</new_text>
</edit>
</patch>
»
also,",
&[&[AssistantEdit {
path: "src/lib.rs".into(),
kind: AssistantEditKind::InsertAfter {
old_text: "fn zero".into(),
new_text: "fn two() {}".into(),
description: Some("add a `two` function".into()),
},
}]],
cx,
);
// Ensure steps are re-parsed when deserializing.
let serialized_context = context.read_with(cx, |context, cx| context.serialize(cx));
let deserialized_context = cx.new(|cx| {
AssistantContext::deserialize(
serialized_context,
Default::default(),
registry.clone(),
prompt_builder.clone(),
Arc::new(SlashCommandWorkingSet::default()),
None,
None,
cx,
)
});
expect_patches(
&deserialized_context,
"
one
two
«<patch>
<edit>
<description>add a `two` function</description>
<path>src/lib.rs</path>
<operation>insert_after</operation>
<old_text>fn zero</old_text>
<new_text>
fn two() {}
</new_text>
</edit>
</patch>
»
also,",
&[&[AssistantEdit {
path: "src/lib.rs".into(),
kind: AssistantEditKind::InsertAfter {
old_text: "fn zero".into(),
new_text: "fn two() {}".into(),
description: Some("add a `two` function".into()),
},
}]],
cx,
);
fn edit(
context: &Entity<AssistantContext>,
new_text_marked_with_edits: &str,
cx: &mut TestAppContext,
) {
context.update(cx, |context, cx| {
context.buffer.update(cx, |buffer, cx| {
buffer.edit_via_marked_text(&new_text_marked_with_edits.unindent(), None, cx);
});
});
cx.executor().run_until_parked();
}
#[track_caller]
fn expect_patches(
context: &Entity<AssistantContext>,
expected_marked_text: &str,
expected_suggestions: &[&[AssistantEdit]],
cx: &mut TestAppContext,
) {
let expected_marked_text = expected_marked_text.unindent();
let (expected_text, _) = marked_text_ranges(&expected_marked_text, false);
let (buffer_text, ranges, patches) = context.update(cx, |context, cx| {
context.buffer.read_with(cx, |buffer, _| {
let ranges = context
.patches
.iter()
.map(|entry| entry.range.to_offset(buffer))
.collect::<Vec<_>>();
(
buffer.text(),
ranges,
context
.patches
.iter()
.map(|step| step.edits.clone())
.collect::<Vec<_>>(),
)
})
});
assert_eq!(buffer_text, expected_text);
let actual_marked_text = generate_marked_text(&expected_text, &ranges, false);
assert_eq!(actual_marked_text, expected_marked_text);
assert_eq!(
patches
.iter()
.map(|patch| {
patch
.iter()
.map(|edit| {
let edit = edit.as_ref().unwrap();
AssistantEdit {
path: edit.path.clone(),
kind: edit.kind.clone(),
}
})
.collect::<Vec<_>>()
})
.collect::<Vec<_>>(),
expected_suggestions
);
}
}
#[gpui::test]
async fn test_serialization(cx: &mut TestAppContext) {
let settings_store = cx.update(SettingsStore::test);
cx.set_global(settings_store);
cx.update(LanguageModelRegistry::test);
let registry = Arc::new(LanguageRegistry::test(cx.executor()));
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
let context = cx.new(|cx| {
AssistantContext::local(
registry.clone(),
None,
None,
prompt_builder.clone(),
Arc::new(SlashCommandWorkingSet::default()),
cx,
)
});
let buffer = context.read_with(cx, |context, _| context.buffer.clone());
let message_0 = context.read_with(cx, |context, _| context.message_anchors[0].id);
let message_1 = context.update(cx, |context, cx| {
context
.insert_message_after(message_0, Role::Assistant, MessageStatus::Done, cx)
.unwrap()
});
let message_2 = context.update(cx, |context, cx| {
context
.insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx)
.unwrap()
});
buffer.update(cx, |buffer, cx| {
buffer.edit([(0..0, "a"), (1..1, "b\nc")], None, cx);
buffer.finalize_last_transaction();
});
let _message_3 = context.update(cx, |context, cx| {
context
.insert_message_after(message_2.id, Role::System, MessageStatus::Done, cx)
.unwrap()
});
buffer.update(cx, |buffer, cx| buffer.undo(cx));
assert_eq!(buffer.read_with(cx, |buffer, _| buffer.text()), "a\nb\nc\n");
assert_eq!(
cx.read(|cx| messages(&context, cx)),
[
(message_0, Role::User, 0..2),
(message_1.id, Role::Assistant, 2..6),
(message_2.id, Role::System, 6..6),
]
);
let serialized_context = context.read_with(cx, |context, cx| context.serialize(cx));
let deserialized_context = cx.new(|cx| {
AssistantContext::deserialize(
serialized_context,
Default::default(),
registry.clone(),
prompt_builder.clone(),
Arc::new(SlashCommandWorkingSet::default()),
None,
None,
cx,
)
});
let deserialized_buffer =
deserialized_context.read_with(cx, |context, _| context.buffer.clone());
assert_eq!(
deserialized_buffer.read_with(cx, |buffer, _| buffer.text()),
"a\nb\nc\n"
);
assert_eq!(
cx.read(|cx| messages(&deserialized_context, cx)),
[
(message_0, Role::User, 0..2),
(message_1.id, Role::Assistant, 2..6),
(message_2.id, Role::System, 6..6),
]
);
}
#[gpui::test(iterations = 100)]
async fn test_random_context_collaboration(cx: &mut TestAppContext, mut rng: StdRng) {
let min_peers = env::var("MIN_PEERS")
.map(|i| i.parse().expect("invalid `MIN_PEERS` variable"))
.unwrap_or(2);
let max_peers = env::var("MAX_PEERS")
.map(|i| i.parse().expect("invalid `MAX_PEERS` variable"))
.unwrap_or(5);
let operations = env::var("OPERATIONS")
.map(|i| i.parse().expect("invalid `OPERATIONS` variable"))
.unwrap_or(50);
let settings_store = cx.update(SettingsStore::test);
cx.set_global(settings_store);
cx.update(LanguageModelRegistry::test);
let slash_commands = cx.update(SlashCommandRegistry::default_global);
slash_commands.register_command(FakeSlashCommand("cmd-1".into()), false);
slash_commands.register_command(FakeSlashCommand("cmd-2".into()), false);
slash_commands.register_command(FakeSlashCommand("cmd-3".into()), false);
let registry = Arc::new(LanguageRegistry::test(cx.background_executor.clone()));
let network = Arc::new(Mutex::new(Network::new(rng.clone())));
let mut contexts = Vec::new();
let num_peers = rng.gen_range(min_peers..=max_peers);
let context_id = ContextId::new();
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
for i in 0..num_peers {
let context = cx.new(|cx| {
AssistantContext::new(
context_id.clone(),
i as ReplicaId,
language::Capability::ReadWrite,
registry.clone(),
prompt_builder.clone(),
Arc::new(SlashCommandWorkingSet::default()),
None,
None,
cx,
)
});
cx.update(|cx| {
cx.subscribe(&context, {
let network = network.clone();
move |_, event, _| {
if let ContextEvent::Operation(op) = event {
network
.lock()
.broadcast(i as ReplicaId, vec![op.to_proto()]);
}
}
})
.detach();
});
contexts.push(context);
network.lock().add_peer(i as ReplicaId);
}
let mut mutation_count = operations;
while mutation_count > 0
|| !network.lock().is_idle()
|| network.lock().contains_disconnected_peers()
{
let context_index = rng.gen_range(0..contexts.len());
let context = &contexts[context_index];
match rng.gen_range(0..100) {
0..=29 if mutation_count > 0 => {
log::info!("Context {}: edit buffer", context_index);
context.update(cx, |context, cx| {
context
.buffer
.update(cx, |buffer, cx| buffer.randomly_edit(&mut rng, 1, cx));
});
mutation_count -= 1;
}
30..=44 if mutation_count > 0 => {
context.update(cx, |context, cx| {
let range = context.buffer.read(cx).random_byte_range(0, &mut rng);
log::info!("Context {}: split message at {:?}", context_index, range);
context.split_message(range, cx);
});
mutation_count -= 1;
}
45..=59 if mutation_count > 0 => {
context.update(cx, |context, cx| {
if let Some(message) = context.messages(cx).choose(&mut rng) {
let role = *[Role::User, Role::Assistant, Role::System]
.choose(&mut rng)
.unwrap();
log::info!(
"Context {}: insert message after {:?} with {:?}",
context_index,
message.id,
role
);
context.insert_message_after(message.id, role, MessageStatus::Done, cx);
}
});
mutation_count -= 1;
}
60..=74 if mutation_count > 0 => {
context.update(cx, |context, cx| {
let command_text = "/".to_string()
+ slash_commands
.command_names()
.choose(&mut rng)
.unwrap()
.clone()
.as_ref();
let command_range = context.buffer.update(cx, |buffer, cx| {
let offset = buffer.random_byte_range(0, &mut rng).start;
buffer.edit(
[(offset..offset, format!("\n{}\n", command_text))],
None,
cx,
);
offset + 1..offset + 1 + command_text.len()
});
let output_text = RandomCharIter::new(&mut rng)
.filter(|c| *c != '\r')
.take(10)
.collect::<String>();
let mut events = vec![Ok(SlashCommandEvent::StartMessage {
role: Role::User,
merge_same_roles: true,
})];
let num_sections = rng.gen_range(0..=3);
let mut section_start = 0;
for _ in 0..num_sections {
let mut section_end = rng.gen_range(section_start..=output_text.len());
while !output_text.is_char_boundary(section_end) {
section_end += 1;
}
events.push(Ok(SlashCommandEvent::StartSection {
icon: IconName::Ai,
label: "section".into(),
metadata: None,
}));
events.push(Ok(SlashCommandEvent::Content(SlashCommandContent::Text {
text: output_text[section_start..section_end].to_string(),
run_commands_in_text: false,
})));
events.push(Ok(SlashCommandEvent::EndSection));
section_start = section_end;
}
if section_start < output_text.len() {
events.push(Ok(SlashCommandEvent::Content(SlashCommandContent::Text {
text: output_text[section_start..].to_string(),
run_commands_in_text: false,
})));
}
log::info!(
"Context {}: insert slash command output at {:?} with {:?} events",
context_index,
command_range,
events.len()
);
let command_range = context.buffer.read(cx).anchor_after(command_range.start)
..context.buffer.read(cx).anchor_after(command_range.end);
context.insert_command_output(
command_range,
"/command",
Task::ready(Ok(stream::iter(events).boxed())),
true,
cx,
);
});
cx.run_until_parked();
mutation_count -= 1;
}
75..=84 if mutation_count > 0 => {
context.update(cx, |context, cx| {
if let Some(message) = context.messages(cx).choose(&mut rng) {
let new_status = match rng.gen_range(0..3) {
0 => MessageStatus::Done,
1 => MessageStatus::Pending,
_ => MessageStatus::Error(SharedString::from("Random error")),
};
log::info!(
"Context {}: update message {:?} status to {:?}",
context_index,
message.id,
new_status
);
context.update_metadata(message.id, cx, |metadata| {
metadata.status = new_status;
});
}
});
mutation_count -= 1;
}
_ => {
let replica_id = context_index as ReplicaId;
if network.lock().is_disconnected(replica_id) {
network.lock().reconnect_peer(replica_id, 0);
let (ops_to_send, ops_to_receive) = cx.read(|cx| {
let host_context = &contexts[0].read(cx);
let guest_context = context.read(cx);
(
guest_context.serialize_ops(&host_context.version(cx), cx),
host_context.serialize_ops(&guest_context.version(cx), cx),
)
});
let ops_to_send = ops_to_send.await;
let ops_to_receive = ops_to_receive
.await
.into_iter()
.map(ContextOperation::from_proto)
.collect::<Result<Vec<_>>>()
.unwrap();
log::info!(
"Context {}: reconnecting. Sent {} operations, received {} operations",
context_index,
ops_to_send.len(),
ops_to_receive.len()
);
network.lock().broadcast(replica_id, ops_to_send);
context.update(cx, |context, cx| context.apply_ops(ops_to_receive, cx));
} else if rng.gen_bool(0.1) && replica_id != 0 {
log::info!("Context {}: disconnecting", context_index);
network.lock().disconnect_peer(replica_id);
} else if network.lock().has_unreceived(replica_id) {
log::info!("Context {}: applying operations", context_index);
let ops = network.lock().receive(replica_id);
let ops = ops
.into_iter()
.map(ContextOperation::from_proto)
.collect::<Result<Vec<_>>>()
.unwrap();
context.update(cx, |context, cx| context.apply_ops(ops, cx));
}
}
}
}
cx.read(|cx| {
let first_context = contexts[0].read(cx);
for context in &contexts[1..] {
let context = context.read(cx);
assert!(context.pending_ops.is_empty(), "pending ops: {:?}", context.pending_ops);
assert_eq!(
context.buffer.read(cx).text(),
first_context.buffer.read(cx).text(),
"Context {} text != Context 0 text",
context.buffer.read(cx).replica_id()
);
assert_eq!(
context.message_anchors,
first_context.message_anchors,
"Context {} messages != Context 0 messages",
context.buffer.read(cx).replica_id()
);
assert_eq!(
context.messages_metadata,
first_context.messages_metadata,
"Context {} message metadata != Context 0 message metadata",
context.buffer.read(cx).replica_id()
);
assert_eq!(
context.slash_command_output_sections,
first_context.slash_command_output_sections,
"Context {} slash command output sections != Context 0 slash command output sections",
context.buffer.read(cx).replica_id()
);
}
});
}
#[gpui::test]
fn test_mark_cache_anchors(cx: &mut App) {
let settings_store = SettingsStore::test(cx);
LanguageModelRegistry::test(cx);
cx.set_global(settings_store);
let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
let context = cx.new(|cx| {
AssistantContext::local(
registry,
None,
None,
prompt_builder.clone(),
Arc::new(SlashCommandWorkingSet::default()),
cx,
)
});
let buffer = context.read(cx).buffer.clone();
// Create a test cache configuration
let cache_configuration = &Some(LanguageModelCacheConfiguration {
max_cache_anchors: 3,
should_speculate: true,
min_total_token: 10,
});
let message_1 = context.read(cx).message_anchors[0].clone();
context.update(cx, |context, cx| {
context.mark_cache_anchors(cache_configuration, false, cx)
});
assert_eq!(
messages_cache(&context, cx)
.iter()
.filter(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor))
.count(),
0,
"Empty messages should not have any cache anchors."
);
buffer.update(cx, |buffer, cx| buffer.edit([(0..0, "aaa")], None, cx));
let message_2 = context
.update(cx, |context, cx| {
context.insert_message_after(message_1.id, Role::User, MessageStatus::Pending, cx)
})
.unwrap();
buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "bbbbbbb")], None, cx));
let message_3 = context
.update(cx, |context, cx| {
context.insert_message_after(message_2.id, Role::User, MessageStatus::Pending, cx)
})
.unwrap();
buffer.update(cx, |buffer, cx| buffer.edit([(12..12, "cccccc")], None, cx));
context.update(cx, |context, cx| {
context.mark_cache_anchors(cache_configuration, false, cx)
});
assert_eq!(buffer.read(cx).text(), "aaa\nbbbbbbb\ncccccc");
assert_eq!(
messages_cache(&context, cx)
.iter()
.filter(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor))
.count(),
0,
"Messages should not be marked for cache before going over the token minimum."
);
context.update(cx, |context, _| {
context.token_count = Some(20);
});
context.update(cx, |context, cx| {
context.mark_cache_anchors(cache_configuration, true, cx)
});
assert_eq!(
messages_cache(&context, cx)
.iter()
.map(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor))
.collect::<Vec<bool>>(),
vec![true, true, false],
"Last message should not be an anchor on speculative request."
);
context
.update(cx, |context, cx| {
context.insert_message_after(message_3.id, Role::Assistant, MessageStatus::Pending, cx)
})
.unwrap();
context.update(cx, |context, cx| {
context.mark_cache_anchors(cache_configuration, false, cx)
});
assert_eq!(
messages_cache(&context, cx)
.iter()
.map(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor))
.collect::<Vec<bool>>(),
vec![false, true, true, false],
"Most recent message should also be cached if not a speculative request."
);
context.update(cx, |context, cx| {
context.update_cache_status_for_completion(cx)
});
assert_eq!(
messages_cache(&context, cx)
.iter()
.map(|(_, cache)| cache
.as_ref()
.map_or(None, |cache| Some(cache.status.clone())))
.collect::<Vec<Option<CacheStatus>>>(),
vec![
Some(CacheStatus::Cached),
Some(CacheStatus::Cached),
Some(CacheStatus::Cached),
None
],
"All user messages prior to anchor should be marked as cached."
);
buffer.update(cx, |buffer, cx| buffer.edit([(14..14, "d")], None, cx));
context.update(cx, |context, cx| {
context.mark_cache_anchors(cache_configuration, false, cx)
});
assert_eq!(
messages_cache(&context, cx)
.iter()
.map(|(_, cache)| cache
.as_ref()
.map_or(None, |cache| Some(cache.status.clone())))
.collect::<Vec<Option<CacheStatus>>>(),
vec![
Some(CacheStatus::Cached),
Some(CacheStatus::Cached),
Some(CacheStatus::Pending),
None
],
"Modifying a message should invalidate it's cache but leave previous messages."
);
buffer.update(cx, |buffer, cx| buffer.edit([(2..2, "e")], None, cx));
context.update(cx, |context, cx| {
context.mark_cache_anchors(cache_configuration, false, cx)
});
assert_eq!(
messages_cache(&context, cx)
.iter()
.map(|(_, cache)| cache
.as_ref()
.map_or(None, |cache| Some(cache.status.clone())))
.collect::<Vec<Option<CacheStatus>>>(),
vec![
Some(CacheStatus::Pending),
Some(CacheStatus::Pending),
Some(CacheStatus::Pending),
None
],
"Modifying a message should invalidate all future messages."
);
}
fn messages(context: &Entity<AssistantContext>, cx: &App) -> Vec<(MessageId, Role, Range<usize>)> {
context
.read(cx)
.messages(cx)
.map(|message| (message.id, message.role, message.offset_range))
.collect()
}
fn messages_cache(
context: &Entity<AssistantContext>,
cx: &App,
) -> Vec<(MessageId, Option<MessageCacheMetadata>)> {
context
.read(cx)
.messages(cx)
.map(|message| (message.id, message.cache.clone()))
.collect()
}
#[derive(Clone)]
struct FakeSlashCommand(String);
impl SlashCommand for FakeSlashCommand {
fn name(&self) -> String {
self.0.clone()
}
fn description(&self) -> String {
format!("Fake slash command: {}", self.0)
}
fn menu_text(&self) -> String {
format!("Run fake command: {}", self.0)
}
fn complete_argument(
self: Arc<Self>,
_arguments: &[String],
_cancel: Arc<AtomicBool>,
_workspace: Option<WeakEntity<Workspace>>,
_window: &mut Window,
_cx: &mut App,
) -> Task<Result<Vec<ArgumentCompletion>>> {
Task::ready(Ok(vec![]))
}
fn requires_argument(&self) -> bool {
false
}
fn run(
self: Arc<Self>,
_arguments: &[String],
_context_slash_command_output_sections: &[SlashCommandOutputSection<language::Anchor>],
_context_buffer: BufferSnapshot,
_workspace: WeakEntity<Workspace>,
_delegate: Option<Arc<dyn LspAdapterDelegate>>,
_window: &mut Window,
_cx: &mut App,
) -> Task<SlashCommandResult> {
Task::ready(Ok(SlashCommandOutput {
text: format!("Executed fake command: {}", self.0),
sections: vec![],
run_commands_in_text: false,
}
.to_event_stream()))
}
}