Optimizer

sgmcmcjax.optimizer.build_adam_optimizer(dt: float, loglikelihood: Callable, logprior: Callable, data: Tuple, batch_size: int)Callable[source]

build adam optimizer using JAX optimizers module

Parameters
  • dt (float) – step size

  • loglikelihood (Callable) – log-likelihood for a single data point

  • logprior (Callable) – log-prior for a single data point

  • data (Tuple) – tuple of data. It should either have a single array (for unsupervised problems) or have two arrays (for supervised problems)

  • batch_size (int) – batch size

Returns

optimizer function with signature:
Args:

key (PRNGKey): random key

Niters (int): number of iterations

params_IC (PyTree): initial parameters

Returns:

PyTree: final parameters

jnp.ndarray: array of log-posterior values during the optimization

Return type

Callable

sgmcmcjax.optimizer.build_optax_optimizer(optimizer: optax._src.base.GradientTransformation, loglikelihood: Callable, logprior: Callable, data: Tuple, batch_size: int)Callable[source]

build Optax optimizer. See Optax: https://github.com/deepmind/optax

Parameters
  • optimizer (optax.GradientTransformation) – Optax GradientTransformation object

  • loglikelihood (Callable) – log-likelihood for a single data point

  • logprior (Callable) – log-prior for a single data point

  • data (Tuple) – tuple of data. It should either have a single array (for unsupervised problems) or have two arrays (for supervised problems)

  • batch_size (int) – batch size

Returns

optimizer function with signature:
Args:

key (PRNGKey): random key

Niters (int): number of iterations

params_IC (PyTree): initial parameters

Returns:

PyTree: final parameters

jnp.ndarray: array of log-posterior values during the optimization

Return type

Callable