Metropolis-Hastings in JAX

Add a dash of performance to classical statistics
probability
python
Bayesian
Author
Affiliation

Alin Morariu

Published

June 27, 2025

One of my favourite things to do is to translate math to code and code to math. Algorithms tend to be written down on paper in a way that is easy to read but not always obvious to implement. In this post, I’m going to go through the Metropolis-Hastings algorithm and show one way of implementing it. There are many ways, some are slower, some are faster, but this is one is mine.

Bayesian statistics and MCMC

The set up for Bayesian problems goes something like this… given some observed data we assume it comes from a data generating process (we call this the model), so we want to estimate the parameters of the model. Let \(\mathcal{D}\) be the observed data and \(\theta\) be the model parameters and we can write the following:

\[ \pi(\theta \mid \mathcal{D}) \propto p(\theta) \, \mathcal{L}(\mathcal{D} \mid \theta) \]

Where \(p(\theta)\) is called the prior distribution (on the model parameters), \(\mathcal{L}(\mathcal{D} \mid \theta)\) is the likelihood function, and \(\pi(\theta \mid \mathcal{D})\) is the posterior. The goal of MCMC is to explore the probability space that is the posteruor distribution. In this post, the posterior distribution is going to be very simple (so much so that you can work it out by hand) but you can imagine that for larger, more complex models, these posterior distributions will not be “nice” to work with. What’s cool about MCMC algorithms, is that they will (eventually) be able to generate samples from these complex posterior distributions so we can find the optimal parameter values for out models.

Metropolis-Hastings

The Metropolis-Hastings (MH) algorithm is a fundamental MCCM method designed to generate samples from complex probability distributions, such as our posterior. It works when direct sampling isn’t possible or when only an unnormalized density function is available. The algorithm goes something like this:

At each step \(t\), the algorithm proposes a new state \(\theta^*\) from a proposal distribution \(q(\theta^*|\theta_t)\), which often depends on the current state \(\theta_t\). The proposed state is then accepted with a probability \(\alpha\), known as the Metropolis-Hastings acceptance ratio:

\[ \alpha(\theta_t, \theta^*) = \min\left(1, \frac{\pi(\theta^*)q(\theta_t|\theta^*)}{\pi(\theta_t)q(\theta^*|\theta_t)}\right) \]

If the proposed state is accepted, \(\theta_{t+1} = \theta^*\); otherwise, the chain remains at its current state, \(\theta_{t+1} = \theta_t\)1. This acceptance criterion is crucial for correcting any bias introduced by the proposal distribution, ensuring that the generated Markov chain has \(\pi(\theta)\) as its stationary distribution. This outlines a procedure that we can easily write into a for loop; however that style of implementation tends to be slow and cumbersome.

Note

If you’d like to see the algorithm, I’d recommend checking out Scalable Monte Carlo for Bayesian Learning.

Data simulation

import jax
import jax.numpy as jnp
import distrax
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import time 

from typing import Callable, Tuple

We need some observed data so we will simulate it from a 1-dimensional Gaussian distribution.

rng_key = jax.random.PRNGKey(20250627) # use a fixed seed for reproducibility

# define the parameters of our *true* data-generating model
# true_mean is what we are aiming to "discover"/estimate through MCMC
true_mean = jnp.array(4.2, dtype=jnp.float32)
true_stddev = jnp.array(1.5, dtype=jnp.float32)

# this model represents the process that generates the observed data.
def my_model(mean, stddev):
    return distrax.Normal(loc=mean, scale=stddev)

model = my_model(mean = true_mean, stddev= true_stddev)

observed_data = model.sample(seed = rng_key, sample_shape=(40,))
# plot the observed data - think of this as exploratory data analysis
plt.figure(figsize=(9,4))
sns.histplot(observed_data, bins=30, kde=False, stat='density', color='green', label='Observed Data Density', alpha=0.6)
plt.title('Histogram of Observed Data')
plt.xlabel('Value')
plt.ylabel('Density')
plt.grid(True, linestyle=':', alpha=0.6)
plt.show()

