Files
romm/backend/handler/database/sync_sessions_handler.py
nendo 55638d15dc fix: address bugs, security issues, and convention violations in save-sync
- Fix broken path construction in FSSyncHandler: build_* methods now
  return relative paths; sync_watcher uses paths relative to sync base
  instead of CWD (was completely non-functional in production)
- Fix SSH connection leak in push-pull task: conn.close() now in finally
- Add log.warning for disabled SSH host key verification
- Fix race condition in session operation counter: use atomic SQL
  increment instead of read-then-write
- Extract _increment_session_counter helper, add exc_info to warnings
- Replace legacy session.query() with select() in sync_sessions_handler
- Fix orphaned session: trigger_push_pull now passes session_id to job
- Fix wasteful SSH download when no matched_save exists
- Fix BaseModel import collision in sync.py (pydantic -> project base)
- Fix ORM mutation in UserSchema.from_orm_with_request: set field on
  schema instance instead of mutating live ORM object
- Mask ssh_password and ssh_key_path in DeviceSchema API response
- Fix migration PostgreSQL compatibility: condition ON UPDATE clause
  on MySQL, drop enum in downgrade
- Rename copy-paste artifact rom_user_status_enum
2026-03-16 10:56:43 +09:00

182 lines
5.5 KiB
Python

from collections.abc import Sequence
from datetime import datetime, timezone
from sqlalchemy import select, update
from sqlalchemy.orm import Session
from decorators.database import begin_session
from models.sync_session import SyncSession, SyncSessionStatus
from .base_handler import DBBaseHandler
class DBSyncSessionsHandler(DBBaseHandler):
@begin_session
def create_session(
self,
device_id: str,
user_id: int,
session: Session = None, # type: ignore
) -> SyncSession:
sync_session = SyncSession(
device_id=device_id,
user_id=user_id,
status=SyncSessionStatus.PENDING,
initiated_at=datetime.now(timezone.utc),
)
session.add(sync_session)
session.flush()
return sync_session
@begin_session
def get_session(
self,
session_id: int,
user_id: int,
session: Session = None, # type: ignore
) -> SyncSession | None:
return session.scalar(
select(SyncSession).filter_by(id=session_id, user_id=user_id).limit(1)
)
@begin_session
def get_active_session(
self,
device_id: str,
user_id: int,
session: Session = None, # type: ignore
) -> SyncSession | None:
return session.scalar(
select(SyncSession)
.filter(
SyncSession.device_id == device_id,
SyncSession.user_id == user_id,
SyncSession.status.in_(
[
SyncSessionStatus.PENDING,
SyncSessionStatus.IN_PROGRESS,
]
),
)
.order_by(SyncSession.initiated_at.desc())
.limit(1)
)
@begin_session
def update_session(
self,
session_id: int,
data: dict,
session: Session = None, # type: ignore
) -> SyncSession:
session.execute(
update(SyncSession)
.where(SyncSession.id == session_id)
.values(**data)
.execution_options(synchronize_session="evaluate")
)
return session.scalar(select(SyncSession).filter_by(id=session_id))
@begin_session
def increment_operations_completed(
self,
session_id: int,
session: Session = None, # type: ignore
) -> None:
session.execute(
update(SyncSession)
.where(SyncSession.id == session_id)
.values(
operations_completed=SyncSession.operations_completed + 1,
)
.execution_options(synchronize_session="evaluate")
)
@begin_session
def complete_session(
self,
session_id: int,
operations_completed: int = 0,
operations_failed: int = 0,
session: Session = None, # type: ignore
) -> SyncSession:
session.execute(
update(SyncSession)
.where(SyncSession.id == session_id)
.values(
status=SyncSessionStatus.COMPLETED,
completed_at=datetime.now(timezone.utc),
operations_completed=operations_completed,
operations_failed=operations_failed,
)
.execution_options(synchronize_session="evaluate")
)
return session.scalar(select(SyncSession).filter_by(id=session_id))
@begin_session
def fail_session(
self,
session_id: int,
error_message: str | None = None,
session: Session = None, # type: ignore
) -> SyncSession:
session.execute(
update(SyncSession)
.where(SyncSession.id == session_id)
.values(
status=SyncSessionStatus.FAILED,
completed_at=datetime.now(timezone.utc),
error_message=error_message,
)
.execution_options(synchronize_session="evaluate")
)
return session.scalar(select(SyncSession).filter_by(id=session_id))
@begin_session
def cancel_active_sessions(
self,
device_id: str,
user_id: int,
session: Session = None, # type: ignore
) -> int:
"""Cancel all active sessions for a device. Returns count of cancelled sessions."""
result = session.execute(
update(SyncSession)
.where(
SyncSession.device_id == device_id,
SyncSession.user_id == user_id,
SyncSession.status.in_(
[
SyncSessionStatus.PENDING,
SyncSessionStatus.IN_PROGRESS,
]
),
)
.values(
status=SyncSessionStatus.CANCELLED,
completed_at=datetime.now(timezone.utc),
)
.execution_options(synchronize_session="evaluate")
)
return result.rowcount
@begin_session
def get_sessions(
self,
user_id: int,
device_id: str | None = None,
status: SyncSessionStatus | None = None,
limit: int = 50,
session: Session = None, # type: ignore
) -> Sequence[SyncSession]:
query = select(SyncSession).filter_by(user_id=user_id)
if device_id:
query = query.filter_by(device_id=device_id)
if status:
query = query.filter_by(status=status)
query = query.order_by(SyncSession.initiated_at.desc()).limit(limit)
return session.scalars(query).all()