add endpoints for identifiers

This commit is contained in:
Georges-Antoine Assi
2026-01-18 22:17:37 -05:00
parent 54bfb3fec5
commit 3ccc14d4a2
9 changed files with 221 additions and 38 deletions

View File

@@ -24,7 +24,11 @@ from handler.filesystem.base_handler import CoverSize
from logger.formatter import BLUE
from logger.formatter import highlight as hl
from logger.logger import log
from models.collection import Collection, SmartCollection
from models.collection import (
Collection,
SmartCollection,
VirtualCollection,
)
from utils.router import APIRouter
router = APIRouter(
@@ -153,9 +157,6 @@ def get_collections(
description="Filter collections updated after this datetime (ISO 8601 format with timezone information)."
),
] = None,
only_ids: Annotated[
bool | None, Query(description="Only return list of IDs")
] = None,
) -> list[CollectionSchema]:
"""Get collections endpoint
@@ -168,12 +169,32 @@ def get_collections(
"""
collections = db_collection_handler.get_collections(
updated_after=updated_after, only_fields=[Collection.id] if only_ids else None
updated_after=updated_after,
)
return CollectionSchema.for_user(request.user.id, [c for c in collections])
@protected_route(router.get, "/identifiers", [Scope.COLLECTIONS_READ])
def get_collection_identifiers(
request: Request,
) -> list[int]:
"""Get collections identifiers endpoint
Args:
request (Request): Fastapi Request object
Returns:
list[int]: List of collection ids
"""
collections = db_collection_handler.get_collections(
only_fields=[Collection.id, Collection.user_id, Collection.is_public],
)
return [c.id for c in collections if c.user_id == request.user.id or c.is_public]
@protected_route(router.get, "/virtual", [Scope.COLLECTIONS_READ])
def get_virtual_collections(
request: Request,
@@ -215,7 +236,8 @@ def get_smart_collections(
"""
smart_collections = db_collection_handler.get_smart_collections(
request.user.id, updated_after=updated_after
request.user.id,
updated_after=updated_after,
)
return SmartCollectionSchema.for_user(
@@ -223,6 +245,27 @@ def get_smart_collections(
)
@protected_route(router.get, "/smart/identifiers", [Scope.COLLECTIONS_READ])
def get_smart_collection_identifiers(
request: Request,
) -> list[int]:
"""Get smart collections identifiers endpoint
Args:
request (Request): Fastapi Request object
Returns:
list[int]: List of smart collection ids
"""
smart_collections = db_collection_handler.get_smart_collections(
request.user.id,
only_fields=[SmartCollection.id],
)
return [s.id for s in smart_collections]
@protected_route(router.get, "/{id}", [Scope.COLLECTIONS_READ])
def get_collection(request: Request, id: int) -> CollectionSchema:
"""Get collections endpoint

View File

@@ -1,6 +1,6 @@
from typing import Annotated
from fastapi import Body, File, HTTPException, Request, UploadFile, status
from fastapi import Body, File, HTTPException, Query, Request, UploadFile, status
from fastapi.responses import FileResponse
from config import DISABLE_DOWNLOAD_ENDPOINT_AUTH
@@ -119,10 +119,28 @@ def get_platform_firmware(
Returns:
list[FirmwareSchema]: Firmware stored in the database
"""
return [
FirmwareSchema.model_validate(f)
for f in db_firmware_handler.list_firmware(platform_id=platform_id)
]
firmware = db_firmware_handler.list_firmware(
platform_id=platform_id,
)
return [FirmwareSchema.model_validate(f) for f in firmware]
@protected_route(router.get, "/identifiers", [Scope.FIRMWARE_READ])
def get_firmware_identifiers(
request: Request,
) -> list[int]:
"""Get firmware identifiers endpoint
Args:
request (Request): Fastapi Request object
Returns:
list[int]: List of firmware ids
"""
firmware = db_firmware_handler.list_firmware(
only_fields=[Firmware.id],
)
return [f.id for f in firmware]
@protected_route(

View File

@@ -16,6 +16,7 @@ from handler.scan_handler import scan_platform
from logger.formatter import BLUE
from logger.formatter import highlight as hl
from logger.logger import log
from models.platform import Platform
from utils.platforms import get_supported_platforms
from utils.router import APIRouter
@@ -60,10 +61,20 @@ def get_platforms(
) -> list[PlatformSchema]:
"""Retrieve platforms."""
return [
PlatformSchema.model_validate(p)
for p in db_platform_handler.get_platforms(updated_after=updated_after)
]
platforms = db_platform_handler.get_platforms(updated_after=updated_after)
return [PlatformSchema.model_validate(p) for p in platforms]
@protected_route(router.get, "/identifiers", [Scope.PLATFORMS_READ])
def get_platform_identifiers(
request: Request,
) -> list[int]:
"""Retrieve platform identifiers."""
platforms = db_platform_handler.get_platforms(
only_fields=[Platform.id],
)
return [p.id for p in platforms]
@protected_route(router.get, "/supported", [Scope.PLATFORMS_READ])

View File

@@ -66,7 +66,7 @@ from handler.metadata.ss_handler import get_preferred_media_types
from logger.formatter import BLUE
from logger.formatter import highlight as hl
from logger.logger import log
from models.rom import Rom
from models.rom import Rom, RomNote
from utils.database import safe_int, safe_str_to_bool
from utils.filesystem import sanitize_filename
from utils.hashing import crc32_to_hex
@@ -517,6 +517,19 @@ def get_roms(
)
@protected_route(router.get, "/identifiers", [Scope.ROMS_READ])
def get_rom_identifiers(
request: Request,
) -> list[int]:
"""Retrieve rom identifiers."""
db_roms = db_rom_handler.get_roms_scalar(
user_id=request.user.id,
only_fields=[Rom.id],
)
return [r.id for r in db_roms]
@protected_route(
router.get,
"/download",
@@ -1633,8 +1646,8 @@ async def get_rom_notes(
request: Request,
id: Annotated[int, PathVar(description="Rom internal id.", ge=1)],
public_only: bool = DEFAULT_PUBLIC_ONLY,
search: str = DEFAULT_SEARCH,
tags: list[str] = DEFAULT_TAGS,
search: str | None = DEFAULT_SEARCH,
tags: list[str] | None = DEFAULT_TAGS,
) -> list[UserNoteSchema]:
"""Get all notes for a ROM."""
from handler.database import db_rom_handler
@@ -1657,6 +1670,32 @@ async def get_rom_notes(
return [UserNoteSchema.model_validate(note) for note in notes]
@protected_route(
router.get,
"/{id}/notes/identifiers",
[Scope.ROMS_READ],
responses={status.HTTP_404_NOT_FOUND: {}},
)
async def get_rom_note_identifiers(
request: Request,
id: Annotated[int, PathVar(description="Rom internal id.", ge=1)],
) -> list[int]:
"""Get all note identifiers for a ROM."""
from handler.database import db_rom_handler
rom = db_rom_handler.get_rom(id)
if not rom:
raise RomNotFoundInDatabaseException(id)
notes = db_rom_handler.get_rom_notes(
rom_id=id,
user_id=request.user.id,
only_fields=[RomNote.id],
)
return [note.id for note in notes]
@protected_route(
router.post,
"/{id}/notes",

View File

@@ -1,7 +1,7 @@
from datetime import datetime, timezone
from typing import Annotated
from fastapi import Body, HTTPException, Request, UploadFile, status
from fastapi import Body, HTTPException, Query, Request, UploadFile, status
from decorators.auth import protected_route
from endpoints.responses.assets import SaveSchema
@@ -13,6 +13,7 @@ from handler.scan_handler import scan_save, scan_screenshot
from logger.formatter import BLUE
from logger.formatter import highlight as hl
from logger.logger import log
from models.assets import Save
from utils.router import APIRouter
router = APIRouter(
@@ -142,15 +143,39 @@ async def add_save(
@protected_route(router.get, "", [Scope.ASSETS_READ])
def get_saves(
request: Request, rom_id: int | None = None, platform_id: int | None = None
request: Request,
rom_id: int | None = None,
platform_id: int | None = None,
) -> list[SaveSchema]:
saves = db_save_handler.get_saves(
user_id=request.user.id, rom_id=rom_id, platform_id=platform_id
user_id=request.user.id,
rom_id=rom_id,
platform_id=platform_id,
)
return [SaveSchema.model_validate(save) for save in saves]
@protected_route(router.get, "/identifiers", [Scope.ASSETS_READ])
def get_save_identifiers(
request: Request,
) -> list[int]:
"""Get save identifiers endpoint
Args:
request (Request): Fastapi Request object
Returns:
list[int]: List of save ids
"""
saves = db_save_handler.get_saves(
user_id=request.user.id,
only_fields=[Save.id],
)
return [save.id for save in saves]
@protected_route(router.get, "/{id}", [Scope.ASSETS_READ])
def get_save(request: Request, id: int) -> SaveSchema:
save = db_save_handler.get_save(user_id=request.user.id, id=id)

View File

@@ -1,7 +1,7 @@
from datetime import datetime, timezone
from typing import Annotated
from fastapi import Body, HTTPException, Request, UploadFile, status
from fastapi import Body, HTTPException, Query, Request, UploadFile, status
from decorators.auth import protected_route
from endpoints.responses.assets import StateSchema
@@ -13,6 +13,7 @@ from handler.scan_handler import scan_screenshot, scan_state
from logger.formatter import BLUE
from logger.formatter import highlight as hl
from logger.logger import log
from models.assets import State
from utils.router import APIRouter
router = APIRouter(
@@ -144,15 +145,39 @@ async def add_state(
@protected_route(router.get, "", [Scope.ASSETS_READ])
def get_states(
request: Request, rom_id: int | None = None, platform_id: int | None = None
request: Request,
rom_id: int | None = None,
platform_id: int | None = None,
) -> list[StateSchema]:
states = db_state_handler.get_states(
user_id=request.user.id, rom_id=rom_id, platform_id=platform_id
user_id=request.user.id,
rom_id=rom_id,
platform_id=platform_id,
)
return [StateSchema.model_validate(state) for state in states]
@protected_route(router.get, "/identifiers", [Scope.ASSETS_READ])
def get_state_identifiers(
request: Request,
) -> list[int]:
"""Get state identifiers endpoint
Args:
request (Request): Fastapi Request object
Returns:
list[int]: List of state ids
"""
states = db_state_handler.get_states(
user_id=request.user.id,
only_fields=[State.id],
)
return [state.id for state in states]
@protected_route(router.get, "/{id}", [Scope.ASSETS_READ])
def get_state(request: Request, id: int) -> StateSchema:
state = db_state_handler.get_state(user_id=request.user.id, id=id)

View File

@@ -3,7 +3,7 @@ from typing import Annotated, Any, cast
from fastapi import Body, Form, HTTPException
from fastapi import Path as PathVar
from fastapi import Request, status
from fastapi import Query, Request, status
from decorators.auth import protected_route
from endpoints.forms.identity import UserForm
@@ -199,7 +199,9 @@ def create_user_from_invite(
@protected_route(router.get, "", [Scope.USERS_READ])
def get_users(request: Request) -> list[UserSchema]:
def get_users(
request: Request,
) -> list[UserSchema]:
"""Get all users endpoint
Args:
@@ -209,7 +211,27 @@ def get_users(request: Request) -> list[UserSchema]:
list[UserSchema]: All users stored in the RomM's database
"""
return [UserSchema.model_validate(u) for u in db_user_handler.get_users()]
users = db_user_handler.get_users()
return [UserSchema.model_validate(u) for u in users]
@protected_route(router.get, "/identifiers", [Scope.USERS_READ])
def get_user_identifiers(
request: Request,
) -> list[int]:
"""Get all user identifiers endpoint
Args:
request (Request): Fastapi Request object
Returns:
list[int]: All user ids stored in the RomM's database
"""
users = db_user_handler.get_users(
only_fields=[User.id],
)
return [u.id for u in users]
@protected_route(router.get, "/me", [Scope.ME_READ])

View File

@@ -169,20 +169,17 @@ class DBCollectionsHandler(DBBaseHandler):
self,
type: str,
limit: int | None = None,
only_fields: Sequence[QueryableAttribute] | None = None,
session: Session = None, # type: ignore
) -> Sequence[VirtualCollection]:
return (
session.scalars(
select(VirtualCollection)
.filter(or_(VirtualCollection.type == type, literal(type == "all")))
.limit(limit)
.order_by(VirtualCollection.name.asc())
)
.unique()
.all()
query = (
select(VirtualCollection)
.filter(or_(VirtualCollection.type == type, literal(type == "all")))
.limit(limit)
.order_by(VirtualCollection.name.asc())
)
return session.scalars(query).unique().all()
# Smart collections
@begin_session
def add_smart_collection(
@@ -233,6 +230,9 @@ class DBCollectionsHandler(DBBaseHandler):
if updated_after:
query = query.filter(SmartCollection.updated_at > updated_after)
if only_fields:
query = query.options(load_only(*only_fields))
return session.scalars(query).unique().all()
@begin_session

View File

@@ -1061,7 +1061,7 @@ class DBRomsHandler(DBBaseHandler):
rom_id: int,
user_id: int,
public_only: bool = False,
search: str = "",
search: str | None = "",
tags: list[str] | None = None,
only_fields: Sequence[QueryableAttribute] | None = None,
session: Session = None, # type: ignore