Source code for demcompare.demcompare_tiles

#!/usr/bin/env python
# coding: utf8
#
# Copyright (c) 2024 Centre National d'Etudes Spatiales (CNES).
#
# This file is part of demcompare
# (see https://github.com/CNES/demcompare).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
This module contains a wrapper for performing tiling on datas
"""

import argparse
import json
import logging
import multiprocessing as mp
import os
import shutil
import traceback
from typing import Callable, Tuple

# Third party imports
import argcomplete
import numpy as np
import rasterio
from affine import Affine

from demcompare import log_conf
from demcompare import run as run_demcompare_on_tile
from demcompare.img_tools import convert_pix_to_coord


[docs] def get_parser(): """ ArgumentParser for demcompare_tiles :return: parser """ parser = argparse.ArgumentParser( description="Compare Digital Elevation Models by tiling", fromfile_prefix_chars="@", ) parser.add_argument( "tiles_config", metavar="config.json", help=( "path to a json file containing the paths to " "input and output files, the tiles parameters " "and the algorithm parameters" ), ) parser.add_argument( "--loglevel", default="WARNING", choices=("DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"), help="Logger level (default: INFO. Should be one of " "(DEBUG, INFO, WARNING, ERROR, CRITICAL)", ) return parser
[docs] def process_tile(args): """ Function that uses multiprocessing to run `demcompare` on multiple tiles concurrently. :param args: arguments list to compute process_tile with multiprocessing, see list below :type args: list :param row: Row index of the tile :type row: int :param col: Column index of the tile :type col: int :param width: Width of the tile :type width: int :param height: Height of the tile :type height: int :param overlap_size: overlap_size between two tiles :type overlap_size: int :param new_geotransform: geotransform's intersection of two tiles :type new_geotransform: affine :param output_dir: output directory for one tile :type output_dir: str :param dict_config: Tile's demcompare configuration :type dict_config: dict :param loglevel: log level :type loglevel: str """ ( row, col, width, height, overlap_size, new_geotransform, output_dir, dict_config, loglevel, ) = args # Get tile in DEM top_left_col = col * (width - overlap_size) top_left_row = -row * (height - overlap_size) bottom_right_col = top_left_col + width bottom_right_row = top_left_row - height left_point = convert_pix_to_coord( new_geotransform, top_left_row, top_left_col ) right_point = convert_pix_to_coord( new_geotransform, bottom_right_row, bottom_right_col ) roi = { "left": float(left_point[0]), "bottom": float(left_point[1]), "right": float(right_point[0]), "top": float(right_point[1]), } dict_config["input_ref"]["roi"] = roi dict_config["input_sec"]["roi"] = roi saving_dir = os.path.join(output_dir, f"row_{row}/col_{col}/") os.makedirs(saving_dir, exist_ok=True) dict_config["output_dir"] = saving_dir config = os.path.join(output_dir, f"with_roi_{row}_{col}.json") # Create new config with roi with open(config, "w", encoding="utf-8") as f: json.dump(dict_config, f) try: # run demcompare run_demcompare_on_tile(config, loglevel=loglevel) # If gradient function doesn't work on tile except ValueError: logging.info( "Tile (%s, %s) is too small, NaN values are returned", row, col ) shutil.rmtree(saving_dir) # remove tile config from disk os.remove(config)
[docs] def verify_config(dict_config_tiling: dict) -> Tuple[int, int, int, int]: """ Functions that verify tiling configuration :param dict_config_tiling: dictionary containing the tiles parameters :type dict_config_tiling: dict :return: height, width, overlap_size, nb_cpu :rtype: Tuple[int, int, int, int] """ # Function to validate each parameter def validate_param( param_name: str, condition: Callable, error_message: str ) -> int: """ Validate parameter from tile_configuration :param param_name: a tile's parameter :type param_name: str :param condition: a lambda function describing the condition :type condition: Callable :param error_message: Error message for user :type error_message: str :return: The parameter value :rtype: int """ if param_name in dict_config_tiling and condition( dict_config_tiling[param_name] ): return dict_config_tiling[param_name] raise ValueError(error_message) # Define validation conditions conditions = { "height": lambda x: isinstance(x, int) and x > 0, "width": lambda x: isinstance(x, int) and x > 0, "overlap": lambda x: isinstance(x, int) and x >= 0, "nb_cpu": lambda x: isinstance(x, int) and x > 0, } # Validate parameters height = validate_param( "height", conditions["height"], "Height is not consistent" ) width = validate_param( "width", conditions["width"], "Width is not consistent" ) overlap_size = validate_param( "overlap", conditions["overlap"], "Overlap is not consistent" ) # Get actual CPU usage from the environment to automate the default pipeline # and validate CPU load. if "nb_cpu" not in dict_config_tiling: nb_cpu = len(os.sched_getaffinity(0)) else: nb_cpu = validate_param( "nb_cpu", conditions["nb_cpu"], "Number of CPUs is incorrect" ) if nb_cpu > len(os.sched_getaffinity(0)): raise ValueError( "Number of CPUs in the config is more than available CPUs" ) return height, width, overlap_size, nb_cpu
[docs] def get_coreg_results( coreg_results: dict, ) -> Tuple[float, float, float, float, float]: """ Get coregistration results for one tile :param coreg_results: Coregistration dictionary results :type coreg_results: dict :return: x, y and z shifts, percentage of valid data (ref, sec) :rtype: float, float, float, float, float """ x = coreg_results["coregistration_results"]["dx"]["total_offset"] y = coreg_results["coregistration_results"]["dy"]["total_offset"] z = coreg_results["coregistration_results"]["dz"]["total_bias_value"] nb_valid_pts_ref = coreg_results["coregistration_results"][ "reproj_coreg_ref" ]["percentage_valid_points"] nb_valid_pts_sec = coreg_results["coregistration_results"][ "reproj_coreg_sec" ]["percentage_valid_points"] return x, y, z, nb_valid_pts_ref, nb_valid_pts_sec
[docs] def run_tiles(tiles_config, loglevel): # pylint:disable=too-many-locals """ Call demcompare_tiles's main """ # Logging configuration log_conf.setup_logging(default_level=loglevel) with open(tiles_config, "r", encoding="utf-8") as json_file: dict_config = json.load(json_file) # Verify config and get tiles management height, width, overlap_size, nb_cpu = verify_config(dict_config["tiling"]) # Path management output_dir = os.path.abspath(dict_config["output_dir"]) dict_config["input_ref"]["path"] = os.path.abspath( dict_config["input_ref"]["path"] ) dict_config["input_sec"]["path"] = os.path.abspath( dict_config["input_sec"]["path"] ) dict_config["output_dir"] = output_dir # Create output_dir from updated absolute path os.makedirs(dict_config["output_dir"], exist_ok=True) ref_dem = rasterio.open(dict_config["input_ref"]["path"]) sec_dem = rasterio.open(dict_config["input_sec"]["path"]) # Get DEM intersection transformed_sec_bounds = rasterio.warp.transform_bounds( sec_dem.crs, sec_dem.crs, sec_dem.bounds.left, sec_dem.bounds.bottom, sec_dem.bounds.right, sec_dem.bounds.top, ) transformed_ref_bounds = rasterio.warp.transform_bounds( ref_dem.crs, sec_dem.crs, ref_dem.bounds.left, ref_dem.bounds.bottom, ref_dem.bounds.right, ref_dem.bounds.top, ) if rasterio.coords.disjoint_bounds( transformed_sec_bounds, transformed_ref_bounds ): raise NameError("ERROR: ROIs do not intersect") intersection_roi = rasterio.coords.BoundingBox( max(transformed_sec_bounds[0], transformed_ref_bounds[0]), max(transformed_sec_bounds[1], transformed_ref_bounds[1]), min(transformed_sec_bounds[2], transformed_ref_bounds[2]), min(transformed_sec_bounds[3], transformed_ref_bounds[3]), ) # Working on intersection new_geotransform = list( Affine( ref_dem.res[0], 0.0, intersection_roi.left, 0.0, -ref_dem.res[1], intersection_roi.bottom, ).to_gdal() ) # Instance of tiles parameters image_height = abs( int((intersection_roi.top - intersection_roi.bottom) / ref_dem.res[0]) ) image_width = abs( int((intersection_roi.left - intersection_roi.right) / ref_dem.res[0]) ) logging.info( "The intersection DEM size is %s row and %s col", image_height, image_width, ) logging.info("The tile size is %s row %s col", height, width) nb_tiles_row = (image_height - overlap_size) // (height - overlap_size) if (image_height - overlap_size) % (height - overlap_size) != 0: nb_tiles_row += 1 nb_tiles_col = (image_width - overlap_size) // (width - overlap_size) if (image_width - overlap_size) % (width - overlap_size) != 0: nb_tiles_col += 1 logging.info( "There are %s tiles in columns and %s tiles in rows", nb_tiles_col, nb_tiles_row, ) tasks = [ ( row, col, width, height, overlap_size, new_geotransform, output_dir, dict_config, loglevel, ) for row in range(nb_tiles_row) for col in range(nb_tiles_col) ] with mp.Pool(processes=nb_cpu) as pool: _ = pool.map(process_tile, tasks) # Compute matrix to store dx, dy, dz and valid points percentage x_2d = np.full((nb_tiles_row, nb_tiles_col), np.nan) y_2d = np.full((nb_tiles_row, nb_tiles_col), np.nan) z_2d = np.full((nb_tiles_row, nb_tiles_col), np.nan) percentage_valid_points = np.full((2, nb_tiles_row, nb_tiles_col), np.nan) for row in range(nb_tiles_row): for col in range(nb_tiles_col): # Sometimes demcompare is not robust and the directory is removed if os.path.isdir(output_dir + f"/row_{row}/col_{col}"): with open( output_dir + f"/row_{row}/col_{col}/coregistration/" f"coregistration_results.json", "r", encoding="utf-8", ) as json_file: coreg_results = json.load(json_file) x, y, z, valid_point_ref, valid_point_sec = get_coreg_results( coreg_results ) x_2d[row, col] = x y_2d[row, col] = y z_2d[row, col] = z percentage_valid_points[0, row, col] = valid_point_ref percentage_valid_points[1, row, col] = valid_point_sec np.save(os.path.join(output_dir, "coreg_results_x2D.npy"), x_2d) np.save(os.path.join(output_dir, "coreg_results_y2D.npy"), y_2d) np.save(os.path.join(output_dir, "coreg_results_z2D.npy"), z_2d) np.save( os.path.join(output_dir, "percentage_valid_points.npy"), percentage_valid_points, )
[docs] def main(): """ Call demcompare-tile's main """ parser = get_parser() argcomplete.autocomplete(parser) args = parser.parse_args() try: run_tiles(args.tiles_config, args.loglevel) except Exception: # pylint: disable=broad-except logging.error(" Demcompare %s", traceback.format_exc())
if __name__ == "__main__": main()