Standard API

Optimizers

Optimizer(value_and_grad_func, l2_reg[, ...])

The K-FAC optimizer.

Optimizer

class kfac_jax.Optimizer(value_and_grad_func, l2_reg, value_func_has_aux=False, value_func_has_state=False, value_func_has_rng=False, use_adaptive_learning_rate=False, learning_rate_schedule=None, use_adaptive_momentum=False, momentum_schedule=None, use_adaptive_damping=False, damping_schedule=None, initial_damping=None, min_damping=1e-08, max_damping=inf, include_damping_in_quad_change=False, damping_adaptation_interval=5, damping_adaptation_decay=0.9, damping_lower_threshold=0.25, damping_upper_threshold=0.75, always_use_exact_qmodel_for_damping_adjustment=False, norm_constraint=None, num_burnin_steps=10, estimation_mode='fisher_gradients', curvature_ema=0.95, inverse_update_period=5, use_exact_inverses=False, batch_process_func=None, register_only_generic=False, patterns_to_skip=(), auto_register_kwargs=None, layer_tag_to_block_ctor=None, multi_device=False, debug=False, batch_size_extractor=<function default_batch_size_extractor>, pmap_axis_name='kfac_axis', forbid_setting_attributes_after_finalize=True, modifiable_attribute_exceptions=(), include_norms_in_stats=False, include_per_param_norms_in_stats=False, distributed_precon_apply=True, distributed_inverses=True)[source]

The K-FAC optimizer.

class State(velocities, estimator_state, damping, data_seen, step_counter)[source]

Persistent state of the optimizer.

velocities

The update to the parameters from the previous step - \(\theta_t - \theta_{t-1}\).

Type

Params

estimator_state

The persistent state for the curvature estimator.

Type

curvature_estimator.BlockDiagonalCurvature.State

damping

When using damping adaptation, this will contain the current value.

Type

Optional[Array]

data_seen

The number of training cases that the optimizer has processed.

Type

Numeric

step_counter

An integer giving the current step number \(t\).

Type

Numeric

__init__(velocities, estimator_state, damping, data_seen, step_counter)
__init__(value_and_grad_func, l2_reg, value_func_has_aux=False, value_func_has_state=False, value_func_has_rng=False, use_adaptive_learning_rate=False, learning_rate_schedule=None, use_adaptive_momentum=False, momentum_schedule=None, use_adaptive_damping=False, damping_schedule=None, initial_damping=None, min_damping=1e-08, max_damping=inf, include_damping_in_quad_change=False, damping_adaptation_interval=5, damping_adaptation_decay=0.9, damping_lower_threshold=0.25, damping_upper_threshold=0.75, always_use_exact_qmodel_for_damping_adjustment=False, norm_constraint=None, num_burnin_steps=10, estimation_mode='fisher_gradients', curvature_ema=0.95, inverse_update_period=5, use_exact_inverses=False, batch_process_func=None, register_only_generic=False, patterns_to_skip=(), auto_register_kwargs=None, layer_tag_to_block_ctor=None, multi_device=False, debug=False, batch_size_extractor=<function default_batch_size_extractor>, pmap_axis_name='kfac_axis', forbid_setting_attributes_after_finalize=True, modifiable_attribute_exceptions=(), include_norms_in_stats=False, include_per_param_norms_in_stats=False, distributed_precon_apply=True, distributed_inverses=True)[source]

Initializes the K-FAC optimizer with the provided settings.

A note on damping:

One of the main complications of using second-order optimizers like K-FAC is the “damping” parameter. This parameter is multiplied by the identity matrix and (approximately) added to the curvature matrix (i.e. the Fisher or GGN) before it is inverted and multiplied by the gradient when computing the update (before any learning rate scaling). The damping should follow the scale of the objective, so that if you multiply your loss by some factor you should do the same for the damping. Roughly speaking, larger damping values constrain the update vector to a smaller region around zero, which is needed in general since the second-order approximations that underly second-order methods can break down for large updates. (In gradient descent the learning rate plays an analogous role.) The relationship between the damping parameter and the radius of this region is complicated and depends on the scale of the objective amongst other things.

The optimizer provides a system for adjusting the damping automatically via the use_adaptive_damping argument, although this system is not completely reliable. Using a fixed value or a manually tuned schedule can work as good or better for some problems, while it can be a very poor choice for others (like deep autoencoders). Empirically we have found that using a fixed value works well enough for common architectures like convnets and transformers.

Parameters
  • value_and_grad_func (ValueAndGradFunc) – Python callable. The function should return the value of the loss to be optimized and its gradients, and optionally the model state and auxiliary information (usually statistics to log). The interface should be: out_args, loss_grads = value_and_grad_func(*in_args). Here, in_args is (params, func_state, rng, batch), with rng omitted if value_func_has_rng is False, and with func_state omitted if value_func_has_state is False. Meanwhile, out_args is (loss, func_state, aux), with func_state omitted if value_func_has_state is False, and with aux omitted if value_func_has_aux is False. If both value_func_has_state and value_func_has_aux are False, out_args should just be loss and not (loss,).

  • l2_reg (Numeric) – Scalar. Set this value to tell the optimizer what L2 regularization coefficient you are using (if any). Note the coefficient appears in the regularizer as coeff / 2 * sum(param**2). This adds an additional diagonal term to the curvature and hence will affect the quadratic model when using adaptive damping. Note that the user is still responsible for adding regularization to the loss.

  • value_func_has_aux (bool) – Boolean. Specifies whether the provided callable value_and_grad_func returns auxiliary data. (Default: False)

  • value_func_has_state (bool) – Boolean. Specifies whether the provided callable value_and_grad_func has a persistent state that is passed in and out. (Default: False)

  • value_func_has_rng (bool) – Boolean. Specifies whether the provided callable value_and_grad_func additionally takes as input an rng key. (Default: False)

  • use_adaptive_learning_rate (bool) – Boolean. Specifies whether to use the special rule from the original K-FAC paper for picking the learning rate at each step. Note that this won’t work well for stochastic objectives. If this is False, the user must use the learning_rate argument of the step function, or the constructor argument learning_rate_schedule. (Default: False)

  • learning_rate_schedule (Optional[ScheduleType]) – Callable. A schedule for the learning rate. This should take as input the current step number and return a single array that represents the learning rate. (Default: None)

  • use_adaptive_momentum (bool) – Boolean. Specifies whether to use the special rule from the original K-FAC paper for picking the momentum “decay” parameter at each step. Note that this won’t work well for stochastic bjectives. If this is False, the user must use the momentum argument of the step function, or the constructor argument momentum_schedule. (Default: False)

  • momentum_schedule (Optional[ScheduleType]) – Callable. A schedule for the momentum parameter. This should take as input the current step number and return a single array that represents the momentum. (Default: None)

  • use_adaptive_damping (bool) – Boolean. Specifies whether the optimizer will use the Levenberg-Marquardt method to automatically adjust the damping every damping_adaptation_interval iterations. If this is set to False the user must provide a value to the damping argument of the step function at each iteration, or use the damping_schedule constructor argument. Note that the effectiveness of this technique seems to vary between problems. (Default: False)

  • damping_schedule (Optional[ScheduleType]) – Callable. A schedule for the damping. This should take as input the current step number and return a single array that represents the learning rate. (Default: None)

  • initial_damping (Optional[Numeric]) – Scalar or None. This specifies the initial value of the damping that the optimizer will use when using automatic damping adaptation. (Default: None)

  • min_damping (Numeric) – Scalar. Minimum value the damping parameter can take when using automatic damping adaptation. Note that the default value of 1e-8 is quite arbitrary, and you may have to adjust this up or down for your particular problem. If you are using a non-zero value of l2_reg you may be able to set this to zero. (Default: 1e-8)

  • max_damping (Numeric) – Scalar. Maximum value the damping parameter can take when using automatic damping adaptation. (Default: Infinity)

  • include_damping_in_quad_change (bool) – Boolean. Whether to include the contribution of the damping in the quadratic model for the purposes computing the reduction ration (“rho”) in the Levenberg-Marquardt scheme used for adapting the damping. Note that the contribution from the l2_reg argument is always included. (Default: False)

  • damping_adaptation_interval (int) – Int. The number of steps in between adapting the damping parameter. (Default: 5)

  • damping_adaptation_decay (Numeric) – Scalar. The damping parameter will be adjusted up or down by damping_adaptation_decay ** damping_adaptation_interval, or remain unchanged, every damping_adaptation_interval number of iterations. (Default: 0.9)

  • damping_lower_threshold (Numeric) – Scalar. The damping parameter is increased if the reduction ratio is below this threshold. (Default: 0.25)

  • damping_upper_threshold (Numeric) – Scalar. The damping parameter is decreased if the reduction ratio is below this threshold. (Default: 0.75)

  • always_use_exact_qmodel_for_damping_adjustment (bool) – Boolean. When using learning rate and/or momentum adaptation, the quadratic model change used for damping adaption is always computed using the exact curvature matrix. Otherwise, there is an option to use either the exact or approximate curvature matrix to compute the quadratic model change, which is what this argument controls. When True, the exact curvature matrix will be used, which is more expensive, but could possibly produce a better damping schedule. (Default: False)

  • norm_constraint (Optional[Numeric]) – Scalar. If specified, the update is scaled down so that its approximate squared Fisher norm v^T F v is at most the specified value. (Note that here F is the approximate curvature matrix, not the exact.) May only be used when use_adaptive_learning_rate is False. (Default: None)

  • num_burnin_steps (int) – Int. At the start of optimization, e.g. the first step, before performing the actual step the optimizer will perform this many times updates to the curvature approximation without updating the actual parameters. (Default: 10)

  • estimation_mode (str) – String. The type of estimator to use for the curvature matrix. See the documentation for CurvatureEstimator for a detailed description of the possible options. (Default: fisher_gradients).

  • curvature_ema (Numeric) – The decay factor used when calculating the covariance estimate moving averages. (Default: 0.95)

  • inverse_update_period (int) – Int. The number of steps in between updating the the computation of the inverse curvature approximation. (Default: 5)

  • use_exact_inverses (bool) – Bool. If True, preconditioner inverses are computed “exactly” without the pi-adjusted factored damping approach. Note that this involves the use of eigendecompositions, which can sometimes be much more expensive. (Default: False)

  • batch_process_func (Optional[Callable[[Batch], Batch]]) – Callable. A function which to be called on each batch before feeding to the KFAC on device. This could be useful for specific device input optimizations. (Default: None)

  • register_only_generic (bool) – Boolean. Whether when running the auto-tagger to register only generic parameters, or allow it to use the graph matcher to automatically pick up any kind of layer tags. (Default: False)

  • patterns_to_skip (Sequence[str]) – Tuple. A list of any patterns that should be skipped by the graph matcher when auto-tagging. (Default: ())

  • auto_register_kwargs (Optional[Dict[str, Any]]) – Any additional kwargs to be passed down to auto_register_tags(), which is called by the curvature estimator. (Default: None)

  • layer_tag_to_block_ctor (Optional[Dict[str, curvature_estimator.CurvatureBlockCtor]]) – Dictionary. A mapping from layer tags to block classes which to override the default choices of block approximation for that specific tag. See the documentation for CurvatureEstimator for a more detailed description. (Default: None)

  • multi_device (bool) – Boolean. Whether to use pmap and run the optimizer on multiple devices. (Default: False)

  • debug (bool) – Boolean. If neither the step or init functions should be jitted. Note that this also overrides multi_device and prevents using pmap. (Default: False)

  • batch_size_extractor (Callable[[Batch], Numeric]) – A function that takes as input the function arguments and returns the batch size for a single device. (Default: kfac.utils.default_batch_size_extractor)

  • pmap_axis_name (str) – String. The name of the pmap axis to use when multi_device is set to True. (Default: kfac_axis)

  • forbid_setting_attributes_after_finalize (bool) – Boolean. By default after the object is finalized, you can not set any of its properties. This is done in order to protect the user from making changes to the object attributes that would not be picked up by various internal methods after they have been compiled. However, if you are extending this class, and clearly understand the risks of modifying attributes, setting this to False will remove the restriction. (Default: True)

  • modifiable_attribute_exceptions (Sequence[str]) – Sequence of strings. Gives a list of names for attributes that can be modified after finalization even when forbid_setting_attributes_after_finalize is True. (Default: ())

  • include_norms_in_stats (bool) – Boolean. It True, the vector norms of the gradient, preconditioned gradient, and parameter update are included in the statistics returned by the step function. (Default: False)

  • include_per_param_norms_in_stats (bool) – Boolean. It True, the per-parameter vector norms of the gradient, preconditioned gradient, and parameter update are included in the statistics returned by the step function. (Default: False)

  • distributed_precon_apply (bool) – Boolean. Whether to distribute the application of the preconditioner across the different devices in a layer-wise fashion. If False, each device will (redundantly) perform the required operations for all of the layers. (Default: True)

  • distributed_inverses (bool) – Boolean. Whether to distribute the inverse computations (required to compute the preconditioner) across the different devices in a layer-wise fashion. If False, each device will (redundantly) perform the required computations for all of the layers. (Default: True)

property num_burnin_steps: int

The number of burnin steps to run before the first parameter update.

Return type

int

property l2_reg: Array

The weight of the additional diagonal term added to the curvature.

Return type

Array

property estimator: curvature_estimator.BlockDiagonalCurvature

The underlying curvature estimator used by the optimizer.

Return type

curvature_estimator.BlockDiagonalCurvature

property damping_decay_factor: Numeric

How fast to decay the damping, when using damping adaptation.

Return type

Numeric

should_update_damping(state)[source]

Whether at the current step the optimizer should update the damping.

Return type

Array

compute_loss_value(func_args)[source]

Computes the value of the loss function being optimized.

Return type

Array

verify_args_and_get_step_counter(step_counter, learning_rate=None, momentum=None, damping=None, global_step_int=None)[source]

Verifies that the arguments passed to the step function are correct.

Return type

int

init(params, rng, batch, func_state=None)[source]

Initializes the optimizer and returns the appropriate optimizer state.

Return type

‘Optimizer.State’

burnin(num_steps, params, state, rng, data_iterator, func_state=None)[source]

Runs all burnin steps required.

Return type

Tuple[‘Optimizer.State’, Optional[FuncState]]

step(params, state, rng, data_iterator=None, batch=None, func_state=None, learning_rate=None, momentum=None, damping=None, global_step_int=None)[source]

Performs a single update step using the optimizer.

Parameters
  • params (Params) – The current parameters of the model.

  • state ('Optimizer.State') – The current state of the optimizer.

  • rng (PRNGKey) – A Jax PRNG key. Should be different for each iteration and each Jax process/host.

  • data_iterator (Optional[Iterator[Batch]]) – A data iterator to use (if not passing batch).

  • batch (Optional[Batch]) – A single batch used to compute the update. Should only pass one of data_iterator or batch.

  • func_state (Optional[FuncState]) – Any function state that gets passed in and returned.

  • learning_rate (Optional[Array]) – Learning rate to use if the optimizer was created with use_adaptive_learning_rate=True, None otherwise.

  • momentum (Optional[Array]) – Momentum to use if the optimizer was created with use_adaptive_momentum=True, None otherwise.

  • damping (Optional[Array]) – Damping to use if the optimizer was created with use_adaptive_damping=True, None otherwise. See discussion of constructor argument initial_damping for more information about damping.

  • global_step_int (Optional[int]) – The global step as a python int. Note that this must match the step internal to the optimizer that is part of its state.

Return type

ReturnEither

Returns

(params, state, stats) if value_func_has_state=False and (params, state, func_state, stats) otherwise, where

  • params is the updated model parameters.

  • state is the updated optimizer state.

  • func_state is the updated function state.

  • stats is a dictionary of useful statistics including the loss.

compute_l2_quad_matrix(vectors)[source]

Computes the matrix corresponding to the prior/regularizer.

Parameters
  • vectors (Sequence[Params]) – A sequence of parameter-like PyTree structures, each one

  • vector. (representing a different) –

Return type

Array

Returns

A matrix with i,j entry equal to self.l2_reg * v_i^T v_j.

compute_exact_quad_model(vectors, grads, func_args=None)[source]

Computes the components of the exact quadratic model.

Return type

Tuple[Array, Array, Array]

compute_approx_quad_model(state, vectors, grads)[source]

Computes the components of the approximate quadratic model.

Return type

Tuple[Array, Array, Array]

compute_quadratic_model_value(a, a_damped, b, w)[source]

Computes the quadratic model value from the inputs provided.

Return type

Array

weighted_sum_of_objects(objects, coefficients)[source]

Returns the weighted sum of the objects in the sequence.

Return type

utils.PyTree

Curvature Estimators

CurvatureEstimator(func[, params_index, ...])

An abstract curvature estimator class.

BlockDiagonalCurvature(func[, params_index, ...])

Block diagonal curvature estimator class.

ExplicitExactCurvature(func[, params_index, ...])

Explicit exact full curvature estimator class.

ImplicitExactCurvature(func[, params_index, ...])

Represents all exact curvature matrices never constructed explicitly.

set_default_tag_to_block_ctor(tag_name, ...)

Sets the default curvature block constructor for the given tag.

get_default_tag_to_block_ctor(tag_name)

Returns the default curvature block constructor for the give tag name.

CurvatureEstimator

class kfac_jax.CurvatureEstimator(func, params_index=0, default_estimation_mode='fisher_gradients')[source]

An abstract curvature estimator class.

This is a class that abstracts away the process of estimating a curvature matrix and provides many useful functionalities for interacting with it. The state of the estimator contains two parts: the estimated curvature internal representation, as well as potential cached values of different expression involving the curvature matrix (for example matrix powers). The cached values are only updated once you call the method update_cache(). Multiple methods contain the keyword argument use_cached which specify whether you want to compute the corresponding expression using the current curvature estimate or used a cached version.

func

The model evaluation function.

params_index

The index of the parameters argument in arguments list of func.

default_estimation_mode

The estimation mode which to use by default when calling update_curvature_matrix_estimate().

__init__(func, params_index=0, default_estimation_mode='fisher_gradients')[source]

Initializes the CurvatureEstimator instance.

Parameters
  • func (utils.Func) – The model function, which should have at least one registered loss.

  • params_index (int) – The index of the parameters argument in arguments list of func.

  • default_estimation_mode (str) – The estimation mode which to use by default when calling update_curvature_matrix_estimate().

property default_mat_type: str

The type of matrix that this estimator is approximating.

Return type

str

abstract property dim: int

The number of elements of all parameter variables together.

Return type

int

abstract init(rng, func_args, exact_powers_to_cache, approx_powers_to_cache, cache_eigenvalues=False)[source]

Initializes the state for the estimator.

Parameters
  • rng (PRNGKey) – The PRNGKey which to be used for any randomness of the initialization.

  • func_args (utils.FuncArgs) – Example function arguments, which to be used to trace the model function and initialize the state.

  • exact_powers_to_cache (Optional[curvature_blocks.ScalarOrSequence]) – A single value, or multiple values in a list, which specify which exact matrix powers that each block should be caching. Matrix powers for which you intend to call self.multiply_matrix_power, self.multiply_inverse or self.multiply with exact_power=True and use_cached=True must be provided here.

  • approx_powers_to_cache (Optional[curvature_blocks.ScalarOrSequence]) – A single value, or multiple values in a list, which specify approximate matrix powers that each block should be caching. Matrix powers for which you intend to call self.multiply_matrix_power, self.multiply_inverse or self.multiply with exact_power=False and use_cached=True must be provided here.

  • cache_eigenvalues (bool) – Specifies whether each block should be caching the eigenvalues of its approximate curvature.

Return type

StateType

Returns

The initialized state of the estimator.

abstract multiply_matpower(state, parameter_structured_vector, identity_weight, power, exact_power, use_cached, pmap_axis_name)[source]

Computes (CurvatureMatrix + identity_weight I)**power times vector.

Parameters
  • state (StateType) – The state of the estimator.

  • parameter_structured_vector (utils.Params) – A vector in the same structure as the parameters of the model.

  • identity_weight (Numeric) – Specifies the weight of the identity element that is added to the curvature matrix. This can be either a scalar value or a list/tuple of scalar in which case each value specifies the weight individually for each block.

  • power (Scalar) – The power to which you want to raise the matrix (EstimateCurvature + identity_weight I).

  • exact_power (bool) – When set to True the matrix power of EstimateCurvature + identity_weight I is computed exactly. Otherwise this method might use a cheaper approximation, which may vary across different blocks.

  • use_cached (bool) – Whether to use a cached (and possibly stale) version of the curvature matrix estimate.

  • pmap_axis_name (Optional[str]) – The name of any pmap axis, which will be used for aggregating any computed values over multiple devices, as well as parallelizing the computation over devices in a block-wise fashion.

Return type

utils.Params

Returns

A parameter structured vector containing the product.

multiply(state, parameter_structured_vector, identity_weight, exact_power, use_cached, pmap_axis_name)[source]

Computes (CurvatureMatrix + identity_weight I) times vector.

Return type

utils.Params

multiply_inverse(state, parameter_structured_vector, identity_weight, exact_power, use_cached, pmap_axis_name)[source]

Computes (CurvatureMatrix + identity_weight I)^-1 times vector.

Return type

utils.Params

abstract eigenvalues(state, use_cached)[source]

Computes the eigenvalues of the curvature matrix.

Parameters
  • state (StateType) – The state of the estimator.

  • use_cached (bool) – Whether to use a cached versions of the eigenvalues or to use the most recent curvature estimates to compute them. The cached version are going to be at least as fresh as the last time you called update_cache() with eigenvalues=True.

Return type

Array

Returns

A single array containing the eigenvalues of the curvature matrix.

abstract update_curvature_matrix_estimate(state, ema_old, ema_new, batch_size, rng, func_args, pmap_axis_name, estimation_mode=None)[source]

Updates the estimator’s curvature estimates.

Parameters
  • state (StateType) – The state of the estimator to update.

  • ema_old (Numeric) – Specifies the weight of the old value when computing the updated estimate in the moving average.

  • ema_new (Numeric) – Specifies the weight of the new value when computing the updated estimate in the moving average.

  • batch_size (Numeric) – The batch size.

  • rng (PRNGKey) – A PRNGKey to be used for any potential sampling in the estimation process.

  • func_args (utils.FuncArgs) – A structure with the values of the inputs to the traced function (the tagged_func passed into the constructor) which to be used for the estimation process. Should have the same structure as the argument func_args passed in the constructor.

  • pmap_axis_name (Optional[str]) – When calling this method within a pmap context this argument specifies the axis name over which to aggregate across multiple devices/hosts.

  • estimation_mode (Optional[str]) –

    The type of curvature estimator to use. By default (e.g. if None) will use self.default_estimation_mode. One of:

    • fisher_gradients - the basic estimation approach from the original K-FAC paper.

    • fisher_curvature_prop - method which estimates the Fisher using self-products of random 1/-1 vectors times “half-factors” of the Fisher, as described here.

    • fisher_exact - is the obvious generalization of Curvature Propagation to compute the exact Fisher (modulo any additional diagonal or Kronecker approximations) by looping over one-hot vectors for each coordinate of the output instead of using 1/-1 vectors. It is more expensive to compute than the other three options by a factor equal to the output dimension, roughly speaking.

    • fisher_empirical - computes the ‘empirical’ Fisher information matrix (which uses the data’s distribution for the targets, as opposed to the true Fisher which uses the model’s distribution) and requires that each registered loss have specified targets.

    • ggn_curvature_prop - Analogous to fisher_curvature_prop, but estimates the Generalized Gauss-Newton matrix (GGN).

    • ggn_exact - Analogous to fisher_exact, but estimates the Generalized Gauss-Newton matrix (GGN).

Return type

StateType

Returns

The updated state.

abstract update_cache(state, identity_weight, exact_powers, approx_powers, eigenvalues, pmap_axis_name)[source]

Updates the estimator cached values.

Parameters
  • state (StateType) – The state of the estimator to update.

  • identity_weight (Numeric) – Specified the weight of the identity element that is added to the curvature matrix. This can be either a scalar value or a list/tuple of scalar in which case each value specifies the weight individually for each block.

  • exact_powers (Optional[curvature_blocks.ScalarOrSequence]) – Specifies which exact matrix powers in the cache should be updated.

  • approx_powers (Optional[curvature_blocks.ScalarOrSequence]) – Specifies which approximate matrix powers in the cache should be updated.

  • eigenvalues (bool) – Specifies whether to update the cached eigenvalues of each block. If they have not been cached before, this will create an entry with them in the block’s cache.

  • pmap_axis_name (Optional[str]) – The name of any pmap axis, which will be used for aggregating any computed values over multiple devices, as well as parallelizing the computation over devices in a block-wise fashion.

Return type

StateType

Returns

The updated state.

abstract to_dense_matrix(state)[source]

Returns an explicit dense array representing the curvature matrix.

Return type

Array

BlockDiagonalCurvature

class kfac_jax.BlockDiagonalCurvature(func, params_index=0, default_estimation_mode='fisher_gradients', layer_tag_to_block_ctor=None, index_to_block_ctor=None, auto_register_tags=True, distributed_multiplies=True, distributed_cache_updates=True, **auto_register_kwargs)[source]

Block diagonal curvature estimator class.

class State(blocks_states)[source]

Persistent state of the estimator.

blocks_states

A tuple of the state of the estimator corresponding to each block.

Type

Tuple[curvature_blocks.CurvatureBlock.State, …]

__init__(blocks_states)
__init__(func, params_index=0, default_estimation_mode='fisher_gradients', layer_tag_to_block_ctor=None, index_to_block_ctor=None, auto_register_tags=True, distributed_multiplies=True, distributed_cache_updates=True, **auto_register_kwargs)[source]

Initializes the curvature instance.

Parameters
  • func (utils.Func) – The model function, which should have at least one registered loss.

  • params_index (int) – The index of the parameters argument in arguments list of func.

  • default_estimation_mode (str) – The estimation mode which to use by default when calling self.update_curvature_matrix_estimate.

  • layer_tag_to_block_ctor (Optional[Mapping[str, CurvatureBlockCtor]]) – An optional dict mapping tags to specific classes of block approximations, which to override the default ones.

  • index_to_block_ctor (Optional[Mapping[Tuple[int, ...], CurvatureBlockCtor]]) – An optional dict mapping a specific block parameter indices to specific classes of block approximation, which to override the default ones. To get the correct indices check estimator.indices_to_block_map.

  • auto_register_tags (bool) – Whether to automatically register layer tags for parameters that have not been manually registered. For further details see tag_graph_matcher.auto_register_tags.

  • distributed_multiplies (bool) – Whether to distribute the curvature matrix multiplication operations across the different devices in a block-wise fashion. If False, each device will (redundantly) perform the operations for all of the blocks.

  • distributed_cache_updates (bool) – Whether to distribute the cache update multiplication operations across the different devices in a block-wise fashion. If False, each device will (redundantly) perform the operations for all of the blocks.

  • **auto_register_kwargs – Any keyword arguments to pass to into the auto registration function.

property blocks: Optional[Tuple[curvature_blocks.CurvatureBlock]]

The tuple of CurvatureBlock instances used for each layer.

Return type

Optional[Tuple[curvature_blocks.CurvatureBlock]]

property num_blocks: int

The number of separate blocks that this estimator has.

Return type

int

property block_dims: Shape

The number of elements of all parameter variables for each block.

Return type

Shape

property dim: int

The number of elements of all parameter variables together.

Return type

int

property params_structure_vector_of_indices: utils.Params

A tree structure with parameters replaced by their indices.

Return type

utils.Params

property indices_to_block_map: Mapping[Tuple[int, ...], curvature_blocks.CurvatureBlock]

A mapping of parameter indices to their associated blocks.

Return type

Mapping[Tuple[int, …], curvature_blocks.CurvatureBlock]

property params_block_index: utils.Params

A structure, which shows each parameter to which block it corresponds.

Return type

utils.Params

Returns

A parameter-like structure, where each parameter is replaced by an integer index. This index specifies the block (found by self.blocks[index]) which approximates the part of the curvature matrix associated with the parameter.

property num_params_variables: int

The number of separate parameter variables of the model.

Return type

int

params_vector_to_blocks_vectors(parameter_structured_vector)[source]

Splits the parameters to values for each corresponding block.

Return type

Tuple[Tuple[Array, …]]

blocks_vectors_to_params_vector(blocks_vectors)[source]

Reverses the effect of self.vectors_to_blocks.

Return type

utils.Params

init(rng, func_args, exact_powers_to_cache, approx_powers_to_cache, cache_eigenvalues=False)[source]

Initializes the state for the estimator.

Parameters
  • rng (PRNGKey) – The PRNGKey which to be used for any randomness of the initialization.

  • func_args (utils.FuncArgs) – Example function arguments, which to be used to trace the model function and initialize the state.

  • exact_powers_to_cache (Optional[curvature_blocks.ScalarOrSequence]) – A single value, or multiple values in a list, which specify which exact matrix powers that each block should be caching. Matrix powers for which you intend to call self.multiply_matrix_power, self.multiply_inverse or self.multiply with exact_power=True and use_cached=True must be provided here.

  • approx_powers_to_cache (Optional[curvature_blocks.ScalarOrSequence]) – A single value, or multiple values in a list, which specify approximate matrix powers that each block should be caching. Matrix powers for which you intend to call self.multiply_matrix_power, self.multiply_inverse or self.multiply with exact_power=False and use_cached=True must be provided here.

  • cache_eigenvalues (bool) – Specifies whether each block should be caching the eigenvalues of its approximate curvature.

Return type

‘BlockDiagonalCurvature.State’

Returns

The initialized state of the estimator.

multiply_matpower(state, parameter_structured_vector, identity_weight, power, exact_power, use_cached, pmap_axis_name)[source]

Computes (CurvatureMatrix + identity_weight I)**power times vector.

Parameters
  • state ('BlockDiagonalCurvature.State') – The state of the estimator.

  • parameter_structured_vector (utils.Params) – A vector in the same structure as the parameters of the model.

  • identity_weight (Union[Numeric, Sequence[Numeric]]) – Specifies the weight of the identity element that is added to the curvature matrix. This can be either a scalar value or a list/tuple of scalar in which case each value specifies the weight individually for each block.

  • power (Scalar) – The power to which you want to raise the matrix (EstimateCurvature + identity_weight I).

  • exact_power (bool) – When set to True the matrix power of EstimateCurvature + identity_weight I is computed exactly. Otherwise this method might use a cheaper approximation, which may vary across different blocks.

  • use_cached (bool) – Whether to use a cached (and possibly stale) version of the curvature matrix estimate.

  • pmap_axis_name (Optional[str]) – The name of any pmap axis, which will be used for aggregating any computed values over multiple devices, as well as parallelizing the computation over devices in a block-wise fashion.

Return type

utils.Params

Returns

A parameter structured vector containing the product.

block_eigenvalues(state, use_cached)[source]

Computes the eigenvalues for each block of the curvature estimator.

Parameters
  • state ('BlockDiagonalCurvature.State') – The state of the estimator.

  • use_cached (bool) – Whether to use a cached versions of the eigenvalues or to use the most recent curvature estimates to compute them. The cached version are going to be at least as fresh as the last time you called update_cache() with eigenvalues=True.

Return type

Tuple[Array, …]

Returns

A tuple of arrays containing the eigenvalues for each block. The order of this tuple corresponds to the ordering of self.blocks. To understand which parameters correspond to which block you can call self.parameters_block_index.

eigenvalues(state, use_cached)[source]

