[1]:
import jax
import jax.numpy as jnp
from jax import random, vmap, jit

import numpy as np
import pymc as pm
import pymc.sampling_jax
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_style("darkgrid")

from sgmcmcjax.samplers import build_sgld_sampler, build_sgldCV_sampler
import optax
from sgmcmcjax.optimizer import build_optax_optimizer

print(f"Running on PyMC v{pm.__version__}")
/Users/jeremiecoullon/Documents/research/jax/sgmcmc_samplers/venv/lib/python3.8/site-packages/aesara/link/jax/dispatch.py:86: UserWarning: JAX omnistaging couldn't be disabled: Disabling of omnistaging is no longer supported in JAX version 0.2.12 and higher: see https://github.com/google/jax/blob/main/design_notes/omnistaging.md.
  warnings.warn(f"JAX omnistaging couldn't be disabled: {e}")
/Users/jeremiecoullon/Documents/research/jax/sgmcmc_samplers/venv/lib/python3.8/site-packages/pymc/sampling_jax.py:34: UserWarning: This module is experimental.
  warnings.warn("This module is experimental.")
Running on PyMC v4.0.0b2

Sample from a PyMC model using SGMCMCJax

todo: - calculate ESS/sec and print out for sgld, sgld-cv, and NUTS - try another model from the examples: https://docs.pymc.io/en/v3/nb_examples/index.html ? A lot of these don’t use V4 (example PMF would be good to try out as it takes an hour for NUTS)

[2]:
from pymc.sampling_jax import get_jaxified_graph

def build_pymc_loglik_logprior(model):

    original_data = []
    dummy_data_inputs = []
    for observed_RV in model.observed_RVs:
        data =  model.rvs_to_values[observed_RV]
        dummy_data_input = data.type()
        # TODO: You should revert these inplace changes after you're done
        model.rvs_to_values[observed_RV] = dummy_data_input
        original_data.append(data.data)
        dummy_data_inputs.append(dummy_data_input)

    # build loglik
    loglike_fn = get_jaxified_graph(
        inputs=model.value_vars + dummy_data_inputs,
        outputs=[model.datalogpt],
    )
    def pymc_loglikelihood(theta, x):
        return loglike_fn(*theta, x)[0]


    # build logprior
    logp_fn = get_jaxified_graph(
        outputs=[model.varlogpt],
    )

    def pymc_logprior(theta):
        return logp_fn(*theta)[0]

    return pymc_loglikelihood, pymc_logprior

PyMC

[4]:

# generate dataset
N, D = 10_000, 10
key = random.PRNGKey(0)
X_data = random.normal(key, shape=(N, D))


with pm.Model() as model:
    x = pm.Normal("x", 0, 10, shape=D)
    y = pm.Normal("y", x, observed=X_data)


model
[4]:
$$ \begin{array}{rcl} \text{x} &\sim & \operatorname{N}(0,~10)\\\text{y} &\sim & \operatorname{N}(\text{x},~1) \end{array} $$

run PyMC sampler

[5]:
%%time

with model:
    posterior = pm.sample(10_000, chains=1)

# adding this line adds a lot of time.
# is there a "block_until_ready()" missing somewhere?..
print(posterior.posterior.x.values[0,:][-1])
Auto-assigning NUTS sampler...
INFO:pymc:Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
INFO:pymc:Initializing NUTS using jitter+adapt_diag...
Sequential sampling (1 chains in 1 job)
INFO:pymc:Sequential sampling (1 chains in 1 job)
NUTS: [x]
INFO:pymc:NUTS: [x]
100.00% [11000/11000 00:30<00:00 Sampling chain 0, 0 divergences]
Sampling 1 chain for 1_000 tune and 10_000 draw iterations (1_000 + 10_000 draws total) took 30 seconds.
INFO:pymc:Sampling 1 chain for 1_000 tune and 10_000 draw iterations (1_000 + 10_000 draws total) took 30 seconds.
Only one chain was sampled, this makes it impossible to run some convergence checks
INFO:pymc:Only one chain was sampled, this makes it impossible to run some convergence checks
[-0.00409768  0.00598126  0.00143343  0.01084321 -0.00564738  0.0039423
  0.01297556 -0.02975454 -0.01849593 -0.01484458]
CPU times: user 49.3 s, sys: 59.2 s, total: 1min 48s
Wall time: 2min 56s

run using SGMCMCJax

[6]:
pymc_loglikelihood, pymc_logprior = build_pymc_loglik_logprior(model)
[7]:
# build sampler
batch_size = int(0.1*N)
dt = 1e-5

rvs = [rv.name for rv in model.value_vars]
init_position_dict = model.compute_initial_point()
init_position = [init_position_dict[rv] for rv in rvs]


# ========
my_sampler_pymc = build_sgld_sampler(dt,
                                     pymc_loglikelihood,
                                     pymc_logprior,
                                     (X_data,),
                                     batch_size)

# get MAP for SGLD-CV

# Adam
batch_size_adam = int(0.1 * N)
dt_adam = 1e-3
opt = optax.adam(learning_rate=dt_adam)
optimizer = build_optax_optimizer(opt, pymc_loglikelihood, pymc_logprior, (X_data,), batch_size_adam)
Niters = 10_000
opt_params, log_post_list = optimizer(key, Niters, init_position)

# define SGLD-CV
my_sampler_pymc_CV = build_sgldCV_sampler(dt,
                                          pymc_loglikelihood,
                                          pymc_logprior,
                                          (X_data,),
                                          batch_size,
                                          opt_params)

[8]:
%%time

Nsamples = 10_000
# run SGLD
samples_pymc = my_sampler_pymc(key, Nsamples, init_position)
# run SGLD-CV
samples_pymc_CV = my_sampler_pymc_CV(key, Nsamples, init_position)

# result are in a PyTree
_ = samples_pymc[0].block_until_ready()
_ = samples_pymc_CV[0].block_until_ready()
CPU times: user 4.32 s, sys: 78.5 ms, total: 4.4 s
Wall time: 4.43 s

compare posterior samples

[9]:
chain_num = 0
idx = 0

plt.rcParams.update({'font.size': 16})

fig, ax = plt.subplots(2, figsize=(11, 9))
sns.kdeplot(posterior.posterior.x.values[chain_num, :, idx],
            label="PyMC NUTS sampler", alpha=0.5, ax=ax[0], marker="+")
sns.kdeplot(samples_pymc_CV[0][:,idx], label="pymc: sgld-cv", alpha=0.5, ax=ax[0])
sns.kdeplot(samples_pymc[0][:,idx], label="pymc: sgld", alpha=0.5, ax=ax[0])

ax[1].plot(posterior.posterior.x.values[chain_num, :, idx],
           label="PyMC NUTS sampler", alpha=0.4, marker="+")
ax[1].plot(samples_pymc_CV[0][:,idx], label="pymc: sgld-cv", alpha=0.4)
ax[1].plot(samples_pymc[0][:,idx], label="pymc: sgld", alpha=0.4)

plt.legend()
[9]:
<matplotlib.legend.Legend at 0x7fe86dd97d00>
../_images/nbs_PyMC_integration_12_1.png
[ ]:

[ ]:

[ ]: