Source code for arviz_base.sel_utils

"""Utilities for selecting and iterating on xarray objects."""

from itertools import product, tee

import numpy as np
import xarray as xr

from arviz_base.labels import BaseLabeller
from arviz_base.rcparams import rcParams

__all__ = ["xarray_sel_iter", "xarray_var_iter", "xarray_to_ndarray"]


def _dims(data, var_name, skip_dims):
    return [dim for dim in data[var_name].dims if dim not in skip_dims]


def _zip_dims(new_dims, vals):
    return [dict(zip(new_dims, prod)) for prod in product(*vals)]


[docs] def xarray_sel_iter( data, var_names=None, combined=None, skip_dims=None, dim_to_idx=None, reverse_selections=False ): """Convert xarray data to an iterator over variable names and selections. Iterates over each var_name and all of its dimensions, returning the variable names and selections that allow properly obtain the data subsets from ``data`` as desired. The iterable returned defines an exhaustive collection of subsets. Both the input object and the selections defined can have any dimensionality, and the selections from each element in the iterable can have different dimensionality between them. When looping within a dimension, this can be done over the dimension itself or via unique items of explicit indexes for that dimension. Parameters ---------- data : Dataset or DataArray Posterior data in an xarray var_names : sequence of hashable, optional Should be a subset of data.data_vars. Defaults to all of them. combined : bool, optional Whether to combine chains or leave them separate. By default (``None``), this is ignored and the `chain` dimension is looped over or skipped based on `skip_dims`. If set to ``True``/``False`` then `skip_dims` is modified in order to ensure combining chains or not. skip_dims : set, optional Dimensions to not iterate over. Defaults to rcParam ``data.sample_dims``. dim_to_idx : mapping of {hashable : hashable}, optional Mapping from dimensions to indexes to loop over these dimensions using unique items in the provided index. reverse_selections : bool Whether to reverse selections before iterating. Yields ------ var_name : str Variable name to which `selection` and `iselection` correspond to. selection : dict of {hashable : any} Keys are coordinate names and values are scalar coordinate values. To get the values of the variable at these coordinates, do ``data[var_name].sel(selection)`` for :class:`~xarray.Dataset` or ``data.sel(selection)`` for :class:`~xarray.DataArray`. iselection : dict of {hashable : any} Keys are dimension names and values are positional indexes (might not be scalars). To get the values of the variable at these coordinates, do ``data[var_name].isel(iselection)`` for :class:`~xarray.Dataset` or ``data.isel(iselection)`` for :class:`~xarray.DataArray`. Examples -------- Let's create a 3d :class:`~xarray.DataArray` with dimensions "chain", "draw" and "obs_dim". .. jupyter-execute:: import xarray as xr import numpy as np from arviz_base import xarray_sel_iter xr.set_options(display_expand_data=False, display_expand_indexes=True) data = xr.DataArray( np.random.default_rng(2).normal(size=(2,3,7)), dims=["chain", "draw", "obs_dim"], coords={"chain": [1, 2]}, name="sample" ) data By default, ``xarray_sel_iter`` will return an iterable with the subsets that are generated from looping over all dimensions not in ``rcParams["data.sample_dims"]`` (the default value for `skip_dims`). Here, it will be an iterable of length 7, selecting each position in the "obs_dim" dimension: .. jupyter-execute:: list(xarray_sel_iter(data)) Here we are using a ``DataArray``, so the first position in each tuple, ``var_name`` is always the same and corresponds to its name. If we want to iterate over each _sample_ (pair of "chain", "draw" values) we can use `skip_dims`: .. jupyter-execute:: list(xarray_sel_iter(data, skip_dims={"obs_dim"})) Now there are 6 elements, 3 values for "draw" times 2 values for "chain". Note also how the two returned selections now differ. The _coordinate_ values for "chain" are ``1, 2`` whereas their corresponding _positions_ are `0, 1`. To go further in the examples, and show the usage of `dim_to_idx` we need to add some explicit indexes to the ``DataArray``. We do that by adding new coordinates and then setting them as indexes. .. jupyter-execute:: data = data.assign_coords( {"obs_id": ("obs_dim", np.arange(7)), "label_id": ("obs_dim", list("babacbc"))} ).set_xindex("obs_id").set_xindex("label_id") data Note that both the "Coordinates" and the "Indexes" sections of the output have been updated. We can now loop over the "obs_dim" dimension by "itself" (like we did in the first example) or using either of these two new indexes. If we use "label_id", the returned iterator will have length 3, as there are only 3 unique values in "label_id", ``a, b, c``. .. jupyter-execute:: list(xarray_sel_iter(data, dim_to_idx={"obs_dim": "label_id"})) Note that the order of the coordinate values is preserved. Moreover, now not only the values in the selection dict values are different, also their keys. "label_id" is a coordinate+index, but not a dimension, so it can not be used for positional indexing. See Also -------- xarray_var_iter Return a similar iterator whose elements also include the selected subset as a DataArray. """ if skip_dims is None: skip_dims = set(rcParams["data.sample_dims"]) if dim_to_idx is None: dim_to_idx = {} if combined is not None: if combined: skip_dims.add("chain") else: skip_dims.remove("chain") if var_names is None: if isinstance(data, xr.Dataset): var_names = list(data.data_vars) elif isinstance(data, xr.DataArray): var_names = [data.name] data = {data.name: data} for var_name in var_names: if var_name in data: new_dims = _dims(data, var_name, skip_dims) new_dimsidx = [dim_to_idx.get(dim) if dim in dim_to_idx else dim for dim in new_dims] vals = [ ( data[var_name].xindexes[dim_to_idx.get(dim)].to_pandas_index().unique().values if dim in dim_to_idx else data[var_name][dim].values ) for dim in new_dims ] dims = _zip_dims(new_dimsidx, vals) ivals = [] for i, dim in enumerate(new_dims): if dim in dim_to_idx: idx_values = data[var_name][dim_to_idx.get(dim)].values dim_ivals = [np.argwhere(idx_values == v).squeeze() for v in vals[i]] ivals.append([v.item() if v.size == 1 else v for v in dim_ivals]) else: ivals.append(range(len(vals[i]))) idims = _zip_dims(new_dims, ivals) if reverse_selections: dims = reversed(dims) idims = reversed(idims) for selection, iselection in zip(dims, idims): yield var_name, selection, iselection
[docs] def xarray_var_iter( data, var_names=None, combined=None, skip_dims=None, dim_to_idx=None, reverse_selections=False, dim_order=None, ): """Convert xarray data to an iterator over vectors. Iterates over each var_name and all of its coordinates, returning selected subsets as DataArray. Parameters ---------- data : xarray.Dataset Posterior data in an xarray var_names : sequence of hashable, optional Should be a subset of data.data_vars. Defaults to all of them. Passed to :func:`~arviz_base.xarray_sel_iter`. combined : bool, optional Whether to combine chains or leave them separate. Passed to :func:`~arviz_base.xarray_sel_iter`. skip_dims : set, optional Dimensions to not iterate over. Passed to :func:`~arviz_base.xarray_sel_iter`. dim_to_idx : mapping of {hashable : hashable}, optional Mapping from dimension names to index names to define a different way to loop over that dimension. Passed to :func:`~arviz_base.xarray_sel_iter`. reverse_selections : bool, optional Whether to reverse selections before iterating. Passed to :func:`~arviz_base.xarray_sel_iter`. dim_order : list, optional Order for the first dimensions. Skips dimensions not found in the variable. Yields ------ var_name : str Variable name to which `selection`, `iselection` and `data_subset` correspond to. selection : dict of {hashable : any} Keys are coordinate names and values are scalar coordinate values. iselection : dict of {hashable : any} Keys are dimension names and values are positional indexes (might not be scalars). data_subset : DataArray Values of the variable at those coordinates. See Also -------- xarray_sel_iter Return the iterator without the DataArray subset corresponding to the selection. """ data_to_sel = data if var_names is None and isinstance(data, xr.DataArray): data_to_sel = {data.name: data} if isinstance(dim_order, str): dim_order = [dim_order] for var_name, selection, iselection in xarray_sel_iter( data, var_names=var_names, combined=combined, skip_dims=skip_dims, dim_to_idx=dim_to_idx, reverse_selections=reverse_selections, ): selected_data = data_to_sel[var_name].sel(selection) if dim_order is not None: dim_order_selected = [dim for dim in dim_order if dim in selected_data.dims] if dim_order_selected: selected_data = selected_data.transpose(*dim_order_selected, ...) yield var_name, selection, iselection, selected_data.values
def xarray_to_ndarray(data, *, var_names=None, combined=True, label_fun=None): """Take xarray data and unpacks into variables and data into list and numpy array respectively. Assumes that chain and draw are in coordinates Parameters ---------- data : Dataset Data in an xarray from an InferenceData object. Examples include posterior or sample_stats var_names : sequence of hashable, optional Should be a subset of data.data_vars not including chain and draws. Defaults to all of them combined : bool, default True Whether to combine chain into one array label_fun : callable, optional Returns ------- var_names : list List of variable names data : ndarray Data values """ if label_fun is None: label_fun = BaseLabeller().make_label_vert data_to_sel = data if var_names is None and isinstance(data, xr.DataArray): data_to_sel = {data.name: data} iterator1, iterator2 = tee(xarray_sel_iter(data, var_names=var_names, combined=combined)) vars_and_sel = list(iterator1) unpacked_var_names = [ label_fun(var_name, selection, isel) for var_name, selection, isel in vars_and_sel ] # Merge chains and variables, check dtype to be compatible with divergences data data0 = data_to_sel[vars_and_sel[0][0]].sel(**vars_and_sel[0][1]) unpacked_data = np.empty((len(unpacked_var_names), data0.size), dtype=data0.dtype) for idx, (var_name, selection, _) in enumerate(iterator2): unpacked_data[idx] = data_to_sel[var_name].sel(**selection).values.flatten() return unpacked_var_names, unpacked_data