mirror of
https://github.com/rommapp/romm.git
synced 2026-06-28 06:46:00 +00:00
Place auth behind flag
This commit is contained in:
@@ -53,4 +53,4 @@ ROMM_DB_DRIVER = os.environ.get("ROMM_DB_DRIVER", "sqlite")
|
||||
ROMM_AUTH_ENABLED = os.environ.get("ROMM_AUTH_ENABLED", "false") == "true"
|
||||
ROMM_AUTH_USERNAME = os.environ.get("ROMM_AUTH_USERNAME", "admin")
|
||||
ROMM_AUTH_PASSWORD = os.environ.get("ROMM_AUTH_PASSWORD", "admin")
|
||||
ROMM_AUTH_SECRET_KEY = os.environ.get("ROMM_AUTH_SECRET_KEY", secrets.token_hex(32))
|
||||
ROMM_SECRET_KEY = os.environ.get("ROMM_SECRET_KEY", secrets.token_hex(32))
|
||||
|
||||
@@ -149,7 +149,9 @@ class DBHandler:
|
||||
def get_user(self, username: str):
|
||||
try:
|
||||
with self.session.begin() as session:
|
||||
return session.scalars(select(User).filter_by(username=username)).first()
|
||||
return session.scalars(
|
||||
select(User).filter_by(username=username)
|
||||
).first()
|
||||
except ProgrammingError as e:
|
||||
self.raise_error(e)
|
||||
|
||||
|
||||
@@ -7,11 +7,10 @@ from fastapi_pagination import add_pagination
|
||||
from starlette.middleware.authentication import AuthenticationMiddleware
|
||||
from starlette.middleware.sessions import SessionMiddleware
|
||||
|
||||
from config import DEV_PORT, DEV_HOST, ROMM_AUTH_SECRET_KEY
|
||||
from config import DEV_PORT, DEV_HOST, ROMM_SECRET_KEY
|
||||
from endpoints import search, platform, rom, identity, oauth, scan # noqa
|
||||
from utils.socket import socket_app
|
||||
from utils.auth import BasicAuthBackend, CustomCSRFMiddleware
|
||||
|
||||
from utils.auth import BasicAuthBackend, CustomCSRFMiddleware, create_default_admin_user
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
@@ -29,13 +28,13 @@ app.add_middleware(
|
||||
)
|
||||
app.add_middleware(
|
||||
SessionMiddleware,
|
||||
secret_key=ROMM_AUTH_SECRET_KEY,
|
||||
secret_key=ROMM_SECRET_KEY,
|
||||
same_site="strict",
|
||||
https_only=False,
|
||||
)
|
||||
app.add_middleware(
|
||||
CustomCSRFMiddleware,
|
||||
secret=ROMM_AUTH_SECRET_KEY,
|
||||
secret=ROMM_SECRET_KEY,
|
||||
exempt_urls=[re.compile(r"^/oauth/.*"), re.compile(r"^/ws")],
|
||||
)
|
||||
|
||||
@@ -64,5 +63,8 @@ if __name__ == "__main__":
|
||||
# Run migrations
|
||||
alembic.config.main(argv=["upgrade", "head"])
|
||||
|
||||
# Create default admin user
|
||||
create_default_admin_user()
|
||||
|
||||
# Run application
|
||||
uvicorn.run("main:app", host=DEV_HOST, port=DEV_PORT, reload=True)
|
||||
|
||||
@@ -11,9 +11,9 @@ class Role(enum.Enum):
|
||||
ADMIN = 2
|
||||
|
||||
|
||||
VIEWER_SCOPES = ["me.read", "me.write", "roms.read", "platforms.read"]
|
||||
EDITOR_SCOPES = VIEWER_SCOPES + ["roms.write", "platforms.write"]
|
||||
ADMIN_SCOPES = EDITOR_SCOPES + ["users.read", "users.write"]
|
||||
DEFAULT_SCOPES = ["me.read", "me.write", "roms.read", "platforms.read"]
|
||||
WRITE_SCOPES = DEFAULT_SCOPES + ["roms.write", "platforms.write"]
|
||||
FULL_SCOPES = WRITE_SCOPES + ["users.read", "users.write"]
|
||||
|
||||
|
||||
class User(BaseModel, SimpleUser):
|
||||
@@ -27,9 +27,9 @@ class User(BaseModel, SimpleUser):
|
||||
@property
|
||||
def oauth_scopes(self):
|
||||
if self.role == Role.ADMIN:
|
||||
return ADMIN_SCOPES
|
||||
return FULL_SCOPES
|
||||
|
||||
if self.role == Role.EDITOR:
|
||||
return EDITOR_SCOPES
|
||||
return WRITE_SCOPES
|
||||
|
||||
return VIEWER_SCOPES
|
||||
return DEFAULT_SCOPES
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from datetime import datetime, timedelta
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from fastapi import HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from jose import JWTError, jwt
|
||||
@@ -12,8 +13,14 @@ from starlette_csrf import CSRFMiddleware
|
||||
from starlette.types import Receive, Scope, Send
|
||||
|
||||
from handler import dbh
|
||||
from config import ROMM_AUTH_SECRET_KEY
|
||||
from config import (
|
||||
ROMM_SECRET_KEY,
|
||||
ROMM_AUTH_ENABLED,
|
||||
ROMM_AUTH_USERNAME,
|
||||
ROMM_AUTH_PASSWORD,
|
||||
)
|
||||
from utils.cache import cache
|
||||
from models.user import User, Role, FULL_SCOPES
|
||||
|
||||
ALGORITHM = "HS256"
|
||||
|
||||
@@ -51,7 +58,7 @@ def create_oauth_token(data: dict, expires_delta: timedelta | None = None):
|
||||
|
||||
to_encode.update({"exp": expire})
|
||||
|
||||
return jwt.encode(to_encode, ROMM_AUTH_SECRET_KEY, algorithm=ALGORITHM)
|
||||
return jwt.encode(to_encode, ROMM_SECRET_KEY, algorithm=ALGORITHM)
|
||||
|
||||
|
||||
credentials_exception = HTTPException(
|
||||
@@ -63,7 +70,7 @@ credentials_exception = HTTPException(
|
||||
|
||||
async def get_current_active_user_from_token(token: str):
|
||||
try:
|
||||
payload = jwt.decode(token, ROMM_AUTH_SECRET_KEY, algorithms=[ALGORITHM])
|
||||
payload = jwt.decode(token, ROMM_SECRET_KEY, algorithms=[ALGORITHM])
|
||||
except (JWTError):
|
||||
raise credentials_exception
|
||||
|
||||
@@ -82,16 +89,17 @@ async def get_current_active_user_from_token(token: str):
|
||||
|
||||
return user, payload
|
||||
|
||||
|
||||
async def get_current_active_user_from_session(conn: HTTPConnection):
|
||||
# Check if session key already stored in cache
|
||||
session_id = conn.session.get("session_id")
|
||||
if not session_id:
|
||||
return None
|
||||
|
||||
|
||||
username = cache.get(f"romm:{session_id}")
|
||||
if not username:
|
||||
return None
|
||||
|
||||
|
||||
# Key exists therefore user is authenticated
|
||||
user = dbh.get_user(username)
|
||||
if user is None:
|
||||
@@ -104,8 +112,12 @@ async def get_current_active_user_from_session(conn: HTTPConnection):
|
||||
|
||||
return user
|
||||
|
||||
|
||||
class BasicAuthBackend(AuthenticationBackend):
|
||||
async def authenticate(self, conn: HTTPConnection):
|
||||
if not ROMM_AUTH_ENABLED:
|
||||
return (AuthCredentials(FULL_SCOPES), None)
|
||||
|
||||
# Check if session key already stored in cache
|
||||
user = await get_current_active_user_from_session(conn)
|
||||
if user:
|
||||
@@ -125,7 +137,7 @@ class BasicAuthBackend(AuthenticationBackend):
|
||||
# Only access tokens can request resources
|
||||
if payload.get("type") == "access":
|
||||
return (AuthCredentials(user.oauth_scopes), user)
|
||||
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@@ -134,5 +146,21 @@ class CustomCSRFMiddleware(CSRFMiddleware):
|
||||
if scope["type"] != "http":
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
|
||||
|
||||
await super().__call__(scope, receive, send)
|
||||
|
||||
|
||||
def create_default_admin_user():
|
||||
if not ROMM_AUTH_ENABLED:
|
||||
return
|
||||
|
||||
try:
|
||||
dbh.add_user(
|
||||
User(
|
||||
username=ROMM_AUTH_USERNAME,
|
||||
hashed_password=get_password_hash(ROMM_AUTH_PASSWORD),
|
||||
role=Role.ADMIN,
|
||||
)
|
||||
)
|
||||
except IntegrityError:
|
||||
pass
|
||||
|
||||
@@ -24,7 +24,7 @@ REDIS_HOST=127.0.0.1
|
||||
REDIS_PORT=6379
|
||||
|
||||
# Authentication
|
||||
ROMM_SECRET_KEY=
|
||||
ROMM_AUTH_ENABLED=true
|
||||
ROMM_AUTH_USERNAME=admin
|
||||
ROMM_AUTH_PASSWORD=admin
|
||||
ROMM_AUTH_SECRET_KEY=
|
||||
|
||||
@@ -20,7 +20,7 @@ services:
|
||||
- ROMM_AUTH_ENABLED=true # [Optional] Will enable user management and require authentication to access the interface (default to false)
|
||||
- ROMM_AUTH_USERNAME=admin # [Optional] Username for default admin user
|
||||
- ROMM_AUTH_PASSWORD=<admin password> # [Optional] Password for default admin user (defaults to admin)
|
||||
- ROMM_AUTH_SECRET_KEY=<secret key> # [Optional] Used to encrypt user passwords, generate one with `openssl rand -hex 32`
|
||||
- ROMM_SECRET_KEY=<secret key> # [Optional] Used to encrypt user passwords, generate one with `openssl rand -hex 32`
|
||||
volumes:
|
||||
- '/path/to/library:/romm/library'
|
||||
- '/path/to/resources:/romm/resources' # [Optional] Path where roms metadata (covers) are stored
|
||||
|
||||
Reference in New Issue
Block a user