Assistant grouping (#11479)

Groups collections of assistant messages with their tool calls as
children of the assistant message container.


![image](https://github.com/zed-industries/zed/assets/836375/b26b7c90-4c8d-4bbd-972a-1e769d78a455)

Release Notes:

- N/A
This commit is contained in:
Kyle Kelley 2024-05-07 08:21:57 -07:00 committed by GitHub
parent c77d2eb73f
commit 72c47b7f01
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 128 additions and 117 deletions

View file

@ -342,8 +342,8 @@ impl AssistantChat {
}
if self.pending_completion.take().is_some() {
if let Some(ChatMessage::Assistant(message)) = self.messages.last() {
if message.body.text.is_empty() {
if let Some(ChatMessage::Assistant(grouping)) = self.messages.last() {
if grouping.messages.is_empty() {
self.pop_message(cx);
}
}
@ -478,22 +478,30 @@ impl AssistantChat {
while let Some(delta) = stream.next().await {
let delta = delta?;
this.update(cx, |this, cx| {
if let Some(ChatMessage::Assistant(AssistantMessage {
body: message_body,
tool_calls: message_tool_calls,
if let Some(ChatMessage::Assistant(GroupedAssistantMessage {
messages,
..
})) = this.messages.last_mut()
{
if messages.is_empty() {
messages.push(AssistantMessage {
body: RichText::default(),
tool_calls: Vec::new(),
})
}
let message = messages.last_mut().unwrap();
if let Some(content) = &delta.content {
body.push_str(content);
}
for tool_call in delta.tool_calls {
let index = tool_call.index as usize;
if index >= message_tool_calls.len() {
message_tool_calls.resize_with(index + 1, Default::default);
if index >= message.tool_calls.len() {
message.tool_calls.resize_with(index + 1, Default::default);
}
let call = &mut message_tool_calls[index];
let call = &mut message.tool_calls[index];
if let Some(id) = &tool_call.id {
call.id.push_str(id);
@ -512,7 +520,7 @@ impl AssistantChat {
}
}
*message_body =
message.body =
RichText::new(body.clone(), &[], &this.language_registry);
cx.notify();
} else {
@ -527,9 +535,9 @@ impl AssistantChat {
let mut tool_tasks = Vec::new();
this.update(cx, |this, cx| {
if let Some(ChatMessage::Assistant(AssistantMessage {
if let Some(ChatMessage::Assistant(GroupedAssistantMessage {
error: message_error,
tool_calls,
messages,
..
})) = this.messages.last_mut()
{
@ -537,8 +545,10 @@ impl AssistantChat {
message_error.replace(SharedString::from(error.to_string()));
cx.notify();
} else {
for tool_call in tool_calls.iter() {
tool_tasks.push(this.tool_registry.call(tool_call, cx));
if let Some(current_message) = messages.last_mut() {
for tool_call in current_message.tool_calls.iter() {
tool_tasks.push(this.tool_registry.call(tool_call, cx));
}
}
}
}
@ -554,21 +564,38 @@ impl AssistantChat {
let tools = tools.into_iter().filter_map(|tool| tool.ok()).collect();
this.update(cx, |this, cx| {
if let Some(ChatMessage::Assistant(AssistantMessage { tool_calls, .. })) =
if let Some(ChatMessage::Assistant(GroupedAssistantMessage { messages, .. })) =
this.messages.last_mut()
{
*tool_calls = tools;
cx.notify();
if let Some(current_message) = messages.last_mut() {
current_message.tool_calls = tools;
cx.notify();
} else {
unreachable!()
}
}
})?;
}
}
fn push_new_assistant_message(&mut self, cx: &mut ViewContext<Self>) {
let message = ChatMessage::Assistant(AssistantMessage {
// If the last message is a grouped assistant message, add to the grouped message
if let Some(ChatMessage::Assistant(GroupedAssistantMessage { messages, .. })) =
self.messages.last_mut()
{
messages.push(AssistantMessage {
body: RichText::default(),
tool_calls: Vec::new(),
});
return;
}
let message = ChatMessage::Assistant(GroupedAssistantMessage {
id: self.next_message_id.post_inc(),
body: RichText::default(),
tool_calls: Vec::new(),
messages: vec![AssistantMessage {
body: RichText::default(),
tool_calls: Vec::new(),
}],
error: None,
});
self.push_message(message, cx);
@ -687,15 +714,14 @@ impl AssistantChat {
crate::ui::ChatMessage::new(
*id,
UserOrAssistant::User(self.user_store.read(cx).current_user()),
Some(
// todo!(): clean up the vec usage
vec![
RichText::new(
body.read(cx).text(cx),
&[],
&self.language_registry,
)
.element(ElementId::from(id.0), cx),
),
Some(
h_flex()
.gap_2()
.children(
@ -704,7 +730,7 @@ impl AssistantChat {
.map(|attachment| attachment.view.clone()),
)
.into_any_element(),
),
],
self.is_message_collapsed(id),
Box::new(cx.listener({
let id = *id;
@ -719,33 +745,34 @@ impl AssistantChat {
}
})
.into_any(),
ChatMessage::Assistant(AssistantMessage {
ChatMessage::Assistant(GroupedAssistantMessage {
id,
body,
messages,
error,
tool_calls,
..
}) => {
let assistant_body = if body.text.is_empty() {
None
} else {
Some(
div()
.child(body.element(ElementId::from(id.0), cx))
.into_any_element(),
)
};
let mut message_elements = Vec::new();
let tools = tool_calls
.iter()
.map(|tool_call| self.tool_registry.render_tool_call(tool_call, cx))
.collect::<Vec<AnyElement>>();
for message in messages {
if !message.body.text.is_empty() {
message_elements.push(
div()
// todo!(): The element Id will need to be a combo of the base ID and the index within the grouping
.child(message.body.element(ElementId::from(id.0), cx))
.into_any_element(),
)
}
let tools_body = if tools.is_empty() {
None
} else {
Some(div().children(tools).into_any_element())
};
let tools = message
.tool_calls
.iter()
.map(|tool_call| self.tool_registry.render_tool_call(tool_call, cx))
.collect::<Vec<AnyElement>>();
if !tools.is_empty() {
message_elements.push(div().children(tools).into_any_element())
}
}
div()
.when(is_first, |this| this.pt(padding))
@ -753,8 +780,7 @@ impl AssistantChat {
crate::ui::ChatMessage::new(
*id,
UserOrAssistant::Assistant,
assistant_body,
tools_body,
message_elements,
self.is_message_collapsed(id),
Box::new(cx.listener({
let id = *id;
@ -796,46 +822,47 @@ impl AssistantChat {
content: body.read(cx).text(cx),
});
}
ChatMessage::Assistant(AssistantMessage {
body, tool_calls, ..
}) => {
// In no case do we want to send an empty message. This shouldn't happen, but we might as well
// not break the Chat API if it does.
if body.text.is_empty() && tool_calls.is_empty() {
continue;
}
ChatMessage::Assistant(GroupedAssistantMessage { messages, .. }) => {
for message in messages {
let body = message.body.clone();
let tool_calls_from_assistant = tool_calls
.iter()
.map(|tool_call| ToolCall {
content: ToolCallContent::Function {
function: FunctionContent {
name: tool_call.name.clone(),
arguments: tool_call.arguments.clone(),
if body.text.is_empty() && message.tool_calls.is_empty() {
continue;
}
let tool_calls_from_assistant = message
.tool_calls
.iter()
.map(|tool_call| ToolCall {
content: ToolCallContent::Function {
function: FunctionContent {
name: tool_call.name.clone(),
arguments: tool_call.arguments.clone(),
},
},
},
id: tool_call.id.clone(),
})
.collect();
id: tool_call.id.clone(),
})
.collect();
completion_messages.push(CompletionMessage::Assistant {
content: Some(body.text.to_string()),
tool_calls: tool_calls_from_assistant,
});
for tool_call in tool_calls {
// Every tool call _must_ have a result by ID, otherwise OpenAI will error.
let content = match &tool_call.result {
Some(result) => {
result.generate(&tool_call.name, &mut project_context, cx)
}
None => "".to_string(),
};
completion_messages.push(CompletionMessage::Tool {
content,
tool_call_id: tool_call.id.clone(),
completion_messages.push(CompletionMessage::Assistant {
content: Some(body.text.to_string()),
tool_calls: tool_calls_from_assistant,
});
for tool_call in &message.tool_calls {
// Every tool call _must_ have a result by ID, otherwise OpenAI will error.
let content = match &tool_call.result {
Some(result) => {
result.generate(&tool_call.name, &mut project_context, cx)
}
None => "".to_string(),
};
completion_messages.push(CompletionMessage::Tool {
content,
tool_call_id: tool_call.id.clone(),
});
}
}
}
}
@ -885,7 +912,7 @@ impl MessageId {
enum ChatMessage {
User(UserMessage),
Assistant(AssistantMessage),
Assistant(GroupedAssistantMessage),
}
impl ChatMessage {
@ -904,8 +931,12 @@ struct UserMessage {
}
struct AssistantMessage {
id: MessageId,
body: RichText,
tool_calls: Vec<ToolFunctionCall>,
}
struct GroupedAssistantMessage {
id: MessageId,
messages: Vec<AssistantMessage>,
error: Option<SharedString>,
}

View file

@ -15,8 +15,7 @@ pub enum UserOrAssistant {
pub struct ChatMessage {
id: MessageId,
player: UserOrAssistant,
message: Option<AnyElement>,
tools_used: Option<AnyElement>,
messages: Vec<AnyElement>,
selected: bool,
collapsed: bool,
on_collapse_handle_click: Box<dyn Fn(&ClickEvent, &mut WindowContext) + 'static>,
@ -26,16 +25,14 @@ impl ChatMessage {
pub fn new(
id: MessageId,
player: UserOrAssistant,
message: Option<AnyElement>,
tools_used: Option<AnyElement>,
messages: Vec<AnyElement>,
collapsed: bool,
on_collapse_handle_click: Box<dyn Fn(&ClickEvent, &mut WindowContext) + 'static>,
) -> Self {
Self {
id,
player,
message,
tools_used,
messages,
selected: false,
collapsed,
on_collapse_handle_click,
@ -117,19 +114,10 @@ impl RenderOnce for ChatMessage {
.icon_color(Color::Muted)
.on_click(self.on_collapse_handle_click)
.tooltip(|cx| Tooltip::text("Collapse Message", cx)),
), // .child(
// IconButton::new("copy-message", IconName::Copy)
// .icon_color(Color::Muted)
// .icon_size(IconSize::XSmall),
// )
// .child(
// IconButton::new("menu", IconName::Ellipsis)
// .icon_color(Color::Muted)
// .icon_size(IconSize::XSmall),
// ),
),
),
)
.when(self.message.is_some() || self.tools_used.is_some(), |el| {
.when(self.messages.len() > 0, |el| {
el.child(
h_flex().child(
v_flex()
@ -144,8 +132,7 @@ impl RenderOnce for ChatMessage {
this.bg(background_color)
})
.when(self.collapsed, |this| this.h(collapsed_height))
.children(self.message)
.when_some(self.tools_used, |this, tools_used| this.child(tools_used)),
.children(self.messages),
),
)
})

View file

@ -28,8 +28,7 @@ impl Render for ChatMessageStory {
ChatMessage::new(
MessageId(0),
UserOrAssistant::User(Some(user_1.clone())),
Some(div().child("What can I do here?").into_any_element()),
None,
vec![div().child("What can I do here?").into_any_element()],
false,
Box::new(|_, _| {}),
),
@ -39,8 +38,7 @@ impl Render for ChatMessageStory {
ChatMessage::new(
MessageId(0),
UserOrAssistant::User(Some(user_1.clone())),
Some(div().child("What can I do here?").into_any_element()),
None,
vec![div().child("What can I do here?").into_any_element()],
true,
Box::new(|_, _| {}),
),
@ -53,8 +51,7 @@ impl Render for ChatMessageStory {
ChatMessage::new(
MessageId(0),
UserOrAssistant::Assistant,
Some(div().child("You can talk to me!").into_any_element()),
None,
vec![div().child("You can talk to me!").into_any_element()],
false,
Box::new(|_, _| {}),
),
@ -64,8 +61,7 @@ impl Render for ChatMessageStory {
ChatMessage::new(
MessageId(0),
UserOrAssistant::Assistant,
Some(div().child(MULTI_LINE_MESSAGE).into_any_element()),
None,
vec![div().child(MULTI_LINE_MESSAGE).into_any_element()],
true,
Box::new(|_, _| {}),
),
@ -79,24 +75,21 @@ impl Render for ChatMessageStory {
.child(ChatMessage::new(
MessageId(0),
UserOrAssistant::User(Some(user_1.clone())),
Some(div().child("What is Rust??").into_any_element()),
None,
vec![div().child("What is Rust??").into_any_element()],
false,
Box::new(|_, _| {}),
))
.child(ChatMessage::new(
MessageId(0),
UserOrAssistant::Assistant,
Some(div().child("Rust is a multi-paradigm programming language focused on performance and safety").into_any_element()),
None,
vec![div().child("Rust is a multi-paradigm programming language focused on performance and safety").into_any_element()],
false,
Box::new(|_, _| {}),
))
.child(ChatMessage::new(
MessageId(0),
UserOrAssistant::User(Some(user_1)),
Some(div().child("Sounds pretty cool!").into_any_element()),
None,
vec![div().child("Sounds pretty cool!").into_any_element()],
false,
Box::new(|_, _| {}),
)),