mirror of
https://github.com/rommapp/romm.git
synced 2026-06-30 07:45:52 +00:00
Merge branch 'master' into feat/soundtrack-support
This commit is contained in:
@@ -15,7 +15,6 @@ from models.collection import Collection
|
||||
from models.rom import Rom
|
||||
from tasks.scheduled.convert_images_to_webp import ImageConverter
|
||||
from utils.context import ctx_httpx_client
|
||||
from utils.validation import validate_url_for_http_request
|
||||
|
||||
from .base_handler import CoverSize, FSHandler
|
||||
|
||||
@@ -145,8 +144,6 @@ class FSResourcesHandler(FSHandler):
|
||||
return None
|
||||
else:
|
||||
# Handle HTTP URLs
|
||||
validate_url_for_http_request(url_cover, "url_cover")
|
||||
|
||||
httpx_client = ctx_httpx_client.get()
|
||||
try:
|
||||
async with httpx_client.stream(
|
||||
@@ -314,8 +311,6 @@ class FSResourcesHandler(FSHandler):
|
||||
return None
|
||||
else:
|
||||
# Handle HTTP URLs
|
||||
validate_url_for_http_request(url_screenhot, "url_screenshot")
|
||||
|
||||
httpx_client = ctx_httpx_client.get()
|
||||
try:
|
||||
async with httpx_client.stream(
|
||||
@@ -430,8 +425,6 @@ class FSResourcesHandler(FSHandler):
|
||||
return None
|
||||
else:
|
||||
# Handle HTTP URL
|
||||
validate_url_for_http_request(url_manual, "url_manual")
|
||||
|
||||
httpx_client = ctx_httpx_client.get()
|
||||
try:
|
||||
async with httpx_client.stream(
|
||||
@@ -501,8 +494,6 @@ class FSResourcesHandler(FSHandler):
|
||||
|
||||
# Retroachievements
|
||||
async def store_ra_badge(self, url: str, path: str) -> None:
|
||||
validate_url_for_http_request(url, "url_badge")
|
||||
|
||||
httpx_client = ctx_httpx_client.get()
|
||||
directory, filename = os.path.split(path)
|
||||
|
||||
@@ -569,8 +560,6 @@ class FSResourcesHandler(FSHandler):
|
||||
return None
|
||||
else:
|
||||
# Handle HTTP URLs
|
||||
validate_url_for_http_request(url_media, "url_media")
|
||||
|
||||
httpx_client = ctx_httpx_client.get()
|
||||
try:
|
||||
async with httpx_client.stream(
|
||||
|
||||
466
backend/tests/utils/test_ssrf.py
Normal file
466
backend/tests/utils/test_ssrf.py
Normal file
@@ -0,0 +1,466 @@
|
||||
"""Tests for SSRF defense: URL validator + httpcore network backends."""
|
||||
|
||||
import asyncio
|
||||
import socket
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import httpcore
|
||||
import pytest
|
||||
|
||||
from utils.ssrf import (
|
||||
SSRFProtectedAsyncBackend,
|
||||
SSRFProtectedSyncBackend,
|
||||
is_forbidden_ip,
|
||||
parse_ip_literal,
|
||||
validate_url_for_http_request,
|
||||
)
|
||||
from utils.validation import ValidationError
|
||||
|
||||
ConnectCall = tuple[tuple[Any, ...], dict[str, Any]]
|
||||
|
||||
|
||||
def _addr_info(ip: str, port: int) -> list[tuple[Any, ...]]:
|
||||
family = socket.AF_INET6 if ":" in ip else socket.AF_INET
|
||||
return [(family, socket.SOCK_STREAM, 0, "", (ip, port))]
|
||||
|
||||
|
||||
def _stub_async_inner(connect_calls: list[ConnectCall]) -> MagicMock:
|
||||
inner = MagicMock()
|
||||
|
||||
def _record(*args: Any, **kwargs: Any) -> MagicMock:
|
||||
connect_calls.append((args, kwargs))
|
||||
return MagicMock()
|
||||
|
||||
inner.connect_tcp = AsyncMock(side_effect=_record)
|
||||
return inner
|
||||
|
||||
|
||||
def _stub_sync_inner(connect_calls: list[ConnectCall]) -> MagicMock:
|
||||
inner = MagicMock()
|
||||
|
||||
def _record(*args: Any, **kwargs: Any) -> MagicMock:
|
||||
connect_calls.append((args, kwargs))
|
||||
return MagicMock()
|
||||
|
||||
inner.connect_tcp = MagicMock(side_effect=_record)
|
||||
return inner
|
||||
|
||||
|
||||
class TestIsForbiddenIp:
|
||||
@pytest.mark.parametrize(
|
||||
"ip",
|
||||
[
|
||||
"127.0.0.1",
|
||||
"10.0.0.1",
|
||||
"192.168.1.1",
|
||||
"172.16.0.1",
|
||||
"169.254.169.254",
|
||||
"0.0.0.0",
|
||||
"224.0.0.1",
|
||||
"::1",
|
||||
"fc00::1",
|
||||
"fe80::1",
|
||||
"ff02::1",
|
||||
],
|
||||
)
|
||||
def test_forbidden(self, ip):
|
||||
import ipaddress
|
||||
|
||||
assert is_forbidden_ip(ipaddress.ip_address(ip)) is True
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"ip", ["8.8.8.8", "1.1.1.1", "93.184.216.34", "2001:4860:4860::8888"]
|
||||
)
|
||||
def test_allowed(self, ip):
|
||||
import ipaddress
|
||||
|
||||
assert is_forbidden_ip(ipaddress.ip_address(ip)) is False
|
||||
|
||||
|
||||
class TestParseIpLiteral:
|
||||
def test_standard(self):
|
||||
assert str(parse_ip_literal("127.0.0.1")) == "127.0.0.1"
|
||||
assert str(parse_ip_literal("::1")) == "::1"
|
||||
|
||||
def test_non_standard_ipv4(self):
|
||||
# Hex, decimal, shorthand - all map to 127.0.0.1
|
||||
assert str(parse_ip_literal("0x7f000001")) == "127.0.0.1"
|
||||
assert str(parse_ip_literal("2130706433")) == "127.0.0.1"
|
||||
assert str(parse_ip_literal("127.1")) == "127.0.0.1"
|
||||
|
||||
def test_hostname_returns_none(self):
|
||||
assert parse_ip_literal("example.com") is None
|
||||
|
||||
|
||||
class TestSSRFProtectedAsyncBackend:
|
||||
async def test_safe_hostname_connects_to_pinned_ip(self, monkeypatch):
|
||||
"""Backend resolves once, validates, and passes the pinned IP to inner."""
|
||||
calls: list[ConnectCall] = []
|
||||
inner = _stub_async_inner(calls)
|
||||
backend = SSRFProtectedAsyncBackend(inner=inner)
|
||||
|
||||
async def fake_getaddrinfo(host, port, *args, **kwargs):
|
||||
return _addr_info("93.184.216.34", port)
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
monkeypatch.setattr(loop, "getaddrinfo", fake_getaddrinfo)
|
||||
|
||||
await backend.connect_tcp("example.com", 443)
|
||||
|
||||
# inner must receive the resolved IP, not the original hostname.
|
||||
# That is what pins the address against DNS rebinding.
|
||||
assert calls[0][0][0] == "93.184.216.34"
|
||||
assert calls[0][0][1] == 443
|
||||
|
||||
async def test_hostname_resolving_to_private_ip_is_rejected(self, monkeypatch):
|
||||
"""DNS rebinding case: hostname resolves to 127.0.0.1 must fail."""
|
||||
inner = _stub_async_inner([])
|
||||
backend = SSRFProtectedAsyncBackend(inner=inner)
|
||||
|
||||
async def fake_getaddrinfo(host, port, *args, **kwargs):
|
||||
return _addr_info("127.0.0.1", port)
|
||||
|
||||
monkeypatch.setattr(asyncio.get_running_loop(), "getaddrinfo", fake_getaddrinfo)
|
||||
|
||||
with pytest.raises(httpcore.ConnectError, match="forbidden IP"):
|
||||
await backend.connect_tcp("127.0.0.1.nip.io", 80)
|
||||
inner.connect_tcp.assert_not_called()
|
||||
|
||||
async def test_mixed_resolution_rejected(self, monkeypatch):
|
||||
"""If any returned address is forbidden, reject - don't trust round-robin."""
|
||||
inner = _stub_async_inner([])
|
||||
backend = SSRFProtectedAsyncBackend(inner=inner)
|
||||
|
||||
async def fake_getaddrinfo(host, port, *args, **kwargs):
|
||||
return _addr_info("93.184.216.34", port) + _addr_info("10.0.0.1", port)
|
||||
|
||||
monkeypatch.setattr(asyncio.get_running_loop(), "getaddrinfo", fake_getaddrinfo)
|
||||
|
||||
with pytest.raises(httpcore.ConnectError, match="forbidden IP"):
|
||||
await backend.connect_tcp("mixed.example.com", 80)
|
||||
inner.connect_tcp.assert_not_called()
|
||||
|
||||
async def test_literal_forbidden_ip_rejected(self):
|
||||
inner = _stub_async_inner([])
|
||||
backend = SSRFProtectedAsyncBackend(inner=inner)
|
||||
with pytest.raises(httpcore.ConnectError, match="forbidden IP"):
|
||||
await backend.connect_tcp("169.254.169.254", 80)
|
||||
inner.connect_tcp.assert_not_called()
|
||||
|
||||
async def test_literal_public_ip_passes_through(self):
|
||||
calls: list[ConnectCall] = []
|
||||
inner = _stub_async_inner(calls)
|
||||
backend = SSRFProtectedAsyncBackend(inner=inner)
|
||||
await backend.connect_tcp("8.8.8.8", 443)
|
||||
# Literal public IPs are passed through unchanged.
|
||||
assert calls[0][0][0] == "8.8.8.8"
|
||||
|
||||
async def test_non_standard_ipv4_literal_blocked(self):
|
||||
"""Hex/decimal IPv4 forms must be blocked, matching httpx's parsing."""
|
||||
inner = _stub_async_inner([])
|
||||
backend = SSRFProtectedAsyncBackend(inner=inner)
|
||||
with pytest.raises(httpcore.ConnectError, match="forbidden IP"):
|
||||
await backend.connect_tcp("2130706433", 80) # 127.0.0.1
|
||||
inner.connect_tcp.assert_not_called()
|
||||
|
||||
async def test_dns_timeout_raises_connect_timeout(self, monkeypatch):
|
||||
"""A resolver that hangs past the caller's timeout must not block forever.
|
||||
|
||||
Regression: an earlier version applied the caller's timeout only to
|
||||
the TCP connect inside the inner backend, leaving `loop.getaddrinfo`
|
||||
unbounded. We now wrap the lookup in `asyncio.timeout()` so a slow
|
||||
resolver is bounded by the same budget the caller specified.
|
||||
"""
|
||||
inner = _stub_async_inner([])
|
||||
backend = SSRFProtectedAsyncBackend(inner=inner)
|
||||
|
||||
async def hang_forever(*args, **kwargs):
|
||||
await asyncio.sleep(3600)
|
||||
|
||||
monkeypatch.setattr(asyncio.get_running_loop(), "getaddrinfo", hang_forever)
|
||||
|
||||
with pytest.raises(httpcore.ConnectTimeout, match="DNS resolution timed out"):
|
||||
await backend.connect_tcp("slow.example.com", 80, timeout=0.05)
|
||||
inner.connect_tcp.assert_not_called()
|
||||
|
||||
async def test_dns_failure_propagates_as_connect_error(self, monkeypatch):
|
||||
inner = _stub_async_inner([])
|
||||
backend = SSRFProtectedAsyncBackend(inner=inner)
|
||||
|
||||
async def fake_getaddrinfo(*args, **kwargs):
|
||||
raise socket.gaierror("Name or service not known")
|
||||
|
||||
monkeypatch.setattr(asyncio.get_running_loop(), "getaddrinfo", fake_getaddrinfo)
|
||||
|
||||
with pytest.raises(httpcore.ConnectError, match="DNS resolution failed"):
|
||||
await backend.connect_tcp("nonexistent.invalid", 80)
|
||||
inner.connect_tcp.assert_not_called()
|
||||
|
||||
|
||||
class TestSSRFProtectedSyncBackend:
|
||||
def test_safe_hostname_connects_to_pinned_ip(self, monkeypatch):
|
||||
calls: list[ConnectCall] = []
|
||||
inner = _stub_sync_inner(calls)
|
||||
backend = SSRFProtectedSyncBackend(inner=inner)
|
||||
|
||||
monkeypatch.setattr(
|
||||
socket,
|
||||
"getaddrinfo",
|
||||
lambda host, port, *a, **kw: _addr_info("93.184.216.34", port),
|
||||
)
|
||||
|
||||
backend.connect_tcp("example.com", 443)
|
||||
assert calls[0][0][0] == "93.184.216.34"
|
||||
|
||||
def test_hostname_resolving_to_private_ip_is_rejected(self, monkeypatch):
|
||||
inner = _stub_sync_inner([])
|
||||
backend = SSRFProtectedSyncBackend(inner=inner)
|
||||
monkeypatch.setattr(
|
||||
socket,
|
||||
"getaddrinfo",
|
||||
lambda host, port, *a, **kw: _addr_info("127.0.0.1", port),
|
||||
)
|
||||
with pytest.raises(httpcore.ConnectError, match="forbidden IP"):
|
||||
backend.connect_tcp("127.0.0.1.nip.io", 80)
|
||||
inner.connect_tcp.assert_not_called()
|
||||
|
||||
def test_literal_forbidden_ip_rejected(self):
|
||||
inner = _stub_sync_inner([])
|
||||
backend = SSRFProtectedSyncBackend(inner=inner)
|
||||
with pytest.raises(httpcore.ConnectError, match="forbidden IP"):
|
||||
backend.connect_tcp("10.0.0.1", 80)
|
||||
inner.connect_tcp.assert_not_called()
|
||||
|
||||
|
||||
class TestRequestEventHook:
|
||||
"""Verify the syntactic URL validator is wired as a request event hook.
|
||||
|
||||
With the hook in place, every request through a context-provided client
|
||||
is validated automatically; feature code does not need to call
|
||||
`validate_url_for_http_request` itself.
|
||||
"""
|
||||
|
||||
async def test_async_hook_rejects_bad_scheme(self):
|
||||
from utils.context import create_httpx_async_client
|
||||
from utils.validation import ValidationError
|
||||
|
||||
client = create_httpx_async_client()
|
||||
try:
|
||||
with pytest.raises(ValidationError, match="only http and https"):
|
||||
await client.get("file:///etc/passwd")
|
||||
finally:
|
||||
await client.aclose()
|
||||
|
||||
async def test_async_hook_rejects_internal_tld(self):
|
||||
from utils.context import create_httpx_async_client
|
||||
from utils.validation import ValidationError
|
||||
|
||||
client = create_httpx_async_client()
|
||||
try:
|
||||
with pytest.raises(ValidationError, match="internal domain names"):
|
||||
await client.get("http://printer.local/status")
|
||||
finally:
|
||||
await client.aclose()
|
||||
|
||||
def test_sync_hook_rejects_literal_private_ip(self):
|
||||
from utils.context import create_httpx_client
|
||||
from utils.validation import ValidationError
|
||||
|
||||
with create_httpx_client() as client:
|
||||
with pytest.raises(ValidationError, match="private, internal"):
|
||||
client.get("http://10.0.0.1/")
|
||||
|
||||
|
||||
class TestInstallation:
|
||||
"""Verify the backend is actually wired onto httpx clients we create."""
|
||||
|
||||
def test_create_httpx_async_client_installs_backend(self):
|
||||
from utils.context import create_httpx_async_client
|
||||
from utils.ssrf import SSRFProtectedAsyncBackend as Async
|
||||
|
||||
client = create_httpx_async_client()
|
||||
try:
|
||||
assert isinstance(client._transport._pool._network_backend, Async)
|
||||
finally:
|
||||
asyncio.run(client.aclose())
|
||||
|
||||
def test_create_httpx_client_installs_backend(self):
|
||||
from utils.context import create_httpx_client
|
||||
from utils.ssrf import SSRFProtectedSyncBackend as Sync
|
||||
|
||||
with create_httpx_client() as client:
|
||||
assert isinstance(client._transport._pool._network_backend, Sync)
|
||||
|
||||
def test_proxy_transports_are_not_wrapped(self):
|
||||
"""Proxy mounts must keep their stock backend.
|
||||
|
||||
A common deployment pattern is `HTTPS_PROXY=http://sidecar:9050`,
|
||||
where `sidecar` resolves to a docker-bridge private IP. If we
|
||||
wrapped the proxy transport, our SSRF backend would refuse to
|
||||
connect to the operator's chosen proxy. SSRF protection at the
|
||||
proxy hop is the operator's responsibility; the destination URL
|
||||
is still validated by the request event hook on the client.
|
||||
"""
|
||||
import httpx
|
||||
|
||||
from utils.ssrf import SSRFProtectedAsyncBackend as Async
|
||||
from utils.ssrf import (
|
||||
install_async_ssrf_protection,
|
||||
)
|
||||
|
||||
client = httpx.AsyncClient(proxy="http://proxy.invalid:3128")
|
||||
try:
|
||||
install_async_ssrf_protection(client)
|
||||
assert isinstance(client._transport._pool._network_backend, Async)
|
||||
for mount in client._mounts.values():
|
||||
if mount is None:
|
||||
continue
|
||||
# Proxy mount must NOT have been wrapped.
|
||||
assert not isinstance(mount._pool._network_backend, Async)
|
||||
finally:
|
||||
asyncio.run(client.aclose())
|
||||
|
||||
|
||||
class TestValidateUrlForHttpRequest:
|
||||
"""Test URL validation for HTTP requests to prevent SSRF attacks."""
|
||||
|
||||
def test_valid_http_urls(self):
|
||||
"""Valid HTTP/HTTPS URLs pass syntactic validation without DNS lookups.
|
||||
|
||||
DNS-based SSRF checks live in the HTTP client's connect path, so
|
||||
this layer must not call DNS.
|
||||
"""
|
||||
validate_url_for_http_request("http://example.com", "test_url")
|
||||
validate_url_for_http_request("https://example.com", "test_url")
|
||||
validate_url_for_http_request("http://example.com/path", "test_url")
|
||||
validate_url_for_http_request("https://example.com/path?query=1", "test_url")
|
||||
validate_url_for_http_request("http://subdomain.example.com", "test_url")
|
||||
|
||||
def test_validator_does_not_perform_dns_lookup(self, monkeypatch):
|
||||
"""Regression: validator must not block the event loop on DNS.
|
||||
|
||||
Earlier implementations called `socket.getaddrinfo`, which both
|
||||
blocked the running event loop in async media-download callers and
|
||||
was defeated by DNS rebinding (the value seen here did not match
|
||||
the IP httpx later connected to). We patch getaddrinfo to a poison
|
||||
function so any accidental reintroduction fails this test.
|
||||
"""
|
||||
|
||||
def _explode(*_args, **_kwargs):
|
||||
raise AssertionError(
|
||||
"validate_url_for_http_request must not call DNS; "
|
||||
"SSRF DNS protection lives in the HTTP client backend"
|
||||
)
|
||||
|
||||
monkeypatch.setattr(socket, "getaddrinfo", _explode)
|
||||
validate_url_for_http_request("http://example.com", "test_url")
|
||||
|
||||
def test_invalid_empty_url(self):
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
validate_url_for_http_request("", "test_url")
|
||||
assert "cannot be empty" in exc_info.value.message
|
||||
|
||||
def test_invalid_scheme(self):
|
||||
for url in (
|
||||
"ftp://example.com",
|
||||
"file:///etc/passwd",
|
||||
"data:text/html,<h1>test</h1>",
|
||||
"javascript:alert(1)", # XSS attack vector
|
||||
):
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
validate_url_for_http_request(url, "test_url")
|
||||
assert "only http and https schemes are allowed" in exc_info.value.message
|
||||
|
||||
def test_invalid_localhost(self):
|
||||
for url in (
|
||||
"http://localhost",
|
||||
"http://127.0.0.1",
|
||||
"http://[::1]",
|
||||
"http://0.0.0.0",
|
||||
):
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
validate_url_for_http_request(url, "test_url")
|
||||
assert (
|
||||
"localhost and reserved hostnames are not allowed"
|
||||
in exc_info.value.message
|
||||
)
|
||||
|
||||
def test_invalid_private_ipv4_addresses(self):
|
||||
for url in (
|
||||
"http://10.0.0.1",
|
||||
"http://192.168.1.1",
|
||||
"http://172.16.0.1",
|
||||
"http://172.31.255.254",
|
||||
):
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
validate_url_for_http_request(url, "test_url")
|
||||
assert (
|
||||
"private, internal, reserved, or multicast IP addresses are not allowed"
|
||||
in exc_info.value.message
|
||||
)
|
||||
|
||||
def test_invalid_loopback_addresses(self):
|
||||
# 127.0.0.1 itself is in RESERVED_HOSTNAMES; these cover the rest of 127/8.
|
||||
for url in ("http://127.0.0.2", "http://127.255.255.255"):
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
validate_url_for_http_request(url, "test_url")
|
||||
assert (
|
||||
"private, internal, reserved, or multicast IP addresses are not allowed"
|
||||
in exc_info.value.message
|
||||
)
|
||||
|
||||
def test_invalid_private_ipv6_addresses(self):
|
||||
for url in (
|
||||
"http://[fe80::1]", # link-local
|
||||
"http://[fc00::1]", # unique local
|
||||
"http://[fd00::1]", # unique local
|
||||
):
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
validate_url_for_http_request(url, "test_url")
|
||||
assert (
|
||||
"private, internal, reserved, or multicast IP addresses are not allowed"
|
||||
in exc_info.value.message
|
||||
)
|
||||
|
||||
def test_invalid_multicast_addresses(self):
|
||||
for url in ("http://224.0.0.1", "http://[ff02::1]"):
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
validate_url_for_http_request(url, "test_url")
|
||||
assert (
|
||||
"private, internal, reserved, or multicast IP addresses are not allowed"
|
||||
in exc_info.value.message
|
||||
)
|
||||
|
||||
def test_invalid_internal_tlds(self):
|
||||
for url in (
|
||||
"http://server.local",
|
||||
"http://server.internal",
|
||||
"http://server.localhost",
|
||||
):
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
validate_url_for_http_request(url, "test_url")
|
||||
assert "internal domain names are not allowed" in exc_info.value.message
|
||||
|
||||
def test_invalid_non_standard_ip_representations(self):
|
||||
"""Non-standard IPv4 forms (hex, decimal, shorthand) are SSRF bypass vectors."""
|
||||
cases = [
|
||||
"http://0x7f000001", # hex 127.0.0.1
|
||||
"http://2130706433", # decimal 127.0.0.1
|
||||
"http://127.1", # shorthand 127.0.0.1
|
||||
"http://0x0a000001", # hex 10.0.0.1
|
||||
"http://3232235777", # decimal 192.168.1.1
|
||||
"http://0xa9fea9fe", # hex 169.254.169.254 (cloud metadata)
|
||||
]
|
||||
for url in cases:
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
validate_url_for_http_request(url, "test_url")
|
||||
assert (
|
||||
"private, internal, reserved, or multicast IP addresses are not allowed"
|
||||
in exc_info.value.message
|
||||
)
|
||||
|
||||
def test_invalid_missing_hostname(self):
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
validate_url_for_http_request("http://", "test_url")
|
||||
assert "missing hostname" in exc_info.value.message
|
||||
@@ -7,7 +7,6 @@ from utils.validation import (
|
||||
validate_ascii_only,
|
||||
validate_email,
|
||||
validate_password,
|
||||
validate_url_for_http_request,
|
||||
validate_username,
|
||||
)
|
||||
|
||||
@@ -158,226 +157,3 @@ class TestValidateEmail:
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
validate_email("résumé@example.com")
|
||||
assert "ASCII characters" in exc_info.value.message
|
||||
|
||||
|
||||
class TestValidateUrlForHttpRequest:
|
||||
"""Test URL validation for HTTP requests to prevent SSRF attacks."""
|
||||
|
||||
def test_valid_http_urls(self):
|
||||
"""Test that valid HTTP/HTTPS URLs pass validation."""
|
||||
validate_url_for_http_request("http://example.com", "test_url")
|
||||
validate_url_for_http_request("https://example.com", "test_url")
|
||||
validate_url_for_http_request("http://example.com/path", "test_url")
|
||||
validate_url_for_http_request("https://example.com/path?query=1", "test_url")
|
||||
validate_url_for_http_request("http://subdomain.example.com", "test_url")
|
||||
|
||||
def test_invalid_empty_url(self):
|
||||
"""Test that empty URLs fail validation."""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
validate_url_for_http_request("", "test_url")
|
||||
assert "cannot be empty" in exc_info.value.message
|
||||
|
||||
def test_invalid_scheme(self):
|
||||
"""Test that non-HTTP/HTTPS schemes fail validation."""
|
||||
# FTP scheme
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
validate_url_for_http_request("ftp://example.com", "test_url")
|
||||
assert "only http and https schemes are allowed" in exc_info.value.message
|
||||
|
||||
# File scheme
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
validate_url_for_http_request("file:///etc/passwd", "test_url")
|
||||
assert "only http and https schemes are allowed" in exc_info.value.message
|
||||
|
||||
# Data scheme
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
validate_url_for_http_request("data:text/html,<h1>test</h1>", "test_url")
|
||||
assert "only http and https schemes are allowed" in exc_info.value.message
|
||||
|
||||
# JavaScript scheme (XSS attack vector)
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
validate_url_for_http_request("javascript:alert(1)", "test_url")
|
||||
assert "only http and https schemes are allowed" in exc_info.value.message
|
||||
|
||||
def test_invalid_localhost(self):
|
||||
"""Test that localhost and reserved hostnames fail validation."""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
validate_url_for_http_request("http://localhost", "test_url")
|
||||
assert (
|
||||
"localhost and reserved hostnames are not allowed" in exc_info.value.message
|
||||
)
|
||||
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
validate_url_for_http_request("http://127.0.0.1", "test_url")
|
||||
assert (
|
||||
"localhost and reserved hostnames are not allowed" in exc_info.value.message
|
||||
)
|
||||
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
validate_url_for_http_request("http://[::1]", "test_url")
|
||||
assert (
|
||||
"localhost and reserved hostnames are not allowed" in exc_info.value.message
|
||||
)
|
||||
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
validate_url_for_http_request("http://0.0.0.0", "test_url")
|
||||
assert (
|
||||
"localhost and reserved hostnames are not allowed" in exc_info.value.message
|
||||
)
|
||||
|
||||
def test_invalid_private_ipv4_addresses(self):
|
||||
"""Test that private IPv4 addresses fail validation."""
|
||||
# 10.x.x.x range
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
validate_url_for_http_request("http://10.0.0.1", "test_url")
|
||||
assert (
|
||||
"private, internal, and reserved IP addresses are not allowed"
|
||||
in exc_info.value.message
|
||||
)
|
||||
|
||||
# 192.168.x.x range
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
validate_url_for_http_request("http://192.168.1.1", "test_url")
|
||||
assert (
|
||||
"private, internal, and reserved IP addresses are not allowed"
|
||||
in exc_info.value.message
|
||||
)
|
||||
|
||||
# 172.16.x.x - 172.31.x.x range
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
validate_url_for_http_request("http://172.16.0.1", "test_url")
|
||||
assert (
|
||||
"private, internal, and reserved IP addresses are not allowed"
|
||||
in exc_info.value.message
|
||||
)
|
||||
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
validate_url_for_http_request("http://172.31.255.254", "test_url")
|
||||
assert (
|
||||
"private, internal, and reserved IP addresses are not allowed"
|
||||
in exc_info.value.message
|
||||
)
|
||||
|
||||
def test_invalid_loopback_addresses(self):
|
||||
"""Test that loopback addresses fail validation."""
|
||||
# 127.x.x.x range
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
validate_url_for_http_request("http://127.0.0.2", "test_url")
|
||||
assert (
|
||||
"private, internal, and reserved IP addresses are not allowed"
|
||||
in exc_info.value.message
|
||||
)
|
||||
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
validate_url_for_http_request("http://127.255.255.255", "test_url")
|
||||
assert (
|
||||
"private, internal, and reserved IP addresses are not allowed"
|
||||
in exc_info.value.message
|
||||
)
|
||||
|
||||
def test_invalid_private_ipv6_addresses(self):
|
||||
"""Test that private/link-local IPv6 addresses fail validation."""
|
||||
# Link-local IPv6: fe80::/10
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
validate_url_for_http_request("http://[fe80::1]", "test_url")
|
||||
assert (
|
||||
"private, internal, and reserved IP addresses are not allowed"
|
||||
in exc_info.value.message
|
||||
)
|
||||
|
||||
# Unique local address: fc00::/7
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
validate_url_for_http_request("http://[fc00::1]", "test_url")
|
||||
assert (
|
||||
"private, internal, and reserved IP addresses are not allowed"
|
||||
in exc_info.value.message
|
||||
)
|
||||
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
validate_url_for_http_request("http://[fd00::1]", "test_url")
|
||||
assert (
|
||||
"private, internal, and reserved IP addresses are not allowed"
|
||||
in exc_info.value.message
|
||||
)
|
||||
|
||||
def test_invalid_multicast_addresses(self):
|
||||
"""Test that multicast addresses fail validation."""
|
||||
# IPv4 multicast: 224.0.0.0/4
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
validate_url_for_http_request("http://224.0.0.1", "test_url")
|
||||
assert "multicast addresses are not allowed" in exc_info.value.message
|
||||
|
||||
# IPv6 multicast: ff00::/8
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
validate_url_for_http_request("http://[ff02::1]", "test_url")
|
||||
assert "multicast addresses are not allowed" in exc_info.value.message
|
||||
|
||||
def test_invalid_internal_tlds(self):
|
||||
"""Test that internal TLDs fail validation."""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
validate_url_for_http_request("http://server.local", "test_url")
|
||||
assert "internal domain names are not allowed" in exc_info.value.message
|
||||
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
validate_url_for_http_request("http://server.internal", "test_url")
|
||||
assert "internal domain names are not allowed" in exc_info.value.message
|
||||
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
validate_url_for_http_request("http://server.localhost", "test_url")
|
||||
assert "internal domain names are not allowed" in exc_info.value.message
|
||||
|
||||
def test_invalid_non_standard_ip_representations(self):
|
||||
"""Test that non-standard IP representations are blocked (SSRF bypass vectors)."""
|
||||
# Hexadecimal integer for 127.0.0.1
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
validate_url_for_http_request("http://0x7f000001", "test_url")
|
||||
assert (
|
||||
"private, internal, and reserved IP addresses are not allowed"
|
||||
in exc_info.value.message
|
||||
)
|
||||
|
||||
# Decimal integer for 127.0.0.1
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
validate_url_for_http_request("http://2130706433", "test_url")
|
||||
assert (
|
||||
"private, internal, and reserved IP addresses are not allowed"
|
||||
in exc_info.value.message
|
||||
)
|
||||
|
||||
# Shorthand dotted for 127.0.0.1
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
validate_url_for_http_request("http://127.1", "test_url")
|
||||
assert (
|
||||
"private, internal, and reserved IP addresses are not allowed"
|
||||
in exc_info.value.message
|
||||
)
|
||||
|
||||
# Hexadecimal integer for 10.0.0.1
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
validate_url_for_http_request("http://0x0a000001", "test_url")
|
||||
assert (
|
||||
"private, internal, and reserved IP addresses are not allowed"
|
||||
in exc_info.value.message
|
||||
)
|
||||
|
||||
# Decimal integer for 192.168.1.1
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
validate_url_for_http_request("http://3232235777", "test_url")
|
||||
assert (
|
||||
"private, internal, and reserved IP addresses are not allowed"
|
||||
in exc_info.value.message
|
||||
)
|
||||
|
||||
# Hexadecimal integer for 169.254.169.254 (cloud metadata)
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
validate_url_for_http_request("http://0xa9fea9fe", "test_url")
|
||||
assert (
|
||||
"private, internal, and reserved IP addresses are not allowed"
|
||||
in exc_info.value.message
|
||||
)
|
||||
|
||||
def test_invalid_missing_hostname(self):
|
||||
"""Test that URLs without hostnames fail validation."""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
validate_url_for_http_request("http://", "test_url")
|
||||
assert "missing hostname" in exc_info.value.message
|
||||
|
||||
@@ -8,6 +8,11 @@ import httpx
|
||||
from fastapi import Request, Response
|
||||
|
||||
from config import has_proxy_env
|
||||
from utils.ssrf import (
|
||||
install_async_ssrf_protection,
|
||||
install_sync_ssrf_protection,
|
||||
validate_url_for_http_request,
|
||||
)
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
@@ -15,16 +20,34 @@ ctx_aiohttp_session: ContextVar[aiohttp.ClientSession] = ContextVar("aiohttp_ses
|
||||
ctx_httpx_client: ContextVar[httpx.AsyncClient] = ContextVar("httpx_client")
|
||||
|
||||
|
||||
def _validate_request_url_sync(request: httpx.Request) -> None:
|
||||
validate_url_for_http_request(str(request.url))
|
||||
|
||||
|
||||
async def _validate_request_url_async(request: httpx.Request) -> None:
|
||||
validate_url_for_http_request(str(request.url))
|
||||
|
||||
|
||||
def create_aiohttp_session() -> aiohttp.ClientSession:
|
||||
return aiohttp.ClientSession(trust_env=has_proxy_env())
|
||||
|
||||
|
||||
def create_httpx_async_client() -> httpx.AsyncClient:
|
||||
return httpx.AsyncClient(trust_env=has_proxy_env())
|
||||
client = httpx.AsyncClient(
|
||||
trust_env=has_proxy_env(),
|
||||
event_hooks={"request": [_validate_request_url_async]},
|
||||
)
|
||||
install_async_ssrf_protection(client)
|
||||
return client
|
||||
|
||||
|
||||
def create_httpx_client() -> httpx.Client:
|
||||
return httpx.Client(trust_env=has_proxy_env())
|
||||
client = httpx.Client(
|
||||
trust_env=has_proxy_env(),
|
||||
event_hooks={"request": [_validate_request_url_sync]},
|
||||
)
|
||||
install_sync_ssrf_protection(client)
|
||||
return client
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
|
||||
331
backend/utils/ssrf.py
Normal file
331
backend/utils/ssrf.py
Normal file
@@ -0,0 +1,331 @@
|
||||
"""SSRF defense for outbound HTTP.
|
||||
|
||||
Two layers, both wired onto every httpx client built by `utils.context`:
|
||||
|
||||
1. `validate_url_for_http_request` — a syntactic fast-fail check
|
||||
installed as an httpx request event hook. Rejects non-HTTP schemes,
|
||||
literal IPs in forbidden ranges (including non-standard IPv4 forms),
|
||||
reserved hostnames, and internal TLDs before any socket opens.
|
||||
|
||||
2. `SSRFProtectedAsyncBackend` / `SSRFProtectedSyncBackend` — custom
|
||||
httpcore network backends that resolve the hostname inside
|
||||
`connect_tcp`, reject any address in a private/loopback/link-local/
|
||||
reserved/multicast/unspecified range, then connect to that *same*
|
||||
validated address. This is what defeats DNS rebinding: the address
|
||||
used by the OS for the TCP connection is the one we just checked,
|
||||
not a fresh lookup the attacker can answer differently. Doing this
|
||||
work in the backend also avoids blocking the event loop, since the
|
||||
async variant uses `loop.getaddrinfo`.
|
||||
|
||||
httpcore calls `start_tls(server_hostname=<URL host>)` after
|
||||
`connect_tcp` returns, so TLS SNI and certificate verification still
|
||||
use the original hostname even though we connect by IP.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import ipaddress
|
||||
import socket
|
||||
import typing
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import httpcore
|
||||
from httpcore._backends.auto import AutoBackend
|
||||
from httpcore._backends.base import (
|
||||
SOCKET_OPTION,
|
||||
AsyncNetworkBackend,
|
||||
AsyncNetworkStream,
|
||||
NetworkBackend,
|
||||
NetworkStream,
|
||||
)
|
||||
from httpcore._backends.sync import SyncBackend
|
||||
|
||||
from logger.logger import log
|
||||
from utils.validation import ValidationError
|
||||
|
||||
|
||||
def is_forbidden_ip(ip: ipaddress.IPv4Address | ipaddress.IPv6Address) -> bool:
|
||||
"""Return True if the IP must not be reached from a server-side HTTP request."""
|
||||
return (
|
||||
ip.is_private
|
||||
or ip.is_loopback
|
||||
or ip.is_link_local
|
||||
or ip.is_reserved
|
||||
or ip.is_multicast
|
||||
or ip.is_unspecified
|
||||
)
|
||||
|
||||
|
||||
def parse_ip_literal(
|
||||
host: str,
|
||||
) -> ipaddress.IPv4Address | ipaddress.IPv6Address | None:
|
||||
"""Return the parsed IP if `host` is a literal address, else None.
|
||||
|
||||
Accepts non-standard IPv4 forms (hex, decimal, shorthand) via
|
||||
`socket.inet_aton`, which is what HTTP clients themselves accept.
|
||||
"""
|
||||
try:
|
||||
return ipaddress.ip_address(host)
|
||||
except ValueError:
|
||||
pass
|
||||
try:
|
||||
packed = socket.inet_aton(host)
|
||||
except OSError:
|
||||
return None
|
||||
return ipaddress.IPv4Address(packed)
|
||||
|
||||
|
||||
def _pick_safe_address(addr_infos: typing.Iterable[typing.Any], host: str) -> str:
|
||||
"""Validate every resolved address and return the literal IP to connect to.
|
||||
|
||||
All returned addresses are checked: if any falls in a forbidden range
|
||||
we reject the whole name, rather than just skipping that record. A
|
||||
malicious DNS server can otherwise mix public and private answers
|
||||
and rely on the client to round-robin.
|
||||
"""
|
||||
chosen: str | None = None
|
||||
for *_, sockaddr in addr_infos:
|
||||
try:
|
||||
ip = ipaddress.ip_address(sockaddr[0])
|
||||
except (ValueError, IndexError):
|
||||
continue
|
||||
if is_forbidden_ip(ip):
|
||||
msg = (
|
||||
f"SSRF prevention: hostname {host!r} resolves to forbidden " f"IP {ip}"
|
||||
)
|
||||
log.error(msg)
|
||||
raise httpcore.ConnectError(msg)
|
||||
if chosen is None:
|
||||
chosen = sockaddr[0]
|
||||
if chosen is None:
|
||||
raise httpcore.ConnectError(f"No usable addresses for {host!r}")
|
||||
return chosen
|
||||
|
||||
|
||||
def _check_literal(host: str) -> bool:
|
||||
"""Return True if `host` is a literal IP that has been validated as safe.
|
||||
|
||||
Raises httpcore.ConnectError if it is a literal IP in a forbidden range.
|
||||
Returns False if `host` is a hostname (caller must resolve and validate).
|
||||
"""
|
||||
literal = parse_ip_literal(host)
|
||||
if literal is None:
|
||||
return False
|
||||
if is_forbidden_ip(literal):
|
||||
msg = f"SSRF prevention: connection to forbidden IP {literal}"
|
||||
log.error(msg)
|
||||
raise httpcore.ConnectError(msg)
|
||||
return True
|
||||
|
||||
|
||||
class SSRFProtectedAsyncBackend(AsyncNetworkBackend):
|
||||
"""Async backend that validates resolved IPs before establishing TCP."""
|
||||
|
||||
def __init__(self, inner: AsyncNetworkBackend | None = None) -> None:
|
||||
self._inner = inner or AutoBackend()
|
||||
|
||||
# `timeout` parameter is required by AsyncNetworkBackend.connect_tcp;
|
||||
# ruff/ASYNC109 advises against timeout parameters on async APIs *we*
|
||||
# author, but we are implementing an external interface here. The
|
||||
# timeout is consumed via `asyncio.timeout()` below, which is the
|
||||
# asyncio-native pattern ASYNC109 endorses.
|
||||
async def connect_tcp(
|
||||
self,
|
||||
host: str,
|
||||
port: int,
|
||||
timeout: float | None = None, # noqa: ASYNC109
|
||||
local_address: str | None = None,
|
||||
socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
|
||||
) -> AsyncNetworkStream:
|
||||
"""Validate the resolved IP, then connect via the inner backend.
|
||||
|
||||
The DNS lookup is wrapped in `asyncio.timeout()` so a slow
|
||||
resolver is bounded by the caller's timeout. The previous code
|
||||
only timed out the TCP connect inside the inner backend, leaving
|
||||
`loop.getaddrinfo` unbounded.
|
||||
"""
|
||||
if _check_literal(host):
|
||||
return await self._inner.connect_tcp(
|
||||
host, port, timeout, local_address, socket_options
|
||||
)
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
try:
|
||||
async with asyncio.timeout(timeout):
|
||||
addr_infos = await loop.getaddrinfo(host, port, type=socket.SOCK_STREAM)
|
||||
except socket.gaierror as exc:
|
||||
raise httpcore.ConnectError(
|
||||
f"DNS resolution failed for {host!r}: {exc}"
|
||||
) from exc
|
||||
except TimeoutError as exc:
|
||||
raise httpcore.ConnectTimeout(
|
||||
f"DNS resolution timed out for {host!r}"
|
||||
) from exc
|
||||
|
||||
pinned_ip = _pick_safe_address(addr_infos, host)
|
||||
return await self._inner.connect_tcp(
|
||||
pinned_ip, port, timeout, local_address, socket_options
|
||||
)
|
||||
|
||||
# See note on connect_tcp re: ASYNC109 / interface implementation.
|
||||
async def connect_unix_socket(
|
||||
self,
|
||||
path: str,
|
||||
timeout: float | None = None, # noqa: ASYNC109
|
||||
socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
|
||||
) -> AsyncNetworkStream:
|
||||
return await self._inner.connect_unix_socket(path, timeout, socket_options)
|
||||
|
||||
async def sleep(self, seconds: float) -> None:
|
||||
await self._inner.sleep(seconds)
|
||||
|
||||
|
||||
class SSRFProtectedSyncBackend(NetworkBackend):
|
||||
"""Sync backend that validates resolved IPs before establishing TCP."""
|
||||
|
||||
def __init__(self, inner: NetworkBackend | None = None) -> None:
|
||||
self._inner = inner if inner is not None else SyncBackend()
|
||||
|
||||
def connect_tcp(
|
||||
self,
|
||||
host: str,
|
||||
port: int,
|
||||
timeout: float | None = None,
|
||||
local_address: str | None = None,
|
||||
socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
|
||||
) -> NetworkStream:
|
||||
if _check_literal(host):
|
||||
return self._inner.connect_tcp(
|
||||
host, port, timeout, local_address, socket_options
|
||||
)
|
||||
|
||||
try:
|
||||
addr_infos = socket.getaddrinfo(host, port, type=socket.SOCK_STREAM)
|
||||
except socket.gaierror as exc:
|
||||
raise httpcore.ConnectError(
|
||||
f"DNS resolution failed for {host!r}: {exc}"
|
||||
) from exc
|
||||
|
||||
pinned_ip = _pick_safe_address(addr_infos, host)
|
||||
return self._inner.connect_tcp(
|
||||
pinned_ip, port, timeout, local_address, socket_options
|
||||
)
|
||||
|
||||
def connect_unix_socket(
|
||||
self,
|
||||
path: str,
|
||||
timeout: float | None = None,
|
||||
socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
|
||||
) -> NetworkStream:
|
||||
return self._inner.connect_unix_socket(path, timeout, socket_options)
|
||||
|
||||
|
||||
def install_async_ssrf_protection(client: typing.Any) -> None:
|
||||
"""Wrap the client's default transport so SSRF validation runs at connect time.
|
||||
|
||||
httpx does not expose `network_backend` through its public transport
|
||||
API, so we mutate `_pool._network_backend` after construction.
|
||||
"""
|
||||
pool = client._transport._pool
|
||||
if not isinstance(pool._network_backend, SSRFProtectedAsyncBackend):
|
||||
pool._network_backend = SSRFProtectedAsyncBackend(inner=pool._network_backend)
|
||||
|
||||
|
||||
def install_sync_ssrf_protection(client: typing.Any) -> None:
|
||||
"""Sync counterpart of `install_async_ssrf_protection`."""
|
||||
pool = client._transport._pool
|
||||
if not isinstance(pool._network_backend, SSRFProtectedSyncBackend):
|
||||
pool._network_backend = SSRFProtectedSyncBackend(inner=pool._network_backend)
|
||||
|
||||
|
||||
RESERVED_HOSTNAMES = [
|
||||
"localhost",
|
||||
"127.0.0.1",
|
||||
"0.0.0.0", # trunk-ignore(bandit/B104)
|
||||
"::1",
|
||||
"::",
|
||||
]
|
||||
|
||||
|
||||
def validate_url_for_http_request(url: str, field_name: str = "URL") -> None:
|
||||
"""Syntactically validate a URL before passing it to an HTTP client.
|
||||
|
||||
Fast-fail check for cases that don't need DNS to detect:
|
||||
|
||||
- The URL scheme is http or https only
|
||||
- If the host is a literal IP address, it is not private/internal/reserved
|
||||
- The host is not a reserved hostname (localhost, 127.0.0.1, etc.)
|
||||
- The host does not use internal TLDs (.local, .internal, .localhost)
|
||||
|
||||
Wired in as an httpx request event hook by `utils.context.create_httpx_*`,
|
||||
so every request that goes through a context-provided client runs this
|
||||
automatically — direct calls from feature code are not required.
|
||||
|
||||
Dynamic SSRF protection (rejecting hostnames that resolve to a private
|
||||
IP, including DNS-rebinding names like `127.0.0.1.nip.io`) happens
|
||||
inside the HTTP client's connect path via the backends above. Doing
|
||||
the DNS check here would (a) be defeated by DNS rebinding because
|
||||
httpx re-resolves at connect time, and (b) block the event loop,
|
||||
since `socket.getaddrinfo` is synchronous and most callers are async.
|
||||
|
||||
Args:
|
||||
url (str): The URL to validate
|
||||
field_name (str): The name of the field for error messages
|
||||
|
||||
Raises:
|
||||
ValidationError: If the URL is syntactically invalid or matches one
|
||||
of the static SSRF deny patterns.
|
||||
"""
|
||||
if not url or not url.strip():
|
||||
msg = f"{field_name} cannot be empty"
|
||||
log.error(msg)
|
||||
raise ValidationError(msg, field_name)
|
||||
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
except Exception as e:
|
||||
msg = f"Invalid {field_name}: unable to parse URL"
|
||||
log.error(f"{msg}: {str(e)}")
|
||||
raise ValidationError(msg, field_name) from e
|
||||
|
||||
# Validate scheme - only allow http and https
|
||||
if parsed.scheme not in ["http", "https"]:
|
||||
msg = f"Invalid {field_name}: only http and https schemes are allowed"
|
||||
log.error(f"SSRF prevention: {msg} - got scheme '{parsed.scheme}'")
|
||||
raise ValidationError(msg, field_name)
|
||||
|
||||
# Extract hostname
|
||||
hostname = parsed.hostname
|
||||
if not hostname:
|
||||
msg = f"Invalid {field_name}: missing hostname"
|
||||
log.error(msg)
|
||||
raise ValidationError(msg, field_name)
|
||||
|
||||
# Block reserved hostnames that are commonly used to refer to internal services.
|
||||
if hostname.lower() in RESERVED_HOSTNAMES:
|
||||
msg = f"Invalid {field_name}: localhost and reserved hostnames are not allowed"
|
||||
log.error(f"SSRF prevention: {msg} - hostname '{hostname}'")
|
||||
raise ValidationError(msg, field_name)
|
||||
|
||||
# Try to parse hostname as a literal IP (standard or non-standard form).
|
||||
# HTTP clients accept hex (0x7f000001), decimal (2130706433), and
|
||||
# shorthand-dotted (127.1) integers via inet_aton, so we mirror that.
|
||||
ip = parse_ip_literal(hostname)
|
||||
if ip is not None:
|
||||
if is_forbidden_ip(ip):
|
||||
msg = (
|
||||
f"Invalid {field_name}: private, internal, reserved, "
|
||||
"or multicast IP addresses are not allowed"
|
||||
)
|
||||
log.error(f"SSRF prevention: {msg} - IP '{ip}'")
|
||||
raise ValidationError(msg, field_name)
|
||||
return
|
||||
|
||||
# Block common internal TLDs
|
||||
hostname_lower = hostname.lower()
|
||||
internal_tlds = [".local", ".internal", ".localhost"]
|
||||
if any(hostname_lower.endswith(tld) for tld in internal_tlds):
|
||||
msg = f"Invalid {field_name}: internal domain names are not allowed"
|
||||
log.error(f"SSRF prevention: {msg} - hostname '{hostname}'")
|
||||
raise ValidationError(msg, field_name)
|
||||
@@ -1,7 +1,4 @@
|
||||
import ipaddress
|
||||
import re
|
||||
import socket
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from logger.logger import log
|
||||
from models.user import TEXT_FIELD_LENGTH
|
||||
@@ -118,112 +115,3 @@ def validate_email(email: str) -> None:
|
||||
msg = "Invalid email format"
|
||||
log.error(f"Validation failed: {msg} for email: {email}")
|
||||
raise ValidationError(msg, "Email")
|
||||
|
||||
|
||||
# Check for localhost and reserved hostnames
|
||||
RESERVED_HOSTNAMES = [
|
||||
"localhost",
|
||||
"127.0.0.1",
|
||||
"0.0.0.0", # trunk-ignore(bandit/B104)
|
||||
"::1",
|
||||
"::",
|
||||
]
|
||||
|
||||
|
||||
def validate_url_for_http_request(url: str, field_name: str = "URL") -> None:
|
||||
"""Validate URL to prevent Server-Side Request Forgery (SSRF) attacks.
|
||||
|
||||
This function validates that:
|
||||
- The URL scheme is http or https only
|
||||
- If the host is a literal IP address, it is not private/internal/reserved
|
||||
- The host is not a reserved hostname (localhost, 127.0.0.1, etc.)
|
||||
- The host does not use internal TLDs (.local, .internal, .localhost)
|
||||
|
||||
Note: This function does NOT perform DNS resolution. Domain names that resolve
|
||||
to private IPs will not be detected (DNS rebinding/internal DNS bypass possible).
|
||||
It only checks literal IP addresses in the hostname.
|
||||
|
||||
Args:
|
||||
url (str): The URL to validate
|
||||
field_name (str): The name of the field for error messages
|
||||
|
||||
Raises:
|
||||
ValidationError: If the URL is invalid or potentially dangerous
|
||||
"""
|
||||
if not url or not url.strip():
|
||||
msg = f"{field_name} cannot be empty"
|
||||
log.error(msg)
|
||||
raise ValidationError(msg, field_name)
|
||||
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
except Exception as e:
|
||||
msg = f"Invalid {field_name}: unable to parse URL"
|
||||
log.error(f"{msg}: {str(e)}")
|
||||
raise ValidationError(msg, field_name) from e
|
||||
|
||||
# Validate scheme - only allow http and https
|
||||
if parsed.scheme not in ["http", "https"]:
|
||||
msg = f"Invalid {field_name}: only http and https schemes are allowed"
|
||||
log.error(f"SSRF prevention: {msg} - got scheme '{parsed.scheme}'")
|
||||
raise ValidationError(msg, field_name)
|
||||
|
||||
# Extract hostname
|
||||
hostname = parsed.hostname
|
||||
if not hostname:
|
||||
msg = f"Invalid {field_name}: missing hostname"
|
||||
log.error(msg)
|
||||
raise ValidationError(msg, field_name)
|
||||
|
||||
if hostname.lower() in RESERVED_HOSTNAMES:
|
||||
msg = f"Invalid {field_name}: localhost and reserved hostnames are not allowed"
|
||||
log.error(f"SSRF prevention: {msg} - hostname '{hostname}'")
|
||||
raise ValidationError(msg, field_name)
|
||||
|
||||
# Try to resolve hostname as IP address
|
||||
try:
|
||||
ip = ipaddress.ip_address(hostname)
|
||||
|
||||
# Block private/internal/link-local IP addresses
|
||||
if ip.is_private or ip.is_loopback or ip.is_link_local or ip.is_reserved:
|
||||
msg = f"Invalid {field_name}: private, internal, and reserved IP addresses are not allowed"
|
||||
log.error(f"SSRF prevention: {msg} - IP '{ip}'")
|
||||
raise ValidationError(msg, field_name)
|
||||
|
||||
# Block multicast addresses
|
||||
if ip.is_multicast:
|
||||
msg = f"Invalid {field_name}: multicast addresses are not allowed"
|
||||
log.error(f"SSRF prevention: {msg} - IP '{ip}'")
|
||||
raise ValidationError(msg, field_name)
|
||||
|
||||
except ValueError as e:
|
||||
# ipaddress.ip_address() only handles standard notation. HTTP clients
|
||||
# also accept hex integers (0x7f000001), decimal integers (2130706433),
|
||||
# shorthand dotted (127.1), and octal (0177.0.0.1). Use socket.inet_aton()
|
||||
# which handles these non-standard IPv4 representations.
|
||||
try:
|
||||
packed = socket.inet_aton(hostname)
|
||||
ip = ipaddress.IPv4Address(packed)
|
||||
|
||||
if ip.is_private or ip.is_loopback or ip.is_link_local or ip.is_reserved:
|
||||
msg = f"Invalid {field_name}: private, internal, and reserved IP addresses are not allowed"
|
||||
log.error(f"SSRF prevention: {msg} - IP '{ip}'")
|
||||
raise ValidationError(msg, field_name)
|
||||
|
||||
if ip.is_multicast:
|
||||
msg = f"Invalid {field_name}: multicast addresses are not allowed"
|
||||
log.error(f"SSRF prevention: {msg} - IP '{ip}'")
|
||||
raise ValidationError(msg, field_name)
|
||||
|
||||
except OSError:
|
||||
pass # Not an IP address at all - fall through to domain name checks
|
||||
|
||||
# Additional checks for suspicious domain patterns
|
||||
hostname_lower = hostname.lower()
|
||||
|
||||
# Block common internal TLDs
|
||||
internal_tlds = [".local", ".internal", ".localhost"]
|
||||
if any(hostname_lower.endswith(tld) for tld in internal_tlds):
|
||||
msg = f"Invalid {field_name}: internal domain names are not allowed"
|
||||
log.error(f"SSRF prevention: {msg} - hostname '{hostname}'")
|
||||
raise ValidationError(msg, field_name) from e
|
||||
|
||||
@@ -15,7 +15,7 @@ dependencies = [
|
||||
"SQLAlchemy[mariadb-connector,mysql-connector,postgresql-psycopg] ~= 2.0",
|
||||
"Unidecode ~= 1.3",
|
||||
"aiohttp~=3.13",
|
||||
"asyncssh ~= 2.17",
|
||||
"asyncssh~=2.23",
|
||||
"alembic ~= 1.16",
|
||||
"anyio ~= 4.4",
|
||||
"authlib~=1.6.12",
|
||||
|
||||
10
uv.lock
generated
10
uv.lock
generated
@@ -7,7 +7,7 @@ resolution-markers = [
|
||||
]
|
||||
|
||||
[options]
|
||||
exclude-newer = "2026-05-21T08:54:16.815535417Z"
|
||||
exclude-newer = "0001-01-01T00:00:00Z" # This has no effect and is included for backwards compatibility when using relative exclude-newer values.
|
||||
exclude-newer-span = "P7D"
|
||||
|
||||
[options.exclude-newer-package]
|
||||
@@ -176,15 +176,15 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "asyncssh"
|
||||
version = "2.22.0"
|
||||
version = "2.23.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "cryptography" },
|
||||
{ name = "typing-extensions" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/fc/d5/957886c316466349d55c4de6a688a10a98295c0b4429deb8db1a17f3eb19/asyncssh-2.22.0.tar.gz", hash = "sha256:c3ce72b01be4f97b40e62844dd384227e5ff5a401a3793007c42f86a5c8eb537", size = 540523, upload-time = "2025-12-21T23:38:30.5Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/ee/fd/c34fe7e30838b4b9cc91903da26a62c6d33b673c731b3d951fcd70ab1889/asyncssh-2.23.0.tar.gz", hash = "sha256:8c54760953c1f2cf282591bcba5c8c70efc48d645bbf26bd2307a9c66a0ed1a7", size = 542154, upload-time = "2026-05-09T03:15:01.856Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/ed/ae/0da2f2214fc183338af1afe5a103a2052fd03464e8eafbd827abff58a4d0/asyncssh-2.22.0-py3-none-any.whl", hash = "sha256:d16465ccdf1ed20eba1131b14415b155e047f6f5be0d19f39c2e0b61331ee0e7", size = 374938, upload-time = "2025-12-21T23:38:28.976Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ff/b5/b1a3979f4840d1271ca8e0978dbccfb18ad2d33b4ece85cf77122fb46e5f/asyncssh-2.23.0-py3-none-any.whl", hash = "sha256:14108bfdaae17457f0c1841e883ad934271bbfdd46458aa4c4d0973451940ad0", size = 375687, upload-time = "2026-05-09T03:15:00.221Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -2255,7 +2255,7 @@ requires-dist = [
|
||||
{ name = "aiohttp", specifier = "~=3.13" },
|
||||
{ name = "alembic", specifier = "~=1.16" },
|
||||
{ name = "anyio", specifier = "~=4.4" },
|
||||
{ name = "asyncssh", specifier = "~=2.17" },
|
||||
{ name = "asyncssh", specifier = "~=2.23" },
|
||||
{ name = "authlib", specifier = "~=1.6.12" },
|
||||
{ name = "bcrypt", specifier = "<5" },
|
||||
{ name = "colorama", specifier = "~=0.4" },
|
||||
|
||||
Reference in New Issue
Block a user