diff --git a/crates/assistant2/src/assistant2.rs b/crates/assistant2/src/assistant2.rs index 3fa8d25dd6..3dff299e32 100644 --- a/crates/assistant2/src/assistant2.rs +++ b/crates/assistant2/src/assistant2.rs @@ -342,8 +342,8 @@ impl AssistantChat { } if self.pending_completion.take().is_some() { - if let Some(ChatMessage::Assistant(message)) = self.messages.last() { - if message.body.text.is_empty() { + if let Some(ChatMessage::Assistant(grouping)) = self.messages.last() { + if grouping.messages.is_empty() { self.pop_message(cx); } } @@ -478,22 +478,30 @@ impl AssistantChat { while let Some(delta) = stream.next().await { let delta = delta?; this.update(cx, |this, cx| { - if let Some(ChatMessage::Assistant(AssistantMessage { - body: message_body, - tool_calls: message_tool_calls, + if let Some(ChatMessage::Assistant(GroupedAssistantMessage { + messages, .. })) = this.messages.last_mut() { + if messages.is_empty() { + messages.push(AssistantMessage { + body: RichText::default(), + tool_calls: Vec::new(), + }) + } + + let message = messages.last_mut().unwrap(); + if let Some(content) = &delta.content { body.push_str(content); } for tool_call in delta.tool_calls { let index = tool_call.index as usize; - if index >= message_tool_calls.len() { - message_tool_calls.resize_with(index + 1, Default::default); + if index >= message.tool_calls.len() { + message.tool_calls.resize_with(index + 1, Default::default); } - let call = &mut message_tool_calls[index]; + let call = &mut message.tool_calls[index]; if let Some(id) = &tool_call.id { call.id.push_str(id); @@ -512,7 +520,7 @@ impl AssistantChat { } } - *message_body = + message.body = RichText::new(body.clone(), &[], &this.language_registry); cx.notify(); } else { @@ -527,9 +535,9 @@ impl AssistantChat { let mut tool_tasks = Vec::new(); this.update(cx, |this, cx| { - if let Some(ChatMessage::Assistant(AssistantMessage { + if let Some(ChatMessage::Assistant(GroupedAssistantMessage { error: message_error, - tool_calls, + messages, .. })) = this.messages.last_mut() { @@ -537,8 +545,10 @@ impl AssistantChat { message_error.replace(SharedString::from(error.to_string())); cx.notify(); } else { - for tool_call in tool_calls.iter() { - tool_tasks.push(this.tool_registry.call(tool_call, cx)); + if let Some(current_message) = messages.last_mut() { + for tool_call in current_message.tool_calls.iter() { + tool_tasks.push(this.tool_registry.call(tool_call, cx)); + } } } } @@ -554,21 +564,38 @@ impl AssistantChat { let tools = tools.into_iter().filter_map(|tool| tool.ok()).collect(); this.update(cx, |this, cx| { - if let Some(ChatMessage::Assistant(AssistantMessage { tool_calls, .. })) = + if let Some(ChatMessage::Assistant(GroupedAssistantMessage { messages, .. })) = this.messages.last_mut() { - *tool_calls = tools; - cx.notify(); + if let Some(current_message) = messages.last_mut() { + current_message.tool_calls = tools; + cx.notify(); + } else { + unreachable!() + } } })?; } } fn push_new_assistant_message(&mut self, cx: &mut ViewContext) { - let message = ChatMessage::Assistant(AssistantMessage { + // If the last message is a grouped assistant message, add to the grouped message + if let Some(ChatMessage::Assistant(GroupedAssistantMessage { messages, .. })) = + self.messages.last_mut() + { + messages.push(AssistantMessage { + body: RichText::default(), + tool_calls: Vec::new(), + }); + return; + } + + let message = ChatMessage::Assistant(GroupedAssistantMessage { id: self.next_message_id.post_inc(), - body: RichText::default(), - tool_calls: Vec::new(), + messages: vec![AssistantMessage { + body: RichText::default(), + tool_calls: Vec::new(), + }], error: None, }); self.push_message(message, cx); @@ -687,15 +714,14 @@ impl AssistantChat { crate::ui::ChatMessage::new( *id, UserOrAssistant::User(self.user_store.read(cx).current_user()), - Some( + // todo!(): clean up the vec usage + vec![ RichText::new( body.read(cx).text(cx), &[], &self.language_registry, ) .element(ElementId::from(id.0), cx), - ), - Some( h_flex() .gap_2() .children( @@ -704,7 +730,7 @@ impl AssistantChat { .map(|attachment| attachment.view.clone()), ) .into_any_element(), - ), + ], self.is_message_collapsed(id), Box::new(cx.listener({ let id = *id; @@ -719,33 +745,34 @@ impl AssistantChat { } }) .into_any(), - ChatMessage::Assistant(AssistantMessage { + ChatMessage::Assistant(GroupedAssistantMessage { id, - body, + messages, error, - tool_calls, .. }) => { - let assistant_body = if body.text.is_empty() { - None - } else { - Some( - div() - .child(body.element(ElementId::from(id.0), cx)) - .into_any_element(), - ) - }; + let mut message_elements = Vec::new(); - let tools = tool_calls - .iter() - .map(|tool_call| self.tool_registry.render_tool_call(tool_call, cx)) - .collect::>(); + for message in messages { + if !message.body.text.is_empty() { + message_elements.push( + div() + // todo!(): The element Id will need to be a combo of the base ID and the index within the grouping + .child(message.body.element(ElementId::from(id.0), cx)) + .into_any_element(), + ) + } - let tools_body = if tools.is_empty() { - None - } else { - Some(div().children(tools).into_any_element()) - }; + let tools = message + .tool_calls + .iter() + .map(|tool_call| self.tool_registry.render_tool_call(tool_call, cx)) + .collect::>(); + + if !tools.is_empty() { + message_elements.push(div().children(tools).into_any_element()) + } + } div() .when(is_first, |this| this.pt(padding)) @@ -753,8 +780,7 @@ impl AssistantChat { crate::ui::ChatMessage::new( *id, UserOrAssistant::Assistant, - assistant_body, - tools_body, + message_elements, self.is_message_collapsed(id), Box::new(cx.listener({ let id = *id; @@ -796,46 +822,47 @@ impl AssistantChat { content: body.read(cx).text(cx), }); } - ChatMessage::Assistant(AssistantMessage { - body, tool_calls, .. - }) => { - // In no case do we want to send an empty message. This shouldn't happen, but we might as well - // not break the Chat API if it does. - if body.text.is_empty() && tool_calls.is_empty() { - continue; - } + ChatMessage::Assistant(GroupedAssistantMessage { messages, .. }) => { + for message in messages { + let body = message.body.clone(); - let tool_calls_from_assistant = tool_calls - .iter() - .map(|tool_call| ToolCall { - content: ToolCallContent::Function { - function: FunctionContent { - name: tool_call.name.clone(), - arguments: tool_call.arguments.clone(), + if body.text.is_empty() && message.tool_calls.is_empty() { + continue; + } + + let tool_calls_from_assistant = message + .tool_calls + .iter() + .map(|tool_call| ToolCall { + content: ToolCallContent::Function { + function: FunctionContent { + name: tool_call.name.clone(), + arguments: tool_call.arguments.clone(), + }, }, - }, - id: tool_call.id.clone(), - }) - .collect(); + id: tool_call.id.clone(), + }) + .collect(); - completion_messages.push(CompletionMessage::Assistant { - content: Some(body.text.to_string()), - tool_calls: tool_calls_from_assistant, - }); - - for tool_call in tool_calls { - // Every tool call _must_ have a result by ID, otherwise OpenAI will error. - let content = match &tool_call.result { - Some(result) => { - result.generate(&tool_call.name, &mut project_context, cx) - } - None => "".to_string(), - }; - - completion_messages.push(CompletionMessage::Tool { - content, - tool_call_id: tool_call.id.clone(), + completion_messages.push(CompletionMessage::Assistant { + content: Some(body.text.to_string()), + tool_calls: tool_calls_from_assistant, }); + + for tool_call in &message.tool_calls { + // Every tool call _must_ have a result by ID, otherwise OpenAI will error. + let content = match &tool_call.result { + Some(result) => { + result.generate(&tool_call.name, &mut project_context, cx) + } + None => "".to_string(), + }; + + completion_messages.push(CompletionMessage::Tool { + content, + tool_call_id: tool_call.id.clone(), + }); + } } } } @@ -885,7 +912,7 @@ impl MessageId { enum ChatMessage { User(UserMessage), - Assistant(AssistantMessage), + Assistant(GroupedAssistantMessage), } impl ChatMessage { @@ -904,8 +931,12 @@ struct UserMessage { } struct AssistantMessage { - id: MessageId, body: RichText, tool_calls: Vec, +} + +struct GroupedAssistantMessage { + id: MessageId, + messages: Vec, error: Option, } diff --git a/crates/assistant2/src/ui/chat_message.rs b/crates/assistant2/src/ui/chat_message.rs index 7d8043ca61..a3d3f11fdc 100644 --- a/crates/assistant2/src/ui/chat_message.rs +++ b/crates/assistant2/src/ui/chat_message.rs @@ -15,8 +15,7 @@ pub enum UserOrAssistant { pub struct ChatMessage { id: MessageId, player: UserOrAssistant, - message: Option, - tools_used: Option, + messages: Vec, selected: bool, collapsed: bool, on_collapse_handle_click: Box, @@ -26,16 +25,14 @@ impl ChatMessage { pub fn new( id: MessageId, player: UserOrAssistant, - message: Option, - tools_used: Option, + messages: Vec, collapsed: bool, on_collapse_handle_click: Box, ) -> Self { Self { id, player, - message, - tools_used, + messages, selected: false, collapsed, on_collapse_handle_click, @@ -117,19 +114,10 @@ impl RenderOnce for ChatMessage { .icon_color(Color::Muted) .on_click(self.on_collapse_handle_click) .tooltip(|cx| Tooltip::text("Collapse Message", cx)), - ), // .child( - // IconButton::new("copy-message", IconName::Copy) - // .icon_color(Color::Muted) - // .icon_size(IconSize::XSmall), - // ) - // .child( - // IconButton::new("menu", IconName::Ellipsis) - // .icon_color(Color::Muted) - // .icon_size(IconSize::XSmall), - // ), + ), ), ) - .when(self.message.is_some() || self.tools_used.is_some(), |el| { + .when(self.messages.len() > 0, |el| { el.child( h_flex().child( v_flex() @@ -144,8 +132,7 @@ impl RenderOnce for ChatMessage { this.bg(background_color) }) .when(self.collapsed, |this| this.h(collapsed_height)) - .children(self.message) - .when_some(self.tools_used, |this, tools_used| this.child(tools_used)), + .children(self.messages), ), ) }) diff --git a/crates/assistant2/src/ui/stories/chat_message.rs b/crates/assistant2/src/ui/stories/chat_message.rs index 3058d0cdea..1d63ae78c4 100644 --- a/crates/assistant2/src/ui/stories/chat_message.rs +++ b/crates/assistant2/src/ui/stories/chat_message.rs @@ -28,8 +28,7 @@ impl Render for ChatMessageStory { ChatMessage::new( MessageId(0), UserOrAssistant::User(Some(user_1.clone())), - Some(div().child("What can I do here?").into_any_element()), - None, + vec![div().child("What can I do here?").into_any_element()], false, Box::new(|_, _| {}), ), @@ -39,8 +38,7 @@ impl Render for ChatMessageStory { ChatMessage::new( MessageId(0), UserOrAssistant::User(Some(user_1.clone())), - Some(div().child("What can I do here?").into_any_element()), - None, + vec![div().child("What can I do here?").into_any_element()], true, Box::new(|_, _| {}), ), @@ -53,8 +51,7 @@ impl Render for ChatMessageStory { ChatMessage::new( MessageId(0), UserOrAssistant::Assistant, - Some(div().child("You can talk to me!").into_any_element()), - None, + vec![div().child("You can talk to me!").into_any_element()], false, Box::new(|_, _| {}), ), @@ -64,8 +61,7 @@ impl Render for ChatMessageStory { ChatMessage::new( MessageId(0), UserOrAssistant::Assistant, - Some(div().child(MULTI_LINE_MESSAGE).into_any_element()), - None, + vec![div().child(MULTI_LINE_MESSAGE).into_any_element()], true, Box::new(|_, _| {}), ), @@ -79,24 +75,21 @@ impl Render for ChatMessageStory { .child(ChatMessage::new( MessageId(0), UserOrAssistant::User(Some(user_1.clone())), - Some(div().child("What is Rust??").into_any_element()), - None, + vec![div().child("What is Rust??").into_any_element()], false, Box::new(|_, _| {}), )) .child(ChatMessage::new( MessageId(0), UserOrAssistant::Assistant, - Some(div().child("Rust is a multi-paradigm programming language focused on performance and safety").into_any_element()), - None, + vec![div().child("Rust is a multi-paradigm programming language focused on performance and safety").into_any_element()], false, Box::new(|_, _| {}), )) .child(ChatMessage::new( MessageId(0), UserOrAssistant::User(Some(user_1)), - Some(div().child("Sounds pretty cool!").into_any_element()), - None, + vec![div().child("Sounds pretty cool!").into_any_element()], false, Box::new(|_, _| {}), )),