Insert reply after assistant message when hitting cmd-enter
Co-Authored-By: Nathan Sobo <nathan@zed.dev>
This commit is contained in:
parent
ef7ec265c8
commit
16090c35ae
1 changed files with 73 additions and 31 deletions
|
@ -459,7 +459,7 @@ impl Assistant {
|
|||
api_key,
|
||||
buffer,
|
||||
};
|
||||
this.push_message(Role::User, cx);
|
||||
this.insert_message_after(ExcerptId::max(), Role::User, cx);
|
||||
this.count_remaining_tokens(cx);
|
||||
this
|
||||
}
|
||||
|
@ -498,7 +498,7 @@ impl Assistant {
|
|||
})
|
||||
.collect::<Vec<_>>();
|
||||
let model = self.model.clone();
|
||||
self.pending_token_count = cx.spawn(|this, mut cx| {
|
||||
self.pending_token_count = cx.spawn_weak(|this, mut cx| {
|
||||
async move {
|
||||
cx.background().timer(Duration::from_millis(200)).await;
|
||||
let token_count = cx
|
||||
|
@ -506,11 +506,13 @@ impl Assistant {
|
|||
.spawn(async move { tiktoken_rs::num_tokens_from_messages(&model, &messages) })
|
||||
.await?;
|
||||
|
||||
this.update(&mut cx, |this, cx| {
|
||||
this.max_token_count = tiktoken_rs::model::get_context_size(&this.model);
|
||||
this.token_count = Some(token_count);
|
||||
cx.notify()
|
||||
});
|
||||
this.upgrade(&cx)
|
||||
.ok_or_else(|| anyhow!("assistant was dropped"))?
|
||||
.update(&mut cx, |this, cx| {
|
||||
this.max_token_count = tiktoken_rs::model::get_context_size(&this.model);
|
||||
this.token_count = Some(token_count);
|
||||
cx.notify()
|
||||
});
|
||||
anyhow::Ok(())
|
||||
}
|
||||
.log_err()
|
||||
|
@ -547,9 +549,10 @@ impl Assistant {
|
|||
let api_key = self.api_key.borrow().clone();
|
||||
if let Some(api_key) = api_key {
|
||||
let stream = stream_completion(api_key, cx.background().clone(), request);
|
||||
let (excerpt_id, content) = self.push_message(Role::Assistant, cx);
|
||||
self.push_message(Role::User, cx);
|
||||
let task = cx.spawn(|this, mut cx| async move {
|
||||
let (excerpt_id, content) =
|
||||
self.insert_message_after(ExcerptId::max(), Role::Assistant, cx);
|
||||
self.insert_message_after(ExcerptId::max(), Role::User, cx);
|
||||
let task = cx.spawn_weak(|this, mut cx| async move {
|
||||
let stream_completion = async {
|
||||
let mut messages = stream.await?;
|
||||
|
||||
|
@ -564,22 +567,26 @@ impl Assistant {
|
|||
}
|
||||
}
|
||||
|
||||
this.update(&mut cx, |this, cx| {
|
||||
this.pending_completions
|
||||
.retain(|completion| completion.id != this.completion_count);
|
||||
this.summarize(cx);
|
||||
});
|
||||
this.upgrade(&cx)
|
||||
.ok_or_else(|| anyhow!("assistant was dropped"))?
|
||||
.update(&mut cx, |this, cx| {
|
||||
this.pending_completions
|
||||
.retain(|completion| completion.id != this.completion_count);
|
||||
this.summarize(cx);
|
||||
});
|
||||
|
||||
anyhow::Ok(())
|
||||
};
|
||||
|
||||
if let Err(error) = stream_completion.await {
|
||||
this.update(&mut cx, |this, cx| {
|
||||
if let Some(metadata) = this.messages_metadata.get_mut(&excerpt_id) {
|
||||
metadata.error = Some(error.to_string().trim().into());
|
||||
cx.notify();
|
||||
}
|
||||
})
|
||||
if let Some(this) = this.upgrade(&cx) {
|
||||
this.update(&mut cx, |this, cx| {
|
||||
if let Some(metadata) = this.messages_metadata.get_mut(&excerpt_id) {
|
||||
metadata.error = Some(error.to_string().trim().into());
|
||||
cx.notify();
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
|
@ -632,8 +639,9 @@ impl Assistant {
|
|||
}
|
||||
}
|
||||
|
||||
fn push_message(
|
||||
fn insert_message_after(
|
||||
&mut self,
|
||||
excerpt_id: ExcerptId,
|
||||
role: Role,
|
||||
cx: &mut ModelContext<Self>,
|
||||
) -> (ExcerptId, ModelHandle<Buffer>) {
|
||||
|
@ -654,9 +662,10 @@ impl Assistant {
|
|||
buffer.set_language_registry(self.languages.clone());
|
||||
buffer
|
||||
});
|
||||
let excerpt_id = self.buffer.update(cx, |buffer, cx| {
|
||||
let new_excerpt_id = self.buffer.update(cx, |buffer, cx| {
|
||||
buffer
|
||||
.push_excerpts(
|
||||
.insert_excerpts_after(
|
||||
excerpt_id,
|
||||
content.clone(),
|
||||
vec![ExcerptRange {
|
||||
context: 0..0,
|
||||
|
@ -668,19 +677,27 @@ impl Assistant {
|
|||
.unwrap()
|
||||
});
|
||||
|
||||
self.messages.push(Message {
|
||||
excerpt_id,
|
||||
content: content.clone(),
|
||||
});
|
||||
let ix = self
|
||||
.messages
|
||||
.iter()
|
||||
.position(|message| message.excerpt_id == excerpt_id)
|
||||
.map_or(self.messages.len(), |ix| ix + 1);
|
||||
self.messages.insert(
|
||||
ix,
|
||||
Message {
|
||||
excerpt_id: new_excerpt_id,
|
||||
content: content.clone(),
|
||||
},
|
||||
);
|
||||
self.messages_metadata.insert(
|
||||
excerpt_id,
|
||||
new_excerpt_id,
|
||||
MessageMetadata {
|
||||
role,
|
||||
sent_at: Local::now(),
|
||||
error: None,
|
||||
},
|
||||
);
|
||||
(excerpt_id, content)
|
||||
(new_excerpt_id, content)
|
||||
}
|
||||
|
||||
fn summarize(&mut self, cx: &mut ModelContext<Self>) {
|
||||
|
@ -882,7 +899,7 @@ impl AssistantEditor {
|
|||
if metadata.role == Role::User {
|
||||
assistant.assist(cx);
|
||||
} else {
|
||||
assistant.push_message(Role::User, cx);
|
||||
assistant.insert_message_after(excerpt_id, Role::User, cx);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1227,3 +1244,28 @@ async fn stream_completion(
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use gpui::AppContext;
|
||||
|
||||
#[gpui::test]
|
||||
fn test_inserting_and_removing_messages(cx: &mut AppContext) {
|
||||
let registry = Arc::new(LanguageRegistry::test());
|
||||
|
||||
cx.add_model(|cx| {
|
||||
let mut assistant = Assistant::new(Default::default(), registry, cx);
|
||||
let (excerpt_1, _) =
|
||||
assistant.insert_message_after(ExcerptId::max(), Role::Assistant, cx);
|
||||
let (excerpt_2, _) = assistant.insert_message_after(excerpt_1, Role::User, cx);
|
||||
let (excerpt_3, _) = assistant.insert_message_after(excerpt_1, Role::User, cx);
|
||||
assistant.remove_empty_messages(
|
||||
HashSet::from_iter([excerpt_2, excerpt_3]),
|
||||
Default::default(),
|
||||
cx,
|
||||
);
|
||||
assistant
|
||||
});
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue