fix: oauth token invalidation on expiration date and rotating refresh token

This commit is contained in:
HydroSulphide
2026-03-09 17:03:29 +01:00
parent 115fad85e9
commit 41f64eb42b
3 changed files with 97 additions and 12 deletions

View File

@@ -249,10 +249,10 @@ class OAuthHandler:
pass
def create_oauth_token(
self, data: dict, expires_delta: timedelta = DEFAULT_OAUTH_TOKEN_EXPIRY
self, data: dict, expires_delta: timedelta = DEFAULT_OAUTH_TOKEN_EXPIRY
) -> str:
to_encode = data.copy()
expire = datetime.now(timezone.utc) + expires_delta
expire = int((datetime.now(timezone.utc) + expires_delta).timestamp())
to_encode.update({"exp": expire})
return jwt.encode(
@@ -261,6 +261,75 @@ class OAuthHandler:
oct_key,
)
def create_access_token(
self, data: dict, expires_delta: timedelta = DEFAULT_OAUTH_TOKEN_EXPIRY
) -> str:
return self.create_oauth_token(data, expires_delta)
def create_refresh_token(
self, data: dict, expires_delta: timedelta
) -> str:
to_encode = data.copy()
expire = int((datetime.now(timezone.utc) + expires_delta).timestamp())
jti = str(uuid.uuid4())
to_encode.update({
"exp": expire,
"jti": jti,
"type": "refresh",
})
token = jwt.encode(
{"alg": ALGORITHM},
to_encode,
oct_key,
)
redis_client.setex(
f"refresh-jti:{jti}",
int(expires_delta.total_seconds()),
"valid",
)
return token
async def consume_refresh_token(self, token: str):
from handler.database import db_user_handler
try:
payload = jwt.decode(token, oct_key, algorithms=[ALGORITHM])
except (BadSignatureError, DecodeError, ValueError) as exc:
raise OAuthCredentialsException from exc
now = datetime.now(timezone.utc).timestamp()
if now > payload.claims.get("exp", 0):
raise OAuthCredentialsException
if payload.claims.get("iss") != "romm:oauth":
raise OAuthCredentialsException
if payload.claims.get("type") != "refresh":
raise OAuthCredentialsException
jti = payload.claims.get("jti")
if not jti or redis_client.get(f"refresh-jti:{jti}") != b"valid":
raise OAuthCredentialsException
redis_client.delete(f"refresh-jti:{jti}")
username = payload.claims.get("sub")
if not username:
raise OAuthCredentialsException
user = db_user_handler.get_user_by_username(username)
if user is None:
raise OAuthCredentialsException
if not user.enabled:
raise UserDisabledException
return user, payload.claims
async def get_current_active_user_from_bearer_token(self, token: str):
from handler.database import db_user_handler
@@ -269,6 +338,10 @@ class OAuthHandler:
except (BadSignatureError, DecodeError, ValueError) as exc:
raise OAuthCredentialsException from exc
now = datetime.now(timezone.utc).timestamp()
if now > payload.claims.get("exp", 0):
raise OAuthCredentialsException
issuer = payload.claims.get("iss")
if not issuer or issuer != "romm:oauth":
return None, None