Gradient estimation¶
-
sgmcmcjax.gradient_estimation.
build_gradient_estimation_fn
(grad_log_post: Callable, data: Tuple, batch_size: int) → Tuple[Callable, Callable][source]¶ Build a standard gradient estimator
- Parameters
grad_log_post (Callable) – gradient of the log-posterior
data (Tuple) – tuple of data. It should either have a single array (for unsupervised problems) or have two arrays (for supervised problems)
batch_size (int) – batch size
- Returns
gradient estimation function and gradient initialisation function
- Return type
Tuple[Callable, Callable]
-
sgmcmcjax.gradient_estimation.
build_gradient_estimation_fn_CV
(grad_log_post: Callable, data: Tuple, batch_size: int, centering_value: Union[Dict, List, Tuple, numpy.ndarray, jax._src.numpy.lax_numpy.ndarray]) → Tuple[Callable, Callable][source]¶ Build a Control Variates gradient estimator
- Parameters
grad_log_post (Callable) – gradient of the log-posterior
data (Tuple) – tuple of data. It should either have a single array (for unsupervised problems) or have two arrays (for supervised problems)
batch_size (int) – batch size
centering_value (PyTree) – Centering value for the control variates (should be the MAP)
- Returns
gradient estimation function and gradient initialisation function
- Return type
Tuple[Callable, Callable]
-
sgmcmcjax.gradient_estimation.
build_gradient_estimation_fn_SVRG
(grad_log_post: Callable, data: Tuple, batch_size: int, update_rate: int) → Tuple[Callable, Callable][source]¶ Build a SVRG gradient estimator
- Parameters
grad_log_post (Callable) – gradient of the log-posterior
data (Tuple) – tuple of data. It should either have a single array (for unsupervised problems) or have two arrays (for supervised problems)
batch_size (int) – batch size
update_rate (int) – how often to update the centering value in the gradient estimator
- Returns
gradient estimation function and gradient initialisation function
- Return type
Tuple[Callable, Callable]