diff --git a/flask_session/py.typed b/flask_session/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/flask_session/sessions.py b/flask_session/sessions.py index 40da962b..9fb0399a 100644 --- a/flask_session/sessions.py +++ b/flask_session/sessions.py @@ -12,7 +12,9 @@ import os import sys import time +import typing as t from datetime import datetime +from typing import TYPE_CHECKING, Any, ClassVar, Optional, Protocol, Type from uuid import uuid4 import pytz @@ -22,6 +24,11 @@ except ImportError: import pickle # type: ignore[no-redef] +if TYPE_CHECKING: + from redis import Redis + from flask import Flask + from flask.wrappers import Request + from flask.sessions import SessionInterface as FlaskSessionInterface from flask.sessions import SessionMixin from itsdangerous import BadSignature, Signer, want_bytes @@ -38,11 +45,23 @@ def total_seconds(td): return td.days * 60 * 60 * 24 + td.seconds +class SessionSerializerType(Protocol): + def dumps(self, session_content: Any) -> bytes: + """serialize a session object""" + ... + + def loads(self, serialized_session: bytes) -> Any: + """deserialize a session object""" + ... + + class ServerSideSession(CallbackDict, SessionMixin): """Baseclass for server-side based sessions.""" - def __init__(self, initial=None, sid=None, permanent=None): - def on_update(self): + def __init__( + self, initial=None, sid: Optional[str] = None, permanent: Optional[bool] = None + ): + def on_update(self) -> None: self.modified = True CallbackDict.__init__(self, initial, on_update) @@ -89,16 +108,25 @@ class DynamoDBSession(ServerSideSession): class PeeweeSession(ServerSideSession): - def __init__(self, initial=None, sid=None, permanent=None, ip=None): + def __init__( + self, + initial=None, + sid: Optional[str] = None, + permanent: Optional[bool] = None, + ip=None, + ): super().__init__(initial, sid, permanent) self.ip = ip class SessionInterface(FlaskSessionInterface): - def _generate_sid(self): + serializer: ClassVar[SessionSerializerType] + session_class: ClassVar[Type[ServerSideSession]] + + def _generate_sid(self) -> str: return str(uuid4()) - def _get_signer(self, app): + def _get_signer(self, app: Flask) -> Optional[Signer]: if not app.secret_key: return None return Signer(app.secret_key, salt="flask-session", key_derivation="hmac") @@ -107,7 +135,7 @@ def _get_signer(self, app): class NullSessionInterface(SessionInterface): """Used to open a :class:`flask.sessions.NullSession` instance.""" - def open_session(self, app, request): + def open_session(self, app: Flask, request: Request) -> None: return None @@ -123,10 +151,16 @@ class RedisSessionInterface(SessionInterface): :param permanent: Whether to use permanent session or not. """ - serializer = pickle - session_class = RedisSession + serializer: ClassVar[SessionSerializerType] = t.cast(SessionSerializerType, pickle) + session_class: ClassVar[Type[RedisSession]] = RedisSession - def __init__(self, redis, key_prefix, use_signer=False, permanent=True): + def __init__( + self, + redis: Optional[Redis], + key_prefix: str, + use_signer: bool = False, + permanent: bool = True, + ): if redis is None: from redis import Redis @@ -137,7 +171,7 @@ def __init__(self, redis, key_prefix, use_signer=False, permanent=True): self.permanent = permanent self.has_same_site_capability = hasattr(self, "get_cookie_samesite") - def open_session(self, app, request): + def open_session(self, app: Flask, request: Request) -> Optional[SessionMixin]: sid = request.cookies.get(app.config["SESSION_COOKIE_NAME"]) if not sid: sid = self._generate_sid() @@ -164,7 +198,9 @@ def open_session(self, app, request): return self.session_class(sid=sid, permanent=self.permanent) return self.session_class(sid=sid, permanent=self.permanent) - def save_session(self, app, session, response): + def save_session( + self, app: Flask, session: SessionMixin, response: Response + ) -> None: if not self.should_set_cookie(app, session): return domain = self.get_cookie_domain(app) @@ -202,7 +238,9 @@ def save_session(self, app, session, response): ) if self.use_signer: - session_id = self._get_signer(app).sign(want_bytes(session.sid)) + session_id = t.cast(Signer, self._get_signer(app)).sign( + want_bytes(session.sid) + ) else: session_id = session.sid response.set_cookie(