Source code for aws.osml.gdal.gdal_utils

#  Copyright 2023-2024 Amazon.com, Inc. or its affiliates.

import logging
import re
from typing import Dict, List, Optional, Tuple

from defusedxml import ElementTree
from osgeo import gdal, gdalconst

from aws.osml.photogrammetry import SensorModel

from .dynamic_range_adjustment import DRAParameters
from .sensor_model_factory import SensorModelFactory, SensorModelTypes
from .typing import RangeAdjustmentType

logger = logging.getLogger(__name__)


[docs] def load_gdal_dataset(image_path: str) -> Tuple[gdal.Dataset, Optional[SensorModel]]: """ This function loads a GDAL raster dataset from the path provided and constructs a camera model abstraction used to georeference locations on this image. :param image_path: the path to the raster data, may be a local path or a virtual file system (e.g. /vsis3/...) :return: the raster dataset and sensor model """ try: logger.info("GDAL attempted to load image: %s", image_path) ds = gdal.Open(image_path) if ds is None: raise RuntimeError("GDAL Unable to load dataset and UseExceptions is not enabled.") except RuntimeError: logger.info("Skipping: %s - GDAL Unable to Process", image_path) raise ValueError("GDAL Unable to Load: {}".format(image_path)) # Get a GDAL Geo Transform and any available GCPs geo_transform = ds.GetGeoTransform(can_return_null=True) ground_control_points = ds.GetGCPs() proj_wkt = ds.GetProjection() # If this image has NITF TREs defined parse them parsed_tres = None xml_tres = ds.GetMetadata("xml:TRE") if xml_tres is not None and len(xml_tres) > 0: parsed_tres = ElementTree.fromstring(xml_tres[0]) # If this image has NITF DES segments read them xml_dess = ds.GetMetadata("xml:DES") selected_sensor_model_types = [ SensorModelTypes.AFFINE, SensorModelTypes.PROJECTIVE, SensorModelTypes.RPC, SensorModelTypes.SICD, SensorModelTypes.RSM, ] # Create the best sensor model available sensor_model = SensorModelFactory( ds.RasterXSize, ds.RasterYSize, xml_tres=parsed_tres, xml_dess=xml_dess, geo_transform=geo_transform, proj_wkt=proj_wkt, ground_control_points=ground_control_points, selected_sensor_model_types=selected_sensor_model_types, ).build() return ds, sensor_model
[docs] def get_minmax_for_type(gdal_type: int) -> Tuple[float, float]: """ This function computes the minimum and maximum values that can be stored in a given GDAL pixel type. :param gdal_type: the pixel type :return: tuple of min, max values """ min_value = 0 max_value = 255 if gdal_type == gdalconst.GDT_Byte: min_value = 0 max_value = 2**8 - 1 elif gdal_type == gdalconst.GDT_UInt16: min_value = 0 max_value = 2**16 - 1 elif gdal_type == gdalconst.GDT_Int16: min_value = -(2**15) max_value = 2**15 - 1 elif gdal_type == gdalconst.GDT_UInt32: min_value = 0 max_value = 2**32 - 1 elif gdal_type == gdalconst.GDT_Int32: min_value = -(2**31) max_value = 2**31 - 1 elif gdal_type == gdalconst.GDT_UInt64: min_value = 0 max_value = 2**64 - 1 elif gdal_type == gdalconst.GDT_Int64: min_value = -(2**63) max_value = 2**63 - 1 elif gdal_type == gdalconst.GDT_Float32: min_value = -3.4028235e38 max_value = 3.4028235e38 elif gdal_type == gdalconst.GDT_Float64: min_value = -1.7976931348623157e308 max_value = 1.7976931348623157e308 else: logger.warning("Image uses unsupported GDAL datatype {}. Defaulting to [0,255] range".format(gdal_type)) return min_value, max_value
[docs] def get_type_and_scales( raster_dataset: gdal.Dataset, desired_output_type: Optional[int] = None, range_adjustment: RangeAdjustmentType = RangeAdjustmentType.NONE, ) -> Tuple[int, List[List[int]]]: """ Get type and scales of a provided raster dataset :param raster_dataset: the raster dataset containing the region :param desired_output_type: type to be output after dynamic range adjustments :param range_adjustment: the type of pixel scaling effort to :return: a tuple containing type and scales """ scale_params = [] output_type = gdalconst.GDT_Byte num_bands = raster_dataset.RasterCount for band_num in range(1, num_bands + 1): band = raster_dataset.GetRasterBand(band_num) band_type = band.DataType min_value, max_value = get_minmax_for_type(band_type) if desired_output_type is None: output_type = band_type output_min = min_value output_max = max_value else: output_type = desired_output_type output_min, output_max = get_minmax_for_type(desired_output_type) # If a range adjustment is requested compute the range of source pixel values that will be mapped to the full # output range. selected_min = min_value selected_max = max_value if range_adjustment is not RangeAdjustmentType.NONE: # GetStatistics(1,1) means it is ok to approximate but force computation is stats not already available stats = band.GetStatistics(1, 1) min_value = stats[0] max_value = stats[1] num_buckets = int(max_value - min_value) if band_type == gdalconst.GDT_Float32 or band_type == gdalconst.GDT_Float64: num_buckets = 255 dra = DRAParameters.from_counts( counts=band.GetHistogram( buckets=num_buckets, max=max_value, min=min_value, include_out_of_range=1, approx_ok=1 ), first_bucket_value=min_value, last_bucket_value=max_value, ) if range_adjustment == RangeAdjustmentType.DRA: selected_min = dra.suggested_min_value selected_max = dra.suggested_max_value elif range_adjustment == RangeAdjustmentType.MINMAX: selected_min = dra.actual_min_value selected_max = dra.actual_max_value else: logger.warning(f"Unknown range adjustment selected {range_adjustment}. Skipping.") band_scale_parameters = [selected_min, selected_max, output_min, output_max] scale_params.append(band_scale_parameters) return output_type, scale_params
[docs] def get_image_extension(image_path: str) -> str: """ Get the image extension based on the provided image path :param image_path: an image path :return: image extension """ possible_extensions = get_extensions_from_driver(image_path) selected_extension = select_extension(image_path, possible_extensions) image_extension = normalize_extension(selected_extension) logger.info("Image extension: {}".format(image_extension)) return image_extension
[docs] def select_extension(image_path: str, possible_extensions: List[str]) -> str: """ Check to see if provided image path contains a known possible extensions :param image_path: an image path :param possible_extensions: list of possible extensions :return: select extension """ selected_extension = "UNKNOWN" for i, possible_extension in enumerate(possible_extensions): if i == 0: selected_extension = possible_extension.upper() elif f".{possible_extension}".upper() in image_path.upper(): selected_extension = possible_extension.upper() return selected_extension
[docs] def normalize_extension(unnormalized_extension: str) -> str: """ Convert the extension into a proper formatted string :param unnormalized_extension: an unnormalized extension :return: normalized extension """ normalized_extension = unnormalized_extension.upper() if re.search(r"ni?tf", normalized_extension, re.IGNORECASE): normalized_extension = "NITF" elif re.search(r"tif{1,2}", normalized_extension, re.IGNORECASE): normalized_extension = "TIFF" elif re.search(r"jpe?g", normalized_extension, re.IGNORECASE): normalized_extension = "JPEG" return normalized_extension
[docs] def get_extensions_from_driver(image_path: str) -> List[str]: """ Returns a list of driver extensions :param image_path: an image path :return: driver extensions """ driver_extension_lookup = get_gdal_driver_extensions() info = gdal.Info(image_path, format="json") driver_long_name = info.get("driverLongName") return driver_extension_lookup.get(driver_long_name, [])
[docs] def get_gdal_driver_extensions() -> Dict[str, List]: """ Returns a list of gdal driver extensions :return: gdal driver extensions """ driver_lookup = {} for i in range(gdal.GetDriverCount()): drv = gdal.GetDriver(i) driver_name = drv.GetMetadataItem(gdal.DMD_LONGNAME) driver_extensions = drv.GetMetadataItem(gdal.DMD_EXTENSIONS) if driver_extensions: extension_list = driver_extensions.strip().split(" ") driver_lookup[driver_name] = extension_list return driver_lookup