Assistant grouping (#11479)

Groups collections of assistant messages with their tool calls as
children of the assistant message container.


![image](https://github.com/zed-industries/zed/assets/836375/b26b7c90-4c8d-4bbd-972a-1e769d78a455)

Release Notes:

- N/A
This commit is contained in:
Kyle Kelley 2024-05-07 08:21:57 -07:00 committed by GitHub
parent c77d2eb73f
commit 72c47b7f01
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 128 additions and 117 deletions

View file

@ -342,8 +342,8 @@ impl AssistantChat {
} }
if self.pending_completion.take().is_some() { if self.pending_completion.take().is_some() {
if let Some(ChatMessage::Assistant(message)) = self.messages.last() { if let Some(ChatMessage::Assistant(grouping)) = self.messages.last() {
if message.body.text.is_empty() { if grouping.messages.is_empty() {
self.pop_message(cx); self.pop_message(cx);
} }
} }
@ -478,22 +478,30 @@ impl AssistantChat {
while let Some(delta) = stream.next().await { while let Some(delta) = stream.next().await {
let delta = delta?; let delta = delta?;
this.update(cx, |this, cx| { this.update(cx, |this, cx| {
if let Some(ChatMessage::Assistant(AssistantMessage { if let Some(ChatMessage::Assistant(GroupedAssistantMessage {
body: message_body, messages,
tool_calls: message_tool_calls,
.. ..
})) = this.messages.last_mut() })) = this.messages.last_mut()
{ {
if messages.is_empty() {
messages.push(AssistantMessage {
body: RichText::default(),
tool_calls: Vec::new(),
})
}
let message = messages.last_mut().unwrap();
if let Some(content) = &delta.content { if let Some(content) = &delta.content {
body.push_str(content); body.push_str(content);
} }
for tool_call in delta.tool_calls { for tool_call in delta.tool_calls {
let index = tool_call.index as usize; let index = tool_call.index as usize;
if index >= message_tool_calls.len() { if index >= message.tool_calls.len() {
message_tool_calls.resize_with(index + 1, Default::default); message.tool_calls.resize_with(index + 1, Default::default);
} }
let call = &mut message_tool_calls[index]; let call = &mut message.tool_calls[index];
if let Some(id) = &tool_call.id { if let Some(id) = &tool_call.id {
call.id.push_str(id); call.id.push_str(id);
@ -512,7 +520,7 @@ impl AssistantChat {
} }
} }
*message_body = message.body =
RichText::new(body.clone(), &[], &this.language_registry); RichText::new(body.clone(), &[], &this.language_registry);
cx.notify(); cx.notify();
} else { } else {
@ -527,9 +535,9 @@ impl AssistantChat {
let mut tool_tasks = Vec::new(); let mut tool_tasks = Vec::new();
this.update(cx, |this, cx| { this.update(cx, |this, cx| {
if let Some(ChatMessage::Assistant(AssistantMessage { if let Some(ChatMessage::Assistant(GroupedAssistantMessage {
error: message_error, error: message_error,
tool_calls, messages,
.. ..
})) = this.messages.last_mut() })) = this.messages.last_mut()
{ {
@ -537,8 +545,10 @@ impl AssistantChat {
message_error.replace(SharedString::from(error.to_string())); message_error.replace(SharedString::from(error.to_string()));
cx.notify(); cx.notify();
} else { } else {
for tool_call in tool_calls.iter() { if let Some(current_message) = messages.last_mut() {
tool_tasks.push(this.tool_registry.call(tool_call, cx)); for tool_call in current_message.tool_calls.iter() {
tool_tasks.push(this.tool_registry.call(tool_call, cx));
}
} }
} }
} }
@ -554,21 +564,38 @@ impl AssistantChat {
let tools = tools.into_iter().filter_map(|tool| tool.ok()).collect(); let tools = tools.into_iter().filter_map(|tool| tool.ok()).collect();
this.update(cx, |this, cx| { this.update(cx, |this, cx| {
if let Some(ChatMessage::Assistant(AssistantMessage { tool_calls, .. })) = if let Some(ChatMessage::Assistant(GroupedAssistantMessage { messages, .. })) =
this.messages.last_mut() this.messages.last_mut()
{ {
*tool_calls = tools; if let Some(current_message) = messages.last_mut() {
cx.notify(); current_message.tool_calls = tools;
cx.notify();
} else {
unreachable!()
}
} }
})?; })?;
} }
} }
fn push_new_assistant_message(&mut self, cx: &mut ViewContext<Self>) { fn push_new_assistant_message(&mut self, cx: &mut ViewContext<Self>) {
let message = ChatMessage::Assistant(AssistantMessage { // If the last message is a grouped assistant message, add to the grouped message
if let Some(ChatMessage::Assistant(GroupedAssistantMessage { messages, .. })) =
self.messages.last_mut()
{
messages.push(AssistantMessage {
body: RichText::default(),
tool_calls: Vec::new(),
});
return;
}
let message = ChatMessage::Assistant(GroupedAssistantMessage {
id: self.next_message_id.post_inc(), id: self.next_message_id.post_inc(),
body: RichText::default(), messages: vec![AssistantMessage {
tool_calls: Vec::new(), body: RichText::default(),
tool_calls: Vec::new(),
}],
error: None, error: None,
}); });
self.push_message(message, cx); self.push_message(message, cx);
@ -687,15 +714,14 @@ impl AssistantChat {
crate::ui::ChatMessage::new( crate::ui::ChatMessage::new(
*id, *id,
UserOrAssistant::User(self.user_store.read(cx).current_user()), UserOrAssistant::User(self.user_store.read(cx).current_user()),
Some( // todo!(): clean up the vec usage
vec![
RichText::new( RichText::new(
body.read(cx).text(cx), body.read(cx).text(cx),
&[], &[],
&self.language_registry, &self.language_registry,
) )
.element(ElementId::from(id.0), cx), .element(ElementId::from(id.0), cx),
),
Some(
h_flex() h_flex()
.gap_2() .gap_2()
.children( .children(
@ -704,7 +730,7 @@ impl AssistantChat {
.map(|attachment| attachment.view.clone()), .map(|attachment| attachment.view.clone()),
) )
.into_any_element(), .into_any_element(),
), ],
self.is_message_collapsed(id), self.is_message_collapsed(id),
Box::new(cx.listener({ Box::new(cx.listener({
let id = *id; let id = *id;
@ -719,33 +745,34 @@ impl AssistantChat {
} }
}) })
.into_any(), .into_any(),
ChatMessage::Assistant(AssistantMessage { ChatMessage::Assistant(GroupedAssistantMessage {
id, id,
body, messages,
error, error,
tool_calls,
.. ..
}) => { }) => {
let assistant_body = if body.text.is_empty() { let mut message_elements = Vec::new();
None
} else {
Some(
div()
.child(body.element(ElementId::from(id.0), cx))
.into_any_element(),
)
};
let tools = tool_calls for message in messages {
.iter() if !message.body.text.is_empty() {
.map(|tool_call| self.tool_registry.render_tool_call(tool_call, cx)) message_elements.push(
.collect::<Vec<AnyElement>>(); div()
// todo!(): The element Id will need to be a combo of the base ID and the index within the grouping
.child(message.body.element(ElementId::from(id.0), cx))
.into_any_element(),
)
}
let tools_body = if tools.is_empty() { let tools = message
None .tool_calls
} else { .iter()
Some(div().children(tools).into_any_element()) .map(|tool_call| self.tool_registry.render_tool_call(tool_call, cx))
}; .collect::<Vec<AnyElement>>();
if !tools.is_empty() {
message_elements.push(div().children(tools).into_any_element())
}
}
div() div()
.when(is_first, |this| this.pt(padding)) .when(is_first, |this| this.pt(padding))
@ -753,8 +780,7 @@ impl AssistantChat {
crate::ui::ChatMessage::new( crate::ui::ChatMessage::new(
*id, *id,
UserOrAssistant::Assistant, UserOrAssistant::Assistant,
assistant_body, message_elements,
tools_body,
self.is_message_collapsed(id), self.is_message_collapsed(id),
Box::new(cx.listener({ Box::new(cx.listener({
let id = *id; let id = *id;
@ -796,46 +822,47 @@ impl AssistantChat {
content: body.read(cx).text(cx), content: body.read(cx).text(cx),
}); });
} }
ChatMessage::Assistant(AssistantMessage { ChatMessage::Assistant(GroupedAssistantMessage { messages, .. }) => {
body, tool_calls, .. for message in messages {
}) => { let body = message.body.clone();
// In no case do we want to send an empty message. This shouldn't happen, but we might as well
// not break the Chat API if it does.
if body.text.is_empty() && tool_calls.is_empty() {
continue;
}
let tool_calls_from_assistant = tool_calls if body.text.is_empty() && message.tool_calls.is_empty() {
.iter() continue;
.map(|tool_call| ToolCall { }
content: ToolCallContent::Function {
function: FunctionContent { let tool_calls_from_assistant = message
name: tool_call.name.clone(), .tool_calls
arguments: tool_call.arguments.clone(), .iter()
.map(|tool_call| ToolCall {
content: ToolCallContent::Function {
function: FunctionContent {
name: tool_call.name.clone(),
arguments: tool_call.arguments.clone(),
},
}, },
}, id: tool_call.id.clone(),
id: tool_call.id.clone(), })
}) .collect();
.collect();
completion_messages.push(CompletionMessage::Assistant { completion_messages.push(CompletionMessage::Assistant {
content: Some(body.text.to_string()), content: Some(body.text.to_string()),
tool_calls: tool_calls_from_assistant, tool_calls: tool_calls_from_assistant,
});
for tool_call in tool_calls {
// Every tool call _must_ have a result by ID, otherwise OpenAI will error.
let content = match &tool_call.result {
Some(result) => {
result.generate(&tool_call.name, &mut project_context, cx)
}
None => "".to_string(),
};
completion_messages.push(CompletionMessage::Tool {
content,
tool_call_id: tool_call.id.clone(),
}); });
for tool_call in &message.tool_calls {
// Every tool call _must_ have a result by ID, otherwise OpenAI will error.
let content = match &tool_call.result {
Some(result) => {
result.generate(&tool_call.name, &mut project_context, cx)
}
None => "".to_string(),
};
completion_messages.push(CompletionMessage::Tool {
content,
tool_call_id: tool_call.id.clone(),
});
}
} }
} }
} }
@ -885,7 +912,7 @@ impl MessageId {
enum ChatMessage { enum ChatMessage {
User(UserMessage), User(UserMessage),
Assistant(AssistantMessage), Assistant(GroupedAssistantMessage),
} }
impl ChatMessage { impl ChatMessage {
@ -904,8 +931,12 @@ struct UserMessage {
} }
struct AssistantMessage { struct AssistantMessage {
id: MessageId,
body: RichText, body: RichText,
tool_calls: Vec<ToolFunctionCall>, tool_calls: Vec<ToolFunctionCall>,
}
struct GroupedAssistantMessage {
id: MessageId,
messages: Vec<AssistantMessage>,
error: Option<SharedString>, error: Option<SharedString>,
} }

View file

@ -15,8 +15,7 @@ pub enum UserOrAssistant {
pub struct ChatMessage { pub struct ChatMessage {
id: MessageId, id: MessageId,
player: UserOrAssistant, player: UserOrAssistant,
message: Option<AnyElement>, messages: Vec<AnyElement>,
tools_used: Option<AnyElement>,
selected: bool, selected: bool,
collapsed: bool, collapsed: bool,
on_collapse_handle_click: Box<dyn Fn(&ClickEvent, &mut WindowContext) + 'static>, on_collapse_handle_click: Box<dyn Fn(&ClickEvent, &mut WindowContext) + 'static>,
@ -26,16 +25,14 @@ impl ChatMessage {
pub fn new( pub fn new(
id: MessageId, id: MessageId,
player: UserOrAssistant, player: UserOrAssistant,
message: Option<AnyElement>, messages: Vec<AnyElement>,
tools_used: Option<AnyElement>,
collapsed: bool, collapsed: bool,
on_collapse_handle_click: Box<dyn Fn(&ClickEvent, &mut WindowContext) + 'static>, on_collapse_handle_click: Box<dyn Fn(&ClickEvent, &mut WindowContext) + 'static>,
) -> Self { ) -> Self {
Self { Self {
id, id,
player, player,
message, messages,
tools_used,
selected: false, selected: false,
collapsed, collapsed,
on_collapse_handle_click, on_collapse_handle_click,
@ -117,19 +114,10 @@ impl RenderOnce for ChatMessage {
.icon_color(Color::Muted) .icon_color(Color::Muted)
.on_click(self.on_collapse_handle_click) .on_click(self.on_collapse_handle_click)
.tooltip(|cx| Tooltip::text("Collapse Message", cx)), .tooltip(|cx| Tooltip::text("Collapse Message", cx)),
), // .child( ),
// IconButton::new("copy-message", IconName::Copy)
// .icon_color(Color::Muted)
// .icon_size(IconSize::XSmall),
// )
// .child(
// IconButton::new("menu", IconName::Ellipsis)
// .icon_color(Color::Muted)
// .icon_size(IconSize::XSmall),
// ),
), ),
) )
.when(self.message.is_some() || self.tools_used.is_some(), |el| { .when(self.messages.len() > 0, |el| {
el.child( el.child(
h_flex().child( h_flex().child(
v_flex() v_flex()
@ -144,8 +132,7 @@ impl RenderOnce for ChatMessage {
this.bg(background_color) this.bg(background_color)
}) })
.when(self.collapsed, |this| this.h(collapsed_height)) .when(self.collapsed, |this| this.h(collapsed_height))
.children(self.message) .children(self.messages),
.when_some(self.tools_used, |this, tools_used| this.child(tools_used)),
), ),
) )
}) })

View file

@ -28,8 +28,7 @@ impl Render for ChatMessageStory {
ChatMessage::new( ChatMessage::new(
MessageId(0), MessageId(0),
UserOrAssistant::User(Some(user_1.clone())), UserOrAssistant::User(Some(user_1.clone())),
Some(div().child("What can I do here?").into_any_element()), vec![div().child("What can I do here?").into_any_element()],
None,
false, false,
Box::new(|_, _| {}), Box::new(|_, _| {}),
), ),
@ -39,8 +38,7 @@ impl Render for ChatMessageStory {
ChatMessage::new( ChatMessage::new(
MessageId(0), MessageId(0),
UserOrAssistant::User(Some(user_1.clone())), UserOrAssistant::User(Some(user_1.clone())),
Some(div().child("What can I do here?").into_any_element()), vec![div().child("What can I do here?").into_any_element()],
None,
true, true,
Box::new(|_, _| {}), Box::new(|_, _| {}),
), ),
@ -53,8 +51,7 @@ impl Render for ChatMessageStory {
ChatMessage::new( ChatMessage::new(
MessageId(0), MessageId(0),
UserOrAssistant::Assistant, UserOrAssistant::Assistant,
Some(div().child("You can talk to me!").into_any_element()), vec![div().child("You can talk to me!").into_any_element()],
None,
false, false,
Box::new(|_, _| {}), Box::new(|_, _| {}),
), ),
@ -64,8 +61,7 @@ impl Render for ChatMessageStory {
ChatMessage::new( ChatMessage::new(
MessageId(0), MessageId(0),
UserOrAssistant::Assistant, UserOrAssistant::Assistant,
Some(div().child(MULTI_LINE_MESSAGE).into_any_element()), vec![div().child(MULTI_LINE_MESSAGE).into_any_element()],
None,
true, true,
Box::new(|_, _| {}), Box::new(|_, _| {}),
), ),
@ -79,24 +75,21 @@ impl Render for ChatMessageStory {
.child(ChatMessage::new( .child(ChatMessage::new(
MessageId(0), MessageId(0),
UserOrAssistant::User(Some(user_1.clone())), UserOrAssistant::User(Some(user_1.clone())),
Some(div().child("What is Rust??").into_any_element()), vec![div().child("What is Rust??").into_any_element()],
None,
false, false,
Box::new(|_, _| {}), Box::new(|_, _| {}),
)) ))
.child(ChatMessage::new( .child(ChatMessage::new(
MessageId(0), MessageId(0),
UserOrAssistant::Assistant, UserOrAssistant::Assistant,
Some(div().child("Rust is a multi-paradigm programming language focused on performance and safety").into_any_element()), vec![div().child("Rust is a multi-paradigm programming language focused on performance and safety").into_any_element()],
None,
false, false,
Box::new(|_, _| {}), Box::new(|_, _| {}),
)) ))
.child(ChatMessage::new( .child(ChatMessage::new(
MessageId(0), MessageId(0),
UserOrAssistant::User(Some(user_1)), UserOrAssistant::User(Some(user_1)),
Some(div().child("Sounds pretty cool!").into_any_element()), vec![div().child("Sounds pretty cool!").into_any_element()],
None,
false, false,
Box::new(|_, _| {}), Box::new(|_, _| {}),
)), )),