assistant: Pass up tool results in LLM request messages (#17656)

This PR makes it so we pass up the tool results in the `tool_results`
field in the request message to the LLM.

This required reworking how we track non-text content in the context
editor.

We also removed serialization of images in context history, as we were
never deserializing it, and thus it was unneeded.

Release Notes:

- N/A

---------

Co-authored-by: Antonio <antonio@zed.dev>
This commit is contained in:
Marshall Bowers 2024-09-10 15:25:57 -04:00 committed by GitHub
parent 1b627925d3
commit a23e381096
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 220 additions and 216 deletions

View file

@ -11,7 +11,7 @@ use crate::{
},
slash_command_picker,
terminal_inline_assistant::TerminalInlineAssistant,
Assist, CacheStatus, ConfirmCommand, Context, ContextEvent, ContextId, ContextStore,
Assist, CacheStatus, ConfirmCommand, Content, Context, ContextEvent, ContextId, ContextStore,
ContextStoreEvent, CycleMessageRole, DeployHistory, DeployPromptLibrary, InlineAssistId,
InlineAssistant, InsertDraggedFiles, InsertIntoEditor, Message, MessageId, MessageMetadata,
MessageStatus, ModelPickerDelegate, ModelSelector, NewContext, PendingSlashCommand,
@ -46,6 +46,7 @@ use indexed_docs::IndexedDocsStore;
use language::{
language_settings::SoftWrap, Capability, LanguageRegistry, LspAdapterDelegate, Point, ToOffset,
};
use language_model::LanguageModelToolUse;
use language_model::{
provider::cloud::PROVIDER_ID, LanguageModelProvider, LanguageModelProviderId,
LanguageModelRegistry, Role,
@ -1995,6 +1996,20 @@ impl ContextEditor {
let buffer_row = MultiBufferRow(start.to_point(&buffer).row);
buffer_rows_to_fold.insert(buffer_row);
self.context.update(cx, |context, cx| {
context.insert_content(
Content::ToolUse {
range: tool_use.source_range.clone(),
tool_use: LanguageModelToolUse {
id: tool_use.id.to_string(),
name: tool_use.name.clone(),
input: tool_use.input.clone(),
},
},
cx,
);
});
Crease::new(
start..end,
placeholder,
@ -3538,7 +3553,7 @@ impl ContextEditor {
let image_id = image.id();
context.insert_image(image, cx);
for image_position in image_positions.iter() {
context.insert_image_anchor(image_id, image_position.text_anchor, cx);
context.insert_image_content(image_id, image_position.text_anchor, cx);
}
}
});
@ -3553,11 +3568,23 @@ impl ContextEditor {
let new_blocks = self
.context
.read(cx)
.images(cx)
.filter_map(|image| {
.contents(cx)
.filter_map(|content| {
if let Content::Image {
anchor,
render_image,
..
} = content
{
Some((anchor, render_image))
} else {
None
}
})
.filter_map(|(anchor, render_image)| {
const MAX_HEIGHT_IN_LINES: u32 = 8;
let anchor = buffer.anchor_in_excerpt(excerpt_id, image.anchor).unwrap();
let image = image.render_image.clone();
let anchor = buffer.anchor_in_excerpt(excerpt_id, anchor).unwrap();
let image = render_image.clone();
anchor.is_valid(&buffer).then(|| BlockProperties {
position: anchor,
height: MAX_HEIGHT_IN_LINES,

View file

@ -17,7 +17,6 @@ use feature_flags::{FeatureFlag, FeatureFlagAppExt};
use fs::{Fs, RemoveOptions};
use futures::{
future::{self, Shared},
stream::FuturesUnordered,
FutureExt, StreamExt,
};
use gpui::{
@ -29,10 +28,11 @@ use language::{AnchorRangeExt, Bias, Buffer, LanguageRegistry, OffsetRangeExt, P
use language_model::{
LanguageModel, LanguageModelCacheConfiguration, LanguageModelCompletionEvent,
LanguageModelImage, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
LanguageModelRequestTool, MessageContent, Role, StopReason,
LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolUse, MessageContent, Role,
StopReason,
};
use open_ai::Model as OpenAiModel;
use paths::{context_images_dir, contexts_dir};
use paths::contexts_dir;
use project::Project;
use serde::{Deserialize, Serialize};
use smallvec::SmallVec;
@ -377,23 +377,8 @@ impl MessageMetadata {
}
}
#[derive(Clone, Debug)]
pub struct MessageImage {
image_id: u64,
image: Shared<Task<Option<LanguageModelImage>>>,
}
impl PartialEq for MessageImage {
fn eq(&self, other: &Self) -> bool {
self.image_id == other.image_id
}
}
impl Eq for MessageImage {}
#[derive(Clone, Debug)]
pub struct Message {
pub image_offsets: SmallVec<[(usize, MessageImage); 1]>,
pub offset_range: Range<usize>,
pub index_range: Range<usize>,
pub anchor_range: Range<language::Anchor>,
@ -403,60 +388,43 @@ pub struct Message {
pub cache: Option<MessageCacheMetadata>,
}
impl Message {
fn to_request_message(&self, buffer: &Buffer) -> Option<LanguageModelRequestMessage> {
let mut content = Vec::new();
let mut range_start = self.offset_range.start;
for (image_offset, message_image) in self.image_offsets.iter() {
if *image_offset != range_start {
if let Some(text) = Self::collect_text_content(buffer, range_start..*image_offset) {
content.push(text);
}
}
if let Some(image) = message_image.image.clone().now_or_never().flatten() {
content.push(language_model::MessageContent::Image(image));
}
range_start = *image_offset;
}
if range_start != self.offset_range.end {
if let Some(text) =
Self::collect_text_content(buffer, range_start..self.offset_range.end)
{
content.push(text);
}
}
if content.is_empty() {
return None;
}
Some(LanguageModelRequestMessage {
role: self.role,
content,
cache: self.cache.as_ref().map_or(false, |cache| cache.is_anchor),
})
}
fn collect_text_content(buffer: &Buffer, range: Range<usize>) -> Option<MessageContent> {
let text: String = buffer.text_for_range(range.clone()).collect();
if text.trim().is_empty() {
None
} else {
Some(MessageContent::Text(text))
}
}
#[derive(Debug, Clone)]
pub enum Content {
Image {
anchor: language::Anchor,
image_id: u64,
render_image: Arc<RenderImage>,
image: Shared<Task<Option<LanguageModelImage>>>,
},
ToolUse {
range: Range<language::Anchor>,
tool_use: LanguageModelToolUse,
},
ToolResult {
range: Range<language::Anchor>,
tool_use_id: Arc<str>,
},
}
#[derive(Clone, Debug)]
pub struct ImageAnchor {
pub anchor: language::Anchor,
pub image_id: u64,
pub render_image: Arc<RenderImage>,
pub image: Shared<Task<Option<LanguageModelImage>>>,
impl Content {
fn range(&self) -> Range<language::Anchor> {
match self {
Self::Image { anchor, .. } => *anchor..*anchor,
Self::ToolUse { range, .. } | Self::ToolResult { range, .. } => range.clone(),
}
}
fn cmp(&self, other: &Self, buffer: &BufferSnapshot) -> Ordering {
let self_range = self.range();
let other_range = other.range();
if self_range.end.cmp(&other_range.start, buffer).is_lt() {
Ordering::Less
} else if self_range.start.cmp(&other_range.end, buffer).is_gt() {
Ordering::Greater
} else {
Ordering::Equal
}
}
}
struct PendingCompletion {
@ -501,7 +469,7 @@ pub struct Context {
pending_tool_uses_by_id: HashMap<Arc<str>, PendingToolUse>,
message_anchors: Vec<MessageAnchor>,
images: HashMap<u64, (Arc<RenderImage>, Shared<Task<Option<LanguageModelImage>>>)>,
image_anchors: Vec<ImageAnchor>,
contents: Vec<Content>,
messages_metadata: HashMap<MessageId, MessageMetadata>,
summary: Option<ContextSummary>,
pending_summary: Task<Option<()>>,
@ -595,7 +563,7 @@ impl Context {
pending_ops: Vec::new(),
operations: Vec::new(),
message_anchors: Default::default(),
image_anchors: Default::default(),
contents: Default::default(),
images: Default::default(),
messages_metadata: Default::default(),
pending_slash_commands: Vec::new(),
@ -659,11 +627,6 @@ impl Context {
id: message.id,
start: message.offset_range.start,
metadata: self.messages_metadata[&message.id].clone(),
image_offsets: message
.image_offsets
.iter()
.map(|image_offset| (image_offset.0, image_offset.1.image_id))
.collect(),
})
.collect(),
summary: self
@ -1957,6 +1920,14 @@ impl Context {
output_range
});
this.insert_content(
Content::ToolResult {
range: anchor_range.clone(),
tool_use_id: tool_use_id.clone(),
},
cx,
);
cx.emit(ContextEvent::ToolFinished {
tool_use_id,
output_range: anchor_range,
@ -2038,6 +2009,7 @@ impl Context {
let stream_completion = async {
let request_start = Instant::now();
let mut events = stream.await?;
let mut stop_reason = StopReason::EndTurn;
while let Some(event) = events.next().await {
if response_latency.is_none() {
@ -2050,7 +2022,7 @@ impl Context {
.message_anchors
.iter()
.position(|message| message.id == assistant_message_id)?;
let event_to_emit = this.buffer.update(cx, |buffer, cx| {
this.buffer.update(cx, |buffer, cx| {
let message_old_end_offset = this.message_anchors[message_ix + 1..]
.iter()
.find(|message| message.start.is_valid(buffer))
@ -2059,13 +2031,9 @@ impl Context {
});
match event {
LanguageModelCompletionEvent::Stop(reason) => match reason {
StopReason::ToolUse => {
return Some(ContextEvent::UsePendingTools);
}
StopReason::EndTurn => {}
StopReason::MaxTokens => {}
},
LanguageModelCompletionEvent::Stop(reason) => {
stop_reason = reason;
}
LanguageModelCompletionEvent::Text(chunk) => {
buffer.edit(
[(
@ -2116,14 +2084,9 @@ impl Context {
);
}
}
None
});
cx.emit(ContextEvent::StreamedCompletion);
if let Some(event) = event_to_emit {
cx.emit(event);
}
Some(())
})?;
@ -2136,13 +2099,14 @@ impl Context {
this.update_cache_status_for_completion(cx);
})?;
anyhow::Ok(())
anyhow::Ok(stop_reason)
};
let result = stream_completion.await;
this.update(&mut cx, |this, cx| {
let error_message = result
.as_ref()
.err()
.map(|error| error.to_string().trim().to_string());
@ -2170,6 +2134,16 @@ impl Context {
error_message,
);
}
if let Ok(stop_reason) = result {
match stop_reason {
StopReason::ToolUse => {
cx.emit(ContextEvent::UsePendingTools);
}
StopReason::EndTurn => {}
StopReason::MaxTokens => {}
}
}
})
.ok();
}
@ -2186,18 +2160,94 @@ impl Context {
pub fn to_completion_request(&self, cx: &AppContext) -> LanguageModelRequest {
let buffer = self.buffer.read(cx);
let request_messages = self
.messages(cx)
.filter(|message| message.status == MessageStatus::Done)
.filter_map(|message| message.to_request_message(&buffer))
.collect();
LanguageModelRequest {
messages: request_messages,
let mut contents = self.contents(cx).peekable();
fn collect_text_content(buffer: &Buffer, range: Range<usize>) -> Option<String> {
let text: String = buffer.text_for_range(range.clone()).collect();
if text.trim().is_empty() {
None
} else {
Some(text)
}
}
let mut completion_request = LanguageModelRequest {
messages: Vec::new(),
tools: Vec::new(),
stop: Vec::new(),
temperature: 1.0,
};
for message in self.messages(cx) {
if message.status != MessageStatus::Done {
continue;
}
let mut offset = message.offset_range.start;
let mut request_message = LanguageModelRequestMessage {
role: message.role,
content: Vec::new(),
cache: message
.cache
.as_ref()
.map_or(false, |cache| cache.is_anchor),
};
while let Some(content) = contents.peek() {
if content
.range()
.end
.cmp(&message.anchor_range.end, buffer)
.is_lt()
{
let content = contents.next().unwrap();
let range = content.range().to_offset(buffer);
request_message.content.extend(
collect_text_content(buffer, offset..range.start).map(MessageContent::Text),
);
match content {
Content::Image { image, .. } => {
if let Some(image) = image.clone().now_or_never().flatten() {
request_message
.content
.push(language_model::MessageContent::Image(image));
}
}
Content::ToolUse { tool_use, .. } => {
request_message
.content
.push(language_model::MessageContent::ToolUse(tool_use.clone()));
}
Content::ToolResult { tool_use_id, .. } => {
request_message.content.push(
language_model::MessageContent::ToolResult(
LanguageModelToolResult {
tool_use_id: tool_use_id.to_string(),
is_error: false,
content: collect_text_content(buffer, range.clone())
.unwrap_or_default(),
},
),
);
}
}
offset = range.end;
} else {
break;
}
}
request_message.content.extend(
collect_text_content(buffer, offset..message.offset_range.end)
.map(MessageContent::Text),
);
completion_request.messages.push(request_message);
}
completion_request
}
pub fn cancel_last_assist(&mut self, cx: &mut ModelContext<Self>) -> bool {
@ -2335,42 +2385,50 @@ impl Context {
Some(())
}
pub fn insert_image_anchor(
pub fn insert_image_content(
&mut self,
image_id: u64,
anchor: language::Anchor,
cx: &mut ModelContext<Self>,
) -> bool {
cx.emit(ContextEvent::MessagesEdited);
let buffer = self.buffer.read(cx);
let insertion_ix = match self
.image_anchors
.binary_search_by(|existing_anchor| anchor.cmp(&existing_anchor.anchor, buffer))
{
Ok(ix) => ix,
Err(ix) => ix,
};
) {
if let Some((render_image, image)) = self.images.get(&image_id) {
self.image_anchors.insert(
insertion_ix,
ImageAnchor {
self.insert_content(
Content::Image {
anchor,
image_id,
image: image.clone(),
render_image: render_image.clone(),
},
cx,
);
true
} else {
false
}
}
pub fn images<'a>(&'a self, _cx: &'a AppContext) -> impl 'a + Iterator<Item = ImageAnchor> {
self.image_anchors.iter().cloned()
pub fn insert_content(&mut self, content: Content, cx: &mut ModelContext<Self>) {
let buffer = self.buffer.read(cx);
let insertion_ix = match self
.contents
.binary_search_by(|probe| probe.cmp(&content, buffer))
{
Ok(ix) => {
self.contents.remove(ix);
ix
}
Err(ix) => ix,
};
self.contents.insert(insertion_ix, content);
cx.emit(ContextEvent::MessagesEdited);
}
pub fn contents<'a>(&'a self, cx: &'a AppContext) -> impl 'a + Iterator<Item = Content> {
let buffer = self.buffer.read(cx);
self.contents
.iter()
.filter(|content| {
let range = content.range();
range.start.is_valid(buffer) && range.end.is_valid(buffer)
})
.cloned()
}
pub fn split_message(
@ -2533,22 +2591,14 @@ impl Context {
return;
}
let messages = self
.messages(cx)
.filter_map(|message| message.to_request_message(self.buffer.read(cx)))
.chain(Some(LanguageModelRequestMessage {
role: Role::User,
content: vec![
"Summarize the context into a short title without punctuation.".into(),
],
cache: false,
}));
let request = LanguageModelRequest {
messages: messages.collect(),
tools: Vec::new(),
stop: Vec::new(),
temperature: 1.0,
};
let mut request = self.to_completion_request(cx);
request.messages.push(LanguageModelRequestMessage {
role: Role::User,
content: vec![
"Summarize the context into a short title without punctuation.".into(),
],
cache: false,
});
self.pending_summary = cx.spawn(|this, mut cx| {
async move {
@ -2648,10 +2698,8 @@ impl Context {
cx: &'a AppContext,
) -> impl 'a + Iterator<Item = Message> {
let buffer = self.buffer.read(cx);
let messages = message_anchors.enumerate();
let images = self.image_anchors.iter();
Self::messages_from_iters(buffer, &self.messages_metadata, messages, images)
Self::messages_from_iters(buffer, &self.messages_metadata, message_anchors.enumerate())
}
pub fn messages<'a>(&'a self, cx: &'a AppContext) -> impl 'a + Iterator<Item = Message> {
@ -2662,10 +2710,8 @@ impl Context {
buffer: &'a Buffer,
metadata: &'a HashMap<MessageId, MessageMetadata>,
messages: impl Iterator<Item = (usize, &'a MessageAnchor)> + 'a,
images: impl Iterator<Item = &'a ImageAnchor> + 'a,
) -> impl 'a + Iterator<Item = Message> {
let mut messages = messages.peekable();
let mut images = images.peekable();
iter::from_fn(move || {
if let Some((start_ix, message_anchor)) = messages.next() {
@ -2686,22 +2732,6 @@ impl Context {
let message_end_anchor = message_end.unwrap_or(language::Anchor::MAX);
let message_end = message_end_anchor.to_offset(buffer);
let mut image_offsets = SmallVec::new();
while let Some(image_anchor) = images.peek() {
if image_anchor.anchor.cmp(&message_end_anchor, buffer).is_lt() {
image_offsets.push((
image_anchor.anchor.to_offset(buffer),
MessageImage {
image_id: image_anchor.image_id,
image: image_anchor.image.clone(),
},
));
images.next();
} else {
break;
}
}
return Some(Message {
index_range: start_ix..end_ix,
offset_range: message_start..message_end,
@ -2710,7 +2740,6 @@ impl Context {
role: metadata.role,
status: metadata.status.clone(),
cache: metadata.cache.clone(),
image_offsets,
});
}
None
@ -2748,9 +2777,6 @@ impl Context {
})?;
if let Some(summary) = summary {
this.read_with(&cx, |this, cx| this.serialize_images(fs.clone(), cx))?
.await;
let context = this.read_with(&cx, |this, cx| this.serialize(cx))?;
let mut discriminant = 1;
let mut new_path;
@ -2790,45 +2816,6 @@ impl Context {
});
}
pub fn serialize_images(&self, fs: Arc<dyn Fs>, cx: &AppContext) -> Task<()> {
let mut images_to_save = self
.images
.iter()
.map(|(id, (_, llm_image))| {
let fs = fs.clone();
let llm_image = llm_image.clone();
let id = *id;
async move {
if let Some(llm_image) = llm_image.await {
let path: PathBuf =
context_images_dir().join(&format!("{}.png.base64", id));
if fs
.metadata(path.as_path())
.await
.log_err()
.flatten()
.is_none()
{
fs.atomic_write(path, llm_image.source.to_string())
.await
.log_err();
}
}
}
})
.collect::<FuturesUnordered<_>>();
cx.background_executor().spawn(async move {
if fs
.create_dir(context_images_dir().as_ref())
.await
.log_err()
.is_some()
{
while let Some(_) = images_to_save.next().await {}
}
})
}
pub(crate) fn custom_summary(&mut self, custom_summary: String, cx: &mut ModelContext<Self>) {
let timestamp = self.next_timestamp();
let summary = self.summary.get_or_insert(ContextSummary::default());
@ -2914,9 +2901,6 @@ pub struct SavedMessage {
pub id: MessageId,
pub start: usize,
pub metadata: MessageMetadata,
#[serde(default)]
// This is defaulted for backwards compatibility with JSON files created before August 2024. We didn't always have this field.
pub image_offsets: Vec<(usize, u64)>,
}
#[derive(Serialize, Deserialize)]
@ -3102,7 +3086,6 @@ impl SavedContextV0_3_0 {
timestamp,
cache: None,
},
image_offsets: Vec::new(),
})
})
.collect(),

View file

@ -170,12 +170,6 @@ pub fn contexts_dir() -> &'static PathBuf {
})
}
/// Returns the path within the contexts directory where images from contexts are stored.
pub fn context_images_dir() -> &'static PathBuf {
static CONTEXT_IMAGES_DIR: OnceLock<PathBuf> = OnceLock::new();
CONTEXT_IMAGES_DIR.get_or_init(|| contexts_dir().join("images"))
}
/// Returns the path to the contexts directory.
///
/// This is where the prompts for use with the Assistant are stored.