Insert reply after assistant message when hitting cmd-enter

Co-Authored-By: Nathan Sobo <nathan@zed.dev>
This commit is contained in:
Antonio Scandurra 2023-06-06 19:15:06 +02:00
parent ef7ec265c8
commit 16090c35ae

View file

@ -459,7 +459,7 @@ impl Assistant {
api_key, api_key,
buffer, buffer,
}; };
this.push_message(Role::User, cx); this.insert_message_after(ExcerptId::max(), Role::User, cx);
this.count_remaining_tokens(cx); this.count_remaining_tokens(cx);
this this
} }
@ -498,7 +498,7 @@ impl Assistant {
}) })
.collect::<Vec<_>>(); .collect::<Vec<_>>();
let model = self.model.clone(); let model = self.model.clone();
self.pending_token_count = cx.spawn(|this, mut cx| { self.pending_token_count = cx.spawn_weak(|this, mut cx| {
async move { async move {
cx.background().timer(Duration::from_millis(200)).await; cx.background().timer(Duration::from_millis(200)).await;
let token_count = cx let token_count = cx
@ -506,11 +506,13 @@ impl Assistant {
.spawn(async move { tiktoken_rs::num_tokens_from_messages(&model, &messages) }) .spawn(async move { tiktoken_rs::num_tokens_from_messages(&model, &messages) })
.await?; .await?;
this.update(&mut cx, |this, cx| { this.upgrade(&cx)
this.max_token_count = tiktoken_rs::model::get_context_size(&this.model); .ok_or_else(|| anyhow!("assistant was dropped"))?
this.token_count = Some(token_count); .update(&mut cx, |this, cx| {
cx.notify() this.max_token_count = tiktoken_rs::model::get_context_size(&this.model);
}); this.token_count = Some(token_count);
cx.notify()
});
anyhow::Ok(()) anyhow::Ok(())
} }
.log_err() .log_err()
@ -547,9 +549,10 @@ impl Assistant {
let api_key = self.api_key.borrow().clone(); let api_key = self.api_key.borrow().clone();
if let Some(api_key) = api_key { if let Some(api_key) = api_key {
let stream = stream_completion(api_key, cx.background().clone(), request); let stream = stream_completion(api_key, cx.background().clone(), request);
let (excerpt_id, content) = self.push_message(Role::Assistant, cx); let (excerpt_id, content) =
self.push_message(Role::User, cx); self.insert_message_after(ExcerptId::max(), Role::Assistant, cx);
let task = cx.spawn(|this, mut cx| async move { self.insert_message_after(ExcerptId::max(), Role::User, cx);
let task = cx.spawn_weak(|this, mut cx| async move {
let stream_completion = async { let stream_completion = async {
let mut messages = stream.await?; let mut messages = stream.await?;
@ -564,22 +567,26 @@ impl Assistant {
} }
} }
this.update(&mut cx, |this, cx| { this.upgrade(&cx)
this.pending_completions .ok_or_else(|| anyhow!("assistant was dropped"))?
.retain(|completion| completion.id != this.completion_count); .update(&mut cx, |this, cx| {
this.summarize(cx); this.pending_completions
}); .retain(|completion| completion.id != this.completion_count);
this.summarize(cx);
});
anyhow::Ok(()) anyhow::Ok(())
}; };
if let Err(error) = stream_completion.await { if let Err(error) = stream_completion.await {
this.update(&mut cx, |this, cx| { if let Some(this) = this.upgrade(&cx) {
if let Some(metadata) = this.messages_metadata.get_mut(&excerpt_id) { this.update(&mut cx, |this, cx| {
metadata.error = Some(error.to_string().trim().into()); if let Some(metadata) = this.messages_metadata.get_mut(&excerpt_id) {
cx.notify(); metadata.error = Some(error.to_string().trim().into());
} cx.notify();
}) }
});
}
} }
}); });
@ -632,8 +639,9 @@ impl Assistant {
} }
} }
fn push_message( fn insert_message_after(
&mut self, &mut self,
excerpt_id: ExcerptId,
role: Role, role: Role,
cx: &mut ModelContext<Self>, cx: &mut ModelContext<Self>,
) -> (ExcerptId, ModelHandle<Buffer>) { ) -> (ExcerptId, ModelHandle<Buffer>) {
@ -654,9 +662,10 @@ impl Assistant {
buffer.set_language_registry(self.languages.clone()); buffer.set_language_registry(self.languages.clone());
buffer buffer
}); });
let excerpt_id = self.buffer.update(cx, |buffer, cx| { let new_excerpt_id = self.buffer.update(cx, |buffer, cx| {
buffer buffer
.push_excerpts( .insert_excerpts_after(
excerpt_id,
content.clone(), content.clone(),
vec![ExcerptRange { vec![ExcerptRange {
context: 0..0, context: 0..0,
@ -668,19 +677,27 @@ impl Assistant {
.unwrap() .unwrap()
}); });
self.messages.push(Message { let ix = self
excerpt_id, .messages
content: content.clone(), .iter()
}); .position(|message| message.excerpt_id == excerpt_id)
.map_or(self.messages.len(), |ix| ix + 1);
self.messages.insert(
ix,
Message {
excerpt_id: new_excerpt_id,
content: content.clone(),
},
);
self.messages_metadata.insert( self.messages_metadata.insert(
excerpt_id, new_excerpt_id,
MessageMetadata { MessageMetadata {
role, role,
sent_at: Local::now(), sent_at: Local::now(),
error: None, error: None,
}, },
); );
(excerpt_id, content) (new_excerpt_id, content)
} }
fn summarize(&mut self, cx: &mut ModelContext<Self>) { fn summarize(&mut self, cx: &mut ModelContext<Self>) {
@ -882,7 +899,7 @@ impl AssistantEditor {
if metadata.role == Role::User { if metadata.role == Role::User {
assistant.assist(cx); assistant.assist(cx);
} else { } else {
assistant.push_message(Role::User, cx); assistant.insert_message_after(excerpt_id, Role::User, cx);
} }
} }
} }
@ -1227,3 +1244,28 @@ async fn stream_completion(
} }
} }
} }
#[cfg(test)]
mod tests {
use super::*;
use gpui::AppContext;
#[gpui::test]
fn test_inserting_and_removing_messages(cx: &mut AppContext) {
let registry = Arc::new(LanguageRegistry::test());
cx.add_model(|cx| {
let mut assistant = Assistant::new(Default::default(), registry, cx);
let (excerpt_1, _) =
assistant.insert_message_after(ExcerptId::max(), Role::Assistant, cx);
let (excerpt_2, _) = assistant.insert_message_after(excerpt_1, Role::User, cx);
let (excerpt_3, _) = assistant.insert_message_after(excerpt_1, Role::User, cx);
assistant.remove_empty_messages(
HashSet::from_iter([excerpt_2, excerpt_3]),
Default::default(),
cx,
);
assistant
});
}
}