Skip to content

Commit

Permalink
Add user_id field to auth_external.refresh_tokens
Browse files Browse the repository at this point in the history
  • Loading branch information
nlachat-compassion committed Oct 29, 2024
1 parent ee8bfcc commit 95977f9
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 6 deletions.
25 changes: 25 additions & 0 deletions auth_external/models/refresh_tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,14 @@ class RefreshTokens(models.Model):
)
]

user_id = fields.Many2one("res.users", ondelete="cascade", required=True)
"""
User for whom the token was issued. This is not used directly for any
functionality, but could be useful in an emergency situation where
suspicious activity is detected for a user. In that case, an administrator
could revoke all tokens for that particular user to prevent further harm.
"""

is_revoked = fields.Boolean(default=False, required=True)
"""
Whether the refresh token is revoked. False by default (for newly generated
Expand Down Expand Up @@ -137,6 +145,23 @@ def family_str(self) -> str:
f_str = f"[{f_str}]"
out += f"{f_str} <-> "
return out

def revoke_tokens_for_user(self, user_id: int) -> None:
"""
Revokes all the refresh_tokens for the user with the given user_id.
Requires admin privileges to run. This function can be used in an
emergency to revoke all the tokens for a user who acts suspiciously.
"""
user_tokens = self.search([("user_id", "=", user_id)])
nb_user_tokens = len(user_tokens)
nb_tokens_revoked = 0
for t in user_tokens:
if not t.is_revoked:
t.revoke()
nb_tokens_revoked += 1
_logger.info(f"""Revoked {nb_tokens_revoked} refresh_tokens for
{user_id=} ({nb_user_tokens} total tokens in the database
for this user, now all revoked).""")

@api.model
def remove_expired_tokens(self):
Expand Down
3 changes: 2 additions & 1 deletion auth_external/models/res_users.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,8 @@ def generate_external_auth_token(self, rt_old=None):
# timezone information) because all odoo datetimes are forced to
# UTC.
# See https://www.odoo.com/documentation/14.0/developer/reference/addons/orm.html?highlight=fields%20many2many#date-time-fields
"exp": rt_new_exp.replace(tzinfo=None)
"exp": rt_new_exp.replace(tzinfo=None),
"user_id": self.env.user.id
}
)
if rt_old is not None:
Expand Down
2 changes: 1 addition & 1 deletion auth_external/security/ir.model.access.csv
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
id,name,model_id:id,group_id:id,perm_read,perm_write,perm_create,perm_unlink
access_auth_external_tokens_config_user,access_auth_external_tokens_config_user,model_auth_external_tokens_config,base.group_user,1,0,0,0
access_auth_external_tokens_config_admin,access_auth_external_tokens_config_admin,model_auth_external_tokens_config,auth_external.group_admin,1,1,1,1
access_auth_external_refresh_tokens_admin,access_auth_external_refresh_tokens_admin,model_auth_external_refresh_tokens,auth_external.group_admin,1,0,0,0
access_auth_external_refresh_tokens_admin,access_auth_external_refresh_tokens_admin,model_auth_external_refresh_tokens,auth_external.group_admin,1,1,1,1
54 changes: 50 additions & 4 deletions auth_external/tests/test_refresh_tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,42 @@
from odoo.tests.common import TransactionCase
from datetime import datetime, timedelta
from ..models.refresh_tokens import RefreshTokens

FUTURE_TIMEDELTA = timedelta(minutes=3)


class TestRefreshTokens(TransactionCase):

@staticmethod
def get_uuid() -> str:
return str(uuid.uuid4())

def get_refresh_tokens(self):
return self.env["auth_external.refresh_tokens"]

def create_refresh_token(self, timediff: timedelta, parent: RefreshTokens = None) -> RefreshTokens:

def create_refresh_token(
self, timediff: timedelta, parent: RefreshTokens = None, user_id = None
) -> RefreshTokens:
exp = datetime.now() + timediff
rt = self.get_refresh_tokens().create({"jti": TestRefreshTokens.get_uuid(), "exp": exp})
user_id = self.test_user.id if user_id is None else user_id
rt = self.get_refresh_tokens().create(
{"jti": TestRefreshTokens.get_uuid(), "exp": exp, "user_id": user_id}
)
if parent is not None:
parent.link_child(rt)
return rt

def create_user(self) -> "res.users":
login = f"testuser_{random.randint(0, 10000)}"
return self.env["res.users"].create(
{"name": f"Name {login}", "login": login}
)

def setUp(self, *args, **kwargs):
super(TestRefreshTokens, self).setUp(*args, **kwargs)

self.test_user = self.create_user()

def test_get_by_jti(self):
rt1 = self.create_refresh_token(timedelta(hours=1))
got_rt1 = self.get_refresh_tokens().get_by_jti(rt1.jti)
Expand Down Expand Up @@ -69,5 +86,34 @@ def test_remove_expired_tokens(self):
self.assertIn(rt3, rts)
self.assertIn(rt4, rts)

def test_revoke_tokens_for_user(self):
test_user_tokens = [
self.create_refresh_token(FUTURE_TIMEDELTA) for _ in range(42)
]

for i, t in enumerate(test_user_tokens):
# Revoke some tokens to simulate a realistic scenario
if i % 3 == 0:
t.sudo().revoke()

# function under test
self.get_refresh_tokens().revoke_tokens_for_user(self.test_user.id)

for t in test_user_tokens:
self.assertTrue(t.is_revoked)


def test_ondelete_cascade(self):
user = self.create_user()
tokens = [self.create_refresh_token(FUTURE_TIMEDELTA, user_id=user.id) for _ in range(13)]

all_tokens = self.get_refresh_tokens().search([])
for t in tokens:
self.assertIn(t, all_tokens)

user.sudo().unlink()

all_tokens = self.get_refresh_tokens().search([])
for t in tokens:
self.assertNotIn(t, all_tokens)

0 comments on commit 95977f9

Please sign in to comment.