Source code for gpm.utils.collocation

# -----------------------------------------------------------------------------.
# MIT License

# Copyright (c) 2024 GPM-API developers
#
# This file is part of GPM-API.

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

# -----------------------------------------------------------------------------.
"""This module contains utilities for GPM product collocation."""
import datetime

import numpy as np
import pyproj
import xarray as xr

import gpm
from gpm.utils.manipulations import get_spatial_2d_datarray_template


def _get_collocation_defaults_args(product, variables, groups, version, scan_modes):
    """Get collocation defaults arguments."""
    if scan_modes is None:
        scan_modes = gpm.available_scan_modes(product=product, version=version)
    if isinstance(scan_modes, str):
        scan_modes = [scan_modes]
    if product not in gpm.available_products(product_categories="PMW") and len(scan_modes) > 1:
        raise ValueError("Multiple scan modes can be specified only for PMW products!")
    # PMW defaults
    if variables is None and groups is None:
        if product in gpm.available_products(product_levels="2A", product_categories="PMW"):
            variables = [
                "surfacePrecipitation",
                "mostLikelyPrecipitation",
                "cloudWaterPath",  # kg/m2
                "rainWaterPath",  # kg/m2
                "iceWaterPath",  # kg/m2
            ]
        elif product.startswith("1C"):
            variables = ["Tc"]
        elif product.startswith("1B"):
            variables = ["Tb"]
        else:
            pass
    return scan_modes, variables, groups


