diff --git a/backend/tests/utils/test_ssrf.py b/backend/tests/utils/test_ssrf.py index c9ab74c51..b65aecb31 100644 --- a/backend/tests/utils/test_ssrf.py +++ b/backend/tests/utils/test_ssrf.py @@ -2,6 +2,7 @@ import asyncio import socket +from typing import Any from unittest.mock import AsyncMock, MagicMock import httpcore @@ -16,25 +17,33 @@ from utils.ssrf import ( ) from utils.validation import ValidationError +ConnectCall = tuple[tuple[Any, ...], dict[str, Any]] -def _addr_info(ip: str, port: int): + +def _addr_info(ip: str, port: int) -> list[tuple[Any, ...]]: family = socket.AF_INET6 if ":" in ip else socket.AF_INET return [(family, socket.SOCK_STREAM, 0, "", (ip, port))] -def _stub_async_inner(connect_calls: list): +def _stub_async_inner(connect_calls: list[ConnectCall]) -> MagicMock: inner = MagicMock() - inner.connect_tcp = AsyncMock( - side_effect=lambda *a, **kw: connect_calls.append((a, kw)) or MagicMock() - ) + + def _record(*args: Any, **kwargs: Any) -> MagicMock: + connect_calls.append((args, kwargs)) + return MagicMock() + + inner.connect_tcp = AsyncMock(side_effect=_record) return inner -def _stub_sync_inner(connect_calls: list): +def _stub_sync_inner(connect_calls: list[ConnectCall]) -> MagicMock: inner = MagicMock() - inner.connect_tcp = MagicMock( - side_effect=lambda *a, **kw: connect_calls.append((a, kw)) or MagicMock() - ) + + def _record(*args: Any, **kwargs: Any) -> MagicMock: + connect_calls.append((args, kwargs)) + return MagicMock() + + inner.connect_tcp = MagicMock(side_effect=_record) return inner @@ -87,7 +96,7 @@ class TestParseIpLiteral: class TestSSRFProtectedAsyncBackend: async def test_safe_hostname_connects_to_pinned_ip(self, monkeypatch): """Backend resolves once, validates, and passes the pinned IP to inner.""" - calls = [] + calls: list[ConnectCall] = [] inner = _stub_async_inner(calls) backend = SSRFProtectedAsyncBackend(inner=inner) @@ -140,7 +149,7 @@ class TestSSRFProtectedAsyncBackend: inner.connect_tcp.assert_not_called() async def test_literal_public_ip_passes_through(self): - calls = [] + calls: list[ConnectCall] = [] inner = _stub_async_inner(calls) backend = SSRFProtectedAsyncBackend(inner=inner) await backend.connect_tcp("8.8.8.8", 443) @@ -155,6 +164,26 @@ class TestSSRFProtectedAsyncBackend: await backend.connect_tcp("2130706433", 80) # 127.0.0.1 inner.connect_tcp.assert_not_called() + async def test_dns_timeout_raises_connect_timeout(self, monkeypatch): + """A resolver that hangs past the caller's timeout must not block forever. + + Regression: an earlier version applied the caller's timeout only to + the TCP connect inside the inner backend, leaving `loop.getaddrinfo` + unbounded. We now wrap the lookup in `asyncio.timeout()` so a slow + resolver is bounded by the same budget the caller specified. + """ + inner = _stub_async_inner([]) + backend = SSRFProtectedAsyncBackend(inner=inner) + + async def hang_forever(*args, **kwargs): + await asyncio.sleep(3600) + + monkeypatch.setattr(asyncio.get_running_loop(), "getaddrinfo", hang_forever) + + with pytest.raises(httpcore.ConnectTimeout, match="DNS resolution timed out"): + await backend.connect_tcp("slow.example.com", 80, timeout=0.05) + inner.connect_tcp.assert_not_called() + async def test_dns_failure_propagates_as_connect_error(self, monkeypatch): inner = _stub_async_inner([]) backend = SSRFProtectedAsyncBackend(inner=inner) @@ -171,7 +200,7 @@ class TestSSRFProtectedAsyncBackend: class TestSSRFProtectedSyncBackend: def test_safe_hostname_connects_to_pinned_ip(self, monkeypatch): - calls = [] + calls: list[ConnectCall] = [] inner = _stub_sync_inner(calls) backend = SSRFProtectedSyncBackend(inner=inner) diff --git a/backend/utils/ssrf.py b/backend/utils/ssrf.py index 8c8c5266b..0631a0aed 100644 --- a/backend/utils/ssrf.py +++ b/backend/utils/ssrf.py @@ -125,14 +125,26 @@ class SSRFProtectedAsyncBackend(AsyncNetworkBackend): def __init__(self, inner: AsyncNetworkBackend | None = None) -> None: self._inner = inner if inner is not None else AutoBackend() + # `timeout` parameter is required by AsyncNetworkBackend.connect_tcp; + # ruff/ASYNC109 advises against timeout parameters on async APIs *we* + # author, but we are implementing an external interface here. The + # timeout is consumed via `asyncio.timeout()` below, which is the + # asyncio-native pattern ASYNC109 endorses. async def connect_tcp( self, host: str, port: int, - timeout: float | None = None, + timeout: float | None = None, # noqa: ASYNC109 local_address: str | None = None, socket_options: typing.Iterable[SOCKET_OPTION] | None = None, ) -> AsyncNetworkStream: + """Validate the resolved IP, then connect via the inner backend. + + The DNS lookup is wrapped in `asyncio.timeout()` so a slow + resolver is bounded by the caller's timeout. The previous code + only timed out the TCP connect inside the inner backend, leaving + `loop.getaddrinfo` unbounded. + """ if _check_literal(host): return await self._inner.connect_tcp( host, port, timeout, local_address, socket_options @@ -140,21 +152,27 @@ class SSRFProtectedAsyncBackend(AsyncNetworkBackend): loop = asyncio.get_running_loop() try: - addr_infos = await loop.getaddrinfo(host, port, type=socket.SOCK_STREAM) + async with asyncio.timeout(timeout): + addr_infos = await loop.getaddrinfo(host, port, type=socket.SOCK_STREAM) except socket.gaierror as exc: raise httpcore.ConnectError( f"DNS resolution failed for {host!r}: {exc}" ) from exc + except TimeoutError as exc: + raise httpcore.ConnectTimeout( + f"DNS resolution timed out for {host!r}" + ) from exc pinned_ip = _pick_safe_address(addr_infos, host) return await self._inner.connect_tcp( pinned_ip, port, timeout, local_address, socket_options ) + # See note on connect_tcp re: ASYNC109 / interface implementation. async def connect_unix_socket( self, path: str, - timeout: float | None = None, + timeout: float | None = None, # noqa: ASYNC109 socket_options: typing.Iterable[SOCKET_OPTION] | None = None, ) -> AsyncNetworkStream: return await self._inner.connect_unix_socket(path, timeout, socket_options)