Source code for arviz_stats.loo.loo_approximate_posterior

"""Compute PSIS-LOO-CV for approximate posteriors."""

from arviz_base import rcParams

from arviz_stats.loo.helper_loo import (
    _check_log_density,
    _check_log_jacobian,
    _prepare_loo_inputs,
    _warn_pareto_k,
    _warn_pointwise_loo,
)
from arviz_stats.utils import ELPDData


[docs] def loo_approximate_posterior(data, log_p, log_q, pointwise=None, var_name=None, log_jacobian=None): r"""Compute PSIS-LOO-CV for approximate posteriors. Estimates the expected log pointwise predictive density (elpd) using Pareto-smoothed importance sampling leave-one-out cross-validation (PSIS-LOO-CV) for approximate posteriors (e.g., from variational inference). Requires log-densities of the target (log_p) and proposal (log_q) distributions. The PSIS-LOO-CV method is described in [1]_ and [2]_. The approximate posterior correction is computed using the method described in [3]_. See the EABM chapter on `Model Comparison for Large Data <https://arviz-devs.github.io/EABM/Chapters/Model_comparison_large_data.html>`_ for more details. Parameters ---------- data : DataTree or InferenceData Input data. It should contain the log_likelihood group corresponding to samples drawn from the proposal distribution (q). log_p : ndarray or DataArray The (target) log-density evaluated at S samples from the target distribution (p). If ndarray, should be a vector of length S where S is the number of samples. If DataArray, should have dimensions matching the sample dimensions ("chain", "draw"). log_q : ndarray or DataArray The (proposal) log-density evaluated at S samples from the proposal distribution (q). If ndarray, should be a vector of length S where S is the number of samples. If DataArray, should have dimensions matching the sample dimensions ("chain", "draw"). pointwise : bool, optional If True, returns pointwise values. Defaults to rcParams["stats.ic_pointwise"]. var_name : str, optional The name of the variable in log_likelihood groups storing the pointwise log likelihood data to use for loo computation. log_jacobian : DataArray, optional Log-Jacobian adjustment for variable transformations. Required when the model was fitted on transformed response data :math:`z = T(y)` but you want to compute ELPD on the original response scale :math:`y`. The value should be :math:`\log|\frac{dz}{dy}|` (the log absolute value of the derivative of the transformation). Must be a DataArray with dimensions matching the observation dimensions. Returns ------- ELPDData Object with the following attributes: - **kind**: "loo" - **elpd**: expected log pointwise predictive density - **se**: standard error of the elpd - **p**: effective number of parameters - **n_samples**: number of samples - **n_data_points**: number of data points - **scale**: "log" - **warning**: True if the estimated shape parameter of Pareto distribution is greater than ``good_k``. - **good_k**: For a sample size S, the threshold is computed as ``min(1 - 1/log10(S), 0.7)`` - **elpd_i**: :class:`~xarray.DataArray` with the pointwise predictive accuracy, only if ``pointwise=True`` - **pareto_k**: :class:`~xarray.DataArray` with Pareto shape values, only if ``pointwise=True`` - **approx_posterior**: True (approximate posterior correction applied) Examples -------- To calculate PSIS-LOO-CV for posterior approximations, we need to provide the log-densities of the target and proposal distributions. Here we use dummy log-densities. In practice, the log-densities would typically be computed by a posterior approximation method such as the Laplace approximation or automatic differentiation variational inference (ADVI): .. ipython:: In [1]: import numpy as np ...: import xarray as xr ...: from arviz_stats import loo_approximate_posterior ...: from arviz_base import load_arviz_data, extract ...: ...: data = load_arviz_data("centered_eight") ...: log_lik = extract(data, group="log_likelihood", var_names="obs", combined=False) ...: rng = np.random.default_rng(214) ...: ...: values_p = rng.normal(loc=0, scale=1, size=(log_lik.chain.size, log_lik.draw.size)) ...: log_p = xr.DataArray( ...: values_p, ...: dims=["chain", "draw"], ...: coords={"chain": log_lik.chain, "draw": log_lik.draw} ...: ) ...: ...: values_q = rng.normal(loc=-1, scale=1, size=(log_lik.chain.size, log_lik.draw.size)) ...: log_q = xr.DataArray( ...: values_q, ...: dims=["chain", "draw"], ...: coords={"chain": log_lik.chain, "draw": log_lik.draw} ...: ) Now we can calculate pointwise PSIS-LOO-CV for posterior approximations: .. ipython:: In [2]: loo_approx = loo_approximate_posterior( ...: data, ...: log_p=log_p, ...: log_q=log_q, ...: var_name="obs", ...: pointwise=True ...: ) ...: loo_approx We can also calculate the PSIS-LOO-CV for posterior approximations with subsampling for large datasets: .. ipython:: In [3]: from arviz_stats import loo_subsample ...: loo_approx_subsample = loo_subsample( ...: data, ...: observations=4, ...: var_name="obs", ...: log_p=log_p, ...: log_q=log_q, ...: pointwise=True ...: ) ...: loo_approx_subsample See Also -------- loo : Standard PSIS-LOO-CV. loo_subsample : Sub-sampled PSIS-LOO-CV. compare : Compare models based on their ELPD. References ---------- .. [1] Vehtari et al. *Practical Bayesian model evaluation using leave-one-out cross-validation and WAIC*. Statistics and Computing. 27(5) (2017) https://doi.org/10.1007/s11222-016-9696-4 arXiv preprint https://arxiv.org/abs/1507.04544. .. [2] Vehtari et al. *Pareto Smoothed Importance Sampling*. Journal of Machine Learning Research, 25(72) (2024) https://jmlr.org/papers/v25/19-556.html arXiv preprint https://arxiv.org/abs/1507.02646 .. [3] Magnusson, M., Riis Andersen, M., Jonasson, J., & Vehtari, A. *Bayesian Leave-One-Out Cross-Validation for Large Data.* Proceedings of the 36th International Conference on Machine Learning, PMLR 97:4244–4253 (2019) https://proceedings.mlr.press/v97/magnusson19a.html arXiv preprint https://arxiv.org/abs/1904.10679 """ loo_inputs = _prepare_loo_inputs(data, var_name) pointwise = rcParams["stats.ic_pointwise"] if pointwise is None else pointwise log_likelihood = loo_inputs.log_likelihood sample_dims = loo_inputs.sample_dims obs_dims = [dim for dim in log_likelihood.dims if dim not in sample_dims] log_p = _check_log_density(log_p, "log_p", log_likelihood, loo_inputs.n_samples, sample_dims) log_q = _check_log_density(log_q, "log_q", log_likelihood, loo_inputs.n_samples, sample_dims) jacobian_da = _check_log_jacobian(log_jacobian, obs_dims) elpd_i, pareto_k, p_loo_i = log_likelihood.azstats.loo_approximate_posterior( log_p=log_p, log_q=log_q, sample_dims=sample_dims, log_jacobian=jacobian_da ) warn_mg, good_k = _warn_pareto_k(pareto_k, loo_inputs.n_samples) elpd, elpd_se, p_loo, _ = elpd_i.azstats.loo_summary(p_loo_i) if pointwise: _warn_pointwise_loo(elpd, elpd_i.values) return ELPDData( "loo", elpd, elpd_se, p_loo, loo_inputs.n_samples, loo_inputs.n_data_points, "log", warn_mg, good_k, elpd_i if pointwise else None, pareto_k if pointwise else None, approx_posterior=True, )