Module sagemaker_defect_detection.utils
None
None
View Source
from typing import Optional, Union
from pathlib import Path
import tarfile
import logging
from logging.config import fileConfig
import torch
import torch.nn as nn
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
def get_logger(config_path: str) -> logging.Logger:
    fileConfig(config_path, disable_existing_loggers=False)
    logger = logging.getLogger()
    return logger
def str2bool(flag: Union[str, bool]) -> bool:
    if not isinstance(flag, bool):
        if flag.lower() == "false":
            flag = False
        elif flag.lower() == "true":
            flag = True
        else:
            raise ValueError("Wrong boolean argument!")
    return flag
def freeze(m: nn.Module) -> None:
    assert isinstance(m, nn.Module), "freeze only is applied to modules"
    for param in m.parameters():
        param.requires_grad = False
    return
def load_checkpoint(model: nn.Module, path: str, prefix: Optional[str]) -> nn.Module:
    path = Path(path)
    logger.info(f"path: {path}")
    if path.is_dir():
        path_str = str(list(path.rglob("*.ckpt"))[0])
    else:
        path_str = str(path)
    device = "cuda" if torch.cuda.is_available() else "cpu"
    state_dict = torch.load(path_str, map_location=torch.device(device))["state_dict"]
    if prefix is not None:
        if prefix[-1] != ".":
            prefix += "."
        state_dict = {k[len(prefix) :]: v for k, v in state_dict.items() if k.startswith(prefix)}
    model.load_state_dict(state_dict, strict=True)
    return model
Sub-modules
- sagemaker_defect_detection.utils.coco_eval
- sagemaker_defect_detection.utils.coco_utils
- sagemaker_defect_detection.utils.visualize
Variables
logger
Functions
freeze
def freeze(
    m: torch.nn.modules.module.Module
) -> None
View Source
def freeze(m: nn.Module) -> None:
    assert isinstance(m, nn.Module), "freeze only is applied to modules"
    for param in m.parameters():
        param.requires_grad = False
    return
get_logger
def get_logger(
    config_path: str
) -> logging.Logger
View Source
def get_logger(config_path: str) -> logging.Logger:
    fileConfig(config_path, disable_existing_loggers=False)
    logger = logging.getLogger()
    return logger
load_checkpoint
def load_checkpoint(
    model: torch.nn.modules.module.Module,
    path: str,
    prefix: Union[str, NoneType]
) -> torch.nn.modules.module.Module
View Source
def load_checkpoint(model: nn.Module, path: str, prefix: Optional[str]) -> nn.Module:
    path = Path(path)
    logger.info(f"path: {path}")
    if path.is_dir():
        path_str = str(list(path.rglob("*.ckpt"))[0])
    else:
        path_str = str(path)
    device = "cuda" if torch.cuda.is_available() else "cpu"
    state_dict = torch.load(path_str, map_location=torch.device(device))["state_dict"]
    if prefix is not None:
        if prefix[-1] != ".":
            prefix += "."
        state_dict = {k[len(prefix) :]: v for k, v in state_dict.items() if k.startswith(prefix)}
    model.load_state_dict(state_dict, strict=True)
    return model
str2bool
def str2bool(
    flag: Union[str, bool]
) -> bool
View Source
def str2bool(flag: Union[str, bool]) -> bool:
    if not isinstance(flag, bool):
        if flag.lower() == "false":
            flag = False
        elif flag.lower() == "true":
            flag = True
        else:
            raise ValueError("Wrong boolean argument!")
    return flag