Source code for sgmcmcjax.samplers

import functools
from functools import partial
from typing import Any, Callable, NamedTuple, Optional, Tuple

import jax.numpy as jnp
from jax import jit, lax, random
from import tqdm

from .kernels import (
from .types import DiffusionState, PRNGKey, PyTree, SamplerState, SVRGState
from .util import progress_bar_scan

def _build_compiled_sampler(
    init_fn: Callable[[PRNGKey, PyTree], SamplerState],
    kernel: Callable[[int, PRNGKey, SamplerState], SamplerState],
    get_params: Callable[[SamplerState], PyTree],
    pbar: bool = True,
) -> Callable:
    """Build generic compiled sampler

        init_fn (Callable[[PRNGKey, PyTree], SamplerState]): function to initialise the state of chain
        kernel (Callable[[int, PRNGKey, SamplerState], SamplerState]): transition kernel
        get_params (Callable[[SamplerState], PyTree]): functions that gets the target parameters from the state
        pbar (bool, optional): whether or not to display the progress bar. Defaults to True.

        Callable: sampling function with the same signature as kernel

    @partial(jit, static_argnums=(1,))
    def sampler(key, Nsamples, params):
        def body(carry, i):
            key, state = carry
            key, subkey = random.split(key)
            state = kernel(i, subkey, state)
            return (key, state), get_params(state)

        key, subkey = random.split(key)
        state = init_fn(subkey, params)

        lebody = progress_bar_scan(Nsamples)(body) if pbar else body
        (_, _), samples = lax.scan(lebody, (key, state), jnp.arange(Nsamples))
        return samples

    return sampler

def _build_noncompiled_sampler(
    init_fn: Callable[[PRNGKey, PyTree], SamplerState],
    kernel: Callable[[int, PRNGKey, SamplerState], SamplerState],
    get_params: Callable[[SamplerState], PyTree],
    pbar: bool = True,
) -> Callable:
    """Build generic non-compiled sampler

        init_fn (Callable[[PRNGKey, PyTree], SamplerState]): function to initialise the state of chain
        kernel (Callable[[int, PRNGKey, SamplerState], SamplerState]): transition kernel
        get_params (Callable[[SamplerState], PyTree]): functions that gets the target parameters from the state
        pbar (bool, optional): whether or not to display the progress bar. Defaults to True.

        Callable: sampling function with the same signature as kernel

    def sampler(key, Nsamples, params):
        samples = []
        key, subkey = random.split(key)
        state = init_fn(subkey, params)

        _tqdm = tqdm(range(Nsamples)) if pbar else range(Nsamples)

        for i in _tqdm:
            key, subkey = random.split(key)
            state = kernel(i, subkey, state)

        return samples

    return sampler

[docs]def sgmcmc_sampler(build_kernel_fn: Callable) -> Callable: """Decorator that turns a kernel factory into a sampler factory. These samplers have exactly the same signatures as the kernels they're built from (see sgmcmcjax.kernels). Args: build_kernel_fn (Callable): kernel factory Returns: Callable: sampling function with the same signature as build_kernel_fn """ @functools.wraps(build_kernel_fn) def wrapper(*args, **kwargs): compiled = kwargs.pop("compiled", True) pbar = kwargs.pop("pbar", True) init_fn, my_kernel, get_params = build_kernel_fn(*args, **kwargs) if compiled: return _build_compiled_sampler(init_fn, my_kernel, get_params, pbar=pbar) else: return _build_noncompiled_sampler( init_fn, jit(my_kernel), get_params, pbar=pbar ) return wrapper
build_sgld_sampler = sgmcmc_sampler(build_sgld_kernel) build_sgldCV_sampler = sgmcmc_sampler(build_sgldCV_kernel) build_sgld_SVRG_sampler = sgmcmc_sampler(build_sgld_SVRG_kernel) build_psgld_sampler = sgmcmc_sampler(build_psgld_kernel) build_sghmc_sampler = sgmcmc_sampler(build_sghmc_kernel) build_sghmcCV_sampler = sgmcmc_sampler(build_sghmcCV_kernel) build_sghmc_SVRG_sampler = sgmcmc_sampler(build_sghmc_SVRG_kernel) build_baoab_sampler = sgmcmc_sampler(build_baoab_kernel) build_sgnht_sampler = sgmcmc_sampler(build_sgnht_kernel) build_sgnhtCV_sampler = sgmcmc_sampler(build_sgnhtCV_kernel) build_badodab_sampler = sgmcmc_sampler(build_badodab_kernel) build_badodabCV_sampler = sgmcmc_sampler(build_badodabCV_kernel)