mirror of
https://github.com/rommapp/romm.git
synced 2026-06-28 06:46:00 +00:00
manually fix tests
This commit is contained in:
@@ -1,3 +1,7 @@
|
||||
# Original source: https://github.com/frankie567/starlette-csrf
|
||||
# Copyright (c) 2021 Sebastien Delisle
|
||||
# MIT License
|
||||
|
||||
import functools
|
||||
import http.cookies
|
||||
import secrets
|
||||
@@ -91,7 +95,9 @@ class CSRFMiddleware:
|
||||
|
||||
cookie: http.cookies.BaseCookie = http.cookies.SimpleCookie()
|
||||
cookie_name = self.cookie_name
|
||||
cookie[cookie_name] = self._generate_csrf_token(request.user.id)
|
||||
cookie[cookie_name] = self._generate_csrf_token(
|
||||
request.user.id if request.user.is_authenticated else None
|
||||
)
|
||||
cookie[cookie_name]["path"] = self.cookie_path
|
||||
cookie[cookie_name]["secure"] = self.cookie_secure
|
||||
cookie[cookie_name]["httponly"] = self.cookie_httponly
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
import asyncio
|
||||
import re
|
||||
|
||||
from itsdangerous import URLSafeSerializer
|
||||
from starlette.applications import Starlette
|
||||
from starlette.authentication import AuthCredentials, AuthenticationBackend
|
||||
from starlette.middleware import Middleware
|
||||
from starlette.middleware.authentication import AuthenticationMiddleware
|
||||
from starlette.requests import Request
|
||||
@@ -11,12 +13,17 @@ from starlette.testclient import TestClient
|
||||
|
||||
from config import ROMM_AUTH_SECRET_KEY
|
||||
from handler.auth.constants import ALGORITHM
|
||||
from handler.auth.hybrid_auth import HybridAuthBackend
|
||||
from handler.auth.middleware.csrf_middleware import CSRFMiddleware
|
||||
from handler.auth.middleware.session_middleware import SessionMiddleware
|
||||
from models.user import User
|
||||
|
||||
|
||||
# Test app factory #
|
||||
class BasicAuthBackend(AuthenticationBackend):
|
||||
async def authenticate(self, conn):
|
||||
return AuthCredentials(["authenticated"]), User(id=1, username="user_1")
|
||||
|
||||
|
||||
# Test app factory
|
||||
def create_test_app(**csrf_kwargs) -> Starlette:
|
||||
"""Return a Starlette app wired with CSRFMiddleware."""
|
||||
|
||||
@@ -32,13 +39,13 @@ def create_test_app(**csrf_kwargs) -> Starlette:
|
||||
return JSONResponse({"token": token})
|
||||
|
||||
routes = [
|
||||
Route("/get", get_handler, methods=["GET"]),
|
||||
Route("/get", get_handler, methods=["GET", "HEAD", "OPTIONS", "TRACE"]),
|
||||
Route("/post", post_handler, methods=["POST"]),
|
||||
Route("/echo", post_echo, methods=["POST"]),
|
||||
]
|
||||
middleware = [
|
||||
Middleware(AuthenticationMiddleware, backend=BasicAuthBackend()),
|
||||
Middleware(CSRFMiddleware, secret="test-secret", **csrf_kwargs),
|
||||
Middleware(AuthenticationMiddleware, backend=HybridAuthBackend()),
|
||||
Middleware(
|
||||
SessionMiddleware,
|
||||
secret_key=ROMM_AUTH_SECRET_KEY,
|
||||
@@ -110,7 +117,7 @@ class TestCSRFMiddleware:
|
||||
client = TestClient(app)
|
||||
|
||||
for method in ("GET", "HEAD", "OPTIONS", "TRACE"):
|
||||
resp = client.request(method, "/post")
|
||||
resp = client.request(method, "/get")
|
||||
assert resp.status_code == 200
|
||||
|
||||
def test_custom_header_name(self) -> None:
|
||||
@@ -148,7 +155,7 @@ class TestCSRFMiddleware:
|
||||
client = TestClient(app)
|
||||
|
||||
resp = client.get("/get")
|
||||
set_cookie = resp.headers["set-cookie"]
|
||||
set_cookie = resp.headers["set-cookie"].lower()
|
||||
assert "secure" in set_cookie
|
||||
assert "httponly" in set_cookie
|
||||
assert "samesite=strict" in set_cookie
|
||||
@@ -200,15 +207,17 @@ class TestCSRFMiddleware:
|
||||
"""WebSocket (or other non-HTTP) scopes should pass through."""
|
||||
# Manual ASGI call; TestClient doesn't expose WebSocket easily
|
||||
scope = {"type": "websocket", "path": "/ws", "headers": []}
|
||||
receive = lambda: {} # noqa: E731
|
||||
send = lambda msg: None # noqa: E731
|
||||
|
||||
async def receive():
|
||||
return {}
|
||||
|
||||
async def send(msg):
|
||||
pass
|
||||
|
||||
async def dummy_app(scope, receive, send):
|
||||
await send({"type": "websocket.accept"})
|
||||
|
||||
middleware = CSRFMiddleware(dummy_app, secret="test")
|
||||
import asyncio
|
||||
|
||||
asyncio.run(middleware(scope, receive, send)) # should not raise
|
||||
|
||||
def test_token_generation_and_validation(self) -> None:
|
||||
|
||||
Reference in New Issue
Block a user