Streaming tool calls (#29179)

https://github.com/user-attachments/assets/7854a737-ef83-414c-b397-45122e4f32e8



Release Notes:

- Create file and edit file tools now stream their tool descriptions, so
you can see what they're doing sooner.

---------

Co-authored-by: Marshall Bowers <git@maxdeviant.com>
This commit is contained in:
Richard Feldman 2025-04-21 18:28:32 -04:00 committed by GitHub
parent 7aa0fa1543
commit 4f2f9ff762
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
17 changed files with 358 additions and 47 deletions

7
Cargo.lock generated
View file

@ -7713,6 +7713,7 @@ dependencies = [
"mistral",
"ollama",
"open_ai",
"partial-json-fixer",
"project",
"proto",
"schemars",
@ -9828,6 +9829,12 @@ dependencies = [
"windows-targets 0.52.6",
]
[[package]]
name = "partial-json-fixer"
version = "0.5.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "35ffd90b3f3b6477db7478016b9efb1b7e9d38eafd095f0542fe0ec2ea884a13"
[[package]]
name = "password-hash"
version = "0.4.2"

View file

@ -480,6 +480,7 @@ num-format = "0.4.4"
ordered-float = "2.1.1"
palette = { version = "0.7.5", default-features = false, features = ["std"] }
parking_lot = "0.12.1"
partial-json-fixer = "0.5.3"
pathdiff = "0.2"
pet = { git = "https://github.com/microsoft/python-environment-tools.git", rev = "845945b830297a50de0e24020b980a65e4820559" }
pet-fs = { git = "https://github.com/microsoft/python-environment-tools.git", rev = "845945b830297a50de0e24020b980a65e4820559" }

View file

@ -266,14 +266,6 @@ fn default_markdown_style(window: &Window, cx: &App) -> MarkdownStyle {
}
}
fn render_tool_use_markdown(
text: SharedString,
language_registry: Arc<LanguageRegistry>,
cx: &mut App,
) -> Entity<Markdown> {
cx.new(|cx| Markdown::new(text, Some(language_registry), None, cx))
}
fn tool_use_markdown_style(window: &Window, cx: &mut App) -> MarkdownStyle {
let theme_settings = ThemeSettings::get_global(cx);
let colors = cx.theme().colors();
@ -867,21 +859,34 @@ impl ActiveThread {
tool_output: SharedString,
cx: &mut Context<Self>,
) {
let rendered = RenderedToolUse {
label: render_tool_use_markdown(tool_label.into(), self.language_registry.clone(), cx),
input: render_tool_use_markdown(
format!(
"```json\n{}\n```",
serde_json::to_string_pretty(tool_input).unwrap_or_default()
)
.into(),
self.language_registry.clone(),
cx,
),
output: render_tool_use_markdown(tool_output, self.language_registry.clone(), cx),
};
self.rendered_tool_uses
.insert(tool_use_id.clone(), rendered);
let rendered = self
.rendered_tool_uses
.entry(tool_use_id.clone())
.or_insert_with(|| RenderedToolUse {
label: cx.new(|cx| {
Markdown::new("".into(), Some(self.language_registry.clone()), None, cx)
}),
input: cx.new(|cx| {
Markdown::new("".into(), Some(self.language_registry.clone()), None, cx)
}),
output: cx.new(|cx| {
Markdown::new("".into(), Some(self.language_registry.clone()), None, cx)
}),
});
rendered.label.update(cx, |this, cx| {
this.replace(tool_label, cx);
});
rendered.input.update(cx, |this, cx| {
let input = format!(
"```json\n{}\n```",
serde_json::to_string_pretty(tool_input).unwrap_or_default()
);
this.replace(input, cx);
});
rendered.output.update(cx, |this, cx| {
this.replace(tool_output, cx);
});
}
fn handle_thread_event(
@ -974,6 +979,19 @@ impl ActiveThread {
);
}
}
ThreadEvent::StreamedToolUse {
tool_use_id,
ui_text,
input,
} => {
self.render_tool_use_markdown(
tool_use_id.clone(),
ui_text.clone(),
input,
"".into(),
cx,
);
}
ThreadEvent::ToolFinished {
pending_tool_use, ..
} => {
@ -2478,13 +2496,15 @@ impl ActiveThread {
let edit_tools = tool_use.needs_confirmation;
let status_icons = div().child(match &tool_use.status {
ToolUseStatus::Pending | ToolUseStatus::NeedsConfirmation => {
ToolUseStatus::NeedsConfirmation => {
let icon = Icon::new(IconName::Warning)
.color(Color::Warning)
.size(IconSize::Small);
icon.into_any_element()
}
ToolUseStatus::Running => {
ToolUseStatus::Pending
| ToolUseStatus::InputStillStreaming
| ToolUseStatus::Running => {
let icon = Icon::new(IconName::ArrowCircle)
.color(Color::Accent)
.size(IconSize::Small);
@ -2570,7 +2590,7 @@ impl ActiveThread {
}),
)),
),
ToolUseStatus::Running => container.child(
ToolUseStatus::InputStillStreaming | ToolUseStatus::Running => container.child(
results_content_container().child(
h_flex()
.gap_1()

View file

@ -1293,12 +1293,27 @@ impl Thread {
thread.insert_message(Role::Assistant, vec![], cx)
});
thread.tool_use.request_tool_use(
let tool_use_id = tool_use.id.clone();
let streamed_input = if tool_use.is_input_complete {
None
} else {
Some((&tool_use.input).clone())
};
let ui_text = thread.tool_use.request_tool_use(
last_assistant_message_id,
tool_use,
tool_use_metadata.clone(),
cx,
);
if let Some(input) = streamed_input {
cx.emit(ThreadEvent::StreamedToolUse {
tool_use_id,
ui_text,
input,
});
}
}
}
@ -2189,6 +2204,11 @@ pub enum ThreadEvent {
StreamedCompletion,
StreamedAssistantText(MessageId, String),
StreamedAssistantThinking(MessageId, String),
StreamedToolUse {
tool_use_id: LanguageModelToolUseId,
ui_text: Arc<str>,
input: serde_json::Value,
},
Stopped(Result<StopReason, Arc<anyhow::Error>>),
MessageAdded(MessageId),
MessageEdited(MessageId),

View file

@ -75,6 +75,7 @@ impl ToolUseState {
id: tool_use.id.clone(),
name: tool_use.name.clone().into(),
input: tool_use.input.clone(),
is_input_complete: true,
})
.collect::<Vec<_>>();
@ -176,6 +177,9 @@ impl ToolUseState {
PendingToolUseStatus::Error(ref err) => {
ToolUseStatus::Error(err.clone().into())
}
PendingToolUseStatus::InputStillStreaming => {
ToolUseStatus::InputStillStreaming
}
}
} else {
ToolUseStatus::Pending
@ -192,7 +196,12 @@ impl ToolUseState {
tool_uses.push(ToolUse {
id: tool_use.id.clone(),
name: tool_use.name.clone().into(),
ui_text: self.tool_ui_label(&tool_use.name, &tool_use.input, cx),
ui_text: self.tool_ui_label(
&tool_use.name,
&tool_use.input,
tool_use.is_input_complete,
cx,
),
input: tool_use.input.clone(),
status,
icon,
@ -207,10 +216,15 @@ impl ToolUseState {
&self,
tool_name: &str,
input: &serde_json::Value,
is_input_complete: bool,
cx: &App,
) -> SharedString {
if let Some(tool) = self.tools.read(cx).tool(tool_name, cx) {
tool.ui_text(input).into()
if is_input_complete {
tool.ui_text(input).into()
} else {
tool.still_streaming_ui_text(input).into()
}
} else {
format!("Unknown tool {tool_name:?}").into()
}
@ -258,22 +272,50 @@ impl ToolUseState {
tool_use: LanguageModelToolUse,
metadata: ToolUseMetadata,
cx: &App,
) {
self.tool_uses_by_assistant_message
) -> Arc<str> {
let tool_uses = self
.tool_uses_by_assistant_message
.entry(assistant_message_id)
.or_default()
.push(tool_use.clone());
.or_default();
self.tool_use_metadata_by_id
.insert(tool_use.id.clone(), metadata);
let mut existing_tool_use_found = false;
// The tool use is being requested by the Assistant, so we want to
// attach the tool results to the next user message.
let next_user_message_id = MessageId(assistant_message_id.0 + 1);
self.tool_uses_by_user_message
.entry(next_user_message_id)
.or_default()
.push(tool_use.id.clone());
for existing_tool_use in tool_uses.iter_mut() {
if existing_tool_use.id == tool_use.id {
*existing_tool_use = tool_use.clone();
existing_tool_use_found = true;
}
}
if !existing_tool_use_found {
tool_uses.push(tool_use.clone());
}
let status = if tool_use.is_input_complete {
self.tool_use_metadata_by_id
.insert(tool_use.id.clone(), metadata);
// The tool use is being requested by the Assistant, so we want to
// attach the tool results to the next user message.
let next_user_message_id = MessageId(assistant_message_id.0 + 1);
self.tool_uses_by_user_message
.entry(next_user_message_id)
.or_default()
.push(tool_use.id.clone());
PendingToolUseStatus::Idle
} else {
PendingToolUseStatus::InputStillStreaming
};
let ui_text: Arc<str> = self
.tool_ui_label(
&tool_use.name,
&tool_use.input,
tool_use.is_input_complete,
cx,
)
.into();
self.pending_tool_uses_by_id.insert(
tool_use.id.clone(),
@ -281,13 +323,13 @@ impl ToolUseState {
assistant_message_id,
id: tool_use.id,
name: tool_use.name.clone(),
ui_text: self
.tool_ui_label(&tool_use.name, &tool_use.input, cx)
.into(),
ui_text: ui_text.clone(),
input: tool_use.input,
status: PendingToolUseStatus::Idle,
status,
},
);
ui_text
}
pub fn run_pending_tool(
@ -497,6 +539,7 @@ pub struct Confirmation {
#[derive(Debug, Clone)]
pub enum PendingToolUseStatus {
InputStillStreaming,
Idle,
NeedsConfirmation(Arc<Confirmation>),
Running { _task: Shared<Task<()>> },

View file

@ -30,6 +30,7 @@ pub fn init(cx: &mut App) {
#[derive(Debug, Clone)]
pub enum ToolUseStatus {
InputStillStreaming,
NeedsConfirmation,
Pending,
Running,
@ -41,6 +42,7 @@ impl ToolUseStatus {
pub fn text(&self) -> SharedString {
match self {
ToolUseStatus::NeedsConfirmation => "".into(),
ToolUseStatus::InputStillStreaming => "".into(),
ToolUseStatus::Pending => "".into(),
ToolUseStatus::Running => "".into(),
ToolUseStatus::Finished(out) => out.clone(),
@ -148,6 +150,12 @@ pub trait Tool: 'static + Send + Sync {
/// Returns markdown to be displayed in the UI for this tool.
fn ui_text(&self, input: &serde_json::Value) -> String;
/// Returns markdown to be displayed in the UI for this tool, while the input JSON is still streaming
/// (so information may be missing).
fn still_streaming_ui_text(&self, input: &serde_json::Value) -> String {
self.ui_text(input)
}
/// Runs the tool with the provided input.
fn run(
self: Arc<Self>,

View file

@ -33,8 +33,18 @@ pub struct CreateFileToolInput {
pub contents: String,
}
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
struct PartialInput {
#[serde(default)]
path: String,
#[serde(default)]
contents: String,
}
pub struct CreateFileTool;
const DEFAULT_UI_TEXT: &str = "Create file";
impl Tool for CreateFileTool {
fn name(&self) -> String {
"create_file".into()
@ -62,7 +72,14 @@ impl Tool for CreateFileTool {
let path = MarkdownString::inline_code(&input.path);
format!("Create file {path}")
}
Err(_) => "Create file".to_string(),
Err(_) => DEFAULT_UI_TEXT.to_string(),
}
}
fn still_streaming_ui_text(&self, input: &serde_json::Value) -> String {
match serde_json::from_value::<PartialInput>(input.clone()).ok() {
Some(input) if !input.path.is_empty() => input.path,
_ => DEFAULT_UI_TEXT.to_string(),
}
}
@ -111,3 +128,60 @@ impl Tool for CreateFileTool {
.into()
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn still_streaming_ui_text_with_path() {
let tool = CreateFileTool;
let input = json!({
"path": "src/main.rs",
"contents": "fn main() {\n println!(\"Hello, world!\");\n}"
});
assert_eq!(tool.still_streaming_ui_text(&input), "src/main.rs");
}
#[test]
fn still_streaming_ui_text_without_path() {
let tool = CreateFileTool;
let input = json!({
"path": "",
"contents": "fn main() {\n println!(\"Hello, world!\");\n}"
});
assert_eq!(tool.still_streaming_ui_text(&input), DEFAULT_UI_TEXT);
}
#[test]
fn still_streaming_ui_text_with_null() {
let tool = CreateFileTool;
let input = serde_json::Value::Null;
assert_eq!(tool.still_streaming_ui_text(&input), DEFAULT_UI_TEXT);
}
#[test]
fn ui_text_with_valid_input() {
let tool = CreateFileTool;
let input = json!({
"path": "src/main.rs",
"contents": "fn main() {\n println!(\"Hello, world!\");\n}"
});
assert_eq!(tool.ui_text(&input), "Create file `src/main.rs`");
}
#[test]
fn ui_text_with_invalid_input() {
let tool = CreateFileTool;
let input = json!({
"invalid": "field"
});
assert_eq!(tool.ui_text(&input), DEFAULT_UI_TEXT);
}
}

View file

@ -47,8 +47,22 @@ pub struct EditFileToolInput {
pub new_string: String,
}
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
struct PartialInput {
#[serde(default)]
path: String,
#[serde(default)]
display_description: String,
#[serde(default)]
old_string: String,
#[serde(default)]
new_string: String,
}
pub struct EditFileTool;
const DEFAULT_UI_TEXT: &str = "Edit file";
impl Tool for EditFileTool {
fn name(&self) -> String {
"edit_file".into()
@ -77,6 +91,22 @@ impl Tool for EditFileTool {
}
}
fn still_streaming_ui_text(&self, input: &serde_json::Value) -> String {
if let Some(input) = serde_json::from_value::<PartialInput>(input.clone()).ok() {
let description = input.display_description.trim();
if !description.is_empty() {
return description.to_string();
}
let path = input.path.trim();
if !path.is_empty() {
return path.to_string();
}
}
DEFAULT_UI_TEXT.to_string()
}
fn run(
self: Arc<Self>,
input: serde_json::Value,
@ -181,3 +211,69 @@ impl Tool for EditFileTool {
}).into()
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn still_streaming_ui_text_with_path() {
let tool = EditFileTool;
let input = json!({
"path": "src/main.rs",
"display_description": "",
"old_string": "old code",
"new_string": "new code"
});
assert_eq!(tool.still_streaming_ui_text(&input), "src/main.rs");
}
#[test]
fn still_streaming_ui_text_with_description() {
let tool = EditFileTool;
let input = json!({
"path": "",
"display_description": "Fix error handling",
"old_string": "old code",
"new_string": "new code"
});
assert_eq!(tool.still_streaming_ui_text(&input), "Fix error handling");
}
#[test]
fn still_streaming_ui_text_with_path_and_description() {
let tool = EditFileTool;
let input = json!({
"path": "src/main.rs",
"display_description": "Fix error handling",
"old_string": "old code",
"new_string": "new code"
});
assert_eq!(tool.still_streaming_ui_text(&input), "Fix error handling");
}
#[test]
fn still_streaming_ui_text_no_path_or_description() {
let tool = EditFileTool;
let input = json!({
"path": "",
"display_description": "",
"old_string": "old code",
"new_string": "new code"
});
assert_eq!(tool.still_streaming_ui_text(&input), DEFAULT_UI_TEXT);
}
#[test]
fn still_streaming_ui_text_with_null() {
let tool = EditFileTool;
let input = serde_json::Value::Null;
assert_eq!(tool.still_streaming_ui_text(&input), DEFAULT_UI_TEXT);
}
}

View file

@ -426,6 +426,7 @@ impl Example {
ThreadEvent::ToolConfirmationNeeded => {
panic!("{}Bug: Tool confirmation should not be required in eval", log_prefix);
},
ThreadEvent::StreamedToolUse { .. } |
ThreadEvent::StreamedCompletion |
ThreadEvent::MessageAdded(_) |
ThreadEvent::MessageEdited(_) |

View file

@ -187,6 +187,7 @@ pub struct LanguageModelToolUse {
pub id: LanguageModelToolUseId,
pub name: Arc<str>,
pub input: serde_json::Value,
pub is_input_complete: bool,
}
pub struct LanguageModelTextStream {

View file

@ -38,6 +38,7 @@ menu.workspace = true
mistral = { workspace = true, features = ["schemars"] }
ollama = { workspace = true, features = ["schemars"] }
open_ai = { workspace = true, features = ["schemars"] }
partial-json-fixer.workspace = true
project.workspace = true
proto.workspace = true
schemars.workspace = true

View file

@ -713,6 +713,35 @@ pub fn map_to_language_model_completion_events(
ContentDelta::InputJsonDelta { partial_json } => {
if let Some(tool_use) = state.tool_uses_by_index.get_mut(&index) {
tool_use.input_json.push_str(&partial_json);
return Some((
vec![maybe!({
Ok(LanguageModelCompletionEvent::ToolUse(
LanguageModelToolUse {
id: tool_use.id.clone().into(),
name: tool_use.name.clone().into(),
is_input_complete: false,
input: if tool_use.input_json.is_empty() {
serde_json::Value::Object(
serde_json::Map::default(),
)
} else {
serde_json::Value::from_str(
// Convert invalid (incomplete) JSON into
// JSON that serde will accept, e.g. by closing
// unclosed delimiters. This way, we can update
// the UI with whatever has been streamed back so far.
&partial_json_fixer::fix_json(
&tool_use.input_json,
),
)
.map_err(|err| anyhow!(err))?
},
},
))
})],
state,
));
}
}
},
@ -724,6 +753,7 @@ pub fn map_to_language_model_completion_events(
LanguageModelToolUse {
id: tool_use.id.into(),
name: tool_use.name.into(),
is_input_complete: true,
input: if tool_use.input_json.is_empty() {
serde_json::Value::Object(
serde_json::Map::default(),

View file

@ -893,6 +893,7 @@ pub fn map_to_language_model_completion_events(
let tool_use_event = LanguageModelToolUse {
id: tool_use.id.into(),
name: tool_use.name.into(),
is_input_complete: true,
input: if tool_use.input_json.is_empty() {
Value::Null
} else {

View file

@ -367,6 +367,7 @@ pub fn map_to_language_model_completion_events(
LanguageModelToolUse {
id: tool_call.id.into(),
name: tool_call.name.as_str().into(),
is_input_complete: true,
input: serde_json::Value::from_str(
&tool_call.arguments,
)?,

View file

@ -529,6 +529,7 @@ pub fn map_to_language_model_completion_events(
LanguageModelToolUse {
id,
name,
is_input_complete: true,
input: function_call_part.function_call.args,
},
)));

View file

@ -490,6 +490,7 @@ pub fn map_to_language_model_completion_events(
LanguageModelToolUse {
id: tool_call.id.into(),
name: tool_call.name.as_str().into(),
is_input_complete: true,
input: serde_json::Value::from_str(
&tool_call.arguments,
)?,

View file

@ -192,6 +192,11 @@ impl Markdown {
self.parse(cx);
}
pub fn replace(&mut self, source: impl Into<SharedString>, cx: &mut Context<Self>) {
self.source = source.into();
self.parse(cx);
}
pub fn reset(&mut self, source: SharedString, cx: &mut Context<Self>) {
if source == self.source() {
return;