Standard API
Optimizers
|
The K-FAC optimizer. |
Optimizer
- class kfac_jax.Optimizer(value_and_grad_func, l2_reg, value_func_has_aux=False, value_func_has_state=False, value_func_has_rng=False, use_adaptive_learning_rate=False, learning_rate_schedule=None, use_adaptive_momentum=False, momentum_schedule=None, use_adaptive_damping=False, damping_schedule=None, initial_damping=None, min_damping=1e-08, max_damping=inf, include_damping_in_quad_change=False, damping_adaptation_interval=5, damping_adaptation_decay=0.9, damping_lower_threshold=0.25, damping_upper_threshold=0.75, always_use_exact_qmodel_for_damping_adjustment=False, norm_constraint=None, num_burnin_steps=10, estimation_mode='fisher_gradients', curvature_ema=0.95, inverse_update_period=5, use_exact_inverses=False, batch_process_func=None, register_only_generic=False, patterns_to_skip=(), auto_register_kwargs=None, layer_tag_to_block_ctor=None, multi_device=False, debug=False, batch_size_extractor=<function default_batch_size_extractor>, pmap_axis_name='kfac_axis', forbid_setting_attributes_after_finalize=True, modifiable_attribute_exceptions=(), include_norms_in_stats=False, include_per_param_norms_in_stats=False, distributed_precon_apply=True, distributed_inverses=True)[source]
The K-FAC optimizer.
- class State(velocities, estimator_state, damping, data_seen, step_counter)[source]
Persistent state of the optimizer.
- velocities
The update to the parameters from the previous step - \(\theta_t - \theta_{t-1}\).
- Type
Params
- estimator_state
The persistent state for the curvature estimator.
- damping
When using damping adaptation, this will contain the current value.
- Type
Optional[Array]
- data_seen
The number of training cases that the optimizer has processed.
- Type
Numeric
- step_counter
An integer giving the current step number \(t\).
- Type
Numeric
- __init__(velocities, estimator_state, damping, data_seen, step_counter)
- __init__(value_and_grad_func, l2_reg, value_func_has_aux=False, value_func_has_state=False, value_func_has_rng=False, use_adaptive_learning_rate=False, learning_rate_schedule=None, use_adaptive_momentum=False, momentum_schedule=None, use_adaptive_damping=False, damping_schedule=None, initial_damping=None, min_damping=1e-08, max_damping=inf, include_damping_in_quad_change=False, damping_adaptation_interval=5, damping_adaptation_decay=0.9, damping_lower_threshold=0.25, damping_upper_threshold=0.75, always_use_exact_qmodel_for_damping_adjustment=False, norm_constraint=None, num_burnin_steps=10, estimation_mode='fisher_gradients', curvature_ema=0.95, inverse_update_period=5, use_exact_inverses=False, batch_process_func=None, register_only_generic=False, patterns_to_skip=(), auto_register_kwargs=None, layer_tag_to_block_ctor=None, multi_device=False, debug=False, batch_size_extractor=<function default_batch_size_extractor>, pmap_axis_name='kfac_axis', forbid_setting_attributes_after_finalize=True, modifiable_attribute_exceptions=(), include_norms_in_stats=False, include_per_param_norms_in_stats=False, distributed_precon_apply=True, distributed_inverses=True)[source]
Initializes the K-FAC optimizer with the provided settings.
A note on damping:
One of the main complications of using second-order optimizers like K-FAC is the “damping” parameter. This parameter is multiplied by the identity matrix and (approximately) added to the curvature matrix (i.e. the Fisher or GGN) before it is inverted and multiplied by the gradient when computing the update (before any learning rate scaling). The damping should follow the scale of the objective, so that if you multiply your loss by some factor you should do the same for the damping. Roughly speaking, larger damping values constrain the update vector to a smaller region around zero, which is needed in general since the second-order approximations that underly second-order methods can break down for large updates. (In gradient descent the learning rate plays an analogous role.) The relationship between the damping parameter and the radius of this region is complicated and depends on the scale of the objective amongst other things.
The optimizer provides a system for adjusting the damping automatically via the
use_adaptive_dampingargument, although this system is not completely reliable. Using a fixed value or a manually tuned schedule can work as good or better for some problems, while it can be a very poor choice for others (like deep autoencoders). Empirically we have found that using a fixed value works well enough for common architectures like convnets and transformers.- Parameters
value_and_grad_func (ValueAndGradFunc) – Python callable. The function should return the value of the loss to be optimized and its gradients, and optionally the model state and auxiliary information (usually statistics to log). The interface should be:
out_args, loss_grads = value_and_grad_func(*in_args). Here,in_argsis(params, func_state, rng, batch), withrngomitted ifvalue_func_has_rngisFalse, and withfunc_stateomitted ifvalue_func_has_stateisFalse. Meanwhile,out_argsis(loss, func_state, aux), withfunc_stateomitted ifvalue_func_has_stateisFalse, and withauxomitted ifvalue_func_has_auxisFalse. If bothvalue_func_has_stateandvalue_func_has_auxareFalse,out_argsshould just belossand not(loss,).l2_reg (Numeric) – Scalar. Set this value to tell the optimizer what L2 regularization coefficient you are using (if any). Note the coefficient appears in the regularizer as
coeff / 2 * sum(param**2). This adds an additional diagonal term to the curvature and hence will affect the quadratic model when using adaptive damping. Note that the user is still responsible for adding regularization to the loss.value_func_has_aux (bool) – Boolean. Specifies whether the provided callable
value_and_grad_funcreturns auxiliary data. (Default:False)value_func_has_state (bool) – Boolean. Specifies whether the provided callable
value_and_grad_funchas a persistent state that is passed in and out. (Default:False)value_func_has_rng (bool) – Boolean. Specifies whether the provided callable
value_and_grad_funcadditionally takes as input an rng key. (Default:False)use_adaptive_learning_rate (bool) – Boolean. Specifies whether to use the special rule from the original K-FAC paper for picking the learning rate at each step. Note that this won’t work well for stochastic objectives. If this is
False, the user must use thelearning_rateargument of the step function, or the constructor argumentlearning_rate_schedule. (Default:False)learning_rate_schedule (Optional[ScheduleType]) – Callable. A schedule for the learning rate. This should take as input the current step number and return a single array that represents the learning rate. (Default:
None)use_adaptive_momentum (bool) – Boolean. Specifies whether to use the special rule from the original K-FAC paper for picking the momentum “decay” parameter at each step. Note that this won’t work well for stochastic bjectives. If this is
False, the user must use themomentumargument of the step function, or the constructor argumentmomentum_schedule. (Default:False)momentum_schedule (Optional[ScheduleType]) – Callable. A schedule for the momentum parameter. This should take as input the current step number and return a single array that represents the momentum. (Default:
None)use_adaptive_damping (bool) – Boolean. Specifies whether the optimizer will use the Levenberg-Marquardt method to automatically adjust the damping every
damping_adaptation_intervaliterations. If this is set toFalsethe user must provide a value to the damping argument of the step function at each iteration, or use thedamping_scheduleconstructor argument. Note that the effectiveness of this technique seems to vary between problems. (Default:False)damping_schedule (Optional[ScheduleType]) – Callable. A schedule for the damping. This should take as input the current step number and return a single array that represents the learning rate. (Default:
None)initial_damping (Optional[Numeric]) – Scalar or None. This specifies the initial value of the damping that the optimizer will use when using automatic damping adaptation. (Default:
None)min_damping (Numeric) – Scalar. Minimum value the damping parameter can take when using automatic damping adaptation. Note that the default value of 1e-8 is quite arbitrary, and you may have to adjust this up or down for your particular problem. If you are using a non-zero value of l2_reg you may be able to set this to zero. (Default:
1e-8)max_damping (Numeric) – Scalar. Maximum value the damping parameter can take when using automatic damping adaptation. (Default:
Infinity)include_damping_in_quad_change (bool) – Boolean. Whether to include the contribution of the damping in the quadratic model for the purposes computing the reduction ration (“rho”) in the Levenberg-Marquardt scheme used for adapting the damping. Note that the contribution from the
l2_regargument is always included. (Default:False)damping_adaptation_interval (int) – Int. The number of steps in between adapting the damping parameter. (Default:
5)damping_adaptation_decay (Numeric) – Scalar. The damping parameter will be adjusted up or down by
damping_adaptation_decay ** damping_adaptation_interval, or remain unchanged, everydamping_adaptation_intervalnumber of iterations. (Default:0.9)damping_lower_threshold (Numeric) – Scalar. The damping parameter is increased if the reduction ratio is below this threshold. (Default:
0.25)damping_upper_threshold (Numeric) – Scalar. The damping parameter is decreased if the reduction ratio is below this threshold. (Default:
0.75)always_use_exact_qmodel_for_damping_adjustment (bool) – Boolean. When using learning rate and/or momentum adaptation, the quadratic model change used for damping adaption is always computed using the exact curvature matrix. Otherwise, there is an option to use either the exact or approximate curvature matrix to compute the quadratic model change, which is what this argument controls. When True, the exact curvature matrix will be used, which is more expensive, but could possibly produce a better damping schedule. (Default:
False)norm_constraint (Optional[Numeric]) – Scalar. If specified, the update is scaled down so that its approximate squared Fisher norm
v^T F vis at most the specified value. (Note that hereFis the approximate curvature matrix, not the exact.) May only be used whenuse_adaptive_learning_rateisFalse. (Default:None)num_burnin_steps (int) – Int. At the start of optimization, e.g. the first step, before performing the actual step the optimizer will perform this many times updates to the curvature approximation without updating the actual parameters. (Default:
10)estimation_mode (str) – String. The type of estimator to use for the curvature matrix. See the documentation for
CurvatureEstimatorfor a detailed description of the possible options. (Default:fisher_gradients).curvature_ema (Numeric) – The decay factor used when calculating the covariance estimate moving averages. (Default:
0.95)inverse_update_period (int) – Int. The number of steps in between updating the the computation of the inverse curvature approximation. (Default:
5)use_exact_inverses (bool) – Bool. If
True, preconditioner inverses are computed “exactly” without the pi-adjusted factored damping approach. Note that this involves the use of eigendecompositions, which can sometimes be much more expensive. (Default:False)batch_process_func (Optional[Callable[[Batch], Batch]]) – Callable. A function which to be called on each batch before feeding to the KFAC on device. This could be useful for specific device input optimizations. (Default:
None)register_only_generic (bool) – Boolean. Whether when running the auto-tagger to register only generic parameters, or allow it to use the graph matcher to automatically pick up any kind of layer tags. (Default:
False)patterns_to_skip (Sequence[str]) – Tuple. A list of any patterns that should be skipped by the graph matcher when auto-tagging. (Default:
())auto_register_kwargs (Optional[Dict[str, Any]]) – Any additional kwargs to be passed down to
auto_register_tags(), which is called by the curvature estimator. (Default:None)layer_tag_to_block_ctor (Optional[Dict[str, curvature_estimator.CurvatureBlockCtor]]) – Dictionary. A mapping from layer tags to block classes which to override the default choices of block approximation for that specific tag. See the documentation for
CurvatureEstimatorfor a more detailed description. (Default:None)multi_device (bool) – Boolean. Whether to use pmap and run the optimizer on multiple devices. (Default:
False)debug (bool) – Boolean. If neither the step or init functions should be jitted. Note that this also overrides
multi_deviceand prevents using pmap. (Default:False)batch_size_extractor (Callable[[Batch], Numeric]) – A function that takes as input the function arguments and returns the batch size for a single device. (Default:
kfac.utils.default_batch_size_extractor)pmap_axis_name (str) – String. The name of the pmap axis to use when
multi_deviceis set to True. (Default:kfac_axis)forbid_setting_attributes_after_finalize (bool) – Boolean. By default after the object is finalized, you can not set any of its properties. This is done in order to protect the user from making changes to the object attributes that would not be picked up by various internal methods after they have been compiled. However, if you are extending this class, and clearly understand the risks of modifying attributes, setting this to
Falsewill remove the restriction. (Default:True)modifiable_attribute_exceptions (Sequence[str]) – Sequence of strings. Gives a list of names for attributes that can be modified after finalization even when
forbid_setting_attributes_after_finalizeisTrue. (Default:())include_norms_in_stats (bool) – Boolean. It True, the vector norms of the gradient, preconditioned gradient, and parameter update are included in the statistics returned by the step function. (Default:
False)include_per_param_norms_in_stats (bool) – Boolean. It True, the per-parameter vector norms of the gradient, preconditioned gradient, and parameter update are included in the statistics returned by the step function. (Default:
False)distributed_precon_apply (bool) – Boolean. Whether to distribute the application of the preconditioner across the different devices in a layer-wise fashion. If False, each device will (redundantly) perform the required operations for all of the layers. (Default: True)
distributed_inverses (bool) – Boolean. Whether to distribute the inverse computations (required to compute the preconditioner) across the different devices in a layer-wise fashion. If False, each device will (redundantly) perform the required computations for all of the layers. (Default: True)
- property num_burnin_steps: int
The number of burnin steps to run before the first parameter update.
- Return type
int
- property l2_reg: Array
The weight of the additional diagonal term added to the curvature.
- Return type
Array
- property estimator: curvature_estimator.BlockDiagonalCurvature
The underlying curvature estimator used by the optimizer.
- Return type
- property damping_decay_factor: Numeric
How fast to decay the damping, when using damping adaptation.
- Return type
Numeric
- should_update_damping(state)[source]
Whether at the current step the optimizer should update the damping.
- Return type
Array
- compute_loss_value(func_args)[source]
Computes the value of the loss function being optimized.
- Return type
Array
- verify_args_and_get_step_counter(step_counter, learning_rate=None, momentum=None, damping=None, global_step_int=None)[source]
Verifies that the arguments passed to the step function are correct.
- Return type
int
- init(params, rng, batch, func_state=None)[source]
Initializes the optimizer and returns the appropriate optimizer state.
- Return type
‘Optimizer.State’
- burnin(num_steps, params, state, rng, data_iterator, func_state=None)[source]
Runs all burnin steps required.
- Return type
Tuple[‘Optimizer.State’, Optional[FuncState]]
- step(params, state, rng, data_iterator=None, batch=None, func_state=None, learning_rate=None, momentum=None, damping=None, global_step_int=None)[source]
Performs a single update step using the optimizer.
- Parameters
params (Params) – The current parameters of the model.
state ('Optimizer.State') – The current state of the optimizer.
rng (PRNGKey) – A Jax PRNG key. Should be different for each iteration and each Jax process/host.
data_iterator (Optional[Iterator[Batch]]) – A data iterator to use (if not passing
batch).batch (Optional[Batch]) – A single batch used to compute the update. Should only pass one of
data_iteratororbatch.func_state (Optional[FuncState]) – Any function state that gets passed in and returned.
learning_rate (Optional[Array]) – Learning rate to use if the optimizer was created with
use_adaptive_learning_rate=True,Noneotherwise.momentum (Optional[Array]) – Momentum to use if the optimizer was created with
use_adaptive_momentum=True,Noneotherwise.damping (Optional[Array]) – Damping to use if the optimizer was created with
use_adaptive_damping=True,Noneotherwise. See discussion of constructor argumentinitial_dampingfor more information about damping.global_step_int (Optional[int]) – The global step as a python int. Note that this must match the step internal to the optimizer that is part of its state.
- Return type
ReturnEither
- Returns
(params, state, stats) if
value_func_has_state=Falseand (params, state, func_state, stats) otherwise, whereparams is the updated model parameters.
state is the updated optimizer state.
func_state is the updated function state.
stats is a dictionary of useful statistics including the loss.
- compute_l2_quad_matrix(vectors)[source]
Computes the matrix corresponding to the prior/regularizer.
- Parameters
vectors (Sequence[Params]) – A sequence of parameter-like PyTree structures, each one
vector. (representing a different) –
- Return type
Array
- Returns
A matrix with i,j entry equal to
self.l2_reg * v_i^T v_j.
- compute_exact_quad_model(vectors, grads, func_args=None)[source]
Computes the components of the exact quadratic model.
- Return type
Tuple[Array, Array, Array]
- compute_approx_quad_model(state, vectors, grads)[source]
Computes the components of the approximate quadratic model.
- Return type
Tuple[Array, Array, Array]
Curvature Estimators
|
An abstract curvature estimator class. |
|
Block diagonal curvature estimator class. |
|
Explicit exact full curvature estimator class. |
|
Represents all exact curvature matrices never constructed explicitly. |
|
Sets the default curvature block constructor for the given tag. |
|
Returns the default curvature block constructor for the give tag name. |
CurvatureEstimator
- class kfac_jax.CurvatureEstimator(func, params_index=0, default_estimation_mode='fisher_gradients')[source]
An abstract curvature estimator class.
This is a class that abstracts away the process of estimating a curvature matrix and provides many useful functionalities for interacting with it. The state of the estimator contains two parts: the estimated curvature internal representation, as well as potential cached values of different expression involving the curvature matrix (for example matrix powers). The cached values are only updated once you call the method
update_cache(). Multiple methods contain the keyword argumentuse_cachedwhich specify whether you want to compute the corresponding expression using the current curvature estimate or used a cached version.- func
The model evaluation function.
- params_index
The index of the parameters argument in arguments list of
func.
- default_estimation_mode
The estimation mode which to use by default when calling
update_curvature_matrix_estimate().
- __init__(func, params_index=0, default_estimation_mode='fisher_gradients')[source]
Initializes the CurvatureEstimator instance.
- Parameters
func (utils.Func) – The model function, which should have at least one registered loss.
params_index (int) – The index of the parameters argument in arguments list of
func.default_estimation_mode (str) – The estimation mode which to use by default when calling
update_curvature_matrix_estimate().
- property default_mat_type: str
The type of matrix that this estimator is approximating.
- Return type
str
- abstract property dim: int
The number of elements of all parameter variables together.
- Return type
int
- abstract init(rng, func_args, exact_powers_to_cache, approx_powers_to_cache, cache_eigenvalues=False)[source]
Initializes the state for the estimator.
- Parameters
rng (PRNGKey) – The PRNGKey which to be used for any randomness of the initialization.
func_args (utils.FuncArgs) – Example function arguments, which to be used to trace the model function and initialize the state.
exact_powers_to_cache (Optional[curvature_blocks.ScalarOrSequence]) – A single value, or multiple values in a list, which specify which exact matrix powers that each block should be caching. Matrix powers for which you intend to call
self.multiply_matrix_power,self.multiply_inverseorself.multiplywithexact_power=Trueanduse_cached=Truemust be provided here.approx_powers_to_cache (Optional[curvature_blocks.ScalarOrSequence]) – A single value, or multiple values in a list, which specify approximate matrix powers that each block should be caching. Matrix powers for which you intend to call
self.multiply_matrix_power,self.multiply_inverseorself.multiplywithexact_power=Falseanduse_cached=Truemust be provided here.cache_eigenvalues (bool) – Specifies whether each block should be caching the eigenvalues of its approximate curvature.
- Return type
StateType
- Returns
The initialized state of the estimator.
- abstract multiply_matpower(state, parameter_structured_vector, identity_weight, power, exact_power, use_cached, pmap_axis_name)[source]
Computes
(CurvatureMatrix + identity_weight I)**powertimesvector.- Parameters
state (StateType) – The state of the estimator.
parameter_structured_vector (utils.Params) – A vector in the same structure as the parameters of the model.
identity_weight (Numeric) – Specifies the weight of the identity element that is added to the curvature matrix. This can be either a scalar value or a list/tuple of scalar in which case each value specifies the weight individually for each block.
power (Scalar) – The power to which you want to raise the matrix
(EstimateCurvature + identity_weight I).exact_power (bool) – When set to
Truethe matrix power ofEstimateCurvature + identity_weight Iis computed exactly. Otherwise this method might use a cheaper approximation, which may vary across different blocks.use_cached (bool) – Whether to use a cached (and possibly stale) version of the curvature matrix estimate.
pmap_axis_name (Optional[str]) – The name of any pmap axis, which will be used for aggregating any computed values over multiple devices, as well as parallelizing the computation over devices in a block-wise fashion.
- Return type
utils.Params
- Returns
A parameter structured vector containing the product.
- multiply(state, parameter_structured_vector, identity_weight, exact_power, use_cached, pmap_axis_name)[source]
Computes
(CurvatureMatrix + identity_weight I)timesvector.- Return type
utils.Params
- multiply_inverse(state, parameter_structured_vector, identity_weight, exact_power, use_cached, pmap_axis_name)[source]
Computes
(CurvatureMatrix + identity_weight I)^-1timesvector.- Return type
utils.Params
- abstract eigenvalues(state, use_cached)[source]
Computes the eigenvalues of the curvature matrix.
- Parameters
state (StateType) – The state of the estimator.
use_cached (bool) – Whether to use a cached versions of the eigenvalues or to use the most recent curvature estimates to compute them. The cached version are going to be at least as fresh as the last time you called
update_cache()witheigenvalues=True.
- Return type
Array
- Returns
A single array containing the eigenvalues of the curvature matrix.
- abstract update_curvature_matrix_estimate(state, ema_old, ema_new, batch_size, rng, func_args, pmap_axis_name, estimation_mode=None)[source]
Updates the estimator’s curvature estimates.
- Parameters
state (StateType) – The state of the estimator to update.
ema_old (Numeric) – Specifies the weight of the old value when computing the updated estimate in the moving average.
ema_new (Numeric) – Specifies the weight of the new value when computing the updated estimate in the moving average.
batch_size (Numeric) – The batch size.
rng (PRNGKey) – A PRNGKey to be used for any potential sampling in the estimation process.
func_args (utils.FuncArgs) – A structure with the values of the inputs to the traced function (the
tagged_funcpassed into the constructor) which to be used for the estimation process. Should have the same structure as the argumentfunc_argspassed in the constructor.pmap_axis_name (Optional[str]) – When calling this method within a pmap context this argument specifies the axis name over which to aggregate across multiple devices/hosts.
estimation_mode (Optional[str]) –
The type of curvature estimator to use. By default (e.g. if
None) will useself.default_estimation_mode. One of:fisher_gradients - the basic estimation approach from the original K-FAC paper.
fisher_curvature_prop - method which estimates the Fisher using self-products of random 1/-1 vectors times “half-factors” of the Fisher, as described here.
fisher_exact - is the obvious generalization of Curvature Propagation to compute the exact Fisher (modulo any additional diagonal or Kronecker approximations) by looping over one-hot vectors for each coordinate of the output instead of using 1/-1 vectors. It is more expensive to compute than the other three options by a factor equal to the output dimension, roughly speaking.
fisher_empirical - computes the ‘empirical’ Fisher information matrix (which uses the data’s distribution for the targets, as opposed to the true Fisher which uses the model’s distribution) and requires that each registered loss have specified targets.
ggn_curvature_prop - Analogous to fisher_curvature_prop, but estimates the Generalized Gauss-Newton matrix (GGN).
ggn_exact - Analogous to fisher_exact, but estimates the Generalized Gauss-Newton matrix (GGN).
- Return type
StateType
- Returns
The updated state.
- abstract update_cache(state, identity_weight, exact_powers, approx_powers, eigenvalues, pmap_axis_name)[source]
Updates the estimator cached values.
- Parameters
state (StateType) – The state of the estimator to update.
identity_weight (Numeric) – Specified the weight of the identity element that is added to the curvature matrix. This can be either a scalar value or a list/tuple of scalar in which case each value specifies the weight individually for each block.
exact_powers (Optional[curvature_blocks.ScalarOrSequence]) – Specifies which exact matrix powers in the cache should be updated.
approx_powers (Optional[curvature_blocks.ScalarOrSequence]) – Specifies which approximate matrix powers in the cache should be updated.
eigenvalues (bool) – Specifies whether to update the cached eigenvalues of each block. If they have not been cached before, this will create an entry with them in the block’s cache.
pmap_axis_name (Optional[str]) – The name of any pmap axis, which will be used for aggregating any computed values over multiple devices, as well as parallelizing the computation over devices in a block-wise fashion.
- Return type
StateType
- Returns
The updated state.
BlockDiagonalCurvature
- class kfac_jax.BlockDiagonalCurvature(func, params_index=0, default_estimation_mode='fisher_gradients', layer_tag_to_block_ctor=None, index_to_block_ctor=None, auto_register_tags=True, distributed_multiplies=True, distributed_cache_updates=True, **auto_register_kwargs)[source]
Block diagonal curvature estimator class.
- class State(blocks_states)[source]
Persistent state of the estimator.
- blocks_states
A tuple of the state of the estimator corresponding to each block.
- Type
Tuple[curvature_blocks.CurvatureBlock.State, …]
- __init__(blocks_states)
- __init__(func, params_index=0, default_estimation_mode='fisher_gradients', layer_tag_to_block_ctor=None, index_to_block_ctor=None, auto_register_tags=True, distributed_multiplies=True, distributed_cache_updates=True, **auto_register_kwargs)[source]
Initializes the curvature instance.
- Parameters
func (utils.Func) – The model function, which should have at least one registered loss.
params_index (int) – The index of the parameters argument in arguments list of
func.default_estimation_mode (str) – The estimation mode which to use by default when calling
self.update_curvature_matrix_estimate.layer_tag_to_block_ctor (Optional[Mapping[str, CurvatureBlockCtor]]) – An optional dict mapping tags to specific classes of block approximations, which to override the default ones.
index_to_block_ctor (Optional[Mapping[Tuple[int, ...], CurvatureBlockCtor]]) – An optional dict mapping a specific block parameter indices to specific classes of block approximation, which to override the default ones. To get the correct indices check
estimator.indices_to_block_map.auto_register_tags (bool) – Whether to automatically register layer tags for parameters that have not been manually registered. For further details see
tag_graph_matcher.auto_register_tags.distributed_multiplies (bool) – Whether to distribute the curvature matrix multiplication operations across the different devices in a block-wise fashion. If False, each device will (redundantly) perform the operations for all of the blocks.
distributed_cache_updates (bool) – Whether to distribute the cache update multiplication operations across the different devices in a block-wise fashion. If False, each device will (redundantly) perform the operations for all of the blocks.
**auto_register_kwargs – Any keyword arguments to pass to into the auto registration function.
- property blocks: Optional[Tuple[curvature_blocks.CurvatureBlock]]
The tuple of
CurvatureBlockinstances used for each layer.- Return type
Optional[Tuple[curvature_blocks.CurvatureBlock]]
- property num_blocks: int
The number of separate blocks that this estimator has.
- Return type
int
- property block_dims: Shape
The number of elements of all parameter variables for each block.
- Return type
Shape
- property dim: int
The number of elements of all parameter variables together.
- Return type
int
- property params_structure_vector_of_indices: utils.Params
A tree structure with parameters replaced by their indices.
- Return type
utils.Params
- property indices_to_block_map: Mapping[Tuple[int, ...], curvature_blocks.CurvatureBlock]
A mapping of parameter indices to their associated blocks.
- Return type
Mapping[Tuple[int, …], curvature_blocks.CurvatureBlock]
- property params_block_index: utils.Params
A structure, which shows each parameter to which block it corresponds.
- Return type
utils.Params
- Returns
A parameter-like structure, where each parameter is replaced by an integer index. This index specifies the block (found by
self.blocks[index]) which approximates the part of the curvature matrix associated with the parameter.
- property num_params_variables: int
The number of separate parameter variables of the model.
- Return type
int
- params_vector_to_blocks_vectors(parameter_structured_vector)[source]
Splits the parameters to values for each corresponding block.
- Return type
Tuple[Tuple[Array, …]]
- blocks_vectors_to_params_vector(blocks_vectors)[source]
Reverses the effect of
self.vectors_to_blocks.- Return type
utils.Params
- init(rng, func_args, exact_powers_to_cache, approx_powers_to_cache, cache_eigenvalues=False)[source]
Initializes the state for the estimator.
- Parameters
rng (PRNGKey) – The PRNGKey which to be used for any randomness of the initialization.
func_args (utils.FuncArgs) – Example function arguments, which to be used to trace the model function and initialize the state.
exact_powers_to_cache (Optional[curvature_blocks.ScalarOrSequence]) – A single value, or multiple values in a list, which specify which exact matrix powers that each block should be caching. Matrix powers for which you intend to call
self.multiply_matrix_power,self.multiply_inverseorself.multiplywithexact_power=Trueanduse_cached=Truemust be provided here.approx_powers_to_cache (Optional[curvature_blocks.ScalarOrSequence]) – A single value, or multiple values in a list, which specify approximate matrix powers that each block should be caching. Matrix powers for which you intend to call
self.multiply_matrix_power,self.multiply_inverseorself.multiplywithexact_power=Falseanduse_cached=Truemust be provided here.cache_eigenvalues (bool) – Specifies whether each block should be caching the eigenvalues of its approximate curvature.
- Return type
‘BlockDiagonalCurvature.State’
- Returns
The initialized state of the estimator.
- multiply_matpower(state, parameter_structured_vector, identity_weight, power, exact_power, use_cached, pmap_axis_name)[source]
Computes
(CurvatureMatrix + identity_weight I)**powertimesvector.- Parameters
state ('BlockDiagonalCurvature.State') – The state of the estimator.
parameter_structured_vector (utils.Params) – A vector in the same structure as the parameters of the model.
identity_weight (Union[Numeric, Sequence[Numeric]]) – Specifies the weight of the identity element that is added to the curvature matrix. This can be either a scalar value or a list/tuple of scalar in which case each value specifies the weight individually for each block.
power (Scalar) – The power to which you want to raise the matrix
(EstimateCurvature + identity_weight I).exact_power (bool) – When set to
Truethe matrix power ofEstimateCurvature + identity_weight Iis computed exactly. Otherwise this method might use a cheaper approximation, which may vary across different blocks.use_cached (bool) – Whether to use a cached (and possibly stale) version of the curvature matrix estimate.
pmap_axis_name (Optional[str]) – The name of any pmap axis, which will be used for aggregating any computed values over multiple devices, as well as parallelizing the computation over devices in a block-wise fashion.
- Return type
utils.Params
- Returns
A parameter structured vector containing the product.
- block_eigenvalues(state, use_cached)[source]
Computes the eigenvalues for each block of the curvature estimator.
- Parameters
state ('BlockDiagonalCurvature.State') – The state of the estimator.
use_cached (bool) – Whether to use a cached versions of the eigenvalues or to use the most recent curvature estimates to compute them. The cached version are going to be at least as fresh as the last time you called
update_cache()witheigenvalues=True.
- Return type
Tuple[Array, …]
- Returns
A tuple of arrays containing the eigenvalues for each block. The order of this tuple corresponds to the ordering of
self.blocks. To understand which parameters correspond to which block you can callself.parameters_block_index.
- eigenvalues(state, use_cached)[source]
Computes the eigenvalues of the curvature matrix.
- Parameters
state ('BlockDiagonalCurvature.State') – The state of the estimator.
use_cached (bool) – Whether to use a cached versions of the eigenvalues or to use the most recent curvature estimates to compute them. The cached version are going to be at least as fresh as the last time you called
update_cache()witheigenvalues=True.
- Return type
Array
- Returns
A single array containing the eigenvalues of the curvature matrix.
- update_curvature_matrix_estimate(state, ema_old, ema_new, batch_size, rng, func_args, pmap_axis_name, estimation_mode=None)[source]
Updates the estimator’s curvature estimates.
- Parameters
state ('BlockDiagonalCurvature.State') – The state of the estimator to update.
ema_old (Numeric) – Specifies the weight of the old value when computing the updated estimate in the moving average.
ema_new (Numeric) – Specifies the weight of the new value when computing the updated estimate in the moving average.
batch_size (Numeric) – The batch size.
rng (PRNGKey) – A PRNGKey to be used for any potential sampling in the estimation process.
func_args (utils.FuncArgs) – A structure with the values of the inputs to the traced function (the
tagged_funcpassed into the constructor) which to be used for the estimation process. Should have the same structure as the argumentfunc_argspassed in the constructor.pmap_axis_name (Optional[str]) – When calling this method within a pmap context this argument specifies the axis name over which to aggregate across multiple devices/hosts.
estimation_mode (Optional[str]) –
The type of curvature estimator to use. By default (e.g. if
None) will useself.default_estimation_mode. One of:fisher_gradients - the basic estimation approach from the original K-FAC paper.
fisher_curvature_prop - method which estimates the Fisher using self-products of random 1/-1 vectors times “half-factors” of the Fisher, as described here.
fisher_exact - is the obvious generalization of Curvature Propagation to compute the exact Fisher (modulo any additional diagonal or Kronecker approximations) by looping over one-hot vectors for each coordinate of the output instead of using 1/-1 vectors. It is more expensive to compute than the other three options by a factor equal to the output dimension, roughly speaking.
fisher_empirical - computes the ‘empirical’ Fisher information matrix (which uses the data’s distribution for the targets, as opposed to the true Fisher which uses the model’s distribution) and requires that each registered loss have specified targets.
ggn_curvature_prop - Analogous to fisher_curvature_prop, but estimates the Generalized Gauss-Newton matrix (GGN).
ggn_exact - Analogous to fisher_exact, but estimates the Generalized Gauss-Newton matrix (GGN).
- Return type
‘BlockDiagonalCurvature.State’
- Returns
The updated state.
- update_cache(state, identity_weight, exact_powers, approx_powers, eigenvalues, pmap_axis_name)[source]
Updates the estimator cached values.
- Parameters
state ('BlockDiagonalCurvature.State') – The state of the estimator to update.
identity_weight (Union[Numeric, Sequence[Numeric]]) – Specified the weight of the identity element that is added to the curvature matrix. This can be either a scalar value or a list/tuple of scalar in which case each value specifies the weight individually for each block.
exact_powers (Optional[curvature_blocks.ScalarOrSequence]) – Specifies which exact matrix powers in the cache should be updated.
approx_powers (Optional[curvature_blocks.ScalarOrSequence]) – Specifies which approximate matrix powers in the cache should be updated.
eigenvalues (bool) – Specifies whether to update the cached eigenvalues of each block. If they have not been cached before, this will create an entry with them in the block’s cache.
pmap_axis_name (Optional[str]) – The name of any pmap axis, which will be used for aggregating any computed values over multiple devices, as well as parallelizing the computation over devices in a block-wise fashion.
- Return type
‘BlockDiagonalCurvature.State’
- Returns
The updated state.
ExplicitExactCurvature
- class kfac_jax.ExplicitExactCurvature(func, params_index=0, batch_index=1, default_estimation_mode='fisher_exact', layer_tag_to_block_ctor=None, index_to_block_ctor=None, auto_register_tags=True, **auto_register_kwargs)[source]
Explicit exact full curvature estimator class.
This class estimates the full curvature matrix by looping over the batch dimension of the input data and for each single example computes an estimate of the curvature matrix and then averages over all examples in the input data. This implies that the computation scales linearly (without parallelism) with the batch size. The class stores the estimated curvature as a dense matrix, hence its memory requirement is (number of parameters)^2. If
estimation_modeisfisher_exactorggn_exactthan this would compute the exact curvature, but other modes are also supported. As a result of looping over the input data this class needs to know the index of the batch in the arguments to the model function and additionally, since the loop is achieved through indexing, each array leaf of that argument must have the same first dimension size, which will be interpreted as the batch size.- __init__(func, params_index=0, batch_index=1, default_estimation_mode='fisher_exact', layer_tag_to_block_ctor=None, index_to_block_ctor=None, auto_register_tags=True, **auto_register_kwargs)[source]
Initializes the curvature instance.
- Parameters
func (utils.Func) – The model function, which should have at least one registered loss.
params_index (int) – The index of the parameters argument in arguments list of
func.batch_index (int) – Specifies at which index of the inputs to
funcis the batch, representing data over which we average the curvature.default_estimation_mode (str) – The estimation mode which to use by default when calling
self.update_curvature_matrix_estimate.layer_tag_to_block_ctor (Optional[Mapping[str, CurvatureBlockCtor]]) – An optional dict mapping tags to specific classes of block approximations, which to override the default ones.
index_to_block_ctor (Optional[Mapping[Tuple[int, ...], CurvatureBlockCtor]]) – An optional dict mapping a specific block parameter indices to specific classes of block approximation, which to override the default ones. To get the correct indices check
estimator.indices_to_block_map.auto_register_tags (bool) – Whether to automatically register layer tags for parameters that have not been manually registered. For further details see :func:
~auto_register_tags.**auto_register_kwargs – Any keyword arguments to pass to into the auto registration function.
- property batch_index: int
The index in the inputs of the model function, which is the batch.
- Return type
int
- params_vector_to_blocks_vectors(parameter_structured_vector)[source]
Splits the parameters to values for each corresponding block.
- Return type
Tuple[Tuple[Array, …]]
- blocks_vectors_to_params_vector(blocks_vectors)[source]
Reverses the effect of
self.vectors_to_blocks.- Return type
utils.Params
- update_curvature_matrix_estimate(state, ema_old, ema_new, batch_size, rng, func_args, pmap_axis_name, estimation_mode=None)[source]
Updates the estimator’s curvature estimates.
- Parameters
state (BlockDiagonalCurvature.State) – The state of the estimator to update.
ema_old (Numeric) – Specifies the weight of the old value when computing the updated estimate in the moving average.
ema_new (Numeric) – Specifies the weight of the new value when computing the updated estimate in the moving average.
batch_size (Numeric) – The batch size.
rng (PRNGKey) – A PRNGKey to be used for any potential sampling in the estimation process.
func_args (utils.FuncArgs) – A structure with the values of the inputs to the traced function (the
tagged_funcpassed into the constructor) which to be used for the estimation process. Should have the same structure as the argumentfunc_argspassed in the constructor.pmap_axis_name (Optional[str]) – When calling this method within a pmap context this argument specifies the axis name over which to aggregate across multiple devices/hosts.
estimation_mode (Optional[str]) –
The type of curvature estimator to use. By default (e.g. if
None) will useself.default_estimation_mode. One of:fisher_gradients - the basic estimation approach from the original K-FAC paper.
fisher_curvature_prop - method which estimates the Fisher using self-products of random 1/-1 vectors times “half-factors” of the Fisher, as described here.
fisher_exact - is the obvious generalization of Curvature Propagation to compute the exact Fisher (modulo any additional diagonal or Kronecker approximations) by looping over one-hot vectors for each coordinate of the output instead of using 1/-1 vectors. It is more expensive to compute than the other three options by a factor equal to the output dimension, roughly speaking.
fisher_empirical - computes the ‘empirical’ Fisher information matrix (which uses the data’s distribution for the targets, as opposed to the true Fisher which uses the model’s distribution) and requires that each registered loss have specified targets.
ggn_curvature_prop - Analogous to fisher_curvature_prop, but estimates the Generalized Gauss-Newton matrix (GGN).
ggn_exact - Analogous to fisher_exact, but estimates the Generalized Gauss-Newton matrix (GGN).
- Return type
- Returns
The updated state.
- update_cache(state, identity_weight, exact_powers, approx_powers, eigenvalues, pmap_axis_name)[source]
Updates the estimator cached values.
- Parameters
state (BlockDiagonalCurvature.State) – The state of the estimator to update.
identity_weight (Numeric) – Specified the weight of the identity element that is added to the curvature matrix. This can be either a scalar value or a list/tuple of scalar in which case each value specifies the weight individually for each block.
exact_powers (Optional[curvature_blocks.ScalarOrSequence]) – Specifies which exact matrix powers in the cache should be updated.
approx_powers (Optional[curvature_blocks.ScalarOrSequence]) – Specifies which approximate matrix powers in the cache should be updated.
eigenvalues (bool) – Specifies whether to update the cached eigenvalues of each block. If they have not been cached before, this will create an entry with them in the block’s cache.
pmap_axis_name (Optional[str]) – The name of any pmap axis, which will be used for aggregating any computed values over multiple devices, as well as parallelizing the computation over devices in a block-wise fashion.
- Return type
- Returns
The updated state.
ImplicitExactCurvature
- class kfac_jax.ImplicitExactCurvature(func, params_index=0, batch_size_extractor=<function default_batch_size_extractor>)[source]
Represents all exact curvature matrices never constructed explicitly.
- __init__(func, params_index=0, batch_size_extractor=<function default_batch_size_extractor>)[source]
Initializes the ImplicitExactCurvature instance.
- Parameters
func (utils.Func) – The model function, which should have at least one registered loss.
params_index (int) – The index of the parameters argument in arguments list of
func.batch_size_extractor (Callable[[utils.Batch], Numeric]) – A function that takes as input the function arguments and returns the batch size for a single device. (Default:
kfac.utils.default_batch_size_extractor)
- batch_size(func_args)[source]
The expected batch size given a list of loss instances.
- Return type
Numeric
- multiply_hessian(func_args, parameter_structured_vector)[source]
Multiplies the vector with the Hessian matrix of the total loss.
- Parameters
func_args (utils.FuncArgs) – The inputs to the model function, on which to evaluate the Hessian matrix.
parameter_structured_vector (utils.Params) – The vector which to multiply with the Hessian matrix.
- Return type
utils.Params
- Returns
The product
Hv.
- multiply_fisher(func_args, parameter_structured_vector)[source]
Multiplies the vector with the Fisher matrix of the total loss.
- Parameters
func_args (utils.FuncArgs) – The inputs to the model function, on which to evaluate the Fisher matrix.
parameter_structured_vector (utils.Params) – The vector which to multiply with the Fisher matrix.
- Return type
utils.Params
- Returns
The product
Fv.
- multiply_ggn(func_args, parameter_structured_vector)[source]
Multiplies the vector with the GGN matrix of the total loss.
- Parameters
func_args (utils.FuncArgs) – The inputs to the model function, on which to evaluate the GGN matrix.
parameter_structured_vector (utils.Params) – The vector which to multiply with the GGN matrix.
- Return type
utils.Params
- Returns
The product
Gv.
- multiply_fisher_factor_transpose(func_args, parameter_structured_vector)[source]
Multiplies the vector with the transposed factor of the Fisher matrix.
- Parameters
func_args (utils.FuncArgs) – The inputs to the model function, on which to evaluate the Fisher matrix.
parameter_structured_vector (utils.Params) – The vector which to multiply with the Fisher matrix.
- Return type
Tuple[Array, …]
- Returns
The product
B^T v, whereF = BB^T.
- multiply_ggn_factor_transpose(func_args, parameter_structured_vector)[source]
Multiplies the vector with the transposed factor of the GGN matrix.
- Parameters
func_args (utils.FuncArgs) – The inputs to the model function, on which to evaluate the GGN matrix.
parameter_structured_vector (utils.Params) – The vector which to multiply with the GGN matrix.
- Return type
Tuple[Array, …]
- Returns
The product
B^T v, whereG = BB^T.
- multiply_fisher_factor(func_args, loss_inner_vectors)[source]
Multiplies the vector with the factor of the Fisher matrix.
- Parameters
func_args (utils.FuncArgs) – The inputs to the model function, on which to evaluate the Fisher matrix.
loss_inner_vectors (Sequence[Array]) – The vector which to multiply with the Fisher factor matrix.
- Return type
utils.Params
- Returns
The product
Bv, whereF = BB^T.
- multiply_ggn_factor(func_args, loss_inner_vectors)[source]
Multiplies the vector with the factor of the GGN matrix.
- Parameters
func_args (utils.FuncArgs) – The inputs to the model function, on which to evaluate the GGN matrix.
loss_inner_vectors (Sequence[Array]) – The vector which to multiply with the GGN factor matrix.
- Return type
utils.Params
- Returns
The product
Bv, whereG = BB^T.
- multiply_jacobian_transpose(func_args, loss_input_vectors)[source]
Multiplies a vector by the model’s transposed Jacobian.
- Parameters
func_args (utils.FuncArgs) – The inputs to the model function.
loss_input_vectors (Sequence[Sequence[Array]]) – A sequence over losses of sequences of arrays that are the size of the loss’s inputs. This represents the vector to be multiplied.
- Return type
utils.Params
- Returns
The product
J^T v, whereJis the model’s Jacobian andvis is given byloss_inner_vectors.
- get_loss_inner_vector_shapes_and_batch_size(func_args, mode)[source]
Get shapes of loss inner vectors, and the batch size.
- Parameters
func_args (utils.FuncArgs) – The inputs to the model function.
mode (str) – A string representing the type of curvature matrix for the loss inner vectors. Can be “fisher” or “ggn”.
- Return type
Tuple[Tuple[Shape, …], int]
- Returns
Shapes of loss inner vectors in a tuple, and the batch size as an int.
- get_loss_input_shapes_and_batch_size(func_args)[source]
Get shapes of loss input vectors, and the batch size.
- Parameters
func_args (utils.FuncArgs) – The inputs to the model function.
- Return type
Tuple[Tuple[Tuple[Shape, …], …], int]
- Returns
A tuple over losses of tuples containing the shapes of their different inputs, and the batch size (as an int).
set_default_tag_to_block_ctor
Loss Functions
|
Abstract base class for loss functions. |
|
Base class for loss functions that represent negative log-probability. |
|
Negative log-probability loss that uses a Distrax distribution. |
|
Loss log prob loss for a normal distribution parameterized by a mean vector. |
|
Negative log prob loss for a normal distribution with mean and variance. |
|
Negative log prob loss for multiple Bernoulli distributions parametrized by logits. |
Negative log prob loss for a categorical distribution parameterized by logits. |
|
Neg log prob loss for a categorical distribution with onehot targets. |
|
Registers a sigmoid cross-entropy loss function. |
|
Registers a multi-Bernoulli predictive distribution. |
|
Registers a softmax cross-entropy loss function. |
|
Registers a categorical predictive distribution. |
|
|
Registers a squared error loss function. |
Registers a normal predictive distribution. |
LossFunction
- class kfac_jax.LossFunction(weight)[source]
Abstract base class for loss functions.
Note that unlike typical loss functions used in neural networks these are neither summed nor averaged over the batch and the output of evaluate() will not be a scalar. It is up to the user to then to correctly manipulate them as needed.
- __init__(weight)[source]
Initializes the loss instance.
- Parameters
weight (Numeric) – The relative weight attributed to the loss.
- property weight: Numeric
The relative weight of the loss.
- Return type
Numeric
- abstract property targets: Optional[Array]
The targets (if present) used for evaluating the loss.
- Return type
Optional[Array]
- abstract property parameter_dependants: Tuple[Array, ...]
All the parameter dependent arrays of the loss.
- Return type
Tuple[Array, …]
- property num_parameter_dependants: int
Number of parameter dependent arrays of the loss.
- Return type
int
- abstract property parameter_independants: Tuple[Numeric, ...]
All the parameter independent arrays of the loss.
- Return type
Tuple[Numeric, …]
- property num_parameter_independants: int
Number of parameter independent arrays of the loss.
- Return type
int
- abstract copy_with_different_inputs(parameter_dependants)[source]
Creates a copy of the loss function object, but with different inputs.
- Return type
‘LossFunction’
- evaluate(targets=None, coefficient_mode='regular')[source]
Evaluates the loss function on the targets.
- Parameters
targets (Optional[Array]) – The targets, on which to evaluate the loss. If this is set to
Nonewill useself.targetsinstead.coefficient_mode (str) –
Specifies how to use the relative weight of the loss in the returned value. There are three options:
’regular’ - returns
self.weight * loss(targets)’sqrt’ - returns
sqrt(self.weight) * loss(targets)’off’ - returns
loss(targets)
- Return type
Array
- Returns
The value of the loss scaled appropriately by
self.weightaccording to the coefficient mode.- Raises
ValueError if both targets and self.targets are None. –
- grad_of_evaluate(targets, coefficient_mode)[source]
Evaluates the gradient of the loss function, w.r.t. its inputs.
- Parameters
targets (Optional[Array]) – The targets at which to evaluate the loss. If this is
Nonewill useself.targetsinstead.coefficient_mode (str) – The coefficient mode to use for evaluation. See
self.evaluatefor more details.
- Return type
Tuple[Array, …]
- Returns
The gradient of the loss function w.r.t. its inputs, at the provided targets.
- multiply_ggn(vector)[source]
Right-multiplies a vector by the GGN of the loss function.
Here the GGN is the Generalized Gauss-Newton matrix (whose definition is somewhat flexible) of the loss function with respect to its inputs.
- Parameters
vector (Sequence[Array]) – The vector to multiply. Must have the same shape(s) as
self.inputs.- Return type
Tuple[Array, …]
- Returns
The vector right-multiplied by the GGN. Will have the same shape(s) as
self.inputs.
- abstract multiply_ggn_unweighted(vector)[source]
Unweighted version of
multiply_ggn().- Return type
Tuple[Array, …]
- multiply_ggn_factor(vector)[source]
Right-multiplies a vector by a factor B of the GGN.
Here the GGN is the Generalized Gauss-Newton matrix (whose definition is somewhat flexible) of the loss function with respect to its inputs. Typically this will be block-diagonal across different cases in the batch, since the loss function is typically summed across cases.
Note that B can be any matrix satisfying
B * B^T = GwhereGis the GGN, but will agree with the one used in the other methods of this class.- Parameters
vector (Array) – The vector to multiply. Must be of the shape(s) given by ‘self.ggn_factor_inner_shape’.
- Return type
Tuple[Array, …]
- Returns
The vector right-multiplied by B. Will be of the same shape(s) as
self.inputs.
- abstract multiply_ggn_factor_unweighted(vector)[source]
Unweighted version of
multiply_ggn_factor().- Return type
Tuple[Array, …]
- multiply_ggn_factor_transpose(vector)[source]
Right-multiplies a vector by the transpose of a factor B of the GGN.
Here the GGN is the Generalized Gauss-Newton matrix (whose definition is somewhat flexible) of the loss function with respect to its inputs. Typically this will be block-diagonal across different cases in the batch, since the loss function is typically summed across cases.
Note that B can be any matrix satisfying
B * B^T = Gwhere G is the GGN, but will agree with the one used in the other methods of this class.- Parameters
vector (Sequence[Array]) – The vector to multiply. Must have the same shape(s) as
self.inputs.- Return type
Array
- Returns
The vector right-multiplied by B^T. Will be of the shape(s) given by
self.ggn_factor_inner_shape.
- abstract multiply_ggn_factor_transpose_unweighted(vector)[source]
Unweighted version of
multiply_ggn_factor_transpose().- Return type
Array
- multiply_ggn_factor_replicated_one_hot(index)[source]
Right-multiplies a replicated-one-hot vector by a factor B of the GGN.
Here the GGN is the Generalized Gauss-Newton matrix (whose definition is somewhat flexible) of the loss function with respect to its inputs. Typically this will be block-diagonal across different cases in the batch, since the loss function is typically summed across cases.
A replicated-one-hot vector means a tensor which, for each slice along the batch dimension (assumed to be dimension 0), is 1.0 in the entry corresponding to the given index and 0 elsewhere.
Note that B can be any matrix satisfying
B * B^T = Gwhere G is the GGN, but will agree with the one used in the other methods of this class.- Parameters
index (Sequence[int]) – A tuple representing in the index of the entry in each slice that is 1.0. Note that len(index) must be equal to the number of elements of the
ggn_factor_inner_shapetensor minus one.- Return type
Tuple[Array, …]
- Returns
The vector right-multiplied by B^T. Will be of the same shape(s) as the
inputsproperty.
- abstract multiply_ggn_factor_replicated_one_hot_unweighted(index)[source]
Unweighted version of
multiply_ggn_factor_replicated_one_hot().- Return type
Tuple[Array, …]
- abstract property ggn_factor_inner_shape: Shape
The shape of the array returned by self.multiply_ggn_factor.
- Return type
Shape
NegativeLogProbLoss
- class kfac_jax.NegativeLogProbLoss(weight)[source]
Base class for loss functions that represent negative log-probability.
- property parameter_dependants: Tuple[Array, ...]
All the parameter dependent arrays of the loss.
- Return type
Tuple[Array, …]
- abstract property params: Tuple[Array, ...]
Parameters to the underlying distribution.
- Return type
Tuple[Array, …]
- multiply_fisher(vector)[source]
Right-multiplies a vector by the Fisher.
- Parameters
vector (Sequence[Array]) – The vector to multiply. Must have the same shape(s) as
self.inputs.- Return type
Tuple[Array, …]
- Returns
The vector right-multiplied by the Fisher. Will have of the same shape(s) as
self.inputs.
- abstract multiply_fisher_unweighted(vector)[source]
Unweighted version of
multiply_fisher().- Return type
Tuple[Array, …]
- multiply_fisher_factor(vector)[source]
Right-multiplies a vector by a factor B of the Fisher.
Here the Fisher is the Fisher information matrix (i.e. expected outer- product of gradients) with respect to the parameters of the underlying probability distribution (whose log-prob defines the loss). Typically this will be block-diagonal across different cases in the batch, since the distribution is usually (but not always) conditionally iid across different cases.
Note that B can be any matrix satisfying
B * B^T = Fwhere F is the Fisher, but will agree with the one used in the other methods of this class.- Parameters
vector (Array) – The vector to multiply. Must have the same shape(s) as
self.fisher_factor_inner_shape.- Return type
Tuple[Array, …]
- Returns
The vector right-multiplied by B. Will have the same shape(s) as
self.inputs.
- abstract multiply_fisher_factor_unweighted(vector)[source]
Unweighted version of
multiply_fisher_factor().- Return type
Tuple[Array, …]
- multiply_fisher_factor_transpose(vector)[source]
Right-multiplies a vector by the transpose of a factor B of the Fisher.
Here the Fisher is the Fisher information matrix (i.e. expected outer- product of gradients) with respect to the parameters of the underlying probability distribution (whose log-prob defines the loss). Typically this will be block-diagonal across different cases in the batch, since the distribution is usually (but not always) conditionally iid across different cases.
Note that B can be any matrix satisfying
B * B^T = Fwhere F is the Fisher, but will agree with the one used in the other methods of this class.- Parameters
vector (Sequence[Array]) – The vector to multiply. Must have the same shape(s) as
self.inputs.- Return type
Array
- Returns
The vector right-multiplied by B^T. Will have the shape given by
self.fisher_factor_inner_shape.
- abstract multiply_fisher_factor_transpose_unweighted(vector)[source]
Unweighted version of
multiply_fisher_factor_transpose().- Return type
Array
- multiply_fisher_factor_replicated_one_hot(index)[source]
Right-multiplies a replicated-one-hot vector by a factor B of the Fisher.
Here the Fisher is the Fisher information matrix (i.e. expected outer- product of gradients) with respect to the parameters of the underlying probability distribution (whose log-prob defines the loss). Typically this will be block-diagonal across different cases in the batch, since the distribution is usually (but not always) conditionally iid across different cases.
A replicated-one-hot vector means a tensor which, for each slice along the batch dimension (assumed to be dimension 0), is 1.0 in the entry corresponding to the given index and 0 elsewhere.
Note that B can be any matrix satisfying
B * B^T = Hwhere H is the Fisher, but will agree with the one used in the other methods of this class.- Parameters
index (Sequence[int]) – A tuple representing in the index of the entry in each slice that is 1.0. Note that len(index) must be equal to the number of elements of the
fisher_factor_inner_shapetensor minus one.- Return type
Tuple[Array, …]
- Returns
The vector right-multiplied by B. Will have the same shape(s) as
self.inputs.
- abstract multiply_fisher_factor_replicated_one_hot_unweighted(index)[source]
Unweighted version of
multiply_fisher_factor_replicated_one_hot().- Return type
Tuple[Array, …]
- abstract property fisher_factor_inner_shape: Shape
The shape of the array returned by
multiply_fisher_factor().- Return type
Shape
- grad_of_evaluate_on_sample(rng, coefficient_mode)[source]
Evaluates the gradient of the log probability on a random sample.
- Parameters
rng (Array) – Jax PRNG key for sampling.
coefficient_mode (str) – The coefficient mode to use for evaluation.
- Return type
Tuple[Array, …]
- Returns
The gradient of the log probability of targets sampled from the distribution.
DistributionNegativeLogProbLoss
- class kfac_jax.DistributionNegativeLogProbLoss(weight)[source]
Negative log-probability loss that uses a Distrax distribution.
- abstract property dist: distrax.Distribution
The underlying Distrax distribution.
- Return type
distrax.Distribution
- property fisher_factor_inner_shape: Shape
The shape of the array returned by
multiply_fisher_factor().- Return type
Shape
NormalMeanNegativeLogProbLoss
- class kfac_jax.NormalMeanNegativeLogProbLoss(mean, targets=None, variance=0.5, weight=1.0)[source]
Loss log prob loss for a normal distribution parameterized by a mean vector.
Note that the covariance is treated as the identity divided by 2. Also note that the Fisher for such a normal distribution with respect the mean parameter is given by:
F = (1 / variance) * I
See for example https://www.ii.pwr.edu.pl/~tomczak/PDF/[JMT]Fisher_inf.pdf.
- __init__(mean, targets=None, variance=0.5, weight=1.0)[source]
Initializes the loss instance.
- Parameters
mean (Array) – The mean of the normal distribution.
targets (Optional[Array]) – Optional targets to use for evaluation.
variance (Numeric) – The scalar variance of the normal distribution.
weight (Numeric) – The relative weight of the loss.
- property targets: Optional[Array]
The targets (if present) used for evaluating the loss.
- Return type
Optional[Array]
- property parameter_independants: Tuple[Numeric, ...]
All the parameter independent arrays of the loss.
- Return type
Tuple[Numeric, …]
- property dist: distrax.MultivariateNormalDiag
The underlying Distrax distribution.
- Return type
distrax.MultivariateNormalDiag
- property params: Tuple[Array]
Parameters to the underlying distribution.
- Return type
Tuple[Array]
- copy_with_different_inputs(parameter_dependants)[source]
Creates the same
LossFunctionobject, but with different inputs.- Parameters
parameter_dependants (Sequence[Array]) – The inputs to use to the constructor of a class instance. This must be a sequence of length 1.
- Return type
‘NormalMeanNegativeLogProbLoss’
- Returns
- An instance of
NormalMeanNegativeLogPorLosswith the provided inputs.
- An instance of
- Raises
A ValueError if the inputs is a sequence of different length than 1. –
- multiply_fisher_unweighted(vector)[source]
Unweighted version of
multiply_fisher().- Return type
Tuple[Array]
- multiply_fisher_factor_unweighted(vector)[source]
Unweighted version of
multiply_fisher_factor().- Return type
Tuple[Array]
NormalMeanVarianceNegativeLogProbLoss
- class kfac_jax.NormalMeanVarianceNegativeLogProbLoss(mean, variance, targets=None, weight=1.0)[source]
Negative log prob loss for a normal distribution with mean and variance.
This class parameterizes a multivariate normal distribution with n independent dimensions. Unlike
NormalMeanNegativeLogProbLoss, this class does not assume the variance is held constant. The Fisher Information for n = 1 is given by:- F = [[1 / variance, 0],
[ 0, 0.5 / variance^2]]
where the parameters of the distribution are concatenated into a single vector as
[mean, variance]. For n > 1, the mean parameter vector is concatenated with the variance parameter vector. For further details checkout the Wikipedia page.- __init__(mean, variance, targets=None, weight=1.0)[source]
Initializes the loss instance.
- Parameters
mean (Array) – The mean of the normal distribution.
variance (Array) – The variance of the normal distribution.
targets (Optional[Array]) – Optional targets to use for evaluation.
weight (Numeric) – The relative weight of the loss.
- property targets: Optional[Array]
The targets (if present) used for evaluating the loss.
- Return type
Optional[Array]
- property parameter_independants: Tuple[Numeric, ...]
All the parameter independent arrays of the loss.
- Return type
Tuple[Numeric, …]
- property dist: distrax.MultivariateNormalDiag
The underlying Distrax distribution.
- Return type
distrax.MultivariateNormalDiag
- property params: Tuple[Array, Array]
Parameters to the underlying distribution.
- Return type
Tuple[Array, Array]
- copy_with_different_inputs(parameter_dependants)[source]
Creates the same
LossFunctionobject, but with different inputs.- Parameters
parameter_dependants (Sequence[Array]) – The inputs to use to the constructor of a class instance. This must be a sequence of length 2.
- Return type
‘NormalMeanVarianceNegativeLogProbLoss’
- Returns
An instance of
NormalMeanVarianceNegativeLogProbLosswith the provided inputs.- Raises
A ValueError if the inputs is a sequence of different length than 2. –
- multiply_fisher_unweighted(vector)[source]
Unweighted version of
multiply_fisher().- Return type
Tuple[Array, Array]
- multiply_fisher_factor_unweighted(vector)[source]
Unweighted version of
multiply_fisher_factor().- Return type
Tuple[Array, Array]
- multiply_fisher_factor_transpose_unweighted(vector)[source]
Unweighted version of
multiply_fisher_factor_transpose().- Return type
Array
- multiply_fisher_factor_replicated_one_hot_unweighted(index)[source]
Unweighted version of
multiply_fisher_factor_replicated_one_hot().- Return type
Tuple[Array, Array]
- property fisher_factor_inner_shape: Shape
The shape of the array returned by
multiply_fisher_factor().- Return type
Shape
- multiply_ggn_unweighted(vector)[source]
Unweighted version of
multiply_ggn().- Return type
Tuple[Array, …]
- multiply_ggn_factor_unweighted(vector)[source]
Unweighted version of
multiply_ggn_factor().- Return type
Tuple[Array, …]
- multiply_ggn_factor_transpose_unweighted(vector)[source]
Unweighted version of
multiply_ggn_factor_transpose().- Return type
Array
- multiply_ggn_factor_replicated_one_hot_unweighted(index)[source]
Unweighted version of
multiply_ggn_factor_replicated_one_hot().- Return type
Tuple[Array, …]
- property ggn_factor_inner_shape: Shape
The shape of the array returned by self.multiply_ggn_factor.
- Return type
Shape
MultiBernoulliNegativeLogProbLoss
- class kfac_jax.MultiBernoulliNegativeLogProbLoss(logits, targets=None, weight=1.0)[source]
Negative log prob loss for multiple Bernoulli distributions parametrized by logits.
Represents N independent Bernoulli distributions where N = len(logits). Its Fisher Information matrix is given by
F = diag(p * (1-p)), wherep = sigmoid(logits).As F is diagonal with positive entries, its factor B is
B = diag(sqrt(p * (1-p))).- __init__(logits, targets=None, weight=1.0)[source]
Initializes the loss instance.
- Parameters
logits (Array) – The logits of the Bernoulli distribution.
targets (Optional[Array]) – Optional targets to use for evaluation.
weight (Numeric) – The relative weight of the loss.
- property targets: Optional[Array]
The targets (if present) used for evaluating the loss.
- Return type
Optional[Array]
- property parameter_independants: Tuple[Numeric, ...]
All the parameter independent arrays of the loss.
- Return type
Tuple[Numeric, …]
- property dist: distrax.Bernoulli
The underlying Distrax distribution.
- Return type
distrax.Bernoulli
- property params: Tuple[Array]
Parameters to the underlying distribution.
- Return type
Tuple[Array]
- copy_with_different_inputs(parameter_dependants)[source]
Creates a copy of the loss function object, but with different inputs.
- Return type
‘MultiBernoulliNegativeLogProbLoss’
- multiply_fisher_unweighted(vector)[source]
Unweighted version of
multiply_fisher().- Return type
Tuple[Array]
- multiply_fisher_factor_unweighted(vector)[source]
Unweighted version of
multiply_fisher_factor().- Return type
Tuple[Array]
CategoricalLogitsNegativeLogProbLoss
- class kfac_jax.CategoricalLogitsNegativeLogProbLoss(logits, targets=None, mask=None, weight=1.0)[source]
Negative log prob loss for a categorical distribution parameterized by logits.
Note that the Fisher (for a single case) of a categorical distribution, with respect to the natural parameters (i.e. the logits), is given by
F = diag(p) - p*p^T, wherep = softmax(logits). F can be factorized asF = B * B^T, whereB = diag(q) - p*q^Tandqis the entry-wise square root ofp. This is easy to verify using the fact thatq^T*q = 1.- __init__(logits, targets=None, mask=None, weight=1.0)[source]
Initializes the loss instance.
- Parameters
logits (Array) – The logits of the Categorical distribution of shape
(batch_size, output_size).targets (Optional[Array]) – Optional targets to use for evaluation, which specify an integer index of the correct class. Must be of shape
(batch_size,).mask (Optional[Array]) – Optional mask to apply to losses over the batch. Should be 0/1-valued and of shape
(batch_size,). The tensors returned byevaluateandgrad_of_evaluate, as well as the various matrix vector products, will be multiplied by mask (with broadcasting to later dimensions).weight (Numeric) – The relative weight of the loss.
- property targets: Optional[Array]
The targets (if present) used for evaluating the loss.
- Return type
Optional[Array]
- property parameter_independants: Tuple[Numeric, ...]
All the parameter independent arrays of the loss.
- Return type
Tuple[Numeric, …]
- property dist: distrax.Categorical
The underlying Distrax distribution.
- Return type
distrax.Categorical
- property params: Tuple[Array]
Parameters to the underlying distribution.
- Return type
Tuple[Array]
- property fisher_factor_inner_shape: Shape
The shape of the array returned by
multiply_fisher_factor().- Return type
Shape
- copy_with_different_inputs(parameter_dependants)[source]
Creates a copy of the loss function object, but with different inputs.
- Return type
‘CategoricalLogitsNegativeLogProbLoss’
- multiply_fisher_unweighted(vector)[source]
Unweighted version of
multiply_fisher().- Return type
Tuple[Array]
- multiply_fisher_factor_unweighted(vector)[source]
Unweighted version of
multiply_fisher_factor().- Return type
Tuple[Array]
OneHotCategoricalLogitsNegativeLogProbLoss
- class kfac_jax.OneHotCategoricalLogitsNegativeLogProbLoss(logits, targets=None, mask=None, weight=1.0)[source]
Neg log prob loss for a categorical distribution with onehot targets.
Identical to CategoricalLogitsNegativeLogProbLoss except that the underlying distribution is OneHotCategorical as opposed to Categorical.
- property dist: distrax.OneHotCategorical
The underlying Distrax distribution.
- Return type
distrax.OneHotCategorical
register_sigmoid_cross_entropy_loss
- kfac_jax.register_sigmoid_cross_entropy_loss(logits, targets=None, weight=1.0)[source]
Registers a sigmoid cross-entropy loss function.
Note that this is distinct from
register_softmax_cross_entropy_loss()and should not be confused with it. It is similar toregister_multi_bernoulli_predictive_distribution()but without the explicit probabilistic interpretation. It behaves identically for now.- Parameters
logits (Array) – The input logits of the loss as a 2D array of floats. The first dimension will usually be the batch size, but doesn’t need to be (unless using
estimation_mode='fisher_exact'orestimation_mode='ggn_exact'in the optimizer/estimator).targets (Optional[Array]) – (OPTIONAL) The targets for the loss function. Must be of the same shape as
logits. Only required if usingestimation_mode='fisher_empirical'in the optimizer/estimator. (Default: None)weight (Numeric) – (OPTIONAL) a scalar. A coefficient to multiply the loss function by. (Default: 1.0)
register_multi_bernoulli_predictive_distribution
- kfac_jax.register_multi_bernoulli_predictive_distribution(logits, targets=None, weight=1.0)[source]
Registers a multi-Bernoulli predictive distribution.
Note that this is distinct from
register_categorical_predictive_distribution()and should not be confused with it.- Parameters
logits (Array) – The logits of the distribution (i.e. its parameters) as a 2D array of floats. The first dimension will usually be the batch size, but doesn’t need to be (unless using
estimation_mode='fisher_exact'orestimation_mode='ggn_exact'in the optimizer/estimator).targets (Optional[Array]) – (OPTIONAL) The targets for the loss function. Only required if using
estimation_mode='fisher_empirical'in the optimizer/estimator. (Default: None)weight (Numeric) – (OPTIONAL) a scalar. A coefficient to multiply the log prob loss associated with this distribution. The Fisher will be multiplied by the corresponding factor. This is NOT equivalent to changing the temperature of the distribution since we don’t renormalize the log prob in the objective function. (Default: 1.0)
register_softmax_cross_entropy_loss
- kfac_jax.register_softmax_cross_entropy_loss(logits, targets=None, mask=None, weight=1.0)[source]
Registers a softmax cross-entropy loss function.
Note that this is distinct from
register_sigmoid_cross_entropy_loss()and should not be confused with it. It is similar toregister_categorical_predictive_distribution()but without the explicit probabilistic interpretation. It behaves identically for now.- Parameters
logits (Array) – The input logits of the loss as a 2D array of floats. The first dimension will usually be the batch size, but doesn’t need to be (unless using
estimation_mode='fisher_exact'orestimation_mode='ggn_exact'in the optimizer/estimator). The second dimension is the one over which the softmax is computed.targets (Optional[Array]) – (OPTIONAL) The targets for the loss function. Must be a 1D array of integers with shape
(logits.shape[0],). Only required if usingestimation_mode='fisher_empirical'in the optimizer/estimator. (Default: None)mask (Optional[Array]) – (OPTIONAL) Mask to apply to losses. Should be 0/1-valued and of shape
(logits.shape[0],). Losses corresponding to mask values of False will be treated as constant and equal to 0. (Default: None)weight (Numeric) – (OPTIONAL) a scalar. A coefficient to multiply the loss function by. (Default: 1.0)
- Return type
Array
register_categorical_predictive_distribution
- kfac_jax.register_categorical_predictive_distribution(logits, targets=None, mask=None, weight=1.0)[source]
Registers a categorical predictive distribution.
Note that this is distinct from
register_multi_bernoulli_predictive_distribution()and should not be confused with it.- Parameters
logits (Array) – The logits of the distribution (i.e. its parameters) as a 2D array of floats. The first dimension will usually be the batch size, but doesn’t need to be (unless using
estimation_mode='fisher_exact'orestimation_mode='ggn_exact'in the optimizer/estimator). The second dimension is the one over which the softmax is computed.targets (Optional[Array]) – (OPTIONAL) The values at which the log probability of this distribution is evaluated (to give the loss). Must be a 2D array of integers with shape
(logits.shape[0],). Only required if usingestimation_mode='fisher_empirical'in the optimizer/estimator. (Default: None)mask (Optional[Array]) – (OPTIONAL) Mask to apply to log probabilities generated by the distribution. Should be 0/1-valued and of shape
(logits.shape[0],). Log probablities corresponding to mask values of False will be treated as constant and equal to 0. (Default: None)weight (Numeric) – (OPTIONAL) a scalar. A coefficient to multiply the log prob loss associated with this distribution. The Fisher will be multiplied by the corresponding factor. This is NOT equivalent to changing the temperature of the distribution since we don’t renormalize the log prob in the objective function. (Default: 1.0)
register_squared_error_loss
- kfac_jax.register_squared_error_loss(prediction, targets=None, weight=1.0)[source]
Registers a squared error loss function.
This assumes the squared error loss of the form
||target - prediction||^2, averaged across the mini-batch. If your loss uses a coefficient of 0.5 you need to set the “weight” argument to reflect this.- Parameters
prediction (Array) – The prediction made by the network (i.e. its output). The first dimension will usually be the batch size, but doesn’t need to be (unless using
estimation_mode='fisher_exact'orestimation_mode='ggn_exact'in the optimizer/estimator).targets (Optional[Array]) – (OPTIONAL) The targets for the loss function. Only required if using
estimation_mode='fisher_empirical'in the optimizer/estimator. (Default: None)weight (Numeric) – A float coefficient to multiply the loss function by. (Default: 1.0)
- Return type
Array
register_normal_predictive_distribution
- kfac_jax.register_normal_predictive_distribution(mean, targets=None, variance=0.5, weight=1.0)[source]
Registers a normal predictive distribution.
- This corresponds to a squared error loss of the form
weight/(2*var) * ||target - mean||^2
- Parameters
mean (Array) – A tensor defining the mean vector of the distribution. The first dimension will usually be the batch size, but doesn’t need to be (unless using
estimation_mode='fisher_exact'orestimation_mode='ggn_exact'in the optimizer/estimator).targets (Optional[Array]) – (OPTIONAL) The targets for the loss function. Only required if using
estimation_mode='fisher_empirical'in the optimizer/estimator. (Default: None)variance (float) – float. The variance of the distribution. Note that the default value of 0.5 corresponds to a standard squared error loss weight * ||target - prediction||^2. If you want your squared error loss to be of the form
0.5*coeff*||target - prediction||^2you should use variance=1.0. (Default: 0.5)weight (Numeric) – A scalar coefficient to multiply the log prob loss associated with this distribution. The Fisher will be multiplied by the corresponding factor. In general this is NOT equivalent to changing the temperature of the distribution, but in the ase of normal distributions it may be. (Default: 1.0)
Curvature Blocks
|
Abstract class for curvature approximation blocks. |
|
A block that assumes that the curvature is a scaled identity matrix. |
|
An abstract class for approximating only the diagonal of curvature. |
|
An abstract class for approximating the block matrix with a full matrix. |
|
A Kronecker factored block for layers with weights and an optional bias. |
|
Approximates the diagonal of the curvature with in the most obvious way. |
|
Approximates the full curvature with in the most obvious way. |
|
A Diagonal block specifically for dense layers. |
|
A Full block specifically for dense layers. |
|
A |
|
A |
|
A |
|
A |
|
A diagonal approximation specifically for a scale and shift layers. |
|
A full dense approximation specifically for a scale and shift layers. |
|
Sets the default value of maximum parallel elements in the module. |
Returns the default value of maximum parallel elements in the module. |
|
Sets the default value of the eigen decomposition threshold. |
|
Returns the default value of the eigen decomposition threshold. |
CurvatureBlock
- class kfac_jax.CurvatureBlock(layer_tag_eq, name)[source]
Abstract class for curvature approximation blocks.
A CurvatureBlock defines a curvature matrix to be estimated, and gives methods to multiply powers of this with a vector. Powers can be computed exactly or with a class-determined approximation. Cached versions of the powers can be pre-computed to make repeated multiplications cheaper. During initialization, you would have to explicitly specify all powers that you will need to cache.
- class State(cache)[source]
Persistent state of the block.
Any subclasses of
CurvatureBlockshould also internally extend this class, with any attributes needed for the curvature estimation.- cache
A dictionary, containing any state data that is updated on irregular intervals, such as inverses, eigenvalues, etc. Elements of this are updated via calls to
update_cache(), and do not necessarily correspond to the most up-to-date curvature estimate.- Type
Optional[Dict[str, Union[Array, Dict[str, Array]]]]
- __init__(cache)
- __init__(layer_tag_eq, name)[source]
Initializes the block.
- Parameters
layer_tag_eq (tags.LayerTagEqn) – The Jax equation corresponding to the layer tag that this block will approximate the curvature to.
name (str) – The name of this block.
- property layer_tag_primitive: tags.LayerTag
The
jax.core.Primitivecorresponding to the block’s tag equation.- Return type
tags.LayerTag
- property parameter_variables: Tuple[jax.core.Var, ...]
The parameter variables of the underlying Jax equation.
- Return type
Tuple[jax.core.Var, …]
- property outputs_shapes: Tuple[Shape, ...]
The shapes of the output variables of the block’s tag equation.
- Return type
Tuple[Shape, …]
- property inputs_shapes: Tuple[Shape, ...]
The shapes of the input variables of the block’s tag equation.
- Return type
Tuple[Shape, …]
- property parameters_shapes: Tuple[Shape, ...]
The shapes of the parameter variables of the block’s tag equation.
- Return type
Tuple[Shape, …]
- property parameters_canonical_order: Tuple[int, ...]
The canonical order of the parameter variables.
- Return type
Tuple[int, …]
- property layer_tag_extra_params: Dict[str, Any]
Any extra parameters of passed into the Jax primitive of this block.
- Return type
Dict[str, Any]
- property number_of_parameters: int
Number of parameter variables of this block.
- Return type
int
- property dim: int
The number of elements of all parameter variables together.
- Return type
int
- scale(state, use_cache)[source]
A scalar pre-factor of the curvature approximation.
Importantly, all methods assume that whenever a user requests cached values, any state dependant scale is taken into account by the cache (e.g. either stored explicitly and used or mathematically added to values).
- Parameters
state ('CurvatureBlock.State') – The state for this block.
use_cache (bool) – Whether the method requesting this is using cached values or not.
- Return type
Numeric
- Returns
A scalar value to be multiplied with any unscaled block representation.
- fixed_scale()[source]
A fixed scalar pre-factor of the curvature (e.g. constant).
- Return type
Numeric
- state_dependent_scale(state)[source]
A scalar pre-factor of the curvature, computed from the most fresh curvature estimate.
- Return type
Numeric
- init(rng, exact_powers_to_cache, approx_powers_to_cache, cache_eigenvalues)[source]
Initializes the state for this block.
- Parameters
rng (PRNGKey) – The PRNGKey which to be used for any randomness of the initialization
exact_powers_to_cache (Optional[ScalarOrSequence]) – A single value, or multiple values in a list, which specify which exact matrix powers the block should be caching. Matrix powers, which are expected to be used in
multiply_matpower(),multiply_inverse()ormultiply()withexact_power=Trueanduse_cached=Truemust be provided here.approx_powers_to_cache (Optional[ScalarOrSequence]) – A single value, or multiple values in a list, which specify approximate matrix powers the block should be caching. Matrix powers, which are expected to be used in
multiply_matrix_power(),multiply_inverse()ormultiply()withexact_power=Falseanduse_cached=Truemust be provided here.cache_eigenvalues (bool) – Specifies whether the block should be caching the eigenvalues of its approximate curvature.
- Return type
‘CurvatureBlock.State’
- Returns
A dictionary with the initialized state.
- multiply_matpower(state, vector, identity_weight, power, exact_power, use_cached)[source]
Computes
(BlockMatrix + identity_weight I)**powertimesvector.- Parameters
state ('CurvatureBlock.State') – The state for this block.
vector (Sequence[Array]) – A tuple of arrays that should have the same shapes as the block’s parameters_shapes, which represent the vector you want to multiply.
identity_weight (Numeric) – A scalar specifying the weight on the identity matrix that is added to the block matrix before raising it to a power. If
use_cached=Falseit is guaranteed that this argument will be used in the computation. When returning cached values, this argument may be ignored in favor whatever value was last passed toupdate_cache(). The precise semantics of this depend on the concrete subclass and its particular behavior in regard to caching.power (Scalar) – The power to which to raise the matrix.
exact_power (bool) – Specifies whether to compute the exact matrix power of
BlockMatrix + identity_weight I. When this argument isFalsethe exact behaviour will depend on the concrete subclass and the result will in general be an approximation to(BlockMatrix + identity_weight I)^power, although some subclasses may still compute the exact matrix power.use_cached (bool) – Whether to use a cached version for computing the product or to use the most recent curvature estimates. The cached version is going to be at least as fresh as the value provided to the last call to
update_cache()with the same value ofpower
- Return type
Tuple[Array, …]
- Returns
A tuple of arrays, representing the result of the matrix-vector product.
- multiply(state, vector, identity_weight, exact_power, use_cached)[source]
Computes
(BlockMatrix + identity_weight I)timesvector.- Return type
Tuple[Array, …]
- multiply_inverse(state, vector, identity_weight, exact_power, use_cached)[source]
Computes
(BlockMatrix + identity_weight I)^-1timesvector.- Return type
Tuple[Array, …]
- eigenvalues(state, use_cached)[source]
Computes the eigenvalues for this block approximation.
- Parameters
state ('CurvatureBlock.State') – The state dict for this block.
use_cached (bool) – Whether to use a cached versions of the eigenvalues or to use the most recent curvature estimates to compute them. The cached version are going to be at least as fresh as the last time you called
update_cache()witheigenvalues=True.
- Return type
Array
- Returns
An array containing the eigenvalues of the block.
- abstract update_curvature_matrix_estimate(state, estimation_data, ema_old, ema_new, batch_size, pmap_axis_name)[source]
Updates the block’s curvature estimates using the
infoprovided.Each block in general estimates a moving average of its associated curvature matrix. If you don’t want a moving average you can set
ema_old=0andema_new=1.- Parameters
state ('CurvatureBlock.State') – The state dict for this block to update.
estimation_data (Dict[str, Sequence[Array]]) – A map containing data used for updating the curvature matrix estimate for this block. This can be computed by calling the function returned from
layer_tags_vjp(). Please see its implementation for more details on the name of the fields and how they are constructed.ema_old (Numeric) – Specifies the weight of the old value when computing the updated estimate in the moving average.
ema_new (Numeric) – Specifies the weight of the new value when computing the updated estimate in the moving average.
batch_size (Numeric) – The batch size used in computing the values in
info.pmap_axis_name (Optional[str]) – The name of any pmap axis, which might be needed for computing the updates.
- Return type
‘CurvatureBlock.State’
- update_cache(state, identity_weight, exact_powers, approx_powers, eigenvalues)[source]
Updates the cached estimates of the different powers specified.
- Parameters
state ('CurvatureBlock.State') – The state dict for this block to update.
identity_weight (Numeric) – The weight of the identity added to the block’s curvature matrix before computing the cached matrix power.
exact_powers (Optional[ScalarOrSequence]) – Specifies any cached exact matrix powers to be updated.
approx_powers (Optional[ScalarOrSequence]) – Specifies any cached approximate matrix powers to be updated.
eigenvalues (bool) – Specifies whether to update the cached eigenvalues of the block. If they have not been cached before, this will create an entry with them in the block’s cache.
- Return type
‘CurvatureBlock.State’
- Returns
The updated state.
ScaledIdentity
- class kfac_jax.ScaledIdentity(layer_tag_eq, name, scale=1.0)[source]
A block that assumes that the curvature is a scaled identity matrix.
- __init__(layer_tag_eq, name, scale=1.0)[source]
Initializes the block.
- Parameters
layer_tag_eq (tags.LayerTagEqn) – The Jax equation corresponding to the layer tag, that this block will approximate the curvature to.
name (str) – The name of this block.
scale (Numeric) – The scale of the identity matrix.
- fixed_scale()[source]
A fixed scalar pre-factor of the curvature (e.g. constant).
- Return type
Numeric
- update_curvature_matrix_estimate(state, estimation_data, ema_old, ema_new, batch_size, pmap_axis_name)[source]
Updates the block’s curvature estimates using the
infoprovided.Each block in general estimates a moving average of its associated curvature matrix. If you don’t want a moving average you can set
ema_old=0andema_new=1.- Parameters
state (CurvatureBlock.State) – The state dict for this block to update.
estimation_data (Dict[str, Sequence[Array]]) – A map containing data used for updating the curvature matrix estimate for this block. This can be computed by calling the function returned from
layer_tags_vjp(). Please see its implementation for more details on the name of the fields and how they are constructed.ema_old (Numeric) – Specifies the weight of the old value when computing the updated estimate in the moving average.
ema_new (Numeric) – Specifies the weight of the new value when computing the updated estimate in the moving average.
batch_size (Numeric) – The batch size used in computing the values in
info.pmap_axis_name (Optional[str]) – The name of any pmap axis, which might be needed for computing the updates.
- Return type
Diagonal
- class kfac_jax.Diagonal(layer_tag_eq, name)[source]
An abstract class for approximating only the diagonal of curvature.
- class State(cache, diagonal_factors)[source]
Persistent state of the block.
- diagonal_factors
A tuple of the moving averages of the estimated diagonals of the curvature for each parameter that is part of the associated layer.
- Type
Tuple[utils.WeightedMovingAverage]
- __init__(cache, diagonal_factors)
- update_curvature_matrix_estimate(state, estimation_data, ema_old, ema_new, batch_size, pmap_axis_name)[source]
Updates the block’s curvature estimates using the
infoprovided.Each block in general estimates a moving average of its associated curvature matrix. If you don’t want a moving average you can set
ema_old=0andema_new=1.- Parameters
state ('Diagonal.State') – The state dict for this block to update.
estimation_data (Dict[str, Sequence[Array]]) – A map containing data used for updating the curvature matrix estimate for this block. This can be computed by calling the function returned from
layer_tags_vjp(). Please see its implementation for more details on the name of the fields and how they are constructed.ema_old (Numeric) – Specifies the weight of the old value when computing the updated estimate in the moving average.
ema_new (Numeric) – Specifies the weight of the new value when computing the updated estimate in the moving average.
batch_size (Numeric) – The batch size used in computing the values in
info.pmap_axis_name (Optional[str]) – The name of any pmap axis, which might be needed for computing the updates.
- Return type
‘Diagonal.State’
Full
- class kfac_jax.Full(layer_tag_eq, name, eigen_decomposition_threshold=None)[source]
An abstract class for approximating the block matrix with a full matrix.
- class State(cache, matrix)[source]
Persistent state of the block.
- matrix
A moving average of the estimated curvature matrix for all parameters that are part of the associated layer.
- Type
utils.WeightedMovingAverage
- __init__(cache, matrix)
- __init__(layer_tag_eq, name, eigen_decomposition_threshold=None)[source]
Initializes the block.
- Parameters
layer_tag_eq (tags.LayerTagEqn) – The Jax equation corresponding to the layer tag that this block will approximate the curvature to.
name (str) – The name of this block.
eigen_decomposition_threshold (Optional[int]) – During calls to
initandupdate_cacheif higher number of matrix powers than this threshold are requested, instead of computing individual approximate powers, will directly compute the eigen-decomposition instead (which provide access to any matrix power). If this isNonewill use the value returned fromget_default_eigen_decomposition_threshold().
- parameters_list_to_single_vector(parameters_shaped_list)[source]
Converts values corresponding to parameters of the block to vector.
- Return type
Array
- single_vector_to_parameters_list(vector)[source]
Reverses the transformation
self.parameters_list_to_single_vector.- Return type
Tuple[Array, …]
- update_curvature_matrix_estimate(state, estimation_data, ema_old, ema_new, batch_size, pmap_axis_name)[source]
Updates the block’s curvature estimates using the
infoprovided.Each block in general estimates a moving average of its associated curvature matrix. If you don’t want a moving average you can set
ema_old=0andema_new=1.- Parameters
state ('Full.State') – The state dict for this block to update.
estimation_data (Dict[str, Sequence[Array]]) – A map containing data used for updating the curvature matrix estimate for this block. This can be computed by calling the function returned from
layer_tags_vjp(). Please see its implementation for more details on the name of the fields and how they are constructed.ema_old (Numeric) – Specifies the weight of the old value when computing the updated estimate in the moving average.
ema_new (Numeric) – Specifies the weight of the new value when computing the updated estimate in the moving average.
batch_size (Numeric) – The batch size used in computing the values in
info.pmap_axis_name (Optional[str]) – The name of any pmap axis, which might be needed for computing the updates.
- Return type
‘Full.State’
TwoKroneckerFactored
- class kfac_jax.TwoKroneckerFactored(layer_tag_eq, name)[source]
A Kronecker factored block for layers with weights and an optional bias.
- __init__(layer_tag_eq, name)[source]
Initializes the block.
- Parameters
layer_tag_eq (tags.LayerTagEqn) – The Jax equation corresponding to the layer tag that this block will approximate the curvature to.
name (str) – The name of this block.
- property has_bias: bool
Whether this layer’s equation has a bias.
- Return type
bool
NaiveDiagonal
NaiveFull
- class kfac_jax.NaiveFull(layer_tag_eq, name, eigen_decomposition_threshold=None)[source]
Approximates the full curvature with in the most obvious way.
The update to the curvature estimate is computed by
(sum_i g_i) (sum_i g_i)^T / N, whereg_iis the gradient of each individual data point, andNis the batch size.
DenseDiagonal
DenseFull
DenseTwoKroneckerFactored
- class kfac_jax.DenseTwoKroneckerFactored(layer_tag_eq, name)[source]
A
TwoKroneckerFactoredblock specifically for dense layers.
Conv2DDiagonal
- class kfac_jax.Conv2DDiagonal(layer_tag_eq, name, max_elements_for_vmap=None)[source]
A
Diagonalblock specifically for 2D convolution layers.- __init__(layer_tag_eq, name, max_elements_for_vmap=None)[source]
Initializes the block.
Since there is no ‘nice’ formula for computing the average of the tangents for a 2D convolution, what we do is that we have a function -
self.conv2d_tangent_squared- that computes for a single feature map the square of the tangents for the kernel of the convolution. To average over the batch we have two choices - vmap or loop over the batch sequentially using scan. This utility function provides a trade-off by being able to specify the maximum number of batch size that we can vmap over. This means that the maximum memory usage will bemax_batch_size_for_vmaptimes the memory needed when callingself.conv2d_tangent_squared. And the actualvmapwill be calledceil(total_batch_size / max_batch_size_for_vmap)number of times in a loop to find the final average.- Parameters
layer_tag_eq (tags.LayerTagEqn) – The Jax equation corresponding to the layer tag, that this block will approximate the curvature to.
name (str) – The name of this block.
max_elements_for_vmap (Optional[int]) – The threshold used for determining how much computation to the in parallel and how much in serial manner. If
Nonewill use the value returned byget_max_parallel_elements().
Conv2DFull
- class kfac_jax.Conv2DFull(layer_tag_eq, name, max_elements_for_vmap=None)[source]
A
Fullblock specifically for 2D convolution layers.- __init__(layer_tag_eq, name, max_elements_for_vmap=None)[source]
Initializes the block.
Since there is no ‘nice’ formula for computing the average of the tangents for a 2D convolution, what we do is that we have a function -
self.conv2d_tangent_squared- that computes for a single feature map the square of the tangents for the kernel of the convolution. To average over the batch we have two choices - vmap or loop over the batch sequentially using scan. This utility function provides a trade-off by being able to specify the maximum batch that that will be handled in a single iteration of the loop. This means that the maximum memory usage will bemax_batch_size_for_vmaptimes the memory needed when callingself.conv2d_tangent_squared. And the actualvmapwill be calledceil(total_batch_size / max_batch_size_for_vmap)number of times in a loop to find the final average.- Parameters
layer_tag_eq (tags.LayerTagEqn) – The Jax equation corresponding to the layer tag, that this block will approximate the curvature to.
name (str) – The name of this block.
max_elements_for_vmap (Optional[int]) – The threshold used for determining how much computation to the in parallel and how much in serial manner. If
Nonewill use the value returned byget_max_parallel_elements().
Conv2DTwoKroneckerFactored
- class kfac_jax.Conv2DTwoKroneckerFactored(layer_tag_eq, name)[source]
A
TwoKroneckerFactoredblock specifically for 2D convolution layers.- fixed_scale()[source]
A fixed scalar pre-factor of the curvature (e.g. constant).
- Return type
Numeric
- property outputs_channel_index: int
The
channelsindex in the outputs of the layer.- Return type
int
- property inputs_channel_index: int
The
channelsindex in the inputs of the layer.- Return type
int
- property weights_output_channel_index: int
The
channelsindex in weights of the layer.- Return type
int
- property weights_spatial_size: int
The spatial filter size of the weights.
- Return type
int
- property num_locations: int
The number of spatial locations that each filter is applied to.
- Return type
int
- property num_inputs_channels: int
The number of channels in the inputs to the layer.
- Return type
int
- property num_outputs_channels: int
The number of channels in the outputs to the layer.
- Return type
int
ScaleAndShiftDiagonal
- class kfac_jax.ScaleAndShiftDiagonal(layer_tag_eq, name)[source]
A diagonal approximation specifically for a scale and shift layers.
- property has_scale: bool
Whether this layer’s equation has a scale.
- Return type
bool
- property has_shift: bool
Whether this layer’s equation has a shift.
- Return type
bool
ScaleAndShiftFull
set_max_parallel_elements
- kfac_jax.set_max_parallel_elements(value)[source]
Sets the default value of maximum parallel elements in the module.
This value is used to determine the parallel-to-memory tradeoff in the curvature estimation procedure of
Conv2DDiagonalandConv2DFull. See their corresponding docs for further details.- Parameters
value (int) – The default value for maximum number of parallel elements.
get_max_parallel_elements
- kfac_jax.get_max_parallel_elements()[source]
Returns the default value of maximum parallel elements in the module.
This value is used to determine the parallel-to-memory tradeoff in the curvature estimation procedure of
Conv2DDiagonalandConv2DFull. See their corresponding docs for further details.- Return type
int
- Returns
The default value for maximum number of parallel elements.
set_default_eigen_decomposition_threshold
- kfac_jax.set_default_eigen_decomposition_threshold(value)[source]
Sets the default value of the eigen decomposition threshold.
This value is used in
Fullto determine when updating the cache, at what number of different powers to switch the implementation from a simple matrix power to an eigenvector decomposition.- Parameters
value (int) – The default value for eigen decomposition threshold.
get_default_eigen_decomposition_threshold
- kfac_jax.get_default_eigen_decomposition_threshold()[source]
Returns the default value of the eigen decomposition threshold.
This value is used in
Fullto determine when updating the cache, at what number of different powers to switch the implementation from a simple matrix power to an eigenvector decomposition.- Return type
int
- Returns
The default value of the eigen decomposition threshold.
Advanced Features API
Layer and loss tags
|
A Jax primitive for tagging K-FAC losses. |
|
A Jax primitive for tagging K-FAC layers. |
|
Registers a generic tag around the provided parameter array. |
|
Registers a dense layer: |
|
Registers a 2d convolution layer: |
|
Registers a scale and shift layer: |
LossTag
- class kfac_jax.LossTag(cls, parameter_dependants, parameter_independants)[source]
A Jax primitive for tagging K-FAC losses.
The primitive is no-op at runtime, however its goal is to tag (annotate) the Jax computation graph what expression exactly is the loss and what type of loss it represents. This is the only way for K-FAC to know how to compute the curvature matrix.
- __init__(cls, parameter_dependants, parameter_independants)[source]
Initializes a loss tag primitive for the given
LossFunctionclass.When the primitive is created, the constructor automatically registers it with the standard Jax machinery for differentiation,
jax.vmap()and XLA lowering. For further details see please take a look at the JAX documentation on primitives.- Parameters
cls (Type[T]) – The corresponding class of
LossFunctionthat this tag represents.parameter_dependants (Sequence[str]) – The names of each of the parameter dependent inputs to the tag.
parameter_independants (Sequence[str]) – The names of each of the parameter independent inputs to the tag.
- property parameter_dependants_names: Tuple[str, ...]
The number of parameter dependent inputs to the tag primitive.
- Return type
Tuple[str, …]
- property parameter_independants_names: Tuple[str, ...]
The number of parameter independent inputs to the tag primitive.
- Return type
Tuple[str, …]
- loss(*args, args_names)[source]
Constructs an instance of the corresponding
LossFunctionclass.- Return type
T
LayerTag
- class kfac_jax.LayerTag(name, num_inputs, num_outputs)[source]
A Jax primitive for tagging K-FAC layers.
The primitive is no-op at runtime, however its goal is to tag (annotate) the Jax computation graph what expressions represents a single unique layer type. This is the only way for K-FAC to know how to compute the curvature matrix.
- __init__(name, num_inputs, num_outputs)[source]
Initializes a layer tag primitive with the given name.
Any layer tag primitive must have the following interface layer_tag( *outputs, *inputs, *parameters, **kwargs). We refer collectively to
inputs,outputsandparametersas operands. All operands must be Jax arrays, while any of the values inkwargsmust be hashable fixed constants.When the primitive is created, the constructor automatically registers it with the standard Jax machinery for differentiation,
jax.vmap()and XLA lowering. For further details see please take a look at the JAX documentation on primitives.- Parameters
name (str) – The name of the layer primitive.
num_inputs (int) – The number of inputs to the layer.
num_outputs (int) – The number of outputs to the layer.
- property num_outputs: int
The number of outputs of the layer tag that this primitive represents.
- Return type
int
- property num_inputs: int
The number of inputs of the layer tag that this primitive represents.
- Return type
int
register_generic
register_dense
register_conv2d
Automatic Tags Registration
|
Transforms the function by automatically registering layer tags. |
auto_register_tags
- kfac_jax.auto_register_tags(func, func_args, params_index=0, register_only_generic=False, compute_only_loss_tags=True, patterns_to_skip=(), allow_multiple_registrations=False, graph_matcher_rules=<kfac_jax._src.tag_graph_matcher.GraphMatcherComparator object>, graph_patterns=(GraphPattern(name='dense_with_bias', tag_primitive=dense_tag, compute_func=<function _dense>, parameters_extractor_func=<function _dense_parameter_extractor>, example_args=[array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]), [array([[0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0.]]), array([0., 0., 0., 0., 0., 0., 0.])]], in_values_preprocessor=None, _graph=None), GraphPattern(name='dense_with_bias', tag_primitive=dense_tag, compute_func=<function _dense_with_reshape>, parameters_extractor_func=<function _dense_parameter_extractor>, example_args=[array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]), [array([[0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0.]]), array([0., 0., 0., 0., 0., 0., 0.])]], in_values_preprocessor=None, _graph=None), GraphPattern(name='dense_no_bias', tag_primitive=dense_tag, compute_func=<function _dense>, parameters_extractor_func=<function _dense_parameter_extractor>, example_args=[array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]), [array([[0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0.]])]], in_values_preprocessor=None, _graph=None), GraphPattern(name='conv2d_with_bias', tag_primitive=conv2d_tag, compute_func=<function _conv2d>, parameters_extractor_func=<function _conv2d_parameter_extractor>, example_args=[array([[[[0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.]], [[0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.]], [[0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.]], [[0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.]], [[0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.]], [[0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.]], [[0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.]], [[0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.]]], [[[0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.]], [[0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.]], [[0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.]], [[0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.]], [[0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.]], [[0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.]], [[0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.]], [[0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.]]]]), [array([[[[0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.]], [[0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.]], [[0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.]]], [[[0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.]], [[0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.]], [[0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.]]], [[[0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.]], [[0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.]], [[0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.]]]]), array([0., 0., 0., 0.])]], in_values_preprocessor=None, _graph=None), GraphPattern(name='conv2d_no_bias', tag_primitive=conv2d_tag, compute_func=<function _conv2d>, parameters_extractor_func=<function _conv2d_parameter_extractor>, example_args=[array([[[[0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.]], [[0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.]], [[0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.]], [[0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.]], [[0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.]], [[0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.]], [[0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.]], [[0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.]]], [[[0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.]], [[0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.]], [[0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.]], [[0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.]], [[0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.]], [[0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.]], [[0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.]], [[0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.]]]]), [array([[[[0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.]], [[0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.]], [[0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.]]], [[[0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.]], [[0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.]], [[0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.]]], [[[0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.]], [[0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.]], [[0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.]]]])]], in_values_preprocessor=None, _graph=None), GraphPattern(name='scale_and_shift_broadcast_1', tag_primitive=scale_and_shift_tag, compute_func=functools.partial(<function _scale_and_shift>, has_scale=True, has_shift=True), parameters_extractor_func=<function _make_scale_and_shift_pattern.<locals>.<lambda>>, example_args=[array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]), [array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]), array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])]], in_values_preprocessor=None, _graph=None), GraphPattern(name='scale_and_shift_broadcast_0', tag_primitive=scale_and_shift_tag, compute_func=functools.partial(<function _scale_and_shift>, has_scale=True, has_shift=True), parameters_extractor_func=<function _make_scale_and_shift_pattern.<locals>.<lambda>>, example_args=[array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]), [array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]), array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])]], in_values_preprocessor=None, _graph=None), GraphPattern(name='normalization_haiku_broadcast_1', tag_primitive=scale_and_shift_tag, compute_func=functools.partial(<function _normalization_haiku>, has_scale=True, has_shift=True), parameters_extractor_func=<function _make_normalization_haiku_pattern.<locals>.<lambda>>, example_args=[[array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]), array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])], [array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]), array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])]], in_values_preprocessor=<function _normalization_haiku_preprocessor>, _graph=None), GraphPattern(name='normalization_haiku_broadcast_0', tag_primitive=scale_and_shift_tag, compute_func=functools.partial(<function _normalization_haiku>, has_scale=True, has_shift=True), parameters_extractor_func=<function _make_normalization_haiku_pattern.<locals>.<lambda>>, example_args=[[array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]), array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])], [array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]), array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])]], in_values_preprocessor=<function _normalization_haiku_preprocessor>, _graph=None), GraphPattern(name='scale_only_broadcast_1', tag_primitive=scale_and_shift_tag, compute_func=functools.partial(<function _scale_and_shift>, has_scale=True, has_shift=False), parameters_extractor_func=<function _make_scale_and_shift_pattern.<locals>.<lambda>>, example_args=[array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]), [array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])]], in_values_preprocessor=None, _graph=None), GraphPattern(name='scale_only_broadcast_0', tag_primitive=scale_and_shift_tag, compute_func=functools.partial(<function _scale_and_shift>, has_scale=True, has_shift=False), parameters_extractor_func=<function _make_scale_and_shift_pattern.<locals>.<lambda>>, example_args=[array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]), [array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])]], in_values_preprocessor=None, _graph=None), GraphPattern(name='shift_only_broadcast_1', tag_primitive=scale_and_shift_tag, compute_func=functools.partial(<function _scale_and_shift>, has_scale=False, has_shift=True), parameters_extractor_func=<function _make_scale_and_shift_pattern.<locals>.<lambda>>, example_args=[array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]), [array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])]], in_values_preprocessor=None, _graph=None), GraphPattern(name='shift_only_broadcast_0', tag_primitive=scale_and_shift_tag, compute_func=functools.partial(<function _scale_and_shift>, has_scale=False, has_shift=True), parameters_extractor_func=<function _make_scale_and_shift_pattern.<locals>.<lambda>>, example_args=[array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]), [array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])]], in_values_preprocessor=None, _graph=None)))[source]
Transforms the function by automatically registering layer tags.
- Parameters
func (utils.Func) – The original function to transform.
func_args (utils.FuncArgs) – Example arguments to
funcwhich to be used for tracing it.params_index (int) – Specifies, which inputs to the function are to be considered a parameter variable. Specifically -
inputs[params_index].register_only_generic (bool) – If
Trueregisters all parameters not already in a layer tag with a generic tag, effectively ignoringgraph_patterns.compute_only_loss_tags (bool) – If set to
True(default) the resulting function will only compute the loss tags infunc, not its full computation and actual output.patterns_to_skip (Sequence[str]) – The names of any patterns from the provided list, which to be skipped/not used during the pattern matching.
allow_multiple_registrations (bool) – Whether to raise an error if a parameter is registered with more than one layer tag.
graph_matcher_rules (GraphMatcherComparator) – A
GraphMatcherRulesinstance, which is used for determining equivalence of individual Jax primitives.graph_patterns (Sequence[GraphPattern]) – A sequence of
GraphPatternobjects, which contain all patterns to use, in order of precedence, which to try to find in the graph before registering a parameter with a generic layer tag.
- Return type
TaggedFunction
- Returns
A transformed function as described above.
Function tracing and Jacobian computation
|
A wrapper around Jaxpr, with useful additional data. |
|
Creates a function for the vector-Jacobian product w.r.t. |
|
Creates a function for the Jacobian-vector product w.r.t. |
|
Creates a function for the Hessian-vector product w.r.t. |
|
Creates a function for primal values and tangents w.r.t. |
ProcessedJaxpr
- class kfac_jax.ProcessedJaxpr(jaxpr, consts, in_tree, params_index, allow_left_out_params=False)[source]
A wrapper around Jaxpr, with useful additional data.
- jaxpr
The original Jaxpr that is being wrapped.
- consts
The constants returned from the tracing of the original Jaxpr.
- in_tree
The PyTree structure of the inputs to the function that the original Jaxpr has been created from.
- params_index
Specifies, which inputs to the function are to be considered a parameter variable. Specifically -
inputs[params_index].
- loss_tags
A tuple of all of the loss tags in the original Jaxpr.
- layer_tags
A sorted tuple of all of the layer tags in the original Jaxpr. The sorting order is based on the indices of the parameters associated with each layer tag.
- layer_indices
A sequence of tuples, where each tuple has the indices of the parameters associated with the corresponding layer tag.
- __init__(jaxpr, consts, in_tree, params_index, allow_left_out_params=False)[source]
Initializes the instance.
- Parameters
jaxpr (jax.core.Jaxpr) – The raw Jaxpr.
consts (Sequence[Any]) – The constants needed for evaluation of the raw Jaxpr.
in_tree (utils.PyTreeDef) – The PyTree structure of the inputs to the function that the
jaxprhas been created from.params_index (int) – Specifies, which inputs to the function are to be considered a parameter variable. Specifically -
inputs[params_index].allow_left_out_params (bool) – Whether to raise an error if any of the parameter variables is not included in any layer tag.
- property in_vars_flat: List[Var]
A flat list of all of the abstract input variables.
- Return type
List[Var]
- property in_vars: utils.PyTree[Var]
The abstract input variables, as an un-flatten structure.
- Return type
utils.PyTree[Var]
- property params_vars: utils.PyTree[Var]
The abstract parameter variables, as an un-flatten structure.
- Return type
utils.PyTree[Var]
- property params_vars_flat: List[Var]
A flat list of all abstract parameter variables.
- Return type
List[Var]
- property params_tree: utils.PyTreeDef
The PyTree structure of the parameter variables.
- Return type
utils.PyTreeDef
- classmethod make_from_func(func, func_args, params_index=0, auto_register_tags=True, allow_left_out_params=False, **auto_registration_kwargs)[source]
Constructs a
ProcessedJaxprfrom a the given function.- Parameters
func (utils.Func) – The model function, which will be traced.
func_args (FuncArgs) – Function arguments to use for tracing.
params_index (int) – The variables from the function arguments which are at this index (e.g.
func_args[params_index]) are to be considered model parameters.auto_register_tags (bool) – Whether to run an automatic layer registration on the function (e.g.
auto_register_tags()).allow_left_out_params (bool) – If this is set to
Falsean error would be raised if there are any model parameters that have not be assigned to a layer tag.**auto_registration_kwargs – Any additional keyword arguments, to be passed to the automatic registration pass.
- Return type
ProcJaxpr
- Returns
A
ProcessedJaxprrepresenting the model function.
loss_tags_vjp
- kfac_jax.loss_tags_vjp(func, params_index=0)[source]
Creates a function for the vector-Jacobian product w.r.t. all loss tags.
The returned function has a similar interface to
jax.vjp(). It takes as inputs the concrete values of the primals at which the Jacobian will be evaluated. It returns a pair(losses, losses_vjp), where losses is a tuple ofLossFunctionobjects andvjp_funcis a function taking as inputs the concrete values of the tangents of the inputs for each loss tag (corresponding to a loss object inlosses) and returns the corresponding tangents of the parameters.- Parameters
func (utils.Func) – The model function, which must include at least one loss registration.
params_index (int) – The variables from the function arguments which are at this index (e.g. func_args[params_index]) are to be considered model parameters.
- Return type
TransformedFunction[LossTagsVjp]
- Returns
A function that computes the vector-Jacobian product with signature Callable[[FuncArgs], LossTagsVjp].
loss_tags_jvp
- kfac_jax.loss_tags_jvp(func, params_index=0)[source]
Creates a function for the Jacobian-vector product w.r.t. all loss tags.
The returned function has a similar interface to
jax.jvp(). It takes as inputs the concrete values of the primals at which the Jacobian will be evaluated at and the concrete values of the tangents for the parameters, as specified byprocessed_jaxpr.params_index. It returns a pair(losses, losses_tangents), wherelossesis a tuple ofLossFunctionobjects, andlosses_tangentsis a tuple containing the tangents of the inputs for each loss tag (corresponding to a loss object inlosses).- Parameters
func (utils.Func) – The model function, which must include at least one loss registration.
params_index (int) – The variables from the function arguments which are at this index (e.g. func_args[params_index]) are to be considered model parameters.
- Return type
…
- Returns
A function that computes the Jacobian-vector product with signature Callable[[FuncArgs, Params], LossTagsVjp].
loss_tags_hvp
- kfac_jax.loss_tags_hvp(func, params_index=0)[source]
Creates a function for the Hessian-vector product w.r.t. all loss tags.
The returned function takes as inputs the concrete values of the primals for the function arguments at which the Hessian will be evaluated at and the concrete values of the tangents for the parameters, as specified by
processed_jaxpr.params_index. It returns the product of the Hessian with these tangents via backward-over-forward mode autodiff.- Parameters
func (utils.Func) – The model function, which must include at least one loss registration.
params_index (int) – The variables from the function arguments which are at this index (e.g. func_args[params_index]) are to be considered model parameters.
- Return type
…
- Returns
A function that computes the Hessian-vector product and also returns all losses, with signature Callable[[FuncArgs, Params], Tuple[LossTagsVjp, Tuple[loss_functions.LossFunction, …]].
layer_tags_vjp
- kfac_jax.layer_tags_vjp(func, params_index=0, auto_register_tags=True, raise_error_on_diff_jaxpr=True, **auto_registration_kwargs)[source]
Creates a function for primal values and tangents w.r.t. all layer tags.
The returned function has a similar interface to
jax.vjp(). It takes as inputs the concrete values of the primals at which the Jacobian will be evaluated. It returns a pair(losses, vjp_func), wherelossesis a tuple ofLossFunctionobjects, andvjp_funcis a function taking as inputs the concrete values of the tangents of the inputs for each loss tag (corresponding to a loss object inlosses) and returns a list of quantities computed for each layer tag inprocessed_jaxpr. Each entry of the list is a dictionary with the following self-explanatory keys:inputs, outputs, params, outputs_tangents, params_tangents.- Parameters
func (utils.Func) – The model function, which must include at least one loss registration.
params_index (int) – The variables from the function arguments which are at this index (e.g.
func_args[params_index]) are to be considered model parameters.auto_register_tags (bool) – Whether to run an automatic layer registration on the function (e.g.
auto_register_tags()).raise_error_on_diff_jaxpr (bool) – When tracing with different arguments, if the returned jaxpr has a different graph will raise an exception.
**auto_registration_kwargs – Any additional keyword arguments, to be passed to the automatic registration pass.
- Return type
…
- Returns
Returns a function that computes primal values and tangents wrt all layer tags, with signature Callable[[FuncArgs, Params], LossTagsVjp].
Patches Second Moments
|
Computes the first and second moment of the convolutional patches. |
|
The exact same functionality as |
patches_moments
- kfac_jax.patches_moments(inputs, kernel_spatial_shape, strides=1, padding='VALID', data_format='NHWC', dim_numbers=None, inputs_dilation=None, kernel_dilation=None, feature_group_count=1, batch_group_count=1, unroll_loop=False, precision=None, weighting_array=None)[source]
Computes the first and second moment of the convolutional patches.
Since the code is written to support arbitrary convolution data formats, e.g. both NHWC and NCHW, in comments above any of the procedures is written the simplified version of what the statements below do, if the data format was fixed to NHWC.
- Parameters
inputs (Array) – The batch of images.
kernel_spatial_shape (Union[int, Shape]) – The spatial dimensions of the filter (int or list of ints).
strides (Union[int, Shape]) – The spatial dimensions of the strides (int or list of ints).
padding (PaddingVariants) – The padding (str or list of pairs of ints).
data_format (Optional[str]) – The data format of the inputs (None, NHWC, NCHW).
dim_numbers (Optional[Union[DimNumbers, lax.ConvDimensionNumbers]]) – Instance of
jax.lax.ConvDimensionNumbersinstead of data_format.inputs_dilation (Optional[Sequence[int]]) – An integer or sequence of integers, specifying the dilation for the image. Currently, patches_moments does not support dilation, so the only allowed values are None, 1, (1,1).
kernel_dilation (Optional[Sequence[int]]) – An integer or sequence of integers, specifying the dilation for the kernel. Currently, patches_moments does not support dilation, so the only allowed values are None, 1, (1,1).
feature_group_count (int) – The feature grouping for grouped convolutions. Currently, patches_moments supports only 1 and number of channels of the inputs.
batch_group_count (int) – The batch grouping for grouped convolutions. Currently, patches_moments supports only 1.
unroll_loop (bool) – Whether to unroll the loop in python.
precision (Optional[jax.lax.Precision]) – In what precision to run the computation. For more details please read Jax documentation of
jax.lax.conv_general_dilated().weighting_array (Optional[Array]) – A tensor specifying additional weighting of each element of the moment’s average.
- Return type
Tuple[Array, Array]
- Returns
The matrix of the patches’ second and first moment as a pair. The tensor of the patches’ second moment has a shape kernel_spatial_shape + (, channels) + kernel_spatial_shape + (, channels). The tensor of the patches’ first moment has a shape kernel_spatial_shape + (, channels).
patches_moments_explicit
- kfac_jax.patches_moments_explicit(inputs, kernel_spatial_shape, strides=1, padding='VALID', data_format='NHWC', dim_numbers=None, inputs_dilation=None, kernel_dilation=None, feature_group_count=1, batch_group_count=1, unroll_loop=False, precision=None, weighting_array=None)[source]
The exact same functionality as
patches_moments(), but explicitly extracts the patches viajax.lax.conv_general_dilated_patches(), potentially having a higher memory usage.- Return type
Tuple[Array, Array]