Insert reply after assistant message when hitting cmd-enter

Co-Authored-By: Nathan Sobo <nathan@zed.dev>
This commit is contained in:
Antonio Scandurra 2023-06-06 19:15:06 +02:00
parent ef7ec265c8
commit 16090c35ae

View file

@ -459,7 +459,7 @@ impl Assistant {
api_key,
buffer,
};
this.push_message(Role::User, cx);
this.insert_message_after(ExcerptId::max(), Role::User, cx);
this.count_remaining_tokens(cx);
this
}
@ -498,7 +498,7 @@ impl Assistant {
})
.collect::<Vec<_>>();
let model = self.model.clone();
self.pending_token_count = cx.spawn(|this, mut cx| {
self.pending_token_count = cx.spawn_weak(|this, mut cx| {
async move {
cx.background().timer(Duration::from_millis(200)).await;
let token_count = cx
@ -506,11 +506,13 @@ impl Assistant {
.spawn(async move { tiktoken_rs::num_tokens_from_messages(&model, &messages) })
.await?;
this.update(&mut cx, |this, cx| {
this.max_token_count = tiktoken_rs::model::get_context_size(&this.model);
this.token_count = Some(token_count);
cx.notify()
});
this.upgrade(&cx)
.ok_or_else(|| anyhow!("assistant was dropped"))?
.update(&mut cx, |this, cx| {
this.max_token_count = tiktoken_rs::model::get_context_size(&this.model);
this.token_count = Some(token_count);
cx.notify()
});
anyhow::Ok(())
}
.log_err()
@ -547,9 +549,10 @@ impl Assistant {
let api_key = self.api_key.borrow().clone();
if let Some(api_key) = api_key {
let stream = stream_completion(api_key, cx.background().clone(), request);
let (excerpt_id, content) = self.push_message(Role::Assistant, cx);
self.push_message(Role::User, cx);
let task = cx.spawn(|this, mut cx| async move {
let (excerpt_id, content) =
self.insert_message_after(ExcerptId::max(), Role::Assistant, cx);
self.insert_message_after(ExcerptId::max(), Role::User, cx);
let task = cx.spawn_weak(|this, mut cx| async move {
let stream_completion = async {
let mut messages = stream.await?;
@ -564,22 +567,26 @@ impl Assistant {
}
}
this.update(&mut cx, |this, cx| {
this.pending_completions
.retain(|completion| completion.id != this.completion_count);
this.summarize(cx);
});
this.upgrade(&cx)
.ok_or_else(|| anyhow!("assistant was dropped"))?
.update(&mut cx, |this, cx| {
this.pending_completions
.retain(|completion| completion.id != this.completion_count);
this.summarize(cx);
});
anyhow::Ok(())
};
if let Err(error) = stream_completion.await {
this.update(&mut cx, |this, cx| {
if let Some(metadata) = this.messages_metadata.get_mut(&excerpt_id) {
metadata.error = Some(error.to_string().trim().into());
cx.notify();
}
})
if let Some(this) = this.upgrade(&cx) {
this.update(&mut cx, |this, cx| {
if let Some(metadata) = this.messages_metadata.get_mut(&excerpt_id) {
metadata.error = Some(error.to_string().trim().into());
cx.notify();
}
});
}
}
});
@ -632,8 +639,9 @@ impl Assistant {
}
}
fn push_message(
fn insert_message_after(
&mut self,
excerpt_id: ExcerptId,
role: Role,
cx: &mut ModelContext<Self>,
) -> (ExcerptId, ModelHandle<Buffer>) {
@ -654,9 +662,10 @@ impl Assistant {
buffer.set_language_registry(self.languages.clone());
buffer
});
let excerpt_id = self.buffer.update(cx, |buffer, cx| {
let new_excerpt_id = self.buffer.update(cx, |buffer, cx| {
buffer
.push_excerpts(
.insert_excerpts_after(
excerpt_id,
content.clone(),
vec![ExcerptRange {
context: 0..0,
@ -668,19 +677,27 @@ impl Assistant {
.unwrap()
});
self.messages.push(Message {
excerpt_id,
content: content.clone(),
});
let ix = self
.messages
.iter()
.position(|message| message.excerpt_id == excerpt_id)
.map_or(self.messages.len(), |ix| ix + 1);
self.messages.insert(
ix,
Message {
excerpt_id: new_excerpt_id,
content: content.clone(),
},
);
self.messages_metadata.insert(
excerpt_id,
new_excerpt_id,
MessageMetadata {
role,
sent_at: Local::now(),
error: None,
},
);
(excerpt_id, content)
(new_excerpt_id, content)
}
fn summarize(&mut self, cx: &mut ModelContext<Self>) {
@ -882,7 +899,7 @@ impl AssistantEditor {
if metadata.role == Role::User {
assistant.assist(cx);
} else {
assistant.push_message(Role::User, cx);
assistant.insert_message_after(excerpt_id, Role::User, cx);
}
}
}
@ -1227,3 +1244,28 @@ async fn stream_completion(
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use gpui::AppContext;
#[gpui::test]
fn test_inserting_and_removing_messages(cx: &mut AppContext) {
let registry = Arc::new(LanguageRegistry::test());
cx.add_model(|cx| {
let mut assistant = Assistant::new(Default::default(), registry, cx);
let (excerpt_1, _) =
assistant.insert_message_after(ExcerptId::max(), Role::Assistant, cx);
let (excerpt_2, _) = assistant.insert_message_after(excerpt_1, Role::User, cx);
let (excerpt_3, _) = assistant.insert_message_after(excerpt_1, Role::User, cx);
assistant.remove_empty_messages(
HashSet::from_iter([excerpt_2, excerpt_3]),
Default::default(),
cx,
);
assistant
});
}
}