Skip to content

Commit

Permalink
Add endpoint for JWT refresh tokens
Browse files Browse the repository at this point in the history
  • Loading branch information
stveit committed Feb 6, 2025
1 parent 50e46b4 commit 315df0c
Show file tree
Hide file tree
Showing 5 changed files with 242 additions and 0 deletions.
1 change: 1 addition & 0 deletions changelog.d/3270.added.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add endpoint for JWT refresh tokens
1 change: 1 addition & 0 deletions python/nav/web/api/v1/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,4 +73,5 @@
name="prefix-usage-detail",
),
re_path(r'^', include(router.urls)),
re_path(r'^refresh/$', views.JWTRefreshViewSet.as_view(), name='jwt-refresh'),
]
42 changes: 42 additions & 0 deletions python/nav/web/api/v1/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,20 @@
from oidc_auth.authentication import JSONWebTokenAuthentication

from nav.models import manage, event, cabling, rack, profiles
from nav.models.api import JWTRefreshToken
from nav.models.fields import INFINITY, UNRESOLVED
from nav.web.servicecheckers import load_checker_classes
from nav.util import auth_token, is_valid_cidr

from nav.buildconf import VERSION
from nav.web.api.v1 import serializers, alert_serializers
from nav.web.status2 import STATELESS_THRESHOLD
from nav.web.jwtgen import (
generate_access_token,
generate_refresh_token,
hash_token,
decode_token,
)
from nav.macaddress import MacPrefix
from .auth import (
APIPermission,
Expand Down Expand Up @@ -1153,3 +1160,38 @@ class ModuleViewSet(NAVAPIMixin, viewsets.ReadOnlyModelViewSet):
'device__serial',
)
serializer_class = serializers.ModuleSerializer


class JWTRefreshViewSet(APIView):
"""
Accepts a valid refresh token.
Returns a new refresh token and an access token.
"""

def post(self, request):
incoming_token = request.data.get('refresh_token')
token_hash = hash_token(incoming_token)
try:
# If hash exists in the database, then we know it is a real token
db_token = JWTRefreshToken.objects.get(hash=token_hash)
except JWTRefreshToken.DoesNotExist:
return Response("Invalid token", status=status.HTTP_403_FORBIDDEN)
if not db_token.is_active():
return Response("Inactive token", status=status.HTTP_403_FORBIDDEN)

claims = decode_token(incoming_token)
access_token = generate_access_token(claims)
refresh_token = generate_refresh_token(claims)

new_claims = decode_token(refresh_token)
new_hash = hash_token(refresh_token)
db_token.hash = new_hash
db_token.expires = datetime.fromtimestamp(new_claims['exp'])
db_token.activates = datetime.fromtimestamp(new_claims['nbf'])
db_token.save()

response_data = {
'access_token': access_token,
'refresh_token': refresh_token,
}
return Response(response_data)
13 changes: 13 additions & 0 deletions python/nav/web/jwtgen.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from datetime import datetime, timedelta, timezone
from typing import Any, Optional
import hashlib

import jwt

Expand Down Expand Up @@ -49,3 +50,15 @@ def _generate_token(
new_token, JWTConf().get_nav_private_key(), algorithm="RS256"
)
return encoded_token


def hash_token(token: str) -> str:
"""Hashes a token with SHA256"""
hash_object = hashlib.sha256(token.encode('utf-8'))
hex_dig = hash_object.hexdigest()
return hex_dig


def decode_token(token: str) -> dict[str, Any]:
"""Decodes a token in JWT format and returns the data of the decoded token"""
return jwt.decode(token, options={'verify_signature': False})
185 changes: 185 additions & 0 deletions tests/integration/jwt_refresh_endpoint_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
from typing import Generator
import pytest
from datetime import datetime, timedelta

from unittest.mock import Mock, patch

from django.urls import reverse
from nav.models.api import JWTRefreshToken
from nav.web.jwtgen import generate_refresh_token, hash_token, decode_token


def test_token_not_in_database_should_be_rejected(db, api_client, url):
token = generate_refresh_token()
token_hash = hash_token(token)

