Gibbs Sampler for Gaussian Spatial Models

This notebook demonstrates the block-Gibbs sampler for Gaussian spatial regression models (SAR, SEM, SDM, SDEM) in bayespecon. The Gibbs sampler exploits conditional conjugacy to avoid the poor NUTS geometry that spatial Jacobians create, often yielding 10–100× faster sampling for large datasets.

When to Use Gibbs vs NUTS

Criterion

NUTS (default)

Gibbs

Model type

Any (Gaussian, Student-t, NB, Probit, Tobit)

Gaussian only (SAR, SEM, SDM, SDEM)

Posterior geometry

Handles banana/funnel shapes via adaptation

Avoids banana geometry via conjugacy

Speed

Slow for large \(n\) (gradient through Jacobian)

Fast (direct draws for β, σ²; 1-D slice/MALA for ρ/λ)

Robust models

✅ Student-t errors

❌ Not supported

Non-Gaussian

✅ NB, Probit, Tobit

❌ Not supported

ESS/s

Low for spatial models

High (conjugate blocks mix well)

Tuning

1000+ tuning steps needed

Minimal (adaptive slice width or MALA step size)

Rule of thumb: Use sampler='gibbs' for Gaussian spatial models with \(n > 100\). Use NUTS (default) for non-Gaussian models or when you need Student-t robustness.

%load_ext autoreload
%autoreload 2

import arviz as az
import geopandas as gpd
import libpysal
import numpy as np

from bayespecon.models import SAR, SDEM, SDM, SEM

Data: Columbus Crime Dataset

We use the classic Columbus (OH) neighbourhood crime dataset from libpysal.

# Load Columbus dataset
gdf = gpd.read_file(libpysal.examples.get_path("columbus.shp"))

# Response: crime rate
y = gdf["CRIME"].values.astype(float)

# Covariates: income + housing value
X = np.column_stack(
    [
        np.ones(len(y)),
        gdf["INC"].values.astype(float),
        gdf["HOVAL"].values.astype(float),
    ]
)

# Spatial weights: queen contiguity
W = libpysal.graph.Graph.build_contiguity(gdf)
W = W.transform("r")  # row-standardise

print(f"n = {len(y)}, k = {X.shape[1]}")
n = 49, k = 3

SAR Model: Gibbs vs NUTS

The SAR (Spatial Autoregressive) model is:

\[y = \rho W y + X\beta + \varepsilon, \quad \varepsilon \sim N(0, \sigma^2 I)\]

The Gibbs sampler uses a 3-block strategy:

  1. β | ρ, σ², y — conjugate normal (direct draw)

  2. σ² | β, ρ, y — conjugate inverse-gamma (direct draw)

  3. ρ | β, σ², y — 1-D slice sampling or MALA (non-conjugate, scalar)

Only ρ is non-conjugate, and it’s a scalar — making the update trivial.

# --- NUTS  ---
model_nuts = SAR(y=y, X=X, W=W)
idata_nuts = model_nuts.fit(
    sampler="nuts",
    draws=2000,
    tune=1000,
    chains=4,
    target_accept=0.9,
    random_seed=42,
    idata_kwargs={"log_likelihood": True},
)

# --- Gibbs (numpy path) ---
model_gibbs = SAR(y=y, X=X, W=W)
idata_gibbs = model_gibbs.fit(
    sampler="gibbs",
    draws=2000,
    tune=1000,
    chains=4,
    random_seed=42,
    n_jobs=-1,
    idata_kwargs={"log_likelihood": True},
)

print("NUTS posterior means:")
print(f"  rho   = {float(idata_nuts.posterior['rho'].mean()):.4f}")
print(f"  beta  = {idata_nuts.posterior['beta'].mean(dim=['chain', 'draw']).values}")
print(f"  sigma = {float(idata_nuts.posterior['sigma'].mean()):.4f}")
print()
print("Gibbs posterior means:")
print(f"  rho   = {float(idata_gibbs.posterior['rho'].mean()):.4f}")
print(f"  beta  = {idata_gibbs.posterior['beta'].mean(dim=['chain', 'draw']).values}")
print(f"  sigma = {float(idata_gibbs.posterior['sigma'].mean()):.4f}")
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 2 jobs)
NUTS: [rho, beta, sigma2]
Sampling 4 chains for 1_000 tune and 2_000 draw iterations (4_000 + 8_000 draws total) took 16 seconds.
/home/runner/micromamba/envs/test/lib/python3.14/site-packages/rich/live.py:260: UserWarning: install "ipywidgets" 
for Jupyter support
  warnings.warn('install "ipywidgets" for Jupyter support')

/home/runner/micromamba/envs/test/lib/python3.14/site-packages/rich/live.py:260: UserWarning: install "ipywidgets" 
for Jupyter support
  warnings.warn('install "ipywidgets" for Jupyter support')

Sampling 4 chains for 1000 tune and 2000 draw iterations, 4 x 3,000 draws total took 5s (2358 draws/s)
NUTS posterior means:
  rho   = 0.4076
  beta  = [45.96578678 -1.051098   -0.2580071 ]
  sigma = 10.4814

Gibbs posterior means:
  rho   = 0.4074
  beta  = [45.9925485  -1.04618758 -0.26122261]
  sigma = 10.4943
az.summary(idata_nuts)["ess_bulk"]
beta[x0]    3313.0
beta[x1]    3581.0
beta[x2]    4056.0
rho         3902.0
sigma2      4677.0
sigma       4677.0
Name: ess_bulk, dtype: float64
az.summary(idata_gibbs)["ess_bulk"]
rho         8002.0
sigma       7017.0
sigma2      7017.0
beta[x0]    8134.0
beta[x1]    7962.0
beta[x2]    8045.0
Name: ess_bulk, dtype: float64
model_gibbs.spatial_diagnostics_decision()
../_images/31f211cdb7b189f785d7392694d8932af37c8a7f057e9b67f73ea58f12317596.svg

Gibbs Backend Options

The Gibbs sampler supports two execution backends:

Backend

gibbs_method

ρ/λ update

JIT compiled

Requires

NumPy

"numpy" (default)

Adaptive slice sampling

No

NumPy only

JAX

"jax"

Slice sampling (default), MALA, or RW-MH

Yes (@eqx.filter_jit)

JAX + equinox

NumPy path (default)

The NumPy path uses adaptive slice sampling for ρ/λ. It’s pure Python/NumPy — no JAX dependency. Best for moderate \(n\) (up to ~1000). Chains run in parallel by default (n_jobs=-1 uses all CPUs via joblib). Set n_jobs=1 for sequential execution with progress bars.

JAX path

The JAX path compiles the entire 3-block Gibbs step into a single XLA kernel via @eqx.filter_jit. It supports three ρ/λ samplers:

  • Slice sampling (default, use_slice=True): Neal (2003) slice sampling with persistent interval reuse. Best ESS per sample.

  • MALA (use_mala=True, use_slice=False): Metropolis-adjusted Langevin algorithm with JAX autodiff for exact gradients.

  • RW-MH (use_mala=False, use_slice=False): Random-walk Metropolis–Hastings. Simplest but worst mixing.

Best for large \(n\) where JIT compilation amortizes over many iterations. Chains run in parallel via jax.vmap by default (chain_method="vectorized").

# JAX Gibbs with MALA (gradient-guided proposals)
model_jax_mala = SAR(y=y, X=X, W=W)
idata_jax_mala = model_jax_mala.fit(
    sampler="gibbs",
    gibbs_method="jax",
    draws=2000,
    tune=1000,
    chains=4,
    random_seed=42,
    use_mala=True,
    use_slice=False,  # MALA instead of slice
)

# JAX Gibbs with slice sampling (default for JAX path)
model_jax_slice = SAR(y=y, X=X, W=W)
idata_jax_slice = model_jax_slice.fit(
    sampler="gibbs",
    gibbs_method="jax",  # slice sampling is the default
    draws=2000,
    tune=1000,
    chains=4,
    random_seed=42,
)

