Skip to content

Commit 874af87

Browse files
authored
Merge pull request #38 from cicekhayri/load-user-middleware
Load user middleware
2 parents 14120b4 + c08ee95 commit 874af87

File tree

7 files changed

+178
-37
lines changed

7 files changed

+178
-37
lines changed

inspira/auth/mixins/user_mixin.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,17 @@ def set_password(self, password):
2525

2626
def check_password_hash(self, password):
2727
return bcrypt.checkpw(password.encode(UTF8), self.password.encode(UTF8))
28+
29+
30+
class AnonymousUserMixin:
31+
@property
32+
def is_authenticated(self):
33+
return False
34+
35+
@property
36+
def is_active(self):
37+
return False
38+
39+
@property
40+
def is_anonymous(self):
41+
return True

inspira/middlewares/sessions.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -9,30 +9,29 @@
99

1010

1111
class SessionMiddleware:
12-
def __init__(self, secret_key):
13-
self.secret_key = secret_key
14-
self.config = get_global_app().config or Config()
12+
def __init__(self):
13+
self.app = get_global_app()
1514

1615
def build_set_cookie_header(self, session_data):
17-
encoded_payload = encode_session_data(session_data, self.secret_key)
16+
encoded_payload = encode_session_data(session_data, self.app.secret_key)
1817
expires_date = datetime.datetime.utcnow() + datetime.timedelta(
19-
seconds=self.config["SESSION_MAX_AGE"]
18+
seconds=self.app.config["SESSION_MAX_AGE"]
2019
)
2120
formatted_expires = expires_date.strftime("%a, %d %b %Y %H:%M:%S GMT")
2221

2322
cookie_value = (
24-
f"{self.config['SESSION_COOKIE_NAME']}={encoded_payload}; "
25-
f"Expires={formatted_expires}; Path={self.config['SESSION_COOKIE_PATH'] or '/'}; HttpOnly"
23+
f"{self.app.config['SESSION_COOKIE_NAME']}={encoded_payload}; "
24+
f"Expires={formatted_expires}; Path={self.app.config['SESSION_COOKIE_PATH'] or '/'}; HttpOnly"
2625
)
2726

28-
if self.config["SESSION_COOKIE_DOMAIN"]:
29-
cookie_value += f"; Domain={self.config['SESSION_COOKIE_DOMAIN']}"
27+
if self.app.config["SESSION_COOKIE_DOMAIN"]:
28+
cookie_value += f"; Domain={self.app.config['SESSION_COOKIE_DOMAIN']}"
3029

31-
if self.config["SESSION_COOKIE_SECURE"]:
30+
if self.app.config["SESSION_COOKIE_SECURE"]:
3231
cookie_value += "; Secure"
3332

34-
if self.config["SESSION_COOKIE_SAMESITE"]:
35-
cookie_value += f"; SameSite={self.config['SESSION_COOKIE_SAMESITE']}"
33+
if self.app.config["SESSION_COOKIE_SAMESITE"]:
34+
cookie_value += f"; SameSite={self.app.config['SESSION_COOKIE_SAMESITE']}"
3635

3736
return cookie_value
3837

@@ -44,12 +43,12 @@ async def send_wrapper(message):
4443
request = RequestContext().get_request()
4544

4645
cookies = SimpleCookie(request.get_headers().get("cookie", ""))
47-
session_cookie = cookies.get(self.config["SESSION_COOKIE_NAME"])
46+
session_cookie = cookies.get(self.app.config["SESSION_COOKIE_NAME"])
4847
decoded_session = {}
4948

5049
if session_cookie:
5150
decoded_session = decode_session_data(
52-
session_cookie.value, self.secret_key
51+
session_cookie.value, self.app.secret_key
5352
)
5453

5554
if not request.session or decoded_session != request.session:
@@ -58,7 +57,7 @@ async def send_wrapper(message):
5857

5958
headers.append((b"Set-Cookie", cookie_value.encode()))
6059
else:
61-
cookie_value = f"{self.config['SESSION_COOKIE_NAME']}=; Expires=Thu, 01 Jan 1970 00:00:00 GMT; Path={self.config['SESSION_COOKIE_PATH'] or '/'}; HttpOnly"
60+
cookie_value = f"{self.app.config['SESSION_COOKIE_NAME']}=; Expires=Thu, 01 Jan 1970 00:00:00 GMT; Path={self.app.config['SESSION_COOKIE_PATH'] or '/'}; HttpOnly"
6261

6362
headers.append((b"Set-Cookie", cookie_value.encode()))
6463