assert not JWTRefreshToken.objects.filter(hash=token_hash).exists()
response = api_client.post(
url,
follow=True,
data={
'refresh_token': token,
},
)
assert response.status_code == 403


def test_inactive_token_should_be_rejected(db, api_client, url):
token = generate_refresh_token()
# Set expiry date in the past
now = datetime.now()
db_token = JWTRefreshToken(
name="testtoken",
hash=hash_token(token),
expires=now - timedelta(hours=1),
activates=now - timedelta(hours=2),
)
db_token.save()

response = api_client.post(
url,
follow=True,
data={
'refresh_token': token,
},
)

assert response.status_code == 403


def test_valid_token_should_be_accepted(db, api_client, url):
token = generate_refresh_token()
data = decode_token(token)
db_token = JWTRefreshToken(
name="testtoken",
hash=hash_token(token),
expires=datetime.fromtimestamp(data['exp']),
activates=datetime.fromtimestamp(data['nbf']),
)
db_token.save()
response = api_client.post(
url,
follow=True,
data={
'refresh_token': token,
},
)
assert response.status_code == 200


def test_valid_token_should_be_replaced_by_new_token_in_db(db, api_client, url):
token = generate_refresh_token()
token_hash = hash_token(token)
data = decode_token(token)
db_token = JWTRefreshToken(
name="testtoken",
hash=token_hash,
expires=datetime.fromtimestamp(data['exp']),
activates=datetime.fromtimestamp(data['nbf']),
)
db_token.save()
response = api_client.post(
url,
follow=True,
data={
'refresh_token': token,
},
)
assert response.status_code == 200
assert not JWTRefreshToken.objects.filter(hash=token_hash).exists()
new_token = response.data.get("refresh_token")
new_hash = hash_token(new_token)
assert JWTRefreshToken.objects.filter(hash=new_hash).exists()


def test_should_include_access_and_refresh_token_in_response(db, api_client, url):
token = generate_refresh_token()
data = decode_token(token)
db_token = JWTRefreshToken(
name="testtoken",
hash=hash_token(token),
expires=datetime.fromtimestamp(data['exp']),
activates=datetime.fromtimestamp(data['nbf']),
)
db_token.save()
response = api_client.post(
url,
follow=True,
data={
'refresh_token': token,
},
)
assert response.status_code == 200
assert "access_token" in response.data
assert "refresh_token" in response.data


@pytest.fixture()
def url():
return reverse('api:1:jwt-refresh')


@pytest.fixture(scope="module", autouse=True)
def jwtconf_mock(private_key, nav_name) -> Generator[str, None, None]:
"""Mocks the get_nave_name and get_nav_private_key functions for
the JWTConf class
"""
with patch("nav.web.jwtgen.JWTConf") as _jwtconf_mock:
instance = _jwtconf_mock.return_value
instance.get_nav_name = Mock(return_value=nav_name)
instance.get_nav_private_key = Mock(return_value=private_key)
yield _jwtconf_mock


