mlax.nn package

Submodules

mlax.nn.bias module

class mlax.nn.bias.Bias(rng: ~jax.Array, in_features: int | ~typing.Sequence[int], bias_initializer=<function zeros>, dtype=<class 'jax.numpy.float32'>)

Bases: Module

Bias addition layer.

forward(x: Array, rng: None = None, inference_mode: bool = False, batch_axis_name: Hashable | Tuple[Hashable] = ()) Array

Perform the forward pass assuming setup has been called.

Note

Because setup may not initialize submodules, forward may need to initialize submodules before using them. This is commonly done by calling their __call__ method, recursively initializing them.

Parameters:
  • x – Compatible input features.

  • rng – PRNG key. Only necessary for some modules.

Note

When overriding, set rng’s default value to None if a key is not required. MLAX uses this information to avoid splitting and passing keys to modules that do not need them.

Parameters:
  • inference_mode – Whether in inference or training mode. Default: training mode.

  • batch_axis_name – Hashable or tuple of hashable representing the batch axis name(s) when called in a jax.vmap or jax.pmap context. Used by modules such as ZNorm to normalize along the batch axis. Default: (), no batch axis.

Returns:

Output features.

setup(x: Array) None

Initialize parameters. Submodules may not be initialized.

Parameters:

x – Compatible input features.

mlax.nn.conv module

class mlax.nn.conv.Conv(rng: ~jax.Array, out_channels: int, filter_shape: int | ~typing.Sequence[int], strides: int | ~typing.Sequence[int] = 1, padding: str | int | ~typing.Sequence[int | ~typing.Tuple[int, int]] = 'VALID', input_dilation: int | ~typing.Sequence[int] | None = None, filter_dilation: int | ~typing.Sequence[int] | None = None, feature_group_count: int = 1, batch_group_count: int = 1, data_format: str | ~typing.Tuple[str, str, str] = 'channel_last', precision=None, accum_dtype=None, kernel_initializer=<function variance_scaling.<locals>.init>, dtype=<class 'jax.numpy.float32'>)

Bases: Module

Convolution transformation layer.

forward(x: Array, rng: None = None, inference_mode: bool = False, batch_axis_name: Hashable | Tuple[Hashable] = ()) Array

Perform the forward pass assuming setup has been called.

Note

Because setup may not initialize submodules, forward may need to initialize submodules before using them. This is commonly done by calling their __call__ method, recursively initializing them.

Parameters:
  • x – Compatible input features.

  • rng – PRNG key. Only necessary for some modules.

Note

When overriding, set rng’s default value to None if a key is not required. MLAX uses this information to avoid splitting and passing keys to modules that do not need them.

Parameters:
  • inference_mode – Whether in inference or training mode. Default: training mode.

  • batch_axis_name – Hashable or tuple of hashable representing the batch axis name(s) when called in a jax.vmap or jax.pmap context. Used by modules such as ZNorm to normalize along the batch axis. Default: (), no batch axis.

Returns:

Output features.

setup(x: Array) None

Initialize parameters. Submodules may not be initialized.

Parameters:

x – Compatible input features.

mlax.nn.embed module

class mlax.nn.embed.Embed(rng: ~jax.Array, vocab_size: int, embed_dim: int, embed_initializer=<function variance_scaling.<locals>.init>, dtype=<class 'jax.numpy.float32'>)

Bases: Module

Embedding layer.

forward(x: Array, rng: None = None, inference_mode: bool = False, batch_axis_name: Hashable | Tuple[Hashable] = ()) Array

Perform the forward pass assuming setup has been called.

Note

Because setup may not initialize submodules, forward may need to initialize submodules before using them. This is commonly done by calling their __call__ method, recursively initializing them.

Parameters:
  • x – Compatible input features.

  • rng – PRNG key. Only necessary for some modules.

Note

When overriding, set rng’s default value to None if a key is not required. MLAX uses this information to avoid splitting and passing keys to modules that do not need them.

Parameters:
  • inference_mode – Whether in inference or training mode. Default: training mode.

  • batch_axis_name – Hashable or tuple of hashable representing the batch axis name(s) when called in a jax.vmap or jax.pmap context. Used by modules such as ZNorm to normalize along the batch axis. Default: (), no batch axis.

