diff --git a/backend/endpoints/responses/rom.py b/backend/endpoints/responses/rom.py index 2644c6db8..1f1d447db 100644 --- a/backend/endpoints/responses/rom.py +++ b/backend/endpoints/responses/rom.py @@ -1,11 +1,12 @@ from __future__ import annotations import re +from collections.abc import Sequence from datetime import datetime, timezone from typing import NotRequired, TypedDict, get_type_hints from fastapi import Request -from pydantic import ConfigDict, computed_field, field_validator +from pydantic import ConfigDict, Field, computed_field, field_validator from endpoints.responses.assets import SaveSchema, ScreenshotSchema, StateSchema from handler.metadata.flashpoint_handler import FlashpointMetadata @@ -18,7 +19,7 @@ from handler.metadata.moby_handler import MobyMetadata from handler.metadata.ra_handler import RAMetadata from handler.metadata.ss_handler import SSMetadata from models.collection import Collection -from models.rom import Rom, RomArchiveMember, RomFileCategory, RomUserStatus +from models.rom import Rom, RomArchiveMember, RomFile, RomFileCategory, RomUserStatus from .base import BaseModel, UTCDatetime @@ -335,16 +336,50 @@ class SiblingRomSchema(BaseModel): .lower() ) + @classmethod + def from_rom(cls, rom: Rom) -> SiblingRomSchema: + return cls( + id=rom.id, + name=rom.name, + fs_name_no_tags=rom.fs_name_no_tags, + fs_name_no_ext=rom.fs_name_no_ext, + ) + class SimpleRomSchema(RomSchema): sibling_ids: list[int] + # `files` binds to a dedicated attribute via validation alias rather than the + # `Rom.files` relationship: that keeps the full file list out of the default + # gallery payload, and avoids tripping the relationship's lazy="raise" in + # factory contexts where files aren't loaded. Populated only when requested. + files: list[RomFileSchema] = Field( + default_factory=list, validation_alias="included_files" + ) + siblings: list[SiblingRomSchema] = Field(default_factory=list) + + @field_validator("files") + def sort_files(cls, v: list[RomFileSchema]) -> list[RomFileSchema]: + return sorted(v, key=lambda x: x.file_name) + + @field_validator("siblings") + def sort_siblings(cls, v: list[SiblingRomSchema]) -> list[SiblingRomSchema]: + return sorted(v, key=lambda x: x.sort_comparator) @classmethod def from_orm_with_request( - cls, db_rom: Rom, request: Request, sibling_ids: list[int] | None = None + cls, + db_rom: Rom, + request: Request, + sibling_ids: list[int] | None = None, + files: Sequence[RomFile] | None = None, + siblings: Sequence[Rom] | None = None, ) -> SimpleRomSchema: db_rom = cls.populate_properties(db_rom, request) db_rom.sibling_ids = sibling_ids or [] # type: ignore + db_rom.included_files = files or [] # type: ignore + db_rom.siblings = ( # type: ignore + [SiblingRomSchema.from_rom(s) for s in siblings] if siblings else [] + ) return cls.model_validate(db_rom) @classmethod @@ -397,15 +432,7 @@ class DetailedRomSchema(RomSchema): db_rom = cls.populate_properties(db_rom, request) sorted_siblings = sorted( - ( - SiblingRomSchema( - id=s.id, - name=s.name, - fs_name_no_tags=s.fs_name_no_tags, - fs_name_no_ext=s.fs_name_no_ext, - ) - for s in db_rom.sibling_roms - ), + (SiblingRomSchema.from_rom(s) for s in db_rom.sibling_roms), key=lambda x: x.sort_comparator, ) db_rom.siblings = sorted_siblings # type: ignore diff --git a/backend/endpoints/roms/__init__.py b/backend/endpoints/roms/__init__.py index 06f2407b0..bb62e9b7c 100644 --- a/backend/endpoints/roms/__init__.py +++ b/backend/endpoints/roms/__init__.py @@ -482,6 +482,14 @@ def get_roms( description="Filter roms updated after this datetime (ISO 8601 format with timezone information)." ), ] = None, + with_files: Annotated[ + bool, + Query(description="Whether to include each rom's file entries."), + ] = False, + with_siblings: Annotated[ + bool, + Query(description="Whether to include each rom's sibling roms."), + ] = False, ) -> CustomLimitOffsetPage[SimpleRomSchema]: """Retrieve roms.""" unfiltered_query, order_by_attr = db_rom_handler.get_roms_query( @@ -571,8 +579,19 @@ def get_roms( rom_id_index = session.scalars(query.with_only_columns(Rom.id)).all() # type: ignore def _transform(items: Sequence[Rom]) -> list[SimpleRomSchema]: + rom_ids = [i.id for i in items] sibling_ids_by_rom = db_rom_handler.get_sibling_ids_for_roms( - [i.id for i in items], session=session + rom_ids, session=session + ) + files_by_rom = ( + db_rom_handler.get_files_for_roms(rom_ids, session=session) + if with_files + else {} + ) + siblings_by_rom = ( + db_rom_handler.get_siblings_for_roms(rom_ids, session=session) + if with_siblings + else {} ) return [ @@ -580,6 +599,8 @@ def get_roms( db_rom=item, request=request, sibling_ids=sibling_ids_by_rom.get(item.id, []), + files=files_by_rom.get(item.id, []), + siblings=siblings_by_rom.get(item.id, []), ) for item in items ] diff --git a/backend/endpoints/sockets/scan.py b/backend/endpoints/sockets/scan.py index 2d0938214..c765a8355 100644 --- a/backend/endpoints/sockets/scan.py +++ b/backend/endpoints/sockets/scan.py @@ -361,6 +361,7 @@ async def _identify_rom( "rom_user", "last_modified", "files", + "siblings", } ), ) @@ -480,7 +481,14 @@ async def _identify_rom( await socket_manager.emit( "scan:scanning_rom", SimpleRomSchema.from_orm_with_factory(_added_rom).model_dump( - exclude={"created_at", "updated_at", "rom_user", "last_modified", "files"} + exclude={ + "created_at", + "updated_at", + "rom_user", + "last_modified", + "files", + "siblings", + } ), ) diff --git a/backend/handler/database/roms_handler.py b/backend/handler/database/roms_handler.py index fc927607d..9ad1277dc 100644 --- a/backend/handler/database/roms_handler.py +++ b/backend/handler/database/roms_handler.py @@ -231,6 +231,52 @@ class DBRomsHandler(DBBaseHandler): return {rom_id: sorted(ids) for rom_id, ids in buckets.items()} + def get_files_for_roms( + self, + rom_ids: list[int], + *, + session: Session, + ) -> dict[int, list[RomFile]]: + """Return {rom_id: [RomFile, ...]} for the given rom IDs in a single query. + + Used by the list endpoint to serialize files without relying on the + query's relationship eager-load surviving pagination. + """ + if not rom_ids: + return {} + + files = session.scalars( + select(RomFile).where(RomFile.rom_id.in_(rom_ids)) + ).all() + + buckets: dict[int, list[RomFile]] = {rom_id: [] for rom_id in rom_ids} + for file in files: + buckets[file.rom_id].append(file) + + return buckets + + def get_siblings_for_roms( + self, + rom_ids: list[int], + *, + session: Session, + ) -> dict[int, list[Rom]]: + """Return {rom_id: [sibling Rom, ...]} for the given rom IDs in a single query.""" + if not rom_ids: + return {} + + rows = session.execute( + select(SiblingRom.rom_id, Rom) + .join(Rom, Rom.id == SiblingRom.sibling_rom_id) + .where(SiblingRom.rom_id.in_(rom_ids)) + ).all() + + buckets: dict[int, list[Rom]] = {rom_id: [] for rom_id in rom_ids} + for rom_id, sibling in rows: + buckets[rom_id].append(sibling) + + return buckets + def filter_by_platform_id(self, query: Query, platform_id: int): return query.filter(Rom.platform_id == platform_id) diff --git a/backend/handler/scan_handler.py b/backend/handler/scan_handler.py index cc2d01c0e..b837a1abe 100644 --- a/backend/handler/scan_handler.py +++ b/backend/handler/scan_handler.py @@ -443,6 +443,7 @@ async def scan_rom( "rom_user", "last_modified", "files", + "siblings", } ), }, diff --git a/backend/tests/endpoints/roms/test_rom.py b/backend/tests/endpoints/roms/test_rom.py index d7193f364..b8ceedb57 100644 --- a/backend/tests/endpoints/roms/test_rom.py +++ b/backend/tests/endpoints/roms/test_rom.py @@ -4,6 +4,7 @@ from unittest.mock import AsyncMock, patch from fastapi import status from fastapi.testclient import TestClient +from handler.database import db_rom_handler from handler.filesystem.resources_handler import FSResourcesHandler from handler.filesystem.roms_handler import FSRomsHandler from handler.metadata.flashpoint_handler import FlashpointHandler, FlashpointRom @@ -14,7 +15,8 @@ from handler.metadata.moby_handler import MobyGamesHandler, MobyGamesRom from handler.metadata.ra_handler import RAGameRom, RAHandler from handler.metadata.ss_handler import SSHandler, SSRom from models.platform import Platform -from models.rom import Rom +from models.rom import Rom, RomFile +from models.user import User MOCK_IGDB_ID = 11111 MOCK_MOBY_ID = 22222 @@ -80,6 +82,73 @@ def test_get_all_roms( items = body["items"] assert len(items) == 1 assert items[0]["id"] == rom.id + assert items[0]["files"] == [] + assert items[0]["siblings"] == [] + + +def test_get_all_roms_with_files( + client: TestClient, access_token: str, rom: Rom, platform: Platform +): + db_rom_handler.add_rom_file( + RomFile( + rom_id=rom.id, + file_name="test_rom.zip", + file_path=f"{platform.slug}/roms", + file_size_bytes=1024, + last_modified=1700000000.0, + ) + ) + + response = client.get( + "/api/roms", + headers={"Authorization": f"Bearer {access_token}"}, + params={"platform_id": platform.id, "with_files": True}, + ) + assert response.status_code == status.HTTP_200_OK + + item = response.json()["items"][0] + assert item["id"] == rom.id + assert len(item["files"]) == 1 + assert item["files"][0]["file_name"] == "test_rom.zip" + # with_files alone must not pull in siblings. + assert item["siblings"] == [] + + +def test_get_all_roms_with_siblings( + client: TestClient, access_token: str, platform: Platform, admin_user: User +): + siblings = [ + db_rom_handler.add_rom( + Rom( + platform_id=platform.id, + igdb_id=424242, + name=name, + slug=slug, + fs_name=f"{slug}.zip", + fs_name_no_tags=slug, + fs_name_no_ext=slug, + fs_extension="zip", + fs_path=f"{platform.slug}/roms", + ) + ) + for name, slug in (("Game A", "game_a"), ("Game B", "game_b")) + ] + for sibling in siblings: + db_rom_handler.add_rom_user(rom_id=sibling.id, user_id=admin_user.id) + + response = client.get( + "/api/roms", + headers={"Authorization": f"Bearer {access_token}"}, + params={"platform_id": platform.id, "with_siblings": True}, + ) + assert response.status_code == status.HTTP_200_OK + + items = {item["id"]: item for item in response.json()["items"]} + rom_a, rom_b = siblings + assert [s["id"] for s in items[rom_a.id]["siblings"]] == [rom_b.id] + assert [s["id"] for s in items[rom_b.id]["siblings"]] == [rom_a.id] + # with_siblings alone must not pull in files. + assert items[rom_a.id]["files"] == [] def test_get_rom_content_requires_auth(client: TestClient, rom: Rom, rom_file): diff --git a/frontend/src/__generated__/models/SimpleRomSchema.ts b/frontend/src/__generated__/models/SimpleRomSchema.ts index 920e2e5db..c2b28a3f5 100644 --- a/frontend/src/__generated__/models/SimpleRomSchema.ts +++ b/frontend/src/__generated__/models/SimpleRomSchema.ts @@ -3,6 +3,7 @@ /* tslint:disable */ /* eslint-disable */ import type { ManualMetadata } from './ManualMetadata'; +import type { RomFileSchema } from './RomFileSchema'; import type { RomFlashpointMetadata } from './RomFlashpointMetadata'; import type { RomGamelistMetadata } from './RomGamelistMetadata'; import type { RomHasheousMetadata } from './RomHasheousMetadata'; @@ -14,6 +15,7 @@ import type { RomMobyMetadata } from './RomMobyMetadata'; import type { RomRAMetadata } from './RomRAMetadata'; import type { RomSSMetadata } from './RomSSMetadata'; import type { RomUserSchema } from './RomUserSchema'; +import type { SiblingRomSchema } from './SiblingRomSchema'; export type SimpleRomSchema = { id: number; igdb_id: (number | null); @@ -84,5 +86,7 @@ export type SimpleRomSchema = { merged_screenshots: Array; merged_ra_metadata: (RomRAMetadata | null); sibling_ids: Array; + files?: Array; + siblings?: Array; };