Module ocean_science_utilities.interpolate.nd_interp

Expand source code
import numpy as np

from typing import Any, Tuple, Callable, List, Generator, Set, Optional, Dict, Sequence

from ocean_science_utilities.interpolate.general import interpolation_weights_1d
from ocean_science_utilities.tools.math import wrapped_difference
from ocean_science_utilities.tools.grid import enclosing_points_1d


class NdInterpolator:
    def __init__(
        self,
        get_data: Callable[[List[np.ndarray], List[int]], np.ndarray],
        data_coordinates: Sequence[Tuple[str, np.ndarray[Any, Any]]],
        data_shape: Tuple[int, ...],
        interp_coord_names: List[str],
        interp_index_coord_name: str,
        data_periodic_coordinates: Dict[str, float],
        data_period: Optional[float] = None,
        data_discont: Optional[float] = None,
        nearest_neighbour: bool = False,
    ):
        self.get_data = get_data
        self.coord = [x[0] for x in data_coordinates]
        self.data_shape = data_shape
        self.interp_coord_names = interp_coord_names
        self.interp_index_coord_name = interp_index_coord_name
        self.data_coordinates = data_coordinates
        self.data_periodic_coordinates = data_periodic_coordinates
        self.data_period = data_period
        self.data_discont = data_discont
        self.nearest_neighbour = nearest_neighbour

    @property
    def passive_coordinate_names(self) -> List[str]:
        return [name for name in self.coord if name not in self.interp_coord_names]

    @property
    def passive_coord_dim_indices(self) -> List[int]:
        return [self.coord.index(x) for x in self.passive_coordinate_names]

    @property
    def output_passive_coord_dim_indices(self) -> Tuple[int, ...]:
        indices = list(range(self.output_ndims))
        _ = indices.pop(indices.index(self.output_index_coord_index))
        return tuple(indices)

    @property
    def interp_coord_dim_indices(self) -> List[int]:
        return [self.coord.index(x) for x in self.interp_coord_names]

    @property
    def interp_index_coord_index(self) -> int:
        return self.coord.index(self.interp_index_coord_name)

    def output_shape(self, number_of_points: int) -> np.ndarray:
        output_shape = np.ones(self.output_ndims, dtype="int32")
        interpolating_index = self.output_index_coord_index
        passive_ind = self.passive_coord_dim_indices
        jj = 0
        for index in range(self.output_ndims):
            if index == interpolating_index:
                output_shape[index] = number_of_points
            else:
                output_shape[index] = self.data_shape[passive_ind[jj]]
                jj += 1
        return output_shape

    @property
    def output_index_coord_index(self) -> int:
        return int(
            np.searchsorted(
                self.passive_coord_dim_indices, self.interp_index_coord_index
            )
        )

    @property
    def interpolating_coordinates(self) -> List[Tuple[str, np.ndarray]]:
        return [x for x in self.data_coordinates if x[0] in self.interp_coord_names]

    @property
    def output_ndims(self) -> int:
        return self.data_ndims - self.interp_ndims + 1

    @property
    def interp_ndims(self) -> int:
        return len(self.interp_coord_names)

    @property
    def data_ndims(self) -> int:
        return len(self.coord)

    def output_indexing_full(self, slicer: slice) -> Tuple[slice, ...]:
        indicer = [slice(None)] * self.output_ndims
        indicer[self.output_index_coord_index] = slicer
        return tuple(indicer)

    def output_indexing_broadcast(self, slicer: slice) -> Tuple[Any, ...]:
        indicer = [None] * self.output_ndims
        indicer[self.output_index_coord_index] = slicer  # type: ignore
        return tuple(indicer)

    def coordinate_period(self, coordinate_name: str) -> Optional[float]:
        if coordinate_name in self.data_periodic_coordinates:
            return self.data_periodic_coordinates[coordinate_name]
        else:
            return None

    @property
    def data_is_periodic(self) -> bool:
        return self.data_period is not None

    def interpolate(
        self,
        points: Dict[str, np.ndarray],
    ) -> np.ndarray:
        """

        :param self:
        :param points:

        :return:
        """
        number_points = len(points[self.interp_coord_names[0]])

        # Find indices and weights for the succesive 1d interpolation problems
        indices_1d = np.empty((self.interp_ndims, 2, number_points), dtype="int64")

        weights_1d = np.empty((self.interp_ndims, 2, number_points), dtype="float64")

        for index, (coordinate_name, coordinate) in enumerate(
            self.interpolating_coordinates
        ):
            period = self.coordinate_period(coordinate_name)
            indices_1d[index, :, :] = enclosing_points_1d(
                coordinate, points[coordinate_name], period=period
            )
            weights_1d[index, :, :] = interpolation_weights_1d(
                coordinate,
                points[coordinate_name],
                indices_1d[index, :, :],
                period=period,
                extrapolate_left=False,
                extrapolate_right=False,
                nearest_neighbour=self.nearest_neighbour,
            )

        if self.data_is_periodic:
            return self._periodic_data_interpolator(
                number_points, indices_1d, weights_1d
            )
        else:
            return self._data_interpolator(number_points, indices_1d, weights_1d)

    def _data_interpolator(
        self, number_points: int, indices_1d: np.ndarray, weights_1d: np.ndarray
    ) -> np.ndarray:
        # We keep a running sum of the weights, if a point is excluded because it
        # contains no data (NaN) the weights will no longer add up to 1 - and we
        # reschale to account for the missing value. This is an easy way to account
        # for interpolation near missing points. Note that if the contribution of
        # missing weights ( 1-weights_sum) exceeds 0.5 - we consider the point
        # invalid.
        output_shape = self.output_shape(number_points)
        weights_sum = np.zeros(output_shape)
        interp_val = np.zeros(output_shape, dtype=np.float64)

        for intp_indices_nd, intp_weight_nd in _next_point(
            self.interp_ndims, indices_1d, weights_1d
        ):
            # Loop over all interpolation points one at a time.
            val = self.get_data(intp_indices_nd, self.interp_coord_dim_indices)

            mask = np.all(
                ~np.isnan(val), axis=self.output_passive_coord_dim_indices
            ) & (intp_weight_nd > 0)

            weights_sum[self.output_indexing_full(mask)] += intp_weight_nd[
                self.output_indexing_broadcast(mask)
            ]

            interp_val[self.output_indexing_full(mask)] += (
                intp_weight_nd[self.output_indexing_broadcast(mask)]
                * val[self.output_indexing_full(mask)]
            )

        with np.errstate(invalid="ignore", divide="ignore"):
            return np.where(weights_sum > 0.5, interp_val / weights_sum, np.nan)

    def _periodic_data_interpolator(
        self, number_points: int, indices_1d: np.ndarray, weights_1d: np.ndarray
    ) -> np.ndarray:
        # We keep a running sum of the weights, if a point is excluded because it
        # contains no data (NaN) the weights will no longer add up to 1 - and we
        # reschale to account for the missing value. This is an easy way to account
        # for interpolation near missing points. Note that if the contribution of
        # missing weights ( 1-weights_sum) exceeds 0.5 - we consider the point
        # invalid.
        output_shape = self.output_shape(number_points)
        weights_sum = np.zeros(output_shape)

        interp_val = np.zeros(output_shape, dtype=np.complex64)
        for intp_indices_nd, intp_weight_nd in _next_point(
            self.interp_ndims, indices_1d, weights_1d
        ):
            # Loop over all interpolation points one at a time.
            to_rad = np.pi * 2 / self.data_period  # type: ignore
            val = np.exp(
                1j
                * self.get_data(intp_indices_nd, self.interp_coord_dim_indices)
                * to_rad
            )

            mask = np.all(~np.isnan(val), axis=self.output_passive_coord_dim_indices)

            weights_sum[self.output_indexing_full(mask)] += intp_weight_nd[
                self.output_indexing_broadcast(mask)
            ]

            interp_val[self.output_indexing_full(mask)] += (
                intp_weight_nd[self.output_indexing_broadcast(mask)]
                * val[self.output_indexing_full(mask)]
            )

        interp_val = (
            np.angle(  # type: ignore
                np.where(weights_sum > 0.5, interp_val / weights_sum, np.nan)
            )
            * self.data_period  # type: ignore
            / np.pi
            / 2
        )

        return wrapped_difference(
            delta=interp_val, period=self.data_period, discont=self.data_period
        )


