Source code for slapo.primitives.base
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
"""Schedule primitive base."""
from __future__ import annotations
from abc import abstractmethod
PRIMITIVES = {}
[docs]def register_primitive():
"""Register a primitive to the schedule."""
def dectorator(cls):
if cls.name() in PRIMITIVES:
raise ValueError(f"Primitive {cls.name()} already registered")
if not issubclass(cls, Primitive):
raise ValueError(f"Class {cls} is not a subclass of Primitive")
PRIMITIVES[cls.name()] = cls
return cls
return dectorator
class Primitive:
"""A base class of schedule primitives."""
@staticmethod
@abstractmethod
def name():
"""The name of the primitive."""
raise NotImplementedError
@staticmethod
@abstractmethod
def apply(sch, *args, **kwargs):
"""Apply the primitive to the schedule."""
raise NotImplementedError
@staticmethod
def is_verifiable():
"""Is the primitive verifiable."""
return False
@staticmethod
def init_metadata():
"""(Optional) Initialize the metadata of the primitive."""
return None