Bayesian Spatial Models: A Pedagogical Walkthrough

This notebook demonstrates all five models currently implemented in bayespecon:

  1. SLX

  2. SAR

  3. SEM

  4. SDM

  5. SDEM

Each section explains the model idea, fits the model, and inspects posterior summaries and spatial effects.

Conceptual Roadmap

Let \(W\) denote the spatial weights matrix and \(WX\) the spatial lag of covariates.

  • SLX: \(y = X\beta + WX\theta + \epsilon\)

  • SAR: \(y = \rho Wy + X\beta + \epsilon\)

  • SEM: \(y = X\beta + u\), \(u = \lambda Wu + \epsilon\)

  • SDM: \(y = \rho Wy + X\beta + WX\theta + \epsilon\)

  • SDEM: \(y = X\beta + WX\theta + u\), \(u = \lambda Wu + \epsilon\)

In bayespecon, all models accept either:

  • formula mode: model_class(formula=..., data=..., W=...), or

  • matrix mode: model_class(y=..., X=..., W=...).

import arviz as az
import geodatasets
import geopandas as gpd
import matplotlib.pyplot as plt
import pandas as pd
from libpysal.graph import Graph

from bayespecon import OLS, SAR, SDEM, SDM, SEM, SLX
from bayespecon.diagnostics import (
    bayesian_lm_error_test,
    bayesian_lm_lag_test,
    bayesian_lm_sdm_joint_test,
    bayesian_lm_slx_error_joint_test,
    bayesian_lm_wx_test,
)
from bayespecon.diagnostics.bayesfactor import bayes_factor_compare_models

az.style.use("arviz-white")
/home/runner/micromamba/envs/test/lib/python3.14/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm
# Load sample spatial dataset
gdf = gpd.read_file(geodatasets.data.geoda.airbnb.url)
xcols = ["poverty", "rev_rating", "num_spots", "crowded"]
ycol = "price_pp"
gdf = gdf.dropna(subset=xcols + [ycol]).copy()

# Build contiguity graph (queen)
W = Graph.build_contiguity(gdf, rook=False).transform("r")

print(f"Observations: {len(gdf)}")
print(f"Predictors: {xcols}")
Observations: 67
Predictors: ['poverty', 'rev_rating', 'num_spots', 'crowded']

Helper Functions

To keep each model section focused, we use utility functions to:

  • fit with consistent MCMC settings,

  • print posterior summaries,

  • tabulate direct/indirect/total effects.

For a quick pedagogical run, we use small draw counts. Increase these for real inference.

def fit_and_report(model_cls, formula, data, W, draws=400, tune=400, chains=4, seed=42):
    """Fit a spatial model and return (model, summary, effects_df).

    Uses minimal MCMC settings for pedagogical demonstration.
    Increase draws/tune for real analyses.
    """
    model = model_cls(formula=formula, data=data, W=W, logdet_method="eigenvalue")
    model.fit(
        draws=draws,
        tune=tune,
        chains=chains,
        target_accept=0.9,
        random_seed=seed,
        progressbar=False,
        idata_kwargs={"log_likelihood": True},
    )
    summary = model.summary(round_to=3)
    effects_df = pd.DataFrame(model.spatial_effects())
    return model, summary, effects_df

Convergence Diagnostics Helper

For each fitted model, we will inspect:

  • r_hat (target close to 1.00)

  • effective sample sizes (ess_bulk, ess_tail)

  • trace plots for key parameters.

def diagnostics_table(idata, var_names):
    """Show key MCMC diagnostics for the given parameters."""
    cols = ["mean", "sd", "ess_bulk", "ess_tail", "r_hat"]
    return az.summary(idata, var_names=var_names, round_to=3)[cols]


def show_trace(idata, var_names, title):
    """Plot trace plots for the given parameters."""
    az.plot_trace(idata, var_names=var_names)
    plt.suptitle(title, y=1.02)
    plt.tight_layout()
    plt.show()

Models

0) OLS: What could go wrong?

Model:

\[ y = X\beta + \epsilon \]
ols = OLS(
    formula="price_pp ~ poverty + rev_rating + num_spots + crowded",
    data=gdf,
    W=W,
)
olsfit = ols.fit(
    draws=400,
    tune=400,
    chains=4,
    target_accept=0.9,
    random_seed=42,
    progressbar=False,
    idata_kwargs={"log_likelihood": True},
)
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 2 jobs)
NUTS: [beta, sigma]
/home/runner/micromamba/envs/test/lib/python3.14/site-packages/pymc/step_methods/hmc/quadpotential.py:316: RuntimeWarning: overflow encountered in dot
  return 0.5 * np.dot(x, v_out)
Sampling 4 chains for 400 tune and 400 draw iterations (1_600 + 1_600 draws total) took 17 seconds.
ols.summary()
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
Intercept -3.922 87.414 -173.156 159.941 3.412 2.274 660.0 745.0 1.00
poverty 0.272 0.327 -0.297 0.922 0.010 0.008 984.0 1007.0 1.00
rev_rating 0.790 0.921 -1.044 2.502 0.036 0.024 669.0 794.0 1.00
num_spots 0.120 0.024 0.079 0.164 0.001 0.001 1381.0 1128.0 1.00
crowded -2.295 0.902 -4.033 -0.665 0.027 0.023 1150.0 1066.0 1.00
sigma 26.398 2.131 22.378 30.346 0.057 0.055 1389.0 1035.0 1.01

