Coverage for mcp/mission/criteria_scaffold.py: 94%
224 statements
« prev ^ index » next coverage.py v7.14.1, created at 2026-06-15 15:07 +0000
« prev ^ index » next coverage.py v7.14.1, created at 2026-06-15 15:07 +0000
1"""Helpers for ``gco mission scaffold-criteria``.
3The CLI subcommand turns a natural-language directive into a JSON
4array of Criterion objects that ``mission.validation.validate_criteria``
5accepts. Two paths are exposed:
7* :func:`generate_deterministic_criteria` — pure, no I/O. Keyword-matches
8 the directive against a small template table to pick a kind and shape
9 the criterion. The default fallback is a single ``predicate`` with
10 ``expression: "True"`` so the operator notices and edits before use.
11 Always emits at most ``max_criteria`` entries.
12* :func:`generate_sampled_criteria` — async, drives a resolved
13 :class:`SamplingBackend` to produce JSON. The response is parsed,
14 validated through ``validate_criteria``, and on rejection is retried
15 up to ``retries`` times with a feedback prompt mentioning the
16 rejection ``reason``. After the retry budget is exhausted, the helper
17 raises :class:`ScaffoldSamplingError` so the caller can fall back to
18 the deterministic path.
19* :func:`build_scaffold_prompt` — render the prompt the sampling
20 backend sees. Pure; lives here so tests can pin the exact text.
22The module is import-light: no FastMCP, no boto3, no MCP server. It
23imports the validators (and through them the predicate AST validator)
24and the sampling Protocol type, but nothing that touches a transport.
25The CLI wires the two paths together; this module keeps them
26decoupled so each can be tested in isolation.
27"""
29from __future__ import annotations
31import ast
32import json
33import re
34from dataclasses import dataclass
35from typing import TYPE_CHECKING, Any
37from . import validation as _validation
38from .predicate import PredicateRejected, parse_predicate
39from .validation import MissionValidationError
41# <pyflowchart-code-diagram> BEGIN - auto-inserted, do not edit
42# Flowchart(s) generated from this file:
43# * ``generate_sampled_criteria`` -> ``diagrams/code_diagrams/mcp/mission/criteria_scaffold.generate_sampled_criteria.html``
44# (PNG: ``diagrams/code_diagrams/mcp/mission/criteria_scaffold.generate_sampled_criteria.png``)
45# Regenerate with ``python diagrams/code_diagrams/generate.py``.
46# <pyflowchart-code-diagram> END
49if TYPE_CHECKING: # pragma: no cover - type-checker only
50 from .sampling import SamplingBackend
53__all__ = [
54 "DEFAULT_MAX_CRITERIA",
55 "DEFAULT_RETRIES",
56 "ScaffoldSamplingError",
57 "build_scaffold_prompt",
58 "generate_deterministic_criteria",
59 "generate_sampled_criteria",
60]
63# ---------------------------------------------------------------------------
64# Tunables
65# ---------------------------------------------------------------------------
67#: Default cap on the number of criteria scaffolded per call.
68DEFAULT_MAX_CRITERIA: int = 5
70#: Default retry count for the sampling path. Each retry re-prompts
71#: the model with a feedback message containing the rejection reason.
72DEFAULT_RETRIES: int = 3
74# ---------------------------------------------------------------------------
75# Keyword templates for the deterministic fallback
76# ---------------------------------------------------------------------------
78# Each entry is (regex, builder). The first match wins; builders
79# return a single Criterion dict that ``validate_criteria`` accepts.
80# The regex is matched case-insensitively against the directive.
81# Order matters: more specific patterns appear first.
83# "Lower is better" metrics (loss, error rate, latency, cost).
84_LOWER_IS_BETTER_RE = re.compile(r"\b(loss|error|latency|cost)\b", re.IGNORECASE)
86# "Higher is better" metrics (accuracy, throughput, recall, F1).
87_HIGHER_IS_BETTER_RE = re.compile(r"\b(accuracy|throughput|f1|recall|precision)\b", re.IGNORECASE)
89# Search-flavoured directives.
90_SEARCH_RE = re.compile(r"\b(find|search|discover|locate|lookup)\b", re.IGNORECASE)
92# Event-style directives.
93_EVENT_RE = re.compile(
94 r"\b(succeed|succeeded|complete|completed|finish|finished|emit)\b",
95 re.IGNORECASE,
96)
99def _slugify(value: str, fallback: str = "criterion") -> str:
100 """Turn a directive snippet into a stable criterion_id-friendly slug.
102 Lowercase, ASCII letters / digits / underscores only. Empty input
103 falls back to ``fallback``. Non-empty results are capped at 32
104 chars so the audit log entries don't get unwieldy.
105 """
106 cleaned = re.sub(r"[^A-Za-z0-9]+", "_", value.strip().lower()).strip("_")
107 if not cleaned: 107 ↛ 108line 107 didn't jump to line 108 because the condition on line 107 was never true
108 return fallback
109 return cleaned[:32]
112@dataclass(frozen=True)
113class _DirectiveMatch:
114 """Internal: a directive's matched template plus the captured token."""
116 kind: str
117 captured: str # the matched keyword; informs slug + metric name
120def _classify_directive(directive: str) -> _DirectiveMatch | None:
121 """Pick the matching template for ``directive``, or ``None`` for default.
123 The first match wins so more specific patterns can take precedence
124 over the generic "search" template by listing first. Returns
125 ``None`` when nothing matches; the caller then emits the
126 placeholder predicate fallback.
127 """
128 if (m := _LOWER_IS_BETTER_RE.search(directive)) is not None:
129 return _DirectiveMatch(kind="metric_threshold_lower", captured=m.group(1).lower())
130 if (m := _HIGHER_IS_BETTER_RE.search(directive)) is not None:
131 return _DirectiveMatch(kind="metric_threshold_higher", captured=m.group(1).lower())
132 if _SEARCH_RE.search(directive) is not None:
133 return _DirectiveMatch(kind="predicate_search", captured="search")
134 if _EVENT_RE.search(directive) is not None:
135 return _DirectiveMatch(kind="event", captured="job_succeeded")
136 return None
139def _build_metric_threshold(directive: str, captured: str, op: str) -> dict[str, Any]:
140 """Build a ``metric_threshold`` criterion for the given keyword.
142 The metric name uses ``val_<keyword>`` so it lines up with the
143 common validation-loss / val-accuracy convention; the target is a
144 placeholder the operator should override (0.1 for lower-is-better
145 metrics, 0.9 for higher-is-better metrics).
147 The dot-path is prefixed with ``metrics.`` because the engine's
148 Observe_Phase merges the dispatcher's top-level ``metrics`` dict
149 into the Observation under the ``metrics`` key, and the
150 ``_evaluate_metric_threshold`` resolver walks the path against the
151 Observation root. A bare ``val_loss`` (no prefix) would land on
152 every iteration as ``inconclusive: metric_path_missing`` because
153 the Observation's top level carries ``tool_results``, ``metrics``,
154 ``events`` — not loose metric values. See
155 :data:`tests.test_mission_e2e_train_to_loss` for the canonical
156 end-to-end shape this prefix lines up with.
157 """
158 slug = _slugify(captured, fallback="metric")
159 target = 0.1 if op in ("<", "<=") else 0.9
160 metric_name = f"val_{captured}" if captured in ("loss", "accuracy") else captured
161 return {
162 "criterion_id": f"{slug}_target",
163 "kind": "metric_threshold",
164 "required": True,
165 "metric": f"metrics.{metric_name}",
166 "op": op,
167 "target": target,
168 }
171def _build_predicate_search() -> dict[str, Any]:
172 """The canonical search predicate: the iteration produced any results.
174 Uses subscript form (``obs["tool_results"]``) rather than
175 ``obs.get(...)`` because the predicate AST validator rejects
176 method calls on ``obs`` — only the eight pure stdlib callables
177 are allowed. Subscript notation is the documented surface for
178 reading from the Observation.
179 """
180 return {
181 "criterion_id": "results_present",
182 "kind": "predicate",
183 "required": True,
184 "expression": "len(obs['tool_results']) > 0",
185 }
188def _build_tool_call_succeeded(tool_name: str) -> dict[str, Any]:
189 """Build a ``tool_call_succeeded`` criterion targeting ``tool_name``.
191 The slug is derived from the tool name so two ``tool_call_succeeded``
192 entries in the same list don't collide on ``criterion_id``. The
193 default ``min_count`` of 1 is left implicit on the criterion shape
194 so the operator can edit it after scaffolding without first
195 deleting an explicit value.
196 """
197 slug = _slugify(tool_name, fallback="tool")
198 return {
199 "criterion_id": f"{slug}_called",
200 "kind": "tool_call_succeeded",
201 "required": True,
202 "tool_name": tool_name,
203 }
206def _build_event(captured: str) -> dict[str, Any]:
207 """Build an ``event`` criterion using the captured keyword as the name."""
208 return {
209 "criterion_id": "expected_event",
210 "kind": "event",
211 "required": True,
212 "event_name": captured,
213 }
216def _build_default_placeholder() -> dict[str, Any]:
217 """Return the deterministic placeholder predicate.
219 The expression is the literal ``True`` so the criterion is always
220 met — this is intentional. The TODO note in the description is the
221 cue for the operator to edit the file before running. Mission's
222 validators accept the criterion as-is so the scaffolded output is
223 always usable, but a session run with this criterion unmodified
224 completes on iteration 0.
225 """
226 return {
227 "criterion_id": "todo_placeholder",
228 "kind": "predicate",
229 "required": True,
230 "expression": "True",
231 # Non-required pass-through key (not on the validator's
232 # required-keys list) so we don't trip schema validation.
233 # It surfaces in the JSON for the operator to read.
234 "description": "TODO: replace this placeholder with a real success condition.",
235 }
238def generate_deterministic_criteria(
239 directive: str,
240 *,
241 allowlist: list[str] | None = None,
242 max_criteria: int = DEFAULT_MAX_CRITERIA,
243) -> list[dict[str, Any]]:
244 """Build a criteria list deterministically from a directive.
246 Always returns a list that ``validate_criteria`` accepts. The
247 keyword-template lookup is naive on purpose — the fallback is
248 *guidance for the operator*, not a substitute for thinking about
249 the goal. The placeholder predicate is the explicit signal that
250 no template matched.
252 Args:
253 directive: The natural-language goal.
254 allowlist: Optional list of tool names. When the directive is
255 a search-flavoured goal *and* an allowlist is supplied,
256 the generator emits one ``tool_call_succeeded`` criterion
257 per allowlisted tool (capped at ``max_criteria``) instead
258 of the loose ``len(obs['tool_results']) > 0`` predicate.
259 That gives the operator concrete per-tool success
260 signals out of the box and keeps the criterion server-
261 evaluated rather than going through the predicate AST
262 sandbox. Falls back to the predicate when no allowlist
263 is supplied so existing callers keep their shape.
264 max_criteria: Cap on the number of entries returned. Always
265 at least 1; values less than 1 are clamped.
267 Returns:
268 A list of one or more Criterion dicts. The list always
269 validates through :func:`mission.validation.validate_criteria`.
270 """
271 if max_criteria < 1: 271 ↛ 272line 271 didn't jump to line 272 because the condition on line 271 was never true
272 max_criteria = 1
273 match = _classify_directive(directive)
274 if match is None:
275 return [_build_default_placeholder()]
276 if match.kind == "metric_threshold_lower":
277 return [_build_metric_threshold(directive, match.captured, "<=")]
278 if match.kind == "metric_threshold_higher":
279 return [_build_metric_threshold(directive, match.captured, ">=")]
280 if match.kind == "predicate_search":
281 # Prefer per-tool ``tool_call_succeeded`` criteria when the
282 # operator told us what tools they intend to allowlist —
283 # those are server-evaluated and require zero predicate
284 # syntax. Fall back to the loose predicate when no
285 # allowlist is available so the no-allowlist call shape
286 # stays exactly as it was.
287 if allowlist:
288 tool_names = list(allowlist)[:max_criteria]
289 return [_build_tool_call_succeeded(name) for name in tool_names]
290 return [_build_predicate_search()]
291 if match.kind == "event":
292 return [_build_event(match.captured)]
293 # Defensive fallback — keeps mypy happy with the exhaustive return.
294 return [_build_default_placeholder()] # pragma: no cover
297# ---------------------------------------------------------------------------
298# Sampling path
299# ---------------------------------------------------------------------------
302class ScaffoldSamplingError(Exception):
303 """Raised when every sampling attempt was rejected.
305 The caller (the CLI) catches this and falls back to the
306 deterministic path. The ``last_reason`` attribute carries the
307 rejection token from the final retry so the CLI can surface it
308 in a one-line warning.
309 """
311 def __init__(self, last_reason: str, message: str | None = None) -> None:
312 self.last_reason: str = last_reason
313 super().__init__(message or last_reason)
316def build_scaffold_prompt(
317 directive: str,
318 *,
319 allowlist: list[str] | None = None,
320 max_criteria: int = DEFAULT_MAX_CRITERIA,
321 feedback: str | None = None,
322) -> str:
323 """Render the prompt the sampling backend sees.
325 The prompt asks for a strict JSON-array response, one entry per
326 Criterion. The shape is described inline so the model doesn't
327 need to fetch a schema document. The ``feedback`` argument carries
328 the rejection reason from a prior attempt — when present, it is
329 appended as a "feedback" block telling the model why the previous
330 response was rejected.
331 """
332 allowlist_block = "(none specified)" if not allowlist else ", ".join(allowlist)
333 sections: list[str] = []
334 sections.append(
335 "You are drafting Success_Criteria for a Mission goal-directed "
336 "iteration loop. The operator's directive and the tool "
337 "allowlist follow. Produce a JSON array of criterion objects "
338 "the operator can hand to `gco mission start --criteria-file`."
339 )
340 sections.append("")
341 sections.append("=== Directive ===")
342 sections.append(directive)
343 sections.append("")
344 sections.append("=== Tool allowlist ===")
345 sections.append(allowlist_block)
346 sections.append("")
347 sections.append(f"=== Cap: at most {max_criteria} criterion entries ===")
348 sections.append("")
349 sections.append("=== Observation shape (read by predicates and metric paths) ===")
350 sections.append(
351 "Each iteration's Observation is a dict with these fields:\n"
352 ' - "tool_results": list[dict] — every tool the iteration\n'
353 " called returns one entry. Each entry is whatever the\n"
354 " tool itself returned, plus a top-level ``_status`` flag.\n"
355 ' - "metrics": dict[str, Any] — numeric / scalar values\n'
356 " surfaced by tools that emit them. The dot-path for a\n"
357 " metric_threshold criterion against ``val_loss`` is\n"
358 ' ``"metrics.val_loss"`` (NOT ``"val_loss"``); the engine\n'
359 " walks the path against the Observation root and a bare\n"
360 " name will land as ``inconclusive: metric_path_missing``\n"
361 " on every iteration.\n"
362 ' - "events": list[dict] — emitted events, each with an\n'
363 " ``event_name`` key.\n"
364 ' - "errors" (optional): list[dict] — errors any tool raised.\n'
365 ' - "phase_started_at" / "phase_ended_at": ISO-8601 strings.'
366 )
367 sections.append("")
368 sections.append("=== Output schema ===")
369 sections.append(
370 "Return a single JSON array. Each entry is an object with "
371 "these required keys:\n"
372 ' - "criterion_id": unique non-empty string\n'
373 ' - "kind": one of "metric_threshold" / "event" / '
374 '"predicate" / "tool_call_succeeded"\n'
375 ' - "required": JSON boolean\n'
376 "Plus the kind-specific keys:\n"
377 ' metric_threshold -> "metric" (DOT-PATH into the\n'
378 " Observation, e.g.\n"
379 " ``metrics.val_loss``,\n"
380 " ``tool_results.0.score``), "
381 '"op" (one of <, <=, >, >=, ==, !=), "target" (number)\n'
382 ' event -> "event_name" (non-empty string;\n'
383 ' matched against entries in obs["events"])\n'
384 ' tool_call_succeeded -> "tool_name" (non-empty string;\n'
385 " matched against entries in\n"
386 ' ``obs["tool_results"]`` whose\n'
387 ' ``_status`` equals ``"ok"``).\n'
388 ' Optional: "min_count" (positive\n'
389 " int, default 1).\n"
390 " PREFER this kind over a predicate\n"
391 ' when the goal is "this tool ran\n'
392 ' and succeeded" — it is server-\n'
393 " evaluated and never goes through\n"
394 " the predicate AST sandbox.\n"
395 ' predicate -> "expression" (a Python expression\n'
396 " evaluated against `obs` — see\n"
397 " the predicate vocabulary section\n"
398 " below for the exact surface)"
399 )
400 sections.append("")
401 sections.append("=== Predicate vocabulary ===")
402 sections.append(
403 "Predicate expressions run inside a tight AST sandbox. The\n"
404 "allowed surface:\n"
405 "\n"
406 "Names: ``obs`` (the Observation dict).\n"
407 "Top-level callables (twelve, all pure stdlib):\n"
408 " ``len``, ``min``, ``max``, ``sum``, ``abs``,\n"
409 " ``any``, ``all``, ``sorted``,\n"
410 " ``str``, ``int``, ``float``, ``bool`` (type coercions).\n"
411 "Read-only method calls on any value (seven, all pure):\n"
412 " ``.get(key[, default])``, ``.keys()``, ``.values()``,\n"
413 " ``.items()``, ``.lower()``, ``.upper()``, ``.strip()``\n"
414 "Operators: arithmetic, comparisons (<, <=, >, >=, ==, !=,\n"
415 " is, is not, in, not in), boolean (and, or, not), ternary\n"
416 " (a if b else c).\n"
417 "Containers: list/tuple/dict/set literals, list / set / dict\n"
418 " / generator comprehensions (the comprehension target may\n"
419 " not shadow ``obs`` or any callable name).\n"
420 "Subscripts: ``obs['key']``, ``obs['k']['nested']``,\n"
421 " ``obs['list'][0]``, etc.\n"
422 "Attribute access: ONLY single-level on ``obs`` (e.g. ``obs.events``\n"
423 " for read-only access; subscript form is preferred). Nested\n"
424 " walks like ``obs.a.b`` are rejected — use ``obs['a']['b']``.\n"
425 "\n"
426 "Method calls outside the seven pure-accessor names are\n"
427 "rejected (no ``.append``, ``.update``, ``.pop``, ``.count``,\n"
428 "``.startswith``, ``.split``, etc.). Calls to non-allowlisted\n"
429 "names (``list``, ``dict``, ``getattr``, ``isinstance``, ...)\n"
430 "are rejected."
431 )
432 sections.append("")
433 sections.append("=== Predicate examples (do NOT use rejected forms) ===")
434 sections.append(
435 "ACCEPTED predicate expressions:\n"
436 " len(obs['tool_results']) > 0\n"
437 " obs['metrics']['val_loss'] < 0.1\n"
438 " any(e['event_name'] == 'goal_reached' for e in obs['events'])\n"
439 " any(r.get('_status') == 'ok' for r in obs['tool_results'])\n"
440 " all(r.get('_status') == 'ok' for r in obs['tool_results'])\n"
441 " any(r.get('_status') == 'ok' and r.get('tool_name') == 'find_docs'\n"
442 " for r in obs['tool_results'])\n"
443 " any('inference' in str(r).lower() for r in obs['tool_results'])\n"
444 " len(obs.get('errors', [])) == 0\n"
445 " any(k == 'val_loss' for k in obs['metrics'].keys())\n"
446 "\n"
447 "REJECTED predicate expressions (will fail validation):\n"
448 " obs.metrics.val_loss < 0.1 # nested attribute walk; use obs['metrics']['val_loss']\n" # noqa: E501
449 " obs['tool_results'].count('ok') # ``.count`` is not on the method allowlist\n"
450 " obs['tool_results'].append(1) # ``.append`` mutates and is not allowed\n"
451 " any(r.split(',') for r in obs['tool_results']) # ``.split`` not on method allowlist\n"
452 " any(k.startswith('val_') for k in obs['metrics'].keys()) # ``.startswith`` not on method allowlist\n" # noqa: E501
453 " getattr(obs, 'tool_results') # ``getattr`` not on callable allowlist\n"
454 " obs['x'].y.z # attribute walk after subscript"
455 )
456 sections.append("")
457 sections.append("Output only the JSON array. No prose, no markdown fences.")
458 if feedback:
459 sections.append("")
460 sections.append("=== Feedback on previous attempt ===")
461 sections.append(feedback)
462 return "\n".join(sections)
465def _parse_response(text: str) -> list[dict[str, Any]]:
466 """Extract a JSON array from a model response.
468 Models occasionally wrap JSON in markdown fences; tolerate that.
469 Raises ``ValueError`` when no JSON array is recoverable.
470 """
471 stripped = text.strip()
472 # Strip markdown fences if present.
473 if stripped.startswith("```"):
474 # remove first fence line
475 first_newline = stripped.find("\n")
476 if first_newline != -1: 476 ↛ 478line 476 didn't jump to line 478 because the condition on line 476 was always true
477 stripped = stripped[first_newline + 1 :]
478 if stripped.endswith("```"): 478 ↛ 481line 478 didn't jump to line 481 because the condition on line 478 was always true
479 stripped = stripped[:-3].rstrip()
480 # Find the first '[' and last ']' so we tolerate trailing prose.
481 start = stripped.find("[")
482 end = stripped.rfind("]")
483 if start == -1 or end == -1 or end < start:
484 raise ValueError("no JSON array found in response")
485 parsed = json.loads(stripped[start : end + 1])
486 if not isinstance(parsed, list): 486 ↛ 487line 486 didn't jump to line 487 because the condition on line 486 was never true
487 raise ValueError("JSON payload is not a list")
488 out: list[dict[str, Any]] = []
489 for entry in parsed:
490 if not isinstance(entry, dict): 490 ↛ 491line 490 didn't jump to line 491 because the condition on line 490 was never true
491 raise ValueError("array entry is not an object")
492 out.append(entry)
493 return out
496def _normalize_kind_name(criterion: dict[str, Any]) -> dict[str, Any]:
497 """Rewrite obvious ``kind`` typos to the canonical names.
499 Models occasionally emit a near-miss for the criterion ``kind``
500 field — pluralising (``tool_calls_succeeded`` instead of the
501 canonical ``tool_call_succeeded``), abbreviating
502 (``threshold`` instead of ``metric_threshold``), or hyphenating
503 (``tool-call-succeeded`` instead of underscore form). The
504 structural validator rejects these with ``kind_invalid`` and the
505 retry-with-feedback path can recover, but the typos are
506 mechanical: a closed alias map covers every captured emission
507 we have seen across Bedrock models.
509 The map is intentionally narrow — we only canonicalise a name
510 when it is unambiguously a typo for one of the four valid
511 kinds, never a name a future kind extension might claim. Returns
512 the input unchanged when the kind is already canonical, missing,
513 or not a string. Returns a shallow copy when a rewrite fires so
514 the input dict is never mutated.
515 """
516 kind = criterion.get("kind")
517 if not isinstance(kind, str):
518 return criterion
519 canonical = _KIND_ALIASES.get(kind)
520 if canonical is None:
521 return criterion
522 if canonical == kind: 522 ↛ 523line 522 didn't jump to line 523 because the condition on line 522 was never true
523 return criterion
524 out = dict(criterion)
525 out["kind"] = canonical
526 return out
529# Closed alias map for ``_normalize_kind_name``. Every entry here was
530# observed in the captured fixture corpus under
531# ``tests/fixtures/scaffold_responses/`` — adding a new entry is the
532# right move only when a captured model emits a near-miss the
533# rejection-feedback retry doesn't recover on the next attempt.
534_KIND_ALIASES: dict[str, str] = {
535 # Llama 4 Scout pluralises the kind name in its first emission.
536 "tool_calls_succeeded": "tool_call_succeeded",
537 # Hyphenated forms occasionally surface from JSON-schema-trained
538 # smaller models that map ``snake_case`` onto ``kebab-case``.
539 "tool-call-succeeded": "tool_call_succeeded",
540 "metric-threshold": "metric_threshold",
541}
544def _normalize_metric_path(criterion: dict[str, Any]) -> dict[str, Any]:
545 """Auto-prefix bare metric names with ``metrics.`` for ``metric_threshold``.
547 The engine's metric path resolver walks the dot-path against the
548 Observation root, where canonical metric values live under the
549 ``metrics`` sub-dict. A bare ``"val_loss"`` lands as
550 ``inconclusive: metric_path_missing`` on every iteration.
552 Models trained on generic metric semantics tend to emit bare
553 names anyway. Rather than reject the response and burn a retry,
554 this normaliser injects the ``metrics.`` prefix when:
556 1. ``kind == "metric_threshold"``,
557 2. ``metric`` is a non-empty string,
558 3. The string contains no ``.`` separator (so already-qualified
559 paths like ``tool_results.0.score`` or
560 ``metrics.something.nested`` pass through verbatim).
562 Returns a shallow copy so the input is never mutated. The strip is
563 idempotent on already-prefixed values: ``"metrics.foo"`` has a
564 ``.`` so it falls through unchanged.
565 """
566 if criterion.get("kind") != "metric_threshold":
567 return criterion
568 metric = criterion.get("metric")
569 if not isinstance(metric, str) or not metric:
570 return criterion
571 if "." in metric:
572 return criterion
573 out = dict(criterion)
574 out["metric"] = f"metrics.{metric}"
575 return out
578class _AttributeToSubscriptRewriter(ast.NodeTransformer):
579 """Rewrite ``obs.<attr>`` chains as ``obs['<attr>']`` chains.
581 The predicate validator accepts a single-level attribute read on
582 ``obs`` (``obs.tool_results``) but rejects nested attribute walks
583 (``obs.metrics.val_loss``) and method-style calls
584 (``obs.x.any()``, ``obs.get('x')``). Models routinely emit those
585 shapes because they are the obvious Pythonic idioms. This
586 transformer rewrites the *attribute-walk* shapes mechanically;
587 method-call shapes that need creative rewriting are left alone so
588 the standard retry-with-feedback loop can teach the model.
590 The walk only rewrites attribute reads whose innermost base is the
591 ``Name('obs')`` — every other attribute access (e.g. on a list
592 element returned from a comprehension, on a number) is left
593 untouched so the validator's other guards still apply.
594 """
596 def visit_Attribute(self, node: ast.Attribute) -> ast.AST: # noqa: N802 - ast hook name
597 # Recurse into the value first so a nested attribute walk gets
598 # rewritten bottom-up: ``obs.metrics.val_loss`` -> visit
599 # ``obs.metrics`` first (which becomes ``obs['metrics']``)
600 # then wrap the result in ``[...]['val_loss']``.
601 self.generic_visit(node)
602 # Only rewrite when the rewritten base is one of:
603 # * Name('obs') — the simple ``obs.x`` case
604 # * Subscript whose ultimate base is Name('obs') — the
605 # already-rewritten ``obs['metrics']`` case
606 # Anything else (attribute on a Call, on a list literal, on a
607 # comprehension target) is left as-is so the validator's
608 # rejections still fire on shapes the autofix shouldn't try to
609 # silently rescue.
610 base = node.value
611 innermost = base
612 while isinstance(innermost, ast.Subscript):
613 innermost = innermost.value
614 if not (isinstance(innermost, ast.Name) and innermost.id == "obs"): 614 ↛ 615line 614 didn't jump to line 615 because the condition on line 614 was never true
615 return node
616 return ast.Subscript(
617 value=base,
618 slice=ast.Constant(value=node.attr),
619 ctx=node.ctx,
620 )
623def _autofix_predicate(criterion: dict[str, Any]) -> dict[str, Any]:
624 """Best-effort rewrite of attribute-walk predicates into subscript form.
626 Keeps the crit dict unchanged when:
628 * ``kind != "predicate"``
629 * the expression is missing or non-string
630 * the expression already parses cleanly through
631 :func:`mission.predicate.parse_predicate`
632 * source has a syntax error (the validator will reject it with the
633 original code anyway)
634 * the rewritten expression *still* fails validation (so the
635 retry-with-feedback path runs against the original source the
636 model emitted, not a partially-rewritten one)
638 Returns a shallow copy with the rewritten ``expression`` only when
639 the rewrite produced a predicate that clears the validator. This
640 mirrors :func:`_normalize_metric_path` — never mutates input,
641 always returns a JSON-safe dict.
642 """
643 if criterion.get("kind") != "predicate":
644 return criterion
645 expression = criterion.get("expression")
646 if not isinstance(expression, str) or not expression:
647 return criterion
648 # Cheap fast path: if the source is already valid, don't pay the
649 # cost of an AST round-trip on the happy case.
650 try:
651 parse_predicate(expression)
652 return criterion
653 except PredicateRejected:
654 pass
656 try:
657 tree = ast.parse(expression, mode="eval")
658 except SyntaxError:
659 return criterion
661 rewritten_tree = _AttributeToSubscriptRewriter().visit(tree)
662 ast.fix_missing_locations(rewritten_tree)
663 try:
664 rewritten_src = ast.unparse(rewritten_tree)
665 except Exception: # noqa: BLE001 - unparse failure leaves us no better off
666 return criterion
668 # Re-validate the rewrite. If the rewrite still doesn't validate
669 # (e.g. a method call like ``obs.x.any()`` produced
670 # ``obs['x'].any()`` which is still a method-call-on-subscript),
671 # fall back to the original so the retry-with-feedback loop sees
672 # the model's actual emission.
673 try:
674 parse_predicate(rewritten_src)
675 except PredicateRejected:
676 return criterion
678 out = dict(criterion)
679 out["expression"] = rewritten_src
680 return out
683async def generate_sampled_criteria(
684 backend: SamplingBackend,
685 directive: str,
686 *,
687 allowlist: list[str] | None = None,
688 max_criteria: int = DEFAULT_MAX_CRITERIA,
689 retries: int = DEFAULT_RETRIES,
690) -> list[dict[str, Any]]:
691 """Drive a sampling backend to produce a validated criteria list.
693 Builds the prompt, calls ``backend.sample(prompt_str)``, parses
694 the JSON, validates through :func:`validate_criteria`, and on
695 rejection retries up to ``retries`` times with feedback. Returns
696 the validated list (with private ``_parsed_ast`` keys stripped so
697 the result is JSON-safe). Raises :class:`ScaffoldSamplingError`
698 when every attempt was rejected.
700 The backend is duck-typed against the ``SamplingBackend`` protocol
701 on purpose — tests can substitute a stub object whose ``sample``
702 method returns canned strings without bringing in a transport.
703 """
704 feedback: str | None = None
705 last_reason = "no_attempts"
706 # We do retries + 1 total attempts — the first attempt is "free",
707 # then each retry is one extra try.
708 for attempt in range(retries + 1):
709 prompt_str = build_scaffold_prompt(
710 directive,
711 allowlist=allowlist,
712 max_criteria=max_criteria,
713 feedback=feedback,
714 )
715 try:
716 raw = await _call_backend(backend, prompt_str)
717 except Exception as exc: # noqa: BLE001 - transport-agnostic catch
718 # Transport-layer failures are not retriable from the
719 # scaffolder's point of view — the backend itself decides
720 # whether to recover. Surface as a sampling error so the
721 # CLI falls back deterministically.
722 raise ScaffoldSamplingError(
723 "transport_error",
724 message=f"sampling backend raised {type(exc).__name__}: {exc}",
725 ) from exc
726 try:
727 parsed = _parse_response(raw)
728 except (ValueError, json.JSONDecodeError) as exc:
729 last_reason = "json_parse"
730 feedback = (
731 "Your previous response could not be parsed as a JSON "
732 "array. Return a single JSON array, no prose, no "
733 f"markdown fences. ({exc})"
734 )
735 continue
736 # Cap to max_criteria — if the model returned more, truncate
737 # rather than rejecting outright. The structural validator
738 # below catches everything else.
739 if len(parsed) > max_criteria:
740 parsed = parsed[:max_criteria]
741 # Best-effort kind-name normalisation runs first so the
742 # metric-path / predicate-autofix passes branch correctly on
743 # the canonical ``kind``. Models occasionally pluralise
744 # (``tool_calls_succeeded``) or hyphenate
745 # (``tool-call-succeeded``) the kind name; rewriting to the
746 # canonical form here saves a retry round-trip. The map is
747 # closed and explicit — see ``_KIND_ALIASES``.
748 parsed = [_normalize_kind_name(c) for c in parsed]
749 # Best-effort normalisation: a model that emits a bare metric
750 # name (``"val_loss"``) instead of the dot-path
751 # (``"metrics.val_loss"``) the engine actually walks would
752 # otherwise produce a session whose metric_threshold criterion
753 # silently evaluates ``inconclusive: metric_path_missing`` on
754 # every iteration. The prompt now teaches this convention but
755 # we still post-process for robustness against older prompts
756 # and models that ignore the schema.
757 parsed = [_normalize_metric_path(c) for c in parsed]
758 # Best-effort autofix for predicate expressions: the predicate
759 # AST validator rejects attribute-walk patterns
760 # (``obs.metrics.val_loss``) and method-call shapes
761 # (``obs.get('x')``, ``obs.x.any()``) — both are common Python
762 # idioms the model defaults to. The rewriter rescues the
763 # attribute-walk shape into subscript notation; method-call
764 # shapes that need creative rewriting fall through to the
765 # standard retry-with-feedback path so the model gets the
766 # rejection token and tries again.
767 parsed = [_autofix_predicate(c) for c in parsed]
768 try:
769 validated = _validation.validate_criteria(parsed)
770 except MissionValidationError as exc:
771 details = exc.details or {}
772 last_reason = str(details.get("reason") or exc.code)
773 feedback = (
774 "Your previous response was rejected by the validator. "
775 f"Rejection reason: {last_reason}. Details: {details!r}. "
776 "Re-emit a corrected JSON array."
777 )
778 continue
779 # Strip private cached AST keys so the JSON written to disk is
780 # round-trippable. ``_parsed_ast`` is attached to predicate
781 # entries by ``validate_criteria``.
782 del attempt
783 return [
784 {k: v for k, v in entry.items() if not str(k).startswith("_")} for entry in validated
785 ]
786 raise ScaffoldSamplingError(last_reason)
789async def _call_backend(backend: SamplingBackend, prompt_str: str) -> str:
790 """Adapt the protocol's ``sample(SamplingPrompt)`` call to a string prompt.
792 The :class:`SamplingBackend` protocol takes a structured
793 :class:`SamplingPrompt`. The criteria-scaffold use case is a
794 one-off prompt rather than a full Mission round-trip, so we
795 construct a minimal ``SamplingPrompt`` whose render produces
796 exactly ``prompt_str``. Backends that need extra context
797 (Bedrock's region, MCP's model preferences) read from their bound
798 state and ignore the prompt's surrounding fields.
800 Tests can substitute a stub backend whose ``sample`` returns a
801 canned string; those tests pass the stub directly to
802 :func:`generate_sampled_criteria` and bypass the protocol entirely.
803 """
804 # Lazy import to avoid the import cycle: sampling imports validation
805 # which would otherwise import this module.
806 from .sampling import SamplingPrompt # noqa: PLC0415
808 # Wrap the prompt string in a dataclass that renders to itself.
809 # The full SamplingPrompt has many required fields; the scaffolder
810 # uses a thin adapter that overrides ``assemble`` so the existing
811 # backend implementations call ``assemble()`` and get the prompt.
812 prompt_obj = _PromptAdapter(prompt_str)
813 # Backends accept any object with an ``assemble`` method. Both
814 # MCPSamplingBackend and BedrockSamplingBackend call
815 # ``prompt.assemble()`` to get the rendered string.
816 del SamplingPrompt # imported only for documentation linkage
817 return await backend.sample(prompt_obj) # type: ignore[arg-type]
820class _PromptAdapter:
821 """Minimal duck-typed stand-in for :class:`SamplingPrompt`.
823 Both backends call ``prompt.assemble()`` to render the prompt
824 string. This adapter satisfies that single contract so the
825 scaffolder can route a free-form prompt through the same backend
826 surface the engine uses, without constructing a full
827 SamplingPrompt with iteration history that does not exist for a
828 one-off scaffolding call.
829 """
831 def __init__(self, text: str) -> None:
832 self._text: str = text
834 def assemble(self) -> str:
835 return self._text