Important

Problem statement: given the data above, estimate the mean of the data generating model given the standard deviation is 1.5. \[ \mathcal{D} \sim N(\mu, 1.5^2) \quad \leftarrow \text{Find } \mu \]

Now we can create a function that builds our target log-probability function (i.e. the posterior).

Note

In a data science workflow, we won’t know for sure that our model is correct since the data is not simulated. So from here on, we will pretend that we do not know the data is normal, but instead assume that it is.

# create the flat prior distribution (Normal(0, 10^2) for the parameter mu)
prior_mean = jnp.array(0.0, dtype=jnp.float32)
prior_stddev = jnp.array(10.0, dtype=jnp.float32)

prior_distribution = distrax.Normal(loc=prior_mean, scale=prior_stddev)
def make_target_log_prob_fn(prior_dist: distrax.Distribution, 
                                  obs_data: jax.Array, 
                                  obs_likelihood_stddev: jax.Array):
  """
  Factory function to create a callable for the target log-probability (posterior).
  This function forms a closure over the prior distribution, observed data,
  and the likelihood standard deviation.

  Args:
    prior_dist (distrax.Distribution): The prior distribution for the parameter mu.
    obs_data (jax.Array): The observed data.
    obs_likelihood_stddev (jax.Array): The standard deviation for the data likelihood.

  Returns:
    A callable function `target_log_prob_fn(x)` that returns the unnormalized
    log-posterior probability of mu given the observed data.
  """
  @jax.jit # JIT compile the log_prob function
  def target_log_prob_fn(x):
    """
    Calculates the unnormalized log-posterior probability for a given parameter value 'x'.
    This is proportional to log(prior(x)) + log(likelihood(observed_data | x)).
    """
    # calculate the log-probability of 'x' under the prior distribution
    log_prior = prior_dist.log_prob(x)

    # define the likelihood distribution for the observed data, assuming 'x' is the mean
    likelihood_dist = distrax.Normal(loc=x, scale=obs_likelihood_stddev)

    # calculate the log-probability of the observed data under this likelihood
    # and sum over all data points (assuming independence).
    log_likelihood = jnp.sum(likelihood_dist.log_prob(obs_data))

    # unnormalized log-posterior is the sum of the log-prior and log-likelihood.
    return log_prior + log_likelihood
  return target_log_prob_fn

Let’s test that it works.

# create the specific target_log_prob_callable using our defined prior, observed data, and likelihood stddev.
target_log_prob_fn = make_target_log_prob_fn(
    prior_distribution, observed_data, true_stddev)

# test the target_log_prob_callable function at a few points.
# `prior_mean` (0.0) is a good test point for the parameter 'x'.
print(f"Log-posterior at prior mean ({prior_mean.item():.2f}): {target_log_prob_fn(prior_mean).item():.4f}")
print(f"Log-posterior at 6.0: {target_log_prob_fn(jnp.array(6.0)).item():.4f}")
Log-posterior at prior mean (0.00): -229.0388
Log-posterior at 6.0: -101.6034

A modular solution