def _next_point(
    recursion_depth: int, indices_1d: np.ndarray, weights_1d: np.ndarray, *narg
) -> Generator[Tuple[List[np.ndarray], np.ndarray], None, None]:
    """
    We are trying to interpolate over N dimensions. In bilinear interpolation
    this means we have to visit 2**N points. If N is known this is most
    clearly expressed as a set of N nested loops:

    J=-1
    for i1 in range(0,2):
        for i2 in range(0,2):
            ...
                for iN in range(0,2):
                    J+=1
                    do stuff for J'th item.

    Here instead, since we do not know N in advance, use a set of recursive
    loops to depth N, where at the final level we yield for each of the 2**N
    points the values of the points and the weights with which they contribute
    to the interpolated value.

    :param recursion_depth:
    :param indices_1d:
    :param weights_1d:
    :param narg: indices from outer recursive loops
    :return: generater function that yields the J"th thing to do stuff with.
    """
    number_of_coordinates = indices_1d.shape[0]
    number_of_points = indices_1d.shape[2]

    if recursion_depth > 0:
        # Loop over the n'th coordinate, with
        #    n = number_of_coordinates - recursion_depth
        for ii in range(0, 2):
            # Yield from next recursive loop, add the loop coordinate to the
            # arg of the next call
            arg = (*narg, ii)
            yield from _next_point(recursion_depth - 1, indices_1d, weights_1d, *arg)
    else:
        #
        # Here we construct the "fancy" indexes we will use to grab datavalues.
        indices_nd = []
        weights_nd = np.ones((number_of_points,), dtype="float64")
        for index in range(0, number_of_coordinates):
            # get the coordinate index for the current point.
            indices_nd.append(indices_1d[index, narg[index], :])

            # The N-dimensional weight is the multiplication of all weights
            # of the associated 1d problems
            weights_nd *= weights_1d[index, narg[index], :]

        yield indices_nd, weights_nd