print("JAX MALA posterior means:")
print(f"  rho   = {float(idata_jax_mala.posterior['rho'].mean()):.4f}")
print(f"  sigma = {float(idata_jax_mala.posterior['sigma'].mean()):.4f}")
print()
print("JAX Slice posterior means:")
print(f"  rho   = {float(idata_jax_slice.posterior['rho'].mean()):.4f}")
print(f"  sigma = {float(idata_jax_slice.posterior['sigma'].mean()):.4f}")
/home/runner/micromamba/envs/test/lib/python3.14/site-packages/rich/live.py:260: UserWarning: install "ipywidgets" 
for Jupyter support
  warnings.warn('install "ipywidgets" for Jupyter support')

Sampling 4 chains for 1000 tune and 2000 draw iterations, 4 x 3,000 draws total took 4s (3030 draws/s)

Sampling 4 chains for 1000 tune and 2000 draw iterations, 4 x 3,000 draws total took 4s (3215 draws/s)
JAX MALA posterior means:
  rho   = 0.4289
  sigma = 10.4654

JAX Slice posterior means:
  rho   = 0.4055
  sigma = 10.5022

Log-Determinant Method Choices

The spatial Jacobian \(\log|I - \rho W|\) is the computational bottleneck. bayespecon supports several approximation methods:

Method

logdet_method

Extra option

Precompute

Per-ρ cost

JIT-compatible

Best for

Eigenvalue

"eigenvalue"

O(n³)

O(n)

n < 500

Chebyshev

"chebyshev"

exact coefficients when feasible

O(n³) or O(nm)

O(m)

default for n ≥ 500

Chebyshev + stochastic traces

"chebyshev"

trace_estimator="hutchpp" (default), "hutchinson", or "xtrace"

O(kmn)

O(m)

large sparse W

The Gibbs sampler automatically uses the same logdet configuration as the model. For the JAX path, the supported public choices are JIT-compatible.

# Chebyshev logdet with JAX Gibbs
model_cheb = SAR(y=y, X=X, W=W, logdet_method="chebyshev")
idata_cheb = model_cheb.fit(
    sampler="gibbs",
    gibbs_method="jax",
    draws=2000,
    tune=1000,
    chains=4,
    random_seed=42,
)

print(
    f"Chebyshev logdet + JAX Gibbs: rho = {float(idata_cheb.posterior['rho'].mean()):.4f}"
)
/home/runner/micromamba/envs/test/lib/python3.14/site-packages/rich/live.py:260: UserWarning: install "ipywidgets" 
for Jupyter support
  warnings.warn('install "ipywidgets" for Jupyter support')

Sampling 4 chains for 1000 tune and 2000 draw iterations, 4 x 3,000 draws total took 4s (2961 draws/s)
Chebyshev logdet + JAX Gibbs: rho = 0.4050

SEM, SDM, and SDEM Models

The Gibbs sampler works for all four Gaussian spatial model types:

Model

Spatial parameter

ρ/λ update

Null model for ρ/λ

SAR

ρ (lag)

Collapsed slice/MALA

Integrates out β, σ²

SEM

λ (error)

Conditional slice/MALA

Conditional on β, σ²

SDM

ρ (lag)

Collapsed slice/MALA

Same as SAR, Z = [X, WX]

SDEM

λ (error)

Conditional slice/MALA

Same as SEM, Z = [X, WX]

# SEM Gibbs
model_sem = SEM(y=y, X=X, W=W)
idata_sem = model_sem.fit(
    sampler="gibbs",
    draws=2000,
    tune=1000,
    chains=4,
    random_seed=42,
    n_jobs=-1,
)
print(
    f"SEM Gibbs: lam = {float(idata_sem.posterior['lam'].mean()):.4f}, "
    f"sigma = {float(idata_sem.posterior['sigma'].mean()):.4f}"
)

# SDM Gibbs
model_sdm = SDM(y=y, X=X, W=W)
idata_sdm = model_sdm.fit(
    sampler="gibbs",
    draws=2000,
    tune=1000,
    chains=4,
    random_seed=42,
    n_jobs=-1,
)
print(
    f"SDM Gibbs: rho = {float(idata_sdm.posterior['rho'].mean()):.4f}, "
    f"beta shape = {idata_sdm.posterior['beta'].shape}"
)

