"""Helper functions to reorganize data."""
from numbers import Number
import numpy as np
import pandas as pd
import xarray as xr
from arviz_base.converters import convert_to_dataset
from arviz_base.labels import BaseLabeller
from arviz_base.rcparams import rcParams
from arviz_base.sel_utils import xarray_sel_iter
from arviz_base.utils import _var_names
__all__ = [
"dataset_to_dataarray",
"dataset_to_dataframe",
"explode_dataset_dims",
"extract",
"references_to_dataset",
]
# TODO: remove this ignore about too many statements once the code uses validator functions
def _stratified_resample(weights, rng):
"""Stratified resampling."""
N = len(weights)
single_uniform = (rng.random(N) + np.arange(N)) / N
indexes = np.zeros(N, dtype=int)
cum_sum = np.cumsum(weights)
i, j = 0, 0
while i < N:
if single_uniform[i] < cum_sum[j]:
indexes[i] = j
i += 1
else:
j += 1
return indexes
[docs]
def dataset_to_dataarray(
ds, sample_dims=None, labeller=None, add_coords=True, new_dim="label", label_type="flat"
):
"""Convert a Dataset to a stacked DataArray, using a labeller to set coordinate values.
Parameters
----------
ds : Dataset
Input data
sample_dims : sequence of hashable, optional
Dimensions that are present in all variables of `ds` and should be kept
in the returned `DataArray`. All other variables will be stacked
into `new_dim`.
labeller : labeller, optional
Labeller instance with a `make_label_flat` or `make_label_vert` method that
will be use to generate the coordinate values along `new_dim`.
add_coords : bool, default True
Return multiple coordinate variables along `new_dim`. These will contain the newly
generated labels, the stacked variable names, and stacked coordinate values.
new_dim : hashable, default "label"
Name of the new dimension that is created from stacking variables
and dimensions not in `sample_dims`.
label_type : {"flat", "vert"}, default "flat"
if "flat", then `labeller.make_label_flat` method is used to generate the labels and if
"vert", then `labeller.make_label_vert` method is used.
Returns
-------
DataArray
Examples
--------
Convert the posterior group into a stacked and labelled dataarray:
.. jupyter-execute::
import xarray as xr
from arviz_base import load_arviz_data, dataset_to_dataarray
xr.set_options(display_expand_data=False)
idata = load_arviz_data("centered_eight")
dataset_to_dataarray(idata.posterior.dataset)
"""
if labeller is None:
labeller = BaseLabeller()
if sample_dims is None:
sample_dims = rcParams["data.sample_dims"]
if label_type not in ("flat", "vert"):
raise ValueError(f"Invalid label_type: {label_type}")
labeled_stack = ds.to_stacked_array(new_dim, sample_dims=sample_dims)
labels = [
(labeller.make_label_flat if label_type == "flat" else labeller.make_label_vert)(
var_name, sel, isel
)
for var_name, sel, isel in xarray_sel_iter(ds, skip_dims=set(sample_dims))
]
indexes = [
idx_name
for idx_name, idx in labeled_stack.xindexes.items()
if (idx_name not in sample_dims) and (idx.dim not in sample_dims)
]
labeled_stack = labeled_stack.drop_indexes(indexes).assign_coords({new_dim: labels})
for idx_name in indexes:
if idx_name == new_dim:
continue
if add_coords:
labeled_stack = labeled_stack.set_xindex(idx_name)
else:
labeled_stack = labeled_stack.drop_vars(idx_name)
return labeled_stack
[docs]
def dataset_to_dataframe(ds, sample_dims=None, labeller=None, multiindex=False, new_dim="label"):
"""Convert a Dataset to a DataFrame via a stacked DataArray, using a labeller.
Parameters
----------
ds : Dataset
sample_dims : sequence of hashable, optional
labeller : labeller, optional
multiindex : {"row", "column"} or bool, default False
new_dim : hashable, default "label"
Returns
-------
pandas.DataFrame
Examples
--------
The output will have whatever is uses as `sample_dims` as the columns of
the DataFrame, so when these are much longer we might want to transpose the
output:
.. jupyter-execute::
from arviz_base import load_arviz_data, dataset_to_dataframe
idata = load_arviz_data("centered_eight")
dataset_to_dataframe(idata.posterior.dataset)
The default is to only return a single index, with the labels or tuples of coordinate
values in the stacked dimensions. To keep all data from all coordinates as a multiindex
use ``multiindex=True``
.. jupyter-execute::
dataset_to_dataframe(idata.posterior.dataset, multiindex=True)
The only restriction on `sample_dims` is that it is present in all variables
of the dataset. Consequently, we can compute statistical summaries,
concatenate the results into a single dataset creating a new dimension.
.. jupyter-execute::
import xarray as xr
dims = ["chain", "draw"]
post = idata.posterior.dataset
summaries = xr.concat(
(
post.mean(dims).expand_dims(summary=["mean"]),
post.median(dims).expand_dims(summary=["median"]),
post.quantile([.25, .75], dim=dims).rename(
quantile="summary"
).assign_coords(summary=["1st quartile", "3rd quartile"])
),
dim="summary"
)
summaries
Then convert the result into a DataFrame for ease of viewing.
.. jupyter-execute::
dataset_to_dataframe(summaries, sample_dims=["summary"]).T
Note that if all summaries were scalar, it would not be necessary to use
:meth:`~xarray.Dataset.expand_dims` or renaming dimensions, using
:meth:`~xarray.Dataset.assign_coords` on the result to label the newly created
dimension would be enough. But using this approach we already generate a dimension
with coordinate values and can also combine non scalar summaries.
"""
if sample_dims is None:
sample_dims = rcParams["data.sample_dims"]
da = dataset_to_dataarray(ds, sample_dims=sample_dims, labeller=labeller, new_dim=new_dim)
sample_dim = sample_dims[0]
if len(sample_dims) > 1:
da = da.stack(sample=sample_dims)
sample_dim = "sample"
sample_idx = da[sample_dim]
label_idx = da[new_dim]
if multiindex is True or multiindex == "row":
idx_dict = {
idx_name: da[idx_name].to_numpy()
for idx_name in da.xindexes
if sample_dim in da[idx_name].dims
}
sample_idx = pd.MultiIndex.from_arrays(list(idx_dict.values()), names=list(idx_dict.keys()))
if multiindex is True or multiindex == "column":
idx_dict = {
idx_name: da[idx_name].to_numpy()
for idx_name in da.xindexes
if new_dim in da[idx_name].dims
}
label_idx = pd.MultiIndex.from_arrays(list(idx_dict.values()), names=list(idx_dict.keys()))
df = pd.DataFrame(
da.transpose(sample_dim, new_dim).to_numpy(), columns=label_idx, index=sample_idx
)
return df
[docs]
def explode_dataset_dims(ds, dim, labeller=None):
"""Explode dims of a dataset so each slice along them becomes its own variable.
Parameters
----------
ds : Dataset
dim : hashable or sequence of hashable
Dimension or dimensions along which slices to be stored as independent variables should
be defined.
labeller : labeller, optional
Instance of a labeller class used to label the slices generated when exploding along `dim`.
The method ``make_label_flat`` is used.
Returns
-------
Dataset
The dataset with all variables that have `dim` exploded into the respective slices
as new variables.
Examples
--------
In some cases, instead of ``theta`` as a ``(..., school)`` shape variable we'll want
independent variables for each slice:
.. jupyter-execute::
from arviz_base import load_arviz_data, explode_dataset_dims
import xarray as xr
idata = load_arviz_data("centered_eight")
explode_dataset_dims(idata.posterior.dataset, "school")
"""
if isinstance(dim, str):
dim = [dim]
if labeller is None:
labeller = BaseLabeller()
return xr.Dataset(
{
labeller.make_label_flat(var_name, sel, isel): ds[var_name].sel(sel, drop=True)
for var_name, sel, isel in xarray_sel_iter(
ds, skip_dims={d for d in ds.dims if d not in dim}
)
}
)
[docs]
def references_to_dataset(references, ds, sample_dims=None, ref_dim=None):
"""Generate an :class:`~xarray.Dataset` compabible with `ds` from `references`.
Cast common formats to provide references to a compatible Dataset.
This function does not aim to be exhaustive, anything somewhat peculiar or complex
will probably be better off building a Dataset manually instead.
Parameters
----------
references : scalar or 1D array-like or dict or DataArray or Dataset
References to cast into a compatible dataset.
* scalar inputs are interpreted as a reference line in each variable+coordinate not in
`sample_dims` combination.
* array-like inputs are interpreted as multiple reference lines in each variable+coordinate
not in `sample_dims` combination. All subset having the same references
and all references linked to every subset.
* dict inputs are interpreted as array-like with each array matched to the variable
corresponding to that dictionary key.
* DataArray inputs are interpreted as an array-like if unnamed or as a single key
dictionary if named.
* Dataset inputs are returned as is but won't raise an error.
ds : Dataset
Dataset containing the data `references` should be compatible with.
sample_dims : iterable of hashable, optional
Sample dimensions in `ds`. The dimensions in the output will be the dimensions
in `ds` minus `sample_dims` plus optionally a "ref_line_dim" for non-scalar references.
ref_dim : str or list, optional
Names for the new dimensions created during reference value broadcasting. Defaults to None.
By default, "ref_dim" is added for 1D references and "ref_dim_x" for N-dimensional
references when broadcasting over one or more variables.
Returns
-------
Dataset
A Dataset containing a subset of the variables, dimensions, and coordinate names from ds,
with additional "ref_dim" dimensions added when multiple references are requested for one
or more variables.
See Also
--------
xarray.Dataset : Dataset constructor
Examples
--------
Generate a reference dataset with 0 compatible with the centered eight example data:
.. jupyter-execute::
from arviz_base import load_arviz_data, references_to_dataset
idata = load_arviz_data("centered_eight")
references_to_dataset(0, idata.posterior.dataset)
Generate a reference dataset with different references for each variable:
.. jupyter-execute::
references_to_dataset({"mu": -1, "tau": 1, "theta": 0}, idata.posterior.dataset)
Or a similar case but with different number of references for each variable:
.. jupyter-execute::
ref_ds = references_to_dataset(
{"mu": [-1, 0, 1], "tau": [1, 10], "theta": 0},
idata.posterior.dataset
)
ref_ds
Once we have a compatible dataset, we can for example compute the probability
of the samples being above the reference value(s):
.. jupyter-execute::
(idata.posterior.dataset > ref_ds).mean()
"""
# quick exit if dataset input
if isinstance(references, xr.Dataset):
return references
# process argument defaults
if sample_dims is None:
sample_dims = rcParams["data.sample_dims"]
if isinstance(sample_dims, str):
sample_dims = [sample_dims]
if isinstance(ref_dim, str):
ref_dim = [ref_dim]
# start covering cases, for dataarray, if its name is a variable convert to dataset
# if it has no name treat is an array-like
if isinstance(references, xr.DataArray):
name = references.name
if name is not None:
if name not in ds.data_vars:
raise ValueError(
"Input is a named DataArray whose name doesn't match any variable in `ds`. "
"Either use an unamed DataArray or ndarray or make sure the name matches."
)
return references.to_dataset()
references = references.values
# for scalars generate a dataset with requested shape full of reference value
# check for numerical scalar following advise from
# https://numpy.org/doc/2.2/reference/generated/numpy.isscalar.html
if isinstance(references, Number):
aux_ds = ds if sample_dims is None else ds.isel({dim: 0 for dim in sample_dims})
return xr.full_like(aux_ds, references, dtype=np.array(references).dtype)
# for array-like convert to dict so it is handled later on
if isinstance(references, list | tuple | np.ndarray):
references = {var_name: references for var_name in ds.data_vars}
if isinstance(references, dict):
ref_dict = {}
for var_name, da in ds.items():
if var_name not in references:
continue
ref_values = np.atleast_1d(references[var_name])
new_dims = ref_values.shape
if ref_dim is None:
new_dim_names = (
["ref_dim"]
if len(new_dims) == 1
else [f"ref_dim_{i}" for i in range(len(new_dims))]
)
else:
if len(ref_dim) != len(new_dims):
raise ValueError(
f"ref_dim length ({len(ref_dim)}) does not match reference values "
f"length ({len(new_dims)}) for data variable {var_name}"
)
new_dim_names = ref_dim[: len(new_dims)]
sizes = {dim: length for dim, length in da.sizes.items() if dim not in sample_dims}
full_shape = list(sizes.values()) + list(new_dims)
data = np.broadcast_to(ref_values, full_shape)
ref_dict[var_name] = xr.DataArray(
data,
dims=list(sizes) + new_dim_names,
coords=dict(zip(new_dim_names, [np.arange(size) for size in new_dims]))
| {
coord_name: coord_da
for coord_name, coord_da in da.coords.items()
if not set(coord_da.dims).intersection(sample_dims)
},
)
return xr.Dataset(ref_dict)
raise TypeError("Unrecognized input type for `references`")