Computes the eigenvalues of the curvature matrix.

Parameters
  • state ('BlockDiagonalCurvature.State') – The state of the estimator.

  • use_cached (bool) – Whether to use a cached versions of the eigenvalues or to use the most recent curvature estimates to compute them. The cached version are going to be at least as fresh as the last time you called update_cache() with eigenvalues=True.

Return type

Array

Returns

A single array containing the eigenvalues of the curvature matrix.

update_curvature_matrix_estimate(state, ema_old, ema_new, batch_size, rng, func_args, pmap_axis_name, estimation_mode=None)[source]

Updates the estimator’s curvature estimates.

Parameters
  • state ('BlockDiagonalCurvature.State') – The state of the estimator to update.

  • ema_old (Numeric) – Specifies the weight of the old value when computing the updated estimate in the moving average.

  • ema_new (Numeric) – Specifies the weight of the new value when computing the updated estimate in the moving average.

  • batch_size (Numeric) – The batch size.

  • rng (PRNGKey) – A PRNGKey to be used for any potential sampling in the estimation process.

  • func_args (utils.FuncArgs) – A structure with the values of the inputs to the traced function (the tagged_func passed into the constructor) which to be used for the estimation process. Should have the same structure as the argument func_args passed in the constructor.

  • pmap_axis_name (Optional[str]) – When calling this method within a pmap context this argument specifies the axis name over which to aggregate across multiple devices/hosts.

  • estimation_mode (Optional[str]) –

    The type of curvature estimator to use. By default (e.g. if None) will use self.default_estimation_mode. One of:

    • fisher_gradients - the basic estimation approach from the original K-FAC paper.

    • fisher_curvature_prop - method which estimates the Fisher using self-products of random 1/-1 vectors times “half-factors” of the Fisher, as described here.

    • fisher_exact - is the obvious generalization of Curvature Propagation to compute the exact Fisher (modulo any additional diagonal or Kronecker approximations) by looping over one-hot vectors for each coordinate of the output instead of using 1/-1 vectors. It is more expensive to compute than the other three options by a factor equal to the output dimension, roughly speaking.

    • fisher_empirical - computes the ‘empirical’ Fisher information matrix (which uses the data’s distribution for the targets, as opposed to the true Fisher which uses the model’s distribution) and requires that each registered loss have specified targets.

    • ggn_curvature_prop - Analogous to fisher_curvature_prop, but estimates the Generalized Gauss-Newton matrix (GGN).

    • ggn_exact - Analogous to fisher_exact, but estimates the Generalized Gauss-Newton matrix (GGN).

Return type

‘BlockDiagonalCurvature.State’

Returns

The updated state.

update_cache(state, identity_weight, exact_powers, approx_powers, eigenvalues, pmap_axis_name)[source]

Updates the estimator cached values.

Parameters
  • state ('BlockDiagonalCurvature.State') – The state of the estimator to update.

  • identity_weight (Union[Numeric, Sequence[Numeric]]) – Specified the weight of the identity element that is added to the curvature matrix. This can be either a scalar value or a list/tuple of scalar in which case each value specifies the weight individually for each block.

  • exact_powers (Optional[curvature_blocks.ScalarOrSequence]) – Specifies which exact matrix powers in the cache should be updated.

  • approx_powers (Optional[curvature_blocks.ScalarOrSequence]) – Specifies which approximate matrix powers in the cache should be updated.

  • eigenvalues (bool) – Specifies whether to update the cached eigenvalues of each block. If they have not been cached before, this will create an entry with them in the block’s cache.

  • pmap_axis_name (Optional[str]) – The name of any pmap axis, which will be used for aggregating any computed values over multiple devices, as well as parallelizing the computation over devices in a block-wise fashion.

Return type

‘BlockDiagonalCurvature.State’

Returns

The updated state.

to_diagonal_block_dense_matrix(state)[source]

Returns a tuple of arrays with explicit dense matrices of each block.

Return type

Tuple[Array, …]

to_dense_matrix(state)[source]

Returns an explicit dense array representing the curvature matrix.

Return type

Array

ExplicitExactCurvature

class kfac_jax.ExplicitExactCurvature(func, params_index=0, batch_index=1, default_estimation_mode='fisher_exact', layer_tag_to_block_ctor=None, index_to_block_ctor=None, auto_register_tags=True, **auto_register_kwargs)[source]

Explicit exact full curvature estimator class.

This class estimates the full curvature matrix by looping over the batch dimension of the input data and for each single example computes an estimate of the curvature matrix and then averages over all examples in the input data. This implies that the computation scales linearly (without parallelism) with the batch size. The class stores the estimated curvature as a dense matrix, hence its memory requirement is (number of parameters)^2. If estimation_mode is fisher_exact or ggn_exact than this would compute the exact curvature, but other modes are also supported. As a result of looping over the input data this class needs to know the index of the batch in the arguments to the model function and additionally, since the loop is achieved through indexing, each array leaf of that argument must have the same first dimension size, which will be interpreted as the batch size.

__init__(func, params_index=0, batch_index=1, default_estimation_mode='fisher_exact', layer_tag_to_block_ctor=None, index_to_block_ctor=None, auto_register_tags=True, **auto_register_kwargs)[source]

Initializes the curvature instance.

Parameters
  • func (utils.Func) – The model function, which should have at least one registered loss.

  • params_index (int) – The index of the parameters argument in arguments list of func.

  • batch_index (int) – Specifies at which index of the inputs to func is the batch, representing data over which we average the curvature.

  • default_estimation_mode (str) – The estimation mode which to use by default when calling self.update_curvature_matrix_estimate.

  • layer_tag_to_block_ctor (Optional[Mapping[str, CurvatureBlockCtor]]) – An optional dict mapping tags to specific classes of block approximations, which to override the default ones.

  • index_to_block_ctor (Optional[Mapping[Tuple[int, ...], CurvatureBlockCtor]]) – An optional dict mapping a specific block parameter indices to specific classes of block approximation, which to override the default ones. To get the correct indices check estimator.indices_to_block_map.

  • auto_register_tags (bool) – Whether to automatically register layer tags for parameters that have not been manually registered. For further details see :func:~auto_register_tags.

  • **auto_register_kwargs – Any keyword arguments to pass to into the auto registration function.

property batch_index: int

The index in the inputs of the model function, which is the batch.

Return type

int

params_vector_to_blocks_vectors(parameter_structured_vector)[source]

Splits the parameters to values for each corresponding block.

Return type

Tuple[Tuple[Array, …]]

blocks_vectors_to_params_vector(blocks_vectors)[source]

Reverses the effect of self.vectors_to_blocks.

Return type

utils.Params

update_curvature_matrix_estimate(state, ema_old, ema_new, batch_size, rng, func_args, pmap_axis_name, estimation_mode=None)[source]

Updates the estimator’s curvature estimates.

Parameters
  • state (BlockDiagonalCurvature.State) – The state of the estimator to update.

  • ema_old (Numeric) – Specifies the weight of the old value when computing the updated estimate in the moving average.

  • ema_new (Numeric) – Specifies the weight of the new value when computing the updated estimate in the moving average.

  • batch_size (Numeric) – The batch size.

  • rng (PRNGKey) – A PRNGKey to be used for any potential sampling in the estimation process.

  • func_args (utils.FuncArgs) – A structure with the values of the inputs to the traced function (the tagged_func passed into the constructor) which to be used for the estimation process. Should have the same structure as the argument func_args passed in the constructor.

  • pmap_axis_name (Optional[str]) – When calling this method within a pmap context this argument specifies the axis name over which to aggregate across multiple devices/hosts.

  • estimation_mode (Optional[str]) –

    The type of curvature estimator to use. By default (e.g. if None) will use self.default_estimation_mode. One of:

    • fisher_gradients - the basic estimation approach from the original K-FAC paper.

    • fisher_curvature_prop - method which estimates the Fisher using self-products of random 1/-1 vectors times “half-factors” of the Fisher, as described here.

    • fisher_exact - is the obvious generalization of Curvature Propagation to compute the exact Fisher (modulo any additional diagonal or Kronecker approximations) by looping over one-hot vectors for each coordinate of the output instead of using 1/-1 vectors. It is more expensive to compute than the other three options by a factor equal to the output dimension, roughly speaking.

    • fisher_empirical - computes the ‘empirical’ Fisher information matrix (which uses the data’s distribution for the targets, as opposed to the true Fisher which uses the model’s distribution) and requires that each registered loss have specified targets.

    • ggn_curvature_prop - Analogous to fisher_curvature_prop, but estimates the Generalized Gauss-Newton matrix (GGN).

    • ggn_exact - Analogous to fisher_exact, but estimates the Generalized Gauss-Newton matrix (GGN).

Return type

curvature_blocks.Full.State

Returns

The updated state.

update_cache(state, identity_weight, exact_powers, approx_powers, eigenvalues, pmap_axis_name)[source]

Updates the estimator cached values.

Parameters
  • state (BlockDiagonalCurvature.State) – The state of the estimator to update.

  • identity_weight (Numeric) – Specified the weight of the identity element that is added to the curvature matrix. This can be either a scalar value or a list/tuple of scalar in which case each value specifies the weight individually for each block.

  • exact_powers (Optional[curvature_blocks.ScalarOrSequence]) – Specifies which exact matrix powers in the cache should be updated.

  • approx_powers (Optional[curvature_blocks.ScalarOrSequence]) – Specifies which approximate matrix powers in the cache should be updated.

  • eigenvalues (bool) – Specifies whether to update the cached eigenvalues of each block. If they have not been cached before, this will create an entry with them in the block’s cache.

  • pmap_axis_name (Optional[str]) – The name of any pmap axis, which will be used for aggregating any computed values over multiple devices, as well as parallelizing the computation over devices in a block-wise fashion.

Return type

curvature_blocks.Full.State

Returns

The updated state.

ImplicitExactCurvature

class kfac_jax.ImplicitExactCurvature(func, params_index=0, batch_size_extractor=<function default_batch_size_extractor>)[source]

Represents all exact curvature matrices never constructed explicitly.

__init__(func, params_index=0, batch_size_extractor=<function default_batch_size_extractor>)[source]

Initializes the ImplicitExactCurvature instance.

Parameters
  • func (utils.Func) – The model function, which should have at least one registered loss.

  • params_index (int) – The index of the parameters argument in arguments list of func.

  • batch_size_extractor (Callable[[utils.Batch], Numeric]) – A function that takes as input the function arguments and returns the batch size for a single device. (Default: kfac.utils.default_batch_size_extractor)

batch_size(func_args)[source]

The expected batch size given a list of loss instances.

Return type

Numeric

multiply_hessian(func_args, parameter_structured_vector)[source]

Multiplies the vector with the Hessian matrix of the total loss.

Parameters
  • func_args (utils.FuncArgs) – The inputs to the model function, on which to evaluate the Hessian matrix.

  • parameter_structured_vector (utils.Params) – The vector which to multiply with the Hessian matrix.

Return type

utils.Params

Returns

The product Hv.

multiply_fisher(func_args, parameter_structured_vector)[source]

Multiplies the vector with the Fisher matrix of the total loss.

Parameters
  • func_args (utils.FuncArgs) – The inputs to the model function, on which to evaluate the Fisher matrix.

  • parameter_structured_vector (utils.Params) – The vector which to multiply with the Fisher matrix.

Return type

utils.Params

Returns

The product Fv.

multiply_ggn(func_args, parameter_structured_vector)[source]

Multiplies the vector with the GGN matrix of the total loss.

Parameters
  • func_args (utils.FuncArgs) – The inputs to the model function, on which to evaluate the GGN matrix.

  • parameter_structured_vector (utils.Params) – The vector which to multiply with the GGN matrix.

Return type

utils.Params

Returns

The product Gv.

multiply_fisher_factor_transpose(func_args, parameter_structured_vector)[source]

Multiplies the vector with the transposed factor of the Fisher matrix.

Parameters
  • func_args (utils.FuncArgs) – The inputs to the model function, on which to evaluate the Fisher matrix.

  • parameter_structured_vector (utils.Params) – The vector which to multiply with the Fisher matrix.

Return type

Tuple[Array, …]

Returns

The product B^T v, where F = BB^T.

multiply_ggn_factor_transpose(func_args, parameter_structured_vector)[source]

Multiplies the vector with the transposed factor of the GGN matrix.

Parameters
  • func_args (utils.FuncArgs) – The inputs to the model function, on which to evaluate the GGN matrix.

  • parameter_structured_vector (utils.Params) – The vector which to multiply with the GGN matrix.

Return type

Tuple[Array, …]

Returns

The product B^T v, where G = BB^T.

multiply_fisher_factor(func_args, loss_inner_vectors)[source]

Multiplies the vector with the factor of the Fisher matrix.

Parameters
  • func_args (utils.FuncArgs) – The inputs to the model function, on which to evaluate the Fisher matrix.

  • loss_inner_vectors (Sequence[Array]) – The vector which to multiply with the Fisher factor matrix.

Return type

utils.Params

Returns

The product Bv, where F = BB^T.

multiply_ggn_factor(func_args, loss_inner_vectors)[source]

Multiplies the vector with the factor of the GGN matrix.

Parameters
  • func_args (utils.FuncArgs) – The inputs to the model function, on which to evaluate the GGN matrix.

  • loss_inner_vectors (Sequence[Array]) – The vector which to multiply with the GGN factor matrix.

Return type

utils.Params

Returns

The product Bv, where G = BB^T.

multiply_jacobian_transpose(func_args, loss_input_vectors)[source]

Multiplies a vector by the model’s transposed Jacobian.

Parameters
  • func_args (utils.FuncArgs) – The inputs to the model function.

  • loss_input_vectors (Sequence[Sequence[Array]]) – A sequence over losses of sequences of arrays that are the size of the loss’s inputs. This represents the vector to be multiplied.

Return type

utils.Params

Returns

The product J^T v, where J is the model’s Jacobian and v is is given by loss_inner_vectors.

get_loss_inner_vector_shapes_and_batch_size(func_args, mode)[source]

Get shapes of loss inner vectors, and the batch size.

Parameters
  • func_args (utils.FuncArgs) – The inputs to the model function.

  • mode (str) – A string representing the type of curvature matrix for the loss inner vectors. Can be “fisher” or “ggn”.

Return type

Tuple[Tuple[Shape, …], int]

Returns

Shapes of loss inner vectors in a tuple, and the batch size as an int.

get_loss_input_shapes_and_batch_size(func_args)[source]

Get shapes of loss input vectors, and the batch size.

Parameters

func_args (utils.FuncArgs) – The inputs to the model function.

Return type

Tuple[Tuple[Tuple[Shape, …], …], int]

Returns

A tuple over losses of tuples containing the shapes of their different inputs, and the batch size (as an int).

set_default_tag_to_block_ctor

kfac_jax.set_default_tag_to_block_ctor(tag_name, block_ctor)[source]

Sets the default curvature block constructor for the given tag.

Return type

None

get_default_tag_to_block_ctor

kfac_jax.get_default_tag_to_block_ctor(tag_name)[source]

Returns the default curvature block constructor for the give tag name.

Return type

Optional[CurvatureBlockCtor]

Loss Functions

LossFunction(weight)

Abstract base class for loss functions.

NegativeLogProbLoss(weight)

Base class for loss functions that represent negative log-probability.

DistributionNegativeLogProbLoss(weight)

Negative log-probability loss that uses a Distrax distribution.

NormalMeanNegativeLogProbLoss(mean[, ...])

Loss log prob loss for a normal distribution parameterized by a mean vector.

NormalMeanVarianceNegativeLogProbLoss(mean, ...)

Negative log prob loss for a normal distribution with mean and variance.

MultiBernoulliNegativeLogProbLoss(logits[, ...])

Negative log prob loss for multiple Bernoulli distributions parametrized by logits.

CategoricalLogitsNegativeLogProbLoss(logits)

