Source code for arviz_base.datasets

"""Base IO code for all datasets. Heavily influenced by scikit-learn's implementation."""

import hashlib
import itertools
import json
import os
import shutil
from collections import namedtuple
from urllib.request import urlretrieve

from xarray import open_datatree

from arviz_base.rcparams import rcParams

__all__ = ["get_data_home", "clear_data_home", "list_datasets", "load_arviz_data"]

LocalFileMetadata = namedtuple("LocalFileMetadata", ["name", "filename", "description"])

RemoteFileMetadata = namedtuple(
    "RemoteFileMetadata", ["name", "filename", "url", "checksum", "description"]
)
_EXAMPLE_DATA_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "example_data")
_LOCAL_DATA_DIR = os.path.join(_EXAMPLE_DATA_DIR, "data")

with open(os.path.join(_EXAMPLE_DATA_DIR, "data_local.json"), encoding="utf-8") as f:
    LOCAL_DATASETS = {
        entry["name"]: LocalFileMetadata(
            name=entry["name"],
            filename=os.path.join(_LOCAL_DATA_DIR, entry["filename"]),
            description=entry["description"],
        )
        for entry in json.load(f)
    }

with open(os.path.join(_EXAMPLE_DATA_DIR, "data_remote.json"), encoding="utf-8") as f:
    REMOTE_DATASETS = {entry["name"]: RemoteFileMetadata(**entry) for entry in json.load(f)}


[docs] def get_data_home(data_home=None): """Return the path of the arviz data dir. This folder is used by some dataset loaders to avoid downloading the data several times. By default the data dir is set to a folder named 'arviz_data' in the user home folder. Alternatively, it can be set by the :envvar:`ARVIZ_DATA` environment variable or programmatically by giving an explicit folder path. The '~' symbol is expanded to the user home folder. If the folder does not already exist, it is automatically created. Parameters ---------- data_home : str, optional The path to arviz data dir. """ if data_home is None: data_home = os.environ.get("ARVIZ_DATA", os.path.join("~", "arviz_data")) data_home = os.path.expanduser(data_home) if not os.path.exists(data_home): os.makedirs(data_home) return data_home
[docs] def clear_data_home(data_home=None): """Delete all the content of the data home cache. Parameters ---------- data_home : str, optional The path to arviz data dir. """ data_home = get_data_home(data_home) shutil.rmtree(data_home)
def _sha256(path): """Calculate the sha256 hash of the file at path.""" sha256hash = hashlib.sha256() chunk_size = 8192 with open(path, "rb") as buff: while True: buffer = buff.read(chunk_size) if not buffer: break sha256hash.update(buffer) return sha256hash.hexdigest()
[docs] def load_arviz_data(dataset=None, data_home=None, **kwargs): """Load a local or remote pre-made dataset. Run with no parameters to get a list of all available models. The directory to save to can also be set with the environment variable :envvar:`ARVIZ_DATA`. The checksum of the dataset is checked against a hardcoded value to watch for data corruption. Run `az.clear_data_home` to clear the data directory. Parameters ---------- dataset : str Name of dataset to load. data_home : str, optional Where to save remote datasets **kwargs : dict, optional Keyword arguments passed to :func:`arviz.from_netcdf`. Returns ------- xarray.Dataset """ if dataset in LOCAL_DATASETS: resource = LOCAL_DATASETS[dataset] return open_datatree(resource.filename, **kwargs).load() if dataset in REMOTE_DATASETS: remote = REMOTE_DATASETS[dataset] home_dir = get_data_home(data_home=data_home) file_path = os.path.join(home_dir, remote.filename) if not os.path.exists(file_path): http_type = rcParams["data.http_protocol"] # Replaces http type. Redundant if http_type is http, useful if http_type is https url = remote.url.replace("http", http_type) urlretrieve(url, file_path) checksum = _sha256(file_path) if remote.checksum != checksum: raise OSError( f"{file_path} has an SHA256 checksum ({checksum}) differing from expected " "({remote.checksum}), file may be corrupted. " "Run `arviz.clear_data_home()` and try again, or please open an issue." ) return open_datatree(file_path, **kwargs).load() if dataset is None: return dict(itertools.chain(LOCAL_DATASETS.items(), REMOTE_DATASETS.items())) raise ValueError( f"Dataset {dataset} not found! The following are available:\n{list_datasets()}" )
[docs] def list_datasets(): """Get a string representation of all available datasets with descriptions. Returns ------- list of str """ lines = [] for name, resource in itertools.chain(LOCAL_DATASETS.items(), REMOTE_DATASETS.items()): if isinstance(resource, LocalFileMetadata): location = f"local: {resource.filename}" elif isinstance(resource, RemoteFileMetadata): location = f"remote: {resource.url}" else: location = "unknown" lines.append(f"{name}\n{'=' * len(name)}\n{resource.description}\n\n{location}") return f"\n\n{10 * '-'}\n\n".join(lines)