From beeb9f0c31d055f843e41c1443f49cfdb66e5204 Mon Sep 17 00:00:00 2001 From: Michael Manganiello Date: Fri, 18 Oct 2024 23:42:18 -0300 Subject: [PATCH] misc: Create enum for authorization scopes Instead of using just strings, this change converts the scopes to a `StrEnum`, to be compatible with places where a string is expected. This avoids typos when using these scopes, simplifies searching for usages, and improves type hints. An extra change was the fix to the Firmware download endpoint, which wasn't respecting the `DISABLE_DOWNLOAD_ENDPOINT_AUTH` flag. --- backend/decorators/auth.py | 3 +- backend/endpoints/collections.py | 11 +++-- backend/endpoints/config.py | 13 ++--- backend/endpoints/feeds.py | 3 +- backend/endpoints/firmware.py | 17 ++++--- backend/endpoints/platform.py | 13 ++--- backend/endpoints/raw.py | 5 +- backend/endpoints/rom.py | 17 +++---- backend/endpoints/saves.py | 11 +++-- backend/endpoints/screenshots.py | 3 +- backend/endpoints/search.py | 5 +- backend/endpoints/states.py | 11 +++-- backend/endpoints/tasks.py | 5 +- backend/endpoints/user.py | 15 +++--- backend/handler/auth/base_handler.py | 71 +++++++++++++++++----------- backend/models/user.py | 3 +- 16 files changed, 121 insertions(+), 85 deletions(-) diff --git a/backend/decorators/auth.py b/backend/decorators/auth.py index 37c174e2b..9c03e8fcc 100644 --- a/backend/decorators/auth.py +++ b/backend/decorators/auth.py @@ -8,6 +8,7 @@ from handler.auth.base_handler import ( DEFAULT_SCOPES_MAP, FULL_SCOPES_MAP, WRITE_SCOPES_MAP, + Scope, ) from starlette.authentication import requires @@ -25,7 +26,7 @@ oauth2_password_bearer = OAuth2PasswordBearer( def protected_route( method: Any, path: str, - scopes: list[str] | None = None, + scopes: list[Scope] | None = None, **kwargs, ): def decorator(func: DecoratedCallable): diff --git a/backend/endpoints/collections.py b/backend/endpoints/collections.py index 19270244d..8dff4fe4a 100644 --- a/backend/endpoints/collections.py +++ b/backend/endpoints/collections.py @@ -13,6 +13,7 @@ from exceptions.endpoint_exceptions import ( CollectionPermissionError, ) from fastapi import Request, UploadFile +from handler.auth.base_handler import Scope from handler.database import db_collection_handler from handler.filesystem import fs_resource_handler from handler.filesystem.base_handler import CoverSize @@ -25,7 +26,7 @@ from utils.router import APIRouter router = APIRouter() -@protected_route(router.post, "/collections", ["collections.write"]) +@protected_route(router.post, "/collections", [Scope.COLLECTIONS_WRITE]) async def add_collection( request: Request, artwork: UploadFile | None = None, @@ -91,7 +92,7 @@ async def add_collection( ) -@protected_route(router.get, "/collections", ["collections.read"]) +@protected_route(router.get, "/collections", [Scope.COLLECTIONS_READ]) def get_collections(request: Request) -> list[CollectionSchema]: """Get collections endpoint @@ -107,7 +108,7 @@ def get_collections(request: Request) -> list[CollectionSchema]: return CollectionSchema.for_user(request.user.id, collections) -@protected_route(router.get, "/collections/{id}", ["collections.read"]) +@protected_route(router.get, "/collections/{id}", [Scope.COLLECTIONS_READ]) def get_collection(request: Request, id: int) -> CollectionSchema: """Get collections endpoint @@ -127,7 +128,7 @@ def get_collection(request: Request, id: int) -> CollectionSchema: return collection -@protected_route(router.put, "/collections/{id}", ["collections.write"]) +@protected_route(router.put, "/collections/{id}", [Scope.COLLECTIONS_WRITE]) async def update_collection( request: Request, id: int, @@ -213,7 +214,7 @@ async def update_collection( return db_collection_handler.update_collection(id, cleaned_data) -@protected_route(router.delete, "/collections/{id}", ["collections.write"]) +@protected_route(router.delete, "/collections/{id}", [Scope.COLLECTIONS_WRITE]) async def delete_collections(request: Request, id: int) -> MessageResponse: """Delete collections endpoint diff --git a/backend/endpoints/config.py b/backend/endpoints/config.py index eba2bbb37..a7f9b2a45 100644 --- a/backend/endpoints/config.py +++ b/backend/endpoints/config.py @@ -7,6 +7,7 @@ from exceptions.config_exceptions import ( ConfigNotWritableException, ) from fastapi import HTTPException, Request, status +from handler.auth.base_handler import Scope from logger.logger import log from utils.router import APIRouter @@ -43,7 +44,7 @@ def get_config() -> ConfigResponse: ) from exc -@protected_route(router.post, "/config/system/platforms", ["platforms.write"]) +@protected_route(router.post, "/config/system/platforms", [Scope.PLATFORMS_WRITE]) async def add_platform_binding(request: Request) -> MessageResponse: """Add platform binding to the configuration""" @@ -63,7 +64,7 @@ async def add_platform_binding(request: Request) -> MessageResponse: @protected_route( - router.delete, "/config/system/platforms/{fs_slug}", ["platforms.write"] + router.delete, "/config/system/platforms/{fs_slug}", [Scope.PLATFORMS_WRITE] ) async def delete_platform_binding(request: Request, fs_slug: str) -> MessageResponse: """Delete platform binding from the configuration""" @@ -79,7 +80,7 @@ async def delete_platform_binding(request: Request, fs_slug: str) -> MessageResp return {"msg": f"{fs_slug} bind removed successfully!"} -@protected_route(router.post, "/config/system/versions", ["platforms.write"]) +@protected_route(router.post, "/config/system/versions", [Scope.PLATFORMS_WRITE]) async def add_platform_version(request: Request) -> MessageResponse: """Add platform version to the configuration""" @@ -99,7 +100,7 @@ async def add_platform_version(request: Request) -> MessageResponse: @protected_route( - router.delete, "/config/system/versions/{fs_slug}", ["platforms.write"] + router.delete, "/config/system/versions/{fs_slug}", [Scope.PLATFORMS_WRITE] ) async def delete_platform_version(request: Request, fs_slug: str) -> MessageResponse: """Delete platform version from the configuration""" @@ -115,7 +116,7 @@ async def delete_platform_version(request: Request, fs_slug: str) -> MessageResp return {"msg": f"{fs_slug} version removed successfully!"} -# @protected_route(router.post, "/config/exclude", ["platforms.write"]) +# @protected_route(router.post, "/config/exclude", [Scope.PLATFORMS_WRITE]) # async def add_exclusion(request: Request) -> MessageResponse: # """Add platform binding to the configuration""" @@ -127,7 +128,7 @@ async def delete_platform_version(request: Request, fs_slug: str) -> MessageResp # return {"msg": f"Exclusion {exclusion} added to {exclude} successfully!"} -# @protected_route(router.delete, "/config/exclude", ["platforms.write"]) +# @protected_route(router.delete, "/config/exclude", [Scope.PLATFORMS_WRITE]) # async def delete_exclusion(request: Request) -> MessageResponse: # """Delete platform binding from the configuration""" diff --git a/backend/endpoints/feeds.py b/backend/endpoints/feeds.py index ac761b98a..40f305999 100644 --- a/backend/endpoints/feeds.py +++ b/backend/endpoints/feeds.py @@ -12,6 +12,7 @@ from endpoints.responses.feeds import ( WebrcadeFeedSchema, ) from fastapi import Request +from handler.auth.base_handler import Scope from handler.database import db_platform_handler, db_rom_handler from handler.metadata import meta_igdb_handler from handler.metadata.base_hander import SWITCH_TITLEDB_REGEX @@ -25,7 +26,7 @@ router = APIRouter() @protected_route( router.get, "/webrcade/feed", - [] if DISABLE_DOWNLOAD_ENDPOINT_AUTH else ["roms.read"], + [] if DISABLE_DOWNLOAD_ENDPOINT_AUTH else [Scope.ROMS_READ], ) def platforms_webrcade_feed(request: Request) -> WebrcadeFeedSchema: """Get webrcade feed endpoint diff --git a/backend/endpoints/firmware.py b/backend/endpoints/firmware.py index e09fa7aa3..39a1dc5b6 100644 --- a/backend/endpoints/firmware.py +++ b/backend/endpoints/firmware.py @@ -4,6 +4,7 @@ from endpoints.responses import MessageResponse from endpoints.responses.firmware import AddFirmwareResponse, FirmwareSchema from fastapi import File, HTTPException, Request, UploadFile, status from fastapi.responses import FileResponse +from handler.auth.base_handler import Scope from handler.database import db_firmware_handler, db_platform_handler from handler.filesystem import fs_firmware_handler from handler.scan_handler import scan_firmware @@ -13,7 +14,7 @@ from utils.router import APIRouter router = APIRouter() -@protected_route(router.post, "/firmware", ["firmware.write"]) +@protected_route(router.post, "/firmware", [Scope.FIRMWARE_WRITE]) def add_firmware( request: Request, platform_id: int, @@ -76,7 +77,7 @@ def add_firmware( } -@protected_route(router.get, "/firmware", ["firmware.read"]) +@protected_route(router.get, "/firmware", [Scope.FIRMWARE_READ]) def get_platform_firmware( request: Request, platform_id: int | None = None, @@ -95,7 +96,7 @@ def get_platform_firmware( @protected_route( router.get, "/firmware/{id}", - [] if DISABLE_DOWNLOAD_ENDPOINT_AUTH else ["firmware.read"], + [] if DISABLE_DOWNLOAD_ENDPOINT_AUTH else [Scope.FIRMWARE_READ], ) def get_firmware(request: Request, id: int) -> FirmwareSchema: """Get firmware endpoint @@ -113,7 +114,7 @@ def get_firmware(request: Request, id: int) -> FirmwareSchema: @protected_route( router.head, "/firmware/{id}/content/{file_name}", - [] if DISABLE_DOWNLOAD_ENDPOINT_AUTH else ["firmware.read"], + [] if DISABLE_DOWNLOAD_ENDPOINT_AUTH else [Scope.FIRMWARE_READ], ) def head_firmware_content(request: Request, id: int, file_name: str): """Head firmware content endpoint @@ -139,7 +140,11 @@ def head_firmware_content(request: Request, id: int, file_name: str): ) -@protected_route(router.get, "/firmware/{id}/content/{file_name}", ["firmware.read"]) +@protected_route( + router.get, + "/firmware/{id}/content/{file_name}", + [] if DISABLE_DOWNLOAD_ENDPOINT_AUTH else [Scope.FIRMWARE_READ], +) def get_firmware_content( request: Request, id: int, @@ -162,7 +167,7 @@ def get_firmware_content( return FileResponse(path=firmware_path, filename=firmware.file_name) -@protected_route(router.post, "/firmware/delete", ["firmware.write"]) +@protected_route(router.post, "/firmware/delete", [Scope.FIRMWARE_WRITE]) async def delete_firmware( request: Request, ) -> MessageResponse: diff --git a/backend/endpoints/platform.py b/backend/endpoints/platform.py index 24133f6da..54d4388f3 100644 --- a/backend/endpoints/platform.py +++ b/backend/endpoints/platform.py @@ -6,6 +6,7 @@ from endpoints.responses.platform import PlatformSchema from exceptions.endpoint_exceptions import PlatformNotFoundInDatabaseException from exceptions.fs_exceptions import PlatformAlreadyExistsException from fastapi import Request +from handler.auth.base_handler import Scope from handler.database import db_platform_handler from handler.filesystem import fs_platform_handler from handler.metadata.igdb_handler import IGDB_PLATFORM_LIST @@ -17,7 +18,7 @@ from utils.router import APIRouter router = APIRouter() -@protected_route(router.post, "/platforms", ["platforms.write"]) +@protected_route(router.post, "/platforms", [Scope.PLATFORMS_WRITE]) async def add_platforms(request: Request) -> PlatformSchema: """Create platform endpoint @@ -38,7 +39,7 @@ async def add_platforms(request: Request) -> PlatformSchema: return db_platform_handler.add_platform(scanned_platform) -@protected_route(router.get, "/platforms", ["platforms.read"]) +@protected_route(router.get, "/platforms", [Scope.PLATFORMS_READ]) def get_platforms(request: Request) -> list[PlatformSchema]: """Get platforms endpoint @@ -53,7 +54,7 @@ def get_platforms(request: Request) -> list[PlatformSchema]: return db_platform_handler.get_platforms() -@protected_route(router.get, "/platforms/supported", ["platforms.read"]) +@protected_route(router.get, "/platforms/supported", [Scope.PLATFORMS_READ]) def get_supported_platforms(request: Request) -> list[PlatformSchema]: """Get list of supported platforms endpoint @@ -90,7 +91,7 @@ def get_supported_platforms(request: Request) -> list[PlatformSchema]: return supported_platforms -@protected_route(router.get, "/platforms/{id}", ["platforms.read"]) +@protected_route(router.get, "/platforms/{id}", [Scope.PLATFORMS_READ]) def get_platform(request: Request, id: int) -> PlatformSchema: """Get platforms endpoint @@ -110,7 +111,7 @@ def get_platform(request: Request, id: int) -> PlatformSchema: return platform -@protected_route(router.put, "/platforms/{id}", ["platforms.write"]) +@protected_route(router.put, "/platforms/{id}", [Scope.PLATFORMS_WRITE]) async def update_platform(request: Request) -> MessageResponse: """Update platform endpoint @@ -124,7 +125,7 @@ async def update_platform(request: Request) -> MessageResponse: return {"msg": "Enpoint not available yet"} -@protected_route(router.delete, "/platforms/{id}", ["platforms.write"]) +@protected_route(router.delete, "/platforms/{id}", [Scope.PLATFORMS_WRITE]) async def delete_platforms(request: Request, id: int) -> MessageResponse: """Delete platforms endpoint diff --git a/backend/endpoints/raw.py b/backend/endpoints/raw.py index 564735981..c3af8b2c4 100644 --- a/backend/endpoints/raw.py +++ b/backend/endpoints/raw.py @@ -2,18 +2,19 @@ from config import ASSETS_BASE_PATH from decorators.auth import protected_route from fastapi import Request from fastapi.responses import FileResponse +from handler.auth.base_handler import Scope from utils.router import APIRouter router = APIRouter() -@protected_route(router.head, "/raw/assets/{path:path}", ["assets.read"]) +@protected_route(router.head, "/raw/assets/{path:path}", [Scope.ASSETS_READ]) def head_raw_asset(request: Request, path: str): asset_path = f"{ASSETS_BASE_PATH}/{path}" return FileResponse(path=asset_path, filename=path.split("/")[-1]) -@protected_route(router.get, "/raw/assets/{path:path}", ["assets.read"]) +@protected_route(router.get, "/raw/assets/{path:path}", [Scope.ASSETS_READ]) def get_raw_asset(request: Request, path: str): """Download a single asset file diff --git a/backend/endpoints/rom.py b/backend/endpoints/rom.py index dc9562b8b..0dd449e18 100644 --- a/backend/endpoints/rom.py +++ b/backend/endpoints/rom.py @@ -19,6 +19,7 @@ from exceptions.endpoint_exceptions import RomNotFoundInDatabaseException from exceptions.fs_exceptions import RomAlreadyExistsException from fastapi import HTTPException, Query, Request, UploadFile, status from fastapi.responses import Response +from handler.auth.base_handler import Scope from handler.database import db_platform_handler, db_rom_handler from handler.filesystem import fs_resource_handler, fs_rom_handler from handler.filesystem.base_handler import CoverSize @@ -37,7 +38,7 @@ from utils.router import APIRouter router = APIRouter() -@protected_route(router.post, "/roms", ["roms.write"]) +@protected_route(router.post, "/roms", [Scope.ROMS_WRITE]) async def add_rom(request: Request): """Upload single rom endpoint @@ -99,7 +100,7 @@ async def add_rom(request: Request): return Response(status_code=status.HTTP_201_CREATED) -@protected_route(router.get, "/roms", ["roms.read"]) +@protected_route(router.get, "/roms", [Scope.ROMS_READ]) def get_roms( request: Request, platform_id: int | None = None, @@ -137,7 +138,7 @@ def get_roms( @protected_route( router.get, "/roms/{id}", - [] if DISABLE_DOWNLOAD_ENDPOINT_AUTH else ["roms.read"], + [] if DISABLE_DOWNLOAD_ENDPOINT_AUTH else [Scope.ROMS_READ], ) def get_rom(request: Request, id: int) -> DetailedRomSchema: """Get rom endpoint @@ -161,7 +162,7 @@ def get_rom(request: Request, id: int) -> DetailedRomSchema: @protected_route( router.head, "/roms/{id}/content/{file_name}", - [] if DISABLE_DOWNLOAD_ENDPOINT_AUTH else ["roms.read"], + [] if DISABLE_DOWNLOAD_ENDPOINT_AUTH else [Scope.ROMS_READ], ) async def head_rom_content( request: Request, @@ -222,7 +223,7 @@ async def head_rom_content( @protected_route( router.get, "/roms/{id}/content/{file_name}", - [] if DISABLE_DOWNLOAD_ENDPOINT_AUTH else ["roms.read"], + [] if DISABLE_DOWNLOAD_ENDPOINT_AUTH else [Scope.ROMS_READ], ) async def get_rom_content( request: Request, @@ -289,7 +290,7 @@ async def get_rom_content( ) -@protected_route(router.put, "/roms/{id}", ["roms.write"]) +@protected_route(router.put, "/roms/{id}", [Scope.ROMS_WRITE]) async def update_rom( request: Request, id: int, @@ -466,7 +467,7 @@ async def update_rom( return DetailedRomSchema.from_orm_with_request(db_rom_handler.get_rom(id), request) -@protected_route(router.post, "/roms/delete", ["roms.write"]) +@protected_route(router.post, "/roms/delete", [Scope.ROMS_WRITE]) async def delete_roms( request: Request, ) -> MessageResponse: @@ -517,7 +518,7 @@ async def delete_roms( return {"msg": f"{len(roms_ids)} roms deleted successfully!"} -@protected_route(router.put, "/roms/{id}/props", ["roms.user.write"]) +@protected_route(router.put, "/roms/{id}/props", [Scope.ROMS_USER_WRITE]) async def update_rom_user(request: Request, id: int) -> RomUserSchema: data = await request.json() diff --git a/backend/endpoints/saves.py b/backend/endpoints/saves.py index 0a7893d3f..5cee897ec 100644 --- a/backend/endpoints/saves.py +++ b/backend/endpoints/saves.py @@ -5,6 +5,7 @@ from endpoints.responses import MessageResponse from endpoints.responses.assets import SaveSchema, UploadedSavesResponse from exceptions.endpoint_exceptions import RomNotFoundInDatabaseException from fastapi import File, HTTPException, Request, UploadFile, status +from handler.auth.base_handler import Scope from handler.database import db_rom_handler, db_save_handler, db_screenshot_handler from handler.filesystem import fs_asset_handler from handler.scan_handler import scan_save @@ -14,7 +15,7 @@ from utils.router import APIRouter router = APIRouter() -@protected_route(router.post, "/saves", ["assets.write"]) +@protected_route(router.post, "/saves", [Scope.ASSETS_WRITE]) def add_saves( request: Request, rom_id: int, @@ -82,17 +83,17 @@ def add_saves( } -# @protected_route(router.get, "/saves", ["assets.read"]) +# @protected_route(router.get, "/saves", [Scope.ASSETS_READ]) # def get_saves(request: Request) -> MessageResponse: # pass -# @protected_route(router.get, "/saves/{id}", ["assets.read"]) +# @protected_route(router.get, "/saves/{id}", [Scope.ASSETS_READ]) # def get_save(request: Request, id: int) -> MessageResponse: # pass -@protected_route(router.put, "/saves/{id}", ["assets.write"]) +@protected_route(router.put, "/saves/{id}", [Scope.ASSETS_WRITE]) async def update_save(request: Request, id: int) -> SaveSchema: data = await request.form() @@ -126,7 +127,7 @@ async def update_save(request: Request, id: int) -> SaveSchema: return db_save -@protected_route(router.post, "/saves/delete", ["assets.write"]) +@protected_route(router.post, "/saves/delete", [Scope.ASSETS_WRITE]) async def delete_saves(request: Request) -> MessageResponse: data: dict = await request.json() save_ids: list = data["saves"] diff --git a/backend/endpoints/screenshots.py b/backend/endpoints/screenshots.py index 1b97cab0c..480a52407 100644 --- a/backend/endpoints/screenshots.py +++ b/backend/endpoints/screenshots.py @@ -1,6 +1,7 @@ from decorators.auth import protected_route from endpoints.responses.assets import UploadedScreenshotsResponse from fastapi import File, HTTPException, Request, UploadFile, status +from handler.auth.base_handler import Scope from handler.database import db_rom_handler, db_screenshot_handler from handler.filesystem import fs_asset_handler from handler.scan_handler import scan_screenshot @@ -10,7 +11,7 @@ from utils.router import APIRouter router = APIRouter() -@protected_route(router.post, "/screenshots", ["assets.write"]) +@protected_route(router.post, "/screenshots", [Scope.ASSETS_WRITE]) def add_screenshots( request: Request, rom_id: int, diff --git a/backend/endpoints/search.py b/backend/endpoints/search.py index 57ee96457..e6b66b6e7 100644 --- a/backend/endpoints/search.py +++ b/backend/endpoints/search.py @@ -2,6 +2,7 @@ import emoji from decorators.auth import protected_route from endpoints.responses.search import SearchCoverSchema, SearchRomSchema from fastapi import HTTPException, Request, status +from handler.auth.base_handler import Scope from handler.database import db_rom_handler from handler.metadata import meta_igdb_handler, meta_moby_handler, meta_sgdb_handler from handler.metadata.igdb_handler import IGDB_API_ENABLED @@ -14,7 +15,7 @@ from utils.router import APIRouter router = APIRouter() -@protected_route(router.get, "/search/roms", ["roms.read"]) +@protected_route(router.get, "/search/roms", [Scope.ROMS_READ]) async def search_rom( request: Request, rom_id: str, @@ -111,7 +112,7 @@ async def search_rom( return matched_roms -@protected_route(router.get, "/search/cover", ["roms.read"]) +@protected_route(router.get, "/search/cover", [Scope.ROMS_READ]) async def search_cover( request: Request, search_term: str = "", diff --git a/backend/endpoints/states.py b/backend/endpoints/states.py index 462c236fc..dc95311c6 100644 --- a/backend/endpoints/states.py +++ b/backend/endpoints/states.py @@ -5,6 +5,7 @@ from endpoints.responses import MessageResponse from endpoints.responses.assets import StateSchema, UploadedStatesResponse from exceptions.endpoint_exceptions import RomNotFoundInDatabaseException from fastapi import File, HTTPException, Request, UploadFile, status +from handler.auth.base_handler import Scope from handler.database import db_rom_handler, db_screenshot_handler, db_state_handler from handler.filesystem import fs_asset_handler from handler.scan_handler import scan_state @@ -14,7 +15,7 @@ from utils.router import APIRouter router = APIRouter() -@protected_route(router.post, "/states", ["assets.write"]) +@protected_route(router.post, "/states", [Scope.ASSETS_WRITE]) def add_states( request: Request, rom_id: int, @@ -78,17 +79,17 @@ def add_states( } -# @protected_route(router.get, "/states", ["assets.read"]) +# @protected_route(router.get, "/states", [Scope.ASSETS_READ]) # def get_states(request: Request) -> MessageResponse: # pass -# @protected_route(router.get, "/states/{id}", ["assets.read"]) +# @protected_route(router.get, "/states/{id}", [Scope.ASSETS_READ]) # def get_state(request: Request, id: int) -> MessageResponse: # pass -@protected_route(router.put, "/states/{id}", ["assets.write"]) +@protected_route(router.put, "/states/{id}", [Scope.ASSETS_WRITE]) async def update_state(request: Request, id: int) -> StateSchema: data = await request.form() @@ -121,7 +122,7 @@ async def update_state(request: Request, id: int) -> StateSchema: return db_state -@protected_route(router.post, "/states/delete", ["assets.write"]) +@protected_route(router.post, "/states/delete", [Scope.ASSETS_WRITE]) async def delete_states(request: Request) -> MessageResponse: data: dict = await request.json() state_ids: list = data["states"] diff --git a/backend/endpoints/tasks.py b/backend/endpoints/tasks.py index 1873c7ed7..91015ee86 100644 --- a/backend/endpoints/tasks.py +++ b/backend/endpoints/tasks.py @@ -1,13 +1,14 @@ from decorators.auth import protected_route from endpoints.responses import MessageResponse from fastapi import Request +from handler.auth.base_handler import Scope from tasks.update_switch_titledb import update_switch_titledb_task from utils.router import APIRouter router = APIRouter() -@protected_route(router.post, "/tasks/run", ["tasks.run"]) +@protected_route(router.post, "/tasks/run", [Scope.TASKS_RUN]) async def run_tasks(request: Request) -> MessageResponse: """Run all tasks endpoint @@ -21,7 +22,7 @@ async def run_tasks(request: Request) -> MessageResponse: return {"msg": "All tasks ran successfully!"} -@protected_route(router.post, "/tasks/{task}/run", ["tasks.run"]) +@protected_route(router.post, "/tasks/{task}/run", [Scope.TASKS_RUN]) async def run_task(request: Request, task: str) -> MessageResponse: """Run all tasks endpoint diff --git a/backend/endpoints/user.py b/backend/endpoints/user.py index 3bed5391b..b5862cc13 100644 --- a/backend/endpoints/user.py +++ b/backend/endpoints/user.py @@ -9,6 +9,7 @@ from endpoints.responses import MessageResponse from endpoints.responses.identity import UserSchema from fastapi import Depends, HTTPException, Request, status from handler.auth import auth_handler +from handler.auth.base_handler import Scope from handler.database import db_user_handler from handler.filesystem import fs_asset_handler from logger.logger import log @@ -37,9 +38,9 @@ def add_user(request: Request, username: str, password: str, role: str) -> UserS UserSchema: Created user info """ - # If there are admin users already, enforce the `users.write` scope. + # If there are admin users already, enforce the USERS_WRITE scope. if ( - "users.write" not in request.auth.scopes + Scope.USERS_WRITE not in request.auth.scopes and len(db_user_handler.get_admin_users()) > 0 ): raise HTTPException( @@ -65,7 +66,7 @@ def add_user(request: Request, username: str, password: str, role: str) -> UserS return db_user_handler.add_user(user) -@protected_route(router.get, "/users", ["users.read"]) +@protected_route(router.get, "/users", [Scope.USERS_READ]) def get_users(request: Request) -> list[UserSchema]: """Get all users endpoint @@ -79,7 +80,7 @@ def get_users(request: Request) -> list[UserSchema]: return db_user_handler.get_users() -@protected_route(router.get, "/users/me", ["me.read"]) +@protected_route(router.get, "/users/me", [Scope.ME_READ]) def get_current_user(request: Request) -> UserSchema | None: """Get current user endpoint @@ -93,7 +94,7 @@ def get_current_user(request: Request) -> UserSchema | None: return request.user -@protected_route(router.get, "/users/{id}", ["users.read"]) +@protected_route(router.get, "/users/{id}", [Scope.USERS_READ]) def get_user(request: Request, id: int) -> UserSchema: """Get user endpoint @@ -111,7 +112,7 @@ def get_user(request: Request, id: int) -> UserSchema: return user -@protected_route(router.put, "/users/{id}", ["me.write"]) +@protected_route(router.put, "/users/{id}", [Scope.USERS_WRITE]) async def update_user( request: Request, id: int, form_data: Annotated[UserForm, Depends()] ) -> UserSchema: @@ -192,7 +193,7 @@ async def update_user( return db_user_handler.get_user(id) -@protected_route(router.delete, "/users/{id}", ["users.write"]) +@protected_route(router.delete, "/users/{id}", [Scope.USERS_WRITE]) def delete_user(request: Request, id: int) -> MessageResponse: """Delete user endpoint diff --git a/backend/handler/auth/base_handler.py b/backend/handler/auth/base_handler.py index 836c8f88f..0ce5e0413 100644 --- a/backend/handler/auth/base_handler.py +++ b/backend/handler/auth/base_handler.py @@ -1,3 +1,4 @@ +import enum from datetime import datetime, timedelta, timezone from typing import Final @@ -11,32 +12,53 @@ from passlib.context import CryptContext from starlette.requests import HTTPConnection ALGORITHM: Final = "HS256" -DEFAULT_OAUTH_TOKEN_EXPIRY: Final = 15 +DEFAULT_OAUTH_TOKEN_EXPIRY: Final = timedelta(minutes=15) + + +class Scope(enum.StrEnum): + ME_READ = "me.read" + ME_WRITE = "me.write" + ROMS_READ = "roms.read" + ROMS_WRITE = "roms.write" + ROMS_USER_READ = "roms.user.read" + ROMS_USER_WRITE = "roms.user.write" + PLATFORMS_READ = "platforms.read" + PLATFORMS_WRITE = "platforms.write" + ASSETS_READ = "assets.read" + ASSETS_WRITE = "assets.write" + FIRMWARE_READ = "firmware.read" + FIRMWARE_WRITE = "firmware.write" + COLLECTIONS_READ = "collections.read" + COLLECTIONS_WRITE = "collections.write" + USERS_READ = "users.read" + USERS_WRITE = "users.write" + TASKS_RUN = "tasks.run" + DEFAULT_SCOPES_MAP: Final = { - "me.read": "View your profile", - "me.write": "Modify your profile", - "roms.read": "View ROMs", - "platforms.read": "View platforms", - "assets.read": "View assets", - "assets.write": "Modify assets", - "firmware.read": "View firmware", - "roms.user.read": "View user-rom properties", - "roms.user.write": "Modify user-rom properties", - "collections.read": "View collections", - "collections.write": "Modify collections", + Scope.ME_READ: "View your profile", + Scope.ME_WRITE: "Modify your profile", + Scope.ROMS_READ: "View ROMs", + Scope.PLATFORMS_READ: "View platforms", + Scope.ASSETS_READ: "View assets", + Scope.ASSETS_WRITE: "Modify assets", + Scope.FIRMWARE_READ: "View firmware", + Scope.ROMS_USER_READ: "View user-rom properties", + Scope.ROMS_USER_WRITE: "Modify user-rom properties", + Scope.COLLECTIONS_READ: "View collections", + Scope.COLLECTIONS_WRITE: "Modify collections", } WRITE_SCOPES_MAP: Final = { - "roms.write": "Modify ROMs", - "platforms.write": "Modify platforms", - "firmware.write": "Modify firmware", + Scope.ROMS_WRITE: "Modify ROMs", + Scope.PLATFORMS_WRITE: "Modify platforms", + Scope.FIRMWARE_WRITE: "Modify firmware", } FULL_SCOPES_MAP: Final = { - "users.read": "View users", - "users.write": "Modify users", - "tasks.run": "Run tasks", + Scope.USERS_READ: "View users", + Scope.USERS_WRITE: "Modify users", + Scope.TASKS_RUN: "Run tasks", } DEFAULT_SCOPES: Final = list(DEFAULT_SCOPES_MAP.keys()) @@ -102,16 +124,11 @@ class OAuthHandler: def __init__(self) -> None: pass - def create_oauth_token(self, data: dict, expires_delta: timedelta | None = None): + def create_oauth_token( + self, data: dict, expires_delta: timedelta = DEFAULT_OAUTH_TOKEN_EXPIRY + ) -> str: to_encode = data.copy() - - if expires_delta: - expire = datetime.now(timezone.utc) + expires_delta - else: - expire = datetime.now(timezone.utc) + timedelta( - minutes=DEFAULT_OAUTH_TOKEN_EXPIRY - ) - + expire = datetime.now(timezone.utc) + expires_delta to_encode.update({"exp": expire}) return jwt.encode( diff --git a/backend/models/user.py b/backend/models/user.py index 635e19eaa..92c7d260a 100644 --- a/backend/models/user.py +++ b/backend/models/user.py @@ -10,6 +10,7 @@ from sqlalchemy.orm import Mapped, mapped_column, relationship from starlette.authentication import SimpleUser if TYPE_CHECKING: + from handler.auth.base_handler import Scope from models.assets import Save, Screenshot, State from models.collection import Collection from models.rom import RomUser @@ -42,7 +43,7 @@ class User(BaseModel, SimpleUser): collections: Mapped[list[Collection]] = relationship(back_populates="user") @property - def oauth_scopes(self): + def oauth_scopes(self) -> list[Scope]: from handler.auth.base_handler import DEFAULT_SCOPES, FULL_SCOPES, WRITE_SCOPES if self.role == Role.ADMIN: