Files
romm/backend/utils/rate_limiter.py
Georges-Antoine Assi ab9b7bd775 changes from self review
2026-06-07 08:29:49 -04:00

107 lines
3.7 KiB
Python

import asyncio
from collections import deque
class RateLimiter:
"""Pre-emptive async rate limiter.
Spaces grants so callers stay at or below ``requests_per_second``. Concurrent
callers each reserve the next slot before awaiting, so they are evenly spaced
instead of being released in a burst.
"""
def __init__(self, requests_per_second: float) -> None:
if requests_per_second <= 0:
raise ValueError("requests_per_second must be positive")
self._min_interval = 1.0 / requests_per_second
self._next_slot = 0.0
async def acquire(self) -> None:
# Reserving the slot is a read-then-write on _next_slot with no await
# between, so it is atomic on the single-threaded loop and needs no lock.
loop = asyncio.get_running_loop()
now = loop.time()
slot = max(now, self._next_slot)
self._next_slot = slot + self._min_interval
delay = slot - now
if delay > 0:
await asyncio.sleep(delay)
class ConcurrencyLimiter:
"""Caps the number of in-flight operations, with a runtime-adjustable capacity.
It suits APIs that enforce a per-account thread/connection cap (e.g. ScreenScraper)
rather than a call rate. Because a slot is held for the whole request, slow
responses can never cause overlapping requests to exceed the cap.
Use it as an async context manager so the slot is always released, even if the
wrapped request raises:
async with limiter:
await do_request()
"""
def __init__(self, max_concurrency: int) -> None:
if max_concurrency < 1:
raise ValueError("max_concurrency must be at least 1")
self._max_concurrency = max_concurrency
self._in_flight = 0
self._waiters: deque[asyncio.Future[None]] = deque()
@property
def max_concurrency(self) -> int:
return self._max_concurrency
@property
def in_flight(self) -> int:
return self._in_flight
def set_max_concurrency(self, max_concurrency: int) -> None:
if max_concurrency < 1:
raise ValueError("max_concurrency must be at least 1")
previous = self._max_concurrency
self._max_concurrency = max_concurrency
# Wake one waiter per newly opened slot; each re-checks capacity itself.
for _ in range(max(0, max_concurrency - previous)):
self._wake_next()
async def acquire(self) -> None:
# Re-check on every wake-up, as another coroutine may have taken the slot,
# or the capacity may have been lowered while we waited.
while self._in_flight >= self._max_concurrency:
loop = asyncio.get_running_loop()
waiter = loop.create_future()
self._waiters.append(waiter)
try:
try:
await waiter
finally:
self._waiters.remove(waiter)
except asyncio.CancelledError:
# We were granted a slot but cancelled before using it, so pass the
# grant on so a waiting peer is not stranded.
if not waiter.cancelled():
self._wake_next()
raise
self._in_flight += 1
def release(self) -> None:
if self._in_flight > 0:
self._in_flight -= 1
self._wake_next()
def _wake_next(self) -> None:
for waiter in self._waiters:
if not waiter.done():
waiter.set_result(None)
return
async def __aenter__(self) -> "ConcurrencyLimiter":
await self.acquire()
return self
async def __aexit__(self, *exc_info: object) -> None:
self.release()