# -----------------------------------------------------------------------------.
# MIT License
# Copyright (c) 2024 GPM-API developers
#
# This file is part of GPM-API.
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
# -----------------------------------------------------------------------------.
"""This module defines functions providing GPM-API Dataset information."""
from itertools import chain
import numpy as np
import xarray as xr
from gpm.dataset.dimensions import (
FREQUENCY_DIMS,
GRID_SPATIAL_DIMS,
ORBIT_SPATIAL_DIMS,
SPATIAL_DIMS,
VERTICAL_DIMS,
)
from gpm.utils.xarray import (
check_is_xarray,
get_dataset_variables,
)
# Refactor Notes
# - GRID_SPATIAL_DIMS, ORBIT_SPATIAL_DIMS to be refactored
# - is_grid, is_orbit currently also depends on gpm.dataset.crs._get_proj_dim_coords
# - Code could be generalized to work with any satellite data format ???
# - GPM ORBIT = pyresample SwathDefinition
# - GPM GRID = pyresample AreaDefinition
# - GPM ORBIT dimensions: (cross-track, along-track)
# - GPM GRID dimensions: (lon, lat)
# - satpy dimensions (y, x) (for both ORBIT and GRID)
# --> Accept both (lat, lon), (latitude, longitude), (y,x), (...) coordinates
# --> Adapt plotting, crop utility to deal with different coordinate names
# --> Then this functions can be used with whatever satellite products
####-----------------------------------------------------------------------------------------------------------------.
####################
#### Dimensions ####
####################
[docs]
def get_frequency_dimension(xr_obj):
"""Return the name of the available frequency dimension."""
return np.array(FREQUENCY_DIMS)[np.isin(FREQUENCY_DIMS, list(xr_obj.dims))].tolist()
[docs]
def get_vertical_dimension(xr_obj):
"""Return the name of the available vertical dimension."""
vertical_dim = np.array(VERTICAL_DIMS)[np.isin(VERTICAL_DIMS, list(xr_obj.dims))].tolist()
if len(vertical_dim) > 1:
raise ValueError(f"Only one vertical dimension is allowed. Got {vertical_dim}.")
return vertical_dim
[docs]
def get_spatial_dimensions(xr_obj):
"""Return the name of the available spatial dimensions."""
dims = list(xr_obj.dims)
flattened_spatial_dims = list(chain.from_iterable(SPATIAL_DIMS))
spatial_dimensions = np.array(flattened_spatial_dims)[np.isin(flattened_spatial_dims, dims)].tolist()
if len(spatial_dimensions) > 2:
raise ValueError(f"Only two horizontal spatial dimensions are allowed. Got {spatial_dimensions}.")
return spatial_dimensions
def _has_spatial_dim_dataarray(da, strict):
"""Check if the xarray.DataArray has spatial horizontal dimensions."""
spatial_dims = get_spatial_dimensions(da)
if not spatial_dims:
return False
if strict: # only spatial dimensions
return bool(np.all(np.isin(da.dims, spatial_dims)))
return True
def _has_vertical_dim_dataarray(da, strict):
"""Check if the xarray.DataArray has a vertical dimension."""
vertical_dims = list(get_vertical_dimension(da))
if not vertical_dims:
return False
only_vertical_dim = len(da.dims) == 1
if strict and not only_vertical_dim: # noqa
return False
return True
def _has_frequency_dim_dataarray(da, strict):
"""Check if the xarray.DataArray has a frequency dimension."""
frequency_dims = list(get_frequency_dimension(da))
if not frequency_dims:
return False
only_frequency_dim = len(da.dims) == 1
if strict and not only_frequency_dim: # noqa
return False
return True
def _has_vertical_dim_dataset(ds, strict):
"""Check if at least one xarray.DataArrays of a xarray.Dataset have a vertical dimension."""
has_vertical = np.any(
[_has_vertical_dim_dataarray(ds[var], strict=strict) for var in get_dataset_variables(ds)],
).item()
return bool(has_vertical)
def _has_spatial_dim_dataset(ds, strict):
"""Check if at least one xarray.DataArrays of a xarray.Dataset have at least one spatial dimension."""
has_spatial = np.any(
[_has_spatial_dim_dataarray(ds[var], strict=strict) for var in get_dataset_variables(ds)],
).item()
return bool(has_spatial)
def _has_frequency_dim_dataset(ds, strict):
"""Check if at least one xarray.DataArrays of a xarray.Dataset have a frequency dimension."""
has_spatial = np.any(
[_has_frequency_dim_dataarray(ds[var], strict=strict) for var in get_dataset_variables(ds)],
).item()
return bool(has_spatial)
def _check_xarray_conditions(da_condition, ds_condition, xr_obj, strict, squeeze):
check_is_xarray(xr_obj)
if squeeze:
xr_obj = xr_obj.squeeze() # remove dimensions of size 1
if isinstance(xr_obj, xr.Dataset):
return ds_condition(xr_obj, strict=strict)
return da_condition(xr_obj, strict=strict)
[docs]
def has_spatial_dim(xr_obj, strict=False, squeeze=True):
"""Check if the xarray.DataArray or xarray.Dataset have a spatial dimension.
If ``squeeze=True`` (default), dimensions of size=1 are removed prior testing.
If ``strict=True`` , the xarray.DataArray can have only spatial dimensions.
If ``strict=False`` (default), the xarray.DataArray can also have other dimensions.
"""
return _check_xarray_conditions(
_has_spatial_dim_dataarray,
_has_spatial_dim_dataset,
xr_obj=xr_obj,
strict=strict,
squeeze=squeeze,
)
[docs]
def has_vertical_dim(xr_obj, strict=False, squeeze=True):
"""Check if the xarray.DataArray or xarray.Dataset have a vertical dimension.
If ``squeeze=True`` (default), dimensions of size=1 are removed prior testing.
If ``strict=True`` , the xarray.DataArray must have just the vertical dimension.
If ``strict=False`` (default), the xarray.DataArray can also have additional dimensions.
"""
return _check_xarray_conditions(
_has_vertical_dim_dataarray,
_has_vertical_dim_dataset,
xr_obj=xr_obj,
strict=strict,
squeeze=squeeze,
)
[docs]
def has_frequency_dim(xr_obj, strict=False, squeeze=True):
"""Check if the xarray.DataArray or xarray.Dataset have a frequency dimension.
If ``squeeze=True`` (default), dimensions of size=1 are removed prior testing.
If ``strict=True`` , the xarray.DataArray must have just the frequency dimension.
If ``strict=False`` (default), the xarray.DataArray can also have additional dimensions.
"""
return _check_xarray_conditions(
_has_frequency_dim_dataarray,
_has_frequency_dim_dataset,
xr_obj=xr_obj,
strict=strict,
squeeze=squeeze,
)
####-------------------------------------------------------------------------------------
#######################
#### GRID vs ORBIT ####
#######################
def _is_grid_expected_spatial_dims(spatial_dims):
"""Check if the GRID spatial dimensions have the expected names."""
is_grid = set(spatial_dims) == set(GRID_SPATIAL_DIMS)
is_lonlat = set(spatial_dims) == {"latitude", "longitude"}
is_xy = set(spatial_dims) == {"y", "x"}
return bool(is_grid or is_lonlat or is_xy)
def _is_orbit_expected_spatial_dims(spatial_dims):
"""Check if the ORBIT spatial dimensions have the expected names.
Allow to have only one dimension: cross_track or along_track.
"""
# is_orbit = set(spatial_dims) == set(ORBIT_SPATIAL_DIMS)
# is_xy = set(spatial_dims) == {"y", "x"}
# Check if spatial_dims is a non-empty subset of ORBIT_SPATIAL_DIMS
is_orbit = set(spatial_dims).issubset(ORBIT_SPATIAL_DIMS) and bool(spatial_dims)
is_xy = set(spatial_dims).issubset({"y", "x"}) and bool(spatial_dims)
return bool(is_orbit or is_xy)
def _is_expected_spatial_dims(spatial_dims):
"""Check that the spatial_dims are the expected two."""
is_orbit = _is_orbit_expected_spatial_dims(spatial_dims)
is_grid = _is_grid_expected_spatial_dims(spatial_dims)
return bool(is_orbit or is_grid)
[docs]
def is_orbit(xr_obj):
"""Check whether the xarray object is a GPM ORBIT.
An ORBIT cross-section (nadir view) or transect is considered ORBIT.
An ORBIT object must have the coordinates available.
"""
from gpm.dataset.crs import _get_swath_dim_coords
# Check dimension names
spatial_dims = get_spatial_dimensions(xr_obj)
if not _is_orbit_expected_spatial_dims(spatial_dims):
return False
# Check that swath coords exists
# - Swath objects are determined by 1D (nadir looking) and 2D coordinates
x_coord, y_coord = _get_swath_dim_coords(xr_obj)
return bool(x_coord is not None and y_coord is not None)
[docs]
def is_grid(xr_obj):
"""Check whether the xarray object is a GPM GRID.
A GRID slice is not considered a GRID object !
An GRID object must have the coordinates available !
"""
from gpm.dataset.crs import _get_proj_dim_coords
# Check dimension names
spatial_dims = get_spatial_dimensions(xr_obj)
if not _is_grid_expected_spatial_dims(spatial_dims):
return False
# Check that 1D coords exists
# - Area objects can be determined by 1D and 2D coordinates
# - 1D coordinates: projection coordinates
# - 2D coordinates: lon/lat coordinates of each pixel
x_coord, y_coord = _get_proj_dim_coords(xr_obj)
return bool(x_coord is not None and y_coord is not None)
####-------------------------------------------------------------------------------------
#######################
#### ORBIT TYPES ####
#######################
def _is_spatial_2d_dataarray(da, strict):
"""Check if the xarray.DataArray is a spatial 2D array."""
spatial_dims = get_spatial_dimensions(da)
if not _is_expected_spatial_dims(spatial_dims) or len(spatial_dims) != 2:
return False
vertical_dims = get_vertical_dimension(da)
if vertical_dims:
return False
if strict and len(da.dims) != 2: # noqa
return False
return True
def _is_spatial_3d_dataarray(da, strict):
"""Check if the xarray.DataArray is a spatial 3D array."""
spatial_dims = get_spatial_dimensions(da)
if not _is_expected_spatial_dims(spatial_dims) or len(spatial_dims) != 2:
return False
vertical_dims = get_vertical_dimension(da)
if not vertical_dims:
return False
if strict and len(da.dims) != 3: # noqa
return False
return True
def _is_cross_section_dataarray(da, strict):
"""Check if the xarray.DataArray is a cross-section array."""
spatial_dims = list(get_spatial_dimensions(da))
if len(spatial_dims) != 1:
return False
vertical_dims = list(get_vertical_dimension(da))
if not vertical_dims:
return False
if strict and len(da.dims) != 2: # noqa
return False
return True
def _is_transect_dataarray(da, strict):
"""Check if the xarray.DataArray is a transect array."""
spatial_dims = list(get_spatial_dimensions(da))
if len(spatial_dims) != 1:
return False
if strict and len(da.dims) != 1: # noqa
return False
return True
def _check_dataarrays_condition(condition, ds, strict):
if not ds: # Empty dataset (no variables)
return False
all_valid = np.all(
[condition(ds[var], strict=strict) for var in get_dataset_variables(ds)],
)
return bool(all_valid)
def _is_spatial_2d_dataset(ds, strict):
"""Check if all xarray.DataArrays of a xarray.Dataset are spatial 2D objects."""
return _check_dataarrays_condition(_is_spatial_2d_dataarray, ds=ds, strict=strict)
def _is_spatial_3d_dataset(ds, strict):
"""Check if all xarray.DataArrays of a xarray.Dataset are spatial 3D objects."""
return _check_dataarrays_condition(_is_spatial_3d_dataarray, ds=ds, strict=strict)
def _is_cross_section_dataset(ds, strict):
"""Check if all xarray.DataArrays of a xarray.Dataset are cross-section objects."""
return _check_dataarrays_condition(_is_cross_section_dataarray, ds=ds, strict=strict)
def _is_transect_dataset(ds, strict):
"""Check if all xarray.DataArrays of a xarray.Dataset are transects objects."""
return _check_dataarrays_condition(_is_transect_dataarray, ds=ds, strict=strict)
[docs]
def is_spatial_2d(xr_obj, strict=True, squeeze=True):
"""Check if the xarray.DataArray or xarray.Dataset is a spatial 2D object.
If ``squeeze=True`` (default), dimensions of size=1 are removed prior testing.
If ``strict=True`` (default), the xarray.DataArray must have just the 2D spatial dimensions.
If ``strict=False`` , the xarray.DataArray can have additional dimensions (except vertical).
"""
return _check_xarray_conditions(
_is_spatial_2d_dataarray,
_is_spatial_2d_dataset,
xr_obj=xr_obj,
strict=strict,
squeeze=squeeze,
)
[docs]
def is_spatial_3d(xr_obj, strict=True, squeeze=True):
"""Check if the xarray.DataArray or xarray.Dataset i as spatial 3d object.
If ``squeeze=True`` (default), dimensions of size=1 are removed prior testing.
If ``strict=True`` (default), the xarray.DataArray must have just the 3D spatial dimensions.
If ``strict=False`` , the xarray.DataArray can also have additional dimensions.
"""
return _check_xarray_conditions(
_is_spatial_3d_dataarray,
_is_spatial_3d_dataset,
xr_obj=xr_obj,
strict=strict,
squeeze=squeeze,
)
[docs]
def is_cross_section(xr_obj, strict=True, squeeze=True):
"""Check if the xarray.DataArray or xarray.Dataset is a cross-section object.
If ``squeeze=True`` (default), dimensions of size=1 are removed prior testing.
If ``strict=True`` (default), the xarray.DataArray must have just the
vertical dimension and a horizontal dimension.
If ``strict=False`` , the xarray.DataArray can have additional dimensions but only
a single horizontal and vertical dimension.
"""
return _check_xarray_conditions(
_is_cross_section_dataarray,
_is_cross_section_dataset,
xr_obj=xr_obj,
strict=strict,
squeeze=squeeze,
)
[docs]
def is_transect(xr_obj, strict=True, squeeze=True):
"""Check if the xarray.DataArray or xarray.Dataset is a transect object.
If ``squeeze=True`` (default), dimensions of size=1 are removed prior testing.
If ``strict=True`` (default), the xarray.DataArray must have just an horizontal dimension.
If ``strict=False`` , the xarray.DataArray can have additional dimensions but only a single
horizontal dimension.
"""
return _check_xarray_conditions(
_is_transect_dataarray,
_is_transect_dataset,
xr_obj=xr_obj,
strict=strict,
squeeze=squeeze,
)
####-------------------------------------------------------------------------------------------------------------.
#################
#### Checks ####
#################
[docs]
def check_is_orbit(xr_obj):
"""Check is a GPM ORBIT object."""
if not is_orbit(xr_obj):
raise ValueError("Expecting a GPM ORBIT object.")
[docs]
def check_is_grid(xr_obj):
"""Check is a GPM GRID object."""
if not is_grid(xr_obj):
raise ValueError("Expecting a GPM GRID object.")
[docs]
def check_is_gpm_object(xr_obj):
"""Check is a GPM object (GRID or ORBIT)."""
if not is_orbit(xr_obj) and not is_grid(xr_obj):
raise ValueError("Unrecognized GPM xarray object.")
[docs]
def check_has_cross_track_dim(xr_obj, dim="cross_track"):
if dim not in xr_obj.dims:
raise ValueError(f"The 'cross-track' dimension {dim} is not available.")
[docs]
def check_has_along_track_dim(xr_obj, dim="along_track"):
if dim not in xr_obj.dims:
raise ValueError(f"The 'along_track' dimension {dim} is not available.")
[docs]
def check_is_spatial_2d(xr_obj, strict=True, squeeze=True):
"""Check if the xarray.DataArray or xarray.Dataset is a spatial 2D field.
If ``squeeze=True`` (default), dimensions of size=1 are removed prior testing.
If ``strict=True`` (default), the xarray.DataArray must have just the 2D spatial dimensions.
If ``strict=False`` , the xarray.DataArray can also have additional dimensions (except vertical).
"""
if not is_spatial_2d(xr_obj, strict=strict, squeeze=squeeze):
raise ValueError("Expecting a 2D GPM field.")
[docs]
def check_is_spatial_3d(xr_obj, strict=True, squeeze=True):
"""Check if the xarray.DataArray or xarray.Dataset is a spatial 3D field.
If ``squeeze=True`` (default), dimensions of size=1 are removed prior testing.
If ``strict=True`` (default), the xarray.DataArray must have just the 3D spatial dimensions.
If ``strict=False`` , the xarray.DataArray can also have additional dimensions.
"""
if not is_spatial_3d(xr_obj, strict=strict, squeeze=squeeze):
raise ValueError("Expecting a 3D GPM field.")
[docs]
def check_is_cross_section(xr_obj, strict=True, squeeze=True):
"""Check if the xarray.DataArray or xarray.Dataset is a cross-section.
If ``squeeze=True`` (default), dimensions of size=1 are removed prior testing.
If ``strict=True`` (default), the xarray.DataArray must have just the
vertical dimension and a horizontal dimension.
If ``strict=False`` , the xarray.DataArray can also have additional dimensions,
but only a single vertical and horizontal dimension.
"""
if not is_cross_section(xr_obj, strict=strict, squeeze=squeeze):
raise ValueError("Expecting a cross-section extracted from a 3D GPM field.")
[docs]
def check_is_transect(xr_obj, strict=True, squeeze=True):
"""Check if the xarray.DataArray or xarray.Dataset is a transect.
If ``squeeze=True`` (default), dimensions of size=1 are removed prior testing.
If ``strict=True`` (default), the xarray.DataArray must have just an horizontal dimension.
If ``strict=False`` , the xarray.DataArray can also have additional dimensions,
but only an horizontal dimension.
"""
if not is_transect(xr_obj, strict=strict, squeeze=squeeze):
raise ValueError("Expecting a transect object.")
[docs]
def check_has_vertical_dim(xr_obj, strict=False, squeeze=True):
"""Check if the xarray.DataArray or xarray.Dataset have a vertical dimension.
If ``squeeze=True`` (default), dimensions of size=1 are removed prior testing.
If ``strict=False`` (default), the xarray.DataArray can also have additional dimensions.
If ``strict=True`` , the xarray.DataArray must have just the vertical dimension.
"""
if not has_vertical_dim(xr_obj, strict=strict, squeeze=squeeze):
only = "only " if strict else ""
raise ValueError(f"Expecting an xarray object with {only}a vertical dimension.")
[docs]
def check_has_spatial_dim(xr_obj, strict=False, squeeze=True):
"""Check if the xarray.DataArray or xarray.Dataset has at least one spatial horizontal dimension.
If ``squeeze=True`` (default), dimensions of size=1 are removed prior testing.
If ``strict=False`` (default), the xarray.DataArray can also have additional dimensions.
If ``strict=True`` , the xarray.DataArray must have just the spatial dimensions.
"""
if not has_spatial_dim(xr_obj, strict=strict, squeeze=squeeze):
only = "only " if strict else ""
raise ValueError(f"Expecting an xarray object with {only}spatial dimensions.")
[docs]
def check_has_frequency_dim(xr_obj, strict=False, squeeze=True):
"""Check if the xarray.DataArray or xarray.Dataset has a frequency dimension.
If ``squeeze=True`` (default), dimensions of size=1 are removed prior testing.
If ``strict=False`` (default), the xarray.DataArray can also have additional dimensions.
If ``strict=True`` , the xarray.DataArray must have just the spatial dimensions.
"""
if not has_frequency_dim(xr_obj, strict=strict, squeeze=squeeze):
only = "only " if strict else ""
raise ValueError(f"Expecting an xarray object with {only}a frequency dimension.")
####-----------------------------------------------------------------------------------------------------------------.
###############################
#### Variables information ####
###############################
[docs]
def get_spatial_2d_variables(ds, strict=False, squeeze=True):
"""Get list of xarray.Dataset 2D spatial variables.
If ``strict=False`` (default), the potential variables for which a 2D spatial field can be derived.
If ``strict=True``, the variables that are already a 2D spatial field.
"""
variables = [var for var in get_dataset_variables(ds) if is_spatial_2d(ds[var], strict=strict, squeeze=squeeze)]
return sorted(variables)
[docs]
def get_spatial_3d_variables(ds, strict=False, squeeze=True):
"""Get list of xarray.Dataset 3D spatial variables.
If ``strict=False`` (default), the potential variables for which a 3D spatial field can be derived.
If ``strict=True``, the variables that are already a 3D spatial field.
"""
variables = [var for var in get_dataset_variables(ds) if is_spatial_3d(ds[var], strict=strict, squeeze=squeeze)]
return sorted(variables)
[docs]
def get_cross_section_variables(ds, strict=False, squeeze=True):
"""Get list of xarray.Dataset cross-section variables.
If ``strict=False`` (default), the potential variables for which a strict cross-section can be derived.
If ``strict=True``, the variables that are already a cross-section.
"""
variables = [var for var in get_dataset_variables(ds) if is_cross_section(ds[var], strict=strict, squeeze=squeeze)]
return sorted(variables)
# def get_transect_variables(ds, strict=False, squeeze=True):
# """Get list of xarray.Dataset transect variables.
# If ``strict=False`` (default), the potential variables for which a strict transect can be derived.
# If ``strict=True``, the variables that are already a transect.
# """
# variables = [var for var in get_dataset_variables(ds) if is_transect(ds[var], strict=strict, squeeze=squeeze)]
# return sorted(variables)
[docs]
def get_vertical_variables(ds):
"""Get list of xarray.Dataset variables with vertical dimension."""
variables = [var for var in get_dataset_variables(ds) if has_vertical_dim(ds[var], strict=False, squeeze=True)]
return sorted(variables)
[docs]
def get_frequency_variables(ds):
"""Get list of xarray.Dataset variables with frequency-related dimension."""
variables = [var for var in get_dataset_variables(ds) if has_frequency_dim(ds[var], strict=False, squeeze=True)]
return sorted(variables)
[docs]
def get_bin_variables(ds):
"""Get list of xarray.Dataset radar product variables with name starting with `bin` or ending with `Bin`.
In CMB products, bin variables end with the `Bin` suffix.
In L1 and L2 RADAR products, bin variables starts with the `bin` prefix.
"""
variables = [var for var in get_dataset_variables(ds) if var.startswith("bin") or var.endswith("Bin")]
return sorted(variables)