Coverage for gco / services / auth_middleware.py: 99%
106 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-30 21:47 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-30 21:47 +0000
1"""
2Authentication middleware for validating requests from API Gateway.
4This middleware ensures all requests (except health checks) contain a valid
5X-GCO-Auth-Token header that matches the secret stored in AWS Secrets Manager.
6This proves the request came through the authenticated API Gateway path.
8Security Flow:
9 1. API Gateway validates IAM credentials (SigV4)
10 2. Lambda proxy adds secret token header
11 3. This middleware validates the token
12 4. Invalid tokens result in 403 Forbidden
14Secret Rotation Support:
15 During rotation, the middleware validates against both AWSCURRENT and AWSPENDING
16 versions of the secret. This ensures zero-downtime during the rotation window.
17 The cache is refreshed periodically to pick up rotated secrets.
19Environment Variables:
20 AUTH_SECRET_ARN: ARN of the Secrets Manager secret containing the token
21 GCO_DEV_MODE: Set to "true" to allow unauthenticated requests when no
22 secret is configured. Without this flag, missing AUTH_SECRET_ARN
23 causes 503 errors (fail-closed). This prevents accidental
24 unauthenticated deployments due to misconfiguration.
25"""
27from __future__ import annotations
29import json
30import logging
31import os
32import time
33from collections.abc import Awaitable, Callable
34from typing import Any
36import boto3
37from fastapi import HTTPException, Request
38from starlette.middleware.base import BaseHTTPMiddleware
39from starlette.responses import Response
40from starlette.types import ASGIApp
42logger = logging.getLogger(__name__)
44# Module-level cache for secret tokens and client
45_cached_tokens: set[str] = set()
46_cache_timestamp: float = 0
47_secrets_client = None
49# Cache TTL in seconds (5 minutes) - allows picking up rotated secrets
50CACHE_TTL_SECONDS = 300
52# Endpoints that bypass authentication (health checks for load balancers and
53# Global Accelerator). /api/v1/health is included so GA can perform HTTP
54# health checks for intelligent routing without the secret header.
55UNAUTHENTICATED_PATHS = frozenset(["/healthz", "/readyz", "/metrics", "/api/v1/health"])
58def get_secrets_client() -> Any:
59 """
60 Get Secrets Manager client with lazy initialization.
62 The client is configured to use the region from the AUTH_SECRET_ARN
63 environment variable, which may be different from the default region.
65 Returns:
66 boto3 Secrets Manager client instance
67 """
68 global _secrets_client
69 if _secrets_client is None:
70 # Extract region from the secret ARN
71 # Format: arn:aws:secretsmanager:REGION:ACCOUNT:secret:NAME
72 secret_arn = os.environ.get("AUTH_SECRET_ARN", "")
73 region = None
74 if secret_arn:
75 parts = secret_arn.split(":")
76 if len(parts) >= 4: 76 ↛ 78line 76 didn't jump to line 78 because the condition on line 76 was always true
77 region = parts[3]
78 _secrets_client = boto3.client("secretsmanager", region_name=region)
79 return _secrets_client
82def _is_cache_valid() -> bool:
83 """Check if the cached tokens are still valid based on TTL."""
84 return bool(_cached_tokens) and (time.time() - _cache_timestamp) < CACHE_TTL_SECONDS
87def _refresh_cache() -> None:
88 """Refresh the token cache from Secrets Manager.
90 On failure, keeps the existing (stale) cache to avoid rejecting all
91 requests during a transient Secrets Manager outage. The next call
92 after CACHE_TTL_SECONDS will retry the refresh.
93 """
94 global _cached_tokens, _cache_timestamp
96 secret_arn = os.environ.get("AUTH_SECRET_ARN")
97 if not secret_arn:
98 return
100 try:
101 secrets = get_secrets_client()
102 new_tokens: set[str] = set()
104 # Get AWSCURRENT version (always present)
105 try:
106 response = secrets.get_secret_value(
107 SecretId=secret_arn,
108 VersionStage="AWSCURRENT",
109 )
110 secret_data = json.loads(response["SecretString"])
111 new_tokens.add(secret_data["token"])
112 logger.debug("Loaded AWSCURRENT token")
113 except Exception as e:
114 logger.error(f"Failed to load AWSCURRENT secret: {e}")
116 # Get AWSPENDING version (only present during rotation)
117 try:
118 response = secrets.get_secret_value(
119 SecretId=secret_arn,
120 VersionStage="AWSPENDING",
121 )
122 secret_data = json.loads(response["SecretString"])
123 new_tokens.add(secret_data["token"])
124 logger.debug("Loaded AWSPENDING token (rotation in progress)")
125 except secrets.exceptions.ResourceNotFoundException:
126 # No pending version - not in rotation, this is normal
127 pass
128 except Exception as e:
129 # Log but don't fail - AWSPENDING is optional
130 logger.debug(f"No AWSPENDING version available: {e}")
132 if new_tokens:
133 _cached_tokens = new_tokens
134 _cache_timestamp = time.time()
135 logger.info(f"Token cache refreshed with {len(new_tokens)} valid token(s)")
136 elif _cached_tokens:
137 # Couldn't load any new tokens but have stale ones — extend the cache
138 # to avoid rejecting all traffic during a transient SM outage
139 _cache_timestamp = time.time()
140 logger.warning("Token refresh returned empty set, keeping stale cache")
142 except Exception as e:
143 logger.error(f"Failed to refresh token cache: {e}")
144 if _cached_tokens:
145 # Extend stale cache on total failure — better to accept slightly-old
146 # tokens than to reject everything
147 _cache_timestamp = time.time()
148 logger.warning("Extending stale token cache due to refresh failure")
151def get_valid_tokens() -> set[str]:
152 """
153 Retrieve valid authentication tokens from AWS Secrets Manager.
155 Returns both AWSCURRENT and AWSPENDING tokens to support zero-downtime
156 rotation. The tokens are cached with a TTL to minimize API calls while
157 still picking up rotated secrets in a reasonable time.
159 Returns:
160 Set of valid token strings, or empty set if not configured
161 """
162 if not _is_cache_valid():
163 _refresh_cache()
165 return _cached_tokens
168def get_secret_token() -> str | None:
169 """
170 Retrieve the primary authentication token from AWS Secrets Manager.
172 This is a compatibility function that returns the first valid token.
173 For rotation support, use get_valid_tokens() instead.
175 Returns:
176 The secret token string, or None if not configured
177 """
178 tokens = get_valid_tokens()
179 return next(iter(tokens), None) if tokens else None
182def clear_token_cache() -> None:
183 """
184 Clear the token cache, forcing a refresh on next validation.
186 Useful for testing or when you know the secret has been rotated.
187 """
188 global _cached_tokens, _cache_timestamp
189 _cached_tokens = set()
190 _cache_timestamp = 0
191 logger.info("Token cache cleared")
194class AuthenticationMiddleware(BaseHTTPMiddleware):
195 """
196 FastAPI middleware to validate X-GCO-Auth-Token header.
198 This middleware ensures all API requests came through the authenticated
199 API Gateway by validating a secret token header. Health check endpoints
200 are excluded to allow load balancer health probes.
202 During secret rotation, both AWSCURRENT and AWSPENDING tokens are accepted
203 to ensure zero-downtime rotation.
204 """
206 def __init__(self, app: ASGIApp) -> None:
207 super().__init__(app)
208 # Startup-time configuration check — surface misconfigurations early
209 secret_arn = os.environ.get("AUTH_SECRET_ARN")
210 if not secret_arn:
211 dev_mode = os.environ.get("GCO_DEV_MODE", "").lower() == "true"
212 if dev_mode:
213 logger.warning(
214 "GCO_DEV_MODE=true with no AUTH_SECRET_ARN — "
215 "authentication is bypassed. Do NOT use in production."
216 )
217 else:
218 logger.error(
219 "AUTH_SECRET_ARN is not configured and GCO_DEV_MODE is not enabled. "
220 "All non-health-check requests will be denied with 503."
221 )
223 async def dispatch(
224 self,
225 request: Request,
226 call_next: Callable[[Request], Awaitable[Response]],
227 ) -> Response:
228 """
229 Process incoming request and validate authentication.
231 Args:
232 request: The incoming FastAPI request
233 call_next: The next middleware/handler in the chain
235 Returns:
236 Response from the next handler if authenticated
238 Raises:
239 HTTPException: 403 if authentication fails
240 """
241 # Skip authentication for health check endpoints
242 if request.url.path in UNAUTHENTICATED_PATHS:
243 return await call_next(request)
245 valid_tokens = get_valid_tokens()
247 # No tokens available — determine whether to fail open or closed
248 if not valid_tokens:
249 secret_arn = os.environ.get("AUTH_SECRET_ARN")
250 if not secret_arn:
251 # No secret configured. Only allow requests if the operator
252 # explicitly opted into dev mode. This prevents accidental
253 # unauthenticated deployments due to misconfiguration.
254 dev_mode = os.environ.get("GCO_DEV_MODE", "").lower() == "true"
255 if dev_mode:
256 logger.warning(
257 "Authentication bypassed - GCO_DEV_MODE=true, no secret configured"
258 )
259 return await call_next(request)
260 # Fail closed: no secret + no dev mode = deny
261 logger.error(
262 "No AUTH_SECRET_ARN configured and GCO_DEV_MODE is not enabled. "
263 "Set AUTH_SECRET_ARN for production or GCO_DEV_MODE=true for local development."
264 )
265 raise HTTPException(
266 status_code=503,
267 detail="Service unavailable - authentication not configured",
268 )
269 # Secret configured but couldn't load - deny access
270 logger.error("Failed to load authentication tokens")
271 raise HTTPException(
272 status_code=503,
273 detail="Service temporarily unavailable - authentication error",
274 )
276 # Validate the auth header against all valid tokens
277 auth_header = request.headers.get("x-gco-auth-token", "")
279 if auth_header not in valid_tokens:
280 client_ip = request.client.host if request.client else "unknown"
281 logger.warning(f"Invalid auth token from {client_ip} for {request.url.path}")
282 raise HTTPException(
283 status_code=403,
284 detail="Forbidden - requests must come through authenticated API Gateway",
285 )
287 return await call_next(request)