Overview

MLAX is a purely functional ML library built on top of Google JAX.

MLAX follows object-oriented semantics like Keras and PyTorch.

Modules are PyTrees whose leaves are parameters and whose auxiliary data are hyperparameters.

This means MLAX is fully compatible with native JAX transformations, notably:

  1. grad

  2. vmap

  3. pmap

  4. jit

Why MLAX?

Compared to existing JAX libraries,

  1. MLAX stores parameters in the modules themselves rather than in a separate structure, making layer development and parameter surgery easier.

  2. MLAX does not require special versions grad, vmap, pmap, and jit, making it easier to learn and integrate with other JAX libraries.

  3. MLAX allows parameters to be updated in the forward pass, making it easy to develop stateful layers like BatchNorm.

Installation

Visit MLAX’s GitHub page.

Worked Examples

End-to-end examples with reference PyTorch implementations can be found on GitHub as well.