mirror of
https://github.com/rommapp/romm.git
synced 2026-06-27 22:35:57 +00:00
fix(ssrf): bound DNS lookup by caller timeout; clear lint findings
The async backend's `loop.getaddrinfo` ran without any timeout, so a slow or hanging resolver could outlive the timeout the caller passed — the previous code only bounded the TCP connect inside the inner backend. Wrap the resolution in `asyncio.timeout(timeout)` and surface the timeout as `httpcore.ConnectTimeout`. Also tidy the test stubs (mypy func-returns-value) and add explicit type annotations to the `calls` lists (mypy var-annotated). A targeted `# noqa: ASYNC109` sits on the `timeout` parameter of `connect_tcp` / `connect_unix_socket` with an explanatory comment: the rule advises against `timeout` parameters on async APIs we author, but here we're implementing `AsyncNetworkBackend`, and the timeout is consumed in the asyncio-native pattern the rule endorses. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
|
||||
import asyncio
|
||||
import socket
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import httpcore
|
||||
@@ -16,25 +17,33 @@ from utils.ssrf import (
|
||||
)
|
||||
from utils.validation import ValidationError
|
||||
|
||||
ConnectCall = tuple[tuple[Any, ...], dict[str, Any]]
|
||||
|
||||
def _addr_info(ip: str, port: int):
|
||||
|
||||
def _addr_info(ip: str, port: int) -> list[tuple[Any, ...]]:
|
||||
family = socket.AF_INET6 if ":" in ip else socket.AF_INET
|
||||
return [(family, socket.SOCK_STREAM, 0, "", (ip, port))]
|
||||
|
||||
|
||||
def _stub_async_inner(connect_calls: list):
|
||||
def _stub_async_inner(connect_calls: list[ConnectCall]) -> MagicMock:
|
||||
inner = MagicMock()
|
||||
inner.connect_tcp = AsyncMock(
|
||||
side_effect=lambda *a, **kw: connect_calls.append((a, kw)) or MagicMock()
|
||||
)
|
||||
|
||||
def _record(*args: Any, **kwargs: Any) -> MagicMock:
|
||||
connect_calls.append((args, kwargs))
|
||||
return MagicMock()
|
||||
|
||||
inner.connect_tcp = AsyncMock(side_effect=_record)
|
||||
return inner
|
||||
|
||||
|
||||
def _stub_sync_inner(connect_calls: list):
|
||||
def _stub_sync_inner(connect_calls: list[ConnectCall]) -> MagicMock:
|
||||
inner = MagicMock()
|
||||
inner.connect_tcp = MagicMock(
|
||||
side_effect=lambda *a, **kw: connect_calls.append((a, kw)) or MagicMock()
|
||||
)
|
||||
|
||||
def _record(*args: Any, **kwargs: Any) -> MagicMock:
|
||||
connect_calls.append((args, kwargs))
|
||||
return MagicMock()
|
||||
|
||||
inner.connect_tcp = MagicMock(side_effect=_record)
|
||||
return inner
|
||||
|
||||
|
||||
@@ -87,7 +96,7 @@ class TestParseIpLiteral:
|
||||
class TestSSRFProtectedAsyncBackend:
|
||||
async def test_safe_hostname_connects_to_pinned_ip(self, monkeypatch):
|
||||
"""Backend resolves once, validates, and passes the pinned IP to inner."""
|
||||
calls = []
|
||||
calls: list[ConnectCall] = []
|
||||
inner = _stub_async_inner(calls)
|
||||
backend = SSRFProtectedAsyncBackend(inner=inner)
|
||||
|
||||
@@ -140,7 +149,7 @@ class TestSSRFProtectedAsyncBackend:
|
||||
inner.connect_tcp.assert_not_called()
|
||||
|
||||
async def test_literal_public_ip_passes_through(self):
|
||||
calls = []
|
||||
calls: list[ConnectCall] = []
|
||||
inner = _stub_async_inner(calls)
|
||||
backend = SSRFProtectedAsyncBackend(inner=inner)
|
||||
await backend.connect_tcp("8.8.8.8", 443)
|
||||
@@ -155,6 +164,26 @@ class TestSSRFProtectedAsyncBackend:
|
||||
await backend.connect_tcp("2130706433", 80) # 127.0.0.1
|
||||
inner.connect_tcp.assert_not_called()
|
||||
|
||||
async def test_dns_timeout_raises_connect_timeout(self, monkeypatch):
|
||||
"""A resolver that hangs past the caller's timeout must not block forever.
|
||||
|
||||
Regression: an earlier version applied the caller's timeout only to
|
||||
the TCP connect inside the inner backend, leaving `loop.getaddrinfo`
|
||||
unbounded. We now wrap the lookup in `asyncio.timeout()` so a slow
|
||||
resolver is bounded by the same budget the caller specified.
|
||||
"""
|
||||
inner = _stub_async_inner([])
|
||||
backend = SSRFProtectedAsyncBackend(inner=inner)
|
||||
|
||||
async def hang_forever(*args, **kwargs):
|
||||
await asyncio.sleep(3600)
|
||||
|
||||
monkeypatch.setattr(asyncio.get_running_loop(), "getaddrinfo", hang_forever)
|
||||
|
||||
with pytest.raises(httpcore.ConnectTimeout, match="DNS resolution timed out"):
|
||||
await backend.connect_tcp("slow.example.com", 80, timeout=0.05)
|
||||
inner.connect_tcp.assert_not_called()
|
||||
|
||||
async def test_dns_failure_propagates_as_connect_error(self, monkeypatch):
|
||||
inner = _stub_async_inner([])
|
||||
backend = SSRFProtectedAsyncBackend(inner=inner)
|
||||
@@ -171,7 +200,7 @@ class TestSSRFProtectedAsyncBackend:
|
||||
|
||||
class TestSSRFProtectedSyncBackend:
|
||||
def test_safe_hostname_connects_to_pinned_ip(self, monkeypatch):
|
||||
calls = []
|
||||
calls: list[ConnectCall] = []
|
||||
inner = _stub_sync_inner(calls)
|
||||
backend = SSRFProtectedSyncBackend(inner=inner)
|
||||
|
||||
|
||||
@@ -125,14 +125,26 @@ class SSRFProtectedAsyncBackend(AsyncNetworkBackend):
|
||||
def __init__(self, inner: AsyncNetworkBackend | None = None) -> None:
|
||||
self._inner = inner if inner is not None else AutoBackend()
|
||||
|
||||
# `timeout` parameter is required by AsyncNetworkBackend.connect_tcp;
|
||||
# ruff/ASYNC109 advises against timeout parameters on async APIs *we*
|
||||
# author, but we are implementing an external interface here. The
|
||||
# timeout is consumed via `asyncio.timeout()` below, which is the
|
||||
# asyncio-native pattern ASYNC109 endorses.
|
||||
async def connect_tcp(
|
||||
self,
|
||||
host: str,
|
||||
port: int,
|
||||
timeout: float | None = None,
|
||||
timeout: float | None = None, # noqa: ASYNC109
|
||||
local_address: str | None = None,
|
||||
socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
|
||||
) -> AsyncNetworkStream:
|
||||
"""Validate the resolved IP, then connect via the inner backend.
|
||||
|
||||
The DNS lookup is wrapped in `asyncio.timeout()` so a slow
|
||||
resolver is bounded by the caller's timeout. The previous code
|
||||
only timed out the TCP connect inside the inner backend, leaving
|
||||
`loop.getaddrinfo` unbounded.
|
||||
"""
|
||||
if _check_literal(host):
|
||||
return await self._inner.connect_tcp(
|
||||
host, port, timeout, local_address, socket_options
|
||||
@@ -140,21 +152,27 @@ class SSRFProtectedAsyncBackend(AsyncNetworkBackend):
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
try:
|
||||
addr_infos = await loop.getaddrinfo(host, port, type=socket.SOCK_STREAM)
|
||||
async with asyncio.timeout(timeout):
|
||||
addr_infos = await loop.getaddrinfo(host, port, type=socket.SOCK_STREAM)
|
||||
except socket.gaierror as exc:
|
||||
raise httpcore.ConnectError(
|
||||
f"DNS resolution failed for {host!r}: {exc}"
|
||||
) from exc
|
||||
except TimeoutError as exc:
|
||||
raise httpcore.ConnectTimeout(
|
||||
f"DNS resolution timed out for {host!r}"
|
||||
) from exc
|
||||
|
||||
pinned_ip = _pick_safe_address(addr_infos, host)
|
||||
return await self._inner.connect_tcp(
|
||||
pinned_ip, port, timeout, local_address, socket_options
|
||||
)
|
||||
|
||||
# See note on connect_tcp re: ASYNC109 / interface implementation.
|
||||
async def connect_unix_socket(
|
||||
self,
|
||||
path: str,
|
||||
timeout: float | None = None,
|
||||
timeout: float | None = None, # noqa: ASYNC109
|
||||
socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
|
||||
) -> AsyncNetworkStream:
|
||||
return await self._inner.connect_unix_socket(path, timeout, socket_options)
|
||||
|
||||
Reference in New Issue
Block a user