Returns:

Output features.

setup(x: Array) None

Initialize parameters. Submodules may not be initialized.

Parameters:

x – Compatible input features.

mlax.nn.f module

class mlax.nn.f.F(train_fn: Callable[[Any], Any], infer_fn: Callable[[Any], Any] | None = None)

Bases: Module

Wrapper to create pure function layers.

forward(x: Any, rng: None = None, inference_mode: bool = False, batch_axis_name: Hashable | Tuple[Hashable] = ()) Any

Perform the forward pass assuming setup has been called.

Note

Because setup may not initialize submodules, forward may need to initialize submodules before using them. This is commonly done by calling their __call__ method, recursively initializing them.

Parameters:
  • x – Compatible input features.

  • rng – PRNG key. Only necessary for some modules.

Note

When overriding, set rng’s default value to None if a key is not required. MLAX uses this information to avoid splitting and passing keys to modules that do not need them.

Parameters:
  • inference_mode – Whether in inference or training mode. Default: training mode.

  • batch_axis_name – Hashable or tuple of hashable representing the batch axis name(s) when called in a jax.vmap or jax.pmap context. Used by modules such as ZNorm to normalize along the batch axis. Default: (), no batch axis.

Returns:

Output features.

setup(x: Any) None

Initialize parameters. Submodules may not be initialized.

Parameters:

x – Compatible input features.

class mlax.nn.f.FRng(train_fn: Callable[[Any, Array], Any], infer_fn: Callable[[Any, Array], Any] | None = None)

Bases: Module

Wrapper to create pure function layers that may require rng.

forward(x: Any, rng: Array, inference_mode: bool = False, batch_axis_name: Hashable | Tuple[Hashable] = ()) Any

Perform the forward pass assuming setup has been called.

Note

Because setup may not initialize submodules, forward may need to initialize submodules before using them. This is commonly done by calling their __call__ method, recursively initializing them.

Parameters:
  • x – Compatible input features.

  • rng – PRNG key. Only necessary for some modules.

Note

When overriding, set rng’s default value to None if a key is not required. MLAX uses this information to avoid splitting and passing keys to modules that do not need them.

Parameters:
  • inference_mode – Whether in inference or training mode. Default: training mode.

  • batch_axis_name – Hashable or tuple of hashable representing the batch axis name(s) when called in a jax.vmap or jax.pmap context. Used by modules such as ZNorm to normalize along the batch axis. Default: (), no batch axis.

Returns:

Output features.

setup(x: Any) None

Initialize parameters. Submodules may not be initialized.

Parameters:

x – Compatible input features.

mlax.nn.functional module

mlax.nn.functional.apply_attention_weights(attention_weights: Array, value: Array) Array

Apply attention weights to values.

Parameters:
  • attention_weights – Attention weights of the same dtype as value and of shape (query_length, key_value_length).

  • value – Value array of shape (key_value_length, value_depth).

Returns activations:

value with attention_weights applied, of shape (query_length, value_depth).

mlax.nn.functional.avg_pool(x: Array, window_shape: int | Sequence[int], strides: int | Sequence[int] = 1, padding: str | int | Sequence[int | Tuple[int, int]] = 'VALID', input_dilation: int | Sequence[int] | None = None, window_dilation: int | Sequence[int] | None = None, data_format: str = 'channel_last') Array

Apply average pooling over input features.

Parameters:
  • x – Input features. Must have be unbatched thus having n_spatial_dims + 1 dimensions.

  • window_shape – See the window_shape parameter of pool.

  • strides – See the strides parameter of pool. Default: 1.

  • padding – See the padding parameter of pool.

  • input_dilation – See the input_dilation parameter of pool. Default: None, no input dilation.

  • window_dilation – See the window_dilation parameter of pool. Default: None, no window dilation.

  • data_format – “channel_last”, “channel_first”, or a string representing the kernel spec as described in jax.lax.conv_general_dilated, but without N the batch dimension. Default: “channel_last”.

Returns y:

x with average pool applied.

mlax.nn.functional.dot_product_attention_logits(query: Array, key: Array) Array

Compute scaled dot-product attention logits.

