Files
romm/backend/handler/filesystem/assets_handler.py
Georges-Antoine Assi e3aaa106a2 perf(backend): reuse libmagic instance for image upload validation
magic.Magic(mime=True) loads the magic database from disk on construction;
instantiating it per request was adding pointless overhead to every avatar
and artwork upload. Share a module-level instance guarded by a lock (the
underlying magic_t handle is not thread-safe), and surface MagicException
as a 400 so a sniffing failure fails closed instead of bubbling a 500.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-09 10:14:38 -04:00

153 lines
5.1 KiB
Python

import hashlib
import os
import threading
import zipfile
from typing import Final
import magic
from fastapi import HTTPException, UploadFile, status
from config import ASSETS_BASE_PATH
from logger.logger import log
from models.user import User
from .base_handler import FSHandler
# Image MIME types we trust to (a) accept as avatar uploads and (b) serve
# inline from the raw asset endpoint
SAFE_IMAGE_MIME_TYPES: Final[dict[str, str]] = {
"image/png": "png",
"image/jpeg": "jpg",
"image/webp": "webp",
"image/gif": "gif",
}
# libmagic loads its database on construction (~few MB read from disk), so we
# share a single Magic instance across requests. The underlying magic_t handle
# is not thread-safe, so guard from_buffer with a lock — endpoints that call
# this validator may execute in worker threads under sync routes.
_MIME_DETECTOR = magic.Magic(mime=True)
_MIME_DETECTOR_LOCK = threading.Lock()
def validate_image_upload(upload: UploadFile, *, label: str = "Image") -> str:
"""Validate that an uploaded file is one of the safe image types.
Sniffs the leading bytes with libmagic and returns the trusted extension
matching the detected MIME type. Raises HTTPException(400) if the file
is not a recognized PNG/JPEG/WebP/GIF, or if MIME sniffing fails.
Leaves the file cursor at 0.
"""
upload.file.seek(0)
header = upload.file.read(4096)
upload.file.seek(0)
try:
with _MIME_DETECTOR_LOCK:
detected_mime = _MIME_DETECTOR.from_buffer(header)
except magic.MagicException as exc:
log.error(f"libmagic failed to sniff uploaded {label.lower()}: {exc}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Could not determine {label.lower()} file type",
) from exc
safe_extension = SAFE_IMAGE_MIME_TYPES.get(detected_mime)
if not safe_extension:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=(
f"{label} must be a PNG, JPEG, WebP, or GIF image "
f"(detected {detected_mime or 'unknown type'})"
),
)
return safe_extension
class FSAssetsHandler(FSHandler):
def __init__(self) -> None:
super().__init__(base_path=ASSETS_BASE_PATH)
def user_folder_path(self, user: User):
return os.path.join("users", user.fs_safe_folder_name)
# /users/557365723a31/profile
def build_avatar_path(self, user: User):
return os.path.join(self.user_folder_path(user), "profile")
def _build_asset_file_path(
self,
user: User,
folder: str,
platform_fs_slug: str,
rom_id: int,
emulator: str | None = None,
):
user_folder_path = self.user_folder_path(user)
assets_path = os.path.join(
user_folder_path, folder, platform_fs_slug, str(rom_id)
)
if emulator:
assets_path = os.path.join(assets_path, emulator)
return assets_path
# /users/557365723a31/saves/n64/{rom.id}/mupen64plus/
def build_saves_file_path(
self,
user: User,
platform_fs_slug: str,
rom_id: int,
emulator: str | None = None,
):
return self._build_asset_file_path(
user, "saves", platform_fs_slug, rom_id, emulator
)
# /users/557365723a31/states/n64/{rom.id}/mupen64plus
def build_states_file_path(
self,
user: User,
platform_fs_slug: str,
rom_id: int,
emulator: str | None = None,
):
return self._build_asset_file_path(
user, "states", platform_fs_slug, rom_id, emulator
)
# /users/557365723a31/screenshots/{rom.id}/n64
def build_screenshots_file_path(
self, user: User, platform_fs_slug: str, rom_id: int
):
return self._build_asset_file_path(
user, "screenshots", platform_fs_slug, rom_id
)
async def _compute_file_hash(self, file_path: str) -> str:
hash_obj = hashlib.md5(usedforsecurity=False)
async with await self.stream_file(file_path=file_path) as f:
while chunk := await f.read(8192):
hash_obj.update(chunk)
return hash_obj.hexdigest()
async def _compute_zip_hash(self, zip_path: str) -> str:
with zipfile.ZipFile(f"{self.base_path}/{zip_path}", "r") as zf:
file_hashes = []
for name in sorted(zf.namelist()):
if not name.endswith("/"):
content = zf.read(name)
file_hash = hashlib.md5(content, usedforsecurity=False).hexdigest()
file_hashes.append(f"{name}:{file_hash}")
combined = "\n".join(file_hashes)
return hashlib.md5(combined.encode(), usedforsecurity=False).hexdigest()
async def compute_content_hash(self, file_path: str) -> str | None:
try:
if zipfile.is_zipfile(file_path):
return await self._compute_zip_hash(file_path)
return await self._compute_file_hash(file_path)
except Exception as e:
log.debug(f"Failed to compute content hash for {file_path}: {e}")
return None