Assistant grouping (#11479)
Groups collections of assistant messages with their tool calls as children of the assistant message container.  Release Notes: - N/A
This commit is contained in:
parent
c77d2eb73f
commit
72c47b7f01
3 changed files with 128 additions and 117 deletions
|
@ -342,8 +342,8 @@ impl AssistantChat {
|
|||
}
|
||||
|
||||
if self.pending_completion.take().is_some() {
|
||||
if let Some(ChatMessage::Assistant(message)) = self.messages.last() {
|
||||
if message.body.text.is_empty() {
|
||||
if let Some(ChatMessage::Assistant(grouping)) = self.messages.last() {
|
||||
if grouping.messages.is_empty() {
|
||||
self.pop_message(cx);
|
||||
}
|
||||
}
|
||||
|
@ -478,22 +478,30 @@ impl AssistantChat {
|
|||
while let Some(delta) = stream.next().await {
|
||||
let delta = delta?;
|
||||
this.update(cx, |this, cx| {
|
||||
if let Some(ChatMessage::Assistant(AssistantMessage {
|
||||
body: message_body,
|
||||
tool_calls: message_tool_calls,
|
||||
if let Some(ChatMessage::Assistant(GroupedAssistantMessage {
|
||||
messages,
|
||||
..
|
||||
})) = this.messages.last_mut()
|
||||
{
|
||||
if messages.is_empty() {
|
||||
messages.push(AssistantMessage {
|
||||
body: RichText::default(),
|
||||
tool_calls: Vec::new(),
|
||||
})
|
||||
}
|
||||
|
||||
let message = messages.last_mut().unwrap();
|
||||
|
||||
if let Some(content) = &delta.content {
|
||||
body.push_str(content);
|
||||
}
|
||||
|
||||
for tool_call in delta.tool_calls {
|
||||
let index = tool_call.index as usize;
|
||||
if index >= message_tool_calls.len() {
|
||||
message_tool_calls.resize_with(index + 1, Default::default);
|
||||
if index >= message.tool_calls.len() {
|
||||
message.tool_calls.resize_with(index + 1, Default::default);
|
||||
}
|
||||
let call = &mut message_tool_calls[index];
|
||||
let call = &mut message.tool_calls[index];
|
||||
|
||||
if let Some(id) = &tool_call.id {
|
||||
call.id.push_str(id);
|
||||
|
@ -512,7 +520,7 @@ impl AssistantChat {
|
|||
}
|
||||
}
|
||||
|
||||
*message_body =
|
||||
message.body =
|
||||
RichText::new(body.clone(), &[], &this.language_registry);
|
||||
cx.notify();
|
||||
} else {
|
||||
|
@ -527,9 +535,9 @@ impl AssistantChat {
|
|||
|
||||
let mut tool_tasks = Vec::new();
|
||||
this.update(cx, |this, cx| {
|
||||
if let Some(ChatMessage::Assistant(AssistantMessage {
|
||||
if let Some(ChatMessage::Assistant(GroupedAssistantMessage {
|
||||
error: message_error,
|
||||
tool_calls,
|
||||
messages,
|
||||
..
|
||||
})) = this.messages.last_mut()
|
||||
{
|
||||
|
@ -537,8 +545,10 @@ impl AssistantChat {
|
|||
message_error.replace(SharedString::from(error.to_string()));
|
||||
cx.notify();
|
||||
} else {
|
||||
for tool_call in tool_calls.iter() {
|
||||
tool_tasks.push(this.tool_registry.call(tool_call, cx));
|
||||
if let Some(current_message) = messages.last_mut() {
|
||||
for tool_call in current_message.tool_calls.iter() {
|
||||
tool_tasks.push(this.tool_registry.call(tool_call, cx));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -554,21 +564,38 @@ impl AssistantChat {
|
|||
let tools = tools.into_iter().filter_map(|tool| tool.ok()).collect();
|
||||
|
||||
this.update(cx, |this, cx| {
|
||||
if let Some(ChatMessage::Assistant(AssistantMessage { tool_calls, .. })) =
|
||||
if let Some(ChatMessage::Assistant(GroupedAssistantMessage { messages, .. })) =
|
||||
this.messages.last_mut()
|
||||
{
|
||||
*tool_calls = tools;
|
||||
cx.notify();
|
||||
if let Some(current_message) = messages.last_mut() {
|
||||
current_message.tool_calls = tools;
|
||||
cx.notify();
|
||||
} else {
|
||||
unreachable!()
|
||||
}
|
||||
}
|
||||
})?;
|
||||
}
|
||||
}
|
||||
|
||||
fn push_new_assistant_message(&mut self, cx: &mut ViewContext<Self>) {
|
||||
let message = ChatMessage::Assistant(AssistantMessage {
|
||||
// If the last message is a grouped assistant message, add to the grouped message
|
||||
if let Some(ChatMessage::Assistant(GroupedAssistantMessage { messages, .. })) =
|
||||
self.messages.last_mut()
|
||||
{
|
||||
messages.push(AssistantMessage {
|
||||
body: RichText::default(),
|
||||
tool_calls: Vec::new(),
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
let message = ChatMessage::Assistant(GroupedAssistantMessage {
|
||||
id: self.next_message_id.post_inc(),
|
||||
body: RichText::default(),
|
||||
tool_calls: Vec::new(),
|
||||
messages: vec![AssistantMessage {
|
||||
body: RichText::default(),
|
||||
tool_calls: Vec::new(),
|
||||
}],
|
||||
error: None,
|
||||
});
|
||||
self.push_message(message, cx);
|
||||
|
@ -687,15 +714,14 @@ impl AssistantChat {
|
|||
crate::ui::ChatMessage::new(
|
||||
*id,
|
||||
UserOrAssistant::User(self.user_store.read(cx).current_user()),
|
||||
Some(
|
||||
// todo!(): clean up the vec usage
|
||||
vec![
|
||||
RichText::new(
|
||||
body.read(cx).text(cx),
|
||||
&[],
|
||||
&self.language_registry,
|
||||
)
|
||||
.element(ElementId::from(id.0), cx),
|
||||
),
|
||||
Some(
|
||||
h_flex()
|
||||
.gap_2()
|
||||
.children(
|
||||
|
@ -704,7 +730,7 @@ impl AssistantChat {
|
|||
.map(|attachment| attachment.view.clone()),
|
||||
)
|
||||
.into_any_element(),
|
||||
),
|
||||
],
|
||||
self.is_message_collapsed(id),
|
||||
Box::new(cx.listener({
|
||||
let id = *id;
|
||||
|
@ -719,33 +745,34 @@ impl AssistantChat {
|
|||
}
|
||||
})
|
||||
.into_any(),
|
||||
ChatMessage::Assistant(AssistantMessage {
|
||||
ChatMessage::Assistant(GroupedAssistantMessage {
|
||||
id,
|
||||
body,
|
||||
messages,
|
||||
error,
|
||||
tool_calls,
|
||||
..
|
||||
}) => {
|
||||
let assistant_body = if body.text.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(
|
||||
div()
|
||||
.child(body.element(ElementId::from(id.0), cx))
|
||||
.into_any_element(),
|
||||
)
|
||||
};
|
||||
let mut message_elements = Vec::new();
|
||||
|
||||
let tools = tool_calls
|
||||
.iter()
|
||||
.map(|tool_call| self.tool_registry.render_tool_call(tool_call, cx))
|
||||
.collect::<Vec<AnyElement>>();
|
||||
for message in messages {
|
||||
if !message.body.text.is_empty() {
|
||||
message_elements.push(
|
||||
div()
|
||||
// todo!(): The element Id will need to be a combo of the base ID and the index within the grouping
|
||||
.child(message.body.element(ElementId::from(id.0), cx))
|
||||
.into_any_element(),
|
||||
)
|
||||
}
|
||||
|
||||
let tools_body = if tools.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(div().children(tools).into_any_element())
|
||||
};
|
||||
let tools = message
|
||||
.tool_calls
|
||||
.iter()
|
||||
.map(|tool_call| self.tool_registry.render_tool_call(tool_call, cx))
|
||||
.collect::<Vec<AnyElement>>();
|
||||
|
||||
if !tools.is_empty() {
|
||||
message_elements.push(div().children(tools).into_any_element())
|
||||
}
|
||||
}
|
||||
|
||||
div()
|
||||
.when(is_first, |this| this.pt(padding))
|
||||
|
@ -753,8 +780,7 @@ impl AssistantChat {
|
|||
crate::ui::ChatMessage::new(
|
||||
*id,
|
||||
UserOrAssistant::Assistant,
|
||||
assistant_body,
|
||||
tools_body,
|
||||
message_elements,
|
||||
self.is_message_collapsed(id),
|
||||
Box::new(cx.listener({
|
||||
let id = *id;
|
||||
|
@ -796,46 +822,47 @@ impl AssistantChat {
|
|||
content: body.read(cx).text(cx),
|
||||
});
|
||||
}
|
||||
ChatMessage::Assistant(AssistantMessage {
|
||||
body, tool_calls, ..
|
||||
}) => {
|
||||
// In no case do we want to send an empty message. This shouldn't happen, but we might as well
|
||||
// not break the Chat API if it does.
|
||||
if body.text.is_empty() && tool_calls.is_empty() {
|
||||
continue;
|
||||
}
|
||||
ChatMessage::Assistant(GroupedAssistantMessage { messages, .. }) => {
|
||||
for message in messages {
|
||||
let body = message.body.clone();
|
||||
|
||||
let tool_calls_from_assistant = tool_calls
|
||||
.iter()
|
||||
.map(|tool_call| ToolCall {
|
||||
content: ToolCallContent::Function {
|
||||
function: FunctionContent {
|
||||
name: tool_call.name.clone(),
|
||||
arguments: tool_call.arguments.clone(),
|
||||
if body.text.is_empty() && message.tool_calls.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let tool_calls_from_assistant = message
|
||||
.tool_calls
|
||||
.iter()
|
||||
.map(|tool_call| ToolCall {
|
||||
content: ToolCallContent::Function {
|
||||
function: FunctionContent {
|
||||
name: tool_call.name.clone(),
|
||||
arguments: tool_call.arguments.clone(),
|
||||
},
|
||||
},
|
||||
},
|
||||
id: tool_call.id.clone(),
|
||||
})
|
||||
.collect();
|
||||
id: tool_call.id.clone(),
|
||||
})
|
||||
.collect();
|
||||
|
||||
completion_messages.push(CompletionMessage::Assistant {
|
||||
content: Some(body.text.to_string()),
|
||||
tool_calls: tool_calls_from_assistant,
|
||||
});
|
||||
|
||||
for tool_call in tool_calls {
|
||||
// Every tool call _must_ have a result by ID, otherwise OpenAI will error.
|
||||
let content = match &tool_call.result {
|
||||
Some(result) => {
|
||||
result.generate(&tool_call.name, &mut project_context, cx)
|
||||
}
|
||||
None => "".to_string(),
|
||||
};
|
||||
|
||||
completion_messages.push(CompletionMessage::Tool {
|
||||
content,
|
||||
tool_call_id: tool_call.id.clone(),
|
||||
completion_messages.push(CompletionMessage::Assistant {
|
||||
content: Some(body.text.to_string()),
|
||||
tool_calls: tool_calls_from_assistant,
|
||||
});
|
||||
|
||||
for tool_call in &message.tool_calls {
|
||||
// Every tool call _must_ have a result by ID, otherwise OpenAI will error.
|
||||
let content = match &tool_call.result {
|
||||
Some(result) => {
|
||||
result.generate(&tool_call.name, &mut project_context, cx)
|
||||
}
|
||||
None => "".to_string(),
|
||||
};
|
||||
|
||||
completion_messages.push(CompletionMessage::Tool {
|
||||
content,
|
||||
tool_call_id: tool_call.id.clone(),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -885,7 +912,7 @@ impl MessageId {
|
|||
|
||||
enum ChatMessage {
|
||||
User(UserMessage),
|
||||
Assistant(AssistantMessage),
|
||||
Assistant(GroupedAssistantMessage),
|
||||
}
|
||||
|
||||
impl ChatMessage {
|
||||
|
@ -904,8 +931,12 @@ struct UserMessage {
|
|||
}
|
||||
|
||||
struct AssistantMessage {
|
||||
id: MessageId,
|
||||
body: RichText,
|
||||
tool_calls: Vec<ToolFunctionCall>,
|
||||
}
|
||||
|
||||
struct GroupedAssistantMessage {
|
||||
id: MessageId,
|
||||
messages: Vec<AssistantMessage>,
|
||||
error: Option<SharedString>,
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue