assistant: Require user to accept TOS for cloud provider (#16111)
This adds the requirement for users to accept the terms of service the first time they send a message with the Cloud provider. Once this is out and in a nightly, we need to add the check to the server side too, to authenticate access to the models. Demo: https://github.com/user-attachments/assets/0edebf74-8120-4fa2-b801-bb76f04e8a17 Release Notes: - N/A
This commit is contained in:
parent
98f314ba21
commit
fbb533b3e0
14 changed files with 297 additions and 9 deletions
|
@ -9,7 +9,8 @@ CREATE TABLE "users" (
|
|||
"connected_once" BOOLEAN NOT NULL DEFAULT false,
|
||||
"created_at" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"metrics_id" TEXT,
|
||||
"github_user_id" INTEGER
|
||||
"github_user_id" INTEGER,
|
||||
"accepted_tos_at" TIMESTAMP WITHOUT TIME ZONE
|
||||
);
|
||||
CREATE UNIQUE INDEX "index_users_github_login" ON "users" ("github_login");
|
||||
CREATE UNIQUE INDEX "index_invite_code_users" ON "users" ("invite_code");
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
ALTER TABLE users ADD accepted_tos_at TIMESTAMP WITHOUT TIME ZONE;
|
|
@ -225,6 +225,26 @@ impl Database {
|
|||
.await
|
||||
}
|
||||
|
||||
/// Sets "accepted_tos_at" on the user to the given timestamp.
|
||||
pub async fn set_user_accepted_tos_at(
|
||||
&self,
|
||||
id: UserId,
|
||||
accepted_tos_at: Option<DateTime>,
|
||||
) -> Result<()> {
|
||||
self.transaction(|tx| async move {
|
||||
user::Entity::update_many()
|
||||
.filter(user::Column::Id.eq(id))
|
||||
.set(user::ActiveModel {
|
||||
accepted_tos_at: ActiveValue::set(accepted_tos_at),
|
||||
..Default::default()
|
||||
})
|
||||
.exec(&*tx)
|
||||
.await?;
|
||||
Ok(())
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
/// hard delete the user.
|
||||
pub async fn destroy_user(&self, id: UserId) -> Result<()> {
|
||||
self.transaction(|tx| async move {
|
||||
|
|
|
@ -18,6 +18,7 @@ pub struct Model {
|
|||
pub connected_once: bool,
|
||||
pub metrics_id: Uuid,
|
||||
pub created_at: DateTime,
|
||||
pub accepted_tos_at: Option<DateTime>,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
|
||||
|
|
|
@ -10,6 +10,7 @@ mod extension_tests;
|
|||
mod feature_flag_tests;
|
||||
mod message_tests;
|
||||
mod processed_stripe_event_tests;
|
||||
mod user_tests;
|
||||
|
||||
use crate::migrations::run_database_migrations;
|
||||
|
||||
|
|
45
crates/collab/src/db/tests/user_tests.rs
Normal file
45
crates/collab/src/db/tests/user_tests.rs
Normal file
|
@ -0,0 +1,45 @@
|
|||
use chrono::Utc;
|
||||
|
||||
use crate::{
|
||||
db::{Database, NewUserParams},
|
||||
test_both_dbs,
|
||||
};
|
||||
use std::sync::Arc;
|
||||
|
||||
test_both_dbs!(
|
||||
test_accepted_tos,
|
||||
test_accepted_tos_postgres,
|
||||
test_accepted_tos_sqlite
|
||||
);
|
||||
|
||||
async fn test_accepted_tos(db: &Arc<Database>) {
|
||||
let user_id = db
|
||||
.create_user(
|
||||
"user1@example.com",
|
||||
false,
|
||||
NewUserParams {
|
||||
github_login: "user1".to_string(),
|
||||
github_user_id: 1,
|
||||
},
|
||||
)
|
||||
.await
|
||||
.unwrap()
|
||||
.user_id;
|
||||
|
||||
let user = db.get_user_by_id(user_id).await.unwrap().unwrap();
|
||||
assert!(user.accepted_tos_at.is_none());
|
||||
|
||||
let accepted_tos_at = Utc::now().naive_utc();
|
||||
db.set_user_accepted_tos_at(user_id, Some(accepted_tos_at))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let user = db.get_user_by_id(user_id).await.unwrap().unwrap();
|
||||
assert!(user.accepted_tos_at.is_some());
|
||||
assert_eq!(user.accepted_tos_at, Some(accepted_tos_at));
|
||||
|
||||
db.set_user_accepted_tos_at(user_id, None).await.unwrap();
|
||||
|
||||
let user = db.get_user_by_id(user_id).await.unwrap().unwrap();
|
||||
assert!(user.accepted_tos_at.is_none());
|
||||
}
|
|
@ -31,6 +31,7 @@ use axum::{
|
|||
routing::get,
|
||||
Extension, Router, TypedHeader,
|
||||
};
|
||||
use chrono::Utc;
|
||||
use collections::{HashMap, HashSet};
|
||||
pub use connection_pool::{ConnectionPool, ZedVersion};
|
||||
use core::fmt::{self, Debug, Formatter};
|
||||
|
@ -604,6 +605,7 @@ impl Server {
|
|||
.add_message_handler(user_message_handler(update_followers))
|
||||
.add_request_handler(user_handler(get_private_user_info))
|
||||
.add_request_handler(user_handler(get_llm_api_token))
|
||||
.add_request_handler(user_handler(accept_terms_of_service))
|
||||
.add_message_handler(user_message_handler(acknowledge_channel_message))
|
||||
.add_message_handler(user_message_handler(acknowledge_buffer_version))
|
||||
.add_request_handler(user_handler(get_supermaven_api_key))
|
||||
|
@ -4882,6 +4884,25 @@ async fn get_private_user_info(
|
|||
metrics_id,
|
||||
staff: user.admin,
|
||||
flags,
|
||||
accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
|
||||
})?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Accept the terms of service (tos) on behalf of the current user
|
||||
async fn accept_terms_of_service(
|
||||
_request: proto::AcceptTermsOfService,
|
||||
response: Response<proto::AcceptTermsOfService>,
|
||||
session: UserSession,
|
||||
) -> Result<()> {
|
||||
let db = session.db().await;
|
||||
|
||||
let accepted_tos_at = Utc::now();
|
||||
db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
|
||||
.await?;
|
||||
|
||||
response.send(proto::AcceptTermsOfServiceResponse {
|
||||
accepted_tos_at: accepted_tos_at.timestamp() as u64,
|
||||
})?;
|
||||
Ok(())
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue