Source code for s3torchconnector.dcp.s3_file_system

#  Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#  // SPDX-License-Identifier: BSD

import io
import logging
import os
import urllib.parse
from contextlib import contextmanager
from pathlib import Path
from typing import Generator, Union, Optional

from s3torchconnectorclient._mountpoint_s3_client import S3Exception
from tenacity import (
    retry,
    stop_after_attempt,
    retry_if_exception_type,
    before_sleep_log,
    after_log,
    wait_random_exponential,
)
from torch.distributed.checkpoint.filesystem import (
    FileSystemReader,
    FileSystemWriter,
    FileSystemBase,
)
import torch

from s3torchconnector._s3client import S3Client
from s3torchconnector._s3dataset_common import parse_s3_uri
from .._user_agent import UserAgent

[docs] logger = logging.getLogger(__name__)
[docs] class S3FileSystem(FileSystemBase): def __init__(self, region: str, s3_client: Optional[S3Client] = None) -> None: self._path: Union[str, os.PathLike] = "" user_agent = UserAgent(["dcp", torch.__version__]) self._client = ( s3_client if s3_client is not None else S3Client(region=region, user_agent=user_agent) ) @contextmanager
[docs] def create_stream( self, path: Union[str, os.PathLike], mode: str ) -> Generator[io.IOBase, None, None]: """ Create a stream for reading or writing to S3. Args: path (Union[str, os.PathLike]): The S3 path to read or write. mode (str): The mode for the stream. Supports 'rb' for read mode and 'wb' for write mode. Yields: io.BufferedIOBase: A stream for reading or writing to S3. Raises: ValueError: If the mode is not 'rb' or 'wb'. """ path_str = _path_or_str_to_str(path) bucket, key = parse_s3_uri(path_str) if mode == "wb": # write mode logger.debug("create_stream writable for %s", path_str) with self._client.put_object(bucket, key) as stream: yield stream elif mode == "rb": # read mode logger.debug("create_stream readable for %s", path_str) with self._client.get_object(bucket, key) as stream: yield stream else: raise ValueError( f"Invalid {mode=} mode argument: create_stream only supports rb (read mode) & wb (write mode)" )
[docs] def concat_path(self, path: Union[str, os.PathLike], suffix: str) -> str: """ Concatenate a suffix to the given path. Args: path (Union[str, os.PathLike]): The base path. suffix (str): The suffix to concatenate. Returns: str: The concatenated path. """ logger.debug("concat paths %s and %s", path, suffix) path_str = os.fspath(path) result = os.path.join(path_str, suffix) return result
[docs] def init_path(self, path: Union[str, os.PathLike]) -> Union[str, os.PathLike]: """ Initialize the path for the filesystem. Args: path (Union[str, os.PathLike]): The path to initialize. Returns: Union[str, os.PathLike]: The initialized path. """ logger.debug("init_path for %s", path) self._path = path return self._path
[docs] def rename( self, old_path: Union[str, os.PathLike], new_path: Union[str, os.PathLike] ) -> None: """Rename an object in S3. This is emulated by copying it to a new path and deleting the old path. The deletion part is retried (see also :func:`S3FileSystem._delete_with_retry`). Args: old_path (Union[str, os.PathLike]): The current path of the object. new_path (Union[str, os.PathLike]): The new path for the object. Raises: ValueError: If the old and new paths point to different buckets. S3Exception: If there is an error with the S3 client. """ logger.debug("rename %s to %s", old_path, new_path) old_path_str = _path_or_str_to_str(old_path) new_path_str = _path_or_str_to_str(new_path) old_bucket, old_key = parse_s3_uri(old_path_str) escaped_old_key = self._escape_path(old_key) logger.debug("rename: escaped version of the source key: %s", escaped_old_key) new_bucket, new_key = parse_s3_uri(new_path_str) if old_bucket != new_bucket: raise ValueError( f"Source and destination buckets cannot be different (rename does not support cross-buckets operations)" ) self._client.copy_object( src_bucket=old_bucket, src_key=escaped_old_key, dst_bucket=new_bucket, dst_key=new_key, ) logger.debug("rename: copied %s to %s successfully", old_path_str, new_path_str) self._delete_with_retry(old_bucket, old_key) logger.debug("rename: s3://%s/%s successfully", old_bucket, old_key)
[docs] def mkdir(self, path: Union[str, os.PathLike]) -> None: """No-op method for creating directories in S3 (not needed).""" pass
[docs] def exists(self, path: Union[str, os.PathLike]) -> bool: logger.debug("exists %s", path) path_str = _path_or_str_to_str(path) bucket, key = parse_s3_uri(path_str) try: self._client.head_object(bucket, key) except S3Exception as e: if str(e) != "Service error: The object was not found": raise return False return True
[docs] def rm_file(self, path: Union[str, os.PathLike]) -> None: logger.debug("remove %s", path) path_str = _path_or_str_to_str(path) bucket, key = parse_s3_uri(path_str) try: self._client.delete_object(bucket, key) except S3Exception: logger.exception("Failed to remove object from S3")
@classmethod
[docs] def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool: logger.debug("validate_checkpoint_id for %s", checkpoint_id) if isinstance(checkpoint_id, Path): return True try: parse_s3_uri(_path_or_str_to_str(checkpoint_id)) except ValueError: return False return True
@retry( retry=retry_if_exception_type(S3Exception), stop=stop_after_attempt(3), wait=wait_random_exponential(multiplier=1, max=5), before_sleep=before_sleep_log(logger, logging.WARNING), after=after_log(logger, logging.ERROR), reraise=True, ) def _delete_with_retry(self, bucket_name: str, old_key: str): """Wrapper around :func:`S3Client.delete_object` to retry the deletion. Will retry a maximum of 3 times, only for `S3Exception`s, and wait between retries. It will reraise the caught exception too, and logs retries and final error, if any.""" self._client.delete_object(bucket_name, old_key) @staticmethod def _escape_path(string): """URL-encodes path segments while preserving '/' separators using urllib.parse.quote(). Args: string (str): URL path string to escape Returns: str: Path string with each segment percent-encoded, separators preserved """ if not string: return string parts = [] for part in string.split("/"): parts.append(urllib.parse.quote(part, safe="")) return "/".join(parts)
[docs] class S3StorageWriter(FileSystemWriter): def __init__( self, region: str, path: str, **kwargs, ) -> None: """ Initialize an S3 writer for distributed checkpointing. Args: region (str): The AWS region for S3. path (str): The S3 URI to write checkpoints to. kwargs (dict): Keyword arguments to pass to the parent :class:`FileSystemWriter`. """ super().__init__( path=path, sync_files=False, # FIXME: setting this to True makes the run to fail (L#333: `os.fsync(stream.fileno())`) **kwargs, )
[docs] self.fs = S3FileSystem(region) # type: ignore
[docs] self.path = self.fs.init_path(path)
@classmethod
[docs] def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool: return S3FileSystem.validate_checkpoint_id(checkpoint_id)
[docs] class S3StorageReader(FileSystemReader): def __init__(self, region: str, path: Union[str, os.PathLike]) -> None: """ Initialize an S3 reader for distributed checkpointing. Args: region (str): The AWS region for S3. path (Union[str, os.PathLike]): The S3 path to read checkpoints from. """ super().__init__(path)
[docs] self.fs = S3FileSystem(region) # type: ignore
[docs] self.path = self.fs.init_path(path)
[docs] self.sync_files = False
@classmethod
[docs] def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool: return S3FileSystem.validate_checkpoint_id(checkpoint_id)
def _path_or_str_to_str(path: Union[str, os.PathLike]) -> str: return path if isinstance(path, str) else str(path)