Negative log prob loss for a categorical distribution parameterized by logits.

OneHotCategoricalLogitsNegativeLogProbLoss(logits)

Neg log prob loss for a categorical distribution with onehot targets.

register_sigmoid_cross_entropy_loss(logits)

Registers a sigmoid cross-entropy loss function.

register_multi_bernoulli_predictive_distribution(logits)

Registers a multi-Bernoulli predictive distribution.

register_softmax_cross_entropy_loss(logits)

Registers a softmax cross-entropy loss function.

register_categorical_predictive_distribution(logits)

Registers a categorical predictive distribution.

register_squared_error_loss(prediction[, ...])

Registers a squared error loss function.

register_normal_predictive_distribution(mean)

Registers a normal predictive distribution.

LossFunction

class kfac_jax.LossFunction(weight)[source]

Abstract base class for loss functions.

Note that unlike typical loss functions used in neural networks these are neither summed nor averaged over the batch and the output of evaluate() will not be a scalar. It is up to the user to then to correctly manipulate them as needed.

__init__(weight)[source]

Initializes the loss instance.

Parameters

weight (Numeric) – The relative weight attributed to the loss.

property weight: Numeric

The relative weight of the loss.

Return type

Numeric

abstract property targets: Optional[Array]

The targets (if present) used for evaluating the loss.

Return type

Optional[Array]

abstract property parameter_dependants: Tuple[Array, ...]

All the parameter dependent arrays of the loss.

Return type

Tuple[Array, …]

property num_parameter_dependants: int

Number of parameter dependent arrays of the loss.

Return type

int

abstract property parameter_independants: Tuple[Numeric, ...]

All the parameter independent arrays of the loss.

Return type

Tuple[Numeric, …]

property num_parameter_independants: int

Number of parameter independent arrays of the loss.

Return type

int

abstract copy_with_different_inputs(parameter_dependants)[source]

Creates a copy of the loss function object, but with different inputs.

Return type

‘LossFunction’

evaluate(targets=None, coefficient_mode='regular')[source]

Evaluates the loss function on the targets.

Parameters
  • targets (Optional[Array]) – The targets, on which to evaluate the loss. If this is set to None will use self.targets instead.

  • coefficient_mode (str) –

    Specifies how to use the relative weight of the loss in the returned value. There are three options:

    1. ’regular’ - returns self.weight * loss(targets)

    2. ’sqrt’ - returns sqrt(self.weight) * loss(targets)

    3. ’off’ - returns loss(targets)

Return type

Array

Returns

The value of the loss scaled appropriately by self.weight according to the coefficient mode.

Raises

ValueError if both targets and self.targets are None.

grad_of_evaluate(targets, coefficient_mode)[source]

Evaluates the gradient of the loss function, w.r.t. its inputs.

Parameters
  • targets (Optional[Array]) – The targets at which to evaluate the loss. If this is None will use self.targets instead.

  • coefficient_mode (str) – The coefficient mode to use for evaluation. See self.evaluate for more details.

Return type

Tuple[Array, …]

Returns

The gradient of the loss function w.r.t. its inputs, at the provided targets.

multiply_ggn(vector)[source]

Right-multiplies a vector by the GGN of the loss function.

Here the GGN is the Generalized Gauss-Newton matrix (whose definition is somewhat flexible) of the loss function with respect to its inputs.

Parameters

vector (Sequence[Array]) – The vector to multiply. Must have the same shape(s) as self.inputs.

Return type

Tuple[Array, …]

Returns

The vector right-multiplied by the GGN. Will have the same shape(s) as self.inputs.

abstract multiply_ggn_unweighted(vector)[source]

Unweighted version of multiply_ggn().

Return type

Tuple[Array, …]

multiply_ggn_factor(vector)[source]

Right-multiplies a vector by a factor B of the GGN.

Here the GGN is the Generalized Gauss-Newton matrix (whose definition is somewhat flexible) of the loss function with respect to its inputs. Typically this will be block-diagonal across different cases in the batch, since the loss function is typically summed across cases.

Note that B can be any matrix satisfying B * B^T = G where G is the GGN, but will agree with the one used in the other methods of this class.

Parameters

vector (Array) – The vector to multiply. Must be of the shape(s) given by ‘self.ggn_factor_inner_shape’.

Return type

Tuple[Array, …]

Returns

The vector right-multiplied by B. Will be of the same shape(s) as self.inputs.

abstract multiply_ggn_factor_unweighted(vector)[source]

Unweighted version of multiply_ggn_factor().

Return type

Tuple[Array, …]

multiply_ggn_factor_transpose(vector)[source]

Right-multiplies a vector by the transpose of a factor B of the GGN.

Here the GGN is the Generalized Gauss-Newton matrix (whose definition is somewhat flexible) of the loss function with respect to its inputs. Typically this will be block-diagonal across different cases in the batch, since the loss function is typically summed across cases.

Note that B can be any matrix satisfying B * B^T = G where G is the GGN, but will agree with the one used in the other methods of this class.

Parameters

vector (Sequence[Array]) – The vector to multiply. Must have the same shape(s) as self.inputs.

Return type

Array

Returns

The vector right-multiplied by B^T. Will be of the shape(s) given by self.ggn_factor_inner_shape.

abstract multiply_ggn_factor_transpose_unweighted(vector)[source]

Unweighted version of multiply_ggn_factor_transpose().

Return type

Array

multiply_ggn_factor_replicated_one_hot(index)[source]

Right-multiplies a replicated-one-hot vector by a factor B of the GGN.

Here the GGN is the Generalized Gauss-Newton matrix (whose definition is somewhat flexible) of the loss function with respect to its inputs. Typically this will be block-diagonal across different cases in the batch, since the loss function is typically summed across cases.

A replicated-one-hot vector means a tensor which, for each slice along the batch dimension (assumed to be dimension 0), is 1.0 in the entry corresponding to the given index and 0 elsewhere.

Note that B can be any matrix satisfying B * B^T = G where G is the GGN, but will agree with the one used in the other methods of this class.

Parameters

index (Sequence[int]) – A tuple representing in the index of the entry in each slice that is 1.0. Note that len(index) must be equal to the number of elements of the ggn_factor_inner_shape tensor minus one.

Return type

Tuple[Array, …]

Returns

The vector right-multiplied by B^T. Will be of the same shape(s) as the inputs property.

abstract multiply_ggn_factor_replicated_one_hot_unweighted(index)[source]

Unweighted version of multiply_ggn_factor_replicated_one_hot().

Return type

Tuple[Array, …]

abstract property ggn_factor_inner_shape: Shape

The shape of the array returned by self.multiply_ggn_factor.

Return type

Shape

NegativeLogProbLoss

class kfac_jax.NegativeLogProbLoss(weight)[source]

Base class for loss functions that represent negative log-probability.

property parameter_dependants: Tuple[Array, ...]

All the parameter dependent arrays of the loss.

Return type

Tuple[Array, …]

abstract property params: Tuple[Array, ...]

Parameters to the underlying distribution.

Return type

Tuple[Array, …]

multiply_fisher(vector)[source]

Right-multiplies a vector by the Fisher.

Parameters

vector (Sequence[Array]) – The vector to multiply. Must have the same shape(s) as self.inputs.

Return type

Tuple[Array, …]

Returns

The vector right-multiplied by the Fisher. Will have of the same shape(s) as self.inputs.

abstract multiply_fisher_unweighted(vector)[source]

Unweighted version of multiply_fisher().

Return type

Tuple[Array, …]

multiply_fisher_factor(vector)[source]

Right-multiplies a vector by a factor B of the Fisher.

Here the Fisher is the Fisher information matrix (i.e. expected outer- product of gradients) with respect to the parameters of the underlying probability distribution (whose log-prob defines the loss). Typically this will be block-diagonal across different cases in the batch, since the distribution is usually (but not always) conditionally iid across different cases.

Note that B can be any matrix satisfying B * B^T = F where F is the Fisher, but will agree with the one used in the other methods of this class.

Parameters

vector (Array) – The vector to multiply. Must have the same shape(s) as self.fisher_factor_inner_shape.

Return type

Tuple[Array, …]

Returns

The vector right-multiplied by B. Will have the same shape(s) as self.inputs.

abstract multiply_fisher_factor_unweighted(vector)[source]

Unweighted version of multiply_fisher_factor().

Return type

Tuple[Array, …]

multiply_fisher_factor_transpose(vector)[source]

Right-multiplies a vector by the transpose of a factor B of the Fisher.

Here the Fisher is the Fisher information matrix (i.e. expected outer- product of gradients) with respect to the parameters of the underlying probability distribution (whose log-prob defines the loss). Typically this will be block-diagonal across different cases in the batch, since the distribution is usually (but not always) conditionally iid across different cases.

Note that B can be any matrix satisfying B * B^T = F where F is the Fisher, but will agree with the one used in the other methods of this class.

Parameters

vector (Sequence[Array]) – The vector to multiply. Must have the same shape(s) as self.inputs.

Return type

Array

Returns

The vector right-multiplied by B^T. Will have the shape given by self.fisher_factor_inner_shape.

abstract multiply_fisher_factor_transpose_unweighted(vector)[source]

Unweighted version of multiply_fisher_factor_transpose().

Return type

Array

multiply_fisher_factor_replicated_one_hot(index)[source]

Right-multiplies a replicated-one-hot vector by a factor B of the Fisher.

Here the Fisher is the Fisher information matrix (i.e. expected outer- product of gradients) with respect to the parameters of the underlying probability distribution (whose log-prob defines the loss). Typically this will be block-diagonal across different cases in the batch, since the distribution is usually (but not always) conditionally iid across different cases.

A replicated-one-hot vector means a tensor which, for each slice along the batch dimension (assumed to be dimension 0), is 1.0 in the entry corresponding to the given index and 0 elsewhere.

Note that B can be any matrix satisfying B * B^T = H where H is the Fisher, but will agree with the one used in the other methods of this class.

Parameters

index (Sequence[int]) – A tuple representing in the index of the entry in each slice that is 1.0. Note that len(index) must be equal to the number of elements of the fisher_factor_inner_shape tensor minus one.

Return type

Tuple[Array, …]

Returns

The vector right-multiplied by B. Will have the same shape(s) as self.inputs.

abstract multiply_fisher_factor_replicated_one_hot_unweighted(index)[source]

Unweighted version of multiply_fisher_factor_replicated_one_hot().

Return type

Tuple[Array, …]

abstract property fisher_factor_inner_shape: Shape

The shape of the array returned by multiply_fisher_factor().

Return type

Shape

abstract sample(rng)[source]

Sample targets from the underlying distribution.

Return type

Array

grad_of_evaluate_on_sample(rng, coefficient_mode)[source]

Evaluates the gradient of the log probability on a random sample.

Parameters
  • rng (Array) – Jax PRNG key for sampling.

  • coefficient_mode (str) – The coefficient mode to use for evaluation.

Return type

Tuple[Array, …]

Returns

The gradient of the log probability of targets sampled from the distribution.

DistributionNegativeLogProbLoss

class kfac_jax.DistributionNegativeLogProbLoss(weight)[source]

Negative log-probability loss that uses a Distrax distribution.

abstract property dist: distrax.Distribution

The underlying Distrax distribution.

Return type

distrax.Distribution

sample(rng)[source]

Sample targets from the underlying distribution.

Return type

Array

property fisher_factor_inner_shape: Shape

The shape of the array returned by multiply_fisher_factor().

Return type

Shape

NormalMeanNegativeLogProbLoss

class kfac_jax.NormalMeanNegativeLogProbLoss(mean, targets=None, variance=0.5, weight=1.0)[source]

Loss log prob loss for a normal distribution parameterized by a mean vector.

Note that the covariance is treated as the identity divided by 2. Also note that the Fisher for such a normal distribution with respect the mean parameter is given by:

F = (1 / variance) * I

See for example https://www.ii.pwr.edu.pl/~tomczak/PDF/[JMT]Fisher_inf.pdf.

__init__(mean, targets=None, variance=0.5, weight=1.0)[source]

Initializes the loss instance.

Parameters
  • mean (Array) – The mean of the normal distribution.

  • targets (Optional[Array]) – Optional targets to use for evaluation.

  • variance (Numeric) – The scalar variance of the normal distribution.

  • weight (Numeric) – The relative weight of the loss.

property targets: Optional[Array]

The targets (if present) used for evaluating the loss.

Return type

Optional[Array]

property parameter_independants: Tuple[Numeric, ...]

All the parameter independent arrays of the loss.

Return type

Tuple[Numeric, …]

property dist: distrax.MultivariateNormalDiag

The underlying Distrax distribution.

Return type

distrax.MultivariateNormalDiag

property params: Tuple[Array]

Parameters to the underlying distribution.

Return type

Tuple[Array]

copy_with_different_inputs(parameter_dependants)[source]

Creates the same LossFunction object, but with different inputs.

Parameters

parameter_dependants (Sequence[Array]) – The inputs to use to the constructor of a class instance. This must be a sequence of length 1.

Return type

‘NormalMeanNegativeLogProbLoss’

Returns

An instance of NormalMeanNegativeLogPorLoss with the provided

inputs.

Raises

A ValueError if the inputs is a sequence of different length than 1.

multiply_fisher_unweighted(vector)[source]

Unweighted version of multiply_fisher().

Return type

Tuple[Array]

multiply_fisher_factor_unweighted(vector)[source]

Unweighted version of multiply_fisher_factor().

Return type

Tuple[Array]

multiply_fisher_factor_transpose_unweighted(vector)[source]

Unweighted version of multiply_fisher_factor_transpose().

Return type

Array

multiply_fisher_factor_replicated_one_hot_unweighted(index)[source]

Unweighted version of multiply_fisher_factor_replicated_one_hot().

Return type

Tuple[Array]

NormalMeanVarianceNegativeLogProbLoss

class kfac_jax.NormalMeanVarianceNegativeLogProbLoss(mean, variance, targets=None, weight=1.0)[source]

Negative log prob loss for a normal distribution with mean and variance.

This class parameterizes a multivariate normal distribution with n independent dimensions. Unlike NormalMeanNegativeLogProbLoss, this class does not assume the variance is held constant. The Fisher Information for n = 1 is given by:

F = [[1 / variance, 0],

[ 0, 0.5 / variance^2]]

where the parameters of the distribution are concatenated into a single vector as [mean, variance]. For n > 1, the mean parameter vector is concatenated with the variance parameter vector. For further details checkout the Wikipedia page.

__init__(mean, variance, targets=None, weight=1.0)[source]

Initializes the loss instance.

Parameters
  • mean (Array) – The mean of the normal distribution.

  • variance (Array) – The variance of the normal distribution.

  • targets (Optional[Array]) – Optional targets to use for evaluation.

  • weight (Numeric) – The relative weight of the loss.

property targets: Optional[Array]

The targets (if present) used for evaluating the loss.

Return type

Optional[Array]

property parameter_independants: Tuple[Numeric, ...]

All the parameter independent arrays of the loss.

Return type

Tuple[Numeric, …]

property dist: distrax.MultivariateNormalDiag

The underlying Distrax distribution.

Return type

distrax.MultivariateNormalDiag

property params: Tuple[Array, Array]

Parameters to the underlying distribution.

Return type

Tuple[Array, Array]

copy_with_different_inputs(parameter_dependants)[source]

Creates the same LossFunction object, but with different inputs.

Parameters

parameter_dependants (Sequence[Array]) – The inputs to use to the constructor of a class instance. This must be a sequence of length 2.

Return type

‘NormalMeanVarianceNegativeLogProbLoss’

Returns

An instance of NormalMeanVarianceNegativeLogProbLoss with the provided inputs.

Raises

A ValueError if the inputs is a sequence of different length than 2.

multiply_fisher_unweighted(vector)[source]

Unweighted version of multiply_fisher().

Return type

Tuple[Array, Array]

multiply_fisher_factor_unweighted(vector)[source]

