Coverage for mcp/audit_middleware.py: 93%
57 statements
« prev ^ index » next coverage.py v7.14.1, created at 2026-06-15 15:07 +0000
« prev ^ index » next coverage.py v7.14.1, created at 2026-06-15 15:07 +0000
1"""
2Audit-capture middleware for the GCO MCP server.
4Wires up two pieces:
61. ``Context.{warning, info, error, elicit}`` are wrapped at the class level
7 (once, at module import time) so every Context instance — including the
8 fresh one FastMCP creates for each tool call — appends to the active
9 capture buffers.
102. ``AuditCaptureMiddleware`` sets fresh capture buffers in
11 ``audit_messages_var`` / ``audit_elicitations_var`` at the start of every
12 ``on_call_tool`` invocation and resets them on the way out. The audit
13 decorator (``mcp/audit.py::_build_audit_entry``) reads those buffers
14 when emitting the entry.
16The Context class only has its methods patched once. Idempotency is
17enforced by inspecting an attribute we set on the patched function;
18re-imports (test reloads, hot-reload during dev) detect the marker
19and skip re-patching. The wrapped methods short-circuit to the
20originals when no capture buffer is active, so this patch is a no-op
21for any code that uses Context outside of a tool call (e.g. unit
22tests that construct a Context directly).
23"""
25from __future__ import annotations
27from typing import Any
29from audit import audit_elicitations_var, audit_messages_var
30from fastmcp.server.context import Context
31from fastmcp.server.middleware import Middleware
33# Attribute we tag each spy function with so ``_install_context_patches``
34# can detect that the wrappers are already in place. Reading an
35# attribute on the live class method is more robust than a separate
36# module-level boolean: if a re-import re-runs this module, the same
37# marker survives because the patched method on ``Context`` survives.
38_SPY_MARKER = "_gco_audit_spy"
41def _install_context_patches() -> None:
42 """Install the class-level Context method wrappers (once)."""
43 if getattr(Context.warning, _SPY_MARKER, False):
44 return
46 _orig_warning = Context.warning
47 _orig_info = Context.info
48 _orig_error = Context.error
49 _orig_elicit = Context.elicit
51 async def _spy_warning(self: Context, message: str, *args: Any, **kwargs: Any) -> Any:
52 lst = audit_messages_var.get()
53 if lst is not None: 53 ↛ 55line 53 didn't jump to line 55 because the condition on line 53 was always true
54 lst.append({"level": "warning", "message": str(message)})
55 return await _orig_warning(self, message, *args, **kwargs)
57 async def _spy_info(self: Context, message: str, *args: Any, **kwargs: Any) -> Any:
58 lst = audit_messages_var.get()
59 if lst is not None: 59 ↛ 61line 59 didn't jump to line 61 because the condition on line 59 was always true
60 lst.append({"level": "info", "message": str(message)})
61 return await _orig_info(self, message, *args, **kwargs)
63 async def _spy_error(self: Context, message: str, *args: Any, **kwargs: Any) -> Any:
64 lst = audit_messages_var.get()
65 if lst is not None: 65 ↛ 67line 65 didn't jump to line 67 because the condition on line 65 was always true
66 lst.append({"level": "error", "message": str(message)})
67 return await _orig_error(self, message, *args, **kwargs)
69 async def _spy_elicit(self: Context, message: str, *args: Any, **kwargs: Any) -> Any:
70 result = await _orig_elicit(self, message, *args, **kwargs)
71 lst = audit_elicitations_var.get()
72 if lst is not None: 72 ↛ 83line 72 didn't jump to line 83 because the condition on line 72 was always true
73 entry: dict[str, Any] = {
74 "message": str(message),
75 "action": getattr(result, "action", None),
76 }
77 data = getattr(result, "data", None)
78 if data is not None: 78 ↛ 82line 78 didn't jump to line 82 because the condition on line 78 was always true
79 # Stringify to avoid leaking arbitrary user objects through
80 # the audit log; the audit log is a JSON-line stream.
81 entry["data"] = data if isinstance(data, (str, int, float, bool)) else str(data)
82 lst.append(entry)
83 return result
85 # Tag each spy with the marker before attaching so a concurrent
86 # re-entry observes the marker as soon as the assignment lands.
87 for spy in (_spy_warning, _spy_info, _spy_error, _spy_elicit):
88 setattr(spy, _SPY_MARKER, True)
90 Context.warning = _spy_warning # type: ignore[method-assign]
91 Context.info = _spy_info # type: ignore[method-assign]
92 Context.error = _spy_error # type: ignore[method-assign]
93 Context.elicit = _spy_elicit # type: ignore[method-assign]
96class AuditCaptureMiddleware(Middleware):
97 """FastMCP middleware that activates per-invocation audit capture buffers.
99 On every ``on_call_tool`` call, sets fresh empty lists into
100 ``audit_messages_var`` and ``audit_elicitations_var`` so the patched
101 Context methods append into them. Resets the ContextVars on the way
102 out so concurrent calls don't see each other's captures.
103 """
105 def __init__(self) -> None:
106 _install_context_patches()
108 async def on_call_tool(self, context: Any, call_next: Any) -> Any:
109 messages: list[dict[str, str]] = []
110 elicitations: list[dict[str, object]] = []
111 msg_token = audit_messages_var.set(messages)
112 elic_token = audit_elicitations_var.set(elicitations)
113 try:
114 return await call_next(context)
115 finally:
116 audit_messages_var.reset(msg_token)
117 audit_elicitations_var.reset(elic_token)
120# Install patches eagerly at module import so callers that build their own
121# pipelines (or tests that bypass middleware wiring) still get the capture
122# behaviour as long as they install fresh ContextVars themselves.
123_install_context_patches()