add tests for oidc handler

This commit is contained in:
Georges-Antoine Assi
2024-12-12 17:37:30 -05:00
parent 9e844801ed
commit 2d5bc34e9c
5 changed files with 104 additions and 10 deletions

View File

@@ -15,7 +15,7 @@ from exceptions.auth_exceptions import (
from fastapi import Depends, HTTPException, Request, status
from fastapi.responses import RedirectResponse
from fastapi.security.http import HTTPBasic
from handler.auth import auth_handler, oauth_handler, open_id_handler
from handler.auth import auth_handler, oauth_handler, oidc_handler
from handler.database import db_user_handler
from utils.router import APIRouter
@@ -248,9 +248,7 @@ async def auth_openid(request: Request):
raise OIDCNotConfiguredException
token = await oauth.openid.authorize_access_token(request)
potential_user = await open_id_handler.get_current_active_user_from_openid_token(
token
)
potential_user = await oidc_handler.get_current_active_user_from_openid_token(token)
if not potential_user:
raise AuthCredentialsException

View File

@@ -2,4 +2,4 @@ from .base_handler import AuthHandler, OAuthHandler, OpenIDHandler
auth_handler = AuthHandler()
oauth_handler = OAuthHandler()
open_id_handler = OpenIDHandler()
oidc_handler = OpenIDHandler()

View File

@@ -179,29 +179,25 @@ class OpenIDHandler:
from handler.database import db_user_handler
if not OIDC_ENABLED:
return None
return None, None
id_token = token.get("id_token")
try:
payload = jwt.decode(id_token, self.rsa_key, algorithms=["RS256"])
except (BadSignatureError, ValueError) as exc:
print("Error decoding token")
raise OAuthCredentialsException from exc
iss = payload.claims.get("iss")
if OIDC_SERVER_APPLICATION_URL not in str(iss):
print("Invalid issuer")
raise OAuthCredentialsException
email = payload.claims.get("email")
if email is None:
print("No email")
raise OAuthCredentialsException
user = db_user_handler.get_user_by_email(email)
if user is None:
print("User not found")
raise OAuthCredentialsException
if not user.enabled:

View File

View File

@@ -0,0 +1,100 @@
from unittest.mock import MagicMock
import pytest
from fastapi import HTTPException
from handler.auth.base_handler import OpenIDHandler
from httpx import HTTPStatusError, Request, Response
from joserfc.errors import BadSignatureError
from joserfc.jwt import Token
# Mock constants
OIDC_SERVER_APPLICATION_URL = "http://mock-oidc-server"
OIDC_ENABLED = True
@pytest.fixture
def mock_oidc_disabled(mocker):
mocker.patch("handler.auth.base_handler.OIDC_ENABLED", False)
@pytest.fixture
def mock_oidc_enabled(mocker):
mocker.patch(
"handler.auth.base_handler.OIDC_SERVER_APPLICATION_URL",
OIDC_SERVER_APPLICATION_URL,
)
mocker.patch("handler.auth.base_handler.OIDC_ENABLED", True)
def test_oidc_disabled_initialization(mock_oidc_disabled):
"""Test that the handler initializes correctly when OIDC is disabled."""
oidc_handler = OpenIDHandler()
assert not hasattr(oidc_handler, "rsa_key")
def test_oidc_enabled_server_unreachable(mocker, mock_oidc_enabled):
"""Test that initialization raises an HTTPException when the OIDC server is unreachable."""
# Mock request and response
mock_request = Request("GET", f"{OIDC_SERVER_APPLICATION_URL}/jwks/")
mock_response = Response(500, request=mock_request)
# Mock the HTTPStatusError
mocker.patch(
"httpx.Client.get",
side_effect=HTTPStatusError(
"Mocked error", request=mock_request, response=mock_response
),
)
with pytest.raises(HTTPException):
OpenIDHandler()
async def test_oidc_valid_token_decoding(mocker, mock_oidc_enabled):
"""Test token decoding with valid RSA key and token."""
mocker.patch(
"httpx.Client.get",
return_value=MagicMock(
json=lambda: {"keys": [{"kty": "RSA", "n": "fake", "e": "AQAB"}]}
),
)
mock_rsa_key = MagicMock()
mocker.patch(
"handler.auth.base_handler.RSAKey.import_key", return_value=mock_rsa_key
)
mock_jwt_payload = Token(
header={"alg": "RS256"},
claims={"iss": OIDC_SERVER_APPLICATION_URL, "email": "test@example.com"},
)
mocker.patch("joserfc.jwt.decode", return_value=mock_jwt_payload)
mock_user = MagicMock(enabled=True)
mocker.patch(
"handler.database.db_user_handler.get_user_by_email", return_value=mock_user
)
oidc_handler = OpenIDHandler()
token = {"id_token": "valid_token"}
user, claims = await oidc_handler.get_current_active_user_from_openid_token(token)
assert user == mock_user
assert claims == mock_jwt_payload.claims
async def test_oidc_invalid_token_signature(mocker, mock_oidc_enabled):
"""Test token decoding raises exception for invalid signature."""
mocker.patch(
"httpx.Client.get",
return_value=MagicMock(
json=lambda: {"keys": [{"kty": "RSA", "n": "fake", "e": "AQAB"}]}
),
)
mock_rsa_key = MagicMock()
mocker.patch(
"handler.auth.base_handler.RSAKey.import_key", return_value=mock_rsa_key
)
mocker.patch("joserfc.jwt.decode", side_effect=BadSignatureError)
oidc_handler = OpenIDHandler()
token = {"id_token": "invalid_signature_token"}
with pytest.raises(HTTPException):
await oidc_handler.get_current_active_user_from_openid_token(token)