Files
romm/backend/utils/ssrf.py
2026-06-03 16:21:27 -04:00

333 lines
12 KiB
Python

"""SSRF defense for outbound HTTP.
Two layers, both wired onto every httpx client built by `utils.context`:
1. `validate_url_for_http_request` — a syntactic fast-fail check
installed as an httpx request event hook. Rejects non-HTTP schemes,
literal IPs in forbidden ranges (including non-standard IPv4 forms),
reserved hostnames, and internal TLDs before any socket opens.
2. `SSRFProtectedAsyncBackend` / `SSRFProtectedSyncBackend` — custom
httpcore network backends that resolve the hostname inside
`connect_tcp`, reject any address in a private/loopback/link-local/
reserved/multicast/unspecified range, then connect to that *same*
validated address. This is what defeats DNS rebinding: the address
used by the OS for the TCP connection is the one we just checked,
not a fresh lookup the attacker can answer differently. Doing this
work in the backend also avoids blocking the event loop, since the
async variant uses `loop.getaddrinfo`.
httpcore calls `start_tls(server_hostname=<URL host>)` after
`connect_tcp` returns, so TLS SNI and certificate verification still
use the original hostname even though we connect by IP.
"""
from __future__ import annotations
import asyncio
import ipaddress
import socket
import typing
from urllib.parse import urlparse
import httpcore
from httpcore._backends.auto import AutoBackend
from httpcore._backends.base import (
SOCKET_OPTION,
AsyncNetworkBackend,
AsyncNetworkStream,
NetworkBackend,
NetworkStream,
)
from httpcore._backends.sync import SyncBackend
from logger.logger import log
from utils.validation import ValidationError
def is_forbidden_ip(ip: ipaddress.IPv4Address | ipaddress.IPv6Address) -> bool:
"""Return True if the IP must not be reached from a server-side HTTP request."""
return (
ip.is_private
or ip.is_loopback
or ip.is_link_local
or ip.is_reserved
or ip.is_multicast
or ip.is_unspecified
or not ip.is_global
)
def parse_ip_literal(
host: str,
) -> ipaddress.IPv4Address | ipaddress.IPv6Address | None:
"""Return the parsed IP if `host` is a literal address, else None.
Accepts non-standard IPv4 forms (hex, decimal, shorthand) via
`socket.inet_aton`, which is what HTTP clients themselves accept.
"""
try:
return ipaddress.ip_address(host)
except ValueError:
pass
try:
packed = socket.inet_aton(host)
except OSError:
return None
return ipaddress.IPv4Address(packed)
def _pick_safe_address(addr_infos: typing.Iterable[typing.Any], host: str) -> str:
"""Validate every resolved address and return the literal IP to connect to.
All returned addresses are checked: if any falls in a forbidden range
we reject the whole name, rather than just skipping that record. A
malicious DNS server can otherwise mix public and private answers
and rely on the client to round-robin.
"""
chosen: str | None = None
for *_, sockaddr in addr_infos:
try:
ip = ipaddress.ip_address(sockaddr[0])
except (ValueError, IndexError):
continue
if is_forbidden_ip(ip):
msg = (
f"SSRF prevention: hostname {host!r} resolves to forbidden " f"IP {ip}"
)
log.error(msg)
raise httpcore.ConnectError(msg)
if chosen is None:
chosen = sockaddr[0]
if chosen is None:
raise httpcore.ConnectError(f"No usable addresses for {host!r}")
return chosen
def _check_literal(host: str) -> bool:
"""Return True if `host` is a literal IP that has been validated as safe.
Raises httpcore.ConnectError if it is a literal IP in a forbidden range.
Returns False if `host` is a hostname (caller must resolve and validate).
"""
literal = parse_ip_literal(host)
if literal is None:
return False
if is_forbidden_ip(literal):
msg = f"SSRF prevention: connection to forbidden IP {literal}"
log.error(msg)
raise httpcore.ConnectError(msg)
return True
class SSRFProtectedAsyncBackend(AsyncNetworkBackend):
"""Async backend that validates resolved IPs before establishing TCP."""
def __init__(self, inner: AsyncNetworkBackend | None = None) -> None:
self._inner = inner or AutoBackend()
# `timeout` parameter is required by AsyncNetworkBackend.connect_tcp;
# ruff/ASYNC109 advises against timeout parameters on async APIs *we*
# author, but we are implementing an external interface here. The
# timeout is consumed via `asyncio.timeout()` below, which is the
# asyncio-native pattern ASYNC109 endorses.
async def connect_tcp(
self,
host: str,
port: int,
timeout: float | None = None, # noqa: ASYNC109
local_address: str | None = None,
socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
) -> AsyncNetworkStream:
"""Validate the resolved IP, then connect via the inner backend.
The DNS lookup is wrapped in `asyncio.timeout()` so a slow
resolver is bounded by the caller's timeout. The previous code
only timed out the TCP connect inside the inner backend, leaving
`loop.getaddrinfo` unbounded.
"""
if _check_literal(host):
return await self._inner.connect_tcp(
host, port, timeout, local_address, socket_options
)
loop = asyncio.get_running_loop()
try:
async with asyncio.timeout(timeout):
addr_infos = await loop.getaddrinfo(host, port, type=socket.SOCK_STREAM)
except socket.gaierror as exc:
raise httpcore.ConnectError(
f"DNS resolution failed for {host!r}: {exc}"
) from exc
except TimeoutError as exc:
raise httpcore.ConnectTimeout(
f"DNS resolution timed out for {host!r}"
) from exc
pinned_ip = _pick_safe_address(addr_infos, host)
return await self._inner.connect_tcp(
pinned_ip, port, timeout, local_address, socket_options
)
# See note on connect_tcp re: ASYNC109 / interface implementation.
async def connect_unix_socket(
self,
path: str,
timeout: float | None = None, # noqa: ASYNC109
socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
) -> AsyncNetworkStream:
return await self._inner.connect_unix_socket(path, timeout, socket_options)
async def sleep(self, seconds: float) -> None:
await self._inner.sleep(seconds)
class SSRFProtectedSyncBackend(NetworkBackend):
"""Sync backend that validates resolved IPs before establishing TCP."""
def __init__(self, inner: NetworkBackend | None = None) -> None:
self._inner = inner if inner is not None else SyncBackend()
def connect_tcp(
self,
host: str,
port: int,
timeout: float | None = None,
local_address: str | None = None,
socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
) -> NetworkStream:
if _check_literal(host):
return self._inner.connect_tcp(
host, port, timeout, local_address, socket_options
)
try:
addr_infos = socket.getaddrinfo(host, port, type=socket.SOCK_STREAM)
except socket.gaierror as exc:
raise httpcore.ConnectError(
f"DNS resolution failed for {host!r}: {exc}"
) from exc
pinned_ip = _pick_safe_address(addr_infos, host)
return self._inner.connect_tcp(
pinned_ip, port, timeout, local_address, socket_options
)
def connect_unix_socket(
self,
path: str,
timeout: float | None = None,
socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
) -> NetworkStream:
return self._inner.connect_unix_socket(path, timeout, socket_options)
def install_async_ssrf_protection(client: typing.Any) -> None:
"""Wrap the client's default transport so SSRF validation runs at connect time.
httpx does not expose `network_backend` through its public transport
API, so we mutate `_pool._network_backend` after construction.
"""
pool = client._transport._pool
if not isinstance(pool._network_backend, SSRFProtectedAsyncBackend):
pool._network_backend = SSRFProtectedAsyncBackend(inner=pool._network_backend)
def install_sync_ssrf_protection(client: typing.Any) -> None:
"""Sync counterpart of `install_async_ssrf_protection`."""
pool = client._transport._pool
if not isinstance(pool._network_backend, SSRFProtectedSyncBackend):
pool._network_backend = SSRFProtectedSyncBackend(inner=pool._network_backend)
RESERVED_HOSTNAMES = [
"localhost",
"127.0.0.1",
"0.0.0.0", # trunk-ignore(bandit/B104)
"::1",
"::",
]
def validate_url_for_http_request(url: str, field_name: str = "URL") -> None:
"""Syntactically validate a URL before passing it to an HTTP client.
Fast-fail check for cases that don't need DNS to detect:
- The URL scheme is http or https only
- If the host is a literal IP address, it is not private/internal/reserved
- The host is not a reserved hostname (localhost, 127.0.0.1, etc.)
- The host does not use internal TLDs (.local, .internal, .localhost)
Wired in as an httpx request event hook by `utils.context.create_httpx_*`,
so every request that goes through a context-provided client runs this
automatically — direct calls from feature code are not required.
Dynamic SSRF protection (rejecting hostnames that resolve to a private
IP, including DNS-rebinding names like `127.0.0.1.nip.io`) happens
inside the HTTP client's connect path via the backends above. Doing
the DNS check here would (a) be defeated by DNS rebinding because
httpx re-resolves at connect time, and (b) block the event loop,
since `socket.getaddrinfo` is synchronous and most callers are async.
Args:
url (str): The URL to validate
field_name (str): The name of the field for error messages
Raises:
ValidationError: If the URL is syntactically invalid or matches one
of the static SSRF deny patterns.
"""
if not url or not url.strip():
msg = f"{field_name} cannot be empty"
log.error(msg)
raise ValidationError(msg, field_name)
try:
parsed = urlparse(url)
except Exception as e:
msg = f"Invalid {field_name}: unable to parse URL"
log.error(f"{msg}: {str(e)}")
raise ValidationError(msg, field_name) from e
# Validate scheme - only allow http and https
if parsed.scheme not in ["http", "https"]:
msg = f"Invalid {field_name}: only http and https schemes are allowed"
log.error(f"SSRF prevention: {msg} - got scheme '{parsed.scheme}'")
raise ValidationError(msg, field_name)
# Extract hostname
hostname = parsed.hostname
if not hostname:
msg = f"Invalid {field_name}: missing hostname"
log.error(msg)
raise ValidationError(msg, field_name)
# Block reserved hostnames that are commonly used to refer to internal services.
if hostname.lower() in RESERVED_HOSTNAMES:
msg = f"Invalid {field_name}: localhost and reserved hostnames are not allowed"
log.error(f"SSRF prevention: {msg} - hostname '{hostname}'")
raise ValidationError(msg, field_name)
# Try to parse hostname as a literal IP (standard or non-standard form).
# HTTP clients accept hex (0x7f000001), decimal (2130706433), and
# shorthand-dotted (127.1) integers via inet_aton, so we mirror that.
ip = parse_ip_literal(hostname)
if ip is not None:
if is_forbidden_ip(ip):
msg = (
f"Invalid {field_name}: private, internal, reserved, "
"or multicast IP addresses are not allowed"
)
log.error(f"SSRF prevention: {msg} - IP '{ip}'")
raise ValidationError(msg, field_name)
return
# Block common internal TLDs
hostname_lower = hostname.lower()
internal_tlds = [".local", ".internal", ".localhost"]
if any(hostname_lower.endswith(tld) for tld in internal_tlds):
msg = f"Invalid {field_name}: internal domain names are not allowed"
log.error(f"SSRF prevention: {msg} - hostname '{hostname}'")
raise ValidationError(msg, field_name)