Unweighted version of multiply_fisher_factor().

Return type

Tuple[Array, Array]

multiply_fisher_factor_transpose_unweighted(vector)[source]

Unweighted version of multiply_fisher_factor_transpose().

Return type

Array

multiply_fisher_factor_replicated_one_hot_unweighted(index)[source]

Unweighted version of multiply_fisher_factor_replicated_one_hot().

Return type

Tuple[Array, Array]

property fisher_factor_inner_shape: Shape

The shape of the array returned by multiply_fisher_factor().

Return type

Shape

multiply_ggn_unweighted(vector)[source]

Unweighted version of multiply_ggn().

Return type

Tuple[Array, …]

multiply_ggn_factor_unweighted(vector)[source]

Unweighted version of multiply_ggn_factor().

Return type

Tuple[Array, …]

multiply_ggn_factor_transpose_unweighted(vector)[source]

Unweighted version of multiply_ggn_factor_transpose().

Return type

Array

multiply_ggn_factor_replicated_one_hot_unweighted(index)[source]

Unweighted version of multiply_ggn_factor_replicated_one_hot().

Return type

Tuple[Array, …]

property ggn_factor_inner_shape: Shape

The shape of the array returned by self.multiply_ggn_factor.

Return type

Shape

MultiBernoulliNegativeLogProbLoss

class kfac_jax.MultiBernoulliNegativeLogProbLoss(logits, targets=None, weight=1.0)[source]

Negative log prob loss for multiple Bernoulli distributions parametrized by logits.

Represents N independent Bernoulli distributions where N = len(logits). Its Fisher Information matrix is given by F = diag(p * (1-p)), where p = sigmoid(logits).

As F is diagonal with positive entries, its factor B is B = diag(sqrt(p * (1-p))).

__init__(logits, targets=None, weight=1.0)[source]

Initializes the loss instance.

Parameters
  • logits (Array) – The logits of the Bernoulli distribution.

  • targets (Optional[Array]) – Optional targets to use for evaluation.

  • weight (Numeric) – The relative weight of the loss.

property targets: Optional[Array]

The targets (if present) used for evaluating the loss.

Return type

Optional[Array]

property parameter_independants: Tuple[Numeric, ...]

All the parameter independent arrays of the loss.

Return type

Tuple[Numeric, …]

property dist: distrax.Bernoulli

The underlying Distrax distribution.

Return type

distrax.Bernoulli

property params: Tuple[Array]

Parameters to the underlying distribution.

Return type

Tuple[Array]

copy_with_different_inputs(parameter_dependants)[source]

Creates a copy of the loss function object, but with different inputs.

Return type

‘MultiBernoulliNegativeLogProbLoss’

multiply_fisher_unweighted(vector)[source]

Unweighted version of multiply_fisher().

Return type

Tuple[Array]

multiply_fisher_factor_unweighted(vector)[source]

Unweighted version of multiply_fisher_factor().

Return type

Tuple[Array]

multiply_fisher_factor_transpose_unweighted(vector)[source]

Unweighted version of multiply_fisher_factor_transpose().

Return type

Array

multiply_fisher_factor_replicated_one_hot_unweighted(index)[source]

Unweighted version of multiply_fisher_factor_replicated_one_hot().

Return type

Tuple[Array]

CategoricalLogitsNegativeLogProbLoss

class kfac_jax.CategoricalLogitsNegativeLogProbLoss(logits, targets=None, mask=None, weight=1.0)[source]

Negative log prob loss for a categorical distribution parameterized by logits.

Note that the Fisher (for a single case) of a categorical distribution, with respect to the natural parameters (i.e. the logits), is given by F = diag(p) - p*p^T, where p = softmax(logits). F can be factorized as F = B * B^T, where B = diag(q) - p*q^T and q is the entry-wise square root of p. This is easy to verify using the fact that q^T*q = 1 .

__init__(logits, targets=None, mask=None, weight=1.0)[source]

Initializes the loss instance.

Parameters
  • logits (Array) – The logits of the Categorical distribution of shape (batch_size, output_size).

  • targets (Optional[Array]) – Optional targets to use for evaluation, which specify an integer index of the correct class. Must be of shape (batch_size,).

  • mask (Optional[Array]) – Optional mask to apply to losses over the batch. Should be 0/1-valued and of shape (batch_size,). The tensors returned by evaluate and grad_of_evaluate, as well as the various matrix vector products, will be multiplied by mask (with broadcasting to later dimensions).

  • weight (Numeric) – The relative weight of the loss.

property targets: Optional[Array]

The targets (if present) used for evaluating the loss.

Return type

Optional[Array]

property parameter_independants: Tuple[Numeric, ...]

All the parameter independent arrays of the loss.

Return type

Tuple[Numeric, …]

property dist: distrax.Categorical

The underlying Distrax distribution.

Return type

distrax.Categorical

property params: Tuple[Array]

Parameters to the underlying distribution.

Return type

Tuple[Array]

property fisher_factor_inner_shape: Shape

The shape of the array returned by multiply_fisher_factor().

Return type

Shape

copy_with_different_inputs(parameter_dependants)[source]

Creates a copy of the loss function object, but with different inputs.

Return type

‘CategoricalLogitsNegativeLogProbLoss’

multiply_fisher_unweighted(vector)[source]

Unweighted version of multiply_fisher().

Return type

Tuple[Array]

multiply_fisher_factor_unweighted(vector)[source]

Unweighted version of multiply_fisher_factor().

Return type

Tuple[Array]

multiply_fisher_factor_transpose_unweighted(vector)[source]

Unweighted version of multiply_fisher_factor_transpose().

Return type

Array

multiply_fisher_factor_replicated_one_hot_unweighted(index)[source]

Unweighted version of multiply_fisher_factor_replicated_one_hot().

Return type

Tuple[Array]

OneHotCategoricalLogitsNegativeLogProbLoss

class kfac_jax.OneHotCategoricalLogitsNegativeLogProbLoss(logits, targets=None, mask=None, weight=1.0)[source]

Neg log prob loss for a categorical distribution with onehot targets.

Identical to CategoricalLogitsNegativeLogProbLoss except that the underlying distribution is OneHotCategorical as opposed to Categorical.

property dist: distrax.OneHotCategorical

The underlying Distrax distribution.

Return type

distrax.OneHotCategorical

copy_with_different_inputs(parameter_dependants)[source]

Creates a copy of the loss function object, but with different inputs.

Return type

‘OneHotCategoricalLogitsNegativeLogProbLoss’

register_sigmoid_cross_entropy_loss

kfac_jax.register_sigmoid_cross_entropy_loss(logits, targets=None, weight=1.0)[source]

Registers a sigmoid cross-entropy loss function.

Note that this is distinct from register_softmax_cross_entropy_loss() and should not be confused with it. It is similar to register_multi_bernoulli_predictive_distribution() but without the explicit probabilistic interpretation. It behaves identically for now.

Parameters
  • logits (Array) – The input logits of the loss as a 2D array of floats. The first dimension will usually be the batch size, but doesn’t need to be (unless using estimation_mode='fisher_exact' or estimation_mode='ggn_exact' in the optimizer/estimator).

  • targets (Optional[Array]) – (OPTIONAL) The targets for the loss function. Must be of the same shape as logits. Only required if using estimation_mode='fisher_empirical' in the optimizer/estimator. (Default: None)

  • weight (Numeric) – (OPTIONAL) a scalar. A coefficient to multiply the loss function by. (Default: 1.0)

register_multi_bernoulli_predictive_distribution

kfac_jax.register_multi_bernoulli_predictive_distribution(logits, targets=None, weight=1.0)[source]

Registers a multi-Bernoulli predictive distribution.

Note that this is distinct from register_categorical_predictive_distribution() and should not be confused with it.

Parameters
  • logits (Array) – The logits of the distribution (i.e. its parameters) as a 2D array of floats. The first dimension will usually be the batch size, but doesn’t need to be (unless using estimation_mode='fisher_exact' or estimation_mode='ggn_exact' in the optimizer/estimator).

  • targets (Optional[Array]) – (OPTIONAL) The targets for the loss function. Only required if using estimation_mode='fisher_empirical' in the optimizer/estimator. (Default: None)

  • weight (Numeric) – (OPTIONAL) a scalar. A coefficient to multiply the log prob loss associated with this distribution. The Fisher will be multiplied by the corresponding factor. This is NOT equivalent to changing the temperature of the distribution since we don’t renormalize the log prob in the objective function. (Default: 1.0)

register_softmax_cross_entropy_loss

kfac_jax.register_softmax_cross_entropy_loss(logits, targets=None, mask=None, weight=1.0)[source]

Registers a softmax cross-entropy loss function.

Note that this is distinct from register_sigmoid_cross_entropy_loss() and should not be confused with it. It is similar to register_categorical_predictive_distribution() but without the explicit probabilistic interpretation. It behaves identically for now.

Parameters
  • logits (Array) – The input logits of the loss as a 2D array of floats. The first dimension will usually be the batch size, but doesn’t need to be (unless using estimation_mode='fisher_exact' or estimation_mode='ggn_exact' in the optimizer/estimator). The second dimension is the one over which the softmax is computed.

  • targets (Optional[Array]) – (OPTIONAL) The targets for the loss function. Must be a 1D array of integers with shape (logits.shape[0],). Only required if using estimation_mode='fisher_empirical' in the optimizer/estimator. (Default: None)

  • mask (Optional[Array]) – (OPTIONAL) Mask to apply to losses. Should be 0/1-valued and of shape (logits.shape[0],). Losses corresponding to mask values of False will be treated as constant and equal to 0. (Default: None)

  • weight (Numeric) – (OPTIONAL) a scalar. A coefficient to multiply the loss function by. (Default: 1.0)

Return type

Array

register_categorical_predictive_distribution

kfac_jax.register_categorical_predictive_distribution(logits, targets=None, mask=None, weight=1.0)[source]

Registers a categorical predictive distribution.

Note that this is distinct from register_multi_bernoulli_predictive_distribution() and should not be confused with it.

Parameters
  • logits (Array) – The logits of the distribution (i.e. its parameters) as a 2D array of floats. The first dimension will usually be the batch size, but doesn’t need to be (unless using estimation_mode='fisher_exact' or estimation_mode='ggn_exact' in the optimizer/estimator). The second dimension is the one over which the softmax is computed.

  • targets (Optional[Array]) – (OPTIONAL) The values at which the log probability of this distribution is evaluated (to give the loss). Must be a 2D array of integers with shape (logits.shape[0],). Only required if using estimation_mode='fisher_empirical' in the optimizer/estimator. (Default: None)

  • mask (Optional[Array]) – (OPTIONAL) Mask to apply to log probabilities generated by the distribution. Should be 0/1-valued and of shape (logits.shape[0],). Log probablities corresponding to mask values of False will be treated as constant and equal to 0. (Default: None)

  • weight (Numeric) – (OPTIONAL) a scalar. A coefficient to multiply the log prob loss associated with this distribution. The Fisher will be multiplied by the corresponding factor. This is NOT equivalent to changing the temperature of the distribution since we don’t renormalize the log prob in the objective function. (Default: 1.0)

register_squared_error_loss

kfac_jax.register_squared_error_loss(prediction, targets=None, weight=1.0)[source]

Registers a squared error loss function.

This assumes the squared error loss of the form ||target - prediction||^2, averaged across the mini-batch. If your loss uses a coefficient of 0.5 you need to set the “weight” argument to reflect this.

Parameters
  • prediction (Array) – The prediction made by the network (i.e. its output). The first dimension will usually be the batch size, but doesn’t need to be (unless using estimation_mode='fisher_exact' or estimation_mode='ggn_exact' in the optimizer/estimator).

  • targets (Optional[Array]) – (OPTIONAL) The targets for the loss function. Only required if using estimation_mode='fisher_empirical' in the optimizer/estimator. (Default: None)

  • weight (Numeric) – A float coefficient to multiply the loss function by. (Default: 1.0)

Return type

Array

register_normal_predictive_distribution

kfac_jax.register_normal_predictive_distribution(mean, targets=None, variance=0.5, weight=1.0)[source]

Registers a normal predictive distribution.

This corresponds to a squared error loss of the form

weight/(2*var) * ||target - mean||^2

Parameters
  • mean (Array) – A tensor defining the mean vector of the distribution. The first dimension will usually be the batch size, but doesn’t need to be (unless using estimation_mode='fisher_exact' or estimation_mode='ggn_exact' in the optimizer/estimator).

  • targets (Optional[Array]) – (OPTIONAL) The targets for the loss function. Only required if using estimation_mode='fisher_empirical' in the optimizer/estimator. (Default: None)

  • variance (float) – float. The variance of the distribution. Note that the default value of 0.5 corresponds to a standard squared error loss weight * ||target - prediction||^2. If you want your squared error loss to be of the form 0.5*coeff*||target - prediction||^2 you should use variance=1.0. (Default: 0.5)

  • weight (Numeric) – A scalar coefficient to multiply the log prob loss associated with this distribution. The Fisher will be multiplied by the corresponding factor. In general this is NOT equivalent to changing the temperature of the distribution, but in the ase of normal distributions it may be. (Default: 1.0)

Curvature Blocks

CurvatureBlock(layer_tag_eq, name)

Abstract class for curvature approximation blocks.

ScaledIdentity(layer_tag_eq, name[, scale])

A block that assumes that the curvature is a scaled identity matrix.

Diagonal(layer_tag_eq, name)

An abstract class for approximating only the diagonal of curvature.

Full(layer_tag_eq, name[, ...])

An abstract class for approximating the block matrix with a full matrix.

TwoKroneckerFactored(layer_tag_eq, name)

A Kronecker factored block for layers with weights and an optional bias.

NaiveDiagonal(layer_tag_eq, name)

Approximates the diagonal of the curvature with in the most obvious way.

NaiveFull(layer_tag_eq, name[, ...])

Approximates the full curvature with in the most obvious way.

DenseDiagonal(layer_tag_eq, name)

A Diagonal block specifically for dense layers.

DenseFull(layer_tag_eq, name[, ...])

A Full block specifically for dense layers.

DenseTwoKroneckerFactored(layer_tag_eq, name)

A TwoKroneckerFactored block specifically for dense layers.

Conv2DDiagonal(layer_tag_eq, name[, ...])

A Diagonal block specifically for 2D convolution layers.

Conv2DFull(layer_tag_eq, name[, ...])

A Full block specifically for 2D convolution layers.

Conv2DTwoKroneckerFactored(layer_tag_eq, name)

A TwoKroneckerFactored block specifically for 2D convolution layers.

ScaleAndShiftDiagonal(layer_tag_eq, name)

A diagonal approximation specifically for a scale and shift layers.

ScaleAndShiftFull(layer_tag_eq, name[, ...])

A full dense approximation specifically for a scale and shift layers.

set_max_parallel_elements(value)

Sets the default value of maximum parallel elements in the module.

get_max_parallel_elements()

Returns the default value of maximum parallel elements in the module.

set_default_eigen_decomposition_threshold(value)

Sets the default value of the eigen decomposition threshold.

get_default_eigen_decomposition_threshold()

Returns the default value of the eigen decomposition threshold.

CurvatureBlock

class kfac_jax.CurvatureBlock(layer_tag_eq, name)[source]

Abstract class for curvature approximation blocks.

A CurvatureBlock defines a curvature matrix to be estimated, and gives methods to multiply powers of this with a vector. Powers can be computed exactly or with a class-determined approximation. Cached versions of the powers can be pre-computed to make repeated multiplications cheaper. During initialization, you would have to explicitly specify all powers that you will need to cache.

class State(cache)[source]

Persistent state of the block.

Any subclasses of CurvatureBlock should also internally extend this class, with any attributes needed for the curvature estimation.

cache

A dictionary, containing any state data that is updated on irregular intervals, such as inverses, eigenvalues, etc. Elements of this are updated via calls to update_cache(), and do not necessarily correspond to the most up-to-date curvature estimate.