Parameters:
  • query – Query array of shape (query_length, query_key_depth).

  • key – Key array of the same dtype as query and of shape (key_value_length, query_key_depth).

Returns:

Attention logits of (query_length, key_value_length).

mlax.nn.functional.dropout(x: Array, rng: Any, rate: float, axis: int | Sequence[int]) Array

Apply random dropouts to input features.

Parameters:
  • x – Input features.

  • rng – PRNG key for randomizing dropouts.

  • rate – Probability at which each element is droped out. Must be in [0, 1).

  • axis – Axis or sequence of axes to drop features along.

Returns y:

x with dropouts applied.

mlax.nn.functional.identity(*xs: Any) Any

Identity function.

Parameters:

x – Input features.

Returns y:

x.

mlax.nn.functional.max_pool(x: Array, window_shape: int | Sequence[int], strides: int | Sequence[int] = 1, padding: str | int | Sequence[int | Tuple[int, int]] = 'VALID', input_dilation: int | Sequence[int] | None = None, window_dilation: int | Sequence[int] | None = None, data_format: str = 'channel_last') Array

Apply max pooling over input features.

Parameters:
  • x – Input features. Must have be unbatched thus having n_spatial_dims + 1 dimensions.

  • window_shape – See the window_shape parameter of pool.

  • strides – See the strides parameter of pool. Default: 1.

  • padding – See the padding parameter of pool.

  • input_dilation – See the input_dilation parameter of pool. Default: None, no input dilation.

  • window_dilation – See the window_dilation parameter of pool. Default: None, no window dilation.

  • data_format – “channel_last”, “channel_first”, or a string representing the kernel spec as described in jax.lax.conv_general_dilated, but without N the batch dimension. Default: “channel_last”.

Returns y:

x with max pool applied.

mlax.nn.functional.pool(x: Array, init_value: Any, reduce_fn: Callable[[Any, Any], Any], window_shape: int | Sequence[int], strides: int | Sequence[int] = 1, padding: str | int | Sequence[int | Tuple[int, int]] = 'VALID', input_dilation: int | Sequence[int] | None = None, window_dilation: int | Sequence[int] | None = None, data_format: str = 'channel_last') Array
Apply an arbitrary reduce function over poolings windows of input

features.

Parameters:
  • x – Input features. Must have be unbatched thus having n_spatial_dims + 1 dimensions.

  • init_value – Initial value of the reduce function over each pooling window.

  • reduce_fn – Reduce function.

  • window_shape – An integer or a sequence of n_spatial_dims integers, specifying the shape of the pooling window used on input features. A single integer specifies the same value for all spatial dimensions.

  • strides – An integer or a sequence of n_spatial_dims integers, specifying the strides of the window shape along the spatial dimensions. A single integer specifies the same value for all spatial dimensions. Default: 1.

  • padding – String, integer, or a sequence of n_spatial_dims integers or integer tuple pairs that give the padding to apply before and after each spatial dimension. If integer, the same padding is applied before and after all spatial dimensions. If a sequence of integers, then the same padding is applied before and after each spatial dimension. See the padding parameter of jax.lax.reduce_window, which is used internally.

  • input_dilation – None, an integer, or a sequence of n_spatial_dims integers, specifying the input dilation rate in each spatial dimension. See the base_dilation parameter of jax.lax.reduce_window. Default: None, no input dilation.

  • window_dilation – None, an integer, or a sequence of n_spatial_dims integers, specifying the window dilation rate in each spatial dimension. See the window_dilation parameter of jax.lax.reduce_window. Default: None, no window dilation.

  • data_format – “channel_last”, “channel_first”, or a string representing the kernel spec as described in jax.lax.conv_general_dilated, but without N the batch dimension. Default: “channel_last”.

Returns y:

x with pooling applied.

mlax.nn.functional.sum_pool(x: Array, window_shape: int | Sequence[int], strides: int | Sequence[int] = 1, padding: str | int | Sequence[int | Tuple[int, int]] = 'VALID', input_dilation: int | Sequence[int] | None = None, window_dilation: int | Sequence[int] | None = None, data_format: str = 'channel_last') Array

Apply sum pooling over input features.

