[1]:
%matplotlib inline

import matplotlib.pyplot as plt

import jax.numpy as jnp
import jax
from jax import random, jit

from tqdm.auto import tqdm
from sgmcmcjax.kernels import build_sgld_kernel, build_psgld_kernel, build_sgldAdam_kernel, build_sghmc_kernel

import tensorflow_datasets as tfds

from flax import linen as nn

Flax CNN

Colab notebook showing how to run this example on a TPU

[2]:
class CNN(nn.Module):
    """A simple CNN model."""

    @nn.compact
    def __call__(self, x):
        x = nn.Conv(features=32, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = nn.Conv(features=64, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = x.reshape((x.shape[0], -1))  # flatten
        x = nn.Dense(features=256)(x)
        x = nn.relu(x)
        x = nn.Dense(features=10)(x)
        x = nn.log_softmax(x)
        return x

[3]:
cnn = CNN()

def loglikelihood(params, x, y):
    x = x[jnp.newaxis] # add an axis so that it works for a single data point
    logits = cnn.apply({'params':(params)}, x)
    label = jax.nn.one_hot(y, num_classes=10)
    return jnp.sum(logits*label)

def logprior(params):
    return 1.

@jit
def accuracy_cnn(params, X, y):
    target_class = y
    predicted_class = jnp.argmax(cnn.apply({'params':(params)}, X), axis=1)
    return jnp.mean(predicted_class == target_class)


def get_datasets():
    """Load MNIST train and test datasets into memory."""
    ds_builder = tfds.builder('mnist')
    ds_builder.download_and_prepare()
    train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1))
    test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1))
    train_ds['image'] = jnp.float32(train_ds['image']) / 255.
    test_ds['image'] = jnp.float32(test_ds['image']) / 255.
    return train_ds, test_ds


train_ds, test_ds = get_datasets()

X_train_s = train_ds['image']
y_train_s = jnp.array(train_ds['label'])
X_test_s = test_ds['image']
y_test_s = jnp.array(test_ds['label'])


data = (X_train_s, y_train_s)
batch_size = int(0.01*data[0].shape[0])
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
[4]:
def run_sgmcmc(key, Nsamples, init_fn, my_kernel, get_params, accuracy_rate=100):
    "Run SGMCMC sampler and return the test accuracy list"
    accuracy_list = []
    params = cnn.init(key, jnp.ones([1,28,28,1]))['params']
    key, subkey = random.split(key)
    state = init_fn(subkey, params)

    for i in tqdm(range(Nsamples)):
        key, subkey = random.split(key)
        state = my_kernel(i, subkey, state)
        if i%accuracy_rate==0:
            accuracy_list.append(accuracy_cnn(get_params(state), X_test_s, y_test_s))

    return accuracy_list

SGLD

[5]:
%%time

init_fn, sgld_kernel, get_params = build_sgld_kernel(5e-6, loglikelihood, logprior, data, batch_size)

sgld_kernel = jit(sgld_kernel)

Nsamples = 2000
accuracy_list_sgld = run_sgmcmc(random.PRNGKey(0), Nsamples, init_fn, sgld_kernel, get_params)
CPU times: user 27min 30s, sys: 6min 57s, total: 34min 27s
Wall time: 8min 6s

SGHMC

[6]:
%%time

init_fn, sghmc_kernel, get_params = build_sghmc_kernel(1e-6, 4, loglikelihood,
                                                       logprior, data, batch_size, compiled_leapfrog=False)

Nsamples = 500
accuracy_list_sghmc = run_sgmcmc(random.PRNGKey(0), Nsamples, init_fn,
                                 sghmc_kernel, get_params, accuracy_rate=25)
CPU times: user 30min 22s, sys: 7min 39s, total: 38min 1s
Wall time: 9min 49s

pSGLD

[15]:
%%time

init_fn, psgld_kernel, get_params = build_psgld_kernel(1e-3, loglikelihood, logprior, data, batch_size)

psgld_kernel = jit(psgld_kernel)

Nsamples = 2000
accuracy_list_psgld = run_sgmcmc(random.PRNGKey(0), Nsamples, init_fn, psgld_kernel, get_params)
CPU times: user 13min 5s, sys: 3min 11s, total: 16min 17s
Wall time: 4min 6s

sgldAdam

[8]:
%%time

init_fn, sgldAdam_kernel, get_params = build_sgldAdam_kernel(1e-2, loglikelihood, logprior, data, batch_size)

sgldAdam_kernel = jit(sgldAdam_kernel)

Nsamples = 2000
accuracy_list_sgldAdam = run_sgmcmc(random.PRNGKey(0), Nsamples, init_fn, sgldAdam_kernel, get_params)
CPU times: user 28min 18s, sys: 7min 16s, total: 35min 35s
Wall time: 8min 12s
[11]:
plt.figure(figsize=(10,7))

plt.plot(jnp.arange(0,2000,100), accuracy_list_sgld, label="SGLD", marker="+") # 5e-6
plt.plot(jnp.arange(0,500,25), accuracy_list_sghmc, label="SGHMC", marker="+") # dt=1e-6, L=4
plt.plot(jnp.arange(0,2000,100), accuracy_list_psgld, label="pSGLD", marker=">") # 1e-3
plt.plot(jnp.arange(0,2000,100), accuracy_list_sgldAdam, label="sgldAdam", marker="o") # 1e-2

plt.legend(fontsize=16)
plt.title("Test accuracy of SGMCMC samplers on MNIST with a CNN")
plt.xlabel("Iterations", size=20)
plt.ylabel("Test accuracy", size=20)

print(f"sgld: {100*accuracy_list_sgld[-1]:.1f}%")
print(f"sghmc: {100*accuracy_list_sghmc[-1]:.1f}%")
print(f"psgld: {100*accuracy_list_psgld[-1]:.1f}%")
print(f"sgldAdam: {100*accuracy_list_sgldAdam[-1]:.1f}%")

sgld: 98.4%
sghmc: 97.8%
psgld: 97.0%
sgldAdam: 89.6%
../_images/nbs_Flax_MNIST_13_1.png