Linear regression (Stan)#

We will now conduct the same linear regression as in the previous chapter, but with Stan rather than PyMC. The Stan documentation has an example of linear regression model as well.

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import arviz as az
from cmdstanpy import CmdStanModel

Data#

df = pd.read_csv('data/linreg.csv')
# Separating pre and post periods
df_train = df[df['period']=="pre"]
df_test = df[df['period']=="post"]

# A plot of the dependent variable vs. each of the explanatory variables
fig, axes = plt.subplots(1, 2, figsize=(10, 4))
axes[0].scatter(df_train['HDD'], df_train['consumption'], alpha=0.6)
axes[1].scatter(df_train['CLASSDAYS'], df_train['consumption'], alpha=0.6)
axes[0].set_ylabel("consumption")
axes[0].set_xlabel("hdd")
axes[1].set_xlabel("classdays")
Text(0.5, 0, 'classdays')
_images/2238c5d7bff920569f47bbd2d12a149a56280565766290f95b3b4d94ee4d6953.png

Model specification#

A Stan model is a block of text that can either be written in a separate file, or in the same script as the current code. Stan has high-level interfaces with most programming languages. Here, we’ll use CmdStanPy.

data {
  int<lower=0> N;     // number of data items
  int<lower=0> K;     // number of predictors
  matrix[N, K] x;     // predictor matrix
  vector[N] y;        // outcome vector
}
parameters {
  real alpha;           // intercept
  vector[K] beta;       // coefficients for predictors
  real<lower=0> sigma;  // error scale
 }
model {
  y ~ normal(x * beta + alpha, sigma);  // likelihood
}
generated quantities {
  array[N] real log_lik;
  array[N] real y_hat;
  for (j in 1:N) {
    log_lik[j] = normal_lpdf(y[j] | x[j] * beta + alpha, sigma);
    y_hat[j] = normal_rng(x[j] * beta + alpha, sigma);
  }
}

This code should has been written into a dedicated file called linearregression.stan, which we then load:

model = CmdStanModel(stan_file='models/linearregression.stan')

If there is no compiled executable of this model in this filepath, this function will create one. If there is and the model has not been modified, it will load it without recompilation. This allows saving time when fitting the same model: compiling a new or modified model for the first time takes a while, but then the compiled executable is saved and the model may be re-used at will with different data sets.

Training#

The sample() method is used to do Bayesian inference over the model conditioned on data using using Hamiltonian Monte Carlo (HMC) sampling. It runs Stan’s HMC-NUTS sampler on the model and data and returns a CmdStanMCMC object. CmdStanPy allows specifying the data as a Python dictionary.

In the PyMC linear regression example, we fit two models: one with only HDD as explanatory variable, and one with also the class days. Let us do the same here:

# Data declaration for model 1
model1_data = {
    "N": len(df_train),
    "K": 1,
    "x": df_train[['HDD']].values,
    "y": df_train['consumption'].values,
}

# Fitting model 1
fit1 = model.sample(data=model1_data, show_progress=False)
09:54:13 - cmdstanpy - INFO - CmdStan start processing
09:54:13 - cmdstanpy - INFO - Chain [1] start processing
09:54:13 - cmdstanpy - INFO - Chain [2] start processing
09:54:13 - cmdstanpy - INFO - Chain [3] start processing
09:54:13 - cmdstanpy - INFO - Chain [4] start processing
09:54:14 - cmdstanpy - INFO - Chain [3] done processing
09:54:15 - cmdstanpy - INFO - Chain [1] done processing
09:54:15 - cmdstanpy - INFO - Chain [4] done processing
09:54:16 - cmdstanpy - INFO - Chain [2] done processing
# Data declaration for model 2
model2_data = {
    "N": len(df_train),
    "K": 2,
    "x": df_train[['HDD', 'CLASSDAYS']].values,
    "y": df_train['consumption'].values,
}

# Fitting model 2
fit2 = model.sample(data=model2_data, show_progress=False)
09:54:16 - cmdstanpy - INFO - CmdStan start processing
09:54:16 - cmdstanpy - INFO - Chain [1] start processing
09:54:16 - cmdstanpy - INFO - Chain [2] start processing
09:54:16 - cmdstanpy - INFO - Chain [3] start processing
09:54:16 - cmdstanpy - INFO - Chain [4] start processing
09:54:18 - cmdstanpy - INFO - Chain [3] done processing
09:54:19 - cmdstanpy - INFO - Chain [4] done processing
09:54:19 - cmdstanpy - INFO - Chain [1] done processing
09:54:19 - cmdstanpy - INFO - Chain [2] done processing

We can see than sampling is quite fast once the model has been compiled.

Note: the show_progress=False argument used above is only to simplify the display on this page. It can be removed or set to True to track the progress of sampling.

CmdStan has diagnostics utilities: stansummary and diagnose

The summary() method runs an utility that reads the outputs of all chains and computes summary statistics for all sampler and model parameters and quantities of interest. It also returns summaries of the total joint log-probability density lp__.

fs = fit1.summary(percentiles=(5, 50, 95))
fs.loc[['alpha', 'beta[1]', 'sigma', 'lp__']]
Mean MCSE StdDev 5% 50% 95% N_Eff N_Eff/s R_hat
alpha 351957.0000 367.627000 13780.000000 329221.000 352384.0000 374190.000 1405.020000 730.259000 0.999853
beta[1] 100.2520 0.697350 26.573900 57.367 99.9952 143.995 1452.140000 754.750000 0.999405
sigma 27345.6132 203.997122 7403.755025 18267.100 25965.8000 41039.000 1317.212973 684.622127 1.000971
lp__ -117.5860 0.044984 1.419820 -120.228 -117.2350 -116.031 996.235000 517.794000 1.002270
fs = fit2.summary(percentiles=(5, 50, 95))
fs.loc[['alpha', 'beta[1]', 'beta[2]', 'sigma', 'lp__']]
Mean MCSE StdDev 5% 50% 95% N_Eff N_Eff/s R_hat
alpha 325605.00000 235.844000 9436.19000 310892.0000 325514.0000 340604.0000 1600.83000 1137.76000 1.004220
beta[1] 56.48890 0.383132 17.16070 29.7856 56.5464 83.7613 2006.19000 1425.86000 0.999786
beta[2] 3261.10000 17.712800 734.32600 2063.7100 3259.4900 4438.6500 1718.71000 1221.55000 1.002670
sigma 14078.61626 125.454950 4278.56431 9125.7000 13254.4000 21583.4000 1163.10931 826.65907 1.003890
lp__ -110.08700 0.062631 1.80114 -113.5180 -109.7060 -108.0380 827.02500 587.79300 1.006430

Additionally, the diagnose() method runs an utility which analyzes the per-draw sampler parameters across all chains looking for potential problems which indicate that the sample isn’t a representative sample from the posterior.

print(fit2.diagnose())
Processing csv files: /tmp/tmplq8zu57i/linearregressionk24hs4cb/linearregression-20250923095416_1.csv, /tmp/tmplq8zu57i/linearregressionk24hs4cb/linearregression-20250923095416_2.csv, /tmp/tmplq8zu57i/linearregressionk24hs4cb/linearregression-20250923095416_3.csv, /tmp/tmplq8zu57i/linearregressionk24hs4cb/linearregression-20250923095416_4.csv

Checking sampler transitions treedepth.
Treedepth satisfactory for all transitions.

Checking sampler transitions for divergences.
No divergent transitions found.

Checking E-BFMI - sampler transitions HMC potential energy.
E-BFMI satisfactory.

Effective sample size satisfactory.

Split R-hat values satisfactory all parameters.

Processing complete, no problems detected.

Extracting results#

The fit1 and fit2 objects include methods to access the sampler outputs. They are listed on CmdStanPy’s documentation but we can mention two of them:

  • draws_pd(): returns a pandas.DataFrame over all columns in the Stan CSV file. The argument vars can be used to restrict this to one or more variables.

  • draws_xr(): returns an xarray.Dataset which maps model variable names to their respective values.

res_pd = fit1.draws_pd()
res_xr = fit1.draws_xr()
res_pd.head()
chain__ iter__ draw__ lp__ accept_stat__ stepsize__ treedepth__ n_leapfrog__ divergent__ energy__ ... y_hat[3] y_hat[4] y_hat[5] y_hat[6] y_hat[7] y_hat[8] y_hat[9] y_hat[10] y_hat[11] y_hat[12]
0 1.0 1.0 1.0 -118.238 0.971145 0.339103 3.0 7.0 0.0 118.663 ... 373191.0 408420.0 449295.0 444354.0 519855.0 458727.0 473612.0 422564.0 345776.0 357898.0
1 1.0 2.0 2.0 -118.237 0.993430 0.339103 4.0 15.0 0.0 119.749 ... 355499.0 364834.0 377669.0 427364.0 470069.0 417393.0 413354.0 397723.0 398708.0 354786.0
2 1.0 3.0 3.0 -117.646 0.985189 0.339103 3.0 7.0 0.0 118.872 ... 337474.0 403931.0 393499.0 427051.0 424879.0 438156.0 407038.0 403268.0 361106.0 359410.0
3 1.0 4.0 4.0 -118.426 0.659121 0.339103 3.0 7.0 0.0 119.950 ... 364002.0 395530.0 387969.0 494433.0 485013.0 462230.0 418238.0 418811.0 447337.0 341333.0
4 1.0 5.0 5.0 -122.657 0.941061 0.339103 3.0 15.0 0.0 123.148 ... 479027.0 311505.0 310124.0 368594.0 493771.0 488210.0 505653.0 423499.0 392101.0 304690.0

5 rows × 37 columns

The ArviZ package for exploratory analysis of Bayesian models can be used with CmdStanPy. To make the most out of ArviZ, it is recommended to convert the results to InferenceData. Doing so improves the res_xr xarray by assigning all variables to the right groups.

idata1 = az.from_cmdstanpy(
    posterior = fit1,
    posterior_predictive="y_hat",
    observed_data={"y": df_train['consumption'].values},
    constant_data={"x": df_train[['HDD']].values},
    log_likelihood="log_lik",
    dims={
        "y_hat": ["observations"],
        "log_lik": ["observations"],
        "y": ["observations"],
        "beta": ["features"]
        },
)

idata2 = az.from_cmdstanpy(
    posterior = fit2,
    posterior_predictive="y_hat",
    observed_data={"y": df_train['consumption'].values},
    constant_data={"x": df_train[['HDD', 'CLASSDAYS']].values},
    log_likelihood="log_lik",
    dims={
        "y_hat": ["observations"],
        "log_lik": ["observations"],
        "y": ["observations"],
        "beta": ["features"]
        },
)

idata2
arviz.InferenceData
    • <xarray.Dataset> Size: 136kB
      Dimensions:   (chain: 4, draw: 1000, features: 2)
      Coordinates:
        * chain     (chain) int64 32B 0 1 2 3
        * draw      (draw) int64 8kB 0 1 2 3 4 5 6 7 ... 993 994 995 996 997 998 999
        * features  (features) int64 16B 0 1
      Data variables:
          alpha     (chain, draw) float64 32kB 3.308e+05 3.232e+05 ... 3.237e+05
          beta      (chain, draw, features) float64 64kB 26.95 3.918e+03 ... 3.976e+03
          sigma     (chain, draw) float64 32kB 1.053e+04 1.226e+04 ... 1.716e+04
      Attributes:
          created_at:                 2025-09-23T07:54:26.529345+00:00
          arviz_version:              0.22.0
          inference_library:          cmdstanpy
          inference_library_version:  1.2.5

    • <xarray.Dataset> Size: 392kB
      Dimensions:       (chain: 4, draw: 1000, observations: 12)
      Coordinates:
        * chain         (chain) int64 32B 0 1 2 3
        * draw          (draw) int64 8kB 0 1 2 3 4 5 6 ... 993 994 995 996 997 998 999
        * observations  (observations) int64 96B 0 1 2 3 4 5 6 7 8 9 10 11
      Data variables:
          y_hat         (chain, draw, observations) float64 384kB 3.307e+05 ... 3.3...
      Attributes:
          created_at:                 2025-09-23T07:54:26.630012+00:00
          arviz_version:              0.22.0
          inference_library:          cmdstanpy
          inference_library_version:  1.2.5

    • <xarray.Dataset> Size: 392kB
      Dimensions:       (chain: 4, draw: 1000, observations: 12)
      Coordinates:
        * chain         (chain) int64 32B 0 1 2 3
        * draw          (draw) int64 8kB 0 1 2 3 4 5 6 ... 993 994 995 996 997 998 999
        * observations  (observations) int64 96B 0 1 2 3 4 5 6 7 8 9 10 11
      Data variables:
          log_lik       (chain, draw, observations) float64 384kB -10.64 ... -10.99
      Attributes:
          created_at:                 2025-09-23T07:54:26.666022+00:00
          arviz_version:              0.22.0
          inference_library:          cmdstanpy
          inference_library_version:  1.2.5

    • <xarray.Dataset> Size: 204kB
      Dimensions:          (chain: 4, draw: 1000)
      Coordinates:
        * chain            (chain) int64 32B 0 1 2 3
        * draw             (draw) int64 8kB 0 1 2 3 4 5 6 ... 994 995 996 997 998 999
      Data variables:
          lp               (chain, draw) float64 32kB -110.9 -108.5 ... -108.6 -109.9
          acceptance_rate  (chain, draw) float64 32kB 1.0 0.9958 ... 0.9849 0.9353
          step_size        (chain, draw) float64 32kB 0.1959 0.1959 ... 0.2104 0.2104
          tree_depth       (chain, draw) int64 32kB 2 3 3 4 4 3 3 4 ... 4 4 5 3 3 4 4
          n_steps          (chain, draw) int64 32kB 3 7 11 31 15 11 ... 31 15 15 15 15
          diverging        (chain, draw) bool 4kB False False False ... False False
          energy           (chain, draw) float64 32kB 116.9 111.3 ... 109.9 112.2
      Attributes:
          created_at:                 2025-09-23T07:54:26.586935+00:00
          arviz_version:              0.22.0
          inference_library:          cmdstanpy
          inference_library_version:  1.2.5

    • <xarray.Dataset> Size: 192B
      Dimensions:       (observations: 12)
      Coordinates:
        * observations  (observations) int64 96B 0 1 2 3 4 5 6 7 8 9 10 11
      Data variables:
          y             (observations) int64 96B 321120 335520 ... 401760 339840
      Attributes:
          created_at:                 2025-09-23T07:54:26.640046+00:00
          arviz_version:              0.22.0
          inference_library:          cmdstanpy
          inference_library_version:  1.2.5

    • <xarray.Dataset> Size: 304B
      Dimensions:  (x_dim_0: 12, x_dim_1: 2)
      Coordinates:
        * x_dim_0  (x_dim_0) int64 96B 0 1 2 3 4 5 6 7 8 9 10 11
        * x_dim_1  (x_dim_1) int64 16B 0 1
      Data variables:
          x        (x_dim_0, x_dim_1) float64 192B 12.7 0.0 19.4 6.0 ... 17.5 54.2 0.0
      Attributes:
          created_at:                 2025-09-23T07:54:26.652370+00:00
          arviz_version:              0.22.0
          inference_library:          cmdstanpy
          inference_library_version:  1.2.5

We can now use the same posterior plots as we did with the PyMC output:

az.plot_trace(idata1)
array([[<Axes: title={'center': 'alpha'}>,
        <Axes: title={'center': 'alpha'}>],
       [<Axes: title={'center': 'beta'}>,
        <Axes: title={'center': 'beta'}>],
       [<Axes: title={'center': 'sigma'}>,
        <Axes: title={'center': 'sigma'}>]], dtype=object)
_images/57e65d66d0e8a15e70bce70140dc21936e10a317d8133566a9f690a26fb64f1e.png

And here is the posterior predictive distribution:

y_post = idata2.posterior_predictive["y_hat"]

fig, axes = plt.subplots(1, 2, figsize=(10, 4))

axes[0].scatter(df_train['HDD'], df_train['consumption'], alpha=0.6, label='data')
axes[0].scatter(df_train['HDD'], y_post.mean(("chain", "draw")), c="C1", alpha=0.6, label='posterior mean')
az.plot_hdi(df_train['HDD'], y_post, ax=axes[0])
axes[0].set_xlabel("hdd")
axes[0].set_ylabel("consumption")
axes[0].legend()

axes[1].scatter(df_train['CLASSDAYS'], df_train['consumption'], alpha=0.6)
axes[1].scatter(df_train['CLASSDAYS'], y_post.mean(("chain", "draw")), c="C1", alpha=0.6)
az.plot_hdi(df_train['CLASSDAYS'], y_post, ax=axes[1])
axes[1].set_xlabel("class days")
Text(0.5, 0, 'class days')
_images/988be66ff3e125459e2bf1f3982f5c6430035f42d3daa0fa8250ec1388db325c.png

Reassuringly, this looks exactly like the posterior distribution we had from PyMC.

Model comparison metrics#

The built-in LOO method of ArviZ can be used now (see the previous chapter for a reminder of what is happening here):

df_comp_loo = az.compare({"model1": idata1, "model2": idata2}, ic="loo")
df_comp_loo
/home/simon/anaconda3/envs/bayes_env/lib/python3.11/site-packages/arviz/stats/stats.py:797: UserWarning: Estimated shape parameter of Pareto distribution is greater than 0.70 for one or more samples. You should consider using a more robust model, this is because importance sampling is less likely to work well if the marginal posterior and LOO posterior are very different. This is more likely to happen with a non-robust model and highly influential observations.
  warnings.warn(
rank elpd_loo p_loo elpd_diff weight se dse warning scale
model2 0 -132.533091 3.046308 0.000000 1.0 1.155242 0.000000 True log
model1 1 -140.311554 2.490824 7.778464 0.0 2.074431 2.016874 False log