Move expensive participant update out of transaction

Co-Authored-By: Marshall <marshall@zed.dev>
This commit is contained in:
Conrad Irwin 2024-01-25 11:03:13 -07:00
parent ca27ac21c2
commit adb6f3e9f7
4 changed files with 131 additions and 97 deletions

View file

@ -169,6 +169,30 @@ impl Database {
self.run(body).await self.run(body).await
} }
pub async fn weak_transaction<F, Fut, T>(&self, f: F) -> Result<T>
where
F: Send + Fn(TransactionHandle) -> Fut,
Fut: Send + Future<Output = Result<T>>,
{
let body = async {
let (tx, result) = self.with_weak_transaction(&f).await?;
match result {
Ok(result) => match tx.commit().await.map_err(Into::into) {
Ok(()) => return Ok(result),
Err(error) => {
return Err(error);
}
},
Err(error) => {
tx.rollback().await?;
return Err(error);
}
}
};
self.run(body).await
}
/// The same as room_transaction, but if you need to only optionally return a Room. /// The same as room_transaction, but if you need to only optionally return a Room.
async fn optional_room_transaction<F, Fut, T>(&self, f: F) -> Result<Option<RoomGuard<T>>> async fn optional_room_transaction<F, Fut, T>(&self, f: F) -> Result<Option<RoomGuard<T>>>
where where
@ -284,6 +308,30 @@ impl Database {
Ok((tx, result)) Ok((tx, result))
} }
async fn with_weak_transaction<F, Fut, T>(
&self,
f: &F,
) -> Result<(DatabaseTransaction, Result<T>)>
where
F: Send + Fn(TransactionHandle) -> Fut,
Fut: Send + Future<Output = Result<T>>,
{
let tx = self
.pool
.begin_with_config(Some(IsolationLevel::ReadCommitted), None)
.await?;
let mut tx = Arc::new(Some(tx));
let result = f(TransactionHandle(tx.clone())).await;
let Some(tx) = Arc::get_mut(&mut tx).and_then(|tx| tx.take()) else {
return Err(anyhow!(
"couldn't complete transaction because it's still in use"
))?;
};
Ok((tx, result))
}
async fn run<F, T>(&self, future: F) -> Result<T> async fn run<F, T>(&self, future: F) -> Result<T>
where where
F: Future<Output = Result<T>>, F: Future<Output = Result<T>>,
@ -457,9 +505,8 @@ pub struct NewUserResult {
/// The result of moving a channel. /// The result of moving a channel.
#[derive(Debug)] #[derive(Debug)]
pub struct MoveChannelResult { pub struct MoveChannelResult {
pub participants_to_update: HashMap<UserId, ChannelsForUser>, pub previous_participants: Vec<ChannelMember>,
pub participants_to_remove: HashSet<UserId>, pub descendent_ids: Vec<ChannelId>,
pub moved_channels: HashSet<ChannelId>,
} }
/// The result of renaming a channel. /// The result of renaming a channel.

View file

@ -22,7 +22,6 @@ impl Database {
Ok(self Ok(self
.create_channel(name, None, creator_id) .create_channel(name, None, creator_id)
.await? .await?
.channel
.id) .id)
} }
@ -36,7 +35,6 @@ impl Database {
Ok(self Ok(self
.create_channel(name, Some(parent), creator_id) .create_channel(name, Some(parent), creator_id)
.await? .await?
.channel
.id) .id)
} }
@ -46,7 +44,7 @@ impl Database {
name: &str, name: &str,
parent_channel_id: Option<ChannelId>, parent_channel_id: Option<ChannelId>,
admin_id: UserId, admin_id: UserId,
) -> Result<CreateChannelResult> { ) -> Result<Channel> {
let name = Self::sanitize_channel_name(name)?; let name = Self::sanitize_channel_name(name)?;
self.transaction(move |tx| async move { self.transaction(move |tx| async move {
let mut parent = None; let mut parent = None;
@ -72,14 +70,7 @@ impl Database {
.insert(&*tx) .insert(&*tx)
.await?; .await?;
let participants_to_update; if parent.is_none() {
if let Some(parent) = &parent {
participants_to_update = self
.participants_to_notify_for_channel_change(parent, &*tx)
.await?;
} else {
participants_to_update = vec![];
channel_member::ActiveModel { channel_member::ActiveModel {
id: ActiveValue::NotSet, id: ActiveValue::NotSet,
channel_id: ActiveValue::Set(channel.id), channel_id: ActiveValue::Set(channel.id),
@ -89,12 +80,9 @@ impl Database {
} }
.insert(&*tx) .insert(&*tx)
.await?; .await?;
}; }
Ok(CreateChannelResult { Ok(Channel::from_model(channel, ChannelRole::Admin))
channel: Channel::from_model(channel, ChannelRole::Admin),
participants_to_update,
})
}) })
.await .await
} }
@ -718,6 +706,19 @@ impl Database {
}) })
} }
pub async fn new_participants_to_notify(
&self,
parent_channel_id: ChannelId,
) -> Result<Vec<(UserId, ChannelsForUser)>> {
self.weak_transaction(|tx| async move {
let parent_channel = self.get_channel_internal(parent_channel_id, &*tx).await?;
self.participants_to_notify_for_channel_change(&parent_channel, &*tx)
.await
})
.await
}
// TODO: this is very expensive, and we should rethink
async fn participants_to_notify_for_channel_change( async fn participants_to_notify_for_channel_change(
&self, &self,
new_parent: &channel::Model, new_parent: &channel::Model,
@ -1287,7 +1288,7 @@ impl Database {
let mut model = channel.into_active_model(); let mut model = channel.into_active_model();
model.parent_path = ActiveValue::Set(new_parent_path); model.parent_path = ActiveValue::Set(new_parent_path);
let channel = model.update(&*tx).await?; model.update(&*tx).await?;
if new_parent_channel.is_none() { if new_parent_channel.is_none() {
channel_member::ActiveModel { channel_member::ActiveModel {
@ -1314,34 +1315,9 @@ impl Database {
.all(&*tx) .all(&*tx)
.await?; .await?;
let participants_to_update: HashMap<_, _> = self
.participants_to_notify_for_channel_change(
new_parent_channel.as_ref().unwrap_or(&channel),
&*tx,
)
.await?
.into_iter()
.collect();
let mut moved_channels: HashSet<ChannelId> = HashSet::default();
for id in descendent_ids {
moved_channels.insert(id);
}
moved_channels.insert(channel_id);
let mut participants_to_remove: HashSet<UserId> = HashSet::default();
for participant in previous_participants {
if participant.kind == proto::channel_member::Kind::AncestorMember {
if !participants_to_update.contains_key(&participant.user_id) {
participants_to_remove.insert(participant.user_id);
}
}
}
Ok(Some(MoveChannelResult { Ok(Some(MoveChannelResult {
participants_to_remove, previous_participants,
participants_to_update, descendent_ids,
moved_channels,
})) }))
}) })
.await .await

View file

@ -15,11 +15,11 @@ test_both_dbs!(
async fn test_channel_message_retrieval(db: &Arc<Database>) { async fn test_channel_message_retrieval(db: &Arc<Database>) {
let user = new_test_user(db, "user@example.com").await; let user = new_test_user(db, "user@example.com").await;
let result = db.create_channel("channel", None, user).await.unwrap(); let channel = db.create_channel("channel", None, user).await.unwrap();
let owner_id = db.create_server("test").await.unwrap().0 as u32; let owner_id = db.create_server("test").await.unwrap().0 as u32;
db.join_channel_chat( db.join_channel_chat(
result.channel.id, channel.id,
rpc::ConnectionId { owner_id, id: 0 }, rpc::ConnectionId { owner_id, id: 0 },
user, user,
) )
@ -30,7 +30,7 @@ async fn test_channel_message_retrieval(db: &Arc<Database>) {
for i in 0..10 { for i in 0..10 {
all_messages.push( all_messages.push(
db.create_channel_message( db.create_channel_message(
result.channel.id, channel.id,
user, user,
&i.to_string(), &i.to_string(),
&[], &[],
@ -45,7 +45,7 @@ async fn test_channel_message_retrieval(db: &Arc<Database>) {
} }
let messages = db let messages = db
.get_channel_messages(result.channel.id, user, 3, None) .get_channel_messages(channel.id, user, 3, None)
.await .await
.unwrap() .unwrap()
.into_iter() .into_iter()
@ -55,7 +55,7 @@ async fn test_channel_message_retrieval(db: &Arc<Database>) {
let messages = db let messages = db
.get_channel_messages( .get_channel_messages(
result.channel.id, channel.id,
user, user,
4, 4,
Some(MessageId::from_proto(all_messages[6])), Some(MessageId::from_proto(all_messages[6])),
@ -370,7 +370,6 @@ async fn test_channel_message_mentions(db: &Arc<Database>) {
.create_channel("channel", None, user_a) .create_channel("channel", None, user_a)
.await .await
.unwrap() .unwrap()
.channel
.id; .id;
db.invite_channel_member(channel, user_b, user_a, ChannelRole::Member) db.invite_channel_member(channel, user_b, user_a, ChannelRole::Member)
.await .await

View file

@ -3,9 +3,9 @@ mod connection_pool;
use crate::{ use crate::{
auth::{self, Impersonator}, auth::{self, Impersonator},
db::{ db::{
self, BufferId, ChannelId, ChannelRole, ChannelsForUser, CreateChannelResult, self, BufferId, ChannelId, ChannelRole, ChannelsForUser,
CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId, CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
MoveChannelResult, NotificationId, ProjectId, RemoveChannelMemberResult, NotificationId, ProjectId, RemoveChannelMemberResult,
RenameChannelResult, RespondToChannelInvite, RoomId, ServerId, SetChannelVisibilityResult, RenameChannelResult, RespondToChannelInvite, RoomId, ServerId, SetChannelVisibilityResult,
User, UserId, User, UserId,
}, },
@ -2301,10 +2301,7 @@ async fn create_channel(
let db = session.db().await; let db = session.db().await;
let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id)); let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
let CreateChannelResult { let channel = db
channel,
participants_to_update,
} = db
.create_channel(&request.name, parent_id, session.user_id) .create_channel(&request.name, parent_id, session.user_id)
.await?; .await?;
@ -2313,6 +2310,13 @@ async fn create_channel(
parent_id: request.parent_id, parent_id: request.parent_id,
})?; })?;
let participants_to_update;
if let Some(parent) = parent_id {
participants_to_update = db.new_participants_to_notify(parent).await?;
} else {
participants_to_update = vec![];
}
let connection_pool = session.connection_pool().await; let connection_pool = session.connection_pool().await;
for (user_id, channels) in participants_to_update { for (user_id, channels) in participants_to_update {
let update = build_channels_update(channels, vec![]); let update = build_channels_update(channels, vec![]);
@ -2566,27 +2570,32 @@ async fn move_channel(
let channel_id = ChannelId::from_proto(request.channel_id); let channel_id = ChannelId::from_proto(request.channel_id);
let to = request.to.map(ChannelId::from_proto); let to = request.to.map(ChannelId::from_proto);
let result = session let result = session.db().await.move_channel(channel_id, to, session.user_id).await?;
.db()
.await
.move_channel(channel_id, to, session.user_id)
.await?;
notify_channel_moved(result, session).await?; if let Some(result) = result {
let participants_to_update: HashMap<_, _> = session.db().await
.new_participants_to_notify(
to.unwrap_or(channel_id)
)
.await?
.into_iter()
.collect();
response.send(Ack {})?; let mut moved_channels: HashSet<ChannelId> = HashSet::default();
Ok(()) for id in result.descendent_ids {
} moved_channels.insert(id);
}
moved_channels.insert(channel_id);
let mut participants_to_remove: HashSet<UserId> = HashSet::default();
for participant in result.previous_participants {
if participant.kind == proto::channel_member::Kind::AncestorMember {
if !participants_to_update.contains_key(&participant.user_id) {
participants_to_remove.insert(participant.user_id);
}
}
}
async fn notify_channel_moved(result: Option<MoveChannelResult>, session: Session) -> Result<()> {
let Some(MoveChannelResult {
participants_to_remove,
participants_to_update,
moved_channels,
}) = result
else {
return Ok(());
};
let moved_channels: Vec<u64> = moved_channels.iter().map(|id| id.to_proto()).collect(); let moved_channels: Vec<u64> = moved_channels.iter().map(|id| id.to_proto()).collect();
let connection_pool = session.connection_pool().await; let connection_pool = session.connection_pool().await;
@ -2607,6 +2616,9 @@ async fn notify_channel_moved(result: Option<MoveChannelResult>, session: Sessio
session.peer.send(connection_id, update.clone())?; session.peer.send(connection_id, update.clone())?;
} }
} }
}
response.send(Ack {})?;
Ok(()) Ok(())
} }