diff --git a/backend/alembic/versions/0073_sync_sessions.py b/backend/alembic/versions/0073_sync_sessions.py index cd0b70312..2e16d5e57 100644 --- a/backend/alembic/versions/0073_sync_sessions.py +++ b/backend/alembic/versions/0073_sync_sessions.py @@ -21,7 +21,7 @@ depends_on = None def upgrade() -> None: connection = op.get_bind() if is_postgresql(connection): - rom_user_status_enum = ENUM( + sync_session_status_enum = ENUM( "PENDING", "IN_PROGRESS", "COMPLETED", @@ -30,9 +30,9 @@ def upgrade() -> None: name="syncsessionstatus", create_type=False, ) - rom_user_status_enum.create(connection, checkfirst=False) + sync_session_status_enum.create(connection, checkfirst=False) else: - rom_user_status_enum = sa.Enum( + sync_session_status_enum = sa.Enum( "PENDING", "IN_PROGRESS", "COMPLETED", @@ -86,7 +86,11 @@ def upgrade() -> None: "updated_at", sa.TIMESTAMP(timezone=True), nullable=False, - server_default=sa.text("CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP"), + server_default=( + sa.text("CURRENT_TIMESTAMP") + if is_postgresql(connection) + else sa.text("CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP") + ), ), sa.ForeignKeyConstraint(["device_id"], ["devices.id"], ondelete="CASCADE"), sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"), @@ -107,3 +111,7 @@ def downgrade() -> None: op.drop_index("ix_sync_sessions_device_id", table_name="sync_sessions") op.drop_table("sync_sessions") + + connection = op.get_bind() + if is_postgresql(connection): + ENUM(name="syncsessionstatus").drop(connection, checkfirst=True) diff --git a/backend/endpoints/responses/device.py b/backend/endpoints/responses/device.py index 22870ff52..9ac0226fd 100644 --- a/backend/endpoints/responses/device.py +++ b/backend/endpoints/responses/device.py @@ -1,9 +1,13 @@ -from pydantic import ConfigDict +from typing import Any + +from pydantic import ConfigDict, field_serializer from models.device import SyncMode from .base import BaseModel, UTCDatetime +SENSITIVE_SYNC_CONFIG_KEYS = {"ssh_password", "ssh_key_path"} + class DeviceSyncSchema(BaseModel): model_config = ConfigDict(from_attributes=True) @@ -34,6 +38,16 @@ class DeviceSchema(BaseModel): created_at: UTCDatetime updated_at: UTCDatetime + @field_serializer("sync_config") + @classmethod + def mask_sensitive_fields(cls, v: dict | None) -> dict[str, Any] | None: + if not v: + return v + return { + k: "********" if k in SENSITIVE_SYNC_CONFIG_KEYS else val + for k, val in v.items() + } + class DeviceCreateResponse(BaseModel): device_id: str diff --git a/backend/endpoints/responses/identity.py b/backend/endpoints/responses/identity.py index 674b2a8d7..29a9b504e 100644 --- a/backend/endpoints/responses/identity.py +++ b/backend/endpoints/responses/identity.py @@ -42,8 +42,9 @@ class UserSchema(BaseModel): if not db_user: return None - db_user.current_device_id = request.session.get("device_id") # type: ignore - return cls.model_validate(db_user) + schema = cls.model_validate(db_user) + schema.current_device_id = request.session.get("device_id") + return schema class InviteLinkSchema(BaseModel): diff --git a/backend/endpoints/saves.py b/backend/endpoints/saves.py index 66eb616e3..4fe481b39 100644 --- a/backend/endpoints/saves.py +++ b/backend/endpoints/saves.py @@ -96,6 +96,19 @@ def _resolve_device( return device +def _increment_session_counter(session_id: int, user_id: int) -> None: + try: + session = db_sync_session_handler.get_session( + session_id=session_id, user_id=user_id + ) + if session: + db_sync_session_handler.increment_operations_completed( + session_id=session_id, + ) + except Exception: + log.warning(f"Failed to update sync session {session_id}", exc_info=True) + + router = APIRouter( prefix="/saves", tags=["saves"], @@ -247,17 +260,7 @@ async def add_save( db_device_handler.update_last_seen(device_id=device.id, user_id=request.user.id) if session_id: - try: - session = db_sync_session_handler.get_session( - session_id=session_id, user_id=request.user.id - ) - if session: - db_sync_session_handler.update_session( - session_id=session_id, - data={"operations_completed": session.operations_completed + 1}, - ) - except Exception: - log.warning(f"Failed to update sync session {session_id}") + _increment_session_counter(session_id, request.user.id) if slot and autocleanup: slot_saves = db_save_handler.get_saves( @@ -454,17 +457,7 @@ def download_save( db_device_handler.update_last_seen(device_id=device.id, user_id=request.user.id) if session_id: - try: - session = db_sync_session_handler.get_session( - session_id=session_id, user_id=request.user.id - ) - if session: - db_sync_session_handler.update_session( - session_id=session_id, - data={"operations_completed": session.operations_completed + 1}, - ) - except Exception: - log.warning(f"Failed to update sync session {session_id}") + _increment_session_counter(session_id, request.user.id) return FileResponse(path=str(file_path), filename=save.file_name) diff --git a/backend/endpoints/sync.py b/backend/endpoints/sync.py index ff8cd35ed..bd659db7c 100644 --- a/backend/endpoints/sync.py +++ b/backend/endpoints/sync.py @@ -1,10 +1,10 @@ from datetime import datetime from fastapi import HTTPException, Request, status -from pydantic import BaseModel from config import TASK_TIMEOUT from decorators.auth import protected_route +from endpoints.responses.base import BaseModel from endpoints.responses.sync import ( SyncNegotiateResponse, SyncOperationSchema, @@ -359,6 +359,7 @@ def trigger_push_pull( high_prio_queue.enqueue( "tasks.sync_push_pull_task.run_push_pull_sync", device_id=device.id, + session_id=sync_session.id, force=True, job_timeout=TASK_TIMEOUT, meta={ diff --git a/backend/handler/database/sync_sessions_handler.py b/backend/handler/database/sync_sessions_handler.py index 8fed624d5..7e21afb1e 100644 --- a/backend/handler/database/sync_sessions_handler.py +++ b/backend/handler/database/sync_sessions_handler.py @@ -75,7 +75,22 @@ class DBSyncSessionsHandler(DBBaseHandler): .values(**data) .execution_options(synchronize_session="evaluate") ) - return session.query(SyncSession).filter_by(id=session_id).one() + return session.scalar(select(SyncSession).filter_by(id=session_id)) + + @begin_session + def increment_operations_completed( + self, + session_id: int, + session: Session = None, # type: ignore + ) -> None: + session.execute( + update(SyncSession) + .where(SyncSession.id == session_id) + .values( + operations_completed=SyncSession.operations_completed + 1, + ) + .execution_options(synchronize_session="evaluate") + ) @begin_session def complete_session( @@ -96,7 +111,7 @@ class DBSyncSessionsHandler(DBBaseHandler): ) .execution_options(synchronize_session="evaluate") ) - return session.query(SyncSession).filter_by(id=session_id).one() + return session.scalar(select(SyncSession).filter_by(id=session_id)) @begin_session def fail_session( @@ -115,7 +130,7 @@ class DBSyncSessionsHandler(DBBaseHandler): ) .execution_options(synchronize_session="evaluate") ) - return session.query(SyncSession).filter_by(id=session_id).one() + return session.scalar(select(SyncSession).filter_by(id=session_id)) @begin_session def cancel_active_sessions( diff --git a/backend/handler/filesystem/sync_handler.py b/backend/handler/filesystem/sync_handler.py index bf53b6c12..3aa4d9dbf 100644 --- a/backend/handler/filesystem/sync_handler.py +++ b/backend/handler/filesystem/sync_handler.py @@ -17,8 +17,7 @@ class FSSyncHandler(FSHandler): def build_incoming_path( self, device_id: str, platform_slug: str | None = None ) -> str: - """Build the relative incoming path for a device (and optional platform).""" - parts = [self.base_path, device_id, "incoming"] + parts = [device_id, "incoming"] if platform_slug: parts.append(platform_slug) return os.path.join(*parts) @@ -26,8 +25,7 @@ class FSSyncHandler(FSHandler): def build_outgoing_path( self, device_id: str, platform_slug: str | None = None ) -> str: - """Build the relative outgoing path for a device (and optional platform).""" - parts = [self.base_path, device_id, "outgoing"] + parts = [device_id, "outgoing"] if platform_slug: parts.append(platform_slug) return os.path.join(*parts) @@ -35,16 +33,14 @@ class FSSyncHandler(FSHandler): def build_conflicts_path( self, device_id: str, platform_slug: str | None = None ) -> str: - """Build the relative conflicts path for a device (and optional platform).""" - parts = [self.base_path, device_id, "conflicts"] + parts = [device_id, "conflicts"] if platform_slug: parts.append(platform_slug) return os.path.join(*parts) def ensure_device_directories(self, device_id: str) -> None: - """Create incoming/outgoing directory structure for a device.""" - incoming = Path(self.build_incoming_path(device_id)) - outgoing = Path(self.build_outgoing_path(device_id)) + incoming = self.base_path / self.build_incoming_path(device_id) + outgoing = self.base_path / self.build_outgoing_path(device_id) incoming.mkdir(parents=True, exist_ok=True) outgoing.mkdir(parents=True, exist_ok=True) diff --git a/backend/handler/sync/ssh_handler.py b/backend/handler/sync/ssh_handler.py index ae2c5ce28..74bdabd56 100644 --- a/backend/handler/sync/ssh_handler.py +++ b/backend/handler/sync/ssh_handler.py @@ -106,6 +106,10 @@ class SSHSyncHandler: "provide ssh_key_path/ssh_password in sync_config." ) + log.warning( + f"SSH host key verification disabled for {host} -- " + "connection is vulnerable to MITM attacks" + ) log.info(f"Connecting to {username}@{host}:{port}") return await asyncssh.connect(**connect_kwargs) diff --git a/backend/sync_watcher.py b/backend/sync_watcher.py index b83027897..827cb7cf6 100644 --- a/backend/sync_watcher.py +++ b/backend/sync_watcher.py @@ -46,32 +46,26 @@ Change = tuple[str, str] def _extract_device_and_platform(path: str) -> tuple[str, str, str] | None: """Extract device_id, platform_slug, and filename from a sync incoming path. - Expected path format: {build_incoming_path(device_id, platform_slug)}/filename.ext - i.e. {SYNC_BASE_PATH}/{device_id}/incoming/{platform_slug}/filename.ext + Expected path format: {SYNC_BASE_PATH}/{device_id}/incoming/{platform_slug}/filename.ext """ try: - rel_path = os.path.relpath(path) + rel_path = os.path.relpath(path, start=str(fs_sync_handler.base_path)) parts = rel_path.split(os.sep) - # Minimum: device_id / incoming / platform_slug / filename if len(parts) < 4 or parts[1] != "incoming": return None device_id = parts[0] platform_slug = parts[2] filename = parts[-1] - - # Validate path matches the canonical incoming path structure - expected_prefix = fs_sync_handler.build_incoming_path(device_id, platform_slug) - if not rel_path.startswith(expected_prefix): - return None - return (device_id, platform_slug, filename) except (ValueError, IndexError): return None def _ensure_conflicts_dir(device_id: str, platform_slug: str) -> str: - """Ensure the conflicts directory exists and return its path.""" - conflicts_dir = fs_sync_handler.build_conflicts_path(device_id, platform_slug) + conflicts_dir = str( + fs_sync_handler.base_path + / fs_sync_handler.build_conflicts_path(device_id, platform_slug) + ) os.makedirs(conflicts_dir, exist_ok=True) return conflicts_dir diff --git a/backend/tasks/sync_push_pull_task.py b/backend/tasks/sync_push_pull_task.py index 952aafcd4..370d24bec 100644 --- a/backend/tasks/sync_push_pull_task.py +++ b/backend/tasks/sync_push_pull_task.py @@ -26,7 +26,11 @@ from models.sync_session import SyncSessionStatus from tasks.tasks import PeriodicTask, TaskType -async def run_push_pull_sync(device_id: str | None = None, force: bool = False) -> dict: +async def run_push_pull_sync( + device_id: str | None = None, + session_id: int | None = None, + force: bool = False, +) -> dict: """Execute push-pull sync for one or all push_pull devices.""" if not ENABLE_SYNC_PUSH_PULL and not force: log.info("Push-pull sync not enabled, skipping") @@ -50,13 +54,13 @@ async def run_push_pull_sync(device_id: str | None = None, force: bool = False) for device in devices: if not device.sync_enabled: continue - result = await _sync_device(device) + result = await _sync_device(device, session_id=session_id) results.append(result) return {"status": "completed", "device_results": results} -async def _sync_device(device: Device) -> dict: +async def _sync_device(device: Device, session_id: int | None = None) -> dict: """Perform push-pull sync for a single device.""" sync_config = device.sync_config or {} if not sync_config.get("ssh_host"): @@ -71,10 +75,18 @@ async def _sync_device(device: Device) -> dict: emit_sync_started, ) - # Create sync session - sync_session = db_sync_session_handler.create_session( - device_id=device.id, user_id=device.user_id - ) + if session_id: + sync_session = db_sync_session_handler.get_session( + session_id=session_id, user_id=device.user_id + ) + if not sync_session: + sync_session = db_sync_session_handler.create_session( + device_id=device.id, user_id=device.user_id + ) + else: + sync_session = db_sync_session_handler.create_session( + device_id=device.id, user_id=device.user_id + ) await emit_sync_started( user_id=device.user_id, @@ -110,7 +122,6 @@ async def _sync_device(device: Device) -> dict: db_sync_session_handler.complete_session(session_id=sync_session.id) return {"device_id": device.id, "status": "no_directories"} - # List remote saves remote_saves = await ssh_sync_handler.list_remote_saves(conn, save_directories) log.info( f"Push-pull: found {len(remote_saves)} remote saves on device {device.id}" @@ -126,7 +137,6 @@ async def _sync_device(device: Device) -> dict: operations_planned = len(remote_saves) - # Process each remote save for remote_save in remote_saves: try: action = await _process_remote_save(device, conn, remote_save) @@ -158,14 +168,11 @@ async def _sync_device(device: Device) -> dict: current_file=remote_save.file_name, ) - # Check for server saves that need to be pushed to the device push_count = await _push_missing_saves( device, conn, remote_saves, save_directories ) completed += push_count - conn.close() - except Exception as e: log.error(f"Push-pull sync failed for device {device.id}: {e}", exc_info=True) db_sync_session_handler.fail_session( @@ -178,6 +185,8 @@ async def _sync_device(device: Device) -> dict: error_message=str(e), ) return {"device_id": device.id, "status": "failed", "error": str(e)} + finally: + conn.close() db_sync_session_handler.complete_session( session_id=sync_session.id, @@ -227,21 +236,11 @@ async def _process_remote_save( break if not matched_save: - # New save from device - download it - local_path, content_hash = await ssh_sync_handler.download_save( - conn, remote_save.path + log.info( + f"Push-pull: remote save {hl(remote_save.file_name)} " + f"on platform {remote_save.platform_slug} - no matching server save, skipping" ) - try: - # We have the file locally, but we need a ROM to attach it to. - # Without a clear ROM match, skip for now. - log.info( - f"Push-pull: new remote save {hl(remote_save.file_name)} " - f"on platform {remote_save.platform_slug} - no matching server save" - ) - return "skipped" - finally: - if os.path.exists(local_path): - os.unlink(local_path) + return "skipped" # Compare with existing save device_sync = db_device_save_sync_handler.get_sync( diff --git a/backend/tests/handler/filesystem/test_sync_handler.py b/backend/tests/handler/filesystem/test_sync_handler.py index 78ffcaa78..b6a156b71 100644 --- a/backend/tests/handler/filesystem/test_sync_handler.py +++ b/backend/tests/handler/filesystem/test_sync_handler.py @@ -27,36 +27,28 @@ class TestFSSyncHandler: def test_build_incoming_path(self, handler: FSSyncHandler): path = handler.build_incoming_path("device-1") - assert "device-1" in path - assert "incoming" in path + assert path == os.path.join("device-1", "incoming") def test_build_incoming_path_with_platform(self, handler): path = handler.build_incoming_path("device-1", "gba") - assert "device-1" in path - assert "incoming" in path - assert "gba" in path + assert path == os.path.join("device-1", "incoming", "gba") def test_build_outgoing_path(self, handler: FSSyncHandler): path = handler.build_outgoing_path("device-1") - assert "device-1" in path - assert "outgoing" in path + assert path == os.path.join("device-1", "outgoing") def test_build_outgoing_path_with_platform(self, handler: FSSyncHandler): path = handler.build_outgoing_path("device-1", "snes") - assert "device-1" in path - assert "outgoing" in path - assert "snes" in path + assert path == os.path.join("device-1", "outgoing", "snes") def test_build_conflicts_path(self, handler: FSSyncHandler): path = handler.build_conflicts_path("device-1", "gba") - assert "device-1" in path - assert "conflicts" in path - assert "gba" in path + assert path == os.path.join("device-1", "conflicts", "gba") def test_ensure_device_directories(self, handler: FSSyncHandler, temp_dir): handler.ensure_device_directories("test-device") - incoming = handler.build_incoming_path("test-device") - outgoing = handler.build_outgoing_path("test-device") + incoming = handler.base_path / handler.build_incoming_path("test-device") + outgoing = handler.base_path / handler.build_outgoing_path("test-device") assert os.path.isdir(incoming) assert os.path.isdir(outgoing) @@ -65,9 +57,10 @@ class TestFSSyncHandler: assert result == [] def test_list_incoming_files(self, handler: FSSyncHandler, temp_dir): - # Set up: create incoming/platform/file structure handler.ensure_device_directories("dev-1") - incoming_path = handler.build_incoming_path("dev-1", "gba") + incoming_path = str( + handler.base_path / handler.build_incoming_path("dev-1", "gba") + ) os.makedirs(incoming_path, exist_ok=True) test_file = os.path.join(incoming_path, "save.sav") with open(test_file, "wb") as f: @@ -114,7 +107,7 @@ class TestFSSyncHandler: def test_remove_incoming_file(self, handler: FSSyncHandler, temp_dir): handler.ensure_device_directories("dev-1") - incoming = handler.build_incoming_path("dev-1", "gba") + incoming = str(handler.base_path / handler.build_incoming_path("dev-1", "gba")) os.makedirs(incoming, exist_ok=True) test_file = os.path.join(incoming, "to_remove.sav") with open(test_file, "wb") as f: