diff --git a/backend/handler/filesystem/resources_handler.py b/backend/handler/filesystem/resources_handler.py index 8d8a0883d..081d45e67 100644 --- a/backend/handler/filesystem/resources_handler.py +++ b/backend/handler/filesystem/resources_handler.py @@ -15,7 +15,6 @@ from models.collection import Collection from models.rom import Rom from tasks.scheduled.convert_images_to_webp import ImageConverter from utils.context import ctx_httpx_client -from utils.validation import validate_url_for_http_request from .base_handler import CoverSize, FSHandler @@ -145,8 +144,6 @@ class FSResourcesHandler(FSHandler): return None else: # Handle HTTP URLs - validate_url_for_http_request(url_cover, "url_cover") - httpx_client = ctx_httpx_client.get() try: async with httpx_client.stream( @@ -314,8 +311,6 @@ class FSResourcesHandler(FSHandler): return None else: # Handle HTTP URLs - validate_url_for_http_request(url_screenhot, "url_screenshot") - httpx_client = ctx_httpx_client.get() try: async with httpx_client.stream( @@ -430,8 +425,6 @@ class FSResourcesHandler(FSHandler): return None else: # Handle HTTP URL - validate_url_for_http_request(url_manual, "url_manual") - httpx_client = ctx_httpx_client.get() try: async with httpx_client.stream( @@ -501,8 +494,6 @@ class FSResourcesHandler(FSHandler): # Retroachievements async def store_ra_badge(self, url: str, path: str) -> None: - validate_url_for_http_request(url, "url_badge") - httpx_client = ctx_httpx_client.get() directory, filename = os.path.split(path) @@ -569,8 +560,6 @@ class FSResourcesHandler(FSHandler): return None else: # Handle HTTP URLs - validate_url_for_http_request(url_media, "url_media") - httpx_client = ctx_httpx_client.get() try: async with httpx_client.stream( diff --git a/backend/tests/utils/test_ssrf.py b/backend/tests/utils/test_ssrf.py new file mode 100644 index 000000000..2be1c0a97 --- /dev/null +++ b/backend/tests/utils/test_ssrf.py @@ -0,0 +1,466 @@ +"""Tests for SSRF defense: URL validator + httpcore network backends.""" + +import asyncio +import socket +from typing import Any +from unittest.mock import AsyncMock, MagicMock + +import httpcore +import pytest + +from utils.ssrf import ( + SSRFProtectedAsyncBackend, + SSRFProtectedSyncBackend, + is_forbidden_ip, + parse_ip_literal, + validate_url_for_http_request, +) +from utils.validation import ValidationError + +ConnectCall = tuple[tuple[Any, ...], dict[str, Any]] + + +def _addr_info(ip: str, port: int) -> list[tuple[Any, ...]]: + family = socket.AF_INET6 if ":" in ip else socket.AF_INET + return [(family, socket.SOCK_STREAM, 0, "", (ip, port))] + + +def _stub_async_inner(connect_calls: list[ConnectCall]) -> MagicMock: + inner = MagicMock() + + def _record(*args: Any, **kwargs: Any) -> MagicMock: + connect_calls.append((args, kwargs)) + return MagicMock() + + inner.connect_tcp = AsyncMock(side_effect=_record) + return inner + + +def _stub_sync_inner(connect_calls: list[ConnectCall]) -> MagicMock: + inner = MagicMock() + + def _record(*args: Any, **kwargs: Any) -> MagicMock: + connect_calls.append((args, kwargs)) + return MagicMock() + + inner.connect_tcp = MagicMock(side_effect=_record) + return inner + + +class TestIsForbiddenIp: + @pytest.mark.parametrize( + "ip", + [ + "127.0.0.1", + "10.0.0.1", + "192.168.1.1", + "172.16.0.1", + "169.254.169.254", + "0.0.0.0", + "224.0.0.1", + "::1", + "fc00::1", + "fe80::1", + "ff02::1", + ], + ) + def test_forbidden(self, ip): + import ipaddress + + assert is_forbidden_ip(ipaddress.ip_address(ip)) is True + + @pytest.mark.parametrize( + "ip", ["8.8.8.8", "1.1.1.1", "93.184.216.34", "2001:4860:4860::8888"] + ) + def test_allowed(self, ip): + import ipaddress + + assert is_forbidden_ip(ipaddress.ip_address(ip)) is False + + +class TestParseIpLiteral: + def test_standard(self): + assert str(parse_ip_literal("127.0.0.1")) == "127.0.0.1" + assert str(parse_ip_literal("::1")) == "::1" + + def test_non_standard_ipv4(self): + # Hex, decimal, shorthand - all map to 127.0.0.1 + assert str(parse_ip_literal("0x7f000001")) == "127.0.0.1" + assert str(parse_ip_literal("2130706433")) == "127.0.0.1" + assert str(parse_ip_literal("127.1")) == "127.0.0.1" + + def test_hostname_returns_none(self): + assert parse_ip_literal("example.com") is None + + +class TestSSRFProtectedAsyncBackend: + async def test_safe_hostname_connects_to_pinned_ip(self, monkeypatch): + """Backend resolves once, validates, and passes the pinned IP to inner.""" + calls: list[ConnectCall] = [] + inner = _stub_async_inner(calls) + backend = SSRFProtectedAsyncBackend(inner=inner) + + async def fake_getaddrinfo(host, port, *args, **kwargs): + return _addr_info("93.184.216.34", port) + + loop = asyncio.get_running_loop() + monkeypatch.setattr(loop, "getaddrinfo", fake_getaddrinfo) + + await backend.connect_tcp("example.com", 443) + + # inner must receive the resolved IP, not the original hostname. + # That is what pins the address against DNS rebinding. + assert calls[0][0][0] == "93.184.216.34" + assert calls[0][0][1] == 443 + + async def test_hostname_resolving_to_private_ip_is_rejected(self, monkeypatch): + """DNS rebinding case: hostname resolves to 127.0.0.1 must fail.""" + inner = _stub_async_inner([]) + backend = SSRFProtectedAsyncBackend(inner=inner) + + async def fake_getaddrinfo(host, port, *args, **kwargs): + return _addr_info("127.0.0.1", port) + + monkeypatch.setattr(asyncio.get_running_loop(), "getaddrinfo", fake_getaddrinfo) + + with pytest.raises(httpcore.ConnectError, match="forbidden IP"): + await backend.connect_tcp("127.0.0.1.nip.io", 80) + inner.connect_tcp.assert_not_called() + + async def test_mixed_resolution_rejected(self, monkeypatch): + """If any returned address is forbidden, reject - don't trust round-robin.""" + inner = _stub_async_inner([]) + backend = SSRFProtectedAsyncBackend(inner=inner) + + async def fake_getaddrinfo(host, port, *args, **kwargs): + return _addr_info("93.184.216.34", port) + _addr_info("10.0.0.1", port) + + monkeypatch.setattr(asyncio.get_running_loop(), "getaddrinfo", fake_getaddrinfo) + + with pytest.raises(httpcore.ConnectError, match="forbidden IP"): + await backend.connect_tcp("mixed.example.com", 80) + inner.connect_tcp.assert_not_called() + + async def test_literal_forbidden_ip_rejected(self): + inner = _stub_async_inner([]) + backend = SSRFProtectedAsyncBackend(inner=inner) + with pytest.raises(httpcore.ConnectError, match="forbidden IP"): + await backend.connect_tcp("169.254.169.254", 80) + inner.connect_tcp.assert_not_called() + + async def test_literal_public_ip_passes_through(self): + calls: list[ConnectCall] = [] + inner = _stub_async_inner(calls) + backend = SSRFProtectedAsyncBackend(inner=inner) + await backend.connect_tcp("8.8.8.8", 443) + # Literal public IPs are passed through unchanged. + assert calls[0][0][0] == "8.8.8.8" + + async def test_non_standard_ipv4_literal_blocked(self): + """Hex/decimal IPv4 forms must be blocked, matching httpx's parsing.""" + inner = _stub_async_inner([]) + backend = SSRFProtectedAsyncBackend(inner=inner) + with pytest.raises(httpcore.ConnectError, match="forbidden IP"): + await backend.connect_tcp("2130706433", 80) # 127.0.0.1 + inner.connect_tcp.assert_not_called() + + async def test_dns_timeout_raises_connect_timeout(self, monkeypatch): + """A resolver that hangs past the caller's timeout must not block forever. + + Regression: an earlier version applied the caller's timeout only to + the TCP connect inside the inner backend, leaving `loop.getaddrinfo` + unbounded. We now wrap the lookup in `asyncio.timeout()` so a slow + resolver is bounded by the same budget the caller specified. + """ + inner = _stub_async_inner([]) + backend = SSRFProtectedAsyncBackend(inner=inner) + + async def hang_forever(*args, **kwargs): + await asyncio.sleep(3600) + + monkeypatch.setattr(asyncio.get_running_loop(), "getaddrinfo", hang_forever) + + with pytest.raises(httpcore.ConnectTimeout, match="DNS resolution timed out"): + await backend.connect_tcp("slow.example.com", 80, timeout=0.05) + inner.connect_tcp.assert_not_called() + + async def test_dns_failure_propagates_as_connect_error(self, monkeypatch): + inner = _stub_async_inner([]) + backend = SSRFProtectedAsyncBackend(inner=inner) + + async def fake_getaddrinfo(*args, **kwargs): + raise socket.gaierror("Name or service not known") + + monkeypatch.setattr(asyncio.get_running_loop(), "getaddrinfo", fake_getaddrinfo) + + with pytest.raises(httpcore.ConnectError, match="DNS resolution failed"): + await backend.connect_tcp("nonexistent.invalid", 80) + inner.connect_tcp.assert_not_called() + + +class TestSSRFProtectedSyncBackend: + def test_safe_hostname_connects_to_pinned_ip(self, monkeypatch): + calls: list[ConnectCall] = [] + inner = _stub_sync_inner(calls) + backend = SSRFProtectedSyncBackend(inner=inner) + + monkeypatch.setattr( + socket, + "getaddrinfo", + lambda host, port, *a, **kw: _addr_info("93.184.216.34", port), + ) + + backend.connect_tcp("example.com", 443) + assert calls[0][0][0] == "93.184.216.34" + + def test_hostname_resolving_to_private_ip_is_rejected(self, monkeypatch): + inner = _stub_sync_inner([]) + backend = SSRFProtectedSyncBackend(inner=inner) + monkeypatch.setattr( + socket, + "getaddrinfo", + lambda host, port, *a, **kw: _addr_info("127.0.0.1", port), + ) + with pytest.raises(httpcore.ConnectError, match="forbidden IP"): + backend.connect_tcp("127.0.0.1.nip.io", 80) + inner.connect_tcp.assert_not_called() + + def test_literal_forbidden_ip_rejected(self): + inner = _stub_sync_inner([]) + backend = SSRFProtectedSyncBackend(inner=inner) + with pytest.raises(httpcore.ConnectError, match="forbidden IP"): + backend.connect_tcp("10.0.0.1", 80) + inner.connect_tcp.assert_not_called() + + +class TestRequestEventHook: + """Verify the syntactic URL validator is wired as a request event hook. + + With the hook in place, every request through a context-provided client + is validated automatically; feature code does not need to call + `validate_url_for_http_request` itself. + """ + + async def test_async_hook_rejects_bad_scheme(self): + from utils.context import create_httpx_async_client + from utils.validation import ValidationError + + client = create_httpx_async_client() + try: + with pytest.raises(ValidationError, match="only http and https"): + await client.get("file:///etc/passwd") + finally: + await client.aclose() + + async def test_async_hook_rejects_internal_tld(self): + from utils.context import create_httpx_async_client + from utils.validation import ValidationError + + client = create_httpx_async_client() + try: + with pytest.raises(ValidationError, match="internal domain names"): + await client.get("http://printer.local/status") + finally: + await client.aclose() + + def test_sync_hook_rejects_literal_private_ip(self): + from utils.context import create_httpx_client + from utils.validation import ValidationError + + with create_httpx_client() as client: + with pytest.raises(ValidationError, match="private, internal"): + client.get("http://10.0.0.1/") + + +class TestInstallation: + """Verify the backend is actually wired onto httpx clients we create.""" + + def test_create_httpx_async_client_installs_backend(self): + from utils.context import create_httpx_async_client + from utils.ssrf import SSRFProtectedAsyncBackend as Async + + client = create_httpx_async_client() + try: + assert isinstance(client._transport._pool._network_backend, Async) + finally: + asyncio.run(client.aclose()) + + def test_create_httpx_client_installs_backend(self): + from utils.context import create_httpx_client + from utils.ssrf import SSRFProtectedSyncBackend as Sync + + with create_httpx_client() as client: + assert isinstance(client._transport._pool._network_backend, Sync) + + def test_proxy_transports_are_not_wrapped(self): + """Proxy mounts must keep their stock backend. + + A common deployment pattern is `HTTPS_PROXY=http://sidecar:9050`, + where `sidecar` resolves to a docker-bridge private IP. If we + wrapped the proxy transport, our SSRF backend would refuse to + connect to the operator's chosen proxy. SSRF protection at the + proxy hop is the operator's responsibility; the destination URL + is still validated by the request event hook on the client. + """ + import httpx + + from utils.ssrf import SSRFProtectedAsyncBackend as Async + from utils.ssrf import ( + install_async_ssrf_protection, + ) + + client = httpx.AsyncClient(proxy="http://proxy.invalid:3128") + try: + install_async_ssrf_protection(client) + assert isinstance(client._transport._pool._network_backend, Async) + for mount in client._mounts.values(): + if mount is None: + continue + # Proxy mount must NOT have been wrapped. + assert not isinstance(mount._pool._network_backend, Async) + finally: + asyncio.run(client.aclose()) + + +class TestValidateUrlForHttpRequest: + """Test URL validation for HTTP requests to prevent SSRF attacks.""" + + def test_valid_http_urls(self): + """Valid HTTP/HTTPS URLs pass syntactic validation without DNS lookups. + + DNS-based SSRF checks live in the HTTP client's connect path, so + this layer must not call DNS. + """ + validate_url_for_http_request("http://example.com", "test_url") + validate_url_for_http_request("https://example.com", "test_url") + validate_url_for_http_request("http://example.com/path", "test_url") + validate_url_for_http_request("https://example.com/path?query=1", "test_url") + validate_url_for_http_request("http://subdomain.example.com", "test_url") + + def test_validator_does_not_perform_dns_lookup(self, monkeypatch): + """Regression: validator must not block the event loop on DNS. + + Earlier implementations called `socket.getaddrinfo`, which both + blocked the running event loop in async media-download callers and + was defeated by DNS rebinding (the value seen here did not match + the IP httpx later connected to). We patch getaddrinfo to a poison + function so any accidental reintroduction fails this test. + """ + + def _explode(*_args, **_kwargs): + raise AssertionError( + "validate_url_for_http_request must not call DNS; " + "SSRF DNS protection lives in the HTTP client backend" + ) + + monkeypatch.setattr(socket, "getaddrinfo", _explode) + validate_url_for_http_request("http://example.com", "test_url") + + def test_invalid_empty_url(self): + with pytest.raises(ValidationError) as exc_info: + validate_url_for_http_request("", "test_url") + assert "cannot be empty" in exc_info.value.message + + def test_invalid_scheme(self): + for url in ( + "ftp://example.com", + "file:///etc/passwd", + "data:text/html,

test

", + "javascript:alert(1)", # XSS attack vector + ): + with pytest.raises(ValidationError) as exc_info: + validate_url_for_http_request(url, "test_url") + assert "only http and https schemes are allowed" in exc_info.value.message + + def test_invalid_localhost(self): + for url in ( + "http://localhost", + "http://127.0.0.1", + "http://[::1]", + "http://0.0.0.0", + ): + with pytest.raises(ValidationError) as exc_info: + validate_url_for_http_request(url, "test_url") + assert ( + "localhost and reserved hostnames are not allowed" + in exc_info.value.message + ) + + def test_invalid_private_ipv4_addresses(self): + for url in ( + "http://10.0.0.1", + "http://192.168.1.1", + "http://172.16.0.1", + "http://172.31.255.254", + ): + with pytest.raises(ValidationError) as exc_info: + validate_url_for_http_request(url, "test_url") + assert ( + "private, internal, reserved, or multicast IP addresses are not allowed" + in exc_info.value.message + ) + + def test_invalid_loopback_addresses(self): + # 127.0.0.1 itself is in RESERVED_HOSTNAMES; these cover the rest of 127/8. + for url in ("http://127.0.0.2", "http://127.255.255.255"): + with pytest.raises(ValidationError) as exc_info: + validate_url_for_http_request(url, "test_url") + assert ( + "private, internal, reserved, or multicast IP addresses are not allowed" + in exc_info.value.message + ) + + def test_invalid_private_ipv6_addresses(self): + for url in ( + "http://[fe80::1]", # link-local + "http://[fc00::1]", # unique local + "http://[fd00::1]", # unique local + ): + with pytest.raises(ValidationError) as exc_info: + validate_url_for_http_request(url, "test_url") + assert ( + "private, internal, reserved, or multicast IP addresses are not allowed" + in exc_info.value.message + ) + + def test_invalid_multicast_addresses(self): + for url in ("http://224.0.0.1", "http://[ff02::1]"): + with pytest.raises(ValidationError) as exc_info: + validate_url_for_http_request(url, "test_url") + assert ( + "private, internal, reserved, or multicast IP addresses are not allowed" + in exc_info.value.message + ) + + def test_invalid_internal_tlds(self): + for url in ( + "http://server.local", + "http://server.internal", + "http://server.localhost", + ): + with pytest.raises(ValidationError) as exc_info: + validate_url_for_http_request(url, "test_url") + assert "internal domain names are not allowed" in exc_info.value.message + + def test_invalid_non_standard_ip_representations(self): + """Non-standard IPv4 forms (hex, decimal, shorthand) are SSRF bypass vectors.""" + cases = [ + "http://0x7f000001", # hex 127.0.0.1 + "http://2130706433", # decimal 127.0.0.1 + "http://127.1", # shorthand 127.0.0.1 + "http://0x0a000001", # hex 10.0.0.1 + "http://3232235777", # decimal 192.168.1.1 + "http://0xa9fea9fe", # hex 169.254.169.254 (cloud metadata) + ] + for url in cases: + with pytest.raises(ValidationError) as exc_info: + validate_url_for_http_request(url, "test_url") + assert ( + "private, internal, reserved, or multicast IP addresses are not allowed" + in exc_info.value.message + ) + + def test_invalid_missing_hostname(self): + with pytest.raises(ValidationError) as exc_info: + validate_url_for_http_request("http://", "test_url") + assert "missing hostname" in exc_info.value.message diff --git a/backend/tests/utils/test_validation.py b/backend/tests/utils/test_validation.py index b55720579..73923cbd5 100644 --- a/backend/tests/utils/test_validation.py +++ b/backend/tests/utils/test_validation.py @@ -7,7 +7,6 @@ from utils.validation import ( validate_ascii_only, validate_email, validate_password, - validate_url_for_http_request, validate_username, ) @@ -158,226 +157,3 @@ class TestValidateEmail: with pytest.raises(ValidationError) as exc_info: validate_email("résumé@example.com") assert "ASCII characters" in exc_info.value.message - - -class TestValidateUrlForHttpRequest: - """Test URL validation for HTTP requests to prevent SSRF attacks.""" - - def test_valid_http_urls(self): - """Test that valid HTTP/HTTPS URLs pass validation.""" - validate_url_for_http_request("http://example.com", "test_url") - validate_url_for_http_request("https://example.com", "test_url") - validate_url_for_http_request("http://example.com/path", "test_url") - validate_url_for_http_request("https://example.com/path?query=1", "test_url") - validate_url_for_http_request("http://subdomain.example.com", "test_url") - - def test_invalid_empty_url(self): - """Test that empty URLs fail validation.""" - with pytest.raises(ValidationError) as exc_info: - validate_url_for_http_request("", "test_url") - assert "cannot be empty" in exc_info.value.message - - def test_invalid_scheme(self): - """Test that non-HTTP/HTTPS schemes fail validation.""" - # FTP scheme - with pytest.raises(ValidationError) as exc_info: - validate_url_for_http_request("ftp://example.com", "test_url") - assert "only http and https schemes are allowed" in exc_info.value.message - - # File scheme - with pytest.raises(ValidationError) as exc_info: - validate_url_for_http_request("file:///etc/passwd", "test_url") - assert "only http and https schemes are allowed" in exc_info.value.message - - # Data scheme - with pytest.raises(ValidationError) as exc_info: - validate_url_for_http_request("data:text/html,

