Source code for arviz_plots.plots.khat_plot

"""Plot Pareto tail indices."""

from collections.abc import Mapping, Sequence
from importlib import import_module
from typing import Any, Literal

import numpy as np
import xarray as xr
from arviz_base import rcParams
from arviz_base.labels import BaseLabeller
from arviz_stats.base.stats_utils import calculate_khat_bin_edges

from arviz_plots.plot_collection import PlotCollection
from arviz_plots.plots.utils import (
    annotate_bin_text,
    enable_hover_labels,
    filter_aes,
    format_coords_as_labels,
    get_visual_kwargs,
    set_wrap_layout,
)
from arviz_plots.visuals import (
    annotate_xy,
    hline,
    labelled_title,
    labelled_x,
    labelled_y,
    scatter_xy,
    set_xlim,
    set_xticks,
)


[docs] def plot_khat( elpd_data, threshold=None, hover_format="{index}: {label}", legend=None, color=None, marker=None, hline_values=None, bin_format="{pct:.1f}%", plot_collection=None, backend=None, labeller=None, aes_by_visuals: Mapping[ Literal[ "khat", "threshold_text", "hover", "title", "xlabel", "ylabel", "ticks", ], Sequence[str], ] = None, visuals: Mapping[ Literal[ "khat", "hlines", "bin_text", "threshold_text", "hover", "title", "xlabel", "ylabel", "legend", "ticks", ], Mapping[str, Any] | bool, ] = None, **pc_kwargs, ): r"""Plot Pareto tail indices for diagnosing convergence in PSIS-LOO-CV. The Generalized Pareto distribution (GPD) is fitted to the largest importance ratios to diagnose convergence rates. The shape parameter :math:`\hat{k}` estimates the pre-asymptotic convergence rate based on the fractional number of finite moments. Values :math:`\hat{k} > 0.7` indicate impractically low convergence rates and unreliable estimates. Details are presented in [1]_ and [2]_. Parameters ---------- elpd_data : ELPDData ELPD data object returned by :func:`arviz_stats.loo` containing Pareto k diagnostics. threshold : float, optional Highlight khat values above this threshold with annotations. If None, no points are highlighted. hover_format : str, default ``"{index}: {label}"`` Format string for hover annotations. Supports ``{index}``, ``{label}``, and ``{value}``. legend : bool, optional Whether to display a legend when color aesthetics are active. If None, a legend is shown when a color mapping is available. color : color spec or str, optional Color for scatter points when no aesthetic mapping supplies one. If the value matches a dimension name, that dimension is mapped to the color aesthetic. marker : marker spec or str, optional Marker style for scatter points when no aesthetic mapping supplies one. If the value matches a dimension name, that dimension is mapped to the marker aesthetic. hline_values : sequence of float, optional Custom horizontal line positions. Defaults to [0.0, 0.7, 1.0]. bin_format : str, default ``"{pct:.1f}%"`` Format string for bin percentages. Supports ``{count}`` and ``{pct}`` placeholders. plot_collection : PlotCollection, optional backend : {"matplotlib", "bokeh", "plotly"}, optional Plotting backend to use. Defaults to ``rcParams["plot.backend"]``. labeller : labeller, optional aes_by_visuals : mapping of {str : sequence of str or False}, optional Mapping of visuals to aesthetics that should use their mapping in `plot_collection` when plotted. Valid keys are the same as for `visuals`. By default: * khat -> uses all available aesthetic mappings * threshold_text -> uses no aesthetic mappings * hover -> uses no aesthetic mappings * title -> uses no aesthetic mappings * xlabel -> uses no aesthetic mappings * ylabel -> uses no aesthetic mappings * ticks -> uses no aesthetic mappings visuals : mapping of {str : mapping or bool}, optional Valid keys are: * khat -> passed to :func:`~arviz_plots.visuals.scatter_xy` * hlines -> passed to :func:`~arviz_plots.visuals.hline`, defaults to False * bin_text -> passed to :func:`~arviz_plots.visuals.annotate_xy`, defaults to False * threshold_text -> passed to :func:`~arviz_plots.visuals.annotate_xy` * hover -> enables interactive hover annotations, defaults to False * title -> passed to :func:`~arviz_plots.visuals.labelled_title`, defaults to False * xlabel -> passed to :func:`~arviz_plots.visuals.labelled_x` * ylabel -> passed to :func:`~arviz_plots.visuals.labelled_y` * legend -> passed to :class:`arviz_plots.PlotCollection.add_legend` * ticks -> passed to :func:`~arviz_plots.visuals.set_xticks`, defaults to False **pc_kwargs Passed to :class:`arviz_plots.PlotCollection.wrap`. Returns ------- PlotCollection Warnings -------- When using custom markers via the ``visuals`` dict, ensure the marker type is compatible with your chosen backend. Not all marker types support separate facecolor and edgecolor across different backends. Examples -------- The most basic usage plots the Pareto k values from a LOO-CV computation. Each point represents one observation, with higher k values indicating less reliable importance sampling for that observation. .. plot:: :context: close-figs >>> from arviz_plots import plot_khat, style >>> style.use("arviz-variat") >>> from arviz_base import load_arviz_data >>> from arviz_stats import loo >>> dt = load_arviz_data("rugby") >>> elpd_data = loo(dt, var_name="home_points", pointwise=True) >>> plot_khat(elpd_data, figure_kwargs={"figsize": (10, 5)}) .. minigallery:: plot_khat References ---------- .. [1] Vehtari et al. *Practical Bayesian model evaluation using leave-one-out cross-validation and WAIC*. Statistics and Computing. 27(5) (2017). https://doi.org/10.1007/s11222-016-9696-4. arXiv preprint https://arxiv.org/abs/1507.04544. .. [2] Vehtari et al. *Pareto Smoothed Importance Sampling*. Journal of Machine Learning Research, 25(72) (2024) https://jmlr.org/papers/v25/19-556.html arXiv preprint https://arxiv.org/abs/1507.02646 """ if hline_values is None: good_k = getattr(elpd_data, "good_k", 0.7) hline_values = [0.0, good_k, 1.0] else: hline_values = list(hline_values) visuals = {} if visuals is None else visuals.copy() visuals.setdefault("title", False) if backend is None: if plot_collection is None: backend = rcParams["plot.backend"] else: backend = plot_collection.backend if labeller is None: labeller = BaseLabeller() if aes_by_visuals is None: aes_by_visuals = {} if not hasattr(elpd_data, "pareto_k") or elpd_data.pareto_k is None: raise ValueError( "Could not find 'pareto_k' in the ELPDData object. " "Please ensure the LOO computation includes Pareto k diagnostics." ) khat_data = elpd_data.pareto_k distribution = khat_data.to_dataset(name="pareto_k") n_data_points = khat_data.size khat_dims = list(khat_data.dims) coord_map = {dim: khat_data.coords[dim] for dim in khat_dims if dim in khat_data.coords} khat_flat = np.asarray(khat_data.values).reshape(-1) plot_bknd = import_module(f".backend.{backend}", package="arviz_plots") if plot_collection is None: pc_kwargs["figure_kwargs"] = pc_kwargs.get("figure_kwargs", {}).copy() pc_kwargs["figure_kwargs"].setdefault("sharex", False) pc_kwargs["figure_kwargs"].setdefault("sharey", True) pc_kwargs["aes"] = pc_kwargs.get("aes", {}).copy() if isinstance(color, str) and (color in distribution.dims or color in distribution.coords): pc_kwargs["aes"]["color"] = [color] color = None elif color is None and "model" in distribution.dims and "color" not in pc_kwargs["aes"]: pc_kwargs["aes"]["color"] = ["model"] if isinstance(marker, str) and ( marker in distribution.dims or marker in distribution.coords ): pc_kwargs["aes"]["marker"] = [marker] marker = None use_grid = "rows" in pc_kwargs and pc_kwargs["rows"] if use_grid: plot_collection = PlotCollection.grid( distribution, backend=backend, **pc_kwargs, ) else: pc_kwargs.setdefault("cols", []) pc_kwargs = set_wrap_layout(pc_kwargs, plot_bknd, distribution) plot_collection = PlotCollection.wrap( distribution, backend=backend, **pc_kwargs, ) aes_by_visuals.setdefault("khat", plot_collection.aes_set) aes_by_visuals.setdefault("threshold_text", []) aes_by_visuals.setdefault("hover", []) aes_by_visuals.setdefault("title", []) aes_by_visuals.setdefault("xlabel", []) aes_by_visuals.setdefault("ylabel", []) aes_by_visuals.setdefault("ticks", []) reduce_dims = [d for d in khat_data.dims if d not in plot_collection.facet_dims] has_facets = bool(reduce_dims and plot_collection.facet_dims) if has_facets: reduce_size = int(np.prod([khat_data.sizes[d] for d in reduce_dims])) x_positions_per_facet = np.arange(reduce_size).reshape( [khat_data.sizes[d] for d in reduce_dims] ) xdata = xr.DataArray( x_positions_per_facet, dims=reduce_dims, coords={d: khat_data.coords[d] for d in reduce_dims if d in khat_data.coords}, ).broadcast_like(khat_data) else: x_positions = ( np.arange(n_data_points).reshape(khat_data.shape) if n_data_points else np.zeros(khat_data.shape, dtype=float) ) xdata = xr.DataArray(x_positions, dims=khat_dims, coords=coord_map, name="pareto_k") x_flat = np.asarray(xdata.values).reshape(-1) x_min = x_flat.min() if x_flat.size else 0.0 x_dataset = xr.Dataset({"pareto_k": xdata}) khat_dataset = xr.concat([x_dataset, distribution], dim="plot_axis").assign_coords( plot_axis=["x", "y"] ) new_xlim = None flat_coord_labels = None hover_label_data = None if n_data_points and ( threshold is not None or get_visual_kwargs(visuals, "ticks", default=False) is not False or get_visual_kwargs(visuals, "hover", default=False) is not False ): flat_coord_labels = format_coords_as_labels(khat_data, labeller=labeller) if flat_coord_labels.size == khat_data.size: hover_label_data = xr.DataArray( flat_coord_labels.reshape(khat_data.shape), dims=khat_dims, coords=coord_map, name="labels", ) scalar_ds = xr.Dataset({"pareto_k": xr.DataArray(0)}) hlines_kwargs = get_visual_kwargs(visuals, "hlines", default=False) if hlines_kwargs is not False and hline_values: def _hline_scalar(da, target, **kw): scalar_val = da.values.flat[0] if da.size > 0 else 0 return hline(xr.DataArray(scalar_val), target, **kw) for idx, value in enumerate(hline_values): hline_kwargs = hlines_kwargs.copy() hline_kwargs.setdefault( "linestyle", plot_bknd.get_default_aes("linestyle", len(hline_values), {})[idx] ) hline_kwargs.setdefault("color", f"C{idx + 1}") hline_kwargs.setdefault("alpha", 0.7) hline_data = xr.full_like(khat_data, value) hline_dataset = hline_data.to_dataset(name="pareto_k") plot_collection.map( _hline_scalar, f"hline_{idx}", data=hline_dataset, ignore_aes="all", **hline_kwargs, ) khat_kwargs = get_visual_kwargs(visuals, "khat") if khat_kwargs is not False: _, khat_aes, khat_ignore = filter_aes(plot_collection, aes_by_visuals, "khat", []) if "color" not in khat_aes: khat_kwargs.setdefault("color", color if color is not None else "C0") if "marker" not in khat_aes and marker is not None: khat_kwargs.setdefault("marker", marker) plot_collection.map( scatter_xy, "khat", data=khat_dataset, ignore_aes=khat_ignore, **khat_kwargs, ) bin_text_kwargs = get_visual_kwargs(visuals, "bin_text", default=False) if bin_text_kwargs is not False: bin_text_kwargs.setdefault("color", "B1") bin_text_kwargs.setdefault("horizontal_align", "center") bin_edges = calculate_khat_bin_edges(khat_flat, list(hline_values)) if bin_edges is not None and n_data_points: _, _, bin_text_ignore = filter_aes(plot_collection, aes_by_visuals, "bin_text", []) if reduce_dims: x_max_per_facet = xdata.max(dim=reduce_dims) else: x_max_per_facet = xdata.max() span = x_flat.max() - x_flat.min() if x_flat.size else 1.0 span = max(span, 1.0) x_margin = max(0.05 * span, 0.5) if plot_collection.facet_dims: x_text_per_facet = x_max_per_facet + x_margin # We need to extract the scalar value here for Bokeh compatibility new_xlim_max = float(x_max_per_facet.max().item()) + x_margin else: x_text_per_facet = x_flat.max() + x_margin if x_flat.size else x_margin new_xlim_max = x_text_per_facet + x_margin new_xlim = (x_min, new_xlim_max) num_bins = len(bin_edges) - 1 bin_edges_arr = np.array(bin_edges) y_positions = (bin_edges_arr[:-1] + bin_edges_arr[1:]) / 2 def compute_bin_counts(data_slice): flat_data = np.asarray(data_slice).reshape(-1) if flat_data.size == 0: return np.zeros(num_bins, dtype=int) counts, _ = np.histogram(flat_data, bins=bin_edges) return counts if reduce_dims: counts_per_facet = xr.apply_ufunc( compute_bin_counts, khat_data, input_core_dims=[reduce_dims], output_core_dims=[["bin"]], vectorize=True, ) n_per_facet = khat_data.count(dim=reduce_dims) else: counts_per_facet = xr.DataArray(compute_bin_counts(khat_data), dims=["bin"]) n_per_facet = khat_data.size for i in range(num_bins): bin_counts = ( counts_per_facet.isel(bin=i) if "bin" in counts_per_facet.dims else counts_per_facet[i] ) plot_collection.map( annotate_bin_text, f"bin_{i}", data=distribution, x=x_text_per_facet, y=y_positions[i], count_da=bin_counts, n_da=n_per_facet, bin_format=bin_format, ignore_aes=bin_text_ignore, **bin_text_kwargs, ) threshold_text_kwargs = get_visual_kwargs(visuals, "threshold_text") if ( threshold_text_kwargs is not False and threshold is not None and flat_coord_labels is not None ): _, _, threshold_text_ignore = filter_aes( plot_collection, aes_by_visuals, "threshold_text", [] ) threshold_text_kwargs.setdefault("color", "B1") threshold_text_kwargs.setdefault("vertical_align", "bottom") threshold_text_kwargs.setdefault("horizontal_align", "center") mask = np.asarray(khat_data > threshold).reshape(-1) indices = np.flatnonzero(mask) for flat_idx in indices: label_text = str(flat_coord_labels[flat_idx]) plot_collection.map( annotate_xy, f"threshold_{flat_idx}", data=scalar_ds, x=x_flat[flat_idx], y=khat_flat[flat_idx], text=label_text, ignore_aes=threshold_text_ignore, **threshold_text_kwargs, ) ticks_kwargs = get_visual_kwargs(visuals, "ticks", default=False) if ticks_kwargs is not False and flat_coord_labels is not None: if flat_coord_labels.size: ticks_kwargs.setdefault("rotation", 45) plot_collection.map( set_xticks, "ticks", data=scalar_ds, values=x_flat.tolist(), labels=[str(label) for label in flat_coord_labels], ignore_aes="all", store_artist=False, **ticks_kwargs, ) title_kwargs = get_visual_kwargs(visuals, "title") if title_kwargs is not False: _, title_aes, title_ignore = filter_aes(plot_collection, aes_by_visuals, "title", []) if "color" not in title_aes: title_kwargs.setdefault("color", "B1") def title_coords_only(da, target, sel=None, isel=None, **kw): text = labeller.sel_to_str(sel, isel) if (sel or isel) else None return labelled_title(da, target, text=text, **kw) plot_collection.map( title_coords_only, "title", ignore_aes=title_ignore, subset_info=True, **title_kwargs, ) xlabel_kwargs = get_visual_kwargs(visuals, "xlabel") if xlabel_kwargs is not False: _, xlabel_aes, xlabel_ignore = filter_aes(plot_collection, aes_by_visuals, "xlabel", []) if "color" not in xlabel_aes: xlabel_kwargs.setdefault("color", "B1") xlabel_kwargs.setdefault("text", "Data Point") plot_collection.map( labelled_x, "xlabel", ignore_aes=xlabel_ignore, subset_info=True, **xlabel_kwargs, ) ylabel_kwargs = get_visual_kwargs(visuals, "ylabel") if ylabel_kwargs is not False: _, ylabel_aes, ylabel_ignore = filter_aes(plot_collection, aes_by_visuals, "ylabel", []) if "color" not in ylabel_aes: ylabel_kwargs.setdefault("color", "B1") ylabel_kwargs.setdefault("text", "Shape parameter k") plot_collection.map( labelled_y, "ylabel", ignore_aes=ylabel_ignore, subset_info=True, **ylabel_kwargs, ) if legend is not False: legend_kwargs = get_visual_kwargs(visuals, "legend") if legend_kwargs is not False and "color" in plot_collection.aes.children: color_mapping = plot_collection.aes["color"].data_vars.get("mapping") if color_mapping is not None: legend_kwargs.setdefault("dim", list(color_mapping.dims) or ["color"]) plot_collection.add_legend(**legend_kwargs) if new_xlim is not None: plot_collection.map( set_xlim, "xlim", data=distribution, ignore_aes="all", store_artist=False, limits=new_xlim, ) hover_kwargs = get_visual_kwargs(visuals, "hover", default=False) if hover_kwargs is not False and flat_coord_labels is not None: if hover_label_data is None: hover_label_data = xr.DataArray( flat_coord_labels.reshape(khat_data.shape), dims=khat_dims, coords=coord_map, name="labels", ) enable_hover_labels( backend, plot_collection, hover_format, labels=hover_label_data, colors=None, values=khat_data, ) return plot_collection