# SDEM Gibbs
model_sdem = SDEM(y=y, X=X, W=W)
idata_sdem = model_sdem.fit(
    sampler="gibbs",
    draws=2000,
    tune=1000,
    chains=4,
    random_seed=42,
    n_jobs=-1,
    chain_method="parallel",
)
print(
    f"SDEM Gibbs: lam = {float(idata_sdem.posterior['lam'].mean()):.4f}, "
    f"beta shape = {idata_sdem.posterior['beta'].shape}"
)
/home/runner/micromamba/envs/test/lib/python3.14/site-packages/rich/live.py:260: UserWarning: install "ipywidgets" 
for Jupyter support
  warnings.warn('install "ipywidgets" for Jupyter support')

Sampling 4 chains for 1000 tune and 2000 draw iterations, 4 x 3,000 draws total took 3s (4027 draws/s)
/home/runner/micromamba/envs/test/lib/python3.14/site-packages/rich/live.py:260: UserWarning: install "ipywidgets" 
for Jupyter support
  warnings.warn('install "ipywidgets" for Jupyter support')
SEM Gibbs: lam = 0.5549, sigma = 10.4000
SDM Gibbs: rho = 0.3989, beta shape = (4, 2000, 5)
SDEM Gibbs: lam = 0.4959, beta shape = (4, 2000, 5)

Sampling 4 chains for 1000 tune and 2000 draw iterations, 4 x 3,000 draws total took 2s (7491 draws/s)
/home/runner/micromamba/envs/test/lib/python3.14/site-packages/rich/live.py:260: UserWarning: install "ipywidgets" 
for Jupyter support
  warnings.warn('install "ipywidgets" for Jupyter support')

Sampling 4 chains for 1000 tune and 2000 draw iterations, 4 x 3,000 draws total took 3s (3917 draws/s)
az.summary(idata_sdm)
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
rho 0.399 0.167 0.089 0.715 0.002 0.002 7711.0 6145.0 1.0
sigma 10.544 1.145 8.472 12.675 0.014 0.012 6684.0 7359.0 1.0
sigma2 112.491 25.060 69.148 157.401 0.306 0.324 6684.0 7359.0 1.0
beta[x0] 43.841 13.834 17.207 69.528 0.158 0.138 7749.0 6217.0 1.0
beta[x1] -0.945 0.372 -1.681 -0.279 0.004 0.003 7983.0 7811.0 1.0
beta[x2] -0.296 0.101 -0.485 -0.111 0.001 0.001 7956.0 7581.0 1.0
beta[W*x1] -0.455 0.628 -1.664 0.682 0.007 0.005 8024.0 7814.0 1.0
beta[W*x2] 0.235 0.201 -0.139 0.617 0.002 0.002 8081.0 7450.0 1.0

InferenceData Compatibility

The Gibbs sampler produces the same az.InferenceData output as NUTS, with posterior, log_likelihood, and observed_data groups. This means all ArviZ diagnostics (LOO, WAIC, summary, etc.) work seamlessly.

# ArviZ diagnostics work with Gibbs-produced InferenceData
print("Groups:", idata_gibbs.groups())
print()

# LOO cross-validation
loo = az.loo(idata_gibbs)
print(f"LOO: elpd = {loo.elpd_loo:.2f}, SE = {loo.se:.2f}")
print()

# WAIC
waic = az.waic(idata_gibbs)
print(f"WAIC: elpd = {waic.elpd_waic:.2f}, SE = {waic.se:.2f}")
print()

# Summary
az.summary(idata_gibbs, var_names=["rho", "sigma"])
Groups: ['posterior', 'log_likelihood', 'observed_data']

LOO: elpd = -249.07, SE = 21.00

