Streaming tool calls (#29179)
https://github.com/user-attachments/assets/7854a737-ef83-414c-b397-45122e4f32e8 Release Notes: - Create file and edit file tools now stream their tool descriptions, so you can see what they're doing sooner. --------- Co-authored-by: Marshall Bowers <git@maxdeviant.com>
This commit is contained in:
parent
7aa0fa1543
commit
4f2f9ff762
17 changed files with 358 additions and 47 deletions
7
Cargo.lock
generated
7
Cargo.lock
generated
|
@ -7713,6 +7713,7 @@ dependencies = [
|
||||||
"mistral",
|
"mistral",
|
||||||
"ollama",
|
"ollama",
|
||||||
"open_ai",
|
"open_ai",
|
||||||
|
"partial-json-fixer",
|
||||||
"project",
|
"project",
|
||||||
"proto",
|
"proto",
|
||||||
"schemars",
|
"schemars",
|
||||||
|
@ -9828,6 +9829,12 @@ dependencies = [
|
||||||
"windows-targets 0.52.6",
|
"windows-targets 0.52.6",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "partial-json-fixer"
|
||||||
|
version = "0.5.3"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "35ffd90b3f3b6477db7478016b9efb1b7e9d38eafd095f0542fe0ec2ea884a13"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "password-hash"
|
name = "password-hash"
|
||||||
version = "0.4.2"
|
version = "0.4.2"
|
||||||
|
|
|
@ -480,6 +480,7 @@ num-format = "0.4.4"
|
||||||
ordered-float = "2.1.1"
|
ordered-float = "2.1.1"
|
||||||
palette = { version = "0.7.5", default-features = false, features = ["std"] }
|
palette = { version = "0.7.5", default-features = false, features = ["std"] }
|
||||||
parking_lot = "0.12.1"
|
parking_lot = "0.12.1"
|
||||||
|
partial-json-fixer = "0.5.3"
|
||||||
pathdiff = "0.2"
|
pathdiff = "0.2"
|
||||||
pet = { git = "https://github.com/microsoft/python-environment-tools.git", rev = "845945b830297a50de0e24020b980a65e4820559" }
|
pet = { git = "https://github.com/microsoft/python-environment-tools.git", rev = "845945b830297a50de0e24020b980a65e4820559" }
|
||||||
pet-fs = { git = "https://github.com/microsoft/python-environment-tools.git", rev = "845945b830297a50de0e24020b980a65e4820559" }
|
pet-fs = { git = "https://github.com/microsoft/python-environment-tools.git", rev = "845945b830297a50de0e24020b980a65e4820559" }
|
||||||
|
|
|
@ -266,14 +266,6 @@ fn default_markdown_style(window: &Window, cx: &App) -> MarkdownStyle {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn render_tool_use_markdown(
|
|
||||||
text: SharedString,
|
|
||||||
language_registry: Arc<LanguageRegistry>,
|
|
||||||
cx: &mut App,
|
|
||||||
) -> Entity<Markdown> {
|
|
||||||
cx.new(|cx| Markdown::new(text, Some(language_registry), None, cx))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn tool_use_markdown_style(window: &Window, cx: &mut App) -> MarkdownStyle {
|
fn tool_use_markdown_style(window: &Window, cx: &mut App) -> MarkdownStyle {
|
||||||
let theme_settings = ThemeSettings::get_global(cx);
|
let theme_settings = ThemeSettings::get_global(cx);
|
||||||
let colors = cx.theme().colors();
|
let colors = cx.theme().colors();
|
||||||
|
@ -867,21 +859,34 @@ impl ActiveThread {
|
||||||
tool_output: SharedString,
|
tool_output: SharedString,
|
||||||
cx: &mut Context<Self>,
|
cx: &mut Context<Self>,
|
||||||
) {
|
) {
|
||||||
let rendered = RenderedToolUse {
|
let rendered = self
|
||||||
label: render_tool_use_markdown(tool_label.into(), self.language_registry.clone(), cx),
|
.rendered_tool_uses
|
||||||
input: render_tool_use_markdown(
|
.entry(tool_use_id.clone())
|
||||||
format!(
|
.or_insert_with(|| RenderedToolUse {
|
||||||
|
label: cx.new(|cx| {
|
||||||
|
Markdown::new("".into(), Some(self.language_registry.clone()), None, cx)
|
||||||
|
}),
|
||||||
|
input: cx.new(|cx| {
|
||||||
|
Markdown::new("".into(), Some(self.language_registry.clone()), None, cx)
|
||||||
|
}),
|
||||||
|
output: cx.new(|cx| {
|
||||||
|
Markdown::new("".into(), Some(self.language_registry.clone()), None, cx)
|
||||||
|
}),
|
||||||
|
});
|
||||||
|
|
||||||
|
rendered.label.update(cx, |this, cx| {
|
||||||
|
this.replace(tool_label, cx);
|
||||||
|
});
|
||||||
|
rendered.input.update(cx, |this, cx| {
|
||||||
|
let input = format!(
|
||||||
"```json\n{}\n```",
|
"```json\n{}\n```",
|
||||||
serde_json::to_string_pretty(tool_input).unwrap_or_default()
|
serde_json::to_string_pretty(tool_input).unwrap_or_default()
|
||||||
)
|
);
|
||||||
.into(),
|
this.replace(input, cx);
|
||||||
self.language_registry.clone(),
|
});
|
||||||
cx,
|
rendered.output.update(cx, |this, cx| {
|
||||||
),
|
this.replace(tool_output, cx);
|
||||||
output: render_tool_use_markdown(tool_output, self.language_registry.clone(), cx),
|
});
|
||||||
};
|
|
||||||
self.rendered_tool_uses
|
|
||||||
.insert(tool_use_id.clone(), rendered);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn handle_thread_event(
|
fn handle_thread_event(
|
||||||
|
@ -974,6 +979,19 @@ impl ActiveThread {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
ThreadEvent::StreamedToolUse {
|
||||||
|
tool_use_id,
|
||||||
|
ui_text,
|
||||||
|
input,
|
||||||
|
} => {
|
||||||
|
self.render_tool_use_markdown(
|
||||||
|
tool_use_id.clone(),
|
||||||
|
ui_text.clone(),
|
||||||
|
input,
|
||||||
|
"".into(),
|
||||||
|
cx,
|
||||||
|
);
|
||||||
|
}
|
||||||
ThreadEvent::ToolFinished {
|
ThreadEvent::ToolFinished {
|
||||||
pending_tool_use, ..
|
pending_tool_use, ..
|
||||||
} => {
|
} => {
|
||||||
|
@ -2478,13 +2496,15 @@ impl ActiveThread {
|
||||||
let edit_tools = tool_use.needs_confirmation;
|
let edit_tools = tool_use.needs_confirmation;
|
||||||
|
|
||||||
let status_icons = div().child(match &tool_use.status {
|
let status_icons = div().child(match &tool_use.status {
|
||||||
ToolUseStatus::Pending | ToolUseStatus::NeedsConfirmation => {
|
ToolUseStatus::NeedsConfirmation => {
|
||||||
let icon = Icon::new(IconName::Warning)
|
let icon = Icon::new(IconName::Warning)
|
||||||
.color(Color::Warning)
|
.color(Color::Warning)
|
||||||
.size(IconSize::Small);
|
.size(IconSize::Small);
|
||||||
icon.into_any_element()
|
icon.into_any_element()
|
||||||
}
|
}
|
||||||
ToolUseStatus::Running => {
|
ToolUseStatus::Pending
|
||||||
|
| ToolUseStatus::InputStillStreaming
|
||||||
|
| ToolUseStatus::Running => {
|
||||||
let icon = Icon::new(IconName::ArrowCircle)
|
let icon = Icon::new(IconName::ArrowCircle)
|
||||||
.color(Color::Accent)
|
.color(Color::Accent)
|
||||||
.size(IconSize::Small);
|
.size(IconSize::Small);
|
||||||
|
@ -2570,7 +2590,7 @@ impl ActiveThread {
|
||||||
}),
|
}),
|
||||||
)),
|
)),
|
||||||
),
|
),
|
||||||
ToolUseStatus::Running => container.child(
|
ToolUseStatus::InputStillStreaming | ToolUseStatus::Running => container.child(
|
||||||
results_content_container().child(
|
results_content_container().child(
|
||||||
h_flex()
|
h_flex()
|
||||||
.gap_1()
|
.gap_1()
|
||||||
|
|
|
@ -1293,12 +1293,27 @@ impl Thread {
|
||||||
thread.insert_message(Role::Assistant, vec![], cx)
|
thread.insert_message(Role::Assistant, vec![], cx)
|
||||||
});
|
});
|
||||||
|
|
||||||
thread.tool_use.request_tool_use(
|
let tool_use_id = tool_use.id.clone();
|
||||||
|
let streamed_input = if tool_use.is_input_complete {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some((&tool_use.input).clone())
|
||||||
|
};
|
||||||
|
|
||||||
|
let ui_text = thread.tool_use.request_tool_use(
|
||||||
last_assistant_message_id,
|
last_assistant_message_id,
|
||||||
tool_use,
|
tool_use,
|
||||||
tool_use_metadata.clone(),
|
tool_use_metadata.clone(),
|
||||||
cx,
|
cx,
|
||||||
);
|
);
|
||||||
|
|
||||||
|
if let Some(input) = streamed_input {
|
||||||
|
cx.emit(ThreadEvent::StreamedToolUse {
|
||||||
|
tool_use_id,
|
||||||
|
ui_text,
|
||||||
|
input,
|
||||||
|
});
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2189,6 +2204,11 @@ pub enum ThreadEvent {
|
||||||
StreamedCompletion,
|
StreamedCompletion,
|
||||||
StreamedAssistantText(MessageId, String),
|
StreamedAssistantText(MessageId, String),
|
||||||
StreamedAssistantThinking(MessageId, String),
|
StreamedAssistantThinking(MessageId, String),
|
||||||
|
StreamedToolUse {
|
||||||
|
tool_use_id: LanguageModelToolUseId,
|
||||||
|
ui_text: Arc<str>,
|
||||||
|
input: serde_json::Value,
|
||||||
|
},
|
||||||
Stopped(Result<StopReason, Arc<anyhow::Error>>),
|
Stopped(Result<StopReason, Arc<anyhow::Error>>),
|
||||||
MessageAdded(MessageId),
|
MessageAdded(MessageId),
|
||||||
MessageEdited(MessageId),
|
MessageEdited(MessageId),
|
||||||
|
|
|
@ -75,6 +75,7 @@ impl ToolUseState {
|
||||||
id: tool_use.id.clone(),
|
id: tool_use.id.clone(),
|
||||||
name: tool_use.name.clone().into(),
|
name: tool_use.name.clone().into(),
|
||||||
input: tool_use.input.clone(),
|
input: tool_use.input.clone(),
|
||||||
|
is_input_complete: true,
|
||||||
})
|
})
|
||||||
.collect::<Vec<_>>();
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
|
@ -176,6 +177,9 @@ impl ToolUseState {
|
||||||
PendingToolUseStatus::Error(ref err) => {
|
PendingToolUseStatus::Error(ref err) => {
|
||||||
ToolUseStatus::Error(err.clone().into())
|
ToolUseStatus::Error(err.clone().into())
|
||||||
}
|
}
|
||||||
|
PendingToolUseStatus::InputStillStreaming => {
|
||||||
|
ToolUseStatus::InputStillStreaming
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
ToolUseStatus::Pending
|
ToolUseStatus::Pending
|
||||||
|
@ -192,7 +196,12 @@ impl ToolUseState {
|
||||||
tool_uses.push(ToolUse {
|
tool_uses.push(ToolUse {
|
||||||
id: tool_use.id.clone(),
|
id: tool_use.id.clone(),
|
||||||
name: tool_use.name.clone().into(),
|
name: tool_use.name.clone().into(),
|
||||||
ui_text: self.tool_ui_label(&tool_use.name, &tool_use.input, cx),
|
ui_text: self.tool_ui_label(
|
||||||
|
&tool_use.name,
|
||||||
|
&tool_use.input,
|
||||||
|
tool_use.is_input_complete,
|
||||||
|
cx,
|
||||||
|
),
|
||||||
input: tool_use.input.clone(),
|
input: tool_use.input.clone(),
|
||||||
status,
|
status,
|
||||||
icon,
|
icon,
|
||||||
|
@ -207,10 +216,15 @@ impl ToolUseState {
|
||||||
&self,
|
&self,
|
||||||
tool_name: &str,
|
tool_name: &str,
|
||||||
input: &serde_json::Value,
|
input: &serde_json::Value,
|
||||||
|
is_input_complete: bool,
|
||||||
cx: &App,
|
cx: &App,
|
||||||
) -> SharedString {
|
) -> SharedString {
|
||||||
if let Some(tool) = self.tools.read(cx).tool(tool_name, cx) {
|
if let Some(tool) = self.tools.read(cx).tool(tool_name, cx) {
|
||||||
|
if is_input_complete {
|
||||||
tool.ui_text(input).into()
|
tool.ui_text(input).into()
|
||||||
|
} else {
|
||||||
|
tool.still_streaming_ui_text(input).into()
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
format!("Unknown tool {tool_name:?}").into()
|
format!("Unknown tool {tool_name:?}").into()
|
||||||
}
|
}
|
||||||
|
@ -258,12 +272,26 @@ impl ToolUseState {
|
||||||
tool_use: LanguageModelToolUse,
|
tool_use: LanguageModelToolUse,
|
||||||
metadata: ToolUseMetadata,
|
metadata: ToolUseMetadata,
|
||||||
cx: &App,
|
cx: &App,
|
||||||
) {
|
) -> Arc<str> {
|
||||||
self.tool_uses_by_assistant_message
|
let tool_uses = self
|
||||||
|
.tool_uses_by_assistant_message
|
||||||
.entry(assistant_message_id)
|
.entry(assistant_message_id)
|
||||||
.or_default()
|
.or_default();
|
||||||
.push(tool_use.clone());
|
|
||||||
|
|
||||||
|
let mut existing_tool_use_found = false;
|
||||||
|
|
||||||
|
for existing_tool_use in tool_uses.iter_mut() {
|
||||||
|
if existing_tool_use.id == tool_use.id {
|
||||||
|
*existing_tool_use = tool_use.clone();
|
||||||
|
existing_tool_use_found = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !existing_tool_use_found {
|
||||||
|
tool_uses.push(tool_use.clone());
|
||||||
|
}
|
||||||
|
|
||||||
|
let status = if tool_use.is_input_complete {
|
||||||
self.tool_use_metadata_by_id
|
self.tool_use_metadata_by_id
|
||||||
.insert(tool_use.id.clone(), metadata);
|
.insert(tool_use.id.clone(), metadata);
|
||||||
|
|
||||||
|
@ -275,19 +303,33 @@ impl ToolUseState {
|
||||||
.or_default()
|
.or_default()
|
||||||
.push(tool_use.id.clone());
|
.push(tool_use.id.clone());
|
||||||
|
|
||||||
|
PendingToolUseStatus::Idle
|
||||||
|
} else {
|
||||||
|
PendingToolUseStatus::InputStillStreaming
|
||||||
|
};
|
||||||
|
|
||||||
|
let ui_text: Arc<str> = self
|
||||||
|
.tool_ui_label(
|
||||||
|
&tool_use.name,
|
||||||
|
&tool_use.input,
|
||||||
|
tool_use.is_input_complete,
|
||||||
|
cx,
|
||||||
|
)
|
||||||
|
.into();
|
||||||
|
|
||||||
self.pending_tool_uses_by_id.insert(
|
self.pending_tool_uses_by_id.insert(
|
||||||
tool_use.id.clone(),
|
tool_use.id.clone(),
|
||||||
PendingToolUse {
|
PendingToolUse {
|
||||||
assistant_message_id,
|
assistant_message_id,
|
||||||
id: tool_use.id,
|
id: tool_use.id,
|
||||||
name: tool_use.name.clone(),
|
name: tool_use.name.clone(),
|
||||||
ui_text: self
|
ui_text: ui_text.clone(),
|
||||||
.tool_ui_label(&tool_use.name, &tool_use.input, cx)
|
|
||||||
.into(),
|
|
||||||
input: tool_use.input,
|
input: tool_use.input,
|
||||||
status: PendingToolUseStatus::Idle,
|
status,
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
|
|
||||||
|
ui_text
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn run_pending_tool(
|
pub fn run_pending_tool(
|
||||||
|
@ -497,6 +539,7 @@ pub struct Confirmation {
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub enum PendingToolUseStatus {
|
pub enum PendingToolUseStatus {
|
||||||
|
InputStillStreaming,
|
||||||
Idle,
|
Idle,
|
||||||
NeedsConfirmation(Arc<Confirmation>),
|
NeedsConfirmation(Arc<Confirmation>),
|
||||||
Running { _task: Shared<Task<()>> },
|
Running { _task: Shared<Task<()>> },
|
||||||
|
|
|
@ -30,6 +30,7 @@ pub fn init(cx: &mut App) {
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub enum ToolUseStatus {
|
pub enum ToolUseStatus {
|
||||||
|
InputStillStreaming,
|
||||||
NeedsConfirmation,
|
NeedsConfirmation,
|
||||||
Pending,
|
Pending,
|
||||||
Running,
|
Running,
|
||||||
|
@ -41,6 +42,7 @@ impl ToolUseStatus {
|
||||||
pub fn text(&self) -> SharedString {
|
pub fn text(&self) -> SharedString {
|
||||||
match self {
|
match self {
|
||||||
ToolUseStatus::NeedsConfirmation => "".into(),
|
ToolUseStatus::NeedsConfirmation => "".into(),
|
||||||
|
ToolUseStatus::InputStillStreaming => "".into(),
|
||||||
ToolUseStatus::Pending => "".into(),
|
ToolUseStatus::Pending => "".into(),
|
||||||
ToolUseStatus::Running => "".into(),
|
ToolUseStatus::Running => "".into(),
|
||||||
ToolUseStatus::Finished(out) => out.clone(),
|
ToolUseStatus::Finished(out) => out.clone(),
|
||||||
|
@ -148,6 +150,12 @@ pub trait Tool: 'static + Send + Sync {
|
||||||
/// Returns markdown to be displayed in the UI for this tool.
|
/// Returns markdown to be displayed in the UI for this tool.
|
||||||
fn ui_text(&self, input: &serde_json::Value) -> String;
|
fn ui_text(&self, input: &serde_json::Value) -> String;
|
||||||
|
|
||||||
|
/// Returns markdown to be displayed in the UI for this tool, while the input JSON is still streaming
|
||||||
|
/// (so information may be missing).
|
||||||
|
fn still_streaming_ui_text(&self, input: &serde_json::Value) -> String {
|
||||||
|
self.ui_text(input)
|
||||||
|
}
|
||||||
|
|
||||||
/// Runs the tool with the provided input.
|
/// Runs the tool with the provided input.
|
||||||
fn run(
|
fn run(
|
||||||
self: Arc<Self>,
|
self: Arc<Self>,
|
||||||
|
|
|
@ -33,8 +33,18 @@ pub struct CreateFileToolInput {
|
||||||
pub contents: String,
|
pub contents: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
|
||||||
|
struct PartialInput {
|
||||||
|
#[serde(default)]
|
||||||
|
path: String,
|
||||||
|
#[serde(default)]
|
||||||
|
contents: String,
|
||||||
|
}
|
||||||
|
|
||||||
pub struct CreateFileTool;
|
pub struct CreateFileTool;
|
||||||
|
|
||||||
|
const DEFAULT_UI_TEXT: &str = "Create file";
|
||||||
|
|
||||||
impl Tool for CreateFileTool {
|
impl Tool for CreateFileTool {
|
||||||
fn name(&self) -> String {
|
fn name(&self) -> String {
|
||||||
"create_file".into()
|
"create_file".into()
|
||||||
|
@ -62,7 +72,14 @@ impl Tool for CreateFileTool {
|
||||||
let path = MarkdownString::inline_code(&input.path);
|
let path = MarkdownString::inline_code(&input.path);
|
||||||
format!("Create file {path}")
|
format!("Create file {path}")
|
||||||
}
|
}
|
||||||
Err(_) => "Create file".to_string(),
|
Err(_) => DEFAULT_UI_TEXT.to_string(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn still_streaming_ui_text(&self, input: &serde_json::Value) -> String {
|
||||||
|
match serde_json::from_value::<PartialInput>(input.clone()).ok() {
|
||||||
|
Some(input) if !input.path.is_empty() => input.path,
|
||||||
|
_ => DEFAULT_UI_TEXT.to_string(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -111,3 +128,60 @@ impl Tool for CreateFileTool {
|
||||||
.into()
|
.into()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use serde_json::json;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn still_streaming_ui_text_with_path() {
|
||||||
|
let tool = CreateFileTool;
|
||||||
|
let input = json!({
|
||||||
|
"path": "src/main.rs",
|
||||||
|
"contents": "fn main() {\n println!(\"Hello, world!\");\n}"
|
||||||
|
});
|
||||||
|
|
||||||
|
assert_eq!(tool.still_streaming_ui_text(&input), "src/main.rs");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn still_streaming_ui_text_without_path() {
|
||||||
|
let tool = CreateFileTool;
|
||||||
|
let input = json!({
|
||||||
|
"path": "",
|
||||||
|
"contents": "fn main() {\n println!(\"Hello, world!\");\n}"
|
||||||
|
});
|
||||||
|
|
||||||
|
assert_eq!(tool.still_streaming_ui_text(&input), DEFAULT_UI_TEXT);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn still_streaming_ui_text_with_null() {
|
||||||
|
let tool = CreateFileTool;
|
||||||
|
let input = serde_json::Value::Null;
|
||||||
|
|
||||||
|
assert_eq!(tool.still_streaming_ui_text(&input), DEFAULT_UI_TEXT);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn ui_text_with_valid_input() {
|
||||||
|
let tool = CreateFileTool;
|
||||||
|
let input = json!({
|
||||||
|
"path": "src/main.rs",
|
||||||
|
"contents": "fn main() {\n println!(\"Hello, world!\");\n}"
|
||||||
|
});
|
||||||
|
|
||||||
|
assert_eq!(tool.ui_text(&input), "Create file `src/main.rs`");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn ui_text_with_invalid_input() {
|
||||||
|
let tool = CreateFileTool;
|
||||||
|
let input = json!({
|
||||||
|
"invalid": "field"
|
||||||
|
});
|
||||||
|
|
||||||
|
assert_eq!(tool.ui_text(&input), DEFAULT_UI_TEXT);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -47,8 +47,22 @@ pub struct EditFileToolInput {
|
||||||
pub new_string: String,
|
pub new_string: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
|
||||||
|
struct PartialInput {
|
||||||
|
#[serde(default)]
|
||||||
|
path: String,
|
||||||
|
#[serde(default)]
|
||||||
|
display_description: String,
|
||||||
|
#[serde(default)]
|
||||||
|
old_string: String,
|
||||||
|
#[serde(default)]
|
||||||
|
new_string: String,
|
||||||
|
}
|
||||||
|
|
||||||
pub struct EditFileTool;
|
pub struct EditFileTool;
|
||||||
|
|
||||||
|
const DEFAULT_UI_TEXT: &str = "Edit file";
|
||||||
|
|
||||||
impl Tool for EditFileTool {
|
impl Tool for EditFileTool {
|
||||||
fn name(&self) -> String {
|
fn name(&self) -> String {
|
||||||
"edit_file".into()
|
"edit_file".into()
|
||||||
|
@ -77,6 +91,22 @@ impl Tool for EditFileTool {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn still_streaming_ui_text(&self, input: &serde_json::Value) -> String {
|
||||||
|
if let Some(input) = serde_json::from_value::<PartialInput>(input.clone()).ok() {
|
||||||
|
let description = input.display_description.trim();
|
||||||
|
if !description.is_empty() {
|
||||||
|
return description.to_string();
|
||||||
|
}
|
||||||
|
|
||||||
|
let path = input.path.trim();
|
||||||
|
if !path.is_empty() {
|
||||||
|
return path.to_string();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
DEFAULT_UI_TEXT.to_string()
|
||||||
|
}
|
||||||
|
|
||||||
fn run(
|
fn run(
|
||||||
self: Arc<Self>,
|
self: Arc<Self>,
|
||||||
input: serde_json::Value,
|
input: serde_json::Value,
|
||||||
|
@ -181,3 +211,69 @@ impl Tool for EditFileTool {
|
||||||
}).into()
|
}).into()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use serde_json::json;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn still_streaming_ui_text_with_path() {
|
||||||
|
let tool = EditFileTool;
|
||||||
|
let input = json!({
|
||||||
|
"path": "src/main.rs",
|
||||||
|
"display_description": "",
|
||||||
|
"old_string": "old code",
|
||||||
|
"new_string": "new code"
|
||||||
|
});
|
||||||
|
|
||||||
|
assert_eq!(tool.still_streaming_ui_text(&input), "src/main.rs");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn still_streaming_ui_text_with_description() {
|
||||||
|
let tool = EditFileTool;
|
||||||
|
let input = json!({
|
||||||
|
"path": "",
|
||||||
|
"display_description": "Fix error handling",
|
||||||
|
"old_string": "old code",
|
||||||
|
"new_string": "new code"
|
||||||
|
});
|
||||||
|
|
||||||
|
assert_eq!(tool.still_streaming_ui_text(&input), "Fix error handling");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn still_streaming_ui_text_with_path_and_description() {
|
||||||
|
let tool = EditFileTool;
|
||||||
|
let input = json!({
|
||||||
|
"path": "src/main.rs",
|
||||||
|
"display_description": "Fix error handling",
|
||||||
|
"old_string": "old code",
|
||||||
|
"new_string": "new code"
|
||||||
|
});
|
||||||
|
|
||||||
|
assert_eq!(tool.still_streaming_ui_text(&input), "Fix error handling");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn still_streaming_ui_text_no_path_or_description() {
|
||||||
|
let tool = EditFileTool;
|
||||||
|
let input = json!({
|
||||||
|
"path": "",
|
||||||
|
"display_description": "",
|
||||||
|
"old_string": "old code",
|
||||||
|
"new_string": "new code"
|
||||||
|
});
|
||||||
|
|
||||||
|
assert_eq!(tool.still_streaming_ui_text(&input), DEFAULT_UI_TEXT);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn still_streaming_ui_text_with_null() {
|
||||||
|
let tool = EditFileTool;
|
||||||
|
let input = serde_json::Value::Null;
|
||||||
|
|
||||||
|
assert_eq!(tool.still_streaming_ui_text(&input), DEFAULT_UI_TEXT);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -426,6 +426,7 @@ impl Example {
|
||||||
ThreadEvent::ToolConfirmationNeeded => {
|
ThreadEvent::ToolConfirmationNeeded => {
|
||||||
panic!("{}Bug: Tool confirmation should not be required in eval", log_prefix);
|
panic!("{}Bug: Tool confirmation should not be required in eval", log_prefix);
|
||||||
},
|
},
|
||||||
|
ThreadEvent::StreamedToolUse { .. } |
|
||||||
ThreadEvent::StreamedCompletion |
|
ThreadEvent::StreamedCompletion |
|
||||||
ThreadEvent::MessageAdded(_) |
|
ThreadEvent::MessageAdded(_) |
|
||||||
ThreadEvent::MessageEdited(_) |
|
ThreadEvent::MessageEdited(_) |
|
||||||
|
|
|
@ -187,6 +187,7 @@ pub struct LanguageModelToolUse {
|
||||||
pub id: LanguageModelToolUseId,
|
pub id: LanguageModelToolUseId,
|
||||||
pub name: Arc<str>,
|
pub name: Arc<str>,
|
||||||
pub input: serde_json::Value,
|
pub input: serde_json::Value,
|
||||||
|
pub is_input_complete: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct LanguageModelTextStream {
|
pub struct LanguageModelTextStream {
|
||||||
|
|
|
@ -38,6 +38,7 @@ menu.workspace = true
|
||||||
mistral = { workspace = true, features = ["schemars"] }
|
mistral = { workspace = true, features = ["schemars"] }
|
||||||
ollama = { workspace = true, features = ["schemars"] }
|
ollama = { workspace = true, features = ["schemars"] }
|
||||||
open_ai = { workspace = true, features = ["schemars"] }
|
open_ai = { workspace = true, features = ["schemars"] }
|
||||||
|
partial-json-fixer.workspace = true
|
||||||
project.workspace = true
|
project.workspace = true
|
||||||
proto.workspace = true
|
proto.workspace = true
|
||||||
schemars.workspace = true
|
schemars.workspace = true
|
||||||
|
|
|
@ -713,6 +713,35 @@ pub fn map_to_language_model_completion_events(
|
||||||
ContentDelta::InputJsonDelta { partial_json } => {
|
ContentDelta::InputJsonDelta { partial_json } => {
|
||||||
if let Some(tool_use) = state.tool_uses_by_index.get_mut(&index) {
|
if let Some(tool_use) = state.tool_uses_by_index.get_mut(&index) {
|
||||||
tool_use.input_json.push_str(&partial_json);
|
tool_use.input_json.push_str(&partial_json);
|
||||||
|
|
||||||
|
return Some((
|
||||||
|
vec![maybe!({
|
||||||
|
Ok(LanguageModelCompletionEvent::ToolUse(
|
||||||
|
LanguageModelToolUse {
|
||||||
|
id: tool_use.id.clone().into(),
|
||||||
|
name: tool_use.name.clone().into(),
|
||||||
|
is_input_complete: false,
|
||||||
|
input: if tool_use.input_json.is_empty() {
|
||||||
|
serde_json::Value::Object(
|
||||||
|
serde_json::Map::default(),
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
serde_json::Value::from_str(
|
||||||
|
// Convert invalid (incomplete) JSON into
|
||||||
|
// JSON that serde will accept, e.g. by closing
|
||||||
|
// unclosed delimiters. This way, we can update
|
||||||
|
// the UI with whatever has been streamed back so far.
|
||||||
|
&partial_json_fixer::fix_json(
|
||||||
|
&tool_use.input_json,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
.map_err(|err| anyhow!(err))?
|
||||||
|
},
|
||||||
|
},
|
||||||
|
))
|
||||||
|
})],
|
||||||
|
state,
|
||||||
|
));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
@ -724,6 +753,7 @@ pub fn map_to_language_model_completion_events(
|
||||||
LanguageModelToolUse {
|
LanguageModelToolUse {
|
||||||
id: tool_use.id.into(),
|
id: tool_use.id.into(),
|
||||||
name: tool_use.name.into(),
|
name: tool_use.name.into(),
|
||||||
|
is_input_complete: true,
|
||||||
input: if tool_use.input_json.is_empty() {
|
input: if tool_use.input_json.is_empty() {
|
||||||
serde_json::Value::Object(
|
serde_json::Value::Object(
|
||||||
serde_json::Map::default(),
|
serde_json::Map::default(),
|
||||||
|
|
|
@ -893,6 +893,7 @@ pub fn map_to_language_model_completion_events(
|
||||||
let tool_use_event = LanguageModelToolUse {
|
let tool_use_event = LanguageModelToolUse {
|
||||||
id: tool_use.id.into(),
|
id: tool_use.id.into(),
|
||||||
name: tool_use.name.into(),
|
name: tool_use.name.into(),
|
||||||
|
is_input_complete: true,
|
||||||
input: if tool_use.input_json.is_empty() {
|
input: if tool_use.input_json.is_empty() {
|
||||||
Value::Null
|
Value::Null
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -367,6 +367,7 @@ pub fn map_to_language_model_completion_events(
|
||||||
LanguageModelToolUse {
|
LanguageModelToolUse {
|
||||||
id: tool_call.id.into(),
|
id: tool_call.id.into(),
|
||||||
name: tool_call.name.as_str().into(),
|
name: tool_call.name.as_str().into(),
|
||||||
|
is_input_complete: true,
|
||||||
input: serde_json::Value::from_str(
|
input: serde_json::Value::from_str(
|
||||||
&tool_call.arguments,
|
&tool_call.arguments,
|
||||||
)?,
|
)?,
|
||||||
|
|
|
@ -529,6 +529,7 @@ pub fn map_to_language_model_completion_events(
|
||||||
LanguageModelToolUse {
|
LanguageModelToolUse {
|
||||||
id,
|
id,
|
||||||
name,
|
name,
|
||||||
|
is_input_complete: true,
|
||||||
input: function_call_part.function_call.args,
|
input: function_call_part.function_call.args,
|
||||||
},
|
},
|
||||||
)));
|
)));
|
||||||
|
|
|
@ -490,6 +490,7 @@ pub fn map_to_language_model_completion_events(
|
||||||
LanguageModelToolUse {
|
LanguageModelToolUse {
|
||||||
id: tool_call.id.into(),
|
id: tool_call.id.into(),
|
||||||
name: tool_call.name.as_str().into(),
|
name: tool_call.name.as_str().into(),
|
||||||
|
is_input_complete: true,
|
||||||
input: serde_json::Value::from_str(
|
input: serde_json::Value::from_str(
|
||||||
&tool_call.arguments,
|
&tool_call.arguments,
|
||||||
)?,
|
)?,
|
||||||
|
|
|
@ -192,6 +192,11 @@ impl Markdown {
|
||||||
self.parse(cx);
|
self.parse(cx);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn replace(&mut self, source: impl Into<SharedString>, cx: &mut Context<Self>) {
|
||||||
|
self.source = source.into();
|
||||||
|
self.parse(cx);
|
||||||
|
}
|
||||||
|
|
||||||
pub fn reset(&mut self, source: SharedString, cx: &mut Context<Self>) {
|
pub fn reset(&mut self, source: SharedString, cx: &mut Context<Self>) {
|
||||||
if source == self.source() {
|
if source == self.source() {
|
||||||
return;
|
return;
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue