Source code for gpm.utils.slices

# -----------------------------------------------------------------------------.
# MIT License

# Copyright (c) 2024 GPM-API developers
#
# This file is part of GPM-API.

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

# -----------------------------------------------------------------------------.
"""This module contains utilities for list of slices processing."""

import numpy as np

####---------------------------------------------------------------------------.
#### Tools for list_slices


[docs] def get_list_slices_from_indices(indices): """Return a list of slices from a list/array of integer indices. Example: ``[0,1,2,4,5,8]`` --> ``[slices(0,3),slice(4,6), slice(8,9)]`` """ if isinstance(indices, (int, float)): indices = [indices] # Checks if len(indices) == 0: return [] indices = np.asarray(indices).astype(int) indices = sorted(np.unique(indices)) if np.any(np.sign(indices) < 0): raise ValueError("get_list_slices_from_indices expects only positive" " integer indices.") if len(indices) == 1: return [slice(indices[0], indices[0] + 1)] # Retrieve slices # idx_splits = np.where(np.diff(indices) > 1)[0] # if len(idx_splits) == 0: # list_slices = [slice(min(indices), max(indices))] # else: # list_idx = np.split(indices, idx_splits+1) # list_slices = [slice(x.min(), x.max()+1) for x in list_idx] start = indices[0] previous = indices[0] list_slices = [] for idx in indices[1:]: if idx - previous == 1: previous = idx else: list_slices.append(slice(start, previous + 1)) start = idx previous = idx list_slices.append(slice(start, previous + 1)) return list_slices
[docs] def get_indices_from_list_slices(list_slices, check_non_intersecting=True): """Return a numpy array of indices from a list of slices.""" if len(list_slices) == 0: return np.array([]) list_indices = [np.arange(slc.start, slc.stop, slc.step) for slc in list_slices] indices, counts = np.unique(np.concatenate(list_indices), return_counts=True) if check_non_intersecting and np.any(counts > 1): raise ValueError("The list of slices contains intersecting slices!") return indices
def _get_slices_intersection(slc1, slc2, min_size=1): """Return the intersecting slices from two slices.""" if not isinstance(slc1, slice) or not isinstance(slc2, slice): raise TypeError("Expecting slice objects") start = max(slc1.start, slc2.start) stop = min(slc1.stop, slc2.stop) if stop - start < min_size: return None return slice(start, stop)
[docs] def list_slices_intersection(*args, min_size=1): """Return the intersecting slices from multiple list of slices.""" if len(args) == 0: return [] list_slices = [slice(-np.inf, np.inf)] for i in range(len(args)): list_slices = [_get_slices_intersection(slc1, slc2, min_size) for slc1 in list_slices for slc2 in args[i]] list_slices = [slc for slc in list_slices if slc is not None] if len(list_slices) == 0: return [] return list_slices
[docs] def list_slices_union(*args): """Return the union slices from multiple list of slices.""" list_indices = [get_indices_from_list_slices(l_slc) for l_slc in list(args)] union_indices = np.unique(np.concatenate(list_indices)) return get_list_slices_from_indices(union_indices)
def _get_slices_difference(slc1, slc2): """Return the list of slices covered by slc1 not intersecting slc2.""" slice_left = slice(slc1.start, min(slc1.stop, slc2.start)) slice_right = slice(max(slc1.start, slc2.stop), slc1.stop) slices = [] if get_slice_size(slice_left) > 0: slices.append(slice_left) if get_slice_size(slice_right) > 0: slices.append(slice_right) return slices
[docs] def list_slices_difference(list_slices1, list_slices2): """Return the list of slices covered by list_slices1 not intersecting list_slices2.""" if len(list_slices2) == 0: return list_slices1 list_slices = [ [slc for slc1 in list_slices1 for slc in _get_slices_difference(slc1, slc2)] for slc2 in list_slices2 ] return list_slices_intersection( *list_slices, min_size=0, ) # min_size=0 to keep holes from list_slices2
[docs] def list_slices_combine(*args): """Combine together a list of list_slices, without any additional operation.""" return [slc for list_slices in args for slc in list_slices]
[docs] def list_slices_simplify(list_slices): """Simplify list of of sequential slices. Example 1: [slice(0,2), slice(2,4)] --> [slice(0,4)] """ if len(list_slices) <= 1: return list_slices indices = get_indices_from_list_slices(list_slices, check_non_intersecting=False) return get_list_slices_from_indices(indices)
def _list_slices_sort(list_slices): """Sort a single list of slices.""" return sorted(list_slices, key=lambda x: x.start)
[docs] def list_slices_sort(*args): """Sort a single or multiple list of slices by slice.start. It output a single list of slices! """ list_slices = list_slices_combine(*args) return _list_slices_sort(list_slices)
[docs] def list_slices_filter(list_slices, min_size=None, max_size=None): """Filter list of slices by size.""" if min_size is None and max_size is None: return list_slices # Define min and max size if one is not specified min_size = 0 if min_size is None else min_size max_size = np.inf if max_size is None else max_size # Get list of slice sizes sizes = [get_slice_size(slc) if isinstance(slc, slice) else 0 for slc in list_slices] # Retrieve valid slices valid_bool = np.logical_and(np.array(sizes) >= min_size, np.array(sizes) <= max_size) return np.array(list_slices)[valid_bool].tolist()
[docs] def list_slices_flatten(list_slices): """Flatten out list of slices with 2 nested level. Examples -------- ``[[slice(1, 7934, None)], [slice(1, 2, None)]] --> [slice(1, 7934, None), slice(1, 2, None)]`` ``[slice(1, 7934, None), slice(1, 2, None)] --> [slice(1, 7934, None), slice(1, 2, None)]`` """ flat_list = [] for sublist in list_slices: if isinstance(sublist, list): flat_list += sublist else: flat_list.append(sublist) return flat_list
[docs] def get_list_slices_from_bool_arr(bool_arr, include_false=True, skip_consecutive_false=True): """Return the slices corresponding to sequences of ``True`` in the input arrays. If ``include_false=True``, the last element of each slice sequence (except the last) will be ``False``. If ``include_false=False``, no element in each slice sequence will be ``False``. If ``skip_consecutive_false=True`` (default), the first element of each slice must be a ``True``. If ``skip_consecutive_false=False``, it returns also slices of size 1 which selects just the ``False`` values. If ``include_false=False``, skip_consecutive_false is automatically ``True``. Examples -------- If ``include_false=True`` and ``skip_consecutive_false=False``: --> ``[False, False] --> ``[slice(0,1), slice(1,2)]`` If ``include_false=True`` and ``skip_consecutive_false=True``: --> ``[False, False] --> []`` --> ``[False, False, True] --> ``[slice(2,3)]`` --> ``[False, False, True, False] --> [slice(2,4)]`` If ``include_false=False``: --> ``[False, False, True, False] --> [slice(2,3)]`` """ # Check the arguments if not include_false: skip_consecutive_false = True bool_arr = np.array(bool_arr) # If all True if np.all(bool_arr): list_slices = [slice(0, len(bool_arr))] # If all False elif np.all(~bool_arr): list_slices = [] if skip_consecutive_false else [slice(i, i + 1) for i in range(0, len(bool_arr))] # If True and False else: # Retrieve indices where False start to occur false_indices = np.argwhere(~bool_arr).flatten() # Prepend -1 so first start start at 0, if no False at idx 0 false_indices = np.append(-1, false_indices) list_slices = [] for i in range(1, len(false_indices)): idx_before = false_indices[i - 1] idx = false_indices[i] if skip_consecutive_false and idx - idx_before == 1: continue # Define start start = idx_before + 1 # Define stop stop = idx + 1 if include_false else idx # Define slice slc = slice(start, stop) list_slices.append(slc) # Includes the last slice (if the last bool_arr element is not False) if idx < len(bool_arr) - 1: start = idx + 1 stop = len(bool_arr) slc = slice(start, stop) list_slices.append(slc) # Return list of slices return list_slices
# tests for _get_list_slices_from_bool_arr # bool_arr = np.array([True, False, True, True, True]) # bool_arr = np.array([True, True, True, False, True]) # bool_arr = np.array([True, True, True, True, False]) # bool_arr = np.array([True, False, False, True, True]) # bool_arr = np.array([True, True, True, False, False]) # bool_arr = np.array([False, True, True, True, False]) # bool_arr = np.array([False, False, True, True, True]) # bool_arr = np.array([False]) ####----------------------------------------------------------------------------. #### Tools for slice manipulation
[docs] def ensure_is_slice(slc): if isinstance(slc, slice): return slc if isinstance(slc, int): slc = slice(slc, slc + 1) elif isinstance(slc, (list, tuple)) and len(slc) == 1: slc = slice(slc[0], slc[0] + 1) elif isinstance(slc, np.ndarray) and slc.size == 1: slc = slice(slc.item(), slc.item() + 1) else: raise ValueError("Impossible to convert to a slice object.") return slc
[docs] def get_slice_size(slc): """Get size of the slice. Note: The actual slice size must not be representative of the true slice if slice.stop is larger than the length of object to be sliced. """ if not isinstance(slc, slice): raise TypeError("Expecting slice object") return slc.stop - slc.start
[docs] def get_slice_from_idx_bounds(idx_start, idx_end): """Return the slice required to include the idx bounds.""" return slice(idx_start, idx_end + 1)
[docs] def pad_slice(slc, padding, min_start=0, max_stop=np.inf): """Increase/decrease the slice with the padding argument. Does not ensure that all output slices have same size. Parameters ---------- slc : slice Slice objects. padding : int Padding to be applied to the slice. min_start : int, optional The minimum value for the start of the new slice. The default is 0. max_stop : int The maximum value for the stop of the new slice. The default is np.inf. Returns ------- list_slices : list The list of slices after applying padding. """ return slice(max(slc.start - padding, min_start), min(slc.stop + padding, max_stop))
[docs] def pad_slices(list_slices, padding, valid_shape): """Increase/decrease the list of slices with the padding argument. Parameters ---------- list_slices : list List of slice objects. padding : int or tuple Padding to be applied on each slice. valid_shape : int or tuple The shape of the array which the slices should be valid on. Returns ------- list_slices : list The list of slices after applying padding. """ # Check the inputs if isinstance(padding, int): padding = [padding] * len(list_slices) if isinstance(valid_shape, int): valid_shape = [valid_shape] * len(list_slices) if isinstance(padding, (list, tuple)) and len(padding) != len(list_slices): raise ValueError( "Invalid padding. The length of padding should be the same as the length of list_slices.", ) if isinstance(valid_shape, (list, tuple)) and len(valid_shape) != len(list_slices): raise ValueError( "Invalid valid_shape. The length of valid_shape should be the same as the length of list_slices.", ) # Apply padding return [ pad_slice(s, padding=p, min_start=0, max_stop=valid_shape[i]) for i, (s, p) in enumerate(zip(list_slices, padding)) ]
# min_size = 10 # min_start = 0 # max_stop = 20 # slc = slice(1, 5) # left bound # slc = slice(15, 20) # right bound # slc = slice(8, 12) # middle
[docs] def enlarge_slice(slc, min_size, min_start=0, max_stop=np.inf): """Enlarge a slice object to have at least a size of min_size. The function enforces the left and right bounds of the slice by `max_stop` and `min_start`. If the original slice size is larger than `min_size`, the original slice will be returned. Parameters ---------- slc : slice The original slice object to be enlarged. min_size : int The desired minimum size of the new slice. min_start : int, optional The minimum value for the start of the new slice. The default is 0. max_stop : int, optional The maximum value for the stop of the new slice. The default is np.inf. Returns ------- slice The new slice object with a size of at least min_size and respecting the left and right bounds. """ # Get slice size slice_size = get_slice_size(slc) # If min_size is larger than allowable size, raise error if min_size > (max_stop - min_start): raise ValueError( f"'min_size' {min_size} is too large to generate a slice between {min_start} and {max_stop}.", ) # If slice size larger than min_size, return the slice if slice_size >= min_size: return slc # Calculate the number of points to add on both sides n_indices_to_add = min_size - slice_size add_to_left = add_to_right = n_indices_to_add // 2 # If n_indices_to_add is odd, add + 1 on the left if n_indices_to_add % 2 == 1: add_to_left += 1 # Adjust adding for left and right bounds naive_start = slc.start - add_to_left naive_stop = slc.stop + add_to_right if naive_start <= min_start: exceeding_left_size = min_start - naive_start add_to_right += exceeding_left_size add_to_left -= exceeding_left_size if naive_stop >= max_stop: exceeding_right_size = naive_stop - max_stop add_to_right -= exceeding_right_size add_to_left += exceeding_right_size # Define new slice start = slc.start - add_to_left stop = slc.stop + add_to_right new_slice = slice(start, stop) # Check assert get_slice_size(new_slice) == min_size # Return new slice return new_slice
[docs] def enlarge_slices(list_slices, min_size, valid_shape): """Enlarge a list of slice object to have at least a size of min_size. The function enforces the left and right bounds of the slice to be between 0 and valid_shape. If the original slice size is larger than min_size, the original slice will be returned. Parameters ---------- list_slices : list List of slice objects. min_size : int or tuple Minimum size of the output slice. valid_shape : int or tuple The shape of the array which the slices should be valid on. Returns ------- list_slices : list The list of slices after enlarging it (if necessary). """ # Check the inputs if isinstance(min_size, int): min_size = [min_size] * len(list_slices) if isinstance(valid_shape, int): valid_shape = [valid_shape] * len(list_slices) if isinstance(min_size, (list, tuple)) and len(min_size) != len(list_slices): raise ValueError( "Invalid min_size. The length of min_size should be the same as the length of list_slices.", ) if isinstance(valid_shape, (list, tuple)) and len(valid_shape) != len(list_slices): raise ValueError( "Invalid valid_shape. The length of valid_shape should be the same as the length of list_slices.", ) # Enlarge the slice return [ enlarge_slice(slc, min_size=s, min_start=0, max_stop=valid_shape[i]) for i, (slc, s) in enumerate(zip(list_slices, min_size)) ]