import functools
from functools import partial
from typing import Callable, Optional
from jax import random
from jax._src.util import safe_map, unzip2
from jax.tree_util import tree_flatten, tree_unflatten
from .types import DiffusionState
map = safe_map
[docs]def diffusion_factory(
is_palindrome: Optional[bool] = False, is_sghmc: Optional[bool] = False
) -> Callable:
"""Diffusion factory
Returns a decorator to make a diffusion factory work on PyTrees
Args:
is_palindrome (Optional[bool], optional): whether or not the diffusion is a palidrome method (ie: splitting scheme) Defaults to False.
is_sghmc (Optional[bool], optional): whether or not the sampler will include momentum resampling. Defaults to False.
Returns:
Callable: decorator to apply to a diffusion factory
"""
def _diffusion(sampler_maker):
@functools.wraps(sampler_maker)
def tree_sampler_maker(*args, **kwargs):
if is_sghmc and (not is_palindrome):
init_fn, update, get_params, resample_momentum = sampler_maker(
*args, **kwargs
)
elif (not is_sghmc) and is_palindrome:
init_fn, (update, update2), get_params = sampler_maker(*args, **kwargs)
elif is_sghmc and is_palindrome:
(
init_fn,
(update, update2),
get_params,
resample_momentum,
) = sampler_maker(*args, **kwargs)
else:
init_fn, update, get_params = sampler_maker(*args, **kwargs)
@functools.wraps(init_fn)
def tree_init(x0_tree):
x0_flat, tree = tree_flatten(x0_tree)
initial_states = [init_fn(x0) for x0 in x0_flat]
states_flat, subtrees = unzip2(map(tree_flatten, initial_states))
return DiffusionState(states_flat, tree, subtrees)
@functools.wraps(update)
def tree_update(i, key, grad_tree, sampler_state):
states_flat, tree, subtrees = sampler_state
grad_flat, tree2 = tree_flatten(grad_tree)
if tree2 != tree:
msg = (
"sampler update function was passed a gradient tree that did "
"not match the parameter tree structure with which it was "
"initialized: parameter tree {} and grad tree {}."
)
raise TypeError(msg.format(tree, tree2))
states = map(tree_unflatten, subtrees, states_flat)
keys = random.split(key, len(states))
new_states = map(partial(update, i), keys, grad_flat, states)
new_states_flat, subtrees2 = unzip2(map(tree_flatten, new_states))
for subtree, subtree2 in zip(subtrees, subtrees2):
if subtree2 != subtree:
msg = (
"sampler update function produced an output structure that "
"did not match its input structure: input {} and output {}."
)
raise TypeError(msg.format(subtree, subtree2))
return DiffusionState(new_states_flat, tree, subtrees)
@functools.wraps(get_params)
def tree_get_params(sampler_state):
states_flat, tree, subtrees = sampler_state
states = map(tree_unflatten, subtrees, states_flat)
params = map(get_params, states)
return tree_unflatten(tree, params)
if is_sghmc:
@functools.wraps(resample_momentum)
def tree_resample_momentum(i, k, state):
x_tree = tree_get_params(state)
x_flat, tree = tree_flatten(x_tree)
keys = random.split(k, len(x_flat))
states = [
resample_momentum(i, key, x) for (key, x) in zip(keys, x_flat)
]
states_flat, subtree = unzip2(map(tree_flatten, states))
return DiffusionState(states_flat, tree, subtree)
if is_palindrome:
@functools.wraps(update2)
def tree_update2(i, key, grad_tree, sampler_state):
states_flat, tree, subtrees = sampler_state
grad_flat, tree2 = tree_flatten(grad_tree)
if tree2 != tree:
msg = (
"sampler update2 function was passed a gradient tree that did "
"not match the parameter tree structure with which it was "
"initialized: parameter tree {} and grad tree {}."
)
raise TypeError(msg.format(tree, tree2))
states = map(tree_unflatten, subtrees, states_flat)
keys = random.split(key, len(states))
new_states = map(partial(update2, i), keys, grad_flat, states)
new_states_flat, subtrees2 = unzip2(map(tree_flatten, new_states))
for subtree, subtree2 in zip(subtrees, subtrees2):
if subtree2 != subtree:
msg = (
"sampler update2 function produced an output structure that "
"did not match its input structure: input {} and output {}."
)
raise TypeError(msg.format(subtree, subtree2))
return DiffusionState(new_states_flat, tree, subtrees)
if is_sghmc and (not is_palindrome):
return tree_init, tree_update, tree_get_params, tree_resample_momentum
elif (not is_sghmc) and is_palindrome:
return tree_init, (tree_update, tree_update2), tree_get_params
elif is_sghmc and is_palindrome:
return (
tree_init,
(tree_update, tree_update2),
tree_get_params,
tree_resample_momentum,
)
else:
return tree_init, tree_update, tree_get_params
return tree_sampler_maker
return _diffusion
diffusion = diffusion_factory()
diffusion_sghmc = diffusion_factory(is_sghmc=True)
diffusion_palindrome = diffusion_factory(is_palindrome=True)