From eba2971ffbb5ffe00ef006f2403db23efce5074a Mon Sep 17 00:00:00 2001 From: Michael Manganiello Date: Mon, 14 Oct 2024 01:08:33 -0300 Subject: [PATCH] fix: Simplify query that validates new username already exists Instead of fetching all users and checking if the new username is present in the list, we can directly query the database for the username. --- backend/endpoints/tests/test_identity.py | 16 ++++++++++++++++ backend/endpoints/user.py | 3 ++- backend/handler/database/users_handler.py | 6 ++++-- 3 files changed, 22 insertions(+), 3 deletions(-) diff --git a/backend/endpoints/tests/test_identity.py b/backend/endpoints/tests/test_identity.py index 5bfbbff26..e66043d34 100644 --- a/backend/endpoints/tests/test_identity.py +++ b/backend/endpoints/tests/test_identity.py @@ -131,6 +131,22 @@ def test_add_user_from_unauthorized_user( assert response.status_code == expected_status_code +def test_add_user_with_existing_username(client, access_token, admin_user): + response = client.post( + "/api/users", + params={ + "username": admin_user.username, + "password": "new_user_password", + "role": Role.VIEWER.value, + }, + headers={"Authorization": f"Bearer {access_token}"}, + ) + assert response.status_code == HTTPStatus.BAD_REQUEST + + response = response.json() + assert response["detail"] == f"Username {admin_user.username} already exists" + + def test_update_user(client, access_token, editor_user): assert editor_user.role == Role.EDITOR diff --git a/backend/endpoints/user.py b/backend/endpoints/user.py index 113d04387..3bed5391b 100644 --- a/backend/endpoints/user.py +++ b/backend/endpoints/user.py @@ -47,7 +47,8 @@ def add_user(request: Request, username: str, password: str, role: str) -> UserS detail="Forbidden", ) - if username in [user.username for user in db_user_handler.get_users()]: + existing_user = db_user_handler.get_user_by_username(username) + if existing_user: msg = f"Username {username} already exists" log.error(msg) raise HTTPException( diff --git a/backend/handler/database/users_handler.py b/backend/handler/database/users_handler.py index f0a6eafa4..92a467424 100644 --- a/backend/handler/database/users_handler.py +++ b/backend/handler/database/users_handler.py @@ -12,11 +12,13 @@ class DBUsersHandler(DBBaseHandler): return session.merge(user) @begin_session - def get_user_by_username(self, username: str, session: Session = None): + def get_user_by_username( + self, username: str, session: Session = None + ) -> User | None: return session.scalar(select(User).filter_by(username=username).limit(1)) @begin_session - def get_user(self, id: int, session: Session = None) -> User: + def get_user(self, id: int, session: Session = None) -> User | None: return session.get(User, id) @begin_session