Coverage for mcp/tools/inference.py: 96%
110 statements
« prev ^ index » next coverage.py v7.14.1, created at 2026-06-15 15:07 +0000
« prev ^ index » next coverage.py v7.14.1, created at 2026-06-15 15:07 +0000
1"""Inference endpoint management MCP tools."""
3import asyncio
4import contextlib
5import json
6from typing import Any
8import cli_runner
9from audit import audit_logged
10from feature_flags import FLAG_DESTRUCTIVE_OPERATIONS, is_enabled
11from server import mcp
14async def _ctx_warning(message: str) -> None:
15 """Emit ``ctx.warning(...)`` from inside a tool body, no-op when no Context.
17 The destructive ``delete_inference`` tool runs short — we don't need the
18 full long-task progress stack, just an audited warning back to the
19 operator (and the audit log via the middleware spy).
20 """
21 try:
22 from fastmcp.server.dependencies import get_context
24 ctx = get_context()
25 except Exception:
26 return
27 with contextlib.suppress(Exception):
28 await ctx.warning(message)
31@mcp.tool(tags={"low-risk", "inference"})
32@audit_logged
33def deploy_inference(
34 name: str,
35 image: str,
36 gpu_count: int = 1,
37 replicas: int = 1,
38 port: int = 8000,
39 region: str | None = None,
40 env_vars: list[str] | None = None,
41) -> str:
42 """Deploy an inference endpoint across regions.
44 Args:
45 name: Endpoint name (e.g. my-llm).
46 image: Container image (e.g. vllm/vllm-openai:v0.22.0).
47 gpu_count: GPUs per replica.
48 replicas: Number of replicas per region.
49 port: Container port.
50 region: Target region(s). Omit for all deployed regions.
51 env_vars: Environment variables as KEY=VALUE strings.
52 """
53 args = [
54 "inference",
55 "deploy",
56 name,
57 "-i",
58 image,
59 "--gpu-count",
60 str(gpu_count),
61 "--replicas",
62 str(replicas),
63 "--port",
64 str(port),
65 ]
66 if region:
67 args += ["-r", region]
68 for env in env_vars or []:
69 args += ["-e", env]
70 return cli_runner._run_cli(*args)
73@mcp.tool(tags={"safe", "inference"})
74@audit_logged
75def list_inference_endpoints(state: str | None = None, region: str | None = None) -> str:
76 """List all inference endpoints.
78 Args:
79 state: Filter by state (deploying, running, stopped, deleted).
80 region: Filter by region.
81 """
82 args = ["inference", "list"]
83 if state: 83 ↛ 85line 83 didn't jump to line 85 because the condition on line 83 was always true
84 args += ["--state", state]
85 if region: 85 ↛ 86line 85 didn't jump to line 86 because the condition on line 85 was never true
86 args += ["-r", region]
87 return cli_runner._run_cli(*args)
90@mcp.tool(tags={"safe", "inference"})
91@audit_logged
92def inference_status(name: str) -> str:
93 """Get detailed status of an inference endpoint including per-region breakdown.
95 Args:
96 name: Endpoint name.
97 """
98 return cli_runner._run_cli("inference", "status", name)
101@mcp.tool(tags={"low-risk", "inference"})
102@audit_logged
103def scale_inference(name: str, replicas: int) -> str:
104 """Scale an inference endpoint.
106 Args:
107 name: Endpoint name.
108 replicas: Target replica count.
109 """
110 return cli_runner._run_cli("inference", "scale", name, "--replicas", str(replicas))
113@mcp.tool(tags={"low-risk", "inference"})
114@audit_logged
115def update_inference_image(name: str, image: str) -> str:
116 """Rolling update of an inference endpoint's container image.
118 Args:
119 name: Endpoint name.
120 image: New container image.
121 """
122 return cli_runner._run_cli("inference", "update-image", name, "-i", image)
125@mcp.tool(tags={"low-risk", "inference"})
126@audit_logged
127def stop_inference(name: str) -> str:
128 """Stop an inference endpoint (scales to zero, keeps config).
130 Args:
131 name: Endpoint name.
132 """
133 return cli_runner._run_cli("inference", "stop", name, "-y")
136@mcp.tool(tags={"low-risk", "inference"})
137@audit_logged
138def start_inference(name: str) -> str:
139 """Start a stopped inference endpoint.
141 Args:
142 name: Endpoint name.
143 """
144 return cli_runner._run_cli("inference", "start", name)
147if is_enabled(FLAG_DESTRUCTIVE_OPERATIONS):
149 @mcp.tool(tags={"destructive", "inference"})
150 @audit_logged
151 async def delete_inference(name: str) -> str:
152 """[gated by GCO_ENABLE_DESTRUCTIVE_OPERATIONS] destructive.
154 Delete an inference endpoint. Cannot be undone — the endpoint, its
155 DynamoDB record, and the underlying Kubernetes resources are removed.
157 Args:
158 name: Endpoint name.
159 """
160 await _ctx_warning(f"Deleting inference endpoint {name!r} — this cannot be undone.")
161 return await asyncio.to_thread(cli_runner._run_cli, "inference", "delete", name, "-y")
164@mcp.tool(tags={"low-risk", "inference"})
165@audit_logged
166def canary_deploy(name: str, image: str, weight: int = 10) -> str:
167 """Start a canary deployment (A/B test a new image version).
169 Args:
170 name: Endpoint name.
171 image: New image to canary.
172 weight: Percentage of traffic to send to canary (1-99).
173 """
174 return cli_runner._run_cli("inference", "canary", name, "-i", image, "--weight", str(weight))
177@mcp.tool(tags={"low-risk", "inference"})
178@audit_logged
179def promote_canary(name: str) -> str:
180 """Promote canary to primary (100% traffic to new version).
182 Args:
183 name: Endpoint name.
184 """
185 return cli_runner._run_cli("inference", "promote", name, "-y")
188@mcp.tool(tags={"low-risk", "inference"})
189@audit_logged
190def rollback_canary(name: str) -> str:
191 """Rollback canary (remove canary, 100% traffic to primary).
193 Args:
194 name: Endpoint name.
195 """
196 return cli_runner._run_cli("inference", "rollback", name, "-y")
199@mcp.tool(tags={"safe", "inference"})
200@audit_logged
201def invoke_inference(
202 name: str,
203 prompt: str,
204 max_tokens: int = 100,
205 api_path: str | None = None,
206 stream: bool = False,
207 region: str | None = None,
208) -> str:
209 """Send a prompt to an inference endpoint and return the generated text.
211 Automatically discovers the endpoint's ingress path, detects the serving
212 framework (vLLM, TGI, Triton), and routes the request through the API
213 Gateway with SigV4 authentication.
215 Use this for single-turn text completions. For multi-turn conversations
216 with chat models, use chat_inference instead.
218 Args:
219 name: Endpoint name (e.g. my-llm).
220 prompt: Text prompt to send to the model.
221 max_tokens: Maximum tokens to generate (default: 100).
222 api_path: Override the API sub-path (default: auto-detect from framework).
223 stream: Enable streaming for lower time-to-first-token (default: false).
224 region: Target region for the request (default: nearest via Global Accelerator).
225 """
226 args = ["inference", "invoke", name, "-p", prompt, "--max-tokens", str(max_tokens)]
227 if api_path:
228 args += ["--path", api_path]
229 if stream:
230 args.append("--stream")
231 if region:
232 args += ["-r", region]
233 return cli_runner._run_cli(*args)
236@mcp.tool(tags={"safe", "inference"})
237@audit_logged
238def chat_inference(
239 name: str,
240 messages: list[dict[str, str]],
241 max_tokens: int = 256,
242 temperature: float | None = None,
243 stream: bool = False,
244 region: str | None = None,
245) -> str:
246 """Send a multi-turn chat conversation to an inference endpoint.
248 Sends an OpenAI-compatible /v1/chat/completions request. Works with
249 vLLM, TGI (with --api-protocol openai), and any OpenAI-compatible server.
251 Each message in the list should have 'role' (system/user/assistant) and
252 'content' keys.
254 Args:
255 name: Endpoint name (e.g. my-llm).
256 messages: List of chat messages, e.g. [{"role": "user", "content": "Hello"}].
257 max_tokens: Maximum tokens to generate (default: 256).
258 temperature: Sampling temperature (optional, server default if omitted).
259 stream: Enable streaming for lower time-to-first-token (default: false).
260 region: Target region for the request.
261 """
262 body: dict[str, Any] = {"messages": messages, "max_tokens": max_tokens, "stream": stream}
263 if temperature is not None:
264 body["temperature"] = temperature
265 data_str = json.dumps(body)
266 args = ["inference", "invoke", name, "-d", data_str, "--path", "/v1/chat/completions"]
267 if stream:
268 args.append("--stream")
269 if region:
270 args += ["-r", region]
271 return cli_runner._run_cli(*args)
274@mcp.tool(tags={"safe", "inference"})
275@audit_logged
276def inference_health(name: str, region: str | None = None) -> str:
277 """Check if an inference endpoint is healthy and ready to serve requests.
279 Hits the endpoint's health check path and returns status and latency.
280 Useful to verify readiness before sending inference requests.
282 Args:
283 name: Endpoint name.
284 region: Target region to check (default: nearest via Global Accelerator).
285 """
286 args = ["inference", "health", name]
287 if region:
288 args += ["-r", region]
289 return cli_runner._run_cli(*args)
292@mcp.tool(tags={"safe", "inference"})
293@audit_logged
294def list_endpoint_models(name: str, region: str | None = None) -> str:
295 """List models loaded on an inference endpoint.
297 Queries the endpoint's /v1/models path (OpenAI-compatible) to discover
298 which models are loaded, their context length, and other metadata.
299 Works with vLLM and other OpenAI-compatible servers.
301 Args:
302 name: Endpoint name.
303 region: Target region to query.
304 """
305 args = ["inference", "models", name]
306 if region:
307 args += ["-r", region]
308 return cli_runner._run_cli(*args)