"""CmdStanPy specific conversion code."""
import logging
import re
from pathlib import Path
import numpy as np
from xarray import DataTree
from arviz_base.base import dict_to_dataset, infer_stan_dtypes, requires
from arviz_base.rcparams import rcParams
_log = logging.getLogger(__name__)
class CmdStanPyConverter:
"""Encapsulate CmdStanPy specific logic."""
# pylint: disable=too-many-instance-attributes
def __init__(
self,
*,
posterior=None,
posterior_predictive=None,
predictions=None,
prior=None,
prior_predictive=None,
observed_data=None,
constant_data=None,
predictions_constant_data=None,
log_likelihood=False,
index_origin=None,
coords=None,
dims=None,
save_warmup=None,
dtypes=None,
):
self.posterior = posterior # CmdStanPy CmdStanMCMC object
self.posterior_predictive = posterior_predictive
self.predictions = predictions
self.prior = prior
self.prior_predictive = prior_predictive
self.observed_data = observed_data
self.constant_data = constant_data
self.predictions_constant_data = predictions_constant_data
self.log_likelihood = log_likelihood
self.index_origin = index_origin
self.coords = coords
self.dims = dims
self.save_warmup = rcParams["data.save_warmup"] if save_warmup is None else save_warmup
import cmdstanpy # pylint: disable=import-error
if dtypes is None:
dtypes = {}
elif isinstance(dtypes, cmdstanpy.model.CmdStanModel):
model_code = dtypes.code()
dtypes = infer_stan_dtypes(model_code)
elif isinstance(dtypes, str):
dtypes_path = Path(dtypes)
if dtypes_path.exists():
with dtypes_path.open("r", encoding="UTF-8") as f_obj:
model_code = f_obj.read()
else:
model_code = dtypes
dtypes = infer_stan_dtypes(model_code)
self.dtypes = dtypes
if self.log_likelihood is True and "log_lik" in self.posterior.stan_variables():
self.log_likelihood = ["log_lik"]
if isinstance(self.log_likelihood, bool):
self.log_likelihood = None
self.cmdstanpy = cmdstanpy
def _warmup_return_to_dict(self, data, data_warmup, group):
res = {
group: dict_to_dataset(
data,
inference_library=self.cmdstanpy,
coords=self.coords,
dims=self.dims,
index_origin=self.index_origin,
skip_event_dims="log_likelihood" == group,
),
}
if self.save_warmup and data_warmup:
res[f"warmup_{group}"] = dict_to_dataset(
data_warmup,
inference_library=self.cmdstanpy,
coords=self.coords,
dims=self.dims,
index_origin=self.index_origin,
skip_event_dims="log_likelihood" == group,
)
return res
@requires("posterior")
def posterior_to_xarray(self):
"""Extract posterior samples from output csv."""
items = list(self.posterior.stan_variables().keys())
if self.posterior_predictive is not None:
try:
items = _filter(items, self.posterior_predictive)
except ValueError:
pass
if self.predictions is not None:
try:
items = _filter(items, self.predictions)
except ValueError:
pass
if self.log_likelihood is not None:
try:
items = _filter(items, self.log_likelihood)
except ValueError:
pass
data, data_warmup = _unpack_fit(
self.posterior,
items,
self.save_warmup,
self.dtypes,
)
return self._warmup_return_to_dict(data, data_warmup, "posterior")
@requires("posterior")
def sample_stats_to_xarray(self):
"""Extract sample_stats from prosterior fit."""
data, data_warmup = self.stats_to_xarray(self.posterior)
return self._warmup_return_to_dict(data, data_warmup, "sample_stats")
@requires("prior")
def sample_stats_prior_to_xarray(self):
"""Extract sample_stats from prior fit."""
data, data_warmup = self.stats_to_xarray(self.prior)
return self._warmup_return_to_dict(data, data_warmup, "sample_stats_prior")
def stats_to_xarray(self, fit):
"""Extract sample_stats from fit."""
dtypes = {
"divergent__": bool,
"n_leapfrog__": np.int64,
"treedepth__": np.int64,
**self.dtypes,
}
items = list(fit.method_variables().keys())
rename_dict = {
"divergent": "diverging",
"n_leapfrog": "n_steps",
"treedepth": "tree_depth",
"stepsize": "step_size",
"accept_stat": "acceptance_rate",
}
data, data_warmup = _unpack_fit(
fit,
items,
self.save_warmup,
self.dtypes,
)
for item in items:
name = re.sub("__$", "", item)
name = rename_dict.get(name, name)
data[name] = data.pop(item).astype(dtypes.get(item, float))
if data_warmup:
data_warmup[name] = data_warmup.pop(item).astype(dtypes.get(item, float))
return (data, data_warmup)
@requires("posterior")
@requires("posterior_predictive")
def posterior_predictive_to_xarray(self):
"""Convert posterior_predictive samples to xarray."""
data, data_warmup = self.predictive_to_xarray(self.posterior_predictive, self.posterior)
return self._warmup_return_to_dict(data, data_warmup, "posterior_predictive")
@requires("prior")
@requires("prior_predictive")
def prior_predictive_to_xarray(self):
"""Convert prior_predictive samples to xarray."""
data, data_warmup = self.predictive_to_xarray(self.prior_predictive, self.prior)
return self._warmup_return_to_dict(data, data_warmup, "prior_predictive")
def predictive_to_xarray(self, names, fit):
"""Convert predictive samples to xarray."""
predictive = _as_set(names)
data, data_warmup = _unpack_fit(
fit,
predictive,
self.save_warmup,
self.dtypes,
)
return (data, data_warmup)
@requires("posterior")
@requires("predictions")
def predictions_to_xarray(self):
"""Convert out of sample predictions samples to xarray."""
predictions = _as_set(self.predictions)
data, data_warmup = _unpack_fit(
self.posterior,
predictions,
self.save_warmup,
self.dtypes,
)
return self._warmup_return_to_dict(data, data_warmup, "predictions")
@requires("posterior")
@requires("log_likelihood")
def log_likelihood_to_xarray(self):
"""Convert elementwise log likelihood samples to xarray."""
log_likelihood = _as_set(self.log_likelihood)
data, data_warmup = _unpack_fit(
self.posterior,
log_likelihood,
self.save_warmup,
self.dtypes,
)
if isinstance(self.log_likelihood, dict):
data = {obs_name: data[lik_name] for obs_name, lik_name in self.log_likelihood.items()}
if data_warmup:
data_warmup = {
obs_name: data_warmup[lik_name]
for obs_name, lik_name in self.log_likelihood.items()
}
return self._warmup_return_to_dict(data, data_warmup, "log_likelihood")
@requires("prior")
def prior_to_xarray(self):
"""Convert prior samples to xarray."""
items = list(self.prior.stan_variables().keys())
if self.prior_predictive is not None:
try:
items = _filter(items, self.prior_predictive)
except ValueError:
pass
data, data_warmup = _unpack_fit(
self.prior,
items,
self.save_warmup,
self.dtypes,
)
return self._warmup_return_to_dict(data, data_warmup, "prior")
@requires("observed_data")
def observed_data_to_xarray(self):
"""Convert observed data to xarray."""
return dict_to_dataset(
self.observed_data,
inference_library=self.cmdstanpy,
coords=self.coords,
dims=self.dims,
sample_dims=[],
index_origin=self.index_origin,
)
@requires("constant_data")
def constant_data_to_xarray(self):
"""Convert constant data to xarray."""
return dict_to_dataset(
self.constant_data,
inference_library=self.cmdstanpy,
coords=self.coords,
dims=self.dims,
sample_dims=[],
index_origin=self.index_origin,
)
@requires("predictions_constant_data")
def predictions_constant_data_to_xarray(self):
"""Convert constant data to xarray."""
return dict_to_dataset(
self.predictions_constant_data,
inference_library=self.cmdstanpy,
coords=self.coords,
dims=self.dims,
sample_dims=[],
index_origin=self.index_origin,
)
def to_datatree(self):
"""Convert all available data to an InferenceData object.
Note that if groups can not be created (i.e., there is no `output`, so
the `posterior` and `sample_stats` can not be extracted), then the InferenceData
will not have those groups.
"""
datadict = {
"observed_data": self.observed_data_to_xarray(),
"constant_data": self.constant_data_to_xarray(),
"predictions_constant_data": self.predictions_constant_data_to_xarray(),
}
datalist = [
self.posterior_to_xarray(),
self.sample_stats_to_xarray(),
self.posterior_predictive_to_xarray(),
self.predictions_to_xarray(),
self.prior_to_xarray(),
self.sample_stats_prior_to_xarray(),
self.prior_predictive_to_xarray(),
self.log_likelihood_to_xarray(),
]
for ds_dict in datalist:
if ds_dict is not None:
datadict.update(ds_dict)
return DataTree.from_dict({group: ds for group, ds in datadict.items() if ds is not None})
def _as_set(spec):
"""Uniform representation for args which be name or list of names."""
if spec is None:
return []
if isinstance(spec, str):
return [spec]
try:
return set(spec.values())
except AttributeError:
return set(spec)
def _filter(names, spec):
"""Remove names from list of names."""
if isinstance(spec, str):
names.remove(spec)
elif isinstance(spec, list):
for item in spec:
names.remove(item)
elif isinstance(spec, dict):
for item in spec.values():
names.remove(item)
return names
def _unpack_fit(fit, items, save_warmup, dtypes):
"""Transform fit to dictionary containing ndarrays.
Parameters
----------
data : cmdstanpy.CmdStanMCMC
items : list
save_warmup : bool
dtypes : dict
Returns
-------
dict
key, values pairs. Values are formatted to shape = (chains, draws, *shape)
"""
num_warmup = 0
if save_warmup:
if not fit._save_warmup: # pylint: disable=protected-access
save_warmup = False
else:
num_warmup = fit.num_draws_warmup
nchains = fit.chains
sample = {}
sample_warmup = {}
stan_variables = set(fit.stan_variables())
method_variables = fit.method_variables()
for item in items:
if item in stan_variables:
raw_draws = fit.stan_variable(item, inc_warmup=save_warmup)
raw_draws = np.swapaxes(
raw_draws.reshape((-1, nchains, *raw_draws.shape[1:]), order="F"), 0, 1
)
elif item in method_variables:
raw_draws = np.swapaxes(method_variables[item].reshape((-1, nchains), order="F"), 0, 1)
else:
raise ValueError(f"fit data, unknown variable: {item}")
raw_draws = raw_draws.astype(dtypes.get(item))
if save_warmup:
if item in method_variables:
sample[item] = raw_draws
else:
sample_warmup[item] = raw_draws[:, :num_warmup, ...]
sample[item] = raw_draws[:, num_warmup:, ...]
else:
sample[item] = raw_draws
return sample, sample_warmup
[docs]
def from_cmdstanpy(
posterior=None,
*,
posterior_predictive=None,
predictions=None,
prior=None,
prior_predictive=None,
observed_data=None,
constant_data=None,
predictions_constant_data=None,
log_likelihood=None,
index_origin=None,
coords=None,
dims=None,
save_warmup=None,
dtypes=None,
):
"""Convert CmdStanPy data into an InferenceData object.
For a usage example read the
:ref:`Creating InferenceData section on from_cmdstanpy <creating_InferenceData>`
Parameters
----------
posterior : cmdstanpy.CmdStanMCMC, optional
CmdStanPy CmdStanMCMC
posterior_predictive : str, list of str, optional
Posterior predictive samples for the fit.
predictions : str, list of str, optional
Out of sample prediction samples for the fit.
prior : cmdstanpy.CmdStanMCMC, optional
CmdStanPy CmdStanMCMC
prior_predictive : str, list of str, optional
Prior predictive samples for the fit.
observed_data : dict, optional
Observed data used in the sampling.
constant_data : dict, optional
Constant data used in the sampling.
predictions_constant_data : dict, optional
Constant data for predictions used in the sampling.
log_likelihood : str, list of str, dict of {str: str}, optional
Pointwise log_likelihood for the data. If a dict, its keys should represent var_names
from the corresponding observed data and its values the stan variable where the
data is stored. By default, if a variable ``log_lik`` is present in the Stan model,
it will be retrieved as pointwise log likelihood values. Use ``False``
to avoid this behaviour.
index_origin : int, optional
Starting value of integer coordinate values. Defaults to the value in rcParam
``data.index_origin``.
coords : dict, optional
A dictionary containing the values that are used as index. The key
is the name of the dimension, the values are the index values.
dims : mapping of {hashable : sequence of hashable}, optional
A mapping from variables to a list of coordinate names for the variable.
save_warmup : bool, optional
Save warmup iterations into InferenceData object, if found in the input files.
If not defined, use default defined by the rcParams.
dtypes : dict, str or cmdstanpy.CmdStanModel, optional
A dictionary containing dtype information (int, float) for parameters.
If input is a string, it is assumed to be a model code or path to model code file.
Model code can extracted from cmdstanpy.CmdStanModel object.
Returns
-------
DataTree
"""
return CmdStanPyConverter(
posterior=posterior,
posterior_predictive=posterior_predictive,
predictions=predictions,
prior=prior,
prior_predictive=prior_predictive,
observed_data=observed_data,
constant_data=constant_data,
predictions_constant_data=predictions_constant_data,
log_likelihood=log_likelihood,
index_origin=index_origin,
coords=coords,
dims=dims,
save_warmup=save_warmup,
dtypes=dtypes,
).to_datatree()