diff --git a/crates/ai/src/assistant.rs b/crates/ai/src/assistant.rs index a96528039c..be3a49ce18 100644 --- a/crates/ai/src/assistant.rs +++ b/crates/ai/src/assistant.rs @@ -543,7 +543,7 @@ impl Assistant { fn count_remaining_tokens(&mut self, cx: &mut ModelContext) { let messages = self - .open_ai_request_messages(cx) + .messages(cx) .into_iter() .filter_map(|message| { Some(tiktoken_rs::ChatCompletionRequestMessage { @@ -552,7 +552,7 @@ impl Assistant { Role::Assistant => "assistant".into(), Role::System => "system".into(), }, - content: message.content, + content: self.buffer.read(cx).text_for_range(message.range).collect(), name: None, }) }) @@ -596,7 +596,10 @@ impl Assistant { ) -> Option<(MessageAnchor, MessageAnchor)> { let request = OpenAIRequest { model: self.model.clone(), - messages: self.open_ai_request_messages(cx), + messages: self + .messages(cx) + .map(|message| message.to_open_ai_message(self.buffer.read(cx))) + .collect(), stream: true, }; @@ -841,16 +844,19 @@ impl Assistant { if self.message_anchors.len() >= 2 && self.summary.is_none() { let api_key = self.api_key.borrow().clone(); if let Some(api_key) = api_key { - let mut messages = self.open_ai_request_messages(cx); - messages.truncate(2); - messages.push(RequestMessage { - role: Role::User, - content: "Summarize the conversation into a short title without punctuation" - .into(), - }); + let messages = self + .messages(cx) + .take(2) + .map(|message| message.to_open_ai_message(self.buffer.read(cx))) + .chain(Some(RequestMessage { + role: Role::User, + content: + "Summarize the conversation into a short title without punctuation" + .into(), + })); let request = OpenAIRequest { model: self.model.clone(), - messages, + messages: messages.collect(), stream: true, }; @@ -878,16 +884,6 @@ impl Assistant { } } - fn open_ai_request_messages(&self, cx: &AppContext) -> Vec { - let buffer = self.buffer.read(cx); - self.messages(cx) - .map(|message| RequestMessage { - role: message.role, - content: buffer.text_for_range(message.range).collect(), - }) - .collect() - } - fn message_for_offset<'a>(&'a self, offset: usize, cx: &'a AppContext) -> Option { let mut messages = self.messages(cx).peekable(); while let Some(message) = messages.next() { @@ -1446,6 +1442,15 @@ pub struct Message { error: Option>, } +impl Message { + fn to_open_ai_message(&self, buffer: &Buffer) -> RequestMessage { + RequestMessage { + role: self.role, + content: buffer.text_for_range(self.range.clone()).collect(), + } + } +} + async fn stream_completion( api_key: String, executor: Arc,