Have read_file support images (#30435)

This is very basic support for them. There are a number of other TODOs
before this is really a first-class supported feature, so not adding any
release notes for it; for now, this PR just makes it so that if
read_file tries to read a PNG (which has come up in practice), it at
least correctly sends it to Anthropic instead of messing up.

This also lays the groundwork for future PRs for more first-class
support for images in tool calls across more image file formats and LLM
providers.

Release Notes:

- N/A

---------

Co-authored-by: Agus Zubiaga <hi@aguz.me>
Co-authored-by: Agus Zubiaga <agus@zed.dev>
This commit is contained in:
Richard Feldman 2025-05-13 10:58:00 +02:00 committed by GitHub
parent f01af006e1
commit 8fdf309a4a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
30 changed files with 557 additions and 194 deletions

View file

@ -33,7 +33,9 @@ use language_model::{
LanguageModelRequestMessage, LanguageModelToolUseId, MessageContent, Role, StopReason, LanguageModelRequestMessage, LanguageModelToolUseId, MessageContent, Role, StopReason,
}; };
use markdown::parser::{CodeBlockKind, CodeBlockMetadata}; use markdown::parser::{CodeBlockKind, CodeBlockMetadata};
use markdown::{HeadingLevelStyles, Markdown, MarkdownElement, MarkdownStyle, ParsedMarkdown}; use markdown::{
HeadingLevelStyles, Markdown, MarkdownElement, MarkdownStyle, ParsedMarkdown, PathWithRange,
};
use project::{ProjectEntryId, ProjectItem as _}; use project::{ProjectEntryId, ProjectItem as _};
use rope::Point; use rope::Point;
use settings::{Settings as _, SettingsStore, update_settings_file}; use settings::{Settings as _, SettingsStore, update_settings_file};
@ -430,49 +432,8 @@ fn render_markdown_code_block(
let path_range = path_range.clone(); let path_range = path_range.clone();
move |_, window, cx| { move |_, window, cx| {
workspace workspace
.update(cx, { .update(cx, |workspace, cx| {
|workspace, cx| { open_path(&path_range, window, workspace, cx)
let Some(project_path) = workspace
.project()
.read(cx)
.find_project_path(&path_range.path, cx)
else {
return;
};
let Some(target) = path_range.range.as_ref().map(|range| {
Point::new(
// Line number is 1-based
range.start.line.saturating_sub(1),
range.start.col.unwrap_or(0),
)
}) else {
return;
};
let open_task = workspace.open_path(
project_path,
None,
true,
window,
cx,
);
window
.spawn(cx, async move |cx| {
let item = open_task.await?;
if let Some(active_editor) =
item.downcast::<Editor>()
{
active_editor
.update_in(cx, |editor, window, cx| {
editor.go_to_singleton_buffer_point(
target, window, cx,
);
})
.ok();
}
anyhow::Ok(())
})
.detach_and_log_err(cx);
}
}) })
.ok(); .ok();
} }
@ -598,6 +559,45 @@ fn render_markdown_code_block(
.when(can_expand && !is_expanded, |this| this.max_h_80()) .when(can_expand && !is_expanded, |this| this.max_h_80())
} }
fn open_path(
path_range: &PathWithRange,
window: &mut Window,
workspace: &mut Workspace,
cx: &mut Context<'_, Workspace>,
) {
let Some(project_path) = workspace
.project()
.read(cx)
.find_project_path(&path_range.path, cx)
else {
return; // TODO instead of just bailing out, open that path in a buffer.
};
let Some(target) = path_range.range.as_ref().map(|range| {
Point::new(
// Line number is 1-based
range.start.line.saturating_sub(1),
range.start.col.unwrap_or(0),
)
}) else {
return;
};
let open_task = workspace.open_path(project_path, None, true, window, cx);
window
.spawn(cx, async move |cx| {
let item = open_task.await?;
if let Some(active_editor) = item.downcast::<Editor>() {
active_editor
.update_in(cx, |editor, window, cx| {
editor.go_to_singleton_buffer_point(target, window, cx);
})
.ok();
}
anyhow::Ok(())
})
.detach_and_log_err(cx);
}
fn render_code_language( fn render_code_language(
language: Option<&Arc<Language>>, language: Option<&Arc<Language>>,
name_fallback: SharedString, name_fallback: SharedString,

View file

@ -22,9 +22,9 @@ use language_model::{
ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
LanguageModelId, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest, LanguageModelId, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult, LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent, LanguageModelToolResultContent, LanguageModelToolUseId, MaxMonthlySpendReachedError,
ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, SelectedModel, MessageContent, ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role,
StopReason, TokenUsage, SelectedModel, StopReason, TokenUsage,
}; };
use postage::stream::Stream as _; use postage::stream::Stream as _;
use project::Project; use project::Project;
@ -880,7 +880,13 @@ impl Thread {
} }
pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> { pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
Some(&self.tool_use.tool_result(id)?.content) match &self.tool_use.tool_result(id)?.content {
LanguageModelToolResultContent::Text(str) => Some(str),
LanguageModelToolResultContent::Image(_) => {
// TODO: We should display image
None
}
}
} }
pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> { pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
@ -2502,7 +2508,15 @@ impl Thread {
} }
writeln!(markdown, "**\n")?; writeln!(markdown, "**\n")?;
writeln!(markdown, "{}", tool_result.content)?; match &tool_result.content {
LanguageModelToolResultContent::Text(str) => {
writeln!(markdown, "{}", str)?;
}
LanguageModelToolResultContent::Image(image) => {
writeln!(markdown, "![Image](data:base64,{})", image.source)?;
}
}
if let Some(output) = tool_result.output.as_ref() { if let Some(output) = tool_result.output.as_ref() {
writeln!( writeln!(
markdown, markdown,

View file

@ -19,7 +19,7 @@ use gpui::{
}; };
use heed::Database; use heed::Database;
use heed::types::SerdeBincode; use heed::types::SerdeBincode;
use language_model::{LanguageModelToolUseId, Role, TokenUsage}; use language_model::{LanguageModelToolResultContent, LanguageModelToolUseId, Role, TokenUsage};
use project::context_server_store::{ContextServerStatus, ContextServerStore}; use project::context_server_store::{ContextServerStatus, ContextServerStore};
use project::{Project, ProjectItem, ProjectPath, Worktree}; use project::{Project, ProjectItem, ProjectPath, Worktree};
use prompt_store::{ use prompt_store::{
@ -775,7 +775,7 @@ pub struct SerializedToolUse {
pub struct SerializedToolResult { pub struct SerializedToolResult {
pub tool_use_id: LanguageModelToolUseId, pub tool_use_id: LanguageModelToolUseId,
pub is_error: bool, pub is_error: bool,
pub content: Arc<str>, pub content: LanguageModelToolResultContent,
pub output: Option<serde_json::Value>, pub output: Option<serde_json::Value>,
} }

View file

@ -1,14 +1,16 @@
use std::sync::Arc; use std::sync::Arc;
use anyhow::Result; use anyhow::Result;
use assistant_tool::{AnyToolCard, Tool, ToolResultOutput, ToolUseStatus, ToolWorkingSet}; use assistant_tool::{
AnyToolCard, Tool, ToolResultContent, ToolResultOutput, ToolUseStatus, ToolWorkingSet,
};
use collections::HashMap; use collections::HashMap;
use futures::FutureExt as _; use futures::FutureExt as _;
use futures::future::Shared; use futures::future::Shared;
use gpui::{App, Entity, SharedString, Task}; use gpui::{App, Entity, SharedString, Task};
use language_model::{ use language_model::{
ConfiguredModel, LanguageModel, LanguageModelRequest, LanguageModelToolResult, ConfiguredModel, LanguageModel, LanguageModelRequest, LanguageModelToolResult,
LanguageModelToolUse, LanguageModelToolUseId, Role, LanguageModelToolResultContent, LanguageModelToolUse, LanguageModelToolUseId, Role,
}; };
use project::Project; use project::Project;
use ui::{IconName, Window}; use ui::{IconName, Window};
@ -165,10 +167,16 @@ impl ToolUseState {
let status = (|| { let status = (|| {
if let Some(tool_result) = tool_result { if let Some(tool_result) = tool_result {
let content = tool_result
.content
.to_str()
.map(|str| str.to_owned().into())
.unwrap_or_default();
return if tool_result.is_error { return if tool_result.is_error {
ToolUseStatus::Error(tool_result.content.clone().into()) ToolUseStatus::Error(content)
} else { } else {
ToolUseStatus::Finished(tool_result.content.clone().into()) ToolUseStatus::Finished(content)
}; };
} }
@ -399,21 +407,44 @@ impl ToolUseState {
let tool_result = output.content; let tool_result = output.content;
const BYTES_PER_TOKEN_ESTIMATE: usize = 3; const BYTES_PER_TOKEN_ESTIMATE: usize = 3;
// Protect from clearly large output let old_use = self.pending_tool_uses_by_id.remove(&tool_use_id);
// Protect from overly large output
let tool_output_limit = configured_model let tool_output_limit = configured_model
.map(|model| model.model.max_token_count() * BYTES_PER_TOKEN_ESTIMATE) .map(|model| model.model.max_token_count() * BYTES_PER_TOKEN_ESTIMATE)
.unwrap_or(usize::MAX); .unwrap_or(usize::MAX);
let tool_result = if tool_result.len() <= tool_output_limit { let content = match tool_result {
tool_result ToolResultContent::Text(text) => {
} else { let truncated = truncate_lines_to_byte_limit(&text, tool_output_limit);
let truncated = truncate_lines_to_byte_limit(&tool_result, tool_output_limit);
LanguageModelToolResultContent::Text(
format!( format!(
"Tool result too long. The first {} bytes:\n\n{}", "Tool result too long. The first {} bytes:\n\n{}",
truncated.len(), truncated.len(),
truncated truncated
) )
.into(),
)
}
ToolResultContent::Image(language_model_image) => {
if language_model_image.estimate_tokens() < tool_output_limit {
LanguageModelToolResultContent::Image(language_model_image)
} else {
self.tool_results.insert(
tool_use_id.clone(),
LanguageModelToolResult {
tool_use_id: tool_use_id.clone(),
tool_name,
content: "Tool responded with an image that would exceeded the remaining tokens".into(),
is_error: true,
output: None,
},
);
return old_use;
}
}
}; };
self.tool_results.insert( self.tool_results.insert(
@ -421,12 +452,13 @@ impl ToolUseState {
LanguageModelToolResult { LanguageModelToolResult {
tool_use_id: tool_use_id.clone(), tool_use_id: tool_use_id.clone(),
tool_name, tool_name,
content: tool_result.into(), content,
is_error: false, is_error: false,
output: output.output, output: output.output,
}, },
); );
self.pending_tool_uses_by_id.remove(&tool_use_id)
old_use
} }
Err(err) => { Err(err) => {
self.tool_results.insert( self.tool_results.insert(
@ -434,7 +466,7 @@ impl ToolUseState {
LanguageModelToolResult { LanguageModelToolResult {
tool_use_id: tool_use_id.clone(), tool_use_id: tool_use_id.clone(),
tool_name, tool_name,
content: err.to_string().into(), content: LanguageModelToolResultContent::Text(err.to_string().into()),
is_error: true, is_error: true,
output: None, output: None,
}, },

View file

@ -534,12 +534,26 @@ pub enum RequestContent {
ToolResult { ToolResult {
tool_use_id: String, tool_use_id: String,
is_error: bool, is_error: bool,
content: String, content: ToolResultContent,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
cache_control: Option<CacheControl>, cache_control: Option<CacheControl>,
}, },
} }
#[derive(Debug, Serialize, Deserialize)]
#[serde(untagged)]
pub enum ToolResultContent {
JustText(String),
Multipart(Vec<ToolResultPart>),
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "lowercase")]
pub enum ToolResultPart {
Text { text: String },
Image { source: ImageSource },
}
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
#[serde(tag = "type")] #[serde(tag = "type")]
pub enum ResponseContent { pub enum ResponseContent {

View file

@ -19,6 +19,7 @@ use gpui::Window;
use gpui::{App, Entity, SharedString, Task, WeakEntity}; use gpui::{App, Entity, SharedString, Task, WeakEntity};
use icons::IconName; use icons::IconName;
use language_model::LanguageModel; use language_model::LanguageModel;
use language_model::LanguageModelImage;
use language_model::LanguageModelRequest; use language_model::LanguageModelRequest;
use language_model::LanguageModelToolSchemaFormat; use language_model::LanguageModelToolSchemaFormat;
use project::Project; use project::Project;
@ -65,21 +66,50 @@ impl ToolUseStatus {
#[derive(Debug)] #[derive(Debug)]
pub struct ToolResultOutput { pub struct ToolResultOutput {
pub content: String, pub content: ToolResultContent,
pub output: Option<serde_json::Value>, pub output: Option<serde_json::Value>,
} }
#[derive(Debug, PartialEq, Eq)]
pub enum ToolResultContent {
Text(String),
Image(LanguageModelImage),
}
impl ToolResultContent {
pub fn len(&self) -> usize {
match self {
ToolResultContent::Text(str) => str.len(),
ToolResultContent::Image(image) => image.len(),
}
}
pub fn is_empty(&self) -> bool {
match self {
ToolResultContent::Text(str) => str.is_empty(),
ToolResultContent::Image(image) => image.is_empty(),
}
}
pub fn as_str(&self) -> Option<&str> {
match self {
ToolResultContent::Text(str) => Some(str),
ToolResultContent::Image(_) => None,
}
}
}
impl From<String> for ToolResultOutput { impl From<String> for ToolResultOutput {
fn from(value: String) -> Self { fn from(value: String) -> Self {
ToolResultOutput { ToolResultOutput {
content: value, content: ToolResultContent::Text(value),
output: None, output: None,
} }
} }
} }
impl Deref for ToolResultOutput { impl Deref for ToolResultOutput {
type Target = String; type Target = ToolResultContent;
fn deref(&self) -> &Self::Target { fn deref(&self) -> &Self::Target {
&self.content &self.content

View file

@ -10,8 +10,8 @@ use futures::{FutureExt, future::LocalBoxFuture};
use gpui::{AppContext, TestAppContext}; use gpui::{AppContext, TestAppContext};
use indoc::{formatdoc, indoc}; use indoc::{formatdoc, indoc};
use language_model::{ use language_model::{
LanguageModelRegistry, LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolUse, LanguageModelRegistry, LanguageModelRequestTool, LanguageModelToolResult,
LanguageModelToolUseId, LanguageModelToolResultContent, LanguageModelToolUse, LanguageModelToolUseId,
}; };
use project::Project; use project::Project;
use rand::prelude::*; use rand::prelude::*;
@ -951,7 +951,7 @@ fn tool_result(
tool_use_id: LanguageModelToolUseId::from(id.into()), tool_use_id: LanguageModelToolUseId::from(id.into()),
tool_name: name.into(), tool_name: name.into(),
is_error: false, is_error: false,
content: result.into(), content: LanguageModelToolResultContent::Text(result.into()),
output: None, output: None,
}) })
} }

View file

@ -5,7 +5,8 @@ use crate::{
}; };
use anyhow::{Result, anyhow}; use anyhow::{Result, anyhow};
use assistant_tool::{ use assistant_tool::{
ActionLog, AnyToolCard, Tool, ToolCard, ToolResult, ToolResultOutput, ToolUseStatus, ActionLog, AnyToolCard, Tool, ToolCard, ToolResult, ToolResultContent, ToolResultOutput,
ToolUseStatus,
}; };
use buffer_diff::{BufferDiff, BufferDiffSnapshot}; use buffer_diff::{BufferDiff, BufferDiffSnapshot};
use editor::{Editor, EditorMode, MultiBuffer, PathKey}; use editor::{Editor, EditorMode, MultiBuffer, PathKey};
@ -292,7 +293,10 @@ impl Tool for EditFileTool {
} }
} else { } else {
Ok(ToolResultOutput { Ok(ToolResultOutput {
content: format!("Edited {}:\n\n```diff\n{}\n```", input_path, diff), content: ToolResultContent::Text(format!(
"Edited {}:\n\n```diff\n{}\n```",
input_path, diff
)),
output: serde_json::to_value(output).ok(), output: serde_json::to_value(output).ok(),
}) })
} }

View file

@ -1,6 +1,8 @@
use crate::{schema::json_schema_for, ui::ToolCallCardHeader}; use crate::{schema::json_schema_for, ui::ToolCallCardHeader};
use anyhow::{Result, anyhow}; use anyhow::{Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolCard, ToolResult, ToolResultOutput, ToolUseStatus}; use assistant_tool::{
ActionLog, Tool, ToolCard, ToolResult, ToolResultContent, ToolResultOutput, ToolUseStatus,
};
use editor::Editor; use editor::Editor;
use futures::channel::oneshot::{self, Receiver}; use futures::channel::oneshot::{self, Receiver};
use gpui::{ use gpui::{
@ -126,7 +128,7 @@ impl Tool for FindPathTool {
write!(&mut message, "\n{}", mat.display()).unwrap(); write!(&mut message, "\n{}", mat.display()).unwrap();
} }
Ok(ToolResultOutput { Ok(ToolResultOutput {
content: message, content: ToolResultContent::Text(message),
output: Some(serde_json::to_value(output)?), output: Some(serde_json::to_value(output)?),
}) })
} }

View file

@ -752,9 +752,9 @@ mod tests {
match task.output.await { match task.output.await {
Ok(result) => { Ok(result) => {
if cfg!(windows) { if cfg!(windows) {
result.content.replace("root\\", "root/") result.content.as_str().unwrap().replace("root\\", "root/")
} else { } else {
result.content result.content.as_str().unwrap().to_string()
} }
} }
Err(e) => panic!("Failed to run grep tool: {}", e), Err(e) => panic!("Failed to run grep tool: {}", e),

View file

@ -1,13 +1,17 @@
use crate::schema::json_schema_for; use crate::schema::json_schema_for;
use anyhow::{Result, anyhow}; use anyhow::{Result, anyhow};
use assistant_tool::outline;
use assistant_tool::{ActionLog, Tool, ToolResult}; use assistant_tool::{ActionLog, Tool, ToolResult};
use assistant_tool::{ToolResultContent, outline};
use gpui::{AnyWindowHandle, App, Entity, Task}; use gpui::{AnyWindowHandle, App, Entity, Task};
use project::{ImageItem, image_store};
use assistant_tool::ToolResultOutput;
use indoc::formatdoc; use indoc::formatdoc;
use itertools::Itertools; use itertools::Itertools;
use language::{Anchor, Point}; use language::{Anchor, Point};
use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat}; use language_model::{
LanguageModel, LanguageModelImage, LanguageModelRequest, LanguageModelToolSchemaFormat,
};
use project::{AgentLocation, Project}; use project::{AgentLocation, Project};
use schemars::JsonSchema; use schemars::JsonSchema;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@ -86,7 +90,7 @@ impl Tool for ReadFileTool {
_request: Arc<LanguageModelRequest>, _request: Arc<LanguageModelRequest>,
project: Entity<Project>, project: Entity<Project>,
action_log: Entity<ActionLog>, action_log: Entity<ActionLog>,
_model: Arc<dyn LanguageModel>, model: Arc<dyn LanguageModel>,
_window: Option<AnyWindowHandle>, _window: Option<AnyWindowHandle>,
cx: &mut App, cx: &mut App,
) -> ToolResult { ) -> ToolResult {
@ -100,6 +104,42 @@ impl Tool for ReadFileTool {
}; };
let file_path = input.path.clone(); let file_path = input.path.clone();
if image_store::is_image_file(&project, &project_path, cx) {
if !model.supports_images() {
return Task::ready(Err(anyhow!(
"Attempted to read an image, but Zed doesn't currently sending images to {}.",
model.name().0
)))
.into();
}
let task = cx.spawn(async move |cx| -> Result<ToolResultOutput> {
let image_entity: Entity<ImageItem> = cx
.update(|cx| {
project.update(cx, |project, cx| {
project.open_image(project_path.clone(), cx)
})
})?
.await?;
let image =
image_entity.read_with(cx, |image_item, _| Arc::clone(&image_item.image))?;
let language_model_image = cx
.update(|cx| LanguageModelImage::from_image(image, cx))?
.await
.ok_or_else(|| anyhow!("Failed to process image"))?;
Ok(ToolResultOutput {
content: ToolResultContent::Image(language_model_image),
output: None,
})
});
return task.into();
}
cx.spawn(async move |cx| { cx.spawn(async move |cx| {
let buffer = cx let buffer = cx
.update(|cx| { .update(|cx| {
@ -282,7 +322,10 @@ mod test {
.output .output
}) })
.await; .await;
assert_eq!(result.unwrap().content, "This is a small file content"); assert_eq!(
result.unwrap().content.as_str(),
Some("This is a small file content")
);
} }
#[gpui::test] #[gpui::test]
@ -322,6 +365,7 @@ mod test {
}) })
.await; .await;
let content = result.unwrap(); let content = result.unwrap();
let content = content.as_str().unwrap();
assert_eq!( assert_eq!(
content.lines().skip(4).take(6).collect::<Vec<_>>(), content.lines().skip(4).take(6).collect::<Vec<_>>(),
vec![ vec![
@ -365,6 +409,8 @@ mod test {
.collect::<Vec<_>>(); .collect::<Vec<_>>();
pretty_assertions::assert_eq!( pretty_assertions::assert_eq!(
content content
.as_str()
.unwrap()
.lines() .lines()
.skip(4) .skip(4)
.take(expected_content.len()) .take(expected_content.len())
@ -408,7 +454,10 @@ mod test {
.output .output
}) })
.await; .await;
assert_eq!(result.unwrap().content, "Line 2\nLine 3\nLine 4"); assert_eq!(
result.unwrap().content.as_str(),
Some("Line 2\nLine 3\nLine 4")
);
} }
#[gpui::test] #[gpui::test]
@ -448,7 +497,7 @@ mod test {
.output .output
}) })
.await; .await;
assert_eq!(result.unwrap().content, "Line 1\nLine 2"); assert_eq!(result.unwrap().content.as_str(), Some("Line 1\nLine 2"));
// end_line of 0 should result in at least 1 line // end_line of 0 should result in at least 1 line
let result = cx let result = cx
@ -471,7 +520,7 @@ mod test {
.output .output
}) })
.await; .await;
assert_eq!(result.unwrap().content, "Line 1"); assert_eq!(result.unwrap().content.as_str(), Some("Line 1"));
// when start_line > end_line, should still return at least 1 line // when start_line > end_line, should still return at least 1 line
let result = cx let result = cx
@ -494,7 +543,7 @@ mod test {
.output .output
}) })
.await; .await;
assert_eq!(result.unwrap().content, "Line 3"); assert_eq!(result.unwrap().content.as_str(), Some("Line 3"));
} }
fn init_test(cx: &mut TestAppContext) { fn init_test(cx: &mut TestAppContext) {

View file

@ -1,5 +1,5 @@
use crate::schema::json_schema_for; use crate::schema::json_schema_for;
use anyhow::{Context as _, Result, anyhow, bail}; use anyhow::{Context as _, Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolCard, ToolResult, ToolUseStatus}; use assistant_tool::{ActionLog, Tool, ToolCard, ToolResult, ToolUseStatus};
use futures::{FutureExt as _, future::Shared}; use futures::{FutureExt as _, future::Shared};
use gpui::{ use gpui::{
@ -125,18 +125,24 @@ impl Tool for TerminalTool {
Err(err) => return Task::ready(Err(anyhow!(err))).into(), Err(err) => return Task::ready(Err(anyhow!(err))).into(),
}; };
let input_path = Path::new(&input.cd); let working_dir = match working_dir(&input, &project, cx) {
let working_dir = match working_dir(&input, &project, input_path, cx) {
Ok(dir) => dir, Ok(dir) => dir,
Err(err) => return Task::ready(Err(err)).into(), Err(err) => return Task::ready(Err(err)).into(),
}; };
let program = self.determine_shell.clone(); let program = self.determine_shell.clone();
let command = if cfg!(windows) { let command = if cfg!(windows) {
format!("$null | & {{{}}}", input.command.replace("\"", "'")) format!("$null | & {{{}}}", input.command.replace("\"", "'"))
} else if let Some(cwd) = working_dir
.as_ref()
.and_then(|cwd| cwd.as_os_str().to_str())
{
// Make sure once we're *inside* the shell, we cd into `cwd`
format!("(cd {cwd}; {}) </dev/null", input.command)
} else { } else {
format!("({}) </dev/null", input.command) format!("({}) </dev/null", input.command)
}; };
let args = vec!["-c".into(), command]; let args = vec!["-c".into(), command];
let cwd = working_dir.clone(); let cwd = working_dir.clone();
let env = match &working_dir { let env = match &working_dir {
Some(dir) => project.update(cx, |project, cx| { Some(dir) => project.update(cx, |project, cx| {
@ -319,19 +325,13 @@ fn process_content(
} else { } else {
content content
}; };
let is_empty = content.trim().is_empty(); let content = content.trim();
let is_empty = content.is_empty();
let content = format!( let content = format!("```\n{content}\n```");
"```\n{}{}```",
content,
if content.ends_with('\n') { "" } else { "\n" }
);
let content = if should_truncate { let content = if should_truncate {
format!( format!(
"Command output too long. The first {} bytes:\n\n{}", "Command output too long. The first {} bytes:\n\n{content}",
content.len(), content.len(),
content,
) )
} else { } else {
content content
@ -371,42 +371,47 @@ fn process_content(
fn working_dir( fn working_dir(
input: &TerminalToolInput, input: &TerminalToolInput,
project: &Entity<Project>, project: &Entity<Project>,
input_path: &Path,
cx: &mut App, cx: &mut App,
) -> Result<Option<PathBuf>> { ) -> Result<Option<PathBuf>> {
let project = project.read(cx); let project = project.read(cx);
let cd = &input.cd;
if input.cd == "." { if cd == "." || cd == "" {
// Accept "." as meaning "the one worktree" if we only have one worktree. // Accept "." or "" as meaning "the one worktree" if we only have one worktree.
let mut worktrees = project.worktrees(cx); let mut worktrees = project.worktrees(cx);
match worktrees.next() { match worktrees.next() {
Some(worktree) => { Some(worktree) => {
if worktrees.next().is_some() { if worktrees.next().is_none() {
bail!(
"'.' is ambiguous in multi-root workspaces. Please specify a root directory explicitly.",
);
}
Ok(Some(worktree.read(cx).abs_path().to_path_buf())) Ok(Some(worktree.read(cx).abs_path().to_path_buf()))
} else {
Err(anyhow!(
"'.' is ambiguous in multi-root workspaces. Please specify a root directory explicitly.",
))
}
} }
None => Ok(None), None => Ok(None),
} }
} else if input_path.is_absolute() { } else {
let input_path = Path::new(cd);
if input_path.is_absolute() {
// Absolute paths are allowed, but only if they're in one of the project's worktrees. // Absolute paths are allowed, but only if they're in one of the project's worktrees.
if !project if project
.worktrees(cx) .worktrees(cx)
.any(|worktree| input_path.starts_with(&worktree.read(cx).abs_path())) .any(|worktree| input_path.starts_with(&worktree.read(cx).abs_path()))
{ {
bail!("The absolute path must be within one of the project's worktrees"); return Ok(Some(input_path.into()));
}
} else {
if let Some(worktree) = project.worktree_for_root_name(cd, cx) {
return Ok(Some(worktree.read(cx).abs_path().to_path_buf()));
}
} }
Ok(Some(input_path.into())) Err(anyhow!(
} else { "`cd` directory {cd:?} was not in any of the project's worktrees."
let Some(worktree) = project.worktree_for_root_name(&input.cd, cx) else { ))
bail!("`cd` directory {:?} not found in the project", input.cd);
};
Ok(Some(worktree.read(cx).abs_path().to_path_buf()))
} }
} }
@ -727,8 +732,8 @@ mod tests {
) )
}); });
let output = result.output.await.log_err().map(|output| output.content); let output = result.output.await.log_err().unwrap().content;
assert_eq!(output, Some("Command executed successfully.".into())); assert_eq!(output.as_str().unwrap(), "Command executed successfully.");
} }
#[gpui::test] #[gpui::test]
@ -761,12 +766,13 @@ mod tests {
cx, cx,
); );
cx.spawn(async move |_| { cx.spawn(async move |_| {
let output = headless_result let output = headless_result.output.await.map(|output| output.content);
.output assert_eq!(
.await output
.log_err() .ok()
.map(|output| output.content); .and_then(|content| content.as_str().map(ToString::to_string)),
assert_eq!(output, expected); expected
);
}) })
}; };
@ -774,7 +780,7 @@ mod tests {
check( check(
TerminalToolInput { TerminalToolInput {
command: "pwd".into(), command: "pwd".into(),
cd: "project".into(), cd: ".".into(),
}, },
Some(format!( Some(format!(
"```\n{}\n```", "```\n{}\n```",
@ -789,12 +795,9 @@ mod tests {
check( check(
TerminalToolInput { TerminalToolInput {
command: "pwd".into(), command: "pwd".into(),
cd: ".".into(), cd: "other-project".into(),
}, },
Some(format!( None, // other-project is a dir, but *not* a worktree (yet)
"```\n{}\n```",
tree.path().join("project").display()
)),
cx, cx,
) )
}) })

View file

@ -3,7 +3,9 @@ use std::{sync::Arc, time::Duration};
use crate::schema::json_schema_for; use crate::schema::json_schema_for;
use crate::ui::ToolCallCardHeader; use crate::ui::ToolCallCardHeader;
use anyhow::{Context as _, Result, anyhow}; use anyhow::{Context as _, Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolCard, ToolResult, ToolResultOutput, ToolUseStatus}; use assistant_tool::{
ActionLog, Tool, ToolCard, ToolResult, ToolResultContent, ToolResultOutput, ToolUseStatus,
};
use futures::{Future, FutureExt, TryFutureExt}; use futures::{Future, FutureExt, TryFutureExt};
use gpui::{ use gpui::{
AnyWindowHandle, App, AppContext, Context, Entity, IntoElement, Task, WeakEntity, Window, AnyWindowHandle, App, AppContext, Context, Entity, IntoElement, Task, WeakEntity, Window,
@ -74,8 +76,10 @@ impl Tool for WebSearchTool {
async move { async move {
let response = search_task.await.map_err(|err| anyhow!(err))?; let response = search_task.await.map_err(|err| anyhow!(err))?;
Ok(ToolResultOutput { Ok(ToolResultOutput {
content: serde_json::to_string(&response) content: ToolResultContent::Text(
serde_json::to_string(&response)
.context("Failed to serialize search results")?, .context("Failed to serialize search results")?,
),
output: Some(serde_json::to_value(response)?), output: Some(serde_json::to_value(response)?),
}) })
} }

View file

@ -113,7 +113,7 @@ pub enum ModelVendor {
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq, Clone)] #[derive(Serialize, Deserialize, Debug, Eq, PartialEq, Clone)]
#[serde(tag = "type")] #[serde(tag = "type")]
pub enum ChatMessageContent { pub enum ChatMessagePart {
#[serde(rename = "text")] #[serde(rename = "text")]
Text { text: String }, Text { text: String },
#[serde(rename = "image_url")] #[serde(rename = "image_url")]
@ -194,26 +194,55 @@ pub enum ToolChoice {
None, None,
} }
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] #[derive(Serialize, Deserialize, Debug)]
#[serde(tag = "role", rename_all = "lowercase")] #[serde(tag = "role", rename_all = "lowercase")]
pub enum ChatMessage { pub enum ChatMessage {
Assistant { Assistant {
content: Option<String>, content: ChatMessageContent,
#[serde(default, skip_serializing_if = "Vec::is_empty")] #[serde(default, skip_serializing_if = "Vec::is_empty")]
tool_calls: Vec<ToolCall>, tool_calls: Vec<ToolCall>,
}, },
User { User {
content: Vec<ChatMessageContent>, content: ChatMessageContent,
}, },
System { System {
content: String, content: String,
}, },
Tool { Tool {
content: String, content: ChatMessageContent,
tool_call_id: String, tool_call_id: String,
}, },
} }
#[derive(Debug, Serialize, Deserialize)]
#[serde(untagged)]
pub enum ChatMessageContent {
OnlyText(String),
Multipart(Vec<ChatMessagePart>),
}
impl ChatMessageContent {
pub fn empty() -> Self {
ChatMessageContent::Multipart(vec![])
}
}
impl From<Vec<ChatMessagePart>> for ChatMessageContent {
fn from(mut parts: Vec<ChatMessagePart>) -> Self {
if let [ChatMessagePart::Text { text }] = parts.as_mut_slice() {
ChatMessageContent::OnlyText(std::mem::take(text))
} else {
ChatMessageContent::Multipart(parts)
}
}
}
impl From<String> for ChatMessageContent {
fn from(text: String) -> Self {
ChatMessageContent::OnlyText(text)
}
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] #[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
pub struct ToolCall { pub struct ToolCall {
pub id: String, pub id: String,

View file

@ -9,7 +9,7 @@ use handlebars::Handlebars;
use language::{Buffer, DiagnosticSeverity, OffsetRangeExt as _}; use language::{Buffer, DiagnosticSeverity, OffsetRangeExt as _};
use language_model::{ use language_model::{
LanguageModel, LanguageModelCompletionEvent, LanguageModelRequest, LanguageModelRequestMessage, LanguageModel, LanguageModelCompletionEvent, LanguageModelRequest, LanguageModelRequestMessage,
MessageContent, Role, TokenUsage, LanguageModelToolResultContent, MessageContent, Role, TokenUsage,
}; };
use project::lsp_store::OpenLspBufferHandle; use project::lsp_store::OpenLspBufferHandle;
use project::{DiagnosticSummary, Project, ProjectPath}; use project::{DiagnosticSummary, Project, ProjectPath};
@ -964,7 +964,15 @@ impl RequestMarkdown {
if tool_result.is_error { if tool_result.is_error {
messages.push_str("**ERROR:**\n"); messages.push_str("**ERROR:**\n");
} }
messages.push_str(&format!("{}\n\n", tool_result.content));
match &tool_result.content {
LanguageModelToolResultContent::Text(str) => {
writeln!(messages, "{}\n", str).ok();
}
LanguageModelToolResultContent::Image(image) => {
writeln!(messages, "![Image](data:base64,{})\n", image.source).ok();
}
}
if let Some(output) = tool_result.output.as_ref() { if let Some(output) = tool_result.output.as_ref() {
writeln!( writeln!(

View file

@ -157,6 +157,10 @@ impl LanguageModel for FakeLanguageModel {
false false
} }
fn supports_images(&self) -> bool {
false
}
fn telemetry_id(&self) -> String { fn telemetry_id(&self) -> String {
"fake".to_string() "fake".to_string()
} }

View file

@ -243,6 +243,9 @@ pub trait LanguageModel: Send + Sync {
LanguageModelAvailability::Public LanguageModelAvailability::Public
} }
/// Whether this model supports images
fn supports_images(&self) -> bool;
/// Whether this model supports tools. /// Whether this model supports tools.
fn supports_tools(&self) -> bool; fn supports_tools(&self) -> bool;

View file

@ -21,6 +21,16 @@ pub struct LanguageModelImage {
size: Size<DevicePixels>, size: Size<DevicePixels>,
} }
impl LanguageModelImage {
pub fn len(&self) -> usize {
self.source.len()
}
pub fn is_empty(&self) -> bool {
self.source.is_empty()
}
}
impl std::fmt::Debug for LanguageModelImage { impl std::fmt::Debug for LanguageModelImage {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("LanguageModelImage") f.debug_struct("LanguageModelImage")
@ -134,10 +144,45 @@ pub struct LanguageModelToolResult {
pub tool_use_id: LanguageModelToolUseId, pub tool_use_id: LanguageModelToolUseId,
pub tool_name: Arc<str>, pub tool_name: Arc<str>,
pub is_error: bool, pub is_error: bool,
pub content: Arc<str>, pub content: LanguageModelToolResultContent,
pub output: Option<serde_json::Value>, pub output: Option<serde_json::Value>,
} }
#[derive(Debug, Clone, Deserialize, Serialize, Eq, PartialEq, Hash)]
#[serde(untagged)]
pub enum LanguageModelToolResultContent {
Text(Arc<str>),
Image(LanguageModelImage),
}
impl LanguageModelToolResultContent {
pub fn to_str(&self) -> Option<&str> {
match self {
Self::Text(text) => Some(&text),
Self::Image(_) => None,
}
}
pub fn is_empty(&self) -> bool {
match self {
Self::Text(text) => text.chars().all(|c| c.is_whitespace()),
Self::Image(_) => false,
}
}
}
impl From<&str> for LanguageModelToolResultContent {
fn from(value: &str) -> Self {
Self::Text(Arc::from(value))
}
}
impl From<String> for LanguageModelToolResultContent {
fn from(value: String) -> Self {
Self::Text(Arc::from(value))
}
}
#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)] #[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)]
pub enum MessageContent { pub enum MessageContent {
Text(String), Text(String),
@ -151,6 +196,29 @@ pub enum MessageContent {
ToolResult(LanguageModelToolResult), ToolResult(LanguageModelToolResult),
} }
impl MessageContent {
pub fn to_str(&self) -> Option<&str> {
match self {
MessageContent::Text(text) => Some(text.as_str()),
MessageContent::Thinking { text, .. } => Some(text.as_str()),
MessageContent::RedactedThinking(_) => None,
MessageContent::ToolResult(tool_result) => tool_result.content.to_str(),
MessageContent::ToolUse(_) | MessageContent::Image(_) => None,
}
}
pub fn is_empty(&self) -> bool {
match self {
MessageContent::Text(text) => text.chars().all(|c| c.is_whitespace()),
MessageContent::Thinking { text, .. } => text.chars().all(|c| c.is_whitespace()),
MessageContent::ToolResult(tool_result) => tool_result.content.is_empty(),
MessageContent::RedactedThinking(_)
| MessageContent::ToolUse(_)
| MessageContent::Image(_) => false,
}
}
}
impl From<String> for MessageContent { impl From<String> for MessageContent {
fn from(value: String) -> Self { fn from(value: String) -> Self {
MessageContent::Text(value) MessageContent::Text(value)
@ -173,13 +241,7 @@ pub struct LanguageModelRequestMessage {
impl LanguageModelRequestMessage { impl LanguageModelRequestMessage {
pub fn string_contents(&self) -> String { pub fn string_contents(&self) -> String {
let mut buffer = String::new(); let mut buffer = String::new();
for string in self.content.iter().filter_map(|content| match content { for string in self.content.iter().filter_map(|content| content.to_str()) {
MessageContent::Text(text) => Some(text.as_str()),
MessageContent::Thinking { text, .. } => Some(text.as_str()),
MessageContent::RedactedThinking(_) => None,
MessageContent::ToolResult(tool_result) => Some(tool_result.content.as_ref()),
MessageContent::ToolUse(_) | MessageContent::Image(_) => None,
}) {
buffer.push_str(string); buffer.push_str(string);
} }
@ -187,16 +249,7 @@ impl LanguageModelRequestMessage {
} }
pub fn contents_empty(&self) -> bool { pub fn contents_empty(&self) -> bool {
self.content.iter().all(|content| match content { self.content.iter().all(|content| content.is_empty())
MessageContent::Text(text) => text.chars().all(|c| c.is_whitespace()),
MessageContent::Thinking { text, .. } => text.chars().all(|c| c.is_whitespace()),
MessageContent::ToolResult(tool_result) => {
tool_result.content.chars().all(|c| c.is_whitespace())
}
MessageContent::RedactedThinking(_)
| MessageContent::ToolUse(_)
| MessageContent::Image(_) => false,
})
} }
} }

View file

@ -759,6 +759,10 @@ mod tests {
false false
} }
fn supports_images(&self) -> bool {
false
}
fn telemetry_id(&self) -> String { fn telemetry_id(&self) -> String {
format!("{}/{}", self.provider_id.0, self.name.0) format!("{}/{}", self.provider_id.0, self.name.0)
} }

View file

@ -1,6 +1,9 @@
use crate::AllLanguageModelSettings; use crate::AllLanguageModelSettings;
use crate::ui::InstructionListItem; use crate::ui::InstructionListItem;
use anthropic::{AnthropicError, AnthropicModelMode, ContentDelta, Event, ResponseContent, Usage}; use anthropic::{
AnthropicError, AnthropicModelMode, ContentDelta, Event, ResponseContent, ToolResultContent,
ToolResultPart, Usage,
};
use anyhow::{Context as _, Result, anyhow}; use anyhow::{Context as _, Result, anyhow};
use collections::{BTreeMap, HashMap}; use collections::{BTreeMap, HashMap};
use credentials_provider::CredentialsProvider; use credentials_provider::CredentialsProvider;
@ -15,8 +18,8 @@ use language_model::{
AuthenticateError, LanguageModel, LanguageModelCacheConfiguration, AuthenticateError, LanguageModel, LanguageModelCacheConfiguration,
LanguageModelCompletionError, LanguageModelId, LanguageModelKnownError, LanguageModelName, LanguageModelCompletionError, LanguageModelId, LanguageModelKnownError, LanguageModelName,
LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelRequest, LanguageModelToolChoice, MessageContent, LanguageModelProviderState, LanguageModelRequest, LanguageModelToolChoice,
RateLimiter, Role, LanguageModelToolResultContent, MessageContent, RateLimiter, Role,
}; };
use language_model::{LanguageModelCompletionEvent, LanguageModelToolUse, StopReason}; use language_model::{LanguageModelCompletionEvent, LanguageModelToolUse, StopReason};
use schemars::JsonSchema; use schemars::JsonSchema;
@ -346,9 +349,14 @@ pub fn count_anthropic_tokens(
MessageContent::ToolUse(_tool_use) => { MessageContent::ToolUse(_tool_use) => {
// TODO: Estimate token usage from tool uses. // TODO: Estimate token usage from tool uses.
} }
MessageContent::ToolResult(tool_result) => { MessageContent::ToolResult(tool_result) => match &tool_result.content {
string_contents.push_str(&tool_result.content); LanguageModelToolResultContent::Text(txt) => {
string_contents.push_str(txt);
} }
LanguageModelToolResultContent::Image(image) => {
tokens_from_images += image.estimate_tokens();
}
},
} }
} }
@ -421,6 +429,10 @@ impl LanguageModel for AnthropicModel {
true true
} }
fn supports_images(&self) -> bool {
true
}
fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool { fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
match choice { match choice {
LanguageModelToolChoice::Auto LanguageModelToolChoice::Auto
@ -575,7 +587,20 @@ pub fn into_anthropic(
Some(anthropic::RequestContent::ToolResult { Some(anthropic::RequestContent::ToolResult {
tool_use_id: tool_result.tool_use_id.to_string(), tool_use_id: tool_result.tool_use_id.to_string(),
is_error: tool_result.is_error, is_error: tool_result.is_error,
content: tool_result.content.to_string(), content: match tool_result.content {
LanguageModelToolResultContent::Text(text) => {
ToolResultContent::JustText(text.to_string())
}
LanguageModelToolResultContent::Image(image) => {
ToolResultContent::Multipart(vec![ToolResultPart::Image {
source: anthropic::ImageSource {
source_type: "base64".to_string(),
media_type: "image/png".to_string(),
data: image.source.to_string(),
},
}])
}
},
cache_control, cache_control,
}) })
} }

View file

@ -36,7 +36,8 @@ use language_model::{
LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName,
LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelRequest, LanguageModelToolChoice, LanguageModelProviderState, LanguageModelRequest, LanguageModelToolChoice,
LanguageModelToolUse, MessageContent, RateLimiter, Role, TokenUsage, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent, RateLimiter, Role,
TokenUsage,
}; };
use schemars::JsonSchema; use schemars::JsonSchema;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@ -490,6 +491,10 @@ impl LanguageModel for BedrockModel {
self.model.supports_tool_use() self.model.supports_tool_use()
} }
fn supports_images(&self) -> bool {
false
}
fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool { fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
match choice { match choice {
LanguageModelToolChoice::Auto | LanguageModelToolChoice::Any => { LanguageModelToolChoice::Auto | LanguageModelToolChoice::Any => {
@ -635,9 +640,17 @@ pub fn into_bedrock(
MessageContent::ToolResult(tool_result) => { MessageContent::ToolResult(tool_result) => {
BedrockToolResultBlock::builder() BedrockToolResultBlock::builder()
.tool_use_id(tool_result.tool_use_id.to_string()) .tool_use_id(tool_result.tool_use_id.to_string())
.content(BedrockToolResultContentBlock::Text( .content(match tool_result.content {
tool_result.content.to_string(), LanguageModelToolResultContent::Text(text) => {
)) BedrockToolResultContentBlock::Text(text.to_string())
}
LanguageModelToolResultContent::Image(_) => {
BedrockToolResultContentBlock::Text(
// TODO: Bedrock image support
"[Tool responded with an image, but Zed doesn't support these in Bedrock models yet]".to_string()
)
}
})
.status({ .status({
if tool_result.is_error { if tool_result.is_error {
BedrockToolResultStatus::Error BedrockToolResultStatus::Error
@ -762,9 +775,14 @@ pub fn get_bedrock_tokens(
MessageContent::ToolUse(_tool_use) => { MessageContent::ToolUse(_tool_use) => {
// TODO: Estimate token usage from tool uses. // TODO: Estimate token usage from tool uses.
} }
MessageContent::ToolResult(tool_result) => { MessageContent::ToolResult(tool_result) => match tool_result.content {
string_contents.push_str(&tool_result.content); LanguageModelToolResultContent::Text(text) => {
string_contents.push_str(&text);
} }
LanguageModelToolResultContent::Image(image) => {
tokens_from_images += image.estimate_tokens();
}
},
} }
} }

View file

@ -686,6 +686,14 @@ impl LanguageModel for CloudLanguageModel {
} }
} }
fn supports_images(&self) -> bool {
match self.model {
CloudModel::Anthropic(_) => true,
CloudModel::Google(_) => true,
CloudModel::OpenAi(_) => false,
}
}
fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool { fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
match choice { match choice {
LanguageModelToolChoice::Auto LanguageModelToolChoice::Auto

View file

@ -5,8 +5,9 @@ use std::sync::Arc;
use anyhow::{Result, anyhow}; use anyhow::{Result, anyhow};
use collections::HashMap; use collections::HashMap;
use copilot::copilot_chat::{ use copilot::copilot_chat::{
ChatMessage, ChatMessageContent, CopilotChat, ImageUrl, Model as CopilotChatModel, ModelVendor, ChatMessage, ChatMessageContent, ChatMessagePart, CopilotChat, ImageUrl,
Request as CopilotChatRequest, ResponseEvent, Tool, ToolCall, Model as CopilotChatModel, ModelVendor, Request as CopilotChatRequest, ResponseEvent, Tool,
ToolCall,
}; };
use copilot::{Copilot, Status}; use copilot::{Copilot, Status};
use futures::future::BoxFuture; use futures::future::BoxFuture;
@ -20,12 +21,14 @@ use language_model::{
AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
LanguageModelRequestMessage, LanguageModelToolChoice, LanguageModelToolSchemaFormat, LanguageModelRequestMessage, LanguageModelToolChoice, LanguageModelToolResultContent,
LanguageModelToolUse, MessageContent, RateLimiter, Role, StopReason, LanguageModelToolSchemaFormat, LanguageModelToolUse, MessageContent, RateLimiter, Role,
StopReason,
}; };
use settings::SettingsStore; use settings::SettingsStore;
use std::time::Duration; use std::time::Duration;
use ui::prelude::*; use ui::prelude::*;
use util::debug_panic;
use super::anthropic::count_anthropic_tokens; use super::anthropic::count_anthropic_tokens;
use super::google::count_google_tokens; use super::google::count_google_tokens;
@ -198,6 +201,10 @@ impl LanguageModel for CopilotChatLanguageModel {
self.model.supports_tools() self.model.supports_tools()
} }
fn supports_images(&self) -> bool {
self.model.supports_vision()
}
fn tool_input_format(&self) -> LanguageModelToolSchemaFormat { fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
match self.model.vendor() { match self.model.vendor() {
ModelVendor::OpenAI | ModelVendor::Anthropic => { ModelVendor::OpenAI | ModelVendor::Anthropic => {
@ -447,9 +454,28 @@ fn into_copilot_chat(
Role::User => { Role::User => {
for content in &message.content { for content in &message.content {
if let MessageContent::ToolResult(tool_result) = content { if let MessageContent::ToolResult(tool_result) = content {
let content = match &tool_result.content {
LanguageModelToolResultContent::Text(text) => text.to_string().into(),
LanguageModelToolResultContent::Image(image) => {
if model.supports_vision() {
ChatMessageContent::Multipart(vec![ChatMessagePart::Image {
image_url: ImageUrl {
url: image.to_base64_url(),
},
}])
} else {
debug_panic!(
"This should be caught at {} level",
tool_result.tool_name
);
"[Tool responded with an image, but this model does not support vision]".to_string().into()
}
}
};
messages.push(ChatMessage::Tool { messages.push(ChatMessage::Tool {
tool_call_id: tool_result.tool_use_id.to_string(), tool_call_id: tool_result.tool_use_id.to_string(),
content: tool_result.content.to_string(), content,
}); });
} }
} }
@ -460,18 +486,18 @@ fn into_copilot_chat(
MessageContent::Text(text) | MessageContent::Thinking { text, .. } MessageContent::Text(text) | MessageContent::Thinking { text, .. }
if !text.is_empty() => if !text.is_empty() =>
{ {
if let Some(ChatMessageContent::Text { text: text_content }) = if let Some(ChatMessagePart::Text { text: text_content }) =
content_parts.last_mut() content_parts.last_mut()
{ {
text_content.push_str(text); text_content.push_str(text);
} else { } else {
content_parts.push(ChatMessageContent::Text { content_parts.push(ChatMessagePart::Text {
text: text.to_string(), text: text.to_string(),
}); });
} }
} }
MessageContent::Image(image) if model.supports_vision() => { MessageContent::Image(image) if model.supports_vision() => {
content_parts.push(ChatMessageContent::Image { content_parts.push(ChatMessagePart::Image {
image_url: ImageUrl { image_url: ImageUrl {
url: image.to_base64_url(), url: image.to_base64_url(),
}, },
@ -483,7 +509,7 @@ fn into_copilot_chat(
if !content_parts.is_empty() { if !content_parts.is_empty() {
messages.push(ChatMessage::User { messages.push(ChatMessage::User {
content: content_parts, content: content_parts.into(),
}); });
} }
} }
@ -523,9 +549,9 @@ fn into_copilot_chat(
messages.push(ChatMessage::Assistant { messages.push(ChatMessage::Assistant {
content: if text_content.is_empty() { content: if text_content.is_empty() {
None ChatMessageContent::empty()
} else { } else {
Some(text_content) text_content.into()
}, },
tool_calls, tool_calls,
}); });

View file

@ -287,6 +287,10 @@ impl LanguageModel for DeepSeekLanguageModel {
false false
} }
fn supports_images(&self) -> bool {
false
}
fn telemetry_id(&self) -> String { fn telemetry_id(&self) -> String {
format!("deepseek/{}", self.model.id()) format!("deepseek/{}", self.model.id())
} }

View file

@ -313,6 +313,10 @@ impl LanguageModel for GoogleLanguageModel {
true true
} }
fn supports_images(&self) -> bool {
true
}
fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool { fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
match choice { match choice {
LanguageModelToolChoice::Auto LanguageModelToolChoice::Auto

View file

@ -285,6 +285,10 @@ impl LanguageModel for LmStudioLanguageModel {
false false
} }
fn supports_images(&self) -> bool {
false
}
fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool { fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
false false
} }

View file

@ -303,6 +303,10 @@ impl LanguageModel for MistralLanguageModel {
false false
} }
fn supports_images(&self) -> bool {
false
}
fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool { fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
false false
} }

View file

@ -325,6 +325,10 @@ impl LanguageModel for OllamaLanguageModel {
self.model.supports_tools.unwrap_or(false) self.model.supports_tools.unwrap_or(false)
} }
fn supports_images(&self) -> bool {
false
}
fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool { fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
match choice { match choice {
LanguageModelToolChoice::Auto => false, LanguageModelToolChoice::Auto => false,

View file

@ -12,7 +12,8 @@ use language_model::{
AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
LanguageModelToolChoice, LanguageModelToolUse, MessageContent, RateLimiter, Role, StopReason, LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent,
RateLimiter, Role, StopReason,
}; };
use open_ai::{Model, ResponseStreamEvent, stream_completion}; use open_ai::{Model, ResponseStreamEvent, stream_completion};
use schemars::JsonSchema; use schemars::JsonSchema;
@ -295,6 +296,10 @@ impl LanguageModel for OpenAiLanguageModel {
true true
} }
fn supports_images(&self) -> bool {
false
}
fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool { fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
match choice { match choice {
LanguageModelToolChoice::Auto => true, LanguageModelToolChoice::Auto => true,
@ -392,8 +397,16 @@ pub fn into_open_ai(
} }
} }
MessageContent::ToolResult(tool_result) => { MessageContent::ToolResult(tool_result) => {
let content = match &tool_result.content {
LanguageModelToolResultContent::Text(text) => text.to_string(),
LanguageModelToolResultContent::Image(_) => {
// TODO: Open AI image support
"[Tool responded with an image, but Zed doesn't support these in Open AI models yet]".to_string()
}
};
messages.push(open_ai::RequestMessage::Tool { messages.push(open_ai::RequestMessage::Tool {
content: tool_result.content.to_string(), content,
tool_call_id: tool_result.tool_use_id.to_string(), tool_call_id: tool_result.tool_use_id.to_string(),
}); });
} }

View file

@ -2,7 +2,7 @@
/// The tests in this file assume that server_cx is running on Windows too. /// The tests in this file assume that server_cx is running on Windows too.
/// We neead to find a way to test Windows-Non-Windows interactions. /// We neead to find a way to test Windows-Non-Windows interactions.
use crate::headless_project::HeadlessProject; use crate::headless_project::HeadlessProject;
use assistant_tool::Tool as _; use assistant_tool::{Tool as _, ToolResultContent};
use assistant_tools::{ReadFileTool, ReadFileToolInput}; use assistant_tools::{ReadFileTool, ReadFileToolInput};
use client::{Client, UserStore}; use client::{Client, UserStore};
use clock::FakeSystemClock; use clock::FakeSystemClock;
@ -1593,7 +1593,7 @@ async fn test_remote_agent_fs_tool_calls(cx: &mut TestAppContext, server_cx: &mu
) )
}); });
let output = exists_result.output.await.unwrap().content; let output = exists_result.output.await.unwrap().content;
assert_eq!(output, "B"); assert_eq!(output, ToolResultContent::Text("B".to_string()));
let input = ReadFileToolInput { let input = ReadFileToolInput {
path: "project/c.txt".into(), path: "project/c.txt".into(),