Get integration tests passing with sqlite

Co-authored-by: Antonio Scandurra <antonio@zed.dev>
This commit is contained in:
Max Brunsfeld 2022-11-10 11:03:52 -08:00
parent 1bb41b6f54
commit 05a6bd914d
11 changed files with 473 additions and 3301 deletions

View file

@ -2,7 +2,7 @@ mod store;
use crate::{
auth,
db::{self, ChannelId, MessageId, ProjectId, User, UserId},
db::{self, ProjectId, User, UserId},
AppState, Result,
};
use anyhow::anyhow;
@ -24,7 +24,7 @@ use axum::{
};
use collections::{HashMap, HashSet};
use futures::{
channel::{mpsc, oneshot},
channel::oneshot,
future::{self, BoxFuture},
stream::FuturesUnordered,
FutureExt, SinkExt, StreamExt, TryStreamExt,
@ -51,7 +51,6 @@ use std::{
time::Duration,
};
pub use store::{Store, Worktree};
use time::OffsetDateTime;
use tokio::{
sync::{Mutex, MutexGuard},
time::Sleep,
@ -62,10 +61,6 @@ use tracing::{info_span, instrument, Instrument};
lazy_static! {
static ref METRIC_CONNECTIONS: IntGauge =
register_int_gauge!("connections", "number of connections").unwrap();
static ref METRIC_REGISTERED_PROJECTS: IntGauge =
register_int_gauge!("registered_projects", "number of registered projects").unwrap();
static ref METRIC_ACTIVE_PROJECTS: IntGauge =
register_int_gauge!("active_projects", "number of active projects").unwrap();
static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
"shared_projects",
"number of open projects with one or more guests"
@ -95,7 +90,6 @@ pub struct Server {
pub(crate) store: Mutex<Store>,
app_state: Arc<AppState>,
handlers: HashMap<TypeId, MessageHandler>,
notifications: Option<mpsc::UnboundedSender<()>>,
}
pub trait Executor: Send + Clone {
@ -107,9 +101,6 @@ pub trait Executor: Send + Clone {
#[derive(Clone)]
pub struct RealExecutor;
const MESSAGE_COUNT_PER_PAGE: usize = 100;
const MAX_MESSAGE_LEN: usize = 1024;
pub(crate) struct StoreGuard<'a> {
guard: MutexGuard<'a, Store>,
_not_send: PhantomData<Rc<()>>,
@ -132,16 +123,12 @@ where
}
impl Server {
pub fn new(
app_state: Arc<AppState>,
notifications: Option<mpsc::UnboundedSender<()>>,
) -> Arc<Self> {
pub fn new(app_state: Arc<AppState>) -> Arc<Self> {
let mut server = Self {
peer: Peer::new(),
app_state,
store: Default::default(),
handlers: Default::default(),
notifications,
};
server
@ -158,9 +145,7 @@ impl Server {
.add_request_handler(Server::join_project)
.add_message_handler(Server::leave_project)
.add_message_handler(Server::update_project)
.add_message_handler(Server::register_project_activity)
.add_request_handler(Server::update_worktree)
.add_message_handler(Server::update_worktree_extensions)
.add_message_handler(Server::start_language_server)
.add_message_handler(Server::update_language_server)
.add_message_handler(Server::update_diagnostic_summary)
@ -194,19 +179,14 @@ impl Server {
.add_message_handler(Server::buffer_reloaded)
.add_message_handler(Server::buffer_saved)
.add_request_handler(Server::save_buffer)
.add_request_handler(Server::get_channels)
.add_request_handler(Server::get_users)
.add_request_handler(Server::fuzzy_search_users)
.add_request_handler(Server::request_contact)
.add_request_handler(Server::remove_contact)
.add_request_handler(Server::respond_to_contact_request)
.add_request_handler(Server::join_channel)
.add_message_handler(Server::leave_channel)
.add_request_handler(Server::send_channel_message)
.add_request_handler(Server::follow)
.add_message_handler(Server::unfollow)
.add_message_handler(Server::update_followers)
.add_request_handler(Server::get_channel_messages)
.add_message_handler(Server::update_diff_base)
.add_request_handler(Server::get_private_user_info);
@ -290,58 +270,6 @@ impl Server {
})
}
/// Start a long lived task that records which users are active in which projects.
pub fn start_recording_project_activity<E: 'static + Executor>(
self: &Arc<Self>,
interval: Duration,
executor: E,
) {
executor.spawn_detached({
let this = Arc::downgrade(self);
let executor = executor.clone();
async move {
let mut period_start = OffsetDateTime::now_utc();
let mut active_projects = Vec::<(UserId, ProjectId)>::new();
loop {
let sleep = executor.sleep(interval);
sleep.await;
let this = if let Some(this) = this.upgrade() {
this
} else {
break;
};
active_projects.clear();
active_projects.extend(this.store().await.projects().flat_map(
|(project_id, project)| {
project.guests.values().chain([&project.host]).filter_map(
|collaborator| {
if !collaborator.admin
&& collaborator
.last_activity
.map_or(false, |activity| activity > period_start)
{
Some((collaborator.user_id, *project_id))
} else {
None
}
},
)
},
));
let period_end = OffsetDateTime::now_utc();
this.app_state
.db
.record_user_activity(period_start..period_end, &active_projects)
.await
.trace_err();
period_start = period_end;
}
}
});
}
pub fn handle_connection<E: Executor>(
self: &Arc<Self>,
connection: Connection,
@ -432,18 +360,11 @@ impl Server {
let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
let span_enter = span.enter();
if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
let notifications = this.notifications.clone();
let is_background = message.is_background();
let handle_message = (handler)(this.clone(), message);
drop(span_enter);
let handle_message = async move {
handle_message.await;
if let Some(mut notifications) = notifications {
let _ = notifications.send(()).await;
}
}.instrument(span);
let handle_message = handle_message.instrument(span);
if is_background {
executor.spawn_detached(handle_message);
} else {
@ -1172,17 +1093,6 @@ impl Server {
Ok(())
}
async fn register_project_activity(
self: Arc<Server>,
request: TypedEnvelope<proto::RegisterProjectActivity>,
) -> Result<()> {
self.store().await.register_project_activity(
ProjectId::from_proto(request.payload.project_id),
request.sender_id,
)?;
Ok(())
}
async fn update_worktree(
self: Arc<Server>,
request: TypedEnvelope<proto::UpdateWorktree>,
@ -1209,25 +1119,6 @@ impl Server {
Ok(())
}
async fn update_worktree_extensions(
self: Arc<Server>,
request: TypedEnvelope<proto::UpdateWorktreeExtensions>,
) -> Result<()> {
let project_id = ProjectId::from_proto(request.payload.project_id);
let worktree_id = request.payload.worktree_id;
let extensions = request
.payload
.extensions
.into_iter()
.zip(request.payload.counts)
.collect();
self.app_state
.db
.update_worktree_extensions(project_id, worktree_id, extensions)
.await?;
Ok(())
}
async fn update_diagnostic_summary(
self: Arc<Server>,
request: TypedEnvelope<proto::UpdateDiagnosticSummary>,
@ -1363,8 +1254,7 @@ impl Server {
) -> Result<()> {
let project_id = ProjectId::from_proto(request.payload.project_id);
let receiver_ids = {
let mut store = self.store().await;
store.register_project_activity(project_id, request.sender_id)?;
let store = self.store().await;
store.project_connection_ids(project_id, request.sender_id)?
};
@ -1430,15 +1320,13 @@ impl Server {
let leader_id = ConnectionId(request.payload.leader_id);
let follower_id = request.sender_id;
{
let mut store = self.store().await;
let store = self.store().await;
if !store
.project_connection_ids(project_id, follower_id)?
.contains(&leader_id)
{
Err(anyhow!("no such peer"))?;
}
store.register_project_activity(project_id, follower_id)?;
}
let mut response_payload = self
@ -1455,14 +1343,13 @@ impl Server {
async fn unfollow(self: Arc<Self>, request: TypedEnvelope<proto::Unfollow>) -> Result<()> {
let project_id = ProjectId::from_proto(request.payload.project_id);
let leader_id = ConnectionId(request.payload.leader_id);
let mut store = self.store().await;
let store = self.store().await;
if !store
.project_connection_ids(project_id, request.sender_id)?
.contains(&leader_id)
{
Err(anyhow!("no such peer"))?;
}
store.register_project_activity(project_id, request.sender_id)?;
self.peer
.forward_send(request.sender_id, leader_id, request.payload)?;
Ok(())
@ -1473,8 +1360,7 @@ impl Server {
request: TypedEnvelope<proto::UpdateFollowers>,
) -> Result<()> {
let project_id = ProjectId::from_proto(request.payload.project_id);
let mut store = self.store().await;
store.register_project_activity(project_id, request.sender_id)?;
let store = self.store().await;
let connection_ids = store.project_connection_ids(project_id, request.sender_id)?;
let leader_id = request
.payload
@ -1495,28 +1381,6 @@ impl Server {
Ok(())
}
async fn get_channels(
self: Arc<Server>,
request: TypedEnvelope<proto::GetChannels>,
response: Response<proto::GetChannels>,
) -> Result<()> {
let user_id = self
.store()
.await
.user_id_for_connection(request.sender_id)?;
let channels = self.app_state.db.get_accessible_channels(user_id).await?;
response.send(proto::GetChannelsResponse {
channels: channels
.into_iter()
.map(|chan| proto::Channel {
id: chan.id.to_proto(),
name: chan.name,
})
.collect(),
})?;
Ok(())
}
async fn get_users(
self: Arc<Server>,
request: TypedEnvelope<proto::GetUsers>,
@ -1712,175 +1576,6 @@ impl Server {
Ok(())
}
async fn join_channel(
self: Arc<Self>,
request: TypedEnvelope<proto::JoinChannel>,
response: Response<proto::JoinChannel>,
) -> Result<()> {
let user_id = self
.store()
.await
.user_id_for_connection(request.sender_id)?;
let channel_id = ChannelId::from_proto(request.payload.channel_id);
if !self
.app_state
.db
.can_user_access_channel(user_id, channel_id)
.await?
{
Err(anyhow!("access denied"))?;
}
self.store()
.await
.join_channel(request.sender_id, channel_id);
let messages = self
.app_state
.db
.get_channel_messages(channel_id, MESSAGE_COUNT_PER_PAGE, None)
.await?
.into_iter()
.map(|msg| proto::ChannelMessage {
id: msg.id.to_proto(),
body: msg.body,
timestamp: msg.sent_at.unix_timestamp() as u64,
sender_id: msg.sender_id.to_proto(),
nonce: Some(msg.nonce.as_u128().into()),
})
.collect::<Vec<_>>();
response.send(proto::JoinChannelResponse {
done: messages.len() < MESSAGE_COUNT_PER_PAGE,
messages,
})?;
Ok(())
}
async fn leave_channel(
self: Arc<Self>,
request: TypedEnvelope<proto::LeaveChannel>,
) -> Result<()> {
let user_id = self
.store()
.await
.user_id_for_connection(request.sender_id)?;
let channel_id = ChannelId::from_proto(request.payload.channel_id);
if !self
.app_state
.db
.can_user_access_channel(user_id, channel_id)
.await?
{
Err(anyhow!("access denied"))?;
}
self.store()
.await
.leave_channel(request.sender_id, channel_id);
Ok(())
}
async fn send_channel_message(
self: Arc<Self>,
request: TypedEnvelope<proto::SendChannelMessage>,
response: Response<proto::SendChannelMessage>,
) -> Result<()> {
let channel_id = ChannelId::from_proto(request.payload.channel_id);
let user_id;
let connection_ids;
{
let state = self.store().await;
user_id = state.user_id_for_connection(request.sender_id)?;
connection_ids = state.channel_connection_ids(channel_id)?;
}
// Validate the message body.
let body = request.payload.body.trim().to_string();
if body.len() > MAX_MESSAGE_LEN {
return Err(anyhow!("message is too long"))?;
}
if body.is_empty() {
return Err(anyhow!("message can't be blank"))?;
}
let timestamp = OffsetDateTime::now_utc();
let nonce = request
.payload
.nonce
.ok_or_else(|| anyhow!("nonce can't be blank"))?;
let message_id = self
.app_state
.db
.create_channel_message(channel_id, user_id, &body, timestamp, nonce.clone().into())
.await?
.to_proto();
let message = proto::ChannelMessage {
sender_id: user_id.to_proto(),
id: message_id,
body,
timestamp: timestamp.unix_timestamp() as u64,
nonce: Some(nonce),
};
broadcast(request.sender_id, connection_ids, |conn_id| {
self.peer.send(
conn_id,
proto::ChannelMessageSent {
channel_id: channel_id.to_proto(),
message: Some(message.clone()),
},
)
});
response.send(proto::SendChannelMessageResponse {
message: Some(message),
})?;
Ok(())
}
async fn get_channel_messages(
self: Arc<Self>,
request: TypedEnvelope<proto::GetChannelMessages>,
response: Response<proto::GetChannelMessages>,
) -> Result<()> {
let user_id = self
.store()
.await
.user_id_for_connection(request.sender_id)?;
let channel_id = ChannelId::from_proto(request.payload.channel_id);
if !self
.app_state
.db
.can_user_access_channel(user_id, channel_id)
.await?
{
Err(anyhow!("access denied"))?;
}
let messages = self
.app_state
.db
.get_channel_messages(
channel_id,
MESSAGE_COUNT_PER_PAGE,
Some(MessageId::from_proto(request.payload.before_message_id)),
)
.await?
.into_iter()
.map(|msg| proto::ChannelMessage {
id: msg.id.to_proto(),
body: msg.body,
timestamp: msg.sent_at.unix_timestamp() as u64,
sender_id: msg.sender_id.to_proto(),
nonce: Some(msg.nonce.as_u128().into()),
})
.collect::<Vec<_>>();
response.send(proto::GetChannelMessagesResponse {
done: messages.len() < MESSAGE_COUNT_PER_PAGE,
messages,
})?;
Ok(())
}
async fn update_diff_base(
self: Arc<Server>,
request: TypedEnvelope<proto::UpdateDiffBase>,
@ -2061,11 +1756,8 @@ pub async fn handle_websocket_request(
}
pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> axum::response::Response {
// We call `store_mut` here for its side effects of updating metrics.
let metrics = server.store().await.metrics();
METRIC_CONNECTIONS.set(metrics.connections as _);
METRIC_REGISTERED_PROJECTS.set(metrics.registered_projects as _);
METRIC_ACTIVE_PROJECTS.set(metrics.active_projects as _);
METRIC_SHARED_PROJECTS.set(metrics.shared_projects as _);
let encoder = prometheus::TextEncoder::new();