# -----------------------------------------------------------------------------.
# MIT License
# Copyright (c) 2024 GPM-API developers
#
# This file is part of GPM-API.
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
# -----------------------------------------------------------------------------.
"""This module contains basic functions for GPM-API data visualization."""
import inspect
import cartopy
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import matplotlib.pyplot as plt
import numpy as np
import xarray as xr
from mpl_toolkits.axes_grid1 import make_axes_locatable
from scipy.interpolate import griddata
import gpm
from gpm import get_plot_kwargs
from gpm.utils.area import get_lonlat_corners_from_centroids
[docs]
def is_generator(obj):
return inspect.isgeneratorfunction(obj) or inspect.isgenerator(obj)
def _call_optimize_layout(self):
"""Optimize the figure layout."""
adapt_fig_size(ax=self.axes)
self.figure.tight_layout()
[docs]
def add_optimize_layout_method(p):
"""Add a method to optimize the figure layout using monkey patching."""
p.optimize_layout = _call_optimize_layout.__get__(p, type(p))
return p
[docs]
def adapt_fig_size(ax, nrow=1, ncol=1):
"""Adjusts the figure height of the plot based on the aspect ratio of cartopy subplots.
This function is intended to be called after all plotting has been completed.
It operates under the assumption that all subplots within the figure share the same aspect ratio.
Assumes that the first axis in the collection of axes is representative of all others.
This means that all subplots are expected to have the same aspect ratio and size.
The implementation is inspired by Mathias Hauser's mplotutils set_map_layout function.
"""
# Determine the number of rows and columns of subplots in the figure.
# This information is crucial for calculating the new height of the figure.
# nrow, ncol, __, __ = ax.get_subplotspec().get_geometry()
# Access the figure object from the axis to manipulate its properties.
fig = ax.get_figure()
# Retrieve the current size of the figure in inches.
width, original_height = fig.get_size_inches()
# A call to draw the canvas is required to make sure the geometry of the figure is up-to-date.
# This ensures that subsequent calculations for adjusting the layout are based on the latest state.
fig.canvas.draw()
# Extract subplot parameters to understand the figure's layout.
# These parameters include the margins of the figure and the spaces between subplots.
bottom = fig.subplotpars.bottom
top = fig.subplotpars.top
left = fig.subplotpars.left
right = fig.subplotpars.right
hspace = fig.subplotpars.hspace # vertical space between subplots
wspace = fig.subplotpars.wspace # horizontal space between subplots
# Calculate the aspect ratio of the data in the subplot.
# This ratio is used to adjust the height of the figure to match the aspect ratio of the data.
aspect = ax.get_data_ratio()
# Calculate the width of a single plot, considering the left and right margins,
# the number of columns, and the space between columns.
wp = (width - width * (left + (1 - right))) / (ncol + (ncol - 1) * wspace)
# Calculate the height of a single plot using its width and the data aspect ratio.
hp = wp * aspect
# Calculate the new height of the figure, taking into account the number of rows,
# the space between rows, and the top and bottom margins.
height = (hp * (nrow + ((nrow - 1) * hspace))) / (1.0 - (bottom + (1 - top)))
# Check if the new height is significantly reduced (more than halved).
if original_height / height > 2:
# Calculate the scale factor to adjust the figure size closer to the original.
scale_factor = original_height / height / 2
# Apply the scale factor to both width and height to maintain the aspect ratio.
width *= scale_factor
height *= scale_factor
# Apply the calculated width and height to adjust the figure size.
fig.set_figwidth(width)
fig.set_figheight(height)
####--------------------------------------------------------------------------.
[docs]
def infill_invalid_coords(xr_obj, x="lon", y="lat"):
"""Infill invalid coordinates.
Interpolate the coordinates within the convex hull of data.
Use nearest neighbour outside the convex hull of data.
"""
# Copy object
xr_obj = xr_obj.copy()
lon = np.asanyarray(xr_obj[x].data)
lat = np.asanyarray(xr_obj[y].data)
# Retrieve infilled coordinates
lon, lat, _ = get_valid_pcolormesh_inputs(x=lon, y=lat, data=None, mask_data=False)
xr_obj[x].data = lon
xr_obj[y].data = lat
return xr_obj
def _interpolate_data(arr, method="linear"):
# 1D coordinate (i.e. along_track/cross_track view)
if arr.ndim == 1:
return _interpolate_1d_coord(arr, method=method)
# 2D coordinates (swath image)
return _interpolate_2d_coord(arr, method=method)
def _interpolate_1d_coord(arr, method="linear"):
# Find invalid locations
is_invalid = ~np.isfinite(arr)
# Find the indices of NaN values
nan_indices = np.where(is_invalid)[0]
# Return array if not NaN values
if len(nan_indices) == 0:
return arr
# Find the indices of non-NaN values
non_nan_indices = np.where(~is_invalid)
# Create indices
indices = np.arange(len(arr))
# Points where we have valid data
points = indices[non_nan_indices]
# Points where data is NaN
points_nan = indices[nan_indices]
# Values at the non-NaN points
values = arr[non_nan_indices]
# Interpolate using griddata
arr_new = arr.copy()
arr_new[nan_indices] = griddata(points, values, points_nan, method=method)
return arr_new
def _interpolate_2d_coord(arr, method="linear"):
# Find invalid locations
is_invalid = ~np.isfinite(arr)
# Find the indices of NaN values
nan_indices = np.where(is_invalid)
# Return array if not NaN values
if len(nan_indices) == 0:
return arr
# Find the indices of non-NaN values
non_nan_indices = np.where(~is_invalid)
# Create a meshgrid of indices
x, y = np.meshgrid(range(arr.shape[1]), range(arr.shape[0]))
# Points (X, Y) where we have valid data
points = np.array([y[non_nan_indices], x[non_nan_indices]]).T
# Points where data is NaN
points_nan = np.array([y[nan_indices], x[nan_indices]]).T
# Values at the non-NaN points
values = arr[non_nan_indices]
# Interpolate using griddata
arr_new = arr.copy()
arr_new[nan_indices] = griddata(points, values, points_nan, method=method)
return arr_new
def _mask_antimeridian_crossing_arr(arr, antimeridian_mask, rgb):
if np.ma.is_masked(arr):
if rgb:
antimeridian_mask = np.broadcast_to(np.expand_dims(antimeridian_mask, axis=-1), arr.shape)
combined_mask = np.logical_or(arr.mask, antimeridian_mask)
else:
combined_mask = np.logical_or(arr.mask, antimeridian_mask)
arr = np.ma.masked_where(combined_mask, arr)
else:
if rgb:
antimeridian_mask = np.broadcast_to(
np.expand_dims(antimeridian_mask, axis=-1),
arr.shape,
)
arr = np.ma.masked_where(antimeridian_mask, arr)
return arr
[docs]
def mask_antimeridian_crossing_array(arr, lon, rgb, plot_kwargs):
"""Mask the array cells crossing the antimeridian.
Here we assume not invalid lon coordinates anymore.
Cartopy still bugs with several projections when data cross the antimeridian.
By default, GPM-API mask data crossing the antimeridian.
The GPM-API configuration default can be modified with: ``gpm.config.set({"viz_hide_antimeridian_data": False})``
"""
antimeridian_mask = get_antimeridian_mask(lon)
is_crossing_antimeridian = np.any(antimeridian_mask)
if is_crossing_antimeridian:
# Sanitize cmap to avoid cartopy bug related to cmap bad color
# - Cartopy requires the bad color to be fully transparent
plot_kwargs = _sanitize_cartopy_plot_kwargs(plot_kwargs)
# Mask data based on GPM-API config 'viz_hide_antimeridian_data'
if gpm.config.get("viz_hide_antimeridian_data"): # default is True
arr = _mask_antimeridian_crossing_arr(arr, antimeridian_mask=antimeridian_mask, rgb=rgb)
return arr, plot_kwargs
[docs]
def get_antimeridian_mask(lons):
"""Get mask of longitude coordinates neighbors crossing the antimeridian."""
from scipy.ndimage import binary_dilation
# Initialize mask
n_y, n_x = lons.shape
mask = np.zeros((n_y - 1, n_x - 1))
# Check vertical edges
row_idx, col_idx = np.where(np.abs(np.diff(lons, axis=0)) > 180)
col_idx = np.clip(col_idx - 1, 0, n_x - 1)
mask[row_idx, col_idx] = 1
# Check horizontal edges
row_idx, col_idx = np.where(np.abs(np.diff(lons, axis=1)) > 180)
row_idx = np.clip(row_idx - 1, 0, n_y - 1)
mask[row_idx, col_idx] = 1
# Buffer by 1 in all directions to avoid plotting cells neighbour to those crossing the antimeridian
# --> This should not be needed, but it's needed to avoid cartopy bugs !
return binary_dilation(mask)
####--------------------------------------------------------------------------.
########################
#### Plot utilities ####
########################
[docs]
def preprocess_rgb_dataarray(da, rgb):
if rgb:
if rgb not in da.dims:
raise ValueError(f"The specified rgb='{rgb}' must be a dimension of the DataArray.")
if da[rgb].size not in [3, 4]:
raise ValueError("The RGB dimension must have size 3 or 4.")
da = da.transpose(..., rgb)
return da
[docs]
def preprocess_subplot_kwargs(subplot_kwargs):
subplot_kwargs = {} if subplot_kwargs is None else subplot_kwargs
subplot_kwargs = subplot_kwargs.copy()
if "projection" not in subplot_kwargs:
subplot_kwargs["projection"] = ccrs.PlateCarree()
return subplot_kwargs
[docs]
def infer_xy_labels(da, x=None, y=None, rgb=None):
from xarray.plot.utils import _infer_xy_labels
# Infer dimensions
x, y = _infer_xy_labels(da, x=x, y=y, imshow=True, rgb=rgb) # dummy flag for rgb
return x, y
[docs]
def infer_map_xy_coords(da, x=None, y=None):
"""
Infer possible map x and y coordinates for the given DataArray.
Parameters
----------
da : xarray.DataArray
The input DataArray.
x : str, optional
The name of the x (i.e. longitude) coordinate. If None, it will be inferred.
y : str, optional
The name of the y (i.e. latitude) coordinate. If None, it will be inferred.
Returns
-------
tuple
The inferred (x, y) coordinates.
"""
possible_x_coords = ["x", "lon", "longitude"]
possible_y_coords = ["y", "lat", "latitude"]
if x is None:
for coord in possible_x_coords:
if coord in da.coords:
x = coord
break
else:
raise ValueError("Cannot infer x coordinate. Please provide the x coordinate.")
if y is None:
for coord in possible_y_coords:
if coord in da.coords:
y = coord
break
else:
raise ValueError("Cannot infer y coordinate. Please provide the y coordinate.")
return x, y
[docs]
def initialize_cartopy_plot(
ax,
fig_kwargs,
subplot_kwargs,
add_background,
):
"""Initialize figure for cartopy plot if necessary."""
# - Initialize figure
if ax is None:
fig_kwargs = preprocess_figure_args(
ax=ax,
fig_kwargs=fig_kwargs,
subplot_kwargs=subplot_kwargs,
)
subplot_kwargs = preprocess_subplot_kwargs(subplot_kwargs)
_, ax = plt.subplots(subplot_kw=subplot_kwargs, **fig_kwargs)
# - Add cartopy background
if add_background:
ax = plot_cartopy_background(ax)
return ax
[docs]
def plot_cartopy_background(ax):
"""Plot cartopy background."""
# - Add coastlines
ax.coastlines()
ax.add_feature(cartopy.feature.LAND, facecolor=[0.9, 0.9, 0.9])
ax.add_feature(cartopy.feature.OCEAN, alpha=0.6)
ax.add_feature(cartopy.feature.BORDERS) # BORDERS also draws provinces, ...
# - Add grid lines
gl = ax.gridlines(
crs=ccrs.PlateCarree(),
draw_labels=True,
linewidth=1,
color="gray",
alpha=0.1,
linestyle="-",
)
gl.top_labels = False # gl.xlabels_top = False
gl.right_labels = False # gl.ylabels_right = False
gl.xlines = True
gl.ylines = True
return ax
[docs]
def plot_sides(sides, ax, **plot_kwargs):
"""Plot boundary sides.
Expects a list of (lon, lat) tuples.
"""
for side in sides:
p = ax.plot(*side, transform=ccrs.Geodetic(), **plot_kwargs)
return p[0]
####--------------------------------------------------------------------------.
############################
#### Colorbar utilities ####
############################
def _get_orientation_location(cbar_kwargs):
location = cbar_kwargs.get("location", None)
orientation = cbar_kwargs.get("orientation", None)
# Set defaults
if location is None and orientation is None:
return "vertical", "right"
# Check orientation is horizontal or vertical
if orientation is not None and orientation not in ("horizontal", "vertical"):
raise ValueError("Invalid orientation. Choose 'horizontal' or 'vertical'.")
# Check location is top, left, right or bottom
if location is not None and location not in ("top", "left", "right", "bottom"):
raise ValueError("Invalid location. Choose 'top', 'left', 'right', or 'bottom'.")
# Check compatible arguments
if orientation is not None and location is not None:
if orientation == "vertical":
# Raise error if not right or left
if location not in ("right", "left"):
raise ValueError("Invalid location for vertical orientation. Choose 'right' or 'left'.")
elif location not in ("bottom", "top"):
raise ValueError("Invalid location for horizontal orientation. Choose 'bottom' or 'top'.")
return orientation, location
# Return with default location if missing
if orientation is not None:
if orientation == "vertical":
return "vertical", "right"
return "horizontal", "bottom"
# Return with correct orientation if missing
# if location is not None:
if location in ("right", "left"):
return "vertical", location
return "horizontal", location
[docs]
def plot_colorbar(p, ax, cbar_kwargs=None):
"""Add a colorbar to a matplotlib/cartopy plot.
cbar_kwargs 'size' and 'pad' controls the size of the colorbar.
and the padding between the plot and the colorbar.
p: matplotlib.image.AxesImage
ax: cartopy.mpl.geoaxes.GeoAxesSubplot
"""
cbar_kwargs = {} if cbar_kwargs is None else cbar_kwargs
cbar_kwargs = cbar_kwargs.copy() # otherwise pop ticklabels outside the function
ticklabels = cbar_kwargs.pop("ticklabels", None)
orientation, location = _get_orientation_location(cbar_kwargs)
# Defne colorbar axis
divider = make_axes_locatable(ax)
if orientation == "vertical":
size = cbar_kwargs.get("size", "5%")
pad = cbar_kwargs.get("pad", 0.1)
cax = divider.append_axes(location, size=size, pad=pad, axes_class=plt.Axes)
else: # orientation == "horizontal":
size = cbar_kwargs.get("size", "5%")
pad = cbar_kwargs.get("pad", 0.25)
cax = divider.append_axes(location, size=size, pad=pad, axes_class=plt.Axes)
p.figure.add_axes(cax)
# Add colorbar
cbar = plt.colorbar(p, cax=cax, ax=ax, **cbar_kwargs)
if ticklabels is not None:
_ = cbar.ax.set_yticklabels(ticklabels) if orientation == "vertical" else cbar.ax.set_xticklabels(ticklabels)
return cbar
[docs]
def set_colorbar_fully_transparent(p):
"""Add a fully transparent colorbar.
This is useful for animation where the colorbar should
not always in all frames but the plot area must be fixed.
"""
# Get the position of the colorbar
cbar_pos = p.colorbar.ax.get_position()
cbar_x, cbar_y = cbar_pos.x0, cbar_pos.y0
cbar_width, cbar_height = cbar_pos.width, cbar_pos.height
# Remove the colorbar
p.colorbar.ax.set_visible(False)
# Now plot an empty rectangle
fig = plt.gcf()
rect = plt.Rectangle(
(cbar_x, cbar_y),
cbar_width,
cbar_height,
transform=fig.transFigure,
facecolor="none",
edgecolor="none",
)
fig.patches.append(rect)
####--------------------------------------------------------------------------.
##########################
#### Cartopy wrappers ####
##########################
def _sanitize_cartopy_plot_kwargs(plot_kwargs):
"""Sanitize 'cmap' to avoid cartopy bug related to cmap bad color.
Cartopy requires the bad color to be fully transparent.
"""
cmap = plot_kwargs.get("cmap", None)
if cmap is not None:
bad = cmap.get_bad()
bad[3] = 0 # enforce to 0 (transparent)
cmap.set_bad(bad)
plot_kwargs["cmap"] = cmap
return plot_kwargs
def _compute_extent(x_coords, y_coords):
"""Compute the extent (x_min, x_max, y_min, y_max) from the pixel centroids in x and y coordinates.
This function assumes that the spacing between each pixel is uniform.
"""
# Calculate the pixel size assuming uniform spacing between pixels
pixel_size_x = (x_coords[-1] - x_coords[0]) / (len(x_coords) - 1)
pixel_size_y = (y_coords[-1] - y_coords[0]) / (len(y_coords) - 1)
# Adjust min and max to get the corners of the outer pixels
x_min, x_max = x_coords[0] - pixel_size_x / 2, x_coords[-1] + pixel_size_x / 2
y_min, y_max = y_coords[0] - pixel_size_y / 2, y_coords[-1] + pixel_size_y / 2
return [x_min, x_max, y_min, y_max]
[docs]
def plot_cartopy_imshow(
ax,
da,
x,
y,
interpolation="nearest",
add_colorbar=True,
plot_kwargs=None,
cbar_kwargs=None,
):
"""Plot imshow with cartopy."""
plot_kwargs = {} if plot_kwargs is None else plot_kwargs
# Infer x and y
x, y = infer_xy_labels(da, x=x, y=y, rgb=plot_kwargs.get("rgb", None))
# - Ensure image with correct dimensions orders
da = da.transpose(y, x, ...)
arr = np.asanyarray(da.data)
# - Compute coordinates
x_coords = da[x].to_numpy()
y_coords = da[y].to_numpy()
# - Derive extent
extent = _compute_extent(x_coords=x_coords, y_coords=y_coords)
# - Determine origin based on the orientation of da[y] values
# --> If increasing, set origin="lower"
# --> If decreasing, set origin="upper"
origin = "lower" if y_coords[1] > y_coords[0] else "upper"
# - Add variable field with cartopy
rgb = plot_kwargs.pop("rgb", False)
p = ax.imshow(
arr,
transform=ccrs.PlateCarree(),
extent=extent,
origin=origin,
interpolation=interpolation,
**plot_kwargs,
)
# - Set the extent
ax.set_extent(extent)
# - Add colorbar
if add_colorbar and not rgb:
_ = plot_colorbar(p=p, ax=ax, cbar_kwargs=cbar_kwargs)
return p
[docs]
def plot_cartopy_pcolormesh(
ax,
da,
x,
y,
add_colorbar=True,
add_swath_lines=True,
plot_kwargs=None,
cbar_kwargs=None,
):
"""Plot imshow with cartopy.
x and y must represents longitude and latitudes.
The function currently does not allow to zoom on regions across the antimeridian.
The function mask scanning pixels which spans across the antimeridian.
If the DataArray has a RGB dimension, plot_kwargs should contain the ``rgb``
key with the name of the RGB dimension.
"""
plot_kwargs = {} if plot_kwargs is None else plot_kwargs
# Remove RGB from plot_kwargs
rgb = plot_kwargs.pop("rgb", False)
# Get x, y, and array to plot
da = preprocess_rgb_dataarray(da, rgb=rgb)
da = da.compute()
lon = da[x].data.copy()
lat = da[y].data.copy()
arr = da.data
# Check if 1D coordinate (orbit nadir-view / transect / cross-section case)
is_1d_case = lon.ndim == 1
# Infill invalid value and mask data at invalid coordinates
# - No invalid values after this function call
lon, lat, arr = get_valid_pcolormesh_inputs(lon, lat, arr, rgb=rgb, mask_data=True)
if is_1d_case:
arr = np.expand_dims(arr, axis=1)
# Ensure arguments
if rgb:
add_colorbar = False
# Compute coordinates of cell corners for pcolormesh quadrilateral mesh
# - This enable correct masking of cells crossing the antimeridian
lon, lat = get_lonlat_corners_from_centroids(lon, lat)
# Mask cells crossing the antimeridian
# - with gpm.config.set({"viz_hide_antimeridian_data": False}): can be used to modify the masking behaviour
arr, plot_kwargs = mask_antimeridian_crossing_array(arr, lon, rgb, plot_kwargs)
# Add variable field with cartopy
_ = plot_kwargs.setdefault("shading", "flat")
p = ax.pcolormesh(
lon,
lat,
arr,
transform=ccrs.PlateCarree(),
**plot_kwargs,
)
# Add swath lines
if add_swath_lines and not is_1d_case:
sides = [(lon[0, :], lat[0, :]), (lon[-1, :], lat[-1, :])]
plot_sides(sides=sides, ax=ax, linestyle="--", color="black")
# Add colorbar
if add_colorbar:
_ = plot_colorbar(p=p, ax=ax, cbar_kwargs=cbar_kwargs)
return p
####-------------------------------------------------------------------------------.
#########################
#### Xarray wrappers ####
#########################
def _preprocess_xr_kwargs(add_colorbar, plot_kwargs, cbar_kwargs):
if not add_colorbar:
cbar_kwargs = None
if "rgb" in plot_kwargs:
cbar_kwargs = None
add_colorbar = False
plot_kwargs = {"rgb": plot_kwargs.get("rgb")} # alpha currently skipped if RGB
return add_colorbar, plot_kwargs, cbar_kwargs
[docs]
def plot_xr_pcolormesh(
ax,
da,
x,
y,
add_colorbar=True,
cbar_kwargs=None,
**plot_kwargs,
):
"""Plot pcolormesh with xarray."""
ticklabels = cbar_kwargs.pop("ticklabels", None)
add_colorbar, plot_kwargs, cbar_kwargs = _preprocess_xr_kwargs(
add_colorbar=add_colorbar,
plot_kwargs=plot_kwargs,
cbar_kwargs=cbar_kwargs,
)
p = da.plot.pcolormesh(
x=x,
y=y,
ax=ax,
add_colorbar=add_colorbar,
cbar_kwargs=cbar_kwargs,
**plot_kwargs,
)
plt.title(da.name)
if add_colorbar and ticklabels is not None:
p.colorbar.ax.set_yticklabels(ticklabels)
return p
[docs]
def plot_xr_imshow(
ax,
da,
x,
y,
interpolation="nearest",
add_colorbar=True,
cbar_kwargs=None,
visible_colorbar=True,
**plot_kwargs,
):
"""Plot imshow with xarray.
The colorbar is added with xarray to enable to display multiple colorbars
when calling this function multiple times on different fields with
different colorbars.
"""
ticklabels = cbar_kwargs.pop("ticklabels", None)
add_colorbar, plot_kwargs, cbar_kwargs = _preprocess_xr_kwargs(
add_colorbar=add_colorbar,
plot_kwargs=plot_kwargs,
cbar_kwargs=cbar_kwargs,
)
p = da.plot.imshow(
x=x,
y=y,
ax=ax,
interpolation=interpolation,
add_colorbar=add_colorbar,
cbar_kwargs=cbar_kwargs,
**plot_kwargs,
)
plt.title(da.name)
if add_colorbar and ticklabels is not None:
p.colorbar.ax.set_yticklabels(ticklabels)
# Make the colorbar fully transparent with a smart trick ;)
# - TODO: this still cause issues when plotting 2 colorbars !
if add_colorbar and not visible_colorbar:
set_colorbar_fully_transparent(p)
# Add manually the colorbar
# p = da.plot.imshow(
# x=x,
# y=y,
# ax=ax,
# interpolation=interpolation,
# add_colorbar=False,
# **plot_kwargs,
# )
# plt.title(da.name)
# if add_colorbar:
# _ = plot_colorbar(p=p, ax=ax, cbar_kwargs=cbar_kwargs)
return p
####--------------------------------------------------------------------------.
####################
#### Plot Image ####
####################
def _plot_image(
da,
x=None,
y=None,
ax=None,
add_colorbar=True,
interpolation="nearest",
fig_kwargs=None,
cbar_kwargs=None,
**plot_kwargs,
):
"""Plot GPM orbit granule as in image."""
from gpm.checks import is_grid, is_orbit
from gpm.visualization.facetgrid import sanitize_facetgrid_plot_kwargs
fig_kwargs = preprocess_figure_args(ax=ax, fig_kwargs=fig_kwargs)
# - Initialize figure
if ax is None:
_, ax = plt.subplots(**fig_kwargs)
# - Sanitize plot_kwargs set by by xarray FacetGrid.map_dataarray
is_facetgrid = plot_kwargs.get("_is_facetgrid", False)
plot_kwargs = sanitize_facetgrid_plot_kwargs(plot_kwargs)
# - If not specified, retrieve/update plot_kwargs and cbar_kwargs as function of product name
plot_kwargs, cbar_kwargs = get_plot_kwargs(
name=da.name,
user_plot_kwargs=plot_kwargs,
user_cbar_kwargs=cbar_kwargs,
)
# - Plot with xarray
p = plot_xr_imshow(
ax=ax,
da=da,
x=x,
y=y,
interpolation=interpolation,
add_colorbar=add_colorbar,
cbar_kwargs=cbar_kwargs,
**plot_kwargs,
)
if is_orbit(da):
ax.set_xlabel("Along-Track")
ax.set_ylabel("Cross-Track")
elif is_grid(da):
ax.set_xlabel("Longitude")
ax.set_ylabel("Latitude")
# - Monkey patch the mappable instance to add optimize_layout
if not is_facetgrid:
p = add_optimize_layout_method(p)
# - Return mappable
return p
def _plot_image_facetgrid(
da,
x=None,
y=None,
ax=None,
add_colorbar=True,
interpolation="nearest",
fig_kwargs=None,
cbar_kwargs=None,
**plot_kwargs,
):
"""Plot 2D fields with FacetGrid."""
from gpm.visualization.facetgrid import ImageFacetGrid
# Check inputs
fig_kwargs = preprocess_figure_args(ax=ax, fig_kwargs=fig_kwargs, is_facetgrid=True)
# Retrieve GPM-API defaults cmap and cbar kwargs
variable = da.name
plot_kwargs, cbar_kwargs = get_plot_kwargs(
name=variable,
user_plot_kwargs=plot_kwargs,
user_cbar_kwargs=cbar_kwargs,
)
# Disable colorbar if rgb
# - Move this to pycolorbar !
# - Also remove cmap, norm, vmin and vmax in plot_kwargs
if plot_kwargs.get("rgb", False):
add_colorbar = False
cbar_kwargs = {}
# Create FacetGrid
fc = ImageFacetGrid(
data=da.compute(),
col=plot_kwargs.pop("col", None),
row=plot_kwargs.pop("row", None),
col_wrap=plot_kwargs.pop("col_wrap", None),
axes_pad=plot_kwargs.pop("axes_pad", None),
fig_kwargs=fig_kwargs,
cbar_kwargs=cbar_kwargs,
add_colorbar=add_colorbar,
aspect=plot_kwargs.pop("aspect", False),
facet_height=plot_kwargs.pop("facet_height", 3),
facet_aspect=plot_kwargs.pop("facet_aspect", 1),
)
# Plot the maps
fc = fc.map_dataarray(
_plot_image,
x=x,
y=y,
add_colorbar=False,
interpolation=interpolation,
cbar_kwargs=cbar_kwargs,
**plot_kwargs,
)
fc.remove_duplicated_axis_labels()
# Add colorbar
if add_colorbar:
fc.add_colorbar(**cbar_kwargs)
return fc
[docs]
def plot_image(
da,
x=None,
y=None,
ax=None,
add_colorbar=True,
interpolation="nearest",
fig_kwargs=None,
cbar_kwargs=None,
**plot_kwargs,
):
"""Plot data using imshow.
Parameters
----------
da : xarray.DataArray
xarray DataArray.
x : str, optional
X dimension name.
If ``None``, takes the second dimension.
The default is ``None``.
y : str, optional
Y dimension name.
If ``None``, takes the first dimension.
The default is ``None``.
ax : cartopy.mpl.geoaxes.GeoAxes, optional
The matplotlib axes where to plot the image.
If ``None``, a figure is initialized using the
specified ``fig_kwargs``.
The default is ``None``.
add_colorbar : bool, optional
Whether to add a colorbar. The default is ``True``.
interpolation : str, optional
Argument to be passed to imshow.
The default is ``"nearest"``.
fig_kwargs : dict, optional
Figure options to be passed to :py:class:`matplotlib.pyplot.subplots`.
The default is ``None``.
Only used if ``ax`` is ``None``.
subplot_kwargs : dict, optional
Subplot options to be passed to :py:class:`matplotlib.pyplot.subplots`.
The default is ``None``.
Only used if ```ax``` is ``None``.
cbar_kwargs : dict, optional
Colorbar options. The default is ``None``.
**plot_kwargs
Additional arguments to be passed to the plotting function.
Examples include ``cmap``, ``norm``, ``vmin``, ``vmax``, ``levels``, ...
For FacetGrid plots, specify ``row``, ``col`` and ``col_wrap``.
With ``rgb`` you can specify the name of the xarray.DataArray RGB dimension.
"""
from gpm.checks import check_is_spatial_2d, is_spatial_2d
# Plot orbit
if not is_spatial_2d(da, strict=False):
raise ValueError("Can not plot. It's not a spatial 2D object.")
# Check inputs
da = check_object_format(da, plot_kwargs=plot_kwargs, check_function=check_is_spatial_2d, strict=True)
# Plot FacetGrid with xarray imshow
if "col" in plot_kwargs or "row" in plot_kwargs:
p = _plot_image_facetgrid(
da=da,
x=x,
y=y,
ax=ax,
add_colorbar=add_colorbar,
interpolation=interpolation,
fig_kwargs=fig_kwargs,
cbar_kwargs=cbar_kwargs,
**plot_kwargs,
)
# Plot with xarray imshow
else:
p = _plot_image(
da=da,
x=x,
y=y,
ax=ax,
add_colorbar=add_colorbar,
interpolation=interpolation,
fig_kwargs=fig_kwargs,
cbar_kwargs=cbar_kwargs,
**plot_kwargs,
)
# Return mappable
return p
####--------------------------------------------------------------------------.
##################
#### Plot map ####
##################
[docs]
def plot_map(
da,
x=None,
y=None,
ax=None,
add_colorbar=True,
add_swath_lines=True, # used only for GPM orbit objects
add_background=True,
interpolation="nearest", # used only for GPM grid objects
fig_kwargs=None,
subplot_kwargs=None,
cbar_kwargs=None,
**plot_kwargs,
):
"""Plot data on a geographic map.
Parameters
----------
da : xarray.DataArray
xarray DataArray.
x : str, optional
Longitude coordinate name.
If ``None``, takes the second dimension.
The default is ``None``.
y : str, optional
Latitude coordinate name.
If ``None``, takes the first dimension.
The default is ``None``.
ax : cartopy.mpl.geoaxes.GeoAxes, optional
The cartopy GeoAxes where to plot the map.
If ``None``, a figure is initialized using the
specified ``fig_kwargs`` and ``subplot_kwargs``.
The default is ``None``.
add_colorbar : bool, optional
Whether to add a colorbar. The default is ``True``.
add_swath_lines : bool, optional
Whether to plot the swath sides with a dashed line. The default is ``True``.
This argument only applies for ORBIT objects.
add_background : bool, optional
Whether to add the map background. The default is ``True``.
interpolation : str, optional
Argument to be passed to :py:class:`matplotlib.axes.Axes.imshow`. Only applies for GRID objects.
The default is ``"nearest"``.
fig_kwargs : dict, optional
Figure options to be passed to `matplotlib.pyplot.subplots`.
The default is ``None``.
Only used if ``ax`` is ``None``.
subplot_kwargs : dict, optional
Dictionary of keyword arguments for :py:class:`matplotlib.pyplot.subplots`.
Must contain the Cartopy CRS ` ``projection`` key if specified.
The default is ``None``.
Only used if ``ax`` is ``None``.
cbar_kwargs : dict, optional
Colorbar options. The default is ``None``.
**plot_kwargs
Additional arguments to be passed to the plotting function.
Examples include ``cmap``, ``norm``, ``vmin``, ``vmax``, ``levels``, ...
For FacetGrid plots, specify ``row``, ``col`` and ``col_wrap``.
With ``rgb`` you can specify the name of the xarray.DataArray RGB dimension.
"""
from gpm.checks import has_spatial_dim, is_grid, is_orbit, is_spatial_2d
from gpm.visualization.grid import plot_grid_map
from gpm.visualization.orbit import plot_orbit_map
# Plot orbit
# - allow vertical or other dimensions for FacetGrid
# - allow to plot a swath of size 1 (i.e. nadir-looking)
if is_orbit(da) and has_spatial_dim(da):
p = plot_orbit_map(
da=da,
x=x,
y=y,
ax=ax,
add_colorbar=add_colorbar,
add_swath_lines=add_swath_lines,
add_background=add_background,
fig_kwargs=fig_kwargs,
subplot_kwargs=subplot_kwargs,
cbar_kwargs=cbar_kwargs,
**plot_kwargs,
)
# Plot grid
elif is_grid(da) and is_spatial_2d(da, strict=False):
p = plot_grid_map(
da=da,
x=x,
y=y,
ax=ax,
add_colorbar=add_colorbar,
interpolation=interpolation,
add_background=add_background,
fig_kwargs=fig_kwargs,
subplot_kwargs=subplot_kwargs,
cbar_kwargs=cbar_kwargs,
**plot_kwargs,
)
else:
raise ValueError("Can not plot. It's neither a GPM GRID or GPM ORBIT spatial 2D object.")
# Return mappable
return p
[docs]
def plot_map_mesh(
xr_obj,
x=None,
y=None,
ax=None,
edgecolors="k",
linewidth=0.1,
add_background=True,
fig_kwargs=None,
subplot_kwargs=None,
**plot_kwargs,
):
from gpm.checks import is_grid, is_orbit
from gpm.visualization.grid import plot_grid_mesh
from gpm.visualization.orbit import plot_orbit_mesh
# Plot orbit
if is_orbit(xr_obj):
x, y = infer_map_xy_coords(xr_obj, x=x, y=y)
p = plot_orbit_mesh(
da=xr_obj[y],
ax=ax,
x=x,
y=y,
edgecolors=edgecolors,
linewidth=linewidth,
add_background=add_background,
fig_kwargs=fig_kwargs,
subplot_kwargs=subplot_kwargs,
**plot_kwargs,
)
elif is_grid(xr_obj):
p = plot_grid_mesh(
xr_obj=xr_obj,
x=x,
y=y,
ax=ax,
edgecolors=edgecolors,
linewidth=linewidth,
add_background=add_background,
fig_kwargs=fig_kwargs,
subplot_kwargs=subplot_kwargs,
**plot_kwargs,
)
else:
raise ValueError("Can not plot. It's neither a GPM GRID or GPM ORBIT spatial object.")
# Return mappable
return p
[docs]
def plot_map_mesh_centroids(
xr_obj,
x=None,
y=None,
ax=None,
c="r",
s=1,
add_background=True,
fig_kwargs=None,
subplot_kwargs=None,
**plot_kwargs,
):
"""Plot GPM orbit granule mesh centroids in a cartographic map."""
from gpm.checks import is_grid, is_orbit
# Initialize figure if necessary
ax = initialize_cartopy_plot(
ax=ax,
fig_kwargs=fig_kwargs,
subplot_kwargs=subplot_kwargs,
add_background=add_background,
)
# Retrieve orbits lon, lat coordinates
if is_orbit(xr_obj):
x, y = infer_map_xy_coords(xr_obj, x=x, y=y)
# Retrieve grid centroids mesh
if is_grid(xr_obj):
x, y = infer_xy_labels(xr_obj, x=x, y=y)
xr_obj = create_grid_mesh_data_array(xr_obj, x=x, y=y)
# Extract numpy arrays
lon = xr_obj[x].to_numpy()
lat = xr_obj[y].to_numpy()
# Plot centroids
p = ax.scatter(lon, lat, transform=ccrs.PlateCarree(), c=c, s=s, **plot_kwargs)
# Return mappable
return p
[docs]
def create_grid_mesh_data_array(xr_obj, x, y):
"""Create a 2D mesh coordinates DataArray.
Takes as input the 1D coordinate arrays from an existing xarray.DataArray or xarray.Dataset object.
The function creates a 2D grid (mesh) of x and y coordinates and initializes
the data values to NaN.
Parameters
----------
xr_obj : xarray.DataArray or xarray.Dataset
The input xarray object containing the 1D coordinate arrays.
x : str
The name of the x-coordinate in `xr_obj`.
y : str
The name of the y-coordinate in `xr_obj`.
Returns
-------
da_mesh : xarray.DataArray
A 2D xarray.DataArray with mesh coordinates for `x` and `y`, and NaN values for data points.
Notes
-----
The resulting xarray.DataArray has dimensions named 'y' and 'x', corresponding to the
y and x coordinates respectively.
The coordinate values are taken directly from the input 1D coordinate arrays,
and the data values are set to NaN.
"""
# Extract 1D coordinate arrays
x_coords = xr_obj[x].to_numpy()
y_coords = xr_obj[y].to_numpy()
# Create 2D meshgrid for x and y coordinates
X, Y = np.meshgrid(x_coords, y_coords, indexing="xy")
# Create a 2D array of NaN values with the same shape as the meshgrid
dummy_values = np.full(X.shape, np.nan)
# Create a new DataArray with 2D coordinates and NaN values
return xr.DataArray(
dummy_values,
coords={x: (("y", "x"), X), y: (("y", "x"), Y)},
dims=("y", "x"),
)
####--------------------------------------------------------------------------.
def _plot_labels(
xr_obj,
label_name=None,
max_n_labels=50,
add_colorbar=True,
interpolation="nearest",
cmap="Paired",
fig_kwargs=None,
**plot_kwargs,
):
"""Plot labels.
The maximum allowed number of labels to plot is 'max_n_labels'.
"""
from ximage.labels.labels import get_label_indices, redefine_label_array
from ximage.labels.plot_labels import get_label_colorbar_settings
from gpm.visualization.plot import plot_image
if isinstance(xr_obj, xr.Dataset):
dataarray = xr_obj[label_name]
else:
dataarray = xr_obj[label_name] if label_name is not None else xr_obj
dataarray = dataarray.compute()
label_indices = get_label_indices(dataarray)
n_labels = len(label_indices)
if add_colorbar and n_labels > max_n_labels:
msg = f"""The array currently contains {n_labels} labels
and 'max_n_labels' is set to {max_n_labels}. The colorbar is not displayed!"""
print(msg)
add_colorbar = False
# Relabel array from 1 to ... for plotting
dataarray = redefine_label_array(dataarray, label_indices=label_indices)
# Replace 0 with nan
dataarray = dataarray.where(dataarray > 0)
# Define appropriate colormap
default_plot_kwargs, cbar_kwargs = get_label_colorbar_settings(label_indices, cmap=cmap)
default_plot_kwargs.update(plot_kwargs)
# Plot image
return plot_image(
dataarray,
interpolation=interpolation,
add_colorbar=add_colorbar,
cbar_kwargs=cbar_kwargs,
fig_kwargs=fig_kwargs,
**default_plot_kwargs,
)
[docs]
def plot_labels(
obj, # Dataset, DataArray or generator
label_name=None,
max_n_labels=50,
add_colorbar=True,
interpolation="nearest",
cmap="Paired",
fig_kwargs=None,
**plot_kwargs,
):
if is_generator(obj):
for _, xr_obj in obj: # label_id, xr_obj
p = _plot_labels(
xr_obj=xr_obj,
label_name=label_name,
max_n_labels=max_n_labels,
add_colorbar=add_colorbar,
interpolation=interpolation,
cmap=cmap,
fig_kwargs=fig_kwargs,
**plot_kwargs,
)
plt.show()
else:
p = _plot_labels(
xr_obj=obj,
label_name=label_name,
max_n_labels=max_n_labels,
add_colorbar=add_colorbar,
interpolation=interpolation,
cmap=cmap,
fig_kwargs=fig_kwargs,
**plot_kwargs,
)
return p
[docs]
def plot_patches(
patch_gen,
variable=None,
add_colorbar=True,
interpolation="nearest",
fig_kwargs=None,
cbar_kwargs=None,
**plot_kwargs,
):
"""Plot patches."""
from gpm.visualization.plot import plot_image
# Plot patches
for _, xr_patch in patch_gen: # label_id, xr_obj
if isinstance(xr_patch, xr.Dataset):
if variable is None:
raise ValueError("'variable' must be specified when plotting xarray.Dataset patches.")
xr_patch = xr_patch[variable]
try:
plot_image(
xr_patch,
interpolation=interpolation,
add_colorbar=add_colorbar,
fig_kwargs=fig_kwargs,
cbar_kwargs=cbar_kwargs,
**plot_kwargs,
)
plt.show()
except Exception:
pass
####--------------------------------------------------------------------------.
[docs]
def get_inset_bounds(
ax,
loc="upper right",
inset_height=0.2,
inside_figure=True,
aspect_ratio=1,
):
"""Calculate the bounds for an inset axes in a matplotlib figure.
This function computes the normalized figure coordinates for placing an inset axes within a figure,
based on the specified location, size, and whether the inset should be fully inside the figure bounds.
It is designed to be used with matplotlib figures to facilitate the addition of insets (e.g., for maps
or zoomed plots) at predefined positions.
Parameters
----------
loc : str
The location of the inset within the figure. Valid options are ``'lower left'``, ``'lower right'``,
``'upper left'``, and ``'upper right'``. The default is ``'upper right'``.
inset_height : float
The size of the inset height, specified as a fraction of the figure's height.
For example, a value of 0.2 indicates that the inset's height will be 20% of the figure's height.
The aspect ratio will govern the ``inset_width``.
inside_figure : bool, optional
Determines whether the inset is constrained to be fully inside the figure bounds. If ``True`` (default),
the inset is placed fully within the figure. If ``False``, the inset can extend beyond the figure's edges,
allowing for a half-outside placement.
aspect_ratio : float, optional
The width-to-height ratio of the inset figure.
A value greater than 1 indicates an inset figure wider than it is tall,
and a value less than 1 indicates an inset figure taller than it is wide.
The default value is 1.0, indicating a square inset figure.
Returns
-------
inset_bounds : list of float
The calculated bounds of the inset, in the format ``[x0, y0, width, height]``, where ``x0`` and ``y0``
are the normalized figure coordinates of the lower left corner of the inset, and ``width`` and
``height`` are the normalized width and height of the inset, respectively.
"""
# Get the bounding box of the parent axes in figure coordinates
bbox = ax.get_position()
parent_width = bbox.width
parent_height = bbox.height
# Compute the inset width percentage (relative to the parent axes)
# - Take into account possible different aspect ratios
inset_height_abs = inset_height * parent_height
inset_width_abs = inset_height_abs * aspect_ratio
inset_width = inset_width_abs / parent_width
loc_mapping = {
"upper right": (1 - inset_width, 1 - inset_height),
"upper left": (0, 1 - inset_height),
"lower right": (1 - inset_width, 0),
"lower left": (0, 0),
}
inset_x, inset_y = loc_mapping[loc]
# Adjust for insets that are allowed to be half outside of the figure
if not inside_figure:
inset_x += inset_width / 2 * (-1 if loc.endswith("left") else 1)
inset_y += inset_height / 2 * (-1 if loc.startswith("lower") else 1)
return [inset_x, inset_y, inset_width, inset_height]
[docs]
def add_map_inset(ax, loc="upper left", inset_height=0.2, projection=None, inside_figure=True):
"""Adds an inset map to a matplotlib axis using Cartopy, highlighting the extent of the main plot.
This function creates a smaller map inset within a larger map plot to show a global view or
contextual location of the main plot's extent.
It uses Cartopy for map projections and plotting, and it outlines the extent of the main plot
within the inset to provide geographical context.
Parameters
----------
ax : matplotlib.axes.Axes or cartopy.mpl.geoaxes.GeoAxes
The main matplotlib or cartopy axis object where the geographic data is plotted.
loc : str, optional
The location of the inset map within the main plot.
Options include ``'lower left'``, ``'lower right'``,
``'upper left'``, and ``'upper right'``. The default is ``'upper left'``.
inset_height : float, optional
The size of the inset height, specified as a fraction of the figure's height.
For example, a value of 0.2 indicates that the inset's height will be 20% of the figure's height.
The aspect ratio (of the map inset) will govern the ``inset_width``.
inside_figure : bool, optional
Determines whether the inset is constrained to be fully inside the figure bounds. If ``True`` (default),
the inset is placed fully within the figure. If ``False``, the inset can extend beyond the figure's edges,
allowing for a half-outside placement.
projection: cartopy.crs.Projection, optional
A cartopy projection. If ``None``, am Orthographic projection centered on the extent center is used.
Returns
-------
ax2 : cartopy.mpl.geoaxes.GeoAxes
The Cartopy GeoAxesSubplot object for the inset map.
Notes
-----
The function adjusts the extent of the inset map based on the main plot's extent, adding a
slight padding for visual clarity. It then overlays a red outline indicating the main plot's
geographical extent.
Examples
--------
>>> p = da.gpm.plot_map()
>>> add_map_inset(ax=p.axes, loc="upper left", inset_height=0.15)
This example creates a main plot with a specified extent and adds an upper-left inset map
showing the global context of the main plot's extent.
"""
from shapely import Polygon
from gpm.utils.geospatial import extend_geographic_extent
# Retrieve extent and bounds
extent = ax.get_extent()
extent = extend_geographic_extent(extent, padding=0.5)
bounds = [extent[i] for i in [0, 2, 1, 3]]
# Create Cartopy Polygon
polygon = Polygon.from_bounds(*bounds)
# Define Orthographic projection
if projection is None:
lon_min, lon_max, lat_min, lat_max = extent
projection = ccrs.Orthographic(
central_latitude=(lat_min + lat_max) / 2,
central_longitude=(lon_min + lon_max) / 2,
)
# Define aspect ratio of the map inset
aspect_ratio = float(np.diff(projection.x_limits) / np.diff(projection.y_limits).item())
# Define inset location relative to main plot (ax) in normalized units
# - Lower-left corner of inset Axes, and its width and height
# - [x0, y0, width, height]
inset_bounds = get_inset_bounds(
ax=ax,
loc=loc,
inset_height=inset_height,
inside_figure=inside_figure,
aspect_ratio=aspect_ratio,
)
# ax2 = plt.axes(inset_bounds, projection=projection)
ax2 = ax.inset_axes(
inset_bounds,
projection=projection,
)
# Add global map
ax2.set_global()
ax2.add_feature(cfeature.LAND)
ax2.add_feature(cfeature.OCEAN)
# Add extent polygon
_ = ax2.add_geometries(
[polygon],
ccrs.PlateCarree(),
facecolor="none",
edgecolor="red",
linewidth=0.3,
)
return ax2