Parameters:
  • x – Input features. Must have be unbatched thus having n_spatial_dims + 1 dimensions.

  • window_shape – See the window_shape parameter of pool.

  • strides – See the strides parameter of pool. Default: 1.

  • padding – See the padding parameter of pool.

  • input_dilation – See the input_dilation parameter of pool. Default: None, no input dilation.

  • window_dilation – See the window_dilation parameter of pool. Default: None, no window dilation.

  • data_format – “channel_last”, “channel_first”, or a string representing the kernel spec as described in jax.lax.conv_general_dilated, but without N the batch dimension. Default: “channel_last”.

Returns y:

x with sum pooling applied.

mlax.nn.functional.z_norm(x: Array, axis: str | int | Sequence[int], batch_axis_name: Hashable | Tuple[Hashable] = (), epsilon: float = 1e-05)

Apply Z-score normalization.

Parameters:
  • axis – “all”, “channel_last”, “channel_first”, axis, or sequence of axes to normalize input features along. “all” indicates normalization along all axes (layer norm). “channel_last” and “channel_first” indicate normalization along all but the channel axis, assumed to be the last or first axis (instance norm).

  • epsilon – Small number added to variance to avoid divisions by zero. Default: 1e-05.

  • batch_axis_name – Hashable or tuple of hashable representing the batch axis name(s) to normalize along in addition to those in axis. Default: (), no normlization along any batch axis.

Returns:

x with normalization applied.

mlax.nn.linear module

class mlax.nn.linear.Linear(rng: ~jax.Array, out_features: int, precision=None, accum_dtype=None, transposed_kernel: bool = False, kernel_initializer=<function variance_scaling.<locals>.init>, dtype=<class 'jax.numpy.float32'>)

Bases: Module

Linear transformation layer without bias with lazy kernel initialization.

forward(x: Array, rng: None = None, inference_mode: bool = False, batch_axis_name: Hashable | Tuple[Hashable] = ()) Array

Perform the forward pass assuming setup has been called.

Note

Because setup may not initialize submodules, forward may need to initialize submodules before using them. This is commonly done by calling their __call__ method, recursively initializing them.

Parameters:
  • x – Compatible input features.

  • rng – PRNG key. Only necessary for some modules.

Note

When overriding, set rng’s default value to None if a key is not required. MLAX uses this information to avoid splitting and passing keys to modules that do not need them.

Parameters:
  • inference_mode – Whether in inference or training mode. Default: training mode.

  • batch_axis_name – Hashable or tuple of hashable representing the batch axis name(s) when called in a jax.vmap or jax.pmap context. Used by modules such as ZNorm to normalize along the batch axis. Default: (), no batch axis.

Returns:

Output features.

setup(x: Array) None

Initialize parameters. Submodules may not be initialized.

Parameters:

x – Compatible input features.

mlax.nn.parallel module

class mlax.nn.parallel.Parallel(layers: Iterable[Module])

Bases: Module

Combination of layers that do not require rng in parallel.

forward(x: Iterable[Any], rng: None = None, inference_mode: bool = False, batch_axis_name: Hashable | Tuple[Hashable] = ()) List[Any]

Perform the forward pass assuming setup has been called.

Note

Because setup may not initialize submodules, forward may need to initialize submodules before using them. This is commonly done by calling their __call__ method, recursively initializing them.

Parameters:
  • x – Compatible input features.

  • rng – PRNG key. Only necessary for some modules.

Note

When overriding, set rng’s default value to None if a key is not required. MLAX uses this information to avoid splitting and passing keys to modules that do not need them.

Parameters:
  • inference_mode – Whether in inference or training mode. Default: training mode.

  • batch_axis_name – Hashable or tuple of hashable representing the batch axis name(s) when called in a jax.vmap or jax.pmap context. Used by modules such as ZNorm to normalize along the batch axis. Default: (), no batch axis.

Returns:

Output features.

setup(x: Any) None

Initialize parameters. Submodules may not be initialized.

Parameters:

x – Compatible input features.

class mlax.nn.parallel.ParallelRng(layers: Iterable[Module])

Bases: Module

Combination of layers that may require rng in parallel.

forward(x: Iterable[Any], rng: Array, inference_mode: bool = False, batch_axis_name: Hashable | Tuple[Hashable] = ()) List[Any]