Classes

class NdInterpolator (get_data: Callable[[List[numpy.ndarray], List[int]], numpy.ndarray], data_coordinates: Sequence[Tuple[str, numpy.ndarray[Any, Any]]], data_shape: Tuple[int, ...], interp_coord_names: List[str], interp_index_coord_name: str, data_periodic_coordinates: Dict[str, float], data_period: Optional[float] = None, data_discont: Optional[float] = None, nearest_neighbour: bool = False)
Expand source code
class NdInterpolator:
    def __init__(
        self,
        get_data: Callable[[List[np.ndarray], List[int]], np.ndarray],
        data_coordinates: Sequence[Tuple[str, np.ndarray[Any, Any]]],
        data_shape: Tuple[int, ...],
        interp_coord_names: List[str],
        interp_index_coord_name: str,
        data_periodic_coordinates: Dict[str, float],
        data_period: Optional[float] = None,
        data_discont: Optional[float] = None,
        nearest_neighbour: bool = False,
    ):
        self.get_data = get_data
        self.coord = [x[0] for x in data_coordinates]
        self.data_shape = data_shape
        self.interp_coord_names = interp_coord_names
        self.interp_index_coord_name = interp_index_coord_name
        self.data_coordinates = data_coordinates
        self.data_periodic_coordinates = data_periodic_coordinates
        self.data_period = data_period
        self.data_discont = data_discont
        self.nearest_neighbour = nearest_neighbour

    @property
    def passive_coordinate_names(self) -> List[str]:
        return [name for name in self.coord if name not in self.interp_coord_names]

    @property
    def passive_coord_dim_indices(self) -> List[int]:
        return [self.coord.index(x) for x in self.passive_coordinate_names]

    @property
    def output_passive_coord_dim_indices(self) -> Tuple[int, ...]:
        indices = list(range(self.output_ndims))
        _ = indices.pop(indices.index(self.output_index_coord_index))
        return tuple(indices)

    @property
    def interp_coord_dim_indices(self) -> List[int]:
        return [self.coord.index(x) for x in self.interp_coord_names]

    @property
    def interp_index_coord_index(self) -> int:
        return self.coord.index(self.interp_index_coord_name)

    def output_shape(self, number_of_points: int) -> np.ndarray:
        output_shape = np.ones(self.output_ndims, dtype="int32")
        interpolating_index = self.output_index_coord_index
        passive_ind = self.passive_coord_dim_indices
        jj = 0
        for index in range(self.output_ndims):
            if index == interpolating_index:
                output_shape[index] = number_of_points
            else:
                output_shape[index] = self.data_shape[passive_ind[jj]]
                jj += 1
        return output_shape

    @property
    def output_index_coord_index(self) -> int:
        return int(
            np.searchsorted(
                self.passive_coord_dim_indices, self.interp_index_coord_index
            )
        )

    @property
    def interpolating_coordinates(self) -> List[Tuple[str, np.ndarray]]:
        return [x for x in self.data_coordinates if x[0] in self.interp_coord_names]

    @property
    def output_ndims(self) -> int:
        return self.data_ndims - self.interp_ndims + 1

    @property
    def interp_ndims(self) -> int:
        return len(self.interp_coord_names)

    @property
    def data_ndims(self) -> int:
        return len(self.coord)

    def output_indexing_full(self, slicer: slice) -> Tuple[slice, ...]:
        indicer = [slice(None)] * self.output_ndims
        indicer[self.output_index_coord_index] = slicer
        return tuple(indicer)

    def output_indexing_broadcast(self, slicer: slice) -> Tuple[Any, ...]:
        indicer = [None] * self.output_ndims
        indicer[self.output_index_coord_index] = slicer  # type: ignore
        return tuple(indicer)

    def coordinate_period(self, coordinate_name: str) -> Optional[float]:
        if coordinate_name in self.data_periodic_coordinates:
            return self.data_periodic_coordinates[coordinate_name]
        else:
            return None

    @property
    def data_is_periodic(self) -> bool:
        return self.data_period is not None

    def interpolate(
        self,
        points: Dict[str, np.ndarray],
    ) -> np.ndarray:
        """

        :param self:
        :param points:

        :return:
        """
        number_points = len(points[self.interp_coord_names[0]])

        # Find indices and weights for the succesive 1d interpolation problems
        indices_1d = np.empty((self.interp_ndims, 2, number_points), dtype="int64")

        weights_1d = np.empty((self.interp_ndims, 2, number_points), dtype="float64")

        for index, (coordinate_name, coordinate) in enumerate(
            self.interpolating_coordinates
        ):
            period = self.coordinate_period(coordinate_name)
            indices_1d[index, :, :] = enclosing_points_1d(
                coordinate, points[coordinate_name], period=period
            )
            weights_1d[index, :, :] = interpolation_weights_1d(
                coordinate,
                points[coordinate_name],
                indices_1d[index, :, :],
                period=period,
                extrapolate_left=False,
                extrapolate_right=False,
                nearest_neighbour=self.nearest_neighbour,
            )

        if self.data_is_periodic:
            return self._periodic_data_interpolator(
                number_points, indices_1d, weights_1d
            )
        else:
            return self._data_interpolator(number_points, indices_1d, weights_1d)

    def _data_interpolator(
        self, number_points: int, indices_1d: np.ndarray, weights_1d: np.ndarray
    ) -> np.ndarray:
        # We keep a running sum of the weights, if a point is excluded because it
        # contains no data (NaN) the weights will no longer add up to 1 - and we
        # reschale to account for the missing value. This is an easy way to account
        # for interpolation near missing points. Note that if the contribution of
        # missing weights ( 1-weights_sum) exceeds 0.5 - we consider the point
        # invalid.
        output_shape = self.output_shape(number_points)
        weights_sum = np.zeros(output_shape)
        interp_val = np.zeros(output_shape, dtype=np.float64)

        for intp_indices_nd, intp_weight_nd in _next_point(
            self.interp_ndims, indices_1d, weights_1d
        ):
            # Loop over all interpolation points one at a time.
            val = self.get_data(intp_indices_nd, self.interp_coord_dim_indices)

            mask = np.all(
                ~np.isnan(val), axis=self.output_passive_coord_dim_indices
            ) & (intp_weight_nd > 0)

            weights_sum[self.output_indexing_full(mask)] += intp_weight_nd[
                self.output_indexing_broadcast(mask)
            ]

            interp_val[self.output_indexing_full(mask)] += (
                intp_weight_nd[self.output_indexing_broadcast(mask)]
                * val[self.output_indexing_full(mask)]
            )

        with np.errstate(invalid="ignore", divide="ignore"):
            return np.where(weights_sum > 0.5, interp_val / weights_sum, np.nan)

    def _periodic_data_interpolator(
        self, number_points: int, indices_1d: np.ndarray, weights_1d: np.ndarray
    ) -> np.ndarray:
        # We keep a running sum of the weights, if a point is excluded because it
        # contains no data (NaN) the weights will no longer add up to 1 - and we
        # reschale to account for the missing value. This is an easy way to account
        # for interpolation near missing points. Note that if the contribution of
        # missing weights ( 1-weights_sum) exceeds 0.5 - we consider the point
        # invalid.
        output_shape = self.output_shape(number_points)
        weights_sum = np.zeros(output_shape)

        interp_val = np.zeros(output_shape, dtype=np.complex64)
        for intp_indices_nd, intp_weight_nd in _next_point(
            self.interp_ndims, indices_1d, weights_1d
        ):
            # Loop over all interpolation points one at a time.
            to_rad = np.pi * 2 / self.data_period  # type: ignore
            val = np.exp(
                1j
                * self.get_data(intp_indices_nd, self.interp_coord_dim_indices)
                * to_rad
            )

            mask = np.all(~np.isnan(val), axis=self.output_passive_coord_dim_indices)

            weights_sum[self.output_indexing_full(mask)] += intp_weight_nd[
                self.output_indexing_broadcast(mask)
            ]

            interp_val[self.output_indexing_full(mask)] += (
                intp_weight_nd[self.output_indexing_broadcast(mask)]
                * val[self.output_indexing_full(mask)]
            )

        interp_val = (
            np.angle(  # type: ignore
                np.where(weights_sum > 0.5, interp_val / weights_sum, np.nan)
            )
            * self.data_period  # type: ignore
            / np.pi
            / 2
        )

        return wrapped_difference(
            delta=interp_val, period=self.data_period, discont=self.data_period
        )

