diff --git a/hotdata/_auth.py b/hotdata/_auth.py index 6c47ba3..7ab2882 100644 --- a/hotdata/_auth.py +++ b/hotdata/_auth.py @@ -28,6 +28,11 @@ * **Refresh, then re-mint** -- prefer the refresh token when available; on refresh failure, re-mint from the held API token (always possible since the SDK holds it). Matches the CLI. +* **Transient-failure retry** -- a momentary ``5xx`` or a transport error on the + token endpoint is retried with bounded exponential backoff + jitter + (``_MAX_ATTEMPTS`` total) before giving up, so a brief server-side blip does + not fail the caller (#113). A ``4xx`` is never retried -- a bad/expired + credential is not transient. * **TLS/proxy reuse** -- the exchange call reuses the SDK's configured TLS, client cert and proxy settings (see :func:`_pool_from_config`) so it behaves like every other SDK request, with a bounded timeout so a stalled token @@ -36,6 +41,7 @@ import json import os +import random import ssl import threading import time @@ -47,6 +53,15 @@ _TIMEOUT = 30.0 # seconds -- never let a stalled token endpoint hang every request _CLIENT_ID = "hotdata-python-sdk" +# Bounded retry of *transient* token-exchange failures (#113). A momentary 5xx +# or a transport error (connection/read failure) on the token endpoint should +# not fail the caller outright -- an immediate re-attempt typically succeeds. +# We retry those, but never a 4xx (a bad/expired credential is not transient). +_MAX_ATTEMPTS = 3 # one initial attempt + up to two retries +_BACKOFF_BASE = 0.1 # seconds -- first retry waits ~this, doubling thereafter +_BACKOFF_MAX = 2.0 # cap on a single backoff so a flapping host can't stall us +_BACKOFF_JITTER = 0.5 # additive jitter fraction (delay in [base, 1.5*base]) to spread retries out + # Env var that disables exchange entirely. Used as a hard escape hatch during # the rollout window and for local/dev setups. Only affirmative values opt out # (see _DISABLE_VALUES) so that ``=0`` / ``=false`` do NOT silently disable it. @@ -109,8 +124,11 @@ def _pool_from_config(configuration): pool_args["server_hostname"] = configuration.tls_server_name if configuration.socket_options is not None: pool_args["socket_options"] = configuration.socket_options - # `retries`/`maxsize` are intentionally not mirrored: the exchange is a - # single bounded-timeout request that fails fast rather than retrying. + # `retries`/`maxsize` are intentionally not mirrored: urllib3's own Retry is + # left at urllib3's default (it does not retry POST status codes anyway), so + # we do not inherit the SDK's connection-reset retry here. Transient-failure + # retry for the exchange is handled explicitly in `_TokenManager._exchange` + # (5xx + transport errors, bounded backoff), where we can keep 4xx fatal. if configuration.proxy: if _is_socks_proxy_url(configuration.proxy): @@ -125,6 +143,29 @@ def _pool_from_config(configuration): return urllib3.PoolManager(**pool_args) +def _is_transient_status(status): + """True for HTTP statuses worth retrying (server-side, likely momentary). + + Only 5xx is transient: the request reached the server but it failed to + handle it (e.g. a brief ``500``/``503``). A 4xx -- including ``400``/``401`` + from a bad or expired credential -- is a definitive rejection that a retry + will not fix, so it is never retried. + """ + return 500 <= status < 600 + + +def _backoff_delay(attempt): + """Seconds to sleep before retry number ``attempt`` (0 = first retry). + + Exponential growth from ``_BACKOFF_BASE`` (doubling per attempt) capped at + ``_BACKOFF_MAX``, plus additive jitter in ``[0, _BACKOFF_JITTER * base]`` so + concurrent clients retrying the same blip don't resynchronize into a thundering + herd. Mirrors the Rust SDK's ``backoff_delay``. + """ + base = min(_BACKOFF_BASE * (2 ** attempt), _BACKOFF_MAX) + return base * (1 + _BACKOFF_JITTER * random.random()) + + class _TokenManager: """Exchanges an API token for short-lived JWTs and keeps them fresh. @@ -134,10 +175,11 @@ class _TokenManager: exchanged. """ - def __init__(self, credential, configuration, pool=None): + def __init__(self, credential, configuration, pool=None, sleep=None): self._credential = credential self._config = configuration # read host + TLS lazily at mint time self._pool = pool # injected in tests; else built from config TLS + self._sleep = sleep or time.sleep # injected in tests so retry backoff is instant self._lock = threading.Lock() self._jwt = None self._exp = 0.0 @@ -168,8 +210,18 @@ def bearer_value(self): """ if not self._needs_exchange: return self._credential # already a JWT (or opt-out) -> unchanged + # Lock-free fast path: a still-valid cached JWT needs no mint and must + # not block behind an in-flight (possibly retrying) mint that holds the + # lock for up to several timeouts. Attribute reads are atomic under the + # GIL; the worst a benign jwt/exp race can do is fall through to re-check + # under the lock, and the _LEEWAY margin keeps a token read here valid on + # the wire even if it is about to be rotated. + jwt = self._jwt + if jwt and time.time() < self._exp - _LEEWAY: + return jwt with self._lock: - # Fast path: a still-valid cached JWT, no network call. + # Re-check under the lock: another thread may have minted a fresh JWT + # while we waited (double-checked locking). if self._jwt and time.time() < self._exp - _LEEWAY: return self._jwt # Prefer the refresh token; on failure, drop it and re-mint below. @@ -187,24 +239,12 @@ def _mint(self, params): # -- a non-200, a transport error, or a malformed/missing-token body -- # returns False so the caller re-mints from the held API token. An # api_token mint instead raises TokenExchangeError on any failure, since - # there is no further fallback. + # there is no further fallback. Transient failures (5xx + transport + # errors) are retried inside _exchange before either outcome. params["client_id"] = _CLIENT_ID is_refresh = params["grant_type"] == "refresh_token" try: - pool = self._pool or _pool_from_config(self._config) # reuses ssl_ca_cert/cert/proxy - host = self._config.host.rstrip("/") # read host lazily -- may be set post-construct - resp = pool.request( - "POST", - f"{host}/v1/auth/jwt", - body=urlencode(params), - headers={"Content-Type": "application/x-www-form-urlencoded"}, - timeout=_TIMEOUT, - ) - if resp.status != 200: - raise TokenExchangeError( - f"token exchange failed: {resp.status} {resp.data[:200]!r}" - ) - data = json.loads(resp.data) + data = self._exchange(params) token = data["access_token"] expires_in = float(data.get("expires_in", 300)) except ( @@ -224,6 +264,57 @@ def _mint(self, params): self._refresh = data.get("refresh_token") or self._refresh return True + def _exchange(self, params): + # POST the token-exchange request, retrying transient failures, and + # return the parsed JSON body of the 200 response. + # + # Retries 5xx responses and transport errors (urllib3 HTTPError, e.g. a + # connection/read failure) with bounded exponential backoff + jitter, up + # to _MAX_ATTEMPTS total. A 4xx is returned immediately as a fatal + # TokenExchangeError -- bad/expired credentials are not transient. Once + # the budget is exhausted, the last failure is surfaced: a 5xx as a + # TokenExchangeError preserving the status/body, a transport error as the + # raised HTTPError (which _mint wraps). JSON/missing-token errors from + # parsing a 200 body propagate unretried. + pool = self._pool or _pool_from_config(self._config) # reuses ssl_ca_cert/cert/proxy + host = self._config.host.rstrip("/") # read host lazily -- may be set post-construct + url = f"{host}/v1/auth/jwt" + body = urlencode(params) + headers = {"Content-Type": "application/x-www-form-urlencoded"} + for attempt in range(_MAX_ATTEMPTS): + last = attempt == _MAX_ATTEMPTS - 1 + try: + resp = pool.request( + "POST", + url, + body=body, + headers=headers, + timeout=_TIMEOUT, + # Disable urllib3's own per-request retries so this loop is + # the sole arbiter of the attempt budget. Otherwise urllib3's + # default (Retry(3)) would retry connection errors *inside* + # each attempt, multiplying the effective transport-attempt + # count well past _MAX_ATTEMPTS. + retries=False, + ) + except urllib3.exceptions.HTTPError: + # Transport-level failure (connection/read error): transient, but + # don't retry past the budget -- re-raise for _mint to handle. + if last: + raise + self._sleep(_backoff_delay(attempt)) + continue + if resp.status == 200: + return json.loads(resp.data) + if _is_transient_status(resp.status) and not last: + self._sleep(_backoff_delay(attempt)) + continue + # A 4xx, or a 5xx with the retry budget exhausted: fatal, surfacing + # the last status/body. + raise TokenExchangeError( + f"token exchange failed: {resp.status} {resp.data[:200]!r}" + ) + __all__ = [ "TokenExchangeError", diff --git a/tests/test_auth.py b/tests/test_auth.py index 5c181c0..c77f43d 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -34,11 +34,13 @@ from urllib.parse import parse_qs import pytest +import urllib3 from hotdata import Configuration from hotdata._auth import ( _CLIENT_ID, _LEEWAY, + _MAX_ATTEMPTS, TokenExchangeError, _TokenManager, _pool_from_config, @@ -87,6 +89,7 @@ def request( body: Optional[Any] = None, headers: Optional[Dict[str, str]] = None, timeout: Optional[float] = None, + retries: Any = None, ) -> _FakeResponse: if self._pre_request is not None: self._pre_request() @@ -98,11 +101,18 @@ def request( "body": body, "headers": dict(headers or {}), "timeout": timeout, + "retries": retries, } ) if len(self._responses) > 1: - return self._responses.pop(0) - return self._responses[0] + item = self._responses.pop(0) + else: + item = self._responses[0] + # A scripted *exception* is raised (models a transport error); anything + # else is returned as the response. + if isinstance(item, BaseException): + raise item + return item def _config(host: str = "https://api.hotdata.test") -> Configuration: @@ -295,6 +305,162 @@ def test_non_200_api_token_mint_raises_token_exchange_error() -> None: assert _form(pool.calls[0]["body"])["grant_type"] == ["api_token"] +# -------------------------------------------------------------------------- +# Transient-failure retry (#113) +# -------------------------------------------------------------------------- + + +def test_transient_5xx_is_retried_then_succeeds() -> None: + """A momentary 500 on the token endpoint must be retried, not surfaced -- + an immediate re-attempt succeeds (the CI failure mode from #113).""" + sleeps: List[float] = [] + pool = _FakePool( + [ + _FakeResponse(500, b"upstream hiccup"), + _mint_response(access_token="eyJ.recovered.jwt"), + ] + ) + mgr = _TokenManager("hd_secret_token", _config(), pool=pool, sleep=sleeps.append) + + assert mgr.bearer_value() == "eyJ.recovered.jwt" + # Two pool hits: the failed 500 and the successful retry. + assert len(pool.calls) == 2 + # Backoff slept exactly once, between the two attempts. + assert len(sleeps) == 1 + assert sleeps[0] > 0 + + +def test_transport_error_is_retried_then_succeeds() -> None: + """A transport error (e.g. a dropped connection) before any response is + transient and must be retried.""" + sleeps: List[float] = [] + pool = _FakePool( + [ + urllib3.exceptions.ProtocolError("Connection aborted."), + _mint_response(access_token="eyJ.recovered.jwt"), + ] + ) + mgr = _TokenManager("hd_secret_token", _config(), pool=pool, sleep=sleeps.append) + + assert mgr.bearer_value() == "eyJ.recovered.jwt" + assert len(pool.calls) == 2 + assert len(sleeps) == 1 + + +def test_exchange_disables_urllib3_internal_retries() -> None: + """The explicit loop must be the *sole* arbiter of the attempt budget, so + the request passes ``retries=False`` to stop urllib3 from retrying + connection errors inside each attempt (which would multiply the effective + transport-attempt count past _MAX_ATTEMPTS).""" + pool = _FakePool([_mint_response()]) + mgr = _TokenManager("hd_secret_token", _config(), pool=pool, sleep=lambda _: None) + + mgr.bearer_value() + + assert pool.calls[0]["retries"] is False + + +def test_persistent_transport_error_exhausts_then_raises() -> None: + """Three persistent transport errors exhaust the budget and surface as a + TokenExchangeError (the raised HTTPError is wrapped by _mint).""" + sleeps: List[float] = [] + pool = _FakePool([urllib3.exceptions.ProtocolError("Connection aborted.")]) + mgr = _TokenManager("hd_secret_token", _config(), pool=pool, sleep=sleeps.append) + + with pytest.raises(TokenExchangeError): + mgr.bearer_value() + + assert len(pool.calls) == _MAX_ATTEMPTS + assert len(sleeps) == _MAX_ATTEMPTS - 1 + + +def test_refresh_path_retries_transport_errors_before_remint() -> None: + """The retry budget wraps transport errors on the refresh grant too: every + refresh attempt drops the connection, so the budget is exhausted and the + manager falls back to a re-mint rather than failing.""" + sleeps: List[float] = [] + short_lived = _mint_response( + access_token="eyJ.short.jwt", + refresh_token="rt_first", + expires_in=_LEEWAY - 5, + ) + remint = _mint_response(access_token="eyJ.reminted.jwt", expires_in=300) + pool = _FakePool( + [short_lived] + + [urllib3.exceptions.ProtocolError("Connection aborted.")] * _MAX_ATTEMPTS + + [remint] + ) + mgr = _TokenManager("hd_secret_token", _config(), pool=pool, sleep=sleeps.append) + + assert mgr.bearer_value() == "eyJ.short.jwt" + assert mgr.bearer_value() == "eyJ.reminted.jwt" + + # 1 initial mint + _MAX_ATTEMPTS refresh tries (all transport errors) + the + # successful re-mint. + assert len(pool.calls) == 1 + _MAX_ATTEMPTS + 1 + grants = [_form(c["body"])["grant_type"][0] for c in pool.calls] + assert grants == ["api_token"] + ["refresh_token"] * _MAX_ATTEMPTS + ["api_token"] + + +def test_4xx_is_not_retried() -> None: + """A 4xx (bad/expired credential) is not transient -- it must fail on the + first attempt with no retry.""" + sleeps: List[float] = [] + pool = _FakePool([_FakeResponse(401, {"error": "invalid_grant"})]) + mgr = _TokenManager("hd_bad_token", _config(), pool=pool, sleep=sleeps.append) + + with pytest.raises(TokenExchangeError): + mgr.bearer_value() + + assert len(pool.calls) == 1 + assert sleeps == [] + + +def test_retries_are_bounded_then_surface_last_error() -> None: + """When 5xx persists, retries stop at the bounded budget and the final + error preserves the last status/body.""" + sleeps: List[float] = [] + pool = _FakePool([_FakeResponse(503, b"still overloaded")]) + mgr = _TokenManager("hd_secret_token", _config(), pool=pool, sleep=sleeps.append) + + with pytest.raises(TokenExchangeError) as excinfo: + mgr.bearer_value() + + assert len(pool.calls) == _MAX_ATTEMPTS + assert len(sleeps) == _MAX_ATTEMPTS - 1 + msg = str(excinfo.value) + assert "503" in msg + assert "still overloaded" in msg + + +def test_refresh_path_retries_transient_failures_before_remint() -> None: + """The retry budget wraps the refresh grant too: a transient 5xx on refresh + is retried, and only a fully-exhausted refresh falls back to a re-mint.""" + sleeps: List[float] = [] + short_lived = _mint_response( + access_token="eyJ.short.jwt", + refresh_token="rt_first", + expires_in=_LEEWAY - 5, + ) + # Every refresh attempt 500s (budget exhausted) -> fall back to api_token. + refresh_500 = _FakeResponse(500, b"refresh upstream error") + remint = _mint_response(access_token="eyJ.reminted.jwt", expires_in=300) + pool = _FakePool( + [short_lived] + + [refresh_500] * _MAX_ATTEMPTS + + [remint] + ) + mgr = _TokenManager("hd_secret_token", _config(), pool=pool, sleep=sleeps.append) + + assert mgr.bearer_value() == "eyJ.short.jwt" + assert mgr.bearer_value() == "eyJ.reminted.jwt" + + # 1 initial mint + _MAX_ATTEMPTS refresh tries + 1 successful re-mint. + assert len(pool.calls) == 1 + _MAX_ATTEMPTS + 1 + grants = [_form(c["body"])["grant_type"][0] for c in pool.calls] + assert grants == ["api_token"] + ["refresh_token"] * _MAX_ATTEMPTS + ["api_token"] + + # -------------------------------------------------------------------------- # Opt-out # -------------------------------------------------------------------------- @@ -378,6 +544,33 @@ def worker() -> None: assert len(pool.calls) == 1 +def test_valid_cached_jwt_served_without_blocking_on_mint_lock() -> None: + """A caller holding a still-valid cached JWT must not block behind an + in-flight (possibly retrying) mint that holds the single-flight lock. The + fast path reads the cache lock-free, so a degraded token endpoint stalling + one minting thread cannot serialize callers that need no mint.""" + pool = _FakePool([_mint_response(access_token="eyJ.cached.jwt", expires_in=300)]) + mgr = _TokenManager("hd_secret_token", _config(), pool=pool) + + # Prime the cache with a long-lived JWT. + assert mgr.bearer_value() == "eyJ.cached.jwt" + + # Simulate a mint in flight on another thread by holding the lock. + mgr._lock.acquire() + try: + result: List[str] = [] + t = threading.Thread(target=lambda: result.append(mgr.bearer_value())) + t.start() + t.join(timeout=2.0) + assert not t.is_alive(), "fast path must not block on the held mint lock" + assert result == ["eyJ.cached.jwt"] + finally: + mgr._lock.release() + + # The cached JWT was reused; no second mint happened. + assert len(pool.calls) == 1 + + # -------------------------------------------------------------------------- # Deepcopy round-trip (the lock + pool gotcha) # --------------------------------------------------------------------------