Kernels

sgmcmcjax.kernels.build_badodabCV_kernel(dt: float, loglikelihood: Callable, logprior: Callable, data: Tuple, batch_size: int, centering_value: Union[Dict, List, Tuple, numpy.ndarray, jax._src.numpy.lax_numpy.ndarray], a: float = 0.01)Tuple[Callable, Callable, Callable][source]

build BADODAB kernel with Control Variates

Parameters
  • dt (float) – step size

  • loglikelihood (Callable) – log-likelihood for a single data point

  • logprior (Callable) – log-prior for a single data point

  • data (Tuple) – tuple of data. It should either have a single array (for unsupervised problems) or have two arrays (for supervised problems)

  • batch_size (int) – batch size

  • centering_value (PyTree) – Centering value for the control variates (should be the MAP)

  • a (float, optional) – initial value of alpha. Defaults to 0.01.

Returns

An (init_fun, kernel, get_params) triple.

Return type

Tuple[Callable, Callable, Callable]

sgmcmcjax.kernels.build_badodab_kernel(dt: float, loglikelihood: Callable, logprior: Callable, data: Tuple, batch_size: int, a: float = 0.01)Tuple[Callable, Callable, Callable][source]

build BADODAB kernel, a splitting scheme for the 3-equation Langevin diffusion. See https://arxiv.org/abs/1505.06889

Parameters
  • dt (float) – step size

  • loglikelihood (Callable) – log-likelihood for a single data point

  • logprior (Callable) – log-prior for a single data point

  • data (Tuple) – tuple of data. It should either have a single array (for unsupervised problems) or have two arrays (for supervised problems)

  • batch_size (int) – batch size

  • a (float, optional) – initial value of alpha. Defaults to 0.01.

Returns

An (init_fun, kernel, get_params) triple.

Return type

Tuple[Callable, Callable, Callable]

sgmcmcjax.kernels.build_baoab_kernel(dt: float, loglikelihood: Callable, logprior: Callable, data: Tuple, batch_size: int, gamma: float, tau: float = 1.0)Tuple[Callable, Callable, Callable][source]

build BAOAB kernel, a splitting scheme for the underdampled Langevin diffusion: https://aip.scitation.org/doi/abs/10.1063/1.4802990

Parameters
  • dt (float) – step size

  • loglikelihood (Callable) – log-likelihood for a single data point

  • logprior (Callable) – log-prior for a single data point

  • data (Tuple) – tuple of data. It should either have a single array (for unsupervised problems) or have two arrays (for supervised problems)

  • batch_size (int) – batch size

  • gamma (float) – friction coefficient

  • tau (float, optional) – temperature. Defaults to 1.

Returns

An (init_fun, kernel, get_params) triple.

Return type

Tuple[Callable, Callable, Callable]

sgmcmcjax.kernels.build_psgld_kernel(dt: float, loglikelihood: Callable, logprior: Callable, data: Tuple, batch_size: int, alpha: float = 0.99, eps: float = 1e-05)Tuple[Callable, Callable, Callable][source]

build preconditioned SGLD kernel

Parameters
  • dt (float) – step size

  • loglikelihood (Callable) – log-likelihood for a single data point

  • logprior (Callable) – log-prior for a single data point

  • data (Tuple) – tuple of data. It should either have a single array (for unsupervised problems) or have two arrays (for supervised problems)

  • batch_size (int) – batch size

  • alpha (float, optional) – balances the weights of historical and current gradients. Defaults to 0.99.

  • eps (float, optional) – controls the extremes of the curvature of preconditioner. Defaults to 1e-5.

Returns

An (init_fun, kernel, get_params) triple.

Return type

Tuple[Callable, Callable, Callable]

sgmcmcjax.kernels.build_sghmcCV_kernel(dt: float, L: int, loglikelihood: Callable, logprior: Callable, data: Tuple, batch_size: int, centering_value: Union[Dict, List, Tuple, numpy.ndarray, jax._src.numpy.lax_numpy.ndarray], alpha: float = 0.01, compiled_leapfrog: bool = True)Tuple[Callable, Callable, Callable][source]

build stochatic gradient HMC kernel with Control Variates

