language_model: Add tool results to message content (#17363)
This PR updates the message content for an LLM request to allow it contain tool results. Release Notes: - N/A
This commit is contained in:
parent
74907cb3e6
commit
30b2133336
3 changed files with 73 additions and 38 deletions
|
@ -423,6 +423,14 @@ pub enum RequestContent {
|
|||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
cache_control: Option<CacheControl>,
|
||||
},
|
||||
#[serde(rename = "tool_result")]
|
||||
ToolResult {
|
||||
tool_use_id: String,
|
||||
is_error: bool,
|
||||
content: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
cache_control: Option<CacheControl>,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
|
|
|
@ -261,12 +261,15 @@ pub fn count_anthropic_tokens(
|
|||
|
||||
for content in message.content {
|
||||
match content {
|
||||
MessageContent::Text(string) => {
|
||||
string_contents.push_str(&string);
|
||||
MessageContent::Text(text) => {
|
||||
string_contents.push_str(&text);
|
||||
}
|
||||
MessageContent::Image(image) => {
|
||||
tokens_from_images += image.estimate_tokens();
|
||||
}
|
||||
MessageContent::ToolResult(tool_result) => {
|
||||
string_contents.push_str(&tool_result.content);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -8,14 +8,24 @@ use serde::{Deserialize, Serialize};
|
|||
use ui::{px, SharedString};
|
||||
use util::ResultExt;
|
||||
|
||||
#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Debug, Hash)]
|
||||
#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Hash)]
|
||||
pub struct LanguageModelImage {
|
||||
// A base64 encoded PNG image
|
||||
/// A base64-encoded PNG image.
|
||||
pub source: SharedString,
|
||||
size: Size<DevicePixels>,
|
||||
}
|
||||
|
||||
const ANTHROPIC_SIZE_LIMT: f32 = 1568.0; // Anthropic wants uploaded images to be smaller than this in both dimensions
|
||||
impl std::fmt::Debug for LanguageModelImage {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("LanguageModelImage")
|
||||
.field("source", &format!("<{} bytes>", self.source.len()))
|
||||
.field("size", &self.size)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
/// Anthropic wants uploaded images to be smaller than this in both dimensions.
|
||||
const ANTHROPIC_SIZE_LIMT: f32 = 1568.;
|
||||
|
||||
impl LanguageModelImage {
|
||||
pub fn from_image(data: Image, cx: &mut AppContext) -> Task<Option<Self>> {
|
||||
|
@ -67,7 +77,7 @@ impl LanguageModelImage {
|
|||
}
|
||||
}
|
||||
|
||||
// SAFETY: The base64 encoder should not produce non-UTF8
|
||||
// SAFETY: The base64 encoder should not produce non-UTF8.
|
||||
let source = unsafe { String::from_utf8_unchecked(base64_image) };
|
||||
|
||||
Some(LanguageModelImage {
|
||||
|
@ -77,7 +87,7 @@ impl LanguageModelImage {
|
|||
})
|
||||
}
|
||||
|
||||
/// Resolves image into an LLM-ready format (base64)
|
||||
/// Resolves image into an LLM-ready format (base64).
|
||||
pub fn from_render_image(data: &RenderImage) -> Option<Self> {
|
||||
let image_size = data.size(0);
|
||||
|
||||
|
@ -130,7 +140,7 @@ impl LanguageModelImage {
|
|||
base64_encoder.write_all(png.as_slice()).log_err()?;
|
||||
}
|
||||
|
||||
// SAFETY: The base64 encoder should not produce non-UTF8
|
||||
// SAFETY: The base64 encoder should not produce non-UTF8.
|
||||
let source = unsafe { String::from_utf8_unchecked(base64_image) };
|
||||
|
||||
Some(LanguageModelImage {
|
||||
|
@ -144,35 +154,32 @@ impl LanguageModelImage {
|
|||
let height = self.size.height.0.unsigned_abs() as usize;
|
||||
|
||||
// From: https://docs.anthropic.com/en/docs/build-with-claude/vision#calculate-image-costs
|
||||
// Note that are a lot of conditions on anthropic's API, and OpenAI doesn't use this,
|
||||
// so this method is more of a rough guess
|
||||
// Note that are a lot of conditions on Anthropic's API, and OpenAI doesn't use this,
|
||||
// so this method is more of a rough guess.
|
||||
(width * height) / 750
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Serialize, Deserialize, Eq, PartialEq, Hash)]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)]
|
||||
pub struct LanguageModelToolResult {
|
||||
pub tool_use_id: String,
|
||||
pub is_error: bool,
|
||||
pub content: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)]
|
||||
pub enum MessageContent {
|
||||
Text(String),
|
||||
Image(LanguageModelImage),
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for MessageContent {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
MessageContent::Text(t) => f.debug_struct("MessageContent").field("text", t).finish(),
|
||||
MessageContent::Image(i) => f
|
||||
.debug_struct("MessageContent")
|
||||
.field("image", &i.source.len())
|
||||
.finish(),
|
||||
}
|
||||
}
|
||||
ToolResult(LanguageModelToolResult),
|
||||
}
|
||||
|
||||
impl MessageContent {
|
||||
pub fn as_string(&self) -> &str {
|
||||
match self {
|
||||
MessageContent::Text(s) => s.as_str(),
|
||||
MessageContent::Text(text) => text.as_str(),
|
||||
MessageContent::Image(_) => "",
|
||||
MessageContent::ToolResult(tool_result) => tool_result.content.as_str(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -200,8 +207,9 @@ impl LanguageModelRequestMessage {
|
|||
pub fn string_contents(&self) -> String {
|
||||
let mut string_buffer = String::new();
|
||||
for string in self.content.iter().filter_map(|content| match content {
|
||||
MessageContent::Text(s) => Some(s),
|
||||
MessageContent::Text(text) => Some(text),
|
||||
MessageContent::Image(_) => None,
|
||||
MessageContent::ToolResult(tool_result) => Some(&tool_result.content),
|
||||
}) {
|
||||
string_buffer.push_str(string.as_str())
|
||||
}
|
||||
|
@ -214,8 +222,11 @@ impl LanguageModelRequestMessage {
|
|||
.content
|
||||
.get(0)
|
||||
.map(|content| match content {
|
||||
MessageContent::Text(s) => s.trim().is_empty(),
|
||||
MessageContent::Text(text) => text.trim().is_empty(),
|
||||
MessageContent::Image(_) => true,
|
||||
MessageContent::ToolResult(tool_result) => {
|
||||
tool_result.content.trim().is_empty()
|
||||
}
|
||||
})
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
@ -316,21 +327,34 @@ impl LanguageModelRequest {
|
|||
.content
|
||||
.into_iter()
|
||||
.filter_map(|content| match content {
|
||||
MessageContent::Text(t) if !t.is_empty() => {
|
||||
Some(anthropic::RequestContent::Text {
|
||||
text: t,
|
||||
MessageContent::Text(text) => {
|
||||
if !text.is_empty() {
|
||||
Some(anthropic::RequestContent::Text {
|
||||
text,
|
||||
cache_control,
|
||||
})
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
MessageContent::Image(image) => {
|
||||
Some(anthropic::RequestContent::Image {
|
||||
source: anthropic::ImageSource {
|
||||
source_type: "base64".to_string(),
|
||||
media_type: "image/png".to_string(),
|
||||
data: image.source.to_string(),
|
||||
},
|
||||
cache_control,
|
||||
})
|
||||
}
|
||||
MessageContent::ToolResult(tool_result) => {
|
||||
Some(anthropic::RequestContent::ToolResult {
|
||||
tool_use_id: tool_result.tool_use_id,
|
||||
is_error: tool_result.is_error,
|
||||
content: tool_result.content,
|
||||
cache_control,
|
||||
})
|
||||
}
|
||||
MessageContent::Image(i) => Some(anthropic::RequestContent::Image {
|
||||
source: anthropic::ImageSource {
|
||||
source_type: "base64".to_string(),
|
||||
media_type: "image/png".to_string(),
|
||||
data: i.source.to_string(),
|
||||
},
|
||||
cache_control,
|
||||
}),
|
||||
_ => None,
|
||||
})
|
||||
.collect();
|
||||
let anthropic_role = match message.role {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue