assistant_tool: Pass an Entity<Project> to Tool::run (#26312)

This PR updates the `Tool::run` method to take an `Entity<Project>`
instead of a `WeakEntity<Project>`.

Release Notes:

- N/A
This commit is contained in:
Marshall Bowers 2025-03-07 18:30:56 -05:00 committed by GitHub
parent 921c24e274
commit e70d0edfac
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 20 additions and 30 deletions

View file

@ -5,7 +5,7 @@ use assistant_tool::ToolWorkingSet;
use chrono::{DateTime, Utc};
use collections::{BTreeMap, HashMap, HashSet};
use futures::StreamExt as _;
use gpui::{App, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
use gpui::{App, Context, Entity, EventEmitter, SharedString, Task};
use language_model::{
LanguageModel, LanguageModelCompletionEvent, LanguageModelRegistry, LanguageModelRequest,
LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
@ -72,7 +72,7 @@ pub struct Thread {
context_by_message: HashMap<MessageId, Vec<ContextId>>,
completion_count: usize,
pending_completions: Vec<PendingCompletion>,
project: WeakEntity<Project>,
project: Entity<Project>,
tools: Arc<ToolWorkingSet>,
tool_use: ToolUseState,
}
@ -94,7 +94,7 @@ impl Thread {
context_by_message: HashMap::default(),
completion_count: 0,
pending_completions: Vec::new(),
project: project.downgrade(),
project,
tools,
tool_use: ToolUseState::new(),
}
@ -135,7 +135,7 @@ impl Thread {
context_by_message: HashMap::default(),
completion_count: 0,
pending_completions: Vec::new(),
project: project.downgrade(),
project,
tools,
tool_use,
}

View file

@ -4,7 +4,7 @@ mod tool_working_set;
use std::sync::Arc;
use anyhow::Result;
use gpui::{App, Task, WeakEntity};
use gpui::{App, Entity, Task};
use project::Project;
pub use crate::tool_registry::*;
@ -31,7 +31,7 @@ pub trait Tool: 'static + Send + Sync {
fn run(
self: Arc<Self>,
input: serde_json::Value,
project: WeakEntity<Project>,
project: Entity<Project>,
cx: &mut App,
) -> Task<Result<String>>;
}

View file

@ -1,8 +1,8 @@
use std::sync::Arc;
use anyhow::{anyhow, Result};
use anyhow::Result;
use assistant_tool::Tool;
use gpui::{App, Task, WeakEntity};
use gpui::{App, Entity, Task};
use project::Project;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
@ -34,13 +34,9 @@ impl Tool for ListWorktreesTool {
fn run(
self: Arc<Self>,
_input: serde_json::Value,
project: WeakEntity<Project>,
project: Entity<Project>,
cx: &mut App,
) -> Task<Result<String>> {
let Some(project) = project.upgrade() else {
return Task::ready(Err(anyhow!("project dropped")));
};
cx.spawn(|cx| async move {
cx.update(|cx| {
#[derive(Debug, Serialize)]

View file

@ -3,7 +3,7 @@ use std::sync::Arc;
use anyhow::{anyhow, Result};
use assistant_tool::Tool;
use chrono::{Local, Utc};
use gpui::{App, Task, WeakEntity};
use gpui::{App, Entity, Task};
use project::Project;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
@ -42,7 +42,7 @@ impl Tool for NowTool {
fn run(
self: Arc<Self>,
input: serde_json::Value,
_project: WeakEntity<Project>,
_project: Entity<Project>,
_cx: &mut App,
) -> Task<Result<String>> {
let input: NowToolInput = match serde_json::from_value(input) {

View file

@ -3,7 +3,7 @@ use std::sync::Arc;
use anyhow::{anyhow, Result};
use assistant_tool::Tool;
use gpui::{App, Task, WeakEntity};
use gpui::{App, Entity, Task};
use project::{Project, ProjectPath, WorktreeId};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
@ -37,13 +37,9 @@ impl Tool for ReadFileTool {
fn run(
self: Arc<Self>,
input: serde_json::Value,
project: WeakEntity<Project>,
project: Entity<Project>,
cx: &mut App,
) -> Task<Result<String>> {
let Some(project) = project.upgrade() else {
return Task::ready(Err(anyhow!("project dropped")));
};
let input = match serde_json::from_value::<ReadFileToolInput>(input) {
Ok(input) => input,
Err(err) => return Task::ready(Err(anyhow!(err))),

View file

@ -1,8 +1,9 @@
use std::sync::Arc;
use anyhow::{anyhow, bail};
use anyhow::{anyhow, bail, Result};
use assistant_tool::Tool;
use gpui::{App, Entity, Task};
use project::Project;
use crate::manager::ContextServerManager;
use crate::types;
@ -49,11 +50,11 @@ impl Tool for ContextServerTool {
}
fn run(
self: std::sync::Arc<Self>,
self: Arc<Self>,
input: serde_json::Value,
_project: gpui::WeakEntity<project::Project>,
_project: Entity<Project>,
cx: &mut App,
) -> gpui::Task<gpui::Result<String>> {
) -> Task<Result<String>> {
if let Some(server) = self.server_manager.read(cx).get_server(&self.server_id) {
cx.foreground_executor().spawn({
let tool_name = self.tool.name.clone();

View file

@ -4,7 +4,7 @@ use project::Project;
pub(crate) use session::*;
use assistant_tool::{Tool, ToolRegistry};
use gpui::{App, AppContext as _, Task, WeakEntity};
use gpui::{App, AppContext as _, Entity, Task};
use schemars::JsonSchema;
use serde::Deserialize;
use std::sync::Arc;
@ -38,16 +38,13 @@ impl Tool for ScriptingTool {
fn run(
self: Arc<Self>,
input: serde_json::Value,
project: WeakEntity<Project>,
project: Entity<Project>,
cx: &mut App,
) -> Task<anyhow::Result<String>> {
let input = match serde_json::from_value::<ScriptingToolInput>(input) {
Err(err) => return Task::ready(Err(err.into())),
Ok(input) => input,
};
let Some(project) = project.upgrade() else {
return Task::ready(Err(anyhow::anyhow!("project dropped")));
};
let session = cx.new(|cx| Session::new(project, cx));
let lua_script = input.lua_script;