mirror of
https://github.com/rommapp/romm.git
synced 2026-06-28 06:46:00 +00:00
fix: oauth token invalidation on expiration date and rotating refresh token
This commit is contained in:
@@ -249,10 +249,10 @@ class OAuthHandler:
|
||||
pass
|
||||
|
||||
def create_oauth_token(
|
||||
self, data: dict, expires_delta: timedelta = DEFAULT_OAUTH_TOKEN_EXPIRY
|
||||
self, data: dict, expires_delta: timedelta = DEFAULT_OAUTH_TOKEN_EXPIRY
|
||||
) -> str:
|
||||
to_encode = data.copy()
|
||||
expire = datetime.now(timezone.utc) + expires_delta
|
||||
expire = int((datetime.now(timezone.utc) + expires_delta).timestamp())
|
||||
to_encode.update({"exp": expire})
|
||||
|
||||
return jwt.encode(
|
||||
@@ -261,6 +261,75 @@ class OAuthHandler:
|
||||
oct_key,
|
||||
)
|
||||
|
||||
def create_access_token(
|
||||
self, data: dict, expires_delta: timedelta = DEFAULT_OAUTH_TOKEN_EXPIRY
|
||||
) -> str:
|
||||
return self.create_oauth_token(data, expires_delta)
|
||||
|
||||
def create_refresh_token(
|
||||
self, data: dict, expires_delta: timedelta
|
||||
) -> str:
|
||||
to_encode = data.copy()
|
||||
expire = int((datetime.now(timezone.utc) + expires_delta).timestamp())
|
||||
jti = str(uuid.uuid4())
|
||||
|
||||
to_encode.update({
|
||||
"exp": expire,
|
||||
"jti": jti,
|
||||
"type": "refresh",
|
||||
})
|
||||
|
||||
token = jwt.encode(
|
||||
{"alg": ALGORITHM},
|
||||
to_encode,
|
||||
oct_key,
|
||||
)
|
||||
|
||||
redis_client.setex(
|
||||
f"refresh-jti:{jti}",
|
||||
int(expires_delta.total_seconds()),
|
||||
"valid",
|
||||
)
|
||||
|
||||
return token
|
||||
|
||||
async def consume_refresh_token(self, token: str):
|
||||
from handler.database import db_user_handler
|
||||
|
||||
try:
|
||||
payload = jwt.decode(token, oct_key, algorithms=[ALGORITHM])
|
||||
except (BadSignatureError, DecodeError, ValueError) as exc:
|
||||
raise OAuthCredentialsException from exc
|
||||
|
||||
now = datetime.now(timezone.utc).timestamp()
|
||||
if now > payload.claims.get("exp", 0):
|
||||
raise OAuthCredentialsException
|
||||
|
||||
if payload.claims.get("iss") != "romm:oauth":
|
||||
raise OAuthCredentialsException
|
||||
|
||||
if payload.claims.get("type") != "refresh":
|
||||
raise OAuthCredentialsException
|
||||
|
||||
jti = payload.claims.get("jti")
|
||||
if not jti or redis_client.get(f"refresh-jti:{jti}") != b"valid":
|
||||
raise OAuthCredentialsException
|
||||
|
||||
redis_client.delete(f"refresh-jti:{jti}")
|
||||
|
||||
username = payload.claims.get("sub")
|
||||
if not username:
|
||||
raise OAuthCredentialsException
|
||||
|
||||
user = db_user_handler.get_user_by_username(username)
|
||||
if user is None:
|
||||
raise OAuthCredentialsException
|
||||
|
||||
if not user.enabled:
|
||||
raise UserDisabledException
|
||||
|
||||
return user, payload.claims
|
||||
|
||||
async def get_current_active_user_from_bearer_token(self, token: str):
|
||||
from handler.database import db_user_handler
|
||||
|
||||
@@ -269,6 +338,10 @@ class OAuthHandler:
|
||||
except (BadSignatureError, DecodeError, ValueError) as exc:
|
||||
raise OAuthCredentialsException from exc
|
||||
|
||||
now = datetime.now(timezone.utc).timestamp()
|
||||
if now > payload.claims.get("exp", 0):
|
||||
raise OAuthCredentialsException
|
||||
|
||||
issuer = payload.claims.get("iss")
|
||||
if not issuer or issuer != "romm:oauth":
|
||||
return None, None
|
||||
|
||||
Reference in New Issue
Block a user