Instance variables

var data_is_periodic : bool
Expand source code
@property
def data_is_periodic(self) -> bool:
    return self.data_period is not None
var data_ndims : int
Expand source code
@property
def data_ndims(self) -> int:
    return len(self.coord)
var interp_coord_dim_indices : List[int]
Expand source code
@property
def interp_coord_dim_indices(self) -> List[int]:
    return [self.coord.index(x) for x in self.interp_coord_names]
var interp_index_coord_index : int
Expand source code
@property
def interp_index_coord_index(self) -> int:
    return self.coord.index(self.interp_index_coord_name)
var interp_ndims : int
Expand source code
@property
def interp_ndims(self) -> int:
    return len(self.interp_coord_names)
var interpolating_coordinates : List[Tuple[str, numpy.ndarray]]
Expand source code
@property
def interpolating_coordinates(self) -> List[Tuple[str, np.ndarray]]:
    return [x for x in self.data_coordinates if x[0] in self.interp_coord_names]
var output_index_coord_index : int
Expand source code
@property
def output_index_coord_index(self) -> int:
    return int(
        np.searchsorted(
            self.passive_coord_dim_indices, self.interp_index_coord_index
        )
    )
var output_ndims : int
Expand source code
@property
def output_ndims(self) -> int:
    return self.data_ndims - self.interp_ndims + 1
