[1]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np

import jax.numpy as jnp
from jax import random, jit

Build a transition kernelΒΆ

In this series of notebooks we use logistic regression as the example. However the usage is exactly the same for other models.

[2]:
# import model and create dataset
from sgmcmcjax.models.logistic_regression import gen_data, loglikelihood, logprior

key = random.PRNGKey(42)
dim = 10
Ndata = 100000

theta_true, X, y_data = gen_data(key, dim, Ndata)

data = (X, y_data)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
generating data, with N=100000 and dim=10

Here we build the transition kernel for sgld given some hyperparameters such as step size and batch size

[3]:
from sgmcmcjax.kernels import build_sgld_kernel
from tqdm.auto import tqdm

batch_size = int(0.01*X.shape[0])
dt = 1e-5

init_fn, my_kernel, get_params = build_sgld_kernel(dt, loglikelihood, logprior, data, batch_size)
my_kernel = jit(my_kernel)

If we build the kernel function we also get init_fn and get_params. init_fn takes in a random key and the inital parameters and returns the state of the sampler. get_params takes in state and returns the current parameter. So we will pass state into the transition kernel to update it.

We can now write the loop ourselves and update the state using kernel function. Note that we must also split the random key, and save the samples ourselves.

Writing the loop manually is useful if we want to do things like calculate the accuracy on a test dataset throughout the sampling.

[4]:
%%time

key = random.PRNGKey(0)
Nsamples = 1000
samples = []

key, subkey = random.split(key)
state = init_fn(subkey, theta_true)

for i in tqdm(range(Nsamples)):
    key, subkey = random.split(key)
    state = my_kernel(i, subkey, state)
    samples.append(get_params(state))

samples = np.array(samples)
CPU times: user 1.96 s, sys: 46.3 ms, total: 2.01 s
Wall time: 2.01 s
[5]:

plt.plot(samples[:,0])
[5]:
[<matplotlib.lines.Line2D at 0x7fbfbd0d5ee0>]
../_images/nbs_kernel_7_1.png