1) SLX: Spatially Lagged Covariates Only

Model:

\[ y = X\beta + WX\theta + \epsilon \]

Interpretation:

  • beta captures local covariate effects.

  • theta captures neighbor-covariate spillovers.

  • No spatial lag on \(y\), so no autoregressive propagation through outcomes.

slx, summary_slx, effects_slx = fit_and_report(
    SLX,
    formula="price_pp ~ poverty + rev_rating + num_spots + crowded",
    data=gdf,
    W=W,
)
display(summary_slx)
display(effects_slx)
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 2 jobs)
NUTS: [beta, sigma]
/home/runner/micromamba/envs/test/lib/python3.14/site-packages/pymc/step_methods/hmc/quadpotential.py:316: RuntimeWarning: overflow encountered in dot
  return 0.5 * np.dot(x, v_out)
Sampling 4 chains for 400 tune and 400 draw iterations (1_600 + 1_600 draws total) took 29 seconds.
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
Intercept 189.574 210.163 -167.015 627.688 8.939 6.045 552.011 683.690 1.007
poverty -0.675 0.455 -1.527 0.168 0.015 0.010 946.699 1202.923 1.001
rev_rating 0.343 0.889 -1.261 2.019 0.030 0.022 858.478 995.480 1.002
num_spots 0.021 0.030 -0.037 0.074 0.001 0.001 1059.799 938.293 1.005
crowded -1.030 1.165 -3.224 1.098 0.036 0.024 1025.749 1059.615 1.000
W*poverty 1.183 0.668 -0.014 2.446 0.028 0.018 598.411 832.256 1.006
W*rev_rating -1.763 1.814 -5.026 1.749 0.072 0.048 639.098 768.137 1.005
W*num_spots 0.201 0.046 0.122 0.288 0.001 0.001 1041.499 1223.241 1.003
W*crowded -1.435 1.572 -4.655 1.306 0.053 0.035 881.426 1109.505 1.003
sigma 23.648 2.047 20.086 27.515 0.062 0.052 1098.255 1117.649 1.001
direct direct_ci_lower direct_ci_upper direct_pvalue indirect indirect_ci_lower indirect_ci_upper indirect_pvalue total total_ci_lower total_ci_upper total_pvalue
variable
poverty -0.674834 -1.560399 0.207956 0.13375 1.183429 -0.121623 2.508287 0.06875 0.508595 -0.427303 1.553064 0.31375
rev_rating 0.342909 -1.341800 2.094372 0.68125 -1.762965 -5.199795 1.912176 0.31750 -1.420056 -5.765232 2.970336 0.49000
num_spots 0.020759 -0.035625 0.080420 0.49750 0.201125 0.109459 0.286015 0.00000 0.221883 0.161526 0.283079 0.00000
crowded -1.029607 -3.224310 1.321937 0.39250 -1.434761 -4.610449 1.562538 0.38000 -2.464368 -4.825579 -0.128478 0.04000
az.plot_forest(slx.inference_data)
display(diagnostics_table(slx.inference_data, ["beta", "sigma"]))
show_trace(slx.inference_data, ["sigma"], "SLX Trace: sigma")
mean sd ess_bulk ess_tail r_hat
beta[Intercept] 189.574 210.163 552.011 683.690 1.007
beta[poverty] -0.675 0.455 946.699 1202.923 1.001
beta[rev_rating] 0.343 0.889 858.478 995.480 1.002
beta[num_spots] 0.021 0.030 1059.799 938.293 1.005
beta[crowded] -1.030 1.165 1025.749 1059.615 1.000
beta[W*poverty] 1.183 0.668 598.411 832.256 1.006
beta[W*rev_rating] -1.763 1.814 639.098 768.137 1.005
beta[W*num_spots] 0.201 0.046 1041.499 1223.241 1.003
beta[W*crowded] -1.435 1.572 881.426 1109.505 1.003
sigma 23.648 2.047 1098.255 1117.649 1.001
/tmp/ipykernel_8741/3581438262.py:11: UserWarning: The figure layout has changed to tight
  plt.tight_layout()
../_images/df6cccac8f43b860481f771a947bdbaa5beef5a8e0fc356b0e5fc1219d994082.png ../_images/ef9ec3aab19c266ce8c27a34dc0d71aea1d53e4aea1869346815f47ce669c089.png

2) SAR: Spatial Lag of Outcome

Model:

\[ y = \rho Wy + X\beta + \epsilon \]

Interpretation:

  • \(\rho\) controls feedback from neighboring outcomes.

  • Effects are amplified through \((I - \rho W)^{-1}\), so direct and indirect impacts differ from raw \(\beta\).

sar, summary_sar, effects_sar = fit_and_report(
    SAR,
    formula="price_pp ~ poverty + rev_rating + num_spots + crowded",
    data=gdf,
    W=W,
)
display(summary_sar)
display(effects_sar)
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 2 jobs)
NUTS: [rho, beta, sigma]
/home/runner/micromamba/envs/test/lib/python3.14/site-packages/pymc/step_methods/hmc/quadpotential.py:316: RuntimeWarning: overflow encountered in dot
  return 0.5 * np.dot(x, v_out)