Type

Optional[Dict[str, Union[Array, Dict[str, Array]]]]

__init__(cache)
__init__(layer_tag_eq, name)[source]

Initializes the block.

Parameters
  • layer_tag_eq (tags.LayerTagEqn) – The Jax equation corresponding to the layer tag that this block will approximate the curvature to.

  • name (str) – The name of this block.

property layer_tag_primitive: tags.LayerTag

The jax.core.Primitive corresponding to the block’s tag equation.

Return type

tags.LayerTag

property parameter_variables: Tuple[jax.core.Var, ...]

The parameter variables of the underlying Jax equation.

Return type

Tuple[jax.core.Var, …]

property outputs_shapes: Tuple[Shape, ...]

The shapes of the output variables of the block’s tag equation.

Return type

Tuple[Shape, …]

property inputs_shapes: Tuple[Shape, ...]

The shapes of the input variables of the block’s tag equation.

Return type

Tuple[Shape, …]

property parameters_shapes: Tuple[Shape, ...]

The shapes of the parameter variables of the block’s tag equation.

Return type

Tuple[Shape, …]

property parameters_canonical_order: Tuple[int, ...]

The canonical order of the parameter variables.

Return type

Tuple[int, …]

property layer_tag_extra_params: Dict[str, Any]

Any extra parameters of passed into the Jax primitive of this block.

Return type

Dict[str, Any]

property number_of_parameters: int

Number of parameter variables of this block.

Return type

int

property dim: int

The number of elements of all parameter variables together.

Return type

int

scale(state, use_cache)[source]

A scalar pre-factor of the curvature approximation.

Importantly, all methods assume that whenever a user requests cached values, any state dependant scale is taken into account by the cache (e.g. either stored explicitly and used or mathematically added to values).

Parameters
  • state ('CurvatureBlock.State') – The state for this block.

  • use_cache (bool) – Whether the method requesting this is using cached values or not.

Return type

Numeric

Returns

A scalar value to be multiplied with any unscaled block representation.

fixed_scale()[source]

A fixed scalar pre-factor of the curvature (e.g. constant).

Return type

Numeric

state_dependent_scale(state)[source]

A scalar pre-factor of the curvature, computed from the most fresh curvature estimate.

Return type

Numeric

init(rng, exact_powers_to_cache, approx_powers_to_cache, cache_eigenvalues)[source]

Initializes the state for this block.

Parameters
  • rng (PRNGKey) – The PRNGKey which to be used for any randomness of the initialization

  • exact_powers_to_cache (Optional[ScalarOrSequence]) – A single value, or multiple values in a list, which specify which exact matrix powers the block should be caching. Matrix powers, which are expected to be used in multiply_matpower(), multiply_inverse() or multiply() with exact_power=True and use_cached=True must be provided here.

  • approx_powers_to_cache (Optional[ScalarOrSequence]) – A single value, or multiple values in a list, which specify approximate matrix powers the block should be caching. Matrix powers, which are expected to be used in multiply_matrix_power(), multiply_inverse() or multiply() with exact_power=False and use_cached=True must be provided here.

  • cache_eigenvalues (bool) – Specifies whether the block should be caching the eigenvalues of its approximate curvature.

Return type

‘CurvatureBlock.State’

Returns

A dictionary with the initialized state.

multiply_matpower(state, vector, identity_weight, power, exact_power, use_cached)[source]

Computes (BlockMatrix + identity_weight I)**power times vector.

Parameters
  • state ('CurvatureBlock.State') – The state for this block.

  • vector (Sequence[Array]) – A tuple of arrays that should have the same shapes as the block’s parameters_shapes, which represent the vector you want to multiply.

  • identity_weight (Numeric) – A scalar specifying the weight on the identity matrix that is added to the block matrix before raising it to a power. If use_cached=False it is guaranteed that this argument will be used in the computation. When returning cached values, this argument may be ignored in favor whatever value was last passed to update_cache(). The precise semantics of this depend on the concrete subclass and its particular behavior in regard to caching.

  • power (Scalar) – The power to which to raise the matrix.

  • exact_power (bool) – Specifies whether to compute the exact matrix power of BlockMatrix + identity_weight I. When this argument is False the exact behaviour will depend on the concrete subclass and the result will in general be an approximation to (BlockMatrix + identity_weight I)^power, although some subclasses may still compute the exact matrix power.

  • use_cached (bool) – Whether to use a cached version for computing the product or to use the most recent curvature estimates. The cached version is going to be at least as fresh as the value provided to the last call to update_cache() with the same value of power

Return type

Tuple[Array, …]

Returns

A tuple of arrays, representing the result of the matrix-vector product.

multiply(state, vector, identity_weight, exact_power, use_cached)[source]

Computes (BlockMatrix + identity_weight I) times vector.

Return type

Tuple[Array, …]

multiply_inverse(state, vector, identity_weight, exact_power, use_cached)[source]

Computes (BlockMatrix + identity_weight I)^-1 times vector.

Return type

Tuple[Array, …]

eigenvalues(state, use_cached)[source]

Computes the eigenvalues for this block approximation.

Parameters
  • state ('CurvatureBlock.State') – The state dict for this block.

  • use_cached (bool) – Whether to use a cached versions of the eigenvalues or to use the most recent curvature estimates to compute them. The cached version are going to be at least as fresh as the last time you called update_cache() with eigenvalues=True.

Return type

Array

Returns

An array containing the eigenvalues of the block.

abstract update_curvature_matrix_estimate(state, estimation_data, ema_old, ema_new, batch_size, pmap_axis_name)[source]

Updates the block’s curvature estimates using the info provided.

Each block in general estimates a moving average of its associated curvature matrix. If you don’t want a moving average you can set ema_old=0 and ema_new=1.

Parameters
  • state ('CurvatureBlock.State') – The state dict for this block to update.

  • estimation_data (Dict[str, Sequence[Array]]) – A map containing data used for updating the curvature matrix estimate for this block. This can be computed by calling the function returned from layer_tags_vjp(). Please see its implementation for more details on the name of the fields and how they are constructed.

  • ema_old (Numeric) – Specifies the weight of the old value when computing the updated estimate in the moving average.

  • ema_new (Numeric) – Specifies the weight of the new value when computing the updated estimate in the moving average.

  • batch_size (Numeric) – The batch size used in computing the values in info.

  • pmap_axis_name (Optional[str]) – The name of any pmap axis, which might be needed for computing the updates.

Return type

‘CurvatureBlock.State’

update_cache(state, identity_weight, exact_powers, approx_powers, eigenvalues)[source]

Updates the cached estimates of the different powers specified.

Parameters
  • state ('CurvatureBlock.State') – The state dict for this block to update.

  • identity_weight (Numeric) – The weight of the identity added to the block’s curvature matrix before computing the cached matrix power.

  • exact_powers (Optional[ScalarOrSequence]) – Specifies any cached exact matrix powers to be updated.

  • approx_powers (Optional[ScalarOrSequence]) – Specifies any cached approximate matrix powers to be updated.

  • eigenvalues (bool) – Specifies whether to update the cached eigenvalues of the block. If they have not been cached before, this will create an entry with them in the block’s cache.

Return type

‘CurvatureBlock.State’

Returns

The updated state.

to_dense_matrix(state)[source]

Returns a dense representation of the approximate curvature matrix.

Return type

Array

ScaledIdentity

class kfac_jax.ScaledIdentity(layer_tag_eq, name, scale=1.0)[source]

A block that assumes that the curvature is a scaled identity matrix.

__init__(layer_tag_eq, name, scale=1.0)[source]

Initializes the block.

Parameters
  • layer_tag_eq (tags.LayerTagEqn) – The Jax equation corresponding to the layer tag, that this block will approximate the curvature to.

  • name (str) – The name of this block.

  • scale (Numeric) – The scale of the identity matrix.

fixed_scale()[source]

A fixed scalar pre-factor of the curvature (e.g. constant).

Return type

Numeric

update_curvature_matrix_estimate(state, estimation_data, ema_old, ema_new, batch_size, pmap_axis_name)[source]

Updates the block’s curvature estimates using the info provided.

Each block in general estimates a moving average of its associated curvature matrix. If you don’t want a moving average you can set ema_old=0 and ema_new=1.

Parameters
  • state (CurvatureBlock.State) – The state dict for this block to update.

  • estimation_data (Dict[str, Sequence[Array]]) – A map containing data used for updating the curvature matrix estimate for this block. This can be computed by calling the function returned from layer_tags_vjp(). Please see its implementation for more details on the name of the fields and how they are constructed.

  • ema_old (Numeric) – Specifies the weight of the old value when computing the updated estimate in the moving average.

  • ema_new (Numeric) – Specifies the weight of the new value when computing the updated estimate in the moving average.

  • batch_size (Numeric) – The batch size used in computing the values in info.

  • pmap_axis_name (Optional[str]) – The name of any pmap axis, which might be needed for computing the updates.

Return type

CurvatureBlock.State

Diagonal

class kfac_jax.Diagonal(layer_tag_eq, name)[source]

An abstract class for approximating only the diagonal of curvature.

class State(cache, diagonal_factors)[source]

Persistent state of the block.

diagonal_factors

A tuple of the moving averages of the estimated diagonals of the curvature for each parameter that is part of the associated layer.

Type

Tuple[utils.WeightedMovingAverage]

__init__(cache, diagonal_factors)
update_curvature_matrix_estimate(state, estimation_data, ema_old, ema_new, batch_size, pmap_axis_name)[source]

Updates the block’s curvature estimates using the info provided.

Each block in general estimates a moving average of its associated curvature matrix. If you don’t want a moving average you can set ema_old=0 and ema_new=1.

Parameters
  • state ('Diagonal.State') – The state dict for this block to update.

  • estimation_data (Dict[str, Sequence[Array]]) – A map containing data used for updating the curvature matrix estimate for this block. This can be computed by calling the function returned from layer_tags_vjp(). Please see its implementation for more details on the name of the fields and how they are constructed.

  • ema_old (Numeric) – Specifies the weight of the old value when computing the updated estimate in the moving average.

  • ema_new (Numeric) – Specifies the weight of the new value when computing the updated estimate in the moving average.

  • batch_size (Numeric) – The batch size used in computing the values in info.

  • pmap_axis_name (Optional[str]) – The name of any pmap axis, which might be needed for computing the updates.

Return type

‘Diagonal.State’

Full

class kfac_jax.Full(layer_tag_eq, name, eigen_decomposition_threshold=None)[source]

An abstract class for approximating the block matrix with a full matrix.

class State(cache, matrix)[source]

Persistent state of the block.

matrix

A moving average of the estimated curvature matrix for all parameters that are part of the associated layer.

Type

utils.WeightedMovingAverage

__init__(cache, matrix)
__init__(layer_tag_eq, name, eigen_decomposition_threshold=None)[source]

Initializes the block.

Parameters
  • layer_tag_eq (tags.LayerTagEqn) – The Jax equation corresponding to the layer tag that this block will approximate the curvature to.

  • name (str) – The name of this block.

  • eigen_decomposition_threshold (Optional[int]) – During calls to init and update_cache if higher number of matrix powers than this threshold are requested, instead of computing individual approximate powers, will directly compute the eigen-decomposition instead (which provide access to any matrix power). If this is None will use the value returned from get_default_eigen_decomposition_threshold().

parameters_list_to_single_vector(parameters_shaped_list)[source]

Converts values corresponding to parameters of the block to vector.

Return type

Array

single_vector_to_parameters_list(vector)[source]

Reverses the transformation self.parameters_list_to_single_vector.

Return type

Tuple[Array, …]

update_curvature_matrix_estimate(state, estimation_data, ema_old, ema_new, batch_size, pmap_axis_name)[source]

Updates the block’s curvature estimates using the info provided.

Each block in general estimates a moving average of its associated curvature matrix. If you don’t want a moving average you can set ema_old=0 and ema_new=1.

Parameters
  • state ('Full.State') – The state dict for this block to update.

  • estimation_data (Dict[str, Sequence[Array]]) – A map containing data used for updating the curvature matrix estimate for this block. This can be computed by calling the function returned from layer_tags_vjp(). Please see its implementation for more details on the name of the fields and how they are constructed.

  • ema_old (Numeric) – Specifies the weight of the old value when computing the updated estimate in the moving average.

  • ema_new (Numeric) – Specifies the weight of the new value when computing the updated estimate in the moving average.

  • batch_size (Numeric) – The batch size used in computing the values in info.

  • pmap_axis_name (Optional[str]) – The name of any pmap axis, which might be needed for computing the updates.

Return type

‘Full.State’

TwoKroneckerFactored

class kfac_jax.TwoKroneckerFactored(layer_tag_eq, name)[source]

A Kronecker factored block for layers with weights and an optional bias.

__init__(layer_tag_eq, name)[source]

Initializes the block.

Parameters
  • layer_tag_eq (tags.LayerTagEqn) – The Jax equation corresponding to the layer tag that this block will approximate the curvature to.

  • name (str) – The name of this block.

property has_bias: bool

Whether this layer’s equation has a bias.

Return type

bool

parameters_shaped_list_to_array(parameters_shaped_list)[source]

Combines all parameters to a single non axis grouped array.

Return type

Array

array_to_parameters_shaped_list(array)[source]

An inverse transformation of self.parameters_shaped_list_to_array.

Return type

Tuple[Array, …]

NaiveDiagonal

class kfac_jax.NaiveDiagonal(layer_tag_eq, name)[source]

Approximates the diagonal of the curvature with in the most obvious way.

The update to the curvature estimate is computed by (sum_i g_i) ** 2 / N. where g_i is the gradient of each individual data point, and N is the batch size.

NaiveFull

class kfac_jax.NaiveFull(layer_tag_eq, name, eigen_decomposition_threshold=None)[source]

Approximates the full curvature with in the most obvious way.

The update to the curvature estimate is computed by (sum_i g_i) (sum_i g_i)^T / N, where g_i is the gradient of each individual data point, and N is the batch size.

DenseDiagonal

class kfac_jax.DenseDiagonal(layer_tag_eq, name)[source]

A Diagonal block specifically for dense layers.

property has_bias: bool

Whether the layer has a bias parameter.

Return type

bool

DenseFull

class kfac_jax.DenseFull(layer_tag_eq, name, eigen_decomposition_threshold=None)[source]

A Full block specifically for dense layers.

DenseTwoKroneckerFactored

class kfac_jax.DenseTwoKroneckerFactored(layer_tag_eq, name)[source]

A TwoKroneckerFactored block specifically for dense layers.

Conv2DDiagonal

class kfac_jax.Conv2DDiagonal(layer_tag_eq, name, max_elements_for_vmap=None)[source]

A Diagonal block specifically for 2D convolution layers.

__init__(layer_tag_eq, name, max_elements_for_vmap=None)[source]

Initializes the block.

Since there is no ‘nice’ formula for computing the average of the tangents for a 2D convolution, what we do is that we have a function - self.conv2d_tangent_squared - that computes for a single feature map the square of the tangents for the kernel of the convolution. To average over the batch we have two choices - vmap or loop over the batch sequentially using scan. This utility function provides a trade-off by being able to specify the maximum number of batch size that we can vmap over. This means that the maximum memory usage will be max_batch_size_for_vmap times the memory needed when calling self.conv2d_tangent_squared. And the actual vmap will be called ceil(total_batch_size / max_batch_size_for_vmap) number of times in a loop to find the final average.

Parameters
  • layer_tag_eq (tags.LayerTagEqn) – The Jax equation corresponding to the layer tag, that this block will approximate the curvature to.

  • name (str) – The name of this block.

  • max_elements_for_vmap (Optional[int]) – The threshold used for determining how much computation to the in parallel and how much in serial manner. If None will use the value returned by get_max_parallel_elements().

conv2d_tangent_squared(image_features_map, output_tangent)[source]

