Add new redis-backed session middleware

This commit is contained in:
Georges-Antoine Assi
2025-11-22 10:47:59 -05:00
parent 28e1db8d2b
commit ec6bb24662
6 changed files with 203 additions and 35 deletions

View File

@@ -21,11 +21,14 @@ from config import (
from decorators.auth import oauth
from exceptions.auth_exceptions import OAuthCredentialsException, UserDisabledException
from handler.auth.constants import ALGORITHM, DEFAULT_OAUTH_TOKEN_EXPIRY, TokenPurpose
from handler.auth.middleware.redis_session_middleware import RedisSessionMiddleware
from handler.redis_handler import redis_client
from logger.formatter import CYAN
from logger.formatter import highlight as hl
from logger.logger import log
oct_key = OctKey.import_key(ROMM_AUTH_SECRET_KEY)
class AuthHandler:
def __init__(self) -> None:
@@ -95,7 +98,7 @@ class AuthHandler:
token = jwt.encode(
{"alg": ALGORITHM},
to_encode,
OctKey.import_key(ROMM_AUTH_SECRET_KEY),
oct_key,
)
log.info(
f"Reset password link requested for {hl(user.username, color=CYAN)}. Reset link: {hl(f'{ROMM_BASE_URL}/reset-password?token={token}')}"
@@ -119,7 +122,7 @@ class AuthHandler:
from handler.database import db_user_handler
try:
payload = jwt.decode(token, ROMM_AUTH_SECRET_KEY, algorithms=[ALGORITHM])
payload = jwt.decode(token, oct_key, algorithms=[ALGORITHM])
except (BadSignatureError, DecodeError, ValueError) as exc:
raise HTTPException(status_code=400, detail="Invalid token") from exc
@@ -146,12 +149,12 @@ class AuthHandler:
raise HTTPException(status_code=404, detail="User not found")
now = datetime.now(timezone.utc).timestamp()
if now > payload.claims.get("exp"):
if now > payload.claims.get("exp", 0.0):
raise HTTPException(status_code=400, detail="Token has expired")
return user
def set_user_new_password(self, user: Any, new_password: str) -> None:
async def set_user_new_password(self, user: Any, new_password: str) -> None:
"""
Set the new password for the user.
Args:
@@ -163,6 +166,7 @@ class AuthHandler:
db_user_handler.update_user(
user.id, {"hashed_password": self.get_password_hash(new_password)}
)
await RedisSessionMiddleware.clear_user_sessions(user.username)
def generate_invite_link_token(self, user: Any, role: str) -> str:
"""
@@ -192,7 +196,7 @@ class AuthHandler:
token = jwt.encode(
{"alg": ALGORITHM},
to_encode,
OctKey.import_key(ROMM_AUTH_SECRET_KEY),
oct_key,
)
invite_link = f"{ROMM_BASE_URL}/register?token={token}"
log.info(
@@ -212,9 +216,7 @@ class AuthHandler:
str: The JTI (JWT ID) of the token.
"""
try:
payload = jwt.decode(
token, OctKey.import_key(ROMM_AUTH_SECRET_KEY), algorithms=[ALGORITHM]
)
payload = jwt.decode(token, oct_key, algorithms=[ALGORITHM])
except (BadSignatureError, DecodeError, ValueError) as exc:
raise HTTPException(status_code=400, detail="Invalid token") from exc
@@ -256,16 +258,14 @@ class OAuthHandler:
return jwt.encode(
{"alg": ALGORITHM},
to_encode,
OctKey.import_key(ROMM_AUTH_SECRET_KEY),
oct_key,
)
async def get_current_active_user_from_bearer_token(self, token: str):
from handler.database import db_user_handler
try:
payload = jwt.decode(
token, OctKey.import_key(ROMM_AUTH_SECRET_KEY), algorithms=[ALGORITHM]
)
payload = jwt.decode(token, oct_key, algorithms=[ALGORITHM])
except (BadSignatureError, DecodeError, ValueError) as exc:
raise OAuthCredentialsException from exc

View File

@@ -0,0 +1,92 @@
import json
import uuid
from starlette.datastructures import MutableHeaders
from starlette.requests import HTTPConnection
from starlette.types import ASGIApp, Message, Receive, Scope, Send
from config import SESSION_MAX_AGE_SECONDS
from handler.redis_handler import async_cache
class RedisSessionMiddleware:
def __init__(
self,
app: ASGIApp,
session_cookie: str = "session",
max_age: int = SESSION_MAX_AGE_SECONDS,
same_site: str = "lax",
https_only: bool = False,
) -> None:
self.app = app
self.session_cookie = session_cookie
self.max_age = max_age
self.security_flags = "httponly; samesite=" + same_site
if https_only:
self.security_flags += "; secure"
@staticmethod
async def clear_user_sessions(user_id: str) -> None:
"""
Clears all active sessions for a given user.
"""
session_ids = await async_cache.smembers(f"user_sessions:{user_id}")
if session_ids:
for session_id in session_ids:
await async_cache.delete(f"session:{session_id}")
await async_cache.delete(f"user_sessions:{user_id}")
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] not in ("http", "websocket"):
await self.app(scope, receive, send)
return
connection = HTTPConnection(scope)
session_id = None # Initialize session_id to None
session_cookie_from_request = connection.cookies.get(self.session_cookie)
if session_cookie_from_request:
session_id = session_cookie_from_request
session_data = await async_cache.get(f"session:{session_id}")
if session_data:
scope["session"] = json.loads(session_data)
scope["session"]["session_id"] = session_id
else:
scope["session"] = {}
else:
scope["session"] = {}
async def send_wrapper(message: Message) -> None:
nonlocal session_id
if message["type"] == "http.response.start":
headers = MutableHeaders(scope=message)
# Check for user_id to track user-specific sessions
user_id = scope["session"].get("sub")
if scope["session"]:
session_id = scope["session"].pop("session_id", None) or str(
uuid.uuid4()
) # Retrieve or create session_id
session_data_json = json.dumps(scope["session"])
await async_cache.set(
f"session:{session_id}", session_data_json, ex=self.max_age
)
# Add session_id to user set of sessions
if user_id:
await async_cache.sadd(f"user_sessions:{user_id}", session_id)
header_value = f"{self.session_cookie}={session_id}; path=/; Max-Age={self.max_age}; {self.security_flags}"
headers.append("Set-Cookie", header_value)
elif session_id:
await async_cache.delete(f"session:{session_id}")
# Remove session_id from user set of sessions
if user_id:
await async_cache.srem(f"user_sessions:{user_id}", session_id)
header_value = f"{self.session_cookie}=null; path=/; expires=Thu, 01 Jan 1970 00:00:00 GMT; {self.security_flags}"
headers.append("Set-Cookie", header_value)
await send(message)
await self.app(scope, receive, send_wrapper)