Module sagemaker_defect_detection.dataset.neu
None
None
View Source
from typing import List, Tuple, Optional, Callable
import os
from pathlib import Path
from collections import namedtuple
from xml.etree.ElementTree import ElementTree
import numpy as np
import cv2
import torch
from torch.utils.data.dataset import Dataset
from torchvision.datasets import ImageFolder
class NEUCLS(ImageFolder):
    """
    NEU-CLS dataset processing and loading
    """
    def __init__(
        self,
        root: str,
        split: str,
        augmentation: Optional[Callable] = None,
        preprocessing: Optional[Callable] = None,
        seed: int = 123,
        **kwargs,
    ) -> None:
        """
        NEU-CLS dataset
        Parameters
        ----------
        root : str
            Dataset root path
        split : str
            Data split from train, val and test
        augmentation : Optional[Callable], optional
            Image augmentation function, by default None
        preprocess : Optional[Callable], optional
            Image preprocessing function, by default None
        seed : int, optional
            Random number generator seed, by default 123
        Raises
        ------
        ValueError
            If unsupported split is used
        """
        super().__init__(root, **kwargs)
        self.samples: List[Tuple[str, int]]
        self.split = split
        self.augmentation = augmentation
        self.preprocessing = preprocessing
        n_items = len(self.samples)
        np.random.seed(seed)
        perm = np.random.permutation(list(range(n_items)))
        # TODO: add split ratios as parameters
        train_end = int(0.6 * n_items)
        val_end = int(0.2 * n_items) + train_end
        if split == "train":
            self.samples = [self.samples[i] for i in perm[:train_end]]
        elif split == "val":
            self.samples = [self.samples[i] for i in perm[train_end:val_end]]
        elif split == "test":
            self.samples = [self.samples[i] for i in perm[val_end:]]
        else:
            raise ValueError(f"Unknown split mode. Choose from `train`, `val` or `test`. Given {split}")
DetectionSample = namedtuple("DetectionSample", ["image_path", "class_idx", "annotations"])
class NEUDET(Dataset):
    """
    NEU-DET dataset processing and loading
    """
    def __init__(
        self,
        root: str,
        split: str,
        augmentation: Optional[Callable] = None,
        preprocess: Optional[Callable] = None,
        seed: int = 123,
    ):
        """
        NEU-DET dataset
        Parameters
        ----------
        root : str
            Dataset root path
        split : str
            Data split from train, val and test
        augmentation : Optional[Callable], optional
            Image augmentation function, by default None
        preprocess : Optional[Callable], optional
            Image preprocessing function, by default None
        seed : int, optional
            Random number generator seed, by default 123
        Raises
        ------
        ValueError
            If unsupported split is used
        """
        super().__init__()
        self.root = Path(root)
        self.split = split
        self.classes, self.class_to_idx = self._find_classes()
        self.samples: List[DetectionSample] = self._make_dataset()
        self.augmentation = augmentation
        self.preprocess = preprocess
        n_items = len(self.samples)
        np.random.seed(seed)
        perm = np.random.permutation(list(range(n_items)))
        train_end = int(0.6 * n_items)
        val_end = int(0.2 * n_items) + train_end
        if split == "train":
            self.samples = [self.samples[i] for i in perm[:train_end]]
        elif split == "val":
            self.samples = [self.samples[i] for i in perm[train_end:val_end]]
        elif split == "test":
            self.samples = [self.samples[i] for i in perm[val_end:]]
        else:
            raise ValueError(f"Unknown split mode. Choose from `train`, `val` or `test`. Given {split}")
    def _make_dataset(self) -> List[DetectionSample]:
        instances = []
        base_dir = self.root.expanduser()
        for target_cls in sorted(self.class_to_idx.keys()):
            cls_idx = self.class_to_idx[target_cls]
            target_dir = base_dir / target_cls
            if not target_dir.is_dir():
                continue
            images = sorted(list((target_dir / "images").glob("*.jpg")))
            annotations = sorted(list((target_dir / "annotations").glob("*.xml")))
            assert len(images) == len(annotations), f"something is wrong. Mismatched number of images and annotations"
            for path, ann in zip(images, annotations):
                instances.append(DetectionSample(str(path), int(cls_idx), str(ann)))
        return instances
    def _find_classes(self):
        classes = sorted([d.name for d in os.scandir(str(self.root)) if d.is_dir()])
        class_to_idx = {cls_name: i for i, cls_name in enumerate(classes, 1)}  # no bg label in NEU
        return classes, class_to_idx
    @staticmethod
    def _get_bboxes(ann: str) -> List[List[int]]:
        tree = ElementTree().parse(ann)
        bboxes = []
        for bndbox in tree.iterfind("object/bndbox"):
            # should subtract 1 like coco?
            bbox = [int(bndbox.findtext(t)) - 1 for t in ("xmin", "ymin", "xmax", "ymax")]  # type: ignore
            assert bbox[2] > bbox[0] and bbox[3] > bbox[1], f"box size error, given {bbox}"
            bboxes.append(bbox)
        return bboxes
    def __len__(self):
        return len(self.samples)
    def __getitem__(self, idx: int):
        # Note: images are grayscaled BUT resnet needs 3 channels
        image = cv2.imread(self.samples[idx].image_path)  # BGR channel last
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        boxes = self._get_bboxes(self.samples[idx].annotations)
        num_objs = len(boxes)
        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        labels = torch.tensor([self.samples[idx].class_idx] * num_objs, dtype=torch.int64)
        image_id = torch.tensor([idx], dtype=torch.int64)
        iscrowd = torch.zeros((len(boxes),), dtype=torch.int64)
        target = {}
        target["boxes"] = boxes
        target["labels"] = labels
        target["image_id"] = image_id
        target["iscrowd"] = iscrowd
        if self.augmentation is not None:
            sample = self.augmentation(**{"image": image, "bboxes": boxes, "labels": labels})
            image = sample["image"]
            target["boxes"] = torch.as_tensor(sample["bboxes"], dtype=torch.float32)
            # guards against crops that don't pass the min_visibility augmentation threshold
            if not target["boxes"].numel():
                return None
            target["labels"] = torch.as_tensor(sample["labels"], dtype=torch.int64)
        if self.preprocess is not None:
            image = self.preprocess(image=image)["image"]
        boxes = target["boxes"]
        target["area"] = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
        return image, target, image_id
    def collate_fn(self, batch):
        batch = filter(lambda x: x is not None, batch)
        return tuple(zip(*batch))
Classes
DetectionSample
class DetectionSample(
    /,
    *args,
    **kwargs
)
Ancestors (in MRO)
- builtins.tuple
Class variables
annotations
class_idx
image_path
Methods
count
def count(
    self,
    value,
    /
)
Return number of occurrences of value.
index
def index(
    self,
    value,
    start=0,
    stop=9223372036854775807,
    /
)
Return first index of value.
Raises ValueError if the value is not present.
NEUCLS
class NEUCLS(
    root: str,
    split: str,
    augmentation: Union[Callable, NoneType] = None,
    preprocessing: Union[Callable, NoneType] = None,
    seed: int = 123,
    **kwargs
)
Ancestors (in MRO)
- torchvision.datasets.folder.ImageFolder
- torchvision.datasets.folder.DatasetFolder
- torchvision.datasets.vision.VisionDataset
- torch.utils.data.dataset.Dataset
Methods
extra_repr
def extra_repr(
    self
)
View Source
    def extra_repr(self):
        return ""
NEUDET
class NEUDET(
    root: str,
    split: str,
    augmentation: Union[Callable, NoneType] = None,
    preprocess: Union[Callable, NoneType] = None,
    seed: int = 123
)
Ancestors (in MRO)
- torch.utils.data.dataset.Dataset
Methods
collate_fn
def collate_fn(
    self,
    batch
)
View Source
    def collate_fn(self, batch):
        batch = filter(lambda x: x is not None, batch)
        return tuple(zip(*batch))