import time import typing from collections import namedtuple from starlette.datastructures import MutableHeaders, Secret from starlette.requests import HTTPConnection from starlette.types import ASGIApp, Message, Receive, Scope, Send from starlette_csrf.middleware import CSRFMiddleware from jose import jwt, JWTError class CustomCSRFMiddleware(CSRFMiddleware): async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: if scope["type"] != "http": await self.app(scope, receive, send) return await super().__call__(scope, receive, send) SecretKey = namedtuple("SecretKey", ("encode", "decode")) class SessionMiddleware: def __init__( self, app: ASGIApp, secret_key: typing.Union[str, Secret, SecretKey], session_cookie: str = "session", max_age: int = 14 * 24 * 60 * 60, # 14 days, in seconds same_site: str = "lax", https_only: bool = False, jwt_alg: str = "HS256", ) -> None: self.app = app self.jwt_alg = jwt_alg if not isinstance(secret_key, SecretKey): self.jwt_secret = SecretKey(Secret(str(secret_key)), None) else: self.jwt_secret = secret_key # check crypto setup so we bail out if needed _jwt = jwt.encode({"1": 2}, key=str(self.jwt_secret.encode), algorithm=jwt_alg) assert {"1": 2} == jwt.decode( _jwt, key=str( self.jwt_secret.decode if self.jwt_secret.decode else self.jwt_secret.encode ), algorithms=[jwt_alg], ), "wrong crypto setup" self.session_cookie = session_cookie self.max_age = max_age self.security_flags = "httponly; samesite=" + same_site if https_only: # Secure flag can be used with HTTPS only self.security_flags += "; secure" def _validate_jwt_payload(self, jwt_payload): if not isinstance(jwt_payload, dict): return {} # The "exp" (expiration time) claim identifies the expiration time on # or after which the JWT MUST NOT be accepted for processing. if "exp" in jwt_payload and jwt_payload["exp"] < int(time.time()): return {} # The "nbf" (not before) claim identifies the time before which the JWT # MUST NOT be accepted for processing. if "nbf" in jwt_payload and jwt_payload["nbf"] > int(time.time()): return {} return jwt_payload async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: if scope["type"] not in ("http", "websocket"): # pragma: no cover await self.app(scope, receive, send) return connection = HTTPConnection(scope) initial_session_was_empty = True if self.session_cookie in connection.cookies: data = connection.cookies[self.session_cookie].encode("utf-8") try: jwt_payload = jwt.decode( data, key=str( self.jwt_secret.decode if self.jwt_secret.decode else self.jwt_secret.encode ), algorithms=[self.jwt_alg], ) jwt_payload = self._validate_jwt_payload(jwt_payload) scope["session"] = jwt_payload initial_session_was_empty = False except JWTError: scope["session"] = {} else: scope["session"] = {} async def send_wrapper(message: Message) -> None: if message["type"] == "http.response.start": if scope["session"]: if "exp" not in scope["session"]: scope["session"]["exp"] = int(time.time()) + self.max_age data = jwt.encode( scope["session"], key=str(self.jwt_secret.encode), algorithm=self.jwt_alg, ) headers = MutableHeaders(scope=message) header_value = "%s=%s; path=/; Max-Age=%d; %s" % ( self.session_cookie, data, self.max_age, self.security_flags, ) headers.append("Set-Cookie", header_value) elif not initial_session_was_empty: # The session has been cleared. headers = MutableHeaders(scope=message) header_value = "%s=%s; %s" % ( self.session_cookie, "null; path=/; expires=Thu, 01 Jan 1970 00:00:00 GMT;", self.security_flags, ) headers.append("Set-Cookie", header_value) await send(message) await self.app(scope, receive, send_wrapper)