Source code for s3torchconnector.dcp.s3_prefix_strategy

#  Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#  // SPDX-License-Identifier: BSD
from abc import ABC, abstractmethod
from typing import List, Optional

import torch.distributed as dist


[docs] class S3PrefixStrategyBase(ABC): """Base class for S3 prefix generation strategies.""" def __init__(self): pass def __call__(self, rank: int) -> str: """Generate prefix for given rank.""" return self.generate_prefix(rank) @abstractmethod
[docs] def generate_prefix(self, rank: int) -> str: """Generate storage prefix for the given rank.""" pass
[docs] class DefaultPrefixStrategy(S3PrefixStrategyBase): """Default strategy for generating S3 prefixes."""
[docs] def generate_prefix(self, rank: int) -> str: """Generate simple rank-based name without prefix.""" return f"__{rank}_"
[docs] class NumericPrefixStrategy(S3PrefixStrategyBase): """Base class for numeric prefix generation strategies.""" def __init__( self, base: int, epoch_num: Optional[int] = None, min_prefix_length: int = 10, prefix_count: Optional[int] = None, ): """ Initialize numeric prefix strategy. Args: base (int): The numeric base for the prefix (e.g., 2 for binary, 16 for hex). epoch_num (int, optional): Epoch number for checkpoint ordering. If None, no epoch information will be included in the prefix. Defaults to None. min_prefix_length (int): Minimum length of the generated prefix. Prefix will be padded with trailing zeros if necessary. Must be positive. Defaults to 10. prefix_count (int, optional): Number of unique prefixes to generate. If not provided, world size will be used as default value. Defaults to None. Raises: ValueError: If epoch_num, min_prefix_length, or prefix_count are invalid. """ if min_prefix_length < 1: raise ValueError( f"Minimum prefix length must be positive, got {min_prefix_length}" ) if epoch_num is not None and not isinstance(epoch_num, int): raise ValueError( f"Epoch number must be None or an integer, got {epoch_num}" ) if prefix_count is not None and ( not isinstance(prefix_count, int) or prefix_count < 1 ): raise ValueError( f"Prefix count must be a positive integer, got {prefix_count}" ) super().__init__()
[docs] self.base = base
[docs] self.epoch_num = epoch_num
[docs] self.min_prefix_len = min_prefix_length
[docs] self.prefix_count = 1
if prefix_count is not None: self.prefix_count = prefix_count elif dist.is_initialized(): self.prefix_count = dist.get_world_size()
[docs] self.prefix_map = self._generate_prefix_map()
[docs] def generate_prefix(self, rank: int) -> str: """ Generate numeric-based prefix with optional epoch number. Args: rank: Process rank in the distributed environment. Returns: Prefix string in format: <pattern>/epoch_<num>/__<rank>_ or <pattern>/__<rank>_ if no epoch number is provided. """ epoch_suffix = f"epoch_{self.epoch_num}/" if self.epoch_num is not None else "" return f"{self.prefix_map[rank % len(self.prefix_map)]}/{epoch_suffix}__{rank}_"
def _generate_prefix_map(self) -> List[str]: """Generate mapping of ranks to numeric-based prefixes.""" minimum_required_length = self._calculate_prefix_length() adjusted_prefix_length = max(minimum_required_length, self.min_prefix_len) all_prefixes = [ self._format_number(i, adjusted_prefix_length)[::-1] for i in range(self.prefix_count) ] return all_prefixes def _calculate_prefix_length(self) -> int: """Calculate minimum prefix length needed for unique combinations.""" prefix_length = 1 size = self.base while size < self.prefix_count: prefix_length += 1 size *= self.base return prefix_length @abstractmethod def _format_number(self, number: int, length: int) -> str: """Format a number to the appropriate base representation.""" pass
[docs] class BinaryPrefixStrategy(NumericPrefixStrategy): """Binary (Base2) prefix generation strategy using only 0 and 1.""" def __init__( self, epoch_num: Optional[int] = None, min_prefix_length: int = 10, prefix_count: Optional[int] = None, ): super().__init__( base=2, epoch_num=epoch_num, min_prefix_length=min_prefix_length, prefix_count=prefix_count, ) def _format_number(self, number: int, length: int) -> str: return format(number, f"0{length}b")
[docs] class HexPrefixStrategy(NumericPrefixStrategy): """Hexadecimal-based prefix generation strategy.""" def __init__( self, epoch_num: Optional[int] = None, min_prefix_length: int = 10, prefix_count: Optional[int] = None, ): super().__init__( base=16, epoch_num=epoch_num, min_prefix_length=min_prefix_length, prefix_count=prefix_count, ) def _format_number(self, number: int, length: int) -> str: return format(number, f"0{length}x")
[docs] class RoundRobinPrefixStrategy(S3PrefixStrategyBase): """Strategy that distributes ranks across user-provided prefixes in round-robin fashion.""" def __init__(self, user_prefixes: List[str], epoch_num: Optional[int] = None): """ Initialize round-robin prefix strategy. Args: user_prefixes: List of prefixes to distribute ranks across. Must not be empty. epoch_num: Epoch number for checkpoint ordering. Raises: ValueError: If user_prefixes is empty. """ super().__init__() if not user_prefixes: raise ValueError("user_prefixes must not be empty")
[docs] self.user_prefixes = user_prefixes
[docs] self.epoch_num = epoch_num
[docs] def generate_prefix(self, rank: int) -> str: """ Generate prefix for given rank using round-robin distribution. Args: rank: Process rank in the distributed environment. Returns: Prefix string in format: <user_prefix>/epoch_<num>/__<rank>_ or <user_prefix>/__<rank>_ if no epoch number is provided. """ epoch_suffix = f"epoch_{self.epoch_num}/" if self.epoch_num is not None else "" return f"{self.user_prefixes[rank % len(self.user_prefixes)]}/{epoch_suffix}__{rank}_"