refactor stats and user handlers

This commit is contained in:
Zurdi
2024-01-17 01:41:23 +01:00
parent b4d623c617
commit 0af3ffb86f
27 changed files with 467 additions and 486 deletions

View File

@@ -0,0 +1,10 @@
from typing_extensions import TypedDict
class StatsReturn(TypedDict):
PLATFORMS: int
ROMS: int
SAVES: int
STATES: int
SCREENSHOTS: int
FILESIZE: int

View File

@@ -8,10 +8,10 @@ from exceptions.fs_exceptions import (
RomsNotFoundException,
)
from handler import (
dbh,
dbplatformh,
dbromh,
dbsaveh,
dbscreenshotsh,
dbstateh,
fsasseth,
fsplatformh,
@@ -19,6 +19,7 @@ from handler import (
fsromh,
socketh,
)
from handler.fs_handler import Asset
from handler.redis_handler import high_prio_queue, redis_url
from handler.scan_handler import (
scan_platform,
@@ -28,7 +29,6 @@ from handler.scan_handler import (
scan_state,
)
from logger.logger import log
from handler.fs_handler import Asset
def _get_socket_manager():
@@ -145,14 +145,13 @@ async def scan_platforms(
if save:
# Update file size if changed
if save.file_size_bytes != scanned_save.file_size_bytes:
dbh.update_save(
dbsaveh.update_save(
save.id, {"file_size_bytes": scanned_save.file_size_bytes}
)
continue
scanned_save.emulator = fs_emulator
rom = dbromh.get_rom_by_filename_no_tags(scanned_save.file_name_no_tags)
if rom:
scanned_save.rom_id = rom.id
dbsaveh.add_save(scanned_save)
@@ -173,7 +172,7 @@ async def scan_platforms(
if state:
# Update file size if changed
if state.file_size_bytes != scanned_state.file_size_bytes:
dbh.update_state(
dbstateh.update_state(
state.id, {"file_size_bytes": scanned_state.file_size_bytes}
)
@@ -181,66 +180,66 @@ async def scan_platforms(
scanned_state.emulator = fs_emulator
rom = dbromh.get_rom_by_filename_no_tags(scanned_state.file_name_no_tags)
if rom:
scanned_state.rom_id = rom.id
dbstateh.add_state(scanned_state)
# # Scanning screenshots
# log.info(f"\t · {len(fs_assets['screenshots'])} screenshots found")
# for fs_screenshot_filename in fs_assets["screenshots"]:
# scanned_screenshot = scan_screenshot(
# file_name=fs_screenshot_filename, platform=platform
# )
# Scanning screenshots
fs_screenshots = fsasseth.get_assets(
platform.fs_slug, rom.file_name_no_tags, Asset.SCREENSHOTS
)
log.info(f"\t · {len(fs_screenshots)} screenshots found")
for fs_screenshot_filename in fs_screenshots:
scanned_screenshot = scan_screenshot(
file_name=fs_screenshot_filename, platform=platform
)
# screenshot = dbh.get_screenshot_by_filename(fs_screenshot_filename)
# if screenshot:
# # Update file size if changed
# if screenshot.file_size_bytes != scanned_screenshot.file_size_bytes:
# dbh.update_screenshot(
# screenshot.id,
# {"file_size_bytes": scanned_screenshot.file_size_bytes},
# )
# continue
screenshot = dbscreenshotsh.get_screenshot_by_filename(
fs_screenshot_filename
)
if screenshot:
# Update file size if changed
if screenshot.file_size_bytes != scanned_screenshot.file_size_bytes:
dbscreenshotsh.update_screenshot(
screenshot.id,
{"file_size_bytes": scanned_screenshot.file_size_bytes},
)
continue
# rom = dbh.get_rom_by_filename_no_tags(scanned_screenshot.file_name_no_tags)
if rom:
scanned_screenshot.rom_id = rom.id
dbscreenshotsh.add_screenshot(scanned_screenshot)
# if rom:
# scanned_screenshot.rom_id = rom.id
# dbh.add_screenshot(scanned_screenshot)
# for fs_rom in fs_roms:
# rom = dbh.get_rom_by_filename(platform.id, fs_rom["file_name"])
# dbsaveh.purge_saves(rom.id, [s for _e, s in fs_assets["saves"]])
# dbstateh.purge_states(rom.id, [s for _e, s in fs_assets["states"]])
# dbh.purge_screenshots(rom.id, fs_assets["screenshots"])
# dbromh.purge_roms(platform.id, [rom["file_name"] for rom in fs_roms])
dbsaveh.purge_saves(rom.id, [s for _e, s in fs_saves])
dbstateh.purge_states(rom.id, [s for _e, s in fs_states])
dbscreenshotsh.purge_screenshots(rom.id, fs_screenshots)
dbromh.purge_roms(platform.id, [rom["file_name"] for rom in fs_roms])
# Scanning screenshots outside platform folders
# fs_screenshots = fsasseth.get_screenshots()
# log.info("Screenshots")
# log.info(f" · {len(fs_screenshots)} screenshots found")
# for fs_platform, fs_screenshot_filename in fs_screenshots:
# scanned_screenshot = scan_screenshot(
# file_name=fs_screenshot_filename, fs_platform=fs_platform
# )
fs_screenshots = fsasseth.get_screenshots()
log.info("Screenshots")
log.info(f" · {len(fs_screenshots)} screenshots found")
for fs_platform, fs_screenshot_filename in fs_screenshots:
scanned_screenshot = scan_screenshot(
file_name=fs_screenshot_filename, fs_platform=fs_platform
)
# screenshot = dbh.get_screenshot_by_filename(fs_screenshot_filename)
# if screenshot:
# # Update file size if changed
# if screenshot.file_size_bytes != scanned_screenshot.file_size_bytes:
# dbh.update_screenshot(
# screenshot.id,
# {"file_size_bytes": scanned_screenshot.file_size_bytes},
# )
# continue
screenshot = dbscreenshotsh.get_screenshot_by_filename(fs_screenshot_filename)
if screenshot:
# Update file size if changed
if screenshot.file_size_bytes != scanned_screenshot.file_size_bytes:
dbscreenshotsh.update_screenshot(
screenshot.id,
{"file_size_bytes": scanned_screenshot.file_size_bytes},
)
continue
# rom = dbh.get_rom_by_filename_no_tags(scanned_screenshot.file_name_no_tags)
# if rom:
# scanned_screenshot.rom_id = rom.id
# dbh.add_screenshot(scanned_screenshot)
rom = dbromh.get_rom_by_filename_no_tags(scanned_screenshot.file_name_no_tags)
if rom:
scanned_screenshot.rom_id = rom.id
dbscreenshotsh.add_screenshot(scanned_screenshot)
# dbh.purge_screenshots([s for _e, s in fs_screenshots])
# dbscreenshotsh.purge_screenshots([s for _e, s in fs_screenshots])
dbplatformh.purge_platforms(fs_platforms)
log.info(emoji.emojize(":check_mark: Scan completed "))

View File

@@ -0,0 +1,23 @@
from endpoints.responses.stats import StatsReturn
from fastapi import APIRouter
from handler import dbstatsh
router = APIRouter()
@router.get("/stats")
def stats() -> StatsReturn:
"""Endpoint to return the current RomM stats
Returns:
dict: Dictionary with all the stats
"""
return {
"PLATFORMS": dbstatsh.get_platforms_count(),
"ROMS": dbstatsh.get_roms_count(),
"SAVES": dbstatsh.get_saves_count(),
"STATES": dbstatsh.get_states_count(),
"SCREENSHOTS": dbstatsh.get_screenshots_count(),
"FILESIZE": dbstatsh.get_total_filesize(),
}

View File

@@ -5,8 +5,8 @@ from decorators.auth import protected_route
from endpoints.forms.identity import UserForm
from endpoints.responses import MessageResponse
from endpoints.responses.identity import UserSchema
from fastapi import APIRouter, Depends, HTTPException, Request, status
from handler import authh, dbh, fsresourceh
from fastapi import APIRouter, Depends, HTTPException, Request
from handler import authh, dbuserh, fsresourceh
from models.user import Role, User
router = APIRouter()
@@ -41,7 +41,7 @@ def add_user(request: Request, username: str, password: str, role: str) -> UserS
role=Role[role.upper()],
)
return dbh.add_user(user)
return dbuserh.add_user(user)
@protected_route(router.get, "/users", ["users.read"])
@@ -55,7 +55,7 @@ def get_users(request: Request) -> list[UserSchema]:
list[UserSchema]: All users stored in the RomM's database
"""
return dbh.get_users()
return dbuserh.get_users()
@protected_route(router.get, "/users/me", ["me.read"])
@@ -83,7 +83,7 @@ def get_user(request: Request, id: int) -> UserSchema:
UserSchem: User stored in the RomM's database
"""
user = dbh.get_user(id)
user = dbuserh.get_user(id)
if not user:
raise HTTPException(status_code=404, detail="User not found")
@@ -115,14 +115,14 @@ def update_user(
status_code=400,
detail="Cannot update user: ROMM_AUTH_ENABLED is set to False",
)
user = dbh.get_user(id)
user = dbuserh.get_user(id)
if not user:
raise HTTPException(status_code=404, detail="User not found")
cleaned_data = {}
if form_data.username and form_data.username != user.username:
existing_user = dbh.get_user_by_username(form_data.username.lower())
existing_user = dbuserh.get_user_by_username(form_data.username.lower())
if existing_user:
raise HTTPException(
status_code=400, detail="Username already in use by another user"
@@ -150,7 +150,7 @@ def update_user(
file_object.write(form_data.avatar.file.read())
if cleaned_data:
dbh.update_user(id, cleaned_data)
dbuserh.update_user(id, cleaned_data)
# Log out the current user if username or password changed
creds_updated = cleaned_data.get("username") or cleaned_data.get(
@@ -159,7 +159,7 @@ def update_user(
if request.user.id == id and creds_updated:
authh.clear_session(request)
return dbh.get_user(id)
return dbuserh.get_user(id)
@protected_route(router.delete, "/users/{id}", ["users.write"])
@@ -186,7 +186,7 @@ def delete_user(request: Request, id: int) -> MessageResponse:
detail="Cannot delete user: ROMM_AUTH_ENABLED is set to False",
)
user = dbh.get_user(id)
user = dbuserh.get_user(id)
if not user:
raise HTTPException(status_code=404, detail="User not found")
@@ -195,11 +195,11 @@ def delete_user(request: Request, id: int) -> MessageResponse:
raise HTTPException(status_code=400, detail="You cannot delete yourself")
# You can't delete the last admin user
if user.role == Role.ADMIN and len(dbh.get_admin_users()) == 1:
if user.role == Role.ADMIN and len(dbuserh.get_admin_users()) == 1:
raise HTTPException(
status_code=400, detail="You cannot delete the last admin user"
)
dbh.delete_user(id)
dbuserh.delete_user(id)
return {"msg": "User successfully deleted"}

View File

@@ -1,8 +1,11 @@
from handler.auth_handler.auth_handler import AuthHandler, OAuthHandler
from handler.auth_handler import AuthHandler, OAuthHandler
from handler.db_handler.db_platforms_handler import DBPlatformsHandler
from handler.db_handler.db_roms_handler import DBRomsHandler
from handler.db_handler.db_saves_handler import DBSavesHandler
from handler.db_handler.db_states_handler import DBStatesHandler
from handler.db_handler.db_users_handler import DBUsersHandler
from handler.db_handler.db_stats_handler import DBStatsHandler
from handler.db_handler.db_screenshots_handler import DBScreenshotsHandler
from handler.fs_handler.fs_assets_handler import FSAssetsHandler
from handler.fs_handler.fs_platforms_handler import FSPlatformsHandler
from handler.fs_handler.fs_resources_handler import FSResourceHandler
@@ -22,12 +25,10 @@ dbplatformh = DBPlatformsHandler()
dbromh = DBRomsHandler()
dbsaveh = DBSavesHandler()
dbstateh = DBStatesHandler()
dbuserh = DBUsersHandler()
dbstatsh = DBStatsHandler()
dbscreenshotsh = DBScreenshotsHandler()
fsplatformh = FSPlatformsHandler()
fsromh = FSRomsHandler()
fsasseth = FSAssetsHandler()
fsresourceh = FSResourceHandler()
from handler.db_handler.db_handler import DBHandler
dbh = DBHandler()

View File

@@ -1,5 +1,20 @@
from datetime import datetime, timedelta
from typing import Final
from config import (
ROMM_AUTH_ENABLED,
ROMM_AUTH_PASSWORD,
ROMM_AUTH_SECRET_KEY,
ROMM_AUTH_USERNAME,
)
from exceptions.auth_exceptions import OAuthCredentialsException
from fastapi import HTTPException, Request, status
from handler.redis_handler import cache
from jose import JWTError, jwt
from passlib.context import CryptContext
from sqlalchemy.exc import IntegrityError
from starlette.requests import HTTPConnection
ALGORITHM: Final = "HS256"
DEFAULT_OAUTH_TOKEN_EXPIRY: Final = 15
@@ -26,3 +41,122 @@ FULL_SCOPES_MAP: Final = {
DEFAULT_SCOPES: Final = list(DEFAULT_SCOPES_MAP.keys())
WRITE_SCOPES: Final = DEFAULT_SCOPES + list(WRITE_SCOPES_MAP.keys())
FULL_SCOPES: Final = WRITE_SCOPES + list(FULL_SCOPES_MAP.keys())
class AuthHandler:
def __init__(self) -> None:
self.pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
def _verify_password(self, plain_password, hashed_password):
return self.pwd_context.verify(plain_password, hashed_password)
def get_password_hash(self, password):
return self.pwd_context.hash(password)
@staticmethod
def clear_session(req: HTTPConnection | Request):
session_id = req.session.get("session_id")
if session_id:
redish.cache.delete(f"romm:{session_id}") # type: ignore[attr-defined]
req.session["session_id"] = None
def authenticate_user(self, username: str, password: str):
from handler import dbuserh
user = dbuserh.get_user_by_username(username)
if not user:
return None
if not self._verify_password(password, user.hashed_password):
return None
return user
async def get_current_active_user_from_session(self, conn: HTTPConnection):
from handler import dbuserh
# Check if session key already stored in cache
session_id = conn.session.get("session_id")
if not session_id:
return None
username = cache.get(f"romm:{session_id}") # type: ignore[attr-defined]
if not username:
return None
# Key exists therefore user is probably authenticated
user = dbuserh.get_user_by_username(username)
if user is None:
self.clear_session(conn)
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="User not found",
)
if not user.enabled:
self.clear_session(conn)
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail="Inactive user"
)
return user
def create_default_admin_user(self):
from handler import dbuserh
from models.user import Role, User
if not ROMM_AUTH_ENABLED:
return
try:
dbuserh.add_user(
User(
username=ROMM_AUTH_USERNAME,
hashed_password=self.get_password_hash(ROMM_AUTH_PASSWORD),
role=Role.ADMIN,
)
)
except IntegrityError:
pass
class OAuthHandler:
def __init__(self) -> None:
pass
def create_oauth_token(data: dict, expires_delta: timedelta | None = None):
to_encode = data.copy()
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(minutes=DEFAULT_OAUTH_TOKEN_EXPIRY)
to_encode.update({"exp": expire})
return jwt.encode(to_encode, ROMM_AUTH_SECRET_KEY, algorithm=ALGORITHM)
async def get_current_active_user_from_bearer_token(token: str):
from handler import dbuserh
try:
payload = jwt.decode(token, ROMM_AUTH_SECRET_KEY, algorithms=[ALGORITHM])
except JWTError:
raise OAuthCredentialsException
username = payload.get("sub")
if username is None:
raise OAuthCredentialsException
user = dbuserh.get_user_by_username(username)
if user is None:
raise OAuthCredentialsException
if not user.enabled:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Inactive user"
)
return user, payload

View File

@@ -1,135 +0,0 @@
from datetime import datetime, timedelta
from config import (
ROMM_AUTH_ENABLED,
ROMM_AUTH_PASSWORD,
ROMM_AUTH_SECRET_KEY,
ROMM_AUTH_USERNAME,
)
from exceptions.auth_exceptions import OAuthCredentialsException
from fastapi import HTTPException, Request, status
from handler.auth_handler import ALGORITHM, DEFAULT_OAUTH_TOKEN_EXPIRY
from jose import JWTError, jwt
from passlib.context import CryptContext
from sqlalchemy.exc import IntegrityError
from starlette.requests import HTTPConnection
from handler.redis_handler import cache
class AuthHandler:
def __init__(self) -> None:
self.pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
def _verify_password(self, plain_password, hashed_password):
return self.pwd_context.verify(plain_password, hashed_password)
def get_password_hash(self, password):
return self.pwd_context.hash(password)
@staticmethod
def clear_session(req: HTTPConnection | Request):
session_id = req.session.get("session_id")
if session_id:
redish.cache.delete(f"romm:{session_id}") # type: ignore[attr-defined]
req.session["session_id"] = None
def authenticate_user(self, username: str, password: str):
from handler import dbh
user = dbh.get_user_by_username(username)
if not user:
return None
if not self._verify_password(password, user.hashed_password):
return None
return user
async def get_current_active_user_from_session(self, conn: HTTPConnection):
from handler import dbh
# Check if session key already stored in cache
session_id = conn.session.get("session_id")
if not session_id:
return None
username = cache.get(f"romm:{session_id}") # type: ignore[attr-defined]
if not username:
return None
# Key exists therefore user is probably authenticated
user = dbh.get_user_by_username(username)
if user is None:
self.clear_session(conn)
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="User not found",
)
if not user.enabled:
self.clear_session(conn)
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail="Inactive user"
)
return user
def create_default_admin_user(self):
from handler import dbh
from models.user import Role, User
if not ROMM_AUTH_ENABLED:
return
try:
dbh.add_user(
User(
username=ROMM_AUTH_USERNAME,
hashed_password=self.get_password_hash(ROMM_AUTH_PASSWORD),
role=Role.ADMIN,
)
)
except IntegrityError:
pass
class OAuthHandler:
def __init__(self) -> None:
pass
def create_oauth_token(data: dict, expires_delta: timedelta | None = None):
to_encode = data.copy()
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(minutes=DEFAULT_OAUTH_TOKEN_EXPIRY)
to_encode.update({"exp": expire})
return jwt.encode(to_encode, ROMM_AUTH_SECRET_KEY, algorithm=ALGORITHM)
async def get_current_active_user_from_bearer_token(token: str):
from handler import dbh
try:
payload = jwt.decode(token, ROMM_AUTH_SECRET_KEY, algorithms=[ALGORITHM])
except JWTError:
raise OAuthCredentialsException
username = payload.get("sub")
if username is None:
raise OAuthCredentialsException
user = dbh.get_user_by_username(username)
if user is None:
raise OAuthCredentialsException
if not user.enabled:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Inactive user"
)
return user, payload

View File

@@ -0,0 +1,9 @@
from config.config_manager import ConfigManager
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
class DBHandler:
def __init__(self) -> None:
self.engine = create_engine(ConfigManager.get_db_engine(), pool_pre_ping=True)
self.session = sessionmaker(bind=self.engine, expire_on_commit=False)

View File

@@ -1,132 +0,0 @@
from config.config_manager import ConfigManager
from decorators.database import begin_session
from models import Role, Platform, Rom, Save, Screenshot, State, User
from sqlalchemy import create_engine, delete, func, select, update
from sqlalchemy.orm import Session, sessionmaker
class DBHandler:
def __init__(self) -> None:
self.engine = create_engine(ConfigManager.get_db_engine(), pool_pre_ping=True)
self.session = sessionmaker(bind=self.engine, expire_on_commit=False)
# ========= Screenshots =========
@begin_session
def add_screenshot(self, screenshot: Screenshot, session: Session = None):
return session.merge(screenshot)
@begin_session
def get_screenshot(self, id, session: Session = None):
return session.get(Screenshot, id)
@begin_session
def get_screenshot_by_filename(self, file_name: str, session: Session = None):
return session.scalars(
select(Screenshot).filter_by(file_name=file_name).limit(1)
).first()
@begin_session
def update_screenshot(self, id: int, data: dict, session: Session = None):
session.execute(
update(Screenshot)
.where(Screenshot.id == id)
.values(**data)
.execution_options(synchronize_session="evaluate")
)
@begin_session
def delete_screenshot(self, id: int, session: Session = None):
return session.execute(
delete(Screenshot)
.where(Screenshot.id == id)
.execution_options(synchronize_session="evaluate")
)
@begin_session
def purge_screenshots(
self, rom_id: int, screenshots: list[str], session: Session = None
):
return session.execute(
delete(Screenshot)
.where(
Screenshot.rom_id == rom_id,
Screenshot.file_name.not_in(screenshots),
)
.execution_options(synchronize_session="evaluate")
)
# ========= Users =========
@begin_session
def add_user(self, user: User, session: Session = None):
return session.merge(user)
@begin_session
def get_user_by_username(self, username: str, session: Session = None):
return session.scalars(
select(User).filter_by(username=username).limit(1)
).first()
@begin_session
def get_user(self, id: int, session: Session = None):
return session.get(User, id)
@begin_session
def update_user(self, id: int, data: dict, session: Session = None):
session.execute(
update(User)
.where(User.id == id)
.values(**data)
.execution_options(synchronize_session="evaluate")
)
@begin_session
def delete_user(self, id: int, session: Session = None):
return session.execute(
delete(User)
.where(User.id == id)
.execution_options(synchronize_session="evaluate")
)
@begin_session
def get_users(self, session: Session = None):
return session.scalars(select(User)).all()
@begin_session
def get_admin_users(self, session: Session = None):
return session.scalars(select(User).filter_by(role=Role.ADMIN)).all()
# ========= Stats =========
@begin_session
def get_platforms_count(self, session: Session = None):
# Only count platforms with more then 0 roms
return session.scalar(
select(func.count())
.select_from(Platform)
.where(
select(func.count())
.select_from(Rom)
.filter_by(platform_id=Platform.id)
.as_scalar()
> 0
)
)
@begin_session
def get_roms_count(self, session: Session = None) -> int:
return session.scalar(select(func.count()).select_from(Rom))
@begin_session
def get_saves_count(self, session: Session = None) -> int:
return session.scalar(select(func.count()).select_from(Save))
@begin_session
def get_states_count(self, session: Session = None) -> int:
return session.scalar(select(func.count()).select_from(State))
@begin_session
def get_screenshots_count(self, session: Session = None) -> int:
return session.scalar(select(func.count()).select_from(Screenshot))
@begin_session
def get_total_filesize(self, session: Session = None) -> int:
return 0

View File

@@ -1,10 +1,9 @@
from models import Platform, Rom, Save, State
from decorators.database import begin_session
from handler.db_handler import DBHandler
from models import Platform, Rom
from sqlalchemy import delete, func, or_, select
from sqlalchemy.orm import Session
from decorators.database import begin_session
from handler.db_handler.db_handler import DBHandler
class DBPlatformsHandler(DBHandler):
@begin_session

View File

@@ -1,5 +1,5 @@
from decorators.database import begin_session
from handler.db_handler.db_handler import DBHandler
from handler.db_handler import DBHandler
from models import Rom
from sqlalchemy import and_, delete, func, select, update
from sqlalchemy.orm import Session

View File

@@ -1,5 +1,5 @@
from decorators.database import begin_session
from handler.db_handler.db_handler import DBHandler
from handler.db_handler import DBHandler
from models import Save
from sqlalchemy import and_, delete, select, update
from sqlalchemy.orm import Session

View File

@@ -0,0 +1,51 @@
from decorators.database import begin_session
from handler.db_handler import DBHandler
from models import Screenshot
from sqlalchemy import delete, select, update
from sqlalchemy.orm import Session
class DBScreenshotsHandler(DBHandler):
@begin_session
def add_screenshot(self, screenshot: Screenshot, session: Session = None):
return session.merge(screenshot)
@begin_session
def get_screenshot(self, id, session: Session = None):
return session.get(Screenshot, id)
@begin_session
def get_screenshot_by_filename(self, file_name: str, session: Session = None):
return session.scalars(
select(Screenshot).filter_by(file_name=file_name).limit(1)
).first()
@begin_session
def update_screenshot(self, id: int, data: dict, session: Session = None):
session.execute(
update(Screenshot)
.where(Screenshot.id == id)
.values(**data)
.execution_options(synchronize_session="evaluate")
)
@begin_session
def delete_screenshot(self, id: int, session: Session = None):
return session.execute(
delete(Screenshot)
.where(Screenshot.id == id)
.execution_options(synchronize_session="evaluate")
)
@begin_session
def purge_screenshots(
self, rom_id: int, screenshots: list[str], session: Session = None
):
return session.execute(
delete(Screenshot)
.where(
Screenshot.rom_id == rom_id,
Screenshot.file_name.not_in(screenshots),
)
.execution_options(synchronize_session="evaluate")
)

View File

@@ -1,5 +1,5 @@
from decorators.database import begin_session
from handler.db_handler.db_handler import DBHandler
from handler.db_handler import DBHandler
from models import State
from sqlalchemy import and_, delete, select, update
from sqlalchemy.orm import Session

View File

@@ -0,0 +1,42 @@
from decorators.database import begin_session
from handler.db_handler import DBHandler
from models import Platform, Rom, Save, Screenshot, State
from sqlalchemy import func, select
from sqlalchemy.orm import Session
class DBStatsHandler(DBHandler):
@begin_session
def get_platforms_count(self, session: Session = None):
# Only count platforms with more then 0 roms
return session.scalar(
select(func.count())
.select_from(Platform)
.where(
select(func.count())
.select_from(Rom)
.filter_by(platform_id=Platform.id)
.as_scalar()
> 0
)
)
@begin_session
def get_roms_count(self, session: Session = None) -> int:
return session.scalar(select(func.count()).select_from(Rom))
@begin_session
def get_saves_count(self, session: Session = None) -> int:
return session.scalar(select(func.count()).select_from(Save))
@begin_session
def get_states_count(self, session: Session = None) -> int:
return session.scalar(select(func.count()).select_from(State))
@begin_session
def get_screenshots_count(self, session: Session = None) -> int:
return session.scalar(select(func.count()).select_from(Screenshot))
@begin_session
def get_total_filesize(self, session: Session = None) -> int:
return 0

View File

@@ -0,0 +1,46 @@
from decorators.database import begin_session
from handler.db_handler import DBHandler
from models import Role, User
from sqlalchemy import delete, select, update
from sqlalchemy.orm import Session
class DBUsersHandler(DBHandler):
@begin_session
def add_user(self, user: User, session: Session = None):
return session.merge(user)
@begin_session
def get_user_by_username(self, username: str, session: Session = None):
return session.scalars(
select(User).filter_by(username=username).limit(1)
).first()
@begin_session
def get_user(self, id: int, session: Session = None):
return session.get(User, id)
@begin_session
def update_user(self, id: int, data: dict, session: Session = None):
session.execute(
update(User)
.where(User.id == id)
.values(**data)
.execution_options(synchronize_session="evaluate")
)
@begin_session
def delete_user(self, id: int, session: Session = None):
return session.execute(
delete(User)
.where(User.id == id)
.execution_options(synchronize_session="evaluate")
)
@begin_session
def get_users(self, session: Session = None):
return session.scalars(select(User)).all()
@begin_session
def get_admin_users(self, session: Session = None):
return session.scalars(select(User).filter_by(role=Role.ADMIN)).all()

View File

@@ -1,7 +1,12 @@
import os
import re
import shutil
from abc import ABC
from enum import Enum
from typing import Final
from config import ROMM_BASE_PATH
from config import LIBRARY_BASE_PATH, ROMM_BASE_PATH
from config.config_manager import config_manager as cm
RESOURCES_BASE_PATH: Final = f"{ROMM_BASE_PATH}/resources"
DEFAULT_WIDTH_COVER_L: Final = 264 # Width of big cover of IGDB
@@ -79,3 +84,43 @@ class Asset(Enum):
SAVES = "saves"
STATES = "states"
SCREENSHOTS = "screenshots"
class FSHandler(ABC):
def __init__(self) -> None:
pass
@staticmethod
def get_fs_structure(fs_slug: str, folder: str = cm.config.ROMS_FOLDER_NAME):
return (
f"{folder}/{fs_slug}"
if os.path.exists(cm.config.HIGH_PRIO_STRUCTURE_PATH)
else f"{fs_slug}/{folder}"
)
@staticmethod
def _get_file_name_with_no_extension(file_name: str) -> str:
return re.sub(EXTENSION_REGEX, "", file_name).strip()
@staticmethod
def get_file_name_with_no_tags(file_name: str) -> str:
file_name_no_extension = re.sub(EXTENSION_REGEX, "", file_name).strip()
return re.split(TAG_REGEX, file_name_no_extension)[0].strip()
@staticmethod
def parse_file_extension(file_name) -> str:
match = re.search(EXTENSION_REGEX, file_name)
return match.group(1) if match else ""
@staticmethod
def remove_file(file_name: str, file_path: str):
try:
os.remove(f"{LIBRARY_BASE_PATH}/{file_path}/{file_name}")
except IsADirectoryError:
shutil.rmtree(f"{LIBRARY_BASE_PATH}/{file_path}/{file_name}")
def build_upload_file_path(
self, fs_slug: str, folder: str = cm.config.ROMS_FOLDER_NAME
):
rom_path = self.get_fs_structure(fs_slug, folder=folder)
return f"{LIBRARY_BASE_PATH}/{rom_path}"

View File

@@ -7,8 +7,7 @@ import requests
from config import LIBRARY_BASE_PATH
from config.config_manager import config_manager as cm
from fastapi import UploadFile
from handler.fs_handler import RESOURCES_BASE_PATH, Asset
from handler.fs_handler.fs_handler import FSHandler
from handler.fs_handler import RESOURCES_BASE_PATH, Asset, FSHandler
from logger.logger import log
@@ -86,19 +85,17 @@ class FSAssetsHandler(FSHandler):
saves_file_path = f"{LIBRARY_BASE_PATH}/{assets_path}"
fs_assets: list[str] = []
# fs_states: list[str] = []
# fs_screenshots: list[str] = []
assets: list[str] = []
try:
emulators = list(os.walk(saves_file_path))[0][1]
for emulator in emulators:
fs_assets += [
assets += [
(emulator, file)
for file in list(os.walk(f"{saves_file_path}/{emulator}"))[0][2]
]
fs_assets += [
assets += [
(None, file)
for file in list(os.walk(saves_file_path))[0][2]
if file.split(".")[0] == rom_file_name_no_tags
@@ -106,40 +103,7 @@ class FSAssetsHandler(FSHandler):
except IndexError:
pass
# states_path = self.get_fs_structure(
# platform_slug, folder=cm.config.STATES_FOLDER_NAME
# )
# states_file_path = f"{LIBRARY_BASE_PATH}/{states_path}"
# try:
# emulators = list(os.walk(states_file_path))[0][1]
# for emulator in emulators:
# fs_states += [
# (emulator, file)
# for file in list(os.walk(f"{states_file_path}/{emulator}"))[0][2]
# ]
# fs_states += [
# (None, file) for file in list(os.walk(states_file_path))[0][2]
# ]
# except IndexError:
# pass
# screenshots_path = self.get_fs_structure(
# platform_slug, folder=cm.config.SCREENSHOTS_FOLDER_NAME
# )
# screenshots_file_path = f"{LIBRARY_BASE_PATH}/{screenshots_path}"
# try:
# fs_screenshots += [
# file for file in list(os.walk(screenshots_file_path))[0][2]
# ]
# except IndexError:
# pass
return fs_assets
# "states": fs_states,
# "screenshots": fs_screenshots,
return assets
@staticmethod
def get_screenshots():

View File

@@ -1,48 +0,0 @@
import os
import re
import shutil
from abc import ABC
from config import LIBRARY_BASE_PATH
from config.config_manager import config_manager as cm
from handler.fs_handler import EXTENSION_REGEX, TAG_REGEX
class FSHandler(ABC):
def __init__(self) -> None:
pass
@staticmethod
def get_fs_structure(fs_slug: str, folder: str = cm.config.ROMS_FOLDER_NAME):
return (
f"{folder}/{fs_slug}"
if os.path.exists(cm.config.HIGH_PRIO_STRUCTURE_PATH)
else f"{fs_slug}/{folder}"
)
@staticmethod
def _get_file_name_with_no_extension(file_name: str) -> str:
return re.sub(EXTENSION_REGEX, "", file_name).strip()
@staticmethod
def get_file_name_with_no_tags(file_name: str) -> str:
file_name_no_extension = re.sub(EXTENSION_REGEX, "", file_name).strip()
return re.split(TAG_REGEX, file_name_no_extension)[0].strip()
@staticmethod
def parse_file_extension(file_name) -> str:
match = re.search(EXTENSION_REGEX, file_name)
return match.group(1) if match else ""
def build_upload_file_path(
self, fs_slug: str, folder: str = cm.config.ROMS_FOLDER_NAME
):
rom_path = self.get_fs_structure(fs_slug, folder=folder)
return f"{LIBRARY_BASE_PATH}/{rom_path}"
@staticmethod
def remove_file(file_name: str, file_path: str):
try:
os.remove(f"{LIBRARY_BASE_PATH}/{file_path}/{file_name}")
except IsADirectoryError:
shutil.rmtree(f"{LIBRARY_BASE_PATH}/{file_path}/{file_name}")

View File

@@ -3,7 +3,7 @@ import os
from config import LIBRARY_BASE_PATH
from config.config_manager import config_manager as cm
from exceptions.fs_exceptions import FolderStructureNotMatchException
from handler.fs_handler.fs_handler import FSHandler
from handler.fs_handler import FSHandler
class FSPlatformsHandler(FSHandler):

View File

@@ -18,8 +18,8 @@ from handler.fs_handler import (
DEFAULT_WIDTH_COVER_S,
RESOURCES_BASE_PATH,
CoverSize,
FSHandler,
)
from handler.fs_handler.fs_handler import FSHandler
from PIL import Image

View File

@@ -2,7 +2,6 @@ import fnmatch
import os
import re
from pathlib import Path
from models.platform import Platform
from config import LIBRARY_BASE_PATH
from config.config_manager import config_manager as cm
@@ -13,8 +12,9 @@ from handler.fs_handler import (
REGIONS_BY_SHORTCODE,
REGIONS_NAME_KEYS,
TAG_REGEX,
FSHandler,
)
from handler.fs_handler.fs_handler import FSHandler
from models.platform import Platform
class FSRomsHandler(FSHandler):

View File

@@ -284,7 +284,7 @@ class IGDBHandler:
@check_twitch_token
async def get_rom(self, file_name: str, platform_idgb_id: int) -> IGDBRomType:
# TODO: refactor
from handler.fs_handler.fs_handler import FSHandler
from handler.fs_handler import FSHandler
get_search_term = FSHandler.get_file_name_with_no_tags
search_term = get_search_term(file_name)

View File

@@ -3,7 +3,7 @@ from typing import Any
import emoji
from config.config_manager import config_manager as cm
from handler import fsasseth, dbh, igdbh, fsresourceh, fsromh
from handler import fsasseth, dbplatformh, igdbh, fsresourceh, fsromh
from logger.logger import log
from models import Platform, Rom, Save, Screenshot, State

View File

@@ -6,8 +6,7 @@ from sqlalchemy.orm import sessionmaker
from config.config_manager import ConfigManager
from models import Platform, Rom, User, Save, State, Screenshot
from models.user import Role
from utils.auth import get_password_hash
from .. import dbh
from handler import dbh, dbuserh, dbplatformh, dbromh, dbsaveh, dbstateh, authh
engine = create_engine(ConfigManager.get_db_engine(), pool_pre_ping=True)
session = sessionmaker(bind=engine, expire_on_commit=False)
@@ -34,7 +33,7 @@ def platform():
platform = Platform(
name="test_platform", slug="test_platform_slug", fs_slug="test_platform_slug"
)
return dbh.add_platform(platform)
return dbplatformh.add_platform(platform)
@pytest.fixture
@@ -49,7 +48,7 @@ def rom(platform: Platform):
file_path=f"{platform.slug}/roms",
file_size_bytes=1000.0,
)
return dbh.add_rom(rom)
return dbromh.add_rom(rom)
@pytest.fixture
@@ -64,7 +63,7 @@ def save(rom: Rom):
file_path=f"{rom.platform_slug}/saves/test_emulator",
file_size_bytes=1.0,
)
return dbh.add_save(save)
return dbsaveh.add_save(save)
@pytest.fixture
@@ -79,7 +78,7 @@ def state(rom: Rom):
file_path=f"{rom.platform_slug}/states/test_emulator",
file_size_bytes=2.0,
)
return dbh.add_state(state)
return dbstateh.add_state(state)
@pytest.fixture
@@ -98,27 +97,27 @@ def screenshot(rom: Rom):
def admin_user():
user = User(
username="test_admin",
hashed_password=get_password_hash("test_admin_password"),
hashed_password=authh.get_password_hash("test_admin_password"),
role=Role.ADMIN,
)
return dbh.add_user(user)
return dbuserh.add_user(user)
@pytest.fixture
def editor_user():
user = User(
username="test_editor",
hashed_password=get_password_hash("test_editor_password"),
hashed_password=authh.get_password_hash("test_editor_password"),
role=Role.EDITOR,
)
return dbh.add_user(user)
return dbuserh.add_user(user)
@pytest.fixture
def viewer_user():
user = User(
username="test_viewer",
hashed_password=get_password_hash("test_viewer_password"),
hashed_password=authh.get_password_hash("test_viewer_password"),
role=Role.VIEWER,
)
return dbh.add_user(user)
return dbuserh.add_user(user)

View File

@@ -184,14 +184,14 @@ def test_screenshots(screenshot):
rom = dbh.get_rom(screenshot.rom_id)
assert len(rom.screenshots) == 2
screenshot = dbh.get_screenshot(rom.screenshots[0].id)
screenshot = dbscreenshotsh.get_screenshot(rom.screenshots[0].id)
assert screenshot.file_name == "test_screenshot.png"
dbh.update_screenshot(screenshot.id, {"file_name": "test_screenshot_2.png"})
dbscreenshotsh.update_screenshot(screenshot.id, {"file_name": "test_screenshot_2.png"})
screenshot = dbh.get_screenshot(screenshot.id)
assert screenshot.file_name == "test_screenshot_2.png"
dbh.delete_screenshot(screenshot.id)
dbscreenshotsh.delete_screenshot(screenshot.id)
rom = dbh.get_rom(screenshot.rom_id)
assert len(rom.screenshots) == 1

View File

@@ -1,6 +1,5 @@
import re
import sys
from typing_extensions import TypedDict
import alembic.config
import uvicorn
@@ -17,12 +16,13 @@ from endpoints import (
tasks,
user,
webrcade,
stats,
)
from endpoints.sockets import scan
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi_pagination import add_pagination
from handler import authh, dbh, ghh, socketh
from handler import authh, dbuserh, ghh, socketh
from handler.auth_handler.hybrid_auth import HybridAuthBackend
from handler.auth_handler.middleware import CustomCSRFMiddleware
from starlette.middleware.authentication import AuthenticationMiddleware
@@ -71,44 +71,18 @@ app.include_router(states.router)
app.include_router(tasks.router)
app.include_router(webrcade.router)
app.include_router(config.router)
app.include_router(stats.router)
add_pagination(app)
app.mount("/ws", socketh.socket_app)
class StatsReturn(TypedDict):
PLATFORMS: int
ROMS: int
SAVES: int
STATES: int
SCREENSHOTS: int
FILESIZE: int
@app.get("/stats")
def stats() -> StatsReturn:
"""Endpoint to return the current RomM stats
Returns:
dict: Dictionary with all the stats
"""
return {
"PLATFORMS": dbh.get_platforms_count(),
"ROMS": dbh.get_roms_count(),
"SAVES": dbh.get_saves_count(),
"STATES": dbh.get_states_count(),
"SCREENSHOTS": dbh.get_screenshots_count(),
"FILESIZE": dbh.get_total_filesize(),
}
@app.on_event("startup")
def startup() -> None:
"""Event to handle RomM startup logic."""
# Create default admin user if no admin user exists
if len(dbh.get_admin_users()) == 0 and "pytest" not in sys.modules:
if len(dbuserh.get_admin_users()) == 0 and "pytest" not in sys.modules:
authh.create_default_admin_user()