mirror of
https://github.com/rommapp/romm.git
synced 2026-06-28 06:46:00 +00:00
refactor stats and user handlers
This commit is contained in:
10
backend/endpoints/responses/stats.py
Normal file
10
backend/endpoints/responses/stats.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
|
||||
class StatsReturn(TypedDict):
|
||||
PLATFORMS: int
|
||||
ROMS: int
|
||||
SAVES: int
|
||||
STATES: int
|
||||
SCREENSHOTS: int
|
||||
FILESIZE: int
|
||||
@@ -8,10 +8,10 @@ from exceptions.fs_exceptions import (
|
||||
RomsNotFoundException,
|
||||
)
|
||||
from handler import (
|
||||
dbh,
|
||||
dbplatformh,
|
||||
dbromh,
|
||||
dbsaveh,
|
||||
dbscreenshotsh,
|
||||
dbstateh,
|
||||
fsasseth,
|
||||
fsplatformh,
|
||||
@@ -19,6 +19,7 @@ from handler import (
|
||||
fsromh,
|
||||
socketh,
|
||||
)
|
||||
from handler.fs_handler import Asset
|
||||
from handler.redis_handler import high_prio_queue, redis_url
|
||||
from handler.scan_handler import (
|
||||
scan_platform,
|
||||
@@ -28,7 +29,6 @@ from handler.scan_handler import (
|
||||
scan_state,
|
||||
)
|
||||
from logger.logger import log
|
||||
from handler.fs_handler import Asset
|
||||
|
||||
|
||||
def _get_socket_manager():
|
||||
@@ -145,14 +145,13 @@ async def scan_platforms(
|
||||
if save:
|
||||
# Update file size if changed
|
||||
if save.file_size_bytes != scanned_save.file_size_bytes:
|
||||
dbh.update_save(
|
||||
dbsaveh.update_save(
|
||||
save.id, {"file_size_bytes": scanned_save.file_size_bytes}
|
||||
)
|
||||
continue
|
||||
|
||||
scanned_save.emulator = fs_emulator
|
||||
|
||||
rom = dbromh.get_rom_by_filename_no_tags(scanned_save.file_name_no_tags)
|
||||
if rom:
|
||||
scanned_save.rom_id = rom.id
|
||||
dbsaveh.add_save(scanned_save)
|
||||
@@ -173,7 +172,7 @@ async def scan_platforms(
|
||||
if state:
|
||||
# Update file size if changed
|
||||
if state.file_size_bytes != scanned_state.file_size_bytes:
|
||||
dbh.update_state(
|
||||
dbstateh.update_state(
|
||||
state.id, {"file_size_bytes": scanned_state.file_size_bytes}
|
||||
)
|
||||
|
||||
@@ -181,66 +180,66 @@ async def scan_platforms(
|
||||
|
||||
scanned_state.emulator = fs_emulator
|
||||
|
||||
rom = dbromh.get_rom_by_filename_no_tags(scanned_state.file_name_no_tags)
|
||||
if rom:
|
||||
scanned_state.rom_id = rom.id
|
||||
dbstateh.add_state(scanned_state)
|
||||
|
||||
# # Scanning screenshots
|
||||
# log.info(f"\t · {len(fs_assets['screenshots'])} screenshots found")
|
||||
# for fs_screenshot_filename in fs_assets["screenshots"]:
|
||||
# scanned_screenshot = scan_screenshot(
|
||||
# file_name=fs_screenshot_filename, platform=platform
|
||||
# )
|
||||
# Scanning screenshots
|
||||
fs_screenshots = fsasseth.get_assets(
|
||||
platform.fs_slug, rom.file_name_no_tags, Asset.SCREENSHOTS
|
||||
)
|
||||
log.info(f"\t · {len(fs_screenshots)} screenshots found")
|
||||
for fs_screenshot_filename in fs_screenshots:
|
||||
scanned_screenshot = scan_screenshot(
|
||||
file_name=fs_screenshot_filename, platform=platform
|
||||
)
|
||||
|
||||
# screenshot = dbh.get_screenshot_by_filename(fs_screenshot_filename)
|
||||
# if screenshot:
|
||||
# # Update file size if changed
|
||||
# if screenshot.file_size_bytes != scanned_screenshot.file_size_bytes:
|
||||
# dbh.update_screenshot(
|
||||
# screenshot.id,
|
||||
# {"file_size_bytes": scanned_screenshot.file_size_bytes},
|
||||
# )
|
||||
# continue
|
||||
screenshot = dbscreenshotsh.get_screenshot_by_filename(
|
||||
fs_screenshot_filename
|
||||
)
|
||||
if screenshot:
|
||||
# Update file size if changed
|
||||
if screenshot.file_size_bytes != scanned_screenshot.file_size_bytes:
|
||||
dbscreenshotsh.update_screenshot(
|
||||
screenshot.id,
|
||||
{"file_size_bytes": scanned_screenshot.file_size_bytes},
|
||||
)
|
||||
continue
|
||||
|
||||
# rom = dbh.get_rom_by_filename_no_tags(scanned_screenshot.file_name_no_tags)
|
||||
if rom:
|
||||
scanned_screenshot.rom_id = rom.id
|
||||
dbscreenshotsh.add_screenshot(scanned_screenshot)
|
||||
|
||||
# if rom:
|
||||
# scanned_screenshot.rom_id = rom.id
|
||||
# dbh.add_screenshot(scanned_screenshot)
|
||||
|
||||
# for fs_rom in fs_roms:
|
||||
# rom = dbh.get_rom_by_filename(platform.id, fs_rom["file_name"])
|
||||
# dbsaveh.purge_saves(rom.id, [s for _e, s in fs_assets["saves"]])
|
||||
# dbstateh.purge_states(rom.id, [s for _e, s in fs_assets["states"]])
|
||||
# dbh.purge_screenshots(rom.id, fs_assets["screenshots"])
|
||||
# dbromh.purge_roms(platform.id, [rom["file_name"] for rom in fs_roms])
|
||||
dbsaveh.purge_saves(rom.id, [s for _e, s in fs_saves])
|
||||
dbstateh.purge_states(rom.id, [s for _e, s in fs_states])
|
||||
dbscreenshotsh.purge_screenshots(rom.id, fs_screenshots)
|
||||
dbromh.purge_roms(platform.id, [rom["file_name"] for rom in fs_roms])
|
||||
|
||||
# Scanning screenshots outside platform folders
|
||||
# fs_screenshots = fsasseth.get_screenshots()
|
||||
# log.info("Screenshots")
|
||||
# log.info(f" · {len(fs_screenshots)} screenshots found")
|
||||
# for fs_platform, fs_screenshot_filename in fs_screenshots:
|
||||
# scanned_screenshot = scan_screenshot(
|
||||
# file_name=fs_screenshot_filename, fs_platform=fs_platform
|
||||
# )
|
||||
fs_screenshots = fsasseth.get_screenshots()
|
||||
log.info("Screenshots")
|
||||
log.info(f" · {len(fs_screenshots)} screenshots found")
|
||||
for fs_platform, fs_screenshot_filename in fs_screenshots:
|
||||
scanned_screenshot = scan_screenshot(
|
||||
file_name=fs_screenshot_filename, fs_platform=fs_platform
|
||||
)
|
||||
|
||||
# screenshot = dbh.get_screenshot_by_filename(fs_screenshot_filename)
|
||||
# if screenshot:
|
||||
# # Update file size if changed
|
||||
# if screenshot.file_size_bytes != scanned_screenshot.file_size_bytes:
|
||||
# dbh.update_screenshot(
|
||||
# screenshot.id,
|
||||
# {"file_size_bytes": scanned_screenshot.file_size_bytes},
|
||||
# )
|
||||
# continue
|
||||
screenshot = dbscreenshotsh.get_screenshot_by_filename(fs_screenshot_filename)
|
||||
if screenshot:
|
||||
# Update file size if changed
|
||||
if screenshot.file_size_bytes != scanned_screenshot.file_size_bytes:
|
||||
dbscreenshotsh.update_screenshot(
|
||||
screenshot.id,
|
||||
{"file_size_bytes": scanned_screenshot.file_size_bytes},
|
||||
)
|
||||
continue
|
||||
|
||||
# rom = dbh.get_rom_by_filename_no_tags(scanned_screenshot.file_name_no_tags)
|
||||
# if rom:
|
||||
# scanned_screenshot.rom_id = rom.id
|
||||
# dbh.add_screenshot(scanned_screenshot)
|
||||
rom = dbromh.get_rom_by_filename_no_tags(scanned_screenshot.file_name_no_tags)
|
||||
if rom:
|
||||
scanned_screenshot.rom_id = rom.id
|
||||
dbscreenshotsh.add_screenshot(scanned_screenshot)
|
||||
|
||||
# dbh.purge_screenshots([s for _e, s in fs_screenshots])
|
||||
# dbscreenshotsh.purge_screenshots([s for _e, s in fs_screenshots])
|
||||
dbplatformh.purge_platforms(fs_platforms)
|
||||
|
||||
log.info(emoji.emojize(":check_mark: Scan completed "))
|
||||
|
||||
23
backend/endpoints/stats.py
Normal file
23
backend/endpoints/stats.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from endpoints.responses.stats import StatsReturn
|
||||
from fastapi import APIRouter
|
||||
from handler import dbstatsh
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/stats")
|
||||
def stats() -> StatsReturn:
|
||||
"""Endpoint to return the current RomM stats
|
||||
|
||||
Returns:
|
||||
dict: Dictionary with all the stats
|
||||
"""
|
||||
|
||||
return {
|
||||
"PLATFORMS": dbstatsh.get_platforms_count(),
|
||||
"ROMS": dbstatsh.get_roms_count(),
|
||||
"SAVES": dbstatsh.get_saves_count(),
|
||||
"STATES": dbstatsh.get_states_count(),
|
||||
"SCREENSHOTS": dbstatsh.get_screenshots_count(),
|
||||
"FILESIZE": dbstatsh.get_total_filesize(),
|
||||
}
|
||||
@@ -5,8 +5,8 @@ from decorators.auth import protected_route
|
||||
from endpoints.forms.identity import UserForm
|
||||
from endpoints.responses import MessageResponse
|
||||
from endpoints.responses.identity import UserSchema
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from handler import authh, dbh, fsresourceh
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from handler import authh, dbuserh, fsresourceh
|
||||
from models.user import Role, User
|
||||
|
||||
router = APIRouter()
|
||||
@@ -41,7 +41,7 @@ def add_user(request: Request, username: str, password: str, role: str) -> UserS
|
||||
role=Role[role.upper()],
|
||||
)
|
||||
|
||||
return dbh.add_user(user)
|
||||
return dbuserh.add_user(user)
|
||||
|
||||
|
||||
@protected_route(router.get, "/users", ["users.read"])
|
||||
@@ -55,7 +55,7 @@ def get_users(request: Request) -> list[UserSchema]:
|
||||
list[UserSchema]: All users stored in the RomM's database
|
||||
"""
|
||||
|
||||
return dbh.get_users()
|
||||
return dbuserh.get_users()
|
||||
|
||||
|
||||
@protected_route(router.get, "/users/me", ["me.read"])
|
||||
@@ -83,7 +83,7 @@ def get_user(request: Request, id: int) -> UserSchema:
|
||||
UserSchem: User stored in the RomM's database
|
||||
"""
|
||||
|
||||
user = dbh.get_user(id)
|
||||
user = dbuserh.get_user(id)
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
@@ -115,14 +115,14 @@ def update_user(
|
||||
status_code=400,
|
||||
detail="Cannot update user: ROMM_AUTH_ENABLED is set to False",
|
||||
)
|
||||
user = dbh.get_user(id)
|
||||
user = dbuserh.get_user(id)
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
cleaned_data = {}
|
||||
|
||||
if form_data.username and form_data.username != user.username:
|
||||
existing_user = dbh.get_user_by_username(form_data.username.lower())
|
||||
existing_user = dbuserh.get_user_by_username(form_data.username.lower())
|
||||
if existing_user:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Username already in use by another user"
|
||||
@@ -150,7 +150,7 @@ def update_user(
|
||||
file_object.write(form_data.avatar.file.read())
|
||||
|
||||
if cleaned_data:
|
||||
dbh.update_user(id, cleaned_data)
|
||||
dbuserh.update_user(id, cleaned_data)
|
||||
|
||||
# Log out the current user if username or password changed
|
||||
creds_updated = cleaned_data.get("username") or cleaned_data.get(
|
||||
@@ -159,7 +159,7 @@ def update_user(
|
||||
if request.user.id == id and creds_updated:
|
||||
authh.clear_session(request)
|
||||
|
||||
return dbh.get_user(id)
|
||||
return dbuserh.get_user(id)
|
||||
|
||||
|
||||
@protected_route(router.delete, "/users/{id}", ["users.write"])
|
||||
@@ -186,7 +186,7 @@ def delete_user(request: Request, id: int) -> MessageResponse:
|
||||
detail="Cannot delete user: ROMM_AUTH_ENABLED is set to False",
|
||||
)
|
||||
|
||||
user = dbh.get_user(id)
|
||||
user = dbuserh.get_user(id)
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
@@ -195,11 +195,11 @@ def delete_user(request: Request, id: int) -> MessageResponse:
|
||||
raise HTTPException(status_code=400, detail="You cannot delete yourself")
|
||||
|
||||
# You can't delete the last admin user
|
||||
if user.role == Role.ADMIN and len(dbh.get_admin_users()) == 1:
|
||||
if user.role == Role.ADMIN and len(dbuserh.get_admin_users()) == 1:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="You cannot delete the last admin user"
|
||||
)
|
||||
|
||||
dbh.delete_user(id)
|
||||
dbuserh.delete_user(id)
|
||||
|
||||
return {"msg": "User successfully deleted"}
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
from handler.auth_handler.auth_handler import AuthHandler, OAuthHandler
|
||||
from handler.auth_handler import AuthHandler, OAuthHandler
|
||||
from handler.db_handler.db_platforms_handler import DBPlatformsHandler
|
||||
from handler.db_handler.db_roms_handler import DBRomsHandler
|
||||
from handler.db_handler.db_saves_handler import DBSavesHandler
|
||||
from handler.db_handler.db_states_handler import DBStatesHandler
|
||||
from handler.db_handler.db_users_handler import DBUsersHandler
|
||||
from handler.db_handler.db_stats_handler import DBStatsHandler
|
||||
from handler.db_handler.db_screenshots_handler import DBScreenshotsHandler
|
||||
from handler.fs_handler.fs_assets_handler import FSAssetsHandler
|
||||
from handler.fs_handler.fs_platforms_handler import FSPlatformsHandler
|
||||
from handler.fs_handler.fs_resources_handler import FSResourceHandler
|
||||
@@ -22,12 +25,10 @@ dbplatformh = DBPlatformsHandler()
|
||||
dbromh = DBRomsHandler()
|
||||
dbsaveh = DBSavesHandler()
|
||||
dbstateh = DBStatesHandler()
|
||||
dbuserh = DBUsersHandler()
|
||||
dbstatsh = DBStatsHandler()
|
||||
dbscreenshotsh = DBScreenshotsHandler()
|
||||
fsplatformh = FSPlatformsHandler()
|
||||
fsromh = FSRomsHandler()
|
||||
fsasseth = FSAssetsHandler()
|
||||
fsresourceh = FSResourceHandler()
|
||||
|
||||
|
||||
from handler.db_handler.db_handler import DBHandler
|
||||
|
||||
dbh = DBHandler()
|
||||
|
||||
@@ -1,5 +1,20 @@
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Final
|
||||
|
||||
from config import (
|
||||
ROMM_AUTH_ENABLED,
|
||||
ROMM_AUTH_PASSWORD,
|
||||
ROMM_AUTH_SECRET_KEY,
|
||||
ROMM_AUTH_USERNAME,
|
||||
)
|
||||
from exceptions.auth_exceptions import OAuthCredentialsException
|
||||
from fastapi import HTTPException, Request, status
|
||||
from handler.redis_handler import cache
|
||||
from jose import JWTError, jwt
|
||||
from passlib.context import CryptContext
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from starlette.requests import HTTPConnection
|
||||
|
||||
ALGORITHM: Final = "HS256"
|
||||
DEFAULT_OAUTH_TOKEN_EXPIRY: Final = 15
|
||||
|
||||
@@ -26,3 +41,122 @@ FULL_SCOPES_MAP: Final = {
|
||||
DEFAULT_SCOPES: Final = list(DEFAULT_SCOPES_MAP.keys())
|
||||
WRITE_SCOPES: Final = DEFAULT_SCOPES + list(WRITE_SCOPES_MAP.keys())
|
||||
FULL_SCOPES: Final = WRITE_SCOPES + list(FULL_SCOPES_MAP.keys())
|
||||
|
||||
|
||||
class AuthHandler:
|
||||
def __init__(self) -> None:
|
||||
self.pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
|
||||
def _verify_password(self, plain_password, hashed_password):
|
||||
return self.pwd_context.verify(plain_password, hashed_password)
|
||||
|
||||
def get_password_hash(self, password):
|
||||
return self.pwd_context.hash(password)
|
||||
|
||||
@staticmethod
|
||||
def clear_session(req: HTTPConnection | Request):
|
||||
session_id = req.session.get("session_id")
|
||||
if session_id:
|
||||
redish.cache.delete(f"romm:{session_id}") # type: ignore[attr-defined]
|
||||
req.session["session_id"] = None
|
||||
|
||||
def authenticate_user(self, username: str, password: str):
|
||||
from handler import dbuserh
|
||||
|
||||
user = dbuserh.get_user_by_username(username)
|
||||
if not user:
|
||||
return None
|
||||
|
||||
if not self._verify_password(password, user.hashed_password):
|
||||
return None
|
||||
|
||||
return user
|
||||
|
||||
async def get_current_active_user_from_session(self, conn: HTTPConnection):
|
||||
from handler import dbuserh
|
||||
|
||||
# Check if session key already stored in cache
|
||||
session_id = conn.session.get("session_id")
|
||||
if not session_id:
|
||||
return None
|
||||
|
||||
username = cache.get(f"romm:{session_id}") # type: ignore[attr-defined]
|
||||
if not username:
|
||||
return None
|
||||
|
||||
# Key exists therefore user is probably authenticated
|
||||
user = dbuserh.get_user_by_username(username)
|
||||
if user is None:
|
||||
self.clear_session(conn)
|
||||
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="User not found",
|
||||
)
|
||||
|
||||
if not user.enabled:
|
||||
self.clear_session(conn)
|
||||
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail="Inactive user"
|
||||
)
|
||||
|
||||
return user
|
||||
|
||||
def create_default_admin_user(self):
|
||||
from handler import dbuserh
|
||||
from models.user import Role, User
|
||||
|
||||
if not ROMM_AUTH_ENABLED:
|
||||
return
|
||||
|
||||
try:
|
||||
dbuserh.add_user(
|
||||
User(
|
||||
username=ROMM_AUTH_USERNAME,
|
||||
hashed_password=self.get_password_hash(ROMM_AUTH_PASSWORD),
|
||||
role=Role.ADMIN,
|
||||
)
|
||||
)
|
||||
except IntegrityError:
|
||||
pass
|
||||
|
||||
|
||||
class OAuthHandler:
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def create_oauth_token(data: dict, expires_delta: timedelta | None = None):
|
||||
to_encode = data.copy()
|
||||
|
||||
if expires_delta:
|
||||
expire = datetime.utcnow() + expires_delta
|
||||
else:
|
||||
expire = datetime.utcnow() + timedelta(minutes=DEFAULT_OAUTH_TOKEN_EXPIRY)
|
||||
|
||||
to_encode.update({"exp": expire})
|
||||
|
||||
return jwt.encode(to_encode, ROMM_AUTH_SECRET_KEY, algorithm=ALGORITHM)
|
||||
|
||||
async def get_current_active_user_from_bearer_token(token: str):
|
||||
from handler import dbuserh
|
||||
|
||||
try:
|
||||
payload = jwt.decode(token, ROMM_AUTH_SECRET_KEY, algorithms=[ALGORITHM])
|
||||
except JWTError:
|
||||
raise OAuthCredentialsException
|
||||
|
||||
username = payload.get("sub")
|
||||
if username is None:
|
||||
raise OAuthCredentialsException
|
||||
|
||||
user = dbuserh.get_user_by_username(username)
|
||||
if user is None:
|
||||
raise OAuthCredentialsException
|
||||
|
||||
if not user.enabled:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Inactive user"
|
||||
)
|
||||
|
||||
return user, payload
|
||||
|
||||
@@ -1,135 +0,0 @@
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from config import (
|
||||
ROMM_AUTH_ENABLED,
|
||||
ROMM_AUTH_PASSWORD,
|
||||
ROMM_AUTH_SECRET_KEY,
|
||||
ROMM_AUTH_USERNAME,
|
||||
)
|
||||
from exceptions.auth_exceptions import OAuthCredentialsException
|
||||
from fastapi import HTTPException, Request, status
|
||||
from handler.auth_handler import ALGORITHM, DEFAULT_OAUTH_TOKEN_EXPIRY
|
||||
from jose import JWTError, jwt
|
||||
from passlib.context import CryptContext
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from starlette.requests import HTTPConnection
|
||||
from handler.redis_handler import cache
|
||||
|
||||
|
||||
class AuthHandler:
|
||||
def __init__(self) -> None:
|
||||
self.pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
|
||||
def _verify_password(self, plain_password, hashed_password):
|
||||
return self.pwd_context.verify(plain_password, hashed_password)
|
||||
|
||||
def get_password_hash(self, password):
|
||||
return self.pwd_context.hash(password)
|
||||
|
||||
@staticmethod
|
||||
def clear_session(req: HTTPConnection | Request):
|
||||
session_id = req.session.get("session_id")
|
||||
if session_id:
|
||||
redish.cache.delete(f"romm:{session_id}") # type: ignore[attr-defined]
|
||||
req.session["session_id"] = None
|
||||
|
||||
def authenticate_user(self, username: str, password: str):
|
||||
from handler import dbh
|
||||
|
||||
user = dbh.get_user_by_username(username)
|
||||
if not user:
|
||||
return None
|
||||
|
||||
if not self._verify_password(password, user.hashed_password):
|
||||
return None
|
||||
|
||||
return user
|
||||
|
||||
async def get_current_active_user_from_session(self, conn: HTTPConnection):
|
||||
from handler import dbh
|
||||
|
||||
# Check if session key already stored in cache
|
||||
session_id = conn.session.get("session_id")
|
||||
if not session_id:
|
||||
return None
|
||||
|
||||
username = cache.get(f"romm:{session_id}") # type: ignore[attr-defined]
|
||||
if not username:
|
||||
return None
|
||||
|
||||
# Key exists therefore user is probably authenticated
|
||||
user = dbh.get_user_by_username(username)
|
||||
if user is None:
|
||||
self.clear_session(conn)
|
||||
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="User not found",
|
||||
)
|
||||
|
||||
if not user.enabled:
|
||||
self.clear_session(conn)
|
||||
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail="Inactive user"
|
||||
)
|
||||
|
||||
return user
|
||||
|
||||
def create_default_admin_user(self):
|
||||
from handler import dbh
|
||||
from models.user import Role, User
|
||||
|
||||
if not ROMM_AUTH_ENABLED:
|
||||
return
|
||||
|
||||
try:
|
||||
dbh.add_user(
|
||||
User(
|
||||
username=ROMM_AUTH_USERNAME,
|
||||
hashed_password=self.get_password_hash(ROMM_AUTH_PASSWORD),
|
||||
role=Role.ADMIN,
|
||||
)
|
||||
)
|
||||
except IntegrityError:
|
||||
pass
|
||||
|
||||
|
||||
class OAuthHandler:
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def create_oauth_token(data: dict, expires_delta: timedelta | None = None):
|
||||
to_encode = data.copy()
|
||||
|
||||
if expires_delta:
|
||||
expire = datetime.utcnow() + expires_delta
|
||||
else:
|
||||
expire = datetime.utcnow() + timedelta(minutes=DEFAULT_OAUTH_TOKEN_EXPIRY)
|
||||
|
||||
to_encode.update({"exp": expire})
|
||||
|
||||
return jwt.encode(to_encode, ROMM_AUTH_SECRET_KEY, algorithm=ALGORITHM)
|
||||
|
||||
async def get_current_active_user_from_bearer_token(token: str):
|
||||
from handler import dbh
|
||||
|
||||
try:
|
||||
payload = jwt.decode(token, ROMM_AUTH_SECRET_KEY, algorithms=[ALGORITHM])
|
||||
except JWTError:
|
||||
raise OAuthCredentialsException
|
||||
|
||||
username = payload.get("sub")
|
||||
if username is None:
|
||||
raise OAuthCredentialsException
|
||||
|
||||
user = dbh.get_user_by_username(username)
|
||||
if user is None:
|
||||
raise OAuthCredentialsException
|
||||
|
||||
if not user.enabled:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Inactive user"
|
||||
)
|
||||
|
||||
return user, payload
|
||||
@@ -0,0 +1,9 @@
|
||||
from config.config_manager import ConfigManager
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
|
||||
class DBHandler:
|
||||
def __init__(self) -> None:
|
||||
self.engine = create_engine(ConfigManager.get_db_engine(), pool_pre_ping=True)
|
||||
self.session = sessionmaker(bind=self.engine, expire_on_commit=False)
|
||||
|
||||
@@ -1,132 +0,0 @@
|
||||
from config.config_manager import ConfigManager
|
||||
from decorators.database import begin_session
|
||||
from models import Role, Platform, Rom, Save, Screenshot, State, User
|
||||
from sqlalchemy import create_engine, delete, func, select, update
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
|
||||
class DBHandler:
|
||||
def __init__(self) -> None:
|
||||
self.engine = create_engine(ConfigManager.get_db_engine(), pool_pre_ping=True)
|
||||
self.session = sessionmaker(bind=self.engine, expire_on_commit=False)
|
||||
|
||||
# ========= Screenshots =========
|
||||
@begin_session
|
||||
def add_screenshot(self, screenshot: Screenshot, session: Session = None):
|
||||
return session.merge(screenshot)
|
||||
|
||||
@begin_session
|
||||
def get_screenshot(self, id, session: Session = None):
|
||||
return session.get(Screenshot, id)
|
||||
|
||||
@begin_session
|
||||
def get_screenshot_by_filename(self, file_name: str, session: Session = None):
|
||||
return session.scalars(
|
||||
select(Screenshot).filter_by(file_name=file_name).limit(1)
|
||||
).first()
|
||||
|
||||
@begin_session
|
||||
def update_screenshot(self, id: int, data: dict, session: Session = None):
|
||||
session.execute(
|
||||
update(Screenshot)
|
||||
.where(Screenshot.id == id)
|
||||
.values(**data)
|
||||
.execution_options(synchronize_session="evaluate")
|
||||
)
|
||||
|
||||
@begin_session
|
||||
def delete_screenshot(self, id: int, session: Session = None):
|
||||
return session.execute(
|
||||
delete(Screenshot)
|
||||
.where(Screenshot.id == id)
|
||||
.execution_options(synchronize_session="evaluate")
|
||||
)
|
||||
|
||||
@begin_session
|
||||
def purge_screenshots(
|
||||
self, rom_id: int, screenshots: list[str], session: Session = None
|
||||
):
|
||||
return session.execute(
|
||||
delete(Screenshot)
|
||||
.where(
|
||||
Screenshot.rom_id == rom_id,
|
||||
Screenshot.file_name.not_in(screenshots),
|
||||
)
|
||||
.execution_options(synchronize_session="evaluate")
|
||||
)
|
||||
|
||||
# ========= Users =========
|
||||
@begin_session
|
||||
def add_user(self, user: User, session: Session = None):
|
||||
return session.merge(user)
|
||||
|
||||
@begin_session
|
||||
def get_user_by_username(self, username: str, session: Session = None):
|
||||
return session.scalars(
|
||||
select(User).filter_by(username=username).limit(1)
|
||||
).first()
|
||||
|
||||
@begin_session
|
||||
def get_user(self, id: int, session: Session = None):
|
||||
return session.get(User, id)
|
||||
|
||||
@begin_session
|
||||
def update_user(self, id: int, data: dict, session: Session = None):
|
||||
session.execute(
|
||||
update(User)
|
||||
.where(User.id == id)
|
||||
.values(**data)
|
||||
.execution_options(synchronize_session="evaluate")
|
||||
)
|
||||
|
||||
@begin_session
|
||||
def delete_user(self, id: int, session: Session = None):
|
||||
return session.execute(
|
||||
delete(User)
|
||||
.where(User.id == id)
|
||||
.execution_options(synchronize_session="evaluate")
|
||||
)
|
||||
|
||||
@begin_session
|
||||
def get_users(self, session: Session = None):
|
||||
return session.scalars(select(User)).all()
|
||||
|
||||
@begin_session
|
||||
def get_admin_users(self, session: Session = None):
|
||||
return session.scalars(select(User).filter_by(role=Role.ADMIN)).all()
|
||||
|
||||
# ========= Stats =========
|
||||
@begin_session
|
||||
def get_platforms_count(self, session: Session = None):
|
||||
# Only count platforms with more then 0 roms
|
||||
return session.scalar(
|
||||
select(func.count())
|
||||
.select_from(Platform)
|
||||
.where(
|
||||
select(func.count())
|
||||
.select_from(Rom)
|
||||
.filter_by(platform_id=Platform.id)
|
||||
.as_scalar()
|
||||
> 0
|
||||
)
|
||||
)
|
||||
|
||||
@begin_session
|
||||
def get_roms_count(self, session: Session = None) -> int:
|
||||
return session.scalar(select(func.count()).select_from(Rom))
|
||||
|
||||
@begin_session
|
||||
def get_saves_count(self, session: Session = None) -> int:
|
||||
return session.scalar(select(func.count()).select_from(Save))
|
||||
|
||||
@begin_session
|
||||
def get_states_count(self, session: Session = None) -> int:
|
||||
return session.scalar(select(func.count()).select_from(State))
|
||||
|
||||
@begin_session
|
||||
def get_screenshots_count(self, session: Session = None) -> int:
|
||||
return session.scalar(select(func.count()).select_from(Screenshot))
|
||||
|
||||
@begin_session
|
||||
def get_total_filesize(self, session: Session = None) -> int:
|
||||
return 0
|
||||
@@ -1,10 +1,9 @@
|
||||
from models import Platform, Rom, Save, State
|
||||
from decorators.database import begin_session
|
||||
from handler.db_handler import DBHandler
|
||||
from models import Platform, Rom
|
||||
from sqlalchemy import delete, func, or_, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from decorators.database import begin_session
|
||||
from handler.db_handler.db_handler import DBHandler
|
||||
|
||||
|
||||
class DBPlatformsHandler(DBHandler):
|
||||
@begin_session
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from decorators.database import begin_session
|
||||
from handler.db_handler.db_handler import DBHandler
|
||||
from handler.db_handler import DBHandler
|
||||
from models import Rom
|
||||
from sqlalchemy import and_, delete, func, select, update
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from decorators.database import begin_session
|
||||
from handler.db_handler.db_handler import DBHandler
|
||||
from handler.db_handler import DBHandler
|
||||
from models import Save
|
||||
from sqlalchemy import and_, delete, select, update
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
51
backend/handler/db_handler/db_screenshots_handler.py
Normal file
51
backend/handler/db_handler/db_screenshots_handler.py
Normal file
@@ -0,0 +1,51 @@
|
||||
from decorators.database import begin_session
|
||||
from handler.db_handler import DBHandler
|
||||
from models import Screenshot
|
||||
from sqlalchemy import delete, select, update
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
|
||||
class DBScreenshotsHandler(DBHandler):
|
||||
@begin_session
|
||||
def add_screenshot(self, screenshot: Screenshot, session: Session = None):
|
||||
return session.merge(screenshot)
|
||||
|
||||
@begin_session
|
||||
def get_screenshot(self, id, session: Session = None):
|
||||
return session.get(Screenshot, id)
|
||||
|
||||
@begin_session
|
||||
def get_screenshot_by_filename(self, file_name: str, session: Session = None):
|
||||
return session.scalars(
|
||||
select(Screenshot).filter_by(file_name=file_name).limit(1)
|
||||
).first()
|
||||
|
||||
@begin_session
|
||||
def update_screenshot(self, id: int, data: dict, session: Session = None):
|
||||
session.execute(
|
||||
update(Screenshot)
|
||||
.where(Screenshot.id == id)
|
||||
.values(**data)
|
||||
.execution_options(synchronize_session="evaluate")
|
||||
)
|
||||
|
||||
@begin_session
|
||||
def delete_screenshot(self, id: int, session: Session = None):
|
||||
return session.execute(
|
||||
delete(Screenshot)
|
||||
.where(Screenshot.id == id)
|
||||
.execution_options(synchronize_session="evaluate")
|
||||
)
|
||||
|
||||
@begin_session
|
||||
def purge_screenshots(
|
||||
self, rom_id: int, screenshots: list[str], session: Session = None
|
||||
):
|
||||
return session.execute(
|
||||
delete(Screenshot)
|
||||
.where(
|
||||
Screenshot.rom_id == rom_id,
|
||||
Screenshot.file_name.not_in(screenshots),
|
||||
)
|
||||
.execution_options(synchronize_session="evaluate")
|
||||
)
|
||||
@@ -1,5 +1,5 @@
|
||||
from decorators.database import begin_session
|
||||
from handler.db_handler.db_handler import DBHandler
|
||||
from handler.db_handler import DBHandler
|
||||
from models import State
|
||||
from sqlalchemy import and_, delete, select, update
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
42
backend/handler/db_handler/db_stats_handler.py
Normal file
42
backend/handler/db_handler/db_stats_handler.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from decorators.database import begin_session
|
||||
from handler.db_handler import DBHandler
|
||||
from models import Platform, Rom, Save, Screenshot, State
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
|
||||
class DBStatsHandler(DBHandler):
|
||||
@begin_session
|
||||
def get_platforms_count(self, session: Session = None):
|
||||
# Only count platforms with more then 0 roms
|
||||
return session.scalar(
|
||||
select(func.count())
|
||||
.select_from(Platform)
|
||||
.where(
|
||||
select(func.count())
|
||||
.select_from(Rom)
|
||||
.filter_by(platform_id=Platform.id)
|
||||
.as_scalar()
|
||||
> 0
|
||||
)
|
||||
)
|
||||
|
||||
@begin_session
|
||||
def get_roms_count(self, session: Session = None) -> int:
|
||||
return session.scalar(select(func.count()).select_from(Rom))
|
||||
|
||||
@begin_session
|
||||
def get_saves_count(self, session: Session = None) -> int:
|
||||
return session.scalar(select(func.count()).select_from(Save))
|
||||
|
||||
@begin_session
|
||||
def get_states_count(self, session: Session = None) -> int:
|
||||
return session.scalar(select(func.count()).select_from(State))
|
||||
|
||||
@begin_session
|
||||
def get_screenshots_count(self, session: Session = None) -> int:
|
||||
return session.scalar(select(func.count()).select_from(Screenshot))
|
||||
|
||||
@begin_session
|
||||
def get_total_filesize(self, session: Session = None) -> int:
|
||||
return 0
|
||||
46
backend/handler/db_handler/db_users_handler.py
Normal file
46
backend/handler/db_handler/db_users_handler.py
Normal file
@@ -0,0 +1,46 @@
|
||||
from decorators.database import begin_session
|
||||
from handler.db_handler import DBHandler
|
||||
from models import Role, User
|
||||
from sqlalchemy import delete, select, update
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
|
||||
class DBUsersHandler(DBHandler):
|
||||
@begin_session
|
||||
def add_user(self, user: User, session: Session = None):
|
||||
return session.merge(user)
|
||||
|
||||
@begin_session
|
||||
def get_user_by_username(self, username: str, session: Session = None):
|
||||
return session.scalars(
|
||||
select(User).filter_by(username=username).limit(1)
|
||||
).first()
|
||||
|
||||
@begin_session
|
||||
def get_user(self, id: int, session: Session = None):
|
||||
return session.get(User, id)
|
||||
|
||||
@begin_session
|
||||
def update_user(self, id: int, data: dict, session: Session = None):
|
||||
session.execute(
|
||||
update(User)
|
||||
.where(User.id == id)
|
||||
.values(**data)
|
||||
.execution_options(synchronize_session="evaluate")
|
||||
)
|
||||
|
||||
@begin_session
|
||||
def delete_user(self, id: int, session: Session = None):
|
||||
return session.execute(
|
||||
delete(User)
|
||||
.where(User.id == id)
|
||||
.execution_options(synchronize_session="evaluate")
|
||||
)
|
||||
|
||||
@begin_session
|
||||
def get_users(self, session: Session = None):
|
||||
return session.scalars(select(User)).all()
|
||||
|
||||
@begin_session
|
||||
def get_admin_users(self, session: Session = None):
|
||||
return session.scalars(select(User).filter_by(role=Role.ADMIN)).all()
|
||||
@@ -1,7 +1,12 @@
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
from abc import ABC
|
||||
from enum import Enum
|
||||
from typing import Final
|
||||
|
||||
from config import ROMM_BASE_PATH
|
||||
from config import LIBRARY_BASE_PATH, ROMM_BASE_PATH
|
||||
from config.config_manager import config_manager as cm
|
||||
|
||||
RESOURCES_BASE_PATH: Final = f"{ROMM_BASE_PATH}/resources"
|
||||
DEFAULT_WIDTH_COVER_L: Final = 264 # Width of big cover of IGDB
|
||||
@@ -79,3 +84,43 @@ class Asset(Enum):
|
||||
SAVES = "saves"
|
||||
STATES = "states"
|
||||
SCREENSHOTS = "screenshots"
|
||||
|
||||
|
||||
class FSHandler(ABC):
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def get_fs_structure(fs_slug: str, folder: str = cm.config.ROMS_FOLDER_NAME):
|
||||
return (
|
||||
f"{folder}/{fs_slug}"
|
||||
if os.path.exists(cm.config.HIGH_PRIO_STRUCTURE_PATH)
|
||||
else f"{fs_slug}/{folder}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_file_name_with_no_extension(file_name: str) -> str:
|
||||
return re.sub(EXTENSION_REGEX, "", file_name).strip()
|
||||
|
||||
@staticmethod
|
||||
def get_file_name_with_no_tags(file_name: str) -> str:
|
||||
file_name_no_extension = re.sub(EXTENSION_REGEX, "", file_name).strip()
|
||||
return re.split(TAG_REGEX, file_name_no_extension)[0].strip()
|
||||
|
||||
@staticmethod
|
||||
def parse_file_extension(file_name) -> str:
|
||||
match = re.search(EXTENSION_REGEX, file_name)
|
||||
return match.group(1) if match else ""
|
||||
|
||||
@staticmethod
|
||||
def remove_file(file_name: str, file_path: str):
|
||||
try:
|
||||
os.remove(f"{LIBRARY_BASE_PATH}/{file_path}/{file_name}")
|
||||
except IsADirectoryError:
|
||||
shutil.rmtree(f"{LIBRARY_BASE_PATH}/{file_path}/{file_name}")
|
||||
|
||||
def build_upload_file_path(
|
||||
self, fs_slug: str, folder: str = cm.config.ROMS_FOLDER_NAME
|
||||
):
|
||||
rom_path = self.get_fs_structure(fs_slug, folder=folder)
|
||||
return f"{LIBRARY_BASE_PATH}/{rom_path}"
|
||||
|
||||
@@ -7,8 +7,7 @@ import requests
|
||||
from config import LIBRARY_BASE_PATH
|
||||
from config.config_manager import config_manager as cm
|
||||
from fastapi import UploadFile
|
||||
from handler.fs_handler import RESOURCES_BASE_PATH, Asset
|
||||
from handler.fs_handler.fs_handler import FSHandler
|
||||
from handler.fs_handler import RESOURCES_BASE_PATH, Asset, FSHandler
|
||||
from logger.logger import log
|
||||
|
||||
|
||||
@@ -86,19 +85,17 @@ class FSAssetsHandler(FSHandler):
|
||||
|
||||
saves_file_path = f"{LIBRARY_BASE_PATH}/{assets_path}"
|
||||
|
||||
fs_assets: list[str] = []
|
||||
# fs_states: list[str] = []
|
||||
# fs_screenshots: list[str] = []
|
||||
assets: list[str] = []
|
||||
|
||||
try:
|
||||
emulators = list(os.walk(saves_file_path))[0][1]
|
||||
for emulator in emulators:
|
||||
fs_assets += [
|
||||
assets += [
|
||||
(emulator, file)
|
||||
for file in list(os.walk(f"{saves_file_path}/{emulator}"))[0][2]
|
||||
]
|
||||
|
||||
fs_assets += [
|
||||
assets += [
|
||||
(None, file)
|
||||
for file in list(os.walk(saves_file_path))[0][2]
|
||||
if file.split(".")[0] == rom_file_name_no_tags
|
||||
@@ -106,40 +103,7 @@ class FSAssetsHandler(FSHandler):
|
||||
except IndexError:
|
||||
pass
|
||||
|
||||
# states_path = self.get_fs_structure(
|
||||
# platform_slug, folder=cm.config.STATES_FOLDER_NAME
|
||||
# )
|
||||
# states_file_path = f"{LIBRARY_BASE_PATH}/{states_path}"
|
||||
|
||||
# try:
|
||||
# emulators = list(os.walk(states_file_path))[0][1]
|
||||
# for emulator in emulators:
|
||||
# fs_states += [
|
||||
# (emulator, file)
|
||||
# for file in list(os.walk(f"{states_file_path}/{emulator}"))[0][2]
|
||||
# ]
|
||||
|
||||
# fs_states += [
|
||||
# (None, file) for file in list(os.walk(states_file_path))[0][2]
|
||||
# ]
|
||||
# except IndexError:
|
||||
# pass
|
||||
|
||||
# screenshots_path = self.get_fs_structure(
|
||||
# platform_slug, folder=cm.config.SCREENSHOTS_FOLDER_NAME
|
||||
# )
|
||||
# screenshots_file_path = f"{LIBRARY_BASE_PATH}/{screenshots_path}"
|
||||
|
||||
# try:
|
||||
# fs_screenshots += [
|
||||
# file for file in list(os.walk(screenshots_file_path))[0][2]
|
||||
# ]
|
||||
# except IndexError:
|
||||
# pass
|
||||
|
||||
return fs_assets
|
||||
# "states": fs_states,
|
||||
# "screenshots": fs_screenshots,
|
||||
return assets
|
||||
|
||||
@staticmethod
|
||||
def get_screenshots():
|
||||
|
||||
@@ -1,48 +0,0 @@
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
from abc import ABC
|
||||
|
||||
from config import LIBRARY_BASE_PATH
|
||||
from config.config_manager import config_manager as cm
|
||||
from handler.fs_handler import EXTENSION_REGEX, TAG_REGEX
|
||||
|
||||
|
||||
class FSHandler(ABC):
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def get_fs_structure(fs_slug: str, folder: str = cm.config.ROMS_FOLDER_NAME):
|
||||
return (
|
||||
f"{folder}/{fs_slug}"
|
||||
if os.path.exists(cm.config.HIGH_PRIO_STRUCTURE_PATH)
|
||||
else f"{fs_slug}/{folder}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_file_name_with_no_extension(file_name: str) -> str:
|
||||
return re.sub(EXTENSION_REGEX, "", file_name).strip()
|
||||
|
||||
@staticmethod
|
||||
def get_file_name_with_no_tags(file_name: str) -> str:
|
||||
file_name_no_extension = re.sub(EXTENSION_REGEX, "", file_name).strip()
|
||||
return re.split(TAG_REGEX, file_name_no_extension)[0].strip()
|
||||
|
||||
@staticmethod
|
||||
def parse_file_extension(file_name) -> str:
|
||||
match = re.search(EXTENSION_REGEX, file_name)
|
||||
return match.group(1) if match else ""
|
||||
|
||||
def build_upload_file_path(
|
||||
self, fs_slug: str, folder: str = cm.config.ROMS_FOLDER_NAME
|
||||
):
|
||||
rom_path = self.get_fs_structure(fs_slug, folder=folder)
|
||||
return f"{LIBRARY_BASE_PATH}/{rom_path}"
|
||||
|
||||
@staticmethod
|
||||
def remove_file(file_name: str, file_path: str):
|
||||
try:
|
||||
os.remove(f"{LIBRARY_BASE_PATH}/{file_path}/{file_name}")
|
||||
except IsADirectoryError:
|
||||
shutil.rmtree(f"{LIBRARY_BASE_PATH}/{file_path}/{file_name}")
|
||||
@@ -3,7 +3,7 @@ import os
|
||||
from config import LIBRARY_BASE_PATH
|
||||
from config.config_manager import config_manager as cm
|
||||
from exceptions.fs_exceptions import FolderStructureNotMatchException
|
||||
from handler.fs_handler.fs_handler import FSHandler
|
||||
from handler.fs_handler import FSHandler
|
||||
|
||||
|
||||
class FSPlatformsHandler(FSHandler):
|
||||
|
||||
@@ -18,8 +18,8 @@ from handler.fs_handler import (
|
||||
DEFAULT_WIDTH_COVER_S,
|
||||
RESOURCES_BASE_PATH,
|
||||
CoverSize,
|
||||
FSHandler,
|
||||
)
|
||||
from handler.fs_handler.fs_handler import FSHandler
|
||||
from PIL import Image
|
||||
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@ import fnmatch
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
from models.platform import Platform
|
||||
|
||||
from config import LIBRARY_BASE_PATH
|
||||
from config.config_manager import config_manager as cm
|
||||
@@ -13,8 +12,9 @@ from handler.fs_handler import (
|
||||
REGIONS_BY_SHORTCODE,
|
||||
REGIONS_NAME_KEYS,
|
||||
TAG_REGEX,
|
||||
FSHandler,
|
||||
)
|
||||
from handler.fs_handler.fs_handler import FSHandler
|
||||
from models.platform import Platform
|
||||
|
||||
|
||||
class FSRomsHandler(FSHandler):
|
||||
|
||||
@@ -284,7 +284,7 @@ class IGDBHandler:
|
||||
@check_twitch_token
|
||||
async def get_rom(self, file_name: str, platform_idgb_id: int) -> IGDBRomType:
|
||||
# TODO: refactor
|
||||
from handler.fs_handler.fs_handler import FSHandler
|
||||
from handler.fs_handler import FSHandler
|
||||
|
||||
get_search_term = FSHandler.get_file_name_with_no_tags
|
||||
search_term = get_search_term(file_name)
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import Any
|
||||
|
||||
import emoji
|
||||
from config.config_manager import config_manager as cm
|
||||
from handler import fsasseth, dbh, igdbh, fsresourceh, fsromh
|
||||
from handler import fsasseth, dbplatformh, igdbh, fsresourceh, fsromh
|
||||
from logger.logger import log
|
||||
from models import Platform, Rom, Save, Screenshot, State
|
||||
|
||||
|
||||
@@ -6,8 +6,7 @@ from sqlalchemy.orm import sessionmaker
|
||||
from config.config_manager import ConfigManager
|
||||
from models import Platform, Rom, User, Save, State, Screenshot
|
||||
from models.user import Role
|
||||
from utils.auth import get_password_hash
|
||||
from .. import dbh
|
||||
from handler import dbh, dbuserh, dbplatformh, dbromh, dbsaveh, dbstateh, authh
|
||||
|
||||
engine = create_engine(ConfigManager.get_db_engine(), pool_pre_ping=True)
|
||||
session = sessionmaker(bind=engine, expire_on_commit=False)
|
||||
@@ -34,7 +33,7 @@ def platform():
|
||||
platform = Platform(
|
||||
name="test_platform", slug="test_platform_slug", fs_slug="test_platform_slug"
|
||||
)
|
||||
return dbh.add_platform(platform)
|
||||
return dbplatformh.add_platform(platform)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -49,7 +48,7 @@ def rom(platform: Platform):
|
||||
file_path=f"{platform.slug}/roms",
|
||||
file_size_bytes=1000.0,
|
||||
)
|
||||
return dbh.add_rom(rom)
|
||||
return dbromh.add_rom(rom)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -64,7 +63,7 @@ def save(rom: Rom):
|
||||
file_path=f"{rom.platform_slug}/saves/test_emulator",
|
||||
file_size_bytes=1.0,
|
||||
)
|
||||
return dbh.add_save(save)
|
||||
return dbsaveh.add_save(save)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -79,7 +78,7 @@ def state(rom: Rom):
|
||||
file_path=f"{rom.platform_slug}/states/test_emulator",
|
||||
file_size_bytes=2.0,
|
||||
)
|
||||
return dbh.add_state(state)
|
||||
return dbstateh.add_state(state)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -98,27 +97,27 @@ def screenshot(rom: Rom):
|
||||
def admin_user():
|
||||
user = User(
|
||||
username="test_admin",
|
||||
hashed_password=get_password_hash("test_admin_password"),
|
||||
hashed_password=authh.get_password_hash("test_admin_password"),
|
||||
role=Role.ADMIN,
|
||||
)
|
||||
return dbh.add_user(user)
|
||||
return dbuserh.add_user(user)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def editor_user():
|
||||
user = User(
|
||||
username="test_editor",
|
||||
hashed_password=get_password_hash("test_editor_password"),
|
||||
hashed_password=authh.get_password_hash("test_editor_password"),
|
||||
role=Role.EDITOR,
|
||||
)
|
||||
return dbh.add_user(user)
|
||||
return dbuserh.add_user(user)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def viewer_user():
|
||||
user = User(
|
||||
username="test_viewer",
|
||||
hashed_password=get_password_hash("test_viewer_password"),
|
||||
hashed_password=authh.get_password_hash("test_viewer_password"),
|
||||
role=Role.VIEWER,
|
||||
)
|
||||
return dbh.add_user(user)
|
||||
return dbuserh.add_user(user)
|
||||
|
||||
@@ -184,14 +184,14 @@ def test_screenshots(screenshot):
|
||||
rom = dbh.get_rom(screenshot.rom_id)
|
||||
assert len(rom.screenshots) == 2
|
||||
|
||||
screenshot = dbh.get_screenshot(rom.screenshots[0].id)
|
||||
screenshot = dbscreenshotsh.get_screenshot(rom.screenshots[0].id)
|
||||
assert screenshot.file_name == "test_screenshot.png"
|
||||
|
||||
dbh.update_screenshot(screenshot.id, {"file_name": "test_screenshot_2.png"})
|
||||
dbscreenshotsh.update_screenshot(screenshot.id, {"file_name": "test_screenshot_2.png"})
|
||||
screenshot = dbh.get_screenshot(screenshot.id)
|
||||
assert screenshot.file_name == "test_screenshot_2.png"
|
||||
|
||||
dbh.delete_screenshot(screenshot.id)
|
||||
dbscreenshotsh.delete_screenshot(screenshot.id)
|
||||
|
||||
rom = dbh.get_rom(screenshot.rom_id)
|
||||
assert len(rom.screenshots) == 1
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import re
|
||||
import sys
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
import alembic.config
|
||||
import uvicorn
|
||||
@@ -17,12 +16,13 @@ from endpoints import (
|
||||
tasks,
|
||||
user,
|
||||
webrcade,
|
||||
stats,
|
||||
)
|
||||
from endpoints.sockets import scan
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi_pagination import add_pagination
|
||||
from handler import authh, dbh, ghh, socketh
|
||||
from handler import authh, dbuserh, ghh, socketh
|
||||
from handler.auth_handler.hybrid_auth import HybridAuthBackend
|
||||
from handler.auth_handler.middleware import CustomCSRFMiddleware
|
||||
from starlette.middleware.authentication import AuthenticationMiddleware
|
||||
@@ -71,44 +71,18 @@ app.include_router(states.router)
|
||||
app.include_router(tasks.router)
|
||||
app.include_router(webrcade.router)
|
||||
app.include_router(config.router)
|
||||
app.include_router(stats.router)
|
||||
|
||||
add_pagination(app)
|
||||
app.mount("/ws", socketh.socket_app)
|
||||
|
||||
|
||||
class StatsReturn(TypedDict):
|
||||
PLATFORMS: int
|
||||
ROMS: int
|
||||
SAVES: int
|
||||
STATES: int
|
||||
SCREENSHOTS: int
|
||||
FILESIZE: int
|
||||
|
||||
|
||||
@app.get("/stats")
|
||||
def stats() -> StatsReturn:
|
||||
"""Endpoint to return the current RomM stats
|
||||
|
||||
Returns:
|
||||
dict: Dictionary with all the stats
|
||||
"""
|
||||
|
||||
return {
|
||||
"PLATFORMS": dbh.get_platforms_count(),
|
||||
"ROMS": dbh.get_roms_count(),
|
||||
"SAVES": dbh.get_saves_count(),
|
||||
"STATES": dbh.get_states_count(),
|
||||
"SCREENSHOTS": dbh.get_screenshots_count(),
|
||||
"FILESIZE": dbh.get_total_filesize(),
|
||||
}
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
def startup() -> None:
|
||||
"""Event to handle RomM startup logic."""
|
||||
|
||||
# Create default admin user if no admin user exists
|
||||
if len(dbh.get_admin_users()) == 0 and "pytest" not in sys.modules:
|
||||
if len(dbuserh.get_admin_users()) == 0 and "pytest" not in sys.modules:
|
||||
authh.create_default_admin_user()
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user