Source code for slapo.pattern

# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0

from torch import nn


[docs]class Pattern(nn.Module):
[docs] def forward(self, *args): raise NotImplementedError
[docs]class ModulePattern(nn.Module): def __init__(self, name): super().__init__() self.name = name
[docs] def forward(self, *args): raise NotImplementedError
def call_module(mod_name, *args): raise NotImplementedError