diff --git a/backend/decorators/auth.py b/backend/decorators/auth.py index 37c174e2b..9c03e8fcc 100644 --- a/backend/decorators/auth.py +++ b/backend/decorators/auth.py @@ -8,6 +8,7 @@ from handler.auth.base_handler import ( DEFAULT_SCOPES_MAP, FULL_SCOPES_MAP, WRITE_SCOPES_MAP, + Scope, ) from starlette.authentication import requires @@ -25,7 +26,7 @@ oauth2_password_bearer = OAuth2PasswordBearer( def protected_route( method: Any, path: str, - scopes: list[str] | None = None, + scopes: list[Scope] | None = None, **kwargs, ): def decorator(func: DecoratedCallable): diff --git a/backend/endpoints/collections.py b/backend/endpoints/collections.py index 19270244d..8dff4fe4a 100644 --- a/backend/endpoints/collections.py +++ b/backend/endpoints/collections.py @@ -13,6 +13,7 @@ from exceptions.endpoint_exceptions import ( CollectionPermissionError, ) from fastapi import Request, UploadFile +from handler.auth.base_handler import Scope from handler.database import db_collection_handler from handler.filesystem import fs_resource_handler from handler.filesystem.base_handler import CoverSize @@ -25,7 +26,7 @@ from utils.router import APIRouter router = APIRouter() -@protected_route(router.post, "/collections", ["collections.write"]) +@protected_route(router.post, "/collections", [Scope.COLLECTIONS_WRITE]) async def add_collection( request: Request, artwork: UploadFile | None = None, @@ -91,7 +92,7 @@ async def add_collection( ) -@protected_route(router.get, "/collections", ["collections.read"]) +@protected_route(router.get, "/collections", [Scope.COLLECTIONS_READ]) def get_collections(request: Request) -> list[CollectionSchema]: """Get collections endpoint @@ -107,7 +108,7 @@ def get_collections(request: Request) -> list[CollectionSchema]: return CollectionSchema.for_user(request.user.id, collections) -@protected_route(router.get, "/collections/{id}", ["collections.read"]) +@protected_route(router.get, "/collections/{id}", [Scope.COLLECTIONS_READ]) def get_collection(request: Request, id: int) -> CollectionSchema: """Get collections endpoint @@ -127,7 +128,7 @@ def get_collection(request: Request, id: int) -> CollectionSchema: return collection -@protected_route(router.put, "/collections/{id}", ["collections.write"]) +@protected_route(router.put, "/collections/{id}", [Scope.COLLECTIONS_WRITE]) async def update_collection( request: Request, id: int, @@ -213,7 +214,7 @@ async def update_collection( return db_collection_handler.update_collection(id, cleaned_data) -@protected_route(router.delete, "/collections/{id}", ["collections.write"]) +@protected_route(router.delete, "/collections/{id}", [Scope.COLLECTIONS_WRITE]) async def delete_collections(request: Request, id: int) -> MessageResponse: """Delete collections endpoint diff --git a/backend/endpoints/config.py b/backend/endpoints/config.py index eba2bbb37..a7f9b2a45 100644 --- a/backend/endpoints/config.py +++ b/backend/endpoints/config.py @@ -7,6 +7,7 @@ from exceptions.config_exceptions import ( ConfigNotWritableException, ) from fastapi import HTTPException, Request, status +from handler.auth.base_handler import Scope from logger.logger import log from utils.router import APIRouter @@ -43,7 +44,7 @@ def get_config() -> ConfigResponse: ) from exc -@protected_route(router.post, "/config/system/platforms", ["platforms.write"]) +@protected_route(router.post, "/config/system/platforms", [Scope.PLATFORMS_WRITE]) async def add_platform_binding(request: Request) -> MessageResponse: """Add platform binding to the configuration""" @@ -63,7 +64,7 @@ async def add_platform_binding(request: Request) -> MessageResponse: @protected_route( - router.delete, "/config/system/platforms/{fs_slug}", ["platforms.write"] + router.delete, "/config/system/platforms/{fs_slug}", [Scope.PLATFORMS_WRITE] ) async def delete_platform_binding(request: Request, fs_slug: str) -> MessageResponse: """Delete platform binding from the configuration""" @@ -79,7 +80,7 @@ async def delete_platform_binding(request: Request, fs_slug: str) -> MessageResp return {"msg": f"{fs_slug} bind removed successfully!"} -@protected_route(router.post, "/config/system/versions", ["platforms.write"]) +@protected_route(router.post, "/config/system/versions", [Scope.PLATFORMS_WRITE]) async def add_platform_version(request: Request) -> MessageResponse: """Add platform version to the configuration""" @@ -99,7 +100,7 @@ async def add_platform_version(request: Request) -> MessageResponse: @protected_route( - router.delete, "/config/system/versions/{fs_slug}", ["platforms.write"] + router.delete, "/config/system/versions/{fs_slug}", [Scope.PLATFORMS_WRITE] ) async def delete_platform_version(request: Request, fs_slug: str) -> MessageResponse: """Delete platform version from the configuration""" @@ -115,7 +116,7 @@ async def delete_platform_version(request: Request, fs_slug: str) -> MessageResp return {"msg": f"{fs_slug} version removed successfully!"} -# @protected_route(router.post, "/config/exclude", ["platforms.write"]) +# @protected_route(router.post, "/config/exclude", [Scope.PLATFORMS_WRITE]) # async def add_exclusion(request: Request) -> MessageResponse: # """Add platform binding to the configuration""" @@ -127,7 +128,7 @@ async def delete_platform_version(request: Request, fs_slug: str) -> MessageResp # return {"msg": f"Exclusion {exclusion} added to {exclude} successfully!"} -# @protected_route(router.delete, "/config/exclude", ["platforms.write"]) +# @protected_route(router.delete, "/config/exclude", [Scope.PLATFORMS_WRITE]) # async def delete_exclusion(request: Request) -> MessageResponse: # """Delete platform binding from the configuration""" diff --git a/backend/endpoints/feeds.py b/backend/endpoints/feeds.py index ac761b98a..40f305999 100644 --- a/backend/endpoints/feeds.py +++ b/backend/endpoints/feeds.py @@ -12,6 +12,7 @@ from endpoints.responses.feeds import ( WebrcadeFeedSchema, ) from fastapi import Request +from handler.auth.base_handler import Scope from handler.database import db_platform_handler, db_rom_handler from handler.metadata import meta_igdb_handler from handler.metadata.base_hander import SWITCH_TITLEDB_REGEX @@ -25,7 +26,7 @@ router = APIRouter() @protected_route( router.get, "/webrcade/feed", - [] if DISABLE_DOWNLOAD_ENDPOINT_AUTH else ["roms.read"], + [] if DISABLE_DOWNLOAD_ENDPOINT_AUTH else [Scope.ROMS_READ], ) def platforms_webrcade_feed(request: Request) -> WebrcadeFeedSchema: """Get webrcade feed endpoint diff --git a/backend/endpoints/firmware.py b/backend/endpoints/firmware.py index e09fa7aa3..39a1dc5b6 100644 --- a/backend/endpoints/firmware.py +++ b/backend/endpoints/firmware.py @@ -4,6 +4,7 @@ from endpoints.responses import MessageResponse from endpoints.responses.firmware import AddFirmwareResponse, FirmwareSchema from fastapi import File, HTTPException, Request, UploadFile, status from fastapi.responses import FileResponse +from handler.auth.base_handler import Scope from handler.database import db_firmware_handler, db_platform_handler from handler.filesystem import fs_firmware_handler from handler.scan_handler import scan_firmware @@ -13,7 +14,7 @@ from utils.router import APIRouter router = APIRouter() -@protected_route(router.post, "/firmware", ["firmware.write"]) +@protected_route(router.post, "/firmware", [Scope.FIRMWARE_WRITE]) def add_firmware( request: Request, platform_id: int, @@ -76,7 +77,7 @@ def add_firmware( } -@protected_route(router.get, "/firmware", ["firmware.read"]) +@protected_route(router.get, "/firmware", [Scope.FIRMWARE_READ]) def get_platform_firmware( request: Request, platform_id: int | None = None, @@ -95,7 +96,7 @@ def get_platform_firmware( @protected_route( router.get, "/firmware/{id}", - [] if DISABLE_DOWNLOAD_ENDPOINT_AUTH else ["firmware.read"], + [] if DISABLE_DOWNLOAD_ENDPOINT_AUTH else [Scope.FIRMWARE_READ], ) def get_firmware(request: Request, id: int) -> FirmwareSchema: """Get firmware endpoint @@ -113,7 +114,7 @@ def get_firmware(request: Request, id: int) -> FirmwareSchema: @protected_route( router.head, "/firmware/{id}/content/{file_name}", - [] if DISABLE_DOWNLOAD_ENDPOINT_AUTH else ["firmware.read"], + [] if DISABLE_DOWNLOAD_ENDPOINT_AUTH else [Scope.FIRMWARE_READ], ) def head_firmware_content(request: Request, id: int, file_name: str): """Head firmware content endpoint @@ -139,7 +140,11 @@ def head_firmware_content(request: Request, id: int, file_name: str): ) -@protected_route(router.get, "/firmware/{id}/content/{file_name}", ["firmware.read"]) +@protected_route( + router.get, + "/firmware/{id}/content/{file_name}", + [] if DISABLE_DOWNLOAD_ENDPOINT_AUTH else [Scope.FIRMWARE_READ], +) def get_firmware_content( request: Request, id: int, @@ -162,7 +167,7 @@ def get_firmware_content( return FileResponse(path=firmware_path, filename=firmware.file_name) -@protected_route(router.post, "/firmware/delete", ["firmware.write"]) +@protected_route(router.post, "/firmware/delete", [Scope.FIRMWARE_WRITE]) async def delete_firmware( request: Request, ) -> MessageResponse: diff --git a/backend/endpoints/platform.py b/backend/endpoints/platform.py index 24133f6da..54d4388f3 100644 --- a/backend/endpoints/platform.py +++ b/backend/endpoints/platform.py @@ -6,6 +6,7 @@ from endpoints.responses.platform import PlatformSchema from exceptions.endpoint_exceptions import PlatformNotFoundInDatabaseException from exceptions.fs_exceptions import PlatformAlreadyExistsException from fastapi import Request +from handler.auth.base_handler import Scope from handler.database import db_platform_handler from handler.filesystem import fs_platform_handler from handler.metadata.igdb_handler import IGDB_PLATFORM_LIST @@ -17,7 +18,7 @@ from utils.router import APIRouter router = APIRouter() -@protected_route(router.post, "/platforms", ["platforms.write"]) +@protected_route(router.post, "/platforms", [Scope.PLATFORMS_WRITE]) async def add_platforms(request: Request) -> PlatformSchema: """Create platform endpoint @@ -38,7 +39,7 @@ async def add_platforms(request: Request) -> PlatformSchema: return db_platform_handler.add_platform(scanned_platform) -@protected_route(router.get, "/platforms", ["platforms.read"]) +@protected_route(router.get, "/platforms", [Scope.PLATFORMS_READ]) def get_platforms(request: Request) -> list[PlatformSchema]: """Get platforms endpoint @@ -53,7 +54,7 @@ def get_platforms(request: Request) -> list[PlatformSchema]: return db_platform_handler.get_platforms() -@protected_route(router.get, "/platforms/supported", ["platforms.read"]) +@protected_route(router.get, "/platforms/supported", [Scope.PLATFORMS_READ]) def get_supported_platforms(request: Request) -> list[PlatformSchema]: """Get list of supported platforms endpoint @@ -90,7 +91,7 @@ def get_supported_platforms(request: Request) -> list[PlatformSchema]: return supported_platforms -@protected_route(router.get, "/platforms/{id}", ["platforms.read"]) +@protected_route(router.get, "/platforms/{id}", [Scope.PLATFORMS_READ]) def get_platform(request: Request, id: int) -> PlatformSchema: """Get platforms endpoint @@ -110,7 +111,7 @@ def get_platform(request: Request, id: int) -> PlatformSchema: return platform -@protected_route(router.put, "/platforms/{id}", ["platforms.write"]) +@protected_route(router.put, "/platforms/{id}", [Scope.PLATFORMS_WRITE]) async def update_platform(request: Request) -> MessageResponse: """Update platform endpoint @@ -124,7 +125,7 @@ async def update_platform(request: Request) -> MessageResponse: return {"msg": "Enpoint not available yet"} -@protected_route(router.delete, "/platforms/{id}", ["platforms.write"]) +@protected_route(router.delete, "/platforms/{id}", [Scope.PLATFORMS_WRITE]) async def delete_platforms(request: Request, id: int) -> MessageResponse: """Delete platforms endpoint diff --git a/backend/endpoints/raw.py b/backend/endpoints/raw.py index 564735981..c3af8b2c4 100644 --- a/backend/endpoints/raw.py +++ b/backend/endpoints/raw.py @@ -2,18 +2,19 @@ from config import ASSETS_BASE_PATH from decorators.auth import protected_route from fastapi import Request from fastapi.responses import FileResponse +from handler.auth.base_handler import Scope from utils.router import APIRouter router = APIRouter() -@protected_route(router.head, "/raw/assets/{path:path}", ["assets.read"]) +@protected_route(router.head, "/raw/assets/{path:path}", [Scope.ASSETS_READ]) def head_raw_asset(request: Request, path: str): asset_path = f"{ASSETS_BASE_PATH}/{path}" return FileResponse(path=asset_path, filename=path.split("/")[-1]) -@protected_route(router.get, "/raw/assets/{path:path}", ["assets.read"]) +@protected_route(router.get, "/raw/assets/{path:path}", [Scope.ASSETS_READ]) def get_raw_asset(request: Request, path: str): """Download a single asset file diff --git a/backend/endpoints/rom.py b/backend/endpoints/rom.py index dc9562b8b..0dd449e18 100644 --- a/backend/endpoints/rom.py +++ b/backend/endpoints/rom.py @@ -19,6 +19,7 @@ from exceptions.endpoint_exceptions import RomNotFoundInDatabaseException from exceptions.fs_exceptions import RomAlreadyExistsException from fastapi import HTTPException, Query, Request, UploadFile, status from fastapi.responses import Response +from handler.auth.base_handler import Scope from handler.database import db_platform_handler, db_rom_handler from handler.filesystem import fs_resource_handler, fs_rom_handler from handler.filesystem.base_handler import CoverSize @@ -37,7 +38,7 @@ from utils.router import APIRouter router = APIRouter() -@protected_route(router.post, "/roms", ["roms.write"]) +@protected_route(router.post, "/roms", [Scope.ROMS_WRITE]) async def add_rom(request: Request): """Upload single rom endpoint @@ -99,7 +100,7 @@ async def add_rom(request: Request): return Response(status_code=status.HTTP_201_CREATED) -@protected_route(router.get, "/roms", ["roms.read"]) +@protected_route(router.get, "/roms", [Scope.ROMS_READ]) def get_roms( request: Request, platform_id: int | None = None, @@ -137,7 +138,7 @@ def get_roms( @protected_route( router.get, "/roms/{id}", - [] if DISABLE_DOWNLOAD_ENDPOINT_AUTH else ["roms.read"], + [] if DISABLE_DOWNLOAD_ENDPOINT_AUTH else [Scope.ROMS_READ], ) def get_rom(request: Request, id: int) -> DetailedRomSchema: """Get rom endpoint @@ -161,7 +162,7 @@ def get_rom(request: Request, id: int) -> DetailedRomSchema: @protected_route( router.head, "/roms/{id}/content/{file_name}", - [] if DISABLE_DOWNLOAD_ENDPOINT_AUTH else ["roms.read"], + [] if DISABLE_DOWNLOAD_ENDPOINT_AUTH else [Scope.ROMS_READ], ) async def head_rom_content( request: Request, @@ -222,7 +223,7 @@ async def head_rom_content( @protected_route( router.get, "/roms/{id}/content/{file_name}", - [] if DISABLE_DOWNLOAD_ENDPOINT_AUTH else ["roms.read"], + [] if DISABLE_DOWNLOAD_ENDPOINT_AUTH else [Scope.ROMS_READ], ) async def get_rom_content( request: Request, @@ -289,7 +290,7 @@ async def get_rom_content( ) -@protected_route(router.put, "/roms/{id}", ["roms.write"]) +@protected_route(router.put, "/roms/{id}", [Scope.ROMS_WRITE]) async def update_rom( request: Request, id: int, @@ -466,7 +467,7 @@ async def update_rom( return DetailedRomSchema.from_orm_with_request(db_rom_handler.get_rom(id), request) -@protected_route(router.post, "/roms/delete", ["roms.write"]) +@protected_route(router.post, "/roms/delete", [Scope.ROMS_WRITE]) async def delete_roms( request: Request, ) -> MessageResponse: @@ -517,7 +518,7 @@ async def delete_roms( return {"msg": f"{len(roms_ids)} roms deleted successfully!"} -@protected_route(router.put, "/roms/{id}/props", ["roms.user.write"]) +@protected_route(router.put, "/roms/{id}/props", [Scope.ROMS_USER_WRITE]) async def update_rom_user(request: Request, id: int) -> RomUserSchema: data = await request.json() diff --git a/backend/endpoints/saves.py b/backend/endpoints/saves.py index 0a7893d3f..5cee897ec 100644 --- a/backend/endpoints/saves.py +++ b/backend/endpoints/saves.py @@ -5,6 +5,7 @@ from endpoints.responses import MessageResponse from endpoints.responses.assets import SaveSchema, UploadedSavesResponse from exceptions.endpoint_exceptions import RomNotFoundInDatabaseException from fastapi import File, HTTPException, Request, UploadFile, status +from handler.auth.base_handler import Scope from handler.database import db_rom_handler, db_save_handler, db_screenshot_handler from handler.filesystem import fs_asset_handler from handler.scan_handler import scan_save @@ -14,7 +15,7 @@ from utils.router import APIRouter router = APIRouter() -@protected_route(router.post, "/saves", ["assets.write"]) +@protected_route(router.post, "/saves", [Scope.ASSETS_WRITE]) def add_saves( request: Request, rom_id: int, @@ -82,17 +83,17 @@ def add_saves( } -# @protected_route(router.get, "/saves", ["assets.read"]) +# @protected_route(router.get, "/saves", [Scope.ASSETS_READ]) # def get_saves(request: Request) -> MessageResponse: # pass -# @protected_route(router.get, "/saves/{id}", ["assets.read"]) +# @protected_route(router.get, "/saves/{id}", [Scope.ASSETS_READ]) # def get_save(request: Request, id: int) -> MessageResponse: # pass -@protected_route(router.put, "/saves/{id}", ["assets.write"]) +@protected_route(router.put, "/saves/{id}", [Scope.ASSETS_WRITE]) async def update_save(request: Request, id: int) -> SaveSchema: data = await request.form() @@ -126,7 +127,7 @@ async def update_save(request: Request, id: int) -> SaveSchema: return db_save -@protected_route(router.post, "/saves/delete", ["assets.write"]) +@protected_route(router.post, "/saves/delete", [Scope.ASSETS_WRITE]) async def delete_saves(request: Request) -> MessageResponse: data: dict = await request.json() save_ids: list = data["saves"] diff --git a/backend/endpoints/screenshots.py b/backend/endpoints/screenshots.py index 1b97cab0c..480a52407 100644 --- a/backend/endpoints/screenshots.py +++ b/backend/endpoints/screenshots.py @@ -1,6 +1,7 @@ from decorators.auth import protected_route from endpoints.responses.assets import UploadedScreenshotsResponse from fastapi import File, HTTPException, Request, UploadFile, status +from handler.auth.base_handler import Scope from handler.database import db_rom_handler, db_screenshot_handler from handler.filesystem import fs_asset_handler from handler.scan_handler import scan_screenshot @@ -10,7 +11,7 @@ from utils.router import APIRouter router = APIRouter() -@protected_route(router.post, "/screenshots", ["assets.write"]) +@protected_route(router.post, "/screenshots", [Scope.ASSETS_WRITE]) def add_screenshots( request: Request, rom_id: int, diff --git a/backend/endpoints/search.py b/backend/endpoints/search.py index 57ee96457..e6b66b6e7 100644 --- a/backend/endpoints/search.py +++ b/backend/endpoints/search.py @@ -2,6 +2,7 @@ import emoji from decorators.auth import protected_route from endpoints.responses.search import SearchCoverSchema, SearchRomSchema from fastapi import HTTPException, Request, status +from handler.auth.base_handler import Scope from handler.database import db_rom_handler from handler.metadata import meta_igdb_handler, meta_moby_handler, meta_sgdb_handler from handler.metadata.igdb_handler import IGDB_API_ENABLED @@ -14,7 +15,7 @@ from utils.router import APIRouter router = APIRouter() -@protected_route(router.get, "/search/roms", ["roms.read"]) +@protected_route(router.get, "/search/roms", [Scope.ROMS_READ]) async def search_rom( request: Request, rom_id: str, @@ -111,7 +112,7 @@ async def search_rom( return matched_roms -@protected_route(router.get, "/search/cover", ["roms.read"]) +@protected_route(router.get, "/search/cover", [Scope.ROMS_READ]) async def search_cover( request: Request, search_term: str = "", diff --git a/backend/endpoints/states.py b/backend/endpoints/states.py index 462c236fc..dc95311c6 100644 --- a/backend/endpoints/states.py +++ b/backend/endpoints/states.py @@ -5,6 +5,7 @@ from endpoints.responses import MessageResponse from endpoints.responses.assets import StateSchema, UploadedStatesResponse from exceptions.endpoint_exceptions import RomNotFoundInDatabaseException from fastapi import File, HTTPException, Request, UploadFile, status +from handler.auth.base_handler import Scope from handler.database import db_rom_handler, db_screenshot_handler, db_state_handler from handler.filesystem import fs_asset_handler from handler.scan_handler import scan_state @@ -14,7 +15,7 @@ from utils.router import APIRouter router = APIRouter() -@protected_route(router.post, "/states", ["assets.write"]) +@protected_route(router.post, "/states", [Scope.ASSETS_WRITE]) def add_states( request: Request, rom_id: int, @@ -78,17 +79,17 @@ def add_states( } -# @protected_route(router.get, "/states", ["assets.read"]) +# @protected_route(router.get, "/states", [Scope.ASSETS_READ]) # def get_states(request: Request) -> MessageResponse: # pass -# @protected_route(router.get, "/states/{id}", ["assets.read"]) +# @protected_route(router.get, "/states/{id}", [Scope.ASSETS_READ]) # def get_state(request: Request, id: int) -> MessageResponse: # pass -@protected_route(router.put, "/states/{id}", ["assets.write"]) +@protected_route(router.put, "/states/{id}", [Scope.ASSETS_WRITE]) async def update_state(request: Request, id: int) -> StateSchema: data = await request.form() @@ -121,7 +122,7 @@ async def update_state(request: Request, id: int) -> StateSchema: return db_state -@protected_route(router.post, "/states/delete", ["assets.write"]) +@protected_route(router.post, "/states/delete", [Scope.ASSETS_WRITE]) async def delete_states(request: Request) -> MessageResponse: data: dict = await request.json() state_ids: list = data["states"] diff --git a/backend/endpoints/tasks.py b/backend/endpoints/tasks.py index 1873c7ed7..91015ee86 100644 --- a/backend/endpoints/tasks.py +++ b/backend/endpoints/tasks.py @@ -1,13 +1,14 @@ from decorators.auth import protected_route from endpoints.responses import MessageResponse from fastapi import Request +from handler.auth.base_handler import Scope from tasks.update_switch_titledb import update_switch_titledb_task from utils.router import APIRouter router = APIRouter() -@protected_route(router.post, "/tasks/run", ["tasks.run"]) +@protected_route(router.post, "/tasks/run", [Scope.TASKS_RUN]) async def run_tasks(request: Request) -> MessageResponse: """Run all tasks endpoint @@ -21,7 +22,7 @@ async def run_tasks(request: Request) -> MessageResponse: return {"msg": "All tasks ran successfully!"} -@protected_route(router.post, "/tasks/{task}/run", ["tasks.run"]) +@protected_route(router.post, "/tasks/{task}/run", [Scope.TASKS_RUN]) async def run_task(request: Request, task: str) -> MessageResponse: """Run all tasks endpoint diff --git a/backend/endpoints/user.py b/backend/endpoints/user.py index 3bed5391b..b5862cc13 100644 --- a/backend/endpoints/user.py +++ b/backend/endpoints/user.py @@ -9,6 +9,7 @@ from endpoints.responses import MessageResponse from endpoints.responses.identity import UserSchema from fastapi import Depends, HTTPException, Request, status from handler.auth import auth_handler +from handler.auth.base_handler import Scope from handler.database import db_user_handler from handler.filesystem import fs_asset_handler from logger.logger import log @@ -37,9 +38,9 @@ def add_user(request: Request, username: str, password: str, role: str) -> UserS UserSchema: Created user info """ - # If there are admin users already, enforce the `users.write` scope. + # If there are admin users already, enforce the USERS_WRITE scope. if ( - "users.write" not in request.auth.scopes + Scope.USERS_WRITE not in request.auth.scopes and len(db_user_handler.get_admin_users()) > 0 ): raise HTTPException( @@ -65,7 +66,7 @@ def add_user(request: Request, username: str, password: str, role: str) -> UserS return db_user_handler.add_user(user) -@protected_route(router.get, "/users", ["users.read"]) +@protected_route(router.get, "/users", [Scope.USERS_READ]) def get_users(request: Request) -> list[UserSchema]: """Get all users endpoint @@ -79,7 +80,7 @@ def get_users(request: Request) -> list[UserSchema]: return db_user_handler.get_users() -@protected_route(router.get, "/users/me", ["me.read"]) +@protected_route(router.get, "/users/me", [Scope.ME_READ]) def get_current_user(request: Request) -> UserSchema | None: """Get current user endpoint @@ -93,7 +94,7 @@ def get_current_user(request: Request) -> UserSchema | None: return request.user -@protected_route(router.get, "/users/{id}", ["users.read"]) +@protected_route(router.get, "/users/{id}", [Scope.USERS_READ]) def get_user(request: Request, id: int) -> UserSchema: """Get user endpoint @@ -111,7 +112,7 @@ def get_user(request: Request, id: int) -> UserSchema: return user -@protected_route(router.put, "/users/{id}", ["me.write"]) +@protected_route(router.put, "/users/{id}", [Scope.USERS_WRITE]) async def update_user( request: Request, id: int, form_data: Annotated[UserForm, Depends()] ) -> UserSchema: @@ -192,7 +193,7 @@ async def update_user( return db_user_handler.get_user(id) -@protected_route(router.delete, "/users/{id}", ["users.write"]) +@protected_route(router.delete, "/users/{id}", [Scope.USERS_WRITE]) def delete_user(request: Request, id: int) -> MessageResponse: """Delete user endpoint diff --git a/backend/handler/auth/base_handler.py b/backend/handler/auth/base_handler.py index 836c8f88f..0ce5e0413 100644 --- a/backend/handler/auth/base_handler.py +++ b/backend/handler/auth/base_handler.py @@ -1,3 +1,4 @@ +import enum from datetime import datetime, timedelta, timezone from typing import Final @@ -11,32 +12,53 @@ from passlib.context import CryptContext from starlette.requests import HTTPConnection ALGORITHM: Final = "HS256" -DEFAULT_OAUTH_TOKEN_EXPIRY: Final = 15 +DEFAULT_OAUTH_TOKEN_EXPIRY: Final = timedelta(minutes=15) + + +class Scope(enum.StrEnum): + ME_READ = "me.read" + ME_WRITE = "me.write" + ROMS_READ = "roms.read" + ROMS_WRITE = "roms.write" + ROMS_USER_READ = "roms.user.read" + ROMS_USER_WRITE = "roms.user.write" + PLATFORMS_READ = "platforms.read" + PLATFORMS_WRITE = "platforms.write" + ASSETS_READ = "assets.read" + ASSETS_WRITE = "assets.write" + FIRMWARE_READ = "firmware.read" + FIRMWARE_WRITE = "firmware.write" + COLLECTIONS_READ = "collections.read" + COLLECTIONS_WRITE = "collections.write" + USERS_READ = "users.read" + USERS_WRITE = "users.write" + TASKS_RUN = "tasks.run" + DEFAULT_SCOPES_MAP: Final = { - "me.read": "View your profile", - "me.write": "Modify your profile", - "roms.read": "View ROMs", - "platforms.read": "View platforms", - "assets.read": "View assets", - "assets.write": "Modify assets", - "firmware.read": "View firmware", - "roms.user.read": "View user-rom properties", - "roms.user.write": "Modify user-rom properties", - "collections.read": "View collections", - "collections.write": "Modify collections", + Scope.ME_READ: "View your profile", + Scope.ME_WRITE: "Modify your profile", + Scope.ROMS_READ: "View ROMs", + Scope.PLATFORMS_READ: "View platforms", + Scope.ASSETS_READ: "View assets", + Scope.ASSETS_WRITE: "Modify assets", + Scope.FIRMWARE_READ: "View firmware", + Scope.ROMS_USER_READ: "View user-rom properties", + Scope.ROMS_USER_WRITE: "Modify user-rom properties", + Scope.COLLECTIONS_READ: "View collections", + Scope.COLLECTIONS_WRITE: "Modify collections", } WRITE_SCOPES_MAP: Final = { - "roms.write": "Modify ROMs", - "platforms.write": "Modify platforms", - "firmware.write": "Modify firmware", + Scope.ROMS_WRITE: "Modify ROMs", + Scope.PLATFORMS_WRITE: "Modify platforms", + Scope.FIRMWARE_WRITE: "Modify firmware", } FULL_SCOPES_MAP: Final = { - "users.read": "View users", - "users.write": "Modify users", - "tasks.run": "Run tasks", + Scope.USERS_READ: "View users", + Scope.USERS_WRITE: "Modify users", + Scope.TASKS_RUN: "Run tasks", } DEFAULT_SCOPES: Final = list(DEFAULT_SCOPES_MAP.keys()) @@ -102,16 +124,11 @@ class OAuthHandler: def __init__(self) -> None: pass - def create_oauth_token(self, data: dict, expires_delta: timedelta | None = None): + def create_oauth_token( + self, data: dict, expires_delta: timedelta = DEFAULT_OAUTH_TOKEN_EXPIRY + ) -> str: to_encode = data.copy() - - if expires_delta: - expire = datetime.now(timezone.utc) + expires_delta - else: - expire = datetime.now(timezone.utc) + timedelta( - minutes=DEFAULT_OAUTH_TOKEN_EXPIRY - ) - + expire = datetime.now(timezone.utc) + expires_delta to_encode.update({"exp": expire}) return jwt.encode( diff --git a/backend/models/user.py b/backend/models/user.py index 635e19eaa..92c7d260a 100644 --- a/backend/models/user.py +++ b/backend/models/user.py @@ -10,6 +10,7 @@ from sqlalchemy.orm import Mapped, mapped_column, relationship from starlette.authentication import SimpleUser if TYPE_CHECKING: + from handler.auth.base_handler import Scope from models.assets import Save, Screenshot, State from models.collection import Collection from models.rom import RomUser @@ -42,7 +43,7 @@ class User(BaseModel, SimpleUser): collections: Mapped[list[Collection]] = relationship(back_populates="user") @property - def oauth_scopes(self): + def oauth_scopes(self) -> list[Scope]: from handler.auth.base_handler import DEFAULT_SCOPES, FULL_SCOPES, WRITE_SCOPES if self.role == Role.ADMIN: