From f13f929d7dab75fb4b6e95b8aa4cee122a2225c1 Mon Sep 17 00:00:00 2001 From: Georges-Antoine Assi Date: Sat, 14 Mar 2026 22:26:53 -0400 Subject: [PATCH] tweaks based on self review --- backend/endpoints/sync.py | 82 +--------------------- backend/handler/filesystem/sync_handler.py | 26 +++++-- backend/handler/sync/ssh_handler.py | 67 ++++++++++-------- backend/sync_watcher.py | 19 +++-- backend/tasks/sync_push_pull_task.py | 2 +- docker/init_scripts/init | 24 +++++++ entrypoint.sh | 8 +++ 7 files changed, 104 insertions(+), 124 deletions(-) diff --git a/backend/endpoints/sync.py b/backend/endpoints/sync.py index fca13760a..ff8cd35ed 100644 --- a/backend/endpoints/sync.py +++ b/backend/endpoints/sync.py @@ -1,6 +1,6 @@ from datetime import datetime -from fastapi import HTTPException, Request, UploadFile, status +from fastapi import HTTPException, Request, status from pydantic import BaseModel from config import TASK_TIMEOUT @@ -19,7 +19,6 @@ from handler.database import ( ) from handler.redis_handler import high_prio_queue from handler.sync.comparison import compare_save_state -from handler.sync.ssh_handler import ssh_sync_handler from logger.logger import log from models.assets import Save from models.device import SyncMode @@ -370,82 +369,3 @@ def trigger_push_pull( log.info(f"Enqueued push-pull sync for device {device.id}") return SyncSessionSchema.model_validate(sync_session) - - -@protected_route(router.post, "/devices/{device_id}/ssh-key", [Scope.DEVICES_WRITE]) -async def upload_ssh_key( - request: Request, - device_id: str, - keyFile: UploadFile, -) -> dict: - """Upload an SSH private key for a push-pull device.""" - device = db_device_handler.get_device(device_id=device_id, user_id=request.user.id) - if not device: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"Device with ID {device_id} not found", - ) - - key_data = await keyFile.read() - if not key_data: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Empty key file", - ) - - # Validate it looks like a private key - key_str = key_data.decode("utf-8", errors="replace") - if "PRIVATE KEY" not in key_str: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="File does not appear to be a valid SSH private key", - ) - - key_path = ssh_sync_handler.store_key(device.id, key_data) - - # Update device sync_config with the key path - sync_config = device.sync_config or {} - sync_config["ssh_key_path"] = str(key_path) - db_device_handler.update_device( - device_id=device.id, - user_id=request.user.id, - data={"sync_config": sync_config}, - ) - - log.info(f"Stored SSH key for device {device.id}") - return {"status": "ok", "key_path": str(key_path)} - - -@protected_route( - router.delete, - "/devices/{device_id}/ssh-key", - [Scope.DEVICES_WRITE], - status_code=status.HTTP_204_NO_CONTENT, -) -def delete_ssh_key( - request: Request, - device_id: str, -) -> None: - """Remove an SSH private key for a device.""" - device = db_device_handler.get_device(device_id=device_id, user_id=request.user.id) - if not device: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"Device with ID {device_id} not found", - ) - - removed = ssh_sync_handler.remove_key(device.id) - if not removed: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="No SSH key found for this device", - ) - - # Remove key path from sync_config - sync_config = device.sync_config or {} - sync_config.pop("ssh_key_path", None) - db_device_handler.update_device( - device_id=device.id, - user_id=request.user.id, - data={"sync_config": sync_config}, - ) diff --git a/backend/handler/filesystem/sync_handler.py b/backend/handler/filesystem/sync_handler.py index 51b61552c..bf53b6c12 100644 --- a/backend/handler/filesystem/sync_handler.py +++ b/backend/handler/filesystem/sync_handler.py @@ -17,7 +17,8 @@ class FSSyncHandler(FSHandler): def build_incoming_path( self, device_id: str, platform_slug: str | None = None ) -> str: - parts = [device_id, "incoming"] + """Build the relative incoming path for a device (and optional platform).""" + parts = [self.base_path, device_id, "incoming"] if platform_slug: parts.append(platform_slug) return os.path.join(*parts) @@ -25,16 +26,25 @@ class FSSyncHandler(FSHandler): def build_outgoing_path( self, device_id: str, platform_slug: str | None = None ) -> str: - parts = [device_id, "outgoing"] + """Build the relative outgoing path for a device (and optional platform).""" + parts = [self.base_path, device_id, "outgoing"] + if platform_slug: + parts.append(platform_slug) + return os.path.join(*parts) + + def build_conflicts_path( + self, device_id: str, platform_slug: str | None = None + ) -> str: + """Build the relative conflicts path for a device (and optional platform).""" + parts = [self.base_path, device_id, "conflicts"] if platform_slug: parts.append(platform_slug) return os.path.join(*parts) def ensure_device_directories(self, device_id: str) -> None: """Create incoming/outgoing directory structure for a device.""" - device_base = self.base_path / device_id - incoming = device_base / "incoming" - outgoing = device_base / "outgoing" + incoming = Path(self.build_incoming_path(device_id)) + outgoing = Path(self.build_outgoing_path(device_id)) incoming.mkdir(parents=True, exist_ok=True) outgoing.mkdir(parents=True, exist_ok=True) @@ -46,7 +56,7 @@ class FSSyncHandler(FSHandler): Returns list of dicts with keys: platform_slug, file_name, full_path, file_size, mtime """ - incoming_dir = self.base_path / device_id / "incoming" + incoming_dir = self.base_path / self.build_incoming_path(device_id) if not incoming_dir.exists(): return [] @@ -84,7 +94,9 @@ class FSSyncHandler(FSHandler): self, device_id: str, platform_slug: str, file_name: str, data: bytes ) -> str: """Write a file to a device's outgoing directory.""" - outgoing_dir = self.base_path / device_id / "outgoing" / platform_slug + outgoing_dir = self.base_path / self.build_outgoing_path( + device_id, platform_slug + ) outgoing_dir.mkdir(parents=True, exist_ok=True) file_path = outgoing_dir / file_name file_path.write_bytes(data) diff --git a/backend/handler/sync/ssh_handler.py b/backend/handler/sync/ssh_handler.py index 602d45db1..ae2c5ce28 100644 --- a/backend/handler/sync/ssh_handler.py +++ b/backend/handler/sync/ssh_handler.py @@ -2,6 +2,11 @@ Provides methods to connect to remote devices via SSH, list remote save files, and perform bidirectional file transfers using SFTP. + +SSH keys are expected to be pre-mounted on the server (e.g. via Docker volume) +at the path configured by SYNC_SSH_KEYS_PATH. Keys are looked up by device_id +({SYNC_SSH_KEYS_PATH}/{device_id}.pem) or via an explicit ssh_key_path in the +device's sync_config. """ from __future__ import annotations @@ -33,46 +38,49 @@ class RemoteSaveInfo: class SSHSyncHandler: - """Handles SSH/SFTP operations for push-pull sync mode.""" + """Handles SSH/SFTP operations for push-pull sync mode. + + SSH keys are expected to be pre-mounted on the server filesystem at + SYNC_SSH_KEYS_PATH. The handler looks up keys by device_id convention + ({keys_path}/{device_id}.pem) or uses an explicit path from sync_config. + """ def __init__(self) -> None: self.keys_path = Path(SYNC_SSH_KEYS_PATH) self.keys_path.mkdir(parents=True, exist_ok=True) - def get_key_path(self, device_id: str) -> Path: - """Get the SSH key path for a device.""" - return self.keys_path / f"{device_id}.pem" + def _resolve_key_path(self, device_id: str, sync_config: dict) -> str | None: + """Resolve the SSH key path for a device. - def has_key(self, device_id: str) -> bool: - """Check if an SSH key exists for a device.""" - return self.get_key_path(device_id).is_file() + Checks, in order: + 1. Explicit ssh_key_path in sync_config + 2. Convention-based path: {SYNC_SSH_KEYS_PATH}/{device_id}.pem + """ + explicit = sync_config.get("ssh_key_path") + if explicit and os.path.isfile(explicit): + return explicit - def store_key(self, device_id: str, key_data: bytes) -> Path: - """Store an SSH private key for a device.""" - key_path = self.get_key_path(device_id) - key_path.write_bytes(key_data) - key_path.chmod(0o600) - log.info(f"Stored SSH key for device {device_id}") - return key_path + convention_path = self.keys_path / f"{device_id}.pem" + if convention_path.is_file(): + return str(convention_path) - def remove_key(self, device_id: str) -> bool: - """Remove an SSH private key for a device.""" - key_path = self.get_key_path(device_id) - if key_path.is_file(): - key_path.unlink() - log.info(f"Removed SSH key for device {device_id}") - return True - return False + return None - async def connect(self, sync_config: dict) -> asyncssh.SSHClientConnection: + async def connect( + self, sync_config: dict, device_id: str | None = None + ) -> asyncssh.SSHClientConnection: """Establish an SSH connection using device sync_config. + SSH keys should be pre-mounted on the server. The handler resolves + the key by checking sync_config.ssh_key_path first, then falls back + to the convention-based path {SYNC_SSH_KEYS_PATH}/{device_id}.pem. + sync_config should contain: - ssh_host: hostname or IP - ssh_port: port (default 22) - ssh_username: username - - ssh_key_path: path to private key (or device_id to look up) - - ssh_password: password (optional, alternative to key) + - ssh_key_path: explicit path to private key (optional) + - ssh_password: password (optional, fallback if no key found) """ host = sync_config["ssh_host"] port = sync_config.get("ssh_port", 22) @@ -85,16 +93,17 @@ class SSHSyncHandler: "known_hosts": None, # Accept all host keys (TODO: make configurable) } - # Try key-based auth first - key_path = sync_config.get("ssh_key_path") - if key_path and os.path.isfile(key_path): + # Resolve key path (explicit or convention-based) + key_path = self._resolve_key_path(device_id or "", sync_config) + if key_path: connect_kwargs["client_keys"] = [key_path] elif sync_config.get("ssh_password"): connect_kwargs["password"] = sync_config["ssh_password"] else: raise ValueError( f"No SSH authentication method available for {host}. " - "Provide ssh_key_path or ssh_password in sync_config." + f"Mount a key at {self.keys_path}/{{device_id}}.pem or " + "provide ssh_key_path/ssh_password in sync_config." ) log.info(f"Connecting to {username}@{host}:{port}") diff --git a/backend/sync_watcher.py b/backend/sync_watcher.py index eb7335d4c..b83027897 100644 --- a/backend/sync_watcher.py +++ b/backend/sync_watcher.py @@ -1,8 +1,8 @@ """Sync folder watcher for File Transfer mode. This module is invoked by watchfiles when changes are detected in the sync -folder (SYNC_BASE_PATH). It processes incoming save files from devices that -use file_transfer sync mode. +folder. It processes incoming save files from devices that use file_transfer +sync mode. The watcher is configured to run as a separate watchfiles process monitoring the sync base path. When files appear in a device's incoming/ directory, they @@ -19,7 +19,7 @@ from typing import cast import sentry_sdk -from config import ENABLE_SYNC_FOLDER_WATCHER, SENTRY_DSN, SYNC_BASE_PATH +from config import ENABLE_SYNC_FOLDER_WATCHER, SENTRY_DSN from handler.database import ( db_device_handler, db_device_save_sync_handler, @@ -46,10 +46,11 @@ Change = tuple[str, str] def _extract_device_and_platform(path: str) -> tuple[str, str, str] | None: """Extract device_id, platform_slug, and filename from a sync incoming path. - Expected path format: {SYNC_BASE_PATH}/{device_id}/incoming/{platform_slug}/filename.ext + Expected path format: {build_incoming_path(device_id, platform_slug)}/filename.ext + i.e. {SYNC_BASE_PATH}/{device_id}/incoming/{platform_slug}/filename.ext """ try: - rel_path = os.path.relpath(path, SYNC_BASE_PATH) + rel_path = os.path.relpath(path) parts = rel_path.split(os.sep) # Minimum: device_id / incoming / platform_slug / filename if len(parts) < 4 or parts[1] != "incoming": @@ -57,6 +58,12 @@ def _extract_device_and_platform(path: str) -> tuple[str, str, str] | None: device_id = parts[0] platform_slug = parts[2] filename = parts[-1] + + # Validate path matches the canonical incoming path structure + expected_prefix = fs_sync_handler.build_incoming_path(device_id, platform_slug) + if not rel_path.startswith(expected_prefix): + return None + return (device_id, platform_slug, filename) except (ValueError, IndexError): return None @@ -64,7 +71,7 @@ def _extract_device_and_platform(path: str) -> tuple[str, str, str] | None: def _ensure_conflicts_dir(device_id: str, platform_slug: str) -> str: """Ensure the conflicts directory exists and return its path.""" - conflicts_dir = os.path.join(SYNC_BASE_PATH, device_id, "conflicts", platform_slug) + conflicts_dir = fs_sync_handler.build_conflicts_path(device_id, platform_slug) os.makedirs(conflicts_dir, exist_ok=True) return conflicts_dir diff --git a/backend/tasks/sync_push_pull_task.py b/backend/tasks/sync_push_pull_task.py index f15c9f3fa..952aafcd4 100644 --- a/backend/tasks/sync_push_pull_task.py +++ b/backend/tasks/sync_push_pull_task.py @@ -84,7 +84,7 @@ async def _sync_device(device: Device) -> dict: ) try: - conn = await ssh_sync_handler.connect(sync_config) + conn = await ssh_sync_handler.connect(sync_config, device_id=device.id) except Exception as e: log.error(f"Push-pull: failed to connect to device {device.id}: {e}") db_sync_session_handler.fail_session( diff --git a/docker/init_scripts/init b/docker/init_scripts/init index 34ea60ec8..f9d493160 100755 --- a/docker/init_scripts/init +++ b/docker/init_scripts/init @@ -12,6 +12,7 @@ ENABLE_RESCAN_ON_FILESYSTEM_CHANGE="${ENABLE_RESCAN_ON_FILESYSTEM_CHANGE:="false ENABLE_SCHEDULED_RESCAN="${ENABLE_SCHEDULED_RESCAN:="false"}" ENABLE_SCHEDULED_UPDATE_LAUNCHBOX_METADATA="${ENABLE_SCHEDULED_UPDATE_LAUNCHBOX_METADATA:="false"}" ENABLE_SCHEDULED_UPDATE_SWITCH_TITLEDB="${ENABLE_SCHEDULED_UPDATE_SWITCH_TITLEDB:="false"}" +ENABLE_SYNC_FOLDER_WATCHER="${ENABLE_SYNC_FOLDER_WATCHER:="false"}" # if REDIS_HOST is set, we assume that an external redis is used REDIS_HOST="${REDIS_HOST:=""}" @@ -236,6 +237,23 @@ start_bin_watcher() { echo "${WATCHER_PID}" >/tmp/watcher.pid } +start_bin_sync_watcher() { + info_log "Starting sync folder watcher" + if [[ ${OTEL_SDK_DISABLED:-false} == "true" ]]; then + watchfiles \ + --target-type command \ + "python3 sync_watcher.py" \ + /romm/sync & + else + watchfiles \ + --target-type command \ + "opentelemetry-instrument --service_name '${OTEL_SERVICE_NAME_PREFIX-}sync_watcher' python3 sync_watcher.py" \ + /romm/sync & + fi + SYNC_WATCHER_PID=$! + echo "${SYNC_WATCHER_PID}" >/tmp/sync_watcher.pid +} + watchdog_process_pid() { PROCESS=$1 if [[ -f "/tmp/${PROCESS}.pid" ]]; then @@ -267,6 +285,7 @@ shutdown() { # shutdown in reverse order stop_process_pid rq_worker stop_process_pid rq_scheduler + stop_process_pid sync_watcher stop_process_pid watcher stop_process_pid nginx stop_process_pid gunicorn @@ -332,6 +351,11 @@ while ! ((exited)); do watchdog_process_pid watcher fi + # only start the sync folder watcher if enabled + if [[ ${ENABLE_SYNC_FOLDER_WATCHER} == "true" ]]; then + watchdog_process_pid sync_watcher + fi + watchdog_process_pid nginx # check for died processes every 5 seconds diff --git a/entrypoint.sh b/entrypoint.sh index 7c15d1d19..4c00c1470 100644 --- a/entrypoint.sh +++ b/entrypoint.sh @@ -77,6 +77,14 @@ watchfiles \ 'uv run python watcher.py' \ /app/romm/library & +if [[ ${ENABLE_SYNC_FOLDER_WATCHER:-false} == "true" ]]; then + echo "Starting sync folder watcher..." + watchfiles \ + --target-type command \ + 'uv run python sync_watcher.py' \ + /app/romm/sync & +fi + # Start the frontend dev server cd /app/frontend npm run dev &