mlax package

Subpackages

Submodules

mlax.module module

MLAX module base class and parameter.

class mlax.module.Module

Bases: object

MLAX layer base class. PyTree of mlax.Parameters.

apply(x: Any, rng: Array | None, inference_mode: bool = False, batch_axis_name: Hashable | Tuple[Hashable] = ()) Tuple[Any, Any]

Thin wrapper around forward to that returns self in addition to output features. Can be jit-compiled to speed up __call__.

combine(*rest)

Combine self’s parameters with rest’s.

filter(f=<function is_trainable_param>, inverse=False) Any

Apply a filter f on self’s parameters. Filtered out parameters have their data field replaced with None.

filter_with_path(f, inverse=False) Any

filter with path.

forward(x: Any, rng: Array | None, inference_mode: bool = False, batch_axis_name: Hashable | Tuple[Hashable] = ()) Any

Perform the forward pass assuming setup has been called.

Parameters:
  • x – Compatible input features.

  • rng – PRNG key. Only necessary for some modules.

  • inference_mode – Whether in inference or training mode. Default: training mode.

  • batch_axis_name – Hashable or tuple of hashable representing the batch axis name(s) when called in a jax.vmap or jax.pmap context. Used by modules such as ZNorm to normalize along the batch axis. Default: (), no batch axis.

Returns:

Output features.

Note

When overriding, set rng’s default value to None if a key is not required. MLAX uses this information to avoid splitting and passing keys to modules that do not need them.

partition(f=<function is_trainable_param>) Tuple[Any, Any]

Partition on self’s parameters on filter f. Unselected parameters have their data field replaced with None.

partition_with_path(f) Tuple[Any, Any]

partition with path.

setup(x: Any) None

Initialize parameters and put self into a valid state for forward. Submodules may not be initialized until __call__ is called.

Parameters:

x – Compatible input features.

tree_flatten_with_keys()

Flatten into parameters and auxiliary hyperparameters.

classmethod tree_unflatten(aux, param_values)

Unflatten parameters and auxiliary hyperparameters.

class mlax.module.Parameter(trainable: bool | None, data: Any | None = None)

Bases: object

PyTree wrapper around a valid JAX object and metadata.

tree_flatten()

Flatten into a valid JAX object and auxiliary metadata.

classmethod tree_unflatten(aux, children)

Unflatten a valid JAX object and auxiliary metadata.

mlax.module.is_leaf_param(p)

Whether p is a parameter whose trainable is not None.

mlax.module.is_non_trainable_param(p)

Whether p is a parameter whose trainable is False.

mlax.module.is_trainable_param(p)

Whether p is a parameter whose trainable is True.

Module contents