Bayesian Statistics TutorialIntermediate

PyMC Model Summary

A hands-on PyMC tutorial on model summary & structure.

In this tutorial, we will cover tools in PyMC to produce model summary and structure.

PyMC Model Summary & Structure

In this tutorial, we will focus on some useful tools for providing model summaries & structure.

Let's start by loading some useful libraries with pip or within your IDE.

pip install numpy pymc arviz graphviz

Now let's move into python and being with some imports.

import pymc as pm
import numpy as np
import arviz as az

We won't dwell on this section, but here we will create a simple hierarchical model that we will use to explore the PyMC functionality.

# Create a simple hierarchical model for demonstration
np.random.seed(42)
groups = np.repeat([0, 1, 2], 20)

group_means = np.array([10, 15, 12])
y_obs = group_means[groups] + np.random.normal(0, 2, size=60) + np.random.normal(0, 1, 60)

with pm.Model(coords={"group": [0, 1, 2]}) as hierarchical_model:
    # Hyperpriors
    mu_global = pm.Normal("mu_global", mu=0, sigma=10)
    sigma_global = pm.HalfNormal("sigma_global", sigma=5)

    # Group-level parameters
    mu_group = pm.Normal("mu_group", mu=mu_global, sigma=sigma_global, dims="group")
    sigma_obs = pm.HalfNormal("sigma_obs", sigma=2)

    # Likelihood
    y = pm.Normal("y", mu=mu_group[groups], sigma=sigma_obs, observed=y_obs)

Now we can explore and summarize the model design.

We can shart by showing the graph of the model. The function model_to_graphviz() is a PyMC function that creates a graphviz visualization. This will generate a PDF with a visual diagram of your model's structure as a Directed Acyclic Graph (DAG).

If you are having an issue with graphviz, see this user guide - https://graphviz.readthedocs.io/en/stable/manual.html.

graph = pm.model_to_graphviz(hierarchical_model)

graph.view()

This commands below will show the basic variables which includes all stochastic random variables in your model. It includes both parameters to estimate and the observed data. If a variable in the model has a probability distribution then it is a basic random variable.

for rv in hierarchical_model.basic_RVs:
    print(f"  - {rv.name}: {rv.type}")

The output should look something like this:

  • mu_global: Scalar(float64, shape=())
  • sigma_global: Scalar(float64, shape=())
  • mu_group: Vector(float64, shape=(?,))
  • sigma_obs: Scalar(float64, shape=())
  • y: Vector(float64, shape=(60,))

The free random variables are the unobserved parameters that need to be estimated. These are the variables that MCMC will sample.

for rv in hierarchical_model.free_RVs:
    shape_info = f"shape={rv.type.shape}" if hasattr(rv.type, 'shape') else ""
    print(f"  - {rv.name}: {rv.type} {shape_info}")
  • mu_global: Scalar(float64, shape=()) shape=()
  • sigma_global: Scalar(float64, shape=()) shape=()
  • mu_group: Vector(float64, shape=(?,)) shape=(None,)
  • sigma_obs: Scalar(float64, shape=()) shape=()

Finally we can look at just the observed random variable.

for rv in hierarchical_model.observed_RVs:
    print(f"  - {rv.name}: shape={rv.type.shape}")
  • y: shape=(60,)

Next we will take a look at the dictionary of initial values for all free parameters in the model. This represents PyMC's best guess for a starting point.

We see initial values for mu_global (notice it is 1-D) and mu_group (3-D).

Why sigma_global_log? This is because PyMC using a log transformation to transform constrained parameters (half-normal -> positive only values) to unconstrained space. Once transformed, the sampling can happen on the real line (-inf to +inf).

test_point = hierarchical_model.initial_point()
print("\nInitial parameter values:")
for param, value in test_point.items():
    if np.isscalar(value):
        print(f"  {param}: {value:.3f}")
    else:
        print(f"  {param}: {value} (shape={value.shape})")

Initial parameter values: mu_global: 0.0 (shape=()) sigma_global_log__: 1.6094379124341003 (shape=()) mu_group: [0. 0. 0.] (shape=(3,)) sigma_obs_log__: 0.6931471805599453 (shape=())

Now we can check the model validity.

The first line represents the log of the joint probability of all variables at their initial values. It combines the prior probabilities and the likelihood.

