KFAC-JAX internals

Tracing and detecting layers

Compared to black-box optimizers, K-FAC is not fully agnostic to the model and loss function being optimized. The algorithm needs to know what type of loss function it is optimizing, as this determines the exact form of the curvature matrix. Additionally, it needs to know which parameters are grouped together into the same block, and what type of layer that is, in order to select appropriate block approximation. As demonstrated in the quickstart section in guides one must explicitly register their loss function via library calls. Layers are detected and registered automatically by default, however there is a mechanism to do it manually for particular layers.

Computing relative statistics for curvature estimation

In contrast to standard first-order optimizers, K-FAC requires gradients with respect to the outputs of each layer in order to compute the curvature estimate. This is particularly difficult in JAX, as it does not allow computing of gradients with respect to intermediate expressions inside a function. To deal with this, the library traces the provided value_and_grad_func up to the registered loss function, discovers all of the layers and creates a new function, which has additional auxiliary inputs, set to zero, which are added to the outputs of each layer. This is what kfac_jax:layer_tags_vjp() performs and is at the core of the curvature matrix estimator.

Extending KFAC-JAX

Below we provide a short guide on how to extend the library to new types of layers, and to implement new curvature blocks.

Creating new layer tags

1. Create a new instance kfac_jax.LayerTag, specifying the number of inputs and outputs the corresponding computation has.

2. Write a function that takes as inputs the parameters, inputs, and outputs of the computation of your new layer, and binds it with the layer tag you have defined in the previous step.

3. If you want your tag to be automatically detectable by the library, you will need to create a new “graph pattern”, which tells the library what a computation for this layer looks like and how to interpret it. This process can be broken down to the following smaller steps:

a. Create a function which performs the computation done by your new layer in isolation.

b. Create a function that extracts additional parameters that can be relevant for computing the curvature block approximations. (E.g. for convolutional layers we have to capture things like dilation and other parameters which identify whether this is standard or separable convolution.) Note that the input to this function is going the be a sequence of JAX equations that are returned by calling jax.make_jaxpr() on the function performing the computation.

c. Using the functions created in the above steps, create an instance of kfac_jax.GraphPattern.

4. Provide your graph pattern to the optimizer as part of the auto_register_kwargs argument.

Below is an example of how one might add support for dense/fully-connected layers (which are of course already supported by KFAC-JAX) using the above steps:

from typing import Sequence
import chex
import jax
import jax.numpy as jnp
import kfac_jax

# Step 1
dense = LayerTag(name="dense_tag", num_inputs=1, num_outputs=1)

# Step 2
def register_dense(
    y: chex.Array,
    x: chex.Array,
    w: chex.Array,
    b: Optional[chex.Array] = None,
    **kwargs,
) -> chex.Array:
  """Registers a dense layer: ``y = matmul(x, w) + b``."""
  if b is None:
    return dense.bind(y, x, w, **kwargs)
  return dense.bind(y, x, w, b, **kwargs)

# Step 3.a
def _dense(x: chex.Array, params: Sequence[chex.Array]) -> chex.Array:
  """Example of a dense layer function."""
  w, *opt_b = params
  y = jnp.matmul(x, w)
  return y if not opt_b else y + opt_b[0]

# Step 3.b
def _dense_parameter_extractor(
    eqns: Sequence[core.JaxprEqn],
) -> Mapping[str, Any]:
  """Extracts all parameters from the conv_general_dilated operator."""
  for eqn in eqns:
    if eqn.primitive.name == "dot_general":
      return dict(**eqn.params)
  assert False

# Step 3.c
dense_with_bias_pattern = GraphPattern(
    name="dense_with_bias",
    tag_primitive=tags.dense,
    precedence=0,
    compute_func=_dense,
    parameters_extractor_func=_dense_parameter_extractor,
    example_args=[np.zeros([11, 13]), [np.zeros([13, 7]), np.zeros([7])]],
)

# Step 4
optimizer = kfac_jax.Optimizer(
    ...
    auto_register_kwargs=dict(
        graph_patterns=((dense_with_bias_pattern,) +
                        kfac_jax.tag_graph_matcher.DEFAULT_GRAPH_PATTERNS),
    ),
    ...
)

See the FermiNet project for another example of how to add a new layer tag using the above steps.

Creating new curvature blocks

1. Create a new curvature block class by extending kfac_jax.CurvatureBlock.

2. Tell the optimizer which tags should use the new curvature block by providing a mapping between the name of the tags and the class you created in the previous step through the layer_tag_to_block_ctor argument of kfac_jax.Optimizer.

Below is an example of how one might add a standard Kronecker-factored block approximation of dense layers (which is of course already supported by KFAC-JAX):

import jax
import jax.numpy as jnp
import kfac_jax

# Step 1
class DenseTwoKroneckerFactored(TwoKroneckerFactored):
  """A :class:`~TwoKroneckerFactored` block specifically for dense layers."""

  def input_size(self) -> int:
    """The size of the Kronecker-factor corresponding to inputs."""
    if self.has_bias:
      return self.parameters_shapes[0][0] + 1
    else:
      return self.parameters_shapes[0][0]

  def output_size(self) -> int:
    """The size of the Kronecker-factor corresponding to outputs."""
    return self.parameters_shapes[0][1]

  def update_curvature_matrix_estimate(
      self,
      state: TwoKroneckerFactored.State,
      estimation_data: Mapping[str, Sequence[chex.Array]],
      ema_old: chex.Numeric,
      ema_new: chex.Numeric,
      batch_size: int,
      pmap_axis_name: Optional[str],
  ) -> TwoKroneckerFactored.State:
    del pmap_axis_name
    x, = estimation_data["inputs"]
    dy, = estimation_data["outputs_tangent"]
    assert utils.first_dim_is_size(batch_size, x, dy)

    if self.has_bias:
      x_one = jnp.ones_like(x[:, :1])
      x = jnp.concatenate([x, x_one], axis=1)
    input_stats = jnp.matmul(x.T, x) / batch_size
    output_stats = jnp.matmul(dy.T, dy) / batch_size
    state.inputs_factor.update(input_stats, ema_old, ema_new)
    state.outputs_factor.update(output_stats, ema_old, ema_new)
    return state

# Step 2
optimizer = kfac_jax.Optimizer(
    ...
    layer_tag_to_block_ctor=dict(dense_tag=DenseTwoKroneckerFactored),
    ...
)

See the FermiNet project for another example of how to add curvature block using the above steps.