KFAC-JAX Documentation

KFAC-JAX is a library built on top of JAX for second-order optimization of neural networks and for computing scalable curvature approximations. The main goal of the library is to provide researchers with an easy-to-use implementation of the K-FAC paper optimizer and curvature estimator.

Installation

KFAC-JAX is written in pure Python, but depends on C++ code via JAX.

First, follow these instructions to install JAX with the relevant accelerator support.

Then, install KFAC-JAX using pip:

$ pip install git+https://github.com/google-deepmind/kfac-jax

Alternatively, you can install via PyPI:

$ pip install -U kfac-jax

Our examples rely on additional libraries, all of which you can install using:

$ pip install -r requirements_examples.txt

Guides

High Level Overview

Contribute

We are not taking code contributions from Github at this time. All PRs from Github will be rejected. If you find a bug, or for any feature requests and ideas, please email us at [kfac-jax-dev@google.com](mailto:kfac-jax-dev@google.com), or raise an issue on the Github issue tracker.

Support

If you are having issues, please let us know by filing an issue on our issue tracker.

License

KFAC-JAX is licensed under the Apache 2.0 License.

Indices and tables