[1]:
%matplotlib inline
import matplotlib.pyplot as plt
from functools import partial
import jax.numpy as jnp
from jax import random, jit
from sgmcmcjax.samplers import build_sgld_sampler
Gaussian posteriorΒΆ
Assume data comes from a Gaussian with unknown mean and identity covariance, and estimate the mean
[8]:
# define model in JAX
def loglikelihood(theta, x):
return -0.5*jnp.dot(x-theta, x-theta)
def logprior(theta):
return -0.5*jnp.dot(theta, theta)*0.01
[9]:
# generate dataset
N, D = 10000, 100
key = random.PRNGKey(0)
X_data = random.normal(key, shape=(N, D))
[24]:
# build sampler
batch_size = int(0.1*N)
dt = 1e-5
my_sampler = build_sgld_sampler(dt, loglikelihood, logprior, (X_data,), batch_size)
# jit the sampler
my_sampler = partial(jit, static_argnums=(1,))(my_sampler)
[29]:
%%time
# run sampler
Nsamples = 10000
samples = my_sampler(key, Nsamples, jnp.zeros(D))
idx = 0
plt.plot(samples[:, idx])
CPU times: user 835 ms, sys: 14.2 ms, total: 850 ms
Wall time: 1.02 s
[29]:
[<matplotlib.lines.Line2D at 0x7fbba4542400>]