Implement refresh tokens

This commit is contained in:
Georges-Antoine Assi
2023-08-11 00:30:30 -04:00
parent a8a402d250
commit 6cf6f6ca78
2 changed files with 61 additions and 25 deletions

View File

@@ -15,11 +15,14 @@ from utils.auth import (
authenticate_user,
create_access_token,
get_password_hash,
ACCESS_TOKEN_EXPIRE_MINUTES,
get_current_active_user,
)
router = APIRouter()
ACCESS_TOKEN_EXPIRE_MINUTES = 30
REFRESH_TOKEN_EXPIRE_DAYS = 7
class UserSchema(BaseModel):
username: str
@@ -38,18 +41,43 @@ def credentials_exception(scheme: str):
)
@router.post("/token", response_model=UserSchema)
@router.post("/token")
def generate_access_token(form_data: Annotated[OAuth2PasswordRequestForm, Depends()]):
user = authenticate_user(form_data.username, form_data.password)
if not user:
raise credentials_exception("Bearer")
access_token = create_access_token(
data={"sub": user.username},
data={"sub": user.username, "type": "access"},
expires_delta=timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES),
)
return {"access_token": access_token, "token_type": "bearer"}
refresh_token = create_access_token(
data={"sub": user.username, "type": "refresh"},
expires_delta=timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS),
)
return {
"access_token": access_token,
"refresh_token": refresh_token,
"token_type": "Bearer",
}
@router.post("/refresh_token")
def refresh_access_token(request: Request):
if not request.user.is_authenticated:
raise credentials_exception("Bearer")
access_token = create_access_token(
data={"sub": request.user.username, "type": "access"},
expires_delta=timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES),
)
return {
"access_token": access_token,
"token_type": "Bearer",
}
@router.post("/login")

View File

@@ -14,7 +14,6 @@ from config import SECRET_KEY
from utils.cache import cache
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 30
def verify_password(plain_password, hashed_password):
@@ -60,6 +59,28 @@ credentials_exception = HTTPException(
)
async def get_current_active_user(token: str):
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
except (JWTError):
raise credentials_exception
username: str = payload.get("sub")
if username is None:
raise credentials_exception
user = dbh.get_user(username)
if user is None:
raise credentials_exception
if user.disabled:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Inactive user"
)
return user, payload.get("type")
class BasicAuthBackend(AuthenticationBackend):
async def authenticate(self, conn: HTTPConnection):
# Check if session key already stored in cache
@@ -83,27 +104,14 @@ class BasicAuthBackend(AuthenticationBackend):
if "Authorization" not in conn.headers:
return
auth = conn.headers["Authorization"]
scheme, token = auth.split()
# Returns if Authorization header is not Bearer
scheme, token = conn.headers["Authorization"].split()
if scheme.lower() != "bearer":
return
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
except (JWTError):
raise credentials_exception
user, token_type = await get_current_active_user(token)
username: str = payload.get("sub")
if username is None:
raise credentials_exception
user = dbh.get_user(username)
if user is None:
raise credentials_exception
if user.disabled:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail="Inactive user"
)
return (AuthCredentials(user.oauth_scopes), user)
return (
AuthCredentials(user.oauth_scopes if token_type == "access" else None),
user,
)