Get integration tests passing with sqlite
Co-authored-by: Antonio Scandurra <antonio@zed.dev>
This commit is contained in:
parent
1bb41b6f54
commit
05a6bd914d
11 changed files with 473 additions and 3301 deletions
|
@ -2,7 +2,7 @@ mod store;
|
|||
|
||||
use crate::{
|
||||
auth,
|
||||
db::{self, ChannelId, MessageId, ProjectId, User, UserId},
|
||||
db::{self, ProjectId, User, UserId},
|
||||
AppState, Result,
|
||||
};
|
||||
use anyhow::anyhow;
|
||||
|
@ -24,7 +24,7 @@ use axum::{
|
|||
};
|
||||
use collections::{HashMap, HashSet};
|
||||
use futures::{
|
||||
channel::{mpsc, oneshot},
|
||||
channel::oneshot,
|
||||
future::{self, BoxFuture},
|
||||
stream::FuturesUnordered,
|
||||
FutureExt, SinkExt, StreamExt, TryStreamExt,
|
||||
|
@ -51,7 +51,6 @@ use std::{
|
|||
time::Duration,
|
||||
};
|
||||
pub use store::{Store, Worktree};
|
||||
use time::OffsetDateTime;
|
||||
use tokio::{
|
||||
sync::{Mutex, MutexGuard},
|
||||
time::Sleep,
|
||||
|
@ -62,10 +61,6 @@ use tracing::{info_span, instrument, Instrument};
|
|||
lazy_static! {
|
||||
static ref METRIC_CONNECTIONS: IntGauge =
|
||||
register_int_gauge!("connections", "number of connections").unwrap();
|
||||
static ref METRIC_REGISTERED_PROJECTS: IntGauge =
|
||||
register_int_gauge!("registered_projects", "number of registered projects").unwrap();
|
||||
static ref METRIC_ACTIVE_PROJECTS: IntGauge =
|
||||
register_int_gauge!("active_projects", "number of active projects").unwrap();
|
||||
static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
|
||||
"shared_projects",
|
||||
"number of open projects with one or more guests"
|
||||
|
@ -95,7 +90,6 @@ pub struct Server {
|
|||
pub(crate) store: Mutex<Store>,
|
||||
app_state: Arc<AppState>,
|
||||
handlers: HashMap<TypeId, MessageHandler>,
|
||||
notifications: Option<mpsc::UnboundedSender<()>>,
|
||||
}
|
||||
|
||||
pub trait Executor: Send + Clone {
|
||||
|
@ -107,9 +101,6 @@ pub trait Executor: Send + Clone {
|
|||
#[derive(Clone)]
|
||||
pub struct RealExecutor;
|
||||
|
||||
const MESSAGE_COUNT_PER_PAGE: usize = 100;
|
||||
const MAX_MESSAGE_LEN: usize = 1024;
|
||||
|
||||
pub(crate) struct StoreGuard<'a> {
|
||||
guard: MutexGuard<'a, Store>,
|
||||
_not_send: PhantomData<Rc<()>>,
|
||||
|
@ -132,16 +123,12 @@ where
|
|||
}
|
||||
|
||||
impl Server {
|
||||
pub fn new(
|
||||
app_state: Arc<AppState>,
|
||||
notifications: Option<mpsc::UnboundedSender<()>>,
|
||||
) -> Arc<Self> {
|
||||
pub fn new(app_state: Arc<AppState>) -> Arc<Self> {
|
||||
let mut server = Self {
|
||||
peer: Peer::new(),
|
||||
app_state,
|
||||
store: Default::default(),
|
||||
handlers: Default::default(),
|
||||
notifications,
|
||||
};
|
||||
|
||||
server
|
||||
|
@ -158,9 +145,7 @@ impl Server {
|
|||
.add_request_handler(Server::join_project)
|
||||
.add_message_handler(Server::leave_project)
|
||||
.add_message_handler(Server::update_project)
|
||||
.add_message_handler(Server::register_project_activity)
|
||||
.add_request_handler(Server::update_worktree)
|
||||
.add_message_handler(Server::update_worktree_extensions)
|
||||
.add_message_handler(Server::start_language_server)
|
||||
.add_message_handler(Server::update_language_server)
|
||||
.add_message_handler(Server::update_diagnostic_summary)
|
||||
|
@ -194,19 +179,14 @@ impl Server {
|
|||
.add_message_handler(Server::buffer_reloaded)
|
||||
.add_message_handler(Server::buffer_saved)
|
||||
.add_request_handler(Server::save_buffer)
|
||||
.add_request_handler(Server::get_channels)
|
||||
.add_request_handler(Server::get_users)
|
||||
.add_request_handler(Server::fuzzy_search_users)
|
||||
.add_request_handler(Server::request_contact)
|
||||
.add_request_handler(Server::remove_contact)
|
||||
.add_request_handler(Server::respond_to_contact_request)
|
||||
.add_request_handler(Server::join_channel)
|
||||
.add_message_handler(Server::leave_channel)
|
||||
.add_request_handler(Server::send_channel_message)
|
||||
.add_request_handler(Server::follow)
|
||||
.add_message_handler(Server::unfollow)
|
||||
.add_message_handler(Server::update_followers)
|
||||
.add_request_handler(Server::get_channel_messages)
|
||||
.add_message_handler(Server::update_diff_base)
|
||||
.add_request_handler(Server::get_private_user_info);
|
||||
|
||||
|
@ -290,58 +270,6 @@ impl Server {
|
|||
})
|
||||
}
|
||||
|
||||
/// Start a long lived task that records which users are active in which projects.
|
||||
pub fn start_recording_project_activity<E: 'static + Executor>(
|
||||
self: &Arc<Self>,
|
||||
interval: Duration,
|
||||
executor: E,
|
||||
) {
|
||||
executor.spawn_detached({
|
||||
let this = Arc::downgrade(self);
|
||||
let executor = executor.clone();
|
||||
async move {
|
||||
let mut period_start = OffsetDateTime::now_utc();
|
||||
let mut active_projects = Vec::<(UserId, ProjectId)>::new();
|
||||
loop {
|
||||
let sleep = executor.sleep(interval);
|
||||
sleep.await;
|
||||
let this = if let Some(this) = this.upgrade() {
|
||||
this
|
||||
} else {
|
||||
break;
|
||||
};
|
||||
|
||||
active_projects.clear();
|
||||
active_projects.extend(this.store().await.projects().flat_map(
|
||||
|(project_id, project)| {
|
||||
project.guests.values().chain([&project.host]).filter_map(
|
||||
|collaborator| {
|
||||
if !collaborator.admin
|
||||
&& collaborator
|
||||
.last_activity
|
||||
.map_or(false, |activity| activity > period_start)
|
||||
{
|
||||
Some((collaborator.user_id, *project_id))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
},
|
||||
)
|
||||
},
|
||||
));
|
||||
|
||||
let period_end = OffsetDateTime::now_utc();
|
||||
this.app_state
|
||||
.db
|
||||
.record_user_activity(period_start..period_end, &active_projects)
|
||||
.await
|
||||
.trace_err();
|
||||
period_start = period_end;
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
pub fn handle_connection<E: Executor>(
|
||||
self: &Arc<Self>,
|
||||
connection: Connection,
|
||||
|
@ -432,18 +360,11 @@ impl Server {
|
|||
let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
|
||||
let span_enter = span.enter();
|
||||
if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
|
||||
let notifications = this.notifications.clone();
|
||||
let is_background = message.is_background();
|
||||
let handle_message = (handler)(this.clone(), message);
|
||||
|
||||
drop(span_enter);
|
||||
let handle_message = async move {
|
||||
handle_message.await;
|
||||
if let Some(mut notifications) = notifications {
|
||||
let _ = notifications.send(()).await;
|
||||
}
|
||||
}.instrument(span);
|
||||
|
||||
let handle_message = handle_message.instrument(span);
|
||||
if is_background {
|
||||
executor.spawn_detached(handle_message);
|
||||
} else {
|
||||
|
@ -1172,17 +1093,6 @@ impl Server {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
async fn register_project_activity(
|
||||
self: Arc<Server>,
|
||||
request: TypedEnvelope<proto::RegisterProjectActivity>,
|
||||
) -> Result<()> {
|
||||
self.store().await.register_project_activity(
|
||||
ProjectId::from_proto(request.payload.project_id),
|
||||
request.sender_id,
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn update_worktree(
|
||||
self: Arc<Server>,
|
||||
request: TypedEnvelope<proto::UpdateWorktree>,
|
||||
|
@ -1209,25 +1119,6 @@ impl Server {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
async fn update_worktree_extensions(
|
||||
self: Arc<Server>,
|
||||
request: TypedEnvelope<proto::UpdateWorktreeExtensions>,
|
||||
) -> Result<()> {
|
||||
let project_id = ProjectId::from_proto(request.payload.project_id);
|
||||
let worktree_id = request.payload.worktree_id;
|
||||
let extensions = request
|
||||
.payload
|
||||
.extensions
|
||||
.into_iter()
|
||||
.zip(request.payload.counts)
|
||||
.collect();
|
||||
self.app_state
|
||||
.db
|
||||
.update_worktree_extensions(project_id, worktree_id, extensions)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn update_diagnostic_summary(
|
||||
self: Arc<Server>,
|
||||
request: TypedEnvelope<proto::UpdateDiagnosticSummary>,
|
||||
|
@ -1363,8 +1254,7 @@ impl Server {
|
|||
) -> Result<()> {
|
||||
let project_id = ProjectId::from_proto(request.payload.project_id);
|
||||
let receiver_ids = {
|
||||
let mut store = self.store().await;
|
||||
store.register_project_activity(project_id, request.sender_id)?;
|
||||
let store = self.store().await;
|
||||
store.project_connection_ids(project_id, request.sender_id)?
|
||||
};
|
||||
|
||||
|
@ -1430,15 +1320,13 @@ impl Server {
|
|||
let leader_id = ConnectionId(request.payload.leader_id);
|
||||
let follower_id = request.sender_id;
|
||||
{
|
||||
let mut store = self.store().await;
|
||||
let store = self.store().await;
|
||||
if !store
|
||||
.project_connection_ids(project_id, follower_id)?
|
||||
.contains(&leader_id)
|
||||
{
|
||||
Err(anyhow!("no such peer"))?;
|
||||
}
|
||||
|
||||
store.register_project_activity(project_id, follower_id)?;
|
||||
}
|
||||
|
||||
let mut response_payload = self
|
||||
|
@ -1455,14 +1343,13 @@ impl Server {
|
|||
async fn unfollow(self: Arc<Self>, request: TypedEnvelope<proto::Unfollow>) -> Result<()> {
|
||||
let project_id = ProjectId::from_proto(request.payload.project_id);
|
||||
let leader_id = ConnectionId(request.payload.leader_id);
|
||||
let mut store = self.store().await;
|
||||
let store = self.store().await;
|
||||
if !store
|
||||
.project_connection_ids(project_id, request.sender_id)?
|
||||
.contains(&leader_id)
|
||||
{
|
||||
Err(anyhow!("no such peer"))?;
|
||||
}
|
||||
store.register_project_activity(project_id, request.sender_id)?;
|
||||
self.peer
|
||||
.forward_send(request.sender_id, leader_id, request.payload)?;
|
||||
Ok(())
|
||||
|
@ -1473,8 +1360,7 @@ impl Server {
|
|||
request: TypedEnvelope<proto::UpdateFollowers>,
|
||||
) -> Result<()> {
|
||||
let project_id = ProjectId::from_proto(request.payload.project_id);
|
||||
let mut store = self.store().await;
|
||||
store.register_project_activity(project_id, request.sender_id)?;
|
||||
let store = self.store().await;
|
||||
let connection_ids = store.project_connection_ids(project_id, request.sender_id)?;
|
||||
let leader_id = request
|
||||
.payload
|
||||
|
@ -1495,28 +1381,6 @@ impl Server {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
async fn get_channels(
|
||||
self: Arc<Server>,
|
||||
request: TypedEnvelope<proto::GetChannels>,
|
||||
response: Response<proto::GetChannels>,
|
||||
) -> Result<()> {
|
||||
let user_id = self
|
||||
.store()
|
||||
.await
|
||||
.user_id_for_connection(request.sender_id)?;
|
||||
let channels = self.app_state.db.get_accessible_channels(user_id).await?;
|
||||
response.send(proto::GetChannelsResponse {
|
||||
channels: channels
|
||||
.into_iter()
|
||||
.map(|chan| proto::Channel {
|
||||
id: chan.id.to_proto(),
|
||||
name: chan.name,
|
||||
})
|
||||
.collect(),
|
||||
})?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn get_users(
|
||||
self: Arc<Server>,
|
||||
request: TypedEnvelope<proto::GetUsers>,
|
||||
|
@ -1712,175 +1576,6 @@ impl Server {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
async fn join_channel(
|
||||
self: Arc<Self>,
|
||||
request: TypedEnvelope<proto::JoinChannel>,
|
||||
response: Response<proto::JoinChannel>,
|
||||
) -> Result<()> {
|
||||
let user_id = self
|
||||
.store()
|
||||
.await
|
||||
.user_id_for_connection(request.sender_id)?;
|
||||
let channel_id = ChannelId::from_proto(request.payload.channel_id);
|
||||
if !self
|
||||
.app_state
|
||||
.db
|
||||
.can_user_access_channel(user_id, channel_id)
|
||||
.await?
|
||||
{
|
||||
Err(anyhow!("access denied"))?;
|
||||
}
|
||||
|
||||
self.store()
|
||||
.await
|
||||
.join_channel(request.sender_id, channel_id);
|
||||
let messages = self
|
||||
.app_state
|
||||
.db
|
||||
.get_channel_messages(channel_id, MESSAGE_COUNT_PER_PAGE, None)
|
||||
.await?
|
||||
.into_iter()
|
||||
.map(|msg| proto::ChannelMessage {
|
||||
id: msg.id.to_proto(),
|
||||
body: msg.body,
|
||||
timestamp: msg.sent_at.unix_timestamp() as u64,
|
||||
sender_id: msg.sender_id.to_proto(),
|
||||
nonce: Some(msg.nonce.as_u128().into()),
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
response.send(proto::JoinChannelResponse {
|
||||
done: messages.len() < MESSAGE_COUNT_PER_PAGE,
|
||||
messages,
|
||||
})?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn leave_channel(
|
||||
self: Arc<Self>,
|
||||
request: TypedEnvelope<proto::LeaveChannel>,
|
||||
) -> Result<()> {
|
||||
let user_id = self
|
||||
.store()
|
||||
.await
|
||||
.user_id_for_connection(request.sender_id)?;
|
||||
let channel_id = ChannelId::from_proto(request.payload.channel_id);
|
||||
if !self
|
||||
.app_state
|
||||
.db
|
||||
.can_user_access_channel(user_id, channel_id)
|
||||
.await?
|
||||
{
|
||||
Err(anyhow!("access denied"))?;
|
||||
}
|
||||
|
||||
self.store()
|
||||
.await
|
||||
.leave_channel(request.sender_id, channel_id);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn send_channel_message(
|
||||
self: Arc<Self>,
|
||||
request: TypedEnvelope<proto::SendChannelMessage>,
|
||||
response: Response<proto::SendChannelMessage>,
|
||||
) -> Result<()> {
|
||||
let channel_id = ChannelId::from_proto(request.payload.channel_id);
|
||||
let user_id;
|
||||
let connection_ids;
|
||||
{
|
||||
let state = self.store().await;
|
||||
user_id = state.user_id_for_connection(request.sender_id)?;
|
||||
connection_ids = state.channel_connection_ids(channel_id)?;
|
||||
}
|
||||
|
||||
// Validate the message body.
|
||||
let body = request.payload.body.trim().to_string();
|
||||
if body.len() > MAX_MESSAGE_LEN {
|
||||
return Err(anyhow!("message is too long"))?;
|
||||
}
|
||||
if body.is_empty() {
|
||||
return Err(anyhow!("message can't be blank"))?;
|
||||
}
|
||||
|
||||
let timestamp = OffsetDateTime::now_utc();
|
||||
let nonce = request
|
||||
.payload
|
||||
.nonce
|
||||
.ok_or_else(|| anyhow!("nonce can't be blank"))?;
|
||||
|
||||
let message_id = self
|
||||
.app_state
|
||||
.db
|
||||
.create_channel_message(channel_id, user_id, &body, timestamp, nonce.clone().into())
|
||||
.await?
|
||||
.to_proto();
|
||||
let message = proto::ChannelMessage {
|
||||
sender_id: user_id.to_proto(),
|
||||
id: message_id,
|
||||
body,
|
||||
timestamp: timestamp.unix_timestamp() as u64,
|
||||
nonce: Some(nonce),
|
||||
};
|
||||
broadcast(request.sender_id, connection_ids, |conn_id| {
|
||||
self.peer.send(
|
||||
conn_id,
|
||||
proto::ChannelMessageSent {
|
||||
channel_id: channel_id.to_proto(),
|
||||
message: Some(message.clone()),
|
||||
},
|
||||
)
|
||||
});
|
||||
response.send(proto::SendChannelMessageResponse {
|
||||
message: Some(message),
|
||||
})?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn get_channel_messages(
|
||||
self: Arc<Self>,
|
||||
request: TypedEnvelope<proto::GetChannelMessages>,
|
||||
response: Response<proto::GetChannelMessages>,
|
||||
) -> Result<()> {
|
||||
let user_id = self
|
||||
.store()
|
||||
.await
|
||||
.user_id_for_connection(request.sender_id)?;
|
||||
let channel_id = ChannelId::from_proto(request.payload.channel_id);
|
||||
if !self
|
||||
.app_state
|
||||
.db
|
||||
.can_user_access_channel(user_id, channel_id)
|
||||
.await?
|
||||
{
|
||||
Err(anyhow!("access denied"))?;
|
||||
}
|
||||
|
||||
let messages = self
|
||||
.app_state
|
||||
.db
|
||||
.get_channel_messages(
|
||||
channel_id,
|
||||
MESSAGE_COUNT_PER_PAGE,
|
||||
Some(MessageId::from_proto(request.payload.before_message_id)),
|
||||
)
|
||||
.await?
|
||||
.into_iter()
|
||||
.map(|msg| proto::ChannelMessage {
|
||||
id: msg.id.to_proto(),
|
||||
body: msg.body,
|
||||
timestamp: msg.sent_at.unix_timestamp() as u64,
|
||||
sender_id: msg.sender_id.to_proto(),
|
||||
nonce: Some(msg.nonce.as_u128().into()),
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
response.send(proto::GetChannelMessagesResponse {
|
||||
done: messages.len() < MESSAGE_COUNT_PER_PAGE,
|
||||
messages,
|
||||
})?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn update_diff_base(
|
||||
self: Arc<Server>,
|
||||
request: TypedEnvelope<proto::UpdateDiffBase>,
|
||||
|
@ -2061,11 +1756,8 @@ pub async fn handle_websocket_request(
|
|||
}
|
||||
|
||||
pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> axum::response::Response {
|
||||
// We call `store_mut` here for its side effects of updating metrics.
|
||||
let metrics = server.store().await.metrics();
|
||||
METRIC_CONNECTIONS.set(metrics.connections as _);
|
||||
METRIC_REGISTERED_PROJECTS.set(metrics.registered_projects as _);
|
||||
METRIC_ACTIVE_PROJECTS.set(metrics.active_projects as _);
|
||||
METRIC_SHARED_PROJECTS.set(metrics.shared_projects as _);
|
||||
|
||||
let encoder = prometheus::TextEncoder::new();
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue