Merge branch 'master' into feat/soundtrack-support

This commit is contained in:
Zurdi
2026-05-28 11:20:58 +02:00
committed by GitHub
8 changed files with 828 additions and 355 deletions

View File

@@ -15,7 +15,6 @@ from models.collection import Collection
from models.rom import Rom
from tasks.scheduled.convert_images_to_webp import ImageConverter
from utils.context import ctx_httpx_client
from utils.validation import validate_url_for_http_request
from .base_handler import CoverSize, FSHandler
@@ -145,8 +144,6 @@ class FSResourcesHandler(FSHandler):
return None
else:
# Handle HTTP URLs
validate_url_for_http_request(url_cover, "url_cover")
httpx_client = ctx_httpx_client.get()
try:
async with httpx_client.stream(
@@ -314,8 +311,6 @@ class FSResourcesHandler(FSHandler):
return None
else:
# Handle HTTP URLs
validate_url_for_http_request(url_screenhot, "url_screenshot")
httpx_client = ctx_httpx_client.get()
try:
async with httpx_client.stream(
@@ -430,8 +425,6 @@ class FSResourcesHandler(FSHandler):
return None
else:
# Handle HTTP URL
validate_url_for_http_request(url_manual, "url_manual")
httpx_client = ctx_httpx_client.get()
try:
async with httpx_client.stream(
@@ -501,8 +494,6 @@ class FSResourcesHandler(FSHandler):
# Retroachievements
async def store_ra_badge(self, url: str, path: str) -> None:
validate_url_for_http_request(url, "url_badge")
httpx_client = ctx_httpx_client.get()
directory, filename = os.path.split(path)
@@ -569,8 +560,6 @@ class FSResourcesHandler(FSHandler):
return None
else:
# Handle HTTP URLs
validate_url_for_http_request(url_media, "url_media")
httpx_client = ctx_httpx_client.get()
try:
async with httpx_client.stream(

View File

@@ -0,0 +1,466 @@
"""Tests for SSRF defense: URL validator + httpcore network backends."""
import asyncio
import socket
from typing import Any
from unittest.mock import AsyncMock, MagicMock
import httpcore
import pytest
from utils.ssrf import (
SSRFProtectedAsyncBackend,
SSRFProtectedSyncBackend,
is_forbidden_ip,
parse_ip_literal,
validate_url_for_http_request,
)
from utils.validation import ValidationError
ConnectCall = tuple[tuple[Any, ...], dict[str, Any]]
def _addr_info(ip: str, port: int) -> list[tuple[Any, ...]]:
family = socket.AF_INET6 if ":" in ip else socket.AF_INET
return [(family, socket.SOCK_STREAM, 0, "", (ip, port))]
def _stub_async_inner(connect_calls: list[ConnectCall]) -> MagicMock:
inner = MagicMock()
def _record(*args: Any, **kwargs: Any) -> MagicMock:
connect_calls.append((args, kwargs))
return MagicMock()
inner.connect_tcp = AsyncMock(side_effect=_record)
return inner
def _stub_sync_inner(connect_calls: list[ConnectCall]) -> MagicMock:
inner = MagicMock()
def _record(*args: Any, **kwargs: Any) -> MagicMock:
connect_calls.append((args, kwargs))
return MagicMock()
inner.connect_tcp = MagicMock(side_effect=_record)
return inner
class TestIsForbiddenIp:
@pytest.mark.parametrize(
"ip",
[
"127.0.0.1",
"10.0.0.1",
"192.168.1.1",
"172.16.0.1",
"169.254.169.254",
"0.0.0.0",
"224.0.0.1",
"::1",
"fc00::1",
"fe80::1",
"ff02::1",
],
)
def test_forbidden(self, ip):
import ipaddress
assert is_forbidden_ip(ipaddress.ip_address(ip)) is True
@pytest.mark.parametrize(
"ip", ["8.8.8.8", "1.1.1.1", "93.184.216.34", "2001:4860:4860::8888"]
)
def test_allowed(self, ip):
import ipaddress
assert is_forbidden_ip(ipaddress.ip_address(ip)) is False
class TestParseIpLiteral:
def test_standard(self):
assert str(parse_ip_literal("127.0.0.1")) == "127.0.0.1"
assert str(parse_ip_literal("::1")) == "::1"
def test_non_standard_ipv4(self):
# Hex, decimal, shorthand - all map to 127.0.0.1
assert str(parse_ip_literal("0x7f000001")) == "127.0.0.1"
assert str(parse_ip_literal("2130706433")) == "127.0.0.1"
assert str(parse_ip_literal("127.1")) == "127.0.0.1"
def test_hostname_returns_none(self):
assert parse_ip_literal("example.com") is None
class TestSSRFProtectedAsyncBackend:
async def test_safe_hostname_connects_to_pinned_ip(self, monkeypatch):
"""Backend resolves once, validates, and passes the pinned IP to inner."""
calls: list[ConnectCall] = []
inner = _stub_async_inner(calls)
backend = SSRFProtectedAsyncBackend(inner=inner)
async def fake_getaddrinfo(host, port, *args, **kwargs):
return _addr_info("93.184.216.34", port)
loop = asyncio.get_running_loop()
monkeypatch.setattr(loop, "getaddrinfo", fake_getaddrinfo)
await backend.connect_tcp("example.com", 443)
# inner must receive the resolved IP, not the original hostname.
# That is what pins the address against DNS rebinding.
assert calls[0][0][0] == "93.184.216.34"
assert calls[0][0][1] == 443
async def test_hostname_resolving_to_private_ip_is_rejected(self, monkeypatch):
"""DNS rebinding case: hostname resolves to 127.0.0.1 must fail."""
inner = _stub_async_inner([])
backend = SSRFProtectedAsyncBackend(inner=inner)
async def fake_getaddrinfo(host, port, *args, **kwargs):
return _addr_info("127.0.0.1", port)
monkeypatch.setattr(asyncio.get_running_loop(), "getaddrinfo", fake_getaddrinfo)
with pytest.raises(httpcore.ConnectError, match="forbidden IP"):
await backend.connect_tcp("127.0.0.1.nip.io", 80)
inner.connect_tcp.assert_not_called()
async def test_mixed_resolution_rejected(self, monkeypatch):
"""If any returned address is forbidden, reject - don't trust round-robin."""
inner = _stub_async_inner([])
backend = SSRFProtectedAsyncBackend(inner=inner)
async def fake_getaddrinfo(host, port, *args, **kwargs):
return _addr_info("93.184.216.34", port) + _addr_info("10.0.0.1", port)
monkeypatch.setattr(asyncio.get_running_loop(), "getaddrinfo", fake_getaddrinfo)
with pytest.raises(httpcore.ConnectError, match="forbidden IP"):
await backend.connect_tcp("mixed.example.com", 80)
inner.connect_tcp.assert_not_called()
async def test_literal_forbidden_ip_rejected(self):
inner = _stub_async_inner([])
backend = SSRFProtectedAsyncBackend(inner=inner)
with pytest.raises(httpcore.ConnectError, match="forbidden IP"):
await backend.connect_tcp("169.254.169.254", 80)
inner.connect_tcp.assert_not_called()
async def test_literal_public_ip_passes_through(self):
calls: list[ConnectCall] = []
inner = _stub_async_inner(calls)
backend = SSRFProtectedAsyncBackend(inner=inner)
await backend.connect_tcp("8.8.8.8", 443)
# Literal public IPs are passed through unchanged.
assert calls[0][0][0] == "8.8.8.8"
async def test_non_standard_ipv4_literal_blocked(self):
"""Hex/decimal IPv4 forms must be blocked, matching httpx's parsing."""
inner = _stub_async_inner([])
backend = SSRFProtectedAsyncBackend(inner=inner)
with pytest.raises(httpcore.ConnectError, match="forbidden IP"):
await backend.connect_tcp("2130706433", 80) # 127.0.0.1
inner.connect_tcp.assert_not_called()
async def test_dns_timeout_raises_connect_timeout(self, monkeypatch):
"""A resolver that hangs past the caller's timeout must not block forever.
Regression: an earlier version applied the caller's timeout only to
the TCP connect inside the inner backend, leaving `loop.getaddrinfo`
unbounded. We now wrap the lookup in `asyncio.timeout()` so a slow
resolver is bounded by the same budget the caller specified.
"""
inner = _stub_async_inner([])
backend = SSRFProtectedAsyncBackend(inner=inner)
async def hang_forever(*args, **kwargs):
await asyncio.sleep(3600)
monkeypatch.setattr(asyncio.get_running_loop(), "getaddrinfo", hang_forever)
with pytest.raises(httpcore.ConnectTimeout, match="DNS resolution timed out"):
await backend.connect_tcp("slow.example.com", 80, timeout=0.05)
inner.connect_tcp.assert_not_called()
async def test_dns_failure_propagates_as_connect_error(self, monkeypatch):
inner = _stub_async_inner([])
backend = SSRFProtectedAsyncBackend(inner=inner)
async def fake_getaddrinfo(*args, **kwargs):
raise socket.gaierror("Name or service not known")
monkeypatch.setattr(asyncio.get_running_loop(), "getaddrinfo", fake_getaddrinfo)
with pytest.raises(httpcore.ConnectError, match="DNS resolution failed"):
await backend.connect_tcp("nonexistent.invalid", 80)
inner.connect_tcp.assert_not_called()
class TestSSRFProtectedSyncBackend:
def test_safe_hostname_connects_to_pinned_ip(self, monkeypatch):
calls: list[ConnectCall] = []
inner = _stub_sync_inner(calls)
backend = SSRFProtectedSyncBackend(inner=inner)
monkeypatch.setattr(
socket,
"getaddrinfo",
lambda host, port, *a, **kw: _addr_info("93.184.216.34", port),
)
backend.connect_tcp("example.com", 443)
assert calls[0][0][0] == "93.184.216.34"
def test_hostname_resolving_to_private_ip_is_rejected(self, monkeypatch):
inner = _stub_sync_inner([])
backend = SSRFProtectedSyncBackend(inner=inner)
monkeypatch.setattr(
socket,
"getaddrinfo",
lambda host, port, *a, **kw: _addr_info("127.0.0.1", port),
)
with pytest.raises(httpcore.ConnectError, match="forbidden IP"):
backend.connect_tcp("127.0.0.1.nip.io", 80)
inner.connect_tcp.assert_not_called()
def test_literal_forbidden_ip_rejected(self):
inner = _stub_sync_inner([])
backend = SSRFProtectedSyncBackend(inner=inner)
with pytest.raises(httpcore.ConnectError, match="forbidden IP"):
backend.connect_tcp("10.0.0.1", 80)
inner.connect_tcp.assert_not_called()
class TestRequestEventHook:
"""Verify the syntactic URL validator is wired as a request event hook.
With the hook in place, every request through a context-provided client
is validated automatically; feature code does not need to call
`validate_url_for_http_request` itself.
"""
async def test_async_hook_rejects_bad_scheme(self):
from utils.context import create_httpx_async_client
from utils.validation import ValidationError
client = create_httpx_async_client()
try:
with pytest.raises(ValidationError, match="only http and https"):
await client.get("file:///etc/passwd")
finally:
await client.aclose()
async def test_async_hook_rejects_internal_tld(self):
from utils.context import create_httpx_async_client
from utils.validation import ValidationError
client = create_httpx_async_client()
try:
with pytest.raises(ValidationError, match="internal domain names"):
await client.get("http://printer.local/status")
finally:
await client.aclose()
def test_sync_hook_rejects_literal_private_ip(self):
from utils.context import create_httpx_client
from utils.validation import ValidationError
with create_httpx_client() as client:
with pytest.raises(ValidationError, match="private, internal"):
client.get("http://10.0.0.1/")
class TestInstallation:
"""Verify the backend is actually wired onto httpx clients we create."""
def test_create_httpx_async_client_installs_backend(self):
from utils.context import create_httpx_async_client
from utils.ssrf import SSRFProtectedAsyncBackend as Async
client = create_httpx_async_client()
try:
assert isinstance(client._transport._pool._network_backend, Async)
finally:
asyncio.run(client.aclose())
def test_create_httpx_client_installs_backend(self):
from utils.context import create_httpx_client
from utils.ssrf import SSRFProtectedSyncBackend as Sync
with create_httpx_client() as client:
assert isinstance(client._transport._pool._network_backend, Sync)
def test_proxy_transports_are_not_wrapped(self):
"""Proxy mounts must keep their stock backend.
A common deployment pattern is `HTTPS_PROXY=http://sidecar:9050`,
where `sidecar` resolves to a docker-bridge private IP. If we
wrapped the proxy transport, our SSRF backend would refuse to
connect to the operator's chosen proxy. SSRF protection at the
proxy hop is the operator's responsibility; the destination URL
is still validated by the request event hook on the client.
"""
import httpx
from utils.ssrf import SSRFProtectedAsyncBackend as Async
from utils.ssrf import (
install_async_ssrf_protection,
)
client = httpx.AsyncClient(proxy="http://proxy.invalid:3128")
try:
install_async_ssrf_protection(client)
assert isinstance(client._transport._pool._network_backend, Async)
for mount in client._mounts.values():
if mount is None:
continue
# Proxy mount must NOT have been wrapped.
assert not isinstance(mount._pool._network_backend, Async)
finally:
asyncio.run(client.aclose())
class TestValidateUrlForHttpRequest:
"""Test URL validation for HTTP requests to prevent SSRF attacks."""
def test_valid_http_urls(self):
"""Valid HTTP/HTTPS URLs pass syntactic validation without DNS lookups.
DNS-based SSRF checks live in the HTTP client's connect path, so
this layer must not call DNS.
"""
validate_url_for_http_request("http://example.com", "test_url")
validate_url_for_http_request("https://example.com", "test_url")
validate_url_for_http_request("http://example.com/path", "test_url")
validate_url_for_http_request("https://example.com/path?query=1", "test_url")
validate_url_for_http_request("http://subdomain.example.com", "test_url")
def test_validator_does_not_perform_dns_lookup(self, monkeypatch):
"""Regression: validator must not block the event loop on DNS.
Earlier implementations called `socket.getaddrinfo`, which both
blocked the running event loop in async media-download callers and
was defeated by DNS rebinding (the value seen here did not match
the IP httpx later connected to). We patch getaddrinfo to a poison
function so any accidental reintroduction fails this test.
"""
def _explode(*_args, **_kwargs):
raise AssertionError(
"validate_url_for_http_request must not call DNS; "
"SSRF DNS protection lives in the HTTP client backend"
)
monkeypatch.setattr(socket, "getaddrinfo", _explode)
validate_url_for_http_request("http://example.com", "test_url")
def test_invalid_empty_url(self):
with pytest.raises(ValidationError) as exc_info:
validate_url_for_http_request("", "test_url")
assert "cannot be empty" in exc_info.value.message
def test_invalid_scheme(self):
for url in (
"ftp://example.com",
"file:///etc/passwd",
"data:text/html,<h1>test</h1>",
"javascript:alert(1)", # XSS attack vector
):
with pytest.raises(ValidationError) as exc_info:
validate_url_for_http_request(url, "test_url")
assert "only http and https schemes are allowed" in exc_info.value.message
def test_invalid_localhost(self):
for url in (
"http://localhost",
"http://127.0.0.1",
"http://[::1]",
"http://0.0.0.0",
):
with pytest.raises(ValidationError) as exc_info:
validate_url_for_http_request(url, "test_url")
assert (
"localhost and reserved hostnames are not allowed"
in exc_info.value.message
)
def test_invalid_private_ipv4_addresses(self):
for url in (
"http://10.0.0.1",
"http://192.168.1.1",
"http://172.16.0.1",
"http://172.31.255.254",
):
with pytest.raises(ValidationError) as exc_info:
validate_url_for_http_request(url, "test_url")
assert (
"private, internal, reserved, or multicast IP addresses are not allowed"
in exc_info.value.message
)
def test_invalid_loopback_addresses(self):
# 127.0.0.1 itself is in RESERVED_HOSTNAMES; these cover the rest of 127/8.
for url in ("http://127.0.0.2", "http://127.255.255.255"):
with pytest.raises(ValidationError) as exc_info:
validate_url_for_http_request(url, "test_url")
assert (
"private, internal, reserved, or multicast IP addresses are not allowed"
in exc_info.value.message
)
def test_invalid_private_ipv6_addresses(self):
for url in (
"http://[fe80::1]", # link-local
"http://[fc00::1]", # unique local
"http://[fd00::1]", # unique local
):
with pytest.raises(ValidationError) as exc_info:
validate_url_for_http_request(url, "test_url")
assert (
"private, internal, reserved, or multicast IP addresses are not allowed"
in exc_info.value.message
)
def test_invalid_multicast_addresses(self):
for url in ("http://224.0.0.1", "http://[ff02::1]"):
with pytest.raises(ValidationError) as exc_info:
validate_url_for_http_request(url, "test_url")
assert (
"private, internal, reserved, or multicast IP addresses are not allowed"
in exc_info.value.message
)
def test_invalid_internal_tlds(self):
for url in (
"http://server.local",
"http://server.internal",
"http://server.localhost",
):
with pytest.raises(ValidationError) as exc_info:
validate_url_for_http_request(url, "test_url")
assert "internal domain names are not allowed" in exc_info.value.message
def test_invalid_non_standard_ip_representations(self):
"""Non-standard IPv4 forms (hex, decimal, shorthand) are SSRF bypass vectors."""
cases = [
"http://0x7f000001", # hex 127.0.0.1
"http://2130706433", # decimal 127.0.0.1
"http://127.1", # shorthand 127.0.0.1
"http://0x0a000001", # hex 10.0.0.1
"http://3232235777", # decimal 192.168.1.1
"http://0xa9fea9fe", # hex 169.254.169.254 (cloud metadata)
]
for url in cases:
with pytest.raises(ValidationError) as exc_info:
validate_url_for_http_request(url, "test_url")
assert (
"private, internal, reserved, or multicast IP addresses are not allowed"
in exc_info.value.message
)
def test_invalid_missing_hostname(self):
with pytest.raises(ValidationError) as exc_info:
validate_url_for_http_request("http://", "test_url")
assert "missing hostname" in exc_info.value.message

View File

@@ -7,7 +7,6 @@ from utils.validation import (
validate_ascii_only,
validate_email,
validate_password,
validate_url_for_http_request,
validate_username,
)
@@ -158,226 +157,3 @@ class TestValidateEmail:
with pytest.raises(ValidationError) as exc_info:
validate_email("résumé@example.com")
assert "ASCII characters" in exc_info.value.message
class TestValidateUrlForHttpRequest:
"""Test URL validation for HTTP requests to prevent SSRF attacks."""
def test_valid_http_urls(self):
"""Test that valid HTTP/HTTPS URLs pass validation."""
validate_url_for_http_request("http://example.com", "test_url")
validate_url_for_http_request("https://example.com", "test_url")
validate_url_for_http_request("http://example.com/path", "test_url")
validate_url_for_http_request("https://example.com/path?query=1", "test_url")
validate_url_for_http_request("http://subdomain.example.com", "test_url")
def test_invalid_empty_url(self):
"""Test that empty URLs fail validation."""
with pytest.raises(ValidationError) as exc_info:
validate_url_for_http_request("", "test_url")
assert "cannot be empty" in exc_info.value.message
def test_invalid_scheme(self):
"""Test that non-HTTP/HTTPS schemes fail validation."""
# FTP scheme
with pytest.raises(ValidationError) as exc_info:
validate_url_for_http_request("ftp://example.com", "test_url")
assert "only http and https schemes are allowed" in exc_info.value.message
# File scheme
with pytest.raises(ValidationError) as exc_info:
validate_url_for_http_request("file:///etc/passwd", "test_url")
assert "only http and https schemes are allowed" in exc_info.value.message
# Data scheme
with pytest.raises(ValidationError) as exc_info:
validate_url_for_http_request("data:text/html,<h1>test</h1>", "test_url")
assert "only http and https schemes are allowed" in exc_info.value.message
# JavaScript scheme (XSS attack vector)
with pytest.raises(ValidationError) as exc_info:
validate_url_for_http_request("javascript:alert(1)", "test_url")
assert "only http and https schemes are allowed" in exc_info.value.message
def test_invalid_localhost(self):
"""Test that localhost and reserved hostnames fail validation."""
with pytest.raises(ValidationError) as exc_info:
validate_url_for_http_request("http://localhost", "test_url")
assert (
"localhost and reserved hostnames are not allowed" in exc_info.value.message
)
with pytest.raises(ValidationError) as exc_info:
validate_url_for_http_request("http://127.0.0.1", "test_url")
assert (
"localhost and reserved hostnames are not allowed" in exc_info.value.message
)
with pytest.raises(ValidationError) as exc_info:
validate_url_for_http_request("http://[::1]", "test_url")
assert (
"localhost and reserved hostnames are not allowed" in exc_info.value.message
)
with pytest.raises(ValidationError) as exc_info:
validate_url_for_http_request("http://0.0.0.0", "test_url")
assert (
"localhost and reserved hostnames are not allowed" in exc_info.value.message
)
def test_invalid_private_ipv4_addresses(self):
"""Test that private IPv4 addresses fail validation."""
# 10.x.x.x range
with pytest.raises(ValidationError) as exc_info:
validate_url_for_http_request("http://10.0.0.1", "test_url")
assert (
"private, internal, and reserved IP addresses are not allowed"
in exc_info.value.message
)
# 192.168.x.x range
with pytest.raises(ValidationError) as exc_info:
validate_url_for_http_request("http://192.168.1.1", "test_url")
assert (
"private, internal, and reserved IP addresses are not allowed"
in exc_info.value.message
)
# 172.16.x.x - 172.31.x.x range
with pytest.raises(ValidationError) as exc_info:
validate_url_for_http_request("http://172.16.0.1", "test_url")
assert (
"private, internal, and reserved IP addresses are not allowed"
in exc_info.value.message
)
with pytest.raises(ValidationError) as exc_info:
validate_url_for_http_request("http://172.31.255.254", "test_url")
assert (
"private, internal, and reserved IP addresses are not allowed"
in exc_info.value.message
)
def test_invalid_loopback_addresses(self):
"""Test that loopback addresses fail validation."""
# 127.x.x.x range
with pytest.raises(ValidationError) as exc_info:
validate_url_for_http_request("http://127.0.0.2", "test_url")
assert (
"private, internal, and reserved IP addresses are not allowed"
in exc_info.value.message
)
with pytest.raises(ValidationError) as exc_info:
validate_url_for_http_request("http://127.255.255.255", "test_url")
assert (
"private, internal, and reserved IP addresses are not allowed"
in exc_info.value.message
)
def test_invalid_private_ipv6_addresses(self):
"""Test that private/link-local IPv6 addresses fail validation."""
# Link-local IPv6: fe80::/10
with pytest.raises(ValidationError) as exc_info:
validate_url_for_http_request("http://[fe80::1]", "test_url")
assert (
"private, internal, and reserved IP addresses are not allowed"
in exc_info.value.message
)
# Unique local address: fc00::/7
with pytest.raises(ValidationError) as exc_info:
validate_url_for_http_request("http://[fc00::1]", "test_url")
assert (
"private, internal, and reserved IP addresses are not allowed"
in exc_info.value.message
)
with pytest.raises(ValidationError) as exc_info:
validate_url_for_http_request("http://[fd00::1]", "test_url")
assert (
"private, internal, and reserved IP addresses are not allowed"
in exc_info.value.message
)
def test_invalid_multicast_addresses(self):
"""Test that multicast addresses fail validation."""
# IPv4 multicast: 224.0.0.0/4
with pytest.raises(ValidationError) as exc_info:
validate_url_for_http_request("http://224.0.0.1", "test_url")
assert "multicast addresses are not allowed" in exc_info.value.message
# IPv6 multicast: ff00::/8
with pytest.raises(ValidationError) as exc_info:
validate_url_for_http_request("http://[ff02::1]", "test_url")
assert "multicast addresses are not allowed" in exc_info.value.message
def test_invalid_internal_tlds(self):
"""Test that internal TLDs fail validation."""
with pytest.raises(ValidationError) as exc_info:
validate_url_for_http_request("http://server.local", "test_url")
assert "internal domain names are not allowed" in exc_info.value.message
with pytest.raises(ValidationError) as exc_info:
validate_url_for_http_request("http://server.internal", "test_url")
assert "internal domain names are not allowed" in exc_info.value.message
with pytest.raises(ValidationError) as exc_info:
validate_url_for_http_request("http://server.localhost", "test_url")
assert "internal domain names are not allowed" in exc_info.value.message
def test_invalid_non_standard_ip_representations(self):
"""Test that non-standard IP representations are blocked (SSRF bypass vectors)."""
# Hexadecimal integer for 127.0.0.1
with pytest.raises(ValidationError) as exc_info:
validate_url_for_http_request("http://0x7f000001", "test_url")
assert (
"private, internal, and reserved IP addresses are not allowed"
in exc_info.value.message
)
# Decimal integer for 127.0.0.1
with pytest.raises(ValidationError) as exc_info:
validate_url_for_http_request("http://2130706433", "test_url")
assert (
"private, internal, and reserved IP addresses are not allowed"
in exc_info.value.message
)
# Shorthand dotted for 127.0.0.1
with pytest.raises(ValidationError) as exc_info:
validate_url_for_http_request("http://127.1", "test_url")
assert (
"private, internal, and reserved IP addresses are not allowed"
in exc_info.value.message
)
# Hexadecimal integer for 10.0.0.1
with pytest.raises(ValidationError) as exc_info:
validate_url_for_http_request("http://0x0a000001", "test_url")
assert (
"private, internal, and reserved IP addresses are not allowed"
in exc_info.value.message
)
# Decimal integer for 192.168.1.1
with pytest.raises(ValidationError) as exc_info:
validate_url_for_http_request("http://3232235777", "test_url")
assert (
"private, internal, and reserved IP addresses are not allowed"
in exc_info.value.message
)
# Hexadecimal integer for 169.254.169.254 (cloud metadata)
with pytest.raises(ValidationError) as exc_info:
validate_url_for_http_request("http://0xa9fea9fe", "test_url")
assert (
"private, internal, and reserved IP addresses are not allowed"
in exc_info.value.message
)
def test_invalid_missing_hostname(self):
"""Test that URLs without hostnames fail validation."""
with pytest.raises(ValidationError) as exc_info:
validate_url_for_http_request("http://", "test_url")
assert "missing hostname" in exc_info.value.message

View File

@@ -8,6 +8,11 @@ import httpx
from fastapi import Request, Response
from config import has_proxy_env
from utils.ssrf import (
install_async_ssrf_protection,
install_sync_ssrf_protection,
validate_url_for_http_request,
)
_T = TypeVar("_T")
@@ -15,16 +20,34 @@ ctx_aiohttp_session: ContextVar[aiohttp.ClientSession] = ContextVar("aiohttp_ses
ctx_httpx_client: ContextVar[httpx.AsyncClient] = ContextVar("httpx_client")
def _validate_request_url_sync(request: httpx.Request) -> None:
validate_url_for_http_request(str(request.url))
async def _validate_request_url_async(request: httpx.Request) -> None:
validate_url_for_http_request(str(request.url))
def create_aiohttp_session() -> aiohttp.ClientSession:
return aiohttp.ClientSession(trust_env=has_proxy_env())
def create_httpx_async_client() -> httpx.AsyncClient:
return httpx.AsyncClient(trust_env=has_proxy_env())
client = httpx.AsyncClient(
trust_env=has_proxy_env(),
event_hooks={"request": [_validate_request_url_async]},
)
install_async_ssrf_protection(client)
return client
def create_httpx_client() -> httpx.Client:
return httpx.Client(trust_env=has_proxy_env())
client = httpx.Client(
trust_env=has_proxy_env(),
event_hooks={"request": [_validate_request_url_sync]},
)
install_sync_ssrf_protection(client)
return client
@asynccontextmanager

331
backend/utils/ssrf.py Normal file
View File

@@ -0,0 +1,331 @@
"""SSRF defense for outbound HTTP.
Two layers, both wired onto every httpx client built by `utils.context`:
1. `validate_url_for_http_request` — a syntactic fast-fail check
installed as an httpx request event hook. Rejects non-HTTP schemes,
literal IPs in forbidden ranges (including non-standard IPv4 forms),
reserved hostnames, and internal TLDs before any socket opens.
2. `SSRFProtectedAsyncBackend` / `SSRFProtectedSyncBackend` — custom
httpcore network backends that resolve the hostname inside
`connect_tcp`, reject any address in a private/loopback/link-local/
reserved/multicast/unspecified range, then connect to that *same*
validated address. This is what defeats DNS rebinding: the address
used by the OS for the TCP connection is the one we just checked,
not a fresh lookup the attacker can answer differently. Doing this
work in the backend also avoids blocking the event loop, since the
async variant uses `loop.getaddrinfo`.
httpcore calls `start_tls(server_hostname=<URL host>)` after
`connect_tcp` returns, so TLS SNI and certificate verification still
use the original hostname even though we connect by IP.
"""
from __future__ import annotations
import asyncio
import ipaddress
import socket
import typing
from urllib.parse import urlparse
import httpcore
from httpcore._backends.auto import AutoBackend
from httpcore._backends.base import (
SOCKET_OPTION,
AsyncNetworkBackend,
AsyncNetworkStream,
NetworkBackend,
NetworkStream,
)
from httpcore._backends.sync import SyncBackend
from logger.logger import log
from utils.validation import ValidationError
def is_forbidden_ip(ip: ipaddress.IPv4Address | ipaddress.IPv6Address) -> bool:
"""Return True if the IP must not be reached from a server-side HTTP request."""
return (
ip.is_private
or ip.is_loopback
or ip.is_link_local
or ip.is_reserved
or ip.is_multicast
or ip.is_unspecified
)
def parse_ip_literal(
host: str,
) -> ipaddress.IPv4Address | ipaddress.IPv6Address | None:
"""Return the parsed IP if `host` is a literal address, else None.
Accepts non-standard IPv4 forms (hex, decimal, shorthand) via
`socket.inet_aton`, which is what HTTP clients themselves accept.
"""
try:
return ipaddress.ip_address(host)
except ValueError:
pass
try:
packed = socket.inet_aton(host)
except OSError:
return None
return ipaddress.IPv4Address(packed)
def _pick_safe_address(addr_infos: typing.Iterable[typing.Any], host: str) -> str:
"""Validate every resolved address and return the literal IP to connect to.
All returned addresses are checked: if any falls in a forbidden range
we reject the whole name, rather than just skipping that record. A
malicious DNS server can otherwise mix public and private answers
and rely on the client to round-robin.
"""
chosen: str | None = None
for *_, sockaddr in addr_infos:
try:
ip = ipaddress.ip_address(sockaddr[0])
except (ValueError, IndexError):
continue
if is_forbidden_ip(ip):
msg = (
f"SSRF prevention: hostname {host!r} resolves to forbidden " f"IP {ip}"
)
log.error(msg)
raise httpcore.ConnectError(msg)
if chosen is None:
chosen = sockaddr[0]
if chosen is None:
raise httpcore.ConnectError(f"No usable addresses for {host!r}")
return chosen
def _check_literal(host: str) -> bool:
"""Return True if `host` is a literal IP that has been validated as safe.
Raises httpcore.ConnectError if it is a literal IP in a forbidden range.
Returns False if `host` is a hostname (caller must resolve and validate).
"""
literal = parse_ip_literal(host)
if literal is None:
return False
if is_forbidden_ip(literal):
msg = f"SSRF prevention: connection to forbidden IP {literal}"
log.error(msg)
raise httpcore.ConnectError(msg)
return True
class SSRFProtectedAsyncBackend(AsyncNetworkBackend):
"""Async backend that validates resolved IPs before establishing TCP."""
def __init__(self, inner: AsyncNetworkBackend | None = None) -> None:
self._inner = inner or AutoBackend()
# `timeout` parameter is required by AsyncNetworkBackend.connect_tcp;
# ruff/ASYNC109 advises against timeout parameters on async APIs *we*
# author, but we are implementing an external interface here. The
# timeout is consumed via `asyncio.timeout()` below, which is the
# asyncio-native pattern ASYNC109 endorses.
async def connect_tcp(
self,
host: str,
port: int,
timeout: float | None = None, # noqa: ASYNC109
local_address: str | None = None,
socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
) -> AsyncNetworkStream:
"""Validate the resolved IP, then connect via the inner backend.
The DNS lookup is wrapped in `asyncio.timeout()` so a slow
resolver is bounded by the caller's timeout. The previous code
only timed out the TCP connect inside the inner backend, leaving
`loop.getaddrinfo` unbounded.
"""
if _check_literal(host):
return await self._inner.connect_tcp(
host, port, timeout, local_address, socket_options
)
loop = asyncio.get_running_loop()
try:
async with asyncio.timeout(timeout):
addr_infos = await loop.getaddrinfo(host, port, type=socket.SOCK_STREAM)
except socket.gaierror as exc:
raise httpcore.ConnectError(
f"DNS resolution failed for {host!r}: {exc}"
) from exc
except TimeoutError as exc:
raise httpcore.ConnectTimeout(
f"DNS resolution timed out for {host!r}"
) from exc
pinned_ip = _pick_safe_address(addr_infos, host)
return await self._inner.connect_tcp(
pinned_ip, port, timeout, local_address, socket_options
)
# See note on connect_tcp re: ASYNC109 / interface implementation.
async def connect_unix_socket(
self,
path: str,
timeout: float | None = None, # noqa: ASYNC109
socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
) -> AsyncNetworkStream:
return await self._inner.connect_unix_socket(path, timeout, socket_options)
async def sleep(self, seconds: float) -> None:
await self._inner.sleep(seconds)
class SSRFProtectedSyncBackend(NetworkBackend):
"""Sync backend that validates resolved IPs before establishing TCP."""
def __init__(self, inner: NetworkBackend | None = None) -> None:
self._inner = inner if inner is not None else SyncBackend()
def connect_tcp(
self,
host: str,
port: int,
timeout: float | None = None,
local_address: str | None = None,
socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
) -> NetworkStream:
if _check_literal(host):
return self._inner.connect_tcp(
host, port, timeout, local_address, socket_options
)
try:
addr_infos = socket.getaddrinfo(host, port, type=socket.SOCK_STREAM)
except socket.gaierror as exc:
raise httpcore.ConnectError(
f"DNS resolution failed for {host!r}: {exc}"
) from exc
pinned_ip = _pick_safe_address(addr_infos, host)
return self._inner.connect_tcp(
pinned_ip, port, timeout, local_address, socket_options
)
def connect_unix_socket(
self,
path: str,
timeout: float | None = None,
socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
) -> NetworkStream:
return self._inner.connect_unix_socket(path, timeout, socket_options)
def install_async_ssrf_protection(client: typing.Any) -> None:
"""Wrap the client's default transport so SSRF validation runs at connect time.
httpx does not expose `network_backend` through its public transport
API, so we mutate `_pool._network_backend` after construction.
"""
pool = client._transport._pool
if not isinstance(pool._network_backend, SSRFProtectedAsyncBackend):
pool._network_backend = SSRFProtectedAsyncBackend(inner=pool._network_backend)
def install_sync_ssrf_protection(client: typing.Any) -> None:
"""Sync counterpart of `install_async_ssrf_protection`."""
pool = client._transport._pool
if not isinstance(pool._network_backend, SSRFProtectedSyncBackend):
pool._network_backend = SSRFProtectedSyncBackend(inner=pool._network_backend)
RESERVED_HOSTNAMES = [
"localhost",
"127.0.0.1",
"0.0.0.0", # trunk-ignore(bandit/B104)
"::1",
"::",
]
def validate_url_for_http_request(url: str, field_name: str = "URL") -> None:
"""Syntactically validate a URL before passing it to an HTTP client.
Fast-fail check for cases that don't need DNS to detect:
- The URL scheme is http or https only
- If the host is a literal IP address, it is not private/internal/reserved
- The host is not a reserved hostname (localhost, 127.0.0.1, etc.)
- The host does not use internal TLDs (.local, .internal, .localhost)
Wired in as an httpx request event hook by `utils.context.create_httpx_*`,
so every request that goes through a context-provided client runs this
automatically — direct calls from feature code are not required.
Dynamic SSRF protection (rejecting hostnames that resolve to a private
IP, including DNS-rebinding names like `127.0.0.1.nip.io`) happens
inside the HTTP client's connect path via the backends above. Doing
the DNS check here would (a) be defeated by DNS rebinding because
httpx re-resolves at connect time, and (b) block the event loop,
since `socket.getaddrinfo` is synchronous and most callers are async.
Args:
url (str): The URL to validate
field_name (str): The name of the field for error messages
Raises:
ValidationError: If the URL is syntactically invalid or matches one
of the static SSRF deny patterns.
"""
if not url or not url.strip():
msg = f"{field_name} cannot be empty"
log.error(msg)
raise ValidationError(msg, field_name)
try:
parsed = urlparse(url)
except Exception as e:
msg = f"Invalid {field_name}: unable to parse URL"
log.error(f"{msg}: {str(e)}")
raise ValidationError(msg, field_name) from e
# Validate scheme - only allow http and https
if parsed.scheme not in ["http", "https"]:
msg = f"Invalid {field_name}: only http and https schemes are allowed"
log.error(f"SSRF prevention: {msg} - got scheme '{parsed.scheme}'")
raise ValidationError(msg, field_name)
# Extract hostname
hostname = parsed.hostname
if not hostname:
msg = f"Invalid {field_name}: missing hostname"
log.error(msg)
raise ValidationError(msg, field_name)
# Block reserved hostnames that are commonly used to refer to internal services.
if hostname.lower() in RESERVED_HOSTNAMES:
msg = f"Invalid {field_name}: localhost and reserved hostnames are not allowed"
log.error(f"SSRF prevention: {msg} - hostname '{hostname}'")
raise ValidationError(msg, field_name)
# Try to parse hostname as a literal IP (standard or non-standard form).
# HTTP clients accept hex (0x7f000001), decimal (2130706433), and
# shorthand-dotted (127.1) integers via inet_aton, so we mirror that.
ip = parse_ip_literal(hostname)
if ip is not None:
if is_forbidden_ip(ip):
msg = (
f"Invalid {field_name}: private, internal, reserved, "
"or multicast IP addresses are not allowed"
)
log.error(f"SSRF prevention: {msg} - IP '{ip}'")
raise ValidationError(msg, field_name)
return
# Block common internal TLDs
hostname_lower = hostname.lower()
internal_tlds = [".local", ".internal", ".localhost"]
if any(hostname_lower.endswith(tld) for tld in internal_tlds):
msg = f"Invalid {field_name}: internal domain names are not allowed"
log.error(f"SSRF prevention: {msg} - hostname '{hostname}'")
raise ValidationError(msg, field_name)

View File

@@ -1,7 +1,4 @@
import ipaddress
import re
import socket
from urllib.parse import urlparse
from logger.logger import log
from models.user import TEXT_FIELD_LENGTH
@@ -118,112 +115,3 @@ def validate_email(email: str) -> None:
msg = "Invalid email format"
log.error(f"Validation failed: {msg} for email: {email}")
raise ValidationError(msg, "Email")
# Check for localhost and reserved hostnames
RESERVED_HOSTNAMES = [
"localhost",
"127.0.0.1",
"0.0.0.0", # trunk-ignore(bandit/B104)
"::1",
"::",
]
def validate_url_for_http_request(url: str, field_name: str = "URL") -> None:
"""Validate URL to prevent Server-Side Request Forgery (SSRF) attacks.
This function validates that:
- The URL scheme is http or https only
- If the host is a literal IP address, it is not private/internal/reserved
- The host is not a reserved hostname (localhost, 127.0.0.1, etc.)
- The host does not use internal TLDs (.local, .internal, .localhost)
Note: This function does NOT perform DNS resolution. Domain names that resolve
to private IPs will not be detected (DNS rebinding/internal DNS bypass possible).
It only checks literal IP addresses in the hostname.
Args:
url (str): The URL to validate
field_name (str): The name of the field for error messages
Raises:
ValidationError: If the URL is invalid or potentially dangerous
"""
if not url or not url.strip():
msg = f"{field_name} cannot be empty"
log.error(msg)
raise ValidationError(msg, field_name)
try:
parsed = urlparse(url)
except Exception as e:
msg = f"Invalid {field_name}: unable to parse URL"
log.error(f"{msg}: {str(e)}")
raise ValidationError(msg, field_name) from e
# Validate scheme - only allow http and https
if parsed.scheme not in ["http", "https"]:
msg = f"Invalid {field_name}: only http and https schemes are allowed"
log.error(f"SSRF prevention: {msg} - got scheme '{parsed.scheme}'")
raise ValidationError(msg, field_name)
# Extract hostname
hostname = parsed.hostname
if not hostname:
msg = f"Invalid {field_name}: missing hostname"
log.error(msg)
raise ValidationError(msg, field_name)
if hostname.lower() in RESERVED_HOSTNAMES:
msg = f"Invalid {field_name}: localhost and reserved hostnames are not allowed"
log.error(f"SSRF prevention: {msg} - hostname '{hostname}'")
raise ValidationError(msg, field_name)
# Try to resolve hostname as IP address
try:
ip = ipaddress.ip_address(hostname)
# Block private/internal/link-local IP addresses
if ip.is_private or ip.is_loopback or ip.is_link_local or ip.is_reserved:
msg = f"Invalid {field_name}: private, internal, and reserved IP addresses are not allowed"
log.error(f"SSRF prevention: {msg} - IP '{ip}'")
raise ValidationError(msg, field_name)
# Block multicast addresses
if ip.is_multicast:
msg = f"Invalid {field_name}: multicast addresses are not allowed"
log.error(f"SSRF prevention: {msg} - IP '{ip}'")
raise ValidationError(msg, field_name)
except ValueError as e:
# ipaddress.ip_address() only handles standard notation. HTTP clients
# also accept hex integers (0x7f000001), decimal integers (2130706433),
# shorthand dotted (127.1), and octal (0177.0.0.1). Use socket.inet_aton()
# which handles these non-standard IPv4 representations.
try:
packed = socket.inet_aton(hostname)
ip = ipaddress.IPv4Address(packed)
if ip.is_private or ip.is_loopback or ip.is_link_local or ip.is_reserved:
msg = f"Invalid {field_name}: private, internal, and reserved IP addresses are not allowed"
log.error(f"SSRF prevention: {msg} - IP '{ip}'")
raise ValidationError(msg, field_name)
if ip.is_multicast:
msg = f"Invalid {field_name}: multicast addresses are not allowed"
log.error(f"SSRF prevention: {msg} - IP '{ip}'")
raise ValidationError(msg, field_name)
except OSError:
pass # Not an IP address at all - fall through to domain name checks
# Additional checks for suspicious domain patterns
hostname_lower = hostname.lower()
# Block common internal TLDs
internal_tlds = [".local", ".internal", ".localhost"]
if any(hostname_lower.endswith(tld) for tld in internal_tlds):
msg = f"Invalid {field_name}: internal domain names are not allowed"
log.error(f"SSRF prevention: {msg} - hostname '{hostname}'")
raise ValidationError(msg, field_name) from e

View File

@@ -15,7 +15,7 @@ dependencies = [
"SQLAlchemy[mariadb-connector,mysql-connector,postgresql-psycopg] ~= 2.0",
"Unidecode ~= 1.3",
"aiohttp~=3.13",
"asyncssh ~= 2.17",
"asyncssh~=2.23",
"alembic ~= 1.16",
"anyio ~= 4.4",
"authlib~=1.6.12",

10
uv.lock generated
View File

@@ -7,7 +7,7 @@ resolution-markers = [
]
[options]
exclude-newer = "2026-05-21T08:54:16.815535417Z"
exclude-newer = "0001-01-01T00:00:00Z" # This has no effect and is included for backwards compatibility when using relative exclude-newer values.
exclude-newer-span = "P7D"
[options.exclude-newer-package]
@@ -176,15 +176,15 @@ wheels = [
[[package]]
name = "asyncssh"
version = "2.22.0"
version = "2.23.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "cryptography" },
{ name = "typing-extensions" },
]
sdist = { url = "https://files.pythonhosted.org/packages/fc/d5/957886c316466349d55c4de6a688a10a98295c0b4429deb8db1a17f3eb19/asyncssh-2.22.0.tar.gz", hash = "sha256:c3ce72b01be4f97b40e62844dd384227e5ff5a401a3793007c42f86a5c8eb537", size = 540523, upload-time = "2025-12-21T23:38:30.5Z" }
sdist = { url = "https://files.pythonhosted.org/packages/ee/fd/c34fe7e30838b4b9cc91903da26a62c6d33b673c731b3d951fcd70ab1889/asyncssh-2.23.0.tar.gz", hash = "sha256:8c54760953c1f2cf282591bcba5c8c70efc48d645bbf26bd2307a9c66a0ed1a7", size = 542154, upload-time = "2026-05-09T03:15:01.856Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/ed/ae/0da2f2214fc183338af1afe5a103a2052fd03464e8eafbd827abff58a4d0/asyncssh-2.22.0-py3-none-any.whl", hash = "sha256:d16465ccdf1ed20eba1131b14415b155e047f6f5be0d19f39c2e0b61331ee0e7", size = 374938, upload-time = "2025-12-21T23:38:28.976Z" },
{ url = "https://files.pythonhosted.org/packages/ff/b5/b1a3979f4840d1271ca8e0978dbccfb18ad2d33b4ece85cf77122fb46e5f/asyncssh-2.23.0-py3-none-any.whl", hash = "sha256:14108bfdaae17457f0c1841e883ad934271bbfdd46458aa4c4d0973451940ad0", size = 375687, upload-time = "2026-05-09T03:15:00.221Z" },
]
[[package]]
@@ -2255,7 +2255,7 @@ requires-dist = [
{ name = "aiohttp", specifier = "~=3.13" },
{ name = "alembic", specifier = "~=1.16" },
{ name = "anyio", specifier = "~=4.4" },
{ name = "asyncssh", specifier = "~=2.17" },
{ name = "asyncssh", specifier = "~=2.23" },
{ name = "authlib", specifier = "~=1.6.12" },
{ name = "bcrypt", specifier = "<5" },
{ name = "colorama", specifier = "~=0.4" },