assistant: Insert creases for tool uses (#17330)

This PR makes it so we create creases for each of the tool uses in the
context editor.

<img width="1290" alt="Screenshot 2024-09-03 at 5 37 33 PM"
src="https://github.com/user-attachments/assets/94e943fd-3f05-4bc4-9672-94bff42ec500">

Release Notes:

- N/A
This commit is contained in:
Marshall Bowers 2024-09-03 17:52:52 -04:00 committed by GitHub
parent be657377a2
commit c2448e1673
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 119 additions and 1 deletions

View file

@ -1416,6 +1416,7 @@ pub struct ContextEditor {
remote_id: Option<workspace::ViewId>,
pending_slash_command_creases: HashMap<Range<language::Anchor>, CreaseId>,
pending_slash_command_blocks: HashMap<Range<language::Anchor>, CustomBlockId>,
pending_tool_use_creases: HashMap<Range<language::Anchor>, CreaseId>,
_subscriptions: Vec<Subscription>,
workflow_steps: HashMap<Range<language::Anchor>, WorkflowStepViewState>,
active_workflow_step: Option<ActiveWorkflowStep>,
@ -1480,6 +1481,7 @@ impl ContextEditor {
project,
pending_slash_command_creases: HashMap::default(),
pending_slash_command_blocks: HashMap::default(),
pending_tool_use_creases: HashMap::default(),
_subscriptions,
workflow_steps: HashMap::default(),
active_workflow_step: None,
@ -1855,6 +1857,72 @@ impl ContextEditor {
cx,
);
}
let new_tool_uses = self
.context
.read(cx)
.pending_tool_uses()
.into_iter()
.filter(|tool_use| {
!self
.pending_tool_use_creases
.contains_key(&tool_use.source_range)
})
.cloned()
.collect::<Vec<_>>();
let buffer = editor.buffer().read(cx).snapshot(cx);
let (excerpt_id, _buffer_id, _) = buffer.as_singleton().unwrap();
let excerpt_id = *excerpt_id;
let mut buffer_rows_to_fold = BTreeSet::new();
let creases = new_tool_uses
.iter()
.map(|tool_use| {
let placeholder = FoldPlaceholder {
render: render_fold_icon_button(
cx.view().downgrade(),
IconName::PocketKnife,
tool_use.name.clone().into(),
),
constrain_width: false,
merge_adjacent: false,
};
let render_trailer =
move |_row, _unfold, _cx: &mut WindowContext| Empty.into_any();
let start = buffer
.anchor_in_excerpt(excerpt_id, tool_use.source_range.start)
.unwrap();
let end = buffer
.anchor_in_excerpt(excerpt_id, tool_use.source_range.end)
.unwrap();
let buffer_row = MultiBufferRow(start.to_point(&buffer).row);
buffer_rows_to_fold.insert(buffer_row);
Crease::new(
start..end,
placeholder,
fold_toggle("tool-use"),
render_trailer,
)
})
.collect::<Vec<_>>();
let crease_ids = editor.insert_creases(creases, cx);
for buffer_row in buffer_rows_to_fold.into_iter().rev() {
editor.fold_at(&FoldAt { buffer_row }, cx);
}
self.pending_tool_use_creases.extend(
new_tool_uses
.iter()
.map(|tool_use| tool_use.source_range.clone())
.zip(crease_ids),
);
});
}
ContextEvent::WorkflowStepsUpdated { removed, updated } => {

View file

@ -490,6 +490,7 @@ pub struct Context {
edits_since_last_parse: language::Subscription,
finished_slash_commands: HashSet<SlashCommandId>,
slash_command_output_sections: Vec<SlashCommandOutputSection<language::Anchor>>,
pending_tool_uses_by_id: HashMap<String, PendingToolUse>,
message_anchors: Vec<MessageAnchor>,
images: HashMap<u64, (Arc<RenderImage>, Shared<Task<Option<LanguageModelImage>>>)>,
image_anchors: Vec<ImageAnchor>,
@ -591,6 +592,7 @@ impl Context {
messages_metadata: Default::default(),
pending_slash_commands: Vec::new(),
finished_slash_commands: HashSet::default(),
pending_tool_uses_by_id: HashMap::default(),
slash_command_output_sections: Vec::new(),
edits_since_last_parse: edits_since_last_slash_command_parse,
summary: None,
@ -1004,6 +1006,14 @@ impl Context {
&self.slash_command_output_sections
}
pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> {
self.pending_tool_uses_by_id.values().collect()
}
pub fn get_tool_use_by_id(&self, id: &String) -> Option<&PendingToolUse> {
self.pending_tool_uses_by_id.get(id)
}
fn set_language(&mut self, cx: &mut ModelContext<Self>) {
let markdown = self.language_registry.language_for_name("Markdown");
cx.spawn(|this, mut cx| async move {
@ -1984,12 +1994,16 @@ impl Context {
);
}
LanguageModelCompletionEvent::ToolUse(tool_use) => {
const NEWLINE: char = '\n';
let mut text = String::new();
text.push('\n');
text.push(NEWLINE);
text.push_str(
&serde_json::to_string_pretty(&tool_use)
.expect("failed to serialize tool use to JSON"),
);
text.push(NEWLINE);
let text_len = text.len();
buffer.edit(
[(
@ -1999,6 +2013,23 @@ impl Context {
None,
cx,
);
let start_ix = message_old_end_offset + NEWLINE.len_utf8();
let end_ix =
message_old_end_offset + text_len - NEWLINE.len_utf8();
let source_range = buffer.anchor_after(start_ix)
..buffer.anchor_after(end_ix);
this.pending_tool_uses_by_id.insert(
tool_use.id.clone(),
PendingToolUse {
id: tool_use.id,
name: tool_use.name,
input: tool_use.input,
status: PendingToolUseStatus::Idle,
source_range,
},
);
}
}
});
@ -2757,6 +2788,22 @@ pub enum PendingSlashCommandStatus {
Error(String),
}
#[derive(Debug, Clone)]
pub struct PendingToolUse {
pub id: String,
pub name: String,
pub input: serde_json::Value,
pub status: PendingToolUseStatus,
pub source_range: Range<language::Anchor>,
}
#[derive(Debug, Clone)]
pub enum PendingToolUseStatus {
Idle,
Running { _task: Shared<Task<()>> },
Error(String),
}
#[derive(Serialize, Deserialize)]
pub struct SavedMessage {
pub id: MessageId,