diff --git a/backend/endpoints/tests/test_identity.py b/backend/endpoints/tests/test_identity.py index 5bfbbff26..e66043d34 100644 --- a/backend/endpoints/tests/test_identity.py +++ b/backend/endpoints/tests/test_identity.py @@ -131,6 +131,22 @@ def test_add_user_from_unauthorized_user( assert response.status_code == expected_status_code +def test_add_user_with_existing_username(client, access_token, admin_user): + response = client.post( + "/api/users", + params={ + "username": admin_user.username, + "password": "new_user_password", + "role": Role.VIEWER.value, + }, + headers={"Authorization": f"Bearer {access_token}"}, + ) + assert response.status_code == HTTPStatus.BAD_REQUEST + + response = response.json() + assert response["detail"] == f"Username {admin_user.username} already exists" + + def test_update_user(client, access_token, editor_user): assert editor_user.role == Role.EDITOR diff --git a/backend/endpoints/user.py b/backend/endpoints/user.py index 113d04387..3bed5391b 100644 --- a/backend/endpoints/user.py +++ b/backend/endpoints/user.py @@ -47,7 +47,8 @@ def add_user(request: Request, username: str, password: str, role: str) -> UserS detail="Forbidden", ) - if username in [user.username for user in db_user_handler.get_users()]: + existing_user = db_user_handler.get_user_by_username(username) + if existing_user: msg = f"Username {username} already exists" log.error(msg) raise HTTPException( diff --git a/backend/handler/database/users_handler.py b/backend/handler/database/users_handler.py index f0a6eafa4..92a467424 100644 --- a/backend/handler/database/users_handler.py +++ b/backend/handler/database/users_handler.py @@ -12,11 +12,13 @@ class DBUsersHandler(DBBaseHandler): return session.merge(user) @begin_session - def get_user_by_username(self, username: str, session: Session = None): + def get_user_by_username( + self, username: str, session: Session = None + ) -> User | None: return session.scalar(select(User).filter_by(username=username).limit(1)) @begin_session - def get_user(self, id: int, session: Session = None) -> User: + def get_user(self, id: int, session: Session = None) -> User | None: return session.get(User, id) @begin_session