mlax.nn package
Submodules
mlax.nn.bias module
- class mlax.nn.bias.Bias(rng: ~jax.Array, in_features: int | ~typing.Sequence[int], bias_initializer=<function zeros>, dtype=<class 'jax.numpy.float32'>)
Bases:
ModuleBias addition layer.
- forward(x: Array, rng: None = None, inference_mode: bool = False, batch_axis_name: Hashable | Tuple[Hashable] = ()) Array
Perform the forward pass assuming
setuphas been called.- Parameters:
x – Compatible input features.
rng – PRNG key. Only necessary for some modules.
inference_mode – Whether in inference or training mode. Default: training mode.
batch_axis_name – Hashable or tuple of hashable representing the batch axis name(s) when called in a jax.vmap or jax.pmap context. Used by modules such as ZNorm to normalize along the batch axis. Default: (), no batch axis.
- Returns:
Output features.
Note
When overriding, set
rng’s default value toNoneif a key is not required. MLAX uses this information to avoid splitting and passing keys to modules that do not need them.
- setup(x: Array) None
Initialize parameters and put
selfinto a valid state forforward. Submodules may not be initialized until__call__is called.- Parameters:
x – Compatible input features.
mlax.nn.conv module
- class mlax.nn.conv.Conv(rng: ~jax.Array, out_channels: int, filter_shape: int | ~typing.Sequence[int], strides: int | ~typing.Sequence[int] = 1, padding: str | int | ~typing.Sequence[int | ~typing.Tuple[int, int]] = 'VALID', input_dilation: int | ~typing.Sequence[int] | None = None, filter_dilation: int | ~typing.Sequence[int] | None = None, feature_group_count: int = 1, batch_group_count: int = 1, data_format: str | ~typing.Tuple[str, str, str] = 'channel_last', precision=None, accum_dtype=None, kernel_initializer=<function variance_scaling.<locals>.init>, dtype=<class 'jax.numpy.float32'>)
Bases:
ModuleConvolution transformation layer.
- forward(x: Array, rng: None = None, inference_mode: bool = False, batch_axis_name: Hashable | Tuple[Hashable] = ()) Array
Perform the forward pass assuming
setuphas been called.- Parameters:
x – Compatible input features.
rng – PRNG key. Only necessary for some modules.
inference_mode – Whether in inference or training mode. Default: training mode.
batch_axis_name – Hashable or tuple of hashable representing the batch axis name(s) when called in a jax.vmap or jax.pmap context. Used by modules such as ZNorm to normalize along the batch axis. Default: (), no batch axis.
- Returns:
Output features.
Note
When overriding, set
rng’s default value toNoneif a key is not required. MLAX uses this information to avoid splitting and passing keys to modules that do not need them.
- setup(x: Array) None
Initialize parameters and put
selfinto a valid state forforward. Submodules may not be initialized until__call__is called.- Parameters:
x – Compatible input features.
mlax.nn.embed module
- class mlax.nn.embed.Embed(rng: ~jax.Array, vocab_size: int, embed_dim: int, embed_initializer=<function variance_scaling.<locals>.init>, dtype=<class 'jax.numpy.float32'>)
Bases:
ModuleEmbedding layer.
- forward(x: Array, rng: None = None, inference_mode: bool = False, batch_axis_name: Hashable | Tuple[Hashable] = ()) Array
Perform the forward pass assuming
setuphas been called.- Parameters:
x – Compatible input features.
rng – PRNG key. Only necessary for some modules.
inference_mode – Whether in inference or training mode. Default: training mode.
batch_axis_name – Hashable or tuple of hashable representing the batch axis name(s) when called in a jax.vmap or jax.pmap context. Used by modules such as ZNorm to normalize along the batch axis. Default: (), no batch axis.
- Returns:
Output features.
Note
When overriding, set
rng’s default value toNoneif a key is not required. MLAX uses this information to avoid splitting and passing keys to modules that do not need them.
- setup(x: Array) None
Initialize parameters and put
selfinto a valid state forforward. Submodules may not be initialized until__call__is called.- Parameters:
x – Compatible input features.
mlax.nn.f module
- class mlax.nn.f.F(train_fn: Callable[[Any], Any], infer_fn: Callable[[Any], Any] | None = None)
Bases:
ModuleWrapper to create pure function layers.
- forward(x: Any, rng: None = None, inference_mode: bool = False, batch_axis_name: Hashable | Tuple[Hashable] = ()) Any
Perform the forward pass assuming
setuphas been called.- Parameters:
x – Compatible input features.
rng – PRNG key. Only necessary for some modules.
inference_mode – Whether in inference or training mode. Default: training mode.
batch_axis_name – Hashable or tuple of hashable representing the batch axis name(s) when called in a jax.vmap or jax.pmap context. Used by modules such as ZNorm to normalize along the batch axis. Default: (), no batch axis.
- Returns:
Output features.
Note
When overriding, set
rng’s default value toNoneif a key is not required. MLAX uses this information to avoid splitting and passing keys to modules that do not need them.
- setup(x: Any) None
Initialize parameters and put
selfinto a valid state forforward. Submodules may not be initialized until__call__is called.- Parameters:
x – Compatible input features.
- class mlax.nn.f.FRng(train_fn: Callable[[Any, Array], Any], infer_fn: Callable[[Any, Array], Any] | None = None)
Bases:
ModuleWrapper to create pure function layers that may require rng.
- forward(x: Any, rng: Array, inference_mode: bool = False, batch_axis_name: Hashable | Tuple[Hashable] = ()) Any
Perform the forward pass assuming
setuphas been called.- Parameters:
x – Compatible input features.
rng – PRNG key. Only necessary for some modules.
inference_mode – Whether in inference or training mode. Default: training mode.
batch_axis_name – Hashable or tuple of hashable representing the batch axis name(s) when called in a jax.vmap or jax.pmap context. Used by modules such as ZNorm to normalize along the batch axis. Default: (), no batch axis.
- Returns:
Output features.
Note
When overriding, set
rng’s default value toNoneif a key is not required. MLAX uses this information to avoid splitting and passing keys to modules that do not need them.
- setup(x: Any) None
Initialize parameters and put
selfinto a valid state forforward. Submodules may not be initialized until__call__is called.- Parameters:
x – Compatible input features.
mlax.nn.functional module
- mlax.nn.functional.apply_attention_weights(attention_weights: Array, value: Array) Array
Apply attention weights to values.
- Parameters:
attention_weights – Attention weights of the same dtype as
valueand of shape(query_length, key_value_length).value – Value array of shape
(key_value_length, value_depth).
- Returns activations:
valuewithattention_weightsapplied, of shape(query_length, value_depth).
- mlax.nn.functional.avg_pool(x: Array, window_shape: int | Sequence[int], strides: int | Sequence[int] = 1, padding: str | int | Sequence[int | Tuple[int, int]] = 'VALID', input_dilation: int | Sequence[int] | None = None, window_dilation: int | Sequence[int] | None = None, data_format: str = 'channel_last') Array
Apply average pooling over input features.
- Parameters:
x – Input features. Must have be unbatched thus having
n_spatial_dims + 1dimensions.window_shape – See the
window_shapeparameter ofpooling.strides – See the
stridesparameter ofpooling. Default: 1.padding – See the
paddingparameter ofpooling.input_dilation – See the
input_dilationparameter ofpooling. Default: None, no input dilation.window_dilation – See the
window_dilationparameter ofpooling. Default: None, no window dilation.data_format – “channel_last”, “channel_first”, or a string representing the kernel spec as described in
jax.lax.conv_general_dilated, but without N the batch dimension. Default: “channel_last”.
- Returns y:
xwith average pooling applied.
- mlax.nn.functional.dot_product_attention_logits(query: Array, key: Array) Array
Compute scaled dot-product attention logits.
- Parameters:
query – Query array of shape
(query_length, query_key_depth).key – Key array of the same dtype as
queryand of shape(key_value_length, query_key_depth).
- Returns:
Attention logits of
(query_length, key_value_length).
- mlax.nn.functional.dropout(x: Array, rng: Any, rate: float, axis: int | Sequence[int]) Array
Apply random dropouts to input features.
- Parameters:
x – Input features.
rng – PRNG key for randomizing dropouts.
rate – Probability at which each element is droped out. Must be in [0, 1).
axis – Axis or sequence of axes to drop features along.
- Returns y:
xwith dropouts applied.
- mlax.nn.functional.identity(*xs: Any) Any
Identity function.
- Parameters:
x – Input features.
- Returns y:
x.
- mlax.nn.functional.max_pool(x: Array, window_shape: int | Sequence[int], strides: int | Sequence[int] = 1, padding: str | int | Sequence[int | Tuple[int, int]] = 'VALID', input_dilation: int | Sequence[int] | None = None, window_dilation: int | Sequence[int] | None = None, data_format: str = 'channel_last') Array
Apply max pooling over input features.
- Parameters:
x – Input features. Must have be unbatched thus having
n_spatial_dims + 1dimensions.window_shape – See the
window_shapeparameter ofpooling.strides – See the
stridesparameter ofpooling. Default: 1.padding – See the
paddingparameter ofpooling.input_dilation – See the
input_dilationparameter ofpooling. Default: None, no input dilation.window_dilation – See the
window_dilationparameter ofpooling. Default: None, no window dilation.data_format – “channel_last”, “channel_first”, or a string representing the kernel spec as described in
jax.lax.conv_general_dilated, but without N the batch dimension. Default: “channel_last”.
- Returns y:
xwith max pooling applied.
- mlax.nn.functional.pool(x: Array, init_value: Any, reduce_fn: Callable[[Any, Any], Any], window_shape: int | Sequence[int], strides: int | Sequence[int] = 1, padding: str | int | Sequence[int | Tuple[int, int]] = 'VALID', input_dilation: int | Sequence[int] | None = None, window_dilation: int | Sequence[int] | None = None, data_format: str = 'channel_last') Array
- Apply an arbitrary reduce function over poolings windows of input
features.
- Parameters:
x – Input features. Must have be unbatched thus having
n_spatial_dims + 1dimensions.init_value – Initial value of the reduce function over each pooling window.
reduce_fn – Reduce function.
window_shape – An integer or a sequence of
n_spatial_dimsintegers, specifying the shape of the pooling window used on input features. A single integer specifies the same value for all spatial dimensions.strides – An integer or a sequence of
n_spatial_dimsintegers, specifying the strides of the window shape along the spatial dimensions. A single integer specifies the same value for all spatial dimensions. Default: 1.padding – String, integer, or a sequence of n_spatial_dims integers or integer tuple pairs that give the padding to apply before and after each spatial dimension. If integer, the same padding is applied before and after all spatial dimensions. If a sequence of integers, then the same padding is applied before and after each spatial dimension. See the
paddingparameter of jax.lax.reduce_window, which is used internally.input_dilation – None, an integer, or a sequence of
n_spatial_dimsintegers, specifying the input dilation rate in each spatial dimension. See thebase_dilationparameter of jax.lax.reduce_window. Default: None, no input dilation.window_dilation – None, an integer, or a sequence of
n_spatial_dimsintegers, specifying the window dilation rate in each spatial dimension. See thewindow_dilationparameter of jax.lax.reduce_window. Default: None, no window dilation.data_format – “channel_last”, “channel_first”, or a string representing the kernel spec as described in
jax.lax.conv_general_dilated, but without N the batch dimension. Default: “channel_last”.
- Returns y:
xwith pooling applied.
- mlax.nn.functional.sum_pool(x: Array, window_shape: int | Sequence[int], strides: int | Sequence[int] = 1, padding: str | int | Sequence[int | Tuple[int, int]] = 'VALID', input_dilation: int | Sequence[int] | None = None, window_dilation: int | Sequence[int] | None = None, data_format: str = 'channel_last') Array
Apply sum pooling over input features.
- Parameters:
x – Input features. Must have be unbatched thus having
n_spatial_dims + 1dimensions.window_shape – See the
window_shapeparameter ofpooling.strides – See the
stridesparameter ofpooling. Default: 1.padding – See the
paddingparameter ofpooling.input_dilation – See the
input_dilationparameter ofpooling. Default: None, no input dilation.window_dilation – See the
window_dilationparameter ofpooling. Default: None, no window dilation.data_format – “channel_last”, “channel_first”, or a string representing the kernel spec as described in
jax.lax.conv_general_dilated, but without N the batch dimension. Default: “channel_last”.
- Returns y:
xwith sum pooling applied.
- mlax.nn.functional.z_norm(x: Array, axis: str | int | Sequence[int], batch_axis_name: Hashable | Tuple[Hashable] = (), epsilon: float = 1e-05)
Apply Z-score normalization.
- Parameters:
axis – “all”, “channel_last”, “channel_first”, axis, or sequence of axes to normalize input features along. “all” indicates normalization along all axes (layer norm). “channel_last” and “channel_first” indicate normalization along all but the channel axis, assumed to be the last or first axis (instance norm).
epsilon – Small number added to variance to avoid divisions by zero. Default: 1e-05.
batch_axis_name – Hashable or tuple of hashable representing the batch axis name(s) to normalize along in addition to those in
axis. Default: (), no normlization along any batch axis.
- Returns:
xwith normalization applied.
mlax.nn.linear module
- class mlax.nn.linear.Linear(rng: ~jax.Array, out_features: int, precision=None, accum_dtype=None, transposed_kernel: bool = False, kernel_initializer=<function variance_scaling.<locals>.init>, dtype=<class 'jax.numpy.float32'>)
Bases:
ModuleLinear transformation layer without bias with lazy kernel initialization.
- forward(x: Array, rng: None = None, inference_mode: bool = False, batch_axis_name: Hashable | Tuple[Hashable] = ()) Array
Perform the forward pass assuming
setuphas been called.- Parameters:
x – Compatible input features.
rng – PRNG key. Only necessary for some modules.
inference_mode – Whether in inference or training mode. Default: training mode.
batch_axis_name – Hashable or tuple of hashable representing the batch axis name(s) when called in a jax.vmap or jax.pmap context. Used by modules such as ZNorm to normalize along the batch axis. Default: (), no batch axis.
- Returns:
Output features.
Note
When overriding, set
rng’s default value toNoneif a key is not required. MLAX uses this information to avoid splitting and passing keys to modules that do not need them.
- setup(x: Array) None
Initialize parameters and put
selfinto a valid state forforward. Submodules may not be initialized until__call__is called.- Parameters:
x – Compatible input features.
mlax.nn.parallel module
- class mlax.nn.parallel.Parallel(layers: Iterable[Module])
Bases:
ModuleCombination of layers that do not require rng in parallel.
- forward(x: Iterable[Any], rng: None = None, inference_mode: bool = False, batch_axis_name: Hashable | Tuple[Hashable] = ()) List[Any]
Perform the forward pass assuming
setuphas been called.- Parameters:
x – Compatible input features.
rng – PRNG key. Only necessary for some modules.
inference_mode – Whether in inference or training mode. Default: training mode.
batch_axis_name – Hashable or tuple of hashable representing the batch axis name(s) when called in a jax.vmap or jax.pmap context. Used by modules such as ZNorm to normalize along the batch axis. Default: (), no batch axis.
- Returns:
Output features.
Note
When overriding, set
rng’s default value toNoneif a key is not required. MLAX uses this information to avoid splitting and passing keys to modules that do not need them.
- setup(x: Any) None
Initialize parameters and put
selfinto a valid state forforward. Submodules may not be initialized until__call__is called.- Parameters:
x – Compatible input features.
- class mlax.nn.parallel.ParallelRng(layers: Iterable[Module])
Bases:
ModuleCombination of layers that may require rng in parallel.
- forward(x: Iterable[Any], rng: Array, inference_mode: bool = False, batch_axis_name: Hashable | Tuple[Hashable] = ()) List[Any]
Perform the forward pass assuming
setuphas been called.- Parameters:
x – Compatible input features.
rng – PRNG key. Only necessary for some modules.
inference_mode – Whether in inference or training mode. Default: training mode.
batch_axis_name – Hashable or tuple of hashable representing the batch axis name(s) when called in a jax.vmap or jax.pmap context. Used by modules such as ZNorm to normalize along the batch axis. Default: (), no batch axis.
- Returns:
Output features.
Note
When overriding, set
rng’s default value toNoneif a key is not required. MLAX uses this information to avoid splitting and passing keys to modules that do not need them.
- setup(x: Any) None
Initialize parameters and put
selfinto a valid state forforward. Submodules may not be initialized until__call__is called.- Parameters:
x – Compatible input features.
mlax.nn.scaler module
- class mlax.nn.scaler.Scaler(rng: ~jax.Array, in_features: int | ~typing.Sequence[int | None], scaler_initializer=<function ones>, dtype=<class 'jax.numpy.float32'>)
Bases:
ModuleScaler layer.
- forward(x: Array, rng: None = None, inference_mode: bool = False, batch_axis_name: Hashable | Tuple[Hashable] = ()) Array
Perform the forward pass assuming
setuphas been called.- Parameters:
x – Compatible input features.
rng – PRNG key. Only necessary for some modules.
inference_mode – Whether in inference or training mode. Default: training mode.
batch_axis_name – Hashable or tuple of hashable representing the batch axis name(s) when called in a jax.vmap or jax.pmap context. Used by modules such as ZNorm to normalize along the batch axis. Default: (), no batch axis.
- Returns:
Output features.
Note
When overriding, set
rng’s default value toNoneif a key is not required. MLAX uses this information to avoid splitting and passing keys to modules that do not need them.
- setup(x: Array) None
Initialize parameters and put
selfinto a valid state forforward. Submodules may not be initialized until__call__is called.- Parameters:
x – Compatible input features.
mlax.nn.series module
- class mlax.nn.series.Series(layers: Iterable[Module])
Bases:
ModuleCombination of layers that do not require rng in series.
- forward(x: Any, rng: None = None, inference_mode: bool = False, batch_axis_name: Hashable | Tuple[Hashable] = ()) Any
Perform the forward pass assuming
setuphas been called.- Parameters:
x – Compatible input features.
rng – PRNG key. Only necessary for some modules.
inference_mode – Whether in inference or training mode. Default: training mode.
batch_axis_name – Hashable or tuple of hashable representing the batch axis name(s) when called in a jax.vmap or jax.pmap context. Used by modules such as ZNorm to normalize along the batch axis. Default: (), no batch axis.
- Returns:
Output features.
Note
When overriding, set
rng’s default value toNoneif a key is not required. MLAX uses this information to avoid splitting and passing keys to modules that do not need them.
- setup(x: Any) None
Initialize parameters and put
selfinto a valid state forforward. Submodules may not be initialized until__call__is called.- Parameters:
x – Compatible input features.
- class mlax.nn.series.SeriesRng(layers: Iterable[Module])
Bases:
ModuleCombination of layers that may require rng in series.
- forward(x: Any, rng: Array, inference_mode: bool = False, batch_axis_name: Hashable | Tuple[Hashable] = ()) Any
Perform the forward pass assuming
setuphas been called.- Parameters:
x – Compatible input features.
rng – PRNG key. Only necessary for some modules.
inference_mode – Whether in inference or training mode. Default: training mode.
batch_axis_name – Hashable or tuple of hashable representing the batch axis name(s) when called in a jax.vmap or jax.pmap context. Used by modules such as ZNorm to normalize along the batch axis. Default: (), no batch axis.
- Returns:
Output features.
Note
When overriding, set
rng’s default value toNoneif a key is not required. MLAX uses this information to avoid splitting and passing keys to modules that do not need them.
- setup(x: Any) None
Initialize parameters and put
selfinto a valid state forforward. Submodules may not be initialized until__call__is called.- Parameters:
x – Compatible input features.
mlax.nn.z_norm module
- class mlax.nn.z_norm.ZNorm(rng: ~jax.Array, axis: str | int | ~typing.Sequence[int], epsilon: float = 1e-05, momentum: float = 0.9, mean_initializer=<function zeros>, variance_initializer=<function ones>, dtype=<class 'jax.numpy.float32'>)
Bases:
ModuleZ-score normalization across batch axes with running mean and variance.
- forward(x: Array, rng: None = None, inference_mode: bool = False, batch_axis_name: Hashable | Tuple[Hashable] = ()) Array
Perform the forward pass assuming
setuphas been called.- Parameters:
x – Compatible input features.
rng – PRNG key. Only necessary for some modules.
inference_mode – Whether in inference or training mode. Default: training mode.
batch_axis_name – Hashable or tuple of hashable representing the batch axis name(s) when called in a jax.vmap or jax.pmap context. Used by modules such as ZNorm to normalize along the batch axis. Default: (), no batch axis.
- Returns:
Output features.
Note
When overriding, set
rng’s default value toNoneif a key is not required. MLAX uses this information to avoid splitting and passing keys to modules that do not need them.
- setup(x: Array) None
Initialize parameters and put
selfinto a valid state forforward. Submodules may not be initialized until__call__is called.- Parameters:
x – Compatible input features.