Computes the elementwise square of a tangent for a single feature map.

Return type

Array

Conv2DFull

class kfac_jax.Conv2DFull(layer_tag_eq, name, max_elements_for_vmap=None)[source]

A Full block specifically for 2D convolution layers.

__init__(layer_tag_eq, name, max_elements_for_vmap=None)[source]

Initializes the block.

Since there is no ‘nice’ formula for computing the average of the tangents for a 2D convolution, what we do is that we have a function - self.conv2d_tangent_squared - that computes for a single feature map the square of the tangents for the kernel of the convolution. To average over the batch we have two choices - vmap or loop over the batch sequentially using scan. This utility function provides a trade-off by being able to specify the maximum batch that that will be handled in a single iteration of the loop. This means that the maximum memory usage will be max_batch_size_for_vmap times the memory needed when calling self.conv2d_tangent_squared. And the actual vmap will be called ceil(total_batch_size / max_batch_size_for_vmap) number of times in a loop to find the final average.

Parameters
  • layer_tag_eq (tags.LayerTagEqn) – The Jax equation corresponding to the layer tag, that this block will approximate the curvature to.

  • name (str) – The name of this block.

  • max_elements_for_vmap (Optional[int]) – The threshold used for determining how much computation to the in parallel and how much in serial manner. If None will use the value returned by get_max_parallel_elements().

conv2d_tangent_outer_product(inputs, tangent_of_outputs)[source]

Computes the outer product of a tangent for a single feature map.

Return type

Array

Conv2DTwoKroneckerFactored

class kfac_jax.Conv2DTwoKroneckerFactored(layer_tag_eq, name)[source]

A TwoKroneckerFactored block specifically for 2D convolution layers.

fixed_scale()[source]

A fixed scalar pre-factor of the curvature (e.g. constant).

Return type

Numeric

property outputs_channel_index: int

The channels index in the outputs of the layer.

Return type

int

property inputs_channel_index: int

The channels index in the inputs of the layer.

Return type

int

property weights_output_channel_index: int

The channels index in weights of the layer.

Return type

int

property weights_spatial_size: int

The spatial filter size of the weights.

Return type

int

property num_locations: int

The number of spatial locations that each filter is applied to.

Return type

int

property num_inputs_channels: int

The number of channels in the inputs to the layer.

Return type

int

property num_outputs_channels: int

The number of channels in the outputs to the layer.

Return type

int

compute_inputs_stats(inputs, weighting_array=None)[source]

Computes the statistics for the inputs factor.

Return type

Array

compute_outputs_stats(tangent_of_output)[source]

Computes the statistics for the outputs factor.

Return type

Array

ScaleAndShiftDiagonal

class kfac_jax.ScaleAndShiftDiagonal(layer_tag_eq, name)[source]

A diagonal approximation specifically for a scale and shift layers.

property has_scale: bool

Whether this layer’s equation has a scale.

Return type

bool

property has_shift: bool

Whether this layer’s equation has a shift.

Return type

bool

ScaleAndShiftFull

class kfac_jax.ScaleAndShiftFull(layer_tag_eq, name, eigen_decomposition_threshold=None)[source]

A full dense approximation specifically for a scale and shift layers.

set_max_parallel_elements

kfac_jax.set_max_parallel_elements(value)[source]

Sets the default value of maximum parallel elements in the module.

This value is used to determine the parallel-to-memory tradeoff in the curvature estimation procedure of Conv2DDiagonal and Conv2DFull. See their corresponding docs for further details.

Parameters

value (int) – The default value for maximum number of parallel elements.

get_max_parallel_elements

kfac_jax.get_max_parallel_elements()[source]

Returns the default value of maximum parallel elements in the module.

This value is used to determine the parallel-to-memory tradeoff in the curvature estimation procedure of Conv2DDiagonal and Conv2DFull. See their corresponding docs for further details.

Return type

int

Returns

The default value for maximum number of parallel elements.

set_default_eigen_decomposition_threshold

kfac_jax.set_default_eigen_decomposition_threshold(value)[source]

Sets the default value of the eigen decomposition threshold.

This value is used in Full to determine when updating the cache, at what number of different powers to switch the implementation from a simple matrix power to an eigenvector decomposition.

Parameters

value (int) – The default value for eigen decomposition threshold.

get_default_eigen_decomposition_threshold

kfac_jax.get_default_eigen_decomposition_threshold()[source]

Returns the default value of the eigen decomposition threshold.

This value is used in Full to determine when updating the cache, at what number of different powers to switch the implementation from a simple matrix power to an eigenvector decomposition.

Return type

int

Returns

The default value of the eigen decomposition threshold.

Advanced Features API

Layer and loss tags

LossTag(cls, parameter_dependants, ...)

A Jax primitive for tagging K-FAC losses.

LayerTag(name, num_inputs, num_outputs)

A Jax primitive for tagging K-FAC layers.

register_generic(parameter)

Registers a generic tag around the provided parameter array.

register_dense(y, x, w[, b])

Registers a dense layer: y = matmul(x, w) + b.

register_conv2d(y, x, w[, b])

Registers a 2d convolution layer: y = conv2d(x, w) + b.

register_scale_and_shift(y, x[, scale, shift])

Registers a scale and shift layer: y = x * scale + shift.

LossTag

class kfac_jax.LossTag(cls, parameter_dependants, parameter_independants)[source]

A Jax primitive for tagging K-FAC losses.

The primitive is no-op at runtime, however its goal is to tag (annotate) the Jax computation graph what expression exactly is the loss and what type of loss it represents. This is the only way for K-FAC to know how to compute the curvature matrix.

__init__(cls, parameter_dependants, parameter_independants)[source]

Initializes a loss tag primitive for the given LossFunction class.

When the primitive is created, the constructor automatically registers it with the standard Jax machinery for differentiation, jax.vmap() and XLA lowering. For further details see please take a look at the JAX documentation on primitives.

Parameters
  • cls (Type[T]) – The corresponding class of LossFunction that this tag represents.

  • parameter_dependants (Sequence[str]) – The names of each of the parameter dependent inputs to the tag.

  • parameter_independants (Sequence[str]) – The names of each of the parameter independent inputs to the tag.

property parameter_dependants_names: Tuple[str, ...]

The number of parameter dependent inputs to the tag primitive.

Return type

Tuple[str, …]

property parameter_independants_names: Tuple[str, ...]

The number of parameter independent inputs to the tag primitive.

Return type

Tuple[str, …]

loss(*args, args_names)[source]

Constructs an instance of the corresponding LossFunction class.

Return type

T

get_outputs(*args, args_names)[source]

Verifies that the number of arguments matches expectations.

Return type

Tuple[ArrayOrXla, …]

LayerTag

class kfac_jax.LayerTag(name, num_inputs, num_outputs)[source]

A Jax primitive for tagging K-FAC layers.

The primitive is no-op at runtime, however its goal is to tag (annotate) the Jax computation graph what expressions represents a single unique layer type. This is the only way for K-FAC to know how to compute the curvature matrix.

__init__(name, num_inputs, num_outputs)[source]

Initializes a layer tag primitive with the given name.

Any layer tag primitive must have the following interface layer_tag( *outputs, *inputs, *parameters, **kwargs). We refer collectively to inputs , outputs and parameters as operands. All operands must be Jax arrays, while any of the values in kwargs must be hashable fixed constants.

When the primitive is created, the constructor automatically registers it with the standard Jax machinery for differentiation, jax.vmap() and XLA lowering. For further details see please take a look at the JAX documentation on primitives.

Parameters
  • name (str) – The name of the layer primitive.

  • num_inputs (int) – The number of inputs to the layer.

  • num_outputs (int) – The number of outputs to the layer.

property num_outputs: int

The number of outputs of the layer tag that this primitive represents.

Return type

int

property num_inputs: int

The number of inputs of the layer tag that this primitive represents.

Return type

int

split_all_inputs(all_inputs)[source]

Splits the operands of the primitive into (outputs, inputs, params).

Return type

Tuple[Tuple[T, …], Tuple[T, …], Tuple[T, …]]

get_outputs(*operands, **_)[source]

Extracts the outputs of a layer from the operands of the primitive.

Return type

Array

register_generic

kfac_jax.register_generic(parameter)[source]

Registers a generic tag around the provided parameter array.

Return type

Array

register_dense

kfac_jax.register_dense(y, x, w, b=None, **kwargs)[source]

Registers a dense layer: y = matmul(x, w) + b.

Return type

Array

register_conv2d

kfac_jax.register_conv2d(y, x, w, b=None, **kwargs)[source]

Registers a 2d convolution layer: y = conv2d(x, w) + b.

Return type

Array

register_scale_and_shift

kfac_jax.register_scale_and_shift(y, x, scale=None, shift=None)[source]

Registers a scale and shift layer: y = x * scale + shift.

Return type

Array

Automatic Tags Registration

auto_register_tags(func, func_args[, ...])

Transforms the function by automatically registering layer tags.

auto_register_tags

kfac_jax.auto_register_tags(func, func_args, params_index=0, register_only_generic=False, compute_only_loss_tags=True, patterns_to_skip=(), allow_multiple_registrations=False, graph_matcher_rules=<kfac_jax._src.tag_graph_matcher.GraphMatcherComparator object>, graph_patterns=(GraphPattern(name='dense_with_bias', tag_primitive=dense_tag, compute_func=<function _dense>, parameters_extractor_func=<function _dense_parameter_extractor>, example_args=[array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]), [array([[0., 0., 0., 0., 0., 0., 0.],        [0., 0., 0., 0., 0., 0., 0.],        [0., 0., 0., 0., 0., 0., 0.],        [0., 0., 0., 0., 0., 0., 0.],        [0., 0., 0., 0., 0., 0., 0.],        [0., 0., 0., 0., 0., 0., 0.],        [0., 0., 0., 0., 0., 0., 0.],        [0., 0., 0., 0., 0., 0., 0.],        [0., 0., 0., 0., 0., 0., 0.],        [0., 0., 0., 0., 0., 0., 0.],        [0., 0., 0., 0., 0., 0., 0.],        [0., 0., 0., 0., 0., 0., 0.],        [0., 0., 0., 0., 0., 0., 0.]]), array([0., 0., 0., 0., 0., 0., 0.])]], in_values_preprocessor=None, _graph=None), GraphPattern(name='dense_with_bias', tag_primitive=dense_tag, compute_func=<function _dense_with_reshape>, parameters_extractor_func=<function _dense_parameter_extractor>, example_args=[array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]), [array([[0., 0., 0., 0., 0., 0., 0.],        [0., 0., 0., 0., 0., 0., 0.],        [0., 0., 0., 0., 0., 0., 0.],        [0., 0., 0., 0., 0., 0., 0.],        [0., 0., 0., 0., 0., 0., 0.],        [0., 0., 0., 0., 0., 0., 0.],        [0., 0., 0., 0., 0., 0., 0.],        [0., 0., 0., 0., 0., 0., 0.],        [0., 0., 0., 0., 0., 0., 0.],        [0., 0., 0., 0., 0., 0., 0.],        [0., 0., 0., 0., 0., 0., 0.],        [0., 0., 0., 0., 0., 0., 0.],        [0., 0., 0., 0., 0., 0., 0.]]), array([0., 0., 0., 0., 0., 0., 0.])]], in_values_preprocessor=None, _graph=None), GraphPattern(name='dense_no_bias', tag_primitive=dense_tag, compute_func=<function _dense>, parameters_extractor_func=<function _dense_parameter_extractor>, example_args=[array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]), [array([[0., 0., 0., 0., 0., 0., 0.],        [0., 0., 0., 0., 0., 0., 0.],        [0., 0., 0., 0., 0., 0., 0.],        [0., 0., 0., 0., 0., 0., 0.],        [0., 0., 0., 0., 0., 0., 0.],        [0., 0., 0., 0., 0., 0., 0.],        [0., 0., 0., 0., 0., 0., 0.],        [0., 0., 0., 0., 0., 0., 0.],        [0., 0., 0., 0., 0., 0., 0.],        [0., 0., 0., 0., 0., 0., 0.],        [0., 0., 0., 0., 0., 0., 0.],        [0., 0., 0., 0., 0., 0., 0.],        [0., 0., 0., 0., 0., 0., 0.]])]], in_values_preprocessor=None, _graph=None), GraphPattern(name='conv2d_with_bias', tag_primitive=conv2d_tag, compute_func=<function _conv2d>, parameters_extractor_func=<function _conv2d_parameter_extractor>, example_args=[array([[[[0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.]],          [[0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.]],          [[0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.]],          [[0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.]],          [[0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.]],          [[0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.]],          [[0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.]],          [[0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.]]],          [[[0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.]],          [[0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.]],          [[0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.]],          [[0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.]],          [[0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.]],          [[0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.]],          [[0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.]],          [[0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.]]]]), [array([[[[0., 0., 0., 0.],          [0., 0., 0., 0.],          [0., 0., 0., 0.],          [0., 0., 0., 0.],          [0., 0., 0., 0.]],          [[0., 0., 0., 0.],          [0., 0., 0., 0.],          [0., 0., 0., 0.],          [0., 0., 0., 0.],          [0., 0., 0., 0.]],          [[0., 0., 0., 0.],          [0., 0., 0., 0.],          [0., 0., 0., 0.],          [0., 0., 0., 0.],          [0., 0., 0., 0.]]],          [[[0., 0., 0., 0.],          [0., 0., 0., 0.],          [0., 0., 0., 0.],          [0., 0., 0., 0.],          [0., 0., 0., 0.]],          [[0., 0., 0., 0.],          [0., 0., 0., 0.],          [0., 0., 0., 0.],          [0., 0., 0., 0.],          [0., 0., 0., 0.]],          [[0., 0., 0., 0.],          [0., 0., 0., 0.],          [0., 0., 0., 0.],          [0., 0., 0., 0.],          [0., 0., 0., 0.]]],          [[[0., 0., 0., 0.],          [0., 0., 0., 0.],          [0., 0., 0., 0.],          [0., 0., 0., 0.],          [0., 0., 0., 0.]],          [[0., 0., 0., 0.],          [0., 0., 0., 0.],          [0., 0., 0., 0.],          [0., 0., 0., 0.],          [0., 0., 0., 0.]],          [[0., 0., 0., 0.],          [0., 0., 0., 0.],          [0., 0., 0., 0.],          [0., 0., 0., 0.],          [0., 0., 0., 0.]]]]), array([0., 0., 0., 0.])]], in_values_preprocessor=None, _graph=None), GraphPattern(name='conv2d_no_bias', tag_primitive=conv2d_tag, compute_func=<function _conv2d>, parameters_extractor_func=<function _conv2d_parameter_extractor>, example_args=[array([[[[0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.]],          [[0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.]],          [[0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.]],          [[0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.]],          [[0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.]],          [[0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.]],          [[0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.]],          [[0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.]]],          [[[0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.]],          [[0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.]],          [[0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.]],          [[0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.]],          [[0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.]],          [[0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.]],          [[0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.]],          [[0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.]]]]), [array([[[[0., 0., 0., 0.],          [0., 0., 0., 0.],          [0., 0., 0., 0.],          [0., 0., 0., 0.],          [0., 0., 0., 0.]],          [[0., 0., 0., 0.],          [0., 0., 0., 0.],          [0., 0., 0., 0.],          [0., 0., 0., 0.],          [0., 0., 0., 0.]],          [[0., 0., 0., 0.],          [0., 0., 0., 0.],          [0., 0., 0., 0.],          [0., 0., 0., 0.],          [0., 0., 0., 0.]]],          [[[0., 0., 0., 0.],          [0., 0., 0., 0.],          [0., 0., 0., 0.],          [0., 0., 0., 0.],          [0., 0., 0., 0.]],          [[0., 0., 0., 0.],          [0., 0., 0., 0.],          [0., 0., 0., 0.],          [0., 0., 0., 0.],          [0., 0., 0., 0.]],          [[0., 0., 0., 0.],          [0., 0., 0., 0.],          [0., 0., 0., 0.],          [0., 0., 0., 0.],          [0., 0., 0., 0.]]],          [[[0., 0., 0., 0.],          [0., 0., 0., 0.],          [0., 0., 0., 0.],          [0., 0., 0., 0.],          [0., 0., 0., 0.]],          [[0., 0., 0., 0.],          [0., 0., 0., 0.],          [0., 0., 0., 0.],          [0., 0., 0., 0.],          [0., 0., 0., 0.]],          [[0., 0., 0., 0.],          [0., 0., 0., 0.],          [0., 0., 0., 0.],          [0., 0., 0., 0.],          [0., 0., 0., 0.]]]])]], in_values_preprocessor=None, _graph=None), GraphPattern(name='scale_and_shift_broadcast_1', tag_primitive=scale_and_shift_tag, compute_func=functools.partial(<function _scale_and_shift>, has_scale=True, has_shift=True), parameters_extractor_func=<function _make_scale_and_shift_pattern.<locals>.<lambda>>, example_args=[array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]), [array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]), array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])]], in_values_preprocessor=None, _graph=None), GraphPattern(name='scale_and_shift_broadcast_0', tag_primitive=scale_and_shift_tag, compute_func=functools.partial(<function _scale_and_shift>, has_scale=True, has_shift=True), parameters_extractor_func=<function _make_scale_and_shift_pattern.<locals>.<lambda>>, example_args=[array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]), [array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]), array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])]], in_values_preprocessor=None, _graph=None), GraphPattern(name='normalization_haiku_broadcast_1', tag_primitive=scale_and_shift_tag, compute_func=functools.partial(<function _normalization_haiku>, has_scale=True, has_shift=True), parameters_extractor_func=<function _make_normalization_haiku_pattern.<locals>.<lambda>>, example_args=[[array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]), array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])], [array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]), array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])]], in_values_preprocessor=<function _normalization_haiku_preprocessor>, _graph=None), GraphPattern(name='normalization_haiku_broadcast_0', tag_primitive=scale_and_shift_tag, compute_func=functools.partial(<function _normalization_haiku>, has_scale=True, has_shift=True), parameters_extractor_func=<function _make_normalization_haiku_pattern.<locals>.<lambda>>, example_args=[[array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]), array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])], [array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]), array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])]], in_values_preprocessor=<function _normalization_haiku_preprocessor>, _graph=None), GraphPattern(name='scale_only_broadcast_1', tag_primitive=scale_and_shift_tag, compute_func=functools.partial(<function _scale_and_shift>, has_scale=True, has_shift=False), parameters_extractor_func=<function _make_scale_and_shift_pattern.<locals>.<lambda>>, example_args=[array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]), [array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])]], in_values_preprocessor=None, _graph=None), GraphPattern(name='scale_only_broadcast_0', tag_primitive=scale_and_shift_tag, compute_func=functools.partial(<function _scale_and_shift>, has_scale=True, has_shift=False), parameters_extractor_func=<function _make_scale_and_shift_pattern.<locals>.<lambda>>, example_args=[array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]), [array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])]], in_values_preprocessor=None, _graph=None), GraphPattern(name='shift_only_broadcast_1', tag_primitive=scale_and_shift_tag, compute_func=functools.partial(<function _scale_and_shift>, has_scale=False, has_shift=True), parameters_extractor_func=<function _make_scale_and_shift_pattern.<locals>.<lambda>>, example_args=[array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]), [array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])]], in_values_preprocessor=None, _graph=None), GraphPattern(name='shift_only_broadcast_0', tag_primitive=scale_and_shift_tag, compute_func=functools.partial(<function _scale_and_shift>, has_scale=False, has_shift=True), parameters_extractor_func=<function _make_scale_and_shift_pattern.<locals>.<lambda>>, example_args=[array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]), [array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])]], in_values_preprocessor=None, _graph=None)))[source]