The goal in this implementation is to “section” off parts of the algorithm that we can allocate to a function that does one thing (a specialist of sorts). The most obvious candidate for this is the proposal function \(Q(x'|x; \sigma)\).

proposal_stddev = jnp.array(1.0, dtype=jnp.float32)

def make_proposal_distribution_fn(proposal_stddev_val):
  """
  Factory function to create a callable that generates a Normal proposal distribution.

  Args:
    proposal_stddev_val (jax.Array): The standard deviation for the proposal distribution.

  Returns:
    A callable function `proposal_dist_fn(center_val)` that returns a
    distrax.Normal object centered at `center_val` with
    the given `proposal_stddev_val`.
  """
  def proposal_dist_fn(center_val):
    """
    Creates a Normal distribution centered at `center_val`.
    """
    return distrax.Normal(loc=center_val, scale=proposal_stddev_val)
  return proposal_dist_fn

# Create the specific proposal distribution callable using our global standard deviation.
proposal_distribution_fn = make_proposal_distribution_fn(proposal_stddev)

Since the algorithm requires several probability evaluations, I liked the idea of bundling up these calculations into one function (this can be better implemented by having an individual function for each probability since that would be an easier system to unit test).

def make_compute_mh_log_probabilities(target_log_prob_fn, proposal_dist_fn):
    def compute_mh_log_probabilities(current_sample, proposed_sample):
        """
        Computes all necessary log probabilities for the Metropolis-Hastings acceptance ratio.
        This function implicitly accesses target_log_prob_fn and proposal_dist_fn from the closure.

        Args:
            current_sample (jax.Array): The current value in the MCMC chain.
            proposed_sample (jax.Array): The proposed candidate value.

        Returns:
            A tuple containing:
            - proposed_target_log_prob (jax.Array): Log-prob of proposed_sample under target.
            - current_target_log_prob (jax.Array): Log-prob of current_sample under target.
            - log_proposal_forward_prob (jax.Array): Log-prob of proposing proposed_sample from current_sample.
            - log_proposal_reverse_prob (jax.Array): Log-prob of proposing current_sample from proposed_sample.
        """
        # Calculate the log-probability of the proposed sample under the target distribution.
        proposed_target_log_prob = target_log_prob_fn(proposed_sample)

        # Calculate the log-probability of the current sample under the target distribution.
        current_target_log_prob = target_log_prob_fn(current_sample)

        # Define the proposal distribution for the forward step (current -> proposed)
        proposal_forward_dist = proposal_dist_fn(current_sample)
        log_proposal_forward_prob = proposal_forward_dist.log_prob(proposed_sample)

        # Define the proposal distribution for the reverse step (proposed -> current)
        proposal_reverse_dist = proposal_dist_fn(proposed_sample)
        log_proposal_reverse_prob = proposal_reverse_dist.log_prob(current_sample)

        return (proposed_target_log_prob, current_target_log_prob,
                log_proposal_forward_prob, log_proposal_reverse_prob)
        
    return compute_mh_log_probabilities
    
compute_mh_log_probabilities = make_compute_mh_log_probabilities(
    target_log_prob_fn, 
    proposal_distribution_fn
    )

And finally for the algorithm!

def make_metropolis_hastings_step(target_log_prob_fn, proposal_dist_fn):
  """
  Factory function to create a Metropolis-Hastings step function with a closure
  over the target log-probability function and the proposal distribution callable.

  Args:
    target_log_prob_fn (callable): A function that returns the log-probability
                                    of a sample under the target distribution.
    proposal_dist_fn (callable): A function `(center_val) -> distrax.Normal`
                                 that generates the proposal distribution centered at `center_val`.

  Returns:
    A callable `jax.jit` compiled function that performs one step of the Metropolis-Hastings algorithm.
  """
  # JIT compile the main step function
  @jax.jit
  def metropolis_hastings_step(carry, x): # x is a dummy variable from jax.lax.scan's `xs`
    """
    Performs one step of the Metropolis-Hastings algorithm.
    This function now uses the target_log_prob_fn and proposal_dist_fn from its closure.

    Args:
      carry: A tuple containing (current_sample, current_log_prob, rng_key).
             - current_sample (jax.Array): The current value in the MCMC chain.
             - current_log_prob (jax.Array): The log-probability of the current_sample
                                             under the target distribution.
             - rng_key (jax.Array): The JAX PRNG key for random operations.
      x: A dummy variable from the `elems` sequence of `jax.lax.scan` (unused here).

    Returns:
      A tuple (next_carry, output_value) for `jax.lax.scan`:
        - next_carry: (next_sample, next_log_prob, updated_rng_key) for the next iteration.
        - output_value: next_sample (the actual sample to be collected).
    """
    current_sample, current_log_prob, rng_key = carry

    # Split the RNG key for distinct random operations within this step
    proposal_key, uniform_key, next_rng_key = jax.random.split(rng_key, 3)

    # 1. Propose a new candidate sample using the `proposal_dist_fn` from closure.
    # Pass the proposal_key for reproducibility
    proposed_sample = proposal_dist_fn(current_sample).sample(seed=proposal_key)

    # 2. Compute all necessary log probabilities using our nested helper function.
    (proposed_target_log_prob, _, # We already have current_log_prob from current_state, so ignore this return
     log_proposal_forward_prob, log_proposal_reverse_prob) = \
        compute_mh_log_probabilities(current_sample, proposed_sample)

    # 3. Calculate the Metropolis-Hastings acceptance ratio in log-space
    log_acceptance_ratio = (proposed_target_log_prob - current_log_prob) + \
                           (log_proposal_reverse_prob - log_proposal_forward_prob)

    # The acceptance ratio 'alpha' must be between 0 and 1.
    acceptance_ratio = jnp.exp(jnp.minimum(0.0, log_acceptance_ratio))

    # 4. Generate a uniform random number for acceptance check
    # Pass the uniform_key for reproducibility
    u = jax.random.uniform(uniform_key, shape=current_sample.shape, dtype=current_sample.dtype)

    # 5. Decide whether to accept the proposed sample
    accept = jnp.less(u, acceptance_ratio)

    # Select the next sample based on acceptance
    next_sample = jnp.where(accept, proposed_sample, current_sample)

    # Select the log-probability corresponding to the next sample.
    next_log_prob = jnp.where(accept, proposed_target_log_prob, current_log_prob)

    # Return the new carry state (for next iteration) and the sample to collect
    return (next_sample, next_log_prob, next_rng_key), next_sample

  return metropolis_hastings_step

Running the algorithm

# Set the total number of samples to generate
num_samples = 5000

# Define the initial state (starting point) of our MCMC chain.
initial_sample = jnp.array(0.0, dtype=jnp.float32)
initial_log_prob = target_log_prob_fn(initial_sample)

# The initial_state (carry) for jax.lax.scan is a tuple of (initial_sample, initial_log_prob, initial_rng_key).
initial_carry = (initial_sample, initial_log_prob, rng_key)
# Create the specific MH step function (which is already JIT-compiled by its factory)
mh_step_function = make_metropolis_hastings_step(
    target_log_prob_fn, proposal_distribution_fn)
%%time 
print(f"Starting Metropolis-Hastings sampling for {num_samples} steps...")

# Use jax.lax.scan to run the Metropolis-Hastings steps iteratively.
# jax.lax.scan returns (final_carry, accumulated_outputs)
final_carry, mh_samples_jax = jax.lax.scan(
    f=mh_step_function,
    init=initial_carry,
    xs=jnp.arange(num_samples) # A dummy sequence of length num_samples to drive the iterations
)

print("Sampling complete!")
Starting Metropolis-Hastings sampling for 5000 steps...
Sampling complete!
CPU times: user 439 ms, sys: 21.5 ms, total: 461 ms
Wall time: 110 ms

Visualizing the results

mh_samples = np.array(mh_samples_jax) 
plt.figure(figsize=(9,4))
plt.plot(mh_samples, color='blue', alpha=0.6)
# Plot the true mean of the data-generating process
plt.axhline(true_mean.item(), color='red', linestyle='--', linewidth = 2, label=f'True Mean (Data Generating) ({true_mean.item():.2f})')
# Plot the estimated mean from the latter part of the MCMC chain (after some burn-in)
plt.axhline(np.mean(mh_samples[500:]), color='orange', linestyle='--', linewidth = 2, label=f'Estimated Mean (Post Burn-in) ({np.mean(mh_samples[500:]):.2f})')
plt.title('Metropolis-Hastings Trace Plot')
plt.xlabel('Iteration')
plt.ylabel('Sample Value')
plt.grid(True, linestyle=':', alpha=0.6)
plt.legend()
plt.show()

plt.figure(figsize=(9,4))
sns.histplot(mh_samples, bins=50, kde=True, stat='density', color='green', label='MH Samples Density', alpha=0.6)

# Generate points for the prior PDF
x_range = np.linspace(np.min(mh_samples) - 1, np.max(mh_samples) + 1, 500)
# Convert JAX array from distrax.prob to NumPy for plotting
prior_pdf = np.array(prior_distribution.prob(jnp.array(x_range)))
plt.plot(x_range, prior_pdf, color='red', linestyle='--', label='Prior PDF')

plt.title('Distribution of MH Samples vs. Prior PDF')
plt.xlabel('Value')
plt.ylabel('Density')
plt.legend()
plt.grid(True, linestyle=':', alpha=0.6)
plt.show()

fig, axes = plt.subplots(2, 2, figsize=(12, 10))
axes = axes.flatten() # Flatten the 2x2 array of axes for easy iteration

iteration_points = [10, 100, 1000, 5000]

# The true posterior calculations are removed as they are rarely available in practice.
# We will compare the MCMC estimates against the true data-generating mean.
# We use the 'true_mean' (4.2) as a reference point for the parameter.

x_range = np.linspace(np.min(mh_samples) - 1, np.max(mh_samples) + 1, 500) # Keep range based on samples


for i, num_iters in enumerate(iteration_points):
    # Slice the samples up to the current iteration point
    current_samples = mh_samples[:num_iters]

    ax = axes[i]
    sns.kdeplot(current_samples, ax=ax, color='blue', fill=True, label='MCMC Estimate')
    # Plot a vertical line at the true data-generating mean for reference
    ax.axvline(true_mean.item(), color='red', linestyle='--', label='True Data Mean')
    ax.set_title(f'KDE after {num_iters} Iterations')
    ax.set_xlabel('Parameter Value (X)')
    ax.set_ylabel('Density')
    ax.legend()
    ax.grid(True, linestyle=':', alpha=0.6)

plt.suptitle('Evolution of Parameter Distribution with MCMC Iterations', fontsize=16, y=0.98)
plt.tight_layout(rect=[0, 0, 1, 0.98]) # Adjust layout to prevent title overlap
plt.show()

I really like this plot because it shows the intuition behind the algorithm. A well tuned algorithm will slowly converge towards the true underlying parameter value and overpower a flat, uninformative prior.

print("\n--- Sample Statistics ---")
print(f"Mean of MH samples: {np.mean(mh_samples):.4f}")
print(f"Standard deviation of MH samples: {np.std(mh_samples):.4f}")
# Use .item() for scalar JAX arrays when printing
print(f"True Mean: {true_mean.item():.4f}")

# Calculate differences between consecutive samples
diffs = np.diff(mh_samples)
# Count where the difference is not zero (i.e., a new sample was accepted)
accepted_steps = np.sum(diffs != 0)
# Acceptance rate is the number of accepted steps divided by total steps (excluding initial)
acceptance_rate = accepted_steps / (num_samples - 1) # Subtract 1 because diff reduces length by 1

print(f"Approximate Acceptance Rate: {acceptance_rate:.4f}")

--- Sample Statistics ---
Mean of MH samples: 4.1865
Standard deviation of MH samples: 0.2855
True Mean: 4.2000
Approximate Acceptance Rate: 0.2845

Thanks for reading

In my spare time, I like to take photos so I’m going to add one photo I like at the end of each post as a thank you :) Lake District, UK, 2024

Footnotes

Note

Documentation for functions was written with help from Google’s Gemini platform.

Footnotes

  1. If you want a formal description of the algorithm, you can find it on the Wikepedia page.↩︎