inspira/middlewares/user_loader.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
from http.cookies import SimpleCookie
2+
from typing import Dict, Any, Callable
3+
4+
from inspira.auth.auth_utils import decode_auth_token
5+
from inspira.auth.mixins.user_mixin import AnonymousUserMixin
6+
from inspira.config import Config
7+
from inspira.globals import get_global_app
8+
from inspira.requests import RequestContext
9+
from inspira.utils.session_utils import decode_session_data
10+
11+
12+
class UserLoaderMiddleware:
13+
def __init__(self, user_model):
14+
self.user_model = user_model
15+
self.app = get_global_app()
16+
17+
async def __call__(self, handler):
18+
async def middleware(scope: Dict[str, Any], receive: Callable, send: Callable):
19+
request = RequestContext().get_request()
20+
cookies = SimpleCookie(request.get_headers().get("cookie", ""))
21+
token = cookies.get(self.app.config["SESSION_COOKIE_NAME"])
22+
23+
user = None
24+
25+
if token:
26+
decoded_session = decode_session_data(token.value, self.app.secret_key)
27+
user_id = decode_auth_token(decoded_session.get("token", None))
28+
29+
if user_id:
30+
user = self.user_model.query.get(user_id)
31+
RequestContext.set_current_user(user or AnonymousUserMixin())
32+
request.user = RequestContext.get_current_user()
33+
34+
await handler(scope, receive, send)
35+
36+
return middleware

inspira/requests.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
class RequestContext:
99
_current_request = None
10+
_current_user = None
1011

1112
@classmethod
1213
def set_request(cls, request):
@@ -16,6 +17,14 @@ def set_request(cls, request):
1617
def get_request(cls):
1718
return cls._current_request
1819

20+
@classmethod
21+
def get_current_user(cls):
22+
return cls._current_user
23+
24+
@classmethod
25+
def set_current_user(cls, user):
26+
cls._current_user = user
27+
1928

2029
class Request:
2130
def __init__(self, scope: Dict[str, Any], receive: Callable, send: Callable):
@@ -25,6 +34,7 @@ def __init__(self, scope: Dict[str, Any], receive: Callable, send: Callable):
2534
self._session = {}
2635
self._headers = {}
2736
self._forbidden = False
37+
self.user = None
2838

2939
def is_forbidden(self):
3040
return self._forbidden

tests/conftest.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,22 +10,14 @@
1010
from inspira.testclient import TestClient
1111

1212

13-
@pytest.fixture
14-
def app():
15-
return Inspira()
16-
17-
1813
@pytest.fixture
1914
def secret_key():
2015
return "your_secret_key"
2116

2217

2318
@pytest.fixture
24-
def app_with_secret_token(secret_key):
25-
app = Inspira()
26-
app.secret_key = secret_key
27-
28-
return app
19+
def app(secret_key):
20+
return Inspira(secret_key)
2921

3022

3123
@pytest.fixture
@@ -34,8 +26,8 @@ def client(app):
3426

3527

3628
@pytest.fixture
37-
def client_session(app_with_secret_token):
38-
return TestClient(app_with_secret_token)
29+
def client_session(app):
30+
return TestClient(app)
3931

4032

4133
@pytest.fixture
@@ -84,3 +76,12 @@ def mock_scope():
8476
@pytest.fixture
8577
def sample_config():
8678
return Config()
79+
80+
81+
@pytest.fixture
82+
def user_mock():
83+
class User:
84+
def __init__(self, id):
85+
self.id = id
86+
87+
return User

tests/test_middleware.py

Lines changed: 89 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,15 @@
33

44
import pytest
55

6+
from inspira.auth.auth_utils import login_user, decode_auth_token
7+
from inspira.auth.decorators import login_required
8+
from inspira.auth.mixins.user_mixin import AnonymousUserMixin
69
from inspira.decorators.http_methods import get
710
from inspira.enums import HttpMethod
811
from inspira.middlewares.cors import CORSMiddleware
912
from inspira.middlewares.sessions import SessionMiddleware
10-
from inspira.requests import Request
13+
from inspira.middlewares.user_loader import UserLoaderMiddleware
14+
from inspira.requests import Request, RequestContext
1115
from inspira.responses import JsonResponse, HttpResponse
1216
from inspira.utils.session_utils import decode_session_data
1317

@@ -93,8 +97,8 @@ async def test_route(request: Request):
9397

9498

9599
@pytest.mark.asyncio
96-
async def test_set_session_success(app, client):
97-
session_middleware = SessionMiddleware(secret_key="dummy")
100+
async def test_set_session_success(app, secret_key, client):
101+
session_middleware = SessionMiddleware()
98102

99103
app.add_middleware(session_middleware)
100104

@@ -113,7 +117,7 @@ async def test_route(request: Request):
113117

114118
assert session_cookie is not None
115119

116-
decoded_session = decode_session_data(session_cookie.value, "dummy")
120+
decoded_session = decode_session_data(session_cookie.value, secret_key)
117121
expected_session = {"message": "Hej"}
118122

119123
assert decoded_session == expected_session
@@ -122,7 +126,7 @@ async def test_route(request: Request):
122126