Parameters
  • dt (float) – step size

  • L (int) – number of leapfrog steps

  • loglikelihood (Callable) – log-likelihood for a single data point

  • logprior (Callable) – log-prior for a single data point

  • data (Tuple) – tuple of data. It should either have a single array (for unsupervised problems) or have two arrays (for supervised problems)

  • batch_size (int) – batch size

  • centering_value (PyTree) – Centering value for the control variates (should be the MAP)

  • alpha (float, optional) – friction coefficient. Defaults to 0.01.

  • compiled_leapfrog (bool, optional) – whether or not the loop is performed with lax.scan or not. Otherwise run a native Python loop. Defaults to True.

Returns

An (init_fun, kernel, get_params) triple.

Return type

Tuple[Callable, Callable, Callable]

sgmcmcjax.kernels.build_sghmc_SVRG_kernel(dt: float, L: int, loglikelihood: Callable, logprior: Callable, data: Tuple, batch_size: int, update_rate: int, alpha: float = 0.01, compiled_leapfrog: bool = True)Tuple[Callable, Callable, Callable][source]

build stochatic gradient HMC kernel with SVRG

Parameters
  • dt (float) – step size

  • L (int) – number of leapfrog steps

  • loglikelihood (Callable) – log-likelihood for a single data point

  • logprior (Callable) – log-prior for a single data point

  • data (Tuple) – tuple of data. It should either have a single array (for unsupervised problems) or have two arrays (for supervised problems)

  • batch_size (int) – batch size

  • update_rate (int) – how often to update the centering value in the gradient estimator

  • alpha (float, optional) – friction coefficient. Defaults to 0.01.

  • compiled_leapfrog (bool, optional) – whether or not the loop is performed with lax.scan or not. Otherwise run a native Python loop. Defaults to True.

Returns

An (init_fun, kernel, get_params) triple.

Return type

Tuple[Callable, Callable, Callable]

sgmcmcjax.kernels.build_sghmc_kernel(dt: float, L: int, loglikelihood: Callable, logprior: Callable, data: Tuple, batch_size: int, alpha: float = 0.01, compiled_leapfrog: bool = True)Tuple[Callable, Callable, Callable][source]

build stochastic gradient HMC kernel. https://arxiv.org/abs/1402.4102

Parameters
  • dt (float) – step size

  • L (int) – number of leapfrog steps

  • loglikelihood (Callable) – log-likelihood for a single data point

  • logprior (Callable) – log-prior for a single data point

  • data (Tuple) – tuple of data. It should either have a single array (for unsupervised problems) or have two arrays (for supervised problems)

  • batch_size (int) – batch size

  • alpha (float, optional) – friction coefficient. Defaults to 0.01.

  • compiled_leapfrog (bool, optional) – whether or not the loop is performed with lax.scan or not. Otherwise run a native Python loop. Defaults to True.

Returns

An (init_fun, kernel, get_params) triple.

Return type

Tuple[Callable, Callable, Callable]

sgmcmcjax.kernels.build_sgldAdam_kernel(dt: float, loglikelihood: Callable, logprior: Callable, data: Tuple, batch_size: int, beta1: float = 0.9, beta2: float = 0.999, eps: float = 1e-08)Tuple[Callable, Callable, Callable][source]

build SGLD-adam kernel. See appendix in paper: https://arxiv.org/abs/2105.13059v1

Parameters
  • dt (float) – step size

  • loglikelihood (Callable) – log-likelihood for a single data point

  • logprior (Callable) – log-prior for a single data point

  • data (Tuple) – tuple of data. It should either have a single array (for unsupervised problems) or have two arrays (for supervised problems)

  • batch_size (int) – batch size

  • beta1 (float, optional) – weights for the first moment of the gradients. Defaults to 0.9.

  • beta2 (float, optional) – weights for the second moment of the gradients. Defaults to 0.999.

  • eps (float, optional) – small value to avoid instabilities. Defaults to 1e-8.

Returns

An (init_fun, kernel, get_params) triple.

Return type

Tuple[Callable, Callable, Callable]