@pytest.fixture(scope="module")
def private_key() -> str:
"""Yields a private key in PEM format"""
key = """-----BEGIN PRIVATE KEY-----
MIIEuwIBADANBgkqhkiG9w0BAQEFAASCBKUwggShAgEAAoIBAQCp+4AEZM4uYZKu
/hrKzySMTFFx3/ncWo6XAFpADQHXLOwRB9Xh1/OwigHiqs/wHRAAmnrlkwCCQA8r
xiHBAMjp5ApbkyggQz/DVijrpSba6Tiy1cyBTZC3cvOK2FpJzsakJLhIXD1HaULO
ClyIJB/YrmHmQc8SL3Uzou5mMpdcBC2pzwmEW1cvQURpnvgrDF8V86GrQkjK6nIP
IEeuW6kbD5lWFAPfLf1ohDWex3yxeSFyXNRApJhbF4HrKFemPkOi7acsky38UomQ
jZgAMHPotJNkQvAHcnXHhg0FcWGdohv5bc/Ctt9GwZOzJxwyJLBBsSewbE310TZi
3oLU1TmvAgMBAAECgf8zrhi95+gdMeKRpwV+TnxOK5CXjqvo0vTcnr7Runf/c9On
WeUtRPr83E4LxuMcSGRqdTfoP0loUGb3EsYwZ+IDOnyWWvytfRoQdExSA2RM1PDo
GRiUN4Dy8CrGNqvnb3agG99Ay3Ura6q5T20n9ykM4qKL3yDrO9fmWyMgRJbAOAYm
xzf7H910mDZghXPpq8nzDky0JLNZcaqbxuPQ3+EI4p2dLNXbNqMPs8Y20JKLeOPs
HikRM0zfhHEJSt5IPFQ54/CzscGHGeCleQINWTgvDLMcE5fJMvbLLZixV+YsBfAq
e2JsSubS+9RI2ktMlSKaemr8yeoIpsXfAiJSHkECgYEA0NKU18xK+9w5IXfgNwI4
peu2tWgwyZSp5R2pdLT7O1dJoLYRoAmcXNePB0VXNARqGxTNypJ9zmMawNmf3YRS
BqG8aKz7qpATlx9OwYlk09fsS6MeVmaur8bHGHP6O+gt7Xg+zhiFPvU9P5LB+C0Z
0d4grEmIxNhJCtJRQOThD8ECgYEA0GKRO9SJdnhw1b6LPLd+o/AX7IEzQDHwdtfi
0h7hKHHGBlUMbIBwwjKmyKm6cSe0PYe96LqrVg+cVf84wbLZPAixhOjyplLznBzF
LqOrfFPfI5lQVhslE1H1CdLlk9eyT96jDgmLAg8EGSMV8aLGj++Gi2l/isujHlWF
BI4YpW8CgYEAsyKyhJzABmbYq5lGQmopZkxapCwJDiP1ypIzd+Z5TmKGytLlM8CK
3iocjEQzlm/jBfBGyWv5eD8UCDOoLEMCiqXcFn+uNJb79zvoN6ZBVGl6TzhTIhNb
73Y5/QQguZtnKrtoRSxLwcJnFE41D0zBRYOjy6gZJ6PSpPHeuiid2QECgYACuZc+
mgvmIbMQCHrXo2qjiCs364SZDU4gr7gGmWLGXZ6CTLBp5tASqgjmTNnkSumfeFvy
ZCaDbJbVxQ2f8s/GajKwEz/BDwqievnVH0zJxmr/kyyqw5Ybh5HVvA1GfqaVRssJ
DvTjZQDft0a9Lyy7ix1OS2XgkcMjTWj840LNPwKBgDPXMBgL5h41jd7jCsXzPhyr
V96RzQkPcKsoVvrCoNi8eoEYgRd9jwfiU12rlXv+fgVXrrfMoJBoYT6YtrxEJVdM
RAjRpnE8PMqCUA8Rd7RFK9Vp5Uo8RxTNvk9yPvDv1+lHHV7lEltIk5PXuKPHIrc1
nNUyhzvJs2Qba2L/huNC
-----END PRIVATE KEY-----"""
return key


@pytest.fixture()
def public_key() -> str:
"""Yields a public key in PEM format"""
key = """-----BEGIN PUBLIC KEY-----
MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAqfuABGTOLmGSrv4ays8k
jExRcd/53FqOlwBaQA0B1yzsEQfV4dfzsIoB4qrP8B0QAJp65ZMAgkAPK8YhwQDI
6eQKW5MoIEM/w1Yo66Um2uk4stXMgU2Qt3LzithaSc7GpCS4SFw9R2lCzgpciCQf
2K5h5kHPEi91M6LuZjKXXAQtqc8JhFtXL0FEaZ74KwxfFfOhq0JIyupyDyBHrlup
Gw+ZVhQD3y39aIQ1nsd8sXkhclzUQKSYWxeB6yhXpj5Dou2nLJMt/FKJkI2YADBz
6LSTZELwB3J1x4YNBXFhnaIb+W3PwrbfRsGTsyccMiSwQbEnsGxN9dE2Yt6C1NU5
rwIDAQAB
-----END PUBLIC KEY-----"""
return key


@pytest.fixture(scope="module")
def nav_name() -> str:
return "nav"

0 comments on commit 315df0c

Please sign in to comment.