Utility functions

sgmcmcjax.util.build_grad_log_post(loglikelihood: Callable, logprior: Callable, data: Tuple, with_val: bool = False)Callable[source]

Build the gradient of the log-posterior. The returned function has signature:

grad_lost_post (Callable)
Args:

param (Pytree): parameters to evaluate the log-posterior at args: data (either minibatch or fullbatch) to pass in to the log-likelihood

Returns:

gradient of the log-posterior (PyTree), and optionally the value of the log-posterior (float)

Parameters
  • loglikelihood (Callable) – log-likelihood for a single data point

  • logprior (Callable) – log-prior for a single data point

  • data (Tuple) – tuple of data. It should either have a single array (for unsupervised problems) or have two arrays (for supervised problems)

  • with_val (bool, optional) – Whether or not the returned function also inclues the value of the log-posterior as well as the value of the gradient. Defaults to False.

Raises

ValueError – the ‘data’ argument should either be a tuple of size 1 or 2

Returns

The gradient of the log-posterior

Return type

Callable

sgmcmcjax.util.progress_bar_scan(num_samples: int, message: Union[None, str] = None)Callable[source]

Decorator factory to build a tqdm progress bar using lax.scan. This returns a decorator to apply to the ‘body’ function used in lax.scan

Parameters
  • num_samples (int) – number of samples to run lax.scan

  • message (Union[None, str]) – message to display in the progress bar. Defaults to f’Running for {num_samples:,} iterations’

Returns

decorator to apply to the body function used in lax.scan

Return type

Callable

sgmcmcjax.util.run_loop(f: Callable, state: Any, xs: Iterable, compiled: bool = True)Any[source]

Loop over an iterable and keep only the final state the function f should return (state, None) compiled: whether or not to run lax.scan or a Python loop

Parameters
  • f (Callable) – function to apply at every iteration

  • state (Any) – state of the system

  • xs (iter) – Iteratable to loop over

  • compiled (Bool, optional) – whether or not the loop is performed with lax.scan or not. Otherwise run a native Python loop. Defaults to True.

Returns

final state of the system

Return type

Any