Coverage for mcp/tools/inference.py: 96%

110 statements  

« prev     ^ index     » next       coverage.py v7.14.1, created at 2026-06-15 15:07 +0000

1"""Inference endpoint management MCP tools.""" 

2 

3import asyncio 

4import contextlib 

5import json 

6from typing import Any 

7 

8import cli_runner 

9from audit import audit_logged 

10from feature_flags import FLAG_DESTRUCTIVE_OPERATIONS, is_enabled 

11from server import mcp 

12 

13 

14async def _ctx_warning(message: str) -> None: 

15 """Emit ``ctx.warning(...)`` from inside a tool body, no-op when no Context. 

16 

17 The destructive ``delete_inference`` tool runs short — we don't need the 

18 full long-task progress stack, just an audited warning back to the 

19 operator (and the audit log via the middleware spy). 

20 """ 

21 try: 

22 from fastmcp.server.dependencies import get_context 

23 

24 ctx = get_context() 

25 except Exception: 

26 return 

27 with contextlib.suppress(Exception): 

28 await ctx.warning(message) 

29 

30 

31@mcp.tool(tags={"low-risk", "inference"}) 

32@audit_logged 

33def deploy_inference( 

34 name: str, 

35 image: str, 

36 gpu_count: int = 1, 

37 replicas: int = 1, 

38 port: int = 8000, 

39 region: str | None = None, 

40 env_vars: list[str] | None = None, 

41) -> str: 

42 """Deploy an inference endpoint across regions. 

43 

44 Args: 

45 name: Endpoint name (e.g. my-llm). 

46 image: Container image (e.g. vllm/vllm-openai:v0.22.0). 

47 gpu_count: GPUs per replica. 

48 replicas: Number of replicas per region. 

49 port: Container port. 

50 region: Target region(s). Omit for all deployed regions. 

51 env_vars: Environment variables as KEY=VALUE strings. 

52 """ 

53 args = [ 

54 "inference", 

55 "deploy", 

56 name, 

57 "-i", 

58 image, 

59 "--gpu-count", 

60 str(gpu_count), 

61 "--replicas", 

62 str(replicas), 

63 "--port", 

64 str(port), 

65 ] 

66 if region: 

67 args += ["-r", region] 

68 for env in env_vars or []: 

69 args += ["-e", env] 

70 return cli_runner._run_cli(*args) 

71 

72 

73@mcp.tool(tags={"safe", "inference"}) 

74@audit_logged 

75def list_inference_endpoints(state: str | None = None, region: str | None = None) -> str: 

76 """List all inference endpoints. 

77 

78 Args: 

79 state: Filter by state (deploying, running, stopped, deleted). 

80 region: Filter by region. 

81 """ 

82 args = ["inference", "list"] 

83 if state: 83 ↛ 85line 83 didn't jump to line 85 because the condition on line 83 was always true

84 args += ["--state", state] 

85 if region: 85 ↛ 86line 85 didn't jump to line 86 because the condition on line 85 was never true

86 args += ["-r", region] 

87 return cli_runner._run_cli(*args) 

88 

89 

90@mcp.tool(tags={"safe", "inference"}) 

91@audit_logged 

92def inference_status(name: str) -> str: 

93 """Get detailed status of an inference endpoint including per-region breakdown. 

94 

95 Args: 

96 name: Endpoint name. 

97 """ 

98 return cli_runner._run_cli("inference", "status", name) 

99 

100 

101@mcp.tool(tags={"low-risk", "inference"}) 

102@audit_logged 

103def scale_inference(name: str, replicas: int) -> str: 

104 """Scale an inference endpoint. 

105 

106 Args: 

107 name: Endpoint name. 

108 replicas: Target replica count. 

109 """ 

110 return cli_runner._run_cli("inference", "scale", name, "--replicas", str(replicas)) 

111 

112 

113@mcp.tool(tags={"low-risk", "inference"}) 

114@audit_logged 

115def update_inference_image(name: str, image: str) -> str: 

116 """Rolling update of an inference endpoint's container image. 

117 

118 Args: 

119 name: Endpoint name. 

120 image: New container image. 

121 """ 

122 return cli_runner._run_cli("inference", "update-image", name, "-i", image) 

123 

124 

125@mcp.tool(tags={"low-risk", "inference"}) 

126@audit_logged 

127def stop_inference(name: str) -> str: 

128 """Stop an inference endpoint (scales to zero, keeps config). 

129 

130 Args: 

131 name: Endpoint name. 

132 """ 

133 return cli_runner._run_cli("inference", "stop", name, "-y") 

134 

135 

136@mcp.tool(tags={"low-risk", "inference"}) 

137@audit_logged 

138def start_inference(name: str) -> str: 

139 """Start a stopped inference endpoint. 

140 

141 Args: 

142 name: Endpoint name. 

143 """ 

144 return cli_runner._run_cli("inference", "start", name) 

145 

146 

147if is_enabled(FLAG_DESTRUCTIVE_OPERATIONS): 

148 

149 @mcp.tool(tags={"destructive", "inference"}) 

150 @audit_logged 

151 async def delete_inference(name: str) -> str: 

152 """[gated by GCO_ENABLE_DESTRUCTIVE_OPERATIONS] destructive. 

153 

154 Delete an inference endpoint. Cannot be undone — the endpoint, its 

155 DynamoDB record, and the underlying Kubernetes resources are removed. 

156 

157 Args: 

158 name: Endpoint name. 

159 """ 

160 await _ctx_warning(f"Deleting inference endpoint {name!r} — this cannot be undone.") 

