KFAC-JAX internals
Tracing and detecting layers
Compared to black-box optimizers, K-FAC is not fully agnostic to the model and loss function being optimized. The algorithm needs to know what type of loss function it is optimizing, as this determines the exact form of the curvature matrix. Additionally, it needs to know which parameters are grouped together into the same block, and what type of layer that is, in order to select appropriate block approximation. As demonstrated in the quickstart section in guides one must explicitly register their loss function via library calls. Layers are detected and registered automatically by default, however there is a mechanism to do it manually for particular layers.
Computing relative statistics for curvature estimation
In contrast to standard first-order optimizers, K-FAC requires gradients
with respect to the outputs of each layer in order to compute the curvature
estimate.
This is particularly difficult in JAX, as it does not allow computing of
gradients with respect to intermediate expressions inside a function.
To deal with this, the library traces the provided value_and_grad_func up to
the registered loss function, discovers all of the layers and creates a new
function, which has additional auxiliary inputs, set to zero, which are added
to the outputs of each layer.
This is what kfac_jax:layer_tags_vjp() performs and is at the core of the
curvature matrix estimator.
Extending KFAC-JAX
Below we provide a short guide on how to extend the library to new types of layers, and to implement new curvature blocks.
Creating new curvature blocks
1. Create a new curvature block class by extending
kfac_jax.CurvatureBlock.
2. Tell the optimizer which tags should use the new curvature block by providing
a mapping between the name of the tags and the class you created in the previous
step through the layer_tag_to_block_ctor argument of
kfac_jax.Optimizer.
Below is an example of how one might add a standard Kronecker-factored block approximation of dense layers (which is of course already supported by KFAC-JAX):
import jax
import jax.numpy as jnp
import kfac_jax
# Step 1
class DenseTwoKroneckerFactored(TwoKroneckerFactored):
"""A :class:`~TwoKroneckerFactored` block specifically for dense layers."""
def input_size(self) -> int:
"""The size of the Kronecker-factor corresponding to inputs."""
if self.has_bias:
return self.parameters_shapes[0][0] + 1
else:
return self.parameters_shapes[0][0]
def output_size(self) -> int:
"""The size of the Kronecker-factor corresponding to outputs."""
return self.parameters_shapes[0][1]
def update_curvature_matrix_estimate(
self,
state: TwoKroneckerFactored.State,
estimation_data: Mapping[str, Sequence[chex.Array]],
ema_old: chex.Numeric,
ema_new: chex.Numeric,
batch_size: int,
pmap_axis_name: Optional[str],
) -> TwoKroneckerFactored.State:
del pmap_axis_name
x, = estimation_data["inputs"]
dy, = estimation_data["outputs_tangent"]
assert utils.first_dim_is_size(batch_size, x, dy)
if self.has_bias:
x_one = jnp.ones_like(x[:, :1])
x = jnp.concatenate([x, x_one], axis=1)
input_stats = jnp.matmul(x.T, x) / batch_size
output_stats = jnp.matmul(dy.T, dy) / batch_size
state.inputs_factor.update(input_stats, ema_old, ema_new)
state.outputs_factor.update(output_stats, ema_old, ema_new)
return state
# Step 2
optimizer = kfac_jax.Optimizer(
...
layer_tag_to_block_ctor=dict(dense_tag=DenseTwoKroneckerFactored),
...
)
See the FermiNet project for another example of how to add curvature block using the above steps.