Coverage for gco / services / auth_middleware.py: 99%

106 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-30 21:47 +0000

1""" 

2Authentication middleware for validating requests from API Gateway. 

3 

4This middleware ensures all requests (except health checks) contain a valid 

5X-GCO-Auth-Token header that matches the secret stored in AWS Secrets Manager. 

6This proves the request came through the authenticated API Gateway path. 

7 

8Security Flow: 

9 1. API Gateway validates IAM credentials (SigV4) 

10 2. Lambda proxy adds secret token header 

11 3. This middleware validates the token 

12 4. Invalid tokens result in 403 Forbidden 

13 

14Secret Rotation Support: 

15 During rotation, the middleware validates against both AWSCURRENT and AWSPENDING 

16 versions of the secret. This ensures zero-downtime during the rotation window. 

17 The cache is refreshed periodically to pick up rotated secrets. 

18 

19Environment Variables: 

20 AUTH_SECRET_ARN: ARN of the Secrets Manager secret containing the token 

21 GCO_DEV_MODE: Set to "true" to allow unauthenticated requests when no 

22 secret is configured. Without this flag, missing AUTH_SECRET_ARN 

23 causes 503 errors (fail-closed). This prevents accidental 

24 unauthenticated deployments due to misconfiguration. 

25""" 

26 

27from __future__ import annotations 

28 

29import json 

30import logging 

31import os 

32import time 

33from collections.abc import Awaitable, Callable 

34from typing import Any 

35 

36import boto3 

37from fastapi import HTTPException, Request 

38from starlette.middleware.base import BaseHTTPMiddleware 

39from starlette.responses import Response 

40from starlette.types import ASGIApp 

41 

42logger = logging.getLogger(__name__) 

43 

44# Module-level cache for secret tokens and client 

45_cached_tokens: set[str] = set() 

46_cache_timestamp: float = 0 

47_secrets_client = None 

48 

49# Cache TTL in seconds (5 minutes) - allows picking up rotated secrets 

50CACHE_TTL_SECONDS = 300 

51 

52# Endpoints that bypass authentication (health checks for load balancers and 

53# Global Accelerator). /api/v1/health is included so GA can perform HTTP 

54# health checks for intelligent routing without the secret header. 

55UNAUTHENTICATED_PATHS = frozenset(["/healthz", "/readyz", "/metrics", "/api/v1/health"]) 

56 

57 

58def get_secrets_client() -> Any: 

59 """ 

60 Get Secrets Manager client with lazy initialization. 

61 

62 The client is configured to use the region from the AUTH_SECRET_ARN 

63 environment variable, which may be different from the default region. 

64 

65 Returns: 

66 boto3 Secrets Manager client instance 

67 """ 

68 global _secrets_client 

69 if _secrets_client is None: 

70 # Extract region from the secret ARN 

71 # Format: arn:aws:secretsmanager:REGION:ACCOUNT:secret:NAME 

72 secret_arn = os.environ.get("AUTH_SECRET_ARN", "") 

73 region = None 

74 if secret_arn: 

75 parts = secret_arn.split(":") 

76 if len(parts) >= 4: 76 ↛ 78line 76 didn't jump to line 78 because the condition on line 76 was always true

77 region = parts[3] 

78 _secrets_client = boto3.client("secretsmanager", region_name=region) 

79 return _secrets_client 

80 

81 

82def _is_cache_valid() -> bool: 

83 """Check if the cached tokens are still valid based on TTL.""" 

84 return bool(_cached_tokens) and (time.time() - _cache_timestamp) < CACHE_TTL_SECONDS 

85 

86 

87def _refresh_cache() -> None: 

88 """Refresh the token cache from Secrets Manager. 

89 

90 On failure, keeps the existing (stale) cache to avoid rejecting all 

91 requests during a transient Secrets Manager outage. The next call 

92 after CACHE_TTL_SECONDS will retry the refresh. 

93 """ 

94 global _cached_tokens, _cache_timestamp 

95 

96 secret_arn = os.environ.get("AUTH_SECRET_ARN") 

97 if not secret_arn: 

98 return 

99 

100 try: 

101 secrets = get_secrets_client() 

102 new_tokens: set[str] = set() 

103 

104 # Get AWSCURRENT version (always present) 

105 try: 

106 response = secrets.get_secret_value( 

107 SecretId=secret_arn, 

108 VersionStage="AWSCURRENT", 

109 ) 

110 secret_data = json.loads(response["SecretString"]) 

111 new_tokens.add(secret_data["token"]) 

112 logger.debug("Loaded AWSCURRENT token") 

113 except Exception as e: 

114 logger.error(f"Failed to load AWSCURRENT secret: {e}") 

115 

116 # Get AWSPENDING version (only present during rotation) 

117 try: 

118 response = secrets.get_secret_value( 

119 SecretId=secret_arn, 

120 VersionStage="AWSPENDING", 

121 ) 

122 secret_data = json.loads(response["SecretString"]) 

123 new_tokens.add(secret_data["token"]) 

124 logger.debug("Loaded AWSPENDING token (rotation in progress)") 

125 except secrets.exceptions.ResourceNotFoundException: 

126 # No pending version - not in rotation, this is normal 

127 pass 

128 except Exception as e: 

129 # Log but don't fail - AWSPENDING is optional 

130 logger.debug(f"No AWSPENDING version available: {e}") 

131 

132 if new_tokens: 

133 _cached_tokens = new_tokens 

134 _cache_timestamp = time.time() 

135 logger.info(f"Token cache refreshed with {len(new_tokens)} valid token(s)") 

136 elif _cached_tokens: 

137 # Couldn't load any new tokens but have stale ones — extend the cache 

