Source code for slapo.logger

# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
"""Logging."""
import sys
import logging
from logging import getLevelName

import torch.distributed as dist

FORMATTER = logging.Formatter(
    "[%(asctime)s][%(levelname)s][%(filename)s:%(lineno)d:%(funcName)s] %(message)s"
)
STREAM_HANDLER = logging.StreamHandler()
STREAM_HANDLER.setFormatter(FORMATTER)

LOGGER_TABLE = {}

# Syntax suger.
CRITICAL = logging.CRITICAL
FATAL = logging.FATAL
ERROR = logging.ERROR
WARNING = logging.WARNING
WARN = logging.WARN
INFO = logging.INFO
DEBUG = logging.DEBUG
NOTSET = logging.NOTSET


[docs]def get_logger(name="Slapo", level=INFO): """Attach to the default logger.""" if name in LOGGER_TABLE: logger = LOGGER_TABLE[name] if logger.level != level: logger.warning( f"Logger {name} already exists with {getLevelName(logger.level)}. " f"The new level {getLevelName(level)} will be ignored." ) return logger logger = logging.getLogger(name) logger.setLevel(level) logger.propagate = False ch = logging.StreamHandler(stream=sys.stdout) ch.setLevel(level) ch.setFormatter(FORMATTER) logger.addHandler(ch) orig_log = logger._log def wrapper(level, msg, *args, **kwargs): """Log when distributed is not initialized or when the rank is in the list. Note that ranks=None means all ranks. """ ranks = kwargs.pop("ranks", None) group = kwargs.pop("group", None) # Always log when distributed is not initialized or ranks are not specified. should_log = True if dist.is_initialized(): my_rank = dist.get_rank(group) rank_info = f"[Rank {my_rank}] " if ranks is not None: # Only log when the current rank is in the list. ranks = ranks if isinstance(ranks, (list, tuple)) else [ranks] should_log = my_rank in set(ranks) else: rank_info = "" if should_log: orig_log( level, f"{rank_info}{msg}", *args, **kwargs, ) logger._log = wrapper LOGGER_TABLE[name] = logger return logger