Kernel Stein Discrepancy¶
-
sgmcmcjax.ksd.
imq_KSD
(samples: Union[numpy.ndarray, jax._src.numpy.lax_numpy.ndarray], grads: Union[numpy.ndarray, jax._src.numpy.lax_numpy.ndarray]) → float[source]¶ Kernel Stein Discrepancy with IMQ kernel
- Parameters
samples (Array) – MCMC samples
grads (Array) – gradients of the MCMC samples
- Returns
estimate of the KSD
- Return type
float
-
sgmcmcjax.ksd.
k_0_fun
(parm1: Union[numpy.ndarray, jax._src.numpy.lax_numpy.ndarray], parm2: Union[numpy.ndarray, jax._src.numpy.lax_numpy.ndarray], gradlogp1: Union[numpy.ndarray, jax._src.numpy.lax_numpy.ndarray], gradlogp2: Union[numpy.ndarray, jax._src.numpy.lax_numpy.ndarray], c: float = 1.0, beta: float = - 0.5) → float[source]¶ KSD kernel with the IMQ kernel and the 2 norm: http://proceedings.mlr.press/v70/gorham17a/gorham17a.pdf
- Parameters
parm1 (Array) – sampled parameter 1
parm2 (Array) – sampled parameter 2
gradlogp1 (Array) – gradient of sampled parameter 1
gradlogp2 (Array) – gradient of sampled parameter 2
c (float, optional) – intercept parameter in the IMQ kernel. Defaults to 1.
beta (float, optional) – exponent parameter in the IMQ kernel. Defaults to -0.5.
- Returns
value of kernel for the pair of samples
- Return type
float