diff --git a/Cargo.lock b/Cargo.lock index c7c878f654..e2318eac72 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -7713,6 +7713,7 @@ dependencies = [ "mistral", "ollama", "open_ai", + "partial-json-fixer", "project", "proto", "schemars", @@ -9828,6 +9829,12 @@ dependencies = [ "windows-targets 0.52.6", ] +[[package]] +name = "partial-json-fixer" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "35ffd90b3f3b6477db7478016b9efb1b7e9d38eafd095f0542fe0ec2ea884a13" + [[package]] name = "password-hash" version = "0.4.2" diff --git a/Cargo.toml b/Cargo.toml index bb2ffeba0e..891d6a16c9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -480,6 +480,7 @@ num-format = "0.4.4" ordered-float = "2.1.1" palette = { version = "0.7.5", default-features = false, features = ["std"] } parking_lot = "0.12.1" +partial-json-fixer = "0.5.3" pathdiff = "0.2" pet = { git = "https://github.com/microsoft/python-environment-tools.git", rev = "845945b830297a50de0e24020b980a65e4820559" } pet-fs = { git = "https://github.com/microsoft/python-environment-tools.git", rev = "845945b830297a50de0e24020b980a65e4820559" } diff --git a/crates/agent/src/active_thread.rs b/crates/agent/src/active_thread.rs index a0dd54218d..55968d6849 100644 --- a/crates/agent/src/active_thread.rs +++ b/crates/agent/src/active_thread.rs @@ -266,14 +266,6 @@ fn default_markdown_style(window: &Window, cx: &App) -> MarkdownStyle { } } -fn render_tool_use_markdown( - text: SharedString, - language_registry: Arc, - cx: &mut App, -) -> Entity { - cx.new(|cx| Markdown::new(text, Some(language_registry), None, cx)) -} - fn tool_use_markdown_style(window: &Window, cx: &mut App) -> MarkdownStyle { let theme_settings = ThemeSettings::get_global(cx); let colors = cx.theme().colors(); @@ -867,21 +859,34 @@ impl ActiveThread { tool_output: SharedString, cx: &mut Context, ) { - let rendered = RenderedToolUse { - label: render_tool_use_markdown(tool_label.into(), self.language_registry.clone(), cx), - input: render_tool_use_markdown( - format!( - "```json\n{}\n```", - serde_json::to_string_pretty(tool_input).unwrap_or_default() - ) - .into(), - self.language_registry.clone(), - cx, - ), - output: render_tool_use_markdown(tool_output, self.language_registry.clone(), cx), - }; - self.rendered_tool_uses - .insert(tool_use_id.clone(), rendered); + let rendered = self + .rendered_tool_uses + .entry(tool_use_id.clone()) + .or_insert_with(|| RenderedToolUse { + label: cx.new(|cx| { + Markdown::new("".into(), Some(self.language_registry.clone()), None, cx) + }), + input: cx.new(|cx| { + Markdown::new("".into(), Some(self.language_registry.clone()), None, cx) + }), + output: cx.new(|cx| { + Markdown::new("".into(), Some(self.language_registry.clone()), None, cx) + }), + }); + + rendered.label.update(cx, |this, cx| { + this.replace(tool_label, cx); + }); + rendered.input.update(cx, |this, cx| { + let input = format!( + "```json\n{}\n```", + serde_json::to_string_pretty(tool_input).unwrap_or_default() + ); + this.replace(input, cx); + }); + rendered.output.update(cx, |this, cx| { + this.replace(tool_output, cx); + }); } fn handle_thread_event( @@ -974,6 +979,19 @@ impl ActiveThread { ); } } + ThreadEvent::StreamedToolUse { + tool_use_id, + ui_text, + input, + } => { + self.render_tool_use_markdown( + tool_use_id.clone(), + ui_text.clone(), + input, + "".into(), + cx, + ); + } ThreadEvent::ToolFinished { pending_tool_use, .. } => { @@ -2478,13 +2496,15 @@ impl ActiveThread { let edit_tools = tool_use.needs_confirmation; let status_icons = div().child(match &tool_use.status { - ToolUseStatus::Pending | ToolUseStatus::NeedsConfirmation => { + ToolUseStatus::NeedsConfirmation => { let icon = Icon::new(IconName::Warning) .color(Color::Warning) .size(IconSize::Small); icon.into_any_element() } - ToolUseStatus::Running => { + ToolUseStatus::Pending + | ToolUseStatus::InputStillStreaming + | ToolUseStatus::Running => { let icon = Icon::new(IconName::ArrowCircle) .color(Color::Accent) .size(IconSize::Small); @@ -2570,7 +2590,7 @@ impl ActiveThread { }), )), ), - ToolUseStatus::Running => container.child( + ToolUseStatus::InputStillStreaming | ToolUseStatus::Running => container.child( results_content_container().child( h_flex() .gap_1() diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index 9769fd92ba..81b6b3c980 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -1293,12 +1293,27 @@ impl Thread { thread.insert_message(Role::Assistant, vec![], cx) }); - thread.tool_use.request_tool_use( + let tool_use_id = tool_use.id.clone(); + let streamed_input = if tool_use.is_input_complete { + None + } else { + Some((&tool_use.input).clone()) + }; + + let ui_text = thread.tool_use.request_tool_use( last_assistant_message_id, tool_use, tool_use_metadata.clone(), cx, ); + + if let Some(input) = streamed_input { + cx.emit(ThreadEvent::StreamedToolUse { + tool_use_id, + ui_text, + input, + }); + } } } @@ -2189,6 +2204,11 @@ pub enum ThreadEvent { StreamedCompletion, StreamedAssistantText(MessageId, String), StreamedAssistantThinking(MessageId, String), + StreamedToolUse { + tool_use_id: LanguageModelToolUseId, + ui_text: Arc, + input: serde_json::Value, + }, Stopped(Result>), MessageAdded(MessageId), MessageEdited(MessageId), diff --git a/crates/agent/src/tool_use.rs b/crates/agent/src/tool_use.rs index 2e10248e79..4fb49a2d16 100644 --- a/crates/agent/src/tool_use.rs +++ b/crates/agent/src/tool_use.rs @@ -75,6 +75,7 @@ impl ToolUseState { id: tool_use.id.clone(), name: tool_use.name.clone().into(), input: tool_use.input.clone(), + is_input_complete: true, }) .collect::>(); @@ -176,6 +177,9 @@ impl ToolUseState { PendingToolUseStatus::Error(ref err) => { ToolUseStatus::Error(err.clone().into()) } + PendingToolUseStatus::InputStillStreaming => { + ToolUseStatus::InputStillStreaming + } } } else { ToolUseStatus::Pending @@ -192,7 +196,12 @@ impl ToolUseState { tool_uses.push(ToolUse { id: tool_use.id.clone(), name: tool_use.name.clone().into(), - ui_text: self.tool_ui_label(&tool_use.name, &tool_use.input, cx), + ui_text: self.tool_ui_label( + &tool_use.name, + &tool_use.input, + tool_use.is_input_complete, + cx, + ), input: tool_use.input.clone(), status, icon, @@ -207,10 +216,15 @@ impl ToolUseState { &self, tool_name: &str, input: &serde_json::Value, + is_input_complete: bool, cx: &App, ) -> SharedString { if let Some(tool) = self.tools.read(cx).tool(tool_name, cx) { - tool.ui_text(input).into() + if is_input_complete { + tool.ui_text(input).into() + } else { + tool.still_streaming_ui_text(input).into() + } } else { format!("Unknown tool {tool_name:?}").into() } @@ -258,22 +272,50 @@ impl ToolUseState { tool_use: LanguageModelToolUse, metadata: ToolUseMetadata, cx: &App, - ) { - self.tool_uses_by_assistant_message + ) -> Arc { + let tool_uses = self + .tool_uses_by_assistant_message .entry(assistant_message_id) - .or_default() - .push(tool_use.clone()); + .or_default(); - self.tool_use_metadata_by_id - .insert(tool_use.id.clone(), metadata); + let mut existing_tool_use_found = false; - // The tool use is being requested by the Assistant, so we want to - // attach the tool results to the next user message. - let next_user_message_id = MessageId(assistant_message_id.0 + 1); - self.tool_uses_by_user_message - .entry(next_user_message_id) - .or_default() - .push(tool_use.id.clone()); + for existing_tool_use in tool_uses.iter_mut() { + if existing_tool_use.id == tool_use.id { + *existing_tool_use = tool_use.clone(); + existing_tool_use_found = true; + } + } + + if !existing_tool_use_found { + tool_uses.push(tool_use.clone()); + } + + let status = if tool_use.is_input_complete { + self.tool_use_metadata_by_id + .insert(tool_use.id.clone(), metadata); + + // The tool use is being requested by the Assistant, so we want to + // attach the tool results to the next user message. + let next_user_message_id = MessageId(assistant_message_id.0 + 1); + self.tool_uses_by_user_message + .entry(next_user_message_id) + .or_default() + .push(tool_use.id.clone()); + + PendingToolUseStatus::Idle + } else { + PendingToolUseStatus::InputStillStreaming + }; + + let ui_text: Arc = self + .tool_ui_label( + &tool_use.name, + &tool_use.input, + tool_use.is_input_complete, + cx, + ) + .into(); self.pending_tool_uses_by_id.insert( tool_use.id.clone(), @@ -281,13 +323,13 @@ impl ToolUseState { assistant_message_id, id: tool_use.id, name: tool_use.name.clone(), - ui_text: self - .tool_ui_label(&tool_use.name, &tool_use.input, cx) - .into(), + ui_text: ui_text.clone(), input: tool_use.input, - status: PendingToolUseStatus::Idle, + status, }, ); + + ui_text } pub fn run_pending_tool( @@ -497,6 +539,7 @@ pub struct Confirmation { #[derive(Debug, Clone)] pub enum PendingToolUseStatus { + InputStillStreaming, Idle, NeedsConfirmation(Arc), Running { _task: Shared> }, diff --git a/crates/assistant_tool/src/assistant_tool.rs b/crates/assistant_tool/src/assistant_tool.rs index cb7f0ff518..c3e277c783 100644 --- a/crates/assistant_tool/src/assistant_tool.rs +++ b/crates/assistant_tool/src/assistant_tool.rs @@ -30,6 +30,7 @@ pub fn init(cx: &mut App) { #[derive(Debug, Clone)] pub enum ToolUseStatus { + InputStillStreaming, NeedsConfirmation, Pending, Running, @@ -41,6 +42,7 @@ impl ToolUseStatus { pub fn text(&self) -> SharedString { match self { ToolUseStatus::NeedsConfirmation => "".into(), + ToolUseStatus::InputStillStreaming => "".into(), ToolUseStatus::Pending => "".into(), ToolUseStatus::Running => "".into(), ToolUseStatus::Finished(out) => out.clone(), @@ -148,6 +150,12 @@ pub trait Tool: 'static + Send + Sync { /// Returns markdown to be displayed in the UI for this tool. fn ui_text(&self, input: &serde_json::Value) -> String; + /// Returns markdown to be displayed in the UI for this tool, while the input JSON is still streaming + /// (so information may be missing). + fn still_streaming_ui_text(&self, input: &serde_json::Value) -> String { + self.ui_text(input) + } + /// Runs the tool with the provided input. fn run( self: Arc, diff --git a/crates/assistant_tools/src/create_file_tool.rs b/crates/assistant_tools/src/create_file_tool.rs index dc777bfb8d..934811cbfa 100644 --- a/crates/assistant_tools/src/create_file_tool.rs +++ b/crates/assistant_tools/src/create_file_tool.rs @@ -33,8 +33,18 @@ pub struct CreateFileToolInput { pub contents: String, } +#[derive(Debug, Serialize, Deserialize, JsonSchema)] +struct PartialInput { + #[serde(default)] + path: String, + #[serde(default)] + contents: String, +} + pub struct CreateFileTool; +const DEFAULT_UI_TEXT: &str = "Create file"; + impl Tool for CreateFileTool { fn name(&self) -> String { "create_file".into() @@ -62,7 +72,14 @@ impl Tool for CreateFileTool { let path = MarkdownString::inline_code(&input.path); format!("Create file {path}") } - Err(_) => "Create file".to_string(), + Err(_) => DEFAULT_UI_TEXT.to_string(), + } + } + + fn still_streaming_ui_text(&self, input: &serde_json::Value) -> String { + match serde_json::from_value::(input.clone()).ok() { + Some(input) if !input.path.is_empty() => input.path, + _ => DEFAULT_UI_TEXT.to_string(), } } @@ -111,3 +128,60 @@ impl Tool for CreateFileTool { .into() } } + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + + #[test] + fn still_streaming_ui_text_with_path() { + let tool = CreateFileTool; + let input = json!({ + "path": "src/main.rs", + "contents": "fn main() {\n println!(\"Hello, world!\");\n}" + }); + + assert_eq!(tool.still_streaming_ui_text(&input), "src/main.rs"); + } + + #[test] + fn still_streaming_ui_text_without_path() { + let tool = CreateFileTool; + let input = json!({ + "path": "", + "contents": "fn main() {\n println!(\"Hello, world!\");\n}" + }); + + assert_eq!(tool.still_streaming_ui_text(&input), DEFAULT_UI_TEXT); + } + + #[test] + fn still_streaming_ui_text_with_null() { + let tool = CreateFileTool; + let input = serde_json::Value::Null; + + assert_eq!(tool.still_streaming_ui_text(&input), DEFAULT_UI_TEXT); + } + + #[test] + fn ui_text_with_valid_input() { + let tool = CreateFileTool; + let input = json!({ + "path": "src/main.rs", + "contents": "fn main() {\n println!(\"Hello, world!\");\n}" + }); + + assert_eq!(tool.ui_text(&input), "Create file `src/main.rs`"); + } + + #[test] + fn ui_text_with_invalid_input() { + let tool = CreateFileTool; + let input = json!({ + "invalid": "field" + }); + + assert_eq!(tool.ui_text(&input), DEFAULT_UI_TEXT); + } +} diff --git a/crates/assistant_tools/src/edit_file_tool.rs b/crates/assistant_tools/src/edit_file_tool.rs index 136dd60bfe..c66e2ffb69 100644 --- a/crates/assistant_tools/src/edit_file_tool.rs +++ b/crates/assistant_tools/src/edit_file_tool.rs @@ -47,8 +47,22 @@ pub struct EditFileToolInput { pub new_string: String, } +#[derive(Debug, Serialize, Deserialize, JsonSchema)] +struct PartialInput { + #[serde(default)] + path: String, + #[serde(default)] + display_description: String, + #[serde(default)] + old_string: String, + #[serde(default)] + new_string: String, +} + pub struct EditFileTool; +const DEFAULT_UI_TEXT: &str = "Edit file"; + impl Tool for EditFileTool { fn name(&self) -> String { "edit_file".into() @@ -77,6 +91,22 @@ impl Tool for EditFileTool { } } + fn still_streaming_ui_text(&self, input: &serde_json::Value) -> String { + if let Some(input) = serde_json::from_value::(input.clone()).ok() { + let description = input.display_description.trim(); + if !description.is_empty() { + return description.to_string(); + } + + let path = input.path.trim(); + if !path.is_empty() { + return path.to_string(); + } + } + + DEFAULT_UI_TEXT.to_string() + } + fn run( self: Arc, input: serde_json::Value, @@ -181,3 +211,69 @@ impl Tool for EditFileTool { }).into() } } + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + + #[test] + fn still_streaming_ui_text_with_path() { + let tool = EditFileTool; + let input = json!({ + "path": "src/main.rs", + "display_description": "", + "old_string": "old code", + "new_string": "new code" + }); + + assert_eq!(tool.still_streaming_ui_text(&input), "src/main.rs"); + } + + #[test] + fn still_streaming_ui_text_with_description() { + let tool = EditFileTool; + let input = json!({ + "path": "", + "display_description": "Fix error handling", + "old_string": "old code", + "new_string": "new code" + }); + + assert_eq!(tool.still_streaming_ui_text(&input), "Fix error handling"); + } + + #[test] + fn still_streaming_ui_text_with_path_and_description() { + let tool = EditFileTool; + let input = json!({ + "path": "src/main.rs", + "display_description": "Fix error handling", + "old_string": "old code", + "new_string": "new code" + }); + + assert_eq!(tool.still_streaming_ui_text(&input), "Fix error handling"); + } + + #[test] + fn still_streaming_ui_text_no_path_or_description() { + let tool = EditFileTool; + let input = json!({ + "path": "", + "display_description": "", + "old_string": "old code", + "new_string": "new code" + }); + + assert_eq!(tool.still_streaming_ui_text(&input), DEFAULT_UI_TEXT); + } + + #[test] + fn still_streaming_ui_text_with_null() { + let tool = EditFileTool; + let input = serde_json::Value::Null; + + assert_eq!(tool.still_streaming_ui_text(&input), DEFAULT_UI_TEXT); + } +} diff --git a/crates/eval/src/example.rs b/crates/eval/src/example.rs index a82fbecc93..65f5a659dc 100644 --- a/crates/eval/src/example.rs +++ b/crates/eval/src/example.rs @@ -426,6 +426,7 @@ impl Example { ThreadEvent::ToolConfirmationNeeded => { panic!("{}Bug: Tool confirmation should not be required in eval", log_prefix); }, + ThreadEvent::StreamedToolUse { .. } | ThreadEvent::StreamedCompletion | ThreadEvent::MessageAdded(_) | ThreadEvent::MessageEdited(_) | diff --git a/crates/language_model/src/language_model.rs b/crates/language_model/src/language_model.rs index 71d8551bd5..8dc641f88c 100644 --- a/crates/language_model/src/language_model.rs +++ b/crates/language_model/src/language_model.rs @@ -187,6 +187,7 @@ pub struct LanguageModelToolUse { pub id: LanguageModelToolUseId, pub name: Arc, pub input: serde_json::Value, + pub is_input_complete: bool, } pub struct LanguageModelTextStream { diff --git a/crates/language_models/Cargo.toml b/crates/language_models/Cargo.toml index 304a284f4f..5ded3386b5 100644 --- a/crates/language_models/Cargo.toml +++ b/crates/language_models/Cargo.toml @@ -38,6 +38,7 @@ menu.workspace = true mistral = { workspace = true, features = ["schemars"] } ollama = { workspace = true, features = ["schemars"] } open_ai = { workspace = true, features = ["schemars"] } +partial-json-fixer.workspace = true project.workspace = true proto.workspace = true schemars.workspace = true diff --git a/crates/language_models/src/provider/anthropic.rs b/crates/language_models/src/provider/anthropic.rs index f998969bfe..6c8c664a26 100644 --- a/crates/language_models/src/provider/anthropic.rs +++ b/crates/language_models/src/provider/anthropic.rs @@ -713,6 +713,35 @@ pub fn map_to_language_model_completion_events( ContentDelta::InputJsonDelta { partial_json } => { if let Some(tool_use) = state.tool_uses_by_index.get_mut(&index) { tool_use.input_json.push_str(&partial_json); + + return Some(( + vec![maybe!({ + Ok(LanguageModelCompletionEvent::ToolUse( + LanguageModelToolUse { + id: tool_use.id.clone().into(), + name: tool_use.name.clone().into(), + is_input_complete: false, + input: if tool_use.input_json.is_empty() { + serde_json::Value::Object( + serde_json::Map::default(), + ) + } else { + serde_json::Value::from_str( + // Convert invalid (incomplete) JSON into + // JSON that serde will accept, e.g. by closing + // unclosed delimiters. This way, we can update + // the UI with whatever has been streamed back so far. + &partial_json_fixer::fix_json( + &tool_use.input_json, + ), + ) + .map_err(|err| anyhow!(err))? + }, + }, + )) + })], + state, + )); } } }, @@ -724,6 +753,7 @@ pub fn map_to_language_model_completion_events( LanguageModelToolUse { id: tool_use.id.into(), name: tool_use.name.into(), + is_input_complete: true, input: if tool_use.input_json.is_empty() { serde_json::Value::Object( serde_json::Map::default(), diff --git a/crates/language_models/src/provider/bedrock.rs b/crates/language_models/src/provider/bedrock.rs index a2748b45be..bb5c10cf93 100644 --- a/crates/language_models/src/provider/bedrock.rs +++ b/crates/language_models/src/provider/bedrock.rs @@ -893,6 +893,7 @@ pub fn map_to_language_model_completion_events( let tool_use_event = LanguageModelToolUse { id: tool_use.id.into(), name: tool_use.name.into(), + is_input_complete: true, input: if tool_use.input_json.is_empty() { Value::Null } else { diff --git a/crates/language_models/src/provider/copilot_chat.rs b/crates/language_models/src/provider/copilot_chat.rs index 255de2d536..eac138cd61 100644 --- a/crates/language_models/src/provider/copilot_chat.rs +++ b/crates/language_models/src/provider/copilot_chat.rs @@ -367,6 +367,7 @@ pub fn map_to_language_model_completion_events( LanguageModelToolUse { id: tool_call.id.into(), name: tool_call.name.as_str().into(), + is_input_complete: true, input: serde_json::Value::from_str( &tool_call.arguments, )?, diff --git a/crates/language_models/src/provider/google.rs b/crates/language_models/src/provider/google.rs index bbe6c58353..3db0157396 100644 --- a/crates/language_models/src/provider/google.rs +++ b/crates/language_models/src/provider/google.rs @@ -529,6 +529,7 @@ pub fn map_to_language_model_completion_events( LanguageModelToolUse { id, name, + is_input_complete: true, input: function_call_part.function_call.args, }, ))); diff --git a/crates/language_models/src/provider/open_ai.rs b/crates/language_models/src/provider/open_ai.rs index 188a219e2d..11ac5394e6 100644 --- a/crates/language_models/src/provider/open_ai.rs +++ b/crates/language_models/src/provider/open_ai.rs @@ -490,6 +490,7 @@ pub fn map_to_language_model_completion_events( LanguageModelToolUse { id: tool_call.id.into(), name: tool_call.name.as_str().into(), + is_input_complete: true, input: serde_json::Value::from_str( &tool_call.arguments, )?, diff --git a/crates/markdown/src/markdown.rs b/crates/markdown/src/markdown.rs index 0c1539824e..556a48135b 100644 --- a/crates/markdown/src/markdown.rs +++ b/crates/markdown/src/markdown.rs @@ -192,6 +192,11 @@ impl Markdown { self.parse(cx); } + pub fn replace(&mut self, source: impl Into, cx: &mut Context) { + self.source = source.into(); + self.parse(cx); + } + pub fn reset(&mut self, source: SharedString, cx: &mut Context) { if source == self.source() { return;