mlax package

Subpackages

Submodules

mlax.module module

MLAX module base class and parameter.

class mlax.module.Module

Bases: object

MLAX layer base class. PyTree of mlax.Parameters.

apply(x: Any, rng: Array | None, inference_mode: bool = False, batch_axis_name: Hashable | Tuple[Hashable] = ()) Tuple[Any, Any]

Thin wrapper around forward to that returns self in addition to output features. Can be jit-compiled to speed up __call__.

combine(*rest)

Combine self’s parameters with rest’s.

filter(f=<function is_trainable_param>, inverse=False) Any

Apply a filter f on self’s parameters. Filtered out parameters have their data field replaced with None.

filter_with_path(f, inverse=False) Any

filter with path.

forward(x: Any, rng: Array | None, inference_mode: bool = False, batch_axis_name: Hashable | Tuple[Hashable] = ()) Any

Perform the forward pass assuming setup has been called.

Note

Because setup may not initialize submodules, forward may need to initialize submodules before using them. This is commonly done by calling their __call__ method, recursively initializing them.

Parameters:
  • x – Compatible input features.

  • rng – PRNG key. Only necessary for some modules.

Note

When overriding, set rng’s default value to None if a key is not required. MLAX uses this information to avoid splitting and passing keys to modules that do not need them.

Parameters:
  • inference_mode – Whether in inference or training mode. Default: training mode.

  • batch_axis_name – Hashable or tuple of hashable representing the batch axis name(s) when called in a jax.vmap or jax.pmap context. Used by modules such as ZNorm to normalize along the batch axis. Default: (), no batch axis.

Returns:

Output features.

initialized

Whether a module’s hyperparameters, parameters, and submodules are all initialized.

partition(f=<function is_trainable_param>) Tuple[Any, Any]

Partition on self’s parameters on filter f. Unselected parameters have their data field replaced with None.

partition_with_path(f) Tuple[Any, Any]

partition with path.

setup(x: Any) None

Initialize parameters. Submodules may not be initialized.

Parameters:

x – Compatible input features.

tree_flatten_with_keys()

Flatten into parameters and auxiliary hyperparameters.

classmethod tree_unflatten(aux, param_values)

Unflatten parameters and auxiliary hyperparameters.

class mlax.module.Parameter(trainable: bool | None, data: Any | None = None)

Bases: object

PyTree wrapper around a valid JAX object and metadata.

tree_flatten()

Flatten into a valid JAX object and auxiliary metadata.

classmethod tree_unflatten(aux, children)

Unflatten a valid JAX object and auxiliary metadata.

mlax.module.is_leaf_param(p)

Whether p is a parameter whose trainable is not None.

mlax.module.is_non_trainable_param(p)

Whether p is a parameter whose trainable is False.

mlax.module.is_trainable_param(p)

Whether p is a parameter whose trainable is True.

Module contents