Optimizer¶
-
sgmcmcjax.optimizer.
build_adam_optimizer
(dt: float, loglikelihood: Callable, logprior: Callable, data: Tuple, batch_size: int) → Callable[source]¶ build adam optimizer using JAX optimizers module
- Parameters
dt (float) – step size
loglikelihood (Callable) – log-likelihood for a single data point
logprior (Callable) – log-prior for a single data point
data (Tuple) – tuple of data. It should either have a single array (for unsupervised problems) or have two arrays (for supervised problems)
batch_size (int) – batch size
- Returns
- optimizer function with signature:
- Args:
key (PRNGKey): random key
Niters (int): number of iterations
params_IC (PyTree): initial parameters
- Returns:
PyTree: final parameters
jnp.ndarray: array of log-posterior values during the optimization
- Return type
Callable
-
sgmcmcjax.optimizer.
build_optax_optimizer
(optimizer: optax._src.base.GradientTransformation, loglikelihood: Callable, logprior: Callable, data: Tuple, batch_size: int) → Callable[source]¶ build Optax optimizer. See Optax: https://github.com/deepmind/optax
- Parameters
optimizer (optax.GradientTransformation) – Optax GradientTransformation object
loglikelihood (Callable) – log-likelihood for a single data point
logprior (Callable) – log-prior for a single data point
data (Tuple) – tuple of data. It should either have a single array (for unsupervised problems) or have two arrays (for supervised problems)
batch_size (int) – batch size
- Returns
- optimizer function with signature:
- Args:
key (PRNGKey): random key
Niters (int): number of iterations
params_IC (PyTree): initial parameters
- Returns:
PyTree: final parameters
jnp.ndarray: array of log-posterior values during the optimization
- Return type
Callable