diff --git a/backend/endpoints/auth.py b/backend/endpoints/auth.py index 4acbfed0c..29bde9f22 100644 --- a/backend/endpoints/auth.py +++ b/backend/endpoints/auth.py @@ -93,7 +93,7 @@ async def token(form_data: Annotated[OAuth2RequestForm, Depends()]) -> TokenResp TokenResponse: TypedDict with the new generated token info """ - # Suppport refreshing access tokens + # Support refreshing access tokens if form_data.grant_type == "refresh_token": token = form_data.refresh_token if not token: @@ -101,9 +101,7 @@ async def token(form_data: Annotated[OAuth2RequestForm, Depends()]) -> TokenResp status_code=status.HTTP_400_BAD_REQUEST, detail="Missing refresh token" ) - user, claims = await oauth_handler.get_current_active_user_from_bearer_token( - token - ) + user, claims = await oauth_handler.consume_refresh_token(token) if not user or not claims: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid refresh token" @@ -119,7 +117,7 @@ async def token(form_data: Annotated[OAuth2RequestForm, Depends()]) -> TokenResp status_code=status.HTTP_403_FORBIDDEN, detail="User account is disabled" ) - access_token = oauth_handler.create_oauth_token( + access_token = oauth_handler.create_access_token( data={ "sub": user.username, "iss": "romm:oauth", @@ -129,10 +127,22 @@ async def token(form_data: Annotated[OAuth2RequestForm, Depends()]) -> TokenResp expires_delta=timedelta(seconds=ACCESS_TOKEN_EXPIRE_SECONDS), ) + refresh_token = oauth_handler.create_refresh_token( + data={ + "sub": user.username, + "iss": "romm:oauth", + "scopes": claims.get("scopes"), + "type": "refresh", + }, + expires_delta=timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS), + ) + return { "access_token": access_token, + "refresh_token": refresh_token, "token_type": "bearer", # trunk-ignore(bandit/B105) - "expires": ACCESS_TOKEN_EXPIRE_SECONDS, + "expires_in": ACCESS_TOKEN_EXPIRE_SECONDS, + "refresh_expires_in": REFRESH_TOKEN_EXPIRE_DAYS * 24 * 60 * 60, } # Authentication via username/password @@ -176,7 +186,7 @@ async def token(form_data: Annotated[OAuth2RequestForm, Depends()]) -> TokenResp detail="Insufficient scope", ) - access_token = oauth_handler.create_oauth_token( + access_token = oauth_handler.create_access_token( data={ "sub": user.username, "iss": "romm:oauth", @@ -186,7 +196,7 @@ async def token(form_data: Annotated[OAuth2RequestForm, Depends()]) -> TokenResp expires_delta=timedelta(seconds=ACCESS_TOKEN_EXPIRE_SECONDS), ) - refresh_token = oauth_handler.create_oauth_token( + refresh_token = oauth_handler.create_refresh_token( data={ "sub": user.username, "iss": "romm:oauth", @@ -200,7 +210,8 @@ async def token(form_data: Annotated[OAuth2RequestForm, Depends()]) -> TokenResp "access_token": access_token, "refresh_token": refresh_token, "token_type": "bearer", # trunk-ignore(bandit/B105) - "expires": ACCESS_TOKEN_EXPIRE_SECONDS, + "expires_in": ACCESS_TOKEN_EXPIRE_SECONDS, + "refresh_expires_in": REFRESH_TOKEN_EXPIRE_DAYS * 24 * 60 * 60, } diff --git a/backend/endpoints/responses/oauth.py b/backend/endpoints/responses/oauth.py index f70023a9b..179f704c0 100644 --- a/backend/endpoints/responses/oauth.py +++ b/backend/endpoints/responses/oauth.py @@ -5,4 +5,5 @@ class TokenResponse(TypedDict): access_token: str refresh_token: NotRequired[str] token_type: str - expires: int + expires_in: int + refresh_expires_in: int diff --git a/backend/handler/auth/base_handler.py b/backend/handler/auth/base_handler.py index 0c2b1c460..d56240bf0 100644 --- a/backend/handler/auth/base_handler.py +++ b/backend/handler/auth/base_handler.py @@ -249,10 +249,10 @@ class OAuthHandler: pass def create_oauth_token( - self, data: dict, expires_delta: timedelta = DEFAULT_OAUTH_TOKEN_EXPIRY + self, data: dict, expires_delta: timedelta = DEFAULT_OAUTH_TOKEN_EXPIRY ) -> str: to_encode = data.copy() - expire = datetime.now(timezone.utc) + expires_delta + expire = int((datetime.now(timezone.utc) + expires_delta).timestamp()) to_encode.update({"exp": expire}) return jwt.encode( @@ -261,6 +261,75 @@ class OAuthHandler: oct_key, ) + def create_access_token( + self, data: dict, expires_delta: timedelta = DEFAULT_OAUTH_TOKEN_EXPIRY + ) -> str: + return self.create_oauth_token(data, expires_delta) + + def create_refresh_token( + self, data: dict, expires_delta: timedelta + ) -> str: + to_encode = data.copy() + expire = int((datetime.now(timezone.utc) + expires_delta).timestamp()) + jti = str(uuid.uuid4()) + + to_encode.update({ + "exp": expire, + "jti": jti, + "type": "refresh", + }) + + token = jwt.encode( + {"alg": ALGORITHM}, + to_encode, + oct_key, + ) + + redis_client.setex( + f"refresh-jti:{jti}", + int(expires_delta.total_seconds()), + "valid", + ) + + return token + + async def consume_refresh_token(self, token: str): + from handler.database import db_user_handler + + try: + payload = jwt.decode(token, oct_key, algorithms=[ALGORITHM]) + except (BadSignatureError, DecodeError, ValueError) as exc: + raise OAuthCredentialsException from exc + + now = datetime.now(timezone.utc).timestamp() + if now > payload.claims.get("exp", 0): + raise OAuthCredentialsException + + if payload.claims.get("iss") != "romm:oauth": + raise OAuthCredentialsException + + if payload.claims.get("type") != "refresh": + raise OAuthCredentialsException + + jti = payload.claims.get("jti") + if not jti or redis_client.get(f"refresh-jti:{jti}") != b"valid": + raise OAuthCredentialsException + + redis_client.delete(f"refresh-jti:{jti}") + + username = payload.claims.get("sub") + if not username: + raise OAuthCredentialsException + + user = db_user_handler.get_user_by_username(username) + if user is None: + raise OAuthCredentialsException + + if not user.enabled: + raise UserDisabledException + + return user, payload.claims + async def get_current_active_user_from_bearer_token(self, token: str): from handler.database import db_user_handler @@ -269,6 +338,10 @@ class OAuthHandler: except (BadSignatureError, DecodeError, ValueError) as exc: raise OAuthCredentialsException from exc + now = datetime.now(timezone.utc).timestamp() + if now > payload.claims.get("exp", 0): + raise OAuthCredentialsException + issuer = payload.claims.get("iss") if not issuer or issuer != "romm:oauth": return None, None