diff --git a/backend/endpoints/identity.py b/backend/endpoints/identity.py index 0a628d52d..e8be3951a 100644 --- a/backend/endpoints/identity.py +++ b/backend/endpoints/identity.py @@ -15,11 +15,14 @@ from utils.auth import ( authenticate_user, create_access_token, get_password_hash, - ACCESS_TOKEN_EXPIRE_MINUTES, + get_current_active_user, ) router = APIRouter() +ACCESS_TOKEN_EXPIRE_MINUTES = 30 +REFRESH_TOKEN_EXPIRE_DAYS = 7 + class UserSchema(BaseModel): username: str @@ -38,18 +41,43 @@ def credentials_exception(scheme: str): ) -@router.post("/token", response_model=UserSchema) +@router.post("/token") def generate_access_token(form_data: Annotated[OAuth2PasswordRequestForm, Depends()]): user = authenticate_user(form_data.username, form_data.password) if not user: raise credentials_exception("Bearer") access_token = create_access_token( - data={"sub": user.username}, + data={"sub": user.username, "type": "access"}, expires_delta=timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES), ) - return {"access_token": access_token, "token_type": "bearer"} + refresh_token = create_access_token( + data={"sub": user.username, "type": "refresh"}, + expires_delta=timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS), + ) + + return { + "access_token": access_token, + "refresh_token": refresh_token, + "token_type": "Bearer", + } + + +@router.post("/refresh_token") +def refresh_access_token(request: Request): + if not request.user.is_authenticated: + raise credentials_exception("Bearer") + + access_token = create_access_token( + data={"sub": request.user.username, "type": "access"}, + expires_delta=timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES), + ) + + return { + "access_token": access_token, + "token_type": "Bearer", + } @router.post("/login") diff --git a/backend/utils/auth.py b/backend/utils/auth.py index 4bd7bd4e3..8c07b2c3f 100644 --- a/backend/utils/auth.py +++ b/backend/utils/auth.py @@ -14,7 +14,6 @@ from config import SECRET_KEY from utils.cache import cache ALGORITHM = "HS256" -ACCESS_TOKEN_EXPIRE_MINUTES = 30 def verify_password(plain_password, hashed_password): @@ -60,6 +59,28 @@ credentials_exception = HTTPException( ) +async def get_current_active_user(token: str): + try: + payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) + except (JWTError): + raise credentials_exception + + username: str = payload.get("sub") + if username is None: + raise credentials_exception + + user = dbh.get_user(username) + if user is None: + raise credentials_exception + + if user.disabled: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, detail="Inactive user" + ) + + return user, payload.get("type") + + class BasicAuthBackend(AuthenticationBackend): async def authenticate(self, conn: HTTPConnection): # Check if session key already stored in cache @@ -83,27 +104,14 @@ class BasicAuthBackend(AuthenticationBackend): if "Authorization" not in conn.headers: return - auth = conn.headers["Authorization"] - scheme, token = auth.split() + # Returns if Authorization header is not Bearer + scheme, token = conn.headers["Authorization"].split() if scheme.lower() != "bearer": return - try: - payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) - except (JWTError): - raise credentials_exception + user, token_type = await get_current_active_user(token) - username: str = payload.get("sub") - if username is None: - raise credentials_exception - - user = dbh.get_user(username) - if user is None: - raise credentials_exception - - if user.disabled: - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, detail="Inactive user" - ) - - return (AuthCredentials(user.oauth_scopes), user) + return ( + AuthCredentials(user.oauth_scopes if token_type == "access" else None), + user, + )