Files
romm/backend/utils/validation.py
Georges-Antoine Assi 2b0ed2296b cleanup
2026-05-27 09:29:21 -04:00

262 lines
8.8 KiB
Python

import ipaddress
import re
import socket
from urllib.parse import urlparse
from logger.logger import log
from models.user import TEXT_FIELD_LENGTH
class ValidationError(Exception):
"""Custom exception for validation errors."""
def __init__(self, message: str, field_name: str = "field"):
self.message = message
self.field_name = field_name
super().__init__(self.message)
# Pre-compiled regex patterns for better performance
USERNAME_PATTERN = re.compile(r"^[a-zA-Z0-9_-]+$")
EMAIL_PATTERN = re.compile(r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$")
def validate_ascii_only(value: str, field_name: str = "field") -> None:
"""Validate that a string contains only ASCII characters.
Args:
value (str): The value to validate
field_name (str): The name of the field for error messages
Raises:
ValidationError: If the value contains non-ASCII characters
"""
if not value:
return
# Check if any character is outside ASCII range (0-127)
if any(ord(char) > 127 for char in value):
msg = f"{field_name} must contain only ASCII characters"
log.error(f"Validation failed: {msg}")
raise ValidationError(msg, field_name)
def validate_username(username: str) -> None:
"""Validate username format and content.
Args:
username (str): The username to validate
Raises:
ValidationError: If the username is invalid
"""
if not username or not username.strip():
msg = "Username cannot be empty"
log.error(msg)
raise ValidationError(msg, "Username")
validate_ascii_only(username, "Username")
if len(username) < 3:
msg = "Username must be at least 3 characters long"
log.error(msg)
raise ValidationError(msg, "Username")
if len(username) > TEXT_FIELD_LENGTH:
msg = "Username must be no more than 255 characters long"
log.error(msg)
raise ValidationError(msg, "Username")
if not USERNAME_PATTERN.match(username):
msg = "Username can only contain letters, numbers, underscores, and hyphens"
log.error(f"Validation failed: {msg} for username: {username}")
raise ValidationError(msg, "Username")
def validate_password(password: str) -> None:
"""Validate password format and content.
Args:
password (str): The password to validate
Raises:
ValidationError: If the password is invalid
"""
if not password or not password.strip():
msg = "Password cannot be empty"
log.error(msg)
raise ValidationError(msg, "Password")
validate_ascii_only(password, "Password")
if len(password) < 6:
msg = "Password must be at least 6 characters long"
log.error(msg)
raise ValidationError(msg, "Password")
if len(password) > TEXT_FIELD_LENGTH:
msg = "Password must be no more than 255 characters long"
log.error(msg)
raise ValidationError(msg, "Password")
def validate_email(email: str) -> None:
"""Validate email format and content.
Args:
email (str): The email to validate
Raises:
ValidationError: If the email is invalid
"""
if not email:
return
validate_ascii_only(email, "Email")
if not EMAIL_PATTERN.match(email):
msg = "Invalid email format"
log.error(f"Validation failed: {msg} for email: {email}")
raise ValidationError(msg, "Email")
# Check for localhost and reserved hostnames
RESERVED_HOSTNAMES = [
"localhost",
"127.0.0.1",
"0.0.0.0", # trunk-ignore(bandit/B104)
"::1",
"::",
]
def _is_forbidden_ip(ip: ipaddress.IPv4Address | ipaddress.IPv6Address) -> bool:
"""Return True if the IP address is in a range that must not be reached
by a server-side HTTP request (private, loopback, link-local, reserved,
or multicast)."""
return (
ip.is_private
or ip.is_loopback
or ip.is_link_local
or ip.is_reserved
or ip.is_multicast
or ip.is_unspecified
)
def validate_url_for_http_request(url: str, field_name: str = "URL") -> None:
"""Validate URL to prevent Server-Side Request Forgery (SSRF) attacks.
This function validates that:
- The URL scheme is http or https only
- If the host is a literal IP address, it is not private/internal/reserved
- The host is not a reserved hostname (localhost, 127.0.0.1, etc.)
- The host does not use internal TLDs (.local, .internal, .localhost)
- The host, if a domain name, does not resolve via DNS to a private,
loopback, link-local, reserved, or multicast address
Args:
url (str): The URL to validate
field_name (str): The name of the field for error messages
Raises:
ValidationError: If the URL is invalid or potentially dangerous
"""
if not url or not url.strip():
msg = f"{field_name} cannot be empty"
log.error(msg)
raise ValidationError(msg, field_name)
try:
parsed = urlparse(url)
except Exception as e:
msg = f"Invalid {field_name}: unable to parse URL"
log.error(f"{msg}: {str(e)}")
raise ValidationError(msg, field_name) from e
# Validate scheme - only allow http and https
if parsed.scheme not in ["http", "https"]:
msg = f"Invalid {field_name}: only http and https schemes are allowed"
log.error(f"SSRF prevention: {msg} - got scheme '{parsed.scheme}'")
raise ValidationError(msg, field_name)
# Extract hostname
hostname = parsed.hostname
if not hostname:
msg = f"Invalid {field_name}: missing hostname"
log.error(msg)
raise ValidationError(msg, field_name)
if hostname.lower() in RESERVED_HOSTNAMES:
msg = f"Invalid {field_name}: localhost and reserved hostnames are not allowed"
log.error(f"SSRF prevention: {msg} - hostname '{hostname}'")
raise ValidationError(msg, field_name)
# Try to resolve hostname as IP address
try:
ip = ipaddress.ip_address(hostname)
if _is_forbidden_ip(ip):
msg = f"Invalid {field_name}: private, internal, reserved, or multicast IP addresses are not allowed"
log.error(f"SSRF prevention: {msg} - IP '{ip}'")
raise ValidationError(msg, field_name)
# DNS resolution is not needed, allow it
return
except ValueError as e:
# ipaddress.ip_address() only handles standard notation. HTTP clients
# also accept hex integers (0x7f000001), decimal integers (2130706433),
# shorthand dotted (127.1), and octal (0177.0.0.1). Use socket.inet_aton()
# which handles these non-standard IPv4 representations.
try:
packed = socket.inet_aton(hostname)
ip = ipaddress.IPv4Address(packed)
if _is_forbidden_ip(ip):
msg = f"Invalid {field_name}: private, internal, reserved, or multicast IP addresses are not allowed"
log.error(f"SSRF prevention: {msg} - IP '{ip}'")
raise ValidationError(msg, field_name)
# Non-standard IPv4 literal
return
except OSError:
pass # Fall through to domain name checks
# Additional checks for suspicious domain patterns
hostname_lower = hostname.lower()
# Block common internal TLDs
internal_tlds = [".local", ".internal", ".localhost"]
if any(hostname_lower.endswith(tld) for tld in internal_tlds):
msg = f"Invalid {field_name}: internal domain names are not allowed"
log.error(f"SSRF prevention: {msg} - hostname '{hostname}'")
raise ValidationError(msg, field_name) from e
# Resolve the hostname via DNS and reject any answer that points at a
# private/loopback/link-local/reserved/multicast address. Without this
# check, an attacker-controlled (or wildcard) DNS name such as
# `127.0.0.1.nip.io` slips past the literal-IP checks above and the
# subsequent HTTP request reaches an internal target (SSRF).
try:
addr_infos = socket.getaddrinfo(hostname, parsed.port, type=socket.SOCK_STREAM)
except socket.gaierror as e:
msg = f"Invalid {field_name}: hostname could not be resolved"
log.error(f"SSRF prevention: {msg} - hostname '{hostname}': {e}")
raise ValidationError(msg, field_name) from e
for *_, sockaddr in addr_infos:
try:
resolved_ip = ipaddress.ip_address(sockaddr[0])
except ValueError:
continue
if _is_forbidden_ip(resolved_ip):
msg = (
f"Invalid {field_name}: hostname resolves to a private, "
f"internal, reserved, or multicast IP address"
)
log.error(
f"SSRF prevention: {msg} - hostname '{hostname}' -> '{resolved_ip}'"
)
raise ValidationError(msg, field_name)