161 return await asyncio.to_thread(cli_runner._run_cli, "inference", "delete", name, "-y") 

162 

163 

164@mcp.tool(tags={"low-risk", "inference"}) 

165@audit_logged 

166def canary_deploy(name: str, image: str, weight: int = 10) -> str: 

167 """Start a canary deployment (A/B test a new image version). 

168 

169 Args: 

170 name: Endpoint name. 

171 image: New image to canary. 

172 weight: Percentage of traffic to send to canary (1-99). 

173 """ 

174 return cli_runner._run_cli("inference", "canary", name, "-i", image, "--weight", str(weight)) 

175 

176 

177@mcp.tool(tags={"low-risk", "inference"}) 

178@audit_logged 

179def promote_canary(name: str) -> str: 

180 """Promote canary to primary (100% traffic to new version). 

181 

182 Args: 

183 name: Endpoint name. 

184 """ 

185 return cli_runner._run_cli("inference", "promote", name, "-y") 

186 

187 

188@mcp.tool(tags={"low-risk", "inference"}) 

189@audit_logged 

190def rollback_canary(name: str) -> str: 

191 """Rollback canary (remove canary, 100% traffic to primary). 

192 

193 Args: 

194 name: Endpoint name. 

195 """ 

196 return cli_runner._run_cli("inference", "rollback", name, "-y") 

197 

198 

199@mcp.tool(tags={"safe", "inference"}) 

200@audit_logged 

201def invoke_inference( 

202 name: str, 

203 prompt: str, 

204 max_tokens: int = 100, 

205 api_path: str | None = None, 

206 stream: bool = False, 

207 region: str | None = None, 

208) -> str: 

209 """Send a prompt to an inference endpoint and return the generated text. 

210 

211 Automatically discovers the endpoint's ingress path, detects the serving 

212 framework (vLLM, TGI, Triton), and routes the request through the API 

213 Gateway with SigV4 authentication. 

214 

215 Use this for single-turn text completions. For multi-turn conversations 

216 with chat models, use chat_inference instead. 

217 

218 Args: 

219 name: Endpoint name (e.g. my-llm). 

220 prompt: Text prompt to send to the model. 

221 max_tokens: Maximum tokens to generate (default: 100). 

222 api_path: Override the API sub-path (default: auto-detect from framework). 

223 stream: Enable streaming for lower time-to-first-token (default: false). 

224 region: Target region for the request (default: nearest via Global Accelerator). 

225 """ 

226 args = ["inference", "invoke", name, "-p", prompt, "--max-tokens", str(max_tokens)] 

227 if api_path: 

228 args += ["--path", api_path] 

229 if stream: 

230 args.append("--stream") 

231 if region: 

232 args += ["-r", region] 

233 return cli_runner._run_cli(*args) 

234 

235 

236@mcp.tool(tags={"safe", "inference"}) 

237@audit_logged 

238def chat_inference( 

239 name: str, 

240 messages: list[dict[str, str]], 

241 max_tokens: int = 256, 

242 temperature: float | None = None, 

243 stream: bool = False, 

244 region: str | None = None, 

245) -> str: 

246 """Send a multi-turn chat conversation to an inference endpoint. 

247 

248 Sends an OpenAI-compatible /v1/chat/completions request. Works with 

249 vLLM, TGI (with --api-protocol openai), and any OpenAI-compatible server. 

250 

251 Each message in the list should have 'role' (system/user/assistant) and 

252 'content' keys. 

253 

254 Args: 

255 name: Endpoint name (e.g. my-llm). 

256 messages: List of chat messages, e.g. [{"role": "user", "content": "Hello"}]. 

257 max_tokens: Maximum tokens to generate (default: 256). 

258 temperature: Sampling temperature (optional, server default if omitted). 

259 stream: Enable streaming for lower time-to-first-token (default: false). 

260 region: Target region for the request. 

261 """ 

262 body: dict[str, Any] = {"messages": messages, "max_tokens": max_tokens, "stream": stream} 

263 if temperature is not None: 

264 body["temperature"] = temperature 

265 data_str = json.dumps(body) 

266 args = ["inference", "invoke", name, "-d", data_str, "--path", "/v1/chat/completions"] 

267 if stream: 

268 args.append("--stream") 

269 if region: 

270 args += ["-r", region] 

271 return cli_runner._run_cli(*args) 

272 

273 

274@mcp.tool(tags={"safe", "inference"}) 

275@audit_logged 

276def inference_health(name: str, region: str | None = None) -> str: 

277 """Check if an inference endpoint is healthy and ready to serve requests. 

278 

279 Hits the endpoint's health check path and returns status and latency. 

280 Useful to verify readiness before sending inference requests. 

281 

282 Args: 

283 name: Endpoint name. 

284 region: Target region to check (default: nearest via Global Accelerator). 

285 """ 

286 args = ["inference", "health", name] 

287 if region: 

288 args += ["-r", region] 

289 return cli_runner._run_cli(*args) 

290 

291 

292@mcp.tool(tags={"safe", "inference"}) 

293@audit_logged 

294def list_endpoint_models(name: str, region: str | None = None) -> str: 

295 """List models loaded on an inference endpoint. 

296 

297 Queries the endpoint's /v1/models path (OpenAI-compatible) to discover 

298 which models are loaded, their context length, and other metadata. 

299 Works with vLLM and other OpenAI-compatible servers. 

300 

301 Args: 

302 name: Endpoint name. 

303 region: Target region to query. 

304 """ 

305 args = ["inference", "models", name] 

306 if region: 

307 args += ["-r", region] 

308 return cli_runner._run_cli(*args)