Source code for slapo.framework_dialect.registry
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
"""Framework dialect registration."""
DIALECTS = {
"pipeline_stage": {},
"pipeline_engine": {},
"runtime_engine": {None: lambda model, **kwargs: (model, None)},
"log_parser": {},
}
[docs]def register_framework_dialect(target, cls_type):
"""Register a framework dialect."""
if cls_type not in DIALECTS:
raise ValueError(f"Only support {DIALECTS.keys()}, but got {cls_type}")
def decorator(dialect_cls):
if "target" in DIALECTS[cls_type]:
raise ValueError(
f"Target {target} already registered for {cls_type} dialects"
)
DIALECTS[cls_type][target] = dialect_cls
return dialect_cls
return decorator
[docs]def get_all_dialects(cls_type):
"""Get all registered framework dialects."""
if cls_type not in DIALECTS:
raise ValueError(f"Only support {DIALECTS.keys()}, but got {cls_type}")
return DIALECTS[cls_type]
[docs]def get_dialect_cls(cls_type, target, allow_none=False):
"""Get the framework dialect class."""
if cls_type not in DIALECTS:
raise ValueError(f"Only support {DIALECTS.keys()}, but got {cls_type}")
if target not in DIALECTS[cls_type]:
if allow_none:
if None in DIALECTS[cls_type]:
target = None
else:
raise ValueError(
f"Target {target} does not register default dialect for {cls_type}"
)
else:
raise ValueError(f"Target {target} not registered for {cls_type} dialects")
return DIALECTS[cls_type][target]