Files
romm/backend/handler/database/client_tokens_handler.py
nendo e0b25fbc6c feat(client-tokens): add client API tokens with QR pairing flow
Long-lived, revocable, scope-restricted tokens for external clients
(mobile apps, retro handhelds, third-party tools). Includes:

- Backend: model, migration, DB handler, auth integration (rmm_ prefix
  routing in HybridAuthBackend), CRUD + pairing + exchange endpoints,
  rate limiting, scope intersection enforcement, admin oversight
- Frontend: settings page with token management table, stepped
  create/deliver dialog (config -> copy/pair), QR code with RomM logo,
  admin token table, standalone /pair page for QR scan landing
- /pair page supports custom-scheme callbacks for app deep linking,
  falls back to displaying code for manual entry
- 33 backend tests across 5 classes (CRUD, auth, isolation, pairing,
  admin)
2026-03-11 10:56:35 +09:00

144 lines
4.1 KiB
Python

from collections.abc import Sequence
from datetime import datetime, timedelta, timezone
from sqlalchemy import delete, func, select, update
from sqlalchemy.orm import Session, joinedload
from decorators.database import begin_session
from models.client_token import ClientToken
from utils.datetime import to_utc
from .base_handler import DBBaseHandler
LAST_USED_DEBOUNCE = timedelta(minutes=5)
class DBClientTokensHandler(DBBaseHandler):
@begin_session
def add_token(
self,
token: ClientToken,
session: Session = None, # type: ignore
) -> ClientToken:
return session.merge(token)
@begin_session
def get_token_by_hash(
self,
hashed_token: str,
session: Session = None, # type: ignore
) -> ClientToken | None:
return session.scalar(
select(ClientToken).where(ClientToken.hashed_token == hashed_token)
)
@begin_session
def get_tokens_by_user(
self,
user_id: int,
session: Session = None, # type: ignore
) -> Sequence[ClientToken]:
return session.scalars(
select(ClientToken)
.where(ClientToken.user_id == user_id)
.order_by(ClientToken.created_at.desc())
).all()
@begin_session
def get_all_tokens(
self,
session: Session = None, # type: ignore
) -> Sequence[ClientToken]:
return (
session.scalars(
select(ClientToken)
.options(joinedload(ClientToken.user))
.order_by(ClientToken.created_at.desc())
)
.unique()
.all()
)
@begin_session
def delete_token(
self,
token_id: int,
user_id: int | None = None,
session: Session = None, # type: ignore
) -> int:
stmt = delete(ClientToken).where(ClientToken.id == token_id)
if user_id is not None:
stmt = stmt.where(ClientToken.user_id == user_id)
result = session.execute(stmt.execution_options(synchronize_session="evaluate"))
return result.rowcount
@begin_session
def update_last_used(
self,
token_id: int,
session: Session = None, # type: ignore
) -> None:
now = datetime.now(timezone.utc)
token = session.get(ClientToken, token_id)
if token is None:
return
if (
token.last_used_at
and (now - to_utc(token.last_used_at)) < LAST_USED_DEBOUNCE
):
return
session.execute(
update(ClientToken)
.where(ClientToken.id == token_id)
.values(last_used_at=now)
.execution_options(synchronize_session="evaluate")
)
@begin_session
def update_hashed_token(
self,
token_id: int,
new_hash: str,
user_id: int | None = None,
session: Session = None, # type: ignore
) -> ClientToken | None:
stmt = (
update(ClientToken)
.where(ClientToken.id == token_id)
.values(hashed_token=new_hash, last_used_at=None)
.execution_options(synchronize_session="evaluate")
)
if user_id is not None:
stmt = stmt.where(ClientToken.user_id == user_id)
result = session.execute(stmt)
if result.rowcount == 0:
return None
return session.get(ClientToken, token_id)
@begin_session
def count_tokens_by_user(
self,
user_id: int,
session: Session = None, # type: ignore
) -> int:
return (
session.scalar(
select(func.count())
.select_from(ClientToken)
.where(ClientToken.user_id == user_id)
)
or 0
)
@begin_session
def get_token(
self,
token_id: int,
user_id: int | None = None,
session: Session = None, # type: ignore
) -> ClientToken | None:
stmt = select(ClientToken).where(ClientToken.id == token_id)
if user_id is not None:
stmt = stmt.where(ClientToken.user_id == user_id)
return session.scalar(stmt)