Special cases¶
To build a new sampler there are a few special cases that don’t exactly follow the previous section. Here we look at how to write splitting schemes and samplers that resample momentum.
Splitting schemes¶
Splitting schemes require a gradient evaluation halfway through the diffusion solver. This changes how we define the diffusion as well as the kernel factory.
Our diffusion solver needs a different decorator and needs to include two update functions. We show how this is done with the BAOAB sampler:
@diffusion_palindrome
def baoab(dt, gamma, tau=1):
dt = make_schedule(dt)
def init_fn(x):
v = jnp.zeros_like(x)
return x, v
def update1(i, k, g, state):
x, v = state
v = v + dt(i)*0.5*g
x = x + v*dt(i)*0.5
c1 = jnp.exp(-gamma*dt(i))
c2 = jnp.sqrt(1 - c1**2)
v = c1*v + tau*c2*random.normal(k, shape=jnp.shape(v))
x = x + v*dt(i)*0.5
return x, v
def update2(i, k, g, state):
x, v = state
v = v + dt(i)*0.5*g
return x, v
def get_params(state):
x, _ = state
return x
return init_fn, (update1, update2), get_params
Once you have the diffusion you build the kernel factory in exactly the same was for sgld, namely:
def build_badodab_kernel(dt, loglikelihood, logprior, data, batch_size, a=0.01):
grad_log_post = build_grad_log_post(loglikelihood, logprior, data)
estimate_gradient, init_gradient = build_gradient_estimation_fn(grad_log_post, data, batch_size)
init_fn, baoab_kernel, get_params = _build_langevin_kernel(*badodab(dt, a), estimate_gradient, init_gradient)
return init_fn, baoab_kernel, get_params
The way to build the sampler factory is unchanged.
Momentum resampling¶
Samplers such as SGHMC resample momentum every L iterations (with L the number of leapfrog steps). The diffusion and kernel factory must therefore be written slightly differently.
The diffusion function must now include a resample_momentum function:
@diffusion_sghmc
def sghmc(dt, alpha=0.01, beta=0):
"https://arxiv.org/abs/1402.4102"
dt = make_schedule(dt)
def init_fn(x):
v = jnp.zeros_like(x)
return x, v
def update(i, k, g, state):
x, v = state
x = x + v
v = v + dt(i)*g - alpha*v + jnp.sqrt(2*(alpha - beta)*dt(i))*random.normal(k, shape=jnp.shape(x))
return x, v
def get_params(state):
x, _ = state
return x
def resample_momentum(i, k, x):
v = jnp.sqrt(dt(i))*random.normal(k, shape=jnp.shape(x))
return x, v
return init_fn, update, get_params, resample_momentum
The kernel factory is now slightly different to the previous cases: it simply uses a different helper function (_build_sghmc_kernel) to build the kernel:
def build_sghmc_kernel(dt, L, loglikelihood, logprior, data, batch_size, alpha=0.01, compiled_leapfrog=True):
grad_log_post = build_grad_log_post(loglikelihood, logprior, data)
estimate_gradient, init_gradient = build_gradient_estimation_fn(grad_log_post, data, batch_size)
init_fn, sghmc_kernel, get_params = _build_sghmc_kernel(*sghmc(dt, alpha), estimate_gradient,
init_gradient, L, compiled_leapfrog=compiled_leapfrog)
return init_fn, sghmc_kernel, get_params
The sampler factory is then built in the usual way.