language_model: Add tool results to message content (#17363)

This PR updates the message content for an LLM request to allow it
contain tool results.

Release Notes:

- N/A
This commit is contained in:
Marshall Bowers 2024-09-04 13:29:01 -04:00 committed by GitHub
parent 74907cb3e6
commit 30b2133336
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 73 additions and 38 deletions

View file

@ -423,6 +423,14 @@ pub enum RequestContent {
#[serde(skip_serializing_if = "Option::is_none")]
cache_control: Option<CacheControl>,
},
#[serde(rename = "tool_result")]
ToolResult {
tool_use_id: String,
is_error: bool,
content: String,
#[serde(skip_serializing_if = "Option::is_none")]
cache_control: Option<CacheControl>,
},
}
#[derive(Debug, Serialize, Deserialize)]

View file

@ -261,12 +261,15 @@ pub fn count_anthropic_tokens(
for content in message.content {
match content {
MessageContent::Text(string) => {
string_contents.push_str(&string);
MessageContent::Text(text) => {
string_contents.push_str(&text);
}
MessageContent::Image(image) => {
tokens_from_images += image.estimate_tokens();
}
MessageContent::ToolResult(tool_result) => {
string_contents.push_str(&tool_result.content);
}
}
}

View file

@ -8,14 +8,24 @@ use serde::{Deserialize, Serialize};
use ui::{px, SharedString};
use util::ResultExt;
#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Debug, Hash)]
#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Hash)]
pub struct LanguageModelImage {
// A base64 encoded PNG image
/// A base64-encoded PNG image.
pub source: SharedString,
size: Size<DevicePixels>,
}
const ANTHROPIC_SIZE_LIMT: f32 = 1568.0; // Anthropic wants uploaded images to be smaller than this in both dimensions
impl std::fmt::Debug for LanguageModelImage {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("LanguageModelImage")
.field("source", &format!("<{} bytes>", self.source.len()))
.field("size", &self.size)
.finish()
}
}
/// Anthropic wants uploaded images to be smaller than this in both dimensions.
const ANTHROPIC_SIZE_LIMT: f32 = 1568.;
impl LanguageModelImage {
pub fn from_image(data: Image, cx: &mut AppContext) -> Task<Option<Self>> {
@ -67,7 +77,7 @@ impl LanguageModelImage {
}
}
// SAFETY: The base64 encoder should not produce non-UTF8
// SAFETY: The base64 encoder should not produce non-UTF8.
let source = unsafe { String::from_utf8_unchecked(base64_image) };
Some(LanguageModelImage {
@ -77,7 +87,7 @@ impl LanguageModelImage {
})
}
/// Resolves image into an LLM-ready format (base64)
/// Resolves image into an LLM-ready format (base64).
pub fn from_render_image(data: &RenderImage) -> Option<Self> {
let image_size = data.size(0);
@ -130,7 +140,7 @@ impl LanguageModelImage {
base64_encoder.write_all(png.as_slice()).log_err()?;
}
// SAFETY: The base64 encoder should not produce non-UTF8
// SAFETY: The base64 encoder should not produce non-UTF8.
let source = unsafe { String::from_utf8_unchecked(base64_image) };
Some(LanguageModelImage {
@ -144,35 +154,32 @@ impl LanguageModelImage {
let height = self.size.height.0.unsigned_abs() as usize;
// From: https://docs.anthropic.com/en/docs/build-with-claude/vision#calculate-image-costs
// Note that are a lot of conditions on anthropic's API, and OpenAI doesn't use this,
// so this method is more of a rough guess
// Note that are a lot of conditions on Anthropic's API, and OpenAI doesn't use this,
// so this method is more of a rough guess.
(width * height) / 750
}
}
#[derive(Clone, Serialize, Deserialize, Eq, PartialEq, Hash)]
#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)]
pub struct LanguageModelToolResult {
pub tool_use_id: String,
pub is_error: bool,
pub content: String,
}
#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)]
pub enum MessageContent {
Text(String),
Image(LanguageModelImage),
}
impl std::fmt::Debug for MessageContent {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
MessageContent::Text(t) => f.debug_struct("MessageContent").field("text", t).finish(),
MessageContent::Image(i) => f
.debug_struct("MessageContent")
.field("image", &i.source.len())
.finish(),
}
}
ToolResult(LanguageModelToolResult),
}
impl MessageContent {
pub fn as_string(&self) -> &str {
match self {
MessageContent::Text(s) => s.as_str(),
MessageContent::Text(text) => text.as_str(),
MessageContent::Image(_) => "",
MessageContent::ToolResult(tool_result) => tool_result.content.as_str(),
}
}
}
@ -200,8 +207,9 @@ impl LanguageModelRequestMessage {
pub fn string_contents(&self) -> String {
let mut string_buffer = String::new();
for string in self.content.iter().filter_map(|content| match content {
MessageContent::Text(s) => Some(s),
MessageContent::Text(text) => Some(text),
MessageContent::Image(_) => None,
MessageContent::ToolResult(tool_result) => Some(&tool_result.content),
}) {
string_buffer.push_str(string.as_str())
}
@ -214,8 +222,11 @@ impl LanguageModelRequestMessage {
.content
.get(0)
.map(|content| match content {
MessageContent::Text(s) => s.trim().is_empty(),
MessageContent::Text(text) => text.trim().is_empty(),
MessageContent::Image(_) => true,
MessageContent::ToolResult(tool_result) => {
tool_result.content.trim().is_empty()
}
})
.unwrap_or(false)
}
@ -316,21 +327,34 @@ impl LanguageModelRequest {
.content
.into_iter()
.filter_map(|content| match content {
MessageContent::Text(t) if !t.is_empty() => {
Some(anthropic::RequestContent::Text {
text: t,
MessageContent::Text(text) => {
if !text.is_empty() {
Some(anthropic::RequestContent::Text {
text,
cache_control,
})
} else {
None
}
}
MessageContent::Image(image) => {
Some(anthropic::RequestContent::Image {
source: anthropic::ImageSource {
source_type: "base64".to_string(),
media_type: "image/png".to_string(),
data: image.source.to_string(),
},
cache_control,
})
}
MessageContent::ToolResult(tool_result) => {
Some(anthropic::RequestContent::ToolResult {
tool_use_id: tool_result.tool_use_id,
is_error: tool_result.is_error,
content: tool_result.content,
cache_control,
})
}
MessageContent::Image(i) => Some(anthropic::RequestContent::Image {
source: anthropic::ImageSource {
source_type: "base64".to_string(),
media_type: "image/png".to_string(),
data: i.source.to_string(),
},
cache_control,
}),
_ => None,
})
.collect();
let anthropic_role = match message.role {