Add opt-in files/siblings expansion to GET /api/roms

This commit is contained in:
nendo
2026-06-07 15:09:24 +09:00
parent bb82cf470a
commit 37f0feab8c
7 changed files with 191 additions and 15 deletions

View File

@@ -1,11 +1,12 @@
from __future__ import annotations
import re
from collections.abc import Sequence
from datetime import datetime, timezone
from typing import NotRequired, TypedDict, get_type_hints
from fastapi import Request
from pydantic import ConfigDict, computed_field, field_validator
from pydantic import ConfigDict, Field, computed_field, field_validator
from endpoints.responses.assets import SaveSchema, ScreenshotSchema, StateSchema
from handler.metadata.flashpoint_handler import FlashpointMetadata
@@ -18,7 +19,7 @@ from handler.metadata.moby_handler import MobyMetadata
from handler.metadata.ra_handler import RAMetadata
from handler.metadata.ss_handler import SSMetadata
from models.collection import Collection
from models.rom import Rom, RomArchiveMember, RomFileCategory, RomUserStatus
from models.rom import Rom, RomArchiveMember, RomFile, RomFileCategory, RomUserStatus
from .base import BaseModel, UTCDatetime
@@ -335,16 +336,50 @@ class SiblingRomSchema(BaseModel):
.lower()
)
@classmethod
def from_rom(cls, rom: Rom) -> SiblingRomSchema:
return cls(
id=rom.id,
name=rom.name,
fs_name_no_tags=rom.fs_name_no_tags,
fs_name_no_ext=rom.fs_name_no_ext,
)
class SimpleRomSchema(RomSchema):
sibling_ids: list[int]
# `files` binds to a dedicated attribute via validation alias rather than the
# `Rom.files` relationship: that keeps the full file list out of the default
# gallery payload, and avoids tripping the relationship's lazy="raise" in
# factory contexts where files aren't loaded. Populated only when requested.
files: list[RomFileSchema] = Field(
default_factory=list, validation_alias="included_files"
)
siblings: list[SiblingRomSchema] = Field(default_factory=list)
@field_validator("files")
def sort_files(cls, v: list[RomFileSchema]) -> list[RomFileSchema]:
return sorted(v, key=lambda x: x.file_name)
@field_validator("siblings")
def sort_siblings(cls, v: list[SiblingRomSchema]) -> list[SiblingRomSchema]:
return sorted(v, key=lambda x: x.sort_comparator)
@classmethod
def from_orm_with_request(
cls, db_rom: Rom, request: Request, sibling_ids: list[int] | None = None
cls,
db_rom: Rom,
request: Request,
sibling_ids: list[int] | None = None,
files: Sequence[RomFile] | None = None,
siblings: Sequence[Rom] | None = None,
) -> SimpleRomSchema:
db_rom = cls.populate_properties(db_rom, request)
db_rom.sibling_ids = sibling_ids or [] # type: ignore
db_rom.included_files = files or [] # type: ignore
db_rom.siblings = ( # type: ignore
[SiblingRomSchema.from_rom(s) for s in siblings] if siblings else []
)
return cls.model_validate(db_rom)
@classmethod
@@ -397,15 +432,7 @@ class DetailedRomSchema(RomSchema):
db_rom = cls.populate_properties(db_rom, request)
sorted_siblings = sorted(
(
SiblingRomSchema(
id=s.id,
name=s.name,
fs_name_no_tags=s.fs_name_no_tags,
fs_name_no_ext=s.fs_name_no_ext,
)
for s in db_rom.sibling_roms
),
(SiblingRomSchema.from_rom(s) for s in db_rom.sibling_roms),
key=lambda x: x.sort_comparator,
)
db_rom.siblings = sorted_siblings # type: ignore

View File