/home/runner/micromamba/envs/test/lib/python3.14/site-packages/pymc/step_methods/hmc/quadpotential.py:316: RuntimeWarning: overflow encountered in dot
  return 0.5 * np.dot(x, v_out)
Sampling 4 chains for 400 tune and 400 draw iterations (1_600 + 1_600 draws total) took 19 seconds.
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
Intercept -43.650 78.768 -200.765 95.548 2.883 2.184 745.206 897.104 1.006
poverty 0.100 0.316 -0.488 0.677 0.010 0.009 1064.785 1027.666 1.005
rev_rating 0.913 0.825 -0.453 2.611 0.031 0.022 715.441 932.399 1.006
num_spots 0.077 0.025 0.030 0.126 0.001 0.001 1296.118 1029.851 1.002
crowded -1.748 0.860 -3.493 -0.265 0.024 0.021 1290.005 1084.390 1.002
rho 0.439 0.140 0.170 0.694 0.004 0.004 1251.185 938.602 1.009
sigma 24.315 2.152 20.449 28.357 0.062 0.059 1193.522 1006.979 1.001
direct direct_ci_lower direct_ci_upper direct_pvalue indirect indirect_ci_lower indirect_ci_upper indirect_pvalue total total_ci_lower total_ci_upper total_pvalue
variable
poverty 0.103270 -0.569727 0.750255 0.73125 0.050322 -0.659327 0.651505 0.73500 0.153591 -1.215315 1.303442 0.73125
rev_rating 0.971991 -0.739878 2.707951 0.26000 0.767987 -0.486534 3.057629 0.26375 1.739977 -1.183852 5.382396 0.26000
num_spots 0.081282 0.032140 0.134003 0.00000 0.058741 0.014567 0.140535 0.00375 0.140023 0.061422 0.235276 0.00000
crowded -1.852929 -3.615410 -0.017002 0.04750 -1.387262 -3.980152 0.002979 0.05125 -3.240191 -6.860549 -0.034940 0.04750
az.plot_forest(sar.inference_data)
display(diagnostics_table(sar.inference_data, ["rho", "beta", "sigma"]))
show_trace(sar.inference_data, ["rho", "sigma"], "SAR Trace: rho, sigma")
mean sd ess_bulk ess_tail r_hat
rho 0.439 0.140 1251.185 938.602 1.009
beta[Intercept] -43.650 78.768 745.206 897.104 1.006
beta[poverty] 0.100 0.316 1064.785 1027.666 1.005
beta[rev_rating] 0.913 0.825 715.441 932.399 1.006
beta[num_spots] 0.077 0.025 1296.118 1029.851 1.002
beta[crowded] -1.748 0.860 1290.005 1084.390 1.002
sigma 24.315 2.152 1193.522 1006.979 1.001
/tmp/ipykernel_8741/3581438262.py:11: UserWarning: The figure layout has changed to tight
  plt.tight_layout()
../_images/a87b7706a7824d37c5b8ab1cb507cb6dfa0cd3b3ebf55d86ca5fb7e159412ba4.png ../_images/5ed23ae8a7b3f6ab8600763556e8aa4fcc1db0128cfbc0979bf22bcd8c8a98d3.png

3) SEM: Spatial Correlation in Errors

Model:

\[ y = X\beta + u, \quad u = \lambda Wu + \epsilon \]

Interpretation:

  • Spatial dependence is moved to latent shocks rather than outcome feedback.

  • Useful when omitted spatial factors induce correlated residuals.

sem, summary_sem, effects_sem = fit_and_report(
    SEM,
    formula="price_pp ~ poverty + rev_rating + num_spots + crowded",
    data=gdf,
    W=W,
)
display(summary_sem)
display(effects_sem)
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 2 jobs)
NUTS: [lam, beta, sigma]
/home/runner/micromamba/envs/test/lib/python3.14/site-packages/pymc/step_methods/hmc/quadpotential.py:316: RuntimeWarning: overflow encountered in dot
  return 0.5 * np.dot(x, v_out)
/home/runner/micromamba/envs/test/lib/python3.14/site-packages/pymc/step_methods/hmc/quadpotential.py:316: RuntimeWarning: overflow encountered in dot
  return 0.5 * np.dot(x, v_out)
