Module sagemaker_defect_detection.dataset.neu
None
None
View Source
from typing import List, Tuple, Optional, Callable
import os
from pathlib import Path
from collections import namedtuple
from xml.etree.ElementTree import ElementTree
import numpy as np
import cv2
import torch
from torch.utils.data.dataset import Dataset
from torchvision.datasets import ImageFolder
class NEUCLS(ImageFolder):
"""
NEU-CLS dataset processing and loading
"""
def __init__(
self,
root: str,
split: str,
augmentation: Optional[Callable] = None,
preprocessing: Optional[Callable] = None,
seed: int = 123,
**kwargs,
) -> None:
"""
NEU-CLS dataset
Parameters
----------
root : str
Dataset root path
split : str
Data split from train, val and test
augmentation : Optional[Callable], optional
Image augmentation function, by default None
preprocess : Optional[Callable], optional
Image preprocessing function, by default None
seed : int, optional
Random number generator seed, by default 123
Raises
------
ValueError
If unsupported split is used
"""
super().__init__(root, **kwargs)
self.samples: List[Tuple[str, int]]
self.split = split
self.augmentation = augmentation
self.preprocessing = preprocessing
n_items = len(self.samples)
np.random.seed(seed)
perm = np.random.permutation(list(range(n_items)))
# TODO: add split ratios as parameters
train_end = int(0.6 * n_items)
val_end = int(0.2 * n_items) + train_end
if split == "train":
self.samples = [self.samples[i] for i in perm[:train_end]]
elif split == "val":
self.samples = [self.samples[i] for i in perm[train_end:val_end]]
elif split == "test":
self.samples = [self.samples[i] for i in perm[val_end:]]
else:
raise ValueError(f"Unknown split mode. Choose from `train`, `val` or `test`. Given {split}")
DetectionSample = namedtuple("DetectionSample", ["image_path", "class_idx", "annotations"])
class NEUDET(Dataset):
"""
NEU-DET dataset processing and loading
"""
def __init__(
self,
root: str,
split: str,
augmentation: Optional[Callable] = None,
preprocess: Optional[Callable] = None,
seed: int = 123,
):
"""
NEU-DET dataset
Parameters
----------
root : str
Dataset root path
split : str
Data split from train, val and test
augmentation : Optional[Callable], optional
Image augmentation function, by default None
preprocess : Optional[Callable], optional
Image preprocessing function, by default None
seed : int, optional
Random number generator seed, by default 123
Raises
------
ValueError
If unsupported split is used
"""
super().__init__()
self.root = Path(root)
self.split = split
self.classes, self.class_to_idx = self._find_classes()
self.samples: List[DetectionSample] = self._make_dataset()
self.augmentation = augmentation
self.preprocess = preprocess
n_items = len(self.samples)
np.random.seed(seed)
perm = np.random.permutation(list(range(n_items)))
train_end = int(0.6 * n_items)
val_end = int(0.2 * n_items) + train_end
if split == "train":
self.samples = [self.samples[i] for i in perm[:train_end]]
elif split == "val":
self.samples = [self.samples[i] for i in perm[train_end:val_end]]
elif split == "test":
self.samples = [self.samples[i] for i in perm[val_end:]]
else:
raise ValueError(f"Unknown split mode. Choose from `train`, `val` or `test`. Given {split}")
def _make_dataset(self) -> List[DetectionSample]:
instances = []
base_dir = self.root.expanduser()
for target_cls in sorted(self.class_to_idx.keys()):
cls_idx = self.class_to_idx[target_cls]
target_dir = base_dir / target_cls
if not target_dir.is_dir():
continue
images = sorted(list((target_dir / "images").glob("*.jpg")))
annotations = sorted(list((target_dir / "annotations").glob("*.xml")))
assert len(images) == len(annotations), f"something is wrong. Mismatched number of images and annotations"
for path, ann in zip(images, annotations):
instances.append(DetectionSample(str(path), int(cls_idx), str(ann)))
return instances
def _find_classes(self):
classes = sorted([d.name for d in os.scandir(str(self.root)) if d.is_dir()])
class_to_idx = {cls_name: i for i, cls_name in enumerate(classes, 1)} # no bg label in NEU
return classes, class_to_idx
@staticmethod
def _get_bboxes(ann: str) -> List[List[int]]:
tree = ElementTree().parse(ann)
bboxes = []
for bndbox in tree.iterfind("object/bndbox"):
# should subtract 1 like coco?
bbox = [int(bndbox.findtext(t)) - 1 for t in ("xmin", "ymin", "xmax", "ymax")] # type: ignore
assert bbox[2] > bbox[0] and bbox[3] > bbox[1], f"box size error, given {bbox}"
bboxes.append(bbox)
return bboxes
def __len__(self):
return len(self.samples)
def __getitem__(self, idx: int):
# Note: images are grayscaled BUT resnet needs 3 channels
image = cv2.imread(self.samples[idx].image_path) # BGR channel last
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
boxes = self._get_bboxes(self.samples[idx].annotations)
num_objs = len(boxes)
boxes = torch.as_tensor(boxes, dtype=torch.float32)
labels = torch.tensor([self.samples[idx].class_idx] * num_objs, dtype=torch.int64)
image_id = torch.tensor([idx], dtype=torch.int64)
iscrowd = torch.zeros((len(boxes),), dtype=torch.int64)
target = {}
target["boxes"] = boxes
target["labels"] = labels
target["image_id"] = image_id
target["iscrowd"] = iscrowd
if self.augmentation is not None:
sample = self.augmentation(**{"image": image, "bboxes": boxes, "labels": labels})
image = sample["image"]
target["boxes"] = torch.as_tensor(sample["bboxes"], dtype=torch.float32)
# guards against crops that don't pass the min_visibility augmentation threshold
if not target["boxes"].numel():
return None
target["labels"] = torch.as_tensor(sample["labels"], dtype=torch.int64)
if self.preprocess is not None:
image = self.preprocess(image=image)["image"]
boxes = target["boxes"]
target["area"] = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
return image, target, image_id
def collate_fn(self, batch):
batch = filter(lambda x: x is not None, batch)
return tuple(zip(*batch))
Classes
DetectionSample
class DetectionSample(
/,
*args,
**kwargs
)
Ancestors (in MRO)
- builtins.tuple
Class variables
annotations
class_idx
image_path
Methods
count
def count(
self,
value,
/
)
Return number of occurrences of value.
index
def index(
self,
value,
start=0,
stop=9223372036854775807,
/
)
Return first index of value.
Raises ValueError if the value is not present.
NEUCLS
class NEUCLS(
root: str,
split: str,
augmentation: Union[Callable, NoneType] = None,
preprocessing: Union[Callable, NoneType] = None,
seed: int = 123,
**kwargs
)
Ancestors (in MRO)
- torchvision.datasets.folder.ImageFolder
- torchvision.datasets.folder.DatasetFolder
- torchvision.datasets.vision.VisionDataset
- torch.utils.data.dataset.Dataset
Methods
extra_repr
def extra_repr(
self
)
View Source
def extra_repr(self):
return ""
NEUDET
class NEUDET(
root: str,
split: str,
augmentation: Union[Callable, NoneType] = None,
preprocess: Union[Callable, NoneType] = None,
seed: int = 123
)
Ancestors (in MRO)
- torch.utils.data.dataset.Dataset
Methods
collate_fn
def collate_fn(
self,
batch
)
View Source
def collate_fn(self, batch):
batch = filter(lambda x: x is not None, batch)
return tuple(zip(*batch))