var output_passive_coord_dim_indices : Tuple[int, ...]
Expand source code
@property
def output_passive_coord_dim_indices(self) -> Tuple[int, ...]:
    indices = list(range(self.output_ndims))
    _ = indices.pop(indices.index(self.output_index_coord_index))
    return tuple(indices)
var passive_coord_dim_indices : List[int]
Expand source code
@property
def passive_coord_dim_indices(self) -> List[int]:
    return [self.coord.index(x) for x in self.passive_coordinate_names]
var passive_coordinate_names : List[str]
Expand source code
@property
def passive_coordinate_names(self) -> List[str]:
    return [name for name in self.coord if name not in self.interp_coord_names]

Methods

def coordinate_period(self, coordinate_name: str) ‑> Optional[float]
Expand source code
def coordinate_period(self, coordinate_name: str) -> Optional[float]:
    if coordinate_name in self.data_periodic_coordinates:
        return self.data_periodic_coordinates[coordinate_name]
    else:
        return None
def interpolate(self, points: Dict[str, numpy.ndarray]) ‑> numpy.ndarray

:param self: :param points:

:return:

Expand source code
def interpolate(
    self,
    points: Dict[str, np.ndarray],
) -> np.ndarray:
    """

    :param self:
    :param points:

    :return:
    """
    number_points = len(points[self.interp_coord_names[0]])

    # Find indices and weights for the succesive 1d interpolation problems
    indices_1d = np.empty((self.interp_ndims, 2, number_points), dtype="int64")

    weights_1d = np.empty((self.interp_ndims, 2, number_points), dtype="float64")

    for index, (coordinate_name, coordinate) in enumerate(
        self.interpolating_coordinates
    ):
        period = self.coordinate_period(coordinate_name)
        indices_1d[index, :, :] = enclosing_points_1d(
            coordinate, points[coordinate_name], period=period
        )
        weights_1d[index, :, :] = interpolation_weights_1d(
            coordinate,
            points[coordinate_name],
            indices_1d[index, :, :],
            period=period,
            extrapolate_left=False,
            extrapolate_right=False,
            nearest_neighbour=self.nearest_neighbour,
        )

    if self.data_is_periodic:
        return self._periodic_data_interpolator(
            number_points, indices_1d, weights_1d
        )
    else:
        return self._data_interpolator(number_points, indices_1d, weights_1d)
def output_indexing_broadcast(self, slicer: slice) ‑> Tuple[Any, ...]
Expand source code
def output_indexing_broadcast(self, slicer: slice) -> Tuple[Any, ...]:
    indicer = [None] * self.output_ndims
    indicer[self.output_index_coord_index] = slicer  # type: ignore
    return tuple(indicer)
def output_indexing_full(self, slicer: slice) ‑> Tuple[slice, ...]
Expand source code
def output_indexing_full(self, slicer: slice) -> Tuple[slice, ...]:
    indicer = [slice(None)] * self.output_ndims
    indicer[self.output_index_coord_index] = slicer
    return tuple(indicer)
def output_shape(self, number_of_points: int) ‑> numpy.ndarray
Expand source code
def output_shape(self, number_of_points: int) -> np.ndarray:
    output_shape = np.ones(self.output_ndims, dtype="int32")
    interpolating_index = self.output_index_coord_index
    passive_ind = self.passive_coord_dim_indices
    jj = 0
    for index in range(self.output_ndims):
        if index == interpolating_index:
            output_shape[index] = number_of_points
        else:
            output_shape[index] = self.data_shape[passive_ind[jj]]
            jj += 1
    return output_shape