123127
@pytest.mark.asyncio
124128
async def test_invalid_signature_exception(app, client):
125-
session_middleware = SessionMiddleware(secret_key="dummy")
129+
session_middleware = SessionMiddleware()
126130

127131
app.add_middleware(session_middleware)
128132

@@ -149,7 +153,7 @@ async def test_route(request: Request):
149153

150154
@pytest.mark.asyncio
151155
async def test_remove_session_success(app, client):
152-
session_middleware = SessionMiddleware(secret_key="dummy")
156+
session_middleware = SessionMiddleware()
153157

154158
app.add_middleware(session_middleware)
155159

@@ -184,7 +188,7 @@ async def remove_route(request: Request):
184188

185189
@pytest.mark.asyncio
186190
async def test_get_session_success(app, client):
187-
session_middleware = SessionMiddleware(secret_key="dummy")
191+
session_middleware = SessionMiddleware()
188192

189193
app.add_middleware(session_middleware)
190194

@@ -203,7 +207,7 @@ async def get_route(request: Request):
203207

204208
@pytest.mark.asyncio
205209
async def test_remove_nonexistent_key(app, client):
206-
session_middleware = SessionMiddleware(secret_key="dummy")
210+
session_middleware = SessionMiddleware()
207211

208212
app.add_middleware(session_middleware)
209213

@@ -224,3 +228,80 @@ async def remove_route(request: Request):
224228

225229
assert session_cookie.value == ""
226230
assert response.status_code == HTTPStatus.OK
231+
232+
233+
@pytest.mark.asyncio
234+
async def test_user_loader_middleware(app, client, user_mock, secret_key):
235+
session_middleware = SessionMiddleware()
236+
237+
user_loader_middleware = UserLoaderMiddleware(user_mock)
238+
239+
app.add_middleware(session_middleware)
240+
app.add_middleware(user_loader_middleware)
241+
242+
@get("/get")
243+
async def get_route(request: Request):
244+
user = user_mock(id=1)
245+
login_user(user.id)
246+
return HttpResponse(f"User ID: 1323")
247+
248+
app.add_route("/get", HttpMethod.GET, get_route)
249+
250+
response = await client.get("/get")
251+
set_cookie_header = response.headers.get("set-cookie", "")
252+
253+
assert set_cookie_header is not None
254+
assert "session=" in set_cookie_header
255+
256+
cookies = SimpleCookie(set_cookie_header)
257+
session_cookie = cookies.get("session")
258+
259+
assert session_cookie is not None
260+
261+
decoded_session = decode_session_data(session_cookie.value, secret_key)
262+
get_user_id = decode_auth_token(decoded_session["token"])
263+
264+
assert get_user_id == 1
265+
266+
267+
@pytest.mark.asyncio
268+
async def test_user_not_logged_in(app, client, secret_key, user_mock):
269+
session_middleware = SessionMiddleware()
270+
user_loader_middleware = UserLoaderMiddleware(user_mock)
271+
app.add_middleware(session_middleware)
272+
app.add_middleware(user_loader_middleware)
273+
274+
@get("/protected")
275+
@login_required
276+
async def protected(request: Request):
277+
return HttpResponse("Protected Route")
278+
279+
app.add_route("/protected", HttpMethod.GET, protected)
280+
281+
response = await client.get("/protected")
282+
283+
assert response.status_code == HTTPStatus.UNAUTHORIZED.value
284+
assert "Unauthorized" in response.text
285+
286+
287+
@pytest.mark.asyncio
288+
async def test_user_loader_middleware_anonymous_user(
289+
app, client, secret_key, user_mock
290+
):
291+
user_loader_middleware = UserLoaderMiddleware(user_mock)
292+
app.add_middleware(user_loader_middleware)
293+
294+
@get("/protected")
295+
async def protected(request: Request):
296+
user_authenticated = request.user.is_authenticated
297+
return JsonResponse({"user_authenticated": user_authenticated})
298+
299+
app.add_route("/protected", HttpMethod.GET, protected)
300+
301+
response = await client.get("/protected")
302+
303+
assert response.status_code == 200
304+
assert response.json() == {"user_authenticated": False}
305+
306+
user_in_method = RequestContext.get_current_user()
307+
assert isinstance(user_in_method, AnonymousUserMixin)

tests/test_responses.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ async def home(request):
6969

7070

7171
@pytest.mark.asyncio
72-
async def test_set_multiple_cookie(app_with_secret_token, client_session):
72+
async def test_set_multiple_cookie(app, client_session):
7373
@get("/home")
7474
async def home(request):
7575
http_response = HttpResponse("This is a test endpoint")
@@ -78,7 +78,7 @@ async def home(request):
7878

7979
return http_response
8080

81-
app_with_secret_token.add_route("/home", HttpMethod.GET, home)
81+
app.add_route("/home", HttpMethod.GET, home)
8282

8383
response = await client_session.get("/home")
8484
headers_dict = dict(response.headers)

0 commit comments

Comments
 (0)