mas_storage_pg/user/
mod.rs

1// Copyright 2024, 2025 New Vector Ltd.
2// Copyright 2021-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
5// Please see LICENSE files in the repository root for full details.
6
7//! A module containing the PostgreSQL implementation of the user-related
8//! repositories
9
10use async_trait::async_trait;
11use mas_data_model::User;
12use mas_storage::{
13    Clock,
14    user::{UserFilter, UserRepository},
15};
16use rand::RngCore;
17use sea_query::{Expr, PostgresQueryBuilder, Query};
18use sea_query_binder::SqlxBinder;
19use sqlx::PgConnection;
20use ulid::Ulid;
21use uuid::Uuid;
22
23use crate::{
24    DatabaseError,
25    filter::{Filter, StatementExt},
26    iden::Users,
27    pagination::QueryBuilderExt,
28    tracing::ExecuteExt,
29};
30
31mod email;
32mod password;
33mod recovery;
34mod registration;
35mod registration_token;
36mod session;
37mod terms;
38
39#[cfg(test)]
40mod tests;
41
42pub use self::{
43    email::PgUserEmailRepository, password::PgUserPasswordRepository,
44    recovery::PgUserRecoveryRepository, registration::PgUserRegistrationRepository,
45    registration_token::PgUserRegistrationTokenRepository, session::PgBrowserSessionRepository,
46    terms::PgUserTermsRepository,
47};
48
49/// An implementation of [`UserRepository`] for a PostgreSQL connection
50pub struct PgUserRepository<'c> {
51    conn: &'c mut PgConnection,
52}
53
54impl<'c> PgUserRepository<'c> {
55    /// Create a new [`PgUserRepository`] from an active PostgreSQL connection
56    pub fn new(conn: &'c mut PgConnection) -> Self {
57        Self { conn }
58    }
59}
60
61mod priv_ {
62    // The enum_def macro generates a public enum, which we don't want, because it
63    // triggers the missing docs warning
64    #![allow(missing_docs)]
65
66    use chrono::{DateTime, Utc};
67    use sea_query::enum_def;
68    use uuid::Uuid;
69
70    #[derive(Debug, Clone, sqlx::FromRow)]
71    #[enum_def]
72    pub(super) struct UserLookup {
73        pub(super) user_id: Uuid,
74        pub(super) username: String,
75        pub(super) created_at: DateTime<Utc>,
76        pub(super) locked_at: Option<DateTime<Utc>>,
77        pub(super) deactivated_at: Option<DateTime<Utc>>,
78        pub(super) can_request_admin: bool,
79    }
80}
81
82use priv_::{UserLookup, UserLookupIden};
83
84impl From<UserLookup> for User {
85    fn from(value: UserLookup) -> Self {
86        let id = value.user_id.into();
87        Self {
88            id,
89            username: value.username,
90            sub: id.to_string(),
91            created_at: value.created_at,
92            locked_at: value.locked_at,
93            deactivated_at: value.deactivated_at,
94            can_request_admin: value.can_request_admin,
95        }
96    }
97}
98
99impl Filter for UserFilter<'_> {
100    fn generate_condition(&self, _has_joins: bool) -> impl sea_query::IntoCondition {
101        sea_query::Condition::all()
102            .add_option(self.state().map(|state| {
103                match state {
104                    mas_storage::user::UserState::Deactivated => {
105                        Expr::col((Users::Table, Users::DeactivatedAt)).is_not_null()
106                    }
107                    mas_storage::user::UserState::Locked => {
108                        Expr::col((Users::Table, Users::LockedAt)).is_not_null()
109                    }
110                    mas_storage::user::UserState::Active => {
111                        Expr::col((Users::Table, Users::LockedAt))
112                            .is_null()
113                            .and(Expr::col((Users::Table, Users::DeactivatedAt)).is_null())
114                    }
115                }
116            }))
117            .add_option(self.can_request_admin().map(|can_request_admin| {
118                Expr::col((Users::Table, Users::CanRequestAdmin)).eq(can_request_admin)
119            }))
120    }
121}
122
123#[async_trait]
124impl UserRepository for PgUserRepository<'_> {
125    type Error = DatabaseError;
126
127    #[tracing::instrument(
128        name = "db.user.lookup",
129        skip_all,
130        fields(
131            db.query.text,
132            user.id = %id,
133        ),
134        err,
135    )]
136    async fn lookup(&mut self, id: Ulid) -> Result<Option<User>, Self::Error> {
137        let res = sqlx::query_as!(
138            UserLookup,
139            r#"
140                SELECT user_id
141                     , username
142                     , created_at
143                     , locked_at
144                     , deactivated_at
145                     , can_request_admin
146                FROM users
147                WHERE user_id = $1
148            "#,
149            Uuid::from(id),
150        )
151        .traced()
152        .fetch_optional(&mut *self.conn)
153        .await?;
154
155        let Some(res) = res else { return Ok(None) };
156
157        Ok(Some(res.into()))
158    }
159
160    #[tracing::instrument(
161        name = "db.user.find_by_username",
162        skip_all,
163        fields(
164            db.query.text,
165            user.username = username,
166        ),
167        err,
168    )]
169    async fn find_by_username(&mut self, username: &str) -> Result<Option<User>, Self::Error> {
170        // We may have multiple users with the same username, but with a different
171        // casing. In this case, we want to return the one which matches the exact
172        // casing
173        let res = sqlx::query_as!(
174            UserLookup,
175            r#"
176                SELECT user_id
177                     , username
178                     , created_at
179                     , locked_at
180                     , deactivated_at
181                     , can_request_admin
182                FROM users
183                WHERE LOWER(username) = LOWER($1)
184            "#,
185            username,
186        )
187        .traced()
188        .fetch_all(&mut *self.conn)
189        .await?;
190
191        match &res[..] {
192            // Happy path: there is only one user matching the username…
193            [user] => Ok(Some(user.clone().into())),
194            // …or none.
195            [] => Ok(None),
196            list => {
197                // If there are multiple users with the same username, we want to
198                // return the one which matches the exact casing
199                if let Some(user) = list.iter().find(|user| user.username == username) {
200                    Ok(Some(user.clone().into()))
201                } else {
202                    // If none match exactly, we prefer to return nothing
203                    Ok(None)
204                }
205            }
206        }
207    }
208
209    #[tracing::instrument(
210        name = "db.user.add",
211        skip_all,
212        fields(
213            db.query.text,
214            user.username = username,
215            user.id,
216        ),
217        err,
218    )]
219    async fn add(
220        &mut self,
221        rng: &mut (dyn RngCore + Send),
222        clock: &dyn Clock,
223        username: String,
224    ) -> Result<User, Self::Error> {
225        let created_at = clock.now();
226        let id = Ulid::from_datetime_with_source(created_at.into(), rng);
227        tracing::Span::current().record("user.id", tracing::field::display(id));
228
229        let res = sqlx::query!(
230            r#"
231                INSERT INTO users (user_id, username, created_at)
232                VALUES ($1, $2, $3)
233                ON CONFLICT (username) DO NOTHING
234            "#,
235            Uuid::from(id),
236            username,
237            created_at,
238        )
239        .traced()
240        .execute(&mut *self.conn)
241        .await?;
242
243        // If the user already exists, want to return an error but not poison the
244        // transaction
245        DatabaseError::ensure_affected_rows(&res, 1)?;
246
247        Ok(User {
248            id,
249            username,
250            sub: id.to_string(),
251            created_at,
252            locked_at: None,
253            deactivated_at: None,
254            can_request_admin: false,
255        })
256    }
257
258    #[tracing::instrument(
259        name = "db.user.exists",
260        skip_all,
261        fields(
262            db.query.text,
263            user.username = username,
264        ),
265        err,
266    )]
267    async fn exists(&mut self, username: &str) -> Result<bool, Self::Error> {
268        let exists = sqlx::query_scalar!(
269            r#"
270                SELECT EXISTS(
271                    SELECT 1 FROM users WHERE LOWER(username) = LOWER($1)
272                ) AS "exists!"
273            "#,
274            username
275        )
276        .traced()
277        .fetch_one(&mut *self.conn)
278        .await?;
279
280        Ok(exists)
281    }
282
283    #[tracing::instrument(
284        name = "db.user.lock",
285        skip_all,
286        fields(
287            db.query.text,
288            %user.id,
289        ),
290        err,
291    )]
292    async fn lock(&mut self, clock: &dyn Clock, mut user: User) -> Result<User, Self::Error> {
293        if user.locked_at.is_some() {
294            return Ok(user);
295        }
296
297        let locked_at = clock.now();
298        let res = sqlx::query!(
299            r#"
300                UPDATE users
301                SET locked_at = $1
302                WHERE user_id = $2
303            "#,
304            locked_at,
305            Uuid::from(user.id),
306        )
307        .traced()
308        .execute(&mut *self.conn)
309        .await?;
310
311        DatabaseError::ensure_affected_rows(&res, 1)?;
312
313        user.locked_at = Some(locked_at);
314
315        Ok(user)
316    }
317
318    #[tracing::instrument(
319        name = "db.user.unlock",
320        skip_all,
321        fields(
322            db.query.text,
323            %user.id,
324        ),
325        err,
326    )]
327    async fn unlock(&mut self, mut user: User) -> Result<User, Self::Error> {
328        if user.locked_at.is_none() {
329            return Ok(user);
330        }
331
332        let res = sqlx::query!(
333            r#"
334                UPDATE users
335                SET locked_at = NULL
336                WHERE user_id = $1
337            "#,
338            Uuid::from(user.id),
339        )
340        .traced()
341        .execute(&mut *self.conn)
342        .await?;
343
344        DatabaseError::ensure_affected_rows(&res, 1)?;
345
346        user.locked_at = None;
347
348        Ok(user)
349    }
350
351    #[tracing::instrument(
352        name = "db.user.deactivate",
353        skip_all,
354        fields(
355            db.query.text,
356            %user.id,
357        ),
358        err,
359    )]
360    async fn deactivate(&mut self, clock: &dyn Clock, mut user: User) -> Result<User, Self::Error> {
361        if user.deactivated_at.is_some() {
362            return Ok(user);
363        }
364
365        let deactivated_at = clock.now();
366        let res = sqlx::query!(
367            r#"
368                UPDATE users
369                SET deactivated_at = $2
370                WHERE user_id = $1
371                  AND deactivated_at IS NULL
372            "#,
373            Uuid::from(user.id),
374            deactivated_at,
375        )
376        .traced()
377        .execute(&mut *self.conn)
378        .await?;
379
380        DatabaseError::ensure_affected_rows(&res, 1)?;
381
382        user.deactivated_at = Some(user.created_at);
383
384        Ok(user)
385    }
386
387    #[tracing::instrument(
388        name = "db.user.set_can_request_admin",
389        skip_all,
390        fields(
391            db.query.text,
392            %user.id,
393            user.can_request_admin = can_request_admin,
394        ),
395        err,
396    )]
397    async fn set_can_request_admin(
398        &mut self,
399        mut user: User,
400        can_request_admin: bool,
401    ) -> Result<User, Self::Error> {
402        let res = sqlx::query!(
403            r#"
404                UPDATE users
405                SET can_request_admin = $2
406                WHERE user_id = $1
407            "#,
408            Uuid::from(user.id),
409            can_request_admin,
410        )
411        .traced()
412        .execute(&mut *self.conn)
413        .await?;
414
415        DatabaseError::ensure_affected_rows(&res, 1)?;
416
417        user.can_request_admin = can_request_admin;
418
419        Ok(user)
420    }
421
422    #[tracing::instrument(
423        name = "db.user.list",
424        skip_all,
425        fields(
426            db.query.text,
427        ),
428        err,
429    )]
430    async fn list(
431        &mut self,
432        filter: UserFilter<'_>,
433        pagination: mas_storage::Pagination,
434    ) -> Result<mas_storage::Page<User>, Self::Error> {
435        let (sql, arguments) = Query::select()
436            .expr_as(
437                Expr::col((Users::Table, Users::UserId)),
438                UserLookupIden::UserId,
439            )
440            .expr_as(
441                Expr::col((Users::Table, Users::Username)),
442                UserLookupIden::Username,
443            )
444            .expr_as(
445                Expr::col((Users::Table, Users::CreatedAt)),
446                UserLookupIden::CreatedAt,
447            )
448            .expr_as(
449                Expr::col((Users::Table, Users::LockedAt)),
450                UserLookupIden::LockedAt,
451            )
452            .expr_as(
453                Expr::col((Users::Table, Users::DeactivatedAt)),
454                UserLookupIden::DeactivatedAt,
455            )
456            .expr_as(
457                Expr::col((Users::Table, Users::CanRequestAdmin)),
458                UserLookupIden::CanRequestAdmin,
459            )
460            .from(Users::Table)
461            .apply_filter(filter)
462            .generate_pagination((Users::Table, Users::UserId), pagination)
463            .build_sqlx(PostgresQueryBuilder);
464
465        let edges: Vec<UserLookup> = sqlx::query_as_with(&sql, arguments)
466            .traced()
467            .fetch_all(&mut *self.conn)
468            .await?;
469
470        let page = pagination.process(edges).map(User::from);
471
472        Ok(page)
473    }
474
475    #[tracing::instrument(
476        name = "db.user.count",
477        skip_all,
478        fields(
479            db.query.text,
480        ),
481        err,
482    )]
483    async fn count(&mut self, filter: UserFilter<'_>) -> Result<usize, Self::Error> {
484        let (sql, arguments) = Query::select()
485            .expr(Expr::col((Users::Table, Users::UserId)).count())
486            .from(Users::Table)
487            .apply_filter(filter)
488            .build_sqlx(PostgresQueryBuilder);
489
490        let count: i64 = sqlx::query_scalar_with(&sql, arguments)
491            .traced()
492            .fetch_one(&mut *self.conn)
493            .await?;
494
495        count
496            .try_into()
497            .map_err(DatabaseError::to_invalid_operation)
498    }
499
500    #[tracing::instrument(
501        name = "db.user.acquire_lock_for_sync",
502        skip_all,
503        fields(
504            db.query.text,
505            user.id = %user.id,
506        ),
507        err,
508    )]
509    async fn acquire_lock_for_sync(&mut self, user: &User) -> Result<(), Self::Error> {
510        // XXX: this lock isn't stictly scoped to users, but as we don't use many
511        // postgres advisory locks, it's fine for now. Later on, we could use row-level
512        // locks to make sure we don't get into trouble
513
514        // Convert the user ID to a u128 and grab the lower 64 bits
515        // As this includes 64bit of the random part of the ULID, it should be random
516        // enough to not collide
517        let lock_id = (u128::from(user.id) & 0xffff_ffff_ffff_ffff) as i64;
518
519        // Use a PG advisory lock, which will be released when the transaction is
520        // committed or rolled back
521        sqlx::query!(
522            r#"
523                SELECT pg_advisory_xact_lock($1)
524            "#,
525            lock_id,
526        )
527        .traced()
528        .execute(&mut *self.conn)
529        .await?;
530
531        Ok(())
532    }
533}