KFAC-JAX Documentation
KFAC-JAX is a library built on top of JAX for second-order optimization of neural networks and for computing scalable curvature approximations. The main goal of the library is to provide researchers with an easy-to-use implementation of the K-FAC paper optimizer and curvature estimator.
Installation
KFAC-JAX is written in pure Python, but depends on C++ code via JAX.
First, follow these instructions to install JAX with the relevant accelerator support.
Then, install KFAC-JAX using pip:
$ pip install git+https://github.com/google-deepmind/kfac-jax
Alternatively, you can install via PyPI:
$ pip install -U kfac-jax
Our examples rely on additional libraries, all of which you can install using:
$ pip install -r requirements_examples.txt
Guides
High Level Overview
API Documentation
Advanced Topics
Contribute
We are not taking code contributions from Github at this time. All PRs from Github will be rejected. If you find a bug, or for any feature requests and ideas, please email us at [kfac-jax-dev@google.com](mailto:kfac-jax-dev@google.com), or raise an issue on the Github issue tracker.
Issue tracker: https://github.com/google-deepmind/kfac-jax/issues
Source code: https://github.com/google-deepmind/kfac-jax/tree/main
Support
If you are having issues, please let us know by filing an issue on our issue tracker.
License
KFAC-JAX is licensed under the Apache 2.0 License.