test

", "test_url") - assert "only http and https schemes are allowed" in exc_info.value.message - - # JavaScript scheme (XSS attack vector) - with pytest.raises(ValidationError) as exc_info: - validate_url_for_http_request("javascript:alert(1)", "test_url") - assert "only http and https schemes are allowed" in exc_info.value.message - - def test_invalid_localhost(self): - """Test that localhost and reserved hostnames fail validation.""" - with pytest.raises(ValidationError) as exc_info: - validate_url_for_http_request("http://localhost", "test_url") - assert ( - "localhost and reserved hostnames are not allowed" in exc_info.value.message - ) - - with pytest.raises(ValidationError) as exc_info: - validate_url_for_http_request("http://127.0.0.1", "test_url") - assert ( - "localhost and reserved hostnames are not allowed" in exc_info.value.message - ) - - with pytest.raises(ValidationError) as exc_info: - validate_url_for_http_request("http://[::1]", "test_url") - assert ( - "localhost and reserved hostnames are not allowed" in exc_info.value.message - ) - - with pytest.raises(ValidationError) as exc_info: - validate_url_for_http_request("http://0.0.0.0", "test_url") - assert ( - "localhost and reserved hostnames are not allowed" in exc_info.value.message - ) - - def test_invalid_private_ipv4_addresses(self): - """Test that private IPv4 addresses fail validation.""" - # 10.x.x.x range - with pytest.raises(ValidationError) as exc_info: - validate_url_for_http_request("http://10.0.0.1", "test_url") - assert ( - "private, internal, and reserved IP addresses are not allowed" - in exc_info.value.message - ) - - # 192.168.x.x range - with pytest.raises(ValidationError) as exc_info: - validate_url_for_http_request("http://192.168.1.1", "test_url") - assert ( - "private, internal, and reserved IP addresses are not allowed" - in exc_info.value.message - ) - - # 172.16.x.x - 172.31.x.x range - with pytest.raises(ValidationError) as exc_info: - validate_url_for_http_request("http://172.16.0.1", "test_url") - assert ( - "private, internal, and reserved IP addresses are not allowed" - in exc_info.value.message - ) - - with pytest.raises(ValidationError) as exc_info: - validate_url_for_http_request("http://172.31.255.254", "test_url") - assert ( - "private, internal, and reserved IP addresses are not allowed" - in exc_info.value.message - ) - - def test_invalid_loopback_addresses(self): - """Test that loopback addresses fail validation.""" - # 127.x.x.x range - with pytest.raises(ValidationError) as exc_info: - validate_url_for_http_request("http://127.0.0.2", "test_url") - assert ( - "private, internal, and reserved IP addresses are not allowed" - in exc_info.value.message - ) - - with pytest.raises(ValidationError) as exc_info: - validate_url_for_http_request("http://127.255.255.255", "test_url") - assert ( - "private, internal, and reserved IP addresses are not allowed" - in exc_info.value.message - ) - - def test_invalid_private_ipv6_addresses(self): - """Test that private/link-local IPv6 addresses fail validation.""" - # Link-local IPv6: fe80::/10 - with pytest.raises(ValidationError) as exc_info: - validate_url_for_http_request("http://[fe80::1]", "test_url") - assert ( - "private, internal, and reserved IP addresses are not allowed" - in exc_info.value.message - ) - - # Unique local address: fc00::/7 - with pytest.raises(ValidationError) as exc_info: - validate_url_for_http_request("http://[fc00::1]", "test_url") - assert ( - "private, internal, and reserved IP addresses are not allowed" - in exc_info.value.message - ) - - with pytest.raises(ValidationError) as exc_info: - validate_url_for_http_request("http://[fd00::1]", "test_url") - assert ( - "private, internal, and reserved IP addresses are not allowed" - in exc_info.value.message - ) - - def test_invalid_multicast_addresses(self): - """Test that multicast addresses fail validation.""" - # IPv4 multicast: 224.0.0.0/4 - with pytest.raises(ValidationError) as exc_info: - validate_url_for_http_request("http://224.0.0.1", "test_url") - assert "multicast addresses are not allowed" in exc_info.value.message - - # IPv6 multicast: ff00::/8 - with pytest.raises(ValidationError) as exc_info: - validate_url_for_http_request("http://[ff02::1]", "test_url") - assert "multicast addresses are not allowed" in exc_info.value.message - - def test_invalid_internal_tlds(self): - """Test that internal TLDs fail validation.""" - with pytest.raises(ValidationError) as exc_info: - validate_url_for_http_request("http://server.local", "test_url") - assert "internal domain names are not allowed" in exc_info.value.message - - with pytest.raises(ValidationError) as exc_info: - validate_url_for_http_request("http://server.internal", "test_url") - assert "internal domain names are not allowed" in exc_info.value.message - - with pytest.raises(ValidationError) as exc_info: - validate_url_for_http_request("http://server.localhost", "test_url") - assert "internal domain names are not allowed" in exc_info.value.message - - def test_invalid_non_standard_ip_representations(self): - """Test that non-standard IP representations are blocked (SSRF bypass vectors).""" - # Hexadecimal integer for 127.0.0.1 - with pytest.raises(ValidationError) as exc_info: - validate_url_for_http_request("http://0x7f000001", "test_url") - assert ( - "private, internal, and reserved IP addresses are not allowed" - in exc_info.value.message - ) - - # Decimal integer for 127.0.0.1 - with pytest.raises(ValidationError) as exc_info: - validate_url_for_http_request("http://2130706433", "test_url") - assert ( - "private, internal, and reserved IP addresses are not allowed" - in exc_info.value.message - ) - - # Shorthand dotted for 127.0.0.1 - with pytest.raises(ValidationError) as exc_info: - validate_url_for_http_request("http://127.1", "test_url") - assert ( - "private, internal, and reserved IP addresses are not allowed" - in exc_info.value.message - ) - - # Hexadecimal integer for 10.0.0.1 - with pytest.raises(ValidationError) as exc_info: - validate_url_for_http_request("http://0x0a000001", "test_url") - assert ( - "private, internal, and reserved IP addresses are not allowed" - in exc_info.value.message - ) - - # Decimal integer for 192.168.1.1 - with pytest.raises(ValidationError) as exc_info: - validate_url_for_http_request("http://3232235777", "test_url") - assert ( - "private, internal, and reserved IP addresses are not allowed" - in exc_info.value.message - ) - - # Hexadecimal integer for 169.254.169.254 (cloud metadata) - with pytest.raises(ValidationError) as exc_info: - validate_url_for_http_request("http://0xa9fea9fe", "test_url") - assert ( - "private, internal, and reserved IP addresses are not allowed" - in exc_info.value.message - ) - - def test_invalid_missing_hostname(self): - """Test that URLs without hostnames fail validation.""" - with pytest.raises(ValidationError) as exc_info: - validate_url_for_http_request("http://", "test_url") - assert "missing hostname" in exc_info.value.message diff --git a/backend/utils/context.py b/backend/utils/context.py index 4f9188aaf..b49faf1fa 100644 --- a/backend/utils/context.py +++ b/backend/utils/context.py @@ -8,6 +8,11 @@ import httpx from fastapi import Request, Response from config import has_proxy_env +from utils.ssrf import ( + install_async_ssrf_protection, + install_sync_ssrf_protection, + validate_url_for_http_request, +) _T = TypeVar("_T") @@ -15,16 +20,34 @@ ctx_aiohttp_session: ContextVar[aiohttp.ClientSession] = ContextVar("aiohttp_ses ctx_httpx_client: ContextVar[httpx.AsyncClient] = ContextVar("httpx_client") +def _validate_request_url_sync(request: httpx.Request) -> None: + validate_url_for_http_request(str(request.url)) + + +async def _validate_request_url_async(request: httpx.Request) -> None: + validate_url_for_http_request(str(request.url)) + + def create_aiohttp_session() -> aiohttp.ClientSession: return aiohttp.ClientSession(trust_env=has_proxy_env()) def create_httpx_async_client() -> httpx.AsyncClient: - return httpx.AsyncClient(trust_env=has_proxy_env()) + client = httpx.AsyncClient( + trust_env=has_proxy_env(), + event_hooks={"request": [_validate_request_url_async]}, + ) + install_async_ssrf_protection(client) + return client def create_httpx_client() -> httpx.Client: - return httpx.Client(trust_env=has_proxy_env()) + client = httpx.Client( + trust_env=has_proxy_env(), + event_hooks={"request": [_validate_request_url_sync]}, + ) + install_sync_ssrf_protection(client) + return client @asynccontextmanager diff --git a/backend/utils/ssrf.py b/backend/utils/ssrf.py new file mode 100644 index 000000000..118f0afa7 --- /dev/null +++ b/backend/utils/ssrf.py @@ -0,0 +1,331 @@ +"""SSRF defense for outbound HTTP. + +Two layers, both wired onto every httpx client built by `utils.context`: + + 1. `validate_url_for_http_request` — a syntactic fast-fail check + installed as an httpx request event hook. Rejects non-HTTP schemes, + literal IPs in forbidden ranges (including non-standard IPv4 forms), + reserved hostnames, and internal TLDs before any socket opens. + + 2. `SSRFProtectedAsyncBackend` / `SSRFProtectedSyncBackend` — custom + httpcore network backends that resolve the hostname inside + `connect_tcp`, reject any address in a private/loopback/link-local/ + reserved/multicast/unspecified range, then connect to that *same* + validated address. This is what defeats DNS rebinding: the address + used by the OS for the TCP connection is the one we just checked, + not a fresh lookup the attacker can answer differently. Doing this + work in the backend also avoids blocking the event loop, since the + async variant uses `loop.getaddrinfo`. + +httpcore calls `start_tls(server_hostname=)` after +`connect_tcp` returns, so TLS SNI and certificate verification still +use the original hostname even though we connect by IP. +""" + +from __future__ import annotations + +import asyncio +import ipaddress +import socket +import typing +from urllib.parse import urlparse + +import httpcore +from httpcore._backends.auto import AutoBackend +from httpcore._backends.base import ( + SOCKET_OPTION, + AsyncNetworkBackend, + AsyncNetworkStream, + NetworkBackend, + NetworkStream, +) +from httpcore._backends.sync import SyncBackend + +from logger.logger import log +from utils.validation import ValidationError + + +def is_forbidden_ip(ip: ipaddress.IPv4Address | ipaddress.IPv6Address) -> bool: + """Return True if the IP must not be reached from a server-side HTTP request.""" + return ( + ip.is_private + or ip.is_loopback + or ip.is_link_local + or ip.is_reserved + or ip.is_multicast + or ip.is_unspecified + ) + + +def parse_ip_literal( + host: str, +) -> ipaddress.IPv4Address | ipaddress.IPv6Address | None: + """Return the parsed IP if `host` is a literal address, else None. + + Accepts non-standard IPv4 forms (hex, decimal, shorthand) via + `socket.inet_aton`, which is what HTTP clients themselves accept. + """ + try: + return ipaddress.ip_address(host) + except ValueError: + pass + try: + packed = socket.inet_aton(host) + except OSError: + return None + return ipaddress.IPv4Address(packed) + + +def _pick_safe_address(addr_infos: typing.Iterable[typing.Any], host: str) -> str: + """Validate every resolved address and return the literal IP to connect to. + + All returned addresses are checked: if any falls in a forbidden range + we reject the whole name, rather than just skipping that record. A + malicious DNS server can otherwise mix public and private answers + and rely on the client to round-robin. + """ + chosen: str | None = None + for *_, sockaddr in addr_infos: + try: + ip = ipaddress.ip_address(sockaddr[0]) + except (ValueError, IndexError): + continue + if is_forbidden_ip(ip): + msg = ( + f"SSRF prevention: hostname {host!r} resolves to forbidden " f"IP {ip}" + ) + log.error(msg) + raise httpcore.ConnectError(msg) + if chosen is None: + chosen = sockaddr[0] + if chosen is None: + raise httpcore.ConnectError(f"No usable addresses for {host!r}") + return chosen + + +def _check_literal(host: str) -> bool: + """Return True if `host` is a literal IP that has been validated as safe. + + Raises httpcore.ConnectError if it is a literal IP in a forbidden range. + Returns False if `host` is a hostname (caller must resolve and validate). + """ + literal = parse_ip_literal(host) + if literal is None: + return False + if is_forbidden_ip(literal): + msg = f"SSRF prevention: connection to forbidden IP {literal}" + log.error(msg) + raise httpcore.ConnectError(msg) + return True + + +class SSRFProtectedAsyncBackend(AsyncNetworkBackend): + """Async backend that validates resolved IPs before establishing TCP.""" + + def __init__(self, inner: AsyncNetworkBackend | None = None) -> None: + self._inner = inner or AutoBackend() + + # `timeout` parameter is required by AsyncNetworkBackend.connect_tcp; + # ruff/ASYNC109 advises against timeout parameters on async APIs *we* + # author, but we are implementing an external interface here. The + # timeout is consumed via `asyncio.timeout()` below, which is the + # asyncio-native pattern ASYNC109 endorses. + async def connect_tcp( + self, + host: str, + port: int, + timeout: float | None = None, # noqa: ASYNC109 + local_address: str | None = None, + socket_options: typing.Iterable[SOCKET_OPTION] | None = None, + ) -> AsyncNetworkStream: + """Validate the resolved IP, then connect via the inner backend. + + The DNS lookup is wrapped in `asyncio.timeout()` so a slow + resolver is bounded by the caller's timeout. The previous code + only timed out the TCP connect inside the inner backend, leaving + `loop.getaddrinfo` unbounded. + """ + if _check_literal(host): + return await self._inner.connect_tcp( + host, port, timeout, local_address, socket_options + ) + + loop = asyncio.get_running_loop() + try: + async with asyncio.timeout(timeout): + addr_infos = await loop.getaddrinfo(host, port, type=socket.SOCK_STREAM) + except socket.gaierror as exc: + raise httpcore.ConnectError( + f"DNS resolution failed for {host!r}: {exc}" + ) from exc + except TimeoutError as exc: + raise httpcore.ConnectTimeout( + f"DNS resolution timed out for {host!r}" + ) from exc + + pinned_ip = _pick_safe_address(addr_infos, host) + return await self._inner.connect_tcp( + pinned_ip, port, timeout, local_address, socket_options + ) + + # See note on connect_tcp re: ASYNC109 / interface implementation. + async def connect_unix_socket( + self, + path: str, + timeout: float | None = None, # noqa: ASYNC109 + socket_options: typing.Iterable[SOCKET_OPTION] | None = None, + ) -> AsyncNetworkStream: + return await self._inner.connect_unix_socket(path, timeout, socket_options) + + async def sleep(self, seconds: float) -> None: + await self._inner.sleep(seconds) + + +class SSRFProtectedSyncBackend(NetworkBackend): + """Sync backend that validates resolved IPs before establishing TCP.""" + + def __init__(self, inner: NetworkBackend | None = None) -> None: + self._inner = inner if inner is not None else SyncBackend() + + def connect_tcp( + self, + host: str, + port: int, + timeout: float | None = None, + local_address: str | None = None, + socket_options: typing.Iterable[SOCKET_OPTION] | None = None, + ) -> NetworkStream: + if _check_literal(host): + return self._inner.connect_tcp( + host, port, timeout, local_address, socket_options + ) + + try: + addr_infos = socket.getaddrinfo(host, port, type=socket.SOCK_STREAM) + except socket.gaierror as exc: + raise httpcore.ConnectError( + f"DNS resolution failed for {host!r}: {exc}" + ) from exc + + pinned_ip = _pick_safe_address(addr_infos, host) + return self._inner.connect_tcp( + pinned_ip, port, timeout, local_address, socket_options + ) + + def connect_unix_socket( + self, + path: str, + timeout: float | None = None, + socket_options: typing.Iterable[SOCKET_OPTION] | None = None, + ) -> NetworkStream: + return self._inner.connect_unix_socket(path, timeout, socket_options) + + +def install_async_ssrf_protection(client: typing.Any) -> None: + """Wrap the client's default transport so SSRF validation runs at connect time. + + httpx does not expose `network_backend` through its public transport + API, so we mutate `_pool._network_backend` after construction. + """ + pool = client._transport._pool + if not isinstance(pool._network_backend, SSRFProtectedAsyncBackend): + pool._network_backend = SSRFProtectedAsyncBackend(inner=pool._network_backend) + + +def install_sync_ssrf_protection(client: typing.Any) -> None: + """Sync counterpart of `install_async_ssrf_protection`.""" + pool = client._transport._pool + if not isinstance(pool._network_backend, SSRFProtectedSyncBackend): + pool._network_backend = SSRFProtectedSyncBackend(inner=pool._network_backend) + + +RESERVED_HOSTNAMES = [ + "localhost", + "127.0.0.1", + "0.0.0.0", # trunk-ignore(bandit/B104) + "::1", + "::", +] + + +def validate_url_for_http_request(url: str, field_name: str = "URL") -> None: + """Syntactically validate a URL before passing it to an HTTP client. + + Fast-fail check for cases that don't need DNS to detect: + + - The URL scheme is http or https only + - If the host is a literal IP address, it is not private/internal/reserved + - The host is not a reserved hostname (localhost, 127.0.0.1, etc.) + - The host does not use internal TLDs (.local, .internal, .localhost) + + Wired in as an httpx request event hook by `utils.context.create_httpx_*`, + so every request that goes through a context-provided client runs this + automatically — direct calls from feature code are not required. + + Dynamic SSRF protection (rejecting hostnames that resolve to a private + IP, including DNS-rebinding names like `127.0.0.1.nip.io`) happens + inside the HTTP client's connect path via the backends above. Doing + the DNS check here would (a) be defeated by DNS rebinding because + httpx re-resolves at connect time, and (b) block the event loop, + since `socket.getaddrinfo` is synchronous and most callers are async. + + Args: + url (str): The URL to validate + field_name (str): The name of the field for error messages + + Raises: + ValidationError: If the URL is syntactically invalid or matches one + of the static SSRF deny patterns. + """ + if not url or not url.strip(): + msg = f"{field_name} cannot be empty" + log.error(msg) + raise ValidationError(msg, field_name) + + try: + parsed = urlparse(url) + except Exception as e: + msg = f"Invalid {field_name}: unable to parse URL" + log.error(f"{msg}: {str(e)}") + raise ValidationError(msg, field_name) from e + + # Validate scheme - only allow http and https + if parsed.scheme not in ["http", "https"]: + msg = f"Invalid {field_name}: only http and https schemes are allowed" + log.error(f"SSRF prevention: {msg} - got scheme '{parsed.scheme}'") + raise ValidationError(msg, field_name) + + # Extract hostname + hostname = parsed.hostname + if not hostname: + msg = f"Invalid {field_name}: missing hostname" + log.error(msg) + raise ValidationError(msg, field_name) + + # Block reserved hostnames that are commonly used to refer to internal services. + if hostname.lower() in RESERVED_HOSTNAMES: + msg = f"Invalid {field_name}: localhost and reserved hostnames are not allowed" + log.error(f"SSRF prevention: {msg} - hostname '{hostname}'") + raise ValidationError(msg, field_name) + + # Try to parse hostname as a literal IP (standard or non-standard form). + # HTTP clients accept hex (0x7f000001), decimal (2130706433), and + # shorthand-dotted (127.1) integers via inet_aton, so we mirror that. + ip = parse_ip_literal(hostname) + if ip is not None: + if is_forbidden_ip(ip): + msg = ( + f"Invalid {field_name}: private, internal, reserved, " + "or multicast IP addresses are not allowed" + ) + log.error(f"SSRF prevention: {msg} - IP '{ip}'") + raise ValidationError(msg, field_name) + return + + # Block common internal TLDs + hostname_lower = hostname.lower() + internal_tlds = [".local", ".internal", ".localhost"] + if any(hostname_lower.endswith(tld) for tld in internal_tlds): + msg = f"Invalid {field_name}: internal domain names are not allowed" + log.error(f"SSRF prevention: {msg} - hostname '{hostname}'") + raise ValidationError(msg, field_name) diff --git a/backend/utils/validation.py b/backend/utils/validation.py index 27ddff568..5e07a038b 100644 --- a/backend/utils/validation.py +++ b/backend/utils/validation.py @@ -1,7 +1,4 @@ -import ipaddress import re -import socket -from urllib.parse import urlparse from logger.logger import log from models.user import TEXT_FIELD_LENGTH @@ -118,112 +115,3 @@ def validate_email(email: str) -> None: msg = "Invalid email format" log.error(f"Validation failed: {msg} for email: {email}") raise ValidationError(msg, "Email") - - -# Check for localhost and reserved hostnames -RESERVED_HOSTNAMES = [ - "localhost", - "127.0.0.1", - "0.0.0.0", # trunk-ignore(bandit/B104) - "::1", - "::", -] - - -def validate_url_for_http_request(url: str, field_name: str = "URL") -> None: - """Validate URL to prevent Server-Side Request Forgery (SSRF) attacks. - - This function validates that: - - The URL scheme is http or https only - - If the host is a literal IP address, it is not private/internal/reserved - - The host is not a reserved hostname (localhost, 127.0.0.1, etc.) - - The host does not use internal TLDs (.local, .internal, .localhost) - - Note: This function does NOT perform DNS resolution. Domain names that resolve - to private IPs will not be detected (DNS rebinding/internal DNS bypass possible). - It only checks literal IP addresses in the hostname. - - Args: - url (str): The URL to validate - field_name (str): The name of the field for error messages - - Raises: - ValidationError: If the URL is invalid or potentially dangerous - """ - if not url or not url.strip(): - msg = f"{field_name} cannot be empty" - log.error(msg) - raise ValidationError(msg, field_name) - - try: - parsed = urlparse(url) - except Exception as e: - msg = f"Invalid {field_name}: unable to parse URL" - log.error(f"{msg}: {str(e)}") - raise ValidationError(msg, field_name) from e - - # Validate scheme - only allow http and https - if parsed.scheme not in ["http", "https"]: - msg = f"Invalid {field_name}: only http and https schemes are allowed" - log.error(f"SSRF prevention: {msg} - got scheme '{parsed.scheme}'") - raise ValidationError(msg, field_name) - - # Extract hostname - hostname = parsed.hostname - if not hostname: - msg = f"Invalid {field_name}: missing hostname" - log.error(msg) - raise ValidationError(msg, field_name) - - if hostname.lower() in RESERVED_HOSTNAMES: - msg = f"Invalid {field_name}: localhost and reserved hostnames are not allowed" - log.error(f"SSRF prevention: {msg} - hostname '{hostname}'") - raise ValidationError(msg, field_name) - - # Try to resolve hostname as IP address - try: - ip = ipaddress.ip_address(hostname) - - # Block private/internal/link-local IP addresses - if ip.is_private or ip.is_loopback or ip.is_link_local or ip.is_reserved: - msg = f"Invalid {field_name}: private, internal, and reserved IP addresses are not allowed" - log.error(f"SSRF prevention: {msg} - IP '{ip}'") - raise ValidationError(msg, field_name) - - # Block multicast addresses - if ip.is_multicast: - msg = f"Invalid {field_name}: multicast addresses are not allowed" - log.error(f"SSRF prevention: {msg} - IP '{ip}'") - raise ValidationError(msg, field_name) - - except ValueError as e: - # ipaddress.ip_address() only handles standard notation. HTTP clients - # also accept hex integers (0x7f000001), decimal integers (2130706433), - # shorthand dotted (127.1), and octal (0177.0.0.1). Use socket.inet_aton() - # which handles these non-standard IPv4 representations. - try: - packed = socket.inet_aton(hostname) - ip = ipaddress.IPv4Address(packed) - - if ip.is_private or ip.is_loopback or ip.is_link_local or ip.is_reserved: - msg = f"Invalid {field_name}: private, internal, and reserved IP addresses are not allowed" - log.error(f"SSRF prevention: {msg} - IP '{ip}'") - raise ValidationError(msg, field_name) - - if ip.is_multicast: - msg = f"Invalid {field_name}: multicast addresses are not allowed" - log.error(f"SSRF prevention: {msg} - IP '{ip}'") - raise ValidationError(msg, field_name) - - except OSError: - pass # Not an IP address at all - fall through to domain name checks - - # Additional checks for suspicious domain patterns - hostname_lower = hostname.lower() - - # Block common internal TLDs - internal_tlds = [".local", ".internal", ".localhost"] - if any(hostname_lower.endswith(tld) for tld in internal_tlds): - msg = f"Invalid {field_name}: internal domain names are not allowed" - log.error(f"SSRF prevention: {msg} - hostname '{hostname}'") - raise ValidationError(msg, field_name) from e diff --git a/pyproject.toml b/pyproject.toml index 1ac0b1cde..1440eacf2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,7 +15,7 @@ dependencies = [ "SQLAlchemy[mariadb-connector,mysql-connector,postgresql-psycopg] ~= 2.0", "Unidecode ~= 1.3", "aiohttp~=3.13", - "asyncssh ~= 2.17", + "asyncssh~=2.23", "alembic ~= 1.16", "anyio ~= 4.4", "authlib~=1.6.12", diff --git a/uv.lock b/uv.lock index 25c21e7f3..4d3f3d2fd 100644 --- a/uv.lock +++ b/uv.lock @@ -7,7 +7,7 @@ resolution-markers = [ ] [options] -exclude-newer = "2026-05-21T08:54:16.815535417Z" +exclude-newer = "0001-01-01T00:00:00Z" # This has no effect and is included for backwards compatibility when using relative exclude-newer values. exclude-newer-span = "P7D" [options.exclude-newer-package] @@ -176,15 +176,15 @@ wheels = [ [[package]] name = "asyncssh" -version = "2.22.0" +version = "2.23.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "cryptography" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/fc/d5/957886c316466349d55c4de6a688a10a98295c0b4429deb8db1a17f3eb19/asyncssh-2.22.0.tar.gz", hash = "sha256:c3ce72b01be4f97b40e62844dd384227e5ff5a401a3793007c42f86a5c8eb537", size = 540523, upload-time = "2025-12-21T23:38:30.5Z" } +sdist = { url = "https://files.pythonhosted.org/packages/ee/fd/c34fe7e30838b4b9cc91903da26a62c6d33b673c731b3d951fcd70ab1889/asyncssh-2.23.0.tar.gz", hash = "sha256:8c54760953c1f2cf282591bcba5c8c70efc48d645bbf26bd2307a9c66a0ed1a7", size = 542154, upload-time = "2026-05-09T03:15:01.856Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/ed/ae/0da2f2214fc183338af1afe5a103a2052fd03464e8eafbd827abff58a4d0/asyncssh-2.22.0-py3-none-any.whl", hash = "sha256:d16465ccdf1ed20eba1131b14415b155e047f6f5be0d19f39c2e0b61331ee0e7", size = 374938, upload-time = "2025-12-21T23:38:28.976Z" }, + { url = "https://files.pythonhosted.org/packages/ff/b5/b1a3979f4840d1271ca8e0978dbccfb18ad2d33b4ece85cf77122fb46e5f/asyncssh-2.23.0-py3-none-any.whl", hash = "sha256:14108bfdaae17457f0c1841e883ad934271bbfdd46458aa4c4d0973451940ad0", size = 375687, upload-time = "2026-05-09T03:15:00.221Z" }, ] [[package]] @@ -2255,7 +2255,7 @@ requires-dist = [ { name = "aiohttp", specifier = "~=3.13" }, { name = "alembic", specifier = "~=1.16" }, { name = "anyio", specifier = "~=4.4" }, - { name = "asyncssh", specifier = "~=2.17" }, + { name = "asyncssh", specifier = "~=2.23" }, { name = "authlib", specifier = "~=1.6.12" }, { name = "bcrypt", specifier = "<5" }, { name = "colorama", specifier = "~=0.4" },