Source code for gpm.utils.xarray

# -----------------------------------------------------------------------------.
# MIT License

# Copyright (c) 2024 GPM-API developers
#
# This file is part of GPM-API.

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

# -----------------------------------------------------------------------------.
"""This module contains general utility for xarray objects."""
import functools

import numpy as np
import xarray as xr

####-------------------------------------------------------------------
#################
#### Checker ####
#################


[docs] def check_is_xarray(x): if not isinstance(x, (xr.DataArray, xr.Dataset)): raise TypeError("Expecting a xarray.Dataset or xarray.DataArray.")
[docs] def check_is_xarray_dataarray(x): if not isinstance(x, xr.DataArray): raise TypeError("Expecting a xarray.DataArray.")
[docs] def check_is_xarray_dataset(x): if not isinstance(x, xr.Dataset): raise TypeError("Expecting a xarray.Dataset.")
[docs] def check_variable_availabilty(ds, variable, argname): """Check variable availability in an xarray Dataset.""" if variable is None: raise ValueError("Please specify a dataset variable.") if variable not in ds: raise ValueError( f"{variable} is not a variable of the xarray.Dataset. Invalid {argname} argument.", )
####------------------------------------------------------------------- ################### #### Utilities #### ###################
[docs] def get_dataset_variables(ds, sort=False): """Get list of xarray.Dataset variables.""" variables = list(ds.data_vars) if sort: variables = sorted(variables) return variables
[docs] def get_xarray_variable(xr_obj, variable=None): """Return variable DataArray from xarray object. If the input is a xr.DataArray, it return it If the input is a xr.Dataset, it return the specified variable. """ check_is_xarray(xr_obj) if isinstance(xr_obj, xr.Dataset): check_variable_availabilty(xr_obj, variable, argname="variable") da = xr_obj[variable] else: da = xr_obj return da
[docs] def get_dimensions_without(xr_obj, dims): """Return the dimensions of the xarray object without the specified dimensions.""" if isinstance(dims, str): dims = [dims] data_dims = np.array(list(xr_obj.dims)) return data_dims[np.isin(data_dims, dims, invert=True)].tolist()
[docs] def has_unique_chunking(ds): """Check if a dataset has unique chunking.""" if not isinstance(ds, xr.Dataset): raise ValueError("Input must be an xarray Dataset.") # Create a dictionary to store unique chunk shapes for each dimension unique_chunks_per_dim = {} # Iterate through each variable's chunks for var_name in ds.variables: if hasattr(ds[var_name].data, "chunks"): # is dask array var_chunks = ds[var_name].data.chunks for dim, chunks in zip(ds[var_name].dims, var_chunks): if dim not in unique_chunks_per_dim: unique_chunks_per_dim[dim] = set() unique_chunks_per_dim[dim].add(chunks) if chunks not in unique_chunks_per_dim[dim]: return False # If all chunks are unique for each dimension, return True return True
[docs] def ensure_unique_chunking(ds): """Ensure the dataset has unique chunking. Conversion to :py:class:`dask.dataframe.DataFrame` requires unique chunking. If the xarray.Dataset does not have unique chunking, perform ``ds.unify_chunks``. Variable chunks can be visualized with: for var in ds.data_vars: print(var, ds[var].chunks) """ if not has_unique_chunking(ds): ds = ds.unify_chunks() return ds
####------------------------------------------------------------------- #################### #### Decorators #### ####################
[docs] def ensure_dim_order_dataarray(da, func, *args, **kwargs): """Ensure that the output DataArray has the same dimensions order as the input. New dimensions are moved to the last positions. """ # Get the original dimension order original_dims = da.dims dict_coord_dims = {coord: da[coord].dims for coord in list(da.coords)} # Apply the function to the DataArray da_out = func(da, *args, **kwargs) # Check output type if not isinstance(da_out, xr.DataArray): raise TypeError("The function does not return a xr.DataArray.") # Check which of the original dimensions are still present dim_order = [dim for dim in original_dims if dim in da_out.dims] # Transpose the result to ensure the same dimension order da_out = da_out.transpose(*dim_order, ...) # Transpose the coordinates to for coord in list(da_out.coords): if coord in dict_coord_dims: dim_order = [dim for dim in dict_coord_dims[coord] if dim in da_out[coord].dims] da_out[coord] = da_out[coord].transpose(*dim_order, ...) return da_out
[docs] def ensure_dim_order_dataset(ds, func, *args, **kwargs): """Ensure that the output Dataset has the same dimensions order as the input. New dimensions are moved to the last positions. """ # Get the original dimension order dict_coord_dims = {coord: ds[coord].dims for coord in list(ds.coords)} dict_var_dims = {var: ds[var].dims for var in list(ds.data_vars)} # Apply the function to the Dataset ds_out = func(ds, *args, **kwargs) if not isinstance(ds_out, xr.Dataset): raise TypeError("The function does not return a xr.Dataset.") # Check which of the original variables and dimensions are still present and reorder for var in list(ds_out.data_vars): if var in dict_var_dims: dim_order = [dim for dim in dict_var_dims[var] if dim in ds_out[var].dims] ds_out[var] = ds_out[var].transpose(*dim_order, ...) for coord in list(ds_out.coords): if coord in dict_coord_dims: dim_order = [dim for dim in dict_coord_dims[coord] if dim in ds_out[coord].dims] ds_out[coord] = ds_out[coord].transpose(*dim_order, ...) return ds_out
[docs] def xr_ensure_dimension_order(func): """Decorator which ensures the output xarray object has same dimension order as input. The decorator expects that the functions return the same type of xarray object ! The decorator can deal with functions that: - returns an xarray object with new dimensions - returns an xarray object with less dimensions than the originals New dimensions are moved to the last positions. """ @functools.wraps(func) def wrapper(*args, **kwargs): xr_obj = args[0] # Assuming the first argument is the dataset if isinstance(xr_obj, xr.Dataset): return ensure_dim_order_dataset(xr_obj, func, *args[1:], **kwargs) return ensure_dim_order_dataarray(xr_obj, func, *args[1:], **kwargs) return wrapper
[docs] @xr_ensure_dimension_order def squeeze_unsqueeze_dataarray(da, func, *args, **kwargs): """Ensure that the output DataArray has the same dimensions as the input. Dimensions of size 1 are kept also if the function drop them ! New dimensions are moved to the last positions. """ # Retrieve dimension to be squeezed original_dims = set(da.dims) squeezed_dims = original_dims - set(da.squeeze().dims) # List coordinates which are squeezed dict_squeezed = {dim: [] for dim in squeezed_dims} for dim in squeezed_dims: for coord in list(da.coords): if dim in da[coord].dims: dict_squeezed[dim].append(coord) # Squeeze da = da.squeeze() # Apply function da = func(da, *args, **kwargs) # Call the function with the squeezed dataset # Check output type if not isinstance(da, xr.DataArray): raise TypeError("The function does not return a xr.DataArray.") # Unsqueeze back for dim, coords in dict_squeezed.items(): if dim not in da.dims: da = da.expand_dims(dim=dim, axis=None) for coord in coords: if dim not in da[coord].dims: # coord with same name as dim are automatically expanded ! da[coord] = da[coord].expand_dims(dim=dim, axis=None) # Deal with coordinates named as dimension but without such dimension ! # for dim, coords in dict_squeezed.items(): # if len(coords) == 0 and dim in da.coords: # scalar_coord_value = da[dim].data[0] # da = da.drop_vars(dim) # da = da.assign_coords({"___tmp_coord__": scalar_coord_value}).rename({"___tmp_coord__": dim}) return da
[docs] @xr_ensure_dimension_order def squeeze_unsqueeze_dataset(ds, func, *args, **kwargs): """Ensure that the output Dataset has the same dimensions as the input. Dimensions of size 1 are kept also if the function drop them ! New dimensions are moved to the last positions. """ # Retrieve dimension to be squeezed original_dims = set(ds.dims) squeezed_dims = original_dims - set(ds.squeeze().dims) # List coordinates which are squeezed dict_squeezed = {dim: [] for dim in squeezed_dims} for dim in squeezed_dims: for var in ds.variables: # coords + variables if dim in ds[var].dims: dict_squeezed[dim].append(var) # Squeeze ds = ds.squeeze() # Apply function ds = func(ds, *args, **kwargs) # Call the function with the squeezed dataset # Check output type if not isinstance(ds, xr.Dataset): raise TypeError("The function does not return a xr.Dataset.") # Unsqueeze back for dim, variables in dict_squeezed.items(): for var in variables: if dim not in ds[var].dims: ds[var] = ds[var].expand_dims(dim=dim, axis=None) # not same order as start return ds
[docs] def xr_squeeze_unsqueeze(func): """Decorator that squeeze-unsqueeze the xarray object before passing it to the function. This decorator allow to keep the dimensions of the xarray object intact. Dimensions of size 1 are kept also if the function drop them. The dimension order of the arrays is conserved. New dimensions are moved to the last positions. """ @functools.wraps(func) def wrapper(*args, **kwargs): xr_obj = args[0] # Assuming the first argument is the dataset if isinstance(xr_obj, xr.Dataset): return squeeze_unsqueeze_dataset(xr_obj, func, *args[1:], **kwargs) return squeeze_unsqueeze_dataarray(xr_obj, func, *args[1:], **kwargs) return wrapper