Perform the forward pass assuming setup has been called.

Note

Because setup may not initialize submodules, forward may need to initialize submodules before using them. This is commonly done by calling their __call__ method, recursively initializing them.

Parameters:
  • x – Compatible input features.

  • rng – PRNG key. Only necessary for some modules.

Note

When overriding, set rng’s default value to None if a key is not required. MLAX uses this information to avoid splitting and passing keys to modules that do not need them.

Parameters:
  • inference_mode – Whether in inference or training mode. Default: training mode.

  • batch_axis_name – Hashable or tuple of hashable representing the batch axis name(s) when called in a jax.vmap or jax.pmap context. Used by modules such as ZNorm to normalize along the batch axis. Default: (), no batch axis.

Returns:

Output features.

setup(x: Any) None

Initialize parameters. Submodules may not be initialized.

Parameters:

x – Compatible input features.

mlax.nn.recurrent module

class mlax.nn.recurrent.Recurrent(cell, reverse: bool = False, unroll: int = 1)

Bases: Module

Wrapper around a recurrent cell that does not require rng.

forward(xh: Tuple[Any, Any], rng: None = None, inference_mode: bool = False, batch_axis_name: Hashable | Tuple[Hashable] = ()) Any

Perform the forward pass assuming setup has been called.

Note

Because setup may not initialize submodules, forward may need to initialize submodules before using them. This is commonly done by calling their __call__ method, recursively initializing them.

Parameters:
  • x – Compatible input features.

  • rng – PRNG key. Only necessary for some modules.

Note

When overriding, set rng’s default value to None if a key is not required. MLAX uses this information to avoid splitting and passing keys to modules that do not need them.

Parameters:
  • inference_mode – Whether in inference or training mode. Default: training mode.

  • batch_axis_name – Hashable or tuple of hashable representing the batch axis name(s) when called in a jax.vmap or jax.pmap context. Used by modules such as ZNorm to normalize along the batch axis. Default: (), no batch axis.

Returns:

Output features.

setup(xh: Tuple[Any, Any]) None

Initialize parameters. Submodules may not be initialized.

Parameters:

x – Compatible input features.

class mlax.nn.recurrent.RecurrentRng(cell, reverse: bool = False, unroll: int = 1)

Bases: Module

Wrapper around a recurrent cell that may require rng.

forward(xh: Tuple[Any, Any], rng: Array, inference_mode: bool = False, batch_axis_name: Hashable | Tuple[Hashable] = ()) Any

Perform the forward pass assuming setup has been called.

Note

Because setup may not initialize submodules, forward may need to initialize submodules before using them. This is commonly done by calling their __call__ method, recursively initializing them.

Parameters:
  • x – Compatible input features.

  • rng – PRNG key. Only necessary for some modules.

Note

When overriding, set rng’s default value to None if a key is not required. MLAX uses this information to avoid splitting and passing keys to modules that do not need them.

Parameters:
  • inference_mode – Whether in inference or training mode. Default: training mode.

  • batch_axis_name – Hashable or tuple of hashable representing the batch axis name(s) when called in a jax.vmap or jax.pmap context. Used by modules such as ZNorm to normalize along the batch axis. Default: (), no batch axis.

Returns:

Output features.

setup(xh: Tuple[Any, Any]) None

Initialize parameters. Submodules may not be initialized.

Parameters:

x – Compatible input features.

mlax.nn.scaler module

class mlax.nn.scaler.Scaler(rng: ~jax.Array, in_features: int | ~typing.Sequence[int | None], scaler_initializer=<function ones>, dtype=<class 'jax.numpy.float32'>)

Bases: Module

Scaler layer.

forward(x: Array, rng: None = None, inference_mode: bool = False, batch_axis_name: Hashable | Tuple[Hashable] = ()) Array

Perform the forward pass assuming setup has been called.

Note

Because setup may not initialize submodules, forward may need to initialize submodules before using them. This is commonly done by calling their __call__ method, recursively initializing them.

Parameters:
  • x – Compatible input features.

  • rng – PRNG key. Only necessary for some modules.

Note

When overriding, set rng’s default value to None if a key is not required. MLAX uses this information to avoid splitting and passing keys to modules that do not need them.

