Take a weak thread in EditFileTool to avoid cycle

This commit is contained in:
Antonio Scandurra 2025-08-18 10:13:02 +02:00
parent fae5900749
commit a231fd3ee5
3 changed files with 89 additions and 75 deletions

View file

@ -574,7 +574,7 @@ impl NativeAgentConnection {
thread.add_tool(CreateDirectoryTool::new(project.clone()));
thread.add_tool(DeletePathTool::new(project.clone(), action_log.clone()));
thread.add_tool(DiagnosticsTool::new(project.clone()));
thread.add_tool(EditFileTool::new(cx.entity()));
thread.add_tool(EditFileTool::new(cx.weak_entity()));
thread.add_tool(FetchTool::new(project.read(cx).client().http_client()));
thread.add_tool(FindPathTool::new(project.clone()));
thread.add_tool(GrepTool::new(project.clone()));
@ -801,7 +801,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
fn load_thread(
self: Rc<Self>,
project: Entity<Project>,
cwd: &Path,
_cwd: &Path,
session_id: acp::SessionId,
cx: &mut App,
) -> Task<Result<Entity<acp_thread::AcpThread>>> {
@ -828,46 +828,43 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
let agent = self.0.clone();
// Create Thread
let thread = agent.update(
cx,
|agent, cx: &mut gpui::Context<NativeAgent>| -> Result<_> {
let configured_model = LanguageModelRegistry::global(cx)
.update(cx, |registry, cx| {
db_thread
.model
.and_then(|model| {
let model = SelectedModel {
provider: model.provider.clone().into(),
model: model.model.clone().into(),
};
registry.select_model(&model, cx)
})
.or_else(|| registry.default_model())
})
.context("no default model configured")?;
let thread = agent.update(cx, |agent, cx| {
let configured_model = LanguageModelRegistry::global(cx)
.update(cx, |registry, cx| {
db_thread
.model
.and_then(|model| {
let model = SelectedModel {
provider: model.provider.clone().into(),
model: model.model.clone().into(),
};
registry.select_model(&model, cx)
})
.or_else(|| registry.default_model())
})
.context("no default model configured")?;
let model = agent
.models
.model_from_id(&LanguageModels::model_id(&configured_model.model))
.context("no model by id")?;
let model = agent
.models
.model_from_id(&LanguageModels::model_id(&configured_model.model))
.context("no model by id")?;
let thread = cx.new(|cx| {
let mut thread = Thread::new(
project.clone(),
agent.project_context.clone(),
agent.context_server_registry.clone(),
action_log.clone(),
agent.templates.clone(),
model,
cx,
);
Self::register_tools(&mut thread, project, action_log, cx);
thread
});
let thread = cx.new(|cx| {
let mut thread = Thread::new(
project.clone(),
agent.project_context.clone(),
agent.context_server_registry.clone(),
action_log.clone(),
agent.templates.clone(),
model,
cx,
);
Self::register_tools(&mut thread, project, action_log, cx);
thread
});
Ok(thread)
},
)??;
anyhow::Ok(thread)
})??;
// Store the session
agent.update(cx, |agent, cx| {
@ -884,7 +881,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
})?;
// we need to actually deserialize the DbThread.
todo!()
// todo!()
Ok(acp_thread)
})

View file