log P(data, parameters) = log P(data | parameters) + log P(parameters)

The probabilities are between 0 and 1 so it is natural to see a negative log probability.

In addition to this, the probability by variable is also given.

What to look out for:

  • If you get -inf it means the initial values are impossible under the model
  • Very large negative numbers (-1e10) suggest problems.
  • Expect to see reasonable starting values (on order of -1e2).
try:
    # Compute log probability at the initial point
    log_prob = hierarchical_model.compile_logp()(hierarchical_model.initial_point())
    print(f"✓ Model is valid! Log probability: {float(log_prob):.2f}")
    print("\nLog probability by variable:")
    logp_func = hierarchical_model.compile_logp(vars=hierarchical_model.free_RVs, sum=False)
    point_logps = logp_func(hierarchical_model.initial_point())
    for var, lp in zip([rv.name for rv in hierarchical_model.free_RVs], point_logps):
        lp_val = float(lp) if np.isscalar(lp) else float(np.sum(lp))
        print(f"  {var}: {lp_val:.3f}")
except Exception as e:
    print(f"✗ Model has issues: {e}")

✓ Model is valid! Log probability: -1254.07

Log probability by variable: mu_global: -3.222 sigma_global: -0.726 mu_group: -7.585 sigma_obs: -0.726

We can also provide a simple summary. This could be useful for a standardized model summary.

print(f"\nNumber of free parameters: {len(hierarchical_model.free_RVs)}")
print(f"Number of observed variables: {len(hierarchical_model.observed_RVs)}")
print(f"Coordinates: {hierarchical_model.coords}")

Output:

Number of free parameters: 4
Number of observed variables: 1
Coordinates: {'group': (0, 1, 2)}

We can also look at the model dimensions.

for name, rv in hierarchical_model.named_vars.items():
    if hasattr(rv, 'eval'):
        try:
            actual_shape = rv.eval().shape
            if actual_shape:
                print(f"  {name}: {actual_shape}")
        except:
            if hasattr(rv.type, 'shape') and rv.type.shape:
                print(f"  {name}: {rv.type.shape} (symbolic)")
    elif hasattr(rv.type, 'shape') and rv.type.shape:
        print(f"  {name}: {rv.type.shape}")

mu_group: (3,) y: (60,)

Lets sample from the posterior.

The sampling parameters are:

  • 500: Draw 500 samples from the posterior (per chain)
  • tune=500: Tune the sampler for 500 iterations first (then discard these)
  • chains=2: Run 2 independent chains (for convergence checking)
  • Total posterior samples: 500 × 2 = 1000
with hierarchical_model:
    trace = pm.sample(500, tune=500, chains=2, return_inferencedata=True,
                      random_seed=42, progressbar=False)

print("\nArviZ InferenceData structure:")
print(trace)

Initializing NUTS using jitter+adapt_diag... Multiprocess sampling (2 chains in 2 jobs) NUTS: [mu_global, sigma_global, mu_group, sigma_obs] Sampling 2 chains for 500 tune and 500 draw iterations (1_000 + 1_000 draws total) took 0 seconds. There were 9 divergences after tuning. Increase target_accept or reparameterize. We recommend running at least 4 chains for robust computation of convergence diagnostics

ArviZ InferenceData structure: Inference data with groups:

posterior sample_stats observed_data

Now we can use arviz to summarize the results.

The trace contains all the sample results. This is a standard naming convention. summary() computes statistics for the variables. The call also specified the three variables to show.

print("\nQuick summary statistics:")
print(az.summary(trace, var_names=["mu_global", "sigma_global", "mu_group"]))

Quick summary statistics: mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk
mu_global 11.299 2.406 6.812 15.604 0.099 0.156 754.0
sigma_global 3.620 1.963 0.934 7.400 0.079 0.079 639.0
mu_group[0] 9.697 0.453 8.921 10.578 0.013 0.015 1186.0
mu_group[1] 14.355 0.487 13.361 15.187 0.015 0.019 1119.0
mu_group[2] 11.966 0.483 11.100 12.922 0.015 0.017 1058.0

ess_tail r_hat
mu_global 502.0 1.0
sigma_global 475.0 1.0
mu_group[0] 584.0 1.0
mu_group[1] 610.0 1.0
mu_group[2] 506.0 1.0