Assistant grouping (#11479)
Groups collections of assistant messages with their tool calls as children of the assistant message container.  Release Notes: - N/A
This commit is contained in:
parent
c77d2eb73f
commit
72c47b7f01
3 changed files with 128 additions and 117 deletions
|
@ -342,8 +342,8 @@ impl AssistantChat {
|
|||
}
|
||||
|
||||
if self.pending_completion.take().is_some() {
|
||||
if let Some(ChatMessage::Assistant(message)) = self.messages.last() {
|
||||
if message.body.text.is_empty() {
|
||||
if let Some(ChatMessage::Assistant(grouping)) = self.messages.last() {
|
||||
if grouping.messages.is_empty() {
|
||||
self.pop_message(cx);
|
||||
}
|
||||
}
|
||||
|
@ -478,22 +478,30 @@ impl AssistantChat {
|
|||
while let Some(delta) = stream.next().await {
|
||||
let delta = delta?;
|
||||
this.update(cx, |this, cx| {
|
||||
if let Some(ChatMessage::Assistant(AssistantMessage {
|
||||
body: message_body,
|
||||
tool_calls: message_tool_calls,
|
||||
if let Some(ChatMessage::Assistant(GroupedAssistantMessage {
|
||||
messages,
|
||||
..
|
||||
})) = this.messages.last_mut()
|
||||
{
|
||||
if messages.is_empty() {
|
||||
messages.push(AssistantMessage {
|
||||
body: RichText::default(),
|
||||
tool_calls: Vec::new(),
|
||||
})
|
||||
}
|
||||
|
||||
let message = messages.last_mut().unwrap();
|
||||
|
||||
if let Some(content) = &delta.content {
|
||||
body.push_str(content);
|
||||
}
|
||||
|
||||
for tool_call in delta.tool_calls {
|
||||
let index = tool_call.index as usize;
|
||||
if index >= message_tool_calls.len() {
|
||||
message_tool_calls.resize_with(index + 1, Default::default);
|
||||
if index >= message.tool_calls.len() {
|
||||
message.tool_calls.resize_with(index + 1, Default::default);
|
||||
}
|
||||
let call = &mut message_tool_calls[index];
|
||||
let call = &mut message.tool_calls[index];
|
||||
|
||||
if let Some(id) = &tool_call.id {
|
||||
call.id.push_str(id);
|
||||
|
@ -512,7 +520,7 @@ impl AssistantChat {
|
|||
}
|
||||
}
|
||||
|
||||
*message_body =
|
||||
message.body =
|
||||
RichText::new(body.clone(), &[], &this.language_registry);
|
||||
cx.notify();
|
||||
} else {
|
||||
|
@ -527,9 +535,9 @@ impl AssistantChat {
|
|||
|
||||
let mut tool_tasks = Vec::new();
|
||||
this.update(cx, |this, cx| {
|
||||
if let Some(ChatMessage::Assistant(AssistantMessage {
|
||||
if let Some(ChatMessage::Assistant(GroupedAssistantMessage {
|
||||
error: message_error,
|
||||
tool_calls,
|
||||
messages,
|
||||
..
|
||||
})) = this.messages.last_mut()
|
||||
{
|
||||
|
@ -537,8 +545,10 @@ impl AssistantChat {
|
|||
message_error.replace(SharedString::from(error.to_string()));
|
||||
cx.notify();
|
||||
} else {
|
||||
for tool_call in tool_calls.iter() {
|
||||
tool_tasks.push(this.tool_registry.call(tool_call, cx));
|
||||
if let Some(current_message) = messages.last_mut() {
|
||||
for tool_call in current_message.tool_calls.iter() {
|
||||
tool_tasks.push(this.tool_registry.call(tool_call, cx));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -554,21 +564,38 @@ impl AssistantChat {
|
|||
let tools = tools.into_iter().filter_map(|tool| tool.ok()).collect();
|
||||
|
||||
this.update(cx, |this, cx| {
|
||||
if let Some(ChatMessage::Assistant(AssistantMessage { tool_calls, .. })) =
|
||||
if let Some(ChatMessage::Assistant(GroupedAssistantMessage { messages, .. })) =
|
||||
this.messages.last_mut()
|
||||
{
|
||||
*tool_calls = tools;
|
||||
cx.notify();
|
||||
if let Some(current_message) = messages.last_mut() {
|
||||
current_message.tool_calls = tools;
|
||||
cx.notify();
|
||||
} else {
|
||||
unreachable!()
|
||||
}
|
||||
}
|
||||
})?;
|
||||
}
|
||||
}
|
||||
|
||||
fn push_new_assistant_message(&mut self, cx: &mut ViewContext<Self>) {
|
||||
let message = ChatMessage::Assistant(AssistantMessage {
|
||||
// If the last message is a grouped assistant message, add to the grouped message
|
||||
if let Some(ChatMessage::Assistant(GroupedAssistantMessage { messages, .. })) =
|
||||
self.messages.last_mut()
|
||||
{
|
||||
messages.push(AssistantMessage {
|
||||
body: RichText::default(),
|
||||
tool_calls: Vec::new(),
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
let message = ChatMessage::Assistant(GroupedAssistantMessage {
|
||||
id: self.next_message_id.post_inc(),
|
||||
body: RichText::default(),
|
||||
tool_calls: Vec::new(),
|
||||
messages: vec![AssistantMessage {
|
||||
body: RichText::default(),
|
||||
tool_calls: Vec::new(),
|
||||
}],
|
||||
error: None,
|
||||
});
|
||||
self.push_message(message, cx);
|
||||
|
@ -687,15 +714,14 @@ impl AssistantChat {
|
|||
crate::ui::ChatMessage::new(
|
||||
*id,
|
||||
UserOrAssistant::User(self.user_store.read(cx).current_user()),
|
||||
Some(
|
||||
// todo!(): clean up the vec usage
|
||||
vec![
|
||||
RichText::new(
|
||||
body.read(cx).text(cx),
|
||||
&[],
|
||||
&self.language_registry,
|
||||
)
|
||||
.element(ElementId::from(id.0), cx),
|
||||
),
|
||||
Some(
|
||||
h_flex()
|
||||
.gap_2()
|
||||
.children(
|
||||
|
@ -704,7 +730,7 @@ impl AssistantChat {
|
|||
.map(|attachment| attachment.view.clone()),
|
||||
)
|
||||
.into_any_element(),
|
||||
),
|
||||
],
|
||||
self.is_message_collapsed(id),
|
||||
Box::new(cx.listener({
|
||||
let id = *id;
|
||||
|
@ -719,33 +745,34 @@ impl AssistantChat {
|
|||
}
|
||||
})
|
||||
.into_any(),
|
||||
ChatMessage::Assistant(AssistantMessage {
|
||||
ChatMessage::Assistant(GroupedAssistantMessage {
|
||||
id,
|
||||
body,
|
||||
messages,
|
||||
error,
|
||||
tool_calls,
|
||||
..
|
||||
}) => {
|
||||
let assistant_body = if body.text.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(
|
||||
div()
|
||||
.child(body.element(ElementId::from(id.0), cx))
|
||||
.into_any_element(),
|
||||
)
|
||||
};
|
||||
let mut message_elements = Vec::new();
|
||||
|
||||
let tools = tool_calls
|
||||
.iter()
|
||||
.map(|tool_call| self.tool_registry.render_tool_call(tool_call, cx))
|
||||
.collect::<Vec<AnyElement>>();
|
||||
for message in messages {
|
||||
if !message.body.text.is_empty() {
|
||||
message_elements.push(
|
||||
div()
|
||||
// todo!(): The element Id will need to be a combo of the base ID and the index within the grouping
|
||||
.child(message.body.element(ElementId::from(id.0), cx))
|
||||
.into_any_element(),
|
||||
)
|
||||
}
|
||||
|
||||
let tools_body = if tools.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(div().children(tools).into_any_element())
|
||||
};
|
||||
let tools = message
|
||||
.tool_calls
|
||||
.iter()
|
||||
.map(|tool_call| self.tool_registry.render_tool_call(tool_call, cx))
|
||||
.collect::<Vec<AnyElement>>();
|
||||
|
||||
if !tools.is_empty() {
|
||||
message_elements.push(div().children(tools).into_any_element())
|
||||
}
|
||||
}
|
||||
|
||||
div()
|
||||
.when(is_first, |this| this.pt(padding))
|
||||
|
@ -753,8 +780,7 @@ impl AssistantChat {
|
|||
crate::ui::ChatMessage::new(
|
||||
*id,
|
||||
UserOrAssistant::Assistant,
|
||||
assistant_body,
|
||||
tools_body,
|
||||
message_elements,
|
||||
self.is_message_collapsed(id),
|
||||
Box::new(cx.listener({
|
||||
let id = *id;
|
||||
|
@ -796,46 +822,47 @@ impl AssistantChat {
|
|||
content: body.read(cx).text(cx),
|
||||
});
|
||||
}
|
||||
ChatMessage::Assistant(AssistantMessage {
|
||||
body, tool_calls, ..
|
||||
}) => {
|
||||
// In no case do we want to send an empty message. This shouldn't happen, but we might as well
|
||||
// not break the Chat API if it does.
|
||||
if body.text.is_empty() && tool_calls.is_empty() {
|
||||
continue;
|
||||
}
|
||||
ChatMessage::Assistant(GroupedAssistantMessage { messages, .. }) => {
|
||||
for message in messages {
|
||||
let body = message.body.clone();
|
||||
|
||||
let tool_calls_from_assistant = tool_calls
|
||||
.iter()
|
||||
.map(|tool_call| ToolCall {
|
||||
content: ToolCallContent::Function {
|
||||
function: FunctionContent {
|
||||
name: tool_call.name.clone(),
|
||||
arguments: tool_call.arguments.clone(),
|
||||
if body.text.is_empty() && message.tool_calls.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let tool_calls_from_assistant = message
|
||||
.tool_calls
|
||||
.iter()
|
||||
.map(|tool_call| ToolCall {
|
||||
content: ToolCallContent::Function {
|
||||
function: FunctionContent {
|
||||
name: tool_call.name.clone(),
|
||||
arguments: tool_call.arguments.clone(),
|
||||
},
|
||||
},
|
||||
},
|
||||
id: tool_call.id.clone(),
|
||||
})
|
||||
.collect();
|
||||
id: tool_call.id.clone(),
|
||||
})
|
||||
.collect();
|
||||
|
||||
completion_messages.push(CompletionMessage::Assistant {
|
||||
content: Some(body.text.to_string()),
|
||||
tool_calls: tool_calls_from_assistant,
|
||||
});
|
||||
|
||||
for tool_call in tool_calls {
|
||||
// Every tool call _must_ have a result by ID, otherwise OpenAI will error.
|
||||
let content = match &tool_call.result {
|
||||
Some(result) => {
|
||||
result.generate(&tool_call.name, &mut project_context, cx)
|
||||
}
|
||||
None => "".to_string(),
|
||||
};
|
||||
|
||||
completion_messages.push(CompletionMessage::Tool {
|
||||
content,
|
||||
tool_call_id: tool_call.id.clone(),
|
||||
completion_messages.push(CompletionMessage::Assistant {
|
||||
content: Some(body.text.to_string()),
|
||||
tool_calls: tool_calls_from_assistant,
|
||||
});
|
||||
|
||||
for tool_call in &message.tool_calls {
|
||||
// Every tool call _must_ have a result by ID, otherwise OpenAI will error.
|
||||
let content = match &tool_call.result {
|
||||
Some(result) => {
|
||||
result.generate(&tool_call.name, &mut project_context, cx)
|
||||
}
|
||||
None => "".to_string(),
|
||||
};
|
||||
|
||||
completion_messages.push(CompletionMessage::Tool {
|
||||
content,
|
||||
tool_call_id: tool_call.id.clone(),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -885,7 +912,7 @@ impl MessageId {
|
|||
|
||||
enum ChatMessage {
|
||||
User(UserMessage),
|
||||
Assistant(AssistantMessage),
|
||||
Assistant(GroupedAssistantMessage),
|
||||
}
|
||||
|
||||
impl ChatMessage {
|
||||
|
@ -904,8 +931,12 @@ struct UserMessage {
|
|||
}
|
||||
|
||||
struct AssistantMessage {
|
||||
id: MessageId,
|
||||
body: RichText,
|
||||
tool_calls: Vec<ToolFunctionCall>,
|
||||
}
|
||||
|
||||
struct GroupedAssistantMessage {
|
||||
id: MessageId,
|
||||
messages: Vec<AssistantMessage>,
|
||||
error: Option<SharedString>,
|
||||
}
|
||||
|
|
|
@ -15,8 +15,7 @@ pub enum UserOrAssistant {
|
|||
pub struct ChatMessage {
|
||||
id: MessageId,
|
||||
player: UserOrAssistant,
|
||||
message: Option<AnyElement>,
|
||||
tools_used: Option<AnyElement>,
|
||||
messages: Vec<AnyElement>,
|
||||
selected: bool,
|
||||
collapsed: bool,
|
||||
on_collapse_handle_click: Box<dyn Fn(&ClickEvent, &mut WindowContext) + 'static>,
|
||||
|
@ -26,16 +25,14 @@ impl ChatMessage {
|
|||
pub fn new(
|
||||
id: MessageId,
|
||||
player: UserOrAssistant,
|
||||
message: Option<AnyElement>,
|
||||
tools_used: Option<AnyElement>,
|
||||
messages: Vec<AnyElement>,
|
||||
collapsed: bool,
|
||||
on_collapse_handle_click: Box<dyn Fn(&ClickEvent, &mut WindowContext) + 'static>,
|
||||
) -> Self {
|
||||
Self {
|
||||
id,
|
||||
player,
|
||||
message,
|
||||
tools_used,
|
||||
messages,
|
||||
selected: false,
|
||||
collapsed,
|
||||
on_collapse_handle_click,
|
||||
|
@ -117,19 +114,10 @@ impl RenderOnce for ChatMessage {
|
|||
.icon_color(Color::Muted)
|
||||
.on_click(self.on_collapse_handle_click)
|
||||
.tooltip(|cx| Tooltip::text("Collapse Message", cx)),
|
||||
), // .child(
|
||||
// IconButton::new("copy-message", IconName::Copy)
|
||||
// .icon_color(Color::Muted)
|
||||
// .icon_size(IconSize::XSmall),
|
||||
// )
|
||||
// .child(
|
||||
// IconButton::new("menu", IconName::Ellipsis)
|
||||
// .icon_color(Color::Muted)
|
||||
// .icon_size(IconSize::XSmall),
|
||||
// ),
|
||||
),
|
||||
),
|
||||
)
|
||||
.when(self.message.is_some() || self.tools_used.is_some(), |el| {
|
||||
.when(self.messages.len() > 0, |el| {
|
||||
el.child(
|
||||
h_flex().child(
|
||||
v_flex()
|
||||
|
@ -144,8 +132,7 @@ impl RenderOnce for ChatMessage {
|
|||
this.bg(background_color)
|
||||
})
|
||||
.when(self.collapsed, |this| this.h(collapsed_height))
|
||||
.children(self.message)
|
||||
.when_some(self.tools_used, |this, tools_used| this.child(tools_used)),
|
||||
.children(self.messages),
|
||||
),
|
||||
)
|
||||
})
|
||||
|
|
|
@ -28,8 +28,7 @@ impl Render for ChatMessageStory {
|
|||
ChatMessage::new(
|
||||
MessageId(0),
|
||||
UserOrAssistant::User(Some(user_1.clone())),
|
||||
Some(div().child("What can I do here?").into_any_element()),
|
||||
None,
|
||||
vec![div().child("What can I do here?").into_any_element()],
|
||||
false,
|
||||
Box::new(|_, _| {}),
|
||||
),
|
||||
|
@ -39,8 +38,7 @@ impl Render for ChatMessageStory {
|
|||
ChatMessage::new(
|
||||
MessageId(0),
|
||||
UserOrAssistant::User(Some(user_1.clone())),
|
||||
Some(div().child("What can I do here?").into_any_element()),
|
||||
None,
|
||||
vec![div().child("What can I do here?").into_any_element()],
|
||||
true,
|
||||
Box::new(|_, _| {}),
|
||||
),
|
||||
|
@ -53,8 +51,7 @@ impl Render for ChatMessageStory {
|
|||
ChatMessage::new(
|
||||
MessageId(0),
|
||||
UserOrAssistant::Assistant,
|
||||
Some(div().child("You can talk to me!").into_any_element()),
|
||||
None,
|
||||
vec![div().child("You can talk to me!").into_any_element()],
|
||||
false,
|
||||
Box::new(|_, _| {}),
|
||||
),
|
||||
|
@ -64,8 +61,7 @@ impl Render for ChatMessageStory {
|
|||
ChatMessage::new(
|
||||
MessageId(0),
|
||||
UserOrAssistant::Assistant,
|
||||
Some(div().child(MULTI_LINE_MESSAGE).into_any_element()),
|
||||
None,
|
||||
vec![div().child(MULTI_LINE_MESSAGE).into_any_element()],
|
||||
true,
|
||||
Box::new(|_, _| {}),
|
||||
),
|
||||
|
@ -79,24 +75,21 @@ impl Render for ChatMessageStory {
|
|||
.child(ChatMessage::new(
|
||||
MessageId(0),
|
||||
UserOrAssistant::User(Some(user_1.clone())),
|
||||
Some(div().child("What is Rust??").into_any_element()),
|
||||
None,
|
||||
vec![div().child("What is Rust??").into_any_element()],
|
||||
false,
|
||||
Box::new(|_, _| {}),
|
||||
))
|
||||
.child(ChatMessage::new(
|
||||
MessageId(0),
|
||||
UserOrAssistant::Assistant,
|
||||
Some(div().child("Rust is a multi-paradigm programming language focused on performance and safety").into_any_element()),
|
||||
None,
|
||||
vec![div().child("Rust is a multi-paradigm programming language focused on performance and safety").into_any_element()],
|
||||
false,
|
||||
Box::new(|_, _| {}),
|
||||
))
|
||||
.child(ChatMessage::new(
|
||||
MessageId(0),
|
||||
UserOrAssistant::User(Some(user_1)),
|
||||
Some(div().child("Sounds pretty cool!").into_any_element()),
|
||||
None,
|
||||
vec![div().child("Sounds pretty cool!").into_any_element()],
|
||||
false,
|
||||
Box::new(|_, _| {}),
|
||||
)),
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue