Skip to content

Module neu

Dependencies: unzip unrar

python -m pip install patool pyunpack

View Source
"""

Dependencies: unzip unrar

python -m pip install patool pyunpack

"""

from pathlib import Path

import shutil

import re

import os

try:

    from pyunpack import Archive

except ModuleNotFoundError:

    print("installing the dependencies `patool` and `pyunpack` for unzipping the data")

    import subprocess

    subprocess.run("python -m pip install patool==1.12 pyunpack==0.2.1 -q", shell=True)

    from pyunpack import Archive

CLASSES = {

    "crazing": "Cr",

    "inclusion": "In",

    "pitted_surface": "PS",

    "patches": "Pa",

    "rolled-in_scale": "RS",

    "scratches": "Sc",

}

def unpack(path: str) -> None:

    path = Path(path)

    Archive(str(path)).extractall(str(path.parent))

    return

def cp_class_images(data_path: Path, class_name: str, class_path_dest: Path) -> None:

    lst = list(data_path.rglob(f"{class_name}_*"))

    for img_file in lst:

        shutil.copy2(str(img_file), str(class_path_dest / img_file.name))

    assert len(lst) == len(list(class_path_dest.glob("*")))

    return

def cp_image_annotation(data_path: Path, class_name: str, image_path_dest: Path, annotation_path_dest: Path) -> None:

    img_lst = sorted(list((data_path / "IMAGES").rglob(f"{class_name}_*")))

    ann_lst = sorted(list((data_path / "ANNOTATIONS").rglob(f"{class_name}_*")))

    assert len(img_lst) == len(

        ann_lst

    ), f"images count {len(img_lst)} does not match with annotations count {len(ann_lst)} for class {class_name}"

    for (img_file, ann_file) in zip(img_lst, ann_lst):

        shutil.copy2(str(img_file), str(image_path_dest / img_file.name))

        shutil.copy2(str(ann_file), str(annotation_path_dest / ann_file.name))

    assert len(list(image_path_dest.glob("*"))) == len(list(annotation_path_dest.glob("*")))

    return

def main(data_path: str, output_path: str, archived: bool = True) -> None:

    """

    Data preparation

    Parameters

    ----------

    data_path : str

        Raw data path

    output_path : str

        Output data path

    archived: bool

        Whether the file is archived or not (for testing)

    Raises

    ------

    ValueError

        If the packed data file is different from NEU-CLS or NEU-DET

    """

    data_path = Path(data_path)

    if archived:

        unpack(data_path)

    data_path = data_path.parent / re.search(r"^[^.]*", str(data_path.name)).group(0)

    try:

        os.remove(str(data_path / "Thumbs.db"))

    except FileNotFoundError:

        print(f"Thumbs.db is not found. Continuing ...")

        pass

    except Exception as e:

        print(f"{e}: Unknown error!")

        raise e

    output_path = Path(output_path)

    if data_path.name == "NEU-CLS":

        for cls_ in CLASSES.values():

            cls_path = output_path / cls_

            cls_path.mkdir(exist_ok=True)

            cp_class_images(data_path, cls_, cls_path)

    elif data_path.name == "NEU-DET":

        for cls_ in CLASSES:

            cls_path = output_path / CLASSES[cls_]

            image_path = cls_path / "images"

            image_path.mkdir(parents=True, exist_ok=True)

            annotation_path = cls_path / "annotations"

            annotation_path.mkdir(exist_ok=True)

            cp_image_annotation(data_path, cls_, image_path, annotation_path)

    else:

        raise ValueError(f"Unknown data. Choose between `NEU-CLS` and `NEU-DET`. Given {data_path.name}")

    return

if __name__ == "__main__":

    import sys

    if len(sys.argv) < 3:

        print("Provide `data_path` and `output_path`")

        sys.exit(1)

    main(sys.argv[1], sys.argv[2])

    print("Done")

Variables

CLASSES

Functions

cp_class_images

def cp_class_images(
    data_path: pathlib.Path,
    class_name: str,
    class_path_dest: pathlib.Path
) -> None
View Source
def cp_class_images(data_path: Path, class_name: str, class_path_dest: Path) -> None:

    lst = list(data_path.rglob(f"{class_name}_*"))

    for img_file in lst:

        shutil.copy2(str(img_file), str(class_path_dest / img_file.name))

    assert len(lst) == len(list(class_path_dest.glob("*")))

    return

cp_image_annotation

def cp_image_annotation(
    data_path: pathlib.Path,
    class_name: str,
    image_path_dest: pathlib.Path,
    annotation_path_dest: pathlib.Path
) -> None
View Source
def cp_image_annotation(data_path: Path, class_name: str, image_path_dest: Path, annotation_path_dest: Path) -> None:

    img_lst = sorted(list((data_path / "IMAGES").rglob(f"{class_name}_*")))

    ann_lst = sorted(list((data_path / "ANNOTATIONS").rglob(f"{class_name}_*")))

    assert len(img_lst) == len(

        ann_lst

    ), f"images count {len(img_lst)} does not match with annotations count {len(ann_lst)} for class {class_name}"

    for (img_file, ann_file) in zip(img_lst, ann_lst):

        shutil.copy2(str(img_file), str(image_path_dest / img_file.name))

        shutil.copy2(str(ann_file), str(annotation_path_dest / ann_file.name))

    assert len(list(image_path_dest.glob("*"))) == len(list(annotation_path_dest.glob("*")))

    return

main

def main(
    data_path: str,
    output_path: str,
    archived: bool = True
) -> None

Data preparation

Parameters:

Name Type Description Default
data_path str Raw data path None
output_path str Output data path None
archived bool Whether the file is archived or not (for testing) None

Raises:

Type Description
ValueError If the packed data file is different from NEU-CLS or NEU-DET
View Source
def main(data_path: str, output_path: str, archived: bool = True) -> None:

    """

    Data preparation

    Parameters

    ----------

    data_path : str

        Raw data path

    output_path : str

        Output data path

    archived: bool

        Whether the file is archived or not (for testing)

    Raises

    ------

    ValueError

        If the packed data file is different from NEU-CLS or NEU-DET

    """

    data_path = Path(data_path)

    if archived:

        unpack(data_path)

    data_path = data_path.parent / re.search(r"^[^.]*", str(data_path.name)).group(0)

    try:

        os.remove(str(data_path / "Thumbs.db"))

    except FileNotFoundError:

        print(f"Thumbs.db is not found. Continuing ...")

        pass

    except Exception as e:

        print(f"{e}: Unknown error!")

        raise e

    output_path = Path(output_path)

    if data_path.name == "NEU-CLS":

        for cls_ in CLASSES.values():

            cls_path = output_path / cls_

            cls_path.mkdir(exist_ok=True)

            cp_class_images(data_path, cls_, cls_path)

    elif data_path.name == "NEU-DET":

        for cls_ in CLASSES:

            cls_path = output_path / CLASSES[cls_]

            image_path = cls_path / "images"

            image_path.mkdir(parents=True, exist_ok=True)

            annotation_path = cls_path / "annotations"

            annotation_path.mkdir(exist_ok=True)

            cp_image_annotation(data_path, cls_, image_path, annotation_path)

    else:

        raise ValueError(f"Unknown data. Choose between `NEU-CLS` and `NEU-DET`. Given {data_path.name}")

    return

unpack

def unpack(
    path: str
) -> None
View Source
def unpack(path: str) -> None:

    path = Path(path)

    Archive(str(path)).extractall(str(path.parent))

    return