fix jwt tokens

This commit is contained in:
Georges-Antoine Assi
2024-03-17 12:22:36 -04:00
parent 6bee52cc98
commit 3ad8c7a653
2 changed files with 42 additions and 21 deletions

View File

@@ -32,22 +32,23 @@ class SessionMiddleware:
jwt_alg: str = "HS256",
) -> None:
self.app = app
self.jwt_alg = jwt_alg
self.jwt_header = {"alg": jwt_alg}
if not isinstance(secret_key, SecretKey):
self.jwt_secret = SecretKey(Secret(str(secret_key)), None)
else:
self.jwt_secret = secret_key
# check crypto setup so we bail out if needed
_jwt = jwt.encode(self.jwt_header, {"1": 2}, str(self.jwt_secret.encode))
_jwt = jwt.encode({"1": 2}, key=str(self.jwt_secret.encode), algorithm=jwt_alg)
assert {"1": 2} == jwt.decode(
_jwt,
str(
key=str(
self.jwt_secret.decode
if self.jwt_secret.decode
else self.jwt_secret.encode
),
algorithms=[jwt_alg],
), "wrong crypto setup"
self.session_cookie = session_cookie
@@ -56,6 +57,22 @@ class SessionMiddleware:
if https_only: # Secure flag can be used with HTTPS only
self.security_flags += "; secure"
def _validate_jwt_payload(self, jwt_payload):
if not isinstance(jwt_payload, dict):
return {}
# The "exp" (expiration time) claim identifies the expiration time on
# or after which the JWT MUST NOT be accepted for processing.
if "exp" in jwt_payload and jwt_payload["exp"] < int(time.time()):
return {}
# The "nbf" (not before) claim identifies the time before which the JWT
# MUST NOT be accepted for processing.
if "nbf" in jwt_payload and jwt_payload["nbf"] > int(time.time()):
return {}
return jwt_payload
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] not in ("http", "websocket"): # pragma: no cover
await self.app(scope, receive, send)
@@ -69,14 +86,15 @@ class SessionMiddleware:
try:
jwt_payload = jwt.decode(
data,
str(
key=str(
self.jwt_secret.decode
if self.jwt_secret.decode
else self.jwt_secret.encode
),
algorithms=[self.jwt_alg],
)
jwt_payload.validate_exp(time.time(), 0)
jwt_payload.validate_nbf(time.time(), 0)
jwt_payload = self._validate_jwt_payload(jwt_payload)
scope["session"] = jwt_payload
initial_session_was_empty = False
except JWTError:
@@ -89,14 +107,17 @@ class SessionMiddleware:
if scope["session"]:
if "exp" not in scope["session"]:
scope["session"]["exp"] = int(time.time()) + self.max_age
data = jwt.encode(
self.jwt_header, scope["session"], str(self.jwt_secret.encode)
scope["session"],
key=str(self.jwt_secret.encode),
algorithm=self.jwt_alg,
)
headers = MutableHeaders(scope=message)
header_value = "%s=%s; path=/; Max-Age=%d; %s" % (
self.session_cookie,
data.decode("utf-8"),
data,
self.max_age,
self.security_flags,
)

View File

@@ -4,6 +4,7 @@ import sys
import alembic.config
import uvicorn
from config import DEV_HOST, DEV_PORT, ROMM_AUTH_SECRET_KEY, DISABLE_CSRF_PROTECTION
from contextlib import asynccontextmanager
from endpoints import (
auth,
config,
@@ -30,7 +31,18 @@ from handler.auth_handler.hybrid_auth import HybridAuthBackend
from handler.auth_handler.middleware import CustomCSRFMiddleware, SessionMiddleware
from starlette.middleware.authentication import AuthenticationMiddleware
app = FastAPI(title="RomM API", version=github_handler.get_version())
@asynccontextmanager
async def lifespan(app: FastAPI):
if "pytest" not in sys.modules:
# Create default admin user if no admin user exists
if len(db_user_handler.get_admin_users()) == 0:
auth_handler.create_default_admin_user()
yield
app = FastAPI(title="RomM API", version=github_handler.get_version(), lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
@@ -82,18 +94,6 @@ add_pagination(app)
app.mount("/ws", socket_handler.socket_app)
@app.on_event("startup")
def startup() -> None:
"""Event to handle RomM startup logic."""
if "pytest" in sys.modules:
return
# Create default admin user if no admin user exists
if len(db_user_handler.get_admin_users()) == 0:
auth_handler.create_default_admin_user()
if __name__ == "__main__":
# Run migrations
alembic.config.main(argv=["upgrade", "head"])