Sampling 4 chains for 400 tune and 400 draw iterations (1_600 + 1_600 draws total) took 22 seconds.
There were 7 divergences after tuning. Increase `target_accept` or reparameterize.
The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
Intercept -27.907 77.831 -165.614 125.951 2.905 2.386 712.488 703.147 1.004
poverty -0.118 0.444 -0.907 0.783 0.016 0.013 789.682 811.807 1.012
rev_rating 1.131 0.820 -0.455 2.625 0.030 0.025 763.064 741.629 1.004
num_spots 0.072 0.034 0.006 0.131 0.001 0.001 822.666 703.874 1.003
crowded -1.843 1.115 -3.898 0.288 0.037 0.030 919.523 985.320 1.004
lam 0.503 0.181 0.156 0.819 0.007 0.004 575.534 857.794 1.003
sigma 24.892 2.137 20.794 28.518 0.058 0.058 1345.146 1014.527 1.000
direct direct_ci_lower direct_ci_upper direct_pvalue indirect indirect_ci_lower indirect_ci_upper indirect_pvalue total total_ci_lower total_ci_upper total_pvalue
variable
poverty -0.117650 -1.046586 0.772897 0.77125 0.0 0.0 0.0 0.0 -0.117650 -1.046586 0.772897 0.77125
rev_rating 1.131456 -0.474084 2.757689 0.17000 0.0 0.0 0.0 0.0 1.131456 -0.474084 2.757689 0.17000
num_spots 0.071559 0.003338 0.135797 0.03500 0.0 0.0 0.0 0.0 0.071559 0.003338 0.135797 0.03500
crowded -1.843487 -3.971198 0.425201 0.10375 0.0 0.0 0.0 0.0 -1.843487 -3.971198 0.425201 0.10375
az.plot_forest(sem.inference_data)
display(diagnostics_table(sem.inference_data, ["lam", "beta", "sigma"]))
show_trace(sem.inference_data, ["lam", "sigma"], "SEM Trace: lam, sigma")
mean sd ess_bulk ess_tail r_hat
lam 0.503 0.181 575.534 857.794 1.003
beta[Intercept] -27.907 77.831 712.488 703.147 1.004
beta[poverty] -0.118 0.444 789.682 811.807 1.012
beta[rev_rating] 1.131 0.820 763.064 741.629 1.004
beta[num_spots] 0.072 0.034 822.666 703.874 1.003
beta[crowded] -1.843 1.115 919.523 985.320 1.004
sigma 24.892 2.137 1345.146 1014.527 1.000
/tmp/ipykernel_8741/3581438262.py:11: UserWarning: The figure layout has changed to tight
  plt.tight_layout()
../_images/cd34572e42ff9cf17dcb6a7d79b53519d0db0f7615554f75f1db9e037dd1b59c.png ../_images/f7bdd0cb6b22b33422e1e19c29e2facfb8c00e2f88d59997ef07fa178481a3ae.png

4) SDM: SAR + SLX

Model:

\[ y = \rho Wy + X\beta + WX\theta + \epsilon \]

Interpretation:

  • Includes both outcome feedback and neighbor-covariate channels.

  • Often used as a flexible nesting model for SAR/SLX.

sdm, summary_sdm, effects_sdm = fit_and_report(
    SDM,
    formula="price_pp ~ poverty + rev_rating + num_spots + crowded",
    data=gdf,
    W=W,
)
display(summary_sdm)
display(effects_sdm)
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 2 jobs)
NUTS: [rho, beta, sigma]
/home/runner/micromamba/envs/test/lib/python3.14/site-packages/pymc/step_methods/hmc/quadpotential.py:316: RuntimeWarning: overflow encountered in dot
  return 0.5 * np.dot(x, v_out)
/home/runner/micromamba/envs/test/lib/python3.14/site-packages/pymc/step_methods/hmc/quadpotential.py:316: RuntimeWarning: overflow encountered in dot
  return 0.5 * np.dot(x, v_out)
Sampling 4 chains for 400 tune and 400 draw iterations (1_600 + 1_600 draws total) took 33 seconds.
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
Intercept 192.602 197.311 -194.214 552.073 7.174 5.346 752.245 825.749 1.005
poverty -0.691 0.431 -1.515 0.119 0.013 0.011 1083.248 1085.405 1.003
rev_rating 0.409 0.855 -1.095 2.154 0.025 0.022 1141.296 993.185 1.003
num_spots 0.016 0.031 -0.043 0.072 0.001 0.001 1165.213 1244.267 1.007
crowded -1.035 1.144 -3.073 1.183 0.035 0.025 1082.196 1244.503 1.002
W*poverty 1.023 0.651 -0.209 2.210 0.022 0.014 871.928 1166.753 1.001
W*rev_rating -2.006 1.767 -5.350 1.148 0.060 0.043 868.648 889.136 1.003
W*num_spots 0.166 0.050 0.074 0.261 0.002 0.001 872.363 1196.654 1.004
W*crowded -0.934 1.706 -4.007 2.152 0.053 0.038 1040.296 1225.345 1.002
rho 0.239 0.162 -0.062 0.551 0.005 0.004 1275.488 1187.169 1.002
sigma 23.187 2.055 19.475 26.918 0.055 0.055 1392.929 1066.343 0.999
direct direct_ci_lower direct_ci_upper direct_pvalue indirect indirect_ci_lower indirect_ci_upper indirect_pvalue total total_ci_lower total_ci_upper total_pvalue
variable
poverty -0.647269 -1.444526 0.235586 0.12875 1.069852 -0.464521 2.532650 0.17500 0.422583 -1.074571 1.767189 0.50375
rev_rating 0.285608 -1.447613 2.096370 0.74875 -2.529045 -8.089502 2.249513 0.29000 -2.243437 -8.905225 3.500816 0.43500
num_spots 0.026125 -0.031623 0.085046 0.36750 0.218146 0.111995 0.338539 0.00000 0.244271 0.161811 0.353908 0.00000
crowded -1.097486 -3.362976 1.020728 0.33250 -1.497818 -5.273957 2.287999 0.43375 -2.595303 -5.913159 0.723464 0.11625
az.plot_forest(sdm.inference_data)
display(diagnostics_table(sdm.inference_data, ["rho", "beta", "sigma"]))
show_trace(sdm.inference_data, ["rho", "sigma"], "SDM Trace: rho, sigma")
mean sd ess_bulk ess_tail r_hat
rho 0.239 0.162 1275.488 1187.169 1.002
beta[Intercept] 192.602 197.311 752.245 825.749 1.005
beta[poverty] -0.691 0.431 1083.248 1085.405 1.003
beta[rev_rating] 0.409 0.855 1141.296 993.185 1.003
beta[num_spots] 0.016 0.031 1165.213 1244.267 1.007
beta[crowded] -1.035 1.144 1082.196 1244.503 1.002
beta[W*poverty] 1.023 0.651 871.928 1166.753 1.001
beta[W*rev_rating] -2.006 1.767 868.648 889.136 1.003
beta[W*num_spots] 0.166 0.050 872.363 1196.654 1.004
beta[W*crowded] -0.934 1.706 1040.296 1225.345 1.002
sigma 23.187 2.055 1392.929 1066.343 0.999
/tmp/ipykernel_8741/3581438262.py:11: UserWarning: The figure layout has changed to tight
  plt.tight_layout()
