Files
romm/backend/tests/utils/test_ssrf.py
Georges-Antoine Assi c3adbd3f71 fix(ssrf): bound DNS lookup by caller timeout; clear lint findings
The async backend's `loop.getaddrinfo` ran without any timeout, so a
slow or hanging resolver could outlive the timeout the caller passed —
the previous code only bounded the TCP connect inside the inner
backend. Wrap the resolution in `asyncio.timeout(timeout)` and surface
the timeout as `httpcore.ConnectTimeout`.

Also tidy the test stubs (mypy func-returns-value) and add explicit
type annotations to the `calls` lists (mypy var-annotated). A targeted
`# noqa: ASYNC109` sits on the `timeout` parameter of `connect_tcp` /
`connect_unix_socket` with an explanatory comment: the rule advises
against `timeout` parameters on async APIs we author, but here we're
implementing `AsyncNetworkBackend`, and the timeout is consumed in the
asyncio-native pattern the rule endorses.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-27 18:31:42 -04:00

450 lines
17 KiB
Python

"""Tests for SSRF defense: URL validator + httpcore network backends."""
import asyncio
import socket
from typing import Any
from unittest.mock import AsyncMock, MagicMock
import httpcore
import pytest
from utils.ssrf import (
SSRFProtectedAsyncBackend,
SSRFProtectedSyncBackend,
is_forbidden_ip,
parse_ip_literal,
validate_url_for_http_request,
)
from utils.validation import ValidationError
ConnectCall = tuple[tuple[Any, ...], dict[str, Any]]
def _addr_info(ip: str, port: int) -> list[tuple[Any, ...]]:
family = socket.AF_INET6 if ":" in ip else socket.AF_INET
return [(family, socket.SOCK_STREAM, 0, "", (ip, port))]
def _stub_async_inner(connect_calls: list[ConnectCall]) -> MagicMock:
inner = MagicMock()
def _record(*args: Any, **kwargs: Any) -> MagicMock:
connect_calls.append((args, kwargs))
return MagicMock()
inner.connect_tcp = AsyncMock(side_effect=_record)
return inner
def _stub_sync_inner(connect_calls: list[ConnectCall]) -> MagicMock:
inner = MagicMock()
def _record(*args: Any, **kwargs: Any) -> MagicMock:
connect_calls.append((args, kwargs))
return MagicMock()
inner.connect_tcp = MagicMock(side_effect=_record)
return inner
class TestIsForbiddenIp:
@pytest.mark.parametrize(
"ip",
[
"127.0.0.1",
"10.0.0.1",
"192.168.1.1",
"172.16.0.1",
"169.254.169.254",
"0.0.0.0",
"224.0.0.1",
"::1",
"fc00::1",
"fe80::1",
"ff02::1",
],
)
def test_forbidden(self, ip):
import ipaddress
assert is_forbidden_ip(ipaddress.ip_address(ip)) is True
@pytest.mark.parametrize(
"ip", ["8.8.8.8", "1.1.1.1", "93.184.216.34", "2001:4860:4860::8888"]
)
def test_allowed(self, ip):
import ipaddress
assert is_forbidden_ip(ipaddress.ip_address(ip)) is False
class TestParseIpLiteral:
def test_standard(self):
assert str(parse_ip_literal("127.0.0.1")) == "127.0.0.1"
assert str(parse_ip_literal("::1")) == "::1"
def test_non_standard_ipv4(self):
# Hex, decimal, shorthand - all map to 127.0.0.1
assert str(parse_ip_literal("0x7f000001")) == "127.0.0.1"
assert str(parse_ip_literal("2130706433")) == "127.0.0.1"
assert str(parse_ip_literal("127.1")) == "127.0.0.1"
def test_hostname_returns_none(self):
assert parse_ip_literal("example.com") is None
class TestSSRFProtectedAsyncBackend:
async def test_safe_hostname_connects_to_pinned_ip(self, monkeypatch):
"""Backend resolves once, validates, and passes the pinned IP to inner."""
calls: list[ConnectCall] = []
inner = _stub_async_inner(calls)
backend = SSRFProtectedAsyncBackend(inner=inner)
async def fake_getaddrinfo(host, port, *args, **kwargs):
return _addr_info("93.184.216.34", port)
loop = asyncio.get_running_loop()
monkeypatch.setattr(loop, "getaddrinfo", fake_getaddrinfo)
await backend.connect_tcp("example.com", 443)
# inner must receive the resolved IP, not the original hostname.
# That is what pins the address against DNS rebinding.
assert calls[0][0][0] == "93.184.216.34"
assert calls[0][0][1] == 443
async def test_hostname_resolving_to_private_ip_is_rejected(self, monkeypatch):
"""DNS rebinding case: hostname resolves to 127.0.0.1 must fail."""
inner = _stub_async_inner([])
backend = SSRFProtectedAsyncBackend(inner=inner)
async def fake_getaddrinfo(host, port, *args, **kwargs):
return _addr_info("127.0.0.1", port)
monkeypatch.setattr(asyncio.get_running_loop(), "getaddrinfo", fake_getaddrinfo)
with pytest.raises(httpcore.ConnectError, match="forbidden IP"):
await backend.connect_tcp("127.0.0.1.nip.io", 80)
inner.connect_tcp.assert_not_called()
async def test_mixed_resolution_rejected(self, monkeypatch):
"""If any returned address is forbidden, reject - don't trust round-robin."""
inner = _stub_async_inner([])
backend = SSRFProtectedAsyncBackend(inner=inner)
async def fake_getaddrinfo(host, port, *args, **kwargs):
return _addr_info("93.184.216.34", port) + _addr_info("10.0.0.1", port)
monkeypatch.setattr(asyncio.get_running_loop(), "getaddrinfo", fake_getaddrinfo)
with pytest.raises(httpcore.ConnectError, match="forbidden IP"):
await backend.connect_tcp("mixed.example.com", 80)
inner.connect_tcp.assert_not_called()
async def test_literal_forbidden_ip_rejected(self):
inner = _stub_async_inner([])
backend = SSRFProtectedAsyncBackend(inner=inner)
with pytest.raises(httpcore.ConnectError, match="forbidden IP"):
await backend.connect_tcp("169.254.169.254", 80)
inner.connect_tcp.assert_not_called()
async def test_literal_public_ip_passes_through(self):
calls: list[ConnectCall] = []
inner = _stub_async_inner(calls)
backend = SSRFProtectedAsyncBackend(inner=inner)
await backend.connect_tcp("8.8.8.8", 443)
# Literal public IPs are passed through unchanged.
assert calls[0][0][0] == "8.8.8.8"
async def test_non_standard_ipv4_literal_blocked(self):
"""Hex/decimal IPv4 forms must be blocked, matching httpx's parsing."""
inner = _stub_async_inner([])
backend = SSRFProtectedAsyncBackend(inner=inner)
with pytest.raises(httpcore.ConnectError, match="forbidden IP"):
await backend.connect_tcp("2130706433", 80) # 127.0.0.1
inner.connect_tcp.assert_not_called()
async def test_dns_timeout_raises_connect_timeout(self, monkeypatch):
"""A resolver that hangs past the caller's timeout must not block forever.
Regression: an earlier version applied the caller's timeout only to
the TCP connect inside the inner backend, leaving `loop.getaddrinfo`
unbounded. We now wrap the lookup in `asyncio.timeout()` so a slow
resolver is bounded by the same budget the caller specified.
"""
inner = _stub_async_inner([])
backend = SSRFProtectedAsyncBackend(inner=inner)
async def hang_forever(*args, **kwargs):
await asyncio.sleep(3600)
monkeypatch.setattr(asyncio.get_running_loop(), "getaddrinfo", hang_forever)
with pytest.raises(httpcore.ConnectTimeout, match="DNS resolution timed out"):
await backend.connect_tcp("slow.example.com", 80, timeout=0.05)
inner.connect_tcp.assert_not_called()
async def test_dns_failure_propagates_as_connect_error(self, monkeypatch):
inner = _stub_async_inner([])
backend = SSRFProtectedAsyncBackend(inner=inner)
async def fake_getaddrinfo(*args, **kwargs):
raise socket.gaierror("Name or service not known")
monkeypatch.setattr(asyncio.get_running_loop(), "getaddrinfo", fake_getaddrinfo)
with pytest.raises(httpcore.ConnectError, match="DNS resolution failed"):
await backend.connect_tcp("nonexistent.invalid", 80)
inner.connect_tcp.assert_not_called()
class TestSSRFProtectedSyncBackend:
def test_safe_hostname_connects_to_pinned_ip(self, monkeypatch):
calls: list[ConnectCall] = []
inner = _stub_sync_inner(calls)
backend = SSRFProtectedSyncBackend(inner=inner)
monkeypatch.setattr(
socket,
"getaddrinfo",
lambda host, port, *a, **kw: _addr_info("93.184.216.34", port),
)
backend.connect_tcp("example.com", 443)
assert calls[0][0][0] == "93.184.216.34"
def test_hostname_resolving_to_private_ip_is_rejected(self, monkeypatch):
inner = _stub_sync_inner([])
backend = SSRFProtectedSyncBackend(inner=inner)
monkeypatch.setattr(
socket,
"getaddrinfo",
lambda host, port, *a, **kw: _addr_info("127.0.0.1", port),
)
with pytest.raises(httpcore.ConnectError, match="forbidden IP"):
backend.connect_tcp("127.0.0.1.nip.io", 80)
inner.connect_tcp.assert_not_called()
def test_literal_forbidden_ip_rejected(self):
inner = _stub_sync_inner([])
backend = SSRFProtectedSyncBackend(inner=inner)
with pytest.raises(httpcore.ConnectError, match="forbidden IP"):
backend.connect_tcp("10.0.0.1", 80)
inner.connect_tcp.assert_not_called()
class TestRequestEventHook:
"""Verify the syntactic URL validator is wired as a request event hook.
With the hook in place, every request through a context-provided client
is validated automatically; feature code does not need to call
`validate_url_for_http_request` itself.
"""
async def test_async_hook_rejects_bad_scheme(self):
from utils.context import create_httpx_async_client
from utils.validation import ValidationError
client = create_httpx_async_client()
try:
with pytest.raises(ValidationError, match="only http and https"):
await client.get("file:///etc/passwd")
finally:
await client.aclose()
async def test_async_hook_rejects_internal_tld(self):
from utils.context import create_httpx_async_client
from utils.validation import ValidationError
client = create_httpx_async_client()
try:
with pytest.raises(ValidationError, match="internal domain names"):
await client.get("http://printer.local/status")
finally:
await client.aclose()
def test_sync_hook_rejects_literal_private_ip(self):
from utils.context import create_httpx_client
from utils.validation import ValidationError
with create_httpx_client() as client:
with pytest.raises(ValidationError, match="private, internal"):
client.get("http://10.0.0.1/")
class TestInstallation:
"""Verify the backend is actually wired onto httpx clients we create."""
def test_create_httpx_async_client_installs_backend(self):
from utils.context import create_httpx_async_client
from utils.ssrf import SSRFProtectedAsyncBackend as Async
from utils.ssrf import (
_iter_client_transports,
)
client = create_httpx_async_client()
try:
transports = list(_iter_client_transports(client))
assert transports, "expected at least one transport on the client"
for transport in transports:
assert isinstance(transport._pool._network_backend, Async)
finally:
asyncio.run(client.aclose())
def test_create_httpx_client_installs_backend(self):
from utils.context import create_httpx_client
from utils.ssrf import SSRFProtectedSyncBackend as Sync
from utils.ssrf import (
_iter_client_transports,
)
with create_httpx_client() as client:
transports = list(_iter_client_transports(client))
assert transports, "expected at least one transport on the client"
for transport in transports:
assert isinstance(transport._pool._network_backend, Sync)
class TestValidateUrlForHttpRequest:
"""Test URL validation for HTTP requests to prevent SSRF attacks."""
def test_valid_http_urls(self):
"""Valid HTTP/HTTPS URLs pass syntactic validation without DNS lookups.
DNS-based SSRF checks live in the HTTP client's connect path, so
this layer must not call DNS.
"""
validate_url_for_http_request("http://example.com", "test_url")
validate_url_for_http_request("https://example.com", "test_url")
validate_url_for_http_request("http://example.com/path", "test_url")
validate_url_for_http_request("https://example.com/path?query=1", "test_url")
validate_url_for_http_request("http://subdomain.example.com", "test_url")
def test_validator_does_not_perform_dns_lookup(self, monkeypatch):
"""Regression: validator must not block the event loop on DNS.
Earlier implementations called `socket.getaddrinfo`, which both
blocked the running event loop in async media-download callers and
was defeated by DNS rebinding (the value seen here did not match
the IP httpx later connected to). We patch getaddrinfo to a poison
function so any accidental reintroduction fails this test.
"""
def _explode(*_args, **_kwargs):
raise AssertionError(
"validate_url_for_http_request must not call DNS; "
"SSRF DNS protection lives in the HTTP client backend"
)
monkeypatch.setattr(socket, "getaddrinfo", _explode)
validate_url_for_http_request("http://example.com", "test_url")
def test_invalid_empty_url(self):
with pytest.raises(ValidationError) as exc_info:
validate_url_for_http_request("", "test_url")
assert "cannot be empty" in exc_info.value.message
def test_invalid_scheme(self):
for url in (
"ftp://example.com",
"file:///etc/passwd",
"data:text/html,<h1>test</h1>",
"javascript:alert(1)", # XSS attack vector
):
with pytest.raises(ValidationError) as exc_info:
validate_url_for_http_request(url, "test_url")
assert "only http and https schemes are allowed" in exc_info.value.message
def test_invalid_localhost(self):
for url in (
"http://localhost",
"http://127.0.0.1",
"http://[::1]",
"http://0.0.0.0",
):
with pytest.raises(ValidationError) as exc_info:
validate_url_for_http_request(url, "test_url")
assert (
"localhost and reserved hostnames are not allowed"
in exc_info.value.message
)
def test_invalid_private_ipv4_addresses(self):
for url in (
"http://10.0.0.1",
"http://192.168.1.1",
"http://172.16.0.1",
"http://172.31.255.254",
):
with pytest.raises(ValidationError) as exc_info:
validate_url_for_http_request(url, "test_url")
assert (
"private, internal, reserved, or multicast IP addresses are not allowed"
in exc_info.value.message
)
def test_invalid_loopback_addresses(self):
# 127.0.0.1 itself is in RESERVED_HOSTNAMES; these cover the rest of 127/8.
for url in ("http://127.0.0.2", "http://127.255.255.255"):
with pytest.raises(ValidationError) as exc_info:
validate_url_for_http_request(url, "test_url")
assert (
"private, internal, reserved, or multicast IP addresses are not allowed"
in exc_info.value.message
)
def test_invalid_private_ipv6_addresses(self):
for url in (
"http://[fe80::1]", # link-local
"http://[fc00::1]", # unique local
"http://[fd00::1]", # unique local
):
with pytest.raises(ValidationError) as exc_info:
validate_url_for_http_request(url, "test_url")
assert (
"private, internal, reserved, or multicast IP addresses are not allowed"
in exc_info.value.message
)
def test_invalid_multicast_addresses(self):
for url in ("http://224.0.0.1", "http://[ff02::1]"):
with pytest.raises(ValidationError) as exc_info:
validate_url_for_http_request(url, "test_url")
assert (
"private, internal, reserved, or multicast IP addresses are not allowed"
in exc_info.value.message
)
def test_invalid_internal_tlds(self):
for url in (
"http://server.local",
"http://server.internal",
"http://server.localhost",
):
with pytest.raises(ValidationError) as exc_info:
validate_url_for_http_request(url, "test_url")
assert "internal domain names are not allowed" in exc_info.value.message
def test_invalid_non_standard_ip_representations(self):
"""Non-standard IPv4 forms (hex, decimal, shorthand) are SSRF bypass vectors."""
cases = [
"http://0x7f000001", # hex 127.0.0.1
"http://2130706433", # decimal 127.0.0.1
"http://127.1", # shorthand 127.0.0.1
"http://0x0a000001", # hex 10.0.0.1
"http://3232235777", # decimal 192.168.1.1
"http://0xa9fea9fe", # hex 169.254.169.254 (cloud metadata)
]
for url in cases:
with pytest.raises(ValidationError) as exc_info:
validate_url_for_http_request(url, "test_url")
assert (
"private, internal, reserved, or multicast IP addresses are not allowed"
in exc_info.value.message
)
def test_invalid_missing_hostname(self):
with pytest.raises(ValidationError) as exc_info:
validate_url_for_http_request("http://", "test_url")
assert "missing hostname" in exc_info.value.message