diff --git a/crates/language_model/src/language_model.rs b/crates/language_model/src/language_model.rs index d24a5f9001..171f5fa819 100644 --- a/crates/language_model/src/language_model.rs +++ b/crates/language_model/src/language_model.rs @@ -68,7 +68,7 @@ pub enum StopReason { ToolUse, } -#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)] +#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)] pub struct LanguageModelToolUse { pub id: String, pub name: String, diff --git a/crates/language_model/src/provider/anthropic.rs b/crates/language_model/src/provider/anthropic.rs index f0554970b3..5d4d4c4548 100644 --- a/crates/language_model/src/provider/anthropic.rs +++ b/crates/language_model/src/provider/anthropic.rs @@ -267,6 +267,9 @@ pub fn count_anthropic_tokens( MessageContent::Image(image) => { tokens_from_images += image.estimate_tokens(); } + MessageContent::ToolUse(_tool_use) => { + // TODO: Estimate token usage from tool uses. + } MessageContent::ToolResult(tool_result) => { string_contents.push_str(&tool_result.content); } diff --git a/crates/language_model/src/request.rs b/crates/language_model/src/request.rs index 00e0c38d21..64ce33a21f 100644 --- a/crates/language_model/src/request.rs +++ b/crates/language_model/src/request.rs @@ -1,6 +1,7 @@ use std::io::{Cursor, Write}; use crate::role::Role; +use crate::LanguageModelToolUse; use base64::write::EncoderWriter; use gpui::{point, size, AppContext, DevicePixels, Image, ObjectFit, RenderImage, Size, Task}; use image::{codecs::png::PngEncoder, imageops::resize, DynamicImage, ImageDecoder}; @@ -171,6 +172,7 @@ pub struct LanguageModelToolResult { pub enum MessageContent { Text(String), Image(LanguageModelImage), + ToolUse(LanguageModelToolUse), ToolResult(LanguageModelToolResult), } @@ -198,8 +200,8 @@ impl LanguageModelRequestMessage { let mut string_buffer = String::new(); for string in self.content.iter().filter_map(|content| match content { MessageContent::Text(text) => Some(text), - MessageContent::Image(_) => None, MessageContent::ToolResult(tool_result) => Some(&tool_result.content), + MessageContent::ToolUse(_) | MessageContent::Image(_) => None, }) { string_buffer.push_str(string.as_str()) } @@ -213,10 +215,10 @@ impl LanguageModelRequestMessage { .get(0) .map(|content| match content { MessageContent::Text(text) => text.trim().is_empty(), - MessageContent::Image(_) => true, MessageContent::ToolResult(tool_result) => { tool_result.content.trim().is_empty() } + MessageContent::ToolUse(_) | MessageContent::Image(_) => true, }) .unwrap_or(false) } @@ -337,6 +339,14 @@ impl LanguageModelRequest { cache_control, }) } + MessageContent::ToolUse(tool_use) => { + Some(anthropic::RequestContent::ToolUse { + id: tool_use.id, + name: tool_use.name, + input: tool_use.input, + cache_control, + }) + } MessageContent::ToolResult(tool_result) => { Some(anthropic::RequestContent::ToolResult { tool_use_id: tool_result.tool_use_id,