diff --git a/auth_external/models/refresh_tokens.py b/auth_external/models/refresh_tokens.py index d8a0cff5..48e105d3 100644 --- a/auth_external/models/refresh_tokens.py +++ b/auth_external/models/refresh_tokens.py @@ -28,6 +28,14 @@ class RefreshTokens(models.Model): ) ] + user_id = fields.Many2one("res.users", ondelete="cascade", required=True) + """ + User for whom the token was issued. This is not used directly for any + functionality, but could be useful in an emergency situation where + suspicious activity is detected for a user. In that case, an administrator + could revoke all tokens for that particular user to prevent further harm. + """ + is_revoked = fields.Boolean(default=False, required=True) """ Whether the refresh token is revoked. False by default (for newly generated @@ -137,6 +145,23 @@ def family_str(self) -> str: f_str = f"[{f_str}]" out += f"{f_str} <-> " return out + + def revoke_tokens_for_user(self, user_id: int) -> None: + """ + Revokes all the refresh_tokens for the user with the given user_id. + Requires admin privileges to run. This function can be used in an + emergency to revoke all the tokens for a user who acts suspiciously. + """ + user_tokens = self.search([("user_id", "=", user_id)]) + nb_user_tokens = len(user_tokens) + nb_tokens_revoked = 0 + for t in user_tokens: + if not t.is_revoked: + t.revoke() + nb_tokens_revoked += 1 + _logger.info(f"""Revoked {nb_tokens_revoked} refresh_tokens for + {user_id=} ({nb_user_tokens} total tokens in the database + for this user, now all revoked).""") @api.model def remove_expired_tokens(self): diff --git a/auth_external/models/res_users.py b/auth_external/models/res_users.py index 18212081..79001dcf 100644 --- a/auth_external/models/res_users.py +++ b/auth_external/models/res_users.py @@ -210,7 +210,8 @@ def generate_external_auth_token(self, rt_old=None): # timezone information) because all odoo datetimes are forced to # UTC. # See https://www.odoo.com/documentation/14.0/developer/reference/addons/orm.html?highlight=fields%20many2many#date-time-fields - "exp": rt_new_exp.replace(tzinfo=None) + "exp": rt_new_exp.replace(tzinfo=None), + "user_id": self.env.user.id } ) if rt_old is not None: diff --git a/auth_external/security/ir.model.access.csv b/auth_external/security/ir.model.access.csv index 88da4424..7c0ca80e 100644 --- a/auth_external/security/ir.model.access.csv +++ b/auth_external/security/ir.model.access.csv @@ -1,4 +1,4 @@ id,name,model_id:id,group_id:id,perm_read,perm_write,perm_create,perm_unlink access_auth_external_tokens_config_user,access_auth_external_tokens_config_user,model_auth_external_tokens_config,base.group_user,1,0,0,0 access_auth_external_tokens_config_admin,access_auth_external_tokens_config_admin,model_auth_external_tokens_config,auth_external.group_admin,1,1,1,1 -access_auth_external_refresh_tokens_admin,access_auth_external_refresh_tokens_admin,model_auth_external_refresh_tokens,auth_external.group_admin,1,0,0,0 +access_auth_external_refresh_tokens_admin,access_auth_external_refresh_tokens_admin,model_auth_external_refresh_tokens,auth_external.group_admin,1,1,1,1 diff --git a/auth_external/tests/test_refresh_tokens.py b/auth_external/tests/test_refresh_tokens.py index edaf43f3..7b52176a 100644 --- a/auth_external/tests/test_refresh_tokens.py +++ b/auth_external/tests/test_refresh_tokens.py @@ -3,25 +3,42 @@ from odoo.tests.common import TransactionCase from datetime import datetime, timedelta from ..models.refresh_tokens import RefreshTokens + +FUTURE_TIMEDELTA = timedelta(minutes=3) + + class TestRefreshTokens(TransactionCase): @staticmethod def get_uuid() -> str: return str(uuid.uuid4()) - + def get_refresh_tokens(self): return self.env["auth_external.refresh_tokens"] - - def create_refresh_token(self, timediff: timedelta, parent: RefreshTokens = None) -> RefreshTokens: + + def create_refresh_token( + self, timediff: timedelta, parent: RefreshTokens = None, user_id = None + ) -> RefreshTokens: exp = datetime.now() + timediff - rt = self.get_refresh_tokens().create({"jti": TestRefreshTokens.get_uuid(), "exp": exp}) + user_id = self.test_user.id if user_id is None else user_id + rt = self.get_refresh_tokens().create( + {"jti": TestRefreshTokens.get_uuid(), "exp": exp, "user_id": user_id} + ) if parent is not None: parent.link_child(rt) return rt + + def create_user(self) -> "res.users": + login = f"testuser_{random.randint(0, 10000)}" + return self.env["res.users"].create( + {"name": f"Name {login}", "login": login} + ) def setUp(self, *args, **kwargs): super(TestRefreshTokens, self).setUp(*args, **kwargs) + self.test_user = self.create_user() + def test_get_by_jti(self): rt1 = self.create_refresh_token(timedelta(hours=1)) got_rt1 = self.get_refresh_tokens().get_by_jti(rt1.jti) @@ -69,5 +86,34 @@ def test_remove_expired_tokens(self): self.assertIn(rt3, rts) self.assertIn(rt4, rts) + def test_revoke_tokens_for_user(self): + test_user_tokens = [ + self.create_refresh_token(FUTURE_TIMEDELTA) for _ in range(42) + ] + + for i, t in enumerate(test_user_tokens): + # Revoke some tokens to simulate a realistic scenario + if i % 3 == 0: + t.sudo().revoke() + + # function under test + self.get_refresh_tokens().revoke_tokens_for_user(self.test_user.id) + + for t in test_user_tokens: + self.assertTrue(t.is_revoked) + + + def test_ondelete_cascade(self): + user = self.create_user() + tokens = [self.create_refresh_token(FUTURE_TIMEDELTA, user_id=user.id) for _ in range(13)] + + all_tokens = self.get_refresh_tokens().search([]) + for t in tokens: + self.assertIn(t, all_tokens) + + user.sudo().unlink() + all_tokens = self.get_refresh_tokens().search([]) + for t in tokens: + self.assertNotIn(t, all_tokens)