sgmcmcjax.kernels.build_sgldCV_kernel(dt: float, loglikelihood: Callable, logprior: Callable, data: Tuple, batch_size: int, centering_value: Union[Dict, List, Tuple, numpy.ndarray, jax._src.numpy.lax_numpy.ndarray])Tuple[Callable, Callable, Callable][source]

Build SGLD-CV kernel

Parameters
  • dt (float) – step size

  • loglikelihood (Callable) – log-likelihood for a single data point

  • logprior (Callable) – log-prior for a single data point

  • data (Tuple) – tuple of data. It should either have a single array (for unsupervised problems) or have two arrays (for supervised problems)

  • batch_size (int) – batch size

  • centering_value (PyTree) – Centering value for the control variates (should be the MAP)

Returns

An (init_fun, kernel, get_params) triple.

Return type

Tuple[ Callable, Callable, Callable ]

sgmcmcjax.kernels.build_sgld_SVRG_kernel(dt: float, loglikelihood: Callable, logprior: Callable, data: Tuple, batch_size: int, update_rate: int)Tuple[Callable, Callable, Callable][source]

build SGLD-SVRG kernel

Parameters
  • dt (float) – step size

  • loglikelihood (Callable) – log-likelihood for a single data point

  • logprior (Callable) – log-prior for a single data point

  • data (Tuple) – tuple of data. It should either have a single array (for unsupervised problems) or have two arrays (for supervised problems)

  • batch_size (int) – batch size

  • update_rate (int) – how often to update the centering value in the gradient estimator

Returns

An (init_fun, kernel, get_params) triple.

Return type

Tuple[Callable, Callable, Callable]

sgmcmcjax.kernels.build_sgld_kernel(dt: float, loglikelihood: Callable, logprior: Callable, data: Tuple, batch_size: int)Tuple[Callable, Callable, Callable][source]

Build kernel for SGLD

Parameters
  • dt (float) – step size

  • loglikelihood (Callable) – log-likelihood for a single data point

  • logprior (Callable) – log-prior for a single data point

  • data (Tuple) – tuple of data. It should either have a single array (for unsupervised problems) or have two arrays (for supervised problems)

  • batch_size (int) – batch size

Returns

An (init_fun, kernel, get_params) triple.

Return type

Tuple[ Callable, Callable, Callable ]

sgmcmcjax.kernels.build_sgnhtCV_kernel(dt: float, loglikelihood: Callable, logprior: Callable, data: Tuple, batch_size: int, centering_value: Union[Dict, List, Tuple, numpy.ndarray, jax._src.numpy.lax_numpy.ndarray], a: float = 0.01)Tuple[Callable, Callable, Callable][source]

build stochastic gradient Nose Hoover Thermostats kernel with Control Variates

Parameters
  • dt (float) – step size

  • loglikelihood (Callable) – log-likelihood for a single data point

  • logprior (Callable) – log-prior for a single data point

  • data (Tuple) – tuple of data. It should either have a single array (for unsupervised problems) or have two arrays (for supervised problems)

  • batch_size (int) – batch size

  • centering_value (PyTree) – Centering value for the control variates (should be the MAP)

  • a (float, optional) – diffusion factor. Defaults to 0.01.

Returns

An (init_fun, kernel, get_params) triple.

Return type

Tuple[Callable, Callable, Callable]

sgmcmcjax.kernels.build_sgnht_kernel(dt: float, loglikelihood: Callable, logprior: Callable, data: Tuple, batch_size: int, a: float = 0.01)Tuple[Callable, Callable, Callable][source]

build stochastic gradient Nose Hoover Thermostats kernel. From http://people.ee.duke.edu/~lcarin/sgnht-4.pdf

Parameters
  • dt (float) – step size

  • loglikelihood (Callable) – log-likelihood for a single data point

  • logprior (Callable) – log-prior for a single data point

  • data (Tuple) – tuple of data. It should either have a single array (for unsupervised problems) or have two arrays (for supervised problems)

  • batch_size (int) – batch size

  • a (float, optional) – diffusion factor. Defaults to 0.01.

Returns

An (init_fun, kernel, get_params) triple.

Return type

Tuple[Callable, Callable, Callable]