Utility functions¶
-
sgmcmcjax.util.
build_grad_log_post
(loglikelihood: Callable, logprior: Callable, data: Tuple, with_val: bool = False) → Callable[source]¶ Build the gradient of the log-posterior. The returned function has signature:
- grad_lost_post (Callable)
- Args:
param (Pytree): parameters to evaluate the log-posterior at args: data (either minibatch or fullbatch) to pass in to the log-likelihood
- Returns:
gradient of the log-posterior (PyTree), and optionally the value of the log-posterior (float)
- Parameters
loglikelihood (Callable) – log-likelihood for a single data point
logprior (Callable) – log-prior for a single data point
data (Tuple) – tuple of data. It should either have a single array (for unsupervised problems) or have two arrays (for supervised problems)
with_val (bool, optional) – Whether or not the returned function also inclues the value of the log-posterior as well as the value of the gradient. Defaults to False.
- Raises
ValueError – the ‘data’ argument should either be a tuple of size 1 or 2
- Returns
The gradient of the log-posterior
- Return type
Callable
-
sgmcmcjax.util.
progress_bar_scan
(num_samples: int, message: Union[None, str] = None) → Callable[source]¶ Decorator factory to build a tqdm progress bar using lax.scan. This returns a decorator to apply to the ‘body’ function used in lax.scan
- Parameters
num_samples (int) – number of samples to run lax.scan
message (Union[None, str]) – message to display in the progress bar. Defaults to f’Running for {num_samples:,} iterations’
- Returns
decorator to apply to the body function used in lax.scan
- Return type
Callable
-
sgmcmcjax.util.
run_loop
(f: Callable, state: Any, xs: Iterable, compiled: bool = True) → Any[source]¶ Loop over an iterable and keep only the final state the function f should return (state, None) compiled: whether or not to run lax.scan or a Python loop
- Parameters
f (Callable) – function to apply at every iteration
state (Any) – state of the system
xs (iter) – Iteratable to loop over
compiled (Bool, optional) – whether or not the loop is performed with lax.scan or not. Otherwise run a native Python loop. Defaults to True.
- Returns
final state of the system
- Return type
Any