mlax package
Subpackages
Submodules
mlax.module module
MLAX module base class and parameter.
- class mlax.module.Module
Bases:
objectMLAX layer base class. PyTree of mlax.Parameters.
- apply(x: Any, rng: Array | None, inference_mode: bool = False, batch_axis_name: Hashable | Tuple[Hashable] = ()) Tuple[Any, Any]
Thin wrapper around
forwardto that returnsselfin addition to output features. Can be jit-compiled to speed up__call__.
- combine(*rest)
Combine
self’s parameters withrest’s.
- filter(f=<function is_trainable_param>, inverse=False) Any
Apply a filter
fonself’s parameters. Filtered out parameters have theirdatafield replaced withNone.
- filter_with_path(f, inverse=False) Any
filterwith path.
- forward(x: Any, rng: Array | None, inference_mode: bool = False, batch_axis_name: Hashable | Tuple[Hashable] = ()) Any
Perform the forward pass assuming
setuphas been called.- Parameters:
x – Compatible input features.
rng – PRNG key. Only necessary for some modules.
inference_mode – Whether in inference or training mode. Default: training mode.
batch_axis_name – Hashable or tuple of hashable representing the batch axis name(s) when called in a jax.vmap or jax.pmap context. Used by modules such as ZNorm to normalize along the batch axis. Default: (), no batch axis.
- Returns:
Output features.
Note
When overriding, set
rng’s default value toNoneif a key is not required. MLAX uses this information to avoid splitting and passing keys to modules that do not need them.
- partition(f=<function is_trainable_param>) Tuple[Any, Any]
Partition on
self’s parameters on filterf. Unselected parameters have theirdatafield replaced withNone.
- partition_with_path(f) Tuple[Any, Any]
partitionwith path.
- setup(x: Any) None
Initialize parameters and put
selfinto a valid state forforward. Submodules may not be initialized until__call__is called.- Parameters:
x – Compatible input features.
- tree_flatten_with_keys()
Flatten into parameters and auxiliary hyperparameters.
- classmethod tree_unflatten(aux, param_values)
Unflatten parameters and auxiliary hyperparameters.
- class mlax.module.Parameter(trainable: bool | None, data: Any | None = None)
Bases:
objectPyTree wrapper around a valid JAX object and metadata.
- tree_flatten()
Flatten into a valid JAX object and auxiliary metadata.
- classmethod tree_unflatten(aux, children)
Unflatten a valid JAX object and auxiliary metadata.
- mlax.module.is_leaf_param(p)
Whether
pis a parameter whosetrainable is not None.
- mlax.module.is_non_trainable_param(p)
Whether
pis a parameter whosetrainable is False.
- mlax.module.is_trainable_param(p)
Whether
pis a parameter whosetrainable is True.