Support images in read file tool

Co-authored-by: Antonio Scandurra <me@as-cii.com>
Co-authored-by: Agus Zubiaga <agus@zed.dev>
This commit is contained in:
Ben Brandt 2025-08-08 14:28:03 +02:00
parent 4d5b22a583
commit ebc7df2c2e
No known key found for this signature in database
GPG key ID: D4618C5D3B500571
5 changed files with 137 additions and 86 deletions

View file

@ -272,7 +272,7 @@ async fn test_tool_authorization(cx: &mut TestAppContext) {
tool_name: ToolRequiringPermission.name().into(), tool_name: ToolRequiringPermission.name().into(),
is_error: false, is_error: false,
content: "Allowed".into(), content: "Allowed".into(),
output: None output: Some("Allowed".into())
}), }),
MessageContent::ToolResult(LanguageModelToolResult { MessageContent::ToolResult(LanguageModelToolResult {
tool_use_id: tool_call_auth_2.tool_call.id.0.to_string().into(), tool_use_id: tool_call_auth_2.tool_call.id.0.to_string().into(),

View file

@ -510,15 +510,27 @@ impl Thread {
status: Some(acp::ToolCallStatus::InProgress), status: Some(acp::ToolCallStatus::InProgress),
..Default::default() ..Default::default()
}); });
let supports_images = self.selected_model.supports_images();
let tool_result = tool.run(tool_use.input, tool_event_stream, cx); let tool_result = tool.run(tool_use.input, tool_event_stream, cx);
Some(cx.foreground_executor().spawn(async move { Some(cx.foreground_executor().spawn(async move {
match tool_result.await { let tool_result = tool_result.await.and_then(|output| {
Ok(tool_output) => LanguageModelToolResult { if let LanguageModelToolResultContent::Image(_) = &output.llm_output {
if !supports_images {
return Err(anyhow!(
"Attempted to read an image, but this model doesn't support it.",
));
}
}
Ok(output)
});
match tool_result {
Ok(output) => LanguageModelToolResult {
tool_use_id: tool_use.id, tool_use_id: tool_use.id,
tool_name: tool_use.name, tool_name: tool_use.name,
is_error: false, is_error: false,
content: tool_output.llm_output, content: output.llm_output,
output: Some(tool_output.raw_output), output: Some(output.raw_output),
}, },
Err(error) => LanguageModelToolResult { Err(error) => LanguageModelToolResult {
tool_use_id: tool_use.id, tool_use_id: tool_use.id,

View file

@ -882,7 +882,7 @@ mod tests {
} }
#[gpui::test] #[gpui::test]
async fn test_needs_confirmation(cx: &mut TestAppContext) { async fn test_authorize(cx: &mut TestAppContext) {
init_test(cx); init_test(cx);
let fs = project::FakeFs::new(cx.executor()); let fs = project::FakeFs::new(cx.executor());
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
@ -967,23 +967,7 @@ mod tests {
let event = stream_rx.expect_tool_authorization().await; let event = stream_rx.expect_tool_authorization().await;
assert_eq!(event.tool_call.title, "test 4 (local settings)"); assert_eq!(event.tool_call.title, "test 4 (local settings)");
// Test 5: Path outside of the project should require confirmation. // Test 5: When always_allow_tool_actions is enabled, no confirmation needed
let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
let _auth = cx.update(|cx| {
tool.authorize(
&EditFileToolInput {
display_description: "test 5".into(),
path: paths::config_dir().join("tasks.json"),
mode: EditFileMode::Edit,
},
&stream_tx,
cx,
)
});
let event = stream_rx.expect_tool_authorization().await;
assert_eq!(event.tool_call.title, "test 5 (global settings)");
// Test 6: When always_allow_tool_actions is enabled, no confirmation needed
cx.update(|cx| { cx.update(|cx| {
let mut settings = agent_settings::AgentSettings::get_global(cx).clone(); let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
settings.always_allow_tool_actions = true; settings.always_allow_tool_actions = true;
@ -994,7 +978,7 @@ mod tests {
cx.update(|cx| { cx.update(|cx| {
tool.authorize( tool.authorize(
&EditFileToolInput { &EditFileToolInput {
display_description: "test 6.1".into(), display_description: "test 5.1".into(),
path: ".zed/settings.json".into(), path: ".zed/settings.json".into(),
mode: EditFileMode::Edit, mode: EditFileMode::Edit,
}, },
@ -1010,7 +994,7 @@ mod tests {
cx.update(|cx| { cx.update(|cx| {
tool.authorize( tool.authorize(
&EditFileToolInput { &EditFileToolInput {
display_description: "test 6.2".into(), display_description: "test 5.2".into(),
path: "/etc/hosts".into(), path: "/etc/hosts".into(),
mode: EditFileMode::Edit, mode: EditFileMode::Edit,
}, },
@ -1023,6 +1007,72 @@ mod tests {
assert!(stream_rx.try_next().is_err()); assert!(stream_rx.try_next().is_err());
} }
#[gpui::test]
async fn test_authorize_global_config(cx: &mut TestAppContext) {
init_test(cx);
let fs = project::FakeFs::new(cx.executor());
fs.insert_tree("/project", json!({})).await;
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let model = Arc::new(FakeLanguageModel::default());
let thread = cx.new(|_| {
Thread::new(
project,
Rc::default(),
action_log.clone(),
Templates::new(),
model.clone(),
)
});
let tool = Arc::new(EditFileTool { thread });
// Test global config paths - these should require confirmation if they exist and are outside the project
let test_cases = vec![
(
"/etc/hosts",
true,
"System file should require confirmation",
),
(
"/usr/local/bin/script",
true,
"System bin file should require confirmation",
),
(
"project/normal_file.rs",
false,
"Normal project file should not require confirmation",
),
];
for (path, should_confirm, description) in test_cases {
let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
let auth = cx.update(|cx| {
tool.authorize(
&EditFileToolInput {
display_description: "Edit file".into(),
path: path.into(),
mode: EditFileMode::Edit,
},
&stream_tx,
cx,
)
});
if should_confirm {
stream_rx.expect_tool_authorization().await;
} else {
auth.await.unwrap();
assert!(
stream_rx.try_next().is_err(),
"Failed for case: {} - path: {} - expected no confirmation but got one",
description,
path
);
}
}
}
#[gpui::test] #[gpui::test]
async fn test_needs_confirmation_with_multiple_worktrees(cx: &mut TestAppContext) { async fn test_needs_confirmation_with_multiple_worktrees(cx: &mut TestAppContext) {
init_test(cx); init_test(cx);

View file

@ -1,10 +1,11 @@
use agent_client_protocol::{self as acp}; use agent_client_protocol::{self as acp};
use anyhow::{anyhow, Result}; use anyhow::{anyhow, Context, Result};
use assistant_tool::{outline, ActionLog}; use assistant_tool::{outline, ActionLog};
use gpui::{Entity, Task}; use gpui::{Entity, Task};
use indoc::formatdoc; use indoc::formatdoc;
use language::{Anchor, Point}; use language::{Anchor, Point};
use project::{AgentLocation, Project, WorktreeSettings}; use language_model::{LanguageModelImage, LanguageModelToolResultContent};
use project::{image_store, AgentLocation, ImageItem, Project, WorktreeSettings};
use schemars::JsonSchema; use schemars::JsonSchema;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use settings::Settings; use settings::Settings;
@ -59,7 +60,7 @@ impl ReadFileTool {
impl AgentTool for ReadFileTool { impl AgentTool for ReadFileTool {
type Input = ReadFileToolInput; type Input = ReadFileToolInput;
type Output = String; type Output = LanguageModelToolResultContent;
fn name(&self) -> SharedString { fn name(&self) -> SharedString {
"read_file".into() "read_file".into()
@ -92,9 +93,9 @@ impl AgentTool for ReadFileTool {
fn run( fn run(
self: Arc<Self>, self: Arc<Self>,
input: Self::Input, input: Self::Input,
event_stream: ToolCallEventStream, _event_stream: ToolCallEventStream,
cx: &mut App, cx: &mut App,
) -> Task<Result<String>> { ) -> Task<Result<LanguageModelToolResultContent>> {
let Some(project_path) = self.project.read(cx).find_project_path(&input.path, cx) else { let Some(project_path) = self.project.read(cx).find_project_path(&input.path, cx) else {
return Task::ready(Err(anyhow!("Path {} not found in project", &input.path))); return Task::ready(Err(anyhow!("Path {} not found in project", &input.path)));
}; };
@ -133,51 +134,27 @@ impl AgentTool for ReadFileTool {
let file_path = input.path.clone(); let file_path = input.path.clone();
event_stream.send_update(acp::ToolCallUpdateFields { if image_store::is_image_file(&self.project, &project_path, cx) {
locations: Some(vec![acp::ToolCallLocation { return cx.spawn(async move |cx| {
path: project_path.path.to_path_buf(), let image_entity: Entity<ImageItem> = cx
line: input.start_line, .update(|cx| {
// TODO (tracked): use full range self.project.update(cx, |project, cx| {
}]), project.open_image(project_path.clone(), cx)
..Default::default() })
}); })?
.await?;
// TODO (tracked): images let image =
// if image_store::is_image_file(&self.project, &project_path, cx) { image_entity.read_with(cx, |image_item, _| Arc::clone(&image_item.image))?;
// let model = &self.thread.read(cx).selected_model;
// if !model.supports_images() { let language_model_image = cx
// return Task::ready(Err(anyhow!( .update(|cx| LanguageModelImage::from_image(image, cx))?
// "Attempted to read an image, but Zed doesn't currently support sending images to {}.", .await
// model.name().0 .context("processing image")?;
// )))
// .into();
// }
// return cx.spawn(async move |cx| -> Result<ToolResultOutput> { Ok(language_model_image.into())
// let image_entity: Entity<ImageItem> = cx });
// .update(|cx| { }
// self.project.update(cx, |project, cx| {
// project.open_image(project_path.clone(), cx)
// })
// })?
// .await?;
// let image =
// image_entity.read_with(cx, |image_item, _| Arc::clone(&image_item.image))?;
// let language_model_image = cx
// .update(|cx| LanguageModelImage::from_image(image, cx))?
// .await
// .context("processing image")?;
// Ok(ToolResultOutput {
// content: ToolResultContent::Image(language_model_image),
// output: None,
// })
// });
// }
//
let project = self.project.clone(); let project = self.project.clone();
let action_log = self.action_log.clone(); let action_log = self.action_log.clone();
@ -245,7 +222,7 @@ impl AgentTool for ReadFileTool {
})?; })?;
} }
Ok(result) Ok(result.into())
} else { } else {
// No line ranges specified, so check file size to see if it's too big. // No line ranges specified, so check file size to see if it's too big.
let file_size = buffer.read_with(cx, |buffer, _cx| buffer.text().len())?; let file_size = buffer.read_with(cx, |buffer, _cx| buffer.text().len())?;
@ -258,7 +235,7 @@ impl AgentTool for ReadFileTool {
log.buffer_read(buffer, cx); log.buffer_read(buffer, cx);
})?; })?;
Ok(result) Ok(result.into())
} else { } else {
// File is too big, so return the outline // File is too big, so return the outline
// and a suggestion to read again with line numbers. // and a suggestion to read again with line numbers.
@ -277,7 +254,8 @@ impl AgentTool for ReadFileTool {
Alternatively, you can fall back to the `grep` tool (if available) Alternatively, you can fall back to the `grep` tool (if available)
to search the file for specific content." to search the file for specific content."
}) }
.into())
} }
} }
}) })
@ -346,7 +324,7 @@ mod test {
tool.run(input, ToolCallEventStream::test().0, cx) tool.run(input, ToolCallEventStream::test().0, cx)
}) })
.await; .await;
assert_eq!(result.unwrap(), "This is a small file content"); assert_eq!(result.unwrap(), "This is a small file content".into());
} }
#[gpui::test] #[gpui::test]
@ -366,7 +344,7 @@ mod test {
language_registry.add(Arc::new(rust_lang())); language_registry.add(Arc::new(rust_lang()));
let action_log = cx.new(|_| ActionLog::new(project.clone())); let action_log = cx.new(|_| ActionLog::new(project.clone()));
let tool = Arc::new(ReadFileTool::new(project, action_log)); let tool = Arc::new(ReadFileTool::new(project, action_log));
let content = cx let result = cx
.update(|cx| { .update(|cx| {
let input = ReadFileToolInput { let input = ReadFileToolInput {
path: "root/large_file.rs".into(), path: "root/large_file.rs".into(),
@ -377,6 +355,7 @@ mod test {
}) })
.await .await
.unwrap(); .unwrap();
let content = result.to_str().unwrap();
assert_eq!( assert_eq!(
content.lines().skip(4).take(6).collect::<Vec<_>>(), content.lines().skip(4).take(6).collect::<Vec<_>>(),
@ -399,8 +378,9 @@ mod test {
}; };
tool.run(input, ToolCallEventStream::test().0, cx) tool.run(input, ToolCallEventStream::test().0, cx)
}) })
.await; .await
let content = result.unwrap(); .unwrap();
let content = result.to_str().unwrap();
let expected_content = (0..1000) let expected_content = (0..1000)
.flat_map(|i| { .flat_map(|i| {
vec![ vec![
@ -446,7 +426,7 @@ mod test {
tool.run(input, ToolCallEventStream::test().0, cx) tool.run(input, ToolCallEventStream::test().0, cx)
}) })
.await; .await;
assert_eq!(result.unwrap(), "Line 2\nLine 3\nLine 4"); assert_eq!(result.unwrap(), "Line 2\nLine 3\nLine 4".into());
} }
#[gpui::test] #[gpui::test]
@ -476,7 +456,7 @@ mod test {
tool.clone().run(input, ToolCallEventStream::test().0, cx) tool.clone().run(input, ToolCallEventStream::test().0, cx)
}) })
.await; .await;
assert_eq!(result.unwrap(), "Line 1\nLine 2"); assert_eq!(result.unwrap(), "Line 1\nLine 2".into());
// end_line of 0 should result in at least 1 line // end_line of 0 should result in at least 1 line
let result = cx let result = cx
@ -489,7 +469,7 @@ mod test {
tool.clone().run(input, ToolCallEventStream::test().0, cx) tool.clone().run(input, ToolCallEventStream::test().0, cx)
}) })
.await; .await;
assert_eq!(result.unwrap(), "Line 1"); assert_eq!(result.unwrap(), "Line 1".into());
// when start_line > end_line, should still return at least 1 line // when start_line > end_line, should still return at least 1 line
let result = cx let result = cx
@ -502,7 +482,7 @@ mod test {
tool.clone().run(input, ToolCallEventStream::test().0, cx) tool.clone().run(input, ToolCallEventStream::test().0, cx)
}) })
.await; .await;
assert_eq!(result.unwrap(), "Line 3"); assert_eq!(result.unwrap(), "Line 3".into());
} }
fn init_test(cx: &mut TestAppContext) { fn init_test(cx: &mut TestAppContext) {
@ -730,7 +710,7 @@ mod test {
}) })
.await; .await;
assert!(result.is_ok(), "Should be able to read normal files"); assert!(result.is_ok(), "Should be able to read normal files");
assert_eq!(result.unwrap(), "Normal file content"); assert_eq!(result.unwrap(), "Normal file content".into());
// Path traversal attempts with .. should fail // Path traversal attempts with .. should fail
let result = cx let result = cx
@ -835,7 +815,10 @@ mod test {
.await .await
.unwrap(); .unwrap();
assert_eq!(result, "fn main() { println!(\"Hello from worktree1\"); }"); assert_eq!(
result,
"fn main() { println!(\"Hello from worktree1\"); }".into()
);
// Test reading private file in worktree1 should fail // Test reading private file in worktree1 should fail
let result = cx let result = cx
@ -894,7 +877,7 @@ mod test {
assert_eq!( assert_eq!(
result, result,
"export function greet() { return 'Hello from worktree2'; }" "export function greet() { return 'Hello from worktree2'; }".into()
); );
// Test reading private file in worktree2 should fail // Test reading private file in worktree2 should fail

View file

@ -297,6 +297,12 @@ impl From<String> for LanguageModelToolResultContent {
} }
} }
impl From<LanguageModelImage> for LanguageModelToolResultContent {
fn from(image: LanguageModelImage) -> Self {
Self::Image(image)
}
}
#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)] #[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)]
pub enum MessageContent { pub enum MessageContent {
Text(String), Text(String),