Source code for arviz_base.base

"""ArviZ basic functions and converters."""

import datetime
import importlib
import re
import warnings
from collections.abc import Callable, Hashable, Iterable, Mapping
from copy import deepcopy
from typing import TYPE_CHECKING, Any, TypeVar

import numpy as np
import xarray as xr

from arviz_base._version import __version__
from arviz_base.rcparams import rcParams
from arviz_base.types import CoordSpec, DictData, DimSpec

if TYPE_CHECKING:
    pass

RequiresArgTypeT = TypeVar("RequiresArgTypeT")
RequiresReturnTypeT = TypeVar("RequiresReturnTypeT")


[docs] def generate_dims_coords( shape: Iterable[int], var_name: Hashable, dims: Iterable[Hashable] | None = None, coords: CoordSpec | None = None, index_origin: int | None = None, skip_event_dims: bool = False, check_conventions: bool = True, ): """Generate default dimensions and coordinates for a variable. Parameters ---------- shape : iterable of int Shape of the variable var_name : iterable of hashable Name of the variable. If no dimension name(s) is provided, ArviZ will generate a default dimension name using ``var_name``, e.g. ``"foo_dim_0"`` for the first dimension if ``var_name`` is ``"foo"``. dims : iterable of hashable, optional Dimension names (or identifiers) for the variable. If `skip_event_dims` is ``True`` it can be longer than `shape`. In that case, only the first ``len(shape)`` elements in `dims` will be used. Moreover, if needed, axis of length 1 in shape will also be given different names than the ones provided in `dims`. coords : dict of {hashable: array_like}, optional Map of dimension names to coordinate values. Dimensions without coordinate values mapped to them will be given an integer range as coordinate values. It can have keys for dimension names not present in that variable. index_origin : int, optional Starting value of generated integer coordinate values. Defaults to the value in rcParam ``data.index_origin``. skip_event_dims : bool, default False Whether to allow for different sizes between `shape` and `dims`. See description in `dims` for more details. check_conventions : bool, optional Check ArviZ conventions. Per the ArviZ schema, some dimension names have specific meaning and there might be inconsistencies caught here in the dimension naming step. Returns ------- dims : list of hashable Default dims for that variable coords : dict of {hashable: ndarray} Default coords for that variable """ if index_origin is None: index_origin = rcParams["data.index_origin"] if dims is None: dims = [] if coords is None: coords = {} coords = deepcopy(coords) dims = deepcopy(dims) if len(dims) > len(shape): if skip_event_dims: dims = dims[: len(shape)] else: raise ValueError( ( "In variable {var_name}, there are " + "more dims ({dims_len}) given than existing ones ({shape_len}). " + "dims and shape should match with `skip_event_dims=False`" ).format( var_name=var_name, dims_len=len(dims), shape_len=len(shape), ) ) if skip_event_dims: # In some cases, even when there is an event dim, the shape has the # right length but the length of the axis doesn't match. # For example, the log likelihood of a 3d MvNormal with 20 observations # should be (20,) but it can also be (20, 1). The code below ensures # the (20, 1) option also works. for i, (dim, dim_size) in enumerate(zip(dims, shape)): if (dim in coords) and (dim_size != len(coords[dim])): dims = dims[:i] break missing_dim_count = 0 for idx, dim_len in enumerate(shape): if idx + 1 > len(dims): dim_name = f"{var_name}_dim_{missing_dim_count}" missing_dim_count += 1 dims.append(dim_name) elif dims[idx] is None: dim_name = f"{var_name}_dim_{missing_dim_count}" missing_dim_count += 1 dims[idx] = dim_name dim_name = dims[idx] if dim_name not in coords: coords[dim_name] = np.arange(index_origin, dim_len + index_origin) coords = {dim_name: coords[dim_name] for dim_name in dims} if check_conventions: short_long_pairs = (("draw", "chain"), ("draw", "pred_id"), ("sample", "pred_id")) for long_dim, short_dim in short_long_pairs: if ( long_dim in dims and short_dim in dims and len(coords[short_dim]) > len(coords[long_dim]) ): warnings.warn( f"Found {short_dim} dimension to be longer than {long_dim} dimension, " "check dimensions are correctly named.", UserWarning, ) if "sample" in dims and (("draw" in dims) or ("chain" in dims)): warnings.warn( "Found dimension named 'sample' alongside 'chain'/'draw' ones, " "check dimensions are correctly named.", UserWarning, ) return dims, coords
[docs] def ndarray_to_dataarray( ary, var_name, *, dims=None, sample_dims=None, coords=None, index_origin=None, skip_event_dims=False, check_conventions=True, ): """Convert a numpy array to an xarray.DataArray. The conversion considers some ArviZ conventions and adds extra attributes, so it is similar to initializing an :class:`xarray.DataArray` but not equivalent. Parameters ---------- ary : scalar or array_like Values for the DataArray object to be created. var_name : hashable Name of the created DataArray object. dims : iterable of hashable, optional Dimensions of the DataArray. coords : dict of {hashable: array_like}, optional Coordinates for the dataarray sample_dims : iterable of hashable, optional Dimensions that should be assumed to be present. If missing, they will be added as the dimensions corresponding to the leading axes. index_origin : int, optional Passed to :func:`generate_dims_coords` skip_event_dims : bool, optional Passed to :func:`generate_dims_coords` check_conventions : bool, optional Check ArviZ conventions. Per the ArviZ schema, some dimension names have specific meaning and there might be inconsistencies caught here in the dimension naming step. Returns ------- DataArray See Also -------- dict_to_dataset """ if dims is None: dims = [] if sample_dims is None: sample_dims = rcParams["data.sample_dims"] if sample_dims: var_dims = [sample_dim for sample_dim in sample_dims if sample_dim not in dims] var_dims.extend(dims) else: var_dims = dims var_dims, var_coords = generate_dims_coords( ary.shape if hasattr(ary, "shape") else (), var_name=var_name, dims=var_dims, coords=coords, index_origin=index_origin, skip_event_dims=skip_event_dims, check_conventions=check_conventions, ) return xr.DataArray(ary, coords=var_coords, dims=var_dims)
[docs] def dict_to_dataset( data: DictData, *, attrs: Mapping[Any, Any] | None = None, inference_library: str | None = None, coords: CoordSpec | None = None, dims: DimSpec | None = None, sample_dims: Iterable[Hashable] | None = None, index_origin: int | None = None, skip_event_dims: bool = False, check_conventions: bool = True, ): """Convert a dictionary of numpy arrays to an xarray.Dataset. The conversion considers some ArviZ conventions and adds extra attributes, so it is similar to initializing an :class:`xarray.Dataset` but not equivalent. Parameters ---------- data : dict of {hashable: array_like} Data to convert. Keys are variable names. attrs : dict, optional JSON-like arbitrary metadata to attach to the dataset, in addition to default attributes added by :func:`make_attrs`. .. note:: No serialization checks are done in this function, so you might generate :class:`~xarray.Dataset` objects that can't be serialized or that can only be serialized to some backends. inference_library : module, optional Library used for performing inference. Will be included in the :class:`xarray.Dataset` attributes. coords : dict of {hashable: array_like}, optional Coordinates for the dataset dims : dict of {hashable: iterable of hashable}, optional Dimensions of each variable. The keys are variable names, values are lists of coordinates. sample_dims : iterable of hashable, optional Dimensions that should be assumed to be present in _all_ variables. If missing, they will be added as the dimensions corresponding to the leading axes. index_origin : int, optional Passed to :func:`generate_dims_coords` skip_event_dims : bool, optional Passed to :func:`generate_dims_coords` check_conventions : bool, optional Check ArviZ conventions. Per the ArviZ schema, some dimension names have specific meaning and there might be inconsistencies caught here in the dimension naming step. Returns ------- Dataset See Also -------- ndarray_to_dataarray convert_to_dataset General conversion to `xarray.Dataset` via :func:`convert_to_datatree` Examples -------- Generate a :class:`~xarray.Dataset` with two variables using ``sample_dims``: .. jupyter-execute:: import arviz_base as az import numpy as np rng = np.random.default_rng(2) az.dict_to_dataset( {"a": rng.normal(size=(4, 100)), "b": rng.normal(size=(4, 100))}, sample_dims=["chain", "draw"], ) Generate a :class:`~xarray.Dataset` with the ``chain`` and ``draw`` dimensions in different position. Setting the dimensions for ``a`` to "group" and "chain", ``sample_dims`` will then be used to prepend the "draw" dimension only as "chain" is already there. .. jupyter-execute:: az.dict_to_dataset( {"a": rng.normal(size=(10, 5, 4)), "b": rng.normal(size=(10, 4))}, dims={"a": ["group", "chain"]}, sample_dims=["draw", "chain"], ) """ if dims is None: dims = {} data_vars = { var_name: ndarray_to_dataarray( values, var_name=var_name, dims=dims.get(var_name, []), sample_dims=sample_dims, coords=coords, index_origin=index_origin, skip_event_dims=skip_event_dims, check_conventions=check_conventions, ) for var_name, values in data.items() } return xr.Dataset( data_vars=data_vars, attrs=make_attrs(attrs=attrs, inference_library=inference_library) )
[docs] def make_attrs(attrs=None, inference_library=None): """Make standard attributes to attach to xarray datasets. Parameters ---------- attrs : dict, optional Additional attributes to add or overwrite inference_library : module, optional Library used to perform inference. Returns ------- dict attrs """ default_attrs = { "created_at": datetime.datetime.now(datetime.UTC).isoformat(), "creation_library": "ArviZ", "creation_library_version": __version__, "creation_library_language": "Python", } if inference_library is not None: library_name = inference_library.__name__ default_attrs["inference_library"] = library_name try: version = importlib.metadata.version(library_name) default_attrs["inference_library_version"] = version except importlib.metadata.PackageNotFoundError: if hasattr(inference_library, "__version__"): version = inference_library.__version__ default_attrs["inference_library_version"] = version if attrs is not None: default_attrs.update(attrs) return default_attrs
class requires: # pylint: disable=invalid-name """Decorator to return None if an object does not have the required attribute. If the decorator is called various times on the same function with different attributes, it will return None if one of them is missing. If instead a list of attributes is passed, it will return None if all attributes in the list are missing. Both functionalities can be combined as desired. It can only be used to decorate functions/methods with a single argument, e.g. ``posterior_to_xarray(self)`` is valid, but ``posterior_to_xarray(self, other_arg)`` would not be. See https://github.com/arviz-devs/arviz/pull/1504 for more discussion. """ def __init__(self, *props: str | list[str]) -> None: self.props: tuple[str | list[str], ...] = props def __call__( self, func: Callable[[RequiresArgTypeT], RequiresReturnTypeT] ) -> Callable[[RequiresArgTypeT], RequiresReturnTypeT | None]: # noqa: D202 """Wrap the decorated function.""" def wrapped(cls: RequiresArgTypeT) -> RequiresReturnTypeT | None: """Return None if not all props are available.""" for prop in self.props: prop_list = [prop] if isinstance(prop, str) else prop if all(getattr(cls, prop_i) is None for prop_i in prop_list): return None return func(cls) return wrapped def infer_stan_dtypes(stan_code): """Infer Stan integer variables from generated quantities block.""" # Remove old deprecated comments stan_code = "\n".join( line if "#" not in line else line[: line.find("#")] for line in stan_code.splitlines() ) pattern_remove_comments = re.compile( r'//.*?$|/\*.*?\*/|\'(?:\\.|[^\\\'])*\'|"(?:\\.|[^\\"])*"', re.DOTALL | re.MULTILINE ) stan_code = re.sub(pattern_remove_comments, "", stan_code) # Check generated quantities if "generated quantities" not in stan_code: return {} # Extract generated quantities block gen_quantities_location = stan_code.index("generated quantities") block_start = gen_quantities_location + stan_code[gen_quantities_location:].index("{") curly_bracket_count = 0 block_end = None for block_end, char in enumerate(stan_code[block_start:], block_start + 1): if char == "{": curly_bracket_count += 1 elif char == "}": curly_bracket_count -= 1 if curly_bracket_count == 0: break stan_code = stan_code[block_start:block_end] stan_integer = r"int" stan_limits = r"(?:\<[^\>]+\>)*" # ignore group: 0 or more <....> stan_param = r"([^;=\s\[]+)" # capture group: ends= ";", "=", "[" or whitespace stan_ws = r"\s*" # 0 or more whitespace stan_ws_one = r"\s+" # 1 or more whitespace pattern_int = re.compile( "".join((stan_integer, stan_ws_one, stan_limits, stan_ws, stan_param)), re.IGNORECASE ) dtypes = {key.strip(): "int" for key in re.findall(pattern_int, stan_code)} return dtypes