mirror of
https://github.com/rommapp/romm.git
synced 2026-06-28 06:46:00 +00:00
implement csrf middleware directly in repo
This commit is contained in:
159
backend/handler/auth/middleware/csrf_middleware.py
Normal file
159
backend/handler/auth/middleware/csrf_middleware.py
Normal file
@@ -0,0 +1,159 @@
|
|||||||
|
import functools
|
||||||
|
import http.cookies
|
||||||
|
import secrets
|
||||||
|
from re import Pattern
|
||||||
|
from typing import Optional, cast
|
||||||
|
|
||||||
|
from itsdangerous import BadSignature
|
||||||
|
from itsdangerous.url_safe import URLSafeSerializer
|
||||||
|
from starlette.datastructures import URL, MutableHeaders
|
||||||
|
from starlette.requests import Request
|
||||||
|
from starlette.responses import PlainTextResponse, Response
|
||||||
|
from starlette.types import ASGIApp, Message, Receive, Scope, Send
|
||||||
|
|
||||||
|
|
||||||
|
class CSRFMiddleware:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
app: ASGIApp,
|
||||||
|
secret: str,
|
||||||
|
*,
|
||||||
|
required_urls: Optional[list[Pattern]] = None,
|
||||||
|
exempt_urls: Optional[list[Pattern]] = None,
|
||||||
|
sensitive_cookies: Optional[set[str]] = None,
|
||||||
|
safe_methods: set[str] = {"GET", "HEAD", "OPTIONS", "TRACE"},
|
||||||
|
cookie_name: str = "csrftoken",
|
||||||
|
cookie_path: str = "/",
|
||||||
|
cookie_domain: Optional[str] = None,
|
||||||
|
cookie_secure: bool = False,
|
||||||
|
cookie_httponly: bool = False,
|
||||||
|
cookie_samesite: str = "lax",
|
||||||
|
header_name: str = "x-csrftoken",
|
||||||
|
) -> None:
|
||||||
|
self.app = app
|
||||||
|
self.serializer = URLSafeSerializer(secret, "csrftoken")
|
||||||
|
self.secret = secret
|
||||||
|
self.required_urls = required_urls
|
||||||
|
self.exempt_urls = exempt_urls
|
||||||
|
self.sensitive_cookies = sensitive_cookies
|
||||||
|
self.safe_methods = safe_methods
|
||||||
|
self.cookie_name = cookie_name
|
||||||
|
self.cookie_path = cookie_path
|
||||||
|
self.cookie_domain = cookie_domain
|
||||||
|
self.cookie_secure = cookie_secure
|
||||||
|
self.cookie_httponly = cookie_httponly
|
||||||
|
self.cookie_samesite = cookie_samesite
|
||||||
|
self.header_name = header_name
|
||||||
|
|
||||||
|
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||||
|
# Skip CSRF check if not an HTTP request, like websockets
|
||||||
|
if scope["type"] != "http":
|
||||||
|
await self.app(scope, receive, send)
|
||||||
|
return None
|
||||||
|
|
||||||
|
request = Request(scope, receive)
|
||||||
|
|
||||||
|
# Skip CSRF check if Authorization header is present
|
||||||
|
auth_scheme = request.headers.get("Authorization", "").split(" ", 1)[0].lower()
|
||||||
|
if auth_scheme == "bearer" or auth_scheme == "basic":
|
||||||
|
await self.app(scope, receive, send)
|
||||||
|
return None
|
||||||
|
|
||||||
|
csrf_cookie = request.cookies.get(self.cookie_name)
|
||||||
|
|
||||||
|
if self._url_is_required(request.url) or (
|
||||||
|
request.method not in self.safe_methods
|
||||||
|
and not self._url_is_exempt(request.url)
|
||||||
|
and self._has_sensitive_cookies(request.cookies)
|
||||||
|
):
|
||||||
|
submitted_csrf_token = await self._get_submitted_csrf_token(request)
|
||||||
|
if (
|
||||||
|
not csrf_cookie
|
||||||
|
or not submitted_csrf_token
|
||||||
|
or not self._csrf_tokens_match(
|
||||||
|
csrf_cookie, submitted_csrf_token, request.user.id
|
||||||
|
)
|
||||||
|
):
|
||||||
|
response = self._get_error_response(request)
|
||||||
|
await response(scope, receive, send)
|
||||||
|
return
|
||||||
|
|
||||||
|
send = functools.partial(self.send, send=send, scope=scope)
|
||||||
|
await self.app(scope, receive, send)
|
||||||
|
|
||||||
|
async def send(self, message: Message, send: Send, scope: Scope) -> None:
|
||||||
|
request = Request(scope)
|
||||||
|
csrf_cookie = request.cookies.get(self.cookie_name)
|
||||||
|
|
||||||
|
if csrf_cookie is None:
|
||||||
|
message.setdefault("headers", [])
|
||||||
|
headers = MutableHeaders(scope=message)
|
||||||
|
|
||||||
|
cookie: http.cookies.BaseCookie = http.cookies.SimpleCookie()
|
||||||
|
cookie_name = self.cookie_name
|
||||||
|
cookie[cookie_name] = self._generate_csrf_token(request.user.id)
|
||||||
|
cookie[cookie_name]["path"] = self.cookie_path
|
||||||
|
cookie[cookie_name]["secure"] = self.cookie_secure
|
||||||
|
cookie[cookie_name]["httponly"] = self.cookie_httponly
|
||||||
|
cookie[cookie_name]["samesite"] = self.cookie_samesite
|
||||||
|
if self.cookie_domain is not None:
|
||||||
|
cookie[cookie_name]["domain"] = self.cookie_domain # pragma: no cover
|
||||||
|
headers.append("set-cookie", cookie.output(header="").strip())
|
||||||
|
|
||||||
|
await send(message)
|
||||||
|
|
||||||
|
def _has_sensitive_cookies(self, cookies: dict[str, str]) -> bool:
|
||||||
|
if not self.sensitive_cookies:
|
||||||
|
return True
|
||||||
|
for sensitive_cookie in self.sensitive_cookies:
|
||||||
|
if sensitive_cookie in cookies:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _url_is_required(self, url: URL) -> bool:
|
||||||
|
if not self.required_urls:
|
||||||
|
return False
|
||||||
|
for required_url in self.required_urls:
|
||||||
|
if required_url.match(url.path):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _url_is_exempt(self, url: URL) -> bool:
|
||||||
|
if not self.exempt_urls:
|
||||||
|
return False
|
||||||
|
for exempt_url in self.exempt_urls:
|
||||||
|
if exempt_url.match(url.path):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def _get_submitted_csrf_token(self, request: Request) -> Optional[str]:
|
||||||
|
return request.headers.get(self.header_name)
|
||||||
|
|
||||||
|
def _generate_csrf_token(self, user_id: int | None = None) -> str:
|
||||||
|
obj = {"token": secrets.token_urlsafe(128), "user_id": user_id}
|
||||||
|
return cast(str, self.serializer.dumps(obj))
|
||||||
|
|
||||||
|
def _csrf_tokens_match(
|
||||||
|
self, document_cookie: str, header_cookie: str, user_id: str | None
|
||||||
|
) -> bool:
|
||||||
|
try:
|
||||||
|
decoded_doc_cookie: str = self.serializer.loads(document_cookie)
|
||||||
|
decoded_header_cookie: str = self.serializer.loads(header_cookie)
|
||||||
|
|
||||||
|
# Verify that the tokens match, the user IDs match
|
||||||
|
# and the user_id matches the authenticated user
|
||||||
|
return (
|
||||||
|
secrets.compare_digest(
|
||||||
|
decoded_doc_cookie["token"], decoded_doc_cookie["token"]
|
||||||
|
)
|
||||||
|
and decoded_header_cookie["user_id"] == decoded_header_cookie["user_id"]
|
||||||
|
and decoded_doc_cookie["user_id"] == user_id
|
||||||
|
and decoded_header_cookie["user_id"] == user_id
|
||||||
|
)
|
||||||
|
except BadSignature:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _get_error_response(self, request: Request) -> Response:
|
||||||
|
return PlainTextResponse(
|
||||||
|
content="CSRF token verification failed", status_code=403
|
||||||
|
)
|
||||||
@@ -5,31 +5,11 @@ from joserfc import jwt
|
|||||||
from joserfc.errors import BadSignatureError
|
from joserfc.errors import BadSignatureError
|
||||||
from joserfc.jwk import OctKey
|
from joserfc.jwk import OctKey
|
||||||
from starlette.datastructures import MutableHeaders, Secret
|
from starlette.datastructures import MutableHeaders, Secret
|
||||||
from starlette.requests import HTTPConnection, Request
|
from starlette.requests import HTTPConnection
|
||||||
from starlette.types import ASGIApp, Message, Receive, Scope, Send
|
from starlette.types import ASGIApp, Message, Receive, Scope, Send
|
||||||
from starlette_csrf.middleware import CSRFMiddleware
|
|
||||||
|
|
||||||
from config import SESSION_MAX_AGE_SECONDS
|
from config import SESSION_MAX_AGE_SECONDS
|
||||||
|
|
||||||
|
|
||||||
class CustomCSRFMiddleware(CSRFMiddleware):
|
|
||||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
|
||||||
# Skip CSRF check if not an HTTP request, like websockets
|
|
||||||
if scope["type"] != "http":
|
|
||||||
await self.app(scope, receive, send)
|
|
||||||
return None
|
|
||||||
|
|
||||||
request = Request(scope, receive)
|
|
||||||
|
|
||||||
# Skip CSRF check if Authorization header is present
|
|
||||||
auth_scheme = request.headers.get("Authorization", "").split(" ", 1)[0].lower()
|
|
||||||
if auth_scheme == "bearer" or auth_scheme == "basic":
|
|
||||||
await self.app(scope, receive, send)
|
|
||||||
return None
|
|
||||||
|
|
||||||
await super().__call__(scope, receive, send)
|
|
||||||
|
|
||||||
|
|
||||||
SecretKey = namedtuple("SecretKey", ("encode", "decode"))
|
SecretKey = namedtuple("SecretKey", ("encode", "decode"))
|
||||||
|
|
||||||
|
|
||||||
@@ -44,7 +44,8 @@ from endpoints import (
|
|||||||
)
|
)
|
||||||
from handler.auth.constants import ALGORITHM
|
from handler.auth.constants import ALGORITHM
|
||||||
from handler.auth.hybrid_auth import HybridAuthBackend
|
from handler.auth.hybrid_auth import HybridAuthBackend
|
||||||
from handler.auth.middleware import CustomCSRFMiddleware, SessionMiddleware
|
from handler.auth.middleware.csrf_middleware import CSRFMiddleware
|
||||||
|
from handler.auth.middleware.session_middleware import SessionMiddleware
|
||||||
from handler.socket_handler import socket_handler
|
from handler.socket_handler import socket_handler
|
||||||
from logger.formatter import LOGGING_CONFIG
|
from logger.formatter import LOGGING_CONFIG
|
||||||
from utils import get_version
|
from utils import get_version
|
||||||
@@ -90,7 +91,7 @@ app.add_middleware(
|
|||||||
if not IS_PYTEST_RUN and not DISABLE_CSRF_PROTECTION:
|
if not IS_PYTEST_RUN and not DISABLE_CSRF_PROTECTION:
|
||||||
# CSRF protection (except endpoints listed in exempt_urls)
|
# CSRF protection (except endpoints listed in exempt_urls)
|
||||||
app.add_middleware(
|
app.add_middleware(
|
||||||
CustomCSRFMiddleware,
|
CSRFMiddleware,
|
||||||
cookie_name="romm_csrftoken",
|
cookie_name="romm_csrftoken",
|
||||||
secret=ROMM_AUTH_SECRET_KEY,
|
secret=ROMM_AUTH_SECRET_KEY,
|
||||||
exempt_urls=[re.compile(r"^/api/token.*"), re.compile(r"^/ws")],
|
exempt_urls=[re.compile(r"^/api/token.*"), re.compile(r"^/ws")],
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ dependencies = [
|
|||||||
"fastapi[standard-no-fastapi-cloud-cli] ~= 0.121.1",
|
"fastapi[standard-no-fastapi-cloud-cli] ~= 0.121.1",
|
||||||
"gunicorn ~= 23.0",
|
"gunicorn ~= 23.0",
|
||||||
"httpx ~= 0.27",
|
"httpx ~= 0.27",
|
||||||
|
"itsdangerous>=2.2.0",
|
||||||
"joserfc ~= 1.3.4",
|
"joserfc ~= 1.3.4",
|
||||||
"opentelemetry-distro ~= 0.56",
|
"opentelemetry-distro ~= 0.56",
|
||||||
"opentelemetry-exporter-otlp ~= 1.36",
|
"opentelemetry-exporter-otlp ~= 1.36",
|
||||||
@@ -46,7 +47,6 @@ dependencies = [
|
|||||||
"rq-scheduler @ git+https://github.com/adamantike/rq-scheduler.git@feat/script-options-username-ssl",
|
"rq-scheduler @ git+https://github.com/adamantike/rq-scheduler.git@feat/script-options-username-ssl",
|
||||||
"sentry-sdk ~= 2.32",
|
"sentry-sdk ~= 2.32",
|
||||||
"starlette ~= 0.49",
|
"starlette ~= 0.49",
|
||||||
"starlette-csrf ~= 3.0",
|
|
||||||
"streaming-form-data ~= 1.19",
|
"streaming-form-data ~= 1.19",
|
||||||
"strsimpy ~= 0.2",
|
"strsimpy ~= 0.2",
|
||||||
"types-colorama ~= 0.4",
|
"types-colorama ~= 0.4",
|
||||||
|
|||||||
17
uv.lock
generated
17
uv.lock
generated
@@ -1911,6 +1911,7 @@ dependencies = [
|
|||||||
{ name = "fastapi-pagination", extra = ["sqlalchemy"] },
|
{ name = "fastapi-pagination", extra = ["sqlalchemy"] },
|
||||||
{ name = "gunicorn" },
|
{ name = "gunicorn" },
|
||||||
{ name = "httpx" },
|
{ name = "httpx" },
|
||||||
|
{ name = "itsdangerous" },
|
||||||
{ name = "joserfc" },
|
{ name = "joserfc" },
|
||||||
{ name = "opentelemetry-distro" },
|
{ name = "opentelemetry-distro" },
|
||||||
{ name = "opentelemetry-exporter-otlp" },
|
{ name = "opentelemetry-exporter-otlp" },
|
||||||
@@ -1933,7 +1934,6 @@ dependencies = [
|
|||||||
{ name = "sentry-sdk" },
|
{ name = "sentry-sdk" },
|
||||||
{ name = "sqlalchemy", extra = ["mariadb-connector", "mysql-connector", "postgresql-psycopg"] },
|
{ name = "sqlalchemy", extra = ["mariadb-connector", "mysql-connector", "postgresql-psycopg"] },
|
||||||
{ name = "starlette" },
|
{ name = "starlette" },
|
||||||
{ name = "starlette-csrf" },
|
|
||||||
{ name = "streaming-form-data" },
|
{ name = "streaming-form-data" },
|
||||||
{ name = "strsimpy" },
|
{ name = "strsimpy" },
|
||||||
{ name = "types-colorama" },
|
{ name = "types-colorama" },
|
||||||
@@ -1982,6 +1982,7 @@ requires-dist = [
|
|||||||
{ name = "httpx", specifier = "~=0.27" },
|
{ name = "httpx", specifier = "~=0.27" },
|
||||||
{ name = "ipdb", marker = "extra == 'dev'", specifier = "~=0.13" },
|
{ name = "ipdb", marker = "extra == 'dev'", specifier = "~=0.13" },
|
||||||
{ name = "ipykernel", marker = "extra == 'dev'", specifier = "~=6.29" },
|
{ name = "ipykernel", marker = "extra == 'dev'", specifier = "~=6.29" },
|
||||||
|
{ name = "itsdangerous", specifier = ">=2.2.0" },
|
||||||
{ name = "joserfc", specifier = "~=1.3.4" },
|
{ name = "joserfc", specifier = "~=1.3.4" },
|
||||||
{ name = "memray", marker = "extra == 'dev'", specifier = "~=1.15" },
|
{ name = "memray", marker = "extra == 'dev'", specifier = "~=1.15" },
|
||||||
{ name = "mypy", marker = "extra == 'dev'", specifier = "~=1.13" },
|
{ name = "mypy", marker = "extra == 'dev'", specifier = "~=1.13" },
|
||||||
@@ -2013,7 +2014,6 @@ requires-dist = [
|
|||||||
{ name = "sentry-sdk", specifier = "~=2.32" },
|
{ name = "sentry-sdk", specifier = "~=2.32" },
|
||||||
{ name = "sqlalchemy", extras = ["mariadb-connector", "mysql-connector", "postgresql-psycopg"], specifier = "~=2.0" },
|
{ name = "sqlalchemy", extras = ["mariadb-connector", "mysql-connector", "postgresql-psycopg"], specifier = "~=2.0" },
|
||||||
{ name = "starlette", specifier = "~=0.49" },
|
{ name = "starlette", specifier = "~=0.49" },
|
||||||
{ name = "starlette-csrf", specifier = "~=3.0" },
|
|
||||||
{ name = "streaming-form-data", specifier = "~=1.19" },
|
{ name = "streaming-form-data", specifier = "~=1.19" },
|
||||||
{ name = "strsimpy", specifier = "~=0.2" },
|
{ name = "strsimpy", specifier = "~=0.2" },
|
||||||
{ name = "types-colorama", specifier = "~=0.4" },
|
{ name = "types-colorama", specifier = "~=0.4" },
|
||||||
@@ -2198,19 +2198,6 @@ wheels = [
|
|||||||
{ url = "https://files.pythonhosted.org/packages/51/da/545b75d420bb23b5d494b0517757b351963e974e79933f01e05c929f20a6/starlette-0.49.1-py3-none-any.whl", hash = "sha256:d92ce9f07e4a3caa3ac13a79523bd18e3bc0042bb8ff2d759a8e7dd0e1859875", size = 74175, upload-time = "2025-10-28T17:34:09.13Z" },
|
{ url = "https://files.pythonhosted.org/packages/51/da/545b75d420bb23b5d494b0517757b351963e974e79933f01e05c929f20a6/starlette-0.49.1-py3-none-any.whl", hash = "sha256:d92ce9f07e4a3caa3ac13a79523bd18e3bc0042bb8ff2d759a8e7dd0e1859875", size = 74175, upload-time = "2025-10-28T17:34:09.13Z" },
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "starlette-csrf"
|
|
||||||
version = "3.0.0"
|
|
||||||
source = { registry = "https://pypi.org/simple" }
|
|
||||||
dependencies = [
|
|
||||||
{ name = "itsdangerous" },
|
|
||||||
{ name = "starlette" },
|
|
||||||
]
|
|
||||||
sdist = { url = "https://files.pythonhosted.org/packages/0f/7c/53c57b4cd76c9a4493a8525d34a76d7e4bbe0ff957de1c53f30241aa757a/starlette_csrf-3.0.0.tar.gz", hash = "sha256:7afaca8c72cc3c726e5942778af53454607ca3e653fd86cd75ee35d8cd1cfa77", size = 8371, upload-time = "2023-06-27T13:23:24.387Z" }
|
|
||||||
wheels = [
|
|
||||||
{ url = "https://files.pythonhosted.org/packages/b9/83/6641e4fdcf33b1cc614a74ecabe5835236a1b2564bf6735db7e35d788795/starlette_csrf-3.0.0-py3-none-any.whl", hash = "sha256:aac29b366e83621d3fc56be690866e16f3c56df91ab5e184b77950540a4e2761", size = 6170, upload-time = "2023-06-27T13:23:25.563Z" },
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "streaming-form-data"
|
name = "streaming-form-data"
|
||||||
version = "1.19.1"
|
version = "1.19.1"
|
||||||
|
|||||||
Reference in New Issue
Block a user