Standard API
Optimizers
|
The K-FAC optimizer. |
Optimizer
- class kfac_jax.Optimizer(value_and_grad_func, l2_reg, value_func_has_aux=False, value_func_has_state=False, value_func_has_rng=False, use_adaptive_learning_rate=False, learning_rate_schedule=None, use_adaptive_momentum=False, momentum_schedule=None, use_adaptive_damping=False, damping_schedule=None, initial_damping=None, min_damping=1e-08, max_damping=inf, include_damping_in_quad_change=False, damping_adaptation_interval=5, damping_adaptation_decay=0.9, damping_lower_threshold=0.25, damping_upper_threshold=0.75, always_use_exact_qmodel_for_damping_adjustment=False, precon_damping_mult=1.0, norm_constraint=None, num_burnin_steps=10, estimation_mode=None, custom_estimator_ctor=None, curvature_ema=0.95, curvature_update_period=1, inverse_update_period=5, use_exact_inverses=False, batch_process_func=None, register_only_generic=False, patterns_to_skip=(), use_automatic_registration=True, auto_register_kwargs=None, layer_tag_to_block_ctor=None, multi_device=False, debug=False, batch_size_extractor=<function default_batch_size_extractor>, pmap_axis_name='kfac_axis', forbid_setting_attributes_after_finalize=True, modifiable_attribute_exceptions=(), include_norms_in_stats=False, include_per_param_norms_in_stats=False, include_registered_loss_in_stats=False, distributed_precon_apply=True, distributed_inverses=True, num_estimator_samples=1, should_vmap_estimator_samples=False, norm_to_scale_identity_weight_per_block=None)[source]
The K-FAC optimizer.
- class State(velocities, estimator_state, damping, data_seen, step_counter)[source]
Persistent state of the optimizer.
- velocities
The update to the parameters from the previous step - \(\theta_t - \theta_{t-1}\).
- Type
Params
- estimator_state
The persistent state for the curvature estimator.
- damping
When using damping adaptation, this will contain the current value.
- Type
Optional[Array]
- data_seen
The number of training cases that the optimizer has processed.
- Type
Numeric
- step_counter
An integer giving the current step number \(t\).
- Type
Numeric
- __init__(velocities, estimator_state, damping, data_seen, step_counter)
- __init__(value_and_grad_func, l2_reg, value_func_has_aux=False, value_func_has_state=False, value_func_has_rng=False, use_adaptive_learning_rate=False, learning_rate_schedule=None, use_adaptive_momentum=False, momentum_schedule=None, use_adaptive_damping=False, damping_schedule=None, initial_damping=None, min_damping=1e-08, max_damping=inf, include_damping_in_quad_change=False, damping_adaptation_interval=5, damping_adaptation_decay=0.9, damping_lower_threshold=0.25, damping_upper_threshold=0.75, always_use_exact_qmodel_for_damping_adjustment=False, precon_damping_mult=1.0, norm_constraint=None, num_burnin_steps=10, estimation_mode=None, custom_estimator_ctor=None, curvature_ema=0.95, curvature_update_period=1, inverse_update_period=5, use_exact_inverses=False, batch_process_func=None, register_only_generic=False, patterns_to_skip=(), use_automatic_registration=True, auto_register_kwargs=None, layer_tag_to_block_ctor=None, multi_device=False, debug=False, batch_size_extractor=<function default_batch_size_extractor>, pmap_axis_name='kfac_axis', forbid_setting_attributes_after_finalize=True, modifiable_attribute_exceptions=(), include_norms_in_stats=False, include_per_param_norms_in_stats=False, include_registered_loss_in_stats=False, distributed_precon_apply=True, distributed_inverses=True, num_estimator_samples=1, should_vmap_estimator_samples=False, norm_to_scale_identity_weight_per_block=None)[source]
Initializes the K-FAC optimizer with the provided settings.
NOTE: Please read the docstring for this constructor carefully. Especially the description of
value_and_grad_func.A note on the “damping” parameter:
One of the main complications of using second-order optimizers like K-FAC is the “damping” parameter. This parameter is multiplied by the identity matrix and (approximately) added to the curvature matrix (i.e. the Fisher or GGN) before it is inverted and multiplied by the gradient when computing the update (before any learning rate scaling). The damping should follow the scale of the objective, so that if you multiply your loss by some factor you should do the same for the damping. Roughly speaking, larger damping values constrain the update vector to a smaller region around zero, which is needed in general since the second-order approximations that underly second-order methods can break down for large updates. (In gradient descent the learning rate plays an analogous role.) The relationship between the damping parameter and the radius of this region is complicated and depends on the scale of the objective amongst other things.
The optimizer provides a system for adjusting the damping automatically via the
use_adaptive_dampingargument, although this system is not reliable, especially for highly stochastic objectives. Using a fixed value or a manually tuned schedule can work as good or better for some problems, while it can be a very poor choice for others (like deep autoencoders). Empirically we have found that using a fixed value works well enough for common architectures like convnets and transformers.- Parameters
value_and_grad_func (ValueAndGradFunc) – Python callable. This function should return the value of the loss to be optimized and its gradients, and optionally the model state and auxiliary information in the form of a a dict mapping strings to scalar arrays (usually statistics to log). Note that it should not be jitted/pmapped or otherwise compiled by JAX, as this can lead to errors. (Compilation is done internally by the optimizer.) The interface of this function should be:
out_args, loss_grads = value_and_grad_func(*in_args). Here,in_argsis(params, func_state, rng, batch), withrngomitted ifvalue_func_has_rngisFalse, and withfunc_stateomitted ifvalue_func_has_stateisFalse. Meanwhile,out_argsis(loss, (func_state, aux))ifvalue_func_has_stateandvalue_func_has_auxare bothTrue,(loss, func_state)ifvalue_func_has_stateisTrueandvalue_func_has_auxisFalse,(loss, aux)ifvalue_func_has_stateisFalseandvalue_func_has_auxisTrue, and finallylossifvalue_func_has_stateandvalue_func_has_auxare bothFalse. This should be consistent with how JAX’svalue_and_gradAPI function is typically used.l2_reg (Numeric) – Scalar. Set this value to tell the optimizer what L2 regularization coefficient you are using (if any). Note the coefficient appears in the regularizer as
coeff / 2 * sum(param**2). This adds an additional diagonal term to the curvature and hence will affect the quadratic model when using adaptive damping. Note that the user is still responsible for adding regularization to the loss.value_func_has_aux (bool) – Boolean. Specifies whether the provided callable
value_and_grad_funcreturns auxiliary data. (Default:False)value_func_has_state (bool) – Boolean. Specifies whether the provided callable
value_and_grad_funchas a persistent state that is passed in and out. (Default:False)value_func_has_rng (bool) – Boolean. Specifies whether the provided callable
value_and_grad_funcadditionally takes as input an rng key. (Default:False)use_adaptive_learning_rate (bool) – Boolean. Specifies whether to use the special rule from the original K-FAC paper for picking the learning rate at each step. Note that this won’t work well for stochastic objectives. If this is
False, the user must use thelearning_rateargument of the step function, or the constructor argumentlearning_rate_schedule. (Default:False)learning_rate_schedule (Optional[ScheduleType]) – Callable. A schedule for the learning rate. This should take as input the current step number, and optionally the amount of data seen so far as a keyword argument
data_seen, and return a single array that represents the learning rate. (Default:None)use_adaptive_momentum (bool) – Boolean. Specifies whether to use the special rule from the original K-FAC paper for picking the momentum “decay” parameter at each step. Note that this won’t work well for stochastic objectives. If this is
False, the user must use themomentumargument of the step function, or the constructor argumentmomentum_schedule. (Default:False)momentum_schedule (Optional[ScheduleType]) – Callable. A schedule for the momentum parameter. This should take as input the current step number, and optionally the amount of data seen so far as a keyword argument
data_seen, and return a single array that represents the momentum. (Default:None)use_adaptive_damping (bool) – Boolean. Specifies whether the optimizer will use the Levenberg-Marquardt method to automatically adjust the damping every
damping_adaptation_intervaliterations. If this is set toFalsethe user must provide a value to the damping argument of the step function at each iteration, or use thedamping_scheduleconstructor argument. Note that the effectiveness of this technique seems to vary between problems. (Default:False)damping_schedule (Optional[ScheduleType]) – Callable. A schedule for the damping. This should take as input the current step number, and optionally the amount of data seen so far as a keyword argument
data_seen, and return a single array that represents the learning rate. (Default:None)initial_damping (Optional[Numeric]) – Scalar or None. This specifies the initial value of the damping that the optimizer will use when using automatic damping adaptation. (Default:
None)min_damping (Numeric) – Scalar. Minimum value the damping parameter can take when using automatic damping adaptation. Note that the default value of 1e-8 is quite arbitrary, and you may have to adjust this up or down for your particular problem. If you are using a non-zero value of l2_reg you may be able to set this to zero. (Default:
1e-8)max_damping (Numeric) – Scalar. Maximum value the damping parameter can take when using automatic damping adaptation. (Default:
Infinity)include_damping_in_quad_change (bool) – Boolean. Whether to include the contribution of the damping in the quadratic model for the purposes computing the reduction ration (“rho”) in the Levenberg-Marquardt scheme used for adapting the damping. Note that the contribution from the
l2_regargument is always included. (Default:False)damping_adaptation_interval (int) – Int. The number of steps in between adapting the damping parameter. (Default:
5)damping_adaptation_decay (Numeric) – Scalar. The damping parameter will be adjusted up or down by
damping_adaptation_decay ** damping_adaptation_interval, or remain unchanged, everydamping_adaptation_intervalnumber of iterations. (Default:0.9)damping_lower_threshold (Numeric) – Scalar. The damping parameter is increased if the reduction ratio is below this threshold. (Default:
0.25)damping_upper_threshold (Numeric) – Scalar. The damping parameter is decreased if the reduction ratio is below this threshold. (Default:
0.75)always_use_exact_qmodel_for_damping_adjustment (bool) – Boolean. When using learning rate and/or momentum adaptation, the quadratic model change used for damping adaption is always computed using the exact curvature matrix. Otherwise, there is an option to use either the exact or approximate curvature matrix to compute the quadratic model change, which is what this argument controls. When True, the exact curvature matrix will be used, which is more expensive, but could possibly produce a better damping schedule. (Default:
False)precon_damping_mult (Numeric) – Scalar. Multiplies the damping used in the preconditioner (vs the exact quadratic model) by this value. (Default: 1.0)
norm_constraint (Optional[Numeric]) – Scalar. If specified, the update is scaled down so that its approximate squared Fisher norm
v^T F vis at most the specified value. (Note that hereFis the approximate curvature matrix, not the exact.) May only be used whenuse_adaptive_learning_rateisFalse. (Default:None)num_burnin_steps (int) – Int. At the start of optimization, e.g. the first step, before performing the actual step the optimizer will perform this many times updates to the curvature approximation without updating the actual parameters. (Default:
10)estimation_mode (Optional[str]) – String. The type of estimator to use for the curvature matrix. See the documentation for
CurvatureEstimatorfor a detailed description of the possible options. IfNonewill use default estimation_mode mode of the used CurvatureEstimator subclass, which is typically “fisher_gradients”. (Default:None)custom_estimator_ctor (Optional[Callable[..., curvature_estimator.BlockDiagonalCurvature]]) – Optional constructor for subclass of
BlockDiagonalCurvature. If specified, the optimizer will use this conastructor instead of the defaultBlockDiagonalCurvature. (Default:None)curvature_ema (Numeric) – The decay factor used when calculating the covariance estimate moving averages. (Default:
0.95)curvature_update_period (int) – Int. The number of steps in between updating the the curvature estimates. (Default:
1)inverse_update_period (int) – Int. The number of steps in between updating the the computation of the inverse curvature approximation. (Default:
5)use_exact_inverses (bool) – Bool. If
True, preconditioner inverses are computed “exactly” without the pi-adjusted factored damping approach. Note that this involves the use of eigendecompositions, which can sometimes be much more expensive. (Default:False)batch_process_func (Optional[Callable[[Batch], Batch]]) – Callable. A function which to be called on each batch before feeding to the KFAC on device. This could be useful for specific device input optimizations. (Default:
None)register_only_generic (bool) – Boolean. Whether when running the auto-tagger to register only generic parameters, or allow it to use the graph matcher to automatically pick up any kind of layer tags. (Default:
False)patterns_to_skip (Sequence[str]) – Tuple. A list of any patterns that should be skipped by the graph matcher when auto-tagging. (Default:
())use_automatic_registration (bool) – Bool. If
True, the optimizer will try to automatically register the layers of your network. (Default:True)auto_register_kwargs (Optional[Dict[str, Any]]) – Any additional kwargs to be passed down to
auto_register_tags(), which is called by the curvature estimator. (Default:None)layer_tag_to_block_ctor (Optional[Dict[str, curvature_estimator.CurvatureBlockCtor]]) – Dictionary. A mapping from layer tags to block classes which to override the default choices of block approximation for that specific tag. See the documentation for
CurvatureEstimatorfor a more detailed description. (Default:None)multi_device (bool) – Boolean. Whether to use pmap and run the optimizer on multiple devices. (Default:
False)debug (bool) – Boolean. If neither the step or init functions should be jitted. Note that this also overrides
multi_deviceand prevents using pmap. (Default:False)batch_size_extractor (Callable[[Batch], Numeric]) – A function that takes as input the function arguments and returns the batch size for a single device. (Default:
kfac.utils.default_batch_size_extractor)pmap_axis_name (str) – String. The name of the pmap axis to use when
multi_deviceis set to True. (Default:kfac_axis)forbid_setting_attributes_after_finalize (bool) – Boolean. By default, after the object is finalized, you can not set any of its properties. This is done in order to protect the user from making changes to the object attributes that would not be picked up by various internal methods after they have been compiled. However, if you are extending this class, and clearly understand the risks of modifying attributes, setting this to
Falsewill remove the restriction. (Default:True)modifiable_attribute_exceptions (Sequence[str]) – Sequence of strings. Gives a list of names for attributes that can be modified after finalization even when
forbid_setting_attributes_after_finalizeisTrue. (Default:())include_norms_in_stats (bool) – Boolean. It True, the vector norms of the gradient, preconditioned gradient, and parameter update are included in the statistics returned by the step function. (Default:
False)include_per_param_norms_in_stats (bool) – Boolean. It True, the per-parameter vector norms of the gradient, preconditioned gradient, and parameter update are included in the statistics returned by the step function. (Default:
False)include_registered_loss_in_stats (bool) – Boolean. If True, we include the loss, as computed from the registered losses, in the stats. Also included is the relative difference between this as the loss computed from
value_and_grad_func. This is useful for debugging registration errors. Note this for this option to work it’s required that the targets are passed for each loss function registration. (Default:False)distributed_precon_apply (bool) – Boolean. Whether to distribute the application of the preconditioner across the different devices in a layer-wise fashion. If False, each device will (redundantly) perform the required operations for all the layers. (Default: True)
distributed_inverses (bool) – Boolean. Whether to distribute the inverse computations (required to compute the preconditioner) across the different devices in a layer-wise fashion. If False, each device will (redundantly) perform the required computations for all the layers. (Default: True)
num_estimator_samples (int) – Number of samples (per case) to use when computing stochastic curvature matrix estimates. This option is only used when
estimation_mode == 'fisher_gradients'orestimation_mode == '[fisher,ggn]_curvature_prop'. (Default: 1)should_vmap_estimator_samples (bool) – Whether to use
jax.vmapto compute samples whennum_estimator_samples > 1. (Default: False)norm_to_scale_identity_weight_per_block (Optional[str]) – The name of a norm to use to compute extra per-block scaling for the damping. See psd_matrix_norm() in utils/math.py for the definition of these. Note that this will not affect the exact quadratic model that is used as part of the “adaptive” learning rate, momentum, and damping methods. (Default: None)
- property num_burnin_steps: int
The number of burnin steps to run before the first parameter update.
- Return type
int
- property l2_reg: Numeric
The weight of the additional diagonal term added to the curvature.
- Return type
Numeric
- property estimator: curvature_estimator.BlockDiagonalCurvature
The underlying curvature estimator used by the optimizer.
- Return type
- property damping_decay_factor: Numeric
How fast to decay the damping, when using damping adaptation.
- Return type
Numeric
- should_update_damping(state)[source]
Whether at the current step the optimizer should update the damping.
- Return type
Array
- should_update_estimate_curvature(state)[source]
Whether at the current step the optimizer should update the curvature estimates.
- Return type
Union[Array, bool]
- should_update_inverse_cache(state)[source]
Whether at the current step the optimizer should update the inverse curvature approximation.
- Return type
Union[Array, bool]
- compute_loss_value(func_args)[source]
Computes the value of the loss function being optimized.
- Return type
Array
- verify_args_and_get_step_counter(step_counter, learning_rate=None, momentum=None, damping=None, global_step_int=None)[source]
Verifies that the arguments passed to the step function are correct.
- Return type
int
- init(params, rng, batch, func_state=None)[source]
Initializes the optimizer and returns the appropriate optimizer state.
- Return type
- burnin(num_steps, params, state, rng, data_iterator, func_state=None)[source]
Runs all burnin steps required.
- Return type
Tuple[‘Optimizer.State’, Optional[FuncState]]
- step(params, state, rng, data_iterator=None, batch=None, func_state=None, learning_rate=None, momentum=None, damping=None, global_step_int=None)[source]
Performs a single update step using the optimizer.
NOTE: please do not jit/pmap or otherwise compile this function with JAX, as this can lead to errors. Compilation is handled internally by the optimizer.
- Parameters
params (Params) – The current parameters of the model.
state (Optimizer.State) – The current state of the optimizer.
rng (PRNGKey) – A Jax PRNG key. Should be different for each iteration and each Jax process/host.
data_iterator (Optional[Iterator[Batch]]) – A data iterator to use (if not passing
batch).batch (Optional[Batch]) – A single batch used to compute the update. Should only pass one of
data_iteratororbatch.func_state (Optional[FuncState]) – Any function state that gets passed in and returned.
learning_rate (Optional[Array]) – Learning rate to use if the optimizer was created with
use_adaptive_learning_rate=True,Noneotherwise.momentum (Optional[Array]) – Momentum to use if the optimizer was created with
use_adaptive_momentum=True,Noneotherwise.damping (Optional[Array]) – Damping to use if the optimizer was created with
use_adaptive_damping=True,Noneotherwise. See discussion of constructor argumentinitial_dampingfor more information about damping.global_step_int (Optional[int]) – The global step as a python int. Note that this must match the step internal to the optimizer that is part of its state.
- Return type
ReturnEither
- Returns
(params, state, stats) if
value_func_has_state=Falseand (params, state, func_state, stats) otherwise, whereparams is the updated model parameters.
state is the updated optimizer state.
func_state is the updated function state.
stats is a dictionary of useful statistics including the loss.
- compute_l2_quad_matrix(vectors)[source]
Computes the matrix corresponding to the prior/regularizer.
- Parameters
vectors (Sequence[Params]) – A sequence of parameter-like PyTree structures, each one
vector. (representing a different) –
- Return type
Array
- Returns
A matrix with i,j entry equal to
self.l2_reg * v_i^T v_j.
- compute_exact_quad_model(vectors, grads, func_args, state=None)[source]
Computes the components of the exact quadratic model.
- Return type
Tuple[Array, Array, Array]
- compute_approx_quad_model(state, vectors, grads)[source]
Computes the components of the approximate quadratic model.
- Return type
Tuple[Array, Array, Array]
Curvature Estimators
|
An abstract curvature estimator class. |
|
Block diagonal curvature estimator class. |
|
Explicit exact full curvature estimator class. |
|
Represents all exact curvature matrices never constructed explicitly. |
|
Sets the default curvature block constructor for the given tag. |
|
Returns the default curvature block constructor for the give tag name. |
CurvatureEstimator
- class kfac_jax.CurvatureEstimator(func, params_index=0, default_estimation_mode='fisher_gradients')[source]
An abstract curvature estimator class.
This is a class that abstracts away the process of estimating a curvature matrix and provides many useful functionalities for interacting with it. The state of the estimator contains two parts: the estimated curvature internal representation, as well as potential cached values of different expression involving the curvature matrix (for example matrix powers). The cached values are only updated once you call the method
update_cache(). Multiple methods contain the keyword argumentuse_cachedwhich specify whether you want to compute the corresponding expression using the current curvature estimate or using a cached version.- func
The model evaluation function.
- params_index
The index of the parameters argument in arguments list of
func.
- default_estimation_mode
The estimation mode which to use by default when calling
update_curvature_matrix_estimate().
- __init__(func, params_index=0, default_estimation_mode='fisher_gradients')[source]
Initializes the CurvatureEstimator instance.
- Parameters
func (utils.Func) – The model function, which should have at least one registered loss.
params_index (int) – The index of the parameters argument in arguments list of
func.default_estimation_mode (str) – The estimation mode which to use by default when calling
update_curvature_matrix_estimate().
- property default_mat_type: str
The type of matrix that this estimator is approximating.
- Return type
str
- abstract property dim: int
The number of elements of all parameter variables together.
- Return type
int
- abstract init(rng, func_args, exact_powers_to_cache, approx_powers_to_cache, cache_eigenvalues=False)[source]
Initializes the state for the estimator.
- Parameters
rng (PRNGKey) – The PRNGKey which to be used for any randomness of the initialization.
func_args (utils.FuncArgs) – Example function arguments, which to be used to trace the model function and initialize the state.
exact_powers_to_cache (Optional[curvature_blocks.ScalarOrSequence]) – A single value, or multiple values in a list, which specify which exact matrix powers that each block should be caching. Matrix powers for which you intend to call
self.multiply_matrix_power,self.multiply_inverseorself.multiplywithexact_power=Trueanduse_cached=Truemust be provided here.approx_powers_to_cache (Optional[curvature_blocks.ScalarOrSequence]) – A single value, or multiple values in a list, which specify approximate matrix powers that each block should be caching. Matrix powers for which you intend to call
self.multiply_matrix_power,self.multiply_inverseorself.multiplywithexact_power=Falseanduse_cached=Truemust be provided here.cache_eigenvalues (bool) – Specifies whether each block should be caching the eigenvalues of its approximate curvature.
- Return type
StateType
- Returns
The initialized state of the estimator.
- abstract sync(state, pmap_axis_name)[source]
Synchronizes across devices the state of the estimator.
- Return type
StateType
- abstract multiply_matpower(state, parameter_structured_vector, identity_weight, power, exact_power, use_cached, pmap_axis_name, norm_to_scale_identity_weight_per_block=None)[source]
Computes
(CurvatureMatrix + identity_weight I)**powertimesvector.- Parameters
state (StateType) – The state of the estimator.
parameter_structured_vector (utils.Params) – A vector in the same structure as the parameters of the model.
identity_weight (Numeric) – Specifies the weight of the identity element that is added to the curvature matrix. This can be either a scalar value or a list/tuple of scalar in which case each value specifies the weight individually for each block.
power (Scalar) – The power to which you want to raise the matrix
(EstimateCurvature + identity_weight I).exact_power (bool) – When set to
Truethe matrix power ofEstimateCurvature + identity_weight Iis computed exactly. Otherwise this method might use a cheaper approximation, which may vary across different blocks.use_cached (bool) – Whether to use a cached (and possibly stale) version of the curvature matrix estimate.
pmap_axis_name (Optional[str]) – The name of any pmap axis, which will be used for aggregating any computed values over multiple devices, as well as parallelizing the computation over devices in a block-wise fashion.
norm_to_scale_identity_weight_per_block (Optional[str]) – The name of a norm to use to compute extra per-block scaling for identity_weight. See psd_matrix_norm() in utils/math.py for the definition of these.
- Return type
utils.Params
- Returns
A parameter structured vector containing the product.
- multiply(state, parameter_structured_vector, identity_weight, exact_power, use_cached, pmap_axis_name, norm_to_scale_identity_weight_per_block=None)[source]
Computes
(CurvatureMatrix + identity_weight I)timesvector.- Return type
utils.Params
- multiply_inverse(state, parameter_structured_vector, identity_weight, exact_power, use_cached, pmap_axis_name, norm_to_scale_identity_weight_per_block=None)[source]
Computes
(CurvatureMatrix + identity_weight I)^-1timesvector.- Return type
utils.Params
- abstract eigenvalues(state, use_cached)[source]
Computes the eigenvalues of the curvature matrix.
- Parameters
state (StateType) – The state of the estimator.
use_cached (bool) – Whether to use a cached versions of the eigenvalues or to use the most recent curvature estimates to compute them. The cached version are going to be at least as fresh as the last time you called
update_cache()witheigenvalues=True.
- Return type
Array
- Returns
A single array containing the eigenvalues of the curvature matrix.
- abstract update_curvature_matrix_estimate(state, ema_old, ema_new, batch_size, rng, func_args, estimation_mode=None)[source]
Updates the estimator’s curvature estimates.
- Parameters
state (StateType) – The state of the estimator to update.
ema_old (Numeric) – Specifies the weight of the old value when computing the updated estimate in the moving average.
ema_new (Numeric) – Specifies the weight of the new value when computing the updated estimate in the moving average.
batch_size (Numeric) – The batch size.
rng (PRNGKey) – A PRNGKey to be used for any potential sampling in the estimation process.
func_args (utils.FuncArgs) – A structure with the values of the inputs to the traced function (the
tagged_funcpassed into the constructor) which to be used for the estimation process. Should have the same structure as the argumentfunc_argspassed in the constructor.estimation_mode (Optional[str]) –
The type of curvature estimator to use. By default (e.g. if
None) will useself.default_estimation_mode. One of:fisher_gradients - the basic estimation approach from the original K-FAC paper.
fisher_curvature_prop - method which estimates the Fisher using self-products of random 1/-1 vectors times “half-factors” of the Fisher, as described here.
fisher_exact - is the obvious generalization of Curvature Propagation to compute the exact Fisher (modulo any additional diagonal or Kronecker approximations) by looping over one-hot vectors for each coordinate of the output instead of using 1/-1 vectors. It is more expensive to compute than the other three options by a factor equal to the output dimension, roughly speaking.
fisher_empirical - computes the ‘empirical’ Fisher information matrix (which uses the data’s distribution for the targets, as opposed to the true Fisher which uses the model’s distribution) and requires that each registered loss have specified targets.
ggn_curvature_prop - Analogous to fisher_curvature_prop, but estimates the Generalized Gauss-Newton matrix (GGN).
ggn_exact - Analogous to fisher_exact, but estimates the Generalized Gauss-Newton matrix (GGN).
- Return type
StateType
- Returns
The updated state.
- abstract update_cache(state, identity_weight, exact_powers, approx_powers, eigenvalues, pmap_axis_name)[source]
Updates the estimator cached values.
- Parameters
state (StateType) – The state of the estimator to update.
identity_weight (Numeric) – Specified the weight of the identity element that is added to the curvature matrix. This can be either a scalar value or a list/tuple of scalar in which case each value specifies the weight individually for each block.
exact_powers (Optional[curvature_blocks.ScalarOrSequence]) – Specifies which exact matrix powers in the cache should be updated.
approx_powers (Optional[curvature_blocks.ScalarOrSequence]) – Specifies which approximate matrix powers in the cache should be updated.
eigenvalues (bool) – Specifies whether to update the cached eigenvalues of each block. If they have not been cached before, this will create an entry with them in the block’s cache.
pmap_axis_name (Optional[str]) – The name of any pmap axis, which will be used for aggregating any computed values over multiple devices, as well as parallelizing the computation over devices in a block-wise fashion.
- Return type
StateType
- Returns
The updated state.
BlockDiagonalCurvature
- class kfac_jax.BlockDiagonalCurvature(func, params_index=0, default_estimation_mode=None, layer_tag_to_block_ctor=None, index_to_block_ctor=None, auto_register_tags=True, distributed_multiplies=True, distributed_cache_updates=True, num_samples=1, should_vmap_samples=False, **auto_register_kwargs)[source]
Block diagonal curvature estimator class.
- class State(synced, blocks_states)[source]
Persistent state of the estimator.
- synced
A Jax boolean, specifying if the state has been synced across devices (this does not include the cache, which is never explicitly synced).
- Type
Array
- blocks_states
A tuple of the state of the estimator corresponding to each block.
- Type
Tuple[curvature_blocks.CurvatureBlock.State, …]
- __init__(synced, blocks_states)
- __init__(func, params_index=0, default_estimation_mode=None, layer_tag_to_block_ctor=None, index_to_block_ctor=None, auto_register_tags=True, distributed_multiplies=True, distributed_cache_updates=True, num_samples=1, should_vmap_samples=False, **auto_register_kwargs)[source]
Initializes the curvature instance.
- Parameters
func (utils.Func) – The model function, which should have at least one registered loss.
params_index (int) – The index of the parameters argument in arguments list of
func.default_estimation_mode (Optional[str]) – The estimation mode which to use by default when calling
self.update_curvature_matrix_estimate. IfNonethis will be'fisher_gradients'.layer_tag_to_block_ctor (Optional[Mapping[str, CurvatureBlockCtor]]) – An optional dict mapping tags to specific classes of block approximations, which to override the default ones.
index_to_block_ctor (Optional[Mapping[Tuple[int, ...], CurvatureBlockCtor]]) – An optional dict mapping a specific block parameter indices to specific classes of block approximation, which to override the default ones. To get the correct indices check
estimator.indices_to_block_map.auto_register_tags (bool) – Whether to automatically register layer tags for parameters that have not been manually registered. For further details see
tag_graph_matcher.auto_register_tags.distributed_multiplies (bool) – Whether to distribute the curvature matrix multiplication operations across the different devices in a block-wise fashion. If False, each device will (redundantly) perform the operations for all of the blocks.
distributed_cache_updates (bool) – Whether to distribute the cache update multiplication operations across the different devices in a block-wise fashion. If False, each device will (redundantly) perform the operations for all of the blocks.
num_samples (int) – Number of samples (per case) to use when computing stochastic curvature matrix estimates. This option is only used when
estimation_mode == 'fisher_gradients'orestimation_mode == '[fisher,ggn]_curvature_prop'.should_vmap_samples (bool) – Whether to use
jax.vmapto compute samples whennum_samples > 1.**auto_register_kwargs (Any) – Any keyword arguments to pass to into the auto registration function.
- property blocks: Optional[Tuple[curvature_blocks.CurvatureBlock]]
The tuple of
CurvatureBlockinstances used for each layer.- Return type
Optional[Tuple[curvature_blocks.CurvatureBlock]]
- property num_blocks: int
The number of separate blocks that this estimator has.
- Return type
int
- property block_dims: Shape
The number of elements of all parameter variables for each block.
- Return type
Shape
- property dim: int
The number of elements of all parameter variables together.
- Return type
int
- property params_structure_vector_of_indices: utils.Params
A tree structure with parameters replaced by their indices.
- Return type
utils.Params
- property indices_to_block_map: Mapping[Tuple[int, ...], curvature_blocks.CurvatureBlock]
A mapping of parameter indices to their associated blocks.
- Return type
Mapping[Tuple[int, …], curvature_blocks.CurvatureBlock]
- property params_block_index: utils.Params
A structure, which shows each parameter to which block it corresponds.
- Return type
utils.Params
- Returns
A parameter-like structure, where each parameter is replaced by an integer index. This index specifies the block (found by
self.blocks[index]) which approximates the part of the curvature matrix associated with the parameter.
- property num_params_variables: int
The number of separate parameter variables of the model.
- Return type
int
- params_vector_to_blocks_vectors(parameter_structured_vector)[source]
Splits the parameters to values for each corresponding block.
- Return type
Tuple[Tuple[Array, …]]
- blocks_vectors_to_params_vector(blocks_vectors)[source]
Reverses the effect of
self.vectors_to_blocks.- Return type
utils.Params
- init(rng, func_args, exact_powers_to_cache, approx_powers_to_cache, cache_eigenvalues=False)[source]
Initializes the state for the estimator.
- Parameters
rng (PRNGKey) – The PRNGKey which to be used for any randomness of the initialization.
func_args (utils.FuncArgs) – Example function arguments, which to be used to trace the model function and initialize the state.
exact_powers_to_cache (Optional[curvature_blocks.ScalarOrSequence]) – A single value, or multiple values in a list, which specify which exact matrix powers that each block should be caching. Matrix powers for which you intend to call
self.multiply_matrix_power,self.multiply_inverseorself.multiplywithexact_power=Trueanduse_cached=Truemust be provided here.approx_powers_to_cache (Optional[curvature_blocks.ScalarOrSequence]) – A single value, or multiple values in a list, which specify approximate matrix powers that each block should be caching. Matrix powers for which you intend to call
self.multiply_matrix_power,self.multiply_inverseorself.multiplywithexact_power=Falseanduse_cached=Truemust be provided here.cache_eigenvalues (bool) – Specifies whether each block should be caching the eigenvalues of its approximate curvature.
- Return type
- Returns
The initialized state of the estimator.
- sync(state, pmap_axis_name)[source]
Synchronizes across devices the state of the estimator.
- Return type
- multiply_matpower(state, parameter_structured_vector, identity_weight, power, exact_power, use_cached, pmap_axis_name, norm_to_scale_identity_weight_per_block=None)[source]
Computes
(CurvatureMatrix + identity_weight I)**powertimesvector.- Parameters
state (BlockDiagonalCurvature.State) – The state of the estimator.
parameter_structured_vector (utils.Params) – A vector in the same structure as the parameters of the model.
identity_weight (Union[Numeric, Sequence[Numeric]]) – Specifies the weight of the identity element that is added to the curvature matrix. This can be either a scalar value or a list/tuple of scalar in which case each value specifies the weight individually for each block.
power (Scalar) – The power to which you want to raise the matrix
(EstimateCurvature + identity_weight I).exact_power (bool) – When set to
Truethe matrix power ofEstimateCurvature + identity_weight Iis computed exactly. Otherwise this method might use a cheaper approximation, which may vary across different blocks.use_cached (bool) – Whether to use a cached (and possibly stale) version of the curvature matrix estimate.
pmap_axis_name (Optional[str]) – The name of any pmap axis, which will be used for aggregating any computed values over multiple devices, as well as parallelizing the computation over devices in a block-wise fashion.
norm_to_scale_identity_weight_per_block (Optional[str]) – The name of a norm to use to compute extra per-block scaling for identity_weight. See psd_matrix_norm() in utils/math.py for the definition of these.
- Return type
utils.Params
- Returns
A parameter structured vector containing the product.
- block_eigenvalues(state, use_cached)[source]
Computes the eigenvalues for each block of the curvature estimator.
- Parameters
state (BlockDiagonalCurvature.State) – The state of the estimator.
use_cached (bool) – Whether to use a cached versions of the eigenvalues or to use the most recent curvature estimates to compute them. The cached version are going to be at least as fresh as the last time you called
update_cache()witheigenvalues=True.
- Return type
Tuple[Array, …]
- Returns
A tuple of arrays containing the eigenvalues for each block. The order of this tuple corresponds to the ordering of
self.blocks. To understand which parameters correspond to which block you can callself.parameters_block_index.
- eigenvalues(state, use_cached)[source]
Computes the eigenvalues of the curvature matrix.
- Parameters
state (BlockDiagonalCurvature.State) – The state of the estimator.
use_cached (bool) – Whether to use a cached versions of the eigenvalues or to use the most recent curvature estimates to compute them. The cached version are going to be at least as fresh as the last time you called
update_cache()witheigenvalues=True.
- Return type
Array
- Returns
A single array containing the eigenvalues of the curvature matrix.
- update_curvature_matrix_estimate(state, ema_old, ema_new, batch_size, rng, func_args, estimation_mode=None)[source]
Updates the estimator’s curvature estimates.
- Parameters
state (BlockDiagonalCurvature.State) – The state of the estimator to update.
ema_old (Numeric) – Specifies the weight of the old value when computing the updated estimate in the moving average.
ema_new (Numeric) – Specifies the weight of the new value when computing the updated estimate in the moving average.
batch_size (Numeric) – The batch size.
rng (PRNGKey) – A PRNGKey to be used for any potential sampling in the estimation process.
func_args (utils.FuncArgs) – A structure with the values of the inputs to the traced function (the
tagged_funcpassed into the constructor) which to be used for the estimation process. Should have the same structure as the argumentfunc_argspassed in the constructor.estimation_mode (Optional[str]) –
The type of curvature estimator to use. By default (e.g. if
None) will useself.default_estimation_mode. One of:fisher_gradients - the basic estimation approach from the original K-FAC paper.
fisher_curvature_prop - method which estimates the Fisher using self-products of random 1/-1 vectors times “half-factors” of the Fisher, as described here.
fisher_exact - is the obvious generalization of Curvature Propagation to compute the exact Fisher (modulo any additional diagonal or Kronecker approximations) by looping over one-hot vectors for each coordinate of the output instead of using 1/-1 vectors. It is more expensive to compute than the other three options by a factor equal to the output dimension, roughly speaking.
fisher_empirical - computes the ‘empirical’ Fisher information matrix (which uses the data’s distribution for the targets, as opposed to the true Fisher which uses the model’s distribution) and requires that each registered loss have specified targets.
ggn_curvature_prop - Analogous to fisher_curvature_prop, but estimates the Generalized Gauss-Newton matrix (GGN).
ggn_exact - Analogous to fisher_exact, but estimates the Generalized Gauss-Newton matrix (GGN).
- Return type
- Returns
The updated state.
- update_cache(state, identity_weight, exact_powers, approx_powers, eigenvalues, pmap_axis_name, norm_to_scale_identity_weight_per_block=None)[source]
Updates the estimator cached values.
- Parameters
state (BlockDiagonalCurvature.State) – The state of the estimator to update.
identity_weight (Union[Numeric, Sequence[Numeric]]) – Specified the weight of the identity element that is added to the curvature matrix. This can be either a scalar value or a list/tuple of scalar in which case each value specifies the weight individually for each block.
exact_powers (Optional[curvature_blocks.ScalarOrSequence]) – Specifies which exact matrix powers in the cache should be updated.
approx_powers (Optional[curvature_blocks.ScalarOrSequence]) – Specifies which approximate matrix powers in the cache should be updated.
eigenvalues (bool) – Specifies whether to update the cached eigenvalues of each block. If they have not been cached before, this will create an entry with them in the block’s cache.
pmap_axis_name (Optional[str]) – The name of any pmap axis, which will be used for aggregating any computed values over multiple devices, as well as parallelizing the computation over devices in a block-wise fashion.
- Return type
- Returns
The updated state.
ExplicitExactCurvature
- class kfac_jax.ExplicitExactCurvature(func, batch_index=1, default_estimation_mode=None, layer_tag_to_block_ctor=None, auto_register_tags=False, param_order=None, **kwargs)[source]
Explicit exact full curvature estimator class.
This class estimates the full curvature matrix by looping over the batch dimension of the input data and for each single example computes an estimate of the curvature matrix and then averages over all examples in the input data. This implies that the computation scales linearly (without parallelism) with the batch size. The class stores the estimated curvature as a dense matrix, hence its memory requirement is (number of parameters)^2. If
estimation_modeisfisher_exactorggn_exactthen this would compute the exact curvature, but other modes are also supported. As a result of looping over the input data this class needs to know the index of the batch in the arguments to the model function and additionally, since the loop is achieved through indexing, each array leaf of that argument must have the same first dimension size, which will be interpreted as the batch size.- __init__(func, batch_index=1, default_estimation_mode=None, layer_tag_to_block_ctor=None, auto_register_tags=False, param_order=None, **kwargs)[source]
Initializes the curvature instance.
- Parameters
func (utils.Func) – The model function, which should have at least one registered loss.
batch_index (int) – Specifies at which index of the inputs to
funcis the batch, representing data over which we average the curvature.default_estimation_mode (Optional[str]) – The estimation mode which to use by default when calling
self.update_curvature_matrix_estimate. IfNonethis will be'fisher_exact'.layer_tag_to_block_ctor (Optional[Mapping[str, CurvatureBlockCtor]]) – An optional dict mapping tags to specific classes of block approximations, which to override the default ones.
auto_register_tags (bool) – This argument will be ignored since this subclass doesn’t use automatic registration.
param_order (Optional[Tuple[int]]) – An optional tuple of ints specifying the order of parameters (with the reference order being the one used by
func). If not specified, the reference order is used. The parameter order will determine the order of blocks returned byto_diagonal_block_dense_matrix, and the order of the rows and columns ofto_dense_matrix.**kwargs (Any) – Addiional keyword arguments passed to the superclass
BlockDiagonalCurvature.
- update_curvature_matrix_estimate(state, ema_old, ema_new, batch_size, rng, func_args, estimation_mode=None)[source]
Updates the estimator’s curvature estimates.
- Parameters
state (BlockDiagonalCurvature.State) – The state of the estimator to update.
ema_old (Numeric) – Specifies the weight of the old value when computing the updated estimate in the moving average.
ema_new (Numeric) – Specifies the weight of the new value when computing the updated estimate in the moving average.
batch_size (Numeric) – The batch size.
rng (PRNGKey) – A PRNGKey to be used for any potential sampling in the estimation process.
func_args (utils.FuncArgs) – A structure with the values of the inputs to the traced function (the
tagged_funcpassed into the constructor) which to be used for the estimation process. Should have the same structure as the argumentfunc_argspassed in the constructor.estimation_mode (Optional[str]) –
The type of curvature estimator to use. By default (e.g. if
None) will useself.default_estimation_mode. One of:fisher_gradients - the basic estimation approach from the original K-FAC paper.
fisher_curvature_prop - method which estimates the Fisher using self-products of random 1/-1 vectors times “half-factors” of the Fisher, as described here.
fisher_exact - is the obvious generalization of Curvature Propagation to compute the exact Fisher (modulo any additional diagonal or Kronecker approximations) by looping over one-hot vectors for each coordinate of the output instead of using 1/-1 vectors. It is more expensive to compute than the other three options by a factor equal to the output dimension, roughly speaking.
fisher_empirical - computes the ‘empirical’ Fisher information matrix (which uses the data’s distribution for the targets, as opposed to the true Fisher which uses the model’s distribution) and requires that each registered loss have specified targets.
ggn_curvature_prop - Analogous to fisher_curvature_prop, but estimates the Generalized Gauss-Newton matrix (GGN).
ggn_exact - Analogous to fisher_exact, but estimates the Generalized Gauss-Newton matrix (GGN).
- Return type
- Returns
The updated state.
ImplicitExactCurvature
- class kfac_jax.ImplicitExactCurvature(func, params_index=0, batch_size_extractor=<function default_batch_size_extractor>)[source]
Represents all exact curvature matrices never constructed explicitly.
- __init__(func, params_index=0, batch_size_extractor=<function default_batch_size_extractor>)[source]
Initializes the ImplicitExactCurvature instance.
- Parameters
func (utils.Func) – The model function, which should have at least one registered loss.
params_index (int) – The index of the parameters argument in arguments list of
func.batch_size_extractor (Callable[[utils.Batch], Numeric]) – A function that takes as input the function arguments and returns the batch size for a single device. (Default:
kfac.utils.default_batch_size_extractor)
- batch_size(func_args)[source]
The expected batch size given a list of loss instances.
- Return type
Numeric
- multiply_hessian(func_args, parameter_structured_vector)[source]
Multiplies the vector with the Hessian matrix of the total loss.
- Parameters
func_args (utils.FuncArgs) – The inputs to the model function, on which to evaluate the Hessian matrix.
parameter_structured_vector (utils.Params) – The vector which to multiply with the Hessian matrix.
- Return type
utils.Params
- Returns
The product
Hv.
- multiply_jacobian(func_args, parameter_structured_vector, return_loss_objects=False)[source]
Multiplies a vector by the model’s Jacobian.
- Parameters
func_args (utils.FuncArgs) – The inputs to the model function.
parameter_structured_vector (utils.Params) – A vector in the same structure as the parameters of the model.
return_loss_objects (bool) – If set to True will return as an additional output the loss objects evaluated at the provided function arguments.
- Return type
Union[LossFunctionInputsTuple, Tuple[LossFunctionInputsTuple, LossFunctionsTuple]]
- Returns
The product
J v, whereJis the model’s Jacobian andvis given byparameter_structured_vector.
- multiply_jacobian_transpose(func_args, loss_input_vectors, return_loss_objects=False)[source]
Multiplies a vector by the model’s transposed Jacobian.
- Parameters
func_args (utils.FuncArgs) – The inputs to the model function.
loss_input_vectors (LossFunctionInputsSequence) – A sequence over losses of sequences of arrays that are the size of the loss’s inputs. This represents the vector to be multiplied.
return_loss_objects (bool) – If set to True will return as an additional output the loss objects evaluated at the provided function arguments.
- Return type
Union[utils.Params, Tuple[utils.Params, LossFunctionsTuple]]
- Returns
The product
J^T v, whereJis the model’s Jacobian andvis given byloss_inner_vectors.
- multiply_fisher(func_args, parameter_structured_vector)[source]
Multiplies the vector with the Fisher matrix of the total loss.
- Parameters
func_args (utils.FuncArgs) – The inputs to the model function, on which to evaluate the Fisher matrix.
parameter_structured_vector (utils.Params) – The vector which to multiply with the Fisher matrix.
- Return type
utils.Params
- Returns
The product
Fv.
- multiply_ggn(func_args, parameter_structured_vector)[source]
Multiplies the vector with the GGN matrix of the total loss.
- Parameters
func_args (utils.FuncArgs) – The inputs to the model function, on which to evaluate the GGN matrix.
parameter_structured_vector (utils.Params) – The vector which to multiply with the GGN matrix.
- Return type
utils.Params
- Returns
The product
Gv.
- multiply_fisher_factor_transpose(func_args, parameter_structured_vector)[source]
Multiplies the vector with the transposed factor of the Fisher matrix.
- Parameters
func_args (utils.FuncArgs) – The inputs to the model function, on which to evaluate the Fisher matrix.
parameter_structured_vector (utils.Params) – The vector which to multiply with the Fisher matrix.
- Return type
Tuple[Array, …]
- Returns
The product
B^T v, whereF = BB^T.
- multiply_ggn_factor_transpose(func_args, parameter_structured_vector)[source]
Multiplies the vector with the transposed factor of the GGN matrix.
- Parameters
func_args (utils.FuncArgs) – The inputs to the model function, on which to evaluate the GGN matrix.
parameter_structured_vector (utils.Params) – The vector which to multiply with the GGN matrix.
- Return type
Tuple[Array, …]
- Returns
The product
B^T v, whereG = BB^T.
- multiply_fisher_factor(func_args, loss_inner_vectors)[source]
Multiplies the vector with the factor of the Fisher matrix.
- Parameters
func_args (utils.FuncArgs) – The inputs to the model function, on which to evaluate the Fisher matrix.
loss_inner_vectors (Sequence[Array]) – The vector which to multiply with the Fisher factor matrix.
- Return type
utils.Params
- Returns
The product
Bv, whereF = BB^T.
- multiply_ggn_factor(func_args, loss_inner_vectors)[source]
Multiplies the vector with the factor of the GGN matrix.
- Parameters
func_args (utils.FuncArgs) – The inputs to the model function, on which to evaluate the GGN matrix.
loss_inner_vectors (Sequence[Array]) – The vector which to multiply with the GGN factor matrix.
- Return type
utils.Params
- Returns
The product
Bv, whereG = BB^T.
- get_loss_inner_vector_shapes_and_batch_size(func_args, mode)[source]
Get shapes of loss inner vectors, and the batch size.
- Parameters
func_args (utils.FuncArgs) – The inputs to the model function.
mode (str) – A string representing the type of curvature matrix for the loss inner vectors. Can be “fisher” or “ggn”.
- Return type
Tuple[Tuple[Shape, …], int]
- Returns
Shapes of loss inner vectors in a tuple, and the batch size as an int.
- get_loss_input_shapes_and_batch_size(func_args)[source]
Get shapes of loss input vectors, and the batch size.
- Parameters
func_args (utils.FuncArgs) – The inputs to the model function.
- Return type
Tuple[Tuple[Tuple[Shape, …], …], int]
- Returns
A tuple over losses of tuples containing the shapes of their different inputs, and the batch size (as an int).
set_default_tag_to_block_ctor
get_default_tag_to_block_ctor
Loss Functions
|
Abstract base class for loss functions. |
|
Base class for loss functions that represent negative log-probability. |
|
Negative log-probability loss that uses a Distrax distribution. |
|
Loss log prob loss for a normal distribution parameterized by a mean vector. |
|
Negative log prob loss for a normal distribution with mean and variance. |
|
Negative log prob loss for multiple Bernoulli distributions parametrized by logits. |
Negative log prob loss for a categorical distribution parameterized by logits. |
|
Neg log prob loss for a categorical distribution with onehot targets. |
|
Registers a sigmoid cross-entropy loss function. |
|
Registers a multi-Bernoulli predictive distribution. |
|
Registers a softmax cross-entropy loss function. |
|
Registers a categorical predictive distribution. |
|
|
Registers a squared error loss function. |
Registers a normal predictive distribution. |
LossFunction
- class kfac_jax.LossFunction(weight)[source]
Abstract base class for loss functions.
Note that unlike typical loss functions used in neural networks these are neither summed nor averaged over the batch and the output of evaluate() will not be a scalar. It is up to the user to then to correctly manipulate them as needed.
- __init__(weight)[source]
Initializes the loss instance.
- Parameters
weight (Numeric) – The relative weight attributed to the loss.
- property weight: Numeric
The relative weight of the loss.
- Return type
Numeric
- abstract property targets: Optional[Array]
The targets (if present) used for evaluating the loss.
- Return type
Optional[Array]
- abstract property parameter_dependants: Tuple[Array, ...]
All the parameter dependent arrays of the loss.
- Return type
Tuple[Array, …]
- property num_parameter_dependants: int
Number of parameter dependent arrays of the loss.
- Return type
int
- abstract property parameter_independants: Tuple[Numeric, ...]
All the parameter independent arrays of the loss.
- Return type
Tuple[Numeric, …]
- property num_parameter_independants: int
Number of parameter independent arrays of the loss.
- Return type
int
- copy_with_different_inputs(parameter_dependants)[source]
Creates a copy of the loss function object, but with different inputs.
- Return type
- evaluate(targets=None, coefficient_mode='regular')[source]
Evaluates the loss function on the targets.
- Parameters
targets (Optional[Array]) – The targets, on which to evaluate the loss. If this is set to
Nonewill useself.targetsinstead.coefficient_mode (str) –
Specifies how to use the relative weight of the loss in the returned value. There are three options:
’regular’ - returns
self.weight * loss(targets)’sqrt’ - returns
sqrt(self.weight) * loss(targets)’off’ - returns
loss(targets)
- Return type
Array
- Returns
The value of the loss scaled appropriately by
self.weightaccording to the coefficient mode.- Raises
ValueError if both targets and self.targets are None. –
- grad_of_evaluate(targets, coefficient_mode)[source]
Evaluates the gradient of the loss function, w.r.t. its inputs.
- Parameters
targets (Optional[Array]) – The targets at which to evaluate the loss. If this is
Nonewill useself.targetsinstead.coefficient_mode (str) – The coefficient mode to use for evaluation. See
self.evaluatefor more details.
- Return type
Tuple[Array, …]
- Returns
The gradient of the loss function w.r.t. its inputs, at the provided targets.
- multiply_ggn(vector)[source]
Right-multiplies a vector by the GGN of the loss function.
Here the GGN is the Generalized Gauss-Newton matrix (whose definition is somewhat flexible) of the loss function with respect to its inputs.
- Parameters
vector (Sequence[Array]) – The vector to multiply. Must have the same shape(s) as
self.inputs.- Return type
Tuple[Array, …]
- Returns
The vector right-multiplied by the GGN. Will have the same shape(s) as
self.inputs.
- abstract multiply_ggn_unweighted(vector)[source]
Unweighted version of
multiply_ggn().- Return type
Tuple[Array, …]
- multiply_ggn_factor(vector)[source]
Right-multiplies a vector by a factor B of the GGN.
Here the GGN is the Generalized Gauss-Newton matrix (whose definition is somewhat flexible) of the loss function with respect to its inputs. Typically this will be block-diagonal across different cases in the batch, since the loss function is typically summed across cases.
Note that B can be any matrix satisfying
B * B^T = GwhereGis the GGN, but will agree with the one used in the other methods of this class.- Parameters
vector (Array) – The vector to multiply. Must be of the shape(s) given by ‘self.ggn_factor_inner_shape’.
- Return type
Tuple[Array, …]
- Returns
The vector right-multiplied by B. Will be of the same shape(s) as
self.inputs.
- abstract multiply_ggn_factor_unweighted(vector)[source]
Unweighted version of
multiply_ggn_factor().- Return type
Tuple[Array, …]
- multiply_ggn_factor_transpose(vector)[source]
Right-multiplies a vector by the transpose of a factor B of the GGN.
Here the GGN is the Generalized Gauss-Newton matrix (whose definition is somewhat flexible) of the loss function with respect to its inputs. Typically this will be block-diagonal across different cases in the batch, since the loss function is typically summed across cases.
Note that B can be any matrix satisfying
B * B^T = Gwhere G is the GGN, but will agree with the one used in the other methods of this class.- Parameters
vector (Sequence[Array]) – The vector to multiply. Must have the same shape(s) as
self.inputs.- Return type
Array
- Returns
The vector right-multiplied by B^T. Will be of the shape(s) given by
self.ggn_factor_inner_shape.
- abstract multiply_ggn_factor_transpose_unweighted(vector)[source]
Unweighted version of
multiply_ggn_factor_transpose().- Return type
Array
- multiply_ggn_factor_replicated_one_hot(index)[source]
Right-multiplies a replicated-one-hot vector by a factor B of the GGN.
Here the GGN is the Generalized Gauss-Newton matrix (whose definition is somewhat flexible) of the loss function with respect to its inputs. Typically this will be block-diagonal across different cases in the batch, since the loss function is typically summed across cases.
A replicated-one-hot vector means a tensor which, for each slice along the batch dimension (assumed to be dimension 0), is 1.0 in the entry corresponding to the given index and 0 elsewhere.
Note that B can be any matrix satisfying
B * B^T = Gwhere G is the GGN, but will agree with the one used in the other methods of this class.- Parameters
index (Sequence[int]) – A tuple representing in the index of the entry in each slice that is 1.0. Note that len(index) must be equal to the number of elements of the
ggn_factor_inner_shapetensor minus one.- Return type
Tuple[Array, …]
- Returns
The vector right-multiplied by B^T. Will be of the same shape(s) as the
inputsproperty.
- abstract multiply_ggn_factor_replicated_one_hot_unweighted(index)[source]
Unweighted version of
multiply_ggn_factor_replicated_one_hot().- Return type
Tuple[Array, …]
- abstract property ggn_factor_inner_shape: Shape
The shape of the array returned by self.multiply_ggn_factor.
- Return type
Shape
NegativeLogProbLoss
- class kfac_jax.NegativeLogProbLoss(weight)[source]
Base class for loss functions that represent negative log-probability.
- property parameter_dependants: Tuple[Array, ...]
All the parameter dependent arrays of the loss.
- Return type
Tuple[Array, …]
- abstract property params: Tuple[Array, ...]
Parameters to the underlying distribution.
- Return type
Tuple[Array, …]
- multiply_fisher(vector)[source]
Right-multiplies a vector by the Fisher.
- Parameters
vector (Sequence[Array]) – The vector to multiply. Must have the same shape(s) as
self.inputs.- Return type
Tuple[Array, …]
- Returns
The vector right-multiplied by the Fisher. Will have of the same shape(s) as
self.inputs.
- abstract multiply_fisher_unweighted(vector)[source]
Unweighted version of
multiply_fisher().- Return type
Tuple[Array, …]
- multiply_fisher_factor(vector)[source]
Right-multiplies a vector by a factor B of the Fisher.
Here the Fisher is the Fisher information matrix (i.e. expected outer- product of gradients) with respect to the parameters of the underlying probability distribution (whose log-prob defines the loss). Typically this will be block-diagonal across different cases in the batch, since the distribution is usually (but not always) conditionally iid across different cases.
Note that B can be any matrix satisfying
B * B^T = Fwhere F is the Fisher, but will agree with the one used in the other methods of this class.- Parameters
vector (Array) – The vector to multiply. Must have the same shape(s) as
self.fisher_factor_inner_shape.- Return type
Tuple[Array, …]
- Returns
The vector right-multiplied by B. Will have the same shape(s) as
self.inputs.
- abstract multiply_fisher_factor_unweighted(vector)[source]
Unweighted version of
multiply_fisher_factor().- Return type
Tuple[Array, …]
- multiply_fisher_factor_transpose(vector)[source]
Right-multiplies a vector by the transpose of a factor B of the Fisher.
Here the Fisher is the Fisher information matrix (i.e. expected outer- product of gradients) with respect to the parameters of the underlying probability distribution (whose log-prob defines the loss). Typically this will be block-diagonal across different cases in the batch, since the distribution is usually (but not always) conditionally iid across different cases.
Note that B can be any matrix satisfying
B * B^T = Fwhere F is the Fisher, but will agree with the one used in the other methods of this class.- Parameters
vector (Sequence[Array]) – The vector to multiply. Must have the same shape(s) as
self.inputs.- Return type
Array
- Returns
The vector right-multiplied by B^T. Will have the shape given by
self.fisher_factor_inner_shape.
- abstract multiply_fisher_factor_transpose_unweighted(vector)[source]
Unweighted version of
multiply_fisher_factor_transpose().- Return type
Array
- multiply_fisher_factor_replicated_one_hot(index)[source]
Right-multiplies a replicated-one-hot vector by a factor B of the Fisher.
Here the Fisher is the Fisher information matrix (i.e. expected outer- product of gradients) with respect to the parameters of the underlying probability distribution (whose log-prob defines the loss). Typically this will be block-diagonal across different cases in the batch, since the distribution is usually (but not always) conditionally iid across different cases.
A replicated-one-hot vector means a tensor which, for each slice along the batch dimension (assumed to be dimension 0), is 1.0 in the entry corresponding to the given index and 0 elsewhere.
Note that B can be any matrix satisfying
B * B^T = Hwhere H is the Fisher, but will agree with the one used in the other methods of this class.- Parameters
index (Sequence[int]) – A tuple representing in the index of the entry in each slice that is 1.0. Note that len(index) must be equal to the number of elements of the
fisher_factor_inner_shapetensor minus one.- Return type
Tuple[Array, …]
- Returns
The vector right-multiplied by B. Will have the same shape(s) as
self.inputs.
- abstract multiply_fisher_factor_replicated_one_hot_unweighted(index)[source]
Unweighted version of
multiply_fisher_factor_replicated_one_hot().- Return type
Tuple[Array, …]
- abstract property fisher_factor_inner_shape: Shape
The shape of the array returned by
multiply_fisher_factor().- Return type
Shape
- grad_of_evaluate_on_sample(rng, coefficient_mode)[source]
Evaluates the gradient of the log probability on a random sample.
- Parameters
rng (Array) – Jax PRNG key for sampling.
coefficient_mode (str) – The coefficient mode to use for evaluation.
- Return type
Tuple[Array, …]
- Returns
The gradient of the log probability of targets sampled from the distribution.
DistributionNegativeLogProbLoss
- class kfac_jax.DistributionNegativeLogProbLoss(weight)[source]
Negative log-probability loss that uses a Distrax distribution.
- abstract property dist: distrax.Distribution
The underlying Distrax distribution.
- Return type
distrax.Distribution
- property fisher_factor_inner_shape: Shape
The shape of the array returned by
multiply_fisher_factor().- Return type
Shape
NormalMeanNegativeLogProbLoss
- class kfac_jax.NormalMeanNegativeLogProbLoss(mean, targets=None, variance=0.5, weight=1.0, normalize_log_prob=True)[source]
Loss log prob loss for a normal distribution parameterized by a mean vector.
Note that the covariance is treated as the identity divided by 2. Also note that the Fisher for such a normal distribution with respect the mean parameter is given by:
F = (1 / variance) * I
See for example https://www.ii.pwr.edu.pl/~tomczak/PDF/[JMT]Fisher_inf.pdf.
- __init__(mean, targets=None, variance=0.5, weight=1.0, normalize_log_prob=True)[source]
Initializes the loss instance.
- Parameters
mean (Array) – The mean of the normal distribution.
targets (Optional[Array]) – Optional targets to use for evaluation.
variance (Numeric) – The scalar variance of the normal distribution.
weight (Numeric) – The relative weight of the loss.
normalize_log_prob (bool) – Whether the log prob should include the standard normalization constant for Gaussians (which is additive and depends on the variance).
- property targets: Optional[Array]
The targets (if present) used for evaluating the loss.
- Return type
Optional[Array]
- property parameter_independants: Tuple[Numeric, ...]
All the parameter independent arrays of the loss.
- Return type
Tuple[Numeric, …]
- property dist: distrax.MultivariateNormalDiag
The underlying Distrax distribution.
- Return type
distrax.MultivariateNormalDiag
- property params: Tuple[Array]
Parameters to the underlying distribution.
- Return type
Tuple[Array]
- multiply_fisher_unweighted(vector)[source]
Unweighted version of
multiply_fisher().- Return type
Tuple[Array]
- multiply_fisher_factor_unweighted(vector)[source]
Unweighted version of
multiply_fisher_factor().- Return type
Tuple[Array]
NormalMeanVarianceNegativeLogProbLoss
- class kfac_jax.NormalMeanVarianceNegativeLogProbLoss(mean, variance, targets=None, weight=1.0)[source]
Negative log prob loss for a normal distribution with mean and variance.
This class parameterizes a multivariate normal distribution with n independent dimensions. Unlike
NormalMeanNegativeLogProbLoss, this class does not assume the variance is held constant. The Fisher Information for n = 1 is given by:- F = [[1 / variance, 0],
[ 0, 0.5 / variance^2]]
where the parameters of the distribution are concatenated into a single vector as
[mean, variance]. For n > 1, the mean parameter vector is concatenated with the variance parameter vector. For further details checkout the Wikipedia page.- __init__(mean, variance, targets=None, weight=1.0)[source]
Initializes the loss instance.
- Parameters
mean (Array) – The mean of the normal distribution.
variance (Array) – The variance of the normal distribution.
targets (Optional[Array]) – Optional targets to use for evaluation.
weight (Numeric) – The relative weight of the loss.
- property targets: Optional[Array]
The targets (if present) used for evaluating the loss.
- Return type
Optional[Array]
- property parameter_independants: Tuple[Numeric, ...]
All the parameter independent arrays of the loss.
- Return type
Tuple[Numeric, …]
- property dist: distrax.MultivariateNormalDiag
The underlying Distrax distribution.
- Return type
distrax.MultivariateNormalDiag
- property params: Tuple[Array, Array]
Parameters to the underlying distribution.
- Return type
Tuple[Array, Array]
- multiply_fisher_unweighted(vector)[source]
Unweighted version of
multiply_fisher().- Return type
Tuple[Array, Array]
- multiply_fisher_factor_unweighted(vector)[source]
Unweighted version of
multiply_fisher_factor().- Return type
Tuple[Array, Array]
- multiply_fisher_factor_transpose_unweighted(vector)[source]
Unweighted version of
multiply_fisher_factor_transpose().- Return type
Array
- multiply_fisher_factor_replicated_one_hot_unweighted(index)[source]
Unweighted version of
multiply_fisher_factor_replicated_one_hot().- Return type
Tuple[Array, Array]
- property fisher_factor_inner_shape: Shape
The shape of the array returned by
multiply_fisher_factor().- Return type
Shape
- multiply_ggn_unweighted(vector)[source]
Unweighted version of
multiply_ggn().- Return type
Tuple[Array, …]
- multiply_ggn_factor_unweighted(vector)[source]
Unweighted version of
multiply_ggn_factor().- Return type
Tuple[Array, …]
- multiply_ggn_factor_transpose_unweighted(vector)[source]
Unweighted version of
multiply_ggn_factor_transpose().- Return type
Array
- multiply_ggn_factor_replicated_one_hot_unweighted(index)[source]
Unweighted version of
multiply_ggn_factor_replicated_one_hot().- Return type
Tuple[Array, …]
- property ggn_factor_inner_shape: Shape
The shape of the array returned by self.multiply_ggn_factor.
- Return type
Shape
MultiBernoulliNegativeLogProbLoss
- class kfac_jax.MultiBernoulliNegativeLogProbLoss(logits, targets=None, weight=1.0)[source]
Negative log prob loss for multiple Bernoulli distributions parametrized by logits.
Represents N independent Bernoulli distributions where N = len(logits). Its Fisher Information matrix is given by
F = diag(p * (1-p)), wherep = sigmoid(logits).As F is diagonal with positive entries, its factor B is
B = diag(sqrt(p * (1-p))).- __init__(logits, targets=None, weight=1.0)[source]
Initializes the loss instance.
- Parameters
logits (Array) – The logits of the Bernoulli distribution.
targets (Optional[Array]) – Optional targets to use for evaluation.
weight (Numeric) – The relative weight of the loss.
- property targets: Optional[Array]
The targets (if present) used for evaluating the loss.
- Return type
Optional[Array]
- property parameter_independants: Tuple[Numeric, ...]
All the parameter independent arrays of the loss.
- Return type
Tuple[Numeric, …]
- property dist: distrax.Bernoulli
The underlying Distrax distribution.
- Return type
distrax.Bernoulli
- property params: Tuple[Array]
Parameters to the underlying distribution.
- Return type
Tuple[Array]
- multiply_fisher_unweighted(vector)[source]
Unweighted version of
multiply_fisher().- Return type
Tuple[Array]
- multiply_fisher_factor_unweighted(vector)[source]
Unweighted version of
multiply_fisher_factor().- Return type
Tuple[Array]
CategoricalLogitsNegativeLogProbLoss
- class kfac_jax.CategoricalLogitsNegativeLogProbLoss(logits, targets=None, mask=None, weight=1.0)[source]
Negative log prob loss for a categorical distribution parameterized by logits.
Note that the Fisher (for a single case) of a categorical distribution, with respect to the natural parameters (i.e. the logits), is given by
F = diag(p) - p*p^T, wherep = softmax(logits). F can be factorized asF = B * B^T, whereB = diag(q) - p*q^Tandqis the entry-wise square root ofp. This is easy to verify using the fact thatq^T*q = 1.- __init__(logits, targets=None, mask=None, weight=1.0)[source]
Initializes the loss instance.
- Parameters
logits (Array) – The logits of the Categorical distribution of shape
(batch_size, output_size).targets (Optional[Array]) – Optional targets to use for evaluation, which specify an integer index of the correct class. Must be of shape
(batch_size,).mask (Optional[Array]) – Optional mask to apply to losses over the batch. Should be 0/1-valued and of shape
(batch_size,). The tensors returned byevaluateandgrad_of_evaluate, as well as the various matrix vector products, will be multiplied by mask (with broadcasting to later dimensions).weight (Numeric) – The relative weight of the loss.
- property targets: Optional[Array]
The targets (if present) used for evaluating the loss.
- Return type
Optional[Array]
- property parameter_independants: Tuple[Numeric, ...]
All the parameter independent arrays of the loss.
- Return type
Tuple[Numeric, …]
- property dist: distrax.Categorical
The underlying Distrax distribution.
- Return type
distrax.Categorical
- property params: Tuple[Array]
Parameters to the underlying distribution.
- Return type
Tuple[Array]
- property fisher_factor_inner_shape: Shape
The shape of the array returned by
multiply_fisher_factor().- Return type
Shape
- multiply_fisher_unweighted(vector)[source]
Unweighted version of
multiply_fisher().- Return type
Tuple[Array]
- multiply_fisher_factor_unweighted(vector)[source]
Unweighted version of
multiply_fisher_factor().- Return type
Tuple[Array]
OneHotCategoricalLogitsNegativeLogProbLoss
- class kfac_jax.OneHotCategoricalLogitsNegativeLogProbLoss(logits, targets=None, mask=None, weight=1.0)[source]
Neg log prob loss for a categorical distribution with onehot targets.
Identical to CategoricalLogitsNegativeLogProbLoss except that the underlying distribution is OneHotCategorical as opposed to Categorical.
- property dist: distrax.OneHotCategorical
The underlying Distrax distribution.
- Return type
distrax.OneHotCategorical
register_sigmoid_cross_entropy_loss
- kfac_jax.register_sigmoid_cross_entropy_loss(logits, targets=None, weight=1.0)[source]
Registers a sigmoid cross-entropy loss function.
This assumes a sigmoid cross-entropy loss of the form
weight * jnp.sum(sigmoid_cross_entropy(logits, targets)) / batch_size.NOTE: this function assumes you are not averaging over non-batch dimensions when computing the loss. e.g. if dimension 0 were the batch dimension, this corresponds to ``jnp.mean(jnp.sum(sigmoid_cross_entropy(logits, targets),
axis=range(1, target.ndims)), axis=0)``
and not
jnp.mean(sigmoid_cross_entropy(logits, targets))If your loss is of the latter form you can compensate for this by passing the appropriate value toweight.NOTE: this function is distinct from
register_softmax_cross_entropy_loss()and should not be confused with it. It is similar toregister_multi_bernoulli_predictive_distribution()but without the explicit probabilistic interpretation. It behaves identically for now.- Parameters
logits (Array) – The input logits of the loss as a 2D array of floats. The first dimension will usually be the batch size, but doesn’t need to be (unless using
estimation_mode='fisher_exact'orestimation_mode='ggn_exact'in the optimizer/estimator).targets (Optional[Array]) – (OPTIONAL) The targets for the loss function. Must be of the same shape as
logits. Only required if usingestimation_mode='fisher_empirical'in the optimizer/estimator. (Default: None)weight (Numeric) – The constant scalar coefficient which this loss is multiplied by. Note that this must be constant and independent of the network’s parameters. (Default: 1.0)
register_multi_bernoulli_predictive_distribution
- kfac_jax.register_multi_bernoulli_predictive_distribution(logits, targets=None, weight=1.0)[source]
Registers a multi-Bernoulli predictive distribution.
This corresponds to a sigmoid cross-entropy loss of the form
weight * jnp.sum(sigmoid_cross_entropy(logits, targets)) / batch_size.NOTE: this function assumes you are not averaging over non-batch dimensions when computing the loss. e.g. if dimension 0 were the batch dimension, this corresponds to ``jnp.mean(jnp.sum(sigmoid_cross_entropy(logits, targets),
axis=range(1, target.ndims)), axis=0)``
and not
jnp.mean(sigmoid_cross_entropy(logits, targets))If your loss is of the latter form you can compensate for it by passing the appropriate value toweight.NOTE: this is distinct from
register_categorical_predictive_distribution()and should not be confused with it.- Parameters
logits (Array) – The logits of the distribution (i.e. its parameters) as a 2D array of floats. The first dimension will usually be the batch size, but doesn’t need to be (unless using
estimation_mode='fisher_exact'orestimation_mode='ggn_exact'in the optimizer/estimator).targets (Optional[Array]) – (OPTIONAL) The targets for the loss function. Only required if using
estimation_mode='fisher_empirical'in the optimizer/estimator. (Default: None)weight (Numeric) – The constant scalar coefficient that the log prob loss associated with this distribution is multiplied by. This is NOT equivalent to changing the temperature of the distribution since we don’t renormalize the log prob in the objective function. Note that this must be constant and independent of the network’s parameters. (Default: 1.0)
register_softmax_cross_entropy_loss
- kfac_jax.register_softmax_cross_entropy_loss(logits, targets=None, mask=None, weight=1.0)[source]
Registers a softmax cross-entropy loss function.
This assumes a softmax cross-entropy loss of the form
weight * jnp.sum(softmax_cross_entropy(logits, targets)) / batch_size.NOTE:this is distinct from
register_sigmoid_cross_entropy_loss()and should not be confused with it. It is similar toregister_categorical_predictive_distribution()but without the explicit probabilistic interpretation. It behaves identically for now.- Parameters
logits (Array) – The input logits of the loss as a 2D array of floats. The first dimension will usually be the batch size, but doesn’t need to be (unless using
estimation_mode='fisher_exact'orestimation_mode='ggn_exact'in the optimizer/estimator). The second dimension is the one over which the softmax is computed.targets (Optional[Array]) – (OPTIONAL) The targets for the loss function. Must be a 1D array of integers with shape
(logits.shape[0],). Only required if usingestimation_mode='fisher_empirical'in the optimizer/estimator. (Default: None)mask (Optional[Array]) – (OPTIONAL) Mask to apply to losses. Should be 0/1-valued and of shape
(logits.shape[0],). Losses corresponding to mask values of False will be treated as constant and equal to 0. (Default: None)weight (Numeric) – The constant scalar coefficient which this loss is multiplied by. Note that this must be constant and independent of the network’s parameters. (Default: 1.0)
register_categorical_predictive_distribution
- kfac_jax.register_categorical_predictive_distribution(logits, targets=None, mask=None, weight=1.0)[source]
Registers a categorical predictive distribution.
This corresponds to a softmax cross-entropy loss of the form
weight * jnp.sum(softmax_cross_entropy(logits, targets)) / batch_size.NOTE: this is distinct from
register_multi_bernoulli_predictive_distribution()and should not be confused with it.- Parameters
logits (Array) – The logits of the distribution (i.e. its parameters) as a 2D array of floats. The first dimension will usually be the batch size, but doesn’t need to be (unless using
estimation_mode='fisher_exact'orestimation_mode='ggn_exact'in the optimizer/estimator). The second dimension is the one over which the softmax is computed.targets (Optional[Array]) – (OPTIONAL) The values at which the log probability of this distribution is evaluated (to give the loss). Must be a 2D array of integers with shape
(logits.shape[0],). Only required if usingestimation_mode='fisher_empirical'in the optimizer/estimator. (Default: None)mask (Optional[Array]) – (OPTIONAL) Mask to apply to log probabilities generated by the distribution. Should be 0/1-valued and of shape
(logits.shape[0],). Log probablities corresponding to mask values of False will be treated as constant and equal to 0. (Default: None)weight (Numeric) – The constant scalar coefficient that the log prob loss associated with this distribution is multiplied by. This is NOT equivalent to changing the temperature of the distribution since we don’t renormalize the log prob in the objective function. Note that this must be constant and independent of the network’s parameters. (Default: 1.0)
register_squared_error_loss
- kfac_jax.register_squared_error_loss(prediction, targets=None, weight=1.0)[source]
Registers a squared error loss function.
This assumes a squared error loss of the form
weight * jnp.sum((targets - prediction)**2) / batch_size.If your loss uses a coefficient of 0.5 you need to set the
weightargument to reflect this.NOTE: this function assumes you are not averaging over non-batch dimensions when computing the loss. e.g. if dimension 0 were the batch dimension, this corresponds to ``jnp.mean(jnp.sum((target - prediction)**2,
axis=range(1, target.ndims)), axis=0)``
and not
jnp.mean((target - prediction)**2)If your loss is of the latter form you can compensate for it by passing the appropriate value toweight.NOTE: even though
predictionandtargetsare interchangeable in the definition of the squared error loss, they are not interchangeable in this function.predictionmust be the output of your parameterized function (e.g. neural network), andtargetsmust not depend on the parameters. Mixing the two up could lead to a silent failure of the curvature estimation.- Parameters
prediction (Array) – The prediction made by the network (i.e. its output). The first dimension will usually be the batch size, but doesn’t need to be (unless using
estimation_mode='fisher_exact'orestimation_mode='ggn_exact'in the optimizer/estimator).targets (Optional[Array]) – (OPTIONAL) The targets for the loss function. Only required if using
estimation_mode='fisher_empirical'in the optimizer/estimator. (Default: None)weight (Numeric) – The constant scalar coefficient which this loss is multiplied by. Note that this must be constant and independent of the network’s parameters. (Default: 1.0)
register_normal_predictive_distribution
- kfac_jax.register_normal_predictive_distribution(mean, targets=None, variance=0.5, weight=1.0, normalize_log_prob=True)[source]
Registers a normal predictive distribution.
- This corresponds to a squared error loss of the form
weight/(2*var) * jnp.sum((targets - mean)**2) / batch_size.
NOTE: this function assumes you are not averaging over non-batch dimensions when computing the loss. e.g. if dimension 0 were the batch dimension, this corresponds to ``jnp.mean(jnp.sum((target - prediction)**2,
axis=range(1,target.ndims)), axis=0)``
and not
jnp.mean((target - prediction)**2). If your loss is of the latter form you can compensate for it by passing the appropriate value toweight.- Parameters
mean (Array) – A tensor defining the mean vector of the distribution. The first dimension will usually be the batch size, but doesn’t need to be (unless using
estimation_mode='fisher_exact'orestimation_mode='ggn_exact'in the optimizer/estimator).targets (Optional[Array]) – (OPTIONAL) The targets for the loss function. Only required if using
estimation_mode='fisher_empirical'in the optimizer/estimator. (Default: None)variance (float) – The variance of the distribution. Must be a constant scalar, independent of the network’s parameters. Note that the default value of 0.5 corresponds to a standard squared error loss
weight * jnp.sum((target - prediction)**2). If you want your squared error loss to be of the form0.5*coeff*jnp.sum((target - prediction)**2)you should use variance=1.0. (Default: 0.5)weight (Numeric) – A constant scalar coefficient that the log prob loss associated with this distribution is multiplied by. In general this is NOT equivalent to changing the temperature of the distribution, but in the case of normal distributions it may be. Note that this must be constant and independent of the network’s parameters. (Default: 1.0)
normalize_log_prob (bool) – Whether the negative log prob loss associated to this this distribution should include the additive normalization constant (which is constant and depends on
variance) that makes it a true log prob, and not just a squared error loss. Note that this has no effect on the behavior of optimizer with the exception of in niche situations where the loss value is computed from the registrations. e.g., wheninclude_registered_loss_in_stats=Trueis used. (Default: True)
Curvature Blocks
|
Abstract class for curvature approximation blocks. |
|
A block that assumes that the curvature is a scaled identity matrix. |
|
An abstract class for approximating only the diagonal of curvature. |
|
An abstract class for approximating the block matrix with a full matrix. |
|
A Kronecker factored block for layers with weights and an optional bias. |
|
Approximates the diagonal of the curvature with in the most obvious way. |
|
Approximates the full curvature with in the most obvious way. |
|
A Diagonal block specifically for dense layers. |
|
A Full block specifically for dense layers. |
|
A |
|
A |
|
A |
|
A |
|
A diagonal approximation specifically for a scale and shift layers. |
|
A full dense approximation specifically for a scale and shift layers. |
|
Sets the default value of maximum parallel elements in the module. |
Returns the default value of maximum parallel elements in the module. |
|
Sets the default value of the eigen decomposition threshold. |
|
Returns the default value of the eigen decomposition threshold. |
CurvatureBlock
- class kfac_jax.CurvatureBlock(layer_tag_eq, name)[source]
Abstract class for curvature approximation blocks.
A CurvatureBlock defines a curvature matrix to be estimated, and gives methods to multiply powers of this with a vector. Powers can be computed exactly or with a class-determined approximation. Cached versions of the powers can be pre-computed to make repeated multiplications cheaper. During initialization, you would have to explicitly specify all powers that you will need to cache.
- class State(cache)[source]
Persistent state of the block.
Any subclasses of
CurvatureBlockshould also internally extend this class, with any attributes needed for the curvature estimation.- cache
A dictionary, containing any state data that is updated on irregular intervals, such as inverses, eigenvalues, etc. Elements of this are updated via calls to
update_cache(), and do not necessarily correspond to the most up-to-date curvature estimate.- Type
Optional[Dict[str, Union[Array, Dict[str, Array]]]]
- __init__(cache)
- __init__(layer_tag_eq, name)[source]
Initializes the block.
- Parameters
layer_tag_eq (tags.LayerTagEqn) – The Jax equation corresponding to the layer tag that this block will approximate the curvature to.
name (str) – The name of this block.
- property layer_tag_primitive: tags.LayerTag
The
jax.core.Primitivecorresponding to the block’s tag equation.- Return type
tags.LayerTag
- property parameter_variables: Tuple[jax.core.Var, ...]
The parameter variables of the underlying Jax equation.
- Return type
Tuple[jax.core.Var, …]
- property outputs_shapes: Tuple[Shape, ...]
The shapes of the output variables of the block’s tag equation.
- Return type
Tuple[Shape, …]
- property inputs_shapes: Tuple[Shape, ...]
The shapes of the input variables of the block’s tag equation.
- Return type
Tuple[Shape, …]
- property parameters_shapes: Tuple[Shape, ...]
The shapes of the parameter variables of the block’s tag equation.
- Return type
Tuple[Shape, …]
- property parameters_canonical_order: Tuple[int, ...]
The canonical order of the parameter variables.
- Return type
Tuple[int, …]
- property layer_tag_extra_params: Dict[str, Any]
Any extra parameters of passed into the Jax primitive of this block.
- Return type
Dict[str, Any]
- property number_of_parameters: int
Number of parameter variables of this block.
- Return type
int
- property dim: int
The number of elements of all parameter variables together.
- Return type
int
- scale(state, use_cache)[source]
A scalar pre-factor of the curvature approximation.
Importantly, all methods assume that whenever a user requests cached values, any state dependant scale is taken into account by the cache (e.g. either stored explicitly and used or mathematically added to values).
- Parameters
state (CurvatureBlock.State) – The state for this block.
use_cache (bool) – Whether the method requesting this is using cached values or not.
- Return type
Numeric
- Returns
A scalar value to be multiplied with any unscaled block representation.
- fixed_scale()[source]
A fixed scalar pre-factor of the curvature (e.g. constant).
- Return type
Numeric
- state_dependent_scale(state)[source]
A scalar pre-factor of the curvature, computed from the most fresh curvature estimate.
- Return type
Numeric
- init(rng, exact_powers_to_cache, approx_powers_to_cache, cache_eigenvalues)[source]
Initializes the state for this block.
- Parameters
rng (PRNGKey) – The PRNGKey which to be used for any randomness of the initialization
exact_powers_to_cache (Optional[ScalarOrSequence]) – A single value, or multiple values in a list, which specify which exact matrix powers the block should be caching. Matrix powers, which are expected to be used in
multiply_matpower(),multiply_inverse()ormultiply()withexact_power=Trueanduse_cached=Truemust be provided here.approx_powers_to_cache (Optional[ScalarOrSequence]) – A single value, or multiple values in a list, which specify approximate matrix powers the block should be caching. Matrix powers, which are expected to be used in
multiply_matrix_power(),multiply_inverse()ormultiply()withexact_power=Falseanduse_cached=Truemust be provided here.cache_eigenvalues (bool) – Specifies whether the block should be caching the eigenvalues of its approximate curvature.
- Return type
- Returns
A dictionary with the initialized state.
- abstract sync(state, pmap_axis_name)[source]
Syncs the state across different devices (does not sync the cache).
- Return type
- multiply_matpower(state, vector, identity_weight, power, exact_power, use_cached)[source]
Computes
(BlockMatrix + identity_weight I)**powertimesvector.- Parameters
state (CurvatureBlock.State) – The state for this block.
vector (Sequence[Array]) – A tuple of arrays that should have the same shapes as the block’s parameters_shapes, which represent the vector you want to multiply.
identity_weight (Numeric) – A scalar specifying the weight on the identity matrix that is added to the block matrix before raising it to a power. If
use_cached=Falseit is guaranteed that this argument will be used in the computation. When returning cached values, this argument may be ignored in favor whatever value was last passed toupdate_cache(). The precise semantics of this depend on the concrete subclass and its particular behavior in regard to caching.power (Scalar) – The power to which to raise the matrix.
exact_power (bool) – Specifies whether to compute the exact matrix power of
BlockMatrix + identity_weight I. When this argument isFalsethe exact behaviour will depend on the concrete subclass and the result will in general be an approximation to(BlockMatrix + identity_weight I)^power, although some subclasses may still compute the exact matrix power.use_cached (bool) – Whether to use a cached version for computing the product or to use the most recent curvature estimates. The cached version is going to be at least as fresh as the value provided to the last call to
update_cache()with the same value ofpower
- Return type
Tuple[Array, …]
- Returns
A tuple of arrays, representing the result of the matrix-vector product.
- multiply(state, vector, identity_weight, exact_power, use_cached)[source]
Computes
(BlockMatrix + identity_weight I)timesvector.- Return type
Tuple[Array, …]
- multiply_inverse(state, vector, identity_weight, exact_power, use_cached)[source]
Computes
(BlockMatrix + identity_weight I)^-1timesvector.- Return type
Tuple[Array, …]
- eigenvalues(state, use_cached)[source]
Computes the eigenvalues for this block approximation.
- Parameters
state (CurvatureBlock.State) – The state dict for this block.
use_cached (bool) – Whether to use a cached versions of the eigenvalues or to use the most recent curvature estimates to compute them. The cached version are going to be at least as fresh as the last time you called
update_cache()witheigenvalues=True.
- Return type
Array
- Returns
An array containing the eigenvalues of the block.
- abstract update_curvature_matrix_estimate(state, estimation_data, ema_old, ema_new, batch_size)[source]
Updates the block’s curvature estimates using the
infoprovided.Each block in general estimates a moving average of its associated curvature matrix. If you don’t want a moving average you can set
ema_old=0andema_new=1.- Parameters
state (CurvatureBlock.State) – The state dict for this block to update.
estimation_data (Dict[str, Sequence[Array]]) – A map containing data used for updating the curvature matrix estimate for this block. This can be computed by calling the function returned from
layer_tags_vjp(). Please see its implementation for more details on the name of the fields and how they are constructed.ema_old (Numeric) – Specifies the weight of the old value when computing the updated estimate in the moving average.
ema_new (Numeric) – Specifies the weight of the new value when computing the updated estimate in the moving average.
batch_size (Numeric) – The batch size used in computing the values in
info.
- Return type
- update_cache(state, identity_weight, exact_powers, approx_powers, eigenvalues)[source]
Updates the cached estimates of the different powers specified.
- Parameters
state (CurvatureBlock.State) – The state dict for this block to update.
identity_weight (Numeric) – The weight of the identity added to the block’s curvature matrix before computing the cached matrix power.
exact_powers (Optional[ScalarOrSequence]) – Specifies any cached exact matrix powers to be updated.
approx_powers (Optional[ScalarOrSequence]) – Specifies any cached approximate matrix powers to be updated.
eigenvalues (bool) – Specifies whether to update the cached eigenvalues of the block. If they have not been cached before, this will create an entry with them in the block’s cache.
- Return type
- Returns
The updated state.
ScaledIdentity
- class kfac_jax.ScaledIdentity(layer_tag_eq, name, scale=1.0)[source]
A block that assumes that the curvature is a scaled identity matrix.
- __init__(layer_tag_eq, name, scale=1.0)[source]
Initializes the block.
- Parameters
layer_tag_eq (tags.LayerTagEqn) – The Jax equation corresponding to the layer tag, that this block will approximate the curvature to.
name (str) – The name of this block.
scale (Numeric) – The scale of the identity matrix.
- fixed_scale()[source]
A fixed scalar pre-factor of the curvature (e.g. constant).
- Return type
Numeric
- sync(state, pmap_axis_name)[source]
Syncs the state across different devices (does not sync the cache).
- Return type
- update_curvature_matrix_estimate(state, estimation_data, ema_old, ema_new, batch_size)[source]
Updates the block’s curvature estimates using the
infoprovided.Each block in general estimates a moving average of its associated curvature matrix. If you don’t want a moving average you can set
ema_old=0andema_new=1.- Parameters
state (CurvatureBlock.State) – The state dict for this block to update.
estimation_data (Dict[str, Sequence[Array]]) – A map containing data used for updating the curvature matrix estimate for this block. This can be computed by calling the function returned from
layer_tags_vjp(). Please see its implementation for more details on the name of the fields and how they are constructed.ema_old (Numeric) – Specifies the weight of the old value when computing the updated estimate in the moving average.
ema_new (Numeric) – Specifies the weight of the new value when computing the updated estimate in the moving average.
batch_size (Numeric) – The batch size used in computing the values in
info.
- Return type
Diagonal
- class kfac_jax.Diagonal(layer_tag_eq, name)[source]
An abstract class for approximating only the diagonal of curvature.
- class State(cache, diagonal_factors)[source]
Persistent state of the block.
- diagonal_factors
A tuple of the moving averages of the estimated diagonals of the curvature for each parameter that is part of the associated layer.
- Type
Tuple[utils.WeightedMovingAverage]
- __init__(cache, diagonal_factors)
Full
- class kfac_jax.Full(layer_tag_eq, name, eigen_decomposition_threshold=None)[source]
An abstract class for approximating the block matrix with a full matrix.
- class State(cache, matrix)[source]
Persistent state of the block.
- matrix
A moving average of the estimated curvature matrix for all parameters that are part of the associated layer.
- Type
utils.WeightedMovingAverage
- __init__(cache, matrix)
- __init__(layer_tag_eq, name, eigen_decomposition_threshold=None)[source]
Initializes the block.
- Parameters
layer_tag_eq (tags.LayerTagEqn) – The Jax equation corresponding to the layer tag that this block will approximate the curvature to.
name (str) – The name of this block.
eigen_decomposition_threshold (Optional[int]) – During calls to
initandupdate_cacheif higher number of matrix powers than this threshold are requested, instead of computing individual approximate powers, will directly compute the eigen-decomposition instead (which provide access to any matrix power). If this isNonewill use the value returned fromget_default_eigen_decomposition_threshold().
- parameters_list_to_single_vector(parameters_shaped_list)[source]
Converts values corresponding to parameters of the block to vector.
- Return type
Array
- single_vector_to_parameters_list(vector)[source]
Reverses the transformation
self.parameters_list_to_single_vector.- Return type
Tuple[Array, …]
TwoKroneckerFactored
- class kfac_jax.TwoKroneckerFactored(layer_tag_eq, name)[source]
A Kronecker factored block for layers with weights and an optional bias.
- __init__(layer_tag_eq, name)[source]
Initializes the block.
- Parameters
layer_tag_eq (tags.LayerTagEqn) – The Jax equation corresponding to the layer tag that this block will approximate the curvature to.
name (str) – The name of this block.
- property has_bias: bool
Whether this layer’s equation has a bias.
- Return type
bool
NaiveDiagonal
- class kfac_jax.NaiveDiagonal(layer_tag_eq, name)[source]
Approximates the diagonal of the curvature with in the most obvious way.
The update to the curvature estimate is computed by
(sum_i g_i) ** 2 / N. where g_i is the gradient of each individual data point, andNis the batch size.- update_curvature_matrix_estimate(state, estimation_data, ema_old, ema_new, batch_size)[source]
Updates the block’s curvature estimates using the
infoprovided.Each block in general estimates a moving average of its associated curvature matrix. If you don’t want a moving average you can set
ema_old=0andema_new=1.- Parameters
state (NaiveDiagonal.State) – The state dict for this block to update.
estimation_data (Dict[str, Sequence[Array]]) – A map containing data used for updating the curvature matrix estimate for this block. This can be computed by calling the function returned from
layer_tags_vjp(). Please see its implementation for more details on the name of the fields and how they are constructed.ema_old (Numeric) – Specifies the weight of the old value when computing the updated estimate in the moving average.
ema_new (Numeric) – Specifies the weight of the new value when computing the updated estimate in the moving average.
batch_size (Numeric) – The batch size used in computing the values in
info.
- Return type
NaiveDiagonal.State
NaiveFull
- class kfac_jax.NaiveFull(layer_tag_eq, name, eigen_decomposition_threshold=None)[source]
Approximates the full curvature with in the most obvious way.
The update to the curvature estimate is computed by
(sum_i g_i) (sum_i g_i)^T / N, whereg_iis the gradient of each individual data point, andNis the batch size.- update_curvature_matrix_estimate(state, estimation_data, ema_old, ema_new, batch_size)[source]
Updates the block’s curvature estimates using the
infoprovided.Each block in general estimates a moving average of its associated curvature matrix. If you don’t want a moving average you can set
ema_old=0andema_new=1.- Parameters
state (Full.State) – The state dict for this block to update.
estimation_data (Dict[str, Sequence[Array]]) – A map containing data used for updating the curvature matrix estimate for this block. This can be computed by calling the function returned from
layer_tags_vjp(). Please see its implementation for more details on the name of the fields and how they are constructed.ema_old (Numeric) – Specifies the weight of the old value when computing the updated estimate in the moving average.
ema_new (Numeric) – Specifies the weight of the new value when computing the updated estimate in the moving average.
batch_size (Numeric) – The batch size used in computing the values in
info.
- Return type
DenseDiagonal
- class kfac_jax.DenseDiagonal(layer_tag_eq, name)[source]
A Diagonal block specifically for dense layers.
- property has_bias: bool
Whether the layer has a bias parameter.
- Return type
bool
- update_curvature_matrix_estimate(state, estimation_data, ema_old, ema_new, batch_size)[source]
Updates the block’s curvature estimates using the
infoprovided.Each block in general estimates a moving average of its associated curvature matrix. If you don’t want a moving average you can set
ema_old=0andema_new=1.- Parameters
state (Diagonal.State) – The state dict for this block to update.
estimation_data (Dict[str, Sequence[Array]]) – A map containing data used for updating the curvature matrix estimate for this block. This can be computed by calling the function returned from
layer_tags_vjp(). Please see its implementation for more details on the name of the fields and how they are constructed.ema_old (Numeric) – Specifies the weight of the old value when computing the updated estimate in the moving average.
ema_new (Numeric) – Specifies the weight of the new value when computing the updated estimate in the moving average.
batch_size (Numeric) – The batch size used in computing the values in
info.
- Return type
DenseFull
- class kfac_jax.DenseFull(layer_tag_eq, name, eigen_decomposition_threshold=None)[source]
A Full block specifically for dense layers.
- update_curvature_matrix_estimate(state, estimation_data, ema_old, ema_new, batch_size)[source]
Updates the block’s curvature estimates using the
infoprovided.Each block in general estimates a moving average of its associated curvature matrix. If you don’t want a moving average you can set
ema_old=0andema_new=1.- Parameters
state (Full.State) – The state dict for this block to update.
estimation_data (Dict[str, Sequence[Array]]) – A map containing data used for updating the curvature matrix estimate for this block. This can be computed by calling the function returned from
layer_tags_vjp(). Please see its implementation for more details on the name of the fields and how they are constructed.ema_old (Numeric) – Specifies the weight of the old value when computing the updated estimate in the moving average.
ema_new (Numeric) – Specifies the weight of the new value when computing the updated estimate in the moving average.
batch_size (Numeric) – The batch size used in computing the values in
info.
- Return type
DenseTwoKroneckerFactored
- class kfac_jax.DenseTwoKroneckerFactored(layer_tag_eq, name)[source]
A
TwoKroneckerFactoredblock specifically for dense layers.- update_curvature_matrix_estimate(state, estimation_data, ema_old, ema_new, batch_size)[source]
Updates the block’s curvature estimates using the
infoprovided.Each block in general estimates a moving average of its associated curvature matrix. If you don’t want a moving average you can set
ema_old=0andema_new=1.- Parameters
state (KroneckerFactored.State) – The state dict for this block to update.
estimation_data (Mapping[str, Sequence[Array]]) – A map containing data used for updating the curvature matrix estimate for this block. This can be computed by calling the function returned from
layer_tags_vjp(). Please see its implementation for more details on the name of the fields and how they are constructed.ema_old (Numeric) – Specifies the weight of the old value when computing the updated estimate in the moving average.
ema_new (Numeric) – Specifies the weight of the new value when computing the updated estimate in the moving average.
batch_size (Numeric) – The batch size used in computing the values in
info.
- Return type
KroneckerFactored.State
Conv2DDiagonal
- class kfac_jax.Conv2DDiagonal(layer_tag_eq, name, max_elements_for_vmap=None)[source]
A
Diagonalblock specifically for 2D convolution layers.- __init__(layer_tag_eq, name, max_elements_for_vmap=None)[source]
Initializes the block.
Since there is no ‘nice’ formula for computing the average of the tangents for a 2D convolution, what we do is that we have a function -
self.conv2d_tangent_squared- that computes for a single feature map the square of the tangents for the kernel of the convolution. To average over the batch we have two choices - vmap or loop over the batch sequentially using scan. This utility function provides a trade-off by being able to specify the maximum number of batch size that we can vmap over. This means that the maximum memory usage will bemax_batch_size_for_vmaptimes the memory needed when callingself.conv2d_tangent_squared. And the actualvmapwill be calledceil(total_batch_size / max_batch_size_for_vmap)number of times in a loop to find the final average.- Parameters
layer_tag_eq (tags.LayerTagEqn) – The Jax equation corresponding to the layer tag, that this block will approximate the curvature to.
name (str) – The name of this block.
max_elements_for_vmap (Optional[int]) – The threshold used for determining how much computation to the in parallel and how much in serial manner. If
Nonewill use the value returned byget_max_parallel_elements().
- conv2d_tangent_squared(image_features_map, output_tangent)[source]
Computes the elementwise square of a tangent for a single feature map.
- Return type
Array
- update_curvature_matrix_estimate(state, estimation_data, ema_old, ema_new, batch_size)[source]
Updates the block’s curvature estimates using the
infoprovided.Each block in general estimates a moving average of its associated curvature matrix. If you don’t want a moving average you can set
ema_old=0andema_new=1.- Parameters
state (Diagonal.State) – The state dict for this block to update.
estimation_data (Dict[str, Sequence[Array]]) – A map containing data used for updating the curvature matrix estimate for this block. This can be computed by calling the function returned from
layer_tags_vjp(). Please see its implementation for more details on the name of the fields and how they are constructed.ema_old (Numeric) – Specifies the weight of the old value when computing the updated estimate in the moving average.
ema_new (Numeric) – Specifies the weight of the new value when computing the updated estimate in the moving average.
batch_size (Numeric) – The batch size used in computing the values in
info.
- Return type
Conv2DFull
- class kfac_jax.Conv2DFull(layer_tag_eq, name, max_elements_for_vmap=None)[source]
A
Fullblock specifically for 2D convolution layers.- __init__(layer_tag_eq, name, max_elements_for_vmap=None)[source]
Initializes the block.
Since there is no ‘nice’ formula for computing the average of the tangents for a 2D convolution, what we do is that we have a function -
self.conv2d_tangent_squared- that computes for a single feature map the square of the tangents for the kernel of the convolution. To average over the batch we have two choices - vmap or loop over the batch sequentially using scan. This utility function provides a trade-off by being able to specify the maximum batch that that will be handled in a single iteration of the loop. This means that the maximum memory usage will bemax_batch_size_for_vmaptimes the memory needed when callingself.conv2d_tangent_squared. And the actualvmapwill be calledceil(total_batch_size / max_batch_size_for_vmap)number of times in a loop to find the final average.- Parameters
layer_tag_eq (tags.LayerTagEqn) – The Jax equation corresponding to the layer tag, that this block will approximate the curvature to.
name (str) – The name of this block.
max_elements_for_vmap (Optional[int]) – The threshold used for determining how much computation to the in parallel and how much in serial manner. If
Nonewill use the value returned byget_max_parallel_elements().
- conv2d_tangent_outer_product(inputs, tangent_of_outputs)[source]
Computes the outer product of a tangent for a single feature map.
- Return type
Array
- update_curvature_matrix_estimate(state, estimation_data, ema_old, ema_new, batch_size)[source]
Updates the block’s curvature estimates using the
infoprovided.Each block in general estimates a moving average of its associated curvature matrix. If you don’t want a moving average you can set
ema_old=0andema_new=1.- Parameters
state (Full.State) – The state dict for this block to update.
estimation_data (Dict[str, Sequence[Array]]) – A map containing data used for updating the curvature matrix estimate for this block. This can be computed by calling the function returned from
layer_tags_vjp(). Please see its implementation for more details on the name of the fields and how they are constructed.ema_old (Numeric) – Specifies the weight of the old value when computing the updated estimate in the moving average.
ema_new (Numeric) – Specifies the weight of the new value when computing the updated estimate in the moving average.
batch_size (Numeric) – The batch size used in computing the values in
info.
- Return type
Conv2DTwoKroneckerFactored
- class kfac_jax.Conv2DTwoKroneckerFactored(layer_tag_eq, name)[source]
A
TwoKroneckerFactoredblock specifically for 2D convolution layers.- fixed_scale()[source]
A fixed scalar pre-factor of the curvature (e.g. constant).
- Return type
Numeric
- property outputs_channel_index: int
The
channelsindex in the outputs of the layer.- Return type
int
- property inputs_channel_index: int
The
channelsindex in the inputs of the layer.- Return type
int
- property weights_output_channel_index: int
The
channelsindex in weights of the layer.- Return type
int
- property weights_spatial_size: int
The spatial filter size of the weights.
- Return type
int
- property num_locations: int
The number of spatial locations that each filter is applied to.
- Return type
int
- property num_inputs_channels: int
The number of channels in the inputs to the layer.
- Return type
int
- property num_outputs_channels: int
The number of channels in the outputs to the layer.
- Return type
int
- compute_inputs_stats(inputs, weighting_array=None)[source]
Computes the statistics for the inputs factor.
- Return type
Array
- compute_outputs_stats(tangent_of_output)[source]
Computes the statistics for the outputs factor.
- Return type
Array
- update_curvature_matrix_estimate(state, estimation_data, ema_old, ema_new, batch_size)[source]
Updates the block’s curvature estimates using the
infoprovided.Each block in general estimates a moving average of its associated curvature matrix. If you don’t want a moving average you can set
ema_old=0andema_new=1.- Parameters
state (TwoKroneckerFactored.State) – The state dict for this block to update.
estimation_data (Mapping[str, Sequence[Array]]) – A map containing data used for updating the curvature matrix estimate for this block. This can be computed by calling the function returned from
layer_tags_vjp(). Please see its implementation for more details on the name of the fields and how they are constructed.ema_old (Numeric) – Specifies the weight of the old value when computing the updated estimate in the moving average.
ema_new (Numeric) – Specifies the weight of the new value when computing the updated estimate in the moving average.
batch_size (Numeric) – The batch size used in computing the values in
info.
- Return type
TwoKroneckerFactored.State
ScaleAndShiftDiagonal
- class kfac_jax.ScaleAndShiftDiagonal(layer_tag_eq, name)[source]
A diagonal approximation specifically for a scale and shift layers.
- property has_scale: bool
Whether this layer’s equation has a scale.
- Return type
bool
- property has_shift: bool
Whether this layer’s equation has a shift.
- Return type
bool
- update_curvature_matrix_estimate(state, estimation_data, ema_old, ema_new, batch_size)[source]
Updates the block’s curvature estimates using the
infoprovided.Each block in general estimates a moving average of its associated curvature matrix. If you don’t want a moving average you can set
ema_old=0andema_new=1.- Parameters
state (Diagonal.State) – The state dict for this block to update.
estimation_data (Dict[str, Sequence[Array]]) – A map containing data used for updating the curvature matrix estimate for this block. This can be computed by calling the function returned from
layer_tags_vjp(). Please see its implementation for more details on the name of the fields and how they are constructed.ema_old (Numeric) – Specifies the weight of the old value when computing the updated estimate in the moving average.
ema_new (Numeric) – Specifies the weight of the new value when computing the updated estimate in the moving average.
batch_size (Numeric) – The batch size used in computing the values in
info.
- Return type
ScaleAndShiftFull
- class kfac_jax.ScaleAndShiftFull(layer_tag_eq, name, eigen_decomposition_threshold=None)[source]
A full dense approximation specifically for a scale and shift layers.
- update_curvature_matrix_estimate(state, estimation_data, ema_old, ema_new, batch_size)[source]
Updates the block’s curvature estimates using the
infoprovided.Each block in general estimates a moving average of its associated curvature matrix. If you don’t want a moving average you can set
ema_old=0andema_new=1.- Parameters
state (Full.State) – The state dict for this block to update.
estimation_data (Dict[str, Sequence[Array]]) – A map containing data used for updating the curvature matrix estimate for this block. This can be computed by calling the function returned from
layer_tags_vjp(). Please see its implementation for more details on the name of the fields and how they are constructed.ema_old (Numeric) – Specifies the weight of the old value when computing the updated estimate in the moving average.
ema_new (Numeric) – Specifies the weight of the new value when computing the updated estimate in the moving average.
batch_size (Numeric) – The batch size used in computing the values in
info.
- Return type
set_max_parallel_elements
- kfac_jax.set_max_parallel_elements(value)[source]
Sets the default value of maximum parallel elements in the module.
This value is used to determine the parallel-to-memory tradeoff in the curvature estimation procedure of
Conv2DDiagonalandConv2DFull. See their corresponding docs for further details.- Parameters
value (int) – The default value for maximum number of parallel elements.
get_max_parallel_elements
- kfac_jax.get_max_parallel_elements()[source]
Returns the default value of maximum parallel elements in the module.
This value is used to determine the parallel-to-memory tradeoff in the curvature estimation procedure of
Conv2DDiagonalandConv2DFull. See their corresponding docs for further details.- Return type
int
- Returns
The default value for maximum number of parallel elements.
set_default_eigen_decomposition_threshold
- kfac_jax.set_default_eigen_decomposition_threshold(value)[source]
Sets the default value of the eigen decomposition threshold.
This value is used in
Fullto determine when updating the cache, at what number of different powers to switch the implementation from a simple matrix power to an eigenvector decomposition.- Parameters
value (int) – The default value for eigen decomposition threshold.
get_default_eigen_decomposition_threshold
- kfac_jax.get_default_eigen_decomposition_threshold()[source]
Returns the default value of the eigen decomposition threshold.
This value is used in
Fullto determine when updating the cache, at what number of different powers to switch the implementation from a simple matrix power to an eigenvector decomposition.- Return type
int
- Returns
The default value of the eigen decomposition threshold.
Advanced Features API
Function tracing and Jacobian computation
|
A wrapper around Jaxpr, with useful additional data. |
|
Creates a function for the vector-Jacobian product w.r.t. |
|
Creates a function for the Jacobian-vector product w.r.t. |
|
Creates a function for the Hessian-vector product w.r.t. |
|
Creates a function for primal values and tangents w.r.t. |
ProcessedJaxpr
- class kfac_jax.ProcessedJaxpr(jaxpr, consts, in_tree, params_index, allow_left_out_params=False)[source]
A wrapper around Jaxpr, with useful additional data.
- jaxpr
The original Jaxpr that is being wrapped.
- consts
The constants returned from the tracing of the original Jaxpr.
- in_tree
The PyTree structure of the inputs to the function that the original Jaxpr has been created from.
- params_index
Specifies, which inputs to the function are to be considered a parameter variable. Specifically -
inputs[params_index].
- loss_tags
A tuple of all of the loss tags in the original Jaxpr.
- layer_tags
A sorted tuple of all of the layer tags in the original Jaxpr. The sorting order is based on the indices of the parameters associated with each layer tag.
- layer_indices
A sequence of tuples, where each tuple has the indices of the parameters associated with the corresponding layer tag.
- __init__(jaxpr, consts, in_tree, params_index, allow_left_out_params=False)[source]
Initializes the instance.
- Parameters
jaxpr (jax.core.Jaxpr) – The raw Jaxpr.
consts (Sequence[Any]) – The constants needed for evaluation of the raw Jaxpr.
in_tree (utils.PyTreeDef) – The PyTree structure of the inputs to the function that the
jaxprhas been created from.params_index (int) – Specifies, which inputs to the function are to be considered a parameter variable. Specifically -
inputs[params_index].allow_left_out_params (bool) – Whether to raise an error if any of the parameter variables is not included in any layer tag.
- property in_vars_flat: List[Var]
A flat list of all of the abstract input variables.
- Return type
List[Var]
- property in_vars: utils.PyTree[Var]
The abstract input variables, as an un-flatten structure.
- Return type
utils.PyTree[Var]
- property params_vars: utils.PyTree[Var]
The abstract parameter variables, as an un-flatten structure.
- Return type
utils.PyTree[Var]
- property params_vars_flat: List[Var]
A flat list of all abstract parameter variables.
- Return type
List[Var]
- property params_tree: utils.PyTreeDef
The PyTree structure of the parameter variables.
- Return type
utils.PyTreeDef
- classmethod make_from_func(func, func_args, params_index=0, auto_register_tags=True, allow_left_out_params=False, **auto_registration_kwargs)[source]
Constructs a
ProcessedJaxprfrom a the given function.- Parameters
func (utils.Func) – The model function, which will be traced.
func_args (FuncArgs) – Function arguments to use for tracing.
params_index (int) – The variables from the function arguments which are at this index (e.g.
func_args[params_index]) are to be considered model parameters.auto_register_tags (bool) – Whether to run an automatic layer registration on the function (e.g.
auto_register_tags()).allow_left_out_params (bool) – If this is set to
Falsean error would be raised if there are any model parameters that have not be assigned to a layer tag.**auto_registration_kwargs (Any) – Any additional keyword arguments, to be passed to the automatic registration pass.
- Return type
ProcJaxpr
- Returns
A
ProcessedJaxprrepresenting the model function.
Patches Second Moments
|
Computes the first and second moment of the convolutional patches. |
|
The exact same functionality as |
patches_moments
- kfac_jax.patches_moments(inputs, kernel_spatial_shape, strides=1, padding='VALID', data_format='NHWC', dim_numbers=None, inputs_dilation=None, kernel_dilation=None, feature_group_count=1, batch_group_count=1, unroll_loop=False, precision=None, weighting_array=None)[source]
Computes the first and second moment of the convolutional patches.
Since the code is written to support arbitrary convolution data formats, e.g. both NHWC and NCHW, in comments above any of the procedures is written the simplified version of what the statements below do, if the data format was fixed to NHWC.
- Parameters
inputs (Array) – The batch of images.
kernel_spatial_shape (Union[int, Shape]) – The spatial dimensions of the filter (int or list of ints).
strides (Union[int, Shape]) – The spatial dimensions of the strides (int or list of ints).
padding (PaddingVariants) – The padding (str or list of pairs of ints).
data_format (Optional[str]) – The data format of the inputs (None, NHWC, NCHW).
dim_numbers (Optional[Union[DimNumbers, lax.ConvDimensionNumbers]]) – Instance of
jax.lax.ConvDimensionNumbersinstead of data_format.inputs_dilation (Optional[Sequence[int]]) – An integer or sequence of integers, specifying the dilation for the image. Currently, patches_moments does not support dilation, so the only allowed values are None, 1, (1,1).
kernel_dilation (Optional[Sequence[int]]) – An integer or sequence of integers, specifying the dilation for the kernel. Currently, patches_moments does not support dilation, so the only allowed values are None, 1, (1,1).
feature_group_count (int) – The feature grouping for grouped convolutions. Currently, patches_moments supports only 1 and number of channels of the inputs.
batch_group_count (int) – The batch grouping for grouped convolutions. Currently, patches_moments supports only 1.
unroll_loop (bool) – Whether to unroll the loop in python.
precision (Optional[jax.lax.Precision]) – In what precision to run the computation. For more details please read Jax documentation of
jax.lax.conv_general_dilated().weighting_array (Optional[Array]) – A tensor specifying additional weighting of each element of the moment’s average.
- Return type
Tuple[Array, Array]
- Returns
The matrix of the patches’ second and first moment as a pair. The tensor of the patches’ second moment has a shape kernel_spatial_shape + (, channels) + kernel_spatial_shape + (, channels). The tensor of the patches’ first moment has a shape kernel_spatial_shape + (, channels).
patches_moments_explicit
- kfac_jax.patches_moments_explicit(inputs, kernel_spatial_shape, strides=1, padding='VALID', data_format='NHWC', dim_numbers=None, inputs_dilation=None, kernel_dilation=None, feature_group_count=1, batch_group_count=1, unroll_loop=False, precision=None, weighting_array=None)[source]
The exact same functionality as
patches_moments(), but explicitly extracts the patches viajax.lax.conv_general_dilated_patches(), potentially having a higher memory usage.- Return type
Tuple[Array, Array]