@ -441,7 +441,7 @@ impl Thread {
cx: &mut Context<Self>,
) -> Self {
let profile_id = AgentSettings::get_global(cx).default_profile.clone();
let this = Self {
Self {
messages: Vec::new(),
completion_mode: CompletionMode::Normal,
running_turn: None,
@ -455,7 +455,7 @@ impl Thread {
model,
project,
action_log,
};
}
}
pub fn project(&self) -> &Entity<Project> {

View file

@ -5,7 +5,7 @@ use anyhow::{Context as _, Result, anyhow};
use assistant_tools::edit_agent::{EditAgent, EditAgentOutput, EditAgentOutputEvent, EditFormat};
use cloud_llm_client::CompletionIntent;
use collections::HashSet;
use gpui::{App, AppContext, AsyncApp, Entity, Task};
use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity};
use indoc::formatdoc;
use language::ToPoint;
use language::language_settings::{self, FormatOnSave};
@ -122,11 +122,11 @@ impl From<EditFileToolOutput> for LanguageModelToolResultContent {
}
pub struct EditFileTool {
thread: Entity<Thread>,
thread: WeakEntity<Thread>,
}
impl EditFileTool {
pub fn new(thread: Entity<Thread>) -> Self {
pub fn new(thread: WeakEntity<Thread>) -> Self {
Self { thread }
}
@ -167,8 +167,11 @@ impl EditFileTool {
// Check if path is inside the global config directory
// First check if it's already inside project - if not, try to canonicalize
let thread = self.thread.read(cx);
let project_path = thread.project().read(cx).find_project_path(&input.path, cx);
let Ok(project_path) = self.thread.read_with(cx, |thread, cx| {
thread.project().read(cx).find_project_path(&input.path, cx)
}) else {
return Task::ready(Err(anyhow!("thread was dropped")));
};
// If the path is inside the project, and it's not one of the above edge cases,
// then no confirmation is necessary. Otherwise, confirmation is necessary.
@ -221,7 +224,12 @@ impl AgentTool for EditFileTool {
event_stream: ToolCallEventStream,
cx: &mut App,
) -> Task<Result<Self::Output>> {
let project = self.thread.read(cx).project().clone();
let Ok(project) = self
.thread
.read_with(cx, |thread, _cx| thread.project().clone())
else {
return Task::ready(Err(anyhow!("thread was dropped")));
};
let project_path = match resolve_path(&input, project.clone(), cx) {
Ok(path) => path,
Err(err) => return Task::ready(Err(anyhow!(err))),
@ -237,17 +245,15 @@ impl AgentTool for EditFileTool {
});
}
let request = self.thread.update(cx, |thread, cx| {
thread.build_completion_request(CompletionIntent::ToolResults, cx)
});
let thread = self.thread.read(cx);
let model = thread.model().clone();
let action_log = thread.action_log().clone();
let authorize = self.authorize(&input, &event_stream, cx);
cx.spawn(async move |cx: &mut AsyncApp| {
authorize.await?;
let (request, model, action_log) = self.thread.update(cx, |thread, cx| {
let request = thread.build_completion_request(CompletionIntent::ToolResults, cx);
(request, thread.model().clone(), thread.action_log().clone())
})?;
let edit_format = EditFormat::from_model(model.clone())?;
let edit_agent = EditAgent::new(
model,
@ -531,7 +537,11 @@ mod tests {
path: "root/nonexistent_file.txt".into(),
mode: EditFileMode::Edit,
};
Arc::new(EditFileTool { thread }).run(input, ToolCallEventStream::test().0, cx)
Arc::new(EditFileTool::new(thread.downgrade())).run(
input,
ToolCallEventStream::test().0,
cx,
)
})
.await;
assert_eq!(
@ -744,10 +754,11 @@ mod tests {
path: "root/src/main.rs".into(),
mode: EditFileMode::Overwrite,
};
Arc::new(EditFileTool {
thread: thread.clone(),
})
.run(input, ToolCallEventStream::test().0, cx)
Arc::new(EditFileTool::new(thread.downgrade())).run(
input,
ToolCallEventStream::test().0,
cx,
)
});
// Stream the unformatted content
@ -800,7 +811,11 @@ mod tests {
path: "root/src/main.rs".into(),
mode: EditFileMode::Overwrite,
};
Arc::new(EditFileTool { thread }).run(input, ToolCallEventStream::test().0, cx)
Arc::new(EditFileTool::new(thread.downgrade())).run(
input,
ToolCallEventStream::test().0,
cx,
)
});
// Stream the unformatted content
@ -881,10 +896,11 @@ mod tests {
path: "root/src/main.rs".into(),
mode: EditFileMode::Overwrite,
};
Arc::new(EditFileTool {
thread: thread.clone(),
})
.run(input, ToolCallEventStream::test().0, cx)
Arc::new(EditFileTool::new(thread.downgrade())).run(
input,
ToolCallEventStream::test().0,
cx,
)
});
// Stream the content with trailing whitespace
@ -932,10 +948,11 @@ mod tests {
path: "root/src/main.rs".into(),
mode: EditFileMode::Overwrite,
};
Arc::new(EditFileTool {
thread: thread.clone(),
})
.run(input, ToolCallEventStream::test().0, cx)
Arc::new(EditFileTool::new(thread.downgrade())).run(
input,
ToolCallEventStream::test().0,
cx,
)
});
// Stream the content with trailing whitespace
@ -983,7 +1000,7 @@ mod tests {
cx,
)
});
let tool = Arc::new(EditFileTool { thread });
let tool = Arc::new(EditFileTool::new(thread.downgrade()));
fs.insert_tree("/root", json!({})).await;
// Test 1: Path with .zed component should require confirmation
@ -1114,7 +1131,7 @@ mod tests {
cx,
)
});
let tool = Arc::new(EditFileTool { thread });
let tool = Arc::new(EditFileTool::new(thread.downgrade()));
// Test global config paths - these should require confirmation if they exist and are outside the project
let test_cases = vec![
@ -1224,7 +1241,7 @@ mod tests {
cx,
)
});
let tool = Arc::new(EditFileTool { thread });
let tool = Arc::new(EditFileTool::new(thread.downgrade()));
// Test files in different worktrees
let test_cases = vec![
@ -1305,7 +1322,7 @@ mod tests {
cx,
)
});
let tool = Arc::new(EditFileTool { thread });
let tool = Arc::new(EditFileTool::new(thread.downgrade()));
// Test edge cases
let test_cases = vec![
@ -1389,7 +1406,7 @@ mod tests {
cx,
)
});
let tool = Arc::new(EditFileTool { thread });
let tool = Arc::new(EditFileTool::new(thread.downgrade()));
// Test different EditFileMode values
let modes = vec![
@ -1470,7 +1487,7 @@ mod tests {
cx,
)
});
let tool = Arc::new(EditFileTool { thread });
let tool = Arc::new(EditFileTool::new(thread.downgrade()));
assert_eq!(
tool.initial_title(Err(json!({