Source code for slapo.framework_dialect.registry

# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
"""Framework dialect registration."""

DIALECTS = {
    "pipeline_stage": {},
    "pipeline_engine": {},
    "runtime_engine": {None: lambda model, **kwargs: (model, None)},
    "log_parser": {},
}


[docs]def register_framework_dialect(target, cls_type): """Register a framework dialect.""" if cls_type not in DIALECTS: raise ValueError(f"Only support {DIALECTS.keys()}, but got {cls_type}") def decorator(dialect_cls): if "target" in DIALECTS[cls_type]: raise ValueError( f"Target {target} already registered for {cls_type} dialects" ) DIALECTS[cls_type][target] = dialect_cls return dialect_cls return decorator
[docs]def get_all_dialects(cls_type): """Get all registered framework dialects.""" if cls_type not in DIALECTS: raise ValueError(f"Only support {DIALECTS.keys()}, but got {cls_type}") return DIALECTS[cls_type]
[docs]def get_dialect_cls(cls_type, target, allow_none=False): """Get the framework dialect class.""" if cls_type not in DIALECTS: raise ValueError(f"Only support {DIALECTS.keys()}, but got {cls_type}") if target not in DIALECTS[cls_type]: if allow_none: if None in DIALECTS[cls_type]: target = None else: raise ValueError( f"Target {target} does not register default dialect for {cls_type}" ) else: raise ValueError(f"Target {target} not registered for {cls_type} dialects") return DIALECTS[cls_type][target]