Overview
MLAX is a purely functional ML library built on top of Google JAX.
MLAX follows object-oriented semantics like Keras and PyTorch.
Modules are PyTrees whose leaves are parameters and whose auxiliary data are hyperparameters.
This means MLAX is fully compatible with native JAX transformations, notably:
Why MLAX?
Compared to existing JAX libraries,
MLAX stores parameters in the modules themselves rather than in a separate structure, making layer development and parameter surgery easier.
MLAX does not require special versions
grad,vmap,pmap, andjit, making it easier to learn and integrate with other JAX libraries.MLAX allows parameters to be updated in the forward pass, making it easy to develop stateful layers like BatchNorm.
Installation
Visit MLAX’s GitHub page.
Worked Examples
End-to-end examples with reference PyTorch implementations can be found on GitHub as well.