assistant: Require user to accept TOS for cloud provider (#16111)

This adds the requirement for users to accept the terms of service the
first time they send a message with the Cloud provider.

Once this is out and in a nightly, we need to add the check to the
server side too, to authenticate access to the models.

Demo:


https://github.com/user-attachments/assets/0edebf74-8120-4fa2-b801-bb76f04e8a17



Release Notes:

- N/A
This commit is contained in:
Thorsten Ball 2024-08-12 17:43:35 +02:00 committed by GitHub
parent 98f314ba21
commit fbb533b3e0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 297 additions and 9 deletions

View file

@ -9,7 +9,8 @@ CREATE TABLE "users" (
"connected_once" BOOLEAN NOT NULL DEFAULT false,
"created_at" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
"metrics_id" TEXT,
"github_user_id" INTEGER
"github_user_id" INTEGER,
"accepted_tos_at" TIMESTAMP WITHOUT TIME ZONE
);
CREATE UNIQUE INDEX "index_users_github_login" ON "users" ("github_login");
CREATE UNIQUE INDEX "index_invite_code_users" ON "users" ("invite_code");

View file

@ -0,0 +1 @@
ALTER TABLE users ADD accepted_tos_at TIMESTAMP WITHOUT TIME ZONE;

View file

@ -225,6 +225,26 @@ impl Database {
.await
}
/// Sets "accepted_tos_at" on the user to the given timestamp.
pub async fn set_user_accepted_tos_at(
&self,
id: UserId,
accepted_tos_at: Option<DateTime>,
) -> Result<()> {
self.transaction(|tx| async move {
user::Entity::update_many()
.filter(user::Column::Id.eq(id))
.set(user::ActiveModel {
accepted_tos_at: ActiveValue::set(accepted_tos_at),
..Default::default()
})
.exec(&*tx)
.await?;
Ok(())
})
.await
}
/// hard delete the user.
pub async fn destroy_user(&self, id: UserId) -> Result<()> {
self.transaction(|tx| async move {

View file

@ -18,6 +18,7 @@ pub struct Model {
pub connected_once: bool,
pub metrics_id: Uuid,
pub created_at: DateTime,
pub accepted_tos_at: Option<DateTime>,
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]

View file

@ -10,6 +10,7 @@ mod extension_tests;
mod feature_flag_tests;
mod message_tests;
mod processed_stripe_event_tests;
mod user_tests;
use crate::migrations::run_database_migrations;

View file

@ -0,0 +1,45 @@
use chrono::Utc;
use crate::{
db::{Database, NewUserParams},
test_both_dbs,
};
use std::sync::Arc;
test_both_dbs!(
test_accepted_tos,
test_accepted_tos_postgres,
test_accepted_tos_sqlite
);
async fn test_accepted_tos(db: &Arc<Database>) {
let user_id = db
.create_user(
"user1@example.com",
false,
NewUserParams {
github_login: "user1".to_string(),
github_user_id: 1,
},
)
.await
.unwrap()
.user_id;
let user = db.get_user_by_id(user_id).await.unwrap().unwrap();
assert!(user.accepted_tos_at.is_none());
let accepted_tos_at = Utc::now().naive_utc();
db.set_user_accepted_tos_at(user_id, Some(accepted_tos_at))
.await
.unwrap();
let user = db.get_user_by_id(user_id).await.unwrap().unwrap();
assert!(user.accepted_tos_at.is_some());
assert_eq!(user.accepted_tos_at, Some(accepted_tos_at));
db.set_user_accepted_tos_at(user_id, None).await.unwrap();
let user = db.get_user_by_id(user_id).await.unwrap().unwrap();
assert!(user.accepted_tos_at.is_none());
}

View file

@ -31,6 +31,7 @@ use axum::{
routing::get,
Extension, Router, TypedHeader,
};
use chrono::Utc;
use collections::{HashMap, HashSet};
pub use connection_pool::{ConnectionPool, ZedVersion};
use core::fmt::{self, Debug, Formatter};
@ -604,6 +605,7 @@ impl Server {
.add_message_handler(user_message_handler(update_followers))
.add_request_handler(user_handler(get_private_user_info))
.add_request_handler(user_handler(get_llm_api_token))
.add_request_handler(user_handler(accept_terms_of_service))
.add_message_handler(user_message_handler(acknowledge_channel_message))
.add_message_handler(user_message_handler(acknowledge_buffer_version))
.add_request_handler(user_handler(get_supermaven_api_key))
@ -4882,6 +4884,25 @@ async fn get_private_user_info(
metrics_id,
staff: user.admin,
flags,
accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
})?;
Ok(())
}
/// Accept the terms of service (tos) on behalf of the current user
async fn accept_terms_of_service(
_request: proto::AcceptTermsOfService,
response: Response<proto::AcceptTermsOfService>,
session: UserSession,
) -> Result<()> {
let db = session.db().await;
let accepted_tos_at = Utc::now();
db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
.await?;
response.send(proto::AcceptTermsOfServiceResponse {
accepted_tos_at: accepted_tos_at.timestamp() as u64,
})?;
Ok(())
}