manually fix tests

This commit is contained in:
Georges-Antoine Assi
2025-11-18 00:00:49 -05:00
parent 6a1a344ba2
commit d1824bf894
2 changed files with 26 additions and 11 deletions

View File

@@ -1,3 +1,7 @@
# Original source: https://github.com/frankie567/starlette-csrf
# Copyright (c) 2021 Sebastien Delisle
# MIT License
import functools
import http.cookies
import secrets
@@ -91,7 +95,9 @@ class CSRFMiddleware:
cookie: http.cookies.BaseCookie = http.cookies.SimpleCookie()
cookie_name = self.cookie_name
cookie[cookie_name] = self._generate_csrf_token(request.user.id)
cookie[cookie_name] = self._generate_csrf_token(
request.user.id if request.user.is_authenticated else None
)
cookie[cookie_name]["path"] = self.cookie_path
cookie[cookie_name]["secure"] = self.cookie_secure
cookie[cookie_name]["httponly"] = self.cookie_httponly

View File

@@ -1,7 +1,9 @@
import asyncio
import re
from itsdangerous import URLSafeSerializer
from starlette.applications import Starlette
from starlette.authentication import AuthCredentials, AuthenticationBackend
from starlette.middleware import Middleware
from starlette.middleware.authentication import AuthenticationMiddleware
from starlette.requests import Request
@@ -11,12 +13,17 @@ from starlette.testclient import TestClient
from config import ROMM_AUTH_SECRET_KEY
from handler.auth.constants import ALGORITHM
from handler.auth.hybrid_auth import HybridAuthBackend
from handler.auth.middleware.csrf_middleware import CSRFMiddleware
from handler.auth.middleware.session_middleware import SessionMiddleware
from models.user import User
# Test app factory #
class BasicAuthBackend(AuthenticationBackend):
async def authenticate(self, conn):
return AuthCredentials(["authenticated"]), User(id=1, username="user_1")
# Test app factory
def create_test_app(**csrf_kwargs) -> Starlette:
"""Return a Starlette app wired with CSRFMiddleware."""
@@ -32,13 +39,13 @@ def create_test_app(**csrf_kwargs) -> Starlette:
return JSONResponse({"token": token})
routes = [
Route("/get", get_handler, methods=["GET"]),
Route("/get", get_handler, methods=["GET", "HEAD", "OPTIONS", "TRACE"]),
Route("/post", post_handler, methods=["POST"]),
Route("/echo", post_echo, methods=["POST"]),
]
middleware = [
Middleware(AuthenticationMiddleware, backend=BasicAuthBackend()),
Middleware(CSRFMiddleware, secret="test-secret", **csrf_kwargs),
Middleware(AuthenticationMiddleware, backend=HybridAuthBackend()),
Middleware(
SessionMiddleware,
secret_key=ROMM_AUTH_SECRET_KEY,
@@ -110,7 +117,7 @@ class TestCSRFMiddleware:
client = TestClient(app)
for method in ("GET", "HEAD", "OPTIONS", "TRACE"):
resp = client.request(method, "/post")
resp = client.request(method, "/get")
assert resp.status_code == 200
def test_custom_header_name(self) -> None:
@@ -148,7 +155,7 @@ class TestCSRFMiddleware:
client = TestClient(app)
resp = client.get("/get")
set_cookie = resp.headers["set-cookie"]
set_cookie = resp.headers["set-cookie"].lower()
assert "secure" in set_cookie
assert "httponly" in set_cookie
assert "samesite=strict" in set_cookie
@@ -200,15 +207,17 @@ class TestCSRFMiddleware:
"""WebSocket (or other non-HTTP) scopes should pass through."""
# Manual ASGI call; TestClient doesn't expose WebSocket easily
scope = {"type": "websocket", "path": "/ws", "headers": []}
receive = lambda: {} # noqa: E731
send = lambda msg: None # noqa: E731
async def receive():
return {}
async def send(msg):
pass
async def dummy_app(scope, receive, send):
await send({"type": "websocket.accept"})
middleware = CSRFMiddleware(dummy_app, secret="test")
import asyncio
asyncio.run(middleware(scope, receive, send)) # should not raise
def test_token_generation_and_validation(self) -> None: