diff --git a/crates/agent2/src/tests/mod.rs b/crates/agent2/src/tests/mod.rs index 26cf38741b..b70f54ac0a 100644 --- a/crates/agent2/src/tests/mod.rs +++ b/crates/agent2/src/tests/mod.rs @@ -272,7 +272,7 @@ async fn test_tool_authorization(cx: &mut TestAppContext) { tool_name: ToolRequiringPermission.name().into(), is_error: false, content: "Allowed".into(), - output: None + output: Some("Allowed".into()) }), MessageContent::ToolResult(LanguageModelToolResult { tool_use_id: tool_call_auth_2.tool_call.id.0.to_string().into(), diff --git a/crates/agent2/src/thread.rs b/crates/agent2/src/thread.rs index e60387bd44..98f2d0651d 100644 --- a/crates/agent2/src/thread.rs +++ b/crates/agent2/src/thread.rs @@ -510,15 +510,27 @@ impl Thread { status: Some(acp::ToolCallStatus::InProgress), ..Default::default() }); + let supports_images = self.selected_model.supports_images(); let tool_result = tool.run(tool_use.input, tool_event_stream, cx); Some(cx.foreground_executor().spawn(async move { - match tool_result.await { - Ok(tool_output) => LanguageModelToolResult { + let tool_result = tool_result.await.and_then(|output| { + if let LanguageModelToolResultContent::Image(_) = &output.llm_output { + if !supports_images { + return Err(anyhow!( + "Attempted to read an image, but this model doesn't support it.", + )); + } + } + Ok(output) + }); + + match tool_result { + Ok(output) => LanguageModelToolResult { tool_use_id: tool_use.id, tool_name: tool_use.name, is_error: false, - content: tool_output.llm_output, - output: Some(tool_output.raw_output), + content: output.llm_output, + output: Some(output.raw_output), }, Err(error) => LanguageModelToolResult { tool_use_id: tool_use.id, diff --git a/crates/agent2/src/tools/edit_file_tool.rs b/crates/agent2/src/tools/edit_file_tool.rs index fdb057c683..0dbe0be217 100644 --- a/crates/agent2/src/tools/edit_file_tool.rs +++ b/crates/agent2/src/tools/edit_file_tool.rs @@ -882,7 +882,7 @@ mod tests { } #[gpui::test] - async fn test_needs_confirmation(cx: &mut TestAppContext) { + async fn test_authorize(cx: &mut TestAppContext) { init_test(cx); let fs = project::FakeFs::new(cx.executor()); let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; @@ -967,23 +967,7 @@ mod tests { let event = stream_rx.expect_tool_authorization().await; assert_eq!(event.tool_call.title, "test 4 (local settings)"); - // Test 5: Path outside of the project should require confirmation. - let (stream_tx, mut stream_rx) = ToolCallEventStream::test(); - let _auth = cx.update(|cx| { - tool.authorize( - &EditFileToolInput { - display_description: "test 5".into(), - path: paths::config_dir().join("tasks.json"), - mode: EditFileMode::Edit, - }, - &stream_tx, - cx, - ) - }); - let event = stream_rx.expect_tool_authorization().await; - assert_eq!(event.tool_call.title, "test 5 (global settings)"); - - // Test 6: When always_allow_tool_actions is enabled, no confirmation needed + // Test 5: When always_allow_tool_actions is enabled, no confirmation needed cx.update(|cx| { let mut settings = agent_settings::AgentSettings::get_global(cx).clone(); settings.always_allow_tool_actions = true; @@ -994,7 +978,7 @@ mod tests { cx.update(|cx| { tool.authorize( &EditFileToolInput { - display_description: "test 6.1".into(), + display_description: "test 5.1".into(), path: ".zed/settings.json".into(), mode: EditFileMode::Edit, }, @@ -1010,7 +994,7 @@ mod tests { cx.update(|cx| { tool.authorize( &EditFileToolInput { - display_description: "test 6.2".into(), + display_description: "test 5.2".into(), path: "/etc/hosts".into(), mode: EditFileMode::Edit, }, @@ -1023,6 +1007,72 @@ mod tests { assert!(stream_rx.try_next().is_err()); } + #[gpui::test] + async fn test_authorize_global_config(cx: &mut TestAppContext) { + init_test(cx); + let fs = project::FakeFs::new(cx.executor()); + fs.insert_tree("/project", json!({})).await; + let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; + let action_log = cx.new(|_| ActionLog::new(project.clone())); + let model = Arc::new(FakeLanguageModel::default()); + let thread = cx.new(|_| { + Thread::new( + project, + Rc::default(), + action_log.clone(), + Templates::new(), + model.clone(), + ) + }); + let tool = Arc::new(EditFileTool { thread }); + + // Test global config paths - these should require confirmation if they exist and are outside the project + let test_cases = vec![ + ( + "/etc/hosts", + true, + "System file should require confirmation", + ), + ( + "/usr/local/bin/script", + true, + "System bin file should require confirmation", + ), + ( + "project/normal_file.rs", + false, + "Normal project file should not require confirmation", + ), + ]; + + for (path, should_confirm, description) in test_cases { + let (stream_tx, mut stream_rx) = ToolCallEventStream::test(); + let auth = cx.update(|cx| { + tool.authorize( + &EditFileToolInput { + display_description: "Edit file".into(), + path: path.into(), + mode: EditFileMode::Edit, + }, + &stream_tx, + cx, + ) + }); + + if should_confirm { + stream_rx.expect_tool_authorization().await; + } else { + auth.await.unwrap(); + assert!( + stream_rx.try_next().is_err(), + "Failed for case: {} - path: {} - expected no confirmation but got one", + description, + path + ); + } + } + } + #[gpui::test] async fn test_needs_confirmation_with_multiple_worktrees(cx: &mut TestAppContext) { init_test(cx); diff --git a/crates/agent2/src/tools/read_file_tool.rs b/crates/agent2/src/tools/read_file_tool.rs index f85efed286..3d91e3dc74 100644 --- a/crates/agent2/src/tools/read_file_tool.rs +++ b/crates/agent2/src/tools/read_file_tool.rs @@ -1,10 +1,11 @@ use agent_client_protocol::{self as acp}; -use anyhow::{anyhow, Result}; +use anyhow::{anyhow, Context, Result}; use assistant_tool::{outline, ActionLog}; use gpui::{Entity, Task}; use indoc::formatdoc; use language::{Anchor, Point}; -use project::{AgentLocation, Project, WorktreeSettings}; +use language_model::{LanguageModelImage, LanguageModelToolResultContent}; +use project::{image_store, AgentLocation, ImageItem, Project, WorktreeSettings}; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use settings::Settings; @@ -59,7 +60,7 @@ impl ReadFileTool { impl AgentTool for ReadFileTool { type Input = ReadFileToolInput; - type Output = String; + type Output = LanguageModelToolResultContent; fn name(&self) -> SharedString { "read_file".into() @@ -92,9 +93,9 @@ impl AgentTool for ReadFileTool { fn run( self: Arc, input: Self::Input, - event_stream: ToolCallEventStream, + _event_stream: ToolCallEventStream, cx: &mut App, - ) -> Task> { + ) -> Task> { let Some(project_path) = self.project.read(cx).find_project_path(&input.path, cx) else { return Task::ready(Err(anyhow!("Path {} not found in project", &input.path))); }; @@ -133,51 +134,27 @@ impl AgentTool for ReadFileTool { let file_path = input.path.clone(); - event_stream.send_update(acp::ToolCallUpdateFields { - locations: Some(vec![acp::ToolCallLocation { - path: project_path.path.to_path_buf(), - line: input.start_line, - // TODO (tracked): use full range - }]), - ..Default::default() - }); + if image_store::is_image_file(&self.project, &project_path, cx) { + return cx.spawn(async move |cx| { + let image_entity: Entity = cx + .update(|cx| { + self.project.update(cx, |project, cx| { + project.open_image(project_path.clone(), cx) + }) + })? + .await?; - // TODO (tracked): images - // if image_store::is_image_file(&self.project, &project_path, cx) { - // let model = &self.thread.read(cx).selected_model; + let image = + image_entity.read_with(cx, |image_item, _| Arc::clone(&image_item.image))?; - // if !model.supports_images() { - // return Task::ready(Err(anyhow!( - // "Attempted to read an image, but Zed doesn't currently support sending images to {}.", - // model.name().0 - // ))) - // .into(); - // } + let language_model_image = cx + .update(|cx| LanguageModelImage::from_image(image, cx))? + .await + .context("processing image")?; - // return cx.spawn(async move |cx| -> Result { - // let image_entity: Entity = cx - // .update(|cx| { - // self.project.update(cx, |project, cx| { - // project.open_image(project_path.clone(), cx) - // }) - // })? - // .await?; - - // let image = - // image_entity.read_with(cx, |image_item, _| Arc::clone(&image_item.image))?; - - // let language_model_image = cx - // .update(|cx| LanguageModelImage::from_image(image, cx))? - // .await - // .context("processing image")?; - - // Ok(ToolResultOutput { - // content: ToolResultContent::Image(language_model_image), - // output: None, - // }) - // }); - // } - // + Ok(language_model_image.into()) + }); + } let project = self.project.clone(); let action_log = self.action_log.clone(); @@ -245,7 +222,7 @@ impl AgentTool for ReadFileTool { })?; } - Ok(result) + Ok(result.into()) } else { // No line ranges specified, so check file size to see if it's too big. let file_size = buffer.read_with(cx, |buffer, _cx| buffer.text().len())?; @@ -258,7 +235,7 @@ impl AgentTool for ReadFileTool { log.buffer_read(buffer, cx); })?; - Ok(result) + Ok(result.into()) } else { // File is too big, so return the outline // and a suggestion to read again with line numbers. @@ -277,7 +254,8 @@ impl AgentTool for ReadFileTool { Alternatively, you can fall back to the `grep` tool (if available) to search the file for specific content." - }) + } + .into()) } } }) @@ -346,7 +324,7 @@ mod test { tool.run(input, ToolCallEventStream::test().0, cx) }) .await; - assert_eq!(result.unwrap(), "This is a small file content"); + assert_eq!(result.unwrap(), "This is a small file content".into()); } #[gpui::test] @@ -366,7 +344,7 @@ mod test { language_registry.add(Arc::new(rust_lang())); let action_log = cx.new(|_| ActionLog::new(project.clone())); let tool = Arc::new(ReadFileTool::new(project, action_log)); - let content = cx + let result = cx .update(|cx| { let input = ReadFileToolInput { path: "root/large_file.rs".into(), @@ -377,6 +355,7 @@ mod test { }) .await .unwrap(); + let content = result.to_str().unwrap(); assert_eq!( content.lines().skip(4).take(6).collect::>(), @@ -399,8 +378,9 @@ mod test { }; tool.run(input, ToolCallEventStream::test().0, cx) }) - .await; - let content = result.unwrap(); + .await + .unwrap(); + let content = result.to_str().unwrap(); let expected_content = (0..1000) .flat_map(|i| { vec![ @@ -446,7 +426,7 @@ mod test { tool.run(input, ToolCallEventStream::test().0, cx) }) .await; - assert_eq!(result.unwrap(), "Line 2\nLine 3\nLine 4"); + assert_eq!(result.unwrap(), "Line 2\nLine 3\nLine 4".into()); } #[gpui::test] @@ -476,7 +456,7 @@ mod test { tool.clone().run(input, ToolCallEventStream::test().0, cx) }) .await; - assert_eq!(result.unwrap(), "Line 1\nLine 2"); + assert_eq!(result.unwrap(), "Line 1\nLine 2".into()); // end_line of 0 should result in at least 1 line let result = cx @@ -489,7 +469,7 @@ mod test { tool.clone().run(input, ToolCallEventStream::test().0, cx) }) .await; - assert_eq!(result.unwrap(), "Line 1"); + assert_eq!(result.unwrap(), "Line 1".into()); // when start_line > end_line, should still return at least 1 line let result = cx @@ -502,7 +482,7 @@ mod test { tool.clone().run(input, ToolCallEventStream::test().0, cx) }) .await; - assert_eq!(result.unwrap(), "Line 3"); + assert_eq!(result.unwrap(), "Line 3".into()); } fn init_test(cx: &mut TestAppContext) { @@ -730,7 +710,7 @@ mod test { }) .await; assert!(result.is_ok(), "Should be able to read normal files"); - assert_eq!(result.unwrap(), "Normal file content"); + assert_eq!(result.unwrap(), "Normal file content".into()); // Path traversal attempts with .. should fail let result = cx @@ -835,7 +815,10 @@ mod test { .await .unwrap(); - assert_eq!(result, "fn main() { println!(\"Hello from worktree1\"); }"); + assert_eq!( + result, + "fn main() { println!(\"Hello from worktree1\"); }".into() + ); // Test reading private file in worktree1 should fail let result = cx @@ -894,7 +877,7 @@ mod test { assert_eq!( result, - "export function greet() { return 'Hello from worktree2'; }" + "export function greet() { return 'Hello from worktree2'; }".into() ); // Test reading private file in worktree2 should fail diff --git a/crates/language_model/src/request.rs b/crates/language_model/src/request.rs index dc485e9937..edce3d03b7 100644 --- a/crates/language_model/src/request.rs +++ b/crates/language_model/src/request.rs @@ -297,6 +297,12 @@ impl From for LanguageModelToolResultContent { } } +impl From for LanguageModelToolResultContent { + fn from(image: LanguageModelImage) -> Self { + Self::Image(image) + } +} + #[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)] pub enum MessageContent { Text(String),