mirror of
https://github.com/rommapp/romm.git
synced 2026-06-29 15:25:46 +00:00
fix jwt tokens
This commit is contained in:
@@ -32,22 +32,23 @@ class SessionMiddleware:
|
||||
jwt_alg: str = "HS256",
|
||||
) -> None:
|
||||
self.app = app
|
||||
self.jwt_alg = jwt_alg
|
||||
|
||||
self.jwt_header = {"alg": jwt_alg}
|
||||
if not isinstance(secret_key, SecretKey):
|
||||
self.jwt_secret = SecretKey(Secret(str(secret_key)), None)
|
||||
else:
|
||||
self.jwt_secret = secret_key
|
||||
|
||||
# check crypto setup so we bail out if needed
|
||||
_jwt = jwt.encode(self.jwt_header, {"1": 2}, str(self.jwt_secret.encode))
|
||||
_jwt = jwt.encode({"1": 2}, key=str(self.jwt_secret.encode), algorithm=jwt_alg)
|
||||
assert {"1": 2} == jwt.decode(
|
||||
_jwt,
|
||||
str(
|
||||
key=str(
|
||||
self.jwt_secret.decode
|
||||
if self.jwt_secret.decode
|
||||
else self.jwt_secret.encode
|
||||
),
|
||||
algorithms=[jwt_alg],
|
||||
), "wrong crypto setup"
|
||||
|
||||
self.session_cookie = session_cookie
|
||||
@@ -56,6 +57,22 @@ class SessionMiddleware:
|
||||
if https_only: # Secure flag can be used with HTTPS only
|
||||
self.security_flags += "; secure"
|
||||
|
||||
def _validate_jwt_payload(self, jwt_payload):
|
||||
if not isinstance(jwt_payload, dict):
|
||||
return {}
|
||||
|
||||
# The "exp" (expiration time) claim identifies the expiration time on
|
||||
# or after which the JWT MUST NOT be accepted for processing.
|
||||
if "exp" in jwt_payload and jwt_payload["exp"] < int(time.time()):
|
||||
return {}
|
||||
|
||||
# The "nbf" (not before) claim identifies the time before which the JWT
|
||||
# MUST NOT be accepted for processing.
|
||||
if "nbf" in jwt_payload and jwt_payload["nbf"] > int(time.time()):
|
||||
return {}
|
||||
|
||||
return jwt_payload
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
if scope["type"] not in ("http", "websocket"): # pragma: no cover
|
||||
await self.app(scope, receive, send)
|
||||
@@ -69,14 +86,15 @@ class SessionMiddleware:
|
||||
try:
|
||||
jwt_payload = jwt.decode(
|
||||
data,
|
||||
str(
|
||||
key=str(
|
||||
self.jwt_secret.decode
|
||||
if self.jwt_secret.decode
|
||||
else self.jwt_secret.encode
|
||||
),
|
||||
algorithms=[self.jwt_alg],
|
||||
)
|
||||
jwt_payload.validate_exp(time.time(), 0)
|
||||
jwt_payload.validate_nbf(time.time(), 0)
|
||||
|
||||
jwt_payload = self._validate_jwt_payload(jwt_payload)
|
||||
scope["session"] = jwt_payload
|
||||
initial_session_was_empty = False
|
||||
except JWTError:
|
||||
@@ -89,14 +107,17 @@ class SessionMiddleware:
|
||||
if scope["session"]:
|
||||
if "exp" not in scope["session"]:
|
||||
scope["session"]["exp"] = int(time.time()) + self.max_age
|
||||
|
||||
data = jwt.encode(
|
||||
self.jwt_header, scope["session"], str(self.jwt_secret.encode)
|
||||
scope["session"],
|
||||
key=str(self.jwt_secret.encode),
|
||||
algorithm=self.jwt_alg,
|
||||
)
|
||||
|
||||
headers = MutableHeaders(scope=message)
|
||||
header_value = "%s=%s; path=/; Max-Age=%d; %s" % (
|
||||
self.session_cookie,
|
||||
data.decode("utf-8"),
|
||||
data,
|
||||
self.max_age,
|
||||
self.security_flags,
|
||||
)
|
||||
|
||||
@@ -4,6 +4,7 @@ import sys
|
||||
import alembic.config
|
||||
import uvicorn
|
||||
from config import DEV_HOST, DEV_PORT, ROMM_AUTH_SECRET_KEY, DISABLE_CSRF_PROTECTION
|
||||
from contextlib import asynccontextmanager
|
||||
from endpoints import (
|
||||
auth,
|
||||
config,
|
||||
@@ -30,7 +31,18 @@ from handler.auth_handler.hybrid_auth import HybridAuthBackend
|
||||
from handler.auth_handler.middleware import CustomCSRFMiddleware, SessionMiddleware
|
||||
from starlette.middleware.authentication import AuthenticationMiddleware
|
||||
|
||||
app = FastAPI(title="RomM API", version=github_handler.get_version())
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
if "pytest" not in sys.modules:
|
||||
# Create default admin user if no admin user exists
|
||||
if len(db_user_handler.get_admin_users()) == 0:
|
||||
auth_handler.create_default_admin_user()
|
||||
|
||||
yield
|
||||
|
||||
|
||||
app = FastAPI(title="RomM API", version=github_handler.get_version(), lifespan=lifespan)
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
@@ -82,18 +94,6 @@ add_pagination(app)
|
||||
app.mount("/ws", socket_handler.socket_app)
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
def startup() -> None:
|
||||
"""Event to handle RomM startup logic."""
|
||||
|
||||
if "pytest" in sys.modules:
|
||||
return
|
||||
|
||||
# Create default admin user if no admin user exists
|
||||
if len(db_user_handler.get_admin_users()) == 0:
|
||||
auth_handler.create_default_admin_user()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run migrations
|
||||
alembic.config.main(argv=["upgrade", "head"])
|
||||
|
||||
Reference in New Issue
Block a user