Modules and Parameters
MLAX modules are PyTrees that inherit from mlax.Module. Fields that are
parameters and modules are considered children of the PyTree (parameters). All
other fields are considered auxiliary data (hyperparameters),
MLAX parameters inherit from mlax.Parameter. They are PyTrees that wrap
some data. data can be JAX types (jax.Array, np.array, etc.),
PyTrees of valid JAX types (list/tuple/dict of jax.Array), or even other
mlax.Parameter s and mlax.Module s.
Use parameters’ trainable field to indicate whether a parameter is a leaf in
a PyTree. trainable=None means a parameter contains nested parameters.
trainable=False indicates a leaf parameter whose data is not trainable.
trainable=True indicates a leaf parameter whose data is trainable.
MLAX modules’ auxiliary data can contain non JAX types (str, lambda,
etc.) but they must be comparable and hashable if the module were to be
jit-compiled.
The following code illustrate the different possible fields a module can have.
class Foo(Module):
def __init__(self):
self.foo = Parameter(trainable=True, data=None) # Ok, empty trainable parameter
class Bar(Module):
def __init__(self):
super().__init__()
self.a = Parameter(trainable=True, data=jnp.ones((3, 4))) # Ok, trainable parameter
self.b = Parameter(trainable=False, data=np.ones((3, 4))) # Ok, non-trainable parameter
self.c = Parameter(trainable=False, data=[1, 2, 3]) # Ok, non-trainable PyTree parameter
# self.d = Parameter(trainable=False, data="abc") # Not ok, not a valid JAX type in parameter
# self.e = Parameter(trainable=True, data=Foo()) # Not ok, contains a nested parameter/module in a leaf parameter
self.f = Parameter(trainable=None, data=Foo()) # Ok, `trainable=None` means not a leaf parameter
self.g = Foo() # Ok, submodule
self.h = "abc" # Ok, hyperparameter
self.i = 1 # Ok, hyperparameter
# self.j = [1, 2, 3] # Not ok, not hashable
Note
MLAX modules use vars() to determine their fields during PyTree
flattening and unflattening. This means all variables must be stored in
__dict__. Avoid storing variables in __slots__.
To define a custom module, inherit from mlax.Module and implement three
functions:
__init__to initialize the custom module’s hyperparameters.initto initialize the trainable and non-trainable parameters given asample input.
applyto perform the forward pass assuming the custom module has beeninitialized with
init.
MLAX module implements the __call__ function, which initilizes the module
and submodules if they are not initialized. The function then performs the
forward pass and returns the results and an updated self.
class MyLayer(Module):
def __init__(self, ...):
super().__init__()
...
def init(self, x: Any) -> None:
...
def apply(
self,
x: Any,
rng: Optional[Array],
inference_mode: bool = False,
batch_axis_name: Union[Hashable, Tuple[Hashable]] = ()
) -> Tuple[Any, Any]:
...
Parameters can always be mutated, even in apply. Hyperparameters can also be
mutated in apply, but this may result in recompilation inside a jax.jit.
Both parameters and hyperparameters cannot be deleted once initialized.
A module can be treated like any other PyTree, but MLAX modules have some convenience functions to filter, partition, and combine module parameters.
For example, to get a copy of a module’s trainable parameters, use
module.filter(f=is_trainable_param). Non-trainable parameters’ data are
set to None.
To partition a module into trainable and non-trainable parameters, use
module.partition(f=is_trainable_param). This is equivalent to
module.filter(f=is_trainable_param), module.filter(f=is_non_trainable_param)
To combine partitioned modules trainables and non_trainables into a new
module, use trainables.combine(non_trainables). Parameters in trainables
are copied and those whose data is None will have it replaced with
the data of the corresponding parameter in non_trainables.