Source code for s3torchconnector.s3iterable_dataset

#  Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#  // SPDX-License-Identifier: BSD
from functools import partial
from typing import Iterator, Any, Union, Iterable, Callable, Optional
import logging

import torch.utils.data
import torch

from . import S3Reader
from ._s3bucket_key_data import S3BucketKeyData
from ._s3client import S3Client, S3ClientConfig
from ._s3dataset_common import (
    identity,
    get_objects_from_uris,
    get_objects_from_prefix,
)

[docs] log = logging.getLogger(__name__)
[docs] class S3IterableDataset(torch.utils.data.IterableDataset): """An IterableStyle dataset created from S3 objects. To create an instance of S3IterableDataset, you need to use `from_prefix` or `from_objects` methods. """ def __init__( self, region: str, get_dataset_objects: Callable[[S3Client], Iterable[S3BucketKeyData]], endpoint: Optional[str] = None, transform: Callable[[S3Reader], Any] = identity, s3client_config: Optional[S3ClientConfig] = None, enable_sharding: bool = False, ): self._get_dataset_objects = get_dataset_objects self._transform = transform self._region = region self._endpoint = endpoint self._s3client_config = s3client_config self._client = None self._enable_sharding = enable_sharding self._rank = 0 self._world_size = 1 if torch.distributed.is_initialized(): self._rank = torch.distributed.get_rank() self._world_size = torch.distributed.get_world_size() @property
[docs] def region(self): return self._region
@property
[docs] def endpoint(self): return self._endpoint
@classmethod
[docs] def from_objects( cls, object_uris: Union[str, Iterable[str]], *, region: str, endpoint: Optional[str] = None, transform: Callable[[S3Reader], Any] = identity, s3client_config: Optional[S3ClientConfig] = None, enable_sharding: bool = False, ): """Returns an instance of S3IterableDataset using the S3 URI(s) provided. Args: object_uris(str | Iterable[str]): S3 URI of the object(s) desired. region(str): AWS region of the S3 bucket where the objects are stored. endpoint(str): AWS endpoint of the S3 bucket where the objects are stored. transform: Optional callable which is used to transform an S3Reader into the desired type. s3client_config: Optional S3ClientConfig with parameters for S3 client. enable_sharding: If True, shard the dataset across multiple workers for parallel data loading. If False (default), each worker loads the entire dataset independently. Returns: S3IterableDataset: An IterableStyle dataset created from S3 objects. Raises: S3Exception: An error occurred accessing S3. """ log.info(f"Building {cls.__name__} from_objects") return cls( region, partial(get_objects_from_uris, object_uris), endpoint, transform=transform, s3client_config=s3client_config, enable_sharding=enable_sharding, )
@classmethod
[docs] def from_prefix( cls, s3_uri: str, *, region: str, endpoint: Optional[str] = None, transform: Callable[[S3Reader], Any] = identity, s3client_config: Optional[S3ClientConfig] = None, enable_sharding: bool = False, ): """Returns an instance of S3IterableDataset using the S3 URI provided. Args: s3_uri(str): An S3 URI (prefix) of the object(s) desired. Objects matching the prefix will be included in the returned dataset. region(str): AWS region of the S3 bucket where the objects are stored. endpoint(str): AWS endpoint of the S3 bucket where the objects are stored. transform: Optional callable which is used to transform an S3Reader into the desired type. s3client_config: Optional S3ClientConfig with parameters for S3 client. enable_sharding: If True, shard the dataset across multiple workers for parallel data loading. If False (default), each worker loads the entire dataset independently. Returns: S3IterableDataset: An IterableStyle dataset created from S3 objects. Raises: S3Exception: An error occurred accessing S3. """ log.info(f"Building {cls.__name__} from_prefix {s3_uri=}") return cls( region, partial(get_objects_from_prefix, s3_uri), endpoint, transform=transform, s3client_config=s3client_config, enable_sharding=enable_sharding, )
def _get_client(self): if self._client is None: self._client = S3Client( self.region, endpoint=self.endpoint, s3client_config=self._s3client_config, ) return self._client def _get_transformed_object(self, bucket_key: S3BucketKeyData) -> Any: return self._transform( self._get_client().get_object( bucket_key.bucket, bucket_key.key, object_info=bucket_key.object_info ) ) def __iter__(self) -> Iterator[Any]: worker_id = 0 num_workers = 1 if self._enable_sharding: worker_info = torch.utils.data.get_worker_info() if worker_info is not None: worker_id = worker_info.id num_workers = worker_info.num_workers if not self._enable_sharding or (self._world_size == 1 and num_workers == 1): # sharding disabled or only one shard is available, so return the entire dataset return map( self._get_transformed_object, self._get_dataset_objects(self._get_client()), ) """In a multi-process setting (e.g., distributed training), the dataset needs to be sharded across multiple processes. The following variables control this sharding: _rank: The rank (index) of the current process within the world (group of processes). _world_size: The total number of processes in the world (group). In addition, within each process, the dataset may be further sharded across multiple worker threads or processes (e.g., for data loading). The following variables control this intra-process sharding: worker_id: The ID of the current worker thread/process within the process. num_workers: The total number of worker threads/processes within the process. """ # First, distribute objects across ranks rank_sharded_objects = ( obj for idx, obj in enumerate(self._get_dataset_objects(self._get_client())) if idx % self._world_size == self._rank ) # Then, distribute objects within each rank across workers worker_sharded_objects = ( obj for idx, obj in enumerate(rank_sharded_objects) if idx % num_workers == worker_id ) return map(self._get_transformed_object, worker_sharded_objects)