# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# // SPDX-License-Identifier: BSD
import io
import logging
import os
import urllib.parse
from contextlib import contextmanager
from pathlib import Path
from typing import Generator, Union, Optional
from s3torchconnectorclient._mountpoint_s3_client import S3Exception
from tenacity import (
retry,
stop_after_attempt,
retry_if_exception_type,
before_sleep_log,
after_log,
wait_random_exponential,
)
from torch.distributed.checkpoint.filesystem import (
FileSystemReader,
FileSystemWriter,
FileSystemBase,
)
import torch
from s3torchconnector._s3client import S3Client
from s3torchconnector._s3dataset_common import parse_s3_uri
from .._user_agent import UserAgent
[docs]
logger = logging.getLogger(__name__)
[docs]
class S3FileSystem(FileSystemBase):
def __init__(self, region: str, s3_client: Optional[S3Client] = None) -> None:
self._path: Union[str, os.PathLike] = ""
user_agent = UserAgent(["dcp", torch.__version__])
self._client = (
s3_client
if s3_client is not None
else S3Client(region=region, user_agent=user_agent)
)
@contextmanager
[docs]
def create_stream(
self, path: Union[str, os.PathLike], mode: str
) -> Generator[io.IOBase, None, None]:
"""
Create a stream for reading or writing to S3.
Args:
path (Union[str, os.PathLike]): The S3 path to read or write.
mode (str): The mode for the stream. Supports 'rb' for read mode and 'wb' for write mode.
Yields:
io.BufferedIOBase: A stream for reading or writing to S3.
Raises:
ValueError: If the mode is not 'rb' or 'wb'.
"""
path_str = _path_or_str_to_str(path)
bucket, key = parse_s3_uri(path_str)
if mode == "wb": # write mode
logger.debug("create_stream writable for %s", path_str)
with self._client.put_object(bucket, key) as stream:
yield stream
elif mode == "rb": # read mode
logger.debug("create_stream readable for %s", path_str)
with self._client.get_object(bucket, key) as stream:
yield stream
else:
raise ValueError(
f"Invalid {mode=} mode argument: create_stream only supports rb (read mode) & wb (write mode)"
)
[docs]
def concat_path(self, path: Union[str, os.PathLike], suffix: str) -> str:
"""
Concatenate a suffix to the given path.
Args:
path (Union[str, os.PathLike]): The base path.
suffix (str): The suffix to concatenate.
Returns:
str: The concatenated path.
"""
logger.debug("concat paths %s and %s", path, suffix)
path_str = os.fspath(path)
result = os.path.join(path_str, suffix)
return result
[docs]
def init_path(self, path: Union[str, os.PathLike]) -> Union[str, os.PathLike]:
"""
Initialize the path for the filesystem.
Args:
path (Union[str, os.PathLike]): The path to initialize.
Returns:
Union[str, os.PathLike]: The initialized path.
"""
logger.debug("init_path for %s", path)
self._path = path
return self._path
[docs]
def rename(
self, old_path: Union[str, os.PathLike], new_path: Union[str, os.PathLike]
) -> None:
"""Rename an object in S3.
This is emulated by copying it to a new path and deleting the old path. The deletion part is retried (see also
:func:`S3FileSystem._delete_with_retry`).
Args:
old_path (Union[str, os.PathLike]): The current path of the object.
new_path (Union[str, os.PathLike]): The new path for the object.
Raises:
ValueError: If the old and new paths point to different buckets.
S3Exception: If there is an error with the S3 client.
"""
logger.debug("rename %s to %s", old_path, new_path)
old_path_str = _path_or_str_to_str(old_path)
new_path_str = _path_or_str_to_str(new_path)
old_bucket, old_key = parse_s3_uri(old_path_str)
escaped_old_key = self._escape_path(old_key)
logger.debug("rename: escaped version of the source key: %s", escaped_old_key)
new_bucket, new_key = parse_s3_uri(new_path_str)
if old_bucket != new_bucket:
raise ValueError(
f"Source and destination buckets cannot be different (rename does not support cross-buckets operations)"
)
self._client.copy_object(
src_bucket=old_bucket,
src_key=escaped_old_key,
dst_bucket=new_bucket,
dst_key=new_key,
)
logger.debug("rename: copied %s to %s successfully", old_path_str, new_path_str)
self._delete_with_retry(old_bucket, old_key)
logger.debug("rename: s3://%s/%s successfully", old_bucket, old_key)
[docs]
def mkdir(self, path: Union[str, os.PathLike]) -> None:
"""No-op method for creating directories in S3 (not needed)."""
pass
[docs]
def exists(self, path: Union[str, os.PathLike]) -> bool:
logger.debug("exists %s", path)
path_str = _path_or_str_to_str(path)
bucket, key = parse_s3_uri(path_str)
try:
self._client.head_object(bucket, key)
except S3Exception as e:
if str(e) != "Service error: The object was not found":
raise
return False
return True
[docs]
def rm_file(self, path: Union[str, os.PathLike]) -> None:
logger.debug("remove %s", path)
path_str = _path_or_str_to_str(path)
bucket, key = parse_s3_uri(path_str)
try:
self._client.delete_object(bucket, key)
except S3Exception:
logger.exception("Failed to remove object from S3")
@classmethod
[docs]
def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool:
logger.debug("validate_checkpoint_id for %s", checkpoint_id)
if isinstance(checkpoint_id, Path):
return True
try:
parse_s3_uri(_path_or_str_to_str(checkpoint_id))
except ValueError:
return False
return True
@retry(
retry=retry_if_exception_type(S3Exception),
stop=stop_after_attempt(3),
wait=wait_random_exponential(multiplier=1, max=5),
before_sleep=before_sleep_log(logger, logging.WARNING),
after=after_log(logger, logging.ERROR),
reraise=True,
)
def _delete_with_retry(self, bucket_name: str, old_key: str):
"""Wrapper around :func:`S3Client.delete_object` to retry the deletion.
Will retry a maximum of 3 times, only for `S3Exception`s, and wait between retries. It will reraise the caught
exception too, and logs retries and final error, if any."""
self._client.delete_object(bucket_name, old_key)
@staticmethod
def _escape_path(string):
"""URL-encodes path segments while preserving '/' separators using urllib.parse.quote().
Args:
string (str): URL path string to escape
Returns:
str: Path string with each segment percent-encoded, separators preserved
"""
if not string:
return string
parts = []
for part in string.split("/"):
parts.append(urllib.parse.quote(part, safe=""))
return "/".join(parts)
[docs]
class S3StorageWriter(FileSystemWriter):
def __init__(
self,
region: str,
path: str,
**kwargs,
) -> None:
"""
Initialize an S3 writer for distributed checkpointing.
Args:
region (str): The AWS region for S3.
path (str): The S3 URI to write checkpoints to.
kwargs (dict): Keyword arguments to pass to the parent :class:`FileSystemWriter`.
"""
super().__init__(
path=path,
sync_files=False, # FIXME: setting this to True makes the run to fail (L#333: `os.fsync(stream.fileno())`)
**kwargs,
)
[docs]
self.fs = S3FileSystem(region) # type: ignore
[docs]
self.path = self.fs.init_path(path)
@classmethod
[docs]
def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool:
return S3FileSystem.validate_checkpoint_id(checkpoint_id)
[docs]
class S3StorageReader(FileSystemReader):
def __init__(self, region: str, path: Union[str, os.PathLike]) -> None:
"""
Initialize an S3 reader for distributed checkpointing.
Args:
region (str): The AWS region for S3.
path (Union[str, os.PathLike]): The S3 path to read checkpoints from.
"""
super().__init__(path)
[docs]
self.fs = S3FileSystem(region) # type: ignore
[docs]
self.path = self.fs.init_path(path)
[docs]
self.sync_files = False
@classmethod
[docs]
def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool:
return S3FileSystem.validate_checkpoint_id(checkpoint_id)
def _path_or_str_to_str(path: Union[str, os.PathLike]) -> str:
return path if isinstance(path, str) else str(path)