Curvature Estimation
KFAC-JAX supports curvature estimation for neural networks. Below we list the main components of the library used for this.
Curvature estimator subclasses
The kfac_jax.CurvatureEstimator class is
the main abstract class for different estimators of the curvature matrix of a
neural network.
The library currently implements the following three concrete subclasses:
1. kfac_jax.BlockDiagonalCurvature - this estimator approximates the
curvature as block diagonal, with blocks corresponding to layers.
This estimator is used by the main optimizer class kfac_jax.Optimizer
for computing the preconditioner.
2. kfac_jax.ExplicitExactCurvature - this estimator approximates the
curvature as an explicit dense array.
Note its state is very large (square in the number of parameters) and
infeasible for even moderately sizes models.
By default it will compute the exact curvature matrix, which can also be
computationally expensive when the model output or batch size is large.
It is possible to compute stochastic approximations by using different values
of the estimation_mode argument.
This class is mainly targeted for small scale toy and demonstrative problems.
3. kfac_jax.ImplicitExactCurvature - a special estimator which provides
matrix-vector products functionality with the full curvature matrix, without
materializing it in memory.
This can be very useful in some circumstances, but the class does not implement
product with the inverse of the matrix, or any power other than 1.
This estimator is used during optimization for automatic selection of learning
rate and momentum via the quadratic model.
Block approximations
The block diagonal curvature approximation is a very general class, which allows users to choose different approximations for different layers. There are three major types of curvature blocks (with corresponding subclasses):
1. kfac_jax.Diagonal - approximates the curvature block for the given
layer using only a diagonal matrix.
This type of approximation only considers the curvature for each entry of each
parameter, and ignores interactions between different entries/parameters.
It is used by default for “scale and shift” layers, such as those that commonly
appear in batch norm and layer normalization layers.
2. kfac_jax.TwoKroneckerFactored - approximates the curvature block for
the given layer using the Kronecker product of two matrices as in the
K-FAC paper.
Unlike diagonal approximations, this captures interactions between different
parameter entries within a layer. However, it is an approximation, and
essentially factorizes this interaction as a product of interactions between
input vector entries, and output vector entries. It is practical for most
layers, as long as they aren’t very wide, and is used by default for
dense/fully-connected and convolutional layers.
3. kfac_jax.Full - explicitly computes the full curvature matrix for
the block parameters.
When the number of parameters is large, using this block can be too memory
intensive, as its state size is square in the number of parameters.
As a result, it is most suitable for layers with small parameterizations.
By default it is not used for any layer type.
Supported layers
Currently, the library only includes support for the three most common types of layers used in practice:
Dense layers, corresponding to \(y = Wx + b\).
2D convolution layers, corresponding to \(y = W \star x + b\).
Scale and shift layers, corresponding to \(y = w \odot x + b\).
Here \(\star\) corresponds to convolution and \(\odot\) to elementwise product. Parameter reuse, such as in recurrent networks and attention layers, is not currently supported.
If you want to extend the library with your own layers refer to the relevant section in advanced for how to do this.
Supported losses
Currently, the library only includes support for the three most common types of loss functions used in practice:
1. kfac_jax.register_sigmoid_cross_entropy_loss() specifies that the model
outputs a vector of logits and is trained using the standard sigmoid cross
entropy loss.
This can be interpreted as predicting a factorized Bernoulli distribution over
output labels.
An alias to this is
kfac_jax.register_multi_bernoulli_predictive_distribution().
2. kfac_jax.register_softmax_cross_entropy_loss() specifies
that the model outputs a vector of logits and is trained using the standard
softmax cross entropy loss.
This can be interpreted as predicting a Categorical distribution over output
labels.
An alias to this is
kfac_jax.register_categorical_predictive_distribution().
3. kfac_jax.register_squared_error_loss() specifies
that the model outputs a vector and is trained using the standard squared loss.
This can be interpreted as predicting a Gaussian with a variance of 0.5.
An alias to this is kfac_jax.register_normal_predictive_distribution().
If you want to create and extend the library with your own loss functions checkout the relevant section in advanced on how to do this.
Optimizer
The optimization algorithm implemented in kfac_jax.Optimizer follows
the K-FAC paper.
Throughout optimization the Optimizer instance keeps the following persistent
state:
If we denote the current minibatch of data by \(\bm{x}_t\), the current parameters by \(\bm{\theta}_t\), the L2 regularizer by \(\gamma\) and the loss function (which includes the L2 regularizer) by \(\mathcal{L}\), a high level pseudocode for a single step of the optimizer is:
Steps 1, 2, 3, 5 and 6 are standard for any second order optimization algorithm. Step 4 and 7 are described in more details below.
Computing the update coefficients (4)
The update coefficients \(\alpha_t\) and \(\beta_t\) in step 4 can
either be provided manually by the user at each step, or can be computed
automatically from the local quadratic model.
This is controlled by the optimizer arguments use_adaptive_learning_rate
and use_adaptive_momentum.
Note that these features don’t currently work very well unless you use a very
large batch size, and/or increase the batch size dynamically during training
(as was done in the original K-FAC paper).
Automatic selection of update coefficients
The procedure to automatically select the update coefficients uses the local quadratic model defined as:
where \(\bm{C}\) is usually the exact curvature matrix. To compute \(\alpha_t\) and \(\beta_t\), we minimize \(q(\alpha_t \hat{\bm{g}}_t + \beta_t \bm{v}_t)\) with respect to the two scalars, treating \(\hat{\bm{g}}_t\) and \(\bm{v}_t\) as fixed vectors. Since this is a simple two dimensional quadratic problem, and it requires only matrix-vector products with \(\bm{C}\), it can be solved efficiently. For further details see Section 7 of the original K-FAC paper.
Updating the damping (7)
The damping update is done via the Levenberg-Marquardt heuristic. This is done by computing the reduction ratio:
where \(q\) is the quadratic model value induced by either the exact or
approximate curvature matrix.
If the optimizer uses either learning rate or momentum adaptation, or
always_use_exact_qmodel_for_damping_adjustment is set to True, the
optimizer will use the exact curvature matrix; otherwise it will use the
approximate curvature.
If the value of \(\rho\) deviates too much from 1 we either increase or
decrease the damping \(\lambda\) as described in Section 6.5 of the original
K-FAC paper.
Whether the damping is adapted, or provided by the user at each single step, is
controlled by the optimizer argument use_adaptive_damping.
Amortizing expensive computations
When running the optimizer, several of the steps involved can have a noticeable computational overhead. For this reason, the optimizer class allows these to be performed every K steps, and to cache the values across iterations. This has been found to work well in practice without significant drawbacks in training performance. This is applied to computing the inverse of the estimated approximate curvature (step 3), and to the updates of the damping (step 7).