Parameters:
  • inference_mode – Whether in inference or training mode. Default: training mode.

  • batch_axis_name – Hashable or tuple of hashable representing the batch axis name(s) when called in a jax.vmap or jax.pmap context. Used by modules such as ZNorm to normalize along the batch axis. Default: (), no batch axis.

Returns:

Output features.

setup(x: Array) None

Initialize parameters. Submodules may not be initialized.

Parameters:

x – Compatible input features.

mlax.nn.series module

class mlax.nn.series.Series(layers: Iterable[Module])

Bases: Module

Combination of layers that do not require rng in series.

forward(x: Any, rng: None = None, inference_mode: bool = False, batch_axis_name: Hashable | Tuple[Hashable] = ()) Any

Perform the forward pass assuming setup has been called.

Note

Because setup may not initialize submodules, forward may need to initialize submodules before using them. This is commonly done by calling their __call__ method, recursively initializing them.

Parameters:
  • x – Compatible input features.

  • rng – PRNG key. Only necessary for some modules.

Note

When overriding, set rng’s default value to None if a key is not required. MLAX uses this information to avoid splitting and passing keys to modules that do not need them.

Parameters:
  • inference_mode – Whether in inference or training mode. Default: training mode.

  • batch_axis_name – Hashable or tuple of hashable representing the batch axis name(s) when called in a jax.vmap or jax.pmap context. Used by modules such as ZNorm to normalize along the batch axis. Default: (), no batch axis.

Returns:

Output features.

setup(x: Any) None

Initialize parameters. Submodules may not be initialized.

Parameters:

x – Compatible input features.

class mlax.nn.series.SeriesRng(layers: Iterable[Module])

Bases: Module

Combination of layers that may require rng in series.

forward(x: Any, rng: Array, inference_mode: bool = False, batch_axis_name: Hashable | Tuple[Hashable] = ()) Any

Perform the forward pass assuming setup has been called.

Note

Because setup may not initialize submodules, forward may need to initialize submodules before using them. This is commonly done by calling their __call__ method, recursively initializing them.

Parameters:
  • x – Compatible input features.

  • rng – PRNG key. Only necessary for some modules.

Note

When overriding, set rng’s default value to None if a key is not required. MLAX uses this information to avoid splitting and passing keys to modules that do not need them.

Parameters:
  • inference_mode – Whether in inference or training mode. Default: training mode.

  • batch_axis_name – Hashable or tuple of hashable representing the batch axis name(s) when called in a jax.vmap or jax.pmap context. Used by modules such as ZNorm to normalize along the batch axis. Default: (), no batch axis.

Returns:

Output features.

setup(x: Any) None

Initialize parameters. Submodules may not be initialized.

Parameters:

x – Compatible input features.

mlax.nn.z_norm module

class mlax.nn.z_norm.ZNorm(rng: ~jax.Array, axis: str | int | ~typing.Sequence[int], epsilon: float = 1e-05, momentum: float = 0.9, mean_initializer=<function zeros>, variance_initializer=<function ones>, dtype=<class 'jax.numpy.float32'>)

Bases: Module

Z-score normalization across batch axes with running mean and variance.

forward(x: Array, rng: None = None, inference_mode: bool = False, batch_axis_name: Hashable | Tuple[Hashable] = ()) Array

Perform the forward pass assuming setup has been called.

Note

Because setup may not initialize submodules, forward may need to initialize submodules before using them. This is commonly done by calling their __call__ method, recursively initializing them.

Parameters:
  • x – Compatible input features.

  • rng – PRNG key. Only necessary for some modules.

Note

When overriding, set rng’s default value to None if a key is not required. MLAX uses this information to avoid splitting and passing keys to modules that do not need them.

Parameters:
  • inference_mode – Whether in inference or training mode. Default: training mode.

  • batch_axis_name – Hashable or tuple of hashable representing the batch axis name(s) when called in a jax.vmap or jax.pmap context. Used by modules such as ZNorm to normalize along the batch axis. Default: (), no batch axis.

Returns:

Output features.

setup(x: Array) None

Initialize parameters. Submodules may not be initialized.

Parameters:

x – Compatible input features.

Module contents