From 3ad8c7a653a8560aeebb0c05ab0c695616ddc62d Mon Sep 17 00:00:00 2001 From: Georges-Antoine Assi Date: Sun, 17 Mar 2024 12:22:36 -0400 Subject: [PATCH] fix jwt tokens --- backend/handler/auth_handler/middleware.py | 37 +++++++++++++++++----- backend/main.py | 26 +++++++-------- 2 files changed, 42 insertions(+), 21 deletions(-) diff --git a/backend/handler/auth_handler/middleware.py b/backend/handler/auth_handler/middleware.py index 428245d09..a6430f4ce 100644 --- a/backend/handler/auth_handler/middleware.py +++ b/backend/handler/auth_handler/middleware.py @@ -32,22 +32,23 @@ class SessionMiddleware: jwt_alg: str = "HS256", ) -> None: self.app = app + self.jwt_alg = jwt_alg - self.jwt_header = {"alg": jwt_alg} if not isinstance(secret_key, SecretKey): self.jwt_secret = SecretKey(Secret(str(secret_key)), None) else: self.jwt_secret = secret_key # check crypto setup so we bail out if needed - _jwt = jwt.encode(self.jwt_header, {"1": 2}, str(self.jwt_secret.encode)) + _jwt = jwt.encode({"1": 2}, key=str(self.jwt_secret.encode), algorithm=jwt_alg) assert {"1": 2} == jwt.decode( _jwt, - str( + key=str( self.jwt_secret.decode if self.jwt_secret.decode else self.jwt_secret.encode ), + algorithms=[jwt_alg], ), "wrong crypto setup" self.session_cookie = session_cookie @@ -56,6 +57,22 @@ class SessionMiddleware: if https_only: # Secure flag can be used with HTTPS only self.security_flags += "; secure" + def _validate_jwt_payload(self, jwt_payload): + if not isinstance(jwt_payload, dict): + return {} + + # The "exp" (expiration time) claim identifies the expiration time on + # or after which the JWT MUST NOT be accepted for processing. + if "exp" in jwt_payload and jwt_payload["exp"] < int(time.time()): + return {} + + # The "nbf" (not before) claim identifies the time before which the JWT + # MUST NOT be accepted for processing. + if "nbf" in jwt_payload and jwt_payload["nbf"] > int(time.time()): + return {} + + return jwt_payload + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: if scope["type"] not in ("http", "websocket"): # pragma: no cover await self.app(scope, receive, send) @@ -69,14 +86,15 @@ class SessionMiddleware: try: jwt_payload = jwt.decode( data, - str( + key=str( self.jwt_secret.decode if self.jwt_secret.decode else self.jwt_secret.encode ), + algorithms=[self.jwt_alg], ) - jwt_payload.validate_exp(time.time(), 0) - jwt_payload.validate_nbf(time.time(), 0) + + jwt_payload = self._validate_jwt_payload(jwt_payload) scope["session"] = jwt_payload initial_session_was_empty = False except JWTError: @@ -89,14 +107,17 @@ class SessionMiddleware: if scope["session"]: if "exp" not in scope["session"]: scope["session"]["exp"] = int(time.time()) + self.max_age + data = jwt.encode( - self.jwt_header, scope["session"], str(self.jwt_secret.encode) + scope["session"], + key=str(self.jwt_secret.encode), + algorithm=self.jwt_alg, ) headers = MutableHeaders(scope=message) header_value = "%s=%s; path=/; Max-Age=%d; %s" % ( self.session_cookie, - data.decode("utf-8"), + data, self.max_age, self.security_flags, ) diff --git a/backend/main.py b/backend/main.py index bfe6e970f..e0b1e0edb 100644 --- a/backend/main.py +++ b/backend/main.py @@ -4,6 +4,7 @@ import sys import alembic.config import uvicorn from config import DEV_HOST, DEV_PORT, ROMM_AUTH_SECRET_KEY, DISABLE_CSRF_PROTECTION +from contextlib import asynccontextmanager from endpoints import ( auth, config, @@ -30,7 +31,18 @@ from handler.auth_handler.hybrid_auth import HybridAuthBackend from handler.auth_handler.middleware import CustomCSRFMiddleware, SessionMiddleware from starlette.middleware.authentication import AuthenticationMiddleware -app = FastAPI(title="RomM API", version=github_handler.get_version()) + +@asynccontextmanager +async def lifespan(app: FastAPI): + if "pytest" not in sys.modules: + # Create default admin user if no admin user exists + if len(db_user_handler.get_admin_users()) == 0: + auth_handler.create_default_admin_user() + + yield + + +app = FastAPI(title="RomM API", version=github_handler.get_version(), lifespan=lifespan) app.add_middleware( CORSMiddleware, @@ -82,18 +94,6 @@ add_pagination(app) app.mount("/ws", socket_handler.socket_app) -@app.on_event("startup") -def startup() -> None: - """Event to handle RomM startup logic.""" - - if "pytest" in sys.modules: - return - - # Create default admin user if no admin user exists - if len(db_user_handler.get_admin_users()) == 0: - auth_handler.create_default_admin_user() - - if __name__ == "__main__": # Run migrations alembic.config.main(argv=["upgrade", "head"])