diff --git a/backend/endpoints/responses/stats.py b/backend/endpoints/responses/stats.py new file mode 100644 index 000000000..ee199ecaf --- /dev/null +++ b/backend/endpoints/responses/stats.py @@ -0,0 +1,10 @@ +from typing_extensions import TypedDict + + +class StatsReturn(TypedDict): + PLATFORMS: int + ROMS: int + SAVES: int + STATES: int + SCREENSHOTS: int + FILESIZE: int diff --git a/backend/endpoints/sockets/scan.py b/backend/endpoints/sockets/scan.py index 437a7eb16..5c61cd46d 100644 --- a/backend/endpoints/sockets/scan.py +++ b/backend/endpoints/sockets/scan.py @@ -8,10 +8,10 @@ from exceptions.fs_exceptions import ( RomsNotFoundException, ) from handler import ( - dbh, dbplatformh, dbromh, dbsaveh, + dbscreenshotsh, dbstateh, fsasseth, fsplatformh, @@ -19,6 +19,7 @@ from handler import ( fsromh, socketh, ) +from handler.fs_handler import Asset from handler.redis_handler import high_prio_queue, redis_url from handler.scan_handler import ( scan_platform, @@ -28,7 +29,6 @@ from handler.scan_handler import ( scan_state, ) from logger.logger import log -from handler.fs_handler import Asset def _get_socket_manager(): @@ -145,14 +145,13 @@ async def scan_platforms( if save: # Update file size if changed if save.file_size_bytes != scanned_save.file_size_bytes: - dbh.update_save( + dbsaveh.update_save( save.id, {"file_size_bytes": scanned_save.file_size_bytes} ) continue scanned_save.emulator = fs_emulator - rom = dbromh.get_rom_by_filename_no_tags(scanned_save.file_name_no_tags) if rom: scanned_save.rom_id = rom.id dbsaveh.add_save(scanned_save) @@ -173,7 +172,7 @@ async def scan_platforms( if state: # Update file size if changed if state.file_size_bytes != scanned_state.file_size_bytes: - dbh.update_state( + dbstateh.update_state( state.id, {"file_size_bytes": scanned_state.file_size_bytes} ) @@ -181,66 +180,66 @@ async def scan_platforms( scanned_state.emulator = fs_emulator - rom = dbromh.get_rom_by_filename_no_tags(scanned_state.file_name_no_tags) if rom: scanned_state.rom_id = rom.id dbstateh.add_state(scanned_state) - # # Scanning screenshots - # log.info(f"\t · {len(fs_assets['screenshots'])} screenshots found") - # for fs_screenshot_filename in fs_assets["screenshots"]: - # scanned_screenshot = scan_screenshot( - # file_name=fs_screenshot_filename, platform=platform - # ) + # Scanning screenshots + fs_screenshots = fsasseth.get_assets( + platform.fs_slug, rom.file_name_no_tags, Asset.SCREENSHOTS + ) + log.info(f"\t · {len(fs_screenshots)} screenshots found") + for fs_screenshot_filename in fs_screenshots: + scanned_screenshot = scan_screenshot( + file_name=fs_screenshot_filename, platform=platform + ) - # screenshot = dbh.get_screenshot_by_filename(fs_screenshot_filename) - # if screenshot: - # # Update file size if changed - # if screenshot.file_size_bytes != scanned_screenshot.file_size_bytes: - # dbh.update_screenshot( - # screenshot.id, - # {"file_size_bytes": scanned_screenshot.file_size_bytes}, - # ) - # continue + screenshot = dbscreenshotsh.get_screenshot_by_filename( + fs_screenshot_filename + ) + if screenshot: + # Update file size if changed + if screenshot.file_size_bytes != scanned_screenshot.file_size_bytes: + dbscreenshotsh.update_screenshot( + screenshot.id, + {"file_size_bytes": scanned_screenshot.file_size_bytes}, + ) + continue - # rom = dbh.get_rom_by_filename_no_tags(scanned_screenshot.file_name_no_tags) + if rom: + scanned_screenshot.rom_id = rom.id + dbscreenshotsh.add_screenshot(scanned_screenshot) - # if rom: - # scanned_screenshot.rom_id = rom.id - # dbh.add_screenshot(scanned_screenshot) - - # for fs_rom in fs_roms: - # rom = dbh.get_rom_by_filename(platform.id, fs_rom["file_name"]) - # dbsaveh.purge_saves(rom.id, [s for _e, s in fs_assets["saves"]]) - # dbstateh.purge_states(rom.id, [s for _e, s in fs_assets["states"]]) - # dbh.purge_screenshots(rom.id, fs_assets["screenshots"]) - # dbromh.purge_roms(platform.id, [rom["file_name"] for rom in fs_roms]) + dbsaveh.purge_saves(rom.id, [s for _e, s in fs_saves]) + dbstateh.purge_states(rom.id, [s for _e, s in fs_states]) + dbscreenshotsh.purge_screenshots(rom.id, fs_screenshots) + dbromh.purge_roms(platform.id, [rom["file_name"] for rom in fs_roms]) # Scanning screenshots outside platform folders - # fs_screenshots = fsasseth.get_screenshots() - # log.info("Screenshots") - # log.info(f" · {len(fs_screenshots)} screenshots found") - # for fs_platform, fs_screenshot_filename in fs_screenshots: - # scanned_screenshot = scan_screenshot( - # file_name=fs_screenshot_filename, fs_platform=fs_platform - # ) + fs_screenshots = fsasseth.get_screenshots() + log.info("Screenshots") + log.info(f" · {len(fs_screenshots)} screenshots found") + for fs_platform, fs_screenshot_filename in fs_screenshots: + scanned_screenshot = scan_screenshot( + file_name=fs_screenshot_filename, fs_platform=fs_platform + ) - # screenshot = dbh.get_screenshot_by_filename(fs_screenshot_filename) - # if screenshot: - # # Update file size if changed - # if screenshot.file_size_bytes != scanned_screenshot.file_size_bytes: - # dbh.update_screenshot( - # screenshot.id, - # {"file_size_bytes": scanned_screenshot.file_size_bytes}, - # ) - # continue + screenshot = dbscreenshotsh.get_screenshot_by_filename(fs_screenshot_filename) + if screenshot: + # Update file size if changed + if screenshot.file_size_bytes != scanned_screenshot.file_size_bytes: + dbscreenshotsh.update_screenshot( + screenshot.id, + {"file_size_bytes": scanned_screenshot.file_size_bytes}, + ) + continue - # rom = dbh.get_rom_by_filename_no_tags(scanned_screenshot.file_name_no_tags) - # if rom: - # scanned_screenshot.rom_id = rom.id - # dbh.add_screenshot(scanned_screenshot) + rom = dbromh.get_rom_by_filename_no_tags(scanned_screenshot.file_name_no_tags) + if rom: + scanned_screenshot.rom_id = rom.id + dbscreenshotsh.add_screenshot(scanned_screenshot) - # dbh.purge_screenshots([s for _e, s in fs_screenshots]) + # dbscreenshotsh.purge_screenshots([s for _e, s in fs_screenshots]) dbplatformh.purge_platforms(fs_platforms) log.info(emoji.emojize(":check_mark: Scan completed ")) diff --git a/backend/endpoints/stats.py b/backend/endpoints/stats.py new file mode 100644 index 000000000..817bbc8ba --- /dev/null +++ b/backend/endpoints/stats.py @@ -0,0 +1,23 @@ +from endpoints.responses.stats import StatsReturn +from fastapi import APIRouter +from handler import dbstatsh + +router = APIRouter() + + +@router.get("/stats") +def stats() -> StatsReturn: + """Endpoint to return the current RomM stats + + Returns: + dict: Dictionary with all the stats + """ + + return { + "PLATFORMS": dbstatsh.get_platforms_count(), + "ROMS": dbstatsh.get_roms_count(), + "SAVES": dbstatsh.get_saves_count(), + "STATES": dbstatsh.get_states_count(), + "SCREENSHOTS": dbstatsh.get_screenshots_count(), + "FILESIZE": dbstatsh.get_total_filesize(), + } diff --git a/backend/endpoints/user.py b/backend/endpoints/user.py index 3c515f1fc..83cdd708d 100644 --- a/backend/endpoints/user.py +++ b/backend/endpoints/user.py @@ -5,8 +5,8 @@ from decorators.auth import protected_route from endpoints.forms.identity import UserForm from endpoints.responses import MessageResponse from endpoints.responses.identity import UserSchema -from fastapi import APIRouter, Depends, HTTPException, Request, status -from handler import authh, dbh, fsresourceh +from fastapi import APIRouter, Depends, HTTPException, Request +from handler import authh, dbuserh, fsresourceh from models.user import Role, User router = APIRouter() @@ -41,7 +41,7 @@ def add_user(request: Request, username: str, password: str, role: str) -> UserS role=Role[role.upper()], ) - return dbh.add_user(user) + return dbuserh.add_user(user) @protected_route(router.get, "/users", ["users.read"]) @@ -55,7 +55,7 @@ def get_users(request: Request) -> list[UserSchema]: list[UserSchema]: All users stored in the RomM's database """ - return dbh.get_users() + return dbuserh.get_users() @protected_route(router.get, "/users/me", ["me.read"]) @@ -83,7 +83,7 @@ def get_user(request: Request, id: int) -> UserSchema: UserSchem: User stored in the RomM's database """ - user = dbh.get_user(id) + user = dbuserh.get_user(id) if not user: raise HTTPException(status_code=404, detail="User not found") @@ -115,14 +115,14 @@ def update_user( status_code=400, detail="Cannot update user: ROMM_AUTH_ENABLED is set to False", ) - user = dbh.get_user(id) + user = dbuserh.get_user(id) if not user: raise HTTPException(status_code=404, detail="User not found") cleaned_data = {} if form_data.username and form_data.username != user.username: - existing_user = dbh.get_user_by_username(form_data.username.lower()) + existing_user = dbuserh.get_user_by_username(form_data.username.lower()) if existing_user: raise HTTPException( status_code=400, detail="Username already in use by another user" @@ -150,7 +150,7 @@ def update_user( file_object.write(form_data.avatar.file.read()) if cleaned_data: - dbh.update_user(id, cleaned_data) + dbuserh.update_user(id, cleaned_data) # Log out the current user if username or password changed creds_updated = cleaned_data.get("username") or cleaned_data.get( @@ -159,7 +159,7 @@ def update_user( if request.user.id == id and creds_updated: authh.clear_session(request) - return dbh.get_user(id) + return dbuserh.get_user(id) @protected_route(router.delete, "/users/{id}", ["users.write"]) @@ -186,7 +186,7 @@ def delete_user(request: Request, id: int) -> MessageResponse: detail="Cannot delete user: ROMM_AUTH_ENABLED is set to False", ) - user = dbh.get_user(id) + user = dbuserh.get_user(id) if not user: raise HTTPException(status_code=404, detail="User not found") @@ -195,11 +195,11 @@ def delete_user(request: Request, id: int) -> MessageResponse: raise HTTPException(status_code=400, detail="You cannot delete yourself") # You can't delete the last admin user - if user.role == Role.ADMIN and len(dbh.get_admin_users()) == 1: + if user.role == Role.ADMIN and len(dbuserh.get_admin_users()) == 1: raise HTTPException( status_code=400, detail="You cannot delete the last admin user" ) - dbh.delete_user(id) + dbuserh.delete_user(id) return {"msg": "User successfully deleted"} diff --git a/backend/handler/__init__.py b/backend/handler/__init__.py index 8cdeea52f..fdd2544d8 100644 --- a/backend/handler/__init__.py +++ b/backend/handler/__init__.py @@ -1,8 +1,11 @@ -from handler.auth_handler.auth_handler import AuthHandler, OAuthHandler +from handler.auth_handler import AuthHandler, OAuthHandler from handler.db_handler.db_platforms_handler import DBPlatformsHandler from handler.db_handler.db_roms_handler import DBRomsHandler from handler.db_handler.db_saves_handler import DBSavesHandler from handler.db_handler.db_states_handler import DBStatesHandler +from handler.db_handler.db_users_handler import DBUsersHandler +from handler.db_handler.db_stats_handler import DBStatsHandler +from handler.db_handler.db_screenshots_handler import DBScreenshotsHandler from handler.fs_handler.fs_assets_handler import FSAssetsHandler from handler.fs_handler.fs_platforms_handler import FSPlatformsHandler from handler.fs_handler.fs_resources_handler import FSResourceHandler @@ -22,12 +25,10 @@ dbplatformh = DBPlatformsHandler() dbromh = DBRomsHandler() dbsaveh = DBSavesHandler() dbstateh = DBStatesHandler() +dbuserh = DBUsersHandler() +dbstatsh = DBStatsHandler() +dbscreenshotsh = DBScreenshotsHandler() fsplatformh = FSPlatformsHandler() fsromh = FSRomsHandler() fsasseth = FSAssetsHandler() fsresourceh = FSResourceHandler() - - -from handler.db_handler.db_handler import DBHandler - -dbh = DBHandler() diff --git a/backend/handler/auth_handler/__init__.py b/backend/handler/auth_handler/__init__.py index 768c3b017..4aa6d0e11 100644 --- a/backend/handler/auth_handler/__init__.py +++ b/backend/handler/auth_handler/__init__.py @@ -1,5 +1,20 @@ +from datetime import datetime, timedelta from typing import Final +from config import ( + ROMM_AUTH_ENABLED, + ROMM_AUTH_PASSWORD, + ROMM_AUTH_SECRET_KEY, + ROMM_AUTH_USERNAME, +) +from exceptions.auth_exceptions import OAuthCredentialsException +from fastapi import HTTPException, Request, status +from handler.redis_handler import cache +from jose import JWTError, jwt +from passlib.context import CryptContext +from sqlalchemy.exc import IntegrityError +from starlette.requests import HTTPConnection + ALGORITHM: Final = "HS256" DEFAULT_OAUTH_TOKEN_EXPIRY: Final = 15 @@ -26,3 +41,122 @@ FULL_SCOPES_MAP: Final = { DEFAULT_SCOPES: Final = list(DEFAULT_SCOPES_MAP.keys()) WRITE_SCOPES: Final = DEFAULT_SCOPES + list(WRITE_SCOPES_MAP.keys()) FULL_SCOPES: Final = WRITE_SCOPES + list(FULL_SCOPES_MAP.keys()) + + +class AuthHandler: + def __init__(self) -> None: + self.pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") + + def _verify_password(self, plain_password, hashed_password): + return self.pwd_context.verify(plain_password, hashed_password) + + def get_password_hash(self, password): + return self.pwd_context.hash(password) + + @staticmethod + def clear_session(req: HTTPConnection | Request): + session_id = req.session.get("session_id") + if session_id: + redish.cache.delete(f"romm:{session_id}") # type: ignore[attr-defined] + req.session["session_id"] = None + + def authenticate_user(self, username: str, password: str): + from handler import dbuserh + + user = dbuserh.get_user_by_username(username) + if not user: + return None + + if not self._verify_password(password, user.hashed_password): + return None + + return user + + async def get_current_active_user_from_session(self, conn: HTTPConnection): + from handler import dbuserh + + # Check if session key already stored in cache + session_id = conn.session.get("session_id") + if not session_id: + return None + + username = cache.get(f"romm:{session_id}") # type: ignore[attr-defined] + if not username: + return None + + # Key exists therefore user is probably authenticated + user = dbuserh.get_user_by_username(username) + if user is None: + self.clear_session(conn) + + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="User not found", + ) + + if not user.enabled: + self.clear_session(conn) + + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, detail="Inactive user" + ) + + return user + + def create_default_admin_user(self): + from handler import dbuserh + from models.user import Role, User + + if not ROMM_AUTH_ENABLED: + return + + try: + dbuserh.add_user( + User( + username=ROMM_AUTH_USERNAME, + hashed_password=self.get_password_hash(ROMM_AUTH_PASSWORD), + role=Role.ADMIN, + ) + ) + except IntegrityError: + pass + + +class OAuthHandler: + def __init__(self) -> None: + pass + + def create_oauth_token(data: dict, expires_delta: timedelta | None = None): + to_encode = data.copy() + + if expires_delta: + expire = datetime.utcnow() + expires_delta + else: + expire = datetime.utcnow() + timedelta(minutes=DEFAULT_OAUTH_TOKEN_EXPIRY) + + to_encode.update({"exp": expire}) + + return jwt.encode(to_encode, ROMM_AUTH_SECRET_KEY, algorithm=ALGORITHM) + + async def get_current_active_user_from_bearer_token(token: str): + from handler import dbuserh + + try: + payload = jwt.decode(token, ROMM_AUTH_SECRET_KEY, algorithms=[ALGORITHM]) + except JWTError: + raise OAuthCredentialsException + + username = payload.get("sub") + if username is None: + raise OAuthCredentialsException + + user = dbuserh.get_user_by_username(username) + if user is None: + raise OAuthCredentialsException + + if not user.enabled: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, detail="Inactive user" + ) + + return user, payload diff --git a/backend/handler/auth_handler/auth_handler.py b/backend/handler/auth_handler/auth_handler.py deleted file mode 100644 index 1cd6810f8..000000000 --- a/backend/handler/auth_handler/auth_handler.py +++ /dev/null @@ -1,135 +0,0 @@ -from datetime import datetime, timedelta - -from config import ( - ROMM_AUTH_ENABLED, - ROMM_AUTH_PASSWORD, - ROMM_AUTH_SECRET_KEY, - ROMM_AUTH_USERNAME, -) -from exceptions.auth_exceptions import OAuthCredentialsException -from fastapi import HTTPException, Request, status -from handler.auth_handler import ALGORITHM, DEFAULT_OAUTH_TOKEN_EXPIRY -from jose import JWTError, jwt -from passlib.context import CryptContext -from sqlalchemy.exc import IntegrityError -from starlette.requests import HTTPConnection -from handler.redis_handler import cache - - -class AuthHandler: - def __init__(self) -> None: - self.pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") - - def _verify_password(self, plain_password, hashed_password): - return self.pwd_context.verify(plain_password, hashed_password) - - def get_password_hash(self, password): - return self.pwd_context.hash(password) - - @staticmethod - def clear_session(req: HTTPConnection | Request): - session_id = req.session.get("session_id") - if session_id: - redish.cache.delete(f"romm:{session_id}") # type: ignore[attr-defined] - req.session["session_id"] = None - - def authenticate_user(self, username: str, password: str): - from handler import dbh - - user = dbh.get_user_by_username(username) - if not user: - return None - - if not self._verify_password(password, user.hashed_password): - return None - - return user - - async def get_current_active_user_from_session(self, conn: HTTPConnection): - from handler import dbh - - # Check if session key already stored in cache - session_id = conn.session.get("session_id") - if not session_id: - return None - - username = cache.get(f"romm:{session_id}") # type: ignore[attr-defined] - if not username: - return None - - # Key exists therefore user is probably authenticated - user = dbh.get_user_by_username(username) - if user is None: - self.clear_session(conn) - - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail="User not found", - ) - - if not user.enabled: - self.clear_session(conn) - - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, detail="Inactive user" - ) - - return user - - def create_default_admin_user(self): - from handler import dbh - from models.user import Role, User - - if not ROMM_AUTH_ENABLED: - return - - try: - dbh.add_user( - User( - username=ROMM_AUTH_USERNAME, - hashed_password=self.get_password_hash(ROMM_AUTH_PASSWORD), - role=Role.ADMIN, - ) - ) - except IntegrityError: - pass - - -class OAuthHandler: - def __init__(self) -> None: - pass - - def create_oauth_token(data: dict, expires_delta: timedelta | None = None): - to_encode = data.copy() - - if expires_delta: - expire = datetime.utcnow() + expires_delta - else: - expire = datetime.utcnow() + timedelta(minutes=DEFAULT_OAUTH_TOKEN_EXPIRY) - - to_encode.update({"exp": expire}) - - return jwt.encode(to_encode, ROMM_AUTH_SECRET_KEY, algorithm=ALGORITHM) - - async def get_current_active_user_from_bearer_token(token: str): - from handler import dbh - - try: - payload = jwt.decode(token, ROMM_AUTH_SECRET_KEY, algorithms=[ALGORITHM]) - except JWTError: - raise OAuthCredentialsException - - username = payload.get("sub") - if username is None: - raise OAuthCredentialsException - - user = dbh.get_user_by_username(username) - if user is None: - raise OAuthCredentialsException - - if not user.enabled: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, detail="Inactive user" - ) - - return user, payload diff --git a/backend/handler/db_handler/__init__.py b/backend/handler/db_handler/__init__.py index e69de29bb..9c6014df7 100644 --- a/backend/handler/db_handler/__init__.py +++ b/backend/handler/db_handler/__init__.py @@ -0,0 +1,9 @@ +from config.config_manager import ConfigManager +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker + + +class DBHandler: + def __init__(self) -> None: + self.engine = create_engine(ConfigManager.get_db_engine(), pool_pre_ping=True) + self.session = sessionmaker(bind=self.engine, expire_on_commit=False) diff --git a/backend/handler/db_handler/db_handler.py b/backend/handler/db_handler/db_handler.py deleted file mode 100644 index e9d7289ae..000000000 --- a/backend/handler/db_handler/db_handler.py +++ /dev/null @@ -1,132 +0,0 @@ -from config.config_manager import ConfigManager -from decorators.database import begin_session -from models import Role, Platform, Rom, Save, Screenshot, State, User -from sqlalchemy import create_engine, delete, func, select, update -from sqlalchemy.orm import Session, sessionmaker - - -class DBHandler: - def __init__(self) -> None: - self.engine = create_engine(ConfigManager.get_db_engine(), pool_pre_ping=True) - self.session = sessionmaker(bind=self.engine, expire_on_commit=False) - - # ========= Screenshots ========= - @begin_session - def add_screenshot(self, screenshot: Screenshot, session: Session = None): - return session.merge(screenshot) - - @begin_session - def get_screenshot(self, id, session: Session = None): - return session.get(Screenshot, id) - - @begin_session - def get_screenshot_by_filename(self, file_name: str, session: Session = None): - return session.scalars( - select(Screenshot).filter_by(file_name=file_name).limit(1) - ).first() - - @begin_session - def update_screenshot(self, id: int, data: dict, session: Session = None): - session.execute( - update(Screenshot) - .where(Screenshot.id == id) - .values(**data) - .execution_options(synchronize_session="evaluate") - ) - - @begin_session - def delete_screenshot(self, id: int, session: Session = None): - return session.execute( - delete(Screenshot) - .where(Screenshot.id == id) - .execution_options(synchronize_session="evaluate") - ) - - @begin_session - def purge_screenshots( - self, rom_id: int, screenshots: list[str], session: Session = None - ): - return session.execute( - delete(Screenshot) - .where( - Screenshot.rom_id == rom_id, - Screenshot.file_name.not_in(screenshots), - ) - .execution_options(synchronize_session="evaluate") - ) - - # ========= Users ========= - @begin_session - def add_user(self, user: User, session: Session = None): - return session.merge(user) - - @begin_session - def get_user_by_username(self, username: str, session: Session = None): - return session.scalars( - select(User).filter_by(username=username).limit(1) - ).first() - - @begin_session - def get_user(self, id: int, session: Session = None): - return session.get(User, id) - - @begin_session - def update_user(self, id: int, data: dict, session: Session = None): - session.execute( - update(User) - .where(User.id == id) - .values(**data) - .execution_options(synchronize_session="evaluate") - ) - - @begin_session - def delete_user(self, id: int, session: Session = None): - return session.execute( - delete(User) - .where(User.id == id) - .execution_options(synchronize_session="evaluate") - ) - - @begin_session - def get_users(self, session: Session = None): - return session.scalars(select(User)).all() - - @begin_session - def get_admin_users(self, session: Session = None): - return session.scalars(select(User).filter_by(role=Role.ADMIN)).all() - - # ========= Stats ========= - @begin_session - def get_platforms_count(self, session: Session = None): - # Only count platforms with more then 0 roms - return session.scalar( - select(func.count()) - .select_from(Platform) - .where( - select(func.count()) - .select_from(Rom) - .filter_by(platform_id=Platform.id) - .as_scalar() - > 0 - ) - ) - - @begin_session - def get_roms_count(self, session: Session = None) -> int: - return session.scalar(select(func.count()).select_from(Rom)) - - @begin_session - def get_saves_count(self, session: Session = None) -> int: - return session.scalar(select(func.count()).select_from(Save)) - - @begin_session - def get_states_count(self, session: Session = None) -> int: - return session.scalar(select(func.count()).select_from(State)) - - @begin_session - def get_screenshots_count(self, session: Session = None) -> int: - return session.scalar(select(func.count()).select_from(Screenshot)) - - @begin_session - def get_total_filesize(self, session: Session = None) -> int: - return 0 diff --git a/backend/handler/db_handler/db_platforms_handler.py b/backend/handler/db_handler/db_platforms_handler.py index 6e260871a..31582d26f 100644 --- a/backend/handler/db_handler/db_platforms_handler.py +++ b/backend/handler/db_handler/db_platforms_handler.py @@ -1,10 +1,9 @@ -from models import Platform, Rom, Save, State +from decorators.database import begin_session +from handler.db_handler import DBHandler +from models import Platform, Rom from sqlalchemy import delete, func, or_, select from sqlalchemy.orm import Session -from decorators.database import begin_session -from handler.db_handler.db_handler import DBHandler - class DBPlatformsHandler(DBHandler): @begin_session diff --git a/backend/handler/db_handler/db_roms_handler.py b/backend/handler/db_handler/db_roms_handler.py index 5ef8bc795..84ba028aa 100644 --- a/backend/handler/db_handler/db_roms_handler.py +++ b/backend/handler/db_handler/db_roms_handler.py @@ -1,5 +1,5 @@ from decorators.database import begin_session -from handler.db_handler.db_handler import DBHandler +from handler.db_handler import DBHandler from models import Rom from sqlalchemy import and_, delete, func, select, update from sqlalchemy.orm import Session diff --git a/backend/handler/db_handler/db_saves_handler.py b/backend/handler/db_handler/db_saves_handler.py index b8c44b11d..49f2082d6 100644 --- a/backend/handler/db_handler/db_saves_handler.py +++ b/backend/handler/db_handler/db_saves_handler.py @@ -1,5 +1,5 @@ from decorators.database import begin_session -from handler.db_handler.db_handler import DBHandler +from handler.db_handler import DBHandler from models import Save from sqlalchemy import and_, delete, select, update from sqlalchemy.orm import Session diff --git a/backend/handler/db_handler/db_screenshots_handler.py b/backend/handler/db_handler/db_screenshots_handler.py new file mode 100644 index 000000000..e0d5e110c --- /dev/null +++ b/backend/handler/db_handler/db_screenshots_handler.py @@ -0,0 +1,51 @@ +from decorators.database import begin_session +from handler.db_handler import DBHandler +from models import Screenshot +from sqlalchemy import delete, select, update +from sqlalchemy.orm import Session + + +class DBScreenshotsHandler(DBHandler): + @begin_session + def add_screenshot(self, screenshot: Screenshot, session: Session = None): + return session.merge(screenshot) + + @begin_session + def get_screenshot(self, id, session: Session = None): + return session.get(Screenshot, id) + + @begin_session + def get_screenshot_by_filename(self, file_name: str, session: Session = None): + return session.scalars( + select(Screenshot).filter_by(file_name=file_name).limit(1) + ).first() + + @begin_session + def update_screenshot(self, id: int, data: dict, session: Session = None): + session.execute( + update(Screenshot) + .where(Screenshot.id == id) + .values(**data) + .execution_options(synchronize_session="evaluate") + ) + + @begin_session + def delete_screenshot(self, id: int, session: Session = None): + return session.execute( + delete(Screenshot) + .where(Screenshot.id == id) + .execution_options(synchronize_session="evaluate") + ) + + @begin_session + def purge_screenshots( + self, rom_id: int, screenshots: list[str], session: Session = None + ): + return session.execute( + delete(Screenshot) + .where( + Screenshot.rom_id == rom_id, + Screenshot.file_name.not_in(screenshots), + ) + .execution_options(synchronize_session="evaluate") + ) diff --git a/backend/handler/db_handler/db_states_handler.py b/backend/handler/db_handler/db_states_handler.py index 0ef902cd3..49caf9a21 100644 --- a/backend/handler/db_handler/db_states_handler.py +++ b/backend/handler/db_handler/db_states_handler.py @@ -1,5 +1,5 @@ from decorators.database import begin_session -from handler.db_handler.db_handler import DBHandler +from handler.db_handler import DBHandler from models import State from sqlalchemy import and_, delete, select, update from sqlalchemy.orm import Session diff --git a/backend/handler/db_handler/db_stats_handler.py b/backend/handler/db_handler/db_stats_handler.py new file mode 100644 index 000000000..f6c7d70a2 --- /dev/null +++ b/backend/handler/db_handler/db_stats_handler.py @@ -0,0 +1,42 @@ +from decorators.database import begin_session +from handler.db_handler import DBHandler +from models import Platform, Rom, Save, Screenshot, State +from sqlalchemy import func, select +from sqlalchemy.orm import Session + + +class DBStatsHandler(DBHandler): + @begin_session + def get_platforms_count(self, session: Session = None): + # Only count platforms with more then 0 roms + return session.scalar( + select(func.count()) + .select_from(Platform) + .where( + select(func.count()) + .select_from(Rom) + .filter_by(platform_id=Platform.id) + .as_scalar() + > 0 + ) + ) + + @begin_session + def get_roms_count(self, session: Session = None) -> int: + return session.scalar(select(func.count()).select_from(Rom)) + + @begin_session + def get_saves_count(self, session: Session = None) -> int: + return session.scalar(select(func.count()).select_from(Save)) + + @begin_session + def get_states_count(self, session: Session = None) -> int: + return session.scalar(select(func.count()).select_from(State)) + + @begin_session + def get_screenshots_count(self, session: Session = None) -> int: + return session.scalar(select(func.count()).select_from(Screenshot)) + + @begin_session + def get_total_filesize(self, session: Session = None) -> int: + return 0 diff --git a/backend/handler/db_handler/db_users_handler.py b/backend/handler/db_handler/db_users_handler.py new file mode 100644 index 000000000..6fccbb827 --- /dev/null +++ b/backend/handler/db_handler/db_users_handler.py @@ -0,0 +1,46 @@ +from decorators.database import begin_session +from handler.db_handler import DBHandler +from models import Role, User +from sqlalchemy import delete, select, update +from sqlalchemy.orm import Session + + +class DBUsersHandler(DBHandler): + @begin_session + def add_user(self, user: User, session: Session = None): + return session.merge(user) + + @begin_session + def get_user_by_username(self, username: str, session: Session = None): + return session.scalars( + select(User).filter_by(username=username).limit(1) + ).first() + + @begin_session + def get_user(self, id: int, session: Session = None): + return session.get(User, id) + + @begin_session + def update_user(self, id: int, data: dict, session: Session = None): + session.execute( + update(User) + .where(User.id == id) + .values(**data) + .execution_options(synchronize_session="evaluate") + ) + + @begin_session + def delete_user(self, id: int, session: Session = None): + return session.execute( + delete(User) + .where(User.id == id) + .execution_options(synchronize_session="evaluate") + ) + + @begin_session + def get_users(self, session: Session = None): + return session.scalars(select(User)).all() + + @begin_session + def get_admin_users(self, session: Session = None): + return session.scalars(select(User).filter_by(role=Role.ADMIN)).all() diff --git a/backend/handler/fs_handler/__init__.py b/backend/handler/fs_handler/__init__.py index ed373c626..672c2de76 100644 --- a/backend/handler/fs_handler/__init__.py +++ b/backend/handler/fs_handler/__init__.py @@ -1,7 +1,12 @@ +import os +import re +import shutil +from abc import ABC from enum import Enum from typing import Final -from config import ROMM_BASE_PATH +from config import LIBRARY_BASE_PATH, ROMM_BASE_PATH +from config.config_manager import config_manager as cm RESOURCES_BASE_PATH: Final = f"{ROMM_BASE_PATH}/resources" DEFAULT_WIDTH_COVER_L: Final = 264 # Width of big cover of IGDB @@ -79,3 +84,43 @@ class Asset(Enum): SAVES = "saves" STATES = "states" SCREENSHOTS = "screenshots" + + +class FSHandler(ABC): + def __init__(self) -> None: + pass + + @staticmethod + def get_fs_structure(fs_slug: str, folder: str = cm.config.ROMS_FOLDER_NAME): + return ( + f"{folder}/{fs_slug}" + if os.path.exists(cm.config.HIGH_PRIO_STRUCTURE_PATH) + else f"{fs_slug}/{folder}" + ) + + @staticmethod + def _get_file_name_with_no_extension(file_name: str) -> str: + return re.sub(EXTENSION_REGEX, "", file_name).strip() + + @staticmethod + def get_file_name_with_no_tags(file_name: str) -> str: + file_name_no_extension = re.sub(EXTENSION_REGEX, "", file_name).strip() + return re.split(TAG_REGEX, file_name_no_extension)[0].strip() + + @staticmethod + def parse_file_extension(file_name) -> str: + match = re.search(EXTENSION_REGEX, file_name) + return match.group(1) if match else "" + + @staticmethod + def remove_file(file_name: str, file_path: str): + try: + os.remove(f"{LIBRARY_BASE_PATH}/{file_path}/{file_name}") + except IsADirectoryError: + shutil.rmtree(f"{LIBRARY_BASE_PATH}/{file_path}/{file_name}") + + def build_upload_file_path( + self, fs_slug: str, folder: str = cm.config.ROMS_FOLDER_NAME + ): + rom_path = self.get_fs_structure(fs_slug, folder=folder) + return f"{LIBRARY_BASE_PATH}/{rom_path}" diff --git a/backend/handler/fs_handler/fs_assets_handler.py b/backend/handler/fs_handler/fs_assets_handler.py index bdfb25237..f8ede9b52 100644 --- a/backend/handler/fs_handler/fs_assets_handler.py +++ b/backend/handler/fs_handler/fs_assets_handler.py @@ -7,8 +7,7 @@ import requests from config import LIBRARY_BASE_PATH from config.config_manager import config_manager as cm from fastapi import UploadFile -from handler.fs_handler import RESOURCES_BASE_PATH, Asset -from handler.fs_handler.fs_handler import FSHandler +from handler.fs_handler import RESOURCES_BASE_PATH, Asset, FSHandler from logger.logger import log @@ -86,19 +85,17 @@ class FSAssetsHandler(FSHandler): saves_file_path = f"{LIBRARY_BASE_PATH}/{assets_path}" - fs_assets: list[str] = [] - # fs_states: list[str] = [] - # fs_screenshots: list[str] = [] + assets: list[str] = [] try: emulators = list(os.walk(saves_file_path))[0][1] for emulator in emulators: - fs_assets += [ + assets += [ (emulator, file) for file in list(os.walk(f"{saves_file_path}/{emulator}"))[0][2] ] - fs_assets += [ + assets += [ (None, file) for file in list(os.walk(saves_file_path))[0][2] if file.split(".")[0] == rom_file_name_no_tags @@ -106,40 +103,7 @@ class FSAssetsHandler(FSHandler): except IndexError: pass - # states_path = self.get_fs_structure( - # platform_slug, folder=cm.config.STATES_FOLDER_NAME - # ) - # states_file_path = f"{LIBRARY_BASE_PATH}/{states_path}" - - # try: - # emulators = list(os.walk(states_file_path))[0][1] - # for emulator in emulators: - # fs_states += [ - # (emulator, file) - # for file in list(os.walk(f"{states_file_path}/{emulator}"))[0][2] - # ] - - # fs_states += [ - # (None, file) for file in list(os.walk(states_file_path))[0][2] - # ] - # except IndexError: - # pass - - # screenshots_path = self.get_fs_structure( - # platform_slug, folder=cm.config.SCREENSHOTS_FOLDER_NAME - # ) - # screenshots_file_path = f"{LIBRARY_BASE_PATH}/{screenshots_path}" - - # try: - # fs_screenshots += [ - # file for file in list(os.walk(screenshots_file_path))[0][2] - # ] - # except IndexError: - # pass - - return fs_assets - # "states": fs_states, - # "screenshots": fs_screenshots, + return assets @staticmethod def get_screenshots(): diff --git a/backend/handler/fs_handler/fs_handler.py b/backend/handler/fs_handler/fs_handler.py deleted file mode 100644 index 275a8e886..000000000 --- a/backend/handler/fs_handler/fs_handler.py +++ /dev/null @@ -1,48 +0,0 @@ -import os -import re -import shutil -from abc import ABC - -from config import LIBRARY_BASE_PATH -from config.config_manager import config_manager as cm -from handler.fs_handler import EXTENSION_REGEX, TAG_REGEX - - -class FSHandler(ABC): - def __init__(self) -> None: - pass - - @staticmethod - def get_fs_structure(fs_slug: str, folder: str = cm.config.ROMS_FOLDER_NAME): - return ( - f"{folder}/{fs_slug}" - if os.path.exists(cm.config.HIGH_PRIO_STRUCTURE_PATH) - else f"{fs_slug}/{folder}" - ) - - @staticmethod - def _get_file_name_with_no_extension(file_name: str) -> str: - return re.sub(EXTENSION_REGEX, "", file_name).strip() - - @staticmethod - def get_file_name_with_no_tags(file_name: str) -> str: - file_name_no_extension = re.sub(EXTENSION_REGEX, "", file_name).strip() - return re.split(TAG_REGEX, file_name_no_extension)[0].strip() - - @staticmethod - def parse_file_extension(file_name) -> str: - match = re.search(EXTENSION_REGEX, file_name) - return match.group(1) if match else "" - - def build_upload_file_path( - self, fs_slug: str, folder: str = cm.config.ROMS_FOLDER_NAME - ): - rom_path = self.get_fs_structure(fs_slug, folder=folder) - return f"{LIBRARY_BASE_PATH}/{rom_path}" - - @staticmethod - def remove_file(file_name: str, file_path: str): - try: - os.remove(f"{LIBRARY_BASE_PATH}/{file_path}/{file_name}") - except IsADirectoryError: - shutil.rmtree(f"{LIBRARY_BASE_PATH}/{file_path}/{file_name}") diff --git a/backend/handler/fs_handler/fs_platforms_handler.py b/backend/handler/fs_handler/fs_platforms_handler.py index e73acc96f..3c51f7799 100644 --- a/backend/handler/fs_handler/fs_platforms_handler.py +++ b/backend/handler/fs_handler/fs_platforms_handler.py @@ -3,7 +3,7 @@ import os from config import LIBRARY_BASE_PATH from config.config_manager import config_manager as cm from exceptions.fs_exceptions import FolderStructureNotMatchException -from handler.fs_handler.fs_handler import FSHandler +from handler.fs_handler import FSHandler class FSPlatformsHandler(FSHandler): diff --git a/backend/handler/fs_handler/fs_resources_handler.py b/backend/handler/fs_handler/fs_resources_handler.py index 4d945fa91..240313e82 100644 --- a/backend/handler/fs_handler/fs_resources_handler.py +++ b/backend/handler/fs_handler/fs_resources_handler.py @@ -18,8 +18,8 @@ from handler.fs_handler import ( DEFAULT_WIDTH_COVER_S, RESOURCES_BASE_PATH, CoverSize, + FSHandler, ) -from handler.fs_handler.fs_handler import FSHandler from PIL import Image diff --git a/backend/handler/fs_handler/fs_roms_handler.py b/backend/handler/fs_handler/fs_roms_handler.py index 704eb59e8..575888823 100644 --- a/backend/handler/fs_handler/fs_roms_handler.py +++ b/backend/handler/fs_handler/fs_roms_handler.py @@ -2,7 +2,6 @@ import fnmatch import os import re from pathlib import Path -from models.platform import Platform from config import LIBRARY_BASE_PATH from config.config_manager import config_manager as cm @@ -13,8 +12,9 @@ from handler.fs_handler import ( REGIONS_BY_SHORTCODE, REGIONS_NAME_KEYS, TAG_REGEX, + FSHandler, ) -from handler.fs_handler.fs_handler import FSHandler +from models.platform import Platform class FSRomsHandler(FSHandler): diff --git a/backend/handler/igdb_handler.py b/backend/handler/igdb_handler.py index 25c1110be..bc9c03722 100644 --- a/backend/handler/igdb_handler.py +++ b/backend/handler/igdb_handler.py @@ -284,7 +284,7 @@ class IGDBHandler: @check_twitch_token async def get_rom(self, file_name: str, platform_idgb_id: int) -> IGDBRomType: # TODO: refactor - from handler.fs_handler.fs_handler import FSHandler + from handler.fs_handler import FSHandler get_search_term = FSHandler.get_file_name_with_no_tags search_term = get_search_term(file_name) diff --git a/backend/handler/scan_handler.py b/backend/handler/scan_handler.py index 18dd99f36..19ccc607b 100644 --- a/backend/handler/scan_handler.py +++ b/backend/handler/scan_handler.py @@ -3,7 +3,7 @@ from typing import Any import emoji from config.config_manager import config_manager as cm -from handler import fsasseth, dbh, igdbh, fsresourceh, fsromh +from handler import fsasseth, dbplatformh, igdbh, fsresourceh, fsromh from logger.logger import log from models import Platform, Rom, Save, Screenshot, State diff --git a/backend/handler/tests/conftest.py b/backend/handler/tests/conftest.py index aec56fb7b..abddd004e 100644 --- a/backend/handler/tests/conftest.py +++ b/backend/handler/tests/conftest.py @@ -6,8 +6,7 @@ from sqlalchemy.orm import sessionmaker from config.config_manager import ConfigManager from models import Platform, Rom, User, Save, State, Screenshot from models.user import Role -from utils.auth import get_password_hash -from .. import dbh +from handler import dbh, dbuserh, dbplatformh, dbromh, dbsaveh, dbstateh, authh engine = create_engine(ConfigManager.get_db_engine(), pool_pre_ping=True) session = sessionmaker(bind=engine, expire_on_commit=False) @@ -34,7 +33,7 @@ def platform(): platform = Platform( name="test_platform", slug="test_platform_slug", fs_slug="test_platform_slug" ) - return dbh.add_platform(platform) + return dbplatformh.add_platform(platform) @pytest.fixture @@ -49,7 +48,7 @@ def rom(platform: Platform): file_path=f"{platform.slug}/roms", file_size_bytes=1000.0, ) - return dbh.add_rom(rom) + return dbromh.add_rom(rom) @pytest.fixture @@ -64,7 +63,7 @@ def save(rom: Rom): file_path=f"{rom.platform_slug}/saves/test_emulator", file_size_bytes=1.0, ) - return dbh.add_save(save) + return dbsaveh.add_save(save) @pytest.fixture @@ -79,7 +78,7 @@ def state(rom: Rom): file_path=f"{rom.platform_slug}/states/test_emulator", file_size_bytes=2.0, ) - return dbh.add_state(state) + return dbstateh.add_state(state) @pytest.fixture @@ -98,27 +97,27 @@ def screenshot(rom: Rom): def admin_user(): user = User( username="test_admin", - hashed_password=get_password_hash("test_admin_password"), + hashed_password=authh.get_password_hash("test_admin_password"), role=Role.ADMIN, ) - return dbh.add_user(user) + return dbuserh.add_user(user) @pytest.fixture def editor_user(): user = User( username="test_editor", - hashed_password=get_password_hash("test_editor_password"), + hashed_password=authh.get_password_hash("test_editor_password"), role=Role.EDITOR, ) - return dbh.add_user(user) + return dbuserh.add_user(user) @pytest.fixture def viewer_user(): user = User( username="test_viewer", - hashed_password=get_password_hash("test_viewer_password"), + hashed_password=authh.get_password_hash("test_viewer_password"), role=Role.VIEWER, ) - return dbh.add_user(user) + return dbuserh.add_user(user) diff --git a/backend/handler/tests/test_db_handler.py b/backend/handler/tests/test_db_handler.py index 55d2984c1..de1826535 100644 --- a/backend/handler/tests/test_db_handler.py +++ b/backend/handler/tests/test_db_handler.py @@ -184,14 +184,14 @@ def test_screenshots(screenshot): rom = dbh.get_rom(screenshot.rom_id) assert len(rom.screenshots) == 2 - screenshot = dbh.get_screenshot(rom.screenshots[0].id) + screenshot = dbscreenshotsh.get_screenshot(rom.screenshots[0].id) assert screenshot.file_name == "test_screenshot.png" - dbh.update_screenshot(screenshot.id, {"file_name": "test_screenshot_2.png"}) + dbscreenshotsh.update_screenshot(screenshot.id, {"file_name": "test_screenshot_2.png"}) screenshot = dbh.get_screenshot(screenshot.id) assert screenshot.file_name == "test_screenshot_2.png" - dbh.delete_screenshot(screenshot.id) + dbscreenshotsh.delete_screenshot(screenshot.id) rom = dbh.get_rom(screenshot.rom_id) assert len(rom.screenshots) == 1 diff --git a/backend/main.py b/backend/main.py index 835219ecc..3d6272a8c 100644 --- a/backend/main.py +++ b/backend/main.py @@ -1,6 +1,5 @@ import re import sys -from typing_extensions import TypedDict import alembic.config import uvicorn @@ -17,12 +16,13 @@ from endpoints import ( tasks, user, webrcade, + stats, ) from endpoints.sockets import scan from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from fastapi_pagination import add_pagination -from handler import authh, dbh, ghh, socketh +from handler import authh, dbuserh, ghh, socketh from handler.auth_handler.hybrid_auth import HybridAuthBackend from handler.auth_handler.middleware import CustomCSRFMiddleware from starlette.middleware.authentication import AuthenticationMiddleware @@ -71,44 +71,18 @@ app.include_router(states.router) app.include_router(tasks.router) app.include_router(webrcade.router) app.include_router(config.router) +app.include_router(stats.router) add_pagination(app) app.mount("/ws", socketh.socket_app) -class StatsReturn(TypedDict): - PLATFORMS: int - ROMS: int - SAVES: int - STATES: int - SCREENSHOTS: int - FILESIZE: int - - -@app.get("/stats") -def stats() -> StatsReturn: - """Endpoint to return the current RomM stats - - Returns: - dict: Dictionary with all the stats - """ - - return { - "PLATFORMS": dbh.get_platforms_count(), - "ROMS": dbh.get_roms_count(), - "SAVES": dbh.get_saves_count(), - "STATES": dbh.get_states_count(), - "SCREENSHOTS": dbh.get_screenshots_count(), - "FILESIZE": dbh.get_total_filesize(), - } - - @app.on_event("startup") def startup() -> None: """Event to handle RomM startup logic.""" # Create default admin user if no admin user exists - if len(dbh.get_admin_users()) == 0 and "pytest" not in sys.modules: + if len(dbuserh.get_admin_users()) == 0 and "pytest" not in sys.modules: authh.create_default_admin_user()