From 153724aad3709abc8bbbc59d584fe139d4ec801f Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Fri, 22 Aug 2025 15:44:58 -0700 Subject: [PATCH] Clean up handling of serialized ssh connection ids (#36781) Small follow-up to #36714 Release Notes: - N/A --- crates/remote/src/ssh_session.rs | 5 - crates/workspace/src/persistence.rs | 166 +++++++++++----------- crates/workspace/src/persistence/model.rs | 7 +- crates/workspace/src/workspace.rs | 12 +- 4 files changed, 93 insertions(+), 97 deletions(-) diff --git a/crates/remote/src/ssh_session.rs b/crates/remote/src/ssh_session.rs index c02d0ad7e7..b9af528643 100644 --- a/crates/remote/src/ssh_session.rs +++ b/crates/remote/src/ssh_session.rs @@ -52,11 +52,6 @@ use util::{ paths::{PathStyle, RemotePathBuf}, }; -#[derive( - Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, serde::Serialize, serde::Deserialize, -)] -pub struct SshProjectId(pub u64); - #[derive(Clone)] pub struct SshSocket { connection_options: SshConnectionOptions, diff --git a/crates/workspace/src/persistence.rs b/crates/workspace/src/persistence.rs index de8f63957c..39a1e08c93 100644 --- a/crates/workspace/src/persistence.rs +++ b/crates/workspace/src/persistence.rs @@ -9,13 +9,13 @@ use std::{ }; use anyhow::{Context as _, Result, bail}; +use collections::HashMap; use db::{define_connection, query, sqlez::connection::Connection, sqlez_macros::sql}; use gpui::{Axis, Bounds, Task, WindowBounds, WindowId, point, size}; use project::debugger::breakpoint_store::{BreakpointState, SourceBreakpoint}; use language::{LanguageName, Toolchain}; use project::WorktreeId; -use remote::ssh_session::SshProjectId; use sqlez::{ bindable::{Bind, Column, StaticColumnCount}, statement::{SqlType, Statement}, @@ -33,7 +33,7 @@ use crate::{ use model::{ GroupId, ItemId, PaneId, SerializedItem, SerializedPane, SerializedPaneGroup, - SerializedSshConnection, SerializedWorkspace, + SerializedSshConnection, SerializedWorkspace, SshConnectionId, }; use self::model::{DockStructure, SerializedWorkspaceLocation}; @@ -615,7 +615,7 @@ impl WorkspaceDb { pub(crate) fn ssh_workspace_for_roots>( &self, worktree_roots: &[P], - ssh_project_id: SshProjectId, + ssh_project_id: SshConnectionId, ) -> Option { self.workspace_for_roots_internal(worktree_roots, Some(ssh_project_id)) } @@ -623,7 +623,7 @@ impl WorkspaceDb { pub(crate) fn workspace_for_roots_internal>( &self, worktree_roots: &[P], - ssh_connection_id: Option, + ssh_connection_id: Option, ) -> Option { // paths are sorted before db interactions to ensure that the order of the paths // doesn't affect the workspace selection for existing workspaces @@ -762,15 +762,21 @@ impl WorkspaceDb { /// that used this workspace previously pub(crate) async fn save_workspace(&self, workspace: SerializedWorkspace) { let paths = workspace.paths.serialize(); - let ssh_connection_id = match &workspace.location { - SerializedWorkspaceLocation::Local => None, - SerializedWorkspaceLocation::Ssh(serialized_ssh_connection) => { - Some(serialized_ssh_connection.id.0) - } - }; log::debug!("Saving workspace at location: {:?}", workspace.location); self.write(move |conn| { conn.with_savepoint("update_worktrees", || { + let ssh_connection_id = match &workspace.location { + SerializedWorkspaceLocation::Local => None, + SerializedWorkspaceLocation::Ssh(connection) => { + Some(Self::get_or_create_ssh_connection_query( + conn, + connection.host.clone(), + connection.port, + connection.user.clone(), + )?.0) + } + }; + // Clear out panes and pane_groups conn.exec_bound(sql!( DELETE FROM pane_groups WHERE workspace_id = ?1; @@ -893,39 +899,34 @@ impl WorkspaceDb { host: String, port: Option, user: Option, - ) -> Result { - if let Some(id) = self - .get_ssh_connection(host.clone(), port, user.clone()) - .await? + ) -> Result { + self.write(move |conn| Self::get_or_create_ssh_connection_query(conn, host, port, user)) + .await + } + + fn get_or_create_ssh_connection_query( + this: &Connection, + host: String, + port: Option, + user: Option, + ) -> Result { + if let Some(id) = this.select_row_bound(sql!( + SELECT id FROM ssh_connections WHERE host IS ? AND port IS ? AND user IS ? LIMIT 1 + ))?((host.clone(), port, user.clone()))? { - Ok(SshProjectId(id)) + Ok(SshConnectionId(id)) } else { log::debug!("Inserting SSH project at host {host}"); - let id = self - .insert_ssh_connection(host, port, user) - .await? - .context("failed to insert ssh project")?; - Ok(SshProjectId(id)) - } - } - - query! { - async fn get_ssh_connection(host: String, port: Option, user: Option) -> Result> { - SELECT id - FROM ssh_connections - WHERE host IS ? AND port IS ? AND user IS ? - LIMIT 1 - } - } - - query! { - async fn insert_ssh_connection(host: String, port: Option, user: Option) -> Result> { - INSERT INTO ssh_connections ( - host, - port, - user - ) VALUES (?1, ?2, ?3) - RETURNING id + let id = this.select_row_bound(sql!( + INSERT INTO ssh_connections ( + host, + port, + user + ) VALUES (?1, ?2, ?3) + RETURNING id + ))?((host, port, user))? + .context("failed to insert ssh project")?; + Ok(SshConnectionId(id)) } } @@ -963,7 +964,7 @@ impl WorkspaceDb { fn session_workspaces( &self, session_id: String, - ) -> Result, Option)>> { + ) -> Result, Option)>> { Ok(self .session_workspaces_query(session_id)? .into_iter() @@ -971,7 +972,7 @@ impl WorkspaceDb { ( PathList::deserialize(&SerializedPathList { paths, order }), window_id, - ssh_connection_id.map(SshProjectId), + ssh_connection_id.map(SshConnectionId), ) }) .collect()) @@ -1001,15 +1002,15 @@ impl WorkspaceDb { } } - fn ssh_connections(&self) -> Result> { + fn ssh_connections(&self) -> Result> { Ok(self .ssh_connections_query()? .into_iter() - .map(|(id, host, port, user)| SerializedSshConnection { - id: SshProjectId(id), - host, - port, - user, + .map(|(id, host, port, user)| { + ( + SshConnectionId(id), + SerializedSshConnection { host, port, user }, + ) }) .collect()) } @@ -1021,19 +1022,18 @@ impl WorkspaceDb { } } - pub fn ssh_connection(&self, id: SshProjectId) -> Result { + pub(crate) fn ssh_connection(&self, id: SshConnectionId) -> Result { let row = self.ssh_connection_query(id.0)?; Ok(SerializedSshConnection { - id: SshProjectId(row.0), - host: row.1, - port: row.2, - user: row.3, + host: row.0, + port: row.1, + user: row.2, }) } query! { - fn ssh_connection_query(id: u64) -> Result<(u64, String, Option, Option)> { - SELECT id, host, port, user + fn ssh_connection_query(id: u64) -> Result<(String, Option, Option)> { + SELECT host, port, user FROM ssh_connections WHERE id = ? } @@ -1075,10 +1075,8 @@ impl WorkspaceDb { let ssh_connections = self.ssh_connections()?; for (id, paths, ssh_connection_id) in self.recent_workspaces()? { - if let Some(ssh_connection_id) = ssh_connection_id.map(SshProjectId) { - if let Some(ssh_connection) = - ssh_connections.iter().find(|rp| rp.id == ssh_connection_id) - { + if let Some(ssh_connection_id) = ssh_connection_id.map(SshConnectionId) { + if let Some(ssh_connection) = ssh_connections.get(&ssh_connection_id) { result.push(( id, SerializedWorkspaceLocation::Ssh(ssh_connection.clone()), @@ -2340,12 +2338,10 @@ mod tests { ] .into_iter() .map(|(host, user)| async { - let id = db - .get_or_create_ssh_connection(host.to_string(), None, Some(user.to_string())) + db.get_or_create_ssh_connection(host.to_string(), None, Some(user.to_string())) .await .unwrap(); SerializedSshConnection { - id, host: host.into(), port: None, user: Some(user.into()), @@ -2501,26 +2497,34 @@ mod tests { let stored_projects = db.ssh_connections().unwrap(); assert_eq!( stored_projects, - &[ - SerializedSshConnection { - id: ids[0], - host: "example.com".into(), - port: None, - user: None, - }, - SerializedSshConnection { - id: ids[1], - host: "anotherexample.com".into(), - port: Some(123), - user: Some("user2".into()), - }, - SerializedSshConnection { - id: ids[2], - host: "yetanother.com".into(), - port: Some(345), - user: None, - }, + [ + ( + ids[0], + SerializedSshConnection { + host: "example.com".into(), + port: None, + user: None, + } + ), + ( + ids[1], + SerializedSshConnection { + host: "anotherexample.com".into(), + port: Some(123), + user: Some("user2".into()), + } + ), + ( + ids[2], + SerializedSshConnection { + host: "yetanother.com".into(), + port: Some(345), + user: None, + } + ), ] + .into_iter() + .collect::>(), ); } diff --git a/crates/workspace/src/persistence/model.rs b/crates/workspace/src/persistence/model.rs index afe4ae6235..04757d0495 100644 --- a/crates/workspace/src/persistence/model.rs +++ b/crates/workspace/src/persistence/model.rs @@ -12,7 +12,6 @@ use db::sqlez::{ use gpui::{AsyncWindowContext, Entity, WeakEntity}; use project::{Project, debugger::breakpoint_store::SourceBreakpoint}; -use remote::ssh_session::SshProjectId; use serde::{Deserialize, Serialize}; use std::{ collections::BTreeMap, @@ -22,9 +21,13 @@ use std::{ use util::ResultExt; use uuid::Uuid; +#[derive( + Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, serde::Serialize, serde::Deserialize, +)] +pub(crate) struct SshConnectionId(pub u64); + #[derive(Debug, Clone, PartialEq, Deserialize, Serialize)] pub struct SerializedSshConnection { - pub id: SshProjectId, pub host: String, pub port: Option, pub user: Option, diff --git a/crates/workspace/src/workspace.rs b/crates/workspace/src/workspace.rs index d07ea30cf9..bf58786d67 100644 --- a/crates/workspace/src/workspace.rs +++ b/crates/workspace/src/workspace.rs @@ -74,10 +74,7 @@ use project::{ DirectoryLister, Project, ProjectEntryId, ProjectPath, ResolvedPath, Worktree, WorktreeId, debugger::{breakpoint_store::BreakpointStoreEvent, session::ThreadStatus}, }; -use remote::{ - SshClientDelegate, SshConnectionOptions, - ssh_session::{ConnectionIdentifier, SshProjectId}, -}; +use remote::{SshClientDelegate, SshConnectionOptions, ssh_session::ConnectionIdentifier}; use schemars::JsonSchema; use serde::Deserialize; use session::AppSession; @@ -1128,7 +1125,6 @@ pub struct Workspace { terminal_provider: Option>, debugger_provider: Option>, serializable_items_tx: UnboundedSender>, - serialized_ssh_connection_id: Option, _items_serializer: Task>, session_id: Option, scheduled_tasks: Vec>, @@ -1461,7 +1457,7 @@ impl Workspace { serializable_items_tx, _items_serializer, session_id: Some(session_id), - serialized_ssh_connection_id: None, + scheduled_tasks: Vec::new(), } } @@ -5288,11 +5284,9 @@ impl Workspace { fn serialize_workspace_location(&self, cx: &App) -> WorkspaceLocation { let paths = PathList::new(&self.root_paths(cx)); - let connection = self.project.read(cx).ssh_connection_options(cx); - if let Some((id, connection)) = self.serialized_ssh_connection_id.zip(connection) { + if let Some(connection) = self.project.read(cx).ssh_connection_options(cx) { WorkspaceLocation::Location( SerializedWorkspaceLocation::Ssh(SerializedSshConnection { - id, host: connection.host, port: connection.port, user: connection.username,