From 315df0c6fce8c70b5ba02d3cf0e3d40973db88c4 Mon Sep 17 00:00:00 2001 From: Simon Oliver Tveit Date: Thu, 6 Feb 2025 08:02:09 +0100 Subject: [PATCH] Add endpoint for JWT refresh tokens --- changelog.d/3270.added.md | 1 + python/nav/web/api/v1/urls.py | 1 + python/nav/web/api/v1/views.py | 42 ++++ python/nav/web/jwtgen.py | 13 ++ .../integration/jwt_refresh_endpoint_test.py | 185 ++++++++++++++++++ 5 files changed, 242 insertions(+) create mode 100644 changelog.d/3270.added.md create mode 100644 tests/integration/jwt_refresh_endpoint_test.py diff --git a/changelog.d/3270.added.md b/changelog.d/3270.added.md new file mode 100644 index 0000000000..1923819760 --- /dev/null +++ b/changelog.d/3270.added.md @@ -0,0 +1 @@ +Add endpoint for JWT refresh tokens diff --git a/python/nav/web/api/v1/urls.py b/python/nav/web/api/v1/urls.py index 04a1fe6e31..c48499c61b 100644 --- a/python/nav/web/api/v1/urls.py +++ b/python/nav/web/api/v1/urls.py @@ -73,4 +73,5 @@ name="prefix-usage-detail", ), re_path(r'^', include(router.urls)), + re_path(r'^refresh/$', views.JWTRefreshViewSet.as_view(), name='jwt-refresh'), ] diff --git a/python/nav/web/api/v1/views.py b/python/nav/web/api/v1/views.py index fbeef1ef99..d8ec714122 100644 --- a/python/nav/web/api/v1/views.py +++ b/python/nav/web/api/v1/views.py @@ -45,6 +45,7 @@ from oidc_auth.authentication import JSONWebTokenAuthentication from nav.models import manage, event, cabling, rack, profiles +from nav.models.api import JWTRefreshToken from nav.models.fields import INFINITY, UNRESOLVED from nav.web.servicecheckers import load_checker_classes from nav.util import auth_token, is_valid_cidr @@ -52,6 +53,12 @@ from nav.buildconf import VERSION from nav.web.api.v1 import serializers, alert_serializers from nav.web.status2 import STATELESS_THRESHOLD +from nav.web.jwtgen import ( + generate_access_token, + generate_refresh_token, + hash_token, + decode_token, +) from nav.macaddress import MacPrefix from .auth import ( APIPermission, @@ -1153,3 +1160,38 @@ class ModuleViewSet(NAVAPIMixin, viewsets.ReadOnlyModelViewSet): 'device__serial', ) serializer_class = serializers.ModuleSerializer + + +class JWTRefreshViewSet(APIView): + """ + Accepts a valid refresh token. + Returns a new refresh token and an access token. + """ + + def post(self, request): + incoming_token = request.data.get('refresh_token') + token_hash = hash_token(incoming_token) + try: + # If hash exists in the database, then we know it is a real token + db_token = JWTRefreshToken.objects.get(hash=token_hash) + except JWTRefreshToken.DoesNotExist: + return Response("Invalid token", status=status.HTTP_403_FORBIDDEN) + if not db_token.is_active(): + return Response("Inactive token", status=status.HTTP_403_FORBIDDEN) + + claims = decode_token(incoming_token) + access_token = generate_access_token(claims) + refresh_token = generate_refresh_token(claims) + + new_claims = decode_token(refresh_token) + new_hash = hash_token(refresh_token) + db_token.hash = new_hash + db_token.expires = datetime.fromtimestamp(new_claims['exp']) + db_token.activates = datetime.fromtimestamp(new_claims['nbf']) + db_token.save() + + response_data = { + 'access_token': access_token, + 'refresh_token': refresh_token, + } + return Response(response_data) diff --git a/python/nav/web/jwtgen.py b/python/nav/web/jwtgen.py index 93a97b364a..be905510fb 100644 --- a/python/nav/web/jwtgen.py +++ b/python/nav/web/jwtgen.py @@ -1,5 +1,6 @@ from datetime import datetime, timedelta, timezone from typing import Any, Optional +import hashlib import jwt @@ -49,3 +50,15 @@ def _generate_token( new_token, JWTConf().get_nav_private_key(), algorithm="RS256" ) return encoded_token + + +def hash_token(token: str) -> str: + """Hashes a token with SHA256""" + hash_object = hashlib.sha256(token.encode('utf-8')) + hex_dig = hash_object.hexdigest() + return hex_dig + + +def decode_token(token: str) -> dict[str, Any]: + """Decodes a token in JWT format and returns the data of the decoded token""" + return jwt.decode(token, options={'verify_signature': False}) diff --git a/tests/integration/jwt_refresh_endpoint_test.py b/tests/integration/jwt_refresh_endpoint_test.py new file mode 100644 index 0000000000..b4ebf033df --- /dev/null +++ b/tests/integration/jwt_refresh_endpoint_test.py @@ -0,0 +1,185 @@ +from typing import Generator +import pytest +from datetime import datetime, timedelta + +from unittest.mock import Mock, patch + +from django.urls import reverse +from nav.models.api import JWTRefreshToken +from nav.web.jwtgen import generate_refresh_token, hash_token, decode_token + + +def test_token_not_in_database_should_be_rejected(db, api_client, url): + token = generate_refresh_token() + token_hash = hash_token(token) + + assert not JWTRefreshToken.objects.filter(hash=token_hash).exists() + response = api_client.post( + url, + follow=True, + data={ + 'refresh_token': token, + }, + ) + assert response.status_code == 403 + + +def test_inactive_token_should_be_rejected(db, api_client, url): + token = generate_refresh_token() + # Set expiry date in the past + now = datetime.now() + db_token = JWTRefreshToken( + name="testtoken", + hash=hash_token(token), + expires=now - timedelta(hours=1), + activates=now - timedelta(hours=2), + ) + db_token.save() + + response = api_client.post( + url, + follow=True, + data={ + 'refresh_token': token, + }, + ) + + assert response.status_code == 403 + + +def test_valid_token_should_be_accepted(db, api_client, url): + token = generate_refresh_token() + data = decode_token(token) + db_token = JWTRefreshToken( + name="testtoken", + hash=hash_token(token), + expires=datetime.fromtimestamp(data['exp']), + activates=datetime.fromtimestamp(data['nbf']), + ) + db_token.save() + response = api_client.post( + url, + follow=True, + data={ + 'refresh_token': token, + }, + ) + assert response.status_code == 200 + + +def test_valid_token_should_be_replaced_by_new_token_in_db(db, api_client, url): + token = generate_refresh_token() + token_hash = hash_token(token) + data = decode_token(token) + db_token = JWTRefreshToken( + name="testtoken", + hash=token_hash, + expires=datetime.fromtimestamp(data['exp']), + activates=datetime.fromtimestamp(data['nbf']), + ) + db_token.save() + response = api_client.post( + url, + follow=True, + data={ + 'refresh_token': token, + }, + ) + assert response.status_code == 200 + assert not JWTRefreshToken.objects.filter(hash=token_hash).exists() + new_token = response.data.get("refresh_token") + new_hash = hash_token(new_token) + assert JWTRefreshToken.objects.filter(hash=new_hash).exists() + + +def test_should_include_access_and_refresh_token_in_response(db, api_client, url): + token = generate_refresh_token() + data = decode_token(token) + db_token = JWTRefreshToken( + name="testtoken", + hash=hash_token(token), + expires=datetime.fromtimestamp(data['exp']), + activates=datetime.fromtimestamp(data['nbf']), + ) + db_token.save() + response = api_client.post( + url, + follow=True, + data={ + 'refresh_token': token, + }, + ) + assert response.status_code == 200 + assert "access_token" in response.data + assert "refresh_token" in response.data + + +@pytest.fixture() +def url(): + return reverse('api:1:jwt-refresh') + + +@pytest.fixture(scope="module", autouse=True) +def jwtconf_mock(private_key, nav_name) -> Generator[str, None, None]: + """Mocks the get_nave_name and get_nav_private_key functions for + the JWTConf class + """ + with patch("nav.web.jwtgen.JWTConf") as _jwtconf_mock: + instance = _jwtconf_mock.return_value + instance.get_nav_name = Mock(return_value=nav_name) + instance.get_nav_private_key = Mock(return_value=private_key) + yield _jwtconf_mock + + +@pytest.fixture(scope="module") +def private_key() -> str: + """Yields a private key in PEM format""" + key = """-----BEGIN PRIVATE KEY----- +MIIEuwIBADANBgkqhkiG9w0BAQEFAASCBKUwggShAgEAAoIBAQCp+4AEZM4uYZKu +/hrKzySMTFFx3/ncWo6XAFpADQHXLOwRB9Xh1/OwigHiqs/wHRAAmnrlkwCCQA8r +xiHBAMjp5ApbkyggQz/DVijrpSba6Tiy1cyBTZC3cvOK2FpJzsakJLhIXD1HaULO +ClyIJB/YrmHmQc8SL3Uzou5mMpdcBC2pzwmEW1cvQURpnvgrDF8V86GrQkjK6nIP +IEeuW6kbD5lWFAPfLf1ohDWex3yxeSFyXNRApJhbF4HrKFemPkOi7acsky38UomQ +jZgAMHPotJNkQvAHcnXHhg0FcWGdohv5bc/Ctt9GwZOzJxwyJLBBsSewbE310TZi +3oLU1TmvAgMBAAECgf8zrhi95+gdMeKRpwV+TnxOK5CXjqvo0vTcnr7Runf/c9On +WeUtRPr83E4LxuMcSGRqdTfoP0loUGb3EsYwZ+IDOnyWWvytfRoQdExSA2RM1PDo +GRiUN4Dy8CrGNqvnb3agG99Ay3Ura6q5T20n9ykM4qKL3yDrO9fmWyMgRJbAOAYm +xzf7H910mDZghXPpq8nzDky0JLNZcaqbxuPQ3+EI4p2dLNXbNqMPs8Y20JKLeOPs +HikRM0zfhHEJSt5IPFQ54/CzscGHGeCleQINWTgvDLMcE5fJMvbLLZixV+YsBfAq +e2JsSubS+9RI2ktMlSKaemr8yeoIpsXfAiJSHkECgYEA0NKU18xK+9w5IXfgNwI4 +peu2tWgwyZSp5R2pdLT7O1dJoLYRoAmcXNePB0VXNARqGxTNypJ9zmMawNmf3YRS +BqG8aKz7qpATlx9OwYlk09fsS6MeVmaur8bHGHP6O+gt7Xg+zhiFPvU9P5LB+C0Z +0d4grEmIxNhJCtJRQOThD8ECgYEA0GKRO9SJdnhw1b6LPLd+o/AX7IEzQDHwdtfi +0h7hKHHGBlUMbIBwwjKmyKm6cSe0PYe96LqrVg+cVf84wbLZPAixhOjyplLznBzF +LqOrfFPfI5lQVhslE1H1CdLlk9eyT96jDgmLAg8EGSMV8aLGj++Gi2l/isujHlWF +BI4YpW8CgYEAsyKyhJzABmbYq5lGQmopZkxapCwJDiP1ypIzd+Z5TmKGytLlM8CK +3iocjEQzlm/jBfBGyWv5eD8UCDOoLEMCiqXcFn+uNJb79zvoN6ZBVGl6TzhTIhNb +73Y5/QQguZtnKrtoRSxLwcJnFE41D0zBRYOjy6gZJ6PSpPHeuiid2QECgYACuZc+ +mgvmIbMQCHrXo2qjiCs364SZDU4gr7gGmWLGXZ6CTLBp5tASqgjmTNnkSumfeFvy +ZCaDbJbVxQ2f8s/GajKwEz/BDwqievnVH0zJxmr/kyyqw5Ybh5HVvA1GfqaVRssJ +DvTjZQDft0a9Lyy7ix1OS2XgkcMjTWj840LNPwKBgDPXMBgL5h41jd7jCsXzPhyr +V96RzQkPcKsoVvrCoNi8eoEYgRd9jwfiU12rlXv+fgVXrrfMoJBoYT6YtrxEJVdM +RAjRpnE8PMqCUA8Rd7RFK9Vp5Uo8RxTNvk9yPvDv1+lHHV7lEltIk5PXuKPHIrc1 +nNUyhzvJs2Qba2L/huNC +-----END PRIVATE KEY-----""" + return key + + +@pytest.fixture() +def public_key() -> str: + """Yields a public key in PEM format""" + key = """-----BEGIN PUBLIC KEY----- +MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAqfuABGTOLmGSrv4ays8k +jExRcd/53FqOlwBaQA0B1yzsEQfV4dfzsIoB4qrP8B0QAJp65ZMAgkAPK8YhwQDI +6eQKW5MoIEM/w1Yo66Um2uk4stXMgU2Qt3LzithaSc7GpCS4SFw9R2lCzgpciCQf +2K5h5kHPEi91M6LuZjKXXAQtqc8JhFtXL0FEaZ74KwxfFfOhq0JIyupyDyBHrlup +Gw+ZVhQD3y39aIQ1nsd8sXkhclzUQKSYWxeB6yhXpj5Dou2nLJMt/FKJkI2YADBz +6LSTZELwB3J1x4YNBXFhnaIb+W3PwrbfRsGTsyccMiSwQbEnsGxN9dE2Yt6C1NU5 +rwIDAQAB +-----END PUBLIC KEY-----""" + return key + + +@pytest.fixture(scope="module") +def nav_name() -> str: + return "nav"