Coverage for mcp/tools/models.py: 95%

36 statements  

« prev     ^ index     » next       coverage.py v7.14.1, created at 2026-06-15 15:07 +0000

1"""Model weight management MCP tools.""" 

2 

3import asyncio 

4import contextlib 

5 

6import cli_runner 

7from audit import audit_logged 

8from feature_flags import FLAG_DESTRUCTIVE_OPERATIONS, FLAG_MODEL_UPLOAD, is_enabled 

9from server import mcp 

10 

11 

12@mcp.tool(tags={"safe", "models"}) 

13@audit_logged 

14def list_models() -> str: 

15 """List all uploaded model weights in the S3 bucket.""" 

16 return cli_runner._run_cli("models", "list") 

17 

18 

19@mcp.tool(tags={"safe", "models"}) 

20@audit_logged 

21def get_model_uri(model_name: str) -> str: 

22 """Get the S3 URI for a model (for use with --model-source). 

23 

24 Args: 

25 model_name: Name of the model. 

26 """ 

27 return cli_runner._run_cli("models", "uri", model_name) 

28 

29 

30async def _ctx_warning(message: str) -> None: 

31 """Emit ``ctx.warning(...)`` from inside a tool body, no-op when no Context.""" 

32 try: 

33 from fastmcp.server.dependencies import get_context 

34 

35 ctx = get_context() 

36 except Exception: 

37 return 

38 with contextlib.suppress(Exception): 

39 await ctx.warning(message) 

40 

41 

42# ============================================================================= 

43# Model upload — gated by GCO_ENABLE_MODEL_UPLOAD 

44# ============================================================================= 

45 

46 

47if is_enabled(FLAG_MODEL_UPLOAD): 

48 

49 @mcp.tool(tags={"data-upload", "models"}) 

50 @audit_logged 

51 async def models_upload( 

52 model_name: str, 

53 source_path: str, 

54 region: str | None = None, 

55 ) -> str: 

56 """[gated by GCO_ENABLE_MODEL_UPLOAD] data-upload. 

57 

58 `gco models upload` — upload model weights from a local path to the 

59 central S3 bucket. The CLI handles multipart uploads and progress 

60 reporting; this tool surface returns the final result JSON. 

61 

62 Args: 

63 model_name: Model name in the registry. 

64 source_path: Local file or directory to upload. 

65 region: Optional region override. 

66 """ 

67 args = ["models", "upload", source_path, "--name", model_name] 

68 if region: 

69 args += ["-r", region] 

70 return await asyncio.to_thread(cli_runner._run_cli, *args) 

71 

72 

73# ============================================================================= 

74# Destructive tools — gated by GCO_ENABLE_DESTRUCTIVE_OPERATIONS 

75# ============================================================================= 

76 

77 

78if is_enabled(FLAG_DESTRUCTIVE_OPERATIONS): 

79 

80 @mcp.tool(tags={"destructive", "models"}) 

81 @audit_logged 

82 async def delete_model(model_name: str) -> str: 

83 """[gated by GCO_ENABLE_DESTRUCTIVE_OPERATIONS] destructive. 

84 

85 `gco models delete` — delete a model from the central S3 bucket. 

86 Cannot be undone — every file under the model's S3 prefix is 

87 permanently removed. 

88 

89 Args: 

90 model_name: Name of the model to delete. 

91 """ 

92 await _ctx_warning(f"Deleting model {model_name!r} — this cannot be undone.") 

93 return await asyncio.to_thread(cli_runner._run_cli, "models", "delete", model_name, "-y")