KFAC-JAX
stable
Guides
Quickstart
High Level Overview
Curvature Estimation
Optimizer
API Documentation
Standard API
Advanced Features API
Advanced Topics
KFAC-JAX internals
Extending KFAC-JAX
KFAC-JAX
»
Index
Edit on GitHub
Index
_
|
A
|
B
|
C
|
D
|
E
|
F
|
G
|
H
|
I
|
J
|
L
|
M
|
N
|
O
|
P
|
R
|
S
|
T
|
U
|
V
|
W
_
__init__() (kfac_jax.BlockDiagonalCurvature method)
(kfac_jax.BlockDiagonalCurvature.State method)
(kfac_jax.CategoricalLogitsNegativeLogProbLoss method)
(kfac_jax.Conv2DDiagonal method)
(kfac_jax.Conv2DFull method)
(kfac_jax.CurvatureBlock method)
(kfac_jax.CurvatureBlock.State method)
(kfac_jax.CurvatureEstimator method)
(kfac_jax.Diagonal.State method)
(kfac_jax.ExplicitExactCurvature method)
(kfac_jax.Full method)
(kfac_jax.Full.State method)
(kfac_jax.ImplicitExactCurvature method)
(kfac_jax.LayerTag method)
(kfac_jax.LossFunction method)
(kfac_jax.LossTag method)
(kfac_jax.MultiBernoulliNegativeLogProbLoss method)
(kfac_jax.NormalMeanNegativeLogProbLoss method)
(kfac_jax.NormalMeanVarianceNegativeLogProbLoss method)
(kfac_jax.Optimizer method)
(kfac_jax.Optimizer.State method)
(kfac_jax.ProcessedJaxpr method)
(kfac_jax.ScaledIdentity method)
(kfac_jax.TwoKroneckerFactored method)
A
array_to_parameters_shaped_list() (kfac_jax.TwoKroneckerFactored method)
auto_register_tags() (in module kfac_jax)
B
batch_index (kfac_jax.ExplicitExactCurvature property)
batch_size() (kfac_jax.ImplicitExactCurvature method)
block_dims (kfac_jax.BlockDiagonalCurvature property)
block_eigenvalues() (kfac_jax.BlockDiagonalCurvature method)
BlockDiagonalCurvature (class in kfac_jax)
BlockDiagonalCurvature.State (class in kfac_jax)
blocks (kfac_jax.BlockDiagonalCurvature property)
blocks_states (kfac_jax.BlockDiagonalCurvature.State attribute)
blocks_vectors_to_params_vector() (kfac_jax.BlockDiagonalCurvature method)
(kfac_jax.ExplicitExactCurvature method)
burnin() (kfac_jax.Optimizer method)
C
cache (kfac_jax.CurvatureBlock.State attribute)
CategoricalLogitsNegativeLogProbLoss (class in kfac_jax)
compute_approx_quad_model() (kfac_jax.Optimizer method)
compute_exact_quad_model() (kfac_jax.Optimizer method)
compute_inputs_stats() (kfac_jax.Conv2DTwoKroneckerFactored method)
compute_l2_quad_matrix() (kfac_jax.Optimizer method)
compute_loss_value() (kfac_jax.Optimizer method)
compute_outputs_stats() (kfac_jax.Conv2DTwoKroneckerFactored method)
compute_quadratic_model_value() (kfac_jax.Optimizer method)
consts (kfac_jax.ProcessedJaxpr attribute)
conv2d_tangent_outer_product() (kfac_jax.Conv2DFull method)
conv2d_tangent_squared() (kfac_jax.Conv2DDiagonal method)
Conv2DDiagonal (class in kfac_jax)
Conv2DFull (class in kfac_jax)
Conv2DTwoKroneckerFactored (class in kfac_jax)
copy_with_different_inputs() (kfac_jax.CategoricalLogitsNegativeLogProbLoss method)
(kfac_jax.LossFunction method)
(kfac_jax.MultiBernoulliNegativeLogProbLoss method)
(kfac_jax.NormalMeanNegativeLogProbLoss method)
(kfac_jax.NormalMeanVarianceNegativeLogProbLoss method)
(kfac_jax.OneHotCategoricalLogitsNegativeLogProbLoss method)
CurvatureBlock (class in kfac_jax)
CurvatureBlock.State (class in kfac_jax)
CurvatureEstimator (class in kfac_jax)
D
damping (kfac_jax.Optimizer.State attribute)
damping_decay_factor (kfac_jax.Optimizer property)
data_seen (kfac_jax.Optimizer.State attribute)
default_estimation_mode (kfac_jax.CurvatureEstimator attribute)
default_mat_type (kfac_jax.CurvatureEstimator property)
DenseDiagonal (class in kfac_jax)
DenseFull (class in kfac_jax)
DenseTwoKroneckerFactored (class in kfac_jax)
Diagonal (class in kfac_jax)
Diagonal.State (class in kfac_jax)
diagonal_factors (kfac_jax.Diagonal.State attribute)
dim (kfac_jax.BlockDiagonalCurvature property)
(kfac_jax.CurvatureBlock property)
(kfac_jax.CurvatureEstimator property)
dist (kfac_jax.CategoricalLogitsNegativeLogProbLoss property)
(kfac_jax.DistributionNegativeLogProbLoss property)
(kfac_jax.MultiBernoulliNegativeLogProbLoss property)
(kfac_jax.NormalMeanNegativeLogProbLoss property)
(kfac_jax.NormalMeanVarianceNegativeLogProbLoss property)
(kfac_jax.OneHotCategoricalLogitsNegativeLogProbLoss property)
DistributionNegativeLogProbLoss (class in kfac_jax)
E
eigenvalues() (kfac_jax.BlockDiagonalCurvature method)
(kfac_jax.CurvatureBlock method)
(kfac_jax.CurvatureEstimator method)
estimator (kfac_jax.Optimizer property)
estimator_state (kfac_jax.Optimizer.State attribute)
evaluate() (kfac_jax.LossFunction method)
ExplicitExactCurvature (class in kfac_jax)
F
fisher_factor_inner_shape (kfac_jax.CategoricalLogitsNegativeLogProbLoss property)
(kfac_jax.DistributionNegativeLogProbLoss property)
(kfac_jax.NegativeLogProbLoss property)
(kfac_jax.NormalMeanVarianceNegativeLogProbLoss property)
fixed_scale() (kfac_jax.Conv2DTwoKroneckerFactored method)
(kfac_jax.CurvatureBlock method)
(kfac_jax.ScaledIdentity method)
Full (class in kfac_jax)
Full.State (class in kfac_jax)
func (kfac_jax.CurvatureEstimator attribute)
G
get_default_eigen_decomposition_threshold() (in module kfac_jax)
get_default_tag_to_block_ctor() (in module kfac_jax)
get_loss_inner_vector_shapes_and_batch_size() (kfac_jax.ImplicitExactCurvature method)
get_loss_input_shapes_and_batch_size() (kfac_jax.ImplicitExactCurvature method)
get_max_parallel_elements() (in module kfac_jax)
get_outputs() (kfac_jax.LayerTag method)
(kfac_jax.LossTag method)
ggn_factor_inner_shape (kfac_jax.LossFunction property)
(kfac_jax.NormalMeanVarianceNegativeLogProbLoss property)
grad_of_evaluate() (kfac_jax.LossFunction method)
grad_of_evaluate_on_sample() (kfac_jax.NegativeLogProbLoss method)
H
has_bias (kfac_jax.DenseDiagonal property)
(kfac_jax.TwoKroneckerFactored property)
has_scale (kfac_jax.ScaleAndShiftDiagonal property)
has_shift (kfac_jax.ScaleAndShiftDiagonal property)
I
ImplicitExactCurvature (class in kfac_jax)
in_tree (kfac_jax.ProcessedJaxpr attribute)
in_vars (kfac_jax.ProcessedJaxpr property)
in_vars_flat (kfac_jax.ProcessedJaxpr property)
indices_to_block_map (kfac_jax.BlockDiagonalCurvature property)
init() (kfac_jax.BlockDiagonalCurvature method)
(kfac_jax.CurvatureBlock method)
(kfac_jax.CurvatureEstimator method)
(kfac_jax.Optimizer method)
inputs_channel_index (kfac_jax.Conv2DTwoKroneckerFactored property)
inputs_shapes (kfac_jax.CurvatureBlock property)
J
jaxpr (kfac_jax.ProcessedJaxpr attribute)
L
l2_reg (kfac_jax.Optimizer property)
layer_indices (kfac_jax.ProcessedJaxpr attribute)
layer_tag_extra_params (kfac_jax.CurvatureBlock property)
layer_tag_primitive (kfac_jax.CurvatureBlock property)
layer_tags (kfac_jax.ProcessedJaxpr attribute)
layer_tags_vjp() (in module kfac_jax)
LayerTag (class in kfac_jax)
loss() (kfac_jax.LossTag method)
loss_tags (kfac_jax.ProcessedJaxpr attribute)
loss_tags_hvp() (in module kfac_jax)
loss_tags_jvp() (in module kfac_jax)
loss_tags_vjp() (in module kfac_jax)
LossFunction (class in kfac_jax)
LossTag (class in kfac_jax)
M
make_from_func() (kfac_jax.ProcessedJaxpr class method)
matrix (kfac_jax.Full.State attribute)
MultiBernoulliNegativeLogProbLoss (class in kfac_jax)
multiply() (kfac_jax.CurvatureBlock method)
(kfac_jax.CurvatureEstimator method)
multiply_fisher() (kfac_jax.ImplicitExactCurvature method)
(kfac_jax.NegativeLogProbLoss method)
multiply_fisher_factor() (kfac_jax.ImplicitExactCurvature method)
(kfac_jax.NegativeLogProbLoss method)
multiply_fisher_factor_replicated_one_hot() (kfac_jax.NegativeLogProbLoss method)
multiply_fisher_factor_replicated_one_hot_unweighted() (kfac_jax.CategoricalLogitsNegativeLogProbLoss method)
(kfac_jax.MultiBernoulliNegativeLogProbLoss method)
(kfac_jax.NegativeLogProbLoss method)
(kfac_jax.NormalMeanNegativeLogProbLoss method)
(kfac_jax.NormalMeanVarianceNegativeLogProbLoss method)
multiply_fisher_factor_transpose() (kfac_jax.ImplicitExactCurvature method)
(kfac_jax.NegativeLogProbLoss method)
multiply_fisher_factor_transpose_unweighted() (kfac_jax.CategoricalLogitsNegativeLogProbLoss method)
(kfac_jax.MultiBernoulliNegativeLogProbLoss method)
(kfac_jax.NegativeLogProbLoss method)
(kfac_jax.NormalMeanNegativeLogProbLoss method)
(kfac_jax.NormalMeanVarianceNegativeLogProbLoss method)
multiply_fisher_factor_unweighted() (kfac_jax.CategoricalLogitsNegativeLogProbLoss method)
(kfac_jax.MultiBernoulliNegativeLogProbLoss method)
(kfac_jax.NegativeLogProbLoss method)
(kfac_jax.NormalMeanNegativeLogProbLoss method)
(kfac_jax.NormalMeanVarianceNegativeLogProbLoss method)
multiply_fisher_unweighted() (kfac_jax.CategoricalLogitsNegativeLogProbLoss method)
(kfac_jax.MultiBernoulliNegativeLogProbLoss method)
(kfac_jax.NegativeLogProbLoss method)
(kfac_jax.NormalMeanNegativeLogProbLoss method)
(kfac_jax.NormalMeanVarianceNegativeLogProbLoss method)
multiply_ggn() (kfac_jax.ImplicitExactCurvature method)
(kfac_jax.LossFunction method)
multiply_ggn_factor() (kfac_jax.ImplicitExactCurvature method)
(kfac_jax.LossFunction method)
multiply_ggn_factor_replicated_one_hot() (kfac_jax.LossFunction method)
multiply_ggn_factor_replicated_one_hot_unweighted() (kfac_jax.LossFunction method)
(kfac_jax.NormalMeanVarianceNegativeLogProbLoss method)
multiply_ggn_factor_transpose() (kfac_jax.ImplicitExactCurvature method)
(kfac_jax.LossFunction method)
multiply_ggn_factor_transpose_unweighted() (kfac_jax.LossFunction method)
(kfac_jax.NormalMeanVarianceNegativeLogProbLoss method)
multiply_ggn_factor_unweighted() (kfac_jax.LossFunction method)
(kfac_jax.NormalMeanVarianceNegativeLogProbLoss method)
multiply_ggn_unweighted() (kfac_jax.LossFunction method)
(kfac_jax.NormalMeanVarianceNegativeLogProbLoss method)
multiply_hessian() (kfac_jax.ImplicitExactCurvature method)
multiply_inverse() (kfac_jax.CurvatureBlock method)
(kfac_jax.CurvatureEstimator method)
multiply_jacobian_transpose() (kfac_jax.ImplicitExactCurvature method)
multiply_matpower() (kfac_jax.BlockDiagonalCurvature method)
(kfac_jax.CurvatureBlock method)
(kfac_jax.CurvatureEstimator method)
N
NaiveDiagonal (class in kfac_jax)
NaiveFull (class in kfac_jax)
NegativeLogProbLoss (class in kfac_jax)
NormalMeanNegativeLogProbLoss (class in kfac_jax)
NormalMeanVarianceNegativeLogProbLoss (class in kfac_jax)
num_blocks (kfac_jax.BlockDiagonalCurvature property)
num_burnin_steps (kfac_jax.Optimizer property)
num_inputs (kfac_jax.LayerTag property)
num_inputs_channels (kfac_jax.Conv2DTwoKroneckerFactored property)
num_locations (kfac_jax.Conv2DTwoKroneckerFactored property)
num_outputs (kfac_jax.LayerTag property)
num_outputs_channels (kfac_jax.Conv2DTwoKroneckerFactored property)
num_parameter_dependants (kfac_jax.LossFunction property)
num_parameter_independants (kfac_jax.LossFunction property)
num_params_variables (kfac_jax.BlockDiagonalCurvature property)
number_of_parameters (kfac_jax.CurvatureBlock property)
O
OneHotCategoricalLogitsNegativeLogProbLoss (class in kfac_jax)
Optimizer (class in kfac_jax)
Optimizer.State (class in kfac_jax)
outputs_channel_index (kfac_jax.Conv2DTwoKroneckerFactored property)
outputs_shapes (kfac_jax.CurvatureBlock property)
P
parameter_dependants (kfac_jax.LossFunction property)
(kfac_jax.NegativeLogProbLoss property)
parameter_dependants_names (kfac_jax.LossTag property)
parameter_independants (kfac_jax.CategoricalLogitsNegativeLogProbLoss property)
(kfac_jax.LossFunction property)
(kfac_jax.MultiBernoulliNegativeLogProbLoss property)
(kfac_jax.NormalMeanNegativeLogProbLoss property)
(kfac_jax.NormalMeanVarianceNegativeLogProbLoss property)
parameter_independants_names (kfac_jax.LossTag property)
parameter_variables (kfac_jax.CurvatureBlock property)
parameters_canonical_order (kfac_jax.CurvatureBlock property)
parameters_list_to_single_vector() (kfac_jax.Full method)
parameters_shaped_list_to_array() (kfac_jax.TwoKroneckerFactored method)
parameters_shapes (kfac_jax.CurvatureBlock property)
params (kfac_jax.CategoricalLogitsNegativeLogProbLoss property)
(kfac_jax.MultiBernoulliNegativeLogProbLoss property)
(kfac_jax.NegativeLogProbLoss property)
(kfac_jax.NormalMeanNegativeLogProbLoss property)
(kfac_jax.NormalMeanVarianceNegativeLogProbLoss property)
params_block_index (kfac_jax.BlockDiagonalCurvature property)
params_index (kfac_jax.CurvatureEstimator attribute)
(kfac_jax.ProcessedJaxpr attribute)
params_structure_vector_of_indices (kfac_jax.BlockDiagonalCurvature property)
params_tree (kfac_jax.ProcessedJaxpr property)
params_vars (kfac_jax.ProcessedJaxpr property)
params_vars_flat (kfac_jax.ProcessedJaxpr property)
params_vector_to_blocks_vectors() (kfac_jax.BlockDiagonalCurvature method)
(kfac_jax.ExplicitExactCurvature method)
patches_moments() (in module kfac_jax)
patches_moments_explicit() (in module kfac_jax)
ProcessedJaxpr (class in kfac_jax)
R
register_categorical_predictive_distribution() (in module kfac_jax)
register_conv2d() (in module kfac_jax)
register_dense() (in module kfac_jax)
register_generic() (in module kfac_jax)
register_multi_bernoulli_predictive_distribution() (in module kfac_jax)
register_normal_predictive_distribution() (in module kfac_jax)
register_scale_and_shift() (in module kfac_jax)
register_sigmoid_cross_entropy_loss() (in module kfac_jax)
register_softmax_cross_entropy_loss() (in module kfac_jax)
register_squared_error_loss() (in module kfac_jax)
S
sample() (kfac_jax.DistributionNegativeLogProbLoss method)
(kfac_jax.NegativeLogProbLoss method)
scale() (kfac_jax.CurvatureBlock method)
ScaleAndShiftDiagonal (class in kfac_jax)
ScaleAndShiftFull (class in kfac_jax)
ScaledIdentity (class in kfac_jax)
set_default_eigen_decomposition_threshold() (in module kfac_jax)
set_default_tag_to_block_ctor() (in module kfac_jax)
set_max_parallel_elements() (in module kfac_jax)
should_update_damping() (kfac_jax.Optimizer method)
single_vector_to_parameters_list() (kfac_jax.Full method)
split_all_inputs() (kfac_jax.LayerTag method)
state_dependent_scale() (kfac_jax.CurvatureBlock method)
step() (kfac_jax.Optimizer method)
step_counter (kfac_jax.Optimizer.State attribute)
T
targets (kfac_jax.CategoricalLogitsNegativeLogProbLoss property)
(kfac_jax.LossFunction property)
(kfac_jax.MultiBernoulliNegativeLogProbLoss property)
(kfac_jax.NormalMeanNegativeLogProbLoss property)
(kfac_jax.NormalMeanVarianceNegativeLogProbLoss property)
to_dense_matrix() (kfac_jax.BlockDiagonalCurvature method)
(kfac_jax.CurvatureBlock method)
(kfac_jax.CurvatureEstimator method)
to_diagonal_block_dense_matrix() (kfac_jax.BlockDiagonalCurvature method)
TwoKroneckerFactored (class in kfac_jax)
U
update_cache() (kfac_jax.BlockDiagonalCurvature method)
(kfac_jax.CurvatureBlock method)
(kfac_jax.CurvatureEstimator method)
(kfac_jax.ExplicitExactCurvature method)
update_curvature_matrix_estimate() (kfac_jax.BlockDiagonalCurvature method)
(kfac_jax.CurvatureBlock method)
(kfac_jax.CurvatureEstimator method)
(kfac_jax.Diagonal method)
(kfac_jax.ExplicitExactCurvature method)
(kfac_jax.Full method)
(kfac_jax.ScaledIdentity method)
V
velocities (kfac_jax.Optimizer.State attribute)
verify_args_and_get_step_counter() (kfac_jax.Optimizer method)
W
weight (kfac_jax.LossFunction property)
weighted_sum_of_objects() (kfac_jax.Optimizer method)
weights_output_channel_index (kfac_jax.Conv2DTwoKroneckerFactored property)
weights_spatial_size (kfac_jax.Conv2DTwoKroneckerFactored property)