Display what the tool is doing (#27120)

<img width="639" alt="Screenshot 2025-03-19 at 4 56 47 PM"
src="https://github.com/user-attachments/assets/b997f04d-4aff-4070-87b1-ffdb61019bd1"
/>

Release Notes:

- N/A

---------

Co-authored-by: Agus Zubiaga <hi@aguz.me>
This commit is contained in:
Richard Feldman 2025-03-20 09:16:39 -04:00 committed by GitHub
parent aae81fd54c
commit e3578fc44a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
18 changed files with 349 additions and 132 deletions

View file

@ -146,10 +146,10 @@ impl Thread {
pending_completions: Vec::new(),
project: project.clone(),
prompt_builder,
tools,
tool_use: ToolUseState::new(),
tools: tools.clone(),
tool_use: ToolUseState::new(tools.clone()),
scripting_session: cx.new(|cx| ScriptingSession::new(project.clone(), cx)),
scripting_tool_use: ToolUseState::new(),
scripting_tool_use: ToolUseState::new(tools),
action_log: cx.new(|_| ActionLog::new()),
initial_project_snapshot: {
let project_snapshot = Self::project_snapshot(project, cx);
@ -176,11 +176,12 @@ impl Thread {
.map(|message| message.id.0 + 1)
.unwrap_or(0),
);
let tool_use = ToolUseState::from_serialized_messages(&serialized.messages, |name| {
name != ScriptingTool::NAME
});
let tool_use =
ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages, |name| {
name != ScriptingTool::NAME
});
let scripting_tool_use =
ToolUseState::from_serialized_messages(&serialized.messages, |name| {
ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages, |name| {
name == ScriptingTool::NAME
});
let scripting_session = cx.new(|cx| ScriptingSession::new(project.clone(), cx));
@ -328,12 +329,12 @@ impl Thread {
all_pending_tool_uses.all(|tool_use| tool_use.status.is_error())
}
pub fn tool_uses_for_message(&self, id: MessageId) -> Vec<ToolUse> {
self.tool_use.tool_uses_for_message(id)
pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
self.tool_use.tool_uses_for_message(id, cx)
}
pub fn scripting_tool_uses_for_message(&self, id: MessageId) -> Vec<ToolUse> {
self.scripting_tool_use.tool_uses_for_message(id)
pub fn scripting_tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
self.scripting_tool_use.tool_uses_for_message(id, cx)
}
pub fn tool_results_for_message(&self, id: MessageId) -> Vec<&LanguageModelToolResult> {
@ -448,7 +449,7 @@ impl Thread {
let initial_project_snapshot = self.initial_project_snapshot.clone();
cx.spawn(async move |this, cx| {
let initial_project_snapshot = initial_project_snapshot.await;
this.read_with(cx, |this, _| SerializedThread {
this.read_with(cx, |this, cx| SerializedThread {
summary: this.summary_or_default(),
updated_at: this.updated_at(),
messages: this
@ -458,9 +459,9 @@ impl Thread {
role: message.role,
text: message.text.clone(),
tool_uses: this
.tool_uses_for_message(message.id)
.tool_uses_for_message(message.id, cx)
.into_iter()
.chain(this.scripting_tool_uses_for_message(message.id))
.chain(this.scripting_tool_uses_for_message(message.id, cx))
.map(|tool_use| SerializedToolUse {
id: tool_use.id,
name: tool_use.name,
@ -809,13 +810,17 @@ impl Thread {
.rfind(|message| message.role == Role::Assistant)
{
if tool_use.name.as_ref() == ScriptingTool::NAME {
thread
.scripting_tool_use
.request_tool_use(last_assistant_message.id, tool_use);
thread.scripting_tool_use.request_tool_use(
last_assistant_message.id,
tool_use,
cx,
);
} else {
thread
.tool_use
.request_tool_use(last_assistant_message.id, tool_use);
thread.tool_use.request_tool_use(
last_assistant_message.id,
tool_use,
cx,
);
}
}
}
@ -956,7 +961,10 @@ impl Thread {
});
}
pub fn use_pending_tools(&mut self, cx: &mut Context<Self>) {
pub fn use_pending_tools(
&mut self,
cx: &mut Context<Self>,
) -> impl IntoIterator<Item = PendingToolUse> {
let request = self.to_completion_request(RequestKind::Chat, cx);
let pending_tool_uses = self
.tool_use
@ -966,17 +974,22 @@ impl Thread {
.cloned()
.collect::<Vec<_>>();
for tool_use in pending_tool_uses {
for tool_use in pending_tool_uses.iter() {
if let Some(tool) = self.tools.tool(&tool_use.name, cx) {
let task = tool.run(
tool_use.input,
tool_use.input.clone(),
&request.messages,
self.project.clone(),
self.action_log.clone(),
cx,
);
self.insert_tool_output(tool_use.id.clone(), task, cx);
self.insert_tool_output(
tool_use.id.clone(),
tool_use.ui_text.clone().into(),
task,
cx,
);
}
}
@ -988,8 +1001,8 @@ impl Thread {
.cloned()
.collect::<Vec<_>>();
for scripting_tool_use in pending_scripting_tool_uses {
let task = match ScriptingTool::deserialize_input(scripting_tool_use.input) {
for scripting_tool_use in pending_scripting_tool_uses.iter() {
let task = match ScriptingTool::deserialize_input(scripting_tool_use.input.clone()) {
Err(err) => Task::ready(Err(err.into())),
Ok(input) => {
let (script_id, script_task) =
@ -1016,13 +1029,20 @@ impl Thread {
}
};
self.insert_scripting_tool_output(scripting_tool_use.id.clone(), task, cx);
let ui_text: SharedString = scripting_tool_use.name.clone().into();
self.insert_scripting_tool_output(scripting_tool_use.id.clone(), ui_text, task, cx);
}
pending_tool_uses
.into_iter()
.chain(pending_scripting_tool_uses)
}
pub fn insert_tool_output(
&mut self,
tool_use_id: LanguageModelToolUseId,
ui_text: SharedString,
output: Task<Result<String>>,
cx: &mut Context<Self>,
) {
@ -1047,12 +1067,13 @@ impl Thread {
});
self.tool_use
.run_pending_tool(tool_use_id, insert_output_task);
.run_pending_tool(tool_use_id, ui_text, insert_output_task);
}
pub fn insert_scripting_tool_output(
&mut self,
tool_use_id: LanguageModelToolUseId,
ui_text: SharedString,
output: Task<Result<String>>,
cx: &mut Context<Self>,
) {
@ -1077,7 +1098,7 @@ impl Thread {
});
self.scripting_tool_use
.run_pending_tool(tool_use_id, insert_output_task);
.run_pending_tool(tool_use_id, ui_text, insert_output_task);
}
pub fn attach_tool_results(
@ -1250,7 +1271,7 @@ impl Thread {
})
}
pub fn to_markdown(&self) -> Result<String> {
pub fn to_markdown(&self, cx: &App) -> Result<String> {
let mut markdown = Vec::new();
if let Some(summary) = self.summary() {
@ -1269,7 +1290,7 @@ impl Thread {
)?;
writeln!(markdown, "{}\n", message.text)?;
for tool_use in self.tool_uses_for_message(message.id) {
for tool_use in self.tool_uses_for_message(message.id, cx) {
writeln!(
markdown,
"**Use Tool: {} ({})**",