diff --git a/RFC-0001-plan.md b/RFC-0001-plan.md new file mode 100644 index 000000000..0d3366dd9 --- /dev/null +++ b/RFC-0001-plan.md @@ -0,0 +1,192 @@ +# RFC 0001: Device Registration and Multi-Protocol Save Synchronization — Implementation Plan + +## Context + +Cross-device save synchronization is needed for RomM's expanding client ecosystem. The RFC defines three sync modes (API, File Transfer, Push-Pull) to accommodate devices with varying capabilities. The codebase already has foundational pieces: a `Device` model with `SyncMode` enum, `DeviceSaveSync` tracking, and save endpoints with conflict detection. This plan builds on those foundations. + +Implementation is split into 3 phases. **Phase 1 (API Mode)** is the immediate priority as it enables smart client apps. Phases 2 and 3 are outlined but will be planned in detail when Phase 1 is complete. + +--- + +## Phase 1: Foundation & API Sync Mode + +### 1. Database Schema Changes + +**Modify `backend/models/device.py`** — Add `sync_config` JSON column: + +```python +sync_config: Mapped[dict | None] = mapped_column(JSON, nullable=True) +``` + +Stores mode-specific configuration (SSH config for push-pull, sync folder for file transfer, save format preferences). + +**Create `backend/models/sync_session.py`** — New `SyncSession` model: + +- `id`: Integer PK, autoincrement +- `device_id`: String FK -> devices.id (CASCADE) +- `user_id`: Integer FK -> users.id (CASCADE) +- `status`: Enum (pending, in_progress, completed, failed, cancelled) +- `initiated_at`: TIMESTAMP +- `completed_at`: TIMESTAMP (nullable) +- `operations_planned`: Integer (default 0) +- `operations_completed`: Integer (default 0) +- `operations_failed`: Integer (default 0) +- `error_message`: String(1000), nullable +- Relationships: `device`, `user` + +**Create `backend/alembic/versions/0073_sync_sessions.py`** — Migration: + +- Creates `sync_sessions` table +- Adds `sync_config` JSON column to `devices` +- Adds indexes on `device_id`, `user_id`, `status` + +### 2. Shared Comparison Algorithm + +**Create `backend/handler/sync/__init__.py`** and **`backend/handler/sync/comparison.py`** + +Extract the save comparison logic into a reusable function used by all three sync modes: + +```python +def compare_save_state( + client_hash: str | None, + client_updated_at: datetime, + server_hash: str | None, + server_updated_at: datetime, + device_last_synced_at: datetime | None, +) -> SyncComparisonResult: # action: upload | download | conflict | no_op +``` + +Rules: + +- Same hash → `no_op` +- Client has save, server doesn't → `upload` +- Server has save, client doesn't → `download` (or `no_op` if device previously synced and intentionally deleted) +- Both have save, different hash → compare timestamps; client newer → `upload`, server newer → `download`, ambiguous → `conflict` + +### 3. Sync Session DB Handler + +**Create `backend/handler/database/sync_sessions_handler.py`** + +Methods: `create_session`, `get_session`, `get_active_session`, `update_session`, `complete_session`, `fail_session`, `get_sessions` + +**Modify `backend/handler/database/__init__.py`** — Register `db_sync_session_handler`. + +### 4. Sync Negotiation Endpoint + +**Create `backend/endpoints/sync.py`** — Router at `/sync` + +Core endpoint: `POST /api/sync/negotiate` + +Request payload: + +```python +class ClientSaveState(BaseModel): + rom_id: int + file_name: str + slot: str | None = None + emulator: str | None = None + content_hash: str | None = None + updated_at: datetime + file_size_bytes: int + +class SyncNegotiatePayload(BaseModel): + device_id: str + saves: list[ClientSaveState] +``` + +Response (in **`backend/endpoints/responses/sync.py`**): + +```python +class SyncOperation(BaseModel): + action: Literal["upload", "download", "conflict", "no_op"] + rom_id: int + save_id: int | None = None + file_name: str + slot: str | None = None + reason: str + server_updated_at: datetime | None = None + server_content_hash: str | None = None + +class SyncNegotiateResponse(BaseModel): + session_id: int + operations: list[SyncOperation] + total_upload: int + total_download: int + total_conflict: int + total_no_op: int +``` + +Negotiation logic: + +1. Validate device exists and belongs to user +2. Cancel any existing active session for this device +3. Create new `SyncSession` (pending) +4. For each client save: look up server save by `(rom_id, user_id, file_name)` or `(rom_id, user_id, slot)`, then run comparison algorithm +5. Check for server saves the client didn't mention (tracked, not untracked) → `download` +6. Return operations list + +Additional endpoints: + +- `POST /api/sync/{session_id}/complete` — mark session done +- `GET /api/sync/sessions` — list sessions for a device +- `GET /api/sync/sessions/{session_id}` — session detail + +### 5. Modifications to Existing Code + +**`backend/endpoints/device.py`** — Add `sync_mode` and `sync_config` to `DeviceCreatePayload` and `DeviceUpdatePayload`. + +**`backend/endpoints/responses/device.py`** — Add `sync_config: dict | None = None` to `DeviceSchema`. + +**`backend/endpoints/saves.py`** — Add optional `session_id: int | None = None` parameter to `add_save` and `download_save`. When provided, increment `operations_completed` on the session. + +**`backend/main.py`** — Register sync router. + +### 6. Files Summary + +| Action | File | +| ------ | ----------------------------------------------------------------- | +| Create | `backend/models/sync_session.py` | +| Create | `backend/alembic/versions/0073_sync_sessions.py` | +| Create | `backend/handler/sync/__init__.py` | +| Create | `backend/handler/sync/comparison.py` | +| Create | `backend/handler/database/sync_sessions_handler.py` | +| Create | `backend/endpoints/sync.py` | +| Create | `backend/endpoints/responses/sync.py` | +| Modify | `backend/models/device.py` — add `sync_config` column | +| Modify | `backend/handler/database/__init__.py` — register handler | +| Modify | `backend/endpoints/device.py` — sync_config/sync_mode in payloads | +| Modify | `backend/endpoints/responses/device.py` — sync_config in schema | +| Modify | `backend/endpoints/saves.py` — optional session_id | +| Modify | `backend/main.py` — register router | + +--- + +## Phase 2: File Transfer Mode (Outline) + +- Add `SYNC_BASE_PATH` config and `ENABLE_SYNC_FOLDER_WATCHER` flag +- Create sync folder watcher (`backend/watcher_sync.py`) following `backend/watcher.py` patterns +- Create periodic scan task (`backend/tasks/sync_folder_task.py`) as fallback +- Auto-create device sync folder structure on FILE_TRANSFER registration +- Folder convention: `/romm/sync/{device_id}/incoming/{platform_slug}/` and `.../outgoing/` +- Reuse comparison algorithm from Phase 1 +- Extend `fs_asset_handler` with sync folder path builders + +## Phase 3: Push-Pull Mode (Outline) + +- Add `asyncssh` dependency +- Create SSH/SFTP handler (`backend/handler/sync/ssh_handler.py`) +- Create periodic push-pull task (`backend/tasks/sync_push_pull_task.py`) +- Add manual trigger endpoint: `POST /api/sync/{device_id}/push-pull` +- Add SSH key management endpoints +- WebSocket notifications for sync progress (`backend/endpoints/sockets/sync.py`) +- Reuse comparison algorithm from Phase 1 + +--- + +## Verification + +1. **Unit tests**: Test comparison algorithm with all edge cases (same hash, client-only, server-only, newer/older, conflicts) +2. **Integration tests**: Full negotiate → upload/download → complete session flow (in `backend/tests/endpoints/test_sync.py`) +3. **Device tests**: Registration with sync_config, update sync_mode +4. **Session lifecycle**: Create, complete, fail, cancel +5. **Manual testing**: Use the existing save endpoints with session_id tracking diff --git a/backend/alembic/versions/0073_sync_sessions.py b/backend/alembic/versions/0073_sync_sessions.py new file mode 100644 index 000000000..38e92090f --- /dev/null +++ b/backend/alembic/versions/0073_sync_sessions.py @@ -0,0 +1,83 @@ +"""Add sync_sessions table and sync_config to devices + +Revision ID: 0073_sync_sessions +Revises: 0072_client_tokens +Create Date: 2026-03-14 00:00:00.000000 + +""" + +import sqlalchemy as sa +from alembic import op + +revision = "0073_sync_sessions" +down_revision = "0072_client_tokens" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.create_table( + "sync_sessions", + sa.Column("id", sa.Integer(), autoincrement=True, nullable=False), + sa.Column("device_id", sa.String(length=255), nullable=False), + sa.Column("user_id", sa.Integer(), nullable=False), + sa.Column( + "status", + sa.Enum( + "pending", + "in_progress", + "completed", + "failed", + "cancelled", + name="syncsessionstatus", + ), + nullable=False, + server_default="pending", + ), + sa.Column( + "initiated_at", + sa.TIMESTAMP(timezone=True), + nullable=False, + server_default=sa.text("CURRENT_TIMESTAMP"), + ), + sa.Column("completed_at", sa.TIMESTAMP(timezone=True), nullable=True), + sa.Column( + "operations_planned", sa.Integer(), nullable=False, server_default="0" + ), + sa.Column( + "operations_completed", sa.Integer(), nullable=False, server_default="0" + ), + sa.Column( + "operations_failed", sa.Integer(), nullable=False, server_default="0" + ), + sa.Column("error_message", sa.String(length=1000), nullable=True), + sa.Column( + "created_at", + sa.TIMESTAMP(timezone=True), + nullable=False, + server_default=sa.text("CURRENT_TIMESTAMP"), + ), + sa.Column( + "updated_at", + sa.TIMESTAMP(timezone=True), + nullable=False, + server_default=sa.text("CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP"), + ), + sa.ForeignKeyConstraint(["device_id"], ["devices.id"], ondelete="CASCADE"), + sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index("ix_sync_sessions_device_id", "sync_sessions", ["device_id"]) + op.create_index("ix_sync_sessions_user_id", "sync_sessions", ["user_id"]) + op.create_index("ix_sync_sessions_status", "sync_sessions", ["status"]) + + op.add_column("devices", sa.Column("sync_config", sa.JSON(), nullable=True)) + + +def downgrade() -> None: + op.drop_column("devices", "sync_config") + + op.drop_index("ix_sync_sessions_status", table_name="sync_sessions") + op.drop_index("ix_sync_sessions_user_id", table_name="sync_sessions") + op.drop_index("ix_sync_sessions_device_id", table_name="sync_sessions") + op.drop_table("sync_sessions") diff --git a/backend/config/__init__.py b/backend/config/__init__.py index c1dd10a88..1482a8990 100644 --- a/backend/config/__init__.py +++ b/backend/config/__init__.py @@ -212,6 +212,23 @@ SCHEDULED_RETROACHIEVEMENTS_PROGRESS_SYNC_CRON: Final[str] = _get_env( "0 4 * * *", # At 4:00 AM every day ) +# SYNC +SYNC_BASE_PATH: Final[str] = _get_env("SYNC_BASE_PATH", f"{ROMM_BASE_PATH}/sync") +ENABLE_SYNC_FOLDER_WATCHER: Final[bool] = safe_str_to_bool( + _get_env("ENABLE_SYNC_FOLDER_WATCHER") +) +SYNC_FOLDER_SCAN_DELAY: Final[int] = safe_int( + _get_env("SYNC_FOLDER_SCAN_DELAY"), 2 # 2 minutes +) +ENABLE_SYNC_PUSH_PULL: Final[bool] = safe_str_to_bool(_get_env("ENABLE_SYNC_PUSH_PULL")) +SYNC_PUSH_PULL_CRON: Final[str] = _get_env( + "SYNC_PUSH_PULL_CRON", + "*/30 * * * *", # Every 30 minutes +) +SYNC_SSH_KEYS_PATH: Final[str] = _get_env( + "SYNC_SSH_KEYS_PATH", f"{ROMM_BASE_PATH}/sync/keys" +) + # EMULATION DISABLE_EMULATOR_JS: Final[bool] = safe_str_to_bool(_get_env("DISABLE_EMULATOR_JS")) DISABLE_RUFFLE_RS: Final[bool] = safe_str_to_bool(_get_env("DISABLE_RUFFLE_RS")) diff --git a/backend/endpoints/device.py b/backend/endpoints/device.py index 8911dfbf4..7d42ea5d1 100644 --- a/backend/endpoints/device.py +++ b/backend/endpoints/device.py @@ -8,8 +8,9 @@ from decorators.auth import protected_route from endpoints.responses.device import DeviceCreateResponse, DeviceSchema from handler.auth.constants import Scope from handler.database import db_device_handler, db_device_save_sync_handler +from handler.filesystem import fs_sync_handler from logger.logger import log -from models.device import Device +from models.device import Device, SyncMode from utils.router import APIRouter router = APIRouter( @@ -26,6 +27,8 @@ class DeviceCreatePayload(BaseModel): ip_address: str | None = None mac_address: str | None = None hostname: str | None = None + sync_mode: SyncMode | None = None + sync_config: dict | None = None allow_existing: bool = True allow_duplicate: bool = False reset_syncs: bool = False @@ -46,6 +49,8 @@ class DeviceUpdatePayload(BaseModel): mac_address: str | None = None hostname: str | None = None sync_enabled: bool | None = None + sync_mode: SyncMode | None = None + sync_config: dict | None = None @protected_route(router.post, "", [Scope.DEVICES_WRITE]) @@ -107,12 +112,21 @@ def register_device( ip_address=payload.ip_address, mac_address=payload.mac_address, hostname=payload.hostname, + sync_mode=payload.sync_mode, + sync_config=payload.sync_config, last_seen=now, ) db_device = db_device_handler.add_device(device) log.info(f"Registered device {device_id} for user {request.user.username}") + # Auto-create sync folders for file_transfer devices + if payload.sync_mode == SyncMode.FILE_TRANSFER: + try: + fs_sync_handler.ensure_device_directories(device_id) + except Exception: + log.warning(f"Failed to create sync directories for device {device_id}") + return DeviceCreateResponse( device_id=db_device.id, name=db_device.name, diff --git a/backend/endpoints/responses/device.py b/backend/endpoints/responses/device.py index cfed1f0ba..2b76d778b 100644 --- a/backend/endpoints/responses/device.py +++ b/backend/endpoints/responses/device.py @@ -28,6 +28,7 @@ class DeviceSchema(BaseModel): hostname: str | None sync_mode: SyncMode sync_enabled: bool + sync_config: dict | None last_seen: datetime | None created_at: datetime updated_at: datetime diff --git a/backend/endpoints/responses/sync.py b/backend/endpoints/responses/sync.py new file mode 100644 index 000000000..6c2c89c21 --- /dev/null +++ b/backend/endpoints/responses/sync.py @@ -0,0 +1,43 @@ +from datetime import datetime +from typing import Literal + +from .base import BaseModel + + +class SyncOperationSchema(BaseModel): + action: Literal["upload", "download", "conflict", "no_op"] + rom_id: int + save_id: int | None = None + file_name: str + slot: str | None = None + emulator: str | None = None + reason: str + server_updated_at: datetime | None = None + server_content_hash: str | None = None + + +class SyncNegotiateResponse(BaseModel): + session_id: int + operations: list[SyncOperationSchema] + total_upload: int + total_download: int + total_conflict: int + total_no_op: int + + +class SyncSessionSchema(BaseModel): + id: int + device_id: str + user_id: int + status: str + initiated_at: datetime + completed_at: datetime | None = None + operations_planned: int + operations_completed: int + operations_failed: int + error_message: str | None = None + created_at: datetime + updated_at: datetime + + class Config: + from_attributes = True diff --git a/backend/endpoints/saves.py b/backend/endpoints/saves.py index f111be828..9aa7803bc 100644 --- a/backend/endpoints/saves.py +++ b/backend/endpoints/saves.py @@ -17,6 +17,7 @@ from handler.database import ( db_rom_handler, db_save_handler, db_screenshot_handler, + db_sync_session_handler, ) from handler.filesystem import fs_asset_handler from handler.scan_handler import scan_save, scan_screenshot @@ -116,6 +117,7 @@ async def add_save( emulator: str | None = None, slot: str | None = None, device_id: str | None = None, + session_id: int | None = None, overwrite: bool = False, autocleanup: bool = False, autocleanup_limit: int = 10, @@ -244,6 +246,19 @@ async def add_save( ) db_device_handler.update_last_seen(device_id=device.id, user_id=request.user.id) + if session_id: + try: + session = db_sync_session_handler.get_session( + session_id=session_id, user_id=request.user.id + ) + if session: + db_sync_session_handler.update_session( + session_id=session_id, + data={"operations_completed": session.operations_completed + 1}, + ) + except Exception: + log.warning(f"Failed to update sync session {session_id}") + if slot and autocleanup: slot_saves = db_save_handler.get_saves( user_id=request.user.id, @@ -401,6 +416,7 @@ def download_save( request: Request, id: int, device_id: str | None = None, + session_id: int | None = None, optimistic: bool = True, ) -> FileResponse: """Download a save file.""" @@ -437,6 +453,19 @@ def download_save( ) db_device_handler.update_last_seen(device_id=device.id, user_id=request.user.id) + if session_id: + try: + session = db_sync_session_handler.get_session( + session_id=session_id, user_id=request.user.id + ) + if session: + db_sync_session_handler.update_session( + session_id=session_id, + data={"operations_completed": session.operations_completed + 1}, + ) + except Exception: + log.warning(f"Failed to update sync session {session_id}") + return FileResponse(path=str(file_path), filename=save.file_name) diff --git a/backend/endpoints/sockets/sync.py b/backend/endpoints/sockets/sync.py new file mode 100644 index 000000000..61c80f7ef --- /dev/null +++ b/backend/endpoints/sockets/sync.py @@ -0,0 +1,127 @@ +"""WebSocket events for sync progress notifications. + +Emits events: +- sync:started - when a sync session begins +- sync:progress - periodic updates during sync +- sync:completed - when a sync session finishes +- sync:conflict - when a conflict is detected +- sync:error - when a sync operation fails + +Uses AsyncRedisManager in write-only mode so these can be called from +RQ background workers (push-pull task, folder watcher) that don't have +access to the main socket server instance. +""" + +import socketio # type: ignore + +from config import REDIS_URL + + +def _get_socket_manager() -> socketio.AsyncRedisManager: + """Create a write-only Redis manager for emitting from background tasks.""" + return socketio.AsyncRedisManager(REDIS_URL, write_only=True) + + +async def emit_sync_started( + user_id: int, + device_id: str, + session_id: int, + sync_mode: str, +) -> None: + """Notify that a sync session has started.""" + sm = _get_socket_manager() + await sm.emit( + "sync:started", + { + "device_id": device_id, + "session_id": session_id, + "sync_mode": sync_mode, + }, + room=f"user:{user_id}", + ) + + +async def emit_sync_progress( + user_id: int, + device_id: str, + session_id: int, + operations_completed: int, + operations_planned: int, + current_file: str | None = None, +) -> None: + """Notify sync progress update.""" + sm = _get_socket_manager() + await sm.emit( + "sync:progress", + { + "device_id": device_id, + "session_id": session_id, + "operations_completed": operations_completed, + "operations_planned": operations_planned, + "current_file": current_file, + }, + room=f"user:{user_id}", + ) + + +async def emit_sync_completed( + user_id: int, + device_id: str, + session_id: int, + operations_completed: int, + operations_failed: int, +) -> None: + """Notify that a sync session has completed.""" + sm = _get_socket_manager() + await sm.emit( + "sync:completed", + { + "device_id": device_id, + "session_id": session_id, + "operations_completed": operations_completed, + "operations_failed": operations_failed, + }, + room=f"user:{user_id}", + ) + + +async def emit_sync_conflict( + user_id: int, + device_id: str, + session_id: int, + file_name: str, + rom_id: int, + reason: str, +) -> None: + """Notify that a sync conflict was detected.""" + sm = _get_socket_manager() + await sm.emit( + "sync:conflict", + { + "device_id": device_id, + "session_id": session_id, + "file_name": file_name, + "rom_id": rom_id, + "reason": reason, + }, + room=f"user:{user_id}", + ) + + +async def emit_sync_error( + user_id: int, + device_id: str, + session_id: int, + error_message: str, +) -> None: + """Notify that a sync error occurred.""" + sm = _get_socket_manager() + await sm.emit( + "sync:error", + { + "device_id": device_id, + "session_id": session_id, + "error": error_message, + }, + room=f"user:{user_id}", + ) diff --git a/backend/endpoints/sync.py b/backend/endpoints/sync.py new file mode 100644 index 000000000..fca13760a --- /dev/null +++ b/backend/endpoints/sync.py @@ -0,0 +1,451 @@ +from datetime import datetime + +from fastapi import HTTPException, Request, UploadFile, status +from pydantic import BaseModel + +from config import TASK_TIMEOUT +from decorators.auth import protected_route +from endpoints.responses.sync import ( + SyncNegotiateResponse, + SyncOperationSchema, + SyncSessionSchema, +) +from handler.auth.constants import Scope +from handler.database import ( + db_device_handler, + db_device_save_sync_handler, + db_save_handler, + db_sync_session_handler, +) +from handler.redis_handler import high_prio_queue +from handler.sync.comparison import compare_save_state +from handler.sync.ssh_handler import ssh_sync_handler +from logger.logger import log +from models.assets import Save +from models.device import SyncMode +from models.sync_session import SyncSessionStatus +from utils.datetime import to_utc +from utils.router import APIRouter + +router = APIRouter( + prefix="/sync", + tags=["sync"], +) + + +class ClientSaveState(BaseModel): + rom_id: int + file_name: str + slot: str | None = None + emulator: str | None = None + content_hash: str | None = None + updated_at: datetime + file_size_bytes: int + + +class SyncNegotiatePayload(BaseModel): + device_id: str + saves: list[ClientSaveState] + + +class SyncCompletePayload(BaseModel): + operations_completed: int = 0 + operations_failed: int = 0 + + +@protected_route(router.post, "/negotiate", [Scope.ASSETS_READ, Scope.DEVICES_READ]) +def negotiate_sync( + request: Request, + payload: SyncNegotiatePayload, +) -> SyncNegotiateResponse: + """Negotiate sync operations between a client device and the server. + + The client sends its current save state, and the server returns a list of + operations (upload, download, conflict, no_op) to bring both sides in sync. + """ + device = db_device_handler.get_device( + device_id=payload.device_id, user_id=request.user.id + ) + if not device: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Device with ID {payload.device_id} not found", + ) + + if not device.sync_enabled: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Sync is disabled for this device", + ) + + # Cancel any existing active sessions for this device + cancelled = db_sync_session_handler.cancel_active_sessions( + device_id=device.id, user_id=request.user.id + ) + if cancelled: + log.info(f"Cancelled {cancelled} active sync session(s) for device {device.id}") + + # Create a new sync session + sync_session = db_sync_session_handler.create_session( + device_id=device.id, user_id=request.user.id + ) + + operations: list[SyncOperationSchema] = [] + + # Build a set of server saves for this user, keyed by (rom_id, file_name) + # We'll also track which server saves were mentioned by the client + server_saves = db_save_handler.get_saves(user_id=request.user.id) + server_save_map: dict[tuple[int, str], Save] = {} + for save in server_saves: + server_save_map[(save.rom_id, save.file_name)] = save + + # Get all sync records for this device + all_save_ids = [s.id for s in server_saves] + device_syncs = db_device_save_sync_handler.get_syncs_for_device_and_saves( + device_id=device.id, save_ids=all_save_ids + ) + sync_by_save_id = {s.save_id: s for s in device_syncs} + + # Track which server saves were referenced by the client + matched_server_save_ids: set[int] = set() + + # Process each client save + for client_save in payload.saves: + key = (client_save.rom_id, client_save.file_name) + server_save = server_save_map.get(key) + + if server_save is None: + # Client has a save the server doesn't -> upload + operations.append( + SyncOperationSchema( + action="upload", + rom_id=client_save.rom_id, + save_id=None, + file_name=client_save.file_name, + slot=client_save.slot, + emulator=client_save.emulator, + reason="Save exists on client but not on server", + ) + ) + continue + + matched_server_save_ids.add(server_save.id) + device_sync = sync_by_save_id.get(server_save.id) + + # Skip untracked saves + if device_sync and device_sync.is_untracked: + operations.append( + SyncOperationSchema( + action="no_op", + rom_id=server_save.rom_id, + save_id=server_save.id, + file_name=server_save.file_name, + slot=server_save.slot, + emulator=server_save.emulator, + reason="Save is untracked on this device", + ) + ) + continue + + result = compare_save_state( + client_hash=client_save.content_hash, + client_updated_at=client_save.updated_at, + server_hash=server_save.content_hash, + server_updated_at=server_save.updated_at, + device_last_synced_at=device_sync.last_synced_at if device_sync else None, + ) + + operations.append( + SyncOperationSchema( + action=result.action, + rom_id=server_save.rom_id, + save_id=server_save.id, + file_name=server_save.file_name, + slot=server_save.slot, + emulator=server_save.emulator, + reason=result.reason, + server_updated_at=server_save.updated_at, + server_content_hash=server_save.content_hash, + ) + ) + + # Check for server saves the client didn't mention + for save in server_saves: + if save.id in matched_server_save_ids: + continue + + device_sync = sync_by_save_id.get(save.id) + + # Skip untracked saves + if device_sync and device_sync.is_untracked: + continue + + # If device has synced this save before and the save hasn't changed, + # the client intentionally deleted it - treat as no_op + if device_sync: + synced_ts = to_utc(device_sync.last_synced_at) + save_ts = to_utc(save.updated_at) + if save_ts <= synced_ts: + # Save hasn't changed since device last synced - client deleted it + continue + + # Save changed after device last synced - device should download + operations.append( + SyncOperationSchema( + action="download", + rom_id=save.rom_id, + save_id=save.id, + file_name=save.file_name, + slot=save.slot, + emulator=save.emulator, + reason="Server save updated since last sync, not present on client", + server_updated_at=save.updated_at, + server_content_hash=save.content_hash, + ) + ) + else: + # Device has never synced this save - download it + operations.append( + SyncOperationSchema( + action="download", + rom_id=save.rom_id, + save_id=save.id, + file_name=save.file_name, + slot=save.slot, + emulator=save.emulator, + reason="Save exists on server but not on client", + server_updated_at=save.updated_at, + server_content_hash=save.content_hash, + ) + ) + + # Update session with operation counts + total_upload = sum(1 for op in operations if op.action == "upload") + total_download = sum(1 for op in operations if op.action == "download") + total_conflict = sum(1 for op in operations if op.action == "conflict") + total_no_op = sum(1 for op in operations if op.action == "no_op") + + db_sync_session_handler.update_session( + session_id=sync_session.id, + data={ + "status": SyncSessionStatus.IN_PROGRESS, + "operations_planned": total_upload + total_download + total_conflict, + }, + ) + + # Update device last_seen + db_device_handler.update_last_seen(device_id=device.id, user_id=request.user.id) + + log.info( + f"Sync negotiation for device {device.id}: " + f"{total_upload} uploads, {total_download} downloads, " + f"{total_conflict} conflicts, {total_no_op} no-ops" + ) + + return SyncNegotiateResponse( + session_id=sync_session.id, + operations=operations, + total_upload=total_upload, + total_download=total_download, + total_conflict=total_conflict, + total_no_op=total_no_op, + ) + + +@protected_route(router.post, "/sessions/{session_id}/complete", [Scope.DEVICES_WRITE]) +def complete_sync_session( + request: Request, + session_id: int, + payload: SyncCompletePayload, +) -> SyncSessionSchema: + """Mark a sync session as completed.""" + sync_session = db_sync_session_handler.get_session( + session_id=session_id, user_id=request.user.id + ) + if not sync_session: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Sync session with ID {session_id} not found", + ) + + if sync_session.status not in ( + SyncSessionStatus.PENDING, + SyncSessionStatus.IN_PROGRESS, + ): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Session is already {sync_session.status}", + ) + + completed = db_sync_session_handler.complete_session( + session_id=session_id, + operations_completed=payload.operations_completed, + operations_failed=payload.operations_failed, + ) + + log.info( + f"Sync session {session_id} completed: " + f"{payload.operations_completed} succeeded, {payload.operations_failed} failed" + ) + + return SyncSessionSchema.model_validate(completed) + + +@protected_route(router.get, "/sessions", [Scope.DEVICES_READ]) +def get_sync_sessions( + request: Request, + device_id: str | None = None, + limit: int = 50, +) -> list[SyncSessionSchema]: + """List sync sessions for the current user.""" + sessions = db_sync_session_handler.get_sessions( + user_id=request.user.id, + device_id=device_id, + limit=limit, + ) + return [SyncSessionSchema.model_validate(s) for s in sessions] + + +@protected_route(router.get, "/sessions/{session_id}", [Scope.DEVICES_READ]) +def get_sync_session( + request: Request, + session_id: int, +) -> SyncSessionSchema: + """Get a specific sync session.""" + sync_session = db_sync_session_handler.get_session( + session_id=session_id, user_id=request.user.id + ) + if not sync_session: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Sync session with ID {session_id} not found", + ) + + return SyncSessionSchema.model_validate(sync_session) + + +# --- Push-Pull Mode Endpoints --- + + +@protected_route(router.post, "/devices/{device_id}/push-pull", [Scope.DEVICES_WRITE]) +def trigger_push_pull( + request: Request, + device_id: str, +) -> SyncSessionSchema: + """Manually trigger a push-pull sync for a specific device.""" + device = db_device_handler.get_device(device_id=device_id, user_id=request.user.id) + if not device: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Device with ID {device_id} not found", + ) + + if device.sync_mode != SyncMode.PUSH_PULL: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Device is not in push_pull sync mode", + ) + + if not device.sync_enabled: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Sync is disabled for this device", + ) + + # Create a session and enqueue the job + sync_session = db_sync_session_handler.create_session( + device_id=device.id, user_id=request.user.id + ) + + high_prio_queue.enqueue( + "tasks.sync_push_pull_task.run_push_pull_sync", + device_id=device.id, + force=True, + job_timeout=TASK_TIMEOUT, + meta={ + "task_name": "Push-Pull Sync", + "task_type": "sync", + }, + ) + + log.info(f"Enqueued push-pull sync for device {device.id}") + return SyncSessionSchema.model_validate(sync_session) + + +@protected_route(router.post, "/devices/{device_id}/ssh-key", [Scope.DEVICES_WRITE]) +async def upload_ssh_key( + request: Request, + device_id: str, + keyFile: UploadFile, +) -> dict: + """Upload an SSH private key for a push-pull device.""" + device = db_device_handler.get_device(device_id=device_id, user_id=request.user.id) + if not device: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Device with ID {device_id} not found", + ) + + key_data = await keyFile.read() + if not key_data: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Empty key file", + ) + + # Validate it looks like a private key + key_str = key_data.decode("utf-8", errors="replace") + if "PRIVATE KEY" not in key_str: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="File does not appear to be a valid SSH private key", + ) + + key_path = ssh_sync_handler.store_key(device.id, key_data) + + # Update device sync_config with the key path + sync_config = device.sync_config or {} + sync_config["ssh_key_path"] = str(key_path) + db_device_handler.update_device( + device_id=device.id, + user_id=request.user.id, + data={"sync_config": sync_config}, + ) + + log.info(f"Stored SSH key for device {device.id}") + return {"status": "ok", "key_path": str(key_path)} + + +@protected_route( + router.delete, + "/devices/{device_id}/ssh-key", + [Scope.DEVICES_WRITE], + status_code=status.HTTP_204_NO_CONTENT, +) +def delete_ssh_key( + request: Request, + device_id: str, +) -> None: + """Remove an SSH private key for a device.""" + device = db_device_handler.get_device(device_id=device_id, user_id=request.user.id) + if not device: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Device with ID {device_id} not found", + ) + + removed = ssh_sync_handler.remove_key(device.id) + if not removed: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="No SSH key found for this device", + ) + + # Remove key path from sync_config + sync_config = device.sync_config or {} + sync_config.pop("ssh_key_path", None) + db_device_handler.update_device( + device_id=device.id, + user_id=request.user.id, + data={"sync_config": sync_config}, + ) diff --git a/backend/handler/database/__init__.py b/backend/handler/database/__init__.py index 1b0867cc4..c14d871b5 100644 --- a/backend/handler/database/__init__.py +++ b/backend/handler/database/__init__.py @@ -9,6 +9,7 @@ from .saves_handler import DBSavesHandler from .screenshots_handler import DBScreenshotsHandler from .states_handler import DBStatesHandler from .stats_handler import DBStatsHandler +from .sync_sessions_handler import DBSyncSessionsHandler from .users_handler import DBUsersHandler db_client_token_handler = DBClientTokensHandler() @@ -22,4 +23,5 @@ db_save_handler = DBSavesHandler() db_screenshot_handler = DBScreenshotsHandler() db_state_handler = DBStatesHandler() db_stats_handler = DBStatsHandler() +db_sync_session_handler = DBSyncSessionsHandler() db_user_handler = DBUsersHandler() diff --git a/backend/handler/database/devices_handler.py b/backend/handler/database/devices_handler.py index 81318aa9a..d0ef1dc28 100644 --- a/backend/handler/database/devices_handler.py +++ b/backend/handler/database/devices_handler.py @@ -5,7 +5,7 @@ from sqlalchemy import delete, select, update from sqlalchemy.orm import Session from decorators.database import begin_session -from models.device import Device +from models.device import Device, SyncMode from .base_handler import DBBaseHandler @@ -57,6 +57,15 @@ class DBDevicesHandler(DBBaseHandler): return None + @begin_session + def get_device_by_id( + self, + device_id: str, + session: Session = None, # type: ignore + ) -> Device | None: + """Get a device by ID without user filtering (for server-side operations).""" + return session.scalar(select(Device).filter_by(id=device_id).limit(1)) + @begin_session def get_devices( self, @@ -65,6 +74,15 @@ class DBDevicesHandler(DBBaseHandler): ) -> Sequence[Device]: return session.scalars(select(Device).filter_by(user_id=user_id)).all() + @begin_session + def get_all_devices_by_sync_mode( + self, + sync_mode: SyncMode, + session: Session = None, # type: ignore + ) -> Sequence[Device]: + """Get all devices with a specific sync mode (across all users).""" + return session.scalars(select(Device).filter_by(sync_mode=sync_mode)).all() + @begin_session def update_device( self, diff --git a/backend/handler/database/sync_sessions_handler.py b/backend/handler/database/sync_sessions_handler.py new file mode 100644 index 000000000..8fed624d5 --- /dev/null +++ b/backend/handler/database/sync_sessions_handler.py @@ -0,0 +1,166 @@ +from collections.abc import Sequence +from datetime import datetime, timezone + +from sqlalchemy import select, update +from sqlalchemy.orm import Session + +from decorators.database import begin_session +from models.sync_session import SyncSession, SyncSessionStatus + +from .base_handler import DBBaseHandler + + +class DBSyncSessionsHandler(DBBaseHandler): + @begin_session + def create_session( + self, + device_id: str, + user_id: int, + session: Session = None, # type: ignore + ) -> SyncSession: + sync_session = SyncSession( + device_id=device_id, + user_id=user_id, + status=SyncSessionStatus.PENDING, + initiated_at=datetime.now(timezone.utc), + ) + session.add(sync_session) + session.flush() + return sync_session + + @begin_session + def get_session( + self, + session_id: int, + user_id: int, + session: Session = None, # type: ignore + ) -> SyncSession | None: + return session.scalar( + select(SyncSession).filter_by(id=session_id, user_id=user_id).limit(1) + ) + + @begin_session + def get_active_session( + self, + device_id: str, + user_id: int, + session: Session = None, # type: ignore + ) -> SyncSession | None: + return session.scalar( + select(SyncSession) + .filter( + SyncSession.device_id == device_id, + SyncSession.user_id == user_id, + SyncSession.status.in_( + [ + SyncSessionStatus.PENDING, + SyncSessionStatus.IN_PROGRESS, + ] + ), + ) + .order_by(SyncSession.initiated_at.desc()) + .limit(1) + ) + + @begin_session + def update_session( + self, + session_id: int, + data: dict, + session: Session = None, # type: ignore + ) -> SyncSession: + session.execute( + update(SyncSession) + .where(SyncSession.id == session_id) + .values(**data) + .execution_options(synchronize_session="evaluate") + ) + return session.query(SyncSession).filter_by(id=session_id).one() + + @begin_session + def complete_session( + self, + session_id: int, + operations_completed: int = 0, + operations_failed: int = 0, + session: Session = None, # type: ignore + ) -> SyncSession: + session.execute( + update(SyncSession) + .where(SyncSession.id == session_id) + .values( + status=SyncSessionStatus.COMPLETED, + completed_at=datetime.now(timezone.utc), + operations_completed=operations_completed, + operations_failed=operations_failed, + ) + .execution_options(synchronize_session="evaluate") + ) + return session.query(SyncSession).filter_by(id=session_id).one() + + @begin_session + def fail_session( + self, + session_id: int, + error_message: str | None = None, + session: Session = None, # type: ignore + ) -> SyncSession: + session.execute( + update(SyncSession) + .where(SyncSession.id == session_id) + .values( + status=SyncSessionStatus.FAILED, + completed_at=datetime.now(timezone.utc), + error_message=error_message, + ) + .execution_options(synchronize_session="evaluate") + ) + return session.query(SyncSession).filter_by(id=session_id).one() + + @begin_session + def cancel_active_sessions( + self, + device_id: str, + user_id: int, + session: Session = None, # type: ignore + ) -> int: + """Cancel all active sessions for a device. Returns count of cancelled sessions.""" + result = session.execute( + update(SyncSession) + .where( + SyncSession.device_id == device_id, + SyncSession.user_id == user_id, + SyncSession.status.in_( + [ + SyncSessionStatus.PENDING, + SyncSessionStatus.IN_PROGRESS, + ] + ), + ) + .values( + status=SyncSessionStatus.CANCELLED, + completed_at=datetime.now(timezone.utc), + ) + .execution_options(synchronize_session="evaluate") + ) + return result.rowcount + + @begin_session + def get_sessions( + self, + user_id: int, + device_id: str | None = None, + status: SyncSessionStatus | None = None, + limit: int = 50, + session: Session = None, # type: ignore + ) -> Sequence[SyncSession]: + query = select(SyncSession).filter_by(user_id=user_id) + + if device_id: + query = query.filter_by(device_id=device_id) + + if status: + query = query.filter_by(status=status) + + query = query.order_by(SyncSession.initiated_at.desc()).limit(limit) + return session.scalars(query).all() diff --git a/backend/handler/filesystem/__init__.py b/backend/handler/filesystem/__init__.py index 871ccdb6e..d11d2df22 100644 --- a/backend/handler/filesystem/__init__.py +++ b/backend/handler/filesystem/__init__.py @@ -3,9 +3,11 @@ from .firmware_handler import FSFirmwareHandler from .platforms_handler import FSPlatformsHandler from .resources_handler import FSResourcesHandler from .roms_handler import FSRomsHandler +from .sync_handler import FSSyncHandler fs_asset_handler = FSAssetsHandler() fs_firmware_handler = FSFirmwareHandler() fs_platform_handler = FSPlatformsHandler() fs_rom_handler = FSRomsHandler() fs_resource_handler = FSResourcesHandler() +fs_sync_handler = FSSyncHandler() diff --git a/backend/handler/filesystem/sync_handler.py b/backend/handler/filesystem/sync_handler.py new file mode 100644 index 000000000..51b61552c --- /dev/null +++ b/backend/handler/filesystem/sync_handler.py @@ -0,0 +1,102 @@ +import hashlib +import os +from pathlib import Path + +from config import SYNC_BASE_PATH +from logger.logger import log + +from .base_handler import FSHandler + + +class FSSyncHandler(FSHandler): + """Filesystem handler for sync folder operations (File Transfer mode).""" + + def __init__(self) -> None: + super().__init__(base_path=SYNC_BASE_PATH) + + def build_incoming_path( + self, device_id: str, platform_slug: str | None = None + ) -> str: + parts = [device_id, "incoming"] + if platform_slug: + parts.append(platform_slug) + return os.path.join(*parts) + + def build_outgoing_path( + self, device_id: str, platform_slug: str | None = None + ) -> str: + parts = [device_id, "outgoing"] + if platform_slug: + parts.append(platform_slug) + return os.path.join(*parts) + + def ensure_device_directories(self, device_id: str) -> None: + """Create incoming/outgoing directory structure for a device.""" + device_base = self.base_path / device_id + incoming = device_base / "incoming" + outgoing = device_base / "outgoing" + + incoming.mkdir(parents=True, exist_ok=True) + outgoing.mkdir(parents=True, exist_ok=True) + + log.info(f"Ensured sync directories for device {device_id}") + + def list_incoming_files(self, device_id: str) -> list[dict]: + """List all files in a device's incoming directory. + + Returns list of dicts with keys: platform_slug, file_name, full_path, file_size, mtime + """ + incoming_dir = self.base_path / device_id / "incoming" + if not incoming_dir.exists(): + return [] + + results = [] + for platform_dir in incoming_dir.iterdir(): + if not platform_dir.is_dir(): + continue + platform_slug = platform_dir.name + for file_path in platform_dir.rglob("*"): + if not file_path.is_file(): + continue + stat = file_path.stat() + results.append( + { + "platform_slug": platform_slug, + "file_name": file_path.name, + "full_path": str(file_path), + "relative_path": str(file_path.relative_to(incoming_dir)), + "file_size": stat.st_size, + "mtime": stat.st_mtime, + } + ) + + return results + + def compute_file_hash(self, file_path: str) -> str: + """Compute MD5 hash of a file synchronously (for watcher context).""" + hash_obj = hashlib.md5(usedforsecurity=False) + with open(file_path, "rb") as f: + while chunk := f.read(8192): + hash_obj.update(chunk) + return hash_obj.hexdigest() + + def write_outgoing_file( + self, device_id: str, platform_slug: str, file_name: str, data: bytes + ) -> str: + """Write a file to a device's outgoing directory.""" + outgoing_dir = self.base_path / device_id / "outgoing" / platform_slug + outgoing_dir.mkdir(parents=True, exist_ok=True) + file_path = outgoing_dir / file_name + file_path.write_bytes(data) + return str(file_path) + + def remove_incoming_file(self, full_path: str) -> None: + """Remove a processed file from the incoming directory.""" + path = Path(full_path) + if path.exists() and path.is_file(): + # Validate the file is within our base path + try: + path.resolve().relative_to(self.base_path.resolve()) + except ValueError: + raise ValueError(f"Path {full_path} is outside the sync base directory") + path.unlink() diff --git a/backend/handler/sync/__init__.py b/backend/handler/sync/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/backend/handler/sync/comparison.py b/backend/handler/sync/comparison.py new file mode 100644 index 000000000..980eb872a --- /dev/null +++ b/backend/handler/sync/comparison.py @@ -0,0 +1,73 @@ +from __future__ import annotations + +from datetime import datetime +from typing import Literal, NamedTuple + +from utils.datetime import to_utc + +SyncAction = Literal["upload", "download", "conflict", "no_op"] + + +class SyncComparisonResult(NamedTuple): + action: SyncAction + reason: str + + +def compare_save_state( + *, + client_hash: str | None, + client_updated_at: datetime, + server_hash: str | None, + server_updated_at: datetime, + device_last_synced_at: datetime | None, +) -> SyncComparisonResult: + """Compare client and server save state to determine the sync action. + + Returns a (action, reason) tuple where action is one of: + - upload: client save should be uploaded to server + - download: server save should be downloaded to client + - conflict: both sides changed, needs resolution + - no_op: saves are in sync + """ + client_ts = to_utc(client_updated_at) + server_ts = to_utc(server_updated_at) + + # If hashes match, saves are identical + if client_hash and server_hash and client_hash == server_hash: + return SyncComparisonResult("no_op", "Content is identical") + + # If we have a last sync timestamp, use it to determine which side changed + if device_last_synced_at: + synced_ts = to_utc(device_last_synced_at) + client_changed = client_ts > synced_ts + server_changed = server_ts > synced_ts + + if client_changed and server_changed: + return SyncComparisonResult( + "conflict", "Both sides changed since last sync" + ) + + if client_changed: + return SyncComparisonResult("upload", "Client save is newer than last sync") + + if server_changed: + return SyncComparisonResult( + "download", "Server save is newer than last sync" + ) + + return SyncComparisonResult("no_op", "No changes since last sync") + + # No sync history: fall back to timestamp comparison + if client_ts > server_ts: + return SyncComparisonResult("upload", "Client save is newer (no sync history)") + + if server_ts > client_ts: + return SyncComparisonResult( + "download", "Server save is newer (no sync history)" + ) + + # Same timestamp, different hashes (or missing hashes) + if client_hash != server_hash: + return SyncComparisonResult("conflict", "Same timestamp but different content") + + return SyncComparisonResult("no_op", "Saves appear identical") diff --git a/backend/handler/sync/ssh_handler.py b/backend/handler/sync/ssh_handler.py new file mode 100644 index 000000000..602d45db1 --- /dev/null +++ b/backend/handler/sync/ssh_handler.py @@ -0,0 +1,210 @@ +"""SSH/SFTP handler for Push-Pull sync mode. + +Provides methods to connect to remote devices via SSH, list remote save files, +and perform bidirectional file transfers using SFTP. +""" + +from __future__ import annotations + +import hashlib +import os +import tempfile +from dataclasses import dataclass +from datetime import datetime, timezone +from pathlib import Path +from typing import Any + +import asyncssh + +from config import SYNC_SSH_KEYS_PATH +from logger.logger import log + + +@dataclass +class RemoteSaveInfo: + """Information about a save file on a remote device.""" + + path: str + file_name: str + platform_slug: str + file_size: int + mtime: datetime + content_hash: str | None = None + + +class SSHSyncHandler: + """Handles SSH/SFTP operations for push-pull sync mode.""" + + def __init__(self) -> None: + self.keys_path = Path(SYNC_SSH_KEYS_PATH) + self.keys_path.mkdir(parents=True, exist_ok=True) + + def get_key_path(self, device_id: str) -> Path: + """Get the SSH key path for a device.""" + return self.keys_path / f"{device_id}.pem" + + def has_key(self, device_id: str) -> bool: + """Check if an SSH key exists for a device.""" + return self.get_key_path(device_id).is_file() + + def store_key(self, device_id: str, key_data: bytes) -> Path: + """Store an SSH private key for a device.""" + key_path = self.get_key_path(device_id) + key_path.write_bytes(key_data) + key_path.chmod(0o600) + log.info(f"Stored SSH key for device {device_id}") + return key_path + + def remove_key(self, device_id: str) -> bool: + """Remove an SSH private key for a device.""" + key_path = self.get_key_path(device_id) + if key_path.is_file(): + key_path.unlink() + log.info(f"Removed SSH key for device {device_id}") + return True + return False + + async def connect(self, sync_config: dict) -> asyncssh.SSHClientConnection: + """Establish an SSH connection using device sync_config. + + sync_config should contain: + - ssh_host: hostname or IP + - ssh_port: port (default 22) + - ssh_username: username + - ssh_key_path: path to private key (or device_id to look up) + - ssh_password: password (optional, alternative to key) + """ + host = sync_config["ssh_host"] + port = sync_config.get("ssh_port", 22) + username = sync_config.get("ssh_username", "root") + + connect_kwargs: dict[str, Any] = { + "host": host, + "port": port, + "username": username, + "known_hosts": None, # Accept all host keys (TODO: make configurable) + } + + # Try key-based auth first + key_path = sync_config.get("ssh_key_path") + if key_path and os.path.isfile(key_path): + connect_kwargs["client_keys"] = [key_path] + elif sync_config.get("ssh_password"): + connect_kwargs["password"] = sync_config["ssh_password"] + else: + raise ValueError( + f"No SSH authentication method available for {host}. " + "Provide ssh_key_path or ssh_password in sync_config." + ) + + log.info(f"Connecting to {username}@{host}:{port}") + return await asyncssh.connect(**connect_kwargs) + + async def list_remote_saves( + self, + conn: asyncssh.SSHClientConnection, + save_directories: list[dict], + ) -> list[RemoteSaveInfo]: + """List save files on a remote device. + + save_directories is a list of dicts with keys: + - platform_slug: str + - path: str (remote directory path) + - extension: str (optional, file extension filter, e.g. ".srm") + """ + results: list[RemoteSaveInfo] = [] + + async with conn.start_sftp_client() as sftp: + for dir_config in save_directories: + platform_slug = dir_config["platform_slug"] + remote_path = dir_config["path"] + extension = dir_config.get("extension", "") + + try: + entries = await sftp.listdir(remote_path) + except asyncssh.SFTPNoSuchFile: + log.warning(f"Remote directory not found: {remote_path}") + continue + + for entry in entries: + if extension and not entry.endswith(extension): + continue + + full_remote_path = f"{remote_path}/{entry}" + try: + attrs = await sftp.stat(full_remote_path) + if not attrs.type == asyncssh.constants.FILEXFER_TYPE_REGULAR: + continue + + mtime = datetime.fromtimestamp( + attrs.mtime or 0, tz=timezone.utc + ) + results.append( + RemoteSaveInfo( + path=full_remote_path, + file_name=entry, + platform_slug=platform_slug, + file_size=attrs.size or 0, + mtime=mtime, + ) + ) + except asyncssh.SFTPError as e: + log.warning(f"Failed to stat {full_remote_path}: {e}") + + return results + + async def download_save( + self, + conn: asyncssh.SSHClientConnection, + remote_path: str, + local_path: str | None = None, + ) -> tuple[str, str]: + """Download a save file from a remote device. + + Returns (local_temp_path, content_hash). + """ + if local_path is None: + fd, local_path = tempfile.mkstemp(prefix="romm_sync_") + os.close(fd) + + async with conn.start_sftp_client() as sftp: + await sftp.get(remote_path, local_path) + + # Compute hash + hash_obj = hashlib.md5(usedforsecurity=False) + with open(local_path, "rb") as f: + while chunk := f.read(8192): + hash_obj.update(chunk) + + return local_path, hash_obj.hexdigest() + + async def upload_save( + self, + conn: asyncssh.SSHClientConnection, + local_path: str, + remote_path: str, + ) -> None: + """Upload a save file to a remote device.""" + async with conn.start_sftp_client() as sftp: + # Ensure remote directory exists + remote_dir = os.path.dirname(remote_path) + try: + await sftp.mkdir(remote_dir) + except asyncssh.SFTPError: + pass # Directory likely already exists + + await sftp.put(local_path, remote_path) + log.info(f"Uploaded {local_path} -> {remote_path}") + + async def delete_remote_save( + self, + conn: asyncssh.SSHClientConnection, + remote_path: str, + ) -> None: + """Delete a save file from a remote device.""" + async with conn.start_sftp_client() as sftp: + await sftp.remove(remote_path) + log.info(f"Deleted remote file: {remote_path}") + + +ssh_sync_handler = SSHSyncHandler() diff --git a/backend/main.py b/backend/main.py index fad20037a..88d90cafe 100644 --- a/backend/main.py +++ b/backend/main.py @@ -15,6 +15,7 @@ from startup import main import endpoints.sockets.netplay # noqa import endpoints.sockets.scan # noqa +import endpoints.sockets.sync # noqa from config import ( DEV_HOST, DEV_PORT, @@ -42,6 +43,7 @@ from endpoints.screenshots import router as screenshots_router from endpoints.search import router as search_router from endpoints.states import router as states_router from endpoints.stats import router as stats_router +from endpoints.sync import router as sync_router from endpoints.tasks import router as tasks_router from endpoints.user import router as user_router from handler.auth.hybrid_auth import HybridAuthBackend @@ -131,6 +133,7 @@ app.include_router(rom_router, prefix="/api") app.include_router(search_router, prefix="/api") app.include_router(saves_router, prefix="/api") app.include_router(states_router, prefix="/api") +app.include_router(sync_router, prefix="/api") app.include_router(tasks_router, prefix="/api") app.include_router(feeds_router, prefix="/api") app.include_router(configs_router, prefix="/api") diff --git a/backend/models/device.py b/backend/models/device.py index 15d24febc..911471d26 100644 --- a/backend/models/device.py +++ b/backend/models/device.py @@ -4,7 +4,7 @@ import enum from datetime import datetime from typing import TYPE_CHECKING -from sqlalchemy import TIMESTAMP, Boolean, Enum, ForeignKey, String +from sqlalchemy import JSON, TIMESTAMP, Boolean, Enum, ForeignKey, String from sqlalchemy.orm import Mapped, mapped_column, relationship from models.base import BaseModel @@ -38,6 +38,7 @@ class Device(BaseModel): sync_mode: Mapped[SyncMode] = mapped_column(Enum(SyncMode), default=SyncMode.API) sync_enabled: Mapped[bool] = mapped_column(Boolean, default=True) + sync_config: Mapped[dict | None] = mapped_column(JSON, nullable=True) last_seen: Mapped[datetime | None] = mapped_column(TIMESTAMP(timezone=True)) diff --git a/backend/models/sync_session.py b/backend/models/sync_session.py new file mode 100644 index 000000000..505611b26 --- /dev/null +++ b/backend/models/sync_session.py @@ -0,0 +1,51 @@ +from __future__ import annotations + +import enum +from datetime import datetime +from typing import TYPE_CHECKING + +from sqlalchemy import TIMESTAMP, Enum, ForeignKey, Integer, String +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from models.base import BaseModel + +if TYPE_CHECKING: + from models.device import Device + from models.user import User + + +class SyncSessionStatus(enum.StrEnum): + PENDING = "pending" + IN_PROGRESS = "in_progress" + COMPLETED = "completed" + FAILED = "failed" + CANCELLED = "cancelled" + + +class SyncSession(BaseModel): + __tablename__ = "sync_sessions" + __table_args__ = {"extend_existing": True} + + id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) + device_id: Mapped[str] = mapped_column( + String(255), ForeignKey("devices.id", ondelete="CASCADE"), index=True + ) + user_id: Mapped[int] = mapped_column( + ForeignKey("users.id", ondelete="CASCADE"), index=True + ) + + status: Mapped[SyncSessionStatus] = mapped_column( + Enum(SyncSessionStatus), default=SyncSessionStatus.PENDING, index=True + ) + initiated_at: Mapped[datetime] = mapped_column(TIMESTAMP(timezone=True)) + completed_at: Mapped[datetime | None] = mapped_column( + TIMESTAMP(timezone=True), nullable=True + ) + + operations_planned: Mapped[int] = mapped_column(Integer, default=0) + operations_completed: Mapped[int] = mapped_column(Integer, default=0) + operations_failed: Mapped[int] = mapped_column(Integer, default=0) + error_message: Mapped[str | None] = mapped_column(String(1000), nullable=True) + + device: Mapped[Device] = relationship(lazy="joined") + user: Mapped[User] = relationship(lazy="joined") diff --git a/backend/startup.py b/backend/startup.py index e366e195e..9181960ad 100644 --- a/backend/startup.py +++ b/backend/startup.py @@ -11,6 +11,7 @@ from config import ( ENABLE_SCHEDULED_RETROACHIEVEMENTS_PROGRESS_SYNC, ENABLE_SCHEDULED_UPDATE_LAUNCHBOX_METADATA, ENABLE_SCHEDULED_UPDATE_SWITCH_TITLEDB, + ENABLE_SYNC_PUSH_PULL, SENTRY_DSN, ) from handler.metadata.base_handler import ( @@ -33,6 +34,7 @@ from tasks.scheduled.sync_retroachievements_progress import ( ) from tasks.scheduled.update_launchbox_metadata import update_launchbox_metadata_task from tasks.scheduled.update_switch_titledb import update_switch_titledb_task +from tasks.sync_push_pull_task import sync_push_pull_task from utils import get_version from utils.cache import conditionally_set_cache from utils.context import initialize_context @@ -65,6 +67,9 @@ async def main() -> None: if ENABLE_SCHEDULED_RETROACHIEVEMENTS_PROGRESS_SYNC: log.info("Starting scheduled RetroAchievements progress sync") sync_retroachievements_progress_task.init() + if ENABLE_SYNC_PUSH_PULL: + log.info("Starting scheduled push-pull sync") + sync_push_pull_task.init() log.info("Initializing cache with fixtures data") await conditionally_set_cache( diff --git a/backend/sync_watcher.py b/backend/sync_watcher.py new file mode 100644 index 000000000..eb7335d4c --- /dev/null +++ b/backend/sync_watcher.py @@ -0,0 +1,333 @@ +"""Sync folder watcher for File Transfer mode. + +This module is invoked by watchfiles when changes are detected in the sync +folder (SYNC_BASE_PATH). It processes incoming save files from devices that +use file_transfer sync mode. + +The watcher is configured to run as a separate watchfiles process monitoring +the sync base path. When files appear in a device's incoming/ directory, they +are matched to ROMs and processed as save uploads. +""" + +import asyncio +import json +import os +import shutil +from collections.abc import Sequence +from datetime import datetime, timezone +from typing import cast + +import sentry_sdk + +from config import ENABLE_SYNC_FOLDER_WATCHER, SENTRY_DSN, SYNC_BASE_PATH +from handler.database import ( + db_device_handler, + db_device_save_sync_handler, + db_platform_handler, + db_save_handler, + db_sync_session_handler, +) +from handler.filesystem import fs_asset_handler, fs_sync_handler +from handler.sync.comparison import compare_save_state +from logger.formatter import highlight as hl +from logger.logger import log +from models.device import SyncMode +from models.sync_session import SyncSessionStatus +from utils import get_version + +sentry_sdk.init( + dsn=SENTRY_DSN, + release=f"romm@{get_version()}", +) + +Change = tuple[str, str] + + +def _extract_device_and_platform(path: str) -> tuple[str, str, str] | None: + """Extract device_id, platform_slug, and filename from a sync incoming path. + + Expected path format: {SYNC_BASE_PATH}/{device_id}/incoming/{platform_slug}/filename.ext + """ + try: + rel_path = os.path.relpath(path, SYNC_BASE_PATH) + parts = rel_path.split(os.sep) + # Minimum: device_id / incoming / platform_slug / filename + if len(parts) < 4 or parts[1] != "incoming": + return None + device_id = parts[0] + platform_slug = parts[2] + filename = parts[-1] + return (device_id, platform_slug, filename) + except (ValueError, IndexError): + return None + + +def _ensure_conflicts_dir(device_id: str, platform_slug: str) -> str: + """Ensure the conflicts directory exists and return its path.""" + conflicts_dir = os.path.join(SYNC_BASE_PATH, device_id, "conflicts", platform_slug) + os.makedirs(conflicts_dir, exist_ok=True) + return conflicts_dir + + +def process_sync_changes(changes: Sequence[Change]) -> None: + """Process file changes detected in the sync folder.""" + if not ENABLE_SYNC_FOLDER_WATCHER: + return + + # Only process added/modified files in incoming directories + added_files: list[tuple[str, str, str, str]] = ( + [] + ) # (device_id, platform_slug, filename, full_path) + for _event_type, change_path in changes: + src_path = os.fsdecode(change_path) + + # Only process files (not directories) + if not os.path.isfile(src_path): + continue + + parsed = _extract_device_and_platform(src_path) + if not parsed: + continue + + device_id, platform_slug, filename = parsed + added_files.append((device_id, platform_slug, filename, src_path)) + + if not added_files: + return + + # Group by device + by_device: dict[str, list[tuple[str, str, str]]] = {} + for device_id, platform_slug, filename, full_path in added_files: + by_device.setdefault(device_id, []).append((platform_slug, filename, full_path)) + + for device_id, files in by_device.items(): + _process_device_incoming(device_id, files) + + +def _process_device_incoming( + device_id: str, + files: list[tuple[str, str, str]], # (platform_slug, filename, full_path) +) -> None: + """Process incoming files for a single device.""" + from endpoints.sockets.sync import ( + emit_sync_completed, + emit_sync_error, + emit_sync_started, + ) + + # Look up device - try all users since file transfer is server-side + device = db_device_handler.get_device_by_id(device_id) + if not device: + log.warning(f"Sync watcher: unknown device {device_id}, skipping") + return + + if device.sync_mode != SyncMode.FILE_TRANSFER: + log.warning( + f"Sync watcher: device {device_id} is not in file_transfer mode, skipping" + ) + return + + if not device.sync_enabled: + log.info(f"Sync watcher: device {device_id} sync is disabled, skipping") + return + + # Create a sync session + sync_session = db_sync_session_handler.create_session( + device_id=device.id, user_id=device.user_id + ) + db_sync_session_handler.update_session( + session_id=sync_session.id, + data={ + "status": SyncSessionStatus.IN_PROGRESS, + "operations_planned": len(files), + }, + ) + + asyncio.run( + emit_sync_started( + user_id=device.user_id, + device_id=device.id, + session_id=sync_session.id, + sync_mode="file_transfer", + ) + ) + + completed = 0 + failed = 0 + + for platform_slug, filename, full_path in files: + try: + _process_incoming_file( + device, sync_session.id, platform_slug, filename, full_path + ) + completed += 1 + except Exception: + log.error( + f"Sync watcher: failed to process {filename} for device {device_id}", + exc_info=True, + ) + failed += 1 + + # Complete the session + db_sync_session_handler.complete_session( + session_id=sync_session.id, + operations_completed=completed, + operations_failed=failed, + ) + + if failed > 0: + asyncio.run( + emit_sync_error( + user_id=device.user_id, + device_id=device.id, + session_id=sync_session.id, + error_message=f"{failed} file(s) failed to process", + ) + ) + + asyncio.run( + emit_sync_completed( + user_id=device.user_id, + device_id=device.id, + session_id=sync_session.id, + operations_completed=completed, + operations_failed=failed, + ) + ) + + log.info( + f"Sync watcher: device {device_id} processed {completed} files, {failed} failures" + ) + + +def _process_incoming_file( + device, session_id: int, platform_slug: str, filename: str, full_path: str +) -> None: + """Process a single incoming file from a device's sync folder.""" + from endpoints.sockets.sync import emit_sync_conflict + + # Look up platform + platform = db_platform_handler.get_platform_by_fs_slug(platform_slug) + if not platform: + log.warning(f"Sync watcher: unknown platform slug {platform_slug}") + return + + # Compute hash of incoming file + file_hash = fs_sync_handler.compute_file_hash(full_path) + file_size = os.path.getsize(full_path) + file_mtime = datetime.fromtimestamp(os.path.getmtime(full_path), tz=timezone.utc) + + # Try to find matching saves on this platform for this user + saves_on_platform = db_save_handler.get_saves( + user_id=device.user_id, + platform_id=platform.id, + ) + + matched_save = None + for save in saves_on_platform: + if save.file_name == filename: + matched_save = save + break + + if matched_save: + # Compare with existing save + device_sync = db_device_save_sync_handler.get_sync( + device_id=device.id, save_id=matched_save.id + ) + result = compare_save_state( + client_hash=file_hash, + client_updated_at=file_mtime, + server_hash=matched_save.content_hash, + server_updated_at=matched_save.updated_at, + device_last_synced_at=device_sync.last_synced_at if device_sync else None, + ) + + if result.action == "no_op": + log.debug(f"Sync watcher: {filename} is already in sync, skipping") + fs_sync_handler.remove_incoming_file(full_path) + return + + if result.action == "upload": + # Client file is newer - update server save + log.info( + f"Sync watcher: updating save {hl(filename)} from device {device.id}" + ) + with open(full_path, "rb") as f: + file_data = f.read() + asyncio.run( + fs_asset_handler.write_file( + file=file_data, + path=matched_save.file_path, + filename=matched_save.file_name, + ) + ) + db_save_handler.update_save( + matched_save.id, + { + "file_size_bytes": file_size, + "content_hash": file_hash, + }, + ) + db_device_save_sync_handler.upsert_sync( + device_id=device.id, + save_id=matched_save.id, + synced_at=datetime.now(timezone.utc), + ) + fs_sync_handler.remove_incoming_file(full_path) + + elif result.action == "conflict": + log.warning( + f"Sync watcher: conflict detected for {filename} " + f"on device {device.id}: {result.reason}" + ) + # Move conflicting file to conflicts directory + conflicts_dir = _ensure_conflicts_dir(device.id, platform_slug) + conflict_path = os.path.join(conflicts_dir, filename) + shutil.move(full_path, conflict_path) + log.info(f"Sync watcher: moved conflicting file to {conflict_path}") + + # Emit socket notification for conflict + asyncio.run( + emit_sync_conflict( + user_id=device.user_id, + device_id=device.id, + session_id=session_id, + file_name=filename, + rom_id=matched_save.rom_id, + reason=result.reason, + ) + ) + + elif result.action == "download": + # Server is newer - write server save to device's outgoing directory + log.info( + f"Sync watcher: server save is newer for {filename}, " + f"writing to outgoing" + ) + server_file_path = f"{matched_save.file_path}/{matched_save.file_name}" + server_full_path = fs_asset_handler.validate_path(server_file_path) + with open(str(server_full_path), "rb") as f: + server_data = f.read() + fs_sync_handler.write_outgoing_file( + device_id=device.id, + platform_slug=platform_slug, + file_name=filename, + data=server_data, + ) + db_device_save_sync_handler.upsert_sync( + device_id=device.id, + save_id=matched_save.id, + synced_at=datetime.now(timezone.utc), + ) + fs_sync_handler.remove_incoming_file(full_path) + else: + log.info( + f"Sync watcher: new file {hl(filename)} from device {device.id} " + f"on platform {platform_slug} - no matching save found, skipping" + ) + + +if __name__ == "__main__": + changes = cast(list[Change], json.loads(os.getenv("WATCHFILES_CHANGES", "[]"))) + if changes: + process_sync_changes(changes) diff --git a/backend/tasks/sync_folder_task.py b/backend/tasks/sync_folder_task.py new file mode 100644 index 000000000..39269b1fe --- /dev/null +++ b/backend/tasks/sync_folder_task.py @@ -0,0 +1,64 @@ +"""Periodic task to scan sync folders for unprocessed files. + +This serves as a fallback for the file watcher, handling cases where +filesystem events are missed (e.g., server restart, NFS mounts). +""" + +from typing import Any + +from config import ENABLE_SYNC_FOLDER_WATCHER +from handler.database import db_device_handler +from handler.filesystem import fs_sync_handler +from logger.logger import log +from models.device import SyncMode +from tasks.tasks import Task, TaskType + + +class SyncFolderScanTask(Task): + """Scan device sync folders for unprocessed incoming files.""" + + def __init__(self) -> None: + super().__init__( + title="Sync Folder Scan", + description="Scan device sync folders for new save files", + task_type=TaskType.SYNC, + enabled=ENABLE_SYNC_FOLDER_WATCHER, + manual_run=True, + ) + + async def run(self, *args: Any, **kwargs: Any) -> dict: + if not self.enabled: + log.info("Sync folder scan not enabled, skipping") + return {"status": "disabled"} + + # Get all file_transfer devices + devices = db_device_handler.get_all_devices_by_sync_mode(SyncMode.FILE_TRANSFER) + if not devices: + log.info("No file_transfer devices found") + return {"status": "no_devices"} + + total_files = 0 + for device in devices: + if not device.sync_enabled: + continue + + incoming_files = fs_sync_handler.list_incoming_files(device.id) + if incoming_files: + log.info( + f"Sync folder scan: found {len(incoming_files)} files " + f"for device {device.id}" + ) + # Import here to avoid circular imports + from sync_watcher import _process_device_incoming + + file_tuples = [ + (f["platform_slug"], f["file_name"], f["full_path"]) + for f in incoming_files + ] + _process_device_incoming(device.id, file_tuples) + total_files += len(incoming_files) + + return {"status": "completed", "files_processed": total_files} + + +sync_folder_scan_task = SyncFolderScanTask() diff --git a/backend/tasks/sync_push_pull_task.py b/backend/tasks/sync_push_pull_task.py new file mode 100644 index 000000000..f15c9f3fa --- /dev/null +++ b/backend/tasks/sync_push_pull_task.py @@ -0,0 +1,420 @@ +"""Background task for Push-Pull sync mode. + +Connects to devices via SSH/SFTP, scans their save directories, +and performs bidirectional sync operations. +""" + +import os +from datetime import datetime, timezone +from typing import Any + +from config import ENABLE_SYNC_PUSH_PULL, SYNC_PUSH_PULL_CRON +from handler.database import ( + db_device_handler, + db_device_save_sync_handler, + db_platform_handler, + db_save_handler, + db_sync_session_handler, +) +from handler.filesystem import fs_asset_handler +from handler.sync.comparison import compare_save_state +from handler.sync.ssh_handler import ssh_sync_handler +from logger.formatter import highlight as hl +from logger.logger import log +from models.device import Device, SyncMode +from models.sync_session import SyncSessionStatus +from tasks.tasks import PeriodicTask, TaskType + + +async def run_push_pull_sync(device_id: str | None = None, force: bool = False) -> dict: + """Execute push-pull sync for one or all push_pull devices.""" + if not ENABLE_SYNC_PUSH_PULL and not force: + log.info("Push-pull sync not enabled, skipping") + return {"status": "disabled"} + + if device_id: + device = db_device_handler.get_device_by_id(device_id) + if not device: + return {"status": "error", "message": f"Device {device_id} not found"} + devices = [device] + else: + devices = list( + db_device_handler.get_all_devices_by_sync_mode(SyncMode.PUSH_PULL) + ) + + if not devices: + log.info("No push_pull devices found") + return {"status": "no_devices"} + + results = [] + for device in devices: + if not device.sync_enabled: + continue + result = await _sync_device(device) + results.append(result) + + return {"status": "completed", "device_results": results} + + +async def _sync_device(device: Device) -> dict: + """Perform push-pull sync for a single device.""" + sync_config = device.sync_config or {} + if not sync_config.get("ssh_host"): + log.warning(f"Push-pull device {device.id} has no ssh_host configured") + return {"device_id": device.id, "status": "error", "message": "No ssh_host"} + + from endpoints.sockets.sync import ( + emit_sync_completed, + emit_sync_conflict, + emit_sync_error, + emit_sync_progress, + emit_sync_started, + ) + + # Create sync session + sync_session = db_sync_session_handler.create_session( + device_id=device.id, user_id=device.user_id + ) + + await emit_sync_started( + user_id=device.user_id, + device_id=device.id, + session_id=sync_session.id, + sync_mode="push_pull", + ) + + try: + conn = await ssh_sync_handler.connect(sync_config) + except Exception as e: + log.error(f"Push-pull: failed to connect to device {device.id}: {e}") + db_sync_session_handler.fail_session( + session_id=sync_session.id, error_message=str(e) + ) + await emit_sync_error( + user_id=device.user_id, + device_id=device.id, + session_id=sync_session.id, + error_message=str(e), + ) + return {"device_id": device.id, "status": "connection_failed", "error": str(e)} + + completed = 0 + failed = 0 + + try: + save_directories = sync_config.get("save_directories", []) + if not save_directories: + log.warning( + f"Push-pull device {device.id} has no save_directories configured" + ) + db_sync_session_handler.complete_session(session_id=sync_session.id) + return {"device_id": device.id, "status": "no_directories"} + + # List remote saves + remote_saves = await ssh_sync_handler.list_remote_saves(conn, save_directories) + log.info( + f"Push-pull: found {len(remote_saves)} remote saves on device {device.id}" + ) + + db_sync_session_handler.update_session( + session_id=sync_session.id, + data={ + "status": SyncSessionStatus.IN_PROGRESS, + "operations_planned": len(remote_saves), + }, + ) + + operations_planned = len(remote_saves) + + # Process each remote save + for remote_save in remote_saves: + try: + action = await _process_remote_save(device, conn, remote_save) + if action == "conflict": + await emit_sync_conflict( + user_id=device.user_id, + device_id=device.id, + session_id=sync_session.id, + file_name=remote_save.file_name, + rom_id=0, + reason=f"Conflict detected for {remote_save.file_name}", + ) + if action != "skipped": + completed += 1 + except Exception: + log.error( + f"Push-pull: failed to process {remote_save.file_name} " + f"on device {device.id}", + exc_info=True, + ) + failed += 1 + + await emit_sync_progress( + user_id=device.user_id, + device_id=device.id, + session_id=sync_session.id, + operations_completed=completed + failed, + operations_planned=operations_planned, + current_file=remote_save.file_name, + ) + + # Check for server saves that need to be pushed to the device + push_count = await _push_missing_saves( + device, conn, remote_saves, save_directories + ) + completed += push_count + + conn.close() + + except Exception as e: + log.error(f"Push-pull sync failed for device {device.id}: {e}", exc_info=True) + db_sync_session_handler.fail_session( + session_id=sync_session.id, error_message=str(e) + ) + await emit_sync_error( + user_id=device.user_id, + device_id=device.id, + session_id=sync_session.id, + error_message=str(e), + ) + return {"device_id": device.id, "status": "failed", "error": str(e)} + + db_sync_session_handler.complete_session( + session_id=sync_session.id, + operations_completed=completed, + operations_failed=failed, + ) + db_device_handler.update_last_seen(device_id=device.id, user_id=device.user_id) + + await emit_sync_completed( + user_id=device.user_id, + device_id=device.id, + session_id=sync_session.id, + operations_completed=completed, + operations_failed=failed, + ) + + log.info( + f"Push-pull sync for device {device.id}: " + f"{completed} completed, {failed} failed" + ) + return { + "device_id": device.id, + "status": "completed", + "completed": completed, + "failed": failed, + } + + +async def _process_remote_save( + device: Device, + conn, + remote_save, +) -> str: + """Process a single remote save file. Returns action taken.""" + # Look up platform + platform = db_platform_handler.get_platform_by_fs_slug(remote_save.platform_slug) + if not platform: + log.debug(f"Unknown platform slug: {remote_save.platform_slug}") + return "skipped" + + # Find matching server save + saves = db_save_handler.get_saves(user_id=device.user_id, platform_id=platform.id) + matched_save = None + for save in saves: + if save.file_name == remote_save.file_name: + matched_save = save + break + + if not matched_save: + # New save from device - download it + local_path, content_hash = await ssh_sync_handler.download_save( + conn, remote_save.path + ) + try: + # We have the file locally, but we need a ROM to attach it to. + # Without a clear ROM match, skip for now. + log.info( + f"Push-pull: new remote save {hl(remote_save.file_name)} " + f"on platform {remote_save.platform_slug} - no matching server save" + ) + return "skipped" + finally: + if os.path.exists(local_path): + os.unlink(local_path) + + # Compare with existing save + device_sync = db_device_save_sync_handler.get_sync( + device_id=device.id, save_id=matched_save.id + ) + + # Download remote file to get its hash + local_path, remote_hash = await ssh_sync_handler.download_save( + conn, remote_save.path + ) + + try: + result = compare_save_state( + client_hash=remote_hash, + client_updated_at=remote_save.mtime, + server_hash=matched_save.content_hash, + server_updated_at=matched_save.updated_at, + device_last_synced_at=device_sync.last_synced_at if device_sync else None, + ) + + if result.action == "no_op": + # Update sync tracking even for no-ops + db_device_save_sync_handler.upsert_sync( + device_id=device.id, + save_id=matched_save.id, + synced_at=datetime.now(timezone.utc), + ) + return "no_op" + + if result.action == "upload": + # Remote is newer - pull to server + log.info( + f"Push-pull: pulling {hl(remote_save.file_name)} from device {device.id}" + ) + with open(local_path, "rb") as f: + file_data = f.read() + await fs_asset_handler.write_file( + file=file_data, + path=matched_save.file_path, + filename=matched_save.file_name, + ) + db_save_handler.update_save( + matched_save.id, + { + "file_size_bytes": remote_save.file_size, + "content_hash": remote_hash, + }, + ) + db_device_save_sync_handler.upsert_sync( + device_id=device.id, + save_id=matched_save.id, + synced_at=datetime.now(timezone.utc), + ) + return "pulled" + + if result.action == "download": + # Server is newer - push to device + log.info( + f"Push-pull: pushing {hl(matched_save.file_name)} to device {device.id}" + ) + server_file_path = f"{matched_save.file_path}/{matched_save.file_name}" + server_full_path = fs_asset_handler.validate_path(server_file_path) + await ssh_sync_handler.upload_save( + conn, str(server_full_path), remote_save.path + ) + db_device_save_sync_handler.upsert_sync( + device_id=device.id, + save_id=matched_save.id, + synced_at=datetime.now(timezone.utc), + ) + return "pushed" + + if result.action == "conflict": + log.warning( + f"Push-pull: conflict for {remote_save.file_name} " + f"on device {device.id}: {result.reason}" + ) + return "conflict" + + finally: + if os.path.exists(local_path): + os.unlink(local_path) + + return "skipped" + + +async def _push_missing_saves( + device: Device, + conn, + remote_saves, + save_directories: list[dict], +) -> int: + """Push server saves that are missing from the device.""" + pushed = 0 + + # Build set of remote filenames per platform + remote_files: dict[str, set[str]] = {} + for rs in remote_saves: + remote_files.setdefault(rs.platform_slug, set()).add(rs.file_name) + + # Build path lookup from save_directories config + platform_paths: dict[str, str] = {} + for dir_config in save_directories: + platform_paths[dir_config["platform_slug"]] = dir_config["path"] + + # Check server saves for each configured platform + for dir_config in save_directories: + platform_slug = dir_config["platform_slug"] + platform = db_platform_handler.get_platform_by_fs_slug(platform_slug) + if not platform: + continue + + server_saves = db_save_handler.get_saves( + user_id=device.user_id, platform_id=platform.id + ) + + remote_set = remote_files.get(platform_slug, set()) + remote_dir = platform_paths.get(platform_slug, "") + + for save in server_saves: + if save.file_name in remote_set: + continue + + # Check if device has synced this before (intentional delete) + device_sync = db_device_save_sync_handler.get_sync( + device_id=device.id, save_id=save.id + ) + if device_sync and device_sync.is_untracked: + continue + + # Push to device + if remote_dir: + try: + server_file_path = f"{save.file_path}/{save.file_name}" + server_full_path = fs_asset_handler.validate_path(server_file_path) + remote_path = f"{remote_dir}/{save.file_name}" + await ssh_sync_handler.upload_save( + conn, str(server_full_path), remote_path + ) + db_device_save_sync_handler.upsert_sync( + device_id=device.id, + save_id=save.id, + synced_at=datetime.now(timezone.utc), + ) + pushed += 1 + log.info( + f"Push-pull: pushed missing save {hl(save.file_name)} " + f"to device {device.id}" + ) + except Exception: + log.error( + f"Push-pull: failed to push {save.file_name} to device {device.id}", + exc_info=True, + ) + + return pushed + + +class SyncPushPullTask(PeriodicTask): + """Periodic task to run push-pull sync for all configured devices.""" + + def __init__(self) -> None: + super().__init__( + title="Push-Pull Sync", + description="Sync saves with devices via SSH/SFTP", + task_type=TaskType.SYNC, + enabled=ENABLE_SYNC_PUSH_PULL, + cron_string=SYNC_PUSH_PULL_CRON, + func="tasks.sync_push_pull_task.run_push_pull_sync", + ) + + async def run(self, *args: Any, **kwargs: Any) -> Any: + return await run_push_pull_sync(**kwargs) + + +sync_push_pull_task = SyncPushPullTask() diff --git a/backend/tasks/tasks.py b/backend/tasks/tasks.py index 076de6726..799c5535b 100644 --- a/backend/tasks/tasks.py +++ b/backend/tasks/tasks.py @@ -36,6 +36,7 @@ class TaskType(str, Enum): CONVERSION = "conversion" CLEANUP = "cleanup" UPDATE = "update" + SYNC = "sync" WATCHER = "watcher" GENERIC = "generic" diff --git a/pyproject.toml b/pyproject.toml index d9c01e84e..27683ba9e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,6 +15,7 @@ dependencies = [ "SQLAlchemy[mariadb-connector,mysql-connector,postgresql-psycopg] ~= 2.0", "Unidecode ~= 1.3", "aiohttp ~= 3.12", + "asyncssh ~= 2.17", "alembic ~= 1.13", "anyio ~= 4.4", "authlib ~= 1.6.5", diff --git a/uv.lock b/uv.lock index 8b7ded1c2..a7538ecf3 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 3 +revision = 2 requires-python = ">=3.13" resolution-markers = [ "platform_python_implementation != 'PyPy'", @@ -167,6 +167,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/25/8a/c46dcc25341b5bce5472c718902eb3d38600a903b14fa6aeecef3f21a46f/asttokens-3.0.0-py3-none-any.whl", hash = "sha256:e3078351a059199dd5138cb1c706e6430c05eff2ff136af5eb4790f9d28932e2", size = 26918, upload-time = "2024-11-30T04:30:10.946Z" }, ] +[[package]] +name = "asyncssh" +version = "2.22.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cryptography" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/fc/d5/957886c316466349d55c4de6a688a10a98295c0b4429deb8db1a17f3eb19/asyncssh-2.22.0.tar.gz", hash = "sha256:c3ce72b01be4f97b40e62844dd384227e5ff5a401a3793007c42f86a5c8eb537", size = 540523, upload-time = "2025-12-21T23:38:30.5Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ed/ae/0da2f2214fc183338af1afe5a103a2052fd03464e8eafbd827abff58a4d0/asyncssh-2.22.0-py3-none-any.whl", hash = "sha256:d16465ccdf1ed20eba1131b14415b155e047f6f5be0d19f39c2e0b61331ee0e7", size = 374938, upload-time = "2025-12-21T23:38:28.976Z" }, +] + [[package]] name = "attrs" version = "25.3.0" @@ -686,6 +699,7 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b1/cf/f5c0b23309070ae93de75c90d29300751a5aacefc0a3ed1b1d8edb28f08b/greenlet-3.2.3-cp313-cp313-macosx_11_0_universal2.whl", hash = "sha256:500b8689aa9dd1ab26872a34084503aeddefcb438e2e7317b89b11eaea1901ad", size = 270732, upload-time = "2025-06-05T16:10:08.26Z" }, { url = "https://files.pythonhosted.org/packages/48/ae/91a957ba60482d3fecf9be49bc3948f341d706b52ddb9d83a70d42abd498/greenlet-3.2.3-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:a07d3472c2a93117af3b0136f246b2833fdc0b542d4a9799ae5f41c28323faef", size = 639033, upload-time = "2025-06-05T16:38:53.983Z" }, { url = "https://files.pythonhosted.org/packages/6f/df/20ffa66dd5a7a7beffa6451bdb7400d66251374ab40b99981478c69a67a8/greenlet-3.2.3-cp313-cp313-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:8704b3768d2f51150626962f4b9a9e4a17d2e37c8a8d9867bbd9fa4eb938d3b3", size = 652999, upload-time = "2025-06-05T16:41:37.89Z" }, + { url = "https://files.pythonhosted.org/packages/51/b4/ebb2c8cb41e521f1d72bf0465f2f9a2fd803f674a88db228887e6847077e/greenlet-3.2.3-cp313-cp313-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:5035d77a27b7c62db6cf41cf786cfe2242644a7a337a0e155c80960598baab95", size = 647368, upload-time = "2025-06-05T16:48:21.467Z" }, { url = "https://files.pythonhosted.org/packages/8e/6a/1e1b5aa10dced4ae876a322155705257748108b7fd2e4fae3f2a091fe81a/greenlet-3.2.3-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:2d8aa5423cd4a396792f6d4580f88bdc6efcb9205891c9d40d20f6e670992efb", size = 650037, upload-time = "2025-06-05T16:13:06.402Z" }, { url = "https://files.pythonhosted.org/packages/26/f2/ad51331a157c7015c675702e2d5230c243695c788f8f75feba1af32b3617/greenlet-3.2.3-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:2c724620a101f8170065d7dded3f962a2aea7a7dae133a009cada42847e04a7b", size = 608402, upload-time = "2025-06-05T16:12:51.91Z" }, { url = "https://files.pythonhosted.org/packages/26/bc/862bd2083e6b3aff23300900a956f4ea9a4059de337f5c8734346b9b34fc/greenlet-3.2.3-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:873abe55f134c48e1f2a6f53f7d1419192a3d1a4e873bace00499a4e45ea6af0", size = 1119577, upload-time = "2025-06-05T16:36:49.787Z" }, @@ -694,6 +708,7 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d8/ca/accd7aa5280eb92b70ed9e8f7fd79dc50a2c21d8c73b9a0856f5b564e222/greenlet-3.2.3-cp314-cp314-macosx_11_0_universal2.whl", hash = "sha256:3d04332dddb10b4a211b68111dabaee2e1a073663d117dc10247b5b1642bac86", size = 271479, upload-time = "2025-06-05T16:10:47.525Z" }, { url = "https://files.pythonhosted.org/packages/55/71/01ed9895d9eb49223280ecc98a557585edfa56b3d0e965b9fa9f7f06b6d9/greenlet-3.2.3-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:8186162dffde068a465deab08fc72c767196895c39db26ab1c17c0b77a6d8b97", size = 683952, upload-time = "2025-06-05T16:38:55.125Z" }, { url = "https://files.pythonhosted.org/packages/ea/61/638c4bdf460c3c678a0a1ef4c200f347dff80719597e53b5edb2fb27ab54/greenlet-3.2.3-cp314-cp314-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:f4bfbaa6096b1b7a200024784217defedf46a07c2eee1a498e94a1b5f8ec5728", size = 696917, upload-time = "2025-06-05T16:41:38.959Z" }, + { url = "https://files.pythonhosted.org/packages/22/cc/0bd1a7eb759d1f3e3cc2d1bc0f0b487ad3cc9f34d74da4b80f226fde4ec3/greenlet-3.2.3-cp314-cp314-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:ed6cfa9200484d234d8394c70f5492f144b20d4533f69262d530a1a082f6ee9a", size = 692443, upload-time = "2025-06-05T16:48:23.113Z" }, { url = "https://files.pythonhosted.org/packages/67/10/b2a4b63d3f08362662e89c103f7fe28894a51ae0bc890fabf37d1d780e52/greenlet-3.2.3-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:02b0df6f63cd15012bed5401b47829cfd2e97052dc89da3cfaf2c779124eb892", size = 692995, upload-time = "2025-06-05T16:13:07.972Z" }, { url = "https://files.pythonhosted.org/packages/5a/c6/ad82f148a4e3ce9564056453a71529732baf5448ad53fc323e37efe34f66/greenlet-3.2.3-cp314-cp314-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:86c2d68e87107c1792e2e8d5399acec2487a4e993ab76c792408e59394d52141", size = 655320, upload-time = "2025-06-05T16:12:53.453Z" }, { url = "https://files.pythonhosted.org/packages/5c/4f/aab73ecaa6b3086a4c89863d94cf26fa84cbff63f52ce9bc4342b3087a06/greenlet-3.2.3-cp314-cp314-win_amd64.whl", hash = "sha256:8c47aae8fbbfcf82cc13327ae802ba13c9c36753b67e760023fd116bc124a62a", size = 301236, upload-time = "2025-06-05T16:15:20.111Z" }, @@ -2100,6 +2115,7 @@ dependencies = [ { name = "aiohttp" }, { name = "alembic" }, { name = "anyio" }, + { name = "asyncssh" }, { name = "authlib" }, { name = "colorama" }, { name = "defusedxml" }, @@ -2168,6 +2184,7 @@ requires-dist = [ { name = "aiohttp", specifier = "~=3.12" }, { name = "alembic", specifier = "~=1.13" }, { name = "anyio", specifier = "~=4.4" }, + { name = "asyncssh", specifier = "~=2.17" }, { name = "authlib", specifier = "~=1.6.5" }, { name = "colorama", specifier = "~=0.4" }, { name = "defusedxml", specifier = "~=0.7" },