../_images/66dc592e1dfa4cf15a1360cc507f25a5309a876ede9e9ac2ef2cc11b57628756.png ../_images/4580525ecad5efc1ac9490be03537debf5af5adc70d49af92358d40e31bf89a8.png

5) SDEM: SLX + Spatial Error

Model:

\[ y = X\beta + WX\theta + u, \quad u = \lambda Wu + \epsilon \]

Interpretation:

  • Neighbor covariates matter directly (through \(WX\)),

  • and unmodeled shocks are spatially correlated (through \(u\)).

sdem, summary_sdem, effects_sdem = fit_and_report(
    SDEM,
    formula="price_pp ~ poverty + rev_rating + num_spots + crowded",
    data=gdf,
    W=W,
)
display(summary_sdem)
display(effects_sdem)
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 2 jobs)
NUTS: [lam, beta, sigma]
/home/runner/micromamba/envs/test/lib/python3.14/site-packages/pymc/step_methods/hmc/quadpotential.py:316: RuntimeWarning: overflow encountered in dot
  return 0.5 * np.dot(x, v_out)
/home/runner/micromamba/envs/test/lib/python3.14/site-packages/pymc/step_methods/hmc/quadpotential.py:316: RuntimeWarning: overflow encountered in dot
  return 0.5 * np.dot(x, v_out)
Sampling 4 chains for 400 tune and 400 draw iterations (1_600 + 1_600 draws total) took 54 seconds.
There was 1 divergence after tuning. Increase `target_accept` or reparameterize.
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
Intercept 252.539 234.668 -170.163 691.918 10.166 7.135 536.582 787.195 1.002
poverty -0.650 0.430 -1.494 0.120 0.014 0.011 934.876 1012.513 1.003
rev_rating 0.245 0.974 -1.367 2.272 0.033 0.025 860.500 986.327 1.003
num_spots 0.028 0.029 -0.027 0.083 0.001 0.001 819.267 954.465 1.001
crowded -1.070 1.100 -3.190 0.936 0.033 0.026 1112.542 1085.944 1.000
W*poverty 0.934 0.687 -0.381 2.162 0.025 0.016 762.106 1019.675 1.001
W*rev_rating -2.281 1.859 -5.898 0.955 0.077 0.051 578.298 885.051 1.004
W*num_spots 0.186 0.049 0.096 0.275 0.002 0.001 784.753 1040.884 1.002
W*crowded -1.588 1.791 -4.887 1.870 0.055 0.047 1073.592 965.759 1.000
lam 0.333 0.186 -0.038 0.653 0.007 0.004 775.186 1076.509 1.000
sigma 23.170 2.009 19.545 26.910 0.066 0.050 861.891 893.178 1.004
direct direct_ci_lower direct_ci_upper direct_pvalue indirect indirect_ci_lower indirect_ci_upper indirect_pvalue total total_ci_lower total_ci_upper total_pvalue
variable
poverty -0.650383 -1.510108 0.168970 0.12250 0.934437 -0.439177 2.225829 0.18125 0.284054 -1.136890 1.562165 0.62500
rev_rating 0.245450 -1.628970 2.166868 0.81125 -2.281243 -5.876027 1.319903 0.20375 -2.035793 -6.793964 2.587200 0.38875
num_spots 0.028453 -0.029373 0.085675 0.32875 0.185922 0.089171 0.277472 0.00000 0.214375 0.131928 0.294567 0.00000
crowded -1.069645 -3.238427 1.073698 0.33000 -1.587922 -5.236918 1.949305 0.36000 -2.657567 -6.150483 0.712910 0.11000
az.plot_forest(sdem.inference_data)
display(diagnostics_table(sdem.inference_data, ["lam", "beta", "sigma"]))
show_trace(sdem.inference_data, ["lam", "sigma"], "SDEM Trace: lam, sigma")
mean sd ess_bulk ess_tail r_hat
lam 0.333 0.186 775.186 1076.509 1.000
beta[Intercept] 252.539 234.668 536.582 787.195 1.002
beta[poverty] -0.650 0.430 934.876 1012.513 1.003
beta[rev_rating] 0.245 0.974 860.500 986.327 1.003
beta[num_spots] 0.028 0.029 819.267 954.465 1.001
beta[crowded] -1.070 1.100 1112.542 1085.944 1.000
beta[W*poverty] 0.934 0.687 762.106 1019.675 1.001
beta[W*rev_rating] -2.281 1.859 578.298 885.051 1.004
beta[W*num_spots] 0.186 0.049 784.753 1040.884 1.002
beta[W*crowded] -1.588 1.791 1073.592 965.759 1.000
sigma 23.170 2.009 861.891 893.178 1.004
/tmp/ipykernel_8741/3581438262.py:11: UserWarning: The figure layout has changed to tight
  plt.tight_layout()
../_images/f13c43415537ed2cb8514cb87eb8e0b8a1359cbada8e8bfe77978dda0102abd4.png ../_images/e4a788add1250a22b8c4f77a169ab7d370cab0491e2f5f81853f3052ca654bb7.png

MCMC Sampling Adequacy

Bayesian estimation of spatial models has a well-known efficiency pitfall: the spatial dependence parameter \(\rho\) (or \(\lambda\)) often mixes slowly, so a chain that looks converged in its point estimate can still produce posterior credible intervals that are 10–12 % too narrow because the sampler has not visited the tails enough times [Wolf et al., 2018]. The spatial_mcmc_diagnostic helper checks effective sample size, sampler yield, \(\hat{R}\), and HPDI stability for the spatial parameter and warns when any threshold is violated.

from bayespecon import spatial_mcmc_diagnostic

# Run on the SDM (most parameters; both rho and impact summaries depend on it)
report = spatial_mcmc_diagnostic(sdm, emit_warnings=False)
report.to_frame()
ess_bulk ess_tail r_hat mcse_mean yield_pct hpdi_drift_pct adequate
parameter
rho 1275.487907 1187.16926 1.002182 0.00453 79.717994 5.818803 False
report
SpatialMCMCReport(parameters=['rho'], ess_bulk={'rho': 1275.4879068543516}, ess_tail={'rho': 1187.1692603602617}, r_hat={'rho': 1.0021817336718464}, mcse_mean={'rho': 0.004529722859119265}, nominal_size=1600, yield_pct={'rho': 79.71799417839698}, hpdi_drift_pct={'rho': 5.818802561427723}, warnings_triggered=["95% HPDI width for 'rho' drifts by 5.8% between the last third and the full chain (> 5%); the credible interval has not yet stabilised. Consider doubling `draws` and/or `tune`, or re-running with more chains. See Wolf, Anselin & Arribas-Bel (2018), Geographical Analysis 50:97-119."], adequate=False, adequate_by_param={'rho': False})

Model Comparison

# Collect all fitted models for comparison
model_dict = {
    "OLS": ols,
    "SLX": slx,
    "SAR": sar,
    "SEM": sem,
    "SDM": sdm,
    "SDEM": sdem,
}
idata_dict = {name: m.inference_data for name, m in model_dict.items()}

Bayes Factor Model Comparison

Bayes factors provide an alternative to information criteria for comparing competing models. While WAIC and LOO assess predictive performance, Bayes factors compare marginal likelihoods — the probability of the data under each model, integrated over all parameter values weighted by the prior:

\[BF_{ij} = \frac{p(y \mid \mathcal{M}_i)}{p(y \mid \mathcal{M}_j)} = \frac{ML_i}{ML_j}\]

This makes Bayes factors sensitive to the prior: models with diffuse priors on unnecessary parameters are penalized more heavily, because the marginal likelihood averages over all possible parameter values weighted by the prior. In spatial models, this means that models with many WX coefficients under wide priors (e.g., SLX, SDM, SDEM) can receive substantially lower marginal likelihoods than more parsimonious alternatives (e.g., OLS, SAR, SEM) — even if their log-likelihoods are similar.

Interpreting Bayes factors (Kass & Raftery, 1995):

BF range

Evidence strength

1 – 3

Anecdotal

3 – 10

Moderate

10 – 30

Strong

30 – 100

Very strong

> 100

Extreme

Method. We use bridge sampling (Meng & Wong, 1996) to estimate each model’s log marginal likelihood, following the R bridgesampling package (Gronau et al., 2020). The implementation uses ESS weighting, two-phase convergence, and MCSE diagnostics. For reliable estimates, 40,000+ posterior samples are recommended.

Important caveat. Bayes factors can be very sensitive to prior specification. Models with diffuse priors on many parameters will tend to have lower marginal likelihoods. This is by design — it is the Bayesian Occam’s razor at work — but it means that Bayes factors and information criteria may disagree when priors are uninformative.

bayes_factor_compare_models(model_dict, method="bridge").round(3)
/tmp/ipykernel_8741/3295012235.py:1: UserWarning: Bridge sampling with 1600 posterior samples for 'OLS' may yield imprecise marginal-likelihood estimates. A conservative rule of thumb is 40,000+ samples (Gronau, Singmann, & Wagenmakers, 2017).
  bayes_factor_compare_models(model_dict, method="bridge").round(3)
