diff --git a/crates/agent/src/active_thread.rs b/crates/agent/src/active_thread.rs index c3674ffc91..95cc09fa20 100644 --- a/crates/agent/src/active_thread.rs +++ b/crates/agent/src/active_thread.rs @@ -33,7 +33,9 @@ use language_model::{ LanguageModelRequestMessage, LanguageModelToolUseId, MessageContent, Role, StopReason, }; use markdown::parser::{CodeBlockKind, CodeBlockMetadata}; -use markdown::{HeadingLevelStyles, Markdown, MarkdownElement, MarkdownStyle, ParsedMarkdown}; +use markdown::{ + HeadingLevelStyles, Markdown, MarkdownElement, MarkdownStyle, ParsedMarkdown, PathWithRange, +}; use project::{ProjectEntryId, ProjectItem as _}; use rope::Point; use settings::{Settings as _, SettingsStore, update_settings_file}; @@ -430,49 +432,8 @@ fn render_markdown_code_block( let path_range = path_range.clone(); move |_, window, cx| { workspace - .update(cx, { - |workspace, cx| { - let Some(project_path) = workspace - .project() - .read(cx) - .find_project_path(&path_range.path, cx) - else { - return; - }; - let Some(target) = path_range.range.as_ref().map(|range| { - Point::new( - // Line number is 1-based - range.start.line.saturating_sub(1), - range.start.col.unwrap_or(0), - ) - }) else { - return; - }; - let open_task = workspace.open_path( - project_path, - None, - true, - window, - cx, - ); - window - .spawn(cx, async move |cx| { - let item = open_task.await?; - if let Some(active_editor) = - item.downcast::() - { - active_editor - .update_in(cx, |editor, window, cx| { - editor.go_to_singleton_buffer_point( - target, window, cx, - ); - }) - .ok(); - } - anyhow::Ok(()) - }) - .detach_and_log_err(cx); - } + .update(cx, |workspace, cx| { + open_path(&path_range, window, workspace, cx) }) .ok(); } @@ -598,6 +559,45 @@ fn render_markdown_code_block( .when(can_expand && !is_expanded, |this| this.max_h_80()) } +fn open_path( + path_range: &PathWithRange, + window: &mut Window, + workspace: &mut Workspace, + cx: &mut Context<'_, Workspace>, +) { + let Some(project_path) = workspace + .project() + .read(cx) + .find_project_path(&path_range.path, cx) + else { + return; // TODO instead of just bailing out, open that path in a buffer. + }; + + let Some(target) = path_range.range.as_ref().map(|range| { + Point::new( + // Line number is 1-based + range.start.line.saturating_sub(1), + range.start.col.unwrap_or(0), + ) + }) else { + return; + }; + let open_task = workspace.open_path(project_path, None, true, window, cx); + window + .spawn(cx, async move |cx| { + let item = open_task.await?; + if let Some(active_editor) = item.downcast::() { + active_editor + .update_in(cx, |editor, window, cx| { + editor.go_to_singleton_buffer_point(target, window, cx); + }) + .ok(); + } + anyhow::Ok(()) + }) + .detach_and_log_err(cx); +} + fn render_code_language( language: Option<&Arc>, name_fallback: SharedString, diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index 93b9a73d94..2b6269a053 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -22,9 +22,9 @@ use language_model::{ ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult, - LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent, - ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, SelectedModel, - StopReason, TokenUsage, + LanguageModelToolResultContent, LanguageModelToolUseId, MaxMonthlySpendReachedError, + MessageContent, ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, + SelectedModel, StopReason, TokenUsage, }; use postage::stream::Stream as _; use project::Project; @@ -880,7 +880,13 @@ impl Thread { } pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc> { - Some(&self.tool_use.tool_result(id)?.content) + match &self.tool_use.tool_result(id)?.content { + LanguageModelToolResultContent::Text(str) => Some(str), + LanguageModelToolResultContent::Image(_) => { + // TODO: We should display image + None + } + } } pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option { @@ -2502,7 +2508,15 @@ impl Thread { } writeln!(markdown, "**\n")?; - writeln!(markdown, "{}", tool_result.content)?; + match &tool_result.content { + LanguageModelToolResultContent::Text(str) => { + writeln!(markdown, "{}", str)?; + } + LanguageModelToolResultContent::Image(image) => { + writeln!(markdown, "![Image](data:base64,{})", image.source)?; + } + } + if let Some(output) = tool_result.output.as_ref() { writeln!( markdown, diff --git a/crates/agent/src/thread_store.rs b/crates/agent/src/thread_store.rs index 508ddfa051..c43e452152 100644 --- a/crates/agent/src/thread_store.rs +++ b/crates/agent/src/thread_store.rs @@ -19,7 +19,7 @@ use gpui::{ }; use heed::Database; use heed::types::SerdeBincode; -use language_model::{LanguageModelToolUseId, Role, TokenUsage}; +use language_model::{LanguageModelToolResultContent, LanguageModelToolUseId, Role, TokenUsage}; use project::context_server_store::{ContextServerStatus, ContextServerStore}; use project::{Project, ProjectItem, ProjectPath, Worktree}; use prompt_store::{ @@ -775,7 +775,7 @@ pub struct SerializedToolUse { pub struct SerializedToolResult { pub tool_use_id: LanguageModelToolUseId, pub is_error: bool, - pub content: Arc, + pub content: LanguageModelToolResultContent, pub output: Option, } diff --git a/crates/agent/src/tool_use.rs b/crates/agent/src/tool_use.rs index 690e169c96..5ed330b29d 100644 --- a/crates/agent/src/tool_use.rs +++ b/crates/agent/src/tool_use.rs @@ -1,14 +1,16 @@ use std::sync::Arc; use anyhow::Result; -use assistant_tool::{AnyToolCard, Tool, ToolResultOutput, ToolUseStatus, ToolWorkingSet}; +use assistant_tool::{ + AnyToolCard, Tool, ToolResultContent, ToolResultOutput, ToolUseStatus, ToolWorkingSet, +}; use collections::HashMap; use futures::FutureExt as _; use futures::future::Shared; use gpui::{App, Entity, SharedString, Task}; use language_model::{ ConfiguredModel, LanguageModel, LanguageModelRequest, LanguageModelToolResult, - LanguageModelToolUse, LanguageModelToolUseId, Role, + LanguageModelToolResultContent, LanguageModelToolUse, LanguageModelToolUseId, Role, }; use project::Project; use ui::{IconName, Window}; @@ -165,10 +167,16 @@ impl ToolUseState { let status = (|| { if let Some(tool_result) = tool_result { + let content = tool_result + .content + .to_str() + .map(|str| str.to_owned().into()) + .unwrap_or_default(); + return if tool_result.is_error { - ToolUseStatus::Error(tool_result.content.clone().into()) + ToolUseStatus::Error(content) } else { - ToolUseStatus::Finished(tool_result.content.clone().into()) + ToolUseStatus::Finished(content) }; } @@ -399,21 +407,44 @@ impl ToolUseState { let tool_result = output.content; const BYTES_PER_TOKEN_ESTIMATE: usize = 3; - // Protect from clearly large output + let old_use = self.pending_tool_uses_by_id.remove(&tool_use_id); + + // Protect from overly large output let tool_output_limit = configured_model .map(|model| model.model.max_token_count() * BYTES_PER_TOKEN_ESTIMATE) .unwrap_or(usize::MAX); - let tool_result = if tool_result.len() <= tool_output_limit { - tool_result - } else { - let truncated = truncate_lines_to_byte_limit(&tool_result, tool_output_limit); + let content = match tool_result { + ToolResultContent::Text(text) => { + let truncated = truncate_lines_to_byte_limit(&text, tool_output_limit); - format!( - "Tool result too long. The first {} bytes:\n\n{}", - truncated.len(), - truncated - ) + LanguageModelToolResultContent::Text( + format!( + "Tool result too long. The first {} bytes:\n\n{}", + truncated.len(), + truncated + ) + .into(), + ) + } + ToolResultContent::Image(language_model_image) => { + if language_model_image.estimate_tokens() < tool_output_limit { + LanguageModelToolResultContent::Image(language_model_image) + } else { + self.tool_results.insert( + tool_use_id.clone(), + LanguageModelToolResult { + tool_use_id: tool_use_id.clone(), + tool_name, + content: "Tool responded with an image that would exceeded the remaining tokens".into(), + is_error: true, + output: None, + }, + ); + + return old_use; + } + } }; self.tool_results.insert( @@ -421,12 +452,13 @@ impl ToolUseState { LanguageModelToolResult { tool_use_id: tool_use_id.clone(), tool_name, - content: tool_result.into(), + content, is_error: false, output: output.output, }, ); - self.pending_tool_uses_by_id.remove(&tool_use_id) + + old_use } Err(err) => { self.tool_results.insert( @@ -434,7 +466,7 @@ impl ToolUseState { LanguageModelToolResult { tool_use_id: tool_use_id.clone(), tool_name, - content: err.to_string().into(), + content: LanguageModelToolResultContent::Text(err.to_string().into()), is_error: true, output: None, }, diff --git a/crates/anthropic/src/anthropic.rs b/crates/anthropic/src/anthropic.rs index d99517d844..b323b595ba 100644 --- a/crates/anthropic/src/anthropic.rs +++ b/crates/anthropic/src/anthropic.rs @@ -534,12 +534,26 @@ pub enum RequestContent { ToolResult { tool_use_id: String, is_error: bool, - content: String, + content: ToolResultContent, #[serde(skip_serializing_if = "Option::is_none")] cache_control: Option, }, } +#[derive(Debug, Serialize, Deserialize)] +#[serde(untagged)] +pub enum ToolResultContent { + JustText(String), + Multipart(Vec), +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "lowercase")] +pub enum ToolResultPart { + Text { text: String }, + Image { source: ImageSource }, +} + #[derive(Debug, Serialize, Deserialize)] #[serde(tag = "type")] pub enum ResponseContent { diff --git a/crates/assistant_tool/src/assistant_tool.rs b/crates/assistant_tool/src/assistant_tool.rs index 0f248807a0..ecda105f6d 100644 --- a/crates/assistant_tool/src/assistant_tool.rs +++ b/crates/assistant_tool/src/assistant_tool.rs @@ -19,6 +19,7 @@ use gpui::Window; use gpui::{App, Entity, SharedString, Task, WeakEntity}; use icons::IconName; use language_model::LanguageModel; +use language_model::LanguageModelImage; use language_model::LanguageModelRequest; use language_model::LanguageModelToolSchemaFormat; use project::Project; @@ -65,21 +66,50 @@ impl ToolUseStatus { #[derive(Debug)] pub struct ToolResultOutput { - pub content: String, + pub content: ToolResultContent, pub output: Option, } +#[derive(Debug, PartialEq, Eq)] +pub enum ToolResultContent { + Text(String), + Image(LanguageModelImage), +} + +impl ToolResultContent { + pub fn len(&self) -> usize { + match self { + ToolResultContent::Text(str) => str.len(), + ToolResultContent::Image(image) => image.len(), + } + } + + pub fn is_empty(&self) -> bool { + match self { + ToolResultContent::Text(str) => str.is_empty(), + ToolResultContent::Image(image) => image.is_empty(), + } + } + + pub fn as_str(&self) -> Option<&str> { + match self { + ToolResultContent::Text(str) => Some(str), + ToolResultContent::Image(_) => None, + } + } +} + impl From for ToolResultOutput { fn from(value: String) -> Self { ToolResultOutput { - content: value, + content: ToolResultContent::Text(value), output: None, } } } impl Deref for ToolResultOutput { - type Target = String; + type Target = ToolResultContent; fn deref(&self) -> &Self::Target { &self.content diff --git a/crates/assistant_tools/src/edit_agent/evals.rs b/crates/assistant_tools/src/edit_agent/evals.rs index 894da7ad34..9b7d3e8aca 100644 --- a/crates/assistant_tools/src/edit_agent/evals.rs +++ b/crates/assistant_tools/src/edit_agent/evals.rs @@ -10,8 +10,8 @@ use futures::{FutureExt, future::LocalBoxFuture}; use gpui::{AppContext, TestAppContext}; use indoc::{formatdoc, indoc}; use language_model::{ - LanguageModelRegistry, LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolUse, - LanguageModelToolUseId, + LanguageModelRegistry, LanguageModelRequestTool, LanguageModelToolResult, + LanguageModelToolResultContent, LanguageModelToolUse, LanguageModelToolUseId, }; use project::Project; use rand::prelude::*; @@ -951,7 +951,7 @@ fn tool_result( tool_use_id: LanguageModelToolUseId::from(id.into()), tool_name: name.into(), is_error: false, - content: result.into(), + content: LanguageModelToolResultContent::Text(result.into()), output: None, }) } diff --git a/crates/assistant_tools/src/edit_file_tool.rs b/crates/assistant_tools/src/edit_file_tool.rs index 8c60f980da..8c38534bee 100644 --- a/crates/assistant_tools/src/edit_file_tool.rs +++ b/crates/assistant_tools/src/edit_file_tool.rs @@ -5,7 +5,8 @@ use crate::{ }; use anyhow::{Result, anyhow}; use assistant_tool::{ - ActionLog, AnyToolCard, Tool, ToolCard, ToolResult, ToolResultOutput, ToolUseStatus, + ActionLog, AnyToolCard, Tool, ToolCard, ToolResult, ToolResultContent, ToolResultOutput, + ToolUseStatus, }; use buffer_diff::{BufferDiff, BufferDiffSnapshot}; use editor::{Editor, EditorMode, MultiBuffer, PathKey}; @@ -292,7 +293,10 @@ impl Tool for EditFileTool { } } else { Ok(ToolResultOutput { - content: format!("Edited {}:\n\n```diff\n{}\n```", input_path, diff), + content: ToolResultContent::Text(format!( + "Edited {}:\n\n```diff\n{}\n```", + input_path, diff + )), output: serde_json::to_value(output).ok(), }) } diff --git a/crates/assistant_tools/src/find_path_tool.rs b/crates/assistant_tools/src/find_path_tool.rs index 2004508a47..9061b4a45c 100644 --- a/crates/assistant_tools/src/find_path_tool.rs +++ b/crates/assistant_tools/src/find_path_tool.rs @@ -1,6 +1,8 @@ use crate::{schema::json_schema_for, ui::ToolCallCardHeader}; use anyhow::{Result, anyhow}; -use assistant_tool::{ActionLog, Tool, ToolCard, ToolResult, ToolResultOutput, ToolUseStatus}; +use assistant_tool::{ + ActionLog, Tool, ToolCard, ToolResult, ToolResultContent, ToolResultOutput, ToolUseStatus, +}; use editor::Editor; use futures::channel::oneshot::{self, Receiver}; use gpui::{ @@ -126,7 +128,7 @@ impl Tool for FindPathTool { write!(&mut message, "\n{}", mat.display()).unwrap(); } Ok(ToolResultOutput { - content: message, + content: ToolResultContent::Text(message), output: Some(serde_json::to_value(output)?), }) } diff --git a/crates/assistant_tools/src/grep_tool.rs b/crates/assistant_tools/src/grep_tool.rs index 88d26df3e5..3f6c87f5dc 100644 --- a/crates/assistant_tools/src/grep_tool.rs +++ b/crates/assistant_tools/src/grep_tool.rs @@ -752,9 +752,9 @@ mod tests { match task.output.await { Ok(result) => { if cfg!(windows) { - result.content.replace("root\\", "root/") + result.content.as_str().unwrap().replace("root\\", "root/") } else { - result.content + result.content.as_str().unwrap().to_string() } } Err(e) => panic!("Failed to run grep tool: {}", e), diff --git a/crates/assistant_tools/src/read_file_tool.rs b/crates/assistant_tools/src/read_file_tool.rs index 08c7adb737..ec237eb873 100644 --- a/crates/assistant_tools/src/read_file_tool.rs +++ b/crates/assistant_tools/src/read_file_tool.rs @@ -1,13 +1,17 @@ use crate::schema::json_schema_for; use anyhow::{Result, anyhow}; -use assistant_tool::outline; use assistant_tool::{ActionLog, Tool, ToolResult}; +use assistant_tool::{ToolResultContent, outline}; use gpui::{AnyWindowHandle, App, Entity, Task}; +use project::{ImageItem, image_store}; +use assistant_tool::ToolResultOutput; use indoc::formatdoc; use itertools::Itertools; use language::{Anchor, Point}; -use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat}; +use language_model::{ + LanguageModel, LanguageModelImage, LanguageModelRequest, LanguageModelToolSchemaFormat, +}; use project::{AgentLocation, Project}; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; @@ -86,7 +90,7 @@ impl Tool for ReadFileTool { _request: Arc, project: Entity, action_log: Entity, - _model: Arc, + model: Arc, _window: Option, cx: &mut App, ) -> ToolResult { @@ -100,6 +104,42 @@ impl Tool for ReadFileTool { }; let file_path = input.path.clone(); + + if image_store::is_image_file(&project, &project_path, cx) { + if !model.supports_images() { + return Task::ready(Err(anyhow!( + "Attempted to read an image, but Zed doesn't currently sending images to {}.", + model.name().0 + ))) + .into(); + } + + let task = cx.spawn(async move |cx| -> Result { + let image_entity: Entity = cx + .update(|cx| { + project.update(cx, |project, cx| { + project.open_image(project_path.clone(), cx) + }) + })? + .await?; + + let image = + image_entity.read_with(cx, |image_item, _| Arc::clone(&image_item.image))?; + + let language_model_image = cx + .update(|cx| LanguageModelImage::from_image(image, cx))? + .await + .ok_or_else(|| anyhow!("Failed to process image"))?; + + Ok(ToolResultOutput { + content: ToolResultContent::Image(language_model_image), + output: None, + }) + }); + + return task.into(); + } + cx.spawn(async move |cx| { let buffer = cx .update(|cx| { @@ -282,7 +322,10 @@ mod test { .output }) .await; - assert_eq!(result.unwrap().content, "This is a small file content"); + assert_eq!( + result.unwrap().content.as_str(), + Some("This is a small file content") + ); } #[gpui::test] @@ -322,6 +365,7 @@ mod test { }) .await; let content = result.unwrap(); + let content = content.as_str().unwrap(); assert_eq!( content.lines().skip(4).take(6).collect::>(), vec![ @@ -365,6 +409,8 @@ mod test { .collect::>(); pretty_assertions::assert_eq!( content + .as_str() + .unwrap() .lines() .skip(4) .take(expected_content.len()) @@ -408,7 +454,10 @@ mod test { .output }) .await; - assert_eq!(result.unwrap().content, "Line 2\nLine 3\nLine 4"); + assert_eq!( + result.unwrap().content.as_str(), + Some("Line 2\nLine 3\nLine 4") + ); } #[gpui::test] @@ -448,7 +497,7 @@ mod test { .output }) .await; - assert_eq!(result.unwrap().content, "Line 1\nLine 2"); + assert_eq!(result.unwrap().content.as_str(), Some("Line 1\nLine 2")); // end_line of 0 should result in at least 1 line let result = cx @@ -471,7 +520,7 @@ mod test { .output }) .await; - assert_eq!(result.unwrap().content, "Line 1"); + assert_eq!(result.unwrap().content.as_str(), Some("Line 1")); // when start_line > end_line, should still return at least 1 line let result = cx @@ -494,7 +543,7 @@ mod test { .output }) .await; - assert_eq!(result.unwrap().content, "Line 3"); + assert_eq!(result.unwrap().content.as_str(), Some("Line 3")); } fn init_test(cx: &mut TestAppContext) { diff --git a/crates/assistant_tools/src/terminal_tool.rs b/crates/assistant_tools/src/terminal_tool.rs index 82c98f4419..41fe3a4fe5 100644 --- a/crates/assistant_tools/src/terminal_tool.rs +++ b/crates/assistant_tools/src/terminal_tool.rs @@ -1,5 +1,5 @@ use crate::schema::json_schema_for; -use anyhow::{Context as _, Result, anyhow, bail}; +use anyhow::{Context as _, Result, anyhow}; use assistant_tool::{ActionLog, Tool, ToolCard, ToolResult, ToolUseStatus}; use futures::{FutureExt as _, future::Shared}; use gpui::{ @@ -125,18 +125,24 @@ impl Tool for TerminalTool { Err(err) => return Task::ready(Err(anyhow!(err))).into(), }; - let input_path = Path::new(&input.cd); - let working_dir = match working_dir(&input, &project, input_path, cx) { + let working_dir = match working_dir(&input, &project, cx) { Ok(dir) => dir, Err(err) => return Task::ready(Err(err)).into(), }; let program = self.determine_shell.clone(); let command = if cfg!(windows) { format!("$null | & {{{}}}", input.command.replace("\"", "'")) + } else if let Some(cwd) = working_dir + .as_ref() + .and_then(|cwd| cwd.as_os_str().to_str()) + { + // Make sure once we're *inside* the shell, we cd into `cwd` + format!("(cd {cwd}; {}) project.update(cx, |project, cx| { @@ -319,19 +325,13 @@ fn process_content( } else { content }; - let is_empty = content.trim().is_empty(); - - let content = format!( - "```\n{}{}```", - content, - if content.ends_with('\n') { "" } else { "\n" } - ); - + let content = content.trim(); + let is_empty = content.is_empty(); + let content = format!("```\n{content}\n```"); let content = if should_truncate { format!( - "Command output too long. The first {} bytes:\n\n{}", + "Command output too long. The first {} bytes:\n\n{content}", content.len(), - content, ) } else { content @@ -371,42 +371,47 @@ fn process_content( fn working_dir( input: &TerminalToolInput, project: &Entity, - input_path: &Path, cx: &mut App, ) -> Result> { let project = project.read(cx); + let cd = &input.cd; - if input.cd == "." { - // Accept "." as meaning "the one worktree" if we only have one worktree. + if cd == "." || cd == "" { + // Accept "." or "" as meaning "the one worktree" if we only have one worktree. let mut worktrees = project.worktrees(cx); match worktrees.next() { Some(worktree) => { - if worktrees.next().is_some() { - bail!( + if worktrees.next().is_none() { + Ok(Some(worktree.read(cx).abs_path().to_path_buf())) + } else { + Err(anyhow!( "'.' is ambiguous in multi-root workspaces. Please specify a root directory explicitly.", - ); + )) } - Ok(Some(worktree.read(cx).abs_path().to_path_buf())) } None => Ok(None), } - } else if input_path.is_absolute() { - // Absolute paths are allowed, but only if they're in one of the project's worktrees. - if !project - .worktrees(cx) - .any(|worktree| input_path.starts_with(&worktree.read(cx).abs_path())) - { - bail!("The absolute path must be within one of the project's worktrees"); + } else { + let input_path = Path::new(cd); + + if input_path.is_absolute() { + // Absolute paths are allowed, but only if they're in one of the project's worktrees. + if project + .worktrees(cx) + .any(|worktree| input_path.starts_with(&worktree.read(cx).abs_path())) + { + return Ok(Some(input_path.into())); + } + } else { + if let Some(worktree) = project.worktree_for_root_name(cd, cx) { + return Ok(Some(worktree.read(cx).abs_path().to_path_buf())); + } } - Ok(Some(input_path.into())) - } else { - let Some(worktree) = project.worktree_for_root_name(&input.cd, cx) else { - bail!("`cd` directory {:?} not found in the project", input.cd); - }; - - Ok(Some(worktree.read(cx).abs_path().to_path_buf())) + Err(anyhow!( + "`cd` directory {cd:?} was not in any of the project's worktrees." + )) } } @@ -727,8 +732,8 @@ mod tests { ) }); - let output = result.output.await.log_err().map(|output| output.content); - assert_eq!(output, Some("Command executed successfully.".into())); + let output = result.output.await.log_err().unwrap().content; + assert_eq!(output.as_str().unwrap(), "Command executed successfully."); } #[gpui::test] @@ -761,12 +766,13 @@ mod tests { cx, ); cx.spawn(async move |_| { - let output = headless_result - .output - .await - .log_err() - .map(|output| output.content); - assert_eq!(output, expected); + let output = headless_result.output.await.map(|output| output.content); + assert_eq!( + output + .ok() + .and_then(|content| content.as_str().map(ToString::to_string)), + expected + ); }) }; @@ -774,7 +780,7 @@ mod tests { check( TerminalToolInput { command: "pwd".into(), - cd: "project".into(), + cd: ".".into(), }, Some(format!( "```\n{}\n```", @@ -789,12 +795,9 @@ mod tests { check( TerminalToolInput { command: "pwd".into(), - cd: ".".into(), + cd: "other-project".into(), }, - Some(format!( - "```\n{}\n```", - tree.path().join("project").display() - )), + None, // other-project is a dir, but *not* a worktree (yet) cx, ) }) diff --git a/crates/assistant_tools/src/web_search_tool.rs b/crates/assistant_tools/src/web_search_tool.rs index d7e71e940f..46f7a79285 100644 --- a/crates/assistant_tools/src/web_search_tool.rs +++ b/crates/assistant_tools/src/web_search_tool.rs @@ -3,7 +3,9 @@ use std::{sync::Arc, time::Duration}; use crate::schema::json_schema_for; use crate::ui::ToolCallCardHeader; use anyhow::{Context as _, Result, anyhow}; -use assistant_tool::{ActionLog, Tool, ToolCard, ToolResult, ToolResultOutput, ToolUseStatus}; +use assistant_tool::{ + ActionLog, Tool, ToolCard, ToolResult, ToolResultContent, ToolResultOutput, ToolUseStatus, +}; use futures::{Future, FutureExt, TryFutureExt}; use gpui::{ AnyWindowHandle, App, AppContext, Context, Entity, IntoElement, Task, WeakEntity, Window, @@ -74,8 +76,10 @@ impl Tool for WebSearchTool { async move { let response = search_task.await.map_err(|err| anyhow!(err))?; Ok(ToolResultOutput { - content: serde_json::to_string(&response) - .context("Failed to serialize search results")?, + content: ToolResultContent::Text( + serde_json::to_string(&response) + .context("Failed to serialize search results")?, + ), output: Some(serde_json::to_value(response)?), }) } diff --git a/crates/copilot/src/copilot_chat.rs b/crates/copilot/src/copilot_chat.rs index ce5449d6f1..2ac6bfe5a7 100644 --- a/crates/copilot/src/copilot_chat.rs +++ b/crates/copilot/src/copilot_chat.rs @@ -113,7 +113,7 @@ pub enum ModelVendor { #[derive(Serialize, Deserialize, Debug, Eq, PartialEq, Clone)] #[serde(tag = "type")] -pub enum ChatMessageContent { +pub enum ChatMessagePart { #[serde(rename = "text")] Text { text: String }, #[serde(rename = "image_url")] @@ -194,26 +194,55 @@ pub enum ToolChoice { None, } -#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] +#[derive(Serialize, Deserialize, Debug)] #[serde(tag = "role", rename_all = "lowercase")] pub enum ChatMessage { Assistant { - content: Option, + content: ChatMessageContent, #[serde(default, skip_serializing_if = "Vec::is_empty")] tool_calls: Vec, }, User { - content: Vec, + content: ChatMessageContent, }, System { content: String, }, Tool { - content: String, + content: ChatMessageContent, tool_call_id: String, }, } +#[derive(Debug, Serialize, Deserialize)] +#[serde(untagged)] +pub enum ChatMessageContent { + OnlyText(String), + Multipart(Vec), +} + +impl ChatMessageContent { + pub fn empty() -> Self { + ChatMessageContent::Multipart(vec![]) + } +} + +impl From> for ChatMessageContent { + fn from(mut parts: Vec) -> Self { + if let [ChatMessagePart::Text { text }] = parts.as_mut_slice() { + ChatMessageContent::OnlyText(std::mem::take(text)) + } else { + ChatMessageContent::Multipart(parts) + } + } +} + +impl From for ChatMessageContent { + fn from(text: String) -> Self { + ChatMessageContent::OnlyText(text) + } +} + #[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] pub struct ToolCall { pub id: String, diff --git a/crates/eval/src/instance.rs b/crates/eval/src/instance.rs index bc0e2ac7b2..f7ba4a43ad 100644 --- a/crates/eval/src/instance.rs +++ b/crates/eval/src/instance.rs @@ -9,7 +9,7 @@ use handlebars::Handlebars; use language::{Buffer, DiagnosticSeverity, OffsetRangeExt as _}; use language_model::{ LanguageModel, LanguageModelCompletionEvent, LanguageModelRequest, LanguageModelRequestMessage, - MessageContent, Role, TokenUsage, + LanguageModelToolResultContent, MessageContent, Role, TokenUsage, }; use project::lsp_store::OpenLspBufferHandle; use project::{DiagnosticSummary, Project, ProjectPath}; @@ -964,7 +964,15 @@ impl RequestMarkdown { if tool_result.is_error { messages.push_str("**ERROR:**\n"); } - messages.push_str(&format!("{}\n\n", tool_result.content)); + + match &tool_result.content { + LanguageModelToolResultContent::Text(str) => { + writeln!(messages, "{}\n", str).ok(); + } + LanguageModelToolResultContent::Image(image) => { + writeln!(messages, "![Image](data:base64,{})\n", image.source).ok(); + } + } if let Some(output) = tool_result.output.as_ref() { writeln!( diff --git a/crates/language_model/src/fake_provider.rs b/crates/language_model/src/fake_provider.rs index b68cd39731..e94322608c 100644 --- a/crates/language_model/src/fake_provider.rs +++ b/crates/language_model/src/fake_provider.rs @@ -157,6 +157,10 @@ impl LanguageModel for FakeLanguageModel { false } + fn supports_images(&self) -> bool { + false + } + fn telemetry_id(&self) -> String { "fake".to_string() } diff --git a/crates/language_model/src/language_model.rs b/crates/language_model/src/language_model.rs index 2f234a7aaf..538ef95c5a 100644 --- a/crates/language_model/src/language_model.rs +++ b/crates/language_model/src/language_model.rs @@ -243,6 +243,9 @@ pub trait LanguageModel: Send + Sync { LanguageModelAvailability::Public } + /// Whether this model supports images + fn supports_images(&self) -> bool; + /// Whether this model supports tools. fn supports_tools(&self) -> bool; diff --git a/crates/language_model/src/request.rs b/crates/language_model/src/request.rs index 11befb5101..a78c6b4ce2 100644 --- a/crates/language_model/src/request.rs +++ b/crates/language_model/src/request.rs @@ -21,6 +21,16 @@ pub struct LanguageModelImage { size: Size, } +impl LanguageModelImage { + pub fn len(&self) -> usize { + self.source.len() + } + + pub fn is_empty(&self) -> bool { + self.source.is_empty() + } +} + impl std::fmt::Debug for LanguageModelImage { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("LanguageModelImage") @@ -134,10 +144,45 @@ pub struct LanguageModelToolResult { pub tool_use_id: LanguageModelToolUseId, pub tool_name: Arc, pub is_error: bool, - pub content: Arc, + pub content: LanguageModelToolResultContent, pub output: Option, } +#[derive(Debug, Clone, Deserialize, Serialize, Eq, PartialEq, Hash)] +#[serde(untagged)] +pub enum LanguageModelToolResultContent { + Text(Arc), + Image(LanguageModelImage), +} + +impl LanguageModelToolResultContent { + pub fn to_str(&self) -> Option<&str> { + match self { + Self::Text(text) => Some(&text), + Self::Image(_) => None, + } + } + + pub fn is_empty(&self) -> bool { + match self { + Self::Text(text) => text.chars().all(|c| c.is_whitespace()), + Self::Image(_) => false, + } + } +} + +impl From<&str> for LanguageModelToolResultContent { + fn from(value: &str) -> Self { + Self::Text(Arc::from(value)) + } +} + +impl From for LanguageModelToolResultContent { + fn from(value: String) -> Self { + Self::Text(Arc::from(value)) + } +} + #[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)] pub enum MessageContent { Text(String), @@ -151,6 +196,29 @@ pub enum MessageContent { ToolResult(LanguageModelToolResult), } +impl MessageContent { + pub fn to_str(&self) -> Option<&str> { + match self { + MessageContent::Text(text) => Some(text.as_str()), + MessageContent::Thinking { text, .. } => Some(text.as_str()), + MessageContent::RedactedThinking(_) => None, + MessageContent::ToolResult(tool_result) => tool_result.content.to_str(), + MessageContent::ToolUse(_) | MessageContent::Image(_) => None, + } + } + + pub fn is_empty(&self) -> bool { + match self { + MessageContent::Text(text) => text.chars().all(|c| c.is_whitespace()), + MessageContent::Thinking { text, .. } => text.chars().all(|c| c.is_whitespace()), + MessageContent::ToolResult(tool_result) => tool_result.content.is_empty(), + MessageContent::RedactedThinking(_) + | MessageContent::ToolUse(_) + | MessageContent::Image(_) => false, + } + } +} + impl From for MessageContent { fn from(value: String) -> Self { MessageContent::Text(value) @@ -173,13 +241,7 @@ pub struct LanguageModelRequestMessage { impl LanguageModelRequestMessage { pub fn string_contents(&self) -> String { let mut buffer = String::new(); - for string in self.content.iter().filter_map(|content| match content { - MessageContent::Text(text) => Some(text.as_str()), - MessageContent::Thinking { text, .. } => Some(text.as_str()), - MessageContent::RedactedThinking(_) => None, - MessageContent::ToolResult(tool_result) => Some(tool_result.content.as_ref()), - MessageContent::ToolUse(_) | MessageContent::Image(_) => None, - }) { + for string in self.content.iter().filter_map(|content| content.to_str()) { buffer.push_str(string); } @@ -187,16 +249,7 @@ impl LanguageModelRequestMessage { } pub fn contents_empty(&self) -> bool { - self.content.iter().all(|content| match content { - MessageContent::Text(text) => text.chars().all(|c| c.is_whitespace()), - MessageContent::Thinking { text, .. } => text.chars().all(|c| c.is_whitespace()), - MessageContent::ToolResult(tool_result) => { - tool_result.content.chars().all(|c| c.is_whitespace()) - } - MessageContent::RedactedThinking(_) - | MessageContent::ToolUse(_) - | MessageContent::Image(_) => false, - }) + self.content.iter().all(|content| content.is_empty()) } } diff --git a/crates/language_model_selector/src/language_model_selector.rs b/crates/language_model_selector/src/language_model_selector.rs index e1dbb1cc42..49939b91b5 100644 --- a/crates/language_model_selector/src/language_model_selector.rs +++ b/crates/language_model_selector/src/language_model_selector.rs @@ -759,6 +759,10 @@ mod tests { false } + fn supports_images(&self) -> bool { + false + } + fn telemetry_id(&self) -> String { format!("{}/{}", self.provider_id.0, self.name.0) } diff --git a/crates/language_models/src/provider/anthropic.rs b/crates/language_models/src/provider/anthropic.rs index 3a36a8339b..eccde976d3 100644 --- a/crates/language_models/src/provider/anthropic.rs +++ b/crates/language_models/src/provider/anthropic.rs @@ -1,6 +1,9 @@ use crate::AllLanguageModelSettings; use crate::ui::InstructionListItem; -use anthropic::{AnthropicError, AnthropicModelMode, ContentDelta, Event, ResponseContent, Usage}; +use anthropic::{ + AnthropicError, AnthropicModelMode, ContentDelta, Event, ResponseContent, ToolResultContent, + ToolResultPart, Usage, +}; use anyhow::{Context as _, Result, anyhow}; use collections::{BTreeMap, HashMap}; use credentials_provider::CredentialsProvider; @@ -15,8 +18,8 @@ use language_model::{ AuthenticateError, LanguageModel, LanguageModelCacheConfiguration, LanguageModelCompletionError, LanguageModelId, LanguageModelKnownError, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, - LanguageModelProviderState, LanguageModelRequest, LanguageModelToolChoice, MessageContent, - RateLimiter, Role, + LanguageModelProviderState, LanguageModelRequest, LanguageModelToolChoice, + LanguageModelToolResultContent, MessageContent, RateLimiter, Role, }; use language_model::{LanguageModelCompletionEvent, LanguageModelToolUse, StopReason}; use schemars::JsonSchema; @@ -346,9 +349,14 @@ pub fn count_anthropic_tokens( MessageContent::ToolUse(_tool_use) => { // TODO: Estimate token usage from tool uses. } - MessageContent::ToolResult(tool_result) => { - string_contents.push_str(&tool_result.content); - } + MessageContent::ToolResult(tool_result) => match &tool_result.content { + LanguageModelToolResultContent::Text(txt) => { + string_contents.push_str(txt); + } + LanguageModelToolResultContent::Image(image) => { + tokens_from_images += image.estimate_tokens(); + } + }, } } @@ -421,6 +429,10 @@ impl LanguageModel for AnthropicModel { true } + fn supports_images(&self) -> bool { + true + } + fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool { match choice { LanguageModelToolChoice::Auto @@ -575,7 +587,20 @@ pub fn into_anthropic( Some(anthropic::RequestContent::ToolResult { tool_use_id: tool_result.tool_use_id.to_string(), is_error: tool_result.is_error, - content: tool_result.content.to_string(), + content: match tool_result.content { + LanguageModelToolResultContent::Text(text) => { + ToolResultContent::JustText(text.to_string()) + } + LanguageModelToolResultContent::Image(image) => { + ToolResultContent::Multipart(vec![ToolResultPart::Image { + source: anthropic::ImageSource { + source_type: "base64".to_string(), + media_type: "image/png".to_string(), + data: image.source.to_string(), + }, + }]) + } + }, cache_control, }) } diff --git a/crates/language_models/src/provider/bedrock.rs b/crates/language_models/src/provider/bedrock.rs index f675250609..f4f8e2dce4 100644 --- a/crates/language_models/src/provider/bedrock.rs +++ b/crates/language_models/src/provider/bedrock.rs @@ -36,7 +36,8 @@ use language_model::{ LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, LanguageModelToolChoice, - LanguageModelToolUse, MessageContent, RateLimiter, Role, TokenUsage, + LanguageModelToolResultContent, LanguageModelToolUse, MessageContent, RateLimiter, Role, + TokenUsage, }; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; @@ -490,6 +491,10 @@ impl LanguageModel for BedrockModel { self.model.supports_tool_use() } + fn supports_images(&self) -> bool { + false + } + fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool { match choice { LanguageModelToolChoice::Auto | LanguageModelToolChoice::Any => { @@ -635,9 +640,17 @@ pub fn into_bedrock( MessageContent::ToolResult(tool_result) => { BedrockToolResultBlock::builder() .tool_use_id(tool_result.tool_use_id.to_string()) - .content(BedrockToolResultContentBlock::Text( - tool_result.content.to_string(), - )) + .content(match tool_result.content { + LanguageModelToolResultContent::Text(text) => { + BedrockToolResultContentBlock::Text(text.to_string()) + } + LanguageModelToolResultContent::Image(_) => { + BedrockToolResultContentBlock::Text( + // TODO: Bedrock image support + "[Tool responded with an image, but Zed doesn't support these in Bedrock models yet]".to_string() + ) + } + }) .status({ if tool_result.is_error { BedrockToolResultStatus::Error @@ -762,9 +775,14 @@ pub fn get_bedrock_tokens( MessageContent::ToolUse(_tool_use) => { // TODO: Estimate token usage from tool uses. } - MessageContent::ToolResult(tool_result) => { - string_contents.push_str(&tool_result.content); - } + MessageContent::ToolResult(tool_result) => match tool_result.content { + LanguageModelToolResultContent::Text(text) => { + string_contents.push_str(&text); + } + LanguageModelToolResultContent::Image(image) => { + tokens_from_images += image.estimate_tokens(); + } + }, } } diff --git a/crates/language_models/src/provider/cloud.rs b/crates/language_models/src/provider/cloud.rs index c35f5f10c1..ffc56c684b 100644 --- a/crates/language_models/src/provider/cloud.rs +++ b/crates/language_models/src/provider/cloud.rs @@ -686,6 +686,14 @@ impl LanguageModel for CloudLanguageModel { } } + fn supports_images(&self) -> bool { + match self.model { + CloudModel::Anthropic(_) => true, + CloudModel::Google(_) => true, + CloudModel::OpenAi(_) => false, + } + } + fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool { match choice { LanguageModelToolChoice::Auto diff --git a/crates/language_models/src/provider/copilot_chat.rs b/crates/language_models/src/provider/copilot_chat.rs index 0c250f0f47..5c96266178 100644 --- a/crates/language_models/src/provider/copilot_chat.rs +++ b/crates/language_models/src/provider/copilot_chat.rs @@ -5,8 +5,9 @@ use std::sync::Arc; use anyhow::{Result, anyhow}; use collections::HashMap; use copilot::copilot_chat::{ - ChatMessage, ChatMessageContent, CopilotChat, ImageUrl, Model as CopilotChatModel, ModelVendor, - Request as CopilotChatRequest, ResponseEvent, Tool, ToolCall, + ChatMessage, ChatMessageContent, ChatMessagePart, CopilotChat, ImageUrl, + Model as CopilotChatModel, ModelVendor, Request as CopilotChatRequest, ResponseEvent, Tool, + ToolCall, }; use copilot::{Copilot, Status}; use futures::future::BoxFuture; @@ -20,12 +21,14 @@ use language_model::{ AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, - LanguageModelRequestMessage, LanguageModelToolChoice, LanguageModelToolSchemaFormat, - LanguageModelToolUse, MessageContent, RateLimiter, Role, StopReason, + LanguageModelRequestMessage, LanguageModelToolChoice, LanguageModelToolResultContent, + LanguageModelToolSchemaFormat, LanguageModelToolUse, MessageContent, RateLimiter, Role, + StopReason, }; use settings::SettingsStore; use std::time::Duration; use ui::prelude::*; +use util::debug_panic; use super::anthropic::count_anthropic_tokens; use super::google::count_google_tokens; @@ -198,6 +201,10 @@ impl LanguageModel for CopilotChatLanguageModel { self.model.supports_tools() } + fn supports_images(&self) -> bool { + self.model.supports_vision() + } + fn tool_input_format(&self) -> LanguageModelToolSchemaFormat { match self.model.vendor() { ModelVendor::OpenAI | ModelVendor::Anthropic => { @@ -447,9 +454,28 @@ fn into_copilot_chat( Role::User => { for content in &message.content { if let MessageContent::ToolResult(tool_result) = content { + let content = match &tool_result.content { + LanguageModelToolResultContent::Text(text) => text.to_string().into(), + LanguageModelToolResultContent::Image(image) => { + if model.supports_vision() { + ChatMessageContent::Multipart(vec![ChatMessagePart::Image { + image_url: ImageUrl { + url: image.to_base64_url(), + }, + }]) + } else { + debug_panic!( + "This should be caught at {} level", + tool_result.tool_name + ); + "[Tool responded with an image, but this model does not support vision]".to_string().into() + } + } + }; + messages.push(ChatMessage::Tool { tool_call_id: tool_result.tool_use_id.to_string(), - content: tool_result.content.to_string(), + content, }); } } @@ -460,18 +486,18 @@ fn into_copilot_chat( MessageContent::Text(text) | MessageContent::Thinking { text, .. } if !text.is_empty() => { - if let Some(ChatMessageContent::Text { text: text_content }) = + if let Some(ChatMessagePart::Text { text: text_content }) = content_parts.last_mut() { text_content.push_str(text); } else { - content_parts.push(ChatMessageContent::Text { + content_parts.push(ChatMessagePart::Text { text: text.to_string(), }); } } MessageContent::Image(image) if model.supports_vision() => { - content_parts.push(ChatMessageContent::Image { + content_parts.push(ChatMessagePart::Image { image_url: ImageUrl { url: image.to_base64_url(), }, @@ -483,7 +509,7 @@ fn into_copilot_chat( if !content_parts.is_empty() { messages.push(ChatMessage::User { - content: content_parts, + content: content_parts.into(), }); } } @@ -523,9 +549,9 @@ fn into_copilot_chat( messages.push(ChatMessage::Assistant { content: if text_content.is_empty() { - None + ChatMessageContent::empty() } else { - Some(text_content) + text_content.into() }, tool_calls, }); diff --git a/crates/language_models/src/provider/deepseek.rs b/crates/language_models/src/provider/deepseek.rs index 1d8e22024f..8492741aad 100644 --- a/crates/language_models/src/provider/deepseek.rs +++ b/crates/language_models/src/provider/deepseek.rs @@ -287,6 +287,10 @@ impl LanguageModel for DeepSeekLanguageModel { false } + fn supports_images(&self) -> bool { + false + } + fn telemetry_id(&self) -> String { format!("deepseek/{}", self.model.id()) } diff --git a/crates/language_models/src/provider/google.rs b/crates/language_models/src/provider/google.rs index b4753a7636..4f3c0cb112 100644 --- a/crates/language_models/src/provider/google.rs +++ b/crates/language_models/src/provider/google.rs @@ -313,6 +313,10 @@ impl LanguageModel for GoogleLanguageModel { true } + fn supports_images(&self) -> bool { + true + } + fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool { match choice { LanguageModelToolChoice::Auto diff --git a/crates/language_models/src/provider/lmstudio.rs b/crates/language_models/src/provider/lmstudio.rs index 57f9e4ad86..509816272c 100644 --- a/crates/language_models/src/provider/lmstudio.rs +++ b/crates/language_models/src/provider/lmstudio.rs @@ -285,6 +285,10 @@ impl LanguageModel for LmStudioLanguageModel { false } + fn supports_images(&self) -> bool { + false + } + fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool { false } diff --git a/crates/language_models/src/provider/mistral.rs b/crates/language_models/src/provider/mistral.rs index 9ca8623e0c..5143767e9e 100644 --- a/crates/language_models/src/provider/mistral.rs +++ b/crates/language_models/src/provider/mistral.rs @@ -303,6 +303,10 @@ impl LanguageModel for MistralLanguageModel { false } + fn supports_images(&self) -> bool { + false + } + fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool { false } diff --git a/crates/language_models/src/provider/ollama.rs b/crates/language_models/src/provider/ollama.rs index e0a19f1740..1bb46ea482 100644 --- a/crates/language_models/src/provider/ollama.rs +++ b/crates/language_models/src/provider/ollama.rs @@ -325,6 +325,10 @@ impl LanguageModel for OllamaLanguageModel { self.model.supports_tools.unwrap_or(false) } + fn supports_images(&self) -> bool { + false + } + fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool { match choice { LanguageModelToolChoice::Auto => false, diff --git a/crates/language_models/src/provider/open_ai.rs b/crates/language_models/src/provider/open_ai.rs index 4f2a750c5a..b19b4653b1 100644 --- a/crates/language_models/src/provider/open_ai.rs +++ b/crates/language_models/src/provider/open_ai.rs @@ -12,7 +12,8 @@ use language_model::{ AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, - LanguageModelToolChoice, LanguageModelToolUse, MessageContent, RateLimiter, Role, StopReason, + LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent, + RateLimiter, Role, StopReason, }; use open_ai::{Model, ResponseStreamEvent, stream_completion}; use schemars::JsonSchema; @@ -295,6 +296,10 @@ impl LanguageModel for OpenAiLanguageModel { true } + fn supports_images(&self) -> bool { + false + } + fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool { match choice { LanguageModelToolChoice::Auto => true, @@ -392,8 +397,16 @@ pub fn into_open_ai( } } MessageContent::ToolResult(tool_result) => { + let content = match &tool_result.content { + LanguageModelToolResultContent::Text(text) => text.to_string(), + LanguageModelToolResultContent::Image(_) => { + // TODO: Open AI image support + "[Tool responded with an image, but Zed doesn't support these in Open AI models yet]".to_string() + } + }; + messages.push(open_ai::RequestMessage::Tool { - content: tool_result.content.to_string(), + content, tool_call_id: tool_result.tool_use_id.to_string(), }); } diff --git a/crates/remote_server/src/remote_editing_tests.rs b/crates/remote_server/src/remote_editing_tests.rs index 5ed462c934..9c4cfc6743 100644 --- a/crates/remote_server/src/remote_editing_tests.rs +++ b/crates/remote_server/src/remote_editing_tests.rs @@ -2,7 +2,7 @@ /// The tests in this file assume that server_cx is running on Windows too. /// We neead to find a way to test Windows-Non-Windows interactions. use crate::headless_project::HeadlessProject; -use assistant_tool::Tool as _; +use assistant_tool::{Tool as _, ToolResultContent}; use assistant_tools::{ReadFileTool, ReadFileToolInput}; use client::{Client, UserStore}; use clock::FakeSystemClock; @@ -1593,7 +1593,7 @@ async fn test_remote_agent_fs_tool_calls(cx: &mut TestAppContext, server_cx: &mu ) }); let output = exists_result.output.await.unwrap().content; - assert_eq!(output, "B"); + assert_eq!(output, ToolResultContent::Text("B".to_string())); let input = ReadFileToolInput { path: "project/c.txt".into(),