Gradient estimation

sgmcmcjax.gradient_estimation.build_gradient_estimation_fn(grad_log_post: Callable, data: Tuple, batch_size: int)Tuple[Callable, Callable][source]

Build a standard gradient estimator

Parameters
  • grad_log_post (Callable) – gradient of the log-posterior

  • data (Tuple) – tuple of data. It should either have a single array (for unsupervised problems) or have two arrays (for supervised problems)

  • batch_size (int) – batch size

Returns

gradient estimation function and gradient initialisation function

Return type

Tuple[Callable, Callable]

sgmcmcjax.gradient_estimation.build_gradient_estimation_fn_CV(grad_log_post: Callable, data: Tuple, batch_size: int, centering_value: Union[Dict, List, Tuple, numpy.ndarray, jax._src.numpy.lax_numpy.ndarray])Tuple[Callable, Callable][source]

Build a Control Variates gradient estimator

Parameters
  • grad_log_post (Callable) – gradient of the log-posterior

  • data (Tuple) – tuple of data. It should either have a single array (for unsupervised problems) or have two arrays (for supervised problems)

  • batch_size (int) – batch size

  • centering_value (PyTree) – Centering value for the control variates (should be the MAP)

Returns

gradient estimation function and gradient initialisation function

Return type

Tuple[Callable, Callable]

sgmcmcjax.gradient_estimation.build_gradient_estimation_fn_SVRG(grad_log_post: Callable, data: Tuple, batch_size: int, update_rate: int)Tuple[Callable, Callable][source]

Build a SVRG gradient estimator

Parameters
  • grad_log_post (Callable) – gradient of the log-posterior

  • data (Tuple) – tuple of data. It should either have a single array (for unsupervised problems) or have two arrays (for supervised problems)

  • batch_size (int) – batch size

  • update_rate (int) – how often to update the centering value in the gradient estimator

Returns

gradient estimation function and gradient initialisation function

Return type

Tuple[Callable, Callable]