@@ -482,6 +482,14 @@ def get_roms(
description="Filter roms updated after this datetime (ISO 8601 format with timezone information)."
),
] = None,
with_files: Annotated[
bool,
Query(description="Whether to include each rom's file entries."),
] = False,
with_siblings: Annotated[
bool,
Query(description="Whether to include each rom's sibling roms."),
] = False,
) -> CustomLimitOffsetPage[SimpleRomSchema]:
"""Retrieve roms."""
unfiltered_query, order_by_attr = db_rom_handler.get_roms_query(
@@ -571,8 +579,19 @@ def get_roms(
rom_id_index = session.scalars(query.with_only_columns(Rom.id)).all() # type: ignore
def _transform(items: Sequence[Rom]) -> list[SimpleRomSchema]:
rom_ids = [i.id for i in items]
sibling_ids_by_rom = db_rom_handler.get_sibling_ids_for_roms(
[i.id for i in items], session=session
rom_ids, session=session
)
files_by_rom = (
db_rom_handler.get_files_for_roms(rom_ids, session=session)
if with_files
else {}
)
siblings_by_rom = (
db_rom_handler.get_siblings_for_roms(rom_ids, session=session)
if with_siblings
else {}
)
return [
@@ -580,6 +599,8 @@ def get_roms(
db_rom=item,
request=request,
sibling_ids=sibling_ids_by_rom.get(item.id, []),
files=files_by_rom.get(item.id, []),
siblings=siblings_by_rom.get(item.id, []),
)
for item in items
]

View File

@@ -361,6 +361,7 @@ async def _identify_rom(
"rom_user",
"last_modified",
"files",
"siblings",
}
),
)
@@ -480,7 +481,14 @@ async def _identify_rom(
await socket_manager.emit(
"scan:scanning_rom",
SimpleRomSchema.from_orm_with_factory(_added_rom).model_dump(
exclude={"created_at", "updated_at", "rom_user", "last_modified", "files"}
exclude={
"created_at",
"updated_at",
"rom_user",
"last_modified",
"files",
"siblings",
}
),
)

View File

@@ -231,6 +231,52 @@ class DBRomsHandler(DBBaseHandler):
return {rom_id: sorted(ids) for rom_id, ids in buckets.items()}
def get_files_for_roms(
self,
rom_ids: list[int],
*,
session: Session,
) -> dict[int, list[RomFile]]:
"""Return {rom_id: [RomFile, ...]} for the given rom IDs in a single query.
Used by the list endpoint to serialize files without relying on the
query's relationship eager-load surviving pagination.
"""
if not rom_ids:
return {}
files = session.scalars(
select(RomFile).where(RomFile.rom_id.in_(rom_ids))
).all()
buckets: dict[int, list[RomFile]] = {rom_id: [] for rom_id in rom_ids}
for file in files:
buckets[file.rom_id].append(file)
return buckets
def get_siblings_for_roms(
self,
rom_ids: list[int],
*,
session: Session,
) -> dict[int, list[Rom]]:
"""Return {rom_id: [sibling Rom, ...]} for the given rom IDs in a single query."""
if not rom_ids:
return {}
rows = session.execute(
select(SiblingRom.rom_id, Rom)
.join(Rom, Rom.id == SiblingRom.sibling_rom_id)
.where(SiblingRom.rom_id.in_(rom_ids))
).all()
buckets: dict[int, list[Rom]] = {rom_id: [] for rom_id in rom_ids}
for rom_id, sibling in rows:
buckets[rom_id].append(sibling)
return buckets
def filter_by_platform_id(self, query: Query, platform_id: int):
return query.filter(Rom.platform_id == platform_id)

View File

@@ -443,6 +443,7 @@ async def scan_rom(
"rom_user",
"last_modified",
"files",
"siblings",
}
),
},

View File

@@ -4,6 +4,7 @@ from unittest.mock import AsyncMock, patch
from fastapi import status
from fastapi.testclient import TestClient
from handler.database import db_rom_handler
from handler.filesystem.resources_handler import FSResourcesHandler
from handler.filesystem.roms_handler import FSRomsHandler
from handler.metadata.flashpoint_handler import FlashpointHandler, FlashpointRom
@@ -14,7 +15,8 @@ from handler.metadata.moby_handler import MobyGamesHandler, MobyGamesRom
from handler.metadata.ra_handler import RAGameRom, RAHandler
from handler.metadata.ss_handler import SSHandler, SSRom
from models.platform import Platform
from models.rom import Rom
from models.rom import Rom, RomFile
from models.user import User
MOCK_IGDB_ID = 11111
MOCK_MOBY_ID = 22222
@@ -80,6 +82,73 @@ def test_get_all_roms(
items = body["items"]
assert len(items) == 1
assert items[0]["id"] == rom.id
assert items[0]["files"] == []
assert items[0]["siblings"] == []
def test_get_all_roms_with_files(
client: TestClient, access_token: str, rom: Rom, platform: Platform
):
db_rom_handler.add_rom_file(
RomFile(
rom_id=rom.id,
file_name="test_rom.zip",
file_path=f"{platform.slug}/roms",
file_size_bytes=1024,
last_modified=1700000000.0,
)
)
response = client.get(
"/api/roms",
headers={"Authorization": f"Bearer {access_token}"},
params={"platform_id": platform.id, "with_files": True},
)
assert response.status_code == status.HTTP_200_OK
item = response.json()["items"][0]
assert item["id"] == rom.id
assert len(item["files"]) == 1
assert item["files"][0]["file_name"] == "test_rom.zip"
# with_files alone must not pull in siblings.
assert item["siblings"] == []
def test_get_all_roms_with_siblings(
client: TestClient, access_token: str, platform: Platform, admin_user: User
):
siblings = [
db_rom_handler.add_rom(
Rom(
platform_id=platform.id,
igdb_id=424242,
name=name,
slug=slug,
fs_name=f"{slug}.zip",
fs_name_no_tags=slug,
fs_name_no_ext=slug,
fs_extension="zip",
fs_path=f"{platform.slug}/roms",
)
)
for name, slug in (("Game A", "game_a"), ("Game B", "game_b"))
]
for sibling in siblings:
db_rom_handler.add_rom_user(rom_id=sibling.id, user_id=admin_user.id)
response = client.get(
"/api/roms",
headers={"Authorization": f"Bearer {access_token}"},
params={"platform_id": platform.id, "with_siblings": True},
)
assert response.status_code == status.HTTP_200_OK
items = {item["id"]: item for item in response.json()["items"]}
rom_a, rom_b = siblings
assert [s["id"] for s in items[rom_a.id]["siblings"]] == [rom_b.id]
assert [s["id"] for s in items[rom_b.id]["siblings"]] == [rom_a.id]
# with_siblings alone must not pull in files.
assert items[rom_a.id]["files"] == []
def test_get_rom_content_requires_auth(client: TestClient, rom: Rom, rom_file):

View File

@@ -3,6 +3,7 @@
/* tslint:disable */
/* eslint-disable */
import type { ManualMetadata } from './ManualMetadata';
import type { RomFileSchema } from './RomFileSchema';
import type { RomFlashpointMetadata } from './RomFlashpointMetadata';
import type { RomGamelistMetadata } from './RomGamelistMetadata';
import type { RomHasheousMetadata } from './RomHasheousMetadata';
@@ -14,6 +15,7 @@ import type { RomMobyMetadata } from './RomMobyMetadata';
import type { RomRAMetadata } from './RomRAMetadata';
import type { RomSSMetadata } from './RomSSMetadata';
import type { RomUserSchema } from './RomUserSchema';
import type { SiblingRomSchema } from './SiblingRomSchema';
export type SimpleRomSchema = {
id: number;
igdb_id: (number | null);
@@ -84,5 +86,7 @@ export type SimpleRomSchema = {
merged_screenshots: Array<string>;
merged_ra_metadata: (RomRAMetadata | null);
sibling_ids: Array<number>;
files?: Array<RomFileSchema>;
siblings?: Array<SiblingRomSchema>;
};