OLS SLX SAR SEM SDM SDEM
OLS 1.000 5.775991e+20 0.066 0.166 1.021055e+21 5.600229e+20
SLX 0.000 1.000000e+00 0.000 0.000 1.768000e+00 9.700000e-01
SAR 15.254 8.810636e+21 1.000 2.537 1.557507e+22 8.542530e+21
SEM 6.012 3.472435e+21 0.394 1.000 6.138423e+21 3.366770e+21
SDM 0.000 5.660000e-01 0.000 0.000 1.000000e+00 5.480000e-01
SDEM 0.000 1.031000e+00 0.000 0.000 1.823000e+00 1.000000e+00
bayes_factor_compare_models(model_dict, method="bic").round(3)
OLS SLX SAR SEM SDM SDEM
OLS 1.000 0.426 0.123 1.609 0.966 1.613
SLX 2.346 1.000 0.289 3.774 2.266 3.783
SAR 8.120 3.461 1.000 13.062 7.841 13.095
SEM 0.622 0.265 0.077 1.000 0.600 1.002
SDM 1.036 0.441 0.128 1.666 1.000 1.670
SDEM 0.620 0.264 0.076 0.998 0.599 1.000

Bridge Sampling vs. BIC

  • BIC: SAR > SLX > OLS > SEM (SAR wins, SLX is moderate, SEM is worst)

  • Bridge: SAR > SEM > OLS (SAR wins, SEM is moderate)

The two methods can produce qualitatively different model rankings. This is expected and reflects a fundamental difference in what they assume about priors:

  • BIC approximates \(\log(ML) \approx \hat\ell_{\max} - \frac{k}{2}\log(n)\), which assumes unit-information priors (priors containing as much information as a single observation). The penalty per parameter is fixed at \(\frac{1}{2}\log(n) \approx 2.1\) for \(n = 77\).

  • Bridge sampling integrates over the actual priors in the model. When priors are wide (e.g., Normal(0, 100) on WX coefficients), the marginal likelihood penalizes each such parameter by roughly \(\log(\sigma_{\text{prior}} / \sigma_{\text{post}})\), which can be 5–10× larger than the BIC penalty.

This is why models with many WX terms (SLX, SDM, SDEM) may look reasonable under BIC but receive extreme Bayes factors under bridge sampling: the wide priors on the WX coefficients are “wasted” — they spread probability mass over implausible parameter values, reducing the marginal likelihood. This is Bayesian Occam’s razor at work, and bridge sampling is generally more trustworthy because it accounts for the actual prior specification.

Model Comparison: WAIC and LOO

We compare fitted models using information criteria from ArviZ.

# WAIC and LOO comparison
for ic in ("waic", "loo"):
    cmp = az.compare(idata_dict, ic=ic, method="BB-pseudo-BMA")
    az.plot_compare(cmp)
