diff --git a/backend/endpoints/collections.py b/backend/endpoints/collections.py index d1a632754..fe5157525 100644 --- a/backend/endpoints/collections.py +++ b/backend/endpoints/collections.py @@ -24,7 +24,11 @@ from handler.filesystem.base_handler import CoverSize from logger.formatter import BLUE from logger.formatter import highlight as hl from logger.logger import log -from models.collection import Collection, SmartCollection +from models.collection import ( + Collection, + SmartCollection, + VirtualCollection, +) from utils.router import APIRouter router = APIRouter( @@ -153,9 +157,6 @@ def get_collections( description="Filter collections updated after this datetime (ISO 8601 format with timezone information)." ), ] = None, - only_ids: Annotated[ - bool | None, Query(description="Only return list of IDs") - ] = None, ) -> list[CollectionSchema]: """Get collections endpoint @@ -168,12 +169,32 @@ def get_collections( """ collections = db_collection_handler.get_collections( - updated_after=updated_after, only_fields=[Collection.id] if only_ids else None + updated_after=updated_after, ) return CollectionSchema.for_user(request.user.id, [c for c in collections]) +@protected_route(router.get, "/identifiers", [Scope.COLLECTIONS_READ]) +def get_collection_identifiers( + request: Request, +) -> list[int]: + """Get collections identifiers endpoint + + Args: + request (Request): Fastapi Request object + + Returns: + list[int]: List of collection ids + """ + + collections = db_collection_handler.get_collections( + only_fields=[Collection.id, Collection.user_id, Collection.is_public], + ) + + return [c.id for c in collections if c.user_id == request.user.id or c.is_public] + + @protected_route(router.get, "/virtual", [Scope.COLLECTIONS_READ]) def get_virtual_collections( request: Request, @@ -215,7 +236,8 @@ def get_smart_collections( """ smart_collections = db_collection_handler.get_smart_collections( - request.user.id, updated_after=updated_after + request.user.id, + updated_after=updated_after, ) return SmartCollectionSchema.for_user( @@ -223,6 +245,27 @@ def get_smart_collections( ) +@protected_route(router.get, "/smart/identifiers", [Scope.COLLECTIONS_READ]) +def get_smart_collection_identifiers( + request: Request, +) -> list[int]: + """Get smart collections identifiers endpoint + + Args: + request (Request): Fastapi Request object + + Returns: + list[int]: List of smart collection ids + """ + + smart_collections = db_collection_handler.get_smart_collections( + request.user.id, + only_fields=[SmartCollection.id], + ) + + return [s.id for s in smart_collections] + + @protected_route(router.get, "/{id}", [Scope.COLLECTIONS_READ]) def get_collection(request: Request, id: int) -> CollectionSchema: """Get collections endpoint diff --git a/backend/endpoints/firmware.py b/backend/endpoints/firmware.py index efc5ee0e7..b68c22ad4 100644 --- a/backend/endpoints/firmware.py +++ b/backend/endpoints/firmware.py @@ -1,6 +1,6 @@ from typing import Annotated -from fastapi import Body, File, HTTPException, Request, UploadFile, status +from fastapi import Body, File, HTTPException, Query, Request, UploadFile, status from fastapi.responses import FileResponse from config import DISABLE_DOWNLOAD_ENDPOINT_AUTH @@ -119,10 +119,28 @@ def get_platform_firmware( Returns: list[FirmwareSchema]: Firmware stored in the database """ - return [ - FirmwareSchema.model_validate(f) - for f in db_firmware_handler.list_firmware(platform_id=platform_id) - ] + firmware = db_firmware_handler.list_firmware( + platform_id=platform_id, + ) + return [FirmwareSchema.model_validate(f) for f in firmware] + + +@protected_route(router.get, "/identifiers", [Scope.FIRMWARE_READ]) +def get_firmware_identifiers( + request: Request, +) -> list[int]: + """Get firmware identifiers endpoint + + Args: + request (Request): Fastapi Request object + + Returns: + list[int]: List of firmware ids + """ + firmware = db_firmware_handler.list_firmware( + only_fields=[Firmware.id], + ) + return [f.id for f in firmware] @protected_route( diff --git a/backend/endpoints/platform.py b/backend/endpoints/platform.py index 9f77f766d..cfa497698 100644 --- a/backend/endpoints/platform.py +++ b/backend/endpoints/platform.py @@ -16,6 +16,7 @@ from handler.scan_handler import scan_platform from logger.formatter import BLUE from logger.formatter import highlight as hl from logger.logger import log +from models.platform import Platform from utils.platforms import get_supported_platforms from utils.router import APIRouter @@ -60,10 +61,20 @@ def get_platforms( ) -> list[PlatformSchema]: """Retrieve platforms.""" - return [ - PlatformSchema.model_validate(p) - for p in db_platform_handler.get_platforms(updated_after=updated_after) - ] + platforms = db_platform_handler.get_platforms(updated_after=updated_after) + return [PlatformSchema.model_validate(p) for p in platforms] + + +@protected_route(router.get, "/identifiers", [Scope.PLATFORMS_READ]) +def get_platform_identifiers( + request: Request, +) -> list[int]: + """Retrieve platform identifiers.""" + + platforms = db_platform_handler.get_platforms( + only_fields=[Platform.id], + ) + return [p.id for p in platforms] @protected_route(router.get, "/supported", [Scope.PLATFORMS_READ]) diff --git a/backend/endpoints/rom.py b/backend/endpoints/rom.py index 54b13bbf3..6fa00628c 100644 --- a/backend/endpoints/rom.py +++ b/backend/endpoints/rom.py @@ -66,7 +66,7 @@ from handler.metadata.ss_handler import get_preferred_media_types from logger.formatter import BLUE from logger.formatter import highlight as hl from logger.logger import log -from models.rom import Rom +from models.rom import Rom, RomNote from utils.database import safe_int, safe_str_to_bool from utils.filesystem import sanitize_filename from utils.hashing import crc32_to_hex @@ -517,6 +517,19 @@ def get_roms( ) +@protected_route(router.get, "/identifiers", [Scope.ROMS_READ]) +def get_rom_identifiers( + request: Request, +) -> list[int]: + """Retrieve rom identifiers.""" + db_roms = db_rom_handler.get_roms_scalar( + user_id=request.user.id, + only_fields=[Rom.id], + ) + + return [r.id for r in db_roms] + + @protected_route( router.get, "/download", @@ -1633,8 +1646,8 @@ async def get_rom_notes( request: Request, id: Annotated[int, PathVar(description="Rom internal id.", ge=1)], public_only: bool = DEFAULT_PUBLIC_ONLY, - search: str = DEFAULT_SEARCH, - tags: list[str] = DEFAULT_TAGS, + search: str | None = DEFAULT_SEARCH, + tags: list[str] | None = DEFAULT_TAGS, ) -> list[UserNoteSchema]: """Get all notes for a ROM.""" from handler.database import db_rom_handler @@ -1657,6 +1670,32 @@ async def get_rom_notes( return [UserNoteSchema.model_validate(note) for note in notes] +@protected_route( + router.get, + "/{id}/notes/identifiers", + [Scope.ROMS_READ], + responses={status.HTTP_404_NOT_FOUND: {}}, +) +async def get_rom_note_identifiers( + request: Request, + id: Annotated[int, PathVar(description="Rom internal id.", ge=1)], +) -> list[int]: + """Get all note identifiers for a ROM.""" + from handler.database import db_rom_handler + + rom = db_rom_handler.get_rom(id) + if not rom: + raise RomNotFoundInDatabaseException(id) + + notes = db_rom_handler.get_rom_notes( + rom_id=id, + user_id=request.user.id, + only_fields=[RomNote.id], + ) + + return [note.id for note in notes] + + @protected_route( router.post, "/{id}/notes", diff --git a/backend/endpoints/saves.py b/backend/endpoints/saves.py index 46d465616..a6bc2f460 100644 --- a/backend/endpoints/saves.py +++ b/backend/endpoints/saves.py @@ -1,7 +1,7 @@ from datetime import datetime, timezone from typing import Annotated -from fastapi import Body, HTTPException, Request, UploadFile, status +from fastapi import Body, HTTPException, Query, Request, UploadFile, status from decorators.auth import protected_route from endpoints.responses.assets import SaveSchema @@ -13,6 +13,7 @@ from handler.scan_handler import scan_save, scan_screenshot from logger.formatter import BLUE from logger.formatter import highlight as hl from logger.logger import log +from models.assets import Save from utils.router import APIRouter router = APIRouter( @@ -142,15 +143,39 @@ async def add_save( @protected_route(router.get, "", [Scope.ASSETS_READ]) def get_saves( - request: Request, rom_id: int | None = None, platform_id: int | None = None + request: Request, + rom_id: int | None = None, + platform_id: int | None = None, ) -> list[SaveSchema]: saves = db_save_handler.get_saves( - user_id=request.user.id, rom_id=rom_id, platform_id=platform_id + user_id=request.user.id, + rom_id=rom_id, + platform_id=platform_id, ) return [SaveSchema.model_validate(save) for save in saves] +@protected_route(router.get, "/identifiers", [Scope.ASSETS_READ]) +def get_save_identifiers( + request: Request, +) -> list[int]: + """Get save identifiers endpoint + + Args: + request (Request): Fastapi Request object + + Returns: + list[int]: List of save ids + """ + saves = db_save_handler.get_saves( + user_id=request.user.id, + only_fields=[Save.id], + ) + + return [save.id for save in saves] + + @protected_route(router.get, "/{id}", [Scope.ASSETS_READ]) def get_save(request: Request, id: int) -> SaveSchema: save = db_save_handler.get_save(user_id=request.user.id, id=id) diff --git a/backend/endpoints/states.py b/backend/endpoints/states.py index f26d6d857..bb46e6536 100644 --- a/backend/endpoints/states.py +++ b/backend/endpoints/states.py @@ -1,7 +1,7 @@ from datetime import datetime, timezone from typing import Annotated -from fastapi import Body, HTTPException, Request, UploadFile, status +from fastapi import Body, HTTPException, Query, Request, UploadFile, status from decorators.auth import protected_route from endpoints.responses.assets import StateSchema @@ -13,6 +13,7 @@ from handler.scan_handler import scan_screenshot, scan_state from logger.formatter import BLUE from logger.formatter import highlight as hl from logger.logger import log +from models.assets import State from utils.router import APIRouter router = APIRouter( @@ -144,15 +145,39 @@ async def add_state( @protected_route(router.get, "", [Scope.ASSETS_READ]) def get_states( - request: Request, rom_id: int | None = None, platform_id: int | None = None + request: Request, + rom_id: int | None = None, + platform_id: int | None = None, ) -> list[StateSchema]: states = db_state_handler.get_states( - user_id=request.user.id, rom_id=rom_id, platform_id=platform_id + user_id=request.user.id, + rom_id=rom_id, + platform_id=platform_id, ) return [StateSchema.model_validate(state) for state in states] +@protected_route(router.get, "/identifiers", [Scope.ASSETS_READ]) +def get_state_identifiers( + request: Request, +) -> list[int]: + """Get state identifiers endpoint + + Args: + request (Request): Fastapi Request object + + Returns: + list[int]: List of state ids + """ + states = db_state_handler.get_states( + user_id=request.user.id, + only_fields=[State.id], + ) + + return [state.id for state in states] + + @protected_route(router.get, "/{id}", [Scope.ASSETS_READ]) def get_state(request: Request, id: int) -> StateSchema: state = db_state_handler.get_state(user_id=request.user.id, id=id) diff --git a/backend/endpoints/user.py b/backend/endpoints/user.py index 48d31151f..5684ec881 100644 --- a/backend/endpoints/user.py +++ b/backend/endpoints/user.py @@ -3,7 +3,7 @@ from typing import Annotated, Any, cast from fastapi import Body, Form, HTTPException from fastapi import Path as PathVar -from fastapi import Request, status +from fastapi import Query, Request, status from decorators.auth import protected_route from endpoints.forms.identity import UserForm @@ -199,7 +199,9 @@ def create_user_from_invite( @protected_route(router.get, "", [Scope.USERS_READ]) -def get_users(request: Request) -> list[UserSchema]: +def get_users( + request: Request, +) -> list[UserSchema]: """Get all users endpoint Args: @@ -209,7 +211,27 @@ def get_users(request: Request) -> list[UserSchema]: list[UserSchema]: All users stored in the RomM's database """ - return [UserSchema.model_validate(u) for u in db_user_handler.get_users()] + users = db_user_handler.get_users() + return [UserSchema.model_validate(u) for u in users] + + +@protected_route(router.get, "/identifiers", [Scope.USERS_READ]) +def get_user_identifiers( + request: Request, +) -> list[int]: + """Get all user identifiers endpoint + + Args: + request (Request): Fastapi Request object + + Returns: + list[int]: All user ids stored in the RomM's database + """ + + users = db_user_handler.get_users( + only_fields=[User.id], + ) + return [u.id for u in users] @protected_route(router.get, "/me", [Scope.ME_READ]) diff --git a/backend/handler/database/collections_handler.py b/backend/handler/database/collections_handler.py index 571b61042..3b6da308f 100644 --- a/backend/handler/database/collections_handler.py +++ b/backend/handler/database/collections_handler.py @@ -169,20 +169,17 @@ class DBCollectionsHandler(DBBaseHandler): self, type: str, limit: int | None = None, - only_fields: Sequence[QueryableAttribute] | None = None, session: Session = None, # type: ignore ) -> Sequence[VirtualCollection]: - return ( - session.scalars( - select(VirtualCollection) - .filter(or_(VirtualCollection.type == type, literal(type == "all"))) - .limit(limit) - .order_by(VirtualCollection.name.asc()) - ) - .unique() - .all() + query = ( + select(VirtualCollection) + .filter(or_(VirtualCollection.type == type, literal(type == "all"))) + .limit(limit) + .order_by(VirtualCollection.name.asc()) ) + return session.scalars(query).unique().all() + # Smart collections @begin_session def add_smart_collection( @@ -233,6 +230,9 @@ class DBCollectionsHandler(DBBaseHandler): if updated_after: query = query.filter(SmartCollection.updated_at > updated_after) + if only_fields: + query = query.options(load_only(*only_fields)) + return session.scalars(query).unique().all() @begin_session diff --git a/backend/handler/database/roms_handler.py b/backend/handler/database/roms_handler.py index 07312f15b..69d7db444 100644 --- a/backend/handler/database/roms_handler.py +++ b/backend/handler/database/roms_handler.py @@ -1061,7 +1061,7 @@ class DBRomsHandler(DBBaseHandler): rom_id: int, user_id: int, public_only: bool = False, - search: str = "", + search: str | None = "", tags: list[str] | None = None, only_fields: Sequence[QueryableAttribute] | None = None, session: Session = None, # type: ignore