[1]:
%matplotlib inline
import matplotlib.pyplot as plt
import jax.numpy as jnp
from jax import random, jit
Bayesian Neural NetworkΒΆ
Use an SGLD kernel and compute the accuracy on the test dataset while training
MNIST dataset
[2]:
from sgmcmcjax.models.bayesian_NN.NN_data import X_train, X_test, y_train, y_test
from sgmcmcjax.models.bayesian_NN.NN_model import init_network, loglikelihood, logprior, accuracy
from sgmcmcjax.kernels import build_sgld_kernel
from tqdm.auto import tqdm
batch_size = int(0.01*X_train.shape[0])
data = (X_train, y_train)
init_fn, my_kernel, get_params = build_sgld_kernel(5e-5, loglikelihood, logprior, data, batch_size)
my_kernel = jit(my_kernel)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
[3]:
# define the inital state
key = random.PRNGKey(10)
key, subkey1, subkey2 = random.split(key,3)
sizes = [784, 100, 10]
params_IC = init_network(subkey1, sizes)
state = init_fn(subkey2, params_IC)
[4]:
%%time
Nsamples = 2000
samples = []
accuracy_list = []
for i in tqdm(range(Nsamples)):
key, subkey = random.split(key)
state = my_kernel(i, subkey, state)
samples.append(get_params(state))
if i%10==0:
accuracy_list.append(accuracy(get_params(state), X_test, y_test))
CPU times: user 29.5 s, sys: 2.79 s, total: 32.3 s
Wall time: 12.9 s
[5]:
plt.plot(accuracy_list[:])
print(f"Final accuracy: {100*accuracy_list[-8]:.1f}%")
Final accuracy: 92.1%