fix(ssrf): bound DNS lookup by caller timeout; clear lint findings

The async backend's `loop.getaddrinfo` ran without any timeout, so a
slow or hanging resolver could outlive the timeout the caller passed —
the previous code only bounded the TCP connect inside the inner
backend. Wrap the resolution in `asyncio.timeout(timeout)` and surface
the timeout as `httpcore.ConnectTimeout`.

Also tidy the test stubs (mypy func-returns-value) and add explicit
type annotations to the `calls` lists (mypy var-annotated). A targeted
`# noqa: ASYNC109` sits on the `timeout` parameter of `connect_tcp` /
`connect_unix_socket` with an explanatory comment: the rule advises
against `timeout` parameters on async APIs we author, but here we're
implementing `AsyncNetworkBackend`, and the timeout is consumed in the
asyncio-native pattern the rule endorses.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Georges-Antoine Assi
2026-05-27 18:31:42 -04:00
parent 30451d5651
commit c3adbd3f71
2 changed files with 62 additions and 15 deletions

View File

@@ -2,6 +2,7 @@
import asyncio
import socket
from typing import Any
from unittest.mock import AsyncMock, MagicMock
import httpcore
@@ -16,25 +17,33 @@ from utils.ssrf import (
)
from utils.validation import ValidationError
ConnectCall = tuple[tuple[Any, ...], dict[str, Any]]
def _addr_info(ip: str, port: int):
def _addr_info(ip: str, port: int) -> list[tuple[Any, ...]]:
family = socket.AF_INET6 if ":" in ip else socket.AF_INET
return [(family, socket.SOCK_STREAM, 0, "", (ip, port))]
def _stub_async_inner(connect_calls: list):
def _stub_async_inner(connect_calls: list[ConnectCall]) -> MagicMock:
inner = MagicMock()
inner.connect_tcp = AsyncMock(
side_effect=lambda *a, **kw: connect_calls.append((a, kw)) or MagicMock()
)
def _record(*args: Any, **kwargs: Any) -> MagicMock:
connect_calls.append((args, kwargs))
return MagicMock()
inner.connect_tcp = AsyncMock(side_effect=_record)
return inner
def _stub_sync_inner(connect_calls: list):
def _stub_sync_inner(connect_calls: list[ConnectCall]) -> MagicMock:
inner = MagicMock()
inner.connect_tcp = MagicMock(
side_effect=lambda *a, **kw: connect_calls.append((a, kw)) or MagicMock()
)
def _record(*args: Any, **kwargs: Any) -> MagicMock:
connect_calls.append((args, kwargs))
return MagicMock()
inner.connect_tcp = MagicMock(side_effect=_record)
return inner
@@ -87,7 +96,7 @@ class TestParseIpLiteral:
class TestSSRFProtectedAsyncBackend:
async def test_safe_hostname_connects_to_pinned_ip(self, monkeypatch):
"""Backend resolves once, validates, and passes the pinned IP to inner."""
calls = []
calls: list[ConnectCall] = []
inner = _stub_async_inner(calls)
backend = SSRFProtectedAsyncBackend(inner=inner)
@@ -140,7 +149,7 @@ class TestSSRFProtectedAsyncBackend:
inner.connect_tcp.assert_not_called()
async def test_literal_public_ip_passes_through(self):
calls = []
calls: list[ConnectCall] = []
inner = _stub_async_inner(calls)
backend = SSRFProtectedAsyncBackend(inner=inner)
await backend.connect_tcp("8.8.8.8", 443)
@@ -155,6 +164,26 @@ class TestSSRFProtectedAsyncBackend:
await backend.connect_tcp("2130706433", 80) # 127.0.0.1
inner.connect_tcp.assert_not_called()
async def test_dns_timeout_raises_connect_timeout(self, monkeypatch):
"""A resolver that hangs past the caller's timeout must not block forever.
Regression: an earlier version applied the caller's timeout only to
the TCP connect inside the inner backend, leaving `loop.getaddrinfo`
unbounded. We now wrap the lookup in `asyncio.timeout()` so a slow
resolver is bounded by the same budget the caller specified.
"""
inner = _stub_async_inner([])
backend = SSRFProtectedAsyncBackend(inner=inner)
async def hang_forever(*args, **kwargs):
await asyncio.sleep(3600)
monkeypatch.setattr(asyncio.get_running_loop(), "getaddrinfo", hang_forever)
with pytest.raises(httpcore.ConnectTimeout, match="DNS resolution timed out"):
await backend.connect_tcp("slow.example.com", 80, timeout=0.05)
inner.connect_tcp.assert_not_called()
async def test_dns_failure_propagates_as_connect_error(self, monkeypatch):
inner = _stub_async_inner([])
backend = SSRFProtectedAsyncBackend(inner=inner)
@@ -171,7 +200,7 @@ class TestSSRFProtectedAsyncBackend:
class TestSSRFProtectedSyncBackend:
def test_safe_hostname_connects_to_pinned_ip(self, monkeypatch):
calls = []
calls: list[ConnectCall] = []
inner = _stub_sync_inner(calls)
backend = SSRFProtectedSyncBackend(inner=inner)

View File

@@ -125,14 +125,26 @@ class SSRFProtectedAsyncBackend(AsyncNetworkBackend):
def __init__(self, inner: AsyncNetworkBackend | None = None) -> None:
self._inner = inner if inner is not None else AutoBackend()
# `timeout` parameter is required by AsyncNetworkBackend.connect_tcp;
# ruff/ASYNC109 advises against timeout parameters on async APIs *we*
# author, but we are implementing an external interface here. The
# timeout is consumed via `asyncio.timeout()` below, which is the
# asyncio-native pattern ASYNC109 endorses.
async def connect_tcp(
self,
host: str,
port: int,
timeout: float | None = None,
timeout: float | None = None, # noqa: ASYNC109
local_address: str | None = None,
socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
) -> AsyncNetworkStream:
"""Validate the resolved IP, then connect via the inner backend.
The DNS lookup is wrapped in `asyncio.timeout()` so a slow
resolver is bounded by the caller's timeout. The previous code
only timed out the TCP connect inside the inner backend, leaving
`loop.getaddrinfo` unbounded.
"""
if _check_literal(host):
return await self._inner.connect_tcp(
host, port, timeout, local_address, socket_options
@@ -140,21 +152,27 @@ class SSRFProtectedAsyncBackend(AsyncNetworkBackend):
loop = asyncio.get_running_loop()
try:
addr_infos = await loop.getaddrinfo(host, port, type=socket.SOCK_STREAM)
async with asyncio.timeout(timeout):
addr_infos = await loop.getaddrinfo(host, port, type=socket.SOCK_STREAM)
except socket.gaierror as exc:
raise httpcore.ConnectError(
f"DNS resolution failed for {host!r}: {exc}"
) from exc
except TimeoutError as exc:
raise httpcore.ConnectTimeout(
f"DNS resolution timed out for {host!r}"
) from exc
pinned_ip = _pick_safe_address(addr_infos, host)
return await self._inner.connect_tcp(
pinned_ip, port, timeout, local_address, socket_options
)
# See note on connect_tcp re: ASYNC109 / interface implementation.
async def connect_unix_socket(
self,
path: str,
timeout: float | None = None,
timeout: float | None = None, # noqa: ASYNC109
socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
) -> AsyncNetworkStream:
return await self._inner.connect_unix_socket(path, timeout, socket_options)