/home/runner/micromamba/envs/test/lib/python3.14/site-packages/arviz/stats/stats.py:1652: UserWarning: For one or more samples the posterior variance of the log predictive densities exceeds 0.4. This could be indication of WAIC starting to fail. 
See http://arxiv.org/abs/1507.04544 for details
  warnings.warn(
/home/runner/micromamba/envs/test/lib/python3.14/site-packages/arviz/stats/stats.py:1652: UserWarning: For one or more samples the posterior variance of the log predictive densities exceeds 0.4. This could be indication of WAIC starting to fail. 
See http://arxiv.org/abs/1507.04544 for details
  warnings.warn(
/home/runner/micromamba/envs/test/lib/python3.14/site-packages/arviz/stats/stats.py:1652: UserWarning: For one or more samples the posterior variance of the log predictive densities exceeds 0.4. This could be indication of WAIC starting to fail. 
See http://arxiv.org/abs/1507.04544 for details
  warnings.warn(
/home/runner/micromamba/envs/test/lib/python3.14/site-packages/arviz/stats/stats.py:1652: UserWarning: For one or more samples the posterior variance of the log predictive densities exceeds 0.4. This could be indication of WAIC starting to fail. 
See http://arxiv.org/abs/1507.04544 for details
  warnings.warn(
/home/runner/micromamba/envs/test/lib/python3.14/site-packages/arviz/stats/stats.py:1652: UserWarning: For one or more samples the posterior variance of the log predictive densities exceeds 0.4. This could be indication of WAIC starting to fail. 
See http://arxiv.org/abs/1507.04544 for details
  warnings.warn(
/home/runner/micromamba/envs/test/lib/python3.14/site-packages/arviz/stats/stats.py:1652: UserWarning: For one or more samples the posterior variance of the log predictive densities exceeds 0.4. This could be indication of WAIC starting to fail. 
See http://arxiv.org/abs/1507.04544 for details
  warnings.warn(
/home/runner/micromamba/envs/test/lib/python3.14/site-packages/arviz/stats/stats.py:782: UserWarning: Estimated shape parameter of Pareto distribution is greater than 0.69 for one or more samples. You should consider using a more robust model, this is because importance sampling is less likely to work well if the marginal posterior and LOO posterior are very different. This is more likely to happen with a non-robust model and highly influential observations.
  warnings.warn(
../_images/9fda15e57795f92abb5b29f20c6ce205de526e1b157977b677544df4dae4549c.png ../_images/20e16354a54603e54780191d448cc893ea84034a9ddb84345de892d02d51de79.png

Bayesian LM Specification Tests

The bayespecon.diagnostics module provides Bayesian Lagrange-multiplier tests that operate on posterior draws rather than point estimates. These replace the frequentist LM tests that were previously available.

Below we run the Bayesian LM tests from the fitted OLS model:

  • bayesian_lm_lag_test — tests \(H_0: \rho = 0\) (no spatial lag)

  • bayesian_lm_error_test — tests \(H_0: \lambda = 0\) (no spatial error)

  • bayesian_lm_wx_test — tests \(H_0: \gamma = 0\) (no WX, from SAR null)

  • bayesian_lm_sdm_joint_test — joint test for SDM (\(H_0: \rho = 0\) and \(\gamma = 0\))

  • bayesian_lm_slx_error_joint_test — joint test for SDEM (\(H_0: \lambda = 0\) and \(\gamma = 0\))

Each returns a BayesianLMTestResult with posterior summary statistics and a Bayesian p-value.

# Bayesian LM specification tests
lm_results = pd.DataFrame(
    [
        bayesian_lm_lag_test(ols).to_series(),
        bayesian_lm_error_test(ols).to_series(),
        bayesian_lm_wx_test(sar).to_series(),
        bayesian_lm_sdm_joint_test(ols).to_series(),
        bayesian_lm_slx_error_joint_test(ols).to_series(),
    ],
    index=["LM-Lag", "LM-Error", "LM-WX", "LM-SDM Joint", "LM-SLX-Error Joint"],
)

lm_results
lm_samples mean median credible_interval bayes_pvalue test_type df n_draws k_wx
LM-Lag [2141.9448000869183, 1058.4500549051181, 25.27... 9.060126e+02 4.230825e+02 (0.9196702949394995, 4476.948164232156) 0.0 bayesian_lm_lag 1 1600 NaN
LM-Error [1225.9131157509232, 2047.1881953470418, 1185.... 2.661902e+03 2.088608e+03 (467.53638748404313, 8210.710911824342) 0.0 bayesian_lm_error 1 1600 NaN
LM-WX [2383027.9493734804, 1712404.0729000492, 16642... 2.399841e+06 2.152488e+06 (445374.36773369374, 6049129.236334314) 0.0 bayesian_lm_wx 4 1600 4.0
LM-SDM Joint [3472113.2017246773, 4713480.8508704705, 36335... 5.215857e+06 4.873212e+06 (1795047.234990342, 10696517.375943972) 0.0 bayesian_lm_sdm_joint 5 1600 4.0
LM-SLX-Error Joint [3890599.0404604143, 5347438.158741046, 373684... 6.350730e+06 5.611159e+06 (1645005.9592452648, 15451396.01965429) 0.0 bayesian_lm_slx_error_joint 5 1600 4.0

Compare Spatial Parameters Across Models

This cell compares posterior means/intervals for the spatial scalar where present:

  • rho in SAR/SDM

  • lam in SEM/SDEM

# Compare spatial parameters across models
spatial_rows = []
for name, model, var in [
    ("SAR", sar, "rho"),
    ("SEM", sem, "lam"),
    ("SDM", sdm, "rho"),
    ("SDEM", sdem, "lam"),
]:
    if var in model.inference_data.posterior:
        summary = az.summary(model.inference_data, var_names=[var], round_to=3)
        summary.insert(0, "model", name)
        spatial_rows.append(summary)

pd.concat(spatial_rows)
model mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
rho SAR 0.439 0.140 0.170 0.694 0.004 0.004 1251.185 938.602 1.009
lam SEM 0.503 0.181 0.156 0.819 0.007 0.004 575.534 857.794 1.003
rho SDM 0.239 0.162 -0.062 0.551 0.005 0.004 1275.488 1187.169 1.002
lam SDEM 0.333 0.186 -0.038 0.653 0.007 0.004 775.186 1076.509 1.000

Matrix-Mode API Example (Optional)

The package also supports direct matrix inputs. This is useful if your design matrix is already engineered.

y = gdf[ycol]
X = gdf[xcols]

sar_matrix = SAR(y=y, X=X, W=W)
sar_matrix.fit(draws=200, tune=200, chains=2, random_seed=7, progressbar=False)
display(sar_matrix.summary(round_to=3))
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (2 chains in 2 jobs)
NUTS: [rho, beta, sigma]
/home/runner/micromamba/envs/test/lib/python3.14/site-packages/pymc/step_methods/hmc/quadpotential.py:316: RuntimeWarning: overflow encountered in dot
  return 0.5 * np.dot(x, v_out)
/home/runner/micromamba/envs/test/lib/python3.14/site-packages/pymc/step_methods/hmc/quadpotential.py:316: RuntimeWarning: overflow encountered in dot
  return 0.5 * np.dot(x, v_out)
Sampling 2 chains for 200 tune and 200 draw iterations (400 + 400 draws total) took 6 seconds.
We recommend running at least 4 chains for robust computation of convergence diagnostics
The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details
The effective sample size per chain is smaller than 100 for some parameters.  A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
poverty 0.046 0.281 -0.463 0.539 0.016 0.015 300.333 247.832 1.035
rev_rating 0.458 0.112 0.240 0.649 0.009 0.005 153.670 290.475 1.002
num_spots 0.079 0.023 0.042 0.127 0.001 0.001 255.255 334.916 1.004
crowded -1.668 0.878 -3.134 0.151 0.057 0.043 244.280 252.719 1.002
rho 0.426 0.127 0.191 0.654 0.011 0.007 145.755 254.152 1.003
sigma 24.381 2.171 20.962 28.240 0.166 0.182 192.654 230.408 1.009