Kernel Stein Discrepancy

sgmcmcjax.ksd.imq_KSD(samples: Union[numpy.ndarray, jax._src.numpy.lax_numpy.ndarray], grads: Union[numpy.ndarray, jax._src.numpy.lax_numpy.ndarray])float[source]

Kernel Stein Discrepancy with IMQ kernel

Parameters
  • samples (Array) – MCMC samples

  • grads (Array) – gradients of the MCMC samples

Returns

estimate of the KSD

Return type

float

sgmcmcjax.ksd.k_0_fun(parm1: Union[numpy.ndarray, jax._src.numpy.lax_numpy.ndarray], parm2: Union[numpy.ndarray, jax._src.numpy.lax_numpy.ndarray], gradlogp1: Union[numpy.ndarray, jax._src.numpy.lax_numpy.ndarray], gradlogp2: Union[numpy.ndarray, jax._src.numpy.lax_numpy.ndarray], c: float = 1.0, beta: float = - 0.5)float[source]

KSD kernel with the IMQ kernel and the 2 norm: http://proceedings.mlr.press/v70/gorham17a/gorham17a.pdf

Parameters
  • parm1 (Array) – sampled parameter 1

  • parm2 (Array) – sampled parameter 2

  • gradlogp1 (Array) – gradient of sampled parameter 1

  • gradlogp2 (Array) – gradient of sampled parameter 2

  • c (float, optional) – intercept parameter in the IMQ kernel. Defaults to 1.

  • beta (float, optional) – exponent parameter in the IMQ kernel. Defaults to -0.5.

Returns

value of kernel for the pair of samples

Return type

float