Transforms the function by automatically registering layer tags.

Parameters
  • func (utils.Func) – The original function to transform.

  • func_args (utils.FuncArgs) – Example arguments to func which to be used for tracing it.

  • params_index (int) – Specifies, which inputs to the function are to be considered a parameter variable. Specifically - inputs[params_index].

  • register_only_generic (bool) – If True registers all parameters not already in a layer tag with a generic tag, effectively ignoring graph_patterns.

  • compute_only_loss_tags (bool) – If set to True (default) the resulting function will only compute the loss tags in func, not its full computation and actual output.

  • patterns_to_skip (Sequence[str]) – The names of any patterns from the provided list, which to be skipped/not used during the pattern matching.

  • allow_multiple_registrations (bool) – Whether to raise an error if a parameter is registered with more than one layer tag.

  • graph_matcher_rules (GraphMatcherComparator) – A GraphMatcherRules instance, which is used for determining equivalence of individual Jax primitives.

  • graph_patterns (Sequence[GraphPattern]) – A sequence of GraphPattern objects, which contain all patterns to use, in order of precedence, which to try to find in the graph before registering a parameter with a generic layer tag.

Return type

TaggedFunction

Returns

A transformed function as described above.

Function tracing and Jacobian computation

ProcessedJaxpr(jaxpr, consts, in_tree, ...)

A wrapper around Jaxpr, with useful additional data.

loss_tags_vjp(func[, params_index])

Creates a function for the vector-Jacobian product w.r.t.

loss_tags_jvp(func[, params_index])

Creates a function for the Jacobian-vector product w.r.t.

loss_tags_hvp(func[, params_index])

Creates a function for the Hessian-vector product w.r.t.

layer_tags_vjp(func[, params_index, ...])

Creates a function for primal values and tangents w.r.t.

ProcessedJaxpr

class kfac_jax.ProcessedJaxpr(jaxpr, consts, in_tree, params_index, allow_left_out_params=False)[source]

A wrapper around Jaxpr, with useful additional data.

jaxpr

The original Jaxpr that is being wrapped.

consts

The constants returned from the tracing of the original Jaxpr.

in_tree

The PyTree structure of the inputs to the function that the original Jaxpr has been created from.

params_index

Specifies, which inputs to the function are to be considered a parameter variable. Specifically - inputs[params_index].

loss_tags

A tuple of all of the loss tags in the original Jaxpr.

layer_tags

A sorted tuple of all of the layer tags in the original Jaxpr. The sorting order is based on the indices of the parameters associated with each layer tag.

layer_indices

A sequence of tuples, where each tuple has the indices of the parameters associated with the corresponding layer tag.

__init__(jaxpr, consts, in_tree, params_index, allow_left_out_params=False)[source]

Initializes the instance.

Parameters
  • jaxpr (jax.core.Jaxpr) – The raw Jaxpr.

  • consts (Sequence[Any]) – The constants needed for evaluation of the raw Jaxpr.

  • in_tree (utils.PyTreeDef) – The PyTree structure of the inputs to the function that the jaxpr has been created from.

  • params_index (int) – Specifies, which inputs to the function are to be considered a parameter variable. Specifically - inputs[params_index].

  • allow_left_out_params (bool) – Whether to raise an error if any of the parameter variables is not included in any layer tag.

property in_vars_flat: List[Var]

A flat list of all of the abstract input variables.

Return type

List[Var]

property in_vars: utils.PyTree[Var]

The abstract input variables, as an un-flatten structure.

Return type

utils.PyTree[Var]

property params_vars: utils.PyTree[Var]

The abstract parameter variables, as an un-flatten structure.

Return type

utils.PyTree[Var]

property params_vars_flat: List[Var]

A flat list of all abstract parameter variables.

Return type

List[Var]

property params_tree: utils.PyTreeDef

The PyTree structure of the parameter variables.

Return type

utils.PyTreeDef

classmethod make_from_func(func, func_args, params_index=0, auto_register_tags=True, allow_left_out_params=False, **auto_registration_kwargs)[source]

Constructs a ProcessedJaxpr from a the given function.

Parameters
  • func (utils.Func) – The model function, which will be traced.

  • func_args (FuncArgs) – Function arguments to use for tracing.

  • params_index (int) – The variables from the function arguments which are at this index (e.g. func_args[params_index]) are to be considered model parameters.

  • auto_register_tags (bool) – Whether to run an automatic layer registration on the function (e.g. auto_register_tags()).

  • allow_left_out_params (bool) – If this is set to False an error would be raised if there are any model parameters that have not be assigned to a layer tag.

  • **auto_registration_kwargs – Any additional keyword arguments, to be passed to the automatic registration pass.

Return type

ProcJaxpr

Returns

A ProcessedJaxpr representing the model function.

loss_tags_vjp

kfac_jax.loss_tags_vjp(func, params_index=0)[source]

Creates a function for the vector-Jacobian product w.r.t. all loss tags.

The returned function has a similar interface to jax.vjp(). It takes as inputs the concrete values of the primals at which the Jacobian will be evaluated. It returns a pair (losses, losses_vjp), where losses is a tuple of LossFunction objects and vjp_func is a function taking as inputs the concrete values of the tangents of the inputs for each loss tag (corresponding to a loss object in losses) and returns the corresponding tangents of the parameters.

Parameters
  • func (utils.Func) – The model function, which must include at least one loss registration.

  • params_index (int) – The variables from the function arguments which are at this index (e.g. func_args[params_index]) are to be considered model parameters.

Return type

TransformedFunction[LossTagsVjp]

Returns

A function that computes the vector-Jacobian product with signature Callable[[FuncArgs], LossTagsVjp].

loss_tags_jvp

kfac_jax.loss_tags_jvp(func, params_index=0)[source]

Creates a function for the Jacobian-vector product w.r.t. all loss tags.

The returned function has a similar interface to jax.jvp(). It takes as inputs the concrete values of the primals at which the Jacobian will be evaluated at and the concrete values of the tangents for the parameters, as specified by processed_jaxpr.params_index. It returns a pair (losses, losses_tangents), where losses is a tuple of LossFunction objects, and losses_tangents is a tuple containing the tangents of the inputs for each loss tag (corresponding to a loss object in losses).

Parameters
  • func (utils.Func) – The model function, which must include at least one loss registration.

  • params_index (int) – The variables from the function arguments which are at this index (e.g. func_args[params_index]) are to be considered model parameters.

Return type

Returns

A function that computes the Jacobian-vector product with signature Callable[[FuncArgs, Params], LossTagsVjp].

loss_tags_hvp

kfac_jax.loss_tags_hvp(func, params_index=0)[source]

Creates a function for the Hessian-vector product w.r.t. all loss tags.

The returned function takes as inputs the concrete values of the primals for the function arguments at which the Hessian will be evaluated at and the concrete values of the tangents for the parameters, as specified by processed_jaxpr.params_index. It returns the product of the Hessian with these tangents via backward-over-forward mode autodiff.

Parameters
  • func (utils.Func) – The model function, which must include at least one loss registration.

  • params_index (int) – The variables from the function arguments which are at this index (e.g. func_args[params_index]) are to be considered model parameters.

Return type

Returns

A function that computes the Hessian-vector product and also returns all losses, with signature Callable[[FuncArgs, Params], Tuple[LossTagsVjp, Tuple[loss_functions.LossFunction, …]].

layer_tags_vjp

kfac_jax.layer_tags_vjp(func, params_index=0, auto_register_tags=True, raise_error_on_diff_jaxpr=True, **auto_registration_kwargs)[source]

Creates a function for primal values and tangents w.r.t. all layer tags.

The returned function has a similar interface to jax.vjp(). It takes as inputs the concrete values of the primals at which the Jacobian will be evaluated. It returns a pair (losses, vjp_func), where losses is a tuple of LossFunction objects, and vjp_func is a function taking as inputs the concrete values of the tangents of the inputs for each loss tag (corresponding to a loss object in losses) and returns a list of quantities computed for each layer tag in processed_jaxpr. Each entry of the list is a dictionary with the following self-explanatory keys: inputs, outputs, params, outputs_tangents, params_tangents.

Parameters
  • func (utils.Func) – The model function, which must include at least one loss registration.

  • params_index (int) – The variables from the function arguments which are at this index (e.g. func_args[params_index]) are to be considered model parameters.

  • auto_register_tags (bool) – Whether to run an automatic layer registration on the function (e.g. auto_register_tags()).

  • raise_error_on_diff_jaxpr (bool) – When tracing with different arguments, if the returned jaxpr has a different graph will raise an exception.

  • **auto_registration_kwargs – Any additional keyword arguments, to be passed to the automatic registration pass.

Return type

Returns

Returns a function that computes primal values and tangents wrt all layer tags, with signature Callable[[FuncArgs, Params], LossTagsVjp].

Patches Second Moments

patches_moments(inputs, kernel_spatial_shape)

Computes the first and second moment of the convolutional patches.

patches_moments_explicit(inputs, ...[, ...])

The exact same functionality as patches_moments(), but explicitly extracts the patches via jax.lax.conv_general_dilated_patches(), potentially having a higher memory usage.

patches_moments

kfac_jax.patches_moments(inputs, kernel_spatial_shape, strides=1, padding='VALID', data_format='NHWC', dim_numbers=None, inputs_dilation=None, kernel_dilation=None, feature_group_count=1, batch_group_count=1, unroll_loop=False, precision=None, weighting_array=None)[source]

Computes the first and second moment of the convolutional patches.

Since the code is written to support arbitrary convolution data formats, e.g. both NHWC and NCHW, in comments above any of the procedures is written the simplified version of what the statements below do, if the data format was fixed to NHWC.

Parameters
  • inputs (Array) – The batch of images.

  • kernel_spatial_shape (Union[int, Shape]) – The spatial dimensions of the filter (int or list of ints).

  • strides (Union[int, Shape]) – The spatial dimensions of the strides (int or list of ints).

  • padding (PaddingVariants) – The padding (str or list of pairs of ints).

  • data_format (Optional[str]) – The data format of the inputs (None, NHWC, NCHW).

  • dim_numbers (Optional[Union[DimNumbers, lax.ConvDimensionNumbers]]) – Instance of jax.lax.ConvDimensionNumbers instead of data_format.

  • inputs_dilation (Optional[Sequence[int]]) – An integer or sequence of integers, specifying the dilation for the image. Currently, patches_moments does not support dilation, so the only allowed values are None, 1, (1,1).

  • kernel_dilation (Optional[Sequence[int]]) – An integer or sequence of integers, specifying the dilation for the kernel. Currently, patches_moments does not support dilation, so the only allowed values are None, 1, (1,1).

  • feature_group_count (int) – The feature grouping for grouped convolutions. Currently, patches_moments supports only 1 and number of channels of the inputs.

  • batch_group_count (int) – The batch grouping for grouped convolutions. Currently, patches_moments supports only 1.

  • unroll_loop (bool) – Whether to unroll the loop in python.

  • precision (Optional[jax.lax.Precision]) – In what precision to run the computation. For more details please read Jax documentation of jax.lax.conv_general_dilated().

  • weighting_array (Optional[Array]) – A tensor specifying additional weighting of each element of the moment’s average.

Return type

Tuple[Array, Array]

Returns

The matrix of the patches’ second and first moment as a pair. The tensor of the patches’ second moment has a shape kernel_spatial_shape + (, channels) + kernel_spatial_shape + (, channels). The tensor of the patches’ first moment has a shape kernel_spatial_shape + (, channels).

patches_moments_explicit

kfac_jax.patches_moments_explicit(inputs, kernel_spatial_shape, strides=1, padding='VALID', data_format='NHWC', dim_numbers=None, inputs_dilation=None, kernel_dilation=None, feature_group_count=1, batch_group_count=1, unroll_loop=False, precision=None, weighting_array=None)[source]

The exact same functionality as patches_moments(), but explicitly extracts the patches via jax.lax.conv_general_dilated_patches(), potentially having a higher memory usage.

Return type

Tuple[Array, Array]