138 # to avoid rejecting all traffic during a transient SM outage 

139 _cache_timestamp = time.time() 

140 logger.warning("Token refresh returned empty set, keeping stale cache") 

141 

142 except Exception as e: 

143 logger.error(f"Failed to refresh token cache: {e}") 

144 if _cached_tokens: 

145 # Extend stale cache on total failure — better to accept slightly-old 

146 # tokens than to reject everything 

147 _cache_timestamp = time.time() 

148 logger.warning("Extending stale token cache due to refresh failure") 

149 

150 

151def get_valid_tokens() -> set[str]: 

152 """ 

153 Retrieve valid authentication tokens from AWS Secrets Manager. 

154 

155 Returns both AWSCURRENT and AWSPENDING tokens to support zero-downtime 

156 rotation. The tokens are cached with a TTL to minimize API calls while 

157 still picking up rotated secrets in a reasonable time. 

158 

159 Returns: 

160 Set of valid token strings, or empty set if not configured 

161 """ 

162 if not _is_cache_valid(): 

163 _refresh_cache() 

164 

165 return _cached_tokens 

166 

167 

168def get_secret_token() -> str | None: 

169 """ 

170 Retrieve the primary authentication token from AWS Secrets Manager. 

171 

172 This is a compatibility function that returns the first valid token. 

173 For rotation support, use get_valid_tokens() instead. 

174 

175 Returns: 

176 The secret token string, or None if not configured 

177 """ 

178 tokens = get_valid_tokens() 

179 return next(iter(tokens), None) if tokens else None 

180 

181 

182def clear_token_cache() -> None: 

183 """ 

184 Clear the token cache, forcing a refresh on next validation. 

185 

186 Useful for testing or when you know the secret has been rotated. 

187 """ 

188 global _cached_tokens, _cache_timestamp 

189 _cached_tokens = set() 

190 _cache_timestamp = 0 

191 logger.info("Token cache cleared") 

192 

193 

194class AuthenticationMiddleware(BaseHTTPMiddleware): 

195 """ 

196 FastAPI middleware to validate X-GCO-Auth-Token header. 

197 

198 This middleware ensures all API requests came through the authenticated 

199 API Gateway by validating a secret token header. Health check endpoints 

200 are excluded to allow load balancer health probes. 

201 

202 During secret rotation, both AWSCURRENT and AWSPENDING tokens are accepted 

203 to ensure zero-downtime rotation. 

204 """ 

205 

206 def __init__(self, app: ASGIApp) -> None: 

207 super().__init__(app) 

208 # Startup-time configuration check — surface misconfigurations early 

209 secret_arn = os.environ.get("AUTH_SECRET_ARN") 

210 if not secret_arn: 

211 dev_mode = os.environ.get("GCO_DEV_MODE", "").lower() == "true" 

212 if dev_mode: 

213 logger.warning( 

214 "GCO_DEV_MODE=true with no AUTH_SECRET_ARN — " 

215 "authentication is bypassed. Do NOT use in production." 

216 ) 

217 else: 

218 logger.error( 

219 "AUTH_SECRET_ARN is not configured and GCO_DEV_MODE is not enabled. " 

220 "All non-health-check requests will be denied with 503." 

221 ) 

222 

223 async def dispatch( 

224 self, 

225 request: Request, 

226 call_next: Callable[[Request], Awaitable[Response]], 

227 ) -> Response: 

228 """ 

229 Process incoming request and validate authentication. 

230 

231 Args: 

232 request: The incoming FastAPI request 

233 call_next: The next middleware/handler in the chain 

234 

235 Returns: 

236 Response from the next handler if authenticated 

237 

238 Raises: 

239 HTTPException: 403 if authentication fails 

240 """ 

241 # Skip authentication for health check endpoints 

242 if request.url.path in UNAUTHENTICATED_PATHS: 

243 return await call_next(request) 

244 

245 valid_tokens = get_valid_tokens() 

246 

247 # No tokens available — determine whether to fail open or closed 

248 if not valid_tokens: 

249 secret_arn = os.environ.get("AUTH_SECRET_ARN") 

250 if not secret_arn: 

251 # No secret configured. Only allow requests if the operator 

252 # explicitly opted into dev mode. This prevents accidental 

253 # unauthenticated deployments due to misconfiguration. 

254 dev_mode = os.environ.get("GCO_DEV_MODE", "").lower() == "true" 

255 if dev_mode: 

256 logger.warning( 

257 "Authentication bypassed - GCO_DEV_MODE=true, no secret configured" 

258 ) 

259 return await call_next(request) 

260 # Fail closed: no secret + no dev mode = deny 

261 logger.error( 

262 "No AUTH_SECRET_ARN configured and GCO_DEV_MODE is not enabled. " 

263 "Set AUTH_SECRET_ARN for production or GCO_DEV_MODE=true for local development." 

264 ) 

265 raise HTTPException( 

266 status_code=503, 

267 detail="Service unavailable - authentication not configured", 

268 ) 

269 # Secret configured but couldn't load - deny access 

270 logger.error("Failed to load authentication tokens") 

271 raise HTTPException( 

272 status_code=503, 

273 detail="Service temporarily unavailable - authentication error", 

274 ) 

275 

276 # Validate the auth header against all valid tokens 

277 auth_header = request.headers.get("x-gco-auth-token", "") 

278 

279 if auth_header not in valid_tokens: 

280 client_ip = request.client.host if request.client else "unknown" 

281 logger.warning(f"Invalid auth token from {client_ip} for {request.url.path}") 

282 raise HTTPException( 

283 status_code=403, 

284 detail="Forbidden - requests must come through authenticated API Gateway", 

285 ) 

286 

287 return await call_next(request)