Coverage for mcp/audit_middleware.py: 93%

57 statements  

« prev     ^ index     » next       coverage.py v7.14.1, created at 2026-06-15 15:07 +0000

1""" 

2Audit-capture middleware for the GCO MCP server. 

3 

4Wires up two pieces: 

5 

61. ``Context.{warning, info, error, elicit}`` are wrapped at the class level 

7 (once, at module import time) so every Context instance — including the 

8 fresh one FastMCP creates for each tool call — appends to the active 

9 capture buffers. 

102. ``AuditCaptureMiddleware`` sets fresh capture buffers in 

11 ``audit_messages_var`` / ``audit_elicitations_var`` at the start of every 

12 ``on_call_tool`` invocation and resets them on the way out. The audit 

13 decorator (``mcp/audit.py::_build_audit_entry``) reads those buffers 

14 when emitting the entry. 

15 

16The Context class only has its methods patched once. Idempotency is 

17enforced by inspecting an attribute we set on the patched function; 

18re-imports (test reloads, hot-reload during dev) detect the marker 

19and skip re-patching. The wrapped methods short-circuit to the 

20originals when no capture buffer is active, so this patch is a no-op 

21for any code that uses Context outside of a tool call (e.g. unit 

22tests that construct a Context directly). 

23""" 

24 

25from __future__ import annotations 

26 

27from typing import Any 

28 

29from audit import audit_elicitations_var, audit_messages_var 

30from fastmcp.server.context import Context 

31from fastmcp.server.middleware import Middleware 

32 

33# Attribute we tag each spy function with so ``_install_context_patches`` 

34# can detect that the wrappers are already in place. Reading an 

35# attribute on the live class method is more robust than a separate 

36# module-level boolean: if a re-import re-runs this module, the same 

37# marker survives because the patched method on ``Context`` survives. 

38_SPY_MARKER = "_gco_audit_spy" 

39 

40 

41def _install_context_patches() -> None: 

42 """Install the class-level Context method wrappers (once).""" 

43 if getattr(Context.warning, _SPY_MARKER, False): 

44 return 

45 

46 _orig_warning = Context.warning 

47 _orig_info = Context.info 

48 _orig_error = Context.error 

49 _orig_elicit = Context.elicit 

50 

51 async def _spy_warning(self: Context, message: str, *args: Any, **kwargs: Any) -> Any: 

52 lst = audit_messages_var.get() 

53 if lst is not None: 53 ↛ 55line 53 didn't jump to line 55 because the condition on line 53 was always true

54 lst.append({"level": "warning", "message": str(message)}) 

55 return await _orig_warning(self, message, *args, **kwargs) 

56 

57 async def _spy_info(self: Context, message: str, *args: Any, **kwargs: Any) -> Any: 

58 lst = audit_messages_var.get() 

59 if lst is not None: 59 ↛ 61line 59 didn't jump to line 61 because the condition on line 59 was always true

60 lst.append({"level": "info", "message": str(message)}) 

61 return await _orig_info(self, message, *args, **kwargs) 

62 

63 async def _spy_error(self: Context, message: str, *args: Any, **kwargs: Any) -> Any: 

64 lst = audit_messages_var.get() 

65 if lst is not None: 65 ↛ 67line 65 didn't jump to line 67 because the condition on line 65 was always true

66 lst.append({"level": "error", "message": str(message)}) 

67 return await _orig_error(self, message, *args, **kwargs) 

68 

69 async def _spy_elicit(self: Context, message: str, *args: Any, **kwargs: Any) -> Any: 

70 result = await _orig_elicit(self, message, *args, **kwargs) 

71 lst = audit_elicitations_var.get() 

72 if lst is not None: 72 ↛ 83line 72 didn't jump to line 83 because the condition on line 72 was always true

73 entry: dict[str, Any] = { 

74 "message": str(message), 

75 "action": getattr(result, "action", None), 

76 } 

77 data = getattr(result, "data", None) 

78 if data is not None: 78 ↛ 82line 78 didn't jump to line 82 because the condition on line 78 was always true

79 # Stringify to avoid leaking arbitrary user objects through 

80 # the audit log; the audit log is a JSON-line stream. 

81 entry["data"] = data if isinstance(data, (str, int, float, bool)) else str(data) 

82 lst.append(entry) 

83 return result 

84 

85 # Tag each spy with the marker before attaching so a concurrent 

86 # re-entry observes the marker as soon as the assignment lands. 

87 for spy in (_spy_warning, _spy_info, _spy_error, _spy_elicit): 

88 setattr(spy, _SPY_MARKER, True) 

89 

90 Context.warning = _spy_warning # type: ignore[method-assign] 

91 Context.info = _spy_info # type: ignore[method-assign] 

92 Context.error = _spy_error # type: ignore[method-assign] 

93 Context.elicit = _spy_elicit # type: ignore[method-assign] 

94 

95 

96class AuditCaptureMiddleware(Middleware): 

97 """FastMCP middleware that activates per-invocation audit capture buffers. 

98 

99 On every ``on_call_tool`` call, sets fresh empty lists into 

100 ``audit_messages_var`` and ``audit_elicitations_var`` so the patched 

101 Context methods append into them. Resets the ContextVars on the way 

102 out so concurrent calls don't see each other's captures. 

103 """ 

104 

105 def __init__(self) -> None: 

106 _install_context_patches() 

107 

108 async def on_call_tool(self, context: Any, call_next: Any) -> Any: 

109 messages: list[dict[str, str]] = [] 

110 elicitations: list[dict[str, object]] = [] 

111 msg_token = audit_messages_var.set(messages) 

112 elic_token = audit_elicitations_var.set(elicitations) 

113 try: 

114 return await call_next(context) 

115 finally: 

116 audit_messages_var.reset(msg_token) 

117 audit_elicitations_var.reset(elic_token) 

118 

119 

120# Install patches eagerly at module import so callers that build their own 

121# pipelines (or tests that bypass middleware wiring) still get the capture 

122# behaviour as long as they install fresh ContextVars themselves. 

123_install_context_patches()