Source code for arviz_plots.plots.ppc_interval_plot

"""Predictive intervals plot."""
from collections.abc import Mapping, Sequence
from importlib import import_module
from typing import Any, Literal

import numpy as np
import xarray as xr
from arviz_base import rcParams
from arviz_base.labels import BaseLabeller

from arviz_plots.plot_collection import PlotCollection
from arviz_plots.plots.utils import (
    filter_aes,
    get_visual_kwargs,
    process_group_variables_coords,
    set_wrap_layout,
)
from arviz_plots.visuals import ci_bound_y, labelled_title, labelled_x, labelled_y, point_y


[docs] def plot_ppc_interval( dt, var_names=None, filter_vars=None, group="posterior_predictive", coords=None, sample_dims=None, point_estimate=None, ci_kind=None, ci_probs=None, plot_collection=None, backend=None, labeller=None, aes_by_visuals: Mapping[ Literal[ "trunk", "twig", "observed_markers", "prediction_markers", "xlabel", "ylabel", "title", ], Sequence[str] | bool, ] = None, visuals: Mapping[ Literal[ "trunk", "twig", "observed_markers", "prediction_markers", "xlabel", "ylabel", "title", ], Mapping[str, Any] | bool, ] = None, stats: Mapping[ Literal["trunk", "twig", "point_estimate"], Mapping[str, Any] | xr.Dataset ] = None, **pc_kwargs, ): """Plot posterior predictive intervals with observed data overlaid. Displays observed data as a point and predicted data as a point estimate plus two credible intervals. Parameters ---------- dt : DataTree Input data. It should contain the ``posterior_predictive`` and ``observed_data`` groups. var_names : str or list of str, optional One or more variables to be plotted. Prefix the variables by ~ when you want to exclude them from the plot. filter_vars : {None, "like", "regex"}, default=None If None, interpret var_names as the real variables names. If "like", interpret var_names as substrings of the real variables names. If "regex", interpret var_names as regular expressions on the real variables names. group : str Group to be plotted. Defaults to "posterior_predictive". It could also be "prior_predictive". coords : dict, optional Coordinates of `var_names` to be plotted. sample_dims : str or sequence of hashable, optional Dimensions to reduce unless mapped to an aesthetic. Defaults to ``rcParams["data.sample_dims"]`` point_estimate : {"mean", "median", "mode"}, optional Which point estimate to plot for the predictive distribution. Defaults to rcParam ``stats.point_estimate``. ci_kind : {"hdi", "eti"}, optional Which credible interval to use. Defaults to ``rcParams["stats.ci_kind"]``. ci_probs : (float, float), optional Indicates the probabilities for the inner (twig) and outer (trunk) credible intervals. Defaults to ``(0.5, rcParams["stats.ci_prob"])``. It's assumed that ``ci_probs[0] < ci_probs[1]``, otherwise they are sorted. plot_collection : PlotCollection, optional backend : {"matplotlib", "bokeh", "plotly", "none"}, optional labeller : labeller, optional aes_by_visuals : mapping of {str : sequence of str or False}, optional Mapping of visuals to aesthetics that should use their mapping in `plot_collection` when plotted. visuals : mapping of {str : mapping or bool}, optional Valid keys are: * trunk, twig -> passed to :func:`~arviz_plots.visuals.ci_bound_y` * observed_markers -> passed to :func:`~arviz_plots.visuals.point_y` * prediction_markers -> passed to :func:`~arviz_plots.visuals.point_y` * xlabel -> passed to :func:`~arviz_plots.visuals.labelled_x` * ylabel -> passed to :func:`~arviz_plots.visuals.labelled_y` * title -> passed to :func:`~arviz_plots.visuals.labelled_title` defaults to False stats : mapping, optional Valid keys are: * trunk, twig -> passed to eti or hdi * point_estimate -> passed to mean, median or mode **pc_kwargs Passed to :class:`arviz_plots.PlotCollection.grid` Returns ------- PlotCollection See Also -------- plot_ppc_dist : Plot 1D marginals for the posterior/prior predictive and observed data. plot_ppc_rootogram : Plot ppc rootogram for discrete (count) data plot_forest : Plot forest plot for posterior/prior groups. Examples -------- Plot posterior predictive intervals for the radon dataset, with custom styling. .. plot:: :context: close-figs >>> from arviz_base import load_arviz_data >>> import arviz_plots as azp >>> azp.style.use("arviz-variat") >>> data = load_arviz_data("rugby") >>> pc = azp.plot_ppc_interval(data) """ if sample_dims is None: sample_dims = rcParams["data.sample_dims"] if isinstance(sample_dims, str): sample_dims = [sample_dims] sample_dims = list(sample_dims) if stats is None: stats = {} else: stats = stats.copy() if visuals is None: visuals = {} else: visuals = visuals.copy() if ci_kind is None: ci_kind = rcParams["stats.ci_kind"] if ci_kind not in ("hdi", "eti"): raise ValueError(f"ci_kind must be either 'hdi' or 'eti', but {ci_kind} was passed.") if point_estimate is None: point_estimate = rcParams["stats.point_estimate"] if ci_probs is None: rc_ci_prob = rcParams["stats.ci_prob"] ci_probs = (0.5, rc_ci_prob) ci_probs = np.array(ci_probs) if ci_probs.size != 2: raise ValueError("ci_probs must have two elements for twig and trunk intervals.") if np.any(ci_probs < 0) or np.any(ci_probs > 1): raise ValueError("ci_probs must be between 0 and 1.") ci_probs.sort() if backend is None: if plot_collection is None: backend = rcParams["plot.backend"] else: backend = plot_collection.backend if labeller is None: labeller = BaseLabeller() visuals.setdefault("title", False) plot_bknd = import_module(f".backend.{backend}", package="arviz_plots") ds_predictive = process_group_variables_coords( dt, group=group, var_names=var_names, filter_vars=filter_vars, coords=coords ) # Extract observed data if "observed_data" in dt: observed_data = process_group_variables_coords( dt, group="observed_data", var_names=var_names, filter_vars=filter_vars, coords=coords, ) else: observed_data = None if ci_kind == "eti": ci_fun = ds_predictive.azstats.eti elif ci_kind == "hdi": ci_fun = ds_predictive.azstats.hdi ci_trunk = ci_fun(prob=ci_probs[1], dim=sample_dims, **stats.get("trunk", {})) ci_twig = ci_fun(prob=ci_probs[0], dim=sample_dims, **stats.get("twig", {})) if point_estimate == "median": point = ds_predictive.median(dim=sample_dims, **stats.get("point_estimate", {})) elif point_estimate == "mean": point = ds_predictive.mean(dim=sample_dims, **stats.get("point_estimate", {})) elif point_estimate == "mode": point = ds_predictive.azstats.mode(dim=sample_dims, **stats.get("point_estimate", {})) else: raise ValueError( f"point_estimate must be 'mean', 'median' or 'mode', but {point_estimate} was passed." ) if plot_collection is None: pc_kwargs["figure_kwargs"] = pc_kwargs.get("figure_kwargs", {}).copy() pc_kwargs["aes"] = pc_kwargs.get("aes", {}).copy() pc_kwargs.setdefault("cols", "__variable__") pc_kwargs = set_wrap_layout(pc_kwargs, plot_bknd, ds_predictive) plot_collection = PlotCollection.wrap( ds_predictive, backend=backend, **pc_kwargs, ) visuals = {} if visuals is None else visuals aes_by_visuals = {} if aes_by_visuals is None else aes_by_visuals ## trunk intervals ci_trunk_kwargs = get_visual_kwargs(visuals, "trunk") _, ci_trunk_aes, ci_trunk_ignore = filter_aes( plot_collection, aes_by_visuals, "trunk", sample_dims ) if ci_trunk_kwargs is not False: if "color" not in ci_trunk_aes: ci_trunk_kwargs.setdefault("color", "C0") ci_trunk_kwargs.setdefault("alpha", 0.5) ci_trunk_kwargs.setdefault("width", 3) plot_collection.map( ci_bound_y, "trunk", data=ci_trunk, ignore_aes=ci_trunk_ignore, **ci_trunk_kwargs, ) ## twig intervals ci_twig_kwargs = get_visual_kwargs(visuals, "twig") _, ci_twig_aes, ci_twig_ignore = filter_aes( plot_collection, aes_by_visuals, "twig", sample_dims ) if ci_twig_kwargs is not False: if "color" not in ci_twig_aes: ci_twig_kwargs.setdefault("color", "C0") ci_twig_kwargs.setdefault("width", 3) plot_collection.map( ci_bound_y, "twig", data=ci_twig, ignore_aes=ci_twig_ignore, **ci_twig_kwargs, ) ## prediction_markers prediction_ms_kwargs = get_visual_kwargs(visuals, "prediction_markers") if prediction_ms_kwargs is not False: _, _, prediction_ms_ignore = filter_aes( plot_collection, aes_by_visuals, "prediction_markers", sample_dims ) prediction_ms_kwargs.setdefault("color", "C0") prediction_ms_kwargs.setdefault("marker", "C0") plot_collection.map( point_y, "prediction_markers", data=point, ignore_aes=prediction_ms_ignore, **prediction_ms_kwargs, ) ## observed_markers observed_ms_kwargs = get_visual_kwargs( visuals, "observed_markers", False if group == "prior_predictive" or observed_data is None else None, ) if observed_ms_kwargs is not False: _, _, observed_ms_ignore = filter_aes( plot_collection, aes_by_visuals, "observed_markers", sample_dims ) observed_ms_kwargs.setdefault("color", "B1") observed_ms_kwargs.setdefault("marker", "C6") plot_collection.map( point_y, "observed_markers", data=observed_data, ignore_aes=observed_ms_ignore, **observed_ms_kwargs, ) # set xlabel _, xlabels_aes, xlabels_ignore = filter_aes( plot_collection, aes_by_visuals, "xlabel", sample_dims ) xlabel_kwargs = get_visual_kwargs(visuals, "xlabel") if xlabel_kwargs is not False: if "color" not in xlabels_aes: xlabel_kwargs.setdefault("color", "B1") xlabel_kwargs.setdefault("text", "data point (index)") plot_collection.map( labelled_x, "xlabel", ignore_aes=xlabels_ignore, subset_info=True, **xlabel_kwargs, ) # set ylabel _, ylabels_aes, ylabels_ignore = filter_aes( plot_collection, aes_by_visuals, "ylabel", sample_dims ) ylabel_kwargs = get_visual_kwargs(visuals, "ylabel") if ylabel_kwargs is not False: if "color" not in ylabels_aes: ylabel_kwargs.setdefault("color", "B1") plot_collection.map( labelled_y, "ylabel", ignore_aes=ylabels_ignore, subset_info=True, labeller=labeller, **ylabel_kwargs, ) # set title _, title_aes, title_ignore = filter_aes(plot_collection, aes_by_visuals, "title", sample_dims) title_kwargs = get_visual_kwargs(visuals, "title") if title_kwargs is not False: if "color" not in title_aes: title_kwargs.setdefault("color", "B1") plot_collection.map( labelled_title, "title", ignore_aes=title_ignore, subset_info=True, labeller=labeller, **title_kwargs, ) return plot_collection