mirror of
https://github.com/rommapp/romm.git
synced 2026-06-28 14:56:01 +00:00
Implement refresh tokens
This commit is contained in:
@@ -15,11 +15,14 @@ from utils.auth import (
|
||||
authenticate_user,
|
||||
create_access_token,
|
||||
get_password_hash,
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES,
|
||||
get_current_active_user,
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES = 30
|
||||
REFRESH_TOKEN_EXPIRE_DAYS = 7
|
||||
|
||||
|
||||
class UserSchema(BaseModel):
|
||||
username: str
|
||||
@@ -38,18 +41,43 @@ def credentials_exception(scheme: str):
|
||||
)
|
||||
|
||||
|
||||
@router.post("/token", response_model=UserSchema)
|
||||
@router.post("/token")
|
||||
def generate_access_token(form_data: Annotated[OAuth2PasswordRequestForm, Depends()]):
|
||||
user = authenticate_user(form_data.username, form_data.password)
|
||||
if not user:
|
||||
raise credentials_exception("Bearer")
|
||||
|
||||
access_token = create_access_token(
|
||||
data={"sub": user.username},
|
||||
data={"sub": user.username, "type": "access"},
|
||||
expires_delta=timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES),
|
||||
)
|
||||
|
||||
return {"access_token": access_token, "token_type": "bearer"}
|
||||
refresh_token = create_access_token(
|
||||
data={"sub": user.username, "type": "refresh"},
|
||||
expires_delta=timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS),
|
||||
)
|
||||
|
||||
return {
|
||||
"access_token": access_token,
|
||||
"refresh_token": refresh_token,
|
||||
"token_type": "Bearer",
|
||||
}
|
||||
|
||||
|
||||
@router.post("/refresh_token")
|
||||
def refresh_access_token(request: Request):
|
||||
if not request.user.is_authenticated:
|
||||
raise credentials_exception("Bearer")
|
||||
|
||||
access_token = create_access_token(
|
||||
data={"sub": request.user.username, "type": "access"},
|
||||
expires_delta=timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES),
|
||||
)
|
||||
|
||||
return {
|
||||
"access_token": access_token,
|
||||
"token_type": "Bearer",
|
||||
}
|
||||
|
||||
|
||||
@router.post("/login")
|
||||
|
||||
@@ -14,7 +14,6 @@ from config import SECRET_KEY
|
||||
from utils.cache import cache
|
||||
|
||||
ALGORITHM = "HS256"
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES = 30
|
||||
|
||||
|
||||
def verify_password(plain_password, hashed_password):
|
||||
@@ -60,6 +59,28 @@ credentials_exception = HTTPException(
|
||||
)
|
||||
|
||||
|
||||
async def get_current_active_user(token: str):
|
||||
try:
|
||||
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
except (JWTError):
|
||||
raise credentials_exception
|
||||
|
||||
username: str = payload.get("sub")
|
||||
if username is None:
|
||||
raise credentials_exception
|
||||
|
||||
user = dbh.get_user(username)
|
||||
if user is None:
|
||||
raise credentials_exception
|
||||
|
||||
if user.disabled:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Inactive user"
|
||||
)
|
||||
|
||||
return user, payload.get("type")
|
||||
|
||||
|
||||
class BasicAuthBackend(AuthenticationBackend):
|
||||
async def authenticate(self, conn: HTTPConnection):
|
||||
# Check if session key already stored in cache
|
||||
@@ -83,27 +104,14 @@ class BasicAuthBackend(AuthenticationBackend):
|
||||
if "Authorization" not in conn.headers:
|
||||
return
|
||||
|
||||
auth = conn.headers["Authorization"]
|
||||
scheme, token = auth.split()
|
||||
# Returns if Authorization header is not Bearer
|
||||
scheme, token = conn.headers["Authorization"].split()
|
||||
if scheme.lower() != "bearer":
|
||||
return
|
||||
|
||||
try:
|
||||
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
except (JWTError):
|
||||
raise credentials_exception
|
||||
user, token_type = await get_current_active_user(token)
|
||||
|
||||
username: str = payload.get("sub")
|
||||
if username is None:
|
||||
raise credentials_exception
|
||||
|
||||
user = dbh.get_user(username)
|
||||
if user is None:
|
||||
raise credentials_exception
|
||||
|
||||
if user.disabled:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail="Inactive user"
|
||||
)
|
||||
|
||||
return (AuthCredentials(user.oauth_scopes), user)
|
||||
return (
|
||||
AuthCredentials(user.oauth_scopes if token_type == "access" else None),
|
||||
user,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user