Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added flask_session/py.typed
Empty file.
62 changes: 50 additions & 12 deletions flask_session/sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
import os
import sys
import time
import typing as t
from datetime import datetime
from typing import TYPE_CHECKING, Any, ClassVar, Optional, Protocol, Type
from uuid import uuid4

import pytz
Expand All @@ -22,6 +24,11 @@
except ImportError:
import pickle # type: ignore[no-redef]

if TYPE_CHECKING:
from redis import Redis
from flask import Flask
from flask.wrappers import Request

from flask.sessions import SessionInterface as FlaskSessionInterface
from flask.sessions import SessionMixin
from itsdangerous import BadSignature, Signer, want_bytes
Expand All @@ -38,11 +45,23 @@ def total_seconds(td):
return td.days * 60 * 60 * 24 + td.seconds


class SessionSerializerType(Protocol):
def dumps(self, session_content: Any) -> bytes:
"""serialize a session object"""
...

def loads(self, serialized_session: bytes) -> Any:
"""deserialize a session object"""
...


class ServerSideSession(CallbackDict, SessionMixin):
"""Baseclass for server-side based sessions."""

def __init__(self, initial=None, sid=None, permanent=None):
def on_update(self):
def __init__(
self, initial=None, sid: Optional[str] = None, permanent: Optional[bool] = None
):
def on_update(self) -> None:
self.modified = True

CallbackDict.__init__(self, initial, on_update)
Expand Down Expand Up @@ -89,16 +108,25 @@ class DynamoDBSession(ServerSideSession):


class PeeweeSession(ServerSideSession):
def __init__(self, initial=None, sid=None, permanent=None, ip=None):
def __init__(
self,
initial=None,
sid: Optional[str] = None,
permanent: Optional[bool] = None,
ip=None,
):
super().__init__(initial, sid, permanent)
self.ip = ip


class SessionInterface(FlaskSessionInterface):
def _generate_sid(self):
serializer: ClassVar[SessionSerializerType]
session_class: ClassVar[Type[ServerSideSession]]

def _generate_sid(self) -> str:
return str(uuid4())

def _get_signer(self, app):
def _get_signer(self, app: Flask) -> Optional[Signer]:
if not app.secret_key:
return None
return Signer(app.secret_key, salt="flask-session", key_derivation="hmac")
Expand All @@ -107,7 +135,7 @@ def _get_signer(self, app):
class NullSessionInterface(SessionInterface):
"""Used to open a :class:`flask.sessions.NullSession` instance."""

def open_session(self, app, request):
def open_session(self, app: Flask, request: Request) -> None:
return None


Expand All @@ -123,10 +151,16 @@ class RedisSessionInterface(SessionInterface):
:param permanent: Whether to use permanent session or not.
"""

serializer = pickle
session_class = RedisSession
serializer: ClassVar[SessionSerializerType] = t.cast(SessionSerializerType, pickle)
session_class: ClassVar[Type[RedisSession]] = RedisSession

def __init__(self, redis, key_prefix, use_signer=False, permanent=True):
def __init__(
self,
redis: Optional[Redis],
key_prefix: str,
use_signer: bool = False,
permanent: bool = True,
):
if redis is None:
from redis import Redis

Expand All @@ -137,7 +171,7 @@ def __init__(self, redis, key_prefix, use_signer=False, permanent=True):
self.permanent = permanent
self.has_same_site_capability = hasattr(self, "get_cookie_samesite")

def open_session(self, app, request):
def open_session(self, app: Flask, request: Request) -> Optional[SessionMixin]:
sid = request.cookies.get(app.config["SESSION_COOKIE_NAME"])
if not sid:
sid = self._generate_sid()
Expand All @@ -164,7 +198,9 @@ def open_session(self, app, request):
return self.session_class(sid=sid, permanent=self.permanent)
return self.session_class(sid=sid, permanent=self.permanent)

def save_session(self, app, session, response):
def save_session(
self, app: Flask, session: SessionMixin, response: Response
) -> None:
if not self.should_set_cookie(app, session):
return
domain = self.get_cookie_domain(app)
Expand Down Expand Up @@ -202,7 +238,9 @@ def save_session(self, app, session, response):
)

if self.use_signer:
session_id = self._get_signer(app).sign(want_bytes(session.sid))
session_id = t.cast(Signer, self._get_signer(app)).sign(
want_bytes(session.sid)
)
else:
session_id = session.sid
response.set_cookie(
Expand Down