WAIC: elpd = -220.49, SE = 15.51
/home/runner/micromamba/envs/test/lib/python3.14/site-packages/arviz/stats/stats.py:782: UserWarning: Estimated shape parameter of Pareto distribution is greater than 0.70 for one or more samples. You should consider using a more robust model, this is because importance sampling is less likely to work well if the marginal posterior and LOO posterior are very different. This is more likely to happen with a non-robust model and highly influential observations.
  warnings.warn(
/home/runner/micromamba/envs/test/lib/python3.14/site-packages/arviz/stats/stats.py:1652: UserWarning: For one or more samples the posterior variance of the log predictive densities exceeds 0.4. This could be indication of WAIC starting to fail. 
See http://arxiv.org/abs/1507.04544 for details
  warnings.warn(
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
rho 0.407 0.127 0.166 0.640 0.001 0.001 8002.0 5635.0 1.0
sigma 10.494 1.087 8.463 12.444 0.013 0.011 7017.0 7494.0 1.0
az.plot_trace(idata_gibbs)
array([[<Axes: title={'center': 'rho'}>, <Axes: title={'center': 'rho'}>],
       [<Axes: title={'center': 'sigma'}>,
        <Axes: title={'center': 'sigma'}>],
       [<Axes: title={'center': 'sigma2'}>,
        <Axes: title={'center': 'sigma2'}>],
       [<Axes: title={'center': 'beta'}>,
        <Axes: title={'center': 'beta'}>]], dtype=object)
../_images/b1d44ac3d87f9341e6efac89300e14611bdff00450936f2cf65c007beb3bd410.png

Spatial Diagnostics

The Gibbs-produced InferenceData also works with bayespecon’s Bayesian LM diagnostics, which test for spatial dependence in the residuals.

# Run spatial diagnostics on the Gibbs-fitted SAR model
report = model_gibbs.spatial_diagnostics()
print(report)
                   statistic      median  df       p_value  ci_lower  \
test                                                                   
LM-Error           63.568933    4.022099   1  1.554312e-15  0.015997   
LM-WX             357.458738  123.740412   2  0.000000e+00  1.491911   
Robust-LM-WX     1622.911781  552.478228   2  0.000000e+00  2.573554   
Robust-LM-Error     0.207840    0.134203   1  6.484657e-01  0.025429   

                    ci_upper  
test                          
LM-Error          527.343692  
LM-WX            2054.380655  
Robust-LM-WX     9378.016085  
Robust-LM-Error     0.844670

Summary

The block-Gibbs sampler for Gaussian spatial models in bayespecon provides:

  1. Same output format as NUTS (az.InferenceData with posterior, log_likelihood, observed_data)

  2. Two backends: NumPy (adaptive slice, n_jobs controls parallelism, default parallel) and JAX (slice/MALA/RW-MH, chain_method controls parallelism, default vectorized)

  3. All four Gaussian models: SAR, SEM, SDM, SDEM

  4. Full ArviZ compatibility: LOO, WAIC, summary, plot

  5. Full diagnostics compatibility: Bayesian LM tests, spatial diagnostics

# Quick reference
model = SAR(y=y, X=X, W=W)

# NumPy Gibbs (default, n_jobs=-1 runs chains in parallel via joblib)
idata = model.fit(sampler="gibbs", draws=2000, tune=1000, chains=4)

# NumPy Gibbs with sequential chains (for debugging)
idata = model.fit(sampler="gibbs", draws=2000, tune=1000, chains=4, n_jobs=1)

# JAX Gibbs with slice sampling (default for JAX path)
idata = model.fit(sampler="gibbs", gibbs_method="jax", draws=2000, tune=1000, chains=4)

# JAX Gibbs with MALA
idata = model.fit(sampler="gibbs", gibbs_method="jax", use_mala=True, use_slice=False, draws=2000, tune=1000, chains=4)

# JAX Gibbs with RW-MH
idata = model.fit(sampler="gibbs", gibbs_method="jax", use_mala=False, use_slice=False, draws=2000, tune=1000, chains=4)

# Chebyshev logdet + JAX Gibbs
model = SAR(y=y, X=X, W=W, logdet_method="chebyshev")
idata = model.fit(sampler="gibbs", gibbs_method="jax", draws=2000, tune=1000, chains=4)