[docs] def collocate_product( ds, product, product_type="RS", version=None, storage="GES_DISC", scan_modes=None, variables=None, groups=None, verbose=True, decode_cf=True, chunks={}, ): """Collocate a product on the provided dataset. It assumes that along all the input dataset, there is an approximate collocated product. """ # Get default collocation arguments scan_modes, variables, groups = _get_collocation_defaults_args( product=product, variables=variables, groups=groups, version=version, scan_modes=scan_modes, ) # Define start_time, end_time around input dataset start_time = ds.gpm.start_time - datetime.timedelta(minutes=5) end_time = ds.gpm.end_time + datetime.timedelta(minutes=5) # Download PMW products (if necessary) gpm.download( product=product, product_type=product_type, start_time=start_time, end_time=end_time, version=version, storage=storage, force_download=False, verbose=verbose, ) # Read datasets dt = gpm.open_datatree( product=product, start_time=start_time, end_time=end_time, # Optional version=version, variables=variables, groups=groups, product_type=product_type, scan_modes=scan_modes, chunks=chunks, decode_cf=decode_cf, ) # Remap datasets list_remapped = [dt[scan_mode].to_dataset().gpm.remap_on(ds) for scan_mode in scan_modes] # Concatenate if necessary (PMW case) output_ds = xr.concat(list_remapped, dim="pmw_frequency") if len(list_remapped) > 1 else list_remapped[0] # Add time of dst dataset output_ds = output_ds.assign_coords({"time": ds.reset_coords()["time"]}) # Assign attributes output_ds.attrs = dt[scan_modes[0]].attrs output_ds.attrs["ScanMode"] = scan_modes return output_ds
# def _expand_with_scan_mode(ds, scan_mode): # # Squeeze dataset # if "incidence_angle" in ds.dims: # ds = ds.squeeze(dim="incidence_angle") # # Define variables to exclude from expansion (always keep as is) # variables_to_exclude = [ # "lon", # "lat", # "longitude", # "latitude", # "crsWGS84", # "gpm_id", # "gpm_time", # "gpm_granule_id", # "gpm_along_track_id", # "gpm_cross_track_id", # ] # # Define variables to preprocess # variables = set(ds.variables) - set(variables_to_exclude) # coords_to_expand = [ # var # for var in variables # if ( # var in ds.coords # and "pmw_frequency" not in ds[var].dims # and np.all(np.isin(("along_track", "cross_track"), ds[var].dims)) # ) # ] # variables_to_expand = [ # var # for var in variables # if ( # var not in ds.coords # and "pmw_frequency" not in ds[var].dims # and np.all(np.isin(("along_track", "cross_track"), ds[var].dims)) # ) # ] # variables_not_to_expand = [var for var in variables if "pmw_frequency" in ds[var].dims] # # If nothing to expand, return input dataset # if not variables_to_expand and not coords_to_expand: # return ds # # Expand dataset # if variables_to_expand: # ds_expanded = ds[variables_to_expand] # if coords_to_expand: # ds_expanded = ds_expanded.reset_coords(coords_to_expand) # else: # only coords to expand # ds_expanded = ds.reset_coords(coords_to_expand)[coords_to_expand] # ds_expanded = ds_expanded.expand_dims(dim={"scan_mode": 1}, axis=-1) # ds_expanded = ds_expanded.assign_coords({"scan_mode": [scan_mode]}) # # Add variables to not be expanded # ds_expanded.update(ds.reset_coords()[variables_not_to_expand]) # return ds_expanded # def _remap_pmw_datatree(dt, scan_modes, scan_mode_reference, radius_of_influence=20_000): # # Define grid template # ds_template = dt[scan_mode_reference].to_dataset() # # Remap each scan_mode dataset onto the template grid # list_remapped = [ # dt[scan_mode].to_dataset().gpm.remap_on(ds_template, radius_of_influence=radius_of_influence) # for scan_mode in scan_modes # ] # # Insert the template dataset as the first element in the list # list_remapped.insert(0, ds_template) # # Concatenate common variables along pmw_frequency or scan_mode # vars_pmw_freq = [var for var in list_remapped[1].data_vars if "pmw_frequency" in ds_template[var].dims] # vars_scan_mode = [var for var in list_remapped[1].data_vars if "scan_mode" in ds_template[var].dims] # ds_pmw = xr.concat([ds[vars_pmw_freq] for ds in list_remapped], dim="pmw_frequency") # ds_scan_mode = xr.concat([ds[vars_scan_mode] for ds in list_remapped], dim="scan_mode") # output_ds = xr.merge([ds_pmw, ds_scan_mode]) # # Add back some variables # # - TODO: SClatitude, SClongitude, SCaltitude, FractionalGranuleNumber # # - TODO: incidenceAngleIndex # return output_ds def _remap_pmw_datatree(dt_expanded, scan_modes, scan_mode_reference, radius_of_influence=20_000): # Define grid template ds_template = dt_expanded[scan_mode_reference].to_dataset() # Remap each scan_mode dataset onto the template grid list_remapped = [ds_template] for scan_mode in scan_modes: ds = dt_expanded[scan_mode].to_dataset().gpm.remap_on(ds_template, radius_of_influence=radius_of_influence) list_remapped.append(ds) # Concatenate common variables along pmw_frequency or scan_mode, then add unique variables list_pmw_freq = [ds[[var for var in ds.data_vars if "pmw_frequency" in ds[var].dims]] for ds in list_remapped] list_scan_mode = [ds[[var for var in ds.data_vars if "scan_mode" in ds[var].dims]] for ds in list_remapped] list_unique = [ ds[[var for var in ds.data_vars if not np.isin(["scan_mode", "pmw_frequency"], ds[var].dims).any()]] for ds in list_remapped ] with xr.set_options(use_new_combine_kwarg_defaults=True): ds_unique = xr.merge(list_unique, compat="no_conflicts", join="outer", combine_attrs="override") ds_pmw = xr.concat( list_pmw_freq, dim="pmw_frequency", coords="minimal", compat="override", combine_attrs="override", ) ds_scan_mode = xr.concat( list_scan_mode, dim="scan_mode", coords="minimal", compat="override", combine_attrs="override", ) ds = xr.merge([ds_pmw, ds_scan_mode, ds_unique], compat="override", join="outer", combine_attrs="override") # Add back missing variables of dt[scan_mode_reference] # TODO: SClatitude, SClongitude, SCaltitude, FractionalGranuleNumber return ds
[docs] def preprocess_datatree(dt, exclude_vars=None, fixed_vars=None): """Preprocess DataTree for remapping by handling variables consistently. Parameters ---------- dt : xarray.DataTree DataTree with multiple scan modes. exclude_vars : list, optional Variables to exclude from processing. fixed_vars : list, optional Variables to preserve as-is. Returns ------- xarray.DataTree Preprocessed DataTree ready for remapping. """ # Prepare datasets for remapping # - Remove 'exclude_vars' variables from all datatree nodes # --> SClatitude, SClongitude, SCaltitude, FractionalGranuleNumbe # # - Variables with only (along_track) dimension should be broadcasted to (along_track, cross_track) # unless being 'fixed_vars' # --> Quality and sunLocalTime # --> If user broadcast such variables to have the pmw_frequency dimension # --> The final dataset will be concatenate along pmw_frequency dimension # --> Otherwise such variables will be concatenated along the scan_mode dimension # - Remove all variables without (along_track, cross-track) dimensions unless being in 'fixed_vars' # List scan modes scan_modes = list(dt) # Move coordinates to be remapped as variables coords_variables = ["sunLocalTime", "Quality"] for scan_mode in scan_modes: ds = dt[scan_mode].to_dataset() dt[scan_mode] = ds.reset_coords([var for var in coords_variables if var in list(ds.coords)], drop=False) # Set default values if exclude_vars is None: exclude_vars = [ "SClatitude", "SClongitude", "SCaltitude", "FractionalGranuleNumber", ] if fixed_vars is None: fixed_vars = [ "lon", "lat", "longitude", "latitude", "crsWGS84", "gpm_id", "gpm_time", "gpm_granule_id", "gpm_along_track_id", "gpm_cross_track_id", ] # - Remove 'exclude_vars' variables from all datatree nodes dt = dt.map_over_datasets(lambda ds: ds.drop_vars(exclude_vars, errors="ignore")) # - Variables with only (along_track) or (cross-track) dimension should be broadcasted # to (along_track, cross_track) unless being 'fixed_vars' def _broadcast_along_track_only(ds, fixed_vars): ds_spatial_2d_template = get_spatial_2d_datarray_template(ds) variables = [var for var in ds.data_vars if var not in fixed_vars] for var in variables: if len(ds[var].gpm.spatial_dimensions) == 1: ds[var] = ds[var].broadcast_like(ds_spatial_2d_template) return ds for scan_mode in scan_modes: dt[scan_mode] = _broadcast_along_track_only(dt[scan_mode].to_dataset(), fixed_vars=fixed_vars) # ------------------------------------------------------------------------- # Preprocess variables with (along_track, cross-track) dimensions # - Unique variables: # - Variables with (along_track, cross-track) dimensions not shared across any datatree dataset # are kept as they are # - Can be just added to the final dataset # - Partial variables: # - Variables with (along_track, cross-track) dimensions present in some but not all datatree dataset # - NaN DataArray variable should be added to dataset where missing --> Then treated as shared variables # - Shared variables with pmw_frequency: # - Variables with (along_track, cross-track, pmw_frequency) dimensions shared across all datatree dataset # are kept as they are # - Such variables are concatenated along pmw_frequency in the final dataset # - Shared variables without pmw_frequency: # - Variables with (along_track, cross-track) dimensions shared across all datatree dataset # are expand with the scan_mode dimension. # - Such variables are concatenated along scan_mode in the final dataset # List variables across datatree dict_vars = {scan_mode: list(dt[scan_mode].data_vars) for scan_mode in scan_modes} # Find list of all variables all_variables = [var for vars_list in dict_vars.values() for var in vars_list] # Find shared variables (present in every scan_mode) shared_variables = set(dict_vars[scan_modes[0]]) for scan_mode in scan_modes[1:]: shared_variables &= set(dict_vars[scan_mode]) # Find unique variables across scan modes unique_variables = {scan_mode: set(dict_vars[scan_mode]) - shared_variables for scan_mode in scan_modes} unique_variables = set().union(*unique_variables.values()) # Find partial variables (occur only in some scan modes) partial_variables = set(all_variables) - unique_variables - shared_variables # noqa F841 # Add dummy NaN array in scan_modes where partial_variables missing # TODO: TODO # Expand shared variables without pmw_frequency dimension for scan_mode in scan_modes: ds = dt[scan_mode].to_dataset() for var in ds.data_vars: if var in shared_variables and "pmw_frequency" not in ds[var].dims: if "scan_mode" not in ds: ds = ds.assign_coords({"scan_mode": scan_mode}) ds[var] = ds[var].expand_dims(dim={"scan_mode": 1}, axis=-1) dt[scan_mode] = ds return dt
def _define_time_blocks(dt, scan_mode_reference, window_duration=3600, overlap_duration=60): """ Define overlapping time periods. Parameters ---------- dt (dict): Dictionary of xarray datasets by scan mode. scan_mode_reference (str): Reference scan mode key. window_duration (int): Duration of the time window in seconds (default: 1 hour). overlap_duration (int): Duration of the overlap in seconds (default: 1 minute). Returns ------- list: List of tuples containing (start_time, end_time, gpm_id_start, gpm_id_stop). """ # Retrieve reference dataset ds = dt[scan_mode_reference] # Extract the reference time dimension start_time = np.datetime64(ds["time"].gpm.start_time) end_time = np.datetime64(ds["time"].gpm.end_time) # Calculate the time window step with overlap step_duration = window_duration - overlap_duration # Generate reference time blocks start_times = np.arange(start_time, end_time, np.timedelta64(step_duration, "s")) end_times = start_times + np.timedelta64(step_duration, "s") end_times = np.minimum(end_times, end_time) # Now define window time blocks and gpm_id at each time block time_blocks = [ [ np.maximum(start_time - np.timedelta64(window_duration, "s"), start_times[0]), np.minimum(end_time + np.timedelta64(window_duration, "s"), end_times[-1]), ds["gpm_id"].gpm.sel(time=start_time, method="nearest").to_numpy().item(), ds["gpm_id"].gpm.sel(time=end_time, method="nearest").to_numpy().item(), ] for start_time, end_time in zip(start_times, end_times, strict=False) ] # Count number of blocks n_blocks = len(time_blocks) # Enforce first and last gpm_id gpm_ids = ds["gpm_id"].data time_blocks[0][2] = gpm_ids[0].item() time_blocks[-1][3] = gpm_ids[-1].item() # If only 1 block, return it if n_blocks == 1: return time_blocks # Otherwise, ensure no missing gpm_id between blocks and no repeated gpm_id for i in range(1, n_blocks): next_gpm_id_start = time_blocks[i][2] previous_gpm_id_stop = time_blocks[i - 1][3] # Retrieve gpm_id position next_idx = np.where(gpm_ids == next_gpm_id_start)[0] previous_idx = np.where(gpm_ids == previous_gpm_id_stop)[0] # If same gpm_id, modify next_gpm_id_start with the next gpm_id value if next_idx == previous_idx: time_blocks[i][2] = gpm_ids[previous_idx + 1].item() if next_idx > previous_idx + 1: time_blocks[i][2] = gpm_ids[previous_idx + 1].item() return time_blocks
[docs] def regrid_pmw_l1(dt, scan_mode_reference="S1", radius_of_influence=20_000): """ Regrid the scan modes of a PMW Level 1 product into a common grid. Parameters ---------- dt : xarray.DataTree DataTree containing multiple scan modes (nodes). scan_mode_reference : str, optional The scan mode/node with the spatial coordinates to use as reference grid. Returns ------- xarray.Dataset The collocated dataset, with PMW channels concatenated along a 'pmw_frequency' dimension. """ # Retrieve available scan modes scan_modes = list(dt) # Check template in datatree if scan_mode_reference not in dt: raise ValueError(f"The 'scan_mode_reference' '{scan_mode_reference}' is not found in the provided DataTree.") # Remove template scan mode from scan_modes scan_modes.remove(scan_mode_reference) # Ensure at least one scan mode to collocate if len(scan_modes) == 0: return dt[scan_mode_reference].to_dataset() # Retrieve datatree attributes attrs = dt[scan_mode_reference].attrs.copy() if attrs.get("gpm_api_product") not in gpm.available_products( product_categories="PMW", product_levels=["1B", "1C"], ): raise ValueError("The DataTree does not contain a 1B or 1C PMW product.") # - TODO: Variables without (along_track, cross-track) dimensions are currently not remapped ! # Prepare DataTree to remap # - Expand the required variables # - Infill missing variables # dict_scan_modes = { # scan_mode: _expand_with_scan_mode(dt[scan_mode].to_dataset(), scan_mode=scan_mode) for scan_mode in list(dt) # } # dt_expanded = xr.DataTree.from_dict(dict_scan_modes) dt_expanded = preprocess_datatree(dt, exclude_vars=None, fixed_vars=None) # If GPM product, remap by blocks to avoid orbit intersections if "gpm_id" in dt[scan_mode_reference]: # Define time blocks over which to remap time_blocks = _define_time_blocks( dt_expanded, scan_mode_reference=scan_mode_reference, window_duration=60 * 30, overlap_duration=120, ) # Loop and remap over block of time (to avoid orbit intersection) list_ds = [] # start_time, end_time, gpm_id_start, gpm_id_stop = time_blocks[0] for start_time, end_time, gpm_id_start, gpm_id_stop in time_blocks: dict_subset = { scan_mode: dt_expanded[scan_mode].to_dataset().gpm.sel(time=slice(start_time, end_time)) for scan_mode in list(dt_expanded) } dt_subset = xr.DataTree.from_dict(dict_subset) output_ds = _remap_pmw_datatree( dt_subset, scan_mode_reference=scan_mode_reference, scan_modes=scan_modes, radius_of_influence=radius_of_influence, ) output_ds = output_ds.gpm.sel(gpm_id=slice(gpm_id_start, gpm_id_stop)) list_ds.append(output_ds) # Concatenate dataset ds = xr.concat(list_ds, dim="along_track", coords="minimal", compat="override") # If i.e. TCPRIMED product, remap full datatree directly else: ds = _remap_pmw_datatree( dt_expanded, scan_mode_reference=scan_mode_reference, scan_modes=scan_modes, radius_of_influence=radius_of_influence, ) # Assign attributes ds.attrs = attrs ds.attrs["ScanModes"] = sorted([*scan_modes, scan_mode_reference]) return ds
[docs] def remap_era5(ds, variables): """Remap ERA5 variables onto the input dataset using nearest neighbour.""" from gpm.dataset.crs import set_dataset_crs # Open ERA5 archive on Google Cloud Bucket # - Currently available from 1940 to May 2023 # - x axis longitudes goes from 0 to 360 (pm=0) # - y axis is decreasing ! ds_era5 = xr.open_zarr( "gs://gcp-public-data-arco-era5/ar/full_37-1h-0p25deg-chunk-1.zarr-v3", chunks=None, storage_options={"token": "anon"}, ) # ds_era5["2m_temperature"].sel(time="2000-01-01 12:00:00").gpm.plot_map() # Check variables if isinstance(variables, str): variables = [variables] available_variables = list(ds_era5.data_vars) invalid_variables = [var for var in variables if var not in available_variables] if len(invalid_variables) > 1: raise ValueError( f"{invalid_variables} are invalid ERA5 variables. Available variables are {available_variables}.", ) # Define time window over which to retrieve ERA5 data start_time = ds.gpm.start_time - datetime.timedelta(minutes=60) end_time = ds.gpm.end_time + datetime.timedelta(minutes=60) # Retrieve ERA5 data globally at given time period ds_era5 = ds_era5[variables].sel(time=slice(start_time, end_time)) ds_era5 = ds_era5.compute() # Set CRS to dataset crs_wgs84 = pyproj.CRS(proj="longlat", ellps="WGS84") ds_era5 = set_dataset_crs(ds_era5, crs=crs_wgs84, grid_mapping_name="spatial_ref", inplace=False) # Map data to swath ds_env = ds_era5.gpm.remap_on(ds) # Select data closest to sensor observation ds_env["delta_time"] = ds_env["time"] - ds["time"] idx_time = np.abs(ds_env["delta_time"]).argmin(dim="time") ds_env = ds_env.isel(time=idx_time) ds_env = ds_env.assign_coords({"time": (*ds["time"].dims, ds["time"].data)}) return ds_env